Skip to content

Commit

Permalink
Merge branch 'dev' into v1.19.1
Browse files Browse the repository at this point in the history
  • Loading branch information
JialinXin committed Nov 21, 2022
2 parents 3470d1f + 0800589 commit 8d74734
Show file tree
Hide file tree
Showing 38 changed files with 489 additions and 243 deletions.
7 changes: 7 additions & 0 deletions AzureSignalR.sln
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ChatSample.Net60", "samples
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ChatSample.Net70", "samples\ChatSample\ChatSample.Net70\ChatSample.Net70.csproj", "{49634EE4-A0F4-4672-A8B3-B994CF81C9AB}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ManagementPublisher", "samples\ChatSample\ChatSample.ManagementPublisher\ManagementPublisher.csproj", "{0F32E624-7AC8-4CA7-8ED9-E1A877442020}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -196,6 +198,10 @@ Global
{49634EE4-A0F4-4672-A8B3-B994CF81C9AB}.Debug|Any CPU.Build.0 = Debug|Any CPU
{49634EE4-A0F4-4672-A8B3-B994CF81C9AB}.Release|Any CPU.ActiveCfg = Release|Any CPU
{49634EE4-A0F4-4672-A8B3-B994CF81C9AB}.Release|Any CPU.Build.0 = Release|Any CPU
{0F32E624-7AC8-4CA7-8ED9-E1A877442020}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{0F32E624-7AC8-4CA7-8ED9-E1A877442020}.Debug|Any CPU.Build.0 = Debug|Any CPU
{0F32E624-7AC8-4CA7-8ED9-E1A877442020}.Release|Any CPU.ActiveCfg = Release|Any CPU
{0F32E624-7AC8-4CA7-8ED9-E1A877442020}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -230,6 +236,7 @@ Global
{82C1FF3D-EC6C-4B21-B6A4-E69E8D75D0D0} = {2429FBD8-1FCE-4C42-AA28-DF32F7249E77}
{594EC59A-7305-4A36-8BE6-4A928FBFD71B} = {C965ED06-6A17-4329-B3C6-811830F4F4ED}
{49634EE4-A0F4-4672-A8B3-B994CF81C9AB} = {C965ED06-6A17-4329-B3C6-811830F4F4ED}
{0F32E624-7AC8-4CA7-8ED9-E1A877442020} = {C965ED06-6A17-4329-B3C6-811830F4F4ED}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {7945A4E4-ACDB-4F6E-95CA-6AC6E7C2CD59}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net6.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\src\Microsoft.Azure.SignalR.Management\Microsoft.Azure.SignalR.Management.csproj" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using Microsoft.Azure.SignalR.Management;

