Consider subtypes in exception error page mapping

Fixes gh-417
pull/450/head
Dave Syer 11 years ago
parent a0ba8c90a6
commit c53a36717d

@ -71,6 +71,8 @@ class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContainer imple
private final Map<Class<?>, String> exceptions = new HashMap<Class<?>, String>(); private final Map<Class<?>, String> exceptions = new HashMap<Class<?>, String>();
private final Map<Class<?>, Class<?>> subtypes = new HashMap<Class<?>, Class<?>>();
@Override @Override
public void init(FilterConfig filterConfig) throws ServletException { public void init(FilterConfig filterConfig) throws ServletException {
} }
@ -118,14 +120,34 @@ class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContainer imple
private void handleException(HttpServletRequest request, private void handleException(HttpServletRequest request,
HttpServletResponse response, ErrorWrapperResponse wrapped, Throwable ex) HttpServletResponse response, ErrorWrapperResponse wrapped, Throwable ex)
throws IOException, ServletException { throws IOException, ServletException {
String errorPath = getErrorPath(this.exceptions, ex.getClass()); Class<?> type = ex.getClass();
String errorPath = this.global;
if (this.exceptions.containsKey(type)) {
errorPath = this.exceptions.get(type);
}
else {
if (this.subtypes.containsKey(type)) {
errorPath = this.exceptions.get(this.subtypes.get(type));
}
else {
Class<?> subtype = type;
while (subtype != Object.class) {
subtype = subtype.getSuperclass();
if (this.exceptions.containsKey(subtype)) {
this.subtypes.put(subtype, type);
errorPath = this.exceptions.get(subtype);
break;
}
}
}
}
if (errorPath == null) { if (errorPath == null) {
rethrow(ex); rethrow(ex);
return; return;
} }
setErrorAttributes(request, 500, ex.getMessage()); setErrorAttributes(request, 500, ex.getMessage());
request.setAttribute(ERROR_EXCEPTION, ex); request.setAttribute(ERROR_EXCEPTION, ex);
request.setAttribute(ERROR_EXCEPTION_TYPE, ex.getClass().getName()); request.setAttribute(ERROR_EXCEPTION_TYPE, type.getName());
wrapped.sendError(500, ex.getMessage()); wrapped.sendError(500, ex.getMessage());
request.getRequestDispatcher(errorPath).forward(request, response); request.getRequestDispatcher(errorPath).forward(request, response);
} }

@ -40,7 +40,7 @@ import static org.junit.Assert.assertThat;
* *
* @author Dave Syer * @author Dave Syer
*/ */
public class ErrorWrapperEmbeddedServletContainerFactoryTests { public class ErrorPageFilterTests {
private ErrorPageFilter filter = new ErrorPageFilter(); private ErrorPageFilter filter = new ErrorPageFilter();
@ -119,4 +119,27 @@ public class ErrorWrapperEmbeddedServletContainerFactoryTests {
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE), assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE),
equalTo((Object) RuntimeException.class.getName())); equalTo((Object) RuntimeException.class.getName()));
} }
@Test
public void subClassExceptionError() throws Exception {
this.filter.addErrorPages(new ErrorPage(RuntimeException.class, "/500"));
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
super.doFilter(request, response);
throw new IllegalStateException("BAD");
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getStatus(),
equalTo(500));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_STATUS_CODE),
equalTo((Object) 500));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE),
equalTo((Object) "BAD"));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE),
equalTo((Object) IllegalStateException.class.getName()));
}
} }
Loading…
Cancel
Save