Merge changes from topic "sp04"

* changes:
  [SP05] add unit test for onStatsProviderLimitReached in NPMS
  [SP04] add unit test for NetworkStatsProvider
  [SP03] support registerNetworkStatsProvider API
  [SP03.1] Replace com.android.internal.util.Preconditions.checkNotNull
diff --git a/api/system-current.txt b/api/system-current.txt
index 52a86c1..5fa6bbd 100755
--- a/api/system-current.txt
+++ b/api/system-current.txt
@@ -1208,6 +1208,10 @@
     field public static final String SERVICE_INTERFACE = "android.app.usage.CacheQuotaService";
   }
 
+  public class NetworkStatsManager {
+    method @NonNull @RequiresPermission(android.Manifest.permission.UPDATE_DEVICE_STATS) public android.net.netstats.provider.NetworkStatsProviderCallback registerNetworkStatsProvider(@NonNull String, @NonNull android.net.netstats.provider.AbstractNetworkStatsProvider);
+  }
+
   public static final class UsageEvents.Event {
     method public int getInstanceId();
     method @Nullable public String getNotificationChannelId();
diff --git a/core/java/android/app/usage/NetworkStatsManager.java b/core/java/android/app/usage/NetworkStatsManager.java
index 7412970..9c4a8f4 100644
--- a/core/java/android/app/usage/NetworkStatsManager.java
+++ b/core/java/android/app/usage/NetworkStatsManager.java
@@ -16,9 +16,10 @@
 
 package android.app.usage;
 
-import static com.android.internal.util.Preconditions.checkNotNull;
-
+import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.annotation.RequiresPermission;
+import android.annotation.SystemApi;
 import android.annotation.SystemService;
 import android.annotation.TestApi;
 import android.app.usage.NetworkStats.Bucket;
@@ -29,6 +30,9 @@
 import android.net.INetworkStatsService;
 import android.net.NetworkIdentity;
 import android.net.NetworkTemplate;
+import android.net.netstats.provider.AbstractNetworkStatsProvider;
+import android.net.netstats.provider.NetworkStatsProviderCallback;
+import android.net.netstats.provider.NetworkStatsProviderWrapper;
 import android.os.Binder;
 import android.os.Handler;
 import android.os.Looper;
@@ -42,6 +46,8 @@
 
 import com.android.internal.annotations.VisibleForTesting;
 
+import java.util.Objects;
+
 /**
  * Provides access to network usage history and statistics. Usage data is collected in
  * discrete bins of time called 'Buckets'. See {@link NetworkStats.Bucket} for details.
@@ -418,7 +424,7 @@
     /** @hide */
     public void registerUsageCallback(NetworkTemplate template, int networkType,
             long thresholdBytes, UsageCallback callback, @Nullable Handler handler) {
-        checkNotNull(callback, "UsageCallback cannot be null");
+        Objects.requireNonNull(callback, "UsageCallback cannot be null");
 
         final Looper looper;
         if (handler == null) {
@@ -519,6 +525,34 @@
         private DataUsageRequest request;
     }
 
+    /**
+     * Registers a custom provider of {@link android.net.NetworkStats} to combine the network
+     * statistics that cannot be seen by the kernel to system. To unregister, invoke
+     * {@link NetworkStatsProviderCallback#unregister()}.
+     *
+     * @param tag a human readable identifier of the custom network stats provider.
+     * @param provider a custom implementation of {@link AbstractNetworkStatsProvider} that needs to
+     *                 be registered to the system.
+     * @return a {@link NetworkStatsProviderCallback}, which can be used to report events to the
+     *         system.
+     * @hide
+     */
+    @SystemApi
+    @RequiresPermission(android.Manifest.permission.UPDATE_DEVICE_STATS)
+    @NonNull public NetworkStatsProviderCallback registerNetworkStatsProvider(
+            @NonNull String tag,
+            @NonNull AbstractNetworkStatsProvider provider) {
+        try {
+            final NetworkStatsProviderWrapper wrapper = new NetworkStatsProviderWrapper(provider);
+            return new NetworkStatsProviderCallback(
+                    mService.registerNetworkStatsProvider(tag, wrapper));
+        } catch (RemoteException e) {
+            e.rethrowAsRuntimeException();
+        }
+        // Unreachable code, but compiler doesn't know about it.
+        return null;
+    }
+
     private static NetworkTemplate createTemplate(int networkType, String subscriberId) {
         final NetworkTemplate template;
         switch (networkType) {
diff --git a/core/java/android/net/INetworkStatsService.aidl b/core/java/android/net/INetworkStatsService.aidl
index 9994f9f..5fa515a 100644
--- a/core/java/android/net/INetworkStatsService.aidl
+++ b/core/java/android/net/INetworkStatsService.aidl
@@ -23,6 +23,8 @@
 import android.net.NetworkStats;
 import android.net.NetworkStatsHistory;
 import android.net.NetworkTemplate;
+import android.net.netstats.provider.INetworkStatsProvider;
+import android.net.netstats.provider.INetworkStatsProviderCallback;
 import android.os.IBinder;
 import android.os.Messenger;
 import com.android.internal.net.VpnInfo;
@@ -89,4 +91,7 @@
     /** Get the total network stats information since boot */
     long getTotalStats(int type);
 
+    /** Registers a network stats provider */
+    INetworkStatsProviderCallback registerNetworkStatsProvider(String tag,
+            in INetworkStatsProvider provider);
 }
diff --git a/services/core/java/com/android/server/net/NetworkPolicyManagerInternal.java b/services/core/java/com/android/server/net/NetworkPolicyManagerInternal.java
index 7f650ee..b24a938 100644
--- a/services/core/java/com/android/server/net/NetworkPolicyManagerInternal.java
+++ b/services/core/java/com/android/server/net/NetworkPolicyManagerInternal.java
@@ -18,8 +18,10 @@
 
 import static com.android.server.net.NetworkPolicyManagerService.isUidNetworkingBlockedInternal;
 
+import android.annotation.NonNull;
 import android.net.Network;
 import android.net.NetworkTemplate;
+import android.net.netstats.provider.AbstractNetworkStatsProvider;
 import android.telephony.SubscriptionPlan;
 
 import java.util.Set;
@@ -126,4 +128,12 @@
      */
     public abstract void setMeteredRestrictedPackagesAsync(
             Set<String> packageNames, int userId);
+
+    /**
+     *  Notifies that any of the {@link AbstractNetworkStatsProvider} has reached its quota
+     *  which was set through {@link AbstractNetworkStatsProvider#setLimit(String, long)}.
+     *
+     * @param tag the human readable identifier of the custom network stats provider.
+     */
+    public abstract void onStatsProviderLimitReached(@NonNull String tag);
 }
diff --git a/services/core/java/com/android/server/net/NetworkPolicyManagerService.java b/services/core/java/com/android/server/net/NetworkPolicyManagerService.java
index 18c6e3a..04b51a9 100644
--- a/services/core/java/com/android/server/net/NetworkPolicyManagerService.java
+++ b/services/core/java/com/android/server/net/NetworkPolicyManagerService.java
@@ -75,6 +75,7 @@
 import static android.net.NetworkTemplate.MATCH_WIFI;
 import static android.net.NetworkTemplate.buildTemplateMobileAll;
 import static android.net.TrafficStats.MB_IN_BYTES;
+import static android.net.netstats.provider.AbstractNetworkStatsProvider.QUOTA_UNLIMITED;
 import static android.os.Trace.TRACE_TAG_NETWORK;
 import static android.provider.Settings.Global.NETPOLICY_OVERRIDE_ENABLED;
 import static android.provider.Settings.Global.NETPOLICY_QUOTA_ENABLED;
@@ -388,6 +389,7 @@
     private static final int MSG_METERED_RESTRICTED_PACKAGES_CHANGED = 17;
     private static final int MSG_SET_NETWORK_TEMPLATE_ENABLED = 18;
     private static final int MSG_SUBSCRIPTION_PLANS_CHANGED = 19;
+    private static final int MSG_STATS_PROVIDER_LIMIT_REACHED = 20;
 
     private static final int UID_MSG_STATE_CHANGED = 100;
     private static final int UID_MSG_GONE = 101;
@@ -4543,6 +4545,14 @@
                     mListeners.finishBroadcast();
                     return true;
                 }
+                case MSG_STATS_PROVIDER_LIMIT_REACHED: {
+                    mNetworkStats.forceUpdate();
+                    synchronized (mNetworkPoliciesSecondLock) {
+                        updateNetworkEnabledNL();
+                        updateNotificationsNL();
+                    }
+                    return true;
+                }
                 case MSG_LIMIT_REACHED: {
                     final String iface = (String) msg.obj;
 
@@ -4598,14 +4608,18 @@
                     return true;
                 }
                 case MSG_UPDATE_INTERFACE_QUOTA: {
-                    removeInterfaceQuota((String) msg.obj);
+                    final String iface = (String) msg.obj;
                     // int params need to be stitched back into a long
-                    setInterfaceQuota((String) msg.obj,
-                            ((long) msg.arg1 << 32) | (msg.arg2 & 0xFFFFFFFFL));
+                    final long quota = ((long) msg.arg1 << 32) | (msg.arg2 & 0xFFFFFFFFL);
+                    removeInterfaceQuota(iface);
+                    setInterfaceQuota(iface, quota);
+                    mNetworkStats.setStatsProviderLimit(iface, quota);
                     return true;
                 }
                 case MSG_REMOVE_INTERFACE_QUOTA: {
-                    removeInterfaceQuota((String) msg.obj);
+                    final String iface = (String) msg.obj;
+                    removeInterfaceQuota(iface);
+                    mNetworkStats.setStatsProviderLimit(iface, QUOTA_UNLIMITED);
                     return true;
                 }
                 case MSG_RESET_FIREWALL_RULES_BY_UID: {
@@ -5256,6 +5270,12 @@
             mHandler.obtainMessage(MSG_METERED_RESTRICTED_PACKAGES_CHANGED,
                     userId, 0, packageNames).sendToTarget();
         }
+
+        @Override
+        public void onStatsProviderLimitReached(@NonNull String tag) {
+            Log.v(TAG, "onStatsProviderLimitReached: " + tag);
+            mHandler.obtainMessage(MSG_STATS_PROVIDER_LIMIT_REACHED).sendToTarget();
+        }
     }
 
     private void setMeteredRestrictedPackagesInternal(Set<String> packageNames, int userId) {
diff --git a/services/core/java/com/android/server/net/NetworkStatsManagerInternal.java b/services/core/java/com/android/server/net/NetworkStatsManagerInternal.java
index 4843ede..6d72cb5 100644
--- a/services/core/java/com/android/server/net/NetworkStatsManagerInternal.java
+++ b/services/core/java/com/android/server/net/NetworkStatsManagerInternal.java
@@ -16,6 +16,7 @@
 
 package com.android.server.net;
 
+import android.annotation.NonNull;
 import android.net.NetworkStats;
 import android.net.NetworkTemplate;
 
@@ -34,4 +35,10 @@
 
     /** Force update of statistics. */
     public abstract void forceUpdate();
+
+    /**
+     * Set the quota limit to all registered custom network stats providers.
+     * Note that invocation of any interface will be sent to all providers.
+     */
+    public abstract void setStatsProviderLimit(@NonNull String iface, long quota);
 }
