Call context.close() rather than shutdown hook in DevTools restart

Previously, when DevTools was restarting the application it would
use reflection to run all of the JVM's shutdown hooks. This was done
to close any SpringApplications' application contexts. Unfortunately,
it had the unwanted side-effect of running other shutdown hooks as
well.

The other shutdown hooks were often written with the, entirely
reasonable, expectation that they would only be called when the JVM
was shutting down. Calling them at another time could leave the
hook's library in an unexpected state. One such example is Log4J2
which was worked around in aaae4aa3 (see gh-4279). Another is the
problem with Eureka (see gh-4097). There's no work around for this
problem, even with reflective hackery, hence the change being made
here.

This commit updates the Restarter so that shutdown hooks are no longer
called during a restart. This removes the chance of a restart having
the unwanted side-effect of leaving a third-party library in a broken
state. RestartApplicationListener now prepares the Restarter with the
root application context, and the Restarter then closes it as part of
the restart. The changes have been tested with an application that
uses a single context and an application with a context hierarchy.

Closes gh-4097
pull/3499/merge
Andy Wilkinson 9 years ago
parent 6e3faecce6
commit 2522a5f9ef

@ -1,63 +0,0 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.devtools.log4j2;
import java.lang.reflect.Field;
import java.util.Collection;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.core.util.Cancellable;
import org.apache.logging.log4j.core.util.ShutdownCallbackRegistry;
import org.apache.logging.log4j.spi.LoggerContextFactory;
import org.springframework.boot.devtools.restart.RestartListener;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;
/**
* {@link RestartListener} that prepares Log4J2 for an application restart.
*
* @author Andy Wilkinson
* @since 1.3.0
*/
public class Log4J2RestartListener implements RestartListener {
@Override
public void beforeRestart() {
if (ClassUtils.isPresent("org.apache.logging.log4j.core.impl.Log4jContextFactory",
getClass().getClassLoader())) {
prepareLog4J2ForRestart();
}
}
private void prepareLog4J2ForRestart() {
LoggerContextFactory factory = LogManager.getFactory();
Field field = ReflectionUtils.findField(factory.getClass(),
"shutdownCallbackRegistry");
ReflectionUtils.makeAccessible(field);
ShutdownCallbackRegistry shutdownCallbackRegistry = (ShutdownCallbackRegistry) ReflectionUtils
.getField(field, factory);
Field hooksField = ReflectionUtils.findField(shutdownCallbackRegistry.getClass(),
"hooks");
ReflectionUtils.makeAccessible(hooksField);
@SuppressWarnings("unchecked")
Collection<Cancellable> hooks = (Collection<Cancellable>) ReflectionUtils
.getField(hooksField, shutdownCallbackRegistry);
hooks.clear();
}
}

