Update NetlinkSocketTest to handle RTM_GETNEIGH{TBL} restrictions

Modify testBasicWorkingGetNeighborsQuery to handle the case where
sending RTM_GETNEIGH{TBL} messages is not allowed.

Add test BasicWorkingGetAddrQuery, which performs a simple RTM_GETADDR
request, as a smoke test for when sending RTM_GETNEIGH{TBL} is not
allowed.

Test: atest NetlinkSocketTest
Bug: 171572148
Change-Id: I09628cf7830c7f348eaf27a33cca0903ba84722c
diff --git a/tests/unit/src/android/net/netlink/NetlinkSocketTest.java b/tests/unit/src/android/net/netlink/NetlinkSocketTest.java
index 5716803..6a84a85 100644
--- a/tests/unit/src/android/net/netlink/NetlinkSocketTest.java
+++ b/tests/unit/src/android/net/netlink/NetlinkSocketTest.java
@@ -17,21 +17,36 @@
 package android.net.netlink;
 
 import static android.net.netlink.NetlinkSocket.DEFAULT_RECV_BUFSIZE;
+import static android.net.netlink.StructNlMsgHdr.NLM_F_DUMP;
+import static android.net.netlink.StructNlMsgHdr.NLM_F_REQUEST;
+import static android.system.OsConstants.AF_INET;
+import static android.system.OsConstants.AF_INET6;
+import static android.system.OsConstants.AF_UNSPEC;
+import static android.system.OsConstants.EACCES;
 import static android.system.OsConstants.NETLINK_ROUTE;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
+import android.content.Context;
 import android.net.netlink.NetlinkSocket;
 import android.net.netlink.RtNetlinkNeighborMessage;
 import android.net.netlink.StructNlMsgHdr;
+import android.system.ErrnoException;
 import android.system.NetlinkSocketAddress;
 import android.system.Os;
 
 import androidx.test.filters.SmallTest;
+import androidx.test.platform.app.InstrumentationRegistry;
 import androidx.test.runner.AndroidJUnit4;
 
+import com.android.modules.utils.build.SdkLevel;
+import com.android.net.module.util.Struct;
+import com.android.net.module.util.Struct.Field;
+import com.android.net.module.util.Struct.Type;
+
 import libcore.io.IoUtils;
 
 import org.junit.Test;
@@ -47,7 +62,7 @@
     private final String TAG = "NetlinkSocketTest";
 
     @Test
