Register for kernel global data usage alerts.

Instead of polling every 15 minutes, register for alerts that trigger
when system-wide traffic passes a threshold.  Still mixed with polling
to persist UID stats, but relaxed to 30 minutes.  Currently watches
for every 512kB.

Make persistence decision separately for network versus UID, and use
total delta bytes when making decision.  Use light bootstrap during
systemReady() instead of heavy poll, which had been force-loading all
UID data unnecessarily.

Bug: 5023631
Change-Id: I04b723d6c4bf872fb1028071122dba66a8e1b576
diff --git a/core/java/android/net/NetworkStats.java b/core/java/android/net/NetworkStats.java
index 272545d..5b883a0 100644
--- a/core/java/android/net/NetworkStats.java
+++ b/core/java/android/net/NetworkStats.java
@@ -313,6 +313,22 @@
     }
 
     /**
+     * Return total bytes represented by this snapshot object, usually used when
+     * checking if a {@link #subtract(NetworkStats)} delta passes a threshold.
+     */
+    public long getTotalBytes() {
+        long totalBytes = 0;
+        for (int i = 0; i < size; i++) {
+            // skip specific tags, since already counted in TAG_NONE
+            if (tag[i] != TAG_NONE) continue;
+
+            totalBytes += rxBytes[i];
+            totalBytes += txBytes[i];
+        }
+        return totalBytes;
+    }
+
+    /**
      * Subtract the given {@link NetworkStats}, effectively leaving the delta
      * between two snapshots in time. Assumes that statistics rows collect over
      * time, and that none of them have disappeared.
diff --git a/core/tests/coretests/src/android/net/NetworkStatsTest.java b/core/tests/coretests/src/android/net/NetworkStatsTest.java
index 47ba88a..c36685d 100644
--- a/core/tests/coretests/src/android/net/NetworkStatsTest.java
+++ b/core/tests/coretests/src/android/net/NetworkStatsTest.java
@@ -17,7 +17,9 @@
 package android.net;
 
 import static android.net.NetworkStats.SET_DEFAULT;
+import static android.net.NetworkStats.SET_FOREGROUND;
 import static android.net.NetworkStats.TAG_NONE;
+import static android.net.NetworkStats.UID_ALL;
 
 import android.test.suitebuilder.annotation.SmallTest;
 
@@ -27,6 +29,7 @@
 public class NetworkStatsTest extends TestCase {
 
     private static final String TEST_IFACE = "test0";
+    private static final String TEST_IFACE2 = "test2";
     private static final int TEST_UID = 1001;
     private static final long TEST_START = 1194220800000L;
 
@@ -135,6 +138,44 @@
         assertValues(result, 2, TEST_IFACE, 102, SET_DEFAULT, TAG_NONE, 1024L, 8L, 1024L, 8L, 20);
     }
 
+    public void testSubtractMissingRows() throws Exception {
+        final NetworkStats before = new NetworkStats(TEST_START, 2)
+                .addValues(TEST_IFACE, UID_ALL, SET_DEFAULT, TAG_NONE, 1024L, 0L, 0L, 0L, 0)
+                .addValues(TEST_IFACE2, UID_ALL, SET_DEFAULT, TAG_NONE, 2048L, 0L, 0L, 0L, 0);
+
+        final NetworkStats after = new NetworkStats(TEST_START, 1)
+                .addValues(TEST_IFACE2, UID_ALL, SET_DEFAULT, TAG_NONE, 2049L, 2L, 3L, 4L, 0);
+
+        final NetworkStats result = after.subtract(before);
+
+        // should silently drop omitted rows
+        assertEquals(1, result.size());
+        assertValues(result, 0, TEST_IFACE2, UID_ALL, SET_DEFAULT, TAG_NONE, 1L, 2L, 3L, 4L, 0);
+        assertEquals(4L, result.getTotalBytes());
+    }
+
+    public void testTotalBytes() throws Exception {
+        final NetworkStats iface = new NetworkStats(TEST_START, 2)
+                .addValues(TEST_IFACE, UID_ALL, SET_DEFAULT, TAG_NONE, 128L, 0L, 0L, 0L, 0L)
+                .addValues(TEST_IFACE2, UID_ALL, SET_DEFAULT, TAG_NONE, 256L, 0L, 0L, 0L, 0L);
+        assertEquals(384L, iface.getTotalBytes());
+
+        final NetworkStats uidSet = new NetworkStats(TEST_START, 3)
+                .addValues(TEST_IFACE, 100, SET_DEFAULT, TAG_NONE, 32L, 0L, 0L, 0L, 0L)
+                .addValues(TEST_IFACE, 101, SET_DEFAULT, TAG_NONE, 32L, 0L, 0L, 0L, 0L)
+                .addValues(TEST_IFACE, 101, SET_FOREGROUND, TAG_NONE, 32L, 0L, 0L, 0L, 0L);
+        assertEquals(96L, uidSet.getTotalBytes());
+
+        final NetworkStats uidTag = new NetworkStats(TEST_START, 3)
+                .addValues(TEST_IFACE, 100, SET_DEFAULT, TAG_NONE, 16L, 0L, 0L, 0L, 0L)
+                .addValues(TEST_IFACE2, 100, SET_DEFAULT, TAG_NONE, 16L, 0L, 0L, 0L, 0L)
+                .addValues(TEST_IFACE2, 100, SET_DEFAULT, 0xF00D, 8L, 0L, 0L, 0L, 0L)
+                .addValues(TEST_IFACE2, 100, SET_FOREGROUND, TAG_NONE, 16L, 0L, 0L, 0L, 0L)
+                .addValues(TEST_IFACE, 101, SET_DEFAULT, TAG_NONE, 16L, 0L, 0L, 0L, 0L)
+                .addValues(TEST_IFACE, 101, SET_DEFAULT, 0xF00D, 8L, 0L, 0L, 0L, 0L);
+        assertEquals(64L, uidTag.getTotalBytes());
+    }
+
     private static void assertValues(NetworkStats stats, int index, String iface, int uid, int set,
             int tag, long rxBytes, long rxPackets, long txBytes, long txPackets, int operations) {
         final NetworkStats.Entry entry = stats.getValues(index, null);
diff --git a/services/java/com/android/server/NetworkManagementService.java b/services/java/com/android/server/NetworkManagementService.java
index 06077dd..c679dcf 100644
--- a/services/java/com/android/server/NetworkManagementService.java
+++ b/services/java/com/android/server/NetworkManagementService.java
@@ -69,7 +69,8 @@
 /**
  * @hide
  */
