Implementation of data usage callbacks.

NetworkStatsService will register data usage requests
and keep data usage stats scoped to the request.

There are different types of data usage requests
- scoped to a set of NetworkTemplate; these are restrictred to
device owners and carrier apps and allow the caller to monitor
all activity on the specified interfaces.
- scoped to all uids visible to the user, if the user has
android.Manifest.permission#PACKAGE_USAGE_STATS permission.
The set of uids may change over time, so we keep track of that.
- scoped to a set of uids given by the caller, granted that
the caller has access to those uids.
- scoped to the caller's own data usage. This doesn't require
PACKAGE_USAGE_STATS.

Bug: 25812785
Change-Id: Ie11f35fc1f29d0dbe82f7fc924b169bb55c76708
diff --git a/core/java/android/app/usage/NetworkStatsManager.java b/core/java/android/app/usage/NetworkStatsManager.java
index 13aeef0..2e3aca4 100644
--- a/core/java/android/app/usage/NetworkStatsManager.java
+++ b/core/java/android/app/usage/NetworkStatsManager.java
@@ -25,9 +25,15 @@
 import android.net.DataUsageRequest;
 import android.net.NetworkIdentity;
 import android.net.NetworkTemplate;
+import android.net.INetworkStatsService;
+import android.os.Binder;
 import android.os.Build;
+import android.os.Message;
+import android.os.Messenger;
 import android.os.Handler;
+import android.os.Looper;
 import android.os.RemoteException;
+import android.os.ServiceManager;
 import android.util.Log;
 
 /**
@@ -75,16 +81,26 @@
  * not included.
  */
 public class NetworkStatsManager {
-    private final static String TAG = "NetworkStatsManager";
+    private static final String TAG = "NetworkStatsManager";
+    private static final boolean DBG = false;
+
+    /** @hide */
+    public static final int CALLBACK_LIMIT_REACHED = 0;
+    /** @hide */
+    public static final int CALLBACK_RELEASED = 1;
 
     private final Context mContext;
+    private final INetworkStatsService mService;
 
     /**
      * {@hide}
      */
     public NetworkStatsManager(Context context) {
         mContext = context;
+        mService = INetworkStatsService.Stub.asInterface(
+                ServiceManager.getService(Context.NETWORK_STATS_SERVICE));
     }
+
     /**
      * Query network usage statistics summaries. Result is summarised data usage for the whole
      * device. Result is a single Bucket aggregated over time, state, uid, tag and roaming. This
@@ -322,7 +338,40 @@
         checkNotNull(policy, "DataUsagePolicy cannot be null");
         checkNotNull(callback, "DataUsageCallback cannot be null");
 
-        // TODO: Implement stub.
+        final Looper looper;
+        if (handler == null) {
+            looper = Looper.myLooper();
+        } else {
+            looper = handler.getLooper();
+        }
+
+        if (DBG) Log.d(TAG, "registerDataUsageCallback called with " + policy);
+
+        NetworkTemplate[] templates;
+        if (policy.subscriberIds == null || policy.subscriberIds.length == 0) {
+            templates = new NetworkTemplate[1];
+            templates[0] = createTemplate(policy.networkType, null /* subscriberId */);
+        } else {
+            templates = new NetworkTemplate[policy.subscriberIds.length];
+            for (int i = 0; i < policy.subscriberIds.length; i++) {
+                templates[i] = createTemplate(policy.networkType, policy.subscriberIds[i]);
+            }
+        }
+        DataUsageRequest request = new DataUsageRequest(DataUsageRequest.REQUEST_ID_UNSET,
+                templates, policy.uids, policy.thresholdInBytes);
+        try {
+            CallbackHandler callbackHandler = new CallbackHandler(looper, callback);
+            callback.request = mService.registerDataUsageCallback(
+                    mContext.getOpPackageName(), request, new Messenger(callbackHandler),
+                    new Binder());
+            if (DBG) Log.d(TAG, "registerDataUsageCallback returned " + callback.request);
+
+            if (callback.request == null) {
+                Log.e(TAG, "Request from callback is null; should not happen");
+            }
+        } catch (RemoteException e) {
+            if (DBG) Log.d(TAG, "Remote exception when registering callback");
+        }
     }
 
     /**
@@ -331,9 +380,15 @@
      * @param callback The {@link DataUsageCallback} used when registering.
      */
     public void unregisterDataUsageCallback(DataUsageCallback callback) {
-        checkNotNull(callback, "DataUsageCallback cannot be null");
-
-        // TODO: Implement stub.
+        if (callback == null || callback.request == null
+                || callback.request.requestId == DataUsageRequest.REQUEST_ID_UNSET) {
+            throw new IllegalArgumentException("Invalid DataUsageCallback");
+        }
+        try {
+            mService.unregisterDataUsageRequest(callback.request);
+        } catch (RemoteException e) {
+            if (DBG) Log.d(TAG, "Remote exception when unregistering callback");
+        }
     }
 
     /**
@@ -366,4 +421,38 @@
         }
         return template;
     }
+
+    private static class CallbackHandler extends Handler {
+        private DataUsageCallback mCallback;
+        CallbackHandler(Looper looper, DataUsageCallback callback) {
+            super(looper);
+            mCallback = callback;
+        }
+
+        @Override
+        public void handleMessage(Message message) {
+            DataUsageRequest request =
+                    (DataUsageRequest) getObject(message, DataUsageRequest.PARCELABLE_KEY);
+
+            switch (message.what) {
+                case CALLBACK_LIMIT_REACHED: {
+                    if (mCallback != null) {
+                        mCallback.onLimitReached();
+                    } else {
+                        Log.e(TAG, "limit reached with released callback for " + request);
+                    }
+                    break;
+                }
+                case CALLBACK_RELEASED: {
+                    if (DBG) Log.d(TAG, "callback released for " + request);
+                    mCallback = null;
+                    break;
+                }
+            }
+        }
+
+        private static Object getObject(Message msg, String key) {
+            return msg.getData().getParcelable(key);
+        }
+    }
 }
diff --git a/core/java/android/net/DataUsageRequest.java b/core/java/android/net/DataUsageRequest.java
index 0e46f4c..5e96cc1 100644
--- a/core/java/android/net/DataUsageRequest.java
+++ b/core/java/android/net/DataUsageRequest.java
@@ -34,6 +34,11 @@
     /**
      * @hide
      */
+    public static final String PARCELABLE_KEY = "DataUsageRequest";
+
+    /**
+     * @hide
+     */
     public static final int REQUEST_ID_UNSET = 0;
 
     /**
diff --git a/core/java/android/net/INetworkStatsService.aidl b/core/java/android/net/INetworkStatsService.aidl
index 6436e42..2eea940 100644
--- a/core/java/android/net/INetworkStatsService.aidl
+++ b/core/java/android/net/INetworkStatsService.aidl
@@ -16,10 +16,13 @@
 
 package android.net;
 
+import android.net.DataUsageRequest;
 import android.net.INetworkStatsSession;
 import android.net.NetworkStats;
 import android.net.NetworkStatsHistory;
 import android.net.NetworkTemplate;
+import android.os.IBinder;
+import android.os.Messenger;
 
 /** {@hide} */
 interface INetworkStatsService {
@@ -57,4 +60,11 @@
     /** Advise persistance threshold; may be overridden internally. */
     void advisePersistThreshold(long thresholdBytes);
 
+    /** Registers a callback on data usage. */
+    DataUsageRequest registerDataUsageCallback(String callingPackage,
+            in DataUsageRequest request, in Messenger messenger, in IBinder binder);
+
+    /** Unregisters a callback on data usage. */
+    void unregisterDataUsageRequest(in DataUsageRequest request);
+
 }
diff --git a/services/core/java/com/android/server/net/NetworkStatsAccess.java b/services/core/java/com/android/server/net/NetworkStatsAccess.java
index 479b065..98fe770 100644
--- a/services/core/java/com/android/server/net/NetworkStatsAccess.java
+++ b/services/core/java/com/android/server/net/NetworkStatsAccess.java
@@ -17,6 +17,7 @@
 package com.android.server.net;
 
 import static android.Manifest.permission.READ_NETWORK_USAGE_HISTORY;
+import static android.net.NetworkStats.UID_ALL;
 import static android.net.TrafficStats.UID_REMOVED;
 import static android.net.TrafficStats.UID_TETHERING;
 
@@ -48,6 +49,7 @@
     @IntDef({
             Level.DEFAULT,
             Level.USER,
+            Level.DEVICESUMMARY,
             Level.DEVICE,
     })
     @Retention(RetentionPolicy.SOURCE)
@@ -147,6 +149,12 @@
                 // Device-level access - can access usage for any uid.
                 return true;
             case NetworkStatsAccess.Level.DEVICESUMMARY:
+                // Can access usage for any app running in the same user, along
+                // with some special uids (system, removed, or tethering) and
+                // anonymized uids
+                return uid == android.os.Process.SYSTEM_UID || uid == UID_REMOVED
+                        || uid == UID_TETHERING || uid == UID_ALL
+                        || UserHandle.getUserId(uid) == UserHandle.getUserId(callerUid);
             case NetworkStatsAccess.Level.USER:
                 // User-level access - can access usage for any app running in the same user, along
                 // with some special uids (system, removed, or tethering).
