Use unique_fd instead of int for sockets

Also refactor socket closing method.

Bug: 135717624
Test: atest
Change-Id: Iedbaf2521c4453195708114d7892ddc2cbe22211
diff --git a/res_init.cpp b/res_init.cpp
index 16fd98c..049e225 100644
--- a/res_init.cpp
+++ b/res_init.cpp
@@ -94,23 +94,11 @@
 
 void res_init(ResState* statp, const struct android_net_context* _Nonnull netcontext,
               android::net::NetworkDnsEventReported* _Nonnull event) {
-    memset(statp, 0, sizeof *statp);
-
     statp->netid = netcontext->dns_netid;
     statp->uid = netcontext->uid;
     statp->pid = netcontext->pid;
+    statp->nscount = 1;
     statp->id = arc4random_uniform(65536);
-    statp->_mark = netcontext->dns_mark;
-    statp->netcontext_flags = netcontext->flags;
-    statp->event = event;
-
-    statp->ndots = 1;
-    statp->_vcsock = -1;
-
-    for (int ns = 0; ns < MAXNS; ns++) {
-        statp->nssocks[ns] = -1;
-    }
-
     // The following dummy initialization is probably useless because
     // it's overwritten later by resolv_populate_res_for_net().
     // TODO: check if it's safe to remove.
@@ -120,28 +108,13 @@
             .sin.sin_port = htons(NAMESERVER_PORT),
     };
     memcpy(&statp->nsaddrs, &u, sizeof(u));
-    statp->nscount = 1;
-}
 
-/*
- * This routine is for closing the socket if a virtual circuit is used and
- * the program wants to close it.  This provides support for endhostent()
- * which expects to close the socket.
- *
- * This routine is not expected to be user visible.
- */
-void res_nclose(res_state statp) {
-    int ns;
-
-    if (statp->_vcsock >= 0) {
-        (void) close(statp->_vcsock);
-        statp->_vcsock = -1;
-        statp->_flags &= ~RES_F_VC;
+    for (auto& sock : statp->nssocks) {
+        sock.reset();
     }
-    for (ns = 0; ns < MAXNS; ns++) {
-        if (statp->nssocks[ns] != -1) {
-            close(statp->nssocks[ns]);
-            statp->nssocks[ns] = -1;
-        }
-    }
+    statp->ndots = 1;
+    statp->_mark = netcontext->dns_mark;
+    statp->tcp_nssock.reset();
+    statp->event = event;
+    statp->netcontext_flags = netcontext->flags;
 }
diff --git a/res_send.cpp b/res_send.cpp
index a4083b6..332bb78 100644
--- a/res_send.cpp
+++ b/res_send.cpp
@@ -555,7 +555,7 @@
             }
             if (resplen < 0) {
                 _resolv_cache_query_failed(statp->netid, buf, buflen, flags);
-                res_nclose(statp);
+                statp->closeSockets();
                 return -terrno;
             };
 
@@ -565,11 +565,11 @@
             if (cache_status == RESOLV_CACHE_NOTFOUND) {
                 resolv_cache_add(statp->netid, buf, buflen, ans, resplen);
             }
-            res_nclose(statp);
+            statp->closeSockets();
             return (resplen);
         }  // for each ns
     }  // for each retry
-    res_nclose(statp);
+    statp->closeSockets();
     terrno = useTcp ? terrno : gotsomewhere ? ETIMEDOUT : ECONNREFUSED;
     // TODO: Remove errno once callers stop using it
     errno = useTcp ? terrno
@@ -631,25 +631,24 @@
     struct timespec now = evNowTime();
 
     /* Are we still talking to whom we want to talk to? */
