1313import io .undertow .server .HttpHandler ;
1414import io .undertow .server .HttpServerExchange ;
1515import io .undertow .util .HeaderMap ;
16+ import io .undertow .util .Headers ;
1617import io .undertow .util .HttpString ;
1718import org .slf4j .Logger ;
1819import org .slf4j .LoggerFactory ;
20+ import software .amazon .awssdk .auth .credentials .AwsCredentialsProvider ;
21+ import software .amazon .awssdk .auth .credentials .AwsCredentials ;
1922import software .amazon .awssdk .core .SdkBytes ;
2023import software .amazon .awssdk .core .client .config .ClientOverrideConfiguration ;
2124import software .amazon .awssdk .http .async .SdkAsyncHttpClient ;
2629import software .amazon .awssdk .services .lambda .model .InvokeRequest ;
2730import software .amazon .awssdk .services .sts .StsClient ;
2831import software .amazon .awssdk .services .sts .auth .StsAssumeRoleCredentialsProvider ;
32+ import software .amazon .awssdk .services .sts .auth .StsAssumeRoleWithWebIdentityCredentialsProvider ;
2933import software .amazon .awssdk .services .sts .model .AssumeRoleRequest ;
34+ import software .amazon .awssdk .services .sts .model .AssumeRoleWithWebIdentityRequest ;
3035
3136import java .net .URI ;
37+ import java .nio .charset .StandardCharsets ;
38+ import java .security .MessageDigest ;
39+ import java .security .NoSuchAlgorithmException ;
3240import java .time .Duration ;
41+ import java .util .Base64 ;
3342import java .util .Deque ;
3443import java .util .HashMap ;
3544import java .util .Map ;
@@ -40,13 +49,81 @@ public class LambdaFunctionHandler implements LightHttpHandler {
4049 private static final Logger logger = LoggerFactory .getLogger (LambdaFunctionHandler .class );
4150 private static final String MISSING_ENDPOINT_FUNCTION = "ERR10063" ;
4251 private static final String EMPTY_LAMBDA_RESPONSE = "ERR10064" ;
52+ private static final String STS_TYPE_FUNC_USER = "StsFuncUser" ;
53+ private static final String STS_TYPE_WEB_IDENTITY = "StsWebIdentity" ;
54+ private static final String BEARER_PREFIX = "BEARER" ;
55+ private static final String INVALID_WEB_IDENTITY_TOKEN_MESSAGE = "Missing or invalid Bearer token for STS web identity" ;
4356 private static AbstractMetricsHandler metricsHandler ;
4457
4558 private LambdaInvokerConfig config ;
4659 private LambdaAsyncClient client ;
4760 private StsAssumeRoleCredentialsProvider stsCredentialsProvider ;
61+ private MutableStsWebIdentityCredentialsProvider stsWebIdentityCredentialsProvider ;
4862 private StsClient stsClient ;
4963
64+ static final class MutableStsWebIdentityCredentialsProvider implements AwsCredentialsProvider , AutoCloseable {
65+ private final LambdaInvokerConfig config ;
66+ private final StsClient stsClient ;
67+ private StsAssumeRoleWithWebIdentityCredentialsProvider delegate ;
68+ private String tokenFingerprint ;
69+
70+ MutableStsWebIdentityCredentialsProvider (LambdaInvokerConfig config , StsClient stsClient ) {
71+ this .config = config ;
72+ this .stsClient = stsClient ;
73+ }
74+
75+ synchronized boolean updateToken (String token ) {
76+ String nextFingerprint = fingerprintToken (token );
77+ if (nextFingerprint .equals (tokenFingerprint ) && delegate != null ) {
78+ return false ;
79+ }
80+ StsAssumeRoleWithWebIdentityCredentialsProvider nextDelegate =
81+ StsAssumeRoleWithWebIdentityCredentialsProvider .builder ()
82+ .stsClient (stsClient )
83+ .refreshRequest (AssumeRoleWithWebIdentityRequest .builder ()
84+ .roleArn (config .getRoleArn ())
85+ .roleSessionName (config .getRoleSessionName ())
86+ .durationSeconds (config .getDurationSeconds ())
87+ .webIdentityToken (token )
88+ .build ())
89+ .build ();
90+ StsAssumeRoleWithWebIdentityCredentialsProvider previousDelegate = delegate ;
91+ delegate = nextDelegate ;
92+ tokenFingerprint = nextFingerprint ;
93+ closeDelegate (previousDelegate );
94+ return true ;
95+ }
96+
97+ synchronized String getTokenFingerprint () {
98+ return tokenFingerprint ;
99+ }
100+
101+ @ Override
102+ public synchronized AwsCredentials resolveCredentials () {
103+ if (delegate == null ) {
104+ throw new IllegalStateException ("STS web identity credentials provider has not been initialized with a bearer token" );
105+ }
106+ return delegate .resolveCredentials ();
107+ }
108+
109+ @ Override
110+ public synchronized void close () {
111+ closeDelegate (delegate );
112+ delegate = null ;
113+ tokenFingerprint = null ;
114+ }
115+
116+ private void closeDelegate (StsAssumeRoleWithWebIdentityCredentialsProvider provider ) {
117+ if (provider != null ) {
118+ try {
119+ provider .close ();
120+ } catch (Exception e ) {
121+ logger .error ("Failed to close the StsAssumeRoleWithWebIdentityCredentialsProvider" , e );
122+ }
123+ }
124+ }
125+ }
126+
50127 // Package-private constructor for testing - avoids loading config from file and metrics chain setup
51128 LambdaFunctionHandler (LambdaInvokerConfig config ) {
52129 this .config = config ;
@@ -70,6 +147,40 @@ public LambdaFunctionHandler() {
70147 }
71148
72149 private LambdaAsyncClient initClient (LambdaInvokerConfig config ) {
150+ AwsCredentialsProvider credentialsProvider = null ;
151+ // If any STS type selected, use the matching credentials provider for automatic refresh.
152+ if (STS_TYPE_FUNC_USER .equals (config .getStsType ())) {
153+ if (logger .isInfoEnabled ()) logger .info ("STS AssumeRole is set to " + STS_TYPE_FUNC_USER + " for role: {}" , config .getRoleArn ());
154+ stsClient = StsClient .builder ()
155+ .region (Region .of (config .getRegion ()))
156+ .build ();
157+ stsCredentialsProvider = StsAssumeRoleCredentialsProvider .builder ()
158+ .stsClient (stsClient )
159+ .refreshRequest (AssumeRoleRequest .builder ()
160+ .roleArn (config .getRoleArn ())
161+ .roleSessionName (config .getRoleSessionName ())
162+ .durationSeconds (config .getDurationSeconds ())
163+ .build ())
164+ .build ();
165+ credentialsProvider = stsCredentialsProvider ;
166+ } else if (STS_TYPE_WEB_IDENTITY .equals (config .getStsType ())) {
167+ if (logger .isInfoEnabled ()) logger .info ("STS AssumeRole is set to " + STS_TYPE_WEB_IDENTITY + " for role: {}" , config .getRoleArn ());
168+ stsClient = StsClient .builder ()
169+ .region (Region .of (config .getRegion ()))
170+ .build ();
171+ stsWebIdentityCredentialsProvider = buildMutableStsWebIdentityCredentialsProvider (config , stsClient );
172+ credentialsProvider = stsWebIdentityCredentialsProvider ;
173+ } else {
174+ if (logger .isInfoEnabled ()) logger .info ("No STS AssumeRole is set. Using default credential provider chain for LambdaAsyncClient." );
175+ }
176+ return buildLambdaClient (config , credentialsProvider );
177+ }
178+
179+ MutableStsWebIdentityCredentialsProvider buildMutableStsWebIdentityCredentialsProvider (LambdaInvokerConfig config , StsClient stsClient ) {
180+ return new MutableStsWebIdentityCredentialsProvider (config , stsClient );
181+ }
182+
183+ LambdaAsyncClient buildLambdaClient (LambdaInvokerConfig config , AwsCredentialsProvider credentialsProvider ) {
73184 SdkAsyncHttpClient asyncHttpClient = NettyNioAsyncHttpClient .builder ()
74185 .readTimeout (Duration .ofMillis (config .getApiCallAttemptTimeout ()))
75186 .writeTimeout (Duration .ofMillis (config .getApiCallAttemptTimeout ()))
@@ -103,26 +214,24 @@ private LambdaAsyncClient initClient(LambdaInvokerConfig config) {
103214 builder .endpointOverride (URI .create (config .getEndpointOverride ()));
104215 }
105216
106- // If STS is enabled, use StsAssumeRoleCredentialsProvider for automatic credential refresh
107- if (config .isStsEnabled ()) {
108- if (logger .isInfoEnabled ()) logger .info ("STS AssumeRole is enabled. Using StsAssumeRoleCredentialsProvider for role: {}" , config .getRoleArn ());
109- stsClient = StsClient .builder ()
110- .region (Region .of (config .getRegion ()))
111- .build ();
112- stsCredentialsProvider = StsAssumeRoleCredentialsProvider .builder ()
113- .stsClient (stsClient )
114- .refreshRequest (AssumeRoleRequest .builder ()
115- .roleArn (config .getRoleArn ())
116- .roleSessionName (config .getRoleSessionName ())
117- .durationSeconds (config .getDurationSeconds ())
118- .build ())
119- .build ();
120- builder .credentialsProvider (stsCredentialsProvider );
217+ if (credentialsProvider != null ) {
218+ builder .credentialsProvider (credentialsProvider );
121219 }
122220
123221 return builder .build ();
124222 }
125223
224+ boolean updateWebIdentityToken (String token ) {
225+ if (stsWebIdentityCredentialsProvider == null ) {
226+ throw new IllegalStateException ("STS web identity credentials provider is not configured" );
227+ }
228+ return stsWebIdentityCredentialsProvider .updateToken (token );
229+ }
230+
231+ String currentWebIdentityTokenFingerprint () {
232+ return stsWebIdentityCredentialsProvider == null ? null : stsWebIdentityCredentialsProvider .getTokenFingerprint ();
233+ }
234+
126235 @ Override
127236 public void handleRequest (HttpServerExchange exchange ) throws Exception {
128237 LambdaInvokerConfig newConfig = LambdaInvokerConfig .load ();
@@ -146,6 +255,14 @@ public void handleRequest(HttpServerExchange exchange) throws Exception {
146255 }
147256 stsCredentialsProvider = null ;
148257 }
258+ if (stsWebIdentityCredentialsProvider != null ) {
259+ try {
260+ stsWebIdentityCredentialsProvider .close ();
261+ } catch (Exception e ) {
262+ logger .error ("Failed to close the StsAssumeRoleWithWebIdentityCredentialsProvider" , e );
263+ }
264+ stsWebIdentityCredentialsProvider = null ;
265+ }
149266 if (stsClient != null ) {
150267 try {
151268 stsClient .close ();
@@ -187,6 +304,21 @@ public void handleRequest(HttpServerExchange exchange) throws Exception {
187304 if (config .isMetricsInjection () && metricsHandler != null ) metricsHandler .injectMetrics (exchange , startTime , config .getMetricsName (), endpoint );
188305 return ;
189306 }
307+ if (STS_TYPE_WEB_IDENTITY .equals (config .getStsType ())) {
308+ String rawAuthHeader = exchange .getRequestHeaders ().getFirst (Headers .AUTHORIZATION );
309+ String token = extractBearerToken (rawAuthHeader );
310+ if (token == null || token .isEmpty ()) {
311+ exchange .setStatusCode (401 );
312+ exchange .getResponseSender ().send (INVALID_WEB_IDENTITY_TOKEN_MESSAGE );
313+ if (config .isMetricsInjection () && metricsHandler != null ) metricsHandler .injectMetrics (exchange , startTime , config .getMetricsName (), endpoint );
314+ return ;
315+ }
316+ if (updateWebIdentityToken (token )) {
317+ if (logger .isDebugEnabled ()) logger .debug ("Authorization token changed. Refreshed the shared STS web identity credentials provider." );
318+ } else {
319+ if (logger .isDebugEnabled ()) logger .debug ("Authorization token unchanged. Reusing the shared STS web identity credentials provider." );
320+ }
321+ }
190322 APIGatewayProxyRequestEvent requestEvent = new APIGatewayProxyRequestEvent ();
191323 requestEvent .setHttpMethod (httpMethod );
192324 requestEvent .setPath (requestPath );
@@ -277,4 +409,38 @@ private void setResponseHeaders(HttpServerExchange exchange, Map<String, String>
277409 }
278410 }
279411 }
412+
413+ /**
414+ * Extracts the bearer token from a raw Authorization header value.
415+ * Returns the token string if the header starts with "Bearer " (case-insensitive),
416+ * or {@code null} if the header is missing/empty or does not use the Bearer scheme.
417+ */
418+ static String extractBearerToken (String authorizationHeaderValue ) {
419+ if (authorizationHeaderValue == null || authorizationHeaderValue .isEmpty ()) {
420+ if (logger .isDebugEnabled ()) logger .debug ("Missing Authorization header from request. STS AssumeRole with Web Identity may fail" );
421+ return null ;
422+ }
423+ if (authorizationHeaderValue .length () > BEARER_PREFIX .length () + 1 &&
424+ authorizationHeaderValue .regionMatches (true , 0 , BEARER_PREFIX , 0 , BEARER_PREFIX .length ()) &&
425+ authorizationHeaderValue .charAt (BEARER_PREFIX .length ()) == ' ' ) {
426+ String token = authorizationHeaderValue .substring (BEARER_PREFIX .length () + 1 ).trim ();
427+ if (token .isEmpty ()) {
428+ if (logger .isDebugEnabled ()) logger .debug ("Authorization header contains a blank Bearer token. STS AssumeRole with Web Identity may fail" );
429+ return null ;
430+ }
431+ return token ;
432+ }
433+ if (logger .isDebugEnabled ()) logger .debug ("Authorization header does not start with Bearer. STS AssumeRole with Web Identity may fail" );
434+ return null ;
435+ }
436+
437+ static String fingerprintToken (String token ) {
438+ try {
439+ MessageDigest digest = MessageDigest .getInstance ("SHA-256" );
440+ byte [] hashed = digest .digest (token .getBytes (StandardCharsets .UTF_8 ));
441+ return Base64 .getEncoder ().encodeToString (hashed );
442+ } catch (NoSuchAlgorithmException e ) {
443+ throw new IllegalStateException ("SHA-256 is not available" , e );
444+ }
445+ }
280446}
0 commit comments