Add fallback logic and enable XFRM-I support in netd

This patch adds fallback logic, checking for XFRM-I kernel support, and
switching to use XFRM-I if supported. Fallbacks to VTIs are provided for
backward compatibility with 4.4 kernels. Parameters for VTI versus
XFRM-I are selected based on the kernel support for XFRM interfaces.

This is part of a patch set to enable XFRM-I support, with automatic
fallbacks to VTI in XfrmController (3/3)

Bug: 77856928
Test: Binder tests updated, passing. CTS & unit tests also passing
Change-Id: Idf90adeec0d499fe4d566e4203f0eabb2b94fffa
diff --git a/server/NetdNativeService.cpp b/server/NetdNativeService.cpp
index 7e1965b..3758991 100644
--- a/server/NetdNativeService.cpp
+++ b/server/NetdNativeService.cpp
@@ -188,6 +188,9 @@
     gCtls->trafficCtrl.dump(dw, false);
     dw.blankline();
 
+    gCtls->xfrmCtrl.dump(dw);
+    dw.blankline();
+
     {
         ScopedIndent indentLog(dw);
         if (contains(args, String16(OPT_SHORT))) {
diff --git a/server/XfrmController.cpp b/server/XfrmController.cpp
index 297a4b0..cf4d367 100644
--- a/server/XfrmController.cpp
+++ b/server/XfrmController.cpp
@@ -31,6 +31,7 @@
 #include <inttypes.h>
 
 #include <arpa/inet.h>
+#include <net/if.h>
 #include <netinet/in.h>
 
 #include <sys/socket.h>
@@ -53,6 +54,7 @@
 #include <log/log.h>
 #include <log/log_properties.h>
 #include <logwrap/logwrap.h>
+#include "DumpWriter.h"
 #include "Fwmark.h"
 #include "InterfaceController.h"
 #include "NetdConstants.h"
@@ -60,6 +62,9 @@
 #include "Permission.h"
 #include "ResponseCode.h"
 #include "XfrmController.h"
+#include "android-base/stringprintf.h"
+#include "android-base/strings.h"
+#include "android-base/unique_fd.h"
 #include "netdutils/Fd.h"
 #include "netdutils/Slice.h"
 #include "netdutils/Syscalls.h"
@@ -93,6 +98,9 @@
 constexpr const char* INFO_KIND_VTI6 = "vti6";
 constexpr const char* INFO_KIND_XFRMI = "xfrm";
 constexpr int INFO_KIND_MAX_LEN = 8;
+constexpr int LOOPBACK_IFINDEX = 1;
+
+bool mIsXfrmIntfSupported = false;
 
 static inline bool isEngBuild() {
     static const std::string sBuildType = android::base::GetProperty("ro.build.type", "user");
@@ -386,8 +394,15 @@
 //
 XfrmController::XfrmController(void) {}
 
+// Test-only constructor allowing override of XFRM Interface support checks
+XfrmController::XfrmController(bool xfrmIntfSupport) {
+    mIsXfrmIntfSupported = xfrmIntfSupport;
+}
+
 netdutils::Status XfrmController::Init() {
     RETURN_IF_NOT_OK(flushInterfaces());
+    mIsXfrmIntfSupported = isXfrmIntfSupported();
+
     XfrmSocketImpl sock;
     RETURN_IF_NOT_OK(sock.open());
     RETURN_IF_NOT_OK(flushSaDb(sock));
@@ -424,6 +439,18 @@
     return s.sendMessage(XFRM_MSG_FLUSHPOLICY, NETLINK_REQUEST_FLAGS, 0, &iov);
 }
 
+bool XfrmController::isXfrmIntfSupported() {
+    const char* IPSEC_TEST_INTF_NAME = "ipsec_test";
+    const int32_t XFRM_TEST_IF_ID = 0xFFFF;
+
+    bool errored = false;
+    errored |=
+            ipSecAddXfrmInterface(IPSEC_TEST_INTF_NAME, XFRM_TEST_IF_ID, NETLINK_ROUTE_CREATE_FLAGS)
+                    .code();
+    errored |= ipSecRemoveTunnelInterface(IPSEC_TEST_INTF_NAME).code();
+    return !errored;
+}
+
 netdutils::Status XfrmController::ipSecSetEncapSocketOwner(const android::base::unique_fd& socket,
                                                            int newUid, uid_t callerUid) {
     ALOGD("XfrmController:%s, line=%d", __FUNCTION__, __LINE__);
@@ -656,9 +683,13 @@
                                                      XfrmCommonInfo* info) {
     info->transformId = transformId;
     info->spi = htonl(spi);
-    info->mark.v = markValue;
-    info->mark.m = markMask;
-    info->xfrm_if_id = xfrmInterfaceId;
+
+    if (mIsXfrmIntfSupported) {
+        info->xfrm_if_id = xfrmInterfaceId;
+    } else {
+        info->mark.v = markValue;
+        info->mark.m = markMask;
+    }
 
     return netdutils::status::ok;
 }
@@ -904,11 +935,13 @@
         return netdutils::statusFromErrno(EINVAL, "Key length invalid; exceeds MAX_KEY_LENGTH");
     }
 
-    if (record.mode != XfrmMode::TUNNEL && record.xfrm_if_id != 0) {
-        // TODO: Also throw errors if output mark or mark supplied
+    if (record.mode != XfrmMode::TUNNEL &&
+        (record.xfrm_if_id != 0 || record.netId != 0 || record.mark.v != 0 || record.mark.m != 0)) {
         return netdutils::statusFromErrno(EINVAL,
-                                          "xfrm_if_id parameter invalid for non "
-                                          "tunnel-mode transform");
+                                          "xfrm_if_id, mark and netid parameters invalid "
+                                          "for non tunnel-mode transform");
+    } else if (record.mode == XfrmMode::TUNNEL && !mIsXfrmIntfSupported && record.xfrm_if_id != 0) {
+        return netdutils::statusFromErrno(EINVAL, "xfrm_if_id set for VTI Security Association");
     }
 
     int len;
@@ -1255,6 +1288,11 @@
 }
 
 int XfrmController::fillNlAttrXfrmMark(const XfrmCommonInfo& record, nlattr_xfrm_mark* mark) {
+    // Do not set if we were not given a mark
+    if (record.mark.v == 0 && record.mark.m == 0) {
+        return 0;
+    }
+
     mark->mark.v = record.mark.v; // set to 0 if it's not used
     mark->mark.m = record.mark.m; // set to 0 if it's not used
     int len = NLA_HDRLEN + sizeof(xfrm_mark);
@@ -1325,12 +1363,15 @@
 
     uint16_t flags = isUpdate ? NETLINK_REQUEST_FLAGS : NETLINK_ROUTE_CREATE_FLAGS;
 
-    return ipSecAddVirtualTunnelInterface(deviceName, localAddress, remoteAddress, ikey, okey,
-                                          flags);
+    if (mIsXfrmIntfSupported) {
+        return ipSecAddXfrmInterface(deviceName, interfaceId, flags);
+    } else {
+        return ipSecAddVirtualTunnelInterface(deviceName, localAddress, remoteAddress, ikey, okey,
+                                              flags);
+    }
 }
 
 netdutils::Status XfrmController::ipSecAddXfrmInterface(const std::string& deviceName,
-                                                        int32_t underlyingInterface,
                                                         int32_t interfaceId, uint16_t flags) {
     ALOGD("XfrmController::%s, line=%d", __FUNCTION__, __LINE__);
 
@@ -1391,7 +1432,10 @@
                                                  .nla_len = RTA_LENGTH(sizeof(uint32_t)),
                                                  .nla_type = IFLA_XFRM_LINK,
                                          },
-                                 .xfrmLink = static_cast<uint32_t>(underlyingInterface),
+                                 //   Always use LOOPBACK_IFINDEX, since we use output marks for
+                                 //   route lookup instead. The use case of having a Network with
+                                 //   loopback in it is unsupported in tunnel mode.
+                                 .xfrmLink = static_cast<uint32_t>(LOOPBACK_IFINDEX),
 
                                  .xfrmIfIdNla =
                                          {
@@ -1575,5 +1619,13 @@
     return netdutils::statusFromErrno(ret, "Error in deleting IpSec interface " + deviceName);
 }
 
+void XfrmController::dump(DumpWriter& dw) {
+    ScopedIndent indentForXfrmController(dw);
+    dw.println("XfrmController");
+
+    ScopedIndent indentForXfrmISupport(dw);
+    dw.println("XFRM-I support: %d", mIsXfrmIntfSupported);
+}
+
 } // namespace net
 } // namespace android
