Do the A/AAAA lookup in parallel for getaddrinfo

Create threads for each A/AAAA lookup.
The functionality is disabled for now.

Bug: 2609013
Bug: 135717624
Bug: 151698212
Test: atest

Change-Id: I23cdcdf800d2f9ee42b5f19a5dd2045cdf61d41f
diff --git a/getaddrinfo.cpp b/getaddrinfo.cpp
index 3683920..2d5d1b0 100644
--- a/getaddrinfo.cpp
+++ b/getaddrinfo.cpp
@@ -53,6 +53,8 @@
 #include <sys/un.h>
 #include <unistd.h>
 
+#include <future>
+
 #include <android-base/logging.h>
 
 #include "netd_resolv/resolv.h"
@@ -61,6 +63,7 @@
 #include "res_init.h"
 #include "resolv_cache.h"
 #include "resolv_private.h"
+#include "util.h"
 
 #define ANY 0
 
@@ -1573,6 +1576,137 @@
 
 /* resolver logic */
 
+namespace {
+
+constexpr int SLEEP_TIME_MS = 2;
+
+int getHerrnoFromRcode(int rcode) {
+    switch (rcode) {
+        // Not defined in RFC.
+        case RCODE_TIMEOUT:
+            // DNS metrics monitors DNS query timeout.
+            return NETD_RESOLV_H_ERRNO_EXT_TIMEOUT;  // extended h_errno.
+        // Defined in RFC 1035 section 4.1.1.
+        case NXDOMAIN:
+            return HOST_NOT_FOUND;
+        case SERVFAIL:
+            return TRY_AGAIN;
+        case NOERROR:
+            return NO_DATA;
+        case FORMERR:
+        case NOTIMP:
+        case REFUSED:
+        default:
+            return NO_RECOVERY;
+    }
+}
+
+struct QueryResult {
+    int ancount;
+    int rcode;
+    int herrno;
+    NetworkDnsEventReported event;
+};
+
+QueryResult doQuery(const char* name, res_target* t, res_state res) {
+    HEADER* hp = (HEADER*)(void*)t->answer.data();
+
+    hp->rcode = NOERROR;  // default
+
+    const int cl = t->qclass;
+    const int type = t->qtype;
+    const int anslen = t->answer.size();
+
+    LOG(DEBUG) << __func__ << ": (" << cl << ", " << type << ")";
+
+    uint8_t buf[MAXPACKET];
+
+    int n = res_nmkquery(QUERY, name, cl, type, /*data=*/nullptr, /*datalen=*/0, buf, sizeof(buf),
+                         res->netcontext_flags);
+
+    if (n > 0 &&
+        (res->netcontext_flags & (NET_CONTEXT_FLAG_USE_DNS_OVER_TLS | NET_CONTEXT_FLAG_USE_EDNS))) {
+        n = res_nopt(res, n, buf, sizeof(buf), anslen);
+    }
+
+    NetworkDnsEventReported event;
+    if (n <= 0) {
+        LOG(ERROR) << __func__ << ": res_nmkquery failed";
+        return {0, -1, NO_RECOVERY, event};
+        return {
+                .ancount = 0,
+                .rcode = -1,
+                .herrno = NO_RECOVERY,
+                .event = event,
+        };
+    }
+
+    ResState res_temp = fromResState(*res, &event);
+
+    int rcode = NOERROR;
+    n = res_nsend(&res_temp, buf, n, t->answer.data(), anslen, &rcode, 0);
+    if (n < 0 || hp->rcode != NOERROR || ntohs(hp->ancount) == 0) {
+        // if the query choked with EDNS0, retry without EDNS0
+        if ((res_temp.netcontext_flags &
+             (NET_CONTEXT_FLAG_USE_DNS_OVER_TLS | NET_CONTEXT_FLAG_USE_EDNS)) &&
+            (res_temp._flags & RES_F_EDNS0ERR)) {
+            LOG(DEBUG) << __func__ << ": retry without EDNS0";
+            n = res_nmkquery(QUERY, name, cl, type, /*data=*/nullptr, /*datalen=*/0, buf,
+                             sizeof(buf), res->netcontext_flags);
+            n = res_nsend(&res_temp, buf, n, t->answer.data(), anslen, &rcode, 0);
+        }
+    }
+
+    LOG(DEBUG) << __func__ << ": rcode=" << hp->rcode << ", ancount=" << ntohs(hp->ancount);
+
+    t->n = n;
+    return {
+            .ancount = ntohs(hp->ancount),
+            .rcode = rcode,
+            .event = event,
+    };
+}
+
+}  // namespace
+
+static int res_queryN_parallel(const char* name, res_target* target, res_state res, int* herrno) {
+    std::vector<std::future<QueryResult>> results;
+    results.reserve(2);
+    for (res_target* t = target; t; t = t->next) {
+        results.emplace_back(std::async(std::launch::async, doQuery, name, t, res));
+        // Avoiding gateways drop packets if queries are sent too close together
+        if (t->next) usleep(SLEEP_TIME_MS * 1000);
+    }
+
+    int ancount = 0;
+    int rcode = 0;
+
+    for (auto& f : results) {
+        const QueryResult& r = f.get();
+        if (r.herrno == NO_RECOVERY) {
+            *herrno = r.herrno;
+            return -1;
+        }
+        res->event->MergeFrom(r.event);
+        ancount += r.ancount;
+        rcode = r.rcode;
+    }
+
+    if (ancount == 0) {
+        *herrno = getHerrnoFromRcode(rcode);
+        return -1;
+    }
+
+    return ancount;
+}
+
+static int res_queryN_wrapper(const char* name, res_target* target, res_state res, int* herrno) {
+    const bool parallel_lookup = getExperimentFlagInt("parallel_lookup", 0);
+    if (parallel_lookup) return res_queryN_parallel(name, target, res, herrno);
+
+    return res_queryN(name, target, res, herrno);
+}
+
 /*
  * Formulate a normal query, send, and await answer.
  * Returned answer is placed in supplied buffer "answer".
@@ -1647,29 +1781,7 @@
     }
 
     if (ancount == 0) {
-        switch (rcode) {
-            // Not defined in RFC.
-            case RCODE_TIMEOUT:
-                // DNS metrics monitors DNS query timeout.
-                *herrno = NETD_RESOLV_H_ERRNO_EXT_TIMEOUT;  // extended h_errno.
-                break;
-            // Defined in RFC 1035 section 4.1.1.
-            case NXDOMAIN:
-                *herrno = HOST_NOT_FOUND;
-                break;
-            case SERVFAIL:
-                *herrno = TRY_AGAIN;
-                break;
-            case NOERROR:
-                *herrno = NO_DATA;
-                break;
-            case FORMERR:
-            case NOTIMP:
-            case REFUSED:
-            default:
-                *herrno = NO_RECOVERY;
-                break;
-        }
+        *herrno = getHerrnoFromRcode(rcode);
         return -1;
     }
     return ancount;
@@ -1795,10 +1907,8 @@
     return -1;
 }
 
-/*
- * Perform a call on res_query on the concatenation of name and domain,
- * removing a trailing dot from name if domain is NULL.
- */
+// Perform a call on res_query on the concatenation of name and domain,
+// removing a trailing dot from name if domain is NULL.
 static int res_querydomainN(const char* name, const char* domain, res_target* target, res_state res,
                             int* herrno) {
     char nbuf[MAXDNAME];
@@ -1828,5 +1938,5 @@
         }
         snprintf(nbuf, sizeof(nbuf), "%s.%s", name, domain);
     }
