Skip to content

Commit

Permalink
Add backpressure to Channel receivers
Browse files Browse the repository at this point in the history
Fixes #392.

Simply put, this changes its `UnboundedReceiver` from
`unbounded_channel` with a `Receiver` from `channel(...)`, with a
configurable value in `channel_buffer_size` of 100. This means that, for
each channel that's created, it can hold up to 100 `Msg` objects, which
are at most as big as the window size plus metadata/indirection
pointers.

This is enough to force messages to be read from the underlying
`TcpStream` only as quickly as the server/client is able to handle them.
This has been tested with the example given in #392, including a
modification to slowly read into a non-expanding buffer, and appears to
fix the issue that's been pointed out.

Some other considerations:

- It is still possible to cause memory to explode by opening multiple
  channels, although that might be intentional on the user's part. In
  this case, it's up to the user to handle when channels get created or
  not.
- The limited buffering also means that control `Msg`s (eg. `Eof`) will
  also be delayed by the transport layer backpressure, although this
  might be expected by the SSH standard. Reading through RFC 4254, I
  couldn't find any mention of non-`Data`/`ExtendedData` messages having
  to be dealt differently, but I don't think it's even possible to do so
  over TCP. If this assumption is wrong, this might require a separate
  `mpsc::unbounded_channel` just for control messages.

**BREAKING CHANGES**:

- This removes `fn send_channel_msg`, which doesn't seem
  to be used anywhere in this library, but is part of the publicly
  exposed API nonetheless.
- This adds the configuration value `channel_buffer_size` to both
  servers and clients that allows setting how big the buffer size should
  be.
  • Loading branch information
EpicEric committed Dec 9, 2024
1 parent 5511842 commit 0512b24
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 91 deletions.
8 changes: 4 additions & 4 deletions russh/src/channels/channel_ref.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::mpsc::Sender;

use super::WindowSizeRef;
use crate::ChannelMsg;
Expand All @@ -7,12 +7,12 @@ use crate::ChannelMsg;
/// to it and update it's `window_size`.
#[derive(Debug)]
pub struct ChannelRef {
pub(super) sender: UnboundedSender<ChannelMsg>,
pub(super) sender: Sender<ChannelMsg>,
pub(super) window_size: WindowSizeRef,
}

