[SP17] Wait for stats providers to report stats update

From current design, the traffic accounted by stats providers
will be updated asynchronously when force polling. When upper
layer make two subsequently queries. They will get stale
stats upon the first query, and may get newest/stale stats
base on the result of race.

Thus, wait for a bit of time to allow asynchronous stats update
complete to reduce the chance of race. In pratice, it would
be finished in ~2ms when testing.

Test: systrace.py network
Test: atest FrameworksNetTests
Bug: 147460444
Change-Id: I22a00fc4049cddf77fd578e25769ae1979f2cc6d
diff --git a/services/core/java/com/android/server/net/NetworkStatsService.java b/services/core/java/com/android/server/net/NetworkStatsService.java
index 415ccb8..1c52110 100644
--- a/services/core/java/com/android/server/net/NetworkStatsService.java
+++ b/services/core/java/com/android/server/net/NetworkStatsService.java
@@ -155,6 +155,8 @@
 import java.util.HashSet;
 import java.util.List;
 import java.util.Objects;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
 
 /**
  * Collect and persist detailed network statistics, and provide this data to
@@ -280,8 +282,11 @@
     private final DropBoxNonMonotonicObserver mNonMonotonicObserver =
             new DropBoxNonMonotonicObserver();
 
+    private static final int MAX_STATS_PROVIDER_POLL_WAIT_TIME_MS = 100;
     private final RemoteCallbackList<NetworkStatsProviderCallbackImpl> mStatsProviderCbList =
             new RemoteCallbackList<>();
+    /** Semaphore used to wait for stats provider to respond to request stats update. */
+    private final Semaphore mStatsProviderSem = new Semaphore(0, true);
 
     @GuardedBy("mStatsLock")
     private NetworkStatsRecorder mDevRecorder;
@@ -1337,6 +1342,25 @@
         final boolean persistUid = (flags & FLAG_PERSIST_UID) != 0;
         final boolean persistForce = (flags & FLAG_PERSIST_FORCE) != 0;
 
+        // Request asynchronous stats update from all providers for next poll. And wait a bit of
+        // time to allow providers report-in given that normally binder call should be fast.
+        // TODO: request with a valid token.
+        Trace.traceBegin(TRACE_TAG_NETWORK, "provider.requestStatsUpdate");
+        final int registeredCallbackCount = mStatsProviderCbList.getRegisteredCallbackCount();
+        mStatsProviderSem.drainPermits();
+        invokeForAllStatsProviderCallbacks((cb) -> cb.mProvider.requestStatsUpdate(0 /* unused */));
+        try {
+            mStatsProviderSem.tryAcquire(registeredCallbackCount,
+                    MAX_STATS_PROVIDER_POLL_WAIT_TIME_MS, TimeUnit.MILLISECONDS);
+        } catch (InterruptedException e) {
+            // Strictly speaking it's possible a provider happened to deliver between the timeout
+            // and the log, and that doesn't matter too much as this is just a debug log.
+            Log.d(TAG, "requestStatsUpdate - providers responded "
+                    + mStatsProviderSem.availablePermits()
+                    + "/" + registeredCallbackCount + " : " + e);
+        }
+        Trace.traceEnd(TRACE_TAG_NETWORK);
+
         // TODO: consider marking "untrusted" times in historical stats
         final long currentTime = mClock.millis();
 
@@ -1374,10 +1398,6 @@
             performSampleLocked();
         }
 
-        // request asynchronous stats update from all providers for next poll.
-        // TODO: request with a valid token.
-        invokeForAllStatsProviderCallbacks((cb) -> cb.mProvider.requestStatsUpdate(0 /* unused */));
-
         // finally, dispatch updated event to any listeners
         final Intent updatedIntent = new Intent(ACTION_NETWORK_STATS_UPDATED);
         updatedIntent.setFlags(Intent.FLAG_RECEIVER_REGISTERED_ONLY);
@@ -1798,7 +1818,8 @@
         Objects.requireNonNull(tag, "tag is null");
         try {
             NetworkStatsProviderCallbackImpl callback = new NetworkStatsProviderCallbackImpl(
-                            tag, provider, mAlertObserver, mStatsProviderCbList);
+                    tag, provider, mStatsProviderSem, mAlertObserver,
+                    mStatsProviderCbList);
             mStatsProviderCbList.register(callback);
             Log.d(TAG, "registerNetworkStatsProvider from " + callback.mTag + " uid/pid="
                     + getCallingUid() + "/" + getCallingPid());
@@ -1846,6 +1867,7 @@
         @NonNull final String mTag;
         @NonNull private final Object mProviderStatsLock = new Object();
         @NonNull final INetworkStatsProvider mProvider;
+        @NonNull private final Semaphore mSemaphore;
         @NonNull final INetworkManagementEventObserver mAlertObserver;
         @NonNull final RemoteCallbackList<NetworkStatsProviderCallbackImpl> mStatsProviderCbList;
 
@@ -1857,12 +1879,14 @@
 
         NetworkStatsProviderCallbackImpl(
                 @NonNull String tag, @NonNull INetworkStatsProvider provider,
+                @NonNull Semaphore semaphore,
                 @NonNull INetworkManagementEventObserver alertObserver,
                 @NonNull RemoteCallbackList<NetworkStatsProviderCallbackImpl> cbList)
                 throws RemoteException {
             mTag = tag;
             mProvider = provider;
             mProvider.asBinder().linkToDeath(this, 0);
+            mSemaphore = semaphore;
             mAlertObserver = alertObserver;
             mStatsProviderCbList = cbList;
         }
@@ -1895,6 +1919,7 @@
                 if (ifaceStats != null) mIfaceStats.combineAllValues(ifaceStats);
                 if (uidStats != null) mUidStats.combineAllValues(uidStats);
             }
+            mSemaphore.release();
         }
 
         @Override