Use SOCK_DESTROY in netd.

Bug: 26976388
Change-Id: I1965ece8ae65d78323b5a49eeebefe29677be63f
diff --git a/server/NetlinkHandler.cpp b/server/NetlinkHandler.cpp
index 97dc3e0..718fbdb 100644
--- a/server/NetlinkHandler.cpp
+++ b/server/NetlinkHandler.cpp
@@ -29,6 +29,7 @@
 #include "NetlinkHandler.h"
 #include "NetlinkManager.h"
 #include "ResponseCode.h"
+#include "SockDiag.h"
 
 static const char *kUpdated = "updated";
 static const char *kRemoved = "removed";
@@ -78,14 +79,36 @@
             const char *flags = evt->findParam("FLAGS");
             const char *scope = evt->findParam("SCOPE");
             if (action == NetlinkEvent::Action::kAddressRemoved && iface && address) {
-                int resetMask = strchr(address, ':') ? RESET_IPV6_ADDRESSES : RESET_IPV4_ADDRESSES;
-                resetMask |= RESET_IGNORE_INTERFACE_ADDRESS;
-                if (int ret = ifc_reset_connections(iface, resetMask)) {
-                    ALOGE("ifc_reset_connections failed on iface %s for address %s (%s)", iface,
-                          address, strerror(ret));
+                // Note: if this interface was deleted, iface is "" and we don't notify.
+                SockDiag sd;
+                if (sd.open()) {
+                    char addrstr[INET6_ADDRSTRLEN];
+                    strncpy(addrstr, address, sizeof(addrstr));
+                    char *slash = strchr(addrstr, '/');
+                    if (slash) {
+                        *slash = '\0';
+                    }
+
+                    int ret = sd.destroySockets(addrstr);
+                    if (ret < 0) {
+                        ALOGE("Error destroying sockets: %s", strerror(ret));
+                    }
+                } else {
+                    ALOGE("Error opening NETLINK_SOCK_DIAG socket: %s", strerror(errno));
+                }
+
+                // TODO: delete this once SOCK_DESTROY works everywhere.
+                if (iface[0]) {
+                    int resetMask = strchr(address, ':') ?
+                            RESET_IPV6_ADDRESSES : RESET_IPV4_ADDRESSES;
+                    resetMask |= RESET_IGNORE_INTERFACE_ADDRESS;
+                    if (int ret = ifc_reset_connections(iface, resetMask)) {
+                        ALOGE("ifc_reset_connections failed on iface %s for address %s (%s)", iface,
+                              address, strerror(ret));
+                    }
                 }
             }
-            if (iface && flags && scope) {
+            if (iface && iface[0] && address && flags && scope) {
                 notifyAddressChanged(action, address, iface, flags, scope);
             }
         } else if (action == NetlinkEvent::Action::kRdnss) {
diff --git a/server/SockDiag.cpp b/server/SockDiag.cpp
index 2f1437c..b9f69cd 100644
--- a/server/SockDiag.cpp
+++ b/server/SockDiag.cpp
@@ -33,6 +33,8 @@
 #include "NetdConstants.h"
 #include "SockDiag.h"
 
+#include <chrono>
+
 #ifndef SOCK_DESTROY
 #define SOCK_DESTROY 21
 #endif
@@ -208,6 +210,10 @@
 }
 
 int SockDiag::sockDestroy(uint8_t proto, const inet_diag_msg *msg) {
+    if (msg == nullptr) {
+       return 0;
+    }
+
     DestroyRequest request = {
         .nlh = {
             .nlmsg_type = SOCK_DESTROY,
@@ -226,5 +232,47 @@
         return -errno;
     }
 
-    return checkError(mWriteSock);
+    int ret = checkError(mWriteSock);
+    if (!ret) mSocketsDestroyed++;
+    return ret;
+}
+
+int SockDiag::destroySockets(uint8_t proto, int family, const char *addrstr) {
+    if (!hasSocks()) {
+        return -EBADFD;
+    }
+
+    if (int ret = sendDumpRequest(proto, family, addrstr)) {
+        return ret;
+    }
+
+    auto destroy = [this] (uint8_t proto, const inet_diag_msg *msg) {
+        return this->sockDestroy(proto, msg);
+    };
+
+    return readDiagMsg(proto, destroy);
+}
+
+int SockDiag::destroySockets(const char *addrstr) {
+    using ms = std::chrono::duration<float, std::ratio<1, 1000>>;
+
+    mSocketsDestroyed = 0;
+    const auto start = std::chrono::steady_clock::now();
+    if (!strchr(addrstr, ':')) {
+        if (int ret = destroySockets(IPPROTO_TCP, AF_INET, addrstr)) {
+            ALOGE("Failed to destroy IPv4 sockets on %s: %s", addrstr, strerror(-ret));
+            return ret;
+        }
+    }
+    if (int ret = destroySockets(IPPROTO_TCP, AF_INET6, addrstr)) {
+        ALOGE("Failed to destroy IPv6 sockets on %s: %s", addrstr, strerror(-ret));
+        return ret;
+    }
+    auto elapsed = std::chrono::duration_cast<ms>(std::chrono::steady_clock::now() - start);
+
+    if (mSocketsDestroyed > 0) {
+        ALOGI("Destroyed %d sockets on %s in %.1f ms", mSocketsDestroyed, addrstr, elapsed.count());
+    }
+
+    return mSocketsDestroyed;
 }
diff --git a/server/SockDiag.h b/server/SockDiag.h
index 3b6ca8b..56acbdb 100644
--- a/server/SockDiag.h
+++ b/server/SockDiag.h
@@ -5,6 +5,7 @@
 #include <linux/inet_diag.h>
 
 struct inet_diag_msg;
+class SockDiagTest;
 
 class SockDiag {
 
@@ -17,17 +18,20 @@
         inet_diag_req_v2 req;
     } __attribute__((__packed__));
 
-    SockDiag() : mSock(-1), mWriteSock(-1) {}
+    SockDiag() : mSock(-1), mWriteSock(-1), mSocketsDestroyed(0) {}
     bool open();
     virtual ~SockDiag() { closeSocks(); }
 
     int sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr);
     int readDiagMsg(uint8_t proto, DumpCallback callback);
     int sockDestroy(uint8_t proto, const inet_diag_msg *);
+    int destroySockets(const char *addrstr);
 
   private:
     int mSock;
     int mWriteSock;
+    int mSocketsDestroyed;
+    int destroySockets(uint8_t proto, int family, const char *addrstr);
     bool hasSocks() { return mSock != -1 && mWriteSock != -1; }
     void closeSocks() { close(mSock); close(mWriteSock); mSock = mWriteSock = -1; }
 };