diff --git a/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java b/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java index fdf3f657ac..f6af4d60ce 100644 --- a/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java +++ b/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java @@ -88,6 +88,11 @@ public class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContaine ErrorPageFilter.this.doFilter(request, response, chain); } + @Override + protected boolean shouldNotFilterAsyncDispatch() { + return false; + } + }; @Override diff --git a/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java b/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java index f83be0f6c0..3e1af18dbf 100644 --- a/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java +++ b/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java @@ -34,6 +34,10 @@ import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockFilterConfig; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.web.context.request.async.DeferredResult; +import org.springframework.web.context.request.async.StandardServletAsyncWebRequest; +import org.springframework.web.context.request.async.WebAsyncManager; +import org.springframework.web.context.request.async.WebAsyncUtils; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -373,6 +377,62 @@ public class ErrorPageFilterTests { assertTrue(this.response.isCommitted()); } + @Test + public void responseIsNotCommitedDuringAsyncDispatch() throws Exception { + setUpAsyncDispatch(); + + this.filter.doFilter(this.request, this.response, this.chain); + + assertThat(this.chain.getRequest(), equalTo((ServletRequest) this.request)); + assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getResponse(), + equalTo((ServletResponse) this.response)); + assertFalse(this.response.isCommitted()); + } + + @Test + public void responseIsCommitedWhenExceptionIsThrownDuringAsyncDispatch() + throws Exception { + this.filter.addErrorPages(new ErrorPage("/error")); + setUpAsyncDispatch(); + this.chain = new MockFilterChain() { + @Override + public void doFilter(ServletRequest request, ServletResponse response) + throws IOException, ServletException { + super.doFilter(request, response); + throw new RuntimeException("BAD"); + } + }; + + this.filter.doFilter(this.request, this.response, this.chain); + + assertThat(this.chain.getRequest(), equalTo((ServletRequest) this.request)); + assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getResponse(), + equalTo((ServletResponse) this.response)); + assertTrue(this.response.isCommitted()); + } + + @Test + public void responseIsCommitedWhenStatusIs400PlusDuringAsyncDispatch() + throws Exception { + this.filter.addErrorPages(new ErrorPage("/error")); + setUpAsyncDispatch(); + this.chain = new MockFilterChain() { + @Override + public void doFilter(ServletRequest request, ServletResponse response) + throws IOException, ServletException { + super.doFilter(request, response); + ((HttpServletResponse) response).sendError(400, "BAD"); + } + }; + + this.filter.doFilter(this.request, this.response, this.chain); + + assertThat(this.chain.getRequest(), equalTo((ServletRequest) this.request)); + assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getResponse(), + equalTo((ServletResponse) this.response)); + assertTrue(this.response.isCommitted()); + } + @Test public void responseIsNotFlushedIfStatusIsLessThan400AndItHasAlreadyBeenCommitted() throws Exception { @@ -419,4 +479,14 @@ public class ErrorPageFilterTests { assertThat(this.output.toString(), containsString("request [/test/alpha]")); } + private void setUpAsyncDispatch() throws Exception { + this.request.setAsyncSupported(true); + this.request.setAsyncStarted(true); + DeferredResult result = new DeferredResult(); + WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(this.request); + asyncManager.setAsyncWebRequest(new StandardServletAsyncWebRequest(this.request, + this.response)); + asyncManager.startDeferredResultProcessing(result); + } + }