diff --git a/src/NATS.Client.Core/Internal/WebSocketConnection.cs b/src/NATS.Client.Core/Internal/WebSocketConnection.cs index 588ad82f3..74085e21e 100644 --- a/src/NATS.Client.Core/Internal/WebSocketConnection.cs +++ b/src/NATS.Client.Core/Internal/WebSocketConnection.cs @@ -39,11 +39,12 @@ public Task ConnectAsync(Uri uri, CancellationToken cancellationToken) /// /// Connect with Timeout. When failed, Dispose this connection. /// - public async ValueTask ConnectAsync(Uri uri, TimeSpan timeout) + public async ValueTask ConnectAsync(Uri uri, NatsOpts opts) { - using var cts = new CancellationTokenSource(timeout); + using var cts = new CancellationTokenSource(opts.ConnectTimeout); try { + await InvokeCallbackForClientWebSocketOptionsAsync(opts, uri, _socket.Options, cts.Token).ConfigureAwait(false); await _socket.ConnectAsync(uri, cts.Token).ConfigureAwait(false); } catch (Exception ex) @@ -130,4 +131,13 @@ public void SignalDisconnected(Exception exception) { _waitForClosedSource.TrySetResult(exception); } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private async Task InvokeCallbackForClientWebSocketOptionsAsync(NatsOpts opts, Uri uri, ClientWebSocketOptions options, CancellationToken token) + { + if (opts.ConfigureWebSocketOpts != null) + { + await opts.ConfigureWebSocketOpts(uri, options, token).ConfigureAwait(false); + } + } } diff --git a/src/NATS.Client.Core/NatsConnection.cs b/src/NATS.Client.Core/NatsConnection.cs index 93d2ab021..6e4967b97 100644 --- a/src/NATS.Client.Core/NatsConnection.cs +++ b/src/NATS.Client.Core/NatsConnection.cs @@ -318,7 +318,7 @@ private async ValueTask InitialConnectAsync() if (uri.IsWebSocket) { var conn = new WebSocketConnection(); - await conn.ConnectAsync(uri.Uri, Opts.ConnectTimeout).ConfigureAwait(false); + await conn.ConnectAsync(uri.Uri, Opts).ConfigureAwait(false); _socket = conn; } else @@ -606,7 +606,7 @@ private async void ReconnectLoop() { _logger.LogDebug(NatsLogEvents.Connection, "Trying to reconnect using WebSocket {Url} [{ReconnectCount}]", url, reconnectCount); var conn = new WebSocketConnection(); - await conn.ConnectAsync(url.Uri, Opts.ConnectTimeout).ConfigureAwait(false); + await conn.ConnectAsync(url.Uri, Opts).ConfigureAwait(false); _socket = conn; } else diff --git a/src/NATS.Client.Core/NatsOpts.cs b/src/NATS.Client.Core/NatsOpts.cs index c9e2f2f00..49b1ef1d9 100644 --- a/src/NATS.Client.Core/NatsOpts.cs +++ b/src/NATS.Client.Core/NatsOpts.cs @@ -1,3 +1,4 @@ +using System.Net.WebSockets; using System.Text; using System.Threading.Channels; using Microsoft.Extensions.Logging; @@ -114,6 +115,28 @@ public sealed record NatsOpts /// public BoundedChannelFullMode SubPendingChannelFullMode { get; init; } = BoundedChannelFullMode.DropNewest; + /// + /// An optional async callback handler for manipulation of ClientWebSocketOptions used for WebSocket connections. + /// + /// + /// This can be used to set authorization header and other HTTP header values. + /// Note: Setting HTTP header values is not supported by Blazor WebAssembly as the underlying browser implementation does not support adding headers to a WebSocket. + /// The callback's execution time contributes to the connection establishment subject to the . + /// Implementors should use the passed CancellationToken for async operations called by this handler. + /// + /// + /// await using var nats = new NatsConnection(new NatsOpts + /// { + /// Url = "ws://localhost:8080", + /// ConfigureWebSocketOpts = (serverUri, clientWsOpts, ct) => + /// { + /// clientWsOpts.SetRequestHeader("authorization", $"Bearer MY_TOKEN"); + /// return ValueTask.CompletedTask; + /// }, + /// }); + /// + public Func? ConfigureWebSocketOpts { get; init; } = null; + internal NatsUri[] GetSeedUris() { var urls = Url.Split(','); diff --git a/tests/NATS.Client.Core.Tests/WebSocketOptionsTest.cs b/tests/NATS.Client.Core.Tests/WebSocketOptionsTest.cs new file mode 100644 index 000000000..fce13df15 --- /dev/null +++ b/tests/NATS.Client.Core.Tests/WebSocketOptionsTest.cs @@ -0,0 +1,325 @@ +using System.Net; +using Microsoft.Extensions.Logging; +using NATS.Client.TestUtilities; + +namespace NATS.Client.Core.Tests; + +public class WebSocketOptionsTest +{ + private readonly List _logs = new(); + + // Modeled after similar test in SendBufferTest.cs which also uses the MockServer. + [Fact] + public async Task MockWebsocketServer_PubSubWithCancelAndReconnect_ShouldCallbackTwice() + { + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + + List pubs = new(); + await using var server = new MockServer( + handler: (client, cmd) => + { + if (cmd.Name == "PUB") + { + lock (pubs) + pubs.Add($"PUB {cmd.Subject}"); + } + + if (cmd is { Name: "PUB", Subject: "close" }) + { + client.Close(); + } + + return Task.CompletedTask; + }, + Log, + info: $"{{\"max_payload\":{1024 * 4}}}", + cancellationToken: cts.Token); + + await using var wsServer = new WebSocketMockServer( + server.Url, + connectHandler: (httpContext) => + { + return true; + }, + Log, + cts.Token); + + Log("__________________________________"); + + var testLogger = new InMemoryTestLoggerFactory(LogLevel.Warning, m => + { + Log($"[NC] {m.Message}"); + }); + + var tokenCount = 0; + await using var nats = new NatsConnection(new NatsOpts + { + Url = wsServer.WebSocketUrl, + LoggerFactory = testLogger, + ConfigureWebSocketOpts = (serverUri, clientWsOpts, ct) => + { + tokenCount++; + Log($"[C] ConfigureWebSocketOpts {serverUri}, accessToken TOKEN_{tokenCount}"); + clientWsOpts.SetRequestHeader("authorization", $"Bearer TOKEN_{tokenCount}"); + return ValueTask.CompletedTask; + }, + }); + + Log($"[C] connect {server.Url}"); + await nats.ConnectAsync(); + + Log($"[C] ping"); + var rtt = await nats.PingAsync(cts.Token); + Log($"[C] ping rtt={rtt}"); + + Log($"[C] publishing x1..."); + await nats.PublishAsync("x1", "x", cancellationToken: cts.Token); + + // we will close the connection in mock server when we receive subject "close" + Log($"[C] publishing close (4KB)..."); + var pubTask = nats.PublishAsync("close", new byte[1024 * 4], cancellationToken: cts.Token).AsTask(); + + await pubTask.WaitAsync(cts.Token); + + for (var i = 1; i <= 10; i++) + { + try + { + await nats.PingAsync(cts.Token); + break; + } + catch (OperationCanceledException) + { + if (i == 10) + throw; + await Task.Delay(10 * i, cts.Token); + } + } + + Log($"[C] publishing x2..."); + await nats.PublishAsync("x2", "x", cancellationToken: cts.Token); + + Log($"[C] flush..."); + await nats.PingAsync(cts.Token); + + // Look for logs like the following: + // [WS] Received WebSocketRequest with authorization header: Bearer TOKEN_2 + var tokens = GetLogs().Where(l => l.Contains("Bearer")).ToList(); + Assert.Equal(2, tokens.Count); + var token = tokens.Where(t => t.Contains("TOKEN_1")); + Assert.Single(token); + token = tokens.Where(t => t.Contains("TOKEN_2")); + Assert.Single(token); + + lock (pubs) + { + Assert.Equal(3, pubs.Count); + Assert.Equal("PUB x1", pubs[0]); + Assert.Equal("PUB close", pubs[1]); + Assert.Equal("PUB x2", pubs[2]); + } + } + + [Fact] + public async Task WebSocketRespondsWithHttpError_ShouldThrowNatsException() + { + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + + await using var server = new MockServer( + handler: (client, cmd) => + { + return Task.CompletedTask; + }, + Log, + info: $"{{\"max_payload\":{1024 * 4}}}", + cancellationToken: cts.Token); + + await using var wsServer = new WebSocketMockServer( + server.Url, + connectHandler: (httpContext) => + { + httpContext.Response.StatusCode = (int)HttpStatusCode.Forbidden; + return false; // reject connection + }, + Log, + cts.Token); + + Log("__________________________________"); + + var testLogger = new InMemoryTestLoggerFactory(LogLevel.Warning, m => + { + Log($"[NC] {m.Message}"); + }); + + await using var nats = new NatsConnection(new NatsOpts + { + Url = wsServer.WebSocketUrl, + LoggerFactory = testLogger, + }); + + Log($"[C] connect {server.Url}"); + + // expect: NATS.Client.Core.NatsException : can not connect uris: ws://127.0.0.1:5004 + var exception = await Assert.ThrowsAsync(async () => await nats.ConnectAsync()); + } + + [Fact] + public async Task HttpErrorDuringReconnect_ShouldContinueToReconnect() + { + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + + await using var server = new MockServer( + handler: (client, cmd) => + { + if (cmd is { Name: "PUB", Subject: "close" }) + { + client.Close(); + } + + return Task.CompletedTask; + }, + Log, + info: $"{{\"max_payload\":{1024 * 4}}}", + cancellationToken: cts.Token); + + var tokenCount = 0; + + await using var wsServer = new WebSocketMockServer( + server.Url, + connectHandler: (httpContext) => + { + var token = httpContext.Request.Headers.Authorization; + if (token.Contains("Bearer TOKEN_1") || token.Contains("Bearer TOKEN_4")) + { + return true; + } + else + { + httpContext.Response.StatusCode = (int)HttpStatusCode.Forbidden; + return false; // reject connection + } + }, + Log, + cts.Token); + + Log("__________________________________"); + + var testLogger = new InMemoryTestLoggerFactory(LogLevel.Warning, m => + { + Log($"[NC] {m.Message}"); + }); + + await using var nats = new NatsConnection(new NatsOpts + { + Url = wsServer.WebSocketUrl, + LoggerFactory = testLogger, + ConfigureWebSocketOpts = (serverUri, clientWsOpts, ct) => + { + tokenCount++; + Log($"[C] ConfigureWebSocketOpts {serverUri}, accessToken TOKEN_{tokenCount}"); + clientWsOpts.SetRequestHeader("authorization", $"Bearer TOKEN_{tokenCount}"); + return ValueTask.CompletedTask; + }, + }); + + Log($"[C] connect {server.Url}"); + + // close connection and trigger reconnect + Log($"[C] publishing close ..."); + await nats.PublishAsync("close", "x", cancellationToken: cts.Token); + + for (var i = 1; i <= 10; i++) + { + try + { + await nats.PingAsync(cts.Token); + break; + } + catch (OperationCanceledException) + { + if (i == 10) + throw; + await Task.Delay(100 * i, cts.Token); + } + } + + Log($"[C] publishing reconnected"); + await nats.PublishAsync("reconnected", "rc", cancellationToken: cts.Token); + await Task.Delay(100); // short delay to allow log to be collected for reconnect + + // Expect to see in logs: + // 1st callback and TOKEN_1 + // Initial Connect + // 2nd callback with rejected TOKEN_2 + // NC reconnect + // 3rd callback with rejected TOKEN_3 + // NC reconnect + // 4th callback with good TOKEN_4 + // Successful Publish after reconnect + + // 4 tokens + var logs = GetLogs(); + var tokens = logs.Where(l => l.Contains("Bearer")).ToList(); + Assert.Equal(4, tokens.Count); + Assert.Single(tokens.Where(t => t.Contains("TOKEN_1"))); + Assert.Single(tokens.Where(t => t.Contains("TOKEN_2"))); + Assert.Single(tokens.Where(t => t.Contains("TOKEN_3"))); + Assert.Single(tokens.Where(t => t.Contains("TOKEN_4"))); + + // 2 errors in NATS.Client triggering the reconnect + var failures = logs.Where(l => l.Contains("[NC] Failed to connect NATS")); + Assert.Equal(2, failures.Count()); + + // 2 connects in MockServer + var connects = logs.Where(l => l.Contains("RCV CONNECT")); + Assert.Equal(2, failures.Count()); + + // 1 reconnect in MockServer + var reconnectPublish = logs.Where(l => l.Contains("RCV PUB reconnected")); + Assert.Single(reconnectPublish); + } + + [Fact] + public async Task ExceptionThrownInCallback_ShouldThrowNatsException() + { + // void Log(string m) => TmpFileLogger.Log(m); + List logs = new(); + void Log(string m) + { + lock (logs) + logs.Add(m); + } + + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + + var testLogger = new InMemoryTestLoggerFactory(LogLevel.Warning, m => + { + Log($"[NC] {m.Message}"); + }); + + await using var nats = new NatsConnection(new NatsOpts + { + Url = "ws://localhost:1234", + LoggerFactory = testLogger, + ConfigureWebSocketOpts = (serverUri, clientWsOpts, ct) => + { + throw new Exception("Error in callback"); + }, + }); + + // expect: NATS.Client.Core.NatsException : can not connect uris: ws://localhost:1234 + var exception = await Assert.ThrowsAsync(async () => await nats.ConnectAsync()); + } + + private void Log(string m) + { + lock (_logs) + _logs.Add(m); + } + + private List GetLogs() + { + lock (_logs) + return _logs.ToList(); + } +} diff --git a/tests/NATS.Client.TestUtilities/InMemoryTestLoggerFactory.cs b/tests/NATS.Client.TestUtilities/InMemoryTestLoggerFactory.cs index c04100c24..46f39b375 100644 --- a/tests/NATS.Client.TestUtilities/InMemoryTestLoggerFactory.cs +++ b/tests/NATS.Client.TestUtilities/InMemoryTestLoggerFactory.cs @@ -47,7 +47,12 @@ public void Log(LogLevel logLevel, EventId eventId, TState state, Except public bool IsEnabled(LogLevel logLevel) => logLevel >= level; +#if NET8_0_OR_GREATER + public IDisposable? BeginScope(TState state) + where TState : notnull => new NullDisposable(); +#else public IDisposable BeginScope(TState state) => new NullDisposable(); +#endif private class NullDisposable : IDisposable { diff --git a/tests/NATS.Client.TestUtilities/NATS.Client.TestUtilities.csproj b/tests/NATS.Client.TestUtilities/NATS.Client.TestUtilities.csproj index 4bf81b8e4..b9421f412 100644 --- a/tests/NATS.Client.TestUtilities/NATS.Client.TestUtilities.csproj +++ b/tests/NATS.Client.TestUtilities/NATS.Client.TestUtilities.csproj @@ -1,4 +1,4 @@ - + net6.0;net8.0 diff --git a/tests/NATS.Client.TestUtilities/OutputHelperLogger.cs b/tests/NATS.Client.TestUtilities/OutputHelperLogger.cs index a487ab562..75f5285e4 100644 --- a/tests/NATS.Client.TestUtilities/OutputHelperLogger.cs +++ b/tests/NATS.Client.TestUtilities/OutputHelperLogger.cs @@ -41,7 +41,12 @@ public Logger(string categoryName, ITestOutputHelper testOutputHelper, NatsServe _natsServer = natsServer; } +#if NET8_0_OR_GREATER + public IDisposable? BeginScope(TState state) + where TState : notnull +#else public IDisposable BeginScope(TState state) +#endif { return NullDisposable.Instance; } diff --git a/tests/NATS.Client.TestUtilities/WebSocketMockServer.cs b/tests/NATS.Client.TestUtilities/WebSocketMockServer.cs new file mode 100644 index 000000000..33779d500 --- /dev/null +++ b/tests/NATS.Client.TestUtilities/WebSocketMockServer.cs @@ -0,0 +1,138 @@ +using System.Net; +using System.Net.Sockets; +using System.Net.WebSockets; +using System.Text; +using Microsoft.AspNetCore; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; + +namespace NATS.Client.TestUtilities; + +public class WebSocketMockServer : IAsyncDisposable +{ + private readonly string _natsServerUrl; + private readonly Action _logger; + private readonly CancellationTokenSource _cts; + private readonly Task _wsServerTask; + + public WebSocketMockServer( + string natsServerUrl, + Func connectHandler, + Action logger, + CancellationToken cancellationToken = default) + { + _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + cancellationToken = _cts.Token; + _natsServerUrl = natsServerUrl; + _logger = logger; + WebSocketPort = 5004; + + _wsServerTask = RunWsServer(connectHandler, cancellationToken); + } + + public int WebSocketPort { get; } + + public string WebSocketUrl => $"ws://127.0.0.1:{WebSocketPort}"; + + public async ValueTask DisposeAsync() + { + _cts.Cancel(); + + try + { + await _wsServerTask.WaitAsync(TimeSpan.FromSeconds(10)); + } + catch (TimeoutException) + { + } + catch (OperationCanceledException) + { + } + catch (SocketException) + { + } + catch (IOException) + { + } + } + + private Task RunWsServer(Func connectHandler, CancellationToken ct) + { + var wsServerTask = WebHost.CreateDefaultBuilder() + .SuppressStatusMessages(true) + .ConfigureLogging(logging => logging.ClearProviders()) + .ConfigureKestrel(options => options.ListenLocalhost(WebSocketPort)) // unfortunately need to hard-code WebSocket port because ListenLocalhost() doesn't support picking a dynamic port + .Configure(app => app.UseWebSockets().Run(async context => + { + _logger($"[WS] Received WebSocketRequest with authorization header: {context.Request.Headers.Authorization}"); + + if (!connectHandler(context)) + return; + + if (context.WebSockets.IsWebSocketRequest) + { + using var webSocket = await context.WebSockets.AcceptWebSocketAsync(); + await HandleRequestResponse(webSocket, ct); + } + })) + .Build().RunAsync(ct); + + return wsServerTask; + } + + private async Task HandleRequestResponse(WebSocket webSocket, CancellationToken ct) + { + var wsRequestBuffer = new byte[1024 * 4]; + using TcpClient tcpClient = new(); + var endpoint = IPEndPoint.Parse(_natsServerUrl); + await tcpClient.ConnectAsync(endpoint); + await using var stream = tcpClient.GetStream(); + + // send responses received from NATS mock server back to WebSocket client + var respondBackTask = Task.Run(async () => + { + try + { + var tcpResponseBuffer = new byte[1024 * 4]; + + while (!ct.IsCancellationRequested) + { + var received = await stream.ReadAsync(tcpResponseBuffer, ct); + + var message = Encoding.UTF8.GetString(tcpResponseBuffer, 0, received); + + await webSocket.SendAsync( + new ArraySegment(tcpResponseBuffer, 0, received), + WebSocketMessageType.Binary, + true, + ct); + } + } + catch (Exception e) + { + // if our TCP connection with the NATS mock server breaks then close the WebSocket too. + _logger($"[WS] Exception in response task: {e.Message}"); + webSocket.Abort(); + } + }); + + // forward received message via TCP to NATS mock server + var receiveResult = await webSocket.ReceiveAsync( + new ArraySegment(wsRequestBuffer), ct); + + while (!receiveResult.CloseStatus.HasValue) + { + await stream.WriteAsync(wsRequestBuffer, 0, receiveResult.Count, ct); + + receiveResult = await webSocket.ReceiveAsync( + new ArraySegment(wsRequestBuffer), ct); + } + + await webSocket.CloseAsync( + receiveResult.CloseStatus.Value, + receiveResult.CloseStatusDescription, + CancellationToken.None); + } +}