Skip to content

Commit

Permalink
Fix failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Ndiritu committed Nov 5, 2024
1 parent 7007775 commit b1d4800
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 88 deletions.
158 changes: 79 additions & 79 deletions src/http/httpClient/Middleware/AuthorizationHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,92 +15,92 @@

namespace Microsoft.Kiota.Http.HttpClientLibrary.Middleware
{
/// <summary>
/// Adds an Authorization header to the request if the header is not already present.
/// Also handles Continuous Access Evaluation (CAE) claims challenges if the initial
/// token request was made using this handler
/// </summary>
public class AuthorizationHandler : DelegatingHandler
{

private const string AuthorizationHeader = "Authorization";
private readonly BaseBearerTokenAuthenticationProvider authenticationProvider;

/// <summary>
/// Adds an Authorization header to the request if the header is not already present.
/// Also handles Continuous Access Evaluation (CAE) claims challenges if the initial
/// token request was made using this handler
/// Constructs an <see cref="AuthorizationHandler"/>
/// </summary>
public class AuthorizationHandler : DelegatingHandler
/// <param name="authenticationProvider"></param>
/// <exception cref="ArgumentNullException"></exception>
public AuthorizationHandler(BaseBearerTokenAuthenticationProvider authenticationProvider)
{
if(authenticationProvider == null) throw new ArgumentNullException(nameof(authenticationProvider));
this.authenticationProvider = authenticationProvider;
}

private const string AuthorizationHeader = "Authorization";
private readonly BaseBearerTokenAuthenticationProvider authenticationProvider;

/// <summary>
/// Constructs an <see cref="AuthorizationHandler"/>
/// </summary>
/// <param name="authenticationProvider"></param>
/// <exception cref="ArgumentNullException"></exception>
public AuthorizationHandler(BaseBearerTokenAuthenticationProvider authenticationProvider)
{
if(authenticationProvider == null) throw new ArgumentNullException(nameof(authenticationProvider));
this.authenticationProvider = authenticationProvider;
}
/// <summary>
/// Adds an Authorization header if not already provided
/// </summary>
/// <param name="request"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request,
CancellationToken cancellationToken)
{
if(request == null) throw new ArgumentNullException(nameof(request));

/// <summary>
/// Adds an Authorization header if not already provided
/// </summary>
/// <param name="request"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request,
CancellationToken cancellationToken)
Activity? activity = null;
if(request.GetRequestOption<ObservabilityOptions>() is { } obsOptions)
{
var activitySource = ActivitySourceRegistry.DefaultInstance.GetOrCreateActivitySource(obsOptions.TracerInstrumentationName);
activity = activitySource?.StartActivity($"{nameof(AuthorizationHandler)}_{nameof(SendAsync)}");
activity?.SetTag("com.microsoft.kiota.handler.authorization.enable", true);
}
try
{
if(request.Headers.Contains(AuthorizationHeader))
{
if(request == null) throw new ArgumentNullException(nameof(request));

Activity? activity = null;
if(request.GetRequestOption<ObservabilityOptions>() is { } obsOptions)
{
var activitySource = ActivitySourceRegistry.DefaultInstance.GetOrCreateActivitySource(obsOptions.TracerInstrumentationName);
activity = activitySource?.StartActivity($"{nameof(AuthorizationHandler)}_{nameof(SendAsync)}");
activity?.SetTag("com.microsoft.kiota.handler.authorization.enable", true);
}
try
{
if(request.Headers.Contains(AuthorizationHeader))
{
activity?.SetTag("com.microsoft.kiota.handler.authorization.token_present", true);
return await base.SendAsync(request, cancellationToken).ConfigureAwait(false);
}
Dictionary<string, object> additionalAuthenticationContext = new Dictionary<string, object>();
await AuthenticateRequestAsync(request, additionalAuthenticationContext, cancellationToken, activity).ConfigureAwait(false);
var response = await base.SendAsync(request, cancellationToken).ConfigureAwait(false);
if(response.StatusCode != HttpStatusCode.Unauthorized || response.RequestMessage == null || !response.RequestMessage.IsBuffered())
return response;
// Attempt CAE claims challenge
var claims = ContinuousAccessEvaluation.GetClaims(response);
if(string.IsNullOrEmpty(claims))
return response;
activity?.AddEvent(new ActivityEvent("com.microsoft.kiota.handler.authorization.challenge_received"));
additionalAuthenticationContext[ContinuousAccessEvaluation.ClaimsKey] = claims;
HttpRequestMessage retryRequest = response.RequestMessage;
await AuthenticateRequestAsync(retryRequest, additionalAuthenticationContext, cancellationToken, activity).ConfigureAwait(false);
activity?.SetTag("http.request.resend_count", 1);
return await base.SendAsync(retryRequest, cancellationToken).ConfigureAwait(false);
}
finally
{
activity?.Dispose();
}
activity?.SetTag("com.microsoft.kiota.handler.authorization.token_present", true);
return await base.SendAsync(request, cancellationToken).ConfigureAwait(false);
}
Dictionary<string, object> additionalAuthenticationContext = new Dictionary<string, object>();
await AuthenticateRequestAsync(request, additionalAuthenticationContext, cancellationToken, activity).ConfigureAwait(false);
var response = await base.SendAsync(request, cancellationToken).ConfigureAwait(false);
if(response.StatusCode != HttpStatusCode.Unauthorized || response.RequestMessage == null || !response.RequestMessage.IsBuffered())
return response;
// Attempt CAE claims challenge
var claims = ContinuousAccessEvaluation.GetClaims(response);
if(string.IsNullOrEmpty(claims))
return response;
activity?.AddEvent(new ActivityEvent("com.microsoft.kiota.handler.authorization.challenge_received"));
additionalAuthenticationContext[ContinuousAccessEvaluation.ClaimsKey] = claims;
HttpRequestMessage retryRequest = response.RequestMessage;
await AuthenticateRequestAsync(retryRequest, additionalAuthenticationContext, cancellationToken, activity).ConfigureAwait(false);
activity?.SetTag("http.request.resend_count", 1);
return await base.SendAsync(retryRequest, cancellationToken).ConfigureAwait(false);
}
finally
{
activity?.Dispose();
}
}

