Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions source/Halibut.Tests.DotMemory/MemoryFixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,26 @@ public void TcpClientsAreDisposedCorrectly()
.WriteTo.NUnitOutput()
.CreateLogger();

// Two separate HalibutRuntime instances are used to avoid an SChannel session cache
// collision on Windows (.NET Framework / SslProtocols.None). SChannel's TLS session
// cache is per-process and keyed on certificate + host. If a single runtime acts as
// both a TLS server (accepting inbound connections) and a TLS client (making outbound
// polling connections) using the same certificate, SChannel can incorrectly reuse a
// server-side session for a client-side handshake, causing SSPI/TLS failures.
//
// - server: pure TLS server — accepts inbound connections from listening tentacles.
// Uses Certificates.Octopus.
// - pollingServer: pure TLS client — only makes outbound polling connections to tentacles.
// Must use a DIFFERENT certificate (Certificates.TentacleListening) so
// that SChannel never sees the same cert in both server and client roles
// within this process.
HalibutRuntime? server = null;
HalibutRuntime? pollingServer = null;

try
{
server = RunServer(Certificates.Octopus, out var port);
pollingServer = RunPollingServer(Certificates.TentacleListening);

var expectedTcpClientCount = 1; //server listen = 1 tcpclient
//valid requests
Expand All @@ -75,13 +90,15 @@ public void TcpClientsAreDisposedCorrectly()
for (var i = 0; i < NumberOfClients; i++)
{
expectedTcpClientCount++; // each time the server polls, it keeps a tcpclient (as we dont have support to say StopPolling)
RunPollingClient(server, Certificates.TentaclePolling, Certificates.TentaclePollingPublicThumbprint).GetAwaiter().GetResult();
RunPollingClient(pollingServer, Certificates.TentaclePolling, Certificates.TentaclePollingPublicThumbprint).GetAwaiter().GetResult();
}

#if SUPPORTS_WEB_SOCKET_CLIENT
//setup polling websocket
AddSslCertToLocalStoreAndRegisterFor("0.0.0.0:8434");
for (var i = 0; i < NumberOfClients; i++)
{
RunWebSocketPollingClient(server, Certificates.TentaclePolling, Certificates.TentaclePollingPublicThumbprint, Certificates.OctopusPublicThumbprint).GetAwaiter().GetResult();
RunWebSocketPollingClient(pollingServer, Certificates.TentaclePolling, Certificates.TentaclePollingPublicThumbprint, Certificates.TentacleListeningPublicThumbprint).GetAwaiter().GetResult();
}
#endif

Expand All @@ -106,6 +123,7 @@ public void TcpClientsAreDisposedCorrectly()
finally
{
server?.DisposeAsync().GetAwaiter().GetResult();
pollingServer?.DisposeAsync().GetAwaiter().GetResult();
}
}

Expand Down Expand Up @@ -142,16 +160,24 @@ static HalibutRuntime RunServer(X509Certificate2 serverCertificate, out int port
.WithLogFactory(new TestContextLogFactory("client", LogLevel.Info))
.Build();

//set up listening
server.Trust(Certificates.TentacleListeningPublicThumbprint);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment here probs doesn't help

port = server.Listen();

//setup polling websocket
AddSslCertToLocalStoreAndRegisterFor("0.0.0.0:8434");

return server;
}

static HalibutRuntime RunPollingServer(X509Certificate2 serverCertificate)
{
var services = new DelegateServiceFactory();
services.Register<ICalculatorService, IAsyncCalculatorService>(() => new AsyncCalculatorService());

return new HalibutRuntimeBuilder()
.WithServerCertificate(serverCertificate)
.WithServiceFactory(services)
.WithLogFactory(new TestContextLogFactory("polling-server", LogLevel.Info))
.Build();
}

