From f4c42c85360bc3e669fff56e915727930a5f1337 Mon Sep 17 00:00:00 2001 From: Steve Gordon Date: Wed, 30 Oct 2024 12:06:42 +0000 Subject: [PATCH] Enhance StreamResponse handling and update dependencies (#121) Updated `ResponseBuilderDefaults` to include `StreamResponse` in `SpecialTypes`. Refactored `SetBodyCoreAsync` in `DefaultResponseBuilder.cs` for readability and removed unnecessary `using` statements. Modified `RequestCoreAsync` in `HttpWebRequestInvoker.cs` and `BuildResponseAsync` in `InMemoryRequestInvoker.cs` to handle `StreamResponse` types with proper disposal. Updated `Elastic.Transport.csproj` to reference `System.Text.Json` version `8.0.5`. --- .../Pipeline/DefaultResponseBuilder.cs | 142 +++++---- .../TransportClient/HttpRequestInvoker.cs | 33 +- .../TransportClient/HttpWebRequestInvoker.cs | 59 +++- .../TransportClient/InMemoryRequestInvoker.cs | 52 ++-- .../Requests/Body/PostData.cs | 6 + .../Responses/Special/StreamResponse.cs | 41 ++- .../Responses/TransportResponse.cs | 24 +- .../Http/StreamResponseTests.cs | 208 +++++++++++++ .../Plumbing/InMemoryConnectionFactory.cs | 5 +- .../ResponseBuilderDisposeTests.cs | 294 +++++++++++------- 10 files changed, 635 insertions(+), 229 deletions(-) create mode 100644 tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs diff --git a/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs b/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs index 892bd64..e7b2a77 100644 --- a/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs +++ b/src/Elastic.Transport/Components/Pipeline/DefaultResponseBuilder.cs @@ -26,7 +26,7 @@ internal static class ResponseBuilderDefaults public static readonly Type[] SpecialTypes = { - typeof(StringResponse), typeof(BytesResponse), typeof(VoidResponse), typeof(DynamicResponse) + typeof(StringResponse), typeof(BytesResponse), typeof(VoidResponse), typeof(DynamicResponse), typeof(StreamResponse) }; } @@ -66,11 +66,8 @@ IReadOnlyDictionary tcpStats // Only attempt to set the body if the response may have content if (MayHaveBody(statusCode, requestData.Method, contentLength)) response = SetBody(details, requestData, responseStream, mimeType); - else - responseStream.Dispose(); response ??= new TResponse(); - response.ApiCallDetails = details; return response; } @@ -101,11 +98,8 @@ public override async Task ToResponseAsync( if (MayHaveBody(statusCode, requestData.Method, contentLength)) response = await SetBodyAsync(details, requestData, responseStream, mimeType, cancellationToken).ConfigureAwait(false); - else - responseStream.Dispose(); response ??= new TResponse(); - response.ApiCallDetails = details; return response; } @@ -211,6 +205,8 @@ private async ValueTask SetBodyCoreAsync(bool isAsync, var disableDirectStreaming = requestData.PostData?.DisableDirectStreaming ?? requestData.ConnectionSettings.DisableDirectStreaming; var requiresErrorDeserialization = RequiresErrorDeserialization(details, requestData); + var ownsStream = false; + if (disableDirectStreaming || NeedsToEagerReadStream() || requiresErrorDeserialization) { var inMemoryStream = requestData.MemoryStreamFactory.Create(); @@ -221,90 +217,111 @@ private async ValueTask SetBodyCoreAsync(bool isAsync, responseStream.CopyTo(inMemoryStream, BufferSize); bytes = SwapStreams(ref responseStream, ref inMemoryStream); + ownsStream = true; details.ResponseBodyInBytes = bytes; } - using (responseStream) + if (TrySetSpecialType(mimeType, bytes, responseStream, requestData.MemoryStreamFactory, out var response)) { - if (SetSpecialTypes(mimeType, bytes, responseStream, requestData.MemoryStreamFactory, out var r)) return r; + ConditionalDisposal(responseStream, ownsStream, response); + return response; + } - if (details.HttpStatusCode.HasValue && - requestData.SkipDeserializationForStatusCodes.Contains(details.HttpStatusCode.Value)) - return null; + if (details.HttpStatusCode.HasValue && + requestData.SkipDeserializationForStatusCodes.Contains(details.HttpStatusCode.Value)) + { + ConditionalDisposal(responseStream, ownsStream, response); + return null; + } - var serializer = requestData.ConnectionSettings.RequestResponseSerializer; + var serializer = requestData.ConnectionSettings.RequestResponseSerializer; - TResponse response; - if (requestData.CustomResponseBuilder != null) - { - var beforeTicks = Stopwatch.GetTimestamp(); + if (requestData.CustomResponseBuilder != null) + { + var beforeTicks = Stopwatch.GetTimestamp(); - if (isAsync) - response = await requestData.CustomResponseBuilder - .DeserializeResponseAsync(serializer, details, responseStream, cancellationToken) - .ConfigureAwait(false) as TResponse; - else - response = requestData.CustomResponseBuilder - .DeserializeResponse(serializer, details, responseStream) as TResponse; + if (isAsync) + response = await requestData.CustomResponseBuilder + .DeserializeResponseAsync(serializer, details, responseStream, cancellationToken) + .ConfigureAwait(false) as TResponse; + else + response = requestData.CustomResponseBuilder + .DeserializeResponse(serializer, details, responseStream) as TResponse; - var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000); - if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) - Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); + var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000); + if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) + Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); + ConditionalDisposal(responseStream, ownsStream, response); + return response; + } + + // TODO: Handle empty data in a nicer way as throwing exceptions has a cost we'd like to avoid! + // ie. check content-length (add to ApiCallDetails)? Content-length cannot be retrieved from a GZip content stream which is annoying. + try + { + if (requiresErrorDeserialization && TryGetError(details, requestData, responseStream, out var error) && error.HasError()) + { + response = new TResponse(); + SetErrorOnResponse(response, error); + ConditionalDisposal(responseStream, ownsStream, response); return response; } - // TODO: Handle empty data in a nicer way as throwing exceptions has a cost we'd like to avoid! - // ie. check content-length (add to ApiCallDetails)? Content-length cannot be retrieved from a GZip content stream which is annoying. - try + if (!requestData.ValidateResponseContentType(mimeType)) { - if (requiresErrorDeserialization && TryGetError(details, requestData, responseStream, out var error) && error.HasError()) - { - response = new TResponse(); - SetErrorOnResponse(response, error); - return response; - } + ConditionalDisposal(responseStream, ownsStream, response); + return default; + } - if (!requestData.ValidateResponseContentType(mimeType)) - return default; + var beforeTicks = Stopwatch.GetTimestamp(); - var beforeTicks = Stopwatch.GetTimestamp(); + if (isAsync) + response = await serializer.DeserializeAsync(responseStream, cancellationToken).ConfigureAwait(false); + else + response = serializer.Deserialize(responseStream); - if (isAsync) - response = await serializer.DeserializeAsync(responseStream, cancellationToken).ConfigureAwait(false); - else - response = serializer.Deserialize(responseStream); + var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000); - var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000); + if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) + Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); - if (deserializeResponseMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) - Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportDeserializeResponseMs, deserializeResponseMs); + ConditionalDisposal(responseStream, ownsStream, response); + return response; + } + catch (JsonException ex) when (ex.Message.Contains("The input does not contain any JSON tokens")) + { + // Note the exception this handles is ONLY thrown after a check if the stream length is zero. + // When the length is zero, `default` is returned by Deserialize(Async) instead. + ConditionalDisposal(responseStream, ownsStream, response); + return default; + } - return response; - } - catch (JsonException ex) when (ex.Message.Contains("The input does not contain any JSON tokens")) - { - return default; - } + static void ConditionalDisposal(Stream responseStream, bool ownsStream, TResponse response) + { + // We only dispose of the responseStream if we created it (i.e. it is a MemoryStream) we + // created via MemoryStreamFactory. + if (ownsStream && (response is null || !response.LeaveOpen)) + responseStream.Dispose(); } } - private static bool SetSpecialTypes(string mimeType, byte[] bytes, Stream responseStream, - MemoryStreamFactory memoryStreamFactory, out TResponse cs) + private static bool TrySetSpecialType(string mimeType, byte[] bytes, Stream responseStream, + MemoryStreamFactory memoryStreamFactory, out TResponse response) where TResponse : TransportResponse, new() { - cs = null; + response = null; var responseType = typeof(TResponse); if (!SpecialTypes.Contains(responseType)) return false; if (responseType == typeof(StringResponse)) - cs = new StringResponse(bytes.Utf8String()) as TResponse; + response = new StringResponse(bytes.Utf8String()) as TResponse; else if (responseType == typeof(StreamResponse)) - cs = new StreamResponse(responseStream, mimeType) as TResponse; + response = new StreamResponse(responseStream, mimeType) as TResponse; else if (responseType == typeof(BytesResponse)) - cs = new BytesResponse(bytes) as TResponse; + response = new BytesResponse(bytes) as TResponse; else if (responseType == typeof(VoidResponse)) - cs = VoidResponse.Default as TResponse; + response = VoidResponse.Default as TResponse; else if (responseType == typeof(DynamicResponse)) { //if not json store the result under "body" @@ -314,17 +331,17 @@ private static bool SetSpecialTypes(string mimeType, byte[] bytes, St { ["body"] = new DynamicValue(bytes.Utf8String()) }; - cs = new DynamicResponse(dictionary) as TResponse; + response = new DynamicResponse(dictionary) as TResponse; } else { using var ms = memoryStreamFactory.Create(bytes); var body = LowLevelRequestResponseSerializer.Instance.Deserialize(ms); - cs = new DynamicResponse(body) as TResponse; + response = new DynamicResponse(body) as TResponse; } } - return cs != null; + return response != null; } private static bool NeedsToEagerReadStream() @@ -336,7 +353,6 @@ private static bool NeedsToEagerReadStream() private static byte[] SwapStreams(ref Stream responseStream, ref MemoryStream ms) { var bytes = ms.ToArray(); - responseStream.Dispose(); responseStream = ms; responseStream.Position = 0; return bytes; diff --git a/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs index a466b78..4fd3335 100644 --- a/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/HttpRequestInvoker.cs @@ -75,7 +75,7 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req Exception ex = null; string mimeType = null; long contentLength = -1; - IDisposable receive = DiagnosticSources.SingletonDisposable; + IDisposable receivedResponse = DiagnosticSources.SingletonDisposable; ReadOnlyDictionary tcpStats = null; ReadOnlyDictionary threadPoolStats = null; Dictionary> responseHeaders = null; @@ -118,7 +118,7 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req responseMessage = client.SendAsync(requestMessage, HttpCompletionOption.ResponseHeadersRead, cancellationToken).GetAwaiter().GetResult(); #endif - receive = responseMessage; + receivedResponse = responseMessage; statusCode = (int)responseMessage.StatusCode; } @@ -154,13 +154,10 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req ex = e; } - var isStreamResponse = typeof(TResponse) == typeof(StreamResponse); + TResponse response; - using (isStreamResponse ? DiagnosticSources.SingletonDisposable : receive) - using (isStreamResponse ? Stream.Null : responseStream ??= Stream.Null) + try { - TResponse response; - if (isAsync) response = await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponseAsync (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats, cancellationToken) @@ -169,9 +166,18 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req response = requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats); - // Defer disposal of the response message - if (response is StreamResponse sr) - sr.Finalizer = () => receive.Dispose(); + // Unless indicated otherwise by the TransportResponse, we've now handled the response stream, so we can dispose of the HttpResponseMessage + // to release the connection. In cases, where the derived response works directly on the stream, it can be left open and additional IDisposable + // resources can be linked such that their disposal is deferred. + if (response.LeaveOpen) + { + response.LinkedDisposables = [receivedResponse, responseStream]; + } + else + { + responseStream.Dispose(); + receivedResponse.Dispose(); + } if (!OpenTelemetry.CurrentSpanIsElasticTransportOwnedAndHasListeners || (!(Activity.Current?.IsAllDataRequested ?? false))) return response; @@ -185,6 +191,13 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req return response; } + catch + { + // if there's an exception, ensure we always release the stream and response so that the connection is freed. + responseStream.Dispose(); + receivedResponse.Dispose(); + throw; + } } private static Dictionary>? ParseHeaders(RequestData requestData, HttpResponseMessage responseMessage) diff --git a/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs index e1c8ae7..6ced0e1 100644 --- a/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs @@ -68,6 +68,7 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req Exception ex = null; string mimeType = null; long contentLength = -1; + IDisposable receivedResponse = DiagnosticSources.SingletonDisposable; ReadOnlyDictionary tcpStats = null; ReadOnlyDictionary threadPoolStats = null; Dictionary> responseHeaders = null; @@ -146,6 +147,8 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req httpWebResponse = (HttpWebResponse)request.GetResponse(); } + receivedResponse = httpWebResponse; + HandleResponse(httpWebResponse, out statusCode, out responseStream, out mimeType); responseHeaders = ParseHeaders(requestData, httpWebResponse, responseHeaders); contentLength = httpWebResponse.ContentLength; @@ -161,28 +164,50 @@ private async ValueTask RequestCoreAsync(bool isAsync, Req { unregisterWaitHandle?.Invoke(); } - responseStream ??= Stream.Null; - - TResponse response; - if (isAsync) - response = await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponseAsync - (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats, cancellationToken) - .ConfigureAwait(false); - else - response = requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse - (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats); - - if (OpenTelemetry.CurrentSpanIsElasticTransportOwnedAndHasListeners && (Activity.Current?.IsAllDataRequested ?? false)) + try { - var attributes = requestData.ConnectionSettings.ProductRegistration.ParseOpenTelemetryAttributesFromApiCallDetails(response.ApiCallDetails); - foreach (var attribute in attributes) + TResponse response; + + if (isAsync) + response = await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponseAsync + (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats, cancellationToken) + .ConfigureAwait(false); + else + response = requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse + (requestData, ex, statusCode, responseHeaders, responseStream, mimeType, contentLength, threadPoolStats, tcpStats); + + // Unless indicated otherwise by the TransportResponse, we've now handled the response stream, so we can dispose of the HttpResponseMessage + // to release the connection. In cases, where the derived response works directly on the stream, it can be left open and additional IDisposable + // resources can be linked such that their disposal is deferred. + if (response.LeaveOpen) { - Activity.Current?.SetTag(attribute.Key, attribute.Value); + response.LinkedDisposables = [receivedResponse, responseStream]; + } + else + { + responseStream.Dispose(); + receivedResponse.Dispose(); } - } - return response; + if (OpenTelemetry.CurrentSpanIsElasticTransportOwnedAndHasListeners && (Activity.Current?.IsAllDataRequested ?? false)) + { + var attributes = requestData.ConnectionSettings.ProductRegistration.ParseOpenTelemetryAttributesFromApiCallDetails(response.ApiCallDetails); + foreach (var attribute in attributes) + { + Activity.Current?.SetTag(attribute.Key, attribute.Value); + } + } + + return response; + } + catch + { + // if there's an exception, ensure we always release the stream and response so that the connection is freed. + responseStream.Dispose(); + receivedResponse.Dispose(); + throw; + } } private static Dictionary> ParseHeaders(RequestData requestData, HttpWebResponse responseMessage, Dictionary> responseHeaders) diff --git a/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs b/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs index 12e1f9b..aead20a 100644 --- a/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs +++ b/src/Elastic.Transport/Components/TransportClient/InMemoryRequestInvoker.cs @@ -66,24 +66,27 @@ public TResponse BuildResponse(RequestData requestData, byte[] respon { var body = responseBody ?? _responseBody; var data = requestData.PostData; - if (data != null) + + if (data is not null) { - using (var stream = requestData.MemoryStreamFactory.Create()) + using var stream = requestData.MemoryStreamFactory.Create(); + if (requestData.HttpCompression) + { + using var zipStream = new GZipStream(stream, CompressionMode.Compress); + data.Write(zipStream, requestData.ConnectionSettings); + } + else { - if (requestData.HttpCompression) - { - using var zipStream = new GZipStream(stream, CompressionMode.Compress); - data.Write(zipStream, requestData.ConnectionSettings); - } - else - data.Write(stream, requestData.ConnectionSettings); + data.Write(stream, requestData.ConnectionSettings); } } requestData.MadeItToResponse = true; var sc = statusCode ?? _statusCode; - Stream s = body != null ? requestData.MemoryStreamFactory.Create(body) : requestData.MemoryStreamFactory.Create(EmptyBody); - return requestData.ConnectionSettings.ProductRegistration.ResponseBuilder.ToResponse(requestData, _exception, sc, _headers, s, contentType ?? _contentType ?? RequestData.DefaultMimeType, body?.Length ?? 0, null, null); + Stream responseStream = body != null ? requestData.MemoryStreamFactory.Create(body) : requestData.MemoryStreamFactory.Create(EmptyBody); + + return requestData.ConnectionSettings.ProductRegistration.ResponseBuilder + .ToResponse(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType ?? RequestData.DefaultMimeType, body?.Length ?? 0, null, null); } /// > @@ -93,26 +96,29 @@ public async Task BuildResponseAsync(RequestData requestDa { var body = responseBody ?? _responseBody; var data = requestData.PostData; - if (data != null) + + if (data is not null) { - using (var stream = requestData.MemoryStreamFactory.Create()) + using var stream = requestData.MemoryStreamFactory.Create(); + + if (requestData.HttpCompression) { - if (requestData.HttpCompression) - { - using var zipStream = new GZipStream(stream, CompressionMode.Compress); - await data.WriteAsync(zipStream, requestData.ConnectionSettings, cancellationToken).ConfigureAwait(false); - } - else - await data.WriteAsync(stream, requestData.ConnectionSettings, cancellationToken).ConfigureAwait(false); + using var zipStream = new GZipStream(stream, CompressionMode.Compress); + await data.WriteAsync(zipStream, requestData.ConnectionSettings, cancellationToken).ConfigureAwait(false); + } + else + { + await data.WriteAsync(stream, requestData.ConnectionSettings, cancellationToken).ConfigureAwait(false); } } requestData.MadeItToResponse = true; var sc = statusCode ?? _statusCode; - Stream s = body != null ? requestData.MemoryStreamFactory.Create(body) : requestData.MemoryStreamFactory.Create(EmptyBody); + + Stream responseStream = body != null ? requestData.MemoryStreamFactory.Create(body) : requestData.MemoryStreamFactory.Create(EmptyBody); + return await requestData.ConnectionSettings.ProductRegistration.ResponseBuilder - .ToResponseAsync(requestData, _exception, sc, _headers, s, contentType ?? _contentType, body?.Length ?? 0, null, null, cancellationToken) + .ToResponseAsync(requestData, _exception, sc, _headers, responseStream, contentType ?? _contentType, body?.Length ?? 0, null, null, cancellationToken) .ConfigureAwait(false); } - } diff --git a/src/Elastic.Transport/Requests/Body/PostData.cs b/src/Elastic.Transport/Requests/Body/PostData.cs index 68f2c52..9e59775 100644 --- a/src/Elastic.Transport/Requests/Body/PostData.cs +++ b/src/Elastic.Transport/Requests/Body/PostData.cs @@ -111,6 +111,7 @@ protected void FinishStream(Stream writableStream, MemoryStream buffer, ITranspo buffer.Position = 0; buffer.CopyTo(writableStream, BufferSize); WrittenBytes ??= buffer.ToArray(); + buffer.Dispose(); } /// @@ -132,5 +133,10 @@ protected async buffer.Position = 0; await buffer.CopyToAsync(writableStream, BufferSize, ctx).ConfigureAwait(false); WrittenBytes ??= buffer.ToArray(); +#if NET + await buffer.DisposeAsync().ConfigureAwait(false); +#else + buffer.Dispose(); +#endif } } diff --git a/src/Elastic.Transport/Responses/Special/StreamResponse.cs b/src/Elastic.Transport/Responses/Special/StreamResponse.cs index 53dd22a..08b2de6 100644 --- a/src/Elastic.Transport/Responses/Special/StreamResponse.cs +++ b/src/Elastic.Transport/Responses/Special/StreamResponse.cs @@ -10,14 +10,12 @@ namespace Elastic.Transport; /// /// A response that exposes the response as . /// -/// Must be disposed after use. +/// MUST be disposed after use to ensure the HTTP connection is freed for reuse. /// /// -public sealed class StreamResponse : - TransportResponse, - IDisposable +public class StreamResponse : TransportResponse, IDisposable { - internal Action? Finalizer { get; set; } + private bool _disposed; /// /// The MIME type of the response, if present. @@ -38,10 +36,37 @@ public StreamResponse(Stream body, string? mimeType) MimeType = mimeType ?? string.Empty; } - /// + internal override bool LeaveOpen => true; + + /// + /// Disposes the underlying stream. + /// + /// + protected virtual void Dispose(bool disposing) + { + if (!_disposed) + { + if (disposing) + { + Body.Dispose(); + + if (LinkedDisposables is not null) + { + foreach (var disposable in LinkedDisposables) + disposable.Dispose(); + } + } + + _disposed = true; + } + } + + /// + /// Disposes the underlying stream. + /// public void Dispose() { - Body.Dispose(); - Finalizer?.Invoke(); + Dispose(disposing: true); + GC.SuppressFinalize(this); } } diff --git a/src/Elastic.Transport/Responses/TransportResponse.cs b/src/Elastic.Transport/Responses/TransportResponse.cs index e74cece..4e7f4a8 100644 --- a/src/Elastic.Transport/Responses/TransportResponse.cs +++ b/src/Elastic.Transport/Responses/TransportResponse.cs @@ -2,13 +2,15 @@ // Elasticsearch B.V licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information +using System; +using System.Collections.Generic; using System.Text.Json.Serialization; namespace Elastic.Transport; /// /// A response from an Elastic product including details about the request/response life cycle. Base class for the built in low level response -/// types, , , and +/// types, , , , and /// public abstract class TransportResponse : TransportResponse { @@ -34,5 +36,25 @@ public abstract class TransportResponse public override string ToString() => ApiCallDetails?.DebugInformation // ReSharper disable once ConstantNullCoalescingCondition ?? $"{nameof(ApiCallDetails)} not set, likely a bug, reverting to default ToString(): {base.ToString()}"; + + /// + /// Allows other disposable resources to to be disposed along with the response. + /// + /// + /// While it's slightly confusing to have this on the base type which is NOT IDisposable, it avoids + /// specialised type checking in the request invoker and response builder code. Currently, only used by + /// StreamResponse and kept internal. If we later make this public, we might need to refine this. + /// + [JsonIgnore] + internal IEnumerable? LinkedDisposables { get; set; } + + /// + /// Allows the response to identify that the response stream should NOT be automatically disposed. + /// + /// + /// Currently only used by StreamResponse and therefore internal. + /// + [JsonIgnore] + internal virtual bool LeaveOpen { get; } = false; } diff --git a/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs b/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs new file mode 100644 index 0000000..7ec6720 --- /dev/null +++ b/tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs @@ -0,0 +1,208 @@ +// Licensed to Elasticsearch B.V under one or more agreements. +// Elasticsearch B.V licenses this file to you under the Apache 2.0 License. +// See the LICENSE file in the project root for more information + +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Elastic.Transport.IntegrationTests.Plumbing; +using Elastic.Transport.Products.Elasticsearch; +using FluentAssertions; +using Microsoft.AspNetCore.Mvc; +using Xunit; + +namespace Elastic.Transport.IntegrationTests.Http; + +public class StreamResponseTests(TransportTestServer instance) : AssemblyServerTestsBase(instance) +{ + private const string Path = "/streamresponse"; + + [Fact] + public async Task StreamResponse_ShouldNotBeDisposed() + { + var nodePool = new SingleNodePool(Server.Uri); + var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient))); + var transport = new DistributedTransport(config); + + var response = await transport.PostAsync(Path, PostData.String("{}")); + + // Ensure the stream is readable + using var sr = new StreamReader(response.Body); + _ = sr.ReadToEndAsync(); + } + + [Fact] + public async Task StreamResponse_MemoryStreamShouldNotBeDisposed() + { + var nodePool = new SingleNodePool(Server.Uri); + var memoryStreamFactory = new TrackMemoryStreamFactory(); + var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient))) + .MemoryStreamFactory(memoryStreamFactory) + .DisableDirectStreaming(true); + + var transport = new DistributedTransport(config); + + _ = await transport.PostAsync(Path, PostData.String("{}")); + + // When disable direct streaming, we have 1 for the original content, 1 for the buffered request bytes and the last for the buffered response + memoryStreamFactory.Created.Count.Should().Be(3); + memoryStreamFactory.Created.Last().IsDisposed.Should().BeFalse(); + } + + [Fact] + public async Task StringResponse_MemoryStreamShouldBeDisposed() + { + var nodePool = new SingleNodePool(Server.Uri); + var memoryStreamFactory = new TrackMemoryStreamFactory(); + var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient))) + .MemoryStreamFactory(memoryStreamFactory); + + var transport = new DistributedTransport(config); + + _ = await transport.PostAsync(Path, PostData.String("{}")); + + memoryStreamFactory.Created.Count.Should().Be(2); + foreach (var memoryStream in memoryStreamFactory.Created) + { + memoryStream.IsDisposed.Should().BeTrue(); + } + } + + [Fact] + public async Task WhenInvalidJson_MemoryStreamShouldBeDisposed() + { + var nodePool = new SingleNodePool(Server.Uri); + var memoryStreamFactory = new TrackMemoryStreamFactory(); + var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient))) + .MemoryStreamFactory(memoryStreamFactory) + .DisableDirectStreaming(true); + + var transport = new DistributedTransport(config); + + var payload = new Payload { ResponseJsonString = " " }; + _ = await transport.PostAsync(Path, PostData.Serializable(payload)); + + memoryStreamFactory.Created.Count.Should().Be(3); + foreach (var memoryStream in memoryStreamFactory.Created) + { + memoryStream.IsDisposed.Should().BeTrue(); + } + } + + [Fact] + public async Task WhenNoContent_MemoryStreamShouldBeDisposed() + { + var nodePool = new SingleNodePool(Server.Uri); + var memoryStreamFactory = new TrackMemoryStreamFactory(); + var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient))) + .MemoryStreamFactory(memoryStreamFactory); + + var transport = new DistributedTransport(config); + + var payload = new Payload { ResponseJsonString = "", StatusCode = 204 }; + _ = await transport.PostAsync(Path, PostData.Serializable(payload)); + + // We expect one for sending the request payload, but as the response is 204, we shouldn't + // see other memory streams being created for the response. + memoryStreamFactory.Created.Count.Should().Be(1); + foreach (var memoryStream in memoryStreamFactory.Created) + { + memoryStream.IsDisposed.Should().BeTrue(); + } + } + + [Fact] + public async Task PlainText_MemoryStreamShouldBeDisposed() + { + var nodePool = new SingleNodePool(Server.Uri); + var memoryStreamFactory = new TrackMemoryStreamFactory(); + var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient))) + .MemoryStreamFactory(memoryStreamFactory) + .DisableDirectStreaming(true); + + var transport = new DistributedTransport(config); + + var payload = new Payload { ResponseJsonString = "text", ContentType = "text/plain" }; + _ = await transport.PostAsync(Path, PostData.Serializable(payload)); + + memoryStreamFactory.Created.Count.Should().Be(3); + foreach (var memoryStream in memoryStreamFactory.Created) + { + memoryStream.IsDisposed.Should().BeTrue(); + } + } + + private class TestResponse : TransportResponse + { + } + + private class TrackDisposeStream : MemoryStream + { + public TrackDisposeStream() { } + + public TrackDisposeStream(byte[] bytes) : base(bytes) { } + + public TrackDisposeStream(byte[] bytes, int index, int count) : base(bytes, index, count) { } + + public bool IsDisposed { get; private set; } + + protected override void Dispose(bool disposing) + { + IsDisposed = true; + base.Dispose(disposing); + } + } + + private class TrackMemoryStreamFactory : MemoryStreamFactory + { + public IList Created { get; } = []; + + public override MemoryStream Create() + { + var stream = new TrackDisposeStream(); + Created.Add(stream); + return stream; + } + + public override MemoryStream Create(byte[] bytes) + { + var stream = new TrackDisposeStream(bytes); + Created.Add(stream); + return stream; + } + + public override MemoryStream Create(byte[] bytes, int index, int count) + { + var stream = new TrackDisposeStream(bytes, index, count); + Created.Add(stream); + return stream; + } + } +} + +public class Payload +{ + public string ResponseJsonString { get; set; } = "{}"; + public string ContentType { get; set; } = "application/json"; + public int StatusCode { get; set; } = 200; +} + +[ApiController, Route("[controller]")] +public class StreamResponseController : ControllerBase +{ + [HttpPost] + public async Task Post([FromBody] Payload payload) + { + Response.ContentType = payload.ContentType; + + if (payload.StatusCode != 204) + { + await Response.BodyWriter.WriteAsync(Encoding.UTF8.GetBytes(payload.ResponseJsonString)); + await Response.BodyWriter.CompleteAsync(); + } + + return StatusCode(payload.StatusCode); + } +} diff --git a/tests/Elastic.Transport.Tests/Plumbing/InMemoryConnectionFactory.cs b/tests/Elastic.Transport.Tests/Plumbing/InMemoryConnectionFactory.cs index c38bb05..f4eb21d 100644 --- a/tests/Elastic.Transport.Tests/Plumbing/InMemoryConnectionFactory.cs +++ b/tests/Elastic.Transport.Tests/Plumbing/InMemoryConnectionFactory.cs @@ -3,16 +3,17 @@ // See the LICENSE file in the project root for more information using System; +using Elastic.Transport.Products; namespace Elastic.Transport.Tests.Plumbing { public static class InMemoryConnectionFactory { - public static TransportConfiguration Create() + public static TransportConfiguration Create(ProductRegistration productRegistration = null) { var invoker = new InMemoryRequestInvoker(); var pool = new SingleNodePool(new Uri("http://localhost:9200")); - var settings = new TransportConfiguration(pool, invoker); + var settings = new TransportConfiguration(pool, invoker, productRegistration: productRegistration); return settings; } } diff --git a/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs b/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs index db5f3e4..8b79c6a 100644 --- a/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs +++ b/tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs @@ -5,144 +5,228 @@ using System; using System.Collections.Generic; using System.IO; -using System.Linq; +using System.Text; using System.Threading; using System.Threading.Tasks; +using Elastic.Transport.Products; using Elastic.Transport.Tests.Plumbing; using FluentAssertions; using Xunit; -namespace Elastic.Transport.Tests +namespace Elastic.Transport.Tests; + +public class ResponseBuilderDisposeTests { - public class ResponseBuilderDisposeTests + private readonly ITransportConfiguration _settings = InMemoryConnectionFactory.Create().DisableDirectStreaming(false); + private readonly ITransportConfiguration _settingsDisableDirectStream = InMemoryConnectionFactory.Create().DisableDirectStreaming(); + + [Fact] + public async Task StreamResponseWithPotentialBody_StreamIsNotDisposed() => + await AssertResponse(false, expectedDisposed: false); + + [Fact] + public async Task StreamResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsNotDisposed() => + await AssertResponse(true, expectedDisposed: false); + + [Fact] + public async Task ResponseWithPotentialBodyButInvalidMimeType_MemoryStreamIsDisposed() => + await AssertResponse(true, mimeType: "application/not-valid", expectedDisposed: true); + + [Fact] + public async Task ResponseWithPotentialBodyButSkippedStatusCode_MemoryStreamIsDisposed() => + await AssertResponse(true, skipStatusCode: 200, expectedDisposed: true); + + [Fact] + public async Task ResponseWithPotentialBodyButEmptyJson_MemoryStreamIsDisposed() => + await AssertResponse(true, responseJson: " ", expectedDisposed: true); + + [Fact] + // NOTE: The empty string here hits a fast path in STJ which returns default if the stream length is zero. + public async Task ResponseWithPotentialBodyButNullResponseDuringDeserialization_MemoryStreamIsDisposed() => + await AssertResponse(true, responseJson: "", expectedDisposed: true); + + [Fact] + public async Task ResponseWithPotentialBodyAndCustomResponseBuilder_MemoryStreamIsDisposed() => + await AssertResponse(true, customResponseBuilder: new TestCustomResponseBuilder(), expectedDisposed: true); + + [Fact] + // NOTE: We expect one memory stream factory creation when handling error responses + public async Task ResponseWithPotentialBodyAndErrorResponse_StreamIsDisposed() => + await AssertResponse(true, productRegistration: new TestProductRegistration(), expectedDisposed: true); + + [Fact] + public async Task StringResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsDisposed() => + await AssertResponse(false, expectedDisposed: true, memoryStreamCreateExpected: 1); + + private async Task AssertResponse(bool disableDirectStreaming, int statusCode = 200, HttpMethod httpMethod = HttpMethod.GET, int contentLength = 10, + bool expectedDisposed = true, string mimeType = "application/json", string responseJson = "{}", int skipStatusCode = -1, + CustomResponseBuilder customResponseBuilder = null, ProductRegistration productRegistration = null, int memoryStreamCreateExpected = -1) + where T : TransportResponse, new() { - private readonly ITransportConfiguration _settings = InMemoryConnectionFactory.Create().DisableDirectStreaming(false); - private readonly ITransportConfiguration _settingsDisableDirectStream = InMemoryConnectionFactory.Create().DisableDirectStreaming(); - - [Fact] public async Task ResponseWithHttpStatusCode() => await AssertRegularResponse(false, 1); + ITransportConfiguration config; - [Fact] public async Task ResponseBuilderWithNoHttpStatusCode() => await AssertRegularResponse(false); + if (skipStatusCode > -1 ) + { + config = InMemoryConnectionFactory.Create(productRegistration) + .DisableDirectStreaming(disableDirectStreaming) + .SkipDeserializationForStatusCodes(skipStatusCode); + } + else if (productRegistration is not null) + { + config = InMemoryConnectionFactory.Create(productRegistration) + .DisableDirectStreaming(disableDirectStreaming); + } + else + { + config = disableDirectStreaming ? _settingsDisableDirectStream : _settings; + } - [Fact] public async Task ResponseWithHttpStatusCodeDisableDirectStreaming() => - await AssertRegularResponse(true, 1); + var memoryStreamFactory = new TrackMemoryStreamFactory(); + + var requestData = new RequestData(httpMethod, "/", null, config, null, customResponseBuilder, memoryStreamFactory, default) + { + Node = new Node(new Uri("http://localhost:9200")) + }; - [Fact] public async Task ResponseBuilderWithNoHttpStatusCodeDisableDirectStreaming() => - await AssertRegularResponse(true); + var stream = new TrackDisposeStream(); - private async Task AssertRegularResponse(bool disableDirectStreaming, int? statusCode = null) + if (!string.IsNullOrEmpty(responseJson)) { - var settings = disableDirectStreaming ? _settingsDisableDirectStream : _settings; - var memoryStreamFactory = new TrackMemoryStreamFactory(); - var requestData = new RequestData(HttpMethod.GET, "/", null, settings, null, null, memoryStreamFactory, default) - { - Node = new Node(new Uri("http://localhost:9200")) - }; + stream.Write(Encoding.UTF8.GetBytes(responseJson), 0, responseJson.Length); + stream.Position = 0; + } - var stream = new TrackDisposeStream(); - var response = _settings.ProductRegistration.ResponseBuilder.ToResponse(requestData, null, statusCode, null, stream, null, -1, null, null); - response.Should().NotBeNull(); - - memoryStreamFactory.Created.Count().Should().Be(disableDirectStreaming ? 1 : 0); - if (disableDirectStreaming) - { - var memoryStream = memoryStreamFactory.Created[0]; - memoryStream.IsDisposed.Should().BeTrue(); - } - stream.IsDisposed.Should().BeTrue(); - - - stream = new TrackDisposeStream(); - var ct = new CancellationToken(); - response = await _settings.ProductRegistration.ResponseBuilder.ToResponseAsync(requestData, null, statusCode, null, stream, null, -1, null, null, - cancellationToken: ct); - response.Should().NotBeNull(); - memoryStreamFactory.Created.Count().Should().Be(disableDirectStreaming ? 2 : 0); - if (disableDirectStreaming) - { - var memoryStream = memoryStreamFactory.Created[1]; - memoryStream.IsDisposed.Should().BeTrue(); - } - stream.IsDisposed.Should().BeTrue(); + var response = config.ProductRegistration.ResponseBuilder.ToResponse(requestData, null, statusCode, null, stream, mimeType, contentLength, null, null); + + response.Should().NotBeNull(); + + memoryStreamFactory.Created.Count.Should().Be(memoryStreamCreateExpected > -1 ? memoryStreamCreateExpected : disableDirectStreaming ? 1 : 0); + if (disableDirectStreaming) + { + var memoryStream = memoryStreamFactory.Created[0]; + memoryStream.IsDisposed.Should().Be(expectedDisposed); } - [Fact] public async Task StreamResponseWithHttpStatusCode() => await AssertStreamResponse(false, 200); + // The latest implementation should never dispose the incoming stream and assumes the caller will handler disposal + stream.IsDisposed.Should().Be(false); - [Fact] public async Task StreamResponseBuilderWithNoHttpStatusCode() => await AssertStreamResponse(false); + stream = new TrackDisposeStream(); + var ct = new CancellationToken(); - [Fact] public async Task StreamResponseWithHttpStatusCodeDisableDirectStreaming() => - await AssertStreamResponse(true, 1); + response = await config.ProductRegistration.ResponseBuilder.ToResponseAsync(requestData, null, statusCode, null, stream, null, contentLength, null, null, + cancellationToken: ct); - [Fact] public async Task StreamResponseBuilderWithNoHttpStatusCodeDisableDirectStreaming() => - await AssertStreamResponse(true); + response.Should().NotBeNull(); - private async Task AssertStreamResponse(bool disableDirectStreaming, int? statusCode = null) + memoryStreamFactory.Created.Count.Should().Be(memoryStreamCreateExpected > -1 ? memoryStreamCreateExpected + 1 : disableDirectStreaming ? 2 : 0); + if (disableDirectStreaming) { - var settings = disableDirectStreaming ? _settingsDisableDirectStream : _settings; - var memoryStreamFactory = new TrackMemoryStreamFactory(); + var memoryStream = memoryStreamFactory.Created[0]; + memoryStream.IsDisposed.Should().Be(expectedDisposed); + } - var requestData = new RequestData(HttpMethod.GET, "/", null, settings, null, null, memoryStreamFactory, default) - { - Node = new Node(new Uri("http://localhost:9200")) - }; + // The latest implementation should never dispose the incoming stream and assumes the caller will handler disposal + stream.IsDisposed.Should().Be(false); + } - var stream = new TrackDisposeStream(); - var response = _settings.ProductRegistration.ResponseBuilder.ToResponse(requestData, null, statusCode, null, stream, null, -1, null, null); - response.Should().NotBeNull(); - - memoryStreamFactory.Created.Count().Should().Be(disableDirectStreaming ? 1 : 0); - stream.IsDisposed.Should().Be(true); - - stream = new TrackDisposeStream(); - var ct = new CancellationToken(); - response = await _settings.ProductRegistration.ResponseBuilder.ToResponseAsync(requestData, null, statusCode, null, stream, null, -1, null, null, - cancellationToken: ct); - response.Should().NotBeNull(); - memoryStreamFactory.Created.Count().Should().Be(disableDirectStreaming ? 2 : 0); - stream.IsDisposed.Should().Be(true); - } + private class TestProductRegistration : ProductRegistration + { + public override string DefaultMimeType => "application/json"; + public override string Name => "name"; + public override string ServiceIdentifier => "id"; + public override bool SupportsPing => false; + public override bool SupportsSniff => false; + public override HeadersList ResponseHeadersToParse => []; + public override MetaHeaderProvider MetaHeaderProvider => null; + public override string ProductAssemblyVersion => "0.0.0"; + public override IReadOnlyDictionary DefaultOpenTelemetryAttributes => new Dictionary(); + public override RequestData CreatePingRequestData(Node node, RequestConfiguration requestConfiguration, ITransportConfiguration global, MemoryStreamFactory memoryStreamFactory) => throw new NotImplementedException(); + public override RequestData CreateSniffRequestData(Node node, IRequestConfiguration requestConfiguration, ITransportConfiguration settings, MemoryStreamFactory memoryStreamFactory) => throw new NotImplementedException(); + public override IReadOnlyCollection DefaultHeadersToParse() => []; + public override bool HttpStatusCodeClassifier(HttpMethod method, int statusCode) => true; + public override bool NodePredicate(Node node) => throw new NotImplementedException(); + public override Dictionary ParseOpenTelemetryAttributesFromApiCallDetails(ApiCallDetails callDetails) => throw new NotImplementedException(); + public override TransportResponse Ping(IRequestInvoker requestInvoker, RequestData pingData) => throw new NotImplementedException(); + public override Task PingAsync(IRequestInvoker requestInvoker, RequestData pingData, CancellationToken cancellationToken) => throw new NotImplementedException(); + public override Tuple> Sniff(IRequestInvoker requestInvoker, bool forceSsl, RequestData requestData) => throw new NotImplementedException(); + public override Task>> SniffAsync(IRequestInvoker requestInvoker, bool forceSsl, RequestData requestData, CancellationToken cancellationToken) => throw new NotImplementedException(); + public override int SniffOrder(Node node) => throw new NotImplementedException(); + public override bool TryGetServerErrorReason(TResponse response, out string reason) => throw new NotImplementedException(); + public override ResponseBuilder ResponseBuilder => new TestErrorResponseBuilder(); + } + + private class TestError : ErrorResponse + { + public string MyError { get; set; } + public override bool HasError() => true; + } - private class TrackDisposeStream : MemoryStream + private class TestErrorResponseBuilder : DefaultResponseBuilder + { + protected override void SetErrorOnResponse(TResponse response, TestError error) { - public TrackDisposeStream() { } + // nothing to do in this scenario + } + + protected override bool TryGetError(ApiCallDetails apiCallDetails, RequestData requestData, Stream responseStream, out TestError error) + { + error = new TestError(); + return true; + } + + protected override bool RequiresErrorDeserialization(ApiCallDetails details, RequestData requestData) => true; + } + + private class TestCustomResponseBuilder : CustomResponseBuilder + { + public override object DeserializeResponse(Serializer serializer, ApiCallDetails response, Stream stream) => + new TestResponse { ApiCallDetails = response }; + + public override Task DeserializeResponseAsync(Serializer serializer, ApiCallDetails response, Stream stream, CancellationToken ctx = default) => + Task.FromResult(new TestResponse { ApiCallDetails = response }); + } + + private class TrackDisposeStream : MemoryStream + { + public TrackDisposeStream() { } - public TrackDisposeStream(byte[] bytes) : base(bytes) { } + public TrackDisposeStream(byte[] bytes) : base(bytes) { } - public TrackDisposeStream(byte[] bytes, int index, int count) : base(bytes, index, count) { } + public TrackDisposeStream(byte[] bytes, int index, int count) : base(bytes, index, count) { } - public bool IsDisposed { get; private set; } + public bool IsDisposed { get; private set; } - protected override void Dispose(bool disposing) - { - IsDisposed = true; - base.Dispose(disposing); - } + protected override void Dispose(bool disposing) + { + IsDisposed = true; + base.Dispose(disposing); + } + } + + private class TrackMemoryStreamFactory : MemoryStreamFactory + { + public IList Created { get; } = []; + + public override MemoryStream Create() + { + var stream = new TrackDisposeStream(); + Created.Add(stream); + return stream; + } + + public override MemoryStream Create(byte[] bytes) + { + var stream = new TrackDisposeStream(bytes); + Created.Add(stream); + return stream; } - private class TrackMemoryStreamFactory : MemoryStreamFactory + public override MemoryStream Create(byte[] bytes, int index, int count) { - public IList Created { get; } = new List(); - - public override MemoryStream Create() - { - var stream = new TrackDisposeStream(); - Created.Add(stream); - return stream; - } - - public override MemoryStream Create(byte[] bytes) - { - var stream = new TrackDisposeStream(bytes); - Created.Add(stream); - return stream; - } - - public override MemoryStream Create(byte[] bytes, int index, int count) - { - var stream = new TrackDisposeStream(bytes, index, count); - Created.Add(stream); - return stream; - } + var stream = new TrackDisposeStream(bytes, index, count); + Created.Add(stream); + return stream; } } }