Update AMS to wait for network state update if requested by the main thread.

Bug: 27803922
Test: runtest -c com.android.server.am.ActivityManagerServiceTest frameworks-services
      runtest -c com.android.server.am.ActivityManagerInternalTest frameworks-services
      cts-tradefed run singleCommand cts-dev --module CtsHostsideNetworkTests
      and manual
Change-Id: I7d1052b9941c1fae51ff8ab1c9b89dca3919ccd2
diff --git a/services/tests/servicestests/src/com/android/server/am/ActivityManagerInternalTest.java b/services/tests/servicestests/src/com/android/server/am/ActivityManagerInternalTest.java
index b5934ee..e7c91c0 100644
--- a/services/tests/servicestests/src/com/android/server/am/ActivityManagerInternalTest.java
+++ b/services/tests/servicestests/src/com/android/server/am/ActivityManagerInternalTest.java
@@ -17,9 +17,11 @@
 package com.android.server.am;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
 
 import android.app.ActivityManagerInternal;
-import android.support.test.filters.SmallTest;
+import android.os.SystemClock;
+import android.support.test.filters.MediumTest;
 import android.support.test.runner.AndroidJUnit4;
 
 import org.junit.Before;
@@ -43,9 +45,15 @@
  * Run: adb shell am instrument -e class com.android.server.am.ActivityManagerInternalTest -w \
  *     com.android.frameworks.servicestests/android.support.test.runner.AndroidJUnitRunner
  */
-@SmallTest
 @RunWith(AndroidJUnit4.class)
 public class ActivityManagerInternalTest {
+    private static final int TEST_UID1 = 111;
+    private static final int TEST_UID2 = 112;
+
+    private static final long TEST_PROC_STATE_SEQ1 = 1111;
+    private static final long TEST_PROC_STATE_SEQ2 = 1112;
+    private static final long TEST_PROC_STATE_SEQ3 = 1113;
+
     @Mock private ActivityManagerService.Injector mMockInjector;
 
     private ActivityManagerService mAms;
@@ -58,26 +66,149 @@
         mAmi = mAms.new LocalService();
     }
 
+    @MediumTest
     @Test
