NetworkStatsService: Fix getDetailedUidStats to take VPNs into account.

(cherry picked from commit 29d2ef2fe10ce9686b57a21aa8e10f6481c92a9a)

This API is similar to one provided by NetworkStatsFactory with the
difference that NSS also migrates traffic from VPN UID to other apps.

Since traffic can only be migrated over NetworkStats delta, NSS
therefore maintains NetworkStats snapshot across all UIDs/ifaces/tags.

This snapshot gets updated whenever NSS records a new snapshot
(based on various hooks such as VPN updating its underlying networks,
network getting lost, etc.), or getDetailedUidStats API is invoked by
one of its callers.

Bug: 113122541
Bug: 120145746
Test: atest FrameworksNetTests
Test: manually verified that battery stats are migrating traffic off of
TUN (after patching above CL where we point BatteryStats to use this
API).

Change-Id: I4b8d7c5b6905a4a12c1806dfd35c2c4c63610404
diff --git a/core/java/android/net/NetworkStats.java b/core/java/android/net/NetworkStats.java
index f09f2ee..e8625f3 100644
--- a/core/java/android/net/NetworkStats.java
+++ b/core/java/android/net/NetworkStats.java
@@ -34,6 +34,7 @@
 import java.io.CharArrayWriter;
 import java.io.PrintWriter;
 import java.util.Arrays;
+import java.util.function.Predicate;
 import java.util.HashSet;
 import java.util.Map;
 import java.util.Objects;
@@ -994,23 +995,33 @@
         if (limitUid == UID_ALL && limitTag == TAG_ALL && limitIfaces == INTERFACES_ALL) {
             return;
         }
+        filter(e -> (limitUid == UID_ALL || limitUid == e.uid)
+            && (limitTag == TAG_ALL || limitTag == e.tag)
+            && (limitIfaces == INTERFACES_ALL
+                    || ArrayUtils.contains(limitIfaces, e.iface)));
+    }
 
+    /**
+     * Only keep entries with {@link #set} value less than {@link #SET_DEBUG_START}.
+     *
+     * <p>This mutates the original structure in place.
+     */
+    public void filterDebugEntries() {
+        filter(e -> e.set < SET_DEBUG_START);
+    }
+
+    private void filter(Predicate<Entry> predicate) {
         Entry entry = new Entry();
         int nextOutputEntry = 0;
         for (int i = 0; i < size; i++) {
             entry = getValues(i, entry);
-            final boolean matches =
-                    (limitUid == UID_ALL || limitUid == entry.uid)
-                    && (limitTag == TAG_ALL || limitTag == entry.tag)
-                    && (limitIfaces == INTERFACES_ALL
-                            || ArrayUtils.contains(limitIfaces, entry.iface));
-
-            if (matches) {
-                setValues(nextOutputEntry, entry);
+            if (predicate.test(entry)) {
+                if (nextOutputEntry != i) {
+                    setValues(nextOutputEntry, entry);
+                }
                 nextOutputEntry++;
             }
         }
-
         size = nextOutputEntry;
     }
 
diff --git a/services/core/java/com/android/server/net/NetworkStatsFactory.java b/services/core/java/com/android/server/net/NetworkStatsFactory.java
index 69efd02..473cc97 100644
--- a/services/core/java/com/android/server/net/NetworkStatsFactory.java
+++ b/services/core/java/com/android/server/net/NetworkStatsFactory.java
@@ -263,6 +263,10 @@
         return stats;
     }
 
+    /**
+     * @deprecated Use NetworkStatsService#getDetailedUidStats which also accounts for
+     * VPN traffic
+     */
     public NetworkStats readNetworkStatsDetail() throws IOException {
         return readNetworkStatsDetail(UID_ALL, null, TAG_ALL, null);
     }
diff --git a/services/core/java/com/android/server/net/NetworkStatsService.java b/services/core/java/com/android/server/net/NetworkStatsService.java
index f34ace5..a13368f 100644
--- a/services/core/java/com/android/server/net/NetworkStatsService.java
+++ b/services/core/java/com/android/server/net/NetworkStatsService.java
@@ -293,6 +293,22 @@
     /** Data layer operation counters for splicing into other structures. */
     private NetworkStats mUidOperations = new NetworkStats(0L, 10);
 