static async Task RunListeningClient(X509Certificate2 clientCertificate, int port, string remoteThumbprint, bool expectSuccess = true)
{
await using (var runtime = new HalibutRuntimeBuilder().WithServerCertificate(clientCertificate).Build())
Expand All @@ -169,9 +195,13 @@ static async Task RunPollingClient(HalibutRuntime server, X509Certificate2 clien
.Build())
{
runtime.Listen(new IPEndPoint(IPAddress.IPv6Any, 8433));
runtime.Trust(Certificates.OctopusPublicThumbprint);
// Trust the thumbprint of pollingServer's certificate (TentacleListening), which is
// the cert pollingServer presents when it dials in to establish the polling connection.
runtime.Trust(Certificates.TentacleListeningPublicThumbprint);

//setup polling
// The remote thumbprint here is this runtime's own certificate (TentaclePolling),
// which pollingServer verifies when it connects to port 8433.
var serverEndpoint = new ServiceEndPoint(new Uri("https://localhost:8433"), Certificates.TentaclePollingPublicThumbprint, runtime.TimeoutsAndLimits)
{
TcpClientConnectTimeout = TimeSpan.FromSeconds(5)
Expand Down
27 changes: 18 additions & 9 deletions source/Halibut.Tests/BadCertificatesTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public async Task SucceedsWhenPollingServicePresentsWrongCertificate_ButServiceI
var clientTrustProvider = new DefaultTrustProvider();
var unauthorizedThumbprint = "";
var firstCall = true;
var serviceThumbprint = "";

var unauthorizedClientHasConnected = new TaskCompletionSource<bool>();
CancellationToken.Register(() => unauthorizedClientHasConnected.TrySetCanceled()); // backup to fail the test in case it never connects
Expand All @@ -43,7 +44,7 @@ public async Task SucceedsWhenPollingServicePresentsWrongCertificate_ButServiceI
{
if (firstCall)
{
clientTrustProvider.IsTrusted(CertAndThumbprint.TentaclePolling.Thumbprint).Should().BeFalse();
clientTrustProvider.IsTrusted(serviceThumbprint).Should().BeFalse();
firstCall = false;
}

Expand All @@ -53,6 +54,8 @@ public async Task SucceedsWhenPollingServicePresentsWrongCertificate_ButServiceI
})
.Build(CancellationToken))
{
serviceThumbprint = clientAndBuilder.ServiceThumbprint;

// Act
var clientCountingService = clientAndBuilder.CreateAsyncClient<ICountingService, IAsyncClientCountingService>();
await clientCountingService.IncrementAsync();
Expand All @@ -62,8 +65,8 @@ public async Task SucceedsWhenPollingServicePresentsWrongCertificate_ButServiceI
// Assert
countingService.CurrentValue().Should().Be(1);

clientTrustProvider.IsTrusted(CertAndThumbprint.TentaclePolling.Thumbprint).Should().BeTrue();
unauthorizedThumbprint.Should().Be(CertAndThumbprint.TentaclePolling.Thumbprint);
clientTrustProvider.IsTrusted(serviceThumbprint).Should().BeTrue();
unauthorizedThumbprint.Should().Be(serviceThumbprint);
}
}

Expand Down Expand Up @@ -93,6 +96,8 @@ public async Task FailWhenPollingServicePresentsWrongCertificate_ButServiceIsCon
})
.Build(CancellationToken))
{
var serviceThumbprint = clientAndBuilder.ServiceThumbprint;

using var cts = new CancellationTokenSource();
var clientCountingService = clientAndBuilder.CreateAsyncClient<ICountingService, IAsyncClientCountingServiceWithOptions>(point =>
{
Expand All @@ -105,7 +110,7 @@ public async Task FailWhenPollingServicePresentsWrongCertificate_ButServiceIsCon
// Interestingly the message exchange error is logged to a non polling looking URL, perhaps because it has not been identified?
Wait.UntilActionSucceeds(() => {
AllLogs(serviceLoggers).Select(l => l.FormattedMessage).ToArray()
.Should().Contain(s => s.Contains("and attempted a message exchange, but it presented a client certificate with the thumbprint '4098EC3A2FC2B92B97339D3831BA230CC1DD590F' which is not in the list of thumbprints that we trust"));
.Should().Contain(s => s.Contains($"and attempted a message exchange, but it presented a client certificate with the thumbprint '{serviceThumbprint}' which is not in the list of thumbprints that we trust"));
},
TimeSpan.FromSeconds(10),
Logger,
Expand All @@ -124,7 +129,7 @@ public async Task FailWhenPollingServicePresentsWrongCertificate_ButServiceIsCon
// Assert
countingService.CurrentValue().Should().Be(0, "With a bad certificate the request never should have been made");

unauthorizedThumbprint.Should().Be(CertAndThumbprint.TentaclePolling.Thumbprint);
unauthorizedThumbprint.Should().Be(serviceThumbprint);
}
}

Expand Down Expand Up @@ -195,8 +200,8 @@ public async Task FailWhenClientPresentsWrongCertificateToListeningService(Clien

serviceLoggers[serviceLoggers.Keys.First(x => x != nameof(MessageSerializer))].GetLogs().Should()
.Contain(log => log.FormattedMessage
.Contains("and attempted a message exchange, but it presented a client certificate with the thumbprint " +
"'76225C0717A16C1D0BA4A7FFA76519D286D8A248' which is not in the list of thumbprints that we trust"));
.Contains("and attempted a message exchange, but it presented a client certificate with the thumbprint")
&& log.FormattedMessage.Contains("which is not in the list of thumbprints that we trust"));
}
}

