From 2ff424c2744ff304859d521ab6aa83d611d0cd8d Mon Sep 17 00:00:00 2001 From: Caleb Lloyd <2414837+caleblloyd@users.noreply.github.com> Date: Tue, 4 Jun 2024 15:47:13 -0400 Subject: [PATCH] Keep sub alive when reading channel (#506) * keep sub alive when reading channel Signed-off-by: Caleb Lloyd * refactor memory tests to use Task.Run Signed-off-by: Caleb Lloyd * subscribe benchmark Signed-off-by: Caleb Lloyd --------- Signed-off-by: Caleb Lloyd --- .../MicroBenchmark/PublishParallelBench.cs | 3 + sandbox/MicroBenchmark/PublishSerialBench.cs | 3 + sandbox/MicroBenchmark/Subscribe.cs | 132 +++++++++++++ .../Internal/ActivityEndingMsgReader.cs | 112 +++++++++-- .../NatsConnection.Subscribe.cs | 10 +- src/NATS.Client.Core/NatsSub.cs | 2 +- .../NatsSubTests.cs | 178 ++++++++++++++---- 7 files changed, 385 insertions(+), 55 deletions(-) create mode 100644 sandbox/MicroBenchmark/Subscribe.cs diff --git a/sandbox/MicroBenchmark/PublishParallelBench.cs b/sandbox/MicroBenchmark/PublishParallelBench.cs index 27f4b3310..b93b03f70 100644 --- a/sandbox/MicroBenchmark/PublishParallelBench.cs +++ b/sandbox/MicroBenchmark/PublishParallelBench.cs @@ -45,6 +45,9 @@ public async Task Setup() await _nats.ConnectAsync(); } + [GlobalCleanup] + public async Task Cleanup() => await _nats.DisposeAsync(); + [Benchmark] public async Task PublishParallelAsync() { diff --git a/sandbox/MicroBenchmark/PublishSerialBench.cs b/sandbox/MicroBenchmark/PublishSerialBench.cs index 52e5c1242..5f5920ef1 100644 --- a/sandbox/MicroBenchmark/PublishSerialBench.cs +++ b/sandbox/MicroBenchmark/PublishSerialBench.cs @@ -23,6 +23,9 @@ public async Task SetupAsync() await _nats.ConnectAsync(); } + [GlobalCleanup] + public async Task Cleanup() => await _nats.DisposeAsync(); + [Benchmark] public async Task PublishAsync() { diff --git a/sandbox/MicroBenchmark/Subscribe.cs b/sandbox/MicroBenchmark/Subscribe.cs new file mode 100644 index 000000000..7b1eed671 --- /dev/null +++ b/sandbox/MicroBenchmark/Subscribe.cs @@ -0,0 +1,132 @@ +using System.Diagnostics; +using BenchmarkDotNet.Attributes; +using NATS.Client.Core; + +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + +namespace MicroBenchmark; + +[MemoryDiagnoser] +[ShortRunJob] +[PlainExporter] +public class Subscribe +{ + private const int TotalMsgs = 500_000; + private NatsConnection _nats; + private CancellationTokenSource _cts; + private Task _pubTask; + + [GlobalSetup] + public async Task Setup() + { + _nats = new NatsConnection(NatsOpts.Default); + await _nats.ConnectAsync(); + } + + [GlobalCleanup] + public async Task Cleanup() => await _nats.DisposeAsync(); + + [IterationSetup] + public void IterSetup() + { + _cts = new CancellationTokenSource(); + _pubTask = PubTask(_cts); + } + + [IterationCleanup] + public void IterCleanup() + { + _cts.Cancel(); + _pubTask.GetAwaiter().GetResult(); + } + + [Benchmark] + public async Task SubscribeAsync() + { + var count = 0; +#pragma warning disable SA1312 + await foreach (var _ in _nats.SubscribeAsync("test")) +#pragma warning restore SA1312 + { + if (++count >= TotalMsgs) + { + return; + } + } + } + + [Benchmark] + public async Task CoreWait() + { + var count = 0; + await using var sub = await _nats.SubscribeCoreAsync("test"); + while (await sub.Msgs.WaitToReadAsync()) + { + while (sub.Msgs.TryRead(out _)) + { + if (++count >= TotalMsgs) + { + return; + } + } + } + } + + [Benchmark] + public async Task CoreRead() + { + var count = 0; + await using var sub = await _nats.SubscribeCoreAsync("test"); + while (true) + { + await sub.Msgs.ReadAsync(); + if (++count >= TotalMsgs) + { + return; + } + } + } + + [Benchmark] + public async Task CoreReadAll() + { + var count = 0; + await using var sub = await _nats.SubscribeCoreAsync("test"); +#pragma warning disable SA1312 + await foreach (var _ in sub.Msgs.ReadAllAsync()) +#pragma warning restore SA1312 + { + if (++count >= TotalMsgs) + { + return; + } + } + } + + // limit pub to the same rate across benchmarks + // pub in batches so that groups of messages are available + private Task PubTask(CancellationTokenSource cts) => + Task.Run(async () => + { + const long pubMaxPerSecond = TotalMsgs; + const long batchSize = 100; + const long ticksBetweenBatches = TimeSpan.TicksPerSecond / pubMaxPerSecond * batchSize; + + var sw = new Stopwatch(); + sw.Start(); + var lastTick = sw.ElapsedTicks; + var i = 0L; + while (!cts.IsCancellationRequested) + { + await _nats.PublishAsync("test", "data"); + if (++i % batchSize == 0) + { + while (sw.ElapsedTicks - lastTick < ticksBetweenBatches) + { + } + + lastTick = sw.ElapsedTicks; + } + } + }); +} diff --git a/src/NATS.Client.Core/Internal/ActivityEndingMsgReader.cs b/src/NATS.Client.Core/Internal/ActivityEndingMsgReader.cs index a18f6cc34..156917d2c 100644 --- a/src/NATS.Client.Core/Internal/ActivityEndingMsgReader.cs +++ b/src/NATS.Client.Core/Internal/ActivityEndingMsgReader.cs @@ -1,21 +1,66 @@ using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Threading.Channels; namespace NATS.Client.Core.Internal; +// ActivityEndingMsgReader servers 2 purposes +// 1. End activity for OpenTelemetry +// 2. Keep the INatsSub from being garbage collected as long as calls interacting +// with the _inner channel are being made +// To achieve (1): +// Calls that result in a read from the _inner channel should msg.Headers?.Activity?.Dispose() +// To achieve (2): +// Synchronous calls should call GC.KeepAlive(_sub); immediately before returning +// Asynchronous calls should allocate a GCHandle.Alloc(_sub) at the start of the method, +// and then free it in a try/finally block internal sealed class ActivityEndingMsgReader : ChannelReader> { private readonly ChannelReader> _inner; - public ActivityEndingMsgReader(ChannelReader> inner) => _inner = inner; + private readonly INatsSub _sub; - public override bool CanCount => _inner.CanCount; + public ActivityEndingMsgReader(ChannelReader> inner, INatsSub sub) + { + _inner = inner; + _sub = sub; + } - public override bool CanPeek => _inner.CanPeek; + public override bool CanCount + { + get + { + GC.KeepAlive(_sub); + return _inner.CanCount; + } + } - public override int Count => _inner.Count; + public override bool CanPeek + { + get + { + GC.KeepAlive(_sub); + return _inner.CanPeek; + } + } - public override Task Completion => _inner.Completion; + public override int Count + { + get + { + GC.KeepAlive(_sub); + return _inner.Count; + } + } + + public override Task Completion + { + get + { + GC.KeepAlive(_sub); + return _inner.Completion; + } + } /// [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -26,16 +71,61 @@ public override bool TryRead(out NatsMsg item) item.Headers?.Activity?.Dispose(); + GC.KeepAlive(_sub); return true; } - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public override ValueTask WaitToReadAsync(CancellationToken cancellationToken = default) => _inner.WaitToReadAsync(cancellationToken); + public override async ValueTask WaitToReadAsync(CancellationToken cancellationToken = default) + { + var handle = GCHandle.Alloc(_sub); + try + { + return await _inner.WaitToReadAsync(cancellationToken).ConfigureAwait(false); + } + finally + { + handle.Free(); + } + } - public override ValueTask> ReadAsync(CancellationToken cancellationToken = default) => _inner.ReadAsync(cancellationToken); + public override async ValueTask> ReadAsync(CancellationToken cancellationToken = default) + { + var handle = GCHandle.Alloc(_sub); + try + { + var msg = await _inner.ReadAsync(cancellationToken).ConfigureAwait(false); + msg.Headers?.Activity?.Dispose(); + return msg; + } + finally + { + handle.Free(); + } + } - public override bool TryPeek(out NatsMsg item) => _inner.TryPeek(out item); + public override bool TryPeek(out NatsMsg item) + { + GC.KeepAlive(_sub); + return _inner.TryPeek(out item); + } - public override IAsyncEnumerable> ReadAllAsync(CancellationToken cancellationToken = default) => _inner.ReadAllAsync(cancellationToken); + public override async IAsyncEnumerable> ReadAllAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var handle = GCHandle.Alloc(_sub); + try + { + while (await _inner.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) + { + while (_inner.TryRead(out var msg)) + { + msg.Headers?.Activity?.Dispose(); + yield return msg; + } + } + } + finally + { + handle.Free(); + } + } } diff --git a/src/NATS.Client.Core/NatsConnection.Subscribe.cs b/src/NATS.Client.Core/NatsConnection.Subscribe.cs index 649f7d1fd..7f169d98e 100644 --- a/src/NATS.Client.Core/NatsConnection.Subscribe.cs +++ b/src/NATS.Client.Core/NatsConnection.Subscribe.cs @@ -16,19 +16,15 @@ public async IAsyncEnumerable> SubscribeAsync(string subject, stri { serializer ??= Opts.SerializerRegistry.GetDeserializer(); + // call to RegisterSubAnchor is no longer needed; sub is kept alive in ActivityEndingMsgReader await using var sub = new NatsSub(this, SubscriptionManager.GetManagerFor(subject), subject, queueGroup, opts, serializer, cancellationToken); - using var anchor = RegisterSubAnchor(sub); - await SubAsync(sub, cancellationToken: cancellationToken).ConfigureAwait(false); // We don't cancel the channel reader here because we want to keep reading until the subscription // channel writer completes so that messages left in the channel can be consumed before exit the loop. - while (await sub.Msgs.WaitToReadAsync(CancellationToken.None).ConfigureAwait(false)) + await foreach (var msg in sub.Msgs.ReadAllAsync(CancellationToken.None).ConfigureAwait(false)) { - while (sub.Msgs.TryRead(out var msg)) - { - yield return msg; - } + yield return msg; } } diff --git a/src/NATS.Client.Core/NatsSub.cs b/src/NATS.Client.Core/NatsSub.cs index 67c80fc0c..0d5e51ec0 100644 --- a/src/NATS.Client.Core/NatsSub.cs +++ b/src/NATS.Client.Core/NatsSub.cs @@ -23,7 +23,7 @@ internal NatsSub( connection.GetChannelOpts(connection.Opts, opts?.ChannelOpts), msg => Connection.OnMessageDropped(this, _msgs?.Reader.Count ?? 0, msg)); - Msgs = new ActivityEndingMsgReader(_msgs.Reader); + Msgs = new ActivityEndingMsgReader(_msgs.Reader, this); Serializer = serializer; } diff --git a/tests/NATS.Client.Core.MemoryTests/NatsSubTests.cs b/tests/NATS.Client.Core.MemoryTests/NatsSubTests.cs index 0d2488109..d8708b4be 100644 --- a/tests/NATS.Client.Core.MemoryTests/NatsSubTests.cs +++ b/tests/NATS.Client.Core.MemoryTests/NatsSubTests.cs @@ -1,3 +1,4 @@ +using System.Threading.Channels; using JetBrains.dotMemoryUnit; using NATS.Client.Core.Tests; @@ -37,80 +38,185 @@ async Task Isolator() } finally { - server.DisposeAsync().GetAwaiter().GetResult(); + server.DisposeAsync().AsTask().GetAwaiter().GetResult(); } } [Test] - public void Subscription_should_not_be_collected_when_in_async_enumerator() + public void Subscription_should_not_be_collected_subscribe_async() { var server = NatsServer.Start(); try { + const int iterations = 10; + const string subject = "foo.data"; var nats = server.CreateClientConnection(new NatsOpts { RequestTimeout = TimeSpan.FromSeconds(10) }); + var received = Channel.CreateUnbounded(); - var sync = 0; - - var sub = Task.Run(async () => + var subTask = Task.Run(async () => { - var count = 0; - await foreach (var msg in nats.SubscribeAsync("foo.*")) + var i = 0; +#pragma warning disable SA1312 + await foreach (var _ in nats.SubscribeAsync(subject)) +#pragma warning restore SA1312 { - if (msg.Subject == "foo.sync") - { - Interlocked.Increment(ref sync); - continue; - } - - if (++count == 10) + await received.Writer.WriteAsync(new object()); + if (++i >= iterations) break; } }); - var pub = Task.Run(async () => + RunSubTest(iterations, subject, nats, received, subTask); + } + finally + { + server.DisposeAsync().AsTask().GetAwaiter().GetResult(); + } + } + + [Test] + public void Subscription_should_not_be_collected_subscribe_core_async_read_all_async() + { + var server = NatsServer.Start(); + try + { + const int iterations = 10; + const string subject = "foo.data"; + var nats = server.CreateClientConnection(new NatsOpts { RequestTimeout = TimeSpan.FromSeconds(10) }); + var received = Channel.CreateUnbounded(); + + var subTask = Task.Run(async () => { - while (Volatile.Read(ref sync) == 0) + var i = 0; + await using var sub = await nats.SubscribeCoreAsync(subject); +#pragma warning disable SA1312 + await foreach (var _ in sub.Msgs.ReadAllAsync()) +#pragma warning restore SA1312 { - await nats.PublishAsync("foo.sync", "sync"); + await received.Writer.WriteAsync(new object()); + if (++i >= iterations) + break; } + }); - for (var i = 0; i < 10; i++) + RunSubTest(iterations, subject, nats, received, subTask); + } + finally + { + server.DisposeAsync().AsTask().GetAwaiter().GetResult(); + } + } + + [Test] + public void Subscription_should_not_be_collected_subscribe_core_async_read_async() + { + var server = NatsServer.Start(); + try + { + const int iterations = 10; + const string subject = "foo.data"; + var nats = server.CreateClientConnection(new NatsOpts { RequestTimeout = TimeSpan.FromSeconds(10) }); + var received = Channel.CreateUnbounded(); + + var subTask = Task.Run(async () => + { + var i = 0; + await using var sub = await nats.SubscribeCoreAsync(subject); + while (true) { - GC.Collect(); + await sub.Msgs.ReadAsync(); + await received.Writer.WriteAsync(new object()); + if (++i >= iterations) + break; + } + }); - dotMemory.Check(memory => + RunSubTest(iterations, subject, nats, received, subTask); + } + finally + { + server.DisposeAsync().AsTask().GetAwaiter().GetResult(); + } + } + + [Test] + public void Subscription_should_not_be_collected_subscribe_core_async_wait_to_read_async() + { + var server = NatsServer.Start(); + try + { + const int iterations = 10; + const string subject = "foo.data"; + var nats = server.CreateClientConnection(new NatsOpts { RequestTimeout = TimeSpan.FromSeconds(10) }); + var received = Channel.CreateUnbounded(); + + var subTask = Task.Run(async () => + { + var i = 0; + await using var sub = await nats.SubscribeCoreAsync(subject); + while (await sub.Msgs.WaitToReadAsync()) + { + while (sub.Msgs.TryRead(out _)) { - var count = memory.GetObjects(where => where.Type.Is>()).ObjectsCount; - Assert.That(count, Is.EqualTo(1), "Alive"); - }); + await received.Writer.WriteAsync(new object()); + i++; + } - await nats.PublishAsync("foo.data", "data"); + if (i >= iterations) + { + break; + } } }); - var waitPub = Task.WaitAll(new[] { pub }, TimeSpan.FromSeconds(10)); - if (!waitPub) + RunSubTest(iterations, subject, nats, received, subTask); + } + finally + { + server.DisposeAsync().AsTask().GetAwaiter().GetResult(); + } + } + + private void RunSubTest(int iterations, string subject, NatsConnection nats, Channel received, Task subTask) + { + var i = 0; + var fail = 0; + while (true) + { + nats.PublishAsync(subject, "data").AsTask().GetAwaiter().GetResult(); + try { - Assert.Fail("Timed out waiting for pub task to complete"); + using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(100)); + received.Reader.ReadAsync(cts.Token).AsTask().GetAwaiter().GetResult(); } - - var waitSub = Task.WaitAll(new[] { sub }, TimeSpan.FromSeconds(10)); - if (!waitSub) + catch (OperationCanceledException) { - Assert.Fail("Timed out waiting for sub task to complete"); + if (++fail <= 10) + { + continue; + } + + Assert.Fail($"failed to receive a reply 10 times"); } - GC.Collect(); + if (++i >= iterations) + break; + GC.Collect(); dotMemory.Check(memory => { var count = memory.GetObjects(where => where.Type.Is>()).ObjectsCount; - Assert.That(count, Is.EqualTo(0), "Collected"); + Assert.That(count, Is.EqualTo(1), $"Alive - received {i}"); }); } - finally + + subTask.GetAwaiter().GetResult(); + + GC.Collect(); + dotMemory.Check(memory => { - server.DisposeAsync().GetAwaiter().GetResult(); - } + var count = memory.GetObjects(where => where.Type.Is>()).ObjectsCount; + Assert.That(count, Is.EqualTo(0), "Collected"); + }); } }