diff --git a/src/NATS.Client.Core/Internal/WebSocketConnection.cs b/src/NATS.Client.Core/Internal/WebSocketConnection.cs
index 588ad82f3..74085e21e 100644
--- a/src/NATS.Client.Core/Internal/WebSocketConnection.cs
+++ b/src/NATS.Client.Core/Internal/WebSocketConnection.cs
@@ -39,11 +39,12 @@ public Task ConnectAsync(Uri uri, CancellationToken cancellationToken)
///
/// Connect with Timeout. When failed, Dispose this connection.
///
- public async ValueTask ConnectAsync(Uri uri, TimeSpan timeout)
+ public async ValueTask ConnectAsync(Uri uri, NatsOpts opts)
{
- using var cts = new CancellationTokenSource(timeout);
+ using var cts = new CancellationTokenSource(opts.ConnectTimeout);
try
{
+ await InvokeCallbackForClientWebSocketOptionsAsync(opts, uri, _socket.Options, cts.Token).ConfigureAwait(false);
await _socket.ConnectAsync(uri, cts.Token).ConfigureAwait(false);
}
catch (Exception ex)
@@ -130,4 +131,13 @@ public void SignalDisconnected(Exception exception)
{
_waitForClosedSource.TrySetResult(exception);
}
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private async Task InvokeCallbackForClientWebSocketOptionsAsync(NatsOpts opts, Uri uri, ClientWebSocketOptions options, CancellationToken token)
+ {
+ if (opts.ConfigureWebSocketOpts != null)
+ {
+ await opts.ConfigureWebSocketOpts(uri, options, token).ConfigureAwait(false);
+ }
+ }
}
diff --git a/src/NATS.Client.Core/NatsConnection.cs b/src/NATS.Client.Core/NatsConnection.cs
index 93d2ab021..6e4967b97 100644
--- a/src/NATS.Client.Core/NatsConnection.cs
+++ b/src/NATS.Client.Core/NatsConnection.cs
@@ -318,7 +318,7 @@ private async ValueTask InitialConnectAsync()
if (uri.IsWebSocket)
{
var conn = new WebSocketConnection();
- await conn.ConnectAsync(uri.Uri, Opts.ConnectTimeout).ConfigureAwait(false);
+ await conn.ConnectAsync(uri.Uri, Opts).ConfigureAwait(false);
_socket = conn;
}
else
@@ -606,7 +606,7 @@ private async void ReconnectLoop()
{
_logger.LogDebug(NatsLogEvents.Connection, "Trying to reconnect using WebSocket {Url} [{ReconnectCount}]", url, reconnectCount);
var conn = new WebSocketConnection();
- await conn.ConnectAsync(url.Uri, Opts.ConnectTimeout).ConfigureAwait(false);
+ await conn.ConnectAsync(url.Uri, Opts).ConfigureAwait(false);
_socket = conn;
}
else
diff --git a/src/NATS.Client.Core/NatsOpts.cs b/src/NATS.Client.Core/NatsOpts.cs
index c9e2f2f00..49b1ef1d9 100644
--- a/src/NATS.Client.Core/NatsOpts.cs
+++ b/src/NATS.Client.Core/NatsOpts.cs
@@ -1,3 +1,4 @@
+using System.Net.WebSockets;
using System.Text;
using System.Threading.Channels;
using Microsoft.Extensions.Logging;
@@ -114,6 +115,28 @@ public sealed record NatsOpts
///
public BoundedChannelFullMode SubPendingChannelFullMode { get; init; } = BoundedChannelFullMode.DropNewest;
+ ///
+ /// An optional async callback handler for manipulation of ClientWebSocketOptions used for WebSocket connections.
+ ///
+ ///
+ /// This can be used to set authorization header and other HTTP header values.
+ /// Note: Setting HTTP header values is not supported by Blazor WebAssembly as the underlying browser implementation does not support adding headers to a WebSocket.
+ /// The callback's execution time contributes to the connection establishment subject to the .
+ /// Implementors should use the passed CancellationToken for async operations called by this handler.
+ ///
+ ///
+ /// await using var nats = new NatsConnection(new NatsOpts
+ /// {
+ /// Url = "ws://localhost:8080",
+ /// ConfigureWebSocketOpts = (serverUri, clientWsOpts, ct) =>
+ /// {
+ /// clientWsOpts.SetRequestHeader("authorization", $"Bearer MY_TOKEN");
+ /// return ValueTask.CompletedTask;
+ /// },
+ /// });
+ ///
+ public Func? ConfigureWebSocketOpts { get; init; } = null;
+
internal NatsUri[] GetSeedUris()
{
var urls = Url.Split(',');
diff --git a/tests/NATS.Client.Core.Tests/WebSocketOptionsTest.cs b/tests/NATS.Client.Core.Tests/WebSocketOptionsTest.cs
new file mode 100644
index 000000000..fce13df15
--- /dev/null
+++ b/tests/NATS.Client.Core.Tests/WebSocketOptionsTest.cs
@@ -0,0 +1,325 @@
+using System.Net;
+using Microsoft.Extensions.Logging;
+using NATS.Client.TestUtilities;
+
+namespace NATS.Client.Core.Tests;
+
+public class WebSocketOptionsTest
+{
+ private readonly List _logs = new();
+
+ // Modeled after similar test in SendBufferTest.cs which also uses the MockServer.
+ [Fact]
+ public async Task MockWebsocketServer_PubSubWithCancelAndReconnect_ShouldCallbackTwice()
+ {
+ using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10));
+
+ List pubs = new();
+ await using var server = new MockServer(
+ handler: (client, cmd) =>
+ {
+ if (cmd.Name == "PUB")
+ {
+ lock (pubs)
+ pubs.Add($"PUB {cmd.Subject}");
+ }
+
+ if (cmd is { Name: "PUB", Subject: "close" })
+ {
+ client.Close();
+ }
+
+ return Task.CompletedTask;
+ },
+ Log,
+ info: $"{{\"max_payload\":{1024 * 4}}}",
+ cancellationToken: cts.Token);
+
+ await using var wsServer = new WebSocketMockServer(
+ server.Url,
+ connectHandler: (httpContext) =>
+ {
+ return true;
+ },
+ Log,
+ cts.Token);
+
+ Log("__________________________________");
+
+ var testLogger = new InMemoryTestLoggerFactory(LogLevel.Warning, m =>
+ {
+ Log($"[NC] {m.Message}");
+ });
+
+ var tokenCount = 0;
+ await using var nats = new NatsConnection(new NatsOpts
+ {
+ Url = wsServer.WebSocketUrl,
+ LoggerFactory = testLogger,
+ ConfigureWebSocketOpts = (serverUri, clientWsOpts, ct) =>
+ {
+ tokenCount++;
+ Log($"[C] ConfigureWebSocketOpts {serverUri}, accessToken TOKEN_{tokenCount}");
+ clientWsOpts.SetRequestHeader("authorization", $"Bearer TOKEN_{tokenCount}");
+ return ValueTask.CompletedTask;
+ },
+ });
+
+ Log($"[C] connect {server.Url}");
+ await nats.ConnectAsync();
+
+ Log($"[C] ping");
+ var rtt = await nats.PingAsync(cts.Token);
+ Log($"[C] ping rtt={rtt}");
+
+ Log($"[C] publishing x1...");
+ await nats.PublishAsync("x1", "x", cancellationToken: cts.Token);
+
+ // we will close the connection in mock server when we receive subject "close"
+ Log($"[C] publishing close (4KB)...");
+ var pubTask = nats.PublishAsync("close", new byte[1024 * 4], cancellationToken: cts.Token).AsTask();
+
+ await pubTask.WaitAsync(cts.Token);
+
+ for (var i = 1; i <= 10; i++)
+ {
+ try
+ {
+ await nats.PingAsync(cts.Token);
+ break;
+ }
+ catch (OperationCanceledException)
+ {
+ if (i == 10)
+ throw;
+ await Task.Delay(10 * i, cts.Token);
+ }
+ }
+
+ Log($"[C] publishing x2...");
+ await nats.PublishAsync("x2", "x", cancellationToken: cts.Token);
+
+ Log($"[C] flush...");
+ await nats.PingAsync(cts.Token);
+
+ // Look for logs like the following:
+ // [WS] Received WebSocketRequest with authorization header: Bearer TOKEN_2
+ var tokens = GetLogs().Where(l => l.Contains("Bearer")).ToList();
+ Assert.Equal(2, tokens.Count);
+ var token = tokens.Where(t => t.Contains("TOKEN_1"));
+ Assert.Single(token);
+ token = tokens.Where(t => t.Contains("TOKEN_2"));
+ Assert.Single(token);
+
+ lock (pubs)
+ {
+ Assert.Equal(3, pubs.Count);
+ Assert.Equal("PUB x1", pubs[0]);
+ Assert.Equal("PUB close", pubs[1]);
+ Assert.Equal("PUB x2", pubs[2]);
+ }
+ }
+
+ [Fact]
+ public async Task WebSocketRespondsWithHttpError_ShouldThrowNatsException()
+ {
+ using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10));
+
+ await using var server = new MockServer(
+ handler: (client, cmd) =>
+ {
+ return Task.CompletedTask;
+ },
+ Log,
+ info: $"{{\"max_payload\":{1024 * 4}}}",
+ cancellationToken: cts.Token);
+
+ await using var wsServer = new WebSocketMockServer(
+ server.Url,
+ connectHandler: (httpContext) =>
+ {
+ httpContext.Response.StatusCode = (int)HttpStatusCode.Forbidden;
+ return false; // reject connection
+ },
+ Log,
+ cts.Token);
+
+ Log("__________________________________");
+
+ var testLogger = new InMemoryTestLoggerFactory(LogLevel.Warning, m =>
+ {
+ Log($"[NC] {m.Message}");
+ });
+
+ await using var nats = new NatsConnection(new NatsOpts
+ {
+ Url = wsServer.WebSocketUrl,
+ LoggerFactory = testLogger,
+ });
+
+ Log($"[C] connect {server.Url}");
+
+ // expect: NATS.Client.Core.NatsException : can not connect uris: ws://127.0.0.1:5004
+ var exception = await Assert.ThrowsAsync(async () => await nats.ConnectAsync());
+ }
+
+ [Fact]
+ public async Task HttpErrorDuringReconnect_ShouldContinueToReconnect()
+ {
+ using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10));
+
+ await using var server = new MockServer(
+ handler: (client, cmd) =>
+ {
+ if (cmd is { Name: "PUB", Subject: "close" })
+ {
+ client.Close();
+ }
+
+ return Task.CompletedTask;
+ },
+ Log,
+ info: $"{{\"max_payload\":{1024 * 4}}}",
+ cancellationToken: cts.Token);
+
+ var tokenCount = 0;
+
+ await using var wsServer = new WebSocketMockServer(
+ server.Url,
+ connectHandler: (httpContext) =>
+ {
+ var token = httpContext.Request.Headers.Authorization;
+ if (token.Contains("Bearer TOKEN_1") || token.Contains("Bearer TOKEN_4"))
+ {
+ return true;
+ }
+ else
+ {
+ httpContext.Response.StatusCode = (int)HttpStatusCode.Forbidden;
+ return false; // reject connection
+ }
+ },
+ Log,
+ cts.Token);
+
+ Log("__________________________________");
+
+ var testLogger = new InMemoryTestLoggerFactory(LogLevel.Warning, m =>
+ {
+ Log($"[NC] {m.Message}");
+ });
+
+ await using var nats = new NatsConnection(new NatsOpts
+ {
+ Url = wsServer.WebSocketUrl,
+ LoggerFactory = testLogger,
+ ConfigureWebSocketOpts = (serverUri, clientWsOpts, ct) =>
+ {
+ tokenCount++;
+ Log($"[C] ConfigureWebSocketOpts {serverUri}, accessToken TOKEN_{tokenCount}");
+ clientWsOpts.SetRequestHeader("authorization", $"Bearer TOKEN_{tokenCount}");
+ return ValueTask.CompletedTask;
+ },
+ });
+
+ Log($"[C] connect {server.Url}");
+
+ // close connection and trigger reconnect
+ Log($"[C] publishing close ...");
+ await nats.PublishAsync("close", "x", cancellationToken: cts.Token);
+
+ for (var i = 1; i <= 10; i++)
+ {
+ try
+ {
+ await nats.PingAsync(cts.Token);
+ break;
+ }
+ catch (OperationCanceledException)
+ {
+ if (i == 10)
+ throw;
+ await Task.Delay(100 * i, cts.Token);
+ }
+ }
+
+ Log($"[C] publishing reconnected");
+ await nats.PublishAsync("reconnected", "rc", cancellationToken: cts.Token);
+ await Task.Delay(100); // short delay to allow log to be collected for reconnect
+
+ // Expect to see in logs:
+ // 1st callback and TOKEN_1
+ // Initial Connect
+ // 2nd callback with rejected TOKEN_2
+ // NC reconnect
+ // 3rd callback with rejected TOKEN_3
+ // NC reconnect
+ // 4th callback with good TOKEN_4
+ // Successful Publish after reconnect
+
+ // 4 tokens
+ var logs = GetLogs();
+ var tokens = logs.Where(l => l.Contains("Bearer")).ToList();
+ Assert.Equal(4, tokens.Count);
+ Assert.Single(tokens.Where(t => t.Contains("TOKEN_1")));
+ Assert.Single(tokens.Where(t => t.Contains("TOKEN_2")));
+ Assert.Single(tokens.Where(t => t.Contains("TOKEN_3")));
+ Assert.Single(tokens.Where(t => t.Contains("TOKEN_4")));
+
+ // 2 errors in NATS.Client triggering the reconnect
+ var failures = logs.Where(l => l.Contains("[NC] Failed to connect NATS"));
+ Assert.Equal(2, failures.Count());
+
+ // 2 connects in MockServer
+ var connects = logs.Where(l => l.Contains("RCV CONNECT"));
+ Assert.Equal(2, failures.Count());
+
+ // 1 reconnect in MockServer
+ var reconnectPublish = logs.Where(l => l.Contains("RCV PUB reconnected"));
+ Assert.Single(reconnectPublish);
+ }
+
+ [Fact]
+ public async Task ExceptionThrownInCallback_ShouldThrowNatsException()
+ {
+ // void Log(string m) => TmpFileLogger.Log(m);
+ List logs = new();
+ void Log(string m)
+ {
+ lock (logs)
+ logs.Add(m);
+ }
+
+ using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10));
+
+ var testLogger = new InMemoryTestLoggerFactory(LogLevel.Warning, m =>
+ {
+ Log($"[NC] {m.Message}");
+ });
+
+ await using var nats = new NatsConnection(new NatsOpts
+ {
+ Url = "ws://localhost:1234",
+ LoggerFactory = testLogger,
+ ConfigureWebSocketOpts = (serverUri, clientWsOpts, ct) =>
+ {
+ throw new Exception("Error in callback");
+ },
+ });
+
+ // expect: NATS.Client.Core.NatsException : can not connect uris: ws://localhost:1234
+ var exception = await Assert.ThrowsAsync(async () => await nats.ConnectAsync());
+ }
+
+ private void Log(string m)
+ {
+ lock (_logs)
+ _logs.Add(m);
+ }
+
+ private List GetLogs()
+ {
+ lock (_logs)
+ return _logs.ToList();
+ }
+}
diff --git a/tests/NATS.Client.TestUtilities/InMemoryTestLoggerFactory.cs b/tests/NATS.Client.TestUtilities/InMemoryTestLoggerFactory.cs
index c04100c24..46f39b375 100644
--- a/tests/NATS.Client.TestUtilities/InMemoryTestLoggerFactory.cs
+++ b/tests/NATS.Client.TestUtilities/InMemoryTestLoggerFactory.cs
@@ -47,7 +47,12 @@ public void Log(LogLevel logLevel, EventId eventId, TState state, Except
public bool IsEnabled(LogLevel logLevel) => logLevel >= level;
+#if NET8_0_OR_GREATER
+ public IDisposable? BeginScope(TState state)
+ where TState : notnull => new NullDisposable();
+#else
public IDisposable BeginScope(TState state) => new NullDisposable();
+#endif
private class NullDisposable : IDisposable
{
diff --git a/tests/NATS.Client.TestUtilities/NATS.Client.TestUtilities.csproj b/tests/NATS.Client.TestUtilities/NATS.Client.TestUtilities.csproj
index 4bf81b8e4..b9421f412 100644
--- a/tests/NATS.Client.TestUtilities/NATS.Client.TestUtilities.csproj
+++ b/tests/NATS.Client.TestUtilities/NATS.Client.TestUtilities.csproj
@@ -1,4 +1,4 @@
-
+
net6.0;net8.0
diff --git a/tests/NATS.Client.TestUtilities/OutputHelperLogger.cs b/tests/NATS.Client.TestUtilities/OutputHelperLogger.cs
index a487ab562..75f5285e4 100644
--- a/tests/NATS.Client.TestUtilities/OutputHelperLogger.cs
+++ b/tests/NATS.Client.TestUtilities/OutputHelperLogger.cs
@@ -41,7 +41,12 @@ public Logger(string categoryName, ITestOutputHelper testOutputHelper, NatsServe
_natsServer = natsServer;
}
+#if NET8_0_OR_GREATER
+ public IDisposable? BeginScope(TState state)
+ where TState : notnull
+#else
public IDisposable BeginScope(TState state)
+#endif
{
return NullDisposable.Instance;
}
diff --git a/tests/NATS.Client.TestUtilities/WebSocketMockServer.cs b/tests/NATS.Client.TestUtilities/WebSocketMockServer.cs
new file mode 100644
index 000000000..33779d500
--- /dev/null
+++ b/tests/NATS.Client.TestUtilities/WebSocketMockServer.cs
@@ -0,0 +1,138 @@
+using System.Net;
+using System.Net.Sockets;
+using System.Net.WebSockets;
+using System.Text;
+using Microsoft.AspNetCore;
+using Microsoft.AspNetCore.Builder;
+using Microsoft.AspNetCore.Hosting;
+using Microsoft.AspNetCore.Http;
+using Microsoft.Extensions.Logging;
+
+namespace NATS.Client.TestUtilities;
+
+public class WebSocketMockServer : IAsyncDisposable
+{
+ private readonly string _natsServerUrl;
+ private readonly Action _logger;
+ private readonly CancellationTokenSource _cts;
+ private readonly Task _wsServerTask;
+
+ public WebSocketMockServer(
+ string natsServerUrl,
+ Func connectHandler,
+ Action logger,
+ CancellationToken cancellationToken = default)
+ {
+ _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
+ cancellationToken = _cts.Token;
+ _natsServerUrl = natsServerUrl;
+ _logger = logger;
+ WebSocketPort = 5004;
+
+ _wsServerTask = RunWsServer(connectHandler, cancellationToken);
+ }
+
+ public int WebSocketPort { get; }
+
+ public string WebSocketUrl => $"ws://127.0.0.1:{WebSocketPort}";
+
+ public async ValueTask DisposeAsync()
+ {
+ _cts.Cancel();
+
+ try
+ {
+ await _wsServerTask.WaitAsync(TimeSpan.FromSeconds(10));
+ }
+ catch (TimeoutException)
+ {
+ }
+ catch (OperationCanceledException)
+ {
+ }
+ catch (SocketException)
+ {
+ }
+ catch (IOException)
+ {
+ }
+ }
+
+ private Task RunWsServer(Func connectHandler, CancellationToken ct)
+ {
+ var wsServerTask = WebHost.CreateDefaultBuilder()
+ .SuppressStatusMessages(true)
+ .ConfigureLogging(logging => logging.ClearProviders())
+ .ConfigureKestrel(options => options.ListenLocalhost(WebSocketPort)) // unfortunately need to hard-code WebSocket port because ListenLocalhost() doesn't support picking a dynamic port
+ .Configure(app => app.UseWebSockets().Run(async context =>
+ {
+ _logger($"[WS] Received WebSocketRequest with authorization header: {context.Request.Headers.Authorization}");
+
+ if (!connectHandler(context))
+ return;
+
+ if (context.WebSockets.IsWebSocketRequest)
+ {
+ using var webSocket = await context.WebSockets.AcceptWebSocketAsync();
+ await HandleRequestResponse(webSocket, ct);
+ }
+ }))
+ .Build().RunAsync(ct);
+
+ return wsServerTask;
+ }
+
+ private async Task HandleRequestResponse(WebSocket webSocket, CancellationToken ct)
+ {
+ var wsRequestBuffer = new byte[1024 * 4];
+ using TcpClient tcpClient = new();
+ var endpoint = IPEndPoint.Parse(_natsServerUrl);
+ await tcpClient.ConnectAsync(endpoint);
+ await using var stream = tcpClient.GetStream();
+
+ // send responses received from NATS mock server back to WebSocket client
+ var respondBackTask = Task.Run(async () =>
+ {
+ try
+ {
+ var tcpResponseBuffer = new byte[1024 * 4];
+
+ while (!ct.IsCancellationRequested)
+ {
+ var received = await stream.ReadAsync(tcpResponseBuffer, ct);
+
+ var message = Encoding.UTF8.GetString(tcpResponseBuffer, 0, received);
+
+ await webSocket.SendAsync(
+ new ArraySegment(tcpResponseBuffer, 0, received),
+ WebSocketMessageType.Binary,
+ true,
+ ct);
+ }
+ }
+ catch (Exception e)
+ {
+ // if our TCP connection with the NATS mock server breaks then close the WebSocket too.
+ _logger($"[WS] Exception in response task: {e.Message}");
+ webSocket.Abort();
+ }
+ });
+
+ // forward received message via TCP to NATS mock server
+ var receiveResult = await webSocket.ReceiveAsync(
+ new ArraySegment(wsRequestBuffer), ct);
+
+ while (!receiveResult.CloseStatus.HasValue)
+ {
+ await stream.WriteAsync(wsRequestBuffer, 0, receiveResult.Count, ct);
+
+ receiveResult = await webSocket.ReceiveAsync(
+ new ArraySegment(wsRequestBuffer), ct);
+ }
+
+ await webSocket.CloseAsync(
+ receiveResult.CloseStatus.Value,
+ receiveResult.CloseStatusDescription,
+ CancellationToken.None);
+ }
+}