@ -225,27 +225,35 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext
* /
protected Collection < ServletContextInitializer > getServletContextInitializerBeans ( ) {
Set < ServletContextInitializer > initializers = new LinkedHashSet < ServletContextInitializer > ( ) ;
List < ServletContextInitializer > filters = new ArrayList < ServletContextInitializer > ( ) ;
List < ServletContextInitializer > servlets = new ArrayList < ServletContextInitializer > ( ) ;
List < ServletContextInitializer > listeners = new ArrayList < ServletContextInitializer > ( ) ;
List < ServletContextInitializer > other = new ArrayList < ServletContextInitializer > ( ) ;
Set < Servlet > servletRegistrations = new LinkedHashSet < Servlet > ( ) ;
Set < Filter > filterRegistrations = new LinkedHashSet < Filter > ( ) ;
Set < EventListener > listenerRegistrations = new LinkedHashSet < EventListener > ( ) ;
for ( Entry < String , ServletContextInitializer > initializerBean : getOrderedBeansOfType ( ServletContextInitializer . class ) ) {
ServletContextInitializer initializer = initializerBean . getValue ( ) ;
initializers . add ( initializer ) ;
if ( initializer instanceof ServletRegistrationBean ) {
servlets . add ( initializer ) ;
ServletRegistrationBean servlet = ( ServletRegistrationBean ) initializer ;
servletRegistrations . add ( servlet . getServlet ( ) ) ;
}
if ( initializer instanceof FilterRegistrationBean ) {
else if ( initializer instanceof FilterRegistrationBean ) {
filters . add ( initializer ) ;
FilterRegistrationBean filter = ( FilterRegistrationBean ) initializer ;
filterRegistrations . add ( filter . getFilter ( ) ) ;
}
if ( initializer instanceof ServletListenerRegistrationBean ) {
else if ( initializer instanceof ServletListenerRegistrationBean ) {
listeners . add ( initializer ) ;
listenerRegistrations
. add ( ( ( ServletListenerRegistrationBean < ? > ) initializer )
. getListener ( ) ) ;
}
else {
other . add ( initializer ) ;
}
}
List < Entry < String , Servlet > > servletBeans = getOrderedBeansOfType ( Servlet . class ) ;
@ -261,7 +269,9 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext
servlet , url ) ;
registration . setName ( name ) ;
registration . setMultipartConfig ( getMultipartConfig ( ) ) ;
initializers . add ( registration ) ;
registration . setOrder ( CustomOrderAwareComparator . INSTANCE
. getOrder ( servlet ) ) ;
servlets . add ( registration ) ;
}
}
@ -271,7 +281,9 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext
if ( ! filterRegistrations . contains ( filter ) ) {
FilterRegistrationBean registration = new FilterRegistrationBean ( filter ) ;
registration . setName ( name ) ;
initializers . add ( registration ) ;
registration . setOrder ( CustomOrderAwareComparator . INSTANCE
. getOrder ( filter ) ) ;
filters . add ( registration ) ;
}
}
@ -285,12 +297,23 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext
ServletListenerRegistrationBean < EventListener > registration = new ServletListenerRegistrationBean < EventListener > (
listener ) ;
registration . setName ( name ) ;
initializers . add ( registration ) ;
registration . setOrder ( CustomOrderAwareComparator . INSTANCE
. getOrder ( listener ) ) ;
listeners . add ( registration ) ;
}
}
}
AnnotationAwareOrderComparator . sort ( filters ) ;
AnnotationAwareOrderComparator . sort ( servlets ) ;
AnnotationAwareOrderComparator . sort ( listeners ) ;
AnnotationAwareOrderComparator . sort ( other ) ;
return initializers ;
List < ServletContextInitializer > list = new ArrayList < ServletContextInitializer > (
filters ) ;
list . addAll ( servlets ) ;
list . addAll ( listeners ) ;
list . addAll ( other ) ;
return list ;
}
private MultipartConfigElement getMultipartConfig ( ) {
@ -425,4 +448,15 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext
return this . embeddedServletContainer ;
}
private static class CustomOrderAwareComparator extends
AnnotationAwareOrderComparator {
public static CustomOrderAwareComparator INSTANCE = new CustomOrderAwareComparator ( ) ;
@Override
protected int getOrder ( Object obj ) {
return super . getOrder ( obj ) ;
}
}
}