diff --git a/services/core/java/com/android/server/net/NetworkStatsService.java b/services/core/java/com/android/server/net/NetworkStatsService.java
index a80fc05..fc39a25 100644
--- a/services/core/java/com/android/server/net/NetworkStatsService.java
+++ b/services/core/java/com/android/server/net/NetworkStatsService.java
@@ -18,6 +18,7 @@
 
 import static android.Manifest.permission.ACCESS_NETWORK_STATE;
 import static android.Manifest.permission.READ_NETWORK_USAGE_HISTORY;
+import static android.Manifest.permission.UPDATE_DEVICE_STATS;
 import static android.content.Intent.ACTION_SHUTDOWN;
 import static android.content.Intent.ACTION_UID_REMOVED;
 import static android.content.Intent.ACTION_USER_REMOVED;
@@ -71,6 +72,7 @@
 import static com.android.server.NetworkManagementSocketTagger.setKernelCounterSet;
 
 import android.annotation.NonNull;
+import android.annotation.Nullable;
 import android.app.AlarmManager;
 import android.app.PendingIntent;
 import android.app.usage.NetworkStatsManager;
@@ -97,6 +99,9 @@
 import android.net.NetworkStatsHistory;
 import android.net.NetworkTemplate;
 import android.net.TrafficStats;