-class NetworkManagementService extends INetworkManagementService.Stub implements Watchdog.Monitor {
+public class NetworkManagementService extends INetworkManagementService.Stub
+        implements Watchdog.Monitor {
     private static final String TAG = "NetworkManagementService";
     private static final boolean DBG = false;
     private static final String NETD_TAG = "NetdConnector";
@@ -87,6 +88,12 @@
     /** Path to {@code /proc/net/xt_qtaguid/iface_stat}. */
     private final File mStatsXtIface;
 
+    /**
+     * Name representing {@link #setGlobalAlert(long)} limit when delivered to
+     * {@link INetworkManagementEventObserver#limitReached(String, String)}.
+     */
+    public static final String LIMIT_GLOBAL_ALERT = "globalAlert";
+
     /** {@link #mStatsXtUid} headers. */
     private static final String KEY_IFACE = "iface";
     private static final String KEY_UID = "uid_tag_int";
diff --git a/services/java/com/android/server/net/NetworkPolicyManagerService.java b/services/java/com/android/server/net/NetworkPolicyManagerService.java
index 84880f9..84e5eae 100644
--- a/services/java/com/android/server/net/NetworkPolicyManagerService.java
+++ b/services/java/com/android/server/net/NetworkPolicyManagerService.java
@@ -60,6 +60,7 @@
 import static com.android.server.net.NetworkStatsService.ACTION_NETWORK_STATS_UPDATED;
 import static org.xmlpull.v1.XmlPullParser.END_DOCUMENT;
 import static org.xmlpull.v1.XmlPullParser.START_TAG;
+import static com.android.server.NetworkManagementService.LIMIT_GLOBAL_ALERT;
 
 import android.app.IActivityManager;
 import android.app.INotificationManager;
@@ -454,7 +455,7 @@
             mContext.enforceCallingOrSelfPermission(CONNECTIVITY_INTERNAL, TAG);
 
             synchronized (mRulesLock) {
-                if (mMeteredIfaces.contains(iface)) {
+                if (mMeteredIfaces.contains(iface) && !LIMIT_GLOBAL_ALERT.equals(limitName)) {
                     try {
                         // force stats update to make sure we have numbers that
                         // caused alert to trigger.
@@ -763,7 +764,12 @@
             // disable data connection when over limit and not snoozed
             final boolean overLimit = policy.limitBytes != LIMIT_DISABLED
                     && totalBytes > policy.limitBytes && policy.lastSnooze < start;
-            setNetworkTemplateEnabled(policy.template, !overLimit);
+            final boolean enabled = !overLimit;
+
+            if (LOGD) {
+                Slog.d(TAG, "setting template=" + policy.template + " enabled=" + enabled);
+            }
+            setNetworkTemplateEnabled(policy.template, enabled);
         }
     }
 
@@ -772,7 +778,6 @@
      * for the given {@link NetworkTemplate}.
      */
     private void setNetworkTemplateEnabled(NetworkTemplate template, boolean enabled) {
-        if (LOGD) Slog.d(TAG, "setting template=" + template + " enabled=" + enabled);
         switch (template.getMatchRule()) {
             case MATCH_MOBILE_3G_LOWER:
             case MATCH_MOBILE_4G:
diff --git a/services/java/com/android/server/net/NetworkStatsService.java b/services/java/com/android/server/net/NetworkStatsService.java
index c911687..80ae9bc 100644
--- a/services/java/com/android/server/net/NetworkStatsService.java
+++ b/services/java/com/android/server/net/NetworkStatsService.java
@@ -43,6 +43,7 @@
 import static android.text.format.DateUtils.HOUR_IN_MILLIS;
 import static android.text.format.DateUtils.MINUTE_IN_MILLIS;
 import static com.android.internal.util.Preconditions.checkNotNull;
+import static com.android.server.NetworkManagementService.LIMIT_GLOBAL_ALERT;
 import static com.android.server.NetworkManagementSocketTagger.resetKernelUidStats;
 import static com.android.server.NetworkManagementSocketTagger.setKernelCounterSet;
 
@@ -56,6 +57,7 @@
 import android.content.IntentFilter;
 import android.content.pm.ApplicationInfo;
 import android.net.IConnectivityManager;
+import android.net.INetworkManagementEventObserver;
 import android.net.INetworkStatsService;
 import android.net.NetworkIdentity;
 import android.net.NetworkInfo;
@@ -121,7 +123,8 @@
     private static final int VERSION_UID_WITH_TAG = 3;
     private static final int VERSION_UID_WITH_SET = 4;
 
-    private static final int MSG_FORCE_UPDATE = 0x1;
+    private static final int MSG_PERFORM_POLL = 0x1;
+    private static final int MSG_PERFORM_POLL_DETAILED = 0x2;
 
     private final Context mContext;
     private final INetworkManagementService mNetworkManager;
@@ -141,7 +144,6 @@
 
     private PendingIntent mPollIntent;
 
-    // TODO: listen for kernel push events through netd instead of polling
     // TODO: trim empty history objects entirely
 
     private static final long KB_IN_BYTES = 1024;
@@ -174,17 +176,18 @@
     /** Flag if {@link #mUidStats} have been loaded from disk. */
     private boolean mUidStatsLoaded = false;
 
-    private NetworkStats mLastNetworkSnapshot;
-    private NetworkStats mLastPersistNetworkSnapshot;
+    private NetworkStats mLastPollNetworkSnapshot;
+    private NetworkStats mLastPollUidSnapshot;
+    private NetworkStats mLastPollOperationsSnapshot;
 
-    private NetworkStats mLastUidSnapshot;
+    private NetworkStats mLastPersistNetworkSnapshot;
+    private NetworkStats mLastPersistUidSnapshot;
 
     /** Current counter sets for each UID. */
     private SparseIntArray mActiveUidCounterSet = new SparseIntArray();
 
     /** Data layer operation counters for splicing into other structures. */
     private NetworkStats mOperations = new NetworkStats(0L, 10);
-    private NetworkStats mLastOperationsSnapshot;
 
     private final HandlerThread mHandlerThread;
     private final Handler mHandler;
@@ -252,13 +255,18 @@
         mContext.registerReceiver(mShutdownReceiver, shutdownFilter);
 
         try {
-            registerPollAlarmLocked();
+            mNetworkManager.registerObserver(mAlertObserver);
         } catch (RemoteException e) {
-            Slog.w(TAG, "unable to register poll alarm");
+            // ouch, no push updates means we fall back to
+            // ACTION_NETWORK_STATS_POLL intervals.
+            Slog.e(TAG, "unable to register INetworkManagementEventObserver", e);
         }
 
-        // kick off background poll to bootstrap deltas
-        mHandler.obtainMessage(MSG_FORCE_UPDATE).sendToTarget();
+        registerPollAlarmLocked();
+        registerGlobalAlert();
+
+        // bootstrap initial stats to prevent double-counting later
+        bootstrapStats();
     }
 
     private void shutdownLocked() {
@@ -280,17 +288,37 @@
      * Clear any existing {@link #ACTION_NETWORK_STATS_POLL} alarms, and
      * reschedule based on current {@link NetworkStatsSettings#getPollInterval()}.
      */
-    private void registerPollAlarmLocked() throws RemoteException {
-        if (mPollIntent != null) {
-            mAlarmManager.remove(mPollIntent);
+    private void registerPollAlarmLocked() {
+        try {
+            if (mPollIntent != null) {
+                mAlarmManager.remove(mPollIntent);
+            }
+
+            mPollIntent = PendingIntent.getBroadcast(
+                    mContext, 0, new Intent(ACTION_NETWORK_STATS_POLL), 0);
+
+            final long currentRealtime = SystemClock.elapsedRealtime();
+            mAlarmManager.setInexactRepeating(AlarmManager.ELAPSED_REALTIME, currentRealtime,
+                    mSettings.getPollInterval(), mPollIntent);
+        } catch (RemoteException e) {
+            Slog.w(TAG, "problem registering for poll alarm: " + e);
         }
+    }
 
-        mPollIntent = PendingIntent.getBroadcast(
-                mContext, 0, new Intent(ACTION_NETWORK_STATS_POLL), 0);
-
-        final long currentRealtime = SystemClock.elapsedRealtime();
-        mAlarmManager.setInexactRepeating(AlarmManager.ELAPSED_REALTIME, currentRealtime,
-                mSettings.getPollInterval(), mPollIntent);
+    /**
+     * Register for a global alert that is delivered through
+     * {@link INetworkManagementEventObserver} once a threshold amount of data
+     * has been transferred.
+     */
+    private void registerGlobalAlert() {
+        try {
+            final long alertBytes = mSettings.getPersistThreshold();
+            mNetworkManager.setGlobalAlert(alertBytes);
+        } catch (IllegalStateException e) {
+            Slog.w(TAG, "problem registering for global alert: " + e);
+        } catch (RemoteException e) {
+            Slog.w(TAG, "problem registering for global alert: " + e);
+        }
     }
 
     @Override
@@ -475,10 +503,7 @@
     @Override
     public void forceUpdate() {
         mContext.enforceCallingOrSelfPermission(READ_NETWORK_USAGE_HISTORY, TAG);
-
-        synchronized (mStatsLock) {
-            performPollLocked(true, false);
-        }
+        performPoll(true, false);
     }
 
     /**
@@ -507,14 +532,10 @@
         public void onReceive(Context context, Intent intent) {
             // on background handler thread, and verified UPDATE_DEVICE_STATS
             // permission above.
-            synchronized (mStatsLock) {
-                mWakeLock.acquire();
-                try {
-                    performPollLocked(true, false);
-                } finally {
-                    mWakeLock.release();
-                }
-            }
+            performPoll(true, false);
+
+            // verify that we're watching global alert
+            registerGlobalAlert();
         }
     };
 
@@ -547,6 +568,26 @@
     };
 
     /**
+     * Observer that watches for {@link INetworkManagementService} alerts.
+     */
+    private INetworkManagementEventObserver mAlertObserver = new NetworkAlertObserver() {
+        @Override
+        public void limitReached(String limitName, String iface) {
+            // only someone like NMS should be calling us
+            mContext.enforceCallingOrSelfPermission(CONNECTIVITY_INTERNAL, TAG);
+
+            if (LIMIT_GLOBAL_ALERT.equals(limitName)) {
+                // kick off background poll to collect network stats; UID stats
+                // are handled during normal polling interval.
+                mHandler.obtainMessage(MSG_PERFORM_POLL).sendToTarget();
+
+                // re-arm global alert for next update
+                registerGlobalAlert();
+            }
+        }
+    };
+
+    /**
      * Inspect all current {@link NetworkState} to derive mapping from {@code
      * iface} to {@link NetworkStatsHistory}. When multiple {@link NetworkInfo}
      * are active on a single {@code iface}, they are combined under a single
@@ -588,6 +629,33 @@
     }
 
     /**
+     * Bootstrap initial stats snapshot, usually during {@link #systemReady()}
+     * so we have baseline values without double-counting.
+     */
+    private void bootstrapStats() {
+        try {
+            mLastPollNetworkSnapshot = mNetworkManager.getNetworkStatsSummary();
+            mLastPollUidSnapshot = mNetworkManager.getNetworkStatsUidDetail(UID_ALL);
+            mLastPollOperationsSnapshot = new NetworkStats(0L, 0);
+        } catch (IllegalStateException e) {
+            Slog.w(TAG, "problem reading network stats: " + e);
+        } catch (RemoteException e) {
+            Slog.w(TAG, "problem reading network stats: " + e);
+        }
+    }
+
+    private void performPoll(boolean detailedPoll, boolean forcePersist) {
+        synchronized (mStatsLock) {
+            mWakeLock.acquire();
+            try {
+                performPollLocked(detailedPoll, forcePersist);
+            } finally {
+                mWakeLock.release();
+            }
+        }
+    }
+
+    /**
      * Periodic poll operation, reading current statistics and recording into
      * {@link NetworkStatsHistory}.
      *
@@ -596,6 +664,7 @@
      */
     private void performPollLocked(boolean detailedPoll, boolean forcePersist) {
         if (LOGV) Slog.v(TAG, "performPollLocked()");
+        final long startRealtime = SystemClock.elapsedRealtime();
 
         // try refreshing time source when stale
         if (mTime.getCacheAge() > mSettings.getTimeCacheMaxAge()) {
@@ -605,6 +674,7 @@
         // TODO: consider marking "untrusted" times in historical stats
         final long currentTime = mTime.hasCache() ? mTime.currentTimeMillis()
                 : System.currentTimeMillis();
+        final long persistThreshold = mSettings.getPersistThreshold();
 
         final NetworkStats networkSnapshot;
         final NetworkStats uidSnapshot;
@@ -620,30 +690,32 @@
         }
 
         performNetworkPollLocked(networkSnapshot, currentTime);
-        if (detailedPoll) {
-            performUidPollLocked(uidSnapshot, currentTime);
+
+        // persist when enough network data has occurred
+        final NetworkStats persistNetworkDelta = computeStatsDelta(
+                mLastPersistNetworkSnapshot, networkSnapshot, true);
+        if (forcePersist || persistNetworkDelta.getTotalBytes() > persistThreshold) {
+            writeNetworkStatsLocked();
+            mLastPersistNetworkSnapshot = networkSnapshot;
         }
 
-        // decide if enough has changed to trigger persist
-        final NetworkStats persistDelta = computeStatsDelta(
-                mLastPersistNetworkSnapshot, networkSnapshot, true);
-        final long persistThreshold = mSettings.getPersistThreshold();
+        if (detailedPoll) {
+            performUidPollLocked(uidSnapshot, currentTime);
 
-        NetworkStats.Entry entry = null;
-        for (String iface : persistDelta.getUniqueIfaces()) {
-            final int index = persistDelta.findIndex(iface, UID_ALL, SET_DEFAULT, TAG_NONE);
-            entry = persistDelta.getValues(index, entry);
-            if (forcePersist || entry.rxBytes > persistThreshold
-                    || entry.txBytes > persistThreshold) {
-                writeNetworkStatsLocked();
-                if (mUidStatsLoaded) {
-                    writeUidStatsLocked();
-                }
+            // persist when enough network data has occurred
+            final NetworkStats persistUidDelta = computeStatsDelta(
+                    mLastPersistUidSnapshot, uidSnapshot, true);
+            if (forcePersist || persistUidDelta.getTotalBytes() > persistThreshold) {
+                writeUidStatsLocked();
                 mLastPersistNetworkSnapshot = networkSnapshot;
-                break;
             }
         }
 
+        if (LOGV) {
+            final long duration = SystemClock.elapsedRealtime() - startRealtime;
+            Slog.v(TAG, "performPollLocked() took " + duration + "ms");
+        }
+
         // finally, dispatch updated event to any listeners
         final Intent updatedIntent = new Intent(ACTION_NETWORK_STATS_UPDATED);
         updatedIntent.setFlags(Intent.FLAG_RECEIVER_REGISTERED_ONLY);
@@ -656,7 +728,7 @@
     private void performNetworkPollLocked(NetworkStats networkSnapshot, long currentTime) {
         final HashSet<String> unknownIface = Sets.newHashSet();
 
-        final NetworkStats delta = computeStatsDelta(mLastNetworkSnapshot, networkSnapshot, false);
+        final NetworkStats delta = computeStatsDelta(mLastPollNetworkSnapshot, networkSnapshot, false);
         final long timeStart = currentTime - delta.getElapsedRealtime();
 
         NetworkStats.Entry entry = null;
@@ -678,7 +750,7 @@
             history.removeBucketsBefore(currentTime - maxHistory);
         }
 
-        mLastNetworkSnapshot = networkSnapshot;
+        mLastPollNetworkSnapshot = networkSnapshot;
 
         if (LOGD && unknownIface.size() > 0) {
             Slog.w(TAG, "unknown interfaces " + unknownIface.toString() + ", ignoring those stats");
@@ -691,9 +763,9 @@
     private void performUidPollLocked(NetworkStats uidSnapshot, long currentTime) {
         ensureUidStatsLoadedLocked();
 
-        final NetworkStats delta = computeStatsDelta(mLastUidSnapshot, uidSnapshot, false);
+        final NetworkStats delta = computeStatsDelta(mLastPollUidSnapshot, uidSnapshot, false);
         final NetworkStats operationsDelta = computeStatsDelta(
-                mLastOperationsSnapshot, mOperations, false);
+                mLastPollOperationsSnapshot, mOperations, false);
         final long timeStart = currentTime - delta.getElapsedRealtime();
 
         NetworkStats.Entry entry = null;
@@ -731,8 +803,8 @@
             }
         }
 
-        mLastUidSnapshot = uidSnapshot;
-        mLastOperationsSnapshot = mOperations;
+        mLastPollUidSnapshot = uidSnapshot;
+        mLastPollOperationsSnapshot = mOperations;
         mOperations = new NetworkStats(0L, 10);
     }
 
@@ -1162,8 +1234,12 @@
         /** {@inheritDoc} */
         public boolean handleMessage(Message msg) {
             switch (msg.what) {
-                case MSG_FORCE_UPDATE: {
-                    forceUpdate();
+                case MSG_PERFORM_POLL: {
+                    performPoll(false, false);
+                    return true;
+                }
+                case MSG_PERFORM_POLL_DETAILED: {
+                    performPoll(true, false);
                     return true;
                 }
                 default: {
@@ -1226,10 +1302,10 @@
         }
 
         public long getPollInterval() {
-            return getSecureLong(NETSTATS_POLL_INTERVAL, 15 * MINUTE_IN_MILLIS);
+            return getSecureLong(NETSTATS_POLL_INTERVAL, 30 * MINUTE_IN_MILLIS);
         }
         public long getPersistThreshold() {
-            return getSecureLong(NETSTATS_PERSIST_THRESHOLD, 16 * KB_IN_BYTES);
+            return getSecureLong(NETSTATS_PERSIST_THRESHOLD, 512 * KB_IN_BYTES);
         }
         public long getNetworkBucketDuration() {
             return getSecureLong(NETSTATS_NETWORK_BUCKET_DURATION, HOUR_IN_MILLIS);
diff --git a/services/tests/servicestests/src/com/android/server/NetworkStatsServiceTest.java b/services/tests/servicestests/src/com/android/server/NetworkStatsServiceTest.java
index 6138490..6dd8cd6 100644
--- a/services/tests/servicestests/src/com/android/server/NetworkStatsServiceTest.java
+++ b/services/tests/servicestests/src/com/android/server/NetworkStatsServiceTest.java
@@ -38,6 +38,7 @@
 import static android.text.format.DateUtils.WEEK_IN_MILLIS;
 import static com.android.server.net.NetworkStatsService.ACTION_NETWORK_STATS_POLL;
 import static org.easymock.EasyMock.anyLong;
+import static org.easymock.EasyMock.capture;
 import static org.easymock.EasyMock.createMock;
 import static org.easymock.EasyMock.eq;
 import static org.easymock.EasyMock.expect;
@@ -49,6 +50,7 @@
 import android.app.PendingIntent;
 import android.content.Intent;
 import android.net.IConnectivityManager;
+import android.net.INetworkManagementEventObserver;
 import android.net.LinkProperties;
 import android.net.NetworkInfo;
 import android.net.NetworkInfo.DetailedState;
@@ -65,10 +67,10 @@
 import com.android.server.net.NetworkStatsService;
 import com.android.server.net.NetworkStatsService.NetworkStatsSettings;
 
+import org.easymock.Capture;
 import org.easymock.EasyMock;
 
 import java.io.File;
-import java.util.concurrent.Future;
 
 import libcore.io.IoUtils;
 
@@ -105,6 +107,7 @@
     private IConnectivityManager mConnManager;
 
     private NetworkStatsService mService;
+    private INetworkManagementEventObserver mNetworkObserver;
 
     @Override
     public void setUp() throws Exception {
@@ -132,13 +135,20 @@
         expectDefaultSettings();
         expectNetworkStatsSummary(buildEmptyStats());
         expectNetworkStatsUidDetail(buildEmptyStats());
-        final Future<?> firstPoll = expectSystemReady();
+        expectSystemReady();
+
+        // catch INetworkManagementEventObserver during systemReady()
+        final Capture<INetworkManagementEventObserver> networkObserver = new Capture<
+                INetworkManagementEventObserver>();
+        mNetManager.registerObserver(capture(networkObserver));
+        expectLastCall().atLeastOnce();
 
         replay();
         mService.systemReady();
-        firstPoll.get();
         verifyAndReset();
 
+        mNetworkObserver = networkObserver.getValue();
+
     }
 
     @Override
@@ -183,6 +193,7 @@
         expectNetworkStatsSummary(new NetworkStats(getElapsedRealtime(), 1)
                 .addIfaceValues(TEST_IFACE, 1024L, 1L, 2048L, 2L));
         expectNetworkStatsUidDetail(buildEmptyStats());
+        expectNetworkStatsPoll();
 
         replay();
         mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
@@ -199,6 +210,7 @@
         expectNetworkStatsSummary(new NetworkStats(getElapsedRealtime(), 1)
                 .addIfaceValues(TEST_IFACE, 4096L, 4L, 8192L, 8L));
         expectNetworkStatsUidDetail(buildEmptyStats());
+        expectNetworkStatsPoll();
 
         replay();
         mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
@@ -238,6 +250,7 @@
                 .addValues(TEST_IFACE, UID_RED, SET_FOREGROUND, TAG_NONE, 512L, 4L, 256L, 2L, 0L)
                 .addValues(TEST_IFACE, UID_RED, SET_FOREGROUND, 0xFAAD, 256L, 2L, 128L, 1L, 0L)
                 .addValues(TEST_IFACE, UID_BLUE, SET_DEFAULT, TAG_NONE, 128L, 1L, 128L, 1L, 0L));
+        expectNetworkStatsPoll();
 
         mService.setUidForeground(UID_RED, false);
         mService.incrementOperationCount(UID_RED, 0xFAAD, 4);
@@ -273,11 +286,18 @@
         expectDefaultSettings();
         expectNetworkStatsSummary(buildEmptyStats());
         expectNetworkStatsUidDetail(buildEmptyStats());
-        final Future<?> firstPoll = expectSystemReady();
+        expectSystemReady();
+
+        // catch INetworkManagementEventObserver during systemReady()
+        final Capture<INetworkManagementEventObserver> networkObserver = new Capture<
+                INetworkManagementEventObserver>();
+        mNetManager.registerObserver(capture(networkObserver));
+        expectLastCall().atLeastOnce();
 
         replay();
         mService.systemReady();
-        firstPoll.get();
+
+        mNetworkObserver = networkObserver.getValue();
 
         // after systemReady(), we should have historical stats loaded again
         assertNetworkTotal(sTemplateWifi, 1024L, 8L, 2048L, 16L, 0);
@@ -312,6 +332,7 @@
         expectNetworkStatsSummary(new NetworkStats(getElapsedRealtime(), 1)
                 .addIfaceValues(TEST_IFACE, 512L, 4L, 512L, 4L));
         expectNetworkStatsUidDetail(buildEmptyStats());
+        expectNetworkStatsPoll();
 
         replay();
         mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
@@ -329,6 +350,7 @@
         expectSettings(0L, 30 * MINUTE_IN_MILLIS, WEEK_IN_MILLIS);
         expectNetworkStatsSummary(buildEmptyStats());
         expectNetworkStatsUidDetail(buildEmptyStats());
+        expectNetworkStatsPoll();
 
         replay();
         mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
@@ -363,6 +385,7 @@
                 .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, TAG_NONE, 1536L, 12L, 512L, 4L, 0L)
                 .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, 0xF00D, 512L, 4L, 512L, 4L, 0L)
                 .addValues(TEST_IFACE, UID_BLUE, SET_DEFAULT, TAG_NONE, 512L, 4L, 0L, 0L, 0L));
+        expectNetworkStatsPoll();
 
         mService.incrementOperationCount(UID_RED, 0xF00D, 10);
 
@@ -384,6 +407,7 @@
         expectNetworkState(buildMobile3gState(IMSI_2));
         expectNetworkStatsSummary(buildEmptyStats());
         expectNetworkStatsUidDetail(buildEmptyStats());
+        expectNetworkStatsPoll();
 
         replay();
         mServiceContext.sendBroadcast(new Intent(CONNECTIVITY_ACTION));
@@ -399,6 +423,7 @@
         expectNetworkStatsUidDetail(new NetworkStats(getElapsedRealtime(), 1)
                 .addValues(TEST_IFACE, UID_BLUE, SET_DEFAULT, TAG_NONE, 128L, 1L, 1024L, 8L, 0L)
                 .addValues(TEST_IFACE, UID_BLUE, SET_DEFAULT, 0xFAAD, 128L, 1L, 1024L, 8L, 0L));
+        expectNetworkStatsPoll();
 
         mService.incrementOperationCount(UID_BLUE, 0xFAAD, 10);
 
@@ -441,6 +466,7 @@
                 .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, 0xFAAD, 16L, 1L, 16L, 1L, 0L)
                 .addValues(TEST_IFACE, UID_BLUE, SET_DEFAULT, TAG_NONE, 4096L, 258L, 512L, 32L, 0L)
                 .addValues(TEST_IFACE, UID_GREEN, SET_DEFAULT, TAG_NONE, 16L, 1L, 16L, 1L, 0L));
+        expectNetworkStatsPoll();
 
         mService.incrementOperationCount(UID_RED, 0xFAAD, 10);
 
@@ -494,6 +520,7 @@
         expectNetworkStatsUidDetail(new NetworkStats(getElapsedRealtime(), 1)
                 .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, TAG_NONE, 1024L, 8L, 1024L, 8L, 0L)
                 .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, 0xF00D, 512L, 4L, 512L, 4L, 0L));
+        expectNetworkStatsPoll();
 
         mService.incrementOperationCount(UID_RED, 0xF00D, 5);
 
@@ -511,6 +538,7 @@
         expectNetworkState(buildMobile4gState());
         expectNetworkStatsSummary(buildEmptyStats());
         expectNetworkStatsUidDetail(buildEmptyStats());
+        expectNetworkStatsPoll();
 
         replay();
         mServiceContext.sendBroadcast(new Intent(CONNECTIVITY_ACTION));
@@ -525,6 +553,7 @@
         expectNetworkStatsUidDetail(new NetworkStats(getElapsedRealtime(), 1)
                 .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, TAG_NONE, 512L, 4L, 256L, 2L, 0L)
                 .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, 0xFAAD, 512L, 4L, 256L, 2L, 0L));
+        expectNetworkStatsPoll();
 
         mService.incrementOperationCount(UID_RED, 0xFAAD, 5);
 
@@ -558,6 +587,7 @@
                 .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, TAG_NONE, 50L, 5L, 50L, 5L, 0L)
                 .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, 0xF00D, 10L, 1L, 10L, 1L, 0L)
                 .addValues(TEST_IFACE, UID_BLUE, SET_DEFAULT, TAG_NONE, 1024L, 8L, 512L, 4L, 0L));
+        expectNetworkStatsPoll();
 
         mService.incrementOperationCount(UID_RED, 0xF00D, 1);
 
@@ -576,6 +606,7 @@
         expectNetworkStatsSummary(buildEmptyStats());
         expectNetworkStatsUidDetail(new NetworkStats(getElapsedRealtime(), 1)
                 .addValues(TEST_IFACE, UID_BLUE, SET_DEFAULT, TAG_NONE, 2048L, 16L, 1024L, 8L, 0L));
+        expectNetworkStatsPoll();
 
         replay();
         mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
@@ -617,6 +648,7 @@
         expectNetworkStatsUidDetail(new NetworkStats(getElapsedRealtime(), 1)
                 .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, TAG_NONE, 128L, 2L, 128L, 2L, 0L)
                 .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, 0xF00D, 64L, 1L, 64L, 1L, 0L));
+        expectNetworkStatsPoll();
 
         mService.incrementOperationCount(UID_RED, 0xF00D, 1);
 
@@ -637,6 +669,7 @@
                 .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, 0xF00D, 64L, 1L, 64L, 1L, 0L)
                 .addValues(TEST_IFACE, UID_RED, SET_FOREGROUND, TAG_NONE, 32L, 2L, 32L, 2L, 0L)
                 .addValues(TEST_IFACE, UID_RED, SET_FOREGROUND, 0xFAAD, 1L, 1L, 1L, 1L, 0L));
