Merge pull request #17286 from nosan

* gh-17286:
  Allow multiple values to be specified when configuring a default header

Closes gh-17286
pull/17389/head
Andy Wilkinson 5 years ago
commit f4202d2be2

@ -91,7 +91,7 @@ public class RestTemplateBuilder {
private final BasicAuthentication basicAuthentication; private final BasicAuthentication basicAuthentication;
private final Map<String, String> defaultHeaders; private final Map<String, List<String>> defaultHeaders;
private final Set<RestTemplateCustomizer> customizers; private final Set<RestTemplateCustomizer> customizers;
@ -122,7 +122,7 @@ public class RestTemplateBuilder {
String rootUri, Set<HttpMessageConverter<?>> messageConverters, String rootUri, Set<HttpMessageConverter<?>> messageConverters,
Set<ClientHttpRequestInterceptor> interceptors, Supplier<ClientHttpRequestFactory> requestFactorySupplier, Set<ClientHttpRequestInterceptor> interceptors, Supplier<ClientHttpRequestFactory> requestFactorySupplier,
UriTemplateHandler uriTemplateHandler, ResponseErrorHandler errorHandler, UriTemplateHandler uriTemplateHandler, ResponseErrorHandler errorHandler,
BasicAuthentication basicAuthentication, Map<String, String> defaultHeaders, BasicAuthentication basicAuthentication, Map<String, List<String>> defaultHeaders,
Set<RestTemplateCustomizer> customizers, Set<RestTemplateRequestCustomizer<?>> requestCustomizers) { Set<RestTemplateCustomizer> customizers, Set<RestTemplateRequestCustomizer<?>> requestCustomizers) {
this.requestFactoryCustomizer = requestFactoryCustomizer; this.requestFactoryCustomizer = requestFactoryCustomizer;
this.detectRequestFactory = detectRequestFactory; this.detectRequestFactory = detectRequestFactory;
@ -395,15 +395,17 @@ public class RestTemplateBuilder {
* Add a default header that will be set if not already present on the outgoing * Add a default header that will be set if not already present on the outgoing
* {@link HttpClientRequest}. * {@link HttpClientRequest}.
* @param name the name of the header * @param name the name of the header
* @param value the header value * @param values the header values
* @return a new builder instance * @return a new builder instance
* @since 2.2.0 * @since 2.2.0
*/ */
public RestTemplateBuilder defaultHeader(String name, String value) { public RestTemplateBuilder defaultHeader(String name, String... values) {
Assert.notNull(name, "Name must not be null");
Assert.notNull(values, "Values must not be null");
return new RestTemplateBuilder(this.requestFactoryCustomizer, this.detectRequestFactory, this.rootUri, return new RestTemplateBuilder(this.requestFactoryCustomizer, this.detectRequestFactory, this.rootUri,
this.messageConverters, this.interceptors, this.requestFactory, this.uriTemplateHandler, this.messageConverters, this.interceptors, this.requestFactory, this.uriTemplateHandler,
this.errorHandler, this.basicAuthentication, append(this.defaultHeaders, name, value), this.customizers, this.errorHandler, this.basicAuthentication, append(this.defaultHeaders, name, values),
this.requestCustomizers); this.customizers, this.requestCustomizers);
} }
/** /**
@ -683,6 +685,10 @@ public class RestTemplateBuilder {
return Collections.unmodifiableSet(new LinkedHashSet<>(collection)); return Collections.unmodifiableSet(new LinkedHashSet<>(collection));
} }
private static <T> List<T> listOf(T[] items) {
return Collections.unmodifiableList(new ArrayList<>(Arrays.asList(items)));
}
private static <T> Set<T> append(Collection<? extends T> collection, Collection<? extends T> additions) { private static <T> Set<T> append(Collection<? extends T> collection, Collection<? extends T> additions) {
Set<T> result = new LinkedHashSet<>((collection != null) ? collection : Collections.emptySet()); Set<T> result = new LinkedHashSet<>((collection != null) ? collection : Collections.emptySet());
if (additions != null) { if (additions != null) {
@ -691,9 +697,11 @@ public class RestTemplateBuilder {
return Collections.unmodifiableSet(result); return Collections.unmodifiableSet(result);
} }
private static <K, V> Map<K, V> append(Map<K, V> map, K key, V value) { private static <K, V> Map<K, List<V>> append(Map<K, List<V>> map, K key, V[] values) {
Map<K, V> result = new LinkedHashMap<>((map != null) ? map : Collections.emptyMap()); Map<K, List<V>> result = new LinkedHashMap<>((map != null) ? map : Collections.emptyMap());
result.put(key, value); if (values != null) {
result.put(key, listOf(values));
}
return Collections.unmodifiableMap(result); return Collections.unmodifiableMap(result);
} }

@ -18,6 +18,7 @@ package org.springframework.boot.web.client;
import java.io.IOException; import java.io.IOException;
import java.net.URI; import java.net.URI;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
@ -39,12 +40,12 @@ class RestTemplateBuilderClientHttpRequestFactoryWrapper extends AbstractClientH
private final BasicAuthentication basicAuthentication; private final BasicAuthentication basicAuthentication;
private final Map<String, String> defaultHeaders; private final Map<String, List<String>> defaultHeaders;
private final Set<RestTemplateRequestCustomizer<?>> requestCustomizers; private final Set<RestTemplateRequestCustomizer<?>> requestCustomizers;
RestTemplateBuilderClientHttpRequestFactoryWrapper(ClientHttpRequestFactory requestFactory, RestTemplateBuilderClientHttpRequestFactoryWrapper(ClientHttpRequestFactory requestFactory,
BasicAuthentication basicAuthentication, Map<String, String> defaultHeaders, BasicAuthentication basicAuthentication, Map<String, List<String>> defaultHeaders,
Set<RestTemplateRequestCustomizer<?>> requestCustomizers) { Set<RestTemplateRequestCustomizer<?>> requestCustomizers) {
super(requestFactory); super(requestFactory);
this.basicAuthentication = basicAuthentication; this.basicAuthentication = basicAuthentication;
@ -61,7 +62,7 @@ class RestTemplateBuilderClientHttpRequestFactoryWrapper extends AbstractClientH
if (this.basicAuthentication != null) { if (this.basicAuthentication != null) {
this.basicAuthentication.applyTo(headers); this.basicAuthentication.applyTo(headers);
} }
this.defaultHeaders.forEach(headers::addIfAbsent); this.defaultHeaders.forEach(headers::putIfAbsent);
LambdaSafe.callbacks(RestTemplateRequestCustomizer.class, this.requestCustomizers, request) LambdaSafe.callbacks(RestTemplateRequestCustomizer.class, this.requestCustomizers, request)
.invoke((customizer) -> customizer.customize(request)); .invoke((customizer) -> customizer.customize(request));
return request; return request;

@ -18,9 +18,11 @@ package org.springframework.boot.web.client;
import java.io.IOException; import java.io.IOException;
import java.net.URI; import java.net.URI;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
@ -80,16 +82,16 @@ public class RestTemplateBuilderClientHttpRequestFactoryWrapperTests {
@Test @Test
void createRequestWhenHasDefaultHeadersAddsMissing() throws IOException { void createRequestWhenHasDefaultHeadersAddsMissing() throws IOException {
this.headers.add("one", "existing"); this.headers.add("one", "existing");
Map<String, String> defaultHeaders = new LinkedHashMap<>(); Map<String, List<String>> defaultHeaders = new LinkedHashMap<>();
defaultHeaders.put("one", "1"); defaultHeaders.put("one", Collections.singletonList("1"));
defaultHeaders.put("two", "2"); defaultHeaders.put("two", Arrays.asList("2", "3"));
defaultHeaders.put("three", "3"); defaultHeaders.put("three", Collections.singletonList("4"));
this.requestFactory = new RestTemplateBuilderClientHttpRequestFactoryWrapper(this.requestFactory, null, this.requestFactory = new RestTemplateBuilderClientHttpRequestFactoryWrapper(this.requestFactory, null,
defaultHeaders, Collections.emptySet()); defaultHeaders, Collections.emptySet());
ClientHttpRequest request = createRequest(); ClientHttpRequest request = createRequest();
assertThat(request.getHeaders().get("one")).containsExactly("existing"); assertThat(request.getHeaders().get("one")).containsExactly("existing");
assertThat(request.getHeaders().get("two")).containsExactly("2"); assertThat(request.getHeaders().get("two")).containsExactly("2", "3");
assertThat(request.getHeaders().get("three")).containsExactly("3"); assertThat(request.getHeaders().get("three")).containsExactly("4");
} }
@Test @Test

@ -20,6 +20,7 @@ import java.io.IOException;
import java.net.URI; import java.net.URI;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Duration; import java.time.Duration;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.Set; import java.util.Set;
import java.util.function.Supplier; import java.util.function.Supplier;
@ -33,6 +34,7 @@ import org.mockito.MockitoAnnotations;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.client.BufferingClientHttpRequestFactory; import org.springframework.http.client.BufferingClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestFactory;
@ -321,6 +323,16 @@ class RestTemplateBuilderTests {
assertThat(request.getHeaders()).contains(entry("spring", Collections.singletonList("boot"))); assertThat(request.getHeaders()).contains(entry("spring", Collections.singletonList("boot")));
} }
@Test
void defaultHeaderAddsHeaderValues() throws IOException {
String name = HttpHeaders.ACCEPT;
String[] values = { MediaType.APPLICATION_JSON_VALUE, MediaType.APPLICATION_XML_VALUE };
RestTemplate template = this.builder.defaultHeader(name, values).build();
ClientHttpRequestFactory requestFactory = template.getRequestFactory();
ClientHttpRequest request = requestFactory.createRequest(URI.create("http://localhost"), HttpMethod.GET);
assertThat(request.getHeaders()).contains(entry(name, Arrays.asList(values)));
}
@Test @Test
void requestCustomizersAddsCustomizers() throws IOException { void requestCustomizersAddsCustomizers() throws IOException {
RestTemplate template = this.builder RestTemplate template = this.builder

Loading…
Cancel
Save