-    public void testNotifyNetworkPolicyRulesUpdated() {
-        // For checking there is no crash when there are no active uid records.
-        mAmi.notifyNetworkPolicyRulesUpdated(111, 11);
+    public void testNotifyNetworkPolicyRulesUpdated() throws Exception {
+        // Check there is no crash when there are no active uid records.
+        mAmi.notifyNetworkPolicyRulesUpdated(TEST_UID1, TEST_PROC_STATE_SEQ1);
 
-        // Insert active uid records.
-        final UidRecord record1 = addActiveUidRecord(222, 22);
-        final UidRecord record2 = addActiveUidRecord(333, 33);
-        // Notify that network policy rules are updated for uid 222.
-        mAmi.notifyNetworkPolicyRulesUpdated(222, 44);
-        assertEquals("UidRecord for uid 222 should be updated",
-                44L, record1.lastNetworkUpdatedProcStateSeq);
-        assertEquals("UidRecord for uid 333 should not be updated",
-                33L, record2.lastNetworkUpdatedProcStateSeq);
+        // Notify that network policy rules are updated for TEST_UID1 and verify that
+        // UidRecord.lastNetworkUpdateProcStateSeq is updated and any blocked threads are notified.
+        verifyNetworkUpdatedProcStateSeq(
+                TEST_PROC_STATE_SEQ2, // curProcStateSeq
+                TEST_PROC_STATE_SEQ1, // lastNetworkUpdateProcStateSeq
+                TEST_PROC_STATE_SEQ2, // procStateSeq to notify
+                true); // expectNotify
+
+        // Notify that network policy rules are updated for TEST_UID1 with already handled
+        // procStateSeq and verify that there is no notify call.
+        verifyNetworkUpdatedProcStateSeq(
+                TEST_PROC_STATE_SEQ1, // curProcStateSeq
+                TEST_PROC_STATE_SEQ1, // lastNetworkUpdateProcStateSeq
+                TEST_PROC_STATE_SEQ1, // procStateSeq to notify
+                false); // expectNotify
+
+        // Notify that network policy rules are updated for TEST_UID1 with procStateSeq older
+        // than it's UidRecord.curProcStateSeq and verify that there is no notify call.
+        verifyNetworkUpdatedProcStateSeq(
+                TEST_PROC_STATE_SEQ3, // curProcStateSeq
+                TEST_PROC_STATE_SEQ1, // lastNetworkUpdateProcStateSeq
+                TEST_PROC_STATE_SEQ2, // procStateSeq to notify
+                false); // expectNotify
     }
 
-    private UidRecord addActiveUidRecord(int uid, long lastNetworkUpdatedProcStateSeq) {
+    private void verifyNetworkUpdatedProcStateSeq(long curProcStateSeq,
+            long lastNetworkUpdatedProcStateSeq, long expectedProcStateSeq, boolean expectNotify)
+            throws Exception {
+        final UidRecord record1 = addActiveUidRecord(TEST_UID1, curProcStateSeq,
+                lastNetworkUpdatedProcStateSeq);
+        final UidRecord record2 = addActiveUidRecord(TEST_UID2, curProcStateSeq,
+                lastNetworkUpdatedProcStateSeq);
+
+        final CustomThread thread1 = new CustomThread(record1.lock);
+        thread1.startAndWait("Unexpected state for " + record1);
+        final CustomThread thread2 = new CustomThread(record2.lock);
+        thread2.startAndWait("Unexpected state for " + record2);
+
+        mAmi.notifyNetworkPolicyRulesUpdated(TEST_UID1, expectedProcStateSeq);
+        assertEquals(record1 + " should be updated",
+                expectedProcStateSeq, record1.lastNetworkUpdatedProcStateSeq);
+        assertEquals(record2 + " should not be updated",
+                lastNetworkUpdatedProcStateSeq, record2.lastNetworkUpdatedProcStateSeq);
+
+        if (expectNotify) {
+            thread1.assertTerminated("Unexpected state for " + record1);
+            assertTrue("Threads waiting for network should be notified: " + record1,
+                    thread1.mNotified);
+        } else {
+            thread1.assertWaiting("Unexpected state for " + record1);
+            thread1.interrupt();
+        }
+        thread2.assertWaiting("Unexpected state for " + record2);
+        thread2.interrupt();
+
+        mAms.mActiveUids.clear();
+    }
+
+    private UidRecord addActiveUidRecord(int uid, long curProcStateSeq,
+            long lastNetworkUpdatedProcStateSeq) {
         final UidRecord record = new UidRecord(uid);
         record.lastNetworkUpdatedProcStateSeq = lastNetworkUpdatedProcStateSeq;
+        record.curProcStateSeq = curProcStateSeq;
+        record.waitingForNetwork = true;
         mAms.mActiveUids.put(uid, record);
         return record;
     }
+
+    static class CustomThread extends Thread {
+        private static final long WAIT_TIMEOUT_MS = 1000;
+        private static final long WAIT_INTERVAL_MS = 100;
+
+        private final Object mLock;
+        private Runnable mRunnable;
+        boolean mNotified;
+
+        public CustomThread(Object lock) {
+            mLock = lock;
+        }
+
+        public CustomThread(Object lock, Runnable runnable) {
+            super(runnable);
+            mLock = lock;
+            mRunnable = runnable;
+        }
+
+        @Override
+        public void run() {
+            if (mRunnable != null) {
+                mRunnable.run();
+            } else {
+                synchronized (mLock) {
+                    try {
+                        mLock.wait();
+                    } catch (InterruptedException e) {
+                        Thread.currentThread().interrupted();
+                    }
+                }
+            }
+            mNotified = !Thread.interrupted();
+        }
+
+        public void startAndWait(String errMsg) throws Exception {
+            startAndWait(errMsg, false);
+        }
+
+        public void startAndWait(String errMsg, boolean timedWaiting) throws Exception {
+            start();
+            final long endTime = SystemClock.elapsedRealtime() + WAIT_TIMEOUT_MS;
+            final Thread.State stateToReach = timedWaiting
+                    ? Thread.State.TIMED_WAITING : Thread.State.WAITING;
+            while (getState() != stateToReach
+                    && SystemClock.elapsedRealtime() < endTime) {
+                Thread.sleep(WAIT_INTERVAL_MS);
+            }
+            if (timedWaiting) {
+                assertTimedWaiting(errMsg);
+            } else {
+                assertWaiting(errMsg);
+            }
+        }
+
+        public void assertWaiting(String errMsg) {
+            assertEquals(errMsg, Thread.State.WAITING, getState());
+        }
+
+        public void assertTimedWaiting(String errMsg) {
+            assertEquals(errMsg, Thread.State.TIMED_WAITING, getState());
+        }
+
+        public void assertTerminated(String errMsg) throws Exception {
+            final long endTime = SystemClock.elapsedRealtime() + WAIT_TIMEOUT_MS;
+            while (getState() != Thread.State.TERMINATED
+                    && SystemClock.elapsedRealtime() < endTime) {
+                Thread.sleep(WAIT_INTERVAL_MS);
+            }
+            assertEquals(errMsg, Thread.State.TERMINATED, getState());
+        }
+    }
 }
diff --git a/services/tests/servicestests/src/com/android/server/am/ActivityManagerServiceTest.java b/services/tests/servicestests/src/com/android/server/am/ActivityManagerServiceTest.java
index 4e9333f..cc5764b 100644
--- a/services/tests/servicestests/src/com/android/server/am/ActivityManagerServiceTest.java
+++ b/services/tests/servicestests/src/com/android/server/am/ActivityManagerServiceTest.java
@@ -28,8 +28,12 @@
 import static android.app.ActivityManager.PROCESS_STATE_SERVICE;
 import static android.app.ActivityManager.PROCESS_STATE_TOP;
 import static android.util.DebugUtils.valueToString;
+import static com.android.server.am.ActivityManagerInternalTest.CustomThread;
 import static com.android.server.am.ActivityManagerService.DISPATCH_UIDS_CHANGED_UI_MSG;
 import static com.android.server.am.ActivityManagerService.Injector;
+import static com.android.server.am.ActivityManagerService.NETWORK_STATE_BLOCK;
+import static com.android.server.am.ActivityManagerService.NETWORK_STATE_NO_CHANGE;
+import static com.android.server.am.ActivityManagerService.NETWORK_STATE_UNBLOCK;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -40,11 +44,14 @@
 import static org.junit.Assert.fail;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.verifyZeroInteractions;
 import static org.mockito.Mockito.when;
 
 import android.app.ActivityManager;
 import android.app.AppOpsManager;
+import android.app.IApplicationThread;
 import android.app.IUidObserver;
+import android.content.pm.ApplicationInfo;
 import android.os.Handler;
 import android.os.HandlerThread;
 import android.os.IBinder;
@@ -57,6 +64,7 @@
 import android.support.test.filters.SmallTest;
 import android.support.test.runner.AndroidJUnit4;
 
+import com.android.internal.os.BatteryStatsImpl;
 import com.android.server.AppOpsService;
 
 import org.junit.After;
@@ -67,6 +75,7 @@
 import org.mockito.Mockito;
 import org.mockito.MockitoAnnotations;
 
+import java.io.File;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -109,6 +118,7 @@
 
     @Mock private AppOpsService mAppOpsService;
 
+    private TestInjector mInjector;
     private ActivityManagerService mAms;
     private HandlerThread mHandlerThread;
     private TestHandler mHandler;
@@ -120,7 +130,8 @@
         mHandlerThread = new HandlerThread(TAG);
         mHandlerThread.start();
         mHandler = new TestHandler(mHandlerThread.getLooper());
-        mAms = new ActivityManagerService(new TestInjector());
+        mInjector = new TestInjector();
+        mAms = new ActivityManagerService(mInjector);
     }
 
     @After
