Merge pull request #35874 from romangr

* gh-35874:
  Polish "Support custom token validators for OAuth2"
  Support custom token validators for OAuth2

Closes gh-35874
pull/35191/head
Andy Wilkinson 1 year ago
commit 83d5d89efc

@ -61,6 +61,7 @@ import org.springframework.util.CollectionUtils;
* @author HaiTao Zhang * @author HaiTao Zhang
* @author Anastasiia Losieva * @author Anastasiia Losieva
* @author Mushtaq Ahmed * @author Mushtaq Ahmed
* @author Roman Golovin
*/ */
@Configuration(proxyBeanMethods = false) @Configuration(proxyBeanMethods = false)
class ReactiveOAuth2ResourceServerJwkConfiguration { class ReactiveOAuth2ResourceServerJwkConfiguration {
@ -71,8 +72,12 @@ class ReactiveOAuth2ResourceServerJwkConfiguration {
private final OAuth2ResourceServerProperties.Jwt properties; private final OAuth2ResourceServerProperties.Jwt properties;
JwtConfiguration(OAuth2ResourceServerProperties properties) { private final List<OAuth2TokenValidator<Jwt>> additionalValidators;
JwtConfiguration(OAuth2ResourceServerProperties properties,
ObjectProvider<OAuth2TokenValidator<Jwt>> additionalValidators) {
this.properties = properties.getJwt(); this.properties = properties.getJwt();
this.additionalValidators = additionalValidators.orderedStream().toList();
} }
@Bean @Bean
@ -98,13 +103,16 @@ class ReactiveOAuth2ResourceServerJwkConfiguration {
private OAuth2TokenValidator<Jwt> getValidators(OAuth2TokenValidator<Jwt> defaultValidator) { private OAuth2TokenValidator<Jwt> getValidators(OAuth2TokenValidator<Jwt> defaultValidator) {
List<String> audiences = this.properties.getAudiences(); List<String> audiences = this.properties.getAudiences();
if (CollectionUtils.isEmpty(audiences)) { if (CollectionUtils.isEmpty(audiences) && this.additionalValidators.isEmpty()) {
return defaultValidator; return defaultValidator;
} }
List<OAuth2TokenValidator<Jwt>> validators = new ArrayList<>(); List<OAuth2TokenValidator<Jwt>> validators = new ArrayList<>();
validators.add(defaultValidator); validators.add(defaultValidator);
validators.add(new JwtClaimValidator<List<String>>(JwtClaimNames.AUD, if (!CollectionUtils.isEmpty(audiences)) {
(aud) -> aud != null && !Collections.disjoint(aud, audiences))); validators.add(new JwtClaimValidator<List<String>>(JwtClaimNames.AUD,
(aud) -> aud != null && !Collections.disjoint(aud, audiences)));
}
validators.addAll(this.additionalValidators);
return new DelegatingOAuth2TokenValidator<>(validators); return new DelegatingOAuth2TokenValidator<>(validators);
} }

@ -62,6 +62,7 @@ import static org.springframework.security.config.Customizer.withDefaults;
* @author Artsiom Yudovin * @author Artsiom Yudovin
* @author HaiTao Zhang * @author HaiTao Zhang
* @author Mushtaq Ahmed * @author Mushtaq Ahmed
* @author Roman Golovin
*/ */
@Configuration(proxyBeanMethods = false) @Configuration(proxyBeanMethods = false)
class OAuth2ResourceServerJwtConfiguration { class OAuth2ResourceServerJwtConfiguration {
@ -72,8 +73,12 @@ class OAuth2ResourceServerJwtConfiguration {
private final OAuth2ResourceServerProperties.Jwt properties; private final OAuth2ResourceServerProperties.Jwt properties;
JwtDecoderConfiguration(OAuth2ResourceServerProperties properties) { private final List<OAuth2TokenValidator<Jwt>> additionalValidators;
JwtDecoderConfiguration(OAuth2ResourceServerProperties properties,
ObjectProvider<OAuth2TokenValidator<Jwt>> additionalValidators) {
this.properties = properties.getJwt(); this.properties = properties.getJwt();
this.additionalValidators = additionalValidators.orderedStream().toList();
} }
@Bean @Bean
@ -98,13 +103,16 @@ class OAuth2ResourceServerJwtConfiguration {
private OAuth2TokenValidator<Jwt> getValidators(OAuth2TokenValidator<Jwt> defaultValidator) { private OAuth2TokenValidator<Jwt> getValidators(OAuth2TokenValidator<Jwt> defaultValidator) {
List<String> audiences = this.properties.getAudiences(); List<String> audiences = this.properties.getAudiences();
if (CollectionUtils.isEmpty(audiences)) { if (CollectionUtils.isEmpty(audiences) && this.additionalValidators.isEmpty()) {
return defaultValidator; return defaultValidator;
} }
List<OAuth2TokenValidator<Jwt>> validators = new ArrayList<>(); List<OAuth2TokenValidator<Jwt>> validators = new ArrayList<>();
validators.add(defaultValidator); validators.add(defaultValidator);
validators.add(new JwtClaimValidator<List<String>>(JwtClaimNames.AUD, if (!CollectionUtils.isEmpty(audiences)) {
(aud) -> aud != null && !Collections.disjoint(aud, audiences))); validators.add(new JwtClaimValidator<List<String>>(JwtClaimNames.AUD,
(aud) -> aud != null && !Collections.disjoint(aud, audiences)));
}
validators.addAll(this.additionalValidators);
return new DelegatingOAuth2TokenValidator<>(validators); return new DelegatingOAuth2TokenValidator<>(validators);
} }

@ -17,14 +17,17 @@
package org.springframework.boot.autoconfigure.security.oauth2.resource.reactive; package org.springframework.boot.autoconfigure.security.oauth2.resource.reactive;
import java.io.IOException; import java.io.IOException;
import java.net.MalformedURLException; import java.net.URI;
import java.net.URL; import java.net.URL;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.function.Consumer;
import java.util.stream.Stream; import java.util.stream.Stream;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
@ -33,6 +36,7 @@ import com.nimbusds.jose.JWSAlgorithm;
import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.MockWebServer;
import org.assertj.core.api.InstanceOfAssertFactories; import org.assertj.core.api.InstanceOfAssertFactories;
import org.assertj.core.api.ThrowingConsumer;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.mockito.InOrder; import org.mockito.InOrder;
@ -87,6 +91,7 @@ import static org.springframework.security.config.Customizer.withDefaults;
* @author HaiTao Zhang * @author HaiTao Zhang
* @author Anastasiia Losieva * @author Anastasiia Losieva
* @author Mushtaq Ahmed * @author Mushtaq Ahmed
* @author Roman Golovin
*/ */
class ReactiveOAuth2ResourceServerAutoConfigurationTests { class ReactiveOAuth2ResourceServerAutoConfigurationTests {
@ -438,7 +443,6 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests {
.run((context) -> assertThat(context).doesNotHaveBean(ReactiveOpaqueTokenIntrospector.class)); .run((context) -> assertThat(context).doesNotHaveBean(ReactiveOpaqueTokenIntrospector.class));
} }
@SuppressWarnings("unchecked")
@Test @Test
void autoConfigurationShouldConfigureResourceServerUsingJwkSetUriAndIssuerUri() throws Exception { void autoConfigurationShouldConfigureResourceServerUsingJwkSetUriAndIssuerUri() throws Exception {
this.server = new MockWebServer(); this.server = new MockWebServer();
@ -454,15 +458,11 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests {
.run((context) -> { .run((context) -> {
assertThat(context).hasSingleBean(ReactiveJwtDecoder.class); assertThat(context).hasSingleBean(ReactiveJwtDecoder.class);
ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class); ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class);
DelegatingOAuth2TokenValidator<Jwt> jwtValidator = (DelegatingOAuth2TokenValidator<Jwt>) ReflectionTestUtils validate(jwt().claim("iss", issuer), reactiveJwtDecoder,
.getField(reactiveJwtDecoder, "jwtValidator"); (validators) -> assertThat(validators).hasAtLeastOneElementOfType(JwtIssuerValidator.class));
Collection<OAuth2TokenValidator<Jwt>> tokenValidators = (Collection<OAuth2TokenValidator<Jwt>>) ReflectionTestUtils
.getField(jwtValidator, "tokenValidators");
assertThat(tokenValidators).hasAtLeastOneElementOfType(JwtIssuerValidator.class);
}); });
} }
@SuppressWarnings("unchecked")
@Test @Test
void autoConfigurationShouldNotConfigureIssuerUriAndAudienceJwtValidatorIfPropertyNotConfigured() throws Exception { void autoConfigurationShouldNotConfigureIssuerUriAndAudienceJwtValidatorIfPropertyNotConfigured() throws Exception {
this.server = new MockWebServer(); this.server = new MockWebServer();
@ -476,13 +476,8 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests {
.run((context) -> { .run((context) -> {
assertThat(context).hasSingleBean(ReactiveJwtDecoder.class); assertThat(context).hasSingleBean(ReactiveJwtDecoder.class);
ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class); ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class);
DelegatingOAuth2TokenValidator<Jwt> jwtValidator = (DelegatingOAuth2TokenValidator<Jwt>) ReflectionTestUtils validate(jwt(), reactiveJwtDecoder, (validators) -> assertThat(validators).singleElement()
.getField(reactiveJwtDecoder, "jwtValidator"); .isInstanceOf(JwtTimestampValidator.class));
Collection<OAuth2TokenValidator<Jwt>> tokenValidators = (Collection<OAuth2TokenValidator<Jwt>>) ReflectionTestUtils
.getField(jwtValidator, "tokenValidators");
assertThat(tokenValidators).hasExactlyElementsOfTypes(JwtTimestampValidator.class);
assertThat(tokenValidators).doesNotHaveAnyElementsOfTypes(JwtClaimValidator.class);
assertThat(tokenValidators).doesNotHaveAnyElementsOfTypes(JwtIssuerValidator.class);
}); });
} }
@ -502,39 +497,15 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests {
.run((context) -> { .run((context) -> {
assertThat(context).hasSingleBean(ReactiveJwtDecoder.class); assertThat(context).hasSingleBean(ReactiveJwtDecoder.class);
ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class); ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class);
validate(issuerUri, reactiveJwtDecoder); validate(
jwt().claim("iss", URI.create(issuerUri).toURL())
.claim("aud", List.of("https://test-audience.com")),
reactiveJwtDecoder,
(validators) -> assertThat(validators).hasAtLeastOneElementOfType(JwtIssuerValidator.class)
.satisfiesOnlyOnce(audClaimValidator()));
}); });
} }
@SuppressWarnings("unchecked")
private void validate(String issuerUri, ReactiveJwtDecoder jwtDecoder) throws MalformedURLException {
DelegatingOAuth2TokenValidator<Jwt> jwtValidator = (DelegatingOAuth2TokenValidator<Jwt>) ReflectionTestUtils
.getField(jwtDecoder, "jwtValidator");
Jwt.Builder builder = jwt().claim("aud", Collections.singletonList("https://test-audience.com"));
if (issuerUri != null) {
builder.claim("iss", new URL(issuerUri));
}
Jwt jwt = builder.build();
assertThat(jwtValidator.validate(jwt).hasErrors()).isFalse();
Collection<OAuth2TokenValidator<Jwt>> delegates = (Collection<OAuth2TokenValidator<Jwt>>) ReflectionTestUtils
.getField(jwtValidator, "tokenValidators");
validateDelegates(issuerUri, delegates);
}
@SuppressWarnings("unchecked")
private void validateDelegates(String issuerUri, Collection<OAuth2TokenValidator<Jwt>> delegates) {
assertThat(delegates).hasAtLeastOneElementOfType(JwtClaimValidator.class);
OAuth2TokenValidator<Jwt> delegatingValidator = delegates.stream()
.filter((v) -> v instanceof DelegatingOAuth2TokenValidator)
.findFirst()
.get();
Collection<OAuth2TokenValidator<Jwt>> nestedDelegates = (Collection<OAuth2TokenValidator<Jwt>>) ReflectionTestUtils
.getField(delegatingValidator, "tokenValidators");
if (issuerUri != null) {
assertThat(nestedDelegates).hasAtLeastOneElementOfType(JwtIssuerValidator.class);
}
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Test @Test
void autoConfigurationShouldConfigureAudienceValidatorIfPropertyProvidedAndIssuerUri() throws Exception { void autoConfigurationShouldConfigureAudienceValidatorIfPropertyProvidedAndIssuerUri() throws Exception {
@ -552,7 +523,12 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests {
Mono<ReactiveJwtDecoder> jwtDecoderSupplier = (Mono<ReactiveJwtDecoder>) ReflectionTestUtils Mono<ReactiveJwtDecoder> jwtDecoderSupplier = (Mono<ReactiveJwtDecoder>) ReflectionTestUtils
.getField(supplierJwtDecoderBean, "jwtDecoderMono"); .getField(supplierJwtDecoderBean, "jwtDecoderMono");
ReactiveJwtDecoder jwtDecoder = jwtDecoderSupplier.block(); ReactiveJwtDecoder jwtDecoder = jwtDecoderSupplier.block();
validate(issuerUri, jwtDecoder); validate(
jwt().claim("iss", URI.create(issuerUri).toURL())
.claim("aud", List.of("https://test-audience.com")),
jwtDecoder,
(validators) -> assertThat(validators).hasAtLeastOneElementOfType(JwtIssuerValidator.class)
.satisfiesOnlyOnce(audClaimValidator()));
}); });
} }
@ -570,7 +546,33 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests {
.run((context) -> { .run((context) -> {
assertThat(context).hasSingleBean(ReactiveJwtDecoder.class); assertThat(context).hasSingleBean(ReactiveJwtDecoder.class);
ReactiveJwtDecoder jwtDecoder = context.getBean(ReactiveJwtDecoder.class); ReactiveJwtDecoder jwtDecoder = context.getBean(ReactiveJwtDecoder.class);
validate(null, jwtDecoder); validate(jwt().claim("aud", List.of("https://test-audience.com")), jwtDecoder,
(validators) -> assertThat(validators).satisfiesOnlyOnce(audClaimValidator()));
});
}
@SuppressWarnings("unchecked")
@Test
void autoConfigurationShouldConfigureCustomValidators() throws Exception {
this.server = new MockWebServer();
this.server.start();
String path = "test";
String issuer = this.server.url(path).toString();
String cleanIssuerPath = cleanIssuerPath(issuer);
setupMockResponse(cleanIssuerPath);
String issuerUri = "http://" + this.server.getHostName() + ":" + this.server.getPort() + "/" + path;
this.contextRunner
.withPropertyValues("spring.security.oauth2.resourceserver.jwt.jwk-set-uri=https://jwk-set-uri.com",
"spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri)
.withUserConfiguration(CustomJwtClaimValidatorConfig.class)
.run((context) -> {
assertThat(context).hasSingleBean(ReactiveJwtDecoder.class);
ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class);
OAuth2TokenValidator<Jwt> customValidator = (OAuth2TokenValidator<Jwt>) context
.getBean("customJwtClaimValidator");
validate(jwt().claim("iss", URI.create(issuerUri).toURL()).claim("custom_claim", "custom_claim_value"),
reactiveJwtDecoder, (validators) -> assertThat(validators).contains(customValidator)
.hasAtLeastOneElementOfType(JwtIssuerValidator.class));
}); });
} }
@ -600,6 +602,30 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests {
}); });
} }
@SuppressWarnings("unchecked")
@Test
void customValidatorWhenInvalid() throws Exception {
this.server = new MockWebServer();
this.server.start();
String path = "test";
String issuer = this.server.url(path).toString();
String cleanIssuerPath = cleanIssuerPath(issuer);
setupMockResponse(cleanIssuerPath);
String issuerUri = "http://" + this.server.getHostName() + ":" + this.server.getPort() + "/" + path;
this.contextRunner
.withPropertyValues("spring.security.oauth2.resourceserver.jwt.jwk-set-uri=https://jwk-set-uri.com",
"spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri)
.withUserConfiguration(CustomJwtClaimValidatorConfig.class)
.run((context) -> {
assertThat(context).hasSingleBean(ReactiveJwtDecoder.class);
ReactiveJwtDecoder jwtDecoder = context.getBean(ReactiveJwtDecoder.class);
DelegatingOAuth2TokenValidator<Jwt> jwtValidator = (DelegatingOAuth2TokenValidator<Jwt>) ReflectionTestUtils
.getField(jwtDecoder, "jwtValidator");
Jwt jwt = jwt().claim("iss", new URL(issuerUri)).claim("custom_claim", "invalid_value").build();
assertThat(jwtValidator.validate(jwt).hasErrors()).isTrue();
});
}
private void assertFilterConfiguredWithJwtAuthenticationManager(AssertableReactiveWebApplicationContext context) { private void assertFilterConfiguredWithJwtAuthenticationManager(AssertableReactiveWebApplicationContext context) {
MatcherSecurityWebFilterChain filterChain = (MatcherSecurityWebFilterChain) context MatcherSecurityWebFilterChain filterChain = (MatcherSecurityWebFilterChain) context
.getBean(BeanIds.SPRING_SECURITY_FILTER_CHAIN); .getBean(BeanIds.SPRING_SECURITY_FILTER_CHAIN);
@ -683,6 +709,37 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests {
.subject("mock-test-subject"); .subject("mock-test-subject");
} }
@SuppressWarnings("unchecked")
private void validate(Jwt.Builder builder, ReactiveJwtDecoder jwtDecoder,
ThrowingConsumer<List<OAuth2TokenValidator<Jwt>>> validatorsConsumer) {
DelegatingOAuth2TokenValidator<Jwt> jwtValidator = (DelegatingOAuth2TokenValidator<Jwt>) ReflectionTestUtils
.getField(jwtDecoder, "jwtValidator");
assertThat(jwtValidator.validate(builder.build()).hasErrors()).isFalse();
validatorsConsumer.accept(extractValidators(jwtValidator));
}
@SuppressWarnings("unchecked")
private List<OAuth2TokenValidator<Jwt>> extractValidators(DelegatingOAuth2TokenValidator<Jwt> delegatingValidator) {
Collection<OAuth2TokenValidator<Jwt>> delegates = (Collection<OAuth2TokenValidator<Jwt>>) ReflectionTestUtils
.getField(delegatingValidator, "tokenValidators");
List<OAuth2TokenValidator<Jwt>> extracted = new ArrayList<>();
for (OAuth2TokenValidator<Jwt> delegate : delegates) {
if (delegate instanceof DelegatingOAuth2TokenValidator<Jwt> delegatingDelegate) {
extracted.addAll(extractValidators(delegatingDelegate));
}
else {
extracted.add(delegate);
}
}
return extracted;
}
private Consumer<OAuth2TokenValidator<Jwt>> audClaimValidator() {
return (validator) -> assertThat(validator).isInstanceOf(JwtClaimValidator.class)
.extracting("claim")
.isEqualTo("aud");
}
@EnableWebFluxSecurity @EnableWebFluxSecurity
static class TestConfig { static class TestConfig {
@ -740,4 +797,14 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests {
} }
@Configuration(proxyBeanMethods = false)
static class CustomJwtClaimValidatorConfig {
@Bean
JwtClaimValidator<String> customJwtClaimValidator() {
return new JwtClaimValidator<>("custom_claim", "custom_claim_value"::equals);
}
}
} }

@ -16,14 +16,16 @@
package org.springframework.boot.autoconfigure.security.oauth2.resource.servlet; package org.springframework.boot.autoconfigure.security.oauth2.resource.servlet;
import java.net.MalformedURLException; import java.net.URI;
import java.net.URL; import java.net.URL;
import java.time.Instant; import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Supplier; import java.util.function.Supplier;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
@ -33,6 +35,7 @@ import jakarta.servlet.Filter;
import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.MockWebServer;
import org.assertj.core.api.InstanceOfAssertFactories; import org.assertj.core.api.InstanceOfAssertFactories;
import org.assertj.core.api.ThrowingConsumer;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.mockito.InOrder; import org.mockito.InOrder;
@ -80,6 +83,7 @@ import static org.mockito.Mockito.mock;
* @author Artsiom Yudovin * @author Artsiom Yudovin
* @author HaiTao Zhang * @author HaiTao Zhang
* @author Mushtaq Ahmed * @author Mushtaq Ahmed
* @author Roman Golovin
*/ */
class OAuth2ResourceServerAutoConfigurationTests { class OAuth2ResourceServerAutoConfigurationTests {
@ -190,8 +194,8 @@ class OAuth2ResourceServerAutoConfigurationTests {
}); });
} }
@Test
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Test
void autoConfigurationShouldConfigureResourceServerUsingOidcIssuerUri() throws Exception { void autoConfigurationShouldConfigureResourceServerUsingOidcIssuerUri() throws Exception {
this.server = new MockWebServer(); this.server = new MockWebServer();
this.server.start(); this.server.start();
@ -215,8 +219,8 @@ class OAuth2ResourceServerAutoConfigurationTests {
assertThat(this.server.getRequestCount()).isEqualTo(2); assertThat(this.server.getRequestCount()).isEqualTo(2);
} }
@Test
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Test
void autoConfigurationShouldConfigureResourceServerUsingOidcRfc8414IssuerUri() throws Exception { void autoConfigurationShouldConfigureResourceServerUsingOidcRfc8414IssuerUri() throws Exception {
this.server = new MockWebServer(); this.server = new MockWebServer();
this.server.start(); this.server.start();
@ -240,8 +244,8 @@ class OAuth2ResourceServerAutoConfigurationTests {
assertThat(this.server.getRequestCount()).isEqualTo(3); assertThat(this.server.getRequestCount()).isEqualTo(3);
} }
@Test
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Test
void autoConfigurationShouldConfigureResourceServerUsingOAuthIssuerUri() throws Exception { void autoConfigurationShouldConfigureResourceServerUsingOAuthIssuerUri() throws Exception {
this.server = new MockWebServer(); this.server = new MockWebServer();
this.server.start(); this.server.start();
@ -472,9 +476,8 @@ class OAuth2ResourceServerAutoConfigurationTests {
.run((context) -> { .run((context) -> {
assertThat(context).hasSingleBean(JwtDecoder.class); assertThat(context).hasSingleBean(JwtDecoder.class);
JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class); JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class);
assertThat(jwtDecoder).extracting("jwtValidator.tokenValidators") validate(jwt().claim("iss", issuer), jwtDecoder,
.asInstanceOf(InstanceOfAssertFactories.collection(OAuth2TokenValidator.class)) (validators) -> assertThat(validators).hasAtLeastOneElementOfType(JwtIssuerValidator.class));
.hasAtLeastOneElementOfType(JwtIssuerValidator.class);
}); });
} }
@ -491,11 +494,8 @@ class OAuth2ResourceServerAutoConfigurationTests {
.run((context) -> { .run((context) -> {
assertThat(context).hasSingleBean(JwtDecoder.class); assertThat(context).hasSingleBean(JwtDecoder.class);
JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class); JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class);
assertThat(jwtDecoder).extracting("jwtValidator.tokenValidators") validate(jwt(), jwtDecoder, (validators) -> assertThat(validators).singleElement()
.asInstanceOf(InstanceOfAssertFactories.collection(OAuth2TokenValidator.class)) .isInstanceOf(JwtTimestampValidator.class));
.hasExactlyElementsOfTypes(JwtTimestampValidator.class)
.doesNotHaveAnyElementsOfTypes(JwtClaimValidator.class)
.doesNotHaveAnyElementsOfTypes(JwtIssuerValidator.class);
}); });
} }
@ -515,7 +515,12 @@ class OAuth2ResourceServerAutoConfigurationTests {
.run((context) -> { .run((context) -> {
assertThat(context).hasSingleBean(JwtDecoder.class); assertThat(context).hasSingleBean(JwtDecoder.class);
JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class); JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class);
validate(issuerUri, jwtDecoder); validate(
jwt().claim("iss", URI.create(issuerUri).toURL())
.claim("aud", List.of("https://test-audience.com")),
jwtDecoder,
(validators) -> assertThat(validators).hasAtLeastOneElementOfType(JwtIssuerValidator.class)
.satisfiesOnlyOnce(audClaimValidator()));
}); });
} }
@ -536,36 +541,39 @@ class OAuth2ResourceServerAutoConfigurationTests {
Supplier<JwtDecoder> jwtDecoderSupplier = (Supplier<JwtDecoder>) ReflectionTestUtils Supplier<JwtDecoder> jwtDecoderSupplier = (Supplier<JwtDecoder>) ReflectionTestUtils
.getField(supplierJwtDecoderBean, "delegate"); .getField(supplierJwtDecoderBean, "delegate");
JwtDecoder jwtDecoder = jwtDecoderSupplier.get(); JwtDecoder jwtDecoder = jwtDecoderSupplier.get();
validate(issuerUri, jwtDecoder); validate(
jwt().claim("iss", URI.create(issuerUri).toURL())
.claim("aud", List.of("https://test-audience.com")),
jwtDecoder,
(validators) -> assertThat(validators).hasAtLeastOneElementOfType(JwtIssuerValidator.class)
.satisfiesOnlyOnce(audClaimValidator()));
}); });
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private void validate(String issuerUri, JwtDecoder jwtDecoder) throws MalformedURLException { @Test
DelegatingOAuth2TokenValidator<Jwt> jwtValidator = (DelegatingOAuth2TokenValidator<Jwt>) ReflectionTestUtils void autoConfigurationShouldConfigureCustomValidators() throws Exception {
.getField(jwtDecoder, "jwtValidator"); this.server = new MockWebServer();
Jwt.Builder builder = jwt().claim("aud", Collections.singletonList("https://test-audience.com")); this.server.start();
if (issuerUri != null) { String path = "test";
builder.claim("iss", new URL(issuerUri)); String issuer = this.server.url(path).toString();
} String cleanIssuerPath = cleanIssuerPath(issuer);
Jwt jwt = builder.build(); setupMockResponse(cleanIssuerPath);
assertThat(jwtValidator.validate(jwt).hasErrors()).isFalse(); String issuerUri = "http://" + this.server.getHostName() + ":" + this.server.getPort() + "/" + path;
Collection<OAuth2TokenValidator<Jwt>> delegates = (Collection<OAuth2TokenValidator<Jwt>>) ReflectionTestUtils this.contextRunner.withPropertyValues("spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri)
.getField(jwtValidator, "tokenValidators"); .withUserConfiguration(CustomJwtClaimValidatorConfig.class)
validateDelegates(issuerUri, delegates); .run((context) -> {
} SupplierJwtDecoder supplierJwtDecoderBean = context.getBean(SupplierJwtDecoder.class);
Supplier<JwtDecoder> jwtDecoderSupplier = (Supplier<JwtDecoder>) ReflectionTestUtils
private void validateDelegates(String issuerUri, Collection<OAuth2TokenValidator<Jwt>> delegates) { .getField(supplierJwtDecoderBean, "delegate");
assertThat(delegates).hasAtLeastOneElementOfType(JwtClaimValidator.class); JwtDecoder jwtDecoder = jwtDecoderSupplier.get();
OAuth2TokenValidator<Jwt> delegatingValidator = delegates.stream() assertThat(context).hasBean("customJwtClaimValidator");
.filter((v) -> v instanceof DelegatingOAuth2TokenValidator) OAuth2TokenValidator<Jwt> customValidator = (OAuth2TokenValidator<Jwt>) context
.findFirst() .getBean("customJwtClaimValidator");
.get(); validate(jwt().claim("iss", URI.create(issuerUri).toURL()).claim("custom_claim", "custom_claim_value"),
if (issuerUri != null) { jwtDecoder, (validators) -> assertThat(validators).contains(customValidator)
assertThat(delegatingValidator).extracting("tokenValidators") .hasAtLeastOneElementOfType(JwtIssuerValidator.class));
.asInstanceOf(InstanceOfAssertFactories.collection(OAuth2TokenValidator.class)) });
.hasAtLeastOneElementOfType(JwtIssuerValidator.class);
}
} }
@Test @Test
@ -582,7 +590,8 @@ class OAuth2ResourceServerAutoConfigurationTests {
.run((context) -> { .run((context) -> {
assertThat(context).hasSingleBean(JwtDecoder.class); assertThat(context).hasSingleBean(JwtDecoder.class);
JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class); JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class);
validate(null, jwtDecoder); validate(jwt().claim("aud", List.of("https://test-audience.com")), jwtDecoder,
(validators) -> assertThat(validators).satisfiesOnlyOnce(audClaimValidator()));
}); });
} }
@ -692,6 +701,37 @@ class OAuth2ResourceServerAutoConfigurationTests {
.subject("mock-test-subject"); .subject("mock-test-subject");
} }
@SuppressWarnings("unchecked")
private void validate(Jwt.Builder builder, JwtDecoder jwtDecoder,
ThrowingConsumer<List<OAuth2TokenValidator<Jwt>>> validatorsConsumer) {
DelegatingOAuth2TokenValidator<Jwt> jwtValidator = (DelegatingOAuth2TokenValidator<Jwt>) ReflectionTestUtils
.getField(jwtDecoder, "jwtValidator");
assertThat(jwtValidator.validate(builder.build()).hasErrors()).isFalse();
validatorsConsumer.accept(extractValidators(jwtValidator));
}
@SuppressWarnings("unchecked")
private List<OAuth2TokenValidator<Jwt>> extractValidators(DelegatingOAuth2TokenValidator<Jwt> delegatingValidator) {
Collection<OAuth2TokenValidator<Jwt>> delegates = (Collection<OAuth2TokenValidator<Jwt>>) ReflectionTestUtils
.getField(delegatingValidator, "tokenValidators");
List<OAuth2TokenValidator<Jwt>> extracted = new ArrayList<>();
for (OAuth2TokenValidator<Jwt> delegate : delegates) {
if (delegate instanceof DelegatingOAuth2TokenValidator<Jwt> delegatingDelegate) {
extracted.addAll(extractValidators(delegatingDelegate));
}
else {
extracted.add(delegate);
}
}
return extracted;
}
private Consumer<OAuth2TokenValidator<Jwt>> audClaimValidator() {
return (validator) -> assertThat(validator).isInstanceOf(JwtClaimValidator.class)
.extracting("claim")
.isEqualTo("aud");
}
@Configuration(proxyBeanMethods = false) @Configuration(proxyBeanMethods = false)
@EnableWebSecurity @EnableWebSecurity
static class TestConfig { static class TestConfig {
@ -745,4 +785,14 @@ class OAuth2ResourceServerAutoConfigurationTests {
} }
@Configuration(proxyBeanMethods = false)
static class CustomJwtClaimValidatorConfig {
@Bean
JwtClaimValidator<String> customJwtClaimValidator() {
return new JwtClaimValidator<>("custom_claim", "custom_claim_value"::equals);
}
}
} }

Loading…
Cancel
Save