apf: Drop ARP reply if SPA is 0.0.0.0

Some network re-writing packet from broadcast MACs to unicast,
result in this kind of packets cannot be dropped by APF filter.
Thus, drop ARP reply if source IP is 0.0.0.0.

Note: Linux kernel always ignores such replies in the function arp_process().

Bug: 118044271
Test: runtest frameworks-net -c android.net.apf.ApfTest
Change-Id: Id293bf231913d9b483ce7d8dd909e05fa927ccd7
diff --git a/services/net/java/android/net/apf/ApfFilter.java b/services/net/java/android/net/apf/ApfFilter.java
index b9cc372..f037905 100644
--- a/services/net/java/android/net/apf/ApfFilter.java
+++ b/services/net/java/android/net/apf/ApfFilter.java
@@ -139,7 +139,8 @@
         DROPPED_IPV6_MULTICAST_PING,
         DROPPED_IPV6_NON_ICMP_MULTICAST,
         DROPPED_802_3_FRAME,
-        DROPPED_ETHERTYPE_BLACKLISTED;
+        DROPPED_ETHERTYPE_BLACKLISTED,
+        DROPPED_ARP_REPLY_SPA_NO_HOST;
 
         // Returns the negative byte offset from the end of the APF data segment for
         // a given counter.
@@ -156,7 +157,7 @@
     /**
      * When APFv4 is supported, loads R1 with the offset of the specified counter.
      */
-    private void maybeSetCounter(ApfGenerator gen, Counter c) {
+    private void maybeSetupCounter(ApfGenerator gen, Counter c) {
         if (mApfCapabilities.hasDataAccess()) {
             gen.addLoadImmediate(Register.R1, c.offset());
         }
@@ -288,16 +289,18 @@
     private static final int DHCP_CLIENT_MAC_OFFSET = ETH_HEADER_LEN + UDP_HEADER_LEN + 28;
 
     private static final int ARP_HEADER_OFFSET = ETH_HEADER_LEN;
-    private static final int ARP_OPCODE_OFFSET = ARP_HEADER_OFFSET + 6;
-    private static final short ARP_OPCODE_REQUEST = 1;
-    private static final short ARP_OPCODE_REPLY = 2;
     private static final byte[] ARP_IPV4_HEADER = {
             0, 1, // Hardware type: Ethernet (1)
             8, 0, // Protocol type: IP (0x0800)
             6,    // Hardware size: 6
             4,    // Protocol size: 4
     };
-    private static final int ARP_TARGET_IP_ADDRESS_OFFSET = ETH_HEADER_LEN + 24;
+    private static final int ARP_OPCODE_OFFSET = ARP_HEADER_OFFSET + 6;
+    // Opcode: ARP request (0x0001), ARP reply (0x0002)
+    private static final short ARP_OPCODE_REQUEST = 1;
+    private static final short ARP_OPCODE_REPLY = 2;
+    private static final int ARP_SOURCE_IP_ADDRESS_OFFSET = ARP_HEADER_OFFSET + 14;
+    private static final int ARP_TARGET_IP_ADDRESS_OFFSET = ARP_HEADER_OFFSET + 24;
     // Do not log ApfProgramEvents whose actual lifetimes was less than this.
     private static final int APF_PROGRAM_EVENT_LIFETIME_THRESHOLD = 2;
     // Limit on the Black List size to cap on program usage for this
@@ -816,7 +819,7 @@
                     gen.addJumpIfR0LessThan(filterLifetime, nextFilterLabel);
                 }
             }
-            maybeSetCounter(gen, Counter.DROPPED_RA);
+            maybeSetupCounter(gen, Counter.DROPPED_RA);
             gen.addJump(mCountAndDropLabel);
             gen.defineLabel(nextFilterLabel);
             return filterLifetime;
@@ -883,6 +886,8 @@
         //   pass
         // if not ARP IPv4 reply or request
         //   pass
+        // if ARP reply source ip is 0.0.0.0
+        //   drop
         // if unicast ARP reply
         //   pass
         // if interface has no IPv4 address
@@ -897,18 +902,23 @@
 
         // Pass if not ARP IPv4.
         gen.addLoadImmediate(Register.R0, ARP_HEADER_OFFSET);
