StrictMode to detect untagged network traffic.

Network usage is tracked by the kernel at the UID level, which is
granular enough for normal apps, but large components (such as the
system server) are impossible to debug without adding additional
socket tagging to help identify subsystems within a UID.

To help ensure that system components tag all their network traffic,
this change offers a new StrictMode option to detect and report
untagged sockets.

Test: builds, boots, all common traffic tagged
Bug: 30943431, 30414041
Change-Id: I825c7941076054732264690247de2863342638e2
diff --git a/api/current.txt b/api/current.txt
index 21ff81e..5c4bdd0 100644
--- a/api/current.txt
+++ b/api/current.txt
@@ -24481,6 +24481,7 @@
   public class TrafficStats {
     ctor public TrafficStats();
     method public static void clearThreadStatsTag();
+    method public static int getAndSetThreadStatsTag(int);
     method public static long getMobileRxBytes();
     method public static long getMobileRxPackets();
     method public static long getMobileTxBytes();
@@ -30062,6 +30063,7 @@
     method public android.os.StrictMode.VmPolicy.Builder detectLeakedClosableObjects();
     method public android.os.StrictMode.VmPolicy.Builder detectLeakedRegistrationObjects();
     method public android.os.StrictMode.VmPolicy.Builder detectLeakedSqlLiteObjects();
+    method public android.os.StrictMode.VmPolicy.Builder detectUntaggedSockets();
     method public android.os.StrictMode.VmPolicy.Builder penaltyDeath();
     method public android.os.StrictMode.VmPolicy.Builder penaltyDeathOnCleartextNetwork();
     method public android.os.StrictMode.VmPolicy.Builder penaltyDeathOnFileUriExposure();
diff --git a/api/system-current.txt b/api/system-current.txt
index 1f590dd..f6317a9 100644
--- a/api/system-current.txt
+++ b/api/system-current.txt
@@ -26460,6 +26460,7 @@
     ctor public TrafficStats();
     method public static void clearThreadStatsTag();
     method public static void clearThreadStatsUid();
+    method public static int getAndSetThreadStatsTag(int);
     method public static long getMobileRxBytes();
     method public static long getMobileRxPackets();
     method public static long getMobileTxBytes();
@@ -32706,6 +32707,7 @@
     method public android.os.StrictMode.VmPolicy.Builder detectLeakedClosableObjects();
     method public android.os.StrictMode.VmPolicy.Builder detectLeakedRegistrationObjects();
     method public android.os.StrictMode.VmPolicy.Builder detectLeakedSqlLiteObjects();
+    method public android.os.StrictMode.VmPolicy.Builder detectUntaggedSockets();
     method public android.os.StrictMode.VmPolicy.Builder penaltyDeath();
     method public android.os.StrictMode.VmPolicy.Builder penaltyDeathOnCleartextNetwork();
     method public android.os.StrictMode.VmPolicy.Builder penaltyDeathOnFileUriExposure();
diff --git a/api/test-current.txt b/api/test-current.txt
index f4e71e3..9e6d115 100644
--- a/api/test-current.txt
+++ b/api/test-current.txt
@@ -24571,6 +24571,7 @@
   public class TrafficStats {
     ctor public TrafficStats();
     method public static void clearThreadStatsTag();
+    method public static int getAndSetThreadStatsTag(int);
     method public static long getMobileRxBytes();
     method public static long getMobileRxPackets();
     method public static long getMobileTxBytes();
@@ -30174,6 +30175,7 @@
     method public android.os.StrictMode.VmPolicy.Builder detectLeakedClosableObjects();
     method public android.os.StrictMode.VmPolicy.Builder detectLeakedRegistrationObjects();
     method public android.os.StrictMode.VmPolicy.Builder detectLeakedSqlLiteObjects();
+    method public android.os.StrictMode.VmPolicy.Builder detectUntaggedSockets();
     method public android.os.StrictMode.VmPolicy.Builder penaltyDeath();
     method public android.os.StrictMode.VmPolicy.Builder penaltyDeathOnCleartextNetwork();
     method public android.os.StrictMode.VmPolicy.Builder penaltyDeathOnFileUriExposure();
diff --git a/core/java/android/net/SntpClient.java b/core/java/android/net/SntpClient.java
index cea56b5..ffc735c 100644
--- a/core/java/android/net/SntpClient.java
+++ b/core/java/android/net/SntpClient.java
@@ -96,6 +96,7 @@
 
     public boolean requestTime(InetAddress address, int port, int timeout) {
         DatagramSocket socket = null;
+        final int oldTag = TrafficStats.getAndSetThreadStatsTag(TrafficStats.TAG_SYSTEM_NTP);
         try {
             socket = new DatagramSocket();
             socket.setSoTimeout(timeout);
@@ -161,6 +162,7 @@
             if (socket != null) {
                 socket.close();
             }
+            TrafficStats.setThreadStatsTag(oldTag);
         }
 
         return true;
diff --git a/core/java/android/net/TrafficStats.java b/core/java/android/net/TrafficStats.java
index e7436be..fc66395 100644
--- a/core/java/android/net/TrafficStats.java
+++ b/core/java/android/net/TrafficStats.java
@@ -168,6 +168,24 @@
 
     /**
      * Set active tag to use when accounting {@link Socket} traffic originating
+     * from the current thread. Only one active tag per thread is supported.
+     * <p>
+     * Changes only take effect during subsequent calls to
+     * {@link #tagSocket(Socket)}.
+     * <p>
+     * Tags between {@code 0xFFFFFF00} and {@code 0xFFFFFFFF} are reserved and
+     * used internally by system services like {@link DownloadManager} when
+     * performing traffic on behalf of an application.
+     *
+     * @return the current tag for the calling thread, which can be used to
+     *         restore any existing values after a nested operation is finished
+     */
+    public static int getAndSetThreadStatsTag(int tag) {
+        return NetworkManagementSocketTagger.setThreadSocketStatsTag(tag);
+    }
+
+    /**
+     * Set active tag to use when accounting {@link Socket} traffic originating
      * from the current thread. The tag used internally is well-defined to
      * distinguish all backup-related traffic.
      *
diff --git a/core/java/android/os/StrictMode.java b/core/java/android/os/StrictMode.java
index 0da4bd1..ae981b7 100644
--- a/core/java/android/os/StrictMode.java
+++ b/core/java/android/os/StrictMode.java
@@ -24,6 +24,7 @@
 import android.content.Context;
 import android.content.Intent;
 import android.content.ServiceConnection;
+import android.net.TrafficStats;
 import android.net.Uri;
 import android.util.ArrayMap;
 import android.util.Log;
@@ -245,11 +246,17 @@
      */
     private static final int DETECT_VM_CONTENT_URI_WITHOUT_PERMISSION = 0x80 << 8;  // for VmPolicy
 
+    /**
+     * @hide
+     */
+    private static final int DETECT_VM_UNTAGGED_SOCKET = 0x80 << 24;  // for VmPolicy
+
     private static final int ALL_VM_DETECT_BITS =
             DETECT_VM_CURSOR_LEAKS | DETECT_VM_CLOSABLE_LEAKS |
             DETECT_VM_ACTIVITY_LEAKS | DETECT_VM_INSTANCE_LEAKS |
             DETECT_VM_REGISTRATION_LEAKS | DETECT_VM_FILE_URI_EXPOSURE |
-            DETECT_VM_CLEARTEXT_NETWORK | DETECT_VM_CONTENT_URI_WITHOUT_PERMISSION;
+            DETECT_VM_CLEARTEXT_NETWORK | DETECT_VM_CONTENT_URI_WITHOUT_PERMISSION |
+            DETECT_VM_UNTAGGED_SOCKET;
 
     // Byte 3: Penalty
 
@@ -300,6 +307,8 @@
      */
     public static final int PENALTY_DEATH_ON_FILE_URI_EXPOSURE = 0x04 << 24;
 
+    // CAUTION: we started stealing the top bits of Byte 4 for VM above
+
     /**
      * Mask of all the penalty bits valid for thread policies.
      */
@@ -715,7 +724,8 @@
             public Builder detectAll() {
                 int flags = DETECT_VM_ACTIVITY_LEAKS | DETECT_VM_CURSOR_LEAKS
                         | DETECT_VM_CLOSABLE_LEAKS | DETECT_VM_REGISTRATION_LEAKS
-                        | DETECT_VM_FILE_URI_EXPOSURE | DETECT_VM_CONTENT_URI_WITHOUT_PERMISSION;
+                        | DETECT_VM_FILE_URI_EXPOSURE | DETECT_VM_CONTENT_URI_WITHOUT_PERMISSION
+                        | DETECT_VM_UNTAGGED_SOCKET;
 
                 // TODO: always add DETECT_VM_CLEARTEXT_NETWORK once we have facility
                 // for apps to mark sockets that should be ignored
@@ -820,6 +830,22 @@
             }
 
             /**
+             * Detect any sockets in the calling app which have not been tagged
+             * using {@link TrafficStats}. Tagging sockets can help you
+             * investigate network usage inside your app, such as a narrowing
+             * down heavy usage to a specific library or component.
+             * <p>
+             * This currently does not detect sockets created in native code.
+             *
+             * @see TrafficStats#setThreadStatsTag(int)
+             * @see TrafficStats#tagSocket(java.net.Socket)
+             * @see TrafficStats#tagDatagramSocket(java.net.DatagramSocket)
+             */
+            public Builder detectUntaggedSockets() {
+                return enable(DETECT_VM_UNTAGGED_SOCKET);
+            }
+
+            /**
              * Crashes the whole process on violation. This penalty runs at the
              * end of all enabled penalties so you'll still get your logging or
              * other violations before the process dies.
@@ -1152,6 +1178,11 @@
             if (IS_ENG_BUILD) {
                 policyBuilder.penaltyLog();
             }
+            // All core system components need to tag their sockets to aid
+            // system health investigations
+            if (android.os.Process.myUid() < android.os.Process.FIRST_APPLICATION_UID) {
+                policyBuilder.detectUntaggedSockets();
+            }
             setVmPolicy(policyBuilder.build());
             setCloseGuardEnabled(vmClosableObjectLeaksEnabled());
         }
@@ -1832,6 +1863,13 @@
     /**
      * @hide
      */
+    public static boolean vmUntaggedSocketEnabled() {
+        return (sVmPolicyMask & DETECT_VM_UNTAGGED_SOCKET) != 0;
+    }
+
+    /**
+     * @hide
+     */
     public static void onSqliteObjectLeaked(String message, Throwable originStack) {
         onVmPolicyViolation(message, originStack);
     }
@@ -1911,6 +1949,14 @@
                 forceDeath);
     }
 
+    /**
+     * @hide
+     */
+    public static void onUntaggedSocket() {
+        onVmPolicyViolation(null, new Throwable("Untagged socket detected; use"
+                + " TrafficStats.setThreadSocketTag() to track all network usage"));
+    }
+
     // Map from VM violation fingerprint to uptime millis.
     private static final HashMap<Integer, Long> sLastVmViolationTime = new HashMap<Integer, Long>();
 
diff --git a/core/java/com/android/server/NetworkManagementSocketTagger.java b/core/java/com/android/server/NetworkManagementSocketTagger.java
index 06ef4c9..03f2bc1 100644
--- a/core/java/com/android/server/NetworkManagementSocketTagger.java
+++ b/core/java/com/android/server/NetworkManagementSocketTagger.java
@@ -16,6 +16,7 @@
 
 package com.android.server;
 
+import android.os.StrictMode;
 import android.os.SystemProperties;
 import android.util.Log;
 import android.util.Slog;
@@ -50,16 +51,20 @@
         SocketTagger.set(new NetworkManagementSocketTagger());
     }
 
