Auto-configure CORS options for GraphQL web endpoints

This commit adds `"spring.graphql.cors.*"` configuration properties to
customize the CORS configuration for GraphQL web endpoints.

See gh-29140
pull/29177/head
Brian Clozel 3 years ago
parent 0099460155
commit 4ef9b9e3e5

@ -0,0 +1,160 @@
/*
* Copyright 2012-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.autoconfigure.graphql;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.List;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.PropertyMapper;
import org.springframework.boot.convert.DurationUnit;
import org.springframework.lang.Nullable;
import org.springframework.util.CollectionUtils;
import org.springframework.web.cors.CorsConfiguration;
/**
* Configuration properties for GraphQL endpoint's CORS support.
*
* @author Andy Wilkinson
* @author Brian Clozel
* @since 2.7.0
*/
@ConfigurationProperties(prefix = "spring.graphql.cors")
public class GraphQlCorsProperties {
/**
* Comma-separated list of origins to allow with '*' allowing all origins. When
* allow-credentials is enabled, '*' cannot be used, and setting origin patterns
* should be considered instead. When neither allowed origins nor allowed origin
* patterns are set, cross-origin requests are effectively disabled.
*/
private List<String> allowedOrigins = new ArrayList<>();
/**
* Comma-separated list of origin patterns to allow. Unlike allowed origins which only
* support '*', origin patterns are more flexible, e.g. 'https://*.example.com', and
* can be used with allow-credentials. When neither allowed origins nor allowed origin
* patterns are set, cross-origin requests are effectively disabled.
*/
private List<String> allowedOriginPatterns = new ArrayList<>();
/**
* Comma-separated list of HTTP methods to allow. '*' allows all methods. When not
* set, defaults to GET.
*/
private List<String> allowedMethods = new ArrayList<>();
/**
* Comma-separated list of HTTP headers to allow in a request. '*' allows all headers.
*/
private List<String> allowedHeaders = new ArrayList<>();
/**
* Comma-separated list of headers to include in a response.
*/
private List<String> exposedHeaders = new ArrayList<>();
/**
* Whether credentials are supported. When not set, credentials are not supported.
*/
@Nullable
private Boolean allowCredentials;
/**
* How long the response from a pre-flight request can be cached by clients. If a
* duration suffix is not specified, seconds will be used.
*/
@DurationUnit(ChronoUnit.SECONDS)
private Duration maxAge = Duration.ofSeconds(1800);
public List<String> getAllowedOrigins() {
return this.allowedOrigins;
}
public void setAllowedOrigins(List<String> allowedOrigins) {
this.allowedOrigins = allowedOrigins;
}
public List<String> getAllowedOriginPatterns() {
return this.allowedOriginPatterns;
}
public void setAllowedOriginPatterns(List<String> allowedOriginPatterns) {
this.allowedOriginPatterns = allowedOriginPatterns;
}
public List<String> getAllowedMethods() {
return this.allowedMethods;
}
public void setAllowedMethods(List<String> allowedMethods) {
this.allowedMethods = allowedMethods;
}
public List<String> getAllowedHeaders() {
return this.allowedHeaders;
}
public void setAllowedHeaders(List<String> allowedHeaders) {
this.allowedHeaders = allowedHeaders;
}
public List<String> getExposedHeaders() {
return this.exposedHeaders;
}
public void setExposedHeaders(List<String> exposedHeaders) {
this.exposedHeaders = exposedHeaders;
}
@Nullable
public Boolean getAllowCredentials() {
return this.allowCredentials;
}
public void setAllowCredentials(Boolean allowCredentials) {
this.allowCredentials = allowCredentials;
}
public Duration getMaxAge() {
return this.maxAge;
}
public void setMaxAge(Duration maxAge) {
this.maxAge = maxAge;
}
@Nullable
public CorsConfiguration toCorsConfiguration() {
if (CollectionUtils.isEmpty(this.allowedOrigins) && CollectionUtils.isEmpty(this.allowedOriginPatterns)) {
return null;
}
PropertyMapper map = PropertyMapper.get();
CorsConfiguration config = new CorsConfiguration();
map.from(this::getAllowedOrigins).to(config::setAllowedOrigins);
map.from(this::getAllowedOriginPatterns).to(config::setAllowedOriginPatterns);
map.from(this::getAllowedHeaders).whenNot(CollectionUtils::isEmpty).to(config::setAllowedHeaders);
map.from(this::getAllowedMethods).whenNot(CollectionUtils::isEmpty).to(config::setAllowedMethods);
map.from(this::getExposedHeaders).whenNot(CollectionUtils::isEmpty).to(config::setExposedHeaders);
map.from(this::getMaxAge).whenNonNull().as(Duration::getSeconds).to(config::setMaxAge);
map.from(this::getAllowCredentials).whenNonNull().to(config::setAllowCredentials);
return config;
}
}

