Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/keyvault/azure-security-keyvault-jca/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## 2.11.0-beta.1 (Unreleased)

### Features Added
- Add support for Workload Identity authentication in Azure Kubernetes Service (AKS).

### Breaking Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ private AccessToken getAccessTokenByHttpRequest() {
disableChallengeResourceVerification);
accessToken
= AccessTokenUtil.getAccessToken(resource, aadAuthenticationUri, tenantId, clientId, clientSecret);
} else if (AccessTokenUtil.isFederatedTokenFileConfigured()) {
accessToken = AccessTokenUtil.getAccessTokenUsingWorkloadIdentity(keyVaultBaseUri, tenantId, clientId);
} else {
accessToken = AccessTokenUtil.getAccessToken(resource, managedIdentity);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Collections;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.function.Supplier;
import java.util.logging.Logger;

import static com.azure.security.keyvault.jca.implementation.utils.HttpUtil.addTrailingSlashIfRequired;
Expand Down Expand Up @@ -78,6 +83,11 @@ public final class AccessTokenUtil {
private static final String PROPERTY_IDENTITY_ENDPOINT = "IDENTITY_ENDPOINT";
private static final String PROPERTY_IDENTITY_HEADER = "IDENTITY_HEADER";

private static final String ENV_AZURE_FEDERATED_TOKEN_FILE = "AZURE_FEDERATED_TOKEN_FILE";
private static final String ENV_AZURE_CLIENT_ID = "AZURE_CLIENT_ID";
private static final String ENV_AZURE_TENANT_ID = "AZURE_TENANT_ID";
private static final String ENV_AZURE_AUTHORITY_HOST = "AZURE_AUTHORITY_HOST";

/**
* Get an access token for a managed identity.
*
Expand Down Expand Up @@ -168,6 +178,110 @@ public static AccessToken getAccessToken(String resource, String aadAuthenticati
return result;
}

public static boolean isFederatedTokenFileConfigured() {
String federatedTokenFilePath = System.getenv(ENV_AZURE_FEDERATED_TOKEN_FILE);
return !isNullOrBlank(federatedTokenFilePath);
}

/**
* Get an access token via <a href="https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-client-creds-grant-flow#third-case-access-token-request-with-a-federated-credential">client creds grant flow</a>
* using <a href="https://learn.microsoft.com/en-us/azure/aks/workload-identity-overview">Microsoft Entra Workload ID with AKS</a>.
* Uses the federate token in file located at environment variable <code>AZURE_FEDERATED_TOKEN_FILE</code>
* and provided <code>clientId</code> and <code>tenantId</code> to issue an access token HTTP request.
*
* @param keyVaultBaseUri Base URI of the keyvault.
* @param tenantId Tenant ID to use. If blank fallback to environment variable <code>AZURE_TENANT_ID</code>
* @param clientId Client ID of the managed identity to use. If blank fallback to environment variable <code>AZURE_CLIENT_ID</code>
* @return An access token.
*/
public static AccessToken getAccessTokenUsingWorkloadIdentity(String keyVaultBaseUri, String tenantId,
String clientId) {
LOGGER.entering("AccessTokenUtil", "getAccessTokenUsingWorkloadIdentity",
new Object[] { keyVaultBaseUri, tenantId, clientId });
LOGGER.info("Getting access token using federated Workload Identity token");

String tokenFilePath = System.getenv(ENV_AZURE_FEDERATED_TOKEN_FILE);
LOGGER.log(INFO, "Using federated token file: {0}", tokenFilePath);

tenantId = useDefaultIfBlank(tenantId, () -> System.getenv(ENV_AZURE_TENANT_ID));
clientId = useDefaultIfBlank(clientId, () -> System.getenv(ENV_AZURE_CLIENT_ID));
LOGGER.log(INFO, "Using clientId {0} in tenantId {1}", new Object[] { clientId, tenantId });

// scope is required to end with "/.default"
if (!keyVaultBaseUri.endsWith(".default")) {
keyVaultBaseUri = addTrailingSlashIfRequired(keyVaultBaseUri) + ".default";
}

// allow override of authority host via environment variable
String authorityHost = useDefaultIfBlank(System.getenv(ENV_AZURE_AUTHORITY_HOST), () -> OAUTH2_TOKEN_BASE_URL);

AccessToken result = null;

String federatedToken = readFile(tokenFilePath);
if (!isNullOrBlank(federatedToken)) {
String requestUrl = addTrailingSlashIfRequired(authorityHost) + tenantId + "/oauth2/v2.0/token";
String requestBody = "grant_type=client_credentials" + "&client_id=" + urlEncode(clientId)
+ "&client_assertion_type=" + urlEncode("urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
+ "&client_assertion=" + urlEncode(federatedToken) + "&scope=" + urlEncode(keyVaultBaseUri);

String response = HttpUtil.post(requestUrl, requestBody, "application/x-www-form-urlencoded");
result = parseAccessTokenResponse(response);
} else {
LOGGER.log(WARNING, "Failed to read federated token from file: {0}", tokenFilePath);
}

LOGGER.exiting("AccessTokenUtil", "getAccessTokenUsingWorkloadIdentity", result);

return result;
}

private static String useDefaultIfBlank(String value, Supplier<String> defaultValueSupplier) {
if (isNullOrBlank(value)) {
return defaultValueSupplier.get();
}
return value;
}

private static boolean isNullOrBlank(String value) {
return value == null || value.trim().isEmpty();
}

private static String urlEncode(String text) {
if (text == null) {
return null;
}

try {
return URLEncoder.encode(text, "UTF-8");
} catch (UnsupportedEncodingException e) {
LOGGER.log(WARNING, "Failed to encode text.", e);
return null;
}
}

private static AccessToken parseAccessTokenResponse(String response) {
if (response == null) {
return null;
}

try {
return JsonConverterUtil.fromJson(AccessToken::fromJson, response);
} catch (IOException e) {
LOGGER.log(WARNING, "Failed to parse access token from response.", e);
return null;
}
}

static String readFile(String filePath) {
try {
Path path = Paths.get(filePath);
return new String(Files.readAllBytes(path), StandardCharsets.UTF_8).trim();
} catch (IOException e) {
LOGGER.log(WARNING, "Failed to read file.", e);
return null;
}
}

/**
* Get the access token on Azure App Service.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.security.keyvault.jca;
package com.azure.security.keyvault.jca.implementation.utils;

import com.azure.security.keyvault.jca.PropertyConvertorUtils;
import com.azure.security.keyvault.jca.implementation.model.AccessToken;
import com.azure.security.keyvault.jca.implementation.utils.AccessTokenUtil;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.api.io.TempDir;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;

import static com.azure.security.keyvault.jca.implementation.utils.AccessTokenUtil.getLoginUri;
import static com.azure.security.keyvault.jca.implementation.utils.HttpUtil.API_VERSION_POSTFIX;
import static com.azure.security.keyvault.jca.implementation.utils.HttpUtil.addTrailingSlashIfRequired;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.*;

/**
* The JUnit test for the AuthClient.
Expand Down Expand Up @@ -46,4 +49,15 @@ public void testGetLoginUri() {
assertNotNull(result);
assertDoesNotThrow(() -> new URI(result));
}

@Test
void testReadFile(@TempDir Path tempDir) throws Exception {
Path tempFile = Files.createTempFile(tempDir, "simple_text_file_", ".txt");
String expectedContent = "Just a dummy string";
Files.write(tempFile, expectedContent.getBytes(StandardCharsets.UTF_8));

String actualContent = AccessTokenUtil.readFile(tempFile.toAbsolutePath().toString());
assertNotNull(actualContent);
assertEquals(expectedContent, actualContent);
}
}