Skip to content

Commit

Permalink
Add backpressure to Channel receivers (#412)
Browse files Browse the repository at this point in the history
Co-authored-by: Eugene <x@null.page>
Co-authored-by: Eugene <inbox@null.page>
  • Loading branch information
3 people authored Dec 12, 2024
1 parent e456efe commit f89c19c
Show file tree
Hide file tree
Showing 8 changed files with 279 additions and 91 deletions.
8 changes: 4 additions & 4 deletions russh/src/channels/channel_ref.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::mpsc::Sender;

use super::WindowSizeRef;
use crate::ChannelMsg;
Expand All @@ -7,12 +7,12 @@ use crate::ChannelMsg;
/// to it and update it's `window_size`.
#[derive(Debug)]
pub struct ChannelRef {
pub(super) sender: UnboundedSender<ChannelMsg>,
pub(super) sender: Sender<ChannelMsg>,
pub(super) window_size: WindowSizeRef,
}

impl ChannelRef {
pub fn new(sender: UnboundedSender<ChannelMsg>) -> Self {
pub fn new(sender: Sender<ChannelMsg>) -> Self {
Self {
sender,
window_size: WindowSizeRef::new(0),
Expand All @@ -25,7 +25,7 @@ impl ChannelRef {
}

impl std::ops::Deref for ChannelRef {
type Target = UnboundedSender<ChannelMsg>;
type Target = Sender<ChannelMsg>;

fn deref(&self) -> &Self::Target {
&self.sender
Expand Down
7 changes: 4 additions & 3 deletions russh/src/channels/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::mpsc::{Sender, UnboundedReceiver};
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::{Mutex, Notify};

use crate::{ChannelId, ChannelOpenFailure, CryptoVec, Error, Pty, Sig};
Expand Down Expand Up @@ -143,7 +143,7 @@ impl WindowSizeRef {
pub struct Channel<Send: From<(ChannelId, ChannelMsg)>> {
pub(crate) id: ChannelId,
pub(crate) sender: Sender<Send>,
pub(crate) receiver: UnboundedReceiver<ChannelMsg>,
pub(crate) receiver: Receiver<ChannelMsg>,
pub(crate) max_packet_size: u32,
pub(crate) window_size: WindowSizeRef,
}
Expand All @@ -160,8 +160,9 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static> Channel<S> {
sender: Sender<S>,
max_packet_size: u32,
window_size: u32,
channel_buffer_size: usize,
) -> (Self, ChannelRef) {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let (tx, rx) = tokio::sync::mpsc::channel(channel_buffer_size);
let window_size = WindowSizeRef::new(window_size);

(
Expand Down
48 changes: 28 additions & 20 deletions russh/src/client/encrypted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ impl Session {
max_packet_size: msg.maximum_packet_size,
window_size: msg.initial_window_size,
})
.await
.unwrap_or(());
} else {
error!("no channel for id {local_id:?}");
Expand Down Expand Up @@ -461,7 +462,7 @@ impl Session {
debug!("channel_eof");
let channel_num = map_err!(ChannelId::decode(&mut r))?;
if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::Eof);
let _ = chan.send(ChannelMsg::Eof).await;
}
client.channel_eof(channel_num, self).await
}
Expand All @@ -477,7 +478,7 @@ impl Session {
}

if let Some(sender) = self.channels.remove(&channel_num) {
let _ = sender.send(ChannelMsg::OpenFailure(reason_code));
let _ = sender.send(ChannelMsg::OpenFailure(reason_code)).await;
}

let _ = self.sender.send(Reply::ChannelOpenFailure);
Expand All @@ -502,9 +503,11 @@ impl Session {
}

if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::Data {
data: CryptoVec::from_slice(&data),
});
let _ = chan
.send(ChannelMsg::Data {
data: CryptoVec::from_slice(&data),
})
.await;
}

client.data(channel_num, &data, self).await
Expand All @@ -526,10 +529,12 @@ impl Session {
}

if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::ExtendedData {
ext: extended_code,
data: CryptoVec::from_slice(&data),
});
let _ = chan
.send(ChannelMsg::ExtendedData {
ext: extended_code,
data: CryptoVec::from_slice(&data),
})
.await;
}

client
Expand All @@ -545,15 +550,15 @@ impl Session {
map_err!(u8::decode(&mut r))?; // should be 0.
let client_can_do = map_err!(u8::decode(&mut r))? != 0;
if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::XonXoff { client_can_do });
let _ = chan.send(ChannelMsg::XonXoff { client_can_do }).await;
}
client.xon_xoff(channel_num, client_can_do, self).await
}
"exit-status" => {
map_err!(u8::decode(&mut r))?; // should be 0.
let exit_status = map_err!(u32::decode(&mut r))?;
if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::ExitStatus { exit_status });
let _ = chan.send(ChannelMsg::ExitStatus { exit_status }).await;
}
client.exit_status(channel_num, exit_status, self).await
}
Expand All @@ -565,12 +570,14 @@ impl Session {
let error_message = map_err!(String::decode(&mut r))?;
let lang_tag = map_err!(String::decode(&mut r))?;
if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::ExitSignal {
signal_name: signal_name.clone(),
core_dumped,
error_message: error_message.to_string(),
lang_tag: lang_tag.to_string(),
});
let _ = chan
.send(ChannelMsg::ExitSignal {
signal_name: signal_name.clone(),
core_dumped,
error_message: error_message.to_string(),
lang_tag: lang_tag.to_string(),
})
.await;
}
client
.exit_signal(
Expand Down Expand Up @@ -636,7 +643,7 @@ impl Session {
if let Some(chan) = self.channels.get(&channel_num) {
chan.window_size().update(new_size).await;

let _ = chan.send(ChannelMsg::WindowAdjusted { new_size });
let _ = chan.send(ChannelMsg::WindowAdjusted { new_size }).await;
}
client.window_adjusted(channel_num, new_size, self).await
}
Expand Down Expand Up @@ -686,14 +693,14 @@ impl Session {
Some((&msg::CHANNEL_SUCCESS, mut r)) => {
let channel_num = map_err!(ChannelId::decode(&mut r))?;
if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::Success);
let _ = chan.send(ChannelMsg::Success).await;
}
client.channel_success(channel_num, self).await
}
Some((&msg::CHANNEL_FAILURE, mut r)) => {
let channel_num = map_err!(ChannelId::decode(&mut r))?;
if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::Failure);
let _ = chan.send(ChannelMsg::Failure).await;
}
client.channel_failure(channel_num, self).await
}
Expand Down Expand Up @@ -888,6 +895,7 @@ impl Session {
self.inbound_channel_sender.clone(),
msg.recipient_maximum_packet_size,
msg.recipient_window_size,
self.common.config.channel_buffer_size,
);

