Skip to content

Commit 013d519

Browse files
Merge master into feature/console-session-profile
2 parents fcebbee + 75d1f24 commit 013d519

File tree

6 files changed

+163
-102
lines changed

6 files changed

+163
-102
lines changed

packages/core/src/sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider.ts

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import {
2020
SmusErrorCodes,
2121
extractAccountIdFromResourceMetadata,
2222
convertToToolkitCredentialProvider,
23+
isIamDomain,
2324
} from '../../shared/smusUtils'
2425
import {
2526
createSmusProfile,
@@ -212,10 +213,15 @@ export class SmusAuthenticationProvider {
212213

213214
const credentialsProvider = (await this.getDerCredentialsProvider()) as CredentialsProvider
214215

215-
// Get DataZoneCustomClientHelper instance and check if domain is IAM mode
216+
// Get DataZoneCustomClientHelper instance and fetch domain details to check if it's IAM mode
216217
const datazoneCustomClientHelper = DataZoneCustomClientHelper.getInstance(credentialsProvider, region)
217-
const isIamMode = await datazoneCustomClientHelper.isIamDomain(domainId)
218-
this.logger.debug(`is in IAM mode ${isIamMode}`)
218+
const domain = await datazoneCustomClientHelper.getDomain(domainId)
219+
const isIamMode = isIamDomain({
220+
domainVersion: domain.domainVersion,
221+
iamSignIns: domain.iamSignIns,
222+
domainId: domainId,
223+
})
224+
this.logger.debug(`Domain ${domainId} is in IAM mode: ${isIamMode}`)
219225
await setSmusIamModeContext(isIamMode)
220226
}
221227
} catch (error) {

packages/core/src/sagemakerunifiedstudio/shared/client/datazoneCustomClientHelper.ts

Lines changed: 16 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import * as DataZoneCustomClient from './datazonecustomclient'
1212
import { adaptConnectionCredentialsProvider } from './credentialsAdapter'
1313
import { CredentialsProvider } from '../../../auth/providers/credentials'
1414
import { ToolkitError } from '../../../shared/errors'
15-
import { SmusUtils } from '../smusUtils'
15+
import { SmusUtils, isIamDomain } from '../smusUtils'
1616
import { DevSettings } from '../../../shared/settings'
1717

1818
import { SmusErrorCodes } from '../smusUtils'
@@ -196,7 +196,7 @@ export class DataZoneCustomClientHelper {
196196
}
197197

198198
/**
199-
* Gets the domain with IAM authentication mode in preferences using pagination with early termination
199+
* Gets the domain with IAM authentication mode using pagination with early termination
200200
* @returns Promise resolving to the DataZone domain or undefined if not found
201201
*/
202202
public async getIamDomain(): Promise<DataZoneCustomClient.Types.DomainSummary | undefined> {
@@ -209,7 +209,7 @@ export class DataZoneCustomClientHelper {
209209
let totalDomainsChecked = 0
210210
const maxResultsPerPage = 25
211211

212-
// Paginate through domains and check each page for IAM-based domain
212+
// Paginate through domains and check each page for IAM domain
213213
do {
214214
const response = await this.listDomains({
215215
status: 'AVAILABLE',
@@ -224,12 +224,19 @@ export class DataZoneCustomClientHelper {
224224
`DataZoneCustomClientHelper: Checking ${domains.length} domains in current page (total checked: ${totalDomainsChecked})`
225225
)
226226

227-
// Check each domain in the current page for IAM authentication mode
227+
// Check each domain in the current page for IAM domain
228228
for (const domain of domains) {
229-
if (domain.preferences && domain.preferences.DOMAIN_MODE === 'EXPRESS') {
230-
logger.info(
231-
`DataZoneCustomClientHelper: Found IAM-based domain, id: ${domain.id} (${domain.name})`
232-
)
229+
// Log the complete domain object for debugging
230+
logger.debug(`DataZoneCustomClientHelper: Domain ${domain.id} full response: %O`, domain)
231+
232+
const isIam = isIamDomain({
233+
domainVersion: domain.domainVersion,
234+
iamSignIns: domain.iamSignIns,
235+
domainId: domain.id,
236+
})
237+
238+
if (isIam) {
239+
logger.info(`DataZoneCustomClientHelper: Found IAM domain, id: ${domain.id} (${domain.name})`)
233240
return domain
234241
}
235242
}
@@ -238,7 +245,7 @@ export class DataZoneCustomClientHelper {
238245
} while (nextToken)
239246

240247
logger.info(
241-
`DataZoneCustomClientHelper: No domain with IAM authentication (DOMAIN_MODE: EXPRESS) found after checking all ${totalDomainsChecked} domains`
248+
`DataZoneCustomClientHelper: No IAM domain found after checking all ${totalDomainsChecked} domains`
242249
)
243250
return undefined
244251
} catch (err) {
@@ -272,28 +279,6 @@ export class DataZoneCustomClientHelper {
272279
}
273280
}
274281

275-
/**
276-
* Checks if a specific domain is an IAM-based domain
277-
* @param domainId The ID of the domain to check
278-
* @returns Promise resolving to true if the domain is IAM-based, false otherwise
279-
*/
280-
public async isIamDomain(domainId: string): Promise<boolean> {
281-
try {
282-
this.logger.debug(`DataZoneCustomClientHelper: Checking if domain ${domainId} is IAM-based`)
283-
284-
const domain = await this.getDomain(domainId)
285-
const isIamMode = domain.preferences?.DOMAIN_MODE === 'EXPRESS' || false
286-
287-
this.logger.debug(
288-
`DataZoneCustomClientHelper: Domain ${domainId} is ${isIamMode ? 'IAM-based' : 'not IAM-based'}`
289-
)
290-
return isIamMode
291-
} catch (err) {
292-
this.logger.error('DataZoneCustomClientHelper: Failed to check if domain is IAM-based: %s', err as Error)
293-
throw err
294-
}
295-
}
296-
297282
/**
298283
* Searches for group profiles in the DataZone domain
299284
* @param domainIdentifier The domain identifier to search in

packages/core/src/sagemakerunifiedstudio/shared/smusUtils.ts

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,27 @@ export const SmusTimeouts = {
108108
*/
109109
export const DataZoneServiceId = 'datazone'
110110

111+
/**
112+
* Domain version constants
113+
*/
114+
export const DomainVersionV1 = 'V1'
115+
export const DomainVersionV2 = 'V2'
116+
117+
/**
118+
* IAM sign-in type constants
119+
*/
120+
export const IamSignInRole = 'IAM_ROLE'
121+
export const IamSignInUser = 'IAM_USER'
122+
123+
/**
124+
* Input interface for IAM domain check function
125+
*/
126+
export interface IamDomainCheckInput {
127+
domainVersion: string | undefined
128+
iamSignIns?: string[] | undefined
129+
domainId?: string
130+
}
131+
111132
/**
112133
* Interface for AWS credential objects that need validation
113134
*/
@@ -490,6 +511,49 @@ export class SmusUtils {
490511
}
491512
}
492513

514+
/**
515+
* Determines if a domain is an IAM domain based on IamSignIns field.
516+
*
517+
* IAM domains are V2 domains that support both IAM role and IAM user authentication.
518+
* A domain is considered an IAM domain if its IamSignIns array contains both:
519+
* - IAM_ROLE
520+
* - IAM_USER
521+
*
522+
* @param input - Object containing domain version, IamSignIns, and optional domainId for logging
523+
* @returns true if the domain is an IAM domain, false otherwise
524+
*/
525+
export function isIamDomain(input: IamDomainCheckInput): boolean {
526+
const logger = getLogger('smus')
527+
const domainIdLog = input.domainId ? ` for domain ${input.domainId}` : ''
528+
529+
// Only V2 domains can be IAM domains
530+
if (input.domainVersion !== DomainVersionV2) {
531+
logger.debug(
532+
`IAM domain check${domainIdLog}: Domain version is not V2 (value: ${input.domainVersion}), returning false`
533+
)
534+
return false
535+
}
536+
537+
// Check if IamSignIns contains both IAM_ROLE and IAM_USER
538+
if (!input.iamSignIns || !Array.isArray(input.iamSignIns)) {
539+
logger.debug(`IAM domain check${domainIdLog}: IamSignIns is missing or invalid, returning false`)
540+
return false
541+
}
542+
543+
const hasIamRole = input.iamSignIns.includes(IamSignInRole)
544+
const hasIamUser = input.iamSignIns.includes(IamSignInUser)
545+
546+
if (hasIamRole && hasIamUser) {
547+
logger.debug(`IAM domain check${domainIdLog}: IAM domain detected via IamSignIns`)
548+
return true
549+
}
550+
551+
logger.debug(
552+
`IAM domain check${domainIdLog}: IamSignIns does not contain both IAM_ROLE and IAM_USER, returning false`
553+
)
554+
return false
555+
}
556+
493557
/**
494558
* Extracts the account ID from a SageMaker ARN.
495559
* Supports formats like:

packages/core/src/test/sagemakerunifiedstudio/auth/smusAuthenticationProvider.test.ts

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,7 +1374,6 @@ describe('SmusAuthenticationProvider', function () {
13741374
let getResourceMetadataStub: sinon.SinonStub
13751375
let getDerCredentialsProviderStub: sinon.SinonStub
13761376
let getInstanceStub: sinon.SinonStub
1377-
let isIamDomainStub: sinon.SinonStub
13781377
let mockCredentialsProvider: any
13791378
let mockClientHelper: any
13801379

@@ -1404,12 +1403,19 @@ describe('SmusAuthenticationProvider', function () {
14041403
.resolves(mockCredentialsProvider)
14051404

14061405
// Mock DataZoneCustomClientHelper
1407-
isIamDomainStub = sinon.stub()
1406+
const getDomainStub = sinon.stub()
14081407
mockClientHelper = {
1409-
isIamDomain: isIamDomainStub,
1408+
getDomain: getDomainStub,
14101409
}
14111410

14121411
getInstanceStub = sinon.stub(DataZoneCustomClientHelper, 'getInstance').returns(mockClientHelper)
1412+
1413+
// Setup getDomain to return domain details
1414+
getDomainStub.resolves({
1415+
id: testResourceMetadata.AdditionalMetadata.DataZoneDomainId,
1416+
domainVersion: 'V2',
1417+
iamSignIns: ['IAM_ROLE', 'IAM_USER'],
1418+
})
14131419
})
14141420

14151421
afterEach(function () {
@@ -1418,7 +1424,6 @@ describe('SmusAuthenticationProvider', function () {
14181424

14191425
it('should set IAM mode context to true when domain is IAM mode', async function () {
14201426
getResourceMetadataStub.returns(testResourceMetadata)
1421-
isIamDomainStub.resolves(true)
14221427

14231428
await smusAuthProvider['initIamModeContextInSpaceEnvironment']()
14241429

@@ -1430,13 +1435,17 @@ describe('SmusAuthenticationProvider', function () {
14301435
testResourceMetadata.AdditionalMetadata.DataZoneDomainRegion
14311436
)
14321437
)
1433-
assert.ok(isIamDomainStub.calledWith(testResourceMetadata.AdditionalMetadata.DataZoneDomainId))
14341438
assert.ok(setContextStubGlobal.calledWith('aws.smus.isIamMode', true))
14351439
})
14361440

14371441
it('should set IAM mode context to false when domain is not IAM mode', async function () {
14381442
getResourceMetadataStub.returns(testResourceMetadata)
1439-
isIamDomainStub.resolves(false)
1443+
1444+
// Override getDomain to return a non-IAM domain
1445+
mockClientHelper.getDomain = sinon.stub().resolves({
1446+
id: testResourceMetadata.AdditionalMetadata.DataZoneDomainId,
1447+
domainVersion: 'V2',
1448+
})
14401449

14411450
await smusAuthProvider['initIamModeContextInSpaceEnvironment']()
14421451

@@ -1448,7 +1457,6 @@ describe('SmusAuthenticationProvider', function () {
14481457
testResourceMetadata.AdditionalMetadata.DataZoneDomainRegion
14491458
)
14501459
)
1451-
assert.ok(isIamDomainStub.calledWith(testResourceMetadata.AdditionalMetadata.DataZoneDomainId))
14521460
assert.ok(setContextStubGlobal.calledWith('aws.smus.isIamMode', false))
14531461
})
14541462

@@ -1460,7 +1468,6 @@ describe('SmusAuthenticationProvider', function () {
14601468
assert.ok(getResourceMetadataStub.called)
14611469
assert.ok(getDerCredentialsProviderStub.notCalled)
14621470
assert.ok(getInstanceStub.notCalled)
1463-
assert.ok(isIamDomainStub.notCalled)
14641471
assert.ok(setContextStubGlobal.notCalled)
14651472
})
14661473

@@ -1474,7 +1481,6 @@ describe('SmusAuthenticationProvider', function () {
14741481
assert.ok(getResourceMetadataStub.called)
14751482
assert.ok(getDerCredentialsProviderStub.called)
14761483
assert.ok(getInstanceStub.notCalled)
1477-
assert.ok(isIamDomainStub.notCalled)
14781484
assert.ok(setContextStubGlobal.calledWith('aws.smus.isIamMode', false))
14791485
})
14801486
})

packages/core/src/test/sagemakerunifiedstudio/shared/client/datazoneCustomClientHelper.test.ts

Lines changed: 4 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,8 @@ describe('DataZoneCustomClientHelper', () => {
247247
managedAccountId: '123456789012',
248248
status: 'AVAILABLE',
249249
createdAt: new Date(),
250+
domainVersion: 'V2',
251+
iamSignIns: ['IAM_ROLE'],
250252
preferences: { DOMAIN_MODE: 'STANDARD' },
251253
},
252254
{
@@ -256,6 +258,8 @@ describe('DataZoneCustomClientHelper', () => {
256258
managedAccountId: '123456789012',
257259
status: 'AVAILABLE',
258260
createdAt: new Date(),
261+
domainVersion: 'V2',
262+
iamSignIns: ['IAM_ROLE', 'IAM_USER'],
259263
preferences: { DOMAIN_MODE: 'EXPRESS' },
260264
},
261265
] as DataZoneDomain[],
@@ -412,65 +416,6 @@ describe('DataZoneCustomClientHelper', () => {
412416
})
413417
})
414418

415-
describe('isIamDomain', () => {
416-
it('should return true for EXPRESS domain', async () => {
417-
const mockDomainId = 'dzd_express123'
418-
const mockResponse = {
419-
id: mockDomainId,
420-
name: 'Express Domain',
421-
arn: `arn:aws:datazone:us-east-1:123456789012:domain/${mockDomainId}`,
422-
status: 'AVAILABLE',
423-
preferences: { DOMAIN_MODE: 'EXPRESS' },
424-
}
425-
426-
const getDomainStub = sinon.stub(client, 'getDomain').resolves(mockResponse)
427-
428-
const result = await client.isIamDomain(mockDomainId)
429-
430-
assert.strictEqual(result, true)
431-
assert.ok(getDomainStub.calledOnce)
432-
assert.strictEqual(getDomainStub.firstCall.args[0], mockDomainId)
433-
})
434-
435-
it('should return false for STANDARD domain', async () => {
436-
const mockDomainId = 'dzd_standard123'
437-
const mockResponse = {
438-
id: mockDomainId,
439-
name: 'Standard Domain',
440-
arn: `arn:aws:datazone:us-east-1:123456789012:domain/${mockDomainId}`,
441-
status: 'AVAILABLE',
442-
preferences: { DOMAIN_MODE: 'STANDARD' },
443-
}
444-
445-
const getDomainStub = sinon.stub(client, 'getDomain').resolves(mockResponse)
446-
447-
const result = await client.isIamDomain(mockDomainId)
448-
449-
assert.strictEqual(result, false)
450-
assert.ok(getDomainStub.calledOnce)
451-
assert.strictEqual(getDomainStub.firstCall.args[0], mockDomainId)
452-
})
453-
454-
it('should return false for domain without preferences', async () => {
455-
const mockDomainId = 'dzd_no_prefs123'
456-
const mockResponse = {
457-
id: mockDomainId,
458-
name: 'Domain Without Preferences',
459-
arn: `arn:aws:datazone:us-east-1:123456789012:domain/${mockDomainId}`,
460-
status: 'AVAILABLE',
461-
// No preferences field
462-
}
463-
464-
const getDomainStub = sinon.stub(client, 'getDomain').resolves(mockResponse)
465-
466-
const result = await client.isIamDomain(mockDomainId)
467-
468-
assert.strictEqual(result, false)
469-
assert.ok(getDomainStub.calledOnce)
470-
assert.strictEqual(getDomainStub.firstCall.args[0], mockDomainId)
471-
})
472-
})
473-
474419
describe('searchGroupProfiles', () => {
475420
const mockDomainId = 'dzd_test123'
476421

0 commit comments

Comments
 (0)