@ -1,5 +1,5 @@
/ *
* Copyright 2012 - 201 6 the original author or authors .
* Copyright 2012 - 201 7 the original author or authors .
*
* Licensed under the Apache License , Version 2.0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
@ -17,6 +17,9 @@
package org.springframework.boot.web.support ;
import java.io.IOException ;
import java.util.Enumeration ;
import java.util.HashMap ;
import java.util.Map ;
import javax.servlet.RequestDispatcher ;
import javax.servlet.ServletException ;
@ -35,6 +38,7 @@ import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockFilterConfig ;
import org.springframework.mock.web.MockHttpServletRequest ;
import org.springframework.mock.web.MockHttpServletResponse ;
import org.springframework.mock.web.MockRequestDispatcher ;
import org.springframework.web.context.request.async.DeferredResult ;
import org.springframework.web.context.request.async.StandardServletAsyncWebRequest ;
import org.springframework.web.context.request.async.WebAsyncManager ;
@ -57,8 +61,7 @@ public class ErrorPageFilterTests {
private ErrorPageFilter filter = new ErrorPageFilter ( ) ;
private MockHttpServletRequest request = new MockHttpServletRequest ( "GET" ,
"/test/path" ) ;
private DispatchRecordingMockHttpServletRequest request = new DispatchRecordingMockHttpServletRequest ( ) ;
private MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
@ -261,8 +264,14 @@ public class ErrorPageFilterTests {
. isEqualTo ( 500 ) ;
assertThat ( this . request . getAttribute ( RequestDispatcher . ERROR_MESSAGE ) )
. isEqualTo ( "BAD" ) ;
assertThat ( this . request . getAttribute ( RequestDispatcher . ERROR_EXCEPTION_TYPE ) )
Map < String , Object > requestAttributes = getAttributesForDispatch ( "/500" ) ;
assertThat ( requestAttributes . get ( RequestDispatcher . ERROR_EXCEPTION_TYPE ) )
. isEqualTo ( RuntimeException . class ) ;
assertThat ( requestAttributes . get ( RequestDispatcher . ERROR_EXCEPTION ) )
. isInstanceOf ( RuntimeException . class ) ;
assertThat ( this . request . getAttribute ( RequestDispatcher . ERROR_EXCEPTION_TYPE ) )
. isNull ( ) ;
assertThat ( this . request . getAttribute ( RequestDispatcher . ERROR_EXCEPTION ) ) . isNull ( ) ;
assertThat ( this . request . getAttribute ( RequestDispatcher . ERROR_REQUEST_URI ) )
. isEqualTo ( "/test/path" ) ;
assertThat ( this . response . isCommitted ( ) ) . isTrue ( ) ;
@ -318,8 +327,14 @@ public class ErrorPageFilterTests {
. isEqualTo ( 500 ) ;
assertThat ( this . request . getAttribute ( RequestDispatcher . ERROR_MESSAGE ) )
. isEqualTo ( "BAD" ) ;
assertThat ( this . request . getAttribute ( RequestDispatcher . ERROR_EXCEPTION_TYPE ) )
Map < String , Object > requestAttributes = getAttributesForDispatch ( "/500" ) ;
assertThat ( requestAttributes . get ( RequestDispatcher . ERROR_EXCEPTION_TYPE ) )
. isEqualTo ( IllegalStateException . class ) ;
assertThat ( requestAttributes . get ( RequestDispatcher . ERROR_EXCEPTION ) )
. isInstanceOf ( IllegalStateException . class ) ;
assertThat ( this . request . getAttribute ( RequestDispatcher . ERROR_EXCEPTION_TYPE ) )
. isNull ( ) ;
assertThat ( this . request . getAttribute ( RequestDispatcher . ERROR_EXCEPTION ) ) . isNull ( ) ;
assertThat ( this . request . getAttribute ( RequestDispatcher . ERROR_REQUEST_URI ) )
. isEqualTo ( "/test/path" ) ;
assertThat ( this . response . isCommitted ( ) ) . isTrue ( ) ;
@ -492,8 +507,14 @@ public class ErrorPageFilterTests {
. isEqualTo ( 500 ) ;
assertThat ( this . request . getAttribute ( RequestDispatcher . ERROR_MESSAGE ) )
. isEqualTo ( "BAD" ) ;
assertThat ( this . request . getAttribute ( RequestDispatcher . ERROR_EXCEPTION_TYPE ) )
Map < String , Object > requestAttributes = getAttributesForDispatch ( "/500" ) ;
assertThat ( requestAttributes . get ( RequestDispatcher . ERROR_EXCEPTION_TYPE ) )
. isEqualTo ( RuntimeException . class ) ;
assertThat ( requestAttributes . get ( RequestDispatcher . ERROR_EXCEPTION ) )
. isInstanceOf ( RuntimeException . class ) ;
assertThat ( this . request . getAttribute ( RequestDispatcher . ERROR_EXCEPTION_TYPE ) )
. isNull ( ) ;
assertThat ( this . request . getAttribute ( RequestDispatcher . ERROR_EXCEPTION ) ) . isNull ( ) ;
assertThat ( this . request . getAttribute ( RequestDispatcher . ERROR_REQUEST_URI ) )
. isEqualTo ( "/test/path" ) ;
assertThat ( this . response . isCommitted ( ) ) . isTrue ( ) ;
@ -510,4 +531,60 @@ public class ErrorPageFilterTests {
asyncManager . startDeferredResultProcessing ( result ) ;
}
private Map < String , Object > getAttributesForDispatch ( String path ) {
return this . request . getDispatcher ( path ) . getRequestAttributes ( ) ;
}
private static final class DispatchRecordingMockHttpServletRequest
extends MockHttpServletRequest {
private final Map < String , AttributeCapturingRequestDispatcher > dispatchers = new HashMap < String , AttributeCapturingRequestDispatcher > ( ) ;
private DispatchRecordingMockHttpServletRequest ( ) {
super ( "GET" , "/test/path" ) ;
}
@Override
public RequestDispatcher getRequestDispatcher ( String path ) {
AttributeCapturingRequestDispatcher dispatcher = new AttributeCapturingRequestDispatcher (
path ) ;
this . dispatchers . put ( path , dispatcher ) ;
return dispatcher ;
}
private AttributeCapturingRequestDispatcher getDispatcher ( String path ) {
return this . dispatchers . get ( path ) ;
}
private static final class AttributeCapturingRequestDispatcher
extends MockRequestDispatcher {
private final Map < String , Object > requestAttributes = new HashMap < String , Object > ( ) ;
private AttributeCapturingRequestDispatcher ( String resource ) {
super ( resource ) ;
}
@Override
public void forward ( ServletRequest request , ServletResponse response ) {
captureAttributes ( request ) ;
super . forward ( request , response ) ;
}
private void captureAttributes ( ServletRequest request ) {
Enumeration < String > names = request . getAttributeNames ( ) ;
while ( names . hasMoreElements ( ) ) {
String name = names . nextElement ( ) ;
this . requestAttributes . put ( name , request . getAttribute ( name ) ) ;
}
}
private Map < String , Object > getRequestAttributes ( ) {
return this . requestAttributes ;
}
}
}
}