-    return res_queryN(longname, target, res, herrno);
+    return res_queryN_wrapper(longname, target, res, herrno);
 }
diff --git a/res_init.cpp b/res_init.cpp
index 334e022..a0004c4 100644
--- a/res_init.cpp
+++ b/res_init.cpp
@@ -91,6 +91,7 @@
 
 #include "netd_resolv/resolv.h"
 #include "resolv_private.h"
+#include "stats.pb.h"
 
 void res_init(ResState* statp, const struct android_net_context* _Nonnull netcontext,
               android::net::NetworkDnsEventReported* _Nonnull event) {
@@ -108,3 +109,25 @@
     statp->event = event;
     statp->netcontext_flags = netcontext->flags;
 }
+
+// TODO: Have some proper constructors for ResState instead of this method and res_init().
+ResState fromResState(const ResState& other, android::net::NetworkDnsEventReported* event) {
+    ResState resOutput;
+    resOutput.netid = other.netid;
+    resOutput.uid = other.uid;
+    resOutput.pid = other.pid;
+    resOutput.id = other.id;
+
+    resOutput.nsaddrs = other.nsaddrs;
+
+    for (auto& sock : resOutput.nssocks) {
+        sock.reset();
+    }
+
+    resOutput.ndots = other.ndots;
+    resOutput._mark = other._mark;
+    resOutput.tcp_nssock.reset();
+    resOutput.event = event;
+    resOutput.netcontext_flags = other.netcontext_flags;
+    return resOutput;
+}
diff --git a/res_init.h b/res_init.h
index 4c759c4..0cfd7b9 100644
--- a/res_init.h
+++ b/res_init.h
@@ -16,7 +16,9 @@
 #pragma once
 
 #include "resolv_private.h"
