ebpf - adjust packet/byte statistics for GSO frames

Test: builds, atest
Bug: 146972467
Signed-off-by: Maciej Żenczykowski <maze@google.com>
Change-Id: Ie6cc79181f0af1540fd586ba3cab5acbb0cf00c4
diff --git a/bpf_progs/netd.c b/bpf_progs/netd.c
index bde12f9..519ddb9 100644
--- a/bpf_progs/netd.c
+++ b/bpf_progs/netd.c
@@ -23,6 +23,7 @@
 #include <linux/in6.h>
 #include <linux/ip.h>
 #include <linux/ipv6.h>
+#include <linux/tcp.h>
 #include <stdbool.h>
 #include <stdint.h>
 #include "bpf_net_helpers.h"
@@ -62,23 +63,61 @@
     return (uid <= MAX_SYSTEM_UID) && (uid >= MIN_SYSTEM_UID);
 }
 
+/*
+ * Note: this blindly assumes an MTU of 1500, and that packets > MTU are always TCP,
+ * and that TCP is using the Linux default settings with TCP timestamp option enabled
+ * which uses 12 TCP option bytes per frame.
+ *
+ * These are not unreasonable assumptions:
+ *
+ * The internet does not really support MTUs greater than 1500, so most TCP traffic will
+ * be at that MTU, or slightly below it (worst case our upwards adjustment is too small).
+ *
+ * The chance our traffic isn't IP at all is basically zero, so the IP overhead correction
+ * is bound to be needed.
+ *
+ * Furthermore, the likelyhood that we're having to deal with GSO (ie. > MTU) packets that
+ * are not IP/TCP is pretty small (few other things are supported by Linux) and worse case
+ * our extra overhead will be slightly off, but probably still better than assuming none.
+ *
+ * Most servers are also Linux and thus support/default to using TCP timestamp option
+ * (and indeed TCP timestamp option comes from RFC 1323 titled "TCP Extensions for High
+ * Performance" which also defined TCP window scaling and are thus absolutely ancient...).
+ *
+ * All together this should be more correct than if we simply ignored GSO frames
+ * (ie. counted them as single packets with no extra overhead)
+ *
+ * Especially since the number of packets is important for any future clat offload correction.
+ * (which adjusts upward by 20 bytes per packet to account for ipv4 -> ipv6 header conversion)
+ */
 #define DEFINE_UPDATE_STATS(the_stats_map, TypeOfKey)                                          \
     static __always_inline inline void update_##the_stats_map(struct __sk_buff* skb,           \
                                                               int direction, TypeOfKey* key) { \
-        StatsValue* value;                                                                     \
-        value = bpf_##the_stats_map##_lookup_elem(key);                                        \
+        StatsValue* value = bpf_##the_stats_map##_lookup_elem(key);                            \
         if (!value) {                                                                          \
             StatsValue newValue = {};                                                          \
             bpf_##the_stats_map##_update_elem(key, &newValue, BPF_NOEXIST);                    \
             value = bpf_##the_stats_map##_lookup_elem(key);                                    \
         }                                                                                      \
         if (value) {                                                                           \
+            const int mtu = 1500;                                                              \
+            uint64_t packets = 1;                                                              \
+            uint64_t bytes = skb->len;                                                         \
+            if (bytes > mtu) {                                                                 \
+                bool is_ipv6 = (skb->protocol == htons(ETH_P_IPV6));                           \
+                int ip_overhead = (is_ipv6 ? sizeof(struct ipv6hdr) : sizeof(struct iphdr));   \
+                int tcp_overhead = ip_overhead + sizeof(struct tcphdr) + 12;                   \
+                int mss = mtu - tcp_overhead;                                                  \
+                uint64_t payload = bytes - tcp_overhead;                                       \
+                packets = (payload + mss - 1) / mss;                                           \
+                bytes = tcp_overhead * packets + payload;                                      \
+            }                                                                                  \
             if (direction == BPF_EGRESS) {                                                     \
-                __sync_fetch_and_add(&value->txPackets, 1);                                    \
-                __sync_fetch_and_add(&value->txBytes, skb->len);                               \
+                __sync_fetch_and_add(&value->txPackets, packets);                              \
+                __sync_fetch_and_add(&value->txBytes, bytes);                                  \
             } else if (direction == BPF_INGRESS) {                                             \
-                __sync_fetch_and_add(&value->rxPackets, 1);                                    \
-                __sync_fetch_and_add(&value->rxBytes, skb->len);                               \
+                __sync_fetch_and_add(&value->rxPackets, packets);                              \
+                __sync_fetch_and_add(&value->rxBytes, bytes);                                  \
             }                                                                                  \
         }                                                                                      \
     }