private async Task AuthenticateRequestAsync(HttpRequestMessage request,
Dictionary<string, object> additionalAuthenticationContext,
CancellationToken cancellationToken,
Activity? activityForAttributes)
{
var accessTokenProvider = authenticationProvider.AccessTokenProvider;
if(request.RequestUri == null || !accessTokenProvider.AllowedHostsValidator.IsUrlHostValid(
request.RequestUri))
{
return;
}
var accessToken = await accessTokenProvider.GetAuthorizationTokenAsync(
request.RequestUri,
additionalAuthenticationContext, cancellationToken).ConfigureAwait(false);
activityForAttributes?.SetTag("com.microsoft.kiota.handler.authorization.token_obtained", true);
if(string.IsNullOrEmpty(accessToken)) return;
request.Headers.TryAddWithoutValidation(AuthorizationHeader, $"Bearer {accessToken}");
}
private async Task AuthenticateRequestAsync(HttpRequestMessage request,
Dictionary<string, object> additionalAuthenticationContext,
CancellationToken cancellationToken,
Activity? activityForAttributes)
{
var accessTokenProvider = authenticationProvider.AccessTokenProvider;
if(request.RequestUri == null || !accessTokenProvider.AllowedHostsValidator.IsUrlHostValid(
request.RequestUri))
{
return;
}
var accessToken = await accessTokenProvider.GetAuthorizationTokenAsync(
request.RequestUri,
additionalAuthenticationContext, cancellationToken).ConfigureAwait(false);
activityForAttributes?.SetTag("com.microsoft.kiota.handler.authorization.token_obtained", true);
if(string.IsNullOrEmpty(accessToken)) return;
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", accessToken);
}
}
}
18 changes: 9 additions & 9 deletions tests/http/httpClient/Middleware/AuthorizationHandlerTests.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Text;
using Microsoft.Kiota.Abstractions.Authentication;
using Microsoft.Kiota.Http.HttpClientLibrary.Middleware;
using Microsoft.Kiota.Http.HttpClientLibrary.Tests.Mocks;
Expand All @@ -12,9 +13,6 @@ namespace Microsoft.Kiota.Http.HttpClientLibrary.Tests.Middleware
public class AuthorizationHandlerTests : IDisposable
{
private readonly MockRedirectHandler _testHttpMessageHandler;

private IAccessTokenProvider _mockAccessTokenProvider;

private readonly string _expectedAccessToken = "token";

private readonly string _expectedAccessTokenAfterCAE = "token2";
Expand All @@ -34,14 +32,14 @@ public AuthorizationHandlerTests()
It.IsAny<Uri>(),
It.IsAny<Dictionary<string, object>>(),
It.IsAny<CancellationToken>()
)).Returns(new Task<string>(() => _expectedAccessToken))
.Returns(new Task<string>(() => _expectedAccessTokenAfterCAE));
).Result).Returns(_expectedAccessToken)
.Returns(_expectedAccessTokenAfterCAE);

