Merge "Fix FD leak in ConnectivityManager.getConnectionOwnerUid" am: 20bdc26698
am: 0b6d974713

Change-Id: I5d3b329018a1f640f086887ec222c5cf7e2975b4
diff --git a/services/net/java/android/net/netlink/InetDiagMessage.java b/services/net/java/android/net/netlink/InetDiagMessage.java
index af9e601..31a2556 100644
--- a/services/net/java/android/net/netlink/InetDiagMessage.java
+++ b/services/net/java/android/net/netlink/InetDiagMessage.java
@@ -16,26 +16,23 @@
 
 package android.net.netlink;
 
-import static android.os.Process.INVALID_UID;
 import static android.net.netlink.NetlinkConstants.SOCK_DIAG_BY_FAMILY;
 import static android.net.netlink.NetlinkSocket.DEFAULT_RECV_BUFSIZE;
 import static android.net.netlink.StructNlMsgHdr.NLM_F_DUMP;
 import static android.net.netlink.StructNlMsgHdr.NLM_F_REQUEST;
+import static android.os.Process.INVALID_UID;
 import static android.system.OsConstants.AF_INET;
 import static android.system.OsConstants.AF_INET6;
 import static android.system.OsConstants.IPPROTO_UDP;
 import static android.system.OsConstants.NETLINK_INET_DIAG;
 
-import android.os.Build;
-import android.os.Process;
+import android.net.util.SocketUtils;
 import android.system.ErrnoException;
 import android.util.Log;
 
 import java.io.FileDescriptor;
+import java.io.IOException;
 import java.io.InterruptedIOException;
-import java.net.DatagramSocket;
-import java.net.DatagramSocket;
-import java.net.InetAddress;
 import java.net.Inet4Address;
 import java.net.Inet6Address;
 import java.net.InetSocketAddress;
@@ -163,17 +160,25 @@
      */
     public static int getConnectionOwnerUid(int protocol, InetSocketAddress local,
                                             InetSocketAddress remote) {
+        int uid = INVALID_UID;
+        FileDescriptor fd = null;
         try {
-            final FileDescriptor fd = NetlinkSocket.forProto(NETLINK_INET_DIAG);
+            fd = NetlinkSocket.forProto(NETLINK_INET_DIAG);
             NetlinkSocket.connectToKernel(fd);
-
-            return lookupUid(protocol, local, remote, fd);
-
+            uid = lookupUid(protocol, local, remote, fd);
         } catch (ErrnoException | SocketException | IllegalArgumentException
                 | InterruptedIOException e) {
             Log.e(TAG, e.toString());
+        } finally {
+            if (fd != null) {
+                try {
+                    SocketUtils.closeSocket(fd);
+                } catch (IOException e) {
+                    Log.e(TAG, e.toString());
+                }
+            }
         }
-        return INVALID_UID;
+        return uid;
     }
 
     @Override
diff --git a/tests/net/java/android/net/netlink/InetDiagSocketTest.java b/tests/net/java/android/net/netlink/InetDiagSocketTest.java
index 2adbb06..b4f6e99 100644
--- a/tests/net/java/android/net/netlink/InetDiagSocketTest.java
+++ b/tests/net/java/android/net/netlink/InetDiagSocketTest.java
@@ -18,7 +18,6 @@
 
 import static android.net.netlink.StructNlMsgHdr.NLM_F_DUMP;
 import static android.net.netlink.StructNlMsgHdr.NLM_F_REQUEST;
-import static android.os.Process.INVALID_UID;
 import static android.system.OsConstants.AF_INET;
 import static android.system.OsConstants.AF_INET6;
 import static android.system.OsConstants.IPPROTO_TCP;
@@ -28,6 +27,7 @@
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 
@@ -45,7 +45,6 @@
 import libcore.util.HexEncoding;
 
 import org.junit.Before;
-import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 
@@ -152,9 +151,13 @@
 
     private void checkConnectionOwnerUid(int protocol, InetSocketAddress local,
                                          InetSocketAddress remote, boolean expectSuccess) {
-        final int expectedUid = expectSuccess ? Process.myUid() : INVALID_UID;
         final int uid = mCm.getConnectionOwnerUid(protocol, local, remote);
-        assertEquals(expectedUid, uid);
+
+        if (expectSuccess) {
+            assertEquals(Process.myUid(), uid);
+        } else {
+            assertNotEquals(Process.myUid(), uid);
+        }
     }
 
     private int findLikelyFreeUdpPort(UdpConnection conn) throws Exception {
@@ -165,11 +168,11 @@
         return localPort;
     }
 
+    /**
+     * Create a test connection for UDP and TCP sockets and verify that this
+     * {protocol, local, remote} socket result in receiving a valid UID.
+     */
     public void checkGetConnectionOwnerUid(String to, String from) throws Exception {
-        /**
-         * For TCP connections, create a test connection and verify that this
-         * {protocol, local, remote} socket result in receiving a valid UID.
-         */
         TcpConnection tcp = new TcpConnection(to, from);
         checkConnectionOwnerUid(tcp.protocol, tcp.local, tcp.remote, true);
         checkConnectionOwnerUid(IPPROTO_UDP, tcp.local, tcp.remote, false);
@@ -177,20 +180,14 @@
         checkConnectionOwnerUid(tcp.protocol, tcp.local, new InetSocketAddress(0), false);
         tcp.close();
 
-        /**
-         * For UDP connections, either a complete match {protocol, local, remote} or a
-         * partial match {protocol, local} should return a valid UID.
-         */
         UdpConnection udp = new UdpConnection(to,from);
         checkConnectionOwnerUid(udp.protocol, udp.local, udp.remote, true);
-        checkConnectionOwnerUid(udp.protocol, udp.local, new InetSocketAddress(0), true);
         checkConnectionOwnerUid(IPPROTO_TCP, udp.local, udp.remote, false);
         checkConnectionOwnerUid(udp.protocol, new InetSocketAddress(findLikelyFreeUdpPort(udp)),
                 udp.remote, false);
         udp.close();
     }
 
-    @Ignore
     @Test
     public void testGetConnectionOwnerUid() throws Exception {
         checkGetConnectionOwnerUid("::", null);
@@ -203,6 +200,16 @@
         checkGetConnectionOwnerUid("::1", "::1");
     }
 
+    /* Verify fix for b/141603906 */
+    @Test
+    public void testB141603906() throws Exception {
+        final InetSocketAddress src = new InetSocketAddress(0);
+        final InetSocketAddress dst = new InetSocketAddress(0);
+        for (int i = 1; i <= 100000; i++) {
+            mCm.getConnectionOwnerUid(IPPROTO_TCP, src, dst);
+        }
+    }
+
     // Hexadecimal representation of InetDiagReqV2 request.
     private static final String INET_DIAG_REQ_V2_UDP_INET4_HEX =
             // struct nlmsghdr