diff --git a/spring-bootstrap/src/main/java/org/springframework/bootstrap/autoconfigure/web/MultipartAutoConfiguration.java b/spring-bootstrap/src/main/java/org/springframework/bootstrap/autoconfigure/web/MultipartAutoConfiguration.java new file mode 100644 index 0000000000..b87e997e43 --- /dev/null +++ b/spring-bootstrap/src/main/java/org/springframework/bootstrap/autoconfigure/web/MultipartAutoConfiguration.java @@ -0,0 +1,43 @@ +/* + * Copyright 2012-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.bootstrap.autoconfigure.web; + +import javax.servlet.MultipartConfigElement; + +import org.springframework.bootstrap.context.annotation.ConditionalOnBean; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.multipart.support.StandardServletMultipartResolver; + +/** + * Autoconfiguration for multipart uploads. It detects the existence of a + * {@link MultipartConfigElement} in the app context and then adds critical beans + * while also autowiring it into the Jetty/Tomcat embedded containers. + * + * @author Greg Turnquist + * + */ +@Configuration +public class MultipartAutoConfiguration { + + @ConditionalOnBean(MultipartConfigElement.class) + @Bean + public StandardServletMultipartResolver multipartResolver() { + System.out.println("Loading up a MultipartResolver!!!"); + return new StandardServletMultipartResolver(); + } + +} diff --git a/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/embedded/EmbeddedServletContainerFactory.java b/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/embedded/EmbeddedServletContainerFactory.java index 508983f713..b0c303930d 100644 --- a/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/embedded/EmbeddedServletContainerFactory.java +++ b/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/embedded/EmbeddedServletContainerFactory.java @@ -41,6 +41,6 @@ public interface EmbeddedServletContainerFactory { * @see EmbeddedServletContainer#stop() */ EmbeddedServletContainer getEmbdeddedServletContainer( - ServletContextInitializer... initializers); + ServletContextInitializer... initializers); //TODO(6/14/2013) Fix name of method } diff --git a/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/embedded/EmbeddedWebApplicationContext.java b/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/embedded/EmbeddedWebApplicationContext.java index c7c0d65660..33131fe6d7 100644 --- a/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/embedded/EmbeddedWebApplicationContext.java +++ b/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/embedded/EmbeddedWebApplicationContext.java @@ -23,10 +23,12 @@ import java.util.Comparator; import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; +import java.util.Map; import java.util.Map.Entry; import java.util.Set; import javax.servlet.Filter; +import javax.servlet.MultipartConfigElement; import javax.servlet.Servlet; import javax.servlet.ServletConfig; import javax.servlet.ServletContext; @@ -208,10 +210,17 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext } initializers.add(initializer); } + + Map multipartConfigBeans; + MultipartConfigElement multipartConfigElement = null; + multipartConfigBeans = getBeanFactory().getBeansOfType(MultipartConfigElement.class); + for (MultipartConfigElement bean : multipartConfigBeans.values()) { + multipartConfigElement = bean; + } List> servletBeans = getOrderedBeansOfType(Servlet.class); for (Entry servletBean : servletBeans) { - String name = servletBean.getKey(); + final String name = servletBean.getKey(); Servlet servlet = servletBean.getValue(); if (targets.contains(servlet)) { continue; @@ -220,10 +229,15 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext if (name.equals(DISPATCHER_SERVLET_NAME)) { url = "/"; // always map the main dispatcherServlet to "/" } - ServletRegistrationBean registration = new ServletRegistrationBean(servlet, - url); - registration.setName(name); - initializers.add(registration); + if (multipartConfigElement != null) { + initializers.add(new ServletRegistrationBean(servlet, multipartConfigElement, url) {{ + setName(name); + }}); + } else { + initializers.add(new ServletRegistrationBean(servlet, url) {{ + setName(name); + }}); + } } for (Entry filterBean : getOrderedBeansOfType(Filter.class)) { diff --git a/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/embedded/ServletRegistrationBean.java b/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/embedded/ServletRegistrationBean.java index 3cf5a48077..b796bb959c 100644 --- a/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/embedded/ServletRegistrationBean.java +++ b/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/embedded/ServletRegistrationBean.java @@ -22,6 +22,7 @@ import java.util.LinkedHashSet; import java.util.Set; import javax.servlet.Filter; +import javax.servlet.MultipartConfigElement; import javax.servlet.Servlet; import javax.servlet.ServletContext; import javax.servlet.ServletException; @@ -55,6 +56,8 @@ public class ServletRegistrationBean extends RegistrationBean { private int loadOnStartup = 1; private Set filters = new LinkedHashSet(); + + private MultipartConfigElement multipartConfigElement = null; /** * Create a new {@link ServletRegistrationBean} instance. @@ -72,6 +75,11 @@ public class ServletRegistrationBean extends RegistrationBean { setServlet(servlet); addUrlMappings(urlMappings); } + + public ServletRegistrationBean(Servlet servlet, MultipartConfigElement multipartConfigElement, String... urlMappings) { + this(servlet, urlMappings); + this.multipartConfigElement = multipartConfigElement; + } /** * Sets the servlet to be registered. @@ -181,5 +189,8 @@ public class ServletRegistrationBean extends RegistrationBean { } registration.addMapping(urlMapping); registration.setLoadOnStartup(this.loadOnStartup); + if (multipartConfigElement != null) { + registration.setMultipartConfig(multipartConfigElement); + } } } diff --git a/spring-bootstrap/src/test/java/org/springframework/bootstrap/autoconfigure/web/MultipartAutoConfigurationTests.java b/spring-bootstrap/src/test/java/org/springframework/bootstrap/autoconfigure/web/MultipartAutoConfigurationTests.java new file mode 100644 index 0000000000..a42734c934 --- /dev/null +++ b/spring-bootstrap/src/test/java/org/springframework/bootstrap/autoconfigure/web/MultipartAutoConfigurationTests.java @@ -0,0 +1,234 @@ +/* + * Copyright 2012-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.bootstrap.autoconfigure.web; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.fail; + +import javax.servlet.MultipartConfigElement; + +import org.junit.Test; +import org.springframework.beans.factory.NoSuchBeanDefinitionException; +import org.springframework.bootstrap.context.embedded.AnnotationConfigEmbeddedWebApplicationContext; +import org.springframework.bootstrap.context.embedded.jetty.JettyEmbeddedServletContainerFactory; +import org.springframework.bootstrap.context.embedded.tomcat.TomcatEmbeddedServletContainerFactory; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.stereotype.Controller; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.ResponseBody; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.multipart.MultipartResolver; +import org.springframework.web.multipart.support.StandardServletMultipartResolver; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; + +/** + * A series of embedded unit tests, based on an empty configuration, no multipart + * configuration, and a multipart configuration, with both Jetty and Tomcat. + * + * @author Greg Turnquist + */ +public class MultipartAutoConfigurationTests { + + private AnnotationConfigEmbeddedWebApplicationContext context; + + @Test + public void containerWithNothing() { + this.context = new AnnotationConfigEmbeddedWebApplicationContext( + ContainerWithNothing.class, + EmbeddedServletContainerAutoConfiguration.class, + MultipartAutoConfiguration.class); + try { + DispatcherServlet servlet = this.context.getBean(DispatcherServlet.class); + assertNull(servlet.getMultipartResolver()); + try { + this.context.getBean(StandardServletMultipartResolver.class); + fail("Expected to receive a " + NoSuchBeanDefinitionException.class); + } catch (NoSuchBeanDefinitionException e) { + } + try { + this.context.getBean(MultipartResolver.class); + fail("Expected to receive a " + NoSuchBeanDefinitionException.class); + } catch (NoSuchBeanDefinitionException e) { + } + } finally { + this.context.close(); + } + } + + @Configuration + public static class ContainerWithNothing { + } + + @Test + public void containerWithNoMultipartJettyConfiguration() { + this.context = new AnnotationConfigEmbeddedWebApplicationContext( + ContainerWithNoMultipartJetty.class, + EmbeddedServletContainerAutoConfiguration.class, + MultipartAutoConfiguration.class); + try { + DispatcherServlet servlet = this.context.getBean(DispatcherServlet.class); + assertNull(servlet.getMultipartResolver()); + try { + this.context.getBean(StandardServletMultipartResolver.class); + fail("Expected to receive a " + NoSuchBeanDefinitionException.class); + } catch (NoSuchBeanDefinitionException e) { + } + try { + this.context.getBean(MultipartResolver.class); + fail("Expected to receive a " + NoSuchBeanDefinitionException.class); + } catch (NoSuchBeanDefinitionException e) { + } + verifyServletWorks(); + } finally { + this.context.close(); + } + } + + @Configuration + public static class ContainerWithNoMultipartJetty { + @Bean + JettyEmbeddedServletContainerFactory containerFactory() { + return new JettyEmbeddedServletContainerFactory(); + } + @Bean + WebController controller() { + return new WebController(); + } + } + + @Test + public void containerWithNoMultipartTomcatConfiguration() { + this.context = new AnnotationConfigEmbeddedWebApplicationContext( + ContainerWithNoMultipartTomcat.class, + EmbeddedServletContainerAutoConfiguration.class, + MultipartAutoConfiguration.class); + try { + DispatcherServlet servlet = this.context.getBean(DispatcherServlet.class); + assertNull(servlet.getMultipartResolver()); + try { + this.context.getBean(StandardServletMultipartResolver.class); + fail("Expected to receive a " + NoSuchBeanDefinitionException.class); + } catch (NoSuchBeanDefinitionException e) { + } + try { + this.context.getBean(MultipartResolver.class); + fail("Expected to receive a " + NoSuchBeanDefinitionException.class); + } catch (NoSuchBeanDefinitionException e) { + } + verifyServletWorks(); + } finally { + this.context.close(); + } + } + + @Configuration + public static class ContainerWithNoMultipartTomcat { + @Bean + TomcatEmbeddedServletContainerFactory containerFactory() { + return new TomcatEmbeddedServletContainerFactory(); + } + @Bean + WebController controller() { + return new WebController(); + } + } + + @Test + public void containerWithAutomatedMultipartJettyConfiguration() { + this.context = new AnnotationConfigEmbeddedWebApplicationContext( + ContainerWithEverythingJetty.class, + EmbeddedServletContainerAutoConfiguration.class, + MultipartAutoConfiguration.class); + try { + this.context.getBean(MultipartConfigElement.class); + assertSame( + this.context.getBean(DispatcherServlet.class).getMultipartResolver(), + this.context.getBean(StandardServletMultipartResolver.class)); + verifyServletWorks(); + } finally { + this.context.close(); + } + } + + @Configuration + public static class ContainerWithEverythingJetty { + @Bean + MultipartConfigElement multipartConfigElement() { + return new MultipartConfigElement(""); + } + @Bean + JettyEmbeddedServletContainerFactory containerFactory() { + return new JettyEmbeddedServletContainerFactory(); + } + @Bean + WebController webController() { + return new WebController(); + } + } + + @Test + public void containerWithAutomatedMultipartTomcatConfiguration() { + this.context = new AnnotationConfigEmbeddedWebApplicationContext( + ContainerWithEverythingTomcat.class, + EmbeddedServletContainerAutoConfiguration.class, + MultipartAutoConfiguration.class); + try { + this.context.getBean(MultipartConfigElement.class); + assertSame( + this.context.getBean(DispatcherServlet.class).getMultipartResolver(), + this.context.getBean(StandardServletMultipartResolver.class)); + verifyServletWorks(); + } finally { + this.context.close(); + } + } + + @Configuration + @EnableWebMvc + public static class ContainerWithEverythingTomcat { + @Bean + MultipartConfigElement multipartConfigElement() { + return new MultipartConfigElement(""); + } + @Bean + TomcatEmbeddedServletContainerFactory containerFactory() { + return new TomcatEmbeddedServletContainerFactory(); + } + @Bean + WebController webController() { + return new WebController(); + } + } + + @Controller + public static class WebController { + @RequestMapping("/") + public @ResponseBody String index() { + return "Hello"; + } + } + + private void verifyServletWorks() { + RestTemplate restTemplate = new RestTemplate(); + assertEquals(restTemplate.getForObject("http://localhost:8080/", String.class), "Hello"); + } + + +}