Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC - Trust all certs issued by trusted root authority - Do Not Merge #424

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions source/Halibut.Tests/TlsInsepctionTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
using System;
using System.Linq;
using System.Security.Cryptography.X509Certificates;
using FluentAssertions;
using Halibut.ServiceModel;
using Halibut.Tests.Support;
using Halibut.TestUtils.Contracts;
using Halibut.Transport;
using NUnit.Framework;

namespace Halibut.Tests
{
public class TlsInspectionTests : BaseTest
{

[Test]
public void TrustExplicitlySpecifiedCerts()
{
var services = GetDelegateServiceFactory();
using (var octopusServer = new HalibutRuntimeBuilder()
.WithServerCertificate(CertAndThumbprint.Octopus.Certificate2)
.Build())
using (var tentaclePolling = new HalibutRuntimeBuilder()
.WithServiceFactory(services)
.WithServerCertificate(CertAndThumbprint.TentaclePolling.Certificate2)
.Build())
{
octopusServer.Trust(CertAndThumbprint.TentaclePolling.Thumbprint);

var port = octopusServer.Listen();

tentaclePolling.Poll(new Uri("poll://SQ-TENTAPOLL"), new ServiceEndPoint(new Uri("https://localhost:" + port), CertAndThumbprint.Octopus.Thumbprint));

var echo = octopusServer.CreateClient<IEchoService>("poll://SQ-TENTAPOLL", CertAndThumbprint.TentaclePolling.Thumbprint);

var result = echo.SayHello("World");
result.Should().Be("World...");
}
}

[Test]
public void TrustAnyCertificateIssuedByTrustedCertificateRootAuthority()
{
var store = new X509Store(StoreName.My);
store.Open(OpenFlags.ReadOnly);

var trustedCert = FindTrustedCert(store.Certificates);
var serverCert = trustedCert;

var services = GetDelegateServiceFactory();
using (var octopusServer = new HalibutRuntimeBuilder()
.WithServerCertificate(serverCert)
.Build())
using (var tentaclePolling = new HalibutRuntimeBuilder()
.WithServiceFactory(services)
.WithClientCertificateValidatorFactory(new TrustRootCertificateAuthorityValidatorFactory())
.WithServerCertificate(CertAndThumbprint.TentaclePolling.Certificate2)
.Build())
{
octopusServer.Trust(CertAndThumbprint.TentaclePolling.Thumbprint);

var port = octopusServer.Listen();

tentaclePolling.Poll(new Uri("poll://SQ-TENTAPOLL"), new ServiceEndPoint(new Uri("https://localhost:" + port), null));

var echo = octopusServer.CreateClient<IEchoService>("poll://SQ-TENTAPOLL", CertAndThumbprint.TentaclePolling.Thumbprint);

var result = echo.SayHello("World");
result.Should().Be("World...");
}
}

X509Certificate2 FindTrustedCert(X509Certificate2Collection storeCertificates)
{
foreach (var storeCertificate in storeCertificates)
{
if (storeCertificate.Verify() && storeCertificate.HasPrivateKey)
{
return storeCertificate;
}
}

return null;
}

static DelegateServiceFactory GetDelegateServiceFactory()
{
var services = new DelegateServiceFactory();
services.Register<IEchoService>(() => new EchoService());
return services;
}

}
}
4 changes: 3 additions & 1 deletion source/Halibut.Tests/Transport/SecureClientFixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ public class SecureClientFixture : IDisposable
ServiceEndPoint endpoint;
HalibutRuntime tentacle;
ILog log;
ClientCertificateValidatorFactory clientCertificateValidatorFactory;

