Snap for 7708232 from 912440cd7e915b2020af273373381d3f1c01a79d to sc-v2-release

Change-Id: I5266c147aa7257697be41981405b18683dbaf79f
diff --git a/TestParameterInjector.iml b/TestParameterInjector.iml
deleted file mode 100644
index 77d3a30..0000000
--- a/TestParameterInjector.iml
+++ /dev/null
@@ -1,17 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<module type="JAVA_MODULE" version="4">
-  <component name="NewModuleRootManager" inherit-compiler-output="true">
-    <exclude-output />
-    <content url="file://$MODULE_DIR$">
-      <sourceFolder url="file://$MODULE_DIR$/src/main/java" isTestSource="false" />
-      <sourceFolder url="file://$MODULE_DIR$/src/test/java" isTestSource="true" />
-    </content>
-    <orderEntry type="sourceFolder" forTests="false" />
-    <orderEntry type="module" module-name="framework_srcjars" />
-    <orderEntry type="module" module-name="base" />
-    <orderEntry type="module" module-name="modules-utils" />
-    <orderEntry type="module" module-name="Connectivity" />
-    <orderEntry type="module" module-name="dependencies" />
-    <orderEntry type="inheritedJdk" />
-  </component>
-</module>
\ No newline at end of file
diff --git a/src/main/java/com/google/testing/junit/testparameterinjector/PluggableTestRunner.java b/src/main/java/com/google/testing/junit/testparameterinjector/PluggableTestRunner.java
index 2c9a199..86fb534 100644
--- a/src/main/java/com/google/testing/junit/testparameterinjector/PluggableTestRunner.java
+++ b/src/main/java/com/google/testing/junit/testparameterinjector/PluggableTestRunner.java
@@ -83,16 +83,44 @@
 
   /**
    * If true, all test methods (across different TestMethodProcessors) will be sorted in a
-   * deterministic way by their test name.
+   * deterministic way.
    *
    * <p>Deterministic means that the order will not change, even when tests are added/removed or
    * between releases.
+   *
+   * @deprecated Override {@link #sortTestMethods} with preferred sorting strategy.
    */
+  @Deprecated
   protected boolean shouldSortTestMethodsDeterministically() {
     return false; // Don't sort methods by default
   }
 
   /**
+   * Sort test methods (across different TestMethodProcessors).
+   *
+   * <p>This should be deterministic. The order should not change, even when tests are added/removed
+   * or between releases.
+   */
+  protected Stream<FrameworkMethod> sortTestMethods(Stream<FrameworkMethod> methods) {
+    if (!shouldSortTestMethodsDeterministically()) {
+      return methods;
+    }
+
+    return methods.sorted(
+            comparing((FrameworkMethod method) -> method.getName().hashCode())
+                    .thenComparing(FrameworkMethod::getName));
+  }
+
+  /**
+   * Returns classes used as annotations to indicate test methods.
+   *
+   * <p>Defaults to {@link Test}.
+   */
+  protected ImmutableList<Class<? extends Annotation>> getSupportedTestAnnotations() {
+    return ImmutableList.of(Test.class);
+  }
+
+  /**
    * {@link TestRule}s that will be executed after the ones defined in the test class (but still
    * before all {@link MethodRule}s). This is meant to be overridden by subclasses.
    */
@@ -146,14 +174,11 @@
   @Override
   protected final ImmutableList<FrameworkMethod> computeTestMethods() {
     Stream<FrameworkMethod> processedMethods =
-        super.computeTestMethods().stream().flatMap(method -> processMethod(method).stream());
+            getSupportedTestAnnotations().stream()
+                    .flatMap(annotation -> getTestClass().getAnnotatedMethods(annotation).stream())
+                    .flatMap(method -> processMethod(method).stream());
 
-    if (shouldSortTestMethodsDeterministically()) {
-      processedMethods =
-          processedMethods.sorted(
-              comparing((FrameworkMethod method) -> method.getName().hashCode())
-                  .thenComparing(FrameworkMethod::getName));
-    }
+    processedMethods = sortTestMethods(processedMethods);
 
     return processedMethods.collect(toImmutableList());
   }
