Move InvalidationTracker to SafeIterableMap

Test: refactoring CL.
Change-Id: I46c2acda1b5d50dada12a99fdefff91298807b9d
diff --git a/room/runtime/src/main/java/com/android/support/room/InvalidationTracker.java b/room/runtime/src/main/java/com/android/support/room/InvalidationTracker.java
index 49b2605..f1ad04f 100644
--- a/room/runtime/src/main/java/com/android/support/room/InvalidationTracker.java
+++ b/room/runtime/src/main/java/com/android/support/room/InvalidationTracker.java
@@ -24,13 +24,14 @@
 import android.support.v4.util.ArrayMap;
 import android.util.Log;
 
-import com.android.support.apptoolkit.internal.ObserverSet;
+import com.android.support.apptoolkit.internal.SafeIterableMap;
 import com.android.support.db.SupportSQLiteDatabase;
 import com.android.support.db.SupportSQLiteStatement;
 import com.android.support.executors.AppToolkitTaskExecutor;
 
 import java.lang.ref.WeakReference;
 import java.util.Arrays;
+import java.util.Map;
 import java.util.concurrent.atomic.AtomicBoolean;
 
 /**
@@ -106,42 +107,14 @@
 
     private ObservedTableTracker mObservedTableTracker;
 
+    // should be accessed with synchronization only.
     @VisibleForTesting
-    SyncObserverSet<ObserverWrapper> mObserverSet;
-
-    private ObserverSet.Callback<ObserverWrapper> mInvalidCheck =
-            new ObserverSet.Callback<ObserverWrapper>() {
-
-                @Override
-                public void run(ObserverWrapper key) {
-                    key.checkForInvalidation(mTableVersions);
-                }
-            };
+    SafeIterableMap<Observer, ObserverWrapper> mObserverMap = new SafeIterableMap<>();
 
     @SuppressWarnings("WeakerAccess") // used by generated code.
     public InvalidationTracker(RoomDatabase database, String... tableNames) {
         mDatabase = database;
         mObservedTableTracker = new ObservedTableTracker(tableNames.length);
-        mObserverSet = new SyncObserverSet<ObserverWrapper>() {
-            @Override
-            protected boolean checkEquality(ObserverWrapper existing, ObserverWrapper added) {
-                return existing.mObserver == added.mObserver;
-            }
-
-            @Override
-            protected void onAdded(ObserverWrapper item) {
-                if (mObservedTableTracker.onAdded(item.mTableIds)) {
-                    AppToolkitTaskExecutor.getInstance().executeOnDiskIO(mSyncTriggers);
-                }
-            }
-
-            @Override
-            protected void onRemoved(ObserverWrapper item) {
-                if (mObservedTableTracker.onRemoved(item.mTableIds)) {
-                    AppToolkitTaskExecutor.getInstance().executeOnDiskIO(mSyncTriggers);
-                }
-            }
-        };
         mTableIdLookup = new ArrayMap<>();
         final int size = tableNames.length;
         mTableNames = new String[size];
@@ -245,7 +218,14 @@
             tableIds[i] = tableId;
             versions[i] = mMaxVersion;
         }
-        mObserverSet.add(new ObserverWrapper(observer, tableIds, versions));
+        ObserverWrapper wrapper = new ObserverWrapper(observer, tableIds, versions);
+        ObserverWrapper currentObserver;
+        synchronized (mObserverMap) {
+            currentObserver = mObserverMap.putIfAbsent(observer, wrapper);
+        }
+        if (currentObserver == null && mObservedTableTracker.onAdded(tableIds)) {
+            AppToolkitTaskExecutor.getInstance().executeOnDiskIO(mSyncTriggers);
+        }
     }
 
     /**
@@ -270,7 +250,13 @@
      */
     @SuppressWarnings("WeakerAccess")
     public void removeObserver(final Observer observer) {
-        mObserverSet.remove(new ObserverWrapper(observer, null, null));
+        ObserverWrapper wrapper;
+        synchronized (mObserverMap) {
+            wrapper = mObserverMap.remove(observer);
+        }
+        if (wrapper != null && mObservedTableTracker.onRemoved(wrapper.mTableIds)) {
+            AppToolkitTaskExecutor.getInstance().executeOnDiskIO(mSyncTriggers);
+        }
     }
 
     private Runnable mSyncTriggers = new Runnable() {
@@ -359,7 +345,11 @@
                 cursor.close();
             }
             if (hasUpdatedTable) {
-                mObserverSet.forEach(mInvalidCheck);
+                synchronized (mObserverMap) {
+                    for (Map.Entry<Observer, ObserverWrapper> entry : mObserverMap) {
+                        entry.getValue().checkForInvalidation(mTableVersions);
+                    }
+                }
             }
         }
     };
@@ -598,34 +588,4 @@
             }
         }
     }
-
-    /**
-     * Poor man's sync on observer set.
-     * <p>
-     * When we revisit observer set, we should consider making it thread safe.
-     *
-     * @param <T> The type of the items.
-     */
-    abstract static class SyncObserverSet<T> extends ObserverSet<T> {
-        @Override
-        public void add(T observer) {
-            synchronized (this) {
-                super.add(observer);
-            }
-        }
-
-        @Override
-        public void remove(T observer) {
-            synchronized (this) {
-                super.remove(observer);
-            }
-        }
-
-        @Override
-        public void forEach(Callback<T> func) {
-            synchronized (this) {
-                super.forEach(func);
-            }
-        }
-    }
 }
diff --git a/room/runtime/src/test/java/com/android/support/room/InvalidationTrackerTest.java b/room/runtime/src/test/java/com/android/support/room/InvalidationTrackerTest.java
index 3a3e52b..b77222c 100644
--- a/room/runtime/src/test/java/com/android/support/room/InvalidationTrackerTest.java
+++ b/room/runtime/src/test/java/com/android/support/room/InvalidationTrackerTest.java
@@ -83,13 +83,13 @@
         InvalidationTracker.Observer observer = new LatchObserver(1, "a");
         mTracker.addObserver(observer);
         drainTasks();
-        assertThat(mTracker.mObserverSet.size(), is(1));
+        assertThat(mTracker.mObserverMap.size(), is(1));
         mTracker.removeObserver(new LatchObserver(1, "a"));
         drainTasks();
-        assertThat(mTracker.mObserverSet.size(), is(1));
+        assertThat(mTracker.mObserverMap.size(), is(1));
         mTracker.removeObserver(observer);
         drainTasks();
-        assertThat(mTracker.mObserverSet.size(), is(0));
+        assertThat(mTracker.mObserverMap.size(), is(0));
     }
 
     private void drainTasks() throws InterruptedException {