impl ChannelRef {
pub fn new(sender: UnboundedSender<ChannelMsg>) -> Self {
pub fn new(sender: Sender<ChannelMsg>) -> Self {
Self {
sender,
window_size: WindowSizeRef::new(0),
Expand All @@ -25,7 +25,7 @@ impl ChannelRef {
}

impl std::ops::Deref for ChannelRef {
type Target = UnboundedSender<ChannelMsg>;
type Target = Sender<ChannelMsg>;

fn deref(&self) -> &Self::Target {
&self.sender
Expand Down
7 changes: 4 additions & 3 deletions russh/src/channels/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::mpsc::{Sender, UnboundedReceiver};
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::{Mutex, Notify};

use crate::{ChannelId, ChannelOpenFailure, CryptoVec, Error, Pty, Sig};
Expand Down Expand Up @@ -143,7 +143,7 @@ impl WindowSizeRef {
pub struct Channel<Send: From<(ChannelId, ChannelMsg)>> {
pub(crate) id: ChannelId,
pub(crate) sender: Sender<Send>,
pub(crate) receiver: UnboundedReceiver<ChannelMsg>,
pub(crate) receiver: Receiver<ChannelMsg>,
pub(crate) max_packet_size: u32,
pub(crate) window_size: WindowSizeRef,
}
Expand All @@ -160,8 +160,9 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static> Channel<S> {
sender: Sender<S>,
max_packet_size: u32,
window_size: u32,
channel_buffer_size: usize,
) -> (Self, ChannelRef) {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let (tx, rx) = tokio::sync::mpsc::channel(channel_buffer_size);
let window_size = WindowSizeRef::new(window_size);

(
Expand Down
48 changes: 28 additions & 20 deletions russh/src/client/encrypted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ impl Session {
max_packet_size: msg.maximum_packet_size,
window_size: msg.initial_window_size,
})
.await
.unwrap_or(());
} else {
error!("no channel for id {local_id:?}");
Expand Down Expand Up @@ -457,7 +458,7 @@ impl Session {
debug!("channel_eof");
let channel_num = map_err!(ChannelId::decode(&mut r))?;
if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::Eof);
let _ = chan.send(ChannelMsg::Eof).await;
}
client.channel_eof(channel_num, self).await
}
Expand All @@ -473,7 +474,7 @@ impl Session {
}

if let Some(sender) = self.channels.remove(&channel_num) {
let _ = sender.send(ChannelMsg::OpenFailure(reason_code));
let _ = sender.send(ChannelMsg::OpenFailure(reason_code)).await;
}

let _ = self.sender.send(Reply::ChannelOpenFailure);
Expand All @@ -498,9 +499,11 @@ impl Session {
}

if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::Data {
data: CryptoVec::from_slice(&data),
});
let _ = chan
.send(ChannelMsg::Data {
data: CryptoVec::from_slice(&data),
})
.await;
}

client.data(channel_num, &data, self).await
Expand All @@ -522,10 +525,12 @@ impl Session {
}

if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::ExtendedData {
ext: extended_code,
data: CryptoVec::from_slice(&data),
});
let _ = chan
.send(ChannelMsg::ExtendedData {
ext: extended_code,
data: CryptoVec::from_slice(&data),
})
.await;
}

client
Expand All @@ -541,15 +546,15 @@ impl Session {
map_err!(u8::decode(&mut r))?; // should be 0.
let client_can_do = map_err!(u8::decode(&mut r))? != 0;
if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::XonXoff { client_can_do });
let _ = chan.send(ChannelMsg::XonXoff { client_can_do }).await;
}
client.xon_xoff(channel_num, client_can_do, self).await
}
"exit-status" => {
map_err!(u8::decode(&mut r))?; // should be 0.
let exit_status = map_err!(u32::decode(&mut r))?;
if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::ExitStatus { exit_status });
let _ = chan.send(ChannelMsg::ExitStatus { exit_status }).await;
}
client.exit_status(channel_num, exit_status, self).await
}
Expand All @@ -561,12 +566,14 @@ impl Session {
let error_message = map_err!(String::decode(&mut r))?;
let lang_tag = map_err!(String::decode(&mut r))?;
if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::ExitSignal {
signal_name: signal_name.clone(),
core_dumped,
error_message: error_message.to_string(),
lang_tag: lang_tag.to_string(),
});
let _ = chan
.send(ChannelMsg::ExitSignal {
signal_name: signal_name.clone(),
core_dumped,
error_message: error_message.to_string(),
lang_tag: lang_tag.to_string(),
})
.await;
}
client
.exit_signal(
Expand Down Expand Up @@ -632,7 +639,7 @@ impl Session {
if let Some(chan) = self.channels.get(&channel_num) {
chan.window_size().update(new_size).await;

let _ = chan.send(ChannelMsg::WindowAdjusted { new_size });
let _ = chan.send(ChannelMsg::WindowAdjusted { new_size }).await;
}
client.window_adjusted(channel_num, new_size, self).await
}
Expand Down Expand Up @@ -682,14 +689,14 @@ impl Session {
Some((&msg::CHANNEL_SUCCESS, mut r)) => {
let channel_num = map_err!(ChannelId::decode(&mut r))?;
if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::Success);
let _ = chan.send(ChannelMsg::Success).await;
}
client.channel_success(channel_num, self).await
}
Some((&msg::CHANNEL_FAILURE, mut r)) => {
let channel_num = map_err!(ChannelId::decode(&mut r))?;
if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::Failure);
let _ = chan.send(ChannelMsg::Failure).await;
}
client.channel_failure(channel_num, self).await
}
Expand Down Expand Up @@ -884,6 +891,7 @@ impl Session {
self.inbound_channel_sender.clone(),
msg.recipient_maximum_packet_size,
msg.recipient_window_size,
self.common.config.channel_buffer_size,
);