diff --git a/server/XfrmController.h b/server/XfrmController.h
index 9208f54..bba84e2 100644
--- a/server/XfrmController.h
+++ b/server/XfrmController.h
@@ -51,6 +51,7 @@
 // Suggest we avoid the smallest and largest ints
 class XfrmMessage;
 class TransportModeSecurityAssociation;
+class DumpWriter;
 
 class XfrmSocket {
 public:
@@ -135,6 +136,9 @@
 public:
     XfrmController();
 
+    // Initializer to override XFRM-I support for unit-testing purposes
+    explicit XfrmController(bool xfrmIntfSupport);
+
     static netdutils::Status Init();
 
     static netdutils::Status ipSecSetEncapSocketOwner(const android::base::unique_fd& socket,
@@ -195,6 +199,8 @@
 
     static netdutils::Status ipSecRemoveTunnelInterface(const std::string& deviceName);
 
+    void dump(DumpWriter& dw);
+
     // Some XFRM netlink attributes comprise a header, a struct, and some data
     // after the struct. We wrap all of those in one struct for easier
     // marshalling. The structs below must be ABI compatible with the kernel and
@@ -328,6 +334,8 @@
                   "is needed.");
 #endif
 
+    static bool isXfrmIntfSupported();
+
     // helper functions for filling in the XfrmCommonInfo (and XfrmSaInfo) structure
     static netdutils::Status fillXfrmCommonInfo(const std::string& sourceAddress,
                                                 const std::string& destinationAddress, int32_t spi,
@@ -394,8 +402,7 @@
     static netdutils::Status flushPolicyDb(const XfrmSocket& s);
 
     static netdutils::Status ipSecAddXfrmInterface(const std::string& deviceName,
-                                                   int32_t underlyingInterface, int32_t interfaceId,
-                                                   uint16_t flags);
+                                                   int32_t interfaceId, uint16_t flags);
     static netdutils::Status ipSecAddVirtualTunnelInterface(const std::string& deviceName,
                                                             const std::string& localAddress,
                                                             const std::string& remoteAddress,
diff --git a/server/XfrmControllerTest.cpp b/server/XfrmControllerTest.cpp
index be819d8..5f74e32 100644
--- a/server/XfrmControllerTest.cpp
+++ b/server/XfrmControllerTest.cpp
@@ -136,20 +136,6 @@
     testing::StrictMock<netdutils::ScopedMockSyscalls> mockSyscalls;
 };
 
-// Test class allowing IPv4/IPv6 parameterized tests.
-class XfrmControllerParameterizedTest : public XfrmControllerTest,
-                                        public ::testing::WithParamInterface<int> {};
-
-// Helper to make generated test names readable.
-std::string FamilyName(::testing::TestParamInfo<int> info) {
-    switch(info.param) {
-        case 4: return "IPv4";
-        case 5: return "IPv64DualStack";
-        case 6: return "IPv6";
-    }
-    return android::base::StringPrintf("UNKNOWN family type: %d", info.param);
-}
-
 /* Generate function to set value refered to by 3rd argument.
  *
  * This allows us to mock functions that pass in a pointer, expecting the result to be put into
@@ -217,12 +203,62 @@
     EXPECT_EQ(netdutils::statusFromErrno(EINVAL, "Socket did not have UDP-encap sockopt set"), res);
 }
 
+struct testCaseParams {
+    const int version;
+    const bool xfrmInterfacesEnabled;
+
+    int getTunnelInterfaceNlAttrsLen() {
+        if (xfrmInterfacesEnabled) {
+            return NLMSG_ALIGN(sizeof(XfrmController::nlattr_xfrm_interface_id));
+        } else {
+            return NLMSG_ALIGN(sizeof(XfrmController::nlattr_xfrm_mark));
+        }
+    }
+};
+
+// Test class allowing IPv4/IPv6 parameterized tests.
+class XfrmControllerParameterizedTest : public XfrmControllerTest,
+                                        public ::testing::WithParamInterface<testCaseParams> {};
+
+// Helper to make generated test names readable.
+std::string TestNameGenerator(::testing::TestParamInfo<testCaseParams> info) {
+    std::string name = "";
+    switch (info.param.version) {
+        case 4:
+            name += "IPv4";
+            break;
+        case 5:
+            name += "IPv64DualStack";
+            break;
+        case 6:
+            name += "IPv6";
+            break;
+        default:
+            name += android::base::StringPrintf("UNKNOWN family type: %d", info.param.version);
+            break;
+    }
+
+    name += "_";
+
+    if (info.param.xfrmInterfacesEnabled) {
+        name += "XFRMI";
+    } else {
+        name += "VTI";
+    }
+
+    return name;
+}
+
 // The TEST_P cases below will run with each of the following value parameters.
-INSTANTIATE_TEST_CASE_P(ByFamily, XfrmControllerParameterizedTest, Values(4, 5, 6),
-                        FamilyName);
+INSTANTIATE_TEST_CASE_P(ByFamily, XfrmControllerParameterizedTest,
+                        Values(testCaseParams{4, false}, testCaseParams{4, true},
+                               testCaseParams{5, false}, testCaseParams{5, true},
+                               testCaseParams{6, false}, testCaseParams{6, true}),
+                        TestNameGenerator);
 
 TEST_P(XfrmControllerParameterizedTest, TestIpSecAllocateSpi) {
-    const int version = GetParam();
+    testCaseParams params = GetParam();
+    const int version = params.version;
     const int family = (version == 6) ? AF_INET6 : AF_INET;
     const std::string localAddr = (version == 6) ? LOCALHOST_V6 : LOCALHOST_V4;
     const std::string remoteAddr = (version == 6) ? TEST_ADDR_V6 : TEST_ADDR_V4;
@@ -231,6 +267,7 @@
     response.hdr.nlmsg_type = XFRM_MSG_ALLOCSPI;
     Slice responseSlice = netdutils::makeSlice(response);
 
+    // No IF_ID expected for allocSPI.
     size_t expectedMsgLength = NLMSG_HDRLEN + NLMSG_ALIGN(sizeof(xfrm_userspi_info));
 
     // A vector to hold the flattened netlink message for nlMsgSlice
@@ -240,7 +277,7 @@
     EXPECT_CALL(mockSyscalls, read(_, _))
         .WillOnce(DoAll(SetArgSlice<1>(responseSlice), Return(responseSlice)));
 
-    XfrmController ctrl;
+    XfrmController ctrl(params.xfrmInterfacesEnabled);
     int outSpi = 0;
     Status res = ctrl.ipSecAllocateSpi(1 /* resourceId */, localAddr,
                                        remoteAddr, DROID_SPI, &outSpi);
