diff --git a/mls-rs/Cargo.toml b/mls-rs/Cargo.toml index 94d99a54..c0154388 100644 --- a/mls-rs/Cargo.toml +++ b/mls-rs/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mls-rs" -version = "0.42.1" +version = "0.42.2" edition = "2021" description = "An implementation of Messaging Layer Security (RFC 9420)" homepage = "https://github.com/awslabs/mls-rs" diff --git a/mls-rs/src/group/proposal_cache.rs b/mls-rs/src/group/proposal_cache.rs index 1e9bfcb8..4053f871 100644 --- a/mls-rs/src/group/proposal_cache.rs +++ b/mls-rs/src/group/proposal_cache.rs @@ -4173,17 +4173,27 @@ mod tests { #[cfg(feature = "by_ref_proposal")] let receiver = receiver.with_extensions(extensions); - let (receiver, proposals, proposer) = if by_ref { + let (receiver, proposals, proposer, source) = if by_ref { let proposal_ref = make_proposal_ref(proposal, proposer).await; let receiver = receiver.cache(proposal_ref.clone(), proposal.clone(), proposer); - (receiver, vec![ProposalOrRef::from(proposal_ref)], proposer) + ( + receiver, + vec![ProposalOrRef::from(proposal_ref.clone())], + proposer, + ProposalSource::ByReference(proposal_ref), + ) } else { - (receiver, vec![proposal.clone().into()], committer) + ( + receiver, + vec![proposal.clone().into()], + committer, + ProposalSource::Local, + ) }; let res = receiver.receive(proposals).await; - if proposer_can_propose(proposer, proposal.proposal_type(), by_ref).is_err() { + if proposer_can_propose(proposer, proposal.proposal_type(), &source).is_err() { assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender)); } else { let is_self_update = proposal.proposal_type() == ProposalType::UPDATE diff --git a/mls-rs/src/group/proposal_filter/filtering.rs b/mls-rs/src/group/proposal_filter/filtering.rs index e49869e8..88c32645 100644 --- a/mls-rs/src/group/proposal_filter/filtering.rs +++ b/mls-rs/src/group/proposal_filter/filtering.rs @@ -21,7 +21,10 @@ use crate::{ CipherSuiteProvider, ExtensionList, }; -use super::filtering_common::{filter_out_invalid_psks, ApplyProposalsOutput, ProposalApplier}; +use super::{ + filtering_common::{filter_out_invalid_psks, ApplyProposalsOutput, ProposalApplier}, + ProposalSource, +}; #[cfg(feature = "by_ref_proposal")] use crate::extension::ExternalSendersExt; @@ -439,10 +442,10 @@ fn filter_out_external_init( pub(crate) fn proposer_can_propose( proposer: Sender, proposal_type: ProposalType, - by_ref: bool, + source: &ProposalSource, ) -> Result<(), MlsError> { - let can_propose = match (proposer, by_ref) { - (Sender::Member(_), false) => matches!( + let can_propose = match (proposer, source) { + (Sender::Member(_), ProposalSource::ByValue | ProposalSource::Local) => matches!( proposal_type, ProposalType::ADD | ProposalType::REMOVE @@ -450,7 +453,7 @@ pub(crate) fn proposer_can_propose( | ProposalType::RE_INIT | ProposalType::GROUP_CONTEXT_EXTENSIONS ), - (Sender::Member(_), true) => matches!( + (Sender::Member(_), ProposalSource::ByReference(_)) => matches!( proposal_type, ProposalType::ADD | ProposalType::UPDATE @@ -460,9 +463,9 @@ pub(crate) fn proposer_can_propose( | ProposalType::GROUP_CONTEXT_EXTENSIONS ), #[cfg(feature = "by_ref_proposal")] - (Sender::External(_), false) => false, + (Sender::External(_), ProposalSource::ByValue) => false, #[cfg(feature = "by_ref_proposal")] - (Sender::External(_), true) => matches!( + (Sender::External(_), _) => matches!( proposal_type, ProposalType::ADD | ProposalType::REMOVE @@ -470,13 +473,15 @@ pub(crate) fn proposer_can_propose( | ProposalType::PSK | ProposalType::GROUP_CONTEXT_EXTENSIONS ), - (Sender::NewMemberCommit, false) => matches!( + (Sender::NewMemberCommit, ProposalSource::ByValue | ProposalSource::Local) => matches!( proposal_type, ProposalType::REMOVE | ProposalType::PSK | ProposalType::EXTERNAL_INIT ), - (Sender::NewMemberCommit, true) => false, - (Sender::NewMemberProposal, false) => false, - (Sender::NewMemberProposal, true) => matches!(proposal_type, ProposalType::ADD), + (Sender::NewMemberCommit, ProposalSource::ByReference(_)) => false, + (Sender::NewMemberProposal, ProposalSource::ByValue | ProposalSource::Local) => false, + (Sender::NewMemberProposal, ProposalSource::ByReference(_)) => { + matches!(proposal_type, ProposalType::ADD) + } }; can_propose @@ -490,7 +495,7 @@ pub(crate) fn filter_out_invalid_proposers( ) -> Result { for i in (0..proposals.add_proposals().len()).rev() { let p = &proposals.add_proposals()[i]; - let res = proposer_can_propose(p.sender, ProposalType::ADD, p.is_by_reference()); + let res = proposer_can_propose(p.sender, ProposalType::ADD, &p.source); if !apply_strategy(strategy, p.is_by_reference(), res)? { proposals.remove::(i); @@ -499,7 +504,7 @@ pub(crate) fn filter_out_invalid_proposers( for i in (0..proposals.update_proposals().len()).rev() { let p = &proposals.update_proposals()[i]; - let res = proposer_can_propose(p.sender, ProposalType::UPDATE, p.is_by_reference()); + let res = proposer_can_propose(p.sender, ProposalType::UPDATE, &p.source); if !apply_strategy(strategy, p.is_by_reference(), res)? { proposals.remove::(i); @@ -509,7 +514,7 @@ pub(crate) fn filter_out_invalid_proposers( for i in (0..proposals.remove_proposals().len()).rev() { let p = &proposals.remove_proposals()[i]; - let res = proposer_can_propose(p.sender, ProposalType::REMOVE, p.is_by_reference()); + let res = proposer_can_propose(p.sender, ProposalType::REMOVE, &p.source); if !apply_strategy(strategy, p.is_by_reference(), res)? { proposals.remove::(i); @@ -519,7 +524,7 @@ pub(crate) fn filter_out_invalid_proposers( #[cfg(feature = "psk")] for i in (0..proposals.psk_proposals().len()).rev() { let p = &proposals.psk_proposals()[i]; - let res = proposer_can_propose(p.sender, ProposalType::PSK, p.is_by_reference()); + let res = proposer_can_propose(p.sender, ProposalType::PSK, &p.source); if !apply_strategy(strategy, p.is_by_reference(), res)? { proposals.remove::(i); @@ -528,7 +533,7 @@ pub(crate) fn filter_out_invalid_proposers( for i in (0..proposals.reinit_proposals().len()).rev() { let p = &proposals.reinit_proposals()[i]; - let res = proposer_can_propose(p.sender, ProposalType::RE_INIT, p.is_by_reference()); + let res = proposer_can_propose(p.sender, ProposalType::RE_INIT, &p.source); if !apply_strategy(strategy, p.is_by_reference(), res)? { proposals.remove::(i); @@ -537,7 +542,7 @@ pub(crate) fn filter_out_invalid_proposers( for i in (0..proposals.external_init_proposals().len()).rev() { let p = &proposals.external_init_proposals()[i]; - let res = proposer_can_propose(p.sender, ProposalType::EXTERNAL_INIT, p.is_by_reference()); + let res = proposer_can_propose(p.sender, ProposalType::EXTERNAL_INIT, &p.source); if !apply_strategy(strategy, p.is_by_reference(), res)? { proposals.remove::(i); @@ -547,7 +552,7 @@ pub(crate) fn filter_out_invalid_proposers( for i in (0..proposals.group_context_ext_proposals().len()).rev() { let p = &proposals.group_context_ext_proposals()[i]; let gce_type = ProposalType::GROUP_CONTEXT_EXTENSIONS; - let res = proposer_can_propose(p.sender, gce_type, p.is_by_reference()); + let res = proposer_can_propose(p.sender, gce_type, &p.source); if !apply_strategy(strategy, p.is_by_reference(), res)? { proposals.remove::(i);