-    public static void setThreadSocketStatsTag(int tag) {
+    public static int setThreadSocketStatsTag(int tag) {
+        final int old = threadSocketTags.get().statsTag;
         threadSocketTags.get().statsTag = tag;
+        return old;
     }
 
     public static int getThreadSocketStatsTag() {
         return threadSocketTags.get().statsTag;
     }
 
-    public static void setThreadSocketStatsUid(int uid) {
+    public static int setThreadSocketStatsUid(int uid) {
+        final int old = threadSocketTags.get().statsUid;
         threadSocketTags.get().statsUid = uid;
+        return old;
     }
 
     @Override
@@ -69,6 +74,9 @@
             Log.d(TAG, "tagSocket(" + fd.getInt$() + ") with statsTag=0x"
                     + Integer.toHexString(options.statsTag) + ", statsUid=" + options.statsUid);
         }
+        if (options.statsTag == -1 && StrictMode.vmUntaggedSocketEnabled()) {
+            StrictMode.onUntaggedSocket();
+        }
         // TODO: skip tagging when options would be no-op
         tagSocketFd(fd, options.statsTag, options.statsUid);
     }
diff --git a/services/core/java/com/android/server/connectivity/NetworkDiagnostics.java b/services/core/java/com/android/server/connectivity/NetworkDiagnostics.java
index 5f9efe7..85d1d1e 100644
--- a/services/core/java/com/android/server/connectivity/NetworkDiagnostics.java
+++ b/services/core/java/com/android/server/connectivity/NetworkDiagnostics.java
@@ -23,6 +23,7 @@
 import android.net.Network;
 import android.net.NetworkUtils;
 import android.net.RouteInfo;
