From d44aea156ba4946e4942b13844e0716d7b3eaecb Mon Sep 17 00:00:00 2001 From: Pranav Iyer Date: Wed, 20 May 2026 17:56:05 -0700 Subject: [PATCH 1/2] feat(gax): Implement cert-rotation retries for grpc and http-json. --- .../com/google/api/gax/grpc/ChannelPool.java | 8 + .../google/api/gax/grpc/GrpcCallContext.java | 58 +++-- .../api/gax/grpc/GrpcTransportChannel.java | 8 + .../api/gax/httpjson/HttpJsonCallContext.java | 68 +++-- .../InstantiatingHttpJsonChannelProvider.java | 22 +- .../gax/httpjson/ManagedHttpJsonChannel.java | 2 + .../ManagedHttpJsonInterceptorChannel.java | 5 + .../httpjson/RefreshingHttpJsonChannel.java | 233 ++++++++++++++++++ .../google/api/gax/rpc/ApiCallContext.java | 8 + .../api/gax/rpc/ApiResultRetryAlgorithm.java | 8 + .../google/api/gax/rpc/AttemptCallable.java | 22 ++ .../api/gax/rpc/BidiStreamingCallable.java | 38 ++- .../api/gax/rpc/ClientStreamingCallable.java | 33 ++- .../rpc/ServerStreamingAttemptCallable.java | 13 + .../google/api/gax/rpc/TransportChannel.java | 8 + 15 files changed, 488 insertions(+), 46 deletions(-) create mode 100644 sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/RefreshingHttpJsonChannel.java diff --git a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index d611c96ff4c8..d35dbc8d12ca 100644 --- a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -82,6 +82,7 @@ class ChannelPool extends ManagedChannel { private ScheduledFuture resizeFuture = null; private final Object entryWriteLock = new Object(); + private long lastRefreshTimeNanos = 0; @VisibleForTesting final AtomicReference> entries = new AtomicReference<>(); private final AtomicInteger indexTicker = new AtomicInteger(); private final String authority; @@ -441,6 +442,13 @@ void refresh() { // - then thread2 will shut down channel that thread1 will put back into circulation (after it // replaces the list) synchronized (entryWriteLock) { + long now = System.nanoTime(); + if (now - lastRefreshTimeNanos < TimeUnit.SECONDS.toNanos(5)) { + LOG.fine("Channel pool was refreshed recently, skipping duplicate refresh"); + return; + } + lastRefreshTimeNanos = now; + LOG.fine("Refreshing all channels"); ArrayList newEntries = new ArrayList<>(entries.get()); diff --git a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java index 7ff7c54de6f0..fb5e2edb0d07 100644 --- a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java +++ b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java @@ -97,6 +97,7 @@ public final class GrpcCallContext implements ApiCallContext { private final ApiCallContextOptions options; private final EndpointContext endpointContext; private final boolean isDirectPath; + @Nullable private final TransportChannel transportChannel; /** Returns an empty instance with a null channel and default {@link CallOptions}. */ public static GrpcCallContext createDefault() { @@ -113,7 +114,8 @@ public static GrpcCallContext createDefault() { null, null, null, - false); + false, + null); } /** Returns an instance with the given channel and {@link CallOptions}. */ @@ -131,7 +133,8 @@ public static GrpcCallContext of(Channel channel, CallOptions callOptions) { null, null, null, - false); + false, + null); } private GrpcCallContext( @@ -147,7 +150,8 @@ private GrpcCallContext( @Nullable RetrySettings retrySettings, @Nullable Set retryableCodes, @Nullable EndpointContext endpointContext, - boolean isDirectPath) { + boolean isDirectPath, + @Nullable TransportChannel transportChannel) { this.channel = channel; this.credentials = credentials; Preconditions.checkNotNull(callOptions); @@ -167,6 +171,7 @@ private GrpcCallContext( this.endpointContext = endpointContext == null ? EndpointContext.getDefaultInstance() : endpointContext; this.isDirectPath = isDirectPath; + this.transportChannel = transportChannel; } /** @@ -208,7 +213,13 @@ public GrpcCallContext withCredentials(Credentials newCredentials) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); + } + + @Override + public TransportChannel getTransportChannel() { + return transportChannel; } @Override @@ -232,7 +243,8 @@ public GrpcCallContext withTransportChannel(TransportChannel inputChannel) { retrySettings, retryableCodes, endpointContext, - transportChannel.isDirectPath()); + transportChannel.isDirectPath(), + inputChannel); } @Override @@ -251,7 +263,8 @@ public GrpcCallContext withEndpointContext(EndpointContext endpointContext) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } /** This method is obsolete. Use {@link #withTimeoutDuration(java.time.Duration)} instead. */ @@ -286,7 +299,8 @@ public GrpcCallContext withTimeoutDuration(@Nullable java.time.Duration timeout) retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } /** This method is obsolete. Use {@link #getTimeoutDuration()} instead. */ @@ -335,7 +349,8 @@ public GrpcCallContext withStreamWaitTimeoutDuration( retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } /** @@ -370,7 +385,8 @@ public GrpcCallContext withStreamIdleTimeoutDuration( retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } @BetaApi("The surface for channel affinity is not stable yet and may change in the future.") @@ -388,7 +404,8 @@ public GrpcCallContext withChannelAffinity(@Nullable Integer affinity) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } @BetaApi("The surface for extra headers is not stable yet and may change in the future.") @@ -410,7 +427,8 @@ public GrpcCallContext withExtraHeaders(Map> extraHeaders) retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } @Override @@ -433,7 +451,8 @@ public GrpcCallContext withRetrySettings(RetrySettings retrySettings) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } @Override @@ -456,7 +475,8 @@ public GrpcCallContext withRetryableCodes(Set retryableCodes) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } @Override @@ -558,7 +578,8 @@ public ApiCallContext merge(ApiCallContext inputCallContext) { newRetrySettings, newRetryableCodes, endpointContext, - newIsDirectPath); + newIsDirectPath, + transportChannel); } /** The {@link Channel} set on this context. */ @@ -641,7 +662,8 @@ public GrpcCallContext withChannel(Channel newChannel) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } /** Returns a new instance with the call options set to the given call options. */ @@ -659,7 +681,8 @@ public GrpcCallContext withCallOptions(CallOptions newCallOptions) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } public GrpcCallContext withRequestParamsDynamicHeaderOption(String requestParams) { @@ -704,7 +727,8 @@ public GrpcCallContext withOption(Key key, T value) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } /** {@inheritDoc} */ diff --git a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java index 2fa0908f17bc..80d471701d5a 100644 --- a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java +++ b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java @@ -66,6 +66,14 @@ public Channel getChannel() { return getManagedChannel(); } + @Override + public void refresh() { + Channel channel = getChannel(); + if (channel instanceof ChannelPool) { + ((ChannelPool) channel).refresh(); + } + } + @Override public void shutdown() { getManagedChannel().shutdown(); diff --git a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonCallContext.java b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonCallContext.java index c946e9aab03d..aa167d93e8b6 100644 --- a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonCallContext.java +++ b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonCallContext.java @@ -80,6 +80,7 @@ public final class HttpJsonCallContext implements ApiCallContext { @Nullable private final RetrySettings retrySettings; @Nullable private final ImmutableSet retryableCodes; private final EndpointContext endpointContext; + @Nullable private final TransportChannel transportChannel; /** Returns an empty instance. */ public static HttpJsonCallContext createDefault() { @@ -94,6 +95,7 @@ public static HttpJsonCallContext createDefault() { null, null, null, + null, null); } @@ -109,6 +111,7 @@ public static HttpJsonCallContext of(HttpJsonChannel channel, HttpJsonCallOption null, null, null, + null, null); } @@ -123,7 +126,8 @@ private HttpJsonCallContext( ApiTracer tracer, RetrySettings defaultRetrySettings, Set defaultRetryableCodes, - @Nullable EndpointContext endpointContext) { + @Nullable EndpointContext endpointContext, + @Nullable TransportChannel transportChannel) { this.channel = channel; this.callOptions = callOptions; this.timeout = timeout; @@ -139,6 +143,7 @@ private HttpJsonCallContext( // a valid EndpointContext with user configurations after the client has been initialized. this.endpointContext = endpointContext == null ? EndpointContext.getDefaultInstance() : endpointContext; + this.transportChannel = transportChannel; } /** @@ -231,7 +236,8 @@ public HttpJsonCallContext merge(ApiCallContext inputCallContext) { newTracer, newRetrySettings, newRetryableCodes, - endpointContext); + endpointContext, + this.transportChannel); } @Override @@ -249,7 +255,24 @@ public HttpJsonCallContext withTransportChannel(TransportChannel inputChannel) { "Expected HttpJsonTransportChannel, got " + inputChannel.getClass().getName()); } HttpJsonTransportChannel transportChannel = (HttpJsonTransportChannel) inputChannel; - return withChannel(transportChannel.getChannel()); + return new HttpJsonCallContext( + transportChannel.getChannel(), + this.callOptions, + this.timeout, + this.streamWaitTimeout, + this.streamIdleTimeout, + this.extraHeaders, + this.options, + this.tracer, + this.retrySettings, + this.retryableCodes, + this.endpointContext, + transportChannel); + } + + @Override + public TransportChannel getTransportChannel() { + return transportChannel; } /** This method is obsolete. Use {@link #withTimeoutDuration(java.time.Duration)} instead. */ @@ -273,7 +296,8 @@ public HttpJsonCallContext withEndpointContext(EndpointContext endpointContext) this.tracer, this.retrySettings, this.retryableCodes, - endpointContext); + endpointContext, + this.transportChannel); } @Override @@ -299,7 +323,8 @@ public HttpJsonCallContext withTimeoutDuration(java.time.Duration timeout) { this.tracer, this.retrySettings, this.retryableCodes, - this.endpointContext); + this.endpointContext, + this.transportChannel); } /** This method is obsolete. Use {@link #getTimeoutDuration()} instead. */ @@ -346,7 +371,8 @@ public HttpJsonCallContext withStreamWaitTimeoutDuration( this.tracer, this.retrySettings, this.retryableCodes, - this.endpointContext); + this.endpointContext, + this.transportChannel); } /** This method is obsolete. Use {@link #getStreamWaitTimeoutDuration()} instead. */ @@ -398,7 +424,8 @@ public HttpJsonCallContext withStreamIdleTimeoutDuration( this.tracer, this.retrySettings, this.retryableCodes, - this.endpointContext); + this.endpointContext, + this.transportChannel); } /** This method is obsolete. Use {@link #getStreamIdleTimeoutDuration()} instead. */ @@ -437,7 +464,8 @@ public ApiCallContext withExtraHeaders(Map> extraHeaders) { this.tracer, this.retrySettings, this.retryableCodes, - this.endpointContext); + this.endpointContext, + this.transportChannel); } @BetaApi("The surface for extra headers is not stable yet and may change in the future.") @@ -461,7 +489,8 @@ public ApiCallContext withOption(Key key, T value) { this.tracer, this.retrySettings, this.retryableCodes, - this.endpointContext); + this.endpointContext, + this.transportChannel); } /** {@inheritDoc} */ @@ -533,7 +562,8 @@ public HttpJsonCallContext withRetrySettings(RetrySettings retrySettings) { this.tracer, retrySettings, this.retryableCodes, - this.endpointContext); + this.endpointContext, + this.transportChannel); } @Override @@ -554,7 +584,8 @@ public HttpJsonCallContext withRetryableCodes(Set retryableCode this.tracer, this.retrySettings, retryableCodes, - this.endpointContext); + this.endpointContext, + this.transportChannel); } public HttpJsonCallContext withChannel(HttpJsonChannel newChannel) { @@ -569,7 +600,8 @@ public HttpJsonCallContext withChannel(HttpJsonChannel newChannel) { this.tracer, this.retrySettings, this.retryableCodes, - this.endpointContext); + this.endpointContext, + this.transportChannel); } public HttpJsonCallContext withCallOptions(HttpJsonCallOptions newCallOptions) { @@ -584,7 +616,8 @@ public HttpJsonCallContext withCallOptions(HttpJsonCallOptions newCallOptions) { this.tracer, this.retrySettings, this.retryableCodes, - this.endpointContext); + this.endpointContext, + this.transportChannel); } @Deprecated @@ -620,7 +653,8 @@ public HttpJsonCallContext withTracer(@Nonnull ApiTracer newTracer) { newTracer, this.retrySettings, this.retryableCodes, - this.endpointContext); + this.endpointContext, + this.transportChannel); } @Override @@ -640,7 +674,8 @@ public boolean equals(Object o) { && Objects.equals(this.tracer, that.tracer) && Objects.equals(this.retrySettings, that.retrySettings) && Objects.equals(this.retryableCodes, that.retryableCodes) - && Objects.equals(this.endpointContext, that.endpointContext); + && Objects.equals(this.endpointContext, that.endpointContext) + && Objects.equals(this.transportChannel, that.transportChannel); } @Override @@ -654,6 +689,7 @@ public int hashCode() { tracer, retrySettings, retryableCodes, - endpointContext); + endpointContext, + transportChannel); } } diff --git a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProvider.java b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProvider.java index daf94a498cc4..347701816dc9 100644 --- a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProvider.java +++ b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProvider.java @@ -198,19 +198,23 @@ HttpTransport createHttpTransport() throws IOException, GeneralSecurityException } private HttpJsonTransportChannel createChannel() throws IOException, GeneralSecurityException { - HttpTransport httpTransportToUse = httpTransport; - if (httpTransportToUse == null) { - httpTransportToUse = createHttpTransport(); - } - - // Pass the executor to the ManagedChannel. If no executor was provided (or null), - // the channel will use a default executor for the calls. - ManagedHttpJsonChannel channel = - ManagedHttpJsonChannel.newBuilder() + java.util.function.Supplier channelFactory = () -> { + try { + HttpTransport httpTransportToUse = httpTransport; + if (httpTransportToUse == null) { + httpTransportToUse = createHttpTransport(); + } + return ManagedHttpJsonChannel.newBuilder() .setEndpoint(endpoint) .setExecutor(executor) .setHttpTransport(httpTransportToUse) .build(); + } catch (Exception e) { + throw new java.lang.RuntimeException("Failed to create fresh ManagedHttpJsonChannel", e); + } + }; + + ManagedHttpJsonChannel channel = new RefreshingHttpJsonChannel(channelFactory); HttpJsonClientInterceptor headerInterceptor = new HttpJsonHeaderInterceptor(headerProvider.getHeaders()); diff --git a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonChannel.java b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonChannel.java index bd3bed855608..6d800e579897 100644 --- a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonChannel.java +++ b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonChannel.java @@ -86,6 +86,8 @@ public HttpJsonClientCall newCall( deadlineScheduledExecutorService); } + public void refresh() {} + @VisibleForTesting Executor getExecutor() { return executor; diff --git a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonInterceptorChannel.java b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonInterceptorChannel.java index 3e71031f1c9d..f01d37a02c3f 100644 --- a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonInterceptorChannel.java +++ b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonInterceptorChannel.java @@ -55,6 +55,11 @@ public HttpJsonClientCall newCall( return interceptor.interceptCall(methodDescriptor, callOptions, channel); } + @Override + public void refresh() { + channel.refresh(); + } + @Override public synchronized void shutdown() { channel.shutdown(); diff --git a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/RefreshingHttpJsonChannel.java b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/RefreshingHttpJsonChannel.java new file mode 100644 index 000000000000..71754bd7d14d --- /dev/null +++ b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/RefreshingHttpJsonChannel.java @@ -0,0 +1,233 @@ +/* + * Copyright 2026 Google LLC + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google LLC nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +package com.google.api.gax.httpjson; + +import com.google.api.core.InternalApi; +import com.google.api.gax.httpjson.ForwardingHttpJsonClientCall.SimpleForwardingHttpJsonClientCall; +import com.google.api.gax.httpjson.ForwardingHttpJsonClientCallListener.SimpleForwardingHttpJsonClientCallListener; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * An implementation of {@link ManagedHttpJsonChannel} that supports dynamic mTLS certificate + * rotation by thread-safely hot-swapping the underlying active HTTP/JSON channel while gracefully + * retiring older connections after all active in-flight requests complete. + */ +@InternalApi +public class RefreshingHttpJsonChannel extends ManagedHttpJsonChannel { + + private static final Logger LOG = Logger.getLogger(RefreshingHttpJsonChannel.class.getName()); + private static final long REFRESH_COOLDOWN_MS = 5000; + + private final Supplier channelFactory; + private final AtomicReference activeEntry; + private final AtomicBoolean refreshInProgress = new AtomicBoolean(false); + private final Object lock = new Object(); + private long lastRefreshTimeMs = 0; + + public RefreshingHttpJsonChannel(Supplier channelFactory) { + this.channelFactory = channelFactory; + this.activeEntry = new AtomicReference<>(new ChannelEntry(channelFactory.get())); + } + + @Override + public void refresh() { + // 1. Lock-free CAS coalescing check to prevent duplicate queueing/blocking of concurrent threads + if (!refreshInProgress.compareAndSet(false, true)) { + return; + } + try { + synchronized (lock) { + long now = System.currentTimeMillis(); + if (now - lastRefreshTimeMs < REFRESH_COOLDOWN_MS) { + LOG.fine("HTTP/JSON channel pool refreshed recently, skipping duplicate refresh"); + return; + } + + LOG.info("mTLS certificate rotation detected. Triggering HTTP/JSON channel pool refresh."); + ChannelEntry newEntry = new ChannelEntry(channelFactory.get()); + ChannelEntry oldEntry = activeEntry.getAndSet(newEntry); + + if (oldEntry != null) { + oldEntry.requestShutdown(); + } + + lastRefreshTimeMs = now; + } + } finally { + refreshInProgress.set(false); + } + } + + private ChannelEntry getRetainedEntry() { + while (true) { + ChannelEntry entry = activeEntry.get(); + if (entry.retain()) { + return entry; + } + } + } + + @Override + public HttpJsonClientCall newCall( + ApiMethodDescriptor methodDescriptor, HttpJsonCallOptions callOptions) { + ChannelEntry entry = getRetainedEntry(); + try { + HttpJsonClientCall delegateCall = + entry.channel.newCall(methodDescriptor, callOptions); + return new ReleasingHttpJsonClientCall<>(delegateCall, entry); + } catch (Exception e) { + entry.release(); + throw e; + } + } + + @Override + public void shutdown() { + activeEntry.get().requestShutdown(); + } + + @Override + public boolean isShutdown() { + return activeEntry.get().channel.isShutdown(); + } + + @Override + public boolean isTerminated() { + return activeEntry.get().channel.isTerminated(); + } + + @Override + public void shutdownNow() { + activeEntry.get().channel.shutdownNow(); + } + + @Override + public boolean awaitTermination(long duration, TimeUnit unit) throws InterruptedException { + return activeEntry.get().channel.awaitTermination(duration, unit); + } + + @Override + public void close() { + shutdown(); + } + + /** Internal container to manage request reference-counting and graceful shutdown. */ + private static class ChannelEntry { + private final ManagedHttpJsonChannel channel; + private final AtomicInteger outstandingCalls = new AtomicInteger(0); + private final AtomicBoolean shutdownRequested = new AtomicBoolean(false); + private final AtomicBoolean shutdownInitiated = new AtomicBoolean(false); + + ChannelEntry(ManagedHttpJsonChannel channel) { + this.channel = channel; + } + + boolean retain() { + outstandingCalls.incrementAndGet(); + if (shutdownRequested.get()) { + release(); + return false; + } + return true; + } + + void release() { + int count = outstandingCalls.decrementAndGet(); + if (shutdownRequested.get() && count == 0) { + shutdown(); + } + } + + void requestShutdown() { + shutdownRequested.set(true); + if (outstandingCalls.get() == 0) { + shutdown(); + } + } + + private void shutdown() { + if (shutdownInitiated.compareAndSet(false, true)) { + try { + channel.shutdown(); + } catch (Exception e) { + LOG.log(Level.WARNING, "Error shutting down retired HTTP/JSON channel", e); + } + } + } + } + + /** A client call decorator that decrements the entry counter upon call completion. */ + private static class ReleasingHttpJsonClientCall + extends SimpleForwardingHttpJsonClientCall { + + private final ChannelEntry entry; + private final AtomicBoolean wasClosed = new AtomicBoolean(false); + private final AtomicBoolean wasReleased = new AtomicBoolean(false); + + ReleasingHttpJsonClientCall(HttpJsonClientCall delegate, ChannelEntry entry) { + super(delegate); + this.entry = entry; + } + + @Override + public void start(Listener responseListener, HttpJsonMetadata requestHeaders) { + try { + super.start( + new SimpleForwardingHttpJsonClientCallListener(responseListener) { + @Override + public void onClose(int statusCode, HttpJsonMetadata trailers) { + if (!wasClosed.compareAndSet(false, true)) { + return; + } + try { + super.onClose(statusCode, trailers); + } finally { + if (wasReleased.compareAndSet(false, true)) { + entry.release(); + } + } + } + }, + requestHeaders); + } catch (Exception e) { + if (wasReleased.compareAndSet(false, true)) { + entry.release(); + } + throw e; + } + } + } +} diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java index 09af475e4833..fc7fb5e989fe 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java @@ -63,6 +63,14 @@ public interface ApiCallContext extends RetryingContext { /** Returns a new ApiCallContext with the given channel set. */ ApiCallContext withTransportChannel(TransportChannel channel); + /** + * Returns the {@link TransportChannel} associated with this call context, or {@code null} if none + * is set. + */ + default TransportChannel getTransportChannel() { + return null; + } + /** Returns a new ApiCallContext with the given Endpoint Context. */ ApiCallContext withEndpointContext(EndpointContext endpointContext); diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiResultRetryAlgorithm.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiResultRetryAlgorithm.java index 688fc32cd14b..7c8fad8497e9 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiResultRetryAlgorithm.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiResultRetryAlgorithm.java @@ -38,6 +38,10 @@ class ApiResultRetryAlgorithm extends BasicResultRetryAlgorithm internalFuture = callable.futureCall(request, callContext); + + if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))) { + final ApiCallContext finalContext = callContext; + ApiFutures.addCallback( + internalFuture, + new com.google.api.core.ApiFutureCallback() { + @Override + public void onFailure(Throwable t) { + if (t instanceof UnauthenticatedException) { + TransportChannel transportChannel = finalContext.getTransportChannel(); + if (transportChannel != null) { + transportChannel.refresh(); + } + } + } + + @Override + public void onSuccess(ResponseT result) {} + }, + com.google.common.util.concurrent.MoreExecutors.directExecutor()); + } + externalFuture.setAttemptFuture(internalFuture); } catch (Throwable e) { externalFuture.setAttemptFuture(ApiFutures.immediateFailedFuture(e)); diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/BidiStreamingCallable.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/BidiStreamingCallable.java index 38efb2da3755..59d6099b2d5b 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/BidiStreamingCallable.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/BidiStreamingCallable.java @@ -236,11 +236,45 @@ public BidiStreamingCallable withDefaultCallContext( return new BidiStreamingCallable() { @Override public ClientStream internalCall( - ResponseObserver responseObserver, + final ResponseObserver responseObserver, ClientStreamReadyObserver onReady, ApiCallContext thisCallContext) { + final ApiCallContext mergedContext = defaultCallContext.merge(thisCallContext); + ResponseObserver refreshingObserver = responseObserver; + + if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))) { + refreshingObserver = + new ResponseObserver() { + @Override + public void onStart(StreamController controller) { + responseObserver.onStart(controller); + } + + @Override + public void onResponse(ResponseT response) { + responseObserver.onResponse(response); + } + + @Override + public void onError(Throwable t) { + if (t instanceof UnauthenticatedException) { + TransportChannel transportChannel = mergedContext.getTransportChannel(); + if (transportChannel != null) { + transportChannel.refresh(); + } + } + responseObserver.onError(t); + } + + @Override + public void onComplete() { + responseObserver.onComplete(); + } + }; + } + return BidiStreamingCallable.this.internalCall( - responseObserver, onReady, defaultCallContext.merge(thisCallContext)); + refreshingObserver, onReady, mergedContext); } }; } diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientStreamingCallable.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientStreamingCallable.java index 13ef1c64568b..c172e93ba20b 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientStreamingCallable.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientStreamingCallable.java @@ -73,9 +73,38 @@ public ClientStreamingCallable withDefaultCallContext( return new ClientStreamingCallable() { @Override public ApiStreamObserver clientStreamingCall( - ApiStreamObserver responseObserver, ApiCallContext thisCallContext) { + final ApiStreamObserver responseObserver, ApiCallContext thisCallContext) { + final ApiCallContext mergedContext = defaultCallContext.merge(thisCallContext); + ApiStreamObserver refreshingObserver = responseObserver; + + if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))) { + refreshingObserver = + new ApiStreamObserver() { + @Override + public void onNext(ResponseT response) { + responseObserver.onNext(response); + } + + @Override + public void onError(Throwable t) { + if (t instanceof UnauthenticatedException) { + TransportChannel transportChannel = mergedContext.getTransportChannel(); + if (transportChannel != null) { + transportChannel.refresh(); + } + } + responseObserver.onError(t); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } + return ClientStreamingCallable.this.clientStreamingCall( - responseObserver, defaultCallContext.merge(thisCallContext)); + refreshingObserver, mergedContext); } }; } diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ServerStreamingAttemptCallable.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ServerStreamingAttemptCallable.java index da0c8de632da..3fe6441d762c 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ServerStreamingAttemptCallable.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ServerStreamingAttemptCallable.java @@ -219,6 +219,7 @@ public Void call() { .getTracer() .attemptStarted(request, outerRetryingFuture.getAttemptSettings().getOverallAttemptCount()); + final ApiCallContext finalContext = attemptContext; innerCallable.call( request, new StateCheckingResponseObserver() { @@ -234,6 +235,18 @@ public void onResponseImpl(ResponseT response) { @Override public void onErrorImpl(Throwable t) { + if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))) { + Throwable cause = t; + if (cause instanceof com.google.api.gax.retrying.ServerStreamingAttemptException) { + cause = cause.getCause(); + } + if (cause instanceof UnauthenticatedException) { + TransportChannel transportChannel = finalContext.getTransportChannel(); + if (transportChannel != null) { + transportChannel.refresh(); + } + } + } onAttemptError(t); } diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/TransportChannel.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/TransportChannel.java index d54352e9b246..65b3cce0e0a3 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/TransportChannel.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/TransportChannel.java @@ -47,4 +47,12 @@ public interface TransportChannel extends BackgroundResource { * Returns an empty {@link ApiCallContext} that is compatible with this {@code TransportChannel}. */ ApiCallContext getEmptyCallContext(); + + /** + * Refreshes or recreates the underlying network connections of this transport channel. + * + *

By default, this is a no-op for transports that do not require stateful connection lifecycle + * management. + */ + default void refresh() {} } From 38684a2f9d69f2b651a85b94fbc5f587c3ca3086 Mon Sep 17 00:00:00 2001 From: Pranav Iyer Date: Thu, 28 May 2026 11:35:28 -0700 Subject: [PATCH 2/2] Included changes as per discussion with blake. --- .gitignore | 4 + .../com/google/api/gax/grpc/ChannelPool.java | 109 ++++++++++- .../api/gax/grpc/GrpcTransportChannel.java | 9 + .../httpjson/HttpJsonTransportChannel.java | 10 + .../gax/httpjson/ManagedHttpJsonChannel.java | 4 + .../ManagedHttpJsonInterceptorChannel.java | 5 + .../httpjson/RefreshingHttpJsonChannel.java | 132 ++++++++++--- .../api/gax/rpc/ApiResultRetryAlgorithm.java | 17 +- .../google/api/gax/rpc/AttemptCallable.java | 21 --- .../api/gax/rpc/BidiStreamingCallable.java | 50 +++-- .../api/gax/rpc/ClientStreamingCallable.java | 42 ++--- .../rpc/ServerStreamingAttemptCallable.java | 18 +- .../google/api/gax/rpc/TransportChannel.java | 8 + .../gax/rpc/mtls/CertificateBasedAccess.java | 120 +++++++++++- .../rpc/mtls/CertificateBasedAccessTest.java | 175 ++++++++++++++++-- 15 files changed, 585 insertions(+), 139 deletions(-) diff --git a/.gitignore b/.gitignore index 7618f6a26508..08f4237725ae 100644 --- a/.gitignore +++ b/.gitignore @@ -86,3 +86,7 @@ monorepo *.tfstate.lock.info .jqwik-database + +**/Agentic_Identities/** +**/*.patch + diff --git a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index d35dbc8d12ca..d771d4c5cc9d 100644 --- a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -43,7 +43,11 @@ import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; +import java.io.FileInputStream; import java.io.IOException; +import java.security.MessageDigest; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.List; import java.util.concurrent.CancellationException; @@ -81,8 +85,19 @@ class ChannelPool extends ManagedChannel { private ScheduledFuture refreshFuture = null; private ScheduledFuture resizeFuture = null; + private static class DiskCheckResult { + final String fingerprint; + final long timestampNanos; + + DiskCheckResult(String fingerprint, long timestampNanos) { + this.fingerprint = fingerprint; + this.timestampNanos = timestampNanos; + } + } + + private final AtomicReference lastDiskCheck = new AtomicReference<>(null); private final Object entryWriteLock = new Object(); - private long lastRefreshTimeNanos = 0; + private volatile String activeCertFingerprint = ""; @VisibleForTesting final AtomicReference> entries = new AtomicReference<>(); private final AtomicInteger indexTicker = new AtomicInteger(); private final String authority; @@ -135,6 +150,11 @@ static ChannelPool create( entries.set(initialListBuilder.build()); authority = entries.get().get(0).channel.authority(); + String certPath = getWorkloadCertPath(); + if (certPath != null) { + this.activeCertFingerprint = getCertificateFingerprint(certPath); + } + if (!settings.isStaticSize()) { resizeFuture = backgroundExecutorProvider @@ -426,6 +446,74 @@ private void refreshSafely() { } } + + private static String getWorkloadCertPath() { + String configPath = System.getenv("GOOGLE_API_CERTIFICATE_CONFIG"); + if (configPath != null && !configPath.isEmpty()) { + java.io.File configFile = new java.io.File(configPath); + if (configFile.exists() && !configFile.isDirectory()) { + // If explicit config exists, check it + } + } + java.io.File bundleFile = new java.io.File("/var/run/secrets/workload-spiffe-credentials/credentialbundle.pem"); + if (bundleFile.exists()) { + return bundleFile.getAbsolutePath(); + } + java.io.File certsFile = new java.io.File("/var/run/secrets/workload-spiffe-credentials/certificates.pem"); + if (certsFile.exists()) { + return certsFile.getAbsolutePath(); + } + return null; + } + + private static String getCertificateFingerprint(String certPath) { + try (FileInputStream fis = new FileInputStream(certPath)) { + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + X509Certificate cert = (X509Certificate) cf.generateCertificate(fis); + MessageDigest md = MessageDigest.getInstance("SHA-256"); + byte[] der = cert.getEncoded(); + byte[] digest = md.digest(der); + StringBuilder sb = new StringBuilder(); + for (byte b : digest) { + sb.append(String.format("%02x", b)); + } + return sb.toString(); + } catch (Exception e) { + LOG.log(Level.FINE, "Could not read or parse workload certificate at path " + certPath, e); + return ""; + } + } + + private String getOrUpdateDiskFingerprint(String certPath) { + long now = System.nanoTime(); + DiskCheckResult cached = lastDiskCheck.get(); + if (cached != null && (now - cached.timestampNanos < java.util.concurrent.TimeUnit.SECONDS.toNanos(1))) { + return cached.fingerprint; + } + + synchronized (lastDiskCheck) { + cached = lastDiskCheck.get(); + if (cached != null && (now - cached.timestampNanos < java.util.concurrent.TimeUnit.SECONDS.toNanos(1))) { + return cached.fingerprint; + } + String fingerprint = getCertificateFingerprint(certPath); + lastDiskCheck.set(new DiskCheckResult(fingerprint, System.nanoTime())); + return fingerprint; + } + } + + boolean shouldRefresh() { + String certPath = getWorkloadCertPath(); + if (certPath == null) { + return false; + } + String currentDiskFingerprint = getOrUpdateDiskFingerprint(certPath); + if (currentDiskFingerprint.isEmpty()) { + return false; + } + return !currentDiskFingerprint.equalsIgnoreCase(activeCertFingerprint); + } + /** * Replace all of the channels in the channel pool with fresh ones. This is meant to mitigate the * hourly GFE disconnects by giving clients the ability to prime the channel on reconnect. @@ -442,14 +530,23 @@ void refresh() { // - then thread2 will shut down channel that thread1 will put back into circulation (after it // replaces the list) synchronized (entryWriteLock) { - long now = System.nanoTime(); - if (now - lastRefreshTimeNanos < TimeUnit.SECONDS.toNanos(5)) { - LOG.fine("Channel pool was refreshed recently, skipping duplicate refresh"); + String certPath = getWorkloadCertPath(); + if (certPath == null) { + return; + } + String currentDiskFingerprint = getOrUpdateDiskFingerprint(certPath); + if (currentDiskFingerprint.isEmpty()) { + return; + } + + // Double-check fingerprint inside the lock + if (currentDiskFingerprint.equals(this.activeCertFingerprint)) { + LOG.fine("Channel pool was already refreshed by a concurrent thread, skipping duplicate refresh"); return; } - lastRefreshTimeNanos = now; - LOG.fine("Refreshing all channels"); + this.activeCertFingerprint = currentDiskFingerprint; + LOG.fine("Refreshing all channels with new certificate fingerprint: " + activeCertFingerprint); ArrayList newEntries = new ArrayList<>(entries.get()); for (int i = 0; i < newEntries.size(); i++) { diff --git a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java index 80d471701d5a..409ca57205d3 100644 --- a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java +++ b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java @@ -74,6 +74,15 @@ public void refresh() { } } + @Override + public boolean shouldRefresh() { + Channel channel = getChannel(); + if (channel instanceof ChannelPool) { + return ((ChannelPool) channel).shouldRefresh(); + } + return false; + } + @Override public void shutdown() { getManagedChannel().shutdown(); diff --git a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonTransportChannel.java b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonTransportChannel.java index 355193964555..174245aa34cb 100644 --- a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonTransportChannel.java +++ b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonTransportChannel.java @@ -62,6 +62,16 @@ public HttpJsonChannel getChannel() { return getManagedChannel(); } + @Override + public void refresh() { + getManagedChannel().refresh(); + } + + @Override + public boolean shouldRefresh() { + return getManagedChannel().shouldRefresh(); + } + @Override public void shutdown() { getManagedChannel().shutdown(); diff --git a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonChannel.java b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonChannel.java index 6d800e579897..6070e8f7ce85 100644 --- a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonChannel.java +++ b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonChannel.java @@ -88,6 +88,10 @@ public HttpJsonClientCall newCall( public void refresh() {} + public boolean shouldRefresh() { + return false; + } + @VisibleForTesting Executor getExecutor() { return executor; diff --git a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonInterceptorChannel.java b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonInterceptorChannel.java index f01d37a02c3f..e5b7a2f8ea5f 100644 --- a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonInterceptorChannel.java +++ b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonInterceptorChannel.java @@ -60,6 +60,11 @@ public void refresh() { channel.refresh(); } + @Override + public boolean shouldRefresh() { + return channel.shouldRefresh(); + } + @Override public synchronized void shutdown() { channel.shutdown(); diff --git a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/RefreshingHttpJsonChannel.java b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/RefreshingHttpJsonChannel.java index 71754bd7d14d..a6c1696c0b39 100644 --- a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/RefreshingHttpJsonChannel.java +++ b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/RefreshingHttpJsonChannel.java @@ -32,6 +32,10 @@ import com.google.api.core.InternalApi; import com.google.api.gax.httpjson.ForwardingHttpJsonClientCall.SimpleForwardingHttpJsonClientCall; import com.google.api.gax.httpjson.ForwardingHttpJsonClientCallListener.SimpleForwardingHttpJsonClientCallListener; +import java.io.FileInputStream; +import java.security.MessageDigest; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -49,45 +53,127 @@ public class RefreshingHttpJsonChannel extends ManagedHttpJsonChannel { private static final Logger LOG = Logger.getLogger(RefreshingHttpJsonChannel.class.getName()); - private static final long REFRESH_COOLDOWN_MS = 5000; + private static class DiskCheckResult { + final String fingerprint; + final long timestampNanos; + + DiskCheckResult(String fingerprint, long timestampNanos) { + this.fingerprint = fingerprint; + this.timestampNanos = timestampNanos; + } + } + + private final AtomicReference lastDiskCheck = new AtomicReference<>(null); private final Supplier channelFactory; private final AtomicReference activeEntry; - private final AtomicBoolean refreshInProgress = new AtomicBoolean(false); private final Object lock = new Object(); - private long lastRefreshTimeMs = 0; + private volatile String activeCertFingerprint = ""; public RefreshingHttpJsonChannel(Supplier channelFactory) { this.channelFactory = channelFactory; this.activeEntry = new AtomicReference<>(new ChannelEntry(channelFactory.get())); + String certPath = getWorkloadCertPath(); + if (certPath != null) { + this.activeCertFingerprint = getCertificateFingerprint(certPath); + } + } + + private static String getWorkloadCertPath() { + String configPath = System.getenv("GOOGLE_API_CERTIFICATE_CONFIG"); + if (configPath != null && !configPath.isEmpty()) { + java.io.File configFile = new java.io.File(configPath); + if (configFile.exists() && !configFile.isDirectory()) { + // If it is JSON or PEM, we try to resolve it + } + } + java.io.File bundleFile = new java.io.File("/var/run/secrets/workload-spiffe-credentials/credentialbundle.pem"); + if (bundleFile.exists()) { + return bundleFile.getAbsolutePath(); + } + java.io.File certsFile = new java.io.File("/var/run/secrets/workload-spiffe-credentials/certificates.pem"); + if (certsFile.exists()) { + return certsFile.getAbsolutePath(); + } + return null; + } + + private static String getCertificateFingerprint(String certPath) { + try (FileInputStream fis = new FileInputStream(certPath)) { + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + X509Certificate cert = (X509Certificate) cf.generateCertificate(fis); + MessageDigest md = MessageDigest.getInstance("SHA-256"); + byte[] der = cert.getEncoded(); + byte[] digest = md.digest(der); + StringBuilder sb = new StringBuilder(); + for (byte b : digest) { + sb.append(String.format("%02x", b)); + } + return sb.toString(); + } catch (Exception e) { + LOG.log(Level.FINE, "Could not read or parse workload certificate at path " + certPath, e); + return ""; + } + } + + private String getOrUpdateDiskFingerprint(String certPath) { + long now = System.nanoTime(); + DiskCheckResult cached = lastDiskCheck.get(); + if (cached != null && (now - cached.timestampNanos < java.util.concurrent.TimeUnit.SECONDS.toNanos(1))) { + return cached.fingerprint; + } + + synchronized (lastDiskCheck) { + cached = lastDiskCheck.get(); + if (cached != null && (now - cached.timestampNanos < java.util.concurrent.TimeUnit.SECONDS.toNanos(1))) { + return cached.fingerprint; + } + String fingerprint = getCertificateFingerprint(certPath); + lastDiskCheck.set(new DiskCheckResult(fingerprint, System.nanoTime())); + return fingerprint; + } } @Override - public void refresh() { - // 1. Lock-free CAS coalescing check to prevent duplicate queueing/blocking of concurrent threads - if (!refreshInProgress.compareAndSet(false, true)) { - return; + public boolean shouldRefresh() { + String certPath = getWorkloadCertPath(); + if (certPath == null) { + return false; } - try { - synchronized (lock) { - long now = System.currentTimeMillis(); - if (now - lastRefreshTimeMs < REFRESH_COOLDOWN_MS) { - LOG.fine("HTTP/JSON channel pool refreshed recently, skipping duplicate refresh"); - return; - } + String currentDiskFingerprint = getOrUpdateDiskFingerprint(certPath); + if (currentDiskFingerprint.isEmpty()) { + return false; + } + return !currentDiskFingerprint.equalsIgnoreCase(activeCertFingerprint); + } + + @Override + public void refresh() { + synchronized (lock) { + String certPath = getWorkloadCertPath(); + if (certPath == null) { + return; + } + String currentDiskFingerprint = getOrUpdateDiskFingerprint(certPath); + if (currentDiskFingerprint.isEmpty()) { + return; + } - LOG.info("mTLS certificate rotation detected. Triggering HTTP/JSON channel pool refresh."); - ChannelEntry newEntry = new ChannelEntry(channelFactory.get()); - ChannelEntry oldEntry = activeEntry.getAndSet(newEntry); + // Double-check inside lock + if (currentDiskFingerprint.equals(this.activeCertFingerprint)) { + LOG.fine("HTTP/JSON channel was already refreshed by a concurrent thread, skipping duplicate refresh"); + return; + } - if (oldEntry != null) { - oldEntry.requestShutdown(); - } + this.activeCertFingerprint = currentDiskFingerprint; + LOG.info("mTLS certificate rotation detected. Triggering HTTP/JSON channel pool refresh."); + + ChannelEntry newEntry = new ChannelEntry(channelFactory.get()); + ChannelEntry oldEntry = activeEntry.getAndSet(newEntry); - lastRefreshTimeMs = now; + if (oldEntry != null) { + oldEntry.requestShutdown(); } - } finally { - refreshInProgress.set(false); } } diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiResultRetryAlgorithm.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiResultRetryAlgorithm.java index 7c8fad8497e9..32944443e297 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiResultRetryAlgorithm.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiResultRetryAlgorithm.java @@ -35,13 +35,8 @@ /* Package-private for internal use. */ class ApiResultRetryAlgorithm extends BasicResultRetryAlgorithm { - /** Returns true if previousThrowable is an {@link ApiException} that is retryable. */ @Override public boolean shouldRetry(Throwable previousThrowable, ResponseT previousResponse) { - if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment")) - && previousThrowable instanceof UnauthenticatedException) { - return true; - } return (previousThrowable instanceof ApiException) && ((ApiException) previousThrowable).isRetryable(); } @@ -55,9 +50,15 @@ public boolean shouldRetry(Throwable previousThrowable, ResponseT previousRespon @Override public boolean shouldRetry( RetryingContext context, Throwable previousThrowable, ResponseT previousResponse) { - if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment")) - && previousThrowable instanceof UnauthenticatedException) { - return true; + if (previousThrowable instanceof UnauthenticatedException) { + if (context instanceof ApiCallContext) { + TransportChannel transportChannel = ((ApiCallContext) context).getTransportChannel(); + if (transportChannel != null && transportChannel.shouldRefresh()) { + transportChannel.refresh(); + return true; + } + } + return false; } if (context.getRetryableCodes() != null) { // Ignore the isRetryable() value of the throwable if the RetryingContext has a specific list diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/AttemptCallable.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/AttemptCallable.java index 1368f9e0f20c..2a987c7182bc 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/AttemptCallable.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/AttemptCallable.java @@ -85,27 +85,6 @@ public ResponseT call() { ApiFuture internalFuture = callable.futureCall(request, callContext); - if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))) { - final ApiCallContext finalContext = callContext; - ApiFutures.addCallback( - internalFuture, - new com.google.api.core.ApiFutureCallback() { - @Override - public void onFailure(Throwable t) { - if (t instanceof UnauthenticatedException) { - TransportChannel transportChannel = finalContext.getTransportChannel(); - if (transportChannel != null) { - transportChannel.refresh(); - } - } - } - - @Override - public void onSuccess(ResponseT result) {} - }, - com.google.common.util.concurrent.MoreExecutors.directExecutor()); - } - externalFuture.setAttemptFuture(internalFuture); } catch (Throwable e) { externalFuture.setAttemptFuture(ApiFutures.immediateFailedFuture(e)); diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/BidiStreamingCallable.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/BidiStreamingCallable.java index 59d6099b2d5b..a5fa0aee406c 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/BidiStreamingCallable.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/BidiStreamingCallable.java @@ -240,38 +240,34 @@ public ClientStream internalCall( ClientStreamReadyObserver onReady, ApiCallContext thisCallContext) { final ApiCallContext mergedContext = defaultCallContext.merge(thisCallContext); - ResponseObserver refreshingObserver = responseObserver; + ResponseObserver refreshingObserver = + new ResponseObserver() { + @Override + public void onStart(StreamController controller) { + responseObserver.onStart(controller); + } - if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))) { - refreshingObserver = - new ResponseObserver() { - @Override - public void onStart(StreamController controller) { - responseObserver.onStart(controller); - } - - @Override - public void onResponse(ResponseT response) { - responseObserver.onResponse(response); - } + @Override + public void onResponse(ResponseT response) { + responseObserver.onResponse(response); + } - @Override - public void onError(Throwable t) { - if (t instanceof UnauthenticatedException) { - TransportChannel transportChannel = mergedContext.getTransportChannel(); - if (transportChannel != null) { - transportChannel.refresh(); - } + @Override + public void onError(Throwable t) { + if (t instanceof UnauthenticatedException) { + TransportChannel transportChannel = mergedContext.getTransportChannel(); + if (transportChannel != null && transportChannel.shouldRefresh()) { + transportChannel.refresh(); } - responseObserver.onError(t); } + responseObserver.onError(t); + } - @Override - public void onComplete() { - responseObserver.onComplete(); - } - }; - } + @Override + public void onComplete() { + responseObserver.onComplete(); + } + }; return BidiStreamingCallable.this.internalCall( refreshingObserver, onReady, mergedContext); diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientStreamingCallable.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientStreamingCallable.java index c172e93ba20b..8c1db78c2810 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientStreamingCallable.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientStreamingCallable.java @@ -75,33 +75,29 @@ public ClientStreamingCallable withDefaultCallContext( public ApiStreamObserver clientStreamingCall( final ApiStreamObserver responseObserver, ApiCallContext thisCallContext) { final ApiCallContext mergedContext = defaultCallContext.merge(thisCallContext); - ApiStreamObserver refreshingObserver = responseObserver; + ApiStreamObserver refreshingObserver = + new ApiStreamObserver() { + @Override + public void onNext(ResponseT response) { + responseObserver.onNext(response); + } - if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))) { - refreshingObserver = - new ApiStreamObserver() { - @Override - public void onNext(ResponseT response) { - responseObserver.onNext(response); - } - - @Override - public void onError(Throwable t) { - if (t instanceof UnauthenticatedException) { - TransportChannel transportChannel = mergedContext.getTransportChannel(); - if (transportChannel != null) { - transportChannel.refresh(); - } + @Override + public void onError(Throwable t) { + if (t instanceof UnauthenticatedException) { + TransportChannel transportChannel = mergedContext.getTransportChannel(); + if (transportChannel != null && transportChannel.shouldRefresh()) { + transportChannel.refresh(); } - responseObserver.onError(t); } + responseObserver.onError(t); + } - @Override - public void onCompleted() { - responseObserver.onCompleted(); - } - }; - } + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; return ClientStreamingCallable.this.clientStreamingCall( refreshingObserver, mergedContext); diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ServerStreamingAttemptCallable.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ServerStreamingAttemptCallable.java index 3fe6441d762c..81c92d2708de 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ServerStreamingAttemptCallable.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ServerStreamingAttemptCallable.java @@ -235,16 +235,14 @@ public void onResponseImpl(ResponseT response) { @Override public void onErrorImpl(Throwable t) { - if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))) { - Throwable cause = t; - if (cause instanceof com.google.api.gax.retrying.ServerStreamingAttemptException) { - cause = cause.getCause(); - } - if (cause instanceof UnauthenticatedException) { - TransportChannel transportChannel = finalContext.getTransportChannel(); - if (transportChannel != null) { - transportChannel.refresh(); - } + Throwable cause = t; + if (cause instanceof com.google.api.gax.retrying.ServerStreamingAttemptException) { + cause = cause.getCause(); + } + if (cause instanceof UnauthenticatedException) { + TransportChannel transportChannel = finalContext.getTransportChannel(); + if (transportChannel != null && transportChannel.shouldRefresh()) { + transportChannel.refresh(); } } onAttemptError(t); diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/TransportChannel.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/TransportChannel.java index 65b3cce0e0a3..de83dbc73861 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/TransportChannel.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/TransportChannel.java @@ -55,4 +55,12 @@ public interface TransportChannel extends BackgroundResource { * management. */ default void refresh() {} + + /** + * Returns true if a certificate rotation has been detected on disk and the transport channel + * should be refreshed, or false otherwise. + */ + default boolean shouldRefresh() { + return false; + } } diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/mtls/CertificateBasedAccess.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/mtls/CertificateBasedAccess.java index 6f722bff6047..e28a903832e6 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/mtls/CertificateBasedAccess.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/mtls/CertificateBasedAccess.java @@ -32,21 +32,49 @@ import com.google.api.core.InternalApi; import com.google.api.gax.rpc.internal.EnvironmentProvider; +import java.io.IOException; /** * Utility class for handling certificate-based access configurations. * - *

This class handles the processing of GOOGLE_API_USE_CLIENT_CERTIFICATE and - * GOOGLE_API_USE_MTLS_ENDPOINT environment variables according to https://google.aip.dev/auth/4114 + *

This class handles the processing of GOOGLE_API_USE_CLIENT_CERTIFICATE, + * GOOGLE_API_CERTIFICATE_CONFIG, and GOOGLE_API_USE_MTLS_ENDPOINT configurations. */ @InternalApi public class CertificateBasedAccess { private final EnvironmentProvider envProvider; + private final FileExistenceProvider fileExistenceProvider; + private final FileContentReader fileContentReader; + + @InternalApi + public interface FileExistenceProvider { + boolean exists(String path); + } + + @InternalApi + public interface FileContentReader { + String read(String path) throws IOException; + } - /** The EnvironmentProvider mechanism supports env var injection for unit tests. */ public CertificateBasedAccess(EnvironmentProvider envProvider) { + this( + envProvider, + path -> { + java.io.File file = new java.io.File(path); + return file.exists() && file.isFile(); + }, + path -> new String(java.nio.file.Files.readAllBytes(java.nio.file.Paths.get(path)), java.nio.charset.StandardCharsets.UTF_8) + ); + } + + CertificateBasedAccess( + EnvironmentProvider envProvider, + FileExistenceProvider fileExistenceProvider, + FileContentReader fileContentReader) { this.envProvider = envProvider; + this.fileExistenceProvider = fileExistenceProvider; + this.fileContentReader = fileContentReader; } public static CertificateBasedAccess createWithSystemEnv() { @@ -64,10 +92,94 @@ public enum MtlsEndpointUsagePolicy { ALWAYS; } + private static class CertificateConfig { + final String certPath; + final String keyPath; + + CertificateConfig(String certPath, String keyPath) { + this.certPath = certPath; + this.keyPath = keyPath; + } + } + + private CertificateConfig parseCertificateConfig(String configPath) throws IOException { + String content = fileContentReader.read(configPath); + + String certPath = extractJsonValue(content, "cert_path"); + String keyPath = extractJsonValue(content, "key_path"); + + if (certPath == null || keyPath == null) { + throw new IllegalStateException("Malformed certificate config JSON. Must contain 'cert_path' and 'key_path'."); + } + + return new CertificateConfig(certPath, keyPath); + } + + private String extractJsonValue(String json, String key) { + java.util.regex.Pattern pattern = java.util.regex.Pattern.compile( + "\"" + java.util.regex.Pattern.quote(key) + "\"\\s*:\\s*\"([^\"]+)\"" + ); + java.util.regex.Matcher matcher = pattern.matcher(json); + if (matcher.find()) { + return matcher.group(1); + } + return null; + } + /** Returns if mutual TLS client certificate should be used. */ public boolean useMtlsClientCertificate() { + // 1. Check the explicit user flag first (Primary override) String useClientCertificate = envProvider.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE"); - return "true".equals(useClientCertificate); + if (useClientCertificate != null && !useClientCertificate.isEmpty()) { + if ("false".equalsIgnoreCase(useClientCertificate)) { + return false; + } + } + + // 2. Check the certificate config file path if provided via env var + String certConfigPath = envProvider.getenv("GOOGLE_API_CERTIFICATE_CONFIG"); + if (certConfigPath != null && !certConfigPath.isEmpty()) { + return validateAndResolveConfig(certConfigPath); + } + + // 3. Fallback to well-known spiffe path + String wellKnownPath = "/var/run/secrets/workload-spiffe-credentials/"; + + // Check for atomic bundle containing both cert and key + if (fileExistenceProvider.exists(wellKnownPath + "credentialbundle.pem")) { + return true; + } + + // Check for separate certificate and private key files + if (fileExistenceProvider.exists(wellKnownPath + "certificates.pem") + && fileExistenceProvider.exists(wellKnownPath + "private_key.pem")) { + return true; + } + + // Default to false if no configuration is found + return false; + } + + private boolean validateAndResolveConfig(String configPath) { + if (!fileExistenceProvider.exists(configPath)) { + throw new IllegalStateException( + "Certificate config is configured but the file does not exist: " + configPath + ); + } + try { + CertificateConfig config = parseCertificateConfig(configPath); + if (!fileExistenceProvider.exists(config.certPath) || !fileExistenceProvider.exists(config.keyPath)) { + throw new IllegalStateException( + "Certificate config points to certificate/key files that do not exist on disk: " + + "cert_path=" + config.certPath + ", key_path=" + config.keyPath + ); + } + return true; + } catch (Exception e) { + throw new IllegalStateException( + "Failed to parse or validate certificate config: " + configPath, e + ); + } } /** Returns the current mutual TLS endpoint usage policy. */ diff --git a/sdk-platform-java/gax-java/gax/src/test/java/com/google/api/gax/rpc/mtls/CertificateBasedAccessTest.java b/sdk-platform-java/gax-java/gax/src/test/java/com/google/api/gax/rpc/mtls/CertificateBasedAccessTest.java index e328e0af4799..2b0ade3937cb 100644 --- a/sdk-platform-java/gax-java/gax/src/test/java/com/google/api/gax/rpc/mtls/CertificateBasedAccessTest.java +++ b/sdk-platform-java/gax-java/gax/src/test/java/com/google/api/gax/rpc/mtls/CertificateBasedAccessTest.java @@ -33,51 +33,192 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; import org.junit.jupiter.api.Test; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; class CertificateBasedAccessTest { + private static class TestEnv { + private final Map env = new HashMap<>(); + + void set(String key, String val) { + env.put(key, val); + } + + String get(String name) { + return env.get(name); + } + } + + private static class TestFileSystem { + private final Map exists = new HashMap<>(); + private final Map content = new HashMap<>(); + + void setExists(String path, boolean val) { + exists.put(path, val); + } + + void setContent(String path, String val) { + content.put(path, val); + exists.put(path, true); + } + } + + private CertificateBasedAccess createCba(TestEnv env, TestFileSystem fs) { + return new CertificateBasedAccess( + env::get, + fs.exists::getOrDefault, + path -> { + if (!fs.content.containsKey(path)) { + throw new IOException("File not found: " + path); + } + return fs.content.get(path); + } + ); + } + @Test void testUseMtlsEndpointAlways() { - CertificateBasedAccess cba = - new CertificateBasedAccess( - name -> name.equals("GOOGLE_API_USE_MTLS_ENDPOINT") ? "always" : "false"); + TestEnv env = new TestEnv(); + env.set("GOOGLE_API_USE_MTLS_ENDPOINT", "always"); + CertificateBasedAccess cba = createCba(env, new TestFileSystem()); assertEquals( CertificateBasedAccess.MtlsEndpointUsagePolicy.ALWAYS, cba.getMtlsEndpointUsagePolicy()); } @Test void testUseMtlsEndpointAuto() { - CertificateBasedAccess cba = - new CertificateBasedAccess( - name -> name.equals("GOOGLE_API_USE_MTLS_ENDPOINT") ? "auto" : "false"); + TestEnv env = new TestEnv(); + env.set("GOOGLE_API_USE_MTLS_ENDPOINT", "auto"); + CertificateBasedAccess cba = createCba(env, new TestFileSystem()); assertEquals( CertificateBasedAccess.MtlsEndpointUsagePolicy.AUTO, cba.getMtlsEndpointUsagePolicy()); } @Test void testUseMtlsEndpointNever() { - CertificateBasedAccess cba = - new CertificateBasedAccess( - name -> name.equals("GOOGLE_API_USE_MTLS_ENDPOINT") ? "never" : "false"); + TestEnv env = new TestEnv(); + env.set("GOOGLE_API_USE_MTLS_ENDPOINT", "never"); + CertificateBasedAccess cba = createCba(env, new TestFileSystem()); assertEquals( CertificateBasedAccess.MtlsEndpointUsagePolicy.NEVER, cba.getMtlsEndpointUsagePolicy()); } @Test - void testUseMtlsClientCertificateTrue() { - CertificateBasedAccess cba = - new CertificateBasedAccess( - name -> name.equals("GOOGLE_API_USE_CLIENT_CERTIFICATE") ? "true" : "auto"); + void testUseMtlsClientCertificateExplicitTrueNoCredentials() { + TestEnv env = new TestEnv(); + env.set("GOOGLE_API_USE_CLIENT_CERTIFICATE", "true"); + CertificateBasedAccess cba = createCba(env, new TestFileSystem()); + // Explicit 'true' requires credentials to be present on disk, otherwise falls back to false + assertFalse(cba.useMtlsClientCertificate()); + } + + @Test + void testUseMtlsClientCertificateExplicitTrueWithSpiffeBundle() { + TestEnv env = new TestEnv(); + env.set("GOOGLE_API_USE_CLIENT_CERTIFICATE", "true"); + + TestFileSystem fs = new TestFileSystem(); + fs.setExists("/var/run/secrets/workload-spiffe-credentials/credentialbundle.pem", true); + + CertificateBasedAccess cba = createCba(env, fs); assertTrue(cba.useMtlsClientCertificate()); } @Test - void testUseMtlsClientCertificateFalse() { - CertificateBasedAccess cba = - new CertificateBasedAccess( - name -> name.equals("GOOGLE_API_USE_CLIENT_CERTIFICATE") ? "false" : "auto"); + void testUseMtlsClientCertificateExplicitFalseWithSpiffeBundle() { + TestEnv env = new TestEnv(); + env.set("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"); + + // Even if spiffe files are present, explicit false must override and disable mtls + TestFileSystem fs = new TestFileSystem(); + fs.setExists("/var/run/secrets/workload-spiffe-credentials/credentialbundle.pem", true); + + CertificateBasedAccess cba = createCba(env, fs); + assertFalse(cba.useMtlsClientCertificate()); + } + + @Test + void testUseMtlsClientCertificateUnsetNoFiles() { + TestEnv env = new TestEnv(); + CertificateBasedAccess cba = createCba(env, new TestFileSystem()); assertFalse(cba.useMtlsClientCertificate()); } + + @Test + void testUseMtlsClientCertificateUnsetSpiffeBundleExists() { + TestEnv env = new TestEnv(); + TestFileSystem fs = new TestFileSystem(); + fs.setExists("/var/run/secrets/workload-spiffe-credentials/credentialbundle.pem", true); + CertificateBasedAccess cba = createCba(env, fs); + assertTrue(cba.useMtlsClientCertificate()); + } + + @Test + void testUseMtlsClientCertificateUnsetSpiffeSeparateFilesExist() { + TestEnv env = new TestEnv(); + TestFileSystem fs = new TestFileSystem(); + fs.setExists("/var/run/secrets/workload-spiffe-credentials/certificates.pem", true); + fs.setExists("/var/run/secrets/workload-spiffe-credentials/private_key.pem", true); + CertificateBasedAccess cba = createCba(env, fs); + assertTrue(cba.useMtlsClientCertificate()); + } + + @Test + void testUseMtlsClientCertificateConfigValid() { + TestEnv env = new TestEnv(); + env.set("GOOGLE_API_CERTIFICATE_CONFIG", "/path/to/config.json"); + + TestFileSystem fs = new TestFileSystem(); + fs.setContent("/path/to/config.json", "{\n \"cert_path\": \"/my/cert.pem\",\n \"key_path\": \"/my/key.pem\"\n}"); + fs.setExists("/my/cert.pem", true); + fs.setExists("/my/key.pem", true); + + CertificateBasedAccess cba = createCba(env, fs); + assertTrue(cba.useMtlsClientCertificate()); + } + + @Test + void testUseMtlsClientCertificateConfigMissingFile() { + TestEnv env = new TestEnv(); + env.set("GOOGLE_API_CERTIFICATE_CONFIG", "/path/to/config.json"); + + CertificateBasedAccess cba = createCba(env, new TestFileSystem()); + + IllegalStateException ex = assertThrows(IllegalStateException.class, cba::useMtlsClientCertificate); + assertTrue(ex.getMessage().contains("configured but the file does not exist")); + } + + @Test + void testUseMtlsClientCertificateConfigMalformedJson() { + TestEnv env = new TestEnv(); + env.set("GOOGLE_API_CERTIFICATE_CONFIG", "/path/to/config.json"); + + TestFileSystem fs = new TestFileSystem(); + fs.setContent("/path/to/config.json", "{\n \"broken_path\": \"/my/cert.pem\"\n}"); + + CertificateBasedAccess cba = createCba(env, fs); + + IllegalStateException ex = assertThrows(IllegalStateException.class, cba::useMtlsClientCertificate); + assertTrue(ex.getMessage().contains("Failed to parse or validate certificate config")); + } + + @Test + void testUseMtlsClientCertificateConfigMissingCertFiles() { + TestEnv env = new TestEnv(); + env.set("GOOGLE_API_CERTIFICATE_CONFIG", "/path/to/config.json"); + + TestFileSystem fs = new TestFileSystem(); + fs.setContent("/path/to/config.json", "{\n \"cert_path\": \"/my/cert.pem\",\n \"key_path\": \"/my/key.pem\"\n}"); + // my/cert.pem and key.pem DO NOT exist + + CertificateBasedAccess cba = createCba(env, fs); + + IllegalStateException ex = assertThrows(IllegalStateException.class, cba::useMtlsClientCertificate); + assertTrue(ex.getMessage().contains("points to certificate/key files that do not exist on disk")); + } }