Order Filters, Servlets etc. separately in EmbeddedWebApplicationContext

Users could be surpised if they register a Filter with an @Order and it
isn't apparently respected. This change accumulates all Filters and
FilterRegistrations (for instance) before sorting them.

Fixes gh-1455
pull/1487/merge
Dave Syer 10 years ago
parent 1ddcf3657b
commit 47b59046bd

@ -225,27 +225,35 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext
*/ */
protected Collection<ServletContextInitializer> getServletContextInitializerBeans() { protected Collection<ServletContextInitializer> getServletContextInitializerBeans() {
Set<ServletContextInitializer> initializers = new LinkedHashSet<ServletContextInitializer>(); List<ServletContextInitializer> filters = new ArrayList<ServletContextInitializer>();
List<ServletContextInitializer> servlets = new ArrayList<ServletContextInitializer>();
List<ServletContextInitializer> listeners = new ArrayList<ServletContextInitializer>();
List<ServletContextInitializer> other = new ArrayList<ServletContextInitializer>();
Set<Servlet> servletRegistrations = new LinkedHashSet<Servlet>(); Set<Servlet> servletRegistrations = new LinkedHashSet<Servlet>();
Set<Filter> filterRegistrations = new LinkedHashSet<Filter>(); Set<Filter> filterRegistrations = new LinkedHashSet<Filter>();
Set<EventListener> listenerRegistrations = new LinkedHashSet<EventListener>(); Set<EventListener> listenerRegistrations = new LinkedHashSet<EventListener>();
for (Entry<String, ServletContextInitializer> initializerBean : getOrderedBeansOfType(ServletContextInitializer.class)) { for (Entry<String, ServletContextInitializer> initializerBean : getOrderedBeansOfType(ServletContextInitializer.class)) {
ServletContextInitializer initializer = initializerBean.getValue(); ServletContextInitializer initializer = initializerBean.getValue();
initializers.add(initializer);
if (initializer instanceof ServletRegistrationBean) { if (initializer instanceof ServletRegistrationBean) {
servlets.add(initializer);
ServletRegistrationBean servlet = (ServletRegistrationBean) initializer; ServletRegistrationBean servlet = (ServletRegistrationBean) initializer;
servletRegistrations.add(servlet.getServlet()); servletRegistrations.add(servlet.getServlet());
} }
if (initializer instanceof FilterRegistrationBean) { else if (initializer instanceof FilterRegistrationBean) {
filters.add(initializer);
FilterRegistrationBean filter = (FilterRegistrationBean) initializer; FilterRegistrationBean filter = (FilterRegistrationBean) initializer;
filterRegistrations.add(filter.getFilter()); filterRegistrations.add(filter.getFilter());
} }
if (initializer instanceof ServletListenerRegistrationBean) { else if (initializer instanceof ServletListenerRegistrationBean) {
listeners.add(initializer);
listenerRegistrations listenerRegistrations
.add(((ServletListenerRegistrationBean<?>) initializer) .add(((ServletListenerRegistrationBean<?>) initializer)
.getListener()); .getListener());
} }
else {
other.add(initializer);
}
} }
List<Entry<String, Servlet>> servletBeans = getOrderedBeansOfType(Servlet.class); List<Entry<String, Servlet>> servletBeans = getOrderedBeansOfType(Servlet.class);
@ -261,7 +269,9 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext
servlet, url); servlet, url);
registration.setName(name); registration.setName(name);
registration.setMultipartConfig(getMultipartConfig()); registration.setMultipartConfig(getMultipartConfig());
initializers.add(registration); registration.setOrder(CustomOrderAwareComparator.INSTANCE
.getOrder(servlet));
servlets.add(registration);
} }
} }
@ -271,7 +281,9 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext
if (!filterRegistrations.contains(filter)) { if (!filterRegistrations.contains(filter)) {
FilterRegistrationBean registration = new FilterRegistrationBean(filter); FilterRegistrationBean registration = new FilterRegistrationBean(filter);
registration.setName(name); registration.setName(name);
initializers.add(registration); registration.setOrder(CustomOrderAwareComparator.INSTANCE
.getOrder(filter));
filters.add(registration);
} }
} }
@ -285,12 +297,23 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext
ServletListenerRegistrationBean<EventListener> registration = new ServletListenerRegistrationBean<EventListener>( ServletListenerRegistrationBean<EventListener> registration = new ServletListenerRegistrationBean<EventListener>(
listener); listener);
registration.setName(name); registration.setName(name);
initializers.add(registration); registration.setOrder(CustomOrderAwareComparator.INSTANCE
.getOrder(listener));
listeners.add(registration);
} }
} }
} }
AnnotationAwareOrderComparator.sort(filters);
AnnotationAwareOrderComparator.sort(servlets);
AnnotationAwareOrderComparator.sort(listeners);
AnnotationAwareOrderComparator.sort(other);
return initializers; List<ServletContextInitializer> list = new ArrayList<ServletContextInitializer>(
filters);
list.addAll(servlets);
list.addAll(listeners);
list.addAll(other);
return list;
} }
private MultipartConfigElement getMultipartConfig() { private MultipartConfigElement getMultipartConfig() {
@ -425,4 +448,15 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext
return this.embeddedServletContainer; return this.embeddedServletContainer;
} }
private static class CustomOrderAwareComparator extends
AnnotationAwareOrderComparator {
public static CustomOrderAwareComparator INSTANCE = new CustomOrderAwareComparator();
@Override
protected int getOrder(Object obj) {
return super.getOrder(obj);
}
}
} }

