Kill sockets when a VPN comes up.

1. Change the SockDiag callback function to be a filter that
   returns a bool instead of a function that optionally kills a
   socket. All existing callbacks basically only existed to kill
   sockets under certain conditions, and making them return a
   boolean allows reusing the same callback function signature
   to filter sockets as well.
2. Add a new SockDiag method to kill sockets based on a UidRanges
   object (which contains a number of UID ranges) and a list of
   users to skip.
3. Add a new UIDRANGE mode to SockDiagTest to test the above.
4. When UID ranges are added or removed from the VPN, kill
   sockets in those UID ranges unless the socket UIDs are in
   mProtectableUsers and thus their creator might have set the
   protect bit on their mark.  Short of actually being
   able to see the socket mark on each socket and basing our
   decision on that, this is the best we can do.

Bug: 26976388
Change-Id: I53a30df3feb63254a6451a29fa6041c9b679f9bb
diff --git a/server/SockDiag.cpp b/server/SockDiag.cpp
index 57ba19c..6a39997 100644
--- a/server/SockDiag.cpp
+++ b/server/SockDiag.cpp
@@ -28,6 +28,7 @@
 
 #define LOG_TAG "Netd"
 
+#include <android-base/strings.h>
 #include <cutils/log.h>
 
 #include "NetdConstants.h"
@@ -239,7 +240,9 @@
               }
               default:
                 inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
-                callback(proto, msg);
+                if (callback(proto, msg)) {
+                    sockDestroy(proto, msg);
+                }
             }
         }
     } while (bytesread > 0);
@@ -284,11 +287,9 @@
         return ret;
     }
 
-    auto destroy = [this] (uint8_t proto, const inet_diag_msg *msg) {
-        return this->sockDestroy(proto, msg);
-    };
+    auto destroyAll = [] (uint8_t, const inet_diag_msg*) { return true; };
 
-    return readDiagMsg(proto, destroy);
+    return readDiagMsg(proto, destroyAll);
 }
 
 int SockDiag::destroySockets(const char *addrstr) {
@@ -313,16 +314,31 @@
     return mSocketsDestroyed;
 }
 
+int SockDiag::destroyLiveSockets(DumpCallback destroyFilter) {
+    int proto = IPPROTO_TCP;
+
+    for (const int family : {AF_INET, AF_INET6}) {
+        const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
+        uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
+        if (int ret = sendDumpRequest(proto, family, states)) {
+            ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
+            return ret;
+        }
+        if (int ret = readDiagMsg(proto, destroyFilter)) {
+            ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
+            return ret;
+        }
+    }
+
+    return 0;
+}
+
 int SockDiag::destroySockets(uint8_t proto, const uid_t uid) {
     mSocketsDestroyed = 0;
     Stopwatch s;
 
-    auto destroy = [this, uid] (uint8_t proto, const inet_diag_msg *msg) {
-        if (msg != nullptr && msg->idiag_uid == uid) {
-            return this->sockDestroy(proto, msg);
-        } else {
-            return 0;
-        }
+    auto shouldDestroy = [uid] (uint8_t, const inet_diag_msg *msg) {
+        return (msg != nullptr && msg->idiag_uid == uid);
     };
 
     for (const int family : {AF_INET, AF_INET6}) {
@@ -332,7 +348,7 @@
             ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
             return ret;
         }
-        if (int ret = readDiagMsg(proto, destroy)) {
+        if (int ret = readDiagMsg(proto, shouldDestroy)) {
             ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
             return ret;
         }
@@ -344,3 +360,32 @@
 
     return 0;
 }
+
+int SockDiag::destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids) {
+    mSocketsDestroyed = 0;
+    Stopwatch s;
+
+    auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
+        return msg != nullptr &&
+               uidRanges.hasUid(msg->idiag_uid) &&
+               skipUids.find(msg->idiag_uid) == skipUids.end();
+    };
+
+    if (int ret = destroyLiveSockets(shouldDestroy)) {
+        return ret;
+    }
+
+    std::vector<uid_t> skipUidStrings;
+    for (uid_t uid : skipUids) {
+        skipUidStrings.push_back(uid);
+    }
+    std::sort(skipUidStrings.begin(), skipUidStrings.end());
+
+    if (mSocketsDestroyed > 0) {
+        ALOGI("Destroyed %d sockets for %s skip={%s} in %.1f ms",
+              mSocketsDestroyed, uidRanges.toString().c_str(),
+              android::base::Join(skipUidStrings, " ").c_str(), s.timeTaken());
+    }
+
+    return 0;
+}