Test mark-based routing for outgoing packets.

Change-Id: Ifd696803f22a29bae162ab1d21d7b3552b9b80d3
diff --git a/tests/net_test/mark_test.py b/tests/net_test/mark_test.py
index edfa476..775b441 100755
--- a/tests/net_test/mark_test.py
+++ b/tests/net_test/mark_test.py
@@ -12,6 +12,7 @@
 
 import net_test
 
+DEBUG = False
 
 IFF_TUN = 1
 IFF_TAP = 2
@@ -200,7 +201,8 @@
         cmds = self.COMMANDS
         if self.AUTOCONF_TABLE_OFFSET < 0:
           # Set up routing manually.
-          cmds += self.ROUTE_COMMANDS
+          # Don't do cmds += self.ROUTE_COMMANDS as this modifies self.COMMANDS.
+          cmds = self.COMMANDS + self.ROUTE_COMMANDS
 
       if version == 4:
         # Deleting addresses also causes routes to be deleted, so watch the
@@ -224,7 +226,7 @@
       }).split("\n")
       for cmd in cmds:
         cmd = cmd.split(" ")
-        #print cmd
+        if DEBUG: print " ".join(cmd)
         ret = os.spawnvp(os.P_WAIT, cmd[0], cmd)
         if ret:
           raise ConfigurationError("Setup command failed: %s" % " ".join(cmd))
@@ -317,16 +319,6 @@
       self.assertMultiLineEqual(str(expected).encode("hex"),
                                 str(actual).encode("hex"))
     
-  def ExpectPacketOn(self, netid, msg, expected):
-    try:
-      actual = self.tuns[netid].read(4096)
-    except IOError, e:
-      raise AssertionError(msg + ": " + str(e))
-      
-    self.assertTrue(actual)
-    if expected:
-      self.CheckExpectedPacket(expected, actual, msg)
-
   def assertNoPacketsOn(self, netids, msg):
     for netid in netids:
       try:
@@ -341,6 +333,21 @@
   def assertNoPacketsExceptOn(self, netid, msg):
     self.assertNoPacketsOn([n for n in self.tuns if n != netid], msg)
 
+  def ExpectPacketOn(self, netid, msg, expected=None):
+    # Check no packets were sent on any other netid.
+    self.assertNoPacketsExceptOn(netid, msg)
+
+    # Check that a packet was sent on netid.
+    try:
+      actual = self.tuns[netid].read(4096)
+    except IOError, e:
+      raise AssertionError(msg + ": " + str(e))
+    self.assertTrue(actual)
+
+    # If we know what sort of packet we expect, check that here.
+    if expected:
+      self.CheckExpectedPacket(expected, actual, msg)
+
   def ReceivePacketOn(self, netid, ip_packet):
     routermac = self._RouterMacAddress(netid)
     mymac = self._MyMacAddress(netid)
@@ -362,8 +369,60 @@
   def _GetRemoteAddress(version):
     return {4: net_test.IPV4_ADDR, 6: net_test.IPV6_ADDR}[version]
 
+  def MarkSocket(self, s, netid):
+    s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid)
+
+  def GetProtocolFamily(self, version):
+    return {4: AF_INET, 6: AF_INET6}[version]
+
+  def testOutgoingPackets(self):
+    """Checks that socket marking selects the right outgoing interface."""
+
+    def CheckPingPacket(version, netid, packet):
+      s = net_test.PingSocket(self.GetProtocolFamily(version))
+      dstaddr = self._GetRemoteAddress(version)
+      self.MarkSocket(s, netid)
+      s.sendto(packet, (dstaddr, 19321))
+      self.ExpectPacketOn(netid, "IPv%d ping: mark %d" % (version, netid))
+
+    for netid in self.tuns:
+      CheckPingPacket(4, netid, net_test.IPV4_PING)
+      CheckPingPacket(6, netid, net_test.IPV6_PING)
+
+    def CheckTCPSYNPacket(version, netid, dstaddr):
+      s = net_test.TCPSocket(self.GetProtocolFamily(version))
+      self.MarkSocket(s, netid)
+      # Non-blocking TCP connects always return EINPROGRESS.
+      self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstaddr, 53))
+      self.ExpectPacketOn(netid, "IPv%d TCP connect: mark %d" % (version,
+                                                                 netid))
+      s.close()
+
+    for netid in self.tuns:
+      CheckTCPSYNPacket(4, netid, net_test.IPV4_ADDR)
+      CheckTCPSYNPacket(6, netid, net_test.IPV6_ADDR)
+      CheckTCPSYNPacket(6, netid, "::ffff:" + net_test.IPV4_ADDR)
+
+    def CheckUDPPacket(version, netid, dstaddr):
+      s = net_test.UDPSocket(self.GetProtocolFamily(version))
+      self.MarkSocket(s, netid)
+      s.sendto("hello", (dstaddr, 53))
+      self.ExpectPacketOn(netid, "IPv%d UDP sendto: mark %d" % (version, netid))
+      s.connect((dstaddr, 53))
+      s.send("hello")
+      self.ExpectPacketOn(netid, "IPv%d UDP connect/send: mark %d" % (version,
+                                                                      netid))
+      s.close()
+
+    for netid in self.tuns:
+      CheckUDPPacket(4, netid, net_test.IPV4_ADDR)
+      CheckUDPPacket(6, netid, net_test.IPV6_ADDR)
+      CheckUDPPacket(6, netid, "::ffff:" + net_test.IPV4_ADDR)
+
   def CheckReflection(self, version, packet_generator, reply_generator):
