diff --git a/spring-boot-project/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/web/trace/reactive/HttpTraceWebFilter.java b/spring-boot-project/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/web/trace/reactive/HttpTraceWebFilter.java index 6809c23d69..e66170d4c1 100644 --- a/spring-boot-project/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/web/trace/reactive/HttpTraceWebFilter.java +++ b/spring-boot-project/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/web/trace/reactive/HttpTraceWebFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2018 the original author or authors. + * Copyright 2012-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,10 +26,6 @@ import org.springframework.boot.actuate.trace.http.HttpTrace; import org.springframework.boot.actuate.trace.http.HttpTraceRepository; import org.springframework.boot.actuate.trace.http.Include; import org.springframework.core.Ordered; -import org.springframework.http.HttpStatus; -import org.springframework.http.server.reactive.ServerHttpResponse; -import org.springframework.http.server.reactive.ServerHttpResponseDecorator; -import org.springframework.web.server.ResponseStatusException; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; @@ -95,38 +91,20 @@ public class HttpTraceWebFilter implements WebFilter, Ordered { Principal principal, WebSession session) { ServerWebExchangeTraceableRequest request = new ServerWebExchangeTraceableRequest( exchange); - HttpTrace trace = this.tracer.receivedRequest(request); - return chain.filter(exchange).doAfterSuccessOrError((aVoid, ex) -> { + final HttpTrace trace = this.tracer.receivedRequest(request); + exchange.getResponse().beforeCommit(() -> { TraceableServerHttpResponse response = new TraceableServerHttpResponse( - (ex != null) ? new CustomStatusResponseDecorator(ex, - exchange.getResponse()) : exchange.getResponse()); + exchange.getResponse()); this.tracer.sendingResponse(trace, response, () -> principal, () -> getStartedSessionId(session)); this.repository.add(trace); + return Mono.empty(); }); + return chain.filter(exchange); } private String getStartedSessionId(WebSession session) { return (session != null && session.isStarted()) ? session.getId() : null; } - private static final class CustomStatusResponseDecorator - extends ServerHttpResponseDecorator { - - private final HttpStatus status; - - private CustomStatusResponseDecorator(Throwable ex, ServerHttpResponse delegate) { - super(delegate); - this.status = (ex instanceof ResponseStatusException) - ? ((ResponseStatusException) ex).getStatus() - : HttpStatus.INTERNAL_SERVER_ERROR; - } - - @Override - public HttpStatus getStatusCode() { - return this.status; - } - - } - } diff --git a/spring-boot-project/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/web/trace/reactive/TraceableServerHttpResponse.java b/spring-boot-project/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/web/trace/reactive/TraceableServerHttpResponse.java index db1f44c71c..b3c968ee27 100644 --- a/spring-boot-project/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/web/trace/reactive/TraceableServerHttpResponse.java +++ b/spring-boot-project/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/web/trace/reactive/TraceableServerHttpResponse.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2018 the original author or authors. + * Copyright 2012-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,21 +30,25 @@ import org.springframework.http.server.reactive.ServerHttpResponse; */ class TraceableServerHttpResponse implements TraceableResponse { - private final ServerHttpResponse response; + private final int status; + + private final Map> headers; + + TraceableServerHttpResponse(ServerHttpResponse response) { + this.status = (response.getStatusCode() != null) + ? response.getStatusCode().value() : 200; + this.headers = new LinkedHashMap<>(response.getHeaders()); - TraceableServerHttpResponse(ServerHttpResponse exchange) { - this.response = exchange; } @Override public int getStatus() { - return (this.response.getStatusCode() != null) - ? this.response.getStatusCode().value() : 200; + return this.status; } @Override public Map> getHeaders() { - return new LinkedHashMap<>(this.response.getHeaders()); + return this.headers; } } diff --git a/spring-boot-project/spring-boot-actuator/src/test/java/org/springframework/boot/actuate/trace/http/reactive/HttpTraceWebFilterTests.java b/spring-boot-project/spring-boot-actuator/src/test/java/org/springframework/boot/actuate/trace/http/reactive/HttpTraceWebFilterTests.java index b08d7a2c07..af82cc82bf 100644 --- a/spring-boot-project/spring-boot-actuator/src/test/java/org/springframework/boot/actuate/trace/http/reactive/HttpTraceWebFilterTests.java +++ b/spring-boot-project/spring-boot-actuator/src/test/java/org/springframework/boot/actuate/trace/http/reactive/HttpTraceWebFilterTests.java @@ -16,13 +16,10 @@ package org.springframework.boot.actuate.trace.http.reactive; -import java.io.IOException; import java.security.Principal; import java.time.Duration; import java.util.EnumSet; -import javax.servlet.ServletException; - import org.junit.Test; import reactor.core.publisher.Mono; @@ -33,10 +30,11 @@ import org.springframework.boot.actuate.trace.http.Include; import org.springframework.boot.actuate.web.trace.reactive.HttpTraceWebFilter; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchangeDecorator; +import org.springframework.web.server.WebFilterChain; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; @@ -56,8 +54,8 @@ public class HttpTraceWebFilterTests { this.tracer, EnumSet.allOf(Include.class)); @Test - public void filterTracesExchange() throws ServletException, IOException { - this.filter.filter( + public void filterTracesExchange() { + executeFilter( MockServerWebExchange .from(MockServerHttpRequest.get("https://api.example.com")), (exchange) -> Mono.empty()).block(Duration.ofSeconds(30)); @@ -65,9 +63,8 @@ public class HttpTraceWebFilterTests { } @Test - public void filterCapturesSessionIdWhenSessionIsUsed() - throws ServletException, IOException { - this.filter.filter( + public void filterCapturesSessionIdWhenSessionIsUsed() { + executeFilter( MockServerWebExchange .from(MockServerHttpRequest.get("https://api.example.com")), (exchange) -> { @@ -82,9 +79,8 @@ public class HttpTraceWebFilterTests { } @Test - public void filterDoesNotCaptureIdOfUnusedSession() - throws ServletException, IOException { - this.filter.filter( + public void filterDoesNotCaptureIdOfUnusedSession() { + executeFilter( MockServerWebExchange .from(MockServerHttpRequest.get("https://api.example.com")), (exchange) -> { @@ -97,10 +93,10 @@ public class HttpTraceWebFilterTests { } @Test - public void filterCapturesPrincipal() throws ServletException, IOException { + public void filterCapturesPrincipal() { Principal principal = mock(Principal.class); given(principal.getName()).willReturn("alice"); - this.filter.filter(new ServerWebExchangeDecorator(MockServerWebExchange + executeFilter(new ServerWebExchangeDecorator(MockServerWebExchange .from(MockServerHttpRequest.get("https://api.example.com"))) { @Override @@ -120,17 +116,9 @@ public class HttpTraceWebFilterTests { assertThat(tracedPrincipal.getName()).isEqualTo("alice"); } - @Test - public void statusIsAssumedToBe500WhenChainFails() - throws ServletException, IOException { - assertThatExceptionOfType(Exception.class).isThrownBy(() -> this.filter - .filter(MockServerWebExchange - .from(MockServerHttpRequest.get("https://api.example.com")), - (exchange) -> Mono.error(new RuntimeException())) - .block(Duration.ofSeconds(30))); - assertThat(this.repository.findAll()).hasSize(1); - assertThat(this.repository.findAll().get(0).getResponse().getStatus()) - .isEqualTo(500); + private Mono executeFilter(ServerWebExchange exchange, WebFilterChain chain) { + return this.filter.filter(exchange, chain) + .then(Mono.defer(() -> exchange.getResponse().setComplete())); } }