Don't let standalone Tomcat render its error page after redirect

Previously, if the configured error controller responded with a
redirect to an error caused by an exception, standalone Tomcat would
render its default error page for the original exception. This
occurred because ErrorPageFilter sets the
javax.servlet.error.exception request attribute prior to dispatching
to the error controller and then does not clear it. As the request
unwinds, Tomcat's ErrorReportValve notices that the attribute is set
and renders an error page for the exception that is the attribute's
value.

This commit updates ErrorPageFilter to remove the
javax.servlet.error.exception and javax.servlet.error.exception_type
attributes upon successful completion of a forward to the error
controller. This prevents Tomcat from rendering an error page for
an exception that has already been handled by the error controller.

Closes gh-7920
pull/8319/head
Andy Wilkinson 8 years ago
parent 30074431a7
commit 9247cd9c36

@ -1,5 +1,5 @@
/* /*
* Copyright 2012-2016 the original author or authors. * Copyright 2012-2017 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -183,6 +183,8 @@ public class ErrorPageFilter implements Filter, ErrorPageRegistry {
response.reset(); response.reset();
response.sendError(500, ex.getMessage()); response.sendError(500, ex.getMessage());
request.getRequestDispatcher(path).forward(request, response); request.getRequestDispatcher(path).forward(request, response);
request.removeAttribute(ERROR_EXCEPTION);
request.removeAttribute(ERROR_EXCEPTION_TYPE);
} }
private String getDescription(HttpServletRequest request) { private String getDescription(HttpServletRequest request) {

@ -1,5 +1,5 @@
/* /*
* Copyright 2012-2016 the original author or authors. * Copyright 2012-2017 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -17,6 +17,9 @@
package org.springframework.boot.web.support; package org.springframework.boot.web.support;
import java.io.IOException; import java.io.IOException;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import javax.servlet.RequestDispatcher; import javax.servlet.RequestDispatcher;
import javax.servlet.ServletException; import javax.servlet.ServletException;
@ -35,6 +38,7 @@ import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockFilterConfig; import org.springframework.mock.web.MockFilterConfig;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockRequestDispatcher;
import org.springframework.web.context.request.async.DeferredResult; import org.springframework.web.context.request.async.DeferredResult;
import org.springframework.web.context.request.async.StandardServletAsyncWebRequest; import org.springframework.web.context.request.async.StandardServletAsyncWebRequest;
import org.springframework.web.context.request.async.WebAsyncManager; import org.springframework.web.context.request.async.WebAsyncManager;
@ -57,8 +61,7 @@ public class ErrorPageFilterTests {
private ErrorPageFilter filter = new ErrorPageFilter(); private ErrorPageFilter filter = new ErrorPageFilter();
private MockHttpServletRequest request = new MockHttpServletRequest("GET", private DispatchRecordingMockHttpServletRequest request = new DispatchRecordingMockHttpServletRequest();
"/test/path");
private MockHttpServletResponse response = new MockHttpServletResponse(); private MockHttpServletResponse response = new MockHttpServletResponse();
@ -261,8 +264,14 @@ public class ErrorPageFilterTests {
.isEqualTo(500); .isEqualTo(500);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE)) assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE))
.isEqualTo("BAD"); .isEqualTo("BAD");
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE)) Map<String, Object> requestAttributes = getAttributesForDispatch("/500");
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isEqualTo(RuntimeException.class); .isEqualTo(RuntimeException.class);
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION))
.isInstanceOf(RuntimeException.class);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION)).isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI)) assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI))
.isEqualTo("/test/path"); .isEqualTo("/test/path");
assertThat(this.response.isCommitted()).isTrue(); assertThat(this.response.isCommitted()).isTrue();
@ -318,8 +327,14 @@ public class ErrorPageFilterTests {
.isEqualTo(500); .isEqualTo(500);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE)) assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE))
.isEqualTo("BAD"); .isEqualTo("BAD");
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE)) Map<String, Object> requestAttributes = getAttributesForDispatch("/500");
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isEqualTo(IllegalStateException.class); .isEqualTo(IllegalStateException.class);
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION))
.isInstanceOf(IllegalStateException.class);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION)).isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI)) assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI))
.isEqualTo("/test/path"); .isEqualTo("/test/path");
assertThat(this.response.isCommitted()).isTrue(); assertThat(this.response.isCommitted()).isTrue();
@ -492,8 +507,14 @@ public class ErrorPageFilterTests {
.isEqualTo(500); .isEqualTo(500);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE)) assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE))
.isEqualTo("BAD"); .isEqualTo("BAD");
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE)) Map<String, Object> requestAttributes = getAttributesForDispatch("/500");
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isEqualTo(RuntimeException.class); .isEqualTo(RuntimeException.class);
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION))
.isInstanceOf(RuntimeException.class);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION)).isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI)) assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI))
.isEqualTo("/test/path"); .isEqualTo("/test/path");
assertThat(this.response.isCommitted()).isTrue(); assertThat(this.response.isCommitted()).isTrue();
@ -510,4 +531,60 @@ public class ErrorPageFilterTests {
asyncManager.startDeferredResultProcessing(result); asyncManager.startDeferredResultProcessing(result);
} }
private Map<String, Object> getAttributesForDispatch(String path) {
return this.request.getDispatcher(path).getRequestAttributes();
}
private static final class DispatchRecordingMockHttpServletRequest
extends MockHttpServletRequest {
private final Map<String, AttributeCapturingRequestDispatcher> dispatchers = new HashMap<String, AttributeCapturingRequestDispatcher>();
private DispatchRecordingMockHttpServletRequest() {
super("GET", "/test/path");
}
@Override
public RequestDispatcher getRequestDispatcher(String path) {
AttributeCapturingRequestDispatcher dispatcher = new AttributeCapturingRequestDispatcher(
path);
this.dispatchers.put(path, dispatcher);
return dispatcher;
}
private AttributeCapturingRequestDispatcher getDispatcher(String path) {
return this.dispatchers.get(path);
}
private static final class AttributeCapturingRequestDispatcher
extends MockRequestDispatcher {
private final Map<String, Object> requestAttributes = new HashMap<String, Object>();
private AttributeCapturingRequestDispatcher(String resource) {
super(resource);
}
@Override
public void forward(ServletRequest request, ServletResponse response) {
captureAttributes(request);
super.forward(request, response);
}
private void captureAttributes(ServletRequest request) {
Enumeration<String> names = request.getAttributeNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
this.requestAttributes.put(name, request.getAttribute(name));
}
}
private Map<String, Object> getRequestAttributes() {
return this.requestAttributes;
}
}
}
} }

Loading…
Cancel
Save