diff --git a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java index 1a13885ab7..8d136d2da2 100644 --- a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java +++ b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java @@ -134,11 +134,17 @@ public class TestRestTemplate { httpClientOptions); } + private static RestTemplate buildRestTemplate( + RestTemplateBuilder restTemplateBuilder) { + Assert.notNull(restTemplateBuilder, "RestTemplateBuilder must not be null"); + return restTemplateBuilder.build(); + } + private TestRestTemplate(RestTemplate restTemplate, String username, String password, HttpClientOption... httpClientOptions) { Assert.notNull(restTemplate, "RestTemplate must not be null"); this.httpClientOptions = httpClientOptions; - if (restTemplate.getRequestFactory().getClass().getName() + if (getRequestFactoryClass(restTemplate).getName() .equals("org.springframework.http.client.HttpComponentsClientHttpRequestFactory")) { restTemplate.setRequestFactory( new CustomHttpComponentsClientHttpRequestFactory(httpClientOptions)); @@ -148,10 +154,16 @@ public class TestRestTemplate { this.restTemplate = restTemplate; } - private static RestTemplate buildRestTemplate( - RestTemplateBuilder restTemplateBuilder) { - Assert.notNull(restTemplateBuilder, "RestTemplateBuilder must not be null"); - return restTemplateBuilder.build(); + private Class getRequestFactoryClass(RestTemplate restTemplate) { + ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory(); + if (InterceptingClientHttpRequestFactory.class.isAssignableFrom(requestFactory.getClass())) { + Field requestFactoryField = ReflectionUtils + .findField(RestTemplate.class, "requestFactory"); + ReflectionUtils.makeAccessible(requestFactoryField); + requestFactory = (ClientHttpRequestFactory) + ReflectionUtils.getField(requestFactoryField, restTemplate); + } + return requestFactory.getClass(); } private void addAuthentication(RestTemplate restTemplate, String username, @@ -1022,11 +1034,11 @@ public class TestRestTemplate { * @since 1.4.1 */ public TestRestTemplate withBasicAuth(String username, String password) { - RestTemplate restTemplate = new RestTemplate(); - restTemplate.setMessageConverters(getRestTemplate().getMessageConverters()); - restTemplate.setInterceptors(getRestTemplate().getInterceptors()); - restTemplate.setRequestFactory(getRequestFactory(getRestTemplate())); - restTemplate.setUriTemplateHandler(getRestTemplate().getUriTemplateHandler()); + RestTemplate restTemplate = new RestTemplateBuilder() + .messageConverters(getRestTemplate().getMessageConverters()) + .interceptors(getRestTemplate().getInterceptors()) + .uriTemplateHandler(getRestTemplate().getUriTemplateHandler()) + .build(); TestRestTemplate testRestTemplate = new TestRestTemplate(restTemplate, username, password, this.httpClientOptions); testRestTemplate.getRestTemplate() @@ -1034,18 +1046,6 @@ public class TestRestTemplate { return testRestTemplate; } - private ClientHttpRequestFactory getRequestFactory(RestTemplate restTemplate) { - ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory(); - if (InterceptingClientHttpRequestFactory.class.isAssignableFrom(requestFactory.getClass())) { - Field requestFactoryField = ReflectionUtils - .findField(RestTemplate.class, "requestFactory"); - ReflectionUtils.makeAccessible(requestFactoryField); - requestFactory = (ClientHttpRequestFactory) - ReflectionUtils.getField(requestFactoryField, getRestTemplate()); - } - return requestFactory; - } - @SuppressWarnings({ "rawtypes", "unchecked" }) private RequestEntity createRequestEntityWithRootAppliedUri( RequestEntity requestEntity) {