Automatically register reflection hints for scanned WebListeners

Closes gh-36008
3.0.x
Andy Wilkinson 1 year ago
parent 26b9602596
commit 458418be29

@ -19,10 +19,17 @@ package org.springframework.boot.web.servlet;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects;
import java.util.Set; import java.util.Set;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.TypeReference;
import org.springframework.beans.BeansException; import org.springframework.beans.BeansException;
import org.springframework.beans.factory.annotation.AnnotatedBeanDefinition; import org.springframework.beans.factory.annotation.AnnotatedBeanDefinition;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor;
import org.springframework.beans.factory.aot.BeanFactoryInitializationCode;
import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor; import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
@ -40,7 +47,8 @@ import org.springframework.web.context.WebApplicationContext;
* @see ServletComponentScan * @see ServletComponentScan
* @see ServletComponentScanRegistrar * @see ServletComponentScanRegistrar
*/ */
class ServletComponentRegisteringPostProcessor implements BeanFactoryPostProcessor, ApplicationContextAware { class ServletComponentRegisteringPostProcessor
implements BeanFactoryPostProcessor, ApplicationContextAware, BeanFactoryInitializationAotProcessor {
private static final List<ServletComponentHandler> HANDLERS; private static final List<ServletComponentHandler> HANDLERS;
@ -105,4 +113,29 @@ class ServletComponentRegisteringPostProcessor implements BeanFactoryPostProcess
this.applicationContext = applicationContext; this.applicationContext = applicationContext;
} }
@Override
public BeanFactoryInitializationAotContribution processAheadOfTime(ConfigurableListableBeanFactory beanFactory) {
return new BeanFactoryInitializationAotContribution() {
@Override
public void applyTo(GenerationContext generationContext,
BeanFactoryInitializationCode beanFactoryInitializationCode) {
for (String beanName : beanFactory.getBeanDefinitionNames()) {
BeanDefinition definition = beanFactory.getBeanDefinition(beanName);
if (Objects.equals(definition.getBeanClassName(),
WebListenerHandler.ServletComponentWebListenerRegistrar.class.getName())) {
String listenerClassName = (String) definition.getConstructorArgumentValues()
.getArgumentValue(0, String.class)
.getValue();
generationContext.getRuntimeHints()
.reflection()
.registerType(TypeReference.of(listenerClassName),
MemberCategory.INVOKE_DECLARED_CONSTRUCTORS);
}
}
}
};
}
} }

@ -22,11 +22,9 @@ import java.util.LinkedHashSet;
import java.util.Set; import java.util.Set;
import java.util.function.Supplier; import java.util.function.Supplier;
import org.springframework.beans.factory.aot.BeanRegistrationExcludeFilter;
import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.GenericBeanDefinition; import org.springframework.beans.factory.support.GenericBeanDefinition;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.context.annotation.ImportBeanDefinitionRegistrar; import org.springframework.context.annotation.ImportBeanDefinitionRegistrar;
import org.springframework.core.annotation.AnnotationAttributes; import org.springframework.core.annotation.AnnotationAttributes;
import org.springframework.core.type.AnnotationMetadata; import org.springframework.core.type.AnnotationMetadata;
@ -102,13 +100,4 @@ class ServletComponentScanRegistrar implements ImportBeanDefinitionRegistrar {
} }
static class ServletComponentScanBeanRegistrationExcludeFilter implements BeanRegistrationExcludeFilter {
@Override
public boolean isExcludedFromAotProcessing(RegisteredBean registeredBean) {
return BEAN_NAME.equals(registeredBean.getBeanName());
}
}
} }

@ -20,6 +20,3 @@ org.springframework.boot.jackson.JsonComponentModule.JsonComponentBeanFactoryIni
org.springframework.beans.factory.aot.BeanRegistrationAotProcessor=\ org.springframework.beans.factory.aot.BeanRegistrationAotProcessor=\
org.springframework.boot.context.properties.ConfigurationPropertiesBeanRegistrationAotProcessor,\ org.springframework.boot.context.properties.ConfigurationPropertiesBeanRegistrationAotProcessor,\
org.springframework.boot.jackson.JsonMixinModuleEntriesBeanRegistrationAotProcessor org.springframework.boot.jackson.JsonMixinModuleEntriesBeanRegistrationAotProcessor
org.springframework.beans.factory.aot.BeanRegistrationExcludeFilter=\
org.springframework.boot.web.servlet.ServletComponentScanRegistrar.ServletComponentScanBeanRegistrationExcludeFilter

