Refactor ResState to store nameserver addresses by IPSockAddr
IPSockAddr is more safer and convenient to store socket addresses,
to compare two socket addresses, and to make the code more readable.
The change also removes get_nsaddr(), a static function in res_send.cpp.
Bug: 137169582
Test: cd packages/modules/DnsResolver && atest
Change-Id: I694c293139b01a39c40cc50ba8c4f067a2ac4b07
diff --git a/res_cache.cpp b/res_cache.cpp
index 4c7e187..744a97d 100644
--- a/res_cache.cpp
+++ b/res_cache.cpp
@@ -1671,21 +1671,8 @@
NetConfig* info = find_netconfig_locked(statp->netid);
if (info == nullptr) return;
- // TODO: Convert nsaddrs[] to c++ container and remove the size-checking.
- const int serverNum = std::min(MAXNS, static_cast<int>(info->nameserverSockAddrs.size()));
-
- for (int nserv = 0; nserv < serverNum; nserv++) {
- sockaddr_storage ss = info->nameserverSockAddrs.at(nserv);
-
- if (auto sockaddr_len = sockaddrSize(ss); sockaddr_len != 0) {
- memcpy(&statp->nsaddrs[nserv], &ss, sockaddr_len);
- } else {
- LOG(WARNING) << __func__ << ": can't get sa_len from "
- << info->nameserverSockAddrs.at(nserv);
- }
- }
-
- statp->nscount = serverNum;
+ statp->nscount = static_cast<int>(info->nameserverSockAddrs.size());
+ statp->nsaddrs = info->nameserverSockAddrs;
statp->search_domains = info->search_domains;
statp->tc_mode = info->tc_mode;
}
@@ -1811,18 +1798,18 @@
return -1;
}
-void resolv_cache_add_resolver_stats_sample(unsigned netid, int revision_id, const sockaddr* sa,
+void resolv_cache_add_resolver_stats_sample(unsigned netid, int revision_id,
+ const IPSockAddr& serverSockAddr,
const res_sample& sample, int max_samples) {
- if (max_samples <= 0 || sa == nullptr) return;
+ if (max_samples <= 0) return;
std::lock_guard guard(cache_mutex);
NetConfig* info = find_netconfig_locked(netid);
if (info && info->revision_id == revision_id) {
const int serverNum = std::min(MAXNS, static_cast<int>(info->nameserverSockAddrs.size()));
- const IPSockAddr ipsa = IPSockAddr::toIPSockAddr(*sa);
for (int ns = 0; ns < serverNum; ns++) {
- if (ipsa == info->nameserverSockAddrs.at(ns)) {
+ if (serverSockAddr == info->nameserverSockAddrs[ns]) {
res_cache_add_stats_sample_locked(&info->nsstats[ns], sample, max_samples);
return;
}
diff --git a/res_init.cpp b/res_init.cpp
index 049e225..ca054fc 100644
--- a/res_init.cpp
+++ b/res_init.cpp
@@ -97,17 +97,8 @@
statp->netid = netcontext->dns_netid;
statp->uid = netcontext->uid;
statp->pid = netcontext->pid;
- statp->nscount = 1;
+ statp->nscount = 0;
statp->id = arc4random_uniform(65536);
- // The following dummy initialization is probably useless because
- // it's overwritten later by resolv_populate_res_for_net().
- // TODO: check if it's safe to remove.
- const sockaddr_union u{
- .sin.sin_addr.s_addr = INADDR_ANY,
- .sin.sin_family = AF_INET,
- .sin.sin_port = htons(NAMESERVER_PORT),
- };
- memcpy(&statp->nsaddrs, &u, sizeof(u));
for (auto& sock : statp->nssocks) {
sock.reset();
diff --git a/res_send.cpp b/res_send.cpp
index 4db1b8b..cedb6a5 100644
--- a/res_send.cpp
+++ b/res_send.cpp
@@ -138,11 +138,13 @@
static DnsTlsDispatcher sDnsTlsDispatcher;
-static struct sockaddr* get_nsaddr(res_state, size_t);
-static int send_vc(res_state, res_params* params, const uint8_t*, int, uint8_t*, int, int*, int,
- time_t*, int*, int*);
-static int send_dg(res_state, res_params* params, const uint8_t*, int, uint8_t*, int, int*, int,
- int*, int*, time_t*, int*, int*);
+static int send_vc(res_state statp, res_params* params, const uint8_t* buf, int buflen,
+ uint8_t* ans, int anssiz, int* terrno, size_t ns, time_t* at, int* rcode,
+ int* delay);
+static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int buflen,
+ uint8_t* ans, int anssiz, int* terrno, size_t ns, int* v_circuit,
+ int* gotsomewhere, time_t* at, int* rcode, int* delay);
+
static void dump_error(const char*, const struct sockaddr*, int);
static int sock_eq(struct sockaddr*, struct sockaddr*);
@@ -288,13 +290,13 @@
static bool res_ourserver_p(res_state statp, const sockaddr* sa) {
const sockaddr_in *inp, *srv;
const sockaddr_in6 *in6p, *srv6;
- int ns;
switch (sa->sa_family) {
case AF_INET:
inp = (const struct sockaddr_in*) (const void*) sa;
- for (ns = 0; ns < statp->nscount; ns++) {
- srv = (struct sockaddr_in*) (void*) get_nsaddr(statp, (size_t) ns);
+ for (const IPSockAddr& ipsa : statp->nsaddrs) {
+ sockaddr_storage ss = ipsa;
+ srv = reinterpret_cast<sockaddr_in*>(&ss);
if (srv->sin_family == inp->sin_family && srv->sin_port == inp->sin_port &&
(srv->sin_addr.s_addr == INADDR_ANY ||
srv->sin_addr.s_addr == inp->sin_addr.s_addr))
@@ -303,8 +305,9 @@
break;
case AF_INET6:
in6p = (const struct sockaddr_in6*) (const void*) sa;
- for (ns = 0; ns < statp->nscount; ns++) {
- srv6 = (struct sockaddr_in6*) (void*) get_nsaddr(statp, (size_t) ns);
+ for (const IPSockAddr& ipsa : statp->nsaddrs) {
+ sockaddr_storage ss = ipsa;
+ srv6 = reinterpret_cast<sockaddr_in6*>(&ss);
if (srv6->sin6_family == in6p->sin6_family && srv6->sin6_port == in6p->sin6_port &&
#ifdef HAVE_SIN6_SCOPE_ID
(srv6->sin6_scope_id == 0 || srv6->sin6_scope_id == in6p->sin6_scope_id) &&
@@ -484,20 +487,15 @@
int terrno = ETIMEDOUT;
for (int attempt = 0; attempt < retryTimes; ++attempt) {
- for (int ns = 0; ns < statp->nscount; ++ns) {
+ for (size_t ns = 0; ns < statp->nsaddrs.size(); ++ns) {
if (!usable_servers[ns]) continue;
*rcode = RCODE_INTERNAL_ERROR;
// Get server addr
- const sockaddr* nsap = get_nsaddr(statp, ns);
- const int nsaplen = sockaddrSize(nsap);
-
- static const int niflags = NI_NUMERICHOST | NI_NUMERICSERV;
- char abuf[NI_MAXHOST];
- if (getnameinfo(nsap, (socklen_t)nsaplen, abuf, sizeof(abuf), NULL, 0, niflags) == 0)
- LOG(DEBUG) << __func__ << ": Querying server (# " << ns + 1
- << ") address = " << abuf;
+ const IPSockAddr& serverSockAddr = statp->nsaddrs[ns];
+ LOG(DEBUG) << __func__ << ": Querying server (# " << ns + 1
+ << ") address = " << serverSockAddr.toString();
::android::net::Protocol query_proto = useTcp ? PROTO_TCP : PROTO_UDP;
time_t now = 0;
@@ -533,7 +531,7 @@
dnsQueryEvent->set_cache_hit(static_cast<CacheStatus>(cache_status));
dnsQueryEvent->set_latency_micros(saturate_cast<int32_t>(queryStopwatch.timeTakenUs()));
dnsQueryEvent->set_dns_server_index(ns);
- dnsQueryEvent->set_ip_version(ipFamilyToIPVersion(nsap->sa_family));
+ dnsQueryEvent->set_ip_version(ipFamilyToIPVersion(serverSockAddr.family()));
dnsQueryEvent->set_retry_times(retry_count_for_event);
dnsQueryEvent->set_rcode(static_cast<NsRcode>(*rcode));
dnsQueryEvent->set_protocol(query_proto);
@@ -545,9 +543,9 @@
if (shouldRecordStats) {
res_sample sample;
_res_stats_set_sample(&sample, now, *rcode, delay);
- resolv_cache_add_resolver_stats_sample(statp->netid, revision_id, nsap, sample,
- params.max_samples);
- resolv_stats_add(statp->netid, IPSockAddr::toIPSockAddr(*nsap), dnsQueryEvent);
+ resolv_cache_add_resolver_stats_sample(statp->netid, revision_id, serverSockAddr,
+ sample, params.max_samples);
+ resolv_stats_add(statp->netid, serverSockAddr, dnsQueryEvent);
}
if (resplen == 0) continue;
@@ -582,12 +580,6 @@
return -terrno;
}
-/* Private */
-
-static struct sockaddr* get_nsaddr(res_state statp, size_t n) {
- return (struct sockaddr*)(void*)&statp->nsaddrs[n];
-}
-
static struct timespec get_timeout(res_state statp, const res_params* params, const int ns) {
int msec;
// Legacy algorithm which scales the timeout by nameserver number.
@@ -610,7 +602,7 @@
}
static int send_vc(res_state statp, res_params* params, const uint8_t* buf, int buflen,
- uint8_t* ans, int anssiz, int* terrno, int ns, time_t* at, int* rcode,
+ uint8_t* ans, int anssiz, int* terrno, size_t ns, time_t* at, int* rcode,
int* delay) {
*at = time(NULL);
*delay = 0;
@@ -623,7 +615,14 @@
LOG(INFO) << __func__ << ": using send_vc";
- nsap = get_nsaddr(statp, (size_t) ns);
+ // It should never happen, but just in case.
+ if (ns >= statp->nsaddrs.size()) {
+ LOG(ERROR) << __func__ << ": Out-of-bound indexing: " << ns;
+ return -1;
+ }
+
+ sockaddr_storage ss = statp->nsaddrs[ns];
+ nsap = reinterpret_cast<sockaddr*>(&ss);
nsaplen = sockaddrSize(nsap);
connreset = 0;
@@ -897,12 +896,18 @@
}
static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int buflen,
- uint8_t* ans, int anssiz, int* terrno, int ns, int* v_circuit, int* gotsomewhere,
- time_t* at, int* rcode, int* delay) {
+ uint8_t* ans, int anssiz, int* terrno, size_t ns, int* v_circuit,
+ int* gotsomewhere, time_t* at, int* rcode, int* delay) {
+ // It should never happen, but just in case.
+ if (ns >= statp->nsaddrs.size()) {
+ LOG(ERROR) << __func__ << ": Out-of-bound indexing: " << ns;
+ return -1;
+ }
+
*at = time(nullptr);
*delay = 0;
-
- const sockaddr* nsap = get_nsaddr(statp, (size_t)ns);
+ const sockaddr_storage ss = statp->nsaddrs[ns];
+ const sockaddr* nsap = reinterpret_cast<const sockaddr*>(&ss);
const int nsaplen = sockaddrSize(nsap);
if (statp->nssocks[ns] == -1) {
diff --git a/resolv_cache_unit_test.cpp b/resolv_cache_unit_test.cpp
index 03f8806..4a29258 100644
--- a/resolv_cache_unit_test.cpp
+++ b/resolv_cache_unit_test.cpp
@@ -38,6 +38,8 @@
using namespace std::chrono_literals;
+using android::netdutils::IPSockAddr;
+
constexpr int TEST_NETID = 30;
constexpr int TEST_NETID_2 = 31;
constexpr int DNS_PORT = 53;
@@ -190,9 +192,9 @@
return resolv_set_nameservers(netId, setup.servers, setup.domains, setup.params);
}
- void cacheAddStats(uint32_t netId, int revision_id, const sockaddr* sa,
+ void cacheAddStats(uint32_t netId, int revision_id, const IPSockAddr& ipsa,
const res_sample& sample, int max_samples) {
- resolv_cache_add_resolver_stats_sample(netId, revision_id, sa, sample, max_samples);
+ resolv_cache_add_resolver_stats_sample(netId, revision_id, ipsa, sample, max_samples);
}
int cacheFlush(uint32_t netId) { return resolv_flush_cache_for_net(netId); }
@@ -735,8 +737,7 @@
res_sample sample = {.at = time(NULL), .rtt = 100, .rcode = ns_r_noerror};
sockaddr_in sin = {.sin_family = AF_INET, .sin_port = htons(DNS_PORT)};
ASSERT_TRUE(inet_pton(AF_INET, setup.servers[0].c_str(), &sin.sin_addr));
- cacheAddStats(TEST_NETID, 1 /*revision_id*/, (const sockaddr*)&sin, sample,
- setup.params.max_samples);
+ cacheAddStats(TEST_NETID, 1 /*revision_id*/, IPSockAddr(sin), sample, setup.params.max_samples);
const CacheStats cacheStats = {
.setup = setup,
diff --git a/resolv_private.h b/resolv_private.h
index e217087..c9e7f8f 100644
--- a/resolv_private.h
+++ b/resolv_private.h
@@ -55,6 +55,8 @@
#include <string>
#include <vector>
+#include <netdutils/InternetAddresses.h>
+
#include "DnsResolver.h"
#include "netd_resolv/resolv.h"
#include "params.h"
@@ -102,7 +104,7 @@
int nscount; // number of name srvers
uint16_t id; // current message id
std::vector<std::string> search_domains{}; // domains to search
- sockaddr_union nsaddrs[MAXNS];
+ std::vector<android::netdutils::IPSockAddr> nsaddrs;
android::base::unique_fd nssocks[MAXNS]; // UDP sockets to nameservers
unsigned ndots : 4; // threshold for initial abs. query
unsigned _mark; // If non-0 SET_MARK to _mark on all request sockets
@@ -125,7 +127,8 @@
/* Add a sample to the shared struct for the given netid and server, provided that the
* revision_id of the stored servers has not changed.
*/
-void resolv_cache_add_resolver_stats_sample(unsigned netid, int revision_id, const sockaddr* sa,
+void resolv_cache_add_resolver_stats_sample(unsigned netid, int revision_id,
+ const android::netdutils::IPSockAddr& serverSockAddr,
const res_sample& sample, int max_samples);
// Calculate the round-trip-time from start time t0 and end time t1.