diff --git a/src/protocol/libp2p/kademlia/handle.rs b/src/protocol/libp2p/kademlia/handle.rs index 15903237..a48fd051 100644 --- a/src/protocol/libp2p/kademlia/handle.rs +++ b/src/protocol/libp2p/kademlia/handle.rs @@ -254,12 +254,12 @@ pub enum KademliaEvent { /// The type of the DHT records. #[derive(Debug, Clone)] pub enum RecordsType { - /// Record was found in the local store. + /// Record was found in the local store and [`Quorum::One`] was used. /// /// This contains only a single result. LocalStore(Record), - /// Records found in the network. + /// Records found in the network. This can include the locally found record. Network(Vec), } diff --git a/src/protocol/libp2p/kademlia/mod.rs b/src/protocol/libp2p/kademlia/mod.rs index e1321ba6..421904d1 100644 --- a/src/protocol/libp2p/kademlia/mod.rs +++ b/src/protocol/libp2p/kademlia/mod.rs @@ -1118,7 +1118,7 @@ impl Kademlia { .closest(&Key::new(key), self.replication_factor) .into(), quorum, - if record.is_some() { 1 } else { 0 }, + record.cloned(), ); } } diff --git a/src/protocol/libp2p/kademlia/query/get_record.rs b/src/protocol/libp2p/kademlia/query/get_record.rs index 12ea8293..019ece17 100644 --- a/src/protocol/libp2p/kademlia/query/get_record.rs +++ b/src/protocol/libp2p/kademlia/query/get_record.rs @@ -106,7 +106,11 @@ pub struct GetRecordContext { impl GetRecordContext { /// Create new [`GetRecordContext`]. - pub fn new(config: GetRecordConfig, in_peers: VecDeque) -> Self { + pub fn new( + config: GetRecordConfig, + in_peers: VecDeque, + found_records: Vec, + ) -> Self { let mut candidates = BTreeMap::new(); for candidate in &in_peers { @@ -123,7 +127,7 @@ impl GetRecordContext { candidates, pending: HashMap::new(), queried: HashSet::new(), - found_records: Vec::new(), + found_records, } } @@ -378,7 +382,7 @@ mod tests { #[test] fn completes_when_no_candidates() { let config = default_config(); - let mut context = GetRecordContext::new(config, VecDeque::new()); + let mut context = GetRecordContext::new(config, VecDeque::new(), Vec::new()); assert!(context.is_done()); let event = context.next_action().unwrap(); assert_eq!(event, QueryAction::QueryFailed { query: QueryId(0) }); @@ -387,7 +391,7 @@ mod tests { known_records: 1, ..default_config() }; - let mut context = GetRecordContext::new(config, VecDeque::new()); + let mut context = GetRecordContext::new(config, VecDeque::new(), Vec::new()); assert!(context.is_done()); let event = context.next_action().unwrap(); assert_eq!(event, QueryAction::QuerySucceeded { query: QueryId(0) }); @@ -405,7 +409,7 @@ mod tests { assert_eq!(in_peers_set.len(), 3); let in_peers = in_peers_set.iter().map(|peer| peer_to_kad(*peer)).collect(); - let mut context = GetRecordContext::new(config, in_peers); + let mut context = GetRecordContext::new(config, in_peers, Vec::new()); for num in 0..3 { let event = context.next_action().unwrap(); @@ -444,7 +448,7 @@ mod tests { assert_eq!(in_peers_set.len(), 3); let in_peers = [peer_a, peer_b, peer_c].iter().map(|peer| peer_to_kad(*peer)).collect(); - let mut context = GetRecordContext::new(config, in_peers); + let mut context = GetRecordContext::new(config, in_peers, Vec::new()); // Schedule peer queries. for num in 0..3 { diff --git a/src/protocol/libp2p/kademlia/query/mod.rs b/src/protocol/libp2p/kademlia/query/mod.rs index b933ec5b..9f14a6de 100644 --- a/src/protocol/libp2p/kademlia/query/mod.rs +++ b/src/protocol/libp2p/kademlia/query/mod.rs @@ -318,7 +318,7 @@ impl QueryEngine { target: RecordKey, candidates: VecDeque, quorum: Quorum, - count: usize, + local_record: Option, ) -> QueryId { tracing::debug!( target: LOG_TARGET, @@ -331,7 +331,7 @@ impl QueryEngine { let target = Key::new(target); let config = GetRecordConfig { local_peer_id: self.local_peer_id, - known_records: count, + known_records: if local_record.is_some() { 1 } else { 0 }, quorum, replication_factor: self.replication_factor, parallelism_factor: self.parallelism_factor, @@ -339,10 +339,18 @@ impl QueryEngine { target, }; + let found_records = local_record + .into_iter() + .map(|record| PeerRecord { + peer: self.local_peer_id, + record, + }) + .collect(); + self.queries.insert( query_id, QueryType::GetRecord { - context: GetRecordContext::new(config, candidates), + context: GetRecordContext::new(config, candidates, found_records), }, ); @@ -883,7 +891,7 @@ mod tests { ] .into(), Quorum::All, - 3, + None, ); for _ in 0..4 { diff --git a/src/transport/quic/mod.rs b/src/transport/quic/mod.rs index d69e1603..0cf5e255 100644 --- a/src/transport/quic/mod.rs +++ b/src/transport/quic/mod.rs @@ -43,7 +43,7 @@ use multiaddr::{Multiaddr, Protocol}; use quinn::{ClientConfig, Connecting, Connection, Endpoint, IdleTimeout}; use std::{ - collections::{HashMap, HashSet}, + collections::HashMap, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, pin::Pin, sync::Arc, @@ -120,9 +120,9 @@ pub(crate) struct QuicTransport { /// Opened raw connection, waiting for approval/rejection from `TransportManager`. opened_raw: HashMap, - /// Canceled raw connections. - canceled: HashSet, - + /// Cancel raw connections futures. + /// + /// This is cancelling `Self::pending_raw_connections`. cancel_futures: HashMap, } @@ -235,7 +235,6 @@ impl TransportBuilder for QuicTransport { context, config, listener, - canceled: HashSet::new(), opened_raw: HashMap::new(), pending_open: HashMap::new(), pending_dials: HashMap::new(), @@ -477,8 +476,11 @@ impl Transport for QuicTransport { /// Cancel opening connections. fn cancel(&mut self, connection_id: ConnectionId) { - self.canceled.insert(connection_id); - self.cancel_futures.remove(&connection_id).map(|handle| handle.abort()); + // Cancel the future if it exists. + // State clean-up happens inside the `poll_next`. + if let Some(handle) = self.cancel_futures.get(&connection_id) { + handle.abort(); + } } } @@ -510,27 +512,57 @@ impl Stream for QuicTransport { connection_id, address, stream, - } => - if !self.canceled.remove(&connection_id) { + } => { + let Some(handle) = self.cancel_futures.remove(&connection_id) else { + tracing::warn!( + target: LOG_TARGET, + ?connection_id, + ?address, + "raw connection without a cancel handle", + ); + continue; + }; + + if !handle.is_aborted() { self.opened_raw.insert(connection_id, (stream, address.clone())); return Poll::Ready(Some(TransportEvent::ConnectionOpened { connection_id, address, })); - }, + } + } + RawConnectionResult::Failed { connection_id, errors, - } => - if !self.canceled.remove(&connection_id) { + } => { + let Some(handle) = self.cancel_futures.remove(&connection_id) else { + tracing::warn!( + target: LOG_TARGET, + ?connection_id, + ?errors, + "raw connection without a cancel handle", + ); + continue; + }; + + if !handle.is_aborted() { return Poll::Ready(Some(TransportEvent::OpenFailure { connection_id, errors, })); - }, + } + } + RawConnectionResult::Canceled { connection_id } => { - self.canceled.remove(&connection_id); + if self.cancel_futures.remove(&connection_id).is_none() { + tracing::warn!( + target: LOG_TARGET, + ?connection_id, + "raw cancelled connection without a cancel handle", + ); + } } } } diff --git a/src/transport/tcp/mod.rs b/src/transport/tcp/mod.rs index 4ef52104..ff8c1655 100644 --- a/src/transport/tcp/mod.rs +++ b/src/transport/tcp/mod.rs @@ -46,7 +46,7 @@ use socket2::{Domain, Socket, Type}; use tokio::net::TcpStream; use std::{ - collections::{HashMap, HashSet}, + collections::HashMap, net::SocketAddr, pin::Pin, task::{Context, Poll}, @@ -121,9 +121,9 @@ pub(crate) struct TcpTransport { /// Opened raw connection, waiting for approval/rejection from `TransportManager`. opened_raw: HashMap, - /// Canceled raw connections. - canceled: HashSet, - + /// Cancel raw connections futures. + /// + /// This is cancelling `Self::pending_raw_connections`. cancel_futures: HashMap, /// Connections which have been opened and negotiated but are being validated by the @@ -291,7 +291,6 @@ impl TransportBuilder for TcpTransport { config, context, dial_addresses, - canceled: HashSet::new(), opened_raw: HashMap::new(), pending_open: HashMap::new(), pending_dials: HashMap::new(), @@ -516,8 +515,11 @@ impl Transport for TcpTransport { } fn cancel(&mut self, connection_id: ConnectionId) { - self.canceled.insert(connection_id); - self.cancel_futures.remove(&connection_id).map(|handle| handle.abort()); + // Cancel the future if it exists. + // State clean-up happens inside the `poll_next`. + if let Some(handle) = self.cancel_futures.get(&connection_id) { + handle.abort(); + } } } @@ -560,27 +562,56 @@ impl Stream for TcpTransport { connection_id, address, stream, - } => - if !self.canceled.remove(&connection_id) { + } => { + let Some(handle) = self.cancel_futures.remove(&connection_id) else { + tracing::warn!( + target: LOG_TARGET, + ?connection_id, + ?address, + "raw connection without a cancel handle", + ); + continue; + }; + + if !handle.is_aborted() { self.opened_raw.insert(connection_id, (stream, address.clone())); return Poll::Ready(Some(TransportEvent::ConnectionOpened { connection_id, address, })); - }, + } + } + RawConnectionResult::Failed { connection_id, errors, - } => - if !self.canceled.remove(&connection_id) { + } => { + let Some(handle) = self.cancel_futures.remove(&connection_id) else { + tracing::warn!( + target: LOG_TARGET, + ?connection_id, + ?errors, + "raw connection without a cancel handle", + ); + continue; + }; + + if !handle.is_aborted() { return Poll::Ready(Some(TransportEvent::OpenFailure { connection_id, errors, })); - }, + } + } RawConnectionResult::Canceled { connection_id } => { - self.canceled.remove(&connection_id); + if self.cancel_futures.remove(&connection_id).is_none() { + tracing::warn!( + target: LOG_TARGET, + ?connection_id, + "raw cancelled connection without a cancel handle", + ); + } } } } diff --git a/src/transport/websocket/mod.rs b/src/transport/websocket/mod.rs index bcf37002..d7b374a9 100644 --- a/src/transport/websocket/mod.rs +++ b/src/transport/websocket/mod.rs @@ -50,7 +50,7 @@ use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; use url::Url; use std::{ - collections::{HashMap, HashSet}, + collections::HashMap, pin::Pin, task::{Context, Poll}, time::Duration, @@ -125,9 +125,9 @@ pub(crate) struct WebSocketTransport { /// Opened raw connection, waiting for approval/rejection from `TransportManager`. opened_raw: HashMap>, Multiaddr)>, - /// Canceled raw connections. - canceled: HashSet, - + /// Cancel raw connections futures. + /// + /// This is cancelling `Self::pending_raw_connections`. cancel_futures: HashMap, /// Negotiated connections waiting validation. @@ -321,7 +321,6 @@ impl TransportBuilder for WebSocketTransport { config, context, dial_addresses, - canceled: HashSet::new(), opened_raw: HashMap::new(), pending_open: HashMap::new(), pending_dials: HashMap::new(), @@ -562,8 +561,11 @@ impl Transport for WebSocketTransport { } fn cancel(&mut self, connection_id: ConnectionId) { - self.canceled.insert(connection_id); - self.cancel_futures.remove(&connection_id).map(|handle| handle.abort()); + // Cancel the future if it exists. + // State clean-up happens inside the `poll_next`. + if let Some(handle) = self.cancel_futures.get(&connection_id) { + handle.abort(); + } } } @@ -600,27 +602,56 @@ impl Stream for WebSocketTransport { connection_id, address, stream, - } => - if !self.canceled.remove(&connection_id) { + } => { + let Some(handle) = self.cancel_futures.remove(&connection_id) else { + tracing::warn!( + target: LOG_TARGET, + ?connection_id, + ?address, + "raw connection without a cancel handle", + ); + continue; + }; + + if !handle.is_aborted() { self.opened_raw.insert(connection_id, (stream, address.clone())); return Poll::Ready(Some(TransportEvent::ConnectionOpened { connection_id, address, })); - }, + } + } + RawConnectionResult::Failed { connection_id, errors, - } => - if !self.canceled.remove(&connection_id) { + } => { + let Some(handle) = self.cancel_futures.remove(&connection_id) else { + tracing::warn!( + target: LOG_TARGET, + ?connection_id, + ?errors, + "raw connection without a cancel handle", + ); + continue; + }; + + if !handle.is_aborted() { return Poll::Ready(Some(TransportEvent::OpenFailure { connection_id, errors, })); - }, + } + } RawConnectionResult::Canceled { connection_id } => { - self.canceled.remove(&connection_id); + if self.cancel_futures.remove(&connection_id).is_none() { + tracing::warn!( + target: LOG_TARGET, + ?connection_id, + "raw cancelled connection without a cancel handle", + ); + } } } } diff --git a/tests/protocol/kademlia.rs b/tests/protocol/kademlia.rs index bd121c03..d3d62910 100644 --- a/tests/protocol/kademlia.rs +++ b/tests/protocol/kademlia.rs @@ -464,6 +464,108 @@ async fn get_record_retrieves_remote_records() { } } +#[tokio::test] +async fn get_record_retrieves_local_and_remote_records() { + let (kad_config1, mut kad_handle1) = KademliaConfigBuilder::new().build(); + let (kad_config2, mut kad_handle2) = KademliaConfigBuilder::new().build(); + + let config1 = ConfigBuilder::new() + .with_tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_kademlia(kad_config1) + .build(); + + let config2 = ConfigBuilder::new() + .with_tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_kademlia(kad_config2) + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + // Let peers jnow about each other + kad_handle1 + .add_known_peer( + *litep2p2.local_peer_id(), + litep2p2.listen_addresses().cloned().collect(), + ) + .await; + kad_handle2 + .add_known_peer( + *litep2p1.local_peer_id(), + litep2p1.listen_addresses().cloned().collect(), + ) + .await; + + // Store the record on `litep2p1``. + let original_record = Record::new(vec![1, 2, 3], vec![0x01]); + let query1 = kad_handle1.put_record(original_record.clone()).await; + + let (mut peer1_stored, mut peer2_stored) = (false, false); + let mut query3 = None; + + loop { + tokio::select! { + _ = tokio::time::sleep(tokio::time::Duration::from_secs(10)) => { + panic!("record was not retrieved in 10 secs") + } + event = litep2p1.next_event() => {} + event = litep2p2.next_event() => {} + event = kad_handle1.next() => {} + event = kad_handle2.next() => { + match event { + Some(KademliaEvent::IncomingRecord { record: got_record }) => { + assert_eq!(got_record.key, original_record.key); + assert_eq!(got_record.value, original_record.value); + assert_eq!(got_record.publisher.unwrap(), *litep2p1.local_peer_id()); + assert!(got_record.expires.is_some()); + + // Get record. + let query_id = kad_handle2 + .get_record(RecordKey::from(vec![1, 2, 3]), Quorum::All).await; + query3 = Some(query_id); + } + Some(KademliaEvent::GetRecordSuccess { query_id: _, records }) => { + match records { + RecordsType::LocalStore(_) => { + panic!("the record was retrieved only from peer2") + } + RecordsType::Network(records) => { + assert_eq!(records.len(), 2); + + // Locally retrieved record goes first. + assert_eq!(records[0].peer, *litep2p2.local_peer_id()); + assert_eq!(records[0].record.key, original_record.key); + assert_eq!(records[0].record.value, original_record.value); + assert_eq!(records[0].record.publisher.unwrap(), *litep2p1.local_peer_id()); + assert!(records[0].record.expires.is_some()); + + // Remote record from peer 1. + assert_eq!(records[1].peer, *litep2p1.local_peer_id()); + assert_eq!(records[1].record.key, original_record.key); + assert_eq!(records[1].record.value, original_record.value); + assert_eq!(records[1].record.publisher.unwrap(), *litep2p1.local_peer_id()); + assert!(records[1].record.expires.is_some()); + + break + } + } + } + Some(KademliaEvent::QueryFailed { query_id: _ }) => { + panic!("peer2 query failed") + } + _ => {} + } + } + } + } +} + #[tokio::test] async fn provider_retrieved_by_remote_node() { let (kad_config1, mut kad_handle1) = KademliaConfigBuilder::new().build();