+import android.net.netstats.provider.INetworkStatsProvider;
+import android.net.netstats.provider.INetworkStatsProviderCallback;
+import android.net.netstats.provider.NetworkStatsProviderCallback;
 import android.os.BestClock;
 import android.os.Binder;
 import android.os.DropBoxManager;
@@ -109,6 +114,7 @@
 import android.os.Message;
 import android.os.Messenger;
 import android.os.PowerManager;
+import android.os.RemoteCallbackList;
 import android.os.RemoteException;
 import android.os.SystemClock;
 import android.os.Trace;
@@ -176,7 +182,7 @@
      * This avoids firing the global alert too often on devices with high transfer speeds and
      * high quota.
      */
-    private static final int PERFORM_POLL_DELAY_MS = 1000;
+    private static final int DEFAULT_PERFORM_POLL_DELAY_MS = 1000;
 
     private static final String TAG_NETSTATS_ERROR = "netstats_error";
 
@@ -220,6 +226,7 @@
      */
     public interface NetworkStatsSettings {
         public long getPollInterval();
+        public long getPollDelay();
         public boolean getSampleEnabled();
         public boolean getAugmentEnabled();
 
@@ -248,6 +255,7 @@
     }
 
     private final Object mStatsLock = new Object();
+    private final Object mStatsProviderLock = new Object();
 
     /** Set of currently active ifaces. */
     @GuardedBy("mStatsLock")
