From b3597107bab903d3c0012d0279a240aeefe6e25b Mon Sep 17 00:00:00 2001 From: Dave Syer Date: Fri, 16 Jan 2015 11:25:13 +0000 Subject: [PATCH] More careful masking for the HTTP status in non-error cases The ErrorPageFilter exposes a wrapped response to the downstream chain and unless more care is taken the chain will be able to set the response status, but not inspect it. Fixes gh-2367 --- .../boot/context/web/ErrorPageFilter.java | 14 +- .../web/ErrorPageFilterIntegrationTests.java | 175 ++++++++++++++++++ .../context/web/ErrorPageFilterTests.java | 20 ++ 3 files changed, 205 insertions(+), 4 deletions(-) create mode 100644 spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterIntegrationTests.java diff --git a/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java b/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java index 55d6b66a75..c802f9e78d 100644 --- a/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java +++ b/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java @@ -272,11 +272,10 @@ public class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContaine private String message; - private boolean errorToSend; + private boolean errorToSend = false; public ErrorWrapperResponse(HttpServletResponse response) { super(response); - this.status = response.getStatus(); } @Override @@ -288,13 +287,20 @@ public class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContaine public void sendError(int status, String message) throws IOException { this.status = status; this.message = message; - this.errorToSend = true; + // Do not call super because the container may prevent us from handling the + // error ourselves } @Override public int getStatus() { - return this.status; + if (this.errorToSend) { + return this.status; + } + else { + // If there was no error we need to trust the wrapped response + return super.getStatus(); + } } @Override diff --git a/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterIntegrationTests.java b/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterIntegrationTests.java new file mode 100644 index 0000000000..34e5041ce7 --- /dev/null +++ b/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterIntegrationTests.java @@ -0,0 +1,175 @@ +/* + * Copyright 2012-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.context.web; + +import java.net.URI; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.junit.After; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.context.embedded.AnnotationConfigEmbeddedWebApplicationContext; +import org.springframework.boot.context.embedded.EmbeddedServletContainerFactory; +import org.springframework.boot.context.embedded.tomcat.TomcatEmbeddedServletContainerFactory; +import org.springframework.boot.context.web.ErrorPageFilterIntegrationTests.TomcatConfig; +import org.springframework.boot.test.IntegrationTest; +import org.springframework.boot.test.SpringApplicationConfiguration; +import org.springframework.boot.test.TestRestTemplate; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.stereotype.Controller; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.ResponseBody; +import org.springframework.web.bind.annotation.ResponseStatus; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.ModelAndView; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.config.annotation.InterceptorRegistry; +import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter; +import org.springframework.web.servlet.handler.HandlerInterceptorAdapter; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; + +/** + * + * @author Dave Syer + */ +@RunWith(SpringJUnit4ClassRunner.class) +@SpringApplicationConfiguration(classes = TomcatConfig.class) +@IntegrationTest +@WebAppConfiguration +public class ErrorPageFilterIntegrationTests { + + @Autowired + private HelloWorldController controller; + + @Autowired + private AnnotationConfigEmbeddedWebApplicationContext context; + + @After + public void init() { + this.controller.reset(); + } + + @Test + public void created() throws Exception { + doTest(this.context, "/create", HttpStatus.CREATED); + assertThat(this.controller.getStatus(), equalTo(201)); + } + + @Test + public void ok() throws Exception { + doTest(this.context, "/hello", HttpStatus.OK); + assertThat(this.controller.getStatus(), equalTo(200)); + } + + private void doTest(AnnotationConfigEmbeddedWebApplicationContext context, + String resourcePath, HttpStatus status) throws Exception { + int port = context.getEmbeddedServletContainer().getPort(); + TestRestTemplate template = new TestRestTemplate(); + ResponseEntity entity = template.getForEntity(new URI("http://localhost:" + + port + resourcePath), String.class); + assertThat(entity.getBody(), equalTo("Hello World")); + assertThat(entity.getStatusCode(), equalTo(status)); + } + + @Configuration + @EnableWebMvc + public static class TomcatConfig { + + @Bean + public EmbeddedServletContainerFactory containerFactory() { + return new TomcatEmbeddedServletContainerFactory(0); + } + + @Bean + public ErrorPageFilter errorPageFilter() { + return new ErrorPageFilter(); + } + + @Bean + public DispatcherServlet dispatcherServlet() { + return new DispatcherServlet(); + } + + @Bean + public HelloWorldController helloWorldController() { + return new HelloWorldController(); + } + } + + @Controller + public static class HelloWorldController extends WebMvcConfigurerAdapter { + + private int status; + + private CountDownLatch latch = new CountDownLatch(1); + + public int getStatus() throws InterruptedException { + assertThat("Timed out waiting for latch", + this.latch.await(1, TimeUnit.SECONDS), equalTo(true)); + return this.status; + } + + public void setStatus(int status) { + this.status = status; + } + + public void reset() { + this.status = 0; + this.latch = new CountDownLatch(1); + } + + @Override + public void addInterceptors(InterceptorRegistry registry) { + registry.addInterceptor(new HandlerInterceptorAdapter() { + @Override + public void postHandle(HttpServletRequest request, + HttpServletResponse response, Object handler, + ModelAndView modelAndView) throws Exception { + HelloWorldController.this.setStatus(response.getStatus()); + HelloWorldController.this.latch.countDown(); + } + }); + } + + @RequestMapping("/hello") + @ResponseBody + public String sayHello() { + return "Hello World"; + } + + @RequestMapping("/create") + @ResponseBody + @ResponseStatus(HttpStatus.CREATED) + public String created() { + return "Hello World"; + } + + } + +} diff --git a/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java b/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java index 19dc76992d..30fcce0556 100644 --- a/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java +++ b/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java @@ -77,6 +77,26 @@ public class ErrorPageFilterTests { assertThat(this.response.getForwardedUrl(), is(nullValue())); } + @Test + public void notAnErrorButNotOK() throws Exception { + this.chain = new MockFilterChain() { + @Override + public void doFilter(ServletRequest request, ServletResponse response) + throws IOException, ServletException { + ((HttpServletResponse) response).setStatus(201); + super.doFilter(request, response); + response.flushBuffer(); + } + }; + this.filter.doFilter(this.request, this.response, this.chain); + assertThat(((HttpServletResponse) this.chain.getResponse()).getStatus(), + equalTo(201)); + assertThat( + ((HttpServletResponse) ((HttpServletResponseWrapper) this.chain.getResponse()) + .getResponse()).getStatus(), equalTo(201)); + assertTrue(this.response.isCommitted()); + } + @Test public void unauthorizedWithErrorPath() throws Exception { this.filter.addErrorPages(new ErrorPage("/error"));