Close open sockets when enabling firewall rules.

When enabling a firewall rule that will deny networking to apps,
first close any sockets opened by those apps. Just dropping an
app's packets without closing its connections  has the following
problems:

1. The app has no way to know this has happened until a network
   timeout occurs.
2. The app's connections stay open, so the other end of the
   connection (e.g., a server) might continue to retransmit
   packets. These packets will wake up the kernel and cause
   battery drain, but we cannot respond to them because packets
   on those connections are dropped by the kernel (since the app
   is blackholed). So the other end might keep retransmitting.
3. Even though we think the connections are still open, the
   other end of the connection, or any intermediate NATs or
   firewalls, might time out and close the connection (e.g., by
   sending a RST). Because the app is blackholed, we have no way
   of knowing that this has happened, so when the app is granted
   network access again, these connections might just get stuck.

Bug: 27824851
Bug: 27867653
Change-Id: Iaaad1b26954fc5f1ba5c9ed8bdee039282f5e249
diff --git a/services/core/java/com/android/server/NetworkManagementService.java b/services/core/java/com/android/server/NetworkManagementService.java
index 7458898..9fc17a2 100644
--- a/services/core/java/com/android/server/NetworkManagementService.java
+++ b/services/core/java/com/android/server/NetworkManagementService.java
@@ -72,6 +72,7 @@
 import android.os.RemoteCallbackList;
 import android.os.RemoteException;
 import android.os.ServiceManager;
+import android.os.ServiceSpecificException;
 import android.os.StrictMode;
 import android.os.SystemClock;
 import android.os.SystemProperties;
@@ -2047,6 +2048,65 @@
         }
     }
 
+    private void closeSocketsForFirewallChain(int chain, String chainName) {
+        // UID ranges to close sockets on.
+        UidRange[] ranges;
+        // UID ranges whose sockets we won't touch.
+        int[] exemptUids;
+
+        SparseIntArray rules = getUidFirewallRules(chain);
+        int numUids = 0;
+
+        if (getFirewallType(chain) == FIREWALL_TYPE_WHITELIST) {
+            // Close all sockets on all non-system UIDs...
+            ranges = new UidRange[] {
+                // TODO: is there a better way of finding all existing users? If so, we could
+                // specify their ranges here.
+                new UidRange(Process.FIRST_APPLICATION_UID, Integer.MAX_VALUE),
+            };
+            // ... except for the UIDs that have allow rules.
+            exemptUids = new int[rules.size()];
+            for (int i = 0; i < exemptUids.length; i++) {
+                if (rules.valueAt(i) == NetworkPolicyManager.FIREWALL_RULE_ALLOW) {
+                    exemptUids[numUids] = rules.keyAt(i);
+                    numUids++;
+                }
+            }
+            // Normally, whitelist chains only contain deny rules, so numUids == exemptUids.length.
+            // But the code does not guarantee this in any way, and at least in one case - if we add
+            // a UID rule to the firewall, and then disable the firewall - the chains can contain
+            // the wrong type of rule. In this case, don't close connections that we shouldn't.
+            //
+            // TODO: tighten up this code by ensuring we never set the wrong type of rule, and
+            // fix setFirewallEnabled to grab mQuotaLock and clear rules.
+            if (numUids != exemptUids.length) {
+                exemptUids = Arrays.copyOf(exemptUids, numUids);
+            }
+        } else {
+            // Close sockets for every UID that has a deny rule...
+            ranges = new UidRange[rules.size()];
+            for (int i = 0; i < ranges.length; i++) {
+                if (rules.valueAt(i) == NetworkPolicyManager.FIREWALL_RULE_DENY) {
+                    int uid = rules.keyAt(i);
+                    ranges[numUids] = new UidRange(uid, uid);
+                    numUids++;
+                }
+            }
+            // As above; usually numUids == ranges.length, but not always.
+            if (numUids != ranges.length) {
+                ranges = Arrays.copyOf(ranges, numUids);
+            }
+            // ... with no exceptions.
+            exemptUids = new int[0];
+        }
+
+        try {
+            mNetdService.socketDestroy(ranges, exemptUids);
+        } catch(RemoteException | ServiceSpecificException e) {
+            Slog.e(TAG, "Error closing sockets after enabling chain " + chainName + ": " + e);
+        }
+    }
+
     @Override
     public void setFirewallChainEnabled(int chain, boolean enable) {
         enforceSystemUid();
@@ -2059,25 +2119,35 @@
             mFirewallChainStates.put(chain, enable);
 
             final String operation = enable ? "enable_chain" : "disable_chain";
+            String chainName;
+            switch(chain) {
+                case FIREWALL_CHAIN_STANDBY:
+                    chainName = FIREWALL_CHAIN_NAME_STANDBY;
+                    break;
+                case FIREWALL_CHAIN_DOZABLE:
+                    chainName = FIREWALL_CHAIN_NAME_DOZABLE;
+                    break;
+                case FIREWALL_CHAIN_POWERSAVE:
+                    chainName = FIREWALL_CHAIN_NAME_POWERSAVE;
+                    break;
+                default:
+                    throw new IllegalArgumentException("Bad child chain: " + chain);
+            }
+
             try {
-                String chainName;
-                switch(chain) {
-                    case FIREWALL_CHAIN_STANDBY:
-                        chainName = FIREWALL_CHAIN_NAME_STANDBY;
-                        break;
-                    case FIREWALL_CHAIN_DOZABLE:
-                        chainName = FIREWALL_CHAIN_NAME_DOZABLE;
-                        break;
-                    case FIREWALL_CHAIN_POWERSAVE:
-                        chainName = FIREWALL_CHAIN_NAME_POWERSAVE;
-                        break;
-                    default:
-                        throw new IllegalArgumentException("Bad child chain: " + chain);
-                }
                 mConnector.execute("firewall", operation, chainName);
             } catch (NativeDaemonConnectorException e) {
                 throw e.rethrowAsParcelableException();
             }
+
+            // Close any sockets that were opened by the affected UIDs. This has to be done after
+            // disabling network connectivity, in case they react to the socket close by reopening
+            // the connection and race with the iptables commands that enable the firewall. All
+            // whitelist and blacklist chains allow RSTs through.
+            if (enable) {
+                if (DBG) Slog.d(TAG, "Closing sockets after enabling chain " + chainName);
+                closeSocketsForFirewallChain(chain, chainName);
+            }
         }
     }
 
@@ -2376,7 +2446,7 @@
     }
 
     private void dumpUidFirewallRule(PrintWriter pw, String name, SparseIntArray rules) {
-        pw.print("UID firewall");
+        pw.print("UID firewall ");
         pw.print(name);
         pw.print(" rule: [");
         final int size = rules.size();