Skip to content

Commit

Permalink
Remove generics from GroupStateStorage (#94)
Browse files Browse the repository at this point in the history
* Remove generics from GroupStateStorage

* wip

* Fixup

* Revert "wip"

This reverts commit edba2f3.

---------

Co-authored-by: Marta Mularczyk <mulmarta@amazon.com>
  • Loading branch information
mulmarta and Marta Mularczyk authored Mar 5, 2024
1 parent 1daa244 commit 2f7d6d2
Show file tree
Hide file tree
Showing 11 changed files with 229 additions and 363 deletions.
60 changes: 41 additions & 19 deletions mls-rs-core/src/group/group_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,51 @@
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use core::fmt::{self, Debug};

use crate::error::IntoAnyError;
#[cfg(mls_build_async)]
use alloc::boxed::Box;
use alloc::vec::Vec;
use mls_rs_codec::{MlsDecode, MlsEncode};

/// Generic representation of a group's state.
pub trait GroupState {
#[derive(Clone, PartialEq, Eq)]
pub struct GroupState {
/// A unique group identifier.
fn id(&self) -> Vec<u8>;
pub id: Vec<u8>,
pub data: Vec<u8>,
}

impl Debug for GroupState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GroupState")
.field("id", &crate::debug::pretty_bytes(&self.id))
.field("data", &crate::debug::pretty_bytes(&self.data))
.finish()
}
}

/// Generic representation of a prior epoch.
pub trait EpochRecord {
#[derive(Clone, PartialEq, Eq)]
pub struct EpochRecord {
/// A unique epoch identifier within a particular group.
fn id(&self) -> u64;
pub id: u64,
pub data: Vec<u8>,
}

impl Debug for EpochRecord {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("EpochRecord")
.field("id", &self.id)
.field("data", &crate::debug::pretty_bytes(&self.data))
.finish()
}
}

impl EpochRecord {
pub fn new(id: u64, data: Vec<u8>) -> Self {
Self { id, data }
}
}