self.channels.insert(id, channel_ref);
Expand Down
26 changes: 11 additions & 15 deletions russh/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ pub struct Handle<H: Handler> {
sender: Sender<Msg>,
receiver: UnboundedReceiver<Reply>,
join: russh_util::runtime::JoinHandle<Result<(), H::Error>>,
channel_buffer_size: usize,
}

impl<H: Handler> Drop for Handle<H> {
Expand Down Expand Up @@ -428,7 +429,7 @@ impl<H: Handler> Handle<H> {
/// Wait for confirmation that a channel is open
async fn wait_channel_confirmation(
&self,
mut receiver: UnboundedReceiver<ChannelMsg>,
mut receiver: Receiver<ChannelMsg>,
window_size_ref: WindowSizeRef,
) -> Result<Channel<Msg>, crate::Error> {
loop {
Expand Down Expand Up @@ -467,7 +468,7 @@ impl<H: Handler> Handle<H> {
/// usable when it's confirmed by the server, as indicated by the
/// `confirmed` field of the corresponding `Channel`.
pub async fn channel_open_session(&self) -> Result<Channel<Msg>, crate::Error> {
let (sender, receiver) = unbounded_channel();
let (sender, receiver) = channel(self.channel_buffer_size);
let channel_ref = ChannelRef::new(sender);
let window_size_ref = channel_ref.window_size().clone();

Expand All @@ -485,7 +486,7 @@ impl<H: Handler> Handle<H> {
originator_address: A,
originator_port: u32,
) -> Result<Channel<Msg>, crate::Error> {
let (sender, receiver) = unbounded_channel();
let (sender, receiver) = channel(self.channel_buffer_size);
let channel_ref = ChannelRef::new(sender);
let window_size_ref = channel_ref.window_size().clone();

Expand Down Expand Up @@ -516,7 +517,7 @@ impl<H: Handler> Handle<H> {
originator_address: B,
originator_port: u32,
) -> Result<Channel<Msg>, crate::Error> {
let (sender, receiver) = unbounded_channel();
let (sender, receiver) = channel(self.channel_buffer_size);
let channel_ref = ChannelRef::new(sender);
let window_size_ref = channel_ref.window_size().clone();

Expand All @@ -538,7 +539,7 @@ impl<H: Handler> Handle<H> {
&self,
socket_path: S,
) -> Result<Channel<Msg>, crate::Error> {
let (sender, receiver) = unbounded_channel();
let (sender, receiver) = channel(self.channel_buffer_size);
let channel_ref = ChannelRef::new(sender);
let window_size_ref = channel_ref.window_size().clone();

Expand Down Expand Up @@ -750,6 +751,7 @@ where
config.maximum_packet_size
);
}
let channel_buffer_size = config.channel_buffer_size;
let mut session = Session::new(
config.window_size,
CommonSession {
Expand Down Expand Up @@ -790,6 +792,7 @@ where
sender: handle_sender,
receiver: handle_receiver,
join,
channel_buffer_size,
})
}

Expand Down Expand Up @@ -1274,16 +1277,6 @@ impl Session {
}
Ok(())
}

/// Send a `ChannelMsg` from the background handler to the client.
pub fn send_channel_msg(&self, channel: ChannelId, msg: ChannelMsg) -> bool {
if let Some(chan) = self.channels.get(&channel) {
chan.send(msg).unwrap_or(());
true
} else {
false
}
}
}

thread_local! {
Expand Down Expand Up @@ -1483,6 +1476,8 @@ pub struct Config {
pub window_size: u32,
/// The maximal size of a single packet.
pub maximum_packet_size: u32,
/// Buffer size for each channel (a number of unprocessed messages to store before propagating backpressure to the TCP stream)
pub channel_buffer_size: usize,
/// Lists of preferred algorithms.
pub preferred: negotiation::Preferred,
/// Time after which the connection is garbage-collected.
Expand All @@ -1506,6 +1501,7 @@ impl Default for Config {
limits: Limits::default(),
window_size: 2097152,
maximum_packet_size: 32768,
channel_buffer_size: 100,
preferred: Default::default(),
inactivity_timeout: None,
keepalive_interval: None,
Expand Down
Loading

0 comments on commit f89c19c

Please sign in to comment.