Skip to content

Commit

Permalink
feat: implement ping extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
KolbyML committed Dec 28, 2024
1 parent b73bee1 commit 5febe0e
Show file tree
Hide file tree
Showing 31 changed files with 1,288 additions and 213 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions bin/trin/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use ethportal_api::{
network::Subnetwork,
portal_wire::{NetworkSpec, MAINNET},
},
version::VERSION,
version::{APP_NAME, VERSION},
};
use portalnet::{
bootnodes::Bootnodes,
Expand All @@ -32,7 +32,6 @@ const DEFAULT_SUBNETWORKS: &str = "history";
pub const DEFAULT_STORAGE_CAPACITY_MB: &str = "1000";
pub const DEFAULT_WEB3_TRANSPORT: &str = "ipc";

const APP_NAME: &str = "trin";
#[derive(Parser, Debug, PartialEq, Clone)]
#[command(name = APP_NAME,
author = "https://github.com/ethereum/trin/graphs/contributors",
Expand Down
1 change: 1 addition & 0 deletions ethportal-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ ethereum_serde_utils = "0.7.0"
ethereum_ssz.workspace = true
ethereum_ssz_derive.workspace = true
hex.workspace = true
itertools.workspace = true
jsonrpsee = { workspace = true, features = ["async-client", "client", "macros", "server"]}
keccak-hash.workspace = true
lazy_static.workspace = true
Expand Down
4 changes: 3 additions & 1 deletion ethportal-api/src/types/distance.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use std::{fmt, ops::Deref};

use alloy::primitives::U256;
use ssz_derive::{Decode, Encode};

pub type DataRadius = U256;

/// Represents a distance between two keys in the DHT key space.
#[derive(Copy, Clone, PartialEq, Eq, Default, PartialOrd, Ord, Debug)]
#[derive(Copy, Clone, PartialEq, Eq, Default, PartialOrd, Ord, Debug, Encode, Decode)]
#[ssz(struct_behaviour = "transparent")]
pub struct Distance(U256);

impl fmt::Display for Distance {
Expand Down
1 change: 1 addition & 0 deletions ethportal-api/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub mod execution;
pub mod jsonrpc;
pub mod network;
pub mod node_id;
pub mod ping_extensions;
pub mod portal;
pub mod portal_wire;
pub mod query_trace;
Expand Down
74 changes: 74 additions & 0 deletions ethportal-api/src/types/ping_extensions/decode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use anyhow::bail;
use ssz::Decode;

use super::{
extensions::{
type_0::ClientInfoRadiusCapabilities, type_1::BasicRadius, type_2::HistoryRadius,
type_65535::PingError,
},
CustomPayloadExtensionsFormat, Extensions,
};
use crate::types::portal_wire::CustomPayload;

#[derive(Debug, Clone)]
pub enum DecodedExtension {
Capabilities(ClientInfoRadiusCapabilities),
BasicRadius(BasicRadius),
HistoryRadius(HistoryRadius),
Error(PingError),
}

impl From<DecodedExtension> for Extensions {
fn from(value: DecodedExtension) -> Self {
match value {
DecodedExtension::Capabilities(_) => Extensions::Capabilities,
DecodedExtension::BasicRadius(_) => Extensions::BasicRadius,
DecodedExtension::HistoryRadius(_) => Extensions::HistoryRadius,
DecodedExtension::Error(_) => Extensions::Error,
}
}
}

impl TryFrom<CustomPayload> for DecodedExtension {
type Error = anyhow::Error;

fn try_from(value: CustomPayload) -> Result<Self, anyhow::Error> {
let Ok(ping_custom_payload): anyhow::Result<CustomPayloadExtensionsFormat> =
value.try_into()
else {
bail!("Failed to decode CustomPayloadExtensionsFormat");
};

let Ok(extension_type) = Extensions::try_from(ping_custom_payload.r#type) else {
bail!("Failed to decode extension type");
};

match extension_type {
Extensions::Capabilities => {
let capabilities =
ClientInfoRadiusCapabilities::from_ssz_bytes(&ping_custom_payload.payload)
.map_err(|err| {
anyhow::anyhow!(
"Failed to decode ClientInfoRadiusCapabilities: {err:?}"
)
})?;
Ok(DecodedExtension::Capabilities(capabilities))
}
Extensions::BasicRadius => {
let basic_radius = BasicRadius::from_ssz_bytes(&ping_custom_payload.payload)
.map_err(|err| anyhow::anyhow!("Failed to decode BasicRadius: {err:?}"))?;
Ok(DecodedExtension::BasicRadius(basic_radius))
}
Extensions::HistoryRadius => {
let history_radius = HistoryRadius::from_ssz_bytes(&ping_custom_payload.payload)
.map_err(|err| anyhow::anyhow!("Failed to decode HistoryRadius: {err:?}"))?;
Ok(DecodedExtension::HistoryRadius(history_radius))
}
Extensions::Error => {
let error = PingError::from_ssz_bytes(&ping_custom_payload.payload)
.map_err(|err| anyhow::anyhow!("Failed to decode PingError: {err:?}"))?;
Ok(DecodedExtension::Error(error))
}
}
}
}
4 changes: 4 additions & 0 deletions ethportal-api/src/types/ping_extensions/extensions/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pub mod type_0;
pub mod type_1;
pub mod type_2;
pub mod type_65535;
213 changes: 213 additions & 0 deletions ethportal-api/src/types/ping_extensions/extensions/type_0.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
use std::str::FromStr;

use anyhow::bail;
use itertools::Itertools;
use ssz::{Decode, Encode};
use ssz_derive::{Decode, Encode};
use ssz_types::{
typenum::{U200, U400},
VariableList,
};

use crate::{
types::{
distance::Distance,
ping_extensions::{CustomPayloadExtensionsFormat, ExtensionError, Extensions},
portal_wire::CustomPayload,
},
version::{
APP_NAME, BUILD_ARCHITECTURE, BUILD_OPERATING_SYSTEM, PROGRAMMING_LANGUAGE_VERSION,
TRIN_SHORT_COMMIT, TRIN_VERSION,
},
};

#[derive(PartialEq, Debug, Clone, Encode, Decode)]
pub struct ClientInfoRadiusCapabilities {
pub client_info: Option<ClientInfo>,
pub data_radius: Distance,
capabilities: VariableList<u16, U400>,
}

impl ClientInfoRadiusCapabilities {
pub fn new(radius: Distance, capabilities: Vec<u16>) -> Self {
Self {
client_info: Some(ClientInfo::trin_client_info()),
data_radius: radius,
capabilities: VariableList::from(capabilities),
}
}

pub fn capabilities(&self) -> Result<Vec<Extensions>, ExtensionError> {
self.capabilities
.iter()
.map(|&value| Extensions::try_from(value))
.collect::<Result<Vec<_>, _>>()
}
}

impl From<ClientInfoRadiusCapabilities> for CustomPayload {
fn from(client_info_radius_capacities: ClientInfoRadiusCapabilities) -> Self {
CustomPayload::from(
CustomPayloadExtensionsFormat {
r#type: 0,
payload: client_info_radius_capacities.as_ssz_bytes().into(),
}
.as_ssz_bytes(),
)
}
}

/// Information about the client.
/// example: trin/v0.1.1-892ad575/linux-x86_64/rustc1.81.0
#[derive(PartialEq, Debug, Clone)]
pub struct ClientInfo {
pub client_name: String,
pub client_version: String,
pub short_commit: String,
pub operating_system: String,
pub cpu_architecture: String,
pub programming_language_version: String,
}

impl ClientInfo {
pub fn trin_client_info() -> Self {
Self {
client_name: APP_NAME.to_string(),
client_version: TRIN_VERSION.to_string(),
short_commit: TRIN_SHORT_COMMIT.to_string(),
operating_system: BUILD_OPERATING_SYSTEM.to_string(),
cpu_architecture: BUILD_ARCHITECTURE.to_string(),
programming_language_version: format!("rustc{PROGRAMMING_LANGUAGE_VERSION}"),
}
}

pub fn string(&self) -> String {
format!(
"{}/{}-{}/{}-{}/{}",
self.client_name,
self.client_version,
self.short_commit,
self.operating_system,
self.cpu_architecture,
self.programming_language_version
)
}
}

impl FromStr for ClientInfo {
type Err = anyhow::Error;

fn from_str(string: &str) -> Result<Self, anyhow::Error> {
let parts: Vec<&str> = string.split('/').collect();

if parts.len() != 4 {
bail!("Invalid client info string: should have 4 /'s {}", string);
}

let client_name = parts[0];

let Some((client_version, short_commit)) = parts[1].split('-').collect_tuple() else {
bail!(
"Invalid client info string: should look like 0.1.1-2b00d730 got {}",
parts[1]
);
};

let Some((operating_system, cpu_architecture)) = parts[2].split('-').collect_tuple() else {
bail!(
"Invalid client info string: should look like linux-x86_64 got {}",
parts[2]
);
};

Ok(Self {
client_name: client_name.to_string(),
client_version: client_version.to_string(),
short_commit: short_commit.to_string(),
operating_system: operating_system.to_string(),
cpu_architecture: cpu_architecture.to_string(),
programming_language_version: parts[3].to_string(),
})
}
}

impl Encode for ClientInfo {
fn is_ssz_fixed_len() -> bool {
false
}

fn ssz_append(&self, buf: &mut Vec<u8>) {
let bytes: Vec<u8> = self.string().as_bytes().to_vec();
let byte_list: VariableList<u8, U200> = VariableList::from(bytes);
buf.extend_from_slice(&byte_list);
}

fn ssz_bytes_len(&self) -> usize {
self.as_ssz_bytes().len()
}
}

impl Decode for ClientInfo {
fn from_ssz_bytes(bytes: &[u8]) -> Result<Self, ssz::DecodeError> {
let byte_list = VariableList::<u8, U200>::from_ssz_bytes(bytes)?;
let string = String::from_utf8(byte_list.to_vec()).map_err(|_| {
ssz::DecodeError::BytesInvalid(format!("Invalid utf8 string: {byte_list:?}"))
})?;
Self::from_str(&string).map_err(|err| {
ssz::DecodeError::BytesInvalid(format!("Failed to parse client info: {err:?}"))
})
}

fn is_ssz_fixed_len() -> bool {
false
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_client_info_round_trip() {
let client_info = ClientInfo::trin_client_info();
let bytes = client_info.as_ssz_bytes();
let decoded = ClientInfo::from_ssz_bytes(&bytes).unwrap();
assert_eq!(client_info, decoded);
}

#[test]
fn test_client_info_from_str() {
let client_info = ClientInfo::trin_client_info();
let string = client_info.string();
let decoded = ClientInfo::from_str(&string).unwrap();
assert_eq!(client_info, decoded);
}

#[test]
fn test_client_info_from_str_invalid() {
let string = "trin/0.1.1-2b00d730/linux-x86_64";
let decoded = ClientInfo::from_str(string);
assert!(decoded.is_err());
}

#[test]
fn test_client_info_from_str_invalid_parts() {
let string = "trin/0.1.1-2b00d730/linux-x86_64/rustc1.81.0/extra";
let decoded = ClientInfo::from_str(string);
assert!(decoded.is_err());
}

#[test]
fn test_client_info_from_str_invalid_version() {
let string = "trin/0.1.1/linux-x86_64/rustc1.81.0";
let decoded = ClientInfo::from_str(string);
assert!(decoded.is_err());
}

#[test]
fn test_client_info_from_str_invalid_os() {
let string = "trin/0.1.1-2b00d730/linux/rustc1.81.0";
let decoded = ClientInfo::from_str(string);
assert!(decoded.is_err());
}
}
29 changes: 29 additions & 0 deletions ethportal-api/src/types/ping_extensions/extensions/type_1.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use ssz::Encode;
use ssz_derive::{Decode, Encode};

use crate::types::{
distance::Distance, ping_extensions::CustomPayloadExtensionsFormat, portal_wire::CustomPayload,
};

#[derive(PartialEq, Debug, Clone, Encode, Decode)]
pub struct BasicRadius {
pub data_radius: Distance,
}

impl BasicRadius {
pub fn new(data_radius: Distance) -> Self {
Self { data_radius }
}
}

impl From<BasicRadius> for CustomPayload {
fn from(basic_radius: BasicRadius) -> Self {
CustomPayload::from(
CustomPayloadExtensionsFormat {
r#type: 1,
payload: basic_radius.as_ssz_bytes().into(),
}
.as_ssz_bytes(),
)
}
}
Loading

0 comments on commit 5febe0e

Please sign in to comment.