self.channels.insert(id, channel_ref);
Expand Down
26 changes: 11 additions & 15 deletions russh/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ pub struct Handle<H: Handler> {
sender: Sender<Msg>,
receiver: UnboundedReceiver<Reply>,
join: russh_util::runtime::JoinHandle<Result<(), H::Error>>,
channel_buffer_size: usize,
}

impl<H: Handler> Drop for Handle<H> {
Expand Down Expand Up @@ -427,7 +428,7 @@ impl<H: Handler> Handle<H> {
/// Wait for confirmation that a channel is open
async fn wait_channel_confirmation(
&self,
mut receiver: UnboundedReceiver<ChannelMsg>,
mut receiver: Receiver<ChannelMsg>,
window_size_ref: WindowSizeRef,
) -> Result<Channel<Msg>, crate::Error> {
loop {
Expand Down Expand Up @@ -466,7 +467,7 @@ impl<H: Handler> Handle<H> {
/// usable when it's confirmed by the server, as indicated by the
/// `confirmed` field of the corresponding `Channel`.
pub async fn channel_open_session(&self) -> Result<Channel<Msg>, crate::Error> {
let (sender, receiver) = unbounded_channel();
let (sender, receiver) = channel(self.channel_buffer_size);
let channel_ref = ChannelRef::new(sender);
let window_size_ref = channel_ref.window_size().clone();

Expand All @@ -484,7 +485,7 @@ impl<H: Handler> Handle<H> {
originator_address: A,
originator_port: u32,
) -> Result<Channel<Msg>, crate::Error> {
let (sender, receiver) = unbounded_channel();
let (sender, receiver) = channel(self.channel_buffer_size);
let channel_ref = ChannelRef::new(sender);
let window_size_ref = channel_ref.window_size().clone();

Expand Down Expand Up @@ -515,7 +516,7 @@ impl<H: Handler> Handle<H> {
originator_address: B,
originator_port: u32,
) -> Result<Channel<Msg>, crate::Error> {
let (sender, receiver) = unbounded_channel();
let (sender, receiver) = channel(self.channel_buffer_size);
let channel_ref = ChannelRef::new(sender);
let window_size_ref = channel_ref.window_size().clone();

Expand All @@ -537,7 +538,7 @@ impl<H: Handler> Handle<H> {
&self,
socket_path: S,
) -> Result<Channel<Msg>, crate::Error> {
let (sender, receiver) = unbounded_channel();
let (sender, receiver) = channel(self.channel_buffer_size);
let channel_ref = ChannelRef::new(sender);
let window_size_ref = channel_ref.window_size().clone();

Expand Down Expand Up @@ -749,6 +750,7 @@ where
config.maximum_packet_size
);
}
let channel_buffer_size = config.channel_buffer_size;
let mut session = Session::new(
config.window_size,
CommonSession {
Expand Down Expand Up @@ -789,6 +791,7 @@ where
sender: handle_sender,
receiver: handle_receiver,
join,
channel_buffer_size,
})
}

Expand Down Expand Up @@ -1273,16 +1276,6 @@ impl Session {
}
Ok(())
}

/// Send a `ChannelMsg` from the background handler to the client.
pub fn send_channel_msg(&self, channel: ChannelId, msg: ChannelMsg) -> bool {
if let Some(chan) = self.channels.get(&channel) {
chan.send(msg).unwrap_or(());
true
} else {
false
}
}
}

thread_local! {
Expand Down Expand Up @@ -1482,6 +1475,8 @@ pub struct Config {
pub window_size: u32,
/// The maximal size of a single packet.
pub maximum_packet_size: u32,
/// Buffer size for created channels.
pub channel_buffer_size: usize,
/// Lists of preferred algorithms.
pub preferred: negotiation::Preferred,
/// Time after which the connection is garbage-collected.
Expand All @@ -1505,6 +1500,7 @@ impl Default for Config {
limits: Limits::default(),
window_size: 2097152,
maximum_packet_size: 32768,
channel_buffer_size: 100,
preferred: Default::default(),
inactivity_timeout: None,
keepalive_interval: None,
Expand Down
Loading

0 comments on commit 0512b24

Please sign in to comment.