From 0512b24835be2bd845972197678c7a6c0a00b3e2 Mon Sep 17 00:00:00 2001 From: Eric Rodrigues Pires Date: Sat, 7 Dec 2024 19:12:16 -0300 Subject: [PATCH 1/6] Add backpressure to Channel receivers Fixes #392. Simply put, this changes its `UnboundedReceiver` from `unbounded_channel` with a `Receiver` from `channel(...)`, with a configurable value in `channel_buffer_size` of 100. This means that, for each channel that's created, it can hold up to 100 `Msg` objects, which are at most as big as the window size plus metadata/indirection pointers. This is enough to force messages to be read from the underlying `TcpStream` only as quickly as the server/client is able to handle them. This has been tested with the example given in #392, including a modification to slowly read into a non-expanding buffer, and appears to fix the issue that's been pointed out. Some other considerations: - It is still possible to cause memory to explode by opening multiple channels, although that might be intentional on the user's part. In this case, it's up to the user to handle when channels get created or not. - The limited buffering also means that control `Msg`s (eg. `Eof`) will also be delayed by the transport layer backpressure, although this might be expected by the SSH standard. Reading through RFC 4254, I couldn't find any mention of non-`Data`/`ExtendedData` messages having to be dealt differently, but I don't think it's even possible to do so over TCP. If this assumption is wrong, this might require a separate `mpsc::unbounded_channel` just for control messages. **BREAKING CHANGES**: - This removes `fn send_channel_msg`, which doesn't seem to be used anywhere in this library, but is part of the publicly exposed API nonetheless. - This adds the configuration value `channel_buffer_size` to both servers and clients that allows setting how big the buffer size should be. --- russh/src/channels/channel_ref.rs | 8 +-- russh/src/channels/mod.rs | 7 ++- russh/src/client/encrypted.rs | 48 ++++++++------- russh/src/client/mod.rs | 26 ++++---- russh/src/server/encrypted.rs | 99 +++++++++++++++++++------------ russh/src/server/mod.rs | 13 +++- russh/src/server/session.rs | 17 +++--- 7 files changed, 127 insertions(+), 91 deletions(-) diff --git a/russh/src/channels/channel_ref.rs b/russh/src/channels/channel_ref.rs index e51c5a57..d7f937cd 100644 --- a/russh/src/channels/channel_ref.rs +++ b/russh/src/channels/channel_ref.rs @@ -1,4 +1,4 @@ -use tokio::sync::mpsc::UnboundedSender; +use tokio::sync::mpsc::Sender; use super::WindowSizeRef; use crate::ChannelMsg; @@ -7,12 +7,12 @@ use crate::ChannelMsg; /// to it and update it's `window_size`. #[derive(Debug)] pub struct ChannelRef { - pub(super) sender: UnboundedSender, + pub(super) sender: Sender, pub(super) window_size: WindowSizeRef, } impl ChannelRef { - pub fn new(sender: UnboundedSender) -> Self { + pub fn new(sender: Sender) -> Self { Self { sender, window_size: WindowSizeRef::new(0), @@ -25,7 +25,7 @@ impl ChannelRef { } impl std::ops::Deref for ChannelRef { - type Target = UnboundedSender; + type Target = Sender; fn deref(&self) -> &Self::Target { &self.sender diff --git a/russh/src/channels/mod.rs b/russh/src/channels/mod.rs index c963a0b1..e08cceb6 100644 --- a/russh/src/channels/mod.rs +++ b/russh/src/channels/mod.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::sync::mpsc::{Sender, UnboundedReceiver}; +use tokio::sync::mpsc::{Receiver, Sender}; use tokio::sync::{Mutex, Notify}; use crate::{ChannelId, ChannelOpenFailure, CryptoVec, Error, Pty, Sig}; @@ -143,7 +143,7 @@ impl WindowSizeRef { pub struct Channel> { pub(crate) id: ChannelId, pub(crate) sender: Sender, - pub(crate) receiver: UnboundedReceiver, + pub(crate) receiver: Receiver, pub(crate) max_packet_size: u32, pub(crate) window_size: WindowSizeRef, } @@ -160,8 +160,9 @@ impl + Send + Sync + 'static> Channel { sender: Sender, max_packet_size: u32, window_size: u32, + channel_buffer_size: usize, ) -> (Self, ChannelRef) { - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let (tx, rx) = tokio::sync::mpsc::channel(channel_buffer_size); let window_size = WindowSizeRef::new(window_size); ( diff --git a/russh/src/client/encrypted.rs b/russh/src/client/encrypted.rs index 83d1f4d2..a4d50ae8 100644 --- a/russh/src/client/encrypted.rs +++ b/russh/src/client/encrypted.rs @@ -428,6 +428,7 @@ impl Session { max_packet_size: msg.maximum_packet_size, window_size: msg.initial_window_size, }) + .await .unwrap_or(()); } else { error!("no channel for id {local_id:?}"); @@ -457,7 +458,7 @@ impl Session { debug!("channel_eof"); let channel_num = map_err!(ChannelId::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::Eof); + let _ = chan.send(ChannelMsg::Eof).await; } client.channel_eof(channel_num, self).await } @@ -473,7 +474,7 @@ impl Session { } if let Some(sender) = self.channels.remove(&channel_num) { - let _ = sender.send(ChannelMsg::OpenFailure(reason_code)); + let _ = sender.send(ChannelMsg::OpenFailure(reason_code)).await; } let _ = self.sender.send(Reply::ChannelOpenFailure); @@ -498,9 +499,11 @@ impl Session { } if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::Data { - data: CryptoVec::from_slice(&data), - }); + let _ = chan + .send(ChannelMsg::Data { + data: CryptoVec::from_slice(&data), + }) + .await; } client.data(channel_num, &data, self).await @@ -522,10 +525,12 @@ impl Session { } if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::ExtendedData { - ext: extended_code, - data: CryptoVec::from_slice(&data), - }); + let _ = chan + .send(ChannelMsg::ExtendedData { + ext: extended_code, + data: CryptoVec::from_slice(&data), + }) + .await; } client @@ -541,7 +546,7 @@ impl Session { map_err!(u8::decode(&mut r))?; // should be 0. let client_can_do = map_err!(u8::decode(&mut r))? != 0; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::XonXoff { client_can_do }); + let _ = chan.send(ChannelMsg::XonXoff { client_can_do }).await; } client.xon_xoff(channel_num, client_can_do, self).await } @@ -549,7 +554,7 @@ impl Session { map_err!(u8::decode(&mut r))?; // should be 0. let exit_status = map_err!(u32::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::ExitStatus { exit_status }); + let _ = chan.send(ChannelMsg::ExitStatus { exit_status }).await; } client.exit_status(channel_num, exit_status, self).await } @@ -561,12 +566,14 @@ impl Session { let error_message = map_err!(String::decode(&mut r))?; let lang_tag = map_err!(String::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::ExitSignal { - signal_name: signal_name.clone(), - core_dumped, - error_message: error_message.to_string(), - lang_tag: lang_tag.to_string(), - }); + let _ = chan + .send(ChannelMsg::ExitSignal { + signal_name: signal_name.clone(), + core_dumped, + error_message: error_message.to_string(), + lang_tag: lang_tag.to_string(), + }) + .await; } client .exit_signal( @@ -632,7 +639,7 @@ impl Session { if let Some(chan) = self.channels.get(&channel_num) { chan.window_size().update(new_size).await; - let _ = chan.send(ChannelMsg::WindowAdjusted { new_size }); + let _ = chan.send(ChannelMsg::WindowAdjusted { new_size }).await; } client.window_adjusted(channel_num, new_size, self).await } @@ -682,14 +689,14 @@ impl Session { Some((&msg::CHANNEL_SUCCESS, mut r)) => { let channel_num = map_err!(ChannelId::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::Success); + let _ = chan.send(ChannelMsg::Success).await; } client.channel_success(channel_num, self).await } Some((&msg::CHANNEL_FAILURE, mut r)) => { let channel_num = map_err!(ChannelId::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::Failure); + let _ = chan.send(ChannelMsg::Failure).await; } client.channel_failure(channel_num, self).await } @@ -884,6 +891,7 @@ impl Session { self.inbound_channel_sender.clone(), msg.recipient_maximum_packet_size, msg.recipient_window_size, + self.common.config.channel_buffer_size, ); self.channels.insert(id, channel_ref); diff --git a/russh/src/client/mod.rs b/russh/src/client/mod.rs index 7bc5b1fa..f831f89e 100644 --- a/russh/src/client/mod.rs +++ b/russh/src/client/mod.rs @@ -224,6 +224,7 @@ pub struct Handle { sender: Sender, receiver: UnboundedReceiver, join: russh_util::runtime::JoinHandle>, + channel_buffer_size: usize, } impl Drop for Handle { @@ -427,7 +428,7 @@ impl Handle { /// Wait for confirmation that a channel is open async fn wait_channel_confirmation( &self, - mut receiver: UnboundedReceiver, + mut receiver: Receiver, window_size_ref: WindowSizeRef, ) -> Result, crate::Error> { loop { @@ -466,7 +467,7 @@ impl Handle { /// usable when it's confirmed by the server, as indicated by the /// `confirmed` field of the corresponding `Channel`. pub async fn channel_open_session(&self) -> Result, crate::Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -484,7 +485,7 @@ impl Handle { originator_address: A, originator_port: u32, ) -> Result, crate::Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -515,7 +516,7 @@ impl Handle { originator_address: B, originator_port: u32, ) -> Result, crate::Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -537,7 +538,7 @@ impl Handle { &self, socket_path: S, ) -> Result, crate::Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -749,6 +750,7 @@ where config.maximum_packet_size ); } + let channel_buffer_size = config.channel_buffer_size; let mut session = Session::new( config.window_size, CommonSession { @@ -789,6 +791,7 @@ where sender: handle_sender, receiver: handle_receiver, join, + channel_buffer_size, }) } @@ -1273,16 +1276,6 @@ impl Session { } Ok(()) } - - /// Send a `ChannelMsg` from the background handler to the client. - pub fn send_channel_msg(&self, channel: ChannelId, msg: ChannelMsg) -> bool { - if let Some(chan) = self.channels.get(&channel) { - chan.send(msg).unwrap_or(()); - true - } else { - false - } - } } thread_local! { @@ -1482,6 +1475,8 @@ pub struct Config { pub window_size: u32, /// The maximal size of a single packet. pub maximum_packet_size: u32, + /// Buffer size for created channels. + pub channel_buffer_size: usize, /// Lists of preferred algorithms. pub preferred: negotiation::Preferred, /// Time after which the connection is garbage-collected. @@ -1505,6 +1500,7 @@ impl Default for Config { limits: Limits::default(), window_size: 2097152, maximum_packet_size: 32768, + channel_buffer_size: 100, preferred: Default::default(), inactivity_timeout: None, keepalive_interval: None, diff --git a/russh/src/server/encrypted.rs b/russh/src/server/encrypted.rs index f2c9ffb0..b496c85b 100644 --- a/russh/src/server/encrypted.rs +++ b/russh/src/server/encrypted.rs @@ -701,7 +701,7 @@ impl Session { msg::CHANNEL_EOF => { let channel_num = map_err!(ChannelId::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - chan.send(ChannelMsg::Eof).unwrap_or(()) + chan.send(ChannelMsg::Eof).await.unwrap_or(()) } debug!("handler.channel_eof {:?}", channel_num); handler.channel_eof(channel_num, self).await @@ -733,6 +733,7 @@ impl Session { ext, data: CryptoVec::from_slice(&data), }) + .await .unwrap_or(()) } handler.extended_data(channel_num, ext, &data, self).await @@ -741,6 +742,7 @@ impl Session { chan.send(ChannelMsg::Data { data: CryptoVec::from_slice(&data), }) + .await .unwrap_or(()) } handler.data(channel_num, &data, self).await @@ -766,6 +768,7 @@ impl Session { chan.window_size().update(new_size).await; chan.send(ChannelMsg::WindowAdjusted { new_size }) + .await .unwrap_or(()) } debug!("handler.window_adjusted {:?}", channel_num); @@ -795,6 +798,7 @@ impl Session { max_packet_size: msg.maximum_packet_size, window_size: msg.initial_window_size, }) + .await .unwrap_or(()); } else { error!("no channel for id {:?}", local_id); @@ -853,15 +857,17 @@ impl Session { } if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::RequestPty { - want_reply: true, - term: term.clone(), - col_width, - row_height, - pix_width, - pix_height, - terminal_modes: modes.into(), - }); + let _ = chan + .send(ChannelMsg::RequestPty { + want_reply: true, + term: term.clone(), + col_width, + row_height, + pix_width, + pix_height, + terminal_modes: modes.into(), + }) + .await; } debug!("handler.pty_request {:?}", channel_num); @@ -886,13 +892,15 @@ impl Session { let x11_screen_number = map_err!(u32::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::RequestX11 { - want_reply: true, - single_connection, - x11_authentication_cookie: x11_auth_cookie.clone(), - x11_authentication_protocol: x11_auth_protocol.clone(), - x11_screen_number, - }); + let _ = chan + .send(ChannelMsg::RequestX11 { + want_reply: true, + single_connection, + x11_authentication_cookie: x11_auth_cookie.clone(), + x11_authentication_protocol: x11_auth_protocol.clone(), + x11_screen_number, + }) + .await; } debug!("handler.x11_request {:?}", channel_num); handler @@ -911,11 +919,13 @@ impl Session { let env_value = map_err!(String::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::SetEnv { - want_reply: true, - variable_name: env_variable.clone(), - variable_value: env_value.clone(), - }); + let _ = chan + .send(ChannelMsg::SetEnv { + want_reply: true, + variable_name: env_variable.clone(), + variable_value: env_value.clone(), + }) + .await; } debug!("handler.env_request {:?}", channel_num); @@ -925,14 +935,18 @@ impl Session { } "shell" => { if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::RequestShell { want_reply: true }); + let _ = chan + .send(ChannelMsg::RequestShell { want_reply: true }) + .await; } debug!("handler.shell_request {:?}", channel_num); handler.shell_request(channel_num, self).await } "auth-agent-req@openssh.com" => { if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::AgentForward { want_reply: true }); + let _ = chan + .send(ChannelMsg::AgentForward { want_reply: true }) + .await; } debug!("handler.agent_request {:?}", channel_num); @@ -947,10 +961,12 @@ impl Session { "exec" => { let req = map_err!(Bytes::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::Exec { - want_reply: true, - command: req.to_vec(), - }); + let _ = chan + .send(ChannelMsg::Exec { + want_reply: true, + command: req.to_vec(), + }) + .await; } debug!("handler.exec_request {:?}", channel_num); handler.exec_request(channel_num, &req, self).await @@ -959,10 +975,12 @@ impl Session { let name = map_err!(String::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::RequestSubsystem { - want_reply: true, - name: name.clone(), - }); + let _ = chan + .send(ChannelMsg::RequestSubsystem { + want_reply: true, + name: name.clone(), + }) + .await; } debug!("handler.subsystem_request {:?}", channel_num); handler.subsystem_request(channel_num, &name, self).await @@ -974,12 +992,14 @@ impl Session { let pix_height = map_err!(u32::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::WindowChange { - col_width, - row_height, - pix_width, - pix_height, - }); + let _ = chan + .send(ChannelMsg::WindowChange { + col_width, + row_height, + pix_width, + pix_height, + }) + .await; } debug!("handler.window_change {:?}", channel_num); @@ -1000,6 +1020,7 @@ impl Session { chan.send(ChannelMsg::Signal { signal: signal.clone(), }) + .await .unwrap_or(()) } debug!("handler.signal {:?} {:?}", channel_num, signal); @@ -1110,6 +1131,7 @@ impl Session { if let Some(channel_sender) = self.channels.remove(&channel_num) { channel_sender .send(ChannelMsg::OpenFailure(reason)) + .await .map_err(|_| crate::Error::SendError)?; } @@ -1204,6 +1226,7 @@ impl Session { self.sender.sender.clone(), channel_params.recipient_maximum_packet_size, channel_params.recipient_window_size, + self.common.config.channel_buffer_size, ); match &msg.typ { diff --git a/russh/src/server/mod.rs b/russh/src/server/mod.rs index 9c41078e..8f2ed31b 100644 --- a/russh/src/server/mod.rs +++ b/russh/src/server/mod.rs @@ -79,6 +79,8 @@ pub struct Config { pub window_size: u32, /// The maximal size of a single packet. pub maximum_packet_size: u32, + /// Buffer size for created channels. + pub channel_buffer_size: usize, /// Internal event buffer size pub event_buffer_size: usize, /// Lists of preferred algorithms. @@ -108,6 +110,7 @@ impl Default for Config { keys: Vec::new(), window_size: 2097152, maximum_packet_size: 32768, + channel_buffer_size: 100, event_buffer_size: 10, limits: Limits::default(), preferred: Default::default(), @@ -134,6 +137,7 @@ impl Debug for Config { .field("keys", &"***") .field("window_size", &self.window_size) .field("maximum_packet_size", &self.maximum_packet_size) + .field("channel_buffer_size", &self.channel_buffer_size) .field("event_buffer_size", &self.event_buffer_size) .field("limits", &self.limits) .field("preferred", &self.preferred) @@ -748,10 +752,10 @@ pub trait Server { match accept_result { Ok((socket, _)) => { let config = config.clone(); - let handler = self.new_client(socket.peer_addr().ok()); + let handler = self.new_client(socket.peer_addr().ok()); let error_tx = error_tx.clone(); russh_util::runtime::spawn(async move { - let session = match run_stream(config, socket, handler).await { + let session = match run_stream(config, socket, handler).await { Ok(s) => s, Err(e) => { debug!("Connection setup failed"); @@ -856,8 +860,11 @@ where // Reading SSH id and allocating a session. let mut stream = SshRead::new(stream); let (sender, receiver) = tokio::sync::mpsc::channel(config.event_buffer_size); + let handle = server::session::Handle { + sender, + channel_buffer_size: config.channel_buffer_size, + }; let common = read_ssh_id(config, &mut stream).await?; - let handle = server::session::Handle { sender }; let session = Session { target_window_size: common.config.window_size, common, diff --git a/russh/src/server/session.rs b/russh/src/server/session.rs index 4f379e7c..30968120 100644 --- a/russh/src/server/session.rs +++ b/russh/src/server/session.rs @@ -7,7 +7,7 @@ use negotiation::parse_kex_algo_list; use russh_keys::helpers::NameList; use russh_keys::map_err; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use tokio::sync::mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver}; +use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::oneshot; use super::*; @@ -90,6 +90,7 @@ impl From<(ChannelId, ChannelMsg)> for Msg { /// the request/response cycle. pub struct Handle { pub(crate) sender: Sender, + pub(crate) channel_buffer_size: usize, } impl Handle { @@ -217,7 +218,7 @@ impl Handle { /// confirmed that it allows agent forwarding. See /// [PROTOCOL.agent](https://datatracker.ietf.org/doc/html/draft-miller-ssh-agent). pub async fn channel_open_agent(&self) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -236,7 +237,7 @@ impl Handle { /// usable when it's confirmed by the server, as indicated by the /// `confirmed` field of the corresponding `Channel`. pub async fn channel_open_session(&self) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -261,7 +262,7 @@ impl Handle { originator_address: B, originator_port: u32, ) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -286,7 +287,7 @@ impl Handle { originator_address: B, originator_port: u32, ) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -308,7 +309,7 @@ impl Handle { &self, server_socket_path: A, ) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -328,7 +329,7 @@ impl Handle { originator_address: A, originator_port: u32, ) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -346,7 +347,7 @@ impl Handle { async fn wait_channel_confirmation( &self, - mut receiver: UnboundedReceiver, + mut receiver: Receiver, window_size_ref: WindowSizeRef, ) -> Result, Error> { loop { From 6116afca64a59c0499094a7046708ac743698485 Mon Sep 17 00:00:00 2001 From: Eric Rodrigues Pires Date: Tue, 10 Dec 2024 09:53:54 -0300 Subject: [PATCH 2/6] Add backpressure test --- russh/tests/test_backpressure.rs | 148 +++++++++++++++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 russh/tests/test_backpressure.rs diff --git a/russh/tests/test_backpressure.rs b/russh/tests/test_backpressure.rs new file mode 100644 index 00000000..4c035652 --- /dev/null +++ b/russh/tests/test_backpressure.rs @@ -0,0 +1,148 @@ +use std::net::{SocketAddr, TcpListener, TcpStream}; +use std::sync::Arc; + +use futures::FutureExt; +use rand::RngCore; +use rand_core::OsRng; +use russh::server::{self, Auth, Msg, Server as _, Session}; +use russh::{client, Channel, ChannelMsg}; +use ssh_key::PrivateKey; +use tokio::io::AsyncWriteExt; +use tokio::sync::watch; +use tokio::time::sleep; + +pub const WINDOW_SIZE: usize = 8 * 2048; +pub const CHANNEL_BUFFER_SIZE: usize = 10; + +#[tokio::test] +async fn test_backpressure() -> Result<(), anyhow::Error> { + env_logger::init(); + + let addr = addr(); + let data = data(); + let (tx, rx) = watch::channel(()); + + tokio::spawn(Server::run(addr, rx)); + + // Wait until the server is started + while TcpStream::connect(addr).is_err() { + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + + stream(addr, &data, tx).await?; + + Ok(()) +} + +async fn stream(addr: SocketAddr, data: &[u8], tx: watch::Sender<()>) -> Result<(), anyhow::Error> { + let config = Arc::new(client::Config::default()); + let key = Arc::new(PrivateKey::random(&mut OsRng, ssh_key::Algorithm::Ed25519).unwrap()); + + let mut session = russh::client::connect(config, addr, Client).await?; + let channel = match session.authenticate_publickey("user", key).await { + Ok(true) => session.channel_open_session().await?, + Ok(false) => panic!("Authentication failed"), + Err(err) => return Err(err.into()), + }; + + let mut writer = channel.make_writer(); + + // TCP listener will buffer one extra message + for _ in 0..=CHANNEL_BUFFER_SIZE { + assert!(writer.write(data).await.is_ok()); + } + let pending_write = async { writer.write(data).await.unwrap() }; + sleep(std::time::Duration::from_millis(100)).await; + assert_eq!(pending_write.now_or_never(), None); + // Make space on the buffer + tx.send(()).unwrap(); + assert!(writer.write(data).await.is_ok()); + + Ok(()) +} + +fn data() -> Vec { + let mut rng = rand::thread_rng(); + + let mut data = vec![0u8; WINDOW_SIZE]; // Check whether the window_size resizing works + rng.fill_bytes(&mut data); + + data +} + +/// Find a unused local address to bind our server to +fn addr() -> SocketAddr { + TcpListener::bind(("127.0.0.1", 0)) + .unwrap() + .local_addr() + .unwrap() +} + +#[derive(Clone)] +struct Server { + rx: Option>, +} + +impl Server { + async fn run(addr: SocketAddr, rx: watch::Receiver<()>) { + let config = Arc::new(server::Config { + keys: vec![PrivateKey::random(&mut OsRng, ssh_key::Algorithm::Ed25519).unwrap()], + window_size: WINDOW_SIZE as u32, + channel_buffer_size: CHANNEL_BUFFER_SIZE, + ..Default::default() + }); + let mut sh = Server { rx: Some(rx) }; + + sh.run_on_address(config, addr).await.unwrap(); + } +} + +impl russh::server::Server for Server { + type Handler = Self; + + fn new_client(&mut self, _: Option) -> Self::Handler { + self.clone() + } +} + +#[async_trait::async_trait] +impl russh::server::Handler for Server { + type Error = anyhow::Error; + + async fn auth_publickey( + &mut self, + _: &str, + _: &ssh_key::PublicKey, + ) -> Result { + Ok(Auth::Accept) + } + + async fn channel_open_session( + &mut self, + mut channel: Channel, + _session: &mut Session, + ) -> Result { + let mut rx = self.rx.take().unwrap(); + tokio::spawn(async move { + while let Ok(_) = rx.changed().await { + match channel.wait().await { + Some(ChannelMsg::Data { .. }) => (), + other => panic!("unexpected message {:?}", other), + } + } + }); + + Ok(true) + } +} + +struct Client; + +#[async_trait::async_trait] +impl russh::client::Handler for Client { + type Error = anyhow::Error; + + async fn check_server_key(&mut self, _: &ssh_key::PublicKey) -> Result { + Ok(true) + } +} From 9bed133dc64a884da0881c2e839a82ee56f10fe4 Mon Sep 17 00:00:00 2001 From: Eugene Date: Thu, 12 Dec 2024 17:24:06 +0100 Subject: [PATCH 3/6] Update russh/src/server/mod.rs --- russh/src/server/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/russh/src/server/mod.rs b/russh/src/server/mod.rs index 8f2ed31b..fd3b683d 100644 --- a/russh/src/server/mod.rs +++ b/russh/src/server/mod.rs @@ -79,7 +79,7 @@ pub struct Config { pub window_size: u32, /// The maximal size of a single packet. pub maximum_packet_size: u32, - /// Buffer size for created channels. + /// Buffer size for each channel (a number of unprocessed messages to store before propagating backpressure to the TCP stream) pub channel_buffer_size: usize, /// Internal event buffer size pub event_buffer_size: usize, From 4aa7925b7d59474b826c17f9f52f771f6b616b23 Mon Sep 17 00:00:00 2001 From: Eugene Date: Thu, 12 Dec 2024 17:25:12 +0100 Subject: [PATCH 4/6] Update russh/src/client/mod.rs --- russh/src/client/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/russh/src/client/mod.rs b/russh/src/client/mod.rs index f831f89e..8bf2ae0e 100644 --- a/russh/src/client/mod.rs +++ b/russh/src/client/mod.rs @@ -1475,7 +1475,7 @@ pub struct Config { pub window_size: u32, /// The maximal size of a single packet. pub maximum_packet_size: u32, - /// Buffer size for created channels. + /// Buffer size for each channel (a number of unprocessed messages to store before propagating backpressure to the TCP stream) pub channel_buffer_size: usize, /// Lists of preferred algorithms. pub preferred: negotiation::Preferred, From d4ed6c9cf2d1824fd422d63a930262b3381886c1 Mon Sep 17 00:00:00 2001 From: Eugene Date: Thu, 12 Dec 2024 17:39:01 +0100 Subject: [PATCH 5/6] Update mod.rs --- russh/src/client/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/russh/src/client/mod.rs b/russh/src/client/mod.rs index 8bf2ae0e..a00b6aa3 100644 --- a/russh/src/client/mod.rs +++ b/russh/src/client/mod.rs @@ -1475,7 +1475,7 @@ pub struct Config { pub window_size: u32, /// The maximal size of a single packet. pub maximum_packet_size: u32, - /// Buffer size for each channel (a number of unprocessed messages to store before propagating backpressure to the TCP stream) + /// Buffer size for each channel (a number of unprocessed messages to store before propagating backpressure to the TCP stream) pub channel_buffer_size: usize, /// Lists of preferred algorithms. pub preferred: negotiation::Preferred, From e42a7aa04dd910fc6e40a9e0fd609648e45ffbcd Mon Sep 17 00:00:00 2001 From: Eugene Date: Thu, 12 Dec 2024 17:42:02 +0100 Subject: [PATCH 6/6] Update test_backpressure.rs --- russh/tests/test_backpressure.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/russh/tests/test_backpressure.rs b/russh/tests/test_backpressure.rs index 4c035652..55688ed7 100644 --- a/russh/tests/test_backpressure.rs +++ b/russh/tests/test_backpressure.rs @@ -6,6 +6,7 @@ use rand::RngCore; use rand_core::OsRng; use russh::server::{self, Auth, Msg, Server as _, Session}; use russh::{client, Channel, ChannelMsg}; +use russh_keys::key::PrivateKeyWithHashAlg; use ssh_key::PrivateKey; use tokio::io::AsyncWriteExt; use tokio::sync::watch; @@ -39,7 +40,10 @@ async fn stream(addr: SocketAddr, data: &[u8], tx: watch::Sender<()>) -> Result< let key = Arc::new(PrivateKey::random(&mut OsRng, ssh_key::Algorithm::Ed25519).unwrap()); let mut session = russh::client::connect(config, addr, Client).await?; - let channel = match session.authenticate_publickey("user", key).await { + let channel = match session + .authenticate_publickey("user", PrivateKeyWithHashAlg::new(key, None).unwrap()) + .await + { Ok(true) => session.channel_open_session().await?, Ok(false) => panic!("Authentication failed"), Err(err) => return Err(err.into()),