diff --git a/services/core/java/com/android/server/net/NetworkStatsCollection.java b/services/core/java/com/android/server/net/NetworkStatsCollection.java
index eec7d93..d986e94b 100644
--- a/services/core/java/com/android/server/net/NetworkStatsCollection.java
+++ b/services/core/java/com/android/server/net/NetworkStatsCollection.java
@@ -135,7 +135,11 @@
     }
 
     public int[] getRelevantUids(@NetworkStatsAccess.Level int accessLevel) {
-        final int callerUid = Binder.getCallingUid();
+        return getRelevantUids(accessLevel, Binder.getCallingUid());
+    }
+
+    public int[] getRelevantUids(@NetworkStatsAccess.Level int accessLevel,
+                final int callerUid) {
         IntArray uids = new IntArray();
         for (int i = 0; i < mStats.size(); i++) {
             final Key key = mStats.keyAt(i);
@@ -169,7 +173,17 @@
     public NetworkStatsHistory getHistory(
             NetworkTemplate template, int uid, int set, int tag, int fields, long start, long end,
             @NetworkStatsAccess.Level int accessLevel) {
-        final int callerUid = Binder.getCallingUid();
+        return getHistory(template, uid, set, tag, fields, start, end, accessLevel,
+                Binder.getCallingUid());
+    }
+
+    /**
+     * Combine all {@link NetworkStatsHistory} in this collection which match
+     * the requested parameters.
+     */
+    public NetworkStatsHistory getHistory(
+            NetworkTemplate template, int uid, int set, int tag, int fields, long start, long end,
+            @NetworkStatsAccess.Level int accessLevel, int callerUid) {
         if (!NetworkStatsAccess.isAccessibleToUser(uid, callerUid, accessLevel)) {
             throw new SecurityException("Network stats history of uid " + uid
                     + " is forbidden for caller " + callerUid);
@@ -198,6 +212,15 @@
      */
     public NetworkStats getSummary(NetworkTemplate template, long start, long end,
             @NetworkStatsAccess.Level int accessLevel) {
+        return getSummary(template, start, end, accessLevel, Binder.getCallingUid());
+    }
+
+    /**
+     * Summarize all {@link NetworkStatsHistory} in this collection which match
+     * the requested parameters.
+     */
+    public NetworkStats getSummary(NetworkTemplate template, long start, long end,
+            @NetworkStatsAccess.Level int accessLevel, int callerUid) {
         final long now = System.currentTimeMillis();
 
         final NetworkStats stats = new NetworkStats(end - start, 24);
@@ -207,7 +230,6 @@
         final NetworkStats.Entry entry = new NetworkStats.Entry();
         NetworkStatsHistory.Entry historyEntry = null;
 
-        final int callerUid = Binder.getCallingUid();
         for (int i = 0; i < mStats.size(); i++) {
             final Key key = mStats.keyAt(i);
             if (templateMatches(template, key.ident)
diff --git a/services/core/java/com/android/server/net/NetworkStatsObservers.java b/services/core/java/com/android/server/net/NetworkStatsObservers.java
new file mode 100644
index 0000000..2f55562
--- /dev/null
+++ b/services/core/java/com/android/server/net/NetworkStatsObservers.java
@@ -0,0 +1,493 @@
+/*
+ * Copyright (C) 2016 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.net;
+
+import static android.net.TrafficStats.MB_IN_BYTES;
+import static com.android.internal.util.Preconditions.checkArgument;
+
+import android.app.usage.NetworkStatsManager;
+import android.net.DataUsageRequest;
+import android.net.NetworkStats;
+import android.net.NetworkStats.NonMonotonicObserver;
+import android.net.NetworkStatsHistory;
+import android.net.NetworkTemplate;
+import android.os.Binder;
+import android.os.Bundle;
+import android.os.Looper;
+import android.os.Message;
+import android.os.Messenger;
+import android.os.Handler;
+import android.os.HandlerThread;
+import android.os.IBinder;
+import android.os.Process;
+import android.os.RemoteException;
+import android.util.ArrayMap;
+import android.util.IntArray;
+import android.util.SparseArray;
+import android.util.Slog;
+
+import com.android.internal.annotations.VisibleForTesting;
+import com.android.internal.net.VpnInfo;
+
+import java.util.concurrent.atomic.AtomicInteger;
+
+/**
+ * Manages observers of {@link NetworkStats}. Allows observers to be notified when
+ * data usage has been reported in {@link NetworkStatsService}. An observer can set
+ * a threshold of how much data it cares about to be notified.
+ */
+class NetworkStatsObservers {
+    private static final String TAG = "NetworkStatsObservers";
+    private static final boolean LOGV = true;
+
+    private static final long MIN_THRESHOLD_BYTES = 2 * MB_IN_BYTES;
+
+    private static final int MSG_REGISTER = 1;
+    private static final int MSG_UNREGISTER = 2;
+    private static final int MSG_UPDATE_STATS = 3;
+
+    // All access to this map must be done from the handler thread.
+    // indexed by DataUsageRequest#requestId
+    private final SparseArray<RequestInfo> mDataUsageRequests = new SparseArray<>();
+
+    // Sequence number of DataUsageRequests
+    private final AtomicInteger mNextDataUsageRequestId = new AtomicInteger();
+
+    // Lazily instantiated when an observer is registered.
+    private Handler mHandler;
+
+    /**
+     * Creates a wrapper that contains the caller context and a normalized request.
+     * The request should be returned to the caller app, and the wrapper should be sent to this
+     * object through #addObserver by the service handler.
+     *
+     * <p>It will register the observer asynchronously, so it is safe to call from any thread.
+     *
+     * @return the normalized request wrapped within {@link RequestInfo}.
+     */
+    public DataUsageRequest register(DataUsageRequest inputRequest, Messenger messenger,
+                IBinder binder, int callingUid, @NetworkStatsAccess.Level int accessLevel) {
+        checkVisibilityUids(callingUid, accessLevel, inputRequest.uids);
+
+        DataUsageRequest request = buildRequest(inputRequest);
+        RequestInfo requestInfo = buildRequestInfo(request, messenger, binder, callingUid,
+                accessLevel);
+
+        if (LOGV) Slog.v(TAG, "Registering observer for " + request);
+        getHandler().sendMessage(mHandler.obtainMessage(MSG_REGISTER, requestInfo));
+        return request;
+    }
+
+    /**
+     * Unregister a data usage observer.
+     *
+     * <p>It will unregister the observer asynchronously, so it is safe to call from any thread.
+     */
+    public void unregister(DataUsageRequest request, int callingUid) {
+        getHandler().sendMessage(mHandler.obtainMessage(MSG_UNREGISTER, callingUid, 0 /* ignore */,
+                request));
+    }
+
+    /**
+     * Updates data usage statistics of registered observers and notifies if limits are reached.
+     *
+     * <p>It will update stats asynchronously, so it is safe to call from any thread.
+     */
+    public void updateStats(NetworkStats xtSnapshot, NetworkStats uidSnapshot,
+                ArrayMap<String, NetworkIdentitySet> activeIfaces,
+                ArrayMap<String, NetworkIdentitySet> activeUidIfaces,
+                VpnInfo[] vpnArray, long currentTime) {
+        StatsContext statsContext = new StatsContext(xtSnapshot, uidSnapshot, activeIfaces,
+                activeUidIfaces, vpnArray, currentTime);
+        getHandler().sendMessage(mHandler.obtainMessage(MSG_UPDATE_STATS, statsContext));
+    }
+
+    private Handler getHandler() {
+        if (mHandler == null) {
+            synchronized (this) {
+                if (mHandler == null) {
+                    if (LOGV) Slog.v(TAG, "Creating handler");
+                    mHandler = new Handler(getHandlerLooperLocked(), mHandlerCallback);
+                }
+            }
+        }
+        return mHandler;
+    }
+
+    @VisibleForTesting
+    protected Looper getHandlerLooperLocked() {
+        HandlerThread handlerThread = new HandlerThread(TAG);
+        handlerThread.start();
+        return handlerThread.getLooper();
+    }
+
+    private Handler.Callback mHandlerCallback = new Handler.Callback() {
+        @Override
+        public boolean handleMessage(Message msg) {
+            switch (msg.what) {
+                case MSG_REGISTER: {
+                    handleRegister((RequestInfo) msg.obj);
+                    return true;
+                }
+                case MSG_UNREGISTER: {
+                    handleUnregister((DataUsageRequest) msg.obj, msg.arg1 /* callingUid */);
+                    return true;
+                }
+                case MSG_UPDATE_STATS: {
+                    handleUpdateStats((StatsContext) msg.obj);
+                    return true;
+                }
+                default: {
+                    return false;
+                }
+            }
+        }
+    };
+
+    /**
+     * Adds a {@link RequestInfo} as an observer.
+     * Should only be called from the handler thread otherwise there will be a race condition
+     * on mDataUsageRequests.
+     */
+    private void handleRegister(RequestInfo requestInfo) {
+        mDataUsageRequests.put(requestInfo.mRequest.requestId, requestInfo);
+    }
+
+    /**
+     * Removes a {@link DataUsageRequest} if the calling uid is authorized.
+     * Should only be called from the handler thread otherwise there will be a race condition
+     * on mDataUsageRequests.
+     */
+    private void handleUnregister(DataUsageRequest request, int callingUid) {
+        RequestInfo requestInfo;
+        requestInfo = mDataUsageRequests.get(request.requestId);
+        if (requestInfo == null) {
+            if (LOGV) Slog.v(TAG, "Trying to unregister unknown request " + request);
+            return;
+        }
+        if (Process.SYSTEM_UID != callingUid && requestInfo.mCallingUid != callingUid) {
+            Slog.w(TAG, "Caller uid " + callingUid + " is not owner of " + request);
+            return;
+        }
+
+        if (LOGV) Slog.v(TAG, "Unregistering " + request);
+        mDataUsageRequests.remove(request.requestId);
+        requestInfo.unlinkDeathRecipient();
+        requestInfo.callCallback(NetworkStatsManager.CALLBACK_RELEASED);
+    }
+
+    private void handleUpdateStats(StatsContext statsContext) {
+        if (mDataUsageRequests.size() == 0) {
+            if (LOGV) Slog.v(TAG, "No registered listeners of data usage");
+            return;
+        }
+
+        if (LOGV) Slog.v(TAG, "Checking if any registered observer needs to be notified");
+        for (int i = 0; i < mDataUsageRequests.size(); i++) {
+            RequestInfo requestInfo = mDataUsageRequests.valueAt(i);
+            requestInfo.updateStats(statsContext);
+        }
+    }
+
+    private DataUsageRequest buildRequest(DataUsageRequest request) {
+        // Cap the minimum threshold to a safe default to avoid too many callbacks
+        long thresholdInBytes = Math.max(MIN_THRESHOLD_BYTES, request.thresholdInBytes);
+        if (thresholdInBytes < request.thresholdInBytes) {
+            Slog.w(TAG, "Threshold was too low for " + request
+                    + ". Overriding to a safer default of " + thresholdInBytes + " bytes");
+        }
+        return new DataUsageRequest(mNextDataUsageRequestId.incrementAndGet(),
+                request.templates, request.uids, thresholdInBytes);
+    }
+
+    private RequestInfo buildRequestInfo(DataUsageRequest request,
+                Messenger messenger, IBinder binder, int callingUid,
+                @NetworkStatsAccess.Level int accessLevel) {
+        if (accessLevel <= NetworkStatsAccess.Level.USER
+                || request.uids != null && request.uids.length > 0) {
+            return new UserUsageRequestInfo(this, request, messenger, binder, callingUid,
+                    accessLevel);
+        } else {
+            // Safety check in case a new access level is added and we forgot to update this
+            checkArgument(accessLevel >= NetworkStatsAccess.Level.DEVICESUMMARY);
+            return new NetworkUsageRequestInfo(this, request, messenger, binder, callingUid,
+                    accessLevel);
+        }
+    }
+
+    private void checkVisibilityUids(int callingUid, @NetworkStatsAccess.Level int accessLevel,
+                int[] uids) {
+        if (uids == null) {
+            return;
+        }
+        for (int i = 0; i < uids.length; i++) {
+            if (!NetworkStatsAccess.isAccessibleToUser(uids[i], callingUid, accessLevel)) {
+                throw new SecurityException("Caller " + callingUid + " cannot monitor network stats"
+                        + " for uid " + uids[i] + " with accessLevel " + accessLevel);
+            }
+        }
+    }
+
+    /**
+     * Tracks information relevant to a data usage observer.
+     * It will notice when the calling process dies so we can self-expire.
+     */
+    private abstract static class RequestInfo implements IBinder.DeathRecipient {
+        private final NetworkStatsObservers mStatsObserver;
+        protected final DataUsageRequest mRequest;
+        private final Messenger mMessenger;
+        private final IBinder mBinder;
+        protected final int mCallingUid;
+        protected final @NetworkStatsAccess.Level int mAccessLevel;
+        protected NetworkStatsRecorder mRecorder;
+        protected NetworkStatsCollection mCollection;
+
+        RequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
+                    Messenger messenger, IBinder binder, int callingUid,
+                    @NetworkStatsAccess.Level int accessLevel) {
+            mStatsObserver = statsObserver;
+            mRequest = request;
+            mMessenger = messenger;
+            mBinder = binder;
+            mCallingUid = callingUid;
+            mAccessLevel = accessLevel;
+
+            try {
+                mBinder.linkToDeath(this, 0);
+            } catch (RemoteException e) {
+                binderDied();
+            }
+        }
+
+        @Override
+        public void binderDied() {
+            if (LOGV) Slog.v(TAG, "RequestInfo binderDied("
+                    + mRequest + ", " + mBinder + ")");
+            mStatsObserver.unregister(mRequest, Process.SYSTEM_UID);
+            callCallback(NetworkStatsManager.CALLBACK_RELEASED);
+        }
+
+        @Override
+        public String toString() {
+            return "RequestInfo from uid:" + mCallingUid
+                    + " for " + mRequest + " accessLevel:" + mAccessLevel;
+        }
+
+        private void unlinkDeathRecipient() {
+            if (mBinder != null) {
+                mBinder.unlinkToDeath(this, 0);
+            }
+        }
+
+        /**
+         * Update stats given the samples and interface to identity mappings.
+         */
+        private void updateStats(StatsContext statsContext) {
+            if (mRecorder == null) {
+                // First run; establish baseline stats
+                resetRecorder();
+                recordSample(statsContext);
+                return;
+            }
+            recordSample(statsContext);
+
+            if (checkStats()) {
+                resetRecorder();
+                callCallback(NetworkStatsManager.CALLBACK_LIMIT_REACHED);
+            }
+        }
+
+        private void callCallback(int callbackType) {
+            Bundle bundle = new Bundle();
+            bundle.putParcelable(DataUsageRequest.PARCELABLE_KEY, mRequest);
+            Message msg = Message.obtain();
+            msg.what = callbackType;
+            msg.setData(bundle);
+            try {
+                if (LOGV) {
+                    Slog.v(TAG, "sending notification " + callbackTypeToName(callbackType)
+                            + " for " + mRequest);
+                }
+                mMessenger.send(msg);
+            } catch (RemoteException e) {
+                // May occur naturally in the race of binder death.
+                Slog.w(TAG, "RemoteException caught trying to send a callback msg for " + mRequest);
+            }
+        }
+
+        private void resetRecorder() {
+            mRecorder = new NetworkStatsRecorder();
+            mCollection = mRecorder.getSinceBoot();
+        }
+
+        protected abstract boolean checkStats();
+
+        protected abstract void recordSample(StatsContext statsContext);
+
+        private String callbackTypeToName(int callbackType) {
+            switch (callbackType) {
+                case NetworkStatsManager.CALLBACK_LIMIT_REACHED:
+                    return "LIMIT_REACHED";
+                case NetworkStatsManager.CALLBACK_RELEASED:
+                    return "RELEASED";
+                default:
+                    return "UNKNOWN";
+            }
+        }
+    }
+
+    private static class NetworkUsageRequestInfo extends RequestInfo {
+        NetworkUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
+                    Messenger messenger, IBinder binder, int callingUid,
+                    @NetworkStatsAccess.Level int accessLevel) {
+            super(statsObserver, request, messenger, binder, callingUid, accessLevel);
+        }
+
+        @Override
+        protected boolean checkStats() {
+            for (int i = 0; i < mRequest.templates.length; i++) {
+                long bytesSoFar = getTotalBytesForNetwork(mRequest.templates[i]);
+                if (LOGV) {
+                    Slog.v(TAG, bytesSoFar + " bytes so far since notification for "
+                            + mRequest.templates[i]);
+                }
+                if (bytesSoFar > mRequest.thresholdInBytes) {
+                    return true;
+                }
+            }
+            return false;
+        }
+
+        @Override
+        protected void recordSample(StatsContext statsContext) {
+            // Recorder does not need to be locked in this context since only the handler
+            // thread will update it
+            mRecorder.recordSnapshotLocked(statsContext.mXtSnapshot, statsContext.mActiveIfaces,
+                    statsContext.mVpnArray, statsContext.mCurrentTime);
+        }
+
+        /**
+         * Reads stats matching the given template. {@link NetworkStatsCollection} will aggregate
+         * over all buckets, which in this case should be only one since we built it big enough
+         * that it will outlive the caller. If it doesn't, then there will be multiple buckets.
+         */
+        private long getTotalBytesForNetwork(NetworkTemplate template) {
+            NetworkStats stats = mCollection.getSummary(template,
+                    Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */,
+                    mAccessLevel, mCallingUid);
+            if (LOGV) {
+                Slog.v(TAG, "Netstats for " + template + ": " + stats);
+            }
+            return stats.getTotalBytes();
+        }
+    }
+
+    private static class UserUsageRequestInfo extends RequestInfo {
+        UserUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
+                    Messenger messenger, IBinder binder, int callingUid,
+                    @NetworkStatsAccess.Level int accessLevel) {
+            super(statsObserver, request, messenger, binder, callingUid, accessLevel);
+        }
+
+        @Override
+        protected boolean checkStats() {
+            int[] uidsToMonitor = getUidsToMonitor();
+
+            for (int i = 0; i < mRequest.templates.length; i++) {
+                for (int j = 0; j < uidsToMonitor.length; j++) {
+                    long bytesSoFar = getTotalBytesForNetworkUid(mRequest.templates[i],
+                            uidsToMonitor[j]);
+
+                    if (LOGV) {
+                        Slog.v(TAG, bytesSoFar + " bytes so far since notification for "
+                                + mRequest.templates[i] + " for uid=" + uidsToMonitor[j]);
+                    }
+                    if (bytesSoFar > mRequest.thresholdInBytes) {
+                        return true;
+                    }
+                }
+            }
+            return false;
+        }
+
+        @Override
+        protected void recordSample(StatsContext statsContext) {
+            // Recorder does not need to be locked in this context since only the handler
+            // thread will update it
+            mRecorder.recordSnapshotLocked(statsContext.mUidSnapshot, statsContext.mActiveUidIfaces,
+                    statsContext.mVpnArray, statsContext.mCurrentTime);
+        }
+
+        /**
+         * Reads all stats matching the given template and uid. Ther history will likely only
+         * contain one bucket per ident since we build it big enough that it will outlive the
+         * caller lifetime.
+         */
+        private long getTotalBytesForNetworkUid(NetworkTemplate template, int uid) {
+            try {
+                NetworkStatsHistory history = mCollection.getHistory(template, uid,
+                        NetworkStats.SET_ALL, NetworkStats.TAG_NONE,
+                        NetworkStatsHistory.FIELD_ALL,
+                        Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */,
+                        mAccessLevel, mCallingUid);
+                return history.getTotalBytes();
+            } catch (SecurityException e) {
+                if (LOGV) {
+                    Slog.w(TAG, "CallerUid " + mCallingUid + " may have lost access to uid "
+                            + uid);
+                }
+                return 0;
+            }
+        }
+
+        private int[] getUidsToMonitor() {
+            if (mRequest.uids == null || mRequest.uids.length == 0) {
+                return mCollection.getRelevantUids(mAccessLevel, mCallingUid);
+            }
+            // Pick only uids from the request that are currently accessible to the user
+            IntArray accessibleUids = new IntArray(mRequest.uids.length);
+            for (int i = 0; i < mRequest.uids.length; i++) {
+                int uid = mRequest.uids[i];
+                if (NetworkStatsAccess.isAccessibleToUser(uid, mCallingUid, mAccessLevel)) {
+                    accessibleUids.add(uid);
+                }
+            }
+            return accessibleUids.toArray();
+        }
+    }
+
+    private static class StatsContext {
+        NetworkStats mXtSnapshot;
+        NetworkStats mUidSnapshot;
+        ArrayMap<String, NetworkIdentitySet> mActiveIfaces;
+        ArrayMap<String, NetworkIdentitySet> mActiveUidIfaces;
+        VpnInfo[] mVpnArray;
+        long mCurrentTime;
+
+        StatsContext(NetworkStats xtSnapshot, NetworkStats uidSnapshot,
+                ArrayMap<String, NetworkIdentitySet> activeIfaces,
+                ArrayMap<String, NetworkIdentitySet> activeUidIfaces,
+                VpnInfo[] vpnArray, long currentTime) {
+            mXtSnapshot = xtSnapshot;
+            mUidSnapshot = uidSnapshot;
+            mActiveIfaces = activeIfaces;
+            mActiveUidIfaces = activeUidIfaces;
+            mVpnArray = vpnArray;
+            mCurrentTime = currentTime;
+        }
+    }
+}
diff --git a/services/core/java/com/android/server/net/NetworkStatsRecorder.java b/services/core/java/com/android/server/net/NetworkStatsRecorder.java
index c091960..04dc917 100644
--- a/services/core/java/com/android/server/net/NetworkStatsRecorder.java
+++ b/services/core/java/com/android/server/net/NetworkStatsRecorder.java
@@ -19,6 +19,7 @@
 import static android.net.NetworkStats.TAG_NONE;
 import static android.net.TrafficStats.KB_IN_BYTES;
 import static android.net.TrafficStats.MB_IN_BYTES;
+import static android.text.format.DateUtils.YEAR_IN_MILLIS;
 import static com.android.internal.util.Preconditions.checkNotNull;
 
 import android.net.NetworkStats;
@@ -54,7 +55,7 @@
  * Logic to record deltas between periodic {@link NetworkStats} snapshots into
  * {@link NetworkStatsHistory} that belong to {@link NetworkStatsCollection}.
  * Keeps pending changes in memory until they pass a specific threshold, in
- * bytes. Uses {@link FileRotator} for persistence logic.
+ * bytes. Uses {@link FileRotator} for persistence logic if present.
  * <p>
  * Not inherently thread safe.
  */
@@ -86,6 +87,29 @@
 
     private WeakReference<NetworkStatsCollection> mComplete;
 
+    /**
+     * Non-persisted recorder, with only one bucket. Used by {@link NetworkStatsObservers}.
+     */
+    public NetworkStatsRecorder() {
+        mRotator = null;
+        mObserver = null;
+        mDropBox = null;
+        mCookie = null;
+
+        // set the bucket big enough to have all data in one bucket, but allow some
+        // slack to avoid overflow
+        mBucketDuration = YEAR_IN_MILLIS;
+        mOnlyTags = false;
+
+        mPending = null;
+        mSinceBoot = new NetworkStatsCollection(mBucketDuration);
+
+        mPendingRewriter = null;
+    }
+
+    /**
+     * Persisted recorder.
+     */
     public NetworkStatsRecorder(FileRotator rotator, NonMonotonicObserver<String> observer,
             DropBoxManager dropBox, String cookie, long bucketDuration, boolean onlyTags) {
         mRotator = checkNotNull(rotator, "missing FileRotator");
@@ -110,9 +134,15 @@
 
     public void resetLocked() {
         mLastSnapshot = null;
-        mPending.reset();
-        mSinceBoot.reset();
-        mComplete.clear();
+        if (mPending != null) {
+            mPending.reset();
+        }
+        if (mSinceBoot != null) {
+            mSinceBoot.reset();
+        }
+        if (mComplete != null) {
+            mComplete.clear();
+        }
     }
 
     public NetworkStats.Entry getTotalSinceBootLocked(NetworkTemplate template) {
@@ -120,6 +150,10 @@
                 NetworkStatsAccess.Level.DEVICE).getTotal(null);
     }
 
+    public NetworkStatsCollection getSinceBoot() {
+        return mSinceBoot;
+    }
+
     /**
      * Load complete history represented by {@link FileRotator}. Caches
      * internally as a {@link WeakReference}, and updated with future
@@ -127,6 +161,7 @@
      * as reference is valid.
      */
     public NetworkStatsCollection getOrLoadCompleteLocked() {
+        checkNotNull(mRotator, "missing FileRotator");
         NetworkStatsCollection res = mComplete != null ? mComplete.get() : null;
         if (res == null) {
             res = loadLocked(Long.MIN_VALUE, Long.MAX_VALUE);
@@ -136,6 +171,7 @@
     }
 
     public NetworkStatsCollection getOrLoadPartialLocked(long start, long end) {
+        checkNotNull(mRotator, "missing FileRotator");
         NetworkStatsCollection res = mComplete != null ? mComplete.get() : null;
         if (res == null) {
             res = loadLocked(start, end);
@@ -205,7 +241,9 @@
 
             // only record tag data when requested
             if ((entry.tag == TAG_NONE) != mOnlyTags) {
-                mPending.recordData(ident, entry.uid, entry.set, entry.tag, start, end, entry);
+                if (mPending != null) {
+                    mPending.recordData(ident, entry.uid, entry.set, entry.tag, start, end, entry);
+                }
 
                 // also record against boot stats when present
                 if (mSinceBoot != null) {
@@ -231,6 +269,7 @@
      * {@link #mPersistThresholdBytes}.
      */
     public void maybePersistLocked(long currentTimeMillis) {
+        checkNotNull(mRotator, "missing FileRotator");
         final long pendingBytes = mPending.getTotalBytes();
         if (pendingBytes >= mPersistThresholdBytes) {
             forcePersistLocked(currentTimeMillis);
@@ -243,6 +282,7 @@
      * Force persisting any pending deltas.
      */
     public void forcePersistLocked(long currentTimeMillis) {
+        checkNotNull(mRotator, "missing FileRotator");
         if (mPending.isDirty()) {
             if (LOGD) Slog.d(TAG, "forcePersistLocked() writing for " + mCookie);
             try {
@@ -264,20 +304,26 @@
      * to {@link TrafficStats#UID_REMOVED}.
      */
     public void removeUidsLocked(int[] uids) {
-        try {
-            // Rewrite all persisted data to migrate UID stats
-            mRotator.rewriteAll(new RemoveUidRewriter(mBucketDuration, uids));
-        } catch (IOException e) {
-            Log.wtf(TAG, "problem removing UIDs " + Arrays.toString(uids), e);
-            recoverFromWtf();
-        } catch (OutOfMemoryError e) {
-            Log.wtf(TAG, "problem removing UIDs " + Arrays.toString(uids), e);
-            recoverFromWtf();
+        if (mRotator != null) {
+            try {
+                // Rewrite all persisted data to migrate UID stats
+                mRotator.rewriteAll(new RemoveUidRewriter(mBucketDuration, uids));
+            } catch (IOException e) {
+                Log.wtf(TAG, "problem removing UIDs " + Arrays.toString(uids), e);
+                recoverFromWtf();
+            } catch (OutOfMemoryError e) {
+                Log.wtf(TAG, "problem removing UIDs " + Arrays.toString(uids), e);
+                recoverFromWtf();
+            }
         }
 
         // Remove any pending stats
-        mPending.removeUids(uids);
-        mSinceBoot.removeUids(uids);
+        if (mPending != null) {
+            mPending.removeUids(uids);
+        }
+        if (mSinceBoot != null) {
+            mSinceBoot.removeUids(uids);
+        }
 
         // Clear UID from current stats snapshot
         if (mLastSnapshot != null) {
@@ -361,6 +407,8 @@
     }
 
     public void importLegacyNetworkLocked(File file) throws IOException {
+        checkNotNull(mRotator, "missing FileRotator");
+
         // legacy file still exists; start empty to avoid double importing
         mRotator.deleteAll();
 
@@ -379,6 +427,8 @@
     }
 
     public void importLegacyUidLocked(File file) throws IOException {
+        checkNotNull(mRotator, "missing FileRotator");
+
         // legacy file still exists; start empty to avoid double importing
         mRotator.deleteAll();
 
@@ -397,7 +447,9 @@
     }
 
     public void dumpLocked(IndentingPrintWriter pw, boolean fullHistory) {
-        pw.print("Pending bytes: "); pw.println(mPending.getTotalBytes());
+        if (mPending != null) {
+            pw.print("Pending bytes: "); pw.println(mPending.getTotalBytes());
+        }
         if (fullHistory) {
             pw.println("Complete history:");
             getOrLoadCompleteLocked().dump(pw);
diff --git a/services/core/java/com/android/server/net/NetworkStatsService.java b/services/core/java/com/android/server/net/NetworkStatsService.java
index 3aeceef..2c2e9b9 100644
--- a/services/core/java/com/android/server/net/NetworkStatsService.java
+++ b/services/core/java/com/android/server/net/NetworkStatsService.java
@@ -57,6 +57,7 @@
 import static android.text.format.DateUtils.HOUR_IN_MILLIS;
 import static android.text.format.DateUtils.MINUTE_IN_MILLIS;
 import static android.text.format.DateUtils.SECOND_IN_MILLIS;
+import static com.android.internal.util.Preconditions.checkArgument;
 import static com.android.internal.util.Preconditions.checkNotNull;
 import static com.android.server.NetworkManagementService.LIMIT_GLOBAL_ALERT;
 import static com.android.server.NetworkManagementSocketTagger.resetKernelUidStats;
@@ -72,6 +73,7 @@
 import android.content.IntentFilter;
 import android.content.pm.ApplicationInfo;
 import android.content.pm.PackageManager;
+import android.net.DataUsageRequest;
 import android.net.IConnectivityManager;
 import android.net.INetworkManagementEventObserver;
 import android.net.INetworkStatsService;
@@ -90,8 +92,10 @@
 import android.os.Environment;
 import android.os.Handler;
 import android.os.HandlerThread;
+import android.os.IBinder;
 import android.os.INetworkManagementService;
 import android.os.Message;
+import android.os.Messenger;
 import android.os.PowerManager;
 import android.os.RemoteException;
 import android.os.SystemClock;
@@ -152,6 +156,7 @@
     private final TrustedTime mTime;
     private final TelephonyManager mTeleManager;
     private final NetworkStatsSettings mSettings;
+    private final NetworkStatsObservers mStatsObservers;
 
     private final File mSystemDir;
     private final File mBaseDir;
@@ -233,43 +238,65 @@
     /** Data layer operation counters for splicing into other structures. */
     private NetworkStats mUidOperations = new NetworkStats(0L, 10);
 
-    private final Handler mHandler;
+    /** Must be set in factory by calling #setHandler. */
+    private Handler mHandler;
+    private Handler.Callback mHandlerCallback;
 
     private boolean mSystemReady;
     private long mPersistThreshold = 2 * MB_IN_BYTES;
     private long mGlobalAlertBytes;
 
-    public NetworkStatsService(
-            Context context, INetworkManagementService networkManager, IAlarmManager alarmManager) {
-        this(context, networkManager, alarmManager, NtpTrustedTime.getInstance(context),
-                getDefaultSystemDir(), new DefaultNetworkStatsSettings(context));
-    }
-
     private static File getDefaultSystemDir() {
         return new File(Environment.getDataDirectory(), "system");
     }
 
-    public NetworkStatsService(Context context, INetworkManagementService networkManager,
-            IAlarmManager alarmManager, TrustedTime time, File systemDir,
-            NetworkStatsSettings settings) {
+    private static File getDefaultBaseDir() {
+        File baseDir = new File(getDefaultSystemDir(), "netstats");
+        baseDir.mkdirs();
+        return baseDir;
+    }
+
+    public static NetworkStatsService create(Context context,
+                INetworkManagementService networkManager) {
+        AlarmManager alarmManager = (AlarmManager) context.getSystemService(Context.ALARM_SERVICE);
+        PowerManager powerManager = (PowerManager) context.getSystemService(Context.POWER_SERVICE);
+        PowerManager.WakeLock wakeLock =
+                powerManager.newWakeLock(PowerManager.PARTIAL_WAKE_LOCK, TAG);
+
+        NetworkStatsService service = new NetworkStatsService(context, networkManager, alarmManager,
+                wakeLock, NtpTrustedTime.getInstance(context), TelephonyManager.getDefault(),
+                new DefaultNetworkStatsSettings(context), new NetworkStatsObservers(),
+                getDefaultSystemDir(), getDefaultBaseDir());
+
+        HandlerThread handlerThread = new HandlerThread(TAG);
+        Handler.Callback callback = new HandlerCallback(service);
+        handlerThread.start();
+        Handler handler = new Handler(handlerThread.getLooper(), callback);
+        service.setHandler(handler, callback);
+        return service;
+    }
+
+    @VisibleForTesting
+    NetworkStatsService(Context context, INetworkManagementService networkManager,
+            AlarmManager alarmManager, PowerManager.WakeLock wakeLock, TrustedTime time,
+            TelephonyManager teleManager, NetworkStatsSettings settings,
+            NetworkStatsObservers statsObservers, File systemDir, File baseDir) {
         mContext = checkNotNull(context, "missing Context");
         mNetworkManager = checkNotNull(networkManager, "missing INetworkManagementService");
+        mAlarmManager = checkNotNull(alarmManager, "missing AlarmManager");
         mTime = checkNotNull(time, "missing TrustedTime");
-        mTeleManager = checkNotNull(TelephonyManager.getDefault(), "missing TelephonyManager");
         mSettings = checkNotNull(settings, "missing NetworkStatsSettings");
-        mAlarmManager = (AlarmManager) context.getSystemService(Context.ALARM_SERVICE);
+        mTeleManager = checkNotNull(teleManager, "missing TelephonyManager");
+        mWakeLock = checkNotNull(wakeLock, "missing WakeLock");
+        mStatsObservers = checkNotNull(statsObservers, "missing NetworkStatsObservers");
+        mSystemDir = checkNotNull(systemDir, "missing systemDir");
+        mBaseDir = checkNotNull(baseDir, "missing baseDir");
+    }
 
-        final PowerManager powerManager = (PowerManager) context.getSystemService(
-                Context.POWER_SERVICE);
-        mWakeLock = powerManager.newWakeLock(PowerManager.PARTIAL_WAKE_LOCK, TAG);
-
-        HandlerThread thread = new HandlerThread(TAG);
-        thread.start();
-        mHandler = new Handler(thread.getLooper(), mHandlerCallback);
-
-        mSystemDir = checkNotNull(systemDir);
-        mBaseDir = new File(systemDir, "netstats");
-        mBaseDir.mkdirs();
+    @VisibleForTesting
+    void setHandler(Handler handler, Handler.Callback callback) {
+        mHandler = handler;
+        mHandlerCallback = callback;
     }
 
     public void bindConnectivityManager(IConnectivityManager connManager) {
@@ -733,6 +760,46 @@
         registerGlobalAlert();
     }
 
+    @Override
+    public DataUsageRequest registerDataUsageCallback(String callingPackage,
+                DataUsageRequest request, Messenger messenger, IBinder binder) {
+        checkNotNull(callingPackage, "calling package is null");
+        checkNotNull(request, "DataUsageRequest is null");
+        checkNotNull(request.templates, "NetworkTemplate is null");
+        checkArgument(request.templates.length > 0);
+        checkNotNull(messenger, "messenger is null");
+        checkNotNull(binder, "binder is null");
+
+        int callingUid = Binder.getCallingUid();
+        @NetworkStatsAccess.Level int accessLevel = checkAccessLevel(callingPackage);
+        DataUsageRequest normalizedRequest;
+        final long token = Binder.clearCallingIdentity();
+        try {
+            normalizedRequest = mStatsObservers.register(request, messenger, binder,
+                    callingUid, accessLevel);
+        } finally {
+            Binder.restoreCallingIdentity(token);
+        }
+
+        // Create baseline stats
+        mHandler.sendMessage(mHandler.obtainMessage(MSG_PERFORM_POLL, FLAG_PERSIST_ALL));
+
+        return normalizedRequest;
+   }
+
+    @Override
+    public void unregisterDataUsageRequest(DataUsageRequest request) {
+        checkNotNull(request, "DataUsageRequest is null");
+
+        int callingUid = Binder.getCallingUid();
+        final long token = Binder.clearCallingIdentity();
+        try {
+            mStatsObservers.unregister(request, callingUid);
+        } finally {
+            Binder.restoreCallingIdentity(token);
+        }
+    }
+
     /**
      * Update {@link NetworkStatsRecorder} and {@link #mGlobalAlertBytes} to
      * reflect current {@link #mPersistThreshold} value. Always defers to
@@ -945,6 +1012,11 @@
         mXtRecorder.recordSnapshotLocked(xtSnapshot, mActiveIfaces, null, currentTime);
         mUidRecorder.recordSnapshotLocked(uidSnapshot, mActiveUidIfaces, vpnArray, currentTime);
         mUidTagRecorder.recordSnapshotLocked(uidSnapshot, mActiveUidIfaces, vpnArray, currentTime);
+
+        // We need to make copies of member fields that are sent to the observer to avoid
+        // a race condition between the service handler thread and the observer's
+        mStatsObservers.updateStats(xtSnapshot, uidSnapshot, new ArrayMap<>(mActiveIfaces),
+                new ArrayMap<>(mActiveUidIfaces), vpnArray, currentTime);
     }
 
     /**
@@ -1243,21 +1315,28 @@
         }
     }
 
-    private Handler.Callback mHandlerCallback = new Handler.Callback() {
+    @VisibleForTesting
+    static class HandlerCallback implements Handler.Callback {
+        private final NetworkStatsService mService;
+
+        HandlerCallback(NetworkStatsService service) {
+            this.mService = service;
+        }
+
         @Override
         public boolean handleMessage(Message msg) {
             switch (msg.what) {
                 case MSG_PERFORM_POLL: {
                     final int flags = msg.arg1;
-                    performPoll(flags);
+                    mService.performPoll(flags);
                     return true;
                 }
                 case MSG_UPDATE_IFACES: {
-                    updateIfaces();
+                    mService.updateIfaces();
                     return true;
                 }
                 case MSG_REGISTER_GLOBAL_ALERT: {
-                    registerGlobalAlert();
+                    mService.registerGlobalAlert();
                     return true;
                 }
                 default: {
@@ -1265,7 +1344,7 @@
                 }
             }
         }
-    };
+    }
 
     private void assertBandwidthControlEnabled() {
         if (!isBandwidthControlEnabled()) {
diff --git a/services/java/com/android/server/SystemServer.java b/services/java/com/android/server/SystemServer.java
index 0cf9328..ac972a9 100644
--- a/services/java/com/android/server/SystemServer.java
+++ b/services/java/com/android/server/SystemServer.java
@@ -760,7 +760,7 @@
 
                 traceBeginAndSlog("StartNetworkStatsService");
                 try {
-                    networkStats = new NetworkStatsService(context, networkManagement, alarm);
+                    networkStats = NetworkStatsService.create(context, networkManagement);
                     ServiceManager.addService(Context.NETWORK_STATS_SERVICE, networkStats);
                 } catch (Throwable e) {
                     reportWtf("starting NetworkStats Service", e);
diff --git a/services/tests/servicestests/src/com/android/server/net/NetworkStatsObserversTest.java b/services/tests/servicestests/src/com/android/server/net/NetworkStatsObserversTest.java
new file mode 100644
index 0000000..b9e9aa9
--- /dev/null
+++ b/services/tests/servicestests/src/com/android/server/net/NetworkStatsObserversTest.java
@@ -0,0 +1,634 @@
+/*
+ * Copyright (C) 2016 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.net;
+
+import static android.net.ConnectivityManager.TYPE_MOBILE;
+import static android.net.ConnectivityManager.TYPE_WIFI;
+import static android.text.format.DateUtils.MINUTE_IN_MILLIS;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Matchers.isA;
+import static org.mockito.Mockito.when;
+
+import static android.net.NetworkStats.SET_DEFAULT;
+import static android.net.NetworkStats.ROAMING_DEFAULT;
+import static android.net.NetworkStats.TAG_NONE;
+import static android.net.NetworkTemplate.buildTemplateMobileAll;
+import static android.net.NetworkTemplate.buildTemplateWifiWildcard;
+import static android.net.TrafficStats.MB_IN_BYTES;
+import static android.text.format.DateUtils.MINUTE_IN_MILLIS;
+
+import android.app.usage.NetworkStatsManager;
+import android.net.DataUsageRequest;
+import android.net.NetworkIdentity;
+import android.net.NetworkStats;
+import android.net.NetworkTemplate;
+import android.os.Handler;
+import android.os.HandlerThread;
+import android.os.IBinder;
+import android.os.Process;
+
+import android.os.ConditionVariable;
+import android.os.Looper;
+import android.os.Messenger;
+import android.os.Message;
+import android.os.UserHandle;
+import android.telephony.TelephonyManager;
+import android.util.ArrayMap;
+
+import com.android.internal.net.VpnInfo;
+import com.android.server.net.NetworkStatsService;
+import com.android.server.net.NetworkStatsServiceTest.IdleableHandlerThread;
+import com.android.server.net.NetworkStatsServiceTest.LatchedHandler;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import junit.framework.TestCase;
+
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.MockitoAnnotations;
+
+/**
+ * Tests for {@link NetworkStatsObservers}.
+ */
+public class NetworkStatsObserversTest extends TestCase {
+    private static final String TEST_IFACE = "test0";
+    private static final String TEST_IFACE2 = "test1";
+    private static final long TEST_START = 1194220800000L;
+
+    private static final String IMSI_1 = "310004";
+    private static final String IMSI_2 = "310260";
+    private static final String TEST_SSID = "AndroidAP";
+
+    private static NetworkTemplate sTemplateWifi = buildTemplateWifiWildcard();
+    private static NetworkTemplate sTemplateImsi1 = buildTemplateMobileAll(IMSI_1);
+    private static NetworkTemplate sTemplateImsi2 = buildTemplateMobileAll(IMSI_2);
+
+    private static final int UID_RED = UserHandle.PER_USER_RANGE + 1;
+    private static final int UID_BLUE = UserHandle.PER_USER_RANGE + 2;
+    private static final int UID_GREEN = UserHandle.PER_USER_RANGE + 3;
+    private static final int UID_ANOTHER_USER = 2 * UserHandle.PER_USER_RANGE + 4;
+
+    private static final long WAIT_TIMEOUT = 500;  // 1/2 sec
+    private static final long THRESHOLD_BYTES = 2 * MB_IN_BYTES;
+    private static final long BASE_BYTES = 7 * MB_IN_BYTES;
+    private static final int INVALID_TYPE = -1;
+
+    private static final int[] NO_UIDS = null;
+    private static final VpnInfo[] VPN_INFO = new VpnInfo[0];
+
+    private long mElapsedRealtime;
+
+    private IdleableHandlerThread mObserverHandlerThread;
+    private Handler mObserverNoopHandler;
+
+    private LatchedHandler mHandler;
+    private ConditionVariable mCv;
+
+    private NetworkStatsObservers mStatsObservers;
+    private Messenger mMessenger;
+    private ArrayMap<String, NetworkIdentitySet> mActiveIfaces;
+    private ArrayMap<String, NetworkIdentitySet> mActiveUidIfaces;
+
+    @Mock private IBinder mockBinder;
+
+    @Override
+    public void setUp() throws Exception {
+        super.setUp();
+        MockitoAnnotations.initMocks(this);
+
+        mObserverHandlerThread = new IdleableHandlerThread("HandlerThread");
+        mObserverHandlerThread.start();
+        final Looper observerLooper = mObserverHandlerThread.getLooper();
+        mStatsObservers = new NetworkStatsObservers() {
+            @Override
+            protected Looper getHandlerLooperLocked() {
+                return observerLooper;
+            }
+        };
+
+        mCv = new ConditionVariable();
+        mHandler = new LatchedHandler(Looper.getMainLooper(), mCv);
+        mMessenger = new Messenger(mHandler);
+
+        mActiveIfaces = new ArrayMap<>();
+        mActiveUidIfaces = new ArrayMap<>();
+    }
+
+    public void testRegister_thresholdTooLow_setsDefaultThreshold() throws Exception {
+        long thresholdTooLowBytes = 1L;
+        NetworkTemplate[] templates = new NetworkTemplate[] { sTemplateWifi };
+        DataUsageRequest inputRequest = new DataUsageRequest(
+                DataUsageRequest.REQUEST_ID_UNSET, templates, NO_UIDS, thresholdTooLowBytes);
+
+        DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder,
+                Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE);
+        assertTrue(request.requestId > 0);
+        assertTrue(Arrays.deepEquals(templates, request.templates));
+        assertNull(request.uids);
+        assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
+    }
+
+    public void testRegister_highThreshold_accepted() throws Exception {
+        long highThresholdBytes = 2 * THRESHOLD_BYTES;
+        NetworkTemplate[] templates = new NetworkTemplate[] { sTemplateWifi };
+        DataUsageRequest inputRequest = new DataUsageRequest(
+                DataUsageRequest.REQUEST_ID_UNSET, templates, NO_UIDS, highThresholdBytes);
+
+        DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder,
+                Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE);
+        assertTrue(request.requestId > 0);
+        assertTrue(Arrays.deepEquals(templates, request.templates));
+        assertNull(request.uids);
+        assertEquals(highThresholdBytes, request.thresholdInBytes);
+    }
+
+    public void testRegister_twoRequests_twoIds() throws Exception {
+        NetworkTemplate[] templates = new NetworkTemplate[] { sTemplateWifi };
+        DataUsageRequest inputRequest = new DataUsageRequest(
+                DataUsageRequest.REQUEST_ID_UNSET, templates, NO_UIDS, THRESHOLD_BYTES);
+
+        DataUsageRequest request1 = mStatsObservers.register(inputRequest, mMessenger, mockBinder,
+                Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE);
+        assertTrue(request1.requestId > 0);
+        assertTrue(Arrays.deepEquals(templates, request1.templates));
+        assertNull(request1.uids);
+        assertEquals(THRESHOLD_BYTES, request1.thresholdInBytes);
+
+        DataUsageRequest request2 = mStatsObservers.register(inputRequest, mMessenger, mockBinder,
+                Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE);
+        assertTrue(request2.requestId > request1.requestId);
+        assertTrue(Arrays.deepEquals(templates, request2.templates));
+        assertNull(request2.uids);
+        assertEquals(THRESHOLD_BYTES, request2.thresholdInBytes);
+    }
+
+    public void testRegister_defaultAccess_otherUids_securityException() throws Exception {
+        NetworkTemplate[] templates = new NetworkTemplate[] { sTemplateImsi1 };
+        int[] uids = new int[] { UID_RED, UID_BLUE, UID_GREEN };
+        DataUsageRequest inputRequest = new DataUsageRequest(
+                DataUsageRequest.REQUEST_ID_UNSET, templates, uids, THRESHOLD_BYTES);
+
+        try {
+            mStatsObservers.register(inputRequest, mMessenger, mockBinder, UID_RED,
+                    NetworkStatsAccess.Level.DEFAULT);
+            fail("Should have denied access");
+        } catch (SecurityException expected) {}
+    }
+
+    public void testRegister_userAccess_otherUidsSameUser()
+            throws Exception {
+        NetworkTemplate[] templates = new NetworkTemplate[] { sTemplateImsi1 };
+        int[] uids = new int[] { UID_RED, UID_BLUE, UID_GREEN };
+        DataUsageRequest inputRequest = new DataUsageRequest(
+                DataUsageRequest.REQUEST_ID_UNSET, templates, uids, THRESHOLD_BYTES);
+
+        DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder,
+                UID_RED, NetworkStatsAccess.Level.USER);
+        assertTrue(request.requestId > 0);
+        assertTrue(Arrays.deepEquals(templates, request.templates));
+        assertTrue(Arrays.equals(uids, request.uids));
+        assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
+    }
+
+    public void testRegister_defaultAccess_sameUid() throws Exception {
+        NetworkTemplate[] templates = new NetworkTemplate[] { sTemplateImsi1 };
+        int[] uids = new int[] { UID_RED };
+        DataUsageRequest inputRequest = new DataUsageRequest(
+                DataUsageRequest.REQUEST_ID_UNSET, templates, uids, THRESHOLD_BYTES);
+
+        DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder,
+                UID_RED, NetworkStatsAccess.Level.DEFAULT);
+        assertTrue(request.requestId > 0);
+        assertTrue(Arrays.deepEquals(templates, request.templates));
+        assertTrue(Arrays.equals(uids, request.uids));
+        assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
+    }
+
+    public void testUnregister_unknownRequest_noop() throws Exception {
+        NetworkTemplate[] templates = new NetworkTemplate[] { sTemplateWifi };
+        DataUsageRequest unknownRequest = new DataUsageRequest(
+                123456 /* id */, templates, NO_UIDS, THRESHOLD_BYTES);
+
+        mStatsObservers.unregister(unknownRequest, UID_RED);
+    }
+
+    public void testUnregister_knownRequest_releasesCaller() throws Exception {
+        NetworkTemplate[] templates = new NetworkTemplate[] { sTemplateImsi1 };
+        DataUsageRequest inputRequest = new DataUsageRequest(
+                DataUsageRequest.REQUEST_ID_UNSET, templates, NO_UIDS, THRESHOLD_BYTES);
+
+        DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder,
+                Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE);
+        assertTrue(request.requestId > 0);
+        assertTrue(Arrays.deepEquals(templates, request.templates));
+        assertNull(request.uids);
+        assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
+        Mockito.verify(mockBinder).linkToDeath(any(IBinder.DeathRecipient.class), anyInt());
+
+        mStatsObservers.unregister(request, Process.SYSTEM_UID);
+        waitForObserverToIdle();
+
+        Mockito.verify(mockBinder).unlinkToDeath(any(IBinder.DeathRecipient.class), anyInt());
+    }
+
+    public void testUnregister_knownRequest_invalidUid_doesNotUnregister() throws Exception {
+        NetworkTemplate[] templates = new NetworkTemplate[] { sTemplateImsi1 };
+        DataUsageRequest inputRequest = new DataUsageRequest(
+                DataUsageRequest.REQUEST_ID_UNSET, templates, NO_UIDS, THRESHOLD_BYTES);
+
+        DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder,
+                UID_RED, NetworkStatsAccess.Level.DEVICE);
+        assertTrue(request.requestId > 0);
+        assertTrue(Arrays.deepEquals(templates, request.templates));
+        assertNull(request.uids);
+        assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
+        Mockito.verify(mockBinder).linkToDeath(any(IBinder.DeathRecipient.class), anyInt());
+
+        mStatsObservers.unregister(request, UID_BLUE);
+        waitForObserverToIdle();
+
+        Mockito.verifyZeroInteractions(mockBinder);
+    }
+
+    public void testUpdateStats_initialSample_doesNotNotify() throws Exception {
+        NetworkTemplate[] templates = new NetworkTemplate[] { sTemplateImsi1 };
+        DataUsageRequest inputRequest = new DataUsageRequest(
+                DataUsageRequest.REQUEST_ID_UNSET, templates, NO_UIDS, THRESHOLD_BYTES);
+
+        DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder,
+                Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE);
+        assertTrue(request.requestId > 0);
+        assertTrue(Arrays.deepEquals(templates, request.templates));
+        assertNull(request.uids);
+        assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
+
+        NetworkIdentitySet identSet = new NetworkIdentitySet();
+        identSet.add(new NetworkIdentity(
+                TYPE_MOBILE, TelephonyManager.NETWORK_TYPE_UNKNOWN,
+                IMSI_1, null /* networkId */, false /* roaming */));
+        mActiveIfaces.put(TEST_IFACE, identSet);
+
+        // Baseline
+        NetworkStats xtSnapshot = new NetworkStats(TEST_START, 1 /* initialSize */)
+                .addIfaceValues(TEST_IFACE, BASE_BYTES, 8L, BASE_BYTES, 16L);
+        NetworkStats uidSnapshot = null;
+
+        mStatsObservers.updateStats(
+                xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces,
+                VPN_INFO, TEST_START);
+        waitForObserverToIdle();
+
+        assertTrue(mCv.block(WAIT_TIMEOUT));
+        assertEquals(INVALID_TYPE, mHandler.mLastMessageType);
+    }
+
+    public void testUpdateStats_belowThreshold_doesNotNotify() throws Exception {
+        NetworkTemplate[] templates = new NetworkTemplate[] { sTemplateImsi1 };
+        DataUsageRequest inputRequest = new DataUsageRequest(
+                DataUsageRequest.REQUEST_ID_UNSET, templates, NO_UIDS, THRESHOLD_BYTES);
+
+        DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder,
+                Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE);
+        assertTrue(request.requestId > 0);
+        assertTrue(Arrays.deepEquals(templates, request.templates));
+        assertNull(request.uids);
+        assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
+
+        NetworkIdentitySet identSet = new NetworkIdentitySet();
+        identSet.add(new NetworkIdentity(
+                TYPE_MOBILE, TelephonyManager.NETWORK_TYPE_UNKNOWN,
+                IMSI_1, null /* networkId */, false /* roaming */));
+        mActiveIfaces.put(TEST_IFACE, identSet);
+
+        // Baseline
+        NetworkStats xtSnapshot = new NetworkStats(TEST_START, 1 /* initialSize */)
+                .addIfaceValues(TEST_IFACE, BASE_BYTES, 8L, BASE_BYTES, 16L);
+        NetworkStats uidSnapshot = null;
+        mStatsObservers.updateStats(
+                xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces,
+                VPN_INFO, TEST_START);
+
+        // Delta
+        xtSnapshot = new NetworkStats(TEST_START, 1 /* initialSize */)
+                .addIfaceValues(TEST_IFACE, BASE_BYTES + 1024L, 10L, BASE_BYTES + 2048L, 20L);
+        mStatsObservers.updateStats(
+                xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces,
+                VPN_INFO, TEST_START);
+        waitForObserverToIdle();
+
+        assertTrue(mCv.block(WAIT_TIMEOUT));
+        mCv.block(WAIT_TIMEOUT);
+        assertEquals(INVALID_TYPE, mHandler.mLastMessageType);
+    }
+
+    public void testUpdateStats_aboveThresholdNetwork_notifies() throws Exception {
+        NetworkTemplate[] templates = new NetworkTemplate[] { sTemplateImsi1 };
+        DataUsageRequest inputRequest = new DataUsageRequest(
+                DataUsageRequest.REQUEST_ID_UNSET, templates, NO_UIDS, THRESHOLD_BYTES);
+
+        DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder,
+                Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE);
+        assertTrue(request.requestId > 0);
+        assertTrue(Arrays.deepEquals(templates, request.templates));
+        assertNull(request.uids);
+        assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
+
+        NetworkIdentitySet identSet = new NetworkIdentitySet();
+        identSet.add(new NetworkIdentity(
+                TYPE_MOBILE, TelephonyManager.NETWORK_TYPE_UNKNOWN,
+                IMSI_1, null /* networkId */, false /* roaming */));
+        mActiveIfaces.put(TEST_IFACE, identSet);
+
+        // Baseline
+        NetworkStats xtSnapshot = new NetworkStats(TEST_START, 1 /* initialSize */)
+                .addIfaceValues(TEST_IFACE, BASE_BYTES, 8L, BASE_BYTES, 16L);
+        NetworkStats uidSnapshot = null;
+        mStatsObservers.updateStats(
+                xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces,
+                VPN_INFO, TEST_START);
+
+        // Delta
+        xtSnapshot = new NetworkStats(TEST_START + MINUTE_IN_MILLIS, 1 /* initialSize */)
+                .addIfaceValues(TEST_IFACE, BASE_BYTES + THRESHOLD_BYTES, 12L,
+                        BASE_BYTES + THRESHOLD_BYTES, 22L);
+        mStatsObservers.updateStats(
+                xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces,
+                VPN_INFO, TEST_START);
+        waitForObserverToIdle();
+
+        assertTrue(mCv.block(WAIT_TIMEOUT));
+        assertEquals(NetworkStatsManager.CALLBACK_LIMIT_REACHED, mHandler.mLastMessageType);
+    }
+
+    public void testUpdateStats_aboveThresholdMultipleNetwork_notifies() throws Exception {
+        NetworkTemplate[] templates = new NetworkTemplate[] { sTemplateImsi1, sTemplateImsi2 };
+        DataUsageRequest inputRequest = new DataUsageRequest(
+                DataUsageRequest.REQUEST_ID_UNSET, templates, NO_UIDS, THRESHOLD_BYTES);
+
+        DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder,
+                UID_RED, NetworkStatsAccess.Level.DEVICESUMMARY);
+        assertTrue(request.requestId > 0);
+        assertTrue(Arrays.deepEquals(templates, request.templates));
+        assertNull(request.uids);
+        assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
+
+        NetworkIdentitySet identSet1 = new NetworkIdentitySet();
+        identSet1.add(new NetworkIdentity(
+                TYPE_MOBILE, TelephonyManager.NETWORK_TYPE_UNKNOWN,
+                IMSI_1, null /* networkId */, false /* roaming */));
+        mActiveIfaces.put(TEST_IFACE, identSet1);
+
+        NetworkIdentitySet identSet2 = new NetworkIdentitySet();
+        identSet2.add(new NetworkIdentity(
+                TYPE_MOBILE, TelephonyManager.NETWORK_TYPE_UNKNOWN,
+                IMSI_2, null /* networkId */, false /* roaming */));
+        mActiveIfaces.put(TEST_IFACE2, identSet2);
+
+        // Baseline
+        NetworkStats xtSnapshot = new NetworkStats(TEST_START, 1 /* initialSize */)
+                .addIfaceValues(TEST_IFACE, BASE_BYTES, 8L, BASE_BYTES, 16L)
+                .addIfaceValues(TEST_IFACE2, BASE_BYTES + 1234L, 18L, BASE_BYTES, 12L);
+        NetworkStats uidSnapshot = null;
+        mStatsObservers.updateStats(
+                xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces,
+                VPN_INFO, TEST_START);
+
+        // Delta - traffic on IMSI2
+        xtSnapshot = new NetworkStats(TEST_START + MINUTE_IN_MILLIS, 1 /* initialSize */)
+                .addIfaceValues(TEST_IFACE, BASE_BYTES, 8L, BASE_BYTES, 16L)
+                .addIfaceValues(TEST_IFACE2, BASE_BYTES + THRESHOLD_BYTES, 22L,
+                        BASE_BYTES + THRESHOLD_BYTES, 24L);
+        mStatsObservers.updateStats(
+                xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces,
+                VPN_INFO, TEST_START);
+        waitForObserverToIdle();
+
+        assertTrue(mCv.block(WAIT_TIMEOUT));
+        assertEquals(NetworkStatsManager.CALLBACK_LIMIT_REACHED, mHandler.mLastMessageType);
+    }
+
+    public void testUpdateStats_aboveThresholdUid_notifies() throws Exception {
+        NetworkTemplate[] templates = new NetworkTemplate[] { sTemplateImsi1 };
+        int[] uids = new int[] { UID_RED, UID_BLUE, UID_GREEN };
+        DataUsageRequest inputRequest = new DataUsageRequest(
+                DataUsageRequest.REQUEST_ID_UNSET, templates, uids, THRESHOLD_BYTES);
+
+        DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder,
+                Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE);
+        assertTrue(request.requestId > 0);
+        assertTrue(Arrays.deepEquals(templates, request.templates));
+        assertTrue(Arrays.equals(uids,request.uids));
+        assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
+
+        NetworkIdentitySet identSet = new NetworkIdentitySet();
+        identSet.add(new NetworkIdentity(
+                TYPE_MOBILE, TelephonyManager.NETWORK_TYPE_UNKNOWN,
+                IMSI_1, null /* networkId */, false /* roaming */));
+        mActiveUidIfaces.put(TEST_IFACE, identSet);
+
+        // Baseline
+        NetworkStats xtSnapshot = null;
+        NetworkStats uidSnapshot = new NetworkStats(TEST_START, 2 /* initialSize */)
+                .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, TAG_NONE, ROAMING_DEFAULT,
+                        BASE_BYTES, 2L, BASE_BYTES, 2L, 0L);
+        mStatsObservers.updateStats(
+                xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces,
+                VPN_INFO, TEST_START);
+
+        // Delta
+        uidSnapshot = new NetworkStats(TEST_START+ 2 * MINUTE_IN_MILLIS, 2 /* initialSize */)
+                .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, TAG_NONE, ROAMING_DEFAULT,
+                        BASE_BYTES + THRESHOLD_BYTES, 2L, BASE_BYTES + THRESHOLD_BYTES, 2L, 0L);
+        mStatsObservers.updateStats(
+                xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces,
+                VPN_INFO, TEST_START);
+        waitForObserverToIdle();
+
+        assertTrue(mCv.block(WAIT_TIMEOUT));
+        assertEquals(NetworkStatsManager.CALLBACK_LIMIT_REACHED, mHandler.mLastMessageType);
+    }
+
+    public void testUpdateStats_defaultAccess_noUid_notifiesSameUid() throws Exception {
+        NetworkTemplate[] templates = new NetworkTemplate[] { sTemplateImsi1 };
+        DataUsageRequest inputRequest = new DataUsageRequest(
+                DataUsageRequest.REQUEST_ID_UNSET, templates, NO_UIDS, THRESHOLD_BYTES);
+
+        DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder,
+                UID_RED, NetworkStatsAccess.Level.DEFAULT);
+        assertTrue(request.requestId > 0);
+        assertTrue(Arrays.deepEquals(templates, request.templates));
+        assertNull(request.uids);
+        assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
+
+        NetworkIdentitySet identSet = new NetworkIdentitySet();
+        identSet.add(new NetworkIdentity(
+                TYPE_MOBILE, TelephonyManager.NETWORK_TYPE_UNKNOWN,
+                IMSI_1, null /* networkId */, false /* roaming */));
+        mActiveUidIfaces.put(TEST_IFACE, identSet);
+
+        // Baseline
+        NetworkStats xtSnapshot = null;
+        NetworkStats uidSnapshot = new NetworkStats(TEST_START, 2 /* initialSize */)
+                .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, TAG_NONE, ROAMING_DEFAULT,
+                        BASE_BYTES, 2L, BASE_BYTES, 2L, 0L);
+        mStatsObservers.updateStats(
+                xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces,
+                VPN_INFO, TEST_START);
+
+        // Delta
+        uidSnapshot = new NetworkStats(TEST_START+ 2 * MINUTE_IN_MILLIS, 2 /* initialSize */)
+                .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, TAG_NONE, ROAMING_DEFAULT,
+                        BASE_BYTES + THRESHOLD_BYTES, 2L, BASE_BYTES + THRESHOLD_BYTES, 2L, 0L);
+        mStatsObservers.updateStats(
+                xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces,
+                VPN_INFO, TEST_START);
+        waitForObserverToIdle();
+
+        assertTrue(mCv.block(WAIT_TIMEOUT));
+        assertEquals(NetworkStatsManager.CALLBACK_LIMIT_REACHED, mHandler.mLastMessageType);
+    }
+
+    public void testUpdateStats_defaultAccess_noUid_usageOtherUid_doesNotNotify() throws Exception {
+        NetworkTemplate[] templates = new NetworkTemplate[] { sTemplateImsi1 };
+        DataUsageRequest inputRequest = new DataUsageRequest(
+                DataUsageRequest.REQUEST_ID_UNSET, templates, NO_UIDS, THRESHOLD_BYTES);
+
+        DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder,
+                UID_BLUE, NetworkStatsAccess.Level.DEFAULT);
+        assertTrue(request.requestId > 0);
+        assertTrue(Arrays.deepEquals(templates, request.templates));
+        assertNull(request.uids);
+        assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
+
+        NetworkIdentitySet identSet = new NetworkIdentitySet();
+        identSet.add(new NetworkIdentity(
+                TYPE_MOBILE, TelephonyManager.NETWORK_TYPE_UNKNOWN,
+                IMSI_1, null /* networkId */, false /* roaming */));
+        mActiveUidIfaces.put(TEST_IFACE, identSet);
+
+        // Baseline
+        NetworkStats xtSnapshot = null;
+        NetworkStats uidSnapshot = new NetworkStats(TEST_START, 2 /* initialSize */)
+                .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, TAG_NONE, ROAMING_DEFAULT,
+                        BASE_BYTES, 2L, BASE_BYTES, 2L, 0L);
+        mStatsObservers.updateStats(
+                xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces,
+                VPN_INFO, TEST_START);
+
+        // Delta
+        uidSnapshot = new NetworkStats(TEST_START+ 2 * MINUTE_IN_MILLIS, 2 /* initialSize */)
+                .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, TAG_NONE, ROAMING_DEFAULT,
+                        BASE_BYTES + THRESHOLD_BYTES, 2L, BASE_BYTES + THRESHOLD_BYTES, 2L, 0L);
+        mStatsObservers.updateStats(
+                xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces,
+                VPN_INFO, TEST_START);
+        waitForObserverToIdle();
+
+        assertTrue(mCv.block(WAIT_TIMEOUT));
+        assertEquals(INVALID_TYPE, mHandler.mLastMessageType);
+    }
+
+    public void testUpdateStats_userAccess_usageSameUser_notifies() throws Exception {
+        NetworkTemplate[] templates = new NetworkTemplate[] { sTemplateImsi1 };
+        DataUsageRequest inputRequest = new DataUsageRequest(
+                DataUsageRequest.REQUEST_ID_UNSET, templates, NO_UIDS, THRESHOLD_BYTES);
+
+        DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder,
+                UID_BLUE, NetworkStatsAccess.Level.USER);
+        assertTrue(request.requestId > 0);
+        assertTrue(Arrays.deepEquals(templates, request.templates));
+        assertNull(request.uids);
+        assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
+
+        NetworkIdentitySet identSet = new NetworkIdentitySet();
+        identSet.add(new NetworkIdentity(
+                TYPE_MOBILE, TelephonyManager.NETWORK_TYPE_UNKNOWN,
+                IMSI_1, null /* networkId */, false /* roaming */));
+        mActiveUidIfaces.put(TEST_IFACE, identSet);
+
+        // Baseline
+        NetworkStats xtSnapshot = null;
+        NetworkStats uidSnapshot = new NetworkStats(TEST_START, 2 /* initialSize */)
+                .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, TAG_NONE, ROAMING_DEFAULT,
+                        BASE_BYTES, 2L, BASE_BYTES, 2L, 0L);
+        mStatsObservers.updateStats(
+                xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces,
+                VPN_INFO, TEST_START);
+
+        // Delta
+        uidSnapshot = new NetworkStats(TEST_START+ 2 * MINUTE_IN_MILLIS, 2 /* initialSize */)
+                .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, TAG_NONE, ROAMING_DEFAULT,
+                        BASE_BYTES + THRESHOLD_BYTES, 2L, BASE_BYTES + THRESHOLD_BYTES, 2L, 0L);
+        mStatsObservers.updateStats(
+                xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces,
+                VPN_INFO, TEST_START);
+        waitForObserverToIdle();
+
+        assertTrue(mCv.block(WAIT_TIMEOUT));
+        assertEquals(NetworkStatsManager.CALLBACK_LIMIT_REACHED, mHandler.mLastMessageType);
+    }
+
+    public void testUpdateStats_userAccess_usageAnotherUser_doesNotNotify() throws Exception {
+        NetworkTemplate[] templates = new NetworkTemplate[] { sTemplateImsi1 };
+        DataUsageRequest inputRequest = new DataUsageRequest(
+                DataUsageRequest.REQUEST_ID_UNSET, templates, NO_UIDS, THRESHOLD_BYTES);
+
+        DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder,
+                UID_RED, NetworkStatsAccess.Level.USER);
+        assertTrue(request.requestId > 0);
+        assertTrue(Arrays.deepEquals(templates, request.templates));
+        assertNull(request.uids);
+        assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
+
+        NetworkIdentitySet identSet = new NetworkIdentitySet();
+        identSet.add(new NetworkIdentity(
+                TYPE_MOBILE, TelephonyManager.NETWORK_TYPE_UNKNOWN,
+                IMSI_1, null /* networkId */, false /* roaming */));
+        mActiveUidIfaces.put(TEST_IFACE, identSet);
+
+        // Baseline
+        NetworkStats xtSnapshot = null;
+        NetworkStats uidSnapshot = new NetworkStats(TEST_START, 2 /* initialSize */)
+                .addValues(TEST_IFACE, UID_ANOTHER_USER, SET_DEFAULT, TAG_NONE, ROAMING_DEFAULT,
+                        BASE_BYTES, 2L, BASE_BYTES, 2L, 0L);
+        mStatsObservers.updateStats(
+                xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces,
+                VPN_INFO, TEST_START);
+
+        // Delta
+        uidSnapshot = new NetworkStats(TEST_START+ 2 * MINUTE_IN_MILLIS, 2 /* initialSize */)
+                .addValues(TEST_IFACE, UID_ANOTHER_USER, SET_DEFAULT, TAG_NONE, ROAMING_DEFAULT,
+                        BASE_BYTES + THRESHOLD_BYTES, 2L, BASE_BYTES + THRESHOLD_BYTES, 2L, 0L);
+        mStatsObservers.updateStats(
+                xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces,
+                VPN_INFO, TEST_START);
+        waitForObserverToIdle();
+
+        assertTrue(mCv.block(WAIT_TIMEOUT));
+        assertEquals(INVALID_TYPE, mHandler.mLastMessageType);
+    }
+
+    private void waitForObserverToIdle() {
+        // Send dummy message to make sure that any previous message has been handled
+        mHandler.sendMessage(mHandler.obtainMessage(-1));
+        mObserverHandlerThread.waitForIdle(WAIT_TIMEOUT);
+    }
+}
diff --git a/services/tests/servicestests/src/com/android/server/NetworkStatsServiceTest.java b/services/tests/servicestests/src/com/android/server/net/NetworkStatsServiceTest.java
similarity index 73%
rename from services/tests/servicestests/src/com/android/server/NetworkStatsServiceTest.java
rename to services/tests/servicestests/src/com/android/server/net/NetworkStatsServiceTest.java
index 8cbd32d..4f6c7b9 100644
--- a/services/tests/servicestests/src/com/android/server/NetworkStatsServiceTest.java
+++ b/services/tests/servicestests/src/com/android/server/net/NetworkStatsServiceTest.java
@@ -14,7 +14,7 @@
  * limitations under the License.
  */
 
