From 5bdda249850ffdf6d9506727dd3ebb267fc46a61 Mon Sep 17 00:00:00 2001 From: Karen Chen <64801825+karenc-bq@users.noreply.github.com> Date: Mon, 2 Mar 2026 21:28:44 -0800 Subject: [PATCH 1/2] feat: IAM support for GDB --- .../iam_authentication_plugin.ts | 35 ++-- common/lib/plugin_manager.ts | 1 + .../custom_endpoint/custom_endpoint_plugin.ts | 3 +- .../federated_auth/federated_auth_plugin.ts | 115 +----------- .../federated_auth/okta_auth_plugin.ts | 117 +----------- .../federated_auth/saml_auth_plugin.ts | 167 ++++++++++++++++++ .../saml_credentials_provider_factory.ts | 7 +- common/lib/utils/gdb_region_utils.ts | 102 +++++++++++ common/lib/utils/iam_auth_utils.ts | 35 ++-- common/lib/utils/messages.ts | 2 + common/lib/utils/region_utils.ts | 19 +- tests/unit/federated_auth_plugin.test.ts | 164 ----------------- tests/unit/iam_authentication_plugin.test.ts | 39 ++-- tests/unit/okta_auth_plugin.test.ts | 166 ----------------- tests/unit/saml_auth_plugin.test.ts | 161 +++++++++++++++++ 15 files changed, 513 insertions(+), 620 deletions(-) create mode 100644 common/lib/plugins/federated_auth/saml_auth_plugin.ts create mode 100644 common/lib/utils/gdb_region_utils.ts delete mode 100644 tests/unit/federated_auth_plugin.test.ts delete mode 100644 tests/unit/okta_auth_plugin.test.ts create mode 100644 tests/unit/saml_auth_plugin.test.ts diff --git a/common/lib/authentication/iam_authentication_plugin.ts b/common/lib/authentication/iam_authentication_plugin.ts index e129514e1..9b1f8cf91 100644 --- a/common/lib/authentication/iam_authentication_plugin.ts +++ b/common/lib/authentication/iam_authentication_plugin.ts @@ -26,19 +26,26 @@ import { IamAuthUtils, TokenInfo } from "../utils/iam_auth_utils"; import { ClientWrapper } from "../client_wrapper"; import { RegionUtils } from "../utils/region_utils"; import { CanReleaseResources } from "../can_release_resources"; +import { RdsUrlType } from "../utils/rds_url_type"; +import { RdsUtils } from "../utils/rds_utils"; +import { GDBRegionUtils } from "../utils/gdb_region_utils"; export class IamAuthenticationPlugin extends AbstractConnectionPlugin implements CanReleaseResources { private static readonly SUBSCRIBED_METHODS = new Set(["connect", "forceConnect"]); protected static readonly tokenCache = new Map(); private readonly telemetryFactory; private readonly fetchTokenCounter; - private pluginService: PluginService; + private readonly pluginService: PluginService; + private readonly rdsUtils: RdsUtils = new RdsUtils(); + protected regionUtils: RegionUtils; + protected readonly iamAuthUtils: IamAuthUtils; - constructor(pluginService: PluginService) { + constructor(pluginService: PluginService, iamAuthUtils: IamAuthUtils = new IamAuthUtils()) { super(); this.pluginService = pluginService; this.telemetryFactory = this.pluginService.getTelemetryFactory(); this.fetchTokenCounter = this.telemetryFactory.createCounter("iam.fetchTokenCount"); + this.iamAuthUtils = iamAuthUtils; } getSubscribedMethods(): Set { @@ -74,14 +81,22 @@ export class IamAuthenticationPlugin extends AbstractConnectionPlugin implements throw new AwsWrapperError(`${WrapperProperties.USER} is null or empty`); } - const host = IamAuthUtils.getIamHost(props, hostInfo); - const region: string = RegionUtils.getRegion(props.get(WrapperProperties.IAM_REGION.name), host); - const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getCurrentClient().defaultPort); + const host = this.iamAuthUtils.getIamHost(props, hostInfo); + const port = this.iamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getCurrentClient().defaultPort); + + const type: RdsUrlType = this.rdsUtils.identifyRdsType(host.host); + this.regionUtils = type == RdsUrlType.RDS_GLOBAL_WRITER_CLUSTER ? new GDBRegionUtils() : new RegionUtils(); + const region: string | null = await this.regionUtils.getRegion(WrapperProperties.IAM_REGION.name, host, props); + + if (!region) { + throw new AwsWrapperError(Messages.get("SamlAuthPlugin.unableToDetermineRegion", WrapperProperties.IAM_REGION.name)); + } + const tokenExpirationSec = WrapperProperties.IAM_TOKEN_EXPIRATION.get(props); if (tokenExpirationSec < 0) { throw new AwsWrapperError(Messages.get("AuthenticationToken.tokenExpirationLessThanZero")); } - const cacheKey: string = IamAuthUtils.getCacheKey(port, user, host, region); + const cacheKey: string = this.iamAuthUtils.getCacheKey(port, user, host.host, region); const tokenInfo = IamAuthenticationPlugin.tokenCache.get(cacheKey); const isCachedToken: boolean = tokenInfo !== undefined && !tokenInfo.isExpired(); @@ -91,8 +106,8 @@ export class IamAuthenticationPlugin extends AbstractConnectionPlugin implements WrapperProperties.PASSWORD.set(props, tokenInfo.token); } else { const tokenExpiry: number = Date.now() + tokenExpirationSec * 1000; - const token = await IamAuthUtils.generateAuthenticationToken( - host, + const token = await this.iamAuthUtils.generateAuthenticationToken( + host.host, port, region, user, @@ -118,8 +133,8 @@ export class IamAuthenticationPlugin extends AbstractConnectionPlugin implements // Try to generate a new token and try to connect again const tokenExpiry: number = Date.now() + tokenExpirationSec * 1000; - const token = await IamAuthUtils.generateAuthenticationToken( - host, + const token = await this.iamAuthUtils.generateAuthenticationToken( + host.host, port, region, user, diff --git a/common/lib/plugin_manager.ts b/common/lib/plugin_manager.ts index 492566962..f66e65baa 100644 --- a/common/lib/plugin_manager.ts +++ b/common/lib/plugin_manager.ts @@ -32,6 +32,7 @@ import { TelemetryTraceLevel } from "./utils/telemetry/telemetry_trace_level"; import { ConnectionProvider } from "./connection_provider"; import { ConnectionPluginFactory } from "./plugin_factory"; import { ConfigurationProfile } from "./profile/configuration_profile"; +import { BaseSamlAuthPlugin } from "./plugins/federated_auth/saml_auth_plugin"; type PluginFunc = (plugin: ConnectionPlugin, targetFunc: () => Promise) => Promise; diff --git a/common/lib/plugins/custom_endpoint/custom_endpoint_plugin.ts b/common/lib/plugins/custom_endpoint/custom_endpoint_plugin.ts index 99e5bfb42..3b73c60ac 100644 --- a/common/lib/plugins/custom_endpoint/custom_endpoint_plugin.ts +++ b/common/lib/plugins/custom_endpoint/custom_endpoint_plugin.ts @@ -106,7 +106,8 @@ export class CustomEndpointPlugin extends AbstractConnectionPlugin implements Ca throw new AwsWrapperError(Messages.get("CustomEndpointPlugin.errorParsingEndpointIdentifier", this.customEndpointHostInfo.host)); } - this.region = RegionUtils.getRegion(props.get(WrapperProperties.CUSTOM_ENDPOINT_REGION.name), this.customEndpointHostInfo.host); + const regionUtils = new RegionUtils(); + this.region = await regionUtils.getRegion(WrapperProperties.CUSTOM_ENDPOINT_REGION.name, this.customEndpointHostInfo, props); if (!this.region) { throw new AwsWrapperError(Messages.get("CustomEndpointPlugin.unableToDetermineRegion", WrapperProperties.CUSTOM_ENDPOINT_REGION.name)); } diff --git a/common/lib/plugins/federated_auth/federated_auth_plugin.ts b/common/lib/plugins/federated_auth/federated_auth_plugin.ts index 2a440dd54..406720379 100644 --- a/common/lib/plugins/federated_auth/federated_auth_plugin.ts +++ b/common/lib/plugins/federated_auth/federated_auth_plugin.ts @@ -14,118 +14,13 @@ limitations under the License. */ -import { AbstractConnectionPlugin } from "../../abstract_connection_plugin"; import { PluginService } from "../../plugin_service"; -import { RdsUtils } from "../../utils/rds_utils"; -import { HostInfo } from "../../host_info"; -import { IamAuthUtils, TokenInfo } from "../../utils/iam_auth_utils"; -import { WrapperProperties } from "../../wrapper_property"; -import { logger } from "../../../logutils"; -import { AwsWrapperError } from "../../utils/errors"; -import { Messages } from "../../utils/messages"; import { CredentialsProviderFactory } from "./credentials_provider_factory"; -import { SamlUtils } from "../../utils/saml_utils"; -import { ClientWrapper } from "../../client_wrapper"; -import { TelemetryCounter } from "../../utils/telemetry/telemetry_counter"; -import { RegionUtils } from "../../utils/region_utils"; -import { CanReleaseResources } from "../../can_release_resources"; +import { BaseSamlAuthPlugin } from "./saml_auth_plugin"; +import { IamAuthUtils } from "../../utils/iam_auth_utils"; -export class FederatedAuthPlugin extends AbstractConnectionPlugin implements CanReleaseResources { - protected static readonly tokenCache = new Map(); - protected rdsUtils: RdsUtils = new RdsUtils(); - protected pluginService: PluginService; - private static readonly subscribedMethods = new Set(["connect", "forceConnect"]); - private readonly credentialsProviderFactory: CredentialsProviderFactory; - private readonly fetchTokenCounter: TelemetryCounter; - - public getSubscribedMethods(): Set { - return FederatedAuthPlugin.subscribedMethods; - } - - constructor(pluginService: PluginService, credentialsProviderFactory: CredentialsProviderFactory) { - super(); - this.credentialsProviderFactory = credentialsProviderFactory; - this.pluginService = pluginService; - this.fetchTokenCounter = this.pluginService.getTelemetryFactory().createCounter("federatedAuth.fetchToken.count"); - } - - connect( - hostInfo: HostInfo, - props: Map, - isInitialConnection: boolean, - connectFunc: () => Promise - ): Promise { - return this.connectInternal(hostInfo, props, connectFunc); - } - - forceConnect( - hostInfo: HostInfo, - props: Map, - isInitialConnection: boolean, - forceConnectFunc: () => Promise - ): Promise { - return this.connectInternal(hostInfo, props, forceConnectFunc); - } - - async connectInternal(hostInfo: HostInfo, props: Map, connectFunc: () => Promise): Promise { - SamlUtils.checkIdpCredentialsWithFallback(props); - - const host = IamAuthUtils.getIamHost(props, hostInfo); - const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getDialect().getDefaultPort()); - const region: string = RegionUtils.getRegion(props.get(WrapperProperties.IAM_REGION.name), host); - - const cacheKey = IamAuthUtils.getCacheKey(port, WrapperProperties.DB_USER.get(props), host, region); - const tokenInfo = FederatedAuthPlugin.tokenCache.get(cacheKey); - - const isCachedToken: boolean = tokenInfo !== undefined && !tokenInfo.isExpired(); - - if (isCachedToken && tokenInfo) { - logger.debug(Messages.get("AuthenticationToken.useCachedToken", tokenInfo.token)); - WrapperProperties.PASSWORD.set(props, tokenInfo.token); - } else { - await this.updateAuthenticationToken(hostInfo, props, region, cacheKey, host); - } - WrapperProperties.USER.set(props, WrapperProperties.DB_USER.get(props)); - this.pluginService.updateConfigWithProperties(props); - - try { - return await connectFunc(); - } catch (e) { - if (!this.pluginService.isLoginError(e as Error) || !isCachedToken) { - throw e; - } - try { - await this.updateAuthenticationToken(hostInfo, props, region, cacheKey, host); - return await connectFunc(); - } catch (e: any) { - throw new AwsWrapperError(Messages.get("SamlAuthPlugin.unhandledError", e.message)); - } - } - } - - public async updateAuthenticationToken(hostInfo: HostInfo, props: Map, region: string, cacheKey: string, iamHost: string) { - const tokenExpirationSec = WrapperProperties.IAM_TOKEN_EXPIRATION.get(props); - if (tokenExpirationSec < 0) { - throw new AwsWrapperError(Messages.get("AuthenticationToken.tokenExpirationLessThanZero")); - } - const tokenExpiry: number = Date.now() + tokenExpirationSec * 1000; - const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getDialect().getDefaultPort()); - const token = await IamAuthUtils.generateAuthenticationToken( - iamHost, - port, - region, - WrapperProperties.DB_USER.get(props), - await this.credentialsProviderFactory.getAwsCredentialsProvider(hostInfo.host, region, props), - this.pluginService - ); - this.fetchTokenCounter.inc(); - logger.debug(Messages.get("AuthenticationToken.generatedNewToken", token)); - WrapperProperties.PASSWORD.set(props, token); - FederatedAuthPlugin.tokenCache.set(cacheKey, new TokenInfo(token, tokenExpiry)); - } - - releaseResources(): Promise { - FederatedAuthPlugin.tokenCache.clear(); - return; +export class FederatedAuthPlugin extends BaseSamlAuthPlugin { + constructor(pluginService: PluginService, credentialsProviderFactory: CredentialsProviderFactory, iamAuthUtils: IamAuthUtils = new IamAuthUtils()) { + super(pluginService, credentialsProviderFactory, "federatedAuth.fetchToken.count", iamAuthUtils); } } diff --git a/common/lib/plugins/federated_auth/okta_auth_plugin.ts b/common/lib/plugins/federated_auth/okta_auth_plugin.ts index 9de79bd15..31b6144f8 100644 --- a/common/lib/plugins/federated_auth/okta_auth_plugin.ts +++ b/common/lib/plugins/federated_auth/okta_auth_plugin.ts @@ -14,120 +14,13 @@ limitations under the License. */ -import { AbstractConnectionPlugin } from "../../abstract_connection_plugin"; -import { HostInfo } from "../../host_info"; -import { SamlUtils } from "../../utils/saml_utils"; -import { IamAuthUtils, TokenInfo } from "../../utils/iam_auth_utils"; import { PluginService } from "../../plugin_service"; import { CredentialsProviderFactory } from "./credentials_provider_factory"; -import { RdsUtils } from "../../utils/rds_utils"; -import { WrapperProperties } from "../../wrapper_property"; -import { logger } from "../../../logutils"; -import { Messages } from "../../utils/messages"; -import { AwsWrapperError } from "../../utils/errors"; -import { ClientWrapper } from "../../client_wrapper"; -import { TelemetryCounter } from "../../utils/telemetry/telemetry_counter"; -import { RegionUtils } from "../../utils/region_utils"; -import { CanReleaseResources } from "../../can_release_resources"; +import { BaseSamlAuthPlugin } from "./saml_auth_plugin"; +import { IamAuthUtils } from "../../utils/iam_auth_utils"; -export class OktaAuthPlugin extends AbstractConnectionPlugin implements CanReleaseResources { - protected static readonly tokenCache = new Map(); - private static readonly subscribedMethods = new Set(["connect", "forceConnect"]); - protected pluginService: PluginService; - protected rdsUtils = new RdsUtils(); - private readonly credentialsProviderFactory: CredentialsProviderFactory; - private readonly fetchTokenCounter: TelemetryCounter; - - constructor(pluginService: PluginService, credentialsProviderFactory: CredentialsProviderFactory) { - super(); - this.pluginService = pluginService; - this.credentialsProviderFactory = credentialsProviderFactory; - this.fetchTokenCounter = this.pluginService.getTelemetryFactory().createCounter("oktaAuth.fetchToken.count"); - } - - public getSubscribedMethods(): Set { - return OktaAuthPlugin.subscribedMethods; - } - - connect( - hostInfo: HostInfo, - props: Map, - isInitialConnection: boolean, - connectFunc: () => Promise - ): Promise { - return this.connectInternal(hostInfo, props, connectFunc); - } - - forceConnect( - hostInfo: HostInfo, - props: Map, - isInitialConnection: boolean, - connectFunc: () => Promise - ): Promise { - return this.connectInternal(hostInfo, props, connectFunc); - } - - async connectInternal(hostInfo: HostInfo, props: Map, connectFunc: () => Promise): Promise { - SamlUtils.checkIdpCredentialsWithFallback(props); - - const host = IamAuthUtils.getIamHost(props, hostInfo); - const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getDialect().getDefaultPort()); - const region = RegionUtils.getRegion(props.get(WrapperProperties.IAM_REGION.name), host); - - const cacheKey = IamAuthUtils.getCacheKey(port, WrapperProperties.DB_USER.get(props), host, region); - const tokenInfo = OktaAuthPlugin.tokenCache.get(cacheKey); - - const isCachedToken = tokenInfo !== undefined && !tokenInfo.isExpired(); - - if (isCachedToken) { - logger.debug(Messages.get("AuthenticationToken.useCachedToken", tokenInfo.token)); - WrapperProperties.PASSWORD.set(props, tokenInfo.token); - } else { - await this.updateAuthenticationToken(hostInfo, props, region, cacheKey, host); - } - WrapperProperties.USER.set(props, WrapperProperties.DB_USER.get(props)); - this.pluginService.updateConfigWithProperties(props); - - try { - return await connectFunc(); - } catch (e: any) { - if (!this.pluginService.isLoginError(e as Error) || !isCachedToken) { - logger.debug(Messages.get("Authentication.connectError", e.message)); - throw e; - } - try { - await this.updateAuthenticationToken(hostInfo, props, region, cacheKey, host); - return await connectFunc(); - } catch (e: any) { - throw new AwsWrapperError(Messages.get("SamlAuthPlugin.unhandledError", e.message)); - } - } - } - - public async updateAuthenticationToken(hostInfo: HostInfo, props: Map, region: string, cacheKey: string, iamHost): Promise { - const tokenExpirationSec = WrapperProperties.IAM_TOKEN_EXPIRATION.get(props); - if (tokenExpirationSec < 0) { - throw new AwsWrapperError(Messages.get("AuthenticationToken.tokenExpirationLessThanZero")); - } - const tokenExpiry = Date.now() + tokenExpirationSec * 1000; - const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getDialect().getDefaultPort()); - this.fetchTokenCounter.inc(); - const token = await IamAuthUtils.generateAuthenticationToken( - iamHost, - port, - region, - WrapperProperties.DB_USER.get(props), - await this.credentialsProviderFactory.getAwsCredentialsProvider(hostInfo.host, region, props), - this.pluginService - ); - logger.debug(Messages.get("AuthenticationToken.generatedNewToken", token)); - WrapperProperties.PASSWORD.set(props, token); - this.pluginService.updateConfigWithProperties(props); - OktaAuthPlugin.tokenCache.set(cacheKey, new TokenInfo(token, tokenExpiry)); - } - - releaseResources(): Promise { - OktaAuthPlugin.tokenCache.clear(); - return; +export class OktaAuthPlugin extends BaseSamlAuthPlugin { + constructor(pluginService: PluginService, credentialsProviderFactory: CredentialsProviderFactory, iamAuthUtils: IamAuthUtils = new IamAuthUtils()) { + super(pluginService, credentialsProviderFactory, "oktaAuth.fetchToken.count", iamAuthUtils); } } diff --git a/common/lib/plugins/federated_auth/saml_auth_plugin.ts b/common/lib/plugins/federated_auth/saml_auth_plugin.ts new file mode 100644 index 000000000..7aa48db07 --- /dev/null +++ b/common/lib/plugins/federated_auth/saml_auth_plugin.ts @@ -0,0 +1,167 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { AbstractConnectionPlugin } from "../../abstract_connection_plugin"; +import { PluginService } from "../../plugin_service"; +import { RdsUtils } from "../../utils/rds_utils"; +import { HostInfo } from "../../host_info"; +import { IamAuthUtils, TokenInfo } from "../../utils/iam_auth_utils"; +import { WrapperProperties } from "../../wrapper_property"; +import { logger } from "../../../logutils"; +import { AwsWrapperError } from "../../utils/errors"; +import { Messages } from "../../utils/messages"; +import { CredentialsProviderFactory } from "./credentials_provider_factory"; +import { SamlUtils } from "../../utils/saml_utils"; +import { ClientWrapper } from "../../client_wrapper"; +import { TelemetryCounter } from "../../utils/telemetry/telemetry_counter"; +import { RegionUtils } from "../../utils/region_utils"; +import { RdsUrlType } from "../../utils/rds_url_type"; +import { GDBRegionUtils } from "../../utils/gdb_region_utils"; +import { AwsCredentialIdentity, AwsCredentialIdentityProvider } from "@smithy/types/dist-types/identity/awsCredentialIdentity"; + +export class BaseSamlAuthPlugin extends AbstractConnectionPlugin { + protected static readonly tokenCache = new Map(); + protected rdsUtils: RdsUtils = new RdsUtils(); + protected pluginService: PluginService; + private static readonly subscribedMethods = new Set(["connect", "forceConnect"]); + protected readonly credentialsProviderFactory: CredentialsProviderFactory; + protected readonly fetchTokenCounter: TelemetryCounter; + protected regionUtils: RegionUtils; + protected readonly tokenCacheInstance: Map; + + private readonly iamAuthUtils: IamAuthUtils; + + public getSubscribedMethods(): Set { + return BaseSamlAuthPlugin.subscribedMethods; + } + + protected constructor( + pluginService: PluginService, + credentialsProviderFactory: CredentialsProviderFactory, + telemetryCounterName: string, + iamAuthUtils: IamAuthUtils = new IamAuthUtils() + ) { + super(); + this.credentialsProviderFactory = credentialsProviderFactory; + this.pluginService = pluginService; + this.fetchTokenCounter = this.pluginService.getTelemetryFactory().createCounter(telemetryCounterName); + this.tokenCacheInstance = BaseSamlAuthPlugin.tokenCache; + this.iamAuthUtils = iamAuthUtils; + } + + connect( + hostInfo: HostInfo, + props: Map, + isInitialConnection: boolean, + connectFunc: () => Promise + ): Promise { + return this.connectInternal(hostInfo, props, connectFunc); + } + + forceConnect( + hostInfo: HostInfo, + props: Map, + isInitialConnection: boolean, + forceConnectFunc: () => Promise + ): Promise { + return this.connectInternal(hostInfo, props, forceConnectFunc); + } + + async connectInternal(hostInfo: HostInfo, props: Map, connectFunc: () => Promise): Promise { + SamlUtils.checkIdpCredentialsWithFallback(props); + + const host = this.iamAuthUtils.getIamHost(props, hostInfo); + const port = this.iamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getDialect().getDefaultPort()); + + const type: RdsUrlType = this.rdsUtils.identifyRdsType(host.host); + + let credentialsProvider: AwsCredentialIdentity | AwsCredentialIdentityProvider | undefined = undefined; + if (type === RdsUrlType.RDS_GLOBAL_WRITER_CLUSTER) { + credentialsProvider = await this.credentialsProviderFactory.getAwsCredentialsProvider(hostInfo.host, null, props); + } + + this.regionUtils = type == RdsUrlType.RDS_GLOBAL_WRITER_CLUSTER ? new GDBRegionUtils(credentialsProvider) : new RegionUtils(); + const region: string | null = await this.regionUtils.getRegion(WrapperProperties.IAM_REGION.name, host, props); + + if (!region) { + throw new AwsWrapperError(Messages.get("SamlAuthPlugin.unableToDetermineRegion", WrapperProperties.IAM_REGION.name)); + } + + const cacheKey = this.iamAuthUtils.getCacheKey(port, WrapperProperties.DB_USER.get(props), host.host, region); + const tokenInfo = this.tokenCacheInstance.get(cacheKey); + + const isCachedToken: boolean = tokenInfo !== undefined && !tokenInfo.isExpired(); + + if (isCachedToken && tokenInfo) { + logger.debug(Messages.get("AuthenticationToken.useCachedToken", tokenInfo.token)); + WrapperProperties.PASSWORD.set(props, tokenInfo.token); + } else { + await this.updateAuthenticationToken(hostInfo, props, region, cacheKey, host.host, credentialsProvider); + } + WrapperProperties.USER.set(props, WrapperProperties.DB_USER.get(props)); + this.pluginService.updateConfigWithProperties(props); + + try { + return await connectFunc(); + } catch (e: any) { + if (!this.pluginService.isLoginError(e as Error) || !isCachedToken) { + throw e; + } + try { + await this.updateAuthenticationToken(hostInfo, props, region, cacheKey, host.host, credentialsProvider); + return await connectFunc(); + } catch (e: any) { + throw new AwsWrapperError(Messages.get("SamlAuthPlugin.unhandledError", e.message)); + } + } + } + + public async updateAuthenticationToken( + hostInfo: HostInfo, + props: Map, + region: string, + cacheKey: string, + iamHost: string, + credentials?: AwsCredentialIdentity | AwsCredentialIdentityProvider + ): Promise { + const tokenExpirationSec = WrapperProperties.IAM_TOKEN_EXPIRATION.get(props); + if (tokenExpirationSec < 0) { + throw new AwsWrapperError(Messages.get("AuthenticationToken.tokenExpirationLessThanZero")); + } + const tokenExpiry: number = Date.now() + tokenExpirationSec * 1000; + const port = this.iamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getDialect().getDefaultPort()); + + this.fetchTokenCounter.inc(); + + const token = await this.iamAuthUtils.generateAuthenticationToken( + iamHost, + port, + region, + WrapperProperties.DB_USER.get(props), + credentials ?? (await this.credentialsProviderFactory.getAwsCredentialsProvider(hostInfo.host, region, props)), + this.pluginService + ); + + logger.debug(Messages.get("AuthenticationToken.generatedNewToken", token)); + WrapperProperties.PASSWORD.set(props, token); + this.pluginService.updateConfigWithProperties(props); + this.tokenCacheInstance.set(cacheKey, new TokenInfo(token, tokenExpiry)); + } + + static releaseResources(): void { + BaseSamlAuthPlugin.tokenCache.clear(); + } +} diff --git a/common/lib/plugins/federated_auth/saml_credentials_provider_factory.ts b/common/lib/plugins/federated_auth/saml_credentials_provider_factory.ts index bfdfe7f6d..e6dfc397c 100644 --- a/common/lib/plugins/federated_auth/saml_credentials_provider_factory.ts +++ b/common/lib/plugins/federated_auth/saml_credentials_provider_factory.ts @@ -17,7 +17,6 @@ import { CredentialsProviderFactory } from "./credentials_provider_factory"; import { AssumeRoleWithSAMLCommand, STSClient } from "@aws-sdk/client-sts"; import { WrapperProperties } from "../../wrapper_property"; - import { AwsWrapperError } from "../../utils/errors"; import { AwsCredentialIdentity, AwsCredentialIdentityProvider } from "@smithy/types/dist-types/identity/awsCredentialIdentity"; import { decode } from "entities"; @@ -25,7 +24,7 @@ import { decode } from "entities"; export abstract class SamlCredentialsProviderFactory implements CredentialsProviderFactory { async getAwsCredentialsProvider( host: string, - region: string, + region: string | null, props: Map ): Promise { const samlAssertion = await this.getSamlAssertion(props); @@ -35,9 +34,7 @@ export abstract class SamlCredentialsProviderFactory implements CredentialsProvi PrincipalArn: WrapperProperties.IAM_IDP_ARN.get(props) }); - const stsClient = new STSClient({ - region: region - }); + const stsClient = region !== null ? new STSClient({ region }) : new STSClient(); const results = await stsClient.send(assumeRoleWithSamlRequest); const credentials = results["Credentials"]; diff --git a/common/lib/utils/gdb_region_utils.ts b/common/lib/utils/gdb_region_utils.ts new file mode 100644 index 000000000..12bd9bb01 --- /dev/null +++ b/common/lib/utils/gdb_region_utils.ts @@ -0,0 +1,102 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { RegionUtils } from "./region_utils"; +import { HostInfo } from "../host_info"; +import { AwsCredentialsManager } from "../authentication/aws_credentials_manager"; +import { DescribeGlobalClustersCommand, GlobalCluster, GlobalClusterMember, RDSClient } from "@aws-sdk/client-rds"; +import { AwsCredentialIdentity, AwsCredentialIdentityProvider } from "@smithy/types/dist-types/identity/awsCredentialIdentity"; + +export class GDBRegionUtils extends RegionUtils { + private static readonly GDB_CLUSTER_ARN_PATTERN = /^arn:aws:rds:(?[^:\n]*):([^:\n]*):([^:/\n]*[:/])?(.*)$/; + private static readonly REGION_GROUP = "region"; + private credentialsProvider: AwsCredentialIdentity | AwsCredentialIdentityProvider | undefined; + + constructor(credentialsProvider?: AwsCredentialIdentity | AwsCredentialIdentityProvider) { + super(); + this.credentialsProvider = credentialsProvider; + } + + async getRegion(regionKey: string, hostInfo?: HostInfo, props?: Map): Promise { + if (props.get(regionKey)) { + return super.getRegion(props.get(regionKey), hostInfo); + } + + if (!hostInfo || !props) { + return null; + } + + const clusterId = GDBRegionUtils.rdsUtils.getRdsClusterId(hostInfo.host); + if (!clusterId) { + return null; + } + + const writerClusterArn = await this.findWriterClusterArn(hostInfo, props, clusterId); + return writerClusterArn ? this.getRegionFromClusterArn(writerClusterArn) : null; + } + + private async findWriterClusterArn(hostInfo: HostInfo, props: Map, globalClusterIdentifier: string): Promise { + if (this.credentialsProvider != null) { + this.credentialsProvider = AwsCredentialsManager.getProvider(hostInfo, props); + } + + const rdsClient = this.getRdsClient(); + + try { + const command = new DescribeGlobalClustersCommand({ + GlobalClusterIdentifier: globalClusterIdentifier + }); + + const response = await rdsClient.send(command); + return this.extractWriterClusterArn(response.GlobalClusters); + } finally { + rdsClient.destroy(); + } + } + + private extractWriterClusterArn(globalClusters?: GlobalCluster[]): string | null { + if (!globalClusters) { + return null; + } + + for (const cluster of globalClusters) { + const writerArn = this.findWriterMemberArn(cluster.GlobalClusterMembers); + if (writerArn) { + return writerArn; + } + } + + return null; + } + + private findWriterMemberArn(members?: GlobalClusterMember[]): string | null { + if (!members) { + return null; + } + + const writerMember = members.find((member) => member.IsWriter); + return writerMember?.DBClusterArn ?? null; + } + + private getRegionFromClusterArn(clusterArn: string): string | null { + const match = clusterArn.match(GDBRegionUtils.GDB_CLUSTER_ARN_PATTERN); + return match?.groups?.[GDBRegionUtils.REGION_GROUP] ?? null; + } + + private getRdsClient(): RDSClient { + return new RDSClient({ credentials: this.credentialsProvider }); + } +} diff --git a/common/lib/utils/iam_auth_utils.ts b/common/lib/utils/iam_auth_utils.ts index 42525ddf8..e31169085 100644 --- a/common/lib/utils/iam_auth_utils.ts +++ b/common/lib/utils/iam_auth_utils.ts @@ -16,23 +16,26 @@ import { logger } from "../../logutils"; import { HostInfo } from "../host_info"; -import { WrapperProperties, WrapperProperty } from "../wrapper_property"; -import { AwsWrapperError } from "./errors"; +import { WrapperProperties } from "../wrapper_property"; import { Messages } from "./messages"; -import { RdsUtils } from "./rds_utils"; import { Signer } from "@aws-sdk/rds-signer"; import { AwsCredentialIdentity, AwsCredentialIdentityProvider } from "@smithy/types/dist-types/identity/awsCredentialIdentity"; import { PluginService } from "../plugin_service"; import { TelemetryTraceLevel } from "./telemetry/telemetry_trace_level"; +import { HostInfoBuilder } from "../host_info_builder"; export class IamAuthUtils { - private static readonly TELEMETRY_FETCH_TOKEN = "fetch IAM token"; + private readonly TELEMETRY_FETCH_TOKEN = "fetch IAM token"; - public static getIamHost(props: Map, hostInfo: HostInfo): string { - return WrapperProperties.IAM_HOST.get(props) ? WrapperProperties.IAM_HOST.get(props) : hostInfo.host; + public getIamHost(props: Map, hostInfo: HostInfo): HostInfo { + const iamHost: string | null = WrapperProperties.IAM_HOST.get(props); + + return iamHost + ? new HostInfoBuilder({ hostAvailabilityStrategy: hostInfo.hostAvailabilityStrategy }).copyFrom(hostInfo).withHost(iamHost).build() + : hostInfo; } - public static getIamPort(props: Map, hostInfo: HostInfo, defaultPort: number): number { + public getIamPort(props: Map, hostInfo: HostInfo, defaultPort: number): number { const port = WrapperProperties.IAM_DEFAULT_PORT.get(props); if (port) { if (isNaN(port) || port <= 0) { @@ -49,23 +52,11 @@ export class IamAuthUtils { } } - public static getRdsRegion(hostname: string, rdsUtils: RdsUtils, props: Map, wrapperProperty: WrapperProperty): string { - const rdsRegion = rdsUtils.getRdsRegion(hostname); - - if (!rdsRegion) { - const errorMessage = Messages.get("Authentication.unsupportedHostname", hostname); - logger.debug(errorMessage); - throw new AwsWrapperError(errorMessage); - } - - return wrapperProperty.get(props) ? wrapperProperty.get(props) : rdsRegion; - } - - public static getCacheKey(port: number, user?: string, hostname?: string, region?: string): string { + public getCacheKey(port: number, user?: string, hostname?: string, region?: string): string { return `${region}:${hostname}:${port}:${user}`; } - public static async generateAuthenticationToken( + public async generateAuthenticationToken( hostname: string, port: number, region: string, @@ -74,7 +65,7 @@ export class IamAuthUtils { pluginService: PluginService ): Promise { const telemetryFactory = pluginService.getTelemetryFactory(); - const telemetryContext = telemetryFactory.openTelemetryContext(IamAuthUtils.TELEMETRY_FETCH_TOKEN, TelemetryTraceLevel.NESTED); + const telemetryContext = telemetryFactory.openTelemetryContext(this.TELEMETRY_FETCH_TOKEN, TelemetryTraceLevel.NESTED); return await telemetryContext.start(async () => { const signer = new Signer({ hostname: hostname, diff --git a/common/lib/utils/messages.ts b/common/lib/utils/messages.ts index abc4db8d5..cc4c77509 100644 --- a/common/lib/utils/messages.ts +++ b/common/lib/utils/messages.ts @@ -167,6 +167,8 @@ const MESSAGES: Record = { "Okta SAML Assertion request failed with HTTP status '%s', reason phrase '%s', and response '%s'", "SamlCredentialsProviderFactory.getSamlAssertionFailed": "Failed to get SAML Assertion due to error: '%s'", "SamlAuthPlugin.unhandledError": "Unhandled error: '%s'", + "SamlAuthPlugin.unableToDetermineRegion": + "Unable to determine connection region. If you are using a non-standard RDS URL, please set the '%s' property.", "HostAvailabilityStrategy.invalidMaxRetries": "Invalid value of '%s' for configuration parameter `hostAvailabilityStrategyMaxRetries`. It must be an integer greater or equal to 1.", "HostAvailabilityStrategy.invalidInitialBackoffTime": diff --git a/common/lib/utils/region_utils.ts b/common/lib/utils/region_utils.ts index 7b26a619c..992164a81 100644 --- a/common/lib/utils/region_utils.ts +++ b/common/lib/utils/region_utils.ts @@ -17,6 +17,7 @@ import { RdsUtils } from "./rds_utils"; import { AwsWrapperError } from "./errors"; import { Messages } from "./messages"; +import { HostInfo } from "../host_info"; export class RegionUtils { static readonly REGIONS: string[] = [ @@ -67,21 +68,21 @@ export class RegionUtils { protected static readonly rdsUtils = new RdsUtils(); - static getRegion(regionString: string, host?: string): string | null { - const region = RegionUtils.getRegionFromRegionString(regionString); + async getRegion(regionKey: string, hostInfo?: HostInfo, props?: Map): Promise { + const region = this.getRegionFromRegionString(props.get(regionKey)); if (region !== null) { - return region; + return Promise.resolve(region); } - if (host) { - return RegionUtils.getRegionFromHost(host); + if (hostInfo) { + return Promise.resolve(this.getRegionFromHost(hostInfo.host)); } - return region; + return Promise.resolve(region); } - private static getRegionFromRegionString(regionString: string): string { + private getRegionFromRegionString(regionString: string): string | null { if (!regionString) { return null; } @@ -94,12 +95,12 @@ export class RegionUtils { return region; } - private static getRegionFromHost(host: string): string | null { + private getRegionFromHost(host: string): string | null { const regionString = RegionUtils.rdsUtils.getRdsRegion(host); if (!regionString) { throw new AwsWrapperError(Messages.get("AwsSdk.unsupportedRegion", regionString)); } - return RegionUtils.getRegionFromRegionString(regionString); + return this.getRegionFromRegionString(regionString); } } diff --git a/tests/unit/federated_auth_plugin.test.ts b/tests/unit/federated_auth_plugin.test.ts deleted file mode 100644 index d696e4511..000000000 --- a/tests/unit/federated_auth_plugin.test.ts +++ /dev/null @@ -1,164 +0,0 @@ -/* - Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"). - You may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -import { HostInfo, HostRole, PluginManager } from "../../common/lib"; -import { FederatedAuthPlugin } from "../../common/lib/plugins/federated_auth/federated_auth_plugin"; -import { PluginServiceImpl } from "../../common/lib/plugin_service"; -import { IamAuthUtils, TokenInfo } from "../../common/lib/utils/iam_auth_utils"; -import { WrapperProperties } from "../../common/lib/wrapper_property"; -import { anything, instance, mock, spy, verify, when } from "ts-mockito"; -import { CredentialsProviderFactory } from "../../common/lib/plugins/federated_auth/credentials_provider_factory"; -import { DatabaseDialect } from "../../common/lib/database_dialect/database_dialect"; -import { NullTelemetryFactory } from "../../common/lib/utils/telemetry/null_telemetry_factory"; -import { jest } from "@jest/globals"; -import { PgClientWrapper } from "../../common/lib/pg_client_wrapper"; - -const testToken = "testToken"; -const defaultPort = 5432; -const pgCacheKey = `us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:${defaultPort}:iamUser`; -const dbUser = "iamUser"; -const expirationFiveMinutes = 5 * 60 * 1000; -const tokenCache = new Map(); - -const host = "pg.testdb.us-east-2.rds.amazonaws.com"; -const iamHost = "pg-123.testdb.us-east-2.rds.amazonaws.com"; -const hostInfo = new HostInfo(host, defaultPort, HostRole.WRITER); -const testTokenInfo = new TokenInfo(testToken, Date.now() + expirationFiveMinutes); - -const mockDialect = mock(); -const mockDialectInstance = instance(mockDialect); -const mockPluginService = mock(PluginServiceImpl); -const mockCredentialsProviderFactory = mock(); -const spyIamUtils = spy(IamAuthUtils); -const testCredentials = { - accessKeyId: "foo", - secretAccessKey: "bar", - sessionToken: "baz" -}; -const mockConnectFunc = jest.fn(() => { - return Promise.resolve(mock(PgClientWrapper)); -}); - -describe("federatedAuthTest", () => { - let spyPlugin: FederatedAuthPlugin; - let props: Map; - - beforeEach(() => { - when(mockPluginService.getDialect()).thenReturn(mockDialectInstance); - when(mockPluginService.getTelemetryFactory()).thenReturn(new NullTelemetryFactory()); - when(mockDialect.getDefaultPort()).thenReturn(defaultPort); - when(mockCredentialsProviderFactory.getAwsCredentialsProvider(anything(), anything(), anything())).thenResolve(instance(testCredentials)); - props = new Map(); - WrapperProperties.PLUGINS.set(props, "federatedAuth"); - WrapperProperties.DB_USER.set(props, dbUser); - spyPlugin = spy(new FederatedAuthPlugin(instance(mockPluginService), instance(mockCredentialsProviderFactory))); - }); - - afterEach(async () => { - await PluginManager.releaseResources(); - }); - - it("testCachedToken", async () => { - const spyPluginInstance = instance(spyPlugin); - FederatedAuthPlugin["tokenCache"].set(pgCacheKey, testTokenInfo); - - const key = `us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:${defaultPort}:iamUser`; - tokenCache.set(key, testTokenInfo); - - await spyPluginInstance.connect(hostInfo, props, true, mockConnectFunc); - - expect(dbUser).toBe(WrapperProperties.USER.get(props)); - expect(testToken).toBe(WrapperProperties.PASSWORD.get(props)); - }); - - it("testExpiredCachedToken", async () => { - const spyPluginInstance: FederatedAuthPlugin = instance(spyPlugin); - - const key = `us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:${defaultPort}:iamUser`; - const expiredToken = "expiredToken"; - const expiredTokenInfo = new TokenInfo(expiredToken, Date.now() - 300000); - - FederatedAuthPlugin["tokenCache"].set(key, expiredTokenInfo); - - when(spyIamUtils.generateAuthenticationToken(anything(), anything(), anything(), anything(), anything(), anything())).thenResolve(testToken); - - await spyPluginInstance.connect(hostInfo, props, true, mockConnectFunc); - - expect(dbUser).toBe(WrapperProperties.USER.get(props)); - expect(testToken).toBe(WrapperProperties.PASSWORD.get(props)); - }); - - it("testNoCachedToken", async () => { - const spyPluginInstance = instance(spyPlugin); - - when(spyIamUtils.generateAuthenticationToken(anything(), anything(), anything(), anything(), anything(), anything())).thenResolve(testToken); - - await spyPluginInstance.connect(hostInfo, props, true, mockConnectFunc); - expect(dbUser).toBe(WrapperProperties.USER.get(props)); - expect(testToken).toBe(WrapperProperties.PASSWORD.get(props)); - }); - - it("testSpecifiedIamHostPortRegion", async () => { - const expectedHost = "pg.testdb.us-west-2.rds.amazonaws.com"; - const expectedPort = 9876; - const expectedRegion = "us-west-2"; - - WrapperProperties.IAM_HOST.set(props, expectedHost); - WrapperProperties.IAM_DEFAULT_PORT.set(props, expectedPort); - WrapperProperties.IAM_REGION.set(props, expectedRegion); - - const key = `us-west-2:pg.testdb.us-west-2.rds.amazonaws.com:${expectedPort}:iamUser`; - FederatedAuthPlugin["tokenCache"].set(key, testTokenInfo); - - const spyPluginInstance = instance(spyPlugin); - - await spyPluginInstance.connect(hostInfo, props, true, mockConnectFunc); - - expect(dbUser).toBe(WrapperProperties.USER.get(props)); - expect(testToken).toBe(WrapperProperties.PASSWORD.get(props)); - }); - - it("testIdpCredentialsFallback", async () => { - const expectedUser = "expectedUser"; - const expectedPassword = "expectedPassword"; - WrapperProperties.USER.set(props, expectedUser); - WrapperProperties.PASSWORD.set(props, expectedPassword); - - const spyPluginInstance = instance(spyPlugin); - - const key = `us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:${defaultPort}:iamUser`; - FederatedAuthPlugin["tokenCache"].set(key, testTokenInfo); - - await spyPluginInstance.connect(hostInfo, props, true, mockConnectFunc); - - expect(dbUser).toBe(WrapperProperties.USER.get(props)); - expect(testToken).toBe(WrapperProperties.PASSWORD.get(props)); - expect(expectedUser).toBe(WrapperProperties.IDP_USERNAME.get(props)); - expect(expectedPassword).toBe(WrapperProperties.IDP_PASSWORD.get(props)); - }); - - it("testUsingIamHost", async () => { - WrapperProperties.IAM_HOST.set(props, iamHost); - const spyPluginInstance = instance(spyPlugin); - - when(spyIamUtils.generateAuthenticationToken(anything(), anything(), anything(), anything(), anything(), anything())).thenResolve(testToken); - - await spyPluginInstance.connect(hostInfo, props, true, mockConnectFunc); - expect(dbUser).toBe(WrapperProperties.USER.get(props)); - expect(testToken).toBe(WrapperProperties.PASSWORD.get(props)); - verify(spyIamUtils.generateAuthenticationToken(iamHost, anything(), anything(), anything(), anything(), anything())).once(); - }); -}); diff --git a/tests/unit/iam_authentication_plugin.test.ts b/tests/unit/iam_authentication_plugin.test.ts index c6d384d46..f2ee620fb 100644 --- a/tests/unit/iam_authentication_plugin.test.ts +++ b/tests/unit/iam_authentication_plugin.test.ts @@ -21,7 +21,7 @@ import { SimpleHostAvailabilityStrategy } from "../../common/lib/host_availabili import { AwsClient } from "../../common/lib/aws_client"; import { WrapperProperties } from "../../common/lib/wrapper_property"; import fetch from "node-fetch"; -import { anything, instance, mock, spy, when } from "ts-mockito"; +import { anything, instance, mock, reset, spy, when } from "ts-mockito"; import { IamAuthUtils, TokenInfo } from "../../common/lib/utils/iam_auth_utils"; import { NullTelemetryFactory } from "../../common/lib/utils/telemetry/null_telemetry_factory"; @@ -58,22 +58,11 @@ const props = new Map(); const mockPluginService: PluginServiceImpl = mock(PluginServiceImpl); const mockClient: AwsClient = mock(AwsClient); -const spyIamAuthUtils = spy(IamAuthUtils); class IamAuthenticationPluginTestClass extends IamAuthenticationPlugin { put(key: string, token: TokenInfo) { IamAuthenticationPlugin.tokenCache.set(key, token); } - - public async generateAuthenticationToken( - hostInfo: HostInfo, - props: Map, - hostname: string, - port: number, - region: string - ): Promise { - return Promise.resolve(GENERATED_TOKEN); - } } async function testGenerateToken(info: HostInfo, plugin: IamAuthenticationPluginTestClass) { @@ -100,8 +89,11 @@ async function testToken(info: HostInfo, plugin: IamAuthenticationPlugin) { } describe("testIamAuth", () => { + let spyIamAuthUtils: IamAuthUtils; + beforeEach(() => { PluginManager.releaseResources(); + spyIamAuthUtils = spy(new IamAuthUtils()); props.clear(); props.set(WrapperProperties.USER.name, "postgresqlUser"); @@ -110,15 +102,20 @@ describe("testIamAuth", () => { when(mockPluginService.getCurrentClient()).thenReturn(instance(mockClient)); when(mockPluginService.getTelemetryFactory()).thenReturn(new NullTelemetryFactory()); + when(spyIamAuthUtils.generateAuthenticationToken(anything(), anything(), anything(), anything(), anything(), anything())).thenResolve( GENERATED_TOKEN ); }); + afterEach(() => { + reset(spyIamAuthUtils); + }); + it("testPostgresConnectValidTokenInCache", async () => { when(mockClient.defaultPort).thenReturn(DEFAULT_PG_PORT); - const plugin = new IamAuthenticationPluginTestClass(instance(mockPluginService)); + const plugin = new IamAuthenticationPluginTestClass(instance(mockPluginService), instance(spyIamAuthUtils)); plugin.put(PG_CACHE_KEY, new TokenInfo(TEST_TOKEN, Date.now() + 300000)); await testToken(PG_HOST_INFO, plugin); }); @@ -129,7 +126,7 @@ describe("testIamAuth", () => { when(mockClient.defaultPort).thenReturn(DEFAULT_MYSQL_PORT); - const plugin = new IamAuthenticationPluginTestClass(instance(mockPluginService)); + const plugin = new IamAuthenticationPluginTestClass(instance(mockPluginService), instance(spyIamAuthUtils)); plugin.put(MYSQL_CACHE_KEY, new TokenInfo(TEST_TOKEN, Date.now() + 300000)); await testToken(MYSQL_HOST_INFO, plugin); @@ -139,7 +136,7 @@ describe("testIamAuth", () => { props.set(WrapperProperties.IAM_DEFAULT_PORT.name, 0); const cacheKeyWithNewPort: string = `us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:${PG_HOST_INFO_WITH_PORT.port}:postgresqlUser`; - const plugin = new IamAuthenticationPluginTestClass(instance(mockPluginService)); + const plugin = new IamAuthenticationPluginTestClass(instance(mockPluginService), instance(spyIamAuthUtils)); plugin.put(cacheKeyWithNewPort, new TokenInfo(TEST_TOKEN, Date.now() + 300000)); @@ -151,7 +148,7 @@ describe("testIamAuth", () => { when(mockClient.defaultPort).thenReturn(DEFAULT_PG_PORT); const cacheKeyWithNewPort: string = `us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:${DEFAULT_PG_PORT}:postgresqlUser`; - const plugin = new IamAuthenticationPluginTestClass(instance(mockPluginService)); + const plugin = new IamAuthenticationPluginTestClass(instance(mockPluginService), instance(spyIamAuthUtils)); plugin.put(cacheKeyWithNewPort, new TokenInfo(TEST_TOKEN, Date.now() + 300000)); @@ -160,7 +157,7 @@ describe("testIamAuth", () => { it("testConnectExpiredTokenInCache", async () => { when(mockClient.defaultPort).thenReturn(DEFAULT_PG_PORT); - const plugin = new IamAuthenticationPluginTestClass(instance(mockPluginService)); + const plugin = new IamAuthenticationPluginTestClass(instance(mockPluginService), instance(spyIamAuthUtils)); plugin.put(PG_CACHE_KEY, new TokenInfo(TEST_TOKEN, Date.now() - 300000)); await testGenerateToken(PG_HOST_INFO, plugin); @@ -168,13 +165,13 @@ describe("testIamAuth", () => { it("testConnectEmptyCache", async () => { when(mockClient.defaultPort).thenReturn(DEFAULT_PG_PORT); - const plugin = new IamAuthenticationPluginTestClass(instance(mockPluginService)); + const plugin = new IamAuthenticationPluginTestClass(instance(mockPluginService), instance(spyIamAuthUtils)); await testGenerateToken(PG_HOST_INFO, plugin); }); it("testConnectWithSpecifiedPort", async () => { const cacheKeyWithSpecifiedPort: string = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:1234:postgresqlUser"; - const plugin = new IamAuthenticationPluginTestClass(instance(mockPluginService)); + const plugin = new IamAuthenticationPluginTestClass(instance(mockPluginService), instance(spyIamAuthUtils)); plugin.put(cacheKeyWithSpecifiedPort, new TokenInfo(TEST_TOKEN, Date.now() + 300000)); await testToken(PG_HOST_INFO_WITH_PORT, plugin); @@ -185,7 +182,7 @@ describe("testIamAuth", () => { props.set(WrapperProperties.IAM_DEFAULT_PORT.name, iamDefaultPort); const cacheKeyWithNewPort: string = `us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:${iamDefaultPort}:postgresqlUser`; - const plugin = new IamAuthenticationPluginTestClass(instance(mockPluginService)); + const plugin = new IamAuthenticationPluginTestClass(instance(mockPluginService), instance(spyIamAuthUtils)); plugin.put(cacheKeyWithNewPort, new TokenInfo(TEST_TOKEN, Date.now() + 300000)); await testToken(PG_HOST_INFO_WITH_PORT, plugin); @@ -196,7 +193,7 @@ describe("testIamAuth", () => { when(mockClient.defaultPort).thenReturn(DEFAULT_PG_PORT); const cacheKeyWithNewRegion: string = `us-west-1:pg.testdb.us-west-1.rds.amazonaws.com:${DEFAULT_PG_PORT}:postgresqlUser`; - const plugin = new IamAuthenticationPluginTestClass(instance(mockPluginService)); + const plugin = new IamAuthenticationPluginTestClass(instance(mockPluginService), instance(spyIamAuthUtils)); plugin.put(cacheKeyWithNewRegion, new TokenInfo(TEST_TOKEN, Date.now() + 300000)); await testToken(PG_HOST_INFO_WITH_REGION, plugin); diff --git a/tests/unit/okta_auth_plugin.test.ts b/tests/unit/okta_auth_plugin.test.ts deleted file mode 100644 index 59a3c2811..000000000 --- a/tests/unit/okta_auth_plugin.test.ts +++ /dev/null @@ -1,166 +0,0 @@ -/* - Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"). - You may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -import { anything, instance, mock, spy, verify, when } from "ts-mockito"; -import { PluginServiceImpl } from "../../common/lib/plugin_service"; -import { CredentialsProviderFactory } from "../../common/lib/plugins/federated_auth/credentials_provider_factory"; -import { IamAuthUtils, TokenInfo } from "../../common/lib/utils/iam_auth_utils"; -import { HostInfo, PluginManager } from "../../common/lib"; -import { WrapperProperties } from "../../common/lib/wrapper_property"; -import { DatabaseDialect } from "../../common/lib/database_dialect/database_dialect"; - -import { OktaAuthPlugin } from "../../common/lib/plugins/federated_auth/okta_auth_plugin"; -import { NullTelemetryFactory } from "../../common/lib/utils/telemetry/null_telemetry_factory"; -import { jest } from "@jest/globals"; -import { PgClientWrapper } from "../../common/lib/pg_client_wrapper"; - -const defaultPort = 1234; -const host = "pg.testdb.us-east-2.rds.amazonaws.com"; -const iamHost = "pg-123.testdb.us-east-2.rds.amazonaws.com"; -const hostInfo = new HostInfo(host, defaultPort); -const dbUser = "iamUser"; -const region = "us-east-2"; -const testToken = "someTestToken"; -const testTokenInfo = new TokenInfo(testToken, Date.now() + 300000); - -const mockPluginService = mock(PluginServiceImpl); -const mockDialect = mock(); -const mockDialectInstance = instance(mockDialect); -const testCredentials = { - accessKeyId: "foo", - secretAccessKey: "bar", - sessionToken: "baz" -}; -const spyIamUtils = spy(IamAuthUtils); -const mockCredentialsProviderFactory = mock(); -const mockConnectFunc = jest.fn(() => { - return Promise.resolve(mock(PgClientWrapper)); -}); - -describe("oktaAuthTest", () => { - let spyPlugin: OktaAuthPlugin; - let props: Map; - - beforeEach(() => { - when(mockPluginService.getDialect()).thenReturn(mockDialectInstance); - when(mockPluginService.getTelemetryFactory()).thenReturn(new NullTelemetryFactory()); - when(mockDialect.getDefaultPort()).thenReturn(defaultPort); - when(mockCredentialsProviderFactory.getAwsCredentialsProvider(anything(), anything(), anything())).thenResolve(testCredentials); - props = new Map(); - WrapperProperties.PLUGINS.set(props, "okta"); - WrapperProperties.DB_USER.set(props, dbUser); - spyPlugin = spy(new OktaAuthPlugin(instance(mockPluginService), instance(mockCredentialsProviderFactory))); - }); - - afterEach(async () => { - await PluginManager.releaseResources(); - }); - - it("testCachedToken", async () => { - const spyPluginInstance = instance(spyPlugin); - const key = `us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:${defaultPort}:iamUser`; - - OktaAuthPlugin["tokenCache"].set(key, testTokenInfo); - - await spyPluginInstance.connect(hostInfo, props, false, mockConnectFunc); - - expect(dbUser).toBe(WrapperProperties.USER.get(props)); - expect(testToken).toBe(WrapperProperties.PASSWORD.get(props)); - }); - - it("testExpiredCachedToken", async () => { - const spyPluginInstance = instance(spyPlugin); - const key = `us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:${defaultPort}:iamUser`; - - const someExpiredToken = "someExpiredToken"; - const expiredTokenInfo = new TokenInfo(someExpiredToken, Date.now() - 300000); - - OktaAuthPlugin["tokenCache"].set(key, expiredTokenInfo); - - when(spyIamUtils.generateAuthenticationToken(anything(), anything(), anything(), anything(), anything(), anything())).thenResolve(testToken); - - await spyPluginInstance.connect(hostInfo, props, false, mockConnectFunc); - - verify( - spyIamUtils.generateAuthenticationToken(hostInfo.host, defaultPort, region, dbUser, testCredentials, instance(mockPluginService)) - ).called(); - expect(dbUser).toBe(WrapperProperties.USER.get(props)); - expect(testToken).toBe(WrapperProperties.PASSWORD.get(props)); - }); - - it("testNoCachedToken", async () => { - const spyPluginInstance = instance(spyPlugin); - when(spyIamUtils.generateAuthenticationToken(anything(), anything(), anything(), anything(), anything(), anything())).thenResolve(testToken); - - await spyPluginInstance.connect(hostInfo, props, true, mockConnectFunc); - - verify( - spyIamUtils.generateAuthenticationToken(hostInfo.host, defaultPort, region, dbUser, testCredentials, instance(mockPluginService)) - ).called(); - expect(dbUser).toBe(WrapperProperties.USER.get(props)); - expect(testToken).toBe(WrapperProperties.PASSWORD.get(props)); - }); - - it("testSpecifiedIamHostPortRegion", async () => { - const spyPluginInstance = instance(spyPlugin); - const expectedHost = "pg.testdb.us-west-2.rds.amazonaws.com"; - const expectedPort = 9876; - const expectedRegion = "us-west-2"; - - WrapperProperties.IAM_HOST.set(props, expectedHost); - WrapperProperties.IAM_DEFAULT_PORT.set(props, expectedPort); - WrapperProperties.IAM_REGION.set(props, expectedRegion); - - const key = `us-west-2:pg.testdb.us-west-2.rds.amazonaws.com:${expectedPort}:iamUser`; - - OktaAuthPlugin["tokenCache"].set(key, testTokenInfo); - - await spyPluginInstance.connect(hostInfo, props, true, mockConnectFunc); - - expect(dbUser).toBe(WrapperProperties.USER.get(props)); - expect(testToken).toBe(WrapperProperties.PASSWORD.get(props)); - }); - - it("testIdpCredentialsFallback", async () => { - const spyPluginInstance = instance(spyPlugin); - const expectedUser = "expectedUser"; - const expectedPassword = "expectedPassword"; - - WrapperProperties.USER.set(props, expectedUser); - WrapperProperties.PASSWORD.set(props, expectedPassword); - - const key = `us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:${defaultPort}:iamUser`; - OktaAuthPlugin["tokenCache"].set(key, testTokenInfo); - - await spyPluginInstance.connect(hostInfo, props, true, mockConnectFunc); - expect(dbUser).toBe(WrapperProperties.USER.get(props)); - expect(testToken).toBe(WrapperProperties.PASSWORD.get(props)); - expect(expectedUser).toBe(WrapperProperties.IDP_USERNAME.get(props)); - expect(expectedPassword).toBe(WrapperProperties.IDP_PASSWORD.get(props)); - }); - - it("testUsingIamHost", async () => { - WrapperProperties.IAM_HOST.set(props, iamHost); - const spyPluginInstance = instance(spyPlugin); - when(spyIamUtils.generateAuthenticationToken(anything(), anything(), anything(), anything(), anything(), anything())).thenResolve(testToken); - - await spyPluginInstance.connect(hostInfo, props, true, mockConnectFunc); - - verify(spyIamUtils.generateAuthenticationToken(iamHost, defaultPort, region, dbUser, testCredentials, instance(mockPluginService))).called(); - expect(dbUser).toBe(WrapperProperties.USER.get(props)); - expect(testToken).toBe(WrapperProperties.PASSWORD.get(props)); - }); -}); diff --git a/tests/unit/saml_auth_plugin.test.ts b/tests/unit/saml_auth_plugin.test.ts new file mode 100644 index 000000000..76b55d246 --- /dev/null +++ b/tests/unit/saml_auth_plugin.test.ts @@ -0,0 +1,161 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { HostInfo, HostRole } from "../../common/lib"; +import { FederatedAuthPlugin } from "../../common/lib/plugins/federated_auth/federated_auth_plugin"; +import { OktaAuthPlugin } from "../../common/lib/plugins/federated_auth/okta_auth_plugin"; +import { BaseSamlAuthPlugin } from "../../common/lib/plugins/federated_auth/saml_auth_plugin"; +import { PluginServiceImpl } from "../../common/lib/plugin_service"; +import { IamAuthUtils, TokenInfo } from "../../common/lib/utils/iam_auth_utils"; +import { WrapperProperties } from "../../common/lib/wrapper_property"; +import { anything, instance, mock, reset, spy, verify, when } from "ts-mockito"; +import { CredentialsProviderFactory } from "../../common/lib/plugins/federated_auth/credentials_provider_factory"; +import { DatabaseDialect } from "../../common/lib/database_dialect/database_dialect"; +import { NullTelemetryFactory } from "../../common/lib/utils/telemetry/null_telemetry_factory"; +import { PgClientWrapper } from "../../common/lib/pg_client_wrapper"; + +const testToken = "testToken"; +const defaultPort = 5432; +const dbUser = "iamUser"; +const expirationFiveMinutes = 5 * 60 * 1000; + +const host = "pg.testdb.us-east-2.rds.amazonaws.com"; +const iamHost = "pg-123.testdb.us-east-2.rds.amazonaws.com"; +const testTokenInfo = new TokenInfo(testToken, Date.now() + expirationFiveMinutes); +const hostInfo = new HostInfo(host, defaultPort, HostRole.WRITER); + +const mockDialect = mock(); +const mockDialectInstance = instance(mockDialect); +const mockPluginService = mock(PluginServiceImpl); +const mockCredentialsProviderFactory = mock(); +const mockClientWrapper = mock(PgClientWrapper); +const testCredentials = { + accessKeyId: "foo", + secretAccessKey: "bar", + sessionToken: "baz" +}; +const mockConnectFunc = () => Promise.resolve(instance(mockClientWrapper)); + +describe.each([ + { pluginName: "federatedAuth", PluginClass: FederatedAuthPlugin }, + { pluginName: "okta", PluginClass: OktaAuthPlugin } +])("$pluginName plugin tests", ({ pluginName, PluginClass }) => { + let plugin: FederatedAuthPlugin | OktaAuthPlugin; + let spyIamAuthUtils: IamAuthUtils; + let props: Map; + + beforeEach(() => { + spyIamAuthUtils = spy(new IamAuthUtils()); + + when(mockPluginService.getDialect()).thenReturn(mockDialectInstance); + when(mockPluginService.getTelemetryFactory()).thenReturn(new NullTelemetryFactory()); + when(mockDialect.getDefaultPort()).thenReturn(defaultPort); + when(mockCredentialsProviderFactory.getAwsCredentialsProvider(anything(), anything(), anything())).thenResolve(testCredentials); + + props = new Map(); + WrapperProperties.PLUGINS.set(props, pluginName); + WrapperProperties.DB_USER.set(props, dbUser); + + plugin = new PluginClass(instance(mockPluginService), instance(mockCredentialsProviderFactory), instance(spyIamAuthUtils)); + }); + + afterEach(() => { + BaseSamlAuthPlugin.releaseResources(); + reset(spyIamAuthUtils); + }); + + it("testCachedToken", async () => { + const pgCacheKey = `us-east-2:${host}:${defaultPort}:${dbUser}`; + + BaseSamlAuthPlugin["tokenCache"].set(pgCacheKey, testTokenInfo); + + await plugin.connect(hostInfo, props, true, mockConnectFunc); + + expect(WrapperProperties.USER.get(props)).toBe(dbUser); + expect(WrapperProperties.PASSWORD.get(props)).toBe(testToken); + }); + + it("testExpiredCachedToken", async () => { + const key = `us-east-2:${host}:${defaultPort}:${dbUser}`; + const expiredToken = "expiredToken"; + const expiredTokenInfo = new TokenInfo(expiredToken, Date.now() - 300000); + + BaseSamlAuthPlugin["tokenCache"].set(key, expiredTokenInfo); + + when(spyIamAuthUtils.generateAuthenticationToken(anything(), anything(), anything(), anything(), anything(), anything())).thenResolve(testToken); + + await plugin.connect(hostInfo, props, true, mockConnectFunc); + + expect(WrapperProperties.USER.get(props)).toBe(dbUser); + expect(WrapperProperties.PASSWORD.get(props)).toBe(testToken); + }); + + it("testNoCachedToken", async () => { + when(spyIamAuthUtils.generateAuthenticationToken(anything(), anything(), anything(), anything(), anything(), anything())).thenResolve(testToken); + + await plugin.connect(hostInfo, props, true, mockConnectFunc); + + expect(WrapperProperties.USER.get(props)).toBe(dbUser); + expect(WrapperProperties.PASSWORD.get(props)).toBe(testToken); + }); + + it("testSpecifiedIamHostPortRegion", async () => { + const expectedHost = "pg.testdb.us-west-2.rds.amazonaws.com"; + const expectedPort = 9876; + const expectedRegion = "us-west-2"; + + WrapperProperties.IAM_HOST.set(props, expectedHost); + WrapperProperties.IAM_DEFAULT_PORT.set(props, expectedPort); + WrapperProperties.IAM_REGION.set(props, expectedRegion); + + const key = `${expectedRegion}:${expectedHost}:${expectedPort}:${dbUser}`; + BaseSamlAuthPlugin["tokenCache"].set(key, testTokenInfo); + + await plugin.connect(hostInfo, props, true, mockConnectFunc); + + expect(WrapperProperties.USER.get(props)).toBe(dbUser); + expect(WrapperProperties.PASSWORD.get(props)).toBe(testToken); + }); + + it("testIdpCredentialsFallback", async () => { + const expectedUser = "expectedUser"; + const expectedPassword = "expectedPassword"; + WrapperProperties.USER.set(props, expectedUser); + WrapperProperties.PASSWORD.set(props, expectedPassword); + + const key = `us-east-2:${host}:${defaultPort}:${dbUser}`; + BaseSamlAuthPlugin["tokenCache"].set(key, testTokenInfo); + + await plugin.connect(hostInfo, props, true, mockConnectFunc); + + expect(WrapperProperties.USER.get(props)).toBe(dbUser); + expect(WrapperProperties.PASSWORD.get(props)).toBe(testToken); + expect(WrapperProperties.IDP_USERNAME.get(props)).toBe(expectedUser); + expect(WrapperProperties.IDP_PASSWORD.get(props)).toBe(expectedPassword); + }); + + it("testUsingIamHost", async () => { + WrapperProperties.IAM_HOST.set(props, iamHost); + + when(spyIamAuthUtils.generateAuthenticationToken(anything(), anything(), anything(), anything(), anything(), anything())).thenResolve(testToken); + + await plugin.connect(hostInfo, props, true, mockConnectFunc); + + expect(WrapperProperties.USER.get(props)).toBe(dbUser); + expect(WrapperProperties.PASSWORD.get(props)).toBe(testToken); + verify(spyIamAuthUtils.generateAuthenticationToken(iamHost, anything(), anything(), anything(), anything(), anything())).once(); + }); +}); From 9f8754c3be0b7286363303e220d5b531a61667ac Mon Sep 17 00:00:00 2001 From: Karen Chen <64801825+karenc-bq@users.noreply.github.com> Date: Tue, 3 Mar 2026 12:21:46 -0800 Subject: [PATCH 2/2] chore: address comments --- .../lib/plugins/custom_endpoint/custom_endpoint_plugin.ts | 6 ++++-- common/lib/utils/iam_auth_utils.ts | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/common/lib/plugins/custom_endpoint/custom_endpoint_plugin.ts b/common/lib/plugins/custom_endpoint/custom_endpoint_plugin.ts index 3b73c60ac..70dee99cc 100644 --- a/common/lib/plugins/custom_endpoint/custom_endpoint_plugin.ts +++ b/common/lib/plugins/custom_endpoint/custom_endpoint_plugin.ts @@ -31,11 +31,14 @@ import { sleep } from "../../utils/utils"; import { CustomEndpointMonitor, CustomEndpointMonitorImpl } from "./custom_endpoint_monitor_impl"; import { SubscribedMethodHelper } from "../../utils/subscribed_method_helper"; import { CanReleaseResources } from "../../can_release_resources"; +import { RdsUrlType } from "../../utils/rds_url_type"; +import { GDBRegionUtils } from "../../utils/gdb_region_utils"; export class CustomEndpointPlugin extends AbstractConnectionPlugin implements CanReleaseResources { private static readonly TELEMETRY_WAIT_FOR_INFO_COUNTER = "customEndpoint.waitForInfo.counter"; private static SUBSCRIBED_METHODS: Set = new Set(SubscribedMethodHelper.NETWORK_BOUND_METHODS); private static readonly CACHE_CLEANUP_NANOS = BigInt(60_000_000_000); + private static readonly regionUtils: RegionUtils = new RegionUtils(); private static readonly rdsUtils = new RdsUtils(); protected static readonly monitors: SlidingExpirationCache = new SlidingExpirationCache( @@ -106,8 +109,7 @@ export class CustomEndpointPlugin extends AbstractConnectionPlugin implements Ca throw new AwsWrapperError(Messages.get("CustomEndpointPlugin.errorParsingEndpointIdentifier", this.customEndpointHostInfo.host)); } - const regionUtils = new RegionUtils(); - this.region = await regionUtils.getRegion(WrapperProperties.CUSTOM_ENDPOINT_REGION.name, this.customEndpointHostInfo, props); + this.region = await CustomEndpointPlugin.regionUtils.getRegion(WrapperProperties.CUSTOM_ENDPOINT_REGION.name, this.customEndpointHostInfo, props); if (!this.region) { throw new AwsWrapperError(Messages.get("CustomEndpointPlugin.unableToDetermineRegion", WrapperProperties.CUSTOM_ENDPOINT_REGION.name)); } diff --git a/common/lib/utils/iam_auth_utils.ts b/common/lib/utils/iam_auth_utils.ts index e31169085..7a8d29e5c 100644 --- a/common/lib/utils/iam_auth_utils.ts +++ b/common/lib/utils/iam_auth_utils.ts @@ -25,7 +25,7 @@ import { TelemetryTraceLevel } from "./telemetry/telemetry_trace_level"; import { HostInfoBuilder } from "../host_info_builder"; export class IamAuthUtils { - private readonly TELEMETRY_FETCH_TOKEN = "fetch IAM token"; + private static readonly TELEMETRY_FETCH_TOKEN = "fetch IAM token"; public getIamHost(props: Map, hostInfo: HostInfo): HostInfo { const iamHost: string | null = WrapperProperties.IAM_HOST.get(props); @@ -65,7 +65,7 @@ export class IamAuthUtils { pluginService: PluginService ): Promise { const telemetryFactory = pluginService.getTelemetryFactory(); - const telemetryContext = telemetryFactory.openTelemetryContext(this.TELEMETRY_FETCH_TOKEN, TelemetryTraceLevel.NESTED); + const telemetryContext = telemetryFactory.openTelemetryContext(IamAuthUtils.TELEMETRY_FETCH_TOKEN, TelemetryTraceLevel.NESTED); return await telemetryContext.start(async () => { const signer = new Signer({ hostname: hostname,