@@ -324,7 +349,10 @@
 
   @Override
   protected final void validateTestMethods(List<Throwable> list) {
-    List<FrameworkMethod> testMethods = getTestClass().getAnnotatedMethods(Test.class);
+    List<FrameworkMethod> testMethods =
+            getSupportedTestAnnotations().stream()
+                    .flatMap(annotation -> getTestClass().getAnnotatedMethods(annotation).stream())
+                    .collect(Collectors.toList());
     for (FrameworkMethod testMethod : testMethods) {
       boolean isHandled = false;
       for (TestMethodProcessor testMethodProcessor : getTestMethodProcessors()) {
diff --git a/src/test/java/com/google/testing/junit/testparameterinjector/PluggableTestRunnerTest.java b/src/test/java/com/google/testing/junit/testparameterinjector/PluggableTestRunnerTest.java
index 686b152..13561ff 100644
--- a/src/test/java/com/google/testing/junit/testparameterinjector/PluggableTestRunnerTest.java
+++ b/src/test/java/com/google/testing/junit/testparameterinjector/PluggableTestRunnerTest.java
@@ -15,9 +15,16 @@
 package com.google.testing.junit.testparameterinjector;
 
 import static com.google.common.truth.Truth.assertThat;
+import static java.util.Comparator.comparing;
 
 import com.google.common.collect.ImmutableList;
+
+import java.lang.annotation.Annotation;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.util.ArrayList;
 import java.util.List;
+import java.util.stream.Stream;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.MethodRule;
@@ -30,8 +37,11 @@
 
 @RunWith(JUnit4.class)
 public class PluggableTestRunnerTest {
+  @Retention(RetentionPolicy.RUNTIME)
+  private static @interface CustomTest {}
 
   private static int ruleInvocationCount = 0;
+  private static int testMethodInvocationCount = 0;
 
   public static class TestAndMethodRule implements MethodRule, TestRule {
 
@@ -49,7 +59,7 @@
   }
 
   @RunWith(PluggableTestRunner.class)
-  public static class PluggableTestRunnerTestClass {
+  public static class TestAndMethodRuleTestClass {
 
     @Rule public TestAndMethodRule rule = new TestAndMethodRule();
 
@@ -62,7 +72,7 @@
   @Test
   public void ruleThatIsBothTestRuleAndMethodRuleIsInvokedOnceOnly() throws Exception {
     PluggableTestRunner.run(
-        new PluggableTestRunner(PluggableTestRunnerTestClass.class) {
+        new PluggableTestRunner(TestAndMethodRuleTestClass.class) {
           @Override
           protected List<TestMethodProcessor> createTestMethodProcessorList() {
             return ImmutableList.of();
@@ -71,4 +81,76 @@
 
     assertThat(ruleInvocationCount).isEqualTo(1);
   }
+
+  @RunWith(PluggableTestRunner.class)
+  public static class CustomTestAnnotationTestClass {
+    @SuppressWarnings("JUnit4TestNotRun")
+    @CustomTest
+    public void customTestAnnotatedTest() {
+      testMethodInvocationCount++;
+    }
+
+    @Test
+    public void testAnnotatedTest() {
+      testMethodInvocationCount++;
+    }
+  }
+
+  @Test
+  public void testMarkedWithCustomClassIsInvoked() throws Exception {
+    testMethodInvocationCount = 0;
+    PluggableTestRunner.run(
+            new PluggableTestRunner(CustomTestAnnotationTestClass.class) {
+              @Override
+              protected List<TestMethodProcessor> createTestMethodProcessorList() {
+                return ImmutableList.of();
+              }
+
+              @Override
+              protected ImmutableList<Class<? extends Annotation>> getSupportedTestAnnotations() {
+                return ImmutableList.of(Test.class, CustomTest.class);
+              }
+            });
+
+    assertThat(testMethodInvocationCount).isEqualTo(2);
+  }
+
+  private static final List<String> testOrder = new ArrayList<>();
+
+  @RunWith(PluggableTestRunner.class)
+  public static class SortedPluggableTestRunnerTestClass {
+    @Test
+    public void a() {
+      testOrder.add("a");
+    }
+
+    @Test
+    public void b() {
+      testOrder.add("b");
+    }
+
+    @Test
+    public void c() {
+      testOrder.add("c");
+    }
+  }
+
+  @Test
+  public void testsAreSortedCorrectly() throws Exception {
+    testOrder.clear();
+    PluggableTestRunner.run(
+            new PluggableTestRunner(SortedPluggableTestRunnerTestClass.class) {
+              @Override
+              protected List<TestMethodProcessor> createTestMethodProcessorList() {
+                return ImmutableList.of();
+              }
+
+              @Override
+              protected Stream<FrameworkMethod> sortTestMethods(Stream<FrameworkMethod> methods) {
+                return methods.sorted(comparing(FrameworkMethod::getName).reversed());
+              }
+            });
+    assertThat(testOrder).containsExactly("c", "b", "a");
+  }
+
 }