@ -1,5 +1,5 @@
/* /*
* Copyright 2012-2015 the original author or authors. * Copyright 2012-2016 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,20 +16,19 @@
package org.springframework.boot.devtools.restart; package org.springframework.boot.devtools.restart;
import java.util.List;
import org.springframework.boot.context.event.ApplicationFailedEvent; import org.springframework.boot.context.event.ApplicationFailedEvent;
import org.springframework.boot.context.event.ApplicationPreparedEvent;
import org.springframework.boot.context.event.ApplicationReadyEvent; import org.springframework.boot.context.event.ApplicationReadyEvent;
import org.springframework.boot.context.event.ApplicationStartedEvent; import org.springframework.boot.context.event.ApplicationStartedEvent;
import org.springframework.context.ApplicationEvent; import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationListener; import org.springframework.context.ApplicationListener;
import org.springframework.core.Ordered; import org.springframework.core.Ordered;
import org.springframework.core.io.support.SpringFactoriesLoader;
/** /**
* {@link ApplicationListener} to initialize the {@link Restarter}. * {@link ApplicationListener} to initialize the {@link Restarter}.
* *
* @author Phillip Webb * @author Phillip Webb
* @author Andy Wilkinson
* @since 1.3.0 * @since 1.3.0
* @see Restarter * @see Restarter
*/ */
@ -45,9 +44,16 @@ public class RestartApplicationListener
if (event instanceof ApplicationStartedEvent) { if (event instanceof ApplicationStartedEvent) {
onApplicationStartedEvent((ApplicationStartedEvent) event); onApplicationStartedEvent((ApplicationStartedEvent) event);
} }
if (event instanceof ApplicationPreparedEvent) {
Restarter.getInstance()
.prepare(((ApplicationPreparedEvent) event).getApplicationContext());
}
if (event instanceof ApplicationReadyEvent if (event instanceof ApplicationReadyEvent
|| event instanceof ApplicationFailedEvent) { || event instanceof ApplicationFailedEvent) {
Restarter.getInstance().finish(); Restarter.getInstance().finish();
if (event instanceof ApplicationFailedEvent) {
Restarter.getInstance().prepare(null);
}
} }
} }
@ -59,11 +65,7 @@ public class RestartApplicationListener
String[] args = event.getArgs(); String[] args = event.getArgs();
DefaultRestartInitializer initializer = new DefaultRestartInitializer(); DefaultRestartInitializer initializer = new DefaultRestartInitializer();
boolean restartOnInitialize = !AgentReloader.isActive(); boolean restartOnInitialize = !AgentReloader.isActive();
List<RestartListener> restartListeners = SpringFactoriesLoader Restarter.initialize(args, false, initializer, restartOnInitialize);
.loadFactories(RestartListener.class, getClass().getClassLoader());
Restarter.initialize(args, false, initializer, restartOnInitialize,
restartListeners
.toArray(new RestartListener[restartListeners.size()]));
} }
else { else {
Restarter.disable(); Restarter.disable();

@ -1,5 +1,5 @@
/* /*
* Copyright 2012-2015 the original author or authors. * Copyright 2012-2016 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -19,12 +19,10 @@ package org.springframework.boot.devtools.restart;
import java.beans.Introspector; import java.beans.Introspector;
import java.lang.Thread.UncaughtExceptionHandler; import java.lang.Thread.UncaughtExceptionHandler;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.net.URL; import java.net.URL;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Iterator; import java.util.Iterator;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.LinkedList; import java.util.LinkedList;
@ -49,6 +47,7 @@ import org.springframework.boot.devtools.restart.classloader.ClassLoaderFiles;
import org.springframework.boot.devtools.restart.classloader.RestartClassLoader; import org.springframework.boot.devtools.restart.classloader.RestartClassLoader;
import org.springframework.boot.logging.DeferredLog; import org.springframework.boot.logging.DeferredLog;
import org.springframework.cglib.core.ClassNameReader; import org.springframework.cglib.core.ClassNameReader;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.core.ResolvableType; import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.AnnotationUtils; import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -72,6 +71,7 @@ import org.springframework.util.ReflectionUtils;
* URLs or class file updates for remote restart scenarios. * URLs or class file updates for remote restart scenarios.
* *
* @author Phillip Webb * @author Phillip Webb
* @author Andy Wilkinson
* @since 1.3.0 * @since 1.3.0
* @see RestartApplicationListener * @see RestartApplicationListener
* @see #initialize(String[]) * @see #initialize(String[])
@ -108,11 +108,11 @@ public class Restarter {
private final BlockingDeque<LeakSafeThread> leakSafeThreads = new LinkedBlockingDeque<LeakSafeThread>(); private final BlockingDeque<LeakSafeThread> leakSafeThreads = new LinkedBlockingDeque<LeakSafeThread>();
private final RestartListener[] listeners;
private boolean finished = false; private boolean finished = false;
private Lock stopLock = new ReentrantLock(); private final Lock stopLock = new ReentrantLock();
private volatile ConfigurableApplicationContext rootContext;
/** /**
* Internal constructor to create a new {@link Restarter} instance. * Internal constructor to create a new {@link Restarter} instance.
@ -120,11 +120,10 @@ public class Restarter {
* @param args the application arguments * @param args the application arguments
* @param forceReferenceCleanup if soft/weak reference cleanup should be forced * @param forceReferenceCleanup if soft/weak reference cleanup should be forced
* @param initializer the restart initializer * @param initializer the restart initializer
* @param listeners listeners to be notified of restarts
* @see #initialize(String[]) * @see #initialize(String[])
*/ */
protected Restarter(Thread thread, String[] args, boolean forceReferenceCleanup, protected Restarter(Thread thread, String[] args, boolean forceReferenceCleanup,
RestartInitializer initializer, RestartListener... listeners) { RestartInitializer initializer) {
Assert.notNull(thread, "Thread must not be null"); Assert.notNull(thread, "Thread must not be null");
Assert.notNull(args, "Args must not be null"); Assert.notNull(args, "Args must not be null");
Assert.notNull(initializer, "Initializer must not be null"); Assert.notNull(initializer, "Initializer must not be null");
@ -137,7 +136,6 @@ public class Restarter {
this.args = args; this.args = args;
this.exceptionHandler = thread.getUncaughtExceptionHandler(); this.exceptionHandler = thread.getUncaughtExceptionHandler();
this.leakSafeThreads.add(new LeakSafeThread()); this.leakSafeThreads.add(new LeakSafeThread());
this.listeners = listeners;
} }
private String getMainClassName(Thread thread) { private String getMainClassName(Thread thread) {
@ -250,7 +248,6 @@ public class Restarter {
@Override @Override
public Void call() throws Exception { public Void call() throws Exception {
Restarter.this.beforeRestart();
Restarter.this.stop(); Restarter.this.stop();
Restarter.this.start(failureHandler); Restarter.this.start(failureHandler);
return null; return null;
@ -313,7 +310,10 @@ public class Restarter {
this.logger.debug("Stopping application"); this.logger.debug("Stopping application");
this.stopLock.lock(); this.stopLock.lock();
try { try {
triggerShutdownHooks(); if (this.rootContext != null) {
this.rootContext.close();
this.rootContext = null;
}
cleanupCaches(); cleanupCaches();
if (this.forceReferenceCleanup) { if (this.forceReferenceCleanup) {
forceReferenceCleanup(); forceReferenceCleanup();
@ -326,23 +326,6 @@ public class Restarter {
System.runFinalization(); System.runFinalization();
} }
private void beforeRestart() {
for (RestartListener listener : this.listeners) {
listener.beforeRestart();
}
}
@SuppressWarnings("rawtypes")
private void triggerShutdownHooks() throws Exception {
Class<?> hooksClass = Class.forName("java.lang.ApplicationShutdownHooks");
Method runHooks = hooksClass.getDeclaredMethod("runHooks");
runHooks.setAccessible(true);
runHooks.invoke(null);
Field field = hooksClass.getDeclaredField("hooks");
field.setAccessible(true);
field.set(null, new IdentityHashMap());
}
private void cleanupCaches() throws Exception { private void cleanupCaches() throws Exception {
Introspector.flushCaches(); Introspector.flushCaches();
cleanupKnownCaches(); cleanupKnownCaches();
@ -418,10 +401,17 @@ public class Restarter {
} }
} }
boolean isFinished() { synchronized boolean isFinished() {
return this.finished; return this.finished;
} }
void prepare(ConfigurableApplicationContext applicationContext) {
if (applicationContext != null && applicationContext.getParent() != null) {
return;
}
this.rootContext = applicationContext;
}
private LeakSafeThread getLeakSafeThread() { private LeakSafeThread getLeakSafeThread() {
try { try {
return this.leakSafeThreads.takeFirst(); return this.leakSafeThreads.takeFirst();
@ -520,16 +510,14 @@ public class Restarter {
* @param initializer the restart initializer * @param initializer the restart initializer
* @param restartOnInitialize if the restarter should be restarted immediately when * @param restartOnInitialize if the restarter should be restarted immediately when
* the {@link RestartInitializer} returns non {@code null} results * the {@link RestartInitializer} returns non {@code null} results
* @param listeners listeners to be notified of restarts
*/ */
public static void initialize(String[] args, boolean forceReferenceCleanup, public static void initialize(String[] args, boolean forceReferenceCleanup,
RestartInitializer initializer, boolean restartOnInitialize, RestartInitializer initializer, boolean restartOnInitialize) {
RestartListener... listeners) {
Restarter localInstance = null; Restarter localInstance = null;
synchronized (Restarter.class) { synchronized (Restarter.class) {
if (instance == null) { if (instance == null) {
localInstance = new Restarter(Thread.currentThread(), args, localInstance = new Restarter(Thread.currentThread(), args,
forceReferenceCleanup, initializer, listeners); forceReferenceCleanup, initializer);
instance = localInstance; instance = localInstance;
} }
} }

@ -16,7 +16,3 @@ org.springframework.boot.devtools.autoconfigure.RemoteDevToolsAutoConfiguration
org.springframework.boot.env.EnvironmentPostProcessor=\ org.springframework.boot.env.EnvironmentPostProcessor=\
org.springframework.boot.devtools.env.DevToolsHomePropertiesPostProcessor,\ org.springframework.boot.devtools.env.DevToolsHomePropertiesPostProcessor,\
org.springframework.boot.devtools.env.DevToolsPropertyDefaultsPostProcessor org.springframework.boot.devtools.env.DevToolsPropertyDefaultsPostProcessor
# Restart Listeners
org.springframework.boot.devtools.restart.RestartListener=\
org.springframework.boot.devtools.log4j2.Log4J2RestartListener

@ -1,5 +1,5 @@
/* /*
* Copyright 2012-2015 the original author or authors. * Copyright 2012-2016 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -22,6 +22,7 @@ import org.junit.Test;
import org.springframework.boot.SpringApplication; import org.springframework.boot.SpringApplication;
import org.springframework.boot.context.event.ApplicationFailedEvent; import org.springframework.boot.context.event.ApplicationFailedEvent;
import org.springframework.boot.context.event.ApplicationPreparedEvent;
import org.springframework.boot.context.event.ApplicationReadyEvent; import org.springframework.boot.context.event.ApplicationReadyEvent;
import org.springframework.boot.context.event.ApplicationStartedEvent; import org.springframework.boot.context.event.ApplicationStartedEvent;
import org.springframework.context.ConfigurableApplicationContext; import org.springframework.context.ConfigurableApplicationContext;
@ -36,6 +37,7 @@ import static org.mockito.Mockito.mock;
* Tests for {@link RestartApplicationListener}. * Tests for {@link RestartApplicationListener}.
* *
* @author Phillip Webb * @author Phillip Webb
* @author Andy Wilkinson
*/ */
public class RestartApplicationListenerTests { public class RestartApplicationListenerTests {
@ -62,6 +64,8 @@ public class RestartApplicationListenerTests {
assertThat(ReflectionTestUtils.getField(Restarter.getInstance(), "args")) assertThat(ReflectionTestUtils.getField(Restarter.getInstance(), "args"))
.isEqualTo(ARGS); .isEqualTo(ARGS);
assertThat(Restarter.getInstance().isFinished()).isTrue(); assertThat(Restarter.getInstance().isFinished()).isTrue();
assertThat(ReflectionTestUtils.getField(Restarter.getInstance(), "rootContext"))
.isNotNull();
} }
@Test @Test
@ -70,6 +74,8 @@ public class RestartApplicationListenerTests {
assertThat(ReflectionTestUtils.getField(Restarter.getInstance(), "args")) assertThat(ReflectionTestUtils.getField(Restarter.getInstance(), "args"))
.isEqualTo(ARGS); .isEqualTo(ARGS);
assertThat(Restarter.getInstance().isFinished()).isTrue(); assertThat(Restarter.getInstance().isFinished()).isTrue();
assertThat(ReflectionTestUtils.getField(Restarter.getInstance(), "rootContext"))
.isNull();
} }
@Test @Test
@ -89,6 +95,8 @@ public class RestartApplicationListenerTests {
listener.onApplicationEvent(new ApplicationStartedEvent(application, ARGS)); listener.onApplicationEvent(new ApplicationStartedEvent(application, ARGS));
assertThat(Restarter.getInstance()).isNotEqualTo(nullValue()); assertThat(Restarter.getInstance()).isNotEqualTo(nullValue());
assertThat(Restarter.getInstance().isFinished()).isFalse(); assertThat(Restarter.getInstance().isFinished()).isFalse();
listener.onApplicationEvent(
new ApplicationPreparedEvent(application, ARGS, context));
if (failed) { if (failed) {
listener.onApplicationEvent(new ApplicationFailedEvent(application, ARGS, listener.onApplicationEvent(new ApplicationFailedEvent(application, ARGS,
context, new RuntimeException())); context, new RuntimeException()));

@ -1,5 +1,5 @@
/* /*
* Copyright 2012-2015 the original author or authors. * Copyright 2012-2016 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -34,7 +34,9 @@ import org.springframework.boot.devtools.restart.classloader.ClassLoaderFile;
import org.springframework.boot.devtools.restart.classloader.ClassLoaderFile.Kind; import org.springframework.boot.devtools.restart.classloader.ClassLoaderFile.Kind;
import org.springframework.boot.devtools.restart.classloader.ClassLoaderFiles; import org.springframework.boot.devtools.restart.classloader.ClassLoaderFiles;
import org.springframework.boot.test.OutputCapture; import org.springframework.boot.test.OutputCapture;
import org.springframework.context.ApplicationListener;
import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.event.ContextClosedEvent;
import org.springframework.scheduling.annotation.EnableScheduling; import org.springframework.scheduling.annotation.EnableScheduling;
import org.springframework.scheduling.annotation.Scheduled; import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
@ -51,6 +53,7 @@ import static org.mockito.Mockito.verifyZeroInteractions;
* Tests for {@link Restarter}. * Tests for {@link Restarter}.
* *
* @author Phillip Webb * @author Phillip Webb
* @author Andy Wilkinson
*/ */
public class RestarterTests { public class RestarterTests {
@ -94,7 +97,7 @@ public class RestarterTests {
String output = this.out.toString(); String output = this.out.toString();
assertThat(StringUtils.countOccurrencesOf(output, "Tick 0")).isGreaterThan(1); assertThat(StringUtils.countOccurrencesOf(output, "Tick 0")).isGreaterThan(1);
assertThat(StringUtils.countOccurrencesOf(output, "Tick 1")).isGreaterThan(1); assertThat(StringUtils.countOccurrencesOf(output, "Tick 1")).isGreaterThan(1);
assertThat(TestRestartListener.restarts).isGreaterThan(0); assertThat(CloseCountingApplicationListener.closed).isGreaterThan(1);
} }
@Test @Test
@ -213,15 +216,14 @@ public class RestarterTests {
} }
public static void main(String... args) { public static void main(String... args) {
Restarter.initialize(args, false, new MockRestartInitializer(), true, Restarter.initialize(args, false, new MockRestartInitializer(), true);
new TestRestartListener());
AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext( AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(
SampleApplication.class); SampleApplication.class);
context.registerShutdownHook(); context.addApplicationListener(new CloseCountingApplicationListener());
Restarter.getInstance().prepare(context);
System.out.println("Sleep " + Thread.currentThread()); System.out.println("Sleep " + Thread.currentThread());
sleep(); sleep();
quit = true; quit = true;
context.close();
} }
private static void sleep() { private static void sleep() {
@ -235,6 +237,18 @@ public class RestarterTests {
} }
private static class CloseCountingApplicationListener
implements ApplicationListener<ContextClosedEvent> {
static int closed = 0;
@Override
public void onApplicationEvent(ContextClosedEvent event) {
closed++;
}
}
private static class TestableRestarter extends Restarter { private static class TestableRestarter extends Restarter {
private ClassLoader relaunchClassLoader; private ClassLoader relaunchClassLoader;
@ -276,15 +290,4 @@ public class RestarterTests {
} }
private static class TestRestartListener implements RestartListener {
private static int restarts;
@Override
public void beforeRestart() {
restarts++;
}
}
} }

Loading…
Cancel
Save