/// Storage that can persist and reload a group state.
Expand All @@ -41,14 +70,10 @@ pub trait GroupStateStorage: Send + Sync {
type Error: IntoAnyError;

/// Fetch a group state from storage.
async fn state<T>(&self, group_id: &[u8]) -> Result<Option<T>, Self::Error>
where
T: GroupState + MlsEncode + MlsDecode;
async fn state(&self, group_id: &[u8]) -> Result<Option<Vec<u8>>, Self::Error>;

/// Lazy load cached epoch data from a particular group.
async fn epoch<T>(&self, group_id: &[u8], epoch_id: u64) -> Result<Option<T>, Self::Error>
where
T: EpochRecord + MlsEncode + MlsDecode;
async fn epoch(&self, group_id: &[u8], epoch_id: u64) -> Result<Option<Vec<u8>>, Self::Error>;

/// Write pending state updates.
///
Expand All @@ -69,15 +94,12 @@ pub trait GroupStateStorage: Send + Sync {
/// of this trait. Calls to [`write`](GroupStateStorage::write) should
/// optimally be a single atomic transaction in order to avoid partial writes
/// that may corrupt the group state.
async fn write<ST, ET>(
async fn write(
&mut self,
state: ST,
epoch_inserts: Vec<ET>,
epoch_updates: Vec<ET>,
) -> Result<(), Self::Error>
where
ST: GroupState + MlsEncode + MlsDecode + Send + Sync,
ET: EpochRecord + MlsEncode + MlsDecode + Send + Sync;
state: GroupState,
epoch_inserts: Vec<EpochRecord>,
epoch_updates: Vec<EpochRecord>,
) -> Result<(), Self::Error>;

/// The [`EpochRecord::id`] value that is associated with a stored
/// prior epoch for a particular group.
Expand Down
147 changes: 38 additions & 109 deletions mls-rs-provider-sqlite/src/group_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,17 @@
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use mls_rs_core::{
group::{EpochRecord, GroupState, GroupStateStorage},
mls_rs_codec::{MlsDecode, MlsEncode},
};
use mls_rs_core::group::{EpochRecord, GroupState, GroupStateStorage};
use rusqlite::{params, Connection, OptionalExtension};
use std::{
fmt::{self, Debug},
fmt::Debug,
sync::{Arc, Mutex},
};

use crate::SqLiteDataStorageError;

pub(crate) const DEFAULT_EPOCH_RETENTION_LIMIT: u64 = 3;

#[derive(Clone)]
struct StoredEpoch {
data: Vec<u8>,
id: u64,
}

impl Debug for StoredEpoch {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StoredEpoch")
.field("data", &mls_rs_core::debug::pretty_bytes(&self.data))
.field("id", &self.id)
.finish()
}
}

impl StoredEpoch {
fn new(id: u64, data: Vec<u8>) -> Self {
Self { id, data }
}
}

#[derive(Debug, Clone)]
/// SQLite Storage for MLS group states.
pub struct SqLiteGroupStateStorage {
Expand Down Expand Up @@ -141,17 +117,13 @@ impl SqLiteGroupStateStorage {
.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
}

fn update_group_state<I, U>(
fn update_group_state(
&self,
group_id: &[u8],
group_snapshot: Vec<u8>,
inserts: I,
mut updates: U,
) -> Result<(), SqLiteDataStorageError>
where
I: Iterator<Item = Result<StoredEpoch, SqLiteDataStorageError>>,
U: Iterator<Item = Result<StoredEpoch, SqLiteDataStorageError>>,
{
inserts: Vec<EpochRecord>,
updates: Vec<EpochRecord>,
) -> Result<(), SqLiteDataStorageError> {
let mut max_epoch_id = None;

let mut connection = self.connection.lock().unwrap();
Expand All @@ -167,7 +139,6 @@ impl SqLiteGroupStateStorage {

// Insert new epochs as needed
for epoch in inserts {
let epoch = epoch.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))?;
max_epoch_id = Some(epoch.id);

transaction
Expand All @@ -180,9 +151,7 @@ impl SqLiteGroupStateStorage {
}

// Update existing epochs as needed
updates.try_for_each(|epoch| {
let epoch = epoch.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))?;

updates.into_iter().try_for_each(|epoch| {
transaction
.execute(
"UPDATE epoch SET epoch_data = ? WHERE group_id = ? AND epoch_id = ?",
Expand Down Expand Up @@ -218,63 +187,28 @@ impl SqLiteGroupStateStorage {
impl GroupStateStorage for SqLiteGroupStateStorage {
type Error = SqLiteDataStorageError;

async fn write<ST, ET>(
async fn write(
&mut self,
state: ST,
epoch_inserts: Vec<ET>,
epoch_updates: Vec<ET>,
) -> Result<(), Self::Error>
where
ST: GroupState + MlsEncode + MlsDecode + Send + Sync,
ET: EpochRecord + MlsEncode + MlsDecode + Send + Sync,
{
let group_id = state.id();

let snapshot_data = state
.mls_encode_to_vec()
.map_err(|e| SqLiteDataStorageError::DataConversionError(e.into()))?;

let inserts = epoch_inserts.iter().map(|e| {
Ok(StoredEpoch::new(
e.id(),
e.mls_encode_to_vec()
.map_err(|e| SqLiteDataStorageError::DataConversionError(e.into()))?,
))
});

let updates = epoch_updates.iter().map(|e| {
Ok(StoredEpoch::new(
e.id(),
e.mls_encode_to_vec()
.map_err(|err| SqLiteDataStorageError::DataConversionError(err.into()))?,
))
});

self.update_group_state(group_id.as_slice(), snapshot_data, inserts, updates)
state: GroupState,
inserts: Vec<EpochRecord>,
updates: Vec<EpochRecord>,
) -> Result<(), Self::Error> {
let group_id = state.id;
let snapshot_data = state.data;

self.update_group_state(&group_id, snapshot_data, inserts, updates)
}

async fn state<T>(&self, group_id: &[u8]) -> Result<Option<T>, Self::Error>
where
T: GroupState + MlsEncode + MlsDecode,
{
self.get_snapshot_data(group_id)?
.map(|v| T::mls_decode(&mut v.as_slice()))
.transpose()
.map_err(|e| SqLiteDataStorageError::DataConversionError(e.into()))
async fn state(&self, group_id: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
self.get_snapshot_data(group_id)
}

async fn max_epoch_id(&self, group_id: &[u8]) -> Result<Option<u64>, Self::Error> {
self.max_epoch_id(group_id)
}

async fn epoch<T>(&self, group_id: &[u8], epoch_id: u64) -> Result<Option<T>, Self::Error>
where
T: EpochRecord + MlsEncode + MlsDecode,
{
self.get_epoch_data(group_id, epoch_id)?
.map(|v| T::mls_decode(&mut v.as_slice()))
.transpose()
.map_err(|e| SqLiteDataStorageError::DataConversionError(e.into()))
async fn epoch(&self, group_id: &[u8], epoch_id: u64) -> Result<Option<Vec<u8>>, Self::Error> {
self.get_epoch_data(group_id, epoch_id)
}
}

Expand Down Expand Up @@ -302,8 +236,8 @@ mod tests {
gen_rand_bytes(1024)
}

fn test_epoch(id: u64) -> StoredEpoch {
StoredEpoch {
fn test_epoch(id: u64) -> EpochRecord {
EpochRecord {
data: gen_rand_bytes(256),
id,
}
Expand All @@ -313,7 +247,7 @@ mod tests {
storage: SqLiteGroupStateStorage,
snapshot: Vec<u8>,
group_id: Vec<u8>,
epoch_0: StoredEpoch,
epoch_0: EpochRecord,
}

fn setup_group_storage_test() -> TestData {
Expand All @@ -326,8 +260,8 @@ mod tests {
.update_group_state(
&test_group_id,
test_snapshot.clone(),
vec![test_epoch_0.clone()].into_iter().map(Ok),
vec![].into_iter(),
vec![test_epoch_0.clone()],
vec![],
)
.unwrap();

Expand Down Expand Up @@ -370,8 +304,8 @@ mod tests {
.update_group_state(
&test_data.group_id,
test_snapshot.clone(),
vec![].into_iter(),
vec![Ok(epoch_update.clone())].into_iter(),
vec![],
vec![epoch_update.clone()],
)
.unwrap();

Expand Down Expand Up @@ -410,8 +344,8 @@ mod tests {
.update_group_state(
&test_data.group_id,
test_snapshot(),
test_epochs.clone().into_iter().map(Ok),
vec![].into_iter(),
test_epochs.clone(),
vec![],
)
.unwrap();

Expand Down Expand Up @@ -440,8 +374,8 @@ mod tests {
.update_group_state(
&test_data.group_id,
test_snapshot(),
vec![test_epoch(1)].into_iter().map(Ok),
vec![].into_iter(),
vec![test_epoch(1)],
vec![],
)
.unwrap();

Expand All @@ -453,8 +387,8 @@ mod tests {
.update_group_state(
&test_data.group_id,
test_snapshot(),
test_epochs.clone().into_iter().map(Ok),
vec![Ok(new_epoch_1.clone())].into_iter(),
test_epochs.clone(),
vec![new_epoch_1.clone()],
)
.unwrap();

Expand All @@ -480,12 +414,7 @@ mod tests {
let group_id = b"test";

storage
.update_group_state(
group_id,
vec![0, 1, 2],
vec![].into_iter(),
vec![].into_iter().map(Ok),
)
.update_group_state(group_id, vec![0, 1, 2], vec![], vec![])
.unwrap();

let res = storage.max_epoch_id(group_id).unwrap();
Expand All @@ -502,8 +431,8 @@ mod tests {
.update_group_state(
&test_data.group_id,
test_snapshot(),
(1..10).map(test_epoch).map(Ok),
vec![].into_iter().map(Ok),
(1..10).map(test_epoch).collect(),
vec![],
)
.unwrap();

Expand All @@ -529,8 +458,8 @@ mod tests {
.update_group_state(
&new_group,
test_snapshot(),
vec![new_group_epoch.clone()].into_iter().map(Ok),
vec![].into_iter(),
vec![new_group_epoch.clone()],
vec![],
)
.unwrap();

Expand Down
4 changes: 2 additions & 2 deletions mls-rs-uniffi/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use mls_rs::{
use mls_rs_core::error::IntoAnyError;
use mls_rs_crypto_openssl::OpensslCryptoProvider;

use self::group_state::GroupStateStorageWrapper;
use self::group_state::{GroupStateStorage, GroupStateStorageWrapper};

mod group_state;

Expand Down Expand Up @@ -39,5 +39,5 @@ pub type UniFFIConfig = client_builder::WithIdentityProvider<

#[derive(Debug, Clone, uniffi::Record)]
pub struct ClientConfig {
pub group_state_storage: Arc<dyn group_state::GroupStateStorage>,
pub group_state_storage: Arc<dyn GroupStateStorage>,
}
Loading

0 comments on commit 2f7d6d2

Please sign in to comment.