From c53a36717d2c9a278042a2e49fa6cc794331a903 Mon Sep 17 00:00:00 2001 From: Dave Syer Date: Wed, 5 Mar 2014 09:45:13 +0000 Subject: [PATCH] Consider subtypes in exception error page mapping Fixes gh-417 --- .../boot/context/web/ErrorPageFilter.java | 26 +++++++++++++++++-- ...ryTests.java => ErrorPageFilterTests.java} | 25 +++++++++++++++++- 2 files changed, 48 insertions(+), 3 deletions(-) rename spring-boot/src/test/java/org/springframework/boot/context/web/{ErrorWrapperEmbeddedServletContainerFactoryTests.java => ErrorPageFilterTests.java} (82%) diff --git a/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java b/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java index a16f36f250..9b46e5b644 100644 --- a/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java +++ b/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java @@ -71,6 +71,8 @@ class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContainer imple private final Map, String> exceptions = new HashMap, String>(); + private final Map, Class> subtypes = new HashMap, Class>(); + @Override public void init(FilterConfig filterConfig) throws ServletException { } @@ -118,14 +120,34 @@ class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContainer imple private void handleException(HttpServletRequest request, HttpServletResponse response, ErrorWrapperResponse wrapped, Throwable ex) throws IOException, ServletException { - String errorPath = getErrorPath(this.exceptions, ex.getClass()); + Class type = ex.getClass(); + String errorPath = this.global; + if (this.exceptions.containsKey(type)) { + errorPath = this.exceptions.get(type); + } + else { + if (this.subtypes.containsKey(type)) { + errorPath = this.exceptions.get(this.subtypes.get(type)); + } + else { + Class subtype = type; + while (subtype != Object.class) { + subtype = subtype.getSuperclass(); + if (this.exceptions.containsKey(subtype)) { + this.subtypes.put(subtype, type); + errorPath = this.exceptions.get(subtype); + break; + } + } + } + } if (errorPath == null) { rethrow(ex); return; } setErrorAttributes(request, 500, ex.getMessage()); request.setAttribute(ERROR_EXCEPTION, ex); - request.setAttribute(ERROR_EXCEPTION_TYPE, ex.getClass().getName()); + request.setAttribute(ERROR_EXCEPTION_TYPE, type.getName()); wrapped.sendError(500, ex.getMessage()); request.getRequestDispatcher(errorPath).forward(request, response); } diff --git a/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorWrapperEmbeddedServletContainerFactoryTests.java b/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java similarity index 82% rename from spring-boot/src/test/java/org/springframework/boot/context/web/ErrorWrapperEmbeddedServletContainerFactoryTests.java rename to spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java index 4119a117fd..a7cdb113a2 100644 --- a/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorWrapperEmbeddedServletContainerFactoryTests.java +++ b/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java @@ -40,7 +40,7 @@ import static org.junit.Assert.assertThat; * * @author Dave Syer */ -public class ErrorWrapperEmbeddedServletContainerFactoryTests { +public class ErrorPageFilterTests { private ErrorPageFilter filter = new ErrorPageFilter(); @@ -119,4 +119,27 @@ public class ErrorWrapperEmbeddedServletContainerFactoryTests { assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE), equalTo((Object) RuntimeException.class.getName())); } + + @Test + public void subClassExceptionError() throws Exception { + this.filter.addErrorPages(new ErrorPage(RuntimeException.class, "/500")); + this.chain = new MockFilterChain() { + @Override + public void doFilter(ServletRequest request, ServletResponse response) + throws IOException, ServletException { + super.doFilter(request, response); + throw new IllegalStateException("BAD"); + } + }; + this.filter.doFilter(this.request, this.response, this.chain); + assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getStatus(), + equalTo(500)); + assertThat(this.request.getAttribute(RequestDispatcher.ERROR_STATUS_CODE), + equalTo((Object) 500)); + assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE), + equalTo((Object) "BAD")); + assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE), + equalTo((Object) IllegalStateException.class.getName())); + } + }