Do the A/AAAA lookup in parallel for getaddrinfo
Create threads for each A/AAAA lookup.
The functionality is disabled for now.
Bug: 2609013
Bug: 135717624
Bug: 151698212
Test: atest
Change-Id: I23cdcdf800d2f9ee42b5f19a5dd2045cdf61d41f
diff --git a/getaddrinfo.cpp b/getaddrinfo.cpp
index 3683920..2d5d1b0 100644
--- a/getaddrinfo.cpp
+++ b/getaddrinfo.cpp
@@ -53,6 +53,8 @@
#include <sys/un.h>
#include <unistd.h>
+#include <future>
+
#include <android-base/logging.h>
#include "netd_resolv/resolv.h"
@@ -61,6 +63,7 @@
#include "res_init.h"
#include "resolv_cache.h"
#include "resolv_private.h"
+#include "util.h"
#define ANY 0
@@ -1573,6 +1576,137 @@
/* resolver logic */
+namespace {
+
+constexpr int SLEEP_TIME_MS = 2;
+
+int getHerrnoFromRcode(int rcode) {
+ switch (rcode) {
+ // Not defined in RFC.
+ case RCODE_TIMEOUT:
+ // DNS metrics monitors DNS query timeout.
+ return NETD_RESOLV_H_ERRNO_EXT_TIMEOUT; // extended h_errno.
+ // Defined in RFC 1035 section 4.1.1.
+ case NXDOMAIN:
+ return HOST_NOT_FOUND;
+ case SERVFAIL:
+ return TRY_AGAIN;
+ case NOERROR:
+ return NO_DATA;
+ case FORMERR:
+ case NOTIMP:
+ case REFUSED:
+ default:
+ return NO_RECOVERY;
+ }
+}
+
+struct QueryResult {
+ int ancount;
+ int rcode;
+ int herrno;
+ NetworkDnsEventReported event;
+};
+
+QueryResult doQuery(const char* name, res_target* t, res_state res) {
+ HEADER* hp = (HEADER*)(void*)t->answer.data();
+
+ hp->rcode = NOERROR; // default
+
+ const int cl = t->qclass;
+ const int type = t->qtype;
+ const int anslen = t->answer.size();
+
+ LOG(DEBUG) << __func__ << ": (" << cl << ", " << type << ")";
+
+ uint8_t buf[MAXPACKET];
+
+ int n = res_nmkquery(QUERY, name, cl, type, /*data=*/nullptr, /*datalen=*/0, buf, sizeof(buf),
+ res->netcontext_flags);
+
+ if (n > 0 &&
+ (res->netcontext_flags & (NET_CONTEXT_FLAG_USE_DNS_OVER_TLS | NET_CONTEXT_FLAG_USE_EDNS))) {
+ n = res_nopt(res, n, buf, sizeof(buf), anslen);
+ }
+
+ NetworkDnsEventReported event;
+ if (n <= 0) {
+ LOG(ERROR) << __func__ << ": res_nmkquery failed";
+ return {0, -1, NO_RECOVERY, event};
+ return {
+ .ancount = 0,
+ .rcode = -1,
+ .herrno = NO_RECOVERY,
+ .event = event,
+ };
+ }
+
+ ResState res_temp = fromResState(*res, &event);
+
+ int rcode = NOERROR;
+ n = res_nsend(&res_temp, buf, n, t->answer.data(), anslen, &rcode, 0);
+ if (n < 0 || hp->rcode != NOERROR || ntohs(hp->ancount) == 0) {
+ // if the query choked with EDNS0, retry without EDNS0
+ if ((res_temp.netcontext_flags &
+ (NET_CONTEXT_FLAG_USE_DNS_OVER_TLS | NET_CONTEXT_FLAG_USE_EDNS)) &&
+ (res_temp._flags & RES_F_EDNS0ERR)) {
+ LOG(DEBUG) << __func__ << ": retry without EDNS0";
+ n = res_nmkquery(QUERY, name, cl, type, /*data=*/nullptr, /*datalen=*/0, buf,
+ sizeof(buf), res->netcontext_flags);
+ n = res_nsend(&res_temp, buf, n, t->answer.data(), anslen, &rcode, 0);
+ }
+ }
+
+ LOG(DEBUG) << __func__ << ": rcode=" << hp->rcode << ", ancount=" << ntohs(hp->ancount);
+
+ t->n = n;
+ return {
+ .ancount = ntohs(hp->ancount),
+ .rcode = rcode,
+ .event = event,
+ };
+}
+
+} // namespace
+
+static int res_queryN_parallel(const char* name, res_target* target, res_state res, int* herrno) {
+ std::vector<std::future<QueryResult>> results;
+ results.reserve(2);
+ for (res_target* t = target; t; t = t->next) {
+ results.emplace_back(std::async(std::launch::async, doQuery, name, t, res));
+ // Avoiding gateways drop packets if queries are sent too close together
+ if (t->next) usleep(SLEEP_TIME_MS * 1000);
+ }
+
+ int ancount = 0;
+ int rcode = 0;
+
+ for (auto& f : results) {
+ const QueryResult& r = f.get();
+ if (r.herrno == NO_RECOVERY) {
+ *herrno = r.herrno;
+ return -1;
+ }
+ res->event->MergeFrom(r.event);
+ ancount += r.ancount;
+ rcode = r.rcode;
+ }
+
+ if (ancount == 0) {
+ *herrno = getHerrnoFromRcode(rcode);
+ return -1;
+ }
+
+ return ancount;
+}
+
+static int res_queryN_wrapper(const char* name, res_target* target, res_state res, int* herrno) {
+ const bool parallel_lookup = getExperimentFlagInt("parallel_lookup", 0);
+ if (parallel_lookup) return res_queryN_parallel(name, target, res, herrno);
+
+ return res_queryN(name, target, res, herrno);
+}
+
/*
* Formulate a normal query, send, and await answer.
* Returned answer is placed in supplied buffer "answer".
@@ -1647,29 +1781,7 @@
}
if (ancount == 0) {
- switch (rcode) {
- // Not defined in RFC.
- case RCODE_TIMEOUT:
- // DNS metrics monitors DNS query timeout.
- *herrno = NETD_RESOLV_H_ERRNO_EXT_TIMEOUT; // extended h_errno.
- break;
- // Defined in RFC 1035 section 4.1.1.
- case NXDOMAIN:
- *herrno = HOST_NOT_FOUND;
- break;
- case SERVFAIL:
- *herrno = TRY_AGAIN;
- break;
- case NOERROR:
- *herrno = NO_DATA;
- break;
- case FORMERR:
- case NOTIMP:
- case REFUSED:
- default:
- *herrno = NO_RECOVERY;
- break;
- }
+ *herrno = getHerrnoFromRcode(rcode);
return -1;
}
return ancount;
@@ -1795,10 +1907,8 @@
return -1;
}
-/*
- * Perform a call on res_query on the concatenation of name and domain,
- * removing a trailing dot from name if domain is NULL.
- */
+// Perform a call on res_query on the concatenation of name and domain,
+// removing a trailing dot from name if domain is NULL.
static int res_querydomainN(const char* name, const char* domain, res_target* target, res_state res,
int* herrno) {
char nbuf[MAXDNAME];
@@ -1828,5 +1938,5 @@
}
snprintf(nbuf, sizeof(nbuf), "%s.%s", name, domain);
}
- return res_queryN(longname, target, res, herrno);
+ return res_queryN_wrapper(longname, target, res, herrno);
}
diff --git a/res_init.cpp b/res_init.cpp
index 334e022..a0004c4 100644
--- a/res_init.cpp
+++ b/res_init.cpp
@@ -91,6 +91,7 @@
#include "netd_resolv/resolv.h"
#include "resolv_private.h"
+#include "stats.pb.h"
void res_init(ResState* statp, const struct android_net_context* _Nonnull netcontext,
android::net::NetworkDnsEventReported* _Nonnull event) {
@@ -108,3 +109,25 @@
statp->event = event;
statp->netcontext_flags = netcontext->flags;
}
+
+// TODO: Have some proper constructors for ResState instead of this method and res_init().
+ResState fromResState(const ResState& other, android::net::NetworkDnsEventReported* event) {
+ ResState resOutput;
+ resOutput.netid = other.netid;
+ resOutput.uid = other.uid;
+ resOutput.pid = other.pid;
+ resOutput.id = other.id;
+
+ resOutput.nsaddrs = other.nsaddrs;
+
+ for (auto& sock : resOutput.nssocks) {
+ sock.reset();
+ }
+
+ resOutput.ndots = other.ndots;
+ resOutput._mark = other._mark;
+ resOutput.tcp_nssock.reset();
+ resOutput.event = event;
+ resOutput.netcontext_flags = other.netcontext_flags;
+ return resOutput;
+}
diff --git a/res_init.h b/res_init.h
index 4c759c4..0cfd7b9 100644
--- a/res_init.h
+++ b/res_init.h
@@ -16,7 +16,9 @@
#pragma once
#include "resolv_private.h"
+#include "stats.pb.h"
// TODO: make this a constructor for ResState
void res_init(ResState* res, const struct android_net_context* netcontext,
android::net::NetworkDnsEventReported* event);
+ResState fromResState(const ResState& other, android::net::NetworkDnsEventReported* event);
diff --git a/tests/dns_responder/dns_tls_frontend.cpp b/tests/dns_responder/dns_tls_frontend.cpp
index bf23dbc..28fa3c9 100644
--- a/tests/dns_responder/dns_tls_frontend.cpp
+++ b/tests/dns_responder/dns_tls_frontend.cpp
@@ -213,66 +213,68 @@
SSL_set_fd(ssl.get(), client.get());
LOG(DEBUG) << "Doing SSL handshake";
- bool success = false;
if (SSL_accept(ssl.get()) <= 0) {
LOG(INFO) << "SSL negotiation failure";
} else {
LOG(DEBUG) << "SSL handshake complete";
- success = handleOneRequest(ssl.get());
- }
-
- if (success) {
// Increment queries_ as late as possible, because it represents
// a query that is fully processed, and the response returned to the
// client, including cleanup actions.
- ++queries_;
+ queries_ += handleRequests(ssl.get(), client.get());
}
}
}
LOG(DEBUG) << "Ending loop";
}
-bool DnsTlsFrontend::handleOneRequest(SSL* ssl) {
- uint8_t queryHeader[2];
- if (SSL_read(ssl, &queryHeader, 2) != 2) {
- LOG(INFO) << "Not enough header bytes";
- return false;
- }
- const uint16_t qlen = (queryHeader[0] << 8) | queryHeader[1];
- uint8_t query[qlen];
- size_t qbytes = 0;
- while (qbytes < qlen) {
- int ret = SSL_read(ssl, query + qbytes, qlen - qbytes);
- if (ret <= 0) {
- LOG(INFO) << "Error while reading query";
- return false;
+int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) {
+ int queryCounts = 0;
+ pollfd fds = {.fd = clientFd, .events = POLLIN};
+ do {
+ uint8_t queryHeader[2];
+ if (SSL_read(ssl, &queryHeader, 2) != 2) {
+ LOG(INFO) << "Not enough header bytes";
+ return queryCounts;
}
- qbytes += ret;
- }
- int sent = send(backend_socket_.get(), query, qlen, 0);
- if (sent != qlen) {
- LOG(INFO) << "Failed to send query";
- return false;
- }
- const int max_size = 4096;
- uint8_t recv_buffer[max_size];
- int rlen = recv(backend_socket_.get(), recv_buffer, max_size, 0);
- if (rlen <= 0) {
- LOG(INFO) << "Failed to receive response";
- return false;
- }
- uint8_t responseHeader[2];
- responseHeader[0] = rlen >> 8;
- responseHeader[1] = rlen;
- if (SSL_write(ssl, responseHeader, 2) != 2) {
- LOG(INFO) << "Failed to write response header";
- return false;
- }
- if (SSL_write(ssl, recv_buffer, rlen) != rlen) {
- LOG(INFO) << "Failed to write response body";
- return false;
- }
- return true;
+ const uint16_t qlen = (queryHeader[0] << 8) | queryHeader[1];
+ uint8_t query[qlen];
+ size_t qbytes = 0;
+ while (qbytes < qlen) {
+ int ret = SSL_read(ssl, query + qbytes, qlen - qbytes);
+ if (ret <= 0) {
+ LOG(INFO) << "Error while reading query";
+ return queryCounts;
+ }
+ qbytes += ret;
+ }
+ int sent = send(backend_socket_.get(), query, qlen, 0);
+ if (sent != qlen) {
+ LOG(INFO) << "Failed to send query";
+ return queryCounts;
+ }
+ const int max_size = 4096;
+ uint8_t recv_buffer[max_size];
+ int rlen = recv(backend_socket_.get(), recv_buffer, max_size, 0);
+ if (rlen <= 0) {
+ LOG(INFO) << "Failed to receive response";
+ return queryCounts;
+ }
+ uint8_t responseHeader[2];
+ responseHeader[0] = rlen >> 8;
+ responseHeader[1] = rlen;
+ if (SSL_write(ssl, responseHeader, 2) != 2) {
+ LOG(INFO) << "Failed to write response header";
+ return queryCounts;
+ }
+ if (SSL_write(ssl, recv_buffer, rlen) != rlen) {
+ LOG(INFO) << "Failed to write response body";
+ return queryCounts;
+ }
+ ++queryCounts;
+ } while (poll(&fds, 1, 1) > 0);
+
+ LOG(DEBUG) << __func__ << " return: " << queryCounts;
+ return queryCounts;
}
bool DnsTlsFrontend::stopServer() {
diff --git a/tests/dns_responder/dns_tls_frontend.h b/tests/dns_responder/dns_tls_frontend.h
index 386db6a..6ba5681 100644
--- a/tests/dns_responder/dns_tls_frontend.h
+++ b/tests/dns_responder/dns_tls_frontend.h
@@ -69,7 +69,7 @@
private:
void requestHandler();
- bool handleOneRequest(SSL* ssl);
+ int handleRequests(SSL* ssl, int clientFd);
// Trigger the handler thread to terminate.
bool sendToEventFd();
diff --git a/tests/resolv_gold_test.cpp b/tests/resolv_gold_test.cpp
index d8e9a6c..0b8c1b8 100644
--- a/tests/resolv_gold_test.cpp
+++ b/tests/resolv_gold_test.cpp
@@ -245,7 +245,7 @@
// Verify DNS server query status.
EXPECT_EQ(GetNumQueries(dns, name.c_str()), queries);
- if (protocol == DnsProtocol::TLS) EXPECT_EQ(tls.queries(), static_cast<int>(queries));
+ if (protocol == DnsProtocol::TLS) EXPECT_TRUE(tls.waitForQueries(queries));
}
static constexpr res_params kParams = {
@@ -380,7 +380,7 @@
const std::vector<std::string> result_strs = ToStrings(result);
EXPECT_THAT(result_strs, testing::UnorderedElementsAreArray(
{kHelloExampleComAddrV4, kHelloExampleComAddrV6}));
- EXPECT_EQ(tls.queries(), 3);
+ EXPECT_TRUE(tls.waitForQueries(3));
}
// Parameterized test class definition.
diff --git a/tests/resolv_integration_test.cpp b/tests/resolv_integration_test.cpp
index a1e73f1..7a53bc6 100644
--- a/tests/resolv_integration_test.cpp
+++ b/tests/resolv_integration_test.cpp
@@ -1693,7 +1693,7 @@
// Wait for query to get counted.
EXPECT_TRUE(tls1.waitForQueries(2));
// No new queries should have reached tls2.
- EXPECT_EQ(1, tls2.queries());
+ EXPECT_TRUE(tls2.waitForQueries(1));
// Stop tls1. Subsequent queries should attempt to reach tls1, fail, and retry to tls2.
tls1.stopServer();
@@ -3963,7 +3963,7 @@
std::tie(result, timeTakenMs) = safe_getaddrinfo_time_taken(hostname2, nullptr, hints);
EXPECT_NE(nullptr, result);
- EXPECT_EQ(1, tls.queries());
+ EXPECT_TRUE(tls.waitForQueries(1));
EXPECT_EQ(1U, GetNumQueries(dns, hostname2));
EXPECT_EQ(records.at(1).addr, ToString(result));
@@ -4755,3 +4755,37 @@
std::string result_str = ToString(result);
EXPECT_TRUE(result_str == "::1.2.3.4") << ", result_str='" << result_str << "'";
}
+
+TEST_F(ResolverTest, GetAddrInfoParallelLookupTimeout) {
+ constexpr char listen_addr[] = "127.0.0.4";
+ constexpr char host_name[] = "howdy.example.com.";
+ constexpr int TIMING_TOLERANCE_MS = 200;
+ constexpr int DNS_TIMEOUT_MS = 1000;
+ const std::vector<DnsRecord> records = {
+ {host_name, ns_type::ns_t_a, "1.2.3.4"},
+ {host_name, ns_type::ns_t_aaaa, "::1.2.3.4"},
+ };
+ const std::vector<int> params = {300, 25, 8, 8, DNS_TIMEOUT_MS /* BASE_TIMEOUT_MSEC */,
+ 1 /* retry count */};
+ test::DNSResponder neverRespondDns(listen_addr, "53", static_cast<ns_rcode>(-1));
+ neverRespondDns.setResponseProbability(0.0);
+ StartDns(neverRespondDns, records);
+
+ ASSERT_TRUE(mDnsClient.SetResolversForNetwork({listen_addr}, kDefaultSearchDomains, params));
+ neverRespondDns.clearQueries();
+
+ const std::string udpKeepListeningFlag("persist.device_config.netd_native.parallel_lookup");
+ ScopedSystemProperties scopedSystemProperties(udpKeepListeningFlag, "1");
+
+ // Use a never respond DNS server to verify if the A/AAAA queries are sent in parallel.
+ // The resolver parameters are set to timeout 1s and retry 1 times.
+ // So we expect the safe_getaddrinfo_time_taken() might take ~1s to
+ // return when parallel lookup is enabled. And the DNS server should receive 2 queries.
+ const addrinfo hints = {.ai_family = AF_UNSPEC, .ai_socktype = SOCK_DGRAM};
+ auto [result, timeTakenMs] = safe_getaddrinfo_time_taken(host_name, nullptr, hints);
+
+ EXPECT_TRUE(result == nullptr);
+ EXPECT_NEAR(DNS_TIMEOUT_MS, timeTakenMs, TIMING_TOLERANCE_MS)
+ << "took time should approximate equal timeout";
+ EXPECT_EQ(2U, GetNumQueries(neverRespondDns, host_name));
+}