Skip to content

Commit

Permalink
Merge pull request #143 from MathiasKoch/enhancement/cancel-safe-read
Browse files Browse the repository at this point in the history
enhancement(async): Make RecordReaders read fn cancel-safe
  • Loading branch information
lulf authored Jun 6, 2024
2 parents f48952f + ff80fdc commit 9de49c2
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 53 deletions.
4 changes: 2 additions & 2 deletions src/asynch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ where
delegate: Socket,
opened: bool,
key_schedule: KeySchedule<CipherSuite>,
record_reader: RecordReader<'a, CipherSuite>,
record_reader: RecordReader<'a>,
record_write_buf: WriteBuffer<'a>,
decrypted: DecryptedBufferInfo,
}
Expand Down Expand Up @@ -365,7 +365,7 @@ where
state: State,
delegate: Socket,
key_schedule: ReadKeySchedule<CipherSuite>,
record_reader: RecordReader<'a, CipherSuite>,
record_reader: RecordReader<'a>,
decrypted: DecryptedBufferInfo,
}

Expand Down
4 changes: 2 additions & 2 deletions src/blocking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ where
delegate: Socket,
opened: bool,
key_schedule: KeySchedule<CipherSuite>,
record_reader: RecordReader<'a, CipherSuite>,
record_reader: RecordReader<'a>,
record_write_buf: WriteBuffer<'a>,
decrypted: DecryptedBufferInfo,
}
Expand Down Expand Up @@ -356,7 +356,7 @@ where
state: State,
delegate: Socket,
key_schedule: ReadKeySchedule<CipherSuite>,
record_reader: RecordReader<'a, CipherSuite>,
record_reader: RecordReader<'a>,
decrypted: DecryptedBufferInfo,
}

Expand Down
4 changes: 2 additions & 2 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ impl<'a> State {
self,
transport: &mut Transport,
handshake: &mut Handshake<Provider::CipherSuite>,
record_reader: &mut RecordReader<'_, Provider::CipherSuite>,
record_reader: &mut RecordReader<'_>,
tx_buf: &mut WriteBuffer<'_>,
key_schedule: &mut KeySchedule<Provider::CipherSuite>,
config: &TlsConfig<'a>,
Expand Down Expand Up @@ -237,7 +237,7 @@ impl<'a> State {
self,
transport: &mut Transport,
handshake: &mut Handshake<Provider::CipherSuite>,
record_reader: &mut RecordReader<'_, Provider::CipherSuite>,
record_reader: &mut RecordReader<'_>,
tx_buf: &mut WriteBuffer,
key_schedule: &mut KeySchedule<Provider::CipherSuite>,
config: &TlsConfig<'a>,
Expand Down
2 changes: 2 additions & 0 deletions src/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ pub struct RecordHeader {
}

impl RecordHeader {
pub const LEN: usize = 5;

pub fn content_type(&self) -> ContentType {
// Content type already validated in read
unwrap!(ContentType::of(self.header[0]))
Expand Down
92 changes: 45 additions & 47 deletions src/record_reader.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use core::marker::PhantomData;

use crate::key_schedule::ReadKeySchedule;
use embedded_io::{Error, Read as BlockingRead};
use embedded_io_async::Read as AsyncRead;
Expand All @@ -10,22 +8,15 @@ use crate::{
TlsError,
};

pub struct RecordReader<'a, CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
pub struct RecordReader<'a> {
pub(crate) buf: &'a mut [u8],
/// The number of decoded bytes in the buffer
decoded: usize,
/// The number of read but not yet decoded bytes in the buffer
pending: usize,
cipher_suite: PhantomData<CipherSuite>,
}

impl<'a, CipherSuite> RecordReader<'a, CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
impl<'a> RecordReader<'a> {
pub fn new(buf: &'a mut [u8]) -> Self {
if buf.len() < 16640 {
warn!("Read buffer is smaller than 16640 bytes, which may cause problems!");
Expand All @@ -34,33 +25,26 @@ where
buf,
decoded: 0,
pending: 0,
cipher_suite: PhantomData,
}
}

pub async fn read<'m>(
pub async fn read<'m, CipherSuite: TlsCipherSuite>(
&'m mut self,
transport: &mut impl AsyncRead,
key_schedule: &mut ReadKeySchedule<CipherSuite>,
) -> Result<ServerRecord<'m, CipherSuite>, TlsError> {
let header = self.advance(transport, 5).await?;
let header = RecordHeader::decode(unwrap!(header.try_into().ok()))?;

let content_length = header.content_length();
debug!(
"advance: {:?} - content_length = {} bytes",
header.content_type(),
content_length
);
let data = self.advance(transport, content_length).await?;
ServerRecord::decode(header, data, key_schedule.transcript_hash())
self.advance(transport, RecordHeader::LEN).await?;
let header = self.record_header()?;
self.advance(transport, RecordHeader::LEN + header.content_length())
.await?;
self.consume(header, key_schedule.transcript_hash())
}