@@ -263,8 +300,23 @@
     EXPECT_EQ(DROID_SPI, static_cast<int>(userspi.max));
 }
 
-void testIpSecAddSecurityAssociation(int version, const MockSyscalls& mockSyscalls,
-                                     const XfrmMode& mode, __u32 underlying_netid) {
+void verifyXfrmiArguments(int mark, int mask, int ifId) {
+    // Check that correct arguments (and only those) are non-zero, and correct.
+    EXPECT_EQ(0, mark);
+    EXPECT_EQ(0, mask);
+    EXPECT_EQ(TEST_XFRM_IF_ID, ifId);
+}
+
+void verifyVtiArguments(int mark, int mask, int ifId) {
+    // Check that correct arguments (and only those) are non-zero, and correct.
+    EXPECT_EQ(TEST_XFRM_MARK, mark);
+    EXPECT_EQ(TEST_XFRM_MASK, mask);
+    EXPECT_EQ(0, ifId);
+}
+
+void testIpSecAddSecurityAssociation(testCaseParams params, const MockSyscalls& mockSyscalls,
+                                     const XfrmMode& mode) {
+    const int version = params.version;
     const int family = (version == 6) ? AF_INET6 : AF_INET;
     const std::string localAddr = (version == 6) ? LOCALHOST_V6 : LOCALHOST_V4;
     const std::string remoteAddr = (version == 6) ? TEST_ADDR_V6 : TEST_ADDR_V4;
@@ -280,17 +332,20 @@
     size_t expectedMsgLength =
             NLMSG_HDRLEN + NLMSG_ALIGN(sizeof(xfrm_usersa_info)) +
             NLA_ALIGN(offsetof(XfrmController::nlattr_algo_crypt, key) + KEY_LENGTH) +
-            NLA_ALIGN(offsetof(XfrmController::nlattr_algo_auth, key) + KEY_LENGTH) +
-            NLA_ALIGN(sizeof(XfrmController::nlattr_xfrm_mark));
+            NLA_ALIGN(offsetof(XfrmController::nlattr_algo_auth, key) + KEY_LENGTH);
 
     uint32_t testIfId = 0;
+    uint32_t testMark = 0;
+    uint32_t testMarkMask = 0;
+    uint32_t testOutputNetid = 0;
     if (mode == XfrmMode::TUNNEL) {
-        expectedMsgLength += NLA_ALIGN(sizeof(XfrmController::nlattr_xfrm_interface_id));
-        testIfId = TEST_XFRM_IF_ID;
-    }
-
-    if (underlying_netid) {
+        expectedMsgLength += params.getTunnelInterfaceNlAttrsLen();
         expectedMsgLength += NLA_ALIGN(sizeof(XfrmController::nlattr_xfrm_output_mark));
+
+        testIfId = TEST_XFRM_IF_ID;
+        testMark = TEST_XFRM_MARK;
+        testMarkMask = TEST_XFRM_MASK;
+        testOutputNetid = TEST_XFRM_UNDERLYING_NET;
     }
 
     std::vector<uint8_t> nlMsgBuf;
@@ -299,11 +354,11 @@
     EXPECT_CALL(mockSyscalls, read(_, _))
         .WillOnce(DoAll(SetArgSlice<1>(responseSlice), Return(responseSlice)));
 
-    XfrmController ctrl;
+    XfrmController ctrl(params.xfrmInterfacesEnabled);
     Status res = ctrl.ipSecAddSecurityAssociation(
             1 /* resourceId */, static_cast<int>(mode), localAddr, remoteAddr,
-            underlying_netid /* underlying netid */, DROID_SPI, TEST_XFRM_MARK /* mark */,
-            TEST_XFRM_MASK /* mask */, "hmac(sha256)" /* auth algo */, authKey,
+            testOutputNetid /* underlying netid */, DROID_SPI, testMark /* mark */,
+            testMarkMask /* mask */, "hmac(sha256)" /* auth algo */, authKey,
             128 /* auth trunc length */, "cbc(aes)" /* encryption algo */, cryptKey,
             0 /* crypt trunc length? */, "" /* AEAD algo */, {}, 0,
             static_cast<int>(XfrmEncapType::NONE), 0 /* local port */, 0 /* remote port */,
@@ -376,34 +431,36 @@
                         reinterpret_cast<void*>(&encryptAlgo.key), KEY_LENGTH));
     EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(authKey.data()),
                         reinterpret_cast<void*>(&authAlgo.key), KEY_LENGTH));
