diff --git a/russh/src/channels/channel_ref.rs b/russh/src/channels/channel_ref.rs index e51c5a57..d7f937cd 100644 --- a/russh/src/channels/channel_ref.rs +++ b/russh/src/channels/channel_ref.rs @@ -1,4 +1,4 @@ -use tokio::sync::mpsc::UnboundedSender; +use tokio::sync::mpsc::Sender; use super::WindowSizeRef; use crate::ChannelMsg; @@ -7,12 +7,12 @@ use crate::ChannelMsg; /// to it and update it's `window_size`. #[derive(Debug)] pub struct ChannelRef { - pub(super) sender: UnboundedSender, + pub(super) sender: Sender, pub(super) window_size: WindowSizeRef, } impl ChannelRef { - pub fn new(sender: UnboundedSender) -> Self { + pub fn new(sender: Sender) -> Self { Self { sender, window_size: WindowSizeRef::new(0), @@ -25,7 +25,7 @@ impl ChannelRef { } impl std::ops::Deref for ChannelRef { - type Target = UnboundedSender; + type Target = Sender; fn deref(&self) -> &Self::Target { &self.sender diff --git a/russh/src/channels/mod.rs b/russh/src/channels/mod.rs index c963a0b1..e08cceb6 100644 --- a/russh/src/channels/mod.rs +++ b/russh/src/channels/mod.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::sync::mpsc::{Sender, UnboundedReceiver}; +use tokio::sync::mpsc::{Receiver, Sender}; use tokio::sync::{Mutex, Notify}; use crate::{ChannelId, ChannelOpenFailure, CryptoVec, Error, Pty, Sig}; @@ -143,7 +143,7 @@ impl WindowSizeRef { pub struct Channel> { pub(crate) id: ChannelId, pub(crate) sender: Sender, - pub(crate) receiver: UnboundedReceiver, + pub(crate) receiver: Receiver, pub(crate) max_packet_size: u32, pub(crate) window_size: WindowSizeRef, } @@ -160,8 +160,9 @@ impl + Send + Sync + 'static> Channel { sender: Sender, max_packet_size: u32, window_size: u32, + channel_buffer_size: usize, ) -> (Self, ChannelRef) { - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let (tx, rx) = tokio::sync::mpsc::channel(channel_buffer_size); let window_size = WindowSizeRef::new(window_size); ( diff --git a/russh/src/client/encrypted.rs b/russh/src/client/encrypted.rs index 65ebc9bd..ec714e62 100644 --- a/russh/src/client/encrypted.rs +++ b/russh/src/client/encrypted.rs @@ -432,6 +432,7 @@ impl Session { max_packet_size: msg.maximum_packet_size, window_size: msg.initial_window_size, }) + .await .unwrap_or(()); } else { error!("no channel for id {local_id:?}"); @@ -461,7 +462,7 @@ impl Session { debug!("channel_eof"); let channel_num = map_err!(ChannelId::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::Eof); + let _ = chan.send(ChannelMsg::Eof).await; } client.channel_eof(channel_num, self).await } @@ -477,7 +478,7 @@ impl Session { } if let Some(sender) = self.channels.remove(&channel_num) { - let _ = sender.send(ChannelMsg::OpenFailure(reason_code)); + let _ = sender.send(ChannelMsg::OpenFailure(reason_code)).await; } let _ = self.sender.send(Reply::ChannelOpenFailure); @@ -502,9 +503,11 @@ impl Session { } if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::Data { - data: CryptoVec::from_slice(&data), - }); + let _ = chan + .send(ChannelMsg::Data { + data: CryptoVec::from_slice(&data), + }) + .await; } client.data(channel_num, &data, self).await @@ -526,10 +529,12 @@ impl Session { } if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::ExtendedData { - ext: extended_code, - data: CryptoVec::from_slice(&data), - }); + let _ = chan + .send(ChannelMsg::ExtendedData { + ext: extended_code, + data: CryptoVec::from_slice(&data), + }) + .await; } client @@ -545,7 +550,7 @@ impl Session { map_err!(u8::decode(&mut r))?; // should be 0. let client_can_do = map_err!(u8::decode(&mut r))? != 0; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::XonXoff { client_can_do }); + let _ = chan.send(ChannelMsg::XonXoff { client_can_do }).await; } client.xon_xoff(channel_num, client_can_do, self).await } @@ -553,7 +558,7 @@ impl Session { map_err!(u8::decode(&mut r))?; // should be 0. let exit_status = map_err!(u32::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::ExitStatus { exit_status }); + let _ = chan.send(ChannelMsg::ExitStatus { exit_status }).await; } client.exit_status(channel_num, exit_status, self).await } @@ -565,12 +570,14 @@ impl Session { let error_message = map_err!(String::decode(&mut r))?; let lang_tag = map_err!(String::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::ExitSignal { - signal_name: signal_name.clone(), - core_dumped, - error_message: error_message.to_string(), - lang_tag: lang_tag.to_string(), - }); + let _ = chan + .send(ChannelMsg::ExitSignal { + signal_name: signal_name.clone(), + core_dumped, + error_message: error_message.to_string(), + lang_tag: lang_tag.to_string(), + }) + .await; } client .exit_signal( @@ -636,7 +643,7 @@ impl Session { if let Some(chan) = self.channels.get(&channel_num) { chan.window_size().update(new_size).await; - let _ = chan.send(ChannelMsg::WindowAdjusted { new_size }); + let _ = chan.send(ChannelMsg::WindowAdjusted { new_size }).await; } client.window_adjusted(channel_num, new_size, self).await } @@ -686,14 +693,14 @@ impl Session { Some((&msg::CHANNEL_SUCCESS, mut r)) => { let channel_num = map_err!(ChannelId::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::Success); + let _ = chan.send(ChannelMsg::Success).await; } client.channel_success(channel_num, self).await } Some((&msg::CHANNEL_FAILURE, mut r)) => { let channel_num = map_err!(ChannelId::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::Failure); + let _ = chan.send(ChannelMsg::Failure).await; } client.channel_failure(channel_num, self).await } @@ -888,6 +895,7 @@ impl Session { self.inbound_channel_sender.clone(), msg.recipient_maximum_packet_size, msg.recipient_window_size, + self.common.config.channel_buffer_size, ); self.channels.insert(id, channel_ref); diff --git a/russh/src/client/mod.rs b/russh/src/client/mod.rs index 7b4e1a55..8bdff8b0 100644 --- a/russh/src/client/mod.rs +++ b/russh/src/client/mod.rs @@ -225,6 +225,7 @@ pub struct Handle { sender: Sender, receiver: UnboundedReceiver, join: russh_util::runtime::JoinHandle>, + channel_buffer_size: usize, } impl Drop for Handle { @@ -428,7 +429,7 @@ impl Handle { /// Wait for confirmation that a channel is open async fn wait_channel_confirmation( &self, - mut receiver: UnboundedReceiver, + mut receiver: Receiver, window_size_ref: WindowSizeRef, ) -> Result, crate::Error> { loop { @@ -467,7 +468,7 @@ impl Handle { /// usable when it's confirmed by the server, as indicated by the /// `confirmed` field of the corresponding `Channel`. pub async fn channel_open_session(&self) -> Result, crate::Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -485,7 +486,7 @@ impl Handle { originator_address: A, originator_port: u32, ) -> Result, crate::Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -516,7 +517,7 @@ impl Handle { originator_address: B, originator_port: u32, ) -> Result, crate::Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -538,7 +539,7 @@ impl Handle { &self, socket_path: S, ) -> Result, crate::Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -750,6 +751,7 @@ where config.maximum_packet_size ); } + let channel_buffer_size = config.channel_buffer_size; let mut session = Session::new( config.window_size, CommonSession { @@ -790,6 +792,7 @@ where sender: handle_sender, receiver: handle_receiver, join, + channel_buffer_size, }) } @@ -1274,16 +1277,6 @@ impl Session { } Ok(()) } - - /// Send a `ChannelMsg` from the background handler to the client. - pub fn send_channel_msg(&self, channel: ChannelId, msg: ChannelMsg) -> bool { - if let Some(chan) = self.channels.get(&channel) { - chan.send(msg).unwrap_or(()); - true - } else { - false - } - } } thread_local! { @@ -1483,6 +1476,8 @@ pub struct Config { pub window_size: u32, /// The maximal size of a single packet. pub maximum_packet_size: u32, + /// Buffer size for each channel (a number of unprocessed messages to store before propagating backpressure to the TCP stream) + pub channel_buffer_size: usize, /// Lists of preferred algorithms. pub preferred: negotiation::Preferred, /// Time after which the connection is garbage-collected. @@ -1506,6 +1501,7 @@ impl Default for Config { limits: Limits::default(), window_size: 2097152, maximum_packet_size: 32768, + channel_buffer_size: 100, preferred: Default::default(), inactivity_timeout: None, keepalive_interval: None, diff --git a/russh/src/server/encrypted.rs b/russh/src/server/encrypted.rs index 8bd369d8..315fa169 100644 --- a/russh/src/server/encrypted.rs +++ b/russh/src/server/encrypted.rs @@ -701,7 +701,7 @@ impl Session { msg::CHANNEL_EOF => { let channel_num = map_err!(ChannelId::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - chan.send(ChannelMsg::Eof).unwrap_or(()) + chan.send(ChannelMsg::Eof).await.unwrap_or(()) } debug!("handler.channel_eof {:?}", channel_num); handler.channel_eof(channel_num, self).await @@ -733,6 +733,7 @@ impl Session { ext, data: CryptoVec::from_slice(&data), }) + .await .unwrap_or(()) } handler.extended_data(channel_num, ext, &data, self).await @@ -741,6 +742,7 @@ impl Session { chan.send(ChannelMsg::Data { data: CryptoVec::from_slice(&data), }) + .await .unwrap_or(()) } handler.data(channel_num, &data, self).await @@ -766,6 +768,7 @@ impl Session { chan.window_size().update(new_size).await; chan.send(ChannelMsg::WindowAdjusted { new_size }) + .await .unwrap_or(()) } debug!("handler.window_adjusted {:?}", channel_num); @@ -795,6 +798,7 @@ impl Session { max_packet_size: msg.maximum_packet_size, window_size: msg.initial_window_size, }) + .await .unwrap_or(()); } else { error!("no channel for id {:?}", local_id); @@ -853,15 +857,17 @@ impl Session { } if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::RequestPty { - want_reply: true, - term: term.clone(), - col_width, - row_height, - pix_width, - pix_height, - terminal_modes: modes.into(), - }); + let _ = chan + .send(ChannelMsg::RequestPty { + want_reply: true, + term: term.clone(), + col_width, + row_height, + pix_width, + pix_height, + terminal_modes: modes.into(), + }) + .await; } debug!("handler.pty_request {:?}", channel_num); @@ -886,13 +892,15 @@ impl Session { let x11_screen_number = map_err!(u32::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::RequestX11 { - want_reply: true, - single_connection, - x11_authentication_cookie: x11_auth_cookie.clone(), - x11_authentication_protocol: x11_auth_protocol.clone(), - x11_screen_number, - }); + let _ = chan + .send(ChannelMsg::RequestX11 { + want_reply: true, + single_connection, + x11_authentication_cookie: x11_auth_cookie.clone(), + x11_authentication_protocol: x11_auth_protocol.clone(), + x11_screen_number, + }) + .await; } debug!("handler.x11_request {:?}", channel_num); handler @@ -911,11 +919,13 @@ impl Session { let env_value = map_err!(String::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::SetEnv { - want_reply: true, - variable_name: env_variable.clone(), - variable_value: env_value.clone(), - }); + let _ = chan + .send(ChannelMsg::SetEnv { + want_reply: true, + variable_name: env_variable.clone(), + variable_value: env_value.clone(), + }) + .await; } debug!("handler.env_request {:?}", channel_num); @@ -925,14 +935,18 @@ impl Session { } "shell" => { if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::RequestShell { want_reply: true }); + let _ = chan + .send(ChannelMsg::RequestShell { want_reply: true }) + .await; } debug!("handler.shell_request {:?}", channel_num); handler.shell_request(channel_num, self).await } "auth-agent-req@openssh.com" => { if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::AgentForward { want_reply: true }); + let _ = chan + .send(ChannelMsg::AgentForward { want_reply: true }) + .await; } debug!("handler.agent_request {:?}", channel_num); @@ -947,10 +961,12 @@ impl Session { "exec" => { let req = map_err!(Bytes::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::Exec { - want_reply: true, - command: req.to_vec(), - }); + let _ = chan + .send(ChannelMsg::Exec { + want_reply: true, + command: req.to_vec(), + }) + .await; } debug!("handler.exec_request {:?}", channel_num); handler.exec_request(channel_num, &req, self).await @@ -959,10 +975,12 @@ impl Session { let name = map_err!(String::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::RequestSubsystem { - want_reply: true, - name: name.clone(), - }); + let _ = chan + .send(ChannelMsg::RequestSubsystem { + want_reply: true, + name: name.clone(), + }) + .await; } debug!("handler.subsystem_request {:?}", channel_num); handler.subsystem_request(channel_num, &name, self).await @@ -974,12 +992,14 @@ impl Session { let pix_height = map_err!(u32::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::WindowChange { - col_width, - row_height, - pix_width, - pix_height, - }); + let _ = chan + .send(ChannelMsg::WindowChange { + col_width, + row_height, + pix_width, + pix_height, + }) + .await; } debug!("handler.window_change {:?}", channel_num); @@ -1000,6 +1020,7 @@ impl Session { chan.send(ChannelMsg::Signal { signal: signal.clone(), }) + .await .unwrap_or(()) } debug!("handler.signal {:?} {:?}", channel_num, signal); @@ -1110,6 +1131,7 @@ impl Session { if let Some(channel_sender) = self.channels.remove(&channel_num) { channel_sender .send(ChannelMsg::OpenFailure(reason)) + .await .map_err(|_| crate::Error::SendError)?; } @@ -1204,6 +1226,7 @@ impl Session { self.sender.sender.clone(), channel_params.recipient_maximum_packet_size, channel_params.recipient_window_size, + self.common.config.channel_buffer_size, ); match &msg.typ { diff --git a/russh/src/server/mod.rs b/russh/src/server/mod.rs index 9c41078e..fd3b683d 100644 --- a/russh/src/server/mod.rs +++ b/russh/src/server/mod.rs @@ -79,6 +79,8 @@ pub struct Config { pub window_size: u32, /// The maximal size of a single packet. pub maximum_packet_size: u32, + /// Buffer size for each channel (a number of unprocessed messages to store before propagating backpressure to the TCP stream) + pub channel_buffer_size: usize, /// Internal event buffer size pub event_buffer_size: usize, /// Lists of preferred algorithms. @@ -108,6 +110,7 @@ impl Default for Config { keys: Vec::new(), window_size: 2097152, maximum_packet_size: 32768, + channel_buffer_size: 100, event_buffer_size: 10, limits: Limits::default(), preferred: Default::default(), @@ -134,6 +137,7 @@ impl Debug for Config { .field("keys", &"***") .field("window_size", &self.window_size) .field("maximum_packet_size", &self.maximum_packet_size) + .field("channel_buffer_size", &self.channel_buffer_size) .field("event_buffer_size", &self.event_buffer_size) .field("limits", &self.limits) .field("preferred", &self.preferred) @@ -748,10 +752,10 @@ pub trait Server { match accept_result { Ok((socket, _)) => { let config = config.clone(); - let handler = self.new_client(socket.peer_addr().ok()); + let handler = self.new_client(socket.peer_addr().ok()); let error_tx = error_tx.clone(); russh_util::runtime::spawn(async move { - let session = match run_stream(config, socket, handler).await { + let session = match run_stream(config, socket, handler).await { Ok(s) => s, Err(e) => { debug!("Connection setup failed"); @@ -856,8 +860,11 @@ where // Reading SSH id and allocating a session. let mut stream = SshRead::new(stream); let (sender, receiver) = tokio::sync::mpsc::channel(config.event_buffer_size); + let handle = server::session::Handle { + sender, + channel_buffer_size: config.channel_buffer_size, + }; let common = read_ssh_id(config, &mut stream).await?; - let handle = server::session::Handle { sender }; let session = Session { target_window_size: common.config.window_size, common, diff --git a/russh/src/server/session.rs b/russh/src/server/session.rs index 4f379e7c..30968120 100644 --- a/russh/src/server/session.rs +++ b/russh/src/server/session.rs @@ -7,7 +7,7 @@ use negotiation::parse_kex_algo_list; use russh_keys::helpers::NameList; use russh_keys::map_err; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use tokio::sync::mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver}; +use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::oneshot; use super::*; @@ -90,6 +90,7 @@ impl From<(ChannelId, ChannelMsg)> for Msg { /// the request/response cycle. pub struct Handle { pub(crate) sender: Sender, + pub(crate) channel_buffer_size: usize, } impl Handle { @@ -217,7 +218,7 @@ impl Handle { /// confirmed that it allows agent forwarding. See /// [PROTOCOL.agent](https://datatracker.ietf.org/doc/html/draft-miller-ssh-agent). pub async fn channel_open_agent(&self) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -236,7 +237,7 @@ impl Handle { /// usable when it's confirmed by the server, as indicated by the /// `confirmed` field of the corresponding `Channel`. pub async fn channel_open_session(&self) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -261,7 +262,7 @@ impl Handle { originator_address: B, originator_port: u32, ) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -286,7 +287,7 @@ impl Handle { originator_address: B, originator_port: u32, ) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -308,7 +309,7 @@ impl Handle { &self, server_socket_path: A, ) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -328,7 +329,7 @@ impl Handle { originator_address: A, originator_port: u32, ) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -346,7 +347,7 @@ impl Handle { async fn wait_channel_confirmation( &self, - mut receiver: UnboundedReceiver, + mut receiver: Receiver, window_size_ref: WindowSizeRef, ) -> Result, Error> { loop { diff --git a/russh/tests/test_backpressure.rs b/russh/tests/test_backpressure.rs new file mode 100644 index 00000000..55688ed7 --- /dev/null +++ b/russh/tests/test_backpressure.rs @@ -0,0 +1,152 @@ +use std::net::{SocketAddr, TcpListener, TcpStream}; +use std::sync::Arc; + +use futures::FutureExt; +use rand::RngCore; +use rand_core::OsRng; +use russh::server::{self, Auth, Msg, Server as _, Session}; +use russh::{client, Channel, ChannelMsg}; +use russh_keys::key::PrivateKeyWithHashAlg; +use ssh_key::PrivateKey; +use tokio::io::AsyncWriteExt; +use tokio::sync::watch; +use tokio::time::sleep; + +pub const WINDOW_SIZE: usize = 8 * 2048; +pub const CHANNEL_BUFFER_SIZE: usize = 10; + +#[tokio::test] +async fn test_backpressure() -> Result<(), anyhow::Error> { + env_logger::init(); + + let addr = addr(); + let data = data(); + let (tx, rx) = watch::channel(()); + + tokio::spawn(Server::run(addr, rx)); + + // Wait until the server is started + while TcpStream::connect(addr).is_err() { + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + + stream(addr, &data, tx).await?; + + Ok(()) +} + +async fn stream(addr: SocketAddr, data: &[u8], tx: watch::Sender<()>) -> Result<(), anyhow::Error> { + let config = Arc::new(client::Config::default()); + let key = Arc::new(PrivateKey::random(&mut OsRng, ssh_key::Algorithm::Ed25519).unwrap()); + + let mut session = russh::client::connect(config, addr, Client).await?; + let channel = match session + .authenticate_publickey("user", PrivateKeyWithHashAlg::new(key, None).unwrap()) + .await + { + Ok(true) => session.channel_open_session().await?, + Ok(false) => panic!("Authentication failed"), + Err(err) => return Err(err.into()), + }; + + let mut writer = channel.make_writer(); + + // TCP listener will buffer one extra message + for _ in 0..=CHANNEL_BUFFER_SIZE { + assert!(writer.write(data).await.is_ok()); + } + let pending_write = async { writer.write(data).await.unwrap() }; + sleep(std::time::Duration::from_millis(100)).await; + assert_eq!(pending_write.now_or_never(), None); + // Make space on the buffer + tx.send(()).unwrap(); + assert!(writer.write(data).await.is_ok()); + + Ok(()) +} + +fn data() -> Vec { + let mut rng = rand::thread_rng(); + + let mut data = vec![0u8; WINDOW_SIZE]; // Check whether the window_size resizing works + rng.fill_bytes(&mut data); + + data +} + +/// Find a unused local address to bind our server to +fn addr() -> SocketAddr { + TcpListener::bind(("127.0.0.1", 0)) + .unwrap() + .local_addr() + .unwrap() +} + +#[derive(Clone)] +struct Server { + rx: Option>, +} + +impl Server { + async fn run(addr: SocketAddr, rx: watch::Receiver<()>) { + let config = Arc::new(server::Config { + keys: vec![PrivateKey::random(&mut OsRng, ssh_key::Algorithm::Ed25519).unwrap()], + window_size: WINDOW_SIZE as u32, + channel_buffer_size: CHANNEL_BUFFER_SIZE, + ..Default::default() + }); + let mut sh = Server { rx: Some(rx) }; + + sh.run_on_address(config, addr).await.unwrap(); + } +} + +impl russh::server::Server for Server { + type Handler = Self; + + fn new_client(&mut self, _: Option) -> Self::Handler { + self.clone() + } +} + +#[async_trait::async_trait] +impl russh::server::Handler for Server { + type Error = anyhow::Error; + + async fn auth_publickey( + &mut self, + _: &str, + _: &ssh_key::PublicKey, + ) -> Result { + Ok(Auth::Accept) + } + + async fn channel_open_session( + &mut self, + mut channel: Channel, + _session: &mut Session, + ) -> Result { + let mut rx = self.rx.take().unwrap(); + tokio::spawn(async move { + while let Ok(_) = rx.changed().await { + match channel.wait().await { + Some(ChannelMsg::Data { .. }) => (), + other => panic!("unexpected message {:?}", other), + } + } + }); + + Ok(true) + } +} + +struct Client; + +#[async_trait::async_trait] +impl russh::client::Handler for Client { + type Error = anyhow::Error; + + async fn check_server_key(&mut self, _: &ssh_key::PublicKey) -> Result { + Ok(true) + } +}