Mark sockets on creation (socket()) and accept4().

Continued from: https://android-review.git.corp.google.com/#/c/95094/

Change-Id: Ib0b8f5d7c5013b91eae6bbc3847852eb355c7714
diff --git a/client/NetdClient.cpp b/client/NetdClient.cpp
index 31c2e4d..c0acdc0 100644
--- a/client/NetdClient.cpp
+++ b/client/NetdClient.cpp
@@ -25,6 +25,21 @@
 
 namespace {
 
+// TODO: Convert to C++11 std::atomic<unsigned>.
+volatile sig_atomic_t netIdForProcess = NETID_UNSET;
+volatile sig_atomic_t netIdForResolv = NETID_UNSET;
+
+typedef int (*Accept4FunctionType)(int, sockaddr*, socklen_t*, int);
+typedef int (*ConnectFunctionType)(int, const sockaddr*, socklen_t);
+typedef int (*SocketFunctionType)(int, int, int);
+typedef unsigned (*NetIdForResolvFunctionType)(unsigned);
+
+// These variables are only modified at startup (when libc.so is loaded) and never afterwards, so
+// it's okay that they are read later at runtime without a lock.
+Accept4FunctionType libcAccept4 = 0;
+ConnectFunctionType libcConnect = 0;
+SocketFunctionType libcSocket = 0;
+
 int closeFdAndRestoreErrno(int fd) {
     int error = errno;
     close(fd);
@@ -32,39 +47,21 @@
     return -1;
 }
 
-typedef int (*ConnectFunctionType)(int, const sockaddr*, socklen_t);
-typedef int (*AcceptFunctionType)(int, sockaddr*, socklen_t*);
-typedef unsigned (*NetIdForResolvFunctionType)(unsigned);
-
-// These variables are only modified at startup (when libc.so is loaded) and never afterwards, so
-// it's okay that they are read later at runtime without a lock.
-ConnectFunctionType libcConnect = 0;
-AcceptFunctionType libcAccept = 0;
-
-int netdClientConnect(int sockfd, const sockaddr* addr, socklen_t addrlen) {
-    if (FwmarkClient::shouldSetFwmark(sockfd, addr)) {
-        FwmarkCommand command = {FwmarkCommand::ON_CONNECT, 0};
-        if (!FwmarkClient().send(&command, sizeof(command), sockfd)) {
-            return -1;
-        }
-    }
-    return libcConnect(sockfd, addr, addrlen);
-}
-
-int netdClientAccept(int sockfd, sockaddr* addr, socklen_t* addrlen) {
-    int acceptedSocket = libcAccept(sockfd, addr, addrlen);
+int netdClientAccept4(int sockfd, sockaddr* addr, socklen_t* addrlen, int flags) {
+    int acceptedSocket = libcAccept4(sockfd, addr, addrlen, flags);
     if (acceptedSocket == -1) {
         return -1;
     }
-    sockaddr socketAddress;
-    if (!addr) {
-        socklen_t socketAddressLen = sizeof(socketAddress);
-        if (getsockname(acceptedSocket, &socketAddress, &socketAddressLen) == -1) {
+    int family;
+    if (addr) {
+        family = addr->sa_family;
+    } else {
+        socklen_t familyLen = sizeof(family);
+        if (getsockopt(acceptedSocket, SOL_SOCKET, SO_DOMAIN, &family, &familyLen) == -1) {
             return closeFdAndRestoreErrno(acceptedSocket);
         }
-        addr = &socketAddress;
     }
-    if (FwmarkClient::shouldSetFwmark(acceptedSocket, addr)) {
+    if (FwmarkClient::shouldSetFwmark(family)) {
         FwmarkCommand command = {FwmarkCommand::ON_ACCEPT, 0};
         if (!FwmarkClient().send(&command, sizeof(command), acceptedSocket)) {
             return closeFdAndRestoreErrno(acceptedSocket);
@@ -73,9 +70,29 @@
     return acceptedSocket;
 }
 
-// TODO: Convert to C++11 std::atomic<unsigned>.
-volatile sig_atomic_t netIdForProcess = NETID_UNSET;
-volatile sig_atomic_t netIdForResolv = NETID_UNSET;
+int netdClientConnect(int sockfd, const sockaddr* addr, socklen_t addrlen) {
+    if (sockfd >= 0 && addr && FwmarkClient::shouldSetFwmark(addr->sa_family)) {
+        FwmarkCommand command = {FwmarkCommand::ON_CONNECT, 0};
+        if (!FwmarkClient().send(&command, sizeof(command), sockfd)) {
+            return -1;
+        }
+    }
+    return libcConnect(sockfd, addr, addrlen);
+}
+
+int netdClientSocket(int domain, int type, int protocol) {
+    int socketFd = libcSocket(domain, type, protocol);
+    if (socketFd == -1) {
+        return -1;
+    }
+    unsigned netId = netIdForProcess;
+    if (netId != NETID_UNSET && FwmarkClient::shouldSetFwmark(domain)) {
+        if (!setNetworkForSocket(netId, socketFd)) {
+            return closeFdAndRestoreErrno(socketFd);
+        }
+    }
+    return socketFd;
+}
 
 unsigned getNetworkForResolv(unsigned netId) {
     if (netId != NETID_UNSET) {
@@ -110,6 +127,14 @@
 
 }  // namespace
 
+// accept() just calls accept4(..., 0), so there's no need to handle accept() separately.
+extern "C" void netdClientInitAccept4(Accept4FunctionType* function) {
+    if (function && *function) {
+        libcAccept4 = *function;
+        *function = netdClientAccept4;
+    }
+}
+
 extern "C" void netdClientInitConnect(ConnectFunctionType* function) {
     if (function && *function) {
         libcConnect = *function;
@@ -117,10 +142,10 @@
     }
 }
 
-extern "C" void netdClientInitAccept(AcceptFunctionType* function) {
+extern "C" void netdClientInitSocket(SocketFunctionType* function) {
     if (function && *function) {
-        libcAccept = *function;
-        *function = netdClientAccept;
+        libcSocket = *function;
+        *function = netdClientSocket;
     }
 }