Merge "Allow overriding flags in Robolectric tests" into ub-launcher3-master
diff --git a/robolectric_tests/src/com/android/launcher3/config/FlagOverrideRule.java b/robolectric_tests/src/com/android/launcher3/config/FlagOverrideRule.java
new file mode 100644
index 0000000..92bcc64
--- /dev/null
+++ b/robolectric_tests/src/com/android/launcher3/config/FlagOverrideRule.java
@@ -0,0 +1,116 @@
+package com.android.launcher3.config;
+
+
+import org.junit.rules.TestRule;
+import org.junit.runner.Description;
+import org.junit.runners.model.Statement;
+import org.robolectric.RuntimeEnvironment;
+
+import java.lang.annotation.Annotation;
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Repeatable;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
+
+/**
+ * Test rule that makes overriding flags in Robolectric tests easier. This rule clears all flags
+ * before and after your test, avoiding one test method affecting subsequent methods.
+ *
+ * <p>Usage:
+ * <pre>
+ * {@literal @}Rule public final FlagOverrideRule flags = new FlagOverrideRule();
+ *
+ * {@literal @}FlagOverride(flag = "FOO", value=true)
+ * {@literal @}Test public void myTest() {
+ *     ...
+ * }
+ * </pre>
+ */
+public final class FlagOverrideRule implements TestRule {
+
+    /**
+     * Container annotation for handling multiple {@link FlagOverride} annotations.
+     * <p>
+     * <p>Don't use this directly, use repeated {@link FlagOverride} annotations instead.
+     */
+    @Retention(RetentionPolicy.RUNTIME)
+    @Target({ElementType.METHOD})
+    public @interface FlagOverrides {
+        FlagOverride[] value();
+    }
+
+    @Retention(RetentionPolicy.RUNTIME)
+    @Target({ElementType.METHOD})
+    @Repeatable(FlagOverrides.class)
+    public @interface FlagOverride {
+        String key();
+
+        boolean value();
+    }
+
+    private boolean ruleInProgress;
+
+    @Override
+    public Statement apply(Statement base, Description description) {
+        return new Statement() {
+            @Override
+            public void evaluate() throws Throwable {
+                FeatureFlags.initialize(RuntimeEnvironment.application.getApplicationContext());
+                ruleInProgress = true;
+                try {
+                    clearOverrides();
+                    applyAnnotationOverrides(description);
+                    base.evaluate();
+                } finally {
+                    ruleInProgress = false;
+                    clearOverrides();
+                }
+            }
+        };
+    }
+
+    private void override(BaseFlags.TogglableFlag flag, boolean newValue) {
+        if (!ruleInProgress) {
+            throw new IllegalStateException(
+                    "Rule isn't in progress. Did you remember to mark it with @Rule?");
+        }
+        flag.setForTests(newValue);
+    }
+
+    private void applyAnnotationOverrides(Description description) {
+        for (Annotation annotation : description.getAnnotations()) {
+            if (annotation.annotationType() == FlagOverride.class) {
+                applyAnnotation((FlagOverride) annotation);
+            } else if (annotation.annotationType() == FlagOverrides.class) {
+                // Note: this branch is hit if the annotation is repeated
+                for (FlagOverride flagOverride : ((FlagOverrides) annotation).value()) {
+                    applyAnnotation(flagOverride);
+                }
+            }
+        }
+    }
+
+    private void applyAnnotation(FlagOverride flagOverride) {
+        boolean found = false;
+        for (BaseFlags.TogglableFlag flag : FeatureFlags.getTogglableFlags()) {
+            if (flag.getKey().equals(flagOverride.key())) {
+                override(flag, flagOverride.value());
+                found = true;
+                break;
+            }
+        }
+        if (!found) {
+            throw new IllegalStateException("Flag " + flagOverride.key() + " not found");
+        }
+    }
+
+    /**
+     * Resets all flags to their default values.
+     */
+    private void clearOverrides() {
+        for (BaseFlags.TogglableFlag flag : FeatureFlags.getTogglableFlags()) {
+            flag.setForTests(flag.getDefaultValue());
+        }
+    }
+}
diff --git a/robolectric_tests/src/com/android/launcher3/config/FlagOverrideSampleTest.java b/robolectric_tests/src/com/android/launcher3/config/FlagOverrideSampleTest.java
new file mode 100644
index 0000000..c5a0820
--- /dev/null
+++ b/robolectric_tests/src/com/android/launcher3/config/FlagOverrideSampleTest.java
@@ -0,0 +1,38 @@
+package com.android.launcher3.config;
+
+import com.android.launcher3.config.FlagOverrideRule.FlagOverride;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.robolectric.RobolectricTestRunner;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Sample Robolectric test that demonstrates flag-overriding.
+ */
+@RunWith(RobolectricTestRunner.class)
+public class FlagOverrideSampleTest {
+
+    // Check out https://junit.org/junit4/javadoc/4.12/org/junit/Rule.html for more information
+    // on @Rules.
+    @Rule
+    public final FlagOverrideRule flags = new FlagOverrideRule();
+
+    @FlagOverride(key = "EXAMPLE_FLAG", value = true)
+    @FlagOverride(key = "QUICK_SWITCH", value = false)
+    @Test
+    public void withFlagOn() {
+        assertTrue(FeatureFlags.EXAMPLE_FLAG.get());
+        assertFalse(FeatureFlags.QUICK_SWITCH.get());
+    }
+
+
+    @FlagOverride(key = "EXAMPLE_FLAG", value = false)
+    @Test
+    public void withFlagOff() {
+        assertFalse(FeatureFlags.EXAMPLE_FLAG.get());
+    }
+}
diff --git a/src/com/android/launcher3/config/BaseFlags.java b/src/com/android/launcher3/config/BaseFlags.java
index e5a8a01..b27ae31 100644
--- a/src/com/android/launcher3/config/BaseFlags.java
+++ b/src/com/android/launcher3/config/BaseFlags.java
@@ -31,6 +31,7 @@
 
 import androidx.annotation.GuardedBy;
 import androidx.annotation.Keep;
+import androidx.annotation.VisibleForTesting;
 
 /**
  * Defines a set of flags used to control various launcher behaviors.
@@ -148,7 +149,14 @@
             }
         }
 
-        String getKey() {
+        /** Set the value of this flag. This should only be used in tests. */
+        @VisibleForTesting
+        void setForTests(boolean value) {
+            currentValue = value;
+        }
+
+        @VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
+        public String getKey() {
             return key;
         }