diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java index 9de267f268..0f6fe7214a 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java @@ -91,7 +91,7 @@ public class RestTemplateBuilder { private final BasicAuthentication basicAuthentication; - private final Map defaultHeaders; + private final Map> defaultHeaders; private final Set customizers; @@ -122,7 +122,7 @@ public class RestTemplateBuilder { String rootUri, Set> messageConverters, Set interceptors, Supplier requestFactorySupplier, UriTemplateHandler uriTemplateHandler, ResponseErrorHandler errorHandler, - BasicAuthentication basicAuthentication, Map defaultHeaders, + BasicAuthentication basicAuthentication, Map> defaultHeaders, Set customizers, Set> requestCustomizers) { this.requestFactoryCustomizer = requestFactoryCustomizer; this.detectRequestFactory = detectRequestFactory; @@ -395,15 +395,17 @@ public class RestTemplateBuilder { * Add a default header that will be set if not already present on the outgoing * {@link HttpClientRequest}. * @param name the name of the header - * @param value the header value + * @param values the header values * @return a new builder instance * @since 2.2.0 */ - public RestTemplateBuilder defaultHeader(String name, String value) { + public RestTemplateBuilder defaultHeader(String name, String... values) { + Assert.notNull(name, "Name must not be null"); + Assert.notNull(values, "Values must not be null"); return new RestTemplateBuilder(this.requestFactoryCustomizer, this.detectRequestFactory, this.rootUri, this.messageConverters, this.interceptors, this.requestFactory, this.uriTemplateHandler, - this.errorHandler, this.basicAuthentication, append(this.defaultHeaders, name, value), this.customizers, - this.requestCustomizers); + this.errorHandler, this.basicAuthentication, append(this.defaultHeaders, name, values), + this.customizers, this.requestCustomizers); } /** @@ -683,6 +685,10 @@ public class RestTemplateBuilder { return Collections.unmodifiableSet(new LinkedHashSet<>(collection)); } + private static List listOf(T[] items) { + return Collections.unmodifiableList(new ArrayList<>(Arrays.asList(items))); + } + private static Set append(Collection collection, Collection additions) { Set result = new LinkedHashSet<>((collection != null) ? collection : Collections.emptySet()); if (additions != null) { @@ -691,9 +697,11 @@ public class RestTemplateBuilder { return Collections.unmodifiableSet(result); } - private static Map append(Map map, K key, V value) { - Map result = new LinkedHashMap<>((map != null) ? map : Collections.emptyMap()); - result.put(key, value); + private static Map> append(Map> map, K key, V[] values) { + Map> result = new LinkedHashMap<>((map != null) ? map : Collections.emptyMap()); + if (values != null) { + result.put(key, listOf(values)); + } return Collections.unmodifiableMap(result); } diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilderClientHttpRequestFactoryWrapper.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilderClientHttpRequestFactoryWrapper.java index bc3019ad65..b327650dad 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilderClientHttpRequestFactoryWrapper.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilderClientHttpRequestFactoryWrapper.java @@ -18,6 +18,7 @@ package org.springframework.boot.web.client; import java.io.IOException; import java.net.URI; +import java.util.List; import java.util.Map; import java.util.Set; @@ -39,12 +40,12 @@ class RestTemplateBuilderClientHttpRequestFactoryWrapper extends AbstractClientH private final BasicAuthentication basicAuthentication; - private final Map defaultHeaders; + private final Map> defaultHeaders; private final Set> requestCustomizers; RestTemplateBuilderClientHttpRequestFactoryWrapper(ClientHttpRequestFactory requestFactory, - BasicAuthentication basicAuthentication, Map defaultHeaders, + BasicAuthentication basicAuthentication, Map> defaultHeaders, Set> requestCustomizers) { super(requestFactory); this.basicAuthentication = basicAuthentication; @@ -61,7 +62,7 @@ class RestTemplateBuilderClientHttpRequestFactoryWrapper extends AbstractClientH if (this.basicAuthentication != null) { this.basicAuthentication.applyTo(headers); } - this.defaultHeaders.forEach(headers::addIfAbsent); + this.defaultHeaders.forEach(headers::putIfAbsent); LambdaSafe.callbacks(RestTemplateRequestCustomizer.class, this.requestCustomizers, request) .invoke((customizer) -> customizer.customize(request)); return request; diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderClientHttpRequestFactoryWrapperTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderClientHttpRequestFactoryWrapperTests.java index 655174240d..cc7c2280b2 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderClientHttpRequestFactoryWrapperTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderClientHttpRequestFactoryWrapperTests.java @@ -18,9 +18,11 @@ package org.springframework.boot.web.client; import java.io.IOException; import java.net.URI; +import java.util.Arrays; import java.util.Collections; import java.util.LinkedHashMap; import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; import java.util.Set; @@ -80,16 +82,16 @@ public class RestTemplateBuilderClientHttpRequestFactoryWrapperTests { @Test void createRequestWhenHasDefaultHeadersAddsMissing() throws IOException { this.headers.add("one", "existing"); - Map defaultHeaders = new LinkedHashMap<>(); - defaultHeaders.put("one", "1"); - defaultHeaders.put("two", "2"); - defaultHeaders.put("three", "3"); + Map> defaultHeaders = new LinkedHashMap<>(); + defaultHeaders.put("one", Collections.singletonList("1")); + defaultHeaders.put("two", Arrays.asList("2", "3")); + defaultHeaders.put("three", Collections.singletonList("4")); this.requestFactory = new RestTemplateBuilderClientHttpRequestFactoryWrapper(this.requestFactory, null, defaultHeaders, Collections.emptySet()); ClientHttpRequest request = createRequest(); assertThat(request.getHeaders().get("one")).containsExactly("existing"); - assertThat(request.getHeaders().get("two")).containsExactly("2"); - assertThat(request.getHeaders().get("three")).containsExactly("3"); + assertThat(request.getHeaders().get("two")).containsExactly("2", "3"); + assertThat(request.getHeaders().get("three")).containsExactly("4"); } @Test diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java index 03435302fa..3865ab0bae 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.net.URI; import java.nio.charset.StandardCharsets; import java.time.Duration; +import java.util.Arrays; import java.util.Collections; import java.util.Set; import java.util.function.Supplier; @@ -33,6 +34,7 @@ import org.mockito.MockitoAnnotations; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; import org.springframework.http.client.BufferingClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpRequestFactory; @@ -321,6 +323,16 @@ class RestTemplateBuilderTests { assertThat(request.getHeaders()).contains(entry("spring", Collections.singletonList("boot"))); } + @Test + void defaultHeaderAddsHeaderValues() throws IOException { + String name = HttpHeaders.ACCEPT; + String[] values = { MediaType.APPLICATION_JSON_VALUE, MediaType.APPLICATION_XML_VALUE }; + RestTemplate template = this.builder.defaultHeader(name, values).build(); + ClientHttpRequestFactory requestFactory = template.getRequestFactory(); + ClientHttpRequest request = requestFactory.createRequest(URI.create("http://localhost"), HttpMethod.GET); + assertThat(request.getHeaders()).contains(entry(name, Arrays.asList(values))); + } + @Test void requestCustomizersAddsCustomizers() throws IOException { RestTemplate template = this.builder