diff --git a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/mock/mockito/ResetMocksTestExecutionListener.java b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/mock/mockito/ResetMocksTestExecutionListener.java index 83d6008931..b353cd41c6 100644 --- a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/mock/mockito/ResetMocksTestExecutionListener.java +++ b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/mock/mockito/ResetMocksTestExecutionListener.java @@ -22,6 +22,8 @@ import java.util.Set; import org.mockito.Mockito; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; @@ -80,7 +82,7 @@ public class ResetMocksTestExecutionListener extends AbstractTestExecutionListen BeanDefinition definition = beanFactory.getBeanDefinition(name); if (definition.isSingleton() && instantiatedSingletons.contains(name)) { Object bean = getBean(beanFactory, name); - if (reset.equals(MockReset.get(bean))) { + if (bean != null && reset.equals(MockReset.get(bean))) { Mockito.reset(bean); } } @@ -103,11 +105,25 @@ public class ResetMocksTestExecutionListener extends AbstractTestExecutionListen private Object getBean(ConfigurableListableBeanFactory beanFactory, String name) { try { - return beanFactory.getBean(name); + if (isStandardBeanOrSingletonFactoryBean(beanFactory, name)) { + return beanFactory.getBean(name); + } } catch (Exception ex) { - return beanFactory.getSingleton(name); + // Continue + } + return beanFactory.getSingleton(name); + } + + private boolean isStandardBeanOrSingletonFactoryBean(ConfigurableListableBeanFactory beanFactory, String name) { + String factoryBeanName = BeanFactory.FACTORY_BEAN_PREFIX + name; + if (beanFactory.containsBean(factoryBeanName)) { + FactoryBean factoryBean = (FactoryBean) beanFactory.getBean(factoryBeanName); + if (!factoryBean.isSingleton()) { + return false; + } } + return true; } } diff --git a/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/mock/mockito/ResetMocksTestExecutionListenerTests.java b/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/mock/mockito/ResetMocksTestExecutionListenerTests.java index 1c51cedf94..5735f4a3ff 100644 --- a/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/mock/mockito/ResetMocksTestExecutionListenerTests.java +++ b/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/mock/mockito/ResetMocksTestExecutionListenerTests.java @@ -53,6 +53,7 @@ class ResetMocksTestExecutionListenerTests { given(getMock("before").greeting()).willReturn("before"); given(getMock("after").greeting()).willReturn("after"); given(getMock("fromFactoryBean").greeting()).willReturn("fromFactoryBean"); + assertThat(this.context.getBean(NonSingletonFactoryBean.class).getObjectInvocations).isEqualTo(0); } @Test @@ -61,6 +62,7 @@ class ResetMocksTestExecutionListenerTests { assertThat(getMock("before").greeting()).isNull(); assertThat(getMock("after").greeting()).isNull(); assertThat(getMock("fromFactoryBean").greeting()).isNull(); + assertThat(this.context.getBean(NonSingletonFactoryBean.class).getObjectInvocations).isEqualTo(0); } ExampleService getMock(String name) { @@ -109,6 +111,11 @@ class ResetMocksTestExecutionListenerTests { return new WorkingFactoryBean(); } + @Bean + NonSingletonFactoryBean nonSingletonFactoryBean() { + return new NonSingletonFactoryBean(); + } + } static class BrokenFactoryBean implements FactoryBean { @@ -132,9 +139,11 @@ class ResetMocksTestExecutionListenerTests { static class WorkingFactoryBean implements FactoryBean { + private final ExampleService service = mock(ExampleService.class, MockReset.before()); + @Override public ExampleService getObject() { - return mock(ExampleService.class, MockReset.before()); + return this.service; } @Override @@ -149,4 +158,26 @@ class ResetMocksTestExecutionListenerTests { } + static class NonSingletonFactoryBean implements FactoryBean { + + private int getObjectInvocations = 0; + + @Override + public ExampleService getObject() { + this.getObjectInvocations++; + return mock(ExampleService.class, MockReset.before()); + } + + @Override + public Class getObjectType() { + return ExampleService.class; + } + + @Override + public boolean isSingleton() { + return false; + } + + } + }