-    if (statp->_vcsock >= 0 && (statp->_flags & RES_F_VC) != 0) {
+    if (statp->tcp_nssock >= 0 && (statp->_flags & RES_F_VC) != 0) {
         struct sockaddr_storage peer;
         socklen_t size = sizeof peer;
         unsigned old_mark;
         socklen_t mark_size = sizeof(old_mark);
-        if (getpeername(statp->_vcsock, (struct sockaddr*) (void*) &peer, &size) < 0 ||
-            !sock_eq((struct sockaddr*) (void*) &peer, nsap) ||
-            getsockopt(statp->_vcsock, SOL_SOCKET, SO_MARK, &old_mark, &mark_size) < 0 ||
+        if (getpeername(statp->tcp_nssock, (struct sockaddr*)(void*)&peer, &size) < 0 ||
+            !sock_eq((struct sockaddr*)(void*)&peer, nsap) ||
+            getsockopt(statp->tcp_nssock, SOL_SOCKET, SO_MARK, &old_mark, &mark_size) < 0 ||
             old_mark != statp->_mark) {
-            res_nclose(statp);
-            statp->_flags &= ~RES_F_VC;
+            statp->closeSockets();
         }
     }
 
-    if (statp->_vcsock < 0 || (statp->_flags & RES_F_VC) == 0) {
-        if (statp->_vcsock >= 0) res_nclose(statp);
+    if (statp->tcp_nssock < 0 || (statp->_flags & RES_F_VC) == 0) {
+        if (statp->tcp_nssock >= 0) statp->closeSockets();
 
-        statp->_vcsock = socket(nsap->sa_family, SOCK_STREAM | SOCK_CLOEXEC, 0);
-        if (statp->_vcsock < 0) {
+        statp->tcp_nssock.reset(socket(nsap->sa_family, SOCK_STREAM | SOCK_CLOEXEC, 0));
+        if (statp->tcp_nssock < 0) {
             switch (errno) {
                 case EPROTONOSUPPORT:
                 case EPFNOSUPPORT:
@@ -662,9 +661,9 @@
                     return -1;
             }
         }
-        resolv_tag_socket(statp->_vcsock, statp->uid, statp->pid);
+        resolv_tag_socket(statp->tcp_nssock, statp->uid, statp->pid);
         if (statp->_mark != MARK_UNSET) {
-            if (setsockopt(statp->_vcsock, SOL_SOCKET, SO_MARK, &statp->_mark,
+            if (setsockopt(statp->tcp_nssock, SOL_SOCKET, SO_MARK, &statp->_mark,
                            sizeof(statp->_mark)) < 0) {
                 *terrno = errno;
                 PLOG(DEBUG) << __func__ << ": setsockopt: ";
@@ -672,17 +671,17 @@
             }
         }
         errno = 0;
-        if (random_bind(statp->_vcsock, nsap->sa_family) < 0) {
+        if (random_bind(statp->tcp_nssock, nsap->sa_family) < 0) {
             *terrno = errno;
             dump_error("bind/vc", nsap, nsaplen);
-            res_nclose(statp);
+            statp->closeSockets();
             return (0);
         }
-        if (connect_with_timeout(statp->_vcsock, nsap, (socklen_t) nsaplen,
+        if (connect_with_timeout(statp->tcp_nssock, nsap, (socklen_t)nsaplen,
                                  get_timeout(statp, params, ns)) < 0) {
             *terrno = errno;
             dump_error("connect/vc", nsap, nsaplen);
-            res_nclose(statp);
+            statp->closeSockets();
             /*
              * The way connect_with_timeout() is implemented prevents us from reliably
              * determining whether this was really a timeout or e.g. ECONNREFUSED. Since
@@ -705,10 +704,10 @@
             {.iov_base = &len, .iov_len = INT16SZ},
             {.iov_base = const_cast<uint8_t*>(buf), .iov_len = static_cast<size_t>(buflen)},
     };
-    if (writev(statp->_vcsock, iov, 2) != (INT16SZ + buflen)) {
+    if (writev(statp->tcp_nssock, iov, 2) != (INT16SZ + buflen)) {
         *terrno = errno;
         PLOG(DEBUG) << __func__ << ": write failed: ";
-        res_nclose(statp);
+        statp->closeSockets();
         return (0);
     }
     /*
@@ -717,14 +716,14 @@
 read_len:
     cp = ans;
     len = INT16SZ;
-    while ((n = read(statp->_vcsock, (char*) cp, (size_t) len)) > 0) {
+    while ((n = read(statp->tcp_nssock, (char*)cp, (size_t)len)) > 0) {
         cp += n;
         if ((len -= n) == 0) break;
     }
     if (n <= 0) {
         *terrno = errno;
         PLOG(DEBUG) << __func__ << ": read failed: ";
-        res_nclose(statp);
+        statp->closeSockets();
         /*
          * A long running process might get its TCP
          * connection reset if the remote server was
@@ -736,10 +735,8 @@
          */
         if (*terrno == ECONNRESET && !connreset) {
             connreset = 1;
-            res_nclose(statp);
             goto same_ns;
         }
-        res_nclose(statp);
         return (0);
     }
     uint16_t resplen = ntohs(*reinterpret_cast<const uint16_t*>(ans));
@@ -755,18 +752,18 @@
          */
         LOG(DEBUG) << __func__ << ": undersized: " << len;
         *terrno = EMSGSIZE;
-        res_nclose(statp);
+        statp->closeSockets();
         return (0);
     }
     cp = ans;
-    while (len != 0 && (n = read(statp->_vcsock, (char*) cp, (size_t) len)) > 0) {
+    while (len != 0 && (n = read(statp->tcp_nssock, (char*)cp, (size_t)len)) > 0) {
         cp += n;
         len -= n;
     }
     if (n <= 0) {
         *terrno = errno;
         PLOG(DEBUG) << __func__ << ": read(vc): ";
-        res_nclose(statp);
+        statp->closeSockets();
         return (0);
     }
 
@@ -779,7 +776,7 @@
         while (len != 0) {
             char junk[PACKETSZ];
 
-            n = read(statp->_vcsock, junk, (len > sizeof junk) ? sizeof junk : len);
+            n = read(statp->tcp_nssock, junk, (len > sizeof junk) ? sizeof junk : len);
             if (n > 0)
                 len -= n;
             else
@@ -907,7 +904,7 @@
     const int nsaplen = sockaddrSize(nsap);
 
     if (statp->nssocks[ns] == -1) {
-        statp->nssocks[ns] = socket(nsap->sa_family, SOCK_DGRAM | SOCK_CLOEXEC, 0);
+        statp->nssocks[ns].reset(socket(nsap->sa_family, SOCK_DGRAM | SOCK_CLOEXEC, 0));
         if (statp->nssocks[ns] < 0) {
             switch (errno) {
                 case EPROTONOSUPPORT:
@@ -926,7 +923,7 @@
         if (statp->_mark != MARK_UNSET) {
             if (setsockopt(statp->nssocks[ns], SOL_SOCKET, SO_MARK, &(statp->_mark),
                            sizeof(statp->_mark)) < 0) {
-                res_nclose(statp);
+                statp->closeSockets();
                 return -1;
             }
         }
@@ -936,19 +933,19 @@
         // a nameserver without timing out.
         if (random_bind(statp->nssocks[ns], nsap->sa_family) < 0) {
             dump_error("bind(dg)", nsap, nsaplen);
-            res_nclose(statp);
+            statp->closeSockets();
             return (0);
         }
         if (connect(statp->nssocks[ns], nsap, (socklen_t)nsaplen) < 0) {
             dump_error("connect(dg)", nsap, nsaplen);
-            res_nclose(statp);
+            statp->closeSockets();
             return (0);
         }
         LOG(DEBUG) << __func__ << ": new DG socket";
     }
     if (send(statp->nssocks[ns], (const char*)buf, (size_t)buflen, 0) != buflen) {
         PLOG(DEBUG) << __func__ << ": send: ";
-        res_nclose(statp);
+        statp->closeSockets();
         return 0;
     }
 
@@ -966,7 +963,7 @@
         }
         if (n < 0) {
             PLOG(DEBUG) << __func__ << ": poll: ";
-            res_nclose(statp);
+            statp->closeSockets();
             return 0;
         }
 
@@ -977,7 +974,7 @@
                                (sockaddr*)(void*)&from, &fromlen);
         if (resplen <= 0) {
             PLOG(DEBUG) << __func__ << ": recvfrom: ";
-            res_nclose(statp);
+            statp->closeSockets();
             return 0;
         }
         *gotsomewhere = 1;
@@ -985,7 +982,7 @@
             // Undersized message.
             LOG(DEBUG) << __func__ << ": undersized: " << resplen;
             *terrno = EMSGSIZE;
-            res_nclose(statp);
+            statp->closeSockets();
             return 0;
         }
 
@@ -1003,7 +1000,7 @@
             res_pquery(ans, (resplen > anssiz) ? anssiz : resplen);
             // record the error
             statp->_flags |= RES_F_EDNS0ERR;
-            res_nclose(statp);
+            statp->closeSockets();
             return 0;
         }
 
@@ -1012,7 +1009,7 @@
         if (anhp->rcode == SERVFAIL || anhp->rcode == NOTIMP || anhp->rcode == REFUSED) {
             LOG(DEBUG) << __func__ << ": server rejected query:";
             res_pquery(ans, (resplen > anssiz) ? anssiz : resplen);
-            res_nclose(statp);
+            statp->closeSockets();
             *rcode = anhp->rcode;
             return 0;
         }
@@ -1021,7 +1018,7 @@
             // use TCP with same server.
             LOG(DEBUG) << __func__ << ": truncated answer";
             *v_circuit = 1;
-            res_nclose(statp);
+            statp->closeSockets();
             return 1;
         }
         // All is well, or the error is fatal. Signal that the
diff --git a/resolv_private.h b/resolv_private.h
index a062c1a..e217087 100644
--- a/resolv_private.h
+++ b/resolv_private.h
@@ -49,6 +49,7 @@
 #pragma once
 
 #include <android-base/logging.h>
+#include <android-base/unique_fd.h>
 #include <net/if.h>
 #include <time.h>
 #include <string>
@@ -73,6 +74,10 @@
 #define RES_TIMEOUT 5000 /* min. milliseconds between retries */
 #define RES_DFLRETRY 2    /* Default #/tries. */
 
+// Flags for res_state->_flags
+#define RES_F_VC 0x00000001        // socket is TCP
+#define RES_F_EDNS0ERR 0x00000004  // EDNS0 caused errors
+
 // Holds either a sockaddr_in or a sockaddr_in6.
 union sockaddr_union {
     struct sockaddr sa;
@@ -82,21 +87,31 @@
 constexpr int MAXPACKET = 8 * 1024;
 
 struct ResState {
-    unsigned netid;                           // NetId: cache key and socket mark
-    uid_t uid;                                // uid of the app that sent the DNS lookup
-    pid_t pid;                                // pid of the app that sent the DNS lookup
-    int nscount;                              // number of name srvers
-    uint16_t id;                              // current message id
-    std::vector<std::string> search_domains;  // domains to search
+    void closeSockets() {
+        tcp_nssock.reset();
+        _flags &= ~RES_F_VC;
+
+        for (auto& sock : nssocks) {
+            sock.reset();
+        }
+    }
+    // clang-format off
+    unsigned netid;                             // NetId: cache key and socket mark
+    uid_t uid;                                  // uid of the app that sent the DNS lookup
+    pid_t pid;                                  // pid of the app that sent the DNS lookup
+    int nscount;                                // number of name srvers
+    uint16_t id;                                // current message id
+    std::vector<std::string> search_domains{};  // domains to search
     sockaddr_union nsaddrs[MAXNS];
-    int nssocks[MAXNS];                       // UDP sockets to nameservers
-    unsigned ndots : 4;                       // threshold for initial abs. query
-    unsigned _mark;                           // If non-0 SET_MARK to _mark on all request sockets
-    int _vcsock;                              // TCP socket (but why not one per nameserver?)
-    uint32_t _flags;                          // See RES_F_* defines below
+    android::base::unique_fd nssocks[MAXNS];    // UDP sockets to nameservers
+    unsigned ndots : 4;                         // threshold for initial abs. query
+    unsigned _mark;                             // If non-0 SET_MARK to _mark on all request sockets
+    android::base::unique_fd tcp_nssock;        // TCP socket (but why not one per nameserver?)
+    uint32_t _flags = 0;                        // See RES_F_* defines below
     android::net::NetworkDnsEventReported* event;
     uint32_t netcontext_flags;
-    int tc_mode;
+    int tc_mode = 0;
+    // clang-format on
 };
 
 // TODO: remove these legacy aliases
@@ -121,10 +136,6 @@
 
 /* End of stats related definitions */
 
-// Flags for res_state->_flags
-#define RES_F_VC 0x00000001        // socket is TCP
-#define RES_F_EDNS0ERR 0x00000004  // EDNS0 caused errors
-
 /*
  * Error code extending h_errno codes defined in bionic/libc/include/netdb.h.
  *
@@ -146,7 +157,6 @@
 int res_nmkquery(int op, const char* qname, int cl, int type, const uint8_t* data, int datalen,
                  uint8_t* buf, int buflen, int netcontext_flags);
 int res_nsend(res_state, const uint8_t*, int, uint8_t*, int, int*, uint32_t);
-void res_nclose(res_state);
 int res_nopt(res_state, int, uint8_t*, int, int);
 
 int getaddrinfo_numeric(const char* hostname, const char* servname, addrinfo hints,