diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenExchangeAuthenticationContext.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenExchangeAuthenticationContext.java new file mode 100644 index 00000000000..dc8751dfa4c --- /dev/null +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenExchangeAuthenticationContext.java @@ -0,0 +1,157 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Consumer; + +import org.jspecify.annotations.Nullable; + +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.util.Assert; + +/** + * An {@link OAuth2AuthenticationContext} that holds an + * {@link OAuth2TokenExchangeAuthenticationToken} and additional information and is used + * when validating the OAuth 2.0 Token Exchange Grant Request. + * + * @author Rakesh Kumar Singh + * @since 7.1 + * @see OAuth2AuthenticationContext + * @see OAuth2TokenExchangeAuthenticationToken + * @see OAuth2TokenExchangeAuthenticationProvider#setAuthenticationValidator(Consumer) + */ +public final class OAuth2TokenExchangeAuthenticationContext implements OAuth2AuthenticationContext { + + private static final String ACTOR_AUTHORIZATION_ATTR_NAME = OAuth2TokenExchangeAuthenticationContext.class.getName() + .concat(".actorAuthorization"); + + private final Map context; + + private OAuth2TokenExchangeAuthenticationContext(Map context) { + this.context = Collections.unmodifiableMap(new HashMap<>(context)); + } + + @SuppressWarnings("unchecked") + @Override + public @Nullable V get(Object key) { + return hasKey(key) ? (V) this.context.get(key) : null; + } + + @Override + public boolean hasKey(Object key) { + Assert.notNull(key, "key cannot be null"); + return this.context.containsKey(key); + } + + /** + * Returns the {@link RegisteredClient registered client}. + * @return the {@link RegisteredClient} + */ + public RegisteredClient getRegisteredClient() { + RegisteredClient registeredClient = get(RegisteredClient.class); + Assert.notNull(registeredClient, "registeredClient cannot be null"); + return registeredClient; + } + + /** + * Returns the subject {@link OAuth2Authorization authorization}. + * @return the subject {@link OAuth2Authorization} + */ + public OAuth2Authorization getSubjectAuthorization() { + OAuth2Authorization subjectAuthorization = get(OAuth2Authorization.class); + Assert.notNull(subjectAuthorization, "subjectAuthorization cannot be null"); + return subjectAuthorization; + } + + /** + * Returns the actor {@link OAuth2Authorization authorization}, or {@code null} if not + * available (impersonation case). + * @return the actor {@link OAuth2Authorization}, or {@code null} + */ + public @Nullable OAuth2Authorization getActorAuthorization() { + return get(ACTOR_AUTHORIZATION_ATTR_NAME); + } + + /** + * Constructs a new {@link Builder} with the provided + * {@link OAuth2TokenExchangeAuthenticationToken}. + * @param authentication the {@link OAuth2TokenExchangeAuthenticationToken} + * @return the {@link Builder} + */ + public static Builder with(OAuth2TokenExchangeAuthenticationToken authentication) { + return new Builder(authentication); + } + + /** + * A builder for {@link OAuth2TokenExchangeAuthenticationContext}. + */ + public static final class Builder + extends AbstractBuilder { + + private Builder(OAuth2TokenExchangeAuthenticationToken authentication) { + super(authentication); + } + + /** + * Sets the {@link RegisteredClient registered client}. + * @param registeredClient the {@link RegisteredClient} + * @return the {@link Builder} for further configuration + */ + public Builder registeredClient(RegisteredClient registeredClient) { + return put(RegisteredClient.class, registeredClient); + } + + /** + * Sets the subject {@link OAuth2Authorization}. + * @param subjectAuthorization the subject {@link OAuth2Authorization} + * @return the {@link Builder} for further configuration + */ + public Builder subjectAuthorization(OAuth2Authorization subjectAuthorization) { + return put(OAuth2Authorization.class, subjectAuthorization); + } + + /** + * Sets the actor {@link OAuth2Authorization}, or {@code null} for impersonation. + * @param actorAuthorization the actor {@link OAuth2Authorization}, may be + * {@code null} + * @return the {@link Builder} for further configuration + */ + public Builder actorAuthorization(@Nullable OAuth2Authorization actorAuthorization) { + if (actorAuthorization != null) { + getContext().put(ACTOR_AUTHORIZATION_ATTR_NAME, actorAuthorization); + } + return getThis(); + } + + /** + * Builds a new {@link OAuth2TokenExchangeAuthenticationContext}. + * @return the {@link OAuth2TokenExchangeAuthenticationContext} + */ + @Override + public OAuth2TokenExchangeAuthenticationContext build() { + Assert.notNull(get(RegisteredClient.class), "registeredClient cannot be null"); + Assert.notNull(get(OAuth2Authorization.class), "subjectAuthorization cannot be null"); + return new OAuth2TokenExchangeAuthenticationContext(getContext()); + } + + } + +} diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenExchangeAuthenticationProvider.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenExchangeAuthenticationProvider.java index b0760ad264f..6c6c40b7015 100644 --- a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenExchangeAuthenticationProvider.java +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenExchangeAuthenticationProvider.java @@ -25,6 +25,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Consumer; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -86,6 +87,8 @@ public final class OAuth2TokenExchangeAuthenticationProvider implements Authenti private final OAuth2TokenGenerator tokenGenerator; + private Consumer authenticationValidator = new OAuth2TokenExchangeAuthenticationValidator(); + /** * Constructs an {@code OAuth2TokenExchangeAuthenticationProvider} using the provided * parameters. @@ -204,12 +207,20 @@ else if (authorizedActorClaims != null) { throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT); } + OAuth2TokenExchangeAuthenticationContext authenticationContext = OAuth2TokenExchangeAuthenticationContext + .with(tokenExchangeAuthentication) + .registeredClient(registeredClient) + .subjectAuthorization(subjectAuthorization) + .actorAuthorization(actorAuthorization) + .build(); + this.authenticationValidator.accept(authenticationContext); + Set authorizedScopes = Collections.emptySet(); if (!CollectionUtils.isEmpty(tokenExchangeAuthentication.getScopes())) { - authorizedScopes = validateRequestedScopes(registeredClient, tokenExchangeAuthentication.getScopes()); + authorizedScopes = new LinkedHashSet<>(tokenExchangeAuthentication.getScopes()); } else if (!CollectionUtils.isEmpty(subjectAuthorization.getAuthorizedScopes())) { - authorizedScopes = validateRequestedScopes(registeredClient, subjectAuthorization.getAuthorizedScopes()); + authorizedScopes = new LinkedHashSet<>(subjectAuthorization.getAuthorizedScopes()); } // Verify the DPoP Proof (if available) @@ -285,16 +296,6 @@ private static boolean isValidTokenType(String tokenType, OAuth2Authorization.To && OAuth2TokenFormat.SELF_CONTAINED.getValue().equals(tokenFormat); } - private static Set validateRequestedScopes(RegisteredClient registeredClient, Set requestedScopes) { - for (String requestedScope : requestedScopes) { - if (!registeredClient.getScopes().contains(requestedScope)) { - throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_SCOPE); - } - } - - return new LinkedHashSet<>(requestedScopes); - } - private static void validateClaims(Map expectedClaims, @Nullable Map actualClaims, String... claimNames) { if (actualClaims == null) { @@ -342,4 +343,25 @@ public boolean supports(Class authentication) { return OAuth2TokenExchangeAuthenticationToken.class.isAssignableFrom(authentication); } + /** + * Sets the {@code Consumer} providing access to the + * {@link OAuth2TokenExchangeAuthenticationContext} and is responsible for validating + * specific OAuth 2.0 Token Exchange Grant Request parameters associated in the + * {@link OAuth2TokenExchangeAuthenticationToken}. The default authentication validator + * is {@link OAuth2TokenExchangeAuthenticationValidator}. + * + *

+ * NOTE: The authentication validator MUST throw + * {@link org.springframework.security.oauth2.core.OAuth2AuthenticationException} if + * validation fails. + * @param authenticationValidator the {@code Consumer} providing access to the + * {@link OAuth2TokenExchangeAuthenticationContext} and is responsible for validating + * specific OAuth 2.0 Token Exchange Grant Request parameters + */ + public void setAuthenticationValidator( + Consumer authenticationValidator) { + Assert.notNull(authenticationValidator, "authenticationValidator cannot be null"); + this.authenticationValidator = authenticationValidator; + } + } diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenExchangeAuthenticationValidator.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenExchangeAuthenticationValidator.java new file mode 100644 index 00000000000..e67672300df --- /dev/null +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenExchangeAuthenticationValidator.java @@ -0,0 +1,90 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.util.Set; +import java.util.function.Consumer; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogMessage; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.util.CollectionUtils; + +/** + * A {@code Consumer} providing access to the + * {@link OAuth2TokenExchangeAuthenticationContext} containing an + * {@link OAuth2TokenExchangeAuthenticationToken} and is the default + * {@link OAuth2TokenExchangeAuthenticationProvider#setAuthenticationValidator(Consumer) + * authentication validator} used for validating specific OAuth 2.0 Token Exchange Grant + * Request parameters. + * + *

+ * The default implementation validates + * {@link OAuth2TokenExchangeAuthenticationToken#getScopes()}. If validation fails, an + * {@link OAuth2AuthenticationException} is thrown. + * + * @author Rakesh Kumar Singh + * @since 7.1 + * @see OAuth2TokenExchangeAuthenticationContext + * @see OAuth2TokenExchangeAuthenticationToken + * @see OAuth2TokenExchangeAuthenticationProvider#setAuthenticationValidator(Consumer) + */ +public final class OAuth2TokenExchangeAuthenticationValidator + implements Consumer { + + private static final Log LOGGER = LogFactory.getLog(OAuth2TokenExchangeAuthenticationValidator.class); + + /** + * The default validator for + * {@link OAuth2TokenExchangeAuthenticationToken#getScopes()}. + */ + public static final Consumer DEFAULT_SCOPE_VALIDATOR = OAuth2TokenExchangeAuthenticationValidator::validateScope; + + private final Consumer authenticationValidator = DEFAULT_SCOPE_VALIDATOR; + + @Override + public void accept(OAuth2TokenExchangeAuthenticationContext authenticationContext) { + this.authenticationValidator.accept(authenticationContext); + } + + private static void validateScope(OAuth2TokenExchangeAuthenticationContext authenticationContext) { + OAuth2TokenExchangeAuthenticationToken tokenExchangeAuthentication = authenticationContext.getAuthentication(); + RegisteredClient registeredClient = authenticationContext.getRegisteredClient(); + OAuth2Authorization subjectAuthorization = authenticationContext.getSubjectAuthorization(); + + Set requestedScopes = tokenExchangeAuthentication.getScopes(); + if (CollectionUtils.isEmpty(requestedScopes)) { + requestedScopes = subjectAuthorization.getAuthorizedScopes(); + } + + Set allowedScopes = registeredClient.getScopes(); + if (!requestedScopes.isEmpty() && !allowedScopes.containsAll(requestedScopes)) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug(LogMessage.format( + "Invalid request: requested scope is not allowed" + " for registered client '%s'", + registeredClient.getId())); + } + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_SCOPE); + } + } + +} diff --git a/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenExchangeAuthenticationProviderTests.java b/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenExchangeAuthenticationProviderTests.java index dfa782c3ccb..b84497ae57c 100644 --- a/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenExchangeAuthenticationProviderTests.java +++ b/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenExchangeAuthenticationProviderTests.java @@ -124,6 +124,66 @@ public void tearDown() { AuthorizationServerContextHolder.resetContext(); } + @Test + public void setAuthenticationValidatorWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authenticationProvider.setAuthenticationValidator(null)) + .withMessage("authenticationValidator cannot be null"); + // @formatter:on + } + + @Test + public void authenticateWhenCustomAuthenticationValidatorThenUsed() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .build(); + OAuth2TokenExchangeAuthenticationToken authentication = createDelegationRequest(registeredClient); + OAuth2Authorization subjectAuthorization = TestOAuth2Authorizations.authorization(registeredClient) + .token(createAccessToken(SUBJECT_TOKEN)) + .build(); + OAuth2Authorization actorAuthorization = TestOAuth2Authorizations.authorization(registeredClient) + .token(createAccessToken(ACTOR_TOKEN)) + .build(); + given(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))) + .willReturn(subjectAuthorization, actorAuthorization); + OAuth2AccessToken accessToken = createAccessToken("token-value"); + given(this.tokenGenerator.generate(any(OAuth2TokenContext.class))).willReturn(accessToken); + + Consumer customValidator = mock(Consumer.class); + this.authenticationProvider.setAuthenticationValidator(customValidator); + this.authenticationProvider.authenticate(authentication); + + verify(customValidator).accept(any(OAuth2TokenExchangeAuthenticationContext.class)); + } + + @Test + public void authenticateWhenCustomAuthenticationValidatorThrowsThenPropagated() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .build(); + OAuth2TokenExchangeAuthenticationToken authentication = createDelegationRequest(registeredClient); + OAuth2Authorization subjectAuthorization = TestOAuth2Authorizations.authorization(registeredClient) + .token(createAccessToken(SUBJECT_TOKEN)) + .build(); + OAuth2Authorization actorAuthorization = TestOAuth2Authorizations.authorization(registeredClient) + .token(createAccessToken(ACTOR_TOKEN)) + .build(); + given(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))) + .willReturn(subjectAuthorization, actorAuthorization); + + this.authenticationProvider + .setAuthenticationValidator((ctx) -> { throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST); }); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + verifyNoInteractions(this.tokenGenerator); + } + @Test public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { // @formatter:off