+        expectNetworkStatsPoll();
 
         mService.setUidForeground(UID_RED, true);
         mService.incrementOperationCount(UID_RED, 0xFAAD, 1);
@@ -679,7 +712,7 @@
                 txPackets, operations);
     }
 
-    private Future<?> expectSystemReady() throws Exception {
+    private void expectSystemReady() throws Exception {
         mAlarmManager.remove(isA(PendingIntent.class));
         expectLastCall().anyTimes();
 
@@ -687,8 +720,8 @@
                 eq(AlarmManager.ELAPSED_REALTIME), anyLong(), anyLong(), isA(PendingIntent.class));
         expectLastCall().atLeastOnce();
 
-        return mServiceContext.nextBroadcastIntent(
-                NetworkStatsService.ACTION_NETWORK_STATS_UPDATED);
+        mNetManager.setGlobalAlert(anyLong());
+        expectLastCall().atLeastOnce();
     }
 
     private void expectNetworkState(NetworkState... state) throws Exception {
@@ -727,6 +760,11 @@
         expect(mTime.getCacheCertainty()).andReturn(0L).anyTimes();
     }
 
+    private void expectNetworkStatsPoll() throws Exception {
+        mNetManager.setGlobalAlert(anyLong());
+        expectLastCall().anyTimes();
+    }
+
     private void assertStatsFilesExist(boolean exist) {
         final File networkFile = new File(mStatsDir, "netstats.bin");
         final File uidFile = new File(mStatsDir, "netstats_uid.bin");