+#include "stats.pb.h"
 
 // TODO: make this a constructor for ResState
 void res_init(ResState* res, const struct android_net_context* netcontext,
               android::net::NetworkDnsEventReported* event);
+ResState fromResState(const ResState& other, android::net::NetworkDnsEventReported* event);
diff --git a/tests/dns_responder/dns_tls_frontend.cpp b/tests/dns_responder/dns_tls_frontend.cpp
index bf23dbc..28fa3c9 100644
--- a/tests/dns_responder/dns_tls_frontend.cpp
+++ b/tests/dns_responder/dns_tls_frontend.cpp
@@ -213,66 +213,68 @@
             SSL_set_fd(ssl.get(), client.get());
 
             LOG(DEBUG) << "Doing SSL handshake";
-            bool success = false;
             if (SSL_accept(ssl.get()) <= 0) {
                 LOG(INFO) << "SSL negotiation failure";
             } else {
                 LOG(DEBUG) << "SSL handshake complete";
-                success = handleOneRequest(ssl.get());
-            }
-
-            if (success) {
                 // Increment queries_ as late as possible, because it represents
                 // a query that is fully processed, and the response returned to the
                 // client, including cleanup actions.
-                ++queries_;
+                queries_ += handleRequests(ssl.get(), client.get());
             }
         }
     }
     LOG(DEBUG) << "Ending loop";
 }
 