-    # Test packets addressed to the IP addresses of all our interfaces...
+    """Checks that replies go out on the same interface as the original."""
+
+    # Check packets addressed to the IP addresses of all our interfaces...
     for dest_ip_netid in self.tuns:
       dest_ip_iface = self._GetInterfaceName(dest_ip_netid)
 
@@ -391,7 +450,6 @@
         # Expect a reply on the interface the original packet came in on.
         self.ClearTunQueues()
         self.ReceivePacketOn(iif_netid, packet)
-        self.assertNoPacketsExceptOn(iif_netid, msg)
         self.ExpectPacketOn(iif_netid, msg, reply)
 
   def SYNToClosedPort(self, *args):
@@ -419,6 +477,10 @@
     self.CheckReflection(6, self.SYNToClosedPort, Packets.RST)
 
   @unittest.skipUnless(False, "skipping: doesn't work yet")
+  def testIPv4SYNACKsReflectMark(self):
+    self.CheckReflection(4, Packets.SYNToOpenPort, Packets.SYNACK)
+
+  @unittest.skipUnless(False, "skipping: doesn't work yet")
   def testIPv6SYNACKsReflectMark(self):
     self.CheckReflection(6, Packets.SYNToOpenPort, Packets.SYNACK)
 
diff --git a/tests/net_test/net_test.py b/tests/net_test/net_test.py
index da892f0..caf79a0 100755
--- a/tests/net_test/net_test.py
+++ b/tests/net_test/net_test.py
@@ -60,23 +60,39 @@
   us = (ms % 1000) * 1000
   sock.setsockopt(SOL_SOCKET, SO_RCVTIMEO, struct.pack("LL", s, us))
 
-# Convenience functions to create ping sockets.
+def SetNonBlocking(fd):
+  flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0)
+  fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
+
+# Convenience functions to create sockets.
 def Socket(family, sock_type, protocol):
   s = socket(family, sock_type, protocol)
   SetSocketTimeout(s, 1000)
   return s
 
+def PingSocket(family):
+  proto = {AF_INET: IPPROTO_ICMP, AF_INET6: IPPROTO_ICMPV6}[family]
+  return Socket(family, SOCK_DGRAM, proto)
+
 def IPv4PingSocket():
-  return Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP)
+  return PingSocket(AF_INET)
 
 def IPv6PingSocket():
-  return Socket(AF_INET6, SOCK_DGRAM, IPPROTO_ICMPV6)
+  return PingSocket(AF_INET6)
+
+def TCPSocket(family):
+  s = Socket(family, SOCK_STREAM, IPPROTO_TCP)
+  SetNonBlocking(s.fileno())
+  return s
 
 def IPv4TCPSocket():
-  return Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)
+  return TCPSocket(AF_INET)
 
 def IPv6TCPSocket():
-  return Socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP)
+  return TCPSocket(AF_INET6)
+
+def UDPSocket(family):
+  return Socket(family, SOCK_DGRAM, IPPROTO_UDP)
 
 def IPv6PacketSocket():
   return Socket(AF_PACKET, SOCK_DGRAM, htons(ETH_P_IPV6))
@@ -86,10 +102,6 @@
   s.setsockopt(SOL_IP, IP_HDRINCL, 1)
   return s
 
-def SetNonBlocking(fd):
-  flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0)
-  fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
-
 def GetInterfaceIndex(ifname):
   s = IPv4PingSocket()
   ifr = struct.pack("16si", ifname, 0)
@@ -97,6 +109,7 @@
   return struct.unpack("16si", ifr)[1]
 
 def SetInterfaceHWAddr(ifname, hwaddr):
+  s = IPv4PingSocket()
   hwaddr = hwaddr.replace(":", "")
   hwaddr = hwaddr.decode("hex")
   if len(hwaddr) != 6: