Remove servet/filter class tangle

Remove class tangle between ServletRegistrationBean and
FilterRegistrationBean.
pull/7/head
Phillip Webb 12 years ago
parent beef5ab177
commit 4e15d705aa

@ -206,9 +206,6 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext
if (initializer instanceof RegistrationBean) { if (initializer instanceof RegistrationBean) {
targets.add(((RegistrationBean) initializer).getRegistrationTarget()); targets.add(((RegistrationBean) initializer).getRegistrationTarget());
} }
if (initializer instanceof ServletRegistrationBean) {
targets.addAll(((ServletRegistrationBean) initializer).getFilters());
}
initializers.add(initializer); initializers.add(initializer);
} }

@ -21,7 +21,6 @@ import java.util.Collection;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.Set; import java.util.Set;
import javax.servlet.Filter;
import javax.servlet.MultipartConfigElement; import javax.servlet.MultipartConfigElement;
import javax.servlet.Servlet; import javax.servlet.Servlet;
import javax.servlet.ServletContext; import javax.servlet.ServletContext;
@ -55,8 +54,6 @@ public class ServletRegistrationBean extends RegistrationBean {
private int loadOnStartup = 1; private int loadOnStartup = 1;
private Set<Filter> filters = new LinkedHashSet<Filter>();
private MultipartConfigElement multipartConfig; private MultipartConfigElement multipartConfig;
/** /**
@ -121,32 +118,6 @@ public class ServletRegistrationBean extends RegistrationBean {
this.loadOnStartup = loadOnStartup; this.loadOnStartup = loadOnStartup;
} }
/**
* Sets any Filters that should be registered to this servlet. Any previously
* specified Filters will be replaced.
* @param filters the Filters to set
*/
public void setFilters(Collection<? extends Filter> filters) {
Assert.notNull(filters, "Filters must not be null");
this.filters = new LinkedHashSet<Filter>(filters);
}
/**
* Returns a mutable collection of the Filters being registered with this servlet.
*/
public Collection<Filter> getFilters() {
return this.filters;
}
/**
* Add Filters that will be registered with this servlet.
* @param filters the Filters to add
*/
public void addFilters(Filter... filters) {
Assert.notNull(filters, "Filters must not be null");
this.filters.addAll(Arrays.asList(filters));
}
/** /**
* Set the the {@link MultipartConfigElement multi-part configuration}. * Set the the {@link MultipartConfigElement multi-part configuration}.
* @param multipartConfig the muti-part configuration to set or {@code null} * @param multipartConfig the muti-part configuration to set or {@code null}
@ -179,12 +150,6 @@ public class ServletRegistrationBean extends RegistrationBean {
public void onStartup(ServletContext servletContext) throws ServletException { public void onStartup(ServletContext servletContext) throws ServletException {
Assert.notNull(this.servlet, "Servlet must not be null"); Assert.notNull(this.servlet, "Servlet must not be null");
configure(servletContext.addServlet(getServletName(), this.servlet)); configure(servletContext.addServlet(getServletName(), this.servlet));
for (Filter filter : this.filters) {
FilterRegistrationBean filterRegistration = new FilterRegistrationBean(
filter, this);
filterRegistration.setAsyncSupported(isAsyncSupported());
filterRegistration.onStartup(servletContext);
}
} }
/** /**

@ -40,10 +40,6 @@ import org.springframework.core.Ordered;
import org.springframework.web.context.ServletContextAware; import org.springframework.web.context.ServletContextAware;
import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.request.SessionScope; import org.springframework.web.context.request.SessionScope;
import org.springframework.zero.context.embedded.EmbeddedWebApplicationContext;
import org.springframework.zero.context.embedded.FilterRegistrationBean;
import org.springframework.zero.context.embedded.ServletContextInitializer;
import org.springframework.zero.context.embedded.ServletRegistrationBean;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.instanceOf;
@ -310,7 +306,6 @@ public class EmbeddedWebApplicationContextTests {
Servlet servlet = mock(Servlet.class); Servlet servlet = mock(Servlet.class);
Filter filter = mock(Filter.class); Filter filter = mock(Filter.class);
ServletRegistrationBean initializer = new ServletRegistrationBean(servlet, "/foo"); ServletRegistrationBean initializer = new ServletRegistrationBean(servlet, "/foo");
initializer.addFilters(filter);
this.context.registerBeanDefinition("initializerBean", this.context.registerBeanDefinition("initializerBean",
beanDefinition(initializer)); beanDefinition(initializer));
this.context.registerBeanDefinition("servletBean", beanDefinition(servlet)); this.context.registerBeanDefinition("servletBean", beanDefinition(servlet));

@ -18,12 +18,10 @@ package org.springframework.zero.context.embedded;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap; import java.util.HashMap;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.Map; import java.util.Map;
import javax.servlet.DispatcherType;
import javax.servlet.Filter; import javax.servlet.Filter;
import javax.servlet.FilterRegistration; import javax.servlet.FilterRegistration;
import javax.servlet.Servlet; import javax.servlet.Servlet;
@ -36,11 +34,11 @@ import org.junit.Test;
import org.junit.rules.ExpectedException; import org.junit.rules.ExpectedException;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import org.springframework.zero.context.embedded.ServletRegistrationBean;
import static org.mockito.BDDMockito.*; import static org.mockito.BDDMockito.given;
import static org.mockito.Matchers.*; import static org.mockito.Matchers.anyObject;
import static org.mockito.Mockito.*; import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.verify;
/** /**
* Tests for {@link ServletRegistrationBean}. * Tests for {@link ServletRegistrationBean}.
@ -123,38 +121,38 @@ public class ServletRegistrationBeanTests {
@Test @Test
public void setServletMustNotBeNull() throws Exception { public void setServletMustNotBeNull() throws Exception {
ServletRegistrationBean bean = new ServletRegistrationBean(); ServletRegistrationBean bean = new ServletRegistrationBean();
thrown.expect(IllegalArgumentException.class); this.thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("Servlet must not be null"); this.thrown.expectMessage("Servlet must not be null");
bean.onStartup(this.servletContext); bean.onStartup(this.servletContext);
} }
@Test @Test
public void createServletMustNotBeNull() throws Exception { public void createServletMustNotBeNull() throws Exception {
thrown.expect(IllegalArgumentException.class); this.thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("Servlet must not be null"); this.thrown.expectMessage("Servlet must not be null");
new ServletRegistrationBean(null); new ServletRegistrationBean(null);
} }
@Test @Test
public void setMappingMustNotBeNull() throws Exception { public void setMappingMustNotBeNull() throws Exception {
ServletRegistrationBean bean = new ServletRegistrationBean(this.servlet); ServletRegistrationBean bean = new ServletRegistrationBean(this.servlet);
thrown.expect(IllegalArgumentException.class); this.thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("UrlMappings must not be null"); this.thrown.expectMessage("UrlMappings must not be null");
bean.setUrlMappings(null); bean.setUrlMappings(null);
} }
@Test @Test
public void createMappingMustNotBeNull() throws Exception { public void createMappingMustNotBeNull() throws Exception {
thrown.expect(IllegalArgumentException.class); this.thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("UrlMappings must not be null"); this.thrown.expectMessage("UrlMappings must not be null");
new ServletRegistrationBean(this.servlet, (String[]) null); new ServletRegistrationBean(this.servlet, (String[]) null);
} }
@Test @Test
public void addMappingMustNotBeNull() throws Exception { public void addMappingMustNotBeNull() throws Exception {
ServletRegistrationBean bean = new ServletRegistrationBean(this.servlet); ServletRegistrationBean bean = new ServletRegistrationBean(this.servlet);
thrown.expect(IllegalArgumentException.class); this.thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("UrlMappings must not be null"); this.thrown.expectMessage("UrlMappings must not be null");
bean.addUrlMappings((String[]) null); bean.addUrlMappings((String[]) null);
} }
@ -177,31 +175,4 @@ public class ServletRegistrationBeanTests {
verify(this.registration).setInitParameters(Collections.singletonMap("a", "c")); verify(this.registration).setInitParameters(Collections.singletonMap("a", "c"));
} }
@Test
public void filters() throws Exception {
ServletRegistrationBean bean = new ServletRegistrationBean(this.servlet);
Filter filter = new MockFilter();
bean.addFilters(filter);
bean.onStartup(this.servletContext);
verify(servletContext).addFilter("mockFilter", filter);
verify(filterRegistration).setAsyncSupported(true);
verify(filterRegistration).addMappingForServletNames(
EnumSet.of(DispatcherType.REQUEST, DispatcherType.FORWARD,
DispatcherType.INCLUDE, DispatcherType.ASYNC), false,
"mockServlet");
}
@Test
public void filtersNoAsync() throws Exception {
ServletRegistrationBean bean = new ServletRegistrationBean(this.servlet);
Filter filter = new MockFilter();
bean.addFilters(filter);
bean.setAsyncSupported(false);
bean.onStartup(this.servletContext);
verify(servletContext).addFilter("mockFilter", filter);
verify(filterRegistration).setAsyncSupported(false);
verify(filterRegistration).addMappingForServletNames(
EnumSet.of(DispatcherType.REQUEST, DispatcherType.FORWARD,
DispatcherType.INCLUDE), false, "mockServlet");
}
} }

Loading…
Cancel
Save