-bool DnsTlsFrontend::handleOneRequest(SSL* ssl) {
-    uint8_t queryHeader[2];
-    if (SSL_read(ssl, &queryHeader, 2) != 2) {
-        LOG(INFO) << "Not enough header bytes";
-        return false;
-    }
-    const uint16_t qlen = (queryHeader[0] << 8) | queryHeader[1];
-    uint8_t query[qlen];
-    size_t qbytes = 0;
-    while (qbytes < qlen) {
-        int ret = SSL_read(ssl, query + qbytes, qlen - qbytes);
-        if (ret <= 0) {
-            LOG(INFO) << "Error while reading query";
-            return false;
+int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) {
+    int queryCounts = 0;
+    pollfd fds = {.fd = clientFd, .events = POLLIN};
+    do {
+        uint8_t queryHeader[2];
+        if (SSL_read(ssl, &queryHeader, 2) != 2) {
+            LOG(INFO) << "Not enough header bytes";
+            return queryCounts;
         }
-        qbytes += ret;
-    }
-    int sent = send(backend_socket_.get(), query, qlen, 0);
-    if (sent != qlen) {
-        LOG(INFO) << "Failed to send query";
-        return false;
-    }
-    const int max_size = 4096;
-    uint8_t recv_buffer[max_size];
-    int rlen = recv(backend_socket_.get(), recv_buffer, max_size, 0);
-    if (rlen <= 0) {
-        LOG(INFO) << "Failed to receive response";
-        return false;
-    }
-    uint8_t responseHeader[2];
-    responseHeader[0] = rlen >> 8;
-    responseHeader[1] = rlen;
-    if (SSL_write(ssl, responseHeader, 2) != 2) {
-        LOG(INFO) << "Failed to write response header";
-        return false;
-    }
-    if (SSL_write(ssl, recv_buffer, rlen) != rlen) {
-        LOG(INFO) << "Failed to write response body";
-        return false;
-    }
-    return true;
+        const uint16_t qlen = (queryHeader[0] << 8) | queryHeader[1];
+        uint8_t query[qlen];
+        size_t qbytes = 0;
+        while (qbytes < qlen) {
+            int ret = SSL_read(ssl, query + qbytes, qlen - qbytes);
+            if (ret <= 0) {
+                LOG(INFO) << "Error while reading query";
+                return queryCounts;
+            }
+            qbytes += ret;
+        }
+        int sent = send(backend_socket_.get(), query, qlen, 0);
+        if (sent != qlen) {
+            LOG(INFO) << "Failed to send query";
+            return queryCounts;
+        }
+        const int max_size = 4096;
+        uint8_t recv_buffer[max_size];
+        int rlen = recv(backend_socket_.get(), recv_buffer, max_size, 0);
+        if (rlen <= 0) {
+            LOG(INFO) << "Failed to receive response";
+            return queryCounts;
+        }
+        uint8_t responseHeader[2];
+        responseHeader[0] = rlen >> 8;
+        responseHeader[1] = rlen;
+        if (SSL_write(ssl, responseHeader, 2) != 2) {
+            LOG(INFO) << "Failed to write response header";
+            return queryCounts;
+        }
+        if (SSL_write(ssl, recv_buffer, rlen) != rlen) {
+            LOG(INFO) << "Failed to write response body";
+            return queryCounts;
+        }
+        ++queryCounts;
+    } while (poll(&fds, 1, 1) > 0);
+
+    LOG(DEBUG) << __func__ << " return: " << queryCounts;
+    return queryCounts;
 }
 
 bool DnsTlsFrontend::stopServer() {
diff --git a/tests/dns_responder/dns_tls_frontend.h b/tests/dns_responder/dns_tls_frontend.h
index 386db6a..6ba5681 100644
--- a/tests/dns_responder/dns_tls_frontend.h
+++ b/tests/dns_responder/dns_tls_frontend.h
@@ -69,7 +69,7 @@
 
   private:
     void requestHandler();
-    bool handleOneRequest(SSL* ssl);
+    int handleRequests(SSL* ssl, int clientFd);
 
     // Trigger the handler thread to terminate.
     bool sendToEventFd();
diff --git a/tests/resolv_gold_test.cpp b/tests/resolv_gold_test.cpp
index d8e9a6c..0b8c1b8 100644
--- a/tests/resolv_gold_test.cpp
+++ b/tests/resolv_gold_test.cpp
@@ -245,7 +245,7 @@
 
         // Verify DNS server query status.
         EXPECT_EQ(GetNumQueries(dns, name.c_str()), queries);
-        if (protocol == DnsProtocol::TLS) EXPECT_EQ(tls.queries(), static_cast<int>(queries));
+        if (protocol == DnsProtocol::TLS) EXPECT_TRUE(tls.waitForQueries(queries));
     }
 
     static constexpr res_params kParams = {
@@ -380,7 +380,7 @@
     const std::vector<std::string> result_strs = ToStrings(result);
     EXPECT_THAT(result_strs, testing::UnorderedElementsAreArray(
                                      {kHelloExampleComAddrV4, kHelloExampleComAddrV6}));
-    EXPECT_EQ(tls.queries(), 3);
+    EXPECT_TRUE(tls.waitForQueries(3));
 }
 
 // Parameterized test class definition.
diff --git a/tests/resolv_integration_test.cpp b/tests/resolv_integration_test.cpp
index a1e73f1..7a53bc6 100644
--- a/tests/resolv_integration_test.cpp
+++ b/tests/resolv_integration_test.cpp
@@ -1693,7 +1693,7 @@
     // Wait for query to get counted.
     EXPECT_TRUE(tls1.waitForQueries(2));
     // No new queries should have reached tls2.
-    EXPECT_EQ(1, tls2.queries());
+    EXPECT_TRUE(tls2.waitForQueries(1));
 
     // Stop tls1.  Subsequent queries should attempt to reach tls1, fail, and retry to tls2.
     tls1.stopServer();
@@ -3963,7 +3963,7 @@
     std::tie(result, timeTakenMs) = safe_getaddrinfo_time_taken(hostname2, nullptr, hints);
 
     EXPECT_NE(nullptr, result);
-    EXPECT_EQ(1, tls.queries());
+    EXPECT_TRUE(tls.waitForQueries(1));
     EXPECT_EQ(1U, GetNumQueries(dns, hostname2));
     EXPECT_EQ(records.at(1).addr, ToString(result));
 
@@ -4755,3 +4755,37 @@
     std::string result_str = ToString(result);
     EXPECT_TRUE(result_str == "::1.2.3.4") << ", result_str='" << result_str << "'";
 }