@@ -272,6 +280,9 @@
     private final DropBoxNonMonotonicObserver mNonMonotonicObserver =
             new DropBoxNonMonotonicObserver();
 
+    private final RemoteCallbackList<NetworkStatsProviderCallbackImpl> mStatsProviderCbList =
+            new RemoteCallbackList<>();
+
     @GuardedBy("mStatsLock")
     private NetworkStatsRecorder mDevRecorder;
     @GuardedBy("mStatsLock")
@@ -502,9 +513,9 @@
     }
 
     /**
-     * Register for a global alert that is delivered through
-     * {@link INetworkManagementEventObserver} once a threshold amount of data
-     * has been transferred.
+     * Register for a global alert that is delivered through {@link INetworkManagementEventObserver}
+     * or {@link NetworkStatsProviderCallback#onAlertReached()} once a threshold amount of data has
+     * been transferred.
      */
     private void registerGlobalAlert() {
         try {
@@ -514,6 +525,7 @@
         } catch (RemoteException e) {
             // ignored; service lives in system_server
         }
+        invokeForAllStatsProviderCallbacks((cb) -> cb.mProvider.setAlert(mGlobalAlertBytes));
     }
 
     @Override
@@ -803,8 +815,7 @@
     @Override
     public void incrementOperationCount(int uid, int tag, int operationCount) {
         if (Binder.getCallingUid() != uid) {
-            mContext.enforceCallingOrSelfPermission(
-                    android.Manifest.permission.UPDATE_DEVICE_STATS, TAG);
+            mContext.enforceCallingOrSelfPermission(UPDATE_DEVICE_STATS, TAG);
         }
 
         if (operationCount < 0) {
@@ -1095,7 +1106,7 @@
     /**
      * Observer that watches for {@link INetworkManagementService} alerts.
      */
-    private INetworkManagementEventObserver mAlertObserver = new BaseNetworkObserver() {
+    private final INetworkManagementEventObserver mAlertObserver = new BaseNetworkObserver() {
         @Override
         public void limitReached(String limitName, String iface) {
             // only someone like NMS should be calling us
@@ -1106,7 +1117,7 @@
                 // such a call pending; UID stats are handled during normal polling interval.
                 if (!mHandler.hasMessages(MSG_PERFORM_POLL_REGISTER_ALERT)) {
                     mHandler.sendEmptyMessageDelayed(MSG_PERFORM_POLL_REGISTER_ALERT,
-                            PERFORM_POLL_DELAY_MS);
+                            mSettings.getPollDelay());
                 }
             }
         }
@@ -1252,6 +1263,14 @@
         xtSnapshot.combineAllValues(tetherSnapshot);
         devSnapshot.combineAllValues(tetherSnapshot);
 
+        // Snapshot for dev/xt stats from all custom stats providers. Counts per-interface data
+        // from stats providers that isn't already counted by dev and XT stats.
+        Trace.traceBegin(TRACE_TAG_NETWORK, "snapshotStatsProvider");
+        final NetworkStats providersnapshot = getNetworkStatsFromProviders(STATS_PER_IFACE);
+        Trace.traceEnd(TRACE_TAG_NETWORK);
+        xtSnapshot.combineAllValues(providersnapshot);
+        devSnapshot.combineAllValues(providersnapshot);
+
         // For xt/dev, we pass a null VPN array because usage is aggregated by UID, so VPN traffic
         // can't be reattributed to responsible apps.
         Trace.traceBegin(TRACE_TAG_NETWORK, "recordDev");
@@ -1355,6 +1374,10 @@
             performSampleLocked();
         }
 
+        // request asynchronous stats update from all providers for next poll.
+        // TODO: request with a valid token.
+        invokeForAllStatsProviderCallbacks((cb) -> cb.mProvider.requestStatsUpdate(0 /* unused */));
+
         // finally, dispatch updated event to any listeners
         final Intent updatedIntent = new Intent(ACTION_NETWORK_STATS_UPDATED);
         updatedIntent.setFlags(Intent.FLAG_RECEIVER_REGISTERED_ONLY);
@@ -1476,6 +1499,12 @@
         public void forceUpdate() {
             NetworkStatsService.this.forceUpdate();
         }
+
+        @Override
+        public void setStatsProviderLimit(@NonNull String iface, long quota) {
+            Slog.v(TAG, "setStatsProviderLimit(" + iface + "," + quota + ")");
+            invokeForAllStatsProviderCallbacks((cb) -> cb.mProvider.setLimit(iface, quota));
+        }
     }
 
     @Override
