diff --git a/mls-rs-core/src/group/group_state.rs b/mls-rs-core/src/group/group_state.rs index 42ece05f..659901c0 100644 --- a/mls-rs-core/src/group/group_state.rs +++ b/mls-rs-core/src/group/group_state.rs @@ -2,22 +2,51 @@ // Copyright by contributors to this project. // SPDX-License-Identifier: (Apache-2.0 OR MIT) +use core::fmt::{self, Debug}; + use crate::error::IntoAnyError; #[cfg(mls_build_async)] use alloc::boxed::Box; use alloc::vec::Vec; -use mls_rs_codec::{MlsDecode, MlsEncode}; /// Generic representation of a group's state. -pub trait GroupState { +#[derive(Clone, PartialEq, Eq)] +pub struct GroupState { /// A unique group identifier. - fn id(&self) -> Vec; + pub id: Vec, + pub data: Vec, +} + +impl Debug for GroupState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GroupState") + .field("id", &crate::debug::pretty_bytes(&self.id)) + .field("data", &crate::debug::pretty_bytes(&self.data)) + .finish() + } } /// Generic representation of a prior epoch. -pub trait EpochRecord { +#[derive(Clone, PartialEq, Eq)] +pub struct EpochRecord { /// A unique epoch identifier within a particular group. - fn id(&self) -> u64; + pub id: u64, + pub data: Vec, +} + +impl Debug for EpochRecord { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("EpochRecord") + .field("id", &self.id) + .field("data", &crate::debug::pretty_bytes(&self.data)) + .finish() + } +} + +impl EpochRecord { + pub fn new(id: u64, data: Vec) -> Self { + Self { id, data } + } } /// Storage that can persist and reload a group state. @@ -41,14 +70,10 @@ pub trait GroupStateStorage: Send + Sync { type Error: IntoAnyError; /// Fetch a group state from storage. - async fn state(&self, group_id: &[u8]) -> Result, Self::Error> - where - T: GroupState + MlsEncode + MlsDecode; + async fn state(&self, group_id: &[u8]) -> Result>, Self::Error>; /// Lazy load cached epoch data from a particular group. - async fn epoch(&self, group_id: &[u8], epoch_id: u64) -> Result, Self::Error> - where - T: EpochRecord + MlsEncode + MlsDecode; + async fn epoch(&self, group_id: &[u8], epoch_id: u64) -> Result>, Self::Error>; /// Write pending state updates. /// @@ -69,15 +94,12 @@ pub trait GroupStateStorage: Send + Sync { /// of this trait. Calls to [`write`](GroupStateStorage::write) should /// optimally be a single atomic transaction in order to avoid partial writes /// that may corrupt the group state. - async fn write( + async fn write( &mut self, - state: ST, - epoch_inserts: Vec, - epoch_updates: Vec, - ) -> Result<(), Self::Error> - where - ST: GroupState + MlsEncode + MlsDecode + Send + Sync, - ET: EpochRecord + MlsEncode + MlsDecode + Send + Sync; + state: GroupState, + epoch_inserts: Vec, + epoch_updates: Vec, + ) -> Result<(), Self::Error>; /// The [`EpochRecord::id`] value that is associated with a stored /// prior epoch for a particular group. diff --git a/mls-rs-provider-sqlite/src/group_state.rs b/mls-rs-provider-sqlite/src/group_state.rs index 122c580d..62bd4b25 100644 --- a/mls-rs-provider-sqlite/src/group_state.rs +++ b/mls-rs-provider-sqlite/src/group_state.rs @@ -2,13 +2,10 @@ // Copyright by contributors to this project. // SPDX-License-Identifier: (Apache-2.0 OR MIT) -use mls_rs_core::{ - group::{EpochRecord, GroupState, GroupStateStorage}, - mls_rs_codec::{MlsDecode, MlsEncode}, -}; +use mls_rs_core::group::{EpochRecord, GroupState, GroupStateStorage}; use rusqlite::{params, Connection, OptionalExtension}; use std::{ - fmt::{self, Debug}, + fmt::Debug, sync::{Arc, Mutex}, }; @@ -16,27 +13,6 @@ use crate::SqLiteDataStorageError; pub(crate) const DEFAULT_EPOCH_RETENTION_LIMIT: u64 = 3; -#[derive(Clone)] -struct StoredEpoch { - data: Vec, - id: u64, -} - -impl Debug for StoredEpoch { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("StoredEpoch") - .field("data", &mls_rs_core::debug::pretty_bytes(&self.data)) - .field("id", &self.id) - .finish() - } -} - -impl StoredEpoch { - fn new(id: u64, data: Vec) -> Self { - Self { id, data } - } -} - #[derive(Debug, Clone)] /// SQLite Storage for MLS group states. pub struct SqLiteGroupStateStorage { @@ -141,17 +117,13 @@ impl SqLiteGroupStateStorage { .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into())) } - fn update_group_state( + fn update_group_state( &self, group_id: &[u8], group_snapshot: Vec, - inserts: I, - mut updates: U, - ) -> Result<(), SqLiteDataStorageError> - where - I: Iterator>, - U: Iterator>, - { + inserts: Vec, + updates: Vec, + ) -> Result<(), SqLiteDataStorageError> { let mut max_epoch_id = None; let mut connection = self.connection.lock().unwrap(); @@ -167,7 +139,6 @@ impl SqLiteGroupStateStorage { // Insert new epochs as needed for epoch in inserts { - let epoch = epoch.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))?; max_epoch_id = Some(epoch.id); transaction @@ -180,9 +151,7 @@ impl SqLiteGroupStateStorage { } // Update existing epochs as needed - updates.try_for_each(|epoch| { - let epoch = epoch.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))?; - + updates.into_iter().try_for_each(|epoch| { transaction .execute( "UPDATE epoch SET epoch_data = ? WHERE group_id = ? AND epoch_id = ?", @@ -218,63 +187,28 @@ impl SqLiteGroupStateStorage { impl GroupStateStorage for SqLiteGroupStateStorage { type Error = SqLiteDataStorageError; - async fn write( + async fn write( &mut self, - state: ST, - epoch_inserts: Vec, - epoch_updates: Vec, - ) -> Result<(), Self::Error> - where - ST: GroupState + MlsEncode + MlsDecode + Send + Sync, - ET: EpochRecord + MlsEncode + MlsDecode + Send + Sync, - { - let group_id = state.id(); - - let snapshot_data = state - .mls_encode_to_vec() - .map_err(|e| SqLiteDataStorageError::DataConversionError(e.into()))?; - - let inserts = epoch_inserts.iter().map(|e| { - Ok(StoredEpoch::new( - e.id(), - e.mls_encode_to_vec() - .map_err(|e| SqLiteDataStorageError::DataConversionError(e.into()))?, - )) - }); - - let updates = epoch_updates.iter().map(|e| { - Ok(StoredEpoch::new( - e.id(), - e.mls_encode_to_vec() - .map_err(|err| SqLiteDataStorageError::DataConversionError(err.into()))?, - )) - }); - - self.update_group_state(group_id.as_slice(), snapshot_data, inserts, updates) + state: GroupState, + inserts: Vec, + updates: Vec, + ) -> Result<(), Self::Error> { + let group_id = state.id; + let snapshot_data = state.data; + + self.update_group_state(&group_id, snapshot_data, inserts, updates) } - async fn state(&self, group_id: &[u8]) -> Result, Self::Error> - where - T: GroupState + MlsEncode + MlsDecode, - { - self.get_snapshot_data(group_id)? - .map(|v| T::mls_decode(&mut v.as_slice())) - .transpose() - .map_err(|e| SqLiteDataStorageError::DataConversionError(e.into())) + async fn state(&self, group_id: &[u8]) -> Result>, Self::Error> { + self.get_snapshot_data(group_id) } async fn max_epoch_id(&self, group_id: &[u8]) -> Result, Self::Error> { self.max_epoch_id(group_id) } - async fn epoch(&self, group_id: &[u8], epoch_id: u64) -> Result, Self::Error> - where - T: EpochRecord + MlsEncode + MlsDecode, - { - self.get_epoch_data(group_id, epoch_id)? - .map(|v| T::mls_decode(&mut v.as_slice())) - .transpose() - .map_err(|e| SqLiteDataStorageError::DataConversionError(e.into())) + async fn epoch(&self, group_id: &[u8], epoch_id: u64) -> Result>, Self::Error> { + self.get_epoch_data(group_id, epoch_id) } } @@ -302,8 +236,8 @@ mod tests { gen_rand_bytes(1024) } - fn test_epoch(id: u64) -> StoredEpoch { - StoredEpoch { + fn test_epoch(id: u64) -> EpochRecord { + EpochRecord { data: gen_rand_bytes(256), id, } @@ -313,7 +247,7 @@ mod tests { storage: SqLiteGroupStateStorage, snapshot: Vec, group_id: Vec, - epoch_0: StoredEpoch, + epoch_0: EpochRecord, } fn setup_group_storage_test() -> TestData { @@ -326,8 +260,8 @@ mod tests { .update_group_state( &test_group_id, test_snapshot.clone(), - vec![test_epoch_0.clone()].into_iter().map(Ok), - vec![].into_iter(), + vec![test_epoch_0.clone()], + vec![], ) .unwrap(); @@ -370,8 +304,8 @@ mod tests { .update_group_state( &test_data.group_id, test_snapshot.clone(), - vec![].into_iter(), - vec![Ok(epoch_update.clone())].into_iter(), + vec![], + vec![epoch_update.clone()], ) .unwrap(); @@ -410,8 +344,8 @@ mod tests { .update_group_state( &test_data.group_id, test_snapshot(), - test_epochs.clone().into_iter().map(Ok), - vec![].into_iter(), + test_epochs.clone(), + vec![], ) .unwrap(); @@ -440,8 +374,8 @@ mod tests { .update_group_state( &test_data.group_id, test_snapshot(), - vec![test_epoch(1)].into_iter().map(Ok), - vec![].into_iter(), + vec![test_epoch(1)], + vec![], ) .unwrap(); @@ -453,8 +387,8 @@ mod tests { .update_group_state( &test_data.group_id, test_snapshot(), - test_epochs.clone().into_iter().map(Ok), - vec![Ok(new_epoch_1.clone())].into_iter(), + test_epochs.clone(), + vec![new_epoch_1.clone()], ) .unwrap(); @@ -480,12 +414,7 @@ mod tests { let group_id = b"test"; storage - .update_group_state( - group_id, - vec![0, 1, 2], - vec![].into_iter(), - vec![].into_iter().map(Ok), - ) + .update_group_state(group_id, vec![0, 1, 2], vec![], vec![]) .unwrap(); let res = storage.max_epoch_id(group_id).unwrap(); @@ -502,8 +431,8 @@ mod tests { .update_group_state( &test_data.group_id, test_snapshot(), - (1..10).map(test_epoch).map(Ok), - vec![].into_iter().map(Ok), + (1..10).map(test_epoch).collect(), + vec![], ) .unwrap(); @@ -529,8 +458,8 @@ mod tests { .update_group_state( &new_group, test_snapshot(), - vec![new_group_epoch.clone()].into_iter().map(Ok), - vec![].into_iter(), + vec![new_group_epoch.clone()], + vec![], ) .unwrap(); diff --git a/mls-rs-uniffi/src/config.rs b/mls-rs-uniffi/src/config.rs index 953317a3..8e177541 100644 --- a/mls-rs-uniffi/src/config.rs +++ b/mls-rs-uniffi/src/config.rs @@ -7,7 +7,7 @@ use mls_rs::{ use mls_rs_core::error::IntoAnyError; use mls_rs_crypto_openssl::OpensslCryptoProvider; -use self::group_state::GroupStateStorageWrapper; +use self::group_state::{GroupStateStorage, GroupStateStorageWrapper}; mod group_state; @@ -39,5 +39,5 @@ pub type UniFFIConfig = client_builder::WithIdentityProvider< #[derive(Debug, Clone, uniffi::Record)] pub struct ClientConfig { - pub group_state_storage: Arc, + pub group_state_storage: Arc, } diff --git a/mls-rs-uniffi/src/config/group_state.rs b/mls-rs-uniffi/src/config/group_state.rs index aa381725..dff3f187 100644 --- a/mls-rs-uniffi/src/config/group_state.rs +++ b/mls-rs-uniffi/src/config/group_state.rs @@ -1,30 +1,40 @@ use std::{fmt::Debug, sync::Arc}; -use mls_rs_core::mls_rs_codec::{MlsDecode, MlsEncode}; - use super::FFICallbackError; -#[derive(Clone, Debug, uniffi::Record)] +// TODO(mulmarta): we'd like to use GroupState and EpochRecord from mls-rs-core +// but this breaks python tests because using 2 crates makes uniffi generate +// a python module which must be in a subdirectory of the directory with test scripts +// which is not supported by the script we use. +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, uniffi::Record)] pub struct GroupState { + /// A unique group identifier. pub id: Vec, pub data: Vec, } -impl mls_rs_core::group::GroupState for GroupState { - fn id(&self) -> Vec { - self.id.clone() - } -} - -#[derive(Clone, Debug, uniffi::Record)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, uniffi::Record)] pub struct EpochRecord { + /// A unique epoch identifier within a particular group. pub id: u64, pub data: Vec, } -impl mls_rs_core::group::EpochRecord for EpochRecord { - fn id(&self) -> u64 { - self.id +impl From for GroupState { + fn from(value: mls_rs_core::group::GroupState) -> Self { + Self { + id: value.id, + data: value.data, + } + } +} + +impl From for EpochRecord { + fn from(value: mls_rs_core::group::EpochRecord) -> Self { + Self { + id: value.id, + data: value.data, + } } } @@ -63,65 +73,25 @@ impl From> for GroupStateStorageWrapper { impl mls_rs_core::group::GroupStateStorage for GroupStateStorageWrapper { type Error = FFICallbackError; - async fn state(&self, group_id: &[u8]) -> Result, Self::Error> - where - T: mls_rs_core::group::GroupState + MlsEncode + MlsDecode, - { - let state_data = self.0.state(group_id.to_vec())?; - - state_data - .as_deref() - .map(|v| T::mls_decode(&mut &*v)) - .transpose() - .map_err(Into::into) + async fn state(&self, group_id: &[u8]) -> Result>, Self::Error> { + self.0.state(group_id.to_vec()) } - async fn epoch(&self, group_id: &[u8], epoch_id: u64) -> Result, Self::Error> - where - T: mls_rs_core::group::EpochRecord + MlsEncode + MlsDecode, - { - let epoch_data = self.0.epoch(group_id.to_vec(), epoch_id)?; - - epoch_data - .as_deref() - .map(|v| T::mls_decode(&mut &*v)) - .transpose() - .map_err(Into::into) + async fn epoch(&self, group_id: &[u8], epoch_id: u64) -> Result>, Self::Error> { + self.0.epoch(group_id.to_vec(), epoch_id) } - async fn write( + async fn write( &mut self, - state: ST, - epoch_inserts: Vec, - epoch_updates: Vec, - ) -> Result<(), Self::Error> - where - ST: mls_rs_core::group::GroupState + MlsEncode + MlsDecode + Send + Sync, - ET: mls_rs_core::group::EpochRecord + MlsEncode + MlsDecode + Send + Sync, - { - let state = GroupState { - id: state.id(), - data: state.mls_encode_to_vec()?, - }; - - let epoch_to_record = |v: ET| -> Result<_, Self::Error> { - Ok(EpochRecord { - id: v.id(), - data: v.mls_encode_to_vec()?, - }) - }; - - let inserts = epoch_inserts - .into_iter() - .map(epoch_to_record) - .collect::, _>>()?; - - let updates = epoch_updates - .into_iter() - .map(epoch_to_record) - .collect::, _>>()?; - - self.0.write(state, inserts, updates) + state: mls_rs_core::group::GroupState, + inserts: Vec, + updates: Vec, + ) -> Result<(), Self::Error> { + self.0.write( + state.into(), + inserts.into_iter().map(Into::into).collect(), + updates.into_iter().map(Into::into).collect(), + ) } async fn max_epoch_id(&self, group_id: &[u8]) -> Result, Self::Error> { diff --git a/mls-rs/src/client.rs b/mls-rs/src/client.rs index 0799ada7..eea6f66c 100644 --- a/mls-rs/src/client.rs +++ b/mls-rs/src/client.rs @@ -13,12 +13,13 @@ use crate::group::{ message_signature::AuthenticatedContent, proposal::{AddProposal, Proposal}, }; -use crate::group::{ExportedTree, Group, NewMemberInfo}; +use crate::group::{snapshot::Snapshot, ExportedTree, Group, NewMemberInfo}; use crate::identity::SigningIdentity; use crate::key_package::{KeyPackageGeneration, KeyPackageGenerator}; use crate::protocol_version::ProtocolVersion; use crate::tree_kem::node::NodeIndex; use alloc::vec::Vec; +use mls_rs_codec::MlsDecode; use mls_rs_core::crypto::{CryptoProvider, SignatureSecretKey}; use mls_rs_core::error::{AnyError, IntoAnyError}; use mls_rs_core::extension::{ExtensionError, ExtensionList, ExtensionType}; @@ -614,6 +615,8 @@ where .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))? .ok_or(MlsError::GroupNotFound)?; + let snapshot = Snapshot::mls_decode(&mut &*snapshot)?; + Group::from_snapshot(self.config.clone(), snapshot).await } diff --git a/mls-rs/src/group/snapshot.rs b/mls-rs/src/group/snapshot.rs index 05798ca7..5d56c36b 100644 --- a/mls-rs/src/group/snapshot.rs +++ b/mls-rs/src/group/snapshot.rs @@ -39,7 +39,7 @@ use super::{cipher_suite_provider, epoch::EpochSecrets, state_repo::GroupStateRe #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub(crate) struct Snapshot { version: u16, - state: RawGroupState, + pub(crate) state: RawGroupState, private_tree: TreeKemPrivate, epoch_secrets: EpochSecrets, key_schedule: KeySchedule, @@ -51,12 +51,6 @@ pub(crate) struct Snapshot { signer: SignatureSecretKey, } -impl Snapshot { - pub(crate) fn group_id(&self) -> &[u8] { - &self.state.context.group_id - } -} - #[derive(Debug, MlsEncode, MlsDecode, MlsSize, PartialEq, Clone)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub(crate) struct RawGroupState { diff --git a/mls-rs/src/group/state_repo.rs b/mls-rs/src/group/state_repo.rs index 60846324..f6c0c18d 100644 --- a/mls-rs/src/group/state_repo.rs +++ b/mls-rs/src/group/state_repo.rs @@ -8,6 +8,8 @@ use crate::{group::PriorEpoch, key_package::KeyPackageRef}; use alloc::collections::VecDeque; use alloc::vec::Vec; use core::fmt::{self, Debug}; +use mls_rs_codec::{MlsDecode, MlsEncode}; +use mls_rs_core::group::{EpochRecord, GroupState}; use mls_rs_core::{error::IntoAnyError, group::GroupStateStorage, key_package::KeyPackageStorage}; use super::snapshot::Snapshot; @@ -126,10 +128,11 @@ where // Search the stored cache self.storage - .epoch::(&psk_id.psk_group_id.0, psk_id.psk_epoch) + .epoch(&psk_id.psk_group_id.0, psk_id.psk_epoch) .await - .map_err(|e| MlsError::GroupStorageError(e.into_any_error())) - .map(|e| e.map(|e| e.secrets.resumption_secret)) + .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))? + .map(|e| Ok(PriorEpoch::mls_decode(&mut &*e)?.secrets.resumption_secret)) + .transpose() } #[cfg(feature = "private_message")] @@ -150,18 +153,24 @@ where // Look in the cached updates map, and if not found look in disk storage // and insert into the updates map for future caching - Ok(match self.find_pending(epoch_id) { - Some(i) => self.pending_commit.updates.get_mut(i), + match self.find_pending(epoch_id) { + Some(i) => self.pending_commit.updates.get_mut(i).map(Ok), None => self .storage .epoch(&self.group_id, epoch_id) .await .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))? .and_then(|epoch| { - self.pending_commit.updates.push(epoch); - self.pending_commit.updates.last_mut() + PriorEpoch::mls_decode(&mut &*epoch) + .map(|epoch| { + self.pending_commit.updates.push(epoch); + self.pending_commit.updates.last_mut() + }) + .transpose() }), - }) + } + .transpose() + .map_err(Into::into) } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] @@ -185,11 +194,27 @@ where #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError> { - let inserts = self.pending_commit.inserts.iter().cloned().collect(); - let updates = self.pending_commit.updates.clone(); + let inserts = self + .pending_commit + .inserts + .iter() + .map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?))) + .collect::>()?; + + let updates = self + .pending_commit + .updates + .iter() + .map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?))) + .collect::>()?; + + let group_state = GroupState { + data: group_snapshot.mls_encode_to_vec()?, + id: group_snapshot.state.context.group_id, + }; self.storage - .write(group_snapshot, inserts, updates) + .write(group_state, inserts, updates) .await .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?; @@ -227,10 +252,7 @@ mod tests { test_utils::{random_bytes, test_member, TEST_GROUP}, PskGroupId, ResumptionPSKUsage, }, - storage_provider::{ - group_state::EpochData, - in_memory::{InMemoryGroupStateStorage, InMemoryKeyPackageStorage}, - }, + storage_provider::in_memory::{InMemoryGroupStateStorage, InMemoryKeyPackageStorage}, }; use super::*; @@ -321,7 +343,10 @@ mod tests { assert_eq!( stored.epoch_data.back().unwrap(), - &EpochData::new(test_epoch).unwrap() + &EpochRecord::new( + test_epoch.epoch_id(), + test_epoch.mls_encode_to_vec().unwrap() + ) ); } @@ -386,7 +411,7 @@ mod tests { assert_eq!( stored.epoch_data.back().unwrap(), - &EpochData::new(to_update).unwrap() + &EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap()) ); } @@ -434,12 +459,15 @@ mod tests { assert_eq!( stored.epoch_data.front().unwrap(), - &EpochData::new(to_update).unwrap() + &EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap()) ); assert_eq!( stored.epoch_data.back().unwrap(), - &EpochData::new(test_epoch_1).unwrap() + &EpochRecord::new( + test_epoch_1.epoch_id(), + test_epoch_1.mls_encode_to_vec().unwrap() + ) ); } diff --git a/mls-rs/src/group/state_repo_light.rs b/mls-rs/src/group/state_repo_light.rs index cc4f5b1b..ef823738 100644 --- a/mls-rs/src/group/state_repo_light.rs +++ b/mls-rs/src/group/state_repo_light.rs @@ -6,10 +6,10 @@ use crate::client::MlsError; use crate::key_package::KeyPackageRef; use alloc::vec::Vec; -use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; +use mls_rs_codec::MlsEncode; use mls_rs_core::{ error::IntoAnyError, - group::{EpochRecord, GroupStateStorage}, + group::{GroupState, GroupStateStorage}, key_package::KeyPackageStorage, }; @@ -47,8 +47,13 @@ where #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError> { + let group_state = GroupState { + data: group_snapshot.mls_encode_to_vec()?, + id: group_snapshot.state.context.group_id, + }; + self.storage - .write(group_snapshot, Vec::::new(), Vec::new()) + .write(group_state, Vec::new(), Vec::new()) .await .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?; @@ -63,15 +68,6 @@ where } } -#[derive(MlsSize, MlsEncode, MlsDecode)] -struct PriorEpoch {} - -impl EpochRecord for PriorEpoch { - fn id(&self) -> u64 { - 0 - } -} - #[cfg(test)] mod tests { use crate::{ diff --git a/mls-rs/src/storage_provider.rs b/mls-rs/src/storage_provider.rs index cb6e62ce..ffe8cd92 100644 --- a/mls-rs/src/storage_provider.rs +++ b/mls-rs/src/storage_provider.rs @@ -2,12 +2,10 @@ // Copyright by contributors to this project. // SPDX-License-Identifier: (Apache-2.0 OR MIT) -pub(crate) mod group_state; /// Storage providers that operate completely in memory. pub mod in_memory; pub(crate) mod key_package; -pub use group_state::*; pub use key_package::*; #[cfg(feature = "sqlite")] diff --git a/mls-rs/src/storage_provider/group_state.rs b/mls-rs/src/storage_provider/group_state.rs deleted file mode 100644 index 297a739a..00000000 --- a/mls-rs/src/storage_provider/group_state.rs +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// Copyright by contributors to this project. -// SPDX-License-Identifier: (Apache-2.0 OR MIT) - -use alloc::vec::Vec; -use core::fmt::{self, Debug}; -use mls_rs_codec::MlsEncode; -pub use mls_rs_core::group::{EpochRecord, GroupState}; - -use crate::group::snapshot::Snapshot; - -#[cfg(feature = "prior_epoch")] -use crate::group::epoch::PriorEpoch; - -#[cfg(feature = "prior_epoch")] -impl EpochRecord for PriorEpoch { - fn id(&self) -> u64 { - self.epoch_id() - } -} - -impl GroupState for Snapshot { - fn id(&self) -> Vec { - self.group_id().to_vec() - } -} - -#[derive(Clone, PartialEq, Eq)] -pub(crate) struct EpochData { - pub(crate) id: u64, - pub(crate) data: Vec, -} - -impl Debug for EpochData { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("EpochData") - .field("id", &self.id) - .field("data", &mls_rs_core::debug::pretty_bytes(&self.data)) - .finish() - } -} - -impl EpochData { - pub(crate) fn new(value: T) -> Result - where - T: MlsEncode + EpochRecord, - { - Ok(Self { - id: value.id(), - data: value.mls_encode_to_vec()?, - }) - } -} diff --git a/mls-rs/src/storage_provider/in_memory/group_state_storage.rs b/mls-rs/src/storage_provider/in_memory/group_state_storage.rs index c2eb9edc..5999ed03 100644 --- a/mls-rs/src/storage_provider/in_memory/group_state_storage.rs +++ b/mls-rs/src/storage_provider/in_memory/group_state_storage.rs @@ -10,13 +10,15 @@ use alloc::sync::Arc; #[cfg(mls_build_async)] use alloc::boxed::Box; use alloc::vec::Vec; -use core::fmt::{self, Debug}; -use mls_rs_codec::{MlsDecode, MlsEncode}; +use core::{ + convert::Infallible, + fmt::{self, Debug}, +}; use mls_rs_core::group::{EpochRecord, GroupState, GroupStateStorage}; #[cfg(not(target_has_atomic = "ptr"))] use portable_atomic_util::Arc; -use crate::{client::MlsError, storage_provider::group_state::EpochData}; +use crate::client::MlsError; #[cfg(feature = "std")] use std::collections::{hash_map::Entry, HashMap}; @@ -35,7 +37,7 @@ pub(crate) const DEFAULT_EPOCH_RETENTION_LIMIT: usize = 3; #[derive(Clone)] pub(crate) struct InMemoryGroupData { pub(crate) state_data: Vec, - pub(crate) epoch_data: VecDeque, + pub(crate) epoch_data: VecDeque, } impl Debug for InMemoryGroupData { @@ -64,24 +66,24 @@ impl InMemoryGroupData { .and_then(|e| epoch_id.checked_sub(e.id)) } - pub fn get_epoch(&self, epoch_id: u64) -> Option<&EpochData> { + pub fn get_epoch(&self, epoch_id: u64) -> Option<&EpochRecord> { self.get_epoch_data_index(epoch_id) .and_then(|i| self.epoch_data.get(i as usize)) } - pub fn get_mut_epoch(&mut self, epoch_id: u64) -> Option<&mut EpochData> { + pub fn get_mut_epoch(&mut self, epoch_id: u64) -> Option<&mut EpochRecord> { self.get_epoch_data_index(epoch_id) .and_then(|i| self.epoch_data.get_mut(i as usize)) } - pub fn insert_epoch(&mut self, epoch: EpochData) { + pub fn insert_epoch(&mut self, epoch: EpochRecord) { self.epoch_data.push_back(epoch) } // This function does not fail if an update can't be made. If the epoch // is not in the store, then it can no longer be accessed by future // get_epoch calls and is no longer relevant. - pub fn update_epoch(&mut self, epoch: EpochData) { + pub fn update_epoch(&mut self, epoch: EpochRecord) { if let Some(existing_epoch) = self.get_mut_epoch(epoch.id) { *existing_epoch = epoch } @@ -176,7 +178,7 @@ impl Default for InMemoryGroupStateStorage { #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(mls_build_async, maybe_async::must_be_async)] impl GroupStateStorage for InMemoryGroupStateStorage { - type Error = mls_rs_codec::Error; + type Error = Infallible; async fn max_epoch_id(&self, group_id: &[u8]) -> Result, Self::Error> { Ok(self @@ -185,60 +187,44 @@ impl GroupStateStorage for InMemoryGroupStateStorage { .and_then(|group_data| group_data.epoch_data.back().map(|e| e.id))) } - async fn state(&self, group_id: &[u8]) -> Result, Self::Error> - where - T: mls_rs_core::group::GroupState + MlsDecode, - { - self.lock() + async fn state(&self, group_id: &[u8]) -> Result>, Self::Error> { + Ok(self + .lock() .get(group_id) - .map(|v| T::mls_decode(&mut v.state_data.as_slice())) - .transpose() - .map_err(Into::into) + .map(|data| data.state_data.clone())) } - async fn epoch(&self, group_id: &[u8], epoch_id: u64) -> Result, Self::Error> - where - T: mls_rs_core::group::EpochRecord + MlsEncode + MlsDecode, - { - self.lock() + async fn epoch(&self, group_id: &[u8], epoch_id: u64) -> Result>, Self::Error> { + Ok(self + .lock() .get(group_id) - .and_then(|group_data| group_data.get_epoch(epoch_id)) - .map(|v| T::mls_decode(&mut v.data.as_slice())) - .transpose() - .map_err(Into::into) + .and_then(|data| data.get_epoch(epoch_id).map(|ep| ep.data.clone()))) } - async fn write( + async fn write( &mut self, - state: ST, - epoch_inserts: Vec, - epoch_updates: Vec, - ) -> Result<(), Self::Error> - where - ST: GroupState + MlsEncode + MlsDecode + Send + Sync, - ET: EpochRecord + MlsEncode + MlsDecode + Send + Sync, - { + state: GroupState, + epoch_inserts: Vec, + epoch_updates: Vec, + ) -> Result<(), Self::Error> { let mut group_map = self.lock(); - let state_data = state.mls_encode_to_vec()?; - let group_data = match group_map.entry(state.id()) { + let group_data = match group_map.entry(state.id) { Entry::Occupied(entry) => { let data = entry.into_mut(); - data.state_data = state_data; + data.state_data = state.data; data } - Entry::Vacant(entry) => entry.insert(InMemoryGroupData::new(state_data)), + Entry::Vacant(entry) => entry.insert(InMemoryGroupData::new(state.data)), }; - epoch_inserts.into_iter().try_for_each(|e| { - group_data.insert_epoch(EpochData::new(e)?); - Ok::<_, Self::Error>(()) - })?; + epoch_inserts + .into_iter() + .for_each(|e| group_data.insert_epoch(e)); - epoch_updates.into_iter().try_for_each(|e| { - group_data.update_epoch(EpochData::new(e)?); - Ok::<_, Self::Error>(()) - })?; + epoch_updates + .into_iter() + .for_each(|e| group_data.update_epoch(e)); group_data.trim_epochs(self.max_epoch_retention); @@ -248,22 +234,13 @@ impl GroupStateStorage for InMemoryGroupStateStorage { #[cfg(all(test, feature = "prior_epoch"))] mod tests { - use alloc::{vec, vec::Vec}; + use alloc::{format, vec, vec::Vec}; use assert_matches::assert_matches; use super::{InMemoryGroupData, InMemoryGroupStateStorage}; - use crate::{ - client::{test_utils::TEST_CIPHER_SUITE, MlsError}, - group::{ - epoch::{test_utils::get_test_epoch_with_id, PriorEpoch}, - snapshot::{test_utils::get_test_snapshot, Snapshot}, - test_utils::TEST_GROUP, - }, - storage_provider::EpochData, - }; - - use mls_rs_codec::MlsEncode; - use mls_rs_core::group::GroupStateStorage; + use crate::{client::MlsError, group::test_utils::TEST_GROUP}; + + use mls_rs_core::group::{EpochRecord, GroupState, GroupStateStorage}; impl InMemoryGroupStateStorage { fn test_data(&self) -> InMemoryGroupData { @@ -275,13 +252,15 @@ mod tests { InMemoryGroupStateStorage::new().with_max_epoch_retention(retention_limit) } - fn test_epoch(epoch_id: u64) -> PriorEpoch { - get_test_epoch_with_id(Vec::new(), TEST_CIPHER_SUITE, epoch_id) + fn test_epoch(epoch_id: u64) -> EpochRecord { + EpochRecord::new(epoch_id, format!("epoch {epoch_id}").as_bytes().to_vec()) } - #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - async fn test_snapshot(epoch_id: u64) -> Snapshot { - get_test_snapshot(TEST_CIPHER_SUITE, epoch_id).await + fn test_snapshot(epoch_id: u64) -> GroupState { + GroupState { + id: TEST_GROUP.into(), + data: format!("snapshot {epoch_id}").as_bytes().to_vec(), + } } #[test] @@ -296,7 +275,7 @@ mod tests { let epoch_inserts = vec![test_epoch(0), test_epoch(1)]; storage - .write(test_snapshot(1).await, epoch_inserts, Vec::new()) + .write(test_snapshot(0), epoch_inserts, Vec::new()) .await .unwrap(); @@ -307,7 +286,7 @@ mod tests { let epoch_inserts = vec![test_epoch(3), test_epoch(4)]; storage - .write(test_snapshot(1).await, epoch_inserts, Vec::new()) + .write(test_snapshot(1), epoch_inserts, Vec::new()) .await .unwrap(); @@ -321,7 +300,7 @@ mod tests { let epoch_inserts = vec![test_epoch(0), test_epoch(1), test_epoch(3), test_epoch(4)]; storage - .write(test_snapshot(1).await, epoch_inserts, Vec::new()) + .write(test_snapshot(1), epoch_inserts, Vec::new()) .await .unwrap(); @@ -332,7 +311,7 @@ mod tests { let epoch_inserts = vec![test_epoch(5)]; storage - .write(test_snapshot(1).await, epoch_inserts, Vec::new()) + .write(test_snapshot(1), epoch_inserts, Vec::new()) .await .unwrap(); @@ -357,7 +336,7 @@ mod tests { let updates = with_update .then_some(vec![test_epoch(0)]) .unwrap_or_default(); - let snapshot = test_snapshot(1).await; + let snapshot = test_snapshot(1); storage .write(snapshot.clone(), epoch_inserts.clone(), updates) @@ -366,10 +345,10 @@ mod tests { let stored = storage.test_data(); - assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap()); + assert_eq!(stored.state_data, snapshot.data); assert_eq!(stored.epoch_data.len(), 1); - let expected = EpochData::new(epoch_inserts.pop().unwrap()).unwrap(); + let expected = epoch_inserts.pop().unwrap(); assert_eq!(stored.epoch_data[0], expected); } }