diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/embedded/tomcat/TomcatEmbeddedContext.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/embedded/tomcat/TomcatEmbeddedContext.java index 9384c259fc..1cb2008928 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/embedded/tomcat/TomcatEmbeddedContext.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/embedded/tomcat/TomcatEmbeddedContext.java @@ -16,13 +16,23 @@ package org.springframework.boot.web.embedded.tomcat; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.stream.Stream; + +import javax.servlet.ServletException; + import org.apache.catalina.Container; import org.apache.catalina.LifecycleException; import org.apache.catalina.Manager; +import org.apache.catalina.Wrapper; import org.apache.catalina.core.StandardContext; +import org.apache.catalina.core.StandardWrapper; import org.apache.catalina.session.ManagerBase; -import org.springframework.util.Assert; +import org.springframework.boot.web.server.WebServerException; import org.springframework.util.ClassUtils; /** @@ -52,11 +62,37 @@ class TomcatEmbeddedContext extends StandardContext { public void deferredLoadOnStartup() throws LifecycleException { doWithThreadContextClassLoader(getLoader().getClassLoader(), () -> { - boolean started = super.loadOnStartup(findChildren()); - Assert.state(started, "Unable to start embedded tomcat context " + getName()); + getLoadOnStartupWrappers(findChildren()).forEach(this::load); }); } + private Stream getLoadOnStartupWrappers(Container[] children) { + Map> grouped = new TreeMap<>(); + for (Container child : children) { + Wrapper wrapper = (Wrapper) child; + int order = wrapper.getLoadOnStartup(); + if (order >= 0) { + grouped.computeIfAbsent(order, ArrayList::new); + grouped.get(order).add(wrapper); + } + } + return grouped.values().stream().flatMap(List::stream); + } + + private void load(Wrapper wrapper) { + try { + wrapper.load(); + } + catch (ServletException ex) { + String message = sm.getString("standardContext.loadOnStartup.loadException", + getName(), wrapper.getName()); + if (getComputedFailCtxIfServletStartFails()) { + throw new WebServerException(message, ex); + } + getLogger().error(message, StandardWrapper.getRootCause(ex)); + } + } + /** * Some older Servlet frameworks (e.g. Struts, BIRT) use the Thread context class * loader to create servlet instances in this phase. If they do that and then try to diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/embedded/tomcat/TomcatServletWebServerFactory.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/embedded/tomcat/TomcatServletWebServerFactory.java index 646b756f8a..c7e4cebb39 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/embedded/tomcat/TomcatServletWebServerFactory.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/embedded/tomcat/TomcatServletWebServerFactory.java @@ -329,8 +329,9 @@ public class TomcatServletWebServerFactory extends AbstractServletWebServerFacto ServletContextInitializer[] initializers) { TomcatStarter starter = new TomcatStarter(initializers); if (context instanceof TomcatEmbeddedContext) { - // Should be true - ((TomcatEmbeddedContext) context).setStarter(starter); + TomcatEmbeddedContext embeddedContext = (TomcatEmbeddedContext) context; + embeddedContext.setStarter(starter); + embeddedContext.setFailCtxIfServletStartFails(true); } context.addServletContainerInitializer(starter, NO_CLASSES); for (LifecycleListener lifecycleListener : this.contextLifecycleListeners) { diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/embedded/tomcat/TomcatWebServer.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/embedded/tomcat/TomcatWebServer.java index 75d67d781e..5629fb8bd4 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/embedded/tomcat/TomcatWebServer.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/embedded/tomcat/TomcatWebServer.java @@ -284,7 +284,9 @@ public class TomcatWebServer implements WebServer { } } catch (Exception ex) { - logger.error("Cannot start connector: ", ex); + if (ex instanceof WebServerException) { + throw (WebServerException) ex; + } throw new WebServerException("Unable to start embedded Tomcat connectors", ex); } diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/embedded/tomcat/TomcatServletWebServerFactoryTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/embedded/tomcat/TomcatServletWebServerFactoryTests.java index b78488a961..f7a866a4b2 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/embedded/tomcat/TomcatServletWebServerFactoryTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/embedded/tomcat/TomcatServletWebServerFactoryTests.java @@ -30,7 +30,6 @@ import java.util.Set; import javax.naming.InitialContext; import javax.naming.NamingException; import javax.servlet.ServletException; -import javax.servlet.http.HttpServlet; import org.apache.catalina.Container; import org.apache.catalina.Context; @@ -440,6 +439,20 @@ public class TomcatServletWebServerFactoryTests .isThrownBy(this.webServer::start); } + @Test + public void exceptionThrownOnLoadFailureWhenFailCtxIfServletStartFailsIsFalse() { + TomcatServletWebServerFactory factory = getFactory(); + factory.addContextCustomizers((context) -> { + if (context instanceof StandardContext) { + ((StandardContext) context).setFailCtxIfServletStartFails(false); + } + }); + this.webServer = factory.getWebServer((context) -> { + context.addServlet("failing", FailingServlet.class).setLoadOnStartup(0); + }); + this.webServer.start(); + } + @Override protected JspServlet getJspServlet() throws ServletException { Tomcat tomcat = ((TomcatWebServer) this.webServer).getTomcat(); @@ -488,13 +501,4 @@ public class TomcatServletWebServerFactoryTests assertThat(((ConnectorStartFailedException) ex).getPort()).isEqualTo(blockedPort); } - static class FailingServlet extends HttpServlet { - - @Override - public void init() throws ServletException { - throw new RuntimeException("Init Failure"); - } - - } - } diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/servlet/server/AbstractServletWebServerFactoryTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/servlet/server/AbstractServletWebServerFactoryTests.java index 14307e0083..871c373fc6 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/servlet/server/AbstractServletWebServerFactoryTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/servlet/server/AbstractServletWebServerFactoryTests.java @@ -1046,6 +1046,28 @@ public abstract class AbstractServletWebServerFactoryTests { verify(listener).contextDestroyed(any(ServletContextEvent.class)); } + @Test + public void exceptionThrownOnLoadFailureIsRethrown() { + AbstractServletWebServerFactory factory = getFactory(); + this.webServer = factory.getWebServer((context) -> { + context.addServlet("failing", FailingServlet.class).setLoadOnStartup(0); + }); + assertThatExceptionOfType(WebServerException.class) + .isThrownBy(this.webServer::start) + .satisfies(this::wrapsFailingServletException); + } + + private void wrapsFailingServletException(WebServerException ex) { + Throwable cause = ex.getCause(); + while (cause != null) { + if (cause instanceof FailingServletException) { + return; + } + cause = cause.getCause(); + } + fail("Exception did not wrap FailingServletException"); + } + protected abstract void addConnector(int port, AbstractServletWebServerFactory factory); @@ -1344,4 +1366,21 @@ public abstract class AbstractServletWebServerFactoryTests { } + public static class FailingServlet extends HttpServlet { + + @Override + public void init() throws ServletException { + throw new FailingServletException(); + } + + } + + private static class FailingServletException extends RuntimeException { + + FailingServletException() { + super("Init Failure"); + } + + } + }