Skip to content

Commit

Permalink
simplify asyncio_runnable
Browse files Browse the repository at this point in the history
  • Loading branch information
cwharris committed Nov 1, 2023
1 parent 5f5f7ae commit 48dad3d
Showing 1 changed file with 18 additions and 22 deletions.
40 changes: 18 additions & 22 deletions python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,24 +143,23 @@ class CoroutineRunnableSink : public mrc::node::WritableProvider<T>,
public mrc::node::SinkChannelOwner<T>
{
protected:
CoroutineRunnableSink()
CoroutineRunnableSink() :
m_reader([this](T& value) {

Check warning on line 147 in python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp#L147

Added line #L147 was not covered by tests
return this->get_readable_edge()->await_read(value);
})
{
// Set the default channel
this->set_channel(std::make_unique<mrc::channel::BufferedChannel<T>>());
}

auto build_readable_generator(std::stop_token stop_token) -> mrc::coroutines::AsyncGenerator<T>
{
auto read_awaiter = BoostFutureReader<T>([this](T& value) {
return this->get_readable_edge()->await_read(value);
});

while (!stop_token.stop_requested())
{
T value;

// Pull a message off of the upstream channel
auto status = co_await read_awaiter.async_read(std::ref(value));
auto status = co_await m_reader.async_read(std::ref(value));

if (status != mrc::channel::Status::success)
{
Expand All @@ -172,6 +171,9 @@ class CoroutineRunnableSink : public mrc::node::WritableProvider<T>,

co_return;
}

private:
BoostFutureReader<T> m_reader;
};

template <typename T>
Expand All @@ -184,25 +186,19 @@ class CoroutineRunnableSource : public mrc::node::WritableAcceptor<T>,
{
// Set the default channel
this->set_channel(std::make_unique<mrc::channel::BufferedChannel<T>>());
}

// auto build_readable_generator(std::stop_token stop_token)
// -> mrc::coroutines::AsyncGenerator<mrc::coroutines::detail::VoidValue>
// {
// while (!stop_token.stop_requested())
// {
// co_yield mrc::coroutines::detail::VoidValue{};
// }

// co_return;
// }

auto build_writable_receiver() -> std::shared_ptr<BoostFutureWriter<T>>
{
return std::make_shared<BoostFutureWriter<T>>([this](T&& value) {
m_writer = std::make_shared<BoostFutureWriter<T>>([this](T&& value) {
return this->get_writable_edge()->await_write(std::move(value));
});
}

auto get_writable_receiver() -> std::shared_ptr<BoostFutureWriter<T>>
{
return m_writer;
}

private:
std::shared_ptr<BoostFutureWriter<T>> m_writer;
};

template <typename InputT, typename OutputT>
Expand Down Expand Up @@ -257,7 +253,7 @@ coroutines::Task<> AsyncioRunnable<InputT, OutputT>::main_task(std::shared_ptr<m
{
// Get the generator and receiver
auto input_generator = CoroutineRunnableSink<InputT>::build_readable_generator(m_stop_source.get_token());
auto output_receiver = CoroutineRunnableSource<OutputT>::build_writable_receiver();
auto output_receiver = CoroutineRunnableSource<OutputT>::get_writable_receiver();

// Create the task buffer to limit the number of running tasks
task_buffer_t task_buffer{{.capacity = m_concurrency}};
Expand Down

0 comments on commit 48dad3d

Please sign in to comment.