Mark sockets on creation (socket()) and accept4().
Continued from: https://android-review.git.corp.google.com/#/c/95094/
Change-Id: Ib0b8f5d7c5013b91eae6bbc3847852eb355c7714
diff --git a/client/NetdClient.cpp b/client/NetdClient.cpp
index 31c2e4d..c0acdc0 100644
--- a/client/NetdClient.cpp
+++ b/client/NetdClient.cpp
@@ -25,6 +25,21 @@
namespace {
+// TODO: Convert to C++11 std::atomic<unsigned>.
+volatile sig_atomic_t netIdForProcess = NETID_UNSET;
+volatile sig_atomic_t netIdForResolv = NETID_UNSET;
+
+typedef int (*Accept4FunctionType)(int, sockaddr*, socklen_t*, int);
+typedef int (*ConnectFunctionType)(int, const sockaddr*, socklen_t);
+typedef int (*SocketFunctionType)(int, int, int);
+typedef unsigned (*NetIdForResolvFunctionType)(unsigned);
+
+// These variables are only modified at startup (when libc.so is loaded) and never afterwards, so
+// it's okay that they are read later at runtime without a lock.
+Accept4FunctionType libcAccept4 = 0;
+ConnectFunctionType libcConnect = 0;
+SocketFunctionType libcSocket = 0;
+
int closeFdAndRestoreErrno(int fd) {
int error = errno;
close(fd);
@@ -32,39 +47,21 @@
return -1;
}
-typedef int (*ConnectFunctionType)(int, const sockaddr*, socklen_t);
-typedef int (*AcceptFunctionType)(int, sockaddr*, socklen_t*);
-typedef unsigned (*NetIdForResolvFunctionType)(unsigned);
-
-// These variables are only modified at startup (when libc.so is loaded) and never afterwards, so
-// it's okay that they are read later at runtime without a lock.
-ConnectFunctionType libcConnect = 0;
-AcceptFunctionType libcAccept = 0;
-
-int netdClientConnect(int sockfd, const sockaddr* addr, socklen_t addrlen) {
- if (FwmarkClient::shouldSetFwmark(sockfd, addr)) {
- FwmarkCommand command = {FwmarkCommand::ON_CONNECT, 0};
- if (!FwmarkClient().send(&command, sizeof(command), sockfd)) {
- return -1;
- }
- }
- return libcConnect(sockfd, addr, addrlen);
-}
-
-int netdClientAccept(int sockfd, sockaddr* addr, socklen_t* addrlen) {
- int acceptedSocket = libcAccept(sockfd, addr, addrlen);
+int netdClientAccept4(int sockfd, sockaddr* addr, socklen_t* addrlen, int flags) {
+ int acceptedSocket = libcAccept4(sockfd, addr, addrlen, flags);
if (acceptedSocket == -1) {
return -1;
}
- sockaddr socketAddress;
- if (!addr) {
- socklen_t socketAddressLen = sizeof(socketAddress);
- if (getsockname(acceptedSocket, &socketAddress, &socketAddressLen) == -1) {
+ int family;
+ if (addr) {
+ family = addr->sa_family;
+ } else {
+ socklen_t familyLen = sizeof(family);
+ if (getsockopt(acceptedSocket, SOL_SOCKET, SO_DOMAIN, &family, &familyLen) == -1) {
return closeFdAndRestoreErrno(acceptedSocket);
}
- addr = &socketAddress;
}
- if (FwmarkClient::shouldSetFwmark(acceptedSocket, addr)) {
+ if (FwmarkClient::shouldSetFwmark(family)) {
FwmarkCommand command = {FwmarkCommand::ON_ACCEPT, 0};
if (!FwmarkClient().send(&command, sizeof(command), acceptedSocket)) {
return closeFdAndRestoreErrno(acceptedSocket);
@@ -73,9 +70,29 @@
return acceptedSocket;
}
-// TODO: Convert to C++11 std::atomic<unsigned>.
-volatile sig_atomic_t netIdForProcess = NETID_UNSET;
-volatile sig_atomic_t netIdForResolv = NETID_UNSET;
+int netdClientConnect(int sockfd, const sockaddr* addr, socklen_t addrlen) {
+ if (sockfd >= 0 && addr && FwmarkClient::shouldSetFwmark(addr->sa_family)) {
+ FwmarkCommand command = {FwmarkCommand::ON_CONNECT, 0};
+ if (!FwmarkClient().send(&command, sizeof(command), sockfd)) {
+ return -1;
+ }
+ }
+ return libcConnect(sockfd, addr, addrlen);
+}
+
+int netdClientSocket(int domain, int type, int protocol) {
+ int socketFd = libcSocket(domain, type, protocol);
+ if (socketFd == -1) {
+ return -1;
+ }
+ unsigned netId = netIdForProcess;
+ if (netId != NETID_UNSET && FwmarkClient::shouldSetFwmark(domain)) {
+ if (!setNetworkForSocket(netId, socketFd)) {
+ return closeFdAndRestoreErrno(socketFd);
+ }
+ }
+ return socketFd;
+}
unsigned getNetworkForResolv(unsigned netId) {
if (netId != NETID_UNSET) {
@@ -110,6 +127,14 @@
} // namespace
+// accept() just calls accept4(..., 0), so there's no need to handle accept() separately.
+extern "C" void netdClientInitAccept4(Accept4FunctionType* function) {
+ if (function && *function) {
+ libcAccept4 = *function;
+ *function = netdClientAccept4;
+ }
+}
+
extern "C" void netdClientInitConnect(ConnectFunctionType* function) {
if (function && *function) {
libcConnect = *function;
@@ -117,10 +142,10 @@
}
}
-extern "C" void netdClientInitAccept(AcceptFunctionType* function) {
+extern "C" void netdClientInitSocket(SocketFunctionType* function) {
if (function && *function) {
- libcAccept = *function;
- *function = netdClientAccept;
+ libcSocket = *function;
+ *function = netdClientSocket;
}
}