Skip to content

Commit

Permalink
asyncio runnable first test
Browse files Browse the repository at this point in the history
  • Loading branch information
cwharris committed Oct 27, 2023
1 parent 64a346b commit dc8b0a4
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,8 @@ void AsyncioRunnable<InputT, OutputT>::run(mrc::runnable::Context& ctx)
scheduler->run_until_complete(this->main_task(scheduler));

// Need to drop the output edges
mrc::node::SourceProperties<InputT>::release_edge_connection();
mrc::node::SinkProperties<OutputT>::release_edge_connection();
mrc::node::SourceProperties<OutputT>::release_edge_connection();
mrc::node::SinkProperties<InputT>::release_edge_connection();
}

template <typename InputT, typename OutputT>
Expand Down
1 change: 1 addition & 0 deletions python/mrc/_pymrc/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ add_subdirectory(coro)

# Keep all source files sorted!!!
add_executable(test_pymrc
test_asyncio_runnable.cpp
test_codable_pyobject.cpp
test_executor.cpp
test_main.cpp
Expand Down
116 changes: 116 additions & 0 deletions python/mrc/_pymrc/tests/test_asyncio_runnable.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "test_pymrc.hpp"

#include "pymrc/asyncio_runnable.hpp"
#include "pymrc/executor.hpp"
#include "pymrc/pipeline.hpp"
#include "pymrc/port_builders.hpp"
#include "pymrc/types.hpp"

#include "mrc/node/rx_node.hpp"
#include "mrc/node/rx_sink.hpp"
#include "mrc/node/rx_sink_base.hpp"
#include "mrc/node/rx_source.hpp"
#include "mrc/node/rx_source_base.hpp"
#include "mrc/options/options.hpp"
#include "mrc/options/topology.hpp"
#include "mrc/segment/builder.hpp"
#include "mrc/segment/object.hpp"
#include "mrc/types.hpp"

#include <boost/fiber/future/future.hpp>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <pybind11/cast.h>
#include <pybind11/gil.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h> // IWYU pragma: keep
#include <rxcpp/rx.hpp>

#include <atomic>
#include <cstddef>
#include <functional>
#include <iostream>
#include <map>
#include <memory>
#include <stdexcept>
#include <string>
#include <tuple>
#include <utility>
#include <vector>

namespace py = pybind11;
namespace pymrc = mrc::pymrc;
using namespace std::string_literals;
using namespace py::literals;

PYMRC_TEST_CLASS(AsyncioRunnable);

class MyAsyncioRunnable : public pymrc::AsyncioRunnable<int, unsigned int>
{
mrc::coroutines::AsyncGenerator<unsigned int> on_data(int&& value) override
{
co_yield value;
co_yield value * 2;
};
};

TEST_F(TestAsyncioRunnable, Execute)
{
std::atomic<unsigned int> counter = 0;
pymrc::Pipeline p;

pybind11::module_::import("mrc.core.coro");

auto init = [&counter](mrc::segment::IBuilder& seg) {
auto src = seg.make_source<int>("src", [](rxcpp::subscriber<int>& s) {
if (s.is_subscribed())
{
s.on_next(5);
s.on_next(10);
s.on_next(7);
}

s.on_completed();
});

auto internal = seg.construct_object<MyAsyncioRunnable>("internal");

auto sink = seg.make_sink<unsigned int>("sink", [&counter](unsigned int x) {
counter.fetch_add(x, std::memory_order_relaxed);
});

seg.make_edge(src, internal);
seg.make_edge(internal, sink);
};

p.make_segment("seg1"s, init);

auto options = std::make_shared<mrc::Options>();
options->topology().user_cpuset("0");

pymrc::Executor exec{options};
exec.register_pipeline(p);

exec.start();
exec.join();

EXPECT_EQ(counter, 66);
}

0 comments on commit dc8b0a4

Please sign in to comment.