Order Filters, Servlets etc. separately in EmbeddedWebApplicationContext

Users could be surpised if they register a Filter with an @Order and it
isn't apparently respected. This change accumulates all Filters and
FilterRegistrations (for instance) before sorting them.

Fixes gh-1455
pull/1487/merge
Dave Syer 10 years ago
parent 1ddcf3657b
commit 47b59046bd

@ -225,27 +225,35 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext
*/
protected Collection<ServletContextInitializer> getServletContextInitializerBeans() {
Set<ServletContextInitializer> initializers = new LinkedHashSet<ServletContextInitializer>();
List<ServletContextInitializer> filters = new ArrayList<ServletContextInitializer>();
List<ServletContextInitializer> servlets = new ArrayList<ServletContextInitializer>();
List<ServletContextInitializer> listeners = new ArrayList<ServletContextInitializer>();
List<ServletContextInitializer> other = new ArrayList<ServletContextInitializer>();
Set<Servlet> servletRegistrations = new LinkedHashSet<Servlet>();
Set<Filter> filterRegistrations = new LinkedHashSet<Filter>();
Set<EventListener> listenerRegistrations = new LinkedHashSet<EventListener>();
for (Entry<String, ServletContextInitializer> initializerBean : getOrderedBeansOfType(ServletContextInitializer.class)) {
ServletContextInitializer initializer = initializerBean.getValue();
initializers.add(initializer);
if (initializer instanceof ServletRegistrationBean) {
servlets.add(initializer);
ServletRegistrationBean servlet = (ServletRegistrationBean) initializer;
servletRegistrations.add(servlet.getServlet());
}
if (initializer instanceof FilterRegistrationBean) {
else if (initializer instanceof FilterRegistrationBean) {
filters.add(initializer);
FilterRegistrationBean filter = (FilterRegistrationBean) initializer;
filterRegistrations.add(filter.getFilter());
}
if (initializer instanceof ServletListenerRegistrationBean) {
else if (initializer instanceof ServletListenerRegistrationBean) {
listeners.add(initializer);
listenerRegistrations
.add(((ServletListenerRegistrationBean<?>) initializer)
.getListener());
}
else {
other.add(initializer);
}
}
List<Entry<String, Servlet>> servletBeans = getOrderedBeansOfType(Servlet.class);
@ -261,7 +269,9 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext
servlet, url);
registration.setName(name);
registration.setMultipartConfig(getMultipartConfig());
initializers.add(registration);
registration.setOrder(CustomOrderAwareComparator.INSTANCE
.getOrder(servlet));
servlets.add(registration);
}
}
@ -271,7 +281,9 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext
if (!filterRegistrations.contains(filter)) {
FilterRegistrationBean registration = new FilterRegistrationBean(filter);
registration.setName(name);
initializers.add(registration);
registration.setOrder(CustomOrderAwareComparator.INSTANCE
.getOrder(filter));
filters.add(registration);
}
}
@ -285,12 +297,23 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext
ServletListenerRegistrationBean<EventListener> registration = new ServletListenerRegistrationBean<EventListener>(
listener);
registration.setName(name);
initializers.add(registration);
registration.setOrder(CustomOrderAwareComparator.INSTANCE
.getOrder(listener));
listeners.add(registration);
}
}
}
AnnotationAwareOrderComparator.sort(filters);
AnnotationAwareOrderComparator.sort(servlets);
AnnotationAwareOrderComparator.sort(listeners);
AnnotationAwareOrderComparator.sort(other);
return initializers;
List<ServletContextInitializer> list = new ArrayList<ServletContextInitializer>(
filters);
list.addAll(servlets);
list.addAll(listeners);
list.addAll(other);
return list;
}
private MultipartConfigElement getMultipartConfig() {
@ -425,4 +448,15 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext
return this.embeddedServletContainer;
}
private static class CustomOrderAwareComparator extends
AnnotationAwareOrderComparator {
public static CustomOrderAwareComparator INSTANCE = new CustomOrderAwareComparator();
@Override
protected int getOrder(Object obj) {
return super.getOrder(obj);
}
}
}

@ -16,13 +16,18 @@
package org.springframework.boot.context.embedded;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.Properties;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.Servlet;
import javax.servlet.ServletContext;
import javax.servlet.ServletContextListener;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import org.junit.After;
import org.junit.Before;
@ -30,6 +35,7 @@ import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.InOrder;
import org.mockito.Mockito;
import org.springframework.beans.MutablePropertyValues;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConstructorArgumentValues;
@ -39,9 +45,11 @@ import org.springframework.context.ApplicationListener;
import org.springframework.context.support.AbstractApplicationContext;
import org.springframework.context.support.PropertySourcesPlaceholderConfigurer;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.web.context.ServletContextAware;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.request.SessionScope;
import org.springframework.web.filter.GenericFilterBean;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
@ -192,6 +200,23 @@ public class EmbeddedWebApplicationContextTests {
verify(escf.getRegisteredServlet(0).getRegistration()).addMapping("/");
}
@Test
public void orderedBeanInsertedCorrectly() throws Exception {
addEmbeddedServletContainerFactoryBean();
OrderedFilter filter = new OrderedFilter();
this.context.registerBeanDefinition("filterBean", beanDefinition(filter));
FilterRegistrationBean registration = new FilterRegistrationBean();
registration.setFilter(Mockito.mock(Filter.class));
registration.setOrder(100);
this.context.registerBeanDefinition("filterRegistrationBean",
beanDefinition(registration));
this.context.refresh();
MockEmbeddedServletContainerFactory escf = getEmbeddedServletContainerFactory();
verify(escf.getServletContext()).addFilter("filterBean", filter);
verify(escf.getServletContext()).addFilter("object", registration.getFilter());
assertEquals(filter, escf.getRegisteredFilter(0).getFilter());
}
@Test
public void multipleServletBeans() throws Exception {
addEmbeddedServletContainerFactoryBean();
@ -422,4 +447,14 @@ public class EmbeddedWebApplicationContextTests {
}
}
@Order(10)
protected static class OrderedFilter extends GenericFilterBean {
@Override
public void doFilter(ServletRequest request, ServletResponse response,
FilterChain chain) throws IOException, ServletException {
}
}
}

Loading…
Cancel
Save