@ -31,7 +31,9 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication;
import org.springframework.boot.autoconfigure.graphql.GraphQlAutoConfiguration;
import org.springframework.boot.autoconfigure.graphql.GraphQlCorsProperties;
import org.springframework.boot.autoconfigure.graphql.GraphQlProperties;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ResourceLoader;
@ -45,6 +47,9 @@ import org.springframework.graphql.web.webflux.SchemaHandler;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.reactive.config.CorsRegistry;
import org.springframework.web.reactive.config.WebFluxConfigurer;
import org.springframework.web.reactive.function.server.RouterFunction;
import org.springframework.web.reactive.function.server.RouterFunctions;
import org.springframework.web.reactive.function.server.ServerResponse;
@ -64,6 +69,7 @@ import static org.springframework.web.reactive.function.server.RequestPredicates
@ConditionalOnClass({ GraphQL.class, GraphQlHttpHandler.class })
@ConditionalOnBean(GraphQlService.class)
@AutoConfigureAfter(GraphQlAutoConfiguration.class)
@EnableConfigurationProperties(GraphQlCorsProperties.class)
public class GraphQlWebFluxAutoConfiguration {
private static final Log logger = LogFactory.getLog(GraphQlWebFluxAutoConfiguration.class);
@ -111,4 +117,26 @@ public class GraphQlWebFluxAutoConfiguration {
return builder.build();
}
@Configuration(proxyBeanMethods = false)
public static class GraphQlEndpointCorsConfiguration implements WebFluxConfigurer {
final GraphQlProperties graphQlProperties;
final GraphQlCorsProperties corsProperties;
public GraphQlEndpointCorsConfiguration(GraphQlProperties graphQlProps, GraphQlCorsProperties corsProps) {
this.graphQlProperties = graphQlProps;
this.corsProperties = corsProps;
}
@Override
public void addCorsMappings(CorsRegistry registry) {
CorsConfiguration configuration = this.corsProperties.toCorsConfiguration();
if (configuration != null) {
registry.addMapping(this.graphQlProperties.getPath()).combine(configuration);
}
}
}
}

@ -31,7 +31,9 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication;
import org.springframework.boot.autoconfigure.graphql.GraphQlAutoConfiguration;
import org.springframework.boot.autoconfigure.graphql.GraphQlCorsProperties;
import org.springframework.boot.autoconfigure.graphql.GraphQlProperties;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ResourceLoader;
@ -46,6 +48,9 @@ import org.springframework.graphql.web.webmvc.SchemaHandler;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.servlet.config.annotation.CorsRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
import org.springframework.web.servlet.function.RequestPredicates;
import org.springframework.web.servlet.function.RouterFunction;
import org.springframework.web.servlet.function.RouterFunctions;
@ -63,6 +68,7 @@ import org.springframework.web.servlet.function.ServerResponse;
@ConditionalOnClass({ GraphQL.class, GraphQlHttpHandler.class })
@ConditionalOnBean(GraphQlService.class)
@AutoConfigureAfter(GraphQlAutoConfiguration.class)
@EnableConfigurationProperties(GraphQlCorsProperties.class)
public class GraphQlWebMvcAutoConfiguration {
private static final Log logger = LogFactory.getLog(GraphQlWebMvcAutoConfiguration.class);
@ -112,4 +118,26 @@ public class GraphQlWebMvcAutoConfiguration {
return builder.build();
}
@Configuration(proxyBeanMethods = false)
public static class GraphQlEndpointCorsConfiguration implements WebMvcConfigurer {
final GraphQlProperties graphQlProperties;
final GraphQlCorsProperties corsProperties;
public GraphQlEndpointCorsConfiguration(GraphQlProperties graphQlProps, GraphQlCorsProperties corsProps) {
this.graphQlProperties = graphQlProps;
this.corsProperties = corsProps;
}
@Override
public void addCorsMappings(CorsRegistry registry) {
CorsConfiguration configuration = this.corsProperties.toCorsConfiguration();
if (configuration != null) {
registry.addMapping(this.graphQlProperties.getPath()).combine(configuration);
}
}
}
}