[SetUp]
public void SetUp()
{
var services = new DelegateServiceFactory();
services.Register<IEchoService>(() => new EchoService());
tentacle = new HalibutRuntime(services, Certificates.TentacleListening);
clientCertificateValidatorFactory = new ClientCertificateValidatorFactory();
var tentaclePort = tentacle.Listen();
tentacle.Trust(Certificates.OctopusPublicThumbprint);
endpoint = new ServiceEndPoint("https://localhost:" + tentaclePort, Certificates.TentacleListeningPublicThumbprint)
Expand Down Expand Up @@ -67,7 +69,7 @@ public async Task SecureClientClearsPoolWhenAllConnectionsCorrupt(SyncOrAsync sy
Params = new object[] { "Fred" }
};

var secureClient = new SecureListeningClient((s, l) => GetProtocol(s, l, syncOrAsync), endpoint, Certificates.Octopus, log, connectionManager);
var secureClient = new SecureListeningClient((s, l) => GetProtocol(s, l, syncOrAsync), endpoint, Certificates.Octopus, log, connectionManager, clientCertificateValidatorFactory);
ResponseMessage response = null!;

using var requestCancellationTokens = new RequestCancellationTokens(CancellationToken.None, CancellationToken.None);
Expand Down
11 changes: 7 additions & 4 deletions source/Halibut/HalibutRuntime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public class HalibutRuntime : IHalibutRuntime
readonly ITypeRegistry typeRegistry;
readonly Lazy<ResponseCache> responseCache = new();
readonly Func<RetryPolicy> pollingReconnectRetryPolicy;
readonly IClientCertificateValidatorFactory clientCertificateValidatorFactory;

[Obsolete]
public HalibutRuntime(X509Certificate2 serverCertificate) : this(new NullServiceFactory(), serverCertificate, new DefaultTrustProvider())
Expand Down Expand Up @@ -80,7 +81,8 @@ internal HalibutRuntime(
ITypeRegistry typeRegistry,
IMessageSerializer messageSerializer,
Func<RetryPolicy> pollingReconnectRetryPolicy,
AsyncHalibutFeature asyncHalibutFeature)
AsyncHalibutFeature asyncHalibutFeature,
IClientCertificateValidatorFactory clientCertificateValidatorFactory)
{
AsyncHalibutFeature = asyncHalibutFeature;
this.serverCertificate = serverCertificate;
Expand All @@ -90,6 +92,7 @@ internal HalibutRuntime(
this.typeRegistry = typeRegistry;
this.messageSerializer = messageSerializer;
this.pollingReconnectRetryPolicy = pollingReconnectRetryPolicy;
this.clientCertificateValidatorFactory = clientCertificateValidatorFactory;
invoker = new ServiceInvoker(serviceFactory);
}

Expand Down Expand Up @@ -181,7 +184,7 @@ public void Poll(Uri subscription, ServiceEndPoint endPoint, CancellationToken c
}
else
{
client = new SecureClient(ExchangeProtocolBuilder(), endPoint, serverCertificate, log, connectionManager);
client = new SecureClient(ExchangeProtocolBuilder(), endPoint, serverCertificate, log, connectionManager, clientCertificateValidatorFactory);
}
pollingClients.Add(new PollingClient(subscription, client, HandleIncomingRequest, log, cancellationToken, pollingReconnectRetryPolicy, AsyncHalibutFeature));
}
Expand Down Expand Up @@ -350,7 +353,7 @@ async Task<ResponseMessage> SendOutgoingRequestAsync(RequestMessage request, Met
[Obsolete]
ResponseMessage SendOutgoingHttpsRequest(RequestMessage request, CancellationToken cancellationToken)
{
var client = new SecureListeningClient(ExchangeProtocolBuilder(), request.Destination, serverCertificate, logs.ForEndpoint(request.Destination.BaseUri), connectionManager);
var client = new SecureListeningClient(ExchangeProtocolBuilder(), request.Destination, serverCertificate, logs.ForEndpoint(request.Destination.BaseUri), connectionManager, clientCertificateValidatorFactory);

ResponseMessage response = null;
client.ExecuteTransaction(protocol =>
Expand All @@ -362,7 +365,7 @@ ResponseMessage SendOutgoingHttpsRequest(RequestMessage request, CancellationTok

async Task<ResponseMessage> SendOutgoingHttpsRequestAsync(RequestMessage request, RequestCancellationTokens requestCancellationTokens)
{
var client = new SecureListeningClient(ExchangeProtocolBuilder(), request.Destination, serverCertificate, logs.ForEndpoint(request.Destination.BaseUri), connectionManager);
var client = new SecureListeningClient(ExchangeProtocolBuilder(), request.Destination, serverCertificate, logs.ForEndpoint(request.Destination.BaseUri), connectionManager, clientCertificateValidatorFactory);

ResponseMessage response = null;

Expand Down
12 changes: 11 additions & 1 deletion source/Halibut/HalibutRuntimeBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Security.Cryptography.X509Certificates;
using Halibut.Diagnostics;
using Halibut.ServiceModel;
using Halibut.Transport;
using Halibut.Transport.Protocol;
using Halibut.Util;

Expand All @@ -20,13 +21,20 @@ public class HalibutRuntimeBuilder
Func<RetryPolicy> pollingReconnectRetryPolicy = RetryPolicy.Create;
AsyncHalibutFeature asyncHalibutFeature = AsyncHalibutFeature.Disabled;
Func<string, string, UnauthorizedClientConnectResponse> onUnauthorizedClientConnect;
IClientCertificateValidatorFactory clientCertificateValidatorFactory;

public HalibutRuntimeBuilder WithServiceFactory(IServiceFactory serviceFactory)
{
this.serviceFactory = serviceFactory;
return this;
}

public HalibutRuntimeBuilder WithClientCertificateValidatorFactory(IClientCertificateValidatorFactory clientCertificateValidator)
{
this.clientCertificateValidatorFactory = clientCertificateValidator;
return this;
}

public HalibutRuntimeBuilder WithServerCertificate(X509Certificate2 serverCertificate)
{
this.serverCertificate = serverCertificate;
Expand Down Expand Up @@ -97,6 +105,7 @@ public HalibutRuntime Build()
#pragma warning restore CS0612
var trustProvider = this.trustProvider ?? new DefaultTrustProvider();
var typeRegistry = this.typeRegistry ?? new TypeRegistry();
var clientCertificateValidatorFactory = this.clientCertificateValidatorFactory ?? new ClientCertificateValidatorFactory();

var messageContracts = serviceFactory.RegisteredServiceTypes.ToArray();
typeRegistry.AddToMessageContract(messageContracts);
Expand All @@ -114,7 +123,8 @@ public HalibutRuntime Build()
typeRegistry,
messageSerializer,
pollingReconnectRetryPolicy,
asyncHalibutFeature);
asyncHalibutFeature,
clientCertificateValidatorFactory);

if (onUnauthorizedClientConnect is not null)
{
Expand Down
37 changes: 36 additions & 1 deletion source/Halibut/Transport/ClientCertificateValidator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace Halibut.Transport
{
class ClientCertificateValidator
class ClientCertificateValidator : IClientCertificateValidator
{
readonly ServiceEndPoint endPoint;

Expand All @@ -26,4 +26,39 @@ public bool Validate(object sender, X509Certificate certificate, X509Chain chain
throw new UnexpectedCertificateException(providedCert, endPoint);
}
}

class TrustRootCertificateAuthorityValidator : IClientCertificateValidator
{
public bool Validate(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslpolicyerrors)
{
var result = new X509Certificate2(certificate).Verify();
return result;
}
}

public interface IClientCertificateValidator
{
bool Validate(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslpolicyerrors);
}

public interface IClientCertificateValidatorFactory
{
IClientCertificateValidator Create(ServiceEndPoint serviceEndpoint);
}

public class ClientCertificateValidatorFactory : IClientCertificateValidatorFactory
{
public IClientCertificateValidator Create(ServiceEndPoint serviceEndpoint)
{
return new ClientCertificateValidator(serviceEndpoint);
}
}

public class TrustRootCertificateAuthorityValidatorFactory : IClientCertificateValidatorFactory
{
public IClientCertificateValidator Create(ServiceEndPoint serviceEndpoint)
{
return new TrustRootCertificateAuthorityValidator();
}
}
}
8 changes: 5 additions & 3 deletions source/Halibut/Transport/SecureClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,18 @@ public class SecureClient : ISecureClient
[Obsolete("Replaced by HalibutLimits.RetryCountLimit")] public const int RetryCountLimit = 5;
readonly ILog log;
readonly ConnectionManager connectionManager;
readonly IClientCertificateValidatorFactory clientCertificateValidatorFactory;
readonly X509Certificate2 clientCertificate;
readonly ExchangeProtocolBuilder protocolBuilder;

public SecureClient(ExchangeProtocolBuilder protocolBuilder, ServiceEndPoint serviceEndpoint, X509Certificate2 clientCertificate, ILog log, ConnectionManager connectionManager)
public SecureClient(ExchangeProtocolBuilder protocolBuilder, ServiceEndPoint serviceEndpoint, X509Certificate2 clientCertificate, ILog log, ConnectionManager connectionManager, IClientCertificateValidatorFactory clientCertificateValidatorFactory)
{
this.protocolBuilder = protocolBuilder;
this.ServiceEndpoint = serviceEndpoint;
this.clientCertificate = clientCertificate;
this.log = log;
this.connectionManager = connectionManager;
this.clientCertificateValidatorFactory = clientCertificateValidatorFactory;
}

public ServiceEndPoint ServiceEndpoint { get; }
Expand Down Expand Up @@ -58,7 +60,7 @@ public void ExecuteTransaction(ExchangeAction protocolHandler, CancellationToken
IConnection connection = null;
try
{
connection = connectionManager.AcquireConnection(protocolBuilder, new TcpConnectionFactory(clientCertificate), ServiceEndpoint, log, cancellationToken);
connection = connectionManager.AcquireConnection(protocolBuilder, new TcpConnectionFactory(clientCertificate, clientCertificateValidatorFactory), ServiceEndpoint, log, cancellationToken);

// Beyond this point, we have no way to be certain that the server hasn't tried to process a request; therefore, we can't retry after this point
retryAllowed = false;
Expand Down Expand Up @@ -158,7 +160,7 @@ public async Task ExecuteTransactionAsync(ExchangeActionAsync protocolHandler, R
{
connection = await connectionManager.AcquireConnectionAsync(
protocolBuilder,
new TcpConnectionFactory(clientCertificate),
new TcpConnectionFactory(clientCertificate, clientCertificateValidatorFactory),
ServiceEndpoint,
log,
requestCancellationTokens.LinkedCancellationToken).ConfigureAwait(false);
Expand Down
8 changes: 5 additions & 3 deletions source/Halibut/Transport/SecureListeningClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@ class SecureListeningClient : ISecureClient
{
readonly ILog log;
readonly ConnectionManager connectionManager;
readonly IClientCertificateValidatorFactory clientCertificateValidatorFactory;
readonly X509Certificate2 clientCertificate;
readonly ExchangeProtocolBuilder exchangeProtocolBuilder;

public SecureListeningClient(ExchangeProtocolBuilder exchangeProtocolBuilder, ServiceEndPoint serviceEndpoint, X509Certificate2 clientCertificate, ILog log, ConnectionManager connectionManager)
public SecureListeningClient(ExchangeProtocolBuilder exchangeProtocolBuilder, ServiceEndPoint serviceEndpoint, X509Certificate2 clientCertificate, ILog log, ConnectionManager connectionManager, IClientCertificateValidatorFactory clientCertificateValidatorFactory)
{
this.exchangeProtocolBuilder = exchangeProtocolBuilder;
this.ServiceEndpoint = serviceEndpoint;
this.clientCertificate = clientCertificate;
this.log = log;
this.connectionManager = connectionManager;
this.clientCertificateValidatorFactory = clientCertificateValidatorFactory;
}

public ServiceEndPoint ServiceEndpoint { get; }
Expand Down Expand Up @@ -57,7 +59,7 @@ public void ExecuteTransaction(ExchangeAction protocolHandler, CancellationToken
IConnection connection = null;
try
{
connection = connectionManager.AcquireConnection(exchangeProtocolBuilder, new TcpConnectionFactory(clientCertificate), ServiceEndpoint, log, cancellationToken);
connection = connectionManager.AcquireConnection(exchangeProtocolBuilder, new TcpConnectionFactory(clientCertificate, clientCertificateValidatorFactory), ServiceEndpoint, log, cancellationToken);

// Beyond this point, we have no way to be certain that the server hasn't tried to process a request; therefore, we can't retry after this point
retryAllowed = false;
Expand Down Expand Up @@ -168,7 +170,7 @@ public async Task ExecuteTransactionAsync(ExchangeActionAsync protocolHandler, R
{
connection = await connectionManager.AcquireConnectionAsync(
exchangeProtocolBuilder,
new TcpConnectionFactory(clientCertificate),
new TcpConnectionFactory(clientCertificate, clientCertificateValidatorFactory),
ServiceEndpoint,
log,
requestCancellationTokens.LinkedCancellationToken).ConfigureAwait(false);
Expand Down
6 changes: 4 additions & 2 deletions source/Halibut/Transport/TcpConnectionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,25 @@ public class TcpConnectionFactory : IConnectionFactory
static readonly byte[] MxLine = Encoding.ASCII.GetBytes("MX" + Environment.NewLine + Environment.NewLine);

readonly X509Certificate2 clientCertificate;
readonly IClientCertificateValidatorFactory clientCertificateValidatorFactory;

public TcpConnectionFactory(X509Certificate2 clientCertificate)
public TcpConnectionFactory(X509Certificate2 clientCertificate, IClientCertificateValidatorFactory clientCertificateValidatorFactory)
{
this.clientCertificate = clientCertificate;
this.clientCertificateValidatorFactory = clientCertificateValidatorFactory;
}

[Obsolete]
public IConnection EstablishNewConnection(ExchangeProtocolBuilder exchangeProtocolBuilder, ServiceEndPoint serviceEndpoint, ILog log, CancellationToken cancellationToken)
{
log.Write(EventType.OpeningNewConnection, $"Opening a new connection to {serviceEndpoint.BaseUri}");

var certificateValidator = new ClientCertificateValidator(serviceEndpoint);
var client = CreateConnectedTcpClient(serviceEndpoint, log, cancellationToken);
log.Write(EventType.Diagnostic, $"Connection established to {client.Client.RemoteEndPoint} for {serviceEndpoint.BaseUri}");

var stream = client.GetStream();

var certificateValidator = clientCertificateValidatorFactory.Create(serviceEndpoint);
log.Write(EventType.SecurityNegotiation, "Performing TLS handshake");
var ssl = new SslStream(stream, false, certificateValidator.Validate, UserCertificateSelectionCallback);
ssl.AuthenticateAsClient(serviceEndpoint.BaseUri.Host, new X509Certificate2Collection(clientCertificate), SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, false);
Expand Down