From 48d17a17bbfa176a4cc650a7782a38224605d7f3 Mon Sep 17 00:00:00 2001 From: David Gardner <96306125+dagardner-nv@users.noreply.github.com> Date: Tue, 17 Sep 2024 08:10:21 -0700 Subject: [PATCH] Pass a `mrc.Subscription` object to sources rather than a `mrc.Subscriber` (#499) * Remove the `make_source_subscriber` method in favor of inspecting the Python function signature. * Since the `make_source_subscriber` method was never part of a release I think this can still be considered a non-breaking change. Authors: - David Gardner (https://github.com/dagardner-nv) Approvers: - Michael Demoret (https://github.com/mdemoret-nv) URL: https://github.com/nv-morpheus/MRC/pull/499 --- python/mrc/_pymrc/include/pymrc/segment.hpp | 4 -- .../mrc/_pymrc/include/pymrc/subscriber.hpp | 8 +++- python/mrc/_pymrc/src/segment.cpp | 38 ++++++++++++++----- python/mrc/_pymrc/src/subscriber.cpp | 8 +++- python/mrc/core/segment.cpp | 6 --- python/mrc/core/subscriber.cpp | 3 +- python/tests/test_executor.py | 6 +-- python/tests/test_node.py | 36 +++++++++++++++++- 8 files changed, 83 insertions(+), 26 deletions(-) diff --git a/python/mrc/_pymrc/include/pymrc/segment.hpp b/python/mrc/_pymrc/include/pymrc/segment.hpp index 2ceae6f25..94bce476e 100644 --- a/python/mrc/_pymrc/include/pymrc/segment.hpp +++ b/python/mrc/_pymrc/include/pymrc/segment.hpp @@ -143,10 +143,6 @@ class BuilderProxy const std::string& name, pybind11::function gen_factory); - static std::shared_ptr make_source_subscriber(mrc::segment::IBuilder& self, - const std::string& name, - pybind11::function gen_factory); - static std::shared_ptr make_source_component(mrc::segment::IBuilder& self, const std::string& name, pybind11::iterator source_iterator); diff --git a/python/mrc/_pymrc/include/pymrc/subscriber.hpp b/python/mrc/_pymrc/include/pymrc/subscriber.hpp index 5a079906f..6cc793dd5 100644 --- a/python/mrc/_pymrc/include/pymrc/subscriber.hpp +++ b/python/mrc/_pymrc/include/pymrc/subscriber.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -47,6 +47,12 @@ class SubscriberProxy static bool is_subscribed(PyObjectSubscriber* self); }; +class SubscriptionProxy +{ + public: + static bool is_subscribed(PySubscription* self); +}; + class ObservableProxy { public: diff --git a/python/mrc/_pymrc/src/segment.cpp b/python/mrc/_pymrc/src/segment.cpp index 6afb01967..ec78dc927 100644 --- a/python/mrc/_pymrc/src/segment.cpp +++ b/python/mrc/_pymrc/src/segment.cpp @@ -277,8 +277,9 @@ class SubscriberFuncWrapper : public mrc::pymrc::PythonSource { DVLOG(10) << ctx.info() << " Starting source"; py::gil_scoped_acquire gil; - py::object py_sub = py::cast(subscriber); - auto py_iter = m_gen_factory.operator()(std::move(py_sub)); + PySubscription subscription = subscriber.get_subscription(); + py::object py_sub = py::cast(subscription); + auto py_iter = m_gen_factory.operator()(std::move(py_sub)); PyIteratorWrapper iter_wrapper{std::move(py_iter)}; for (auto next_val : iter_wrapper) @@ -360,14 +361,33 @@ std::shared_ptr BuilderProxy::make_source(mrc::s const std::string& name, py::function gen_factory) { - return build_source(self, name, PyIteratorWrapper(std::move(gen_factory))); -} + // Determine if the gen_factory is expecting to receive a subscription object + auto inspect_mod = py::module::import("inspect"); + auto signature = inspect_mod.attr("signature")(gen_factory); + auto params = signature.attr("parameters"); + auto num_params = py::len(params); + bool expects_subscription = false; + + if (num_params > 0) + { + // We know there is at least one parameter. Check if the first parameter is a subscription object + // Note, when we receive a function that has been bound with `functools.partial(fn, arg1=some_value)`, the + // parameter is still visible in the signature of the partial object. + auto mrc_mod = py::module::import("mrc"); + auto param_values = params.attr("values")(); + auto first_param = py::iter(param_values); + auto type_hint = py::object((*first_param).attr("annotation")); + expects_subscription = (type_hint.is(mrc_mod.attr("Subscription")) || + type_hint.equal(py::str("mrc.Subscription")) || + type_hint.equal(py::str("Subscription"))); + } -std::shared_ptr BuilderProxy::make_source_subscriber(mrc::segment::IBuilder& self, - const std::string& name, - py::function gen_factory) -{ - return self.construct_object(name, std::move(gen_factory)); + if (expects_subscription) + { + return self.construct_object(name, std::move(gen_factory)); + } + + return build_source(self, name, PyIteratorWrapper(std::move(gen_factory))); } std::shared_ptr BuilderProxy::make_source_component(mrc::segment::IBuilder& self, diff --git a/python/mrc/_pymrc/src/subscriber.cpp b/python/mrc/_pymrc/src/subscriber.cpp index c00aaa187..6d94efff9 100644 --- a/python/mrc/_pymrc/src/subscriber.cpp +++ b/python/mrc/_pymrc/src/subscriber.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -115,6 +115,12 @@ bool SubscriberProxy::is_subscribed(PyObjectSubscriber* self) return self->is_subscribed(); } +bool SubscriptionProxy::is_subscribed(PySubscription* self) +{ + // No GIL here + return self->is_subscribed(); +} + PySubscription ObservableProxy::subscribe(PyObjectObservable* self, PyObjectObserver& observer) { // Call the internal subscribe function diff --git a/python/mrc/core/segment.cpp b/python/mrc/core/segment.cpp index 2224fb2e4..6c1898d33 100644 --- a/python/mrc/core/segment.cpp +++ b/python/mrc/core/segment.cpp @@ -134,12 +134,6 @@ PYBIND11_MODULE(segment, py_mod) const std::string&, py::function)>(&BuilderProxy::make_source)); - Builder.def("make_source_subscriber", - static_cast (*)(mrc::segment::IBuilder&, - const std::string&, - py::function)>( - &BuilderProxy::make_source_subscriber)); - Builder.def("make_source_component", static_cast (*)(mrc::segment::IBuilder&, const std::string&, diff --git a/python/mrc/core/subscriber.cpp b/python/mrc/core/subscriber.cpp index d435c4edf..8d6de717a 100644 --- a/python/mrc/core/subscriber.cpp +++ b/python/mrc/core/subscriber.cpp @@ -50,7 +50,8 @@ PYBIND11_MODULE(subscriber, py_mod) // Common must be first in every module pymrc::import(py_mod, "mrc.core.common"); - py::class_(py_mod, "Subscription"); + py::class_(py_mod, "Subscription") + .def("is_subscribed", &SubscriptionProxy::is_subscribed, py::call_guard()); py::class_(py_mod, "Observer") .def("on_next", diff --git a/python/tests/test_executor.py b/python/tests/test_executor.py index eb0e3596f..46381d285 100644 --- a/python/tests/test_executor.py +++ b/python/tests/test_executor.py @@ -69,12 +69,12 @@ def blocking_source(): def build(builder: mrc.Builder): - def gen_data(subscriber: mrc.Subscriber): + def gen_data(subscription: mrc.Subscription): yield 1 - while subscriber.is_subscribed(): + while subscription.is_subscribed(): time.sleep(0.1) - return builder.make_source_subscriber("blocking_source", gen_data) + return builder.make_source("blocking_source", gen_data) return build diff --git a/python/tests/test_node.py b/python/tests/test_node.py index a520e9c65..a59e11eef 100644 --- a/python/tests/test_node.py +++ b/python/tests/test_node.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -489,5 +489,39 @@ def on_completed(): assert on_completed_count == 1 +def test_source_with_bound_value(): + """ + This test ensures that the bound values isn't confused with a subscription object + """ + on_next_value = None + + def segment_init(seg: mrc.Builder): + + def source_gen(a): + yield a + + bound_gen = functools.partial(source_gen, a=1) + source = seg.make_source("my_src", bound_gen) + + def on_next(x: int): + nonlocal on_next_value + on_next_value = x + + sink = seg.make_sink("sink", on_next) + seg.make_edge(source, sink) + + pipeline = mrc.Pipeline() + pipeline.make_segment("my_seg", segment_init) + + options = mrc.Options() + executor = mrc.Executor(options) + executor.register_pipeline(pipeline) + + executor.start() + executor.join() + + assert on_next_value == 1 + + if (__name__ == "__main__"): test_launch_options_properties()