Merge "ConnectivityManager: release all requests mapping to a callback."
am: 1a8f04b1b2

Change-Id: I30d3790822430d250d6005cc165e7fa10f56649e
diff --git a/core/java/android/net/ConnectivityManager.java b/core/java/android/net/ConnectivityManager.java
index 2a985e7..0e5d049 100644
--- a/core/java/android/net/ConnectivityManager.java
+++ b/core/java/android/net/ConnectivityManager.java
@@ -15,8 +15,6 @@
  */
 package android.net;
 
-import static com.android.internal.util.Preconditions.checkNotNull;
-
 import android.annotation.IntDef;
 import android.annotation.Nullable;
 import android.annotation.SdkConstant;
@@ -50,16 +48,19 @@
 
 import com.android.internal.telephony.ITelephony;
 import com.android.internal.telephony.PhoneConstants;
-import com.android.internal.util.Protocol;
 import com.android.internal.util.MessageUtils;
+import com.android.internal.util.Preconditions;
+import com.android.internal.util.Protocol;
 
 import libcore.net.event.NetworkEventDispatcher;
 
 import java.lang.annotation.Retention;
 import java.lang.annotation.RetentionPolicy;
 import java.net.InetAddress;
+import java.util.ArrayList;
 import java.util.HashMap;
-import java.util.concurrent.atomic.AtomicInteger;
+import java.util.List;
+import java.util.Map;
 
 /**
  * Class that answers queries about the state of network connectivity. It also
@@ -1547,8 +1548,8 @@
         }
 
         private PacketKeepalive(Network network, PacketKeepaliveCallback callback) {
-            checkNotNull(network, "network cannot be null");
-            checkNotNull(callback, "callback cannot be null");
+            Preconditions.checkNotNull(network, "network cannot be null");
+            Preconditions.checkNotNull(callback, "callback cannot be null");
             mNetwork = network;
             mCallback = callback;
             HandlerThread thread = new HandlerThread(TAG);
@@ -1835,8 +1836,8 @@
      * {@hide}
      */
     public ConnectivityManager(Context context, IConnectivityManager service) {
-        mContext = checkNotNull(context, "missing context");
-        mService = checkNotNull(service, "missing IConnectivityManager");
+        mContext = Preconditions.checkNotNull(context, "missing context");
+        mService = Preconditions.checkNotNull(service, "missing IConnectivityManager");
         sInstance = this;
     }
 
@@ -2099,7 +2100,7 @@
     @SystemApi
     public void startTethering(int type, boolean showProvisioningUi,
             final OnStartTetheringCallback callback, Handler handler) {
-        checkNotNull(callback, "OnStartTetheringCallback cannot be null.");
+        Preconditions.checkNotNull(callback, "OnStartTetheringCallback cannot be null.");
 
         ResultReceiver wrappedCallback = new ResultReceiver(handler) {
             @Override
@@ -2559,8 +2560,16 @@
     }
 
     /**
-     * Base class for NetworkRequest callbacks.  Used for notifications about network
-     * changes.  Should be extended by applications wanting notifications.
+     * Base class for {@code NetworkRequest} callbacks. Used for notifications about network
+     * changes. Should be extended by applications wanting notifications.
+     *
+     * A {@code NetworkCallback} is registered by calling
+     * {@link #requestNetwork(NetworkRequest, NetworkCallback)},
+     * {@link #registerNetworkCallback(NetworkRequest, NetworkCallback)},
+     * or {@link #registerDefaultNetworkCallback(NetworkCallback). A {@code NetworkCallback} is
+     * unregistered by calling {@link #unregisterNetworkCallback(NetworkCallback)}.
+     * A {@code NetworkCallback} should be registered at most once at any time.
+     * A {@code NetworkCallback} that has been unregistered can be registered again.
      */
     public static class NetworkCallback {
         /**
@@ -2663,6 +2672,10 @@
         public void onNetworkResumed(Network network) {}
 
         private NetworkRequest networkRequest;
+
+        private boolean isRegistered() {
+            return (networkRequest != null) && (networkRequest.requestId != REQUEST_ID_UNSET);
+        }
     }
 
     private static final int BASE = Protocol.BASE_CONNECTIVITY_MANAGER;
@@ -2680,6 +2693,7 @@
     public static final int CALLBACK_CAP_CHANGED         = BASE + 6;
     /** @hide */
     public static final int CALLBACK_IP_CHANGED          = BASE + 7;
+    // TODO: consider deleting CALLBACK_RELEASED and shifting following enum codes down by 1.
     /** @hide */
     public static final int CALLBACK_RELEASED            = BASE + 8;
     // TODO: consider deleting CALLBACK_EXIT and shifting following enum codes down by 1.
@@ -2798,13 +2812,6 @@
                     break;
                 }
                 case CALLBACK_RELEASED: {
-                    final NetworkCallback callback;
-                    synchronized(sCallbacks) {
-                        callback = sCallbacks.remove(request);
-                    }
-                    if (callback == null) {
-                        Log.e(TAG, "callback not found for RELEASED message");
-                    }
                     break;
                 }
                 case CALLBACK_EXIT: {
@@ -2822,12 +2829,12 @@
         }
 
         private NetworkCallback getCallback(NetworkRequest req, String name) {
-            NetworkCallback callback;
+            final NetworkCallback callback;
             synchronized(sCallbacks) {
                 callback = sCallbacks.get(req);
             }
             if (callback == null) {
-                Log.e(TAG, "callback not found for " + name + " message");
+                Log.w(TAG, "callback not found for " + name + " message");
             }
             return callback;
         }
@@ -2850,17 +2857,16 @@
 
     private NetworkRequest sendRequestForNetwork(NetworkCapabilities need, NetworkCallback callback,
             int timeoutMs, int action, int legacyType, CallbackHandler handler) {
-        if (callback == null) {
-            throw new IllegalArgumentException("null NetworkCallback");
-        }
-        if (need == null && action != REQUEST) {
-            throw new IllegalArgumentException("null NetworkCapabilities");
-        }
-        // TODO: throw an exception if callback.networkRequest is not null.
-        // http://b/20701525
+        Preconditions.checkArgument(callback != null, "null NetworkCallback");
+        Preconditions.checkArgument(action == REQUEST || need != null, "null NetworkCapabilities");
         final NetworkRequest request;
         try {
             synchronized(sCallbacks) {
+                if (callback.isRegistered()) {
+                    // TODO: throw exception instead and enforce 1:1 mapping of callbacks
+                    // and requests (http://b/20701525).
+                    Log.e(TAG, "NetworkCallback was already registered");
+                }
                 Messenger messenger = new Messenger(handler);
                 Binder binder = new Binder();
                 if (action == LISTEN) {
@@ -3325,25 +3331,42 @@
     }
 
     /**
-     * Unregisters callbacks about and possibly releases networks originating from
+     * Unregisters a {@code NetworkCallback} and possibly releases networks originating from
      * {@link #requestNetwork(NetworkRequest, NetworkCallback)} and
      * {@link #registerNetworkCallback(NetworkRequest, NetworkCallback)} calls.
      * If the given {@code NetworkCallback} had previously been used with
      * {@code #requestNetwork}, any networks that had been connected to only to satisfy that request
      * will be disconnected.
      *
+     * Notifications that would have triggered that {@code NetworkCallback} will immediately stop
+     * triggering it as soon as this call returns.
+     *
      * @param networkCallback The {@link NetworkCallback} used when making the request.
      */
     public void unregisterNetworkCallback(NetworkCallback networkCallback) {
-        if (networkCallback == null || networkCallback.networkRequest == null ||
-                networkCallback.networkRequest.requestId == REQUEST_ID_UNSET) {
-            throw new IllegalArgumentException("Invalid NetworkCallback");
-        }
-        try {
-            // CallbackHandler will release callback when receiving CALLBACK_RELEASED.
-            mService.releaseNetworkRequest(networkCallback.networkRequest);
-        } catch (RemoteException e) {
-            throw e.rethrowFromSystemServer();
+        Preconditions.checkArgument(networkCallback != null, "null NetworkCallback");
+        final List<NetworkRequest> reqs = new ArrayList<>();
+        // Find all requests associated to this callback and stop callback triggers immediately.
+        // Callback is reusable immediately. http://b/20701525, http://b/35921499.
+        synchronized (sCallbacks) {
+            Preconditions.checkArgument(
+                    networkCallback.isRegistered(), "NetworkCallback was not registered");
+            for (Map.Entry<NetworkRequest, NetworkCallback> e : sCallbacks.entrySet()) {
+                if (e.getValue() == networkCallback) {
+                    reqs.add(e.getKey());
+                }
+            }
+            // TODO: throw exception if callback was registered more than once (http://b/20701525).
+            for (NetworkRequest r : reqs) {
+                try {
+                    mService.releaseNetworkRequest(r);
+                } catch (RemoteException e) {
+                    throw e.rethrowFromSystemServer();
+                }
+                // Only remove mapping if rpc was successful.
+                sCallbacks.remove(r);
+            }
+            networkCallback.networkRequest = null;
         }
     }
 
diff --git a/tests/net/java/android/net/ConnectivityManagerTest.java b/tests/net/java/android/net/ConnectivityManagerTest.java
index b984bbf..ceb0135 100644
--- a/tests/net/java/android/net/ConnectivityManagerTest.java
+++ b/tests/net/java/android/net/ConnectivityManagerTest.java
@@ -36,21 +36,50 @@
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.anyInt;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.timeout;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 import android.net.ConnectivityManager;
 import android.net.NetworkCapabilities;
-
+import android.content.Context;
+import android.os.Bundle;
+import android.os.Handler;
+import android.os.Looper;
+import android.os.Message;
+import android.os.Messenger;
+import android.content.pm.ApplicationInfo;
+import android.os.Build.VERSION_CODES;
+import android.net.ConnectivityManager.NetworkCallback;
 import android.support.test.filters.SmallTest;
 import android.support.test.runner.AndroidJUnit4;
 
-import org.junit.runner.RunWith;
+import org.junit.Before;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
 
 
 
 @RunWith(AndroidJUnit4.class)
 @SmallTest
 public class ConnectivityManagerTest {
+
+    @Mock Context mCtx;
+    @Mock IConnectivityManager mService;
+
+    @Before
+    public void setUp() {
+        MockitoAnnotations.initMocks(this);
+    }
+
     static NetworkCapabilities verifyNetworkCapabilities(
             int legacyType, int transportType, int... capabilities) {
         final NetworkCapabilities nc = ConnectivityManager.networkCapabilitiesForType(legacyType);
@@ -173,4 +202,124 @@
         verifyUnrestrictedNetworkCapabilities(
                 ConnectivityManager.TYPE_ETHERNET, TRANSPORT_ETHERNET);
     }
+
+    @Test
+    public void testCallbackRelease() throws Exception {
+        ConnectivityManager manager = new ConnectivityManager(mCtx, mService);
+        NetworkRequest request = makeRequest(1);
+        NetworkCallback callback = mock(ConnectivityManager.NetworkCallback.class);
+        Handler handler = new Handler(Looper.getMainLooper());
+        ArgumentCaptor<Messenger> captor = ArgumentCaptor.forClass(Messenger.class);
+
+        // register callback
+        when(mService.requestNetwork(any(), captor.capture(), anyInt(), any(), anyInt()))
+                .thenReturn(request);
+        manager.requestNetwork(request, callback, handler);
+
+        // callback triggers
+        captor.getValue().send(makeMessage(request, ConnectivityManager.CALLBACK_AVAILABLE));
+        verify(callback, timeout(500).times(1)).onAvailable(any());
+
+        // unregister callback
+        manager.unregisterNetworkCallback(callback);
+        verify(mService, times(1)).releaseNetworkRequest(request);
+
+        // callback does not trigger anymore.
+        captor.getValue().send(makeMessage(request, ConnectivityManager.CALLBACK_LOSING));
+        verify(callback, timeout(500).times(0)).onLosing(any(), anyInt());
+    }
+
+    @Test
+    public void testCallbackRecycling() throws Exception {
+        ConnectivityManager manager = new ConnectivityManager(mCtx, mService);
+        NetworkRequest req1 = makeRequest(1);
+        NetworkRequest req2 = makeRequest(2);
+        NetworkCallback callback = mock(ConnectivityManager.NetworkCallback.class);
+        Handler handler = new Handler(Looper.getMainLooper());
+        ArgumentCaptor<Messenger> captor = ArgumentCaptor.forClass(Messenger.class);
+
+        // register callback
+        when(mService.requestNetwork(any(), captor.capture(), anyInt(), any(), anyInt()))
+                .thenReturn(req1);
+        manager.requestNetwork(req1, callback, handler);
+
+        // callback triggers
+        captor.getValue().send(makeMessage(req1, ConnectivityManager.CALLBACK_AVAILABLE));
+        verify(callback, timeout(100).times(1)).onAvailable(any());
+
+        // unregister callback
+        manager.unregisterNetworkCallback(callback);
+        verify(mService, times(1)).releaseNetworkRequest(req1);
+
+        // callback does not trigger anymore.
+        captor.getValue().send(makeMessage(req1, ConnectivityManager.CALLBACK_LOSING));
+        verify(callback, timeout(100).times(0)).onLosing(any(), anyInt());
+
+        // callback can be registered again
+        when(mService.requestNetwork(any(), captor.capture(), anyInt(), any(), anyInt()))
+                .thenReturn(req2);
+        manager.requestNetwork(req2, callback, handler);
+
+        // callback triggers
+        captor.getValue().send(makeMessage(req2, ConnectivityManager.CALLBACK_LOST));
+        verify(callback, timeout(100).times(1)).onLost(any());
+
+        // unregister callback
+        manager.unregisterNetworkCallback(callback);
+        verify(mService, times(1)).releaseNetworkRequest(req2);
+    }
+
+    // TODO: turn on this test when request  callback 1:1 mapping is enforced
+    //@Test
+    private void noDoubleCallbackRegistration() throws Exception {
+        ConnectivityManager manager = new ConnectivityManager(mCtx, mService);
+        NetworkRequest request = makeRequest(1);
+        NetworkCallback callback = new ConnectivityManager.NetworkCallback();
+        ApplicationInfo info = new ApplicationInfo();
+        // TODO: update version when starting to enforce 1:1 mapping
+        info.targetSdkVersion = VERSION_CODES.N_MR1 + 1;
+
+        when(mCtx.getApplicationInfo()).thenReturn(info);
+        when(mService.requestNetwork(any(), any(), anyInt(), any(), anyInt())).thenReturn(request);
+
+        Handler handler = new Handler(Looper.getMainLooper());
+        manager.requestNetwork(request, callback, handler);
+
+        // callback is already registered, reregistration should fail.
+        Class<IllegalArgumentException> wantException = IllegalArgumentException.class;
+        expectThrowable(() -> manager.requestNetwork(request, callback), wantException);
+
+        manager.unregisterNetworkCallback(callback);
+        verify(mService, times(1)).releaseNetworkRequest(request);
+
+        // unregistering the callback should make it registrable again.
+        manager.requestNetwork(request, callback);
+    }
+
+    static Message makeMessage(NetworkRequest req, int messageType) {
+        Bundle bundle = new Bundle();
+        bundle.putParcelable(NetworkRequest.class.getSimpleName(), req);
+        Message msg = Message.obtain();
+        msg.what = messageType;
+        msg.setData(bundle);
+        return msg;
+    }
+
+    static NetworkRequest makeRequest(int requestId) {
+        NetworkRequest request = new NetworkRequest.Builder().clearCapabilities().build();
+        return new NetworkRequest(request.networkCapabilities, ConnectivityManager.TYPE_NONE,
+                requestId, NetworkRequest.Type.NONE);
+    }
+
+    static void expectThrowable(Runnable block, Class<? extends Throwable> throwableType) {
+        try {
+            block.run();
+        } catch (Throwable t) {
+            if (t.getClass().equals(throwableType)) {
+                return;
+            }
+            fail("expected exception of type " + throwableType + ", but was " + t.getClass());
+        }
+        fail("expected exception of type " + throwableType);
+    }
 }