Skip to content

Commit

Permalink
Merge branch 'refs/heads/main' into donovan/executor
Browse files Browse the repository at this point in the history
  • Loading branch information
Tjemmmic committed Dec 19, 2024
2 parents 9670950 + 415b234 commit 97f55ea
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 47 deletions.
9 changes: 7 additions & 2 deletions crates/networking/src/gossip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ pub struct GossipHandle {
pub connected_peers: Arc<AtomicUsize>,
pub public_key_to_libp2p_id: Arc<RwLock<BTreeMap<PublicKey, PeerId>>>,
pub recent_messages: parking_lot::Mutex<LruCache<[u8; 32], ()>>,
pub my_id: PeerId,
pub my_id: PublicKey,
}

impl GossipHandle {
Expand Down Expand Up @@ -379,7 +379,8 @@ impl Network for GossipHandle {
}
}

async fn send_message(&self, message: ProtocolMessage) -> Result<(), Error> {
async fn send_message(&self, mut message: ProtocolMessage) -> Result<(), Error> {
message.sender.public_key = Some(self.my_id);
let message_type = if let Some(ParticipantInfo {
public_key: Some(to),
..
Expand Down Expand Up @@ -425,4 +426,8 @@ impl Network for GossipHandle {
.send(payload)
.map_err(|e| Error::NetworkError(format!("Failed to send intra-node payload: {e}")))
}

fn public_id(&self) -> PublicKey {
self.my_id
}
}
8 changes: 2 additions & 6 deletions crates/networking/src/handlers/connections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use crate::gossip::{MyBehaviourRequest, NetworkService};
use crate::key_types::Curve;
use gadget_crypto::{hashing::blake3_256, KeyType};
use gadget_crypto::KeyType;
use itertools::Itertools;
use libp2p::PeerId;

Expand All @@ -23,11 +23,7 @@ impl NetworkService<'_> {
{
let my_peer_id = *self.swarm.local_peer_id();
let msg = my_peer_id.to_bytes();
let hash = blake3_256(&msg);
match <Curve as KeyType>::sign_with_secret_pre_hashed(
&mut self.secret_key.clone(),
&hash,
) {
match <Curve as KeyType>::sign_with_secret(&mut self.secret_key.clone(), &msg) {
Ok(signature) => {
let handshake = MyBehaviourRequest::Handshake {
public_key: self.secret_key.public(),
Expand Down
2 changes: 1 addition & 1 deletion crates/networking/src/handlers/gossip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ impl NetworkService<'_> {
.find(|r| r.0.to_string() == topic)
{
if let Err(e) = tx.send(raw_payload) {
gadget_logging::error!("Failed to send message to worker: {e}");
gadget_logging::warn!("Failed to send message to worker: {e}");
}
} else {
gadget_logging::error!("No registered worker for topic: {topic}!");
Expand Down
21 changes: 8 additions & 13 deletions crates/networking/src/handlers/p2p.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

use crate::gossip::{MyBehaviourRequest, MyBehaviourResponse, NetworkService};
use crate::key_types::Curve;
use gadget_crypto::hashing::blake3_256;
use gadget_crypto::KeyType;
use gadget_std::string::ToString;
use libp2p::gossipsub::IdentTopic;
Expand Down Expand Up @@ -91,10 +90,9 @@ impl NetworkService<'_> {
gadget_logging::trace!("Received handshake from peer: {peer}");
// Verify the signature
let msg = peer.to_bytes();
let hash = blake3_256(&msg);
let valid = <Curve as KeyType>::verify(&public_key, &hash, &signature);
let valid = <Curve as KeyType>::verify(&public_key, &msg, &signature);
if !valid {
gadget_logging::warn!("Invalid signature from peer: {peer}");
gadget_logging::warn!("Invalid initial handshake signature from peer: {peer}");
let _ = self.swarm.disconnect_peer_id(peer);
return;
}
Expand All @@ -110,11 +108,7 @@ impl NetworkService<'_> {
// Send response with our public key
let my_peer_id = self.swarm.local_peer_id();
let msg = my_peer_id.to_bytes();
let hash = blake3_256(&msg);
match <Curve as KeyType>::sign_with_secret_pre_hashed(
&mut self.secret_key.clone(),
&hash,
) {
match <Curve as KeyType>::sign_with_secret(&mut self.secret_key.clone(), &msg) {
Ok(signature) => self.swarm.behaviour_mut().p2p.send_response(
channel,
MyBehaviourResponse::Handshaked {
Expand All @@ -141,7 +135,7 @@ impl NetworkService<'_> {
.find(|r| r.0.to_string() == topic.to_string())
{
if let Err(e) = tx.send(raw_payload) {
gadget_logging::error!("Failed to send message to worker: {e}");
gadget_logging::warn!("Failed to send message to worker: {e}");
}
} else {
gadget_logging::error!("No registered worker for topic: {topic}!");
Expand Down Expand Up @@ -172,10 +166,11 @@ impl NetworkService<'_> {
} => {
gadget_logging::trace!("Received handshake-ack message from peer: {peer}");
let msg = peer.to_bytes();
let hash = blake3_256(&msg);
let valid = <Curve as KeyType>::verify(&public_key, &hash, &signature);
let valid = <Curve as KeyType>::verify(&public_key, &msg, &signature);
if !valid {
gadget_logging::warn!("Invalid signature from peer: {peer}");
gadget_logging::warn!(
"Invalid handshake-acknowledgement signature from peer: {peer}"
);
// TODO: report this peer.
self.public_key_to_libp2p_id
.write()
Expand Down
83 changes: 60 additions & 23 deletions crates/networking/src/networking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,24 @@ pub trait Network: Send + Sync + 'static {
async fn next_message(&self) -> Option<ProtocolMessage>;
async fn send_message(&self, message: ProtocolMessage) -> Result<(), Error>;

fn public_id(&self) -> PublicKey;

fn build_protocol_message<Payload: Serialize>(
&self,
identifier_info: IdentifierInfo,
from: UserID,
to: Option<UserID>,
payload: &Payload,
from_account_id: Option<PublicKey>,
to_network_id: Option<PublicKey>,
) -> ProtocolMessage {
assert!(
(u8::from(to.is_none()) + u8::from(to_network_id.is_none()) != 1),
"Either `to` must be Some AND `to_network_id` is Some, or, both None"
);

let sender_participant_info = ParticipantInfo {
user_id: from,
public_key: from_account_id,
public_key: Some(self.public_id()),
};
let receiver_participant_info = to.map(|to| ParticipantInfo {
user_id: to,
Expand Down Expand Up @@ -141,6 +148,7 @@ pub struct NetworkMultiplexer {
unclaimed_receiving_streams: Arc<DashMap<StreamKey, MultiplexedReceiver>>,
tx_to_networking_layer: MultiplexedSender,
sequence_numbers: Arc<DashMap<CompoundStreamKey, u64>>,
my_id: PublicKey,
}

type ActiveStreams = Arc<DashMap<StreamKey, tokio::sync::mpsc::UnboundedSender<ProtocolMessage>>>;
Expand Down Expand Up @@ -250,6 +258,7 @@ impl NetworkMultiplexer {
pub fn new<N: Network>(network: N) -> Self {
let (tx_to_networking_layer, mut rx_from_substreams) =
tokio::sync::mpsc::unbounded_channel();
let my_id = network.public_id();
let this = NetworkMultiplexer {
to_receiving_streams: Arc::new(DashMap::new()),
unclaimed_receiving_streams: Arc::new(DashMap::new()),
Expand All @@ -258,6 +267,7 @@ impl NetworkMultiplexer {
stream_id: StreamKey::default(),
},
sequence_numbers: Arc::new(DashMap::new()),
my_id,
};

let active_streams = this.to_receiving_streams.clone();
Expand Down Expand Up @@ -317,6 +327,16 @@ impl NetworkMultiplexer {
let mut expected_seqs: HashMap<CompoundStreamKey, u64> = HashMap::default();

while let Some(mut msg) = network_clone.next_message().await {
if let Some(recv) = msg.recipient.as_ref() {
if let Some(recv_pk) = &recv.public_key {
if recv_pk != &my_id {
gadget_logging::warn!(
"Received a message not intended for the local user"
);
}
}
}

if let Ok(multiplexed_message) =
bincode::deserialize::<MultiplexedMessage>(&msg.payload)
{
Expand Down Expand Up @@ -419,12 +439,14 @@ impl NetworkMultiplexer {

pub fn multiplex(&self, id: impl Into<StreamKey>) -> SubNetwork {
let id = id.into();
let my_id = self.my_id;
let mut tx_to_networking_layer = self.tx_to_networking_layer.clone();
if let Some(unclaimed) = self.unclaimed_receiving_streams.remove(&id) {
tx_to_networking_layer.stream_id = id;
return SubNetwork {
tx: tx_to_networking_layer,
rx: Some(unclaimed.1.into()),
my_id,
};
}

Expand All @@ -437,6 +459,7 @@ impl NetworkMultiplexer {
SubNetwork {
tx,
rx: Some(rx.into()),
my_id,
}
}

Expand Down Expand Up @@ -510,6 +533,7 @@ impl<N: Network> From<N> for NetworkMultiplexer {
pub struct SubNetwork {
tx: MultiplexedSender,
rx: Option<Mutex<MultiplexedReceiver>>,
my_id: PublicKey,
}

impl SubNetwork {
Expand Down Expand Up @@ -542,6 +566,10 @@ impl Network for SubNetwork {
async fn send_message(&self, message: ProtocolMessage) -> Result<(), Error> {
self.send(message)
}

fn public_id(&self) -> PublicKey {
self.my_id
}
}

#[cfg(test)]
Expand Down Expand Up @@ -637,9 +665,14 @@ mod tests {

wait_for_nodes_connected(&nodes).await;

let mut mapping = BTreeMap::new();
for (i, node) in nodes.iter().enumerate() {
mapping.insert(i as u16, node.my_id);
}

let mut tasks = Vec::new();
for (i, node) in nodes.into_iter().enumerate() {
let task = tokio::spawn(run_protocol(node, i as u16));
let task = tokio::spawn(run_protocol(node, i as u16, mapping.clone()));
tasks.push(task);
}
// Wait for all tasks to finish
Expand All @@ -654,7 +687,11 @@ mod tests {
}

#[allow(clippy::too_many_lines)]
async fn run_protocol<N: Network>(node: N, i: u16) -> Result<(), crate::Error> {
async fn run_protocol<N: Network>(
node: N,
i: u16,
mapping: BTreeMap<u16, crate::PublicKey>,
) -> Result<(), crate::Error> {
let task_hash = [0u8; 32];
// Safety note: We should be passed a NetworkMultiplexer, and all uses of the N: Network
// used throughout the program must also use the multiplexer to prevent mixed messages.
Expand Down Expand Up @@ -684,8 +721,7 @@ mod tests {
armor: i + 2,
name: format!("Player {}", i),
};

GossipHandle::build_protocol_message(
round1_network.build_protocol_message(
IdentifierInfo {
message_id: 0,
round_id: 0,
Expand All @@ -694,7 +730,6 @@ mod tests {
None,
&Msg::Round1(round),
None,
None,
)
};

Expand Down Expand Up @@ -737,16 +772,16 @@ mod tests {
let msgs = (0..NODE_COUNT)
.filter(|&j| j != i)
.map(|j| {
GossipHandle::build_protocol_message(
let peer_pk = mapping.get(&j).copied().unwrap();
round2_network.build_protocol_message(
IdentifierInfo {
message_id: 0,
round_id: 0,
},
i,
Some(j),
&Msg::Round2(msg.clone()),
None,
None,
Some(peer_pk),
)
})
.collect::<Vec<_>>();
Expand All @@ -762,7 +797,14 @@ mod tests {
let mut msgs = BTreeMap::new();
while let Some(msg) = round2_network.recv().await {
let m = deserialize::<Msg>(&msg.payload).unwrap();
gadget_logging::debug!(from = %msg.sender.user_id, ?m, "Received message");
gadget_logging::info!(
"[Node {}] Received message from {} | Intended Recipient: {}",
i,
msg.sender.user_id,
msg.recipient
.as_ref()
.map_or_else(|| "Broadcast".into(), |r| r.user_id.to_string())
);
// Expecting Round2 message
assert!(
matches!(m, Msg::Round2(_)),
Expand Down Expand Up @@ -790,7 +832,7 @@ mod tests {
rotation: i * 30,
velocity: (i + 1, i + 2, i + 3),
};
GossipHandle::build_protocol_message(
round3_network.build_protocol_message(
IdentifierInfo {
message_id: 0,
round_id: 0,
Expand All @@ -799,7 +841,6 @@ mod tests {
None,
&Msg::Round3(round),
None,
None,
)
};

Expand Down Expand Up @@ -907,12 +948,11 @@ mod tests {

let send_task = async move {
for i in 0..MESSAGE_COUNT {
let msg = GossipHandle::build_protocol_message(
let msg = sub0.build_protocol_message(
IdentifierInfo::default(),
0,
Some(1),
&StressTestPayload { value: i },
Some(public0),
Some(public1),
);
sub0.send(msg).unwrap();
Expand Down Expand Up @@ -942,12 +982,11 @@ mod tests {

let send_task = async move {
for i in 0..MESSAGE_COUNT {
let msg = GossipHandle::build_protocol_message(
let msg = sub1.build_protocol_message(
IdentifierInfo::default(),
1,
Some(0),
&StressTestPayload { value: i },
Some(public1),
Some(public0),
);
sub1.send(msg).unwrap();
Expand Down Expand Up @@ -1008,27 +1047,25 @@ mod tests {

// Send a message in the subnetwork0 to subnetwork1 and vice versa, assert values of message
let payload = vec![1, 2, 3];
let msg = GossipHandle::build_protocol_message(
let msg = subnetwork0.build_protocol_message(
IdentifierInfo::default(),
0,
Some(1),
&payload,
None,
None,
Some(subnetwork1.public_id()),
);

subnetwork0.send(msg.clone()).unwrap();

let received_msg = subnetwork1.recv().await.unwrap();
assert_eq!(received_msg.payload, msg.payload);

let msg = GossipHandle::build_protocol_message(
let msg = subnetwork1.build_protocol_message(
IdentifierInfo::default(),
1,
Some(0),
&payload,
None,
None,
Some(subnetwork0.public_id()),
);

subnetwork1.send(msg.clone()).unwrap();
Expand Down
Loading

0 comments on commit 97f55ea

Please sign in to comment.