Merge "Remove network requests properly." into mnc-dev
diff --git a/core/java/android/net/NetworkFactory.java b/core/java/android/net/NetworkFactory.java
index e47220b..71fda1c 100644
--- a/core/java/android/net/NetworkFactory.java
+++ b/core/java/android/net/NetworkFactory.java
@@ -24,6 +24,7 @@
 import android.util.Log;
 import android.util.SparseArray;
 
+import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.util.Protocol;
 
 /**
@@ -176,9 +177,9 @@
 
     private void handleRemoveRequest(NetworkRequest request) {
         NetworkRequestInfo n = mNetworkRequests.get(request.requestId);
-        if (n != null && n.requested) {
+        if (n != null) {
             mNetworkRequests.remove(request.requestId);
-            releaseNetworkFor(n.request);
+            if (n.requested) releaseNetworkFor(n.request);
         }
     }
 
@@ -273,6 +274,11 @@
         sendMessage(obtainMessage(CMD_SET_FILTER, new NetworkCapabilities(netCap)));
     }
 
+    @VisibleForTesting
+    protected int getRequestCount() {
+        return mNetworkRequests.size();
+    }
+
     protected void log(String s) {
         Log.d(LOG_TAG, s);
     }
diff --git a/services/tests/servicestests/src/com/android/server/ConnectivityServiceTest.java b/services/tests/servicestests/src/com/android/server/ConnectivityServiceTest.java
index 6684be4..bb0a36f 100644
--- a/services/tests/servicestests/src/com/android/server/ConnectivityServiceTest.java
+++ b/services/tests/servicestests/src/com/android/server/ConnectivityServiceTest.java
@@ -48,6 +48,7 @@
 import android.net.NetworkAgent;
 import android.net.NetworkCapabilities;
 import android.net.NetworkConfig;
+import android.net.NetworkFactory;
 import android.net.NetworkInfo;
 import android.net.NetworkInfo.DetailedState;
 import android.net.NetworkMisc;
@@ -55,6 +56,7 @@
 import android.net.RouteInfo;
 import android.os.ConditionVariable;
 import android.os.Handler;
+import android.os.HandlerThread;
 import android.os.Looper;
 import android.os.INetworkManagementService;
 import android.test.AndroidTestCase;
@@ -68,6 +70,7 @@
 
 import java.net.InetAddress;
 import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicBoolean;
 
 /**
  * Tests for {@link ConnectivityService}.
@@ -224,6 +227,31 @@
         }
     }
 
+    private static class MockNetworkFactory extends NetworkFactory {
+        final AtomicBoolean mNetworkStarted = new AtomicBoolean(false);
+
+        public MockNetworkFactory(Looper looper, Context context, String logTag,
+                NetworkCapabilities filter) {
+            super(looper, context, logTag, filter);
+        }
+
+        public int getMyRequestCount() {
+            return getRequestCount();
+        }
+
+        protected void startNetwork() {
+            mNetworkStarted.set(true);
+        }
+
+        protected void stopNetwork() {
+            mNetworkStarted.set(false);
+        }
+
+        public boolean getMyStartRequested() {
+            return mNetworkStarted.get();
+        }
+    }
+
     private class WrappedConnectivityService extends ConnectivityService {
         public WrappedConnectivityService(Context context, INetworkManagementService netManager,
                 INetworkStatsService statsService, INetworkPolicyManager policyManager) {
@@ -447,6 +475,71 @@
         verifyNoNetwork();
     }
 
+    @LargeTest
+    public void testNetworkFactoryRequests() throws Exception {
+        NetworkCapabilities filter = new NetworkCapabilities();
+        filter.addCapability(NET_CAPABILITY_INTERNET);
+        final HandlerThread handlerThread = new HandlerThread("testNetworkFactoryRequests");
+        handlerThread.start();
+        MockNetworkFactory testFactory = new MockNetworkFactory(handlerThread.getLooper(),
+                mServiceContext, "testFactory", filter);
+        testFactory.setScoreFilter(40);
+        testFactory.register();
+        try {
+            Thread.sleep(500);
+        } catch (Exception e) {}
+        assertEquals(1, testFactory.getMyRequestCount());
+        assertEquals(true, testFactory.getMyStartRequested());
+
+        // now bring in a higher scored network
+        MockNetworkAgent testAgent = new MockNetworkAgent(TRANSPORT_CELLULAR);
+        ConditionVariable cv = waitForConnectivityBroadcasts(1);
+        testAgent.connect(true);
+        cv.block();
+        // part of the bringup makes another network request and then releases it
+        // wait for the release
+        try { Thread.sleep(500); } catch (Exception e) {}
+        assertEquals(1, testFactory.getMyRequestCount());
+        assertEquals(false, testFactory.getMyStartRequested());
+
+        // bring in a bunch of requests..
+        ConnectivityManager.NetworkCallback[] networkCallbacks =
+                new ConnectivityManager.NetworkCallback[10];
+        for (int i = 0; i< networkCallbacks.length; i++) {
+            networkCallbacks[i] = new ConnectivityManager.NetworkCallback();
+            NetworkRequest.Builder builder = new NetworkRequest.Builder();
+            builder.addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET);
+            mCm.requestNetwork(builder.build(), networkCallbacks[i]);
+        }
+
+        try {
+            Thread.sleep(1000);
+        } catch (Exception e) {}
+        assertEquals(11, testFactory.getMyRequestCount());
+        assertEquals(false, testFactory.getMyStartRequested());
+
+        // remove the requests
+        for (int i = 0; i < networkCallbacks.length; i++) {
+            mCm.unregisterNetworkCallback(networkCallbacks[i]);
+        }
+        try {
+            Thread.sleep(500);
+        } catch (Exception e) {}
+        assertEquals(1, testFactory.getMyRequestCount());
+        assertEquals(false, testFactory.getMyStartRequested());
+
+        // drop the higher scored network
+        cv = waitForConnectivityBroadcasts(1);
+        testAgent.disconnect();
+        cv.block();
+        assertEquals(1, testFactory.getMyRequestCount());
+        assertEquals(true, testFactory.getMyStartRequested());
+
+        testFactory.unregister();
+        handlerThread.quit();
+    }
+
+
 //    @Override
 //    public void tearDown() throws Exception {
 //        super.tearDown();