Skip to content

Commit

Permalink
An attempt at #401 - removing TX busywait (#408)
Browse files Browse the repository at this point in the history
Co-authored-by: Eric Rodrigues Pires <eric@eric.dev.br>
  • Loading branch information
Eugeny and EpicEric authored Dec 7, 2024
1 parent ac441a6 commit a5c4adc
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 52 deletions.
10 changes: 4 additions & 6 deletions russh/src/channels/channel_ref.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
use std::sync::Arc;

use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::Mutex;

use super::WindowSizeRef;
use crate::ChannelMsg;

/// A handle to the [`super::Channel`]'s to be able to transmit messages
/// to it and update it's `window_size`.
#[derive(Debug)]
pub struct ChannelRef {
pub(super) sender: UnboundedSender<ChannelMsg>,
pub(super) window_size: Arc<Mutex<u32>>,
pub(super) window_size: WindowSizeRef,
}

impl ChannelRef {
pub fn new(sender: UnboundedSender<ChannelMsg>) -> Self {
Self {
sender,
window_size: Default::default(),
window_size: WindowSizeRef::new(0),
}
}

pub fn window_size(&self) -> &Arc<Mutex<u32>> {
pub(crate) fn window_size(&self) -> &WindowSizeRef {
&self.window_size
}
}
Expand Down
91 changes: 74 additions & 17 deletions russh/src/channels/io/tx.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use std::convert::TryFrom;
use std::future::Future;
use std::io;
use std::num::NonZero;
use std::ops::DerefMut;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{ready, Context, Poll};
Expand All @@ -7,7 +11,7 @@ use futures::FutureExt;
use tokio::io::AsyncWrite;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::mpsc::{self, OwnedPermit};
use tokio::sync::{Mutex, OwnedMutexGuard};
use tokio::sync::{Mutex, Notify, OwnedMutexGuard};

use super::ChannelMsg;
use crate::{ChannelId, CryptoVec};
Expand All @@ -16,13 +20,34 @@ type BoxedThreadsafeFuture<T> = Pin<Box<dyn Sync + Send + std::future::Future<Ou
type OwnedPermitFuture<S> =
BoxedThreadsafeFuture<Result<(OwnedPermit<S>, ChannelMsg, usize), SendError<()>>>;

struct WatchNotification(Pin<Box<dyn Sync + Send + Future<Output = ()>>>);

/// A single future that becomes ready once the window size
/// changes to a positive value
impl WatchNotification {
fn new(n: Arc<Notify>) -> Self {
Self(Box::pin(async move { n.notified().await }))
}
}

impl Future for WatchNotification {
type Output = ();

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let inner = self.deref_mut().0.as_mut();
ready!(inner.poll(cx));
Poll::Ready(())
}
}

pub struct ChannelTx<S> {
sender: mpsc::Sender<S>,
send_fut: Option<OwnedPermitFuture<S>>,
id: ChannelId,

window_size_fut: Option<BoxedThreadsafeFuture<OwnedMutexGuard<u32>>>,
window_size: Arc<Mutex<u32>>,
notify: Arc<Notify>,
window_size_notication: WatchNotification,
max_packet_size: u32,
ext: Option<u32>,
}
Expand All @@ -35,43 +60,62 @@ where
sender: mpsc::Sender<S>,
id: ChannelId,
window_size: Arc<Mutex<u32>>,
window_size_notification: Arc<Notify>,
max_packet_size: u32,
ext: Option<u32>,
) -> Self {
Self {
sender,
send_fut: None,
id,
notify: Arc::clone(&window_size_notification),
window_size_notication: WatchNotification::new(window_size_notification),
window_size,
window_size_fut: None,
max_packet_size,
ext,
}
}

