Merge "Fix flaky invalidation test" into oc-mr1-support-27.0-dev
diff --git a/room/integration-tests/testapp/src/androidTest/java/android/arch/persistence/room/integration/testapp/test/InvalidationTest.java b/room/integration-tests/testapp/src/androidTest/java/android/arch/persistence/room/integration/testapp/test/InvalidationTest.java
index 84f20ec..33f4018 100644
--- a/room/integration-tests/testapp/src/androidTest/java/android/arch/persistence/room/integration/testapp/test/InvalidationTest.java
+++ b/room/integration-tests/testapp/src/androidTest/java/android/arch/persistence/room/integration/testapp/test/InvalidationTest.java
@@ -17,20 +17,17 @@
 package android.arch.persistence.room.integration.testapp.test;
 
 import static org.hamcrest.CoreMatchers.hasItem;
-import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.CoreMatchers.nullValue;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.collection.IsCollectionWithSize.hasSize;
 
-import android.arch.core.executor.ArchTaskExecutor;
-import android.arch.core.executor.TaskExecutor;
+import android.arch.core.executor.testing.CountingTaskExecutorRule;
 import android.arch.persistence.room.InvalidationTracker;
 import android.arch.persistence.room.Room;
 import android.arch.persistence.room.integration.testapp.TestDatabase;
 import android.arch.persistence.room.integration.testapp.dao.UserDao;
 import android.arch.persistence.room.integration.testapp.vo.User;
 import android.content.Context;
-import android.os.Handler;
-import android.os.Looper;
 import android.support.annotation.NonNull;
 import android.support.test.InstrumentationRegistry;
 import android.support.test.filters.SmallTest;
@@ -38,17 +35,13 @@
 
 import org.junit.After;
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 
 import java.util.Set;