+    /**
+     * Snapshot containing most recent network stats for all UIDs across all interfaces and tags
+     * since boot.
+     *
+     * <p>Maintains migrated VPN stats which are result of performing TUN migration on {@link
+     * #mLastUidDetailSnapshot}.
+     */
+    @GuardedBy("mStatsLock")
+    private NetworkStats mTunAdjustedStats;
+    /**
+     * Used by {@link #mTunAdjustedStats} to migrate VPN traffic over delta between this snapshot
+     * and latest snapshot.
+     */
+    @GuardedBy("mStatsLock")
+    private NetworkStats mLastUidDetailSnapshot;
+
     /** Must be set in factory by calling #setHandler. */
     private Handler mHandler;
     private Handler.Callback mHandlerCallback;
@@ -812,15 +828,39 @@
     @Override
     public NetworkStats getDetailedUidStats(String[] requiredIfaces) {
         try {
+            // Get the latest snapshot from NetworkStatsFactory.
+            // TODO: Querying for INTERFACES_ALL may incur performance penalty. Consider restricting
+            // this to limited set of ifaces.
+            NetworkStats uidDetailStats = getNetworkStatsUidDetail(INTERFACES_ALL);
+
+            // Migrate traffic from VPN UID over delta and update mTunAdjustedStats.
+            NetworkStats result;
+            synchronized (mStatsLock) {
+                migrateTunTraffic(uidDetailStats, mVpnInfos);
+                result = mTunAdjustedStats.clone();
+            }
+
+            // Apply filter based on ifacesToQuery.
             final String[] ifacesToQuery =
                     NetworkStatsFactory.augmentWithStackedInterfaces(requiredIfaces);
-            return getNetworkStatsUidDetail(ifacesToQuery);
+            result.filter(UID_ALL, ifacesToQuery, TAG_ALL);
+            return result;
         } catch (RemoteException e) {
             Log.wtf(TAG, "Error compiling UID stats", e);
             return new NetworkStats(0L, 0);
         }
     }
 
