pull/35031/head
Phillip Webb 2 years ago
parent 1849b82334
commit 2951cc7594

@ -19,6 +19,7 @@ package org.springframework.boot.autoconfigure.service.connection;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.stream.Stream;
import org.springframework.core.ResolvableType; import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.AnnotationAwareOrderComparator; import org.springframework.core.annotation.AnnotationAwareOrderComparator;
@ -35,76 +36,62 @@ import org.springframework.core.style.ToStringCreator;
*/ */
public class ConnectionDetailsFactories { public class ConnectionDetailsFactories {
private List<FactoryDetails> registeredFactories = new ArrayList<>(); private List<Registration<?, ?>> registrations = new ArrayList<>();
public ConnectionDetailsFactories() { public ConnectionDetailsFactories() {
this(SpringFactoriesLoader.forDefaultResourceLocation(ConnectionDetailsFactory.class.getClassLoader())); this(SpringFactoriesLoader.forDefaultResourceLocation(ConnectionDetailsFactory.class.getClassLoader()));
} }
@SuppressWarnings("rawtypes") @SuppressWarnings({ "rawtypes", "unchecked" })
ConnectionDetailsFactories(SpringFactoriesLoader loader) { ConnectionDetailsFactories(SpringFactoriesLoader loader) {
List<ConnectionDetailsFactory> factories = loader.load(ConnectionDetailsFactory.class); List<ConnectionDetailsFactory> factories = loader.load(ConnectionDetailsFactory.class);
factories.stream().map(this::factoryDetails).filter(Objects::nonNull).forEach(this::register); Stream<Registration<?, ?>> registrations = factories.stream().map(Registration::get);
registrations.filter(Objects::nonNull).forEach(this.registrations::add);
} }
@SuppressWarnings("unchecked") public <S> ConnectionDetails getConnectionDetails(S source) {
private FactoryDetails factoryDetails(ConnectionDetailsFactory<?, ?> factory) { return getConnectionDetailsFactory(source).getConnectionDetails(source);
ResolvableType connectionDetailsFactory = findConnectionDetailsFactory(
ResolvableType.forClass(factory.getClass()));
if (connectionDetailsFactory != null) {
ResolvableType input = connectionDetailsFactory.getGeneric(0);
ResolvableType output = connectionDetailsFactory.getGeneric(1);
return new FactoryDetails(input.getRawClass(), (Class<? extends ConnectionDetails>) output.getRawClass(),
factory);
}
return null;
}
private ResolvableType findConnectionDetailsFactory(ResolvableType type) {
try {
ResolvableType[] interfaces = type.getInterfaces();
for (ResolvableType iface : interfaces) {
if (iface.getRawClass().equals(ConnectionDetailsFactory.class)) {
return iface;
}
}
}
catch (TypeNotPresentException ex) {
// A type referenced by the factory is not present. Skip it.
}
ResolvableType superType = type.getSuperType();
return ResolvableType.NONE.equals(superType) ? null : findConnectionDetailsFactory(superType);
}
private void register(FactoryDetails details) {
this.registeredFactories.add(details);
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public <S> ConnectionDetailsFactory<S, ConnectionDetails> getConnectionDetailsFactory(S source) { public <S> ConnectionDetailsFactory<S, ConnectionDetails> getConnectionDetailsFactory(S source) {
Class<S> input = (Class<S>) source.getClass(); Class<S> sourceType = (Class<S>) source.getClass();
List<ConnectionDetailsFactory<S, ConnectionDetails>> matchingFactories = new ArrayList<>(); List<ConnectionDetailsFactory<S, ConnectionDetails>> result = new ArrayList<>();
for (FactoryDetails factoryDetails : this.registeredFactories) { for (Registration<?, ?> candidate : this.registrations) {
if (factoryDetails.input.isAssignableFrom(input)) { if (candidate.sourceType().isAssignableFrom(sourceType)) {
matchingFactories.add((ConnectionDetailsFactory<S, ConnectionDetails>) factoryDetails.factory); result.add((ConnectionDetailsFactory<S, ConnectionDetails>) candidate.factory());
} }
} }
if (matchingFactories.isEmpty()) { if (result.isEmpty()) {
throw new ConnectionDetailsFactoryNotFoundException(source); throw new ConnectionDetailsFactoryNotFoundException(source);
} }
else { AnnotationAwareOrderComparator.sort(result);
if (matchingFactories.size() == 1) { return (result.size() != 1) ? new CompositeConnectionDetailsFactory<>(result) : result.get(0);
return matchingFactories.get(0); }
/**
* A {@link ConnectionDetailsFactory} registration.
*/
private record Registration<S, D extends ConnectionDetails>(Class<S> sourceType, Class<D> connectionDetailsType,
ConnectionDetailsFactory<S, D> factory) {
@SuppressWarnings("unchecked")
private static <S, D extends ConnectionDetails> Registration<S, D> get(ConnectionDetailsFactory<S, D> factory) {
ResolvableType type = ResolvableType.forClass(ConnectionDetailsFactory.class, factory.getClass());
if (!type.hasUnresolvableGenerics()) {
Class<?>[] generics = type.resolveGenerics();
return new Registration<>((Class<S>) generics[0], (Class<D>) generics[1], factory);
} }
AnnotationAwareOrderComparator.sort(matchingFactories); return null;
return new CompositeConnectionDetailsFactory<>(matchingFactories);
} }
}
private record FactoryDetails(Class<?> input, Class<? extends ConnectionDetails> output,
ConnectionDetailsFactory<?, ?> factory) {
} }
/**
* Composite {@link ConnectionDetailsFactory} implementation.
*
* @param <S> the source type
*/
static class CompositeConnectionDetailsFactory<S> implements ConnectionDetailsFactory<S, ConnectionDetails> { static class CompositeConnectionDetailsFactory<S> implements ConnectionDetailsFactory<S, ConnectionDetails> {
private final List<ConnectionDetailsFactory<S, ConnectionDetails>> delegates; private final List<ConnectionDetailsFactory<S, ConnectionDetails>> delegates;
@ -114,15 +101,16 @@ public class ConnectionDetailsFactories {
} }
@Override @Override
@SuppressWarnings("unchecked") public ConnectionDetails getConnectionDetails(S source) {
public ConnectionDetails getConnectionDetails(Object source) { return this.delegates.stream()
for (ConnectionDetailsFactory<S, ConnectionDetails> delegate : this.delegates) { .map((delegate) -> delegate.getConnectionDetails(source))
ConnectionDetails connectionDetails = delegate.getConnectionDetails((S) source); .filter(Objects::nonNull)
if (connectionDetails != null) { .findFirst()
return connectionDetails; .orElse(null);
} }
}
return null; List<ConnectionDetailsFactory<S, ConnectionDetails>> getDelegates() {
return this.delegates;
} }
@Override @Override
@ -130,10 +118,6 @@ public class ConnectionDetailsFactories {
return new ToStringCreator(this).append("delegates", this.delegates).toString(); return new ToStringCreator(this).append("delegates", this.delegates).toString();
} }
List<ConnectionDetailsFactory<S, ConnectionDetails>> getDelegates() {
return this.delegates;
}
} }
} }

@ -51,19 +51,8 @@ class RedisContainerConnectionDetailsFactory
private RedisContainerConnectionDetails( private RedisContainerConnectionDetails(
ContainerConnectionSource<RedisServiceConnection, RedisConnectionDetails, GenericContainer<?>> source) { ContainerConnectionSource<RedisServiceConnection, RedisConnectionDetails, GenericContainer<?>> source) {
super(source); super(source);
this.standalone = new Standalone() { this.standalone = Standalone.of(source.getContainer().getHost(),
source.getContainer().getFirstMappedPort());
@Override
public String getHost() {
return source.getContainer().getHost();
}
@Override
public int getPort() {
return source.getContainer().getFirstMappedPort();
}
};
} }
@Override @Override

@ -17,6 +17,7 @@
package org.springframework.boot.test.autoconfigure.service.connection; package org.springframework.boot.test.autoconfigure.service.connection;
import java.util.List; import java.util.List;
import java.util.function.Supplier;
import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
@ -24,7 +25,6 @@ import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails;
import org.springframework.boot.autoconfigure.service.connection.ConnectionDetailsFactories; import org.springframework.boot.autoconfigure.service.connection.ConnectionDetailsFactories;
import org.springframework.boot.autoconfigure.service.connection.ConnectionDetailsFactory;
import org.springframework.context.ConfigurableApplicationContext; import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.test.context.ContextCustomizer; import org.springframework.test.context.ContextCustomizer;
import org.springframework.test.context.MergedContextConfiguration; import org.springframework.test.context.MergedContextConfiguration;
@ -61,21 +61,21 @@ class ServiceConnectionContextCustomizer implements ContextCustomizer {
private void registerServiceConnection(BeanDefinitionRegistry registry, ContainerConnectionSource<?, ?, ?> source) { private void registerServiceConnection(BeanDefinitionRegistry registry, ContainerConnectionSource<?, ?, ?> source) {
ConnectionDetails connectionDetails = getConnectionDetails(source); ConnectionDetails connectionDetails = getConnectionDetails(source);
String beanName = source.getBeanName(); register(connectionDetails, registry, source.getBeanName());
registry.registerBeanDefinition(beanName, createBeanDefinition(connectionDetails));
} }
private <S> ConnectionDetails getConnectionDetails(S source) { @SuppressWarnings("unchecked")
ConnectionDetailsFactory<S, ConnectionDetails> factory = this.factories.getConnectionDetailsFactory(source); private <T> void register(ConnectionDetails connectionDetails, BeanDefinitionRegistry registry, String beanName) {
ConnectionDetails connectionDetails = factory.getConnectionDetails(source); Class<T> beanType = (Class<T>) connectionDetails.getClass();
Assert.state(connectionDetails != null, Supplier<T> beanSupplier = () -> (T) connectionDetails;
() -> "No connection details created by %s".formatted(factory.getClass().getName())); BeanDefinition beanDefinition = new RootBeanDefinition(beanType, beanSupplier);
return connectionDetails; registry.registerBeanDefinition(beanName, beanDefinition);
} }
@SuppressWarnings("unchecked") private <S> ConnectionDetails getConnectionDetails(S source) {
private <T> BeanDefinition createBeanDefinition(T instance) { ConnectionDetails connectionDetails = this.factories.getConnectionDetails(source);
return new RootBeanDefinition((Class<T>) instance.getClass(), () -> instance); Assert.state(connectionDetails != null, () -> "No connection details created for %s".formatted(source));
return connectionDetails;
} }
List<ContainerConnectionSource<?, ?, ?>> getSources() { List<ContainerConnectionSource<?, ?, ?>> getSources() {

Loading…
Cancel
Save