XFRM - Cleanup of XfrmId Usage Inconsistencies

The XfrmId  contains all the fields that are required
to identify an SA or a policy during either creation
or deletion. This patch improves the consistency of
creating the internal XfrmId structure that we pass
around in the XfrmController.

Bug: 70594971
Test: runtest -x system/netd/server/netd_unit_test.cpp
Test: cts passes
Change-Id: I39b9ed1599ef65ca957866b0dcb37726e33c53fa
diff --git a/server/XfrmController.cpp b/server/XfrmController.cpp
index 996c9ec..e70ed49 100644
--- a/server/XfrmController.cpp
+++ b/server/XfrmController.cpp
@@ -15,8 +15,8 @@
  * limitations under the License.
  */
 
-#include <string>
 #include <random>
+#include <string>
 #include <vector>
 
 #include <ctype.h>
@@ -125,15 +125,15 @@
     if (__android_log_is_debuggable()) {                                                           \
         do {                                                                                       \
             logHex(__desc16__, __buf__, __len__);                                                  \
-           } while (0);                                                                            \
+        } while (0);                                                                               \
     }
 
 #define LOG_IOV(__iov__)                                                                           \
     if (__android_log_is_debuggable()) {                                                           \
         do {                                                                                       \
             logIov(__iov__);                                                                       \
-           } while (0);                                                                            \
-     }
+        } while (0);                                                                               \
+    }
 
 void logHex(const char* desc16, const char* buf, size_t len) {
     char* printBuf = new char[len * 2 + 1 + 26]; // len->ascii, +newline, +prefix strlen
@@ -227,7 +227,9 @@
     netdutils::Status sendMessage(uint16_t nlMsgType, uint16_t nlMsgFlags, uint16_t nlMsgSeqNum,
                                   std::vector<iovec>* iovecs) const override {
         nlmsghdr nlMsg = {
-            .nlmsg_type = nlMsgType, .nlmsg_flags = nlMsgFlags, .nlmsg_seq = nlMsgSeqNum,
+            .nlmsg_type = nlMsgType,
+            .nlmsg_flags = nlMsgFlags,
+            .nlmsg_seq = nlMsgSeqNum,
         };
 
         (*iovecs)[0].iov_base = &nlMsg;
@@ -350,9 +352,8 @@
     ALOGD("inSpi=%0.8x", inSpi);
 
     XfrmSaInfo saInfo{};
-
     netdutils::Status ret =
-        fillXfrmSaId(direction, localAddress, remoteAddress, INVALID_SPI, &saInfo);
+        fillXfrmId(direction, localAddress, remoteAddress, INVALID_SPI, transformId, &saInfo);
     if (!isOk(ret)) {
         return ret;
     }
@@ -405,13 +406,12 @@
     ALOGD("encapRemotePort=%d", encapRemotePort);
 
     XfrmSaInfo saInfo{};
-    netdutils::Status ret = fillXfrmSaId(direction, localAddress, remoteAddress, spi, &saInfo);
+    netdutils::Status ret =
+        fillXfrmId(direction, localAddress, remoteAddress, spi, transformId, &saInfo);
     if (!isOk(ret)) {
         return ret;
     }
 
-    saInfo.transformId = transformId;
-
     saInfo.auth = XfrmAlgo{
         .name = authAlgo, .key = authKey, .truncLenBits = static_cast<uint16_t>(authTruncBits)};
 
@@ -485,8 +485,9 @@
     ALOGD("remoteAddress=%s", remoteAddress.c_str());
     ALOGD("spi=%0.8x", spi);
 
-    XfrmSaId saId{};
-    netdutils::Status ret = fillXfrmSaId(direction, localAddress, remoteAddress, spi, &saId);
+    XfrmId saId{};
+    netdutils::Status ret =
+        fillXfrmId(direction, localAddress, remoteAddress, spi, transformId, &saId);
     if (!isOk(ret)) {
         return ret;
     }
@@ -506,11 +507,16 @@
     return ret;
 }
 
-netdutils::Status XfrmController::fillXfrmSaId(int32_t direction, const std::string& localAddress,
-                                               const std::string& remoteAddress, int32_t spi,
-                                               XfrmSaId* xfrmId) {
-    xfrm_address_t localXfrmAddr{}, remoteXfrmAddr{};
+netdutils::Status XfrmController::fillXfrmId(int32_t direction, const std::string& localAddress,
+                                             const std::string& remoteAddress, int32_t spi,
+                                             int32_t transformId, XfrmId* xfrmId) {
+    // Fill the straightforward fields first
+    xfrmId->transformId = transformId;
+    xfrmId->direction = static_cast<XfrmDirection>(direction);
+    xfrmId->spi = htonl(spi);
 
+    // Use the addresses to determine the address family and do validation
+    xfrm_address_t localXfrmAddr{}, remoteXfrmAddr{};
     StatusOr<int> addrFamilyLocal, addrFamilyRemote;
     addrFamilyRemote = convertToXfrmAddr(remoteAddress, &remoteXfrmAddr);
     addrFamilyLocal = convertToXfrmAddr(localAddress, &localXfrmAddr);
@@ -529,8 +535,6 @@
 
     xfrmId->addrFamily = addrFamilyRemote.value();
 
-    xfrmId->spi = htonl(spi);
-
     switch (static_cast<XfrmDirection>(direction)) {
         case XfrmDirection::IN:
             xfrmId->dstAddr = localXfrmAddr;
@@ -560,22 +564,16 @@
     ALOGD("remoteAddress=%s", remoteAddress.c_str());
     ALOGD("spi=%0.8x", spi);
 
-    struct sockaddr_storage saddr;
-
     StatusOr<sockaddr_storage> ret = getSyscallInstance().getsockname<sockaddr_storage>(Fd(socket));
     if (!isOk(ret)) {
         ALOGE("Failed to get socket info in %s", __FUNCTION__);
         return ret;
     }
-
-    saddr = ret.value();
+    struct sockaddr_storage saddr = ret.value();
 
     XfrmSaInfo saInfo{};
-    saInfo.transformId = transformId;
-    saInfo.direction = static_cast<XfrmDirection>(direction);
-    saInfo.spi = spi;
-
-    netdutils::Status status = fillXfrmSaId(direction, localAddress, remoteAddress, spi, &saInfo);
+    netdutils::Status status =
+        fillXfrmId(direction, localAddress, remoteAddress, spi, transformId, &saInfo);
     if (!isOk(status)) {
         ALOGE("Couldn't build SA ID %s", __FUNCTION__);
         return status;
@@ -706,7 +704,7 @@
     int len = NLA_HDRLEN + sizeof(xfrm_algo);
     // Kernel always changes last char to null terminator; no safety checks needed.
     strncpy(algo->crypt.alg_name, inAlgo.name.c_str(), sizeof(algo->crypt.alg_name));
-    algo->crypt.alg_key_len = inAlgo.key.size() * 8;      // bits
+    algo->crypt.alg_key_len = inAlgo.key.size() * 8; // bits
     memcpy(algo->key, &inAlgo.key[0], inAlgo.key.size());
     len += inAlgo.key.size();
     fillXfrmNlaHdr(&algo->hdr, XFRMA_ALG_CRYPT, len);
@@ -787,7 +785,7 @@
     return sizeof(*usersa);
 }
 
-int XfrmController::fillUserSaId(const XfrmSaId& record, xfrm_usersa_id* said) {
+int XfrmController::fillUserSaId(const XfrmId& record, xfrm_usersa_id* said) {
     said->daddr = record.dstAddr;
     said->spi = record.spi;
     said->family = record.addrFamily;
@@ -796,7 +794,7 @@
     return sizeof(*said);
 }
 
-netdutils::Status XfrmController::deleteSecurityAssociation(const XfrmSaId& record,
+netdutils::Status XfrmController::deleteSecurityAssociation(const XfrmId& record,
                                                             const XfrmSocket& sock) {
     xfrm_usersa_id said{};
 
diff --git a/server/XfrmController.h b/server/XfrmController.h
index 7e508a0..a881b64 100644
--- a/server/XfrmController.h
+++ b/server/XfrmController.h
@@ -99,7 +99,8 @@
     uint16_t dstPort;
 };
 
-struct XfrmSaId {
+// minimally sufficient structure to match either an SA or a Policy
+struct XfrmId {
     XfrmDirection direction;
     xfrm_address_t dstAddr; // network order
     xfrm_address_t srcAddr;
@@ -108,7 +109,7 @@
     int spi;
 };
 
-struct XfrmSaInfo : XfrmSaId {
+struct XfrmSaInfo : XfrmId {
     XfrmAlgo auth;
     XfrmAlgo crypt;
     XfrmAlgo aead;
@@ -233,10 +234,10 @@
                   "struct xfrm_userspi_info has changed and does not match the kernel struct.");
 #endif
 
-    // helper function for filling in the XfrmSaInfo structure
-    static netdutils::Status fillXfrmSaId(int32_t direction, const std::string& localAddress,
+    // helper function for filling in the XfrmId (and XfrmSaInfo) structure
+    static netdutils::Status fillXfrmId(int32_t direction, const std::string& localAddress,
                                           const std::string& remoteAddress, int32_t spi,
-                                          XfrmSaId* xfrmId);
+                                          int32_t transformId, XfrmId* xfrmId);
 
     // Top level functions for managing a Transport Mode Transform
     static netdutils::Status addTransportModeTransform(const XfrmSaInfo& record);
@@ -258,9 +259,9 @@
     static int fillUserSaInfo(const XfrmSaInfo& record, xfrm_usersa_info* usersa);
 
     // Functions for deleting a Transport Mode SA
-    static netdutils::Status deleteSecurityAssociation(const XfrmSaId& record,
+    static netdutils::Status deleteSecurityAssociation(const XfrmId& record,
                                                        const XfrmSocket& sock);
-    static int fillUserSaId(const XfrmSaId& record, xfrm_usersa_id* said);
+    static int fillUserSaId(const XfrmId& record, xfrm_usersa_id* said);
     static int fillUserTemplate(const XfrmSaInfo& record, xfrm_user_tmpl* tmpl);
     static int fillTransportModeUserSpInfo(const XfrmSaInfo& record, xfrm_userpolicy_info* usersp);