Merge "Replace addrinfo with IPSockAddr to store dns addresses"
diff --git a/Android.bp b/Android.bp
index 625e7c8..b646def 100644
--- a/Android.bp
+++ b/Android.bp
@@ -51,6 +51,7 @@
         "res_query.cpp",
         "res_send.cpp",
         "res_stats.cpp",
+        "util.cpp",
         "Dns64Configuration.cpp",
         "DnsProxyListener.cpp",
         "DnsResolver.cpp",
diff --git a/getaddrinfo.cpp b/getaddrinfo.cpp
index e77b8b3..6e2376b 100644
--- a/getaddrinfo.cpp
+++ b/getaddrinfo.cpp
@@ -1697,7 +1697,7 @@
          * the domain stuff is tried.  Will have a better
          * fix after thread pools are used.
          */
-        _resolv_populate_res_for_net(res);
+        resolv_populate_res_for_net(res);
 
         for (const auto& domain : res->search_domains) {
             ret = res_querydomainN(name, domain.c_str(), target, res, herrno);
diff --git a/res_cache.cpp b/res_cache.cpp
index 25ccea8..de11363 100644
--- a/res_cache.cpp
+++ b/res_cache.cpp
@@ -61,10 +61,14 @@
 #include "DnsStats.h"
 #include "res_debug.h"
 #include "resolv_private.h"
+#include "util.h"
 
 using android::base::StringAppendF;
 using android::net::DnsQueryEvent;
 using android::net::DnsStats;
+using android::net::PROTO_DOT;
+using android::net::PROTO_TCP;
+using android::net::PROTO_UDP;
 using android::netdutils::DumpWriter;
 using android::netdutils::IPSockAddr;
 
@@ -936,7 +940,7 @@
     struct resolv_cache_info* next;
     int nscount;
     std::vector<std::string> nameservers;
-    struct addrinfo* nsaddrinfo[MAXNS];  // TODO: Use struct sockaddr_storage.
+    std::vector<IPSockAddr> nameserverSockAddrs;
     int revision_id;  // # times the nameservers have been replaced
     res_params params;
     struct res_stats nsstats[MAXNS];
@@ -1563,6 +1567,21 @@
     return res;
 }
 
