diff --git a/russh/src/channels/channel_ref.rs b/russh/src/channels/channel_ref.rs index e51c5a57..d7f937cd 100644 --- a/russh/src/channels/channel_ref.rs +++ b/russh/src/channels/channel_ref.rs @@ -1,4 +1,4 @@ -use tokio::sync::mpsc::UnboundedSender; +use tokio::sync::mpsc::Sender; use super::WindowSizeRef; use crate::ChannelMsg; @@ -7,12 +7,12 @@ use crate::ChannelMsg; /// to it and update it's `window_size`. #[derive(Debug)] pub struct ChannelRef { - pub(super) sender: UnboundedSender, + pub(super) sender: Sender, pub(super) window_size: WindowSizeRef, } impl ChannelRef { - pub fn new(sender: UnboundedSender) -> Self { + pub fn new(sender: Sender) -> Self { Self { sender, window_size: WindowSizeRef::new(0), @@ -25,7 +25,7 @@ impl ChannelRef { } impl std::ops::Deref for ChannelRef { - type Target = UnboundedSender; + type Target = Sender; fn deref(&self) -> &Self::Target { &self.sender diff --git a/russh/src/channels/mod.rs b/russh/src/channels/mod.rs index c963a0b1..e08cceb6 100644 --- a/russh/src/channels/mod.rs +++ b/russh/src/channels/mod.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::sync::mpsc::{Sender, UnboundedReceiver}; +use tokio::sync::mpsc::{Receiver, Sender}; use tokio::sync::{Mutex, Notify}; use crate::{ChannelId, ChannelOpenFailure, CryptoVec, Error, Pty, Sig}; @@ -143,7 +143,7 @@ impl WindowSizeRef { pub struct Channel> { pub(crate) id: ChannelId, pub(crate) sender: Sender, - pub(crate) receiver: UnboundedReceiver, + pub(crate) receiver: Receiver, pub(crate) max_packet_size: u32, pub(crate) window_size: WindowSizeRef, } @@ -160,8 +160,9 @@ impl + Send + Sync + 'static> Channel { sender: Sender, max_packet_size: u32, window_size: u32, + channel_buffer_size: usize, ) -> (Self, ChannelRef) { - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let (tx, rx) = tokio::sync::mpsc::channel(channel_buffer_size); let window_size = WindowSizeRef::new(window_size); ( diff --git a/russh/src/client/encrypted.rs b/russh/src/client/encrypted.rs index 83d1f4d2..a4d50ae8 100644 --- a/russh/src/client/encrypted.rs +++ b/russh/src/client/encrypted.rs @@ -428,6 +428,7 @@ impl Session { max_packet_size: msg.maximum_packet_size, window_size: msg.initial_window_size, }) + .await .unwrap_or(()); } else { error!("no channel for id {local_id:?}"); @@ -457,7 +458,7 @@ impl Session { debug!("channel_eof"); let channel_num = map_err!(ChannelId::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::Eof); + let _ = chan.send(ChannelMsg::Eof).await; } client.channel_eof(channel_num, self).await } @@ -473,7 +474,7 @@ impl Session { } if let Some(sender) = self.channels.remove(&channel_num) { - let _ = sender.send(ChannelMsg::OpenFailure(reason_code)); + let _ = sender.send(ChannelMsg::OpenFailure(reason_code)).await; } let _ = self.sender.send(Reply::ChannelOpenFailure); @@ -498,9 +499,11 @@ impl Session { } if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::Data { - data: CryptoVec::from_slice(&data), - }); + let _ = chan + .send(ChannelMsg::Data { + data: CryptoVec::from_slice(&data), + }) + .await; } client.data(channel_num, &data, self).await @@ -522,10 +525,12 @@ impl Session { } if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::ExtendedData { - ext: extended_code, - data: CryptoVec::from_slice(&data), - }); + let _ = chan + .send(ChannelMsg::ExtendedData { + ext: extended_code, + data: CryptoVec::from_slice(&data), + }) + .await; } client @@ -541,7 +546,7 @@ impl Session { map_err!(u8::decode(&mut r))?; // should be 0. let client_can_do = map_err!(u8::decode(&mut r))? != 0; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::XonXoff { client_can_do }); + let _ = chan.send(ChannelMsg::XonXoff { client_can_do }).await; } client.xon_xoff(channel_num, client_can_do, self).await } @@ -549,7 +554,7 @@ impl Session { map_err!(u8::decode(&mut r))?; // should be 0. let exit_status = map_err!(u32::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::ExitStatus { exit_status }); + let _ = chan.send(ChannelMsg::ExitStatus { exit_status }).await; } client.exit_status(channel_num, exit_status, self).await } @@ -561,12 +566,14 @@ impl Session { let error_message = map_err!(String::decode(&mut r))?; let lang_tag = map_err!(String::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::ExitSignal { - signal_name: signal_name.clone(), - core_dumped, - error_message: error_message.to_string(), - lang_tag: lang_tag.to_string(), - }); + let _ = chan + .send(ChannelMsg::ExitSignal { + signal_name: signal_name.clone(), + core_dumped, + error_message: error_message.to_string(), + lang_tag: lang_tag.to_string(), + }) + .await; } client .exit_signal( @@ -632,7 +639,7 @@ impl Session { if let Some(chan) = self.channels.get(&channel_num) { chan.window_size().update(new_size).await; - let _ = chan.send(ChannelMsg::WindowAdjusted { new_size }); + let _ = chan.send(ChannelMsg::WindowAdjusted { new_size }).await; } client.window_adjusted(channel_num, new_size, self).await } @@ -682,14 +689,14 @@ impl Session { Some((&msg::CHANNEL_SUCCESS, mut r)) => { let channel_num = map_err!(ChannelId::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::Success); + let _ = chan.send(ChannelMsg::Success).await; } client.channel_success(channel_num, self).await } Some((&msg::CHANNEL_FAILURE, mut r)) => { let channel_num = map_err!(ChannelId::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::Failure); + let _ = chan.send(ChannelMsg::Failure).await; } client.channel_failure(channel_num, self).await } @@ -884,6 +891,7 @@ impl Session { self.inbound_channel_sender.clone(), msg.recipient_maximum_packet_size, msg.recipient_window_size, + self.common.config.channel_buffer_size, ); self.channels.insert(id, channel_ref); diff --git a/russh/src/client/mod.rs b/russh/src/client/mod.rs index 7bc5b1fa..f831f89e 100644 --- a/russh/src/client/mod.rs +++ b/russh/src/client/mod.rs @@ -224,6 +224,7 @@ pub struct Handle { sender: Sender, receiver: UnboundedReceiver, join: russh_util::runtime::JoinHandle>, + channel_buffer_size: usize, } impl Drop for Handle { @@ -427,7 +428,7 @@ impl Handle { /// Wait for confirmation that a channel is open async fn wait_channel_confirmation( &self, - mut receiver: UnboundedReceiver, + mut receiver: Receiver, window_size_ref: WindowSizeRef, ) -> Result, crate::Error> { loop { @@ -466,7 +467,7 @@ impl Handle { /// usable when it's confirmed by the server, as indicated by the /// `confirmed` field of the corresponding `Channel`. pub async fn channel_open_session(&self) -> Result, crate::Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -484,7 +485,7 @@ impl Handle { originator_address: A, originator_port: u32, ) -> Result, crate::Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -515,7 +516,7 @@ impl Handle { originator_address: B, originator_port: u32, ) -> Result, crate::Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -537,7 +538,7 @@ impl Handle { &self, socket_path: S, ) -> Result, crate::Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -749,6 +750,7 @@ where config.maximum_packet_size ); } + let channel_buffer_size = config.channel_buffer_size; let mut session = Session::new( config.window_size, CommonSession { @@ -789,6 +791,7 @@ where sender: handle_sender, receiver: handle_receiver, join, + channel_buffer_size, }) } @@ -1273,16 +1276,6 @@ impl Session { } Ok(()) } - - /// Send a `ChannelMsg` from the background handler to the client. - pub fn send_channel_msg(&self, channel: ChannelId, msg: ChannelMsg) -> bool { - if let Some(chan) = self.channels.get(&channel) { - chan.send(msg).unwrap_or(()); - true - } else { - false - } - } } thread_local! { @@ -1482,6 +1475,8 @@ pub struct Config { pub window_size: u32, /// The maximal size of a single packet. pub maximum_packet_size: u32, + /// Buffer size for created channels. + pub channel_buffer_size: usize, /// Lists of preferred algorithms. pub preferred: negotiation::Preferred, /// Time after which the connection is garbage-collected. @@ -1505,6 +1500,7 @@ impl Default for Config { limits: Limits::default(), window_size: 2097152, maximum_packet_size: 32768, + channel_buffer_size: 100, preferred: Default::default(), inactivity_timeout: None, keepalive_interval: None, diff --git a/russh/src/server/encrypted.rs b/russh/src/server/encrypted.rs index f2c9ffb0..b496c85b 100644 --- a/russh/src/server/encrypted.rs +++ b/russh/src/server/encrypted.rs @@ -701,7 +701,7 @@ impl Session { msg::CHANNEL_EOF => { let channel_num = map_err!(ChannelId::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - chan.send(ChannelMsg::Eof).unwrap_or(()) + chan.send(ChannelMsg::Eof).await.unwrap_or(()) } debug!("handler.channel_eof {:?}", channel_num); handler.channel_eof(channel_num, self).await @@ -733,6 +733,7 @@ impl Session { ext, data: CryptoVec::from_slice(&data), }) + .await .unwrap_or(()) } handler.extended_data(channel_num, ext, &data, self).await @@ -741,6 +742,7 @@ impl Session { chan.send(ChannelMsg::Data { data: CryptoVec::from_slice(&data), }) + .await .unwrap_or(()) } handler.data(channel_num, &data, self).await @@ -766,6 +768,7 @@ impl Session { chan.window_size().update(new_size).await; chan.send(ChannelMsg::WindowAdjusted { new_size }) + .await .unwrap_or(()) } debug!("handler.window_adjusted {:?}", channel_num); @@ -795,6 +798,7 @@ impl Session { max_packet_size: msg.maximum_packet_size, window_size: msg.initial_window_size, }) + .await .unwrap_or(()); } else { error!("no channel for id {:?}", local_id); @@ -853,15 +857,17 @@ impl Session { } if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::RequestPty { - want_reply: true, - term: term.clone(), - col_width, - row_height, - pix_width, - pix_height, - terminal_modes: modes.into(), - }); + let _ = chan + .send(ChannelMsg::RequestPty { + want_reply: true, + term: term.clone(), + col_width, + row_height, + pix_width, + pix_height, + terminal_modes: modes.into(), + }) + .await; } debug!("handler.pty_request {:?}", channel_num); @@ -886,13 +892,15 @@ impl Session { let x11_screen_number = map_err!(u32::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::RequestX11 { - want_reply: true, - single_connection, - x11_authentication_cookie: x11_auth_cookie.clone(), - x11_authentication_protocol: x11_auth_protocol.clone(), - x11_screen_number, - }); + let _ = chan + .send(ChannelMsg::RequestX11 { + want_reply: true, + single_connection, + x11_authentication_cookie: x11_auth_cookie.clone(), + x11_authentication_protocol: x11_auth_protocol.clone(), + x11_screen_number, + }) + .await; } debug!("handler.x11_request {:?}", channel_num); handler @@ -911,11 +919,13 @@ impl Session { let env_value = map_err!(String::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::SetEnv { - want_reply: true, - variable_name: env_variable.clone(), - variable_value: env_value.clone(), - }); + let _ = chan + .send(ChannelMsg::SetEnv { + want_reply: true, + variable_name: env_variable.clone(), + variable_value: env_value.clone(), + }) + .await; } debug!("handler.env_request {:?}", channel_num); @@ -925,14 +935,18 @@ impl Session { } "shell" => { if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::RequestShell { want_reply: true }); + let _ = chan + .send(ChannelMsg::RequestShell { want_reply: true }) + .await; } debug!("handler.shell_request {:?}", channel_num); handler.shell_request(channel_num, self).await } "auth-agent-req@openssh.com" => { if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::AgentForward { want_reply: true }); + let _ = chan + .send(ChannelMsg::AgentForward { want_reply: true }) + .await; } debug!("handler.agent_request {:?}", channel_num); @@ -947,10 +961,12 @@ impl Session { "exec" => { let req = map_err!(Bytes::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::Exec { - want_reply: true, - command: req.to_vec(), - }); + let _ = chan + .send(ChannelMsg::Exec { + want_reply: true, + command: req.to_vec(), + }) + .await; } debug!("handler.exec_request {:?}", channel_num); handler.exec_request(channel_num, &req, self).await @@ -959,10 +975,12 @@ impl Session { let name = map_err!(String::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::RequestSubsystem { - want_reply: true, - name: name.clone(), - }); + let _ = chan + .send(ChannelMsg::RequestSubsystem { + want_reply: true, + name: name.clone(), + }) + .await; } debug!("handler.subsystem_request {:?}", channel_num); handler.subsystem_request(channel_num, &name, self).await @@ -974,12 +992,14 @@ impl Session { let pix_height = map_err!(u32::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::WindowChange { - col_width, - row_height, - pix_width, - pix_height, - }); + let _ = chan + .send(ChannelMsg::WindowChange { + col_width, + row_height, + pix_width, + pix_height, + }) + .await; } debug!("handler.window_change {:?}", channel_num); @@ -1000,6 +1020,7 @@ impl Session { chan.send(ChannelMsg::Signal { signal: signal.clone(), }) + .await .unwrap_or(()) } debug!("handler.signal {:?} {:?}", channel_num, signal); @@ -1110,6 +1131,7 @@ impl Session { if let Some(channel_sender) = self.channels.remove(&channel_num) { channel_sender .send(ChannelMsg::OpenFailure(reason)) + .await .map_err(|_| crate::Error::SendError)?; } @@ -1204,6 +1226,7 @@ impl Session { self.sender.sender.clone(), channel_params.recipient_maximum_packet_size, channel_params.recipient_window_size, + self.common.config.channel_buffer_size, ); match &msg.typ { diff --git a/russh/src/server/mod.rs b/russh/src/server/mod.rs index 9c41078e..8f2ed31b 100644 --- a/russh/src/server/mod.rs +++ b/russh/src/server/mod.rs @@ -79,6 +79,8 @@ pub struct Config { pub window_size: u32, /// The maximal size of a single packet. pub maximum_packet_size: u32, + /// Buffer size for created channels. + pub channel_buffer_size: usize, /// Internal event buffer size pub event_buffer_size: usize, /// Lists of preferred algorithms. @@ -108,6 +110,7 @@ impl Default for Config { keys: Vec::new(), window_size: 2097152, maximum_packet_size: 32768, + channel_buffer_size: 100, event_buffer_size: 10, limits: Limits::default(), preferred: Default::default(), @@ -134,6 +137,7 @@ impl Debug for Config { .field("keys", &"***") .field("window_size", &self.window_size) .field("maximum_packet_size", &self.maximum_packet_size) + .field("channel_buffer_size", &self.channel_buffer_size) .field("event_buffer_size", &self.event_buffer_size) .field("limits", &self.limits) .field("preferred", &self.preferred) @@ -748,10 +752,10 @@ pub trait Server { match accept_result { Ok((socket, _)) => { let config = config.clone(); - let handler = self.new_client(socket.peer_addr().ok()); + let handler = self.new_client(socket.peer_addr().ok()); let error_tx = error_tx.clone(); russh_util::runtime::spawn(async move { - let session = match run_stream(config, socket, handler).await { + let session = match run_stream(config, socket, handler).await { Ok(s) => s, Err(e) => { debug!("Connection setup failed"); @@ -856,8 +860,11 @@ where // Reading SSH id and allocating a session. let mut stream = SshRead::new(stream); let (sender, receiver) = tokio::sync::mpsc::channel(config.event_buffer_size); + let handle = server::session::Handle { + sender, + channel_buffer_size: config.channel_buffer_size, + }; let common = read_ssh_id(config, &mut stream).await?; - let handle = server::session::Handle { sender }; let session = Session { target_window_size: common.config.window_size, common, diff --git a/russh/src/server/session.rs b/russh/src/server/session.rs index 4f379e7c..30968120 100644 --- a/russh/src/server/session.rs +++ b/russh/src/server/session.rs @@ -7,7 +7,7 @@ use negotiation::parse_kex_algo_list; use russh_keys::helpers::NameList; use russh_keys::map_err; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use tokio::sync::mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver}; +use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::oneshot; use super::*; @@ -90,6 +90,7 @@ impl From<(ChannelId, ChannelMsg)> for Msg { /// the request/response cycle. pub struct Handle { pub(crate) sender: Sender, + pub(crate) channel_buffer_size: usize, } impl Handle { @@ -217,7 +218,7 @@ impl Handle { /// confirmed that it allows agent forwarding. See /// [PROTOCOL.agent](https://datatracker.ietf.org/doc/html/draft-miller-ssh-agent). pub async fn channel_open_agent(&self) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -236,7 +237,7 @@ impl Handle { /// usable when it's confirmed by the server, as indicated by the /// `confirmed` field of the corresponding `Channel`. pub async fn channel_open_session(&self) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -261,7 +262,7 @@ impl Handle { originator_address: B, originator_port: u32, ) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -286,7 +287,7 @@ impl Handle { originator_address: B, originator_port: u32, ) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -308,7 +309,7 @@ impl Handle { &self, server_socket_path: A, ) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -328,7 +329,7 @@ impl Handle { originator_address: A, originator_port: u32, ) -> Result, Error> { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = channel(self.channel_buffer_size); let channel_ref = ChannelRef::new(sender); let window_size_ref = channel_ref.window_size().clone(); @@ -346,7 +347,7 @@ impl Handle { async fn wait_channel_confirmation( &self, - mut receiver: UnboundedReceiver, + mut receiver: Receiver, window_size_ref: WindowSizeRef, ) -> Result, Error> { loop {