-        maybeSetCounter(gen, Counter.PASSED_ARP_NON_IPV4);
+        maybeSetupCounter(gen, Counter.PASSED_ARP_NON_IPV4);
         gen.addJumpIfBytesNotEqual(Register.R0, ARP_IPV4_HEADER, mCountAndPassLabel);
 
         // Pass if unknown ARP opcode.
         gen.addLoad16(Register.R0, ARP_OPCODE_OFFSET);
         gen.addJumpIfR0Equals(ARP_OPCODE_REQUEST, checkTargetIPv4); // Skip to unicast check
-        maybeSetCounter(gen, Counter.PASSED_ARP_UNKNOWN);
+        maybeSetupCounter(gen, Counter.PASSED_ARP_UNKNOWN);
         gen.addJumpIfR0NotEquals(ARP_OPCODE_REPLY, mCountAndPassLabel);
 
+        // Drop if ARP reply source IP is 0.0.0.0
+        gen.addLoad32(Register.R0, ARP_SOURCE_IP_ADDRESS_OFFSET);
+        maybeSetupCounter(gen, Counter.DROPPED_ARP_REPLY_SPA_NO_HOST);
+        gen.addJumpIfR0Equals(IPV4_ANY_HOST_ADDRESS, mCountAndDropLabel);
+
         // Pass if unicast reply.
         gen.addLoadImmediate(Register.R0, ETH_DEST_ADDR_OFFSET);
-        maybeSetCounter(gen, Counter.PASSED_ARP_UNICAST_REPLY);
+        maybeSetupCounter(gen, Counter.PASSED_ARP_UNICAST_REPLY);
         gen.addJumpIfBytesNotEqual(Register.R0, ETH_BROADCAST_MAC_ADDRESS, mCountAndPassLabel);
 
         // Either a unicast request, a unicast reply, or a broadcast reply.
@@ -916,17 +926,17 @@
         if (mIPv4Address == null) {
             // When there is no IPv4 address, drop GARP replies (b/29404209).
             gen.addLoad32(Register.R0, ARP_TARGET_IP_ADDRESS_OFFSET);
-            maybeSetCounter(gen, Counter.DROPPED_GARP_REPLY);
+            maybeSetupCounter(gen, Counter.DROPPED_GARP_REPLY);
             gen.addJumpIfR0Equals(IPV4_ANY_HOST_ADDRESS, mCountAndDropLabel);
         } else {
             // When there is an IPv4 address, drop unicast/broadcast requests
             // and broadcast replies with a different target IPv4 address.
             gen.addLoadImmediate(Register.R0, ARP_TARGET_IP_ADDRESS_OFFSET);
-            maybeSetCounter(gen, Counter.DROPPED_ARP_OTHER_HOST);
+            maybeSetupCounter(gen, Counter.DROPPED_ARP_OTHER_HOST);
             gen.addJumpIfBytesNotEqual(Register.R0, mIPv4Address, mCountAndDropLabel);
         }
 
-        maybeSetCounter(gen, Counter.PASSED_ARP);
+        maybeSetupCounter(gen, Counter.PASSED_ARP);
         gen.addJump(mCountAndPassLabel);
     }
 
@@ -970,7 +980,7 @@
             // NOTE: Relies on R1 containing IPv4 header offset.
             gen.addAddR1();
             gen.addJumpIfBytesNotEqual(Register.R0, mHardwareAddress, skipDhcpv4Filter);
-            maybeSetCounter(gen, Counter.PASSED_DHCP);
+            maybeSetupCounter(gen, Counter.PASSED_DHCP);
             gen.addJump(mCountAndPassLabel);
 
             // Drop all multicasts/broadcasts.
@@ -979,30 +989,30 @@
             // If IPv4 destination address is in multicast range, drop.
             gen.addLoad8(Register.R0, IPV4_DEST_ADDR_OFFSET);
             gen.addAnd(0xf0);
-            maybeSetCounter(gen, Counter.DROPPED_IPV4_MULTICAST);
+            maybeSetupCounter(gen, Counter.DROPPED_IPV4_MULTICAST);
             gen.addJumpIfR0Equals(0xe0, mCountAndDropLabel);
 
             // If IPv4 broadcast packet, drop regardless of L2 (b/30231088).