+bool isValidServer(const std::string& server) {
+    const addrinfo hints = {
+            .ai_family = AF_UNSPEC,
+            .ai_socktype = SOCK_DGRAM,
+    };
+    addrinfo* result = nullptr;
+    if (int err = getaddrinfo_numeric(server.c_str(), "53", hints, &result); err != 0) {
+        LOG(WARNING) << __func__ << ": getaddrinfo_numeric(" << server
+                     << ") = " << gai_strerror(err);
+        return false;
+    }
+    freeaddrinfo(result);
+    return true;
+}
+
 }  // namespace
 
 int resolv_set_nameservers(unsigned netid, const std::vector<std::string>& servers,
@@ -1574,24 +1593,11 @@
 
     // Parse the addresses before actually locking or changing any state, in case there is an error.
     // As a side effect this also reduces the time the lock is kept.
-    // TODO: find a better way to replace addrinfo*, something like std::vector<SafeAddrinfo>
-    addrinfo* nsaddrinfo[MAXNS];
-    for (int i = 0; i < numservers; i++) {
-        // The addrinfo structures allocated here are freed in free_nameservers_locked().
-        const addrinfo hints = {
-                .ai_flags = AI_NUMERICHOST,
-                .ai_family = AF_UNSPEC,
-                .ai_socktype = SOCK_DGRAM,
-        };
-        const int rt = getaddrinfo_numeric(nameservers[i].c_str(), "53", hints, &nsaddrinfo[i]);
-        if (rt != 0) {
-            for (int j = 0; j < i; j++) {
-                freeaddrinfo(nsaddrinfo[j]);
-            }
-            LOG(INFO) << __func__ << ": getaddrinfo_numeric(" << nameservers[i]
-                      << ") = " << gai_strerror(rt);
-            return -EINVAL;
-        }
+    std::vector<IPSockAddr> ipSockAddrs;
+    ipSockAddrs.reserve(nameservers.size());
+    for (const auto& server : nameservers) {
+        if (!isValidServer(server)) return -EINVAL;
+        ipSockAddrs.push_back(IPSockAddr::toIPSockAddr(server, 53));
     }
 
     std::lock_guard guard(cache_mutex);
@@ -1607,10 +1613,10 @@
         free_nameservers_locked(cache_info);
         cache_info->nameservers = std::move(nameservers);
         for (int i = 0; i < numservers; i++) {
-            cache_info->nsaddrinfo[i] = nsaddrinfo[i];
             LOG(INFO) << __func__ << ": netid = " << netid
                       << ", addr = " << cache_info->nameservers[i];
         }
+        cache_info->nameserverSockAddrs = std::move(ipSockAddrs);
         cache_info->nscount = numservers;
     } else {
         if (cache_info->params.max_samples != old_max_samples) {
@@ -1621,23 +1627,15 @@
             // under which servers are considered usable.
             res_cache_clear_stats_locked(cache_info);
         }
-        for (int j = 0; j < numservers; j++) {
-            freeaddrinfo(nsaddrinfo[j]);
-        }
     }
 
     // Always update the search paths. Cache-flushing however is not necessary,
     // since the stored cache entries do contain the domain, not just the host name.
     cache_info->search_domains = filter_domains(domains);
 
-    std::vector<IPSockAddr> serverSockAddrs;
-    serverSockAddrs.reserve(cache_info->nameservers.size());
-    for (const auto& server : cache_info->nameservers) {
-        serverSockAddrs.push_back(IPSockAddr::toIPSockAddr(server, 53));
-    }
-
-    if (!cache_info->dnsStats->setServers(serverSockAddrs, android::net::PROTO_TCP) ||
-        !cache_info->dnsStats->setServers(serverSockAddrs, android::net::PROTO_UDP)) {
+    // Setup stats for cleartext dns servers.
+    if (!cache_info->dnsStats->setServers(cache_info->nameserverSockAddrs, PROTO_TCP) ||
+        !cache_info->dnsStats->setServers(cache_info->nameserverSockAddrs, PROTO_UDP)) {
         LOG(WARNING) << __func__ << ": netid = " << netid << ", failed to set dns stats";
         return -EINVAL;
     }
@@ -1659,44 +1657,38 @@
 }
 
 static void free_nameservers_locked(resolv_cache_info* cache_info) {
-    int i;
-    for (i = 0; i < cache_info->nscount; i++) {
-        cache_info->nameservers.clear();
-        if (cache_info->nsaddrinfo[i] != nullptr) {
-            freeaddrinfo(cache_info->nsaddrinfo[i]);
-            cache_info->nsaddrinfo[i] = nullptr;
-        }
-    }
     cache_info->nscount = 0;
+    cache_info->nameservers.clear();
+    cache_info->nameserverSockAddrs.clear();
     res_cache_clear_stats_locked(cache_info);
 }
 
-void _resolv_populate_res_for_net(res_state statp) {
-    if (statp == NULL) {
+void resolv_populate_res_for_net(ResState* statp) {
+    if (statp == nullptr) {
         return;
     }
     LOG(INFO) << __func__ << ": netid=" << statp->netid;
 
     std::lock_guard guard(cache_mutex);
     resolv_cache_info* info = find_cache_info_locked(statp->netid);
-    if (info != NULL) {
-        int nserv;
-        struct addrinfo* ai;
-        for (nserv = 0; nserv < MAXNS; nserv++) {
-            ai = info->nsaddrinfo[nserv];
-            if (ai == NULL) {
-                break;
-            }
+    if (info == nullptr) return;
 
-            if ((size_t)ai->ai_addrlen <= sizeof(statp->nsaddrs[0])) {
-                memcpy(&statp->nsaddrs[nserv], ai->ai_addr, ai->ai_addrlen);
-            } else {
-                LOG(INFO) << __func__ << ": found too long addrlen";
-            }
+    // TODO: Convert nsaddrs[] to c++ container and remove the size-checking.
+    const int serverNum = std::min(MAXNS, static_cast<int>(info->nameserverSockAddrs.size()));
+
+    for (int nserv = 0; nserv < serverNum; nserv++) {
+        sockaddr_storage ss = info->nameserverSockAddrs.at(nserv);
+
+        if (auto sockaddr_len = sockaddrSize(ss); sockaddr_len != 0) {
+            memcpy(&statp->nsaddrs[nserv], &ss, sockaddr_len);
+        } else {
+            LOG(WARNING) << __func__ << ": can't get sa_len from "
+                         << info->nameserverSockAddrs.at(nserv);
         }
-        statp->nscount = nserv;
-        statp->search_domains = info->search_domains;
     }
+
+    statp->nscount = serverNum;
+    statp->search_domains = info->search_domains;
 }
 
 /* Resolver reachability statistics. */
@@ -1745,31 +1737,18 @@
             return -1;
         }
         int i;
-        for (i = 0; i < info->nscount; i++) {
-            // Verify that the following assumptions are held, failure indicates corruption:
-            //  - getaddrinfo() may never return a sockaddr > sockaddr_storage
-            //  - all addresses are valid
-            //  - there is only one address per addrinfo thanks to numeric resolution
-            int addrlen = info->nsaddrinfo[i]->ai_addrlen;
-            if (addrlen < (int) sizeof(struct sockaddr) || addrlen > (int) sizeof(servers[0])) {
-                LOG(INFO) << __func__ << ": nsaddrinfo[" << i << "].ai_addrlen == " << addrlen;
-                errno = EMSGSIZE;
-                return -1;
-            }
-            if (info->nsaddrinfo[i]->ai_addr == NULL) {
-                LOG(INFO) << __func__ << ": nsaddrinfo[" << i << "].ai_addr == NULL";
-                errno = ENOENT;
-                return -1;
-            }
-            if (info->nsaddrinfo[i]->ai_next != NULL) {
-                LOG(INFO) << __func__ << ": nsaddrinfo[" << i << "].ai_next != NULL";
-                errno = ENOTUNIQ;
-                return -1;
-            }
-        }
         *nscount = info->nscount;
+
+        // It shouldn't happen, but just in case of buffer overflow.
+        if (info->nscount != static_cast<int>(info->nameserverSockAddrs.size())) {
+            LOG(INFO) << __func__ << ": nscount " << info->nscount
+                      << " != " << info->nameserverSockAddrs.size();
+            errno = EFAULT;
+            return -1;
+        }
+
         for (i = 0; i < info->nscount; i++) {
-            memcpy(&servers[i], info->nsaddrinfo[i]->ai_addr, info->nsaddrinfo[i]->ai_addrlen);
+            servers[i] = info->nameserverSockAddrs.at(i);
             stats[i] = info->nsstats[i];
         }
 
diff --git a/res_init.cpp b/res_init.cpp
index e15127f..16fd98c 100644
--- a/res_init.cpp
+++ b/res_init.cpp
@@ -112,7 +112,7 @@
     }
 
     // The following dummy initialization is probably useless because
-    // it's overwritten later by _resolv_populate_res_for_net().
+    // it's overwritten later by resolv_populate_res_for_net().
     // TODO: check if it's safe to remove.
     const sockaddr_union u{
             .sin.sin_addr.s_addr = INADDR_ANY,
diff --git a/res_query.cpp b/res_query.cpp
index be0ff83..7702b44 100644
--- a/res_query.cpp
+++ b/res_query.cpp
@@ -250,7 +250,7 @@
          * be loaded once for the thread instead of each
          * time a query is tried.
          */
-        _resolv_populate_res_for_net(statp);
+        resolv_populate_res_for_net(statp);
 
         for (const auto& domain : statp->search_domains) {
             if (domain == "." || domain == "") ++root_on_list;
diff --git a/res_send.cpp b/res_send.cpp
index e917b13..ef86d0c 100644
--- a/res_send.cpp
+++ b/res_send.cpp
@@ -83,7 +83,6 @@
 
 #include <arpa/inet.h>
 #include <arpa/nameser.h>
-#include <netinet/in.h>
 
 #include <errno.h>
 #include <fcntl.h>
@@ -110,6 +109,7 @@
 #include "res_init.h"
 #include "resolv_cache.h"
 #include "stats.pb.h"
+#include "util.h"
 
 // TODO: use the namespace something like android::netd_resolv for libnetd_resolv
 using android::net::CacheStatus;
@@ -136,7 +136,6 @@
 
 static DnsTlsDispatcher sDnsTlsDispatcher;
 
-static int get_salen(const struct sockaddr*);
 static struct sockaddr* get_nsaddr(res_state, size_t);
 static int send_vc(res_state, res_params* params, const uint8_t*, int, uint8_t*, int, int*, int,
                    time_t*, int*, int*);
@@ -427,7 +426,7 @@
     } else if (cache_status != RESOLV_CACHE_UNSUPPORTED) {
         // had a cache miss for a known network, so populate the thread private
         // data so the normal resolve path can do its thing
-        _resolv_populate_res_for_net(statp);
+        resolv_populate_res_for_net(statp);
     }
     if (statp->nscount == 0) {
         // We have no nameservers configured, so there's no point trying.
@@ -493,7 +492,7 @@
             int delay = 0;
             *rcode = RCODE_INTERNAL_ERROR;
             const sockaddr* nsap = get_nsaddr(statp, ns);
-            nsaplen = get_salen(nsap);
+            nsaplen = sockaddrSize(nsap);
 
         same_ns:
             static const int niflags = NI_NUMERICHOST | NI_NUMERICSERV;
@@ -613,15 +612,6 @@
 
 /* Private */
 
-static int get_salen(const struct sockaddr* sa) {
-    if (sa->sa_family == AF_INET)
-        return (sizeof(struct sockaddr_in));
-    else if (sa->sa_family == AF_INET6)
-        return (sizeof(struct sockaddr_in6));
-    else
-        return (0); /* unknown, die on connect */
-}
-
 static struct sockaddr* get_nsaddr(res_state statp, size_t n) {
     return (struct sockaddr*)(void*)&statp->nsaddrs[n];
 }
@@ -662,7 +652,7 @@
     LOG(INFO) << __func__ << ": using send_vc";
 
     nsap = get_nsaddr(statp, (size_t) ns);
-    nsaplen = get_salen(nsap);
+    nsaplen = sockaddrSize(nsap);
 
     connreset = 0;
 same_ns:
@@ -930,7 +920,7 @@
     int resplen, n, s;
 
     nsap = get_nsaddr(statp, (size_t) ns);
-    nsaplen = get_salen(nsap);
+    nsaplen = sockaddrSize(nsap);
     if (statp->nssocks[ns] == -1) {
         statp->nssocks[ns] = socket(nsap->sa_family, SOCK_DGRAM | SOCK_CLOEXEC, 0);
         if (statp->nssocks[ns] < 0) {
@@ -1234,7 +1224,7 @@
     assert(event != nullptr);
     ResState res;
     res_init(&res, netContext, event);
-    _resolv_populate_res_for_net(&res);
+    resolv_populate_res_for_net(&res);
     *rcode = NOERROR;
     return res_nsend(&res, msg, msgLen, ans, ansLen, rcode, flags);
 }
diff --git a/resolv_cache.h b/resolv_cache.h
index 6ede2e0..16019f2 100644
--- a/resolv_cache.h
+++ b/resolv_cache.h
@@ -42,7 +42,7 @@
 // The name servers are retrieved from the cache which is associated
 // with the network to which ResState is associated.
 struct ResState;
-void _resolv_populate_res_for_net(ResState* statp);
+void resolv_populate_res_for_net(ResState* statp);
 
 std::vector<unsigned> resolv_list_caches();
 
diff --git a/util.cpp b/util.cpp
new file mode 100644
index 0000000..91a033e
--- /dev/null
+++ b/util.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include "util.h"
+
+socklen_t sockaddrSize(const sockaddr* sa) {
+    if (sa == nullptr) return 0;
+
+    switch (sa->sa_family) {
+        case AF_INET:
+            return sizeof(sockaddr_in);
+        case AF_INET6:
+            return sizeof(sockaddr_in6);
+        default:
+            return 0;
+    }
+}
+
+socklen_t sockaddrSize(const sockaddr_storage& ss) {
+    return sockaddrSize(reinterpret_cast<const sockaddr*>(&ss));
+}
diff --git a/util.h b/util.h
new file mode 100644
index 0000000..d879011
--- /dev/null
+++ b/util.h
@@ -0,0 +1,23 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#pragma once
+
+#include <netinet/in.h>
+
+socklen_t sockaddrSize(const sockaddr* sa);
+socklen_t sockaddrSize(const sockaddr_storage& ss);