Skip to content

Commit

Permalink
Add ServerEndpoint support for Azure AD scenario (#1805)
Browse files Browse the repository at this point in the history
* Add ServerEndpoint support for Azure AD scenario

* add threadsafe lock
  • Loading branch information
terencefan authored Jul 11, 2023
1 parent 1d1bc29 commit ed1086d
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 67 deletions.
10 changes: 3 additions & 7 deletions src/Microsoft.Azure.SignalR.Common/Endpoints/AadAccessKey.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,10 @@ private set

private Task<object> InitializedTask => _initializedTcs.Task;

public AadAccessKey(Uri uri, TokenCredential credential) : base(uri)
public AadAccessKey(Uri endpoint, TokenCredential credential, Uri serverEndpoint = null) : base(endpoint)
{
var builder = new UriBuilder(Endpoint)
{
Path = "/api/v1/auth/accessKey",
Port = uri.Port
};
AuthorizeUrl = builder.Uri.AbsoluteUri;
var authorizeUri = (serverEndpoint ?? endpoint).Append("/api/v1/auth/accessKey");
AuthorizeUrl = authorizeUri.AbsoluteUri;
TokenCredential = credential;
}

Expand Down
5 changes: 3 additions & 2 deletions src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKey.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ namespace Microsoft.Azure.SignalR
internal class AccessKey
{
public string Id => Key?.Item1;
public string Value => Key?.Item2;

protected Tuple<string, string> Key { get; set; }
public string Value => Key?.Item2;

public Uri Endpoint { get; }

protected Tuple<string, string> Key { get; set; }

public AccessKey(string uri, string key) : this(new Uri(uri))
{
Key = new Tuple<string, string>(key.GetHashCode().ToString(), key);
Expand Down
41 changes: 33 additions & 8 deletions src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,17 @@ namespace Microsoft.Azure.SignalR
public class ServiceEndpoint
{
private readonly Uri _serviceEndpoint;

private readonly Uri _serverEndpoint;

private readonly Uri _clientEndpoint;

private readonly TokenCredential _tokenCredential;

private readonly object _lock = new object();

private volatile AccessKey _accessKey;

public string ConnectionString { get; }

public EndpointType EndpointType { get; } = EndpointType.Primary;
Expand Down Expand Up @@ -42,6 +50,7 @@ public Uri ClientEndpoint
_clientEndpoint = value;
}
}

/// <summary>
/// When current app server instance has server connections connected to the target endpoint for current hub, it can deliver messages to that endpoint.
/// The endpoint is then considered as *Online*; otherwise, *Offline*.
Expand Down Expand Up @@ -69,7 +78,21 @@ public Uri ClientEndpoint

internal string Version { get; }

internal AccessKey AccessKey { get; private set; }
internal AccessKey AccessKey
{
get
{
if (_accessKey is null)
{
lock (_lock)
{
_accessKey ??= new AadAccessKey(_serviceEndpoint, _tokenCredential, ServerEndpoint);
}
}
return _accessKey;
}
private init => _accessKey = value;
}

// Flag to indicate an updaing endpoint needs staging
internal virtual bool PendingReload { get; set; }
Expand Down Expand Up @@ -132,16 +155,18 @@ public ServiceEndpoint(string nameWithEndpointType, Uri endpoint, TokenCredentia
/// <param name="name">The endpoint name.</param>
/// <param name="serverEndpoint">The endpoint for servers to connect to Azure SignalR.</param>
/// <param name="clientEndpoint">The endpoint for clients to connect to Azure SignalR.</param>
public ServiceEndpoint(Uri endpoint, TokenCredential credential, EndpointType endpointType = EndpointType.Primary, string name = "",
Uri serverEndpoint = null, Uri clientEndpoint = null)
public ServiceEndpoint(Uri endpoint,
TokenCredential credential,
EndpointType endpointType = EndpointType.Primary,
string name = "",
Uri serverEndpoint = null,
Uri clientEndpoint = null)
{
_serviceEndpoint = endpoint ?? throw new ArgumentNullException(nameof(endpoint));
CheckScheme(endpoint);
if (credential is null)
{
throw new ArgumentNullException(nameof(credential));
}
AccessKey = new AadAccessKey(endpoint, credential);

_tokenCredential = credential ?? throw new ArgumentNullException(nameof(credential));

EndpointType = endpointType;
Name = name;

Expand Down
16 changes: 16 additions & 0 deletions src/Microsoft.Azure.SignalR.Common/Endpoints/UriExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Linq;

namespace Microsoft.Azure.SignalR
{
internal static class UriExtensions
{
public static Uri Append(this Uri uri, params string[] paths)
{
return new Uri(paths.Aggregate(uri.AbsoluteUri, (current, path) => string.Format("{0}/{1}", current.TrimEnd('/'), path.TrimStart('/'))));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ internal static class ConnectionStringParser

private const string EndpointProperty = "endpoint";

private const string ServerEndpointProperty = "ServerEndpoint";

private const string InvalidVersionValueFormat = "Version {0} is not supported.";

private const string PortProperty = "port";

private const string ServerEndpoint = "ServerEndpoint";

// For SDK 1.x, only support Azure SignalR Service 1.x
private const string SupportedVersion = "1";

Expand Down Expand Up @@ -114,6 +114,7 @@ internal static ParsedConnectionString Parse(string connectionString)
}

Uri clientEndpointUri = null;
Uri serverEndpointUri = null;

// parse and validate clientEndpoint.
if (dict.TryGetValue(ClientEndpointProperty, out var clientEndpoint))
Expand All @@ -124,25 +125,26 @@ internal static ParsedConnectionString Parse(string connectionString)
}
}

// parse and validate clientEndpoint.
if (dict.TryGetValue(ServerEndpointProperty, out var serverEndpoint))
{
if (!TryGetEndpointUri(serverEndpoint, out serverEndpointUri))
{
throw new ArgumentException($"{ServerEndpointProperty} property in connection string is not a valid URI: {serverEndpoint}.");
}
}

// try building accesskey.
dict.TryGetValue(AuthTypeProperty, out var type);
var accessKey = type?.ToLower() switch
{
TypeAzureAD => BuildAadAccessKey(builder.Uri, dict),
TypeAzure => BuildAzureAccessKey(builder.Uri, dict),
TypeAzureApp => BuildAzureAppAccessKey(builder.Uri, dict),
TypeAzureMsi => BuildAzureMsiAccessKey(builder.Uri, dict),
TypeAzureAD => BuildAzureADAccessKey(builder.Uri, serverEndpointUri, dict),
TypeAzure => BuildAzureAccessKey(builder.Uri, serverEndpointUri, dict),
TypeAzureApp => BuildAzureAppAccessKey(builder.Uri, serverEndpointUri, dict),
TypeAzureMsi => BuildAzureMsiAccessKey(builder.Uri, serverEndpointUri, dict),
_ => BuildAccessKey(builder.Uri, dict),
};

Uri serverEndpointUri = null;
if (dict.TryGetValue(ServerEndpoint, out var serverEndpoint))
{
if (!TryGetEndpointUri(serverEndpoint, out serverEndpointUri))
{
throw new ArgumentException($"{ServerEndpoint} property in connection string is not a valid URI: {serverEndpoint}.");
}
}
return new ParsedConnectionString()
{
Endpoint = builder.Uri,
Expand All @@ -159,19 +161,19 @@ internal static bool TryGetEndpointUri(string endpoint, out Uri uriResult)
(uriResult.Scheme == Uri.UriSchemeHttp || uriResult.Scheme == Uri.UriSchemeHttps);
}

private static AccessKey BuildAadAccessKey(Uri uri, Dictionary<string, string> dict)
private static AccessKey BuildAzureADAccessKey(Uri uri, Uri serverEndpointUri, Dictionary<string, string> dict)
{
if (dict.TryGetValue(ClientIdProperty, out var clientId))
{
if (dict.TryGetValue(TenantIdProperty, out var tenantId))
{
if (dict.TryGetValue(ClientSecretProperty, out var clientSecret))
{
return new AadAccessKey(uri, new ClientSecretCredential(tenantId, clientId, clientSecret));
return new AadAccessKey(uri, new ClientSecretCredential(tenantId, clientId, clientSecret), serverEndpointUri);
}
else if (dict.TryGetValue(ClientCertProperty, out var clientCertPath))
{
return new AadAccessKey(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath));
return new AadAccessKey(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath), serverEndpointUri);
}
else
{
Expand All @@ -180,30 +182,28 @@ private static AccessKey BuildAadAccessKey(Uri uri, Dictionary<string, string> d
}
else
{
return new AadAccessKey(uri, new ManagedIdentityCredential(clientId));
return new AadAccessKey(uri, new ManagedIdentityCredential(clientId), serverEndpointUri);
}
}
else
{
return new AadAccessKey(uri, new ManagedIdentityCredential());
return new AadAccessKey(uri, new ManagedIdentityCredential(), serverEndpointUri);
}
}

private static AccessKey BuildAccessKey(Uri uri, Dictionary<string, string> dict)
{
if (dict.TryGetValue(AccessKeyProperty, out var key))
{
return new AccessKey(uri, key);
}
throw new ArgumentException(MissingAccessKeyProperty, AccessKeyProperty);
return dict.TryGetValue(AccessKeyProperty, out var key)
? new AccessKey(uri, key)
: throw new ArgumentException(MissingAccessKeyProperty, AccessKeyProperty);
}

private static AccessKey BuildAzureAccessKey(Uri uri, Dictionary<string, string> dict)
private static AccessKey BuildAzureAccessKey(Uri uri, Uri serverEndpointUri, Dictionary<string, string> dict)
{
return new AadAccessKey(uri, new DefaultAzureCredential());
return new AadAccessKey(uri, new DefaultAzureCredential(), serverEndpointUri);
}

private static AccessKey BuildAzureAppAccessKey(Uri uri, Dictionary<string, string> dict)
private static AccessKey BuildAzureAppAccessKey(Uri uri, Uri serverEndpointUri, Dictionary<string, string> dict)
{
if (!dict.TryGetValue(ClientIdProperty, out var clientId))
{
Expand All @@ -217,22 +217,20 @@ private static AccessKey BuildAzureAppAccessKey(Uri uri, Dictionary<string, stri

if (dict.TryGetValue(ClientSecretProperty, out var clientSecret))
{
return new AadAccessKey(uri, new ClientSecretCredential(tenantId, clientId, clientSecret));
return new AadAccessKey(uri, new ClientSecretCredential(tenantId, clientId, clientSecret), serverEndpointUri);
}
else if (dict.TryGetValue(ClientCertProperty, out var clientCertPath))
{
return new AadAccessKey(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath));
return new AadAccessKey(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath), serverEndpointUri);
}
throw new ArgumentException(MissingClientSecretProperty, ClientSecretProperty);
}

private static AccessKey BuildAzureMsiAccessKey(Uri uri, Dictionary<string, string> dict)
private static AccessKey BuildAzureMsiAccessKey(Uri uri, Uri serverEndpointUri, Dictionary<string, string> dict)
{
if (dict.TryGetValue(ClientIdProperty, out var clientId))
{
return new AadAccessKey(uri, new ManagedIdentityCredential(clientId));
}
return new AadAccessKey(uri, new ManagedIdentityCredential());
return dict.TryGetValue(ClientIdProperty, out var clientId)
? new AadAccessKey(uri, new ManagedIdentityCredential(clientId), serverEndpointUri)
: new AadAccessKey(uri, new ManagedIdentityCredential(), serverEndpointUri);
}

private static Dictionary<string, string> ToDictionary(string connectionString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;

using Azure.Identity;

using Xunit;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ public void TestAzureApplication(string connectionString)
{
var r = ConnectionStringParser.Parse(connectionString);

var aadAccessKey = Assert.IsType<AadAccessKey>(r.AccessKey);
Assert.IsType<ClientSecretCredential>(aadAccessKey.TokenCredential);
var key = Assert.IsType<AadAccessKey>(r.AccessKey);
Assert.IsType<ClientSecretCredential>(key.TokenCredential);
Assert.Same(r.Endpoint, r.AccessKey.Endpoint);
Assert.Null(r.Version);
Assert.Null(r.ClientEndpoint);
Expand Down Expand Up @@ -148,8 +148,8 @@ internal void TestDefaultAzureCredential(string expectedEndpoint, string connect
var r = ConnectionStringParser.Parse(connectionString);

Assert.Equal(expectedEndpoint, r.Endpoint.AbsoluteUri.TrimEnd('/'));
var aadAccessKey = Assert.IsType<AadAccessKey>(r.AccessKey);
Assert.IsType<DefaultAzureCredential>(aadAccessKey.TokenCredential);
var key = Assert.IsType<AadAccessKey>(r.AccessKey);
Assert.IsType<DefaultAzureCredential>(key.TokenCredential);
Assert.Same(r.Endpoint, r.AccessKey.Endpoint);
}

Expand All @@ -165,12 +165,25 @@ internal void TestManagedIdentity(string expectedEndpoint, string connectionStri
var r = ConnectionStringParser.Parse(connectionString);

Assert.Equal(expectedEndpoint, r.Endpoint.AbsoluteUri.TrimEnd('/'));
var aadAccessKey = Assert.IsType<AadAccessKey>(r.AccessKey);
Assert.IsType<ManagedIdentityCredential>(aadAccessKey.TokenCredential);
var key = Assert.IsType<AadAccessKey>(r.AccessKey);
Assert.IsType<ManagedIdentityCredential>(key.TokenCredential);
Assert.Same(r.Endpoint, r.AccessKey.Endpoint);
Assert.Null(r.ClientEndpoint);
}

[Theory]
[InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo", "https://foo/api/v1/auth/accesskey")]
[InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo:123", "https://foo:123/api/v1/auth/accesskey")]
[InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo/bar", "https://foo/bar/api/v1/auth/accesskey")]
[InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo/bar/", "https://foo/bar/api/v1/auth/accesskey")]
[InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo:123/bar/", "https://foo:123/bar/api/v1/auth/accesskey")]
internal void TestAzureADWithServerEndpoint(string connectionString, string expectedAuthorizeUrl)
{
var r = ConnectionStringParser.Parse(connectionString);
var key = Assert.IsType<AadAccessKey>(r.AccessKey);
Assert.Equal(expectedAuthorizeUrl, key.AuthorizeUrl, StringComparer.OrdinalIgnoreCase);
}

public class ClientEndpointTestData : IEnumerable<object[]>
{
public IEnumerator<object[]> GetEnumerator()
Expand Down
Loading

0 comments on commit ed1086d

Please sign in to comment.