@ -16,13 +16,18 @@
package org.springframework.boot.context.embedded; package org.springframework.boot.context.embedded;
import java.io.IOException;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.util.Properties; import java.util.Properties;
import javax.servlet.Filter; import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.Servlet; import javax.servlet.Servlet;
import javax.servlet.ServletContext; import javax.servlet.ServletContext;
import javax.servlet.ServletContextListener; import javax.servlet.ServletContextListener;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
@ -30,6 +35,7 @@ import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException; import org.junit.rules.ExpectedException;
import org.mockito.InOrder; import org.mockito.InOrder;
import org.mockito.Mockito;
import org.springframework.beans.MutablePropertyValues; import org.springframework.beans.MutablePropertyValues;
import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConstructorArgumentValues; import org.springframework.beans.factory.config.ConstructorArgumentValues;
@ -39,9 +45,11 @@ import org.springframework.context.ApplicationListener;
import org.springframework.context.support.AbstractApplicationContext; import org.springframework.context.support.AbstractApplicationContext;
import org.springframework.context.support.PropertySourcesPlaceholderConfigurer; import org.springframework.context.support.PropertySourcesPlaceholderConfigurer;
import org.springframework.core.Ordered; import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.web.context.ServletContextAware; import org.springframework.web.context.ServletContextAware;
import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.request.SessionScope; import org.springframework.web.context.request.SessionScope;
import org.springframework.web.filter.GenericFilterBean;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.instanceOf;
@ -192,6 +200,23 @@ public class EmbeddedWebApplicationContextTests {
verify(escf.getRegisteredServlet(0).getRegistration()).addMapping("/"); verify(escf.getRegisteredServlet(0).getRegistration()).addMapping("/");
} }
@Test
public void orderedBeanInsertedCorrectly() throws Exception {
addEmbeddedServletContainerFactoryBean();
OrderedFilter filter = new OrderedFilter();
this.context.registerBeanDefinition("filterBean", beanDefinition(filter));
FilterRegistrationBean registration = new FilterRegistrationBean();
registration.setFilter(Mockito.mock(Filter.class));
registration.setOrder(100);
this.context.registerBeanDefinition("filterRegistrationBean",
beanDefinition(registration));
this.context.refresh();
MockEmbeddedServletContainerFactory escf = getEmbeddedServletContainerFactory();
verify(escf.getServletContext()).addFilter("filterBean", filter);
verify(escf.getServletContext()).addFilter("object", registration.getFilter());
assertEquals(filter, escf.getRegisteredFilter(0).getFilter());
}
@Test @Test
public void multipleServletBeans() throws Exception { public void multipleServletBeans() throws Exception {
addEmbeddedServletContainerFactoryBean(); addEmbeddedServletContainerFactoryBean();
@ -422,4 +447,14 @@ public class EmbeddedWebApplicationContextTests {
} }
} }
@Order(10)
protected static class OrderedFilter extends GenericFilterBean {
@Override
public void doFilter(ServletRequest request, ServletResponse response,
FilterChain chain) throws IOException, ServletException {
}
}
} }

Loading…
Cancel
Save