@@ -1690,6 +1719,12 @@
             uidSnapshot.combineAllValues(vtStats);
         }
 
+        // get a stale copy of uid stats snapshot provided by providers.
+        final NetworkStats providerStats = getNetworkStatsFromProviders(STATS_PER_UID);
+        providerStats.filter(UID_ALL, ifaces, TAG_ALL);
+        mStatsFactory.apply464xlatAdjustments(uidSnapshot, providerStats, mUseBpfTrafficStats);
+        uidSnapshot.combineAllValues(providerStats);
+
         uidSnapshot.combineAllValues(mUidOperations);
 
         return uidSnapshot;
@@ -1726,6 +1761,152 @@
         }
     }
 
+    /**
+     * Registers a custom provider of {@link android.net.NetworkStats} to combine the network
+     * statistics that cannot be seen by the kernel to system. To unregister, invoke the
+     * {@code unregister()} of the returned callback.
+     *
+     * @param tag a human readable identifier of the custom network stats provider.
+     * @param provider the binder interface of
+     *                 {@link android.net.netstats.provider.AbstractNetworkStatsProvider} that
+     *                 needs to be registered to the system.
+     *
+     * @return a binder interface of
+     *         {@link android.net.netstats.provider.NetworkStatsProviderCallback}, which can be
+     *         used to report events to the system.
+     */
+    public @NonNull INetworkStatsProviderCallback registerNetworkStatsProvider(
+            @NonNull String tag, @NonNull INetworkStatsProvider provider) {
+        mContext.enforceCallingOrSelfPermission(UPDATE_DEVICE_STATS, TAG);
+        Objects.requireNonNull(provider, "provider is null");
+        Objects.requireNonNull(tag, "tag is null");
+        try {
+            NetworkStatsProviderCallbackImpl callback = new NetworkStatsProviderCallbackImpl(
+                            tag, provider, mAlertObserver, mStatsProviderCbList);
+            mStatsProviderCbList.register(callback);
+            Log.d(TAG, "registerNetworkStatsProvider from " + callback.mTag + " uid/pid="
+                    + getCallingUid() + "/" + getCallingPid());
+            return callback;
+        } catch (RemoteException e) {
+            Log.e(TAG, "registerNetworkStatsProvider failed", e);
+        }
+        return null;
+    }
+
+    // Collect stats from local cache of providers.
+    private @NonNull NetworkStats getNetworkStatsFromProviders(int how) {
+        final NetworkStats ret = new NetworkStats(0L, 0);
+        invokeForAllStatsProviderCallbacks((cb) -> ret.combineAllValues(cb.getCachedStats(how)));
+        return ret;
+    }
+
+    @FunctionalInterface
+    private interface ThrowingConsumer<S, T extends Throwable> {
+        void accept(S s) throws T;
+    }
+
+    private void invokeForAllStatsProviderCallbacks(
+            @NonNull ThrowingConsumer<NetworkStatsProviderCallbackImpl, RemoteException> task) {
+        synchronized (mStatsProviderCbList) {
+            final int length = mStatsProviderCbList.beginBroadcast();
+            try {
+                for (int i = 0; i < length; i++) {
+                    final NetworkStatsProviderCallbackImpl cb =
+                            mStatsProviderCbList.getBroadcastItem(i);
+                    try {
+                        task.accept(cb);
+                    } catch (RemoteException e) {
+                        Log.e(TAG, "Fail to broadcast to provider: " + cb.mTag, e);
+                    }
+                }
+            } finally {
+                mStatsProviderCbList.finishBroadcast();
+            }
+        }
+    }
+
+    private static class NetworkStatsProviderCallbackImpl extends INetworkStatsProviderCallback.Stub
+            implements IBinder.DeathRecipient {
+        @NonNull final String mTag;
+        @NonNull private final Object mProviderStatsLock = new Object();
+        @NonNull final INetworkStatsProvider mProvider;
+        @NonNull final INetworkManagementEventObserver mAlertObserver;
+        @NonNull final RemoteCallbackList<NetworkStatsProviderCallbackImpl> mStatsProviderCbList;
+
+        @GuardedBy("mProviderStatsLock")
+        // STATS_PER_IFACE and STATS_PER_UID
+        private final NetworkStats mIfaceStats = new NetworkStats(0L, 0);
+        @GuardedBy("mProviderStatsLock")
+        private final NetworkStats mUidStats = new NetworkStats(0L, 0);
+
+        NetworkStatsProviderCallbackImpl(
+                @NonNull String tag, @NonNull INetworkStatsProvider provider,
+                @NonNull INetworkManagementEventObserver alertObserver,
+                @NonNull RemoteCallbackList<NetworkStatsProviderCallbackImpl> cbList)
+                throws RemoteException {
+            mTag = tag;
+            mProvider = provider;
+            mProvider.asBinder().linkToDeath(this, 0);
+            mAlertObserver = alertObserver;
+            mStatsProviderCbList = cbList;
+        }
+
+        @NonNull
+        public NetworkStats getCachedStats(int how) {
+            synchronized (mProviderStatsLock) {
+                NetworkStats stats;
+                switch (how) {
+                    case STATS_PER_IFACE:
+                        stats = mIfaceStats;
+                        break;
+                    case STATS_PER_UID:
+                        stats = mUidStats;
+                        break;
+                    default:
+                        throw new IllegalArgumentException("Invalid type: " + how);
+                }
+                // Return a defensive copy instead of local reference.
+                return stats.clone();
+            }
+        }
+
+        @Override
+        public void onStatsUpdated(int token, @Nullable NetworkStats ifaceStats,
+                @Nullable NetworkStats uidStats) {
+            // TODO: 1. Use token to map ifaces to correct NetworkIdentity.
+            //       2. Store the difference and store it directly to the recorder.
+            synchronized (mProviderStatsLock) {
+                if (ifaceStats != null) mIfaceStats.combineAllValues(ifaceStats);
+                if (uidStats != null) mUidStats.combineAllValues(uidStats);
+            }
+        }
+
+        @Override
+        public void onAlertReached() throws RemoteException {
+            mAlertObserver.limitReached(LIMIT_GLOBAL_ALERT, null /* unused */);
+        }
+
+        @Override
+        public void onLimitReached() {
+            Log.d(TAG, mTag + ": onLimitReached");
+            LocalServices.getService(NetworkPolicyManagerInternal.class)
+                    .onStatsProviderLimitReached(mTag);
+        }
+
+        @Override
+        public void binderDied() {
+            Log.d(TAG, mTag + ": binderDied");
+            mStatsProviderCbList.unregister(this);
+        }
+
+        @Override
+        public void unregister() {
+            Log.d(TAG, mTag + ": unregister");
+            mStatsProviderCbList.unregister(this);
+        }
+
+    }
+
     @VisibleForTesting
     static class HandlerCallback implements Handler.Callback {
         private final NetworkStatsService mService;
@@ -1815,6 +1996,10 @@
             return getGlobalLong(NETSTATS_POLL_INTERVAL, 30 * MINUTE_IN_MILLIS);
         }
         @Override
