diff --git a/Sources/GRPC/CallOptions.swift b/Sources/GRPC/CallOptions.swift index 866d888bb..4589ef928 100644 --- a/Sources/GRPC/CallOptions.swift +++ b/Sources/GRPC/CallOptions.swift @@ -133,6 +133,7 @@ extension CallOptions { self.source = source } + @usableFromInline internal func requestID() -> String? { switch self.source { case .none: diff --git a/Sources/GRPC/ConnectionPool/PooledChannel.swift b/Sources/GRPC/ConnectionPool/PooledChannel.swift index 022f30510..d3a407178 100644 --- a/Sources/GRPC/ConnectionPool/PooledChannel.swift +++ b/Sources/GRPC/ConnectionPool/PooledChannel.swift @@ -132,6 +132,11 @@ internal final class PooledChannel: GRPCChannel { callOptions: CallOptions, interceptors: [ClientInterceptor] ) -> Call where Request: Message, Response: Message { + var callOptions = callOptions + if let requestID = callOptions.requestIDProvider.requestID() { + callOptions.applyRequestID(requestID) + } + let (stream, eventLoop) = self._makeStreamChannel(callOptions: callOptions) return Call( @@ -157,6 +162,11 @@ internal final class PooledChannel: GRPCChannel { callOptions: CallOptions, interceptors: [ClientInterceptor] ) -> Call where Request: GRPCPayload, Response: GRPCPayload { + var callOptions = callOptions + if let requestID = callOptions.requestIDProvider.requestID() { + callOptions.applyRequestID(requestID) + } + let (stream, eventLoop) = self._makeStreamChannel(callOptions: callOptions) return Call( @@ -192,3 +202,14 @@ internal final class PooledChannel: GRPCChannel { self._pool.shutdown(mode: .graceful(deadline), promise: promise) } } + +extension CallOptions { + @usableFromInline + mutating func applyRequestID(_ requestID: String) { + self.logger[metadataKey: MetadataKey.requestID] = "\(requestID)" + // Add the request ID header too. + if let requestIDHeader = self.requestIDHeader { + self.customMetadata.add(name: requestIDHeader, value: requestID) + } + } +} diff --git a/Tests/GRPCTests/EchoHelpers/Providers/MetadataEchoProvider.swift b/Tests/GRPCTests/EchoHelpers/Providers/MetadataEchoProvider.swift new file mode 100644 index 000000000..2b39df4e1 --- /dev/null +++ b/Tests/GRPCTests/EchoHelpers/Providers/MetadataEchoProvider.swift @@ -0,0 +1,54 @@ +/* + * Copyright 2021, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import EchoModel +import GRPC +import NIOCore + +internal final class MetadataEchoProvider: Echo_EchoProvider { + let interceptors: Echo_EchoServerInterceptorFactoryProtocol? = nil + + func get( + request: Echo_EchoRequest, + context: StatusOnlyCallContext + ) -> EventLoopFuture { + let response = Echo_EchoResponse.with { + $0.text = context.headers.sorted(by: { $0.name < $1.name }).map { + $0.name + ": " + $0.value + }.joined(separator: "\n") + } + + return context.eventLoop.makeSucceededFuture(response) + } + + func expand( + request: Echo_EchoRequest, + context: StreamingResponseCallContext + ) -> EventLoopFuture { + return context.eventLoop.makeFailedFuture(GRPCStatus(code: .unimplemented)) + } + + func collect( + context: UnaryResponseCallContext + ) -> EventLoopFuture<(StreamEvent) -> Void> { + return context.eventLoop.makeFailedFuture(GRPCStatus(code: .unimplemented)) + } + + func update( + context: StreamingResponseCallContext + ) -> EventLoopFuture<(StreamEvent) -> Void> { + return context.eventLoop.makeFailedFuture(GRPCStatus(code: .unimplemented)) + } +} diff --git a/Tests/GRPCTests/RequestIDTests.swift b/Tests/GRPCTests/RequestIDTests.swift new file mode 100644 index 000000000..e3d2f662a --- /dev/null +++ b/Tests/GRPCTests/RequestIDTests.swift @@ -0,0 +1,85 @@ +/* + * Copyright 2021, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import EchoModel +import GRPC +import NIOCore +import NIOPosix +import XCTest + +internal final class RequestIDTests: GRPCTestCase { + private var server: Server! + private var group: EventLoopGroup! + + override func setUp() { + super.setUp() + + self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + self.server = try! Server.insecure(group: self.group) + .withServiceProviders([MetadataEchoProvider()]) + .withLogger(self.serverLogger) + .bind(host: "127.0.0.1", port: 0) + .wait() + } + + override func tearDown() { + XCTAssertNoThrow(try self.server.close().wait()) + XCTAssertNoThrow(try self.group.syncShutdownGracefully()) + super.tearDown() + } + + func testRequestIDIsPopulatedClientConnection() throws { + let channel = ClientConnection.insecure(group: self.group) + .connect(host: "127.0.0.1", port: self.server.channel.localAddress!.port!) + + defer { + let loop = group.next() + let promise = loop.makePromise(of: Void.self) + channel.closeGracefully(deadline: .now() + .seconds(30), promise: promise) + XCTAssertNoThrow(try promise.futureResult.wait()) + } + + try self._testRequestIDIsPopulated(channel: channel) + } + + func testRequestIDIsPopulatedChannelPool() throws { + let channel = try! GRPCChannelPool.with( + target: .host("127.0.0.1", port: self.server.channel.localAddress!.port!), + transportSecurity: .plaintext, + eventLoopGroup: self.group + ) + + defer { + let loop = group.next() + let promise = loop.makePromise(of: Void.self) + channel.closeGracefully(deadline: .now() + .seconds(30), promise: promise) + XCTAssertNoThrow(try promise.futureResult.wait()) + } + + try self._testRequestIDIsPopulated(channel: channel) + } + + func _testRequestIDIsPopulated(channel: GRPCChannel) throws { + let echo = Echo_EchoClient(channel: channel) + let options = CallOptions( + requestIDProvider: .userDefined("foo"), + requestIDHeader: "request-id-header" + ) + + let get = echo.get(.with { $0.text = "ignored" }, callOptions: options) + let response = try get.response.wait() + XCTAssert(response.text.contains("request-id-header: foo")) + } +}