-import java.util.concurrent.Callable;
-import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.FutureTask;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
 
 /**
  * Tests invalidation tracking.
@@ -56,138 +49,97 @@
 @SmallTest
 @RunWith(AndroidJUnit4.class)
 public class InvalidationTest {
+    @Rule
+    public CountingTaskExecutorRule executorRule = new CountingTaskExecutorRule();
     private UserDao mUserDao;
     private TestDatabase mDb;
 
     @Before
-    public void createDb() {
+    public void createDb() throws TimeoutException, InterruptedException {
         Context context = InstrumentationRegistry.getTargetContext();
         mDb = Room.inMemoryDatabaseBuilder(context, TestDatabase.class).build();
         mUserDao = mDb.getUserDao();
-    }
-
-    @Before
-    public void setSingleThreadedIO() {
-        ArchTaskExecutor.getInstance().setDelegate(new TaskExecutor() {
-            ExecutorService mIOExecutor = Executors.newSingleThreadExecutor();
-            Handler mHandler = new Handler(Looper.getMainLooper());
-
-            @Override
-            public void executeOnDiskIO(Runnable runnable) {
-                mIOExecutor.execute(runnable);
-            }
-
-            @Override
-            public void postToMainThread(Runnable runnable) {
-                mHandler.post(runnable);
-            }
-
-            @Override
-            public boolean isMainThread() {
-                return Thread.currentThread() == Looper.getMainLooper().getThread();
-            }
-        });
+        drain();
     }
 
     @After
-    public void clearExecutor() {
-        ArchTaskExecutor.getInstance().setDelegate(null);
+    public void closeDb() throws TimeoutException, InterruptedException {
+        mDb.close();
+        drain();
     }
 
-    private void waitUntilIOThreadIsIdle() {
-        FutureTask<Void> future = new FutureTask<>(new Callable<Void>() {
-            @Override
-            public Void call() throws Exception {
-                return null;
-            }
-        });
-        ArchTaskExecutor.getInstance().executeOnDiskIO(future);
-        //noinspection TryWithIdenticalCatches
-        try {
-            future.get();
-        } catch (InterruptedException e) {
-            throw new RuntimeException(e);
-        } catch (ExecutionException e) {
-            throw new RuntimeException(e);
-        }
+    private void drain() throws TimeoutException, InterruptedException {
+        executorRule.drainTasks(1, TimeUnit.MINUTES);
     }
 
     @Test
-    public void testInvalidationOnUpdate() throws InterruptedException {
+    public void testInvalidationOnUpdate() throws InterruptedException, TimeoutException {
         User user = TestUtil.createUser(3);
         mUserDao.insert(user);
-        LatchObserver observer = new LatchObserver(1, "User");
+        LoggingObserver observer = new LoggingObserver("User");
         mDb.getInvalidationTracker().addObserver(observer);
+        drain();
         mUserDao.updateById(3, "foo2");
-        waitUntilIOThreadIsIdle();
-        assertThat(observer.await(), is(true));
+        drain();
         assertThat(observer.getInvalidatedTables(), hasSize(1));
         assertThat(observer.getInvalidatedTables(), hasItem("User"));
     }
 
     @Test
-    public void testInvalidationOnDelete() throws InterruptedException {
+    public void testInvalidationOnDelete() throws InterruptedException, TimeoutException {
         User user = TestUtil.createUser(3);
         mUserDao.insert(user);
-        LatchObserver observer = new LatchObserver(1, "User");
+        LoggingObserver observer = new LoggingObserver("User");
         mDb.getInvalidationTracker().addObserver(observer);
+        drain();
         mUserDao.delete(user);
-        waitUntilIOThreadIsIdle();
-        assertThat(observer.await(), is(true));
+        drain();
         assertThat(observer.getInvalidatedTables(), hasSize(1));
         assertThat(observer.getInvalidatedTables(), hasItem("User"));
     }
 
     @Test
-    public void testInvalidationOnInsert() throws InterruptedException {
-        LatchObserver observer = new LatchObserver(1, "User");
+    public void testInvalidationOnInsert() throws InterruptedException, TimeoutException {
+        LoggingObserver observer = new LoggingObserver("User");
         mDb.getInvalidationTracker().addObserver(observer);
+        drain();
         mUserDao.insert(TestUtil.createUser(3));
-        waitUntilIOThreadIsIdle();
-        assertThat(observer.await(), is(true));
+        drain();
         assertThat(observer.getInvalidatedTables(), hasSize(1));
         assertThat(observer.getInvalidatedTables(), hasItem("User"));
     }
 
     @Test
-    public void testDontInvalidateOnLateInsert() throws InterruptedException {
-        LatchObserver observer = new LatchObserver(1, "User");
+    public void testDontInvalidateOnLateInsert() throws InterruptedException, TimeoutException {
+        LoggingObserver observer = new LoggingObserver("User");
         mUserDao.insert(TestUtil.createUser(3));
-        waitUntilIOThreadIsIdle();
+        drain();
         mDb.getInvalidationTracker().addObserver(observer);
-        waitUntilIOThreadIsIdle();
-        assertThat(observer.await(), is(false));
+        drain();
+        assertThat(observer.getInvalidatedTables(), nullValue());
     }
 
     @Test
-    public void testMultipleTables() throws InterruptedException {
-        LatchObserver observer = new LatchObserver(1, "User", "Pet");
+    public void testMultipleTables() throws InterruptedException, TimeoutException {
+        LoggingObserver observer = new LoggingObserver("User", "Pet");
         mDb.getInvalidationTracker().addObserver(observer);
+        drain();
         mUserDao.insert(TestUtil.createUser(3));
-        waitUntilIOThreadIsIdle();
-        assertThat(observer.await(), is(true));
+        drain();
         assertThat(observer.getInvalidatedTables(), hasSize(1));
         assertThat(observer.getInvalidatedTables(), hasItem("User"));
     }
 
-    private static class LatchObserver extends InvalidationTracker.Observer {
-        CountDownLatch mLatch;
-
+    private static class LoggingObserver extends InvalidationTracker.Observer {
         private Set<String> mInvalidatedTables;
 
-        LatchObserver(int permits, String... tables) {
+        LoggingObserver(String... tables) {
             super(tables);
-            mLatch = new CountDownLatch(permits);
-        }
-
-        boolean await() throws InterruptedException {
-            return mLatch.await(5, TimeUnit.SECONDS);
         }
 
         @Override
         public void onInvalidated(@NonNull Set<String> tables) {
             mInvalidatedTables = tables;
-            mLatch.countDown();
         }
 
         Set<String> getInvalidatedTables() {