+        public long getPollDelay() {
+            return DEFAULT_PERFORM_POLL_DELAY_MS;
+        }
+        @Override
         public long getGlobalAlertBytes(long def) {
             return getGlobalLong(NETSTATS_GLOBAL_ALERT_BYTES, def);
         }
diff --git a/services/tests/servicestests/src/com/android/server/net/NetworkPolicyManagerServiceTest.java b/services/tests/servicestests/src/com/android/server/net/NetworkPolicyManagerServiceTest.java
index 8dfd6a4..2c941c6 100644
--- a/services/tests/servicestests/src/com/android/server/net/NetworkPolicyManagerServiceTest.java
+++ b/services/tests/servicestests/src/com/android/server/net/NetworkPolicyManagerServiceTest.java
@@ -37,6 +37,7 @@
 import static android.net.NetworkStats.IFACE_ALL;
 import static android.net.NetworkStats.SET_ALL;
 import static android.net.NetworkStats.TAG_ALL;
+import static android.net.NetworkStats.TAG_NONE;
 import static android.net.NetworkTemplate.buildTemplateMobileAll;
 import static android.net.NetworkTemplate.buildTemplateWifi;
 import static android.net.TrafficStats.MB_IN_BYTES;
@@ -1683,6 +1684,57 @@
     }
 
     /**
+     * Test that when StatsProvider triggers limit reached, new limit will be calculated and
+     * re-armed.
+     */
+    @Test
+    public void testStatsProviderLimitReached() throws Exception {
+        final int CYCLE_DAY = 15;
+
+        final NetworkStats stats = new NetworkStats(0L, 1);
+        stats.addEntry(TEST_IFACE, UID_A, SET_ALL, TAG_NONE,
+                2999, 1, 2000, 1, 0);
+        when(mStatsService.getNetworkTotalBytes(any(), anyLong(), anyLong()))
+                .thenReturn(stats.getTotalBytes());
+        when(mStatsService.getNetworkUidBytes(any(), anyLong(), anyLong()))
+                .thenReturn(stats);
+
+        // Get active mobile network in place
+        expectMobileDefaults();
+        mService.updateNetworks();
+        verify(mStatsService).setStatsProviderLimit(TEST_IFACE, Long.MAX_VALUE);
+
+        // Set limit to 10KB.
+        setNetworkPolicies(new NetworkPolicy(
+                sTemplateMobileAll, CYCLE_DAY, TIMEZONE_UTC, WARNING_DISABLED, 10000L,
+                true));
+        postMsgAndWaitForCompletion();
+
+        // Verifies that remaining quota is set to providers.
+        verify(mStatsService).setStatsProviderLimit(TEST_IFACE, 10000L - 4999L);
+
+        reset(mStatsService);
+
+        // Increase the usage.
+        stats.addEntry(TEST_IFACE, UID_A, SET_ALL, TAG_NONE,
+                1000, 1, 999, 1, 0);
+        when(mStatsService.getNetworkTotalBytes(any(), anyLong(), anyLong()))
+                .thenReturn(stats.getTotalBytes());
+        when(mStatsService.getNetworkUidBytes(any(), anyLong(), anyLong()))
+                .thenReturn(stats);
+
+        // Simulates that limit reached fires earlier by provider, but actually the quota is not
+        // yet reached.
+        final NetworkPolicyManagerInternal npmi = LocalServices
+                .getService(NetworkPolicyManagerInternal.class);
+        npmi.onStatsProviderLimitReached("TEST");
+
+        // Verifies that the limit reached leads to a force update.
+        postMsgAndWaitForCompletion();
+        verify(mStatsService).forceUpdate();
+    }
+
+    /**
      * Exhaustively test isUidNetworkingBlocked to output the expected results based on external
      * conditions.
      */
