Merge "Remove BroadcastReceiver ref on unregister"
diff --git a/packages/SystemUI/src/com/android/systemui/broadcast/UserBroadcastDispatcher.kt b/packages/SystemUI/src/com/android/systemui/broadcast/UserBroadcastDispatcher.kt
index d50666c..b2942bb 100644
--- a/packages/SystemUI/src/com/android/systemui/broadcast/UserBroadcastDispatcher.kt
+++ b/packages/SystemUI/src/com/android/systemui/broadcast/UserBroadcastDispatcher.kt
@@ -28,6 +28,7 @@
 import android.util.ArraySet
 import android.util.Log
 import androidx.annotation.MainThread
+import androidx.annotation.VisibleForTesting
 import com.android.internal.util.Preconditions
 import com.android.systemui.Dumpable
 import java.io.FileDescriptor
@@ -78,6 +79,14 @@
     private val actionsToReceivers = ArrayMap<String, MutableSet<ReceiverData>>()
     private val receiverToReceiverData = ArrayMap<BroadcastReceiver, MutableSet<ReceiverData>>()
 
+    @VisibleForTesting
+    internal fun isReceiverReferenceHeld(receiver: BroadcastReceiver): Boolean {
+        return receiverToReceiverData.contains(receiver) ||
+                actionsToReceivers.any {
+            it.value.any { it.receiver == receiver }
+        }
+    }
+
     // Only call on BG thread as it reads from the maps
     private fun createFilter(): IntentFilter {
         Preconditions.checkState(bgHandler.looper.isCurrentThread,
@@ -142,7 +151,7 @@
         if (DEBUG) Log.w(TAG, "Unregister receiver: $receiver")
         val actions = receiverToReceiverData.getOrElse(receiver) { return }
                 .flatMap { it.filter.actionsIterator().asSequence().asIterable() }.toSet()
-        receiverToReceiverData.get(receiver)?.clear()
+        receiverToReceiverData.remove(receiver)?.clear()
         var changed = false
         actions.forEach { action ->
             actionsToReceivers.get(action)?.removeIf { it.receiver == receiver }
diff --git a/packages/SystemUI/tests/src/com/android/systemui/broadcast/UserBroadcastDispatcherTest.kt b/packages/SystemUI/tests/src/com/android/systemui/broadcast/UserBroadcastDispatcherTest.kt
index e838d9e..21ed155 100644
--- a/packages/SystemUI/tests/src/com/android/systemui/broadcast/UserBroadcastDispatcherTest.kt
+++ b/packages/SystemUI/tests/src/com/android/systemui/broadcast/UserBroadcastDispatcherTest.kt
@@ -77,7 +77,7 @@
     private lateinit var argumentCaptor: ArgumentCaptor<IntentFilter>
 
     private lateinit var testableLooper: TestableLooper
-    private lateinit var universalBroadcastReceiver: UserBroadcastDispatcher
+    private lateinit var userBroadcastDispatcher: UserBroadcastDispatcher
     private lateinit var intentFilter: IntentFilter
     private lateinit var intentFilterOther: IntentFilter
     private lateinit var handler: Handler
@@ -88,9 +88,9 @@
         testableLooper = TestableLooper.get(this)
         handler = Handler(testableLooper.looper)
 
-        universalBroadcastReceiver = UserBroadcastDispatcher(
+        userBroadcastDispatcher = UserBroadcastDispatcher(
                 mockContext, USER_ID, handler, testableLooper.looper)
-        universalBroadcastReceiver.pendingResult = mPendingResult
+        userBroadcastDispatcher.pendingResult = mPendingResult
     }
 
     @Test
@@ -107,11 +107,11 @@
     fun testSingleReceiverRegistered() {
         intentFilter = IntentFilter(ACTION_1)
 
-        universalBroadcastReceiver.registerReceiver(
+        userBroadcastDispatcher.registerReceiver(
                 ReceiverData(broadcastReceiver, intentFilter, mockHandler, USER_HANDLE))
         testableLooper.processAllMessages()
 
-        assertTrue(universalBroadcastReceiver.isRegistered())
+        assertTrue(userBroadcastDispatcher.isRegistered())
         verify(mockContext).registerReceiverAsUser(
                 any(),
                 eq(USER_HANDLE),
@@ -127,19 +127,19 @@
     fun testSingleReceiverUnregistered() {
         intentFilter = IntentFilter(ACTION_1)
 
-        universalBroadcastReceiver.registerReceiver(
+        userBroadcastDispatcher.registerReceiver(
                 ReceiverData(broadcastReceiver, intentFilter, mockHandler, USER_HANDLE))
         testableLooper.processAllMessages()
         reset(mockContext)
 
-        assertTrue(universalBroadcastReceiver.isRegistered())
+        assertTrue(userBroadcastDispatcher.isRegistered())
 
-        universalBroadcastReceiver.unregisterReceiver(broadcastReceiver)
+        userBroadcastDispatcher.unregisterReceiver(broadcastReceiver)
         testableLooper.processAllMessages()
 
         verify(mockContext, atLeastOnce()).unregisterReceiver(any())
         verify(mockContext, never()).registerReceiverAsUser(any(), any(), any(), any(), any())
-        assertFalse(universalBroadcastReceiver.isRegistered())
+        assertFalse(userBroadcastDispatcher.isRegistered())
     }
 
     @Test
@@ -150,13 +150,13 @@
             addCategory(CATEGORY_2)
         }
 
-        universalBroadcastReceiver.registerReceiver(
+        userBroadcastDispatcher.registerReceiver(
                 ReceiverData(broadcastReceiver, intentFilter, mockHandler, USER_HANDLE))
-        universalBroadcastReceiver.registerReceiver(
+        userBroadcastDispatcher.registerReceiver(
                 ReceiverData(broadcastReceiverOther, intentFilterOther, mockHandler, USER_HANDLE))
 
         testableLooper.processAllMessages()
-        assertTrue(universalBroadcastReceiver.isRegistered())
+        assertTrue(userBroadcastDispatcher.isRegistered())
 
         verify(mockContext, times(2)).registerReceiverAsUser(
                 any(),
@@ -178,14 +178,14 @@
         intentFilter = IntentFilter(ACTION_1)
         intentFilterOther = IntentFilter(ACTION_2)
 
-        universalBroadcastReceiver.registerReceiver(
+        userBroadcastDispatcher.registerReceiver(
                 ReceiverData(broadcastReceiver, intentFilter, handler, USER_HANDLE))
-        universalBroadcastReceiver.registerReceiver(
+        userBroadcastDispatcher.registerReceiver(
                 ReceiverData(broadcastReceiverOther, intentFilterOther, handler, USER_HANDLE))
 
         val intent = Intent(ACTION_2)
 
-        universalBroadcastReceiver.onReceive(mockContext, intent)
+        userBroadcastDispatcher.onReceive(mockContext, intent)
         testableLooper.processAllMessages()
 
         verify(broadcastReceiver, never()).onReceive(any(), any())
@@ -197,14 +197,14 @@
         intentFilter = IntentFilter(ACTION_1)
         intentFilterOther = IntentFilter(ACTION_2)
 
-        universalBroadcastReceiver.registerReceiver(
+        userBroadcastDispatcher.registerReceiver(
                 ReceiverData(broadcastReceiver, intentFilter, handler, USER_HANDLE))
-        universalBroadcastReceiver.registerReceiver(
+        userBroadcastDispatcher.registerReceiver(
                 ReceiverData(broadcastReceiver, intentFilterOther, handler, USER_HANDLE))
 
         val intent = Intent(ACTION_2)
 
-        universalBroadcastReceiver.onReceive(mockContext, intent)
+        userBroadcastDispatcher.onReceive(mockContext, intent)
         testableLooper.processAllMessages()
 
         verify(broadcastReceiver).onReceive(mockContext, intent)
@@ -217,14 +217,14 @@
         intentFilterOther = IntentFilter(ACTION_1)
         intentFilterOther.addCategory(CATEGORY_2)
 
-        universalBroadcastReceiver.registerReceiver(
+        userBroadcastDispatcher.registerReceiver(
                 ReceiverData(broadcastReceiver, intentFilter, handler, USER_HANDLE))
-        universalBroadcastReceiver.registerReceiver(
+        userBroadcastDispatcher.registerReceiver(
                 ReceiverData(broadcastReceiverOther, intentFilterOther, handler, USER_HANDLE))
 
         val intent = Intent(ACTION_1)
 
-        universalBroadcastReceiver.onReceive(mockContext, intent)
+        userBroadcastDispatcher.onReceive(mockContext, intent)
         testableLooper.processAllMessages()
 
         verify(broadcastReceiver).onReceive(mockContext, intent)
@@ -234,15 +234,32 @@
     @Test
     fun testPendingResult() {
         intentFilter = IntentFilter(ACTION_1)
-        universalBroadcastReceiver.registerReceiver(
+        userBroadcastDispatcher.registerReceiver(
                 ReceiverData(broadcastReceiver, intentFilter, handler, USER_HANDLE))
 
         val intent = Intent(ACTION_1)
-        universalBroadcastReceiver.onReceive(mockContext, intent)
+        userBroadcastDispatcher.onReceive(mockContext, intent)
 
         testableLooper.processAllMessages()
 
         verify(broadcastReceiver).onReceive(mockContext, intent)
         verify(broadcastReceiver).pendingResult = mPendingResult
     }
+
+    @Test
+    fun testRemoveReceiverReferences() {
+        intentFilter = IntentFilter(ACTION_1)
+        userBroadcastDispatcher.registerReceiver(
+                ReceiverData(broadcastReceiver, intentFilter, handler, USER_HANDLE))
+
+        intentFilterOther = IntentFilter(ACTION_1)
+        intentFilterOther.addAction(ACTION_2)
+        userBroadcastDispatcher.registerReceiver(
+                ReceiverData(broadcastReceiverOther, intentFilterOther, handler, USER_HANDLE))
+
+        userBroadcastDispatcher.unregisterReceiver(broadcastReceiver)
+        testableLooper.processAllMessages()
+
+        assertFalse(userBroadcastDispatcher.isReceiverReferenceHeld(broadcastReceiver))
+    }
 }