mockAccessTokenProvider.Setup(x => x.AllowedHostsValidator).Returns(
new AllowedHostsValidator(new List<string> { "https://graph.microsoft.com" })
new AllowedHostsValidator(new List<string> { "graph.microsoft.com" })
);
this._mockAccessTokenProvider = mockAccessTokenProvider.Object;
this._authenticationProvider = new BaseBearerTokenAuthenticationProvider(_mockAccessTokenProvider!);
var mockAuthenticationProvider = new Mock<BaseBearerTokenAuthenticationProvider>(mockAccessTokenProvider.Object);
this._authenticationProvider = mockAuthenticationProvider.Object;
this._authorizationHandler = new AuthorizationHandler(_authenticationProvider)
{
InnerHandler = this._testHttpMessageHandler
Expand Down Expand Up @@ -98,11 +96,12 @@ public async Task AuthorizationHandlerShouldAttemptCAEClaimsChallenge()
{
// Arrange
HttpRequestMessage httpRequestMessage = new HttpRequestMessage(HttpMethod.Get, "https://graph.microsoft.com");
httpRequestMessage.Content = new ByteArrayContent(Encoding.UTF8.GetBytes("test"));

HttpResponseMessage httpResponse = new HttpResponseMessage(HttpStatusCode.Unauthorized);
httpResponse.Headers.WwwAuthenticate.Add(new AuthenticationHeaderValue("Bearer", _claimsChallengeHeaderValue));

this._testHttpMessageHandler.SetHttpResponse(httpResponse);// set the mock response
this._testHttpMessageHandler.SetHttpResponse(httpResponse, new HttpResponseMessage(HttpStatusCode.OK));// set the mock response

// Act
HttpResponseMessage response = await this._invoker.SendAsync(httpRequestMessage, new CancellationToken());
Expand All @@ -112,6 +111,7 @@ public async Task AuthorizationHandlerShouldAttemptCAEClaimsChallenge()
Assert.True(response.RequestMessage.Headers.Contains("Authorization"));
Assert.True(response.RequestMessage.Headers.GetValues("Authorization").Count() == 1);
Assert.Equal($"Bearer {_expectedAccessTokenAfterCAE}", response.RequestMessage.Headers.GetValues("Authorization").First());
Assert.Equal("test", await response.RequestMessage.Content!.ReadAsStringAsync());
}
}
}

0 comments on commit b1d4800

Please sign in to comment.