-    public void testBasicWorkingGetNeighborsQuery() throws Exception {
+    public void testGetNeighborsQuery() throws Exception {
         final FileDescriptor fd = NetlinkSocket.forProto(NETLINK_ROUTE);
         assertNotNull(fd);
 
@@ -63,6 +78,25 @@
         assertNotNull(req);
 
         final long TIMEOUT = 500;
+        final Context ctx = InstrumentationRegistry.getInstrumentation().getContext();
+        final int targetSdk =
+                ctx.getPackageManager()
+                        .getApplicationInfo(ctx.getPackageName(), 0)
+                        .targetSdkVersion;
+
+        // Apps targeting an SDK version > S are not allowed to send RTM_GETNEIGH{TBL} messages
+        if (SdkLevel.isAtLeastT() && targetSdk > 31) {
+            try {
+                NetlinkSocket.sendMessage(fd, req, 0, req.length, TIMEOUT);
+                fail("RTM_GETNEIGH is not allowed for apps targeting SDK > 31 on T+ platforms");
+            } catch (ErrnoException e) {
+                // Expected
+                assertEquals(e.errno, EACCES);
+                return;
+            }
+        }
+
+        // Check that apps targeting lower API levels / running on older platforms succeed
         assertEquals(req.length, NetlinkSocket.sendMessage(fd, req, 0, req.length, TIMEOUT));
 
         int neighMessageCount = 0;
@@ -103,4 +137,105 @@
 
         IoUtils.closeQuietly(fd);
     }
+
+    @Test
+    public void testBasicWorkingGetAddrQuery() throws Exception {
+        final FileDescriptor fd = NetlinkSocket.forProto(NETLINK_ROUTE);
+        assertNotNull(fd);
+
+        NetlinkSocket.connectToKernel(fd);
+
+        final NetlinkSocketAddress localAddr = (NetlinkSocketAddress) Os.getsockname(fd);
+        assertNotNull(localAddr);
+        assertEquals(0, localAddr.getGroupsMask());
+        assertTrue(0 != localAddr.getPortId());
+
+        final int testSeqno = 8;
+        final byte[] req = newGetAddrRequest(testSeqno);
+        assertNotNull(req);
+
+        final long timeout = 500;
+        assertEquals(req.length, NetlinkSocket.sendMessage(fd, req, 0, req.length, timeout));
+
+        int addrMessageCount = 0;
+
+        while (true) {
+            ByteBuffer response = NetlinkSocket.recvMessage(fd, DEFAULT_RECV_BUFSIZE, timeout);
+            assertNotNull(response);
+            assertTrue(StructNlMsgHdr.STRUCT_SIZE <= response.limit());
+            assertEquals(0, response.position());
+            assertEquals(ByteOrder.nativeOrder(), response.order());
+
+            final StructNlMsgHdr nlmsghdr = StructNlMsgHdr.parse(response);
+            assertNotNull(nlmsghdr);
+
+            if (nlmsghdr.nlmsg_type == NetlinkConstants.NLMSG_DONE) {
+                break;
+            }
+
+            assertEquals(NetlinkConstants.RTM_NEWADDR, nlmsghdr.nlmsg_type);
+            assertTrue((nlmsghdr.nlmsg_flags & StructNlMsgHdr.NLM_F_MULTI) != 0);
+            assertEquals(testSeqno, nlmsghdr.nlmsg_seq);
+            assertEquals(localAddr.getPortId(), nlmsghdr.nlmsg_pid);
+            addrMessageCount++;
+
+            final IfaddrMsg ifaMsg = Struct.parse(IfaddrMsg.class, response);
+            assertTrue(
+                    "Non-IP address family: " + ifaMsg.family,
+                    ifaMsg.family == AF_INET || ifaMsg.family == AF_INET6);
+        }
+
+        assertTrue(addrMessageCount > 0);
+
+        IoUtils.closeQuietly(fd);
+    }
+
+    /** A convenience method to create an RTM_GETADDR request message. */
+    private static byte[] newGetAddrRequest(int seqNo) {
+        final int length = StructNlMsgHdr.STRUCT_SIZE + Struct.getSize(RtgenMsg.class);
+        final byte[] bytes = new byte[length];
+        final ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
+        byteBuffer.order(ByteOrder.nativeOrder());
+
+        final StructNlMsgHdr nlmsghdr = new StructNlMsgHdr();
+        nlmsghdr.nlmsg_len = length;
+        nlmsghdr.nlmsg_type = NetlinkConstants.RTM_GETADDR;
+        nlmsghdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+        nlmsghdr.nlmsg_seq = seqNo;
+        nlmsghdr.pack(byteBuffer);
+
+        final RtgenMsg rtgenMsg = new RtgenMsg();
+        rtgenMsg.family = (byte) AF_UNSPEC;
+        rtgenMsg.writeToByteBuffer(byteBuffer);
+
+        return bytes;
+    }
+
+    /** From uapi/linux/rtnetlink.h */
+    private static class RtgenMsg extends Struct {
+        @Field(order = 0, type = Type.U8)
+        public short family;
+    }
+
+    /**
+     * From uapi/linux/ifaddr.h
+     *
+     * Public ensures visibility to Struct class
+     */
+    public static class IfaddrMsg extends Struct {
+        @Field(order = 0, type = Type.U8)
+        public short family;
+
+        @Field(order = 1, type = Type.U8)
+        public short prefixlen;
+
+        @Field(order = 2, type = Type.U8)
+        public short flags;
+
+        @Field(order = 3, type = Type.U8)
+        public short scope;
+
+        @Field(order = 4, type = Type.U32)
+        public long index;
+    }
 }