-            maybeSetCounter(gen, Counter.DROPPED_IPV4_BROADCAST_ADDR);
+            maybeSetupCounter(gen, Counter.DROPPED_IPV4_BROADCAST_ADDR);
             gen.addLoad32(Register.R0, IPV4_DEST_ADDR_OFFSET);
             gen.addJumpIfR0Equals(IPV4_BROADCAST_ADDRESS, mCountAndDropLabel);
             if (mIPv4Address != null && mIPv4PrefixLength < 31) {
-                maybeSetCounter(gen, Counter.DROPPED_IPV4_BROADCAST_NET);
+                maybeSetupCounter(gen, Counter.DROPPED_IPV4_BROADCAST_NET);
                 int broadcastAddr = ipv4BroadcastAddress(mIPv4Address, mIPv4PrefixLength);
                 gen.addJumpIfR0Equals(broadcastAddr, mCountAndDropLabel);
             }
 
             // If L2 broadcast packet, drop.
             // TODO: can we invert this condition to fall through to the common pass case below?
-            maybeSetCounter(gen, Counter.PASSED_IPV4_UNICAST);
+            maybeSetupCounter(gen, Counter.PASSED_IPV4_UNICAST);
             gen.addLoadImmediate(Register.R0, ETH_DEST_ADDR_OFFSET);
             gen.addJumpIfBytesNotEqual(Register.R0, ETH_BROADCAST_MAC_ADDRESS, mCountAndPassLabel);
-            maybeSetCounter(gen, Counter.DROPPED_IPV4_L2_BROADCAST);
+            maybeSetupCounter(gen, Counter.DROPPED_IPV4_L2_BROADCAST);
             gen.addJump(mCountAndDropLabel);
         }
 
         // Otherwise, pass
-        maybeSetCounter(gen, Counter.PASSED_IPV4);
+        maybeSetupCounter(gen, Counter.PASSED_IPV4);
         gen.addJump(mCountAndPassLabel);
     }
 
@@ -1050,16 +1060,16 @@
 
             // Drop all other packets sent to ff00::/8 (multicast prefix).
             gen.defineLabel(dropAllIPv6MulticastsLabel);
-            maybeSetCounter(gen, Counter.DROPPED_IPV6_NON_ICMP_MULTICAST);
+            maybeSetupCounter(gen, Counter.DROPPED_IPV6_NON_ICMP_MULTICAST);
             gen.addLoad8(Register.R0, IPV6_DEST_ADDR_OFFSET);
             gen.addJumpIfR0Equals(0xff, mCountAndDropLabel);
             // Not multicast. Pass.
-            maybeSetCounter(gen, Counter.PASSED_IPV6_UNICAST_NON_ICMP);
+            maybeSetupCounter(gen, Counter.PASSED_IPV6_UNICAST_NON_ICMP);
             gen.addJump(mCountAndPassLabel);
             gen.defineLabel(skipIPv6MulticastFilterLabel);
         } else {
             // If not ICMPv6, pass.
-            maybeSetCounter(gen, Counter.PASSED_IPV6_NON_ICMP);
+            maybeSetupCounter(gen, Counter.PASSED_IPV6_NON_ICMP);
             gen.addJumpIfR0NotEquals(IPPROTO_ICMPV6, mCountAndPassLabel);
         }
 
@@ -1069,7 +1079,7 @@
         String skipUnsolicitedMulticastNALabel = "skipUnsolicitedMulticastNA";
         gen.addLoad8(Register.R0, ICMP6_TYPE_OFFSET);
         // Drop all router solicitations (b/32833400)
-        maybeSetCounter(gen, Counter.DROPPED_IPV6_ROUTER_SOLICITATION);
+        maybeSetupCounter(gen, Counter.DROPPED_IPV6_ROUTER_SOLICITATION);
         gen.addJumpIfR0Equals(ICMPV6_ROUTER_SOLICITATION, mCountAndDropLabel);
         // If not neighbor announcements, skip filter.
         gen.addJumpIfR0NotEquals(ICMPV6_NEIGHBOR_ADVERTISEMENT, skipUnsolicitedMulticastNALabel);
