Kernel unit tests for XFRM_MIGRATE

This CL adds unit tests for updating addresses and underlying
network of an IPsec SA using CONFIG_XFRM_MIGRATE.

Tested on android-mainline
 - without config set, tests failed
 - with config set, tests ran and passed

Tested on android 5.4
 - without config set, tests were skipped
 - with config set, tests ran and passed

Bug: 169170981
Test: xfrm_tunnel_test.py
Signed-off-by: Yan Yan <evitayan@google.com>
Change-Id: Ieb410dab2ba1c112e2c28a53b8165ebeb0ebb7f2
diff --git a/net/test/run_net_test.sh b/net/test/run_net_test.sh
index d65769d..6f32f81 100755
--- a/net/test/run_net_test.sh
+++ b/net/test/run_net_test.sh
@@ -47,6 +47,7 @@
 
 # Kernel version specific options
 OPTIONS="$OPTIONS XFRM_INTERFACE"                # Various device kernels
+OPTIONS="$OPTIONS XFRM_MIGRATE"                  # Added in 5.10
 OPTIONS="$OPTIONS CGROUP_BPF"                    # Added in android-4.9
 OPTIONS="$OPTIONS NF_SOCKET_IPV4 NF_SOCKET_IPV6" # Added in 4.9
 OPTIONS="$OPTIONS INET_SCTP_DIAG"                # Added in 4.7
diff --git a/net/test/xfrm.py b/net/test/xfrm.py
index c63d2a2..83437bd 100755
--- a/net/test/xfrm.py
+++ b/net/test/xfrm.py
@@ -140,6 +140,11 @@
     "daddr saddr dport dport_mask sport sport_mask "
     "family prefixlen_d prefixlen_s proto ifindex user")
 
+XfrmMigrate = cstruct.Struct(
+    "XfrmMigrate", "=16s16s16s16sBBxxIHH",
+    "old_daddr old_saddr new_daddr new_saddr proto "
+    "mode reqid old_family new_family")
+
 XfrmLifetimeCfg = cstruct.Struct(
     "XfrmLifetimeCfg", "=QQQQQQQQ",
     "soft_byte hard_byte soft_packet hard_packet "
@@ -710,6 +715,52 @@
     for selector in selectors:
       self.DeletePolicyInfo(selector, direction, mark, xfrm_if_id)
 
+  def MigrateTunnel(self, direction, selector, old_saddr, old_daddr,
+                    new_saddr, new_daddr, spi,
+                    encryption, auth_trunc, aead,
+                    encap, new_output_mark, xfrm_if_id):
+    """Update addresses and underlying network of Policies and an SA
+
+    Args:
+      direction: XFRM_POLICY_IN or XFRM_POLICY_OUT
+      selector: An XfrmSelector of the tunnel that needs to be updated.
+        If the passed-in selector is None, it means the tunnel is
+        dual-stack and thus both IPv4 and IPv6 policies will be updated.
+      old_saddr: the old (current) source address of the tunnel
+      old_daddr: the old (current) destination address of the tunnel
+      new_saddr: the new source address the IPsec SA will be migrated to
+      new_daddr: the new destination address the tunnel will be migrated to
+      spi: The SPI for the IPsec SA that encapsulates the tunneled packets
+      encryption: A tuple of an XfrmAlgo and raw key bytes, or None.
+      auth_trunc: A tuple of an XfrmAlgoAuth and raw key bytes, or None.
+      aead: A tuple of an XfrmAlgoAead and raw key bytes, or None.
+      encap: An XfrmEncapTmpl structure, or None.
+      new_output_mark: The mark used to select the new underlying network
+        for packets outbound from xfrm. None means unspecified.
+      xfrm_if_id: The XFRM interface ID
+    """
+
+    if selector is None:
+      selectors = [EmptySelector(AF_INET), EmptySelector(AF_INET6)]
+    else:
+      selectors = [selector]
+
+    nlattrs = []
+    xfrmMigrate = XfrmMigrate((PaddedAddress(old_daddr), PaddedAddress(old_saddr),
+                      PaddedAddress(new_daddr), PaddedAddress(new_saddr),
+                      IPPROTO_ESP, XFRM_MODE_TUNNEL, 0,
+                      net_test.GetAddressFamily(net_test.GetAddressVersion(old_saddr)),
+                      net_test.GetAddressFamily(net_test.GetAddressVersion(new_saddr))))
+    nlattrs.append((XFRMA_MIGRATE, xfrmMigrate))
+
+    for selector in selectors:
+        self.SendXfrmNlRequest(XFRM_MSG_MIGRATE,
+                               XfrmUserpolicyId(sel=selector, dir=direction), nlattrs)
+
+    # UPDSA is called exclusively to update the set_mark=new_output_mark.
+    self.AddSaInfo(new_saddr, new_daddr, spi, XFRM_MODE_TUNNEL, 0, encryption,
+                   auth_trunc, aead, encap, None, new_output_mark, True, xfrm_if_id)
+
 
 if __name__ == "__main__":
   x = Xfrm()
diff --git a/net/test/xfrm_tunnel_test.py b/net/test/xfrm_tunnel_test.py
index f175c09..7497ea2 100755
--- a/net/test/xfrm_tunnel_test.py
+++ b/net/test/xfrm_tunnel_test.py
@@ -37,9 +37,13 @@
 _LOOPBACK_IFINDEX = 1
 _TEST_XFRM_IFNAME = "ipsec42"
 _TEST_XFRM_IF_ID = 42
+_TEST_SPI = 0x1234
 
 # Does the kernel support xfrmi interfaces?
 def HaveXfrmInterfaces():
+  if net_test.LINUX_VERSION >= (4, 19, 0):
+    return True
+
   try:
     i = iproute.IPRoute()
     i.CreateXfrmInterface(_TEST_XFRM_IFNAME, _TEST_XFRM_IF_ID,
@@ -56,13 +60,45 @@
 
 HAVE_XFRM_INTERFACES = HaveXfrmInterfaces()
 
+# Does the kernel support CONFIG_XFRM_MIGRATE?
+def SupportsXfrmMigrate():
+  if net_test.LINUX_VERSION >= (5, 10, 0):
+    return True
+
+  # XFRM_MIGRATE depends on xfrmi interfaces
+  if not HAVE_XFRM_INTERFACES:
+    return False
+
+  try:
+    x = xfrm.Xfrm()
+    wildcard_addr = net_test.GetWildcardAddress(6)
+    selector = xfrm.EmptySelector(AF_INET6)
+
+    # Expect migration to fail with EINVAL because it is trying to migrate a
+    # non-existent SA.
+    x.MigrateTunnel(xfrm.XFRM_POLICY_OUT, selector, wildcard_addr, wildcard_addr,
+                    wildcard_addr, wildcard_addr, _TEST_SPI,
+                    None, None, None, None, None, None)
+    print("Migration succeeded unexpectedly, assuming XFRM_MIGRATE is enabled")
+    return True
+  except IOError as err:
+    if err.errno == ENOPROTOOPT:
+      return False
+    elif err.errno == EINVAL:
+      return True
+    else:
+      print("Unexpected error, assuming XFRM_MIGRATE is enabled:", err.errno)
+      return True
+
+SUPPORTS_XFRM_MIGRATE = SupportsXfrmMigrate()
+
 # Parameters to setup tunnels as special networks
 _TUNNEL_NETID_OFFSET = 0xFC00  # Matches reserved netid range for IpSecService
 _BASE_TUNNEL_NETID = {4: 40, 6: 60}
 _BASE_VTI_OKEY = 2000000100
 _BASE_VTI_IKEY = 2000000200
 
-_TEST_OUT_SPI = 0x1234
+_TEST_OUT_SPI = _TEST_SPI
 _TEST_IN_SPI = _TEST_OUT_SPI
 
 _TEST_OKEY = 2000000100
@@ -132,6 +168,7 @@
   InjectParameterizedTests(XfrmTunnelTest)
   InjectParameterizedTests(XfrmInterfaceTest)
   InjectParameterizedTests(XfrmVtiTest)
+  InjectParameterizedTests(XfrmInterfaceMigrateTest)
 
 
 def InjectParameterizedTests(cls):
@@ -334,6 +371,9 @@
     else:
       auth, crypt = xfrm_base._ALGO_HMAC_SHA1, xfrm_base._ALGO_CBC_AES_256
 
+    self.auth = auth
+    self.crypt = crypt
+
     self._SetupXfrmByType(auth, crypt)
 
   def Rekey(self, outer_family, new_out_sa, new_in_sa):
@@ -448,7 +488,7 @@
 class XfrmInterface(IpSecBaseInterface):
 
   def __init__(self, iface, netid, underlying_netid, ifindex, local, remote,
-               version):
+               version, use_null_crypt=False):
     super(XfrmInterface, self).__init__(iface, netid, underlying_netid, local,
                                         remote, version)
 
@@ -456,7 +496,7 @@
     self.xfrm_if_id = netid
 
     self.SetupInterface()
-    self.SetupXfrm(False)
+    self.SetupXfrm(use_null_crypt)
 
   def SetupInterface(self):
     """Create an XFRM interface."""
@@ -505,9 +545,30 @@
     self.xfrm.DeleteSaInfo(self.remote, old_out_spi, IPPROTO_ESP, None,
                            self.xfrm_if_id)
 
+  def Migrate(self, new_underlying_netid, new_local, new_remote):
+    self.xfrm.MigrateTunnel(xfrm.XFRM_POLICY_IN, None, self.remote, self.local,
+                            new_remote, new_local, self.in_sa.spi,
+                            self.crypt, self.auth, None, None,
+                            new_underlying_netid, self.xfrm_if_id)
+
+    self.xfrm.MigrateTunnel(xfrm.XFRM_POLICY_OUT, None, self.local, self.remote,
+                            new_local, new_remote, self.out_sa.spi,
+                            self.crypt, self.auth, None, None,
+                            new_underlying_netid, self.xfrm_if_id)
+
+    self.local = new_local
+    self.remote = new_remote
+    self.underlying_netid = new_underlying_netid
+
 
 class XfrmTunnelBase(xfrm_base.XfrmBaseTest):
 
+  # Subclass that does not allow multiple tunnels (e.g. XfrmInterfaceMigrateTest)
+  # should override this method.
+  @classmethod
+  def allowMultipleTunnels(cls):
+    return True
+
   @classmethod
   def setUpClass(cls):
     xfrm_base.XfrmBaseTest.setUpClass()
@@ -520,6 +581,10 @@
     # IPv6 tunnel
     cls.tunnelsV4 = {}
     cls.tunnelsV6 = {}
+
+    if not cls.allowMultipleTunnels():
+      return
+
     for i, underlying_netid in enumerate(cls.tuns):
       for version in 4, 6:
         netid = _BASE_TUNNEL_NETID[version] + _TUNNEL_NETID_OFFSET + i
@@ -947,6 +1012,79 @@
   def ParamTestXfrmIntfRekey(self, inner_version, outer_version):
     self._TestTunnelRekey(inner_version, outer_version)
 
+@unittest.skipUnless(SUPPORTS_XFRM_MIGRATE, "XFRM migration unsupported")
+class XfrmInterfaceMigrateTest(XfrmTunnelBase):
+  # TODO: b/172497215 There is a kernel issue that XFRM_MIGRATE cannot work correctly
+  # when there are multiple tunnels with the same selectors. Thus before this issue
+  # is fixed, #allowMultipleTunnels must be overridden to avoid setting up multiple
+  # tunnels. This need to be removed after the kernel issue is fixed.
+  @classmethod
+  def allowMultipleTunnels(cls):
+    return False
+
+  def setUpTunnel(self, outer_version, use_null_crypt):
+    underlying_netid = self.RandomNetid()
+    netid = _BASE_TUNNEL_NETID[outer_version] + _TUNNEL_NETID_OFFSET
+    iface = "ipsec%s" % netid
+    ifindex = self.ifindices[underlying_netid]
+
+    local = self.MyAddress(outer_version, underlying_netid)
+    remote = net_test.IPV4_ADDR if outer_version == 4 else net_test.IPV6_ADDR
+
+    tunnel = XfrmInterface(iface, netid, underlying_netid, ifindex,
+                           local, remote, outer_version, use_null_crypt)
+    self._SetInboundMarking(netid, iface, True)
+    self._SetupTunnelNetwork(tunnel, True)
+
+    return tunnel
+
+  def tearDownTunnel(self, tunnel):
+    self._SetInboundMarking(tunnel.netid, tunnel.iface, False)
+    self._SetupTunnelNetwork(tunnel, False)
+    tunnel.Teardown()
+
+  def _TestTunnel(self, inner_version, outer_version, func, use_null_crypt):
+    tunnel = self.setUpTunnel(outer_version, use_null_crypt)
+
+    # Verify functionality before migration
+    local_inner = tunnel.addrs[inner_version]
+    remote_inner = _GetRemoteInnerAddress(inner_version)
+    func(tunnel, inner_version, local_inner, remote_inner)
+
+    # Migrate tunnel
+    # TODO:b/169170981 Add tests that migrate 4 -> 6 and 6 -> 4
+    new_underlying_netid = self.RandomNetid(exclude=tunnel.underlying_netid)
+    new_local = self.MyAddress(outer_version, new_underlying_netid)
+    new_remote = net_test.IPV4_ADDR2 if outer_version == 4 else net_test.IPV6_ADDR2
+
+    tunnel.Migrate(new_underlying_netid, new_local, new_remote)
+
+    # Verify functionality after migration
+    func(tunnel, inner_version, local_inner, remote_inner)
+
+    self.tearDownTunnel(tunnel)
+
+  def ParamTestMigrateXfrmIntfInput(self, inner_version, outer_version):
+    self._TestTunnel(inner_version, outer_version, self._CheckTunnelInput, True)
+
+  def ParamTestMigrateXfrmIntfOutput(self, inner_version, outer_version):
+    self._TestTunnel(inner_version, outer_version, self._CheckTunnelOutput,
+                     True)
+
+  def ParamTestMigrateXfrmIntfInOutEncrypted(self, inner_version, outer_version):
+    self._TestTunnel(inner_version, outer_version, self._CheckTunnelEncryption,
+                     False)
+
+  def ParamTestMigrateXfrmIntfIcmp(self, inner_version, outer_version):
+    self._TestTunnel(inner_version, outer_version, self._CheckTunnelIcmp, False)
+
+  def ParamTestMigrateXfrmIntfEncryptionWithIcmp(self, inner_version, outer_version):
+    self._TestTunnel(inner_version, outer_version,
+                     self._CheckTunnelEncryptionWithIcmp, False)
+
+  def ParamTestMigrateXfrmIntfRekey(self, inner_version, outer_version):
+    self._TestTunnel(inner_version, outer_version, self._CheckTunnelRekey,
+                     True)
 
 if __name__ == "__main__":
   InjectTests()