Handle servlet startup failures consistently

Ensure that all servlet containers handle servlet startup failures
consistently and throw a `WebServerException` that wraps the original
cause.

Both Undertow and Jetty already dealt with startup failures in this
way, but Tomcat did not. The `TomcatEmbeddedContext` has now been
changed to no longer call `super.loadOnStartup` but instead re-implement
a version of that method that wraps and rethrows the original exception
(as long as `failCtxIfServletStartFails` is `true`, which it now is by
default).

Closes gh-14790
pull/14228/merge
Phillip Webb 6 years ago
parent 4823114e1c
commit 683e9532d6

@ -16,13 +16,23 @@
package org.springframework.boot.web.embedded.tomcat; package org.springframework.boot.web.embedded.tomcat;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.stream.Stream;
import javax.servlet.ServletException;
import org.apache.catalina.Container; import org.apache.catalina.Container;
import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleException;
import org.apache.catalina.Manager; import org.apache.catalina.Manager;
import org.apache.catalina.Wrapper;
import org.apache.catalina.core.StandardContext; import org.apache.catalina.core.StandardContext;
import org.apache.catalina.core.StandardWrapper;
import org.apache.catalina.session.ManagerBase; import org.apache.catalina.session.ManagerBase;
import org.springframework.util.Assert; import org.springframework.boot.web.server.WebServerException;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
/** /**
@ -52,11 +62,37 @@ class TomcatEmbeddedContext extends StandardContext {
public void deferredLoadOnStartup() throws LifecycleException { public void deferredLoadOnStartup() throws LifecycleException {
doWithThreadContextClassLoader(getLoader().getClassLoader(), () -> { doWithThreadContextClassLoader(getLoader().getClassLoader(), () -> {
boolean started = super.loadOnStartup(findChildren()); getLoadOnStartupWrappers(findChildren()).forEach(this::load);
Assert.state(started, "Unable to start embedded tomcat context " + getName());
}); });
} }
private Stream<Wrapper> getLoadOnStartupWrappers(Container[] children) {
Map<Integer, List<Wrapper>> grouped = new TreeMap<>();
for (Container child : children) {
Wrapper wrapper = (Wrapper) child;
int order = wrapper.getLoadOnStartup();
if (order >= 0) {
grouped.computeIfAbsent(order, ArrayList::new);
grouped.get(order).add(wrapper);
}
}
return grouped.values().stream().flatMap(List::stream);
}
private void load(Wrapper wrapper) {
try {
wrapper.load();
}
catch (ServletException ex) {
String message = sm.getString("standardContext.loadOnStartup.loadException",
getName(), wrapper.getName());
if (getComputedFailCtxIfServletStartFails()) {
throw new WebServerException(message, ex);
}
getLogger().error(message, StandardWrapper.getRootCause(ex));
}
}
/** /**
* Some older Servlet frameworks (e.g. Struts, BIRT) use the Thread context class * Some older Servlet frameworks (e.g. Struts, BIRT) use the Thread context class
* loader to create servlet instances in this phase. If they do that and then try to * loader to create servlet instances in this phase. If they do that and then try to

@ -329,8 +329,9 @@ public class TomcatServletWebServerFactory extends AbstractServletWebServerFacto
ServletContextInitializer[] initializers) { ServletContextInitializer[] initializers) {
TomcatStarter starter = new TomcatStarter(initializers); TomcatStarter starter = new TomcatStarter(initializers);
if (context instanceof TomcatEmbeddedContext) { if (context instanceof TomcatEmbeddedContext) {
// Should be true TomcatEmbeddedContext embeddedContext = (TomcatEmbeddedContext) context;
((TomcatEmbeddedContext) context).setStarter(starter); embeddedContext.setStarter(starter);
embeddedContext.setFailCtxIfServletStartFails(true);
} }
context.addServletContainerInitializer(starter, NO_CLASSES); context.addServletContainerInitializer(starter, NO_CLASSES);
for (LifecycleListener lifecycleListener : this.contextLifecycleListeners) { for (LifecycleListener lifecycleListener : this.contextLifecycleListeners) {

@ -284,7 +284,9 @@ public class TomcatWebServer implements WebServer {
} }
} }
catch (Exception ex) { catch (Exception ex) {
logger.error("Cannot start connector: ", ex); if (ex instanceof WebServerException) {
throw (WebServerException) ex;
}
throw new WebServerException("Unable to start embedded Tomcat connectors", throw new WebServerException("Unable to start embedded Tomcat connectors",
ex); ex);
} }

@ -30,7 +30,6 @@ import java.util.Set;
import javax.naming.InitialContext; import javax.naming.InitialContext;
import javax.naming.NamingException; import javax.naming.NamingException;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import org.apache.catalina.Container; import org.apache.catalina.Container;
import org.apache.catalina.Context; import org.apache.catalina.Context;
@ -440,6 +439,20 @@ public class TomcatServletWebServerFactoryTests
.isThrownBy(this.webServer::start); .isThrownBy(this.webServer::start);
} }
@Test
public void exceptionThrownOnLoadFailureWhenFailCtxIfServletStartFailsIsFalse() {
TomcatServletWebServerFactory factory = getFactory();
factory.addContextCustomizers((context) -> {
if (context instanceof StandardContext) {
((StandardContext) context).setFailCtxIfServletStartFails(false);
}
});
this.webServer = factory.getWebServer((context) -> {
context.addServlet("failing", FailingServlet.class).setLoadOnStartup(0);
});
this.webServer.start();
}
@Override @Override
protected JspServlet getJspServlet() throws ServletException { protected JspServlet getJspServlet() throws ServletException {
Tomcat tomcat = ((TomcatWebServer) this.webServer).getTomcat(); Tomcat tomcat = ((TomcatWebServer) this.webServer).getTomcat();
@ -488,13 +501,4 @@ public class TomcatServletWebServerFactoryTests
assertThat(((ConnectorStartFailedException) ex).getPort()).isEqualTo(blockedPort); assertThat(((ConnectorStartFailedException) ex).getPort()).isEqualTo(blockedPort);
} }
static class FailingServlet extends HttpServlet {
@Override
public void init() throws ServletException {
throw new RuntimeException("Init Failure");
}
}
} }

@ -1046,6 +1046,28 @@ public abstract class AbstractServletWebServerFactoryTests {
verify(listener).contextDestroyed(any(ServletContextEvent.class)); verify(listener).contextDestroyed(any(ServletContextEvent.class));
} }
@Test
public void exceptionThrownOnLoadFailureIsRethrown() {
AbstractServletWebServerFactory factory = getFactory();
this.webServer = factory.getWebServer((context) -> {
context.addServlet("failing", FailingServlet.class).setLoadOnStartup(0);
});
assertThatExceptionOfType(WebServerException.class)
.isThrownBy(this.webServer::start)
.satisfies(this::wrapsFailingServletException);
}
private void wrapsFailingServletException(WebServerException ex) {
Throwable cause = ex.getCause();
while (cause != null) {
if (cause instanceof FailingServletException) {
return;
}
cause = cause.getCause();
}
fail("Exception did not wrap FailingServletException");
}
protected abstract void addConnector(int port, protected abstract void addConnector(int port,
AbstractServletWebServerFactory factory); AbstractServletWebServerFactory factory);
@ -1344,4 +1366,21 @@ public abstract class AbstractServletWebServerFactoryTests {
} }
public static class FailingServlet extends HttpServlet {
@Override
public void init() throws ServletException {
throw new FailingServletException();
}
}
private static class FailingServletException extends RuntimeException {
FailingServletException() {
super("Init Failure");
}
}
} }

Loading…
Cancel
Save