diff --git a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/service/connection/ConnectionDetailsFactories.java b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/service/connection/ConnectionDetailsFactories.java index 01d53c65d9..1a16ac5fb0 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/service/connection/ConnectionDetailsFactories.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/service/connection/ConnectionDetailsFactories.java @@ -19,6 +19,7 @@ package org.springframework.boot.autoconfigure.service.connection; import java.util.ArrayList; import java.util.List; import java.util.Objects; +import java.util.stream.Stream; import org.springframework.core.ResolvableType; import org.springframework.core.annotation.AnnotationAwareOrderComparator; @@ -35,76 +36,62 @@ import org.springframework.core.style.ToStringCreator; */ public class ConnectionDetailsFactories { - private List registeredFactories = new ArrayList<>(); + private List> registrations = new ArrayList<>(); public ConnectionDetailsFactories() { this(SpringFactoriesLoader.forDefaultResourceLocation(ConnectionDetailsFactory.class.getClassLoader())); } - @SuppressWarnings("rawtypes") + @SuppressWarnings({ "rawtypes", "unchecked" }) ConnectionDetailsFactories(SpringFactoriesLoader loader) { List factories = loader.load(ConnectionDetailsFactory.class); - factories.stream().map(this::factoryDetails).filter(Objects::nonNull).forEach(this::register); + Stream> registrations = factories.stream().map(Registration::get); + registrations.filter(Objects::nonNull).forEach(this.registrations::add); } - @SuppressWarnings("unchecked") - private FactoryDetails factoryDetails(ConnectionDetailsFactory factory) { - ResolvableType connectionDetailsFactory = findConnectionDetailsFactory( - ResolvableType.forClass(factory.getClass())); - if (connectionDetailsFactory != null) { - ResolvableType input = connectionDetailsFactory.getGeneric(0); - ResolvableType output = connectionDetailsFactory.getGeneric(1); - return new FactoryDetails(input.getRawClass(), (Class) output.getRawClass(), - factory); - } - return null; - } - - private ResolvableType findConnectionDetailsFactory(ResolvableType type) { - try { - ResolvableType[] interfaces = type.getInterfaces(); - for (ResolvableType iface : interfaces) { - if (iface.getRawClass().equals(ConnectionDetailsFactory.class)) { - return iface; - } - } - } - catch (TypeNotPresentException ex) { - // A type referenced by the factory is not present. Skip it. - } - ResolvableType superType = type.getSuperType(); - return ResolvableType.NONE.equals(superType) ? null : findConnectionDetailsFactory(superType); - } - - private void register(FactoryDetails details) { - this.registeredFactories.add(details); + public ConnectionDetails getConnectionDetails(S source) { + return getConnectionDetailsFactory(source).getConnectionDetails(source); } @SuppressWarnings("unchecked") public ConnectionDetailsFactory getConnectionDetailsFactory(S source) { - Class input = (Class) source.getClass(); - List> matchingFactories = new ArrayList<>(); - for (FactoryDetails factoryDetails : this.registeredFactories) { - if (factoryDetails.input.isAssignableFrom(input)) { - matchingFactories.add((ConnectionDetailsFactory) factoryDetails.factory); + Class sourceType = (Class) source.getClass(); + List> result = new ArrayList<>(); + for (Registration candidate : this.registrations) { + if (candidate.sourceType().isAssignableFrom(sourceType)) { + result.add((ConnectionDetailsFactory) candidate.factory()); } } - if (matchingFactories.isEmpty()) { + if (result.isEmpty()) { throw new ConnectionDetailsFactoryNotFoundException(source); } - else { - if (matchingFactories.size() == 1) { - return matchingFactories.get(0); + AnnotationAwareOrderComparator.sort(result); + return (result.size() != 1) ? new CompositeConnectionDetailsFactory<>(result) : result.get(0); + } + + /** + * A {@link ConnectionDetailsFactory} registration. + */ + private record Registration(Class sourceType, Class connectionDetailsType, + ConnectionDetailsFactory factory) { + + @SuppressWarnings("unchecked") + private static Registration get(ConnectionDetailsFactory factory) { + ResolvableType type = ResolvableType.forClass(ConnectionDetailsFactory.class, factory.getClass()); + if (!type.hasUnresolvableGenerics()) { + Class[] generics = type.resolveGenerics(); + return new Registration<>((Class) generics[0], (Class) generics[1], factory); } - AnnotationAwareOrderComparator.sort(matchingFactories); - return new CompositeConnectionDetailsFactory<>(matchingFactories); + return null; } - } - private record FactoryDetails(Class input, Class output, - ConnectionDetailsFactory factory) { } + /** + * Composite {@link ConnectionDetailsFactory} implementation. + * + * @param the source type + */ static class CompositeConnectionDetailsFactory implements ConnectionDetailsFactory { private final List> delegates; @@ -114,15 +101,16 @@ public class ConnectionDetailsFactories { } @Override - @SuppressWarnings("unchecked") - public ConnectionDetails getConnectionDetails(Object source) { - for (ConnectionDetailsFactory delegate : this.delegates) { - ConnectionDetails connectionDetails = delegate.getConnectionDetails((S) source); - if (connectionDetails != null) { - return connectionDetails; - } - } - return null; + public ConnectionDetails getConnectionDetails(S source) { + return this.delegates.stream() + .map((delegate) -> delegate.getConnectionDetails(source)) + .filter(Objects::nonNull) + .findFirst() + .orElse(null); + } + + List> getDelegates() { + return this.delegates; } @Override @@ -130,10 +118,6 @@ public class ConnectionDetailsFactories { return new ToStringCreator(this).append("delegates", this.delegates).toString(); } - List> getDelegates() { - return this.delegates; - } - } } diff --git a/spring-boot-project/spring-boot-test-autoconfigure/src/main/java/org/springframework/boot/test/autoconfigure/data/redis/RedisContainerConnectionDetailsFactory.java b/spring-boot-project/spring-boot-test-autoconfigure/src/main/java/org/springframework/boot/test/autoconfigure/data/redis/RedisContainerConnectionDetailsFactory.java index 01f2742938..0fed65d73e 100644 --- a/spring-boot-project/spring-boot-test-autoconfigure/src/main/java/org/springframework/boot/test/autoconfigure/data/redis/RedisContainerConnectionDetailsFactory.java +++ b/spring-boot-project/spring-boot-test-autoconfigure/src/main/java/org/springframework/boot/test/autoconfigure/data/redis/RedisContainerConnectionDetailsFactory.java @@ -51,19 +51,8 @@ class RedisContainerConnectionDetailsFactory private RedisContainerConnectionDetails( ContainerConnectionSource> source) { super(source); - this.standalone = new Standalone() { - - @Override - public String getHost() { - return source.getContainer().getHost(); - } - - @Override - public int getPort() { - return source.getContainer().getFirstMappedPort(); - } - - }; + this.standalone = Standalone.of(source.getContainer().getHost(), + source.getContainer().getFirstMappedPort()); } @Override diff --git a/spring-boot-project/spring-boot-test-autoconfigure/src/main/java/org/springframework/boot/test/autoconfigure/service/connection/ServiceConnectionContextCustomizer.java b/spring-boot-project/spring-boot-test-autoconfigure/src/main/java/org/springframework/boot/test/autoconfigure/service/connection/ServiceConnectionContextCustomizer.java index 839b125dac..7b68f19196 100644 --- a/spring-boot-project/spring-boot-test-autoconfigure/src/main/java/org/springframework/boot/test/autoconfigure/service/connection/ServiceConnectionContextCustomizer.java +++ b/spring-boot-project/spring-boot-test-autoconfigure/src/main/java/org/springframework/boot/test/autoconfigure/service/connection/ServiceConnectionContextCustomizer.java @@ -17,6 +17,7 @@ package org.springframework.boot.test.autoconfigure.service.connection; import java.util.List; +import java.util.function.Supplier; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; @@ -24,7 +25,6 @@ import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; import org.springframework.boot.autoconfigure.service.connection.ConnectionDetailsFactories; -import org.springframework.boot.autoconfigure.service.connection.ConnectionDetailsFactory; import org.springframework.context.ConfigurableApplicationContext; import org.springframework.test.context.ContextCustomizer; import org.springframework.test.context.MergedContextConfiguration; @@ -61,21 +61,21 @@ class ServiceConnectionContextCustomizer implements ContextCustomizer { private void registerServiceConnection(BeanDefinitionRegistry registry, ContainerConnectionSource source) { ConnectionDetails connectionDetails = getConnectionDetails(source); - String beanName = source.getBeanName(); - registry.registerBeanDefinition(beanName, createBeanDefinition(connectionDetails)); + register(connectionDetails, registry, source.getBeanName()); } - private ConnectionDetails getConnectionDetails(S source) { - ConnectionDetailsFactory factory = this.factories.getConnectionDetailsFactory(source); - ConnectionDetails connectionDetails = factory.getConnectionDetails(source); - Assert.state(connectionDetails != null, - () -> "No connection details created by %s".formatted(factory.getClass().getName())); - return connectionDetails; + @SuppressWarnings("unchecked") + private void register(ConnectionDetails connectionDetails, BeanDefinitionRegistry registry, String beanName) { + Class beanType = (Class) connectionDetails.getClass(); + Supplier beanSupplier = () -> (T) connectionDetails; + BeanDefinition beanDefinition = new RootBeanDefinition(beanType, beanSupplier); + registry.registerBeanDefinition(beanName, beanDefinition); } - @SuppressWarnings("unchecked") - private BeanDefinition createBeanDefinition(T instance) { - return new RootBeanDefinition((Class) instance.getClass(), () -> instance); + private ConnectionDetails getConnectionDetails(S source) { + ConnectionDetails connectionDetails = this.factories.getConnectionDetails(source); + Assert.state(connectionDetails != null, () -> "No connection details created for %s".formatted(source)); + return connectionDetails; } List> getSources() {