diff --git a/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/mongo/MongoDataAutoConfiguration.java b/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/mongo/MongoDataAutoConfiguration.java index c1e014acaf..680b4f80e6 100644 --- a/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/mongo/MongoDataAutoConfiguration.java +++ b/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/mongo/MongoDataAutoConfiguration.java @@ -26,12 +26,16 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.dao.DataAccessException; +import org.springframework.dao.support.PersistenceExceptionTranslator; import org.springframework.data.mongodb.MongoDbFactory; import org.springframework.data.mongodb.core.MongoTemplate; import org.springframework.data.mongodb.core.SimpleMongoDbFactory; import org.springframework.data.mongodb.gridfs.GridFsTemplate; +import org.springframework.util.Assert; import org.springframework.util.StringUtils; +import com.mongodb.DB; import com.mongodb.Mongo; /** @@ -73,11 +77,49 @@ public class MongoDataAutoConfiguration { @Bean @ConditionalOnMissingBean - public GridFsTemplate gridFsTemplate(Mongo mongo, MongoTemplate mongoTemplate) { - String db = StringUtils.hasText(this.properties.getGridFsDatabase()) ? this.properties - .getGridFsDatabase() : this.properties.getMongoClientDatabase(); - return new GridFsTemplate(new SimpleMongoDbFactory(mongo, db), - mongoTemplate.getConverter()); + public GridFsTemplate gridFsTemplate(MongoDbFactory mongoDbFactory, + MongoTemplate mongoTemplate) { + return new GridFsTemplate(new GridFsMongoDbFactory(mongoDbFactory, + this.properties), mongoTemplate.getConverter()); + } + + /** + * {@link MongoDbFactory} decorator to respect + * {@link MongoProperties#getGridFsDatabase()} if set. + */ + private static class GridFsMongoDbFactory implements MongoDbFactory { + + private final MongoDbFactory mongoDbFactory; + + private final MongoProperties properties; + + public GridFsMongoDbFactory(MongoDbFactory mongoDbFactory, + MongoProperties properties) { + Assert.notNull(mongoDbFactory, "MongoDbFactory must not be null"); + Assert.notNull(properties, "Properties must not be null"); + this.mongoDbFactory = mongoDbFactory; + this.properties = properties; + } + + @Override + public DB getDb() throws DataAccessException { + String gridFsDatabase = this.properties.getGridFsDatabase(); + if (StringUtils.hasText(gridFsDatabase)) { + return this.mongoDbFactory.getDb(gridFsDatabase); + } + return this.mongoDbFactory.getDb(); + } + + @Override + public DB getDb(String dbName) throws DataAccessException { + return this.mongoDbFactory.getDb(dbName); + } + + @Override + public PersistenceExceptionTranslator getExceptionTranslator() { + return this.mongoDbFactory.getExceptionTranslator(); + } + } }