pull/26813/head
Phillip Webb 3 years ago
parent 87d35250a5
commit be23a29651

@ -30,12 +30,6 @@ import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
*/
public abstract class AbstractBeansOfTypeDatabaseInitializerDetector implements DatabaseInitializerDetector {
/**
* Returns the bean types that should be detected as being database initializers.
* @return the database initializer bean types
*/
protected abstract Set<Class<?>> getDatabaseInitializerBeanTypes();
@Override
public Set<String> detect(ConfigurableListableBeanFactory beanFactory) {
try {
@ -47,4 +41,10 @@ public abstract class AbstractBeansOfTypeDatabaseInitializerDetector implements
}
}
/**
* Returns the bean types that should be detected as being database initializers.
* @return the database initializer bean types
*/
protected abstract Set<Class<?>> getDatabaseInitializerBeanTypes();
}

@ -32,13 +32,6 @@ import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
public abstract class AbstractBeansOfTypeDependsOnDatabaseInitializationDetector
implements DependsOnDatabaseInitializationDetector {
/**
* Returns the bean types that should be detected as depending on database
* initialization.
* @return the database initialization dependent bean types
*/
protected abstract Set<Class<?>> getDependsOnDatabaseInitializationBeanTypes();
@Override
public Set<String> detect(ConfigurableListableBeanFactory beanFactory) {
try {
@ -50,4 +43,11 @@ public abstract class AbstractBeansOfTypeDependsOnDatabaseInitializationDetector
}
}
/**
* Returns the bean types that should be detected as depending on database
* initialization.
* @return the database initialization dependent bean types
*/
protected abstract Set<Class<?>> getDependsOnDatabaseInitializationBeanTypes();
}