@ -46,7 +46,10 @@ import org.springframework.boot.web.embedded.undertow.UndertowServletWebServerFa
import org.springframework.boot.web.servlet.context.AnnotationConfigServletWebServerApplicationContext; import org.springframework.boot.web.servlet.context.AnnotationConfigServletWebServerApplicationContext;
import org.springframework.boot.web.servlet.server.ConfigurableServletWebServerFactory; import org.springframework.boot.web.servlet.server.ConfigurableServletWebServerFactory;
import org.springframework.boot.web.servlet.server.ServletWebServerFactory; import org.springframework.boot.web.servlet.server.ServletWebServerFactory;
import org.springframework.boot.web.servlet.testcomponents.TestMultipartServlet; import org.springframework.boot.web.servlet.testcomponents.filter.TestFilter;
import org.springframework.boot.web.servlet.testcomponents.listener.TestListener;
import org.springframework.boot.web.servlet.testcomponents.servlet.TestMultipartServlet;
import org.springframework.boot.web.servlet.testcomponents.servlet.TestServlet;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.web.client.RestTemplate; import org.springframework.web.client.RestTemplate;
@ -128,11 +131,9 @@ class ServletComponentScanIntegrationTests {
File metaInf = new File(temp, "META-INF"); File metaInf = new File(temp, "META-INF");
metaInf.mkdirs(); metaInf.mkdirs();
Properties index = new Properties(); Properties index = new Properties();
index.setProperty("org.springframework.boot.web.servlet.testcomponents.TestFilter", WebFilter.class.getName()); index.setProperty(TestFilter.class.getName(), WebFilter.class.getName());
index.setProperty("org.springframework.boot.web.servlet.testcomponents.TestListener", index.setProperty(TestListener.class.getName(), WebListener.class.getName());
WebListener.class.getName()); index.setProperty(TestServlet.class.getName(), WebServlet.class.getName());
index.setProperty("org.springframework.boot.web.servlet.testcomponents.TestServlet",
WebServlet.class.getName());
try (FileWriter writer = new FileWriter(new File(metaInf, "spring.components"))) { try (FileWriter writer = new FileWriter(new File(metaInf, "spring.components"))) {
index.store(writer, null); index.store(writer, null);
} }

@ -21,8 +21,12 @@ import java.util.function.Consumer;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.predicate.RuntimeHintsPredicates;
import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.TestGenerationContext;
import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.boot.web.servlet.context.AnnotationConfigServletWebServerApplicationContext;
import org.springframework.boot.web.servlet.testcomponents.listener.TestListener;
import org.springframework.context.ApplicationContextInitializer; import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
@ -141,6 +145,19 @@ class ServletComponentScanRegistrarTests {
}); });
} }
@Test
void processAheadOfTimeRegistersReflectionHintsForWebListeners() {
AnnotationConfigServletWebServerApplicationContext context = new AnnotationConfigServletWebServerApplicationContext();
context.registerBean(ScanListenerPackage.class);
TestGenerationContext generationContext = new TestGenerationContext(
ClassName.get(getClass().getPackageName(), "TestTarget"));
new ApplicationContextAotGenerator().processAheadOfTime(context, generationContext);
assertThat(RuntimeHintsPredicates.reflection()
.onType(TestListener.class)
.withMemberCategory(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS))
.accepts(generationContext.getRuntimeHints());
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private void compile(GenericApplicationContext context, Consumer<GenericApplicationContext> freshContext) { private void compile(GenericApplicationContext context, Consumer<GenericApplicationContext> freshContext) {
TestGenerationContext generationContext = new TestGenerationContext( TestGenerationContext generationContext = new TestGenerationContext(
@ -192,4 +209,10 @@ class ServletComponentScanRegistrarTests {
} }
@Configuration(proxyBeanMethods = false)
@ServletComponentScan("org.springframework.boot.web.servlet.testcomponents.listener")
static class ScanListenerPackage {
}
} }

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
package org.springframework.boot.web.servlet.testcomponents; package org.springframework.boot.web.servlet.testcomponents.filter;
import java.io.IOException; import java.io.IOException;
@ -27,7 +27,7 @@ import jakarta.servlet.ServletResponse;
import jakarta.servlet.annotation.WebFilter; import jakarta.servlet.annotation.WebFilter;
@WebFilter("/*") @WebFilter("/*")
class TestFilter implements Filter { public class TestFilter implements Filter {
@Override @Override
public void init(FilterConfig filterConfig) throws ServletException { public void init(FilterConfig filterConfig) throws ServletException {

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
package org.springframework.boot.web.servlet.testcomponents; package org.springframework.boot.web.servlet.testcomponents.listener;
import java.io.IOException; import java.io.IOException;
import java.util.EnumSet; import java.util.EnumSet;

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
package org.springframework.boot.web.servlet.testcomponents; package org.springframework.boot.web.servlet.testcomponents.servlet;
import jakarta.servlet.annotation.MultipartConfig; import jakarta.servlet.annotation.MultipartConfig;
import jakarta.servlet.annotation.WebServlet; import jakarta.servlet.annotation.WebServlet;

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
package org.springframework.boot.web.servlet.testcomponents; package org.springframework.boot.web.servlet.testcomponents.servlet;
import java.io.IOException; import java.io.IOException;
Loading…
Cancel
Save