// A simple library file to test cross project reference issue: https://github.com/Azure/azure-signalr/issues/1720
namespace ManagementPublisher
{
internal class MessagePublisher
{
private const string Target = "Target";
private const string HubName = "Chat";
private ServiceHubContext? _hubContext;

public async Task InitAsync(string connectionString, ServiceTransportType transportType = ServiceTransportType.Transient)
{
var serviceManager = new ServiceManagerBuilder().WithOptions(option =>
{
option.ConnectionString = connectionString;
option.ServiceTransportType = transportType;
})
// Uncomment the following line to get more logs
//.WithLoggerFactory(LoggerFactory.Create(builder => builder.AddConsole()))
.BuildServiceManager();

_hubContext = await serviceManager.CreateHubContextAsync(HubName, default);
}

public Task SendMessages(string command, string? receiver, string message)
{
if (_hubContext == null)
{
throw new ArgumentNullException(nameof(_hubContext));
}
switch (command)
{
case "broadcast":
return _hubContext.Clients.All.SendCoreAsync(Target, new[] { message });
case "user":
var userId = receiver ?? throw new ArgumentNullException(nameof(receiver));
return _hubContext.Clients.User(userId).SendCoreAsync(Target, new[] { message });
case "group":
var groupName = receiver ?? throw new ArgumentNullException(nameof(receiver));
return _hubContext.Clients.Group(groupName).SendCoreAsync(Target, new[] { message });
default:
Console.WriteLine($"Can't recognize command {command}");
return Task.CompletedTask;
}
}

public Task CloseConnection(string connectionId, string reason)
{
if (_hubContext == null)
{
throw new ArgumentNullException(nameof(_hubContext));
}
return _hubContext.ClientManager.CloseConnectionAsync(connectionId, reason);
}

public Task<bool> CheckExist(string type, string id)
{
if (_hubContext == null)
{
throw new ArgumentNullException(nameof(_hubContext));
}
return type switch
{
"connection" => _hubContext.ClientManager.ConnectionExistsAsync(id),
"user" => _hubContext.ClientManager.UserExistsAsync(id),
"group" => _hubContext.ClientManager.UserExistsAsync(id),
_ => throw new NotSupportedException(),
};
}

public async Task DisposeAsync()
{
if (_hubContext != null)
{
await _hubContext.DisposeAsync();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,26 @@ internal class ServiceEndpointProvider : IServiceEndpointProvider
"or explicitly pass one using IAppBuilder.RunAzureSignalR(connectionString) in Startup.ConfigureServices.";

private const string ClientPath = "aspnetclient";

private const string ServerPath = "aspnetserver";

private readonly string _audienceBaseUrl;

private readonly string _clientEndpoint;

private readonly string _serverEndpoint;

private readonly AccessKey _accessKey;

private readonly string _appName;

private readonly TimeSpan _accessTokenLifetime;

private readonly AccessTokenAlgorithm _algorithm;

public IWebProxy Proxy { get; }

public ServiceEndpointProvider(
ServiceEndpoint endpoint,
ServiceOptions options)
public ServiceEndpointProvider(ServiceEndpoint endpoint, ServiceOptions options)
{
_accessTokenLifetime = options.AccessTokenLifetime;

Expand All @@ -48,39 +52,13 @@ public ServiceEndpointProvider(
Proxy = options.Proxy;
}

private string GetPrefixedHubName(string applicationName, string hubName)
{
return string.IsNullOrEmpty(applicationName) ? hubName.ToLower() : $"{applicationName.ToLower()}_{hubName.ToLower()}";
}

public Task<string> GenerateClientAccessTokenAsync(string hubName = null, IEnumerable<Claim> claims = null, TimeSpan? lifetime = null)
{
var audience = $"{_audienceBaseUrl}{ClientPath}";

return _accessKey.GenerateAccessTokenAsync(audience, claims, lifetime ?? _accessTokenLifetime, _algorithm);
}

public Task<string> GenerateServerAccessTokenAsync(string hubName, string userId, TimeSpan? lifetime = null)
{
if (_accessKey is AadAccessKey key)
{
return key.GenerateAadTokenAsync();
}

IEnumerable<Claim> claims = null;
if (userId != null)
{
claims = new[]
{
new Claim(ClaimTypes.NameIdentifier, userId)
};
}

var audience = $"{_audienceBaseUrl}{ServerPath}/?hub={GetPrefixedHubName(_appName, hubName)}";

return _accessKey.GenerateAccessTokenAsync(audience, claims, lifetime ?? _accessTokenLifetime, _algorithm);
}

public string GetClientEndpoint(string hubName = null, string originalPath = null, string queryString = null)
{
var queryBuilder = new StringBuilder();
Expand Down Expand Up @@ -114,5 +92,28 @@ public string GetServerEndpoint(string hubName)
{
return $"{_serverEndpoint}{ServerPath}/?hub={GetPrefixedHubName(_appName, hubName)}";
}

public IAccessTokenProvider GetServerAccessTokenProvider(string hubName, string serverId)
{
if (_accessKey is AadAccessKey aadAccessKey)
{
return new AadTokenProvider(aadAccessKey);
}
else if (_accessKey is not null)
{
var audience = $"{_audienceBaseUrl}{ServerPath}/?hub={GetPrefixedHubName(_appName, hubName)}";
var claims = serverId != null ? new[] { new Claim(ClaimTypes.NameIdentifier, serverId) } : null;
return new LocalTokenProvider(_accessKey, audience, claims, _algorithm, _accessTokenLifetime);
}
else
{
throw new ArgumentNullException(nameof(AccessKey));
}
}

private string GetPrefixedHubName(string applicationName, string hubName)
{
return string.IsNullOrEmpty(applicationName) ? hubName.ToLower() : $"{applicationName.ToLower()}_{hubName.ToLower()}";
}
}
}
20 changes: 20 additions & 0 deletions src/Microsoft.Azure.SignalR.Common/Auth/AadTokenProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Threading.Tasks;

namespace Microsoft.Azure.SignalR
{
internal class AadTokenProvider : IAccessTokenProvider
{
private readonly AadAccessKey _accessKey;

public AadTokenProvider(AadAccessKey accessKey)
{
_accessKey = accessKey ?? throw new ArgumentNullException(nameof(accessKey));
}

public Task<string> ProvideAsync() => _accessKey.GenerateAadTokenAsync();
}
}
39 changes: 39 additions & 0 deletions src/Microsoft.Azure.SignalR.Common/Auth/LocalTokenProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Security.Claims;
using System.Threading.Tasks;

namespace Microsoft.Azure.SignalR
{
internal class LocalTokenProvider : IAccessTokenProvider
{
private readonly AccessKey _accessKey;

private readonly AccessTokenAlgorithm _algorithm;

private readonly string _audience;

private readonly TimeSpan _tokenLifetime;

private readonly IEnumerable<Claim> _claims;

public LocalTokenProvider(
AccessKey accessKey,
string audience,
IEnumerable<Claim> claims,
AccessTokenAlgorithm algorithm = AccessTokenAlgorithm.HS256,
TimeSpan? tokenLifetime = null)
{
_accessKey = accessKey ?? throw new ArgumentNullException(nameof(accessKey));
_algorithm = algorithm;
_audience = audience;
_claims = claims;
_tokenLifetime = tokenLifetime ?? Constants.Periods.DefaultAccessTokenLifetime;
}

public Task<string> ProvideAsync() => _accessKey.GenerateAccessTokenAsync(_audience, _claims, _tokenLifetime, _algorithm);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Threading.Tasks;

namespace Microsoft.Azure.SignalR
{
internal interface IAccessTokenProvider
{
Task<string> ProvideAsync();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ internal interface IServiceEndpointProvider

string GetClientEndpoint(string hubName, string originalPath, string queryString);

Task<string> GenerateServerAccessTokenAsync(string hubName, string userId, TimeSpan? lifetime = null);
IAccessTokenProvider GetServerAccessTokenProvider(string hubName, string serverId);

string GetServerEndpoint(string hubName);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ public partial class SignalRServiceRestClient
{
private readonly string _userAgent;

public SignalRServiceRestClient(string userAgent, ServiceClientCredentials credentials, HttpClient httpClient, bool disposeHttpClient) : this(credentials, httpClient, disposeHttpClient)
public SignalRServiceRestClient(string userAgent,
ServiceClientCredentials credentials,
HttpClient httpClient,
bool disposeHttpClient) : this(credentials, httpClient, disposeHttpClient)
{
if (string.IsNullOrWhiteSpace(userAgent))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace Microsoft.Azure.SignalR
internal class ConnectionFactory : IConnectionFactory
{
private readonly ILoggerFactory _loggerFactory;

private readonly string _serverId;

public ConnectionFactory(IServerNameProvider nameProvider, ILoggerFactory loggerFactory)
Expand All @@ -31,7 +32,9 @@ public async Task<ConnectionContext> ConnectAsync(HubServiceEndpoint hubServiceE
{
var provider = hubServiceEndpoint.Provider;
var hubName = hubServiceEndpoint.Hub;
Task<string> accessTokenGenerater() => provider.GenerateServerAccessTokenAsync(hubName, _serverId);

var accessTokenProvider = provider.GetServerAccessTokenProvider(hubName, _serverId);

var url = GetServiceUrl(provider, hubName, connectionId, target);

headers ??= new Dictionary<string, string>();
Expand All @@ -45,7 +48,7 @@ public async Task<ConnectionContext> ConnectAsync(HubServiceEndpoint hubServiceE
Headers = headers,
Proxy = provider.Proxy,
};
var connection = new WebSocketConnectionContext(connectionOptions, _loggerFactory, accessTokenGenerater);
var connection = new WebSocketConnectionContext(connectionOptions, _loggerFactory, accessTokenProvider);
try
{
await connection.StartAsync(url, cancellationToken);
Expand Down Expand Up @@ -91,6 +94,7 @@ private Uri GetServiceUrl(IServiceEndpointProvider provider, string hubName, str
private sealed class GracefulLoggerFactory : ILoggerFactory
{
private readonly ILoggerFactory _inner;

public GracefulLoggerFactory(ILoggerFactory inner)
{
_inner = inner;
Expand All @@ -115,6 +119,7 @@ public void AddProvider(ILoggerProvider provider)
private sealed class GracefulLogger : ILogger
{
private readonly ILogger _inner;

public GracefulLogger(ILogger inner)
{
_inner = inner;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ internal partial class WebSocketsTransport : IDuplexPipe

private readonly WebSocketMessageType _webSocketMessageType = WebSocketMessageType.Binary;
private readonly ClientWebSocket _webSocket;
private readonly Func<Task<string>> _accessTokenProvider;
private readonly IAccessTokenProvider _accessTokenProvider;
private IDuplexPipe _application;
private readonly ILogger _logger;
private readonly TimeSpan _closeTimeout;
Expand All @@ -37,7 +37,9 @@ internal partial class WebSocketsTransport : IDuplexPipe

public PipeWriter Output => _transport.Output;

public WebSocketsTransport(WebSocketConnectionOptions connectionOptions, ILoggerFactory loggerFactory, Func<Task<string>> accessTokenProvider)
public WebSocketsTransport(WebSocketConnectionOptions connectionOptions,
ILoggerFactory loggerFactory,
IAccessTokenProvider accessTokenProvider)
{
_logger = (loggerFactory ?? throw new ArgumentNullException(nameof(loggerFactory))).CreateLogger<WebSocketsTransport>();
_webSocket = new ClientWebSocket();
Expand Down Expand Up @@ -105,7 +107,7 @@ public async Task StartAsync(Uri url, CancellationToken cancellationToken = defa
// We don't need to capture to a local because we never change this delegate.
if (_accessTokenProvider != null)
{
var accessToken = await _accessTokenProvider();
var accessToken = await _accessTokenProvider.ProvideAsync();
if (!string.IsNullOrEmpty(accessToken))
{
_webSocket.Options.SetRequestHeader("Authorization", $"Bearer {accessToken}");
Expand Down
Loading

0 comments on commit 8d74734

Please sign in to comment.