async fn advance<'m>(
&'m mut self,
transport: &mut impl AsyncRead,
amount: usize,
) -> Result<&'m mut [u8], TlsError> {
) -> Result<(), TlsError> {
self.ensure_contiguous(amount)?;

while self.pending < amount {
Expand All @@ -74,27 +58,25 @@ where
self.pending += read;
}

Ok(self.consume(amount))
Ok(())
}

pub fn read_blocking<'m>(
pub fn read_blocking<'m, CipherSuite: TlsCipherSuite>(
&'m mut self,
transport: &mut impl BlockingRead,
key_schedule: &mut ReadKeySchedule<CipherSuite>,
) -> Result<ServerRecord<'m, CipherSuite>, TlsError> {
let header = self.advance_blocking(transport, 5)?;
let header = RecordHeader::decode(unwrap!(header.try_into().ok()))?;

let content_length = header.content_length();
let data = self.advance_blocking(transport, content_length)?;
ServerRecord::decode(header, data, key_schedule.transcript_hash())
self.advance_blocking(transport, RecordHeader::LEN)?;
let header = self.record_header()?;
self.advance_blocking(transport, RecordHeader::LEN + header.content_length())?;
self.consume(header, key_schedule.transcript_hash())
}

fn advance_blocking<'m>(
&'m mut self,
transport: &mut impl BlockingRead,
amount: usize,
) -> Result<&'m mut [u8], TlsError> {
) -> Result<(), TlsError> {
self.ensure_contiguous(amount)?;

while self.pending < amount {
Expand All @@ -107,14 +89,30 @@ where
self.pending += read;
}

Ok(self.consume(amount))
Ok(())
}

fn consume(&mut self, amount: usize) -> &mut [u8] {
let slice = &mut self.buf[self.decoded..self.decoded + amount];
self.decoded += amount;
self.pending -= amount;
slice
fn record_header(&self) -> Result<RecordHeader, TlsError> {
RecordHeader::decode(unwrap!(self.buf
[self.decoded..self.decoded + RecordHeader::LEN]
.try_into()
.ok()))
}

fn consume<'m, CipherSuite: TlsCipherSuite>(
&'m mut self,
header: RecordHeader,
digest: &mut CipherSuite::Hash,
) -> Result<ServerRecord<'m, CipherSuite>, TlsError> {
let content_len = header.content_length();

let slice = &mut self.buf
[self.decoded + RecordHeader::LEN..self.decoded + RecordHeader::LEN + content_len];

self.decoded += RecordHeader::LEN + content_len;
self.pending -= RecordHeader::LEN + content_len;

ServerRecord::decode(header, slice, digest)
}

fn ensure_contiguous(&mut self, len: usize) -> Result<(), TlsError> {
Expand Down Expand Up @@ -207,7 +205,7 @@ mod tests {
);

let mut buf = [0; 32];
let mut reader = RecordReader::<Aes128GcmSha256>::new(&mut buf);
let mut reader = RecordReader::new(&mut buf);
let mut key_schedule = KeySchedule::<Aes128GcmSha256>::new();

{
Expand Down Expand Up @@ -265,8 +263,8 @@ mod tests {
]
.as_slice();

let mut buf = [0; 5]; // This buffer is so small that it cannot contain both the header and data
let mut reader = RecordReader::<Aes128GcmSha256>::new(&mut buf);
let mut buf = [0; 9]; // This buffer is so small that it cannot contain both the header and data
let mut reader = RecordReader::new(&mut buf);
let mut key_schedule = KeySchedule::<Aes128GcmSha256>::new();

{
Expand All @@ -279,8 +277,8 @@ mod tests {
panic!("Wrong server record");
}

assert_eq!(4, reader.decoded); // The buffer is rotated after decoding the header
assert_eq!(1, reader.pending);
assert_eq!(9, reader.decoded); // The buffer is rotated after decoding the header
assert_eq!(0, reader.pending);
}

{
Expand All @@ -293,7 +291,7 @@ mod tests {
panic!("Wrong server record");
}

assert_eq!(2, reader.decoded);
assert_eq!(7, reader.decoded);
assert_eq!(0, reader.pending);
}
}
Expand All @@ -318,7 +316,7 @@ mod tests {
.as_slice();

let mut buf = [0; 32];
let mut reader = RecordReader::<Aes128GcmSha256>::new(&mut buf);
let mut reader = RecordReader::new(&mut buf);
let mut key_schedule = KeySchedule::<Aes128GcmSha256>::new();

{
Expand Down

0 comments on commit 9de49c2

Please sign in to comment.