@@ -1078,7 +1088,7 @@
         gen.addLoadImmediate(Register.R0, IPV6_DEST_ADDR_OFFSET);
         gen.addJumpIfBytesNotEqual(Register.R0, IPV6_ALL_NODES_ADDRESS,
                 skipUnsolicitedMulticastNALabel);
-        maybeSetCounter(gen, Counter.DROPPED_IPV6_MULTICAST_NA);
+        maybeSetupCounter(gen, Counter.DROPPED_IPV6_MULTICAST_NA);
         gen.addJump(mCountAndDropLabel);
         gen.defineLabel(skipUnsolicitedMulticastNALabel);
     }
@@ -1108,7 +1118,7 @@
 
         if (mApfCapabilities.hasDataAccess()) {
             // Increment TOTAL_PACKETS
-            maybeSetCounter(gen, Counter.TOTAL_PACKETS);
+            maybeSetupCounter(gen, Counter.TOTAL_PACKETS);
             gen.addLoadData(Register.R0, 0);  // load counter
             gen.addAdd(1);
             gen.addStoreData(Register.R0, 0);  // write-back counter
@@ -1134,12 +1144,12 @@
 
         if (mDrop802_3Frames) {
             // drop 802.3 frames (ethtype < 0x0600)
-            maybeSetCounter(gen, Counter.DROPPED_802_3_FRAME);
+            maybeSetupCounter(gen, Counter.DROPPED_802_3_FRAME);
             gen.addJumpIfR0LessThan(ETH_TYPE_MIN, mCountAndDropLabel);
         }
 
         // Handle ether-type black list
-        maybeSetCounter(gen, Counter.DROPPED_ETHERTYPE_BLACKLISTED);
+        maybeSetupCounter(gen, Counter.DROPPED_ETHERTYPE_BLACKLISTED);
         for (int p : mEthTypeBlackList) {
             gen.addJumpIfR0Equals(p, mCountAndDropLabel);
         }
@@ -1168,9 +1178,9 @@
 
         // Drop non-IP non-ARP broadcasts, pass the rest
         gen.addLoadImmediate(Register.R0, ETH_DEST_ADDR_OFFSET);
-        maybeSetCounter(gen, Counter.PASSED_NON_IP_UNICAST);
+        maybeSetupCounter(gen, Counter.PASSED_NON_IP_UNICAST);
         gen.addJumpIfBytesNotEqual(Register.R0, ETH_BROADCAST_MAC_ADDRESS, mCountAndPassLabel);
-        maybeSetCounter(gen, Counter.DROPPED_ETH_BROADCAST);
+        maybeSetupCounter(gen, Counter.DROPPED_ETH_BROADCAST);
         gen.addJump(mCountAndDropLabel);
 
         // Add IPv6 filters:
@@ -1193,7 +1203,7 @@
 
         // Execution will reach the bottom of the program if none of the filters match,
         // which will pass the packet to the application processor.
-        maybeSetCounter(gen, Counter.PASSED_IPV6_ICMP);
+        maybeSetupCounter(gen, Counter.PASSED_IPV6_ICMP);
 
         // Append the count & pass trampoline, which increments the counter at the data address
         // pointed to by R1, then jumps to the pass label. This saves a few bytes over inserting
diff --git a/tests/net/java/android/net/apf/ApfTest.java b/tests/net/java/android/net/apf/ApfTest.java
index 436dd85..151b559 100644
--- a/tests/net/java/android/net/apf/ApfTest.java
+++ b/tests/net/java/android/net/apf/ApfTest.java
@@ -1048,12 +1048,17 @@
             4,    // Protocol size: 4
             0, 2  // Opcode: reply (2)
     };
-    private static final int ARP_TARGET_IP_ADDRESS_OFFSET = ETH_HEADER_LEN + 24;
+    private static final int ARP_SOURCE_IP_ADDRESS_OFFSET = ARP_HEADER_OFFSET + 14;
+    private static final int ARP_TARGET_IP_ADDRESS_OFFSET = ARP_HEADER_OFFSET + 24;
 
     private static final byte[] MOCK_IPV4_ADDR           = {10, 0, 0, 1};
     private static final byte[] MOCK_BROADCAST_IPV4_ADDR = {10, 0, 31, (byte) 255}; // prefix = 19
     private static final byte[] MOCK_MULTICAST_IPV4_ADDR = {(byte) 224, 0, 0, 1};
     private static final byte[] ANOTHER_IPV4_ADDR        = {10, 0, 0, 2};
