Skip to content
Open

Sync #163

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lambda-invoker/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@
<groupId>com.networknt</groupId>
<artifactId>metrics</artifactId>
</dependency>
<dependency>
<groupId>com.networknt</groupId>
<artifactId>metrics-config</artifactId>
</dependency>
<dependency>
<groupId>com.networknt</groupId>
<artifactId>body</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
import io.undertow.server.HttpHandler;
import io.undertow.server.HttpServerExchange;
import io.undertow.util.HeaderMap;
import io.undertow.util.Headers;
import io.undertow.util.HttpString;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
Expand All @@ -26,10 +29,16 @@
import software.amazon.awssdk.services.lambda.model.InvokeRequest;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleWithWebIdentityCredentialsProvider;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;
import software.amazon.awssdk.services.sts.model.AssumeRoleWithWebIdentityRequest;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.util.Base64;
import java.util.Deque;
import java.util.HashMap;
import java.util.Map;
Expand All @@ -40,13 +49,81 @@ public class LambdaFunctionHandler implements LightHttpHandler {
private static final Logger logger = LoggerFactory.getLogger(LambdaFunctionHandler.class);
private static final String MISSING_ENDPOINT_FUNCTION = "ERR10063";
private static final String EMPTY_LAMBDA_RESPONSE = "ERR10064";
private static final String STS_TYPE_FUNC_USER = "StsFuncUser";
private static final String STS_TYPE_WEB_IDENTITY = "StsWebIdentity";
private static final String BEARER_PREFIX = "BEARER";
private static final String INVALID_WEB_IDENTITY_TOKEN_MESSAGE = "Missing or invalid Bearer token for STS web identity";
private static AbstractMetricsHandler metricsHandler;

private LambdaInvokerConfig config;
private LambdaAsyncClient client;
private StsAssumeRoleCredentialsProvider stsCredentialsProvider;
private MutableStsWebIdentityCredentialsProvider stsWebIdentityCredentialsProvider;
private StsClient stsClient;

static final class MutableStsWebIdentityCredentialsProvider implements AwsCredentialsProvider, AutoCloseable {
private final LambdaInvokerConfig config;
private final StsClient stsClient;
private StsAssumeRoleWithWebIdentityCredentialsProvider delegate;
private String tokenFingerprint;

MutableStsWebIdentityCredentialsProvider(LambdaInvokerConfig config, StsClient stsClient) {
this.config = config;
this.stsClient = stsClient;
}

synchronized boolean updateToken(String token) {
String nextFingerprint = fingerprintToken(token);
if(nextFingerprint.equals(tokenFingerprint) && delegate != null) {
return false;
}
StsAssumeRoleWithWebIdentityCredentialsProvider nextDelegate =
StsAssumeRoleWithWebIdentityCredentialsProvider.builder()
.stsClient(stsClient)
.refreshRequest(AssumeRoleWithWebIdentityRequest.builder()
.roleArn(config.getRoleArn())
.roleSessionName(config.getRoleSessionName())
.durationSeconds(config.getDurationSeconds())
.webIdentityToken(token)
.build())
.build();
StsAssumeRoleWithWebIdentityCredentialsProvider previousDelegate = delegate;
delegate = nextDelegate;
tokenFingerprint = nextFingerprint;
closeDelegate(previousDelegate);
return true;
}

synchronized String getTokenFingerprint() {
return tokenFingerprint;
}

@Override
public synchronized AwsCredentials resolveCredentials() {
if(delegate == null) {
throw new IllegalStateException("STS web identity credentials provider has not been initialized with a bearer token");
}
return delegate.resolveCredentials();
}

@Override
public synchronized void close() {
closeDelegate(delegate);
delegate = null;
tokenFingerprint = null;
}

private void closeDelegate(StsAssumeRoleWithWebIdentityCredentialsProvider provider) {
if(provider != null) {
try {
provider.close();
} catch (Exception e) {
logger.error("Failed to close the StsAssumeRoleWithWebIdentityCredentialsProvider", e);
}
}
}
}

// Package-private constructor for testing - avoids loading config from file and metrics chain setup
LambdaFunctionHandler(LambdaInvokerConfig config) {
this.config = config;
Expand All @@ -70,6 +147,40 @@ public LambdaFunctionHandler() {
}

private LambdaAsyncClient initClient(LambdaInvokerConfig config) {
AwsCredentialsProvider credentialsProvider = null;
// If any STS type selected, use the matching credentials provider for automatic refresh.
if(STS_TYPE_FUNC_USER.equals(config.getStsType())) {
if(logger.isInfoEnabled()) logger.info("STS AssumeRole is set to " + STS_TYPE_FUNC_USER + " for role: {}", config.getRoleArn());
stsClient = StsClient.builder()
.region(Region.of(config.getRegion()))
.build();
stsCredentialsProvider = StsAssumeRoleCredentialsProvider.builder()
.stsClient(stsClient)
.refreshRequest(AssumeRoleRequest.builder()
.roleArn(config.getRoleArn())
.roleSessionName(config.getRoleSessionName())
.durationSeconds(config.getDurationSeconds())
.build())
.build();
credentialsProvider = stsCredentialsProvider;
} else if(STS_TYPE_WEB_IDENTITY.equals(config.getStsType())) {
if(logger.isInfoEnabled()) logger.info("STS AssumeRole is set to " + STS_TYPE_WEB_IDENTITY + " for role: {}", config.getRoleArn());
stsClient = StsClient.builder()
.region(Region.of(config.getRegion()))
.build();
stsWebIdentityCredentialsProvider = buildMutableStsWebIdentityCredentialsProvider(config, stsClient);
credentialsProvider = stsWebIdentityCredentialsProvider;
} else {
if(logger.isInfoEnabled()) logger.info("No STS AssumeRole is set. Using default credential provider chain for LambdaAsyncClient.");
}
return buildLambdaClient(config, credentialsProvider);
}

MutableStsWebIdentityCredentialsProvider buildMutableStsWebIdentityCredentialsProvider(LambdaInvokerConfig config, StsClient stsClient) {
return new MutableStsWebIdentityCredentialsProvider(config, stsClient);
}

LambdaAsyncClient buildLambdaClient(LambdaInvokerConfig config, AwsCredentialsProvider credentialsProvider) {
SdkAsyncHttpClient asyncHttpClient = NettyNioAsyncHttpClient.builder()
.readTimeout(Duration.ofMillis(config.getApiCallAttemptTimeout()))
.writeTimeout(Duration.ofMillis(config.getApiCallAttemptTimeout()))
Expand Down Expand Up @@ -103,26 +214,24 @@ private LambdaAsyncClient initClient(LambdaInvokerConfig config) {
builder.endpointOverride(URI.create(config.getEndpointOverride()));
}

// If STS is enabled, use StsAssumeRoleCredentialsProvider for automatic credential refresh
if(config.isStsEnabled()) {
if(logger.isInfoEnabled()) logger.info("STS AssumeRole is enabled. Using StsAssumeRoleCredentialsProvider for role: {}", config.getRoleArn());
stsClient = StsClient.builder()
.region(Region.of(config.getRegion()))
.build();
stsCredentialsProvider = StsAssumeRoleCredentialsProvider.builder()
.stsClient(stsClient)
.refreshRequest(AssumeRoleRequest.builder()
.roleArn(config.getRoleArn())
.roleSessionName(config.getRoleSessionName())
.durationSeconds(config.getDurationSeconds())
.build())
.build();
builder.credentialsProvider(stsCredentialsProvider);
if(credentialsProvider != null) {
builder.credentialsProvider(credentialsProvider);
}

return builder.build();
}

boolean updateWebIdentityToken(String token) {
if(stsWebIdentityCredentialsProvider == null) {
throw new IllegalStateException("STS web identity credentials provider is not configured");
}
return stsWebIdentityCredentialsProvider.updateToken(token);
}

String currentWebIdentityTokenFingerprint() {
return stsWebIdentityCredentialsProvider == null ? null : stsWebIdentityCredentialsProvider.getTokenFingerprint();
}

@Override
public void handleRequest(HttpServerExchange exchange) throws Exception {
LambdaInvokerConfig newConfig = LambdaInvokerConfig.load();
Expand All @@ -146,6 +255,14 @@ public void handleRequest(HttpServerExchange exchange) throws Exception {
}
stsCredentialsProvider = null;
}
if(stsWebIdentityCredentialsProvider != null) {
try {
stsWebIdentityCredentialsProvider.close();
} catch (Exception e) {
logger.error("Failed to close the StsAssumeRoleWithWebIdentityCredentialsProvider", e);
}
stsWebIdentityCredentialsProvider = null;
}
if(stsClient != null) {
try {
stsClient.close();
Expand Down Expand Up @@ -187,6 +304,21 @@ public void handleRequest(HttpServerExchange exchange) throws Exception {
if(config.isMetricsInjection() && metricsHandler != null) metricsHandler.injectMetrics(exchange, startTime, config.getMetricsName(), endpoint);
return;
}
if(STS_TYPE_WEB_IDENTITY.equals(config.getStsType())) {
String rawAuthHeader = exchange.getRequestHeaders().getFirst(Headers.AUTHORIZATION);
String token = extractBearerToken(rawAuthHeader);
if(token == null || token.isEmpty()) {
exchange.setStatusCode(401);
exchange.getResponseSender().send(INVALID_WEB_IDENTITY_TOKEN_MESSAGE);
if(config.isMetricsInjection() && metricsHandler != null) metricsHandler.injectMetrics(exchange, startTime, config.getMetricsName(), endpoint);
return;
}
if(updateWebIdentityToken(token)) {
if(logger.isDebugEnabled()) logger.debug("Authorization token changed. Refreshed the shared STS web identity credentials provider.");
} else {
if(logger.isDebugEnabled()) logger.debug("Authorization token unchanged. Reusing the shared STS web identity credentials provider.");
}
}
APIGatewayProxyRequestEvent requestEvent = new APIGatewayProxyRequestEvent();
requestEvent.setHttpMethod(httpMethod);
requestEvent.setPath(requestPath);
Expand Down Expand Up @@ -277,4 +409,38 @@ private void setResponseHeaders(HttpServerExchange exchange, Map<String, String>
}
}
}

/**
* Extracts the bearer token from a raw Authorization header value.
* Returns the token string if the header starts with "Bearer " (case-insensitive),
* or {@code null} if the header is missing/empty or does not use the Bearer scheme.
*/
static String extractBearerToken(String authorizationHeaderValue) {
if (authorizationHeaderValue == null || authorizationHeaderValue.isEmpty()) {
if(logger.isDebugEnabled()) logger.debug("Missing Authorization header from request. STS AssumeRole with Web Identity may fail");
return null;
}
if (authorizationHeaderValue.length() > BEARER_PREFIX.length() + 1 &&
authorizationHeaderValue.regionMatches(true, 0, BEARER_PREFIX, 0, BEARER_PREFIX.length()) &&
authorizationHeaderValue.charAt(BEARER_PREFIX.length()) == ' ') {
String token = authorizationHeaderValue.substring(BEARER_PREFIX.length() + 1).trim();
if (token.isEmpty()) {
if(logger.isDebugEnabled()) logger.debug("Authorization header contains a blank Bearer token. STS AssumeRole with Web Identity may fail");
return null;
}
return token;
}
if(logger.isDebugEnabled()) logger.debug("Authorization header does not start with Bearer. STS AssumeRole with Web Identity may fail");
return null;
}

static String fingerprintToken(String token) {
try {
MessageDigest digest = MessageDigest.getInstance("SHA-256");
byte[] hashed = digest.digest(token.getBytes(StandardCharsets.UTF_8));
return Base64.getEncoder().encodeToString(hashed);
} catch (NoSuchAlgorithmException e) {
throw new IllegalStateException("SHA-256 is not available", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public class LambdaInvokerConfig {
private static final String MAX_CONCURRENCY = "maxConcurrency";
private static final String MAX_PENDING_CONNECTION_ACQUIRES = "maxPendingConnectionAcquires";
private static final String CONNECTION_ACQUISITION_TIMEOUT = "connectionAcquisitionTimeout";
private static final String STS_ENABLED = "stsEnabled";
private static final String STS_TYPE = "stsType";
private static final String ROLE_ARN = "roleArn";
private static final String ROLE_SESSION_NAME = "roleSessionName";
private static final String DURATION_SECONDS = "durationSeconds";
Expand Down Expand Up @@ -152,21 +152,25 @@ public class LambdaInvokerConfig {
)
private String metricsName;

@BooleanField(
configFieldName = STS_ENABLED,
externalizedKeyName = STS_ENABLED,
@StringField(
configFieldName = STS_TYPE,
externalizedKeyName = STS_TYPE,
description = "Enable STS AssumeRole to obtain temporary credentials for Lambda invocation instead of using the\n" +
"permanent IAM credentials. When set to true, the handler will call STS AssumeRole with the configured\n" +
"roleArn, roleSessionName, and durationSeconds to get short-lived credentials. This is the recommended\n" +
"approach for production environments to follow the principle of least privilege.\n",
defaultValue = "false"
"permanent IAM credentials. Only 2 STS types supported: StsFuncUser and StsWebIdentity.\n" +
"If STS is not to be used set this property as empty. When StsFuncUser is set the handler will\n" +
"use the configured AWS IAM User to assume the given RoleARN. When StsWebIdentity is set the handler will\n" +
"use the request bearer token as the WEB_IDENTITY_TOKEN to be exchanged for STS token. Regardless of the\n" +
"selected type, the handler will call STS AssumeRole with the configured roleArn, roleSessionName, and\n" +
"durationSeconds to get short-lived credentials. Using one of the STS types is the recommended approach\n" +
"for production environments to follow the principle of least privilege.\n",
defaultValue = ""
)
private boolean stsEnabled;
private String stsType;

@StringField(
configFieldName = ROLE_ARN,
externalizedKeyName = ROLE_ARN,
description = "The ARN of the IAM role to assume when stsEnabled is true. For example,\n" +
description = "The ARN of the IAM role to assume when stsType is not empty. For example,\n" +
"arn:aws:iam::123456789012:role/LambdaInvokerRole\n"
)
private String roleArn;
Expand Down Expand Up @@ -329,12 +333,12 @@ public void setConnectionAcquisitionTimeout(int connectionAcquisitionTimeout) {
this.connectionAcquisitionTimeout = connectionAcquisitionTimeout;
}

public boolean isStsEnabled() {
return stsEnabled;
public String getStsType() {
return stsType;
}

public void setStsEnabled(boolean stsEnabled) {
this.stsEnabled = stsEnabled;
public void setStsType(String stsType) {
this.stsType = stsType;
}

public String getRoleArn() {
Expand Down Expand Up @@ -398,8 +402,8 @@ private void setConfigData() {
object = mappedConfig.get(CONNECTION_ACQUISITION_TIMEOUT);
if (object != null) connectionAcquisitionTimeout = Config.loadIntegerValue(CONNECTION_ACQUISITION_TIMEOUT, object);

object = mappedConfig.get(STS_ENABLED);
if(object != null) stsEnabled = Config.loadBooleanValue(STS_ENABLED, object);
object = mappedConfig.get(STS_TYPE);
if (object != null) stsType = (String) object;

object = mappedConfig.get(ROLE_ARN);
if(object != null) roleArn = (String) object;
Expand Down Expand Up @@ -456,8 +460,17 @@ private void setConfigMap() {
}

private void validate() {
if (stsEnabled && (roleArn == null || roleArn.trim().isEmpty())) {
throw new ConfigException(ROLE_ARN + " must be configured when " + STS_ENABLED + " is true.");
String normalizedStsType = stsType == null ? null : stsType.trim();
// Write normalized value back so downstream equals() comparisons work correctly
// even when the config value has leading/trailing whitespace (e.g. "StsWebIdentity ").
stsType = normalizedStsType;
if (normalizedStsType != null && !normalizedStsType.isEmpty()) {
if (!"StsFuncUser".equals(normalizedStsType) && !"StsWebIdentity".equals(normalizedStsType)) {
throw new ConfigException(STS_TYPE + " must be one of [StsFuncUser, StsWebIdentity], but was: " + normalizedStsType);
}
if (roleArn == null || roleArn.trim().isEmpty()) {
throw new ConfigException(ROLE_ARN + " must be configured when " + STS_TYPE + " is not empty.");
}
}
}
}
Loading