Test mark-based routing for outgoing packets.
Change-Id: Ifd696803f22a29bae162ab1d21d7b3552b9b80d3
diff --git a/tests/net_test/mark_test.py b/tests/net_test/mark_test.py
index edfa476..775b441 100755
--- a/tests/net_test/mark_test.py
+++ b/tests/net_test/mark_test.py
@@ -12,6 +12,7 @@
import net_test
+DEBUG = False
IFF_TUN = 1
IFF_TAP = 2
@@ -200,7 +201,8 @@
cmds = self.COMMANDS
if self.AUTOCONF_TABLE_OFFSET < 0:
# Set up routing manually.
- cmds += self.ROUTE_COMMANDS
+ # Don't do cmds += self.ROUTE_COMMANDS as this modifies self.COMMANDS.
+ cmds = self.COMMANDS + self.ROUTE_COMMANDS
if version == 4:
# Deleting addresses also causes routes to be deleted, so watch the
@@ -224,7 +226,7 @@
}).split("\n")
for cmd in cmds:
cmd = cmd.split(" ")
- #print cmd
+ if DEBUG: print " ".join(cmd)
ret = os.spawnvp(os.P_WAIT, cmd[0], cmd)
if ret:
raise ConfigurationError("Setup command failed: %s" % " ".join(cmd))
@@ -317,16 +319,6 @@
self.assertMultiLineEqual(str(expected).encode("hex"),
str(actual).encode("hex"))
- def ExpectPacketOn(self, netid, msg, expected):
- try:
- actual = self.tuns[netid].read(4096)
- except IOError, e:
- raise AssertionError(msg + ": " + str(e))
-
- self.assertTrue(actual)
- if expected:
- self.CheckExpectedPacket(expected, actual, msg)
-
def assertNoPacketsOn(self, netids, msg):
for netid in netids:
try:
@@ -341,6 +333,21 @@
def assertNoPacketsExceptOn(self, netid, msg):
self.assertNoPacketsOn([n for n in self.tuns if n != netid], msg)
+ def ExpectPacketOn(self, netid, msg, expected=None):
+ # Check no packets were sent on any other netid.
+ self.assertNoPacketsExceptOn(netid, msg)
+
+ # Check that a packet was sent on netid.
+ try:
+ actual = self.tuns[netid].read(4096)
+ except IOError, e:
+ raise AssertionError(msg + ": " + str(e))
+ self.assertTrue(actual)
+
+ # If we know what sort of packet we expect, check that here.
+ if expected:
+ self.CheckExpectedPacket(expected, actual, msg)
+
def ReceivePacketOn(self, netid, ip_packet):
routermac = self._RouterMacAddress(netid)
mymac = self._MyMacAddress(netid)
@@ -362,8 +369,60 @@
def _GetRemoteAddress(version):
return {4: net_test.IPV4_ADDR, 6: net_test.IPV6_ADDR}[version]
+ def MarkSocket(self, s, netid):
+ s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid)
+
+ def GetProtocolFamily(self, version):
+ return {4: AF_INET, 6: AF_INET6}[version]
+
+ def testOutgoingPackets(self):
+ """Checks that socket marking selects the right outgoing interface."""
+
+ def CheckPingPacket(version, netid, packet):
+ s = net_test.PingSocket(self.GetProtocolFamily(version))
+ dstaddr = self._GetRemoteAddress(version)
+ self.MarkSocket(s, netid)
+ s.sendto(packet, (dstaddr, 19321))
+ self.ExpectPacketOn(netid, "IPv%d ping: mark %d" % (version, netid))
+
+ for netid in self.tuns:
+ CheckPingPacket(4, netid, net_test.IPV4_PING)
+ CheckPingPacket(6, netid, net_test.IPV6_PING)
+
+ def CheckTCPSYNPacket(version, netid, dstaddr):
+ s = net_test.TCPSocket(self.GetProtocolFamily(version))
+ self.MarkSocket(s, netid)
+ # Non-blocking TCP connects always return EINPROGRESS.
+ self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstaddr, 53))
+ self.ExpectPacketOn(netid, "IPv%d TCP connect: mark %d" % (version,
+ netid))
+ s.close()
+
+ for netid in self.tuns:
+ CheckTCPSYNPacket(4, netid, net_test.IPV4_ADDR)
+ CheckTCPSYNPacket(6, netid, net_test.IPV6_ADDR)
+ CheckTCPSYNPacket(6, netid, "::ffff:" + net_test.IPV4_ADDR)
+
+ def CheckUDPPacket(version, netid, dstaddr):
+ s = net_test.UDPSocket(self.GetProtocolFamily(version))
+ self.MarkSocket(s, netid)
+ s.sendto("hello", (dstaddr, 53))
+ self.ExpectPacketOn(netid, "IPv%d UDP sendto: mark %d" % (version, netid))
+ s.connect((dstaddr, 53))
+ s.send("hello")
+ self.ExpectPacketOn(netid, "IPv%d UDP connect/send: mark %d" % (version,
+ netid))
+ s.close()
+
+ for netid in self.tuns:
+ CheckUDPPacket(4, netid, net_test.IPV4_ADDR)
+ CheckUDPPacket(6, netid, net_test.IPV6_ADDR)
+ CheckUDPPacket(6, netid, "::ffff:" + net_test.IPV4_ADDR)
+
def CheckReflection(self, version, packet_generator, reply_generator):
- # Test packets addressed to the IP addresses of all our interfaces...
+ """Checks that replies go out on the same interface as the original."""
+
+ # Check packets addressed to the IP addresses of all our interfaces...
for dest_ip_netid in self.tuns:
dest_ip_iface = self._GetInterfaceName(dest_ip_netid)
@@ -391,7 +450,6 @@
# Expect a reply on the interface the original packet came in on.
self.ClearTunQueues()
self.ReceivePacketOn(iif_netid, packet)
- self.assertNoPacketsExceptOn(iif_netid, msg)
self.ExpectPacketOn(iif_netid, msg, reply)
def SYNToClosedPort(self, *args):
@@ -419,6 +477,10 @@
self.CheckReflection(6, self.SYNToClosedPort, Packets.RST)
@unittest.skipUnless(False, "skipping: doesn't work yet")
+ def testIPv4SYNACKsReflectMark(self):
+ self.CheckReflection(4, Packets.SYNToOpenPort, Packets.SYNACK)
+
+ @unittest.skipUnless(False, "skipping: doesn't work yet")
def testIPv6SYNACKsReflectMark(self):
self.CheckReflection(6, Packets.SYNToOpenPort, Packets.SYNACK)
diff --git a/tests/net_test/net_test.py b/tests/net_test/net_test.py
index da892f0..caf79a0 100755
--- a/tests/net_test/net_test.py
+++ b/tests/net_test/net_test.py
@@ -60,23 +60,39 @@
us = (ms % 1000) * 1000
sock.setsockopt(SOL_SOCKET, SO_RCVTIMEO, struct.pack("LL", s, us))
-# Convenience functions to create ping sockets.
+def SetNonBlocking(fd):
+ flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0)
+ fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
+
+# Convenience functions to create sockets.
def Socket(family, sock_type, protocol):
s = socket(family, sock_type, protocol)
SetSocketTimeout(s, 1000)
return s
+def PingSocket(family):
+ proto = {AF_INET: IPPROTO_ICMP, AF_INET6: IPPROTO_ICMPV6}[family]
+ return Socket(family, SOCK_DGRAM, proto)
+
def IPv4PingSocket():
- return Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP)
+ return PingSocket(AF_INET)
def IPv6PingSocket():
- return Socket(AF_INET6, SOCK_DGRAM, IPPROTO_ICMPV6)
+ return PingSocket(AF_INET6)
+
+def TCPSocket(family):
+ s = Socket(family, SOCK_STREAM, IPPROTO_TCP)
+ SetNonBlocking(s.fileno())
+ return s
def IPv4TCPSocket():
- return Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)
+ return TCPSocket(AF_INET)
def IPv6TCPSocket():
- return Socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP)
+ return TCPSocket(AF_INET6)
+
+def UDPSocket(family):
+ return Socket(family, SOCK_DGRAM, IPPROTO_UDP)
def IPv6PacketSocket():
return Socket(AF_PACKET, SOCK_DGRAM, htons(ETH_P_IPV6))
@@ -86,10 +102,6 @@
s.setsockopt(SOL_IP, IP_HDRINCL, 1)
return s
-def SetNonBlocking(fd):
- flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0)
- fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
-
def GetInterfaceIndex(ifname):
s = IPv4PingSocket()
ifr = struct.pack("16si", ifname, 0)
@@ -97,6 +109,7 @@
return struct.unpack("16si", ifr)[1]
def SetInterfaceHWAddr(ifname, hwaddr):
+ s = IPv4PingSocket()
hwaddr = hwaddr.replace(":", "")
hwaddr = hwaddr.decode("hex")
if len(hwaddr) != 6: