Close sockets when changing network permissions.

Bug: 23113288

(cherry picked from commit c6201c3754710e235f16118761b23760ff4136ad)

Change-Id: I1407644e06e960e453a31b298e46ef866f0eebd2
diff --git a/server/SockDiag.cpp b/server/SockDiag.cpp
index 55d8dea..f6b8b10 100644
--- a/server/SockDiag.cpp
+++ b/server/SockDiag.cpp
@@ -31,7 +31,9 @@
 #include <android-base/strings.h>
 #include <cutils/log.h>
 
+#include "Fwmark.h"
 #include "NetdConstants.h"
+#include "Permission.h"
 #include "SockDiag.h"
 
 #include <chrono>
@@ -40,6 +42,8 @@
 #define SOCK_DESTROY 21
 #endif
 
+#define INET_DIAG_BC_MARK_COND 10
+
 namespace {
 
 int checkError(int fd) {
@@ -186,9 +190,9 @@
     attrs.nla.nla_len = sizeof(attrs) + addrlen;
 
     iovec iov[] = {
-        { nullptr, 0 },
-        { &attrs, sizeof(attrs) },
-        { addr, addrlen },
+        { nullptr,           0 },
+        { &attrs,            sizeof(attrs) },
+        { addr,              addrlen },
     };
 
     uint32_t states = ~(1 << TCP_TIME_WAIT);
@@ -316,18 +320,19 @@
     return mSocketsDestroyed;
 }
 
-int SockDiag::destroyLiveSockets(DumpCallback destroyFilter) {
+int SockDiag::destroyLiveSockets(DumpCallback destroyFilter, const char *what,
+                                 iovec *iov, int iovcnt) {
     int proto = IPPROTO_TCP;
 
     for (const int family : {AF_INET, AF_INET6}) {
         const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
         uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
-        if (int ret = sendDumpRequest(proto, family, states)) {
-            ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
+        if (int ret = sendDumpRequest(proto, family, states, iov, iovcnt)) {
+            ALOGE("Failed to dump %s sockets for %s: %s", familyName, what, strerror(-ret));
             return ret;
         }
         if (int ret = readDiagMsg(proto, destroyFilter)) {
-            ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
+            ALOGE("Failed to destroy %s sockets for %s: %s", familyName, what, strerror(-ret));
             return ret;
         }
     }
@@ -377,7 +382,11 @@
                !(excludeLoopback && isLoopbackSocket(msg));
     };
 
-    if (int ret = destroyLiveSockets(shouldDestroy)) {
+    iovec iov[] = {
+        { nullptr, 0 },
+    };
+
+    if (int ret = destroyLiveSockets(shouldDestroy, "UID", iov, ARRAY_SIZE(iov))) {
         return ret;
     }
 
@@ -395,3 +404,95 @@
 
     return 0;
 }
+
+// Destroys all "live" (CONNECTED, SYN_SENT, SYN_RECV) TCP sockets on the specified netId where:
+// 1. The opening app no longer has permission to use this network, or:
+// 2. The opening app does have permission, but did not explicitly select this network.
+//
+// We destroy sockets without the explicit bit because we want to avoid the situation where a
+// privileged app uses its privileges without knowing it is doing so. For example, a privileged app
+// might have opened a socket on this network just because it was the default network at the
+// time. If we don't kill these sockets, those apps could continue to use them without realizing
+// that they are now sending and receiving traffic on a network that is now restricted.
+int SockDiag::destroySocketsLackingPermission(unsigned netId, Permission permission,
+                                              bool excludeLoopback) {
+    struct markmatch {
+        inet_diag_bc_op op;
+        // TODO: switch to inet_diag_markcond
+        __u32 mark;
+        __u32 mask;
+    } __attribute__((packed));
+    constexpr uint8_t matchlen = sizeof(markmatch);
+
+    Fwmark netIdMark, netIdMask;
+    netIdMark.netId = netId;
+    netIdMask.netId = 0xffff;
+
+    Fwmark controlMark;
+    controlMark.explicitlySelected = true;
+    controlMark.permission = permission;
+
+    // A SOCK_DIAG bytecode program that accepts the sockets we intend to destroy.
+    struct bytecode {
+        markmatch netIdMatch;
+        markmatch controlMatch;
+        inet_diag_bc_op controlJump;
+    } __attribute__((packed)) bytecode;
+
+    // The length of the INET_DIAG_BC_JMP instruction.
+    constexpr uint8_t jmplen = sizeof(inet_diag_bc_op);
+    // Jump exactly this far past the end of the program to reject.
+    constexpr uint8_t rejectoffset = sizeof(inet_diag_bc_op);
+    // Total length of the program.
+    constexpr uint8_t bytecodelen = sizeof(bytecode);
+
+    bytecode = (struct bytecode) {
+        // If netId matches, continue, otherwise, reject (i.e., leave socket alone).
+        { { INET_DIAG_BC_MARK_COND, matchlen, bytecodelen + rejectoffset },
+          netIdMark.intValue, netIdMask.intValue },
+
+        // If explicit and permission bits match, go to the JMP below which rejects the socket
+        // (i.e., we leave it alone). Otherwise, jump to the end of the program, which accepts the
+        // socket (so we destroy it).
+        { { INET_DIAG_BC_MARK_COND, matchlen, matchlen + jmplen },
+          controlMark.intValue, controlMark.intValue },
+
+        // This JMP unconditionally rejects the packet by jumping to the reject target. It is
+        // necessary to keep the kernel bytecode verifier happy. If we don't have a JMP the bytecode
+        // is invalid because the target of every no jump must always be reachable by yes jumps.
+        // Without this JMP, the accept target is not reachable by yes jumps and the program will
+        // be rejected by the validator.
+        { INET_DIAG_BC_JMP, jmplen, jmplen + rejectoffset },
+
+        // We have reached the end of the program. Accept the socket, and destroy it below.
+    };
+
+    struct nlattr nla = {
+        .nla_type = INET_DIAG_REQ_BYTECODE,
+        .nla_len = sizeof(struct nlattr) + bytecodelen,
+    };
+
+    iovec iov[] = {
+        { nullptr,   0 },
+        { &nla,      sizeof(nla) },
+        { &bytecode, bytecodelen },
+    };
+
+    mSocketsDestroyed = 0;
+    Stopwatch s;
+
+    auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
+        return msg != nullptr && !(excludeLoopback && isLoopbackSocket(msg));
+    };
+
+    if (int ret = destroyLiveSockets(shouldDestroy, "permission change", iov, ARRAY_SIZE(iov))) {
+        return ret;
+    }
+
+    if (mSocketsDestroyed > 0) {
+        ALOGI("Destroyed %d sockets for netId %d permission=%d in %.1f ms",
+              mSocketsDestroyed, netId, permission, s.timeTaken());
+    }
+
+    return 0;
+}