diff --git a/sdk/signalr/Microsoft.Azure.WebJobs.Extensions.SignalRService/src/TriggerBindings/MessagePackHubProtocol/BinaryMessageFormatter.cs b/sdk/signalr/Microsoft.Azure.WebJobs.Extensions.SignalRService/src/TriggerBindings/MessagePackHubProtocol/BinaryMessageFormatter.cs new file mode 100644 index 0000000000000..d30853007e2f4 --- /dev/null +++ b/sdk/signalr/Microsoft.Azure.WebJobs.Extensions.SignalRService/src/TriggerBindings/MessagePackHubProtocol/BinaryMessageFormatter.cs @@ -0,0 +1,53 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Buffers; + +namespace Microsoft.AspNetCore.Internal; + +internal static class BinaryMessageFormatter +{ + public static void WriteLengthPrefix(long length, IBufferWriter output) + { + Span lenBuffer = stackalloc byte[5]; + + var lenNumBytes = WriteLengthPrefix(length, lenBuffer); + + output.Write(lenBuffer.Slice(0, lenNumBytes)); + } + + public static int WriteLengthPrefix(long length, Span output) + { + // This code writes length prefix of the message as a VarInt. Read the comment in + // the BinaryMessageParser.TryParseMessage for details. + var lenNumBytes = 0; + do + { + ref var current = ref output[lenNumBytes]; + current = (byte)(length & 0x7f); + length >>= 7; + if (length > 0) + { + current |= 0x80; + } + lenNumBytes++; + } + while (length > 0); + + return lenNumBytes; + } + + public static int LengthPrefixLength(long length) + { + var lenNumBytes = 0; + do + { + length >>= 7; + lenNumBytes++; + } + while (length > 0); + + return lenNumBytes; + } +} \ No newline at end of file diff --git a/sdk/signalr/Microsoft.Azure.WebJobs.Extensions.SignalRService/src/TriggerBindings/MessagePackHubProtocol/DefaultMessagePackHubProtocolWorker.cs b/sdk/signalr/Microsoft.Azure.WebJobs.Extensions.SignalRService/src/TriggerBindings/MessagePackHubProtocol/DefaultMessagePackHubProtocolWorker.cs new file mode 100644 index 0000000000000..7e9c5c1dd7204 --- /dev/null +++ b/sdk/signalr/Microsoft.Azure.WebJobs.Extensions.SignalRService/src/TriggerBindings/MessagePackHubProtocol/DefaultMessagePackHubProtocolWorker.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.IO; + +using MessagePack; + +namespace Microsoft.AspNetCore.SignalR.Protocol; + +internal sealed class DefaultMessagePackHubProtocolWorker : MessagePackHubProtocolWorker +{ + private readonly MessagePackSerializerOptions _messagePackSerializerOptions; + + public DefaultMessagePackHubProtocolWorker(MessagePackSerializerOptions messagePackSerializerOptions) + { + _messagePackSerializerOptions = messagePackSerializerOptions; + } + + protected override object? DeserializeObject(ref MessagePackReader reader, Type type, string field) + { + try + { + return MessagePackSerializer.Deserialize(type, ref reader, _messagePackSerializerOptions); + } + catch (Exception ex) + { + throw new InvalidDataException($"Deserializing object of the `{type.Name}` type for '{field}' failed.", ex); + } + } + + protected override void Serialize(ref MessagePackWriter writer, Type type, object value) + { + MessagePackSerializer.Serialize(type, ref writer, value, _messagePackSerializerOptions); + } +} \ No newline at end of file diff --git a/sdk/signalr/Microsoft.Azure.WebJobs.Extensions.SignalRService/src/TriggerBindings/MessagePackHubProtocol/MemoryBufferWriter.cs b/sdk/signalr/Microsoft.Azure.WebJobs.Extensions.SignalRService/src/TriggerBindings/MessagePackHubProtocol/MemoryBufferWriter.cs new file mode 100644 index 0000000000000..7861509636683 --- /dev/null +++ b/sdk/signalr/Microsoft.Azure.WebJobs.Extensions.SignalRService/src/TriggerBindings/MessagePackHubProtocol/MemoryBufferWriter.cs @@ -0,0 +1,404 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Internal; + +internal sealed class MemoryBufferWriter : Stream, IBufferWriter +{ + [ThreadStatic] + private static MemoryBufferWriter? _cachedInstance; + +#if DEBUG + private bool _inUse; +#endif + + private readonly int _minimumSegmentSize; + private int _bytesWritten; + + private List? _completedSegments; + private byte[]? _currentSegment; + private int _position; + + public MemoryBufferWriter(int minimumSegmentSize = 4096) + { + _minimumSegmentSize = minimumSegmentSize; + } + + public override long Length => _bytesWritten; + public override bool CanRead => false; + public override bool CanSeek => false; + public override bool CanWrite => true; + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public static MemoryBufferWriter Get() + { + var writer = _cachedInstance; + if (writer == null) + { + writer = new MemoryBufferWriter(); + } + else + { + // Taken off the thread static + _cachedInstance = null; + } +#if DEBUG + if (writer._inUse) + { + throw new InvalidOperationException("The reader wasn't returned!"); + } + + writer._inUse = true; +#endif + + return writer; + } + + public static void Return(MemoryBufferWriter writer) + { + _cachedInstance = writer; +#if DEBUG + writer._inUse = false; +#endif + writer.Reset(); + } + + public void Reset() + { + if (_completedSegments != null) + { + for (var i = 0; i < _completedSegments.Count; i++) + { + _completedSegments[i].Return(); + } + + _completedSegments.Clear(); + } + + if (_currentSegment != null) + { + ArrayPool.Shared.Return(_currentSegment); + _currentSegment = null; + } + + _bytesWritten = 0; + _position = 0; + } + + public void Advance(int count) + { + _bytesWritten += count; + _position += count; + } + + public Memory GetMemory(int sizeHint = 0) + { + EnsureCapacity(sizeHint); + + return _currentSegment.AsMemory(_position, _currentSegment.Length - _position); + } + + public Span GetSpan(int sizeHint = 0) + { + EnsureCapacity(sizeHint); + + return _currentSegment.AsSpan(_position, _currentSegment.Length - _position); + } + + public void CopyTo(IBufferWriter destination) + { + if (_completedSegments != null) + { + // Copy completed segments + var count = _completedSegments.Count; + for (var i = 0; i < count; i++) + { + destination.Write(_completedSegments[i].Span); + } + } + + destination.Write(_currentSegment.AsSpan(0, _position)); + } + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + if (_completedSegments == null && _currentSegment is not null) + { + // There is only one segment so write without awaiting. + return destination.WriteAsync(_currentSegment, 0, _position, cancellationToken); + } + + return CopyToSlowAsync(destination, cancellationToken); + } + + [MemberNotNull(nameof(_currentSegment))] + private void EnsureCapacity(int sizeHint) + { + // This does the Right Thing. It only subtracts _position from the current segment length if it's non-null. + // If _currentSegment is null, it returns 0. + var remainingSize = _currentSegment?.Length - _position ?? 0; + + // If the sizeHint is 0, any capacity will do + // Otherwise, the buffer must have enough space for the entire size hint, or we need to add a segment. + if ((sizeHint == 0 && remainingSize > 0) || (sizeHint > 0 && remainingSize >= sizeHint)) + { + // We have capacity in the current segment +#pragma warning disable CS8774 // Member must have a non-null value when exiting. + return; +#pragma warning restore CS8774 // Member must have a non-null value when exiting. + } + + AddSegment(sizeHint); + } + + [MemberNotNull(nameof(_currentSegment))] + private void AddSegment(int sizeHint = 0) + { + if (_currentSegment != null) + { + // We're adding a segment to the list + if (_completedSegments == null) + { + _completedSegments = new List(); + } + + // Position might be less than the segment length if there wasn't enough space to satisfy the sizeHint when + // GetMemory was called. In that case we'll take the current segment and call it "completed", but need to + // ignore any empty space in it. + _completedSegments.Add(new CompletedBuffer(_currentSegment, _position)); + } + + // Get a new buffer using the minimum segment size, unless the size hint is larger than a single segment. + _currentSegment = ArrayPool.Shared.Rent(Math.Max(_minimumSegmentSize, sizeHint)); + _position = 0; + } + + private async Task CopyToSlowAsync(Stream destination, CancellationToken cancellationToken) + { + if (_completedSegments != null) + { + // Copy full segments + var count = _completedSegments.Count; + for (var i = 0; i < count; i++) + { + var segment = _completedSegments[i]; +#if NETCOREAPP + await destination.WriteAsync(segment.Buffer.AsMemory(0, segment.Length), cancellationToken).ConfigureAwait(false); +#else + await destination.WriteAsync(segment.Buffer, 0, segment.Length, cancellationToken).ConfigureAwait(false); +#endif + } + } + + if (_currentSegment is not null) + { +#if NETCOREAPP + await destination.WriteAsync(_currentSegment.AsMemory(0, _position), cancellationToken).ConfigureAwait(false); +#else + await destination.WriteAsync(_currentSegment, 0, _position, cancellationToken).ConfigureAwait(false); +#endif + } + } + + public byte[] ToArray() + { + if (_currentSegment == null) + { + return Array.Empty(); + } + + var result = new byte[_bytesWritten]; + + var totalWritten = 0; + + if (_completedSegments != null) + { + // Copy full segments + var count = _completedSegments.Count; + for (var i = 0; i < count; i++) + { + var segment = _completedSegments[i]; + segment.Span.CopyTo(result.AsSpan(totalWritten)); + totalWritten += segment.Span.Length; + } + } + + // Copy current incomplete segment + _currentSegment.AsSpan(0, _position).CopyTo(result.AsSpan(totalWritten)); + + return result; + } + + public void CopyTo(Span span) + { + Debug.Assert(span.Length >= _bytesWritten); + + if (_currentSegment == null) + { + return; + } + + var totalWritten = 0; + + if (_completedSegments != null) + { + // Copy full segments + var count = _completedSegments.Count; + for (var i = 0; i < count; i++) + { + var segment = _completedSegments[i]; + segment.Span.CopyTo(span.Slice(totalWritten)); + totalWritten += segment.Span.Length; + } + } + + // Copy current incomplete segment + _currentSegment.AsSpan(0, _position).CopyTo(span.Slice(totalWritten)); + + Debug.Assert(_bytesWritten == totalWritten + _position); + } + + public override void Flush() { } + public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask; + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + + public override void WriteByte(byte value) + { + if (_currentSegment != null && (uint)_position < (uint)_currentSegment.Length) + { + _currentSegment[_position] = value; + } + else + { + AddSegment(); + _currentSegment[0] = value; + } + + _position++; + _bytesWritten++; + } + + public override void Write(byte[] buffer, int offset, int count) + { + var position = _position; + if (_currentSegment != null && position < _currentSegment.Length - count) + { + Buffer.BlockCopy(buffer, offset, _currentSegment, position, count); + + _position = position + count; + _bytesWritten += count; + } + else + { + BuffersExtensions.Write(this, buffer.AsSpan(offset, count)); + } + } + +#if NETCOREAPP + public override void Write(ReadOnlySpan span) + { + if (_currentSegment != null && span.TryCopyTo(_currentSegment.AsSpan(_position))) + { + _position += span.Length; + _bytesWritten += span.Length; + } + else + { + BuffersExtensions.Write(this, span); + } + } +#endif + + public WrittenBuffers DetachAndReset() + { + _completedSegments ??= new List(); + + if (_currentSegment is not null) + { + _completedSegments.Add(new CompletedBuffer(_currentSegment, _position)); + } + + var written = new WrittenBuffers(_completedSegments, _bytesWritten); + + _currentSegment = null; + _completedSegments = null; + _bytesWritten = 0; + _position = 0; + + return written; + } + + protected override void Dispose(bool disposing) + { + if (disposing) + { + Reset(); + } + } + + /// + /// Holds the written segments from a MemoryBufferWriter and is no longer attached to a MemoryBufferWriter. + /// You are now responsible for calling Dispose on this type to return the memory to the pool. + /// + internal readonly ref struct WrittenBuffers + { + public readonly List Segments; + private readonly int _bytesWritten; + + public WrittenBuffers(List segments, int bytesWritten) + { + Segments = segments; + _bytesWritten = bytesWritten; + } + + public int ByteLength => _bytesWritten; + + public void Dispose() + { + for (var i = 0; i < Segments.Count; i++) + { + Segments[i].Return(); + } + Segments.Clear(); + } + } + + /// + /// Holds a byte[] from the pool and a size value. Basically a Memory but guaranteed to be backed by an ArrayPool byte[], so that we know we can return it. + /// + internal readonly struct CompletedBuffer + { + public byte[] Buffer { get; } + public int Length { get; } + + public ReadOnlySpan Span => Buffer.AsSpan(0, Length); + + public CompletedBuffer(byte[] buffer, int length) + { + Buffer = buffer; + Length = length; + } + + public void Return() + { + ArrayPool.Shared.Return(Buffer); + } + } +} \ No newline at end of file diff --git a/sdk/signalr/Microsoft.Azure.WebJobs.Extensions.SignalRService/src/TriggerBindings/MessagePackHubProtocol/MessagePackHubProtocol.cs b/sdk/signalr/Microsoft.Azure.WebJobs.Extensions.SignalRService/src/TriggerBindings/MessagePackHubProtocol/MessagePackHubProtocol.cs new file mode 100644 index 0000000000000..66fe3b701d5e0 --- /dev/null +++ b/sdk/signalr/Microsoft.Azure.WebJobs.Extensions.SignalRService/src/TriggerBindings/MessagePackHubProtocol/MessagePackHubProtocol.cs @@ -0,0 +1,112 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; + +using MessagePack; +using MessagePack.Formatters; +using MessagePack.Resolvers; + +using Microsoft.AspNetCore.Connections; +using Microsoft.Extensions.Options; + +namespace Microsoft.AspNetCore.SignalR.Protocol; + +/// +/// Implements the SignalR Hub Protocol using MessagePack. +/// +public class MessagePackHubProtocol : IHubProtocol +{ + private const string ProtocolName = "messagepack"; + private const int ProtocolVersion = 2; + private readonly DefaultMessagePackHubProtocolWorker _worker; + + /// + public string Name => ProtocolName; + + /// + public int Version => ProtocolVersion; + + /// + public TransferFormat TransferFormat => TransferFormat.Binary; + + /// + /// Initializes a new instance of the class. + /// + public MessagePackHubProtocol() + : this(Options.Create(new MessagePackHubProtocolOptions())) + { } + + /// + /// Initializes a new instance of the class. + /// + /// The options used to initialize the protocol. + public MessagePackHubProtocol(IOptions options) + { + ArgumentNullThrowHelper.ThrowIfNull(options); + + _worker = new DefaultMessagePackHubProtocolWorker(options.Value.SerializerOptions); + } + + /// + public bool IsVersionSupported(int version) + { + return version <= Version; + } + + /// + public bool TryParseMessage(ref ReadOnlySequence input, IInvocationBinder binder, [NotNullWhen(true)] out HubMessage? message) + => _worker.TryParseMessage(ref input, binder, out message); + + /// + public void WriteMessage(HubMessage message, IBufferWriter output) + => _worker.WriteMessage(message, output); + + /// + public ReadOnlyMemory GetMessageBytes(HubMessage message) + => _worker.GetMessageBytes(message); + + internal static MessagePackSerializerOptions CreateDefaultMessagePackSerializerOptions() => + MessagePackSerializerOptions + .Standard + .WithResolver(SignalRResolver.Instance) + .WithSecurity(MessagePackSecurity.UntrustedData); + + internal sealed class SignalRResolver : IFormatterResolver + { + public static readonly IFormatterResolver Instance = new SignalRResolver(); + + public static readonly IReadOnlyList Resolvers = new IFormatterResolver[] + { + DynamicEnumAsStringResolver.Instance, + ContractlessStandardResolver.Instance, + }; + + public IMessagePackFormatter? GetFormatter() + { + return Cache.Formatter; + } + + private static class Cache + { + public static readonly IMessagePackFormatter? Formatter = ResolveFormatter(); + + private static IMessagePackFormatter? ResolveFormatter() + { + foreach (var resolver in Resolvers) + { + var formatter = resolver.GetFormatter(); + if (formatter != null) + { + return formatter; + } + } + + return null; + } + } + } +} \ No newline at end of file diff --git a/sdk/signalr/Microsoft.Azure.WebJobs.Extensions.SignalRService/src/TriggerBindings/MessagePackHubProtocol/MessagePackHubProtocolWorker.cs b/sdk/signalr/Microsoft.Azure.WebJobs.Extensions.SignalRService/src/TriggerBindings/MessagePackHubProtocol/MessagePackHubProtocolWorker.cs new file mode 100644 index 0000000000000..7c1f23f8bd215 --- /dev/null +++ b/sdk/signalr/Microsoft.Azure.WebJobs.Extensions.SignalRService/src/TriggerBindings/MessagePackHubProtocol/MessagePackHubProtocolWorker.cs @@ -0,0 +1,718 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable IDE0005 // This file is shared across multiple projects making it ugly to ignore unused usings + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.IO.Pipelines; +using System.Runtime.ExceptionServices; +using System.Text; + +using MessagePack; + +using Microsoft.AspNetCore.Internal; +using Microsoft.Azure.SignalR.Protocol; + +namespace Microsoft.AspNetCore.SignalR.Protocol; + +/// +/// Implements support for MessagePackHubProtocol. This code is shared between SignalR and Blazor. +/// +internal abstract class MessagePackHubProtocolWorker +{ + private const int ErrorResult = 1; + private const int VoidResult = 2; + private const int NonVoidResult = 3; + + public bool TryParseMessage(ref ReadOnlySequence input, IInvocationBinder binder, [NotNullWhen(true)] out HubMessage? message) + { + if (!BinaryMessageParser.TryParseMessage(ref input, out var payload)) + { + message = null; + return false; + } + + var reader = new MessagePackReader(payload); + message = ParseMessage(ref reader, binder); + return message != null; + } + + private HubMessage? ParseMessage(ref MessagePackReader reader, IInvocationBinder binder) + { + var itemCount = reader.ReadArrayHeader(); + + var messageType = ReadInt32(ref reader, "messageType"); + + switch (messageType) + { + case HubProtocolConstants.InvocationMessageType: + return CreateInvocationMessage(ref reader, binder, itemCount); + case HubProtocolConstants.StreamInvocationMessageType: + return CreateStreamInvocationMessage(ref reader, binder, itemCount); + case HubProtocolConstants.StreamItemMessageType: + return CreateStreamItemMessage(ref reader, binder); + case HubProtocolConstants.CompletionMessageType: + return CreateCompletionMessage(ref reader, binder); + case HubProtocolConstants.CancelInvocationMessageType: + return CreateCancelInvocationMessage(ref reader); + case HubProtocolConstants.PingMessageType: + return PingMessage.Instance; + case HubProtocolConstants.CloseMessageType: + return CreateCloseMessage(ref reader, itemCount); + case HubProtocolConstants.AckMessageType: + return CreateAckMessage(ref reader); + case HubProtocolConstants.SequenceMessageType: + return CreateSequenceMessage(ref reader); + default: + // Future protocol changes can add message types, old clients can ignore them + return null; + } + } + + private HubMessage CreateInvocationMessage(ref MessagePackReader reader, IInvocationBinder binder, int itemCount) + { + var headers = ReadHeaders(ref reader); + var invocationId = ReadInvocationId(ref reader); + + // For MsgPack, we represent an empty invocation ID as an empty string, + // so we need to normalize that to "null", which is what indicates a non-blocking invocation. + if (string.IsNullOrEmpty(invocationId)) + { + invocationId = null; + } + + var target = ReadString(ref reader, binder, "target"); + ThrowIfNullOrEmpty(target, "target for Invocation message"); + + object?[]? arguments; + try + { + var parameterTypes = binder.GetParameterTypes(target); + arguments = BindArguments(ref reader, parameterTypes); + } + catch (Exception ex) + { + return new InvocationBindingFailureMessage(invocationId, target, ExceptionDispatchInfo.Capture(ex)); + } + + string[]? streams = null; + // Previous clients will send 5 items, so we check if they sent a stream array or not + if (itemCount > 5) + { + streams = ReadStreamIds(ref reader); + } + + return ApplyHeaders(headers, new InvocationMessage(invocationId, target, arguments, streams)); + } + + private HubMessage CreateStreamInvocationMessage(ref MessagePackReader reader, IInvocationBinder binder, int itemCount) + { + var headers = ReadHeaders(ref reader); + var invocationId = ReadInvocationId(ref reader); + ThrowIfNullOrEmpty(invocationId, "invocation ID for StreamInvocation message"); + + var target = ReadString(ref reader, "target"); + ThrowIfNullOrEmpty(target, "target for StreamInvocation message"); + + object?[] arguments; + try + { + var parameterTypes = binder.GetParameterTypes(target); + arguments = BindArguments(ref reader, parameterTypes); + } + catch (Exception ex) + { + return new InvocationBindingFailureMessage(invocationId, target, ExceptionDispatchInfo.Capture(ex)); + } + + string[]? streams = null; + // Previous clients will send 5 items, so we check if they sent a stream array or not + if (itemCount > 5) + { + streams = ReadStreamIds(ref reader); + } + + return ApplyHeaders(headers, new StreamInvocationMessage(invocationId, target, arguments, streams)); + } + + private HubMessage CreateStreamItemMessage(ref MessagePackReader reader, IInvocationBinder binder) + { + var headers = ReadHeaders(ref reader); + var invocationId = ReadInvocationId(ref reader); + ThrowIfNullOrEmpty(invocationId, "invocation ID for StreamItem message"); + + object? value; + try + { + var itemType = binder.GetStreamItemType(invocationId); + value = DeserializeObject(ref reader, itemType, "item"); + } + catch (Exception ex) + { + return new StreamBindingFailureMessage(invocationId, ExceptionDispatchInfo.Capture(ex)); + } + + return ApplyHeaders(headers, new StreamItemMessage(invocationId, value)); + } + + private CompletionMessage CreateCompletionMessage(ref MessagePackReader reader, IInvocationBinder binder) + { + var headers = ReadHeaders(ref reader); + var invocationId = ReadInvocationId(ref reader); + ThrowIfNullOrEmpty(invocationId, "invocation ID for Completion message"); + + var resultKind = ReadInt32(ref reader, "resultKind"); + + string? error = null; + object? result = null; + var hasResult = false; + + switch (resultKind) + { + case ErrorResult: + error = ReadString(ref reader, "error"); + break; + case NonVoidResult: + hasResult = true; + var itemType = ProtocolHelper.TryGetReturnType(binder, invocationId); + if (itemType is null) + { + reader.Skip(); + } + else + { + if (itemType == typeof(RawResult)) + { + result = new RawResult(reader.ReadRaw()); + } + else + { + try + { + result = DeserializeObject(ref reader, itemType, "argument"); + } + catch (Exception ex) + { + error = $"Error trying to deserialize result to {itemType.Name}. {ex.Message}"; + hasResult = false; + } + } + } + break; + case VoidResult: + hasResult = false; + break; + default: + throw new InvalidDataException("Invalid invocation result kind."); + } + + return ApplyHeaders(headers, new CompletionMessage(invocationId, error, result, hasResult)); + } + + private static CancelInvocationMessage CreateCancelInvocationMessage(ref MessagePackReader reader) + { + var headers = ReadHeaders(ref reader); + var invocationId = ReadInvocationId(ref reader); + ThrowIfNullOrEmpty(invocationId, "invocation ID for CancelInvocation message"); + + return ApplyHeaders(headers, new CancelInvocationMessage(invocationId)); + } + + private static CloseMessage CreateCloseMessage(ref MessagePackReader reader, int itemCount) + { + var error = ReadString(ref reader, "error"); + var allowReconnect = false; + + if (itemCount > 2) + { + allowReconnect = ReadBoolean(ref reader, "allowReconnect"); + } + + // An empty string is still an error + if (error == null && !allowReconnect) + { + return CloseMessage.Empty; + } + + return new CloseMessage(error, allowReconnect); + } + + private static Dictionary? ReadHeaders(ref MessagePackReader reader) + { + var headerCount = ReadMapLength(ref reader, "headers"); + if (headerCount > 0) + { + var headers = new Dictionary(StringComparer.Ordinal); + + for (var i = 0; i < headerCount; i++) + { + var key = ReadString(ref reader, $"headers[{i}].Key"); + ThrowIfNullOrEmpty(key, "key in header"); + + var value = ReadString(ref reader, $"headers[{i}].Value"); + ThrowIfNullOrEmpty(value, "value in header"); + + headers.Add(key, value); + } + return headers; + } + else + { + return null; + } + } + + private static string[]? ReadStreamIds(ref MessagePackReader reader) + { + var streamIdCount = ReadArrayLength(ref reader, "streamIds"); + List? streams = null; + + if (streamIdCount > 0) + { + streams = new List(); + for (var i = 0; i < streamIdCount; i++) + { + var id = reader.ReadString(); + ThrowIfNullOrEmpty(id, "value in streamIds received"); + + streams.Add(id); + } + } + + return streams?.ToArray(); + } + + private static AckMessage CreateAckMessage(ref MessagePackReader reader) + { + return new AckMessage(ReadInt64(ref reader, "sequenceId")); + } + + private static SequenceMessage CreateSequenceMessage(ref MessagePackReader reader) + { + return new SequenceMessage(ReadInt64(ref reader, "sequenceId")); + } + + private object?[] BindArguments(ref MessagePackReader reader, IReadOnlyList parameterTypes) + { + var argumentCount = ReadArrayLength(ref reader, "arguments"); + + if (parameterTypes.Count != argumentCount) + { + throw new InvalidDataException( + $"Invocation provides {argumentCount} argument(s) but target expects {parameterTypes.Count}."); + } + + try + { + var arguments = new object?[argumentCount]; + for (var i = 0; i < argumentCount; i++) + { + arguments[i] = DeserializeObject(ref reader, parameterTypes[i], "argument"); + } + + return arguments; + } + catch (Exception ex) + { + throw new InvalidDataException("Error binding arguments. Make sure that the types of the provided values match the types of the hub method being invoked.", ex); + } + } + + protected abstract object? DeserializeObject(ref MessagePackReader reader, Type type, string field); + + private static T ApplyHeaders(IDictionary? source, T destination) where T : HubInvocationMessage + { + if (source != null && source.Count > 0) + { + destination.Headers = source; + } + + return destination; + } + + /// + public void WriteMessage(HubMessage message, IBufferWriter output) + { + var memoryBufferWriter = MemoryBufferWriter.Get(); + + try + { + var writer = new MessagePackWriter(memoryBufferWriter); + + // Write message to a buffer so we can get its length + WriteMessageCore(message, ref writer); + + // Write length then message to output + BinaryMessageFormatter.WriteLengthPrefix(memoryBufferWriter.Length, output); + memoryBufferWriter.CopyTo(output); + } + finally + { + MemoryBufferWriter.Return(memoryBufferWriter); + } + } + + /// + public ReadOnlyMemory GetMessageBytes(HubMessage message) + { + var memoryBufferWriter = MemoryBufferWriter.Get(); + + try + { + var writer = new MessagePackWriter(memoryBufferWriter); + + // Write message to a buffer so we can get its length + WriteMessageCore(message, ref writer); + + var dataLength = memoryBufferWriter.Length; + var prefixLength = BinaryMessageFormatter.LengthPrefixLength(memoryBufferWriter.Length); + + var array = new byte[dataLength + prefixLength]; + var span = array.AsSpan(); + + // Write length then message to output + var written = BinaryMessageFormatter.WriteLengthPrefix(memoryBufferWriter.Length, span); + Debug.Assert(written == prefixLength); + memoryBufferWriter.CopyTo(span.Slice(prefixLength)); + + return array; + } + finally + { + MemoryBufferWriter.Return(memoryBufferWriter); + } + } + + private void WriteMessageCore(HubMessage message, ref MessagePackWriter writer) + { + switch (message) + { + case InvocationMessage invocationMessage: + WriteInvocationMessage(invocationMessage, ref writer); + break; + case StreamInvocationMessage streamInvocationMessage: + WriteStreamInvocationMessage(streamInvocationMessage, ref writer); + break; + case StreamItemMessage streamItemMessage: + WriteStreamingItemMessage(streamItemMessage, ref writer); + break; + case CompletionMessage completionMessage: + WriteCompletionMessage(completionMessage, ref writer); + break; + case CancelInvocationMessage cancelInvocationMessage: + WriteCancelInvocationMessage(cancelInvocationMessage, ref writer); + break; + case PingMessage: + WritePingMessage(ref writer); + break; + case CloseMessage closeMessage: + WriteCloseMessage(closeMessage, ref writer); + break; + case AckMessage ackMessage: + WriteAckMessage(ackMessage, ref writer); + break; + case SequenceMessage sequenceMessage: + WriteSequenceMessage(sequenceMessage, ref writer); + break; + default: + throw new InvalidDataException($"Unexpected message type: {message.GetType().Name}"); + } + + writer.Flush(); + } + + private void WriteInvocationMessage(InvocationMessage message, ref MessagePackWriter writer) + { + writer.WriteArrayHeader(6); + + writer.Write(HubProtocolConstants.InvocationMessageType); + PackHeaders(message.Headers, ref writer); + if (string.IsNullOrEmpty(message.InvocationId)) + { + writer.WriteNil(); + } + else + { + writer.Write(message.InvocationId); + } + writer.Write(message.Target); + + if (message.Arguments is null) + { + writer.WriteArrayHeader(0); + } + else + { + writer.WriteArrayHeader(message.Arguments.Length); + foreach (var arg in message.Arguments) + { + WriteArgument(arg, ref writer); + } + } + + WriteStreamIds(message.StreamIds, ref writer); + } + + private void WriteStreamInvocationMessage(StreamInvocationMessage message, ref MessagePackWriter writer) + { + writer.WriteArrayHeader(6); + + writer.Write(HubProtocolConstants.StreamInvocationMessageType); + PackHeaders(message.Headers, ref writer); + writer.Write(message.InvocationId); + writer.Write(message.Target); + + if (message.Arguments is null) + { + writer.WriteArrayHeader(0); + } + else + { + writer.WriteArrayHeader(message.Arguments.Length); + foreach (var arg in message.Arguments) + { + WriteArgument(arg, ref writer); + } + } + + WriteStreamIds(message.StreamIds, ref writer); + } + + private void WriteStreamingItemMessage(StreamItemMessage message, ref MessagePackWriter writer) + { + writer.WriteArrayHeader(4); + writer.Write(HubProtocolConstants.StreamItemMessageType); + PackHeaders(message.Headers, ref writer); + writer.Write(message.InvocationId); + WriteArgument(message.Item, ref writer); + } + + private void WriteArgument(object? argument, ref MessagePackWriter writer) + { + if (argument == null) + { + writer.WriteNil(); + } + else if (argument is RawResult result) + { + writer.WriteRaw(result.RawSerializedData); + } + else + { + Serialize(ref writer, argument.GetType(), argument); + } + } + + protected abstract void Serialize(ref MessagePackWriter writer, Type type, object value); + + private static void WriteStreamIds(string[]? streamIds, ref MessagePackWriter writer) + { + if (streamIds != null) + { + writer.WriteArrayHeader(streamIds.Length); + foreach (var streamId in streamIds) + { + writer.Write(streamId); + } + } + else + { + writer.WriteArrayHeader(0); + } + } + + private void WriteCompletionMessage(CompletionMessage message, ref MessagePackWriter writer) + { + var resultKind = + message.Error != null ? ErrorResult : + message.HasResult ? NonVoidResult : + VoidResult; + + writer.WriteArrayHeader(4 + (resultKind != VoidResult ? 1 : 0)); + writer.Write(HubProtocolConstants.CompletionMessageType); + PackHeaders(message.Headers, ref writer); + writer.Write(message.InvocationId); + writer.Write(resultKind); + switch (resultKind) + { + case ErrorResult: + writer.Write(message.Error); + break; + case NonVoidResult: + WriteArgument(message.Result, ref writer); + break; + } + } + + private static void WriteCancelInvocationMessage(CancelInvocationMessage message, ref MessagePackWriter writer) + { + writer.WriteArrayHeader(3); + writer.Write(HubProtocolConstants.CancelInvocationMessageType); + PackHeaders(message.Headers, ref writer); + writer.Write(message.InvocationId); + } + + private static void WriteCloseMessage(CloseMessage message, ref MessagePackWriter writer) + { + writer.WriteArrayHeader(3); + writer.Write(HubProtocolConstants.CloseMessageType); + if (string.IsNullOrEmpty(message.Error)) + { + writer.WriteNil(); + } + else + { + writer.Write(message.Error); + } + + writer.Write(message.AllowReconnect); + } + + private static void WritePingMessage(ref MessagePackWriter writer) + { + writer.WriteArrayHeader(1); + writer.Write(HubProtocolConstants.PingMessageType); + } + + private static void WriteAckMessage(AckMessage message, ref MessagePackWriter writer) + { + writer.WriteArrayHeader(2); + writer.Write(HubProtocolConstants.AckMessageType); + writer.Write(message.SequenceId); + } + + private static void WriteSequenceMessage(SequenceMessage message, ref MessagePackWriter writer) + { + writer.WriteArrayHeader(2); + writer.Write(HubProtocolConstants.SequenceMessageType); + writer.Write(message.SequenceId); + } + + private static void PackHeaders(IDictionary? headers, ref MessagePackWriter writer) + { + if (headers != null) + { + writer.WriteMapHeader(headers.Count); + if (headers.Count > 0) + { + foreach (var header in headers) + { + writer.Write(header.Key); + writer.Write(header.Value); + } + } + } + else + { + writer.WriteMapHeader(0); + } + } + + private static string? ReadInvocationId(ref MessagePackReader reader) => + ReadString(ref reader, "invocationId"); + + private static bool ReadBoolean(ref MessagePackReader reader, string field) + { + try + { + return reader.ReadBoolean(); + } + catch (Exception ex) + { + throw new InvalidDataException($"Reading '{field}' as Boolean failed.", ex); + } + } + + private static int ReadInt32(ref MessagePackReader reader, string field) + { + try + { + return reader.ReadInt32(); + } + catch (Exception ex) + { + throw new InvalidDataException($"Reading '{field}' as Int32 failed.", ex); + } + } + + private static long ReadInt64(ref MessagePackReader reader, string field) + { + try + { + return reader.ReadInt64(); + } + catch (Exception ex) + { + throw new InvalidDataException($"Reading '{field}' as Int64 failed.", ex); + } + } + + protected static string? ReadString(ref MessagePackReader reader, IInvocationBinder binder, string field) + { + try + { +#if NETCOREAPP + if (reader.TryReadStringSpan(out var span)) + { + return binder.GetTarget(span) ?? Encoding.UTF8.GetString(span); + } + return reader.ReadString(); +#else + return reader.ReadString(); +#endif + } + catch (Exception ex) + { + throw new InvalidDataException($"Reading '{field}' as String failed.", ex); + } + } + + protected static string? ReadString(ref MessagePackReader reader, string field) + { + try + { + return reader.ReadString(); + } + catch (Exception ex) + { + throw new InvalidDataException($"Reading '{field}' as String failed.", ex); + } + } + + private static long ReadMapLength(ref MessagePackReader reader, string field) + { + try + { + return reader.ReadMapHeader(); + } + catch (Exception ex) + { + throw new InvalidDataException($"Reading map length for '{field}' failed.", ex); + } + } + + private static long ReadArrayLength(ref MessagePackReader reader, string field) + { + try + { + return reader.ReadArrayHeader(); + } + catch (Exception ex) + { + throw new InvalidDataException($"Reading array length for '{field}' failed.", ex); + } + } + + private static void ThrowIfNullOrEmpty([NotNull] string? target, string message) + { + if (string.IsNullOrEmpty(target)) + { + throw new InvalidDataException($"Null or empty {message}."); + } + } +} \ No newline at end of file