+
+TEST_F(ResolverTest, GetAddrInfoParallelLookupTimeout) {
+    constexpr char listen_addr[] = "127.0.0.4";
+    constexpr char host_name[] = "howdy.example.com.";
+    constexpr int TIMING_TOLERANCE_MS = 200;
+    constexpr int DNS_TIMEOUT_MS = 1000;
+    const std::vector<DnsRecord> records = {
+            {host_name, ns_type::ns_t_a, "1.2.3.4"},
+            {host_name, ns_type::ns_t_aaaa, "::1.2.3.4"},
+    };
+    const std::vector<int> params = {300, 25, 8, 8, DNS_TIMEOUT_MS /* BASE_TIMEOUT_MSEC */,
+                                     1 /* retry count */};
+    test::DNSResponder neverRespondDns(listen_addr, "53", static_cast<ns_rcode>(-1));
+    neverRespondDns.setResponseProbability(0.0);
+    StartDns(neverRespondDns, records);
+
+    ASSERT_TRUE(mDnsClient.SetResolversForNetwork({listen_addr}, kDefaultSearchDomains, params));
+    neverRespondDns.clearQueries();
+
+    const std::string udpKeepListeningFlag("persist.device_config.netd_native.parallel_lookup");
+    ScopedSystemProperties scopedSystemProperties(udpKeepListeningFlag, "1");
+
+    // Use a never respond DNS server to verify if the A/AAAA queries are sent in parallel.
+    // The resolver parameters are set to timeout 1s and retry 1 times.
+    // So we expect the safe_getaddrinfo_time_taken() might take ~1s to
+    // return when parallel lookup is enabled. And the DNS server should receive 2 queries.
+    const addrinfo hints = {.ai_family = AF_UNSPEC, .ai_socktype = SOCK_DGRAM};
+    auto [result, timeTakenMs] = safe_getaddrinfo_time_taken(host_name, nullptr, hints);
+
+    EXPECT_TRUE(result == nullptr);
+    EXPECT_NEAR(DNS_TIMEOUT_MS, timeTakenMs, TIMING_TOLERANCE_MS)
+            << "took time should approximate equal timeout";
+    EXPECT_EQ(2U, GetNumQueries(neverRespondDns, host_name));
+}