+import android.net.TrafficStats;
 import android.os.SystemClock;
 import android.system.ErrnoException;
 import android.system.Os;
@@ -381,7 +382,12 @@
         protected void setupSocket(
                 int sockType, int protocol, long writeTimeout, long readTimeout, int dstPort)
                 throws ErrnoException, IOException {
-            mFileDescriptor = Os.socket(mAddressFamily, sockType, protocol);
+            final int oldTag = TrafficStats.getAndSetThreadStatsTag(TrafficStats.TAG_SYSTEM_PROBE);
+            try {
+                mFileDescriptor = Os.socket(mAddressFamily, sockType, protocol);
+            } finally {
+                TrafficStats.setThreadStatsTag(oldTag);
+            }
             // Setting SNDTIMEO is purely for defensive purposes.
             Os.setsockoptTimeval(mFileDescriptor,
                     SOL_SOCKET, SO_SNDTIMEO, StructTimeval.fromMillis(writeTimeout));
diff --git a/services/core/java/com/android/server/connectivity/NetworkMonitor.java b/services/core/java/com/android/server/connectivity/NetworkMonitor.java
index c40780e..fbda901 100644
--- a/services/core/java/com/android/server/connectivity/NetworkMonitor.java
+++ b/services/core/java/com/android/server/connectivity/NetworkMonitor.java
@@ -779,6 +779,7 @@
         int httpResponseCode = 599;
         String redirectUrl = null;
         final Stopwatch probeTimer = new Stopwatch().start();
+        final int oldTag = TrafficStats.getAndSetThreadStatsTag(TrafficStats.TAG_SYSTEM_PROBE);
         try {
             urlConnection = (HttpURLConnection) mNetworkAgentInfo.network.openConnection(url);
             urlConnection.setInstanceFollowRedirects(probeType == ValidationProbeEvent.PROBE_PAC);
@@ -839,6 +840,7 @@
             if (urlConnection != null) {
                 urlConnection.disconnect();
             }
+            TrafficStats.setThreadStatsTag(oldTag);
         }
         logValidationProbe(probeTimer.stop(), probeType, httpResponseCode);
         return new CaptivePortalProbeResult(httpResponseCode, redirectUrl, url.toString());
diff --git a/services/core/java/com/android/server/connectivity/PacManager.java b/services/core/java/com/android/server/connectivity/PacManager.java
index 58c76ec..34826b6 100644
--- a/services/core/java/com/android/server/connectivity/PacManager.java
+++ b/services/core/java/com/android/server/connectivity/PacManager.java
@@ -25,6 +25,7 @@
 import android.content.IntentFilter;
 import android.content.ServiceConnection;
 import android.net.ProxyInfo;
+import android.net.TrafficStats;
 import android.net.Uri;
 import android.os.Handler;
 import android.os.HandlerThread;
@@ -103,11 +104,15 @@
             String file;
             synchronized (mProxyLock) {
                 if (Uri.EMPTY.equals(mPacUrl)) return;
+                final int oldTag = TrafficStats
+                        .getAndSetThreadStatsTag(TrafficStats.TAG_SYSTEM_PAC);
                 try {
                     file = get(mPacUrl);
                 } catch (IOException ioe) {
                     file = null;
                     Log.w(TAG, "Failed to load PAC file: " + ioe);
+                } finally {
+                    TrafficStats.setThreadStatsTag(oldTag);
                 }
             }
             if (file != null) {
diff --git a/services/core/java/com/android/server/location/GpsXtraDownloader.java b/services/core/java/com/android/server/location/GpsXtraDownloader.java
index bf779859..62332c9 100644
--- a/services/core/java/com/android/server/location/GpsXtraDownloader.java
+++ b/services/core/java/com/android/server/location/GpsXtraDownloader.java
@@ -16,22 +16,19 @@
 
 package com.android.server.location;
 
+import android.net.TrafficStats;
 import android.text.TextUtils;
 import android.util.Log;
 
+import java.io.ByteArrayOutputStream;
 import java.io.IOException;
+import java.io.InputStream;
 import java.net.HttpURLConnection;
 import java.net.URL;
-
-import java.io.ByteArrayOutputStream;
-import java.io.InputStream;
-import java.io.IOException;
 import java.util.Properties;
 import java.util.Random;
 import java.util.concurrent.TimeUnit;
 
-import libcore.io.Streams;
-
 /**
  * A class for downloading GPS XTRA data.
  *
@@ -94,7 +91,12 @@
 
         // load balance our requests among the available servers
         while (result == null) {
-            result = doDownload(mXtraServers[mNextServerIndex]);
+            final int oldTag = TrafficStats.getAndSetThreadStatsTag(TrafficStats.TAG_SYSTEM_GPS);
+            try {
+                result = doDownload(mXtraServers[mNextServerIndex]);
+            } finally {
+                TrafficStats.setThreadStatsTag(oldTag);
+            }
 
             // increment mNextServerIndex and wrap around if necessary
             mNextServerIndex++;
diff --git a/services/net/java/android/net/dhcp/DhcpClient.java b/services/net/java/android/net/dhcp/DhcpClient.java
index 8dd05b1..2624f0b 100644
--- a/services/net/java/android/net/dhcp/DhcpClient.java
+++ b/services/net/java/android/net/dhcp/DhcpClient.java
@@ -30,6 +30,7 @@
 import android.net.InterfaceConfiguration;
 import android.net.LinkAddress;
 import android.net.NetworkUtils;
+import android.net.TrafficStats;
 import android.net.metrics.IpConnectivityLog;
 import android.net.metrics.DhcpClientEvent;
 import android.net.metrics.DhcpErrorEvent;
@@ -303,6 +304,7 @@
     }
 
     private boolean initUdpSocket() {
+        final int oldTag = TrafficStats.getAndSetThreadStatsTag(TrafficStats.TAG_SYSTEM_DHCP);
         try {
             mUdpSock = Os.socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
             Os.setsockoptInt(mUdpSock, SOL_SOCKET, SO_REUSEADDR, 1);
@@ -314,6 +316,8 @@
         } catch(SocketException|ErrnoException e) {
             Log.e(TAG, "Error creating UDP socket", e);
             return false;
+        } finally {
+            TrafficStats.setThreadStatsTag(oldTag);
         }
         return true;
     }
diff --git a/services/net/java/android/net/ip/RouterAdvertisementDaemon.java b/services/net/java/android/net/ip/RouterAdvertisementDaemon.java
index 6802cff..ba1621d 100644
--- a/services/net/java/android/net/ip/RouterAdvertisementDaemon.java
+++ b/services/net/java/android/net/ip/RouterAdvertisementDaemon.java
@@ -22,6 +22,7 @@
 import android.net.LinkAddress;
 import android.net.LinkProperties;
 import android.net.NetworkUtils;
+import android.net.TrafficStats;
 import android.system.ErrnoException;
 import android.system.Os;
 import android.system.StructGroupReq;
@@ -563,6 +564,7 @@
     private boolean createSocket() {
         final int SEND_TIMEOUT_MS = 300;
 
+        final int oldTag = TrafficStats.getAndSetThreadStatsTag(TrafficStats.TAG_SYSTEM_NEIGHBOR);
         try {
             mSocket = Os.socket(AF_INET6, SOCK_RAW, IPPROTO_ICMPV6);
             // Setting SNDTIMEO is purely for defensive purposes.
@@ -574,6 +576,8 @@
         } catch (ErrnoException | IOException e) {
             Log.e(TAG, "Failed to create RA daemon socket: " + e);
             return false;
+        } finally {
+            TrafficStats.setThreadStatsTag(oldTag);
         }
 
         return true;