Skip to content

Commit 74b3370

Browse files
stevehuDiogoFKTCopilotCopilot
authored
Sync (#163)
* support for AssumeRoleWithWebIdentity * adding cache for IDToken * Apply review feedback: fix StsWebIdentity handler issues and config validation Agent-Logs-Url: https://github.com/networknt/light-aws-lambda/sessions/7d681ae3-7571-4446-82f3-109707801103 Co-authored-by: stevehu <2042337+stevehu@users.noreply.github.com> * Add StsWebIdentity unit tests: extract bearer token parsing and add test coverage Agent-Logs-Url: https://github.com/networknt/light-aws-lambda/sessions/e1696a54-4c13-45ea-831c-e02b5160ef43 Co-authored-by: stevehu <2042337+stevehu@users.noreply.github.com> * Update lambda-invoker/src/main/java/com/networknt/aws/lambda/LambdaFunctionHandler.java Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update lambda-invoker/src/test/java/com/networknt/aws/lambda/LambdaInvokerConfigTest.java Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Fix stsType normalization: assign trimmed value back to field in validate() Agent-Logs-Url: https://github.com/networknt/light-aws-lambda/sessions/2cd98e34-9726-4923-b74b-1cb85395a819 Co-authored-by: stevehu <2042337+stevehu@users.noreply.github.com> * address the compiler issue introduced by Copilot * address Copilot comments * change to reuse client * Update lambda-invoker/src/main/java/com/networknt/aws/lambda/LambdaFunctionHandler.java Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Diogo Fekete <Diogo.Fekete@sunlife.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: stevehu <2042337+stevehu@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 62a3cd1 commit 74b3370

14 files changed

Lines changed: 468 additions & 60 deletions

lambda-invoker/pom.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@
4747
<groupId>com.networknt</groupId>
4848
<artifactId>metrics</artifactId>
4949
</dependency>
50+
<dependency>
51+
<groupId>com.networknt</groupId>
52+
<artifactId>metrics-config</artifactId>
53+
</dependency>
5054
<dependency>
5155
<groupId>com.networknt</groupId>
5256
<artifactId>body</artifactId>

lambda-invoker/src/main/java/com/networknt/aws/lambda/LambdaFunctionHandler.java

Lines changed: 181 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313
import io.undertow.server.HttpHandler;
1414
import io.undertow.server.HttpServerExchange;
1515
import io.undertow.util.HeaderMap;
16+
import io.undertow.util.Headers;
1617
import io.undertow.util.HttpString;
1718
import org.slf4j.Logger;
1819
import org.slf4j.LoggerFactory;
20+
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
21+
import software.amazon.awssdk.auth.credentials.AwsCredentials;
1922
import software.amazon.awssdk.core.SdkBytes;
2023
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
2124
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
@@ -26,10 +29,16 @@
2629
import software.amazon.awssdk.services.lambda.model.InvokeRequest;
2730
import software.amazon.awssdk.services.sts.StsClient;
2831
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;
32+
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleWithWebIdentityCredentialsProvider;
2933
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;
34+
import software.amazon.awssdk.services.sts.model.AssumeRoleWithWebIdentityRequest;
3035

3136
import java.net.URI;
37+
import java.nio.charset.StandardCharsets;
38+
import java.security.MessageDigest;
39+
import java.security.NoSuchAlgorithmException;
3240
import java.time.Duration;
41+
import java.util.Base64;
3342
import java.util.Deque;
3443
import java.util.HashMap;
3544
import java.util.Map;
@@ -40,13 +49,81 @@ public class LambdaFunctionHandler implements LightHttpHandler {
4049
private static final Logger logger = LoggerFactory.getLogger(LambdaFunctionHandler.class);
4150
private static final String MISSING_ENDPOINT_FUNCTION = "ERR10063";
4251
private static final String EMPTY_LAMBDA_RESPONSE = "ERR10064";
52+
private static final String STS_TYPE_FUNC_USER = "StsFuncUser";
53+
private static final String STS_TYPE_WEB_IDENTITY = "StsWebIdentity";
54+
private static final String BEARER_PREFIX = "BEARER";
55+
private static final String INVALID_WEB_IDENTITY_TOKEN_MESSAGE = "Missing or invalid Bearer token for STS web identity";
4356
private static AbstractMetricsHandler metricsHandler;
4457

4558
private LambdaInvokerConfig config;
4659
private LambdaAsyncClient client;
4760
private StsAssumeRoleCredentialsProvider stsCredentialsProvider;
61+
private MutableStsWebIdentityCredentialsProvider stsWebIdentityCredentialsProvider;
4862
private StsClient stsClient;
4963

64+
static final class MutableStsWebIdentityCredentialsProvider implements AwsCredentialsProvider, AutoCloseable {
65+
private final LambdaInvokerConfig config;
66+
private final StsClient stsClient;
67+
private StsAssumeRoleWithWebIdentityCredentialsProvider delegate;
68+
private String tokenFingerprint;
69+
70+
MutableStsWebIdentityCredentialsProvider(LambdaInvokerConfig config, StsClient stsClient) {
71+
this.config = config;
72+
this.stsClient = stsClient;
73+
}
74+
75+
synchronized boolean updateToken(String token) {
76+
String nextFingerprint = fingerprintToken(token);
77+
if(nextFingerprint.equals(tokenFingerprint) && delegate != null) {
78+
return false;
79+
}
80+
StsAssumeRoleWithWebIdentityCredentialsProvider nextDelegate =
81+
StsAssumeRoleWithWebIdentityCredentialsProvider.builder()
82+
.stsClient(stsClient)
83+
.refreshRequest(AssumeRoleWithWebIdentityRequest.builder()
84+
.roleArn(config.getRoleArn())
85+
.roleSessionName(config.getRoleSessionName())
86+
.durationSeconds(config.getDurationSeconds())
87+
.webIdentityToken(token)
88+
.build())
89+
.build();
90+
StsAssumeRoleWithWebIdentityCredentialsProvider previousDelegate = delegate;
91+
delegate = nextDelegate;
92+
tokenFingerprint = nextFingerprint;
93+
closeDelegate(previousDelegate);
94+
return true;
95+
}
96+
97+
synchronized String getTokenFingerprint() {
98+
return tokenFingerprint;
99+
}
100+
101+
@Override
102+
public synchronized AwsCredentials resolveCredentials() {
103+
if(delegate == null) {
104+
throw new IllegalStateException("STS web identity credentials provider has not been initialized with a bearer token");
105+
}
106+
return delegate.resolveCredentials();
107+
}
108+
109+
@Override
110+
public synchronized void close() {
111+
closeDelegate(delegate);
112+
delegate = null;
113+
tokenFingerprint = null;
114+
}
115+
116+
private void closeDelegate(StsAssumeRoleWithWebIdentityCredentialsProvider provider) {
117+
if(provider != null) {
118+
try {
119+
provider.close();
120+
} catch (Exception e) {
121+
logger.error("Failed to close the StsAssumeRoleWithWebIdentityCredentialsProvider", e);
122+
}
123+
}
124+
}
125+
}
126+
50127
// Package-private constructor for testing - avoids loading config from file and metrics chain setup
51128
LambdaFunctionHandler(LambdaInvokerConfig config) {
52129
this.config = config;
@@ -70,6 +147,40 @@ public LambdaFunctionHandler() {
70147
}
71148

72149
private LambdaAsyncClient initClient(LambdaInvokerConfig config) {
150+
AwsCredentialsProvider credentialsProvider = null;
151+
// If any STS type selected, use the matching credentials provider for automatic refresh.
152+
if(STS_TYPE_FUNC_USER.equals(config.getStsType())) {
153+
if(logger.isInfoEnabled()) logger.info("STS AssumeRole is set to " + STS_TYPE_FUNC_USER + " for role: {}", config.getRoleArn());
154+
stsClient = StsClient.builder()
155+
.region(Region.of(config.getRegion()))
156+
.build();
157+
stsCredentialsProvider = StsAssumeRoleCredentialsProvider.builder()
158+
.stsClient(stsClient)
159+
.refreshRequest(AssumeRoleRequest.builder()
160+
.roleArn(config.getRoleArn())
161+
.roleSessionName(config.getRoleSessionName())
162+
.durationSeconds(config.getDurationSeconds())
163+
.build())
164+
.build();
165+
credentialsProvider = stsCredentialsProvider;
166+
} else if(STS_TYPE_WEB_IDENTITY.equals(config.getStsType())) {
167+
if(logger.isInfoEnabled()) logger.info("STS AssumeRole is set to " + STS_TYPE_WEB_IDENTITY + " for role: {}", config.getRoleArn());
168+
stsClient = StsClient.builder()
169+
.region(Region.of(config.getRegion()))
170+
.build();
171+
stsWebIdentityCredentialsProvider = buildMutableStsWebIdentityCredentialsProvider(config, stsClient);
172+
credentialsProvider = stsWebIdentityCredentialsProvider;
173+
} else {
174+
if(logger.isInfoEnabled()) logger.info("No STS AssumeRole is set. Using default credential provider chain for LambdaAsyncClient.");
175+
}
176+
return buildLambdaClient(config, credentialsProvider);
177+
}
178+
179+
MutableStsWebIdentityCredentialsProvider buildMutableStsWebIdentityCredentialsProvider(LambdaInvokerConfig config, StsClient stsClient) {
180+
return new MutableStsWebIdentityCredentialsProvider(config, stsClient);
181+
}
182+
183+
LambdaAsyncClient buildLambdaClient(LambdaInvokerConfig config, AwsCredentialsProvider credentialsProvider) {
73184
SdkAsyncHttpClient asyncHttpClient = NettyNioAsyncHttpClient.builder()
74185
.readTimeout(Duration.ofMillis(config.getApiCallAttemptTimeout()))
75186
.writeTimeout(Duration.ofMillis(config.getApiCallAttemptTimeout()))
@@ -103,26 +214,24 @@ private LambdaAsyncClient initClient(LambdaInvokerConfig config) {
103214
builder.endpointOverride(URI.create(config.getEndpointOverride()));
104215
}
105216

106-
// If STS is enabled, use StsAssumeRoleCredentialsProvider for automatic credential refresh
107-
if(config.isStsEnabled()) {
108-
if(logger.isInfoEnabled()) logger.info("STS AssumeRole is enabled. Using StsAssumeRoleCredentialsProvider for role: {}", config.getRoleArn());
109-
stsClient = StsClient.builder()
110-
.region(Region.of(config.getRegion()))
111-
.build();
112-
stsCredentialsProvider = StsAssumeRoleCredentialsProvider.builder()
113-
.stsClient(stsClient)
114-
.refreshRequest(AssumeRoleRequest.builder()
115-
.roleArn(config.getRoleArn())
116-
.roleSessionName(config.getRoleSessionName())
117-
.durationSeconds(config.getDurationSeconds())
118-
.build())
119-
.build();
120-
builder.credentialsProvider(stsCredentialsProvider);
217+
if(credentialsProvider != null) {
218+
builder.credentialsProvider(credentialsProvider);
121219
}
122220

123221
return builder.build();
124222
}
125223

224+
boolean updateWebIdentityToken(String token) {
225+
if(stsWebIdentityCredentialsProvider == null) {
226+
throw new IllegalStateException("STS web identity credentials provider is not configured");
227+
}
228+
return stsWebIdentityCredentialsProvider.updateToken(token);
229+
}
230+
231+
String currentWebIdentityTokenFingerprint() {
232+
return stsWebIdentityCredentialsProvider == null ? null : stsWebIdentityCredentialsProvider.getTokenFingerprint();
233+
}
234+
126235
@Override
127236
public void handleRequest(HttpServerExchange exchange) throws Exception {
128237
LambdaInvokerConfig newConfig = LambdaInvokerConfig.load();
@@ -146,6 +255,14 @@ public void handleRequest(HttpServerExchange exchange) throws Exception {
146255
}
147256
stsCredentialsProvider = null;
148257
}
258+
if(stsWebIdentityCredentialsProvider != null) {
259+
try {
260+
stsWebIdentityCredentialsProvider.close();
261+
} catch (Exception e) {
262+
logger.error("Failed to close the StsAssumeRoleWithWebIdentityCredentialsProvider", e);
263+
}
264+
stsWebIdentityCredentialsProvider = null;
265+
}
149266
if(stsClient != null) {
150267
try {
151268
stsClient.close();
@@ -187,6 +304,21 @@ public void handleRequest(HttpServerExchange exchange) throws Exception {
187304
if(config.isMetricsInjection() && metricsHandler != null) metricsHandler.injectMetrics(exchange, startTime, config.getMetricsName(), endpoint);
188305
return;
189306
}
307+
if(STS_TYPE_WEB_IDENTITY.equals(config.getStsType())) {
308+
String rawAuthHeader = exchange.getRequestHeaders().getFirst(Headers.AUTHORIZATION);
309+
String token = extractBearerToken(rawAuthHeader);
310+
if(token == null || token.isEmpty()) {
311+
exchange.setStatusCode(401);
312+
exchange.getResponseSender().send(INVALID_WEB_IDENTITY_TOKEN_MESSAGE);
313+
if(config.isMetricsInjection() && metricsHandler != null) metricsHandler.injectMetrics(exchange, startTime, config.getMetricsName(), endpoint);
314+
return;
315+
}
316+
if(updateWebIdentityToken(token)) {
317+
if(logger.isDebugEnabled()) logger.debug("Authorization token changed. Refreshed the shared STS web identity credentials provider.");
318+
} else {
319+
if(logger.isDebugEnabled()) logger.debug("Authorization token unchanged. Reusing the shared STS web identity credentials provider.");
320+
}
321+
}
190322
APIGatewayProxyRequestEvent requestEvent = new APIGatewayProxyRequestEvent();
191323
requestEvent.setHttpMethod(httpMethod);
192324
requestEvent.setPath(requestPath);
@@ -277,4 +409,38 @@ private void setResponseHeaders(HttpServerExchange exchange, Map<String, String>
277409
}
278410
}
279411
}
412+
413+
/**
414+
* Extracts the bearer token from a raw Authorization header value.
415+
* Returns the token string if the header starts with "Bearer " (case-insensitive),
416+
* or {@code null} if the header is missing/empty or does not use the Bearer scheme.
417+
*/
418+
static String extractBearerToken(String authorizationHeaderValue) {
419+
if (authorizationHeaderValue == null || authorizationHeaderValue.isEmpty()) {
420+
if(logger.isDebugEnabled()) logger.debug("Missing Authorization header from request. STS AssumeRole with Web Identity may fail");
421+
return null;
422+
}
423+
if (authorizationHeaderValue.length() > BEARER_PREFIX.length() + 1 &&
424+
authorizationHeaderValue.regionMatches(true, 0, BEARER_PREFIX, 0, BEARER_PREFIX.length()) &&
425+
authorizationHeaderValue.charAt(BEARER_PREFIX.length()) == ' ') {
426+
String token = authorizationHeaderValue.substring(BEARER_PREFIX.length() + 1).trim();
427+
if (token.isEmpty()) {
428+
if(logger.isDebugEnabled()) logger.debug("Authorization header contains a blank Bearer token. STS AssumeRole with Web Identity may fail");
429+
return null;
430+
}
431+
return token;
432+
}
433+
if(logger.isDebugEnabled()) logger.debug("Authorization header does not start with Bearer. STS AssumeRole with Web Identity may fail");
434+
return null;
435+
}
436+
437+
static String fingerprintToken(String token) {
438+
try {
439+
MessageDigest digest = MessageDigest.getInstance("SHA-256");
440+
byte[] hashed = digest.digest(token.getBytes(StandardCharsets.UTF_8));
441+
return Base64.getEncoder().encodeToString(hashed);
442+
} catch (NoSuchAlgorithmException e) {
443+
throw new IllegalStateException("SHA-256 is not available", e);
444+
}
445+
}
280446
}

lambda-invoker/src/main/java/com/networknt/aws/lambda/LambdaInvokerConfig.java

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public class LambdaInvokerConfig {
4040
private static final String MAX_CONCURRENCY = "maxConcurrency";
4141
private static final String MAX_PENDING_CONNECTION_ACQUIRES = "maxPendingConnectionAcquires";
4242
private static final String CONNECTION_ACQUISITION_TIMEOUT = "connectionAcquisitionTimeout";
43-
private static final String STS_ENABLED = "stsEnabled";
43+
private static final String STS_TYPE = "stsType";
4444
private static final String ROLE_ARN = "roleArn";
4545
private static final String ROLE_SESSION_NAME = "roleSessionName";
4646
private static final String DURATION_SECONDS = "durationSeconds";
@@ -152,21 +152,25 @@ public class LambdaInvokerConfig {
152152
)
153153
private String metricsName;
154154

155-
@BooleanField(
156-
configFieldName = STS_ENABLED,
157-
externalizedKeyName = STS_ENABLED,
155+
@StringField(
156+
configFieldName = STS_TYPE,
157+
externalizedKeyName = STS_TYPE,
158158
description = "Enable STS AssumeRole to obtain temporary credentials for Lambda invocation instead of using the\n" +
159-
"permanent IAM credentials. When set to true, the handler will call STS AssumeRole with the configured\n" +
160-
"roleArn, roleSessionName, and durationSeconds to get short-lived credentials. This is the recommended\n" +
161-
"approach for production environments to follow the principle of least privilege.\n",
162-
defaultValue = "false"
159+
"permanent IAM credentials. Only 2 STS types supported: StsFuncUser and StsWebIdentity.\n" +
160+
"If STS is not to be used set this property as empty. When StsFuncUser is set the handler will\n" +
161+
"use the configured AWS IAM User to assume the given RoleARN. When StsWebIdentity is set the handler will\n" +
162+
"use the request bearer token as the WEB_IDENTITY_TOKEN to be exchanged for STS token. Regardless of the\n" +
163+
"selected type, the handler will call STS AssumeRole with the configured roleArn, roleSessionName, and\n" +
164+
"durationSeconds to get short-lived credentials. Using one of the STS types is the recommended approach\n" +
165+
"for production environments to follow the principle of least privilege.\n",
166+
defaultValue = ""
163167
)
164-
private boolean stsEnabled;
168+
private String stsType;
165169

166170
@StringField(
167171
configFieldName = ROLE_ARN,
168172
externalizedKeyName = ROLE_ARN,
169-
description = "The ARN of the IAM role to assume when stsEnabled is true. For example,\n" +
173+
description = "The ARN of the IAM role to assume when stsType is not empty. For example,\n" +
170174
"arn:aws:iam::123456789012:role/LambdaInvokerRole\n"
171175
)
172176
private String roleArn;
@@ -329,12 +333,12 @@ public void setConnectionAcquisitionTimeout(int connectionAcquisitionTimeout) {
329333
this.connectionAcquisitionTimeout = connectionAcquisitionTimeout;
330334
}
331335

332-
public boolean isStsEnabled() {
333-
return stsEnabled;
336+
public String getStsType() {
337+
return stsType;
334338
}
335339

336-
public void setStsEnabled(boolean stsEnabled) {
337-
this.stsEnabled = stsEnabled;
340+
public void setStsType(String stsType) {
341+
this.stsType = stsType;
338342
}
339343

340344
public String getRoleArn() {
@@ -398,8 +402,8 @@ private void setConfigData() {
398402
object = mappedConfig.get(CONNECTION_ACQUISITION_TIMEOUT);
399403
if (object != null) connectionAcquisitionTimeout = Config.loadIntegerValue(CONNECTION_ACQUISITION_TIMEOUT, object);
400404

401-
object = mappedConfig.get(STS_ENABLED);
402-
if(object != null) stsEnabled = Config.loadBooleanValue(STS_ENABLED, object);
405+
object = mappedConfig.get(STS_TYPE);
406+
if (object != null) stsType = (String) object;
403407

404408
object = mappedConfig.get(ROLE_ARN);
405409
if(object != null) roleArn = (String) object;
@@ -456,8 +460,17 @@ private void setConfigMap() {
456460
}
457461

458462
private void validate() {
459-
if (stsEnabled && (roleArn == null || roleArn.trim().isEmpty())) {
460-
throw new ConfigException(ROLE_ARN + " must be configured when " + STS_ENABLED + " is true.");
463+
String normalizedStsType = stsType == null ? null : stsType.trim();
464+
// Write normalized value back so downstream equals() comparisons work correctly
465+
// even when the config value has leading/trailing whitespace (e.g. "StsWebIdentity ").
466+
stsType = normalizedStsType;
467+
if (normalizedStsType != null && !normalizedStsType.isEmpty()) {
468+
if (!"StsFuncUser".equals(normalizedStsType) && !"StsWebIdentity".equals(normalizedStsType)) {
469+
throw new ConfigException(STS_TYPE + " must be one of [StsFuncUser, StsWebIdentity], but was: " + normalizedStsType);
470+
}
471+
if (roleArn == null || roleArn.trim().isEmpty()) {
472+
throw new ConfigException(ROLE_ARN + " must be configured when " + STS_TYPE + " is not empty.");
473+
}
461474
}
462475
}
463476
}

0 commit comments

Comments
 (0)