-    EXPECT_EQ(TEST_XFRM_MARK, mark.mark.v);
-    EXPECT_EQ(TEST_XFRM_MASK, mark.mark.m);
-    EXPECT_EQ(testIfId, xfrm_if_id.if_id);
 
-    if (underlying_netid) {
+    if (mode == XfrmMode::TUNNEL) {
+        if (params.xfrmInterfacesEnabled) {
+            verifyXfrmiArguments(mark.mark.v, mark.mark.m, xfrm_if_id.if_id);
+        } else {
+            verifyVtiArguments(mark.mark.v, mark.mark.m, xfrm_if_id.if_id);
+        }
+
         Fwmark fwmark;
         fwmark.intValue = outputmark.outputMark;
-        EXPECT_EQ(underlying_netid, fwmark.netId);
+        EXPECT_EQ(testOutputNetid, fwmark.netId);
         EXPECT_EQ(PERMISSION_SYSTEM, fwmark.permission);
         EXPECT_TRUE(fwmark.explicitlySelected);
         EXPECT_TRUE(fwmark.protectedFromVpn);
+    } else {
+        EXPECT_EQ(0, outputmark.outputMark);
+        EXPECT_EQ(0, mark.mark.v);
+        EXPECT_EQ(0, mark.mark.m);
+        EXPECT_EQ(0, xfrm_if_id.if_id);
     }
 }
 
 TEST_P(XfrmControllerParameterizedTest, TestTransportModeIpSecAddSecurityAssociation) {
-    const int version = GetParam();
-    testIpSecAddSecurityAssociation(version, mockSyscalls, XfrmMode::TRANSPORT, 0);
+    testCaseParams params = GetParam();
+    testIpSecAddSecurityAssociation(params, mockSyscalls, XfrmMode::TRANSPORT);
 }
 
 TEST_P(XfrmControllerParameterizedTest, TestTunnelModeIpSecAddSecurityAssociation) {
-    const int version = GetParam();
-    testIpSecAddSecurityAssociation(version, mockSyscalls, XfrmMode::TUNNEL, 0);
-}
-
-TEST_P(XfrmControllerParameterizedTest, TestTunnelModeIpSecAddSecurityAssociationWithOutputMark) {
-    const int version = GetParam();
-    testIpSecAddSecurityAssociation(version, mockSyscalls, XfrmMode::TUNNEL,
-                                    TEST_XFRM_UNDERLYING_NET);
+    testCaseParams params = GetParam();
+    testIpSecAddSecurityAssociation(params, mockSyscalls, XfrmMode::TUNNEL);
 }
 
 TEST_F(XfrmControllerTest, TestIpSecAddSecurityAssociationIPv4Encap) {
@@ -442,7 +499,8 @@
 }
 
 TEST_P(XfrmControllerParameterizedTest, TestIpSecApplyTransportModeTransform) {
-    const int version = GetParam();
+    testCaseParams params = GetParam();
+    const int version = params.version;
     const int sockFamily = (version == 4) ? AF_INET : AF_INET6;
     const int xfrmFamily = (version == 6) ? AF_INET6: AF_INET;
     const std::string localAddr = (version == 6) ? LOCALHOST_V6 : LOCALHOST_V4;
@@ -468,7 +526,7 @@
         .WillOnce(DoAll(WithArg<3>(Invoke(SavePolicy)), SaveArg<4>(&optlen),
                         Return(netdutils::status::ok)));
 
-    XfrmController ctrl;
+    XfrmController ctrl(params.xfrmInterfacesEnabled);
     Status res = ctrl.ipSecApplyTransportModeTransform(sock, 1 /* resourceId */,
                                                        static_cast<int>(XfrmDirection::OUT),
                                                        localAddr, remoteAddr, DROID_SPI);
@@ -484,7 +542,8 @@
 }
 
 TEST_P(XfrmControllerParameterizedTest, TestIpSecRemoveTransportModeTransform) {
-    const int version = GetParam();
+    testCaseParams params = GetParam();
+    const int version = params.version;
     const int family = (version == 6) ? AF_INET6 : AF_INET;
     const std::string localAddr = (version == 6) ? LOCALHOST_V6 : LOCALHOST_V4;
     const std::string remoteAddr = (version == 6) ? TEST_ADDR_V6 : TEST_ADDR_V4;
@@ -503,7 +562,7 @@
     EXPECT_CALL(mockSyscalls, setsockopt(_, _, _, _, _))
         .WillOnce(DoAll(SaveArg<3>(&optval), SaveArg<4>(&optlen),
                         Return(netdutils::status::ok)));
-    XfrmController ctrl;
+    XfrmController ctrl(params.xfrmInterfacesEnabled);
     Status res = ctrl.ipSecRemoveTransportModeTransform(sock);
 
     EXPECT_TRUE(isOk(res)) << res;
@@ -511,8 +570,26 @@
     EXPECT_EQ(static_cast<socklen_t>(0), optlen);
 }
 
