From ed1086d3853aaeb762d38548a4fd61b789b3691f Mon Sep 17 00:00:00 2001 From: Terence Fan Date: Tue, 11 Jul 2023 10:15:24 +0800 Subject: [PATCH] Add ServerEndpoint support for Azure AD scenario (#1805) * Add ServerEndpoint support for Azure AD scenario * add threadsafe lock --- .../Endpoints/AadAccessKey.cs | 10 +-- .../Endpoints/AccessKey.cs | 5 +- .../Endpoints/ServiceEndpoint.cs | 41 ++++++++--- .../Endpoints/UriExtensions.cs | 16 +++++ .../Utilities/ConnectionStringParser.cs | 68 +++++++++---------- .../Auth/AadAccessKeyTests.cs | 1 - .../Auth/ConnectionStringParserTests.cs | 25 +++++-- .../ServiceEndpointFacts.cs | 44 +++++++++--- 8 files changed, 143 insertions(+), 67 deletions(-) create mode 100644 src/Microsoft.Azure.SignalR.Common/Endpoints/UriExtensions.cs diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/AadAccessKey.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/AadAccessKey.cs index 58dc7d21f..76b982875 100644 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/AadAccessKey.cs +++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/AadAccessKey.cs @@ -61,14 +61,10 @@ private set private Task InitializedTask => _initializedTcs.Task; - public AadAccessKey(Uri uri, TokenCredential credential) : base(uri) + public AadAccessKey(Uri endpoint, TokenCredential credential, Uri serverEndpoint = null) : base(endpoint) { - var builder = new UriBuilder(Endpoint) - { - Path = "/api/v1/auth/accessKey", - Port = uri.Port - }; - AuthorizeUrl = builder.Uri.AbsoluteUri; + var authorizeUri = (serverEndpoint ?? endpoint).Append("/api/v1/auth/accessKey"); + AuthorizeUrl = authorizeUri.AbsoluteUri; TokenCredential = credential; } diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKey.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKey.cs index 5e0ab8b6d..b16d4bff6 100644 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKey.cs +++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKey.cs @@ -12,12 +12,13 @@ namespace Microsoft.Azure.SignalR internal class AccessKey { public string Id => Key?.Item1; - public string Value => Key?.Item2; - protected Tuple Key { get; set; } + public string Value => Key?.Item2; public Uri Endpoint { get; } + protected Tuple Key { get; set; } + public AccessKey(string uri, string key) : this(new Uri(uri)) { Key = new Tuple(key.GetHashCode().ToString(), key); diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs index b0d67456e..d9588ecb2 100644 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs +++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs @@ -10,9 +10,17 @@ namespace Microsoft.Azure.SignalR public class ServiceEndpoint { private readonly Uri _serviceEndpoint; + private readonly Uri _serverEndpoint; + private readonly Uri _clientEndpoint; + private readonly TokenCredential _tokenCredential; + + private readonly object _lock = new object(); + + private volatile AccessKey _accessKey; + public string ConnectionString { get; } public EndpointType EndpointType { get; } = EndpointType.Primary; @@ -42,6 +50,7 @@ public Uri ClientEndpoint _clientEndpoint = value; } } + /// /// When current app server instance has server connections connected to the target endpoint for current hub, it can deliver messages to that endpoint. /// The endpoint is then considered as *Online*; otherwise, *Offline*. @@ -69,7 +78,21 @@ public Uri ClientEndpoint internal string Version { get; } - internal AccessKey AccessKey { get; private set; } + internal AccessKey AccessKey + { + get + { + if (_accessKey is null) + { + lock (_lock) + { + _accessKey ??= new AadAccessKey(_serviceEndpoint, _tokenCredential, ServerEndpoint); + } + } + return _accessKey; + } + private init => _accessKey = value; + } // Flag to indicate an updaing endpoint needs staging internal virtual bool PendingReload { get; set; } @@ -132,16 +155,18 @@ public ServiceEndpoint(string nameWithEndpointType, Uri endpoint, TokenCredentia /// The endpoint name. /// The endpoint for servers to connect to Azure SignalR. /// The endpoint for clients to connect to Azure SignalR. - public ServiceEndpoint(Uri endpoint, TokenCredential credential, EndpointType endpointType = EndpointType.Primary, string name = "", - Uri serverEndpoint = null, Uri clientEndpoint = null) + public ServiceEndpoint(Uri endpoint, + TokenCredential credential, + EndpointType endpointType = EndpointType.Primary, + string name = "", + Uri serverEndpoint = null, + Uri clientEndpoint = null) { _serviceEndpoint = endpoint ?? throw new ArgumentNullException(nameof(endpoint)); CheckScheme(endpoint); - if (credential is null) - { - throw new ArgumentNullException(nameof(credential)); - } - AccessKey = new AadAccessKey(endpoint, credential); + + _tokenCredential = credential ?? throw new ArgumentNullException(nameof(credential)); + EndpointType = endpointType; Name = name; diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/UriExtensions.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/UriExtensions.cs new file mode 100644 index 000000000..dd05e87cc --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/UriExtensions.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; + +namespace Microsoft.Azure.SignalR +{ + internal static class UriExtensions + { + public static Uri Append(this Uri uri, params string[] paths) + { + return new Uri(paths.Aggregate(uri.AbsoluteUri, (current, path) => string.Format("{0}/{1}", current.TrimEnd('/'), path.TrimStart('/')))); + } + } +} diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs index 1e713f0a9..c9c75f32c 100644 --- a/src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs @@ -24,12 +24,12 @@ internal static class ConnectionStringParser private const string EndpointProperty = "endpoint"; + private const string ServerEndpointProperty = "ServerEndpoint"; + private const string InvalidVersionValueFormat = "Version {0} is not supported."; private const string PortProperty = "port"; - private const string ServerEndpoint = "ServerEndpoint"; - // For SDK 1.x, only support Azure SignalR Service 1.x private const string SupportedVersion = "1"; @@ -114,6 +114,7 @@ internal static ParsedConnectionString Parse(string connectionString) } Uri clientEndpointUri = null; + Uri serverEndpointUri = null; // parse and validate clientEndpoint. if (dict.TryGetValue(ClientEndpointProperty, out var clientEndpoint)) @@ -124,25 +125,26 @@ internal static ParsedConnectionString Parse(string connectionString) } } + // parse and validate clientEndpoint. + if (dict.TryGetValue(ServerEndpointProperty, out var serverEndpoint)) + { + if (!TryGetEndpointUri(serverEndpoint, out serverEndpointUri)) + { + throw new ArgumentException($"{ServerEndpointProperty} property in connection string is not a valid URI: {serverEndpoint}."); + } + } + // try building accesskey. dict.TryGetValue(AuthTypeProperty, out var type); var accessKey = type?.ToLower() switch { - TypeAzureAD => BuildAadAccessKey(builder.Uri, dict), - TypeAzure => BuildAzureAccessKey(builder.Uri, dict), - TypeAzureApp => BuildAzureAppAccessKey(builder.Uri, dict), - TypeAzureMsi => BuildAzureMsiAccessKey(builder.Uri, dict), + TypeAzureAD => BuildAzureADAccessKey(builder.Uri, serverEndpointUri, dict), + TypeAzure => BuildAzureAccessKey(builder.Uri, serverEndpointUri, dict), + TypeAzureApp => BuildAzureAppAccessKey(builder.Uri, serverEndpointUri, dict), + TypeAzureMsi => BuildAzureMsiAccessKey(builder.Uri, serverEndpointUri, dict), _ => BuildAccessKey(builder.Uri, dict), }; - Uri serverEndpointUri = null; - if (dict.TryGetValue(ServerEndpoint, out var serverEndpoint)) - { - if (!TryGetEndpointUri(serverEndpoint, out serverEndpointUri)) - { - throw new ArgumentException($"{ServerEndpoint} property in connection string is not a valid URI: {serverEndpoint}."); - } - } return new ParsedConnectionString() { Endpoint = builder.Uri, @@ -159,7 +161,7 @@ internal static bool TryGetEndpointUri(string endpoint, out Uri uriResult) (uriResult.Scheme == Uri.UriSchemeHttp || uriResult.Scheme == Uri.UriSchemeHttps); } - private static AccessKey BuildAadAccessKey(Uri uri, Dictionary dict) + private static AccessKey BuildAzureADAccessKey(Uri uri, Uri serverEndpointUri, Dictionary dict) { if (dict.TryGetValue(ClientIdProperty, out var clientId)) { @@ -167,11 +169,11 @@ private static AccessKey BuildAadAccessKey(Uri uri, Dictionary d { if (dict.TryGetValue(ClientSecretProperty, out var clientSecret)) { - return new AadAccessKey(uri, new ClientSecretCredential(tenantId, clientId, clientSecret)); + return new AadAccessKey(uri, new ClientSecretCredential(tenantId, clientId, clientSecret), serverEndpointUri); } else if (dict.TryGetValue(ClientCertProperty, out var clientCertPath)) { - return new AadAccessKey(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath)); + return new AadAccessKey(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath), serverEndpointUri); } else { @@ -180,30 +182,28 @@ private static AccessKey BuildAadAccessKey(Uri uri, Dictionary d } else { - return new AadAccessKey(uri, new ManagedIdentityCredential(clientId)); + return new AadAccessKey(uri, new ManagedIdentityCredential(clientId), serverEndpointUri); } } else { - return new AadAccessKey(uri, new ManagedIdentityCredential()); + return new AadAccessKey(uri, new ManagedIdentityCredential(), serverEndpointUri); } } private static AccessKey BuildAccessKey(Uri uri, Dictionary dict) { - if (dict.TryGetValue(AccessKeyProperty, out var key)) - { - return new AccessKey(uri, key); - } - throw new ArgumentException(MissingAccessKeyProperty, AccessKeyProperty); + return dict.TryGetValue(AccessKeyProperty, out var key) + ? new AccessKey(uri, key) + : throw new ArgumentException(MissingAccessKeyProperty, AccessKeyProperty); } - private static AccessKey BuildAzureAccessKey(Uri uri, Dictionary dict) + private static AccessKey BuildAzureAccessKey(Uri uri, Uri serverEndpointUri, Dictionary dict) { - return new AadAccessKey(uri, new DefaultAzureCredential()); + return new AadAccessKey(uri, new DefaultAzureCredential(), serverEndpointUri); } - private static AccessKey BuildAzureAppAccessKey(Uri uri, Dictionary dict) + private static AccessKey BuildAzureAppAccessKey(Uri uri, Uri serverEndpointUri, Dictionary dict) { if (!dict.TryGetValue(ClientIdProperty, out var clientId)) { @@ -217,22 +217,20 @@ private static AccessKey BuildAzureAppAccessKey(Uri uri, Dictionary dict) + private static AccessKey BuildAzureMsiAccessKey(Uri uri, Uri serverEndpointUri, Dictionary dict) { - if (dict.TryGetValue(ClientIdProperty, out var clientId)) - { - return new AadAccessKey(uri, new ManagedIdentityCredential(clientId)); - } - return new AadAccessKey(uri, new ManagedIdentityCredential()); + return dict.TryGetValue(ClientIdProperty, out var clientId) + ? new AadAccessKey(uri, new ManagedIdentityCredential(clientId), serverEndpointUri) + : new AadAccessKey(uri, new ManagedIdentityCredential(), serverEndpointUri); } private static Dictionary ToDictionary(string connectionString) diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AadAccessKeyTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AadAccessKeyTests.cs index bb0e0593e..d10eed434 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AadAccessKeyTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AadAccessKeyTests.cs @@ -3,7 +3,6 @@ using System.Security.Claims; using System.Threading; using System.Threading.Tasks; - using Azure.Identity; using Xunit; diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs index d7f5d2670..4bb555f13 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs @@ -101,8 +101,8 @@ public void TestAzureApplication(string connectionString) { var r = ConnectionStringParser.Parse(connectionString); - var aadAccessKey = Assert.IsType(r.AccessKey); - Assert.IsType(aadAccessKey.TokenCredential); + var key = Assert.IsType(r.AccessKey); + Assert.IsType(key.TokenCredential); Assert.Same(r.Endpoint, r.AccessKey.Endpoint); Assert.Null(r.Version); Assert.Null(r.ClientEndpoint); @@ -148,8 +148,8 @@ internal void TestDefaultAzureCredential(string expectedEndpoint, string connect var r = ConnectionStringParser.Parse(connectionString); Assert.Equal(expectedEndpoint, r.Endpoint.AbsoluteUri.TrimEnd('/')); - var aadAccessKey = Assert.IsType(r.AccessKey); - Assert.IsType(aadAccessKey.TokenCredential); + var key = Assert.IsType(r.AccessKey); + Assert.IsType(key.TokenCredential); Assert.Same(r.Endpoint, r.AccessKey.Endpoint); } @@ -165,12 +165,25 @@ internal void TestManagedIdentity(string expectedEndpoint, string connectionStri var r = ConnectionStringParser.Parse(connectionString); Assert.Equal(expectedEndpoint, r.Endpoint.AbsoluteUri.TrimEnd('/')); - var aadAccessKey = Assert.IsType(r.AccessKey); - Assert.IsType(aadAccessKey.TokenCredential); + var key = Assert.IsType(r.AccessKey); + Assert.IsType(key.TokenCredential); Assert.Same(r.Endpoint, r.AccessKey.Endpoint); Assert.Null(r.ClientEndpoint); } + [Theory] + [InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo", "https://foo/api/v1/auth/accesskey")] + [InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo:123", "https://foo:123/api/v1/auth/accesskey")] + [InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo/bar", "https://foo/bar/api/v1/auth/accesskey")] + [InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo/bar/", "https://foo/bar/api/v1/auth/accesskey")] + [InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo:123/bar/", "https://foo:123/bar/api/v1/auth/accesskey")] + internal void TestAzureADWithServerEndpoint(string connectionString, string expectedAuthorizeUrl) + { + var r = ConnectionStringParser.Parse(connectionString); + var key = Assert.IsType(r.AccessKey); + Assert.Equal(expectedAuthorizeUrl, key.AuthorizeUrl, StringComparer.OrdinalIgnoreCase); + } + public class ClientEndpointTestData : IEnumerable { public IEnumerator GetEnumerator() diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/ServiceEndpointFacts.cs b/test/Microsoft.Azure.SignalR.Common.Tests/ServiceEndpointFacts.cs index ce6039bb3..dd89383d1 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/ServiceEndpointFacts.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/ServiceEndpointFacts.cs @@ -65,11 +65,11 @@ public void TestCustomizeEndpointInConstructor() var serverEndpoint = new Uri("http://serverEndpoint:123/path"); var endpoint = "https://test.service.signalr.net"; var serviceEndpoints = new ServiceEndpoint[]{ - new ServiceEndpoint(new Uri(endpoint), new DefaultAzureCredential()) - { - ClientEndpoint = clientEndpoint, - ServerEndpoint = serverEndpoint - }, + new ServiceEndpoint(new Uri(endpoint), new DefaultAzureCredential()) + { + ClientEndpoint = clientEndpoint, + ServerEndpoint = serverEndpoint + }, new ServiceEndpoint($"Endpoint={endpoint};AccessKey={DefaultKey}") { ClientEndpoint = clientEndpoint, @@ -113,7 +113,7 @@ public void TestCreateServiceEndpointFromAnother() [InlineData("http://localhost/", "http://localhost", 80)] [InlineData("http://localhost/foo", "http://localhost", 80)] [InlineData("https://localhost/foo/", "https://localhost", 443)] - public void TestAadConstructor(string url, string expectedEndpoint, int port) + public void TestAzureADConstructor(string url, string expectedEndpoint, int port) { var uri = new Uri(url); var serviceEndpoint = new ServiceEndpoint(uri, new DefaultAzureCredential()); @@ -129,7 +129,7 @@ public void TestAadConstructor(string url, string expectedEndpoint, int port) [InlineData("ftp://localhost")] [InlineData("ws://localhost")] [InlineData("localhost:5050")] - public void TestAadConstructorThrowsError(string url) + public void TestAzureADConstructorThrowsError(string url) { var uri = new Uri(url); Assert.Throws(() => new ServiceEndpoint(uri, new DefaultAzureCredential())); @@ -146,7 +146,7 @@ public void TestAadConstructorThrowsError(string url) [InlineData(":bar", ":bar", EndpointType.Primary)] [InlineData(":primary", "", EndpointType.Primary)] [InlineData(":secondary", "", EndpointType.Secondary)] - public void TestAadConstructorWithKey(string key, string name, EndpointType type) + public void TestAzureADConstructorWithKey(string key, string name, EndpointType type) { var uri = new Uri("http://localhost"); var serviceEndpoint = new ServiceEndpoint(key, uri, new DefaultAzureCredential()); @@ -156,6 +156,34 @@ public void TestAadConstructorWithKey(string key, string name, EndpointType type TestCopyConstructor(serviceEndpoint); } + [Fact] + public void TestAzureADConstructorWithServerEndpoint() + { + var serverEndpoint1 = new Uri("http://serverEndpoint:123"); + var serverEndpoint2 = new Uri("http://serverEndpoint:123/path"); + var serviceEndpoint = "https://test.service.signalr.net"; + var endpoint = new ServiceEndpoint(new Uri(serviceEndpoint), new DefaultAzureCredential()) + { + ServerEndpoint = serverEndpoint1 + }; + var key = Assert.IsType(endpoint.AccessKey); + Assert.Same(key, endpoint.AccessKey); + Assert.Equal("http://serverEndpoint:123/api/v1/auth/accessKey", key.AuthorizeUrl, StringComparer.OrdinalIgnoreCase); + + endpoint = new ServiceEndpoint(new Uri(serviceEndpoint), new DefaultAzureCredential(), serverEndpoint: serverEndpoint2); + key = Assert.IsType(endpoint.AccessKey); + Assert.Same(key, endpoint.AccessKey); + Assert.Equal("http://serverEndpoint:123/path/api/v1/auth/accessKey", key.AuthorizeUrl, StringComparer.OrdinalIgnoreCase); + + endpoint = new ServiceEndpoint(new Uri(serviceEndpoint), new DefaultAzureCredential(), serverEndpoint: serverEndpoint1) + { + ServerEndpoint = serverEndpoint2 // property initialize should override constructor param. + }; + key = Assert.IsType(endpoint.AccessKey); + Assert.Same(key, endpoint.AccessKey); + Assert.Equal("http://serverEndpoint:123/path/api/v1/auth/accessKey", key.AuthorizeUrl, StringComparer.OrdinalIgnoreCase); + } + [Theory] [ClassData(typeof(EndpointEqualityTestData))] public void TestEndpointsEquality(ServiceEndpoint first, ServiceEndpoint second, bool expected)