+    private static final byte[] IPV4_SOURCE_ADDR         = {10, 0, 0, 3};
+    private static final byte[] ANOTHER_IPV4_SOURCE_ADDR = {(byte) 192, 0, 2, 1};
+    private static final byte[] BUG_PROBE_SOURCE_ADDR1   = {0, 0, 1, 2};
+    private static final byte[] BUG_PROBE_SOURCE_ADDR2   = {3, 4, 0, 0};
     private static final byte[] IPV4_ANY_HOST_ADDR       = {0, 0, 0, 0};
 
     // Helper to initialize a default apfFilter.
@@ -1399,10 +1404,16 @@
         assertVerdict(filterResult, program, arpRequestBroadcast(ANOTHER_IPV4_ADDR));
         assertDrop(program, arpRequestBroadcast(IPV4_ANY_HOST_ADDR));
 
+        // Verify ARP reply packets from different source ip
+        assertDrop(program, arpReply(IPV4_ANY_HOST_ADDR, IPV4_ANY_HOST_ADDR));
+        assertPass(program, arpReply(ANOTHER_IPV4_SOURCE_ADDR, IPV4_ANY_HOST_ADDR));
+        assertPass(program, arpReply(BUG_PROBE_SOURCE_ADDR1, IPV4_ANY_HOST_ADDR));
+        assertPass(program, arpReply(BUG_PROBE_SOURCE_ADDR2, IPV4_ANY_HOST_ADDR));
+
         // Verify unicast ARP reply packet is always accepted.
-        assertPass(program, arpReplyUnicast(MOCK_IPV4_ADDR));
-        assertPass(program, arpReplyUnicast(ANOTHER_IPV4_ADDR));
-        assertPass(program, arpReplyUnicast(IPV4_ANY_HOST_ADDR));
+        assertPass(program, arpReply(IPV4_SOURCE_ADDR, MOCK_IPV4_ADDR));
+        assertPass(program, arpReply(IPV4_SOURCE_ADDR, ANOTHER_IPV4_ADDR));
+        assertPass(program, arpReply(IPV4_SOURCE_ADDR, IPV4_ANY_HOST_ADDR));
 
         // Verify GARP reply packets are always filtered
         assertDrop(program, garpReply());
@@ -1431,19 +1442,20 @@
         apfFilter.shutdown();
     }
 
-    private static byte[] arpRequestBroadcast(byte[] tip) {
+    private static byte[] arpReply(byte[] sip, byte[] tip) {
         ByteBuffer packet = ByteBuffer.wrap(new byte[100]);
         packet.putShort(ETH_ETHERTYPE_OFFSET, (short)ETH_P_ARP);
-        put(packet, ETH_DEST_ADDR_OFFSET, ETH_BROADCAST_MAC_ADDRESS);
         put(packet, ARP_HEADER_OFFSET, ARP_IPV4_REPLY_HEADER);
+        put(packet, ARP_SOURCE_IP_ADDRESS_OFFSET, sip);
         put(packet, ARP_TARGET_IP_ADDRESS_OFFSET, tip);
         return packet.array();
     }
 
-    private static byte[] arpReplyUnicast(byte[] tip) {
+    private static byte[] arpRequestBroadcast(byte[] tip) {
         ByteBuffer packet = ByteBuffer.wrap(new byte[100]);
         packet.putShort(ETH_ETHERTYPE_OFFSET, (short)ETH_P_ARP);
-        put(packet, ARP_HEADER_OFFSET, ARP_IPV4_REPLY_HEADER);
+        put(packet, ETH_DEST_ADDR_OFFSET, ETH_BROADCAST_MAC_ADDRESS);
+        put(packet, ARP_HEADER_OFFSET, ARP_IPV4_REQUEST_HEADER);
         put(packet, ARP_TARGET_IP_ADDRESS_OFFSET, tip);
         return packet.array();
     }