diff --git a/spring-boot/src/main/java/org/springframework/boot/context/embedded/tomcat/SkipPatternJarScanner.java b/spring-boot/src/main/java/org/springframework/boot/context/embedded/tomcat/SkipPatternJarScanner.java index c0d464f769..de1fdb9540 100644 --- a/spring-boot/src/main/java/org/springframework/boot/context/embedded/tomcat/SkipPatternJarScanner.java +++ b/spring-boot/src/main/java/org/springframework/boot/context/embedded/tomcat/SkipPatternJarScanner.java @@ -16,6 +16,8 @@ package org.springframework.boot.context.embedded.tomcat; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; import java.util.Collections; import java.util.LinkedHashSet; import java.util.Set; @@ -27,6 +29,9 @@ import org.apache.tomcat.JarScanner; import org.apache.tomcat.JarScannerCallback; import org.apache.tomcat.util.scan.StandardJarScanner; import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.ReflectionUtils; +import org.springframework.util.StringUtils; /** * {@link JarScanner} decorator allowing alternative default jar pattern matching. This @@ -38,6 +43,10 @@ import org.springframework.util.Assert; */ class SkipPatternJarScanner extends StandardJarScanner { + private static final String JAR_SCAN_FILTER_CLASS = "org.apache.tomcat.JarScanFilter"; + + private static final String STANDARD_JAR_SCAN_FILTER_CLASS = "org.apache.tomcat.util.scan.StandardJarScanFilter"; + private final JarScanner jarScanner; private final SkipPattern pattern; @@ -46,6 +55,32 @@ class SkipPatternJarScanner extends StandardJarScanner { Assert.notNull(jarScanner, "JarScanner must not be null"); this.jarScanner = jarScanner; this.pattern = (pattern == null ? new SkipPattern() : new SkipPattern(pattern)); + setPatternToTomcat8SkipFilter(this.pattern); + } + + private void setPatternToTomcat8SkipFilter(SkipPattern pattern) { + if (ClassUtils.isPresent(JAR_SCAN_FILTER_CLASS, null)) { + try { + Class filterClass = Class.forName(JAR_SCAN_FILTER_CLASS); + Method setJarScanner = ReflectionUtils.findMethod( + StandardJarScanner.class, "setJarScanFilter", filterClass); + setJarScanner.invoke(this, createStandardJarScanFilter(pattern)); + } + catch (Exception ex) { + throw new IllegalStateException(ex); + } + } + } + + private Object createStandardJarScanFilter(SkipPattern pattern) + throws ClassNotFoundException, InstantiationException, + IllegalAccessException, InvocationTargetException { + Class filterClass = Class.forName(STANDARD_JAR_SCAN_FILTER_CLASS); + Method setTldSkipMethod = ReflectionUtils.findMethod(filterClass, "setTldSkip", + String.class); + Object scanner = filterClass.newInstance(); + setTldSkipMethod.invoke(scanner, pattern.asCommaDelimitedString()); + return scanner; } @Override @@ -124,6 +159,10 @@ class SkipPatternJarScanner extends StandardJarScanner { this.patterns.add(patterns); } + public String asCommaDelimitedString() { + return StringUtils.collectionToCommaDelimitedString(this.patterns); + } + public Set asSet() { return Collections.unmodifiableSet(this.patterns); }