Use unique_fd instead of int for sockets
Also refactor socket closing method.
Bug: 135717624
Test: atest
Change-Id: Iedbaf2521c4453195708114d7892ddc2cbe22211
diff --git a/res_init.cpp b/res_init.cpp
index 16fd98c..049e225 100644
--- a/res_init.cpp
+++ b/res_init.cpp
@@ -94,23 +94,11 @@
void res_init(ResState* statp, const struct android_net_context* _Nonnull netcontext,
android::net::NetworkDnsEventReported* _Nonnull event) {
- memset(statp, 0, sizeof *statp);
-
statp->netid = netcontext->dns_netid;
statp->uid = netcontext->uid;
statp->pid = netcontext->pid;
+ statp->nscount = 1;
statp->id = arc4random_uniform(65536);
- statp->_mark = netcontext->dns_mark;
- statp->netcontext_flags = netcontext->flags;
- statp->event = event;
-
- statp->ndots = 1;
- statp->_vcsock = -1;
-
- for (int ns = 0; ns < MAXNS; ns++) {
- statp->nssocks[ns] = -1;
- }
-
// The following dummy initialization is probably useless because
// it's overwritten later by resolv_populate_res_for_net().
// TODO: check if it's safe to remove.
@@ -120,28 +108,13 @@
.sin.sin_port = htons(NAMESERVER_PORT),
};
memcpy(&statp->nsaddrs, &u, sizeof(u));
- statp->nscount = 1;
-}
-/*
- * This routine is for closing the socket if a virtual circuit is used and
- * the program wants to close it. This provides support for endhostent()
- * which expects to close the socket.
- *
- * This routine is not expected to be user visible.
- */
-void res_nclose(res_state statp) {
- int ns;
-
- if (statp->_vcsock >= 0) {
- (void) close(statp->_vcsock);
- statp->_vcsock = -1;
- statp->_flags &= ~RES_F_VC;
+ for (auto& sock : statp->nssocks) {
+ sock.reset();
}
- for (ns = 0; ns < MAXNS; ns++) {
- if (statp->nssocks[ns] != -1) {
- close(statp->nssocks[ns]);
- statp->nssocks[ns] = -1;
- }
- }
+ statp->ndots = 1;
+ statp->_mark = netcontext->dns_mark;
+ statp->tcp_nssock.reset();
+ statp->event = event;
+ statp->netcontext_flags = netcontext->flags;
}
diff --git a/res_send.cpp b/res_send.cpp
index a4083b6..332bb78 100644
--- a/res_send.cpp
+++ b/res_send.cpp
@@ -555,7 +555,7 @@
}
if (resplen < 0) {
_resolv_cache_query_failed(statp->netid, buf, buflen, flags);
- res_nclose(statp);
+ statp->closeSockets();
return -terrno;
};
@@ -565,11 +565,11 @@
if (cache_status == RESOLV_CACHE_NOTFOUND) {
resolv_cache_add(statp->netid, buf, buflen, ans, resplen);
}
- res_nclose(statp);
+ statp->closeSockets();
return (resplen);
} // for each ns
} // for each retry
- res_nclose(statp);
+ statp->closeSockets();
terrno = useTcp ? terrno : gotsomewhere ? ETIMEDOUT : ECONNREFUSED;
// TODO: Remove errno once callers stop using it
errno = useTcp ? terrno
@@ -631,25 +631,24 @@
struct timespec now = evNowTime();
/* Are we still talking to whom we want to talk to? */
- if (statp->_vcsock >= 0 && (statp->_flags & RES_F_VC) != 0) {
+ if (statp->tcp_nssock >= 0 && (statp->_flags & RES_F_VC) != 0) {
struct sockaddr_storage peer;
socklen_t size = sizeof peer;
unsigned old_mark;
socklen_t mark_size = sizeof(old_mark);
- if (getpeername(statp->_vcsock, (struct sockaddr*) (void*) &peer, &size) < 0 ||
- !sock_eq((struct sockaddr*) (void*) &peer, nsap) ||
- getsockopt(statp->_vcsock, SOL_SOCKET, SO_MARK, &old_mark, &mark_size) < 0 ||
+ if (getpeername(statp->tcp_nssock, (struct sockaddr*)(void*)&peer, &size) < 0 ||
+ !sock_eq((struct sockaddr*)(void*)&peer, nsap) ||
+ getsockopt(statp->tcp_nssock, SOL_SOCKET, SO_MARK, &old_mark, &mark_size) < 0 ||
old_mark != statp->_mark) {
- res_nclose(statp);
- statp->_flags &= ~RES_F_VC;
+ statp->closeSockets();
}
}
- if (statp->_vcsock < 0 || (statp->_flags & RES_F_VC) == 0) {
- if (statp->_vcsock >= 0) res_nclose(statp);
+ if (statp->tcp_nssock < 0 || (statp->_flags & RES_F_VC) == 0) {
+ if (statp->tcp_nssock >= 0) statp->closeSockets();
- statp->_vcsock = socket(nsap->sa_family, SOCK_STREAM | SOCK_CLOEXEC, 0);
- if (statp->_vcsock < 0) {
+ statp->tcp_nssock.reset(socket(nsap->sa_family, SOCK_STREAM | SOCK_CLOEXEC, 0));
+ if (statp->tcp_nssock < 0) {
switch (errno) {
case EPROTONOSUPPORT:
case EPFNOSUPPORT:
@@ -662,9 +661,9 @@
return -1;
}
}
- resolv_tag_socket(statp->_vcsock, statp->uid, statp->pid);
+ resolv_tag_socket(statp->tcp_nssock, statp->uid, statp->pid);
if (statp->_mark != MARK_UNSET) {
- if (setsockopt(statp->_vcsock, SOL_SOCKET, SO_MARK, &statp->_mark,
+ if (setsockopt(statp->tcp_nssock, SOL_SOCKET, SO_MARK, &statp->_mark,
sizeof(statp->_mark)) < 0) {
*terrno = errno;
PLOG(DEBUG) << __func__ << ": setsockopt: ";
@@ -672,17 +671,17 @@
}
}
errno = 0;
- if (random_bind(statp->_vcsock, nsap->sa_family) < 0) {
+ if (random_bind(statp->tcp_nssock, nsap->sa_family) < 0) {
*terrno = errno;
dump_error("bind/vc", nsap, nsaplen);
- res_nclose(statp);
+ statp->closeSockets();
return (0);
}
- if (connect_with_timeout(statp->_vcsock, nsap, (socklen_t) nsaplen,
+ if (connect_with_timeout(statp->tcp_nssock, nsap, (socklen_t)nsaplen,
get_timeout(statp, params, ns)) < 0) {
*terrno = errno;
dump_error("connect/vc", nsap, nsaplen);
- res_nclose(statp);
+ statp->closeSockets();
/*
* The way connect_with_timeout() is implemented prevents us from reliably
* determining whether this was really a timeout or e.g. ECONNREFUSED. Since
@@ -705,10 +704,10 @@
{.iov_base = &len, .iov_len = INT16SZ},
{.iov_base = const_cast<uint8_t*>(buf), .iov_len = static_cast<size_t>(buflen)},
};
- if (writev(statp->_vcsock, iov, 2) != (INT16SZ + buflen)) {
+ if (writev(statp->tcp_nssock, iov, 2) != (INT16SZ + buflen)) {
*terrno = errno;
PLOG(DEBUG) << __func__ << ": write failed: ";
- res_nclose(statp);
+ statp->closeSockets();
return (0);
}
/*
@@ -717,14 +716,14 @@
read_len:
cp = ans;
len = INT16SZ;
- while ((n = read(statp->_vcsock, (char*) cp, (size_t) len)) > 0) {
+ while ((n = read(statp->tcp_nssock, (char*)cp, (size_t)len)) > 0) {
cp += n;
if ((len -= n) == 0) break;
}
if (n <= 0) {
*terrno = errno;
PLOG(DEBUG) << __func__ << ": read failed: ";
- res_nclose(statp);
+ statp->closeSockets();
/*
* A long running process might get its TCP
* connection reset if the remote server was
@@ -736,10 +735,8 @@
*/
if (*terrno == ECONNRESET && !connreset) {
connreset = 1;
- res_nclose(statp);
goto same_ns;
}
- res_nclose(statp);
return (0);
}
uint16_t resplen = ntohs(*reinterpret_cast<const uint16_t*>(ans));
@@ -755,18 +752,18 @@
*/
LOG(DEBUG) << __func__ << ": undersized: " << len;
*terrno = EMSGSIZE;
- res_nclose(statp);
+ statp->closeSockets();
return (0);
}
cp = ans;
- while (len != 0 && (n = read(statp->_vcsock, (char*) cp, (size_t) len)) > 0) {
+ while (len != 0 && (n = read(statp->tcp_nssock, (char*)cp, (size_t)len)) > 0) {
cp += n;
len -= n;
}
if (n <= 0) {
*terrno = errno;
PLOG(DEBUG) << __func__ << ": read(vc): ";
- res_nclose(statp);
+ statp->closeSockets();
return (0);
}
@@ -779,7 +776,7 @@
while (len != 0) {
char junk[PACKETSZ];
- n = read(statp->_vcsock, junk, (len > sizeof junk) ? sizeof junk : len);
+ n = read(statp->tcp_nssock, junk, (len > sizeof junk) ? sizeof junk : len);
if (n > 0)
len -= n;
else
@@ -907,7 +904,7 @@
const int nsaplen = sockaddrSize(nsap);
if (statp->nssocks[ns] == -1) {
- statp->nssocks[ns] = socket(nsap->sa_family, SOCK_DGRAM | SOCK_CLOEXEC, 0);
+ statp->nssocks[ns].reset(socket(nsap->sa_family, SOCK_DGRAM | SOCK_CLOEXEC, 0));
if (statp->nssocks[ns] < 0) {
switch (errno) {
case EPROTONOSUPPORT:
@@ -926,7 +923,7 @@
if (statp->_mark != MARK_UNSET) {
if (setsockopt(statp->nssocks[ns], SOL_SOCKET, SO_MARK, &(statp->_mark),
sizeof(statp->_mark)) < 0) {
- res_nclose(statp);
+ statp->closeSockets();
return -1;
}
}
@@ -936,19 +933,19 @@
// a nameserver without timing out.
if (random_bind(statp->nssocks[ns], nsap->sa_family) < 0) {
dump_error("bind(dg)", nsap, nsaplen);
- res_nclose(statp);
+ statp->closeSockets();
return (0);
}
if (connect(statp->nssocks[ns], nsap, (socklen_t)nsaplen) < 0) {
dump_error("connect(dg)", nsap, nsaplen);
- res_nclose(statp);
+ statp->closeSockets();
return (0);
}
LOG(DEBUG) << __func__ << ": new DG socket";
}
if (send(statp->nssocks[ns], (const char*)buf, (size_t)buflen, 0) != buflen) {
PLOG(DEBUG) << __func__ << ": send: ";
- res_nclose(statp);
+ statp->closeSockets();
return 0;
}
@@ -966,7 +963,7 @@
}
if (n < 0) {
PLOG(DEBUG) << __func__ << ": poll: ";
- res_nclose(statp);
+ statp->closeSockets();
return 0;
}
@@ -977,7 +974,7 @@
(sockaddr*)(void*)&from, &fromlen);
if (resplen <= 0) {
PLOG(DEBUG) << __func__ << ": recvfrom: ";
- res_nclose(statp);
+ statp->closeSockets();
return 0;
}
*gotsomewhere = 1;
@@ -985,7 +982,7 @@
// Undersized message.
LOG(DEBUG) << __func__ << ": undersized: " << resplen;
*terrno = EMSGSIZE;
- res_nclose(statp);
+ statp->closeSockets();
return 0;
}
@@ -1003,7 +1000,7 @@
res_pquery(ans, (resplen > anssiz) ? anssiz : resplen);
// record the error
statp->_flags |= RES_F_EDNS0ERR;
- res_nclose(statp);
+ statp->closeSockets();
return 0;
}
@@ -1012,7 +1009,7 @@
if (anhp->rcode == SERVFAIL || anhp->rcode == NOTIMP || anhp->rcode == REFUSED) {
LOG(DEBUG) << __func__ << ": server rejected query:";
res_pquery(ans, (resplen > anssiz) ? anssiz : resplen);
- res_nclose(statp);
+ statp->closeSockets();
*rcode = anhp->rcode;
return 0;
}
@@ -1021,7 +1018,7 @@
// use TCP with same server.
LOG(DEBUG) << __func__ << ": truncated answer";
*v_circuit = 1;
- res_nclose(statp);
+ statp->closeSockets();
return 1;
}
// All is well, or the error is fatal. Signal that the
diff --git a/resolv_private.h b/resolv_private.h
index a062c1a..e217087 100644
--- a/resolv_private.h
+++ b/resolv_private.h
@@ -49,6 +49,7 @@
#pragma once
#include <android-base/logging.h>
+#include <android-base/unique_fd.h>
#include <net/if.h>
#include <time.h>
#include <string>
@@ -73,6 +74,10 @@
#define RES_TIMEOUT 5000 /* min. milliseconds between retries */
#define RES_DFLRETRY 2 /* Default #/tries. */
+// Flags for res_state->_flags
+#define RES_F_VC 0x00000001 // socket is TCP
+#define RES_F_EDNS0ERR 0x00000004 // EDNS0 caused errors
+
// Holds either a sockaddr_in or a sockaddr_in6.
union sockaddr_union {
struct sockaddr sa;
@@ -82,21 +87,31 @@
constexpr int MAXPACKET = 8 * 1024;
struct ResState {
- unsigned netid; // NetId: cache key and socket mark
- uid_t uid; // uid of the app that sent the DNS lookup
- pid_t pid; // pid of the app that sent the DNS lookup
- int nscount; // number of name srvers
- uint16_t id; // current message id
- std::vector<std::string> search_domains; // domains to search
+ void closeSockets() {
+ tcp_nssock.reset();
+ _flags &= ~RES_F_VC;
+
+ for (auto& sock : nssocks) {
+ sock.reset();
+ }
+ }
+ // clang-format off
+ unsigned netid; // NetId: cache key and socket mark
+ uid_t uid; // uid of the app that sent the DNS lookup
+ pid_t pid; // pid of the app that sent the DNS lookup
+ int nscount; // number of name srvers
+ uint16_t id; // current message id
+ std::vector<std::string> search_domains{}; // domains to search
sockaddr_union nsaddrs[MAXNS];
- int nssocks[MAXNS]; // UDP sockets to nameservers
- unsigned ndots : 4; // threshold for initial abs. query
- unsigned _mark; // If non-0 SET_MARK to _mark on all request sockets
- int _vcsock; // TCP socket (but why not one per nameserver?)
- uint32_t _flags; // See RES_F_* defines below
+ android::base::unique_fd nssocks[MAXNS]; // UDP sockets to nameservers
+ unsigned ndots : 4; // threshold for initial abs. query
+ unsigned _mark; // If non-0 SET_MARK to _mark on all request sockets
+ android::base::unique_fd tcp_nssock; // TCP socket (but why not one per nameserver?)
+ uint32_t _flags = 0; // See RES_F_* defines below
android::net::NetworkDnsEventReported* event;
uint32_t netcontext_flags;
- int tc_mode;
+ int tc_mode = 0;
+ // clang-format on
};
// TODO: remove these legacy aliases
@@ -121,10 +136,6 @@
/* End of stats related definitions */
-// Flags for res_state->_flags
-#define RES_F_VC 0x00000001 // socket is TCP
-#define RES_F_EDNS0ERR 0x00000004 // EDNS0 caused errors
-
/*
* Error code extending h_errno codes defined in bionic/libc/include/netdb.h.
*
@@ -146,7 +157,6 @@
int res_nmkquery(int op, const char* qname, int cl, int type, const uint8_t* data, int datalen,
uint8_t* buf, int buflen, int netcontext_flags);
int res_nsend(res_state, const uint8_t*, int, uint8_t*, int, int*, uint32_t);
-void res_nclose(res_state);
int res_nopt(res_state, int, uint8_t*, int, int);
int getaddrinfo_numeric(const char* hostname, const char* servname, addrinfo hints,