@@ -128,53 +139,127 @@
         mHandlerThread.quit();
     }
 
+    @MediumTest
     @Test
-    public void testIncrementProcStateSeqIfNeeded() {
+    public void incrementProcStateSeqAndNotifyAppsLocked() throws Exception {
         final UidRecord uidRec = new UidRecord(TEST_UID);
+        uidRec.waitingForNetwork = true;
+        mAms.mActiveUids.put(TEST_UID, uidRec);
 
-        assertEquals("Initially global seq counter should be 0", 0, mAms.mProcStateSeqCounter);
-        assertEquals("Initially seq counter in uidRecord should be 0", 0, uidRec.curProcStateSeq);
+        final BatteryStatsImpl batteryStats = Mockito.mock(BatteryStatsImpl.class);
+        final ProcessRecord appRec = new ProcessRecord(batteryStats,
+                new ApplicationInfo(), TAG, TEST_UID);
+        appRec.thread = Mockito.mock(IApplicationThread.class);
+        mAms.mLruProcesses.add(appRec);
+
+        final ProcessRecord appRec2 = new ProcessRecord(batteryStats,
+                new ApplicationInfo(), TAG, TEST_UID + 1);
+        appRec2.thread = Mockito.mock(IApplicationThread.class);
+        mAms.mLruProcesses.add(appRec2);
 
         // Uid state is not moving from background to foreground or vice versa.
-        uidRec.setProcState = PROCESS_STATE_TOP;
-        uidRec.curProcState = PROCESS_STATE_TOP;
-        mAms.incrementProcStateSeqIfNeeded(uidRec);
-        assertEquals(0, mAms.mProcStateSeqCounter);
-        assertEquals(0, uidRec.curProcStateSeq);
+        verifySeqCounterAndInteractions(uidRec,
+                PROCESS_STATE_TOP, // prevState
+                PROCESS_STATE_TOP, // curState
+                0, // expectedGlobalCounter
+                0, // exptectedCurProcStateSeq
+                NETWORK_STATE_NO_CHANGE, // expectedBlockState
+                false); // expectNotify
 
         // Uid state is moving from foreground to background.
-        uidRec.curProcState = PROCESS_STATE_FOREGROUND_SERVICE;
-        uidRec.setProcState = PROCESS_STATE_SERVICE;
-        mAms.incrementProcStateSeqIfNeeded(uidRec);
-        assertEquals(1, mAms.mProcStateSeqCounter);
-        assertEquals(1, uidRec.curProcStateSeq);
+        verifySeqCounterAndInteractions(uidRec,
+                PROCESS_STATE_FOREGROUND_SERVICE, // prevState
+                PROCESS_STATE_SERVICE, // curState
+                1, // expectedGlobalCounter
+                1, // exptectedCurProcStateSeq
+                NETWORK_STATE_UNBLOCK, // expectedBlockState
+                true); // expectNotify
 
         // Explicitly setting the seq counter for more verification.
         mAms.mProcStateSeqCounter = 42;
 
         // Uid state is not moving from background to foreground or vice versa.
-        uidRec.setProcState = PROCESS_STATE_IMPORTANT_BACKGROUND;
-        uidRec.curProcState = PROCESS_STATE_IMPORTANT_FOREGROUND;
-        mAms.incrementProcStateSeqIfNeeded(uidRec);
-        assertEquals(42, mAms.mProcStateSeqCounter);
-        assertEquals(1, uidRec.curProcStateSeq);
+        verifySeqCounterAndInteractions(uidRec,
+                PROCESS_STATE_IMPORTANT_BACKGROUND, // prevState
+                PROCESS_STATE_IMPORTANT_FOREGROUND, // curState
+                42, // expectedGlobalCounter
+                1, // exptectedCurProcStateSeq
+                NETWORK_STATE_NO_CHANGE, // expectedBlockState
+                false); // expectNotify
 
         // Uid state is moving from background to foreground.
-        uidRec.setProcState = PROCESS_STATE_LAST_ACTIVITY;
-        uidRec.curProcState = PROCESS_STATE_TOP;
-        mAms.incrementProcStateSeqIfNeeded(uidRec);
-        assertEquals(43, mAms.mProcStateSeqCounter);
-        assertEquals(43, uidRec.curProcStateSeq);
+        verifySeqCounterAndInteractions(uidRec,
+                PROCESS_STATE_LAST_ACTIVITY, // prevState
+                PROCESS_STATE_TOP, // curState
+                43, // expectedGlobalCounter
+                43, // exptectedCurProcStateSeq
+                NETWORK_STATE_BLOCK, // expectedBlockState
+                false); // expectNotify
+
+        // verify waiting threads are not notified.
+        uidRec.waitingForNetwork = false;
+        // Uid state is moving from foreground to background.
+        verifySeqCounterAndInteractions(uidRec,
+                PROCESS_STATE_FOREGROUND_SERVICE, // prevState
+                PROCESS_STATE_SERVICE, // curState
+                44, // expectedGlobalCounter
+                44, // exptectedCurProcStateSeq
+                NETWORK_STATE_UNBLOCK, // expectedBlockState
+                false); // expectNotify
+
+        // Verify when uid is not restricted, procStateSeq is not incremented.
+        uidRec.waitingForNetwork = true;
+        mInjector.setNetworkRestrictedForUid(false);
+        verifySeqCounterAndInteractions(uidRec,
+                PROCESS_STATE_IMPORTANT_BACKGROUND, // prevState
+                PROCESS_STATE_TOP, // curState
+                44, // expectedGlobalCounter
+                44, // exptectedCurProcStateSeq
+                -1, // expectedBlockState, -1 to verify there are no interactions with main thread.
+                false); // expectNotify
+    }
+
+    private void verifySeqCounterAndInteractions(UidRecord uidRec, int prevState, int curState,
+            int expectedGlobalCounter, int expectedCurProcStateSeq, int expectedBlockState,
+            boolean expectNotify) throws Exception {
+        CustomThread thread = new CustomThread(uidRec.lock);
+        thread.startAndWait("Unexpected state for " + uidRec);
+
+        uidRec.setProcState = prevState;
+        uidRec.curProcState = curState;
+        mAms.incrementProcStateSeqAndNotifyAppsLocked();
+
+        assertEquals(expectedGlobalCounter, mAms.mProcStateSeqCounter);
+        assertEquals(expectedCurProcStateSeq, uidRec.curProcStateSeq);
+
+        for (int i = mAms.mLruProcesses.size() - 1; i >= 0; --i) {
+            final ProcessRecord app = mAms.mLruProcesses.get(i);
+            // AMS should notify apps only for block states other than NETWORK_STATE_NO_CHANGE.
+            if (app.uid == uidRec.uid && expectedBlockState == NETWORK_STATE_BLOCK) {
+                verify(app.thread).setNetworkBlockSeq(uidRec.curProcStateSeq);
+            } else {
+                verifyZeroInteractions(app.thread);
+            }
+            Mockito.reset(app.thread);
+        }
+
+        if (expectNotify) {
+            thread.assertTerminated("Unexpected state for " + uidRec);
+        } else {
+            thread.assertWaiting("Unexpected state for " + uidRec);
+            thread.interrupt();
+        }
     }
 
     @Test
-    public void testShouldIncrementProcStateSeq() {
+    public void testBlockStateForUid() {
         final UidRecord uidRec = new UidRecord(TEST_UID);
+        int expectedBlockState;
 
-        final String error1 = "Seq should be incremented: prevState: %s, curState: %s";
-        final String error2 = "Seq should not be incremented: prevState: %s, curState: %s";
-        Function<String, String> errorMsg = errorTemplate -> {
+        final String errorTemplate = "Block state should be %s, prevState: %s, curState: %s";
+        Function<Integer, String> errorMsg = (blockState) -> {
             return String.format(errorTemplate,
+                    valueToString(ActivityManagerService.class, "NETWORK_STATE_", blockState),
                     valueToString(ActivityManager.class, "PROCESS_STATE_", uidRec.setProcState),
                     valueToString(ActivityManager.class, "PROCESS_STATE_", uidRec.curProcState));
         };
@@ -182,32 +267,44 @@
         // No change in uid state
         uidRec.setProcState = PROCESS_STATE_RECEIVER;
         uidRec.curProcState = PROCESS_STATE_RECEIVER;
-        assertFalse(errorMsg.apply(error2), mAms.shouldIncrementProcStateSeq(uidRec));
+        expectedBlockState = NETWORK_STATE_NO_CHANGE;
+        assertEquals(errorMsg.apply(expectedBlockState),
+                expectedBlockState, mAms.getBlockStateForUid(uidRec));
 
         // Foreground to foreground
         uidRec.setProcState = PROCESS_STATE_FOREGROUND_SERVICE;
         uidRec.curProcState = PROCESS_STATE_BOUND_FOREGROUND_SERVICE;
-        assertFalse(errorMsg.apply(error2), mAms.shouldIncrementProcStateSeq(uidRec));
+        expectedBlockState = NETWORK_STATE_NO_CHANGE;
+        assertEquals(errorMsg.apply(expectedBlockState),
+                expectedBlockState, mAms.getBlockStateForUid(uidRec));
 
         // Background to background
         uidRec.setProcState = PROCESS_STATE_CACHED_ACTIVITY;
         uidRec.curProcState = PROCESS_STATE_CACHED_EMPTY;
-        assertFalse(errorMsg.apply(error2), mAms.shouldIncrementProcStateSeq(uidRec));
+        expectedBlockState = NETWORK_STATE_NO_CHANGE;
+        assertEquals(errorMsg.apply(expectedBlockState),
+                expectedBlockState, mAms.getBlockStateForUid(uidRec));
 
         // Background to background
         uidRec.setProcState = PROCESS_STATE_NONEXISTENT;
         uidRec.curProcState = PROCESS_STATE_CACHED_ACTIVITY;
-        assertFalse(errorMsg.apply(error2), mAms.shouldIncrementProcStateSeq(uidRec));
+        expectedBlockState = NETWORK_STATE_NO_CHANGE;
+        assertEquals(errorMsg.apply(expectedBlockState),
+                expectedBlockState, mAms.getBlockStateForUid(uidRec));
 
         // Background to foreground
         uidRec.setProcState = PROCESS_STATE_SERVICE;
         uidRec.curProcState = PROCESS_STATE_FOREGROUND_SERVICE;
-        assertTrue(errorMsg.apply(error1), mAms.shouldIncrementProcStateSeq(uidRec));
+        expectedBlockState = NETWORK_STATE_BLOCK;
+        assertEquals(errorMsg.apply(expectedBlockState),
+                expectedBlockState, mAms.getBlockStateForUid(uidRec));
 
         // Foreground to background
         uidRec.setProcState = PROCESS_STATE_TOP;
         uidRec.curProcState = PROCESS_STATE_LAST_ACTIVITY;
-        assertTrue(errorMsg.apply(error1), mAms.shouldIncrementProcStateSeq(uidRec));
+        expectedBlockState = NETWORK_STATE_UNBLOCK;
+        assertEquals(errorMsg.apply(expectedBlockState),
+                expectedBlockState, mAms.getBlockStateForUid(uidRec));
     }
 
     /**
@@ -552,6 +649,81 @@
         }
     }
 
+    @MediumTest
+    @Test
+    public void testWaitForNetworkStateUpdate() throws Exception {
+        // Check there is no crash when there is no UidRecord for myUid
+        mAms.waitForNetworkStateUpdate(TEST_PROC_STATE_SEQ1);
+
+        // Verify there is no waiting when UidRecord.curProcStateSeq is greater than
+        // the procStateSeq in the request to wait.
+        verifyWaitingForNetworkStateUpdate(
+                TEST_PROC_STATE_SEQ1, // curProcStateSeq
+                TEST_PROC_STATE_SEQ1, // lastDsipatchedProcStateSeq
+                TEST_PROC_STATE_SEQ1 - 4, // lastNetworkUpdatedProcStateSeq
+                TEST_PROC_STATE_SEQ1 - 2, // procStateSeqToWait
+                false); // expectWait
+
+        // Verify there is no waiting when the procStateSeq in the request to wait is
+        // not dispatched to NPMS.
+        verifyWaitingForNetworkStateUpdate(
+                TEST_PROC_STATE_SEQ1, // curProcStateSeq
+                TEST_PROC_STATE_SEQ1 - 1, // lastDsipatchedProcStateSeq
+                TEST_PROC_STATE_SEQ1 - 1, // lastNetworkUpdatedProcStateSeq
+                TEST_PROC_STATE_SEQ1, // procStateSeqToWait
+                false); // expectWait
+
+        // Verify there is not waiting when the procStateSeq in the request already has
+        // an updated network state.
+        verifyWaitingForNetworkStateUpdate(
+                TEST_PROC_STATE_SEQ1, // curProcStateSeq
+                TEST_PROC_STATE_SEQ1, // lastDsipatchedProcStateSeq
+                TEST_PROC_STATE_SEQ1, // lastNetworkUpdatedProcStateSeq
+                TEST_PROC_STATE_SEQ1, // procStateSeqToWait
+                false); // expectWait
+
+        // Verify waiting for network works
+        verifyWaitingForNetworkStateUpdate(
+                TEST_PROC_STATE_SEQ1, // curProcStateSeq
+                TEST_PROC_STATE_SEQ1, // lastDsipatchedProcStateSeq
+                TEST_PROC_STATE_SEQ1 - 1, // lastNetworkUpdatedProcStateSeq
+                TEST_PROC_STATE_SEQ1, // procStateSeqToWait
+                true); // expectWait
+    }
+
+    private void verifyWaitingForNetworkStateUpdate(long curProcStateSeq,
+            long lastDispatchedProcStateSeq, long lastNetworkUpdatedProcStateSeq,
+            final long procStateSeqToWait, boolean expectWait) throws Exception {
+        final UidRecord record = new UidRecord(Process.myUid());
+        record.curProcStateSeq = curProcStateSeq;
+        record.lastDispatchedProcStateSeq = lastDispatchedProcStateSeq;
+        record.lastNetworkUpdatedProcStateSeq = lastNetworkUpdatedProcStateSeq;
+        mAms.mActiveUids.put(Process.myUid(), record);
+
+        CustomThread thread = new CustomThread(record.lock, new Runnable() {
+            @Override
+            public void run() {
+                mAms.waitForNetworkStateUpdate(procStateSeqToWait);
+            }
+        });
+        final String errMsg = "Unexpected state for " + record;
+        if (expectWait) {
+            thread.startAndWait(errMsg, true);
+            thread.assertTimedWaiting(errMsg);
+            synchronized (record.lock) {
+                record.lock.notifyAll();
+            }
+            thread.assertTerminated(errMsg);
+            assertTrue(thread.mNotified);
+            assertFalse(record.waitingForNetwork);
+        } else {
+            thread.start();
+            thread.assertTerminated(errMsg);
+        }
+
+        mAms.mActiveUids.clear();
+    }
+
     private class TestHandler extends Handler {
         private static final long WAIT_FOR_MSG_TIMEOUT_MS = 4000; // 4 sec
         private static final long WAIT_FOR_MSG_INTERVAL_MS = 400; // 0.4 sec
@@ -582,15 +754,26 @@
         }
     }
 
-    private class TestInjector implements Injector {
+    private class TestInjector extends Injector {
+        private boolean mRestricted = true;
+
         @Override
-        public AppOpsService getAppOpsService() {
+        public AppOpsService getAppOpsService(File file, Handler handler) {
             return mAppOpsService;
         }
 
         @Override
-        public Handler getHandler() {
+        public Handler getUiHandler(ActivityManagerService service) {
             return mHandler;
         }
+
+        @Override
+        public boolean isNetworkRestrictedForUid(int uid) {
+            return mRestricted;
+        }
+
+        public void setNetworkRestrictedForUid(boolean restricted) {
+            mRestricted = restricted;
+        }
     }
 }
\ No newline at end of file