Collect HTTP trace at commit time for WebFlux

Prior to this commit, the `HttpTraceWebFilter` would collect the
response information (status and headers) for tracing purposes, after
the handling chain is done with the exchange - inside a
`doAfterSuccessOrError`.

Once the handler has processed the exchange, there is no strong
guarantee about the HTTP resources being still present. Depending on the
web server implementation, HTTP resources (including HTTP header maps)
might be recycled, because pooled in the first place.

This commit moves the collection and processing of the HTTP trace right
before the response is committed. This removes the need to handle
special cases with exceptions, since by that time all exception handlers
have processed the response and the information that we extract is the
information that's about to be written to the network.

Fixes gh-15819
pull/15889/head
Brian Clozel 6 years ago
parent cba6079b7b
commit 72c8e5d366

@ -1,5 +1,5 @@
/* /*
* Copyright 2012-2018 the original author or authors. * Copyright 2012-2019 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -26,10 +26,6 @@ import org.springframework.boot.actuate.trace.http.HttpTrace;
import org.springframework.boot.actuate.trace.http.HttpTraceRepository; import org.springframework.boot.actuate.trace.http.HttpTraceRepository;
import org.springframework.boot.actuate.trace.http.Include; import org.springframework.boot.actuate.trace.http.Include;
import org.springframework.core.Ordered; import org.springframework.core.Ordered;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebFilterChain;
@ -95,38 +91,20 @@ public class HttpTraceWebFilter implements WebFilter, Ordered {
Principal principal, WebSession session) { Principal principal, WebSession session) {
ServerWebExchangeTraceableRequest request = new ServerWebExchangeTraceableRequest( ServerWebExchangeTraceableRequest request = new ServerWebExchangeTraceableRequest(
exchange); exchange);
HttpTrace trace = this.tracer.receivedRequest(request); final HttpTrace trace = this.tracer.receivedRequest(request);
return chain.filter(exchange).doAfterSuccessOrError((aVoid, ex) -> { exchange.getResponse().beforeCommit(() -> {
TraceableServerHttpResponse response = new TraceableServerHttpResponse( TraceableServerHttpResponse response = new TraceableServerHttpResponse(
(ex != null) ? new CustomStatusResponseDecorator(ex, exchange.getResponse());
exchange.getResponse()) : exchange.getResponse());
this.tracer.sendingResponse(trace, response, () -> principal, this.tracer.sendingResponse(trace, response, () -> principal,
() -> getStartedSessionId(session)); () -> getStartedSessionId(session));
this.repository.add(trace); this.repository.add(trace);
return Mono.empty();
}); });
return chain.filter(exchange);
} }
private String getStartedSessionId(WebSession session) { private String getStartedSessionId(WebSession session) {
return (session != null && session.isStarted()) ? session.getId() : null; return (session != null && session.isStarted()) ? session.getId() : null;
} }
private static final class CustomStatusResponseDecorator
extends ServerHttpResponseDecorator {
private final HttpStatus status;
private CustomStatusResponseDecorator(Throwable ex, ServerHttpResponse delegate) {
super(delegate);
this.status = (ex instanceof ResponseStatusException)
? ((ResponseStatusException) ex).getStatus()
: HttpStatus.INTERNAL_SERVER_ERROR;
}
@Override
public HttpStatus getStatusCode() {
return this.status;
}
}
} }

@ -1,5 +1,5 @@
/* /*
* Copyright 2012-2018 the original author or authors. * Copyright 2012-2019 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -30,21 +30,25 @@ import org.springframework.http.server.reactive.ServerHttpResponse;
*/ */
class TraceableServerHttpResponse implements TraceableResponse { class TraceableServerHttpResponse implements TraceableResponse {
private final ServerHttpResponse response; private final int status;
private final Map<String, List<String>> headers;
TraceableServerHttpResponse(ServerHttpResponse response) {
this.status = (response.getStatusCode() != null)
? response.getStatusCode().value() : 200;
this.headers = new LinkedHashMap<>(response.getHeaders());
TraceableServerHttpResponse(ServerHttpResponse exchange) {
this.response = exchange;
} }
@Override @Override
public int getStatus() { public int getStatus() {
return (this.response.getStatusCode() != null) return this.status;
? this.response.getStatusCode().value() : 200;
} }
@Override @Override
public Map<String, List<String>> getHeaders() { public Map<String, List<String>> getHeaders() {
return new LinkedHashMap<>(this.response.getHeaders()); return this.headers;
} }
} }

