diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..e63673cab --- /dev/null +++ b/.gitattributes @@ -0,0 +1,3 @@ +* text eol=lf +*.bat text eol=crlf +*.png binary diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 000000000..85cd693ae --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,55 @@ +name: Bug Report +description: Create a bug report for jvm-libp2p + +body: + - type: markdown + attributes: + value: | + Thank you for filing a bug report! + - type: textarea + attributes: + label: Summary + description: Please provide a short summary of the bug, along with any information you feel relevant to replicate the bug. + validations: + required: true + - type: textarea + attributes: + label: Expected behavior + description: Describe what you expect to happen. + validations: + required: true + - type: textarea + attributes: + label: Actual behavior + description: Describe what actually happens. + validations: + required: true + - type: textarea + attributes: + label: Relevant log output + description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. + render: shell + validations: + required: false + - type: textarea + attributes: + label: Possible Solution + description: Suggest a fix/reason for the bug, or ideas how to implement the addition or change. + validations: + required: false + - type: textarea + attributes: + label: Version + description: Which version of libp2p are you using? libp2p version (version number, commit, or branch) + validations: + required: false + - type: dropdown + attributes: + label: Would you like to work on fixing this bug ? + description: Any contribution towards fixing the bug is greatly appreciated. We are more than happy to provide help on the process. + options: + - "Yes" + - "No" + - Maybe + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 000000000..5842c85fe --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,8 @@ +blank_issues_enabled: true +contact_links: + - name: Technical Questions + url: https://github.com/libp2p/jvm-libp2p/discussions/new?category=q-a + about: Please ask technical questions in the jvm-libp2p Github Discussions forum. + - name: Community-wide libp2p Discussion + url: https://discuss.libp2p.io + about: Discussions and questions about the libp2p community. diff --git a/.github/ISSUE_TEMPLATE/enhancement.yml b/.github/ISSUE_TEMPLATE/enhancement.yml new file mode 100644 index 000000000..65def1ad7 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/enhancement.yml @@ -0,0 +1,31 @@ +name: Enhancement +description: Suggest an improvement to an existing jvm-libp2p feature. +body: + - type: textarea + attributes: + label: Description + description: Describe the enhancement that you are proposing. + validations: + required: true + - type: textarea + attributes: + label: Motivation + description: Explain why this enhancement is beneficial. + validations: + required: true + - type: textarea + attributes: + label: Current Implementation + description: Describe the current implementation. + validations: + required: true + - type: dropdown + attributes: + label: Are you planning to do it yourself in a pull request ? + description: Any contribution is greatly appreciated. We are more than happy to provide help on the process. + options: + - "Yes" + - "No" + - Maybe + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 000000000..411c43e38 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,42 @@ +name: Feature request +description: Suggest a new feature in jvm-libp2p +body: + - type: markdown + attributes: + value: | + If you'd like to suggest a feature related to libp2p but not specifically related to the JVM implementation, please file an issue at https://github.com/libp2p/specs instead. + - type: textarea + attributes: + label: Description + description: Briefly describe the feature that you are requesting. + validations: + required: true + - type: textarea + attributes: + label: Motivation + description: Explain why this feature is needed. + validations: + required: true + - type: textarea + attributes: + label: Requirements + description: Write a list of what you want this feature to do. + placeholder: "1." + validations: + required: true + - type: textarea + attributes: + label: Open questions + description: Use this section to ask any questions that are related to the feature. + validations: + required: false + - type: dropdown + attributes: + label: Are you planning to do it yourself in a pull request ? + description: Any contribution is greatly appreciated. We are more than happy to provide help on the process. + options: + - "Yes" + - "No" + - Maybe + validations: + required: true diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 000000000..0eeea244e --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,20 @@ +name: publish +on: + push: + branches: + - "develop" +jobs: + publish: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-java@v3 + with: + distribution: temurin + java-version: 11 + + - name: Setup Gradle + uses: gradle/gradle-build-action@v2 + + - name: Publish to Cloudsmith + run: ./gradlew publish -PcloudsmithUser=${{ secrets.CLOUDSMITH_USER }} -PcloudsmithApiKey=${{ secrets.CLOUDSMITH_API_KEY }} \ No newline at end of file diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 6f6d895d1..16d65d721 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -2,25 +2,12 @@ name: Close and mark stale issue on: schedule: - - cron: '0 0 * * *' + - cron: '0 0 * * *' + +permissions: + issues: write + pull-requests: write jobs: stale: - - runs-on: ubuntu-latest - permissions: - issues: write - pull-requests: write - - steps: - - uses: actions/stale@v3 - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: 'Oops, seems like we needed more information for this issue, please comment with more details or this issue will be closed in 7 days.' - close-issue-message: 'This issue was closed because it is missing author input.' - stale-issue-label: 'kind/stale' - any-of-labels: 'need/author-input' - exempt-issue-labels: 'need/triage,need/community-input,need/maintainer-input,need/maintainers-input,need/analysis,status/blocked,status/in-progress,status/ready,status/deferred,status/inactive' - days-before-issue-stale: 6 - days-before-issue-close: 7 - enable-statistics: true + uses: pl-strflt/.github/.github/workflows/reusable-stale-issue.yml@v0.3 diff --git a/README.md b/README.md index 593dadd51..b7ec0ae09 100644 --- a/README.md +++ b/README.md @@ -6,58 +6,45 @@ ![Build Status](https://github.com/libp2p/jvm-libp2p/actions/workflows/build.yml/badge.svg?branch=master) [![Discourse posts](https://img.shields.io/discourse/https/discuss.libp2p.io/posts.svg)](https://discuss.libp2p.io) -> a libp2p implementation for the JVM, written in Kotlin 🔥 - -**⚠️ This is heavy work in progress! ⚠️** - -## Roadmap - -The endeavour to build jvm-libp2p is split in two phases: - -* **minimal phase (v0.x):** aims to provide the bare minimum stack that will - allow JVM-based Ethereum 2.0 clients to interoperate with other clients that - rely on fully-fledged libp2p stacks written in other languages. - * To achieve this, we have to be wire-compliant, but don't need to fulfill - the complete catalogue of libp2p abstractions. - * This effort will act as a starting point to evolve this project into a - fully-fledged libp2p stack for JVM environments, including Android - runtimes. - * We are shooting for Aug/early Sept 2019. - * Only Java-friendly façade. - -* **maturity phase (v1.x):** upgrades the minimal version to a flexible and - versatile stack adhering to the key design principles of modularity and - pluggability that define the libp2p project. It adds features present in - mature implementations like go-libp2p, rust-libp2p, js-libp2p. - * will offer: pluggable peerstore, connection manager, QUIC transport, - circuit relay, AutoNAT, AutoRelay, NAT traversal, etc. - * Android-friendly. - * Kotlin coroutine-based façade, possibly a Reactive Streams façade too. - * work will begin after the minimal phase concludes. - -## minimal phase (v0.x): Definition of Done - -We have identified the following components on the path to attaining a minimal -implementation: - -- [X] multistream-select 1.0 -- [X] multiformats: [multiaddr](https://github.com/multiformats/multiaddr) -- [X] crypto (RSA, ed25519, secp256k1) -- [X] [secio](https://github.com/libp2p/specs/pull/106) -- [X] [connection bootstrapping](https://github.com/libp2p/specs/pull/168) -- [X] mplex as a multiplexer -- [X] stream multiplexing -- [X] TCP transport (dialing and listening) -- [X] Identify protocol -- [X] Ping protocol -- [X] [peer ID](https://github.com/libp2p/specs/pull/100) -- [X] noise security protocol -- [X] MDNS -- [X] Gossip 1.1 pubsub - -We are explicitly leaving out the peerstore, DHT, pubsub, connection manager, -etc. and other subsystems or concepts that are internal to implementations and -do not impact the ability to hold communications with other libp2p processes. +[Libp2p](https://libp2p.io/) implementation for the JVM, written in Kotlin 🔥 + +## Components + +List of components in the Libp2p spec and their JVM implementation status + +| | Component | Status | +|--------------------------|-------------------------------------------------------------------------------------------------|:----------------:| +| **Transport** | tcp | :green_apple: | +| | [quic](https://github.com/libp2p/specs/tree/master/quic) | :tomato: | +| | websocket | :lemon: | +| | [webtransport](https://github.com/libp2p/specs/tree/master/webtransport) | | +| | [webrtc-browser-to-server](https://github.com/libp2p/specs/blob/master/webrtc/webrtc-direct.md) | | +| | [webrtc-private-to-private](https://github.com/libp2p/specs/blob/master/webrtc/webrtc.md) | | +| **Secure Communication** | [noise](https://github.com/libp2p/specs/blob/master/noise/) | :green_apple: | +| | [tls](https://github.com/libp2p/specs/blob/master/tls/tls.md) | :lemon: | +| | [plaintext](https://github.com/libp2p/specs/blob/master/plaintext/README.md) | :lemon: | +| | [secio](https://github.com/libp2p/specs/blob/master/secio/README.md) **(deprecated)** | :green_apple: | +| **Protocol Select** | [multistream](https://github.com/multiformats/multistream-select) | :green_apple: | +| **Stream Multiplexing** | [yamux](https://github.com/libp2p/specs/blob/master/yamux/README.md) | :lemon: | +| | [mplex](https://github.com/libp2p/specs/blob/master/mplex/README.md) | :green_apple: | +| **NAT Traversal** | [circuit-relay-v2](https://github.com/libp2p/specs/blob/master/relay/circuit-v2.md) | :lemon: | +| | [autonat](https://github.com/libp2p/specs/tree/master/autonat) | :lemon: | +| | [hole-punching](https://github.com/libp2p/specs/blob/master/connections/hole-punching.md) | | +| **Discovery** | [bootstrap](https://github.com/libp2p/specs/blob/master/kad-dht/README.md#bootstrap-process) | | +| | random-walk | | +| | [mdns-discovery](https://github.com/libp2p/specs/blob/master/discovery/mdns.md) | :lemon: | +| | [rendezvous](https://github.com/libp2p/specs/blob/master/rendezvous/README.md) | | +| **Peer Routing** | [kad-dht](https://github.com/libp2p/specs/blob/master/kad-dht/README.md) | | +| **Publish/Subscribe** | floodsub | :lemon: | +| | [gossipsub](https://github.com/libp2p/specs/tree/master/pubsub/gossipsub) | :green_apple: | +| **Storage** | record | | +| **Other protocols** | [ping](https://github.com/libp2p/specs/blob/master/ping/ping.md) | :green_apple: | +| | [identify](https://github.com/libp2p/specs/blob/master/identify/README.md) | :green_apple: | + +Legend: +- :green_apple: - tested in production +- :lemon: - prototype or beta, not tested in production +- :tomato: - in progress ## Gossip simulator @@ -65,57 +52,69 @@ Deterministic Gossip simulator which may simulate networks as large as 10000 of Please check the Simulator [README](tools/simulator/README.md) for more details +## Android support + +The library is basically being developed with Android compatibility in mind. +However we are not aware of anyone using it in production. + +The `examples/android-chatter` module contains working sample Android application. This module is ignored by the Gradle +build when no Android SDK is installed. +To include the Android module define a valid SDK location with an `ANDROID_HOME` environment variable +or by setting the `sdk.dir` path in your project's local properties file local.properties. + +Importing the project into Android Studio should work out of the box. + ## Adding as a dependency to your project Hosting of artefacts is graciously provided by [Cloudsmith](https://cloudsmith.com). -[![Latest version of 'jvm-libp2p-minimal' @ Cloudsmith](https://api-prd.cloudsmith.io/v1/badges/version/libp2p/jvm-libp2p/maven/jvm-libp2p-minimal/latest/a=noarch;xg=io.libp2p/?render=true&show_latest=true)](https://cloudsmith.io/~libp2p/repos/jvm-libp2p/packages/detail/maven/jvm-libp2p-minimal/latest/a=noarch;xg=io.libp2p/) +[![Latest version of 'jvm-libp2p' @ Cloudsmith](https://api-prd.cloudsmith.io/v1/badges/version/libp2p/jvm-libp2p/maven/jvm-libp2p/latest/a=noarch;xg=io.libp2p/?render=true&show_latest=true)](https://cloudsmith.io/~libp2p/repos/jvm-libp2p/packages/detail/maven/jvm-libp2p/latest/a=noarch;xg=io.libp2p/) As an alternative, artefacts are also available on [JitPack](https://jitpack.io/). [![](https://jitpack.io/v/libp2p/jvm-libp2p.svg)](https://jitpack.io/#libp2p/jvm-libp2p) ### Using Gradle -Add the Cloudsmith repository to the `repositories` section of your Gradle file. +Add the required repositories to the `repositories` section of your Gradle file. ```groovy repositories { // ... maven { url "https://dl.cloudsmith.io/public/libp2p/jvm-libp2p/maven/" } + maven { url "https://jitpack.io" } + maven { url "https://artifacts.consensys.net/public/maven/maven/" } } ``` Add the library to the `implementation` part of your Gradle file. ```groovy dependencies { // ... - implementation 'io.libp2p:jvm-libp2p-minimal:X.Y.Z-RELEASE' + implementation 'io.libp2p:jvm-libp2p:X.Y.Z-RELEASE' } ``` ### Using Maven -Add the repository to the `dependencyManagement` section of the pom file: +Add the required repositories to the `dependencyManagement` section of the pom file: ```xml libp2p-jvm-libp2p https://dl.cloudsmith.io/public/libp2p/jvm-libp2p/maven/ - - true - always - - - true - always - + + + JitPack + https://jitpack.io + + + Consensys + https://artifacts.consensys.net/public/maven/maven/ ``` - -And then add jvm-libp2p as a dependency: +Add the library to the `dependencies` section of the pom file: ``` xml io.libp2p - jvm-libp2p-minimal + jvm-libp2p X.Y.Z-RELEASE - pom ``` @@ -138,7 +137,15 @@ To build the library from the `jvm-libp2p` folder, run: ./gradlew build ``` -After the build is complete you may find the library `.jar` file here: `jvm-libp2p/build/libs/jvm-libp2p-minimal-0.x.y-RELEASE.jar` +After the build is complete you may find the library `.jar` file here: `jvm-libp2p/build/libs/jvm-libp2p-X.Y.Z-RELEASE.jar` + +## Notable users + +- [Teku](https://github.com/Consensys/teku) - Ethereum Consensus Layer client +- [Nabu](https://github.com/peergos/nabu) - minimal Java implementation of IPFS +- [Peergos](https://github.com/peergos/peergos) - peer-to-peer encrypted global filesystem + +(Please open a pull request if you want your project to be added here) ## License diff --git a/build.gradle.kts b/build.gradle.kts index b6e9a7271..b312d3b81 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,26 +1,26 @@ +import com.diffplug.gradle.spotless.SpotlessExtension import org.gradle.api.tasks.testing.logging.TestExceptionFormat import org.jetbrains.kotlin.gradle.tasks.KotlinCompile -import java.net.URL - +import java.net.URI // To publish the release artifact to CloudSmith repo run the following : // ./gradlew publish -PcloudsmithUser= -PcloudsmithApiKey= -description = "a minimal implementation of libp2p for the jvm" +description = "a libp2p implementation for the JVM, written in Kotlin" plugins { val kotlinVersion = "1.6.21" id("org.jetbrains.kotlin.jvm") version kotlinVersion apply false - id("com.github.ben-manes.versions").version("0.44.0") + id("com.github.ben-manes.versions").version("0.51.0") id("idea") id("io.gitlab.arturbosch.detekt").version("1.22.0") id("java") id("maven-publish") - id("org.jetbrains.dokka").version("1.7.20") - id("org.jmailen.kotlinter").version("3.10.0") + id("org.jetbrains.dokka").version("1.9.20") + id("com.diffplug.spotless").version("6.25.0") id("java-test-fixtures") - id("io.spring.dependency-management").version("1.1.0") + id("io.spring.dependency-management").version("1.1.6") id("org.jetbrains.kotlin.android") version kotlinVersion apply false id("com.android.application") version "7.4.2" apply false @@ -45,7 +45,7 @@ configure( apply(plugin = "io.gitlab.arturbosch.detekt") apply(plugin = "maven-publish") apply(plugin = "org.jetbrains.dokka") - apply(plugin = "org.jmailen.kotlinter") + apply(plugin = "com.diffplug.spotless") apply(plugin = "java-test-fixtures") apply(plugin = "io.spring.dependency-management") apply(from = "$rootDir/versions.gradle") @@ -57,7 +57,6 @@ configure( implementation("com.google.guava:guava") implementation("org.slf4j:slf4j-api") - implementation("com.github.multiformats:java-multibase:v1.1.1") testFixturesImplementation("com.google.guava:guava") testFixturesImplementation("org.slf4j:slf4j-api") @@ -108,8 +107,25 @@ configure( // Runtime.getRuntime().availableProcessors().div(2)) } - kotlinter { - disabledRules = arrayOf("no-wildcard-imports", "enum-entry-name-case") + configure { + kotlin { + ktlint().editorConfigOverride( + mapOf( + "ktlint_standard_no-wildcard-imports" to "disabled", + "ktlint_standard_enum-entry-name-case" to "disabled", + "ktlint_standard_trailing-comma-on-call-site" to "disabled", + "ktlint_standard_trailing-comma-on-declaration-site" to "disabled", + "ktlint_standard_value-parameter-comment" to "disabled", + "ktlint_standard_value-argument-comment" to "disabled", + "ktlint_standard_property-naming" to "disabled", + "ktlint_standard_function-naming" to "disabled" + ) + ) + } + java { + targetExclude("**/generated/**/proto/**") + googleJavaFormat() + } } val sourcesJar by tasks.registering(Jar::class) { @@ -118,13 +134,13 @@ configure( } tasks.dokkaHtml.configure { - outputDirectory.set(buildDir.resolve("dokka")) + outputDirectory.set(getLayout().buildDirectory.dir("dokka")) dokkaSourceSets { configureEach { jdkVersion.set(11) reportUndocumented.set(false) externalDocumentationLink { - url.set(URL("https://netty.io/4.1/api/")) + url.set(URI.create("https://netty.io/4.1/api/").toURL()) } } } @@ -153,6 +169,14 @@ configure( publications { register("mavenJava", MavenPublication::class) { from(components["java"]) + versionMapping { + usage("java-api") { + fromResolutionOf("runtimeClasspath") + } + usage("java-runtime") { + fromResolutionResult() + } + } artifact(sourcesJar.get()) artifact(dokkaJar.get()) groupId = "io.libp2p" diff --git a/examples/android-chatter/build.gradle b/examples/android-chatter/build.gradle index fdf829839..41dd1bfcd 100644 --- a/examples/android-chatter/build.gradle +++ b/examples/android-chatter/build.gradle @@ -24,6 +24,7 @@ android { packagingOptions { exclude 'META-INF/io.netty.versions.properties' exclude 'META-INF/INDEX.LIST' + exclude 'META-INF/versions/9/OSGI-INF/MANIFEST.MF' } kotlinOptions { jvmTarget = "11" diff --git a/examples/chatter/src/main/kotlin/io/libp2p/example/chat/ChatNode.kt b/examples/chatter/src/main/kotlin/io/libp2p/example/chat/ChatNode.kt index 45ccf0545..b9b9617e3 100644 --- a/examples/chatter/src/main/kotlin/io/libp2p/example/chat/ChatNode.kt +++ b/examples/chatter/src/main/kotlin/io/libp2p/example/chat/ChatNode.kt @@ -50,8 +50,9 @@ class ChatNode(private val printMsg: OnMessage) { fun send(message: String) { peers.values.forEach { it.controller.send(message) } - if (message.startsWith("alias ")) + if (message.startsWith("alias ")) { currentAlias = message.substring(6).trim() + } } // send fun stop() { @@ -83,8 +84,9 @@ class ChatNode(private val printMsg: OnMessage) { if ( info.peerId == chatHost.peerId || knownNodes.contains(info.peerId) - ) + ) { return + } knownNodes.add(info.peerId) @@ -126,10 +128,11 @@ class ChatNode(private val printMsg: OnMessage) { .filterIsInstance() .filter { it.isSiteLocalAddress } .sortedBy { it.hostAddress } - return if (addresses.isNotEmpty()) + return if (addresses.isNotEmpty()) { addresses[0] - else + } else { InetAddress.getLoopbackAddress() + } } } } // class ChatNode diff --git a/examples/chatter/src/main/kotlin/io/libp2p/example/chat/ChatProtocol.kt b/examples/chatter/src/main/kotlin/io/libp2p/example/chat/ChatProtocol.kt index 4517cde87..6674f8db8 100644 --- a/examples/chatter/src/main/kotlin/io/libp2p/example/chat/ChatProtocol.kt +++ b/examples/chatter/src/main/kotlin/io/libp2p/example/chat/ChatProtocol.kt @@ -19,9 +19,9 @@ typealias OnChatMessage = (PeerId, String) -> Unit class Chat(chatCallback: OnChatMessage) : ChatBinding(ChatProtocol(chatCallback)) -const val protocolId: ProtocolId = "/example/chat/0.1.0" +const val PROTOCOL_ID: ProtocolId = "/example/chat/0.1.0" -open class ChatBinding(echo: ChatProtocol) : StrictProtocolBinding(protocolId, echo) +open class ChatBinding(echo: ChatProtocol) : StrictProtocolBinding(PROTOCOL_ID, echo) open class ChatProtocol( private val chatCallback: OnChatMessage diff --git a/examples/cli-chatter/src/main/kotlin/io/libp2p/example/chat/Chatter.kt b/examples/cli-chatter/src/main/kotlin/io/libp2p/example/chat/Chatter.kt index e296280cd..e5a450d7e 100644 --- a/examples/cli-chatter/src/main/kotlin/io/libp2p/example/chat/Chatter.kt +++ b/examples/cli-chatter/src/main/kotlin/io/libp2p/example/chat/Chatter.kt @@ -17,8 +17,9 @@ fun main() { print(">> ") message = readLine()?.trim() - if (message == null || message.isEmpty()) + if (message == null || message.isEmpty()) { continue + } node.send(message) } while ("bye" != message) diff --git a/examples/pinger/src/main/java/io/libp2p/example/ping/Pinger.java b/examples/pinger/src/main/java/io/libp2p/example/ping/Pinger.java index 16b1b1407..441540f57 100644 --- a/examples/pinger/src/main/java/io/libp2p/example/ping/Pinger.java +++ b/examples/pinger/src/main/java/io/libp2p/example/ping/Pinger.java @@ -5,18 +5,13 @@ import io.libp2p.core.multiformats.Multiaddr; import io.libp2p.protocol.Ping; import io.libp2p.protocol.PingController; - import java.util.concurrent.ExecutionException; public class Pinger { - public static void main(String[] args) - throws ExecutionException, InterruptedException { + public static void main(String[] args) throws ExecutionException, InterruptedException { // Create a libp2p node and configure it // to accept TCP connections on a random port - Host node = new HostBuilder() - .protocol(new Ping()) - .listen("/ip4/127.0.0.1/tcp/0") - .build(); + Host node = new HostBuilder().protocol(new Ping()).listen("/ip4/127.0.0.1/tcp/0").build(); // start listening node.start().get(); @@ -26,10 +21,7 @@ public static void main(String[] args) if (args.length > 0) { Multiaddr address = Multiaddr.fromString(args[0]); - PingController pinger = new Ping().dial( - node, - address - ).getController().get(); + PingController pinger = new Ping().dial(node, address).getController().get(); System.out.println("Sending 5 ping messages to " + address.toString()); for (int i = 1; i <= 5; ++i) { diff --git a/funding.json b/funding.json new file mode 100644 index 000000000..020558ba5 --- /dev/null +++ b/funding.json @@ -0,0 +1,5 @@ +{ + "opRetro": { + "projectId": "0x0be3a0fa062180bdfbfdefa993b09acd9edcae93ba0d8d5829dd01c138268f40" + } +} diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index e708b1c02..ccebba771 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 02292eac4..6a93cb7a1 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,7 @@ -#Thu May 11 18:05:55 GST 2023 distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-8.0-bin.zip +distributionSha256Sum=31c55713e40233a8303827ceb42ca48a47267a0ad4bab9177123121e71524c26 +distributionUrl=https\://services.gradle.org/distributions/gradle-8.10.2-bin.zip +networkTimeout=10000 zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew index 4f906e0c8..79a61d421 100755 --- a/gradlew +++ b/gradlew @@ -1,7 +1,7 @@ -#!/usr/bin/env sh +#!/bin/sh # -# Copyright 2015 the original author or authors. +# Copyright © 2015-2021 the original authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,67 +17,101 @@ # ############################################################################## -## -## Gradle start up script for UN*X -## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# ############################################################################## # Attempt to set APP_HOME + # Resolve links: $0 may be a link -PRG="$0" -# Need this for relative symlinks. -while [ -h "$PRG" ] ; do - ls=`ls -ld "$PRG"` - link=`expr "$ls" : '.*-> \(.*\)$'` - if expr "$link" : '/.*' > /dev/null; then - PRG="$link" - else - PRG=`dirname "$PRG"`"/$link" - fi +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac done -SAVED="`pwd`" -cd "`dirname \"$PRG\"`/" >/dev/null -APP_HOME="`pwd -P`" -cd "$SAVED" >/dev/null -APP_NAME="Gradle" -APP_BASE_NAME=`basename "$0"` +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' # Use the maximum available, or set MAX_FD != -1 to use that value. -MAX_FD="maximum" +MAX_FD=maximum warn () { echo "$*" -} +} >&2 die () { echo echo "$*" echo exit 1 -} +} >&2 # OS specific support (must be 'true' or 'false'). cygwin=false msys=false darwin=false nonstop=false -case "`uname`" in - CYGWIN* ) - cygwin=true - ;; - Darwin* ) - darwin=true - ;; - MINGW* ) - msys=true - ;; - NONSTOP* ) - nonstop=true - ;; +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; esac CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar @@ -87,9 +121,9 @@ CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar if [ -n "$JAVA_HOME" ] ; then if [ -x "$JAVA_HOME/jre/sh/java" ] ; then # IBM's JDK on AIX uses strange locations for the executables - JAVACMD="$JAVA_HOME/jre/sh/java" + JAVACMD=$JAVA_HOME/jre/sh/java else - JAVACMD="$JAVA_HOME/bin/java" + JAVACMD=$JAVA_HOME/bin/java fi if [ ! -x "$JAVACMD" ] ; then die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME @@ -98,7 +132,7 @@ Please set the JAVA_HOME variable in your environment to match the location of your Java installation." fi else - JAVACMD="java" + JAVACMD=java which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the @@ -106,80 +140,105 @@ location of your Java installation." fi # Increase the maximum file descriptors if we can. -if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then - MAX_FD_LIMIT=`ulimit -H -n` - if [ $? -eq 0 ] ; then - if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then - MAX_FD="$MAX_FD_LIMIT" - fi - ulimit -n $MAX_FD - if [ $? -ne 0 ] ; then - warn "Could not set maximum file descriptor limit: $MAX_FD" - fi - else - warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" - fi +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac fi -# For Darwin, add options to specify how the application appears in the dock -if $darwin; then - GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" -fi +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. # For Cygwin or MSYS, switch paths to Windows format before running java -if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then - APP_HOME=`cygpath --path --mixed "$APP_HOME"` - CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` - - JAVACMD=`cygpath --unix "$JAVACMD"` - - # We build the pattern for arguments to be converted via cygpath - ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` - SEP="" - for dir in $ROOTDIRSRAW ; do - ROOTDIRS="$ROOTDIRS$SEP$dir" - SEP="|" - done - OURCYGPATTERN="(^($ROOTDIRS))" - # Add a user-defined pattern to the cygpath arguments - if [ "$GRADLE_CYGPATTERN" != "" ] ; then - OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" - fi +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + # Now convert the arguments - kludge to limit ourselves to /bin/sh - i=0 - for arg in "$@" ; do - CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` - CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option - - if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition - eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` - else - eval `echo args$i`="\"$arg\"" + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) fi - i=`expr $i + 1` + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg done - case $i in - 0) set -- ;; - 1) set -- "$args0" ;; - 2) set -- "$args0" "$args1" ;; - 3) set -- "$args0" "$args1" "$args2" ;; - 4) set -- "$args0" "$args1" "$args2" "$args3" ;; - 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; - 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; - 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; - 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; - 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; - esac fi -# Escape application args -save () { - for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done - echo " " -} -APP_ARGS=`save "$@"` +# Collect all arguments for the java command; +# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of +# shell script including quotes and variable substitutions, so put them in +# double quotes to make sure that they get re-expanded; and +# * put everything else in single quotes, so that it's not re-expanded. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# -# Collect all arguments for the java command, following the shell quoting and substitution rules -eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' exec "$JAVACMD" "$@" diff --git a/gradlew.bat b/gradlew.bat index ac1b06f93..93e3f59f1 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -1,89 +1,92 @@ -@rem -@rem Copyright 2015 the original author or authors. -@rem -@rem Licensed under the Apache License, Version 2.0 (the "License"); -@rem you may not use this file except in compliance with the License. -@rem You may obtain a copy of the License at -@rem -@rem https://www.apache.org/licenses/LICENSE-2.0 -@rem -@rem Unless required by applicable law or agreed to in writing, software -@rem distributed under the License is distributed on an "AS IS" BASIS, -@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -@rem See the License for the specific language governing permissions and -@rem limitations under the License. -@rem - -@if "%DEBUG%" == "" @echo off -@rem ########################################################################## -@rem -@rem Gradle startup script for Windows -@rem -@rem ########################################################################## - -@rem Set local scope for the variables with windows NT shell -if "%OS%"=="Windows_NT" setlocal - -set DIRNAME=%~dp0 -if "%DIRNAME%" == "" set DIRNAME=. -set APP_BASE_NAME=%~n0 -set APP_HOME=%DIRNAME% - -@rem Resolve any "." and ".." in APP_HOME to make it shorter. -for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi - -@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" - -@rem Find java.exe -if defined JAVA_HOME goto findJavaFromJavaHome - -set JAVA_EXE=java.exe -%JAVA_EXE% -version >NUL 2>&1 -if "%ERRORLEVEL%" == "0" goto execute - -echo. -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:findJavaFromJavaHome -set JAVA_HOME=%JAVA_HOME:"=% -set JAVA_EXE=%JAVA_HOME%/bin/java.exe - -if exist "%JAVA_EXE%" goto execute - -echo. -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:execute -@rem Setup the command line - -set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar - - -@rem Execute Gradle -"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* - -:end -@rem End local scope for the variables with windows NT shell -if "%ERRORLEVEL%"=="0" goto mainEnd - -:fail -rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of -rem the _cmd.exe /c_ return code! -if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 -exit /b 1 - -:mainEnd -if "%OS%"=="Windows_NT" endlocal - -:omega +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/libp2p/build.gradle.kts b/libp2p/build.gradle.kts index 53bbd41ca..9a901b939 100644 --- a/libp2p/build.gradle.kts +++ b/libp2p/build.gradle.kts @@ -1,9 +1,13 @@ - plugins { - id("com.google.protobuf").version("0.9.2") - id("me.champeau.jmh").version("0.6.8") + id("com.google.protobuf").version("0.9.4") + id("me.champeau.jmh").version("0.7.2") } +// https://docs.gradle.org/current/userguide/java_testing.html#ex-disable-publishing-of-test-fixtures-variants +val javaComponent = components["java"] as AdhocComponentWithVariants +javaComponent.withVariantsFromConfiguration(configurations["testFixturesApiElements"]) { skip() } +javaComponent.withVariantsFromConfiguration(configurations["testFixturesRuntimeElements"]) { skip() } + dependencies { api("io.netty:netty-common") api("io.netty:netty-buffer") @@ -15,11 +19,12 @@ dependencies { api("com.google.protobuf:protobuf-java") + implementation("com.github.multiformats:java-multibase") implementation("tech.pegasys:noise-java") - implementation("org.bouncycastle:bcprov-jdk15on") - implementation("org.bouncycastle:bcpkix-jdk15on") - implementation("org.bouncycastle:bctls-jdk15on") + implementation("org.bouncycastle:bcprov-jdk18on") + implementation("org.bouncycastle:bcpkix-jdk18on") + implementation("org.bouncycastle:bctls-jdk18on") testImplementation(project(":tools:schedulers")) @@ -53,4 +58,3 @@ protobuf { } } } - diff --git a/libp2p/gradle.properties b/libp2p/gradle.properties index 0ef1e083b..3d6bf87a8 100644 --- a/libp2p/gradle.properties +++ b/libp2p/gradle.properties @@ -1 +1 @@ -mavenArtifactId=jvm-libp2p-minimal \ No newline at end of file +mavenArtifactId=jvm-libp2p \ No newline at end of file diff --git a/libp2p/src/jmh/java/io/libp2p/pubsub/gossip/GossipScoreBenchmark.java b/libp2p/src/jmh/java/io/libp2p/pubsub/gossip/GossipScoreBenchmark.java index 90481e7f0..ce87b32d0 100644 --- a/libp2p/src/jmh/java/io/libp2p/pubsub/gossip/GossipScoreBenchmark.java +++ b/libp2p/src/jmh/java/io/libp2p/pubsub/gossip/GossipScoreBenchmark.java @@ -7,10 +7,6 @@ import io.libp2p.tools.schedulers.ControlledExecutorServiceImpl; import io.libp2p.tools.schedulers.TimeController; import io.libp2p.tools.schedulers.TimeControllerImpl; -import org.openjdk.jmh.annotations.*; -import org.openjdk.jmh.infra.Blackhole; -import pubsub.pb.Rpc; - import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -18,6 +14,9 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; +import pubsub.pb.Rpc; @State(Scope.Thread) @Fork(5) @@ -25,92 +24,95 @@ @Measurement(iterations = 10, time = 1000, timeUnit = TimeUnit.MILLISECONDS) public class GossipScoreBenchmark { - private final int peerCount = 5000; - private final int connectedCount = 2000; - private final int topicCount = 128; - - private final List topics = IntStream - .range(0, topicCount) - .mapToObj(i -> "Topic-" + i) - .collect(Collectors.toList()); - - private final List peerIds = Stream.generate(PeerId::random).limit(peerCount).collect(Collectors.toList()); - private final List peerAddresses = IntStream - .range(0, peerCount) - .mapToObj(idx -> - Multiaddr.empty() - .withComponent(Protocol.IP4, new byte[]{(byte) (idx >>> 8 & 0xFF), (byte) (idx & 0xFF), 0, 0}) - .withComponent(Protocol.TCP, new byte[]{0x23, 0x28})) - .collect(Collectors.toList()); - - private final TimeController timeController = new TimeControllerImpl(); - private final ControlledExecutorServiceImpl controlledExecutor = new ControlledExecutorServiceImpl(); - private final GossipScoreParams gossipScoreParams; - private final DefaultGossipScore score; - - public GossipScoreBenchmark() { - Map topicParamMap = topics.stream() - .collect(Collectors.toMap(Function.identity(), __ -> new GossipTopicScoreParams())); - GossipTopicsScoreParams gossipTopicsScoreParams = new GossipTopicsScoreParams(new GossipTopicScoreParams(), topicParamMap); - - gossipScoreParams = new GossipScoreParams(new GossipPeerScoreParams(), gossipTopicsScoreParams, 0, 0, 0, 0, 0); - controlledExecutor.setTimeController(timeController); - score = new DefaultGossipScore(gossipScoreParams, controlledExecutor, timeController::getTime); - - for (int i = 0; i < peerCount; i++) { - PeerId peerId = peerIds.get(i); - score.notifyConnected(peerId, peerAddresses.get(i)); - for (String topic : topics) { - notifyUnseenMessage(peerId, topic); - } - } - - for (int i = connectedCount; i < peerCount; i++) { - score.notifyDisconnected(peerIds.get(i)); - } + private final int peerCount = 5000; + private final int connectedCount = 2000; + private final int topicCount = 128; + + private final List topics = + IntStream.range(0, topicCount).mapToObj(i -> "Topic-" + i).collect(Collectors.toList()); + + private final List peerIds = + Stream.generate(PeerId::random).limit(peerCount).collect(Collectors.toList()); + private final List peerAddresses = + IntStream.range(0, peerCount) + .mapToObj( + idx -> + Multiaddr.empty() + .withComponent( + Protocol.IP4, + new byte[] {(byte) (idx >>> 8 & 0xFF), (byte) (idx & 0xFF), 0, 0}) + .withComponent(Protocol.TCP, new byte[] {0x23, 0x28})) + .collect(Collectors.toList()); + + private final TimeController timeController = new TimeControllerImpl(); + private final ControlledExecutorServiceImpl controlledExecutor = + new ControlledExecutorServiceImpl(); + private final GossipScoreParams gossipScoreParams; + private final DefaultGossipScore score; + + public GossipScoreBenchmark() { + Map topicParamMap = + topics.stream() + .collect(Collectors.toMap(Function.identity(), __ -> new GossipTopicScoreParams())); + GossipTopicsScoreParams gossipTopicsScoreParams = + new GossipTopicsScoreParams(new GossipTopicScoreParams(), topicParamMap); + + gossipScoreParams = + new GossipScoreParams(new GossipPeerScoreParams(), gossipTopicsScoreParams, 0, 0, 0, 0, 0); + controlledExecutor.setTimeController(timeController); + score = new DefaultGossipScore(gossipScoreParams, controlledExecutor, timeController::getTime); + + for (int i = 0; i < peerCount; i++) { + PeerId peerId = peerIds.get(i); + score.notifyConnected(peerId, peerAddresses.get(i)); + for (String topic : topics) { + notifyUnseenMessage(peerId, topic); + } } - private void notifyUnseenMessage(PeerId peerId, String topic) { - Rpc.Message message = Rpc.Message.newBuilder() - .addTopicIDs(topic) - .build(); - score.notifyUnseenValidMessage(peerId, new DefaultPubsubMessage(message)); + for (int i = connectedCount; i < peerCount; i++) { + score.notifyDisconnected(peerIds.get(i)); } - - @Benchmark - public void scoresDelay0(Blackhole bh) { - for (int i = 0; i < connectedCount; i++) { - double s = score.score(peerIds.get(i)); - bh.consume(s); - } + } + + private void notifyUnseenMessage(PeerId peerId, String topic) { + Rpc.Message message = Rpc.Message.newBuilder().addTopicIDs(topic).build(); + score.notifyUnseenValidMessage(peerId, new DefaultPubsubMessage(message)); + } + + @Benchmark + public void scoresDelay0(Blackhole bh) { + for (int i = 0; i < connectedCount; i++) { + double s = score.score(peerIds.get(i)); + bh.consume(s); } + } - @Benchmark - public void scoresDelay100(Blackhole bh) { - timeController.addTime(100); + @Benchmark + public void scoresDelay100(Blackhole bh) { + timeController.addTime(100); - for (int i = 0; i < connectedCount; i++) { - double s = score.score(peerIds.get(i)); - bh.consume(s); - } + for (int i = 0; i < connectedCount; i++) { + double s = score.score(peerIds.get(i)); + bh.consume(s); } + } - @Benchmark - public void scoresDelay10000(Blackhole bh) { - timeController.addTime(10000); + @Benchmark + public void scoresDelay10000(Blackhole bh) { + timeController.addTime(10000); - for (int i = 0; i < connectedCount; i++) { - double s = score.score(peerIds.get(i)); - bh.consume(s); - } + for (int i = 0; i < connectedCount; i++) { + double s = score.score(peerIds.get(i)); + bh.consume(s); } - - /** - * Uncomment for debugging - */ -// public static void main(String[] args) { -// GossipScoreBenchmark benchmark = new GossipScoreBenchmark(); -// Blackhole blackhole = new Blackhole("Today's password is swordfish. I understand instantiating Blackholes directly is dangerous."); -// benchmark.scoresDelay0(blackhole); -// } + } + + /** Uncomment for debugging */ + // public static void main(String[] args) { + // GossipScoreBenchmark benchmark = new GossipScoreBenchmark(); + // Blackhole blackhole = new Blackhole("Today's password is swordfish. I understand + // instantiating Blackholes directly is dangerous."); + // benchmark.scoresDelay0(blackhole); + // } } diff --git a/libp2p/src/main/java/io/libp2p/core/dsl/HostBuilder.java b/libp2p/src/main/java/io/libp2p/core/dsl/HostBuilder.java index 819d2348b..4c431b2dc 100644 --- a/libp2p/src/main/java/io/libp2p/core/dsl/HostBuilder.java +++ b/libp2p/src/main/java/io/libp2p/core/dsl/HostBuilder.java @@ -7,36 +7,37 @@ import io.libp2p.core.security.SecureChannel; import io.libp2p.core.transport.Transport; import io.libp2p.transport.ConnectionUpgrader; - import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.function.*; public class HostBuilder { - public HostBuilder() { this(DefaultMode.Standard); } - public HostBuilder(DefaultMode defaultMode) { - defaultMode_ = defaultMode; + public HostBuilder() { + this(DefaultMode.Standard); + } + + public HostBuilder(DefaultMode defaultMode) { + defaultMode_ = defaultMode; + } + + public enum DefaultMode { + None, + Standard; + + private Builder.Defaults asBuilderDefault() { + if (this.equals(None)) { + return Builder.Defaults.None; + } + return Builder.Defaults.Standard; } + }; - public enum DefaultMode { - None, - Standard; - - private Builder.Defaults asBuilderDefault() { - if (this.equals(None)) { - return Builder.Defaults.None; - } - return Builder.Defaults.Standard; - } - }; - - @SafeVarargs - public final HostBuilder transport( - Function... transports) { - transports_.addAll(Arrays.asList(transports)); - return this; - } + @SafeVarargs + public final HostBuilder transport(Function... transports) { + transports_.addAll(Arrays.asList(transports)); + return this; + } @SafeVarargs public final HostBuilder secureTransport( @@ -45,69 +46,67 @@ public final HostBuilder secureTransport( return this; } - @SafeVarargs - public final HostBuilder secureChannel( - BiFunction, SecureChannel>... secureChannels) { - secureChannels_.addAll(Arrays.asList(secureChannels)); - return this; - } - - @SafeVarargs - public final HostBuilder muxer( - Supplier... muxers) { - muxers_.addAll(Arrays.asList(muxers)); - return this; - } - - public final HostBuilder protocol( - ProtocolBinding... protocols) { - protocols_.addAll(Arrays.asList(protocols)); - return this; - } - - public final HostBuilder listen( - String... addresses) { - listenAddresses_.addAll(Arrays.asList(addresses)); - return this; - } - - - public Host build() { - return BuilderJKt.hostJ( - defaultMode_.asBuilderDefault(), - b -> { - IdentityBuilder identity = b.getIdentity(); + @SafeVarargs + public final HostBuilder secureChannel( + BiFunction, SecureChannel>... secureChannels) { + secureChannels_.addAll(Arrays.asList(secureChannels)); + return this; + } + + @SafeVarargs + public final HostBuilder muxer(Supplier... muxers) { + muxers_.addAll(Arrays.asList(muxers)); + return this; + } + + public final HostBuilder protocol(ProtocolBinding... protocols) { + protocols_.addAll(Arrays.asList(protocols)); + return this; + } + + public final HostBuilder listen(String... addresses) { + listenAddresses_.addAll(Arrays.asList(addresses)); + return this; + } + + public final HostBuilder builderModifier(Consumer builderModifier) { + this.builderModifier = builderModifier; + return this; + } + + @SuppressWarnings("unchecked") + public Host build() { + return BuilderJKt.hostJ( + defaultMode_.asBuilderDefault(), + b -> { + IdentityBuilder identity = b.getIdentity(); identity.random(KEY_TYPE.ED25519); PrivKey peerId = identity.getFactory().invoke(); identity.setFactory(() -> peerId); - secureTransports_.forEach(t -> + secureTransports_.forEach(t -> b.getTransports().add(c -> t.apply(identity.getFactory().invoke(), protocols_)) ); transports_.forEach(t -> - b.getTransports().add(t::apply) - ); - secureChannels_.forEach(sc -> - b.getSecureChannels().add((k, m) -> sc.apply(k, (List)m)) - ); - muxers_.forEach(m -> - b.getMuxers().add(m.get()) - ); - b.getProtocols().addAll(protocols_); - listenAddresses_.forEach(a -> - b.getNetwork().listen(a) - ); - } - ); - } // build - - private DefaultMode defaultMode_; - private List>, Transport>> secureTransports_ = new ArrayList<>(); + b.getTransports().add(t::apply)); + secureChannels_.forEach( + sc -> b.getSecureChannels().add((k, m) -> sc.apply(k, (List) m))); + muxers_.forEach(m -> b.getMuxers().add(m.get())); + b.getProtocols().addAll(protocols_); + listenAddresses_.forEach(a -> b.getNetwork().listen(a)); + builderModifier.accept(b); + }); + } // build + + private DefaultMode defaultMode_; + private List>, Transport>> secureTransports_ = new ArrayList<>(); private List> transports_ = new ArrayList<>(); - private List, SecureChannel>> secureChannels_ = new ArrayList<>(); - private List> muxers_ = new ArrayList<>(); - private List> protocols_ = new ArrayList<>(); - private List listenAddresses_ = new ArrayList<>(); + private List, SecureChannel>> secureChannels_ = + new ArrayList<>(); + private List> muxers_ = new ArrayList<>(); + private List> protocols_ = new ArrayList<>(); + private List listenAddresses_ = new ArrayList<>(); + private Consumer builderModifier = b -> {}; } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/AnswerListener.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/AnswerListener.java index 54dfc5170..f7844b95e 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/AnswerListener.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/AnswerListener.java @@ -1,10 +1,9 @@ -package io.libp2p.discovery.mdns; - -import io.libp2p.discovery.mdns.impl.DNSRecord; - -import java.util.EventListener; -import java.util.List; - -public interface AnswerListener extends EventListener { - void answersReceived(List answers); -} +package io.libp2p.discovery.mdns; + +import io.libp2p.discovery.mdns.impl.DNSRecord; +import java.util.EventListener; +import java.util.List; + +public interface AnswerListener extends EventListener { + void answersReceived(List answers); +} diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/JmDNS.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/JmDNS.java index 8bc6d6259..697854802 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/JmDNS.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/JmDNS.java @@ -5,73 +5,75 @@ package io.libp2p.discovery.mdns; import io.libp2p.discovery.mdns.impl.JmDNSImpl; - import java.io.IOException; import java.net.InetAddress; /** - * Based on code by Arthur van Hoff, Rick Blair, Jeff Sonstein, Werner Randelshofer, Pierre Frisch, Scott Lewis, Scott Cytacki + * Based on code by Arthur van Hoff, Rick Blair, Jeff Sonstein, Werner Randelshofer, Pierre Frisch, + * Scott Lewis, Scott Cytacki */ public abstract class JmDNS { - /** - *

- * Create an instance of JmDNS and bind it to a specific network interface given its IP-address. - *

- *

- * Note: This is a convenience method. The preferred constructor is {@link #create(InetAddress, String)}.
- * Check that your platform correctly handle the default localhost IP address and the local hostname. In doubt use the explicit constructor.
- * This call is equivalent to create(addr, null). - *

- * - * @see #create(InetAddress, String) - * @param addr - * IP address to bind to. - * @return jmDNS instance - */ - public static JmDNS create(final InetAddress addr) { - return new JmDNSImpl(addr, null); - } + /** + * Create an instance of JmDNS and bind it to a specific network interface given its IP-address. + * + *

Note: This is a convenience method. The preferred constructor is {@link + * #create(InetAddress, String)}.
+ * Check that your platform correctly handle the default localhost IP address and the local + * hostname. In doubt use the explicit constructor.
+ * This call is equivalent to create(addr, null). + * + * @see #create(InetAddress, String) + * @param addr IP address to bind to. + * @return jmDNS instance + */ + public static JmDNS create(final InetAddress addr) { + return new JmDNSImpl(addr, null); + } + + /** + * Create an instance of JmDNS and bind it to a specific network interface given its IP-address. + * If addr parameter is null this method will try to resolve to a local IP address of + * the machine using a network discovery: + * + *

    + *
  1. Check the system property net.mdns.interface + *
  2. Check the JVM local host + *
  3. In the last resort bind to the loopback address. This is non functional in most cases. + *
+ * + * If name parameter is null will use the hostname. The hostname is determined by the + * following algorithm: + * + *
    + *
  1. Get the hostname from the InetAdress obtained before. + *
  2. If the hostname is a reverse lookup default to JmDNS name or computer + * if null. + *
  3. If the name contains '.' replace them by '-' + *
  4. Add .local. at the end of the name. + *
+ * + * @param addr IP address to bind to. + * @param name name of the newly created JmDNS + * @return jmDNS instance + * @exception IOException if an exception occurs during the socket creation + */ + public static JmDNS create(final InetAddress addr, final String name) { + return new JmDNSImpl(addr, name); + } + + public abstract void start() throws IOException; - /** - *

- * Create an instance of JmDNS and bind it to a specific network interface given its IP-address. - *

- * If addr parameter is null this method will try to resolve to a local IP address of the machine using a network discovery: - *
    - *
  1. Check the system property net.mdns.interface
  2. - *
  3. Check the JVM local host
  4. - *
  5. In the last resort bind to the loopback address. This is non functional in most cases.
  6. - *
- * If name parameter is null will use the hostname. The hostname is determined by the following algorithm: - *
    - *
  1. Get the hostname from the InetAdress obtained before.
  2. - *
  3. If the hostname is a reverse lookup default to JmDNS name or computer if null.
  4. - *
  5. If the name contains '.' replace them by '-'
  6. - *
  7. Add .local. at the end of the name.
  8. - *
- * - * @param addr - * IP address to bind to. - * @param name - * name of the newly created JmDNS - * @return jmDNS instance - * @exception IOException - * if an exception occurs during the socket creation - */ - public static JmDNS create(final InetAddress addr, final String name) { - return new JmDNSImpl(addr, name); - } + public abstract void stop(); - public abstract void start() throws IOException; - public abstract void stop(); + /** + * Return the name of the JmDNS instance. This is an arbitrary string that is useful for + * distinguishing instances. + * + * @return name of the JmDNS + */ + public abstract String getName(); - /** - * Return the name of the JmDNS instance. This is an arbitrary string that is useful for distinguishing instances. - * - * @return name of the JmDNS - */ - public abstract String getName(); + public abstract void addAnswerListener(String type, int queryInterval, AnswerListener listener); - public abstract void addAnswerListener(String type, int queryInterval, AnswerListener listener); - public abstract void registerService(ServiceInfo info) throws IOException; + public abstract void registerService(ServiceInfo info) throws IOException; } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/ServiceInfo.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/ServiceInfo.java index 482daba5e..80dd0f9c0 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/ServiceInfo.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/ServiceInfo.java @@ -4,226 +4,207 @@ package io.libp2p.discovery.mdns; import io.libp2p.discovery.mdns.impl.ServiceInfoImpl; - import java.net.Inet4Address; import java.net.Inet6Address; import java.util.List; import java.util.Map; /** - *

* The fully qualified service name is build using up to 5 components with the following structure: - * + * *

  *            <app>.<protocol>.<servicedomain>.<parentdomain>.
* <Instance>.<app>.<protocol>.<servicedomain>.<parentdomain>.
* <sub>._sub.<app>.<protocol>.<servicedomain>.<parentdomain>. *
- * + * *
    - *
  1. <servicedomain>.<parentdomain>: This is the domain scope of the service typically "local.", but this can also be something similar to "in-addr.arpa." or "ip6.arpa."
  2. - *
  3. <protocol>: This is either "_tcp" or "_udp"
  4. - *
  5. <app>: This define the application protocol. Typical example are "_http", "_ftp", etc.
  6. - *
  7. <Instance>: This is the service name
  8. - *
  9. <sub>: This is the subtype for the application protocol
  10. + *
  11. <servicedomain>.<parentdomain>: This is the domain scope of the service + * typically "local.", but this can also be something similar to "in-addr.arpa." or + * "ip6.arpa." + *
  12. <protocol>: This is either "_tcp" or "_udp" + *
  13. <app>: This define the application protocol. Typical example are "_http", "_ftp", + * etc. + *
  14. <Instance>: This is the service name + *
  15. <sub>: This is the subtype for the application protocol *
- *

*/ public abstract class ServiceInfo implements Cloneable { - /** - * Fields for the fully qualified map. - */ - public enum Fields { - /** - * Domain Field. - */ - Domain, - /** - * Protocol Field. - */ - Protocol, - /** - * Application Field. - */ - Application, - /** - * Instance Field. - */ - Instance, - /** - * Subtype Field. - */ - Subtype - } - - /** - * Construct a service description for registering with JmDNS. - * - * @param type - * fully qualified service type name, such as _http._tcp.local.. - * @param name - * unqualified service instance name, such as foobar - * @param port - * the local port on which the service runs - * @param text - * string describing the service - * @return new service info - */ - public static ServiceInfo create( - final String type, - final String name, - final int port, - final String text, - final List ip4Addresses, - final List ip6Addresses) { - ServiceInfoImpl si = new ServiceInfoImpl(type, name, "", port, 0, 0, text); - for (Inet4Address a : ip4Addresses) - si.addAddress(a); - for (Inet6Address a : ip6Addresses) - si.addAddress(a); - return si; - } - - /** - * Returns true if the service info is filled with data. - * - * @return true if the service info has data, false otherwise. - */ - public abstract boolean hasData(); - - /** - * Fully qualified service type name, such as _http._tcp.local. - * - * @return service type name - */ - public abstract String getType(); - - /** - * Fully qualified service type name with the subtype if appropriate, such as _printer._sub._http._tcp.local. - * - * @return service type name - */ - public abstract String getTypeWithSubtype(); - - /** - * Unqualified service instance name, such as foobar . - * - * @return service name - */ - public abstract String getName(); - - /** - * The key is used to retrieve service info in hash tables.
- * The key is the lower case qualified name. - * - * @return the key - */ - public abstract String getKey(); - - /** - * Fully qualified service name, such as foobar._http._tcp.local. . - * - * @return qualified service name - */ - public abstract String getQualifiedName(); - - /** - * Get the name of the server. - * - * @return server name - */ - public abstract String getServer(); - - /** - * Returns a list of all IPv4 InetAddresses that can be used for this service. - *

- * In a multi-homed environment service info can be associated with more than one address. - *

- * - * @return list of InetAddress objects - */ - public abstract Inet4Address[] getInet4Addresses(); - - /** - * Returns a list of all IPv6 InetAddresses that can be used for this service. - *

- * In a multi-homed environment service info can be associated with more than one address. - *

- * - * @return list of InetAddress objects - */ - public abstract Inet6Address[] getInet6Addresses(); - - /** - * Get the port for the service. - * - * @return service port - */ - public abstract int getPort(); - - /** - * Get the priority of the service. - * - * @return service priority - */ - public abstract int getPriority(); - - /** - * Get the weight of the service. - * - * @return service weight - */ - public abstract int getWeight(); - - /** - * Get the text for the service as raw bytes. - * - * @return raw service text - */ - public abstract byte[] getTextBytes(); - - /** - * Returns the domain of the service info suitable for printing. - * - * @return service domain - */ - public abstract String getDomain(); - - /** - * Returns the protocol of the service info suitable for printing. - * - * @return service protocol - */ - public abstract String getProtocol(); - - /** - * Returns the application of the service info suitable for printing. - * - * @return service application - */ - public abstract String getApplication(); - - /** - * Returns the sub type of the service info suitable for printing. - * - * @return service sub type - */ - public abstract String getSubtype(); - - /** - * Returns a dictionary of the fully qualified name component of this service. - * - * @return dictionary of the fully qualified name components - */ - public abstract Map getQualifiedNameMap(); - - /* - * (non-Javadoc) - * @see java.lang.Object#clone() - */ - @Override - public ServiceInfo clone() throws CloneNotSupportedException { - return (ServiceInfo) super.clone(); - } + /** Fields for the fully qualified map. */ + public enum Fields { + /** Domain Field. */ + Domain, + /** Protocol Field. */ + Protocol, + /** Application Field. */ + Application, + /** Instance Field. */ + Instance, + /** Subtype Field. */ + Subtype + } + + /** + * Construct a service description for registering with JmDNS. + * + * @param type fully qualified service type name, such as _http._tcp.local.. + * @param name unqualified service instance name, such as foobar + * @param port the local port on which the service runs + * @param text string describing the service + * @return new service info + */ + public static ServiceInfo create( + final String type, + final String name, + final int port, + final String text, + final List ip4Addresses, + final List ip6Addresses) { + ServiceInfoImpl si = new ServiceInfoImpl(type, name, "", port, 0, 0, text); + for (Inet4Address a : ip4Addresses) si.addAddress(a); + for (Inet6Address a : ip6Addresses) si.addAddress(a); + return si; + } + + /** + * Returns true if the service info is filled with data. + * + * @return true if the service info has data, false otherwise. + */ + public abstract boolean hasData(); + + /** + * Fully qualified service type name, such as _http._tcp.local. + * + * @return service type name + */ + public abstract String getType(); + + /** + * Fully qualified service type name with the subtype if appropriate, such as + * _printer._sub._http._tcp.local. + * + * @return service type name + */ + public abstract String getTypeWithSubtype(); + + /** + * Unqualified service instance name, such as foobar . + * + * @return service name + */ + public abstract String getName(); + + /** + * The key is used to retrieve service info in hash tables.
+ * The key is the lower case qualified name. + * + * @return the key + */ + public abstract String getKey(); + + /** + * Fully qualified service name, such as foobar._http._tcp.local. . + * + * @return qualified service name + */ + public abstract String getQualifiedName(); + + /** + * Get the name of the server. + * + * @return server name + */ + public abstract String getServer(); + + /** + * Returns a list of all IPv4 InetAddresses that can be used for this service. + * + *

In a multi-homed environment service info can be associated with more than one address. + * + * @return list of InetAddress objects + */ + public abstract Inet4Address[] getInet4Addresses(); + + /** + * Returns a list of all IPv6 InetAddresses that can be used for this service. + * + *

In a multi-homed environment service info can be associated with more than one address. + * + * @return list of InetAddress objects + */ + public abstract Inet6Address[] getInet6Addresses(); + + /** + * Get the port for the service. + * + * @return service port + */ + public abstract int getPort(); + + /** + * Get the priority of the service. + * + * @return service priority + */ + public abstract int getPriority(); + + /** + * Get the weight of the service. + * + * @return service weight + */ + public abstract int getWeight(); + + /** + * Get the text for the service as raw bytes. + * + * @return raw service text + */ + public abstract byte[] getTextBytes(); + + /** + * Returns the domain of the service info suitable for printing. + * + * @return service domain + */ + public abstract String getDomain(); + + /** + * Returns the protocol of the service info suitable for printing. + * + * @return service protocol + */ + public abstract String getProtocol(); + + /** + * Returns the application of the service info suitable for printing. + * + * @return service application + */ + public abstract String getApplication(); + + /** + * Returns the sub type of the service info suitable for printing. + * + * @return service sub type + */ + public abstract String getSubtype(); + + /** + * Returns a dictionary of the fully qualified name component of this service. + * + * @return dictionary of the fully qualified name components + */ + public abstract Map getQualifiedNameMap(); + + /* + * (non-Javadoc) + * @see java.lang.Object#clone() + */ + @Override + public ServiceInfo clone() throws CloneNotSupportedException { + return (ServiceInfo) super.clone(); + } } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSEntry.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSEntry.java index 1f90e988c..a49b9934f 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSEntry.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSEntry.java @@ -4,214 +4,226 @@ package io.libp2p.discovery.mdns.impl; +import io.libp2p.discovery.mdns.ServiceInfo.Fields; import io.libp2p.discovery.mdns.impl.constants.DNSRecordClass; import io.libp2p.discovery.mdns.impl.constants.DNSRecordType; - import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; import java.io.IOException; import java.util.Collections; import java.util.Map; -import io.libp2p.discovery.mdns.ServiceInfo.Fields; - /** * DNS entry with a name, type, and class. This is the base class for questions and records. * * @author Arthur van Hoff, Pierre Frisch, Rick Blair */ public abstract class DNSEntry { - // private static Logger logger = LoggerFactory.getLogger(DNSEntry.class.getName()); - private final String _key; - - private final String _name; - - private final String _type; - - private final DNSRecordType _recordType; - - private final DNSRecordClass _dnsClass; - - private final boolean _unique; - - final Map _qualifiedNameMap; - - /** - * Create an entry. - */ - DNSEntry(String name, DNSRecordType type, DNSRecordClass recordClass, boolean unique) { - _name = name; - // _key = (name != null ? name.trim().toLowerCase() : null); - _recordType = type; - _dnsClass = recordClass; - _unique = unique; - _qualifiedNameMap = ServiceInfoImpl.decodeQualifiedNameMapForType(this.getName()); - String domain = _qualifiedNameMap.get(Fields.Domain); - String protocol = _qualifiedNameMap.get(Fields.Protocol); - String application = _qualifiedNameMap.get(Fields.Application); - String instance = _qualifiedNameMap.get(Fields.Instance).toLowerCase(); - _type = (application.length() > 0 ? "_" + application + "." : "") + (protocol.length() > 0 ? "_" + protocol + "." : "") + domain + "."; - _key = ((instance.length() > 0 ? instance + "." : "") + _type).toLowerCase(); - } - - /* - * (non-Javadoc) - * @see java.lang.Object#equals(java.lang.Object) - */ - @Override - public boolean equals(Object obj) { - boolean result = false; - if (obj instanceof DNSEntry) { - DNSEntry other = (DNSEntry) obj; - result = this.getKey().equals(other.getKey()) && this.getRecordType().equals(other.getRecordType()) && this.getRecordClass() == other.getRecordClass(); - } - return result; - } - - /** - * Check if two entries have exactly the same name, type, and class. - * - * @param entry - * @return true if the two entries have are for the same record, false otherwise - */ - public boolean isSameEntry(DNSEntry entry) { - return this.getKey().equals(entry.getKey()) && this.matchRecordType(entry.getRecordType()) && this.matchRecordClass(entry.getRecordClass()); - } - - /** - * Check if the requested record class match the current record class - * - * @param recordClass - * @return true if the two entries have compatible class, false otherwise - */ - public boolean matchRecordClass(DNSRecordClass recordClass) { - return (DNSRecordClass.CLASS_ANY == recordClass) || (DNSRecordClass.CLASS_ANY == this.getRecordClass()) || this.getRecordClass().equals(recordClass); - } - - /** - * Check if the requested record tyep match the current record type - * - * @param recordType - * @return true if the two entries have compatible type, false otherwise - */ - public boolean matchRecordType(DNSRecordType recordType) { - return this.getRecordType().equals(recordType); - } - - /** - * Returns the subtype of this entry - * - * @return subtype of this entry - */ - public String getSubtype() { - String subtype = this.getQualifiedNameMap().get(Fields.Subtype); - return (subtype != null ? subtype : ""); - } - - /** - * Returns the name of this entry - * - * @return name of this entry - */ - public String getName() { - return (_name != null ? _name : ""); - } - - /** - * @return the type - */ - public String getType() { - return (_type != null ? _type : ""); - } - - /** - * Returns the key for this entry. The key is the lower case name. - * - * @return key for this entry - */ - public String getKey() { - return (_key != null ? _key : ""); - } - - /** - * @return record type - */ - public DNSRecordType getRecordType() { - return (_recordType != null ? _recordType : DNSRecordType.TYPE_IGNORE); - } - - /** - * @return record class - */ - public DNSRecordClass getRecordClass() { - return (_dnsClass != null ? _dnsClass : DNSRecordClass.CLASS_UNKNOWN); - } - - /** - * @return true if unique - */ - public boolean isUnique() { - return _unique; - } - - public Map getQualifiedNameMap() { - return Collections.unmodifiableMap(_qualifiedNameMap); - } - - /** - * Check if the record is expired. - * - * @param now - * update date - * @return true is the record is expired, false otherwise. - */ - public abstract boolean isExpired(long now); - - /** - * @param dout - * @exception IOException - */ - protected void toByteArray(DataOutputStream dout) throws IOException { - dout.write(this.getName().getBytes("UTF8")); - dout.writeShort(this.getRecordType().indexValue()); - dout.writeShort(this.getRecordClass().indexValue()); - } - - /** - * Creates a byte array representation of this record. This is needed for tie-break tests according to draft-cheshire-dnsext-multicastdns-04.txt chapter 9.2. - * - * @return byte array representation - */ - protected byte[] toByteArray() { - try { - ByteArrayOutputStream bout = new ByteArrayOutputStream(); - DataOutputStream dout = new DataOutputStream(bout); - this.toByteArray(dout); - dout.close(); - return bout.toByteArray(); - } catch (IOException e) { - throw new InternalError(); - } - } - - /** - * Overriden, to return a value which is consistent with the value returned by equals(Object). - */ - @Override - public int hashCode() { - return this.getKey().hashCode() + this.getRecordType().indexValue() + this.getRecordClass().indexValue(); - } - - @Override - public String toString() { - final StringBuilder sb = new StringBuilder(200); - sb.append('[').append(this.getClass().getSimpleName()).append('@').append(System.identityHashCode(this)); - sb.append(" type: ").append(this.getRecordType()); - sb.append(", class: ").append(this.getRecordClass()); - sb.append((_unique ? "-unique," : ",")); - sb.append(" name: ").append( _name); - sb.append(']'); - - return sb.toString(); - } + // private static Logger logger = LoggerFactory.getLogger(DNSEntry.class.getName()); + private final String _key; + + private final String _name; + + private final String _type; + + private final DNSRecordType _recordType; + + private final DNSRecordClass _dnsClass; + + private final boolean _unique; + + final Map _qualifiedNameMap; + + /** Create an entry. */ + DNSEntry(String name, DNSRecordType type, DNSRecordClass recordClass, boolean unique) { + _name = name; + // _key = (name != null ? name.trim().toLowerCase() : null); + _recordType = type; + _dnsClass = recordClass; + _unique = unique; + _qualifiedNameMap = ServiceInfoImpl.decodeQualifiedNameMapForType(this.getName()); + String domain = _qualifiedNameMap.get(Fields.Domain); + String protocol = _qualifiedNameMap.get(Fields.Protocol); + String application = _qualifiedNameMap.get(Fields.Application); + String instance = _qualifiedNameMap.get(Fields.Instance).toLowerCase(); + _type = + (application.length() > 0 ? "_" + application + "." : "") + + (protocol.length() > 0 ? "_" + protocol + "." : "") + + domain + + "."; + _key = ((instance.length() > 0 ? instance + "." : "") + _type).toLowerCase(); + } + + /* + * (non-Javadoc) + * @see java.lang.Object#equals(java.lang.Object) + */ + @Override + public boolean equals(Object obj) { + boolean result = false; + if (obj instanceof DNSEntry) { + DNSEntry other = (DNSEntry) obj; + result = + this.getKey().equals(other.getKey()) + && this.getRecordType().equals(other.getRecordType()) + && this.getRecordClass() == other.getRecordClass(); + } + return result; + } + + /** + * Check if two entries have exactly the same name, type, and class. + * + * @param entry + * @return true if the two entries have are for the same record, false + * otherwise + */ + public boolean isSameEntry(DNSEntry entry) { + return this.getKey().equals(entry.getKey()) + && this.matchRecordType(entry.getRecordType()) + && this.matchRecordClass(entry.getRecordClass()); + } + + /** + * Check if the requested record class match the current record class + * + * @param recordClass + * @return true if the two entries have compatible class, false + * otherwise + */ + public boolean matchRecordClass(DNSRecordClass recordClass) { + return (DNSRecordClass.CLASS_ANY == recordClass) + || (DNSRecordClass.CLASS_ANY == this.getRecordClass()) + || this.getRecordClass().equals(recordClass); + } + + /** + * Check if the requested record tyep match the current record type + * + * @param recordType + * @return true if the two entries have compatible type, false otherwise + */ + public boolean matchRecordType(DNSRecordType recordType) { + return this.getRecordType().equals(recordType); + } + + /** + * Returns the subtype of this entry + * + * @return subtype of this entry + */ + public String getSubtype() { + String subtype = this.getQualifiedNameMap().get(Fields.Subtype); + return (subtype != null ? subtype : ""); + } + + /** + * Returns the name of this entry + * + * @return name of this entry + */ + public String getName() { + return (_name != null ? _name : ""); + } + + /** + * @return the type + */ + public String getType() { + return (_type != null ? _type : ""); + } + + /** + * Returns the key for this entry. The key is the lower case name. + * + * @return key for this entry + */ + public String getKey() { + return (_key != null ? _key : ""); + } + + /** + * @return record type + */ + public DNSRecordType getRecordType() { + return (_recordType != null ? _recordType : DNSRecordType.TYPE_IGNORE); + } + + /** + * @return record class + */ + public DNSRecordClass getRecordClass() { + return (_dnsClass != null ? _dnsClass : DNSRecordClass.CLASS_UNKNOWN); + } + + /** + * @return true if unique + */ + public boolean isUnique() { + return _unique; + } + + public Map getQualifiedNameMap() { + return Collections.unmodifiableMap(_qualifiedNameMap); + } + + /** + * Check if the record is expired. + * + * @param now update date + * @return true is the record is expired, false otherwise. + */ + public abstract boolean isExpired(long now); + + /** + * @param dout + * @exception IOException + */ + protected void toByteArray(DataOutputStream dout) throws IOException { + dout.write(this.getName().getBytes("UTF8")); + dout.writeShort(this.getRecordType().indexValue()); + dout.writeShort(this.getRecordClass().indexValue()); + } + + /** + * Creates a byte array representation of this record. This is needed for tie-break tests + * according to draft-cheshire-dnsext-multicastdns-04.txt chapter 9.2. + * + * @return byte array representation + */ + protected byte[] toByteArray() { + try { + ByteArrayOutputStream bout = new ByteArrayOutputStream(); + DataOutputStream dout = new DataOutputStream(bout); + this.toByteArray(dout); + dout.close(); + return bout.toByteArray(); + } catch (IOException e) { + throw new InternalError(); + } + } + + /** Overriden, to return a value which is consistent with the value returned by equals(Object). */ + @Override + public int hashCode() { + return this.getKey().hashCode() + + this.getRecordType().indexValue() + + this.getRecordClass().indexValue(); + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder(200); + sb.append('[') + .append(this.getClass().getSimpleName()) + .append('@') + .append(System.identityHashCode(this)); + sb.append(" type: ").append(this.getRecordType()); + sb.append(", class: ").append(this.getRecordClass()); + sb.append((_unique ? "-unique," : ",")); + sb.append(" name: ").append(_name); + sb.append(']'); + + return sb.toString(); + } } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSIncoming.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSIncoming.java index f71bf1e29..078832553 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSIncoming.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSIncoming.java @@ -4,594 +4,666 @@ package io.libp2p.discovery.mdns.impl; +import io.libp2p.discovery.mdns.impl.constants.DNSConstants; +import io.libp2p.discovery.mdns.impl.constants.DNSLabel; +import io.libp2p.discovery.mdns.impl.constants.DNSOptionCode; +import io.libp2p.discovery.mdns.impl.constants.DNSRecordClass; +import io.libp2p.discovery.mdns.impl.constants.DNSRecordType; +import io.libp2p.discovery.mdns.impl.constants.DNSResultCode; import java.io.ByteArrayInputStream; import java.io.IOException; import java.net.DatagramPacket; import java.net.InetAddress; import java.util.HashMap; import java.util.Map; - -import io.libp2p.discovery.mdns.impl.constants.DNSRecordClass; -import io.libp2p.discovery.mdns.impl.constants.DNSRecordType; -import io.libp2p.discovery.mdns.impl.constants.DNSResultCode; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import io.libp2p.discovery.mdns.impl.constants.DNSConstants; -import io.libp2p.discovery.mdns.impl.constants.DNSLabel; -import io.libp2p.discovery.mdns.impl.constants.DNSOptionCode; - /** * Parse an incoming DNS message into its components. * * @author Arthur van Hoff, Werner Randelshofer, Pierre Frisch, Daniel Bobbert */ public final class DNSIncoming extends DNSMessage { - private static Logger logger = LoggerFactory.getLogger(DNSIncoming.class.getName()); - - // This is a hack to handle a bug in the BonjourConformanceTest - // It is sending out target strings that don't follow the "domain name" format. - public static boolean USE_DOMAIN_NAME_FORMAT_FOR_SRV_TARGET = true; - - public static class MessageInputStream extends ByteArrayInputStream { - private static Logger logger1 = LoggerFactory.getLogger(MessageInputStream.class.getName()); + private static Logger logger = LoggerFactory.getLogger(DNSIncoming.class.getName()); - final Map _names; + // This is a hack to handle a bug in the BonjourConformanceTest + // It is sending out target strings that don't follow the "domain name" format. + public static boolean USE_DOMAIN_NAME_FORMAT_FOR_SRV_TARGET = true; - public MessageInputStream(byte[] buffer, int length) { - this(buffer, 0, length); - } - - /** - * @param buffer - * @param offset - * @param length - */ - public MessageInputStream(byte[] buffer, int offset, int length) { - super(buffer, offset, length); - _names = new HashMap(); - } + public static class MessageInputStream extends ByteArrayInputStream { + private static Logger logger1 = LoggerFactory.getLogger(MessageInputStream.class.getName()); - public int readByte() { - return this.read(); - } + final Map _names; - public int readUnsignedByte() { - return (this.read() & 0xFF); - } + public MessageInputStream(byte[] buffer, int length) { + this(buffer, 0, length); + } - public int readUnsignedShort() { - return (this.readUnsignedByte() << 8) | this.readUnsignedByte(); - } + /** + * @param buffer + * @param offset + * @param length + */ + public MessageInputStream(byte[] buffer, int offset, int length) { + super(buffer, offset, length); + _names = new HashMap(); + } - public int readInt() { - return (this.readUnsignedShort() << 16) | this.readUnsignedShort(); - } + public int readByte() { + return this.read(); + } - public byte[] readBytes(int len) { - byte bytes[] = new byte[len]; - this.read(bytes, 0, len); - return bytes; - } + public int readUnsignedByte() { + return (this.read() & 0xFF); + } - public String readUTF(int len) { - final StringBuilder sb = new StringBuilder(len); - for (int index = 0; index < len; index++) { - int ch = this.readUnsignedByte(); - switch (ch >> 4) { - case 0: - case 1: - case 2: - case 3: - case 4: - case 5: - case 6: - case 7: - // 0xxxxxxx - break; - case 12: - case 13: - // 110x xxxx 10xx xxxx - ch = ((ch & 0x1F) << 6) | (this.readUnsignedByte() & 0x3F); - index++; - break; - case 14: - // 1110 xxxx 10xx xxxx 10xx xxxx - ch = ((ch & 0x0f) << 12) | ((this.readUnsignedByte() & 0x3F) << 6) | (this.readUnsignedByte() & 0x3F); - index++; - index++; - break; - default: - // 10xx xxxx, 1111 xxxx - ch = ((ch & 0x3F) << 4) | (this.readUnsignedByte() & 0x0f); - index++; - break; - } - sb.append((char) ch); - } - return sb.toString(); - } + public int readUnsignedShort() { + return (this.readUnsignedByte() << 8) | this.readUnsignedByte(); + } - protected synchronized int peek() { - return (pos < count) ? (buf[pos] & 0xff) : -1; - } + public int readInt() { + return (this.readUnsignedShort() << 16) | this.readUnsignedShort(); + } - public String readName() { - Map names = new HashMap(); - final StringBuilder sb = new StringBuilder(); - boolean finished = false; - while (!finished) { - int len = this.readUnsignedByte(); - if (len == 0) { - finished = true; - break; - } - switch (DNSLabel.labelForByte(len)) { - case Standard: - int offset = pos - 1; - String label = this.readUTF(len) + "."; - sb.append(label); - for (StringBuilder previousLabel : names.values()) { - previousLabel.append(label); - } - names.put(Integer.valueOf(offset), new StringBuilder(label)); - break; - case Compressed: - int index = (DNSLabel.labelValue(len) << 8) | this.readUnsignedByte(); - String compressedLabel = _names.get(Integer.valueOf(index)); - if (compressedLabel == null) { - logger1.warn("Bad domain name: possible circular name detected. Bad offset: 0x{} at 0x{}", - Integer.toHexString(index), - Integer.toHexString(pos - 2) - ); - compressedLabel = ""; - } - sb.append(compressedLabel); - for (StringBuilder previousLabel : names.values()) { - previousLabel.append(compressedLabel); - } - finished = true; - break; - case Extended: - // int extendedLabelClass = DNSLabel.labelValue(len); - logger1.debug("Extended label are not currently supported."); - break; - case Unknown: - default: - logger1.warn("Unsupported DNS label type: '{}'", Integer.toHexString(len & 0xC0) ); - } - } - for (final Map.Entry entry : names.entrySet()) { - final Integer index = entry.getKey(); - _names.put(index, entry.getValue().toString()); - } - return sb.toString(); - } + public byte[] readBytes(int len) { + byte bytes[] = new byte[len]; + this.read(bytes, 0, len); + return bytes; + } - public String readNonNameString() { - int len = this.readUnsignedByte(); - return this.readUTF(len); + public String readUTF(int len) { + final StringBuilder sb = new StringBuilder(len); + for (int index = 0; index < len; index++) { + int ch = this.readUnsignedByte(); + switch (ch >> 4) { + case 0: + case 1: + case 2: + case 3: + case 4: + case 5: + case 6: + case 7: + // 0xxxxxxx + break; + case 12: + case 13: + // 110x xxxx 10xx xxxx + ch = ((ch & 0x1F) << 6) | (this.readUnsignedByte() & 0x3F); + index++; + break; + case 14: + // 1110 xxxx 10xx xxxx 10xx xxxx + ch = + ((ch & 0x0f) << 12) + | ((this.readUnsignedByte() & 0x3F) << 6) + | (this.readUnsignedByte() & 0x3F); + index++; + index++; + break; + default: + // 10xx xxxx, 1111 xxxx + ch = ((ch & 0x3F) << 4) | (this.readUnsignedByte() & 0x0f); + index++; + break; } - + sb.append((char) ch); + } + return sb.toString(); } - private final DatagramPacket _packet; - - private final long _receivedTime; - - private final MessageInputStream _messageInputStream; - - private int _senderUDPPayload; - - /** - * Parse a message from a datagram packet. - * - * @param packet - * @exception IOException - */ - public DNSIncoming(DatagramPacket packet) throws IOException { - super(0, 0, packet.getPort() == DNSConstants.MDNS_PORT); - this._packet = packet; - InetAddress source = packet.getAddress(); - this._messageInputStream = new MessageInputStream(packet.getData(), packet.getLength()); - this._receivedTime = System.currentTimeMillis(); - this._senderUDPPayload = DNSConstants.MAX_MSG_TYPICAL; - - try { - this.setId(_messageInputStream.readUnsignedShort()); - this.setFlags(_messageInputStream.readUnsignedShort()); - if (this.getOperationCode() > 0) { - throw new IOException("Received a message with a non standard operation code. Currently unsupported in the specification."); - } - int numQuestions = _messageInputStream.readUnsignedShort(); - int numAnswers = _messageInputStream.readUnsignedShort(); - int numAuthorities = _messageInputStream.readUnsignedShort(); - int numAdditionals = _messageInputStream.readUnsignedShort(); - - logger.debug("DNSIncoming() questions:{} answers:{} authorities:{} additionals:{}", - numQuestions, - numAnswers, - numAuthorities, - numAdditionals - ); - - // We need some sanity checks - // A question is at least 5 bytes and answer 11 so check what we have - - if ((numQuestions * 5 + (numAnswers + numAuthorities + numAdditionals) * 11) > packet.getLength()) { - throw new IOException("questions:" + numQuestions + " answers:" + numAnswers + " authorities:" + numAuthorities + " additionals:" + numAdditionals); - } - - // parse questions - for (int i = 0; i < numQuestions; i++) { - DNSQuestion question = this.readQuestion(); - if (question != null) - _questions.add(question); - } - - // parse answers - for (int i = 0; i < numAnswers; i++) { - DNSRecord rec = this.readAnswer(); - if (rec != null) { - // Add a record, if we were able to create one. - _answers.add(rec); - } - } - - for (int i = 0; i < numAuthorities; i++) { - DNSRecord rec = this.readAnswer(); - if (rec != null) { - // Add a record, if we were able to create one. - _authoritativeAnswers.add(rec); - } - } + protected synchronized int peek() { + return (pos < count) ? (buf[pos] & 0xff) : -1; + } - for (int i = 0; i < numAdditionals; i++) { - DNSRecord rec = this.readAnswer(); - if (rec != null) { - // Add a record, if we were able to create one. - _additionals.add(rec); - } + public String readName() { + Map names = new HashMap(); + final StringBuilder sb = new StringBuilder(); + boolean finished = false; + while (!finished) { + int len = this.readUnsignedByte(); + if (len == 0) { + finished = true; + break; + } + switch (DNSLabel.labelForByte(len)) { + case Standard: + int offset = pos - 1; + String label = this.readUTF(len) + "."; + sb.append(label); + for (StringBuilder previousLabel : names.values()) { + previousLabel.append(label); } - - // We should have drained the entire stream by now - if (_messageInputStream.available() > 0) { - throw new IOException("Received a message with the wrong length."); + names.put(Integer.valueOf(offset), new StringBuilder(label)); + break; + case Compressed: + int index = (DNSLabel.labelValue(len) << 8) | this.readUnsignedByte(); + String compressedLabel = _names.get(Integer.valueOf(index)); + if (compressedLabel == null) { + logger1.warn( + "Bad domain name: possible circular name detected. Bad offset: 0x{} at 0x{}", + Integer.toHexString(index), + Integer.toHexString(pos - 2)); + compressedLabel = ""; } - } catch (Exception e) { - logger.warn("DNSIncoming() dump " + print(true) + "\n exception ", e); - // This ugly but some JVM don't implement the cause on IOException - IOException ioe = new IOException("DNSIncoming corrupted message"); - ioe.initCause(e); - throw ioe; - } finally { - try { - _messageInputStream.close(); - } catch (Exception e) { - logger.warn("MessageInputStream close error"); + sb.append(compressedLabel); + for (StringBuilder previousLabel : names.values()) { + previousLabel.append(compressedLabel); } + finished = true; + break; + case Extended: + // int extendedLabelClass = DNSLabel.labelValue(len); + logger1.debug("Extended label are not currently supported."); + break; + case Unknown: + default: + logger1.warn("Unsupported DNS label type: '{}'", Integer.toHexString(len & 0xC0)); } + } + for (final Map.Entry entry : names.entrySet()) { + final Integer index = entry.getKey(); + _names.put(index, entry.getValue().toString()); + } + return sb.toString(); } - private DNSIncoming(int flags, int id, boolean multicast, DatagramPacket packet, long receivedTime) { - super(flags, id, multicast); - this._packet = packet; - this._messageInputStream = new MessageInputStream(packet.getData(), packet.getLength()); - this._receivedTime = receivedTime; + public String readNonNameString() { + int len = this.readUnsignedByte(); + return this.readUTF(len); } + } + + private final DatagramPacket _packet; + + private final long _receivedTime; + + private final MessageInputStream _messageInputStream; + + private int _senderUDPPayload; + + /** + * Parse a message from a datagram packet. + * + * @param packet + * @exception IOException + */ + public DNSIncoming(DatagramPacket packet) throws IOException { + super(0, 0, packet.getPort() == DNSConstants.MDNS_PORT); + this._packet = packet; + InetAddress source = packet.getAddress(); + this._messageInputStream = new MessageInputStream(packet.getData(), packet.getLength()); + this._receivedTime = System.currentTimeMillis(); + this._senderUDPPayload = DNSConstants.MAX_MSG_TYPICAL; + + try { + this.setId(_messageInputStream.readUnsignedShort()); + this.setFlags(_messageInputStream.readUnsignedShort()); + if (this.getOperationCode() > 0) { + throw new IOException( + "Received a message with a non standard operation code. Currently unsupported in the specification."); + } + int numQuestions = _messageInputStream.readUnsignedShort(); + int numAnswers = _messageInputStream.readUnsignedShort(); + int numAuthorities = _messageInputStream.readUnsignedShort(); + int numAdditionals = _messageInputStream.readUnsignedShort(); + + logger.debug( + "DNSIncoming() questions:{} answers:{} authorities:{} additionals:{}", + numQuestions, + numAnswers, + numAuthorities, + numAdditionals); + + // We need some sanity checks + // A question is at least 5 bytes and answer 11 so check what we have + + if ((numQuestions * 5 + (numAnswers + numAuthorities + numAdditionals) * 11) + > packet.getLength()) { + throw new IOException( + "questions:" + + numQuestions + + " answers:" + + numAnswers + + " authorities:" + + numAuthorities + + " additionals:" + + numAdditionals); + } + + // parse questions + for (int i = 0; i < numQuestions; i++) { + DNSQuestion question = this.readQuestion(); + if (question != null) _questions.add(question); + } + + // parse answers + for (int i = 0; i < numAnswers; i++) { + DNSRecord rec = this.readAnswer(); + if (rec != null) { + // Add a record, if we were able to create one. + _answers.add(rec); + } + } - /* - * (non-Javadoc) - * @see java.lang.Object#clone() - */ - @Override - public DNSIncoming clone() { - DNSIncoming in = new DNSIncoming(this.getFlags(), this.getId(), this.isMulticast(), this._packet, this._receivedTime); - in._senderUDPPayload = this._senderUDPPayload; - in._questions.addAll(this._questions); - in._answers.addAll(this._answers); - in._authoritativeAnswers.addAll(this._authoritativeAnswers); - in._additionals.addAll(this._additionals); - return in; - } + for (int i = 0; i < numAuthorities; i++) { + DNSRecord rec = this.readAnswer(); + if (rec != null) { + // Add a record, if we were able to create one. + _authoritativeAnswers.add(rec); + } + } - private DNSQuestion readQuestion() { - String domain = _messageInputStream.readName(); - DNSRecordType type = DNSRecordType.typeForIndex(_messageInputStream.readUnsignedShort()); - if (type == DNSRecordType.TYPE_IGNORE) { - logger.warn("Could not find record type: {}", this.print(true)); + for (int i = 0; i < numAdditionals; i++) { + DNSRecord rec = this.readAnswer(); + if (rec != null) { + // Add a record, if we were able to create one. + _additionals.add(rec); } - int recordClassIndex = _messageInputStream.readUnsignedShort(); - DNSRecordClass recordClass = DNSRecordClass.classForIndex(recordClassIndex); - boolean unique = recordClass.isUnique(recordClassIndex); - return DNSQuestion.newQuestion(domain, type, recordClass, unique); + } + + // We should have drained the entire stream by now + if (_messageInputStream.available() > 0) { + throw new IOException("Received a message with the wrong length."); + } + } catch (Exception e) { + logger.warn("DNSIncoming() dump " + print(true) + "\n exception ", e); + // This ugly but some JVM don't implement the cause on IOException + IOException ioe = new IOException("DNSIncoming corrupted message"); + ioe.initCause(e); + throw ioe; + } finally { + try { + _messageInputStream.close(); + } catch (Exception e) { + logger.warn("MessageInputStream close error"); + } } - - private DNSRecord readAnswer() { - String domain = _messageInputStream.readName(); - DNSRecordType type = DNSRecordType.typeForIndex(_messageInputStream.readUnsignedShort()); - if (type == DNSRecordType.TYPE_IGNORE) { - logger.warn("Could not find record type. domain: {}\n{}", domain, this.print(true)); + } + + private DNSIncoming( + int flags, int id, boolean multicast, DatagramPacket packet, long receivedTime) { + super(flags, id, multicast); + this._packet = packet; + this._messageInputStream = new MessageInputStream(packet.getData(), packet.getLength()); + this._receivedTime = receivedTime; + } + + /* + * (non-Javadoc) + * @see java.lang.Object#clone() + */ + @Override + public DNSIncoming clone() { + DNSIncoming in = + new DNSIncoming( + this.getFlags(), this.getId(), this.isMulticast(), this._packet, this._receivedTime); + in._senderUDPPayload = this._senderUDPPayload; + in._questions.addAll(this._questions); + in._answers.addAll(this._answers); + in._authoritativeAnswers.addAll(this._authoritativeAnswers); + in._additionals.addAll(this._additionals); + return in; + } + + private DNSQuestion readQuestion() { + String domain = _messageInputStream.readName(); + DNSRecordType type = DNSRecordType.typeForIndex(_messageInputStream.readUnsignedShort()); + if (type == DNSRecordType.TYPE_IGNORE) { + logger.warn("Could not find record type: {}", this.print(true)); + } + int recordClassIndex = _messageInputStream.readUnsignedShort(); + DNSRecordClass recordClass = DNSRecordClass.classForIndex(recordClassIndex); + boolean unique = recordClass.isUnique(recordClassIndex); + return DNSQuestion.newQuestion(domain, type, recordClass, unique); + } + + private DNSRecord readAnswer() { + String domain = _messageInputStream.readName(); + DNSRecordType type = DNSRecordType.typeForIndex(_messageInputStream.readUnsignedShort()); + if (type == DNSRecordType.TYPE_IGNORE) { + logger.warn("Could not find record type. domain: {}\n{}", domain, this.print(true)); + } + int recordClassIndex = _messageInputStream.readUnsignedShort(); + DNSRecordClass recordClass = + (type == DNSRecordType.TYPE_OPT + ? DNSRecordClass.CLASS_UNKNOWN + : DNSRecordClass.classForIndex(recordClassIndex)); + if ((recordClass == DNSRecordClass.CLASS_UNKNOWN) && (type != DNSRecordType.TYPE_OPT)) { + logger.warn( + "Could not find record class. domain: {} type: {}\n{}", domain, type, this.print(true)); + } + boolean unique = recordClass.isUnique(recordClassIndex); + int ttl = _messageInputStream.readInt(); + int len = _messageInputStream.readUnsignedShort(); + DNSRecord rec = null; + + switch (type) { + case TYPE_A: // IPv4 + rec = + new DNSRecord.IPv4Address( + domain, recordClass, unique, ttl, _messageInputStream.readBytes(len)); + break; + case TYPE_AAAA: // IPv6 + rec = + new DNSRecord.IPv6Address( + domain, recordClass, unique, ttl, _messageInputStream.readBytes(len)); + break; + case TYPE_CNAME: + case TYPE_PTR: + String service = ""; + service = _messageInputStream.readName(); + if (service.length() > 0) { + rec = new DNSRecord.Pointer(domain, recordClass, unique, ttl, service); + } else { + logger.warn( + "PTR record of class: {}, there was a problem reading the service name of the answer for domain: {}", + recordClass, + domain); } - int recordClassIndex = _messageInputStream.readUnsignedShort(); - DNSRecordClass recordClass = (type == DNSRecordType.TYPE_OPT ? DNSRecordClass.CLASS_UNKNOWN : DNSRecordClass.classForIndex(recordClassIndex)); - if ((recordClass == DNSRecordClass.CLASS_UNKNOWN) && (type != DNSRecordType.TYPE_OPT)) { - logger.warn("Could not find record class. domain: {} type: {}\n{}", domain, type, this.print(true)); + break; + case TYPE_TXT: + rec = + new DNSRecord.Text( + domain, recordClass, unique, ttl, _messageInputStream.readBytes(len)); + break; + case TYPE_SRV: + int priority = _messageInputStream.readUnsignedShort(); + int weight = _messageInputStream.readUnsignedShort(); + int port = _messageInputStream.readUnsignedShort(); + String target = ""; + // This is a hack to handle a bug in the BonjourConformanceTest + // It is sending out target strings that don't follow the "domain name" format. + if (USE_DOMAIN_NAME_FORMAT_FOR_SRV_TARGET) { + target = _messageInputStream.readName(); + } else { + // [PJYF Nov 13 2010] Do we still need this? This looks really bad. All label are supposed + // to start by a length. + target = _messageInputStream.readNonNameString(); } - boolean unique = recordClass.isUnique(recordClassIndex); - int ttl = _messageInputStream.readInt(); - int len = _messageInputStream.readUnsignedShort(); - DNSRecord rec = null; - - switch (type) { - case TYPE_A: // IPv4 - rec = new DNSRecord.IPv4Address(domain, recordClass, unique, ttl, _messageInputStream.readBytes(len)); - break; - case TYPE_AAAA: // IPv6 - rec = new DNSRecord.IPv6Address(domain, recordClass, unique, ttl, _messageInputStream.readBytes(len)); - break; - case TYPE_CNAME: - case TYPE_PTR: - String service = ""; - service = _messageInputStream.readName(); - if (service.length() > 0) { - rec = new DNSRecord.Pointer(domain, recordClass, unique, ttl, service); - } else { - logger.warn("PTR record of class: {}, there was a problem reading the service name of the answer for domain: {}", recordClass, domain); + rec = + new DNSRecord.Service(domain, recordClass, unique, ttl, priority, weight, port, target); + break; + case TYPE_HINFO: + final StringBuilder sb = new StringBuilder(); + sb.append(_messageInputStream.readUTF(len)); + int index = sb.indexOf(" "); + String cpu = (index > 0 ? sb.substring(0, index) : sb.toString()).trim(); + String os = (index > 0 ? sb.substring(index + 1) : "").trim(); + rec = new DNSRecord.HostInformation(domain, recordClass, unique, ttl, cpu, os); + break; + case TYPE_OPT: + DNSResultCode extendedResultCode = DNSResultCode.resultCodeForFlags(this.getFlags(), ttl); + int version = (ttl & 0x00ff0000) >> 16; + if (version == 0) { + _senderUDPPayload = recordClassIndex; + while (_messageInputStream.available() > 0) { + // Read RDData + int optionCodeInt = 0; + DNSOptionCode optionCode = null; + if (_messageInputStream.available() >= 2) { + optionCodeInt = _messageInputStream.readUnsignedShort(); + optionCode = DNSOptionCode.resultCodeForFlags(optionCodeInt); + } else { + logger.warn("There was a problem reading the OPT record. Ignoring."); + break; + } + int optionLength = 0; + if (_messageInputStream.available() >= 2) { + optionLength = _messageInputStream.readUnsignedShort(); + } else { + logger.warn("There was a problem reading the OPT record. Ignoring."); + break; + } + byte[] optiondata = new byte[0]; + if (_messageInputStream.available() >= optionLength) { + optiondata = _messageInputStream.readBytes(optionLength); + } + // + // We should really do something with those options. + switch (optionCode) { + case Owner: + // Valid length values are 8, 14, 18 and 20 + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // |Opt|Len|V|S|Primary MAC|Wakeup MAC | Password | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // + int ownerVersion = 0; + int ownerSequence = 0; + byte[] ownerPrimaryMacAddress = null; + byte[] ownerWakeupMacAddress = null; + byte[] ownerPassword = null; + try { + ownerVersion = optiondata[0]; + ownerSequence = optiondata[1]; + ownerPrimaryMacAddress = + new byte[] { + optiondata[2], + optiondata[3], + optiondata[4], + optiondata[5], + optiondata[6], + optiondata[7] + }; + ownerWakeupMacAddress = ownerPrimaryMacAddress; + if (optiondata.length > 8) { + // We have a wakeupMacAddress. + ownerWakeupMacAddress = + new byte[] { + optiondata[8], + optiondata[9], + optiondata[10], + optiondata[11], + optiondata[12], + optiondata[13] + }; + } + if (optiondata.length == 18) { + // We have a short password. + ownerPassword = + new byte[] {optiondata[14], optiondata[15], optiondata[16], optiondata[17]}; + } + if (optiondata.length == 22) { + // We have a long password. + ownerPassword = + new byte[] { + optiondata[14], + optiondata[15], + optiondata[16], + optiondata[17], + optiondata[18], + optiondata[19], + optiondata[20], + optiondata[21] + }; + } + } catch (Exception exception) { + logger.warn( + "Malformed OPT answer. Option code: Owner data: {}", + this._hexString(optiondata)); } - break; - case TYPE_TXT: - rec = new DNSRecord.Text(domain, recordClass, unique, ttl, _messageInputStream.readBytes(len)); - break; - case TYPE_SRV: - int priority = _messageInputStream.readUnsignedShort(); - int weight = _messageInputStream.readUnsignedShort(); - int port = _messageInputStream.readUnsignedShort(); - String target = ""; - // This is a hack to handle a bug in the BonjourConformanceTest - // It is sending out target strings that don't follow the "domain name" format. - if (USE_DOMAIN_NAME_FORMAT_FOR_SRV_TARGET) { - target = _messageInputStream.readName(); - } else { - // [PJYF Nov 13 2010] Do we still need this? This looks really bad. All label are supposed to start by a length. - target = _messageInputStream.readNonNameString(); + if (logger.isDebugEnabled()) { + logger.debug( + "Unhandled Owner OPT version: {} sequence: {} MAC address: {} {}{} {}{}", + ownerVersion, + ownerSequence, + this._hexString(ownerPrimaryMacAddress), + (ownerWakeupMacAddress != ownerPrimaryMacAddress + ? " wakeup MAC address: " + : ""), + (ownerWakeupMacAddress != ownerPrimaryMacAddress + ? this._hexString(ownerWakeupMacAddress) + : ""), + (ownerPassword != null ? " password: " : ""), + (ownerPassword != null ? this._hexString(ownerPassword) : "")); } - rec = new DNSRecord.Service(domain, recordClass, unique, ttl, priority, weight, port, target); break; - case TYPE_HINFO: - final StringBuilder sb = new StringBuilder(); - sb.append(_messageInputStream.readUTF(len)); - int index = sb.indexOf(" "); - String cpu = (index > 0 ? sb.substring(0, index) : sb.toString()).trim(); - String os = (index > 0 ? sb.substring(index + 1) : "").trim(); - rec = new DNSRecord.HostInformation(domain, recordClass, unique, ttl, cpu, os); + case LLQ: + case NSID: + case UL: + if (logger.isDebugEnabled()) { + logger.debug( + "There was an OPT answer. Option code: {} data: {}", + optionCode, + this._hexString(optiondata)); + } break; - case TYPE_OPT: - DNSResultCode extendedResultCode = DNSResultCode.resultCodeForFlags(this.getFlags(), ttl); - int version = (ttl & 0x00ff0000) >> 16; - if (version == 0) { - _senderUDPPayload = recordClassIndex; - while (_messageInputStream.available() > 0) { - // Read RDData - int optionCodeInt = 0; - DNSOptionCode optionCode = null; - if (_messageInputStream.available() >= 2) { - optionCodeInt = _messageInputStream.readUnsignedShort(); - optionCode = DNSOptionCode.resultCodeForFlags(optionCodeInt); - } else { - logger.warn("There was a problem reading the OPT record. Ignoring."); - break; - } - int optionLength = 0; - if (_messageInputStream.available() >= 2) { - optionLength = _messageInputStream.readUnsignedShort(); - } else { - logger.warn("There was a problem reading the OPT record. Ignoring."); - break; - } - byte[] optiondata = new byte[0]; - if (_messageInputStream.available() >= optionLength) { - optiondata = _messageInputStream.readBytes(optionLength); - } - // - // We should really do something with those options. - switch (optionCode) { - case Owner: - // Valid length values are 8, 14, 18 and 20 - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // |Opt|Len|V|S|Primary MAC|Wakeup MAC | Password | - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // - int ownerVersion = 0; - int ownerSequence = 0; - byte[] ownerPrimaryMacAddress = null; - byte[] ownerWakeupMacAddress = null; - byte[] ownerPassword = null; - try { - ownerVersion = optiondata[0]; - ownerSequence = optiondata[1]; - ownerPrimaryMacAddress = new byte[] { optiondata[2], optiondata[3], optiondata[4], optiondata[5], optiondata[6], optiondata[7] }; - ownerWakeupMacAddress = ownerPrimaryMacAddress; - if (optiondata.length > 8) { - // We have a wakeupMacAddress. - ownerWakeupMacAddress = new byte[] { optiondata[8], optiondata[9], optiondata[10], optiondata[11], optiondata[12], optiondata[13] }; - } - if (optiondata.length == 18) { - // We have a short password. - ownerPassword = new byte[] { optiondata[14], optiondata[15], optiondata[16], optiondata[17] }; - } - if (optiondata.length == 22) { - // We have a long password. - ownerPassword = new byte[] { optiondata[14], optiondata[15], optiondata[16], optiondata[17], optiondata[18], optiondata[19], optiondata[20], optiondata[21] }; - } - } catch (Exception exception) { - logger.warn("Malformed OPT answer. Option code: Owner data: {}", this._hexString(optiondata)); - } - if (logger.isDebugEnabled()) { - logger.debug("Unhandled Owner OPT version: {} sequence: {} MAC address: {} {}{} {}{}", - ownerVersion, - ownerSequence, - this._hexString(ownerPrimaryMacAddress), - (ownerWakeupMacAddress != ownerPrimaryMacAddress ? " wakeup MAC address: " : ""), - (ownerWakeupMacAddress != ownerPrimaryMacAddress ? this._hexString(ownerWakeupMacAddress) : ""), - (ownerPassword != null ? " password: ": ""), - (ownerPassword != null ? this._hexString(ownerPassword) : "") - ); - } - break; - case LLQ: - case NSID: - case UL: - if (logger.isDebugEnabled()) { - logger.debug("There was an OPT answer. Option code: {} data: {}", optionCode, this._hexString(optiondata)); - } - break; - case Unknown: - if (optionCodeInt >= 65001 && optionCodeInt <= 65534) { - // RFC 6891 defines this range as used for experimental/local purposes. - logger.debug("There was an OPT answer using an experimental/local option code: {} data: {}", optionCodeInt, this._hexString(optiondata)); - } else { - logger.warn("There was an OPT answer. Not currently handled. Option code: {} data: {}", optionCodeInt, this._hexString(optiondata)); - } - break; - default: - // This is to keep the compiler happy. - break; - } - } + case Unknown: + if (optionCodeInt >= 65001 && optionCodeInt <= 65534) { + // RFC 6891 defines this range as used for experimental/local purposes. + logger.debug( + "There was an OPT answer using an experimental/local option code: {} data: {}", + optionCodeInt, + this._hexString(optiondata)); } else { - logger.warn("There was an OPT answer. Wrong version number: {} result code: {}", version, extendedResultCode); + logger.warn( + "There was an OPT answer. Not currently handled. Option code: {} data: {}", + optionCodeInt, + this._hexString(optiondata)); } break; - default: - logger.debug("DNSIncoming() unknown type: {}", type); - _messageInputStream.skip(len); + default: + // This is to keep the compiler happy. break; + } + } + } else { + logger.warn( + "There was an OPT answer. Wrong version number: {} result code: {}", + version, + extendedResultCode); } - return rec; + break; + default: + logger.debug("DNSIncoming() unknown type: {}", type); + _messageInputStream.skip(len); + break; } - - /** - * Debugging. - */ - String print(boolean dump) { - final StringBuilder sb = new StringBuilder(); - sb.append(this.print()); - if (dump) { - byte[] data = new byte[_packet.getLength()]; - System.arraycopy(_packet.getData(), 0, data, 0, data.length); - sb.append(this.print(data)); - } - return sb.toString(); + return rec; + } + + /** Debugging. */ + String print(boolean dump) { + final StringBuilder sb = new StringBuilder(); + sb.append(this.print()); + if (dump) { + byte[] data = new byte[_packet.getLength()]; + System.arraycopy(_packet.getData(), 0, data, 0, data.length); + sb.append(this.print(data)); } - - @Override - public String toString() { - final StringBuilder sb = new StringBuilder(); - sb.append(isQuery() ? "dns[query," : "dns[response,"); - if (_packet.getAddress() != null) { - sb.append(_packet.getAddress().getHostAddress()); - } - sb.append(':'); - sb.append(_packet.getPort()); - sb.append(", length="); - sb.append(_packet.getLength()); - sb.append(", id=0x"); - sb.append(Integer.toHexString(this.getId())); - if (this.getFlags() != 0) { - sb.append(", flags=0x"); - sb.append(Integer.toHexString(this.getFlags())); - if ((this.getFlags() & DNSConstants.FLAGS_QR_RESPONSE) != 0) { - sb.append(":r"); - } - if ((this.getFlags() & DNSConstants.FLAGS_AA) != 0) { - sb.append(":aa"); - } - if ((this.getFlags() & DNSConstants.FLAGS_TC) != 0) { - sb.append(":tc"); - } - } - if (this.getNumberOfQuestions() > 0) { - sb.append(", questions="); - sb.append(this.getNumberOfQuestions()); - } - if (this.getNumberOfAnswers() > 0) { - sb.append(", answers="); - sb.append(this.getNumberOfAnswers()); - } - if (this.getNumberOfAuthorities() > 0) { - sb.append(", authorities="); - sb.append(this.getNumberOfAuthorities()); - } - if (this.getNumberOfAdditionals() > 0) { - sb.append(", additionals="); - sb.append(this.getNumberOfAdditionals()); - } - if (this.getNumberOfQuestions() > 0) { - sb.append("\nquestions:"); - for (DNSQuestion question : _questions) { - sb.append("\n\t"); - sb.append(question); - } - } - if (this.getNumberOfAnswers() > 0) { - sb.append("\nanswers:"); - for (DNSRecord record : _answers) { - sb.append("\n\t"); - sb.append(record); - } - } - if (this.getNumberOfAuthorities() > 0) { - sb.append("\nauthorities:"); - for (DNSRecord record : _authoritativeAnswers) { - sb.append("\n\t"); - sb.append(record); - } - } - if (this.getNumberOfAdditionals() > 0) { - sb.append("\nadditionals:"); - for (DNSRecord record : _additionals) { - sb.append("\n\t"); - sb.append(record); - } - } - sb.append(']'); - - return sb.toString(); + return sb.toString(); + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder(); + sb.append(isQuery() ? "dns[query," : "dns[response,"); + if (_packet.getAddress() != null) { + sb.append(_packet.getAddress().getHostAddress()); } - - public int elapseSinceArrival() { - return (int) (System.currentTimeMillis() - _receivedTime); + sb.append(':'); + sb.append(_packet.getPort()); + sb.append(", length="); + sb.append(_packet.getLength()); + sb.append(", id=0x"); + sb.append(Integer.toHexString(this.getId())); + if (this.getFlags() != 0) { + sb.append(", flags=0x"); + sb.append(Integer.toHexString(this.getFlags())); + if ((this.getFlags() & DNSConstants.FLAGS_QR_RESPONSE) != 0) { + sb.append(":r"); + } + if ((this.getFlags() & DNSConstants.FLAGS_AA) != 0) { + sb.append(":aa"); + } + if ((this.getFlags() & DNSConstants.FLAGS_TC) != 0) { + sb.append(":tc"); + } } - - /** - * This will return the default UDP payload except if an OPT record was found with a different size. - * - * @return the senderUDPPayload - */ - public int getSenderUDPPayload() { - return this._senderUDPPayload; + if (this.getNumberOfQuestions() > 0) { + sb.append(", questions="); + sb.append(this.getNumberOfQuestions()); } - - private static final char[] _nibbleToHex = { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' }; - - /** - * Returns a hex-string for printing - * - * @param bytes - * @return Returns a hex-string which can be used within a SQL expression - */ - private String _hexString(byte[] bytes) { - - final StringBuilder result = new StringBuilder(2 * bytes.length); - - for (int i = 0; i < bytes.length; i++) { - int b = bytes[i] & 0xFF; - result.append(_nibbleToHex[b / 16]); - result.append(_nibbleToHex[b % 16]); - } - - return result.toString(); + if (this.getNumberOfAnswers() > 0) { + sb.append(", answers="); + sb.append(this.getNumberOfAnswers()); + } + if (this.getNumberOfAuthorities() > 0) { + sb.append(", authorities="); + sb.append(this.getNumberOfAuthorities()); + } + if (this.getNumberOfAdditionals() > 0) { + sb.append(", additionals="); + sb.append(this.getNumberOfAdditionals()); + } + if (this.getNumberOfQuestions() > 0) { + sb.append("\nquestions:"); + for (DNSQuestion question : _questions) { + sb.append("\n\t"); + sb.append(question); + } + } + if (this.getNumberOfAnswers() > 0) { + sb.append("\nanswers:"); + for (DNSRecord record : _answers) { + sb.append("\n\t"); + sb.append(record); + } + } + if (this.getNumberOfAuthorities() > 0) { + sb.append("\nauthorities:"); + for (DNSRecord record : _authoritativeAnswers) { + sb.append("\n\t"); + sb.append(record); + } + } + if (this.getNumberOfAdditionals() > 0) { + sb.append("\nadditionals:"); + for (DNSRecord record : _additionals) { + sb.append("\n\t"); + sb.append(record); + } + } + sb.append(']'); + + return sb.toString(); + } + + public int elapseSinceArrival() { + return (int) (System.currentTimeMillis() - _receivedTime); + } + + /** + * This will return the default UDP payload except if an OPT record was found with a different + * size. + * + * @return the senderUDPPayload + */ + public int getSenderUDPPayload() { + return this._senderUDPPayload; + } + + private static final char[] _nibbleToHex = { + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' + }; + + /** + * Returns a hex-string for printing + * + * @param bytes + * @return Returns a hex-string which can be used within a SQL expression + */ + private String _hexString(byte[] bytes) { + + final StringBuilder result = new StringBuilder(2 * bytes.length); + + for (int i = 0; i < bytes.length; i++) { + int b = bytes[i] & 0xFF; + result.append(_nibbleToHex[b / 16]); + result.append(_nibbleToHex[b % 16]); } + return result.toString(); + } } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSMessage.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSMessage.java index a22a63a42..24a6cc812 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSMessage.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSMessage.java @@ -1,16 +1,13 @@ -/** - * - */ +/** */ package io.libp2p.discovery.mdns.impl; +import io.libp2p.discovery.mdns.impl.constants.DNSConstants; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.LinkedList; import java.util.List; -import io.libp2p.discovery.mdns.impl.constants.DNSConstants; - /** * DNSMessage define a DNS message either incoming or outgoing. * @@ -18,318 +15,315 @@ */ public abstract class DNSMessage { - /** - * - */ - public static final boolean MULTICAST = true; - - /** - * - */ - public static final boolean UNICAST = false; - - // protected DatagramPacket _packet; - // protected int _off; - // protected int _len; - // protected byte[] _data; - - private int _id; - - boolean _multicast; - - private int _flags; - - protected final List _questions; - - protected final List _answers; - - protected final List _authoritativeAnswers; - - protected final List _additionals; - - /** - * @param flags - * @param id - * @param multicast - */ - protected DNSMessage(int flags, int id, boolean multicast) { - super(); - _flags = flags; - _id = id; - _multicast = multicast; - _questions = Collections.synchronizedList(new LinkedList()); - _answers = Collections.synchronizedList(new LinkedList()); - _authoritativeAnswers = Collections.synchronizedList(new LinkedList()); - _additionals = Collections.synchronizedList(new LinkedList()); + /** */ + public static final boolean MULTICAST = true; + + /** */ + public static final boolean UNICAST = false; + + // protected DatagramPacket _packet; + // protected int _off; + // protected int _len; + // protected byte[] _data; + + private int _id; + + boolean _multicast; + + private int _flags; + + protected final List _questions; + + protected final List _answers; + + protected final List _authoritativeAnswers; + + protected final List _additionals; + + /** + * @param flags + * @param id + * @param multicast + */ + protected DNSMessage(int flags, int id, boolean multicast) { + super(); + _flags = flags; + _id = id; + _multicast = multicast; + _questions = Collections.synchronizedList(new LinkedList()); + _answers = Collections.synchronizedList(new LinkedList()); + _authoritativeAnswers = Collections.synchronizedList(new LinkedList()); + _additionals = Collections.synchronizedList(new LinkedList()); + } + + // public DatagramPacket getPacket() { + // return _packet; + // } + // + // public int getOffset() { + // return _off; + // } + // + // public int getLength() { + // return _len; + // } + // + // public byte[] getData() { + // if ( _data == null ) _data = new byte[DNSConstants.MAX_MSG_TYPICAL]; + // return _data; + // } + + /** + * @return message id + */ + public int getId() { + return (_multicast ? 0 : _id); + } + + /** + * @param id the id to set + */ + public void setId(int id) { + this._id = id; + } + + /** + * @return message flags + */ + public int getFlags() { + return _flags; + } + + /** + * @param flags the flags to set + */ + public void setFlags(int flags) { + this._flags = flags; + } + + /** + * @return true if multicast + */ + public boolean isMulticast() { + return _multicast; + } + + /** + * @return list of questions + */ + public Collection getQuestions() { + return _questions; + } + + /** + * @return number of questions in the message + */ + public int getNumberOfQuestions() { + return this.getQuestions().size(); + } + + public List getAllAnswers() { + List aList = + new ArrayList( + _answers.size() + _authoritativeAnswers.size() + _additionals.size()); + aList.addAll(_answers); + aList.addAll(_authoritativeAnswers); + aList.addAll(_additionals); + return aList; + } + + /** + * @return list of answers + */ + public Collection getAnswers() { + return _answers; + } + + /** + * @return number of answers in the message + */ + public int getNumberOfAnswers() { + return this.getAnswers().size(); + } + + /** + * @return list of authorities + */ + public Collection getAuthorities() { + return _authoritativeAnswers; + } + + /** + * @return number of authorities in the message + */ + public int getNumberOfAuthorities() { + return this.getAuthorities().size(); + } + + /** + * @return list of additional answers + */ + public Collection getAdditionals() { + return _additionals; + } + + /** + * @return number of additional in the message + */ + public int getNumberOfAdditionals() { + return this.getAdditionals().size(); + } + + /** + * Check is the response code is valid
+ * The only valid value is zero all other values signify an error and the message must be ignored. + * + * @return true if the message has a valid response code. + */ + public boolean isValidResponseCode() { + return (_flags & DNSConstants.FLAGS_RCODE) == 0; + } + + /** + * Returns the operation code value. Currently only standard query 0 is valid. + * + * @return The operation code value. + */ + public int getOperationCode() { + return (_flags & DNSConstants.FLAGS_OPCODE) >> 11; + } + + /** + * Check if the message is truncated. + * + * @return true if the message was truncated + */ + public boolean isTruncated() { + return (_flags & DNSConstants.FLAGS_TC) != 0; + } + + /** + * Check if the message is an authoritative answer. + * + * @return true if the message is an authoritative answer + */ + public boolean isAuthoritativeAnswer() { + return (_flags & DNSConstants.FLAGS_AA) != 0; + } + + /** + * Check if the message is a query. + * + * @return true is the message is a query + */ + public boolean isQuery() { + return (_flags & DNSConstants.FLAGS_QR_MASK) == DNSConstants.FLAGS_QR_QUERY; + } + + /** + * Check if the message is a response. + * + * @return true is the message is a response + */ + public boolean isResponse() { + return (_flags & DNSConstants.FLAGS_QR_MASK) == DNSConstants.FLAGS_QR_RESPONSE; + } + + /** + * Check if the message is empty + * + * @return true is the message is empty + */ + public boolean isEmpty() { + return (this.getNumberOfQuestions() + + this.getNumberOfAnswers() + + this.getNumberOfAuthorities() + + this.getNumberOfAdditionals()) + == 0; + } + + /** Debugging. */ + String print() { + final StringBuilder sb = new StringBuilder(200); + sb.append(this.toString()); + sb.append("\n"); + for (final DNSQuestion question : _questions) { + sb.append("\tquestion: "); + sb.append(question); + sb.append("\n"); } - - // public DatagramPacket getPacket() { - // return _packet; - // } - // - // public int getOffset() { - // return _off; - // } - // - // public int getLength() { - // return _len; - // } - // - // public byte[] getData() { - // if ( _data == null ) _data = new byte[DNSConstants.MAX_MSG_TYPICAL]; - // return _data; - // } - - /** - * @return message id - */ - public int getId() { - return (_multicast ? 0 : _id); - } - - /** - * @param id - * the id to set - */ - public void setId(int id) { - this._id = id; - } - - /** - * @return message flags - */ - public int getFlags() { - return _flags; - } - - /** - * @param flags - * the flags to set - */ - public void setFlags(int flags) { - this._flags = flags; + for (final DNSRecord answer : _answers) { + sb.append("\tanswer: "); + sb.append(answer); + sb.append("\n"); } - - /** - * @return true if multicast - */ - public boolean isMulticast() { - return _multicast; + for (final DNSRecord answer : _authoritativeAnswers) { + sb.append("\tauthoritative: "); + sb.append(answer); + sb.append("\n"); } - - /** - * @return list of questions - */ - public Collection getQuestions() { - return _questions; - } - - /** - * @return number of questions in the message - */ - public int getNumberOfQuestions() { - return this.getQuestions().size(); - } - - public List getAllAnswers() { - List aList = new ArrayList(_answers.size() + _authoritativeAnswers.size() + _additionals.size()); - aList.addAll(_answers); - aList.addAll(_authoritativeAnswers); - aList.addAll(_additionals); - return aList; - } - - /** - * @return list of answers - */ - public Collection getAnswers() { - return _answers; + for (DNSRecord answer : _additionals) { + sb.append("\tadditional: "); + sb.append(answer); + sb.append("\n"); } - - /** - * @return number of answers in the message - */ - public int getNumberOfAnswers() { - return this.getAnswers().size(); - } - - /** - * @return list of authorities - */ - public Collection getAuthorities() { - return _authoritativeAnswers; - } - - /** - * @return number of authorities in the message - */ - public int getNumberOfAuthorities() { - return this.getAuthorities().size(); - } - - /** - * @return list of additional answers - */ - public Collection getAdditionals() { - return _additionals; - } - - /** - * @return number of additional in the message - */ - public int getNumberOfAdditionals() { - return this.getAdditionals().size(); - } - - /** - * Check is the response code is valid
- * The only valid value is zero all other values signify an error and the message must be ignored. - * - * @return true if the message has a valid response code. - */ - public boolean isValidResponseCode() { - return (_flags & DNSConstants.FLAGS_RCODE) == 0; - } - - /** - * Returns the operation code value. Currently only standard query 0 is valid. - * - * @return The operation code value. - */ - public int getOperationCode() { - return (_flags & DNSConstants.FLAGS_OPCODE) >> 11; - } - - /** - * Check if the message is truncated. - * - * @return true if the message was truncated - */ - public boolean isTruncated() { - return (_flags & DNSConstants.FLAGS_TC) != 0; - } - - /** - * Check if the message is an authoritative answer. - * - * @return true if the message is an authoritative answer - */ - public boolean isAuthoritativeAnswer() { - return (_flags & DNSConstants.FLAGS_AA) != 0; - } - - /** - * Check if the message is a query. - * - * @return true is the message is a query - */ - public boolean isQuery() { - return (_flags & DNSConstants.FLAGS_QR_MASK) == DNSConstants.FLAGS_QR_QUERY; - } - - /** - * Check if the message is a response. - * - * @return true is the message is a response - */ - public boolean isResponse() { - return (_flags & DNSConstants.FLAGS_QR_MASK) == DNSConstants.FLAGS_QR_RESPONSE; - } - - /** - * Check if the message is empty - * - * @return true is the message is empty - */ - public boolean isEmpty() { - return (this.getNumberOfQuestions() + this.getNumberOfAnswers() + this.getNumberOfAuthorities() + this.getNumberOfAdditionals()) == 0; - } - - /** - * Debugging. - */ - String print() { - final StringBuilder sb = new StringBuilder(200); - sb.append(this.toString()); - sb.append("\n"); - for (final DNSQuestion question : _questions) { - sb.append("\tquestion: "); - sb.append(question); - sb.append("\n"); + return sb.toString(); + } + + /** + * Debugging. + * + * @param data + * @return data dump + */ + protected String print(byte[] data) { + final StringBuilder sb = new StringBuilder(4000); + for (int off = 0, len = data.length; off < len; off += 32) { + int n = Math.min(32, len - off); + if (off < 0x10) { + sb.append(' '); + } + if (off < 0x100) { + sb.append(' '); + } + if (off < 0x1000) { + sb.append(' '); + } + sb.append(Integer.toHexString(off)); + sb.append(':'); + int index = 0; + for (index = 0; index < n; index++) { + if ((index % 8) == 0) { + sb.append(' '); } - for (final DNSRecord answer : _answers) { - sb.append("\tanswer: "); - sb.append(answer); - sb.append("\n"); + sb.append(Integer.toHexString((data[off + index] & 0xF0) >> 4)); + sb.append(Integer.toHexString((data[off + index] & 0x0F) >> 0)); + } + // for incomplete lines + if (index < 32) { + for (int i = index; i < 32; i++) { + if ((i % 8) == 0) { + sb.append(' '); + } + sb.append(" "); } - for (final DNSRecord answer : _authoritativeAnswers) { - sb.append("\tauthoritative: "); - sb.append(answer); - sb.append("\n"); + } + sb.append(" "); + for (index = 0; index < n; index++) { + if ((index % 8) == 0) { + sb.append(' '); } - for (DNSRecord answer : _additionals) { - sb.append("\tadditional: "); - sb.append(answer); - sb.append("\n"); - } - return sb.toString(); + int ch = data[off + index] & 0xFF; + sb.append(((ch > ' ') && (ch < 127)) ? (char) ch : '.'); + } + sb.append("\n"); + + // limit message size + if (off + 32 >= 2048) { + sb.append("....\n"); + break; + } } - - /** - * Debugging. - * - * @param data - * @return data dump - */ - protected String print(byte[] data) { - final StringBuilder sb = new StringBuilder(4000); - for (int off = 0, len = data.length; off < len; off += 32) { - int n = Math.min(32, len - off); - if (off < 0x10) { - sb.append(' '); - } - if (off < 0x100) { - sb.append(' '); - } - if (off < 0x1000) { - sb.append(' '); - } - sb.append(Integer.toHexString(off)); - sb.append(':'); - int index = 0; - for (index = 0; index < n; index++) { - if ((index % 8) == 0) { - sb.append(' '); - } - sb.append(Integer.toHexString((data[off + index] & 0xF0) >> 4)); - sb.append(Integer.toHexString((data[off + index] & 0x0F) >> 0)); - } - // for incomplete lines - if (index < 32) { - for (int i = index; i < 32; i++) { - if ((i % 8) == 0) { - sb.append(' '); - } - sb.append(" "); - } - } - sb.append(" "); - for (index = 0; index < n; index++) { - if ((index % 8) == 0) { - sb.append(' '); - } - int ch = data[off + index] & 0xFF; - sb.append(((ch > ' ') && (ch < 127)) ? (char) ch : '.'); - } - sb.append("\n"); - - // limit message size - if (off + 32 >= 2048) { - sb.append("....\n"); - break; - } - } - return sb.toString(); - } - + return sb.toString(); + } } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSOutgoing.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSOutgoing.java index fe2b33242..3b42a1738 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSOutgoing.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSOutgoing.java @@ -4,425 +4,426 @@ package io.libp2p.discovery.mdns.impl; +import io.libp2p.discovery.mdns.impl.constants.DNSConstants; import io.libp2p.discovery.mdns.impl.constants.DNSRecordClass; - import java.io.ByteArrayOutputStream; import java.io.IOException; import java.net.InetSocketAddress; import java.util.HashMap; import java.util.Map; -import io.libp2p.discovery.mdns.impl.constants.DNSConstants; - /** * An outgoing DNS message. * * @author Arthur van Hoff, Rick Blair, Werner Randelshofer */ public final class DNSOutgoing extends DNSMessage { - public static class MessageOutputStream extends ByteArrayOutputStream { - private final DNSOutgoing _out; - - private final int _offset; - - /** - * Creates a new message stream, with a buffer capacity of the specified size, in bytes. - * - * @param size - * the initial size. - * @exception IllegalArgumentException - * if size is negative. - */ - MessageOutputStream(int size, DNSOutgoing out) { - this(size, out, 0); - } + public static class MessageOutputStream extends ByteArrayOutputStream { + private final DNSOutgoing _out; - MessageOutputStream(int size, DNSOutgoing out, int offset) { - super(size); - _out = out; - _offset = offset; - } + private final int _offset; - void writeByte(int value) { - this.write(value & 0xFF); - } + /** + * Creates a new message stream, with a buffer capacity of the specified size, in bytes. + * + * @param size the initial size. + * @exception IllegalArgumentException if size is negative. + */ + MessageOutputStream(int size, DNSOutgoing out) { + this(size, out, 0); + } - void writeBytes(String str, int off, int len) { - for (int i = 0; i < len; i++) { - writeByte(str.charAt(off + i)); - } - } + MessageOutputStream(int size, DNSOutgoing out, int offset) { + super(size); + _out = out; + _offset = offset; + } - public void writeBytes(byte data[]) { - if (data != null) { - writeBytes(data, 0, data.length); - } - } + void writeByte(int value) { + this.write(value & 0xFF); + } - void writeBytes(byte data[], int off, int len) { - for (int i = 0; i < len; i++) { - writeByte(data[off + i]); - } - } + void writeBytes(String str, int off, int len) { + for (int i = 0; i < len; i++) { + writeByte(str.charAt(off + i)); + } + } - void writeShort(int value) { - writeByte(value >> 8); - writeByte(value); - } + public void writeBytes(byte data[]) { + if (data != null) { + writeBytes(data, 0, data.length); + } + } - void writeInt(int value) { - writeShort(value >> 16); - writeShort(value); - } + void writeBytes(byte data[], int off, int len) { + for (int i = 0; i < len; i++) { + writeByte(data[off + i]); + } + } - void writeUTF(String str, int off, int len) { - // compute utf length - int utflen = 0; - for (int i = 0; i < len; i++) { - int ch = str.charAt(off + i); - if ((ch >= 0x0001) && (ch <= 0x007F)) { - utflen += 1; - } else { - if (ch > 0x07FF) { - utflen += 3; - } else { - utflen += 2; - } - } - } - // write utf length - writeByte(utflen); - // write utf data - for (int i = 0; i < len; i++) { - int ch = str.charAt(off + i); - if ((ch >= 0x0001) && (ch <= 0x007F)) { - writeByte(ch); - } else { - if (ch > 0x07FF) { - writeByte(0xE0 | ((ch >> 12) & 0x0F)); - writeByte(0x80 | ((ch >> 6) & 0x3F)); - writeByte(0x80 | ((ch >> 0) & 0x3F)); - } else { - writeByte(0xC0 | ((ch >> 6) & 0x1F)); - writeByte(0x80 | ((ch >> 0) & 0x3F)); - } - } - } - } + void writeShort(int value) { + writeByte(value >> 8); + writeByte(value); + } - void writeName(String name) { - writeName(name, true); - } + void writeInt(int value) { + writeShort(value >> 16); + writeShort(value); + } - void writeName(String name, boolean useCompression) { - String aName = name; - while (true) { - int n = indexOfSeparator(aName); - if (n < 0) { - n = aName.length(); - } - if (n <= 0) { - writeByte(0); - return; - } - String label = aName.substring(0, n).replace("\\.", "."); - if (useCompression && USE_DOMAIN_NAME_COMPRESSION) { - Integer offset = _out._names.get(aName); - if (offset != null) { - int val = offset.intValue(); - writeByte((val >> 8) | 0xC0); - writeByte(val & 0xFF); - return; - } - _out._names.put(aName, Integer.valueOf(this.size() + _offset)); - writeUTF(label, 0, label.length()); - } else { - writeUTF(label, 0, label.length()); - } - aName = aName.substring(n); - if (aName.startsWith(".")) { - aName = aName.substring(1); - } - } + void writeUTF(String str, int off, int len) { + // compute utf length + int utflen = 0; + for (int i = 0; i < len; i++) { + int ch = str.charAt(off + i); + if ((ch >= 0x0001) && (ch <= 0x007F)) { + utflen += 1; + } else { + if (ch > 0x07FF) { + utflen += 3; + } else { + utflen += 2; + } } + } + // write utf length + writeByte(utflen); + // write utf data + for (int i = 0; i < len; i++) { + int ch = str.charAt(off + i); + if ((ch >= 0x0001) && (ch <= 0x007F)) { + writeByte(ch); + } else { + if (ch > 0x07FF) { + writeByte(0xE0 | ((ch >> 12) & 0x0F)); + writeByte(0x80 | ((ch >> 6) & 0x3F)); + writeByte(0x80 | ((ch >> 0) & 0x3F)); + } else { + writeByte(0xC0 | ((ch >> 6) & 0x1F)); + writeByte(0x80 | ((ch >> 0) & 0x3F)); + } + } + } + } - private static int indexOfSeparator(String aName) { - int offset = 0; - int n = 0; - - while (true) { - n = aName.indexOf('.', offset); - if (n < 0) - return -1; - - if (n == 0 || aName.charAt(n - 1) != '\\') - return n; + void writeName(String name) { + writeName(name, true); + } - offset = n + 1; - } + void writeName(String name, boolean useCompression) { + String aName = name; + while (true) { + int n = indexOfSeparator(aName); + if (n < 0) { + n = aName.length(); } - - void writeQuestion(DNSQuestion question) { - writeName(question.getName()); - writeShort(question.getRecordType().indexValue()); - writeShort(question.getRecordClass().indexValue()); + if (n <= 0) { + writeByte(0); + return; } - - void writeRecord(DNSRecord rec, long now) { - writeName(rec.getName()); - writeShort(rec.getRecordType().indexValue()); - writeShort(rec.getRecordClass().indexValue() | ((rec.isUnique() && _out.isMulticast()) ? DNSRecordClass.CLASS_UNIQUE : 0)); - writeInt((now == 0) ? rec.getTTL() : rec.getRemainingTTL(now)); - - // We need to take into account the 2 size bytes - MessageOutputStream record = new MessageOutputStream(512, _out, _offset + this.size() + 2); - rec.write(record); - byte[] byteArray = record.toByteArray(); - - writeShort(byteArray.length); - write(byteArray, 0, byteArray.length); + String label = aName.substring(0, n).replace("\\.", "."); + if (useCompression && USE_DOMAIN_NAME_COMPRESSION) { + Integer offset = _out._names.get(aName); + if (offset != null) { + int val = offset.intValue(); + writeByte((val >> 8) | 0xC0); + writeByte(val & 0xFF); + return; + } + _out._names.put(aName, Integer.valueOf(this.size() + _offset)); + writeUTF(label, 0, label.length()); + } else { + writeUTF(label, 0, label.length()); } - + aName = aName.substring(n); + if (aName.startsWith(".")) { + aName = aName.substring(1); + } + } } - /** - * This can be used to turn off domain name compression. This was helpful for tracking problems interacting with other mdns implementations. - */ - public static boolean USE_DOMAIN_NAME_COMPRESSION = true; - - Map _names; + private static int indexOfSeparator(String aName) { + int offset = 0; + int n = 0; - private int _maxUDPPayload; + while (true) { + n = aName.indexOf('.', offset); + if (n < 0) return -1; - private final MessageOutputStream _questionsBytes; + if (n == 0 || aName.charAt(n - 1) != '\\') return n; - private final MessageOutputStream _answersBytes; - - private final MessageOutputStream _authoritativeAnswersBytes; - - private final MessageOutputStream _additionalsAnswersBytes; - - private final static int HEADER_SIZE = 12; - - private InetSocketAddress _destination; - - /** - * Create an outgoing multicast query or response. - * - * @param flags - */ - public DNSOutgoing(int flags) { - this(flags, true, DNSConstants.MAX_MSG_TYPICAL); + offset = n + 1; + } } - /** - * Create an outgoing query or response. - * - * @param flags - * @param multicast - * @param senderUDPPayload - * The sender's UDP payload size is the number of bytes of the largest UDP payload that can be reassembled and delivered in the sender's network stack. - */ - public DNSOutgoing(int flags, boolean multicast, int senderUDPPayload) { - super(flags, 0, multicast); - _names = new HashMap(); - _maxUDPPayload = (senderUDPPayload > 0 ? senderUDPPayload : DNSConstants.MAX_MSG_TYPICAL); - _questionsBytes = new MessageOutputStream(senderUDPPayload, this); - _answersBytes = new MessageOutputStream(senderUDPPayload, this); - _authoritativeAnswersBytes = new MessageOutputStream(senderUDPPayload, this); - _additionalsAnswersBytes = new MessageOutputStream(senderUDPPayload, this); + void writeQuestion(DNSQuestion question) { + writeName(question.getName()); + writeShort(question.getRecordType().indexValue()); + writeShort(question.getRecordClass().indexValue()); } - /** - * Get the forced destination address if a specific one was set. - * - * @return a forced destination address or null if no address is forced. - */ - public InetSocketAddress getDestination() { - return _destination; + void writeRecord(DNSRecord rec, long now) { + writeName(rec.getName()); + writeShort(rec.getRecordType().indexValue()); + writeShort( + rec.getRecordClass().indexValue() + | ((rec.isUnique() && _out.isMulticast()) ? DNSRecordClass.CLASS_UNIQUE : 0)); + writeInt((now == 0) ? rec.getTTL() : rec.getRemainingTTL(now)); + + // We need to take into account the 2 size bytes + MessageOutputStream record = new MessageOutputStream(512, _out, _offset + this.size() + 2); + rec.write(record); + byte[] byteArray = record.toByteArray(); + + writeShort(byteArray.length); + write(byteArray, 0, byteArray.length); } - - /** - * Force a specific destination address if packet is sent. - * - * @param destination - * Set a destination address a packet should be sent to (instead the default one). You could use null to unset the forced destination. - */ - public void setDestination(InetSocketAddress destination) { - _destination = destination; - } - - /** - * Return the number of byte available in the message. - * - * @return available space - */ - public int availableSpace() { - return _maxUDPPayload - HEADER_SIZE - _questionsBytes.size() - _answersBytes.size() - _authoritativeAnswersBytes.size() - _additionalsAnswersBytes.size(); + } + + /** + * This can be used to turn off domain name compression. This was helpful for tracking problems + * interacting with other mdns implementations. + */ + public static boolean USE_DOMAIN_NAME_COMPRESSION = true; + + Map _names; + + private int _maxUDPPayload; + + private final MessageOutputStream _questionsBytes; + + private final MessageOutputStream _answersBytes; + + private final MessageOutputStream _authoritativeAnswersBytes; + + private final MessageOutputStream _additionalsAnswersBytes; + + private static final int HEADER_SIZE = 12; + + private InetSocketAddress _destination; + + /** + * Create an outgoing multicast query or response. + * + * @param flags + */ + public DNSOutgoing(int flags) { + this(flags, true, DNSConstants.MAX_MSG_TYPICAL); + } + + /** + * Create an outgoing query or response. + * + * @param flags + * @param multicast + * @param senderUDPPayload The sender's UDP payload size is the number of bytes of the largest UDP + * payload that can be reassembled and delivered in the sender's network stack. + */ + public DNSOutgoing(int flags, boolean multicast, int senderUDPPayload) { + super(flags, 0, multicast); + _names = new HashMap(); + _maxUDPPayload = (senderUDPPayload > 0 ? senderUDPPayload : DNSConstants.MAX_MSG_TYPICAL); + _questionsBytes = new MessageOutputStream(senderUDPPayload, this); + _answersBytes = new MessageOutputStream(senderUDPPayload, this); + _authoritativeAnswersBytes = new MessageOutputStream(senderUDPPayload, this); + _additionalsAnswersBytes = new MessageOutputStream(senderUDPPayload, this); + } + + /** + * Get the forced destination address if a specific one was set. + * + * @return a forced destination address or null if no address is forced. + */ + public InetSocketAddress getDestination() { + return _destination; + } + + /** + * Force a specific destination address if packet is sent. + * + * @param destination Set a destination address a packet should be sent to (instead the default + * one). You could use null to unset the forced destination. + */ + public void setDestination(InetSocketAddress destination) { + _destination = destination; + } + + /** + * Return the number of byte available in the message. + * + * @return available space + */ + public int availableSpace() { + return _maxUDPPayload + - HEADER_SIZE + - _questionsBytes.size() + - _answersBytes.size() + - _authoritativeAnswersBytes.size() + - _additionalsAnswersBytes.size(); + } + + /** + * Add a question to the message. + * + * @param rec + * @exception IOException + */ + public void addQuestion(DNSQuestion rec) throws IOException { + MessageOutputStream record = new MessageOutputStream(512, this); + record.writeQuestion(rec); + byte[] byteArray = record.toByteArray(); + record.close(); + if (byteArray.length < this.availableSpace()) { + _questions.add(rec); + _questionsBytes.write(byteArray, 0, byteArray.length); + } else { + throw new IOException("message full"); } - - /** - * Add a question to the message. - * - * @param rec - * @exception IOException - */ - public void addQuestion(DNSQuestion rec) throws IOException { + } + + /** + * Add an answer if it is not suppressed. + * + * @param rec + * @exception IOException + */ + public void addAnswer(DNSRecord rec) throws IOException { + this.addAnswer(rec, 0); + } + + /** + * Add an answer to the message. + * + * @param rec + * @param now + * @exception IOException + */ + public void addAnswer(DNSRecord rec, long now) throws IOException { + if (rec != null) { + if ((now == 0) || !rec.isExpired(now)) { MessageOutputStream record = new MessageOutputStream(512, this); - record.writeQuestion(rec); + record.writeRecord(rec, now); byte[] byteArray = record.toByteArray(); record.close(); if (byteArray.length < this.availableSpace()) { - _questions.add(rec); - _questionsBytes.write(byteArray, 0, byteArray.length); + _answers.add(rec); + _answersBytes.write(byteArray, 0, byteArray.length); } else { - throw new IOException("message full"); + throw new IOException("message full"); } + } } - - /** - * Add an answer if it is not suppressed. - * - * @param rec - * @exception IOException - */ - public void addAnswer(DNSRecord rec) throws IOException { - this.addAnswer(rec, 0); + } + + /** + * Builds the final message buffer to be send and returns it. + * + * @return bytes to send. + */ + public byte[] data() { + long now = System.currentTimeMillis(); // System.currentTimeMillis() + _names.clear(); + + MessageOutputStream message = new MessageOutputStream(_maxUDPPayload, this); + message.writeShort(_multicast ? 0 : this.getId()); + message.writeShort(this.getFlags()); + message.writeShort(this.getNumberOfQuestions()); + message.writeShort(this.getNumberOfAnswers()); + message.writeShort(this.getNumberOfAuthorities()); + message.writeShort(this.getNumberOfAdditionals()); + for (DNSQuestion question : _questions) { + message.writeQuestion(question); } - - /** - * Add an answer to the message. - * - * @param rec - * @param now - * @exception IOException - */ - public void addAnswer(DNSRecord rec, long now) throws IOException { - if (rec != null) { - if ((now == 0) || !rec.isExpired(now)) { - MessageOutputStream record = new MessageOutputStream(512, this); - record.writeRecord(rec, now); - byte[] byteArray = record.toByteArray(); - record.close(); - if (byteArray.length < this.availableSpace()) { - _answers.add(rec); - _answersBytes.write(byteArray, 0, byteArray.length); - } else { - throw new IOException("message full"); - } - } - } + for (DNSRecord record : _answers) { + message.writeRecord(record, now); } - - /** - * Builds the final message buffer to be send and returns it. - * - * @return bytes to send. - */ - public byte[] data() { - long now = System.currentTimeMillis(); // System.currentTimeMillis() - _names.clear(); - - MessageOutputStream message = new MessageOutputStream(_maxUDPPayload, this); - message.writeShort(_multicast ? 0 : this.getId()); - message.writeShort(this.getFlags()); - message.writeShort(this.getNumberOfQuestions()); - message.writeShort(this.getNumberOfAnswers()); - message.writeShort(this.getNumberOfAuthorities()); - message.writeShort(this.getNumberOfAdditionals()); - for (DNSQuestion question : _questions) { - message.writeQuestion(question); - } - for (DNSRecord record : _answers) { - message.writeRecord(record, now); - } - for (DNSRecord record : _authoritativeAnswers) { - message.writeRecord(record, now); - } - for (DNSRecord record : _additionals) { - message.writeRecord(record, now); - } - byte[] result = message.toByteArray(); - try { - message.close(); - } catch (IOException exception) {} - return result; + for (DNSRecord record : _authoritativeAnswers) { + message.writeRecord(record, now); } - - @Override - public String toString() { - final StringBuilder sb = new StringBuilder(); - sb.append(isQuery() ? "dns[query:" : "dns[response:"); - sb.append(" id=0x"); - sb.append(Integer.toHexString(this.getId())); - if (this.getFlags() != 0) { - sb.append(", flags=0x"); - sb.append(Integer.toHexString(this.getFlags())); - if (this.isResponse()) { - sb.append(":r"); - } - if (this.isAuthoritativeAnswer()) { - sb.append(":aa"); - } - if (this.isTruncated()) { - sb.append(":tc"); - } - } - if (this.getNumberOfQuestions() > 0) { - sb.append(", questions="); - sb.append(this.getNumberOfQuestions()); - } - if (this.getNumberOfAnswers() > 0) { - sb.append(", answers="); - sb.append(this.getNumberOfAnswers()); - } - if (this.getNumberOfAuthorities() > 0) { - sb.append(", authorities="); - sb.append(this.getNumberOfAuthorities()); - } - if (this.getNumberOfAdditionals() > 0) { - sb.append(", additionals="); - sb.append(this.getNumberOfAdditionals()); - } - if (this.getNumberOfQuestions() > 0) { - sb.append("\nquestions:"); - for (DNSQuestion question : _questions) { - sb.append("\n\t"); - sb.append(question); - } - } - if (this.getNumberOfAnswers() > 0) { - sb.append("\nanswers:"); - for (DNSRecord record : _answers) { - sb.append("\n\t"); - sb.append(record); - } - } - if (this.getNumberOfAuthorities() > 0) { - sb.append("\nauthorities:"); - for (DNSRecord record : _authoritativeAnswers) { - sb.append("\n\t"); - sb.append(record); - } - } - if (this.getNumberOfAdditionals() > 0) { - sb.append("\nadditionals:"); - for (DNSRecord record : _additionals) { - sb.append("\n\t"); - sb.append(record); - } - } - sb.append("\nnames="); - sb.append(_names); - sb.append("]"); - return sb.toString(); + for (DNSRecord record : _additionals) { + message.writeRecord(record, now); } - - /** - * @return the maxUDPPayload - */ - public int getMaxUDPPayload() { - return this._maxUDPPayload; + byte[] result = message.toByteArray(); + try { + message.close(); + } catch (IOException exception) { } - + return result; + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder(); + sb.append(isQuery() ? "dns[query:" : "dns[response:"); + sb.append(" id=0x"); + sb.append(Integer.toHexString(this.getId())); + if (this.getFlags() != 0) { + sb.append(", flags=0x"); + sb.append(Integer.toHexString(this.getFlags())); + if (this.isResponse()) { + sb.append(":r"); + } + if (this.isAuthoritativeAnswer()) { + sb.append(":aa"); + } + if (this.isTruncated()) { + sb.append(":tc"); + } + } + if (this.getNumberOfQuestions() > 0) { + sb.append(", questions="); + sb.append(this.getNumberOfQuestions()); + } + if (this.getNumberOfAnswers() > 0) { + sb.append(", answers="); + sb.append(this.getNumberOfAnswers()); + } + if (this.getNumberOfAuthorities() > 0) { + sb.append(", authorities="); + sb.append(this.getNumberOfAuthorities()); + } + if (this.getNumberOfAdditionals() > 0) { + sb.append(", additionals="); + sb.append(this.getNumberOfAdditionals()); + } + if (this.getNumberOfQuestions() > 0) { + sb.append("\nquestions:"); + for (DNSQuestion question : _questions) { + sb.append("\n\t"); + sb.append(question); + } + } + if (this.getNumberOfAnswers() > 0) { + sb.append("\nanswers:"); + for (DNSRecord record : _answers) { + sb.append("\n\t"); + sb.append(record); + } + } + if (this.getNumberOfAuthorities() > 0) { + sb.append("\nauthorities:"); + for (DNSRecord record : _authoritativeAnswers) { + sb.append("\n\t"); + sb.append(record); + } + } + if (this.getNumberOfAdditionals() > 0) { + sb.append("\nadditionals:"); + for (DNSRecord record : _additionals) { + sb.append("\n\t"); + sb.append(record); + } + } + sb.append("\nnames="); + sb.append(_names); + sb.append("]"); + return sb.toString(); + } + + /** + * @return the maxUDPPayload + */ + public int getMaxUDPPayload() { + return this._maxUDPPayload; + } } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSQuestion.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSQuestion.java index 789aaf418..ed87bb56f 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSQuestion.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSQuestion.java @@ -4,87 +4,89 @@ package io.libp2p.discovery.mdns.impl; -import java.util.Set; - +import io.libp2p.discovery.mdns.ServiceInfo; +import io.libp2p.discovery.mdns.impl.constants.DNSConstants; import io.libp2p.discovery.mdns.impl.constants.DNSRecordClass; import io.libp2p.discovery.mdns.impl.constants.DNSRecordType; +import java.util.Set; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import io.libp2p.discovery.mdns.ServiceInfo; -import io.libp2p.discovery.mdns.impl.constants.DNSConstants; - /** * A DNS question. * * @author Arthur van Hoff, Pierre Frisch */ public class DNSQuestion extends DNSEntry { - private static Logger logger = LoggerFactory.getLogger(DNSQuestion.class.getName()); - - /** - * Pointer question. - */ - private static class Pointer extends DNSQuestion { - Pointer(String name, DNSRecordType type, DNSRecordClass recordClass, boolean unique) { - super(name, type, recordClass, unique); - } + private static Logger logger = LoggerFactory.getLogger(DNSQuestion.class.getName()); - @Override - public void addAnswers(JmDNSImpl jmDNSImpl, Set answers) { - // find matching services - for (ServiceInfo serviceInfo : jmDNSImpl.getServices().values()) { - this.addAnswersForServiceInfo(jmDNSImpl, answers, (ServiceInfoImpl) serviceInfo); - } - } + /** Pointer question. */ + private static class Pointer extends DNSQuestion { + Pointer(String name, DNSRecordType type, DNSRecordClass recordClass, boolean unique) { + super(name, type, recordClass, unique); } - DNSQuestion(String name, DNSRecordType type, DNSRecordClass recordClass, boolean unique) { - super(name, type, recordClass, unique); + @Override + public void addAnswers(JmDNSImpl jmDNSImpl, Set answers) { + // find matching services + for (ServiceInfo serviceInfo : jmDNSImpl.getServices().values()) { + this.addAnswersForServiceInfo(jmDNSImpl, answers, (ServiceInfoImpl) serviceInfo); + } } + } - /** - * Create a question. - * - * @param name - * DNS name to be resolved - * @param type - * Record type to resolve - * @param recordClass - * Record class to resolve - * @param unique - * Request unicast response (Currently not supported in this implementation) - * @return new question - */ - public static DNSQuestion newQuestion(String name, DNSRecordType type, DNSRecordClass recordClass, boolean unique) { - return (type == DNSRecordType.TYPE_PTR) - ? new Pointer(name, type, recordClass, unique) - : null; - } + DNSQuestion(String name, DNSRecordType type, DNSRecordClass recordClass, boolean unique) { + super(name, type, recordClass, unique); + } - /** - * Adds answers to the list for our question. - * - * @param jmDNSImpl - * DNS holding the records - * @param answers - * List of previous answer to append. - */ - public void addAnswers(JmDNSImpl jmDNSImpl, Set answers) { - // By default we do nothing - } + /** + * Create a question. + * + * @param name DNS name to be resolved + * @param type Record type to resolve + * @param recordClass Record class to resolve + * @param unique Request unicast response (Currently not supported in this implementation) + * @return new question + */ + public static DNSQuestion newQuestion( + String name, DNSRecordType type, DNSRecordClass recordClass, boolean unique) { + return (type == DNSRecordType.TYPE_PTR) ? new Pointer(name, type, recordClass, unique) : null; + } - protected void addAnswersForServiceInfo(JmDNSImpl jmDNSImpl, Set answers, ServiceInfoImpl info) { - if (info != null) { - if (this.getName().equalsIgnoreCase(info.getQualifiedName()) || this.getName().equalsIgnoreCase(info.getType()) || this.getName().equalsIgnoreCase(info.getTypeWithSubtype())) { - answers.addAll(info.answers(this.getRecordClass(), DNSRecordClass.UNIQUE, DNSConstants.DNS_TTL, jmDNSImpl.getLocalHost())); - } - logger.debug("{} DNSQuestion({}).addAnswersForServiceInfo(): info: {}\n{}", jmDNSImpl.getName(), this.getName(), info, answers); - } - } + /** + * Adds answers to the list for our question. + * + * @param jmDNSImpl DNS holding the records + * @param answers List of previous answer to append. + */ + public void addAnswers(JmDNSImpl jmDNSImpl, Set answers) { + // By default we do nothing + } - @Override - public boolean isExpired(long now) { - return false; + protected void addAnswersForServiceInfo( + JmDNSImpl jmDNSImpl, Set answers, ServiceInfoImpl info) { + if (info != null) { + if (this.getName().equalsIgnoreCase(info.getQualifiedName()) + || this.getName().equalsIgnoreCase(info.getType()) + || this.getName().equalsIgnoreCase(info.getTypeWithSubtype())) { + answers.addAll( + info.answers( + this.getRecordClass(), + DNSRecordClass.UNIQUE, + DNSConstants.DNS_TTL, + jmDNSImpl.getLocalHost())); + } + logger.debug( + "{} DNSQuestion({}).addAnswersForServiceInfo(): info: {}\n{}", + jmDNSImpl.getName(), + this.getName(), + info, + answers); } -} \ No newline at end of file + } + + @Override + public boolean isExpired(long now) { + return false; + } +} diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSRecord.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSRecord.java index 6e75a6c47..6201af584 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSRecord.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/DNSRecord.java @@ -8,9 +8,6 @@ import io.libp2p.discovery.mdns.impl.constants.DNSRecordClass; import io.libp2p.discovery.mdns.impl.constants.DNSRecordType; import io.libp2p.discovery.mdns.impl.util.ByteWrangler; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.io.DataOutputStream; import java.io.IOException; import java.io.UnsupportedEncodingException; @@ -18,7 +15,8 @@ import java.net.InetAddress; import java.net.UnknownHostException; import java.util.Objects; - +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * DNS record @@ -26,358 +24,378 @@ * @author Arthur van Hoff, Rick Blair, Werner Randelshofer, Pierre Frisch */ public abstract class DNSRecord extends DNSEntry { - private static Logger logger = LoggerFactory.getLogger(DNSRecord.class.getName()); - - private int _ttl; - private long _created; - - /** - * Create a DNSRecord with a name, type, class, and ttl. - */ - DNSRecord(String name, DNSRecordType type, DNSRecordClass recordClass, boolean unique, int ttl) { - super(name, type, recordClass, unique); - this._ttl = ttl; - this._created = System.currentTimeMillis(); + private static Logger logger = LoggerFactory.getLogger(DNSRecord.class.getName()); + + private int _ttl; + private long _created; + + /** Create a DNSRecord with a name, type, class, and ttl. */ + DNSRecord(String name, DNSRecordType type, DNSRecordClass recordClass, boolean unique, int ttl) { + super(name, type, recordClass, unique); + this._ttl = ttl; + this._created = System.currentTimeMillis(); + } + + @Override + public boolean equals(Object other) { + return (other instanceof DNSRecord) && super.equals(other) && sameValue((DNSRecord) other); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), _ttl, _created); + } + + abstract boolean sameValue(DNSRecord other); + + /** Get the expiration time of this record. */ + long getExpirationTime(int percent) { + // ttl is in seconds the constant 10 is 1000 ms / 100 % + return _created + (percent * ((long) _ttl) * 10L); + } + + /** Get the remaining TTL for this record. */ + int getRemainingTTL(long now) { + return (int) Math.max(0, (getExpirationTime(100) - now) / 1000); + } + + @Override + public boolean isExpired(long now) { + return getExpirationTime(100) <= now; + } + + abstract void write(MessageOutputStream out); + + public static class IPv4Address extends Address { + IPv4Address( + String name, DNSRecordClass recordClass, boolean unique, int ttl, InetAddress addr) { + super(name, DNSRecordType.TYPE_A, recordClass, unique, ttl, addr); } - @Override - public boolean equals(Object other) { - return (other instanceof DNSRecord) && super.equals(other) && sameValue((DNSRecord) other); + IPv4Address( + String name, DNSRecordClass recordClass, boolean unique, int ttl, byte[] rawAddress) { + super(name, DNSRecordType.TYPE_A, recordClass, unique, ttl, rawAddress); } @Override - public int hashCode() { - return Objects.hash(super.hashCode(), _ttl, _created); + void write(MessageOutputStream out) { + if (_addr != null) { + byte[] buffer = _addr.getAddress(); + // If we have a type A records we should answer with a IPv4 address + if (_addr instanceof Inet4Address) { + // All is good + } else { + // Get the last four bytes + byte[] tempbuffer = buffer; + buffer = new byte[4]; + System.arraycopy(tempbuffer, 12, buffer, 0, 4); + } + int length = buffer.length; + out.writeBytes(buffer, 0, length); + } } + } - abstract boolean sameValue(DNSRecord other); - - /** - * Get the expiration time of this record. - */ - long getExpirationTime(int percent) { - // ttl is in seconds the constant 10 is 1000 ms / 100 % - return _created + (percent * ((long)_ttl) * 10L); + public static class IPv6Address extends Address { + IPv6Address( + String name, DNSRecordClass recordClass, boolean unique, int ttl, InetAddress addr) { + super(name, DNSRecordType.TYPE_AAAA, recordClass, unique, ttl, addr); } - /** - * Get the remaining TTL for this record. - */ - int getRemainingTTL(long now) { - return (int) Math.max(0, (getExpirationTime(100) - now) / 1000); + IPv6Address( + String name, DNSRecordClass recordClass, boolean unique, int ttl, byte[] rawAddress) { + super(name, DNSRecordType.TYPE_AAAA, recordClass, unique, ttl, rawAddress); } @Override - public boolean isExpired(long now) { - return getExpirationTime(100) <= now; - } - - abstract void write(MessageOutputStream out); - - public static class IPv4Address extends Address { - IPv4Address(String name, DNSRecordClass recordClass, boolean unique, int ttl, InetAddress addr) { - super(name, DNSRecordType.TYPE_A, recordClass, unique, ttl, addr); - } - - IPv4Address(String name, DNSRecordClass recordClass, boolean unique, int ttl, byte[] rawAddress) { - super(name, DNSRecordType.TYPE_A, recordClass, unique, ttl, rawAddress); - } - - @Override - void write(MessageOutputStream out) { - if (_addr != null) { - byte[] buffer = _addr.getAddress(); - // If we have a type A records we should answer with a IPv4 address - if (_addr instanceof Inet4Address) { - // All is good - } else { - // Get the last four bytes - byte[] tempbuffer = buffer; - buffer = new byte[4]; - System.arraycopy(tempbuffer, 12, buffer, 0, 4); - } - int length = buffer.length; - out.writeBytes(buffer, 0, length); + void write(MessageOutputStream out) { + if (_addr != null) { + byte[] buffer = _addr.getAddress(); + // If we have a type AAAA records we should answer with a IPv6 address + if (_addr instanceof Inet4Address) { + byte[] tempbuffer = buffer; + buffer = new byte[16]; + for (int i = 0; i < 16; i++) { + if (i < 11) { + buffer[i] = tempbuffer[i - 12]; + } else { + buffer[i] = 0; } + } } + int length = buffer.length; + out.writeBytes(buffer, 0, length); + } } - - public static class IPv6Address extends Address { - IPv6Address(String name, DNSRecordClass recordClass, boolean unique, int ttl, InetAddress addr) { - super(name, DNSRecordType.TYPE_AAAA, recordClass, unique, ttl, addr); - } - - IPv6Address(String name, DNSRecordClass recordClass, boolean unique, int ttl, byte[] rawAddress) { - super(name, DNSRecordType.TYPE_AAAA, recordClass, unique, ttl, rawAddress); - } - - @Override - void write(MessageOutputStream out) { - if (_addr != null) { - byte[] buffer = _addr.getAddress(); - // If we have a type AAAA records we should answer with a IPv6 address - if (_addr instanceof Inet4Address) { - byte[] tempbuffer = buffer; - buffer = new byte[16]; - for (int i = 0; i < 16; i++) { - if (i < 11) { - buffer[i] = tempbuffer[i - 12]; - } else { - buffer[i] = 0; - } - } - } - int length = buffer.length; - out.writeBytes(buffer, 0, length); - } - } + } + + /** Address record. */ + public abstract static class Address extends DNSRecord { + InetAddress _addr; + + protected Address( + String name, + DNSRecordType type, + DNSRecordClass recordClass, + boolean unique, + int ttl, + InetAddress addr) { + super(name, type, recordClass, unique, ttl); + this._addr = addr; } - /** - * Address record. - */ - public static abstract class Address extends DNSRecord { - InetAddress _addr; - - protected Address(String name, DNSRecordType type, DNSRecordClass recordClass, boolean unique, int ttl, InetAddress addr) { - super(name, type, recordClass, unique, ttl); - this._addr = addr; - } + protected Address( + String name, + DNSRecordType type, + DNSRecordClass recordClass, + boolean unique, + int ttl, + byte[] rawAddress) { + super(name, type, recordClass, unique, ttl); + try { + this._addr = InetAddress.getByAddress(rawAddress); + } catch (UnknownHostException exception) { + logger.warn("Address() exception ", exception); + } + } - protected Address(String name, DNSRecordType type, DNSRecordClass recordClass, boolean unique, int ttl, byte[] rawAddress) { - super(name, type, recordClass, unique, ttl); - try { - this._addr = InetAddress.getByAddress(rawAddress); - } catch (UnknownHostException exception) { - logger.warn("Address() exception ", exception); - } + @Override + boolean sameValue(DNSRecord other) { + try { + if (!(other instanceof Address)) { + return false; } - - @Override - boolean sameValue(DNSRecord other) { - try { - if (!(other instanceof Address)) { - return false; - } - Address address = (Address) other; - if ((this.getAddress() == null) && (address.getAddress() != null)) { - return false; - } - return this.getAddress().equals(address.getAddress()); - } catch (Exception e) { - logger.info("Failed to compare addresses of DNSRecords", e); - return false; - } - } - - public InetAddress getAddress() { - return _addr; + Address address = (Address) other; + if ((this.getAddress() == null) && (address.getAddress() != null)) { + return false; } + return this.getAddress().equals(address.getAddress()); + } catch (Exception e) { + logger.info("Failed to compare addresses of DNSRecords", e); + return false; + } + } - /** - * Creates a byte array representation of this record. This is needed for tie-break tests according to draft-cheshire-dnsext-multicastdns-04.txt chapter 9.2. - */ - @Override - protected void toByteArray(DataOutputStream dout) throws IOException { - super.toByteArray(dout); - byte[] buffer = this.getAddress().getAddress(); - for (int i = 0; i < buffer.length; i++) { - dout.writeByte(buffer[i]); - } - } + public InetAddress getAddress() { + return _addr; } /** - * Pointer record. + * Creates a byte array representation of this record. This is needed for tie-break tests + * according to draft-cheshire-dnsext-multicastdns-04.txt chapter 9.2. */ - public static class Pointer extends DNSRecord { - private final String _alias; - - public Pointer(String name, DNSRecordClass recordClass, boolean unique, int ttl, String alias) { - super(name, DNSRecordType.TYPE_PTR, recordClass, unique, ttl); - this._alias = alias; - } - - @Override - public boolean isSameEntry(DNSEntry entry) { - return super.isSameEntry(entry) && (entry instanceof Pointer) && this.sameValue((Pointer) entry); - } + @Override + protected void toByteArray(DataOutputStream dout) throws IOException { + super.toByteArray(dout); + byte[] buffer = this.getAddress().getAddress(); + for (int i = 0; i < buffer.length; i++) { + dout.writeByte(buffer[i]); + } + } + } - @Override - void write(MessageOutputStream out) { - out.writeName(_alias); - } + /** Pointer record. */ + public static class Pointer extends DNSRecord { + private final String _alias; - @Override - boolean sameValue(DNSRecord other) { - if (!(other instanceof Pointer)) { - return false; - } - Pointer pointer = (Pointer) other; - if ((_alias == null) && (pointer._alias != null)) { - return false; - } - return _alias.equals(pointer._alias); - } + public Pointer(String name, DNSRecordClass recordClass, boolean unique, int ttl, String alias) { + super(name, DNSRecordType.TYPE_PTR, recordClass, unique, ttl); + this._alias = alias; } - public static class Text extends DNSRecord { - private final byte[] _text; + @Override + public boolean isSameEntry(DNSEntry entry) { + return super.isSameEntry(entry) + && (entry instanceof Pointer) + && this.sameValue((Pointer) entry); + } - public Text(String name, DNSRecordClass recordClass, boolean unique, int ttl, byte text[]) { - super(name, DNSRecordType.TYPE_TXT, recordClass, unique, ttl); - this._text = (text != null && text.length > 0 ? text : ByteWrangler.EMPTY_TXT); - } + @Override + void write(MessageOutputStream out) { + out.writeName(_alias); + } - /** - * @return the text - */ - public byte[] getText() { - return this._text; - } + @Override + boolean sameValue(DNSRecord other) { + if (!(other instanceof Pointer)) { + return false; + } + Pointer pointer = (Pointer) other; + if ((_alias == null) && (pointer._alias != null)) { + return false; + } + return _alias.equals(pointer._alias); + } + } - @Override - void write(MessageOutputStream out) { - out.writeBytes(_text, 0, _text.length); - } + public static class Text extends DNSRecord { + private final byte[] _text; - @Override - boolean sameValue(DNSRecord other) { - if (!(other instanceof Text)) { - return false; - } - Text txt = (Text) other; - if ((_text == null) && (txt._text != null)) { - return false; - } - if (txt._text.length != _text.length) { - return false; - } - for (int i = _text.length; i-- > 0;) { - if (txt._text[i] != _text[i]) { - return false; - } - } - return true; - } + public Text(String name, DNSRecordClass recordClass, boolean unique, int ttl, byte text[]) { + super(name, DNSRecordType.TYPE_TXT, recordClass, unique, ttl); + this._text = (text != null && text.length > 0 ? text : ByteWrangler.EMPTY_TXT); } /** - * Service record. + * @return the text */ - public static class Service extends DNSRecord { - private final int _priority; - private final int _weight; - private final int _port; - private final String _server; - - public Service(String name, DNSRecordClass recordClass, boolean unique, int ttl, int priority, int weight, int port, String server) { - super(name, DNSRecordType.TYPE_SRV, recordClass, unique, ttl); - this._priority = priority; - this._weight = weight; - this._port = port; - this._server = server; - } + public byte[] getText() { + return this._text; + } - @Override - void write(MessageOutputStream out) { - out.writeShort(_priority); - out.writeShort(_weight); - out.writeShort(_port); - if (DNSIncoming.USE_DOMAIN_NAME_FORMAT_FOR_SRV_TARGET) { - out.writeName(_server); - } else { - // [PJYF Nov 13 2010] Do we still need this? This looks really bad. All label are supposed to start by a length. - out.writeUTF(_server, 0, _server.length()); + @Override + void write(MessageOutputStream out) { + out.writeBytes(_text, 0, _text.length); + } - // add a zero byte to the end just to be safe, this is the strange form - // used by the BonjourConformanceTest - out.writeByte(0); - } + @Override + boolean sameValue(DNSRecord other) { + if (!(other instanceof Text)) { + return false; + } + Text txt = (Text) other; + if ((_text == null) && (txt._text != null)) { + return false; + } + if (txt._text.length != _text.length) { + return false; + } + for (int i = _text.length; i-- > 0; ) { + if (txt._text[i] != _text[i]) { + return false; } + } + return true; + } + } + + /** Service record. */ + public static class Service extends DNSRecord { + private final int _priority; + private final int _weight; + private final int _port; + private final String _server; + + public Service( + String name, + DNSRecordClass recordClass, + boolean unique, + int ttl, + int priority, + int weight, + int port, + String server) { + super(name, DNSRecordType.TYPE_SRV, recordClass, unique, ttl); + this._priority = priority; + this._weight = weight; + this._port = port; + this._server = server; + } - @Override - protected void toByteArray(DataOutputStream dout) throws IOException { - super.toByteArray(dout); - dout.writeShort(_priority); - dout.writeShort(_weight); - dout.writeShort(_port); - try { - dout.write(_server.getBytes("UTF-8")); - } catch (UnsupportedEncodingException exception) { - /* UTF-8 is always present */ - } - } + @Override + void write(MessageOutputStream out) { + out.writeShort(_priority); + out.writeShort(_weight); + out.writeShort(_port); + if (DNSIncoming.USE_DOMAIN_NAME_FORMAT_FOR_SRV_TARGET) { + out.writeName(_server); + } else { + // [PJYF Nov 13 2010] Do we still need this? This looks really bad. All label are supposed + // to start by a length. + out.writeUTF(_server, 0, _server.length()); + + // add a zero byte to the end just to be safe, this is the strange form + // used by the BonjourConformanceTest + out.writeByte(0); + } + } - /** - * @return the weight - */ - public int getWeight() { - return this._weight; - } + @Override + protected void toByteArray(DataOutputStream dout) throws IOException { + super.toByteArray(dout); + dout.writeShort(_priority); + dout.writeShort(_weight); + dout.writeShort(_port); + try { + dout.write(_server.getBytes("UTF-8")); + } catch (UnsupportedEncodingException exception) { + /* UTF-8 is always present */ + } + } - /** - * @return the port - */ - public int getPort() { - return this._port; - } + /** + * @return the weight + */ + public int getWeight() { + return this._weight; + } - @Override - boolean sameValue(DNSRecord other) { - if (!(other instanceof Service)) { - return false; - } - Service s = (Service) other; - return (_priority == s._priority) && (_weight == s._weight) && (_port == s._port) && _server.equals(s._server); - } + /** + * @return the port + */ + public int getPort() { + return this._port; } - public static class HostInformation extends DNSRecord { - String _os; - String _cpu; - - /** - * @param name - * @param recordClass - * @param unique - * @param ttl - * @param cpu - * @param os - */ - public HostInformation(String name, DNSRecordClass recordClass, boolean unique, int ttl, String cpu, String os) { - super(name, DNSRecordType.TYPE_HINFO, recordClass, unique, ttl); - _cpu = cpu; - _os = os; - } + @Override + boolean sameValue(DNSRecord other) { + if (!(other instanceof Service)) { + return false; + } + Service s = (Service) other; + return (_priority == s._priority) + && (_weight == s._weight) + && (_port == s._port) + && _server.equals(s._server); + } + } - /* - * (non-Javadoc) - * @see javax.jmdns.impl.DNSRecord#sameValue(javax.jmdns.impl.DNSRecord) - */ - @Override - boolean sameValue(DNSRecord other) { - if (!(other instanceof HostInformation)) { - return false; - } - HostInformation hinfo = (HostInformation) other; - if ((_cpu == null) && (hinfo._cpu != null)) { - return false; - } - if ((_os == null) && (hinfo._os != null)) { - return false; - } - return _cpu.equals(hinfo._cpu) && _os.equals(hinfo._os); - } + public static class HostInformation extends DNSRecord { + String _os; + String _cpu; - @Override - void write(MessageOutputStream out) { - String hostInfo = _cpu + " " + _os; - out.writeUTF(hostInfo, 0, hostInfo.length()); - } + /** + * @param name + * @param recordClass + * @param unique + * @param ttl + * @param cpu + * @param os + */ + public HostInformation( + String name, DNSRecordClass recordClass, boolean unique, int ttl, String cpu, String os) { + super(name, DNSRecordType.TYPE_HINFO, recordClass, unique, ttl); + _cpu = cpu; + _os = os; } - public int getTTL() { - return _ttl; + /* + * (non-Javadoc) + * @see javax.jmdns.impl.DNSRecord#sameValue(javax.jmdns.impl.DNSRecord) + */ + @Override + boolean sameValue(DNSRecord other) { + if (!(other instanceof HostInformation)) { + return false; + } + HostInformation hinfo = (HostInformation) other; + if ((_cpu == null) && (hinfo._cpu != null)) { + return false; + } + if ((_os == null) && (hinfo._os != null)) { + return false; + } + return _cpu.equals(hinfo._cpu) && _os.equals(hinfo._os); + } + + @Override + void write(MessageOutputStream out) { + String hostInfo = _cpu + " " + _os; + out.writeUTF(hostInfo, 0, hostInfo.length()); } + } + + public int getTTL() { + return _ttl; + } } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/HostInfo.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/HostInfo.java index c93d6d983..3db3ca6b6 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/HostInfo.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/HostInfo.java @@ -7,7 +7,6 @@ import java.io.IOException; import java.net.*; import java.util.*; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -17,185 +16,192 @@ * @author Pierre Frisch, Werner Randelshofer */ public class HostInfo { - private static Logger logger = LoggerFactory.getLogger(HostInfo.class.getName()); - - protected String _name; - - protected InetAddress _address; - - protected NetworkInterface _interface; - - /** - * @param address - * IP address to bind - * @param dns - * JmDNS instance - * @param jmdnsName - * JmDNS name - * @return new HostInfo - */ - public static HostInfo newHostInfo(InetAddress address, JmDNSImpl dns, String jmdnsName) { - HostInfo localhost = null; - String aName = (jmdnsName != null ? jmdnsName : ""); - InetAddress addr = address; - try { - if (addr == null) { - String ip = System.getProperty("net.mdns.interface"); - if (ip != null) { - addr = InetAddress.getByName(ip); - } else { - addr = InetAddress.getLocalHost(); - if (addr.isLoopbackAddress()) { - // Find local address that isn't a loopback address - InetAddress[] addresses = getInetAddresses(); - if (addresses.length > 0) { - addr = addresses[0]; - } - } - } - if (addr.isLoopbackAddress()) { - logger.warn("Could not find any address beside the loopback."); - } - } - if (aName.length() == 0) { - aName = addr.getHostName(); + private static Logger logger = LoggerFactory.getLogger(HostInfo.class.getName()); + + protected String _name; + + protected InetAddress _address; + + protected NetworkInterface _interface; + + /** + * @param address IP address to bind + * @param dns JmDNS instance + * @param jmdnsName JmDNS name + * @return new HostInfo + */ + public static HostInfo newHostInfo(InetAddress address, JmDNSImpl dns, String jmdnsName) { + HostInfo localhost = null; + String aName = (jmdnsName != null ? jmdnsName : ""); + InetAddress addr = address; + try { + if (addr == null) { + String ip = System.getProperty("net.mdns.interface"); + if (ip != null) { + addr = InetAddress.getByName(ip); + } else { + addr = InetAddress.getLocalHost(); + if (addr.isLoopbackAddress()) { + // Find local address that isn't a loopback address + InetAddress[] addresses = getInetAddresses(); + if (addresses.length > 0) { + addr = addresses[0]; } - if (aName.contains("in-addr.arpa") || (aName.equals(addr.getHostAddress()))) { - aName = ((jmdnsName != null) && (jmdnsName.length() > 0) ? jmdnsName : addr.getHostAddress()); - } - } catch (final IOException e) { - logger.warn("Could not initialize the host network interface on " + address + "because of an error: " + e.getMessage(), e); - // This is only used for running unit test on Debian / Ubuntu - addr = loopbackAddress(); - aName = ((jmdnsName != null) && (jmdnsName.length() > 0) ? jmdnsName : "computer"); + } } - // A host name with "." is illegal. so strip off everything and append .local. - // We also need to be carefull that the .local may already be there - int index = aName.indexOf(".local"); - if (index > 0) { - aName = aName.substring(0, index); + if (addr.isLoopbackAddress()) { + logger.warn("Could not find any address beside the loopback."); } - aName = aName.replaceAll("[:%\\.]", "-"); - aName += ".local."; - localhost = new HostInfo(addr, aName, dns); - return localhost; + } + if (aName.length() == 0) { + aName = addr.getHostName(); + } + if (aName.contains("in-addr.arpa") || (aName.equals(addr.getHostAddress()))) { + aName = + ((jmdnsName != null) && (jmdnsName.length() > 0) ? jmdnsName : addr.getHostAddress()); + } + } catch (final IOException e) { + logger.warn( + "Could not initialize the host network interface on " + + address + + "because of an error: " + + e.getMessage(), + e); + // This is only used for running unit test on Debian / Ubuntu + addr = loopbackAddress(); + aName = ((jmdnsName != null) && (jmdnsName.length() > 0) ? jmdnsName : "computer"); } - - private static InetAddress[] getInetAddresses() { - Set result = new HashSet(); - try { - - for (Enumeration nifs = NetworkInterface.getNetworkInterfaces(); nifs.hasMoreElements();) { - NetworkInterface nif = nifs.nextElement(); - if (useInterface(nif)) { - for (Enumeration iaenum = nif.getInetAddresses(); iaenum.hasMoreElements();) { - InetAddress interfaceAddress = iaenum.nextElement(); - logger.trace("Found NetworkInterface/InetAddress: {} -- {}", nif , interfaceAddress); - result.add(interfaceAddress); - } - } - } - } catch (SocketException se) { - logger.warn("Error while fetching network interfaces addresses: " + se); - } - return result.toArray(new InetAddress[result.size()]); + // A host name with "." is illegal. so strip off everything and append .local. + // We also need to be carefull that the .local may already be there + int index = aName.indexOf(".local"); + if (index > 0) { + aName = aName.substring(0, index); } - - private static boolean useInterface(NetworkInterface networkInterface) { - try { - if (!networkInterface.isUp()) { - return false; - } - - if (!networkInterface.supportsMulticast()) { - return false; - } - - if (networkInterface.isLoopback()) { - return false; - } - - return true; - } catch (Exception exception) { - return false; + aName = aName.replaceAll("[:%\\.]", "-"); + aName += ".local."; + localhost = new HostInfo(addr, aName, dns); + return localhost; + } + + private static InetAddress[] getInetAddresses() { + Set result = new HashSet(); + try { + + for (Enumeration nifs = NetworkInterface.getNetworkInterfaces(); + nifs.hasMoreElements(); ) { + NetworkInterface nif = nifs.nextElement(); + if (useInterface(nif)) { + for (Enumeration iaenum = nif.getInetAddresses(); + iaenum.hasMoreElements(); ) { + InetAddress interfaceAddress = iaenum.nextElement(); + logger.trace("Found NetworkInterface/InetAddress: {} -- {}", nif, interfaceAddress); + result.add(interfaceAddress); + } } + } + } catch (SocketException se) { + logger.warn("Error while fetching network interfaces addresses: " + se); } - - private static InetAddress loopbackAddress() { - try { - return InetAddress.getByName(null); - } catch (UnknownHostException exception) { - return null; - } + return result.toArray(new InetAddress[result.size()]); + } + + private static boolean useInterface(NetworkInterface networkInterface) { + try { + if (!networkInterface.isUp()) { + return false; + } + + if (!networkInterface.supportsMulticast()) { + return false; + } + + if (networkInterface.isLoopback()) { + return false; + } + + return true; + } catch (Exception exception) { + return false; } + } - private HostInfo(final InetAddress address, final String name, final JmDNSImpl dns) { - super(); - this._address = address; - this._name = name; - if (address != null) { - try { - _interface = NetworkInterface.getByInetAddress(address); - } catch (Exception exception) { - logger.warn("LocalHostInfo() exception ", exception); - } - } + private static InetAddress loopbackAddress() { + try { + return InetAddress.getByName(null); + } catch (UnknownHostException exception) { + return null; } - - public String getName() { - return _name; + } + + private HostInfo(final InetAddress address, final String name, final JmDNSImpl dns) { + super(); + this._address = address; + this._name = name; + if (address != null) { + try { + _interface = NetworkInterface.getByInetAddress(address); + } catch (Exception exception) { + logger.warn("LocalHostInfo() exception ", exception); + } } - - public InetAddress getInetAddress() { - return _address; - } - - public NetworkInterface getInterface() { - return _interface; - } - - boolean shouldIgnorePacket(DatagramPacket packet) { - boolean result = false; - if (this.getInetAddress() != null) { - InetAddress from = packet.getAddress(); - if (from != null) { - if ((this.getInetAddress().isLinkLocalAddress() || this.getInetAddress().isMCLinkLocal()) && (!from.isLinkLocalAddress())) { - // A host sending Multicast DNS queries to a link-local destination - // address (including the 224.0.0.251 and FF02::FB link-local multicast - // addresses) MUST only accept responses to that query that originate - // from the local link, and silently discard any other response packets. - // Without this check, it could be possible for remote rogue hosts to - // send spoof answer packets (perhaps unicast to the victim host) which - // the receiving machine could misinterpret as having originated on the - // local link. - result = true; - } - // if (from.isLinkLocalAddress() && (!this.getInetAddress().isLinkLocalAddress())) { - // // Ignore linklocal packets on regular interfaces, unless this is - // // also a linklocal interface. This is to avoid duplicates. This is - // // a terrible hack caused by the lack of an API to get the address - // // of the interface on which the packet was received. - // result = true; - // } - if (from.isLoopbackAddress() && (!this.getInetAddress().isLoopbackAddress())) { - // Ignore loopback packets on a regular interface unless this is also a loopback interface. - result = true; - } - } + } + + public String getName() { + return _name; + } + + public InetAddress getInetAddress() { + return _address; + } + + public NetworkInterface getInterface() { + return _interface; + } + + boolean shouldIgnorePacket(DatagramPacket packet) { + boolean result = false; + if (this.getInetAddress() != null) { + InetAddress from = packet.getAddress(); + if (from != null) { + if ((this.getInetAddress().isLinkLocalAddress() || this.getInetAddress().isMCLinkLocal()) + && (!from.isLinkLocalAddress())) { + // A host sending Multicast DNS queries to a link-local destination + // address (including the 224.0.0.251 and FF02::FB link-local multicast + // addresses) MUST only accept responses to that query that originate + // from the local link, and silently discard any other response packets. + // Without this check, it could be possible for remote rogue hosts to + // send spoof answer packets (perhaps unicast to the victim host) which + // the receiving machine could misinterpret as having originated on the + // local link. + result = true; } - return result; - } - - @Override - public String toString() { - final StringBuilder sb = new StringBuilder(1024); - sb.append("local host info["); - sb.append(getName() != null ? getName() : "no name"); - sb.append(", "); - sb.append(getInterface() != null ? getInterface().getDisplayName() : "???"); - sb.append(":"); - sb.append(getInetAddress() != null ? getInetAddress().getHostAddress() : "no address"); - sb.append("]"); - return sb.toString(); + // if (from.isLinkLocalAddress() && (!this.getInetAddress().isLinkLocalAddress())) { + // // Ignore linklocal packets on regular interfaces, unless this is + // // also a linklocal interface. This is to avoid duplicates. This is + // // a terrible hack caused by the lack of an API to get the address + // // of the interface on which the packet was received. + // result = true; + // } + if (from.isLoopbackAddress() && (!this.getInetAddress().isLoopbackAddress())) { + // Ignore loopback packets on a regular interface unless this is also a loopback + // interface. + result = true; + } + } } + return result; + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder(1024); + sb.append("local host info["); + sb.append(getName() != null ? getName() : "no name"); + sb.append(", "); + sb.append(getInterface() != null ? getInterface().getDisplayName() : "???"); + sb.append(":"); + sb.append(getInetAddress() != null ? getInetAddress().getHostAddress() : "no address"); + sb.append("]"); + return sb.toString(); + } } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/JmDNSImpl.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/JmDNSImpl.java index d050053c4..82d9e4b13 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/JmDNSImpl.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/JmDNSImpl.java @@ -12,9 +12,6 @@ import io.libp2p.discovery.mdns.impl.tasks.Responder; import io.libp2p.discovery.mdns.impl.tasks.ServiceResolver; import io.libp2p.discovery.mdns.impl.util.NamedThreadFactory; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.io.IOException; import java.net.DatagramPacket; import java.net.Inet6Address; @@ -39,431 +36,426 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.locks.ReentrantLock; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Derived from mDNS implementation in Java. * - * @author Arthur van Hoff, Rick Blair, Jeff Sonstein, Werner Randelshofer, Pierre Frisch, Scott Lewis, Kai Kreuzer, Victor Toni + * @author Arthur van Hoff, Rick Blair, Jeff Sonstein, Werner Randelshofer, Pierre Frisch, Scott + * Lewis, Kai Kreuzer, Victor Toni */ public class JmDNSImpl extends JmDNS { - private static Logger logger = LoggerFactory.getLogger(JmDNSImpl.class.getName()); - - /** - * This is the multicast group, we are listening to for multicast DNS messages. - */ - private volatile InetAddress _group; - /** - * This is our multicast socket. - */ - private volatile MulticastSocket _socket; - - private final ConcurrentMap> _answerListeners; - private final ConcurrentMap _serviceResolvers; - - /** - * This hashtable holds the services that have been registered. Keys are instances of String which hold an all lower-case version of the fully qualified service name. Values are instances of ServiceInfo. - */ - private final ConcurrentMap _services; - - /** - * Handle on the local host - */ - private HostInfo _localHost; - - private SocketListener _incomingListener; - - private final ExecutorService _executor = Executors.newSingleThreadExecutor(new NamedThreadFactory("JmDNS")); - - /** - * The source for random values. This is used to introduce random delays in responses. This reduces the potential for collisions on the network. - */ - private final static Random _random = new Random(); - - /** - * This lock is used to coordinate processing of incoming and outgoing messages. This is needed, because the Rendezvous Conformance Test does not forgive race conditions. - */ - private final ReentrantLock _ioLock = new ReentrantLock(); - - private final String _name; - - /** - * Create an instance of JmDNS and bind it to a specific network interface given its IP-address. - * - * @param address IP address to bind to. - * @param name name of the newly created JmDNS - * @throws IOException - */ - public JmDNSImpl(InetAddress address, String name) { - super(); - logger.debug("JmDNS instance created"); - - _answerListeners = new ConcurrentHashMap<>(); - _serviceResolvers = new ConcurrentHashMap<>(); - - _services = new ConcurrentHashMap<>(20); - - _localHost = HostInfo.newHostInfo(address, this, name); - _name = (name != null ? name : _localHost.getName()); - } - - public void start() throws IOException { - // Bind to multicast socket - this.openMulticastSocket(this.getLocalHost()); - this.start(this.getServices().values()); + private static Logger logger = LoggerFactory.getLogger(JmDNSImpl.class.getName()); + + /** This is the multicast group, we are listening to for multicast DNS messages. */ + private volatile InetAddress _group; + + /** This is our multicast socket. */ + private volatile MulticastSocket _socket; + + private final ConcurrentMap> _answerListeners; + private final ConcurrentMap _serviceResolvers; + + /** + * This hashtable holds the services that have been registered. Keys are instances of String which + * hold an all lower-case version of the fully qualified service name. Values are instances of + * ServiceInfo. + */ + private final ConcurrentMap _services; + + /** Handle on the local host */ + private HostInfo _localHost; + + private SocketListener _incomingListener; + + private final ExecutorService _executor = + Executors.newSingleThreadExecutor(new NamedThreadFactory("JmDNS")); + + /** + * The source for random values. This is used to introduce random delays in responses. This + * reduces the potential for collisions on the network. + */ + private static final Random _random = new Random(); + + /** + * This lock is used to coordinate processing of incoming and outgoing messages. This is needed, + * because the Rendezvous Conformance Test does not forgive race conditions. + */ + private final ReentrantLock _ioLock = new ReentrantLock(); + + private final String _name; + + /** + * Create an instance of JmDNS and bind it to a specific network interface given its IP-address. + * + * @param address IP address to bind to. + * @param name name of the newly created JmDNS + * @throws IOException + */ + public JmDNSImpl(InetAddress address, String name) { + super(); + logger.debug("JmDNS instance created"); + + _answerListeners = new ConcurrentHashMap<>(); + _serviceResolvers = new ConcurrentHashMap<>(); + + _services = new ConcurrentHashMap<>(20); + + _localHost = HostInfo.newHostInfo(address, this, name); + _name = (name != null ? name : _localHost.getName()); + } + + public void start() throws IOException { + // Bind to multicast socket + this.openMulticastSocket(this.getLocalHost()); + this.start(this.getServices().values()); + } + + private void start(Collection serviceInfos) { + if (_incomingListener == null) { + _incomingListener = new SocketListener(this); + _incomingListener.start(); } - - private void start(Collection serviceInfos) { - if (_incomingListener == null) { - _incomingListener = new SocketListener(this); - _incomingListener.start(); - } - for (ServiceInfo info : serviceInfos) { - try { - this.registerService(new ServiceInfoImpl(info)); - } catch (final Exception exception) { - logger.warn("start() Registration exception ", exception); - } - } - } - - private void openMulticastSocket(HostInfo hostInfo) throws IOException { - if (_group == null) { - if (hostInfo.getInetAddress() instanceof Inet6Address) { - _group = InetAddress.getByName(DNSConstants.MDNS_GROUP_IPV6); - } else { - _group = InetAddress.getByName(DNSConstants.MDNS_GROUP); - } - } - if (_socket != null) { - this.closeMulticastSocket(); - } - // SocketAddress address = new InetSocketAddress((hostInfo != null ? hostInfo.getInetAddress() : null), DNSConstants.MDNS_PORT); - // System.out.println("Socket Address: " + address); - // try { - // _socket = new MulticastSocket(address); - // } catch (Exception exception) { - // logger.warn("openMulticastSocket() Open socket exception Address: " + address + ", ", exception); - // // The most likely cause is a duplicate address lets open without specifying the address - // _socket = new MulticastSocket(DNSConstants.MDNS_PORT); - // } - _socket = new MulticastSocket(DNSConstants.MDNS_PORT); - if ((hostInfo != null) && (hostInfo.getInterface() != null)) { - final SocketAddress multicastAddr = new InetSocketAddress(_group, DNSConstants.MDNS_PORT); - _socket.setNetworkInterface(hostInfo.getInterface()); - - logger.trace("Trying to joinGroup({}, {})", multicastAddr, hostInfo.getInterface()); - - // this joinGroup() might be less surprisingly so this is the default - _socket.joinGroup(multicastAddr, hostInfo.getInterface()); - } else { - logger.trace("Trying to joinGroup({})", _group); - _socket.joinGroup(_group); - } - - _socket.setTimeToLive(255); + for (ServiceInfo info : serviceInfos) { + try { + this.registerService(new ServiceInfoImpl(info)); + } catch (final Exception exception) { + logger.warn("start() Registration exception ", exception); + } } - - private void closeMulticastSocket() { - // jP: 20010-01-18. See below. We'll need this monitor... - // assert (Thread.holdsLock(this)); - logger.debug("closeMulticastSocket()"); - if (_socket != null) { - // close socket - try { - try { - _socket.leaveGroup(_group); - } catch (SocketException exception) { - // - } - _socket.close(); - } catch (final Exception exception) { - logger.warn("closeMulticastSocket() Close socket exception ", exception); - } - _socket = null; - } + } + + private void openMulticastSocket(HostInfo hostInfo) throws IOException { + if (_group == null) { + if (hostInfo.getInetAddress() instanceof Inet6Address) { + _group = InetAddress.getByName(DNSConstants.MDNS_GROUP_IPV6); + } else { + _group = InetAddress.getByName(DNSConstants.MDNS_GROUP); + } } - - /** - * {@inheritDoc} - */ - @Override - public String getName() { - return _name; + if (_socket != null) { + this.closeMulticastSocket(); } - - /** - * Returns the local host info - * - * @return local host info - */ - public HostInfo getLocalHost() { - return _localHost; + // SocketAddress address = new InetSocketAddress((hostInfo != null ? hostInfo.getInetAddress() : + // null), DNSConstants.MDNS_PORT); + // System.out.println("Socket Address: " + address); + // try { + // _socket = new MulticastSocket(address); + // } catch (Exception exception) { + // logger.warn("openMulticastSocket() Open socket exception Address: " + address + ", ", + // exception); + // // The most likely cause is a duplicate address lets open without specifying the address + // _socket = new MulticastSocket(DNSConstants.MDNS_PORT); + // } + _socket = new MulticastSocket(DNSConstants.MDNS_PORT); + if ((hostInfo != null) && (hostInfo.getInterface() != null)) { + final SocketAddress multicastAddr = new InetSocketAddress(_group, DNSConstants.MDNS_PORT); + _socket.setNetworkInterface(hostInfo.getInterface()); + + logger.trace("Trying to joinGroup({}, {})", multicastAddr, hostInfo.getInterface()); + + // this joinGroup() might be less surprisingly so this is the default + _socket.joinGroup(multicastAddr, hostInfo.getInterface()); + } else { + logger.trace("Trying to joinGroup({})", _group); + _socket.joinGroup(_group); } - void handleServiceAnswers(List answers) { - DNSRecord ptr = answers.get(0); - if (!DNSRecordType.TYPE_PTR.equals(ptr.getRecordType())) - return; - List list = _answerListeners.get(ptr.getKey()); - - if ((list != null) && (!list.isEmpty())) { - final List listCopy; - synchronized (list) { - listCopy = new ArrayList<>(list); - } - for (final AnswerListener listener : listCopy) { - _executor.submit(new Runnable() { - @Override - public void run() { - listener.answersReceived(answers); - } - }); - } - } - } + _socket.setTimeToLive(255); + } - @Override - public void addAnswerListener(String type, int queryInterval, AnswerListener listener) { - final String loType = type.toLowerCase(); - List list = _answerListeners.get(loType); - if (list == null) { - _answerListeners.putIfAbsent(loType, new LinkedList<>()); - list = _answerListeners.get(loType); - } - if (list != null) { - synchronized (list) { - if (!list.contains(listener)) { - list.add(listener); - } - } + private void closeMulticastSocket() { + // jP: 20010-01-18. See below. We'll need this monitor... + // assert (Thread.holdsLock(this)); + logger.debug("closeMulticastSocket()"); + if (_socket != null) { + // close socket + try { + try { + _socket.leaveGroup(_group); + } catch (SocketException exception) { + // } - - startServiceResolver(loType, queryInterval); + _socket.close(); + } catch (final Exception exception) { + logger.warn("closeMulticastSocket() Close socket exception ", exception); + } + _socket = null; } - - /** - * {@inheritDoc} - */ - @Override - public void registerService(ServiceInfo infoAbstract) throws IOException { - final ServiceInfoImpl info = (ServiceInfoImpl) infoAbstract; - - info.setServer(_localHost.getName()); - - _services.putIfAbsent(info.getKey(), info); - - logger.debug("registerService() JmDNS registered service as {}", info); + } + + /** {@inheritDoc} */ + @Override + public String getName() { + return _name; + } + + /** + * Returns the local host info + * + * @return local host info + */ + public HostInfo getLocalHost() { + return _localHost; + } + + void handleServiceAnswers(List answers) { + DNSRecord ptr = answers.get(0); + if (!DNSRecordType.TYPE_PTR.equals(ptr.getRecordType())) return; + List list = _answerListeners.get(ptr.getKey()); + + if ((list != null) && (!list.isEmpty())) { + final List listCopy; + synchronized (list) { + listCopy = new ArrayList<>(list); + } + for (final AnswerListener listener : listCopy) { + _executor.submit( + new Runnable() { + @Override + public void run() { + listener.answersReceived(answers); + } + }); + } } - - /** - * Handle an incoming response. Cache answers, and pass them on to the appropriate questions. - * - * @throws IOException - */ - void handleResponse(DNSIncoming msg) throws IOException { - List allAnswers = msg.getAllAnswers(); - allAnswers = aRecordsLast(allAnswers); - - handleServiceAnswers(allAnswers); + } + + @Override + public void addAnswerListener(String type, int queryInterval, AnswerListener listener) { + final String loType = type.toLowerCase(); + List list = _answerListeners.get(loType); + if (list == null) { + _answerListeners.putIfAbsent(loType, new LinkedList<>()); + list = _answerListeners.get(loType); } - - /** - * In case the a record is received before the srv record the ip address would not be set. - *

- * Multicast Domain Name System (response) - * Transaction ID: 0x0000 - * Flags: 0x8400 Standard query response, No error - * Questions: 0 - * Answer RRs: 2 - * Authority RRs: 0 - * Additional RRs: 8 - * Answers - * _ibisip_http._tcp.local: type PTR, class IN, DeviceManagementService._ibisip_http._tcp.local - * _ibisip_http._tcp.local: type PTR, class IN, PassengerCountingService._ibisip_http._tcp.local - * Additional records - * DeviceManagementService._ibisip_http._tcp.local: type TXT, class IN, cache flush - * PassengerCountingService._ibisip_http._tcp.local: type TXT, class IN, cache flush - * DIST500_7-F07_OC030_05_03941.local: type A, class IN, cache flush, addr 192.168.88.236 - * DeviceManagementService._ibisip_http._tcp.local: type SRV, class IN, cache flush, priority 0, weight 0, port 5000, target DIST500_7-F07_OC030_05_03941.local - * PassengerCountingService._ibisip_http._tcp.local: type SRV, class IN, cache flush, priority 0, weight 0, port 5001, target DIST500_7-F07_OC030_05_03941.local - * DeviceManagementService._ibisip_http._tcp.local: type NSEC, class IN, cache flush, next domain name DeviceManagementService._ibisip_http._tcp.local - * PassengerCountingService._ibisip_http._tcp.local: type NSEC, class IN, cache flush, next domain name PassengerCountingService._ibisip_http._tcp.local - * DIST500_7-F07_OC030_05_03941.local: type NSEC, class IN, cache flush, next domain name DIST500_7-F07_OC030_05_03941.local - */ - private List aRecordsLast(List allAnswers) { - ArrayList ret = new ArrayList(allAnswers.size()); - ArrayList arecords = new ArrayList(); - - for (DNSRecord answer : allAnswers) { - DNSRecordType type = answer.getRecordType(); - if (type.equals(DNSRecordType.TYPE_A) || type.equals(DNSRecordType.TYPE_AAAA)) { - arecords.add(answer); - } else if (type.equals(DNSRecordType.TYPE_PTR)) { - ret.add(0, answer); - } else { - ret.add(answer); - } + if (list != null) { + synchronized (list) { + if (!list.contains(listener)) { + list.add(listener); } - ret.addAll(arecords); - return ret; + } } - - /** - * Handle an incoming query. See if we can answer any part of it given our service infos. - * - * @param in - * @param addr - * @param port - * @throws IOException - */ - void handleQuery(DNSIncoming in, InetAddress addr, int port) throws IOException { - logger.debug("{} handle query: {}", this.getName(), in); - this.ioLock(); - try { - DNSIncoming plannedAnswer = in.clone(); - this.startResponder(plannedAnswer, addr, port); - } finally { - this.ioUnlock(); - } + startServiceResolver(loType, queryInterval); + } + + /** {@inheritDoc} */ + @Override + public void registerService(ServiceInfo infoAbstract) throws IOException { + final ServiceInfoImpl info = (ServiceInfoImpl) infoAbstract; + + info.setServer(_localHost.getName()); + + _services.putIfAbsent(info.getKey(), info); + + logger.debug("registerService() JmDNS registered service as {}", info); + } + + /** + * Handle an incoming response. Cache answers, and pass them on to the appropriate questions. + * + * @throws IOException + */ + void handleResponse(DNSIncoming msg) throws IOException { + List allAnswers = msg.getAllAnswers(); + allAnswers = aRecordsLast(allAnswers); + + handleServiceAnswers(allAnswers); + } + + /** + * In case the a record is received before the srv record the ip address would not be set. + * + *

Multicast Domain Name System (response) Transaction ID: 0x0000 Flags: 0x8400 Standard query + * response, No error Questions: 0 Answer RRs: 2 Authority RRs: 0 Additional RRs: 8 Answers + * _ibisip_http._tcp.local: type PTR, class IN, DeviceManagementService._ibisip_http._tcp.local + * _ibisip_http._tcp.local: type PTR, class IN, PassengerCountingService._ibisip_http._tcp.local + * Additional records DeviceManagementService._ibisip_http._tcp.local: type TXT, class IN, cache + * flush PassengerCountingService._ibisip_http._tcp.local: type TXT, class IN, cache flush + * DIST500_7-F07_OC030_05_03941.local: type A, class IN, cache flush, addr 192.168.88.236 + * DeviceManagementService._ibisip_http._tcp.local: type SRV, class IN, cache flush, priority 0, + * weight 0, port 5000, target DIST500_7-F07_OC030_05_03941.local + * PassengerCountingService._ibisip_http._tcp.local: type SRV, class IN, cache flush, priority 0, + * weight 0, port 5001, target DIST500_7-F07_OC030_05_03941.local + * DeviceManagementService._ibisip_http._tcp.local: type NSEC, class IN, cache flush, next domain + * name DeviceManagementService._ibisip_http._tcp.local + * PassengerCountingService._ibisip_http._tcp.local: type NSEC, class IN, cache flush, next domain + * name PassengerCountingService._ibisip_http._tcp.local DIST500_7-F07_OC030_05_03941.local: type + * NSEC, class IN, cache flush, next domain name DIST500_7-F07_OC030_05_03941.local + */ + private List aRecordsLast(List allAnswers) { + ArrayList ret = new ArrayList(allAnswers.size()); + ArrayList arecords = new ArrayList(); + + for (DNSRecord answer : allAnswers) { + DNSRecordType type = answer.getRecordType(); + if (type.equals(DNSRecordType.TYPE_A) || type.equals(DNSRecordType.TYPE_AAAA)) { + arecords.add(answer); + } else if (type.equals(DNSRecordType.TYPE_PTR)) { + ret.add(0, answer); + } else { + ret.add(answer); + } } - - /** - * Send an outgoing multicast DNS message. - * - * @param out - * @throws IOException - */ - public void send(DNSOutgoing out) throws IOException { - if (!out.isEmpty()) { - final InetAddress addr; - final int port; - - if (out.getDestination() != null) { - addr = out.getDestination().getAddress(); - port = out.getDestination().getPort(); - } else { - addr = _group; - port = DNSConstants.MDNS_PORT; - } - - byte[] message = out.data(); - final DatagramPacket packet = new DatagramPacket(message, message.length, addr, port); - - if (logger.isTraceEnabled()) { - try { - final DNSIncoming msg = new DNSIncoming(packet); - if (logger.isTraceEnabled()) { - logger.trace("send({}) JmDNS out:{}", this.getName(), msg.print(true)); - } - } catch (final IOException e) { - logger.debug(getClass().toString(), ".send(" + this.getName() + ") - JmDNS can not parse what it sends!!!", e); - } - } - final MulticastSocket ms = _socket; - if (ms != null && !ms.isClosed()) { - ms.send(packet); - } - } + ret.addAll(arecords); + return ret; + } + + /** + * Handle an incoming query. See if we can answer any part of it given our service infos. + * + * @param in + * @param addr + * @param port + * @throws IOException + */ + void handleQuery(DNSIncoming in, InetAddress addr, int port) throws IOException { + logger.debug("{} handle query: {}", this.getName(), in); + this.ioLock(); + try { + DNSIncoming plannedAnswer = in.clone(); + this.startResponder(plannedAnswer, addr, port); + } finally { + this.ioUnlock(); } - - private void startServiceResolver(String type, int queryInterval) { - if (_serviceResolvers.containsKey(type)) - return; - - ServiceResolver resolver = new ServiceResolver(this, type, queryInterval); - if (_serviceResolvers.putIfAbsent(type, resolver) == null) - resolver.start(); + } + + /** + * Send an outgoing multicast DNS message. + * + * @param out + * @throws IOException + */ + public void send(DNSOutgoing out) throws IOException { + if (!out.isEmpty()) { + final InetAddress addr; + final int port; + + if (out.getDestination() != null) { + addr = out.getDestination().getAddress(); + port = out.getDestination().getPort(); + } else { + addr = _group; + port = DNSConstants.MDNS_PORT; + } + + byte[] message = out.data(); + final DatagramPacket packet = new DatagramPacket(message, message.length, addr, port); + + if (logger.isTraceEnabled()) { + try { + final DNSIncoming msg = new DNSIncoming(packet); + if (logger.isTraceEnabled()) { + logger.trace("send({}) JmDNS out:{}", this.getName(), msg.print(true)); + } + } catch (final IOException e) { + logger.debug( + getClass().toString(), + ".send(" + this.getName() + ") - JmDNS can not parse what it sends!!!", + e); + } + } + final MulticastSocket ms = _socket; + if (ms != null && !ms.isClosed()) { + ms.send(packet); + } } + } - private void startResponder(DNSIncoming in, InetAddress addr, int port) { - new Responder(this, in, addr, port).start(); - } + private void startServiceResolver(String type, int queryInterval) { + if (_serviceResolvers.containsKey(type)) return; - public void stop() { - logger.debug("Stopping JmDNS: {}", this); + ServiceResolver resolver = new ServiceResolver(this, type, queryInterval); + if (_serviceResolvers.putIfAbsent(type, resolver) == null) resolver.start(); + } - List> shutdowns = new ArrayList<>(); + private void startResponder(DNSIncoming in, InetAddress addr, int port) { + new Responder(this, in, addr, port).start(); + } - shutdowns.add(_incomingListener.stop()); - _incomingListener = null; + public void stop() { + logger.debug("Stopping JmDNS: {}", this); - for (ServiceResolver resolver : _serviceResolvers.values()) - shutdowns.add(resolver.stop()); + List> shutdowns = new ArrayList<>(); - // close socket - this.closeMulticastSocket(); + shutdowns.add(_incomingListener.stop()); + _incomingListener = null; - logger.debug("JmDNS waiting for service stop..."); + for (ServiceResolver resolver : _serviceResolvers.values()) shutdowns.add(resolver.stop()); - for (Future shutdown : shutdowns) { - try { - shutdown.get(10, TimeUnit.SECONDS); - } catch (CancellationException e) { - logger.trace("Task was already cancelled", e); - } catch (InterruptedException e) { - logger.trace("Stopping was interrupted", e); - Thread.currentThread().interrupt(); - } catch (ExecutionException | TimeoutException e) { - logger.debug("Exception when stopping JmDNS: ", e); - throw new RuntimeException(e); - } - } + // close socket + this.closeMulticastSocket(); - _executor.shutdown(); + logger.debug("JmDNS waiting for service stop..."); - logger.debug("JmDNS stopped."); + for (Future shutdown : shutdowns) { + try { + shutdown.get(10, TimeUnit.SECONDS); + } catch (CancellationException e) { + logger.trace("Task was already cancelled", e); + } catch (InterruptedException e) { + logger.trace("Stopping was interrupted", e); + Thread.currentThread().interrupt(); + } catch (ExecutionException | TimeoutException e) { + logger.debug("Exception when stopping JmDNS: ", e); + throw new RuntimeException(e); + } } - /** - * {@inheritDoc} - */ - @Override - public String toString() { - final StringBuilder sb = new StringBuilder(2048); - sb.append("\n"); - sb.append("\t---- Local Host -----"); - sb.append("\n\t"); - sb.append(_localHost); - sb.append("\n\t---- Services -----"); - for (final Map.Entry entry : _services.entrySet()) { - sb.append("\n\t\tService: "); - sb.append(entry.getKey()); - sb.append(": "); - sb.append(entry.getValue()); - } - sb.append("\n"); - sb.append("\t---- Answer Listeners ----"); - for (final Map.Entry> entry : _answerListeners.entrySet()) { - sb.append("\n\t\tAnswer Listener: "); - sb.append(entry.getKey()); - sb.append(": "); - sb.append(entry.getValue()); - } - return sb.toString(); + _executor.shutdown(); + + logger.debug("JmDNS stopped."); + } + + /** {@inheritDoc} */ + @Override + public String toString() { + final StringBuilder sb = new StringBuilder(2048); + sb.append("\n"); + sb.append("\t---- Local Host -----"); + sb.append("\n\t"); + sb.append(_localHost); + sb.append("\n\t---- Services -----"); + for (final Map.Entry entry : _services.entrySet()) { + sb.append("\n\t\tService: "); + sb.append(entry.getKey()); + sb.append(": "); + sb.append(entry.getValue()); } - - public Map getServices() { - return _services; + sb.append("\n"); + sb.append("\t---- Answer Listeners ----"); + for (final Map.Entry> entry : _answerListeners.entrySet()) { + sb.append("\n\t\tAnswer Listener: "); + sb.append(entry.getKey()); + sb.append(": "); + sb.append(entry.getValue()); } + return sb.toString(); + } - public static Random getRandom() { - return _random; - } + public Map getServices() { + return _services; + } - private void ioLock() { - _ioLock.lock(); - } + public static Random getRandom() { + return _random; + } - private void ioUnlock() { - _ioLock.unlock(); - } + private void ioLock() { + _ioLock.lock(); + } - public MulticastSocket getSocket() { - return _socket; - } + private void ioUnlock() { + _ioLock.unlock(); + } - public InetAddress getGroup() { - return _group; - } + public MulticastSocket getSocket() { + return _socket; + } + + public InetAddress getGroup() { + return _group; + } } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/ServiceInfoImpl.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/ServiceInfoImpl.java index cc0c94dbf..54efc0ef7 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/ServiceInfoImpl.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/ServiceInfoImpl.java @@ -7,9 +7,6 @@ import io.libp2p.discovery.mdns.ServiceInfo; import io.libp2p.discovery.mdns.impl.constants.DNSRecordClass; import io.libp2p.discovery.mdns.impl.util.ByteWrangler; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.io.IOException; import java.net.Inet4Address; import java.net.Inet6Address; @@ -23,6 +20,8 @@ import java.util.List; import java.util.Map; import java.util.Set; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * JmDNS service information. @@ -30,435 +29,488 @@ * @author Arthur van Hoff, Jeff Sonstein, Werner Randelshofer, Victor Toni */ public class ServiceInfoImpl extends ServiceInfo { - private static Logger logger = LoggerFactory.getLogger(ServiceInfoImpl.class.getName()); - - private String _domain; - private String _protocol; - private String _application; - private String _name; - private String _subtype; - private String _server; - private int _port; - private int _weight; - private int _priority; - private byte[] _text; - private final Set _ipv4Addresses; - private final Set _ipv6Addresses; - - private transient String _key; - - /** - * @param type - * @param name - * @param subtype - * @param port - * @param weight - * @param priority - * @param text - */ - public ServiceInfoImpl(String type, String name, String subtype, int port, int weight, int priority, String text) { - this(ServiceInfoImpl.decodeQualifiedNameMap(type, name, subtype), port, weight, priority, (byte[]) null); - - try { - this._text = ByteWrangler.encodeText(text); - } catch (final IOException e) { - throw new RuntimeException("Unexpected exception: " + e); + private static Logger logger = LoggerFactory.getLogger(ServiceInfoImpl.class.getName()); + + private String _domain; + private String _protocol; + private String _application; + private String _name; + private String _subtype; + private String _server; + private int _port; + private int _weight; + private int _priority; + private byte[] _text; + private final Set _ipv4Addresses; + private final Set _ipv6Addresses; + + private transient String _key; + + /** + * @param type + * @param name + * @param subtype + * @param port + * @param weight + * @param priority + * @param text + */ + public ServiceInfoImpl( + String type, String name, String subtype, int port, int weight, int priority, String text) { + this( + ServiceInfoImpl.decodeQualifiedNameMap(type, name, subtype), + port, + weight, + priority, + (byte[]) null); + + try { + this._text = ByteWrangler.encodeText(text); + } catch (final IOException e) { + throw new RuntimeException("Unexpected exception: " + e); + } + + _server = text; + } + + ServiceInfoImpl( + Map qualifiedNameMap, int port, int weight, int priority, byte text[]) { + Map map = ServiceInfoImpl.checkQualifiedNameMap(qualifiedNameMap); + + this._domain = map.get(Fields.Domain); + this._protocol = map.get(Fields.Protocol); + this._application = map.get(Fields.Application); + this._name = map.get(Fields.Instance); + this._subtype = map.get(Fields.Subtype); + + this._port = port; + this._weight = weight; + this._priority = priority; + this._text = text; + this._ipv4Addresses = Collections.synchronizedSet(new LinkedHashSet()); + this._ipv6Addresses = Collections.synchronizedSet(new LinkedHashSet()); + } + + ServiceInfoImpl(ServiceInfo info) { + this._ipv4Addresses = Collections.synchronizedSet(new LinkedHashSet()); + this._ipv6Addresses = Collections.synchronizedSet(new LinkedHashSet()); + if (info != null) { + this._domain = info.getDomain(); + this._protocol = info.getProtocol(); + this._application = info.getApplication(); + this._name = info.getName(); + this._subtype = info.getSubtype(); + this._port = info.getPort(); + this._weight = info.getWeight(); + this._priority = info.getPriority(); + this._text = info.getTextBytes(); + Inet6Address[] ipv6Addresses = info.getInet6Addresses(); + for (Inet6Address address : ipv6Addresses) { + this._ipv6Addresses.add(address); + } + Inet4Address[] ipv4Addresses = info.getInet4Addresses(); + for (Inet4Address address : ipv4Addresses) { + this._ipv4Addresses.add(address); + } + } + } + + public static Map decodeQualifiedNameMap( + String type, String name, String subtype) { + Map qualifiedNameMap = decodeQualifiedNameMapForType(type); + + qualifiedNameMap.put(Fields.Instance, name); + qualifiedNameMap.put(Fields.Subtype, subtype); + + return checkQualifiedNameMap(qualifiedNameMap); + } + + public static Map decodeQualifiedNameMapForType(String type) { + int index; + + String casePreservedType = type; + + String aType = type.toLowerCase(); + String application = aType; + String protocol = ""; + String subtype = ""; + String name = ""; + String domain = ""; + + if (aType.contains("in-addr.arpa") || aType.contains("ip6.arpa")) { + index = + (aType.contains("in-addr.arpa") + ? aType.indexOf("in-addr.arpa") + : aType.indexOf("ip6.arpa")); + name = removeSeparators(casePreservedType.substring(0, index)); + domain = casePreservedType.substring(index); + application = ""; + } else if ((!aType.contains("_")) && aType.contains(".")) { + index = aType.indexOf('.'); + name = removeSeparators(casePreservedType.substring(0, index)); + domain = removeSeparators(casePreservedType.substring(index)); + application = ""; + } else { + // First remove the name if it there. + if (!aType.startsWith("_") || aType.startsWith("_services")) { + index = aType.indexOf("._"); + if (index > 0) { + // We need to preserve the case for the user readable name. + name = casePreservedType.substring(0, index); + if (index + 1 < aType.length()) { + aType = aType.substring(index + 1); + casePreservedType = casePreservedType.substring(index + 1); + } } - - _server = text; - } - - ServiceInfoImpl(Map qualifiedNameMap, int port, int weight, int priority, byte text[]) { - Map map = ServiceInfoImpl.checkQualifiedNameMap(qualifiedNameMap); - - this._domain = map.get(Fields.Domain); - this._protocol = map.get(Fields.Protocol); - this._application = map.get(Fields.Application); - this._name = map.get(Fields.Instance); - this._subtype = map.get(Fields.Subtype); - - this._port = port; - this._weight = weight; - this._priority = priority; - this._text = text; - this._ipv4Addresses = Collections.synchronizedSet(new LinkedHashSet()); - this._ipv6Addresses = Collections.synchronizedSet(new LinkedHashSet()); - } - - ServiceInfoImpl(ServiceInfo info) { - this._ipv4Addresses = Collections.synchronizedSet(new LinkedHashSet()); - this._ipv6Addresses = Collections.synchronizedSet(new LinkedHashSet()); - if (info != null) { - this._domain = info.getDomain(); - this._protocol = info.getProtocol(); - this._application = info.getApplication(); - this._name = info.getName(); - this._subtype = info.getSubtype(); - this._port = info.getPort(); - this._weight = info.getWeight(); - this._priority = info.getPriority(); - this._text = info.getTextBytes(); - Inet6Address[] ipv6Addresses = info.getInet6Addresses(); - for (Inet6Address address : ipv6Addresses) { - this._ipv6Addresses.add(address); - } - Inet4Address[] ipv4Addresses = info.getInet4Addresses(); - for (Inet4Address address : ipv4Addresses) { - this._ipv4Addresses.add(address); - } + } + + index = aType.lastIndexOf("._"); + if (index > 0) { + int start = index + 2; + int end = aType.indexOf('.', start); + protocol = casePreservedType.substring(start, end); + } + if (protocol.length() > 0) { + index = aType.indexOf("_" + protocol.toLowerCase() + "."); + int start = index + protocol.length() + 2; + int end = aType.length() - (aType.endsWith(".") ? 1 : 0); + if (end > start) { + domain = casePreservedType.substring(start, end); } - } - - public static Map decodeQualifiedNameMap(String type, String name, String subtype) { - Map qualifiedNameMap = decodeQualifiedNameMapForType(type); - - qualifiedNameMap.put(Fields.Instance, name); - qualifiedNameMap.put(Fields.Subtype, subtype); - - return checkQualifiedNameMap(qualifiedNameMap); - } - - public static Map decodeQualifiedNameMapForType(String type) { - int index; - - String casePreservedType = type; - - String aType = type.toLowerCase(); - String application = aType; - String protocol = ""; - String subtype = ""; - String name = ""; - String domain = ""; - - if (aType.contains("in-addr.arpa") || aType.contains("ip6.arpa")) { - index = (aType.contains("in-addr.arpa") ? aType.indexOf("in-addr.arpa") : aType.indexOf("ip6.arpa")); - name = removeSeparators(casePreservedType.substring(0, index)); - domain = casePreservedType.substring(index); - application = ""; - } else if ((!aType.contains("_")) && aType.contains(".")) { - index = aType.indexOf('.'); - name = removeSeparators(casePreservedType.substring(0, index)); - domain = removeSeparators(casePreservedType.substring(index)); - application = ""; + if (index > 0) { + application = casePreservedType.substring(0, index - 1); } else { - // First remove the name if it there. - if (!aType.startsWith("_") || aType.startsWith("_services")) { - index = aType.indexOf("._"); - if (index > 0) { - // We need to preserve the case for the user readable name. - name = casePreservedType.substring(0, index); - if (index + 1 < aType.length()) { - aType = aType.substring(index + 1); - casePreservedType = casePreservedType.substring(index + 1); - } - } - } - - index = aType.lastIndexOf("._"); - if (index > 0) { - int start = index + 2; - int end = aType.indexOf('.', start); - protocol = casePreservedType.substring(start, end); - } - if (protocol.length() > 0) { - index = aType.indexOf("_" + protocol.toLowerCase() + "."); - int start = index + protocol.length() + 2; - int end = aType.length() - (aType.endsWith(".") ? 1 : 0); - if (end > start) { - domain = casePreservedType.substring(start, end); - } - if (index > 0) { - application = casePreservedType.substring(0, index - 1); - } else { - application = ""; - } - } - index = application.toLowerCase().indexOf("._sub"); - if (index > 0) { - int start = index + 5; - subtype = removeSeparators(application.substring(0, index)); - application = application.substring(start); - } - } - - final Map qualifiedNameMap = new HashMap(5); - qualifiedNameMap.put(Fields.Domain, removeSeparators(domain)); - qualifiedNameMap.put(Fields.Protocol, protocol); - qualifiedNameMap.put(Fields.Application, removeSeparators(application)); - qualifiedNameMap.put(Fields.Instance, name); - qualifiedNameMap.put(Fields.Subtype, subtype); - - return qualifiedNameMap; - } - - protected static Map checkQualifiedNameMap(Map qualifiedNameMap) { - Map checkedQualifiedNameMap = new HashMap(5); - - // Optional domain - String domain = (qualifiedNameMap.containsKey(Fields.Domain) ? qualifiedNameMap.get(Fields.Domain) : "local"); - if ((domain == null) || (domain.length() == 0)) { - domain = "local"; - } - domain = removeSeparators(domain); - checkedQualifiedNameMap.put(Fields.Domain, domain); - // Optional protocol - String protocol = (qualifiedNameMap.containsKey(Fields.Protocol) ? qualifiedNameMap.get(Fields.Protocol) : "tcp"); - if ((protocol == null) || (protocol.length() == 0)) { - protocol = "tcp"; - } - protocol = removeSeparators(protocol); - checkedQualifiedNameMap.put(Fields.Protocol, protocol); - // Application - String application = (qualifiedNameMap.containsKey(Fields.Application) ? qualifiedNameMap.get(Fields.Application) : ""); - if ((application == null) || (application.length() == 0)) { - application = ""; + application = ""; } - application = removeSeparators(application); - checkedQualifiedNameMap.put(Fields.Application, application); - // Instance - String instance = (qualifiedNameMap.containsKey(Fields.Instance) ? qualifiedNameMap.get(Fields.Instance) : ""); - if ((instance == null) || (instance.length() == 0)) { - instance = ""; - // throw new IllegalArgumentException("The instance name component of a fully qualified service cannot be empty."); - } - instance = removeSeparators(instance); - checkedQualifiedNameMap.put(Fields.Instance, instance); - // Optional Subtype - String subtype = (qualifiedNameMap.containsKey(Fields.Subtype) ? qualifiedNameMap.get(Fields.Subtype) : ""); - if ((subtype == null) || (subtype.length() == 0)) { - subtype = ""; - } - subtype = removeSeparators(subtype); - checkedQualifiedNameMap.put(Fields.Subtype, subtype); - - return checkedQualifiedNameMap; - } - - private static String removeSeparators(String name) { - if (name == null) { - return ""; - } - String newName = name.trim(); - if (newName.startsWith(".")) { - newName = newName.substring(1); - } - if (newName.startsWith("_")) { - newName = newName.substring(1); - } - if (newName.endsWith(".")) { - newName = newName.substring(0, newName.length() - 1); - } - return newName; - } - - @Override - public String getType() { - String domain = this.getDomain(); - String protocol = this.getProtocol(); - String application = this.getApplication(); - return (application.length() > 0 ? "_" + application + "." : "") + (protocol.length() > 0 ? "_" + protocol + "." : "") + domain + "."; - } - - @Override - public String getTypeWithSubtype() { - String subtype = this.getSubtype(); - return (subtype.length() > 0 ? "_" + subtype + "._sub." : "") + this.getType(); - } - - @Override - public String getName() { - return (_name != null ? _name : ""); - } - - @Override - public String getKey() { - if (this._key == null) { - this._key = this.getQualifiedName().toLowerCase(); - } - return this._key; - } - - @Override - public String getQualifiedName() { - String domain = this.getDomain(); - String protocol = this.getProtocol(); - String application = this.getApplication(); - String instance = this.getName(); - return (instance.length() > 0 ? instance + "." : "") + (application.length() > 0 ? "_" + application + "." : "") + (protocol.length() > 0 ? "_" + protocol + "." : "") + domain + "."; - } - - @Override - public String getServer() { - return (_server != null ? _server : ""); - } - void setServer(String server) { - this._server = server; - } - - public void addAddress(Inet4Address addr) { - _ipv4Addresses.add(addr); - } - public void addAddress(Inet6Address addr) { - _ipv6Addresses.add(addr); - } - - @Override - public Inet4Address[] getInet4Addresses() { - return _ipv4Addresses.toArray(new Inet4Address[_ipv4Addresses.size()]); - } - - @Override - public Inet6Address[] getInet6Addresses() { - return _ipv6Addresses.toArray(new Inet6Address[_ipv6Addresses.size()]); - } - - @Override - public int getPort() { - return _port; - } - - @Override - public int getPriority() { - return _priority; - } - - @Override - public int getWeight() { - return _weight; - } - - @Override - public byte[] getTextBytes() { - return (this._text != null && this._text.length > 0 ? this._text : ByteWrangler.EMPTY_TXT); - } - - @Override - public String getApplication() { - return (_application != null ? _application : ""); - } - - @Override - public String getDomain() { - return (_domain != null ? _domain : "local"); - } - - @Override - public String getProtocol() { - return (_protocol != null ? _protocol : "tcp"); - } - - @Override - public String getSubtype() { - return (_subtype != null ? _subtype : ""); - } - - @Override - public Map getQualifiedNameMap() { - Map map = new HashMap(5); - - map.put(Fields.Domain, this.getDomain()); - map.put(Fields.Protocol, this.getProtocol()); - map.put(Fields.Application, this.getApplication()); - map.put(Fields.Instance, this.getName()); - map.put(Fields.Subtype, this.getSubtype()); - return map; - } - - @Override - public synchronized boolean hasData() { - return this.getServer() != null && this.hasInetAddress() && this.getTextBytes() != null && this.getTextBytes().length > 0; - } - - private boolean hasInetAddress() { - return _ipv4Addresses.size() > 0 || _ipv6Addresses.size() > 0; - } - - @Override - public int hashCode() { - return getQualifiedName().hashCode(); - } - - @Override - public boolean equals(Object obj) { - return (obj instanceof ServiceInfoImpl) && getQualifiedName().equals(((ServiceInfoImpl) obj).getQualifiedName()); - } - - @Override - public ServiceInfoImpl clone() { - ServiceInfoImpl serviceInfo = new ServiceInfoImpl(this.getQualifiedNameMap(), _port, _weight, _priority, _text); - serviceInfo._ipv6Addresses.addAll(Arrays.asList(getInet6Addresses())); - serviceInfo._ipv4Addresses.addAll(Arrays.asList(getInet4Addresses())); - serviceInfo._server = _server; - return serviceInfo; - } - - @Override - public String toString() { - final StringBuilder sb = new StringBuilder(); - sb.append('[').append(this.getClass().getSimpleName()).append('@').append(System.identityHashCode(this)); - sb.append(" name: '"); - if (0 < this.getName().length()) { - sb.append(this.getName()).append('.'); - } - sb.append(this.getTypeWithSubtype()); - sb.append("' address: '"); - Inet4Address[] addresses4 = this.getInet4Addresses(); - for (InetAddress address : addresses4) { - sb.append(address).append(':').append(this.getPort()).append(' '); - } - Inet6Address[] addresses6 = this.getInet6Addresses(); - for (InetAddress address : addresses6) { - sb.append(address).append(':').append(this.getPort()).append(' '); - } - - sb.append(this.hasData() ? " has data" : " has NO data"); - sb.append(']'); - - return sb.toString(); - } - - /** - * Create a series of answer that correspond with the give service info. - * - * @param recordClass - * record class of the query - * @param unique - * @param ttl - * @param localHost - * @return collection of answers - */ - public Collection answers(DNSRecordClass recordClass, boolean unique, int ttl, HostInfo localHost) { - List list = new ArrayList(); - - if ((recordClass == DNSRecordClass.CLASS_ANY) || (recordClass == DNSRecordClass.CLASS_IN)) { - if (this.getSubtype().length() > 0) { - list.add(new DNSRecord.Pointer(this.getTypeWithSubtype(), DNSRecordClass.CLASS_IN, DNSRecordClass.NOT_UNIQUE, ttl, this.getQualifiedName())); - } - list.add(new DNSRecord.Pointer(this.getType(), DNSRecordClass.CLASS_IN, DNSRecordClass.NOT_UNIQUE, ttl, this.getQualifiedName())); - list.add(new DNSRecord.Service(this.getQualifiedName(), DNSRecordClass.CLASS_IN, unique, ttl, _priority, _weight, _port, localHost.getName())); - list.add(new DNSRecord.Text(this.getQualifiedName(), DNSRecordClass.CLASS_IN, unique, ttl, this.getTextBytes())); - for (InetAddress address : _ipv4Addresses) - list.add( - new DNSRecord.IPv4Address( - this.getQualifiedName(), - DNSRecordClass.CLASS_IN, - unique, - ttl, - address - ) - ); - for (InetAddress address : _ipv6Addresses) - list.add( - new DNSRecord.IPv6Address( - this.getQualifiedName(), - DNSRecordClass.CLASS_IN, - unique, - ttl, - address - ) - ); - } - - return list; - } + } + index = application.toLowerCase().indexOf("._sub"); + if (index > 0) { + int start = index + 5; + subtype = removeSeparators(application.substring(0, index)); + application = application.substring(start); + } + } + + final Map qualifiedNameMap = new HashMap(5); + qualifiedNameMap.put(Fields.Domain, removeSeparators(domain)); + qualifiedNameMap.put(Fields.Protocol, protocol); + qualifiedNameMap.put(Fields.Application, removeSeparators(application)); + qualifiedNameMap.put(Fields.Instance, name); + qualifiedNameMap.put(Fields.Subtype, subtype); + + return qualifiedNameMap; + } + + protected static Map checkQualifiedNameMap(Map qualifiedNameMap) { + Map checkedQualifiedNameMap = new HashMap(5); + + // Optional domain + String domain = + (qualifiedNameMap.containsKey(Fields.Domain) + ? qualifiedNameMap.get(Fields.Domain) + : "local"); + if ((domain == null) || (domain.length() == 0)) { + domain = "local"; + } + domain = removeSeparators(domain); + checkedQualifiedNameMap.put(Fields.Domain, domain); + // Optional protocol + String protocol = + (qualifiedNameMap.containsKey(Fields.Protocol) + ? qualifiedNameMap.get(Fields.Protocol) + : "tcp"); + if ((protocol == null) || (protocol.length() == 0)) { + protocol = "tcp"; + } + protocol = removeSeparators(protocol); + checkedQualifiedNameMap.put(Fields.Protocol, protocol); + // Application + String application = + (qualifiedNameMap.containsKey(Fields.Application) + ? qualifiedNameMap.get(Fields.Application) + : ""); + if ((application == null) || (application.length() == 0)) { + application = ""; + } + application = removeSeparators(application); + checkedQualifiedNameMap.put(Fields.Application, application); + // Instance + String instance = + (qualifiedNameMap.containsKey(Fields.Instance) + ? qualifiedNameMap.get(Fields.Instance) + : ""); + if ((instance == null) || (instance.length() == 0)) { + instance = ""; + // throw new IllegalArgumentException("The instance name component of a fully qualified + // service cannot be empty."); + } + instance = removeSeparators(instance); + checkedQualifiedNameMap.put(Fields.Instance, instance); + // Optional Subtype + String subtype = + (qualifiedNameMap.containsKey(Fields.Subtype) ? qualifiedNameMap.get(Fields.Subtype) : ""); + if ((subtype == null) || (subtype.length() == 0)) { + subtype = ""; + } + subtype = removeSeparators(subtype); + checkedQualifiedNameMap.put(Fields.Subtype, subtype); + + return checkedQualifiedNameMap; + } + + private static String removeSeparators(String name) { + if (name == null) { + return ""; + } + String newName = name.trim(); + if (newName.startsWith(".")) { + newName = newName.substring(1); + } + if (newName.startsWith("_")) { + newName = newName.substring(1); + } + if (newName.endsWith(".")) { + newName = newName.substring(0, newName.length() - 1); + } + return newName; + } + + @Override + public String getType() { + String domain = this.getDomain(); + String protocol = this.getProtocol(); + String application = this.getApplication(); + return (application.length() > 0 ? "_" + application + "." : "") + + (protocol.length() > 0 ? "_" + protocol + "." : "") + + domain + + "."; + } + + @Override + public String getTypeWithSubtype() { + String subtype = this.getSubtype(); + return (subtype.length() > 0 ? "_" + subtype + "._sub." : "") + this.getType(); + } + + @Override + public String getName() { + return (_name != null ? _name : ""); + } + + @Override + public String getKey() { + if (this._key == null) { + this._key = this.getQualifiedName().toLowerCase(); + } + return this._key; + } + + @Override + public String getQualifiedName() { + String domain = this.getDomain(); + String protocol = this.getProtocol(); + String application = this.getApplication(); + String instance = this.getName(); + return (instance.length() > 0 ? instance + "." : "") + + (application.length() > 0 ? "_" + application + "." : "") + + (protocol.length() > 0 ? "_" + protocol + "." : "") + + domain + + "."; + } + + @Override + public String getServer() { + return (_server != null ? _server : ""); + } + + void setServer(String server) { + this._server = server; + } + + public void addAddress(Inet4Address addr) { + _ipv4Addresses.add(addr); + } + + public void addAddress(Inet6Address addr) { + _ipv6Addresses.add(addr); + } + + @Override + public Inet4Address[] getInet4Addresses() { + return _ipv4Addresses.toArray(new Inet4Address[_ipv4Addresses.size()]); + } + + @Override + public Inet6Address[] getInet6Addresses() { + return _ipv6Addresses.toArray(new Inet6Address[_ipv6Addresses.size()]); + } + + @Override + public int getPort() { + return _port; + } + + @Override + public int getPriority() { + return _priority; + } + + @Override + public int getWeight() { + return _weight; + } + + @Override + public byte[] getTextBytes() { + return (this._text != null && this._text.length > 0 ? this._text : ByteWrangler.EMPTY_TXT); + } + + @Override + public String getApplication() { + return (_application != null ? _application : ""); + } + + @Override + public String getDomain() { + return (_domain != null ? _domain : "local"); + } + + @Override + public String getProtocol() { + return (_protocol != null ? _protocol : "tcp"); + } + + @Override + public String getSubtype() { + return (_subtype != null ? _subtype : ""); + } + + @Override + public Map getQualifiedNameMap() { + Map map = new HashMap(5); + + map.put(Fields.Domain, this.getDomain()); + map.put(Fields.Protocol, this.getProtocol()); + map.put(Fields.Application, this.getApplication()); + map.put(Fields.Instance, this.getName()); + map.put(Fields.Subtype, this.getSubtype()); + return map; + } + + @Override + public synchronized boolean hasData() { + return this.getServer() != null + && this.hasInetAddress() + && this.getTextBytes() != null + && this.getTextBytes().length > 0; + } + + private boolean hasInetAddress() { + return _ipv4Addresses.size() > 0 || _ipv6Addresses.size() > 0; + } + + @Override + public int hashCode() { + return getQualifiedName().hashCode(); + } + + @Override + public boolean equals(Object obj) { + return (obj instanceof ServiceInfoImpl) + && getQualifiedName().equals(((ServiceInfoImpl) obj).getQualifiedName()); + } + + @Override + public ServiceInfoImpl clone() { + ServiceInfoImpl serviceInfo = + new ServiceInfoImpl(this.getQualifiedNameMap(), _port, _weight, _priority, _text); + serviceInfo._ipv6Addresses.addAll(Arrays.asList(getInet6Addresses())); + serviceInfo._ipv4Addresses.addAll(Arrays.asList(getInet4Addresses())); + serviceInfo._server = _server; + return serviceInfo; + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder(); + sb.append('[') + .append(this.getClass().getSimpleName()) + .append('@') + .append(System.identityHashCode(this)); + sb.append(" name: '"); + if (0 < this.getName().length()) { + sb.append(this.getName()).append('.'); + } + sb.append(this.getTypeWithSubtype()); + sb.append("' address: '"); + Inet4Address[] addresses4 = this.getInet4Addresses(); + for (InetAddress address : addresses4) { + sb.append(address).append(':').append(this.getPort()).append(' '); + } + Inet6Address[] addresses6 = this.getInet6Addresses(); + for (InetAddress address : addresses6) { + sb.append(address).append(':').append(this.getPort()).append(' '); + } + + sb.append(this.hasData() ? " has data" : " has NO data"); + sb.append(']'); + + return sb.toString(); + } + + /** + * Create a series of answer that correspond with the give service info. + * + * @param recordClass record class of the query + * @param unique + * @param ttl + * @param localHost + * @return collection of answers + */ + public Collection answers( + DNSRecordClass recordClass, boolean unique, int ttl, HostInfo localHost) { + List list = new ArrayList(); + + if ((recordClass == DNSRecordClass.CLASS_ANY) || (recordClass == DNSRecordClass.CLASS_IN)) { + if (this.getSubtype().length() > 0) { + list.add( + new DNSRecord.Pointer( + this.getTypeWithSubtype(), + DNSRecordClass.CLASS_IN, + DNSRecordClass.NOT_UNIQUE, + ttl, + this.getQualifiedName())); + } + list.add( + new DNSRecord.Pointer( + this.getType(), + DNSRecordClass.CLASS_IN, + DNSRecordClass.NOT_UNIQUE, + ttl, + this.getQualifiedName())); + list.add( + new DNSRecord.Service( + this.getQualifiedName(), + DNSRecordClass.CLASS_IN, + unique, + ttl, + _priority, + _weight, + _port, + localHost.getName())); + list.add( + new DNSRecord.Text( + this.getQualifiedName(), DNSRecordClass.CLASS_IN, unique, ttl, this.getTextBytes())); + for (InetAddress address : _ipv4Addresses) + list.add( + new DNSRecord.IPv4Address( + this.getQualifiedName(), DNSRecordClass.CLASS_IN, unique, ttl, address)); + for (InetAddress address : _ipv6Addresses) + list.add( + new DNSRecord.IPv6Address( + this.getQualifiedName(), DNSRecordClass.CLASS_IN, unique, ttl, address)); + } + + return list; + } } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/SocketListener.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/SocketListener.java index 3edbf9acc..17a3a97d6 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/SocketListener.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/SocketListener.java @@ -4,85 +4,81 @@ package io.libp2p.discovery.mdns.impl; +import io.libp2p.discovery.mdns.impl.constants.DNSConstants; +import io.libp2p.discovery.mdns.impl.util.NamedThreadFactory; import java.io.IOException; import java.net.DatagramPacket; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; - -import io.libp2p.discovery.mdns.impl.util.NamedThreadFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import io.libp2p.discovery.mdns.impl.constants.DNSConstants; - -/** - * Listen for multicast packets. - */ +/** Listen for multicast packets. */ class SocketListener implements Runnable { - static Logger logger = LoggerFactory.getLogger(SocketListener.class.getName()); + static Logger logger = LoggerFactory.getLogger(SocketListener.class.getName()); - private final JmDNSImpl _jmDNSImpl; - private final String _name; - private volatile boolean _closed; - private final ExecutorService _executor = Executors.newSingleThreadExecutor(new NamedThreadFactory("JmDNS")); - private Future _isShutdown; + private final JmDNSImpl _jmDNSImpl; + private final String _name; + private volatile boolean _closed; + private final ExecutorService _executor = + Executors.newSingleThreadExecutor(new NamedThreadFactory("JmDNS")); + private Future _isShutdown; - SocketListener(JmDNSImpl jmDNSImpl) { - _name = "SocketListener(" + (jmDNSImpl != null ? jmDNSImpl.getName() : "") + ")"; - this._jmDNSImpl = jmDNSImpl; - } + SocketListener(JmDNSImpl jmDNSImpl) { + _name = "SocketListener(" + (jmDNSImpl != null ? jmDNSImpl.getName() : "") + ")"; + this._jmDNSImpl = jmDNSImpl; + } - public void start() { - _isShutdown = _executor.submit(this, null); - } - public Future stop() { - _closed = true; - _executor.shutdown(); - return _isShutdown; - } + public void start() { + _isShutdown = _executor.submit(this, null); + } - @Override - public void run() { + public Future stop() { + _closed = true; + _executor.shutdown(); + return _isShutdown; + } + + @Override + public void run() { + try { + byte buf[] = new byte[DNSConstants.MAX_MSG_ABSOLUTE]; + DatagramPacket packet = new DatagramPacket(buf, buf.length); + while (!_closed) { + packet.setLength(buf.length); + this._jmDNSImpl.getSocket().receive(packet); + if (_closed) break; try { - byte buf[] = new byte[DNSConstants.MAX_MSG_ABSOLUTE]; - DatagramPacket packet = new DatagramPacket(buf, buf.length); - while (!_closed) { - packet.setLength(buf.length); - this._jmDNSImpl.getSocket().receive(packet); - if (_closed) - break; - try { - if (this._jmDNSImpl.getLocalHost().shouldIgnorePacket(packet)) { - continue; - } + if (this._jmDNSImpl.getLocalHost().shouldIgnorePacket(packet)) { + continue; + } - DNSIncoming msg = new DNSIncoming(packet); - if (msg.isValidResponseCode()) { - if (logger.isTraceEnabled()) { - logger.trace("{}.run() JmDNS in:{}", _name, msg.print(true)); - } - if (msg.isQuery()) { - if (packet.getPort() != DNSConstants.MDNS_PORT) { - this._jmDNSImpl.handleQuery(msg, packet.getAddress(), packet.getPort()); - } - this._jmDNSImpl.handleQuery(msg, this._jmDNSImpl.getGroup(), DNSConstants.MDNS_PORT); - } else { - this._jmDNSImpl.handleResponse(msg); - } - } else { - if (logger.isDebugEnabled()) { - logger.debug("{}.run() JmDNS in message with error code: {}", _name, msg.print(true)); - } - } - } catch (IOException e) { - logger.warn(_name + ".run() exception ", e); - } + DNSIncoming msg = new DNSIncoming(packet); + if (msg.isValidResponseCode()) { + if (logger.isTraceEnabled()) { + logger.trace("{}.run() JmDNS in:{}", _name, msg.print(true)); + } + if (msg.isQuery()) { + if (packet.getPort() != DNSConstants.MDNS_PORT) { + this._jmDNSImpl.handleQuery(msg, packet.getAddress(), packet.getPort()); + } + this._jmDNSImpl.handleQuery(msg, this._jmDNSImpl.getGroup(), DNSConstants.MDNS_PORT); + } else { + this._jmDNSImpl.handleResponse(msg); + } + } else { + if (logger.isDebugEnabled()) { + logger.debug("{}.run() JmDNS in message with error code: {}", _name, msg.print(true)); } + } } catch (IOException e) { - if (!_closed) - logger.warn(_name + ".run() exception ", e); + logger.warn(_name + ".run() exception ", e); } - logger.trace("{}.run() exiting.", _name); + } + } catch (IOException e) { + if (!_closed) logger.warn(_name + ".run() exception ", e); } + logger.trace("{}.run() exiting.", _name); + } } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSConstants.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSConstants.java index 7def48772..4bcfcdfed 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSConstants.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSConstants.java @@ -10,57 +10,63 @@ * @author Arthur van Hoff, Jeff Sonstein, Werner Randelshofer, Pierre Frisch, Rick Blair */ public final class DNSConstants { - // http://www.iana.org/assignments/dns-parameters + // http://www.iana.org/assignments/dns-parameters - // changed to final class - jeffs - public static final String MDNS_GROUP = "224.0.0.251"; - public static final String MDNS_GROUP_IPV6 = "FF02::FB"; - public static final int MDNS_PORT = Integer.getInteger("net.mdns.port", 5353); - public static final int DNS_PORT = 53; - public static final int DNS_TTL = Integer.getInteger("net.dns.ttl", 60 * 60); // default one hour TTL - // public static final int DNS_TTL = 120 * 60; // two hour TTL (draft-cheshire-dnsext-multicastdns.txt ch 13) + // changed to final class - jeffs + public static final String MDNS_GROUP = "224.0.0.251"; + public static final String MDNS_GROUP_IPV6 = "FF02::FB"; + public static final int MDNS_PORT = Integer.getInteger("net.mdns.port", 5353); + public static final int DNS_PORT = 53; + public static final int DNS_TTL = + Integer.getInteger("net.dns.ttl", 60 * 60); // default one hour TTL + // public static final int DNS_TTL = 120 * 60; // two hour TTL + // (draft-cheshire-dnsext-multicastdns.txt ch 13) - public static final int MAX_MSG_TYPICAL = 1460; - public static final int MAX_MSG_ABSOLUTE = 8972; + public static final int MAX_MSG_TYPICAL = 1460; + public static final int MAX_MSG_ABSOLUTE = 8972; - public static final int FLAGS_QR_MASK = 0x8000; // Query response mask - public static final int FLAGS_QR_QUERY = 0x0000; // Query - public static final int FLAGS_QR_RESPONSE = 0x8000; // Response + public static final int FLAGS_QR_MASK = 0x8000; // Query response mask + public static final int FLAGS_QR_QUERY = 0x0000; // Query + public static final int FLAGS_QR_RESPONSE = 0x8000; // Response - public static final int FLAGS_OPCODE = 0x7800; // Operation code - public static final int FLAGS_AA = 0x0400; // Authorative answer - public static final int FLAGS_TC = 0x0200; // Truncated - public static final int FLAGS_RD = 0x0100; // Recursion desired - public static final int FLAGS_RA = 0x8000; // Recursion available + public static final int FLAGS_OPCODE = 0x7800; // Operation code + public static final int FLAGS_AA = 0x0400; // Authorative answer + public static final int FLAGS_TC = 0x0200; // Truncated + public static final int FLAGS_RD = 0x0100; // Recursion desired + public static final int FLAGS_RA = 0x8000; // Recursion available - public static final int FLAGS_Z = 0x0040; // Zero - public static final int FLAGS_AD = 0x0020; // Authentic data - public static final int FLAGS_CD = 0x0010; // Checking disabled - public static final int FLAGS_RCODE = 0x000F; // Response code + public static final int FLAGS_Z = 0x0040; // Zero + public static final int FLAGS_AD = 0x0020; // Authentic data + public static final int FLAGS_CD = 0x0010; // Checking disabled + public static final int FLAGS_RCODE = 0x000F; // Response code - // Time Intervals for various functions + // Time Intervals for various functions - public static final int SHARED_QUERY_TIME = 20; // milliseconds before send shared query - public static final int QUERY_WAIT_INTERVAL = 225; // milliseconds between query loops. - public static final int PROBE_WAIT_INTERVAL = 250; // milliseconds between probe loops. - public static final int RESPONSE_MIN_WAIT_INTERVAL = 20; // minimal wait interval for response. - public static final int RESPONSE_MAX_WAIT_INTERVAL = 115; // maximal wait interval for response - public static final int PROBE_CONFLICT_INTERVAL = 1000; // milliseconds to wait after conflict. - public static final int PROBE_THROTTLE_COUNT = 10; // After x tries go 1 time a sec. on probes. - public static final int PROBE_THROTTLE_COUNT_INTERVAL = 5000; // We only increment the throttle count, if the previous increment is inside this interval. - public static final int ANNOUNCE_WAIT_INTERVAL = 1000; // milliseconds between Announce loops. - public static final int RECORD_REAPER_INTERVAL = 10000; // milliseconds between cache cleanups. - public static final int RECORD_EXPIRY_DELAY = 1; // This is 1s delay used in ttl and therefore in seconds - public static final int KNOWN_ANSWER_TTL = 120; - public static final int ANNOUNCED_RENEWAL_TTL_INTERVAL = DNS_TTL * 500; // 50% of the TTL in milliseconds - public static final int FLUSH_RECORD_OLDER_THAN_1_SECOND = 1; // rfc6762, section 10.2 Flush outdated cache (older than 1 second) + public static final int SHARED_QUERY_TIME = 20; // milliseconds before send shared query + public static final int QUERY_WAIT_INTERVAL = 225; // milliseconds between query loops. + public static final int PROBE_WAIT_INTERVAL = 250; // milliseconds between probe loops. + public static final int RESPONSE_MIN_WAIT_INTERVAL = 20; // minimal wait interval for response. + public static final int RESPONSE_MAX_WAIT_INTERVAL = 115; // maximal wait interval for response + public static final int PROBE_CONFLICT_INTERVAL = 1000; // milliseconds to wait after conflict. + public static final int PROBE_THROTTLE_COUNT = 10; // After x tries go 1 time a sec. on probes. + public static final int PROBE_THROTTLE_COUNT_INTERVAL = + 5000; // We only increment the throttle count, if the previous increment is inside this + // interval. + public static final int ANNOUNCE_WAIT_INTERVAL = 1000; // milliseconds between Announce loops. + public static final int RECORD_REAPER_INTERVAL = 10000; // milliseconds between cache cleanups. + public static final int RECORD_EXPIRY_DELAY = + 1; // This is 1s delay used in ttl and therefore in seconds + public static final int KNOWN_ANSWER_TTL = 120; + public static final int ANNOUNCED_RENEWAL_TTL_INTERVAL = + DNS_TTL * 500; // 50% of the TTL in milliseconds + public static final int FLUSH_RECORD_OLDER_THAN_1_SECOND = + 1; // rfc6762, section 10.2 Flush outdated cache (older than 1 second) - public static final int STALE_REFRESH_INCREMENT = 5; - public static final int STALE_REFRESH_STARTING_PERCENTAGE = 80; + public static final int STALE_REFRESH_INCREMENT = 5; + public static final int STALE_REFRESH_STARTING_PERCENTAGE = 80; - public static final long CLOSE_TIMEOUT = ANNOUNCE_WAIT_INTERVAL * 5L; - public static final long SERVICE_INFO_TIMEOUT = ANNOUNCE_WAIT_INTERVAL * 6L; - - public static final int NETWORK_CHECK_INTERVAL = 10 * 1000; // 10 secondes + public static final long CLOSE_TIMEOUT = ANNOUNCE_WAIT_INTERVAL * 5L; + public static final long SERVICE_INFO_TIMEOUT = ANNOUNCE_WAIT_INTERVAL * 6L; + public static final int NETWORK_CHECK_INTERVAL = 10 * 1000; // 10 secondes } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSLabel.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSLabel.java index 86bf62a73..986c31efa 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSLabel.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSLabel.java @@ -1,87 +1,75 @@ -/** - * - */ +/** */ package io.libp2p.discovery.mdns.impl.constants; /** * DNS label. - * + * * @author Arthur van Hoff, Jeff Sonstein, Werner Randelshofer, Pierre Frisch, Rick Blair */ public enum DNSLabel { - /** - * This is unallocated. - */ - Unknown("", 0x80), - /** - * Standard label [RFC 1035] - */ - Standard("standard label", 0x00), - /** - * Compressed label [RFC 1035] - */ - Compressed("compressed label", 0xC0), - /** - * Extended label [RFC 2671] - */ - Extended("extended label", 0x40); + /** This is unallocated. */ + Unknown("", 0x80), + /** Standard label [RFC 1035] */ + Standard("standard label", 0x00), + /** Compressed label [RFC 1035] */ + Compressed("compressed label", 0xC0), + /** Extended label [RFC 2671] */ + Extended("extended label", 0x40); - /** - * DNS label types are encoded on the first 2 bits - */ - static final int LABEL_MASK = 0xC0; - static final int LABEL_NOT_MASK = 0x3F; + /** DNS label types are encoded on the first 2 bits */ + static final int LABEL_MASK = 0xC0; - private final String _externalName; + static final int LABEL_NOT_MASK = 0x3F; - private final int _index; + private final String _externalName; - DNSLabel(String name, int index) { - _externalName = name; - _index = index; - } + private final int _index; - /** - * Return the string representation of this type - * - * @return String - */ - public String externalName() { - return _externalName; - } + DNSLabel(String name, int index) { + _externalName = name; + _index = index; + } - /** - * Return the numeric value of this type - * - * @return String - */ - public int indexValue() { - return _index; - } + /** + * Return the string representation of this type + * + * @return String + */ + public String externalName() { + return _externalName; + } - /** - * @param index - * @return label - */ - public static DNSLabel labelForByte(int index) { - int maskedIndex = index & LABEL_MASK; - for (DNSLabel aLabel : DNSLabel.values()) { - if (aLabel._index == maskedIndex) return aLabel; - } - return Unknown; - } + /** + * Return the numeric value of this type + * + * @return String + */ + public int indexValue() { + return _index; + } - /** - * @param index - * @return masked value - */ - public static int labelValue(int index) { - return index & LABEL_NOT_MASK; + /** + * @param index + * @return label + */ + public static DNSLabel labelForByte(int index) { + int maskedIndex = index & LABEL_MASK; + for (DNSLabel aLabel : DNSLabel.values()) { + if (aLabel._index == maskedIndex) return aLabel; } + return Unknown; + } - @Override - public String toString() { - return this.name() + " index " + this.indexValue(); - } + /** + * @param index + * @return masked value + */ + public static int labelValue(int index) { + return index & LABEL_NOT_MASK; + } + @Override + public String toString() { + return this.name() + " index " + this.indexValue(); + } } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSOperationCode.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSOperationCode.java index 5045e0ea7..9c495fddb 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSOperationCode.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSOperationCode.java @@ -1,86 +1,69 @@ -/** - * - */ +/** */ package io.libp2p.discovery.mdns.impl.constants; /** * DNS operation code. - * + * * @author Arthur van Hoff, Jeff Sonstein, Werner Randelshofer, Pierre Frisch, Rick Blair */ public enum DNSOperationCode { - /** - * Query [RFC1035] - */ - Query("Query", 0), - /** - * IQuery (Inverse Query, Obsolete) [RFC3425] - */ - IQuery("Inverse Query", 1), - /** - * Status [RFC1035] - */ - Status("Status", 2), - /** - * Unassigned - */ - Unassigned("Unassigned", 3), - /** - * Notify [RFC1996] - */ - Notify("Notify", 4), - /** - * Update [RFC2136] - */ - Update("Update", 5); + /** Query [RFC1035] */ + Query("Query", 0), + /** IQuery (Inverse Query, Obsolete) [RFC3425] */ + IQuery("Inverse Query", 1), + /** Status [RFC1035] */ + Status("Status", 2), + /** Unassigned */ + Unassigned("Unassigned", 3), + /** Notify [RFC1996] */ + Notify("Notify", 4), + /** Update [RFC2136] */ + Update("Update", 5); - /** - * DNS RCode types are encoded on the last 4 bits - */ - static final int OpCode_MASK = 0x7800; + /** DNS RCode types are encoded on the last 4 bits */ + static final int OpCode_MASK = 0x7800; - private final String _externalName; + private final String _externalName; - private final int _index; + private final int _index; - DNSOperationCode(String name, int index) { - _externalName = name; - _index = index; - } + DNSOperationCode(String name, int index) { + _externalName = name; + _index = index; + } - /** - * Return the string representation of this type - * - * @return String - */ - public String externalName() { - return _externalName; - } - - /** - * Return the numeric value of this type - * - * @return String - */ - public int indexValue() { - return _index; - } + /** + * Return the string representation of this type + * + * @return String + */ + public String externalName() { + return _externalName; + } - /** - * @param flags - * @return label - */ - public static DNSOperationCode operationCodeForFlags(int flags) { - int maskedIndex = (flags & OpCode_MASK) >> 11; - for (DNSOperationCode aCode : DNSOperationCode.values()) { - if (aCode._index == maskedIndex) return aCode; - } - return Unassigned; - } + /** + * Return the numeric value of this type + * + * @return String + */ + public int indexValue() { + return _index; + } - @Override - public String toString() { - return this.name() + " index " + this.indexValue(); + /** + * @param flags + * @return label + */ + public static DNSOperationCode operationCodeForFlags(int flags) { + int maskedIndex = (flags & OpCode_MASK) >> 11; + for (DNSOperationCode aCode : DNSOperationCode.values()) { + if (aCode._index == maskedIndex) return aCode; } + return Unassigned; + } + @Override + public String toString() { + return this.name() + " index " + this.indexValue(); + } } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSOptionCode.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSOptionCode.java index d254c2424..a30f437bf 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSOptionCode.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSOptionCode.java @@ -1,78 +1,65 @@ -/** - * - */ +/** */ package io.libp2p.discovery.mdns.impl.constants; /** * DNS option code. - * + * * @author Arthur van Hoff, Pierre Frisch, Rick Blair */ public enum DNSOptionCode { - /** - * Token - */ - Unknown("Unknown", 65535), - /** - * Long-Lived Queries Option [http://files.dns-sd.org/draft-sekar-dns-llq.txt] - */ - LLQ("LLQ", 1), - /** - * Update Leases Option [http://files.dns-sd.org/draft-sekar-dns-ul.txt] - */ - UL("UL", 2), - /** - * Name Server Identifier Option [RFC5001] - */ - NSID("NSID", 3), - /** - * Owner Option [draft-cheshire-edns0-owner-option] - */ - Owner("Owner", 4); + /** Token */ + Unknown("Unknown", 65535), + /** Long-Lived Queries Option [http://files.dns-sd.org/draft-sekar-dns-llq.txt] */ + LLQ("LLQ", 1), + /** Update Leases Option [http://files.dns-sd.org/draft-sekar-dns-ul.txt] */ + UL("UL", 2), + /** Name Server Identifier Option [RFC5001] */ + NSID("NSID", 3), + /** Owner Option [draft-cheshire-edns0-owner-option] */ + Owner("Owner", 4); - private final String _externalName; + private final String _externalName; - private final int _index; + private final int _index; - DNSOptionCode(String name, int index) { - _externalName = name; - _index = index; - } + DNSOptionCode(String name, int index) { + _externalName = name; + _index = index; + } - /** - * Return the string representation of this type - * - * @return String - */ - public String externalName() { - return _externalName; - } - - /** - * Return the numeric value of this type - * - * @return String - */ - public int indexValue() { - return _index; - } + /** + * Return the string representation of this type + * + * @return String + */ + public String externalName() { + return _externalName; + } - /** - * @param optioncode - * @return label - */ - public static DNSOptionCode resultCodeForFlags(int optioncode) { - int maskedIndex = optioncode; - for (DNSOptionCode aCode : DNSOptionCode.values()) { - if (aCode._index == maskedIndex) return aCode; - } - return Unknown; - } + /** + * Return the numeric value of this type + * + * @return String + */ + public int indexValue() { + return _index; + } - @Override - public String toString() { - return this.name() + " index " + this.indexValue(); + /** + * @param optioncode + * @return label + */ + public static DNSOptionCode resultCodeForFlags(int optioncode) { + int maskedIndex = optioncode; + for (DNSOptionCode aCode : DNSOptionCode.values()) { + if (aCode._index == maskedIndex) return aCode; } + return Unknown; + } + @Override + public String toString() { + return this.name() + " index " + this.indexValue(); + } } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSRecordClass.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSRecordClass.java index 359279f88..40bb610f4 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSRecordClass.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSRecordClass.java @@ -1,6 +1,4 @@ -/** - * - */ +/** */ package io.libp2p.discovery.mdns.impl.constants; import org.slf4j.Logger; @@ -8,131 +6,112 @@ /** * DNS Record Class - * + * * @author Arthur van Hoff, Jeff Sonstein, Werner Randelshofer, Pierre Frisch, Rick Blair */ public enum DNSRecordClass { - /** - * - */ - CLASS_UNKNOWN("?", 0), - /** - * static final Internet - */ - CLASS_IN("in", 1), - /** - * CSNET - */ - CLASS_CS("cs", 2), - /** - * CHAOS - */ - CLASS_CH("ch", 3), - /** - * Hesiod - */ - CLASS_HS("hs", 4), - /** - * Used in DNS UPDATE [RFC 2136] - */ - CLASS_NONE("none", 254), - /** - * Not a DNS class, but a DNS query class, meaning "all classes" - */ - CLASS_ANY("any", 255); - - private static Logger logger = LoggerFactory.getLogger(DNSRecordClass.class.getName()); - - /** - * Multicast DNS uses the bottom 15 bits to identify the record class...
- * Except for pseudo records like OPT. - */ - public static final int CLASS_MASK = 0x7FFF; - - /** - * For answers the top bit indicates that all other cached records are now invalid.
- * For questions it indicates that we should send a unicast response. - */ - public static final int CLASS_UNIQUE = 0x8000; - - /** - * - */ - public static final boolean UNIQUE = true; - - /** - * - */ - public static final boolean NOT_UNIQUE = false; - - private final String _externalName; - - private final int _index; - - DNSRecordClass(String name, int index) { - _externalName = name; - _index = index; - } - - /** - * Return the string representation of this type - * - * @return String - */ - public String externalName() { - return _externalName; - } - - /** - * Return the numeric value of this type - * - * @return String - */ - public int indexValue() { - return _index; + /** */ + CLASS_UNKNOWN("?", 0), + /** static final Internet */ + CLASS_IN("in", 1), + /** CSNET */ + CLASS_CS("cs", 2), + /** CHAOS */ + CLASS_CH("ch", 3), + /** Hesiod */ + CLASS_HS("hs", 4), + /** Used in DNS UPDATE [RFC 2136] */ + CLASS_NONE("none", 254), + /** Not a DNS class, but a DNS query class, meaning "all classes" */ + CLASS_ANY("any", 255); + + private static Logger logger = LoggerFactory.getLogger(DNSRecordClass.class.getName()); + + /** + * Multicast DNS uses the bottom 15 bits to identify the record class...
+ * Except for pseudo records like OPT. + */ + public static final int CLASS_MASK = 0x7FFF; + + /** + * For answers the top bit indicates that all other cached records are now invalid.
+ * For questions it indicates that we should send a unicast response. + */ + public static final int CLASS_UNIQUE = 0x8000; + + /** */ + public static final boolean UNIQUE = true; + + /** */ + public static final boolean NOT_UNIQUE = false; + + private final String _externalName; + + private final int _index; + + DNSRecordClass(String name, int index) { + _externalName = name; + _index = index; + } + + /** + * Return the string representation of this type + * + * @return String + */ + public String externalName() { + return _externalName; + } + + /** + * Return the numeric value of this type + * + * @return String + */ + public int indexValue() { + return _index; + } + + /** + * Checks if the class is unique + * + * @param index + * @return true is the class is unique, false otherwise. + */ + public boolean isUnique(int index) { + return (this != CLASS_UNKNOWN) && ((index & CLASS_UNIQUE) != 0); + } + + /** + * @param name + * @return class for name + */ + public static DNSRecordClass classForName(String name) { + if (name != null) { + String aName = name.toLowerCase(); + for (DNSRecordClass aClass : DNSRecordClass.values()) { + if (aClass._externalName.equals(aName)) return aClass; + } } - - /** - * Checks if the class is unique - * - * @param index - * @return true is the class is unique, false otherwise. - */ - public boolean isUnique(int index) { - return (this != CLASS_UNKNOWN) && ((index & CLASS_UNIQUE) != 0); - } - - /** - * @param name - * @return class for name - */ - public static DNSRecordClass classForName(String name) { - if (name != null) { - String aName = name.toLowerCase(); - for (DNSRecordClass aClass : DNSRecordClass.values()) { - if (aClass._externalName.equals(aName)) return aClass; - } - } - logger.warn("Could not find record class for name: {}", name); - return CLASS_UNKNOWN; - } - - /** - * @param index - * @return class for name - */ - public static DNSRecordClass classForIndex(int index) { - int maskedIndex = index & CLASS_MASK; - for (DNSRecordClass aClass : DNSRecordClass.values()) { - if (aClass._index == maskedIndex) return aClass; - } - logger.warn("Could not find record class for index: {}", index); - return CLASS_UNKNOWN; - } - - @Override - public String toString() { - return this.name() + " index " + this.indexValue(); + logger.warn("Could not find record class for name: {}", name); + return CLASS_UNKNOWN; + } + + /** + * @param index + * @return class for name + */ + public static DNSRecordClass classForIndex(int index) { + int maskedIndex = index & CLASS_MASK; + for (DNSRecordClass aClass : DNSRecordClass.values()) { + if (aClass._index == maskedIndex) return aClass; } - + logger.warn("Could not find record class for index: {}", index); + return CLASS_UNKNOWN; + } + + @Override + public String toString() { + return this.name() + " index " + this.indexValue(); + } } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSRecordType.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSRecordType.java index 0ad1e8e9f..9ff8754ac 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSRecordType.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSRecordType.java @@ -1,6 +1,4 @@ -/** - * - */ +/** */ package io.libp2p.discovery.mdns.impl.constants; import org.slf4j.Logger; @@ -8,306 +6,187 @@ /** * DNS Record Type - * + * * @author Arthur van Hoff, Jeff Sonstein, Werner Randelshofer, Pierre Frisch, Rick Blair */ public enum DNSRecordType { - /** - * Address - */ - TYPE_IGNORE("ignore", 0), - /** - * Address - */ - TYPE_A("a", 1), - /** - * Name Server - */ - TYPE_NS("ns", 2), - /** - * Mail Destination - */ - TYPE_MD("md", 3), - /** - * Mail Forwarder - */ - TYPE_MF("mf", 4), - /** - * Canonical Name - */ - TYPE_CNAME("cname", 5), - /** - * Start of Authority - */ - TYPE_SOA("soa", 6), - /** - * Mailbox - */ - TYPE_MB("mb", 7), - /** - * Mail Group - */ - TYPE_MG("mg", 8), - /** - * Mail Rename - */ - TYPE_MR("mr", 9), - /** - * NULL RR - */ - TYPE_NULL("null", 10), - /** - * Well-known-service - */ - TYPE_WKS("wks", 11), - /** - * Domain Name pointer - */ - TYPE_PTR("ptr", 12), - /** - * Host information - */ - TYPE_HINFO("hinfo", 13), - /** - * Mailbox information - */ - TYPE_MINFO("minfo", 14), - /** - * Mail exchanger - */ - TYPE_MX("mx", 15), - /** - * Arbitrary text string - */ - TYPE_TXT("txt", 16), - /** - * for Responsible Person [RFC1183] - */ - TYPE_RP("rp", 17), - /** - * for AFS Data Base location [RFC1183] - */ - TYPE_AFSDB("afsdb", 18), - /** - * for X.25 PSDN address [RFC1183] - */ - TYPE_X25("x25", 19), - /** - * for ISDN address [RFC1183] - */ - TYPE_ISDN("isdn", 20), - /** - * for Route Through [RFC1183] - */ - TYPE_RT("rt", 21), - /** - * for NSAP address, NSAP style A record [RFC1706] - */ - TYPE_NSAP("nsap", 22), - /** - * - */ - TYPE_NSAP_PTR("nsap-otr", 23), - /** - * for security signature [RFC2931] - */ - TYPE_SIG("sig", 24), - /** - * for security key [RFC2535] - */ - TYPE_KEY("key", 25), - /** - * X.400 mail mapping information [RFC2163] - */ - TYPE_PX("px", 26), - /** - * Geographical Position [RFC1712] - */ - TYPE_GPOS("gpos", 27), - /** - * IP6 Address [Thomson] - */ - TYPE_AAAA("aaaa", 28), - /** - * Location Information [Vixie] - */ - TYPE_LOC("loc", 29), - /** - * Next Domain - OBSOLETE [RFC2535, RFC3755] - */ - TYPE_NXT("nxt", 30), - /** - * Endpoint Identifier [Patton] - */ - TYPE_EID("eid", 31), - /** - * Nimrod Locator [Patton] - */ - TYPE_NIMLOC("nimloc", 32), - /** - * Server Selection [RFC2782] - */ - TYPE_SRV("srv", 33), - /** - * ATM Address [Dobrowski] - */ - TYPE_ATMA("atma", 34), - /** - * Naming Authority Pointer [RFC2168, RFC2915] - */ - TYPE_NAPTR("naptr", 35), - /** - * Key Exchanger [RFC2230] - */ - TYPE_KX("kx", 36), - /** - * CERT [RFC2538] - */ - TYPE_CERT("cert", 37), - /** - * A6 [RFC2874] - */ - TYPE_A6("a6", 38), - /** - * DNAME [RFC2672] - */ - TYPE_DNAME("dname", 39), - /** - * SINK [Eastlake] - */ - TYPE_SINK("sink", 40), - /** - * OPT [RFC2671] - */ - TYPE_OPT("opt", 41), - /** - * APL [RFC3123] - */ - TYPE_APL("apl", 42), - /** - * Delegation Signer [RFC3658] - */ - TYPE_DS("ds", 43), - /** - * SSH Key Fingerprint [RFC-ietf-secsh-dns-05.txt] - */ - TYPE_SSHFP("sshfp", 44), - /** - * RRSIG [RFC3755] - */ - TYPE_RRSIG("rrsig", 46), - /** - * NSEC [RFC3755] - */ - TYPE_NSEC("nsec", 47), - /** - * DNSKEY [RFC3755] - */ - TYPE_DNSKEY("dnskey", 48), - /** - * [IANA-Reserved] - */ - TYPE_UINFO("uinfo", 100), - /** - * [IANA-Reserved] - */ - TYPE_UID("uid", 101), - /** - * [IANA-Reserved] - */ - TYPE_GID("gid", 102), - /** - * [IANA-Reserved] - */ - TYPE_UNSPEC("unspec", 103), - /** - * Transaction Key [RFC2930] - */ - TYPE_TKEY("tkey", 249), - /** - * Transaction Signature [RFC2845] - */ - TYPE_TSIG("tsig", 250), - /** - * Incremental transfer [RFC1995] - */ - TYPE_IXFR("ixfr", 251), - /** - * Transfer of an entire zone [RFC1035] - */ - TYPE_AXFR("axfr", 252), - /** - * Mailbox-related records (MB, MG or MR) [RFC1035] - */ - TYPE_MAILA("mails", 253), - /** - * Mail agent RRs (Obsolete - see MX) [RFC1035] - */ - TYPE_MAILB("mailb", 254), - /** - * Request for all records [RFC1035] - */ - TYPE_ANY("any", 255); + /** Address */ + TYPE_IGNORE("ignore", 0), + /** Address */ + TYPE_A("a", 1), + /** Name Server */ + TYPE_NS("ns", 2), + /** Mail Destination */ + TYPE_MD("md", 3), + /** Mail Forwarder */ + TYPE_MF("mf", 4), + /** Canonical Name */ + TYPE_CNAME("cname", 5), + /** Start of Authority */ + TYPE_SOA("soa", 6), + /** Mailbox */ + TYPE_MB("mb", 7), + /** Mail Group */ + TYPE_MG("mg", 8), + /** Mail Rename */ + TYPE_MR("mr", 9), + /** NULL RR */ + TYPE_NULL("null", 10), + /** Well-known-service */ + TYPE_WKS("wks", 11), + /** Domain Name pointer */ + TYPE_PTR("ptr", 12), + /** Host information */ + TYPE_HINFO("hinfo", 13), + /** Mailbox information */ + TYPE_MINFO("minfo", 14), + /** Mail exchanger */ + TYPE_MX("mx", 15), + /** Arbitrary text string */ + TYPE_TXT("txt", 16), + /** for Responsible Person [RFC1183] */ + TYPE_RP("rp", 17), + /** for AFS Data Base location [RFC1183] */ + TYPE_AFSDB("afsdb", 18), + /** for X.25 PSDN address [RFC1183] */ + TYPE_X25("x25", 19), + /** for ISDN address [RFC1183] */ + TYPE_ISDN("isdn", 20), + /** for Route Through [RFC1183] */ + TYPE_RT("rt", 21), + /** for NSAP address, NSAP style A record [RFC1706] */ + TYPE_NSAP("nsap", 22), + /** */ + TYPE_NSAP_PTR("nsap-otr", 23), + /** for security signature [RFC2931] */ + TYPE_SIG("sig", 24), + /** for security key [RFC2535] */ + TYPE_KEY("key", 25), + /** X.400 mail mapping information [RFC2163] */ + TYPE_PX("px", 26), + /** Geographical Position [RFC1712] */ + TYPE_GPOS("gpos", 27), + /** IP6 Address [Thomson] */ + TYPE_AAAA("aaaa", 28), + /** Location Information [Vixie] */ + TYPE_LOC("loc", 29), + /** Next Domain - OBSOLETE [RFC2535, RFC3755] */ + TYPE_NXT("nxt", 30), + /** Endpoint Identifier [Patton] */ + TYPE_EID("eid", 31), + /** Nimrod Locator [Patton] */ + TYPE_NIMLOC("nimloc", 32), + /** Server Selection [RFC2782] */ + TYPE_SRV("srv", 33), + /** ATM Address [Dobrowski] */ + TYPE_ATMA("atma", 34), + /** Naming Authority Pointer [RFC2168, RFC2915] */ + TYPE_NAPTR("naptr", 35), + /** Key Exchanger [RFC2230] */ + TYPE_KX("kx", 36), + /** CERT [RFC2538] */ + TYPE_CERT("cert", 37), + /** A6 [RFC2874] */ + TYPE_A6("a6", 38), + /** DNAME [RFC2672] */ + TYPE_DNAME("dname", 39), + /** SINK [Eastlake] */ + TYPE_SINK("sink", 40), + /** OPT [RFC2671] */ + TYPE_OPT("opt", 41), + /** APL [RFC3123] */ + TYPE_APL("apl", 42), + /** Delegation Signer [RFC3658] */ + TYPE_DS("ds", 43), + /** SSH Key Fingerprint [RFC-ietf-secsh-dns-05.txt] */ + TYPE_SSHFP("sshfp", 44), + /** RRSIG [RFC3755] */ + TYPE_RRSIG("rrsig", 46), + /** NSEC [RFC3755] */ + TYPE_NSEC("nsec", 47), + /** DNSKEY [RFC3755] */ + TYPE_DNSKEY("dnskey", 48), + /** [IANA-Reserved] */ + TYPE_UINFO("uinfo", 100), + /** [IANA-Reserved] */ + TYPE_UID("uid", 101), + /** [IANA-Reserved] */ + TYPE_GID("gid", 102), + /** [IANA-Reserved] */ + TYPE_UNSPEC("unspec", 103), + /** Transaction Key [RFC2930] */ + TYPE_TKEY("tkey", 249), + /** Transaction Signature [RFC2845] */ + TYPE_TSIG("tsig", 250), + /** Incremental transfer [RFC1995] */ + TYPE_IXFR("ixfr", 251), + /** Transfer of an entire zone [RFC1035] */ + TYPE_AXFR("axfr", 252), + /** Mailbox-related records (MB, MG or MR) [RFC1035] */ + TYPE_MAILA("mails", 253), + /** Mail agent RRs (Obsolete - see MX) [RFC1035] */ + TYPE_MAILB("mailb", 254), + /** Request for all records [RFC1035] */ + TYPE_ANY("any", 255); - private static Logger logger = LoggerFactory.getLogger(DNSRecordType.class.getName()); + private static Logger logger = LoggerFactory.getLogger(DNSRecordType.class.getName()); - private final String _externalName; + private final String _externalName; - private final int _index; + private final int _index; - DNSRecordType(String name, int index) { - _externalName = name; - _index = index; - } + DNSRecordType(String name, int index) { + _externalName = name; + _index = index; + } - /** - * Return the string representation of this type - * - * @return String - */ - public String externalName() { - return _externalName; - } - - /** - * Return the numeric value of this type - * - * @return String - */ - public int indexValue() { - return _index; - } + /** + * Return the string representation of this type + * + * @return String + */ + public String externalName() { + return _externalName; + } - /** - * @param name - * @return type for name - */ - public static DNSRecordType typeForName(String name) { - if (name != null) { - String aName = name.toLowerCase(); - for (DNSRecordType aType : DNSRecordType.values()) { - if (aType._externalName.equals(aName)) return aType; - } - } - logger.warn("Could not find record type for name: {}", name); - return TYPE_IGNORE; - } + /** + * Return the numeric value of this type + * + * @return String + */ + public int indexValue() { + return _index; + } - /** - * @param index - * @return type for name - */ - public static DNSRecordType typeForIndex(int index) { - for (DNSRecordType aType : DNSRecordType.values()) { - if (aType._index == index) return aType; - } - logger.warn("Could not find record type for index: {}", index); - return TYPE_IGNORE; + /** + * @param name + * @return type for name + */ + public static DNSRecordType typeForName(String name) { + if (name != null) { + String aName = name.toLowerCase(); + for (DNSRecordType aType : DNSRecordType.values()) { + if (aType._externalName.equals(aName)) return aType; + } } + logger.warn("Could not find record type for name: {}", name); + return TYPE_IGNORE; + } - @Override - public String toString() { - return this.name() + " index " + this.indexValue(); + /** + * @param index + * @return type for name + */ + public static DNSRecordType typeForIndex(int index) { + for (DNSRecordType aType : DNSRecordType.values()) { + if (aType._index == index) return aType; } + logger.warn("Could not find record type for index: {}", index); + return TYPE_IGNORE; + } + @Override + public String toString() { + return this.name() + " index " + this.indexValue(); + } } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSResultCode.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSResultCode.java index 9b73ba628..7175e599c 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSResultCode.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSResultCode.java @@ -1,149 +1,118 @@ -/** - * - */ +/** */ package io.libp2p.discovery.mdns.impl.constants; /** * DNS result code. - * + * * @author Arthur van Hoff, Jeff Sonstein, Werner Randelshofer, Pierre Frisch, Rick Blair */ public enum DNSResultCode { - /** - * Token - */ - Unknown("Unknown", 65535), - /** - * No Error [RFC1035] - */ - NoError("No Error", 0), - /** - * Format Error [RFC1035] - */ - FormErr("Format Error", 1), - /** - * Server Failure [RFC1035] - */ - ServFail("Server Failure", 2), - /** - * Non-Existent Domain [RFC1035] - */ - NXDomain("Non-Existent Domain", 3), - /** - * Not Implemented [RFC1035] - */ - NotImp("Not Implemented", 4), - /** - * Query Refused [RFC1035] - */ - Refused("Query Refused", 5), - /** - * Name Exists when it should not [RFC2136] - */ - YXDomain("Name Exists when it should not", 6), - /** - * RR Set Exists when it should not [RFC2136] - */ - YXRRSet("RR Set Exists when it should not", 7), - /** - * RR Set that should exist does not [RFC2136] - */ - NXRRSet("RR Set that should exist does not", 8), - /** - * Server Not Authoritative for zone [RFC2136]] - */ - NotAuth("Server Not Authoritative for zone", 9), - /** - * Name not contained in zone [RFC2136] - */ - NotZone("NotZone Name not contained in zone", 10), + /** Token */ + Unknown("Unknown", 65535), + /** No Error [RFC1035] */ + NoError("No Error", 0), + /** Format Error [RFC1035] */ + FormErr("Format Error", 1), + /** Server Failure [RFC1035] */ + ServFail("Server Failure", 2), + /** Non-Existent Domain [RFC1035] */ + NXDomain("Non-Existent Domain", 3), + /** Not Implemented [RFC1035] */ + NotImp("Not Implemented", 4), + /** Query Refused [RFC1035] */ + Refused("Query Refused", 5), + /** Name Exists when it should not [RFC2136] */ + YXDomain("Name Exists when it should not", 6), + /** RR Set Exists when it should not [RFC2136] */ + YXRRSet("RR Set Exists when it should not", 7), + /** RR Set that should exist does not [RFC2136] */ + NXRRSet("RR Set that should exist does not", 8), + /** Server Not Authoritative for zone [RFC2136]] */ + NotAuth("Server Not Authoritative for zone", 9), + /** Name not contained in zone [RFC2136] */ + NotZone("NotZone Name not contained in zone", 10), + ; - ; + // 0 NoError No Error [RFC1035] + // 1 FormErr Format Error [RFC1035] + // 2 ServFail Server Failure [RFC1035] + // 3 NXDomain Non-Existent Domain [RFC1035] + // 4 NotImp Not Implemented [RFC1035] + // 5 Refused Query Refused [RFC1035] + // 6 YXDomain Name Exists when it should not [RFC2136] + // 7 YXRRSet RR Set Exists when it should not [RFC2136] + // 8 NXRRSet RR Set that should exist does not [RFC2136] + // 9 NotAuth Server Not Authoritative for zone [RFC2136] + // 10 NotZone Name not contained in zone [RFC2136] + // 11-15 Unassigned + // 16 BADVERS Bad OPT Version [RFC2671] + // 16 BADSIG TSIG Signature Failure [RFC2845] + // 17 BADKEY Key not recognized [RFC2845] + // 18 BADTIME Signature out of time window [RFC2845] + // 19 BADMODE Bad TKEY Mode [RFC2930] + // 20 BADNAME Duplicate key name [RFC2930] + // 21 BADALG Algorithm not supported [RFC2930] + // 22 BADTRUNC Bad Truncation [RFC4635] + // 23-3840 Unassigned + // 3841-4095 Reserved for Private Use [RFC5395] + // 4096-65534 Unassigned + // 65535 Reserved, can be allocated by Standards Action [RFC5395] - // 0 NoError No Error [RFC1035] - // 1 FormErr Format Error [RFC1035] - // 2 ServFail Server Failure [RFC1035] - // 3 NXDomain Non-Existent Domain [RFC1035] - // 4 NotImp Not Implemented [RFC1035] - // 5 Refused Query Refused [RFC1035] - // 6 YXDomain Name Exists when it should not [RFC2136] - // 7 YXRRSet RR Set Exists when it should not [RFC2136] - // 8 NXRRSet RR Set that should exist does not [RFC2136] - // 9 NotAuth Server Not Authoritative for zone [RFC2136] - // 10 NotZone Name not contained in zone [RFC2136] - // 11-15 Unassigned - // 16 BADVERS Bad OPT Version [RFC2671] - // 16 BADSIG TSIG Signature Failure [RFC2845] - // 17 BADKEY Key not recognized [RFC2845] - // 18 BADTIME Signature out of time window [RFC2845] - // 19 BADMODE Bad TKEY Mode [RFC2930] - // 20 BADNAME Duplicate key name [RFC2930] - // 21 BADALG Algorithm not supported [RFC2930] - // 22 BADTRUNC Bad Truncation [RFC4635] - // 23-3840 Unassigned - // 3841-4095 Reserved for Private Use [RFC5395] - // 4096-65534 Unassigned - // 65535 Reserved, can be allocated by Standards Action [RFC5395] + /** DNS Result Code types are encoded on the last 4 bits */ + static final int RCode_MASK = 0x0F; - /** - * DNS Result Code types are encoded on the last 4 bits - */ - final static int RCode_MASK = 0x0F; - /** - * DNS Extended Result Code types are encoded on the first 8 bits - */ - final static int ExtendedRCode_MASK = 0xFF; + /** DNS Extended Result Code types are encoded on the first 8 bits */ + static final int ExtendedRCode_MASK = 0xFF; - private final String _externalName; + private final String _externalName; - private final int _index; + private final int _index; - DNSResultCode(String name, int index) { - _externalName = name; - _index = index; - } + DNSResultCode(String name, int index) { + _externalName = name; + _index = index; + } - /** - * Return the string representation of this type - * - * @return String - */ - public String externalName() { - return _externalName; - } - - /** - * Return the numeric value of this type - * - * @return String - */ - public int indexValue() { - return _index; - } + /** + * Return the string representation of this type + * + * @return String + */ + public String externalName() { + return _externalName; + } - /** - * @param flags - * @return label - */ - public static DNSResultCode resultCodeForFlags(int flags) { - int maskedIndex = flags & RCode_MASK; - for (DNSResultCode aCode : DNSResultCode.values()) { - if (aCode._index == maskedIndex) return aCode; - } - return Unknown; - } + /** + * Return the numeric value of this type + * + * @return String + */ + public int indexValue() { + return _index; + } - public static DNSResultCode resultCodeForFlags(int flags, int extendedRCode) { - int maskedIndex = ((extendedRCode >> 28) & ExtendedRCode_MASK) | (flags & RCode_MASK); - for (DNSResultCode aCode : DNSResultCode.values()) { - if (aCode._index == maskedIndex) return aCode; - } - return Unknown; + /** + * @param flags + * @return label + */ + public static DNSResultCode resultCodeForFlags(int flags) { + int maskedIndex = flags & RCode_MASK; + for (DNSResultCode aCode : DNSResultCode.values()) { + if (aCode._index == maskedIndex) return aCode; } + return Unknown; + } - @Override - public String toString() { - return this.name() + " index " + this.indexValue(); + public static DNSResultCode resultCodeForFlags(int flags, int extendedRCode) { + int maskedIndex = ((extendedRCode >> 28) & ExtendedRCode_MASK) | (flags & RCode_MASK); + for (DNSResultCode aCode : DNSResultCode.values()) { + if (aCode._index == maskedIndex) return aCode; } + return Unknown; + } + @Override + public String toString() { + return this.name() + " index " + this.indexValue(); + } } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSState.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSState.java index 9b1b99995..9b7578ea9 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSState.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/DNSState.java @@ -11,205 +11,189 @@ */ public enum DNSState { - /** - * - */ - PROBING_1("probing 1", StateClass.probing), - /** - * - */ - PROBING_2("probing 2", StateClass.probing), - /** - * - */ - PROBING_3("probing 3", StateClass.probing), - /** - * - */ - ANNOUNCING_1("announcing 1", StateClass.announcing), - /** - * - */ - ANNOUNCING_2("announcing 2", StateClass.announcing), - /** - * - */ - ANNOUNCED("announced", StateClass.announced), - /** - * - */ - CANCELING_1("canceling 1", StateClass.canceling), - /** - * - */ - CANCELING_2("canceling 2", StateClass.canceling), - /** - * - */ - CANCELING_3("canceling 3", StateClass.canceling), - /** - * - */ - CANCELED("canceled", StateClass.canceled), - /** - * - */ - CLOSING("closing", StateClass.closing), - /** - * - */ - CLOSED("closed", StateClass.closed); - - private enum StateClass { - probing, announcing, announced, canceling, canceled, closing, closed + /** */ + PROBING_1("probing 1", StateClass.probing), + /** */ + PROBING_2("probing 2", StateClass.probing), + /** */ + PROBING_3("probing 3", StateClass.probing), + /** */ + ANNOUNCING_1("announcing 1", StateClass.announcing), + /** */ + ANNOUNCING_2("announcing 2", StateClass.announcing), + /** */ + ANNOUNCED("announced", StateClass.announced), + /** */ + CANCELING_1("canceling 1", StateClass.canceling), + /** */ + CANCELING_2("canceling 2", StateClass.canceling), + /** */ + CANCELING_3("canceling 3", StateClass.canceling), + /** */ + CANCELED("canceled", StateClass.canceled), + /** */ + CLOSING("closing", StateClass.closing), + /** */ + CLOSED("closed", StateClass.closed); + + private enum StateClass { + probing, + announcing, + announced, + canceling, + canceled, + closing, + closed + } + + // private static Logger logger = LoggerFactory.getLogger(DNSState.class.getName()); + + private final String _name; + + private final StateClass _state; + + private DNSState(String name, StateClass state) { + _name = name; + _state = state; + } + + @Override + public final String toString() { + return _name; + } + + /** + * Returns the next advanced state.
+ * In general, this advances one step in the following sequence: PROBING_1, PROBING_2, PROBING_3, + * ANNOUNCING_1, ANNOUNCING_2, ANNOUNCED.
+ * or CANCELING_1, CANCELING_2, CANCELING_3, CANCELED Does not advance for ANNOUNCED and CANCELED + * state. + * + * @return next state + */ + public final DNSState advance() { + switch (this) { + case PROBING_1: + return PROBING_2; + case PROBING_2: + return PROBING_3; + case PROBING_3: + return ANNOUNCING_1; + case ANNOUNCING_1: + return ANNOUNCING_2; + case ANNOUNCING_2: + return ANNOUNCED; + case ANNOUNCED: + return ANNOUNCED; + case CANCELING_1: + return CANCELING_2; + case CANCELING_2: + return CANCELING_3; + case CANCELING_3: + return CANCELED; + case CANCELED: + return CANCELED; + case CLOSING: + return CLOSED; + case CLOSED: + return CLOSED; + default: + // This is just to keep the compiler happy as we have covered all cases before. + return this; } - - // private static Logger logger = LoggerFactory.getLogger(DNSState.class.getName()); - - private final String _name; - - private final StateClass _state; - - private DNSState(String name, StateClass state) { - _name = name; - _state = state; + } + + /** + * Returns to the next reverted state. All states except CANCELED revert to PROBING_1. Status + * CANCELED does not revert. + * + * @return reverted state + */ + public final DNSState revert() { + switch (this) { + case PROBING_1: + case PROBING_2: + case PROBING_3: + case ANNOUNCING_1: + case ANNOUNCING_2: + case ANNOUNCED: + return PROBING_1; + case CANCELING_1: + case CANCELING_2: + case CANCELING_3: + return CANCELING_1; + case CANCELED: + return CANCELED; + case CLOSING: + return CLOSING; + case CLOSED: + return CLOSED; + default: + // This is just to keep the compiler happy as we have covered all cases before. + return this; } - - @Override - public final String toString() { - return _name; - } - - /** - * Returns the next advanced state.
- * In general, this advances one step in the following sequence: PROBING_1, PROBING_2, PROBING_3, ANNOUNCING_1, ANNOUNCING_2, ANNOUNCED.
- * or CANCELING_1, CANCELING_2, CANCELING_3, CANCELED Does not advance for ANNOUNCED and CANCELED state. - * - * @return next state - */ - public final DNSState advance() { - switch (this) { - case PROBING_1: - return PROBING_2; - case PROBING_2: - return PROBING_3; - case PROBING_3: - return ANNOUNCING_1; - case ANNOUNCING_1: - return ANNOUNCING_2; - case ANNOUNCING_2: - return ANNOUNCED; - case ANNOUNCED: - return ANNOUNCED; - case CANCELING_1: - return CANCELING_2; - case CANCELING_2: - return CANCELING_3; - case CANCELING_3: - return CANCELED; - case CANCELED: - return CANCELED; - case CLOSING: - return CLOSED; - case CLOSED: - return CLOSED; - default: - // This is just to keep the compiler happy as we have covered all cases before. - return this; - } - } - - /** - * Returns to the next reverted state. All states except CANCELED revert to PROBING_1. Status CANCELED does not revert. - * - * @return reverted state - */ - public final DNSState revert() { - switch (this) { - case PROBING_1: - case PROBING_2: - case PROBING_3: - case ANNOUNCING_1: - case ANNOUNCING_2: - case ANNOUNCED: - return PROBING_1; - case CANCELING_1: - case CANCELING_2: - case CANCELING_3: - return CANCELING_1; - case CANCELED: - return CANCELED; - case CLOSING: - return CLOSING; - case CLOSED: - return CLOSED; - default: - // This is just to keep the compiler happy as we have covered all cases before. - return this; - } - } - - /** - * Returns true, if this is a probing state. - * - * @return true if probing state, false otherwise - */ - public final boolean isProbing() { - return _state == StateClass.probing; - } - - /** - * Returns true, if this is an announcing state. - * - * @return true if announcing state, false otherwise - */ - public final boolean isAnnouncing() { - return _state == StateClass.announcing; - } - - /** - * Returns true, if this is an announced state. - * - * @return true if announced state, false otherwise - */ - public final boolean isAnnounced() { - return _state == StateClass.announced; - } - - /** - * Returns true, if this is a canceling state. - * - * @return true if canceling state, false otherwise - */ - public final boolean isCanceling() { - return _state == StateClass.canceling; - } - - /** - * Returns true, if this is a canceled state. - * - * @return true if canceled state, false otherwise - */ - public final boolean isCanceled() { - return _state == StateClass.canceled; - } - - /** - * Returns true, if this is a closing state. - * - * @return true if closing state, false otherwise - */ - public final boolean isClosing() { - return _state == StateClass.closing; - } - - /** - * Returns true, if this is a closing state. - * - * @return true if closed state, false otherwise - */ - public final boolean isClosed() { - return _state == StateClass.closed; - } - + } + + /** + * Returns true, if this is a probing state. + * + * @return true if probing state, false otherwise + */ + public final boolean isProbing() { + return _state == StateClass.probing; + } + + /** + * Returns true, if this is an announcing state. + * + * @return true if announcing state, false otherwise + */ + public final boolean isAnnouncing() { + return _state == StateClass.announcing; + } + + /** + * Returns true, if this is an announced state. + * + * @return true if announced state, false otherwise + */ + public final boolean isAnnounced() { + return _state == StateClass.announced; + } + + /** + * Returns true, if this is a canceling state. + * + * @return true if canceling state, false otherwise + */ + public final boolean isCanceling() { + return _state == StateClass.canceling; + } + + /** + * Returns true, if this is a canceled state. + * + * @return true if canceled state, false otherwise + */ + public final boolean isCanceled() { + return _state == StateClass.canceled; + } + + /** + * Returns true, if this is a closing state. + * + * @return true if closing state, false otherwise + */ + public final boolean isClosing() { + return _state == StateClass.closing; + } + + /** + * Returns true, if this is a closing state. + * + * @return true if closed state, false otherwise + */ + public final boolean isClosed() { + return _state == StateClass.closed; + } } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/package-info.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/package-info.java index 829cb4144..41316a160 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/package-info.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/constants/package-info.java @@ -1,2 +1 @@ package io.libp2p.discovery.mdns.impl.constants; - diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/package-info.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/package-info.java index cc0bfbe66..effecf1ed 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/package-info.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/package-info.java @@ -1,2 +1 @@ package io.libp2p.discovery.mdns.impl; - diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/tasks/DNSTask.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/tasks/DNSTask.java index 7b4595a83..880c87cc2 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/tasks/DNSTask.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/tasks/DNSTask.java @@ -4,53 +4,50 @@ import io.libp2p.discovery.mdns.impl.DNSOutgoing; import io.libp2p.discovery.mdns.impl.DNSQuestion; import io.libp2p.discovery.mdns.impl.JmDNSImpl; - +import io.libp2p.discovery.mdns.impl.constants.DNSConstants; import java.io.IOException; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; -import io.libp2p.discovery.mdns.impl.constants.DNSConstants; - public abstract class DNSTask implements Runnable { - private final JmDNSImpl _jmDNSImpl; - protected final ScheduledExecutorService _scheduler = - Executors.newScheduledThreadPool(1); - - protected DNSTask(JmDNSImpl jmDNSImpl) { - super(); - this._jmDNSImpl = jmDNSImpl; - } - - protected JmDNSImpl dns() { - return _jmDNSImpl; - } - - public abstract void start(); - - protected abstract String getName(); - - @Override - public String toString() { - return this.getName(); - } - - protected DNSOutgoing addQuestion(DNSOutgoing out, DNSQuestion rec) throws IOException { - DNSOutgoing newOut = out; - try { - newOut.addQuestion(rec); - } catch (final IOException e) { - int flags = newOut.getFlags(); - boolean multicast = newOut.isMulticast(); - int maxUDPPayload = newOut.getMaxUDPPayload(); - int id = newOut.getId(); - - newOut.setFlags(flags | DNSConstants.FLAGS_TC); - newOut.setId(id); - this._jmDNSImpl.send(newOut); - - newOut = new DNSOutgoing(flags, multicast, maxUDPPayload); - newOut.addQuestion(rec); - } - return newOut; + private final JmDNSImpl _jmDNSImpl; + protected final ScheduledExecutorService _scheduler = Executors.newScheduledThreadPool(1); + + protected DNSTask(JmDNSImpl jmDNSImpl) { + super(); + this._jmDNSImpl = jmDNSImpl; + } + + protected JmDNSImpl dns() { + return _jmDNSImpl; + } + + public abstract void start(); + + protected abstract String getName(); + + @Override + public String toString() { + return this.getName(); + } + + protected DNSOutgoing addQuestion(DNSOutgoing out, DNSQuestion rec) throws IOException { + DNSOutgoing newOut = out; + try { + newOut.addQuestion(rec); + } catch (final IOException e) { + int flags = newOut.getFlags(); + boolean multicast = newOut.isMulticast(); + int maxUDPPayload = newOut.getMaxUDPPayload(); + int id = newOut.getId(); + + newOut.setFlags(flags | DNSConstants.FLAGS_TC); + newOut.setId(id); + this._jmDNSImpl.send(newOut); + + newOut = new DNSOutgoing(flags, multicast, maxUDPPayload); + newOut.addQuestion(rec); } + return newOut; + } } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/tasks/Responder.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/tasks/Responder.java index 77dab738e..cd62930b1 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/tasks/Responder.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/tasks/Responder.java @@ -4,131 +4,131 @@ package io.libp2p.discovery.mdns.impl.tasks; +import io.libp2p.discovery.mdns.impl.DNSIncoming; +import io.libp2p.discovery.mdns.impl.DNSOutgoing; +import io.libp2p.discovery.mdns.impl.DNSQuestion; +import io.libp2p.discovery.mdns.impl.DNSRecord; +import io.libp2p.discovery.mdns.impl.JmDNSImpl; +import io.libp2p.discovery.mdns.impl.constants.DNSConstants; import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.util.HashSet; import java.util.Set; import java.util.concurrent.TimeUnit; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import io.libp2p.discovery.mdns.impl.DNSIncoming; -import io.libp2p.discovery.mdns.impl.DNSOutgoing; -import io.libp2p.discovery.mdns.impl.DNSQuestion; -import io.libp2p.discovery.mdns.impl.DNSRecord; -import io.libp2p.discovery.mdns.impl.JmDNSImpl; -import io.libp2p.discovery.mdns.impl.constants.DNSConstants; - -/** - * The Responder sends a single answer for the specified service infos and for the host name. - */ +/** The Responder sends a single answer for the specified service infos and for the host name. */ public class Responder extends DNSTask { - static Logger logger = LoggerFactory.getLogger(Responder.class.getName()); + static Logger logger = LoggerFactory.getLogger(Responder.class.getName()); + + private final DNSIncoming _in; + + private final InetAddress _addr; + private final int _port; + + private final boolean _unicast; + + public Responder(JmDNSImpl jmDNSImpl, DNSIncoming in, InetAddress addr, int port) { + super(jmDNSImpl); + this._in = in; + this._addr = addr; + this._port = port; + this._unicast = (port != DNSConstants.MDNS_PORT); + } + + @Override + protected String getName() { + return "Responder(" + (this.dns() != null ? this.dns().getName() : "") + ")"; + } + + @Override + public String toString() { + return super.toString() + " incoming: " + _in; + } + + @Override + public void start() { + int delay = + DNSConstants.RESPONSE_MIN_WAIT_INTERVAL + + JmDNSImpl.getRandom() + .nextInt( + DNSConstants.RESPONSE_MAX_WAIT_INTERVAL + - DNSConstants.RESPONSE_MIN_WAIT_INTERVAL + + 1) + - _in.elapseSinceArrival(); + if (delay < 0) { + delay = 0; + } + logger.trace("{}.start() Responder chosen delay={}", this.getName(), delay); - private final DNSIncoming _in; + _scheduler.schedule(this, delay, TimeUnit.MILLISECONDS); + } - private final InetAddress _addr; - private final int _port; + @Override + public void run() { + // We use these sets to prevent duplicate records + Set questions = new HashSet(); + Set answers = new HashSet(); - private final boolean _unicast; + try { + // Answer questions + for (DNSQuestion question : _in.getQuestions()) { + logger.debug("{}.run() JmDNS responding to: {}", this.getName(), question); - public Responder(JmDNSImpl jmDNSImpl, DNSIncoming in, InetAddress addr, int port) { - super(jmDNSImpl); - this._in = in; - this._addr = addr; - this._port = port; - this._unicast = (port != DNSConstants.MDNS_PORT); - } + // for unicast responses the question must be included + if (_unicast) { + questions.add(question); + } - @Override - protected String getName() { - return "Responder(" + (this.dns() != null ? this.dns().getName() : "") + ")"; - } + question.addAnswers(this.dns(), answers); + } - @Override - public String toString() { - return super.toString() + " incoming: " + _in; - } + // respond if we have answers + if (!answers.isEmpty()) { + logger.debug("{}.run() JmDNS responding", this.getName()); - @Override - public void start() { - int delay = - DNSConstants.RESPONSE_MIN_WAIT_INTERVAL + - JmDNSImpl.getRandom().nextInt( - DNSConstants.RESPONSE_MAX_WAIT_INTERVAL - - DNSConstants.RESPONSE_MIN_WAIT_INTERVAL + 1 - ) - - _in.elapseSinceArrival(); - if (delay < 0) { - delay = 0; + DNSOutgoing out = + new DNSOutgoing( + DNSConstants.FLAGS_QR_RESPONSE | DNSConstants.FLAGS_AA, + !_unicast, + _in.getSenderUDPPayload()); + if (_unicast) { + out.setDestination(new InetSocketAddress(_addr, _port)); } - logger.trace("{}.start() Responder chosen delay={}", this.getName(), delay); - - _scheduler.schedule(this, delay, TimeUnit.MILLISECONDS); - } - - @Override - public void run() { - // We use these sets to prevent duplicate records - Set questions = new HashSet(); - Set answers = new HashSet(); - - try { - // Answer questions - for (DNSQuestion question : _in.getQuestions()) { - logger.debug("{}.run() JmDNS responding to: {}", this.getName(), question); - - // for unicast responses the question must be included - if (_unicast) { - questions.add(question); - } - - question.addAnswers(this.dns(), answers); - } - - // respond if we have answers - if (!answers.isEmpty()) { - logger.debug("{}.run() JmDNS responding", this.getName()); - - DNSOutgoing out = new DNSOutgoing(DNSConstants.FLAGS_QR_RESPONSE | DNSConstants.FLAGS_AA, !_unicast, _in.getSenderUDPPayload()); - if (_unicast) { - out.setDestination(new InetSocketAddress(_addr, _port)); - } - out.setId(_in.getId()); - for (DNSQuestion question : questions) { - out = this.addQuestion(out, question); - } - for (DNSRecord answer : answers) { - out = this.addAnswer(out, answer); - } - if (!out.isEmpty()) - this.dns().send(out); - } - } catch (Throwable e) { - logger.warn(this.getName() + "run() exception ", e); + out.setId(_in.getId()); + for (DNSQuestion question : questions) { + out = this.addQuestion(out, question); } - _scheduler.shutdown(); - } - - private DNSOutgoing addAnswer(DNSOutgoing out, DNSRecord rec) throws IOException { - DNSOutgoing newOut = out; - try { - newOut.addAnswer(rec); - } catch (final IOException e) { - int flags = newOut.getFlags(); - boolean multicast = newOut.isMulticast(); - int maxUDPPayload = newOut.getMaxUDPPayload(); - int id = newOut.getId(); - - newOut.setFlags(flags | DNSConstants.FLAGS_TC); - newOut.setId(id); - this.dns().send(newOut); - - newOut = new DNSOutgoing(flags, multicast, maxUDPPayload); - newOut.addAnswer(rec); + for (DNSRecord answer : answers) { + out = this.addAnswer(out, answer); } - return newOut; + if (!out.isEmpty()) this.dns().send(out); + } + } catch (Throwable e) { + logger.warn(this.getName() + "run() exception ", e); + } + _scheduler.shutdown(); + } + + private DNSOutgoing addAnswer(DNSOutgoing out, DNSRecord rec) throws IOException { + DNSOutgoing newOut = out; + try { + newOut.addAnswer(rec); + } catch (final IOException e) { + int flags = newOut.getFlags(); + boolean multicast = newOut.isMulticast(); + int maxUDPPayload = newOut.getMaxUDPPayload(); + int id = newOut.getId(); + + newOut.setFlags(flags | DNSConstants.FLAGS_TC); + newOut.setId(id); + this.dns().send(newOut); + + newOut = new DNSOutgoing(flags, multicast, maxUDPPayload); + newOut.addAnswer(rec); } -} \ No newline at end of file + return newOut; + } +} diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/tasks/ServiceResolver.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/tasks/ServiceResolver.java index 4b15a5544..71c34a244 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/tasks/ServiceResolver.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/tasks/ServiceResolver.java @@ -10,72 +10,73 @@ import io.libp2p.discovery.mdns.impl.constants.DNSConstants; import io.libp2p.discovery.mdns.impl.constants.DNSRecordClass; import io.libp2p.discovery.mdns.impl.constants.DNSRecordType; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.io.IOException; import java.util.concurrent.Future; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** - * The ServiceResolver queries three times consecutively for services of a given type, and then removes itself from the timer. + * The ServiceResolver queries three times consecutively for services of a given type, and then + * removes itself from the timer. */ public class ServiceResolver extends DNSTask { - private static Logger logger = LoggerFactory.getLogger(ServiceResolver.class.getName()); + private static Logger logger = LoggerFactory.getLogger(ServiceResolver.class.getName()); - private final String _type; - private final int _queryInterval; - private ScheduledFuture _isShutdown; + private final String _type; + private final int _queryInterval; + private ScheduledFuture _isShutdown; - public ServiceResolver(JmDNSImpl jmDNSImpl, String type, int queryInterval) { - super(jmDNSImpl); - this._type = type; - this._queryInterval = queryInterval; - } + public ServiceResolver(JmDNSImpl jmDNSImpl, String type, int queryInterval) { + super(jmDNSImpl); + this._type = type; + this._queryInterval = queryInterval; + } - @Override - protected String getName() { - return "ServiceResolver(" + (this.dns() != null ? this.dns().getName() : "") + ")"; - } + @Override + protected String getName() { + return "ServiceResolver(" + (this.dns() != null ? this.dns().getName() : "") + ")"; + } - @Override - public void start() { - _isShutdown = _scheduler.scheduleAtFixedRate( - this, - DNSConstants.QUERY_WAIT_INTERVAL, - _queryInterval * 1000, - TimeUnit.MILLISECONDS - ); - } + @Override + public void start() { + _isShutdown = + _scheduler.scheduleAtFixedRate( + this, DNSConstants.QUERY_WAIT_INTERVAL, _queryInterval * 1000, TimeUnit.MILLISECONDS); + } - @SuppressWarnings("unchecked") - public Future stop() { - _scheduler.shutdown(); - return (Future)_isShutdown; - } + @SuppressWarnings("unchecked") + public Future stop() { + _scheduler.shutdown(); + return (Future) _isShutdown; + } - @Override - public void run() { - try { - logger.debug("{}.run() JmDNS {}",this.getName(), this.description()); - DNSOutgoing out = new DNSOutgoing(DNSConstants.FLAGS_QR_QUERY); - out = this.addQuestions(out); - if (!out.isEmpty()) { - this.dns().send(out); - } - } catch (Throwable e) { - logger.warn(this.getName() + ".run() exception ", e); - } + @Override + public void run() { + try { + logger.debug("{}.run() JmDNS {}", this.getName(), this.description()); + DNSOutgoing out = new DNSOutgoing(DNSConstants.FLAGS_QR_QUERY); + out = this.addQuestions(out); + if (!out.isEmpty()) { + this.dns().send(out); + } + } catch (Throwable e) { + logger.warn(this.getName() + ".run() exception ", e); } + } - private DNSOutgoing addQuestions(DNSOutgoing out) throws IOException { - DNSOutgoing newOut = out; - newOut = this.addQuestion(newOut, DNSQuestion.newQuestion(_type, DNSRecordType.TYPE_PTR, DNSRecordClass.CLASS_IN, DNSRecordClass.NOT_UNIQUE)); - return newOut; - } + private DNSOutgoing addQuestions(DNSOutgoing out) throws IOException { + DNSOutgoing newOut = out; + newOut = + this.addQuestion( + newOut, + DNSQuestion.newQuestion( + _type, DNSRecordType.TYPE_PTR, DNSRecordClass.CLASS_IN, DNSRecordClass.NOT_UNIQUE)); + return newOut; + } - private String description() { - return "querying service"; - } -} \ No newline at end of file + private String description() { + return "querying service"; + } +} diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/tasks/package-info.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/tasks/package-info.java index 0262d1c57..6e2ee0569 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/tasks/package-info.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/tasks/package-info.java @@ -1,2 +1 @@ package io.libp2p.discovery.mdns.impl.tasks; - diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/util/ByteWrangler.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/util/ByteWrangler.java index a628f7217..baf198486 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/util/ByteWrangler.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/util/ByteWrangler.java @@ -5,57 +5,48 @@ import java.nio.charset.Charset; /** - * This class contains all the byte shifting - * - * @author Victor Toni + * This class contains all the byte shifting * + * @author Victor Toni */ public class ByteWrangler { - /** - * Maximum number of bytes a value can consist of. - */ - public static final int MAX_VALUE_LENGTH = 255; - - /** - * Maximum number of bytes record data can consist of. - * It is {@link #MAX_VALUE_LENGTH} + 1 because the first byte contains the number of the following bytes. - */ - public static final int MAX_DATA_LENGTH = MAX_VALUE_LENGTH + 1; - - /** - * Representation of no value. A zero length array of bytes. - */ - public static final byte[] NO_VALUE = new byte[0]; - - /** - * Representation of empty text. - * The first byte denotes the length of the following character bytes (in this case zero.) - * - * FIXME: Should this be exported as a method since it could change externally??? - */ - public final static byte[] EMPTY_TXT = new byte[] { 0 }; - - /** - * Name for charset used to convert Strings to/from wire bytes: {@value #CHARSET_NAME}. - */ - public final static String CHARSET_NAME = "UTF-8"; - - /** - * Charset used to convert Strings to/from wire bytes: {@value #CHARSET_NAME}. - */ - private final static Charset CHARSET_UTF_8 = Charset.forName(CHARSET_NAME); - - public static byte[] encodeText(final String text) throws IOException { - final byte data[] = text.getBytes(CHARSET_UTF_8); - if (data.length > MAX_VALUE_LENGTH) { - return EMPTY_TXT; - } - - final ByteArrayOutputStream out = new ByteArrayOutputStream(MAX_DATA_LENGTH); - out.write((byte) data.length); - out.write(data, 0, data.length); - - final byte[] encodedText = out.toByteArray(); - return encodedText; + /** Maximum number of bytes a value can consist of. */ + public static final int MAX_VALUE_LENGTH = 255; + + /** + * Maximum number of bytes record data can consist of. It is {@link #MAX_VALUE_LENGTH} + 1 because + * the first byte contains the number of the following bytes. + */ + public static final int MAX_DATA_LENGTH = MAX_VALUE_LENGTH + 1; + + /** Representation of no value. A zero length array of bytes. */ + public static final byte[] NO_VALUE = new byte[0]; + + /** + * Representation of empty text. The first byte denotes the length of the following character + * bytes (in this case zero.) + * + *

FIXME: Should this be exported as a method since it could change externally??? + */ + public static final byte[] EMPTY_TXT = new byte[] {0}; + + /** Name for charset used to convert Strings to/from wire bytes: {@value #CHARSET_NAME}. */ + public static final String CHARSET_NAME = "UTF-8"; + + /** Charset used to convert Strings to/from wire bytes: {@value #CHARSET_NAME}. */ + private static final Charset CHARSET_UTF_8 = Charset.forName(CHARSET_NAME); + + public static byte[] encodeText(final String text) throws IOException { + final byte data[] = text.getBytes(CHARSET_UTF_8); + if (data.length > MAX_VALUE_LENGTH) { + return EMPTY_TXT; } + + final ByteArrayOutputStream out = new ByteArrayOutputStream(MAX_DATA_LENGTH); + out.write((byte) data.length); + out.write(data, 0, data.length); + + final byte[] encodedText = out.toByteArray(); + return encodedText; + } } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/util/NamedThreadFactory.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/util/NamedThreadFactory.java index 7906a63ee..e5f0d1765 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/util/NamedThreadFactory.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/impl/util/NamedThreadFactory.java @@ -8,28 +8,30 @@ import java.util.concurrent.ThreadFactory; /** - * Custom thread factory which sets the name to make it easier to identify where the pooled threads were created. + * Custom thread factory which sets the name to make it easier to identify where the pooled threads + * were created. * * @author Trejkaz, Pierre Frisch */ public class NamedThreadFactory implements ThreadFactory { - private final ThreadFactory _delegate; - private final String _namePrefix; + private final ThreadFactory _delegate; + private final String _namePrefix; - /** - * Constructs the thread factory. - * - * @param namePrefix a prefix to append to thread names (will be separated from the default thread name by a space.) - */ - public NamedThreadFactory(String namePrefix) { - this._namePrefix = namePrefix; - _delegate = Executors.defaultThreadFactory(); - } + /** + * Constructs the thread factory. + * + * @param namePrefix a prefix to append to thread names (will be separated from the default thread + * name by a space.) + */ + public NamedThreadFactory(String namePrefix) { + this._namePrefix = namePrefix; + _delegate = Executors.defaultThreadFactory(); + } - @Override - public Thread newThread(Runnable runnable) { - Thread thread = _delegate.newThread(runnable); - thread.setName(_namePrefix + ' ' + thread.getName()); - return thread; - } + @Override + public Thread newThread(Runnable runnable) { + Thread thread = _delegate.newThread(runnable); + thread.setName(_namePrefix + ' ' + thread.getName()); + return thread; + } } diff --git a/libp2p/src/main/java/io/libp2p/discovery/mdns/package-info.java b/libp2p/src/main/java/io/libp2p/discovery/mdns/package-info.java index a93d49ef7..ed16fce56 100644 --- a/libp2p/src/main/java/io/libp2p/discovery/mdns/package-info.java +++ b/libp2p/src/main/java/io/libp2p/discovery/mdns/package-info.java @@ -3,24 +3,19 @@ /** * Java code in this package is derived from the JmDNS project. * - * JmDNS is a Java implementation of multi-cast DNS and can be used for - * service registration and discovery in local area networks. JmDNS is - * fully compatible with Apple's Bonjour. The project was originally - * started in December 2002 by Arthur van Hoff at Strangeberry. In - * November 2003 the project was moved to SourceForge, and the name - * was changed from JRendezvous to JmDNS for legal reasons. Many - * thanks to Stuart Cheshire for help and moral support. In 2014, it was - * been moved from Sourceforge to Github by Kai Kreuzer with the kind - * approval from Arthur and Rick. - *

- * https://github.com/jmdns/jmdns/ - *

+ *

JmDNS is a Java implementation of multi-cast DNS and can be used for service registration and + * discovery in local area networks. JmDNS is fully compatible with Apple's Bonjour. The project was + * originally started in December 2002 by Arthur van Hoff at Strangeberry. In November 2003 the + * project was moved to SourceForge, and the name was changed from JRendezvous to JmDNS for legal + * reasons. Many thanks to Stuart Cheshire for help and moral support. In 2014, it was been moved + * from Sourceforge to Github by Kai Kreuzer with the kind approval from Arthur and Rick. * - * JmDNS was originally licensed under the GNU Lesser General Public - * License as jRendevous. It was re-released under the - * Apache License, Version 2.0 in 2005. It is under those terms it is reused here. - *

- * JmDNS License Notice
- * JmDNS Changelog - *

- **/ + *

https://github.com/jmdns/jmdns/ JmDNS was + * originally licensed under the GNU Lesser General Public License as jRendevous. It was re-released + * under the Apache License, Version 2.0 in 2005. It is under those terms it is reused here. + * + *

JmDNS License Notice
+ * JmDNS + * Changelog + */ diff --git a/libp2p/src/main/java/io/libp2p/protocol/autonat/AutonatProtocol.java b/libp2p/src/main/java/io/libp2p/protocol/autonat/AutonatProtocol.java new file mode 100644 index 000000000..5425ba16a --- /dev/null +++ b/libp2p/src/main/java/io/libp2p/protocol/autonat/AutonatProtocol.java @@ -0,0 +1,172 @@ +package io.libp2p.protocol.autonat; + +import com.google.protobuf.*; +import io.libp2p.core.*; +import io.libp2p.core.Stream; +import io.libp2p.core.multiformats.*; +import io.libp2p.core.multistream.*; +import io.libp2p.protocol.*; +import io.libp2p.protocol.autonat.pb.*; +import java.io.*; +import java.net.*; +import java.util.*; +import java.util.concurrent.*; +import java.util.stream.*; +import org.jetbrains.annotations.*; + +public class AutonatProtocol extends ProtobufProtocolHandler { + + public static class Binding extends StrictProtocolBinding { + public Binding() { + super("/libp2p/autonat/v1.0.0", new AutonatProtocol()); + } + } + + public interface AutoNatController { + CompletableFuture rpc(Autonat.Message req); + + default CompletableFuture requestDial( + PeerId ourId, List us) { + if (us.isEmpty()) + throw new IllegalStateException("Requested autonat dial with no addresses!"); + return rpc(Autonat.Message.newBuilder() + .setType(Autonat.Message.MessageType.DIAL) + .setDial( + Autonat.Message.Dial.newBuilder() + .setPeer( + Autonat.Message.PeerInfo.newBuilder() + .addAllAddrs( + us.stream() + .map(a -> ByteString.copyFrom(a.serialize())) + .collect(Collectors.toList())) + .setId(ByteString.copyFrom(ourId.getBytes())))) + .build()) + .thenApply(msg -> msg.getDialResponse()); + } + } + + public static class Sender implements ProtocolMessageHandler, AutoNatController { + private final Stream stream; + private final LinkedBlockingDeque> queue = + new LinkedBlockingDeque<>(); + + public Sender(Stream stream) { + this.stream = stream; + } + + @Override + public void onMessage(@NotNull Stream stream, Autonat.Message msg) { + queue.poll().complete(msg); + } + + public CompletableFuture rpc(Autonat.Message req) { + CompletableFuture res = new CompletableFuture<>(); + queue.add(res); + stream.writeAndFlush(req); + return res; + } + } + + private static boolean sameIP(Multiaddr a, Multiaddr b) { + if (a.has(Protocol.IP4)) + return a.getFirstComponent(Protocol.IP4).equals(b.getFirstComponent(Protocol.IP4)); + if (a.has(Protocol.IP6)) + return a.getFirstComponent(Protocol.IP6).equals(b.getFirstComponent(Protocol.IP6)); + return false; + } + + private static boolean reachableIP(Multiaddr a) { + try { + if (a.has(Protocol.IP4)) + return InetAddress.getByName(a.getFirstComponent(Protocol.IP4).getStringValue()) + .isReachable(1000); + if (a.has(Protocol.IP6)) + return InetAddress.getByName(a.getFirstComponent(Protocol.IP6).getStringValue()) + .isReachable(1000); + } catch (IOException e) { + } + return false; + } + + public static class Receiver + implements ProtocolMessageHandler, AutoNatController { + private final Stream p2pstream; + + public Receiver(Stream p2pstream) { + this.p2pstream = p2pstream; + } + + @Override + public void onMessage(@NotNull Stream stream, Autonat.Message msg) { + switch (msg.getType()) { + case DIAL: + { + Autonat.Message.Dial dial = msg.getDial(); + PeerId peerId = new PeerId(dial.getPeer().getId().toByteArray()); + List requestedDials = + dial.getPeer().getAddrsList().stream() + .map(s -> Multiaddr.deserialize(s.toByteArray())) + .collect(Collectors.toList()); + PeerId streamPeerId = stream.remotePeerId(); + if (!peerId.equals(streamPeerId)) { + p2pstream.close(); + return; + } + + Multiaddr remote = stream.getConnection().remoteAddress(); + Optional reachable = + requestedDials.stream() + .filter(a -> sameIP(a, remote)) + .filter(a -> !a.has(Protocol.P2PCIRCUIT)) + .filter(a -> reachableIP(a)) + .findAny(); + Autonat.Message.Builder resp = + Autonat.Message.newBuilder().setType(Autonat.Message.MessageType.DIAL_RESPONSE); + if (reachable.isPresent()) { + resp = + resp.setDialResponse( + Autonat.Message.DialResponse.newBuilder() + .setStatus(Autonat.Message.ResponseStatus.OK) + .setAddr(ByteString.copyFrom(reachable.get().serialize()))); + } else { + resp = + resp.setDialResponse( + Autonat.Message.DialResponse.newBuilder() + .setStatus(Autonat.Message.ResponseStatus.E_DIAL_ERROR)); + } + p2pstream.writeAndFlush(resp); + } + default: + { + } + } + } + + public CompletableFuture rpc(Autonat.Message msg) { + return CompletableFuture.failedFuture( + new IllegalStateException("Cannot send form a receiver!")); + } + } + + private static final int TRAFFIC_LIMIT = 2 * 1024; + + public AutonatProtocol() { + super(Autonat.Message.getDefaultInstance(), TRAFFIC_LIMIT, TRAFFIC_LIMIT); + } + + @NotNull + @Override + protected CompletableFuture onStartInitiator(@NotNull Stream stream) { + Sender replyPropagator = new Sender(stream); + stream.pushHandler(replyPropagator); + return CompletableFuture.completedFuture(replyPropagator); + } + + @NotNull + @Override + protected CompletableFuture onStartResponder(@NotNull Stream stream) { + Receiver dialer = new Receiver(stream); + stream.pushHandler(dialer); + return CompletableFuture.completedFuture(dialer); + } +} diff --git a/libp2p/src/main/java/io/libp2p/protocol/circuit/CircuitHopProtocol.java b/libp2p/src/main/java/io/libp2p/protocol/circuit/CircuitHopProtocol.java new file mode 100644 index 000000000..be2be179d --- /dev/null +++ b/libp2p/src/main/java/io/libp2p/protocol/circuit/CircuitHopProtocol.java @@ -0,0 +1,417 @@ +package io.libp2p.protocol.circuit; + +import com.google.protobuf.*; +import io.libp2p.core.*; +import io.libp2p.core.Stream; +import io.libp2p.core.crypto.*; +import io.libp2p.core.multiformats.*; +import io.libp2p.core.multistream.*; +import io.libp2p.etc.util.netty.*; +import io.libp2p.protocol.*; +import io.libp2p.protocol.circuit.crypto.pb.*; +import io.libp2p.protocol.circuit.pb.*; +import io.netty.buffer.*; +import io.netty.channel.*; +import io.netty.handler.codec.protobuf.*; +import java.io.*; +import java.nio.charset.*; +import java.time.*; +import java.time.Duration; +import java.time.temporal.*; +import java.util.*; +import java.util.concurrent.*; +import java.util.function.*; +import java.util.stream.*; +import org.jetbrains.annotations.*; + +public class CircuitHopProtocol extends ProtobufProtocolHandler { + + private static final String HOP_HANDLER_NAME = "HOP_HANDLER"; + private static final String STREAM_CLEARER_NAME = "STREAM_CLEARER"; + + public static class Binding extends StrictProtocolBinding implements HostConsumer { + private final CircuitHopProtocol hop; + + private Binding(CircuitHopProtocol hop) { + super("/libp2p/circuit/relay/0.2.0/hop", hop); + this.hop = hop; + } + + public Binding(RelayManager manager, CircuitStopProtocol.Binding stop) { + this(new CircuitHopProtocol(manager, stop)); + } + + @Override + public void setHost(Host us) { + hop.setHost(us); + } + } + + private static void putUvarint(OutputStream out, long x) throws IOException { + while (x >= 0x80) { + out.write((byte) (x | 0x80)); + x >>= 7; + } + out.write((byte) x); + } + + public static byte[] createVoucher( + PrivKey priv, PeerId relay, PeerId requestor, LocalDateTime expiry) { + ByteArrayOutputStream bout = new ByteArrayOutputStream(); + try { + putUvarint(bout, 0x0302); + } catch (IOException e) { + } + byte[] typeMulticodec = bout.toByteArray(); + byte[] payload = + VoucherOuterClass.Voucher.newBuilder() + .setRelay(ByteString.copyFrom(relay.getBytes())) + .setPeer(ByteString.copyFrom(requestor.getBytes())) + .setExpiration(expiry.toEpochSecond(ZoneOffset.UTC) * 1_000_000_000) + .build() + .toByteArray(); + byte[] signDomain = "libp2p-relay-rsvp".getBytes(StandardCharsets.UTF_8); + ByteArrayOutputStream toSign = new ByteArrayOutputStream(); + try { + putUvarint(toSign, signDomain.length); + toSign.write(signDomain); + putUvarint(toSign, typeMulticodec.length); + toSign.write(typeMulticodec); + putUvarint(toSign, payload.length); + toSign.write(payload); + } catch (IOException e) { + } + byte[] signature = priv.sign(toSign.toByteArray()); + return EnvelopeOuterClass.Envelope.newBuilder() + .setPayloadType(ByteString.copyFrom(typeMulticodec)) + .setPayload(ByteString.copyFrom(payload)) + .setPublicKey( + EnvelopeOuterClass.PublicKey.newBuilder() + .setTypeValue(priv.publicKey().getKeyType().getNumber()) + .setData(ByteString.copyFrom(priv.publicKey().raw()))) + .setSignature(ByteString.copyFrom(signature)) + .build() + .toByteArray(); + } + + public static class Reservation { + public final LocalDateTime expiry; + public final int durationSeconds; + public final long maxBytes; + public final byte[] voucher; + public final Multiaddr[] addrs; + + public Reservation( + LocalDateTime expiry, + int durationSeconds, + long maxBytes, + byte[] voucher, + Multiaddr[] addrs) { + this.expiry = expiry; + this.durationSeconds = durationSeconds; + this.maxBytes = maxBytes; + this.voucher = voucher; + this.addrs = addrs; + } + } + + public interface RelayManager { + boolean hasReservation(PeerId source); + + Optional createReservation(PeerId requestor, Multiaddr addr); + + Optional allowConnection(PeerId target, PeerId initiator); + + static RelayManager limitTo(PrivKey priv, PeerId relayPeerId, int concurrent) { + return new RelayManager() { + Map reservations = new HashMap<>(); + + @Override + public synchronized boolean hasReservation(PeerId source) { + return reservations.containsKey(source); + } + + @Override + public synchronized Optional createReservation( + PeerId requestor, Multiaddr addr) { + if (reservations.size() >= concurrent) return Optional.empty(); + LocalDateTime now = LocalDateTime.now(); + LocalDateTime expiry = now.plusHours(1); + byte[] voucher = createVoucher(priv, relayPeerId, requestor, now); + Reservation resv = new Reservation(expiry, 120, 4096, voucher, new Multiaddr[] {addr}); + reservations.put(requestor, resv); + return Optional.of(resv); + } + + @Override + public synchronized Optional allowConnection(PeerId target, PeerId initiator) { + return Optional.ofNullable(reservations.get(target)); + } + }; + } + } + + public interface HopController { + CompletableFuture rpc(Circuit.HopMessage req); + + default CompletableFuture reserve() { + return rpc(Circuit.HopMessage.newBuilder().setType(Circuit.HopMessage.Type.RESERVE).build()) + .thenApply( + msg -> { + if (msg.getStatus() == Circuit.Status.OK) { + long expiry = msg.getReservation().getExpire(); + return new Reservation( + LocalDateTime.ofEpochSecond(expiry, 0, ZoneOffset.UTC), + msg.getLimit().getDuration(), + msg.getLimit().getData(), + msg.getReservation().getVoucher().toByteArray(), + null); + } + throw new IllegalStateException(msg.getStatus().name()); + }); + } + + CompletableFuture connect(PeerId target); + } + + public static class HopRemover extends ChannelInitializer { + + @Override + protected void initChannel(@NotNull Channel ch) throws Exception { + ch.pipeline().remove(HOP_HANDLER_NAME); + // also remove associated protobuf handlers + ch.pipeline().remove(ProtobufDecoder.class); + ch.pipeline().remove(ProtobufEncoder.class); + ch.pipeline().remove(ProtobufVarint32FrameDecoder.class); + ch.pipeline().remove(ProtobufVarint32LengthFieldPrepender.class); + ch.pipeline().remove(STREAM_CLEARER_NAME); + } + } + + public static class Sender implements ProtocolMessageHandler, HopController { + private final Stream stream; + private final LinkedBlockingDeque> queue = + new LinkedBlockingDeque<>(); + + public Sender(Stream stream) { + this.stream = stream; + } + + @Override + public void onMessage(@NotNull Stream stream, Circuit.HopMessage msg) { + queue.poll().complete(msg); + } + + public CompletableFuture rpc(Circuit.HopMessage req) { + CompletableFuture res = new CompletableFuture<>(); + queue.add(res); + stream.writeAndFlush(req); + return res; + } + + @Override + public CompletableFuture connect(PeerId target) { + return rpc(Circuit.HopMessage.newBuilder() + .setType(Circuit.HopMessage.Type.CONNECT) + .setPeer(Circuit.Peer.newBuilder().setId(ByteString.copyFrom(target.getBytes()))) + .build()) + .thenApply( + msg -> { + if (msg.getType() == Circuit.HopMessage.Type.STATUS + && msg.getStatus() == Circuit.Status.OK) { + // remove handler for HOP to return bare stream + stream.pushHandler(STREAM_CLEARER_NAME, new HopRemover()); + return stream; + } + throw new IllegalStateException("Circuit dial returned " + msg.getStatus().name()); + }); + } + } + + public static class Receiver + implements ProtocolMessageHandler, HopController { + private final Host us; + private final RelayManager manager; + private final Supplier> publicAddresses; + private final CircuitStopProtocol.Binding stop; + private final AddressBook addressBook; + + public Receiver( + Host us, + RelayManager manager, + Supplier> publicAddresses, + CircuitStopProtocol.Binding stop, + AddressBook addressBook) { + this.us = us; + this.manager = manager; + this.publicAddresses = publicAddresses; + this.stop = stop; + this.addressBook = addressBook; + } + + @Override + public void onMessage(@NotNull Stream stream, Circuit.HopMessage msg) { + switch (msg.getType()) { + case RESERVE: + { + PeerId requestor = stream.remotePeerId(); + Optional reservation = + manager.createReservation(requestor, stream.getConnection().remoteAddress()); + if (reservation.isEmpty() + || new Multiaddr(stream.getConnection().remoteAddress().toString()) + .has(Protocol.P2PCIRCUIT)) { + stream.writeAndFlush( + Circuit.HopMessage.newBuilder() + .setType(Circuit.HopMessage.Type.STATUS) + .setStatus(Circuit.Status.RESERVATION_REFUSED)); + return; + } + Reservation resv = reservation.get(); + stream.writeAndFlush( + Circuit.HopMessage.newBuilder() + .setType(Circuit.HopMessage.Type.STATUS) + .setStatus(Circuit.Status.OK) + .setReservation( + Circuit.Reservation.newBuilder() + .setExpire(resv.expiry.toEpochSecond(ZoneOffset.UTC)) + .addAllAddrs( + publicAddresses.get().stream() + .map(a -> ByteString.copyFrom(a.serialize())) + .collect(Collectors.toList())) + .setVoucher(ByteString.copyFrom(resv.voucher))) + .setLimit( + Circuit.Limit.newBuilder() + .setDuration(resv.durationSeconds) + .setData(resv.maxBytes))); + } + case CONNECT: + { + PeerId target = new PeerId(msg.getPeer().getId().toByteArray()); + if (manager.hasReservation(target)) { + PeerId initiator = stream.remotePeerId(); + Optional res = manager.allowConnection(target, initiator); + if (res.isPresent()) { + Reservation resv = res.get(); + try { + CircuitStopProtocol.StopController stop = + this.stop + .dial(us, target, resv.addrs) + .getController() + .orTimeout(15, TimeUnit.SECONDS) + .join(); + Circuit.StopMessage reply = + stop.connect(initiator, resv.durationSeconds, resv.maxBytes).join(); + if (reply.getStatus().equals(Circuit.Status.OK)) { + stream.writeAndFlush( + Circuit.HopMessage.newBuilder() + .setType(Circuit.HopMessage.Type.STATUS) + .setStatus(Circuit.Status.OK)); + Stream toTarget = stop.getStream(); + Stream fromRequestor = stream; + // remove hop and stop handlers from streams before proxying + fromRequestor.pushHandler(STREAM_CLEARER_NAME, new HopRemover()); + toTarget.pushHandler( + CircuitStopProtocol.STOP_REMOVER_NAME, + new CircuitStopProtocol.StopRemover()); + + // connect these streams with time + bytes enforcement + fromRequestor.pushHandler(new InboundTrafficLimitHandler(resv.maxBytes)); + fromRequestor.pushHandler( + new TotalTimeoutHandler( + Duration.of(resv.durationSeconds, ChronoUnit.SECONDS))); + toTarget.pushHandler(new InboundTrafficLimitHandler(resv.maxBytes)); + toTarget.pushHandler( + new TotalTimeoutHandler( + Duration.of(resv.durationSeconds, ChronoUnit.SECONDS))); + fromRequestor.pushHandler(new ProxyHandler(toTarget)); + toTarget.pushHandler(new ProxyHandler(fromRequestor)); + } else { + stream.writeAndFlush( + Circuit.HopMessage.newBuilder() + .setType(Circuit.HopMessage.Type.STATUS) + .setStatus(reply.getStatus())); + } + } catch (Exception e) { + stream.writeAndFlush( + Circuit.HopMessage.newBuilder() + .setType(Circuit.HopMessage.Type.STATUS) + .setStatus(Circuit.Status.CONNECTION_FAILED)); + } + } else { + stream.writeAndFlush( + Circuit.HopMessage.newBuilder() + .setType(Circuit.HopMessage.Type.STATUS) + .setStatus(Circuit.Status.RESOURCE_LIMIT_EXCEEDED)); + } + } else { + stream.writeAndFlush( + Circuit.HopMessage.newBuilder() + .setType(Circuit.HopMessage.Type.STATUS) + .setStatus(Circuit.Status.NO_RESERVATION)); + } + } + } + } + + @Override + public CompletableFuture connect(PeerId target) { + return CompletableFuture.failedFuture( + new IllegalStateException("Cannot send from a receiver!")); + } + + public CompletableFuture rpc(Circuit.HopMessage msg) { + return CompletableFuture.failedFuture( + new IllegalStateException("Cannot send from a receiver!")); + } + } + + private static class ProxyHandler extends ChannelInboundHandlerAdapter { + + private final Stream target; + + public ProxyHandler(Stream target) { + this.target = target; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof ByteBuf) { + target.writeAndFlush(msg); + } + } + } + + private static final int TRAFFIC_LIMIT = 2 * 1024; + private final RelayManager manager; + private final CircuitStopProtocol.Binding stop; + private Host us; + + public CircuitHopProtocol(RelayManager manager, CircuitStopProtocol.Binding stop) { + super(Circuit.HopMessage.getDefaultInstance(), TRAFFIC_LIMIT, TRAFFIC_LIMIT); + this.manager = manager; + this.stop = stop; + } + + public void setHost(Host us) { + this.us = us; + } + + @NotNull + @Override + protected CompletableFuture onStartInitiator(@NotNull Stream stream) { + Sender replyPropagator = new Sender(stream); + stream.pushHandler( + HOP_HANDLER_NAME, new ProtocolMessageHandlerAdapter<>(stream, replyPropagator)); + return CompletableFuture.completedFuture(replyPropagator); + } + + @NotNull + @Override + protected CompletableFuture onStartResponder(@NotNull Stream stream) { + if (us == null) throw new IllegalStateException("null Host for us!"); + Supplier> ourpublicAddresses = () -> us.listenAddresses(); + Receiver dialer = new Receiver(us, manager, ourpublicAddresses, stop, us.getAddressBook()); + stream.pushHandler(HOP_HANDLER_NAME, new ProtocolMessageHandlerAdapter<>(stream, dialer)); + return CompletableFuture.completedFuture(dialer); + } +} diff --git a/libp2p/src/main/java/io/libp2p/protocol/circuit/CircuitStopProtocol.java b/libp2p/src/main/java/io/libp2p/protocol/circuit/CircuitStopProtocol.java new file mode 100644 index 000000000..b10ee62d4 --- /dev/null +++ b/libp2p/src/main/java/io/libp2p/protocol/circuit/CircuitStopProtocol.java @@ -0,0 +1,157 @@ +package io.libp2p.protocol.circuit; + +import com.google.protobuf.*; +import io.libp2p.core.*; +import io.libp2p.core.multistream.*; +import io.libp2p.protocol.*; +import io.libp2p.protocol.circuit.pb.*; +import io.netty.channel.*; +import io.netty.handler.codec.protobuf.*; +import java.util.concurrent.*; +import org.jetbrains.annotations.*; + +public class CircuitStopProtocol + extends ProtobufProtocolHandler { + + private static final String STOP_HANDLER_NAME = "STOP_HANDLER"; + public static final String STOP_REMOVER_NAME = "STOP_REMOVER"; + + public static class Binding extends StrictProtocolBinding { + private final CircuitStopProtocol stop; + + public Binding(CircuitStopProtocol stop) { + super("/libp2p/circuit/relay/0.2.0/stop", stop); + this.stop = stop; + } + + public void setTransport(RelayTransport transport) { + stop.setTransport(transport); + } + } + + public interface StopController { + CompletableFuture rpc(Circuit.StopMessage req); + + Stream getStream(); + + default CompletableFuture connect( + PeerId source, int durationSeconds, long maxBytes) { + return rpc( + Circuit.StopMessage.newBuilder() + .setType(Circuit.StopMessage.Type.CONNECT) + .setPeer(Circuit.Peer.newBuilder().setId(ByteString.copyFrom(source.getBytes()))) + .setLimit(Circuit.Limit.newBuilder().setData(maxBytes).setDuration(durationSeconds)) + .build()); + } + } + + public static class Sender + implements ProtocolMessageHandler, StopController { + private final Stream stream; + private final LinkedBlockingDeque> queue = + new LinkedBlockingDeque<>(); + + public Sender(Stream stream) { + this.stream = stream; + } + + @Override + public void onMessage(@NotNull Stream stream, Circuit.StopMessage msg) { + queue.poll().complete(msg); + } + + public CompletableFuture rpc(Circuit.StopMessage req) { + CompletableFuture res = new CompletableFuture<>(); + queue.add(res); + stream.writeAndFlush(req); + return res; + } + + public Stream getStream() { + return stream; + } + } + + public static class StopRemover extends ChannelInitializer { + + @Override + protected void initChannel(@NotNull Channel ch) throws Exception { + ch.pipeline().remove(ProtobufDecoder.class); + ch.pipeline().remove(ProtobufEncoder.class); + ch.pipeline().remove(ProtobufVarint32FrameDecoder.class); + ch.pipeline().remove(ProtobufVarint32LengthFieldPrepender.class); + ch.pipeline().remove(STOP_HANDLER_NAME); + ch.pipeline().remove(STOP_REMOVER_NAME); + } + } + + public static class Receiver + implements ProtocolMessageHandler, StopController { + private final Stream stream; + private final RelayTransport transport; + + public Receiver(Stream stream, RelayTransport transport) { + this.stream = stream; + this.transport = transport; + } + + @Override + public void onMessage(@NotNull Stream stream, Circuit.StopMessage msg) { + if (msg.getType() == Circuit.StopMessage.Type.CONNECT) { + PeerId remote = new PeerId(msg.getPeer().getId().toByteArray()); + int durationSeconds = msg.getLimit().getDuration(); + long limitBytes = msg.getLimit().getData(); + stream.writeAndFlush( + Circuit.StopMessage.newBuilder() + .setType(Circuit.StopMessage.Type.STATUS) + .setStatus(Circuit.Status.OK) + .build()); + // remove STOP handler from stream before upgrading + stream.pushHandler(STOP_REMOVER_NAME, new StopRemover()); + + // now upgrade connection with security and muxer protocol + ConnectionHandler connHandler = null; // TODO + RelayTransport.upgradeStream( + stream, false, transport.upgrader, transport, remote, connHandler); + } + } + + public Stream getStream() { + return stream; + } + + public CompletableFuture rpc(Circuit.StopMessage msg) { + return CompletableFuture.failedFuture( + new IllegalStateException("Cannot send form a receiver!")); + } + } + + private static final int TRAFFIC_LIMIT = 2 * 1024; + + private RelayTransport transport; + + public CircuitStopProtocol() { + super(Circuit.StopMessage.getDefaultInstance(), TRAFFIC_LIMIT, TRAFFIC_LIMIT); + } + + public void setTransport(RelayTransport transport) { + this.transport = transport; + } + + @NotNull + @Override + protected CompletableFuture onStartInitiator(@NotNull Stream stream) { + Sender replyPropagator = new Sender(stream); + stream.pushHandler( + STOP_HANDLER_NAME, new ProtocolMessageHandlerAdapter<>(stream, replyPropagator)); + return CompletableFuture.completedFuture(replyPropagator); + } + + @NotNull + @Override + protected CompletableFuture onStartResponder(@NotNull Stream stream) { + Receiver acceptor = new Receiver(stream, transport); + stream.pushHandler(STOP_HANDLER_NAME, new ProtocolMessageHandlerAdapter<>(stream, acceptor)); + return CompletableFuture.completedFuture(acceptor); + } +} diff --git a/libp2p/src/main/java/io/libp2p/protocol/circuit/HostConsumer.java b/libp2p/src/main/java/io/libp2p/protocol/circuit/HostConsumer.java new file mode 100644 index 000000000..c3848e699 --- /dev/null +++ b/libp2p/src/main/java/io/libp2p/protocol/circuit/HostConsumer.java @@ -0,0 +1,8 @@ +package io.libp2p.protocol.circuit; + +import io.libp2p.core.*; + +public interface HostConsumer { + + void setHost(Host us); +} diff --git a/libp2p/src/main/java/io/libp2p/protocol/circuit/RelayTransport.java b/libp2p/src/main/java/io/libp2p/protocol/circuit/RelayTransport.java new file mode 100644 index 000000000..a6d44e1f8 --- /dev/null +++ b/libp2p/src/main/java/io/libp2p/protocol/circuit/RelayTransport.java @@ -0,0 +1,326 @@ +package io.libp2p.protocol.circuit; + +import io.libp2p.core.*; +import io.libp2p.core.Stream; +import io.libp2p.core.multiformats.*; +import io.libp2p.core.mux.*; +import io.libp2p.core.security.*; +import io.libp2p.core.transport.*; +import io.libp2p.etc.*; +import io.libp2p.transport.*; +import io.netty.channel.*; +import java.time.*; +import java.util.*; +import java.util.concurrent.*; +import java.util.concurrent.atomic.*; +import java.util.function.Function; +import java.util.stream.*; +import kotlin.*; +import org.jetbrains.annotations.*; + +public class RelayTransport implements Transport, HostConsumer { + private Host us; + private final Map listeners = new ConcurrentHashMap<>(); + private final Map dials = new ConcurrentHashMap<>(); + private final Function> candidateRelays; + private final CircuitHopProtocol.Binding hop; + private final CircuitStopProtocol.Binding stop; + public final ConnectionUpgrader upgrader; + private final AtomicInteger relayCount; + private final ScheduledExecutorService runner; + + public RelayTransport( + CircuitHopProtocol.Binding hop, + CircuitStopProtocol.Binding stop, + ConnectionUpgrader upgrader, + Function> candidateRelays, + ScheduledExecutorService runner) { + this.hop = hop; + this.stop = stop; + this.upgrader = upgrader; + this.candidateRelays = candidateRelays; + this.relayCount = new AtomicInteger(0); + this.runner = runner; + } + + @Override + public void setHost(Host us) { + this.us = us; + hop.setHost(us); + } + + public static class CandidateRelay { + public final PeerId id; + public final List addrs; + + public CandidateRelay(PeerId id, List addrs) { + this.id = id; + this.addrs = addrs; + } + } + + private static class RelayState { + List addrs; + CircuitHopProtocol.HopController controller; + Connection conn; + LocalDateTime renewAfter; + } + + public void setRelayCount(int count) { + relayCount.set(count); + } + + @Override + public int getActiveConnections() { + return dials.size(); + } + + @Override + public int getActiveListeners() { + return listeners.size(); + } + + @NotNull + @Override + public CompletableFuture close() { + return CompletableFuture.allOf( + dials.values().stream().map(Stream::close).toArray(CompletableFuture[]::new)) + .thenApply( + x -> { + dials.clear(); + return null; + }); + } + + static class ConnectionOverStream implements Connection { + private final boolean isInitiator; + private final Transport transport; + private final Stream stream; + private SecureChannel.Session security; + private StreamMuxer.Session muxer; + + public ConnectionOverStream(boolean isInitiator, Transport transport, Stream stream) { + this.isInitiator = isInitiator; + this.transport = transport; + this.stream = stream; + } + + @NotNull + @Override + public Multiaddr localAddress() { + return stream.getConnection().localAddress().withComponent(Protocol.P2PCIRCUIT); + } + + @NotNull + @Override + public Multiaddr remoteAddress() { + return stream.getConnection().remoteAddress().withComponent(Protocol.P2PCIRCUIT); + } + + public void setSecureSession(SecureChannel.Session sec) { + this.security = sec; + } + + @NotNull + @Override + public SecureChannel.Session secureSession() { + return security; + } + + public void setMuxerSession(StreamMuxer.Session mux) { + this.muxer = mux; + } + + @NotNull + @Override + public StreamMuxer.Session muxerSession() { + return muxer; + } + + @NotNull + @Override + public Transport transport() { + return transport; + } + + @Override + public boolean isInitiator() { + return isInitiator; + } + + @Override + public void addHandlerBefore( + @NotNull String s, @NotNull String s1, @NotNull ChannelHandler channelHandler) { + stream.addHandlerBefore(s, s1, channelHandler); + } + + @NotNull + @Override + public CompletableFuture close() { + return stream.close(); + } + + @NotNull + @Override + public CompletableFuture closeFuture() { + return stream.closeFuture(); + } + + @Override + public void pushHandler(@NotNull ChannelHandler channelHandler) { + stream.pushHandler(channelHandler); + } + + @Override + public void pushHandler(@NotNull String s, @NotNull ChannelHandler channelHandler) { + stream.pushHandler(s, channelHandler); + } + } + + @NotNull + @Override + public CompletableFuture dial( + @NotNull Multiaddr multiaddr, + @NotNull ConnectionHandler connHandler, + @Nullable ChannelVisitor channelVisitor) { + // first connect to relay over hop + List comps = multiaddr.getComponents(); + int split = comps.indexOf(new MultiaddrComponent(Protocol.P2PCIRCUIT, null)); + Multiaddr relay = new Multiaddr(comps.subList(0, split)); + Multiaddr target = new Multiaddr(comps.subList(split, comps.size())); + CircuitHopProtocol.HopController ctr = hop.dial(us, relay).getController().join(); + // request proxy to target + Stream stream = ctr.connect(target.getPeerId()).join(); + // upgrade with sec and muxer + return upgradeStream(stream, true, upgrader, this, target.getPeerId(), connHandler); + } + + public static CompletableFuture upgradeStream( + Stream stream, + boolean isInitiator, + ConnectionUpgrader upgrader, + Transport transport, + PeerId remote, + ConnectionHandler connHandler) { + ConnectionOverStream conn = new ConnectionOverStream(isInitiator, transport, stream); + CompletableFuture res = new CompletableFuture<>(); + stream.pushHandler( + new ChannelInitializer<>() { + @Override + protected void initChannel(Channel channel) throws Exception { + channel.attr(AttributesKt.getREMOTE_PEER_ID()).set(remote); + channel.attr(AttributesKt.getCONNECTION()).set(conn); + upgrader + .establishSecureChannel(conn) + .thenCompose( + sess -> { + conn.setSecureSession(sess); + if (sess.getEarlyMuxer() != null) { + return ConnectionUpgrader.Companion.establishMuxer( + sess.getEarlyMuxer(), conn); + } else { + return upgrader.establishMuxer(conn); + } + }) + .thenAccept( + sess -> { + conn.setMuxerSession(sess); + connHandler.handleConnection(conn); + res.complete(conn); + }) + .exceptionally( + t -> { + res.completeExceptionally(t); + return null; + }); + channel.pipeline().fireChannelActive(); + } + }); + return res; + } + + @Override + public boolean handles(@NotNull Multiaddr multiaddr) { + return multiaddr.hasAny(Protocol.P2PCIRCUIT); + } + + @Override + public void initialize() { + stop.setTransport(this); + // find relays and connect and reserve + runner.scheduleAtFixedRate(this::ensureEnoughCurrentRelays, 0, 2 * 60, TimeUnit.SECONDS); + } + + public void ensureEnoughCurrentRelays() { + int active = 0; + // renew existing relays before finding new ones + Set> currentRelays = listeners.entrySet(); + for (Map.Entry current : currentRelays) { + RelayState relay = current.getValue(); + LocalDateTime now = LocalDateTime.now(); + if (now.isBefore(relay.renewAfter)) { + active++; + } else { + try { + CircuitHopProtocol.Reservation reservation = relay.controller.reserve().join(); + relay.renewAfter = reservation.expiry.minusMinutes(1); + active++; + } catch (Exception e) { + listeners.remove(current.getKey()); + } + } + } + if (active >= relayCount.get()) return; + + List candidates = candidateRelays.apply(us); + for (CandidateRelay candidate : candidates) { + // connect to relay and get reservation + CircuitHopProtocol.HopController ctr = + hop.dial(us, candidate.id, candidate.addrs.toArray(new Multiaddr[0])) + .getController() + .join(); + CircuitHopProtocol.Reservation resv = ctr.reserve().join(); + active++; + listeners.put(candidate.id, new RelayState()); + if (active >= relayCount.get()) return; + } + } + + @NotNull + @Override + public CompletableFuture listen( + @NotNull Multiaddr relayAddr, + @NotNull ConnectionHandler connectionHandler, + @Nullable ChannelVisitor channelVisitor) { + List components = relayAddr.getComponents(); + Multiaddr withoutCircuit = new Multiaddr(components.subList(0, components.size() - 1)); + CircuitHopProtocol.HopController ctr = hop.dial(us, withoutCircuit).getController().join(); + return ctr.reserve().thenApply(res -> null); + } + + @NotNull + @Override + public List listenAddresses() { + return listeners.entrySet().stream() + .flatMap( + r -> + r.getValue().addrs.stream() + .map( + a -> + a.withP2P(r.getKey()) + .concatenated( + new Multiaddr( + List.of( + new MultiaddrComponent(Protocol.P2PCIRCUIT, null))) + .withP2P(us.getPeerId())))) + .collect(Collectors.toList()); + } + + @NotNull + @Override + public CompletableFuture unlisten(@NotNull Multiaddr multiaddr) { + RelayState relayState = listeners.get(multiaddr); + if (relayState == null) return CompletableFuture.completedFuture(null); + return relayState.conn.close(); + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/core/Connection.kt b/libp2p/src/main/kotlin/io/libp2p/core/Connection.kt index 579b9a1a0..64eb52b64 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/Connection.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/Connection.kt @@ -27,6 +27,7 @@ interface Connection : P2PChannel { * Returns the local [Multiaddr] of this [Connection] */ fun localAddress(): Multiaddr + /** * Returns the remote [Multiaddr] of this [Connection] */ diff --git a/libp2p/src/main/kotlin/io/libp2p/core/Host.kt b/libp2p/src/main/kotlin/io/libp2p/core/Host.kt index cfdb3586d..1ba4685b0 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/Host.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/Host.kt @@ -15,14 +15,17 @@ interface Host { * Our private key which can be used by different protocols to sign messages */ val privKey: PrivKey + /** * Our [PeerId] which is normally derived from [privKey] */ val peerId: PeerId + /** * [Network] implementation */ val network: Network + /** * [AddressBook] implementation */ @@ -81,6 +84,8 @@ interface Host { */ fun addProtocolHandler(protocolBinding: ProtocolBinding) + fun getProtocols(): List> + /** * Removes the handler added with [addProtocolHandler] */ diff --git a/libp2p/src/main/kotlin/io/libp2p/core/Libp2pException.kt b/libp2p/src/main/kotlin/io/libp2p/core/Libp2pException.kt index 31bab6068..b25269f7c 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/Libp2pException.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/Libp2pException.kt @@ -50,6 +50,7 @@ open class NoSuchProtocolException(message: String) : Libp2pException(message) * Indicates that the protocol is not registered at local configuration */ class NoSuchLocalProtocolException(message: String) : NoSuchProtocolException(message) + /** * Indicates that the protocol is not known by the remote party */ diff --git a/libp2p/src/main/kotlin/io/libp2p/core/crypto/Key.kt b/libp2p/src/main/kotlin/io/libp2p/core/crypto/Key.kt index 3324c510c..7d394966d 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/crypto/Key.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/crypto/Key.kt @@ -31,7 +31,7 @@ import java.security.SecureRandom import crypto.pb.Crypto.PrivateKey as PbPrivateKey import crypto.pb.Crypto.PublicKey as PbPublicKey -enum class KEY_TYPE { +enum class KeyType { /** * RSA is an enum for the supported RSA key type @@ -56,7 +56,7 @@ enum class KEY_TYPE { interface Key { - val keyType: crypto.pb.Crypto.KeyType + val keyType: Crypto.KeyType /** * Bytes returns a serialized, storeable representation of this key. @@ -124,12 +124,12 @@ abstract class PubKey(override val keyType: Crypto.KeyType) : Key { * @param bits the number of bits desired for the key (only applicable for RSA). */ @JvmOverloads -fun generateKeyPair(type: KEY_TYPE, bits: Int = 2048, random: SecureRandom = SecureRandom()): Pair { +fun generateKeyPair(type: KeyType, bits: Int = 2048, random: SecureRandom = SecureRandom()): Pair { return when (type) { - KEY_TYPE.RSA -> generateRsaKeyPair(bits, random) - KEY_TYPE.ED25519 -> generateEd25519KeyPair(random) - KEY_TYPE.SECP256K1 -> generateSecp256k1KeyPair(random) - KEY_TYPE.ECDSA -> generateEcdsaKeyPair(random) + KeyType.RSA -> generateRsaKeyPair(bits, random) + KeyType.ED25519 -> generateEd25519KeyPair(random) + KeyType.SECP256K1 -> generateSecp256k1KeyPair(random) + KeyType.ECDSA -> generateEcdsaKeyPair(random) } } diff --git a/libp2p/src/main/kotlin/io/libp2p/core/dsl/Builders.kt b/libp2p/src/main/kotlin/io/libp2p/core/dsl/Builders.kt index 5482b5e8e..ce1416dfd 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/dsl/Builders.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/dsl/Builders.kt @@ -8,7 +8,7 @@ import io.libp2p.core.ConnectionHandler import io.libp2p.core.Host import io.libp2p.core.P2PChannel import io.libp2p.core.Stream -import io.libp2p.core.crypto.KEY_TYPE +import io.libp2p.core.crypto.KeyType import io.libp2p.core.crypto.PrivKey import io.libp2p.core.crypto.generateKeyPair import io.libp2p.core.multiformats.Multiaddr @@ -23,6 +23,7 @@ import io.libp2p.core.security.SecureChannel import io.libp2p.core.transport.Transport import io.libp2p.etc.types.lazyVar import io.libp2p.etc.types.toProtobuf +import io.libp2p.etc.util.netty.LoggingHandlerShort import io.libp2p.host.HostImpl import io.libp2p.host.MemoryAddressBook import io.libp2p.network.NetworkImpl @@ -33,6 +34,7 @@ import io.libp2p.transport.tcp.TcpTransport import io.netty.channel.ChannelHandler import io.netty.handler.logging.LogLevel import io.netty.handler.logging.LoggingHandler +import java.util.concurrent.CopyOnWriteArrayList typealias TransportCtor = (ConnectionUpgrader) -> Transport typealias SecureChannelCtor = (PrivKey, List) -> SecureChannel @@ -173,7 +175,8 @@ open class Builder { } } - val muxers = muxers.map { it.createMuxer(streamMultistreamProtocol, protocols.values) } + val updatableProtocols: MutableList> = CopyOnWriteArrayList(protocols.values) + val muxers = muxers.map { it.createMuxer(streamMultistreamProtocol, updatableProtocols) } val secureChannels = secureChannels.values.map { it(privKey, muxers) } @@ -201,7 +204,7 @@ open class Builder { networkImpl, addressBook, network.listen.map { Multiaddr(it) }, - protocols.values, + updatableProtocols, broadcastConnHandler, streamVisitors ) @@ -217,8 +220,8 @@ class NetworkConfigBuilder { class IdentityBuilder { var factory: IdentityFactory? = null - fun random() = random(KEY_TYPE.ECDSA) - fun random(keyType: KEY_TYPE): IdentityBuilder = apply { factory = { generateKeyPair(keyType).first } } + fun random() = random(KeyType.ECDSA) + fun random(keyType: KeyType): IdentityBuilder = apply { factory = { generateKeyPair(keyType).first } } } class AddressBookBuilder { @@ -239,11 +242,13 @@ class DebugBuilder { * Could be primarily useful for security handshake debugging/monitoring */ val beforeSecureHandler = DebugHandlerBuilder("wire.sec.before") + /** * Injects the [ChannelHandler] right after the connection cipher * to handle plain wire messages */ val afterSecureHandler = DebugHandlerBuilder("wire.sec.after") + /** * Injects the [ChannelHandler] right after the [StreamMuxer] pipeline handler * It intercepts [io.libp2p.mux.MuxFrame] instances @@ -269,6 +274,10 @@ class DebugHandlerBuilder(var name: String) { fun addLogger(level: LogLevel, loggerName: String = name) { addNettyHandler(LoggingHandler(loggerName, level)) } + + fun addCompactLogger(level: LogLevel, loggerName: String = name) { + addNettyHandler(LoggingHandlerShort(loggerName, level)) + } } open class Enumeration(val values: MutableList = mutableListOf()) : MutableList by values { diff --git a/libp2p/src/main/kotlin/io/libp2p/core/multiformats/Multiaddr.kt b/libp2p/src/main/kotlin/io/libp2p/core/multiformats/Multiaddr.kt index d918c1bfb..8bc56d621 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/multiformats/Multiaddr.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/multiformats/Multiaddr.kt @@ -61,7 +61,11 @@ data class Multiaddr(val components: List) { * @throws IllegalArgumentException if existing component value doesn't match [value] */ private fun withComponentImpl(protocol: Protocol, value: ByteArray?): Multiaddr { - val existingComponent = getFirstComponent(protocol) + val existingComponent = if (has(Protocol.P2PCIRCUIT)) { + split { it == Protocol.P2PCIRCUIT }.get(1).getFirstComponent(protocol) + } else { + getFirstComponent(protocol) + } val newComponent = MultiaddrComponent(protocol, value) return if (existingComponent != null) { if (!existingComponent.value.contentEquals(value)) { diff --git a/libp2p/src/main/kotlin/io/libp2p/core/multiformats/MultiaddrDns.kt b/libp2p/src/main/kotlin/io/libp2p/core/multiformats/MultiaddrDns.kt index 294fc5a8a..2fe118b27 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/multiformats/MultiaddrDns.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/multiformats/MultiaddrDns.kt @@ -17,18 +17,20 @@ class MultiaddrDns { private val dnsProtocols = arrayOf(Protocol.DNS4, Protocol.DNS6, Protocol.DNSADDR) fun resolve(addr: Multiaddr, resolver: Resolver = DefaultResolver): List { - if (!addr.hasAny(*dnsProtocols)) + if (!addr.hasAny(*dnsProtocols)) { return listOf(addr) + } val addressesToResolve = addr.split { isDnsProtocol(it) } val resolvedAddresses = mutableListOf>() for (address in addressesToResolve) { val toResolve = address.filterComponents(*dnsProtocols).firstOrNull() - val resolved = if (toResolve != null) + val resolved = if (toResolve != null) { resolve(toResolve.protocol, toResolve.stringValue!!, address, resolver) - else + } else { listOf(address) + } resolvedAddresses.add(resolved) } @@ -70,13 +72,14 @@ class MultiaddrDns { // * /ip4/1.1.1.2/p2p-circuit/ip4/2.1.1.1 // * /ip4/1.1.1.2/p2p-circuit/ip4/2.1.1.2 private fun crossProduct(addressMatrix: List>): List { - return if (addressMatrix.size == 1) + return if (addressMatrix.size == 1) { addressMatrix[0] - else + } else { addressMatrix[0].flatMap { parent -> crossProduct(addressMatrix.subList(1, addressMatrix.size)) .map { child -> parent.concatenated(child) } } + } } private fun isDnsProtocol(proto: Protocol): Boolean { diff --git a/libp2p/src/main/kotlin/io/libp2p/core/multiformats/Multihash.kt b/libp2p/src/main/kotlin/io/libp2p/core/multiformats/Multihash.kt index ad3645506..b09c5fa88 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/multiformats/Multihash.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/multiformats/Multihash.kt @@ -112,6 +112,7 @@ class Multihash(val bytes: ByteBuf, val desc: Descriptor, val lengthBits: Int, v return Multihash(this, desc, lengthBits, digest) } } + @JvmStatic fun digest(desc: Descriptor, content: ByteBuf, lengthBits: Int? = null): Multihash { val entry = REGISTRY[desc] ?: throw InvalidMultihashException("Unrecognised multihash descriptor") diff --git a/libp2p/src/main/kotlin/io/libp2p/core/multiformats/Protocol.kt b/libp2p/src/main/kotlin/io/libp2p/core/multiformats/Protocol.kt index f097591a1..5d171b811 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/multiformats/Protocol.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/multiformats/Protocol.kt @@ -115,8 +115,9 @@ enum class Protocol( private const val LENGTH_PREFIXED_VAR_SIZE = -1 private val SIZE_VALIDATOR: (Protocol, ByteArray?) -> Unit = { protocol, bytes -> - if (!protocol.hasValue && bytes != null) + if (!protocol.hasValue && bytes != null) { throw IllegalArgumentException("No value expected for protocol $protocol, but got ${bytes.contentToString()}") + } if (protocol.hasValue) { requireNotNull(bytes) { "Non-null value expected for protocol $protocol" } if (protocol.sizeBits != LENGTH_PREFIXED_VAR_SIZE && bytes.size * 8 != protocol.sizeBits) { diff --git a/libp2p/src/main/kotlin/io/libp2p/core/multistream/Multistream.kt b/libp2p/src/main/kotlin/io/libp2p/core/multistream/Multistream.kt index 2a7101334..1851086aa 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/multistream/Multistream.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/multistream/Multistream.kt @@ -25,5 +25,5 @@ interface Multistream : P2PChannelHandler { * For _initiator_ role this is the list of protocols the initiator wants to instantiate. * Basically this is either a single protocol or a protocol versions */ - val bindings: MutableList> + val bindings: List> } diff --git a/libp2p/src/main/kotlin/io/libp2p/core/multistream/NegotiatedProtocol.kt b/libp2p/src/main/kotlin/io/libp2p/core/multistream/NegotiatedProtocol.kt index e0b04d655..294fd1f6e 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/multistream/NegotiatedProtocol.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/multistream/NegotiatedProtocol.kt @@ -6,7 +6,7 @@ import java.util.concurrent.CompletableFuture /** * Represents [ProtocolBinding] with exact protocol version which was agreed on */ -open class NegotiatedProtocol> ( +open class NegotiatedProtocol>( val binding: TBinding, val protocol: ProtocolId ) { diff --git a/libp2p/src/main/kotlin/io/libp2p/core/multistream/ProtocolMatcher.kt b/libp2p/src/main/kotlin/io/libp2p/core/multistream/ProtocolMatcher.kt index 7e8448c38..8f46d6144 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/multistream/ProtocolMatcher.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/multistream/ProtocolMatcher.kt @@ -15,10 +15,12 @@ interface ProtocolMatcher { fun strict(protocol: ProtocolId) = object : ProtocolMatcher { override fun matches(proposed: ProtocolId) = protocol == proposed } + @JvmStatic fun prefix(protocolPrefix: String) = object : ProtocolMatcher { override fun matches(proposed: ProtocolId) = proposed.startsWith(protocolPrefix) } + @JvmStatic fun list(protocols: Collection) = object : ProtocolMatcher { override fun matches(proposed: ProtocolId) = proposed in protocols diff --git a/libp2p/src/main/kotlin/io/libp2p/core/mux/StreamMuxerProtocol.kt b/libp2p/src/main/kotlin/io/libp2p/core/mux/StreamMuxerProtocol.kt index 879cd60cd..878e74d05 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/mux/StreamMuxerProtocol.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/mux/StreamMuxerProtocol.kt @@ -3,6 +3,8 @@ package io.libp2p.core.mux import io.libp2p.core.multistream.MultistreamProtocol import io.libp2p.core.multistream.ProtocolBinding import io.libp2p.mux.mplex.MplexStreamMuxer +import io.libp2p.mux.yamux.DEFAULT_ACK_BACKLOG_LIMIT +import io.libp2p.mux.yamux.DEFAULT_MAX_BUFFERED_CONNECTION_WRITES import io.libp2p.mux.yamux.YamuxStreamMuxer fun interface StreamMuxerProtocol { @@ -20,14 +22,26 @@ fun interface StreamMuxerProtocol { ) } + /** + * @param maxBufferedConnectionWrites the maximum amount of bytes in the write buffer per connection + * @param ackBacklogLimit the maximum amount of opened streams per connection which have not been acknowledged + */ @JvmStatic - val Yamux = StreamMuxerProtocol { multistreamProtocol, protocols -> - YamuxStreamMuxer( - multistreamProtocol.createMultistream( - protocols - ).toStreamHandler(), - multistreamProtocol - ) + @JvmOverloads + fun getYamux( + maxBufferedConnectionWrites: Int = DEFAULT_MAX_BUFFERED_CONNECTION_WRITES, + ackBacklogLimit: Int = DEFAULT_ACK_BACKLOG_LIMIT + ): StreamMuxerProtocol { + return StreamMuxerProtocol { multistreamProtocol, protocols -> + YamuxStreamMuxer( + multistreamProtocol.createMultistream( + protocols + ).toStreamHandler(), + multistreamProtocol, + maxBufferedConnectionWrites, + ackBacklogLimit + ) + } } } } diff --git a/libp2p/src/main/kotlin/io/libp2p/core/pubsub/PubsubApi.kt b/libp2p/src/main/kotlin/io/libp2p/core/pubsub/PubsubApi.kt index 0ecb451d4..21b8dbffe 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/pubsub/PubsubApi.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/pubsub/PubsubApi.kt @@ -173,14 +173,17 @@ interface MessageApi { * Message body */ val data: ByteBuf + /** * Sender identity. Usually it a [PeerId] derived from the sender's public key */ val from: ByteArray? + /** * Sequence id for the sender. A pair [from]` + `[seqId] should be globally unique */ val seqId: Long? + /** * A set of message topics */ diff --git a/libp2p/src/main/kotlin/io/libp2p/crypto/Libp2pCrypto.kt b/libp2p/src/main/kotlin/io/libp2p/crypto/Libp2pCrypto.kt index daa6ff4d2..7cb7b0e71 100644 --- a/libp2p/src/main/kotlin/io/libp2p/crypto/Libp2pCrypto.kt +++ b/libp2p/src/main/kotlin/io/libp2p/crypto/Libp2pCrypto.kt @@ -19,11 +19,11 @@ import org.bouncycastle.crypto.macs.HMac import org.bouncycastle.crypto.params.KeyParameter /** - * ErrRsaKeyTooSmall is returned when trying to generate or parse an RSA key + * ERR_RSA_KEY_TOO_SMALL is returned when trying to generate or parse an RSA key * that's smaller than 512 bits. Keys need to be larger enough to sign a 256bit * hash so this is a reasonable absolute minimum. */ -const val ErrRsaKeyTooSmall = "rsa keys must be >= 512 bits to be useful" +const val ERR_RSA_KEY_TOO_SMALL = "rsa keys must be >= 512 bits to be useful" const val RSA_ALGORITHM = "RSA" const val SHA_ALGORITHM = "SHA-256" diff --git a/libp2p/src/main/kotlin/io/libp2p/crypto/keys/Rsa.kt b/libp2p/src/main/kotlin/io/libp2p/crypto/keys/Rsa.kt index 4e0e14f47..752a23cff 100644 --- a/libp2p/src/main/kotlin/io/libp2p/crypto/keys/Rsa.kt +++ b/libp2p/src/main/kotlin/io/libp2p/crypto/keys/Rsa.kt @@ -16,7 +16,7 @@ import crypto.pb.Crypto import io.libp2p.core.Libp2pException import io.libp2p.core.crypto.PrivKey import io.libp2p.core.crypto.PubKey -import io.libp2p.crypto.ErrRsaKeyTooSmall +import io.libp2p.crypto.ERR_RSA_KEY_TOO_SMALL import io.libp2p.crypto.KEY_PKCS8 import io.libp2p.crypto.Libp2pCrypto import io.libp2p.crypto.RSA_ALGORITHM @@ -100,7 +100,7 @@ class RsaPublicKey(private val k: JavaPublicKey) : PubKey(Crypto.KeyType.RSA) { @JvmOverloads fun generateRsaKeyPair(bits: Int, random: SecureRandom = SecureRandom()): Pair { if (bits < 2048) { - throw Libp2pException(ErrRsaKeyTooSmall) + throw Libp2pException(ERR_RSA_KEY_TOO_SMALL) } val kp: KeyPair = with( diff --git a/libp2p/src/main/kotlin/io/libp2p/discovery/MDnsDiscovery.kt b/libp2p/src/main/kotlin/io/libp2p/discovery/MDnsDiscovery.kt index 585135a58..8c7b19183 100644 --- a/libp2p/src/main/kotlin/io/libp2p/discovery/MDnsDiscovery.kt +++ b/libp2p/src/main/kotlin/io/libp2p/discovery/MDnsDiscovery.kt @@ -1,23 +1,21 @@ package io.libp2p.discovery -import io.libp2p.core.Discoverer -import io.libp2p.core.Host -import io.libp2p.core.PeerId -import io.libp2p.core.PeerInfo -import io.libp2p.core.PeerListener +import io.libp2p.core.* import io.libp2p.core.multiformats.Multiaddr +import io.libp2p.core.multiformats.MultiaddrComponent import io.libp2p.core.multiformats.Protocol import io.libp2p.discovery.mdns.AnswerListener import io.libp2p.discovery.mdns.JmDNS import io.libp2p.discovery.mdns.ServiceInfo import io.libp2p.discovery.mdns.impl.DNSRecord import io.libp2p.discovery.mdns.impl.constants.DNSRecordType -import java.net.Inet4Address -import java.net.Inet6Address -import java.net.InetAddress +import java.net.* +import java.util.* import java.util.concurrent.CompletableFuture import java.util.concurrent.CopyOnWriteArrayList import java.util.concurrent.ForkJoinPool +import java.util.stream.Collectors +import java.util.stream.Stream class MDnsDiscovery( private val host: Host, @@ -61,6 +59,10 @@ class MDnsDiscovery( newPeerFoundListeners.forEach { it(peerInfo) } } + fun addHandler(h: PeerListener) { + newPeerFoundListeners += h + } + private fun ipfsDiscoveryInfo(): ServiceInfo { return ServiceInfo.create( serviceTag, @@ -76,15 +78,91 @@ class MDnsDiscovery( val address = host.listenAddresses().find { it.has(Protocol.IP4) } - val str = address?.getFirstComponent(Protocol.TCP)?.stringValue!! + val ipv6OnlyAddress = if (address == null) { + host.listenAddresses().find { + it.has(Protocol.IP6) + } + } else { + address + } + val str = ipv6OnlyAddress?.getFirstComponent(Protocol.TCP)?.stringValue!! return Integer.parseInt(str) } + /* /ip6/::/tcp/4001 should expand to the following for example: + "/ip6/0:0:0:0:0:0:0:1/udp/4001/quic" + "/ip4/50.116.48.246/tcp/4001" + "/ip4/127.0.0.1/tcp/4001" + "/ip6/2600:3c03:0:0:f03c:92ff:fee7:bc1c/tcp/4001" + "/ip6/0:0:0:0:0:0:0:1/tcp/4001" + "/ip4/50.116.48.246/udp/4001/quic" + "/ip4/127.0.0.1/udp/4001/quic" + "/ip6/2600:3c03:0:0:f03c:92ff:fee7:bc1c/udp/4001/quic" + */ + fun expandWildcardAddresses(addr: Multiaddr): List { + // Do not include /p2p or /ipfs components which are superfluous here + if (!isWildcard(addr)) { + return java.util.List.of( + Multiaddr( + addr.components + .stream() + .filter { c: MultiaddrComponent -> + ( + c.protocol !== Protocol.P2P && + c.protocol !== Protocol.IPFS + ) + } + .collect(Collectors.toList()) + ) + ) + } + if (addr.has(Protocol.IP4)) return listNetworkAddresses(false, addr) + return if (addr.has(Protocol.IP6)) listNetworkAddresses(true, addr) else emptyList() + } + + fun listNetworkAddresses(includeIp6: Boolean, addr: Multiaddr): List { + return try { + Collections.list(NetworkInterface.getNetworkInterfaces()).stream() + .flatMap { net: NetworkInterface -> + net.interfaceAddresses.stream() + .map { obj: InterfaceAddress -> obj.address } + .filter { ip: InetAddress? -> includeIp6 || ip is Inet4Address } + } + .map { ip: InetAddress -> + Multiaddr( + Stream.concat( + Stream.of( + MultiaddrComponent( + if (ip is Inet4Address) Protocol.IP4 else Protocol.IP6, + ip.address + ) + ), + addr.components.stream() + .filter { c: MultiaddrComponent -> + c.protocol !== Protocol.IP4 && c.protocol !== Protocol.IP6 && c.protocol !== Protocol.P2P && c.protocol !== Protocol.IPFS + } + ) + .collect(Collectors.toList()) + ) + } + .collect(Collectors.toList()) + } catch (e: SocketException) { + throw RuntimeException(e) + } + } + + fun isWildcard(addr: Multiaddr): Boolean { + val s = addr.toString() + return s.contains("/::/") || s.contains("/0:0:0:0/") + } + private fun ip4Addresses() = ipAddresses(Protocol.IP4, Inet4Address::class.java) private fun ip6Addresses() = ipAddresses(Protocol.IP6, Inet6Address::class.java) private fun ipAddresses(protocol: Protocol, klass: Class): List { - return host.listenAddresses().map { + return host.listenAddresses().flatMap { + expandWildcardAddresses(it) + }.map { it.getFirstComponent(protocol) }.filterNotNull().map { InetAddress.getByAddress(localhost.hostName, it.value) @@ -105,8 +183,9 @@ class MDnsDiscovery( val aRecords = answers.filter { DNSRecordType.TYPE_A.equals(it.recordType) } val aaaaRecords = answers.filter { DNSRecordType.TYPE_AAAA.equals(it.recordType) } - if (txtRecord == null || srvRecord == null || aRecords.isEmpty()) + if (txtRecord == null || srvRecord == null || (aRecords.isEmpty() && aaaaRecords.isEmpty())) { return // incomplete answers + } txtRecord as DNSRecord.Text srvRecord as DNSRecord.Service diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/types/AsyncExt.kt b/libp2p/src/main/kotlin/io/libp2p/etc/types/AsyncExt.kt index 0e24eb4d4..176650fd3 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/types/AsyncExt.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/types/AsyncExt.kt @@ -19,6 +19,14 @@ fun CompletableFuture.bind(result: CompletableFuture) { fun CompletableFuture.forward(forwardTo: CompletableFuture) = forwardTo.bind(this) +fun CompletableFuture.forwardException(forwardTo: CompletableFuture): CompletableFuture { + return whenComplete { _, t -> + if (t != null) { + forwardTo.completeExceptionally(t) + } + } +} + /** * The same as [CompletableFuture.get] but unwraps [ExecutionException] */ @@ -61,16 +69,19 @@ fun anyComplete(all: List>): CompletableFuture = anyComplete(*all.toTypedArray()) fun anyComplete(vararg all: CompletableFuture): CompletableFuture { - return if (all.isEmpty()) completedExceptionally(NothingToCompleteException()) - else object : CompletableFuture() { - init { - val counter = AtomicInteger(all.size) - all.forEach { - it.whenComplete { v, t -> - if (t == null) { - complete(v) - } else if (counter.decrementAndGet() == 0) { - completeExceptionally(NonCompleteException(t)) + return if (all.isEmpty()) { + completedExceptionally(NothingToCompleteException()) + } else { + object : CompletableFuture() { + init { + val counter = AtomicInteger(all.size) + all.forEach { + it.whenComplete { v, t -> + if (t == null) { + complete(v) + } else if (counter.decrementAndGet() == 0) { + completeExceptionally(NonCompleteException(t)) + } } } } diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/types/ByteArrayExt.kt b/libp2p/src/main/kotlin/io/libp2p/etc/types/ByteArrayExt.kt index cd1515c6e..8bcf79edc 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/types/ByteArrayExt.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/types/ByteArrayExt.kt @@ -13,7 +13,7 @@ fun String.fromHex() = operator fun ByteArray.compareTo(other: ByteArray): Int { if (size != other.size) return size - other.size - for (i in 0 until size) { + for (i in indices) { if (this[i] != other[i]) return this[i].toInt().and(0xFF) - other[i].toInt().and(0xFF) } return 0 diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/types/Delegates.kt b/libp2p/src/main/kotlin/io/libp2p/etc/types/Delegates.kt index 9c9e3a77a..a67a3c864 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/types/Delegates.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/types/Delegates.kt @@ -1,5 +1,6 @@ package io.libp2p.etc.types +import kotlin.properties.Delegates import kotlin.properties.ReadWriteProperty import kotlin.reflect.KProperty @@ -39,18 +40,20 @@ fun cappedDouble( // thanks to https://stackoverflow.com/a/47948047/9630725 class LazyMutable(val initializer: () -> T, val rejectSetAfterGet: Boolean = false) : ReadWriteProperty { - private object UNINITIALIZED_VALUE - private var prop: Any? = UNINITIALIZED_VALUE + private object UninitializedValue + private var prop: Any? = UninitializedValue private var readAccessed = false @Suppress("UNCHECKED_CAST") override fun getValue(thisRef: Any?, property: KProperty<*>): T { - return if (prop == UNINITIALIZED_VALUE) { + return if (prop == UninitializedValue) { synchronized(this) { readAccessed = true - return if (prop == UNINITIALIZED_VALUE) initializer().also { prop = it } else prop as T + return if (prop == UninitializedValue) initializer().also { prop = it } else prop as T } - } else prop as T + } else { + prop as T + } } override fun setValue(thisRef: Any?, property: KProperty<*>, value: T) { @@ -90,3 +93,18 @@ data class CappedValueDelegate>( } } } + +fun Delegates.writeOnce(initialValue: T): ReadWriteProperty = object : ReadWriteProperty { + private var value: T = initialValue + private var wasSet = false + + public override fun getValue(thisRef: Any?, property: KProperty<*>): T { + return value + } + + public override fun setValue(thisRef: Any?, property: KProperty<*>, value: T) { + if (wasSet) throw IllegalStateException("Property ${property.name} cannot be set more than once.") + this.value = value + wasSet = true + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/types/MutableBiMultiMap.kt b/libp2p/src/main/kotlin/io/libp2p/etc/types/MutableBiMultiMap.kt index 59d86193b..0d34d0f75 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/types/MutableBiMultiMap.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/types/MutableBiMultiMap.kt @@ -55,6 +55,7 @@ operator fun MutableBiMultiMap.minusAssign(key: K) = removeKey(key) internal class MutableBiMultiMapImpl : MutableBiMultiMap { @VisibleForTesting internal val keyToValue: MutableMap = mutableMapOf() + @VisibleForTesting internal val valueToKeys: MutableMap> = mutableMapOf() diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/types/OtherExt.kt b/libp2p/src/main/kotlin/io/libp2p/etc/types/OtherExt.kt index c62262103..6ea486f61 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/types/OtherExt.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/types/OtherExt.kt @@ -18,7 +18,7 @@ class Deferrable { } fun execute() { - actions.reversed().forEach { + actions.asReversed().forEach { try { it() } catch (e: Exception) { diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/util/P2PService.kt b/libp2p/src/main/kotlin/io/libp2p/etc/util/P2PService.kt index 13c2d3f38..9ad2bdd71 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/util/P2PService.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/util/P2PService.kt @@ -34,6 +34,7 @@ abstract class P2PService( ) { private val peersMutable = mutableListOf() + /** * List of connected peers. * Note that connected peer could not be ready for writing yet, so consider [activePeers] @@ -42,6 +43,7 @@ abstract class P2PService( val peers: List = peersMutable private val activePeersMutable = mutableListOf() + /** * List of active peers to which data could be written */ @@ -241,6 +243,7 @@ abstract class P2PService( * Executes the code on the service event thread */ fun submitOnEventThread(run: () -> C): CompletableFuture = CompletableFuture.supplyAsync({ run() }, executor) + /** * Executes the code on the service event thread */ diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/AbstractChildChannel.kt b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/AbstractChildChannel.kt index d914e2a48..260606240 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/AbstractChildChannel.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/AbstractChildChannel.kt @@ -20,7 +20,10 @@ import java.net.SocketAddress */ abstract class AbstractChildChannel(parent: Channel, id: ChannelId?) : AbstractChannel(parent, id) { private enum class State { - OPEN, ACTIVE, INACTIVE, CLOSED + OPEN, + ACTIVE, + INACTIVE, + CLOSED } private val parentCloseFuture = parent.closeFuture() diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/ByteBufQueue.kt b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/ByteBufQueue.kt new file mode 100644 index 000000000..71899faaf --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/ByteBufQueue.kt @@ -0,0 +1,44 @@ +package io.libp2p.etc.util.netty + +import io.netty.buffer.ByteBuf +import io.netty.buffer.Unpooled + +class ByteBufQueue { + private val data: MutableList = mutableListOf() + + fun push(buf: ByteBuf) { + data += buf + } + + fun take(maxLength: Int): ByteBuf { + val wholeBuffers = mutableListOf() + var size = 0 + while (data.isNotEmpty()) { + val bufLen = data.first().readableBytes() + if (size + bufLen > maxLength) break + size += bufLen + wholeBuffers += data.removeAt(0) + if (size == maxLength) break + } + + val partialBufferSlice = + when { + data.isEmpty() -> null + size == maxLength -> null + else -> data.first() + } + ?.let { buf -> + val remainingBytes = maxLength - size + buf.readRetainedSlice(remainingBytes) + } + + val allBuffers = wholeBuffers + listOfNotNull(partialBufferSlice) + return Unpooled.wrappedBuffer(*allBuffers.toTypedArray()) + } + + fun dispose() { + data.forEach { it.release() } + } + + fun readableBytes(): Int = data.sumOf { it.readableBytes() } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/AbstractMuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/AbstractMuxHandler.kt index f50c3a088..a5c49b175 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/AbstractMuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/AbstractMuxHandler.kt @@ -9,7 +9,6 @@ import io.netty.channel.ChannelHandlerContext import io.netty.channel.ChannelInboundHandlerAdapter import org.slf4j.LoggerFactory import java.util.concurrent.CompletableFuture -import java.util.function.Function typealias MuxChannelInitializer = (MuxChannel) -> Unit @@ -61,10 +60,12 @@ abstract class AbstractMuxHandler() : releaseMessage(msg) throw ConnectionClosedException("Channel with id $id not opened") } + child.remoteDisconnected -> { releaseMessage(msg) throw ConnectionClosedException("Channel with id $id was closed for sending by remote") } + else -> { pendingReadComplete += id child.pipeline().fireChannelRead(msg) @@ -87,6 +88,10 @@ abstract class AbstractMuxHandler() : protected fun onRemoteOpen(id: MuxId) { val initializer = inboundInitializer + if (id in streamMap) { + getChannelHandlerContext().close() + throw Libp2pException("Remote party attempts to open a stream with existing id: $id") + } val child = createChild( id, initializer, @@ -132,7 +137,7 @@ abstract class AbstractMuxHandler() : ): MuxChannel { val child = MuxChannel(this, id, initializer, initiator) streamMap[id] = child - ctx!!.channel().eventLoop().register(child) + ctx!!.channel().eventLoop().register(child).sync() return child } @@ -144,7 +149,7 @@ abstract class AbstractMuxHandler() : try { checkClosed() // if already closed then event loop is already down and async task may never execute return activeFuture.thenApplyAsync( - Function { + { checkClosed() // close may happen after above check and before this point val child = createChild( generateNextId(), diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxId.kt b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxId.kt index 0d7051d93..e619128bc 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxId.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxId.kt @@ -2,9 +2,14 @@ package io.libp2p.etc.util.netty.mux import io.netty.channel.ChannelId -data class MuxId(val parentId: ChannelId, val id: Long, val initiator: Boolean) : ChannelId { - override fun asShortText() = "$parentId/$id/$initiator" - override fun asLongText() = asShortText() +abstract class MuxId( + val parentId: ChannelId, + val id: Long +) : ChannelId { + override fun compareTo(other: ChannelId?) = asShortText().compareTo(other?.asShortText() ?: "") override fun toString() = asLongText() + + abstract override fun hashCode(): Int + abstract override fun equals(other: Any?): Boolean } diff --git a/libp2p/src/main/kotlin/io/libp2p/host/HostImpl.kt b/libp2p/src/main/kotlin/io/libp2p/host/HostImpl.kt index 328f6018e..40df4e491 100644 --- a/libp2p/src/main/kotlin/io/libp2p/host/HostImpl.kt +++ b/libp2p/src/main/kotlin/io/libp2p/host/HostImpl.kt @@ -78,6 +78,10 @@ class HostImpl( protocolHandlers -= protocolBinding } + override fun getProtocols(): List> { + return protocolHandlers + } + override fun addConnectionHandler(handler: ConnectionHandler) { connectionHandlers += handler } diff --git a/libp2p/src/main/kotlin/io/libp2p/multistream/MultistreamImpl.kt b/libp2p/src/main/kotlin/io/libp2p/multistream/MultistreamImpl.kt index 4b7fc626d..8cf242f68 100644 --- a/libp2p/src/main/kotlin/io/libp2p/multistream/MultistreamImpl.kt +++ b/libp2p/src/main/kotlin/io/libp2p/multistream/MultistreamImpl.kt @@ -6,18 +6,14 @@ import io.libp2p.core.multistream.Multistream import io.libp2p.core.multistream.ProtocolBinding import java.time.Duration import java.util.concurrent.CompletableFuture -import java.util.concurrent.CopyOnWriteArrayList class MultistreamImpl( - initList: List> = listOf(), + override val bindings: List>, val preHandler: P2PChannelHandler<*>? = null, val postHandler: P2PChannelHandler<*>? = null, val negotiationTimeLimit: Duration = DEFAULT_NEGOTIATION_TIME_LIMIT ) : Multistream { - override val bindings: MutableList> = - CopyOnWriteArrayList(initList) - override fun initChannel(ch: P2PChannel): CompletableFuture { return with(ch) { preHandler?.also { diff --git a/libp2p/src/main/kotlin/io/libp2p/multistream/Negotiator.kt b/libp2p/src/main/kotlin/io/libp2p/multistream/Negotiator.kt index 988725dd3..ade6c8e34 100644 --- a/libp2p/src/main/kotlin/io/libp2p/multistream/Negotiator.kt +++ b/libp2p/src/main/kotlin/io/libp2p/multistream/Negotiator.kt @@ -94,8 +94,11 @@ object Negotiator { override fun channelRead0(ctx: ChannelHandlerContext, msg: String) { if (msg == MULTISTREAM_PROTO) { - if (!headerRead) headerRead = true else + if (!headerRead) { + headerRead = true + } else { throw ProtocolNegotiationException("Received multistream header more than once") + } } else { processMsg(ctx, msg)?.also { completeEvent -> // first fire event to setup a handler for selected protocol @@ -103,7 +106,7 @@ object Negotiator { ctx.pipeline().remove(this@GenericHandler) // DelimiterBasedFrameDecoder should be removed last since it // propagates unhandled bytes on removal - prehandlers.reversed().forEach { ctx.pipeline().remove(it) } + prehandlers.asReversed().forEach { ctx.pipeline().remove(it) } // activate a handler for selected protocol ctx.fireChannelActive() } diff --git a/libp2p/src/main/kotlin/io/libp2p/multistream/ProtocolSelect.kt b/libp2p/src/main/kotlin/io/libp2p/multistream/ProtocolSelect.kt index 332d8f2b6..9fef1f3f2 100644 --- a/libp2p/src/main/kotlin/io/libp2p/multistream/ProtocolSelect.kt +++ b/libp2p/src/main/kotlin/io/libp2p/multistream/ProtocolSelect.kt @@ -48,7 +48,8 @@ class ProtocolSelect(val protocols: List() val stream = newStream { streamHandler.handleStream(createStream(it)).forward(controller) - }.thenApply { it.attr(STREAM).get() } + } + .thenApply { it.attr(STREAM).get() } + .forwardException(controller) return StreamPromise(stream, controller) } diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/MuxerException.kt b/libp2p/src/main/kotlin/io/libp2p/mux/MuxerException.kt new file mode 100644 index 000000000..b424d7caa --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/mux/MuxerException.kt @@ -0,0 +1,19 @@ +package io.libp2p.mux + +import io.libp2p.core.Libp2pException +import io.libp2p.etc.util.netty.mux.MuxId + +open class MuxerException(message: String, ex: Exception?) : Libp2pException(message, ex) + +class AckBacklogLimitExceededMuxerException(message: String) : MuxerException(message, null) + +open class ReadMuxerException(message: String, ex: Exception?) : MuxerException(message, ex) +open class WriteMuxerException(message: String, ex: Exception?) : MuxerException(message, ex) + +class UnknownStreamIdMuxerException(muxId: MuxId) : ReadMuxerException("Stream with id $muxId not found", null) + +class InvalidFrameMuxerException(message: String) : ReadMuxerException(message, null) + +class WriteBufferOverflowMuxerException(message: String) : WriteMuxerException(message, null) +class ClosedForWritingMuxerException(muxId: MuxId) : + WriteMuxerException("Couldn't write, stream was closed for writing: $muxId", null) diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFrame.kt b/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFrame.kt index fb4e11bc0..e454a00bf 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFrame.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFrame.kt @@ -25,16 +25,18 @@ import io.netty.buffer.Unpooled * @param data the data segment. * @see [mplex documentation](https://github.com/libp2p/specs/tree/master/mplex#opening-a-new-stream) */ -data class MplexFrame(val id: MuxId, val flag: MplexFlag, val data: ByteBuf) : DefaultByteBufHolder(data) { +data class MplexFrame(val id: MplexId, val flag: MplexFlag, val data: ByteBuf) : DefaultByteBufHolder(data) { companion object { + private fun createFrame(id: MuxId, type: MplexFlag.Type, data: ByteBuf) = + MplexFrame(id as MplexId, MplexFlag.getByType(type, id.initiator), data) fun createDataFrame(id: MuxId, data: ByteBuf) = - MplexFrame(id, MplexFlag.getByType(MplexFlag.Type.DATA, id.initiator), data) + createFrame(id, MplexFlag.Type.DATA, data) fun createOpenFrame(id: MuxId) = - MplexFrame(id, MplexFlag.getByType(MplexFlag.Type.OPEN, id.initiator), Unpooled.EMPTY_BUFFER) + createFrame(id, MplexFlag.Type.OPEN, Unpooled.EMPTY_BUFFER) fun createCloseFrame(id: MuxId) = - MplexFrame(id, MplexFlag.getByType(MplexFlag.Type.CLOSE, id.initiator), Unpooled.EMPTY_BUFFER) + createFrame(id, MplexFlag.Type.CLOSE, Unpooled.EMPTY_BUFFER) fun createResetFrame(id: MuxId) = - MplexFrame(id, MplexFlag.getByType(MplexFlag.Type.RESET, id.initiator), Unpooled.EMPTY_BUFFER) + createFrame(id, MplexFlag.Type.RESET, Unpooled.EMPTY_BUFFER) } } diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFrameCodec.kt b/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFrameCodec.kt index 9abe21ed8..a429f6360 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFrameCodec.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFrameCodec.kt @@ -15,7 +15,6 @@ package io.libp2p.mux.mplex import io.libp2p.core.ProtocolViolationException import io.libp2p.etc.types.readUvarint import io.libp2p.etc.types.writeUvarint -import io.libp2p.etc.util.netty.mux.MuxId import io.netty.buffer.ByteBuf import io.netty.channel.ChannelHandlerContext import io.netty.handler.codec.ByteToMessageCodec @@ -75,7 +74,7 @@ class MplexFrameCodec( val data = msg.readSlice(lenData.toInt()) data.retain() // MessageToMessageCodec releases original buffer, but it needs to be relayed val flag = MplexFlag.getByValue(streamTag) - val mplexFrame = MplexFrame(MuxId(ctx.channel().id(), streamId, !flag.isInitiator), flag, data) + val mplexFrame = MplexFrame(MplexId(ctx.channel().id(), streamId, !flag.isInitiator), flag, data) out.add(mplexFrame) } } diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexHandler.kt index b87bdd8e6..4e061cbab 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexHandler.kt @@ -5,7 +5,6 @@ import io.libp2p.core.multistream.MultistreamProtocol import io.libp2p.core.mux.StreamMuxer import io.libp2p.etc.types.sliceMaxSize import io.libp2p.etc.util.netty.mux.MuxChannel -import io.libp2p.etc.util.netty.mux.MuxId import io.libp2p.mux.MuxHandler import io.netty.buffer.ByteBuf import io.netty.channel.ChannelHandlerContext @@ -22,7 +21,7 @@ open class MplexHandler( private val idGenerator = AtomicLong(0xF) override fun generateNextId() = - MuxId(getChannelHandlerContext().channel().id(), idGenerator.incrementAndGet(), true) + MplexId(getChannelHandlerContext().channel().id(), idGenerator.incrementAndGet(), true) override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { msg as MplexFrame diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexId.kt b/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexId.kt new file mode 100644 index 000000000..c69040f50 --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexId.kt @@ -0,0 +1,27 @@ +package io.libp2p.mux.mplex + +import io.libp2p.etc.util.netty.mux.MuxId +import io.netty.channel.ChannelId + +class MplexId( + parentId: ChannelId, + id: Long, + val initiator: Boolean +) : MuxId(parentId, id) { + + override fun asShortText() = "${parentId.asShortText()}/$id/$initiator" + override fun asLongText() = "${parentId.asLongText()}/$id/$initiator" + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false + other as MplexId + return id == other.id && initiator == other.initiator + } + + override fun hashCode(): Int { + var result = id.hashCode() + result = 31 * result + initiator.hashCode() + return result + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlag.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlag.kt new file mode 100644 index 000000000..34f9a10d2 --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlag.kt @@ -0,0 +1,34 @@ +package io.libp2p.mux.yamux + +import io.libp2p.mux.InvalidFrameMuxerException + +/** + * Contains all the permissible values for flags in the yamux protocol. + */ +enum class YamuxFlag(val intFlag: Int) { + SYN(1), + ACK(2), + FIN(4), + RST(8); + + val asSet: Set = setOf(this) + + companion object { + val NONE = emptySet() + + private val validFlagCombinations = mapOf( + 0 to NONE, + SYN.intFlag to SYN.asSet, + ACK.intFlag to ACK.asSet, + FIN.intFlag to FIN.asSet, + RST.intFlag to RST.asSet, + ) + + fun fromInt(flags: Int) = + validFlagCombinations[flags] ?: throw InvalidFrameMuxerException("Invalid Yamux flags value: $flags") + + fun Set.toInt() = this + .fold(0) { acc, flag -> acc or flag.intFlag } + .also { require(it in validFlagCombinations) { "Invalid Yamux flags combination: $this" } } + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlags.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlags.kt deleted file mode 100644 index 85499d0dd..000000000 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlags.kt +++ /dev/null @@ -1,11 +0,0 @@ -package io.libp2p.mux.yamux - -/** - * Contains all the permissible values for flags in the yamux protocol. - */ -object YamuxFlags { - const val SYN = 1 - const val ACK = 2 - const val FIN = 4 - const val RST = 8 -} diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt index fefdf1aee..c35dcea88 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt @@ -1,6 +1,5 @@ package io.libp2p.mux.yamux -import io.libp2p.etc.types.toByteArray import io.libp2p.etc.util.netty.mux.MuxId import io.netty.buffer.ByteBuf import io.netty.buffer.DefaultByteBufHolder @@ -8,16 +7,17 @@ import io.netty.buffer.Unpooled /** * Contains the fields that comprise a yamux frame. - * @param streamId the ID of the stream. - * @param flag the flag value for this frame. + * @param id the ID of the stream. + * @param flags the flags for this frame. + * @param length the length field for this frame. * @param data the data segment. */ -class YamuxFrame(val id: MuxId, val type: Int, val flags: Int, val lenData: Long, val data: ByteBuf? = null) : +class YamuxFrame(val id: MuxId, val type: YamuxType, val flags: Set, val length: Long, val data: ByteBuf? = null) : DefaultByteBufHolder(data ?: Unpooled.EMPTY_BUFFER) { override fun toString(): String { - if (data == null) - return "YamuxFrame(id=$id, type=$type, flag=$flags)" - return "YamuxFrame(id=$id, type=$type, flag=$flags, data=${String(data.toByteArray())})" + val dataString = if (data == null) "" else ", len=${data.readableBytes()}, $data" + val flagsString = flags.joinToString("+") + return "YamuxFrame(id=$id, type=$type, flags=$flagsString, length=$length$dataString)" } } diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt index d21fb2d4f..f2db941ec 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt @@ -1,7 +1,7 @@ package io.libp2p.mux.yamux import io.libp2p.core.ProtocolViolationException -import io.libp2p.etc.util.netty.mux.MuxId +import io.libp2p.mux.yamux.YamuxFlag.Companion.toInt import io.netty.buffer.ByteBuf import io.netty.buffer.Unpooled import io.netty.channel.ChannelHandlerContext @@ -13,7 +13,6 @@ const val DEFAULT_MAX_YAMUX_FRAME_DATA_LENGTH = 1 shl 20 * A Netty codec implementation that converts [YamuxFrame] instances to [ByteBuf] and vice-versa. */ class YamuxFrameCodec( - val isInitiator: Boolean, val maxFrameDataLength: Int = DEFAULT_MAX_YAMUX_FRAME_DATA_LENGTH ) : ByteToMessageCodec() { @@ -26,10 +25,10 @@ class YamuxFrameCodec( */ override fun encode(ctx: ChannelHandlerContext, msg: YamuxFrame, out: ByteBuf) { out.writeByte(0) // version - out.writeByte(msg.type) - out.writeShort(msg.flags) + out.writeByte(msg.type.intValue) + out.writeShort(msg.flags.toInt()) out.writeInt(msg.id.id.toInt()) - out.writeInt(msg.data?.readableBytes() ?: msg.lenData.toInt()) + out.writeInt(msg.data?.readableBytes() ?: msg.length.toInt()) out.writeBytes(msg.data ?: Unpooled.EMPTY_BUFFER) } @@ -42,32 +41,47 @@ class YamuxFrameCodec( */ override fun decode(ctx: ChannelHandlerContext, msg: ByteBuf, out: MutableList) { while (msg.isReadable) { - if (msg.readableBytes() < 12) + if (msg.readableBytes() < 12) { return + } val readerIndex = msg.readerIndex() msg.readByte(); // version always 0 val type = msg.readUnsignedByte() + val yamuxType = YamuxType.fromInt(type.toInt()) val flags = msg.readUnsignedShort() val streamId = msg.readUnsignedInt() - val lenData = msg.readUnsignedInt() - if (type.toInt() != YamuxType.DATA) { - val yamuxFrame = YamuxFrame(MuxId(ctx.channel().id(), streamId, isInitiator.xor(streamId.mod(2).equals(1)).not()), type.toInt(), flags, lenData) + val length = msg.readUnsignedInt() + val yamuxId = YamuxId(ctx.channel().id(), streamId) + val yamuxFlags = YamuxFlag.fromInt(flags) + if (yamuxType != YamuxType.DATA) { + val yamuxFrame = YamuxFrame( + yamuxId, + yamuxType, + yamuxFlags, + length + ) out.add(yamuxFrame) continue } - if (lenData > maxFrameDataLength) { + if (length > maxFrameDataLength) { msg.skipBytes(msg.readableBytes()) - throw ProtocolViolationException("Yamux frame is too large: $lenData") + throw ProtocolViolationException("Yamux frame is too large: $length") } - if (msg.readableBytes() < lenData) { + if (msg.readableBytes() < length) { // not enough data to read the frame content // will wait for more ... msg.readerIndex(readerIndex) return } - val data = msg.readSlice(lenData.toInt()) + val data = msg.readSlice(length.toInt()) data.retain() // MessageToMessageCodec releases original buffer, but it needs to be relayed - val yamuxFrame = YamuxFrame(MuxId(ctx.channel().id(), streamId, isInitiator.xor(streamId.mod(2).equals(1)).not()), type.toInt(), flags, lenData, data) + val yamuxFrame = YamuxFrame( + yamuxId, + yamuxType, + yamuxFlags, + length, + data + ) out.add(yamuxFrame) } } diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt index b92538eeb..bdde3478b 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt @@ -1,184 +1,303 @@ package io.libp2p.mux.yamux +import io.libp2p.core.ConnectionClosedException import io.libp2p.core.Libp2pException import io.libp2p.core.StreamHandler import io.libp2p.core.multistream.MultistreamProtocol import io.libp2p.core.mux.StreamMuxer import io.libp2p.etc.types.sliceMaxSize +import io.libp2p.etc.types.writeOnce +import io.libp2p.etc.util.netty.ByteBufQueue import io.libp2p.etc.util.netty.mux.MuxChannel import io.libp2p.etc.util.netty.mux.MuxId -import io.libp2p.mux.MuxHandler +import io.libp2p.mux.* import io.netty.buffer.ByteBuf import io.netty.channel.ChannelHandlerContext import java.util.concurrent.CompletableFuture import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicInteger +import kotlin.math.max +import kotlin.properties.Delegates + +const val DEFAULT_MAX_BUFFERED_CONNECTION_WRITES = 10 * 1024 * 1024 // 10 MiB +const val DEFAULT_ACK_BACKLOG_LIMIT = 256 const val INITIAL_WINDOW_SIZE = 256 * 1024 -const val MAX_BUFFERED_CONNECTION_WRITES = 1024 * 1024 open class YamuxHandler( override val multistreamProtocol: MultistreamProtocol, override val maxFrameDataLength: Int, ready: CompletableFuture?, inboundStreamHandler: StreamHandler<*>, - initiator: Boolean + private val connectionInitiator: Boolean, + private val maxBufferedConnectionWrites: Int, + private val ackBacklogLimit: Int, + private val initialWindowSize: Int = INITIAL_WINDOW_SIZE ) : MuxHandler(ready, inboundStreamHandler) { - private val idGenerator = AtomicInteger(if (initiator) 1 else 2) // 0 is reserved - private val receiveWindows = ConcurrentHashMap() - private val sendWindows = ConcurrentHashMap() - private val sendBuffers = ConcurrentHashMap() - private val totalBufferedWrites = AtomicInteger() - - inner class SendBuffer(val ctx: ChannelHandlerContext) { - private val buffered = ArrayDeque() - - fun add(data: ByteBuf) { - buffered.add(data) - } - - fun flush(sendWindow: AtomicInteger, id: MuxId): Int { - var written = 0 - while (! buffered.isEmpty()) { - val buf = buffered.first() - if (buf.readableBytes() + written < sendWindow.get()) { - buffered.removeFirst() - sendBlocks(ctx, buf, sendWindow, id) - written += buf.readableBytes() - } else - break + + private inner class YamuxStreamHandler( + val id: MuxId, + val outbound: Boolean + ) { + val acknowledged = AtomicBoolean(false) + val sendWindowSize = AtomicInteger(initialWindowSize) + val receiveWindowSize = AtomicInteger(initialWindowSize) + val sendBuffer = ByteBufQueue() + var closedForWriting by Delegates.writeOnce(false) + + fun dispose() { + sendBuffer.dispose() + } + + fun handleFrameRead(msg: YamuxFrame) { + handleFlags(msg) + when (msg.type) { + YamuxType.DATA -> handleDataRead(msg) + YamuxType.WINDOW_UPDATE -> handleWindowUpdate(msg) + else -> { + /* ignore */ + } + } + } + + private fun handleDataRead(msg: YamuxFrame) { + val size = msg.length.toInt() + if (size == 0) { + return + } + acknowledgeInboundStreamIfNeeded() + val newWindow = receiveWindowSize.addAndGet(-size) + // send a window update frame once half of the window is depleted + if (newWindow < initialWindowSize / 2) { + val delta = initialWindowSize - newWindow + receiveWindowSize.addAndGet(delta) + writeAndFlushFrame(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlag.NONE, delta.toLong())) + } + childRead(msg.id, msg.data!!) + } + + private fun handleWindowUpdate(msg: YamuxFrame) { + val delta = msg.length.toInt() + sendWindowSize.addAndGet(delta) + // try to send any buffered messages after the window update + drainBufferAndMaybeClose() + } + + private fun handleFlags(msg: YamuxFrame) { + when { + YamuxFlag.SYN in msg.flags -> { + // ACK the new stream + writeAndFlushFrame(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlag.ACK.asSet, 0)) + } + + YamuxFlag.ACK in msg.flags -> { + acknowledgeOutboundStreamIfNeeded() + } + + YamuxFlag.FIN in msg.flags -> onRemoteDisconnect(msg.id) + YamuxFlag.RST in msg.flags -> onRemoteClose(msg.id) + } + } + + private fun acknowledgeInboundStreamIfNeeded() { + if (!outbound) { + acknowledged.set(true) } - return written } + + private fun acknowledgeOutboundStreamIfNeeded() { + if (outbound) { + acknowledged.set(true) + } + } + + private fun fillBuffer(data: ByteBuf) { + sendBuffer.push(data) + val totalBufferedWrites = calculateTotalBufferedWrites() + if (totalBufferedWrites > maxBufferedConnectionWrites + sendWindowSize.get()) { + onLocalClose() + throw WriteBufferOverflowMuxerException( + "Overflowed send buffer ($totalBufferedWrites/$maxBufferedConnectionWrites). Last stream attempting to write: $id" + ) + } + } + + private fun drainBufferAndMaybeClose() { + val maxSendLength = max(0, sendWindowSize.get()) + val data = sendBuffer.take(maxSendLength) + sendWindowSize.addAndGet(-data.readableBytes()) + data.sliceMaxSize(maxFrameDataLength) + .forEach { slicedData -> + val length = slicedData.readableBytes() + writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlag.NONE, length.toLong(), slicedData)) + } + + if (closedForWriting && sendBuffer.readableBytes() == 0) { + writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlag.FIN.asSet, 0)) + } + } + + fun sendData(data: ByteBuf) { + if (closedForWriting) { + throw ClosedForWritingMuxerException(id) + } + acknowledgeInboundStreamIfNeeded() + fillBuffer(data) + drainBufferAndMaybeClose() + } + + fun onLocalOpen() { + writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlag.SYN.asSet, 0)) + } + + fun onRemoteOpen() { + // nothing + } + + fun onLocalDisconnect() { + closedForWriting = true + drainBufferAndMaybeClose() + } + + fun onLocalClose() { + // close stream immediately so not transferring buffered data + sendBuffer.dispose() + writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlag.RST.asSet, 0)) + } + } + + private val idGenerator = YamuxStreamIdGenerator(connectionInitiator) + + private val streamHandlers: MutableMap = ConcurrentHashMap() + + /** + * Would contain GoAway error code when received, or would be completed with [ConnectionClosedException] + * when the connection closed without GoAway message + */ + val goAwayPromise = CompletableFuture() + + private fun getStreamHandlerOrThrow(id: MuxId): YamuxStreamHandler = getStreamHandlerOrReleaseAndThrow(id, null) + + private fun getStreamHandlerOrReleaseAndThrow(id: MuxId, msgToRelease: ByteBuf?): YamuxStreamHandler = + streamHandlers[id] ?: run { + if (msgToRelease != null) { + releaseMessage(msgToRelease) + } + throw UnknownStreamIdMuxerException(id) + } + + override fun channelUnregistered(ctx: ChannelHandlerContext?) { + streamHandlers.values.forEach { it.dispose() } + + if (!goAwayPromise.isDone) { + goAwayPromise.completeExceptionally(ConnectionClosedException("Connection was closed without Go Away message")) + } + super.channelUnregistered(ctx) } override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { msg as YamuxFrame + when (msg.type) { - YamuxType.DATA -> handleDataRead(msg) - YamuxType.WINDOW_UPDATE -> handleWindowUpdate(msg) YamuxType.PING -> handlePing(msg) - YamuxType.GO_AWAY -> onRemoteClose(msg.id) - } - } + YamuxType.GO_AWAY -> handleGoAway(msg) + else -> { + if (YamuxFlag.SYN in msg.flags) { + // remote opens a new stream + validateSynRemoteMuxId(msg.id) + onRemoteYamuxOpen(msg.id) + } - fun handlePing(msg: YamuxFrame) { - val ctx = getChannelHandlerContext() - when (msg.flags) { - YamuxFlags.SYN -> ctx.writeAndFlush(YamuxFrame(MuxId(msg.id.parentId, 0, msg.id.initiator), YamuxType.PING, YamuxFlags.ACK, msg.lenData)) - YamuxFlags.ACK -> {} + getStreamHandlerOrReleaseAndThrow(msg.id, msg.data).handleFrameRead(msg) + } } } - fun handleFlags(msg: YamuxFrame) { - val ctx = getChannelHandlerContext() - when (msg.flags) { - YamuxFlags.SYN -> { - // ACK the new stream - onRemoteOpen(msg.id) - ctx.writeAndFlush(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 0)) - } - YamuxFlags.FIN -> onRemoteDisconnect(msg.id) - YamuxFlags.RST -> onRemoteClose(msg.id) - } + private fun writeAndFlushFrame(yamuxFrame: YamuxFrame) { + getChannelHandlerContext().writeAndFlush(yamuxFrame) } - fun handleDataRead(msg: YamuxFrame) { - val ctx = getChannelHandlerContext() - val size = msg.lenData - handleFlags(msg) - if (size.toInt() == 0) - return - val recWindow = receiveWindows.get(msg.id) - if (recWindow == null) { - releaseMessage(msg.data!!) - throw Libp2pException("No receive window for " + msg.id) - } - val newWindow = recWindow.addAndGet(-size.toInt()) - if (newWindow < INITIAL_WINDOW_SIZE / 2) { - val delta = INITIAL_WINDOW_SIZE / 2 - recWindow.addAndGet(delta) - ctx.write(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, 0, delta.toLong())) - ctx.flush() - } - childRead(msg.id, msg.data!!) + private fun abruptlyCloseConnection() { + getChannelHandlerContext().close() } - fun handleWindowUpdate(msg: YamuxFrame) { - handleFlags(msg) - val size = msg.lenData.toInt() - val sendWindow = sendWindows.get(msg.id) - if (sendWindow == null) - throw Libp2pException("No send window for " + msg.id) - sendWindow.addAndGet(size) - val buffer = sendBuffers.get(msg.id) - if (buffer != null) { - val writtenBytes = buffer.flush(sendWindow, msg.id) - totalBufferedWrites.addAndGet(-writtenBytes) + private fun validateSynRemoteMuxId(id: MuxId) { + val isRemoteConnectionInitiator = !connectionInitiator + if (!YamuxStreamIdGenerator.isRemoteSynStreamIdValid(isRemoteConnectionInitiator, id.id)) { + abruptlyCloseConnection() + throw Libp2pException("Invalid remote SYN StreamID: $id, isRemoteInitiator: $isRemoteConnectionInitiator") } } override fun onChildWrite(child: MuxChannel, data: ByteBuf) { - val ctx = getChannelHandlerContext() - - val sendWindow = sendWindows.get(child.id) - if (sendWindow == null) - throw Libp2pException("No send window for " + child.id) - if (sendWindow.get() <= 0) { - // wait until the window is increased to send more data - val buffer = sendBuffers.getOrPut(child.id, { SendBuffer(ctx) }) - buffer.add(data) - if (totalBufferedWrites.addAndGet(data.readableBytes()) > MAX_BUFFERED_CONNECTION_WRITES) - throw Libp2pException("Overflowed send buffer for connection") - return - } - sendBlocks(ctx, data, sendWindow, child.id) - } - - fun sendBlocks(ctx: ChannelHandlerContext, data: ByteBuf, sendWindow: AtomicInteger, id: MuxId) { - data.sliceMaxSize(minOf(maxFrameDataLength, sendWindow.get())) - .map { frameSliceBuf -> - sendWindow.addAndGet(-frameSliceBuf.readableBytes()) - YamuxFrame(id, YamuxType.DATA, 0, frameSliceBuf.readableBytes().toLong(), frameSliceBuf) - }.forEach { muxFrame -> - ctx.write(muxFrame) - } - ctx.flush() + getStreamHandlerOrReleaseAndThrow(child.id, data).sendData(data) } override fun onLocalOpen(child: MuxChannel) { - onStreamCreate(child) - getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.SYN, 0)) + verifyAckBacklogLimitNotReached(child.id, true) + createYamuxStreamHandler(child.id, true).onLocalOpen() } - override fun onRemoteCreated(child: MuxChannel) { - onStreamCreate(child) + private fun onRemoteYamuxOpen(id: MuxId) { + verifyAckBacklogLimitNotReached(id, false) + createYamuxStreamHandler(id, false).onRemoteOpen() + onRemoteOpen(id) } - private fun onStreamCreate(child: MuxChannel) { - receiveWindows.put(child.id, AtomicInteger(INITIAL_WINDOW_SIZE)) - sendWindows.put(child.id, AtomicInteger(INITIAL_WINDOW_SIZE)) + private fun verifyAckBacklogLimitNotReached(id: MuxId, outbound: Boolean) { + val totalUnacknowledgedStreams = + streamHandlers.values.count { it.outbound == outbound && !it.acknowledged.get() } + if (totalUnacknowledgedStreams >= ackBacklogLimit) { + throw AckBacklogLimitExceededMuxerException("The ACK backlog limit of $ackBacklogLimit streams has been reached. Will not open new stream: $id") + } + } + + private fun createYamuxStreamHandler(id: MuxId, outbound: Boolean): YamuxStreamHandler { + val streamHandler = YamuxStreamHandler(id, outbound) + streamHandlers[id] = streamHandler + return streamHandler } override fun onLocalDisconnect(child: MuxChannel) { - val sendWindow = sendWindows.remove(child.id) - val buffered = sendBuffers.remove(child.id) - if (buffered != null && sendWindow != null) { - buffered.flush(sendWindow, child.id) - } - getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.FIN, 0)) + getStreamHandlerOrThrow(child.id).onLocalDisconnect() } override fun onLocalClose(child: MuxChannel) { - getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.RST, 0)) + streamHandlers.remove(child.id)?.onLocalClose() } override fun onChildClosed(child: MuxChannel) { - sendWindows.remove(child.id) - receiveWindows.remove(child.id) - sendBuffers.remove(child.id) + streamHandlers.remove(child.id)?.dispose() + } + + private fun handlePing(msg: YamuxFrame) { + if (msg.id.id != YamuxId.SESSION_STREAM_ID) { + throw InvalidFrameMuxerException("Invalid StreamId for Ping frame type: ${msg.id}") + } + if (YamuxFlag.SYN in msg.flags) { + writeAndFlushFrame( + YamuxFrame( + YamuxId.sessionId(msg.id.parentId), + YamuxType.PING, + YamuxFlag.ACK.asSet, + msg.length + ) + ) + } + } + + private fun handleGoAway(msg: YamuxFrame) { + if (msg.id.id != YamuxId.SESSION_STREAM_ID) { + throw InvalidFrameMuxerException("Invalid StreamId for GoAway frame type: ${msg.id}") + } + goAwayPromise.complete(msg.length) + } + + private fun calculateTotalBufferedWrites(): Int { + return streamHandlers.values.sumOf { it.sendBuffer.readableBytes() } } override fun generateNextId() = - MuxId(getChannelHandlerContext().channel().id(), idGenerator.addAndGet(2).toLong(), true) + YamuxId(getChannelHandlerContext().channel().id(), idGenerator.next()) } diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxId.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxId.kt new file mode 100644 index 000000000..f32e661ee --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxId.kt @@ -0,0 +1,26 @@ +package io.libp2p.mux.yamux + +import io.libp2p.etc.util.netty.mux.MuxId +import io.netty.channel.ChannelId + +class YamuxId( + parentId: ChannelId, + id: Long, +) : MuxId(parentId, id) { + + override fun asShortText() = "${parentId.asShortText()}/$id" + override fun asLongText() = "${parentId.asLongText()}/$id" + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false + return (other as YamuxId).id == id + } + + override fun hashCode(): Int = id.hashCode() + + companion object { + val SESSION_STREAM_ID = 0L + fun sessionId(parentId: ChannelId) = YamuxId(parentId, SESSION_STREAM_ID) + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxStreamIdGenerator.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxStreamIdGenerator.kt new file mode 100644 index 000000000..7be5c5df1 --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxStreamIdGenerator.kt @@ -0,0 +1,15 @@ +package io.libp2p.mux.yamux + +import java.util.concurrent.atomic.AtomicLong + +class YamuxStreamIdGenerator(connectionInitiator: Boolean) { + + private val idCounter = AtomicLong(if (connectionInitiator) 1L else 2L) // 0 is reserved + + fun next() = idCounter.getAndAdd(2) + + companion object { + fun isRemoteSynStreamIdValid(isRemoteConnectionInitiator: Boolean, id: Long) = + id > 0 && (if (isRemoteConnectionInitiator) id % 2 == 1L else id % 2 == 0L) + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxStreamMuxer.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxStreamMuxer.kt index 4b43a0597..b64d81389 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxStreamMuxer.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxStreamMuxer.kt @@ -12,7 +12,9 @@ import java.util.concurrent.CompletableFuture class YamuxStreamMuxer( val inboundStreamHandler: StreamHandler<*>, - private val multistreamProtocol: MultistreamProtocol + private val multistreamProtocol: MultistreamProtocol, + private val maxBufferedConnectionWrites: Int, + private val ackBacklogLimit: Int ) : StreamMuxer, StreamMuxerDebug { override val protocolDescriptor = ProtocolDescriptor("/yamux/1.0.0") @@ -21,7 +23,7 @@ class YamuxStreamMuxer( override fun initChannel(ch: P2PChannel, selectedProtocol: String): CompletableFuture { val muxSessionReady = CompletableFuture() - val yamuxFrameCodec = YamuxFrameCodec(ch.isInitiator) + val yamuxFrameCodec = YamuxFrameCodec() ch.pushHandler(yamuxFrameCodec) muxFramesDebugHandler?.also { it.visit(ch as Connection) } ch.pushHandler( @@ -30,7 +32,9 @@ class YamuxStreamMuxer( yamuxFrameCodec.maxFrameDataLength, muxSessionReady, inboundStreamHandler, - ch.isInitiator + ch.isInitiator, + maxBufferedConnectionWrites, + ackBacklogLimit ) ) diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt index cf66f4b8b..db779e7f9 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt @@ -1,11 +1,20 @@ package io.libp2p.mux.yamux +import io.libp2p.mux.InvalidFrameMuxerException + /** - * Contains all the permissible values for flags in the yamux protocol. + * Contains all the permissible values for types in the yamux protocol. */ -object YamuxType { - const val DATA = 0 - const val WINDOW_UPDATE = 1 - const val PING = 2 - const val GO_AWAY = 3 +enum class YamuxType(val intValue: Int) { + DATA(0), + WINDOW_UPDATE(1), + PING(2), + GO_AWAY(3); + + companion object { + private val intToTypeCache = values().associateBy { it.intValue } + + fun fromInt(intValue: Int): YamuxType = + intToTypeCache[intValue] ?: throw InvalidFrameMuxerException("Invalid Yamux type value: $intValue") + } } diff --git a/libp2p/src/main/kotlin/io/libp2p/protocol/Ping.kt b/libp2p/src/main/kotlin/io/libp2p/protocol/Ping.kt index 616cd6450..1a181c96b 100644 --- a/libp2p/src/main/kotlin/io/libp2p/protocol/Ping.kt +++ b/libp2p/src/main/kotlin/io/libp2p/protocol/Ping.kt @@ -22,20 +22,23 @@ interface PingController { fun ping(): CompletableFuture } -class Ping : PingBinding(PingProtocol()) +class Ping(pingSize: Int) : PingBinding(PingProtocol(pingSize)) { + constructor() : this(32) +} open class PingBinding(ping: PingProtocol) : StrictProtocolBinding("/ipfs/ping/1.0.0", ping) class PingTimeoutException : Libp2pException() -open class PingProtocol : ProtocolHandler(Long.MAX_VALUE, Long.MAX_VALUE) { +open class PingProtocol(var pingSize: Int) : ProtocolHandler(Long.MAX_VALUE, Long.MAX_VALUE) { var timeoutScheduler by lazyVar { Executors.newSingleThreadScheduledExecutor() } var curTime: () -> Long = { System.currentTimeMillis() } var random = Random() - var pingSize = 32 var pingTimeout = Duration.ofSeconds(10) + constructor() : this(32) + override fun onStartInitiator(stream: Stream): CompletableFuture { val handler = PingInitiator() stream.pushHandler(handler) @@ -100,7 +103,8 @@ open class PingProtocol : ProtocolHandler(Long.MAX_VALUE, Long.M { requests.remove(dataS)?.second?.completeExceptionally(PingTimeoutException()) }, - pingTimeout.toMillis(), TimeUnit.MILLISECONDS + pingTimeout.toMillis(), + TimeUnit.MILLISECONDS ) } diff --git a/libp2p/src/main/kotlin/io/libp2p/protocol/ProtocolMessageHandlerAdapter.kt b/libp2p/src/main/kotlin/io/libp2p/protocol/ProtocolMessageHandlerAdapter.kt index 4a3dd52ab..a86dd2cce 100644 --- a/libp2p/src/main/kotlin/io/libp2p/protocol/ProtocolMessageHandlerAdapter.kt +++ b/libp2p/src/main/kotlin/io/libp2p/protocol/ProtocolMessageHandlerAdapter.kt @@ -54,10 +54,12 @@ class ProtocolMessageHandlerAdapter( } private fun checkedRelease(count: Int, obj: Any) { - if (count == -1) + if (count == -1) { return + } val rc = obj as ReferenceCounted - if (count == rc.refCnt()) + if (count == rc.refCnt()) { rc.release() + } } } diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/AbstractRouter.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/AbstractRouter.kt index 1c9d78e50..d5e16401c 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/AbstractRouter.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/AbstractRouter.kt @@ -17,8 +17,6 @@ import java.util.Collections.singletonList import java.util.Optional import java.util.concurrent.CompletableFuture import java.util.concurrent.ScheduledExecutorService -import java.util.function.BiConsumer -import java.util.function.Consumer // 1 MB default max message size const val DEFAULT_MAX_PUBSUB_MESSAGE_SIZE = 1 shl 20 @@ -87,7 +85,6 @@ abstract class AbstractRouter( /** * Flushes all pending message parts for all peers - * @see addPendingRpcPart */ protected fun flushAllPending() { pendingRpcParts.pendingPeers.forEach(::flushPending) @@ -165,7 +162,7 @@ abstract class AbstractRouter( // Validate message if (!validateMessageListLimits(msg)) { - logger.debug("Dropping msg with lists exceeding limits from peer $peer") + logger.debug("Dropping msg with lists exceeding limits from peer {}", peer) return } @@ -175,7 +172,7 @@ abstract class AbstractRouter( .filterIncomingSubscriptions(subscriptions, peersTopics.getByFirst(peer)) .forEach { handleMessageSubscriptions(peer, it) } } catch (e: Exception) { - logger.debug("Subscription filter error, ignoring message from peer $peer", e) + logger.debug("Subscription filter error, ignoring message from peer {}", peer, e) return } @@ -184,7 +181,7 @@ abstract class AbstractRouter( } val (msgSubscribed, nonSubscribed) = msg.publishList - .partition { it.topicIDsList.any { it in subscribedTopics } } + .partition { rpcMsg -> rpcMsg.topicIDsList.any { it in subscribedTopics } } nonSubscribed.forEach { notifyNonSubscribedMessage(peer, it) } @@ -194,7 +191,7 @@ abstract class AbstractRouter( val validationResult = seenMessages[subscribedMessage] if (validationResult != null) { // Message has been seen - notifySeenMessage(peer, seenMessages.getSeenMessage(subscribedMessage), validationResult) + notifySeenMessage(peer, seenMessages.getSeenMessageCached(subscribedMessage), validationResult) false } else { // Message is unseen @@ -209,7 +206,7 @@ abstract class AbstractRouter( messageValidator.validate(it) true } catch (e: Exception) { - logger.debug("Invalid pubsub message from peer $peer: $it", e) + logger.debug("Invalid pubsub message from peer {}: {}", peer, it, e) seenMessages[it] = Optional.of(ValidationResult.Invalid) notifyUnseenInvalidMessage(peer, it) false @@ -223,7 +220,7 @@ abstract class AbstractRouter( validFuts.forEach { (msg, validationFut) -> validationFut.thenAcceptAsync( - Consumer { res -> + { res -> seenMessages[msg] = Optional.of(res) if (res == ValidationResult.Invalid) notifyUnseenInvalidMessage(peer, msg) }, @@ -247,11 +244,19 @@ abstract class AbstractRouter( // broadcast others on completion undone.forEach { it.second.whenCompleteAsync( - BiConsumer { res, err -> + { res, err -> when { err != null -> logger.warn("Exception while handling message from peer $peer: ${it.first}", err) - res == ValidationResult.Invalid -> logger.debug("Invalid pubsub message from peer $peer: ${it.first}") - res == ValidationResult.Ignore -> logger.trace("Ignoring pubsub message from peer $peer: ${it.first}") + res == ValidationResult.Invalid -> logger.debug( + "Invalid pubsub message from peer {}: {}", + peer, + it.first + ) + res == ValidationResult.Ignore -> logger.trace( + "Ignoring pubsub message from peer {}: {}", + peer, + it.first + ) else -> { newValidatedMessages(singletonList(it.first), peer) flushAllPending() @@ -275,15 +280,15 @@ abstract class AbstractRouter( override fun onPeerWireException(peer: PeerHandler?, cause: Throwable) { // exception occurred in protobuf decoders - logger.debug("Malformed message from $peer : $cause") + logger.debug("Malformed message from {} : {}", peer, cause) peer?.also { notifyMalformedMessage(it) } } override fun onServiceException(peer: PeerHandler?, msg: Any?, cause: Throwable) { if (cause is BadPeerException) { - logger.debug("Remote peer ($peer) misbehaviour on message $msg: $cause") + logger.debug("Remote peer ({}) misbehaviour on message {} : {}", peer, msg, cause) } else { - logger.warn("AbstractRouter internal error on message $msg from peer $peer", cause) + logger.warn("AbstractRouter internal error on message {} from peer {}", msg, peer, cause) } } @@ -323,7 +328,11 @@ abstract class AbstractRouter( override fun getPeerTopics(): CompletableFuture>> { return submitOnEventThread { - peersTopics.asFirstToSecondMap().mapKeys { it.key.peerId } + peersTopics.asFirstToSecondMap() + .map { (key, value) -> + key.peerId to value.toSet() + } + .toMap() } } diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubApiImpl.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubApiImpl.kt index dc4bfdf19..fddc263f2 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubApiImpl.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubApiImpl.kt @@ -81,8 +81,11 @@ open class PubsubApiImpl(val router: PubsubRouter) : PubsubApi { it.receiver.apply(rpc2Msg(msg)) } return validationFuts.thenApplyAll { - if (it.isEmpty()) ValidationResult.Ignore - else it.reduce(validationResultReduce) + if (it.isEmpty()) { + ValidationResult.Ignore + } else { + it.reduce(validationResultReduce) + } } } @@ -139,8 +142,10 @@ class MessageImpl(override val originalMessage: PubsubMessage) : MessageApi { private val msg = originalMessage.protobufMessage override val data = msg.data.toByteArray().toByteBuf() override val from = if (msg.hasFrom()) msg.from.toByteArray() else null - override val seqId = if (msg.hasSeqno() && msg.seqno.size() >= 8) + override val seqId = if (msg.hasSeqno() && msg.seqno.size() >= 8) { msg.seqno.toByteArray().copyOfRange(0, 8).toLongBigEndian() - else null + } else { + null + } override val topics = msg.topicIDsList.map { Topic(it) } } diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubProtocol.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubProtocol.kt index afab564d1..49cf95239 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubProtocol.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubProtocol.kt @@ -6,10 +6,25 @@ enum class PubsubProtocol(val announceStr: ProtocolId) { Gossip_V_1_0("/meshsub/1.0.0"), Gossip_V_1_1("/meshsub/1.1.0"), + Gossip_V_1_2("/meshsub/1.2.0"), Floodsub("/floodsub/1.0.0"); companion object { fun fromProtocol(protocol: ProtocolId) = PubsubProtocol.values().find { protocol == it.announceStr } ?: throw NoSuchElementException("No PubsubProtocol found with protocol $protocol") } + + /** + * https://github.com/libp2p/specs/blob/master/pubsub/gossipsub/gossipsub-v1.1.md#prune-backoff-and-peer-exchange + */ + fun supportsBackoffAndPX(): Boolean { + return this == Gossip_V_1_1 || this == Gossip_V_1_2 + } + + /** + * https://github.com/libp2p/specs/blob/master/pubsub/gossipsub/gossipsub-v1.2.md#idontwant-message + */ + fun supportsIDontWant(): Boolean { + return this == Gossip_V_1_2 + } } diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubRouter.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubRouter.kt index ae3d94b16..c960fdb68 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubRouter.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubRouter.kt @@ -22,6 +22,9 @@ interface PubsubMessage { val topics: List get() = protobufMessage.topicIDsList + val size: Int + get() = protobufMessage.data.size() + fun messageSha256() = sha256(protobufMessage.toByteArray()) override fun equals(other: Any?): Boolean diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/SeenCache.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/SeenCache.kt index 239eef021..9136899ac 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/SeenCache.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/SeenCache.kt @@ -17,60 +17,62 @@ import java.util.LinkedList */ interface SeenCache { val size: Int - val messages: Collection - /** - * Returns the 'matching' key if it already exist in the cache or returns the argument if not - */ - fun getSeenMessage(msg: PubsubMessage): PubsubMessage - fun getValue(msg: PubsubMessage): TValue? + fun put(msg: PubsubMessage, value: TValue) + fun get(msg: PubsubMessage): TValue? fun isSeen(msg: PubsubMessage): Boolean fun isSeen(messageId: MessageId): Boolean - fun put(msg: PubsubMessage, value: TValue) - fun remove(msg: PubsubMessage) + fun remove(messageId: MessageId) + + /** + * Returns the 'matching' message if it exists in the cache or falls back to returning the argument if not + * The returned instance may have some data prepared and cached (e.g. `messageId`) which may + * have positive performance effect + */ + fun getSeenMessageCached(msg: PubsubMessage): PubsubMessage } -operator fun SeenCache.get(msg: PubsubMessage) = getValue(msg) +operator fun SeenCache.get(msg: PubsubMessage) = get(msg) operator fun SeenCache.set(msg: PubsubMessage, value: TValue) = put(msg, value) operator fun SeenCache.contains(msg: PubsubMessage) = isSeen(msg) -operator fun SeenCache.minusAssign(msg: PubsubMessage) = remove(msg) +operator fun SeenCache.minusAssign(messageId: MessageId) = remove(messageId) class SimpleSeenCache : SeenCache { - private val map: MutableMap> = mutableMapOf() + private val map: MutableMap = mutableMapOf() override val size: Int get() = map.size - override val messages: Collection - get() = map.values.map { it.first } - override fun getSeenMessage(msg: PubsubMessage) = msg - override fun getValue(msg: PubsubMessage) = map[msg.messageId]?.second + override fun getSeenMessageCached(msg: PubsubMessage) = msg + override fun get(msg: PubsubMessage) = map[msg.messageId] override fun isSeen(msg: PubsubMessage) = msg.messageId in map override fun isSeen(messageId: MessageId) = messageId in map override fun put(msg: PubsubMessage, value: TValue) { - map[msg.messageId] = msg to value + map[msg.messageId] = value + } + override fun remove(messageId: MessageId) { + map -= messageId } - override fun remove(msg: PubsubMessage) { map -= msg.messageId } } class LRUSeenCache(val delegate: SeenCache, private val maxSize: Int) : SeenCache by delegate { - val evictingQueue = LinkedList() + val evictingQueue = LinkedList() override fun put(msg: PubsubMessage, value: TValue) { val oldSize = delegate.size delegate[msg] = value if (oldSize < delegate.size) { - evictingQueue += msg + evictingQueue += msg.messageId if (evictingQueue.size > maxSize) { delegate -= evictingQueue.removeFirst() } } } - override fun remove(msg: PubsubMessage) { - delegate -= msg - evictingQueue -= msg + override fun remove(messageId: MessageId) { + delegate -= messageId + evictingQueue -= messageId } } @@ -80,13 +82,13 @@ class TTLSeenCache( private val curTime: () -> Long ) : SeenCache by delegate { - data class TimedMessage(val time: Long, val message: PubsubMessage) + data class TimedMessage(val time: Long, val messageId: MessageId) val putTimes = LinkedList() override fun put(msg: PubsubMessage, value: TValue) { delegate[msg] = value - putTimes += TimedMessage(curTime(), msg) + putTimes += TimedMessage(curTime(), msg.messageId) pruneOld() } @@ -98,7 +100,7 @@ class TTLSeenCache( if (n.time >= pruneBefore) { break } - delegate -= n.message + delegate -= n.messageId it.remove() } } @@ -106,35 +108,46 @@ class TTLSeenCache( class FastIdSeenCache(private val fastIdFunction: (PubsubMessage) -> Any) : SeenCache { val fastIdMap = mutableBiMultiMap() - val slowIdMap: MutableMap> = mutableMapOf() + val slowIdMap: MutableMap = mutableMapOf() override val size: Int get() = slowIdMap.size - override val messages: Collection - get() = slowIdMap.values.map { it.first } - override fun getSeenMessage(msg: PubsubMessage): PubsubMessage { + override fun getSeenMessageCached(msg: PubsubMessage): PubsubMessage { val slowId = fastIdMap[fastIdFunction(msg)] - return if (slowId == null) msg else slowIdMap[slowId]!!.first + return when { + slowId == null -> msg + msg is FastIdPubsubMessage -> msg + else -> FastIdPubsubMessage(msg, slowId) + } } - override fun getValue(msg: PubsubMessage): TValue? { + override fun get(msg: PubsubMessage): TValue? { val slowId = fastIdMap[fastIdFunction(msg)] ?: msg.messageId - return slowIdMap[slowId]?.second + return slowIdMap[slowId] } override fun isSeen(msg: PubsubMessage) = fastIdFunction(msg) in fastIdMap || msg.messageId in slowIdMap + override fun isSeen(messageId: MessageId) = messageId in slowIdMap override fun put(msg: PubsubMessage, value: TValue) { fastIdMap[fastIdFunction(msg)] = msg.messageId - slowIdMap[msg.messageId] = msg to value + slowIdMap[msg.messageId] = value } - override fun remove(msg: PubsubMessage) { - val slowId = msg.messageId - slowIdMap -= slowId - fastIdMap.removeAllByValue(slowId) + override fun remove(messageId: MessageId) { + slowIdMap -= messageId + fastIdMap.removeAllByValue(messageId) } + + /** + * Wraps [delegate] instance and overrides [messageId] with a cached value + * to avoid slow `messageId` computation by the [delegate] instance + */ + private class FastIdPubsubMessage( + val delegate: PubsubMessage, + override val messageId: MessageId + ) : PubsubMessage by delegate } diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/Gossip.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/Gossip.kt index 0645ddf46..ae5f3c5e2 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/Gossip.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/Gossip.kt @@ -30,13 +30,24 @@ class Gossip @JvmOverloads constructor( } override val protocolDescriptor = - if (router.protocol == PubsubProtocol.Gossip_V_1_1) - ProtocolDescriptor( - PubsubProtocol.Gossip_V_1_1.announceStr, - PubsubProtocol.Gossip_V_1_0.announceStr - ) - else - ProtocolDescriptor(PubsubProtocol.Gossip_V_1_0.announceStr) + when (router.protocol) { + PubsubProtocol.Gossip_V_1_2 -> { + ProtocolDescriptor( + PubsubProtocol.Gossip_V_1_2.announceStr, + PubsubProtocol.Gossip_V_1_1.announceStr, + PubsubProtocol.Gossip_V_1_0.announceStr + ) + } + PubsubProtocol.Gossip_V_1_1 -> { + ProtocolDescriptor( + PubsubProtocol.Gossip_V_1_1.announceStr, + PubsubProtocol.Gossip_V_1_0.announceStr + ) + } + else -> { + ProtocolDescriptor(PubsubProtocol.Gossip_V_1_0.announceStr) + } + } override fun handleConnection(conn: Connection) { conn.muxerSession().createStream(listOf(this)) diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipParams.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipParams.kt index fb7fa6180..654683ab6 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipParams.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipParams.kt @@ -22,6 +22,10 @@ fun defaultDLazy(D: Int) = D fun defaultDScore(D: Int) = D * 2 / 3 fun defaultDOut(D: Int, DLow: Int) = min(D / 2, max(DLow - 1, 0)) +// floodPublishMaxMessageSizeThreshold shortcuts +const val NEVER_FLOOD_PUBLISH = 0 +const val ALWAYS_FLOOD_PUBLISH = Int.MAX_VALUE + /** * Parameters of Gossip 1.1 router */ @@ -112,11 +116,16 @@ data class GossipParams( val seenTTL: Duration = 2.minutes, /** - * [floodPublish] is a gossipsub router option that enables flood publishing. - * When this is enabled, published messages are forwarded to all peers with score >= - * to publishThreshold + * [floodPublishMaxMessageSizeThreshold] controls the maximum size (in bytes) a message will be + * published using flood publishing mode. + * When a message size is <= [floodPublishMaxMessageSizeThreshold], published messages are forwarded + * to all peers with score >= to [GossipScoreParams.publishThreshold] + * + * [NEVER_FLOOD_PUBLISH] and [ALWAYS_FLOOD_PUBLISH] can be used as shortcuts. + * + * The default is [NEVER_FLOOD_PUBLISH] (0 KiB). */ - val floodPublish: Boolean = false, + val floodPublishMaxMessageSizeThreshold: Int = NEVER_FLOOD_PUBLISH, /** * [gossipFactor] affects how many peers we will emit gossip to at each heartbeat. @@ -193,17 +202,17 @@ data class GossipParams( val maxGraftMessages: Int? = null, /** - * [maxPrunePeers] controls the number of peers to include in prune Peer eXchange. + * [maxPeersSentInPruneMsg] controls the number of peers to include in prune Peer eXchange. * When we prune a peer that's eligible for PX (has a good score, etc), we will try to - * send them signed peer records for up to [maxPrunePeers] other peers that we + * send them signed peer records for up to [maxPeersSentInPruneMsg] other peers that we * know of. */ - val maxPrunePeers: Int = 16, + val maxPeersSentInPruneMsg: Int = 16, /** - * [maxPeersPerPruneMessage] is the maximum number of peers allowed in an incoming prune message + * [maxPeersAcceptedInPruneMsg] is the maximum number of peers allowed in an incoming prune message */ - val maxPeersPerPruneMessage: Int? = null, + val maxPeersAcceptedInPruneMsg: Int = 16, /** * [pruneBackoff] controls the backoff time for pruned peers. This is how long @@ -231,7 +240,24 @@ data class GossipParams( * callback to notify outer system to which peers Gossip wants to be connected * The second parameter is a signed peer record: https://github.com/libp2p/specs/pull/217 */ - val connectCallback: (PeerId, ByteArray) -> Unit = { _: PeerId, _: ByteArray -> } + val connectCallback: (PeerId, ByteArray) -> Unit = { _: PeerId, _: ByteArray -> }, + + /** + * [maxIDontWantMessageIds] is the maximum number of IDONTWANT message ids allowed per heartbeat per peer + */ + val maxIDontWantMessageIds: Int = maxIHaveLength * maxIHaveMessages, + + /** + * [iDontWantMinMessageSizeThreshold] controls the minimum size (in bytes) that an incoming message needs to be so that an IDONTWANT message is sent to mesh peers. + * The default is 16 KiB. + */ + val iDontWantMinMessageSizeThreshold: Int = 16384, + + /** + * [iDontWantTTL] Expiry time for cache of received IDONTWANT messages for peers + */ + val iDontWantTTL: Duration = 3.seconds + ) { init { check(D >= 0, "D should be >= 0") @@ -243,6 +269,8 @@ data class GossipParams( check(DLow <= D, "DLow should be <= D") check(DHigh >= D, "DHigh should be >= D") check(gossipFactor in 0.0..1.0, "gossipFactor should be in range [0.0, 1.0]") + check(floodPublishMaxMessageSizeThreshold >= 0, "floodPublishMaxMessageSizeThreshold should be >= 0") + check(iDontWantMinMessageSizeThreshold >= 0, "iDontWantMinMessageSizeThreshold should be >= 0") } companion object { @@ -485,7 +513,7 @@ data class GossipTopicScoreParams( * The penalty is only activated after [meshMessageDeliveriesActivation] time in the mesh. * The weight of the parameter MUST be negative (or zero to disable). */ - val meshMessageDeliveriesWeight: Weight = 0.0 /*-1.0*/, // TODO temporarily exclude this parameter + val meshMessageDeliveriesWeight: Weight = 0.0, // TODO temporarily exclude this parameter /** @see meshMessageDeliveriesWeight */ val meshMessageDeliveriesDecay: Double = 0.0, /** @see meshMessageDeliveriesWeight */ diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt index 2d3b21625..b385aaa3b 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt @@ -23,7 +23,6 @@ import kotlin.collections.any import kotlin.collections.component1 import kotlin.collections.component2 import kotlin.collections.count -import kotlin.collections.distinct import kotlin.collections.drop import kotlin.collections.filter import kotlin.collections.filterNot @@ -42,11 +41,9 @@ import kotlin.collections.mutableSetOf import kotlin.collections.none import kotlin.collections.plus import kotlin.collections.plusAssign -import kotlin.collections.reversed import kotlin.collections.set import kotlin.collections.shuffled import kotlin.collections.sortedBy -import kotlin.collections.sum import kotlin.collections.take import kotlin.collections.toMutableSet import kotlin.math.max @@ -56,6 +53,7 @@ const val MaxBackoffEntries = 10 * 1024 const val MaxIAskedEntries = 256 const val MaxPeerIHaveEntries = 256 const val MaxIWantRequestsEntries = 10 * 1024 +const val MaxPeerIDontWantEntries = 256 typealias CurrentTimeSupplier = () -> Long @@ -122,6 +120,7 @@ open class GossipRouter( private val iAsked = createLRUMap(MaxIAskedEntries) private val peerIHave = createLRUMap(MaxPeerIHaveEntries) private val iWantRequests = createLRUMap, Long>(MaxIWantRequestsEntries) + private val peerIDontWant = createLRUMap(MaxPeerIDontWantEntries) private val heartbeatTask by lazy { executor.scheduleWithFixedDelay( ::catchingHeartbeat, @@ -146,7 +145,9 @@ open class GossipRouter( return currentTimeSupplier() < expire - (params.pruneBackoff + params.graftFloodThreshold).toMillis() } - private fun getDirectPeers() = peers.filter(::isDirect) + private fun getDirectPeers(topic: Topic): List { + return getTopicPeers(topic).filter(::isDirect) + } private fun isDirect(peer: PeerHandler) = scoreParams.peerScoreParams.isDirect(peer.peerId) private fun isConnected(peerId: PeerId) = peers.any { it.peerId == peerId } @@ -166,6 +167,7 @@ open class GossipRouter( } override fun notifyUnseenMessage(peer: PeerHandler, msg: PubsubMessage) { + iDontWant(msg, peer) eventBroadcaster.notifyUnseenMessage(peer.peerId, msg) notifyAnyMessage(peer, msg) } @@ -234,7 +236,6 @@ open class GossipRouter( curTime <= whitelistEntry.whitelistedTill && whitelistEntry.messagesAccepted < acceptRequestsWhitelistMaxMessages ) { - acceptRequestsWhitelist[peer] = whitelistEntry.incrementMessageCount() return true } @@ -251,8 +252,8 @@ open class GossipRouter( } override fun validateMessageListLimits(msg: Rpc.RPCOrBuilder): Boolean { - val iWantMessageIdCount = msg.control?.iwantList?.map { w -> w.messageIDsCount }?.sum() ?: 0 - val iHaveMessageIdCount = msg.control?.ihaveList?.map { w -> w.messageIDsCount }?.sum() ?: 0 + val iWantMessageIdCount = msg.control?.iwantList?.sumOf { w -> w.messageIDsCount } ?: 0 + val iHaveMessageIdCount = msg.control?.ihaveList?.sumOf { w -> w.messageIDsCount } ?: 0 return params.maxPublishedMessages?.let { msg.publishCount <= it } ?: true && params.maxTopicsPerPublishedMessage?.let { msg.publishList.none { m -> m.topicIDsCount > it } } ?: true && @@ -261,7 +262,7 @@ open class GossipRouter( params.maxIWantMessageIds?.let { iWantMessageIdCount <= it } ?: true && params.maxGraftMessages?.let { (msg.control?.graftCount ?: 0) <= it } ?: true && params.maxPruneMessages?.let { (msg.control?.pruneCount ?: 0) <= it } ?: true && - params.maxPeersPerPruneMessage?.let { msg.control?.pruneList?.none { p -> p.peersCount > it } } ?: true + params.maxPeersAcceptedInPruneMsg.let { msg.control?.pruneList?.none { p -> p.peersCount > it } } ?: true } private fun processControlMessage(controlMsg: Any, receivedFrom: PeerHandler) { @@ -270,6 +271,7 @@ open class GossipRouter( is Rpc.ControlPrune -> handlePrune(controlMsg, receivedFrom) is Rpc.ControlIHave -> handleIHave(controlMsg, receivedFrom) is Rpc.ControlIWant -> handleIWant(controlMsg, receivedFrom) + is Rpc.ControlIDontWant -> handleIDontWant(controlMsg, receivedFrom) } } @@ -280,6 +282,7 @@ open class GossipRouter( when { isDirect(peer) -> prune(peer, topic) + isBackOff(peer, topic) -> { notifyRouterMisbehavior(peer, 1) if (isBackOffFlood(peer, topic)) { @@ -287,10 +290,13 @@ open class GossipRouter( } prune(peer, topic) } + score.score(peer.peerId) < 0 -> prune(peer, topic) + meshPeers.size >= params.DHigh && !peer.isOutbound() -> prune(peer, topic) + peer !in meshPeers -> graft(peer, topic) } @@ -301,7 +307,7 @@ open class GossipRouter( mesh[topic]?.remove(peer)?.also { notifyPruned(peer, topic) } - if (this.protocol == PubsubProtocol.Gossip_V_1_1) { + if (this.protocol.supportsBackoffAndPX()) { if (msg.hasBackoff()) { setBackOff(peer, topic, msg.backoff.seconds.toMillis()) } else { @@ -318,6 +324,10 @@ open class GossipRouter( } private fun handleIHave(msg: Rpc.ControlIHave, peer: PeerHandler) { + // we ignore IHAVE gossip for unknown topics + if (msg.hasTopicID() && !mesh.containsKey(msg.topicID)) { + return + } val peerScore = score.score(peer.peerId) // we ignore IHAVE gossip from any peer whose score is below the gossip threshold if (peerScore < scoreParams.gossipThreshold) return @@ -345,12 +355,26 @@ open class GossipRouter( msg.messageIDsList .mapNotNull { mCache.getMessageForPeer(peer.peerId, it.toWBytes()) } .filter { it.sentCount < params.gossipRetransmission } - .map { it.msg } - .forEach { submitPublishMessage(peer, it) } + .forEach { submitPublishMessage(peer, it.msg) } + } + + private fun handleIDontWant(msg: Rpc.ControlIDontWant, peer: PeerHandler) { + if (!this.protocol.supportsIDontWant()) return + val peerScore = score.score(peer.peerId) + if (peerScore < scoreParams.gossipThreshold) return + val iDontWantCacheEntry = peerIDontWant.computeIfAbsent(peer) { IDontWantCacheEntry() } + iDontWantCacheEntry.heartbeatMessageIdsCount += msg.messageIDsCount + if (iDontWantCacheEntry.heartbeatMessageIdsCount > params.maxIDontWantMessageIds) { + return + } + val timeReceived = currentTimeSupplier() + msg.messageIDsList + .map { it.toWBytes() } + .associateWithTo(iDontWantCacheEntry.messageIdsAndTimeReceived) { timeReceived } } private fun processPrunePeers(peersList: List) { - peersList.shuffled(random).take(params.maxPrunePeers) + peersList.shuffled(random).take(params.maxPeersAcceptedInPruneMsg) .map { PeerId(it.peerID.toByteArray()) to it.signedPeerRecord.toByteArray() } .filter { (id, _) -> !isConnected(id) } .forEach { (id, record) -> params.connectCallback(id, record) } @@ -358,18 +382,26 @@ open class GossipRouter( override fun processControl(ctrl: Rpc.ControlMessage, receivedFrom: PeerHandler) { ctrl.run { - (graftList + pruneList + ihaveList + iwantList) + (graftList + pruneList + ihaveList + iwantList + idontwantList) }.forEach { processControlMessage(it, receivedFrom) } } override fun broadcastInbound(msgs: List, receivedFrom: PeerHandler) { msgs.forEach { pubMsg -> - pubMsg.topics + val topics = pubMsg.topics + .asSequence() + + val peersFromMesh = topics .mapNotNull { mesh[it] } .flatten() + + val peersFromDirectPeers = topics.flatMap { getDirectPeers(it) } + + peersFromDirectPeers + .plus(peersFromMesh) .distinct() - .plus(getDirectPeers()) - .filter { it != receivedFrom } + .minus(receivedFrom) + .filterNot { peerDoesNotWantMessage(it, pubMsg.messageId) } .forEach { submitPublishMessage(it, pubMsg) } mCache += pubMsg } @@ -379,36 +411,83 @@ open class GossipRouter( override fun broadcastOutbound(msg: PubsubMessage): CompletableFuture { msg.topics.forEach { lastPublished[it] = currentTimeSupplier() } + val floodPublish = msg.size <= params.floodPublishMaxMessageSizeThreshold + val peers = - if (params.floodPublish) { - msg.topics - .flatMap { getTopicPeers(it) } - .filter { score.score(it.peerId) >= scoreParams.publishThreshold } - .plus(getDirectPeers()) + if (floodPublish) { + selectPeersForOutboundBroadcastingInFloodPublish(msg) } else { - msg.topics - .mapNotNull { topic -> - mesh[topic] ?: fanout[topic] ?: getTopicPeers(topic).shuffled(random).take(params.D) - .also { - if (it.isNotEmpty()) fanout[topic] = it.toMutableSet() - } - } - .flatten() + selectPeersForOutboundBroadcasting(msg) } - val list = peers.map { submitPublishMessage(it, msg) } mCache += msg - flushAllPending() - if (list.isNotEmpty()) { - return anyComplete(list) + return if (peers.isNotEmpty()) { + iDontWant(msg) + val publishedMessages = peers + .filterNot { peerDoesNotWantMessage(it, msg.messageId) } + .map { submitPublishMessage(it, msg) } + if (publishedMessages.isEmpty()) { + // all peers have sent IDONTWANT for this message id + CompletableFuture.completedFuture(Unit) + } else { + flushAllPending() + anyComplete(publishedMessages) + } } else { - return completedExceptionally( + completedExceptionally( NoPeersForOutboundMessageException("No peers for message topics ${msg.topics} found") ) } } + private fun selectPeersForOutboundBroadcastingInFloodPublish(msg: PubsubMessage): List { + return msg.topics + .flatMap { getTopicPeers(it) } + .filter { isDirect(it) || score.score(it.peerId) >= scoreParams.publishThreshold } + } + + private fun selectPeersForOutboundBroadcasting(msg: PubsubMessage): List { + val fromMesh = msg.topics + .map { topic -> + val topicMeshPeers = mesh[topic] + if (topicMeshPeers != null) { + // we are subscribed to the topic + if (topicMeshPeers.size < params.D) { + // we need extra non-mesh peers for more reliable publishing + val nonMeshTopicPeers = getTopicPeers(topic) - topicMeshPeers + val (nonMeshTopicPeersAbovePublishThreshold, nonMeshTopicPeersBelowPublishThreshold) = + nonMeshTopicPeers.partition { score.score(it.peerId) >= scoreParams.publishThreshold } + // this deviates from the original spec but we want at least D peers for publishing + // prioritizing mesh peers, then non-mesh peers with acceptable score, + // and then underscored non-mesh peers as a last resort + listOf( + topicMeshPeers, + nonMeshTopicPeersAbovePublishThreshold.shuffled(random), + nonMeshTopicPeersBelowPublishThreshold.shuffled(random) + ) + .flatten() + .take(params.D) + } else { + topicMeshPeers + } + } else { + // we are not subscribed to the topic + fanout[topic] ?: getTopicPeers(topic).shuffled(random).take(params.D) + .also { + if (it.isNotEmpty()) fanout[topic] = it.toMutableSet() + } + } + } + .flatten() + + val fromDirectPeers = msg.topics.flatMap { getDirectPeers(it) } + + return fromMesh + .plus(fromDirectPeers) + .distinct() + } + override fun subscribe(topic: Topic) { super.subscribe(topic) val fanoutPeers = (fanout[topic] ?: mutableSetOf()) @@ -456,6 +535,15 @@ open class GossipRouter( .whenTrue { notifyIWantTimeout(key.first, key.second) } } + val staleIDontWantTime = this.currentTimeSupplier() - params.iDontWantTTL.toMillis() + peerIDontWant.entries.removeIf { (_, cacheEntry) -> + // reset on heartbeat + cacheEntry.heartbeatMessageIdsCount = 0 + cacheEntry.messageIdsAndTimeReceived.values.removeIf { timeReceived -> timeReceived < staleIDontWantTime } + // remove entry for peer if no IDONTWANT message ids are left in the cache + cacheEntry.messageIdsAndTimeReceived.isEmpty() + } + try { mesh.entries.forEach { (topic, peers) -> @@ -475,7 +563,7 @@ open class GossipRouter( val sortedPeers = peers .shuffled(random) .sortedBy { score.score(it.peerId) } - .reversed() + .asReversed() val bestDPeers = sortedPeers.take(params.DScore) val restPeers = sortedPeers.drop(params.DScore).shuffled(random) @@ -545,7 +633,7 @@ open class GossipRouter( peers.shuffled(random) .take(max((params.gossipFactor * peers.size).toInt(), params.DLazy)) - .forEach { enqueueIhave(it, shuffledMessageIds) } + .forEach { enqueueIhave(it, shuffledMessageIds, topic) } } private fun graft(peer: PeerHandler, topic: Topic) { @@ -562,6 +650,10 @@ open class GossipRouter( } } + private fun peerDoesNotWantMessage(peer: PeerHandler, messageId: MessageId): Boolean { + return peerIDontWant[peer]?.messageIdsAndTimeReceived?.contains(messageId) == true + } + private fun iWant(peer: PeerHandler, messageIds: List) { if (messageIds.isEmpty()) return messageIds[random.nextInt(messageIds.size)] @@ -569,10 +661,23 @@ open class GossipRouter( enqueueIwant(peer, messageIds) } + private fun iDontWant(msg: PubsubMessage, receivedFrom: PeerHandler? = null) { + if (!this.protocol.supportsIDontWant()) return + if (msg.size < params.iDontWantMinMessageSizeThreshold) return + // we need to send IDONTWANT messages to mesh peers immediately in order for them to have an effect + msg.topics + .mapNotNull { mesh[it] } + .flatten() + .distinct() + .minus(setOfNotNull(receivedFrom)) + .forEach { sendIdontwant(it, msg.messageId) } + } + private fun enqueuePrune(peer: PeerHandler, topic: Topic) { val peerQueue = pendingRpcParts.getQueue(peer) - if (peer.getPeerProtocol() == PubsubProtocol.Gossip_V_1_1 && this.protocol == PubsubProtocol.Gossip_V_1_1) { + if (peer.getPeerProtocol().supportsBackoffAndPX() && this.protocol.supportsBackoffAndPX()) { val backoffPeers = (getTopicPeers(topic) - peer) + .take(params.maxPeersSentInPruneMsg) .filter { score.score(it.peerId) >= 0 } .map { it.peerId } peerQueue.addPrune(topic, params.pruneBackoff.seconds, backoffPeers) @@ -587,10 +692,28 @@ open class GossipRouter( private fun enqueueIwant(peer: PeerHandler, messageIds: List) = pendingRpcParts.getQueue(peer).addIWants(messageIds) - private fun enqueueIhave(peer: PeerHandler, messageIds: List) = - pendingRpcParts.getQueue(peer).addIHaves(messageIds) + private fun enqueueIhave(peer: PeerHandler, messageIds: List, topic: Topic) = + pendingRpcParts.getQueue(peer).addIHaves(messageIds, topic) + + private fun sendIdontwant(peer: PeerHandler, messageId: MessageId) { + if (!peer.getPeerProtocol().supportsIDontWant()) { + return + } + val iDontWant = Rpc.RPC.newBuilder().setControl( + Rpc.ControlMessage.newBuilder().addIdontwant( + Rpc.ControlIDontWant.newBuilder() + .addMessageIDs(messageId.toProtobuf()) + ) + ).build() + send(peer, iDontWant) + } data class AcceptRequestsWhitelistEntry(val whitelistedTill: Long, val messagesAccepted: Int = 0) { fun incrementMessageCount() = AcceptRequestsWhitelistEntry(whitelistedTill, messagesAccepted + 1) } + + data class IDontWantCacheEntry( + var heartbeatMessageIdsCount: Int = 0, + val messageIdsAndTimeReceived: MutableMap = mutableMapOf() + ) } diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueue.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueue.kt index bfe86339e..e90332589 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueue.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueue.kt @@ -10,8 +10,8 @@ import pubsub.pb.Rpc interface GossipRpcPartsQueue : RpcPartsQueue { - fun addIHave(messageId: MessageId) - fun addIHaves(messageIds: Collection) = messageIds.forEach { addIHave(it) } + fun addIHave(messageId: MessageId, topic: Topic) + fun addIHaves(messageIds: Collection, topic: Topic) = messageIds.forEach { addIHave(it, topic) } fun addIWant(messageId: MessageId) fun addIWants(messageIds: Collection) = messageIds.forEach { addIWant(it) } @@ -21,6 +21,7 @@ interface GossipRpcPartsQueue : RpcPartsQueue { * Gossip 1.0 variant */ fun addPrune(topic: Topic) + /** * Gossip 1.1 variant */ @@ -36,14 +37,13 @@ open class DefaultGossipRpcPartsQueue( private val params: GossipParams ) : DefaultRpcPartsQueue(), GossipRpcPartsQueue { - protected data class IHavePart(val messageId: MessageId) : AbstractPart { + protected data class IHavePart(val messageId: MessageId, val topic: Topic) : AbstractPart { override fun appendToBuilder(builder: Rpc.RPC.Builder) { val ctrlBuilder = builder.controlBuilder - val iHaveBuilder = if (ctrlBuilder.ihaveBuilderList.isEmpty()) { - ctrlBuilder.addIhaveBuilder() - } else { - ctrlBuilder.getIhaveBuilder(0) - } + val iHaveBuilder = ctrlBuilder.ihaveBuilderList + .find { it.topicID == topic } + ?: ctrlBuilder.addIhaveBuilder().setTopicID(topic) + iHaveBuilder.addMessageIDs(messageId.toProtobuf()) } } @@ -81,8 +81,8 @@ open class DefaultGossipRpcPartsQueue( } } - override fun addIHave(messageId: MessageId) { - addPart(IHavePart(messageId)) + override fun addIHave(messageId: MessageId, topic: Topic) { + addPart(IHavePart(messageId, topic)) } override fun addIWant(messageId: MessageId) { @@ -118,7 +118,6 @@ open class DefaultGossipRpcPartsQueue( publishCount > 0 && subscriptionCount > 0 && iHaveCount > 0 && iWantCount > 0 && graftCount > 0 && pruneCount > 0 ) { - val part = parts[partIdx++] when (part) { is PublishPart -> publishCount-- diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipScore.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipScore.kt index a8f41deb4..c244e0c08 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipScore.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipScore.kt @@ -91,9 +91,11 @@ class DefaultGossipScore( inMesh() && ((curTimeMillis() - joinedMeshTimeMillis).millis > params.meshMessageDeliveriesActivation) private fun meshMessageDeliveriesDeficit() = - if (isMeshMessageDeliveriesActive()) + if (isMeshMessageDeliveriesActive()) { max(0.0, params.meshMessageDeliveriesThreshold - meshMessageDeliveries) - else 0.0 + } else { + 0.0 + } fun meshMessageDeliveriesDeficitSqr() = meshMessageDeliveriesDeficit().pow(2) @@ -142,8 +144,11 @@ class DefaultGossipScore( fun isConnected() = connectedTimeMillis > 0 && disconnectedTimeMillis == 0L fun isDisconnected() = disconnectedTimeMillis > 0 fun getDisconnectDuration() = - if (isDisconnected()) (curTimeMillis() - disconnectedTimeMillis).millis - else throw IllegalStateException("Peer is not disconnected") + if (isDisconnected()) { + (curTimeMillis() - disconnectedTimeMillis).millis + } else { + throw IllegalStateException("Peer is not disconnected") + } } val peerParams = params.peerScoreParams @@ -197,8 +202,11 @@ class DefaultGossipScore( val behaviorExcess = peerScore.behaviorPenalty - peerParams.behaviourPenaltyThreshold val routerPenalty = - if (behaviorExcess < 0) 0.0 - else behaviorExcess.pow(2) * peerParams.behaviourPenaltyWeight + if (behaviorExcess < 0) { + 0.0 + } else { + behaviorExcess.pow(2) * peerParams.behaviourPenaltyWeight + } val computedScore = topicsScore + appScore + ipColocationPenalty + routerPenalty peerScore.cachedScore = computedScore diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/builders/GossipParamsBuilder.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/builders/GossipParamsBuilder.kt index ed272ced7..9696a8f8c 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/builders/GossipParamsBuilder.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/builders/GossipParamsBuilder.kt @@ -34,14 +34,12 @@ class GossipParamsBuilder { private var seenTTL: Duration? = null - private var maxPrunePeers: Int? = null + private var maxPeersSentInPruneMsg: Int? = null - private var maxPeersPerPruneMessage: Int? = null + private var maxPeersAcceptedInPruneMsg: Int? = null private var pruneBackoff: Duration? = null - private var floodPublish: Boolean? = null - private var gossipFactor: Double? = null private var opportunisticGraftPeers: Int? = null @@ -72,6 +70,14 @@ class GossipParamsBuilder { private var connectCallback: Function2? = null + private var maxIDontWantMessageIds: Int? = null + + private var iDontWantMinMessageSizeThreshold: Int? = null + + private var floodPublishMaxMessageSizeThreshold: Int? = null + + private var iDontWantTTL: Duration? = null + init { val source = GossipParams() this.D = source.D @@ -81,10 +87,10 @@ class GossipParamsBuilder { this.gossipHistoryLength = source.gossipHistoryLength this.heartbeatInterval = source.heartbeatInterval this.seenTTL = source.seenTTL - this.maxPrunePeers = source.maxPrunePeers - this.maxPeersPerPruneMessage = source.maxPeersPerPruneMessage + this.maxPeersSentInPruneMsg = source.maxPeersSentInPruneMsg + this.maxPeersAcceptedInPruneMsg = source.maxPeersAcceptedInPruneMsg this.pruneBackoff = source.pruneBackoff - this.floodPublish = source.floodPublish + this.floodPublishMaxMessageSizeThreshold = source.floodPublishMaxMessageSizeThreshold this.gossipFactor = source.gossipFactor this.opportunisticGraftPeers = source.opportunisticGraftPeers this.opportunisticGraftTicks = source.opportunisticGraftTicks @@ -100,6 +106,9 @@ class GossipParamsBuilder { this.maxPruneMessages = source.maxPruneMessages this.gossipRetransmission = source.gossipRetransmission this.connectCallback = source.connectCallback + this.maxIDontWantMessageIds = source.maxIDontWantMessageIds + this.iDontWantMinMessageSizeThreshold = source.iDontWantMinMessageSizeThreshold + this.iDontWantTTL = source.iDontWantTTL } fun D(value: Int): GossipParamsBuilder = apply { D = value } @@ -126,14 +135,12 @@ class GossipParamsBuilder { fun seenTTL(value: Duration): GossipParamsBuilder = apply { seenTTL = value } - fun maxPrunePeers(value: Int): GossipParamsBuilder = apply { maxPrunePeers = value } + fun maxPeersSentInPruneMsg(value: Int): GossipParamsBuilder = apply { maxPeersSentInPruneMsg = value } - fun maxPeersPerPruneMessage(value: Int): GossipParamsBuilder = apply { maxPeersPerPruneMessage = value } + fun maxPeersAcceptedInPruneMsg(value: Int): GossipParamsBuilder = apply { maxPeersAcceptedInPruneMsg = value } fun pruneBackoff(value: Duration): GossipParamsBuilder = apply { pruneBackoff = value } - fun floodPublish(value: Boolean): GossipParamsBuilder = apply { floodPublish = value } - fun gossipFactor(value: Double): GossipParamsBuilder = apply { gossipFactor = value } fun opportunisticGraftPeers(value: Int): GossipParamsBuilder = apply { @@ -172,6 +179,14 @@ class GossipParamsBuilder { connectCallback = value } + fun maxIDontWantMessageIds(value: Int): GossipParamsBuilder = apply { maxIDontWantMessageIds = value } + + fun iDontWantMinMessageSizeThreshold(value: Int): GossipParamsBuilder = apply { iDontWantMinMessageSizeThreshold = value } + + fun floodPublishMaxMessageSizeThreshold(value: Int): GossipParamsBuilder = apply { floodPublishMaxMessageSizeThreshold = value } + + fun iDontWantTTL(value: Duration): GossipParamsBuilder = apply { iDontWantTTL = value } + fun build(): GossipParams { calculateMissing() checkRequiredFields() @@ -188,7 +203,7 @@ class GossipParamsBuilder { gossipHistoryLength = gossipHistoryLength!!, heartbeatInterval = heartbeatInterval!!, seenTTL = seenTTL!!, - floodPublish = floodPublish!!, + floodPublishMaxMessageSizeThreshold = floodPublishMaxMessageSizeThreshold!!, gossipFactor = gossipFactor!!, opportunisticGraftPeers = opportunisticGraftPeers!!, opportunisticGraftTicks = opportunisticGraftTicks!!, @@ -201,12 +216,15 @@ class GossipParamsBuilder { maxIWantMessageIds = maxIWantMessageIds, iWantFollowupTime = iWantFollowupTime!!, maxGraftMessages = maxGraftMessages, - maxPrunePeers = maxPrunePeers!!, - maxPeersPerPruneMessage = maxPeersPerPruneMessage, + maxPeersSentInPruneMsg = maxPeersSentInPruneMsg!!, + maxPeersAcceptedInPruneMsg = maxPeersAcceptedInPruneMsg!!, pruneBackoff = pruneBackoff!!, maxPruneMessages = maxPruneMessages, gossipRetransmission = gossipRetransmission!!, - connectCallback = connectCallback!! + connectCallback = connectCallback!!, + maxIDontWantMessageIds = maxIDontWantMessageIds!!, + iDontWantMinMessageSizeThreshold = iDontWantMinMessageSizeThreshold!!, + iDontWantTTL = iDontWantTTL!! ) } @@ -232,9 +250,9 @@ class GossipParamsBuilder { check(gossipHistoryLength != null, { "gossipHistoryLength must not be null" }) check(heartbeatInterval != null, { "heartbeatInterval must not be null" }) check(seenTTL != null, { "seenTTL must not be null" }) - check(maxPrunePeers != null, { "maxPrunePeers must not be null" }) + check(maxPeersSentInPruneMsg != null, { "maxPeersSentInPruneMsg must not be null" }) check(pruneBackoff != null, { "pruneBackoff must not be null" }) - check(floodPublish != null, { "floodPublish must not be null" }) + check(floodPublishMaxMessageSizeThreshold != null, { "floodPublishMaxMessageSizeThreshold must not be null" }) check(gossipFactor != null, { "gossipFactor must not be null" }) check(opportunisticGraftPeers != null, { "opportunisticGraftPeers must not be null" }) check(opportunisticGraftTicks != null, { "opportunisticGraftTicks must not be null" }) @@ -244,5 +262,8 @@ class GossipParamsBuilder { check(iWantFollowupTime != null, { "iWantFollowupTime must not be null" }) check(gossipRetransmission != null, { "gossipRetransmission must not be null" }) check(connectCallback != null, { "connectCallback must not be null" }) + check(maxIDontWantMessageIds != null, { "maxIDontWantMessageIds must not be null" }) + check(iDontWantMinMessageSizeThreshold != null, { "iDontWantMinMessageSizeThreshold must not be null" }) + check(iDontWantTTL != null, { "iDontWantTTL must not be null" }) } } diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/builders/GossipRouterBuilder.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/builders/GossipRouterBuilder.kt index 06b7db403..5c783ce5f 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/builders/GossipRouterBuilder.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/builders/GossipRouterBuilder.kt @@ -16,7 +16,7 @@ typealias GossipScoreFactory = open class GossipRouterBuilder( var name: String = "GossipRouter", - var protocol: PubsubProtocol = PubsubProtocol.Gossip_V_1_1, + var protocol: PubsubProtocol = PubsubProtocol.Gossip_V_1_2, var params: GossipParams = GossipParams(), var scoreParams: GossipScoreParams = GossipScoreParams(), diff --git a/libp2p/src/main/kotlin/io/libp2p/security/noise/NoiseXXSecureChannel.kt b/libp2p/src/main/kotlin/io/libp2p/security/noise/NoiseXXSecureChannel.kt index 9bb407ef6..c9c02d82b 100644 --- a/libp2p/src/main/kotlin/io/libp2p/security/noise/NoiseXXSecureChannel.kt +++ b/libp2p/src/main/kotlin/io/libp2p/security/noise/NoiseXXSecureChannel.kt @@ -97,7 +97,8 @@ class NoiseIoHandshake( private var sentNoiseKeyPayload = false private var instancePayload: ByteArray? = null private var activated = false - private var remotePeerId: PeerId? = null + private var remotePubKey: PubKey? = null + private val remotePeerId: PeerId? get() = remotePubKey?.let { PeerId.fromPubKey(it) } private var expectedRemotePeerId: PeerId? = null init { @@ -139,7 +140,7 @@ class NoiseIoHandshake( // the remote public key has been provided by the XX protocol val derivedRemotePublicKey = handshakeState.remotePublicKey if (derivedRemotePublicKey.hasPublicKey()) { - remotePeerId = verifyPayload(ctx, instancePayload!!, derivedRemotePublicKey) + remotePubKey = verifyPayload(ctx, instancePayload!!, derivedRemotePublicKey) if (role == Role.INIT && expectedRemotePeerId != remotePeerId) { throw InvalidRemotePubKey() } @@ -226,7 +227,6 @@ class NoiseIoHandshake( } // sendNoiseStaticKeyAsPayload private fun sendNoiseMessage(ctx: ChannelHandlerContext, msg: ByteArray? = null) { - val lenMsg = if (!NoiseXXSecureChannel.rustInteroperability) { msg } else if (msg != null) { @@ -249,7 +249,7 @@ class NoiseIoHandshake( ctx: ChannelHandlerContext, payload: ByteArray, remotePublicKeyState: DHState - ): PeerId { + ): PubKey { log.debug("Verifying noise static key payload") val (pubKeyFromMessage, signatureFromMessage) = unpackKeyAndSignature(payload) @@ -265,7 +265,7 @@ class NoiseIoHandshake( handshakeFailed(ctx, InvalidRemotePubKey()) } - return PeerId.fromPubKey(pubKeyFromMessage) + return pubKeyFromMessage } // verifyPayload private fun unpackKeyAndSignature(payload: ByteArray): Pair { @@ -288,7 +288,7 @@ class NoiseIoHandshake( val secureSession = NoiseSecureChannelSession( PeerId.fromPubKey(localKey.publicKey()), remotePeerId!!, - localKey.publicKey(), + remotePubKey!!, aliceSplit, bobSplit ) diff --git a/libp2p/src/main/kotlin/io/libp2p/security/plaintext/PlaintextInsecureChannel.kt b/libp2p/src/main/kotlin/io/libp2p/security/plaintext/PlaintextInsecureChannel.kt index a75c71ce3..7b2c00d71 100644 --- a/libp2p/src/main/kotlin/io/libp2p/security/plaintext/PlaintextInsecureChannel.kt +++ b/libp2p/src/main/kotlin/io/libp2p/security/plaintext/PlaintextInsecureChannel.kt @@ -84,14 +84,16 @@ class PlaintextHandshakeHandler( val exchangeRecv = Plaintext.Exchange.parser().parseFrom(msg.nioBuffer()) ?: throw InvalidInitialPacket() - if (!exchangeRecv.hasPubkey()) + if (!exchangeRecv.hasPubkey()) { throw InvalidRemotePubKey() + } remotePeerId = PeerId(exchangeRecv.id.toByteArray()) remotePubKey = unmarshalPublicKey(exchangeRecv.pubkey.toByteArray()) val calculatedPeerId = PeerId.fromPubKey(remotePubKey) - if (remotePeerId != calculatedPeerId) + if (remotePeerId != calculatedPeerId) { throw InvalidRemotePubKey() + } handshakeCompleted(ctx) } // channelRead0 diff --git a/libp2p/src/main/kotlin/io/libp2p/security/secio/SecIoCodec.kt b/libp2p/src/main/kotlin/io/libp2p/security/secio/SecIoCodec.kt index dfdbbb3d5..b319a9b91 100644 --- a/libp2p/src/main/kotlin/io/libp2p/security/secio/SecIoCodec.kt +++ b/libp2p/src/main/kotlin/io/libp2p/security/secio/SecIoCodec.kt @@ -27,10 +27,10 @@ class SecIoCodec(val local: SecioParams, val remote: SecioParams) : MessageToMes companion object { fun createCipher(params: SecioParams): StreamCipher { - val aesEngine = AESEngine().apply { + val aesEngine = AESEngine.newInstance().apply { init(true, KeyParameter(params.keys.cipherKey)) } - return SICBlockCipher(aesEngine).apply { + return SICBlockCipher.newInstance(aesEngine).apply { init(true, ParametersWithIV(null, params.keys.iv)) } } @@ -58,8 +58,9 @@ class SecIoCodec(val local: SecioParams, val remote: SecioParams) : MessageToMes val macArr = updateMac(remote, cipherBytes) - if (!macBytes.contentEquals(macArr)) + if (!macBytes.contentEquals(macArr)) { throw InvalidMacException() + } val clearText = processBytes(remoteCipher, cipherBytes) out.add(clearText.toByteBuf()) diff --git a/libp2p/src/main/kotlin/io/libp2p/security/secio/SecIoNegotiator.kt b/libp2p/src/main/kotlin/io/libp2p/security/secio/SecIoNegotiator.kt index 837a9b8e2..c55a0ac19 100644 --- a/libp2p/src/main/kotlin/io/libp2p/security/secio/SecIoNegotiator.kt +++ b/libp2p/src/main/kotlin/io/libp2p/security/secio/SecIoNegotiator.kt @@ -140,8 +140,9 @@ class SecIoNegotiator( val pubKey = unmarshalPublicKey(remotePubKeyBytes) val calcedPeerId = PeerId.fromPubKey(pubKey) - if (remotePeerId != null && calcedPeerId != remotePeerId) + if (remotePeerId != null && calcedPeerId != remotePeerId) { throw InvalidRemotePubKey() + } return pubKey } // validateRemoteKey @@ -151,8 +152,9 @@ class SecIoNegotiator( val h2 = sha256(localPubKeyBytes + remoteNonce) val keyOrder = h1.compareTo(h2) - if (keyOrder == 0) + if (keyOrder == 0) { throw SelfConnecting() + } return keyOrder } // orderKeys @@ -203,8 +205,9 @@ class SecIoNegotiator( exchangeMsg.signature.toByteArray() ) - if (!signatureIsOk) + if (!signatureIsOk) { throw InvalidSignature() + } } // validateExchangeMessage private fun calcHMac(macKey: ByteArray): HMac { @@ -241,8 +244,9 @@ class SecIoNegotiator( } // generateSharedSecret private fun verifyNonceResponse(buf: ByteBuf) { - if (!nonce.contentEquals(buf.toByteArray())) + if (!nonce.contentEquals(buf.toByteArray())) { throw InvalidInitialPacket() + } state = State.FinalValidated } // verifyNonceResponse diff --git a/libp2p/src/main/kotlin/io/libp2p/security/tls/TLSSecureChannel.kt b/libp2p/src/main/kotlin/io/libp2p/security/tls/TLSSecureChannel.kt index 53c1957c5..cbbf876c8 100644 --- a/libp2p/src/main/kotlin/io/libp2p/security/tls/TLSSecureChannel.kt +++ b/libp2p/src/main/kotlin/io/libp2p/security/tls/TLSSecureChannel.kt @@ -109,10 +109,11 @@ fun buildTlsHandler( val connectionKeys = if (certAlgorithm.equals("ECDSA")) generateEcdsaKeyPair() else generateEd25519KeyPair() val javaPrivateKey = getJavaKey(connectionKeys.first) val sslContext = ( - if (isInitiator) + if (isInitiator) { SslContextBuilder.forClient().keyManager(javaPrivateKey, listOf(buildCert(localKey, connectionKeys.first))) - else + } else { SslContextBuilder.forServer(javaPrivateKey, listOf(buildCert(localKey, connectionKeys.first))) + } ) .protocols(listOf("TLSv1.3")) .ciphers(listOf("TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384", "TLS_CHACHA20_POLY1305_SHA256")) @@ -132,10 +133,11 @@ fun buildTlsHandler( val handshake = handler.handshakeFuture() val engine = handler.engine() handshake.addListener { fut -> - if (! fut.isSuccess) { + if (!fut.isSuccess) { var cause = fut.cause() - if (cause != null && cause.cause != null) + if (cause != null && cause.cause != null) { cause = cause.cause + } handshakeComplete.completeExceptionally(cause) } else { val nextProtocol = handler.applicationProtocol() @@ -174,7 +176,7 @@ private class ChannelSetup( private var activated = false override fun channelActive(ctx: ChannelHandlerContext) { - if (! activated) { + if (!activated) { activated = true val expectedRemotePeerId = ctx.channel().attr(REMOTE_PEER_ID).get() val handler = buildTlsHandler( @@ -218,14 +220,15 @@ class Libp2pTrustManager(private val expectedRemotePeer: Optional) : X50 remoteCert = null } override fun checkClientTrusted(certs: Array?, authType: String?) { - if (certs?.size != 1) + if (certs?.size != 1) { throw CertificateException() + } val cert = certs.get(0) remoteCert = cert val claimedPeerId = verifyAndExtractPeerId(arrayOf(cert)) - if (expectedRemotePeer.map { ex -> ! ex.equals(claimedPeerId) }.orElse(false)) + if (expectedRemotePeer.map { ex -> !ex.equals(claimedPeerId) }.orElse(false)) { throw InvalidRemotePubKey() - println("Trusted!") + } } override fun checkServerTrusted(certs: Array?, authType: String?) { @@ -276,14 +279,16 @@ fun getPubKey(pub: PublicKey): PubKey { if (pub.algorithm.equals("EC")) { return EcdsaPublicKey(pub as ECPublicKey) } - if (pub.algorithm.equals("RSA")) + if (pub.algorithm.equals("RSA")) { throw IllegalStateException("Unimplemented RSA public key support for TLS") + } throw IllegalStateException("Unsupported key type: " + pub.algorithm) } fun verifyAndExtractPeerId(chain: Array): PeerId { - if (chain.size != 1) + if (chain.size != 1) { throw java.lang.IllegalStateException("Cert chain must have exactly 1 element!") + } val cert = chain.get(0) // peerid is in the certificate extension val bcCert = org.bouncycastle.asn1.x509.Certificate @@ -291,8 +296,9 @@ fun verifyAndExtractPeerId(chain: Array): PeerId { val bcX509Cert = X509CertificateHolder(bcCert) val libp2pOid = ASN1ObjectIdentifier("1.3.6.1.4.1.53594.1.1") val extension = bcX509Cert.extensions.getExtension(libp2pOid) - if (extension == null) + if (extension == null) { throw IllegalStateException("Certificate extension not present!") + } val input = ASN1InputStream(extension.extnValue.encoded) val wrapper = input.readObject() as DEROctetString val seq = ASN1InputStream(wrapper.octets).readObject() as DLSequence @@ -300,21 +306,25 @@ fun verifyAndExtractPeerId(chain: Array): PeerId { val signature = (seq.getObjectAt(1) as DEROctetString).octets val pubKey = unmarshalPublicKey(pubKeyProto) val pubKeyAsn1 = bcCert.subjectPublicKeyInfo.encoded - if (! pubKey.verify(certificatePrefix.plus(pubKeyAsn1), signature)) + if (!pubKey.verify(certificatePrefix.plus(pubKeyAsn1), signature)) { throw IllegalStateException("Invalid signature on TLS certificate extension!") + } cert.verify(cert.publicKey) val now = Date() - if (bcCert.endDate.date.before(now)) + if (bcCert.endDate.date.before(now)) { throw IllegalStateException("TLS certificate has expired!") - if (bcCert.startDate.date.after(now)) + } + if (bcCert.startDate.date.after(now)) { throw IllegalStateException("TLS certificate is not valid yet!") + } return PeerId.fromPubKey(pubKey) } fun getPublicKeyFromCert(chain: Array): PubKey { - if (chain.size != 1) + if (chain.size != 1) { throw java.lang.IllegalStateException("Cert chain must have exactly 1 element!") + } val cert = chain.get(0) return getPubKey(cert.publicKey) } diff --git a/libp2p/src/main/kotlin/io/libp2p/transport/implementation/ConnectionOverNetty.kt b/libp2p/src/main/kotlin/io/libp2p/transport/implementation/ConnectionOverNetty.kt index fee796146..90c1d824f 100644 --- a/libp2p/src/main/kotlin/io/libp2p/transport/implementation/ConnectionOverNetty.kt +++ b/libp2p/src/main/kotlin/io/libp2p/transport/implementation/ConnectionOverNetty.kt @@ -30,8 +30,12 @@ open class ConnectionOverNetty( ch.attr(CONNECTION).set(this) } - fun setMuxerSession(ms: StreamMuxer.Session) { muxerSession = ms } - fun setSecureSession(ss: SecureChannel.Session) { secureSession = ss } + fun setMuxerSession(ms: StreamMuxer.Session) { + muxerSession = ms + } + fun setSecureSession(ss: SecureChannel.Session) { + secureSession = ss + } override fun muxerSession() = muxerSession override fun secureSession() = secureSession @@ -43,10 +47,11 @@ open class ConnectionOverNetty( toMultiaddr(nettyChannel.remoteAddress() as InetSocketAddress) private fun toMultiaddr(addr: InetSocketAddress): Multiaddr { - if (transport is NettyTransport) + if (transport is NettyTransport) { return transport.toMultiaddr(addr) - else + } else { return toMultiaddrDefault(addr) + } } fun toMultiaddrDefault(addr: InetSocketAddress): Multiaddr { diff --git a/libp2p/src/main/kotlin/io/libp2p/transport/implementation/NettyTransport.kt b/libp2p/src/main/kotlin/io/libp2p/transport/implementation/NettyTransport.kt index 2a078266d..f29c2bfa6 100644 --- a/libp2p/src/main/kotlin/io/libp2p/transport/implementation/NettyTransport.kt +++ b/libp2p/src/main/kotlin/io/libp2p/transport/implementation/NettyTransport.kt @@ -129,8 +129,7 @@ abstract class NettyTransport( ?: throw Libp2pException("No listeners on address $addr") } // unlisten - override fun dial(addr: Multiaddr, connHandler: ConnectionHandler, preHandler: ChannelVisitor?): - CompletableFuture { + override fun dial(addr: Multiaddr, connHandler: ConnectionHandler, preHandler: ChannelVisitor?): CompletableFuture { if (closed) throw Libp2pException("Transport is closed") val remotePeerId = addr.getPeerId() @@ -186,8 +185,9 @@ abstract class NettyTransport( protected fun hostFromMultiaddr(addr: Multiaddr): String { val resolvedAddresses = MultiaddrDns.resolve(addr) - if (resolvedAddresses.isEmpty()) + if (resolvedAddresses.isEmpty()) { throw Libp2pException("Could not resolve $addr to an IP address") + } return resolvedAddresses[0].components.find { it.protocol in arrayOf(Protocol.IP4, Protocol.IP6) diff --git a/libp2p/src/main/kotlin/io/libp2p/transport/tcp/TcpTransport.kt b/libp2p/src/main/kotlin/io/libp2p/transport/tcp/TcpTransport.kt index 71d34c2df..a081ff67d 100644 --- a/libp2p/src/main/kotlin/io/libp2p/transport/tcp/TcpTransport.kt +++ b/libp2p/src/main/kotlin/io/libp2p/transport/tcp/TcpTransport.kt @@ -2,8 +2,10 @@ package io.libp2p.transport.tcp import io.libp2p.core.InternalErrorException import io.libp2p.core.multiformats.Multiaddr +import io.libp2p.core.multiformats.Protocol.DNSADDR import io.libp2p.core.multiformats.Protocol.IP4 import io.libp2p.core.multiformats.Protocol.IP6 +import io.libp2p.core.multiformats.Protocol.P2PCIRCUIT import io.libp2p.core.multiformats.Protocol.TCP import io.libp2p.core.multiformats.Protocol.WS import io.libp2p.transport.ConnectionUpgrader @@ -27,7 +29,9 @@ open class TcpTransport( override fun handles(addr: Multiaddr) = handlesHost(addr) && addr.has(TCP) && - !addr.has(WS) + !addr.has(WS) && + !addr.has(DNSADDR) && + !addr.has(P2PCIRCUIT) override fun serverTransportBuilder( connectionBuilder: ConnectionBuilder, diff --git a/libp2p/src/main/proto/autonat.proto b/libp2p/src/main/proto/autonat.proto new file mode 100644 index 000000000..0e92a5178 --- /dev/null +++ b/libp2p/src/main/proto/autonat.proto @@ -0,0 +1,37 @@ +syntax = "proto2"; + +package io.libp2p.protocol.autonat.pb; + +message Message { + enum MessageType { + DIAL = 0; + DIAL_RESPONSE = 1; + } + + enum ResponseStatus { + OK = 0; + E_DIAL_ERROR = 100; + E_DIAL_REFUSED = 101; + E_BAD_REQUEST = 200; + E_INTERNAL_ERROR = 300; + } + + message PeerInfo { + optional bytes id = 1; + repeated bytes addrs = 2; + } + + message Dial { + optional PeerInfo peer = 1; + } + + message DialResponse { + optional ResponseStatus status = 1; + optional string statusText = 2; + optional bytes addr = 3; + } + + optional MessageType type = 1; + optional Dial dial = 2; + optional DialResponse dialResponse = 3; +} diff --git a/libp2p/src/main/proto/circuit.proto b/libp2p/src/main/proto/circuit.proto new file mode 100644 index 000000000..efc8d425a --- /dev/null +++ b/libp2p/src/main/proto/circuit.proto @@ -0,0 +1,60 @@ +syntax = "proto2"; + +package io.libp2p.protocol.circuit.pb; + +message HopMessage { + enum Type { + RESERVE = 0; + CONNECT = 1; + STATUS = 2; + } + + required Type type = 1; + + optional Peer peer = 2; + optional Reservation reservation = 3; + optional Limit limit = 4; + + optional Status status = 5; +} + +message StopMessage { + enum Type { + CONNECT = 0; + STATUS = 1; + } + + required Type type = 1; + + optional Peer peer = 2; + optional Limit limit = 3; + + optional Status status = 4; +} + +message Peer { + required bytes id = 1; + repeated bytes addrs = 2; +} + +message Reservation { + required uint64 expire = 1; // Unix expiration time (UTC) + repeated bytes addrs = 2; // relay addrs for reserving peer + optional bytes voucher = 3; // reservation voucher +} + +message Limit { + optional uint32 duration = 1; // seconds + optional uint64 data = 2; // bytes +} + +enum Status { + OK = 100; + RESERVATION_REFUSED = 200; + RESOURCE_LIMIT_EXCEEDED = 201; + PERMISSION_DENIED = 202; + CONNECTION_FAILED = 203; + NO_RESERVATION = 204; + MALFORMED_MESSAGE = 400; + UNEXPECTED_MESSAGE = 401; +} diff --git a/libp2p/src/main/proto/envelope.proto b/libp2p/src/main/proto/envelope.proto new file mode 100644 index 000000000..d303fd553 --- /dev/null +++ b/libp2p/src/main/proto/envelope.proto @@ -0,0 +1,46 @@ +syntax = "proto3"; + +package io.libp2p.protocol.circuit.crypto.pb; + +enum KeyType { + RSA = 0; + Ed25519 = 1; + Secp256k1 = 2; + ECDSA = 3; + Curve25519 = 4; +} + +message PublicKey { + KeyType Type = 1; + bytes Data = 2; +} + +message PrivateKey { + KeyType Type = 1; + bytes Data = 2; +} + +// Envelope encloses a signed payload produced by a peer, along with the public +// key of the keypair it was signed with so that it can be statelessly validated +// by the receiver. +// +// The payload is prefixed with a byte string that determines the type, so it +// can be deserialized deterministically. Often, this byte string is a +// multicodec. +message Envelope { + // public_key is the public key of the keypair the enclosed payload was + // signed with. + PublicKey public_key = 1; + + // payload_type encodes the type of payload, so that it can be deserialized + // deterministically. + bytes payload_type = 2; + + // payload is the actual payload carried inside this envelope. + bytes payload = 3; + + // signature is the signature produced by the private key corresponding to + // the enclosed public key, over the payload, prefixing a domain string for + // additional security. + bytes signature = 5; +} diff --git a/libp2p/src/main/proto/rpc.proto b/libp2p/src/main/proto/rpc.proto index 479e73b8b..080eef471 100644 --- a/libp2p/src/main/proto/rpc.proto +++ b/libp2p/src/main/proto/rpc.proto @@ -28,6 +28,7 @@ message ControlMessage { repeated ControlIWant iwant = 2; repeated ControlGraft graft = 3; repeated ControlPrune prune = 4; + repeated ControlIDontWant idontwant = 5; } message ControlIHave { @@ -49,6 +50,10 @@ message ControlPrune { optional uint64 backoff = 3; } +message ControlIDontWant { + repeated bytes messageIDs = 1; +} + message PeerInfo { optional bytes peerID = 1; optional bytes signedPeerRecord = 2; diff --git a/libp2p/src/main/proto/voucher.proto b/libp2p/src/main/proto/voucher.proto new file mode 100644 index 000000000..5b2dea19e --- /dev/null +++ b/libp2p/src/main/proto/voucher.proto @@ -0,0 +1,9 @@ +syntax = "proto2"; + +package io.libp2p.protocol.circuit.pb; + +message Voucher { + required bytes relay = 1; + required bytes peer = 2; + required uint64 expiration = 3; +} \ No newline at end of file diff --git a/libp2p/src/test/java/io/libp2p/core/AutonatTestJava.java b/libp2p/src/test/java/io/libp2p/core/AutonatTestJava.java new file mode 100644 index 000000000..5d74c06de --- /dev/null +++ b/libp2p/src/test/java/io/libp2p/core/AutonatTestJava.java @@ -0,0 +1,73 @@ +package io.libp2p.core; + +import io.libp2p.core.dsl.*; +import io.libp2p.core.multiformats.*; +import io.libp2p.core.mux.*; +import io.libp2p.protocol.*; +import io.libp2p.protocol.autonat.*; +import io.libp2p.protocol.autonat.pb.*; +import io.libp2p.security.noise.*; +import io.libp2p.transport.tcp.*; +import java.util.concurrent.*; +import org.junit.jupiter.api.*; + +public class AutonatTestJava { + + @Test + void autonatDial() throws Exception { + Host clientHost = + new HostBuilder() + .transport(TcpTransport::new) + .secureChannel(NoiseXXSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .protocol(new Ping()) + .protocol(new AutonatProtocol.Binding()) + .listen("/ip4/127.0.0.1/tcp/0") + .build(); + + Host serverHost = + new HostBuilder() + .transport(TcpTransport::new) + .secureChannel(NoiseXXSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .protocol(new Ping()) + .protocol(new AutonatProtocol.Binding()) + .listen("/ip4/127.0.0.1/tcp/0") + .build(); + + CompletableFuture clientStarted = clientHost.start(); + CompletableFuture serverStarted = serverHost.start(); + clientStarted.get(5, TimeUnit.SECONDS); + System.out.println("Client started"); + serverStarted.get(5, TimeUnit.SECONDS); + System.out.println("Server started"); + + StreamPromise autonat = + clientHost + .getNetwork() + .connect(serverHost.getPeerId(), serverHost.listenAddresses().get(0)) + .thenApply(it -> it.muxerSession().createStream(new AutonatProtocol.Binding())) + .get(5, TimeUnit.SECONDS); + + Stream autonatStream = autonat.getStream().get(5, TimeUnit.SECONDS); + System.out.println("Autonat stream created"); + AutonatProtocol.AutoNatController autonatCtr = autonat.getController().get(5, TimeUnit.SECONDS); + System.out.println("Autonat controller created"); + + Autonat.Message.DialResponse resp = + autonatCtr + .requestDial(clientHost.getPeerId(), clientHost.listenAddresses()) + .get(5, TimeUnit.SECONDS); + Assertions.assertEquals(resp.getStatus(), Autonat.Message.ResponseStatus.OK); + Multiaddr received = Multiaddr.deserialize(resp.getAddr().toByteArray()); + Assertions.assertEquals(received, clientHost.listenAddresses().get(0)); + + autonatStream.close().get(5, TimeUnit.SECONDS); + System.out.println("Autonat stream closed"); + + clientHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Client stopped"); + serverHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Server stopped"); + } +} diff --git a/libp2p/src/test/java/io/libp2p/core/HostTestJava.java b/libp2p/src/test/java/io/libp2p/core/HostTestJava.java index a1ce55b95..f2c9a51ff 100644 --- a/libp2p/src/test/java/io/libp2p/core/HostTestJava.java +++ b/libp2p/src/test/java/io/libp2p/core/HostTestJava.java @@ -1,102 +1,291 @@ package io.libp2p.core; -import io.libp2p.core.crypto.KEY_TYPE; import io.libp2p.core.crypto.KeyKt; +import io.libp2p.core.crypto.KeyType; import io.libp2p.core.crypto.PrivKey; import io.libp2p.core.crypto.PubKey; import io.libp2p.core.dsl.HostBuilder; import io.libp2p.core.multiformats.Multiaddr; import io.libp2p.core.mux.StreamMuxerProtocol; -import io.libp2p.protocol.Ping; -import io.libp2p.protocol.PingController; +import io.libp2p.protocol.*; +import io.libp2p.security.noise.*; import io.libp2p.security.tls.*; import io.libp2p.transport.tcp.TcpTransport; -import kotlin.Pair; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - +import io.netty.handler.logging.LogLevel; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import kotlin.Pair; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; public class HostTestJava { - @Test - void ping() throws Exception { -/* HostImpl clientHost = BuildersJKt.hostJ(b -> { - b.getIdentity().random(); - b.getTransports().add(TcpTransport::new); - b.getSecureChannels().add(SecIoSecureChannel::new); - b.getMuxers().add(MplexStreamMuxer::new); - b.getProtocols().add(new Ping()); - b.getDebug().getMuxFramesHandler().setLogger(LogLevel.ERROR, "host-1-MUX"); - b.getDebug().getBeforeSecureHandler().setLogger(LogLevel.ERROR, "host-1-BS"); - b.getDebug().getAfterSecureHandler().setLogger(LogLevel.ERROR, "host-1-AS"); - }); -*/ - String localListenAddress = "/ip4/127.0.0.1/tcp/40002"; - - Host clientHost = new HostBuilder() - .transport(TcpTransport::new) - .secureChannel(TlsSecureChannel::ECDSA) - .muxer(StreamMuxerProtocol::getYamux) - .build(); - - Host serverHost = new HostBuilder() - .transport(TcpTransport::new) - .secureChannel(TlsSecureChannel::new) - .muxer(StreamMuxerProtocol::getYamux) - .protocol(new Ping()) - .listen(localListenAddress) - .build(); - - CompletableFuture clientStarted = clientHost.start(); - CompletableFuture serverStarted = serverHost.start(); - clientStarted.get(5, TimeUnit.SECONDS); - System.out.println("Client started"); - serverStarted.get(5, TimeUnit.SECONDS); - System.out.println("Server started"); - - Assertions.assertEquals(0, clientHost.listenAddresses().size()); - Assertions.assertEquals(1, serverHost.listenAddresses().size()); - Assertions.assertEquals( - localListenAddress + "/p2p/" + serverHost.getPeerId(), - serverHost.listenAddresses().get(0).toString() - ); - - StreamPromise ping = - clientHost.getNetwork().connect( - serverHost.getPeerId(), - new Multiaddr(localListenAddress) - ).thenApply( - it -> it.muxerSession().createStream(new Ping()) - ) - .get(5, TimeUnit.SECONDS); - - Stream pingStream = ping.getStream().get(5, TimeUnit.SECONDS); - System.out.println("Ping stream created"); - PingController pingCtr = ping.getController().get(5, TimeUnit.SECONDS); - System.out.println("Ping controller created"); - - for (int i = 0; i < 10; i++) { - long latency = pingCtr.ping().get(1, TimeUnit.SECONDS); - System.out.println("Ping is " + latency); - } - pingStream.close().get(5, TimeUnit.SECONDS); - System.out.println("Ping stream closed"); - - Assertions.assertThrows(ExecutionException.class, () -> - pingCtr.ping().get(5, TimeUnit.SECONDS)); - - clientHost.stop().get(5, TimeUnit.SECONDS); - System.out.println("Client stopped"); - serverHost.stop().get(5, TimeUnit.SECONDS); - System.out.println("Server stopped"); + @Test + void ping() throws Exception { + /* HostImpl clientHost = BuildersJKt.hostJ(b -> { + b.getIdentity().random(); + b.getTransports().add(TcpTransport::new); + b.getSecureChannels().add(SecIoSecureChannel::new); + b.getMuxers().add(MplexStreamMuxer::new); + b.getProtocols().add(new Ping()); + b.getDebug().getMuxFramesHandler().setLogger(LogLevel.ERROR, "host-1-MUX"); + b.getDebug().getBeforeSecureHandler().setLogger(LogLevel.ERROR, "host-1-BS"); + b.getDebug().getAfterSecureHandler().setLogger(LogLevel.ERROR, "host-1-AS"); + }); + */ + String localListenAddress = "/ip4/127.0.0.1/tcp/40002"; + + Host clientHost = + new HostBuilder() + .transport(TcpTransport::new) + .secureChannel(TlsSecureChannel::ECDSA) + .muxer(StreamMuxerProtocol::getYamux) + .build(); + + Host serverHost = + new HostBuilder() + .transport(TcpTransport::new) + .secureChannel(TlsSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .protocol(new Ping()) + .listen(localListenAddress) + .build(); + + CompletableFuture clientStarted = clientHost.start(); + CompletableFuture serverStarted = serverHost.start(); + clientStarted.get(5, TimeUnit.SECONDS); + System.out.println("Client started"); + serverStarted.get(5, TimeUnit.SECONDS); + System.out.println("Server started"); + + Assertions.assertEquals(0, clientHost.listenAddresses().size()); + Assertions.assertEquals(1, serverHost.listenAddresses().size()); + Assertions.assertEquals( + localListenAddress + "/p2p/" + serverHost.getPeerId(), + serverHost.listenAddresses().get(0).toString()); + + StreamPromise ping = + clientHost + .getNetwork() + .connect(serverHost.getPeerId(), new Multiaddr(localListenAddress)) + .thenApply(it -> it.muxerSession().createStream(new Ping())) + .get(5, TimeUnit.SECONDS); + + Stream pingStream = ping.getStream().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream created"); + PingController pingCtr = ping.getController().get(5, TimeUnit.SECONDS); + System.out.println("Ping controller created"); + + for (int i = 0; i < 10; i++) { + long latency = pingCtr.ping().get(1, TimeUnit.SECONDS); + System.out.println("Ping is " + latency); } + pingStream.close().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream closed"); + + Assertions.assertThrows( + ExecutionException.class, () -> pingCtr.ping().get(5, TimeUnit.SECONDS)); + + clientHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Client stopped"); + serverHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Server stopped"); + } + + @Test + void largePing() throws Exception { + int pingSize = 200 * 1024; + String localListenAddress = "/ip4/127.0.0.1/tcp/40002"; + + Host clientHost = + new HostBuilder() + .transport(TcpTransport::new) + .secureChannel((k, m) -> new TlsSecureChannel(k, m, "ECDSA")) + .muxer(StreamMuxerProtocol::getYamux) + .build(); + + Host serverHost = + new HostBuilder() + .transport(TcpTransport::new) + .secureChannel(TlsSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .protocol(new Ping(pingSize)) + .listen(localListenAddress) + .build(); - @Test - void keyPairGeneration() { - Pair pair = KeyKt.generateKeyPair(KEY_TYPE.SECP256K1); - PeerId peerId = PeerId.fromPubKey(pair.component2()); - System.out.println("PeerId: " + peerId.toHex()); + CompletableFuture clientStarted = clientHost.start(); + CompletableFuture serverStarted = serverHost.start(); + clientStarted.get(5, TimeUnit.SECONDS); + System.out.println("Client started"); + serverStarted.get(5, TimeUnit.SECONDS); + System.out.println("Server started"); + + Assertions.assertEquals(0, clientHost.listenAddresses().size()); + Assertions.assertEquals(1, serverHost.listenAddresses().size()); + Assertions.assertEquals( + localListenAddress + "/p2p/" + serverHost.getPeerId(), + serverHost.listenAddresses().get(0).toString()); + + StreamPromise ping = + clientHost + .getNetwork() + .connect(serverHost.getPeerId(), new Multiaddr(localListenAddress)) + .thenApply(it -> it.muxerSession().createStream(new Ping(pingSize))) + .join(); + + Stream pingStream = ping.getStream().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream created"); + PingController pingCtr = ping.getController().get(5, TimeUnit.SECONDS); + System.out.println("Ping controller created"); + + for (int i = 0; i < 10; i++) { + long latency = pingCtr.ping().join(); // get(5, TimeUnit.SECONDS); + System.out.println("Ping is " + latency); } + pingStream.close().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream closed"); + + Assertions.assertThrows( + ExecutionException.class, () -> pingCtr.ping().get(5, TimeUnit.SECONDS)); + + clientHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Client stopped"); + serverHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Server stopped"); + } + + @Test + void largeBlob() throws Exception { + int blobSize = 1024 * 1024; + String localListenAddress = "/ip4/127.0.0.1/tcp/40002"; + + Host clientHost = + new HostBuilder() + .transport(TcpTransport::new) + .secureChannel(NoiseXXSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .builderModifier( + b -> b.getDebug().getMuxFramesHandler().addCompactLogger(LogLevel.ERROR, "client")) + .build(); + + Host serverHost = + new HostBuilder() + .transport(TcpTransport::new) + .secureChannel(NoiseXXSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .protocol(new Blob(blobSize)) + .listen(localListenAddress) + .builderModifier( + b -> b.getDebug().getMuxFramesHandler().addCompactLogger(LogLevel.ERROR, "server")) + .build(); + + CompletableFuture clientStarted = clientHost.start(); + CompletableFuture serverStarted = serverHost.start(); + clientStarted.get(5, TimeUnit.SECONDS); + System.out.println("Client started"); + serverStarted.get(5, TimeUnit.SECONDS); + System.out.println("Server started"); + + Assertions.assertEquals(0, clientHost.listenAddresses().size()); + Assertions.assertEquals(1, serverHost.listenAddresses().size()); + Assertions.assertEquals( + localListenAddress + "/p2p/" + serverHost.getPeerId(), + serverHost.listenAddresses().get(0).toString()); + + StreamPromise blob = + clientHost + .getNetwork() + .connect(serverHost.getPeerId(), new Multiaddr(localListenAddress)) + .thenApply(it -> it.muxerSession().createStream(new Blob(blobSize))) + .join(); + + Stream blobStream = blob.getStream().get(5, TimeUnit.SECONDS); + System.out.println("Blob stream created"); + BlobController blobCtr = blob.getController().get(5, TimeUnit.SECONDS); + System.out.println("Blob controller created"); + + for (int i = 0; i < 10; i++) { + long latency = blobCtr.blob().join(); + System.out.println("Blob round trip is " + latency); + } + blobStream.close().get(5, TimeUnit.SECONDS); + System.out.println("Blob stream closed"); + + Assertions.assertThrows( + ExecutionException.class, () -> blobCtr.blob().get(5, TimeUnit.SECONDS)); + + clientHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Client stopped"); + serverHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Server stopped"); + } + + @Test + void addPingAfterHostStart() throws Exception { + String localListenAddress = "/ip4/127.0.0.1/tcp/40002"; + + Host clientHost = + new HostBuilder() + .transport(TcpTransport::new) + .secureChannel((k, m) -> new TlsSecureChannel(k, m, "ECDSA")) + .muxer(StreamMuxerProtocol::getYamux) + .build(); + + Host serverHost = + new HostBuilder() + .transport(TcpTransport::new) + .secureChannel(TlsSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .listen(localListenAddress) + .build(); + + CompletableFuture clientStarted = clientHost.start(); + CompletableFuture serverStarted = serverHost.start(); + clientStarted.get(5, TimeUnit.SECONDS); + System.out.println("Client started"); + serverStarted.get(5, TimeUnit.SECONDS); + System.out.println("Server started"); + + Assertions.assertEquals(0, clientHost.listenAddresses().size()); + Assertions.assertEquals(1, serverHost.listenAddresses().size()); + Assertions.assertEquals( + localListenAddress + "/p2p/" + serverHost.getPeerId(), + serverHost.listenAddresses().get(0).toString()); + + serverHost.addProtocolHandler(new Ping()); + + StreamPromise ping = + clientHost + .getNetwork() + .connect(serverHost.getPeerId(), new Multiaddr(localListenAddress)) + .thenApply(it -> it.muxerSession().createStream(new Ping())) + .get(5, TimeUnit.SECONDS); + + Stream pingStream = ping.getStream().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream created"); + PingController pingCtr = ping.getController().get(5, TimeUnit.SECONDS); + System.out.println("Ping controller created"); + + for (int i = 0; i < 10; i++) { + long latency = pingCtr.ping().get(1, TimeUnit.SECONDS); + System.out.println("Ping is " + latency); + } + pingStream.close().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream closed"); + + Assertions.assertThrows( + ExecutionException.class, () -> pingCtr.ping().get(5, TimeUnit.SECONDS)); + + clientHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Client stopped"); + serverHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Server stopped"); + } + + @Test + void keyPairGeneration() { + Pair pair = KeyKt.generateKeyPair(KeyType.SECP256K1); + PeerId peerId = PeerId.fromPubKey(pair.component2()); + System.out.println("PeerId: " + peerId.toHex()); + } } diff --git a/libp2p/src/test/java/io/libp2p/core/RelayTestJava.java b/libp2p/src/test/java/io/libp2p/core/RelayTestJava.java new file mode 100644 index 000000000..cc6c4a53a --- /dev/null +++ b/libp2p/src/test/java/io/libp2p/core/RelayTestJava.java @@ -0,0 +1,221 @@ +package io.libp2p.core; + +import io.libp2p.core.crypto.*; +import io.libp2p.core.dsl.*; +import io.libp2p.core.multiformats.*; +import io.libp2p.core.mux.*; +import io.libp2p.protocol.*; +import io.libp2p.protocol.circuit.*; +import io.libp2p.security.noise.*; +import io.libp2p.transport.tcp.*; +import java.util.*; +import java.util.concurrent.*; +import org.junit.jupiter.api.*; + +public class RelayTestJava { + + private static void enableRelay(BuilderJ b, List relays) { + PrivKey priv = b.getIdentity().random().getFactory().invoke(); + b.getIdentity().setFactory(() -> priv); + PeerId us = PeerId.fromPubKey(priv.publicKey()); + CircuitHopProtocol.RelayManager relayManager = + CircuitHopProtocol.RelayManager.limitTo(priv, us, 5); + CircuitStopProtocol.Binding stop = new CircuitStopProtocol.Binding(new CircuitStopProtocol()); + CircuitHopProtocol.Binding hop = new CircuitHopProtocol.Binding(relayManager, stop); + b.getProtocols().add(hop); + b.getProtocols().add(stop); + b.getTransports() + .add( + u -> new RelayTransport(hop, stop, u, h -> relays, new ScheduledThreadPoolExecutor(1))); + } + + @Test + void pingOverLocalRelay() throws Exception { + Host relayHost = + new HostBuilder() + .builderModifier(b -> enableRelay(b, Collections.emptyList())) + .transport(TcpTransport::new) + .secureChannel(NoiseXXSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .listen("/ip4/127.0.0.1/tcp/0") + .protocol(new Ping()) + .build(); + relayHost.getNetwork().getTransports().stream() + .filter(t -> t instanceof RelayTransport) + .map(t -> (RelayTransport) t) + .findFirst() + .get() + .setHost(relayHost); + CompletableFuture relayStarted = relayHost.start(); + relayStarted.get(5, TimeUnit.SECONDS); + + List relayAddrs = relayHost.listenAddresses(); + Multiaddr relayAddr = relayAddrs.get(0); + RelayTransport.CandidateRelay relay = + new RelayTransport.CandidateRelay(relayHost.getPeerId(), relayAddrs); + List relays = List.of(relay); + + Host clientHost = + new HostBuilder() + .builderModifier(b -> enableRelay(b, relays)) + .transport(TcpTransport::new) + .secureChannel(NoiseXXSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .protocol(new Ping()) + .build(); + clientHost.getNetwork().getTransports().stream() + .filter(t -> t instanceof RelayTransport) + .map(t -> (RelayTransport) t) + .findFirst() + .get() + .setHost(clientHost); + + Host serverHost = + new HostBuilder() + .builderModifier(b -> enableRelay(b, relays)) + .transport(TcpTransport::new) + .secureChannel(NoiseXXSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .protocol(new Ping()) + .listen("/ip4/127.0.0.1/tcp/0") + .listen(relayAddr + "/p2p-circuit") + .build(); + serverHost.getNetwork().getTransports().stream() + .filter(t -> t instanceof RelayTransport) + .map(t -> (RelayTransport) t) + .findFirst() + .get() + .setHost(serverHost); + + CompletableFuture clientStarted = clientHost.start(); + CompletableFuture serverStarted = serverHost.start(); + clientStarted.get(5, TimeUnit.SECONDS); + System.out.println("Client started"); + serverStarted.get(5, TimeUnit.SECONDS); + System.out.println("Server started"); + + Multiaddr toDial = + relayAddr.concatenated( + new Multiaddr("/p2p-circuit/p2p/" + serverHost.getPeerId().toBase58())); + System.out.println("Dialling " + toDial + " from " + clientHost.getPeerId()); + StreamPromise ping = + clientHost + .getNetwork() + .connect(serverHost.getPeerId(), toDial) + .thenApply(it -> it.muxerSession().createStream(new Ping())) + .get(5, TimeUnit.SECONDS); + + Stream pingStream = ping.getStream().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream created"); + PingController pingCtr = ping.getController().get(5, TimeUnit.SECONDS); + System.out.println("Ping controller created"); + + for (int i = 0; i < 10; i++) { + long latency = pingCtr.ping().get(1, TimeUnit.SECONDS); + System.out.println("Ping is " + latency); + } + pingStream.close().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream closed"); + + Assertions.assertThrows( + ExecutionException.class, () -> pingCtr.ping().get(5, TimeUnit.SECONDS)); + + clientHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Client stopped"); + serverHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Server stopped"); + } + + @Test + void relayStreamsAreLimited() throws Exception { + Host relayHost = + new HostBuilder() + .builderModifier(b -> enableRelay(b, Collections.emptyList())) + .transport(TcpTransport::new) + .secureChannel(NoiseXXSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .listen("/ip4/127.0.0.1/tcp/0") + .build(); + relayHost.getNetwork().getTransports().stream() + .filter(t -> t instanceof RelayTransport) + .map(t -> (RelayTransport) t) + .findFirst() + .get() + .setHost(relayHost); + CompletableFuture relayStarted = relayHost.start(); + relayStarted.get(5, TimeUnit.SECONDS); + + List relayAddrs = relayHost.listenAddresses(); + Multiaddr relayAddr = relayAddrs.get(0); + RelayTransport.CandidateRelay relay = + new RelayTransport.CandidateRelay(relayHost.getPeerId(), relayAddrs); + List relays = List.of(relay); + + // Relay streams are limited to 4096 bytes in either direction + // This is the smallest value that triggers the limit + // not sure why there is so much overhead from 3 * multistream + noise + yamux! + int blobSize = 1469; + Host clientHost = + new HostBuilder() + .builderModifier(b -> enableRelay(b, relays)) + .transport(TcpTransport::new) + .secureChannel(NoiseXXSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .protocol(new Blob(blobSize)) + .build(); + clientHost.getNetwork().getTransports().stream() + .filter(t -> t instanceof RelayTransport) + .map(t -> (RelayTransport) t) + .findFirst() + .get() + .setHost(clientHost); + + Host serverHost = + new HostBuilder() + .builderModifier(b -> enableRelay(b, relays)) + .transport(TcpTransport::new) + .secureChannel(NoiseXXSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .protocol(new Blob(blobSize)) + .listen("/ip4/127.0.0.1/tcp/0") + .listen(relayAddr + "/p2p-circuit") + .build(); + serverHost.getNetwork().getTransports().stream() + .filter(t -> t instanceof RelayTransport) + .map(t -> (RelayTransport) t) + .findFirst() + .get() + .setHost(serverHost); + + CompletableFuture clientStarted = clientHost.start(); + CompletableFuture serverStarted = serverHost.start(); + clientStarted.get(5, TimeUnit.SECONDS); + System.out.println("Client started"); + serverStarted.get(5, TimeUnit.SECONDS); + System.out.println("Server started"); + + Multiaddr toDial = + relayAddr.concatenated( + new Multiaddr("/p2p-circuit/p2p/" + serverHost.getPeerId().toBase58())); + System.out.println("Dialling " + toDial + " from " + clientHost.getPeerId()); + StreamPromise blob = + clientHost + .getNetwork() + .connect(serverHost.getPeerId(), toDial) + .thenApply(it -> it.muxerSession().createStream(new Blob(blobSize))) + .get(5, TimeUnit.SECONDS); + + Stream blobStream = blob.getStream().get(5, TimeUnit.SECONDS); + System.out.println("Blob stream created"); + BlobController blobCtr = blob.getController().get(5, TimeUnit.SECONDS); + System.out.println("Blob controller created"); + + Assertions.assertThrows( + ExecutionException.class, () -> blobCtr.blob().get(30, TimeUnit.SECONDS)); + + clientHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Client stopped"); + serverHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Server stopped"); + } +} diff --git a/libp2p/src/test/java/io/libp2p/pubsub/GossipApiTest.java b/libp2p/src/test/java/io/libp2p/pubsub/GossipApiTest.java index 7f7de2bdf..d95485103 100644 --- a/libp2p/src/test/java/io/libp2p/pubsub/GossipApiTest.java +++ b/libp2p/src/test/java/io/libp2p/pubsub/GossipApiTest.java @@ -1,5 +1,8 @@ package io.libp2p.pubsub; +import static io.libp2p.tools.StubsKt.peerHandlerStub; +import static org.assertj.core.api.Assertions.assertThat; + import com.google.protobuf.ByteString; import io.libp2p.core.pubsub.ValidationResult; import io.libp2p.etc.types.WBytes; @@ -8,123 +11,120 @@ import io.libp2p.pubsub.gossip.GossipParamsKt; import io.libp2p.pubsub.gossip.GossipRouter; import io.libp2p.pubsub.gossip.builders.GossipRouterBuilder; -import org.jetbrains.annotations.NotNull; -import org.junit.jupiter.api.Test; -import pubsub.pb.Rpc; - import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; -import java.util.Optional; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.function.Function; - -import static io.libp2p.tools.StubsKt.peerHandlerStub; -import static org.assertj.core.api.Assertions.assertThat; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.Test; +import pubsub.pb.Rpc; public class GossipApiTest { - @Test - public void createGossipTest() { - GossipParams gossipParams = GossipParams.builder() - .D(10) - .DHigh(20) - .build(); - GossipRouterBuilder routerBuilder = new GossipRouterBuilder(); - routerBuilder.setParams(gossipParams); - GossipRouter router = routerBuilder.build(); - assertThat(router.getParams().getD()).isEqualTo(10); - assertThat(router.getParams().getDHigh()).isEqualTo(20); - assertThat(router.getParams().getDScore()).isEqualTo(GossipParamsKt.defaultDScore(10)); - } - - @Test - public void testFastMessageId() throws Exception { - List createdMessages = new ArrayList<>(); - - GossipRouterBuilder routerBuilder = new GossipRouterBuilder(); - routerBuilder.setSeenCache(new FastIdSeenCache<>(msg -> msg.getProtobufMessage().getData())); - routerBuilder.setMessageFactory(m -> { - TestPubsubMessage message = new TestPubsubMessage(m); - createdMessages.add(message); - return message; + @Test + public void createGossipTest() { + GossipParams gossipParams = GossipParams.builder().D(10).DHigh(20).build(); + GossipRouterBuilder routerBuilder = new GossipRouterBuilder(); + routerBuilder.setParams(gossipParams); + GossipRouter router = routerBuilder.build(); + assertThat(router.getParams().getD()).isEqualTo(10); + assertThat(router.getParams().getDHigh()).isEqualTo(20); + assertThat(router.getParams().getDScore()).isEqualTo(GossipParamsKt.defaultDScore(10)); + } + + @Test + public void testFastMessageId() throws Exception { + List createdMessages = new ArrayList<>(); + + GossipRouterBuilder routerBuilder = new GossipRouterBuilder(); + routerBuilder.setSeenCache(new FastIdSeenCache<>(msg -> msg.getProtobufMessage().getData())); + routerBuilder.setMessageFactory( + m -> { + TestPubsubMessage message = new TestPubsubMessage(m); + createdMessages.add(message); + return message; }); - GossipRouter router = routerBuilder.build(); - router.subscribe("topic"); - - BlockingQueue messages = new LinkedBlockingQueue<>(); - router.initHandler(m -> { - messages.add(m); - return CompletableFuture.completedFuture(ValidationResult.Valid); + GossipRouter router = routerBuilder.build(); + router.subscribe("topic"); + + BlockingQueue messages = new LinkedBlockingQueue<>(); + router.initHandler( + m -> { + messages.add(m); + return CompletableFuture.completedFuture(ValidationResult.Valid); }); - P2PService.PeerHandler peerHandler = peerHandlerStub(router); - - router.runOnEventThread(() -> router.onInbound(peerHandler, newMessage("Hello-1"))); - TestPubsubMessage message1 = (TestPubsubMessage) messages.poll(1, TimeUnit.SECONDS); - - assertThat(message1).isNotNull(); - assertThat(message1.canonicalId).isNotNull(); - assertThat(createdMessages.size()).isEqualTo(1); - createdMessages.clear(); - - router.runOnEventThread(() -> router.onInbound(peerHandler, newMessage("Hello-1"))); - TestPubsubMessage message2 = (TestPubsubMessage) messages.poll(100, TimeUnit.MILLISECONDS); + P2PService.PeerHandler peerHandler = peerHandlerStub(router); + + router.runOnEventThread(() -> router.onInbound(peerHandler, newMessage("Hello-1"))); + TestPubsubMessage message1 = (TestPubsubMessage) messages.poll(1, TimeUnit.SECONDS); + + assertThat(message1).isNotNull(); + assertThat(message1.canonicalId).isNotNull(); + assertThat(createdMessages.size()).isEqualTo(1); + createdMessages.clear(); + + router.runOnEventThread(() -> router.onInbound(peerHandler, newMessage("Hello-1"))); + TestPubsubMessage message2 = (TestPubsubMessage) messages.poll(100, TimeUnit.MILLISECONDS); + + assertThat(message2).isNull(); + assertThat(createdMessages.size()).isEqualTo(1); + // assert that 'slow' canonicalId was not calculated and the message was filtered as seen by + // fastId + assertThat(createdMessages.get(0).canonicalId).isNull(); + createdMessages.clear(); + } + + private static Rpc.RPC newMessage(String msg) { + return Rpc.RPC + .newBuilder() + .addPublish( + Rpc.Message.newBuilder() + .addTopicIDs("topic") + .setData(ByteString.copyFrom(msg, StandardCharsets.US_ASCII))) + .build(); + } + + private static class TestPubsubMessage implements PubsubMessage { + final Rpc.Message message; + Function canonicalIdCalculator = + m -> new WBytes(("canon-" + m.getData().toString()).getBytes()); + WBytes canonicalId = null; + + public TestPubsubMessage(Rpc.Message message) { + this.message = message; + } - assertThat(message2).isNull(); - assertThat(createdMessages.size()).isEqualTo(1); - // assert that 'slow' canonicalId was not calculated and the message was filtered as seen by fastId - assertThat(createdMessages.get(0).canonicalId).isNull(); - createdMessages.clear(); + @NotNull + @Override + public Rpc.Message getProtobufMessage() { + return message; } - private static Rpc.RPC newMessage(String msg) { - return Rpc.RPC.newBuilder().addPublish( - Rpc.Message.newBuilder() - .addTopicIDs("topic") - .setData(ByteString.copyFrom(msg, StandardCharsets.US_ASCII)) - ).build(); + @NotNull + @Override + public WBytes getMessageId() { + if (canonicalId == null) { + canonicalId = canonicalIdCalculator.apply(getProtobufMessage()); + } + return canonicalId; } - private static class TestPubsubMessage implements PubsubMessage { - final Rpc.Message message; - Function canonicalIdCalculator = m -> new WBytes(("canon-" + m.getData().toString()).getBytes()); - WBytes canonicalId = null; - - public TestPubsubMessage(Rpc.Message message) { - this.message = message; - } - - @NotNull - @Override - public Rpc.Message getProtobufMessage() { - return message; - } - - @NotNull - @Override - public WBytes getMessageId() { - if (canonicalId == null) { - canonicalId = canonicalIdCalculator.apply(getProtobufMessage()); - } - return canonicalId; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - TestPubsubMessage that = (TestPubsubMessage) o; - return message.equals(that.message); - } - - @Override - public int hashCode() { - return message.hashCode(); - } + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TestPubsubMessage that = (TestPubsubMessage) o; + return message.equals(that.message); } + @Override + public int hashCode() { + return message.hashCode(); + } + } } diff --git a/libp2p/src/test/kotlin/io/libp2p/core/HostTest.kt b/libp2p/src/test/kotlin/io/libp2p/core/HostTest.kt index eb4513f99..09720736e 100644 --- a/libp2p/src/test/kotlin/io/libp2p/core/HostTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/core/HostTest.kt @@ -75,13 +75,17 @@ class HostTest { var interceptRead = false var interceptWrite = false override fun interceptRead(buf: ByteBuf) = - if (interceptRead) + if (interceptRead) { Unpooled.wrappedBuffer("RRR".toByteArray(Charsets.UTF_8)) - else buf + } else { + buf + } override fun interceptWrite(buf: ByteBuf) = - if (interceptWrite) + if (interceptWrite) { Unpooled.wrappedBuffer("WWW".toByteArray(Charsets.UTF_8)) - else buf + } else { + buf + } } val interceptor = TestInterceptor() diff --git a/libp2p/src/test/kotlin/io/libp2p/core/HostTranportsTest.kt b/libp2p/src/test/kotlin/io/libp2p/core/HostTranportsTest.kt index 3768d39ca..522ce91a4 100644 --- a/libp2p/src/test/kotlin/io/libp2p/core/HostTranportsTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/core/HostTranportsTest.kt @@ -32,17 +32,20 @@ import java.util.concurrent.TimeUnit @Tag("secure-channel") class PlaintextTcpTest : TcpTransportHostTest(::PlaintextInsecureChannel) + @Tag("secure-channel") class PlaintextWsTest : WsTransportHostTest(::PlaintextInsecureChannel) @Tag("secure-channel") class SecioTcpTest : TcpTransportHostTest(::SecIoSecureChannel) + @Tag("secure-channel") class SecioWsTest : WsTransportHostTest(::SecIoSecureChannel) @DisabledIfEnvironmentVariable(named = "TRAVIS", matches = "true") @Tag("secure-channel") class NoiseXXTcpTest : TcpTransportHostTest(::NoiseXXSecureChannel) + @DisabledIfEnvironmentVariable(named = "TRAVIS", matches = "true") @Tag("secure-channel") class NoiseXXWsTest : WsTransportHostTest(::NoiseXXSecureChannel) @@ -96,7 +99,7 @@ abstract class HostTransportsTest( add(secureChannelCtor) } muxers { - + StreamMuxerProtocol.Mplex + +StreamMuxerProtocol.Mplex } protocols { +Ping() @@ -121,7 +124,7 @@ abstract class HostTransportsTest( add(secureChannelCtor) } muxers { - + StreamMuxerProtocol.Mplex + +StreamMuxerProtocol.Mplex } network { listen(listenAddress) diff --git a/libp2p/src/test/kotlin/io/libp2p/core/RpcHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/core/RpcHandlerTest.kt index 3f43d040e..da8a6b056 100644 --- a/libp2p/src/test/kotlin/io/libp2p/core/RpcHandlerTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/core/RpcHandlerTest.kt @@ -106,7 +106,7 @@ class RpcHandlerTest { add(::SecIoSecureChannel) } muxers { - + StreamMuxerProtocol.Mplex + +StreamMuxerProtocol.Mplex } protocols { +RpcProtocol() @@ -127,7 +127,7 @@ class RpcHandlerTest { add(::SecIoSecureChannel) } muxers { - + StreamMuxerProtocol.Mplex + +StreamMuxerProtocol.Mplex } protocols { +RpcProtocol() @@ -166,8 +166,11 @@ class RpcHandlerTest { Assertions.assertEquals(1, streamCounter2) Assertions.assertEquals(1, streamCounter1) for (i in 1..100) { - if (host1.streams.isNotEmpty() || host2.streams.isNotEmpty()) Thread.sleep(10) - else break + if (host1.streams.isNotEmpty() || host2.streams.isNotEmpty()) { + Thread.sleep(10) + } else { + break + } } Assertions.assertEquals(0, host1.streams.size) Assertions.assertEquals(0, host2.streams.size) @@ -188,8 +191,11 @@ class RpcHandlerTest { Assertions.assertEquals(2, streamCounter1) Assertions.assertEquals(2, streamCounter2) for (i in 1..100) { - if (host1.streams.isNotEmpty() || host2.streams.isNotEmpty()) Thread.sleep(10) - else break + if (host1.streams.isNotEmpty() || host2.streams.isNotEmpty()) { + Thread.sleep(10) + } else { + break + } } Assertions.assertEquals(0, host1.streams.size) Assertions.assertEquals(0, host2.streams.size) diff --git a/libp2p/src/test/kotlin/io/libp2p/core/dsl/BuilderDefaultsTest.kt b/libp2p/src/test/kotlin/io/libp2p/core/dsl/BuilderDefaultsTest.kt index f7bc79b7a..b3bafc9d7 100644 --- a/libp2p/src/test/kotlin/io/libp2p/core/dsl/BuilderDefaultsTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/core/dsl/BuilderDefaultsTest.kt @@ -57,7 +57,7 @@ class BuilderDefaultsTest { identity { random() } transports { +::TcpTransport } secureChannels { add(::SecIoSecureChannel) } - muxers { + StreamMuxerProtocol.Mplex } + muxers { +StreamMuxerProtocol.Mplex } } host.start().get(5, SECONDS) diff --git a/libp2p/src/test/kotlin/io/libp2p/core/multiformats/MultiaddrDnsTest.kt b/libp2p/src/test/kotlin/io/libp2p/core/multiformats/MultiaddrDnsTest.kt index f0dbaf617..9e28b8b6a 100644 --- a/libp2p/src/test/kotlin/io/libp2p/core/multiformats/MultiaddrDnsTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/core/multiformats/MultiaddrDnsTest.kt @@ -164,12 +164,13 @@ class MultiaddrDnsTest { val TestResolver = object : MultiaddrDns.Resolver { override fun resolveDns4(hostname: String): List { - val address = if ("pig.com".equals(hostname)) + val address = if ("pig.com".equals(hostname)) { listOf("/ip4/1.1.1.1", "/ip4/1.1.1.2") - else if ("localhost".equals(hostname)) + } else if ("localhost".equals(hostname)) { listOf("/ip4/127.0.0.1") - else + } else { listOf("/ip4/2.2.2.1", "/ip4/2.2.2.2") + } return address.map { Multiaddr(it) } } diff --git a/libp2p/src/test/kotlin/io/libp2p/core/multiformats/ProtocolTest.kt b/libp2p/src/test/kotlin/io/libp2p/core/multiformats/ProtocolTest.kt index e78c13bd6..8fcd1f075 100644 --- a/libp2p/src/test/kotlin/io/libp2p/core/multiformats/ProtocolTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/core/multiformats/ProtocolTest.kt @@ -21,6 +21,7 @@ class ProtocolTest { assertEquals(protocol, roundTrippedProtocol) } } + @Test fun tcpProtocolProperties() { assertEquals(Protocol.TCP, Protocol.get("tcp")) diff --git a/libp2p/src/test/kotlin/io/libp2p/discovery/MDnsDiscoveryTest.kt b/libp2p/src/test/kotlin/io/libp2p/discovery/MDnsDiscoveryTest.kt index b0cfa93c6..e0bddd949 100644 --- a/libp2p/src/test/kotlin/io/libp2p/discovery/MDnsDiscoveryTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/discovery/MDnsDiscoveryTest.kt @@ -6,6 +6,7 @@ import io.libp2p.core.multiformats.Multiaddr import io.libp2p.crypto.keys.generateEcdsaKeyPair import io.libp2p.tools.NullHost import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Test import java.util.concurrent.TimeUnit @@ -24,6 +25,18 @@ class MDnsDiscoveryTest { } } + val hostIpv6 = object : NullHost() { + override val peerId: PeerId = PeerId.fromPubKey( + generateEcdsaKeyPair().second + ) + + override fun listenAddresses(): List { + return listOf( + Multiaddr("/ip6/::/tcp/4001") + ) + } + } + val otherHost = object : NullHost() { override val peerId: PeerId = PeerId.fromPubKey( generateEcdsaKeyPair().second @@ -47,6 +60,15 @@ class MDnsDiscoveryTest { discoverer.stop().get(1, TimeUnit.SECONDS) } + @Test + fun `start and stop discovery ipv6`() { + val discoverer = MDnsDiscovery(hostIpv6, testServiceTag) + + discoverer.start().get(1, TimeUnit.SECONDS) + TimeUnit.MILLISECONDS.sleep(100) + discoverer.stop().get(1, TimeUnit.SECONDS) + } + @Test fun `start discovery and listen for self`() { var peerInfo: PeerInfo? = null @@ -69,6 +91,28 @@ class MDnsDiscoveryTest { assertEquals(host.listenAddresses().size, peerInfo?.addresses?.size) } + @Test + fun `start discovery and listen for self ipv6`() { + var peerInfo: PeerInfo? = null + val discoverer = MDnsDiscovery(hostIpv6, testServiceTag) + + discoverer.newPeerFoundListeners += { + peerInfo = it + } + + discoverer.start().get(1, TimeUnit.SECONDS) + for (i in 0..50) { + if (peerInfo != null) { + break + } + TimeUnit.MILLISECONDS.sleep(100) + } + discoverer.stop().get(5, TimeUnit.SECONDS) + + assertEquals(hostIpv6.peerId, peerInfo?.peerId) + assertTrue(hostIpv6.listenAddresses().size <= peerInfo?.addresses?.size!!) + } + @Test fun `start discovery and listen for other`() { var peerInfo: PeerInfo? = null diff --git a/libp2p/src/test/kotlin/io/libp2p/etc/util/netty/ByteBufQueueTest.kt b/libp2p/src/test/kotlin/io/libp2p/etc/util/netty/ByteBufQueueTest.kt new file mode 100644 index 000000000..848d783de --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/etc/util/netty/ByteBufQueueTest.kt @@ -0,0 +1,167 @@ +package io.libp2p.etc.util.netty + +import io.libp2p.tools.readAllBytesAndRelease +import io.netty.buffer.ByteBuf +import io.netty.buffer.Unpooled +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Test + +class ByteBufQueueTest { + + val queue = ByteBufQueue() + + val allocatedBufs = mutableListOf() + + @AfterEach + fun cleanUpAndCheck() { + allocatedBufs.forEach { + assertThat(it.refCnt()).isEqualTo(1) + } + } + + fun allocateBuf(): ByteBuf { + val buf = Unpooled.buffer() + buf.retain() // ref counter to 2 to check that exactly 1 ref remains at the end + allocatedBufs += buf + return buf + } + + fun allocateData(data: String): ByteBuf = + allocateBuf().writeBytes(data.toByteArray()) + + fun ByteBuf.readString() = String(this.readAllBytesAndRelease()) + + @Test + fun emptyTest() { + assertThat(queue.take(100).readString()).isEqualTo("") + } + + @Test + fun zeroTest() { + queue.push(allocateData("abc")) + assertThat(queue.take(0).readString()).isEqualTo("") + assertThat(queue.take(100).readString()).isEqualTo("abc") + } + + @Test + fun emptyZeroTest() { + assertThat(queue.take(0).readString()).isEqualTo("") + } + + @Test + fun emptyBuffersTest1() { + queue.push(allocateData("")) + assertThat(queue.take(10).readString()).isEqualTo("") + } + + @Test + fun emptyBuffersTest2() { + queue.push(allocateData("")) + assertThat(queue.take(0).readString()).isEqualTo("") + } + + @Test + fun emptyBuffersTest3() { + queue.push(allocateData("")) + queue.push(allocateData("a")) + queue.push(allocateData("")) + assertThat(queue.take(10).readString()).isEqualTo("a") + } + + @Test + fun emptyBuffersTest4() { + queue.push(allocateData("a")) + queue.push(allocateData("")) + assertThat(queue.take(10).readString()).isEqualTo("a") + } + + @Test + fun emptyBuffersTest5() { + queue.push(allocateData("a")) + queue.push(allocateData("")) + assertThat(queue.take(1).readString()).isEqualTo("a") + assertThat(queue.take(1).readString()).isEqualTo("") + } + + @Test + fun emptyBuffersTest6() { + queue.push(allocateData("a")) + queue.push(allocateData("")) + queue.push(allocateData("")) + queue.push(allocateData("b")) + assertThat(queue.take(10).readString()).isEqualTo("ab") + } + + @Test + fun pushTake1() { + queue.push(allocateData("abc")) + queue.push(allocateData("def")) + + assertThat(queue.take(4).readString()).isEqualTo("abcd") + assertThat(queue.take(1).readString()).isEqualTo("e") + assertThat(queue.take(100).readString()).isEqualTo("f") + assertThat(queue.take(100).readString()).isEqualTo("") + } + + @Test + fun pushTake2() { + queue.push(allocateData("abc")) + queue.push(allocateData("def")) + + assertThat(queue.take(2).readString()).isEqualTo("ab") + assertThat(queue.take(2).readString()).isEqualTo("cd") + assertThat(queue.take(2).readString()).isEqualTo("ef") + assertThat(queue.take(2).readString()).isEqualTo("") + } + + @Test + fun pushTake3() { + queue.push(allocateData("abc")) + queue.push(allocateData("def")) + + assertThat(queue.take(1).readString()).isEqualTo("a") + assertThat(queue.take(1).readString()).isEqualTo("b") + assertThat(queue.take(1).readString()).isEqualTo("c") + assertThat(queue.take(1).readString()).isEqualTo("d") + assertThat(queue.take(1).readString()).isEqualTo("e") + assertThat(queue.take(1).readString()).isEqualTo("f") + assertThat(queue.take(1).readString()).isEqualTo("") + } + + @Test + fun pushTakePush1() { + queue.push(allocateData("abc")) + assertThat(queue.take(2).readString()).isEqualTo("ab") + queue.push(allocateData("def")) + assertThat(queue.take(2).readString()).isEqualTo("cd") + assertThat(queue.take(100).readString()).isEqualTo("ef") + } + + @Test + fun pushTakePush2() { + queue.push(allocateData("abc")) + assertThat(queue.take(3).readString()).isEqualTo("abc") + queue.push(allocateData("def")) + assertThat(queue.take(2).readString()).isEqualTo("de") + assertThat(queue.take(100).readString()).isEqualTo("f") + } + + @Test + fun pushTakePush3() { + queue.push(allocateData("abc")) + queue.push(allocateData("def")) + assertThat(queue.take(1).readString()).isEqualTo("a") + queue.push(allocateData("ghi")) + assertThat(queue.take(100).readString()).isEqualTo("bcdefghi") + } + + @Test + fun pushTakePush4() { + queue.push(allocateData("abc")) + assertThat(queue.take(1).readString()).isEqualTo("a") + queue.push(allocateData("def")) + queue.push(allocateData("ghi")) + assertThat(queue.take(100).readString()).isEqualTo("bcdefghi") + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt index 466718cd9..bb0f21313 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt @@ -9,6 +9,7 @@ import io.libp2p.etc.types.toHex import io.libp2p.etc.util.netty.mux.RemoteWriteClosed import io.libp2p.etc.util.netty.nettyInitializer import io.libp2p.mux.MuxHandlerAbstractTest.AbstractTestMuxFrame.Flag.* +import io.libp2p.mux.MuxHandlerAbstractTest.TestEventHandler import io.libp2p.tools.TestChannel import io.libp2p.tools.readAllBytesAndRelease import io.netty.buffer.ByteBuf @@ -20,10 +21,7 @@ import io.netty.handler.logging.LoggingHandler import org.assertj.core.api.Assertions.assertThat import org.assertj.core.data.Index import org.junit.jupiter.api.AfterEach -import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.api.Assertions.assertFalse -import org.junit.jupiter.api.Assertions.assertThrows -import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Assertions.* import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import java.util.concurrent.CompletableFuture @@ -38,10 +36,15 @@ abstract class MuxHandlerAbstractTest { val parentChannelId get() = ech.id() val allocatedBufs = mutableListOf() + val activeEventHandlers = mutableListOf() + val isLocalConnectionInitiator = true abstract val maxFrameDataLength: Int abstract fun createMuxHandler(streamHandler: StreamHandler<*>): MuxHandler + abstract val localMuxIdGenerator: Iterator + abstract val remoteMuxIdGenerator: Iterator + fun createTestStreamHandler(): StreamHandler = StreamHandler { stream -> val handler = TestHandler() @@ -70,11 +73,14 @@ abstract class MuxHandlerAbstractTest { } multistreamHandler = createMuxHandler(streamHandler) - ech = TestChannel("test", true, LoggingHandler(LogLevel.ERROR), multistreamHandler) + ech = TestChannel("test", isLocalConnectionInitiator, LoggingHandler(LogLevel.ERROR), multistreamHandler) } @AfterEach open fun cleanUpAndCheck() { + childHandlers.forEach { + assertThat(it.exceptions).isEmpty() + } childHandlers.clear() allocatedBufs.forEach { @@ -94,13 +100,17 @@ abstract class MuxHandlerAbstractTest { abstract fun writeFrame(frame: AbstractTestMuxFrame) abstract fun readFrame(): AbstractTestMuxFrame? fun readFrameOrThrow() = readFrame() ?: throw AssertionError("No outbound frames") - - fun openStream(id: Long) = writeFrame(AbstractTestMuxFrame(id, Open)) + fun openStreamRemote(id: Long) = writeFrame(AbstractTestMuxFrame(id, Open)) + fun openStreamRemote(): Long { + val id = remoteMuxIdGenerator.next() + openStreamRemote(id) + return id + } fun writeStream(id: Long, msg: String) = writeFrame(AbstractTestMuxFrame(id, Data, msg)) fun closeStream(id: Long) = writeFrame(AbstractTestMuxFrame(id, Close)) fun resetStream(id: Long) = writeFrame(AbstractTestMuxFrame(id, Reset)) - fun openStreamByLocal(): TestHandler { + fun openStreamLocal(): TestHandler { val handlerFut = multistreamHandler.createStream(createTestStreamHandler()).controller ech.runPendingTasks() return handlerFut.get() @@ -113,6 +123,8 @@ abstract class MuxHandlerAbstractTest { return buf } + protected fun allocateMessage(hexBytes: String) = hexBytes.fromHex().toByteBuf(allocateBuf()) + fun assertHandlerCount(count: Int) = assertEquals(count, childHandlers.size) fun assertLastMessage(handler: Int, msgCount: Int, msg: String) { val messages = childHandlers[handler].inboundMessages @@ -122,58 +134,57 @@ abstract class MuxHandlerAbstractTest { @Test fun singleStream() { - openStream(12) + val id1 = openStreamRemote() assertHandlerCount(1) assertTrue(childHandlers[0].isActivated) - writeStream(12, "22") + writeStream(id1, "22") assertHandlerCount(1) assertEquals(1, childHandlers[0].inboundMessages.size) assertEquals("22", childHandlers[0].inboundMessages.last()) - writeStream(12, "44") + writeStream(id1, "44") assertHandlerCount(1) assertEquals(2, childHandlers[0].inboundMessages.size) assertEquals("44", childHandlers[0].inboundMessages.last()) - writeStream(12, "66") + writeStream(id1, "66") assertHandlerCount(1) assertEquals(3, childHandlers[0].inboundMessages.size) assertEquals("66", childHandlers[0].inboundMessages.last()) assertFalse(childHandlers[0].isInactivated) - assertTrue(childHandlers[0].exceptions.isEmpty()) } @Test fun `test that readComplete event is fired to child channel`() { - openStream(12) + val id1 = openStreamRemote() assertThat(childHandlers[0].readCompleteEventCount).isZero() - writeStream(12, "22") + writeStream(id1, "22") assertThat(childHandlers[0].readCompleteEventCount).isEqualTo(1) - writeStream(12, "23") + writeStream(id1, "23") assertThat(childHandlers[0].readCompleteEventCount).isEqualTo(2) } @Test fun `test that readComplete event is fired to reading channels only`() { - openStream(12) - openStream(13) + val id1 = openStreamRemote() + val id2 = openStreamRemote() assertThat(childHandlers[0].readCompleteEventCount).isZero() assertThat(childHandlers[1].readCompleteEventCount).isZero() - writeStream(12, "22") + writeStream(id1, "22") assertThat(childHandlers[0].readCompleteEventCount).isEqualTo(1) assertThat(childHandlers[1].readCompleteEventCount).isEqualTo(0) - writeStream(13, "23") + writeStream(id2, "23") assertThat(childHandlers[0].readCompleteEventCount).isEqualTo(1) assertThat(childHandlers[1].readCompleteEventCount).isEqualTo(1) @@ -181,116 +192,111 @@ abstract class MuxHandlerAbstractTest { @Test fun twoStreamsInterleaved() { - openStream(12) - writeStream(12, "22") + val id1 = openStreamRemote() + writeStream(id1, "22") assertHandlerCount(1) assertLastMessage(0, 1, "22") - writeStream(12, "23") + writeStream(id1, "23") assertHandlerCount(1) assertLastMessage(0, 2, "23") - openStream(22) - writeStream(22, "33") + val id2 = openStreamRemote() + writeStream(id2, "33") assertHandlerCount(2) assertLastMessage(1, 1, "33") - writeStream(12, "24") + writeStream(id1, "24") assertHandlerCount(2) assertLastMessage(0, 3, "24") - writeStream(22, "34") + writeStream(id2, "34") assertHandlerCount(2) assertLastMessage(1, 2, "34") assertFalse(childHandlers[0].isInactivated) - assertTrue(childHandlers[0].exceptions.isEmpty()) assertFalse(childHandlers[1].isInactivated) - assertTrue(childHandlers[1].exceptions.isEmpty()) } @Test fun twoStreamsSequential() { - openStream(12) - writeStream(12, "22") + val id1 = openStreamRemote() + writeStream(id1, "22") assertHandlerCount(1) assertLastMessage(0, 1, "22") - writeStream(12, "23") + writeStream(id1, "23") assertHandlerCount(1) assertLastMessage(0, 2, "23") - writeStream(12, "24") + writeStream(id1, "24") assertHandlerCount(1) assertLastMessage(0, 3, "24") - writeStream(12, "25") + writeStream(id1, "25") assertHandlerCount(1) assertLastMessage(0, 4, "25") assertFalse(childHandlers[0].isInactivated) - resetStream(12) + resetStream(id1) assertTrue(childHandlers[0].isHandlerRemoved) - assertTrue(childHandlers[0].exceptions.isEmpty()) - openStream(22) - writeStream(22, "33") + val id2 = openStreamRemote() + writeStream(id2, "33") assertHandlerCount(2) assertLastMessage(1, 1, "33") - writeStream(22, "34") + writeStream(id2, "34") assertHandlerCount(2) assertLastMessage(1, 2, "34") assertFalse(childHandlers[1].isInactivated) - resetStream(22) + resetStream(id2) assertTrue(childHandlers[1].isHandlerRemoved) - assertTrue(childHandlers[1].exceptions.isEmpty()) } @Test fun streamIsReset() { - openStream(22) + val id1 = openStreamRemote() assertFalse(childHandlers[0].ctx.channel().closeFuture().isDone) assertFalse(childHandlers[0].isInactivated) - resetStream(22) + resetStream(id1) assertTrue(childHandlers[0].ctx.channel().closeFuture().isDone) assertTrue(childHandlers[0].isHandlerRemoved) } @Test fun streamIsResetWhenChannelIsClosed() { - openStream(22) + openStreamRemote() assertFalse(childHandlers[0].ctx.channel().closeFuture().isDone) ech.close().await() assertTrue(childHandlers[0].ctx.channel().closeFuture().isDone) assertTrue(childHandlers[0].isHandlerRemoved) - assertTrue(childHandlers[0].exceptions.isEmpty()) } @Test fun cantReceiveOnResetStream() { - openStream(18) - resetStream(18) + val id1 = openStreamRemote() + resetStream(id1) assertThrows(Libp2pException::class.java) { - writeStream(18, "35") + writeStream(id1, "35") } assertTrue(childHandlers[0].isHandlerRemoved) } @Test fun cantReceiveOnClosedStream() { - openStream(18) - closeStream(18) + val id1 = openStreamRemote() + closeStream(id1) assertThrows(Libp2pException::class.java) { - writeStream(18, "35") + writeStream(id1, "35") } assertFalse(childHandlers[0].isInactivated) } @@ -304,9 +310,15 @@ abstract class MuxHandlerAbstractTest { } @Test - fun canResetNonExistentStream() { - resetStream(99) + @SuppressWarnings("SwallowedException") + fun `resetting non existing stream doesnt close connection`() { + try { + resetStream(99) + } catch (e: UnknownStreamIdMuxerException) { + // Muxer is free to either throw an exception or just ignore + } assertHandlerCount(0) + assertThat(ech.isOpen).isTrue() } @Test @@ -323,10 +335,20 @@ abstract class MuxHandlerAbstractTest { assertHandlerCount(0) } + @Test + fun `opening a stream with existing id causes connection close`() { + val id1 = openStreamRemote() + assertThrows(Libp2pException::class.java) { + openStreamRemote(id1) + } + + assertThat(ech.isOpen).isFalse() + } + @Test fun `local create and after local disconnect should still read`() { - val handler = openStreamByLocal() - handler.ctx.writeAndFlush("1984".fromHex().toByteBuf(allocateBuf())) + val handler = openStreamLocal() + handler.ctx.writeAndFlush(allocateMessage("1984")) handler.ctx.disconnect().sync() val openFrame = readFrameOrThrow() @@ -350,7 +372,7 @@ abstract class MuxHandlerAbstractTest { @Test fun `local create and after remote disconnect should still write`() { - val handler = openStreamByLocal() + val handler = openStreamLocal() val openFrame = readFrameOrThrow() assertThat(openFrame.flag).isEqualTo(Open) @@ -362,7 +384,7 @@ abstract class MuxHandlerAbstractTest { assertThat(handler.isUnregistered).isFalse() assertThat(handler.userEvents).containsExactly(RemoteWriteClosed) - handler.ctx.writeAndFlush("1984".fromHex().toByteBuf(allocateBuf())) + handler.ctx.writeAndFlush(allocateMessage("1984")) val readFrame = readFrameOrThrow() assertThat(readFrame.flag).isEqualTo(Data) @@ -372,7 +394,7 @@ abstract class MuxHandlerAbstractTest { @Test fun `test remote and local disconnect closes stream`() { - val handler = openStreamByLocal() + val handler = openStreamLocal() handler.ctx.disconnect().sync() readFrameOrThrow() @@ -389,11 +411,11 @@ abstract class MuxHandlerAbstractTest { @Test fun `test large message is split onto slices`() { - val handler = openStreamByLocal() + val handler = openStreamLocal() readFrameOrThrow() val largeMessage = "42".repeat(maxFrameDataLength - 1) + "4344" - handler.ctx.writeAndFlush(largeMessage.fromHex().toByteBuf(allocateBuf())) + handler.ctx.writeAndFlush(allocateMessage(largeMessage)) val dataFrame1 = readFrameOrThrow() assertThat(dataFrame1.data.fromHex()) @@ -412,35 +434,52 @@ abstract class MuxHandlerAbstractTest { @Test fun `should throw when writing to locally closed stream`() { - val handler = openStreamByLocal() + val handler = openStreamLocal() handler.ctx.disconnect() assertThrows(Exception::class.java) { - handler.ctx.writeAndFlush("42".fromHex().toByteBuf(allocateBuf())).sync() + handler.ctx.writeAndFlush(allocateMessage("42")).sync() } } @Test fun `should throw when writing to reset stream`() { - val handler = openStreamByLocal() + val handler = openStreamLocal() handler.ctx.close() assertThrows(Exception::class.java) { - handler.ctx.writeAndFlush("42".fromHex().toByteBuf(allocateBuf())).sync() + handler.ctx.writeAndFlush(allocateMessage("42")).sync() } } @Test fun `should throw when writing to closed connection`() { - val handler = openStreamByLocal() + val handler = openStreamLocal() ech.close().sync() assertThrows(Exception::class.java) { - handler.ctx.writeAndFlush("42".fromHex().toByteBuf(allocateBuf())).sync() + handler.ctx.writeAndFlush(allocateMessage("42")).sync() + } + } + + @Test + fun `test writing to remotely open stream upon activation`() { + activeEventHandlers += TestEventHandler { + val writePromise = it.ctx.writeAndFlush(allocateMessage("42")) + writePromise.sync() } + val id1 = openStreamRemote() + + val dataFrame = readFrameOrThrow() + assertThat(dataFrame.streamId).isEqualTo(id1) + assertThat(dataFrame.data).isEqualTo("42") + } + + fun interface TestEventHandler { + fun handle(testHandler: TestHandler) } - class TestHandler : ChannelInboundHandlerAdapter() { + inner class TestHandler : ChannelInboundHandlerAdapter() { val inboundMessages = mutableListOf() lateinit var ctx: ChannelHandlerContext var readCompleteEventCount = 0 @@ -461,7 +500,7 @@ abstract class MuxHandlerAbstractTest { override fun handlerAdded(ctx: ChannelHandlerContext) { assertFalse(isHandlerAdded) isHandlerAdded = true - println("MultiplexHandlerTest.handlerAdded") + println("MuxHandlerAbstractTest.handlerAdded") this.ctx = ctx } @@ -469,57 +508,58 @@ abstract class MuxHandlerAbstractTest { assertTrue(isHandlerAdded) assertFalse(isRegistered) isRegistered = true - println("MultiplexHandlerTest.channelRegistered") + println("MuxHandlerAbstractTest.channelRegistered") } override fun channelActive(ctx: ChannelHandlerContext) { assertTrue(isRegistered) assertFalse(isActivated) isActivated = true - println("MultiplexHandlerTest.channelActive") + println("MuxHandlerAbstractTest.channelActive") + activeEventHandlers.forEach { it.handle(this) } } override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { assertTrue(isActivated) - println("MultiplexHandlerTest.channelRead") + println("MuxHandlerAbstractTest.channelRead") msg as ByteBuf inboundMessages += msg.readAllBytesAndRelease().toHex() } override fun channelReadComplete(ctx: ChannelHandlerContext?) { readCompleteEventCount++ - println("MultiplexHandlerTest.channelReadComplete") + println("MuxHandlerAbstractTest.channelReadComplete") } override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) { userEvents += evt - println("MultiplexHandlerTest.userEventTriggered: $evt") + println("MuxHandlerAbstractTest.userEventTriggered: $evt") } override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { exceptions += cause - println("MultiplexHandlerTest.exceptionCaught") + println("MuxHandlerAbstractTest.exceptionCaught") } override fun channelInactive(ctx: ChannelHandlerContext) { assertTrue(isActivated) assertFalse(isInactivated) isInactivated = true - println("MultiplexHandlerTest.channelInactive") + println("MuxHandlerAbstractTest.channelInactive") } override fun channelUnregistered(ctx: ChannelHandlerContext?) { assertTrue(isInactivated) assertFalse(isUnregistered) isUnregistered = true - println("MultiplexHandlerTest.channelUnregistered") + println("MuxHandlerAbstractTest.channelUnregistered") } override fun handlerRemoved(ctx: ChannelHandlerContext?) { assertTrue(isUnregistered) assertFalse(isHandlerRemoved) isHandlerRemoved = true - println("MultiplexHandlerTest.handlerRemoved") + println("MuxHandlerAbstractTest.handlerRemoved") } } diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexFrameCodecTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexFrameCodecTest.kt index 6ba816250..c6f6a5ffa 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexFrameCodecTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexFrameCodecTest.kt @@ -2,7 +2,6 @@ package io.libp2p.mux.mplex import io.libp2p.etc.types.toByteArray import io.libp2p.etc.types.toByteBuf -import io.libp2p.etc.util.netty.mux.MuxId import io.netty.buffer.ByteBuf import io.netty.buffer.Unpooled import io.netty.channel.DefaultChannelId @@ -33,10 +32,13 @@ class MplexFrameCodecTest { val maxFrameDataLength = 1024 val channel = EmbeddedChannel(MplexFrameCodec(maxFrameDataLength = maxFrameDataLength)) + fun muxId(id: Long, initiator: Boolean) = MplexId(dummyId, id, initiator) + @Test fun `check max frame size limit`() { val mplexFrame = MplexFrame( - MuxId(dummyId, 777, true), MplexFlag.MessageInitiator, + muxId(777, true), + MplexFlag.MessageInitiator, ByteArray(maxFrameDataLength).toByteBuf() ) @@ -59,9 +61,9 @@ class MplexFrameCodecTest { @MethodSource("splitIndexes") fun testDecoder(sliceIdx: List) { val mplexFrames = arrayOf( - MplexFrame(MuxId(dummyId, 777, true), MplexFlag.MessageInitiator, "Hello-1".toByteArray().toByteBuf()), - MplexFrame(MuxId(dummyId, 888, true), MplexFlag.MessageInitiator, "Hello-2".toByteArray().toByteBuf()), - MplexFrame(MuxId(dummyId, 999, true), MplexFlag.MessageInitiator, "Hello-3".toByteArray().toByteBuf()) + MplexFrame(muxId(777, true), MplexFlag.MessageInitiator, "Hello-1".toByteArray().toByteBuf()), + MplexFrame(muxId(888, true), MplexFlag.MessageInitiator, "Hello-2".toByteArray().toByteBuf()), + MplexFrame(muxId(999, true), MplexFlag.MessageInitiator, "Hello-3".toByteArray().toByteBuf()) ) assertTrue( channel.writeOutbound(*mplexFrames) @@ -92,13 +94,13 @@ class MplexFrameCodecTest { @Test fun `test id initiator is inverted on decoding`() { val mplexFrames = arrayOf( - MplexFrame.createOpenFrame(MuxId(dummyId, 1, true)), - MplexFrame.createDataFrame(MuxId(dummyId, 2, true), "Hello-2".toByteArray().toByteBuf()), - MplexFrame.createDataFrame(MuxId(dummyId, 3, false), "Hello-3".toByteArray().toByteBuf()), - MplexFrame.createCloseFrame(MuxId(dummyId, 4, true)), - MplexFrame.createCloseFrame(MuxId(dummyId, 5, false)), - MplexFrame.createResetFrame(MuxId(dummyId, 6, true)), - MplexFrame.createResetFrame(MuxId(dummyId, 7, false)), + MplexFrame.createOpenFrame(muxId(1, true)), + MplexFrame.createDataFrame(muxId(2, true), "Hello-2".toByteArray().toByteBuf()), + MplexFrame.createDataFrame(muxId(3, false), "Hello-3".toByteArray().toByteBuf()), + MplexFrame.createCloseFrame(muxId(4, true)), + MplexFrame.createCloseFrame(muxId(5, false)), + MplexFrame.createResetFrame(muxId(6, true)), + MplexFrame.createResetFrame(muxId(7, false)), ) assertTrue( channel.writeOutbound(*mplexFrames) @@ -122,7 +124,7 @@ class MplexFrameCodecTest { assertTrue(frameDataBuf.refCnt() == 1) channel.writeOutbound( - MplexFrame(MuxId(dummyId, 777, true), MplexFlag.MessageInitiator, frameDataBuf) + MplexFrame(muxId(777, true), MplexFlag.MessageInitiator, frameDataBuf) ) val encodedFrame = channel.readOutbound() diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexHandlerTest.kt index 091107331..bcb6c2f2a 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexHandlerTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexHandlerTest.kt @@ -4,7 +4,6 @@ import io.libp2p.core.StreamHandler import io.libp2p.core.multistream.MultistreamProtocolV1 import io.libp2p.etc.types.fromHex import io.libp2p.etc.types.toHex -import io.libp2p.etc.util.netty.mux.MuxId import io.libp2p.mux.MuxHandler import io.libp2p.mux.MuxHandlerAbstractTest import io.libp2p.mux.MuxHandlerAbstractTest.AbstractTestMuxFrame.Flag.* @@ -16,9 +15,15 @@ class MplexHandlerTest : MuxHandlerAbstractTest() { override val maxFrameDataLength = 256 + override val localMuxIdGenerator = (0L..Long.MAX_VALUE).iterator() + override val remoteMuxIdGenerator = (0L..Long.MAX_VALUE).iterator() + override fun createMuxHandler(streamHandler: StreamHandler<*>): MuxHandler = object : MplexHandler( - MultistreamProtocolV1, maxFrameDataLength, null, streamHandler + MultistreamProtocolV1, + maxFrameDataLength, + null, + streamHandler ) { // MuxHandler consumes the exception. Override this behaviour for testing @Deprecated("Deprecated in Java") @@ -28,6 +33,7 @@ class MplexHandlerTest : MuxHandlerAbstractTest() { } override fun writeFrame(frame: AbstractTestMuxFrame) { + val muxId = MplexId(parentChannelId, frame.streamId, true) val mplexFlag = when (frame.flag) { Open -> MplexFlag.Type.OPEN Data -> MplexFlag.Type.DATA @@ -39,7 +45,7 @@ class MplexHandlerTest : MuxHandlerAbstractTest() { else -> frame.data.fromHex().toByteBuf(allocateBuf()) } val mplexFrame = - MplexFrame(MuxId(parentChannelId, frame.streamId, true), MplexFlag.getByType(mplexFlag, true), data) + MplexFrame(muxId, MplexFlag.getByType(mplexFlag, true), data) ech.writeInbound(mplexFrame) } @@ -51,10 +57,9 @@ class MplexHandlerTest : MuxHandlerAbstractTest() { MplexFlag.Type.DATA -> Data MplexFlag.Type.CLOSE -> Close MplexFlag.Type.RESET -> Reset - else -> throw AssertionError("Unknown mplex flag: ${mplexFrame.flag}") } - val sData = maybeMplexFrame.data.readAllBytesAndRelease().toHex() - AbstractTestMuxFrame(mplexFrame.id.id, flag, sData) + val data = maybeMplexFrame.data.readAllBytesAndRelease().toHex() + AbstractTestMuxFrame(mplexFrame.id.id, flag, data) } } } diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt index b920d1285..2bac4bdf3 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt @@ -1,23 +1,46 @@ package io.libp2p.mux.yamux +import io.libp2p.core.Libp2pException import io.libp2p.core.StreamHandler import io.libp2p.core.multistream.MultistreamProtocolV1 import io.libp2p.etc.types.fromHex import io.libp2p.etc.types.toHex -import io.libp2p.etc.util.netty.mux.MuxId +import io.libp2p.mux.AckBacklogLimitExceededMuxerException import io.libp2p.mux.MuxHandler import io.libp2p.mux.MuxHandlerAbstractTest import io.libp2p.mux.MuxHandlerAbstractTest.AbstractTestMuxFrame.Flag.* import io.libp2p.tools.readAllBytesAndRelease +import io.netty.buffer.ByteBuf import io.netty.channel.ChannelHandlerContext +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource class YamuxHandlerTest : MuxHandlerAbstractTest() { override val maxFrameDataLength = 256 + private val maxBufferedConnectionWrites = 512 + private val ackBacklogLimit = 42 + private val initialWindowSize = 300 + override val localMuxIdGenerator = YamuxStreamIdGenerator(isLocalConnectionInitiator).toIterator() + override val remoteMuxIdGenerator = YamuxStreamIdGenerator(!isLocalConnectionInitiator).toIterator() + + private val readFrameQueue = ArrayDeque() + fun Long.toMuxId() = YamuxId(parentChannelId, this) override fun createMuxHandler(streamHandler: StreamHandler<*>): MuxHandler = object : YamuxHandler( - MultistreamProtocolV1, maxFrameDataLength, null, streamHandler, true + MultistreamProtocolV1, + maxFrameDataLength, + null, + streamHandler, + true, + maxBufferedConnectionWrites, + ackBacklogLimit, + initialWindowSize ) { // MuxHandler consumes the exception. Override this behaviour for testing @Deprecated("Deprecated in Java") @@ -27,34 +50,450 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { } override fun writeFrame(frame: AbstractTestMuxFrame) { - val muxId = MuxId(parentChannelId, frame.streamId, true) + val muxId = frame.streamId.toMuxId() val yamuxFrame = when (frame.flag) { - Open -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlags.SYN, 0) - Data -> YamuxFrame( - muxId, - YamuxType.DATA, - 0, - frame.data.fromHex().size.toLong(), - frame.data.fromHex().toByteBuf(allocateBuf()) - ) - Close -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlags.FIN, 0) - Reset -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlags.RST, 0) + Open -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlag.SYN.asSet, 0) + Data -> { + val data = frame.data.fromHex() + YamuxFrame( + muxId, + YamuxType.DATA, + YamuxFlag.NONE, + data.size.toLong(), + data.toByteBuf(allocateBuf()) + ) + } + + Close -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlag.FIN.asSet, 0) + Reset -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlag.RST.asSet, 0) } ech.writeInbound(yamuxFrame) } override fun readFrame(): AbstractTestMuxFrame? { - val maybeYamuxFrame = ech.readOutbound() - return maybeYamuxFrame?.let { yamuxFrame -> - val flag = when { - yamuxFrame.flags == YamuxFlags.SYN -> Open - yamuxFrame.flags == YamuxFlags.FIN -> Close - yamuxFrame.flags == YamuxFlags.RST -> Reset - yamuxFrame.type == YamuxType.DATA -> Data - else -> throw AssertionError("Unsupported yamux frame: $yamuxFrame") - } - val sData = yamuxFrame.data?.readAllBytesAndRelease()?.toHex() ?: "" - AbstractTestMuxFrame(yamuxFrame.id.id, flag, sData) + val yamuxFrame = readYamuxFrame() + if (yamuxFrame != null) { + when { + YamuxFlag.SYN in yamuxFrame.flags -> readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Open) + } + + val data = yamuxFrame.data?.readAllBytesAndRelease()?.toHex() ?: "" + when { + yamuxFrame.type == YamuxType.DATA && data.isNotEmpty() -> + readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Data, data) + } + + when { + YamuxFlag.FIN in yamuxFrame.flags -> readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Close) + YamuxFlag.RST in yamuxFrame.flags -> readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Reset) + } + } + + return readFrameQueue.removeFirstOrNull() + } + + private fun readYamuxFrame(): YamuxFrame? { + return ech.readOutbound() + } + + private fun readYamuxFrameOrThrow() = readYamuxFrame() ?: throw AssertionError("No outbound frames") + + @Test + fun `test ack new stream`() { + // signal opening of new stream + openStreamRemote(12) + + writeStream(12, "23") + + val ackFrame = readYamuxFrameOrThrow() + + // receives ack stream + assertThat(ackFrame.flags).containsExactly(YamuxFlag.ACK) + assertThat(ackFrame.type).isEqualTo(YamuxType.WINDOW_UPDATE) + + closeStream(12) + } + + @Test + fun `test window update is sent after more than half of the window is depleted`() { + openStreamLocal() + val streamId = readFrameOrThrow().streamId + + // > 1/2 window size + val length = (initialWindowSize / 2) + 42 + ech.writeInbound( + YamuxFrame( + streamId.toMuxId(), + YamuxType.DATA, + YamuxFlag.NONE, + length.toLong(), + "42".repeat(length).fromHex().toByteBuf(allocateBuf()) + ) + ) + + val windowUpdateFrame = readYamuxFrameOrThrow() + + // window frame is sent based on the new window + assertThat(windowUpdateFrame.flags).isEmpty() + assertThat(windowUpdateFrame.type).isEqualTo(YamuxType.WINDOW_UPDATE) + assertThat(windowUpdateFrame.length).isEqualTo(length.toLong()) + } + + @Test + fun `data should be buffered and sent after window increased from zero`() { + val handler = openStreamLocal() + val streamId = readFrameOrThrow().streamId + + ech.writeInbound( + YamuxFrame( + streamId.toMuxId(), + YamuxType.WINDOW_UPDATE, + YamuxFlag.ACK.asSet, + -initialWindowSize.toLong() + ) + ) + + handler.ctx.writeAndFlush("1984".fromHex().toByteBuf(allocateBuf())) + + assertThat(readFrame()).isNull() + + ech.writeInbound(YamuxFrame(streamId.toMuxId(), YamuxType.WINDOW_UPDATE, YamuxFlag.ACK.asSet, 5000)) + val frame = readFrameOrThrow() + assertThat(frame.data).isEqualTo("1984") + } + + @Test + fun `buffered data should not be sent if it does not fit within window`() { + val handler = openStreamLocal() + val streamId = readFrameOrThrow().streamId + + ech.writeInbound( + YamuxFrame( + streamId.toMuxId(), + YamuxType.WINDOW_UPDATE, + YamuxFlag.ACK.asSet, + -initialWindowSize.toLong() + ) + ) + + val message = "1984".fromHex().toByteBuf(allocateBuf()) + // 2 bytes per message + handler.ctx.writeAndFlush(message) + handler.ctx.writeAndFlush(message.copy()) + + assertThat(readFrame()).isNull() + + ech.writeInbound( + YamuxFrame( + streamId.toMuxId(), + YamuxType.WINDOW_UPDATE, + YamuxFlag.ACK.asSet, + 2 + ) + ) + + var frame = readFrameOrThrow() + // one message is received + assertThat(frame.data).isEqualTo("1984") + // need to wait for another window update to send more data + assertThat(readFrame()).isNull() + // sending window update + ech.writeInbound( + YamuxFrame( + streamId.toMuxId(), + YamuxType.WINDOW_UPDATE, + YamuxFlag.ACK.asSet, + 1 + ) + ) + frame = readFrameOrThrow() + assertThat(frame.data).isEqualTo("19") + + ech.writeInbound( + YamuxFrame( + streamId.toMuxId(), + YamuxType.WINDOW_UPDATE, + YamuxFlag.ACK.asSet, + 10000 + ) + ) + frame = readFrameOrThrow() + assertThat(frame.data).isEqualTo("84") + } + + @Test + fun `overflowing buffer sends RST flag and throws an exception`() { + val handler = openStreamLocal() + val muxId = readFrameOrThrow().streamId.toMuxId() + + ech.writeInbound( + YamuxFrame( + muxId, + YamuxType.WINDOW_UPDATE, + YamuxFlag.ACK.asSet, + -initialWindowSize.toLong() + ) + ) + + val createMessage: () -> ByteBuf = + { "42".repeat(maxBufferedConnectionWrites / 5).fromHex().toByteBuf(allocateBuf()) } + + for (i in 1..5) { + val writeResult = handler.ctx.writeAndFlush(createMessage()) + assertThat(writeResult.isSuccess).isTrue() + } + + // next message will overflow the configured buffer + val writeResult = handler.ctx.writeAndFlush(createMessage()) + assertThat(writeResult.isSuccess).isFalse() + assertThat(writeResult.cause()) + .isInstanceOf(Libp2pException::class.java) + .hasMessage("Overflowed send buffer (612/512). Last stream attempting to write: $muxId") + + val frame = readYamuxFrameOrThrow() + assertThat(frame.flags).containsExactly(YamuxFlag.RST) + } + + @Test + fun `frames are sent in order when send buffer is used`() { + val handler = openStreamLocal() + val streamId = readFrameOrThrow().streamId + + val createMessage: (String) -> ByteBuf = + { it.toByteArray().toByteBuf(allocateBuf()) } + + val sendWindowUpdate: (Int) -> Unit = { + ech.writeInbound( + YamuxFrame( + streamId.toMuxId(), + YamuxType.WINDOW_UPDATE, + YamuxFlag.ACK.asSet, + it.toLong() + ) + ) + } + + // approximately every 5 messages window size will be depleted + val messagesToSend = 500 + val customWindowSize = 14 + sendWindowUpdate(-initialWindowSize + customWindowSize) + + val range = 1..messagesToSend + + // 100 window updates should be sent to ensure buffer is flushed and all messages are sent + // so will send them at random times ensuring maxBufferedConnectionWrites can never be reached + val windowUpdatesIndices = (range).chunked(100).flatMap { + it.shuffled().take(20) + } + + for (i in range) { + if (i in windowUpdatesIndices) { + sendWindowUpdate(customWindowSize) + } + handler.ctx.writeAndFlush(createMessage(i.toString())) + } + + val receivedData = generateSequence { + readYamuxFrame() + } + .map { + assertThat(it.data).isNotNull() + String(it.data!!.readAllBytesAndRelease()) + } + .joinToString(separator = "") + + val expectedData = range.joinToString(separator = "") + + assertThat(receivedData).isEqualTo(expectedData) + } + + @Test + fun `test ping`() { + val id: Long = YamuxId.SESSION_STREAM_ID + ech.writeInbound( + YamuxFrame( + id.toMuxId(), + YamuxType.PING, + YamuxFlag.SYN.asSet, + // opaque value, echoed back + 3 + ) + ) + + val pingFrame = readYamuxFrameOrThrow() + + assertThat(pingFrame.flags).containsExactly(YamuxFlag.ACK) + assertThat(pingFrame.type).isEqualTo(YamuxType.PING) + assertThat(pingFrame.length).isEqualTo(3) + } + + @Test + fun `test go away`() { + val id: Long = YamuxId.SESSION_STREAM_ID + ech.writeInbound( + YamuxFrame( + id.toMuxId(), + YamuxType.GO_AWAY, + YamuxFlag.NONE, + // normal termination + 0x2 + ) + ) + + val yamuxHandler = multistreamHandler as YamuxHandler + assertThat(yamuxHandler.goAwayPromise).isCompletedWithValue(0x2) + } + + @Test + fun `test no go away on close`() { + val yamuxHandler = multistreamHandler as YamuxHandler + + assertThat(yamuxHandler.goAwayPromise).isNotDone + ech.close() + assertThat(yamuxHandler.goAwayPromise).isCompletedExceptionally + } + + @Test + fun `opening a stream with wrong streamId parity should throw and close connection`() { + val isRemoteConnectionInitiator = !isLocalConnectionInitiator + val correctRemoteId = 10L + if (isRemoteConnectionInitiator) 1 else 0 + val incorrectId = correctRemoteId + 1 + Assertions.assertThrows(Libp2pException::class.java) { + openStreamRemote(incorrectId) + } + assertThat(ech.isOpen).isFalse() + } + + @Test + fun `negative sendWindowSize should be correctly handled`() { + val handler = openStreamLocal() + val muxId = readFrameOrThrow().streamId.toMuxId() + + val msg = "42".repeat(initialWindowSize + 1).fromHex().toByteBuf(allocateBuf()) + // writing a message which is larger than sendWindowSize + handler.ctx.writeAndFlush(msg) + + // sendWindowSize is 0 now + + // remote party wants to reduce the window by 10 + ech.writeInbound( + YamuxFrame( + muxId, + YamuxType.WINDOW_UPDATE, + YamuxFlag.ACK.asSet, + -10 + ) + ) + + // sendWindowSize is -10 now + + val msgPart1 = readYamuxFrameOrThrow() + assertThat(msgPart1.length).isEqualTo(256L) + assertThat(msgPart1.data!!.readableBytes()).isEqualTo(256) + msgPart1.data!!.release() + + val msgPart2 = readYamuxFrameOrThrow() + assertThat(msgPart2.length.toInt()).isEqualTo(initialWindowSize - 256) + assertThat(msgPart2.data!!.readableBytes()).isEqualTo(initialWindowSize - 256) + msgPart2.data!!.release() + + // ACKing message receive + ech.writeInbound( + YamuxFrame( + muxId, + YamuxType.WINDOW_UPDATE, + YamuxFlag.ACK.asSet, + initialWindowSize.toLong() + ) + ) + + val msgPart3 = readYamuxFrameOrThrow() + assertThat(msgPart3.length).isEqualTo(1L) + assertThat(msgPart3.data!!.readableBytes()).isEqualTo(1) + msgPart3.data!!.release() + } + + @Test + fun `local close for writing should flush buffered data and send close frame on writeWindow update`() { + val handler = openStreamLocal() + val muxId = readFrameOrThrow().streamId.toMuxId() + + val msg = "42".repeat(initialWindowSize + 1).fromHex().toByteBuf(allocateBuf()) + // writing a message which is larger than sendWindowSize + handler.ctx.writeAndFlush(msg) + + val msgPart1 = readYamuxFrameOrThrow() + assertThat(msgPart1.length).isEqualTo(256L) + assertThat(msgPart1.data!!.readableBytes()).isEqualTo(256) + msgPart1.data!!.release() + + val msgPart2 = readYamuxFrameOrThrow() + assertThat(msgPart2.length.toInt()).isEqualTo(initialWindowSize - 256) + assertThat(msgPart2.data!!.readableBytes()).isEqualTo(initialWindowSize - 256) + msgPart2.data!!.release() + + // locally close for writing while some outbound data is still buffered + handler.ctx.disconnect() + + // ACKing message receive + ech.writeInbound( + YamuxFrame( + muxId, + YamuxType.WINDOW_UPDATE, + YamuxFlag.ACK.asSet, + initialWindowSize.toLong() + ) + ) + + val msgPart3 = readYamuxFrameOrThrow() + assertThat(msgPart3.length).isEqualTo(1L) + assertThat(msgPart3.data!!.readableBytes()).isEqualTo(1) + msgPart3.data!!.release() + + val closeFrame = readYamuxFrameOrThrow() + assertThat(closeFrame.flags).containsExactly(YamuxFlag.FIN) + assertThat(closeFrame.length).isEqualTo(0L) + assertThat(closeFrame.data).isNull() + } + + @ParameterizedTest + @ValueSource(booleans = [true, false]) + fun `does not create new stream if ACK backlog limit is reached`(outbound: Boolean) { + val openStream: () -> Unit = { + if (outbound) { + openStreamLocal() + } else { + openStreamRemote() + } + } + for (i in 1..ackBacklogLimit) { + openStream() + } + // opening new stream should fail + val exception = assertThrows { openStream() } + + if (outbound) { + assertThat(exception).hasCauseInstanceOf(AckBacklogLimitExceededMuxerException::class.java) + // expected number of SYN frames have been sent + var synFlagFrames = 0 + do { + val frame = readYamuxFrame() + frame?.let { + assertThat(it.flags).isEqualTo(YamuxFlag.SYN.asSet) + synFlagFrames += 1 + } + } while (frame != null) + assertThat(synFlagFrames).isEqualTo(ackBacklogLimit) + } else { + assertThat(exception).isInstanceOf(AckBacklogLimitExceededMuxerException::class.java) + } + } + + companion object { + private fun YamuxStreamIdGenerator.toIterator() = iterator { + while (true) { + yield(this@toIterator.next()) + } } } } diff --git a/libp2p/src/test/kotlin/io/libp2p/protocol/Blob.kt b/libp2p/src/test/kotlin/io/libp2p/protocol/Blob.kt new file mode 100644 index 000000000..a763ee60f --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/protocol/Blob.kt @@ -0,0 +1,144 @@ +package io.libp2p.protocol + +import io.libp2p.core.BadPeerException +import io.libp2p.core.ConnectionClosedException +import io.libp2p.core.Libp2pException +import io.libp2p.core.Stream +import io.libp2p.core.multistream.StrictProtocolBinding +import io.libp2p.etc.types.completedExceptionally +import io.libp2p.etc.types.lazyVar +import io.libp2p.etc.types.toByteArray +import io.libp2p.etc.types.toHex +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandlerContext +import io.netty.handler.codec.ByteToMessageCodec +import java.time.Duration +import java.util.Collections +import java.util.Random +import java.util.concurrent.CompletableFuture +import java.util.concurrent.Executors +import java.util.concurrent.TimeUnit + +interface BlobController { + fun blob(): CompletableFuture +} + +class Blob(blobSize: Int) : BlobBinding(BlobProtocol(blobSize)) + +open class BlobBinding(blob: BlobProtocol) : + StrictProtocolBinding("/ipfs/blob-echo/1.0.0", blob) + +class BlobTimeoutException : Libp2pException() + +open class BlobProtocol(var blobSize: Int) : ProtocolHandler(Long.MAX_VALUE, Long.MAX_VALUE) { + var timeoutScheduler by lazyVar { Executors.newSingleThreadScheduledExecutor() } + var curTime: () -> Long = { System.currentTimeMillis() } + var random = Random() + var blobTimeout = Duration.ofSeconds(10) + + override fun onStartInitiator(stream: Stream): CompletableFuture { + val handler = BlobInitiator() + stream.pushHandler(BlobCodec()) + stream.pushHandler(handler) + stream.pushHandler(BlobCodec()) + return handler.activeFuture + } + + override fun onStartResponder(stream: Stream): CompletableFuture { + val handler = BlobResponder() + stream.pushHandler(BlobCodec()) + stream.pushHandler(BlobResponder()) + stream.pushHandler(BlobCodec()) + return CompletableFuture.completedFuture(handler) + } + + open class BlobCodec : ByteToMessageCodec() { + override fun encode(ctx: ChannelHandlerContext?, msg: ByteArray, out: ByteBuf) { + println("Codec::encode") + out.writeInt(msg.size) + out.writeBytes(msg) + } + + override fun decode(ctx: ChannelHandlerContext?, msg: ByteBuf, out: MutableList) { + println("Codec::decode " + msg.readableBytes()) + val readerIndex = msg.readerIndex() + if (msg.readableBytes() < 4) { + return + } + val len = msg.readInt() + if (msg.readableBytes() < len) { + // not enough data to read the full array + // will wait for more ... + msg.readerIndex(readerIndex) + return + } + val data = msg.readSlice(len) + out.add(data.toByteArray()) + } + } + + open inner class BlobResponder : ProtocolMessageHandler, BlobController { + override fun onMessage(stream: Stream, msg: ByteArray) { + println("Responder::onMessage") + stream.writeAndFlush(msg) + } + + override fun blob(): CompletableFuture { + throw Libp2pException("This is blob responder only") + } + } + + open inner class BlobInitiator : ProtocolMessageHandler, BlobController { + val activeFuture = CompletableFuture() + val requests = Collections.synchronizedMap(mutableMapOf>>()) + lateinit var stream: Stream + var closed = false + + override fun onActivated(stream: Stream) { + this.stream = stream + activeFuture.complete(this) + } + + override fun onMessage(stream: Stream, msg: ByteArray) { + println("Initiator::onMessage") + val dataS = msg.toHex() + val (sentT, future) = requests.remove(dataS) + ?: throw BadPeerException("Unknown or expired blob data in response: $dataS") + future.complete(curTime() - sentT) + } + + override fun onClosed(stream: Stream) { + synchronized(requests) { + closed = true + requests.values.forEach { it.second.completeExceptionally(ConnectionClosedException()) } + requests.clear() + timeoutScheduler.shutdownNow() + } + activeFuture.completeExceptionally(ConnectionClosedException()) + } + + override fun blob(): CompletableFuture { + val ret = CompletableFuture() + val arr = ByteArray(blobSize) + random.nextBytes(arr) + val dataS = arr.toHex() + + synchronized(requests) { + if (closed) return completedExceptionally(ConnectionClosedException()) + requests[dataS] = curTime() to ret + + timeoutScheduler.schedule( + { + requests.remove(dataS)?.second?.completeExceptionally(BlobTimeoutException()) + }, + blobTimeout.toMillis(), + TimeUnit.MILLISECONDS + ) + } + + println("Sender writing " + blobSize) + stream.writeAndFlush(arr) + return ret + } + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/MaxCountTopicSubscriptionFilterTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/MaxCountTopicSubscriptionFilterTest.kt index 0cc30b190..f3483fed7 100644 --- a/libp2p/src/test/kotlin/io/libp2p/pubsub/MaxCountTopicSubscriptionFilterTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/MaxCountTopicSubscriptionFilterTest.kt @@ -134,7 +134,8 @@ internal class MaxCountTopicSubscriptionFilterTest { PubsubSubscription("allow_10", true), ) val result = filter.filterIncomingSubscriptions( - subscriptions, listOf("allow_1", "allow_2", "allow_3", "allow_4", "allow_5", "allow_6", "allow_7") + subscriptions, + listOf("allow_1", "allow_2", "allow_3", "allow_4", "allow_5", "allow_6", "allow_7") ) assertThat(result).isEqualTo(subscriptions) } diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/PubsubRouterTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/PubsubRouterTest.kt index af08bca7c..cc45118e0 100644 --- a/libp2p/src/test/kotlin/io/libp2p/pubsub/PubsubRouterTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/PubsubRouterTest.kt @@ -1,7 +1,13 @@ package io.libp2p.pubsub -import io.libp2p.core.pubsub.* +import io.libp2p.core.pubsub.MessageApi +import io.libp2p.core.pubsub.RESULT_INVALID +import io.libp2p.core.pubsub.RESULT_VALID +import io.libp2p.core.pubsub.Subscriber import io.libp2p.core.pubsub.Topic +import io.libp2p.core.pubsub.ValidationResult +import io.libp2p.core.pubsub.Validator +import io.libp2p.core.pubsub.createPubsubApi import io.libp2p.etc.types.seconds import io.libp2p.etc.types.toByteBuf import io.libp2p.etc.types.toBytesBigEndian @@ -10,6 +16,7 @@ import io.libp2p.pubsub.gossip.GossipRouter import io.libp2p.tools.TestChannel.TestConnection import io.netty.handler.logging.LogLevel import io.netty.util.ResourceLeakDetector +import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test import pubsub.pb.Rpc @@ -279,7 +286,10 @@ abstract class PubsubRouterTest(val routerFactory: DeterministicFuzzRouterFactor doTenNeighborsTopology() } - fun doTenNeighborsTopology(randomSeed: Int = 0, routerFactory: DeterministicFuzzRouterFactory = this.routerFactory) { + fun doTenNeighborsTopology( + randomSeed: Int = 0, + routerFactory: DeterministicFuzzRouterFactory = this.routerFactory + ) { val fuzz = DeterministicFuzz().also { it.randomSeed = randomSeed.toLong() } @@ -296,7 +306,7 @@ abstract class PubsubRouterTest(val routerFactory: DeterministicFuzzRouterFactor } for (i in 0 until nodesCount) { for (j in 1..neighboursCount / 2) - allConnections += allRouters[i].connectSemiDuplex(allRouters[(i + j) % 21]/*, pubsubLogs = LogLevel.ERROR*/) + allConnections += allRouters[i].connectSemiDuplex(allRouters[(i + j) % 21]) .connections } @@ -398,6 +408,7 @@ abstract class PubsubRouterTest(val routerFactory: DeterministicFuzzRouterFactor routers[1].connectSemiDuplex(routers[2], pubsubLogs = LogLevel.ERROR) val apis = routers.map { createPubsubApi(it.router) } + class RecordingSubscriber : Subscriber { var count = 0 override fun accept(t: MessageApi) { @@ -409,7 +420,10 @@ abstract class PubsubRouterTest(val routerFactory: DeterministicFuzzRouterFactor val subs2 = topics .map { it to RecordingSubscriber() } - .map { apis[2].subscribe(it.second, it.first); it.second } + .map { + apis[2].subscribe(it.second, it.first) + it.second + } val scheduler = fuzz.createControlledExecutor() val delayed = { result: ValidationResult, delayMs: Long -> @@ -457,4 +471,54 @@ abstract class PubsubRouterTest(val routerFactory: DeterministicFuzzRouterFactor Assertions.assertEquals(2, subs2[2].count) Assertions.assertEquals(0, subs2[3].count) } + + @Test + fun `getPeerTopics() should return immutable snapshot`() { + val fuzz = DeterministicFuzz() + + fun executeAsyncNow(asyncTask: () -> CompletableFuture): T { + val future = asyncTask() + fuzz.timeController.addTime(Duration.ofMillis(1)) + if (!future.isDone) throw AssertionError("Async task was not complete within virtual 1ms") + return future.join() + } + + val router1 = fuzz.createTestRouter(routerFactory) + val router2 = fuzz.createTestRouter(routerFactory) + router2.router.subscribe("topic1") + + router1.connectSemiDuplex(router2, LogLevel.DEBUG, LogLevel.DEBUG) + + val peerTopics1 = executeAsyncNow { router1.router.getPeerTopics() } + val peerTopics1MapIt = peerTopics1.entries.iterator() + val peerTopics1SetIt = peerTopics1.entries.first().value.iterator() + + router2.router.subscribe("topic2") + + val router3 = fuzz.createTestRouter(routerFactory) + router3.router.subscribe("topic3") + router1.connectSemiDuplex(router3, LogLevel.DEBUG, LogLevel.DEBUG) + + val peerTopics2 = executeAsyncNow { router1.router.getPeerTopics() } + + assertThat(peerTopics2) + .containsExactlyInAnyOrderEntriesOf( + mapOf( + router2.peerId to setOf("topic1", "topic2"), + router3.peerId to setOf("topic3") + ) + ) + + assertThat(peerTopics1) + .containsExactlyInAnyOrderEntriesOf( + mapOf( + router2.peerId to setOf("topic1") + ) + ) + + assertThat(peerTopics1MapIt.next().key).isEqualTo(router2.peerId) + assertThat(peerTopics1MapIt.hasNext()).isFalse() + assertThat(peerTopics1SetIt.next()).isEqualTo("topic1") + assertThat(peerTopics1SetIt.hasNext()).isFalse() + } } diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/SeenCacheTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/SeenCacheTest.kt index 067604ec5..08081ed9b 100644 --- a/libp2p/src/test/kotlin/io/libp2p/pubsub/SeenCacheTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/SeenCacheTest.kt @@ -24,6 +24,29 @@ fun createMessage(number: Int): Rpc.Message { fun createPubsubMessage(number: Int) = TestPubsubMessage(createMessage(number)) fun createPubsubMessage(number: Int, fastId: Int) = TestPubsubMessage(createMessage(number)).also { it.fastID = fastId } +operator fun SeenCache<*>.minusAssign(msg: PubsubMessage) = this.remove(msg.messageId) + +fun assertContainsEntry(cache: SeenCache, fakeMsg: Int) { + assertThat(cache.isSeen(createPubsubMessage(fakeMsg))).isTrue() + assertThat(cache.isSeen(createPubsubMessage(fakeMsg).messageId)).isTrue() + assertThat(cache.get(createPubsubMessage(fakeMsg))).isEqualTo(fakeMsg.toString()) +} +fun assertContainsEntries(cache: SeenCache, vararg fakeMsgs: Int) { + fakeMsgs.forEach { + assertContainsEntry(cache, it) + } +} + +fun assertDoesntContainEntry(cache: SeenCache, fakeMsg: Int) { + assertThat(cache.isSeen(createPubsubMessage(fakeMsg))).isFalse() + assertThat(cache.isSeen(createPubsubMessage(fakeMsg).messageId)).isFalse() + assertThat(cache.get(createPubsubMessage(fakeMsg))).isNull() +} +fun assertDoesntContainEntries(cache: SeenCache, vararg fakeMsgs: Int) { + fakeMsgs.forEach { + assertDoesntContainEntry(cache, it) + } +} class TestPubsubMessage(override val protobufMessage: Rpc.Message) : PubsubMessage { var canonicalIdCalculator: (Rpc.Message) -> WBytes = { @@ -57,48 +80,43 @@ fun genericSanityTest(cache: SeenCache) { cache[createPubsubMessage(1)] = "1" assertThat(cache.size).isEqualTo(1) - assertThat(cache.messages).containsExactly(createPubsubMessage(1)) assertThat(cache.isSeen(createPubsubMessage(1))).isTrue() assertThat(cache.isSeen(createPubsubMessage(2))).isFalse() - assertThat(cache.getValue(createPubsubMessage(1))).isEqualTo("1") - assertThat(cache.getValue(createPubsubMessage(2))).isNull() - assertThat(cache.getSeenMessage(createPubsubMessage(1))).isEqualTo(createPubsubMessage(1)) + assertThat(cache.get(createPubsubMessage(1))).isEqualTo("1") + assertThat(cache.get(createPubsubMessage(2))).isNull() + assertThat(cache.getSeenMessageCached(createPubsubMessage(1))).isEqualTo(createPubsubMessage(1)) cache[createPubsubMessage(1)] = "1-1" assertThat(cache.size).isEqualTo(1) - assertThat(cache.messages).containsExactly(createPubsubMessage(1)) assertThat(cache.isSeen(createPubsubMessage(1))).isTrue() assertThat(cache.isSeen(createPubsubMessage(2))).isFalse() - assertThat(cache.getValue(createPubsubMessage(1))).isEqualTo("1-1") - assertThat(cache.getValue(createPubsubMessage(2))).isNull() + assertThat(cache.get(createPubsubMessage(1))).isEqualTo("1-1") + assertThat(cache.get(createPubsubMessage(2))).isNull() cache[createPubsubMessage(2)] = "2" assertThat(cache.size).isEqualTo(2) - assertThat(cache.messages).containsExactly(createPubsubMessage(1), createPubsubMessage(2)) assertThat(cache.isSeen(createPubsubMessage(1))).isTrue() assertThat(cache.isSeen(createPubsubMessage(2))).isTrue() - assertThat(cache.getValue(createPubsubMessage(1))).isEqualTo("1-1") - assertThat(cache.getValue(createPubsubMessage(2))).isEqualTo("2") + assertThat(cache.get(createPubsubMessage(1))).isEqualTo("1-1") + assertThat(cache.get(createPubsubMessage(2))).isEqualTo("2") cache -= createPubsubMessage(1) assertThat(cache.size).isEqualTo(1) - assertThat(cache.messages).containsExactly(createPubsubMessage(2)) assertThat(cache.isSeen(createPubsubMessage(1))).isFalse() assertThat(cache.isSeen(createPubsubMessage(2))).isTrue() - assertThat(cache.getValue(createPubsubMessage(1))).isNull() - assertThat(cache.getValue(createPubsubMessage(2))).isEqualTo("2") + assertThat(cache.get(createPubsubMessage(1))).isNull() + assertThat(cache.get(createPubsubMessage(2))).isEqualTo("2") cache -= createPubsubMessage(2) assertThat(cache.size).isEqualTo(0) - assertThat(cache.messages).isEmpty() assertThat(cache.isSeen(createPubsubMessage(1))).isFalse() assertThat(cache.isSeen(createPubsubMessage(2))).isFalse() - assertThat(cache.getValue(createPubsubMessage(1))).isNull() - assertThat(cache.getValue(createPubsubMessage(2))).isNull() + assertThat(cache.get(createPubsubMessage(1))).isNull() + assertThat(cache.get(createPubsubMessage(2))).isNull() } class LRUSeenCacheTest { @@ -112,7 +130,6 @@ class LRUSeenCacheTest { assertThat(lruCache.evictingQueue).isEmpty() assertThat(backingCache.size).isEqualTo(0) - assertThat(backingCache.messages).isEmpty() } @Test @@ -124,38 +141,25 @@ class LRUSeenCacheTest { lruCache[createPubsubMessage(3)] = "3" assertThat(lruCache.size).isEqualTo(3) - assertThat(lruCache.messages).containsExactly( - createPubsubMessage(1), - createPubsubMessage(2), - createPubsubMessage(3) - ) + assertContainsEntries(lruCache, 1, 2, 3) lruCache[createPubsubMessage(4)] = "4" assertThat(lruCache.size).isEqualTo(3) - assertThat(lruCache.messages).containsExactly( - createPubsubMessage(2), - createPubsubMessage(3), - createPubsubMessage(4) - ) + assertDoesntContainEntry(lruCache, 1) + assertContainsEntries(lruCache, 2, 3, 4) lruCache[createPubsubMessage(5)] = "5" assertThat(lruCache.size).isEqualTo(3) - assertThat(lruCache.messages).containsExactly( - createPubsubMessage(3), - createPubsubMessage(4), - createPubsubMessage(5) - ) + assertDoesntContainEntries(lruCache, 1, 2) + assertContainsEntries(lruCache, 3, 4, 5) lruCache[createPubsubMessage(1)] = "1" assertThat(lruCache.size).isEqualTo(3) - assertThat(lruCache.messages).containsExactly( - createPubsubMessage(4), - createPubsubMessage(5), - createPubsubMessage(1) - ) + assertDoesntContainEntries(lruCache, 2, 3) + assertContainsEntries(lruCache, 1, 4, 5) assertThat(backingCache.size).isEqualTo(3) } @@ -188,11 +192,8 @@ class LRUSeenCacheTest { lruCache[createPubsubMessage(6)] = "6" assertThat(lruCache.size).isEqualTo(3) - assertThat(lruCache.messages).containsExactly( - createPubsubMessage(3), - createPubsubMessage(4), - createPubsubMessage(6) - ) + assertDoesntContainEntries(lruCache, 1, 2, 5) + assertContainsEntries(lruCache, 3, 4, 6) lruCache -= createPubsubMessage(3) lruCache -= createPubsubMessage(4) @@ -207,21 +208,18 @@ class LRUSeenCacheTest { lruCache[createPubsubMessage(9)] = "9" assertThat(lruCache.size).isEqualTo(3) - assertThat(lruCache.messages).containsExactly( - createPubsubMessage(7), - createPubsubMessage(8), - createPubsubMessage(9) - ) + assertDoesntContainEntries(lruCache, 1, 2, 3, 4, 5, 6) + assertContainsEntries(lruCache, 7, 8, 9) lruCache -= createPubsubMessage(7) lruCache -= createPubsubMessage(8) lruCache -= createPubsubMessage(9) assertThat(lruCache.size).isEqualTo(0) - assertThat(lruCache.messages).isEmpty() + assertDoesntContainEntries(lruCache, 1, 2, 3, 4, 5, 6, 7, 8, 9) assertThat(lruCache.evictingQueue).isEmpty() assertThat(backingCache.size).isEqualTo(0) - assertThat(backingCache.messages).isEmpty() + assertDoesntContainEntries(backingCache, 1, 2, 3, 4, 5, 6, 7, 8, 9) } } @@ -247,50 +245,35 @@ class TTLSeenCacheTest { ttlCache[createPubsubMessage(3)] = "3" assertThat(ttlCache.size).isEqualTo(3) - assertThat(ttlCache.messages).containsExactly( - createPubsubMessage(1), - createPubsubMessage(2), - createPubsubMessage(3) - ) + assertContainsEntries(ttlCache, 1, 2, 3) time.set(1001) ttlCache[createPubsubMessage(4)] = "4" assertThat(ttlCache.size).isEqualTo(3) - assertThat(ttlCache.messages).containsExactly( - createPubsubMessage(2), - createPubsubMessage(3), - createPubsubMessage(4) - ) + assertDoesntContainEntries(ttlCache, 1) + assertContainsEntries(ttlCache, 2, 3, 4) time.set(1002) ttlCache[createPubsubMessage(5)] = "5" assertThat(ttlCache.size).isEqualTo(4) - assertThat(ttlCache.messages).containsExactly( - createPubsubMessage(2), - createPubsubMessage(3), - createPubsubMessage(4), - createPubsubMessage(5) - ) + assertDoesntContainEntries(ttlCache, 1) + assertContainsEntries(ttlCache, 2, 3, 4, 5) time.set(1102) ttlCache[createPubsubMessage(1)] = "1" assertThat(ttlCache.size).isEqualTo(3) - assertThat(ttlCache.messages).containsExactly( - createPubsubMessage(4), - createPubsubMessage(5), - createPubsubMessage(1) - ) + assertDoesntContainEntries(ttlCache, 2, 3) + assertContainsEntries(ttlCache, 1, 4, 5) time.set(3000) ttlCache[createPubsubMessage(6)] = "6" assertThat(ttlCache.size).isEqualTo(1) - assertThat(ttlCache.messages).containsExactly( - createPubsubMessage(6) - ) + assertDoesntContainEntries(ttlCache, 1, 2, 3, 4, 5) + assertContainsEntries(ttlCache, 6) assertThat(ttlCache.putTimes.size).isLessThan(2) } @@ -324,9 +307,8 @@ class TTLSeenCacheTest { ttlCache[createPubsubMessage(6)] = "6" assertThat(ttlCache.size).isEqualTo(1) - assertThat(ttlCache.messages).containsExactly( - createPubsubMessage(6) - ) + assertDoesntContainEntries(ttlCache, 1, 2, 3, 4, 5) + assertContainsEntries(ttlCache, 6) assertThat(ttlCache.putTimes.size).isLessThan(2) } @@ -381,14 +363,14 @@ class FastIdSeenCacheTest { assertThat(m1_1.canonicalId).isNotNull() assertThat(m1_2.canonicalId).isNull() - val m1_3 = cache.getSeenMessage(m1_2) as TestPubsubMessage - assertThat(m1_3.canonicalId).isEqualTo(m1_1.canonicalId) + val m1_3 = cache.getSeenMessageCached(m1_2) + assertThat(m1_3.messageId).isEqualTo(m1_1.canonicalId) assertThat(m1_2.canonicalId).isNull() assertThat(m1_2 in cache).isTrue() assertThat(m1_2.canonicalId).isNull() - assertThat(cache.getValue(m1_2)).isEqualTo("1-1") + assertThat(cache.get(m1_2)).isEqualTo("1-1") assertThat(m1_2.canonicalId).isNull() } @@ -401,14 +383,14 @@ class FastIdSeenCacheTest { cache[m1_1] = "1-1" assertThat(m1_1 in cache).isTrue() assertThat(m1_2 in cache).isTrue() - assertThat(cache.getValue(m1_1)).isEqualTo("1-1") - assertThat(cache.getValue(m1_2)).isEqualTo("1-1") + assertThat(cache.get(m1_1)).isEqualTo("1-1") + assertThat(cache.get(m1_2)).isEqualTo("1-1") cache[m1_2] = "1-2" assertThat(m1_1 in cache).isTrue() assertThat(m1_2 in cache).isTrue() - assertThat(cache.getValue(m1_1)).isEqualTo("1-2") - assertThat(cache.getValue(m1_2)).isEqualTo("1-2") + assertThat(cache.get(m1_1)).isEqualTo("1-2") + assertThat(cache.get(m1_2)).isEqualTo("1-2") val m1_1_1 = createPubsubMessage(1, 1) val m1_2_1 = createPubsubMessage(1, 2) @@ -420,8 +402,8 @@ class FastIdSeenCacheTest { cache -= m1_1 assertThat(m1_1 in cache).isFalse() assertThat(m1_2 in cache).isFalse() - assertThat(cache.getValue(m1_1)).isNull() - assertThat(cache.getValue(m1_2)).isNull() + assertThat(cache.get(m1_1)).isNull() + assertThat(cache.get(m1_2)).isNull() assertThat(cache.fastIdMap.isEmpty()).isTrue() assertThat(cache.slowIdMap).isEmpty() diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/DefaultGossipScoreTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/DefaultGossipScoreTest.kt index ffeceeb3f..7f5056199 100644 --- a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/DefaultGossipScoreTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/DefaultGossipScoreTest.kt @@ -732,7 +732,6 @@ class DefaultGossipScoreTest { @Test fun `test IP colocation penalty`() { - val addr1 = Multiaddr.fromString("/ip4/0.0.0.1") val addr2 = Multiaddr.fromString("/ip4/0.0.0.2") val peer1 = PeerId.random() diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipBackwardCompatibilityTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipBackwardCompatibilityTest.kt index 530318200..66b6f1c25 100644 --- a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipBackwardCompatibilityTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipBackwardCompatibilityTest.kt @@ -6,49 +6,49 @@ import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test class GossipBackwardCompatibilityTest : TwoGossipHostTestBase() { - override val router1 = GossipRouterBuilder(protocol = PubsubProtocol.Gossip_V_1_0).build() - override val router2 = GossipRouterBuilder(protocol = PubsubProtocol.Gossip_V_1_1).build() + override val router1 = GossipRouterBuilder(protocol = PubsubProtocol.Gossip_V_1_1).build() + override val router2 = GossipRouterBuilder(protocol = PubsubProtocol.Gossip_V_1_2).build() @Test - fun testConnect_1_0_to_1_1() { + fun testConnect_1_1_to_1_2() { connect() Assertions.assertEquals( - PubsubProtocol.Gossip_V_1_0.announceStr, + PubsubProtocol.Gossip_V_1_1.announceStr, router1.peers[0].getInboundHandler()!!.stream.getProtocol().get() ) Assertions.assertEquals( - PubsubProtocol.Gossip_V_1_0.announceStr, + PubsubProtocol.Gossip_V_1_1.announceStr, router1.peers[0].getOutboundHandler()!!.stream.getProtocol().get() ) Assertions.assertEquals( - PubsubProtocol.Gossip_V_1_0.announceStr, + PubsubProtocol.Gossip_V_1_1.announceStr, router2.peers[0].getInboundHandler()!!.stream.getProtocol().get() ) Assertions.assertEquals( - PubsubProtocol.Gossip_V_1_0.announceStr, + PubsubProtocol.Gossip_V_1_1.announceStr, router2.peers[0].getOutboundHandler()!!.stream.getProtocol().get() ) } @Test - fun testConnect_1_1_to_1_0() { + fun testConnect_1_2_to_1_1() { connect() Assertions.assertEquals( - PubsubProtocol.Gossip_V_1_0.announceStr, + PubsubProtocol.Gossip_V_1_1.announceStr, router1.peers[0].getInboundHandler()!!.stream.getProtocol().get() ) Assertions.assertEquals( - PubsubProtocol.Gossip_V_1_0.announceStr, + PubsubProtocol.Gossip_V_1_1.announceStr, router1.peers[0].getOutboundHandler()!!.stream.getProtocol().get() ) Assertions.assertEquals( - PubsubProtocol.Gossip_V_1_0.announceStr, + PubsubProtocol.Gossip_V_1_1.announceStr, router2.peers[0].getInboundHandler()!!.stream.getProtocol().get() ) Assertions.assertEquals( - PubsubProtocol.Gossip_V_1_0.announceStr, + PubsubProtocol.Gossip_V_1_1.announceStr, router2.peers[0].getOutboundHandler()!!.stream.getProtocol().get() ) } diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipPubsubRouterTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipPubsubRouterTest.kt index b63fcb2af..80297d123 100644 --- a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipPubsubRouterTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipPubsubRouterTest.kt @@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit class GossipPubsubRouterTest : PubsubRouterTest( createGossipFuzzRouterFactory { - GossipRouterBuilder(params = GossipParams(3, 3, 100, floodPublish = false)) + GossipRouterBuilder(params = GossipParams(3, 3, 100, floodPublishMaxMessageSizeThreshold = NEVER_FLOOD_PUBLISH)) } ) { @@ -59,7 +59,7 @@ class GossipPubsubRouterTest : PubsubRouterTest( // this is to test ihave/iwant fuzz.timeController.addTime(Duration.ofMillis(1)) - val r = { GossipRouterBuilder(params = GossipParams(3, 3, 3, DOut = 0, DLazy = 1000, floodPublish = false)) } + val r = { GossipRouterBuilder(params = GossipParams(3, 3, 3, DOut = 0, DLazy = 1000, floodPublishMaxMessageSizeThreshold = NEVER_FLOOD_PUBLISH)) } val routerCenter = fuzz.createTestGossipRouter(r) allRouters.add(0, routerCenter) diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRouterListLimitsTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRouterListLimitsTest.kt index d29ec1e53..81dc3d761 100644 --- a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRouterListLimitsTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRouterListLimitsTest.kt @@ -1,5 +1,6 @@ package io.libp2p.pubsub.gossip +import io.libp2p.pubsub.Topic import io.libp2p.pubsub.gossip.builders.GossipParamsBuilder import io.libp2p.pubsub.gossip.builders.GossipRouterBuilder import io.libp2p.tools.protobuf.RpcBuilder @@ -15,7 +16,7 @@ class GossipRouterListLimitsTest { private val maxIWantMessageIds = 14 private val maxGraftMessages = 15 private val maxPruneMessages = 16 - private val maxPeersPerPruneMessage = 17 + private val maxPeersAcceptedInPruneMsg = 17 private val gossipParamsWithLimits = GossipParamsBuilder() .maxPublishedMessages(maxPublishedMessages) @@ -25,7 +26,7 @@ class GossipRouterListLimitsTest { .maxIWantMessageIds(maxIWantMessageIds) .maxGraftMessages(maxGraftMessages) .maxPruneMessages(maxPruneMessages) - .maxPeersPerPruneMessage(maxPeersPerPruneMessage) + .maxPeersAcceptedInPruneMsg(maxPeersAcceptedInPruneMsg) .build() private val gossipParamsNoLimits = GossipParamsBuilder() @@ -35,6 +36,8 @@ class GossipRouterListLimitsTest { private val routerWithLimits = GossipRouterBuilder(params = gossipParamsWithLimits).build() private val routerWithNoLimits = GossipRouterBuilder(params = gossipParamsNoLimits).build() + private val topic: Topic = "topic1" + @Test fun validateProtobufLists_validMessage() { val msg = fullMsgBuilder().build() @@ -44,7 +47,7 @@ class GossipRouterListLimitsTest { @Test fun validateProtobufLists_validMessageWithLargeLists_noLimits() { - val msg = fullMsgBuilder(20).build() + val msg = fullMsgBuilder(16).build() Assertions.assertThat(routerWithNoLimits.validateMessageListLimits(msg)).isTrue() } @@ -96,7 +99,7 @@ class GossipRouterListLimitsTest { @Test fun validateProtobufLists_tooManyIHaves() { val builder = fullMsgBuilder() - builder.addIHaves(maxIHaveLength, 1) + builder.addIHaves(maxIHaveLength, 1, topic) val msg = builder.build() Assertions.assertThat(routerWithLimits.validateMessageListLimits(msg)).isFalse() @@ -105,7 +108,7 @@ class GossipRouterListLimitsTest { @Test fun validateProtobufLists_tooManyIHaveMsgIds() { val builder = fullMsgBuilder() - builder.addIHaves(1, maxIHaveLength) + builder.addIHaves(1, maxIHaveLength, topic) val msg = builder.build() Assertions.assertThat(routerWithLimits.validateMessageListLimits(msg)).isFalse() @@ -148,9 +151,9 @@ class GossipRouterListLimitsTest { } @Test - fun validateProtobufLists_tooManyPrunePeers() { + fun validateProtobufLists_tooManyPeersToAcceptInPruneMsg() { val builder = fullMsgBuilder() - builder.addPrunes(1, maxPeersPerPruneMessage + 1) + builder.addPrunes(1, maxPeersAcceptedInPruneMsg + 1) val msg = builder.build() Assertions.assertThat(routerWithLimits.validateMessageListLimits(msg)).isFalse() @@ -186,7 +189,7 @@ class GossipRouterListLimitsTest { @Test fun validateProtobufLists_maxIHaves() { val builder = fullMsgBuilder() - builder.addIHaves(maxIHaveLength - 1, 1) + builder.addIHaves(maxIHaveLength - 1, 1, topic) val msg = builder.build() Assertions.assertThat(routerWithLimits.validateMessageListLimits(msg)).isTrue() @@ -195,7 +198,7 @@ class GossipRouterListLimitsTest { @Test fun validateProtobufLists_maxIHaveMsgIds() { val builder = fullMsgBuilder() - builder.addIHaves(1, maxIHaveLength - 1) + builder.addIHaves(1, maxIHaveLength - 1, topic) val msg = builder.build() Assertions.assertThat(routerWithLimits.validateMessageListLimits(msg)).isTrue() @@ -238,9 +241,9 @@ class GossipRouterListLimitsTest { } @Test - fun validateProtobufLists_maxPrunePeers() { + fun validateProtobufLists_maxPeersAcceptedInPruneMsg() { val builder = fullMsgBuilder() - builder.addPrunes(1, maxPeersPerPruneMessage - 1) + builder.addPrunes(1, maxPeersAcceptedInPruneMsg - 1) val msg = builder.build() Assertions.assertThat(routerWithLimits.validateMessageListLimits(msg)).isTrue() @@ -256,7 +259,7 @@ class GossipRouterListLimitsTest { // Add some data to all possible fields builder.addSubscriptions(listSize) builder.addPublishMessages(listSize, listSize) - builder.addIHaves(listSize, listSize) + builder.addIHaves(listSize, listSize, topic) builder.addIWants(listSize, listSize) builder.addGrafts(listSize) builder.addPrunes(listSize, listSize) diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueueTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueueTest.kt index c5cc7c85c..5b6b35e55 100644 --- a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueueTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueueTest.kt @@ -3,6 +3,7 @@ package io.libp2p.pubsub.gossip import io.libp2p.core.PeerId import io.libp2p.etc.types.toProtobuf import io.libp2p.etc.types.toWBytes +import io.libp2p.pubsub.Topic import io.libp2p.pubsub.gossip.builders.GossipParamsBuilder import io.libp2p.pubsub.gossip.builders.GossipRouterBuilder import org.assertj.core.api.Assertions.assertThat @@ -49,7 +50,7 @@ class GossipRpcPartsQueueTest { queue.addPublish(createRpcMessage("topic-$it", "data")) } (1..iHaves).forEach { - queue.addIHave(byteArrayOf(it.toByte()).toWBytes()) + queue.addIHave(byteArrayOf(it.toByte()).toWBytes(), "topic-$it") } (1..iWants).forEach { queue.addIWant(byteArrayOf(it.toByte()).toWBytes()) @@ -259,4 +260,50 @@ class GossipRpcPartsQueueTest { assertThat(msgs).hasSize(3) assertThat(msgs.merge()).isEqualTo(single) } + + @Test + fun `check that resulting IHAVE sets the topic ID`() { + val topic1: Topic = "topic1" + val messageId1 = "1111".toWBytes() + val topic2: Topic = "topic2" + val messageId2 = "2222".toWBytes() + val partsQueue = TestGossipQueue(gossipParamsWithLimits) + partsQueue.addIHave(messageId1, topic1) + partsQueue.addIHave(messageId2, topic2) + val res = partsQueue.takeMerged().first() + + val serialized = res.toByteArray() + val deserializedRpc = Rpc.RPC.parseFrom(serialized) + assertThat(deserializedRpc.control.ihaveList).containsExactlyInAnyOrder( + Rpc.ControlIHave.newBuilder().setTopicID(topic1).addMessageIDs(messageId1.toProtobuf()).build(), + Rpc.ControlIHave.newBuilder().setTopicID(topic2).addMessageIDs(messageId2.toProtobuf()).build(), + ) + } + + @Test + fun `check that resulting IHAVE correctly groups topics`() { + val partsQueue = TestGossipQueue(gossipParamsWithLimits) + + partsQueue.addIHave("1111".toWBytes(), "topic1") + partsQueue.addIHave("2222".toWBytes(), "topic2") + partsQueue.addIHave("3333".toWBytes(), "topic1") + + val res = partsQueue.takeMerged().first() + + val serialized = res.toByteArray() + val deserializedRpc = Rpc.RPC.parseFrom(serialized) + assertThat(deserializedRpc.control.ihaveList).containsExactlyInAnyOrder( + Rpc.ControlIHave.newBuilder() + .setTopicID("topic1") + .addAllMessageIDs( + listOf( + "1111".toWBytes().toProtobuf(), + "3333".toWBytes().toProtobuf() + ) + ).build(), + Rpc.ControlIHave.newBuilder() + .setTopicID("topic2") + .addMessageIDs("2222".toWBytes().toProtobuf()).build(), + ) + } } diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipTestsBase.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipTestsBase.kt new file mode 100644 index 000000000..ecc912256 --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipTestsBase.kt @@ -0,0 +1,76 @@ +package io.libp2p.pubsub.gossip + +import io.libp2p.core.PeerId +import io.libp2p.etc.types.toBytesBigEndian +import io.libp2p.etc.types.toProtobuf +import io.libp2p.etc.types.toWBytes +import io.libp2p.pubsub.* +import io.libp2p.pubsub.DeterministicFuzz.Companion.createGossipFuzzRouterFactory +import io.libp2p.pubsub.DeterministicFuzz.Companion.createMockFuzzRouterFactory +import io.libp2p.pubsub.gossip.builders.GossipRouterBuilder +import io.netty.handler.logging.LogLevel +import pubsub.pb.Rpc + +abstract class GossipTestsBase { + + protected val GossipScore.testPeerScores get() = (this as DefaultGossipScore).peerScores + + protected fun newProtoMessage(topic: Topic, seqNo: Long, data: ByteArray) = + Rpc.Message.newBuilder() + .addTopicIDs(topic) + .setSeqno(seqNo.toBytesBigEndian().toProtobuf()) + .setData(data.toProtobuf()) + .build() + + protected fun newMessage(topic: Topic, seqNo: Long, data: ByteArray) = + DefaultPubsubMessage(newProtoMessage(topic, seqNo, data)) + + protected fun getMessageId(msg: Rpc.Message): MessageId = msg.from.toWBytes() + msg.seqno.toWBytes() + + class ManyRoutersTest( + val mockRouterCount: Int = 10, + val params: GossipParams = GossipParams(), + val scoreParams: GossipScoreParams = GossipScoreParams(), + val protocol: PubsubProtocol = PubsubProtocol.Gossip_V_1_1 + ) { + val fuzz = DeterministicFuzz() + val gossipRouterBuilderFactory = { GossipRouterBuilder(protocol = protocol, params = params, scoreParams = scoreParams) } + val router0 = fuzz.createTestRouter(createGossipFuzzRouterFactory(gossipRouterBuilderFactory)) + val routers = (0 until mockRouterCount).map { fuzz.createTestRouter(createMockFuzzRouterFactory()) } + val connections = mutableListOf() + val gossipRouter = router0.router as GossipRouter + val mockRouters = routers.map { it.router as MockRouter } + + fun connectAll() = connect(routers.indices) + fun connect(routerIndexes: IntRange, outbound: Boolean = true): List { + val list = + routers.slice(routerIndexes).map { + if (outbound) { + router0.connectSemiDuplex(it, null, LogLevel.ERROR) + } else { + it.connectSemiDuplex(router0, null, LogLevel.ERROR) + } + } + connections += list + return list + } + + fun getMockRouter(peerId: PeerId) = mockRouters[routers.indexOfFirst { it.peerId == peerId }] + } + + class TwoRoutersTest( + val coreParams: GossipParams = GossipParams(), + val scoreParams: GossipScoreParams = GossipScoreParams(), + val mockRouterFactory: DeterministicFuzzRouterFactory = createMockFuzzRouterFactory(), + val protocol: PubsubProtocol = PubsubProtocol.Gossip_V_1_1 + ) { + val fuzz = DeterministicFuzz() + val gossipRouterBuilderFactory = { GossipRouterBuilder(protocol = protocol, params = coreParams, scoreParams = scoreParams) } + val router1 = fuzz.createTestRouter(createGossipFuzzRouterFactory(gossipRouterBuilderFactory)) + val router2 = fuzz.createTestRouter(mockRouterFactory) + val gossipRouter = router1.router as GossipRouter + val mockRouter = router2.router as MockRouter + + val connection = router1.connectSemiDuplex(router2, null, LogLevel.ERROR) + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_1Tests.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_1Tests.kt index 6dbd3fa88..905a6b489 100644 --- a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_1Tests.kt +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_1Tests.kt @@ -1,26 +1,13 @@ +@file:Suppress("ktlint:standard:class-naming") + package io.libp2p.pubsub.gossip import com.google.common.util.concurrent.AtomicDouble +import com.google.protobuf.ByteString import io.libp2p.core.PeerId -import io.libp2p.core.pubsub.MessageApi -import io.libp2p.core.pubsub.RESULT_IGNORE -import io.libp2p.core.pubsub.RESULT_INVALID -import io.libp2p.core.pubsub.RESULT_VALID -import io.libp2p.core.pubsub.Subscriber -import io.libp2p.core.pubsub.ValidationResult -import io.libp2p.core.pubsub.Validator -import io.libp2p.core.pubsub.createPubsubApi -import io.libp2p.etc.types.millis -import io.libp2p.etc.types.minutes -import io.libp2p.etc.types.seconds -import io.libp2p.etc.types.times -import io.libp2p.etc.types.toBytesBigEndian -import io.libp2p.etc.types.toProtobuf -import io.libp2p.etc.types.toWBytes -import io.libp2p.pubsub.* -import io.libp2p.pubsub.DeterministicFuzz.Companion.createGossipFuzzRouterFactory -import io.libp2p.pubsub.DeterministicFuzz.Companion.createMockFuzzRouterFactory -import io.libp2p.pubsub.gossip.builders.GossipRouterBuilder +import io.libp2p.core.pubsub.* +import io.libp2p.etc.types.* +import io.libp2p.pubsub.MockRouter import io.netty.buffer.ByteBuf import io.netty.buffer.Unpooled import io.netty.channel.ChannelHandler @@ -28,10 +15,7 @@ import io.netty.channel.ChannelHandlerContext import io.netty.channel.ChannelOutboundHandlerAdapter import io.netty.channel.ChannelPromise import io.netty.handler.logging.LogLevel -import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.api.Assertions.assertNotNull -import org.junit.jupiter.api.Assertions.assertNull -import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Assertions.* import org.junit.jupiter.api.Test import pubsub.pb.Rpc import java.nio.charset.StandardCharsets @@ -40,49 +24,30 @@ import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.ScheduledExecutorService import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicReference - -class GossipV1_1Tests { - - private val GossipScore.testPeerScores get() = (this as DefaultGossipScore).peerScores - - private fun newProtoMessage(topic: Topic, seqNo: Long, data: ByteArray) = - Rpc.Message.newBuilder() - .addTopicIDs(topic) - .setSeqno(seqNo.toBytesBigEndian().toProtobuf()) - .setData(data.toProtobuf()) - .build() - private fun newMessage(topic: Topic, seqNo: Long, data: ByteArray) = - DefaultPubsubMessage(newProtoMessage(topic, seqNo, data)) - - protected fun getMessageId(msg: Rpc.Message): MessageId = msg.from.toWBytes() + msg.seqno.toWBytes() - - class ManyRoutersTest( - val mockRouterCount: Int = 10, - val params: GossipParams = GossipParams(), - val scoreParams: GossipScoreParams = GossipScoreParams(), -// mockRouters: () -> List = { (0 until mockRouterCount).map { MockRouter() } } - ) { - val fuzz = DeterministicFuzz() - val gossipRouterBuilderFactory = { GossipRouterBuilder(params = params, scoreParams = scoreParams) } - val router0 = fuzz.createTestRouter(createGossipFuzzRouterFactory(gossipRouterBuilderFactory)) - val routers = (0 until mockRouterCount).map { fuzz.createTestRouter(createMockFuzzRouterFactory()) } - val connections = mutableListOf() - val gossipRouter = router0.router as GossipRouter - val mockRouters = routers.map { it.router as MockRouter } - - fun connectAll() = connect(routers.indices) - fun connect(routerIndexes: IntRange, outbound: Boolean = true): List { - val list = - routers.slice(routerIndexes).map { - if (outbound) router0.connectSemiDuplex(it, null, LogLevel.ERROR) - else it.connectSemiDuplex(router0, null, LogLevel.ERROR) - } - connections += list - return list - } - - fun getMockRouter(peerId: PeerId) = mockRouters[routers.indexOfFirst { it.peerId == peerId }] - } +import kotlin.collections.List +import kotlin.collections.component1 +import kotlin.collections.component2 +import kotlin.collections.count +import kotlin.collections.distinct +import kotlin.collections.filter +import kotlin.collections.first +import kotlin.collections.flatMap +import kotlin.collections.forEach +import kotlin.collections.getValue +import kotlin.collections.intersect +import kotlin.collections.map +import kotlin.collections.mapValues +import kotlin.collections.minus +import kotlin.collections.mutableListOf +import kotlin.collections.mutableMapOf +import kotlin.collections.plusAssign +import kotlin.collections.set +import kotlin.collections.slice +import kotlin.collections.take +import kotlin.collections.toMap +import kotlin.collections.withDefault + +class GossipV1_1Tests : GossipTestsBase() { @Test fun selfSanityTest() { @@ -94,21 +59,6 @@ class GossipV1_1Tests { test.mockRouter.waitForMessage { it.publishCount > 0 } } - class TwoRoutersTest( - val coreParams: GossipParams = GossipParams(), - val scoreParams: GossipScoreParams = GossipScoreParams(), - mockRouterFactory: DeterministicFuzzRouterFactory = createMockFuzzRouterFactory() - ) { - val fuzz = DeterministicFuzz() - val gossipRouterBuilderFactory = { GossipRouterBuilder(params = coreParams, scoreParams = scoreParams) } - val router1 = fuzz.createTestRouter(createGossipFuzzRouterFactory(gossipRouterBuilderFactory)) - val router2 = fuzz.createTestRouter(mockRouterFactory) - val gossipRouter = router1.router as GossipRouter - val mockRouter = router2.router as MockRouter - - val connection = router1.connectSemiDuplex(router2, null, LogLevel.ERROR) - } - @Test fun testSeenTTL() { val test = TwoRoutersTest(GossipParams(seenTTL = 1.minutes)) @@ -141,7 +91,7 @@ class GossipV1_1Tests { val api = createPubsubApi(test.gossipRouter) val apiMessages = mutableListOf() - api.subscribe(Subscriber { apiMessages += it }, io.libp2p.core.pubsub.Topic("topic2")) + api.subscribe(Subscriber { apiMessages += it }, Topic("topic2")) val msg1 = Rpc.RPC.newBuilder() .addPublish(newProtoMessage("topic2", 0L, "Hello-1".toByteArray())) @@ -175,12 +125,13 @@ class GossipV1_1Tests { super.initChannelWithHandler(streamHandler, handler) } } + val test = TwoRoutersTest(mockRouterFactory = { exec, _, _ -> MalformedMockRouter(exec) }) val mockRouter = test.router2.router as MalformedMockRouter val api = createPubsubApi(test.gossipRouter) val apiMessages = mutableListOf() - api.subscribe(Subscriber { apiMessages += it }, io.libp2p.core.pubsub.Topic("topic1")) + api.subscribe(Subscriber { apiMessages += it }, Topic("topic1")) val msg1 = Rpc.RPC.newBuilder() .addPublish(newProtoMessage("topic1", 0L, "Hello-1".toByteArray())) @@ -580,8 +531,9 @@ class GossipV1_1Tests { @Test fun testNotFloodPublish() { + val message = newMessage("topic1", 0L, "Hello-0".toByteArray()) val appScore = mutableMapOf().withDefault { 0.0 } - val coreParams = GossipParams(3, 3, 3, floodPublish = false) + val coreParams = GossipParams(3, 3, 3, floodPublishMaxMessageSizeThreshold = message.size - 1) val peerScoreParams = GossipPeerScoreParams(appSpecificScore = { appScore.getValue(it) }) val scoreParams = GossipScoreParams(peerScoreParams = peerScoreParams) val test = ManyRoutersTest(params = coreParams, scoreParams = scoreParams) @@ -595,7 +547,7 @@ class GossipV1_1Tests { val topicMesh = test.gossipRouter.mesh["topic1"]!! assertTrue(topicMesh.size > 0 && topicMesh.size < test.routers.size) - test.gossipRouter.publish(newMessage("topic1", 0L, "Hello-0".toByteArray())) + test.gossipRouter.publish(message) test.fuzz.timeController.addTime(50.millis) @@ -607,8 +559,9 @@ class GossipV1_1Tests { @Test fun testFloodPublish() { + val message = newMessage("topic1", 0L, "Hello-0".toByteArray()) val appScore = mutableMapOf().withDefault { 0.0 } - val coreParams = GossipParams(3, 3, 3, floodPublish = true) + val coreParams = GossipParams(3, 3, 3, floodPublishMaxMessageSizeThreshold = message.size) val peerScoreParams = GossipPeerScoreParams( appSpecificScore = { appScore.getValue(it) }, appSpecificWeight = 1.0 @@ -630,7 +583,7 @@ class GossipV1_1Tests { val topicMesh = test.gossipRouter.mesh["topic1"]!!.map { it.peerId } assertTrue(topicMesh.size > 0 && topicMesh.size < test.routers.size) - test.gossipRouter.publish(newMessage("topic1", 0L, "Hello-0".toByteArray())) + test.gossipRouter.publish(message) test.fuzz.timeController.addTime(50.millis) @@ -696,8 +649,12 @@ class GossipV1_1Tests { fun testAdaptiveGossip() { val appScore = mutableMapOf().withDefault { 0.0 } val coreParams = GossipParams( - 3, 3, 3, DLazy = 3, - floodPublish = false, gossipFactor = 0.5 + 3, + 3, + 3, + DLazy = 3, + floodPublishMaxMessageSizeThreshold = NEVER_FLOOD_PUBLISH, + gossipFactor = 0.5 ) val peerScoreParams = GossipPeerScoreParams( appSpecificScore = { appScore.getValue(it) }, @@ -760,7 +717,7 @@ class GossipV1_1Tests { @Test fun testOutboundMeshQuotas1() { val appScore = mutableMapOf().withDefault { 0.0 } - val coreParams = GossipParams(3, 3, 3, DLazy = 3, DOut = 1, floodPublish = false) + val coreParams = GossipParams(3, 3, 3, DLazy = 3, DOut = 1, floodPublishMaxMessageSizeThreshold = NEVER_FLOOD_PUBLISH) val peerScoreParams = GossipPeerScoreParams(appSpecificScore = { appScore.getValue(it) }) val scoreParams = GossipScoreParams(peerScoreParams = peerScoreParams) val test = ManyRoutersTest(params = coreParams, scoreParams = scoreParams) @@ -806,8 +763,13 @@ class GossipV1_1Tests { fun testOpportunisticGraft() { val appScore = mutableMapOf().withDefault { 0.0 } val coreParams = GossipParams( - 3, 3, 10, DLazy = 3, DOut = 1, - opportunisticGraftPeers = 2, opportunisticGraftTicks = 60 + 3, + 3, + 10, + DLazy = 3, + DOut = 1, + opportunisticGraftPeers = 2, + opportunisticGraftTicks = 60 ) val peerScoreParams = GossipPeerScoreParams( appSpecificScore = { appScore.getValue(it) }, @@ -970,8 +932,11 @@ class GossipV1_1Tests { val validationResult = CompletableFuture() val receivedMessages = LinkedBlockingQueue() - val slowValidator = Validator { receivedMessages += it; validationResult } - api.subscribe(slowValidator, io.libp2p.core.pubsub.Topic("topic1")) + val slowValidator = Validator { + receivedMessages += it + validationResult + } + api.subscribe(slowValidator, Topic("topic1")) test.mockRouters.forEach { it.subscribe("topic1") } val gossiper = test.mockRouters[0] @@ -1123,4 +1088,348 @@ class GossipV1_1Tests { assertEquals(5, iWandIds1.size) assertEquals(5, iWandIds1.distinct().size) } + + @Test + fun testMaxPeersSentInPruneMsg() { + val test = TwoRoutersTest() + + val topic = "topic1" + test.mockRouter.subscribe(topic) + test.gossipRouter.subscribe(topic) + + for (i in 0..20) { + val router = test.fuzz.createTestRouter(test.mockRouterFactory) + (router.router as MockRouter).subscribe(topic) + test.router1.connectSemiDuplex(router, null, LogLevel.ERROR) + } + + // 2 heartbeats - the topic should be GRAFTed + test.fuzz.timeController.addTime(2.seconds) + test.mockRouter.waitForMessage { it.hasControl() && it.control.graftCount > 0 } + + test.gossipRouter.unsubscribe(topic) + test.fuzz.timeController.addTime(2.seconds) + assertEquals( + 1, + test.mockRouter.inboundMessages.count { + it.hasControl() && it.control.pruneCount == 1 && + it.control.getPrune(0).peersCount == test.gossipRouter.params.maxPeersSentInPruneMsg + } + ) + } + + @Test + fun testMaxPeersAcceptedInPruneMsg() { + val test = TwoRoutersTest() + val topic = "topic1" + + test.mockRouter.subscribe(topic) + test.gossipRouter.subscribe(topic) + + // 2 heartbeats - the topic should be GRAFTed + test.fuzz.timeController.addTime(2.seconds) + + test.mockRouter.sendToSingle( + createPruneMessage(topic, test.gossipRouter.params.maxPeersAcceptedInPruneMsg + 1) + ) + + // prune message should be dropped because too many peers + assertEquals(1, test.gossipRouter.mesh[topic]!!.size) + + test.mockRouter.sendToSingle( + createPruneMessage(topic, test.gossipRouter.params.maxPeersAcceptedInPruneMsg) + ) + + // prune message should now be processed + assertEquals(0, test.gossipRouter.mesh[topic]!!.size) + } + + @Test + fun `when a peer leaves the mesh it should still be considered for publishing`() { + val test = TwoRoutersTest() + val topic = "topic1" + + test.mockRouter.subscribe(topic) + test.gossipRouter.subscribe(topic) + + // 2 heartbeats - the topic should be GRAFTed + test.fuzz.timeController.addTime(2.seconds) + + assertTrue((test.gossipRouter.mesh[topic]?.size ?: 0) == 1) + + // remote peer leaves the mesh + test.mockRouter.sendToSingle(createPruneMessage(topic)) + test.fuzz.timeController.addTime(1.seconds) + + assertTrue((test.gossipRouter.mesh[topic]?.size ?: 0) == 0) + + val message1 = newMessage(topic, 0L, "Hello-0".toByteArray()) + test.gossipRouter.publish(message1) + + test.mockRouter.waitForMessage { it.publishCount > 0 } + } + + @Test + fun `should publish to all mesh peers when mesh exceeds D`() { + val gossipParams = GossipParams(D = 6, DHigh = 10) + val test = ManyRoutersTest(params = gossipParams, mockRouterCount = gossipParams.DHigh) + val topic = "topic1" + test.connectAll() + + test.mockRouters.forEach { + it.subscribe(topic) + } + test.gossipRouter.subscribe(topic) + + // 2 heartbeats - the topic should be GRAFTed + test.fuzz.timeController.addTime(2.seconds) + + assertTrue((test.gossipRouter.mesh[topic]?.size ?: 0) == gossipParams.D) + + test.mockRouters.forEach { + it.sendToSingle(createGraftMessage(topic)) + } + + test.fuzz.timeController.addTime(2.seconds) + + assertTrue((test.gossipRouter.mesh[topic]?.size ?: 0) == gossipParams.DHigh) + + // remote peer leaves the mesh + val message1 = newMessage(topic, 0L, "Hello-0".toByteArray()) + test.gossipRouter.publish(message1) + + val routerReceivedMessageCount = + test.mockRouters.count { mockRouter -> + mockRouter.inboundMessages.any { msg -> + msg.publishCount > 0 + } + } + + assertTrue(routerReceivedMessageCount == gossipParams.DHigh) + } + + @Test + fun `publishing should collect at least D peers if mesh is smaller`() { + val params = GossipParams() + + val test = ManyRoutersTest(params = params, mockRouterCount = params.D) + val topic = "topic1" + test.connectAll() + + test.mockRouters.forEach { it.subscribe(topic) } + test.gossipRouter.subscribe(topic) + + // 2 heartbeats - the topic should be GRAFTed + test.fuzz.timeController.addTime(2.seconds) + + val topicMeshRouters = test.gossipRouter.mesh[topic]!! + assertTrue((topicMeshRouters.size) >= params.DLow) + + // leave just 2 peers in the mesh + topicMeshRouters.drop(2) + .forEach { + test.getMockRouter(it.peerId).sendToSingle(createPruneMessage(topic)) + } + test.fuzz.timeController.addTime(1.seconds) + + assertTrue((test.gossipRouter.mesh[topic]?.size ?: 0) == 2) + + val message1 = newMessage(topic, 0L, "Hello-0".toByteArray()) + test.gossipRouter.publish(message1) + + val routerReceivedMessageCount = + test.mockRouters.count { mockRouter -> + mockRouter.inboundMessages.any { msg -> + msg.publishCount > 0 + } + } + + assertTrue(routerReceivedMessageCount >= params.D) + } + + @Test + fun `publishing should collect at least D peers if mesh is smaller and prefer well scored peers`() { + val params = GossipParams() + val peerAppScores = mutableMapOf() + val gossipScoreParams = GossipScoreParams( + peerScoreParams = GossipPeerScoreParams( + appSpecificScore = { + peerAppScores[it]?.toDouble() ?: 0.0 + }, + appSpecificWeight = 1.0 + ) + ) + + val test = ManyRoutersTest(params = params, scoreParams = gossipScoreParams, mockRouterCount = 10) + val topic = "topic1" + test.connectAll() + + test.mockRouters.forEach { it.subscribe(topic) } + test.gossipRouter.subscribe(topic) + + // 2 heartbeats - the topic should be GRAFTed + test.fuzz.timeController.addTime(2.seconds) + + val topicMeshRouters = test.gossipRouter.mesh[topic]!!.toList() + assertTrue((topicMeshRouters.size) == params.D) + + // leave just 2 peers in the mesh + topicMeshRouters.drop(2) + .forEach { + test.getMockRouter(it.peerId).sendToSingle(createPruneMessage(topic)) + } + // downscore all peers except 5 + val goodScoredPeers = topicMeshRouters.take(5).map { it.peerId }.toSet() + test.routers + .map { it.peerId } + .filter { it !in goodScoredPeers } + .forEach { peerAppScores[it] = -gossipScoreParams.publishThreshold.toInt() - 1 } + + // for D = 6: 2 peers in the mesh + 3 peers outside of mesh + others are significantly downscored + test.fuzz.timeController.addTime(1.seconds) + + assertTrue((test.gossipRouter.mesh[topic]?.size ?: 0) == 2) + + val message1 = newMessage(topic, 0L, "Hello-0".toByteArray()) + test.gossipRouter.publish(message1) + + // router should take 2 mesh peers, 3 well scored peers and 1 peer scored below publishThreshold + val peersReceivedMessage = test.routers + .filter { + val mockRouter = it.router as MockRouter + mockRouter.inboundMessages.any { msg -> + msg.publishCount > 0 + } + } + .map { it.peerId } + + assertTrue(peersReceivedMessage.size == params.D) + assertTrue(peersReceivedMessage.containsAll(goodScoredPeers)) + } + + @Test + fun `should always flood publish to subscribed direct peers`() { + val message = newMessage("topic1", 0L, "Hello-0".toByteArray()) + val appScore = mutableMapOf().withDefault { 0.0 } + val directPeers = mutableSetOf() + val coreParams = GossipParams(3, 3, 3, floodPublishMaxMessageSizeThreshold = ALWAYS_FLOOD_PUBLISH) + val peerScoreParams = GossipPeerScoreParams( + appSpecificScore = { appScore.getValue(it) }, + appSpecificWeight = 1.0, + isDirect = { directPeers.contains(it) } + ) + val scoreParams = GossipScoreParams( + peerScoreParams = peerScoreParams, + graylistThreshold = -15.0, + publishThreshold = -10.0, + ) + val test = ManyRoutersTest(mockRouterCount = 10, params = coreParams, scoreParams = scoreParams) + test.connectAll() + + test.gossipRouter.subscribe("topic1") + test.routers.slice(0..5).forEach { + it.router.subscribe("topic1") + } + + test.routers.slice(1..6).forEach { + directPeers.add(it.peerId) + } + + // now only peers from 1 to 5 are direct peers subscribed to the topic + + test.fuzz.timeController.addTime(2.seconds) + + // let's down score all peers + test.routers.forEach { + appScore[it.peerId] = -20.0 + } + test.gossipRouter.publish(message) + + test.fuzz.timeController.addTime(50.millis) + + val publishedCount = test.mockRouters.flatMap { it.inboundMessages }.count { it.publishCount > 0 } + + // only subscribed direct peers should receive the message + assertEquals(5, publishedCount) + } + + @Test + fun `should always publish to subscribed direct peers`() { + val message = newMessage("topic1", 0L, "Hello-0".toByteArray()) + val appScore = mutableMapOf().withDefault { 0.0 } + val directPeers = mutableSetOf() + val coreParams = GossipParams(3, 3, 3, floodPublishMaxMessageSizeThreshold = NEVER_FLOOD_PUBLISH) + val peerScoreParams = GossipPeerScoreParams( + appSpecificScore = { appScore.getValue(it) }, + appSpecificWeight = 1.0, + isDirect = { directPeers.contains(it) } + ) + val scoreParams = GossipScoreParams( + peerScoreParams = peerScoreParams, + graylistThreshold = -15.0, + publishThreshold = -10.0, + ) + val test = ManyRoutersTest(mockRouterCount = 10, params = coreParams, scoreParams = scoreParams) + test.connectAll() + + test.gossipRouter.subscribe("topic1") + + test.routers.slice(0..5).forEach { + it.router.subscribe("topic1") + } + test.routers.slice(1..6).forEach { + directPeers.add(it.peerId) + } + + // now only peers from 1 to 5 are direct peers subscribed to the topic + val subscribedDirectPeers = test.routers.slice(1..5).map { it.peerId } + + test.fuzz.timeController.addTime(2.seconds) + + // let's down score all direct peers + directPeers.forEach { + appScore[it] = -20.0 + } + + val topicMeshRouters = test.gossipRouter.mesh["topic1"]!!.toList() + + // the mesh is strictly smaller than the number of subscribed direct peers + assertTrue(topicMeshRouters.size < subscribedDirectPeers.size) + + val expectedPublishedCount = topicMeshRouters.map { it.peerId }.plus(subscribedDirectPeers).distinct().size + + test.gossipRouter.publish(message) + + test.fuzz.timeController.addTime(50.millis) + + val publishedCount = test.mockRouters.flatMap { it.inboundMessages }.count { it.publishCount > 0 } + + assertEquals(expectedPublishedCount, publishedCount) + } + + private fun createGraftMessage(topic: String): Rpc.RPC { + return Rpc.RPC.newBuilder().setControl( + Rpc.ControlMessage.newBuilder().addGraft( + Rpc.ControlGraft.newBuilder() + .setTopicID(topic) + ) + ).build() + } + + private fun createPruneMessage(topic: String, pxPeersCount: Int = 0): Rpc.RPC { + val peerInfos = List(pxPeersCount) { + Rpc.PeerInfo.newBuilder() + .setPeerID(PeerId.random().bytes.toProtobuf()) + .setSignedPeerRecord(ByteString.EMPTY) + .build() + } + return Rpc.RPC.newBuilder().setControl( + Rpc.ControlMessage.newBuilder().addPrune( + Rpc.ControlPrune.newBuilder() + .setTopicID(topic) + .setBackoff(10) + .addAllPeers(peerInfos) + ) + ).build() + } } diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_2Tests.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_2Tests.kt new file mode 100644 index 000000000..8ce1919c5 --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_2Tests.kt @@ -0,0 +1,191 @@ +@file:Suppress("ktlint:standard:class-naming") + +package io.libp2p.pubsub.gossip + +import io.libp2p.etc.types.millis +import io.libp2p.etc.types.seconds +import io.libp2p.etc.types.toProtobuf +import io.libp2p.etc.types.toWBytes +import io.libp2p.pubsub.PubsubProtocol +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import pubsub.pb.Rpc + +class GossipV1_2Tests : GossipTestsBase() { + + @Test + fun selfSanityTest() { + val test = TwoRoutersTest(protocol = PubsubProtocol.Gossip_V_1_2) + + test.mockRouter.subscribe("topic1") + val msg = newMessage("topic1", 0L, "Hello".toByteArray()) + test.gossipRouter.publish(msg) + test.mockRouter.waitForMessage { it.publishCount > 0 } + } + + @Test + fun iDontWantIsBroadcastToMeshPeers() { + val test = startSingleTopicNetwork( + params = GossipParams(iDontWantMinMessageSizeThreshold = 5), + mockRouterCount = 3 + ) + + val publisher = test.mockRouters[0] + val gossipers = listOf(test.mockRouters[1], test.mockRouters[2]) + + val msg = newMessage("topic1", 0L, "Hello".toByteArray()) + + publisher.sendToSingle( + Rpc.RPC.newBuilder().addPublish(msg.protobufMessage).build() + ) + + test.fuzz.timeController.addTime(100.millis) + + val iDontWants = + gossipers.flatMap { it.inboundMessages }.filter { it.hasControl() }.flatMap { it.control.idontwantList } + + // both gossipers should have received IDONTWANT from the GossipRouter + assertTrue(iDontWants.size == 2) + + iDontWants.forEach { iDontWant -> + assertThat(iDontWant.messageIDsList.map { it.toWBytes() }).containsExactly(msg.messageId) + } + } + + @Test + fun messageIsNotBroadcastIfPeerHasSentIDONTWANT() { + val test = startSingleTopicNetwork( + params = GossipParams(iDontWantMinMessageSizeThreshold = 5), + mockRouterCount = 2 + ) + + val publisher = test.mockRouters[0] + val iDontWantPeer = test.mockRouters[1] + + val msg = newMessage("topic1", 0L, "Hello".toByteArray()) + + // sending IDONTWANT + iDontWantPeer.sendToSingle( + Rpc.RPC.newBuilder().setControl( + Rpc.ControlMessage.newBuilder().addIdontwant( + Rpc.ControlIDontWant.newBuilder().addMessageIDs(msg.messageId.toProtobuf()) + ) + ).build() + ) + + test.fuzz.timeController.addTime(100.millis) + + publisher.sendToSingle( + Rpc.RPC.newBuilder().addPublish(msg.protobufMessage).build() + ) + + test.fuzz.timeController.addTime(100.millis) + + val receivedMessages = iDontWantPeer.inboundMessages.flatMap { it.publishList } + + // message shouldn't have been received + assertThat(receivedMessages).isEmpty() + } + + @Test + fun iDontWantIsNotSentIfSizeIsLessThanTheMinimumConfigured() { + val test = startSingleTopicNetwork( + params = GossipParams(iDontWantMinMessageSizeThreshold = 5), + mockRouterCount = 3 + ) + + val publisher = test.mockRouters[0] + val gossipers = listOf(test.mockRouters[1], test.mockRouters[2]) + + // 4 bytes and minimum is 5, so IDONTWANT shouldn't be sent + val msg = newMessage("topic1", 0L, "Hell".toByteArray()) + + publisher.sendToSingle( + Rpc.RPC.newBuilder().addPublish(msg.protobufMessage).build() + ) + + test.fuzz.timeController.addTime(100.millis) + + val iDontWants = + gossipers.flatMap { it.inboundMessages }.filter { it.hasControl() }.flatMap { it.control.idontwantList } + + assertThat(iDontWants).isEmpty() + } + + @Test + fun testIDontWantTTL() { + val test = startSingleTopicNetwork( + // set TTL to 700ms + params = GossipParams(iDontWantMinMessageSizeThreshold = 5, iDontWantTTL = 700.millis), + mockRouterCount = 2 + ) + + val publisher = test.mockRouters[0] + val iDontWantPeer = test.mockRouters[1] + + val msg = newMessage("topic1", 0L, "Hello".toByteArray()) + + // sending IDONTWANT + iDontWantPeer.sendToSingle( + Rpc.RPC.newBuilder().setControl( + Rpc.ControlMessage.newBuilder().addIdontwant( + Rpc.ControlIDontWant.newBuilder().addMessageIDs(msg.messageId.toProtobuf()) + ) + ).build() + ) + + // 1 heartbeat - the IDONTWANT should have expired + test.fuzz.timeController.addTime(1.seconds) + + publisher.sendToSingle( + Rpc.RPC.newBuilder().addPublish(msg.protobufMessage).build() + ) + + test.fuzz.timeController.addTime(100.millis) + + val receivedMessages = iDontWantPeer.inboundMessages.flatMap { it.publishList } + + // message shouldn't have been received + assertThat(receivedMessages).containsExactly(msg.protobufMessage) + } + + @Test + fun iDontWantIsSentOnPublishing() { + val test = startSingleTopicNetwork( + params = GossipParams(iDontWantMinMessageSizeThreshold = 5), + mockRouterCount = 3 + ) + + test.mockRouters.forEach { it.subscribe("topic1") } + val msgToPublish = newMessage("topic1", 0L, "Hello".toByteArray()) + test.gossipRouter.publish(msgToPublish) + test.mockRouters.forEach { + // IDONTWANT is received + it.waitForMessage { msg -> + msg.control.idontwantCount == 1 && + msg.control.idontwantList.first().messageIDsList.map { mIds -> mIds.toWBytes() }.contains(msgToPublish.messageId) + } + // msg is received + it.waitForMessage { msg -> msg.publishCount > 0 } + } + } + + private fun startSingleTopicNetwork(params: GossipParams, mockRouterCount: Int): ManyRoutersTest { + val test = ManyRoutersTest( + protocol = PubsubProtocol.Gossip_V_1_2, + params = params, + mockRouterCount = mockRouterCount + ) + + test.connectAll() + + test.gossipRouter.subscribe("topic1") + test.mockRouters.forEach { it.subscribe("topic1") } + + // 2 heartbeats - the topic should be GRAFTed + test.fuzz.timeController.addTime(2.seconds) + + return test + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/SubscriptionsLimitTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/SubscriptionsLimitTest.kt index 28b4a6f67..bc1f02fa6 100644 --- a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/SubscriptionsLimitTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/SubscriptionsLimitTest.kt @@ -10,7 +10,7 @@ import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertDoesNotThrow class SubscriptionsLimitTest : TwoGossipHostTestBase() { - override val params = GossipParams(maxSubscriptions = 5, floodPublish = true) + override val params = GossipParams(maxSubscriptions = 5, floodPublishMaxMessageSizeThreshold = ALWAYS_FLOOD_PUBLISH) @Test fun `new peer subscribed to many topics`() { diff --git a/libp2p/src/test/kotlin/io/libp2p/security/CipherSecureChannelTest.kt b/libp2p/src/test/kotlin/io/libp2p/security/CipherSecureChannelTest.kt index b466f04e6..b7290a6a5 100644 --- a/libp2p/src/test/kotlin/io/libp2p/security/CipherSecureChannelTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/security/CipherSecureChannelTest.kt @@ -1,7 +1,7 @@ package io.libp2p.security import io.libp2p.core.PeerId -import io.libp2p.core.crypto.KEY_TYPE +import io.libp2p.core.crypto.KeyType import io.libp2p.core.crypto.generateKeyPair import io.libp2p.core.mux.StreamMuxer import io.libp2p.tools.TestChannel @@ -16,11 +16,38 @@ import java.util.concurrent.TimeUnit.SECONDS abstract class CipherSecureChannelTest(secureChannelCtor: SecureChannelCtor, muxers: List, announce: String) : SecureChannelTestBase(secureChannelCtor, muxers, announce) { + @Test + fun `verify secure session`() { + val (privKey1, pubKey1) = generateKeyPair(KeyType.ECDSA) + val (privKey2, pubKey2) = generateKeyPair(KeyType.ECDSA) + + val protocolSelect1 = makeSelector(privKey1, muxerIds) + val protocolSelect2 = makeSelector(privKey2, muxerIds) + + val eCh1 = makeDialChannel("#1", protocolSelect1, PeerId.fromPubKey(pubKey2)) + val eCh2 = makeListenChannel("#2", protocolSelect2) + + logger.debug("Connecting channels...") + val connection = TestChannel.interConnect(eCh1, eCh2) + + val secSession1 = protocolSelect1.selectedFuture.join() + assertThat(secSession1.localId).isEqualTo(PeerId.fromPubKey(pubKey1)) + assertThat(secSession1.remoteId).isEqualTo(PeerId.fromPubKey(pubKey2)) + assertThat(secSession1.remotePubKey).isEqualTo(pubKey2) + + val secSession2 = protocolSelect2.selectedFuture.join() + assertThat(secSession2.localId).isEqualTo(PeerId.fromPubKey(pubKey2)) + assertThat(secSession2.remoteId).isEqualTo(PeerId.fromPubKey(pubKey1)) + assertThat(secSession2.remotePubKey).isEqualTo(pubKey1) + + logger.debug("Connection made: $connection") + } + @Test fun `incorrect initiator remote PeerId should throw`() { - val (privKey1, _) = generateKeyPair(KEY_TYPE.ECDSA) - val (privKey2, _) = generateKeyPair(KEY_TYPE.ECDSA) - val (_, wrongPubKey) = generateKeyPair(KEY_TYPE.ECDSA) + val (privKey1, _) = generateKeyPair(KeyType.ECDSA) + val (privKey2, _) = generateKeyPair(KeyType.ECDSA) + val (_, wrongPubKey) = generateKeyPair(KeyType.ECDSA) val protocolSelect1 = makeSelector(privKey1, muxerIds) val protocolSelect2 = makeSelector(privKey2, muxerIds) @@ -37,8 +64,8 @@ abstract class CipherSecureChannelTest(secureChannelCtor: SecureChannelCtor, mux @Test fun `test that on malformed message from remote the connection closes and no log noise`() { - val (privKey1, _) = generateKeyPair(KEY_TYPE.ECDSA) - val (privKey2, pubKey2) = generateKeyPair(KEY_TYPE.ECDSA) + val (privKey1, _) = generateKeyPair(KeyType.ECDSA) + val (privKey2, pubKey2) = generateKeyPair(KeyType.ECDSA) val protocolSelect1 = makeSelector(privKey1, muxerIds) val protocolSelect2 = makeSelector(privKey2, muxerIds) diff --git a/libp2p/src/test/kotlin/io/libp2p/security/SecureChannelTestBase.kt b/libp2p/src/test/kotlin/io/libp2p/security/SecureChannelTestBase.kt index a91651a75..a4a899d81 100644 --- a/libp2p/src/test/kotlin/io/libp2p/security/SecureChannelTestBase.kt +++ b/libp2p/src/test/kotlin/io/libp2p/security/SecureChannelTestBase.kt @@ -1,7 +1,7 @@ package io.libp2p.security import io.libp2p.core.PeerId -import io.libp2p.core.crypto.KEY_TYPE +import io.libp2p.core.crypto.KeyType import io.libp2p.core.crypto.PrivKey import io.libp2p.core.crypto.generateKeyPair import io.libp2p.core.multistream.ProtocolMatcher @@ -55,8 +55,8 @@ abstract class SecureChannelTestBase( @ParameterizedTest @MethodSource("plainDataSizes") fun secureInterconnect(dataSize: Int) { - val (privKey1, _) = generateKeyPair(KEY_TYPE.ECDSA) - val (privKey2, pubKey2) = generateKeyPair(KEY_TYPE.ECDSA) + val (privKey1, _) = generateKeyPair(KeyType.ECDSA) + val (privKey2, pubKey2) = generateKeyPair(KeyType.ECDSA) val protocolSelect1 = makeSelector(privKey1, muxerIds) val protocolSelect2 = makeSelector(privKey2, muxerIds) diff --git a/libp2p/src/test/kotlin/io/libp2p/security/noise/NoiseHandshakeTest.kt b/libp2p/src/test/kotlin/io/libp2p/security/noise/NoiseHandshakeTest.kt index a6f6454b2..5489a6e94 100644 --- a/libp2p/src/test/kotlin/io/libp2p/security/noise/NoiseHandshakeTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/security/noise/NoiseHandshakeTest.kt @@ -2,7 +2,7 @@ package io.libp2p.security.noise import com.google.protobuf.ByteString import com.southernstorm.noise.protocol.HandshakeState -import io.libp2p.core.crypto.KEY_TYPE +import io.libp2p.core.crypto.KeyType import io.libp2p.core.crypto.generateKeyPair import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.MethodOrderer.OrderAnnotation @@ -120,7 +120,7 @@ class NoiseHandshakeTest { // generate a Peer Identity protobuf object // use it for encoding and decoding peer identities from the wire // this identity is intended to be sent as a Noise transport payload - val (privKey, pubKey) = generateKeyPair(KEY_TYPE.ECDSA) + val (privKey, pubKey) = generateKeyPair(KeyType.ECDSA) assert(pubKey.bytes().maxOrNull()?.compareTo(0) != 0) // sign the identity using the identity's private key @@ -144,7 +144,7 @@ class NoiseHandshakeTest { @Test fun testAnnounceAndMatch() { - val (privKey1, _) = generateKeyPair(KEY_TYPE.ECDSA) + val (privKey1, _) = generateKeyPair(KeyType.ECDSA) val ch1 = NoiseXXSecureChannel(privKey1, listOf()) @@ -155,11 +155,11 @@ class NoiseHandshakeTest { @Test fun testStaticNoiseKeyPerProcess() { - val (privKey1, _) = generateKeyPair(KEY_TYPE.ECDSA) + val (privKey1, _) = generateKeyPair(KeyType.ECDSA) NoiseXXSecureChannel(privKey1, listOf()) val b1 = NoiseXXSecureChannel.localStaticPrivateKey25519.copyOf() - val (privKey2, _) = generateKeyPair(KEY_TYPE.ECDSA) + val (privKey2, _) = generateKeyPair(KeyType.ECDSA) NoiseXXSecureChannel(privKey2, listOf()) val b2 = NoiseXXSecureChannel.localStaticPrivateKey25519.copyOf() diff --git a/libp2p/src/test/kotlin/io/libp2p/security/secio/EchoSampleTest.kt b/libp2p/src/test/kotlin/io/libp2p/security/secio/EchoSampleTest.kt index f1c8b097c..4948a0022 100644 --- a/libp2p/src/test/kotlin/io/libp2p/security/secio/EchoSampleTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/security/secio/EchoSampleTest.kt @@ -4,7 +4,7 @@ import io.libp2p.core.ChannelVisitor import io.libp2p.core.Connection import io.libp2p.core.P2PChannel import io.libp2p.core.P2PChannelHandler -import io.libp2p.core.crypto.KEY_TYPE +import io.libp2p.core.crypto.KeyType import io.libp2p.core.crypto.generateKeyPair import io.libp2p.core.multiformats.Multiaddr import io.libp2p.core.multistream.MultistreamProtocolV1 @@ -53,7 +53,7 @@ class EchoSampleTest { fun connect1() { val logger = LoggerFactory.getLogger("test") - val (privKey1, _) = generateKeyPair(KEY_TYPE.ECDSA) + val (privKey1, _) = generateKeyPair(KeyType.ECDSA) val applicationProtocols = listOf(createSimpleBinding("/echo/1.0.0") { EchoProtocol() }) val muxer = StreamMuxerProtocol.Mplex.createMuxer(MultistreamProtocolV1, applicationProtocols).also { it as MplexStreamMuxer diff --git a/libp2p/src/test/kotlin/io/libp2p/security/secio/SecIoNegotiatorTest.kt b/libp2p/src/test/kotlin/io/libp2p/security/secio/SecIoNegotiatorTest.kt index e0c119768..45309c119 100644 --- a/libp2p/src/test/kotlin/io/libp2p/security/secio/SecIoNegotiatorTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/security/secio/SecIoNegotiatorTest.kt @@ -1,7 +1,7 @@ package io.libp2p.security.secio import io.libp2p.core.PeerId -import io.libp2p.core.crypto.KEY_TYPE +import io.libp2p.core.crypto.KeyType import io.libp2p.core.crypto.generateKeyPair import io.libp2p.core.crypto.unmarshalPrivateKey import io.libp2p.crypto.keys.secp256k1PublicKeyFromCoordinates @@ -19,8 +19,8 @@ import java.math.BigInteger class SecIoNegotiatorTest { @Test fun handshake() { - val (privKey1, pubKey1) = generateKeyPair(KEY_TYPE.ECDSA) - val (privKey2, pubKey2) = generateKeyPair(KEY_TYPE.ECDSA) + val (privKey1, pubKey1) = generateKeyPair(KeyType.ECDSA) + val (privKey2, pubKey2) = generateKeyPair(KeyType.ECDSA) var bb1: ByteBuf? = null var bb2: ByteBuf? = null diff --git a/libp2p/src/test/kotlin/io/libp2p/security/tls/TlsSecureChannelTest.kt b/libp2p/src/test/kotlin/io/libp2p/security/tls/TlsSecureChannelTest.kt index 1d5fe5ed5..128dbe4e9 100644 --- a/libp2p/src/test/kotlin/io/libp2p/security/tls/TlsSecureChannelTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/security/tls/TlsSecureChannelTest.kt @@ -1,7 +1,7 @@ package io.libp2p.security.tls import io.libp2p.core.PeerId -import io.libp2p.core.crypto.KEY_TYPE +import io.libp2p.core.crypto.KeyType import io.libp2p.core.crypto.generateKeyPair import io.libp2p.core.multistream.MultistreamProtocolDebug import io.libp2p.core.mux.StreamMuxerProtocol @@ -19,14 +19,14 @@ val MultistreamProtocolV1: MultistreamProtocolDebug = MultistreamProtocolDebugV1 @Tag("secure-channel") class TlsSecureChannelTest : SecureChannelTestBase( ::TlsSecureChannel, - listOf(StreamMuxerProtocol.Yamux.createMuxer(MultistreamProtocolV1, listOf())), + listOf(StreamMuxerProtocol.getYamux().createMuxer(MultistreamProtocolV1, listOf())), TlsSecureChannel.announce ) { @Test fun `incorrect initiator remote PeerId should throw`() { - val (privKey1, _) = generateKeyPair(KEY_TYPE.ECDSA) - val (privKey2, _) = generateKeyPair(KEY_TYPE.ECDSA) - val (_, wrongPubKey) = generateKeyPair(KEY_TYPE.ECDSA) + val (privKey1, _) = generateKeyPair(KeyType.ECDSA) + val (privKey2, _) = generateKeyPair(KeyType.ECDSA) + val (_, wrongPubKey) = generateKeyPair(KeyType.ECDSA) val protocolSelect1 = makeSelector(privKey1, muxerIds) val protocolSelect2 = makeSelector(privKey2, muxerIds) diff --git a/libp2p/src/test/kotlin/io/libp2p/transport/tcp/TcpTransportTest.kt b/libp2p/src/test/kotlin/io/libp2p/transport/tcp/TcpTransportTest.kt index e73577448..834f2b20f 100644 --- a/libp2p/src/test/kotlin/io/libp2p/transport/tcp/TcpTransportTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/transport/tcp/TcpTransportTest.kt @@ -16,10 +16,11 @@ class TcpTransportTest : TransportTests() { } override fun localAddress(portNumber: Int): Multiaddr { - return if (ip4DnsAvailable && (portNumber % 2 == 0)) + return if (ip4DnsAvailable && (portNumber % 2 == 0)) { Multiaddr("/dns4/localhost/tcp/$portNumber") - else + } else { Multiaddr("/ip4/127.0.0.1/tcp/$portNumber") + } } override fun badAddress(): Multiaddr = @@ -34,8 +35,7 @@ class TcpTransportTest : TransportTests() { "/ip4/0.0.0.0/tcp/1234", "/ip6/fe80::6f77:b303:aa6e:a16/tcp/42", "/dns4/localhost/tcp/9999", - "/dns6/localhost/tcp/9999", - "/dnsaddr/ipfs.io/tcp/97" + "/dns6/localhost/tcp/9999" ).map { Multiaddr(it) } @JvmStatic diff --git a/libp2p/src/test/kotlin/io/libp2p/transport/ws/WsTransportTest.kt b/libp2p/src/test/kotlin/io/libp2p/transport/ws/WsTransportTest.kt index a3f067610..9c7d42bd5 100644 --- a/libp2p/src/test/kotlin/io/libp2p/transport/ws/WsTransportTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/transport/ws/WsTransportTest.kt @@ -15,10 +15,11 @@ class WsTransportTest : TransportTests() { } // makeTransport override fun localAddress(portNumber: Int): Multiaddr { - return if (ip4DnsAvailable && (portNumber % 2 == 0)) + return if (ip4DnsAvailable && (portNumber % 2 == 0)) { Multiaddr("/dns4/localhost/tcp/$portNumber/ws") - else + } else { Multiaddr("/ip4/127.0.0.1/tcp/$portNumber/ws") + } } // localAddress override fun badAddress(): Multiaddr = diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/AsyncDaemonExecutor.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/AsyncDaemonExecutor.java index db9973351..7dbe50fb3 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/AsyncDaemonExecutor.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/AsyncDaemonExecutor.java @@ -1,50 +1,49 @@ package io.libp2p.tools.p2pd; import io.netty.channel.unix.DomainSocketAddress; - import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.concurrent.CompletableFuture; import java.util.function.Function; -/** - * Created by Anton Nashatyrev on 20.12.2018. - */ +/** Created by Anton Nashatyrev on 20.12.2018. */ public class AsyncDaemonExecutor { - private final SocketAddress address; - - public AsyncDaemonExecutor(SocketAddress address) { - this.address = address; - } - - public CompletableFuture executeWithDaemon( - Function> executor) { - CompletableFuture daemonFut = getDaemon(); - return daemonFut - .thenCompose(executor) - .whenComplete((r, t) -> { - if (!daemonFut.isCompletedExceptionally()) { - try { - daemonFut.get().close(); - } catch (Exception e) {} - } - }); + private final SocketAddress address; + + public AsyncDaemonExecutor(SocketAddress address) { + this.address = address; + } + + public CompletableFuture executeWithDaemon( + Function> executor) { + CompletableFuture daemonFut = getDaemon(); + return daemonFut + .thenCompose(executor) + .whenComplete( + (r, t) -> { + if (!daemonFut.isCompletedExceptionally()) { + try { + daemonFut.get().close(); + } catch (Exception e) { + } + } + }); + } + + public CompletableFuture getDaemon() { + ControlConnector connector; + if (address instanceof InetSocketAddress) { + connector = new TCPControlConnector(); + } else if (address instanceof DomainSocketAddress) { + connector = new UnixSocketControlConnector(); + } else { + throw new IllegalArgumentException(); } - public CompletableFuture getDaemon() { - ControlConnector connector; - if (address instanceof InetSocketAddress) { - connector = new TCPControlConnector(); - } else if (address instanceof DomainSocketAddress) { - connector = new UnixSocketControlConnector(); - } else { - throw new IllegalArgumentException(); - } + return connector.connect(address); + } - return connector.connect(address); - } - - public SocketAddress getAddress() { - return address; - } + public SocketAddress getAddress() { + return address; + } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/ControlConnector.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/ControlConnector.java index 0f2e12209..fadf1ffa0 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/ControlConnector.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/ControlConnector.java @@ -6,50 +6,52 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInitializer; import io.netty.channel.SimpleChannelInboundHandler; - import java.net.SocketAddress; import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; -/** - * Created by Anton Nashatyrev on 13.12.2018. - */ +/** Created by Anton Nashatyrev on 13.12.2018. */ public abstract class ControlConnector { - protected final int connectTimeoutSec = 5; + protected final int connectTimeoutSec = 5; - public abstract CompletableFuture connect(SocketAddress addr); + public abstract CompletableFuture connect(SocketAddress addr); - public abstract ChannelFuture listen(SocketAddress addr, Consumer handlersConsumer); + public abstract ChannelFuture listen( + SocketAddress addr, Consumer handlersConsumer); - protected static class ChannelInit extends ChannelInitializer { - private final Consumer handlersConsumer; - private final boolean initiator; + protected static class ChannelInit extends ChannelInitializer { + private final Consumer handlersConsumer; + private final boolean initiator; - public ChannelInit(Consumer handlersConsumer, boolean initiator) { - this.handlersConsumer = handlersConsumer; - this.initiator = initiator; - } + public ChannelInit(Consumer handlersConsumer, boolean initiator) { + this.handlersConsumer = handlersConsumer; + this.initiator = initiator; + } - @Override - protected void initChannel(Channel ch) throws Exception { - DaemonChannelHandler handler = new DaemonChannelHandler(ch, initiator); - ch.pipeline().addFirst(new SimpleChannelInboundHandler() { + @Override + protected void initChannel(Channel ch) throws Exception { + DaemonChannelHandler handler = new DaemonChannelHandler(ch, initiator); + ch.pipeline() + .addFirst( + new SimpleChannelInboundHandler() { @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { - handlersConsumer.accept(handler); - super.channelActive(ctx); + handlersConsumer.accept(handler); + super.channelActive(ctx); } @Override - protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { - handler.onData((ByteBuf) msg); + protected void channelRead0(ChannelHandlerContext ctx, Object msg) + throws Exception { + handler.onData((ByteBuf) msg); } @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - handler.onError(cause); + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) + throws Exception { + handler.onError(cause); } - }); - } + }); } + } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/DaemonChannelHandler.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/DaemonChannelHandler.java index 391aa5a1c..ed0b74d9e 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/DaemonChannelHandler.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/DaemonChannelHandler.java @@ -12,8 +12,6 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; -import p2pd.pb.P2Pd; - import java.io.ByteArrayOutputStream; import java.io.Closeable; import java.io.IOException; @@ -27,276 +25,281 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.LinkedBlockingQueue; import java.util.function.Function; +import p2pd.pb.P2Pd; -/** - * Created by Anton Nashatyrev on 14.12.2018. - */ +/** Created by Anton Nashatyrev on 14.12.2018. */ public class DaemonChannelHandler implements Closeable, AutoCloseable { - private final Channel channel; - private final boolean isInitiator; - private Queue respBuildQueue = new ConcurrentLinkedQueue<>(); - private StreamHandler streamHandler; - private Stream stream; - private ByteBuf prevDataTail = Unpooled.buffer(0); - - public DaemonChannelHandler(Channel channel, boolean isInitiator) { - this.channel = channel; - this.isInitiator = isInitiator; - } - - public void setStreamHandler(StreamHandler streamHandler) { - this.streamHandler = streamHandler; - } - - void onData(ByteBuf msg) throws InvalidProtocolBufferException { - ByteBuf bytes = prevDataTail.isReadable() ? Unpooled.wrappedBuffer(prevDataTail, msg) : msg; - while (bytes.isReadable()) { - if (stream != null) { - streamHandler.onRead(bytes.nioBuffer()); - bytes.clear(); - break; - } else { - ResponseBuilder responseBuilder = respBuildQueue.peek(); - if (responseBuilder == null) { - throw new RuntimeException("Unexpected response message from p2pDaemon"); - } - - try { - ByteBuf bbDup = bytes.duplicate(); - InputStream is = new ByteBufInputStream(bbDup); - int msgLen = CodedInputStream.readRawVarint32(is.read(), is); - if (msgLen > bbDup.readableBytes()) { - break; - } - } catch (IOException e) { - throw new RuntimeException(e); - } - Action action = responseBuilder.parseNextMessage(bytes); - if (action != Action.ContinueResponse) { - respBuildQueue.poll(); - } - - if (action == Action.StartStream) { - P2Pd.StreamInfo resp = responseBuilder.getStreamInfo(); - MuxerAdress remoteAddr = new MuxerAdress(new Peer(resp.getPeer().toByteArray()), resp.getProto()); - MuxerAdress localAddr = MuxerAdress.listenAddress(resp.getProto()); - - stream = new NettyStream(channel, isInitiator, localAddr, remoteAddr); - streamHandler.onCreate(stream); - channel.closeFuture().addListener((ChannelFutureListener) future -> streamHandler.onClose()); - } - } + private final Channel channel; + private final boolean isInitiator; + private Queue respBuildQueue = new ConcurrentLinkedQueue<>(); + private StreamHandler streamHandler; + private Stream stream; + private ByteBuf prevDataTail = Unpooled.buffer(0); + + public DaemonChannelHandler(Channel channel, boolean isInitiator) { + this.channel = channel; + this.isInitiator = isInitiator; + } + + public void setStreamHandler(StreamHandler streamHandler) { + this.streamHandler = streamHandler; + } + + void onData(ByteBuf msg) throws InvalidProtocolBufferException { + ByteBuf bytes = prevDataTail.isReadable() ? Unpooled.wrappedBuffer(prevDataTail, msg) : msg; + while (bytes.isReadable()) { + if (stream != null) { + streamHandler.onRead(bytes.nioBuffer()); + bytes.clear(); + break; + } else { + ResponseBuilder responseBuilder = respBuildQueue.peek(); + if (responseBuilder == null) { + throw new RuntimeException("Unexpected response message from p2pDaemon"); } - prevDataTail = Unpooled.wrappedBuffer(Util.byteBufToArray(bytes)); - } - void onError(Throwable t) { - streamHandler.onError(t); - } - - public CompletableFuture expectResponse( - ResponseBuilder responseBuilder) { - respBuildQueue.add(responseBuilder); - return responseBuilder.getResponse(); - } - - public CompletableFuture call(P2Pd.Request request, - ResponseBuilder responseBuilder) { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); try { - request.writeDelimitedTo(baos); + ByteBuf bbDup = bytes.duplicate(); + InputStream is = new ByteBufInputStream(bbDup); + int msgLen = CodedInputStream.readRawVarint32(is.read(), is); + if (msgLen > bbDup.readableBytes()) { + break; + } } catch (IOException e) { - throw new RuntimeException(e); + throw new RuntimeException(e); } - byte[] msgBytes = baos.toByteArray(); - ByteBuf buffer = channel.alloc().buffer(msgBytes.length).writeBytes(msgBytes); - CompletableFuture ret = expectResponse(responseBuilder); - ChannelFuture channelFuture = channel.writeAndFlush(buffer); - - try { - channelFuture.get(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } catch (ExecutionException e) { - throw new RuntimeException(e); + Action action = responseBuilder.parseNextMessage(bytes); + if (action != Action.ContinueResponse) { + respBuildQueue.poll(); } - return ret; + if (action == Action.StartStream) { + P2Pd.StreamInfo resp = responseBuilder.getStreamInfo(); + MuxerAdress remoteAddr = + new MuxerAdress(new Peer(resp.getPeer().toByteArray()), resp.getProto()); + MuxerAdress localAddr = MuxerAdress.listenAddress(resp.getProto()); + + stream = new NettyStream(channel, isInitiator, localAddr, remoteAddr); + streamHandler.onCreate(stream); + channel + .closeFuture() + .addListener((ChannelFutureListener) future -> streamHandler.onClose()); + } + } } - - public void close() { - channel.close(); + prevDataTail = Unpooled.wrappedBuffer(Util.byteBufToArray(bytes)); + } + + void onError(Throwable t) { + streamHandler.onError(t); + } + + public CompletableFuture expectResponse( + ResponseBuilder responseBuilder) { + respBuildQueue.add(responseBuilder); + return responseBuilder.getResponse(); + } + + public CompletableFuture call( + P2Pd.Request request, ResponseBuilder responseBuilder) { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try { + request.writeDelimitedTo(baos); + } catch (IOException e) { + throw new RuntimeException(e); } - - @FunctionalInterface - public interface FunctionThrowable { - B apply(A arg) throws Exception; + byte[] msgBytes = baos.toByteArray(); + ByteBuf buffer = channel.alloc().buffer(msgBytes.length).writeBytes(msgBytes); + CompletableFuture ret = expectResponse(responseBuilder); + ChannelFuture channelFuture = channel.writeAndFlush(buffer); + + try { + channelFuture.get(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } catch (ExecutionException e) { + throw new RuntimeException(e); } - private enum Action { - EndResponse, - ContinueResponse, - StartStream + return ret; + } + + public void close() { + channel.close(); + } + + @FunctionalInterface + public interface FunctionThrowable { + B apply(A arg) throws Exception; + } + + private enum Action { + EndResponse, + ContinueResponse, + StartStream + } + + public abstract static class ResponseBuilder { + protected boolean throwOnResponseError = true; + protected CompletableFuture respFuture = new CompletableFuture<>(); + + protected Action parseNextMessage(ByteBuf bytes) { + ByteBuf buf = bytes.duplicate(); + try { + return parseNextMessage(new ByteBufInputStream(bytes)); + } catch (Exception e) { + respFuture.completeExceptionally( + new RuntimeException("Error parsing message: " + (Util.byteBufToArray(buf)), e)); + return Action.EndResponse; + } } - public static abstract class ResponseBuilder { - protected boolean throwOnResponseError = true; - protected CompletableFuture respFuture = new CompletableFuture<>(); - - protected Action parseNextMessage(ByteBuf bytes) { - ByteBuf buf = bytes.duplicate(); - try { - return parseNextMessage(new ByteBufInputStream(bytes)); - } catch (Exception e) { - respFuture.completeExceptionally(new RuntimeException("Error parsing message: " - + (Util.byteBufToArray(buf)), e)); - return Action.EndResponse; - } - } - - abstract Action parseNextMessage(InputStream is) throws Exception; + abstract Action parseNextMessage(InputStream is) throws Exception; - CompletableFuture getResponse() { - return respFuture; - } + CompletableFuture getResponse() { + return respFuture; + } - P2Pd.StreamInfo getStreamInfo() { - try { - TResponse resp = respFuture.get(); - if (resp instanceof P2Pd.Response) { - return ((P2Pd.Response) resp).getStreamInfo(); - } else { - return (P2Pd.StreamInfo) resp; - } - } catch (Exception e) { - throw new RuntimeException(e); - } + P2Pd.StreamInfo getStreamInfo() { + try { + TResponse resp = respFuture.get(); + if (resp instanceof P2Pd.Response) { + return ((P2Pd.Response) resp).getStreamInfo(); + } else { + return (P2Pd.StreamInfo) resp; } + } catch (Exception e) { + throw new RuntimeException(e); + } } + } - public static class SingleMsgResponseBuilder extends ResponseBuilder{ - FunctionThrowable parser; + public static class SingleMsgResponseBuilder extends ResponseBuilder { + FunctionThrowable parser; - public SingleMsgResponseBuilder(FunctionThrowable parser) { - this.parser = parser; - } + public SingleMsgResponseBuilder(FunctionThrowable parser) { + this.parser = parser; + } - @Override - Action parseNextMessage(InputStream is) { - try { - TResponse response = parser.apply(is); - if (throwOnResponseError && response instanceof P2Pd.Response && - ((P2Pd.Response) response).getType() == P2Pd.Response.Type.ERROR) { - throw new P2PDError(((P2Pd.Response) response).getError().toString()); - } else { - respFuture.complete(response); - } - } catch (Exception e) { - respFuture.completeExceptionally(e); - } - return Action.EndResponse; + @Override + Action parseNextMessage(InputStream is) { + try { + TResponse response = parser.apply(is); + if (throwOnResponseError + && response instanceof P2Pd.Response + && ((P2Pd.Response) response).getType() == P2Pd.Response.Type.ERROR) { + throw new P2PDError(((P2Pd.Response) response).getError().toString()); + } else { + respFuture.complete(response); } + } catch (Exception e) { + respFuture.completeExceptionally(e); + } + return Action.EndResponse; + } - CompletableFuture getResponse() { - return respFuture; - } + CompletableFuture getResponse() { + return respFuture; } + } - public static class SimpleResponseBuilder extends SingleMsgResponseBuilder { - public SimpleResponseBuilder() { - super(P2Pd.Response::parseDelimitedFrom); - } + public static class SimpleResponseBuilder extends SingleMsgResponseBuilder { + public SimpleResponseBuilder() { + super(P2Pd.Response::parseDelimitedFrom); } + } - public static class ListenerStreamBuilder extends SingleMsgResponseBuilder { - public ListenerStreamBuilder() { - super(P2Pd.StreamInfo::parseDelimitedFrom); - } - @Override - protected Action parseNextMessage(ByteBuf bytes) { - super.parseNextMessage(bytes); - return Action.StartStream; - } + public static class ListenerStreamBuilder extends SingleMsgResponseBuilder { + public ListenerStreamBuilder() { + super(P2Pd.StreamInfo::parseDelimitedFrom); } - public static class SimpleResponseStreamBuilder extends SingleMsgResponseBuilder { - public SimpleResponseStreamBuilder() { - super(P2Pd.Response::parseDelimitedFrom); - } + @Override + protected Action parseNextMessage(ByteBuf bytes) { + super.parseNextMessage(bytes); + return Action.StartStream; + } + } - @Override - protected Action parseNextMessage(ByteBuf bytes) { - super.parseNextMessage(bytes); - try { - if (getResponse().get().getType() == P2Pd.Response.Type.OK) { - return Action.StartStream; - } else { - return Action.EndResponse; - } - } catch (Exception e) { - throw new RuntimeException(e); - } - } + public static class SimpleResponseStreamBuilder extends SingleMsgResponseBuilder { + public SimpleResponseStreamBuilder() { + super(P2Pd.Response::parseDelimitedFrom); } - public static class DHTListResponse extends ResponseBuilder> { - private final List items = new ArrayList<>(); - private boolean started; - @Override - Action parseNextMessage(InputStream is) throws Exception { - if (!started) { - P2Pd.Response response = P2Pd.Response.parseDelimitedFrom(is); - if (response.getType() == P2Pd.Response.Type.ERROR) { - throw new P2PDError("" + response.getError()); - } else { - if (!response.hasDht() || response.getDht().getType() != P2Pd.DHTResponse.Type.BEGIN) { - throw new RuntimeException("Invalid DHT list start message: " + response); - } - started = true; - return Action.ContinueResponse; - } - } else { - P2Pd.DHTResponse response = P2Pd.DHTResponse.parseDelimitedFrom(is); - if (response.getType() == P2Pd.DHTResponse.Type.END) { - respFuture.complete(items); - return Action.EndResponse; - } else if (response.getType() == P2Pd.DHTResponse.Type.VALUE) { - items.add(response); - return Action.ContinueResponse; - } else { - throw new RuntimeException("Invalid DHT list message: " + response); - } - } + @Override + protected Action parseNextMessage(ByteBuf bytes) { + super.parseNextMessage(bytes); + try { + if (getResponse().get().getType() == P2Pd.Response.Type.OK) { + return Action.StartStream; + } else { + return Action.EndResponse; } + } catch (Exception e) { + throw new RuntimeException(e); + } } + } + + public static class DHTListResponse extends ResponseBuilder> { + private final List items = new ArrayList<>(); + private boolean started; + + @Override + Action parseNextMessage(InputStream is) throws Exception { + if (!started) { + P2Pd.Response response = P2Pd.Response.parseDelimitedFrom(is); + if (response.getType() == P2Pd.Response.Type.ERROR) { + throw new P2PDError("" + response.getError()); + } else { + if (!response.hasDht() || response.getDht().getType() != P2Pd.DHTResponse.Type.BEGIN) { + throw new RuntimeException("Invalid DHT list start message: " + response); + } + started = true; + return Action.ContinueResponse; + } + } else { + P2Pd.DHTResponse response = P2Pd.DHTResponse.parseDelimitedFrom(is); + if (response.getType() == P2Pd.DHTResponse.Type.END) { + respFuture.complete(items); + return Action.EndResponse; + } else if (response.getType() == P2Pd.DHTResponse.Type.VALUE) { + items.add(response); + return Action.ContinueResponse; + } else { + throw new RuntimeException("Invalid DHT list message: " + response); + } + } + } + } - public static class UnboundMessagesResponse extends ResponseBuilder> { - private final BlockingQueue items = new LinkedBlockingQueue<>(); - private final Function decoder; - private boolean started; + public static class UnboundMessagesResponse + extends ResponseBuilder> { + private final BlockingQueue items = new LinkedBlockingQueue<>(); + private final Function decoder; + private boolean started; - public UnboundMessagesResponse(Function decoder) { - this.decoder = decoder; - } + public UnboundMessagesResponse(Function decoder) { + this.decoder = decoder; + } - @Override - Action parseNextMessage(InputStream is) throws Exception { - if (!started) { - P2Pd.Response response = P2Pd.Response.parseDelimitedFrom(is); - if (response.getType() == P2Pd.Response.Type.ERROR) { - throw new P2PDError("" + response.getError()); - } else { - respFuture.complete(items); - started = true; - return Action.ContinueResponse; - } - } else { - MessageT message = decoder.apply(is); - items.add(message); - return Action.ContinueResponse; - } + @Override + Action parseNextMessage(InputStream is) throws Exception { + if (!started) { + P2Pd.Response response = P2Pd.Response.parseDelimitedFrom(is); + if (response.getType() == P2Pd.Response.Type.ERROR) { + throw new P2PDError("" + response.getError()); + } else { + respFuture.complete(items); + started = true; + return Action.ContinueResponse; } + } else { + MessageT message = decoder.apply(is); + items.add(message); + return Action.ContinueResponse; + } } - + } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/DaemonLauncher.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/DaemonLauncher.java index 472cc9dc9..20ce60dc3 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/DaemonLauncher.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/DaemonLauncher.java @@ -7,41 +7,41 @@ public class DaemonLauncher { - public static class Daemon { - public final P2PDHost host; - private final Process process; - - public Daemon(P2PDHost host, Process process) { - this.host = host; - this.process = process; - } - - public void kill() { - process.destroyForcibly(); - } - } - - private final String daemonPath; - private int commandPort = 11111; + public static class Daemon { + public final P2PDHost host; + private final Process process; - public DaemonLauncher(String daemonPath) { - this.daemonPath = daemonPath; + public Daemon(P2PDHost host, Process process) { + this.host = host; + this.process = process; } - public Daemon launch(int nodePort, String ... commandLineArgs) { - ArrayList args = new ArrayList<>(); - int cmdPort = commandPort++; - args.add(daemonPath); - args.add("-listen"); - args.add("/ip4/127.0.0.1/tcp/" + cmdPort); - args.add("-hostAddrs"); - args.add("/ip4/127.0.0.1/tcp/" + nodePort); - args.addAll(Arrays.asList(commandLineArgs)); - try { - Process process = new ProcessBuilder(args).inheritIO().start(); - return new Daemon(new P2PDHost(new InetSocketAddress("127.0.0.1", cmdPort)), process); - } catch (IOException e) { - throw new RuntimeException(e); - } + public void kill() { + process.destroyForcibly(); + } + } + + private final String daemonPath; + private int commandPort = 11111; + + public DaemonLauncher(String daemonPath) { + this.daemonPath = daemonPath; + } + + public Daemon launch(int nodePort, String... commandLineArgs) { + ArrayList args = new ArrayList<>(); + int cmdPort = commandPort++; + args.add(daemonPath); + args.add("-listen"); + args.add("/ip4/127.0.0.1/tcp/" + cmdPort); + args.add("-hostAddrs"); + args.add("/ip4/127.0.0.1/tcp/" + nodePort); + args.addAll(Arrays.asList(commandLineArgs)); + try { + Process process = new ProcessBuilder(args).inheritIO().start(); + return new Daemon(new P2PDHost(new InetSocketAddress("127.0.0.1", cmdPort)), process); + } catch (IOException e) { + throw new RuntimeException(e); } + } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/NettyStream.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/NettyStream.java index b981ab62e..e142034ab 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/NettyStream.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/NettyStream.java @@ -4,64 +4,66 @@ import io.libp2p.tools.p2pd.libp2pj.Stream; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; - import java.nio.ByteBuffer; -/** - * Created by Anton Nashatyrev on 14.12.2018. - */ +/** Created by Anton Nashatyrev on 14.12.2018. */ public class NettyStream implements Stream { - private final Channel channel; - private final boolean initiator; - private final Muxer.MuxerAdress localAddress; - private final Muxer.MuxerAdress remoteAddress; + private final Channel channel; + private final boolean initiator; + private final Muxer.MuxerAdress localAddress; + private final Muxer.MuxerAdress remoteAddress; - public NettyStream(Channel channel, boolean initiator, - Muxer.MuxerAdress localAddress, - Muxer.MuxerAdress remoteAddress) { - this.channel = channel; - this.initiator = initiator; - this.localAddress = localAddress; - this.remoteAddress = remoteAddress; - } + public NettyStream( + Channel channel, + boolean initiator, + Muxer.MuxerAdress localAddress, + Muxer.MuxerAdress remoteAddress) { + this.channel = channel; + this.initiator = initiator; + this.localAddress = localAddress; + this.remoteAddress = remoteAddress; + } - public NettyStream(Channel channel, boolean initiator) { - this(channel, initiator, null, null); - } + public NettyStream(Channel channel, boolean initiator) { + this(channel, initiator, null, null); + } - @Override - public void write(ByteBuffer data) { - channel.write(Unpooled.wrappedBuffer(data)); - } + @Override + public void write(ByteBuffer data) { + channel.write(Unpooled.wrappedBuffer(data)); + } - @Override - public void flush() { - channel.flush(); - } + @Override + public void flush() { + channel.flush(); + } - @Override - public boolean isInitiator() { - return initiator; - } + @Override + public boolean isInitiator() { + return initiator; + } - @Override - public void close() { - channel.close(); - } + @Override + public void close() { + channel.close(); + } - @Override - public Muxer.MuxerAdress getRemoteAddress() { - return remoteAddress; - } + @Override + public Muxer.MuxerAdress getRemoteAddress() { + return remoteAddress; + } - @Override - public Muxer.MuxerAdress getLocalAddress() { - return localAddress; - } + @Override + public Muxer.MuxerAdress getLocalAddress() { + return localAddress; + } - @Override - public String toString() { - return "NettyStream{" + getLocalAddress() + (isInitiator() ? " -> " : " <- ") + getRemoteAddress(); - } + @Override + public String toString() { + return "NettyStream{" + + getLocalAddress() + + (isInitiator() ? " -> " : " <- ") + + getRemoteAddress(); + } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/P2PDDht.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/P2PDDht.java index 625a892ad..94300d761 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/P2PDDht.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/P2PDDht.java @@ -6,150 +6,175 @@ import io.libp2p.tools.p2pd.libp2pj.Peer; import io.libp2p.tools.p2pd.libp2pj.PeerInfo; import io.libp2p.tools.p2pd.libp2pj.util.Cid; -import p2pd.pb.P2Pd; - import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; +import p2pd.pb.P2Pd; -/** - * Created by Anton Nashatyrev on 20.12.2018. - */ +/** Created by Anton Nashatyrev on 20.12.2018. */ public class P2PDDht implements DHT { - private final AsyncDaemonExecutor daemonExecutor; - - public P2PDDht(AsyncDaemonExecutor daemonExecutor) { - this.daemonExecutor = daemonExecutor; - } - - @Override - public CompletableFuture findPeer(Peer peerId) { - return daemonExecutor.executeWithDaemon(h -> { - CompletableFuture resp = h.call( - newDhtRequest(P2Pd.DHTRequest.newBuilder() - .setType(P2Pd.DHTRequest.Type.FIND_PEER) - .setPeer(ByteString.copyFrom(peerId.getIdBytes()))) - , new DaemonChannelHandler.SimpleResponseBuilder()); - return resp.thenApply(r -> fromResp(r.getDht().getPeer())); + private final AsyncDaemonExecutor daemonExecutor; + + public P2PDDht(AsyncDaemonExecutor daemonExecutor) { + this.daemonExecutor = daemonExecutor; + } + + @Override + public CompletableFuture findPeer(Peer peerId) { + return daemonExecutor.executeWithDaemon( + h -> { + CompletableFuture resp = + h.call( + newDhtRequest( + P2Pd.DHTRequest.newBuilder() + .setType(P2Pd.DHTRequest.Type.FIND_PEER) + .setPeer(ByteString.copyFrom(peerId.getIdBytes()))), + new DaemonChannelHandler.SimpleResponseBuilder()); + return resp.thenApply(r -> fromResp(r.getDht().getPeer())); }); - } - - @Override - public CompletableFuture> findPeersConnectedToPeer(Peer peerId) { - return daemonExecutor.executeWithDaemon(h -> { - CompletableFuture> resp = h.call( - newDhtRequest(P2Pd.DHTRequest.newBuilder() - .setType(P2Pd.DHTRequest.Type.FIND_PEERS_CONNECTED_TO_PEER) - .setPeer(ByteString.copyFrom(peerId.getIdBytes()))) - , new DaemonChannelHandler.DHTListResponse()); - - return resp.thenApply(list -> - list.stream().map(pi -> fromResp(pi.getPeer())).collect(Collectors.toList())); + } + + @Override + public CompletableFuture> findPeersConnectedToPeer(Peer peerId) { + return daemonExecutor.executeWithDaemon( + h -> { + CompletableFuture> resp = + h.call( + newDhtRequest( + P2Pd.DHTRequest.newBuilder() + .setType(P2Pd.DHTRequest.Type.FIND_PEERS_CONNECTED_TO_PEER) + .setPeer(ByteString.copyFrom(peerId.getIdBytes()))), + new DaemonChannelHandler.DHTListResponse()); + + return resp.thenApply( + list -> list.stream().map(pi -> fromResp(pi.getPeer())).collect(Collectors.toList())); }); - } - - @Override - public CompletableFuture> findProviders(Cid cid, int maxRetCount) { - return daemonExecutor.executeWithDaemon(h -> { - CompletableFuture> resp = h.call( - newDhtRequest(P2Pd.DHTRequest.newBuilder() - .setType(P2Pd.DHTRequest.Type.FIND_PROVIDERS) - .setCid(ByteString.copyFrom(cid.toBytes()))) - , new DaemonChannelHandler.DHTListResponse()); - - return resp.thenApply(list ->list.stream() - .map(pi -> fromResp(pi.getPeer())).collect(Collectors.toList())); + } + + @Override + public CompletableFuture> findProviders(Cid cid, int maxRetCount) { + return daemonExecutor.executeWithDaemon( + h -> { + CompletableFuture> resp = + h.call( + newDhtRequest( + P2Pd.DHTRequest.newBuilder() + .setType(P2Pd.DHTRequest.Type.FIND_PROVIDERS) + .setCid(ByteString.copyFrom(cid.toBytes()))), + new DaemonChannelHandler.DHTListResponse()); + + return resp.thenApply( + list -> list.stream().map(pi -> fromResp(pi.getPeer())).collect(Collectors.toList())); }); - } - - @Override - public CompletableFuture> getClosestPeers(byte[] key) { - return daemonExecutor.executeWithDaemon(h -> { - CompletableFuture> resp = h.call( - newDhtRequest(P2Pd.DHTRequest.newBuilder() - .setType(P2Pd.DHTRequest.Type.GET_CLOSEST_PEERS) - .setKey(ByteString.copyFrom(key))) - , new DaemonChannelHandler.DHTListResponse()); - - return resp.thenApply(list ->list.stream() - .map(pi -> fromResp(pi.getPeer())).collect(Collectors.toList())); + } + + @Override + public CompletableFuture> getClosestPeers(byte[] key) { + return daemonExecutor.executeWithDaemon( + h -> { + CompletableFuture> resp = + h.call( + newDhtRequest( + P2Pd.DHTRequest.newBuilder() + .setType(P2Pd.DHTRequest.Type.GET_CLOSEST_PEERS) + .setKey(ByteString.copyFrom(key))), + new DaemonChannelHandler.DHTListResponse()); + + return resp.thenApply( + list -> list.stream().map(pi -> fromResp(pi.getPeer())).collect(Collectors.toList())); }); - } - - @Override - public CompletableFuture getPublicKey(Peer peerId) { - return daemonExecutor.executeWithDaemon(h -> { - CompletableFuture resp = h.call( - newDhtRequest(P2Pd.DHTRequest.newBuilder() - .setType(P2Pd.DHTRequest.Type.GET_PUBLIC_KEY) - .setPeer(ByteString.copyFrom(peerId.getIdBytes()))) - , new DaemonChannelHandler.SimpleResponseBuilder()); - return resp.thenApply(r -> r.getDht().getValue().toByteArray()); + } + + @Override + public CompletableFuture getPublicKey(Peer peerId) { + return daemonExecutor.executeWithDaemon( + h -> { + CompletableFuture resp = + h.call( + newDhtRequest( + P2Pd.DHTRequest.newBuilder() + .setType(P2Pd.DHTRequest.Type.GET_PUBLIC_KEY) + .setPeer(ByteString.copyFrom(peerId.getIdBytes()))), + new DaemonChannelHandler.SimpleResponseBuilder()); + return resp.thenApply(r -> r.getDht().getValue().toByteArray()); }); - } - - @Override - public CompletableFuture getValue(byte[] key) { - return daemonExecutor.executeWithDaemon(h -> { - CompletableFuture resp = h.call( - newDhtRequest(P2Pd.DHTRequest.newBuilder() - .setType(P2Pd.DHTRequest.Type.GET_VALUE) - .setKey(ByteString.copyFrom(key))) - , new DaemonChannelHandler.SimpleResponseBuilder()); - return resp.thenApply(r -> r.getDht().getValue().toByteArray()); + } + + @Override + public CompletableFuture getValue(byte[] key) { + return daemonExecutor.executeWithDaemon( + h -> { + CompletableFuture resp = + h.call( + newDhtRequest( + P2Pd.DHTRequest.newBuilder() + .setType(P2Pd.DHTRequest.Type.GET_VALUE) + .setKey(ByteString.copyFrom(key))), + new DaemonChannelHandler.SimpleResponseBuilder()); + return resp.thenApply(r -> r.getDht().getValue().toByteArray()); }); - } - - @Override - public CompletableFuture> searchValue(byte[] key) { - return daemonExecutor.executeWithDaemon(h -> { - CompletableFuture> resp = h.call( - newDhtRequest(P2Pd.DHTRequest.newBuilder() - .setType(P2Pd.DHTRequest.Type.SEARCH_VALUE) - .setKey(ByteString.copyFrom(key))) - , new DaemonChannelHandler.DHTListResponse()); - - return resp.thenApply(list ->list.stream() - .map(val -> val.getValue().toByteArray()).collect(Collectors.toList())); + } + + @Override + public CompletableFuture> searchValue(byte[] key) { + return daemonExecutor.executeWithDaemon( + h -> { + CompletableFuture> resp = + h.call( + newDhtRequest( + P2Pd.DHTRequest.newBuilder() + .setType(P2Pd.DHTRequest.Type.SEARCH_VALUE) + .setKey(ByteString.copyFrom(key))), + new DaemonChannelHandler.DHTListResponse()); + + return resp.thenApply( + list -> + list.stream() + .map(val -> val.getValue().toByteArray()) + .collect(Collectors.toList())); }); - } - - @Override - public CompletableFuture putValue(byte[] key, byte[] value) { - return daemonExecutor.executeWithDaemon(h -> { - CompletableFuture resp = h.call( - newDhtRequest(P2Pd.DHTRequest.newBuilder() - .setType(P2Pd.DHTRequest.Type.PUT_VALUE) - .setKey(ByteString.copyFrom(key)) - .setValue(ByteString.copyFrom(value))) - , new DaemonChannelHandler.SimpleResponseBuilder()); - return resp.thenApply(r -> null); + } + + @Override + public CompletableFuture putValue(byte[] key, byte[] value) { + return daemonExecutor.executeWithDaemon( + h -> { + CompletableFuture resp = + h.call( + newDhtRequest( + P2Pd.DHTRequest.newBuilder() + .setType(P2Pd.DHTRequest.Type.PUT_VALUE) + .setKey(ByteString.copyFrom(key)) + .setValue(ByteString.copyFrom(value))), + new DaemonChannelHandler.SimpleResponseBuilder()); + return resp.thenApply(r -> null); }); - } - - @Override - public CompletableFuture provide(Cid cid) { - return daemonExecutor.executeWithDaemon(h -> { - CompletableFuture resp = h.call( - newDhtRequest(P2Pd.DHTRequest.newBuilder() - .setType(P2Pd.DHTRequest.Type.PROVIDE) - .setCid(ByteString.copyFrom(cid.toBytes()))) - , new DaemonChannelHandler.SimpleResponseBuilder()); - return resp.thenApply(r -> null); + } + + @Override + public CompletableFuture provide(Cid cid) { + return daemonExecutor.executeWithDaemon( + h -> { + CompletableFuture resp = + h.call( + newDhtRequest( + P2Pd.DHTRequest.newBuilder() + .setType(P2Pd.DHTRequest.Type.PROVIDE) + .setCid(ByteString.copyFrom(cid.toBytes()))), + new DaemonChannelHandler.SimpleResponseBuilder()); + return resp.thenApply(r -> null); }); - } - - private static PeerInfo fromResp(P2Pd.PeerInfo pi) { - return new PeerInfo(new Peer(pi.getId().toByteArray()), - pi.getAddrsList().stream() - .map(addr -> Multiaddr.deserialize(addr.toByteArray())) - .collect(Collectors.toList())); - } - - private static P2Pd.Request newDhtRequest(P2Pd.DHTRequest.Builder dht) { - return P2Pd.Request.newBuilder() - .setType(P2Pd.Request.Type.DHT) - .setDht(dht) - .build(); - } + } + + private static PeerInfo fromResp(P2Pd.PeerInfo pi) { + return new PeerInfo( + new Peer(pi.getId().toByteArray()), + pi.getAddrsList().stream() + .map(addr -> Multiaddr.deserialize(addr.toByteArray())) + .collect(Collectors.toList())); + } + + private static P2Pd.Request newDhtRequest(P2Pd.DHTRequest.Builder dht) { + return P2Pd.Request.newBuilder().setType(P2Pd.Request.Type.DHT).setDht(dht).build(); + } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/P2PDError.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/P2PDError.java index bf31ea3f3..bf519774c 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/P2PDError.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/P2PDError.java @@ -1,10 +1,8 @@ package io.libp2p.tools.p2pd; -/** - * Created by Anton Nashatyrev on 14.12.2018. - */ +/** Created by Anton Nashatyrev on 14.12.2018. */ public class P2PDError extends RuntimeException { - public P2PDError(String message) { - super(message); - } + public P2PDError(String message) { + super(message); + } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/P2PDHost.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/P2PDHost.java index 0a1076972..816ae6cf5 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/P2PDHost.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/P2PDHost.java @@ -10,8 +10,6 @@ import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.unix.DomainSocketAddress; -import p2pd.pb.P2Pd; - import java.io.Closeable; import java.io.IOException; import java.net.InetSocketAddress; @@ -22,174 +20,203 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; import java.util.stream.Collectors; +import p2pd.pb.P2Pd; -/** - * Created by Anton Nashatyrev on 18.12.2018. - */ +/** Created by Anton Nashatyrev on 18.12.2018. */ public class P2PDHost implements Host { - private AsyncDaemonExecutor daemonExecutor; + private AsyncDaemonExecutor daemonExecutor; - private final int requestTimeoutSec = 5; + private final int requestTimeoutSec = 5; - public static P2PDHost createDefaultDomainSocket() { - return new P2PDHost(new DomainSocketAddress("/tmp/p2pd.sock")); - } + public static P2PDHost createDefaultDomainSocket() { + return new P2PDHost(new DomainSocketAddress("/tmp/p2pd.sock")); + } - public P2PDHost(SocketAddress addr) { - daemonExecutor = new AsyncDaemonExecutor(addr); - } + public P2PDHost(SocketAddress addr) { + daemonExecutor = new AsyncDaemonExecutor(addr); + } - @Override - public DHT getDht() { - return new P2PDDht(daemonExecutor); - } + @Override + public DHT getDht() { + return new P2PDDht(daemonExecutor); + } - public P2PDPubsub getPubsub() { - return new P2PDPubsub(daemonExecutor); - } + public P2PDPubsub getPubsub() { + return new P2PDPubsub(daemonExecutor); + } - @Override - public Peer getMyId() { - try { - return new Peer(identify().get().getId().toByteArray()); - } catch (Exception e) { - throw new RuntimeException(e); - } + @Override + public Peer getMyId() { + try { + return new Peer(identify().get().getId().toByteArray()); + } catch (Exception e) { + throw new RuntimeException(e); } - - @Override - public List getListenAddresses() { - try { - return identify().get().getAddrsList().stream() - .map(bs -> Multiaddr.deserialize(bs.toByteArray())) - .collect(Collectors.toList()); - } catch (Exception e) { - throw new RuntimeException(e); - } + } + + @Override + public List getListenAddresses() { + try { + return identify().get().getAddrsList().stream() + .map(bs -> Multiaddr.deserialize(bs.toByteArray())) + .collect(Collectors.toList()); + } catch (Exception e) { + throw new RuntimeException(e); } - - private CompletableFuture identify() { - return daemonExecutor.executeWithDaemon(h -> - h.call(P2Pd.Request.newBuilder() - .setType(P2Pd.Request.Type.IDENTIFY) - .build(), new DaemonChannelHandler.SimpleResponseBuilder()) - .thenApply(P2Pd.Response::getIdentify) - ); - } - - @Override - public CompletableFuture connect(List peerAddresses, Peer peerId) { - return daemonExecutor.executeWithDaemon(handler -> { - CompletableFuture resp = handler.call(P2Pd.Request.newBuilder() - .setType(P2Pd.Request.Type.CONNECT) - .setConnect(P2Pd.ConnectRequest.newBuilder() - .setPeer(ByteString.copyFrom(peerId.getIdBytes())) - .addAllAddrs(peerAddresses.stream() - .map(addr -> ByteString.copyFrom(addr.serialize())) - .collect(Collectors.toList())) - .setTimeout(requestTimeoutSec) - .build() - ).build(), - new DaemonChannelHandler.SimpleResponseBuilder()); - return resp.thenApply(r -> null); + } + + private CompletableFuture identify() { + return daemonExecutor.executeWithDaemon( + h -> + h.call( + P2Pd.Request.newBuilder().setType(P2Pd.Request.Type.IDENTIFY).build(), + new DaemonChannelHandler.SimpleResponseBuilder()) + .thenApply(P2Pd.Response::getIdentify)); + } + + @Override + public CompletableFuture connect(List peerAddresses, Peer peerId) { + return daemonExecutor.executeWithDaemon( + handler -> { + CompletableFuture resp = + handler.call( + P2Pd.Request.newBuilder() + .setType(P2Pd.Request.Type.CONNECT) + .setConnect( + P2Pd.ConnectRequest.newBuilder() + .setPeer(ByteString.copyFrom(peerId.getIdBytes())) + .addAllAddrs( + peerAddresses.stream() + .map(addr -> ByteString.copyFrom(addr.serialize())) + .collect(Collectors.toList())) + .setTimeout(requestTimeoutSec) + .build()) + .build(), + new DaemonChannelHandler.SimpleResponseBuilder()); + return resp.thenApply(r -> null); }); - } - - private final List activeChannels = new Vector<>(); - private final AtomicInteger counter = new AtomicInteger(); - - @Override - public CompletableFuture dial(MuxerAdress muxerAdress, StreamHandler streamHandler) { - try { - return daemonExecutor.getDaemon().thenCompose(handler -> { + } + + private final List activeChannels = new Vector<>(); + private final AtomicInteger counter = new AtomicInteger(); + + @Override + public CompletableFuture dial( + MuxerAdress muxerAdress, StreamHandler streamHandler) { + try { + return daemonExecutor + .getDaemon() + .thenCompose( + handler -> { try { - handler.setStreamHandler(new StreamHandlerWrapper<>(streamHandler) - .onCreate(s -> activeChannels.add(handler)) - .onClose(() -> activeChannels.remove(handler)) - ); - CompletableFuture resp = handler.call(P2Pd.Request.newBuilder() - .setType(P2Pd.Request.Type.STREAM_OPEN) - .setStreamOpen(P2Pd.StreamOpenRequest.newBuilder() - .setPeer(ByteString.copyFrom(muxerAdress.getPeer().getIdBytes())) - .addAllProto(muxerAdress.getProtocols().stream() - .map(Protocol::getName).collect(Collectors.toList())) - .setTimeout(requestTimeoutSec) - .build() - ).build(), - new DaemonChannelHandler.SimpleResponseStreamBuilder()); - return resp.whenComplete((r, t) -> { - if (t != null) { - streamHandler.onError(t); - handler.close(); - } - }).thenApply(r -> null); + handler.setStreamHandler( + new StreamHandlerWrapper<>(streamHandler) + .onCreate(s -> activeChannels.add(handler)) + .onClose(() -> activeChannels.remove(handler))); + CompletableFuture resp = + handler.call( + P2Pd.Request.newBuilder() + .setType(P2Pd.Request.Type.STREAM_OPEN) + .setStreamOpen( + P2Pd.StreamOpenRequest.newBuilder() + .setPeer( + ByteString.copyFrom(muxerAdress.getPeer().getIdBytes())) + .addAllProto( + muxerAdress.getProtocols().stream() + .map(Protocol::getName) + .collect(Collectors.toList())) + .setTimeout(requestTimeoutSec) + .build()) + .build(), + new DaemonChannelHandler.SimpleResponseStreamBuilder()); + return resp.whenComplete( + (r, t) -> { + if (t != null) { + streamHandler.onError(t); + handler.close(); + } + }) + .thenApply(r -> null); } catch (Exception e) { - handler.close(); - throw new RuntimeException(e); + handler.close(); + throw new RuntimeException(e); } - }); - } catch (Exception e) { - throw new RuntimeException(e); - } + }); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public CompletableFuture listen( + MuxerAdress muxerAdress, Supplier> handlerFactory) { + Multiaddr listenMultiaddr; + SocketAddress listenAddr; + ControlConnector connector; + if (daemonExecutor.getAddress() instanceof InetSocketAddress) { + connector = new TCPControlConnector(); + int port = 46666 + counter.incrementAndGet(); + listenAddr = new InetSocketAddress("127.0.0.1", port); + listenMultiaddr = new Multiaddr("/ip4/127.0.0.1/tcp/46666"); + } else if (daemonExecutor.getAddress() instanceof DomainSocketAddress) { + connector = new UnixSocketControlConnector(); + String path = "/tmp/p2pd.client." + counter.incrementAndGet(); + listenAddr = new DomainSocketAddress(path); + listenMultiaddr = new Multiaddr("/unix" + path); + } else { + throw new IllegalStateException(); } - @Override - public CompletableFuture listen(MuxerAdress muxerAdress, Supplier> handlerFactory) { - Multiaddr listenMultiaddr; - SocketAddress listenAddr; - ControlConnector connector; - if (daemonExecutor.getAddress() instanceof InetSocketAddress) { - connector = new TCPControlConnector(); - int port = 46666 + counter.incrementAndGet(); - listenAddr = new InetSocketAddress("127.0.0.1", port); - listenMultiaddr = new Multiaddr("/ip4/127.0.0.1/tcp/46666"); - } else if (daemonExecutor.getAddress() instanceof DomainSocketAddress) { - connector = new UnixSocketControlConnector(); - String path = "/tmp/p2pd.client." + counter.incrementAndGet(); - listenAddr = new DomainSocketAddress(path); - listenMultiaddr = new Multiaddr("/unix" + path); - } else { - throw new IllegalStateException(); - } - - ChannelFuture channelFuture = connector.listen(listenAddr, h -> { - StreamHandler streamHandler = handlerFactory.get(); - h.setStreamHandler(streamHandler); - CompletableFuture response = h.expectResponse(new DaemonChannelHandler.ListenerStreamBuilder()); - response.whenComplete((r, t) -> { - if (t != null) { - streamHandler.onError(t); - } + ChannelFuture channelFuture = + connector.listen( + listenAddr, + h -> { + StreamHandler streamHandler = handlerFactory.get(); + h.setStreamHandler(streamHandler); + CompletableFuture response = + h.expectResponse(new DaemonChannelHandler.ListenerStreamBuilder()); + response.whenComplete( + (r, t) -> { + if (t != null) { + streamHandler.onError(t); + } + }); }); - }); - channelFuture.addListener((ChannelFutureListener) - future -> activeChannels.add(() -> future.channel().close())); - - Closeable ret = () -> channelFuture.channel().close(); - return Util.channelFutureToJava(channelFuture) - .thenCompose(channel -> - daemonExecutor.executeWithDaemon(handler -> - handler.call(P2Pd.Request.newBuilder() - .setType(P2Pd.Request.Type.STREAM_HANDLER) - .setStreamHandler(P2Pd.StreamHandlerRequest.newBuilder() - .setAddr(ByteString.copyFrom(listenMultiaddr.serialize())) - .addAllProto(muxerAdress.getProtocols().stream() - .map(Protocol::getName).collect(Collectors.toList())) - .build() - ).build(), - new DaemonChannelHandler.SimpleResponseBuilder()))) - .thenApply(resp1 -> ret); - } - - @Override - public void close() { - activeChannels.forEach(ch -> { - try { - ch.close(); - } catch (IOException e) { - e.printStackTrace(); - } + channelFuture.addListener( + (ChannelFutureListener) future -> activeChannels.add(() -> future.channel().close())); + + Closeable ret = () -> channelFuture.channel().close(); + return Util.channelFutureToJava(channelFuture) + .thenCompose( + channel -> + daemonExecutor.executeWithDaemon( + handler -> + handler.call( + P2Pd.Request.newBuilder() + .setType(P2Pd.Request.Type.STREAM_HANDLER) + .setStreamHandler( + P2Pd.StreamHandlerRequest.newBuilder() + .setAddr(ByteString.copyFrom(listenMultiaddr.serialize())) + .addAllProto( + muxerAdress.getProtocols().stream() + .map(Protocol::getName) + .collect(Collectors.toList())) + .build()) + .build(), + new DaemonChannelHandler.SimpleResponseBuilder()))) + .thenApply(resp1 -> ret); + } + + @Override + public void close() { + activeChannels.forEach( + ch -> { + try { + ch.close(); + } catch (IOException e) { + e.printStackTrace(); + } }); - } + } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/P2PDPubsub.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/P2PDPubsub.java index d0986ad7b..156843248 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/P2PDPubsub.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/P2PDPubsub.java @@ -1,56 +1,56 @@ package io.libp2p.tools.p2pd; import com.google.protobuf.ByteString; -import p2pd.pb.P2Pd; - import java.io.IOException; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; +import p2pd.pb.P2Pd; -/** - * Created by Anton Nashatyrev on 20.12.2018. - */ +/** Created by Anton Nashatyrev on 20.12.2018. */ public class P2PDPubsub { - private final AsyncDaemonExecutor daemonExecutor; + private final AsyncDaemonExecutor daemonExecutor; - public P2PDPubsub(AsyncDaemonExecutor daemonExecutor) { - this.daemonExecutor = daemonExecutor; - } + public P2PDPubsub(AsyncDaemonExecutor daemonExecutor) { + this.daemonExecutor = daemonExecutor; + } - public CompletableFuture> subscribe(String topic) { - return daemonExecutor.getDaemon().thenCompose(h -> { - return h.call( - newPubsubRequest(P2Pd.PSRequest.newBuilder() - .setType(P2Pd.PSRequest.Type.SUBSCRIBE) - .setTopic(topic)) - , new DaemonChannelHandler.UnboundMessagesResponse<>( - is -> { - try { - return P2Pd.PSMessage.parseDelimitedFrom(is); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - )); - }); - } + public CompletableFuture> subscribe(String topic) { + return daemonExecutor + .getDaemon() + .thenCompose( + h -> { + return h.call( + newPubsubRequest( + P2Pd.PSRequest.newBuilder() + .setType(P2Pd.PSRequest.Type.SUBSCRIBE) + .setTopic(topic)), + new DaemonChannelHandler.UnboundMessagesResponse<>( + is -> { + try { + return P2Pd.PSMessage.parseDelimitedFrom(is); + } catch (IOException e) { + throw new RuntimeException(e); + } + })); + }); + } - public CompletableFuture publish(String topic, byte[] data) { - return daemonExecutor.executeWithDaemon(h -> { - CompletableFuture resp = h.call( - newPubsubRequest(P2Pd.PSRequest.newBuilder() - .setType(P2Pd.PSRequest.Type.PUBLISH) - .setTopic(topic) - .setData(ByteString.copyFrom(data))) - , new DaemonChannelHandler.SimpleResponseBuilder()); - return resp.thenApply(r -> null); + public CompletableFuture publish(String topic, byte[] data) { + return daemonExecutor.executeWithDaemon( + h -> { + CompletableFuture resp = + h.call( + newPubsubRequest( + P2Pd.PSRequest.newBuilder() + .setType(P2Pd.PSRequest.Type.PUBLISH) + .setTopic(topic) + .setData(ByteString.copyFrom(data))), + new DaemonChannelHandler.SimpleResponseBuilder()); + return resp.thenApply(r -> null); }); - } + } - private static P2Pd.Request newPubsubRequest(P2Pd.PSRequest.Builder pubsub) { - return P2Pd.Request.newBuilder() - .setType(P2Pd.Request.Type.PUBSUB) - .setPubsub(pubsub) - .build(); - } + private static P2Pd.Request newPubsubRequest(P2Pd.PSRequest.Builder pubsub) { + return P2Pd.Request.newBuilder().setType(P2Pd.Request.Type.PUBSUB).setPubsub(pubsub).build(); + } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/StreamHandlerWrapper.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/StreamHandlerWrapper.java index 4fc3acd4d..d034a2e18 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/StreamHandlerWrapper.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/StreamHandlerWrapper.java @@ -2,55 +2,52 @@ import io.libp2p.tools.p2pd.libp2pj.Stream; import io.libp2p.tools.p2pd.libp2pj.StreamHandler; - import java.nio.ByteBuffer; import java.util.function.Consumer; -/** - * Created by Anton Nashatyrev on 18.12.2018. - */ +/** Created by Anton Nashatyrev on 18.12.2018. */ public class StreamHandlerWrapper implements StreamHandler { - private final StreamHandler delegate; - private Consumer> onCreateListener; - private Runnable onCloseListener; - - public StreamHandlerWrapper(StreamHandler delegate) { - this.delegate = delegate; - } - - public StreamHandlerWrapper onCreate(Consumer> listener) { - this.onCreateListener = listener; - return this; - } - - public StreamHandlerWrapper onClose(Runnable listener) { - this.onCloseListener = listener; - return this; + private final StreamHandler delegate; + private Consumer> onCreateListener; + private Runnable onCloseListener; + + public StreamHandlerWrapper(StreamHandler delegate) { + this.delegate = delegate; + } + + public StreamHandlerWrapper onCreate(Consumer> listener) { + this.onCreateListener = listener; + return this; + } + + public StreamHandlerWrapper onClose(Runnable listener) { + this.onCloseListener = listener; + return this; + } + + @Override + public void onCreate(Stream stream) { + delegate.onCreate(stream); + if (onCreateListener != null) { + onCreateListener.accept(stream); } - - @Override - public void onCreate(Stream stream) { - delegate.onCreate(stream); - if (onCreateListener != null) { - onCreateListener.accept(stream); - } - } - - @Override - public void onRead(ByteBuffer data) { - delegate.onRead(data); + } + + @Override + public void onRead(ByteBuffer data) { + delegate.onRead(data); + } + + @Override + public void onClose() { + delegate.onClose(); + if (onCloseListener != null) { + onCloseListener.run(); } + } - @Override - public void onClose() { - delegate.onClose(); - if (onCloseListener != null) { - onCloseListener.run(); - } - } - - @Override - public void onError(Throwable error) { - delegate.onError(error); - } + @Override + public void onError(Throwable error) { + delegate.onError(error); + } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/TCPControlConnector.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/TCPControlConnector.java index 4893f2cde..509870093 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/TCPControlConnector.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/TCPControlConnector.java @@ -8,47 +8,48 @@ import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; - import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; -/** - * Created by Anton Nashatyrev on 06.08.2019. - */ +/** Created by Anton Nashatyrev on 06.08.2019. */ public class TCPControlConnector extends ControlConnector { - static NioEventLoopGroup group = new NioEventLoopGroup(); + static NioEventLoopGroup group = new NioEventLoopGroup(); + + public CompletableFuture connect(String host, int port) { + return connect(new InetSocketAddress(host, port)); + } - public CompletableFuture connect(String host, int port) { - return connect(new InetSocketAddress(host, port)); - } - public CompletableFuture connect(SocketAddress addr) { - CompletableFuture ret = new CompletableFuture<>(); + public CompletableFuture connect(SocketAddress addr) { + CompletableFuture ret = new CompletableFuture<>(); - ChannelFuture channelFuture = new Bootstrap() - .group(group) - .channel(NioSocketChannel.class) - .handler(new ChannelInit(ret::complete, true)) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeoutSec * 1000) - .connect(addr); + ChannelFuture channelFuture = + new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInit(ret::complete, true)) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeoutSec * 1000) + .connect(addr); - channelFuture.addListener((ChannelFutureListener) future -> { - try { + channelFuture.addListener( + (ChannelFutureListener) + future -> { + try { future.get(); - } catch (Exception e) { + } catch (Exception e) { ret.completeExceptionally(e); - } - }); - return ret; - } + } + }); + return ret; + } - public ChannelFuture listen(SocketAddress addr, Consumer handlersConsumer) { - return new ServerBootstrap() - .group(group) - .channel(NioServerSocketChannel.class) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeoutSec * 1000) - .childHandler(new ChannelInit(handlersConsumer, false)) - .bind(addr); - } + public ChannelFuture listen(SocketAddress addr, Consumer handlersConsumer) { + return new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeoutSec * 1000) + .childHandler(new ChannelInit(handlersConsumer, false)) + .bind(addr); + } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/UnixSocketControlConnector.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/UnixSocketControlConnector.java index edbd3e92c..16db42c4c 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/UnixSocketControlConnector.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/UnixSocketControlConnector.java @@ -10,55 +10,56 @@ import io.netty.channel.epoll.EpollEventLoopGroup; import io.netty.channel.epoll.EpollServerDomainSocketChannel; import io.netty.channel.unix.DomainSocketAddress; - import java.net.SocketAddress; import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; -/** - * Created by Anton Nashatyrev on 06.08.2019. - */ +/** Created by Anton Nashatyrev on 06.08.2019. */ public class UnixSocketControlConnector extends ControlConnector { - protected static final EventLoopGroup group = new EpollEventLoopGroup(); + protected static final EventLoopGroup group = new EpollEventLoopGroup(); + + public CompletableFuture connect() { + return connect("/tmp/p2pd.sock"); + } - public CompletableFuture connect() { - return connect("/tmp/p2pd.sock"); - } + public CompletableFuture connect(String socketPath) { + return connect(new DomainSocketAddress(socketPath)); + } - public CompletableFuture connect(String socketPath) { - return connect(new DomainSocketAddress(socketPath)); - } - public CompletableFuture connect(SocketAddress addr) { - CompletableFuture ret = new CompletableFuture<>(); + public CompletableFuture connect(SocketAddress addr) { + CompletableFuture ret = new CompletableFuture<>(); - ChannelFuture channelFuture = new Bootstrap() - .group(group) - .channel(EpollDomainSocketChannel.class) - .handler(new ChannelInit(ret::complete, true)) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeoutSec * 1000) - .connect(addr); + ChannelFuture channelFuture = + new Bootstrap() + .group(group) + .channel(EpollDomainSocketChannel.class) + .handler(new ChannelInit(ret::complete, true)) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeoutSec * 1000) + .connect(addr); - channelFuture.addListener((ChannelFutureListener) future -> { - try { + channelFuture.addListener( + (ChannelFutureListener) + future -> { + try { future.get(); - } catch (Exception e) { + } catch (Exception e) { ret.completeExceptionally(e); - } - }); - return ret; - } + } + }); + return ret; + } - public ChannelFuture listen(String socketPath, Consumer handlersConsumer) { - return listen(new DomainSocketAddress(socketPath), handlersConsumer); - } + public ChannelFuture listen(String socketPath, Consumer handlersConsumer) { + return listen(new DomainSocketAddress(socketPath), handlersConsumer); + } - public ChannelFuture listen(SocketAddress addr, Consumer handlersConsumer) { + public ChannelFuture listen(SocketAddress addr, Consumer handlersConsumer) { - return new ServerBootstrap() - .group(group) - .channel(EpollServerDomainSocketChannel.class) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeoutSec * 1000) - .childHandler(new ChannelInit(handlersConsumer, false)) - .bind(addr); - } + return new ServerBootstrap() + .group(group) + .channel(EpollServerDomainSocketChannel.class) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeoutSec * 1000) + .childHandler(new ChannelInit(handlersConsumer, false)) + .bind(addr); + } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/Util.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/Util.java index 049a65689..1121ca3e8 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/Util.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/Util.java @@ -8,70 +8,73 @@ import io.netty.util.concurrent.GenericFutureListener; import io.netty.util.concurrent.ImmediateEventExecutor; import io.netty.util.concurrent.Promise; - import java.nio.ByteBuffer; import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; -/** - * Created by Anton Nashatyrev on 18.12.2018. - */ +/** Created by Anton Nashatyrev on 18.12.2018. */ public class Util { - public static Future futureFromJavaToNetty(CompletableFuture javaFut) { - Promise ret = ImmediateEventExecutor.INSTANCE.newPromise(); - javaFut.handle((v, t) -> { - if (t != null) ret.setFailure(t); - else ret.setSuccess(v); - return null; + public static Future futureFromJavaToNetty(CompletableFuture javaFut) { + Promise ret = ImmediateEventExecutor.INSTANCE.newPromise(); + javaFut.handle( + (v, t) -> { + if (t != null) ret.setFailure(t); + else ret.setSuccess(v); + return null; }); - return ret; - } + return ret; + } - public static CompletableFuture channelFutureToJava(ChannelFuture channelFuture) { - CompletableFuture ret = new CompletableFuture<>(); - channelFuture.addListener((ChannelFutureListener) future -> { - try { + public static CompletableFuture channelFutureToJava(ChannelFuture channelFuture) { + CompletableFuture ret = new CompletableFuture<>(); + channelFuture.addListener( + (ChannelFutureListener) + future -> { + try { future.get(); ret.complete(future.channel()); - } catch (Exception e) { + } catch (Exception e) { ret.completeExceptionally(e); - } - }); - return ret; - } + } + }); + return ret; + } - public static CompletableFuture futureFromNettyToJava(Future nettyFut) { - CompletableFuture ret = new CompletableFuture<>(); - addListener(nettyFut, ret::complete, ret::completeExceptionally); - return ret; - } + public static CompletableFuture futureFromNettyToJava(Future nettyFut) { + CompletableFuture ret = new CompletableFuture<>(); + addListener(nettyFut, ret::complete, ret::completeExceptionally); + return ret; + } - private static void addListener(Future future, Consumer success, Consumer error) { - if (future == null) return; + private static void addListener( + Future future, Consumer success, Consumer error) { + if (future == null) return; - future.addListener((GenericFutureListener>) f -> { - try { + future.addListener( + (GenericFutureListener>) + f -> { + try { V v = f.get(); if (success != null) success.accept(v); - } catch (Throwable e) { + } catch (Throwable e) { if (error != null) error.accept(e); - } - }); - } + } + }); + } - private static Promise newPromise() { - return ImmediateEventExecutor.INSTANCE.newPromise(); - } + private static Promise newPromise() { + return ImmediateEventExecutor.INSTANCE.newPromise(); + } - public static byte[] byteBufToArray(ByteBuf bb) { - byte[] ret = new byte[bb.readableBytes()]; - bb.readBytes(ret); - return ret; - } + public static byte[] byteBufToArray(ByteBuf bb) { + byte[] ret = new byte[bb.readableBytes()]; + bb.readBytes(ret); + return ret; + } - public static byte[] byteBufferToArray(ByteBuffer bb) { - byte[] ret = new byte[bb.remaining()]; - bb.get(ret); - return ret; - } + public static byte[] byteBufferToArray(ByteBuffer bb) { + byte[] ret = new byte[bb.remaining()]; + bb.get(ret); + return ret; + } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Connector.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Connector.java index 974c5ad0b..468b5b94b 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Connector.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Connector.java @@ -4,14 +4,11 @@ import java.util.concurrent.CompletableFuture; import java.util.function.Supplier; -/** - * Created by Anton Nashatyrev on 18.12.2018. - */ +/** Created by Anton Nashatyrev on 18.12.2018. */ public interface Connector { - CompletableFuture dial(TEndpoint address, - StreamHandler handler); + CompletableFuture dial(TEndpoint address, StreamHandler handler); - CompletableFuture listen(TEndpoint listenAddress, - Supplier> handlerFactory); + CompletableFuture listen( + TEndpoint listenAddress, Supplier> handlerFactory); } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/DHT.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/DHT.java index 6593149ea..1cdf75a7b 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/DHT.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/DHT.java @@ -1,29 +1,26 @@ package io.libp2p.tools.p2pd.libp2pj; import io.libp2p.tools.p2pd.libp2pj.util.Cid; - import java.util.List; import java.util.concurrent.CompletableFuture; -/** - * Created by Anton Nashatyrev on 21.12.2018. - */ +/** Created by Anton Nashatyrev on 21.12.2018. */ public interface DHT { - CompletableFuture findPeer(Peer peerId); + CompletableFuture findPeer(Peer peerId); - CompletableFuture> findPeersConnectedToPeer(Peer peerId); + CompletableFuture> findPeersConnectedToPeer(Peer peerId); - CompletableFuture> findProviders(Cid cid, int maxRetCount); + CompletableFuture> findProviders(Cid cid, int maxRetCount); - CompletableFuture> getClosestPeers(byte[] key); + CompletableFuture> getClosestPeers(byte[] key); - CompletableFuture getPublicKey(Peer peerId); + CompletableFuture getPublicKey(Peer peerId); - CompletableFuture getValue(byte[] key); + CompletableFuture getValue(byte[] key); - CompletableFuture> searchValue(byte[] key); + CompletableFuture> searchValue(byte[] key); - CompletableFuture putValue(byte[] key, byte[] value); + CompletableFuture putValue(byte[] key, byte[] value); - CompletableFuture provide(Cid cid); + CompletableFuture provide(Cid cid); } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Host.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Host.java index 44c595bda..025b7d40a 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Host.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Host.java @@ -1,37 +1,34 @@ package io.libp2p.tools.p2pd.libp2pj; import io.libp2p.core.multiformats.Multiaddr; - import java.io.Closeable; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.function.Supplier; -/** - * Created by Anton Nashatyrev on 18.12.2018. - */ +/** Created by Anton Nashatyrev on 18.12.2018. */ public interface Host extends Muxer { - Peer getMyId(); + Peer getMyId(); - List getListenAddresses(); + List getListenAddresses(); - CompletableFuture connect(List peerAddresses, Peer peerId); + CompletableFuture connect(List peerAddresses, Peer peerId); - @Override - CompletableFuture dial(MuxerAdress muxerAdress, StreamHandler handler); + @Override + CompletableFuture dial(MuxerAdress muxerAdress, StreamHandler handler); - @Override - CompletableFuture listen(MuxerAdress muxerAdress, - Supplier> handlerFactory); + @Override + CompletableFuture listen( + MuxerAdress muxerAdress, Supplier> handlerFactory); - void close(); + void close(); - DHT getDht(); + DHT getDht(); -// Peerstore getPeerStore(); + // Peerstore getPeerStore(); -// Network getNetwork(); + // Network getNetwork(); -// ConnectionManager getConnectionManager(); + // ConnectionManager getConnectionManager(); } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Muxer.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Muxer.java index 7eb75ef0c..059353be9 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Muxer.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Muxer.java @@ -8,55 +8,52 @@ import java.util.function.Supplier; import java.util.stream.Collectors; -/** - * Created by Anton Nashatyrev on 10.12.2018. - */ +/** Created by Anton Nashatyrev on 10.12.2018. */ public interface Muxer extends Connector { - @Override - CompletableFuture dial(MuxerAdress muxerAdress, StreamHandler handler); + @Override + CompletableFuture dial(MuxerAdress muxerAdress, StreamHandler handler); + + @Override + CompletableFuture listen( + MuxerAdress muxerAdress, Supplier> handlerFactory); + + public class MuxerAdress { + public static MuxerAdress listenAddress(String... protocolNames) { + return new MuxerAdress(null, protocolNames); + } + + private final List protocols = new ArrayList<>(); + private final Peer peer; + + public MuxerAdress(Peer peer, String... protocolNames) { + this(Arrays.stream(protocolNames).map(Protocol::new).collect(Collectors.toList()), peer); + } + + public MuxerAdress(List protocols, Peer peer) { + this.protocols.addAll(protocols); + this.peer = peer; + } + + public MuxerAdress(Protocol protocol, Peer peer) { + this.protocols.add(protocol); + this.peer = peer; + } + + public List getProtocols() { + return protocols; + } + + public Peer getPeer() { + return peer; + } @Override - CompletableFuture listen(MuxerAdress muxerAdress, - Supplier> handlerFactory); - - public class MuxerAdress { - public static MuxerAdress listenAddress(String... protocolNames) { - return new MuxerAdress(null, protocolNames); - } - - private final List protocols = new ArrayList<>(); - private final Peer peer; - - public MuxerAdress(Peer peer, String... protocolNames) { - this(Arrays.stream(protocolNames) - .map(Protocol::new) - .collect(Collectors.toList()), peer); - } - - public MuxerAdress(List protocols, Peer peer) { - this.protocols.addAll(protocols); - this.peer = peer; - } - - public MuxerAdress(Protocol protocol, Peer peer) { - this.protocols.add(protocol); - this.peer = peer; - } - - public List getProtocols() { - return protocols; - } - - public Peer getPeer() { - return peer; - } - - @Override - public String toString() { - return peer + "[" + - protocols.stream().map(Protocol::toString).collect(Collectors.joining(",")) - + "]"; - } + public String toString() { + return peer + + "[" + + protocols.stream().map(Protocol::toString).collect(Collectors.joining(",")) + + "]"; } + } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Peer.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Peer.java index d77de4455..576e00256 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Peer.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Peer.java @@ -1,29 +1,27 @@ package io.libp2p.tools.p2pd.libp2pj; -/** - * Created by Anton Nashatyrev on 18.12.2018. - */ +/** Created by Anton Nashatyrev on 18.12.2018. */ public class Peer { - private final byte[] id; + private final byte[] id; - public Peer(byte[] id) { - this.id = id; - } + public Peer(byte[] id) { + this.id = id; + } - public byte[] getIdBytes() { - return id; - } + public byte[] getIdBytes() { + return id; + } -// public String getIdBase58() { -// return Base58.encode(getIdBytes()); -// } -// -// public String getIdHexString() { -// return Hex.encodeHexString(getIdBytes()); -// } + // public String getIdBase58() { + // return Base58.encode(getIdBytes()); + // } + // + // public String getIdHexString() { + // return Hex.encodeHexString(getIdBytes()); + // } -// @Override -// public String toString() { -// return "Peer{" + "id=" + getIdBase58() + "}"; -// } + // @Override + // public String toString() { + // return "Peer{" + "id=" + getIdBase58() + "}"; + // } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/PeerInfo.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/PeerInfo.java index 999a530b9..2224ec8e8 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/PeerInfo.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/PeerInfo.java @@ -1,34 +1,28 @@ package io.libp2p.tools.p2pd.libp2pj; import io.libp2p.core.multiformats.Multiaddr; - import java.util.List; -/** - * Created by Anton Nashatyrev on 20.12.2018. - */ +/** Created by Anton Nashatyrev on 20.12.2018. */ public class PeerInfo { - private final Peer id; - private final List addresses; + private final Peer id; + private final List addresses; - public PeerInfo(Peer id, List addresses) { - this.id = id; - this.addresses = addresses; - } + public PeerInfo(Peer id, List addresses) { + this.id = id; + this.addresses = addresses; + } - public Peer getId() { - return id; - } + public Peer getId() { + return id; + } - public List getAddresses() { - return addresses; - } + public List getAddresses() { + return addresses; + } - @Override - public String toString() { - return "PeerInfo{" + - "id=" + id + - ", adresses=" + addresses + - '}'; - } + @Override + public String toString() { + return "PeerInfo{" + "id=" + id + ", adresses=" + addresses + '}'; + } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Protocol.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Protocol.java index de0b946ea..b5a2c7327 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Protocol.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Protocol.java @@ -1,21 +1,19 @@ package io.libp2p.tools.p2pd.libp2pj; -/** - * Created by Anton Nashatyrev on 18.12.2018. - */ +/** Created by Anton Nashatyrev on 18.12.2018. */ public class Protocol { - private final String name; + private final String name; - public Protocol(String name) { - this.name = name; - } + public Protocol(String name) { + this.name = name; + } - public String getName() { - return name; - } + public String getName() { + return name; + } - @Override - public String toString() { - return getName(); - } + @Override + public String toString() { + return getName(); + } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Stream.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Stream.java index 623882c11..062d1dffd 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Stream.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Stream.java @@ -2,20 +2,18 @@ import java.nio.ByteBuffer; -/** - * Created by Anton Nashatyrev on 18.12.2018. - */ +/** Created by Anton Nashatyrev on 18.12.2018. */ public interface Stream { - boolean isInitiator(); + boolean isInitiator(); - void write(ByteBuffer data); + void write(ByteBuffer data); - void flush(); + void flush(); - void close(); + void close(); - TEndpoint getRemoteAddress(); + TEndpoint getRemoteAddress(); - TEndpoint getLocalAddress(); + TEndpoint getLocalAddress(); } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/StreamHandler.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/StreamHandler.java index aec2ea382..f503994f5 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/StreamHandler.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/StreamHandler.java @@ -2,16 +2,14 @@ import java.nio.ByteBuffer; -/** - * Created by Anton Nashatyrev on 18.12.2018. - */ +/** Created by Anton Nashatyrev on 18.12.2018. */ public interface StreamHandler { - void onCreate(Stream stream); + void onCreate(Stream stream); - void onRead(ByteBuffer data); + void onRead(ByteBuffer data); - void onClose(); + void onClose(); - void onError(Throwable error); + void onError(Throwable error); } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Transport.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Transport.java index b56f15adf..5a1818a22 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Transport.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/Transport.java @@ -1,34 +1,29 @@ package io.libp2p.tools.p2pd.libp2pj; import io.libp2p.core.multiformats.Multiaddr; - import java.io.Closeable; import java.util.concurrent.CompletableFuture; import java.util.function.Supplier; -/** - * Created by Anton Nashatyrev on 10.12.2018. - */ +/** Created by Anton Nashatyrev on 10.12.2018. */ public interface Transport extends Connector { + @Override + CompletableFuture dial(Multiaddr multiaddress, StreamHandler dialHandler); - @Override - CompletableFuture dial(Multiaddr multiaddress, - StreamHandler dialHandler); + @Override + CompletableFuture listen( + Multiaddr multiaddress, Supplier> handlerFactory); - @Override - CompletableFuture listen(Multiaddr multiaddress, - Supplier> handlerFactory); + interface Listener extends Closeable { - interface Listener extends Closeable { - - @Override - void close(); + @Override + void close(); - Multiaddr getLocalMultiaddress(); + Multiaddr getLocalMultiaddress(); - default CompletableFuture getPublicMultiaddress() { - return CompletableFuture.completedFuture(getLocalMultiaddress()); - } + default CompletableFuture getPublicMultiaddress() { + return CompletableFuture.completedFuture(getLocalMultiaddress()); } + } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/exceptions/MalformedMultiaddressException.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/exceptions/MalformedMultiaddressException.java index dd015ae8f..2d738647b 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/exceptions/MalformedMultiaddressException.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/exceptions/MalformedMultiaddressException.java @@ -1,10 +1,8 @@ package io.libp2p.tools.p2pd.libp2pj.exceptions; -/** - * Created by Anton Nashatyrev on 11.12.2018. - */ +/** Created by Anton Nashatyrev on 11.12.2018. */ public class MalformedMultiaddressException extends RuntimeException { - public MalformedMultiaddressException(String message) { - super(message); - } + public MalformedMultiaddressException(String message) { + super(message); + } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/exceptions/UnsupportedTransportException.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/exceptions/UnsupportedTransportException.java index 0c952ba51..b6e904ccd 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/exceptions/UnsupportedTransportException.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/exceptions/UnsupportedTransportException.java @@ -1,10 +1,8 @@ package io.libp2p.tools.p2pd.libp2pj.exceptions; -/** - * Created by Anton Nashatyrev on 11.12.2018. - */ +/** Created by Anton Nashatyrev on 11.12.2018. */ public class UnsupportedTransportException extends RuntimeException { - public UnsupportedTransportException(String message) { - super(message); - } + public UnsupportedTransportException(String message) { + super(message); + } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/util/Base58.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/util/Base58.java index 07fb17b6a..2e99af584 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/util/Base58.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/util/Base58.java @@ -4,170 +4,170 @@ public class Base58 { - private static final char[] ALPHABET = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz" - .toCharArray(); - private static final int BASE_58 = ALPHABET.length; - private static final int BASE_256 = 256; - - private static final int[] INDEXES = new int[128]; - static { - for (int i = 0; i < INDEXES.length; i++) { - INDEXES[i] = -1; - } - for (int i = 0; i < ALPHABET.length; i++) { - INDEXES[ALPHABET[i]] = i; - } + private static final char[] ALPHABET = + "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz".toCharArray(); + private static final int BASE_58 = ALPHABET.length; + private static final int BASE_256 = 256; + + private static final int[] INDEXES = new int[128]; + + static { + for (int i = 0; i < INDEXES.length; i++) { + INDEXES[i] = -1; + } + for (int i = 0; i < ALPHABET.length; i++) { + INDEXES[ALPHABET[i]] = i; + } + } + + public static void main(String[] args) throws Exception { + byte[] ret = Base58.decode("QmWaWjD7Sfs7Lw7ZgMgbRN47e2iakSMuZHqPRkctHyhFzf"); + System.out.println(Arrays.toString(ret)); + } + + public static String encode(byte[] input) { + if (input.length == 0) { + // paying with the same coin + return ""; + } + + // + // Make a copy of the input since we are going to modify it. + // + input = copyOfRange(input, 0, input.length); + + // + // Count leading zeroes + // + int zeroCount = 0; + while (zeroCount < input.length && input[zeroCount] == 0) { + ++zeroCount; + } + + // + // The actual encoding + // + byte[] temp = new byte[input.length * 2]; + int j = temp.length; + + int startAt = zeroCount; + while (startAt < input.length) { + byte mod = divmod58(input, startAt); + if (input[startAt] == 0) { + ++startAt; + } + + temp[--j] = (byte) ALPHABET[mod]; + } + + // + // Strip extra '1' if any + // + while (j < temp.length && temp[j] == ALPHABET[0]) { + ++j; } + // + // Add as many leading '1' as there were leading zeros. + // + while (--zeroCount >= 0) { + temp[--j] = (byte) ALPHABET[0]; + } + + byte[] output = copyOfRange(temp, j, temp.length); + return new String(output); + } - public static void main(String[] args) throws Exception { - byte[] ret = Base58.decode("QmWaWjD7Sfs7Lw7ZgMgbRN47e2iakSMuZHqPRkctHyhFzf"); - System.out.println(Arrays.toString(ret)); + public static byte[] decode(String input) { + if (input.length() == 0) { + // paying with the same coin + return new byte[0]; } - public static String encode(byte[] input) { - if (input.length == 0) { - // paying with the same coin - return ""; - } - - // - // Make a copy of the input since we are going to modify it. - // - input = copyOfRange(input, 0, input.length); - - // - // Count leading zeroes - // - int zeroCount = 0; - while (zeroCount < input.length && input[zeroCount] == 0) { - ++zeroCount; - } - - // - // The actual encoding - // - byte[] temp = new byte[input.length * 2]; - int j = temp.length; - - int startAt = zeroCount; - while (startAt < input.length) { - byte mod = divmod58(input, startAt); - if (input[startAt] == 0) { - ++startAt; - } - - temp[--j] = (byte) ALPHABET[mod]; - } - - // - // Strip extra '1' if any - // - while (j < temp.length && temp[j] == ALPHABET[0]) { - ++j; - } - - // - // Add as many leading '1' as there were leading zeros. - // - while (--zeroCount >= 0) { - temp[--j] = (byte) ALPHABET[0]; - } - - byte[] output = copyOfRange(temp, j, temp.length); - return new String(output); + byte[] input58 = new byte[input.length()]; + // + // Transform the String to a base58 byte sequence + // + for (int i = 0; i < input.length(); ++i) { + char c = input.charAt(i); + + int digit58 = -1; + if (c >= 0 && c < 128) { + digit58 = INDEXES[c]; + } + if (digit58 < 0) { + throw new RuntimeException("Not a Base58 input: " + input); + } + + input58[i] = (byte) digit58; } - public static byte[] decode(String input) { - if (input.length() == 0) { - // paying with the same coin - return new byte[0]; - } - - byte[] input58 = new byte[input.length()]; - // - // Transform the String to a base58 byte sequence - // - for (int i = 0; i < input.length(); ++i) { - char c = input.charAt(i); - - int digit58 = -1; - if (c >= 0 && c < 128) { - digit58 = INDEXES[c]; - } - if (digit58 < 0) { - throw new RuntimeException("Not a Base58 input: " + input); - } - - input58[i] = (byte) digit58; - } - - // - // Count leading zeroes - // - int zeroCount = 0; - while (zeroCount < input58.length && input58[zeroCount] == 0) { - ++zeroCount; - } - - // - // The encoding - // - byte[] temp = new byte[input.length()]; - int j = temp.length; - - int startAt = zeroCount; - while (startAt < input58.length) { - byte mod = divmod256(input58, startAt); - if (input58[startAt] == 0) { - ++startAt; - } - - temp[--j] = mod; - } - - // - // Do no add extra leading zeroes, move j to first non null byte. - // - while (j < temp.length && temp[j] == 0) { - ++j; - } - - return copyOfRange(temp, j - zeroCount, temp.length); + // + // Count leading zeroes + // + int zeroCount = 0; + while (zeroCount < input58.length && input58[zeroCount] == 0) { + ++zeroCount; } - private static byte divmod58(byte[] number, int startAt) { - int remainder = 0; - for (int i = startAt; i < number.length; i++) { - int digit256 = (int) number[i] & 0xFF; - int temp = remainder * BASE_256 + digit256; + // + // The encoding + // + byte[] temp = new byte[input.length()]; + int j = temp.length; - number[i] = (byte) (temp / BASE_58); + int startAt = zeroCount; + while (startAt < input58.length) { + byte mod = divmod256(input58, startAt); + if (input58[startAt] == 0) { + ++startAt; + } - remainder = temp % BASE_58; - } + temp[--j] = mod; + } - return (byte) remainder; + // + // Do no add extra leading zeroes, move j to first non null byte. + // + while (j < temp.length && temp[j] == 0) { + ++j; } - private static byte divmod256(byte[] number58, int startAt) { - int remainder = 0; - for (int i = startAt; i < number58.length; i++) { - int digit58 = (int) number58[i] & 0xFF; - int temp = remainder * BASE_58 + digit58; + return copyOfRange(temp, j - zeroCount, temp.length); + } - number58[i] = (byte) (temp / BASE_256); + private static byte divmod58(byte[] number, int startAt) { + int remainder = 0; + for (int i = startAt; i < number.length; i++) { + int digit256 = (int) number[i] & 0xFF; + int temp = remainder * BASE_256 + digit256; - remainder = temp % BASE_256; - } + number[i] = (byte) (temp / BASE_58); - return (byte) remainder; + remainder = temp % BASE_58; } - private static byte[] copyOfRange(byte[] source, int from, int to) { - byte[] range = new byte[to - from]; - System.arraycopy(source, from, range, 0, range.length); + return (byte) remainder; + } - return range; + private static byte divmod256(byte[] number58, int startAt) { + int remainder = 0; + for (int i = startAt; i < number58.length; i++) { + int digit58 = (int) number58[i] & 0xFF; + int temp = remainder * BASE_58 + digit58; + + number58[i] = (byte) (temp / BASE_256); + + remainder = temp % BASE_256; } + + return (byte) remainder; + } + + private static byte[] copyOfRange(byte[] source, int from, int to) { + byte[] range = new byte[to - from]; + System.arraycopy(source, from, range, 0, range.length); + + return range; + } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/util/Cid.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/util/Cid.java index c39aa55c2..43f4417bb 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/util/Cid.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/util/Cid.java @@ -2,7 +2,7 @@ public class Cid { - public byte[] toBytes() { - throw new UnsupportedOperationException(); - } + public byte[] toBytes() { + throw new UnsupportedOperationException(); + } } diff --git a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/util/Util.java b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/util/Util.java index 11b4cb79b..16385dc66 100644 --- a/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/util/Util.java +++ b/libp2p/src/testFixtures/java/io/libp2p/tools/p2pd/libp2pj/util/Util.java @@ -3,32 +3,28 @@ import java.util.Arrays; import java.util.List; -/** - * Created by Anton Nashatyrev on 17.12.2018. - */ +/** Created by Anton Nashatyrev on 17.12.2018. */ public class Util { - public static byte[] concat(List arrays) { - byte[] ret = new byte[arrays.stream().mapToInt(arr -> arr.length).sum()]; - int off = 0; - for (byte[] bb : arrays) { - System.arraycopy(bb, 0, ret, off, bb.length); - off += bb.length; - } - return ret; + public static byte[] concat(List arrays) { + byte[] ret = new byte[arrays.stream().mapToInt(arr -> arr.length).sum()]; + int off = 0; + for (byte[] bb : arrays) { + System.arraycopy(bb, 0, ret, off, bb.length); + off += bb.length; } + return ret; + } - /** - * https://developers.google.com/protocol-buffers/docs/encoding - */ - public static byte[] encodeUVariant(long n) { - int size = 0; - byte[] ret = new byte[10]; - while(n > 0) { - if (size > 0) ret[size - 1] |= 0b10000000; - ret[size++] = (byte) (n & 0b01111111); - n >>>= 7; - } - return Arrays.copyOfRange(ret, 0, size); + /** https://developers.google.com/protocol-buffers/docs/encoding */ + public static byte[] encodeUVariant(long n) { + int size = 0; + byte[] ret = new byte[10]; + while (n > 0) { + if (size > 0) ret[size - 1] |= 0b10000000; + ret[size++] = (byte) (n & 0b01111111); + n >>>= 7; } + return Arrays.copyOfRange(ret, 0, size); + } } diff --git a/libp2p/src/testFixtures/kotlin/io/libp2p/pubsub/DeterministicFuzz.kt b/libp2p/src/testFixtures/kotlin/io/libp2p/pubsub/DeterministicFuzz.kt index 80dc9722a..646ee5c5c 100644 --- a/libp2p/src/testFixtures/kotlin/io/libp2p/pubsub/DeterministicFuzz.kt +++ b/libp2p/src/testFixtures/kotlin/io/libp2p/pubsub/DeterministicFuzz.kt @@ -1,6 +1,6 @@ package io.libp2p.pubsub -import io.libp2p.core.crypto.KEY_TYPE +import io.libp2p.core.crypto.KeyType import io.libp2p.core.crypto.generateKeyPair import io.libp2p.etc.types.lazyVar import io.libp2p.pubsub.flood.FloodRouter @@ -37,7 +37,7 @@ class DeterministicFuzz { return TestRouter("" + (cnt++), router).apply { val randomBytes = ByteArray(8) random.nextBytes(randomBytes) - keyPair = generateKeyPair(KEY_TYPE.ECDSA, random = SecureRandom(randomBytes)) + keyPair = generateKeyPair(KeyType.ECDSA, random = SecureRandom(randomBytes)) testExecutor = deterministicExecutor } } diff --git a/libp2p/src/testFixtures/kotlin/io/libp2p/pubsub/TestRouter.kt b/libp2p/src/testFixtures/kotlin/io/libp2p/pubsub/TestRouter.kt index a90a9f45e..60e325aed 100644 --- a/libp2p/src/testFixtures/kotlin/io/libp2p/pubsub/TestRouter.kt +++ b/libp2p/src/testFixtures/kotlin/io/libp2p/pubsub/TestRouter.kt @@ -1,7 +1,7 @@ package io.libp2p.pubsub import io.libp2p.core.PeerId -import io.libp2p.core.crypto.KEY_TYPE +import io.libp2p.core.crypto.KeyType import io.libp2p.core.crypto.generateKeyPair import io.libp2p.core.pubsub.RESULT_VALID import io.libp2p.core.pubsub.ValidationResult @@ -47,7 +47,7 @@ class TestRouter( var testExecutor: ScheduledExecutorService by lazyVar { Executors.newSingleThreadScheduledExecutor() } - var keyPair = generateKeyPair(KEY_TYPE.ECDSA) + var keyPair = generateKeyPair(KeyType.ECDSA) val peerId by lazy { PeerId.fromPubKey(keyPair.second) } val protocol = router.protocol.announceStr @@ -62,13 +62,15 @@ class TestRouter( pubsubLogs: LogLevel? = null, initiator: Boolean ): TestChannel { - val parentChannel = TestChannel("dummy-parent-channel", false) val connection = ConnectionOverNetty(parentChannel, NullTransport(), initiator) connection.setSecureSession( SecureChannel.Session( - peerId, remoteRouter.peerId, remoteRouter.keyPair.second, null + peerId, + remoteRouter.peerId, + remoteRouter.keyPair.second, + null ) ) @@ -90,7 +92,6 @@ class TestRouter( wireLogs: LogLevel? = null, pubsubLogs: LogLevel? = null ): TestConnection { - val thisChannel = newChannel("[${idCnt.incrementAndGet()}]$name=>${another.name}", another, wireLogs, pubsubLogs, true) val anotherChannel = another.newChannel("[${idCnt.incrementAndGet()}]${another.name}=>$name", this, wireLogs, pubsubLogs, false) listOf(thisChannel, anotherChannel).forEach { diff --git a/libp2p/src/testFixtures/kotlin/io/libp2p/pubsub/gossip/Eth2GossipParams.kt b/libp2p/src/testFixtures/kotlin/io/libp2p/pubsub/gossip/Eth2GossipParams.kt index 25aace38c..4eb050f53 100644 --- a/libp2p/src/testFixtures/kotlin/io/libp2p/pubsub/gossip/Eth2GossipParams.kt +++ b/libp2p/src/testFixtures/kotlin/io/libp2p/pubsub/gossip/Eth2GossipParams.kt @@ -20,7 +20,7 @@ val Eth2DefaultGossipParams = GossipParams( DLazy = 8, pruneBackoff = 1.minutes, - floodPublish = true, + floodPublishMaxMessageSizeThreshold = 16384, gossipFactor = 0.25, DScore = 4, DOut = 2, diff --git a/libp2p/src/testFixtures/kotlin/io/libp2p/tools/HostFactory.kt b/libp2p/src/testFixtures/kotlin/io/libp2p/tools/HostFactory.kt index f87a3ad8a..efb48db64 100644 --- a/libp2p/src/testFixtures/kotlin/io/libp2p/tools/HostFactory.kt +++ b/libp2p/src/testFixtures/kotlin/io/libp2p/tools/HostFactory.kt @@ -2,7 +2,7 @@ package io.libp2p.tools import io.libp2p.core.Host import io.libp2p.core.PeerId -import io.libp2p.core.crypto.KEY_TYPE +import io.libp2p.core.crypto.KeyType import io.libp2p.core.crypto.PrivKey import io.libp2p.core.crypto.PubKey import io.libp2p.core.crypto.generateKeyPair @@ -25,7 +25,7 @@ import java.util.concurrent.TimeUnit class HostFactory { - var keyType = KEY_TYPE.ECDSA + var keyType = KeyType.ECDSA var tcpPort = Random().nextInt(10_000) + 6000 var transportCtor = ::TcpTransport var secureCtor: SecureChannelCtor = ::NoiseXXSecureChannel diff --git a/libp2p/src/testFixtures/kotlin/io/libp2p/tools/NullHost.kt b/libp2p/src/testFixtures/kotlin/io/libp2p/tools/NullHost.kt index 2e3f84c44..476bd0e2e 100644 --- a/libp2p/src/testFixtures/kotlin/io/libp2p/tools/NullHost.kt +++ b/libp2p/src/testFixtures/kotlin/io/libp2p/tools/NullHost.kt @@ -55,6 +55,10 @@ open class NullHost : Host { TODO("not implemented") } + override fun getProtocols(): List> { + TODO("not implemented") + } + override fun addConnectionHandler(handler: ConnectionHandler) { TODO("not implemented") } diff --git a/libp2p/src/testFixtures/kotlin/io/libp2p/tools/P2pdRunner.kt b/libp2p/src/testFixtures/kotlin/io/libp2p/tools/P2pdRunner.kt index dd204b7d2..2045ff671 100644 --- a/libp2p/src/testFixtures/kotlin/io/libp2p/tools/P2pdRunner.kt +++ b/libp2p/src/testFixtures/kotlin/io/libp2p/tools/P2pdRunner.kt @@ -14,7 +14,7 @@ class P2pdRunner(val execNames: List = listOf("p2pd", "p2pd.exe"), val e fun findP2pdExe(): String? = (predefinedSearchPaths + execSearchPaths) .flatMap { path -> execNames.map { File(path, it) } } - .firstOrNull() { it.canExecute() } + .firstOrNull { it.canExecute() } ?.absoluteFile?.canonicalPath fun launcher() = findP2pdExe()?.let { DaemonLauncher(it) } diff --git a/libp2p/src/testFixtures/kotlin/io/libp2p/tools/TCPProxy.kt b/libp2p/src/testFixtures/kotlin/io/libp2p/tools/TCPProxy.kt index 3f4ad2346..eb77980d7 100644 --- a/libp2p/src/testFixtures/kotlin/io/libp2p/tools/TCPProxy.kt +++ b/libp2p/src/testFixtures/kotlin/io/libp2p/tools/TCPProxy.kt @@ -26,7 +26,6 @@ class TCPProxy { it.addLastLocal(object : ChannelInboundHandlerAdapter() { val client = CompletableFuture() override fun channelActive(serverCtx: ChannelHandlerContext) { - serverCtx.channel().pipeline().addFirst(LoggingHandler("server", LogLevel.INFO)) Bootstrap().apply { diff --git a/libp2p/src/testFixtures/kotlin/io/libp2p/tools/protobuf/RpcBuilder.kt b/libp2p/src/testFixtures/kotlin/io/libp2p/tools/protobuf/RpcBuilder.kt index a4c55bba6..4da90ef85 100644 --- a/libp2p/src/testFixtures/kotlin/io/libp2p/tools/protobuf/RpcBuilder.kt +++ b/libp2p/src/testFixtures/kotlin/io/libp2p/tools/protobuf/RpcBuilder.kt @@ -1,6 +1,7 @@ package io.libp2p.tools.protobuf import io.libp2p.etc.types.toProtobuf +import io.libp2p.pubsub.Topic import pubsub.pb.Rpc import kotlin.random.Random @@ -28,9 +29,9 @@ class RpcBuilder { } } - fun addIHaves(iHaveCount: Int, messageIdCount: Int) { + fun addIHaves(iHaveCount: Int, messageIdCount: Int, topic: Topic) { for (i in 0 until iHaveCount) { - val iHaveBuilder = Rpc.ControlIHave.newBuilder() + val iHaveBuilder = Rpc.ControlIHave.newBuilder().setTopicID(topic) for (j in 0 until messageIdCount) { iHaveBuilder.addMessageIDs(Random.nextBytes(6).toProtobuf()) } diff --git a/libp2p/src/testFixtures/kotlin/io/libp2p/transport/NullConnectionUpgrader.kt b/libp2p/src/testFixtures/kotlin/io/libp2p/transport/NullConnectionUpgrader.kt index 62f1b6c2c..b8bf2569e 100644 --- a/libp2p/src/testFixtures/kotlin/io/libp2p/transport/NullConnectionUpgrader.kt +++ b/libp2p/src/testFixtures/kotlin/io/libp2p/transport/NullConnectionUpgrader.kt @@ -3,7 +3,7 @@ package io.libp2p.transport import io.libp2p.core.Connection import io.libp2p.core.PeerId import io.libp2p.core.StreamPromise -import io.libp2p.core.crypto.KEY_TYPE +import io.libp2p.core.crypto.KeyType import io.libp2p.core.crypto.generateKeyPair import io.libp2p.core.multistream.MultistreamProtocol import io.libp2p.core.multistream.ProtocolBinding @@ -20,19 +20,17 @@ class NullMultistreamProtocol : MultistreamProtocol { class NullConnectionUpgrader : ConnectionUpgrader(NullMultistreamProtocol(), emptyList(), NullMultistreamProtocol(), emptyList()) { - override fun establishSecureChannel(connection: Connection): - CompletableFuture { + override fun establishSecureChannel(connection: Connection): CompletableFuture { val nonsenseSession = SecureChannel.Session( PeerId.random(), PeerId.random(), - generateKeyPair(KEY_TYPE.RSA).second, + generateKeyPair(KeyType.RSA).second, null ) return CompletableFuture.completedFuture(nonsenseSession) } // establishSecureChannel - override fun establishMuxer(connection: Connection): - CompletableFuture { + override fun establishMuxer(connection: Connection): CompletableFuture { return CompletableFuture.completedFuture(DoNothingMuxerSession()) } // establishMuxer diff --git a/settings.gradle b/settings.gradle index 3840b0611..3b2d71fae 100644 --- a/settings.gradle +++ b/settings.gradle @@ -15,7 +15,7 @@ dependencyResolutionManagement { } } -rootProject.name = 'jvm-libp2p-minimal' +rootProject.name = 'jvm-libp2p' include ':libp2p' include ':tools:schedulers' diff --git a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/AbstractSchedulers.java b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/AbstractSchedulers.java index c44f372d5..02dfcfa3c 100644 --- a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/AbstractSchedulers.java +++ b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/AbstractSchedulers.java @@ -5,8 +5,8 @@ /** * The collection of standard Schedulers, Scheduler factory and system time supplier * - * For debugging and testing the default Schedulers instance can be replaced - * with appropriate one + *

For debugging and testing the default Schedulers instance can be replaced with + * appropriate one */ public abstract class AbstractSchedulers implements Schedulers { private static final int BLOCKING_THREAD_COUNT = 128; diff --git a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ControlledExecutorService.java b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ControlledExecutorService.java index e3367684c..5623695d9 100644 --- a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ControlledExecutorService.java +++ b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ControlledExecutorService.java @@ -3,15 +3,14 @@ import java.util.concurrent.ScheduledExecutorService; /** - * The ScheduledExecutorService which functions based on the - * current system time supplied by {@link TimeController#getTime()} instead of - * System.currentTimeMillis() + * The ScheduledExecutorService which functions based on the current system time + * supplied by {@link TimeController#getTime()} instead of System.currentTimeMillis() */ public interface ControlledExecutorService extends ScheduledExecutorService { /** - * Sets up the {@link TimeController} instance which manages ordered tasks execution - * and provides current time + * Sets up the {@link TimeController} instance which manages ordered tasks execution and provides + * current time */ void setTimeController(TimeController timeController); } diff --git a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ControlledExecutorServiceImpl.java b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ControlledExecutorServiceImpl.java index 2f0c60e32..93f1b5507 100644 --- a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ControlledExecutorServiceImpl.java +++ b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ControlledExecutorServiceImpl.java @@ -16,14 +16,15 @@ public class ControlledExecutorServiceImpl implements ControlledExecutorService { - private class ScheduledTask implements TimeController.Task{ + private class ScheduledTask implements TimeController.Task { Callable callable; final ScheduledFutureImpl future = new ScheduledFutureImpl(b -> cancel()); long targetTime; public ScheduledTask(Callable callable, long targetTime) { if (targetTime < getCurrentTime()) { - throw new IllegalStateException("Invalid target time: " + targetTime + " < " + getCurrentTime()); + throw new IllegalStateException( + "Invalid target time: " + targetTime + " < " + getCurrentTime()); } this.callable = callable; this.targetTime = targetTime; @@ -34,14 +35,15 @@ void cancel() { } public CompletableFuture execute() { - delegateExecutor.execute(() -> { - try { - V res = callable.call(); - future.delegate.complete(res); - } catch (Exception e) { - future.delegate.completeExceptionally(e); - } - }); + delegateExecutor.execute( + () -> { + try { + V res = callable.call(); + future.delegate.complete(res); + } catch (Exception e) { + future.delegate.completeExceptionally(e); + } + }); return future.delegate.thenApply(i -> null); } @@ -56,8 +58,7 @@ public String toString() { } } - - private class ScheduledFutureImpl implements ScheduledFuture { + private class ScheduledFutureImpl implements ScheduledFuture { final CompletableFuture delegate = new CompletableFuture<>(); private final Consumer canceller; @@ -108,7 +109,7 @@ public V get(long timeout, TimeUnit unit) private TimeController timeController; public ControlledExecutorServiceImpl() { - this(Runnable::run); // default immediate executor + this(Runnable::run); // default immediate executor } public ControlledExecutorServiceImpl(TimeController timeController) { @@ -138,55 +139,65 @@ public ScheduledFuture schedule(Callable callable, long delay, TimeUni if (delay < 0) { delay = 0; } - ScheduledTask scheduledTask = new ScheduledTask<>(callable, getCurrentTime() + unit.toMillis(delay)); + ScheduledTask scheduledTask = + new ScheduledTask<>(callable, getCurrentTime() + unit.toMillis(delay)); timeController.addTask(scheduledTask); return scheduledTask.future; } @Override - public ScheduledFuture scheduleAtFixedRate(Runnable command, - long initialDelay, long period, TimeUnit unit) { + public ScheduledFuture scheduleAtFixedRate( + Runnable command, long initialDelay, long period, TimeUnit unit) { ScheduledFuture[] activeFut = new ScheduledFutureImpl[1]; ScheduledFutureImpl ret = new ScheduledFutureImpl<>(b -> activeFut[0].cancel(b)); - activeFut[0] = schedule(() -> { - command.run(); - if (!activeFut[0].isCancelled()) { - activeFut[0] = scheduleAtFixedRate(command, period, period, unit); - } - return null; - }, initialDelay, unit); + activeFut[0] = + schedule( + () -> { + command.run(); + if (!activeFut[0].isCancelled()) { + activeFut[0] = scheduleAtFixedRate(command, period, period, unit); + } + return null; + }, + initialDelay, + unit); return ret; } @Override public ScheduledFuture schedule(Runnable command, long delay, TimeUnit unit) { - return schedule(() -> { - command.run(); - return null; - }, delay, unit); + return schedule( + () -> { + command.run(); + return null; + }, + delay, + unit); } @Override public Future submit(Callable task) { CompletableFuture ret = new CompletableFuture<>(); - execute(() -> { - try { - ret.complete(task.call()); - } catch (Throwable e) { - ret.completeExceptionally(e); - } - }); + execute( + () -> { + try { + ret.complete(task.call()); + } catch (Throwable e) { + ret.completeExceptionally(e); + } + }); return ret; } @Override public Future submit(Runnable task, T result) { - return submit(() -> { - task.run(); - return result; - }); + return submit( + () -> { + task.run(); + return result; + }); } @Override @@ -204,8 +215,8 @@ public void execute(Runnable command) { } @Override - public ScheduledFuture scheduleWithFixedDelay(Runnable command, long initialDelay, long delay, - TimeUnit unit) { + public ScheduledFuture scheduleWithFixedDelay( + Runnable command, long initialDelay, long delay, TimeUnit unit) { return scheduleAtFixedRate(command, initialDelay, delay, unit); } @@ -216,8 +227,9 @@ public List> invokeAll(Collection> tasks) } @Override - public List> invokeAll(Collection> tasks, long timeout, - TimeUnit unit) throws InterruptedException { + public List> invokeAll( + Collection> tasks, long timeout, TimeUnit unit) + throws InterruptedException { throw new UnsupportedOperationException(); } @@ -234,8 +246,7 @@ public T invokeAny(Collection> tasks, long timeout, Ti } @Override - public void shutdown() { - } + public void shutdown() {} @Override public List shutdownNow() { diff --git a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ControlledSchedulers.java b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ControlledSchedulers.java index f3bfef867..a1c1ed855 100644 --- a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ControlledSchedulers.java +++ b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ControlledSchedulers.java @@ -3,40 +3,35 @@ import java.time.Duration; /** - * Special Schedulers implementation which is mostly suitable for testing and simulation. - * The system time is controlled manually and all the schedulers execute tasks according - * to this time. - * Initial system time is equal to 0 + * Special Schedulers implementation which is mostly suitable for testing and simulation. The system + * time is controlled manually and all the schedulers execute tasks according to this time. Initial + * system time is equal to 0 */ public interface ControlledSchedulers extends Schedulers { /** * Sets current time. - * @throws IllegalStateException if this instance is dependent on a parent - * {@link TimeController} + * + * @throws IllegalStateException if this instance is dependent on a parent {@link TimeController} * @see TimeController#setTime(long) */ default void setCurrentTime(long newTime) { getTimeController().setTime(newTime); } - /** - * Just a handy helper method for {@link #setCurrentTime(long)} - */ + /** Just a handy helper method for {@link #setCurrentTime(long)} */ default void addTime(Duration duration) { addTime(duration.toMillis()); } - /** - * Just a handy helper method for {@link #setCurrentTime(long)} - */ + /** Just a handy helper method for {@link #setCurrentTime(long)} */ default void addTime(long millis) { setCurrentTime(getCurrentTime() + millis); } /** - * Returns {@link TimeController} which manages tasks ordered execution and - * supplies current time for this instance + * Returns {@link TimeController} which manages tasks ordered execution and supplies current time + * for this instance */ TimeController getTimeController(); } diff --git a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ControlledSchedulersImpl.java b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ControlledSchedulersImpl.java index 3f195116b..0e2f3f95c 100644 --- a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ControlledSchedulersImpl.java +++ b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ControlledSchedulersImpl.java @@ -25,7 +25,8 @@ protected Scheduler createExecutorScheduler(ScheduledExecutorService executorSer @Override protected ScheduledExecutorService createExecutor(String namePattern, int threads) { - ControlledExecutorServiceImpl service = new ControlledExecutorServiceImpl(createDelegateExecutor()); + ControlledExecutorServiceImpl service = + new ControlledExecutorServiceImpl(createDelegateExecutor()); service.setTimeController(timeController); return service; } diff --git a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/DefaultSchedulers.java b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/DefaultSchedulers.java index 837012929..7f6277c8b 100644 --- a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/DefaultSchedulers.java +++ b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/DefaultSchedulers.java @@ -1,13 +1,12 @@ package io.libp2p.tools.schedulers; import com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadFactory; import java.util.function.Consumer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class DefaultSchedulers extends AbstractSchedulers { diff --git a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ErrorHandlingScheduler.java b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ErrorHandlingScheduler.java index 00f953ce6..c400c5d2e 100644 --- a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ErrorHandlingScheduler.java +++ b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ErrorHandlingScheduler.java @@ -5,16 +5,17 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.function.Consumer; -//import reactor.core.Disposable; + +// import reactor.core.Disposable; public class ErrorHandlingScheduler implements Scheduler { private final Scheduler delegate; private final Consumer errorHandler; -// private reactor.core.scheduler.Scheduler cachedReactor; - public ErrorHandlingScheduler(Scheduler delegate, - Consumer errorHandler) { + // private reactor.core.scheduler.Scheduler cachedReactor; + + public ErrorHandlingScheduler(Scheduler delegate, Consumer errorHandler) { this.delegate = delegate; this.errorHandler = errorHandler; } @@ -31,20 +32,17 @@ public CompletableFuture executeWithDelay(Duration delay, Callable tas @Override public CompletableFuture executeAtFixedRate( - Duration initialDelay, Duration period, - RunnableEx task) { + Duration initialDelay, Duration period, RunnableEx task) { return delegate.executeAtFixedRate(initialDelay, period, () -> runAndHandleError(task)); } @Override - public CompletableFuture execute( - RunnableEx task) { + public CompletableFuture execute(RunnableEx task) { return delegate.execute(() -> runAndHandleError(task)); } @Override - public CompletableFuture executeWithDelay(Duration delay, - RunnableEx task) { + public CompletableFuture executeWithDelay(Duration delay, RunnableEx task) { return delegate.executeWithDelay(delay, () -> runAndHandleError(task)); } @@ -53,14 +51,14 @@ public long getCurrentTime() { return delegate.getCurrentTime(); } -// @Override -// public reactor.core.scheduler.Scheduler toReactor() { -// if (cachedReactor == null) { -// cachedReactor = new ErrorHandlingReactorScheduler(delegate.toReactor(), -// delegate::getCurrentTime); -// } -// return cachedReactor; -// } + // @Override + // public reactor.core.scheduler.Scheduler toReactor() { + // if (cachedReactor == null) { + // cachedReactor = new ErrorHandlingReactorScheduler(delegate.toReactor(), + // delegate::getCurrentTime); + // } + // return cachedReactor; + // } private void runAndHandleError(RunnableEx runnable) throws Exception { try { @@ -74,64 +72,66 @@ private void runAndHandleError(RunnableEx runnable) throws Exception { } } -// private class ErrorHandlingReactorScheduler extends DelegatingReactorScheduler { -// -// public ErrorHandlingReactorScheduler(reactor.core.scheduler.Scheduler delegate, -// Supplier timeSupplier) { -// super(delegate, timeSupplier); -// } -// -// @Nonnull -// @Override -// public Disposable schedule(@Nonnull Runnable task) { -// return super.schedule(() -> runAndHandleError(task)); -// } -// -// @Nonnull -// @Override -// public Disposable schedule(Runnable task, long delay, TimeUnit unit) { -// return super.schedule(() -> runAndHandleError(task), delay, unit); -// } -// -// @Nonnull -// @Override -// public Disposable schedulePeriodically(Runnable task, long initialDelay, long period, -// TimeUnit unit) { -// return super.schedulePeriodically(() -> runAndHandleError(task), initialDelay, period, unit); -// } -// -// @Nonnull -// @Override -// public Worker createWorker() { -// return new DelegateWorker(super.createWorker()) { -// @Nonnull -// @Override -// public Disposable schedule(@Nonnull Runnable task) { -// return super.schedule(() -> runAndHandleError(task)); -// } -// -// @Nonnull -// @Override -// public Disposable schedule(Runnable task, long delay, TimeUnit unit) { -// return super.schedule(() -> runAndHandleError(task), delay, unit); -// } -// -// @Nonnull -// @Override -// public Disposable schedulePeriodically(Runnable task, long initialDelay, long period, -// TimeUnit unit) { -// return super.schedulePeriodically(() -> runAndHandleError(task), initialDelay, period, unit); -// } -// }; -// } -// -// private void runAndHandleError(Runnable runnable) { -// try { -// runnable.run(); -// } catch (Throwable t) { -// errorHandler.accept(t); -// throw new RuntimeException(t); -// } -// } -// } + // private class ErrorHandlingReactorScheduler extends DelegatingReactorScheduler { + // + // public ErrorHandlingReactorScheduler(reactor.core.scheduler.Scheduler delegate, + // Supplier timeSupplier) { + // super(delegate, timeSupplier); + // } + // + // @Nonnull + // @Override + // public Disposable schedule(@Nonnull Runnable task) { + // return super.schedule(() -> runAndHandleError(task)); + // } + // + // @Nonnull + // @Override + // public Disposable schedule(Runnable task, long delay, TimeUnit unit) { + // return super.schedule(() -> runAndHandleError(task), delay, unit); + // } + // + // @Nonnull + // @Override + // public Disposable schedulePeriodically(Runnable task, long initialDelay, long period, + // TimeUnit unit) { + // return super.schedulePeriodically(() -> runAndHandleError(task), initialDelay, period, + // unit); + // } + // + // @Nonnull + // @Override + // public Worker createWorker() { + // return new DelegateWorker(super.createWorker()) { + // @Nonnull + // @Override + // public Disposable schedule(@Nonnull Runnable task) { + // return super.schedule(() -> runAndHandleError(task)); + // } + // + // @Nonnull + // @Override + // public Disposable schedule(Runnable task, long delay, TimeUnit unit) { + // return super.schedule(() -> runAndHandleError(task), delay, unit); + // } + // + // @Nonnull + // @Override + // public Disposable schedulePeriodically(Runnable task, long initialDelay, long period, + // TimeUnit unit) { + // return super.schedulePeriodically(() -> runAndHandleError(task), initialDelay, period, + // unit); + // } + // }; + // } + // + // private void runAndHandleError(Runnable runnable) { + // try { + // runnable.run(); + // } catch (Throwable t) { + // errorHandler.accept(t); + // throw new RuntimeException(t); + // } + // } + // } } diff --git a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ExecutorScheduler.java b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ExecutorScheduler.java index ffe1bc999..db4ce033b 100644 --- a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ExecutorScheduler.java +++ b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/ExecutorScheduler.java @@ -12,7 +12,8 @@ public class ExecutorScheduler implements Scheduler { private final ScheduledExecutorService executorService; private final Supplier timeSupplier; -// private reactor.core.scheduler.Scheduler cachedReactor; + + // private reactor.core.scheduler.Scheduler cachedReactor; public ExecutorScheduler(ScheduledExecutorService executorService, Supplier timeSupplier) { this.executorService = executorService; @@ -22,59 +23,69 @@ public ExecutorScheduler(ScheduledExecutorService executorService, Supplier CompletableFuture execute(Callable task) { CompletableFuture future = new CompletableFuture<>(); - executorService.execute(() -> { - try { - future.complete(task.call()); - } catch (Throwable t) { - future.completeExceptionally(t); - } - }); + executorService.execute( + () -> { + try { + future.complete(task.call()); + } catch (Throwable t) { + future.completeExceptionally(t); + } + }); return future; } @Override public CompletableFuture executeWithDelay(Duration delay, Callable task) { CompletableFuture future = new CompletableFuture<>(); - executorService.schedule(() -> { - try { - future.complete(task.call()); - } catch (Throwable t) { - future.completeExceptionally(t); - } - }, delay.toMillis(), TimeUnit.MILLISECONDS); + executorService.schedule( + () -> { + try { + future.complete(task.call()); + } catch (Throwable t) { + future.completeExceptionally(t); + } + }, + delay.toMillis(), + TimeUnit.MILLISECONDS); return future; } @Override - public CompletableFuture executeAtFixedRate(Duration initialDelay, Duration period, - RunnableEx task) { + public CompletableFuture executeAtFixedRate( + Duration initialDelay, Duration period, RunnableEx task) { ScheduledFuture[] scheduledFuture = new ScheduledFuture[1]; - CompletableFuture ret = new CompletableFuture() { - @Override - public boolean cancel(boolean mayInterruptIfRunning) { - return scheduledFuture[0].cancel(mayInterruptIfRunning); - } - }; - scheduledFuture[0] = executorService.scheduleAtFixedRate(() -> { - try { - task.run(); - } catch (Throwable e) { - ret.completeExceptionally(e); - throw new RuntimeException(e); - } - }, initialDelay.toMillis(), period.toMillis(), TimeUnit.MILLISECONDS); + CompletableFuture ret = + new CompletableFuture() { + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return scheduledFuture[0].cancel(mayInterruptIfRunning); + } + }; + scheduledFuture[0] = + executorService.scheduleAtFixedRate( + () -> { + try { + task.run(); + } catch (Throwable e) { + ret.completeExceptionally(e); + throw new RuntimeException(e); + } + }, + initialDelay.toMillis(), + period.toMillis(), + TimeUnit.MILLISECONDS); return ret; } -// @Override -// public reactor.core.scheduler.Scheduler toReactor() { -// if (cachedReactor == null) { -// cachedReactor = convertToReactor(this); -// } -// return cachedReactor; -// } + // @Override + // public reactor.core.scheduler.Scheduler toReactor() { + // if (cachedReactor == null) { + // cachedReactor = convertToReactor(this); + // } + // return cachedReactor; + // } @Override public long getCurrentTime() { diff --git a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/LatestExecutor.java b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/LatestExecutor.java index 5219b319c..d34cef742 100644 --- a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/LatestExecutor.java +++ b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/LatestExecutor.java @@ -3,15 +3,14 @@ import java.util.function.Consumer; /** - * Processes events submitted via {@link #newEvent(T)} with the specified - * eventProcessor on the specified scheduler. + * Processes events submitted via {@link #newEvent(T)} with the specified eventProcessor + * on the specified scheduler. * - * Guarantees that the latest event would be processed, though other - * intermediate events could be skipped. + *

Guarantees that the latest event would be processed, though other intermediate events could be + * skipped. * - * Skips subsequent events if any previous is still processing. - * Avoids creating scheduling a task for each event thus allowing frequent - * events submitting. + *

Skips subsequent events if any previous is still processing. Avoids creating scheduling a task + * for each event thus allowing frequent events submitting. */ public class LatestExecutor { private final Scheduler scheduler; @@ -25,9 +24,8 @@ public LatestExecutor(Scheduler scheduler, Consumer eventProcessor) { } /** - * Submits a new event for processing. - * This particular event may not be processed if a subsequent event submitted - * shortly + * Submits a new event for processing. This particular event may not be processed if a subsequent + * event submitted shortly */ public synchronized void newEvent(T event) { latestEvent = event; diff --git a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/LoggerMDCExecutor.java b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/LoggerMDCExecutor.java index 2bd55b4c8..e732b0a82 100644 --- a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/LoggerMDCExecutor.java +++ b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/LoggerMDCExecutor.java @@ -1,11 +1,10 @@ package io.libp2p.tools.schedulers; -import org.slf4j.MDC; - import java.util.ArrayList; import java.util.List; import java.util.concurrent.Executor; import java.util.function.Supplier; +import org.slf4j.MDC; public class LoggerMDCExecutor implements Executor { diff --git a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/RunnableEx.java b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/RunnableEx.java index 12bc2797b..42f42bae7 100644 --- a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/RunnableEx.java +++ b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/RunnableEx.java @@ -1,8 +1,6 @@ package io.libp2p.tools.schedulers; -/** - * The same as standard Runnable which can throw unchecked exception - */ +/** The same as standard Runnable which can throw unchecked exception */ public interface RunnableEx { void run() throws Exception; } diff --git a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/Scheduler.java b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/Scheduler.java index 1c51a8009..e4b32e50c 100644 --- a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/Scheduler.java +++ b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/Scheduler.java @@ -5,30 +5,32 @@ import java.util.concurrent.CompletableFuture; import java.util.function.Supplier; -/** - * Analog for standard ScheduledExecutorService - */ +/** Analog for standard ScheduledExecutorService */ public interface Scheduler { CompletableFuture execute(Callable task); CompletableFuture executeWithDelay(Duration delay, Callable task); - CompletableFuture executeAtFixedRate(Duration initialDelay, Duration period, - RunnableEx task); + CompletableFuture executeAtFixedRate( + Duration initialDelay, Duration period, RunnableEx task); long getCurrentTime(); -// default reactor.core.scheduler.Scheduler toReactor() { -// return convertToReactor(this); -// } + // default reactor.core.scheduler.Scheduler toReactor() { + // return convertToReactor(this); + // } default CompletableFuture executeR(Runnable task) { return execute(task::run); } default CompletableFuture execute(RunnableEx task) { - return execute(() -> {task.run(); return null;}); + return execute( + () -> { + task.run(); + return null; + }); } default CompletableFuture executeWithDelayR(Duration delay, Runnable task) { @@ -36,28 +38,36 @@ default CompletableFuture executeWithDelayR(Duration delay, Runnable task) } default CompletableFuture executeWithDelay(Duration delay, RunnableEx task) { - return executeWithDelay(delay, () -> {task.run(); return null;}); + return executeWithDelay( + delay, + () -> { + task.run(); + return null; + }); } @SuppressWarnings("unchecked") - default CompletableFuture orTimeout(CompletableFuture future, Duration futureTimeout, - Supplier exceptionSupplier) { - return (CompletableFuture) CompletableFuture.anyOf( - future, - executeWithDelay(futureTimeout, - () -> {throw exceptionSupplier.get();})); + default CompletableFuture orTimeout( + CompletableFuture future, Duration futureTimeout, Supplier exceptionSupplier) { + return (CompletableFuture) + CompletableFuture.anyOf( + future, + executeWithDelay( + futureTimeout, + () -> { + throw exceptionSupplier.get(); + })); } -// default reactor.core.scheduler.Scheduler convertToReactor(Scheduler scheduler) { -// if (scheduler instanceof ExecutorScheduler) { -// return new DelegatingReactorScheduler( -// reactor.core.scheduler.Schedulers.fromExecutorService( -// ((ExecutorScheduler) scheduler).getExecutorService()), -// this::getCurrentTime); -// } else { -// throw new UnsupportedOperationException( -// "Conversion from custom Scheduler to Reactor Scheduler not implemented yet."); -// } -// } + // default reactor.core.scheduler.Scheduler convertToReactor(Scheduler scheduler) { + // if (scheduler instanceof ExecutorScheduler) { + // return new DelegatingReactorScheduler( + // reactor.core.scheduler.Schedulers.fromExecutorService( + // ((ExecutorScheduler) scheduler).getExecutorService()), + // this::getCurrentTime); + // } else { + // throw new UnsupportedOperationException( + // "Conversion from custom Scheduler to Reactor Scheduler not implemented yet."); + // } + // } } - diff --git a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/Schedulers.java b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/Schedulers.java index 86982b4f7..75bfd62be 100644 --- a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/Schedulers.java +++ b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/Schedulers.java @@ -4,24 +4,23 @@ import java.util.function.Supplier; /** - * The collection of standard Schedulers, Scheduler factory and system time supplier - * Any scheduler withing a system should be obtained or created via this interface + * The collection of standard Schedulers, Scheduler factory and system time supplier Any scheduler + * withing a system should be obtained or created via this interface */ public interface Schedulers { - /** - * Creates default Schedulers implementation for production functioning - */ + /** Creates default Schedulers implementation for production functioning */ static Schedulers createDefault() { return new DefaultSchedulers(); } /** - * Creates the ControlledSchedulers implementation (normally for testing or simulation) - * with the specified delegate Executor factory. + * Creates the ControlledSchedulers implementation (normally for testing or simulation) with the + * specified delegate Executor factory. + * * @param delegateExecutor all the tasks are finally executed on executors created by this - * factory. Normally a single executor should be sufficient and could be supplied as - * () -> mySingleExecutor + * factory. Normally a single executor should be sufficient and could be supplied as + * () -> mySingleExecutor */ static ControlledSchedulers createControlled(Supplier delegateExecutor) { return new ControlledSchedulersImpl() { @@ -33,55 +32,49 @@ protected Executor createDelegateExecutor() { } /** - * Creates the ControlledSchedulers implementation (normally for testing or simulation) - * which executes all the tasks immediately on the same thread or if a task scheduled for - * later execution then this task would be executed within appropriate - * {@link ControlledSchedulers#setCurrentTime(long)} call + * Creates the ControlledSchedulers implementation (normally for testing or simulation) which + * executes all the tasks immediately on the same thread or if a task scheduled for later + * execution then this task would be executed within appropriate {@link + * ControlledSchedulers#setCurrentTime(long)} call */ static ControlledSchedulers createControlled() { return createControlled(() -> Runnable::run); } /** - * Returns the current system time - * This method should be used by all components to obtain the current system time - * System.currentTimeMillis() (or other standard Java means) is prohibited. + * Returns the current system time This method should be used by all components to obtain the + * current system time System.currentTimeMillis() (or other standard Java means) is + * prohibited. */ long getCurrentTime(); /** - * Scheduler to execute CPU heavy tasks - * This is normally based on a thread pool with the number of threads - * equal to number of CPU cores + * Scheduler to execute CPU heavy tasks This is normally based on a thread pool with the number of + * threads equal to number of CPU cores */ Scheduler cpuHeavy(); /** - * The scheduler to execute disk read/write tasks (like DB access, file read/write etc) - * and other tasks with potentially short blocking time. - * Tasks with potentially longer blocking time (like waiting for network response) is - * highly recommended to execute in a non-blocking (reactive) manner or at least on - * a dedicated Scheduler + * The scheduler to execute disk read/write tasks (like DB access, file read/write etc) and other + * tasks with potentially short blocking time. Tasks with potentially longer blocking time (like + * waiting for network response) is highly recommended to execute in a non-blocking (reactive) + * manner or at least on a dedicated Scheduler * - * This Scheduler is normally based on a dynamic pool with sufficient number of threads + *

This Scheduler is normally based on a dynamic pool with sufficient number of threads */ Scheduler blocking(); - /** - * Dedicated Scheduler for internal system asynchronous events - */ + /** Dedicated Scheduler for internal system asynchronous events */ Scheduler events(); - /** - * Creates new single thread Scheduler with the specified thread name - */ + /** Creates new single thread Scheduler with the specified thread name */ default Scheduler newSingleThreadDaemon(String threadName) { return newParallelDaemon(threadName, 1); } /** - * Creates new multi-thread Scheduler with the specified thread namePattern and - * number of pool threads + * Creates new multi-thread Scheduler with the specified thread namePattern and number of pool + * threads */ Scheduler newParallelDaemon(String threadNamePattern, int threadPoolCount); } diff --git a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/TaskQueue.java b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/TaskQueue.java index 3d86d0565..79d976496 100644 --- a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/TaskQueue.java +++ b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/TaskQueue.java @@ -11,77 +11,83 @@ public class TaskQueue { - private NavigableMap> tasks = Collections.synchronizedNavigableMap(new TreeMap<>()); - private boolean executing = false; + private NavigableMap> tasks = + Collections.synchronizedNavigableMap(new TreeMap<>()); + private boolean executing = false; - public void add(TimeController.Task task) { - tasks.computeIfAbsent(task.getTime(), t -> new ConcurrentLinkedQueue<>()).add(task); - } + public void add(TimeController.Task task) { + tasks.computeIfAbsent(task.getTime(), t -> new ConcurrentLinkedQueue<>()).add(task); + } - public void remove(TimeController.Task task) { - tasks.computeIfPresent(task.getTime(), (t, queue) -> { - queue.remove(task); - return queue; + public void remove(TimeController.Task task) { + tasks.computeIfPresent( + task.getTime(), + (t, queue) -> { + queue.remove(task); + return queue; }); - } + } - public boolean isEmpty() { - return tasks.isEmpty(); - } + public boolean isEmpty() { + return tasks.isEmpty(); + } - public long getEarliestTime() { - Map.Entry> entry = tasks.firstEntry(); - return entry == null ? 0 : entry.getKey(); - } + public long getEarliestTime() { + Map.Entry> entry = tasks.firstEntry(); + return entry == null ? 0 : entry.getKey(); + } - private Queue peekEarliest() { - return tasks.get(getEarliestTime()); - } + private Queue peekEarliest() { + return tasks.get(getEarliestTime()); + } - public void dropEarliest() { - tasks.remove(getEarliestTime()); - } + public void dropEarliest() { + tasks.remove(getEarliestTime()); + } - public void executeEarliest() { - if (executing) return; - executing = true; - try { - Queue taskQueue = peekEarliest(); - if (taskQueue == null) return; + public void executeEarliest() { + if (executing) return; + executing = true; + try { + Queue taskQueue = peekEarliest(); + if (taskQueue == null) return; - Queue> resQueue = new LinkedBlockingQueue<>(); + Queue> resQueue = new LinkedBlockingQueue<>(); - drainQueue(taskQueue, resQueue); + drainQueue(taskQueue, resQueue); - while (!resQueue.isEmpty()) { - try { - resQueue.poll().get(); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - dropEarliest(); - } finally { - executing = false; + while (!resQueue.isEmpty()) { + try { + resQueue.poll().get(); + } catch (Exception e) { + throw new RuntimeException(e); } + } + dropEarliest(); + } finally { + executing = false; } + } - public boolean isExecuting() { - return executing; - } + public boolean isExecuting() { + return executing; + } - private synchronized void drainQueue(Queue taskQueue, Queue> resQueue) { - while (!taskQueue.isEmpty()) { - TimeController.Task task = taskQueue.poll(); - CompletableFuture taskFut = task.execute(); - if (taskFut.isDone()) { - resQueue.add(taskFut); - } else { - CompletableFuture resFut = taskFut.whenComplete((v, t) -> { - drainQueue(taskQueue, resQueue); + private synchronized void drainQueue( + Queue taskQueue, Queue> resQueue) { + while (!taskQueue.isEmpty()) { + TimeController.Task task = taskQueue.poll(); + CompletableFuture taskFut = task.execute(); + if (taskFut.isDone()) { + resQueue.add(taskFut); + } else { + CompletableFuture resFut = + taskFut.whenComplete( + (v, t) -> { + drainQueue(taskQueue, resQueue); }); - resQueue.add(resFut); - } - } + resQueue.add(resFut); + } } + } } diff --git a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/TimeController.java b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/TimeController.java index c63e0c843..37aa13555 100644 --- a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/TimeController.java +++ b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/TimeController.java @@ -4,17 +4,13 @@ import java.util.concurrent.CompletableFuture; /** - * Controls global time and execution order of child executors - * The instance can be either 'root' (with no parent) or dependent - * on the parent controller. - * In the latter case all the calls delegated to the parent controller - * which manages the list of tasks and the global time + * Controls global time and execution order of child executors The instance can be either 'root' + * (with no parent) or dependent on the parent controller. In the latter case all the calls + * delegated to the parent controller which manages the list of tasks and the global time */ public interface TimeController { - /** - * Abstract scheduled task - */ + /** Abstract scheduled task */ interface Task { long getTime(); @@ -23,40 +19,33 @@ interface Task { } /** - * Returns this controller local time which differs from the parent - * time in case if time shift != 0 + * Returns this controller local time which differs from the parent time in case if time shift != + * 0 */ long getTime(); /** - * The method call is only valid for the 'root' controller - * Sets internal clock time and executes any tasks scheduled in period from - * the previous time till new currentTime inclusive. - * Periodic tasks are executed several times if scheduled so. - * @param newTime should be >= the last set time + * The method call is only valid for the 'root' controller Sets internal clock time and executes + * any tasks scheduled in period from the previous time till new currentTime + * inclusive. Periodic tasks are executed several times if scheduled so. * + * @param newTime should be >= the last set time * @throws IllegalStateException if the controller is not root */ void setTime(long newTime); - /** - * Child executors should add new scheduled tasks via this method - */ + /** Child executors should add new scheduled tasks via this method */ void addTask(Task task); - /** - * Child executors should cancel scheduled tasks via this method - */ + /** Child executors should cancel scheduled tasks via this method */ void cancelTask(Task task); - /** - * Sets the parent of this controller making it dependent - */ + /** Sets the parent of this controller making it dependent */ void setParent(TimeController parent); /** - * Simulates system clock deviations - * All children executors will see current time shifted by specified value + * Simulates system clock deviations All children executors will see current time shifted by + * specified value */ void setTimeShift(long timeShift); diff --git a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/TimeControllerImpl.java b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/TimeControllerImpl.java index 18a04be17..da7b74b4e 100644 --- a/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/TimeControllerImpl.java +++ b/tools/schedulers/src/main/java/io/libp2p/tools/schedulers/TimeControllerImpl.java @@ -25,7 +25,7 @@ public long getTime() { public void setTime(long newTime) { if (parent != null) { throw new IllegalStateException( - "setTime() is allowed only for the topmost TimeController (without parent)"); + "setTime() is allowed only for the topmost TimeController (without parent)"); } if (newTime < curTime) { throw new IllegalArgumentException("newTime < curTime: " + newTime + ", " + curTime); diff --git a/tools/simulator/src/main/kotlin/io/libp2p/simulate/delay/ChannelMessageDelayer.kt b/tools/simulator/src/main/kotlin/io/libp2p/simulate/delay/ChannelMessageDelayer.kt new file mode 100644 index 000000000..a69a095f2 --- /dev/null +++ b/tools/simulator/src/main/kotlin/io/libp2p/simulate/delay/ChannelMessageDelayer.kt @@ -0,0 +1,29 @@ +package io.libp2p.simulate.delay + +import io.libp2p.simulate.BandwidthDelayer +import io.libp2p.simulate.MessageDelayer +import io.libp2p.simulate.delay.SequentialDelayer.Companion.sequential +import java.util.concurrent.CompletableFuture +import java.util.concurrent.ScheduledExecutorService + +class ChannelMessageDelayer( + executor: ScheduledExecutorService, + localOutboundBandwidthDelayer: BandwidthDelayer, + connectionLatencyDelayer: MessageDelayer, + remoteInboundBandwidthDelayer: BandwidthDelayer, +) : MessageDelayer { + + private val sequentialOutboundBandwidthDelayer = localOutboundBandwidthDelayer.sequential(executor) + private val sequentialInboundBandwidthDelayer = remoteInboundBandwidthDelayer.sequential(executor) + + private val delayer = MessageDelayer { size -> + CompletableFuture.allOf( + sequentialOutboundBandwidthDelayer.delay(size) + .thenCompose { connectionLatencyDelayer.delay(size) }, + connectionLatencyDelayer.delay(size) + .thenCompose { sequentialInboundBandwidthDelayer.delay(size) } + ).thenApply { } + } + + override fun delay(size: Long): CompletableFuture = delayer.delay(size) +} diff --git a/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/Eth2GossipParams.kt b/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/Eth2GossipParams.kt index e65abf015..8d47c0e66 100644 --- a/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/Eth2GossipParams.kt +++ b/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/Eth2GossipParams.kt @@ -25,7 +25,7 @@ val Eth2DefaultGossipParams = GossipParams( DLazy = 8, pruneBackoff = 1.minutes, - floodPublish = true, + floodPublishMaxMessageSizeThreshold = 16384, gossipFactor = 0.25, DScore = 4, DOut = 2, diff --git a/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/GossipSimNetwork.kt b/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/GossipSimNetwork.kt index eb78b432d..e2aa4f0a9 100644 --- a/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/GossipSimNetwork.kt +++ b/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/GossipSimNetwork.kt @@ -26,10 +26,11 @@ class GossipSimNetwork( val commonExecutor = ControlledExecutorServiceImpl(timeController) protected val peerExecutors = - if (cfg.iterationThreadsCount > 1) + if (cfg.iterationThreadsCount > 1) { (0 until cfg.iterationThreadsCount).map { Executors.newSingleThreadScheduledExecutor() } - else + } else { listOf(Executor { it.run() }) + } var simPeerFactory: (Int, SimGossipRouterBuilder) -> GossipSimPeer = { number, router -> GossipSimPeer(number, commonRnd).apply { diff --git a/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/GossipSimPeer.kt b/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/GossipSimPeer.kt index eb8188ccf..08eadfaaa 100644 --- a/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/GossipSimPeer.kt +++ b/tools/simulator/src/main/kotlin/io/libp2p/simulate/gossip/GossipSimPeer.kt @@ -15,7 +15,7 @@ import java.util.concurrent.CompletableFuture class GossipSimPeer( override val simPeerId: Int, override val random: Random, - protocol: PubsubProtocol = PubsubProtocol.Gossip_V_1_1 + protocol: PubsubProtocol = PubsubProtocol.Gossip_V_1_2 ) : StreamSimPeer(true, protocol.announceStr) { var routerBuilder = SimGossipRouterBuilder() @@ -41,8 +41,11 @@ class GossipSimPeer( val logConnection = pubsubLogs(stream.remotePeerId()) router.addPeerWithDebugHandler( stream, - if (logConnection) - LoggingHandler(name, LogLevel.ERROR) else null + if (logConnection) { + LoggingHandler(name, LogLevel.ERROR) + } else { + null + } ) return dummy } diff --git a/tools/simulator/src/main/kotlin/io/libp2p/simulate/main/BlobDecouplingSimulation.kt b/tools/simulator/src/main/kotlin/io/libp2p/simulate/main/BlobDecouplingSimulation.kt index bda3c4e22..ca45433cd 100644 --- a/tools/simulator/src/main/kotlin/io/libp2p/simulate/main/BlobDecouplingSimulation.kt +++ b/tools/simulator/src/main/kotlin/io/libp2p/simulate/main/BlobDecouplingSimulation.kt @@ -41,8 +41,6 @@ class BlobDecouplingSimulation( val randomSeed: Long = 3L, val rnd: Random = Random(randomSeed), - val floodPublish: Boolean = true, - val sendingPeerBand: Bandwidth = Bandwidth.mbitsPerSec(100), val peerBands: Iterator = iterator { @@ -83,10 +81,6 @@ class BlobDecouplingSimulation( ) val gossipParams = Eth2DefaultGossipParams - .copy( -// heartbeatInterval = 1.minutes - floodPublish = floodPublish - ) val gossipScoreParams = Eth2DefaultScoreParams val gossipRouterCtor = { _: Int -> SimGossipRouterBuilder().also { @@ -135,7 +129,6 @@ class BlobDecouplingSimulation( } fun testOnlyBlockDecoupled() { - for (i in 0 until messageCount) { val sendingPeer = sendingPeerIndexes[i] logger("Sending message $i from peer $sendingPeer") @@ -153,7 +146,6 @@ class BlobDecouplingSimulation( } fun testAllDecoupled() { - for (i in 0 until messageCount) { val sendingPeer = sendingPeerIndexes[i] logger("Sending message $i from peer $sendingPeer") @@ -262,7 +254,7 @@ class BlobDecouplingSimulation( fun main() { val bandwidths = bandwidthDistributions.entries.toList() .let { - listOf(/*it[0],*/ it[2]) + listOf(it[2]) }.toMap() val slowBandwidth = Bandwidth.mbitsPerSec(10) @@ -281,8 +273,11 @@ fun main() { val groupedDelays = sim.simulation.gatherPubDeliveryStats() .aggregateSlowestByPublishTime() .groupBy { - if (it.toPeer.inboundBandwidth.totalBandwidth == slowBandwidth) - "Slow" else "Fast" + if (it.toPeer.inboundBandwidth.totalBandwidth == slowBandwidth) { + "Slow" + } else { + "Fast" + } } .mapValues { it.value.deliveryDelays } return GroupByRangeAggregator(groupedDelays) @@ -293,7 +288,6 @@ fun main() { // logger = {}, nodeCount = 1000, peerBands = band, - floodPublish = false, // randomSeed = 2 ) diff --git a/tools/simulator/src/main/kotlin/io/libp2p/simulate/main/GossipScoreTestSimulation.kt b/tools/simulator/src/main/kotlin/io/libp2p/simulate/main/GossipScoreTestSimulation.kt index a4a64501b..c9758a131 100644 --- a/tools/simulator/src/main/kotlin/io/libp2p/simulate/main/GossipScoreTestSimulation.kt +++ b/tools/simulator/src/main/kotlin/io/libp2p/simulate/main/GossipScoreTestSimulation.kt @@ -16,7 +16,6 @@ fun main() { class GossipScoreTestSimulation { fun run() { - val simConfig = GossipSimConfig( totalPeers = 1000, topics = listOf(Topic(BlocksTopic)), diff --git a/tools/simulator/src/main/kotlin/io/libp2p/simulate/stats/StatsFactory.kt b/tools/simulator/src/main/kotlin/io/libp2p/simulate/stats/StatsFactory.kt index ded5f14ae..5eb94f547 100644 --- a/tools/simulator/src/main/kotlin/io/libp2p/simulate/stats/StatsFactory.kt +++ b/tools/simulator/src/main/kotlin/io/libp2p/simulate/stats/StatsFactory.kt @@ -17,7 +17,7 @@ interface StatsFactory { override fun toString() = "" } - var DEFAULT: StatsFactory = object : StatsFactory { + val DEFAULT: StatsFactory = object : StatsFactory { override fun createStats(name: String) = DescriptiveStatsImpl() } } diff --git a/tools/simulator/src/main/kotlin/io/libp2p/simulate/stats/collect/gossip/GossipMessageResult.kt b/tools/simulator/src/main/kotlin/io/libp2p/simulate/stats/collect/gossip/GossipMessageResult.kt index 911672dc9..732adedb1 100644 --- a/tools/simulator/src/main/kotlin/io/libp2p/simulate/stats/collect/gossip/GossipMessageResult.kt +++ b/tools/simulator/src/main/kotlin/io/libp2p/simulate/stats/collect/gossip/GossipMessageResult.kt @@ -165,7 +165,7 @@ class GossipMessageResult( ret += msg curPeer = msg.origMsg.sendingPeer } - return ret.reversed() + return ret.asReversed() } fun findPubMessageFirst(peer: SimPeer, msgId: Long): PubMessageWrapper? = diff --git a/tools/simulator/src/main/kotlin/io/libp2p/simulate/stream/Libp2pConnectionImpl.kt b/tools/simulator/src/main/kotlin/io/libp2p/simulate/stream/Libp2pConnectionImpl.kt index 9c17a0050..3ca752706 100644 --- a/tools/simulator/src/main/kotlin/io/libp2p/simulate/stream/Libp2pConnectionImpl.kt +++ b/tools/simulator/src/main/kotlin/io/libp2p/simulate/stream/Libp2pConnectionImpl.kt @@ -9,8 +9,7 @@ import io.libp2p.simulate.util.NullTransport import io.libp2p.transport.implementation.ConnectionOverNetty class Libp2pConnectionImpl( - val remoteAddr: - Multiaddr, + val remoteAddr: Multiaddr, isInitiator: Boolean, localPubkey: PubKey, remotePubkey: PubKey, diff --git a/tools/simulator/src/main/kotlin/io/libp2p/simulate/stream/StreamNettyChannel.kt b/tools/simulator/src/main/kotlin/io/libp2p/simulate/stream/StreamNettyChannel.kt index 1ca9d0fa4..b90d7014f 100644 --- a/tools/simulator/src/main/kotlin/io/libp2p/simulate/stream/StreamNettyChannel.kt +++ b/tools/simulator/src/main/kotlin/io/libp2p/simulate/stream/StreamNettyChannel.kt @@ -2,6 +2,7 @@ package io.libp2p.simulate.stream import io.libp2p.etc.types.lazyVar import io.libp2p.simulate.* +import io.libp2p.simulate.delay.ChannelMessageDelayer import io.libp2p.simulate.delay.SequentialDelayer.Companion.sequential import io.libp2p.simulate.util.GeneralSizeEstimator import io.netty.channel.Channel @@ -13,7 +14,6 @@ import io.netty.channel.DefaultChannelPromise import io.netty.channel.EventLoop import io.netty.channel.embedded.EmbeddedChannel import io.netty.util.internal.ObjectUtil -import java.util.concurrent.CompletableFuture import java.util.concurrent.Executors import java.util.concurrent.ScheduledExecutorService @@ -26,9 +26,9 @@ class StreamNettyChannel( vararg handlers: ChannelHandler? ) : SimChannel, EmbeddedChannel( - SimChannelId(id), - *handlers -) { + SimChannelId(id), + *handlers + ) { override val msgVisitors: MutableList = mutableListOf() @@ -37,29 +37,17 @@ class StreamNettyChannel( var currentTime: () -> Long = System::currentTimeMillis var msgSizeEstimator = GeneralSizeEstimator private var msgDelayer: MessageDelayer by lazyVar { - createMessageDelayer(outboundBandwidth, MessageDelayer.NO_DELAYER, inboundBandwidth) + createMessageDelayer(MessageDelayer.NO_DELAYER) .sequential(executor) } fun setLatency(latency: MessageDelayer) { - msgDelayer = createMessageDelayer(outboundBandwidth, latency, inboundBandwidth) + msgDelayer = createMessageDelayer(latency) } private fun createMessageDelayer( - outboundBandwidthDelayer: BandwidthDelayer, connectionLatencyDelayer: MessageDelayer, - inboundBandwidthDelayer: BandwidthDelayer, - ): MessageDelayer { - return MessageDelayer { size -> - CompletableFuture.allOf( - outboundBandwidthDelayer.delay(size) - .thenCompose { connectionLatencyDelayer.delay(size) }, - connectionLatencyDelayer.delay(size) - .thenCompose { inboundBandwidthDelayer.delay(size) } - ).thenApply { } - } - .sequential(executor) - } + ): MessageDelayer = ChannelMessageDelayer(executor, outboundBandwidth, connectionLatencyDelayer, inboundBandwidth) @Synchronized fun connect(other: StreamNettyChannel) { diff --git a/tools/simulator/src/main/kotlin/io/libp2p/simulate/stream/StreamSimPeer.kt b/tools/simulator/src/main/kotlin/io/libp2p/simulate/stream/StreamSimPeer.kt index 2958bb1e5..4749122a7 100644 --- a/tools/simulator/src/main/kotlin/io/libp2p/simulate/stream/StreamSimPeer.kt +++ b/tools/simulator/src/main/kotlin/io/libp2p/simulate/stream/StreamSimPeer.kt @@ -3,7 +3,7 @@ package io.libp2p.simulate.stream import io.libp2p.core.PeerId import io.libp2p.core.Stream import io.libp2p.core.StreamHandler -import io.libp2p.core.crypto.KEY_TYPE +import io.libp2p.core.crypto.KeyType import io.libp2p.core.crypto.generateKeyPair import io.libp2p.core.multiformats.Multiaddr import io.libp2p.core.multiformats.MultiaddrComponent @@ -42,7 +42,7 @@ abstract class StreamSimPeer( lateinit var currentTime: () -> Long var keyPair by lazyVar { generateKeyPair( - KEY_TYPE.ECDSA, + KeyType.ECDSA, random = SecureRandom(ByteArray(4).also { random.nextBytes(it) }) ) } diff --git a/tools/simulator/src/main/kotlin/io/libp2p/simulate/stream/StreamSimStream.kt b/tools/simulator/src/main/kotlin/io/libp2p/simulate/stream/StreamSimStream.kt index b5a8d1f1b..370949219 100644 --- a/tools/simulator/src/main/kotlin/io/libp2p/simulate/stream/StreamSimStream.kt +++ b/tools/simulator/src/main/kotlin/io/libp2p/simulate/stream/StreamSimStream.kt @@ -22,11 +22,17 @@ class StreamSimStream( init { val from = - if (streamInitiator == SimStream.StreamInitiator.CONNECTION_DIALER) connection.dialer - else connection.listener + if (streamInitiator == SimStream.StreamInitiator.CONNECTION_DIALER) { + connection.dialer + } else { + connection.listener + } val to = - if (streamInitiator == SimStream.StreamInitiator.CONNECTION_LISTENER) connection.dialer - else connection.listener + if (streamInitiator == SimStream.StreamInitiator.CONNECTION_LISTENER) { + connection.dialer + } else { + connection.listener + } val fromIsInitiator = from === connection.dialer val toIsInitiator = !fromIsInitiator @@ -53,7 +59,6 @@ class StreamSimStream( connectionInitiator: Boolean, streamInitiator: Boolean ): StreamNettyChannel { - return StreamNettyChannel( channelName, this, diff --git a/tools/simulator/src/main/kotlin/io/libp2p/simulate/util/NumExt.kt b/tools/simulator/src/main/kotlin/io/libp2p/simulate/util/NumExt.kt index bc061033c..be3113ac1 100644 --- a/tools/simulator/src/main/kotlin/io/libp2p/simulate/util/NumExt.kt +++ b/tools/simulator/src/main/kotlin/io/libp2p/simulate/util/NumExt.kt @@ -8,8 +8,7 @@ fun Collection.groupByRangesBy( valueExtractor: (TSrc) -> TValue, vararg ranges: ClosedRange ): Map, List> - where TKey : Number, TKey : Comparable { - + where TKey : Number, TKey : Comparable { return this .mapNotNull { v -> ranges.firstOrNull { it.contains(keyExtractor(v)) }?.let { it to v } } .groupBy({ it.first }, { valueExtractor(it.second) }) @@ -20,15 +19,19 @@ fun Collection.groupByRangesBy( keyExtractor: (TSrc) -> TKey, vararg ranges: ClosedRange ): Map, List> - where TKey : Number, TKey : Comparable = + where TKey : Number, TKey : Comparable = groupByRangesBy(keyExtractor, { it }, *ranges) -fun Collection>.groupByRanges(vararg ranges: ClosedRange): Map, List> - where T : Number, T : Comparable = +fun Collection>.groupByRanges( + vararg ranges: ClosedRange +): Map, List> + where T : Number, T : Comparable = groupByRangesBy({ it.first }, { it.second }, *ranges) -fun Collection.countByRanges(vararg ranges: ClosedRange): List - where T : Number, T : Comparable { +fun Collection.countByRanges( + vararg ranges: ClosedRange +): List + where T : Number, T : Comparable { val v = this .map { it to it } .groupByRangesBy({ it.first }, { it.second }, *ranges) @@ -36,7 +39,9 @@ fun Collection.countByRanges(vararg ranges: ClosedRange): List return ranges.map { v[it]?.size ?: 0 } } -fun Collection.countByRanges(ranges: List>): List +fun Collection.countByRanges( + ranges: List> +): List where T : Number, T : Comparable = countByRanges(*ranges.toTypedArray()) diff --git a/tools/simulator/src/test/kotlin/io/libp2p/simulate/delay/ChannelMessageDelayerTest.kt b/tools/simulator/src/test/kotlin/io/libp2p/simulate/delay/ChannelMessageDelayerTest.kt new file mode 100644 index 000000000..84294c4df --- /dev/null +++ b/tools/simulator/src/test/kotlin/io/libp2p/simulate/delay/ChannelMessageDelayerTest.kt @@ -0,0 +1,83 @@ +package io.libp2p.simulate.delay + +import io.libp2p.simulate.Bandwidth +import io.libp2p.tools.schedulers.ControlledExecutorServiceImpl +import io.libp2p.tools.schedulers.TimeControllerImpl +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource +import kotlin.time.Duration.Companion.milliseconds + +class ChannelMessageDelayerTest { + val timeController = TimeControllerImpl() + val executor = ControlledExecutorServiceImpl(timeController) + + private val Int.bytesPerSecond get() = Bandwidth(this.toLong()) + private fun Bandwidth.simpleDelayer() = SimpleBandwidthTracker(this, executor) + private fun Int.millisLatencyDelayer() = TimeDelayer(executor) { this.milliseconds } + + @Test + fun `slow outbound bandwidth prevails`() { + val delayer = ChannelMessageDelayer( + executor, + 1000.bytesPerSecond.simpleDelayer(), + 300.millisLatencyDelayer(), + 1000000.bytesPerSecond.simpleDelayer() + ) + + val delay = delayer.delay(1000).thenApply { timeController.time } + + timeController.addTime(2000) + assertThat(delay).isCompletedWithValue(1300) + } + + @Test + fun `slow inbound bandwidth prevails`() { + val delayer = ChannelMessageDelayer( + executor, + 1000000.bytesPerSecond.simpleDelayer(), + 300.millisLatencyDelayer(), + 1000.bytesPerSecond.simpleDelayer() + ) + + val delay = delayer.delay(1000).thenApply { timeController.time } + timeController.addTime(2000) + + assertThat(delay).isCompletedWithValue(1300) + } + + @ParameterizedTest + @ValueSource(booleans = [true, false]) + fun `subsequent messages ordered and timely`(outboundBandwidthSlower: Boolean) { + val delayer = + when (outboundBandwidthSlower) { + true -> ChannelMessageDelayer( + executor, + 1000.bytesPerSecond.simpleDelayer(), + 300.millisLatencyDelayer(), + 1000000.bytesPerSecond.simpleDelayer() + ) + + false -> ChannelMessageDelayer( + executor, + 1000000.bytesPerSecond.simpleDelayer(), + 300.millisLatencyDelayer(), + 1000.bytesPerSecond.simpleDelayer() + ) + } + + val delay1 = delayer.delay(1000).thenApply { timeController.time } + val delay2 = delayer.delay(10).thenApply { timeController.time } + timeController.addTime(200) + val delay3 = delayer.delay(10).thenApply { timeController.time } + timeController.addTime(1099) + val delay4 = delayer.delay(10).thenApply { timeController.time } + timeController.addTime(10000) + + assertThat(delay1).isCompletedWithValue(1300) + assertThat(delay2).isCompletedWithValue(1310) + assertThat(delay3).isCompletedWithValue(1320) + assertThat(delay4).isCompletedWithValue(1609) + } +} diff --git a/versions.gradle b/versions.gradle index 9bc06333d..ba4a91e5c 100644 --- a/versions.gradle +++ b/versions.gradle @@ -1,33 +1,37 @@ dependencyManagement { + // https://docs.spring.io/dependency-management-plugin/docs/current/reference/html/#pom-generation-disabling + generatedPomCustomization { + enabled = false + } dependencies { dependency "org.jetbrains.kotlinx:kotlinx-coroutines-core:1.6.4" - dependency "com.google.guava:guava:31.1-jre" + dependency "com.google.guava:guava:33.3.1-jre" - dependency "org.slf4j:slf4j-api:2.0.7" - dependencySet(group: 'org.apache.logging.log4j', version: '2.20.0') { + dependency "org.slf4j:slf4j-api:2.0.9" + dependencySet(group: 'org.apache.logging.log4j', version: '2.24.1') { entry 'log4j-core' entry 'log4j-slf4j2-impl' } - dependencySet(group: 'org.junit.jupiter', version: '5.9.2') { + dependencySet(group: 'org.junit.jupiter', version: '5.11.3') { entry 'junit-jupiter-api' entry 'junit-jupiter-engine' entry 'junit-jupiter-params' } dependency "io.mockk:mockk:1.13.3" - dependency "org.assertj:assertj-core:3.24.2" + dependency "org.assertj:assertj-core:3.26.3" - dependencySet(group: "org.openjdk.jmh", version: "1.36") { + dependencySet(group: "org.openjdk.jmh", version: "1.37") { entry 'jmh-core' entry 'jmh-generator-annprocess' } - dependencySet(group: "com.google.protobuf", version: "3.21.12") { + dependencySet(group: "com.google.protobuf", version: "3.25.5") { entry 'protobuf-java' entry 'protoc' } - dependencySet(group: "io.netty", version: "4.1.90.Final") { + dependencySet(group: "io.netty", version: "4.1.115.Final") { entry 'netty-common' entry 'netty-handler' entry 'netty-transport' @@ -35,12 +39,12 @@ dependencyManagement { entry 'netty-codec-http' entry 'netty-transport-classes-epoll' } - dependency "commons-codec:commons-codec:1.15" + dependency "com.github.multiformats:java-multibase:v1.1.1" dependency "tech.pegasys:noise-java:22.1.0" - dependencySet(group: "org.bouncycastle", version: "1.70") { - entry 'bcprov-jdk15on' - entry 'bcpkix-jdk15on' - entry 'bctls-jdk15on' + dependencySet(group: "org.bouncycastle", version: "1.78.1") { + entry 'bcprov-jdk18on' + entry 'bcpkix-jdk18on' + entry 'bctls-jdk18on' } dependency "io.netty.incubator:netty-incubator-codec-native-quic:0.0.38.Final" }