-package com.android.server;
+package com.android.server.net;
 
 import static android.content.Intent.ACTION_UID_REMOVED;
 import static android.content.Intent.EXTRA_UID;
@@ -43,6 +43,7 @@
 import static com.android.server.net.NetworkStatsService.ACTION_NETWORK_STATS_POLL;
 import static org.easymock.EasyMock.anyInt;
 import static org.easymock.EasyMock.anyLong;
+import static org.easymock.EasyMock.anyObject;
 import static org.easymock.EasyMock.capture;
 import static org.easymock.EasyMock.createMock;
 import static org.easymock.EasyMock.eq;
@@ -54,7 +55,10 @@
 import android.app.IAlarmListener;
 import android.app.IAlarmManager;
 import android.app.PendingIntent;
+import android.app.usage.NetworkStatsManager;
+import android.content.Context;
 import android.content.Intent;
+import android.net.DataUsageRequest;
 import android.net.IConnectivityManager;
 import android.net.INetworkManagementEventObserver;
 import android.net.INetworkStatsSession;
@@ -65,7 +69,17 @@
 import android.net.NetworkStats;
 import android.net.NetworkStatsHistory;
 import android.net.NetworkTemplate;
+import android.os.ConditionVariable;
+import android.os.Handler;
+import android.os.HandlerThread;
 import android.os.INetworkManagementService;
