@ -1,5 +1,5 @@
/ *
* Copyright 2012 - 201 4 the original author or authors .
* Copyright 2012 - 201 6 the original author or authors .
*
* Licensed under the Apache License , Version 2.0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
@ -18,12 +18,17 @@ package org.springframework.boot;
import java.io.File ;
import java.io.IOException ;
import java.io.InputStream ;
import java.net.JarURLConnection ;
import java.net.URL ;
import java.net.URLConnection ;
import java.security.CodeSource ;
import java.security.ProtectionDomain ;
import java.util.Enumeration ;
import java.util.jar.JarFile ;
import java.util.jar.Manifest ;
import org.springframework.util.ClassUtils ;
import org.springframework.util.StringUtils ;
/ * *
@ -51,17 +56,51 @@ public class ApplicationHome {
* @param sourceClass the source class or { @code null }
* /
public ApplicationHome ( Class < ? > sourceClass ) {
this . source = findSource ( sourceClass = = null ? get Class( ) : sourceClass ) ;
this . source = findSource ( sourceClass = = null ? get Start Class( ) : sourceClass ) ;
this . dir = findHomeDir ( this . source ) ;
}
private Class < ? > getStartClass ( ) {
try {
ClassLoader classLoader = getClass ( ) . getClassLoader ( ) ;
return getStartClass ( classLoader . getResources ( "META-INF/MANIFEST.MF" ) ) ;
}
catch ( Exception ex ) {
return null ;
}
}
private Class < ? > getStartClass ( Enumeration < URL > manifestResources ) {
while ( manifestResources . hasMoreElements ( ) ) {
try {
InputStream inputStream = manifestResources . nextElement ( ) . openStream ( ) ;
try {
Manifest manifest = new Manifest ( inputStream ) ;
String startClass = manifest . getMainAttributes ( )
. getValue ( "Start-Class" ) ;
if ( startClass ! = null ) {
return ClassUtils . forName ( startClass ,
getClass ( ) . getClassLoader ( ) ) ;
}
}
finally {
inputStream . close ( ) ;
}
}
catch ( Exception ex ) {
}
}
return null ;
}
private File findSource ( Class < ? > sourceClass ) {
try {
ProtectionDomain protectionDomain = sourceClass . getProtectionDomain ( ) ;
CodeSource codeSource = protectionDomain . getCodeSource ( ) ;
ProtectionDomain domain = ( sourceClass = = null ? null
: sourceClass . getProtectionDomain ( ) ) ;
CodeSource codeSource = ( domain = = null ? null : domain . getCodeSource ( ) ) ;
URL location = ( codeSource = = null ? null : codeSource . getLocation ( ) ) ;
File source = ( location = = null ? null : findSource ( location ) ) ;
if ( source ! = null & & source . exists ( ) ) {
if ( source ! = null & & source . exists ( ) & & ! isUnitTest ( ) ) {
return source . getAbsoluteFile ( ) ;
}
return null ;
@ -71,14 +110,36 @@ public class ApplicationHome {
}
}
private boolean isUnitTest ( ) {
try {
for ( StackTraceElement element : Thread . currentThread ( ) . getStackTrace ( ) ) {
if ( element . getClassName ( ) . startsWith ( "org.junit." ) ) {
return true ;
}
}
}
catch ( Exception ex ) {
}
return false ;
}
private File findSource ( URL location ) throws IOException {
URLConnection connection = location . openConnection ( ) ;
if ( connection instanceof JarURLConnection ) {
return new File ( ( ( JarURLConnection ) connection ) . getJarFile ( ) . getName ( ) ) ;
return getRootJar File( ( ( JarURLConnection ) connection ) . getJarFil e( ) ) ;
}
return new File ( location . getPath ( ) ) ;
}
private File getRootJarFile ( JarFile jarFile ) {
String name = jarFile . getName ( ) ;
int separator = name . indexOf ( "!/" ) ;
if ( separator > 0 ) {
name = name . substring ( 0 , separator ) ;
}
return new File ( name ) ;
}
private File findHomeDir ( File source ) {
File homeDir = source ;
homeDir = ( homeDir = = null ? findDefaultHomeDir ( ) : homeDir ) ;