@ -2210,6 +2210,45 @@
}
]
},
{
"name": "spring.graphql.cors.allowed-headers",
"values": [
{
"value": "*"
}
],
"providers": [
{
"name": "any"
}
]
},
{
"name": "spring.graphql.cors.allowed-methods",
"values": [
{
"value": "*"
}
],
"providers": [
{
"name": "any"
}
]
},
{
"name": "spring.graphql.cors.allowed-origins",
"values": [
{
"value": "*"
}
],
"providers": [
{
"name": "any"
}
]
},
{
"name": "spring.jmx.server",
"providers": [

@ -34,6 +34,7 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.graphql.execution.RuntimeWiringConfigurer;
import org.springframework.graphql.web.WebInterceptor;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.test.web.reactive.server.WebTestClient;
@ -55,7 +56,9 @@ class GraphQlWebFluxAutoConfigurationTests {
GraphQlWebFluxAutoConfiguration.class))
.withUserConfiguration(DataFetchersConfiguration.class, CustomWebInterceptor.class)
.withPropertyValues("spring.main.web-application-type=reactive", "spring.graphql.graphiql.enabled=true",
"spring.graphql.schema.printer.enabled=true");
"spring.graphql.schema.printer.enabled=true",
"spring.graphql.cors.allowed-origins=https://example.com",
"spring.graphql.cors.allowed-methods=POST", "spring.graphql.cors.allow-credentials=true");
@Test
void simpleQueryShouldWork() {
@ -114,6 +117,19 @@ class GraphQlWebFluxAutoConfigurationTests {
});
}
@Test
void shouldSupportCors() {
testWithWebClient((client) -> {
String query = "{" + " bookById(id: \\\"book-1\\\"){ " + " id" + " name" + " pageCount"
+ " author" + " }" + "}";
client.post().uri("/graphql").bodyValue("{ \"query\": \"" + query + "\"}")
.header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "POST")
.header(HttpHeaders.ORIGIN, "https://example.com").exchange().expectStatus().isOk().expectHeader()
.valueEquals(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN, "https://example.com").expectHeader()
.valueEquals(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
});
}
private void testWithWebClient(Consumer<WebTestClient> consumer) {
this.contextRunner.run((context) -> {
WebTestClient client = WebTestClient.bindToApplicationContext(context).configureClient()

@ -32,6 +32,7 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.graphql.execution.RuntimeWiringConfigurer;
import org.springframework.graphql.web.WebInterceptor;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
@ -60,7 +61,9 @@ class GraphQlWebMvcAutoConfigurationTests {
GraphQlAutoConfiguration.class, GraphQlWebMvcAutoConfiguration.class))
.withUserConfiguration(DataFetchersConfiguration.class, CustomWebInterceptor.class)
.withPropertyValues("spring.main.web-application-type=servlet", "spring.graphql.graphiql.enabled=true",
"spring.graphql.schema.printer.enabled=true");
"spring.graphql.schema.printer.enabled=true",
"spring.graphql.cors.allowed-origins=https://example.com",
"spring.graphql.cors.allowed-methods=POST", "spring.graphql.cors.allow-credentials=true");
@Test
void simpleQueryShouldWork() {
@ -119,6 +122,21 @@ class GraphQlWebMvcAutoConfigurationTests {
});
}
@Test
void shouldSupportCors() {
testWith((mockMvc) -> {
String query = "{" + " bookById(id: \\\"book-1\\\"){ " + " id" + " name" + " pageCount"
+ " author" + " }" + "}";
MvcResult result = mockMvc.perform(post("/graphql")
.header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "POST")
.header(HttpHeaders.ORIGIN, "https://example.com").content("{\"query\": \"" + query + "\"}"))
.andReturn();
mockMvc.perform(asyncDispatch(result)).andExpect(status().isOk())
.andExpect(header().stringValues(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN, "https://example.com"))
.andExpect(header().stringValues(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"));
});
}
private void testWith(MockMvcConsumer mockMvcConsumer) {
this.contextRunner.run((context) -> {
MediaType mediaType = MediaType.APPLICATION_JSON;

Loading…
Cancel
Save