@ -16,9 +16,11 @@
package org.springframework.boot.sql.init.dependency;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
@ -65,16 +67,23 @@ public class DatabaseInitializationDependencyConfigurer implements ImportBeanDef
@Override
public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) {
if (registry.containsBeanDefinition(DependsOnDatabaseInitializationPostProcessor.class.getName())) {
return;
String name = DependsOnDatabaseInitializationPostProcessor.class.getName();
if (!registry.containsBeanDefinition(name)) {
BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(
DependsOnDatabaseInitializationPostProcessor.class,
this::createDependsOnDatabaseInitializationPostProcessor);
registry.registerBeanDefinition(name, builder.getBeanDefinition());
}
registry.registerBeanDefinition(DependsOnDatabaseInitializationPostProcessor.class.getName(),
BeanDefinitionBuilder
.genericBeanDefinition(DependsOnDatabaseInitializationPostProcessor.class,
() -> new DependsOnDatabaseInitializationPostProcessor(this.environment))
.getBeanDefinition());
}
private DependsOnDatabaseInitializationPostProcessor createDependsOnDatabaseInitializationPostProcessor() {
return new DependsOnDatabaseInitializationPostProcessor(this.environment);
}
/**
* {@link BeanFactoryPostProcessor} used to configure database initialization
* dependency relationships.
*/
static class DependsOnDatabaseInitializationPostProcessor implements BeanFactoryPostProcessor {
private final Environment environment;
@ -85,58 +94,55 @@ public class DatabaseInitializationDependencyConfigurer implements ImportBeanDef
@Override
public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) {
Set<String> detectedDatabaseInitializers = detectDatabaseInitializers(beanFactory);
if (detectedDatabaseInitializers.isEmpty()) {
Set<String> initializerBeanNames = detectInitializerBeanNames(beanFactory);
if (initializerBeanNames.isEmpty()) {
return;
}
for (String dependentDefinitionName : detectDependsOnDatabaseInitialization(beanFactory,
this.environment)) {
BeanDefinition definition = getBeanDefinition(dependentDefinitionName, beanFactory);
String[] dependencies = definition.getDependsOn();
for (String dependencyName : detectedDatabaseInitializers) {
dependencies = StringUtils.addStringToArray(dependencies, dependencyName);
}
definition.setDependsOn(dependencies);
for (String dependsOnInitializationBeanNames : detectDependsOnInitializationBeanNames(beanFactory)) {
BeanDefinition definition = getBeanDefinition(dependsOnInitializationBeanNames, beanFactory);
definition.setDependsOn(merge(definition.getDependsOn(), initializerBeanNames));
}
}
private Set<String> detectDatabaseInitializers(ConfigurableListableBeanFactory beanFactory) {
List<DatabaseInitializerDetector> detectors = instantiateDetectors(beanFactory, this.environment,
DatabaseInitializerDetector.class);
Set<String> detected = new HashSet<>();
private String[] merge(String[] source, Set<String> additional) {
Set<String> result = new LinkedHashSet<>((source != null) ? Arrays.asList(source) : Collections.emptySet());
result.addAll(additional);
return StringUtils.toStringArray(result);
}
private Set<String> detectInitializerBeanNames(ConfigurableListableBeanFactory beanFactory) {
List<DatabaseInitializerDetector> detectors = getDetectors(beanFactory, DatabaseInitializerDetector.class);
Set<String> beanNames = new HashSet<>();
for (DatabaseInitializerDetector detector : detectors) {
for (String initializerName : detector.detect(beanFactory)) {
detected.add(initializerName);
beanFactory.getBeanDefinition(initializerName)
.setAttribute(DatabaseInitializerDetector.class.getName(), detector.getClass().getName());
for (String beanName : detector.detect(beanFactory)) {
BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName);
beanDefinition.setAttribute(DatabaseInitializerDetector.class.getName(),
detector.getClass().getName());
beanNames.add(beanName);
}
}
detected = Collections.unmodifiableSet(detected);
beanNames = Collections.unmodifiableSet(beanNames);
for (DatabaseInitializerDetector detector : detectors) {
detector.detectionComplete(beanFactory, detected);
detector.detectionComplete(beanFactory, beanNames);
}
return detected;
return beanNames;
}
private Collection<String> detectDependsOnDatabaseInitialization(ConfigurableListableBeanFactory beanFactory,
Environment environment) {
List<DependsOnDatabaseInitializationDetector> detectors = instantiateDetectors(beanFactory, environment,
private Collection<String> detectDependsOnInitializationBeanNames(ConfigurableListableBeanFactory beanFactory) {
List<DependsOnDatabaseInitializationDetector> detectors = getDetectors(beanFactory,
DependsOnDatabaseInitializationDetector.class);
Set<String> dependentUponDatabaseInitialization = new HashSet<>();
Set<String> beanNames = new HashSet<>();
for (DependsOnDatabaseInitializationDetector detector : detectors) {
dependentUponDatabaseInitialization.addAll(detector.detect(beanFactory));
beanNames.addAll(detector.detect(beanFactory));
}
return dependentUponDatabaseInitialization;
return beanNames;
}
private <T> List<T> instantiateDetectors(ConfigurableListableBeanFactory beanFactory, Environment environment,
Class<T> detectorType) {
List<String> detectorNames = SpringFactoriesLoader.loadFactoryNames(detectorType,
beanFactory.getBeanClassLoader());
Instantiator<T> instantiator = new Instantiator<>(detectorType,
(availableParameters) -> availableParameters.add(Environment.class, environment));
List<T> detectors = instantiator.instantiate(detectorNames);
return detectors;
private <T> List<T> getDetectors(ConfigurableListableBeanFactory beanFactory, Class<T> type) {
List<String> names = SpringFactoriesLoader.loadFactoryNames(type, beanFactory.getBeanClassLoader());
Instantiator<T> instantiator = new Instantiator<>(type,
(availableParameters) -> availableParameters.add(Environment.class, this.environment));
return instantiator.instantiate(names);
}
private static BeanDefinition getBeanDefinition(String beanName, ConfigurableListableBeanFactory beanFactory) {

@ -33,7 +33,6 @@ import java.util.stream.Collectors;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.mockito.Mockito;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
@ -47,6 +46,7 @@ import org.springframework.mock.env.MockEnvironment;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
@ -59,16 +59,12 @@ class DatabaseInitializationDependencyConfigurerTests {
private final ConfigurableEnvironment environment = new MockEnvironment();
DatabaseInitializerDetector databaseInitializerDetector = MockedDatabaseInitializerDetector.mock;
DependsOnDatabaseInitializationDetector dependsOnDatabaseInitializationDetector = MockedDependsOnDatabaseInitializationDetector.mock;
@TempDir
File temp;
@BeforeEach
void resetMocks() {
reset(MockedDatabaseInitializerDetector.mock, MockedDependsOnDatabaseInitializationDetector.mock);
reset(MockDatabaseInitializerDetector.instance, MockedDependsOnDatabaseInitializationDetector.instance);
}
@Test
@ -89,19 +85,19 @@ class DatabaseInitializationDependencyConfigurerTests {
void whenDependenciesAreConfiguredThenBeansThatDependUponDatabaseInitializationDependUponDetectedDatabaseInitializers() {
BeanDefinition alpha = BeanDefinitionBuilder.genericBeanDefinition(String.class).getBeanDefinition();
BeanDefinition bravo = BeanDefinitionBuilder.genericBeanDefinition(String.class).getBeanDefinition();
performDetection(Arrays.asList(MockedDatabaseInitializerDetector.class,
performDetection(Arrays.asList(MockDatabaseInitializerDetector.class,
MockedDependsOnDatabaseInitializationDetector.class), (context) -> {
context.registerBeanDefinition("alpha", alpha);
context.registerBeanDefinition("bravo", bravo);
given(this.databaseInitializerDetector.detect(context.getBeanFactory()))
given(MockDatabaseInitializerDetector.instance.detect(context.getBeanFactory()))
.willReturn(Collections.singleton("alpha"));
given(this.dependsOnDatabaseInitializationDetector.detect(context.getBeanFactory()))
given(MockedDependsOnDatabaseInitializationDetector.instance.detect(context.getBeanFactory()))
.willReturn(Collections.singleton("bravo"));
context.refresh();
assertThat(alpha.getAttribute(DatabaseInitializerDetector.class.getName()))
.isEqualTo(MockedDatabaseInitializerDetector.class.getName());
.isEqualTo(MockDatabaseInitializerDetector.class.getName());
assertThat(bravo.getAttribute(DatabaseInitializerDetector.class.getName())).isNull();
verify(this.databaseInitializerDetector).detectionComplete(context.getBeanFactory(),
verify(MockDatabaseInitializerDetector.instance).detectionComplete(context.getBeanFactory(),
Collections.singleton("alpha"));
assertThat(bravo.getDependsOn()).containsExactly("alpha");
});
@ -156,31 +152,31 @@ class DatabaseInitializationDependencyConfigurerTests {
}
static class MockedDatabaseInitializerDetector implements DatabaseInitializerDetector {
static class MockDatabaseInitializerDetector implements DatabaseInitializerDetector {
private static DatabaseInitializerDetector mock = Mockito.mock(DatabaseInitializerDetector.class);
private static DatabaseInitializerDetector instance = mock(DatabaseInitializerDetector.class);
@Override
public Set<String> detect(ConfigurableListableBeanFactory beanFactory) {
return MockedDatabaseInitializerDetector.mock.detect(beanFactory);
return MockDatabaseInitializerDetector.instance.detect(beanFactory);
}
@Override
public void detectionComplete(ConfigurableListableBeanFactory beanFactory,
Set<String> databaseInitializerNames) {
mock.detectionComplete(beanFactory, databaseInitializerNames);
instance.detectionComplete(beanFactory, databaseInitializerNames);
}
}
static class MockedDependsOnDatabaseInitializationDetector implements DependsOnDatabaseInitializationDetector {
private static DependsOnDatabaseInitializationDetector mock = Mockito
.mock(DependsOnDatabaseInitializationDetector.class);
private static DependsOnDatabaseInitializationDetector instance = mock(
DependsOnDatabaseInitializationDetector.class);
@Override
public Set<String> detect(ConfigurableListableBeanFactory beanFactory) {
return MockedDependsOnDatabaseInitializationDetector.mock.detect(beanFactory);
return instance.detect(beanFactory);
}
}

Loading…
Cancel
Save