Skip to content

Commit

Permalink
Improve keepalive and inactivity timers
Browse files Browse the repository at this point in the history
* Add an analogue of OpenSSH's `ServerAliveCountMax`.
* Use disjunctive futures for cleanly making these timers optional.
* Use the `Session` to pass information back to the main bg loop from
  the plaintext packet reader, so that only nontrivial data transfer
  will reset the inactivity timer. (And so that `ServerAliveCountMax`
  will be judged correctly.)
  • Loading branch information
mmirate committed Nov 13, 2023
1 parent 43bdc07 commit 50f0f1a
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 26 deletions.
13 changes: 13 additions & 0 deletions russh/src/client/encrypted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ impl Session {
);
match req {
b"xon-xoff" => {
self.activity = false;
r.read_byte().map_err(crate::Error::from)?; // should be 0.
let client_can_do = r.read_byte().map_err(crate::Error::from)? != 0;
if let Some(chan) = self.channels.get(&channel_num) {
Expand Down Expand Up @@ -572,6 +573,7 @@ impl Session {
.await
}
b"keepalive@openssh.com" => {
self.activity = false;
let wants_reply = r.read_byte().map_err(crate::Error::from)?;
if wants_reply == 1 {
if let Some(ref mut enc) = self.common.encrypted {
Expand All @@ -591,6 +593,7 @@ impl Session {
Ok((client, self))
}
_ => {
self.activity = false;
let wants_reply = r.read_byte().map_err(crate::Error::from)?;
if wants_reply == 1 {
if let Some(ref mut enc) = self.common.encrypted {
Expand Down Expand Up @@ -690,6 +693,7 @@ impl Session {
push_packet!(enc.write, enc.write.push(msg::REQUEST_FAILURE))
}
}
self.activity = false;
Ok((client, self))
}
Some(&msg::CHANNEL_SUCCESS) => {
Expand Down Expand Up @@ -802,7 +806,16 @@ impl Session {
Err(crate::Error::Inconsistent.into())
}
}
Some(&msg::REQUEST_SUCCESS | &msg::REQUEST_FAILURE)
if self.server_alive_timeouts > 0 =>
{
self.activity = false;
// TODO what other things might need to happen in response to these two opcodes?
self.server_alive_timeouts = 0;
Ok((client, self))
}
_ => {
self.activity = false;
info!("Unhandled packet: {:?}", buf);
Ok((client, self))
}
Expand Down
82 changes: 57 additions & 25 deletions russh/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,7 @@ use crate::key::PubKey;
use crate::session::{CommonSession, EncryptedState, Exchange, Kex, KexDhDone, KexInit, NewKeys};
use crate::ssh_read::SshRead;
use crate::sshbuffer::{SSHBuffer, SshId};
use crate::{
auth, msg, negotiation, timeout, ChannelId, ChannelOpenFailure, Disconnect, Limits, Sig,
};
use crate::{auth, msg, negotiation, ChannelId, ChannelOpenFailure, Disconnect, Limits, Sig};

mod encrypted;
mod kex;
Expand All @@ -126,6 +124,8 @@ pub struct Session {
pending_len: u32,
inbound_channel_sender: Sender<Msg>,
inbound_channel_receiver: Receiver<Msg>,
server_alive_timeouts: usize,
activity: bool,
}

impl Drop for Session {
Expand Down Expand Up @@ -723,6 +723,16 @@ async fn start_reading<R: AsyncRead + Unpin>(
Ok((n, stream_read, buffer, cipher))
}

fn future_or_pending<F: futures::Future, T>(
val: Option<T>,
f: impl FnOnce(T) -> F,
) -> futures::future::Either<futures::future::Pending<<F as futures::Future>::Output>, F> {
val.map_or(
futures::future::Either::Left(futures::future::pending()),
|x| futures::future::Either::Right(f(x)),
)
}

impl Session {
fn new(
target_window_size: u32,
Expand All @@ -741,6 +751,8 @@ impl Session {
channels: HashMap::new(),
pending_reads: Vec::new(),
pending_len: 0,
server_alive_timeouts: 0,
activity: false,
}
}

Expand Down Expand Up @@ -769,19 +781,32 @@ impl Session {
let mut opening_cipher = Box::new(clear::Key) as Box<dyn OpeningKey + Send>;
std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local);

let time_for_keepalive = tokio::time::sleep_until(self.common.config.keepalive_deadline());
let keepalive_timer =
future_or_pending(self.common.config.keepalive_interval, tokio::time::sleep);
pin!(keepalive_timer);

let inactivity_timer =
future_or_pending(self.common.config.inactivity_timeout, tokio::time::sleep);
pin!(inactivity_timer);

let reading = start_reading(stream_read, buffer, opening_cipher);
pin!(reading);
pin!(time_for_keepalive);

let delay = self.common.config.inactivity_timeout;

#[allow(clippy::panic)] // false positive in select! macro
while !self.common.disconnected {
self.activity = false;
tokio::select! {
() = &mut time_for_keepalive => {
time_for_keepalive.as_mut().reset(self.common.config.keepalive_deadline());
() = &mut keepalive_timer => {
self.send_keepalive(true);
if self.common.config.keepalive_max != 0 && self.server_alive_timeouts > self.common.config.keepalive_max {
debug!("Timeout, server not responding to keepalives");
break
}
self.server_alive_timeouts = self.server_alive_timeouts.saturating_add(1);
}
() = &mut inactivity_timer => {
debug!("timeout");
break
}
r = &mut reading => {
let (stream_read, buffer, mut opening_cipher) = match r {
Expand Down Expand Up @@ -814,6 +839,7 @@ impl Session {
if buf[0] == crate::msg::DISCONNECT {
break;
} else if buf[0] > 4 {
self.activity = true;
let (h, s) = reply(self, handler, &mut encrypted_signal, buf).await?;
handler = h;
self = s;
Expand All @@ -822,7 +848,6 @@ impl Session {

std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local);
reading.set(start_reading(stream_read, buffer, opening_cipher));
time_for_keepalive.as_mut().reset(self.common.config.keepalive_deadline());
}
msg = self.receiver.recv(), if !self.is_rekeying() => {
match msg {
Expand All @@ -840,7 +865,6 @@ impl Session {
Err(_) => break
}
}
time_for_keepalive.as_mut().reset(self.common.config.keepalive_deadline());
}
msg = self.inbound_channel_receiver.recv(), if !self.is_rekeying() => {
match msg {
Expand All @@ -856,17 +880,15 @@ impl Session {
}
}
}
_ = timeout(delay) => {
debug!("timeout");
break
},
}
};

self.flush()?;
if !self.common.write_buffer.buffer.is_empty() {
trace!(
"writing to stream: {:?} bytes",
self.common.write_buffer.buffer.len()
);
self.activity = true;
stream_write
.write_all(&self.common.write_buffer.buffer)
.await
Expand All @@ -880,6 +902,22 @@ impl Session {
enc.state = EncryptedState::Authenticated;
}
}

if let (futures::future::Either::Right(ref mut sleep), Some(d)) = (
keepalive_timer.as_mut().as_pin_mut(),
self.common.config.keepalive_interval,
) {
sleep.as_mut().reset(tokio::time::Instant::now() + d);
}

if self.activity {
if let (futures::future::Either::Right(ref mut sleep), Some(d)) = (
inactivity_timer.as_mut().as_pin_mut(),
self.common.config.inactivity_timeout,
) {
sleep.as_mut().reset(tokio::time::Instant::now() + d);
}
}
}
debug!("disconnected");
self.receiver.close();
Expand Down Expand Up @@ -1279,19 +1317,12 @@ pub struct Config {
pub inactivity_timeout: Option<std::time::Duration>,
/// If nothing is sent or received for this amount of time, send a keepalive message.
pub keepalive_interval: Option<std::time::Duration>,
/// If this many keepalives have been sent without reply, close the connection.
pub keepalive_max: usize,
/// Whether to expect and wait for an authentication call.
pub anonymous: bool,
}

impl Config {
fn keepalive_deadline(&self) -> tokio::time::Instant {
tokio::time::Instant::now()
+ self
.keepalive_interval
.unwrap_or(std::time::Duration::from_secs(86400 * 365))
}
}

impl Default for Config {
fn default() -> Config {
Config {
Expand All @@ -1306,6 +1337,7 @@ impl Default for Config {
preferred: Default::default(),
inactivity_timeout: None,
keepalive_interval: None,
keepalive_max: 3,
anonymous: false,
}
}
Expand Down
2 changes: 1 addition & 1 deletion russh/src/client/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ impl Session {
if let Some(ref mut enc) = self.common.encrypted {
push_packet!(enc.write, {
enc.write.push(msg::GLOBAL_REQUEST);
enc.write.extend_ssh_string(b"keepalive@libssh2.org");
enc.write.extend_ssh_string(b"keepalive@openssh.org");
enc.write.push(want_reply as u8);
});
}
Expand Down

0 comments on commit 50f0f1a

Please sign in to comment.