diff --git a/source/Halibut.Tests/TlsInsepctionTests.cs b/source/Halibut.Tests/TlsInsepctionTests.cs new file mode 100644 index 00000000..be0827f9 --- /dev/null +++ b/source/Halibut.Tests/TlsInsepctionTests.cs @@ -0,0 +1,94 @@ +using System; +using System.Linq; +using System.Security.Cryptography.X509Certificates; +using FluentAssertions; +using Halibut.ServiceModel; +using Halibut.Tests.Support; +using Halibut.TestUtils.Contracts; +using Halibut.Transport; +using NUnit.Framework; + +namespace Halibut.Tests +{ + public class TlsInspectionTests : BaseTest + { + + [Test] + public void TrustExplicitlySpecifiedCerts() + { + var services = GetDelegateServiceFactory(); + using (var octopusServer = new HalibutRuntimeBuilder() + .WithServerCertificate(CertAndThumbprint.Octopus.Certificate2) + .Build()) + using (var tentaclePolling = new HalibutRuntimeBuilder() + .WithServiceFactory(services) + .WithServerCertificate(CertAndThumbprint.TentaclePolling.Certificate2) + .Build()) + { + octopusServer.Trust(CertAndThumbprint.TentaclePolling.Thumbprint); + + var port = octopusServer.Listen(); + + tentaclePolling.Poll(new Uri("poll://SQ-TENTAPOLL"), new ServiceEndPoint(new Uri("https://localhost:" + port), CertAndThumbprint.Octopus.Thumbprint)); + + var echo = octopusServer.CreateClient("poll://SQ-TENTAPOLL", CertAndThumbprint.TentaclePolling.Thumbprint); + + var result = echo.SayHello("World"); + result.Should().Be("World..."); + } + } + + [Test] + public void TrustAnyCertificateIssuedByTrustedCertificateRootAuthority() + { + var store = new X509Store(StoreName.My); + store.Open(OpenFlags.ReadOnly); + + var trustedCert = FindTrustedCert(store.Certificates); + var serverCert = trustedCert; + + var services = GetDelegateServiceFactory(); + using (var octopusServer = new HalibutRuntimeBuilder() + .WithServerCertificate(serverCert) + .Build()) + using (var tentaclePolling = new HalibutRuntimeBuilder() + .WithServiceFactory(services) + .WithClientCertificateValidatorFactory(new TrustRootCertificateAuthorityValidatorFactory()) + .WithServerCertificate(CertAndThumbprint.TentaclePolling.Certificate2) + .Build()) + { + octopusServer.Trust(CertAndThumbprint.TentaclePolling.Thumbprint); + + var port = octopusServer.Listen(); + + tentaclePolling.Poll(new Uri("poll://SQ-TENTAPOLL"), new ServiceEndPoint(new Uri("https://localhost:" + port), null)); + + var echo = octopusServer.CreateClient("poll://SQ-TENTAPOLL", CertAndThumbprint.TentaclePolling.Thumbprint); + + var result = echo.SayHello("World"); + result.Should().Be("World..."); + } + } + + X509Certificate2 FindTrustedCert(X509Certificate2Collection storeCertificates) + { + foreach (var storeCertificate in storeCertificates) + { + if (storeCertificate.Verify() && storeCertificate.HasPrivateKey) + { + return storeCertificate; + } + } + + return null; + } + + static DelegateServiceFactory GetDelegateServiceFactory() + { + var services = new DelegateServiceFactory(); + services.Register(() => new EchoService()); + return services; + } + + } +} \ No newline at end of file diff --git a/source/Halibut.Tests/Transport/SecureClientFixture.cs b/source/Halibut.Tests/Transport/SecureClientFixture.cs index 88b949ed..dd64f3cd 100644 --- a/source/Halibut.Tests/Transport/SecureClientFixture.cs +++ b/source/Halibut.Tests/Transport/SecureClientFixture.cs @@ -20,6 +20,7 @@ public class SecureClientFixture : IDisposable ServiceEndPoint endpoint; HalibutRuntime tentacle; ILog log; + ClientCertificateValidatorFactory clientCertificateValidatorFactory; [SetUp] public void SetUp() @@ -27,6 +28,7 @@ public void SetUp() var services = new DelegateServiceFactory(); services.Register(() => new EchoService()); tentacle = new HalibutRuntime(services, Certificates.TentacleListening); + clientCertificateValidatorFactory = new ClientCertificateValidatorFactory(); var tentaclePort = tentacle.Listen(); tentacle.Trust(Certificates.OctopusPublicThumbprint); endpoint = new ServiceEndPoint("https://localhost:" + tentaclePort, Certificates.TentacleListeningPublicThumbprint) @@ -67,7 +69,7 @@ public async Task SecureClientClearsPoolWhenAllConnectionsCorrupt(SyncOrAsync sy Params = new object[] { "Fred" } }; - var secureClient = new SecureListeningClient((s, l) => GetProtocol(s, l, syncOrAsync), endpoint, Certificates.Octopus, log, connectionManager); + var secureClient = new SecureListeningClient((s, l) => GetProtocol(s, l, syncOrAsync), endpoint, Certificates.Octopus, log, connectionManager, clientCertificateValidatorFactory); ResponseMessage response = null!; using var requestCancellationTokens = new RequestCancellationTokens(CancellationToken.None, CancellationToken.None); diff --git a/source/Halibut/HalibutRuntime.cs b/source/Halibut/HalibutRuntime.cs index 97909004..a18c338a 100644 --- a/source/Halibut/HalibutRuntime.cs +++ b/source/Halibut/HalibutRuntime.cs @@ -36,6 +36,7 @@ public class HalibutRuntime : IHalibutRuntime readonly ITypeRegistry typeRegistry; readonly Lazy responseCache = new(); readonly Func pollingReconnectRetryPolicy; + readonly IClientCertificateValidatorFactory clientCertificateValidatorFactory; [Obsolete] public HalibutRuntime(X509Certificate2 serverCertificate) : this(new NullServiceFactory(), serverCertificate, new DefaultTrustProvider()) @@ -80,7 +81,8 @@ internal HalibutRuntime( ITypeRegistry typeRegistry, IMessageSerializer messageSerializer, Func pollingReconnectRetryPolicy, - AsyncHalibutFeature asyncHalibutFeature) + AsyncHalibutFeature asyncHalibutFeature, + IClientCertificateValidatorFactory clientCertificateValidatorFactory) { AsyncHalibutFeature = asyncHalibutFeature; this.serverCertificate = serverCertificate; @@ -90,6 +92,7 @@ internal HalibutRuntime( this.typeRegistry = typeRegistry; this.messageSerializer = messageSerializer; this.pollingReconnectRetryPolicy = pollingReconnectRetryPolicy; + this.clientCertificateValidatorFactory = clientCertificateValidatorFactory; invoker = new ServiceInvoker(serviceFactory); } @@ -181,7 +184,7 @@ public void Poll(Uri subscription, ServiceEndPoint endPoint, CancellationToken c } else { - client = new SecureClient(ExchangeProtocolBuilder(), endPoint, serverCertificate, log, connectionManager); + client = new SecureClient(ExchangeProtocolBuilder(), endPoint, serverCertificate, log, connectionManager, clientCertificateValidatorFactory); } pollingClients.Add(new PollingClient(subscription, client, HandleIncomingRequest, log, cancellationToken, pollingReconnectRetryPolicy, AsyncHalibutFeature)); } @@ -350,7 +353,7 @@ async Task SendOutgoingRequestAsync(RequestMessage request, Met [Obsolete] ResponseMessage SendOutgoingHttpsRequest(RequestMessage request, CancellationToken cancellationToken) { - var client = new SecureListeningClient(ExchangeProtocolBuilder(), request.Destination, serverCertificate, logs.ForEndpoint(request.Destination.BaseUri), connectionManager); + var client = new SecureListeningClient(ExchangeProtocolBuilder(), request.Destination, serverCertificate, logs.ForEndpoint(request.Destination.BaseUri), connectionManager, clientCertificateValidatorFactory); ResponseMessage response = null; client.ExecuteTransaction(protocol => @@ -362,7 +365,7 @@ ResponseMessage SendOutgoingHttpsRequest(RequestMessage request, CancellationTok async Task SendOutgoingHttpsRequestAsync(RequestMessage request, RequestCancellationTokens requestCancellationTokens) { - var client = new SecureListeningClient(ExchangeProtocolBuilder(), request.Destination, serverCertificate, logs.ForEndpoint(request.Destination.BaseUri), connectionManager); + var client = new SecureListeningClient(ExchangeProtocolBuilder(), request.Destination, serverCertificate, logs.ForEndpoint(request.Destination.BaseUri), connectionManager, clientCertificateValidatorFactory); ResponseMessage response = null; diff --git a/source/Halibut/HalibutRuntimeBuilder.cs b/source/Halibut/HalibutRuntimeBuilder.cs index 86f25693..bc321f76 100644 --- a/source/Halibut/HalibutRuntimeBuilder.cs +++ b/source/Halibut/HalibutRuntimeBuilder.cs @@ -3,6 +3,7 @@ using System.Security.Cryptography.X509Certificates; using Halibut.Diagnostics; using Halibut.ServiceModel; +using Halibut.Transport; using Halibut.Transport.Protocol; using Halibut.Util; @@ -20,6 +21,7 @@ public class HalibutRuntimeBuilder Func pollingReconnectRetryPolicy = RetryPolicy.Create; AsyncHalibutFeature asyncHalibutFeature = AsyncHalibutFeature.Disabled; Func onUnauthorizedClientConnect; + IClientCertificateValidatorFactory clientCertificateValidatorFactory; public HalibutRuntimeBuilder WithServiceFactory(IServiceFactory serviceFactory) { @@ -27,6 +29,12 @@ public HalibutRuntimeBuilder WithServiceFactory(IServiceFactory serviceFactory) return this; } + public HalibutRuntimeBuilder WithClientCertificateValidatorFactory(IClientCertificateValidatorFactory clientCertificateValidator) + { + this.clientCertificateValidatorFactory = clientCertificateValidator; + return this; + } + public HalibutRuntimeBuilder WithServerCertificate(X509Certificate2 serverCertificate) { this.serverCertificate = serverCertificate; @@ -97,6 +105,7 @@ public HalibutRuntime Build() #pragma warning restore CS0612 var trustProvider = this.trustProvider ?? new DefaultTrustProvider(); var typeRegistry = this.typeRegistry ?? new TypeRegistry(); + var clientCertificateValidatorFactory = this.clientCertificateValidatorFactory ?? new ClientCertificateValidatorFactory(); var messageContracts = serviceFactory.RegisteredServiceTypes.ToArray(); typeRegistry.AddToMessageContract(messageContracts); @@ -114,7 +123,8 @@ public HalibutRuntime Build() typeRegistry, messageSerializer, pollingReconnectRetryPolicy, - asyncHalibutFeature); + asyncHalibutFeature, + clientCertificateValidatorFactory); if (onUnauthorizedClientConnect is not null) { diff --git a/source/Halibut/Transport/ClientCertificateValidator.cs b/source/Halibut/Transport/ClientCertificateValidator.cs index 74dda559..a55b921c 100644 --- a/source/Halibut/Transport/ClientCertificateValidator.cs +++ b/source/Halibut/Transport/ClientCertificateValidator.cs @@ -4,7 +4,7 @@ namespace Halibut.Transport { - class ClientCertificateValidator + class ClientCertificateValidator : IClientCertificateValidator { readonly ServiceEndPoint endPoint; @@ -26,4 +26,39 @@ public bool Validate(object sender, X509Certificate certificate, X509Chain chain throw new UnexpectedCertificateException(providedCert, endPoint); } } + + class TrustRootCertificateAuthorityValidator : IClientCertificateValidator + { + public bool Validate(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslpolicyerrors) + { + var result = new X509Certificate2(certificate).Verify(); + return result; + } + } + + public interface IClientCertificateValidator + { + bool Validate(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslpolicyerrors); + } + + public interface IClientCertificateValidatorFactory + { + IClientCertificateValidator Create(ServiceEndPoint serviceEndpoint); + } + + public class ClientCertificateValidatorFactory : IClientCertificateValidatorFactory + { + public IClientCertificateValidator Create(ServiceEndPoint serviceEndpoint) + { + return new ClientCertificateValidator(serviceEndpoint); + } + } + + public class TrustRootCertificateAuthorityValidatorFactory : IClientCertificateValidatorFactory + { + public IClientCertificateValidator Create(ServiceEndPoint serviceEndpoint) + { + return new TrustRootCertificateAuthorityValidator(); + } + } } \ No newline at end of file diff --git a/source/Halibut/Transport/SecureClient.cs b/source/Halibut/Transport/SecureClient.cs index 8ec0fff1..c8a0db62 100644 --- a/source/Halibut/Transport/SecureClient.cs +++ b/source/Halibut/Transport/SecureClient.cs @@ -19,16 +19,18 @@ public class SecureClient : ISecureClient [Obsolete("Replaced by HalibutLimits.RetryCountLimit")] public const int RetryCountLimit = 5; readonly ILog log; readonly ConnectionManager connectionManager; + readonly IClientCertificateValidatorFactory clientCertificateValidatorFactory; readonly X509Certificate2 clientCertificate; readonly ExchangeProtocolBuilder protocolBuilder; - public SecureClient(ExchangeProtocolBuilder protocolBuilder, ServiceEndPoint serviceEndpoint, X509Certificate2 clientCertificate, ILog log, ConnectionManager connectionManager) + public SecureClient(ExchangeProtocolBuilder protocolBuilder, ServiceEndPoint serviceEndpoint, X509Certificate2 clientCertificate, ILog log, ConnectionManager connectionManager, IClientCertificateValidatorFactory clientCertificateValidatorFactory) { this.protocolBuilder = protocolBuilder; this.ServiceEndpoint = serviceEndpoint; this.clientCertificate = clientCertificate; this.log = log; this.connectionManager = connectionManager; + this.clientCertificateValidatorFactory = clientCertificateValidatorFactory; } public ServiceEndPoint ServiceEndpoint { get; } @@ -58,7 +60,7 @@ public void ExecuteTransaction(ExchangeAction protocolHandler, CancellationToken IConnection connection = null; try { - connection = connectionManager.AcquireConnection(protocolBuilder, new TcpConnectionFactory(clientCertificate), ServiceEndpoint, log, cancellationToken); + connection = connectionManager.AcquireConnection(protocolBuilder, new TcpConnectionFactory(clientCertificate, clientCertificateValidatorFactory), ServiceEndpoint, log, cancellationToken); // Beyond this point, we have no way to be certain that the server hasn't tried to process a request; therefore, we can't retry after this point retryAllowed = false; @@ -158,7 +160,7 @@ public async Task ExecuteTransactionAsync(ExchangeActionAsync protocolHandler, R { connection = await connectionManager.AcquireConnectionAsync( protocolBuilder, - new TcpConnectionFactory(clientCertificate), + new TcpConnectionFactory(clientCertificate, clientCertificateValidatorFactory), ServiceEndpoint, log, requestCancellationTokens.LinkedCancellationToken).ConfigureAwait(false); diff --git a/source/Halibut/Transport/SecureListeningClient.cs b/source/Halibut/Transport/SecureListeningClient.cs index aacf47b2..977fe3f2 100644 --- a/source/Halibut/Transport/SecureListeningClient.cs +++ b/source/Halibut/Transport/SecureListeningClient.cs @@ -18,16 +18,18 @@ class SecureListeningClient : ISecureClient { readonly ILog log; readonly ConnectionManager connectionManager; + readonly IClientCertificateValidatorFactory clientCertificateValidatorFactory; readonly X509Certificate2 clientCertificate; readonly ExchangeProtocolBuilder exchangeProtocolBuilder; - public SecureListeningClient(ExchangeProtocolBuilder exchangeProtocolBuilder, ServiceEndPoint serviceEndpoint, X509Certificate2 clientCertificate, ILog log, ConnectionManager connectionManager) + public SecureListeningClient(ExchangeProtocolBuilder exchangeProtocolBuilder, ServiceEndPoint serviceEndpoint, X509Certificate2 clientCertificate, ILog log, ConnectionManager connectionManager, IClientCertificateValidatorFactory clientCertificateValidatorFactory) { this.exchangeProtocolBuilder = exchangeProtocolBuilder; this.ServiceEndpoint = serviceEndpoint; this.clientCertificate = clientCertificate; this.log = log; this.connectionManager = connectionManager; + this.clientCertificateValidatorFactory = clientCertificateValidatorFactory; } public ServiceEndPoint ServiceEndpoint { get; } @@ -57,7 +59,7 @@ public void ExecuteTransaction(ExchangeAction protocolHandler, CancellationToken IConnection connection = null; try { - connection = connectionManager.AcquireConnection(exchangeProtocolBuilder, new TcpConnectionFactory(clientCertificate), ServiceEndpoint, log, cancellationToken); + connection = connectionManager.AcquireConnection(exchangeProtocolBuilder, new TcpConnectionFactory(clientCertificate, clientCertificateValidatorFactory), ServiceEndpoint, log, cancellationToken); // Beyond this point, we have no way to be certain that the server hasn't tried to process a request; therefore, we can't retry after this point retryAllowed = false; @@ -168,7 +170,7 @@ public async Task ExecuteTransactionAsync(ExchangeActionAsync protocolHandler, R { connection = await connectionManager.AcquireConnectionAsync( exchangeProtocolBuilder, - new TcpConnectionFactory(clientCertificate), + new TcpConnectionFactory(clientCertificate, clientCertificateValidatorFactory), ServiceEndpoint, log, requestCancellationTokens.LinkedCancellationToken).ConfigureAwait(false); diff --git a/source/Halibut/Transport/TcpConnectionFactory.cs b/source/Halibut/Transport/TcpConnectionFactory.cs index 52aac44b..48992941 100644 --- a/source/Halibut/Transport/TcpConnectionFactory.cs +++ b/source/Halibut/Transport/TcpConnectionFactory.cs @@ -19,10 +19,12 @@ public class TcpConnectionFactory : IConnectionFactory static readonly byte[] MxLine = Encoding.ASCII.GetBytes("MX" + Environment.NewLine + Environment.NewLine); readonly X509Certificate2 clientCertificate; + readonly IClientCertificateValidatorFactory clientCertificateValidatorFactory; - public TcpConnectionFactory(X509Certificate2 clientCertificate) + public TcpConnectionFactory(X509Certificate2 clientCertificate, IClientCertificateValidatorFactory clientCertificateValidatorFactory) { this.clientCertificate = clientCertificate; + this.clientCertificateValidatorFactory = clientCertificateValidatorFactory; } [Obsolete] @@ -30,12 +32,12 @@ public IConnection EstablishNewConnection(ExchangeProtocolBuilder exchangeProtoc { log.Write(EventType.OpeningNewConnection, $"Opening a new connection to {serviceEndpoint.BaseUri}"); - var certificateValidator = new ClientCertificateValidator(serviceEndpoint); var client = CreateConnectedTcpClient(serviceEndpoint, log, cancellationToken); log.Write(EventType.Diagnostic, $"Connection established to {client.Client.RemoteEndPoint} for {serviceEndpoint.BaseUri}"); var stream = client.GetStream(); + var certificateValidator = clientCertificateValidatorFactory.Create(serviceEndpoint); log.Write(EventType.SecurityNegotiation, "Performing TLS handshake"); var ssl = new SslStream(stream, false, certificateValidator.Validate, UserCertificateSelectionCallback); ssl.AuthenticateAsClient(serviceEndpoint.BaseUri.Host, new X509Certificate2Collection(clientCertificate), SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, false);