+    @VisibleForTesting
+    NetworkStats getTunAdjustedStats() {
+        synchronized (mStatsLock) {
+            if (mTunAdjustedStats == null) {
+                return null;
+            }
+            return mTunAdjustedStats.clone();
+        }
+    }
+
     @Override
     public String[] getMobileIfaces() {
         return mMobileIfaces;
@@ -1295,6 +1335,34 @@
         // a race condition between the service handler thread and the observer's
         mStatsObservers.updateStats(xtSnapshot, uidSnapshot, new ArrayMap<>(mActiveIfaces),
                 new ArrayMap<>(mActiveUidIfaces), vpnArray, currentTime);
+
+        migrateTunTraffic(uidSnapshot, vpnArray);
+    }
+
+    /**
+     * Updates {@link #mTunAdjustedStats} with the delta containing traffic migrated off of VPNs.
+     */
+    @GuardedBy("mStatsLock")
+    private void migrateTunTraffic(NetworkStats uidDetailStats, VpnInfo[] vpnInfoArray) {
+        if (mTunAdjustedStats == null) {
+            // Either device booted or system server restarted, hence traffic cannot be migrated
+            // correctly without knowing the past state of VPN's underlying networks.
+            mTunAdjustedStats = uidDetailStats;
+            mLastUidDetailSnapshot = uidDetailStats;
+            return;
+        }
+        // Migrate delta traffic from VPN to other apps.
+        NetworkStats delta = uidDetailStats.subtract(mLastUidDetailSnapshot);
+        for (VpnInfo info : vpnInfoArray) {
+            delta.migrateTun(info.ownerUid, info.vpnIface, info.underlyingIfaces);
+        }
+        // Filter out debug entries as that may lead to over counting.
+        delta.filterDebugEntries();
+        // Update #mTunAdjustedStats with migrated delta.
+        mTunAdjustedStats.combineAllValues(delta);
+        mTunAdjustedStats.setElapsedRealtime(uidDetailStats.getElapsedRealtime());
+        // Update last snapshot.
+        mLastUidDetailSnapshot = uidDetailStats;
     }
 
     /**
diff --git a/tests/net/java/android/net/NetworkStatsTest.java b/tests/net/java/android/net/NetworkStatsTest.java
index 9ed6cb9..c16a0f4 100644
--- a/tests/net/java/android/net/NetworkStatsTest.java
+++ b/tests/net/java/android/net/NetworkStatsTest.java
@@ -813,6 +813,37 @@
     }
 
     @Test
+    public void testFilterDebugEntries() {
+        NetworkStats.Entry entry1 = new NetworkStats.Entry(
+                "test1", 10100, SET_DEFAULT, TAG_NONE, METERED_NO, ROAMING_NO,
+                DEFAULT_NETWORK_NO, 50000L, 25L, 100000L, 50L, 0L);
+
+        NetworkStats.Entry entry2 = new NetworkStats.Entry(
+                "test2", 10101, SET_DBG_VPN_IN, TAG_NONE, METERED_NO, ROAMING_NO,
+                DEFAULT_NETWORK_NO, 50000L, 25L, 100000L, 50L, 0L);
+
+        NetworkStats.Entry entry3 = new NetworkStats.Entry(
+                "test2", 10101, SET_DEFAULT, TAG_NONE, METERED_NO, ROAMING_NO,
+                DEFAULT_NETWORK_NO, 50000L, 25L, 100000L, 50L, 0L);
+
+        NetworkStats.Entry entry4 = new NetworkStats.Entry(
+                "test2", 10101, SET_DBG_VPN_OUT, TAG_NONE, METERED_NO, ROAMING_NO,
+                DEFAULT_NETWORK_NO, 50000L, 25L, 100000L, 50L, 0L);
+
+        NetworkStats stats = new NetworkStats(TEST_START, 4)
+                .addValues(entry1)
+                .addValues(entry2)
+                .addValues(entry3)
+                .addValues(entry4);
+
+        stats.filterDebugEntries();
+
+        assertEquals(2, stats.size());
+        assertEquals(entry1, stats.getValues(0, null));
+        assertEquals(entry3, stats.getValues(1, null));
+    }
+
+    @Test
     public void testApply464xlatAdjustments() {
         final String v4Iface = "v4-wlan0";
         final String baseIface = "wlan0";
diff --git a/tests/net/java/com/android/server/net/NetworkStatsServiceTest.java b/tests/net/java/com/android/server/net/NetworkStatsServiceTest.java
index c46534c..e35c34a 100644
--- a/tests/net/java/com/android/server/net/NetworkStatsServiceTest.java
+++ b/tests/net/java/com/android/server/net/NetworkStatsServiceTest.java
@@ -57,6 +57,7 @@
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
@@ -216,11 +217,16 @@
         expectNetworkStatsUidDetail(buildEmptyStats());
         expectSystemReady();
 
+        assertNull(mService.getTunAdjustedStats());
         mService.systemReady();
+        // Verify that system ready fetches realtime stats and initializes tun adjusted stats.
+        verify(mNetManager).getNetworkStatsUidDetail(UID_ALL, INTERFACES_ALL);
+        assertNotNull("failed to initialize TUN adjusted stats", mService.getTunAdjustedStats());
+        assertEquals(0, mService.getTunAdjustedStats().size());
+
         mSession = mService.openSession();
         assertNotNull("openSession() failed", mSession);
 
-
         // catch INetworkManagementEventObserver during systemReady()
         ArgumentCaptor<INetworkManagementEventObserver> networkObserver =
               ArgumentCaptor.forClass(INetworkManagementEventObserver.class);
@@ -733,11 +739,13 @@
 
         NetworkStats stats = mService.getDetailedUidStats(ifaceFilter);
 
-        verify(mNetManager, times(1)).getNetworkStatsUidDetail(eq(UID_ALL), argThat(ifaces ->
-                ifaces != null && ifaces.length == 2
-                        && ArrayUtils.contains(ifaces, TEST_IFACE)
-                        && ArrayUtils.contains(ifaces, stackedIface)));
-
+        // mNetManager#getNetworkStatsUidDetail(UID_ALL, INTERFACES_ALL) has following invocations:
+        // 1) NetworkStatsService#systemReady from #setUp.
+        // 2) mService#forceUpdateIfaces in the test above.
+        // 3) Finally, mService#getDetailedUidStats.
+        verify(mNetManager, times(3)).getNetworkStatsUidDetail(UID_ALL, INTERFACES_ALL);
+        assertTrue(ArrayUtils.contains(stats.getUniqueIfaces(), TEST_IFACE));
+        assertTrue(ArrayUtils.contains(stats.getUniqueIfaces(), stackedIface));
         assertEquals(2, stats.size());
         assertEquals(uidStats, stats.getValues(0, null));
         assertEquals(tetheredStats1, stats.getValues(1, null));
@@ -1158,6 +1166,134 @@
     }
 
     @Test
+    public void recordSnapshot_migratesTunTrafficAndUpdatesTunAdjustedStats() throws Exception {
+        assertEquals(0, mService.getTunAdjustedStats().size());
+        // VPN using WiFi (TEST_IFACE).
+        VpnInfo[] vpnInfos = new VpnInfo[] {createVpnInfo(new String[] {TEST_IFACE})};
+        expectBandwidthControlCheck();
+        // create some traffic (assume 10 bytes of MTU for VPN interface and 1 byte encryption
+        // overhead per packet):
+        // 1000 bytes (100 packets) were downloaded by UID_RED over VPN.
+        // VPN received 1100 bytes (100 packets) over WiFi.
+        incrementCurrentTime(HOUR_IN_MILLIS);
+        expectNetworkStatsUidDetail(new NetworkStats(getElapsedRealtime(), 2)
+              .addValues(TUN_IFACE, UID_RED, SET_DEFAULT, TAG_NONE, 1000L, 100L, 0L, 0L, 0L)
+              .addValues(TEST_IFACE, UID_VPN, SET_DEFAULT, TAG_NONE, 1100L, 100L, 0L, 0L, 0L));
+
+        // this should lead to NSS#recordSnapshotLocked
+        mService.forceUpdateIfaces(
+                new Network[0], vpnInfos, new NetworkState[0], null /* activeIface */);
+
+        // Verify TUN adjusted stats have traffic migrated correctly.
+        // Of 1100 bytes VPN received over WiFi, expect 1000 bytes attributed to UID_RED and 100
+        // bytes attributed to UID_VPN.
+        NetworkStats tunAdjStats = mService.getTunAdjustedStats();
+        assertValues(
+                tunAdjStats, TEST_IFACE, UID_RED, SET_ALL, TAG_NONE, METERED_ALL, ROAMING_ALL,
+                DEFAULT_NETWORK_ALL, 1000L, 100L, 0L, 0L, 0);
+        assertValues(
+                tunAdjStats, TEST_IFACE, UID_VPN, SET_ALL, TAG_NONE, METERED_ALL, ROAMING_ALL,
+                DEFAULT_NETWORK_ALL, 100L, 0L, 0L, 0L, 0);
+    }
+
+    @Test
+    public void getDetailedUidStats_migratesTunTrafficAndUpdatesTunAdjustedStats()
+            throws Exception {
+        assertEquals(0, mService.getTunAdjustedStats().size());
+        // VPN using WiFi (TEST_IFACE).
+        VpnInfo[] vpnInfos = new VpnInfo[] {createVpnInfo(new String[] {TEST_IFACE})};
+        expectBandwidthControlCheck();
+        mService.forceUpdateIfaces(
+                new Network[0], vpnInfos, new NetworkState[0], null /* activeIface */);
+        // create some traffic (assume 10 bytes of MTU for VPN interface and 1 byte encryption
+        // overhead per packet):
+        // 1000 bytes (100 packets) were downloaded by UID_RED over VPN.
+        // VPN received 1100 bytes (100 packets) over WiFi.
+        incrementCurrentTime(HOUR_IN_MILLIS);
+        expectNetworkStatsUidDetail(new NetworkStats(getElapsedRealtime(), 2)
+              .addValues(TUN_IFACE, UID_RED, SET_DEFAULT, TAG_NONE, 1000L, 100L, 0L, 0L, 0L)
+              .addValues(TEST_IFACE, UID_VPN, SET_DEFAULT, TAG_NONE, 1100L, 100L, 0L, 0L, 0L));
+
+        mService.getDetailedUidStats(INTERFACES_ALL);
+
+        // Verify internally maintained TUN adjusted stats
+        NetworkStats tunAdjStats = mService.getTunAdjustedStats();
+        // Verify stats for TEST_IFACE (WiFi):
+        // Of 1100 bytes VPN received over WiFi, expect 1000 bytes attributed to UID_RED and 100
+        // bytes attributed to UID_VPN.
+        assertValues(
+                tunAdjStats, TEST_IFACE, UID_RED, SET_ALL, TAG_NONE, METERED_ALL, ROAMING_ALL,
+                DEFAULT_NETWORK_ALL, 1000L, 100L, 0L, 0L, 0);
+        assertValues(
+                tunAdjStats, TEST_IFACE, UID_VPN, SET_ALL, TAG_NONE, METERED_ALL, ROAMING_ALL,
+                DEFAULT_NETWORK_ALL, 100L, 0L, 0L, 0L, 0);
+        // Verify stats for TUN_IFACE; only UID_RED should have usage on it.
+        assertValues(
+                tunAdjStats, TUN_IFACE, UID_RED, SET_ALL, TAG_NONE, METERED_ALL, ROAMING_ALL,
+                DEFAULT_NETWORK_ALL, 1000L, 100L, 0L, 0L, 0);
+        assertValues(
+                tunAdjStats, TUN_IFACE, UID_VPN, SET_ALL, TAG_NONE, METERED_ALL, ROAMING_ALL,
+                DEFAULT_NETWORK_ALL, 0L, 0L, 0L, 0L, 0);
+
+        // lets assume that since last time, VPN received another 1100 bytes (same assumptions as
+        // before i.e. UID_RED downloaded another 1000 bytes).
+        incrementCurrentTime(HOUR_IN_MILLIS);
+        // Note - NetworkStatsFactory returns counters that are monotonically increasing.
+        expectNetworkStatsUidDetail(new NetworkStats(getElapsedRealtime(), 2)
+              .addValues(TUN_IFACE, UID_RED, SET_DEFAULT, TAG_NONE, 2000L, 200L, 0L, 0L, 0L)
+              .addValues(TEST_IFACE, UID_VPN, SET_DEFAULT, TAG_NONE, 2200L, 200L, 0L, 0L, 0L));
+
+        mService.getDetailedUidStats(INTERFACES_ALL);
+
+        tunAdjStats = mService.getTunAdjustedStats();
+        // verify TEST_IFACE stats:
+        assertValues(
+                tunAdjStats, TEST_IFACE, UID_RED, SET_ALL, TAG_NONE, METERED_ALL, ROAMING_ALL,
+                DEFAULT_NETWORK_ALL, 2000L, 200L, 0L, 0L, 0);
+        assertValues(
+                tunAdjStats, TEST_IFACE, UID_VPN, SET_ALL, TAG_NONE, METERED_ALL, ROAMING_ALL,
+                DEFAULT_NETWORK_ALL, 200L, 0L, 0L, 0L, 0);
+        // verify TUN_IFACE stats:
+        assertValues(
+                tunAdjStats, TUN_IFACE, UID_RED, SET_ALL, TAG_NONE, METERED_ALL, ROAMING_ALL,
+                DEFAULT_NETWORK_ALL, 2000L, 200L, 0L, 0L, 0);
+        assertValues(
+                tunAdjStats, TUN_IFACE, UID_VPN, SET_ALL, TAG_NONE, METERED_ALL, ROAMING_ALL,
+                DEFAULT_NETWORK_ALL, 0L, 0L, 0L, 0L, 0);
+    }
+
+    @Test
+    public void getDetailedUidStats_returnsCorrectStatsWithVpnRunning() throws Exception {
+        // VPN using WiFi (TEST_IFACE).
+        VpnInfo[] vpnInfos = new VpnInfo[] {createVpnInfo(new String[] {TEST_IFACE})};
+        expectBandwidthControlCheck();
+        mService.forceUpdateIfaces(
+                new Network[0], vpnInfos, new NetworkState[0], null /* activeIface */);
+        // create some traffic (assume 10 bytes of MTU for VPN interface and 1 byte encryption
+        // overhead per packet):
+        // 1000 bytes (100 packets) were downloaded by UID_RED over VPN.
+        // VPN received 1100 bytes (100 packets) over WiFi.
+        incrementCurrentTime(HOUR_IN_MILLIS);
+        expectNetworkStatsUidDetail(new NetworkStats(getElapsedRealtime(), 2)
+              .addValues(TUN_IFACE, UID_RED, SET_DEFAULT, TAG_NONE, 1000L, 100L, 0L, 0L, 0L)
+              .addValues(TEST_IFACE, UID_VPN, SET_DEFAULT, TAG_NONE, 1100L, 100L, 0L, 0L, 0L));
+
+        // Query realtime stats for TEST_IFACE.
+        NetworkStats queriedStats =
+                mService.getDetailedUidStats(new String[] {TEST_IFACE});
+
+        assertEquals(HOUR_IN_MILLIS, queriedStats.getElapsedRealtime());
+        // verify that returned stats are only for TEST_IFACE and VPN traffic is migrated correctly.
+        assertEquals(new String[] {TEST_IFACE}, queriedStats.getUniqueIfaces());
+        assertValues(
+                queriedStats, TEST_IFACE, UID_RED, SET_ALL, TAG_NONE, METERED_ALL, ROAMING_ALL,
+                DEFAULT_NETWORK_ALL, 1000L, 100L, 0L, 0L, 0);
+        assertValues(
+                queriedStats, TEST_IFACE, UID_VPN, SET_ALL, TAG_NONE, METERED_ALL, ROAMING_ALL,
+                DEFAULT_NETWORK_ALL, 100L, 0L, 0L, 0L, 0);
+    }
+
+    @Test
     public void testRegisterUsageCallback() throws Exception {
         // pretend that wifi network comes online; service should ask about full
         // network state, and poll any existing interfaces before updating.