+void parseTunnelNetlinkAttrs(XfrmController::nlattr_xfrm_mark* mark,
+                             XfrmController::nlattr_xfrm_interface_id* xfrm_if_id, Slice attr_buf) {
+    auto attrHandler = [mark, xfrm_if_id](const nlattr& attr, const Slice& attr_payload) {
+        Slice buf = attr_payload;
+        if (attr.nla_type == XFRMA_MARK) {
+            mark->hdr = attr;
+            netdutils::extract(buf, mark->mark);
+        } else if (attr.nla_type == XFRMA_IF_ID) {
+            xfrm_if_id->hdr = attr;
+            netdutils::extract(buf, xfrm_if_id->if_id);
+        } else {
+            FAIL() << "Unexpected nlattr type: " << attr.nla_type;
+        }
+    };
+    forEachNetlinkAttribute(attr_buf, attrHandler);
+}
+
 TEST_P(XfrmControllerParameterizedTest, TestIpSecDeleteSecurityAssociation) {
-    const int version = GetParam();
+    testCaseParams params = GetParam();
+    const int version = params.version;
     const int family = (version == 6) ? AF_INET6 : AF_INET;
     const std::string localAddr = (version == 6) ? LOCALHOST_V6 : LOCALHOST_V4;
     const std::string remoteAddr = (version == 6) ? TEST_ADDR_V6 : TEST_ADDR_V4;
@@ -522,8 +599,7 @@
     Slice responseSlice = netdutils::makeSlice(response);
 
     size_t expectedMsgLength = NLMSG_HDRLEN + NLMSG_ALIGN(sizeof(xfrm_usersa_id)) +
-                               NLA_ALIGN(sizeof(XfrmController::nlattr_xfrm_mark)) +
-                               NLA_ALIGN(sizeof(XfrmController::nlattr_xfrm_interface_id));
+                               params.getTunnelInterfaceNlAttrsLen();
 
     std::vector<uint8_t> nlMsgBuf;
     EXPECT_CALL(mockSyscalls, writev(_, _))
@@ -531,7 +607,7 @@
     EXPECT_CALL(mockSyscalls, read(_, _))
         .WillOnce(DoAll(SetArgSlice<1>(responseSlice), Return(responseSlice)));
 
-    XfrmController ctrl;
+    XfrmController ctrl(params.xfrmInterfacesEnabled);
     Status res = ctrl.ipSecDeleteSecurityAssociation(1 /* resourceId */, localAddr, remoteAddr,
                                                      DROID_SPI, TEST_XFRM_MARK, TEST_XFRM_MASK,
                                                      TEST_XFRM_IF_ID);
@@ -549,21 +625,21 @@
     EXPECT_EQ(htonl(DROID_SPI), said.spi);
     expectAddressEquals(family, remoteAddr, said.daddr);
 
-    // Extract and check the mark.
+    // Extract and check the marks and xfrm_if_id
     XfrmController::nlattr_xfrm_mark mark{};
-    netdutils::extract(nlMsgSlice, mark);
-    nlMsgSlice = drop(nlMsgSlice, sizeof(XfrmController::nlattr_xfrm_mark));
-    EXPECT_EQ(TEST_XFRM_MARK, mark.mark.v);
-    EXPECT_EQ(TEST_XFRM_MASK, mark.mark.m);
-
-    // Extract and check the interface id.
     XfrmController::nlattr_xfrm_interface_id xfrm_if_id{};
-    netdutils::extract(nlMsgSlice, xfrm_if_id);
-    EXPECT_EQ(TEST_XFRM_IF_ID, xfrm_if_id.if_id);
+    parseTunnelNetlinkAttrs(&mark, &xfrm_if_id, nlMsgSlice);
+
+    if (params.xfrmInterfacesEnabled) {
+        verifyXfrmiArguments(mark.mark.v, mark.mark.m, xfrm_if_id.if_id);
+    } else {
+        verifyVtiArguments(mark.mark.v, mark.mark.m, xfrm_if_id.if_id);
+    }
 }
 
 TEST_P(XfrmControllerParameterizedTest, TestIpSecAddSecurityPolicy) {
-    const int version = GetParam();
+    testCaseParams params = GetParam();
+    const int version = params.version;
     const int family = (version == 6) ? AF_INET6 : AF_INET;
     const std::string localAddr = (version == 6) ? LOCALHOST_V6 : LOCALHOST_V4;
     const std::string remoteAddr = (version == 6) ? TEST_ADDR_V6 : TEST_ADDR_V4;
@@ -574,8 +650,7 @@
 
     size_t expectedMsgLength = NLMSG_HDRLEN + NLMSG_ALIGN(sizeof(xfrm_userpolicy_info)) +
                                NLMSG_ALIGN(sizeof(XfrmController::nlattr_user_tmpl)) +
-                               NLMSG_ALIGN(sizeof(XfrmController::nlattr_xfrm_mark)) +
-                               NLMSG_ALIGN(sizeof(XfrmController::nlattr_xfrm_interface_id));
+                               params.getTunnelInterfaceNlAttrsLen();
 
     std::vector<uint8_t> nlMsgBuf;
     EXPECT_CALL(mockSyscalls, writev(_, _))
@@ -583,7 +658,7 @@
     EXPECT_CALL(mockSyscalls, read(_, _))
         .WillOnce(DoAll(SetArgSlice<1>(responseSlice), Return(responseSlice)));
 
-    XfrmController ctrl;
+    XfrmController ctrl(params.xfrmInterfacesEnabled);
     Status res = ctrl.ipSecAddSecurityPolicy(
             1 /* resourceId */, family, static_cast<int>(XfrmDirection::OUT), localAddr, remoteAddr,
             0 /* SPI */, TEST_XFRM_MARK, TEST_XFRM_MASK, TEST_XFRM_IF_ID);
@@ -628,13 +703,17 @@
     forEachNetlinkAttribute(attr_buf, attrHandler);
 
     expectAddressEquals(family, remoteAddr, usertmpl.tmpl.id.daddr);
-    EXPECT_EQ(TEST_XFRM_MARK, mark.mark.v);
-    EXPECT_EQ(TEST_XFRM_MASK, mark.mark.m);
-    EXPECT_EQ(TEST_XFRM_IF_ID, xfrm_if_id.if_id);
+
+    if (params.xfrmInterfacesEnabled) {
+        verifyXfrmiArguments(mark.mark.v, mark.mark.m, xfrm_if_id.if_id);
+    } else {
+        verifyVtiArguments(mark.mark.v, mark.mark.m, xfrm_if_id.if_id);
+    }
 }
 
 TEST_P(XfrmControllerParameterizedTest, TestIpSecUpdateSecurityPolicy) {
-    const int version = GetParam();
+    testCaseParams params = GetParam();
+    const int version = params.version;
     const int family = (version == 6) ? AF_INET6 : AF_INET;
     const std::string localAddr = (version == 6) ? LOCALHOST_V6 : LOCALHOST_V4;
     const std::string remoteAddr = (version == 6) ? TEST_ADDR_V6 : TEST_ADDR_V4;
@@ -645,8 +724,7 @@
 
     size_t expectedMsgLength = NLMSG_HDRLEN + NLMSG_ALIGN(sizeof(xfrm_userpolicy_info)) +
                                NLMSG_ALIGN(sizeof(XfrmController::nlattr_user_tmpl)) +
-                               NLMSG_ALIGN(sizeof(XfrmController::nlattr_xfrm_mark)) +
-                               NLMSG_ALIGN(sizeof(XfrmController::nlattr_xfrm_interface_id));
+                               params.getTunnelInterfaceNlAttrsLen();
 
     std::vector<uint8_t> nlMsgBuf;
     EXPECT_CALL(mockSyscalls, writev(_, _))
@@ -654,10 +732,11 @@
     EXPECT_CALL(mockSyscalls, read(_, _))
         .WillOnce(DoAll(SetArgSlice<1>(responseSlice), Return(responseSlice)));
 
-    XfrmController ctrl;
+    XfrmController ctrl(params.xfrmInterfacesEnabled);
     Status res = ctrl.ipSecUpdateSecurityPolicy(
             1 /* resourceId */, family, static_cast<int>(XfrmDirection::OUT), localAddr, remoteAddr,
-            0 /* SPI */, 0 /* Mark */, 0 /* Mask */, TEST_XFRM_IF_ID /* xfrm_if_id */);
+            0 /* SPI */, TEST_XFRM_MARK /* Mark */, TEST_XFRM_MARK /* Mask */,
+            TEST_XFRM_IF_ID /* xfrm_if_id */);
 
     EXPECT_TRUE(isOk(res)) << res;
     EXPECT_EQ(expectedMsgLength, nlMsgBuf.size());
