Handle race conditions in SCS when statsd dies

This CL aims to fix two race conditions:

1. When statsd restarts after a crash, the ordering of sayHiToStatsd and
binderDied is not guaranteed. However, previously, we assumed that
binderDied would get called first and reset sStatsd to null. To solve,
we don't assume a function ordering and don't throw an error message in
sayHiToStatsd if sStatsd is not null.

2. When statsd was linked to death, the death recipient was not informed
about all broadcast receivers. Thus, the death recipient might have
known only a partial list of receivers when #binderDied was triggered. To
solve, we make sure that the death recipient knows about all receivers
before we link to death.

Test: atest statsd_test
Test: atest CtsStatsdHostTestCases
Bug: 154275510

Change-Id: I11be65ca2135cde200ab8ecb611a363d8f7c2eb6
diff --git a/apex/service/java/com/android/server/stats/StatsCompanionService.java b/apex/service/java/com/android/server/stats/StatsCompanionService.java
index 93e6c10..5cf5e0b 100644
--- a/apex/service/java/com/android/server/stats/StatsCompanionService.java
+++ b/apex/service/java/com/android/server/stats/StatsCompanionService.java
@@ -54,11 +54,11 @@
 import java.io.FileOutputStream;
 import java.io.IOException;
 import java.io.PrintWriter;
-import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
 
 /**
  * Helper service for statsd (the native stats management service in cmds/statsd/).
@@ -112,17 +112,8 @@
     private final HashMap<Long, String> mDeletedFiles = new HashMap<>();
     private final CompanionHandler mHandler;
 
-    // Flag that is set when PHASE_BOOT_COMPLETED is triggered in the StatsCompanion lifecycle. This
-    // and the flag mSentBootComplete below is used for synchronization to ensure that the boot
-    // complete signal is only ever sent once to statsd. Two signals are needed because
-    // #sayHiToStatsd can be called from both statsd and #onBootPhase
-    // PHASE_THIRD_PARTY_APPS_CAN_START.
-    @GuardedBy("sStatsdLock")
-    private boolean mBootCompleted = false;
-    // Flag that is set when IStatsd#bootCompleted is called. This flag ensures that boot complete
-    // signal is only ever sent once.
-    @GuardedBy("sStatsdLock")
-    private boolean mSentBootComplete = false;
+    // Flag that is set when PHASE_BOOT_COMPLETED is triggered in the StatsCompanion lifecycle.
+    private AtomicBoolean mBootCompleted = new AtomicBoolean(false);
 
     public StatsCompanionService(Context context) {
         super();
@@ -607,27 +598,35 @@
     // Statsd related code
 
     /**
-     * Fetches the statsd IBinder service. This is a blocking call.
+     * Fetches the statsd IBinder service. This is a blocking call that always refetches statsd
+     * instead of returning the cached sStatsd.
      * Note: This should only be called from {@link #sayHiToStatsd()}. All other clients should use
      * the cached sStatsd via {@link #getStatsdNonblocking()}.
      */