@ -16,13 +16,10 @@
package org.springframework.boot.actuate.trace.http.reactive; package org.springframework.boot.actuate.trace.http.reactive;
import java.io.IOException;
import java.security.Principal; import java.security.Principal;
import java.time.Duration; import java.time.Duration;
import java.util.EnumSet; import java.util.EnumSet;
import javax.servlet.ServletException;
import org.junit.Test; import org.junit.Test;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
@ -33,10 +30,11 @@ import org.springframework.boot.actuate.trace.http.Include;
import org.springframework.boot.actuate.web.trace.reactive.HttpTraceWebFilter; import org.springframework.boot.actuate.web.trace.reactive.HttpTraceWebFilter;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.ServerWebExchangeDecorator; import org.springframework.web.server.ServerWebExchangeDecorator;
import org.springframework.web.server.WebFilterChain;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
@ -56,8 +54,8 @@ public class HttpTraceWebFilterTests {
this.tracer, EnumSet.allOf(Include.class)); this.tracer, EnumSet.allOf(Include.class));
@Test @Test
public void filterTracesExchange() throws ServletException, IOException { public void filterTracesExchange() {
this.filter.filter( executeFilter(
MockServerWebExchange MockServerWebExchange
.from(MockServerHttpRequest.get("https://api.example.com")), .from(MockServerHttpRequest.get("https://api.example.com")),
(exchange) -> Mono.empty()).block(Duration.ofSeconds(30)); (exchange) -> Mono.empty()).block(Duration.ofSeconds(30));
@ -65,9 +63,8 @@ public class HttpTraceWebFilterTests {
} }
@Test @Test
public void filterCapturesSessionIdWhenSessionIsUsed() public void filterCapturesSessionIdWhenSessionIsUsed() {
throws ServletException, IOException { executeFilter(
this.filter.filter(
MockServerWebExchange MockServerWebExchange
.from(MockServerHttpRequest.get("https://api.example.com")), .from(MockServerHttpRequest.get("https://api.example.com")),
(exchange) -> { (exchange) -> {
@ -82,9 +79,8 @@ public class HttpTraceWebFilterTests {
} }
@Test @Test
public void filterDoesNotCaptureIdOfUnusedSession() public void filterDoesNotCaptureIdOfUnusedSession() {
throws ServletException, IOException { executeFilter(
this.filter.filter(
MockServerWebExchange MockServerWebExchange
.from(MockServerHttpRequest.get("https://api.example.com")), .from(MockServerHttpRequest.get("https://api.example.com")),
(exchange) -> { (exchange) -> {
@ -97,10 +93,10 @@ public class HttpTraceWebFilterTests {
} }
@Test @Test
public void filterCapturesPrincipal() throws ServletException, IOException { public void filterCapturesPrincipal() {
Principal principal = mock(Principal.class); Principal principal = mock(Principal.class);
given(principal.getName()).willReturn("alice"); given(principal.getName()).willReturn("alice");
this.filter.filter(new ServerWebExchangeDecorator(MockServerWebExchange executeFilter(new ServerWebExchangeDecorator(MockServerWebExchange
.from(MockServerHttpRequest.get("https://api.example.com"))) { .from(MockServerHttpRequest.get("https://api.example.com"))) {
@Override @Override
@ -120,17 +116,9 @@ public class HttpTraceWebFilterTests {
assertThat(tracedPrincipal.getName()).isEqualTo("alice"); assertThat(tracedPrincipal.getName()).isEqualTo("alice");
} }
@Test private Mono<Void> executeFilter(ServerWebExchange exchange, WebFilterChain chain) {
public void statusIsAssumedToBe500WhenChainFails() return this.filter.filter(exchange, chain)
throws ServletException, IOException { .then(Mono.defer(() -> exchange.getResponse().setComplete()));
assertThatExceptionOfType(Exception.class).isThrownBy(() -> this.filter
.filter(MockServerWebExchange
.from(MockServerHttpRequest.get("https://api.example.com")),
(exchange) -> Mono.error(new RuntimeException()))
.block(Duration.ofSeconds(30)));
assertThat(this.repository.findAll()).hasSize(1);
assertThat(this.repository.findAll().get(0).getResponse().getStatus())
.isEqualTo(500);
} }
} }

Loading…
Cancel
Save