Update NetlinkSocketTest to handle RTM_GETNEIGH{TBL} restrictions
Modify testBasicWorkingGetNeighborsQuery to handle the case where
sending RTM_GETNEIGH{TBL} messages is not allowed.
Add test BasicWorkingGetAddrQuery, which performs a simple RTM_GETADDR
request, as a smoke test for when sending RTM_GETNEIGH{TBL} is not
allowed.
Test: atest NetlinkSocketTest
Bug: 171572148
Change-Id: I09628cf7830c7f348eaf27a33cca0903ba84722c
diff --git a/tests/unit/src/android/net/netlink/NetlinkSocketTest.java b/tests/unit/src/android/net/netlink/NetlinkSocketTest.java
index 5716803..6a84a85 100644
--- a/tests/unit/src/android/net/netlink/NetlinkSocketTest.java
+++ b/tests/unit/src/android/net/netlink/NetlinkSocketTest.java
@@ -17,21 +17,36 @@
package android.net.netlink;
import static android.net.netlink.NetlinkSocket.DEFAULT_RECV_BUFSIZE;
+import static android.net.netlink.StructNlMsgHdr.NLM_F_DUMP;
+import static android.net.netlink.StructNlMsgHdr.NLM_F_REQUEST;
+import static android.system.OsConstants.AF_INET;
+import static android.system.OsConstants.AF_INET6;
+import static android.system.OsConstants.AF_UNSPEC;
+import static android.system.OsConstants.EACCES;
import static android.system.OsConstants.NETLINK_ROUTE;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import android.content.Context;
import android.net.netlink.NetlinkSocket;
import android.net.netlink.RtNetlinkNeighborMessage;
import android.net.netlink.StructNlMsgHdr;
+import android.system.ErrnoException;
import android.system.NetlinkSocketAddress;
import android.system.Os;
import androidx.test.filters.SmallTest;
+import androidx.test.platform.app.InstrumentationRegistry;
import androidx.test.runner.AndroidJUnit4;
+import com.android.modules.utils.build.SdkLevel;
+import com.android.net.module.util.Struct;
+import com.android.net.module.util.Struct.Field;
+import com.android.net.module.util.Struct.Type;
+
import libcore.io.IoUtils;
import org.junit.Test;
@@ -47,7 +62,7 @@
private final String TAG = "NetlinkSocketTest";
@Test
- public void testBasicWorkingGetNeighborsQuery() throws Exception {
+ public void testGetNeighborsQuery() throws Exception {
final FileDescriptor fd = NetlinkSocket.forProto(NETLINK_ROUTE);
assertNotNull(fd);
@@ -63,6 +78,25 @@
assertNotNull(req);
final long TIMEOUT = 500;
+ final Context ctx = InstrumentationRegistry.getInstrumentation().getContext();
+ final int targetSdk =
+ ctx.getPackageManager()
+ .getApplicationInfo(ctx.getPackageName(), 0)
+ .targetSdkVersion;
+
+ // Apps targeting an SDK version > S are not allowed to send RTM_GETNEIGH{TBL} messages
+ if (SdkLevel.isAtLeastT() && targetSdk > 31) {
+ try {
+ NetlinkSocket.sendMessage(fd, req, 0, req.length, TIMEOUT);
+ fail("RTM_GETNEIGH is not allowed for apps targeting SDK > 31 on T+ platforms");
+ } catch (ErrnoException e) {
+ // Expected
+ assertEquals(e.errno, EACCES);
+ return;
+ }
+ }
+
+ // Check that apps targeting lower API levels / running on older platforms succeed
assertEquals(req.length, NetlinkSocket.sendMessage(fd, req, 0, req.length, TIMEOUT));
int neighMessageCount = 0;
@@ -103,4 +137,105 @@
IoUtils.closeQuietly(fd);
}
+
+ @Test
+ public void testBasicWorkingGetAddrQuery() throws Exception {
+ final FileDescriptor fd = NetlinkSocket.forProto(NETLINK_ROUTE);
+ assertNotNull(fd);
+
+ NetlinkSocket.connectToKernel(fd);
+
+ final NetlinkSocketAddress localAddr = (NetlinkSocketAddress) Os.getsockname(fd);
+ assertNotNull(localAddr);
+ assertEquals(0, localAddr.getGroupsMask());
+ assertTrue(0 != localAddr.getPortId());
+
+ final int testSeqno = 8;
+ final byte[] req = newGetAddrRequest(testSeqno);
+ assertNotNull(req);
+
+ final long timeout = 500;
+ assertEquals(req.length, NetlinkSocket.sendMessage(fd, req, 0, req.length, timeout));
+
+ int addrMessageCount = 0;
+
+ while (true) {
+ ByteBuffer response = NetlinkSocket.recvMessage(fd, DEFAULT_RECV_BUFSIZE, timeout);
+ assertNotNull(response);
+ assertTrue(StructNlMsgHdr.STRUCT_SIZE <= response.limit());
+ assertEquals(0, response.position());
+ assertEquals(ByteOrder.nativeOrder(), response.order());
+
+ final StructNlMsgHdr nlmsghdr = StructNlMsgHdr.parse(response);
+ assertNotNull(nlmsghdr);
+
+ if (nlmsghdr.nlmsg_type == NetlinkConstants.NLMSG_DONE) {
+ break;
+ }
+
+ assertEquals(NetlinkConstants.RTM_NEWADDR, nlmsghdr.nlmsg_type);
+ assertTrue((nlmsghdr.nlmsg_flags & StructNlMsgHdr.NLM_F_MULTI) != 0);
+ assertEquals(testSeqno, nlmsghdr.nlmsg_seq);
+ assertEquals(localAddr.getPortId(), nlmsghdr.nlmsg_pid);
+ addrMessageCount++;
+
+ final IfaddrMsg ifaMsg = Struct.parse(IfaddrMsg.class, response);
+ assertTrue(
+ "Non-IP address family: " + ifaMsg.family,
+ ifaMsg.family == AF_INET || ifaMsg.family == AF_INET6);
+ }
+
+ assertTrue(addrMessageCount > 0);
+
+ IoUtils.closeQuietly(fd);
+ }
+
+ /** A convenience method to create an RTM_GETADDR request message. */
+ private static byte[] newGetAddrRequest(int seqNo) {
+ final int length = StructNlMsgHdr.STRUCT_SIZE + Struct.getSize(RtgenMsg.class);
+ final byte[] bytes = new byte[length];
+ final ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
+ byteBuffer.order(ByteOrder.nativeOrder());
+
+ final StructNlMsgHdr nlmsghdr = new StructNlMsgHdr();
+ nlmsghdr.nlmsg_len = length;
+ nlmsghdr.nlmsg_type = NetlinkConstants.RTM_GETADDR;
+ nlmsghdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ nlmsghdr.nlmsg_seq = seqNo;
+ nlmsghdr.pack(byteBuffer);
+
+ final RtgenMsg rtgenMsg = new RtgenMsg();
+ rtgenMsg.family = (byte) AF_UNSPEC;
+ rtgenMsg.writeToByteBuffer(byteBuffer);
+
+ return bytes;
+ }
+
+ /** From uapi/linux/rtnetlink.h */
+ private static class RtgenMsg extends Struct {
+ @Field(order = 0, type = Type.U8)
+ public short family;
+ }
+
+ /**
+ * From uapi/linux/ifaddr.h
+ *
+ * Public ensures visibility to Struct class
+ */
+ public static class IfaddrMsg extends Struct {
+ @Field(order = 0, type = Type.U8)
+ public short family;
+
+ @Field(order = 1, type = Type.U8)
+ public short prefixlen;
+
+ @Field(order = 2, type = Type.U8)
+ public short flags;
+
+ @Field(order = 3, type = Type.U8)
+ public short scope;
+
+ @Field(order = 4, type = Type.U32)
+ public long index;
+ }
}