Expand Down Expand Up @@ -254,11 +259,13 @@ public async Task FailWhenListeningServicePresentsWrongCertificate(ClientAndServ
.WithCountingService(countingService)
.Build(CancellationToken))
{
var serviceThumbprint = clientAndBuilder.ServiceThumbprint;

var clientCountingService = clientAndBuilder.CreateAsyncClient<ICountingService, IAsyncClientCountingService>();
(await AssertionExtensions.Should(() => clientCountingService.IncrementAsync()).ThrowAsync<HalibutClientException>())
.And.Message.Should().Contain("" +
"We expected the server to present a certificate with the thumbprint 'EC32122053C6BFF582F8246F5697633D06F0F97F'. " +
"Instead, it presented a certificate with a thumbprint of '36F35047CE8B000CF4C671819A2DD1AFCDE3403D'");
$"Instead, it presented a certificate with a thumbprint of '{serviceThumbprint}'");
countingService.CurrentValue().Should().Be(0, "With a bad certificate the request never should have been made");
}
}
Expand All @@ -275,6 +282,8 @@ public async Task FailWhenPollingServicePresentsWrongCertificate(ClientAndServic
.RecordingClientLogs(out var serviceLoggers)
.Build(CancellationToken))
{
var serviceThumbprint = clientAndBuilder.ServiceThumbprint;

using var cts = new CancellationTokenSource();
var clientCountingService = clientAndBuilder.CreateAsyncClient<ICountingService, IAsyncClientCountingServiceWithOptions>(point =>
{
Expand All @@ -285,7 +294,7 @@ public async Task FailWhenPollingServicePresentsWrongCertificate(ClientAndServic

// Interestingly the message exchange error is logged to a non polling looking URL, perhaps because it has not been identified?
Wait.UntilActionSucceeds(() => { AllLogs(serviceLoggers).Select(l => l.FormattedMessage).ToArray()
.Should().Contain(s => s.Contains("and attempted a message exchange, but it presented a client certificate with the thumbprint '4098EC3A2FC2B92B97339D3831BA230CC1DD590F' which is not in the list of thumbprints that we trust")); },
.Should().Contain(s => s.Contains($"and attempted a message exchange, but it presented a client certificate with the thumbprint '{serviceThumbprint}' which is not in the list of thumbprints that we trust")); },
TimeSpan.FromSeconds(10),
Logger,
CancellationToken);
Expand Down
42 changes: 31 additions & 11 deletions source/Halibut.Tests/ClientServerLifecycleTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,33 @@ namespace Halibut.Tests
{
public class ClientServerLifecycleTests : BaseTest
{
DisposableCollection disposables = null!;
ICertAndThumbprint serverCert = null!;
ICertAndThumbprint listenerCert = null!;
ICertAndThumbprint pollerCert = null!;

[SetUp]
public void SetUpCerts()
{
disposables = new DisposableCollection();
serverCert = TestCertificates.CertFor(CertAndThumbprint.Octopus, disposedBy: disposables);
listenerCert = TestCertificates.CertFor(CertAndThumbprint.TentacleListening, disposedBy: disposables);
pollerCert = TestCertificates.CertFor(CertAndThumbprint.TentaclePolling, disposedBy: disposables);
}

[TearDown]
public void TearDownCerts()
{
disposables?.Dispose();
}

[Test]
public async Task ListeningConfiguration()
{
await using var server = RunServer(out var serverPort);

await using var runtime = CreateRuntimeForListener();
var client = CreateClient(runtime, serverPort);
var client = CreateClient(runtime, serverPort, serverCert);
var result = await client.AddAsync(2, 2);
result.Should().Be(4);
}
Expand All @@ -56,7 +76,7 @@ public async Task ListeningThenPollingConfiguration()
HalibutRuntime CreateRuntimeForListener()
{
var runtime = new HalibutRuntimeBuilder()
.WithServerCertificate(Certificates.TentacleListening)
.WithServerCertificate(listenerCert.Certificate2)
.WithLogFactory(new TestLogFactory(HalibutLog))
.Build();
return runtime;
Expand All @@ -65,15 +85,15 @@ HalibutRuntime CreateRuntimeForListener()
HalibutRuntime CreateRuntimeForPoller(HalibutRuntime serverRuntime, out IAsyncClientCalculatorService client)
{
var runtime = new HalibutRuntimeBuilder()
.WithServerCertificate(Certificates.TentaclePolling)
.WithServerCertificate(pollerCert.Certificate2)
.WithLogFactory(new TestLogFactory(HalibutLog))
.Build();
var port = runtime.Listen();
runtime.Trust(Certificates.OctopusPublicThumbprint);
runtime.Trust(serverCert.Thumbprint);

var pollEndpoint = new ServiceEndPoint(
baseUri: new Uri($"https://localhost:{port}/"),
remoteThumbprint: Certificates.TentaclePollingPublicThumbprint,
remoteThumbprint: pollerCert.Thumbprint,
halibutTimeoutsAndLimits: runtime.TimeoutsAndLimits
)
{
Expand All @@ -83,19 +103,19 @@ HalibutRuntime CreateRuntimeForPoller(HalibutRuntime serverRuntime, out IAsyncCl
serverRuntime.Poll(pollingUri, pollEndpoint, CancellationToken);
var clientEndpoint = new ServiceEndPoint(
baseUri: pollingUri,
remoteThumbprint: Certificates.OctopusPublicThumbprint,
remoteThumbprint: serverCert.Thumbprint,
halibutTimeoutsAndLimits: runtime.TimeoutsAndLimits
);
client = runtime.CreateAsyncClient<ICalculatorService, IAsyncClientCalculatorService>(clientEndpoint);

return runtime;
}

static IAsyncClientCalculatorService CreateClient(HalibutRuntime runtime, int port)
static IAsyncClientCalculatorService CreateClient(HalibutRuntime runtime, int port, ICertAndThumbprint serverCertAndThumbprint)
{
var endpoint = new ServiceEndPoint(
baseUri: $"https://localhost:{port}",
remoteThumbprint: Certificates.OctopusPublicThumbprint,
remoteThumbprint: serverCertAndThumbprint.Thumbprint,
halibutTimeoutsAndLimits: runtime.TimeoutsAndLimits
);
var client = runtime
Expand All @@ -115,13 +135,13 @@ HalibutRuntime RunServer(out int port)
var services = CreateServiceFactory();

var runtime = new HalibutRuntimeBuilder()
.WithServerCertificate(Certificates.Octopus)
.WithServerCertificate(serverCert.Certificate2)
.WithServiceFactory(services)
.WithLogFactory(new TestLogFactory(HalibutLog))
.Build();

runtime.Trust(Certificates.TentacleListeningPublicThumbprint);
runtime.Trust(Certificates.TentaclePollingPublicThumbprint);
runtime.Trust(listenerCert.Thumbprint);
runtime.Trust(pollerCert.Thumbprint);
port = runtime.Listen();

return runtime;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,33 @@ public class HalibutTestBinaryPath
public string BinPath(string version)
{
var onDiskVersion = version.Replace(".", "_");
var projectName = $"Halibut.TestUtils.CompatBinary.v{onDiskVersion}";
return ResolveProjectBinPath(projectName);
}

public string SchannelProbeBinPath()
{
return ResolveProjectBinPath("Halibut.TestUtils.CompatBinary.SchannelProbe");
}

string ResolveProjectBinPath(string projectName)
{
var assemblyDir = new DirectoryInfo(Path.GetDirectoryName(typeof(HalibutTestBinaryRunner).Assembly.Location)!);
var upAt = assemblyDir.Parent!.Parent!.Parent!.Parent!;
var projectName = $"Halibut.TestUtils.CompatBinary.v{onDiskVersion}";
var executable = Path.Combine(upAt.FullName, projectName, assemblyDir.Parent.Parent.Name, assemblyDir.Parent.Name, assemblyDir.Name, projectName);
executable = AddExeForWindows(executable);

if (!File.Exists(executable))
{
throw new Exception("Could not executable at path:\n" +
throw new Exception("Could not find executable at path:\n" +
executable + "\n" +
$"Did you forget to update the csproj to depend on {projectName}\n" +
"If testing a previously untested version of Halibut a new project may be required.");
}

return executable;
}

string AddExeForWindows(string path)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) return path + ".exe";
Expand Down
Loading