diff --git a/tests/net/java/com/android/server/net/NetworkStatsServiceTest.java b/tests/net/java/com/android/server/net/NetworkStatsServiceTest.java
index 6de068e..a9e0b9a 100644
--- a/tests/net/java/com/android/server/net/NetworkStatsServiceTest.java
+++ b/tests/net/java/com/android/server/net/NetworkStatsServiceTest.java
@@ -80,6 +80,7 @@
 import android.net.NetworkStats;
 import android.net.NetworkStatsHistory;
 import android.net.NetworkTemplate;
+import android.net.netstats.provider.INetworkStatsProviderCallback;
 import android.os.ConditionVariable;
 import android.os.Handler;
 import android.os.HandlerThread;
@@ -102,6 +103,7 @@
 import com.android.server.net.NetworkStatsService.NetworkStatsSettings;
 import com.android.server.net.NetworkStatsService.NetworkStatsSettings.Config;
 import com.android.testutils.HandlerUtilsKt;
+import com.android.testutils.TestableNetworkStatsProvider;
 
 import libcore.io.IoUtils;
 
@@ -1001,6 +1003,88 @@
         mService.unregisterUsageRequest(unknownRequest);
     }
 
+    @Test
+    public void testStatsProviderUpdateStats() throws Exception {
+        // Pretend that network comes online.
+        expectDefaultSettings();
+        final NetworkState[] states = new NetworkState[]{buildWifiState(true /* isMetered */)};
+        expectNetworkStatsSummary(buildEmptyStats());
+        expectNetworkStatsUidDetail(buildEmptyStats());
+
+        // Register custom provider and retrieve callback.
+        final TestableNetworkStatsProvider provider = new TestableNetworkStatsProvider();
+        final INetworkStatsProviderCallback cb =
+                mService.registerNetworkStatsProvider("TEST", provider);
+        assertNotNull(cb);
+
+        mService.forceUpdateIfaces(NETWORKS_WIFI, states, getActiveIface(states), new VpnInfo[0]);
+
+        // Verifies that one requestStatsUpdate will be called during iface update.
+        provider.expectStatsUpdate(0 /* unused */);
+
+        // Create some initial traffic and report to the service.
+        incrementCurrentTime(HOUR_IN_MILLIS);
+        final NetworkStats expectedStats = new NetworkStats(0L, 1)
+                .addValues(new NetworkStats.Entry(TEST_IFACE, UID_RED, SET_DEFAULT,
+                        TAG_NONE, METERED_YES, ROAMING_NO, DEFAULT_NETWORK_YES,
+                        128L, 2L, 128L, 2L, 1L))
+                .addValues(new NetworkStats.Entry(TEST_IFACE, UID_RED, SET_DEFAULT,
+                        0xF00D, METERED_YES, ROAMING_NO, DEFAULT_NETWORK_YES,
+                        64L, 1L, 64L, 1L, 1L));
+        cb.onStatsUpdated(0 /* unused */, expectedStats, expectedStats);
+
+        // Make another empty mutable stats object. This is necessary since the new NetworkStats
+        // object will be used to compare with the old one in NetworkStatsRecoder, two of them
+        // cannot be the same object.
+        expectNetworkStatsUidDetail(buildEmptyStats());
+
+        forcePollAndWaitForIdle();
+
+        // Verifies that one requestStatsUpdate and setAlert will be called during polling.
+        provider.expectStatsUpdate(0 /* unused */);
+        provider.expectSetAlert(MB_IN_BYTES);
+
+        // Verifies that service recorded history, does not verify uid tag part.
+        assertUidTotal(sTemplateWifi, UID_RED, 128L, 2L, 128L, 2L, 1);
+
+        // Verifies that onStatsUpdated updates the stats accordingly.
+        final NetworkStats stats = mSession.getSummaryForAllUid(
+                sTemplateWifi, Long.MIN_VALUE, Long.MAX_VALUE, true);
+        assertEquals(2, stats.size());
+        assertValues(stats, IFACE_ALL, UID_RED, SET_DEFAULT, TAG_NONE, METERED_YES, ROAMING_NO,
+                DEFAULT_NETWORK_YES, 128L, 2L, 128L, 2L, 1L);
+        assertValues(stats, IFACE_ALL, UID_RED, SET_DEFAULT, 0xF00D, METERED_YES, ROAMING_NO,
+                DEFAULT_NETWORK_YES, 64L, 1L, 64L, 1L, 1L);
+
+        // Verifies that unregister the callback will remove the provider from service.
+        cb.unregister();
+        forcePollAndWaitForIdle();
+        provider.assertNoCallback();
+    }
+
+    @Test
+    public void testStatsProviderSetAlert() throws Exception {
+        // Pretend that network comes online.
+        expectDefaultSettings();
+        NetworkState[] states = new NetworkState[]{buildWifiState(true /* isMetered */)};
+        mService.forceUpdateIfaces(NETWORKS_WIFI, states, getActiveIface(states), new VpnInfo[0]);
+
+        // Register custom provider and retrieve callback.
+        final TestableNetworkStatsProvider provider = new TestableNetworkStatsProvider();
+        final INetworkStatsProviderCallback cb =
+                mService.registerNetworkStatsProvider("TEST", provider);
+        assertNotNull(cb);
+
+        // Simulates alert quota of the provider has been reached.
+        cb.onAlertReached();
+        HandlerUtilsKt.waitForIdle(mHandler, WAIT_TIMEOUT);
+
+        // Verifies that polling is triggered by alert reached.
+        provider.expectStatsUpdate(0 /* unused */);
+        // Verifies that global alert will be re-armed.
+        provider.expectSetAlert(MB_IN_BYTES);
+    }
+
     private static File getBaseDir(File statsDir) {
         File baseDir = new File(statsDir, "netstats");
         baseDir.mkdirs();
@@ -1102,6 +1186,7 @@
     private void expectSettings(long persistBytes, long bucketDuration, long deleteAge)
             throws Exception {
         when(mSettings.getPollInterval()).thenReturn(HOUR_IN_MILLIS);
+        when(mSettings.getPollDelay()).thenReturn(0L);
         when(mSettings.getSampleEnabled()).thenReturn(true);
 
         final Config config = new Config(bucketDuration, deleteAge, deleteAge);