Add function to create and delete XFRM policy in XfrmController

Bug: 69561215
Test: runtest -x tests/netd_unit_test.cpp

Change-Id: I17f4f069de182eafedf4d98748e1d7be02e663a9
diff --git a/server/XfrmController.cpp b/server/XfrmController.cpp
index 03f2fe3..92dd757 100644
--- a/server/XfrmController.cpp
+++ b/server/XfrmController.cpp
@@ -655,7 +655,58 @@
     return status;
 }
 
-void XfrmController::fillTransportModeSelector(const XfrmSaInfo& record, xfrm_selector* selector) {
+netdutils::Status XfrmController::ipSecAddSecurityPolicy(int32_t transformId, int32_t direction,
+                                                         const std::string& localAddress,
+                                                         const std::string& remoteAddress,
+                                                         int32_t spi) {
+    return processSecurityPolicy(transformId, direction, localAddress, remoteAddress, spi,
+                                 XFRM_MSG_NEWPOLICY);
+}
+
+netdutils::Status XfrmController::ipSecUpdateSecurityPolicy(int32_t transformId, int32_t direction,
+                                                            const std::string& localAddress,
+                                                            const std::string& remoteAddress,
+                                                            int32_t spi) {
+    return processSecurityPolicy(transformId, direction, localAddress, remoteAddress, spi,
+                                 XFRM_MSG_UPDPOLICY);
+}
+
+netdutils::Status XfrmController::ipSecDeleteSecurityPolicy(int32_t transformId, int32_t direction,
+                                                            const std::string& localAddress,
+                                                            const std::string& remoteAddress) {
+    return processSecurityPolicy(transformId, direction, localAddress, remoteAddress, 0,
+                                 XFRM_MSG_DELPOLICY);
+}
+
+netdutils::Status XfrmController::processSecurityPolicy(int32_t transformId, int32_t direction,
+                                                        const std::string& localAddress,
+                                                        const std::string& remoteAddress,
+                                                        int32_t spi, int32_t msgType) {
+    ALOGD("XfrmController::%s, line=%d", __FUNCTION__, __LINE__);
+    ALOGD("transformId=%d", transformId);
+    ALOGD("direction=%d", direction);
+    ALOGD("localAddress=%s", localAddress.c_str());
+    ALOGD("remoteAddress=%s", remoteAddress.c_str());
+    ALOGD("spi=%0.8x", spi);
+    ALOGD("msgType=%d", msgType);
+
+    XfrmSaInfo saInfo{};
+    saInfo.mode = XfrmMode::TUNNEL;
+
+    XfrmSocketImpl sock;
+    RETURN_IF_NOT_OK(sock.open());
+
+    RETURN_IF_NOT_OK(fillXfrmId(localAddress, remoteAddress, spi, transformId, &saInfo));
+
+    if (msgType == XFRM_MSG_DELPOLICY) {
+        return deleteTunnelModeSecurityPolicy(saInfo, sock, static_cast<XfrmDirection>(direction));
+    } else {
+        return updateTunnelModeSecurityPolicy(saInfo, sock, static_cast<XfrmDirection>(direction),
+                                              msgType);
+    }
+}
+
+void XfrmController::fillXfrmSelector(const XfrmSaInfo& record, xfrm_selector* selector) {
     selector->family = record.addrFamily;
     selector->proto = AF_UNSPEC;      // TODO: do we need to match the protocol? it's
                                       // possible via the socket
@@ -797,7 +848,7 @@
 }
 
 int XfrmController::fillUserSaInfo(const XfrmSaInfo& record, xfrm_usersa_info* usersa) {
-    fillTransportModeSelector(record, &usersa->sel);
+    fillXfrmSelector(record, &usersa->sel);
 
     usersa->id.proto = IPPROTO_ESP;
     usersa->id.spi = record.spi;
@@ -898,9 +949,65 @@
     return ret;
 }
 
+netdutils::Status XfrmController::updateTunnelModeSecurityPolicy(const XfrmSaInfo& record,
+                                                                 const XfrmSocket& sock,
+                                                                 XfrmDirection direction,
+                                                                 uint16_t msgType) {
+    xfrm_userpolicy_info userpolicy{};
+    nlattr_user_tmpl usertmpl{};
+
+    enum {
+        NLMSG_HDR,
+        USERPOLICY,
+        USERPOLICY_PAD,
+        USERTMPL,
+        USERTMPL_PAD,
+    };
+
+    std::vector<iovec> iov = {
+        {NULL, 0},        // reserved for the eventual addition of a NLMSG_HDR
+        {&userpolicy, 0}, // main xfrm_userpolicy_info struct
+        {kPadBytes, 0},   // up to NLMSG_ALIGNTO pad bytes of padding
+        {&usertmpl, 0},   // adjust size if xfrm_user_tmpl struct is present
+        {kPadBytes, 0},   // up to NLATTR_ALIGNTO pad bytes
+    };
+
+    int len;
+    len = iov[USERPOLICY].iov_len = fillTransportModeUserSpInfo(record, direction, &userpolicy);
+    iov[USERPOLICY_PAD].iov_len = NLMSG_ALIGN(len) - len;
+
+    len = iov[USERTMPL].iov_len = fillNlAttrUserTemplate(record, &usertmpl);
+    iov[USERTMPL_PAD].iov_len = NLA_ALIGN(len) - len;
+
+    return sock.sendMessage(msgType, NETLINK_REQUEST_FLAGS, 0, &iov);
+}
+
+netdutils::Status XfrmController::deleteTunnelModeSecurityPolicy(const XfrmSaInfo& record,
+                                                                 const XfrmSocket& sock,
+                                                                 XfrmDirection direction) {
+    xfrm_userpolicy_id policyid{};
+
+    enum {
+        NLMSG_HDR,
+        USERPOLICYID,
+        USERPOLICYID_PAD,
+    };
+
+    std::vector<iovec> iov = {
+        {NULL, 0},      // reserved for the eventual addition of a NLMSG_HDR
+        {&policyid, 0}, // main xfrm_userpolicy_id struct
+        {kPadBytes, 0}, // up to NLMSG_ALIGNTO pad bytes of padding
+    };
+
+    int len = iov[USERPOLICYID].iov_len = fillUserPolicyId(record, direction, &policyid);
+    iov[USERPOLICYID_PAD].iov_len = NLMSG_ALIGN(len) - len;
+
+    return sock.sendMessage(XFRM_MSG_DELPOLICY, NETLINK_REQUEST_FLAGS, 0, &iov);
+}
+
 int XfrmController::fillTransportModeUserSpInfo(const XfrmSaInfo& record, XfrmDirection direction,
                                                 xfrm_userpolicy_info* usersp) {
-    fillTransportModeSelector(record, &usersp->sel);
+    fillXfrmSelector(record, &usersp->sel);
     fillXfrmLifetimeDefaults(&usersp->lft);
     fillXfrmCurLifetimeDefaults(&usersp->curlft);
     /* if (index) index & 0x3 == dir -- must be true
@@ -929,8 +1036,25 @@
                                         // algos, we should find it and apply it.
                                         // I can't find one.
     tmpl->ealgos = ALGO_MASK_CRYPT_ALL; // TODO: if there's a bitmask somewhere...
-    return 0;
+    return sizeof(xfrm_user_tmpl*);
 }
 
+int XfrmController::fillNlAttrUserTemplate(const XfrmSaInfo& record, nlattr_user_tmpl* tmpl) {
+    fillUserTemplate(record, &tmpl->tmpl);
+
+    int len = NLA_HDRLEN + sizeof(xfrm_user_tmpl);
+    fillXfrmNlaHdr(&tmpl->hdr, XFRMA_TMPL, len);
+    return len;
+}
+
+int XfrmController::fillUserPolicyId(const XfrmSaInfo& record, XfrmDirection direction,
+                                     xfrm_userpolicy_id* usersp) {
+    // For DELPOLICY, when index is absent, selector is needed to match the policy
+    fillXfrmSelector(record, &usersp->sel);
+    usersp->dir = static_cast<uint8_t>(direction);
+    return sizeof(*usersp);
+}
+
+
 } // namespace net
 } // namespace android
diff --git a/server/XfrmController.h b/server/XfrmController.h
index ee42dec..0d29c1d 100644
--- a/server/XfrmController.h
+++ b/server/XfrmController.h
@@ -148,6 +148,18 @@
 
     netdutils::Status ipSecRemoveTransportModeTransform(const android::base::unique_fd& socket);
 
+    netdutils::Status ipSecAddSecurityPolicy(int32_t transformId, int32_t direction,
+                                             const std::string& sourceAddress,
+                                             const std::string& destinationAddress, int32_t spi);
+
+    netdutils::Status ipSecUpdateSecurityPolicy(int32_t transformId, int32_t direction,
+                                                const std::string& sourceAddress,
+                                                const std::string& destinationAddress, int32_t spi);
+
+    netdutils::Status ipSecDeleteSecurityPolicy(int32_t transformId, int32_t direction,
+                                                const std::string& sourceAddress,
+                                                const std::string& destinationAddress);
+
     // Some XFRM netlink attributes comprise a header, a struct, and some data
     // after the struct. We wrap all of those in one struct for easier
     // marshalling. The structs below must be ABI compatible with the kernel and
@@ -246,7 +258,7 @@
 
     // TODO(messagerefactor): FACTOR OUT ALL MESSAGE BUILDING CODE BELOW HERE
     // Shared between SA and SP
-    static void fillTransportModeSelector(const XfrmSaInfo& record, xfrm_selector* selector);
+    static void fillXfrmSelector(const XfrmSaInfo& record, xfrm_selector* selector);
 
     // Shared between Transport and Tunnel Mode
     static int fillNlAttrXfrmAlgoEnc(const XfrmAlgo& in_algo, nlattr_algo_crypt* algo);
@@ -264,12 +276,28 @@
                                                        const XfrmSocket& sock);
     static int fillUserSaId(const XfrmId& record, xfrm_usersa_id* said);
     static int fillUserTemplate(const XfrmSaInfo& record, xfrm_user_tmpl* tmpl);
+
     static int fillTransportModeUserSpInfo(const XfrmSaInfo& record, XfrmDirection direction,
                                            xfrm_userpolicy_info* usersp);
+    static int fillNlAttrUserTemplate(const XfrmSaInfo& record, nlattr_user_tmpl* tmpl);
+    static int fillUserPolicyId(const XfrmSaInfo& record, XfrmDirection direction,
+                                xfrm_userpolicy_id* policy_id);
 
     static netdutils::Status allocateSpi(const XfrmSaInfo& record, uint32_t minSpi, uint32_t maxSpi,
                                          uint32_t* outSpi, const XfrmSocket& sock);
 
+    static netdutils::Status processSecurityPolicy(int32_t transformId, int32_t direction,
+                                                   const std::string& localAddress,
+                                                   const std::string& remoteAddress,
+                                                   int32_t spi, int32_t msgType);
+    static netdutils::Status updateTunnelModeSecurityPolicy(const XfrmSaInfo& record,
+                                                            const XfrmSocket& sock,
+                                                            XfrmDirection direction,
+                                                            uint16_t msgType);
+    static netdutils::Status deleteTunnelModeSecurityPolicy(const XfrmSaInfo& record,
+                                                            const XfrmSocket& sock,
+                                                            XfrmDirection direction);
+
     // END TODO(messagerefactor)
 };
 
diff --git a/server/XfrmControllerTest.cpp b/server/XfrmControllerTest.cpp
index 4c70fc6..762bede 100644
--- a/server/XfrmControllerTest.cpp
+++ b/server/XfrmControllerTest.cpp
@@ -499,5 +499,128 @@
     expectAddressEquals(family, remoteAddr, said.daddr);
 }
 
+TEST_P(XfrmControllerParameterizedTest, TestIpSecAddSecurityPolicy) {
+    const int version = GetParam();
+    const int family = (version == 6) ? AF_INET6 : AF_INET;
+    const std::string localAddr = (version == 6) ? LOCALHOST_V6 : LOCALHOST_V4;
+    const std::string remoteAddr = (version == 6) ? TEST_ADDR_V6 : TEST_ADDR_V4;
+
+    NetlinkResponse response{};
+    response.hdr.nlmsg_type = XFRM_MSG_NEWPOLICY;
+    Slice responseSlice = netdutils::makeSlice(response);
+
+    size_t expectedMsgLength = NLMSG_HDRLEN + NLMSG_ALIGN(sizeof(xfrm_userpolicy_info)) +
+                               NLMSG_ALIGN(sizeof(XfrmController::nlattr_user_tmpl));
+
+    std::vector<uint8_t> nlMsgBuf;
+    EXPECT_CALL(mockSyscalls, writev(_, _))
+        .WillOnce(DoAll(SaveFlattenedIovecs<1>(&nlMsgBuf), Return(expectedMsgLength)));
+    EXPECT_CALL(mockSyscalls, read(_, _))
+        .WillOnce(DoAll(SetArgSlice<1>(responseSlice), Return(responseSlice)));
+
+    XfrmController ctrl;
+    Status res =
+        ctrl.ipSecAddSecurityPolicy(1 /* resourceId */, static_cast<int>(XfrmDirection::OUT),
+                                    localAddr, remoteAddr, 0 /* SPI */);
+
+    EXPECT_TRUE(isOk(res)) << res;
+    EXPECT_EQ(expectedMsgLength, nlMsgBuf.size());
+
+    Slice nlMsgSlice = netdutils::makeSlice(nlMsgBuf);
+    nlmsghdr hdr;
+    netdutils::extract(nlMsgSlice, hdr);
+    EXPECT_EQ(XFRM_MSG_NEWPOLICY, hdr.nlmsg_type);
+
+    nlMsgSlice = netdutils::drop(nlMsgSlice, NLMSG_HDRLEN);
+    xfrm_userpolicy_info userpolicy{};
+    netdutils::extract(nlMsgSlice, userpolicy);
+
+    EXPECT_EQ(static_cast<uint8_t>(XfrmDirection::OUT), userpolicy.dir);
+
+    // Drop the user policy info.
+    Slice attr_buf = drop(nlMsgSlice, NLA_ALIGN(sizeof(xfrm_userpolicy_info)));
+
+    // Extract and check the user tmpl.
+    XfrmController::nlattr_user_tmpl usertmpl{};
+    auto attrHandler = [&usertmpl](const nlattr& attr, const Slice& attr_payload) {
+        Slice buf = attr_payload;
+        if (attr.nla_type == XFRMA_TMPL) {
+            usertmpl.hdr = attr;
+            netdutils::extract(buf, usertmpl.tmpl);
+        } else {
+            FAIL() << "Unexpected nlattr type: " << attr.nla_type;
+        }
+    };
+    forEachNetlinkAttribute(attr_buf, attrHandler);
+
+    expectAddressEquals(family, remoteAddr, usertmpl.tmpl.id.daddr);
+}
+
+TEST_P(XfrmControllerParameterizedTest, TestIpSecUpdateSecurityPolicy) {
+    const int version = GetParam();
+    const std::string localAddr = (version == 6) ? LOCALHOST_V6 : LOCALHOST_V4;
+    const std::string remoteAddr = (version == 6) ? TEST_ADDR_V6 : TEST_ADDR_V4;
+
+    NetlinkResponse response{};
+    response.hdr.nlmsg_type = XFRM_MSG_UPDPOLICY;
+    Slice responseSlice = netdutils::makeSlice(response);
+
+    size_t expectedMsgLength = NLMSG_HDRLEN + NLMSG_ALIGN(sizeof(xfrm_userpolicy_info)) +
+                               NLMSG_ALIGN(sizeof(XfrmController::nlattr_user_tmpl));
+
+    std::vector<uint8_t> nlMsgBuf;
+    EXPECT_CALL(mockSyscalls, writev(_, _))
+        .WillOnce(DoAll(SaveFlattenedIovecs<1>(&nlMsgBuf), Return(expectedMsgLength)));
+    EXPECT_CALL(mockSyscalls, read(_, _))
+        .WillOnce(DoAll(SetArgSlice<1>(responseSlice), Return(responseSlice)));
+
+    XfrmController ctrl;
+    Status res =
+        ctrl.ipSecUpdateSecurityPolicy(1 /* resourceId */, static_cast<int>(XfrmDirection::OUT),
+                                       localAddr, remoteAddr, 0 /* SPI */);
+
+    EXPECT_TRUE(isOk(res)) << res;
+    EXPECT_EQ(expectedMsgLength, nlMsgBuf.size());
+
+    Slice nlMsgSlice = netdutils::makeSlice(nlMsgBuf);
+    nlmsghdr hdr;
+    netdutils::extract(nlMsgSlice, hdr);
+    EXPECT_EQ(XFRM_MSG_UPDPOLICY, hdr.nlmsg_type);
+}
+
+TEST_P(XfrmControllerParameterizedTest, TestIpSecDeleteSecurityPolicy) {
+    const int version = GetParam();
+    const std::string localAddr = (version == 6) ? LOCALHOST_V6 : LOCALHOST_V4;
+    const std::string remoteAddr = (version == 6) ? TEST_ADDR_V6 : TEST_ADDR_V4;
+
+    NetlinkResponse response{};
+    response.hdr.nlmsg_type = XFRM_MSG_DELPOLICY;
+    Slice responseSlice = netdutils::makeSlice(response);
+
+    size_t expectedMsgLength = NLMSG_HDRLEN + NLMSG_ALIGN(sizeof(xfrm_userpolicy_id));
+
+    std::vector<uint8_t> nlMsgBuf;
+    EXPECT_CALL(mockSyscalls, writev(_, _))
+        .WillOnce(DoAll(SaveFlattenedIovecs<1>(&nlMsgBuf), Return(expectedMsgLength)));
+    EXPECT_CALL(mockSyscalls, read(_, _))
+        .WillOnce(DoAll(SetArgSlice<1>(responseSlice), Return(responseSlice)));
+
+    XfrmController ctrl;
+    Status res =
+        ctrl.ipSecDeleteSecurityPolicy(1 /* resourceId */, static_cast<int>(XfrmDirection::OUT),
+                                       localAddr, remoteAddr);
+
+    EXPECT_TRUE(isOk(res)) << res;
+    EXPECT_EQ(expectedMsgLength, nlMsgBuf.size());
+
+    Slice nlMsgSlice = netdutils::makeSlice(nlMsgBuf);
+    nlMsgSlice = netdutils::drop(nlMsgSlice, NLMSG_HDRLEN);
+
+    xfrm_userpolicy_id policyid{};
+    netdutils::extract(nlMsgSlice, policyid);
+
+    EXPECT_EQ(static_cast<uint8_t>(XfrmDirection::OUT), policyid.dir);
+}
+
 } // namespace net
 } // namespace android