Ensure we close the classloader when done with it

It might leave a bunch of jar files open when it reads
them.

Test: unit tests
Bug: 150865834
Change-Id: Ica1b8395c62579590c953908ab84b070bcf8658d
diff --git a/src/com/android/tradefed/util/TestFilterHelper.java b/src/com/android/tradefed/util/TestFilterHelper.java
index 892a0ce..84ea0f3 100644
--- a/src/com/android/tradefed/util/TestFilterHelper.java
+++ b/src/com/android/tradefed/util/TestFilterHelper.java
@@ -253,35 +253,40 @@
     public boolean shouldRun(Description desc, List<File> extraJars) {
         // We need to build the packageName for a description object
         Class<?> classObj = null;
+        URLClassLoader cl = null;
         try {
-            List<URL> urlList = new ArrayList<>();
-            for (File f : extraJars) {
-                urlList.add(f.toURI().toURL());
+            try {
+                List<URL> urlList = new ArrayList<>();
+                for (File f : extraJars) {
+                    urlList.add(f.toURI().toURL());
+                }
+                cl = URLClassLoader.newInstance(urlList.toArray(new URL[0]));
+                classObj = cl.loadClass(desc.getClassName());
+            } catch (MalformedURLException | ClassNotFoundException e) {
+                throw new IllegalArgumentException(
+                        String.format("Could not load Test class %s", classObj), e);
             }
-            URLClassLoader cl = URLClassLoader.newInstance(urlList.toArray(new URL[0]));
-            classObj = cl.loadClass(desc.getClassName());
-        } catch (MalformedURLException | ClassNotFoundException e) {
-            throw new IllegalArgumentException(String.format("Could not load Test class %s",
-                    classObj), e);
-        }
-        // If class is explicitly annotated to be excluded, exclude it.
-        if (isExcluded(Arrays.asList(classObj.getAnnotations()))) {
-            return false;
-        }
-        String packageName = classObj.getPackage().getName();
 
-        String className = desc.getClassName();
-        String methodName = String.format("%s#%s", className, desc.getMethodName());
-        if (!shouldRunFilter(packageName, className, methodName)) {
-            return false;
+            // If class is explicitly annotated to be excluded, exclude it.
+            if (isExcluded(Arrays.asList(classObj.getAnnotations()))) {
+                return false;
+            }
+            String packageName = classObj.getPackage().getName();
+            String className = desc.getClassName();
+            String methodName = String.format("%s#%s", className, desc.getMethodName());
+            if (!shouldRunFilter(packageName, className, methodName)) {
+                return false;
+            }
+            if (!shouldTestRun(desc)) {
+                return false;
+            }
+            return mIncludeFilters.isEmpty()
+                    || mIncludeFilters.contains(methodName)
+                    || mIncludeFilters.contains(className)
+                    || mIncludeFilters.contains(packageName);
+        } finally {
+            StreamUtil.close(cl);
         }
-        if (!shouldTestRun(desc)) {
-            return false;
-        }
-        return mIncludeFilters.isEmpty()
-                || mIncludeFilters.contains(methodName)
-                || mIncludeFilters.contains(className)
-                || mIncludeFilters.contains(packageName);
     }
 
     /**
diff --git a/src/com/android/tradefed/util/TestLoader.java b/src/com/android/tradefed/util/TestLoader.java
index 848d2cd..962cc9f 100644
--- a/src/com/android/tradefed/util/TestLoader.java
+++ b/src/com/android/tradefed/util/TestLoader.java
@@ -55,8 +55,12 @@
             Set<String> classNames =
                     scanner.getEntriesFromJar(testJarFile, new ExternalClassNameFilter()).keySet();
 
-            ClassLoader jarClassLoader = buildJarClassLoader(testJarFile, dependentJars);
-            return loadTests(classNames, jarClassLoader);
+            URLClassLoader jarClassLoader = buildJarClassLoader(testJarFile, dependentJars);
+            try {
+                return loadTests(classNames, jarClassLoader);
+            } finally {
+                jarClassLoader.close();
+            }
         } catch (IOException e) {
             Log.e(LOG_TAG, String.format("IOException when loading test classes from jar %s",
                     testJarFile.getAbsolutePath()));
@@ -65,7 +69,7 @@
         return null;
     }
 
-    private ClassLoader buildJarClassLoader(File jarFile, Collection<File> dependentJars)
+    private URLClassLoader buildJarClassLoader(File jarFile, Collection<File> dependentJars)
             throws MalformedURLException {
         URL[] urls = new URL[dependentJars.size() + 1];
         urls[0] = jarFile.toURI().toURL();
diff --git a/test_framework/com/android/tradefed/testtype/HostTest.java b/test_framework/com/android/tradefed/testtype/HostTest.java
index f9aad4b..476c43b 100644
--- a/test_framework/com/android/tradefed/testtype/HostTest.java
+++ b/test_framework/com/android/tradefed/testtype/HostTest.java
@@ -181,6 +181,8 @@
     private boolean mSkipTestClassCheck = false;
 
     private List<Object> mTestMethods;
+    private List<Class<?>> mLoadedClasses = new ArrayList<>();
+    private List<URLClassLoader> mOpenClassLoaders = new ArrayList<>();
 
     // Initialized as -1 to indicate that this value needs to be recalculated
     // when test count is requested.
@@ -508,38 +510,46 @@
         mFilterHelper.addAllExcludeAnnotation(mExcludeAnnotations);
 
         try {
-            List<Class<?>> classes = getClasses();
-            if (!mSkipTestClassCheck) {
-                if (classes.isEmpty()) {
-                    throw new IllegalArgumentException("No '--class' option was specified.");
+            try {
+                List<Class<?>> classes = getClasses();
+                if (!mSkipTestClassCheck) {
+                    if (classes.isEmpty()) {
+                        throw new IllegalArgumentException("No '--class' option was specified.");
+                    }
                 }
+                if (mMethodName != null && classes.size() > 1) {
+                    throw new IllegalArgumentException(
+                            String.format(
+                                    "'--method' only supports one '--class' name. Multiple were "
+                                            + "given: '%s'",
+                                    classes));
+                }
+            } catch (IllegalArgumentException e) {
+                listener.testRunStarted(this.getClass().getCanonicalName(), 0);
+                FailureDescription failureDescription =
+                        FailureDescription.create(StreamUtil.getStackTrace(e));
+                failureDescription.setFailureStatus(FailureStatus.TEST_FAILURE);
+                listener.testRunFailed(failureDescription);
+                listener.testRunEnded(0L, new HashMap<String, Metric>());
+                throw e;
             }
-            if (mMethodName != null && classes.size() > 1) {
-                throw new IllegalArgumentException(
-                        String.format(
-                                "'--method' only supports one '--class' name. Multiple were "
-                                        + "given: '%s'",
-                                classes));
-            }
-        } catch (IllegalArgumentException e) {
-            listener.testRunStarted(this.getClass().getCanonicalName(), 0);
-            FailureDescription failureDescription =
-                    FailureDescription.create(StreamUtil.getStackTrace(e));
-            failureDescription.setFailureStatus(FailureStatus.TEST_FAILURE);
-            listener.testRunFailed(failureDescription);
-            listener.testRunEnded(0L, new HashMap<String, Metric>());
-            throw e;
-        }
 
-        // Add a pretty logger to the events to mark clearly start/end of test cases.
-        if (mEnableHostDeviceLogs) {
-            PrettyTestEventLogger logger = new PrettyTestEventLogger(mTestInfo.getDevices());
-            listener = new ResultForwarder(logger, listener);
-        }
-        if (mTestMethods != null) {
-            runTestCases(listener);
-        } else {
-            runTestClasses(listener);
+            // Add a pretty logger to the events to mark clearly start/end of test cases.
+            if (mEnableHostDeviceLogs) {
+                PrettyTestEventLogger logger = new PrettyTestEventLogger(mTestInfo.getDevices());
+                listener = new ResultForwarder(logger, listener);
+            }
+            if (mTestMethods != null) {
+                runTestCases(listener);
+            } else {
+                runTestClasses(listener);
+            }
+        } finally {
+            mLoadedClasses.clear();
+            for (URLClassLoader cl : mOpenClassLoaders) {
+                StreamUtil.close(cl);
+            }
+            mOpenClassLoaders.clear();
         }
     }
 
@@ -885,9 +895,12 @@
     }
 
     protected final List<Class<?>> getClasses() throws IllegalArgumentException {
+        if (!mLoadedClasses.isEmpty()) {
+            return mLoadedClasses;
+        }
         // Use a set to avoid repeat between filters and jar search
         Set<String> classNames = new HashSet<>();
-        List<Class<?>> classes = new ArrayList<>();
+        List<Class<?>> classes = mLoadedClasses;
         for (String className : mClasses) {
             if (classNames.contains(className)) {
                 continue;
@@ -910,18 +923,21 @@
                                 .getUniqueMap()
                                 .get(ModuleDefinition.MODULE_NAME);
                 if (moduleName != null) {
+                    URLClassLoader cl = null;
                     try {
                         File f = getJarFile(moduleName + ".jar", mTestInfo);
                         URL[] urls = {f.toURI().toURL()};
-                        URLClassLoader cl = URLClassLoader.newInstance(urls);
+                        cl = URLClassLoader.newInstance(urls);
                         mJUnit4JarFiles.add(f);
                         Class<?> cls = cl.loadClass(className);
                         classes.add(cls);
                         classNames.add(className);
                         initialError = null;
+                        mOpenClassLoaders.add(cl);
                     } catch (FileNotFoundException
                             | MalformedURLException
                             | ClassNotFoundException fallbackSearch) {
+                        StreamUtil.close(cl);
                         CLog.e(
                                 "Fallback search for a jar containing '%s' didn't work."
                                         + "Consider using --jar option directly instead of using --class",
@@ -933,6 +949,7 @@
                 throw initialError;
             }
         }
+        URLClassLoader cl = null;
         // Inspect for the jar files
         for (String jarName : mJars) {
             JarFile jarFile = null;
@@ -941,8 +958,9 @@
                 jarFile = new JarFile(file);
                 Enumeration<JarEntry> e = jarFile.entries();
                 URL[] urls = {file.toURI().toURL()};
-                URLClassLoader cl = URLClassLoader.newInstance(urls);
+                cl = URLClassLoader.newInstance(urls);
                 mJUnit4JarFiles.add(file);
+                mOpenClassLoaders.add(cl);
 
                 while (e.hasMoreElements()) {
                     JarEntry je = e.nextElement();
@@ -1191,56 +1209,63 @@
         }
         mTestInfo = testInfo;
         List<IRemoteTest> listTests = new ArrayList<>();
-        List<Class<?>> classes = getClasses();
-        if (classes.isEmpty()) {
-            throw new IllegalArgumentException("Missing Test class name");
-        }
-        if (mMethodName != null && classes.size() > 1) {
-            throw new IllegalArgumentException("Method name given with multiple test classes");
-        }
-        List<? extends Object> testObjects;
-        if (shardUnitIsMethod()) {
-            testObjects = getTestMethods();
-        } else {
-            testObjects = classes;
-            // ignore shardCount when shard unit is class;
-            // simply shard by the number of classes
-            shardCount = testObjects.size();
-        }
-        if (testObjects.size() == 1) {
-            return null;
-        }
-        int i = 0;
-        int numTotalTestCases = countTestCases();
-        for (Object testObj : testObjects) {
-            Class<?> classObj = Class.class.isInstance(testObj) ? (Class<?>)testObj : null;
-            HostTest test;
-            if (i >= listTests.size()) {
-                test = createHostTest(classObj);
-                test.mRuntimeHint = 0;
-                // Carry over non-annotation filters to shards.
-                test.addAllExcludeFilters(mFilterHelper.getExcludeFilters());
-                test.addAllIncludeFilters(mFilterHelper.getIncludeFilters());
-                listTests.add(test);
+        try {
+            List<Class<?>> classes = getClasses();
+            if (classes.isEmpty()) {
+                throw new IllegalArgumentException("Missing Test class name");
             }
-            test = (HostTest) listTests.get(i);
-            Collection<? extends Object> subTests;
-            if (classObj != null) {
-                test.addClassName(classObj.getName());
-                subTests = test.mClasses;
+            if (mMethodName != null && classes.size() > 1) {
+                throw new IllegalArgumentException("Method name given with multiple test classes");
+            }
+            List<? extends Object> testObjects;
+            if (shardUnitIsMethod()) {
+                testObjects = getTestMethods();
             } else {
-                test.addTestMethod(testObj);
-                subTests = test.mTestMethods;
+                testObjects = classes;
+                // ignore shardCount when shard unit is class;
+                // simply shard by the number of classes
+                shardCount = testObjects.size();
             }
-            if (numTotalTestCases == 0) {
-                // In case there is no tests left
-                test.mRuntimeHint = 0L;
-            } else {
-                test.mRuntimeHint = mRuntimeHint * subTests.size() / numTotalTestCases;
+            if (testObjects.size() == 1) {
+                return null;
             }
-            i = (i + 1) % shardCount;
+            int i = 0;
+            int numTotalTestCases = countTestCases();
+            for (Object testObj : testObjects) {
+                Class<?> classObj = Class.class.isInstance(testObj) ? (Class<?>) testObj : null;
+                HostTest test;
+                if (i >= listTests.size()) {
+                    test = createHostTest(classObj);
+                    test.mRuntimeHint = 0;
+                    // Carry over non-annotation filters to shards.
+                    test.addAllExcludeFilters(mFilterHelper.getExcludeFilters());
+                    test.addAllIncludeFilters(mFilterHelper.getIncludeFilters());
+                    listTests.add(test);
+                }
+                test = (HostTest) listTests.get(i);
+                Collection<? extends Object> subTests;
+                if (classObj != null) {
+                    test.addClassName(classObj.getName());
+                    subTests = test.mClasses;
+                } else {
+                    test.addTestMethod(testObj);
+                    subTests = test.mTestMethods;
+                }
+                if (numTotalTestCases == 0) {
+                    // In case there is no tests left
+                    test.mRuntimeHint = 0L;
+                } else {
+                    test.mRuntimeHint = mRuntimeHint * subTests.size() / numTotalTestCases;
+                }
+                i = (i + 1) % shardCount;
+            }
+        } finally {
+            mLoadedClasses.clear();
+            for (URLClassLoader cl : mOpenClassLoaders) {
+                StreamUtil.close(cl);
+            }
+            mOpenClassLoaders.clear();
         }
-
         return listTests;
     }