Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize and simplify BitReaderReversed #81

Merged
merged 4 commits into from
Dec 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 93 additions & 167 deletions src/decoding/bit_reader_reverse.rs
Original file line number Diff line number Diff line change
@@ -1,223 +1,149 @@
use crate::io::Read;
use core::convert::TryInto;

/// Zstandard encodes some types of data in a way that the data must be read
/// back to front to decode it properly. `BitReaderReversed` provides a
/// convenient interface to do that.
pub struct BitReaderReversed<'s> {
idx: isize, //index counts bits already read
/// The index of the last read byte in the source.
index: usize,

/// How many bits have been consumed from `bit_container`.
bits_consumed: u8,

/// How many bits have been consumed past the end of the input. Will be zero until all the input
/// has been read.
extra_bits: usize,

/// The source data to read from.
source: &'s [u8],
/// The reader doesn't read directly from the source,
/// it reads bits from here, and the container is
/// "refilled" as it's emptied.

/// The reader doesn't read directly from the source, it reads bits from here, and the container
/// is "refilled" as it's emptied.
bit_container: u64,
bits_in_container: u8,
}

impl<'s> BitReaderReversed<'s> {
/// How many bits are left to read by the reader.
pub fn bits_remaining(&self) -> isize {
self.idx + self.bits_in_container as isize
self.index as isize * 8 + (64 - self.bits_consumed as isize) - self.extra_bits as isize
}

pub fn new(source: &'s [u8]) -> BitReaderReversed<'s> {
BitReaderReversed {
idx: source.len() as isize * 8,
index: source.len(),
bits_consumed: 64,
source,
bit_container: 0,
bits_in_container: 0,
extra_bits: 0,
}
}

/// We refill the container in full bytes, shifting the still unread portion to the left, and filling the lower bits with new data
#[inline(always)]
fn refill_container(&mut self) {
let byte_idx = self.byte_idx() as usize;

let retain_bytes = (self.bits_in_container + 7) / 8;
let want_to_read_bits = 64 - (retain_bytes * 8);

// if there are >= 8 byte left to read we go a fast path:
// The slice is looking something like this |U..UCCCCCCCCR..R| Where U are some unread bytes, C are the bytes in the container, and R are already read bytes
// What we do is, we shift the container by a few bytes to the left by just reading a u64 from the correct position, rereading the portion we did not yet return from the conainer.
// Technically this would still work for positions lower than 8 but this guarantees that enough bytes are in the source and generally makes for less edge cases
if byte_idx >= 8 {
self.refill_fast(byte_idx, retain_bytes, want_to_read_bits)
} else {
// In the slow path we just read however many bytes we can
self.refill_slow(byte_idx, want_to_read_bits)
#[cold]
fn refill(&mut self) {
let bytes_consumed = self.bits_consumed as usize / 8;
if bytes_consumed == 0 {
return;
}
}

#[inline(always)]
fn refill_fast(&mut self, byte_idx: usize, retain_bytes: u8, want_to_read_bits: u8) {
let load_from_byte_idx = byte_idx - 7 + retain_bytes as usize;
let tmp_bytes: [u8; 8] = (&self.source[load_from_byte_idx..][..8])
.try_into()
.unwrap();
let refill = u64::from_le_bytes(tmp_bytes);
self.bit_container = refill;
self.bits_in_container += want_to_read_bits;
self.idx -= want_to_read_bits as isize;
}

#[cold]
fn refill_slow(&mut self, byte_idx: usize, want_to_read_bits: u8) {
let can_read_bits = isize::min(want_to_read_bits as isize, self.idx);
let can_read_bytes = can_read_bits / 8;
let mut tmp_bytes = [0u8; 8];
let offset @ 1..=8 = can_read_bytes as usize else {
unreachable!()
};
let bits_read = offset * 8;

let _ = (&self.source[byte_idx - (offset - 1)..]).read_exact(&mut tmp_bytes[0..offset]);
self.bits_in_container += bits_read as u8;
self.idx -= bits_read as isize;
if offset < 8 {
self.bit_container <<= bits_read;
self.bit_container |= u64::from_le_bytes(tmp_bytes);
if self.index >= bytes_consumed {
self.index -= bytes_consumed;
self.bits_consumed &= 7;
self.bit_container =
u64::from_le_bytes((&self.source[self.index..][..8]).try_into().unwrap());
} else if self.index > 0 {
if self.source.len() >= 8 {
self.bit_container = u64::from_le_bytes((&self.source[..8]).try_into().unwrap());
} else {
let mut value = [0; 8];
value[..self.source.len()].copy_from_slice(self.source);
self.bit_container = u64::from_le_bytes(value);
}

self.bits_consumed -= 8 * self.index as u8;
self.index = 0;

self.bit_container <<= self.bits_consumed;
self.extra_bits += self.bits_consumed as usize;
self.bits_consumed = 0;
} else if self.bits_consumed < 64 {
self.bit_container <<= self.bits_consumed;
self.extra_bits += self.bits_consumed as usize;
self.bits_consumed = 0;
} else {
self.bit_container = u64::from_le_bytes(tmp_bytes);
self.extra_bits += self.bits_consumed as usize;
self.bits_consumed = 0;
self.bit_container = 0;
}
}

/// Next byte that should be read into the container
/// Negative values mean that the source buffer as been read into the container completetly.
fn byte_idx(&self) -> isize {
(self.idx - 1) / 8
// Assert that at least `56 = 64 - 8` bits are available to read.
debug_assert!(self.bits_consumed < 8);
}

/// Read `n` number of bits from the source. Will read at most 56 bits.
/// If there are no more bits to be read from the source zero bits will be returned instead.
#[inline(always)]
pub fn get_bits(&mut self, n: u8) -> u64 {
if n == 0 {
return 0;
}
if self.bits_in_container >= n {
return self.get_bits_unchecked(n);
if self.bits_consumed + n > 64 {
self.refill();
}

self.get_bits_cold(n)
let value = self.peek_bits(n);
self.consume(n);
value
}

#[cold]
fn get_bits_cold(&mut self, n: u8) -> u64 {
let n = u8::min(n, 56);
let signed_n = n as isize;

if self.bits_remaining() <= 0 {
self.idx -= signed_n;
/// Get the next `n` bits from the source without consuming them.
#[inline(always)]
pub fn peek_bits(&mut self, n: u8) -> u64 {
if n == 0 {
return 0;
}

if self.bits_remaining() < signed_n {
let emulated_read_shift = signed_n - self.bits_remaining();
let v = self.get_bits(self.bits_remaining() as u8);
debug_assert!(self.idx == 0);
let value = v.wrapping_shl(emulated_read_shift as u32);
self.idx -= emulated_read_shift;
return value;
}

while (self.bits_in_container < n) && self.idx > 0 {
self.refill_container();
}

debug_assert!(self.bits_in_container >= n);

//if we reach this point there are enough bits in the container
let mask = (1u64 << n) - 1u64;
let shift_by = 64 - self.bits_consumed - n;
(self.bit_container >> shift_by) & mask
}

self.get_bits_unchecked(n)
/// Consume `n` bits from the source.
#[inline(always)]
pub fn consume(&mut self, n: u8) {
self.bits_consumed += n;
debug_assert!(self.bits_consumed <= 64);
}

/// Same as calling get_bits three times but slightly more performant
#[inline(always)]
pub fn get_bits_triple(&mut self, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) {
let sum = n1 as usize + n2 as usize + n3 as usize;
if sum == 0 {
return (0, 0, 0);
}
if sum > 56 {
// try and get the values separately
return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3));
}
let sum = sum as u8;
let sum = n1 + n2 + n3;
if sum <= 56 {
self.refill();

if self.bits_in_container >= sum {
let v1 = if n1 == 0 {
0
} else {
self.get_bits_unchecked(n1)
};
let v2 = if n2 == 0 {
0
} else {
self.get_bits_unchecked(n2)
};
let v3 = if n3 == 0 {
0
} else {
self.get_bits_unchecked(n3)
};
let v1 = self.peek_bits(n1);
self.consume(n1);
let v2 = self.peek_bits(n2);
self.consume(n2);
let v3 = self.peek_bits(n3);
self.consume(n3);

return (v1, v2, v3);
}

self.get_bits_triple_cold(n1, n2, n3, sum)
(self.get_bits(n1), self.get_bits(n2), self.get_bits(n3))
}
}

#[cold]
fn get_bits_triple_cold(&mut self, n1: u8, n2: u8, n3: u8, sum: u8) -> (u64, u64, u64) {
let sum_signed = sum as isize;

if self.bits_remaining() <= 0 {
self.idx -= sum_signed;
return (0, 0, 0);
}

if self.bits_remaining() < sum_signed {
return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3));
}

while (self.bits_in_container < sum) && self.idx > 0 {
self.refill_container();
}

debug_assert!(self.bits_in_container >= sum);

//if we reach this point there are enough bits in the container

let v1 = if n1 == 0 {
0
} else {
self.get_bits_unchecked(n1)
};
let v2 = if n2 == 0 {
0
} else {
self.get_bits_unchecked(n2)
};
let v3 = if n3 == 0 {
0
} else {
self.get_bits_unchecked(n3)
};

(v1, v2, v3)
}

#[inline(always)]
fn get_bits_unchecked(&mut self, n: u8) -> u64 {
let shift_by = self.bits_in_container - n;
let mask = (1u64 << n) - 1u64;

let value = self.bit_container >> shift_by;
self.bits_in_container -= n;
let value_masked = value & mask;
debug_assert!(value_masked < (1 << n));

value_masked
#[cfg(test)]
mod test {

#[test]
fn it_works() {
let data = [0b10101010, 0b01010101];
let mut br = super::BitReaderReversed::new(&data);
assert_eq!(br.get_bits(1), 0);
assert_eq!(br.get_bits(1), 1);
assert_eq!(br.get_bits(1), 0);
assert_eq!(br.get_bits(4), 0b1010);
assert_eq!(br.get_bits(4), 0b1101);
}
}
Loading