fn poll_mk_msg(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<(ChannelMsg, usize)> {
fn poll_writable(&mut self, cx: &mut Context<'_>, buf_len: usize) -> Poll<NonZero<usize>> {
let window_size = self.window_size.clone();
let window_size_fut = self
.window_size_fut
.get_or_insert_with(|| Box::pin(window_size.lock_owned()));
let mut window_size = ready!(window_size_fut.poll_unpin(cx));
self.window_size_fut.take();

let writable = (self.max_packet_size)
.min(*window_size)
.min(buf.len() as u32) as usize;
if writable == 0 {
// TODO fix this busywait
cx.waker().wake_by_ref();
return Poll::Pending;
let writable = (self.max_packet_size).min(*window_size).min(buf_len as u32) as usize;

match NonZero::try_from(writable) {
Ok(w) => {
*window_size -= writable as u32;
if *window_size > 0 {
self.notify.notify_one();
}
Poll::Ready(w)
}
Err(_) => {
drop(window_size);
ready!(self.window_size_notication.poll_unpin(cx));
self.window_size_notication = WatchNotification::new(Arc::clone(&self.notify));
cx.waker().wake_by_ref();
Poll::Pending
}
}
let mut data = CryptoVec::new_zeroed(writable);
#[allow(clippy::indexing_slicing)] // Clamped to maximum `buf.len()` with `.min`
data.copy_from_slice(&buf[..writable]);
data.resize(writable);
}

fn poll_mk_msg(
&mut self,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<(ChannelMsg, NonZero<usize>)> {
let writable = ready!(self.poll_writable(cx, buf.len()));

*window_size -= writable as u32;
drop(window_size);
let mut data = CryptoVec::new_zeroed(writable.into());
#[allow(clippy::indexing_slicing)] // Clamped to maximum `buf.len()` with `.poll_writable`
data.copy_from_slice(&buf[..writable.into()]);
data.resize(writable.into());

let msg = match self.ext {
None => ChannelMsg::Data { data },
Expand Down Expand Up @@ -116,11 +160,17 @@ where
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
if buf.is_empty() {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"cannot send empty buffer",
)));
}
let send_fut = if let Some(x) = self.send_fut.as_mut() {
x
} else {
let (msg, writable) = ready!(self.poll_mk_msg(cx, buf));
self.activate(msg, writable)
self.activate(msg, writable.into())
};
let r = ready!(send_fut.as_mut().poll_unpin(cx));
Poll::Ready(self.handle_write_result(r))
Expand All @@ -143,3 +193,10 @@ where
Poll::Ready(self.handle_write_result(r).map(drop))
}
}

impl<S> Drop for ChannelTx<S> {
fn drop(&mut self) {
// Allow other writers to make progress
self.notify.notify_one();
}
}
40 changes: 34 additions & 6 deletions russh/src/channels/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Arc;

use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::mpsc::{Sender, UnboundedReceiver};
use tokio::sync::Mutex;
use tokio::sync::{Mutex, Notify};

use crate::{ChannelId, ChannelOpenFailure, CryptoVec, Error, Pty, Sig};

Expand Down Expand Up @@ -112,6 +112,31 @@ pub enum ChannelMsg {
OpenFailure(ChannelOpenFailure),
}

#[derive(Clone, Debug)]
pub(crate) struct WindowSizeRef {
value: Arc<Mutex<u32>>,
notifier: Arc<Notify>,
}

impl WindowSizeRef {
pub(crate) fn new(initial: u32) -> Self {
let notifier = Arc::new(Notify::new());
Self {
value: Arc::new(Mutex::new(initial)),
notifier,
}
}

pub(crate) async fn update(&self, value: u32) {
*self.value.lock().await = value;
self.notifier.notify_one();
}

pub(crate) fn subscribe(&self) -> Arc<Notify> {
Arc::clone(&self.notifier)
}
}

/// A handle to a session channel.
///
/// Allows you to read and write from a channel without borrowing the session
Expand All @@ -120,7 +145,7 @@ pub struct Channel<Send: From<(ChannelId, ChannelMsg)>> {
pub(crate) sender: Sender<Send>,
pub(crate) receiver: UnboundedReceiver<ChannelMsg>,
pub(crate) max_packet_size: u32,
pub(crate) window_size: Arc<Mutex<u32>>,
pub(crate) window_size: WindowSizeRef,
}

impl<T: From<(ChannelId, ChannelMsg)>> std::fmt::Debug for Channel<T> {
Expand All @@ -137,7 +162,7 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static> Channel<S> {
window_size: u32,
) -> (Self, ChannelRef) {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let window_size = Arc::new(Mutex::new(window_size));
let window_size = WindowSizeRef::new(window_size);

(
Self {
Expand All @@ -157,7 +182,8 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static> Channel<S> {
/// Returns the min between the maximum packet size and the
/// remaining window size in the channel.
pub async fn writable_packet_size(&self) -> usize {
self.max_packet_size.min(*self.window_size.lock().await) as usize
self.max_packet_size
.min(*self.window_size.value.lock().await) as usize
}

pub fn id(&self) -> ChannelId {
Expand Down Expand Up @@ -337,7 +363,8 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static> Channel<S> {
io::ChannelTx::new(
self.sender.clone(),
self.id,
self.window_size.clone(),
self.window_size.value.clone(),
self.window_size.subscribe(),
self.max_packet_size,
None,
),
Expand Down Expand Up @@ -369,7 +396,8 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static> Channel<S> {
io::ChannelTx::new(
self.sender.clone(),
self.id,
self.window_size.clone(),
self.window_size.value.clone(),
self.window_size.subscribe(),
self.max_packet_size,
ext,
)
Expand Down
2 changes: 1 addition & 1 deletion russh/src/client/encrypted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ impl Session {
new_size -= enc.flush_pending(channel_num)? as u32;
}
if let Some(chan) = self.channels.get(&channel_num) {
*chan.window_size().lock().await = new_size;
chan.window_size().update(new_size).await;

let _ = chan.send(ChannelMsg::WindowAdjusted { new_size });
}
Expand Down
8 changes: 4 additions & 4 deletions russh/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ use tokio::pin;
use tokio::sync::mpsc::{
channel, unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender,
};
use tokio::sync::{oneshot, Mutex};
use tokio::sync::oneshot;

use crate::channels::{Channel, ChannelMsg, ChannelRef};
use crate::channels::{Channel, ChannelMsg, ChannelRef, WindowSizeRef};
use crate::cipher::{self, clear, CipherPair, OpeningKey};
use crate::keys::key::parse_public_key;
use crate::session::{
Expand Down Expand Up @@ -428,7 +428,7 @@ impl<H: Handler> Handle<H> {
async fn wait_channel_confirmation(
&self,
mut receiver: UnboundedReceiver<ChannelMsg>,
window_size_ref: Arc<Mutex<u32>>,
window_size_ref: WindowSizeRef,
) -> Result<Channel<Msg>, crate::Error> {
loop {
match receiver.recv().await {
Expand All @@ -437,7 +437,7 @@ impl<H: Handler> Handle<H> {
max_packet_size,
window_size,
}) => {
*window_size_ref.lock().await = window_size;
window_size_ref.update(window_size).await;

return Ok(Channel {
id,
Expand Down
2 changes: 1 addition & 1 deletion russh/src/server/encrypted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ impl Session {
enc.flush_pending(channel_num)?;
}
if let Some(chan) = self.channels.get(&channel_num) {
*chan.window_size().lock().await = new_size;
chan.window_size().update(new_size).await;

chan.send(ChannelMsg::WindowAdjusted { new_size })
.unwrap_or(())
Expand Down
7 changes: 4 additions & 3 deletions russh/src/server/session.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;

use channels::WindowSizeRef;
use log::debug;
use negotiation::parse_kex_algo_list;
use russh_keys::helpers::NameList;
use russh_keys::map_err;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::sync::mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver};
use tokio::sync::{oneshot, Mutex};
use tokio::sync::oneshot;

use super::*;
use crate::channels::{Channel, ChannelMsg, ChannelRef};
Expand Down Expand Up @@ -346,7 +347,7 @@ impl Handle {
async fn wait_channel_confirmation(
&self,
mut receiver: UnboundedReceiver<ChannelMsg>,
window_size_ref: Arc<Mutex<u32>>,
window_size_ref: WindowSizeRef,
) -> Result<Channel<Msg>, Error> {
loop {
match receiver.recv().await {
Expand All @@ -355,7 +356,7 @@ impl Handle {
max_packet_size,
window_size,
}) => {
*window_size_ref.lock().await = window_size;
window_size_ref.update(window_size).await;

return Ok(Channel {
id,
Expand Down
7 changes: 0 additions & 7 deletions russh/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,6 @@ impl Encrypted {
Ok(())
}

/*
pub fn authenticated(&mut self) {
self.server_compression.init_compress(&mut self.compress);
self.state = EncryptedState::Authenticated;
}
*/

pub fn eof(&mut self, channel: ChannelId) -> Result<(), crate::Error> {
if let Some(channel) = self.has_pending_data_mut(channel) {
channel.pending_eof = true;
Expand Down
Loading

0 comments on commit a5c4adc

Please sign in to comment.