-    private IStatsd fetchStatsdService(StatsdDeathRecipient deathRecipient) {
-        synchronized (sStatsdLock) {
-            if (sStatsd == null) {
-                sStatsd = IStatsd.Stub.asInterface(StatsFrameworkInitializer
-                        .getStatsServiceManager()
-                        .getStatsdServiceRegisterer()
-                        .get());
-                if (sStatsd != null) {
-                    try {
-                        sStatsd.asBinder().linkToDeath(deathRecipient, /* flags */ 0);
-                    } catch (RemoteException e) {
-                        Log.e(TAG, "linkToDeath(StatsdDeathRecipient) failed");
-                        statsdNotReadyLocked();
-                    }
+    private IStatsd fetchStatsdServiceLocked() {
+        sStatsd = IStatsd.Stub.asInterface(StatsFrameworkInitializer
+                .getStatsServiceManager()
+                .getStatsdServiceRegisterer()
+                .get());
+        return sStatsd;
+    }
+
+    private void registerStatsdDeathRecipient(IStatsd statsd, List<BroadcastReceiver> receivers) {
+        StatsdDeathRecipient deathRecipient = new StatsdDeathRecipient(statsd, receivers);
+
+        try {
+            statsd.asBinder().linkToDeath(deathRecipient, /*flags=*/0);
+        } catch (RemoteException e) {
+            Log.e(TAG, "linkToDeath (StatsdDeathRecipient) failed");
+            // Statsd has already died. Unregister receivers ourselves.
+            for (BroadcastReceiver receiver : receivers) {
+                mContext.unregisterReceiver(receiver);
+            }
+            synchronized (sStatsdLock) {
+                if (statsd == sStatsd) {
+                    statsdNotReadyLocked();
                 }
             }
-            return sStatsd;
         }
     }
 
@@ -648,22 +647,23 @@
      * statsd.
      */
     private void sayHiToStatsd() {
-        if (getStatsdNonblocking() != null) {
-            Log.e(TAG, "Trying to fetch statsd, but it was already fetched",
-                    new IllegalStateException(
-                            "sStatsd is not null when being fetched"));
-            return;
+        IStatsd statsd;
+        synchronized (sStatsdLock) {
+            if (sStatsd != null && sStatsd.asBinder().isBinderAlive()) {
+                Log.e(TAG, "statsd has already been fetched before",
+                        new IllegalStateException("IStatsd object should be null or dead"));
+                return;
+            }
+            statsd = fetchStatsdServiceLocked();
         }
-        StatsdDeathRecipient deathRecipient = new StatsdDeathRecipient();
-        IStatsd statsd = fetchStatsdService(deathRecipient);
+
         if (statsd == null) {
-            Log.i(TAG,
-                    "Could not yet find statsd to tell it that StatsCompanion is "
-                            + "alive.");
+            Log.i(TAG, "Could not yet find statsd to tell it that StatsCompanion is alive.");
             return;
         }
-        mStatsManagerService.statsdReady(statsd);
+
         if (DEBUG) Log.d(TAG, "Saying hi to statsd");
+        mStatsManagerService.statsdReady(statsd);
         try {
             statsd.statsCompanionReady();
 
@@ -682,8 +682,7 @@
             mContext.registerReceiverForAllUsers(appUpdateReceiver, filter, null, null);
 
             // Setup receiver for user initialize (which happens once for a new user)
-            // and
-            // if a user is removed.
+            // and if a user is removed.
             filter = new IntentFilter(Intent.ACTION_USER_INITIALIZE);
             filter.addAction(Intent.ACTION_USER_REMOVED);
             mContext.registerReceiverForAllUsers(userUpdateReceiver, filter, null, null);
@@ -691,27 +690,20 @@
             // Setup receiver for device reboots or shutdowns.
             filter = new IntentFilter(Intent.ACTION_REBOOT);
             filter.addAction(Intent.ACTION_SHUTDOWN);
-            mContext.registerReceiverForAllUsers(
-                    shutdownEventReceiver, filter, null, null);
+            mContext.registerReceiverForAllUsers(shutdownEventReceiver, filter, null, null);
 
-            // Only add the receivers if the registration is successful.
-            deathRecipient.addRegisteredBroadcastReceivers(
-                    List.of(appUpdateReceiver, userUpdateReceiver, shutdownEventReceiver));
+            // Register death recipient.
+            List<BroadcastReceiver> broadcastReceivers =
+                    List.of(appUpdateReceiver, userUpdateReceiver, shutdownEventReceiver);
+            registerStatsdDeathRecipient(statsd, broadcastReceivers);
 
-            // Used so we can call statsd.bootComplete() outside of the lock.
-            boolean shouldSendBootComplete = false;
-            synchronized (sStatsdLock) {
-                if (mBootCompleted && !mSentBootComplete) {
-                    mSentBootComplete = true;
-                    shouldSendBootComplete = true;
-                }
-            }
-            if (shouldSendBootComplete) {
+            // Tell statsd that boot has completed. The signal may have already been sent, but since
+            // the signal-receiving function is idempotent, that's ok.
+            if (mBootCompleted.get()) {
                 statsd.bootCompleted();
             }
 
-            // Pull the latest state of UID->app name, version mapping when
-            // statsd starts.
+            // Pull the latest state of UID->app name, version mapping when statsd starts.
             informAllUids(mContext);
 
             Log.i(TAG, "Told statsd that StatsCompanionService is alive.");
@@ -722,18 +714,16 @@
 
     private class StatsdDeathRecipient implements IBinder.DeathRecipient {
 
-        private List<BroadcastReceiver> mReceiversToUnregister;
+        private final IStatsd mStatsd;
+        private final List<BroadcastReceiver> mReceiversToUnregister;
 
-        StatsdDeathRecipient() {
-            mReceiversToUnregister = new ArrayList<>();
+        StatsdDeathRecipient(IStatsd statsd, List<BroadcastReceiver> receivers) {
+            mStatsd = statsd;
+            mReceiversToUnregister = receivers;
         }
 
-        public void addRegisteredBroadcastReceivers(List<BroadcastReceiver> receivers) {
-            synchronized (sStatsdLock) {
-                mReceiversToUnregister.addAll(receivers);
-            }
-        }
-
+        // It is possible for binderDied to be called after a restarted statsd calls statsdReady,
+        // but that's alright because the code does not assume an ordering of the two calls.
         @Override
         public void binderDied() {
             Log.i(TAG, "Statsd is dead - erase all my knowledge, except pullers");
@@ -762,13 +752,19 @@
                         }
                     }
                 }
-                // We only unregister in binder death becaseu receivers can only be unregistered
-                // once, or an IllegalArgumentException is thrown.
+
+                // Unregister receivers on death because receivers can only be unregistered once.
+                // Otherwise, an IllegalArgumentException is thrown.
                 for (BroadcastReceiver receiver: mReceiversToUnregister) {
                     mContext.unregisterReceiver(receiver);
                 }
-                statsdNotReadyLocked();
-                mSentBootComplete = false;
+
+                // It's possible for statsd to have restarted and called statsdReady, causing a new
+                // sStatsd binder object to be fetched, before the binderDied callback runs. Only
+                // call #statsdNotReadyLocked if that hasn't happened yet.
+                if (mStatsd == sStatsd) {
+                    statsdNotReadyLocked();
+                }
             }
         }
     }
@@ -779,19 +775,12 @@
     }
 
     void bootCompleted() {
+        mBootCompleted.set(true);
         IStatsd statsd = getStatsdNonblocking();
-        synchronized (sStatsdLock) {
-            mBootCompleted = true;
-            if (mSentBootComplete) {
-                // do not send a boot complete a second time.
-                return;
-            }
-            if (statsd == null) {
-                // Statsd is not yet ready.
-                // Delay the boot completed ping to {@link #sayHiToStatsd()}
-                return;
-            }
-            mSentBootComplete = true;
+        if (statsd == null) {
+            // Statsd is not yet ready.
+            // Delay the boot completed ping to {@link #sayHiToStatsd()}
+            return;
         }
         try {
             statsd.bootCompleted();
@@ -808,8 +797,7 @@
         }
 
         synchronized (sStatsdLock) {
-            writer.println(
-                    "Number of configuration files deleted: " + mDeletedFiles.size());
+            writer.println("Number of configuration files deleted: " + mDeletedFiles.size());
             if (mDeletedFiles.size() > 0) {
                 writer.println("  timestamp, deleted file name");
             }
@@ -817,8 +805,7 @@
                     SystemClock.currentThreadTimeMillis() - SystemClock.elapsedRealtime();
             for (Long elapsedMillis : mDeletedFiles.keySet()) {
                 long deletionMillis = lastBootMillis + elapsedMillis;
-                writer.println(
-                        "  " + deletionMillis + ", " + mDeletedFiles.get(elapsedMillis));
+                writer.println("  " + deletionMillis + ", " + mDeletedFiles.get(elapsedMillis));
             }
         }
     }