+import android.os.IBinder;
+import android.os.Looper;
+import android.os.Messenger;
+import android.os.MessageQueue;
+import android.os.MessageQueue.IdleHandler;
+import android.os.Message;
+import android.os.PowerManager;
 import android.os.WorkSource;
 import android.telephony.TelephonyManager;
 import android.test.AndroidTestCase;
@@ -74,6 +88,7 @@
 import android.util.TrustedTime;
 
 import com.android.internal.net.VpnInfo;
+import com.android.server.BroadcastInterceptingContext;
 import com.android.server.net.NetworkStatsService;
 import com.android.server.net.NetworkStatsService.NetworkStatsSettings;
 import com.android.server.net.NetworkStatsService.NetworkStatsSettings.Config;
@@ -85,6 +100,7 @@
 
 import java.io.File;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 
 /**
@@ -113,16 +129,20 @@
     private static final int UID_BLUE = 1002;
     private static final int UID_GREEN = 1003;
 
+    private static final long WAIT_TIMEOUT = 2 * 1000;  // 2 secs
+    private static final int INVALID_TYPE = -1;
+
     private long mElapsedRealtime;
 
     private BroadcastInterceptingContext mServiceContext;
     private File mStatsDir;
 
     private INetworkManagementService mNetManager;
-    private IAlarmManager mAlarmManager;
     private TrustedTime mTime;
     private NetworkStatsSettings mSettings;
     private IConnectivityManager mConnManager;
+    private IdleableHandlerThread mHandlerThread;
+    private Handler mHandler;
 
     private NetworkStatsService mService;
     private INetworkStatsSession mSession;
@@ -139,13 +159,28 @@
         }
 
         mNetManager = createMock(INetworkManagementService.class);
-        mAlarmManager = createMock(IAlarmManager.class);
+
+        // TODO: Mock AlarmManager when migrating this test to Mockito.
+        AlarmManager alarmManager = (AlarmManager) mServiceContext
+                .getSystemService(Context.ALARM_SERVICE);
         mTime = createMock(TrustedTime.class);
         mSettings = createMock(NetworkStatsSettings.class);
         mConnManager = createMock(IConnectivityManager.class);
 
+        PowerManager powerManager = (PowerManager) mServiceContext.getSystemService(
+                Context.POWER_SERVICE);
+        PowerManager.WakeLock wakeLock =
+                powerManager.newWakeLock(PowerManager.PARTIAL_WAKE_LOCK, TAG);
+
         mService = new NetworkStatsService(
-                mServiceContext, mNetManager, mAlarmManager, mTime, mStatsDir, mSettings);
+                mServiceContext, mNetManager, alarmManager, wakeLock, mTime,
+                TelephonyManager.getDefault(), mSettings, new NetworkStatsObservers(),
+                mStatsDir, getBaseDir(mStatsDir));
+        mHandlerThread = new IdleableHandlerThread("HandlerThread");
+        mHandlerThread.start();
+        Handler.Callback callback = new NetworkStatsService.HandlerCallback(mService);
+        mHandler = new Handler(mHandlerThread.getLooper(), callback);
+        mService.setHandler(mHandler, callback);
         mService.bindConnectivityManager(mConnManager);
 
         mElapsedRealtime = 0L;
@@ -178,7 +213,6 @@
         mStatsDir = null;
 
         mNetManager = null;
-        mAlarmManager = null;
         mTime = null;
         mSettings = null;
         mConnManager = null;
@@ -217,7 +251,7 @@
         expectNetworkStatsPoll();
 
         replay();
-        mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
+        forcePollAndWaitForIdle();
 
         // verify service recorded history
         assertNetworkTotal(sTemplateWifi, 1024L, 1L, 2048L, 2L, 0);
@@ -234,7 +268,7 @@
         expectNetworkStatsPoll();
 
         replay();
-        mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
+        forcePollAndWaitForIdle();
 
         // verify service recorded history
         assertNetworkTotal(sTemplateWifi, 4096L, 4L, 8192L, 8L, 0);
@@ -282,7 +316,7 @@
         mService.incrementOperationCount(UID_RED, 0xFAAD, 6);
 
         replay();
-        mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
+        forcePollAndWaitForIdle();
 
         // verify service recorded history
         assertNetworkTotal(sTemplateWifi, 1024L, 8L, 2048L, 16L, 0);
@@ -362,7 +396,7 @@
         expectNetworkStatsPoll();
 
         replay();
-        mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
+        forcePollAndWaitForIdle();
 
         // verify service recorded history
         history = mSession.getHistoryForNetwork(sTemplateWifi, FIELD_ALL);
@@ -380,7 +414,7 @@
         expectNetworkStatsPoll();
 
         replay();
-        mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
+        forcePollAndWaitForIdle();
 
         // verify identical stats, but spread across 4 buckets now
         history = mSession.getHistoryForNetwork(sTemplateWifi, FIELD_ALL);
@@ -420,7 +454,7 @@
         mService.incrementOperationCount(UID_RED, 0xF00D, 10);
 
         replay();
-        mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
+        forcePollAndWaitForIdle();
 
         // verify service recorded history
         assertNetworkTotal(sTemplateImsi1, 2048L, 16L, 512L, 4L, 0);
@@ -446,7 +480,7 @@
 
         replay();
         mService.forceUpdateIfaces();
-        mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
+        forcePollAndWaitForIdle();
         verifyAndReset();
 
         // create traffic on second network
@@ -465,7 +499,7 @@
         mService.incrementOperationCount(UID_BLUE, 0xFAAD, 10);
 
         replay();
-        mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
+        forcePollAndWaitForIdle();
 
         // verify original history still intact
         assertNetworkTotal(sTemplateImsi1, 2048L, 16L, 512L, 4L, 0);
@@ -511,7 +545,7 @@
         mService.incrementOperationCount(UID_RED, 0xFAAD, 10);
 
         replay();
-        mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
+        forcePollAndWaitForIdle();
 
         // verify service recorded history
         assertNetworkTotal(sTemplateWifi, 4128L, 258L, 544L, 34L, 0);
@@ -578,7 +612,7 @@
         mService.incrementOperationCount(UID_RED, 0xF00D, 5);
 
         replay();
-        mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
+        forcePollAndWaitForIdle();
 
         // verify service recorded history
         assertUidTotal(sTemplateImsi1, UID_RED, 1024L, 8L, 1024L, 8L, 5);
@@ -598,7 +632,7 @@
 
         replay();
         mService.forceUpdateIfaces();
-        mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
+        forcePollAndWaitForIdle();
         verifyAndReset();
 
         // create traffic on second network
@@ -616,7 +650,7 @@
         mService.incrementOperationCount(UID_RED, 0xFAAD, 5);
 
         replay();
-        mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
+        forcePollAndWaitForIdle();
 
         // verify that ALL_MOBILE template combines both
         assertUidTotal(sTemplateImsi1, UID_RED, 1536L, 12L, 1280L, 10L, 10);
@@ -652,7 +686,7 @@
         mService.incrementOperationCount(UID_RED, 0xF00D, 1);
 
         replay();
-        mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
+        forcePollAndWaitForIdle();
 
         // verify service recorded history
         assertUidTotal(sTemplateWifi, UID_RED, 50L, 5L, 50L, 5L, 1);
@@ -671,7 +705,7 @@
         expectNetworkStatsPoll();
 
         replay();
-        mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
+        forcePollAndWaitForIdle();
 
         // first verify entire history present
         NetworkStats stats = mSession.getSummaryForAllUid(
@@ -722,7 +756,7 @@
         mService.incrementOperationCount(UID_RED, 0xF00D, 1);
 
         replay();
-        mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
+        forcePollAndWaitForIdle();
 
         // verify service recorded history
         assertUidTotal(sTemplateWifi, UID_RED, 128L, 2L, 128L, 2L, 1);
@@ -744,7 +778,7 @@
         mService.incrementOperationCount(UID_RED, 0xFAAD, 1);
 
         replay();
-        mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
+        forcePollAndWaitForIdle();
 
         // test that we combined correctly
         assertUidTotal(sTemplateWifi, UID_RED, 160L, 4L, 160L, 4L, 2);
@@ -795,7 +829,7 @@
         expectNetworkStatsPoll();
 
         replay();
-        mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
+        forcePollAndWaitForIdle();
 
         // verify service recorded history
         assertUidTotal(sTemplateImsi1, UID_RED, 128L, 2L, 128L, 2L, 0);
@@ -843,7 +877,7 @@
         expectNetworkStatsPoll();
 
         replay();
-        mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
+        forcePollAndWaitForIdle();
 
         // verify service recorded history
         assertNetworkTotal(sTemplateImsi1, 2048L, 16L, 512L, 4L, 0);
@@ -853,6 +887,285 @@
 
     }
 
+    public void testRegisterDataUsageCallback_network() throws Exception {
+        // pretend that wifi network comes online; service should ask about full
+        // network state, and poll any existing interfaces before updating.
+        expectCurrentTime();
+        expectDefaultSettings();
+        expectNetworkState(buildWifiState());
+        expectNetworkStatsSummary(buildEmptyStats());
+        expectNetworkStatsUidDetail(buildEmptyStats());
+        expectNetworkStatsPoll();
+        expectBandwidthControlCheck();
+
+        replay();
+        mService.forceUpdateIfaces();
+
+        // verify service has empty history for wifi
+        assertNetworkTotal(sTemplateWifi, 0L, 0L, 0L, 0L, 0);
+        verifyAndReset();
+
+        String callingPackage = "the.calling.package";
+        long thresholdInBytes = 1L;  // very small; should be overriden by framework
+        NetworkTemplate[] templates = new NetworkTemplate[] { sTemplateWifi };
+        DataUsageRequest inputRequest = new DataUsageRequest(
+                DataUsageRequest.REQUEST_ID_UNSET, templates, null /* uids */, thresholdInBytes);
+
+        // Create a messenger that waits for callback activity
+        ConditionVariable cv = new ConditionVariable(false);
+        LatchedHandler latchedHandler = new LatchedHandler(Looper.getMainLooper(), cv);
+        Messenger messenger = new Messenger(latchedHandler);
+
+        // Allow binder to connect
+        IBinder mockBinder = createMock(IBinder.class);
+        mockBinder.linkToDeath((IBinder.DeathRecipient) anyObject(), anyInt());
+        EasyMock.replay(mockBinder);
+
+        // Force poll
+        expectCurrentTime();
+        expectDefaultSettings();
+        expectNetworkStatsSummary(buildEmptyStats());
+        expectNetworkStatsUidDetail(buildEmptyStats());
+        expectNetworkStatsPoll();
+        replay();
+
+        // Register and verify request and that binder was called
+        DataUsageRequest request =
+                mService.registerDataUsageCallback(callingPackage, inputRequest,
+                        messenger, mockBinder);
+        assertTrue(request.requestId > 0);
+        assertTrue(Arrays.deepEquals(templates, request.templates));
+        assertNull(request.uids);
+        long minThresholdInBytes = 2 * 1024 * 1024; // 2 MB
+        assertEquals(minThresholdInBytes, request.thresholdInBytes);
+
+        // Send dummy message to make sure that any previous message has been handled
+        mHandler.sendMessage(mHandler.obtainMessage(-1));
+        mHandlerThread.waitForIdle(WAIT_TIMEOUT);
+
+        verifyAndReset();
+
+        // Make sure that the caller binder gets connected
+        EasyMock.verify(mockBinder);
+        EasyMock.reset(mockBinder);
+
+        // modify some number on wifi, and trigger poll event
+        // not enough traffic to call data usage callback
+        incrementCurrentTime(HOUR_IN_MILLIS);
+        expectCurrentTime();
+        expectDefaultSettings();
+        expectNetworkStatsSummary(new NetworkStats(getElapsedRealtime(), 1)
+                .addIfaceValues(TEST_IFACE, 1024L, 1L, 2048L, 2L));
+        expectNetworkStatsUidDetail(buildEmptyStats());
+        expectNetworkStatsPoll();
+
+        replay();
+        forcePollAndWaitForIdle();
+
+        // verify service recorded history
+        verifyAndReset();
+        assertNetworkTotal(sTemplateWifi, 1024L, 1L, 2048L, 2L, 0);
+
+        // make sure callback has not being called
+        assertEquals(INVALID_TYPE, latchedHandler.mLastMessageType);
+
+        // and bump forward again, with counters going higher. this is
+        // important, since it will trigger the data usage callback
+        incrementCurrentTime(DAY_IN_MILLIS);
+        expectCurrentTime();
+        expectDefaultSettings();
+        expectNetworkStatsSummary(new NetworkStats(getElapsedRealtime(), 1)
+                .addIfaceValues(TEST_IFACE, 4096000L, 4L, 8192000L, 8L));
+        expectNetworkStatsUidDetail(buildEmptyStats());
+        expectNetworkStatsPoll();
+
+        replay();
+        forcePollAndWaitForIdle();
+
+        // verify service recorded history
+        assertNetworkTotal(sTemplateWifi, 4096000L, 4L, 8192000L, 8L, 0);
+        verifyAndReset();
+
+        // Wait for the caller to ack receipt of CALLBACK_LIMIT_REACHED
+        assertTrue(cv.block(WAIT_TIMEOUT));
+        assertEquals(NetworkStatsManager.CALLBACK_LIMIT_REACHED, latchedHandler.mLastMessageType);
+        cv.close();
+
+        // Allow binder to disconnect
+        expect(mockBinder.unlinkToDeath((IBinder.DeathRecipient) anyObject(), anyInt()))
+                .andReturn(true);
+        EasyMock.replay(mockBinder);
+
+        // Unregister request
+        mService.unregisterDataUsageRequest(request);
+
+        // Wait for the caller to ack receipt of CALLBACK_RELEASED
+        assertTrue(cv.block(WAIT_TIMEOUT));
+        assertEquals(NetworkStatsManager.CALLBACK_RELEASED, latchedHandler.mLastMessageType);
+
+        // Make sure that the caller binder gets disconnected
+        EasyMock.verify(mockBinder);
+    }
+
+    public void testRegisterDataUsageCallback_uids() throws Exception {
+        // pretend that network comes online
+        expectCurrentTime();
+        expectDefaultSettings();
+        expectNetworkState(buildMobile3gState(IMSI_1, true /* isRoaming */));
+        expectNetworkStatsSummary(buildEmptyStats());
+        expectNetworkStatsUidDetail(buildEmptyStats());
+        expectNetworkStatsPoll();
+        expectBandwidthControlCheck();
+
+        replay();
+        mService.forceUpdateIfaces();
+        verifyAndReset();
+
+        String callingPackage = "the.calling.package";
+        long thresholdInBytes = 10 * 1024 * 1024;  // 10 MB
+        NetworkTemplate[] templates = new NetworkTemplate[] { sTemplateImsi1, sTemplateImsi2 };
+        int[] uids = new int[] { UID_RED };
+        DataUsageRequest inputRequest = new DataUsageRequest(
+                DataUsageRequest.REQUEST_ID_UNSET, templates, uids, thresholdInBytes);
+
+        // Create a messenger that waits for callback activity
+        ConditionVariable cv = new ConditionVariable(false);
+        cv.close();
+        LatchedHandler latchedHandler = new LatchedHandler(Looper.getMainLooper(), cv);
+        Messenger messenger = new Messenger(latchedHandler);
+
+        // Allow binder to connect
+        IBinder mockBinder = createMock(IBinder.class);
+        mockBinder.linkToDeath((IBinder.DeathRecipient) anyObject(), anyInt());
+        EasyMock.replay(mockBinder);
+
+        // Force poll
+        expectCurrentTime();
+        expectDefaultSettings();
+        expectNetworkStatsSummary(buildEmptyStats());
+        expectNetworkStatsUidDetail(buildEmptyStats());
+        expectNetworkStatsPoll();
+        replay();
+
+        // Register and verify request and that binder was called
+        DataUsageRequest request =
+                mService.registerDataUsageCallback(callingPackage, inputRequest,
+                        messenger, mockBinder);
+        assertTrue(request.requestId > 0);
+        assertTrue(Arrays.deepEquals(templates, request.templates));
+        assertTrue(Arrays.equals(uids, request.uids));
+        assertEquals(thresholdInBytes, request.thresholdInBytes);
+
+        // Wait for service to handle internal MSG_REGISTER_DATA_USAGE_LISTENER
+        mHandler.sendMessage(mHandler.obtainMessage(-1));
+        mHandlerThread.waitForIdle(WAIT_TIMEOUT);
+
+        verifyAndReset();
+
+        // Make sure that the caller binder gets connected
+        EasyMock.verify(mockBinder);
+        EasyMock.reset(mockBinder);
+
+        // modify some number on mobile interface, and trigger poll event
+        // not enough traffic to call data usage callback
+        incrementCurrentTime(HOUR_IN_MILLIS);
+        expectCurrentTime();
+        expectDefaultSettings();
+        expectNetworkStatsSummary(buildEmptyStats());
+        expectNetworkStatsUidDetail(new NetworkStats(getElapsedRealtime(), 1)
+                .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, TAG_NONE, ROAMING_DEFAULT, 128L, 2L,
+                        128L, 2L, 0L)
+                .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, 0xF00D, ROAMING_DEFAULT, 64L, 1L, 64L,
+                        1L, 0L));
+        expectNetworkStatsPoll();
+
+        replay();
+        forcePollAndWaitForIdle();
+
+        // verify service recorded history
+        assertUidTotal(sTemplateImsi1, UID_RED, 128L, 2L, 128L, 2L, 0);
+
+        // verify entire history present
+        NetworkStats stats = mSession.getSummaryForAllUid(
+                sTemplateImsi1, Long.MIN_VALUE, Long.MAX_VALUE, true);
+        assertEquals(2, stats.size());
+        assertValues(stats, IFACE_ALL, UID_RED, SET_DEFAULT, TAG_NONE, ROAMING_ROAMING, 128L, 2L,
+                128L, 2L, 0);
+        assertValues(stats, IFACE_ALL, UID_RED, SET_DEFAULT, 0xF00D, ROAMING_ROAMING, 64L, 1L, 64L,
+                1L, 0);
+
+        verifyAndReset();
+
+        // make sure callback has not being called
+        assertEquals(INVALID_TYPE, latchedHandler.mLastMessageType);
+
+        // and bump forward again, with counters going higher. this is
+        // important, since it will trigger the data usage callback
+        incrementCurrentTime(DAY_IN_MILLIS);
+        expectCurrentTime();
+        expectDefaultSettings();
+        expectNetworkStatsSummary(buildEmptyStats());
+        expectNetworkStatsUidDetail(new NetworkStats(getElapsedRealtime(), 1)
+                .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, TAG_NONE, ROAMING_DEFAULT,
+                        128000000L, 2L, 128000000L, 2L, 0L)
+                .addValues(TEST_IFACE, UID_RED, SET_DEFAULT, 0xF00D, ROAMING_DEFAULT,
+                        64000000L, 1L, 64000000L, 1L, 0L));
+        expectNetworkStatsPoll();
+
+        replay();
+        forcePollAndWaitForIdle();
+
+        // verify service recorded history
+        assertUidTotal(sTemplateImsi1, UID_RED, 128000000L, 2L, 128000000L, 2L, 0);
+
+        // verify entire history present
+        stats = mSession.getSummaryForAllUid(
+                sTemplateImsi1, Long.MIN_VALUE, Long.MAX_VALUE, true);
+        assertEquals(2, stats.size());
+        assertValues(stats, IFACE_ALL, UID_RED, SET_DEFAULT, TAG_NONE, ROAMING_ROAMING,
+                128000000L, 2L, 128000000L, 2L, 0);
+        assertValues(stats, IFACE_ALL, UID_RED, SET_DEFAULT, 0xF00D, ROAMING_ROAMING,
+                64000000L, 1L, 64000000L, 1L, 0);
+
+        verifyAndReset();
+
+        // Wait for the caller to ack receipt of CALLBACK_LIMIT_REACHED
+        assertTrue(cv.block(WAIT_TIMEOUT));
+        assertEquals(NetworkStatsManager.CALLBACK_LIMIT_REACHED, latchedHandler.mLastMessageType);
+        cv.close();
+
+        // Allow binder to disconnect
+        expect(mockBinder.unlinkToDeath((IBinder.DeathRecipient) anyObject(), anyInt()))
+                .andReturn(true);
+        EasyMock.replay(mockBinder);
+
+        // Unregister request
+        mService.unregisterDataUsageRequest(request);
+
+        // Wait for the caller to ack receipt of CALLBACK_RELEASED
+        assertTrue(cv.block(WAIT_TIMEOUT));
+        assertEquals(NetworkStatsManager.CALLBACK_RELEASED, latchedHandler.mLastMessageType);
+
+        // Make sure that the caller binder gets disconnected
+        EasyMock.verify(mockBinder);
+    }
+
+    public void testUnregisterDataUsageCallback_unknown_noop() throws Exception {
+        String callingPackage = "the.calling.package";
+        long thresholdInBytes = 10 * 1024 * 1024;  // 10 MB
+        NetworkTemplate[] templates = new NetworkTemplate[] { sTemplateImsi1, sTemplateImsi2 };
+        DataUsageRequest unknownRequest = new DataUsageRequest(
+                2, templates, null /* uids */, thresholdInBytes);
+
+        mService.unregisterDataUsageRequest(unknownRequest);
+    }
+
+    private static File getBaseDir(File statsDir) {
+        File baseDir = new File(statsDir, "netstats");
+        baseDir.mkdirs();
+        return baseDir;
+    }
+
     private void assertNetworkTotal(NetworkTemplate template, long rxBytes, long rxPackets,
             long txBytes, long txPackets, int operations) throws Exception {
         assertNetworkTotal(template, Long.MIN_VALUE, Long.MAX_VALUE, rxBytes, rxPackets, txBytes,
@@ -894,16 +1207,6 @@
     }
 
     private void expectSystemReady() throws Exception {
-        mAlarmManager.remove(isA(PendingIntent.class), EasyMock.<IAlarmListener>isNull());
-        expectLastCall().anyTimes();
-
-        mAlarmManager.set(eq(getContext().getPackageName()),
-                eq(AlarmManager.ELAPSED_REALTIME), anyLong(), anyLong(), anyLong(),
-                anyInt(), isA(PendingIntent.class), EasyMock.<IAlarmListener>isNull(),
-                EasyMock.<String>isNull(), EasyMock.<WorkSource>isNull(),
-                EasyMock.<AlarmManager.AlarmClockInfo>isNull());
-        expectLastCall().anyTimes();
-
         mNetManager.setGlobalAlert(anyLong());
         expectLastCall().atLeastOnce();
 
@@ -1093,11 +1396,75 @@
     }
 
     private void replay() {
-        EasyMock.replay(mNetManager, mAlarmManager, mTime, mSettings, mConnManager);
+        EasyMock.replay(mNetManager, mTime, mSettings, mConnManager);
     }
 
     private void verifyAndReset() {
-        EasyMock.verify(mNetManager, mAlarmManager, mTime, mSettings, mConnManager);
-        EasyMock.reset(mNetManager, mAlarmManager, mTime, mSettings, mConnManager);
+        EasyMock.verify(mNetManager, mTime, mSettings, mConnManager);
+        EasyMock.reset(mNetManager, mTime, mSettings, mConnManager);
     }
+
+    private void forcePollAndWaitForIdle() {
+        mServiceContext.sendBroadcast(new Intent(ACTION_NETWORK_STATS_POLL));
+        // Send dummy message to make sure that any previous message has been handled
+        mHandler.sendMessage(mHandler.obtainMessage(-1));
+        mHandlerThread.waitForIdle(WAIT_TIMEOUT);
+    }
+
+    static class LatchedHandler extends Handler {
+        private final ConditionVariable mCv;
+        int mLastMessageType = INVALID_TYPE;
+
+        LatchedHandler(Looper looper, ConditionVariable cv) {
+            super(looper);
+            mCv = cv;
+        }
+
+        @Override
+        public void handleMessage(Message msg) {
+            mLastMessageType = msg.what;
+            mCv.open();
+            super.handleMessage(msg);
+        }
+    }
+
+    /**
+     * A subclass of HandlerThread that allows callers to wait for it to become idle. waitForIdle
+     * will return immediately if the handler is already idle.
+     */
+    static class IdleableHandlerThread extends HandlerThread {
+        private IdleHandler mIdleHandler;
+
+        public IdleableHandlerThread(String name) {
+            super(name);
+        }
+
+        public void waitForIdle(long timeoutMs) {
+            final ConditionVariable cv = new ConditionVariable();
+            final MessageQueue queue = getLooper().getQueue();
+
+            synchronized (queue) {
+                if (queue.isIdle()) {
+                    return;
+                }
+
+                assertNull("BUG: only one idle handler allowed", mIdleHandler);
+                mIdleHandler = new IdleHandler() {
+                    public boolean queueIdle() {
+                        cv.open();
+                        mIdleHandler = null;
+                        return false;  // Remove the handler.
+                    }
+                };
+                queue.addIdleHandler(mIdleHandler);
+            }
+
+            if (!cv.block(timeoutMs)) {
+                fail("HandlerThread " + getName() + " did not become idle after " + timeoutMs
+                        + " ms");
+                queue.removeIdleHandler(mIdleHandler);
+            }
+        }
+    }
+
 }