diff --git a/Sources/LiveKit/Core/DataChannelPair.swift b/Sources/LiveKit/Core/DataChannelPair.swift index a5f4a1d9d..9827acaa8 100644 --- a/Sources/LiveKit/Core/DataChannelPair.swift +++ b/Sources/LiveKit/Core/DataChannelPair.swift @@ -104,19 +104,21 @@ class DataChannelPair: NSObject, Loggable { } public func send(userPacket: Livekit_UserPacket, kind: Livekit_DataPacket.Kind) throws { - guard isOpen else { - throw LiveKitError(.invalidState, message: "Data channel is not open") - } - - let packet = Livekit_DataPacket.with { + try send(dataPacket: .with { $0.kind = kind $0.user = userPacket + }) + } + + public func send(dataPacket packet: Livekit_DataPacket) throws { + guard isOpen else { + throw LiveKitError(.invalidState, message: "Data channel is not open") } let serializedData = try packet.serializedData() let rtcData = RTC.createDataBuffer(data: serializedData) - let channel = _state.read { kind == .reliable ? $0.reliable : $0.lossy } + let channel = _state.read { packet.kind == .reliable ? $0.reliable : $0.lossy } guard let sendDataResult = channel?.sendData(rtcData), sendDataResult else { throw LiveKitError(.invalidState, message: "sendData failed") } diff --git a/Sources/LiveKit/Core/RPC.swift b/Sources/LiveKit/Core/RPC.swift new file mode 100644 index 000000000..590dcb010 --- /dev/null +++ b/Sources/LiveKit/Core/RPC.swift @@ -0,0 +1,187 @@ +/* + * Copyright 2025 LiveKit + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import Foundation + +/// Specialized error handling for RPC methods. +/// +/// Instances of this type, when thrown in a RPC method handler, will have their `message` +/// serialized and sent across the wire. The sender will receive an equivalent error on the other side. +/// +/// Built-in types are included but developers may use any message string, with a max length of 256 bytes. +struct RpcError: Error { + /// The error code of the RPC call. Error codes 1001-1999 are reserved for built-in errors. + /// + /// See `RpcError.BuiltInError` for built-in error information. + let code: Int + + /// A message to include. Strings over 256 bytes will be truncated. + let message: String + + /// An optional data payload. Must be smaller than 15KB in size, or else will be truncated. + let data: String + + enum BuiltInError { + case applicationError + case connectionTimeout + case responseTimeout + case recipientDisconnected + case responsePayloadTooLarge + case sendFailed + case unsupportedMethod + case recipientNotFound + case requestPayloadTooLarge + case unsupportedServer + case unsupportedVersion + + var code: Int { + switch self { + case .applicationError: return 1500 + case .connectionTimeout: return 1501 + case .responseTimeout: return 1502 + case .recipientDisconnected: return 1503 + case .responsePayloadTooLarge: return 1504 + case .sendFailed: return 1505 + case .unsupportedMethod: return 1400 + case .recipientNotFound: return 1401 + case .requestPayloadTooLarge: return 1402 + case .unsupportedServer: return 1403 + case .unsupportedVersion: return 1404 + } + } + + var message: String { + switch self { + case .applicationError: return "Application error in method handler" + case .connectionTimeout: return "Connection timeout" + case .responseTimeout: return "Response timeout" + case .recipientDisconnected: return "Recipient disconnected" + case .responsePayloadTooLarge: return "Response payload too large" + case .sendFailed: return "Failed to send" + case .unsupportedMethod: return "Method not supported at destination" + case .recipientNotFound: return "Recipient not found" + case .requestPayloadTooLarge: return "Request payload too large" + case .unsupportedServer: return "RPC not supported by server" + case .unsupportedVersion: return "Unsupported RPC version" + } + } + + func create(data: String = "") -> RpcError { + RpcError(code: code, message: message, data: data) + } + } + + static func builtIn(_ key: BuiltInError, data: String = "") -> RpcError { + RpcError(code: key.code, message: key.message, data: data) + } + + static let MAX_MESSAGE_BYTES = 256 + static let MAX_DATA_BYTES = 15360 // 15 KB + + static func fromProto(_ proto: Livekit_RpcError) -> RpcError { + RpcError( + code: Int(proto.code), + message: (proto.message).truncate(maxBytes: MAX_MESSAGE_BYTES), + data: proto.data.truncate(maxBytes: MAX_DATA_BYTES) + ) + } + + func toProto() -> Livekit_RpcError { + Livekit_RpcError.with { + $0.code = UInt32(code) + $0.message = message + $0.data = data + } + } +} + +/* + * Maximum payload size for RPC requests and responses. If a payload exceeds this size, + * the RPC call will fail with a REQUEST_PAYLOAD_TOO_LARGE(1402) or RESPONSE_PAYLOAD_TOO_LARGE(1504) error. + */ +let MAX_RPC_PAYLOAD_BYTES = 15360 // 15 KB + +/// A handler that processes an RPC request and returns a string +/// that will be sent back to the requester. +/// +/// Throwing an `RpcError` will send the error back to the requester. +/// +/// - SeeAlso: `LocalParticipant.registerRpcMethod` +public typealias RpcHandler = (RpcInvocationData) async throws -> String + +public struct RpcInvocationData { + /// A unique identifier for this RPC request + let requestId: String + + /// The identity of the RemoteParticipant who initiated the RPC call + let callerIdentity: Participant.Identity + + /// The data sent by the caller (as a string) + let payload: String + + /// The maximum time available to return a response + let responseTimeout: TimeInterval +} + +struct PendingRpcResponse { + let participantIdentity: Participant.Identity + let onResolve: (_ payload: String?, _ error: RpcError?) -> Void +} + +actor RpcStateManager { + private var handlers: [String: RpcHandler] = [:] // methodName to handler + private var pendingAcks: Set = Set() + private var pendingResponses: [String: PendingRpcResponse] = [:] // requestId to pending response + + func registerHandler(_ method: String, handler: @escaping RpcHandler) { + handlers[method] = handler + } + + func unregisterHandler(_ method: String) { + handlers.removeValue(forKey: method) + } + + func getHandler(for method: String) -> RpcHandler? { + handlers[method] + } + + func addPendingAck(_ requestId: String) { + pendingAcks.insert(requestId) + } + + @discardableResult + func removePendingAck(_ requestId: String) -> Bool { + pendingAcks.remove(requestId) != nil + } + + func hasPendingAck(_ requestId: String) -> Bool { + pendingAcks.contains(requestId) + } + + func setPendingResponse(_ requestId: String, response: PendingRpcResponse) { + pendingResponses[requestId] = response + } + + @discardableResult + func removePendingResponse(_ requestId: String) -> PendingRpcResponse? { + pendingResponses.removeValue(forKey: requestId) + } + + func removeAllPending(_ requestId: String) async { + pendingAcks.remove(requestId) + pendingResponses.removeValue(forKey: requestId) + } +} diff --git a/Sources/LiveKit/Core/Room+Engine.swift b/Sources/LiveKit/Core/Room+Engine.swift index 0b197b374..5b85b877e 100644 --- a/Sources/LiveKit/Core/Room+Engine.swift +++ b/Sources/LiveKit/Core/Room+Engine.swift @@ -69,6 +69,13 @@ extension Room { } func send(userPacket: Livekit_UserPacket, kind: Livekit_DataPacket.Kind) async throws { + try await send(dataPacket: .with { + $0.user = userPacket + $0.kind = kind + }) + } + + func send(dataPacket packet: Livekit_DataPacket) async throws { func ensurePublisherConnected() async throws { guard _state.isSubscriberPrimary else { return } @@ -96,7 +103,7 @@ extension Room { } // Should return true if successful - try publisherDataChannel.send(userPacket: userPacket, kind: kind) + try publisherDataChannel.send(dataPacket: packet) } } diff --git a/Sources/LiveKit/Core/Room+EngineDelegate.swift b/Sources/LiveKit/Core/Room+EngineDelegate.swift index b2e001083..12585d438 100644 --- a/Sources/LiveKit/Core/Room+EngineDelegate.swift +++ b/Sources/LiveKit/Core/Room+EngineDelegate.swift @@ -239,4 +239,39 @@ extension Room { $0.participant?(participant, trackPublication: publication, didReceiveTranscriptionSegments: segments) } } + + func room(didReceiveRpcResponse response: Livekit_RpcResponse) { + let (payload, error): (String?, RpcError?) = switch response.value { + case let .payload(v): (v, nil) + case let .error(e): (nil, RpcError.fromProto(e)) + default: (nil, nil) + } + + localParticipant.handleIncomingRpcResponse(requestId: response.requestID, + payload: payload, + error: error) + } + + func room(didReceiveRpcAck ack: Livekit_RpcAck) { + let requestId = ack.requestID + localParticipant.handleIncomingRpcAck(requestId: requestId) + } + + func room(didReceiveRpcRequest request: Livekit_RpcRequest, from participantIdentity: String) { + let callerIdentity = Participant.Identity(from: participantIdentity) + let requestId = request.id + let method = request.method + let payload = request.payload + let responseTimeout = TimeInterval(UInt64(request.responseTimeoutMs) / MSEC_PER_SEC) + let version = Int(request.version) + + Task { + await localParticipant.handleIncomingRpcRequest(callerIdentity: callerIdentity, + requestId: requestId, + method: method, + payload: payload, + responseTimeout: responseTimeout, + version: version) + } + } } diff --git a/Sources/LiveKit/Core/Room.swift b/Sources/LiveKit/Core/Room.swift index 953a8f439..96ed89908 100644 --- a/Sources/LiveKit/Core/Room.swift +++ b/Sources/LiveKit/Core/Room.swift @@ -535,6 +535,9 @@ extension Room: DataChannelDelegate { case let .speaker(update): engine(self, didUpdateSpeakers: update.speakers) case let .user(userPacket): engine(self, didReceiveUserPacket: userPacket) case let .transcription(packet): room(didReceiveTranscriptionPacket: packet) + case let .rpcResponse(response): room(didReceiveRpcResponse: response) + case let .rpcAck(ack): room(didReceiveRpcAck: ack) + case let .rpcRequest(request): room(didReceiveRpcRequest: request, from: dataPacket.participantIdentity) default: return } } diff --git a/Sources/LiveKit/Errors.swift b/Sources/LiveKit/Errors.swift index 04f3186ac..d6f103b05 100644 --- a/Sources/LiveKit/Errors.swift +++ b/Sources/LiveKit/Errors.swift @@ -29,6 +29,7 @@ public enum LiveKitErrorType: Int, Sendable { case failedToParseUrl = 102 case failedToConvertData = 103 case invalidState = 104 + case invalidParameter = 105 case webRTC = 201 @@ -66,6 +67,8 @@ extension LiveKitErrorType: CustomStringConvertible { return "Failed to convert data" case .invalidState: return "Invalid state" + case .invalidParameter: + return "Invalid parameter" case .webRTC: return "WebRTC error" case .network: diff --git a/Sources/LiveKit/Extensions/String.swift b/Sources/LiveKit/Extensions/String.swift index cd45e15bf..271ae03d8 100644 --- a/Sources/LiveKit/Extensions/String.swift +++ b/Sources/LiveKit/Extensions/String.swift @@ -21,4 +21,29 @@ extension String { var nilIfEmpty: String? { isEmpty ? nil : self } + + var byteLength: Int { + data(using: .utf8)?.count ?? 0 + } + + func truncate(maxBytes: Int) -> String { + if byteLength <= maxBytes { + return self + } + + var low = 0 + var high = count + + while low < high { + let mid = (low + high + 1) / 2 + let substring = String(prefix(mid)) + if substring.byteLength <= maxBytes { + low = mid + } else { + high = mid - 1 + } + } + + return String(prefix(low)) + } } diff --git a/Sources/LiveKit/Participant/LocalParticipant+RPC.swift b/Sources/LiveKit/Participant/LocalParticipant+RPC.swift new file mode 100644 index 000000000..441e22ade --- /dev/null +++ b/Sources/LiveKit/Participant/LocalParticipant+RPC.swift @@ -0,0 +1,292 @@ +/* + * Copyright 2025 LiveKit + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import Foundation + +// MARK: - Public RPC methods + +public extension LocalParticipant { + /// Establishes the participant as a receiver for calls of the specified RPC method. + /// Will overwrite any existing callback for the same method. + /// + /// Example: + /// ```swift + /// try await room.localParticipant.registerRpcMethod("greet") { data in + /// print("Received greeting from \(data.callerIdentity): \(data.payload)") + /// return "Hello, \(data.callerIdentity)!" + /// } + /// ``` + /// + /// The handler receives an `RpcInvocationData` containing the following parameters: + /// - `requestId`: A unique identifier for this RPC request + /// - `callerIdentity`: The identity of the RemoteParticipant who initiated the RPC call + /// - `payload`: The data sent by the caller (as a string) + /// - `responseTimeout`: The maximum time available to return a response + /// + /// The handler should return a string. + /// If unable to respond within responseTimeout, the request will result in an error on the caller's side. + /// + /// You may throw errors of type RpcError with a string message in the handler, + /// and they will be received on the caller's side with the message intact. + /// Other errors thrown in your handler will not be transmitted as-is, and will instead arrive to the caller as 1500 ("Application Error"). + /// + /// - Parameters: + /// - method: The name of the indicated RPC method + /// - handler: Will be invoked when an RPC request for this method is received + func registerRpcMethod(_ method: String, + handler: @escaping RpcHandler) async + { + await rpcState.registerHandler(method, handler: handler) + } + + /// Unregisters a previously registered RPC method. + /// + /// - Parameter method: The name of the RPC method to unregister + func unregisterRpcMethod(_ method: String) async { + await rpcState.unregisterHandler(method) + } + + /// Initiate an RPC call to a remote participant + /// - Parameters: + /// - destinationIdentity: The identity of the destination participant + /// - method: The method name to call + /// - payload: The payload to pass to the method + /// - responseTimeout: Timeout for receiving a response after initial connection. (default 10s) + /// - Returns: The response payload + /// - Throws: RpcError on failure. Details in RpcError.message + func performRpc(destinationIdentity: Identity, + method: String, + payload: String, + responseTimeout: TimeInterval = 10) async throws -> String + { + guard payload.byteLength <= MAX_RPC_PAYLOAD_BYTES else { + throw RpcError.builtIn(.requestPayloadTooLarge) + } + + let requestId = UUID().uuidString + let maxRoundTripLatency: TimeInterval = 2 + let effectiveTimeout = responseTimeout - maxRoundTripLatency + + try await publishRpcRequest(destinationIdentity: destinationIdentity, + requestId: requestId, + method: method, + payload: payload, + responseTimeout: effectiveTimeout) + + do { + return try await withThrowingTimeout(timeout: responseTimeout) { + try await withCheckedThrowingContinuation { continuation in + Task { + await self.rpcState.addPendingAck(requestId) + + await self.rpcState.setPendingResponse(requestId, response: PendingRpcResponse( + participantIdentity: destinationIdentity, + onResolve: { payload, error in + Task { + await self.rpcState.removePendingAck(requestId) + await self.rpcState.removePendingResponse(requestId) + + if let error { + continuation.resume(throwing: error) + } else { + continuation.resume(returning: payload ?? "") + } + } + } + )) + } + + Task { + try await Task.sleep(nanoseconds: UInt64(maxRoundTripLatency * 1_000_000_000)) + + if await self.rpcState.hasPendingAck(requestId) { + await self.rpcState.removeAllPending(requestId) + continuation.resume(throwing: RpcError.builtIn(.connectionTimeout)) + } + } + } + } + } catch { + if let error = error as? LiveKitError { + if error.type == .timedOut { + throw RpcError.builtIn(.connectionTimeout) + } + } + throw error + } + } +} + +// MARK: - RPC Internal + +extension LocalParticipant { + private func publishRpcRequest(destinationIdentity: Identity, + requestId: String, + method: String, + payload: String, + responseTimeout: TimeInterval = 10) async throws + { + guard payload.byteLength <= MAX_RPC_PAYLOAD_BYTES else { + throw LiveKitError(.invalidParameter, + message: "cannot publish data larger than \(MAX_RPC_PAYLOAD_BYTES)") + } + + let room = try requireRoom() + + let dataPacket = Livekit_DataPacket.with { + $0.destinationIdentities = [destinationIdentity.stringValue] + $0.kind = .reliable + $0.rpcRequest = Livekit_RpcRequest.with { + $0.id = requestId + $0.method = method + $0.payload = payload + $0.responseTimeoutMs = UInt32(responseTimeout * 1000) + $0.version = 1 + } + } + + try await room.send(dataPacket: dataPacket) + } + + private func publishRpcResponse(destinationIdentity: Identity, + requestId: String, + payload: String?, + error: RpcError?) async throws + { + let room = try requireRoom() + + let dataPacket = Livekit_DataPacket.with { + $0.destinationIdentities = [destinationIdentity.stringValue] + $0.kind = .reliable + $0.rpcResponse = Livekit_RpcResponse.with { + $0.requestID = requestId + if let error { + $0.error = error.toProto() + } else { + $0.payload = payload ?? "" + } + } + } + + try await room.send(dataPacket: dataPacket) + } + + private func publishRpcAck(destinationIdentity: Identity, + requestId: String) async throws + { + let room = try requireRoom() + + let dataPacket = Livekit_DataPacket.with { + $0.destinationIdentities = [destinationIdentity.stringValue] + $0.kind = .reliable + $0.rpcAck = Livekit_RpcAck.with { + $0.requestID = requestId + } + } + + try await room.send(dataPacket: dataPacket) + } + + func handleIncomingRpcRequest(callerIdentity: Identity, + requestId: String, + method: String, + payload: String, + responseTimeout: TimeInterval, + version: Int) async + { + do { + try await publishRpcAck(destinationIdentity: callerIdentity, + requestId: requestId) + } catch { + log("[Rpc] Failed to publish RPC ack for \(requestId)", .error) + } + + guard version == 1 else { + do { + try await publishRpcResponse(destinationIdentity: callerIdentity, + requestId: requestId, + payload: nil, + error: RpcError.builtIn(.unsupportedVersion)) + } catch { + log("[Rpc] Failed to publish RPC error response for \(requestId)", .error) + } + return + } + + guard let handler = await rpcState.getHandler(for: method) else { + do { + try await publishRpcResponse(destinationIdentity: callerIdentity, + requestId: requestId, + payload: nil, + error: RpcError.builtIn(.unsupportedMethod)) + } catch { + log("[Rpc] Failed to publish RPC error response for \(requestId)", .error) + } + return + } + + var responseError: RpcError? + var responsePayload: String? + + do { + let response = try await handler(RpcInvocationData(requestId: requestId, + callerIdentity: callerIdentity, + payload: payload, + responseTimeout: responseTimeout)) + + if response.byteLength > MAX_RPC_PAYLOAD_BYTES { + responseError = RpcError.builtIn(.responsePayloadTooLarge) + log("[Rpc] Response payload too large for \(method)", .warning) + } else { + responsePayload = response + } + } catch let error as RpcError { + responseError = error + } catch { + log("[Rpc] Uncaught error returned by RPC handler for \(method). Returning APPLICATION_ERROR instead.", .warning) + responseError = RpcError.builtIn(.applicationError) + } + + do { + try await publishRpcResponse(destinationIdentity: callerIdentity, + requestId: requestId, + payload: responsePayload, + error: responseError) + } catch { + log("[Rpc] Failed to publish RPC response for \(requestId)", .error) + } + } + + func handleIncomingRpcAck(requestId: String) { + Task { + await rpcState.removePendingAck(requestId) + } + } + + func handleIncomingRpcResponse(requestId: String, + payload: String?, + error: RpcError?) + { + Task { + guard let handler = await rpcState.removePendingResponse(requestId) else { + log("[Rpc] Response received for unexpected RPC request, id = \(requestId)", .error) + return + } + + handler.onResolve(payload, error) + } + } +} diff --git a/Sources/LiveKit/Participant/LocalParticipant.swift b/Sources/LiveKit/Participant/LocalParticipant.swift index f197a2719..2eafd55ee 100644 --- a/Sources/LiveKit/Participant/LocalParticipant.swift +++ b/Sources/LiveKit/Participant/LocalParticipant.swift @@ -38,6 +38,8 @@ public class LocalParticipant: Participant { private var trackPermissions: [ParticipantTrackPermission] = [] + let rpcState = RpcStateManager() + /// publish a new audio track to the Room @objc @discardableResult diff --git a/Sources/LiveKit/Support/Global.swift b/Sources/LiveKit/Support/Global.swift index 8c9ff22c4..d250d4721 100644 --- a/Sources/LiveKit/Support/Global.swift +++ b/Sources/LiveKit/Support/Global.swift @@ -14,7 +14,36 @@ * limitations under the License. */ +import Foundation + // merge a ClosedRange func merge(range range1: ClosedRange, with range2: ClosedRange) -> ClosedRange where T: Comparable { min(range1.lowerBound, range2.lowerBound) ... max(range1.upperBound, range2.upperBound) } + +// throws a timeout if the operation takes longer than the given timeout +func withThrowingTimeout(timeout: TimeInterval, + operation: @escaping () async throws -> T) async throws -> T +{ + try await withThrowingTaskGroup(of: T.self) { group in + group.addTask { + try await operation() + } + + group.addTask { + try await Task.sleep(nanoseconds: UInt64(timeout * 1_000_000_000)) + throw LiveKitError(.timedOut) + } + + let result = try await group.next() + + group.cancelAll() + + guard let result else { + // This should never happen since we know we added tasks + throw LiveKitError(.invalidState) + } + + return result + } +} diff --git a/Tests/LiveKitTests/Extensions/StringTests.swift b/Tests/LiveKitTests/Extensions/StringTests.swift new file mode 100644 index 000000000..4d7826af3 --- /dev/null +++ b/Tests/LiveKitTests/Extensions/StringTests.swift @@ -0,0 +1,47 @@ +/* + * Copyright 2025 LiveKit + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@testable import LiveKit +import XCTest + +final class StringTests: XCTestCase { + func testByteLength() { + // ASCII characters (1 byte each) + XCTAssertEqual("hello".byteLength, 5) + XCTAssertEqual("".byteLength, 0) + + // Unicode characters (variable bytes) + XCTAssertEqual("👋".byteLength, 4) // Emoji (4 bytes) + XCTAssertEqual("ñ".byteLength, 2) // Spanish n with tilde (2 bytes) + XCTAssertEqual("你好".byteLength, 6) // Chinese characters (3 bytes each) + } + + func testTruncate() { + // Test ASCII strings + XCTAssertEqual("hello".truncate(maxBytes: 5), "hello") + XCTAssertEqual("hello".truncate(maxBytes: 3), "hel") + XCTAssertEqual("".truncate(maxBytes: 5), "") + + // Test Unicode strings + XCTAssertEqual("👋hello".truncate(maxBytes: 4), "👋") // Emoji is 4 bytes + XCTAssertEqual("hi👋".truncate(maxBytes: 5), "hi") // Won't cut in middle of emoji + XCTAssertEqual("你好world".truncate(maxBytes: 6), "你好") // Chinese characters are 3 bytes each + + // Test edge cases + XCTAssertEqual("hello".truncate(maxBytes: 0), "") + XCTAssertEqual("hello".truncate(maxBytes: 100), "hello") + } +} diff --git a/Tests/LiveKitTests/RpcTests.swift b/Tests/LiveKitTests/RpcTests.swift new file mode 100644 index 000000000..bd9567352 --- /dev/null +++ b/Tests/LiveKitTests/RpcTests.swift @@ -0,0 +1,209 @@ +/* + * Copyright 2025 LiveKit + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@testable import LiveKit +import XCTest + +class RpcTests: XCTestCase { + // Mock DataChannelPair to intercept outgoing packets + class MockDataChannelPair: DataChannelPair { + var packetHandler: (Livekit_DataPacket) -> Void + + init(packetHandler: @escaping (Livekit_DataPacket) -> Void) { + self.packetHandler = packetHandler + } + + override func send(dataPacket packet: Livekit_DataPacket) throws { + packetHandler(packet) + } + } + + // Test performing RPC calls and verifying outgoing packets + func testPerformRpc() async throws { + try await withRooms([RoomTestingOptions()]) { rooms in + let room = rooms[0] + + let expectRequest = self.expectation(description: "Should send RPC request packet") + + let mockDataChannel = MockDataChannelPair { packet in + guard case let .rpcRequest(request) = packet.value else { + print("Not an RPC request packet") + return + } + + guard request.method == "test-method", request.payload == "test-payload", request.responseTimeoutMs == 8000 else { + return + } + + // Trigger fake response packets + Task { + try await Task.sleep(nanoseconds: 100_000_000) + + room.localParticipant.handleIncomingRpcAck(requestId: request.id) + + try await Task.sleep(nanoseconds: 100_000_000) + + room.localParticipant.handleIncomingRpcResponse( + requestId: request.id, + payload: "response-payload", + error: nil + ) + } + expectRequest.fulfill() + } + + room.publisherDataChannel = mockDataChannel + + let response = try await room.localParticipant.performRpc( + destinationIdentity: Participant.Identity(from: "test-destination"), + method: "test-method", + payload: "test-payload" + ) + + XCTAssertEqual(response, "response-payload") + await self.fulfillment(of: [expectRequest], timeout: 5.0) + } + } + + // Test registering and handling incoming RPC requests + func testHandleIncomingRpcRequest() async throws { + try await withRooms([RoomTestingOptions()]) { rooms in + let room = rooms[0] + + let expectResponse = self.expectation(description: "Should send RPC response packet") + + let mockDataChannel = MockDataChannelPair { packet in + guard case let .rpcResponse(response) = packet.value else { + return + } + + guard case let .payload(payload) = response.value else { + return + } + + guard response.requestID == "test-request-1", + payload == "Hello, test-caller!" + else { + return + } + + expectResponse.fulfill() + } + + room.publisherDataChannel = mockDataChannel + + await room.localParticipant.registerRpcMethod("greet") { data in + "Hello, \(data.callerIdentity)!" + } + + await room.localParticipant.handleIncomingRpcRequest( + callerIdentity: Participant.Identity(from: "test-caller"), + requestId: "test-request-1", + method: "greet", + payload: "Hi there!", + responseTimeout: 8, + version: 1 + ) + + await self.fulfillment(of: [expectResponse], timeout: 5.0) + } + } + + // Test error handling for RPC calls + func testRpcErrorHandling() async throws { + try await withRooms([RoomTestingOptions()]) { rooms in + let room = rooms[0] + + let expectError = self.expectation(description: "Should send error response packet") + + let mockDataChannel = MockDataChannelPair { packet in + guard case let .rpcResponse(response) = packet.value, + case let .error(error) = response.value + else { + return + } + + guard error.code == 2000, + error.message == "Custom error", + error.data == "Additional data" + else { + return + } + + expectError.fulfill() + } + + room.publisherDataChannel = mockDataChannel + + await room.localParticipant.registerRpcMethod("failingMethod") { _ in + throw RpcError(code: 2000, message: "Custom error", data: "Additional data") + } + + await room.localParticipant.handleIncomingRpcRequest( + callerIdentity: Participant.Identity(from: "test-caller"), + requestId: "test-request-1", + method: "failingMethod", + payload: "test", + responseTimeout: 8, + version: 1 + ) + + await self.fulfillment(of: [expectError], timeout: 5.0) + } + } + + // Test unregistering RPC methods + func testUnregisterRpcMethod() async throws { + try await withRooms([RoomTestingOptions()]) { rooms in + let room = rooms[0] + + let expectUnsupportedMethod = self.expectation(description: "Should send unsupported method error packet") + + let mockDataChannel = MockDataChannelPair { packet in + guard case let .rpcResponse(response) = packet.value, + case let .error(error) = response.value + else { + return + } + + guard error.code == RpcError.BuiltInError.unsupportedMethod.code else { + return + } + + expectUnsupportedMethod.fulfill() + } + + room.publisherDataChannel = mockDataChannel + + await room.localParticipant.registerRpcMethod("test") { _ in + "test response" + } + + await room.localParticipant.unregisterRpcMethod("test") + + await room.localParticipant.handleIncomingRpcRequest( + callerIdentity: Participant.Identity(from: "test-caller"), + requestId: "test-request-1", + method: "test", + payload: "test", + responseTimeout: 10, + version: 1 + ) + + await self.fulfillment(of: [expectUnsupportedMethod], timeout: 5.0) + } + } +}