Merge "[SP17] Wait for stats providers to report stats update"
diff --git a/services/core/java/com/android/server/net/NetworkStatsService.java b/services/core/java/com/android/server/net/NetworkStatsService.java
index 415ccb8..1c52110 100644
--- a/services/core/java/com/android/server/net/NetworkStatsService.java
+++ b/services/core/java/com/android/server/net/NetworkStatsService.java
@@ -155,6 +155,8 @@
 import java.util.HashSet;
 import java.util.List;
 import java.util.Objects;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
 
 /**
  * Collect and persist detailed network statistics, and provide this data to
@@ -280,8 +282,11 @@
     private final DropBoxNonMonotonicObserver mNonMonotonicObserver =
             new DropBoxNonMonotonicObserver();
 
+    private static final int MAX_STATS_PROVIDER_POLL_WAIT_TIME_MS = 100;
     private final RemoteCallbackList<NetworkStatsProviderCallbackImpl> mStatsProviderCbList =
             new RemoteCallbackList<>();
+    /** Semaphore used to wait for stats provider to respond to request stats update. */
+    private final Semaphore mStatsProviderSem = new Semaphore(0, true);
 
     @GuardedBy("mStatsLock")
     private NetworkStatsRecorder mDevRecorder;
@@ -1337,6 +1342,25 @@
         final boolean persistUid = (flags & FLAG_PERSIST_UID) != 0;
         final boolean persistForce = (flags & FLAG_PERSIST_FORCE) != 0;
 
+        // Request asynchronous stats update from all providers for next poll. And wait a bit of
+        // time to allow providers report-in given that normally binder call should be fast.
+        // TODO: request with a valid token.
+        Trace.traceBegin(TRACE_TAG_NETWORK, "provider.requestStatsUpdate");
+        final int registeredCallbackCount = mStatsProviderCbList.getRegisteredCallbackCount();
+        mStatsProviderSem.drainPermits();
+        invokeForAllStatsProviderCallbacks((cb) -> cb.mProvider.requestStatsUpdate(0 /* unused */));
+        try {
+            mStatsProviderSem.tryAcquire(registeredCallbackCount,
+                    MAX_STATS_PROVIDER_POLL_WAIT_TIME_MS, TimeUnit.MILLISECONDS);
+        } catch (InterruptedException e) {
+            // Strictly speaking it's possible a provider happened to deliver between the timeout
+            // and the log, and that doesn't matter too much as this is just a debug log.
+            Log.d(TAG, "requestStatsUpdate - providers responded "
+                    + mStatsProviderSem.availablePermits()
+                    + "/" + registeredCallbackCount + " : " + e);
+        }
+        Trace.traceEnd(TRACE_TAG_NETWORK);
+
         // TODO: consider marking "untrusted" times in historical stats
         final long currentTime = mClock.millis();
 
@@ -1374,10 +1398,6 @@
             performSampleLocked();
         }
 
-        // request asynchronous stats update from all providers for next poll.
-        // TODO: request with a valid token.
-        invokeForAllStatsProviderCallbacks((cb) -> cb.mProvider.requestStatsUpdate(0 /* unused */));
-
         // finally, dispatch updated event to any listeners
         final Intent updatedIntent = new Intent(ACTION_NETWORK_STATS_UPDATED);
         updatedIntent.setFlags(Intent.FLAG_RECEIVER_REGISTERED_ONLY);
@@ -1798,7 +1818,8 @@
         Objects.requireNonNull(tag, "tag is null");
         try {
             NetworkStatsProviderCallbackImpl callback = new NetworkStatsProviderCallbackImpl(
-                            tag, provider, mAlertObserver, mStatsProviderCbList);
+                    tag, provider, mStatsProviderSem, mAlertObserver,
+                    mStatsProviderCbList);
             mStatsProviderCbList.register(callback);
             Log.d(TAG, "registerNetworkStatsProvider from " + callback.mTag + " uid/pid="
                     + getCallingUid() + "/" + getCallingPid());
@@ -1846,6 +1867,7 @@
         @NonNull final String mTag;
         @NonNull private final Object mProviderStatsLock = new Object();
         @NonNull final INetworkStatsProvider mProvider;
+        @NonNull private final Semaphore mSemaphore;
         @NonNull final INetworkManagementEventObserver mAlertObserver;
         @NonNull final RemoteCallbackList<NetworkStatsProviderCallbackImpl> mStatsProviderCbList;
 
@@ -1857,12 +1879,14 @@
 
         NetworkStatsProviderCallbackImpl(
                 @NonNull String tag, @NonNull INetworkStatsProvider provider,
+                @NonNull Semaphore semaphore,
                 @NonNull INetworkManagementEventObserver alertObserver,
                 @NonNull RemoteCallbackList<NetworkStatsProviderCallbackImpl> cbList)
                 throws RemoteException {
             mTag = tag;
             mProvider = provider;
             mProvider.asBinder().linkToDeath(this, 0);
+            mSemaphore = semaphore;
             mAlertObserver = alertObserver;
             mStatsProviderCbList = cbList;
         }
@@ -1895,6 +1919,7 @@
                 if (ifaceStats != null) mIfaceStats.combineAllValues(ifaceStats);
                 if (uidStats != null) mUidStats.combineAllValues(uidStats);
             }
+            mSemaphore.release();
         }
 
         @Override