@@ -669,7 +748,8 @@
 }
 
 TEST_P(XfrmControllerParameterizedTest, TestIpSecDeleteSecurityPolicy) {
-    const int version = GetParam();
+    testCaseParams params = GetParam();
+    const int version = params.version;
     const int family = (version == 6) ? AF_INET6 : AF_INET;
     const std::string localAddr = (version == 6) ? LOCALHOST_V6 : LOCALHOST_V4;
     const std::string remoteAddr = (version == 6) ? TEST_ADDR_V6 : TEST_ADDR_V4;
@@ -679,8 +759,7 @@
     Slice responseSlice = netdutils::makeSlice(response);
 
     size_t expectedMsgLength = NLMSG_HDRLEN + NLMSG_ALIGN(sizeof(xfrm_userpolicy_id)) +
-                               NLMSG_ALIGN(sizeof(XfrmController::nlattr_xfrm_mark)) +
-                               NLMSG_ALIGN(sizeof(XfrmController::nlattr_xfrm_interface_id));
+                               params.getTunnelInterfaceNlAttrsLen();
 
     std::vector<uint8_t> nlMsgBuf;
     EXPECT_CALL(mockSyscalls, writev(_, _))
@@ -688,7 +767,7 @@
     EXPECT_CALL(mockSyscalls, read(_, _))
         .WillOnce(DoAll(SetArgSlice<1>(responseSlice), Return(responseSlice)));
 
-    XfrmController ctrl;
+    XfrmController ctrl(params.xfrmInterfacesEnabled);
     Status res = ctrl.ipSecDeleteSecurityPolicy(1 /* resourceId */, family,
                                                 static_cast<int>(XfrmDirection::OUT),
                                                 TEST_XFRM_MARK, TEST_XFRM_MASK, TEST_XFRM_IF_ID);
@@ -707,17 +786,16 @@
     // Drop the user policy id.
     nlMsgSlice = drop(nlMsgSlice, NLA_ALIGN(sizeof(xfrm_userpolicy_id)));
 
-    // Extract and check the mark.
+    // Extract and check the marks and xfrm_if_id
     XfrmController::nlattr_xfrm_mark mark{};
-    netdutils::extract(nlMsgSlice, mark);
-    nlMsgSlice = drop(nlMsgSlice, sizeof(XfrmController::nlattr_xfrm_mark));
-    EXPECT_EQ(TEST_XFRM_MARK, mark.mark.v);
-    EXPECT_EQ(TEST_XFRM_MASK, mark.mark.m);
-
-    // Extract and check the interface id.
     XfrmController::nlattr_xfrm_interface_id xfrm_if_id{};
-    netdutils::extract(nlMsgSlice, xfrm_if_id);
-    EXPECT_EQ(TEST_XFRM_IF_ID, xfrm_if_id.if_id);
+    parseTunnelNetlinkAttrs(&mark, &xfrm_if_id, nlMsgSlice);
+
+    if (params.xfrmInterfacesEnabled) {
+        verifyXfrmiArguments(mark.mark.v, mark.mark.m, xfrm_if_id.if_id);
+    } else {
+        verifyVtiArguments(mark.mark.v, mark.mark.m, xfrm_if_id.if_id);
+    }
 }
 
 // TODO: Add tests for VTIs, ensuring that we are sending the correct data over netlink.
diff --git a/server/binder/android/net/INetd.aidl b/server/binder/android/net/INetd.aidl
index 48d064f..7543c8f 100644
--- a/server/binder/android/net/INetd.aidl
+++ b/server/binder/android/net/INetd.aidl
@@ -342,7 +342,8 @@
     * @param mode either Transport or Tunnel mode
     * @param sourceAddress InetAddress as string for the sending endpoint
     * @param destinationAddress InetAddress as string for the receiving endpoint
-    * @param underlyingNetId the netId of the network to which the SA is applied.
+    * @param underlyingNetId the netId of the network to which the SA is applied. Only accepted for
+    *        tunnel mode SAs.
     * @param spi a 32-bit unique ID allocated to the user
     * @param markValue a 32-bit unique ID chosen by the user
     * @param markMask a 32-bit mask chosen by the user