Improve resolver cache lock and thread synchronization
This change comprises:
1. replace pthread_mutex with std::mutex to realize RAII.
2. have a single condition variable to prevent race condition between
threads.
3. add 'predicate' to avoid spurious awakenings.
4. add a parameter in GetResolverInfo API which enabling test cases to
know if timeout happened among concurrent DNS queries. Also, dump it in
bugreport.
5. verify if test case GetAddrInfoV6_concurrent,
GetAddrInfoStressTest_Binder_100 and
GetAddrInfoStressTest_Binder_100000 pass because of timeout on
concurrent DNS queries.
Bug: 120794954
Test: runtests.sh pass
Test: run ResolverTest.GetAddrInfoV6_concurrent 100 times
Test: run ResolverTest.GetAddrInfoStressTest_Binder_100 100 times
Test: run ResolverTest.GetAddrInfoStressTest_Binder_100000 100 times
The test script:
----------------------------
for i in $(seq 1 100)
do
adb shell resolv_integration_test --gtest_filter=ResolverTest.GetAddrInfoV6_concurrent
adb shell resolv_integration_test --gtest_filter=ResolverTest.GetAddrInfoStressTest_Binder_100
adb shell resolv_integration_test --gtest_filter=ResolverTest.GetAddrInfoStressTest_Binder_100000
done
exit 0
Change-Id: I4bdc394ba7ded7a6b7239f2d35b559a4262cb7b9
diff --git a/resolv/dns_responder/dns_responder.cpp b/resolv/dns_responder/dns_responder.cpp
index 3da5f70..fe116b8 100644
--- a/resolv/dns_responder/dns_responder.cpp
+++ b/resolv/dns_responder/dns_responder.cpp
@@ -723,6 +723,13 @@
size_t response_len = sizeof(response);
if (handleDNSRequest(buffer, len, response, &response_len) &&
response_len > 0) {
+ // place wait_for after handleDNSRequest() so we can check the number of queries in
+ // test case before it got responded.
+ std::unique_lock guard(cv_mutex_for_deferred_resp_);
+ cv_for_deferred_resp_.wait(guard, [this]() REQUIRES(cv_mutex_for_deferred_resp_) {
+ return !deferred_resp_;
+ });
+
len = sendto(socket_, response, response_len, 0,
reinterpret_cast<const sockaddr*>(&sa), sa_len);
std::string host_str =
@@ -916,4 +923,12 @@
return true;
}
+void DNSResponder::setDeferredResp(bool deferred_resp) {
+ std::lock_guard<std::mutex> guard(cv_mutex_for_deferred_resp_);
+ deferred_resp_ = deferred_resp;
+ if (!deferred_resp_) {
+ cv_for_deferred_resp_.notify_one();
+ }
+}
+
} // namespace test
diff --git a/resolv/dns_responder/dns_responder.h b/resolv/dns_responder/dns_responder.h
index 0005e44..360ddfc 100644
--- a/resolv/dns_responder/dns_responder.h
+++ b/resolv/dns_responder/dns_responder.h
@@ -79,6 +79,7 @@
void clearQueries();
std::condition_variable& getCv() { return cv; }
std::mutex& getCvMutex() { return cv_mutex_; }
+ void setDeferredResp(bool deferred_resp);
private:
// Key used for accessing mappings.
@@ -159,6 +160,10 @@
std::mutex update_mutex_;
std::condition_variable cv;
std::mutex cv_mutex_;
+
+ std::condition_variable cv_for_deferred_resp_;
+ std::mutex cv_mutex_for_deferred_resp_;
+ bool deferred_resp_ GUARDED_BY(cv_mutex_for_deferred_resp_) = false;
};
} // namespace test
diff --git a/resolv/include/netd_resolv/resolv_stub.h b/resolv/include/netd_resolv/resolv_stub.h
index 9e09b44..0c4dcd9 100644
--- a/resolv/include/netd_resolv/resolv_stub.h
+++ b/resolv/include/netd_resolv/resolv_stub.h
@@ -50,7 +50,8 @@
int (*android_net_res_stats_get_info_for_net)(unsigned netid, int* nscount,
sockaddr_storage servers[MAXNS], int* dcount,
char domains[MAXDNSRCH][MAXDNSRCHPATH],
- __res_params* params, res_stats stats[MAXNS]);
+ __res_params* params, res_stats stats[MAXNS],
+ int* wait_for_pending_req_timeout_count);
void (*android_net_res_stats_get_usable_servers)(const __res_params* params, res_stats stats[],
int nscount, bool valid_servers[]);
diff --git a/resolv/include/netd_resolv/stats.h b/resolv/include/netd_resolv/stats.h
index c504115..e078e72 100644
--- a/resolv/include/netd_resolv/stats.h
+++ b/resolv/include/netd_resolv/stats.h
@@ -51,7 +51,8 @@
LIBNETD_RESOLV_PUBLIC int android_net_res_stats_get_info_for_net(
unsigned netid, int* nscount, sockaddr_storage servers[MAXNS], int* dcount,
- char domains[MAXDNSRCH][MAXDNSRCHPATH], __res_params* params, res_stats stats[MAXNS]);
+ char domains[MAXDNSRCH][MAXDNSRCHPATH], __res_params* params, res_stats stats[MAXNS],
+ int* wait_for_pending_req_timeout_count);
// Returns an array of bools indicating which servers are considered good
LIBNETD_RESOLV_PUBLIC void android_net_res_stats_get_usable_servers(const __res_params* params,
diff --git a/resolv/res_cache.cpp b/resolv/res_cache.cpp
index e7a3b9b..756844c 100644
--- a/resolv/res_cache.cpp
+++ b/resolv/res_cache.cpp
@@ -32,13 +32,13 @@
constexpr bool kDumpData = false;
#define LOG_TAG "res_cache"
-#include <pthread.h>
#include <resolv.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
+#include <mutex>
#include <arpa/inet.h>
#include <arpa/nameser.h>
@@ -48,6 +48,7 @@
#include <netdb.h>
#include <android-base/logging.h>
+#include <android-base/thread_annotations.h>
#include "res_state_ext.h"
#include "resolv_cache.h"
@@ -1122,13 +1123,7 @@
*/
/* Maximum time for a thread to wait for an pending request */
-#define PENDING_REQUEST_TIMEOUT 20;
-
-typedef struct pending_req_info {
- unsigned int hash;
- pthread_cond_t cond;
- struct pending_req_info* next;
-} PendingReqInfo;
+constexpr int PENDING_REQUEST_TIMEOUT = 20;
typedef struct resolv_cache {
int max_entries;
@@ -1136,7 +1131,10 @@
Entry mru_list;
int last_id;
Entry* entries;
- PendingReqInfo pending_requests;
+ struct pending_req_info {
+ unsigned int hash;
+ struct pending_req_info* next;
+ } pending_requests;
} Cache;
struct resolv_cache_info {
@@ -1151,102 +1149,87 @@
struct res_stats nsstats[MAXNS];
char defdname[MAXDNSRCHPATH];
int dnsrch_offset[MAXDNSRCH + 1]; // offsets into defdname
+ int wait_for_pending_req_timeout_count;
};
-static pthread_once_t _res_cache_once = PTHREAD_ONCE_INIT;
-static void res_cache_init(void);
+// A helper class for the Clang Thread Safety Analysis to deal with
+// std::unique_lock.
+class SCOPED_CAPABILITY ScopedAssumeLocked {
+ public:
+ ScopedAssumeLocked(std::mutex& mutex) ACQUIRE(mutex) {}
+ ~ScopedAssumeLocked() RELEASE() {}
+};
-// lock protecting everything in the _resolve_cache_info structs (next ptr, etc)
-static pthread_mutex_t res_cache_list_lock;
+// lock protecting everything in the resolve_cache_info structs (next ptr, etc)
+static std::mutex cache_mutex;
+static std::condition_variable cv;
/* gets cache associated with a network, or NULL if none exists */
-static struct resolv_cache* find_named_cache_locked(unsigned netid);
+static struct resolv_cache* find_named_cache_locked(unsigned netid) REQUIRES(cache_mutex);
+static resolv_cache* get_res_cache_for_net_locked(unsigned netid) REQUIRES(cache_mutex);
-static void _cache_flush_pending_requests_locked(struct resolv_cache* cache) {
- struct pending_req_info *ri, *tmp;
- if (cache) {
- ri = cache->pending_requests.next;
+static void cache_flush_pending_requests_locked(struct resolv_cache* cache) {
+ resolv_cache::pending_req_info *ri, *tmp;
+ if (!cache) return;
- while (ri) {
- tmp = ri;
- ri = ri->next;
- pthread_cond_broadcast(&tmp->cond);
+ ri = cache->pending_requests.next;
- pthread_cond_destroy(&tmp->cond);
- free(tmp);
- }
-
- cache->pending_requests.next = NULL;
- }
-}
-
-/* Return 0 if no pending request is found matching the key.
- * If a matching request is found the calling thread will wait until
- * the matching request completes, then update *cache and return 1. */
-static int _cache_check_pending_request_locked(struct resolv_cache** cache, Entry* key,
- unsigned netid) {
- struct pending_req_info *ri, *prev;
- int exist = 0;
-
- if (*cache && key) {
- ri = (*cache)->pending_requests.next;
- prev = &(*cache)->pending_requests;
- while (ri) {
- if (ri->hash == key->hash) {
- exist = 1;
- break;
- }
- prev = ri;
- ri = ri->next;
- }
-
- if (!exist) {
- ri = (struct pending_req_info*) calloc(1, sizeof(struct pending_req_info));
- if (ri) {
- ri->hash = key->hash;
- pthread_cond_init(&ri->cond, NULL);
- prev->next = ri;
- }
- } else {
- struct timespec ts = {0, 0};
- VLOG << "Waiting for previous request";
- ts.tv_sec = _time_now() + PENDING_REQUEST_TIMEOUT;
- pthread_cond_timedwait(&ri->cond, &res_cache_list_lock, &ts);
- /* Must update *cache as it could have been deleted. */
- *cache = find_named_cache_locked(netid);
- }
+ while (ri) {
+ tmp = ri;
+ ri = ri->next;
+ free(tmp);
}
- return exist;
+ cache->pending_requests.next = NULL;
+ cv.notify_all();
}
-/* notify any waiting thread that waiting on a request
- * matching the key has been added to the cache */
-static void _cache_notify_waiting_tid_locked(struct resolv_cache* cache, Entry* key) {
- struct pending_req_info *ri, *prev;
+// Return true - if there is a pending request in |cache| matching |key|.
+// Return false - if no pending request is found matching the key. Optionally
+// link a new one if parameter append_if_not_found is true.
+static bool cache_has_pending_request_locked(resolv_cache* cache, const Entry* key,
+ bool append_if_not_found) {
+ if (!cache || !key) return false;
- if (cache && key) {
- ri = cache->pending_requests.next;
- prev = &cache->pending_requests;
- while (ri) {
- if (ri->hash == key->hash) {
- pthread_cond_broadcast(&ri->cond);
- break;
- }
- prev = ri;
- ri = ri->next;
+ resolv_cache::pending_req_info* ri = cache->pending_requests.next;
+ resolv_cache::pending_req_info* prev = &cache->pending_requests;
+ while (ri) {
+ if (ri->hash == key->hash) {
+ return true;
}
+ prev = ri;
+ ri = ri->next;
+ }
- // remove item from list and destroy
+ if (append_if_not_found) {
+ ri = (resolv_cache::pending_req_info*)calloc(1, sizeof(resolv_cache::pending_req_info));
if (ri) {
- prev->next = ri->next;
- pthread_cond_destroy(&ri->cond);
- free(ri);
+ ri->hash = key->hash;
+ prev->next = ri;
}
}
+ return false;
+}
+
+// Notify all threads that the cache entry |key| has become available
+static void _cache_notify_waiting_tid_locked(struct resolv_cache* cache, const Entry* key) {
+ if (!cache || !key) return;
+
+ resolv_cache::pending_req_info* ri = cache->pending_requests.next;
+ resolv_cache::pending_req_info* prev = &cache->pending_requests;
+ while (ri) {
+ if (ri->hash == key->hash) {
+ // remove item from list and destroy
+ prev->next = ri->next;
+ free(ri);
+ cv.notify_all();
+ return;
+ }
+ prev = ri;
+ ri = ri->next;
+ }
}
-/* notify the cache that the query failed */
void _resolv_cache_query_failed(unsigned netid, const void* query, int querylen, uint32_t flags) {
// We should not notify with these flags.
if (flags & (ANDROID_RESOLV_NO_CACHE_STORE | ANDROID_RESOLV_NO_CACHE_LOOKUP)) {
@@ -1257,19 +1240,15 @@
if (!entry_init_key(key, query, querylen)) return;
- pthread_mutex_lock(&res_cache_list_lock);
+ std::lock_guard guard(cache_mutex);
cache = find_named_cache_locked(netid);
if (cache) {
_cache_notify_waiting_tid_locked(cache, key);
}
-
- pthread_mutex_unlock(&res_cache_list_lock);
}
-static resolv_cache_info* find_cache_info_locked(unsigned netid);
-
static void cache_flush_locked(Cache* cache) {
int nn;
@@ -1284,7 +1263,7 @@
}
// flush pending request
- _cache_flush_pending_requests_locked(cache);
+ cache_flush_pending_requests_locked(cache);
cache->mru_list.mru_next = cache->mru_list.mru_prev = &cache->mru_list;
cache->num_entries = 0;
@@ -1293,7 +1272,7 @@
VLOG << "*** DNS CACHE FLUSHED ***";
}
-static struct resolv_cache* _resolv_cache_create(void) {
+static resolv_cache* resolv_cache_create() {
struct resolv_cache* cache;
cache = (struct resolv_cache*) calloc(sizeof(*cache), 1);
@@ -1465,60 +1444,74 @@
}
}
+// gets a resolv_cache_info associated with a network, or NULL if not found
+static resolv_cache_info* find_cache_info_locked(unsigned netid) REQUIRES(cache_mutex);
+
ResolvCacheStatus _resolv_cache_lookup(unsigned netid, const void* query, int querylen,
void* answer, int answersize, int* answerlen,
uint32_t flags) {
if (flags & ANDROID_RESOLV_NO_CACHE_LOOKUP) {
return RESOLV_CACHE_SKIP;
}
- Entry key[1];
+ Entry key;
Entry** lookup;
Entry* e;
time_t now;
Cache* cache;
- ResolvCacheStatus result = RESOLV_CACHE_NOTFOUND;
-
VLOG << __func__ << ": lookup";
dump_query((u_char*) query, querylen);
/* we don't cache malformed queries */
- if (!entry_init_key(key, query, querylen)) {
+ if (!entry_init_key(&key, query, querylen)) {
VLOG << __func__ << ": unsupported query";
return RESOLV_CACHE_UNSUPPORTED;
}
/* lookup cache */
- pthread_once(&_res_cache_once, res_cache_init);
- pthread_mutex_lock(&res_cache_list_lock);
-
+ std::unique_lock lock(cache_mutex);
+ ScopedAssumeLocked assume_lock(cache_mutex);
cache = find_named_cache_locked(netid);
if (cache == NULL) {
- result = RESOLV_CACHE_UNSUPPORTED;
- goto Exit;
+ return RESOLV_CACHE_UNSUPPORTED;
}
/* see the description of _lookup_p to understand this.
* the function always return a non-NULL pointer.
*/
- lookup = _cache_lookup_p(cache, key);
+ lookup = _cache_lookup_p(cache, &key);
e = *lookup;
if (e == NULL) {
VLOG << "NOT IN CACHE";
// If it is no-cache-store mode, we won't wait for possible query.
if (flags & ANDROID_RESOLV_NO_CACHE_STORE) {
- result = RESOLV_CACHE_SKIP;
- goto Exit;
+ return RESOLV_CACHE_SKIP;
}
- // calling thread will wait if an outstanding request is found
- // that matching this query
- if (!_cache_check_pending_request_locked(&cache, key, netid) || cache == NULL) {
- goto Exit;
+
+ if (!cache_has_pending_request_locked(cache, &key, true)) {
+ return RESOLV_CACHE_NOTFOUND;
+
} else {
- lookup = _cache_lookup_p(cache, key);
+ VLOG << "Waiting for previous request";
+ // wait until (1) timeout OR
+ // (2) cv is notified AND no pending request matching the |key|
+ // (cv notifier should delete pending request before sending notification.)
+ bool ret = cv.wait_for(lock, std::chrono::seconds(PENDING_REQUEST_TIMEOUT),
+ [netid, &cache, &key]() REQUIRES(cache_mutex) {
+ // Must update cache as it could have been deleted
+ cache = find_named_cache_locked(netid);
+ return !cache_has_pending_request_locked(cache, &key, false);
+ });
+ if (ret == false) {
+ resolv_cache_info* info = find_cache_info_locked(netid);
+ if (info != NULL) {
+ info->wait_for_pending_req_timeout_count++;
+ }
+ }
+ lookup = _cache_lookup_p(cache, &key);
e = *lookup;
if (e == NULL) {
- goto Exit;
+ return RESOLV_CACHE_NOTFOUND;
}
}
}
@@ -1530,15 +1523,14 @@
VLOG << " NOT IN CACHE (STALE ENTRY " << *lookup << "DISCARDED)";
dump_query(e->query, e->querylen);
_cache_remove_p(cache, lookup);
- goto Exit;
+ return RESOLV_CACHE_NOTFOUND;
}
*answerlen = e->answerlen;
if (e->answerlen > answersize) {
/* NOTE: we return UNSUPPORTED if the answer buffer is too short */
- result = RESOLV_CACHE_UNSUPPORTED;
VLOG << " ANSWER TOO LONG";
- goto Exit;
+ return RESOLV_CACHE_UNSUPPORTED;
}
memcpy(answer, e->answer, e->answerlen);
@@ -1549,12 +1541,8 @@
entry_mru_add(e, &cache->mru_list);
}
- VLOG << "FOUND IN CACHE entry=" << e;
- result = RESOLV_CACHE_FOUND;
-
-Exit:
- pthread_mutex_unlock(&res_cache_list_lock);
- return result;
+ VLOG << " FOUND IN CACHE entry=" << e;
+ return RESOLV_CACHE_FOUND;
}
void _resolv_cache_add(unsigned netid, const void* query, int querylen, const void* answer,
@@ -1572,11 +1560,11 @@
return;
}
- pthread_mutex_lock(&res_cache_list_lock);
+ std::lock_guard guard(cache_mutex);
cache = find_named_cache_locked(netid);
if (cache == NULL) {
- goto Exit;
+ return;
}
VLOG << __func__ << ": query:";
@@ -1592,7 +1580,8 @@
if (e != NULL) { /* should not happen */
VLOG << __func__ << ": ALREADY IN CACHE (" << e << ") ? IGNORING ADD";
- goto Exit;
+ _cache_notify_waiting_tid_locked(cache, key);
+ return;
}
if (cache->num_entries >= cache->max_entries) {
@@ -1605,7 +1594,8 @@
e = *lookup;
if (e != NULL) {
VLOG << __func__ << ": ALREADY IN CACHE (" << e << ") ? IGNORING ADD";
- goto Exit;
+ _cache_notify_waiting_tid_locked(cache, key);
+ return;
}
}
@@ -1617,24 +1607,18 @@
_cache_add_p(cache, lookup, e);
}
}
- cache_dump_mru(cache);
-Exit:
- if (cache != NULL) {
- _cache_notify_waiting_tid_locked(cache, key);
- }
- pthread_mutex_unlock(&res_cache_list_lock);
+ cache_dump_mru(cache);
+ _cache_notify_waiting_tid_locked(cache, key);
}
-// Head of the list of caches. Protected by _res_cache_list_lock.
-static struct resolv_cache_info res_cache_list;
+// Head of the list of caches.
+static struct resolv_cache_info res_cache_list GUARDED_BY(cache_mutex);
// insert resolv_cache_info into the list of resolv_cache_infos
static void insert_cache_info_locked(resolv_cache_info* cache_info);
// creates a resolv_cache_info
static resolv_cache_info* create_cache_info();
-// gets a resolv_cache_info associated with a network, or NULL if not found
-static resolv_cache_info* find_cache_info_locked(unsigned netid);
// empty the nameservers set for the named cache
static void free_nameservers_locked(resolv_cache_info* cache_info);
// return 1 if the provided list of name servers differs from the list of name servers
@@ -1644,20 +1628,11 @@
// clears the stats samples contained withing the given cache_info
static void res_cache_clear_stats_locked(resolv_cache_info* cache_info);
-static void res_cache_init(void) {
- memset(&res_cache_list, 0, sizeof(res_cache_list));
- pthread_mutex_init(&res_cache_list_lock, NULL);
-}
-
// public API for netd to query if name server is set on specific netid
bool resolv_has_nameservers(unsigned netid) {
- pthread_once(&_res_cache_once, res_cache_init);
- pthread_mutex_lock(&res_cache_list_lock);
+ std::lock_guard guard(cache_mutex);
resolv_cache_info* info = find_cache_info_locked(netid);
- const bool ret = (info != nullptr) && (info->nscount > 0);
- pthread_mutex_unlock(&res_cache_list_lock);
-
- return ret;
+ return (info != nullptr) && (info->nscount > 0);
}
// look up the named cache, and creates one if needed
@@ -1666,7 +1641,7 @@
if (!cache) {
resolv_cache_info* cache_info = create_cache_info();
if (cache_info) {
- cache = _resolv_cache_create();
+ cache = resolv_cache_create();
if (cache) {
cache_info->cache = cache;
cache_info->netid = netid;
@@ -1680,8 +1655,7 @@
}
void resolv_delete_cache_for_net(unsigned netid) {
- pthread_once(&_res_cache_once, res_cache_init);
- pthread_mutex_lock(&res_cache_list_lock);
+ std::lock_guard guard(cache_mutex);
struct resolv_cache_info* prev_cache_info = &res_cache_list;
@@ -1700,14 +1674,13 @@
prev_cache_info = prev_cache_info->next;
}
-
- pthread_mutex_unlock(&res_cache_list_lock);
}
static resolv_cache_info* create_cache_info() {
return (struct resolv_cache_info*) calloc(sizeof(struct resolv_cache_info), 1);
}
+// TODO: convert this to a simple and efficient C++ container.
static void insert_cache_info_locked(struct resolv_cache_info* cache_info) {
struct resolv_cache_info* last;
for (last = &res_cache_list; last->next; last = last->next) {}
@@ -1771,9 +1744,7 @@
}
}
- pthread_once(&_res_cache_once, res_cache_init);
- pthread_mutex_lock(&res_cache_list_lock);
-
+ std::lock_guard guard(cache_mutex);
// creates the cache if not created
get_res_cache_for_net_locked(netid);
@@ -1847,7 +1818,6 @@
*offset = -1; /* cache_info->dnsrch_offset has MAXDNSRCH+1 items */
}
- pthread_mutex_unlock(&res_cache_list_lock);
return 0;
}
@@ -1898,9 +1868,7 @@
return;
}
- pthread_once(&_res_cache_once, res_cache_init);
- pthread_mutex_lock(&res_cache_list_lock);
-
+ std::lock_guard guard(cache_mutex);
resolv_cache_info* info = find_cache_info_locked(statp->netid);
if (info != NULL) {
int nserv;
@@ -1938,7 +1906,6 @@
*pp++ = &statp->defdname[0] + *p++;
}
}
- pthread_mutex_unlock(&res_cache_list_lock);
}
/* Resolver reachability statistics. */
@@ -1970,14 +1937,14 @@
struct sockaddr_storage servers[MAXNS], int* dcount,
char domains[MAXDNSRCH][MAXDNSRCHPATH],
struct __res_params* params,
- struct res_stats stats[MAXNS]) {
+ struct res_stats stats[MAXNS],
+ int* wait_for_pending_req_timeout_count) {
int revision_id = -1;
- pthread_mutex_lock(&res_cache_list_lock);
+ std::lock_guard guard(cache_mutex);
resolv_cache_info* info = find_cache_info_locked(netid);
if (info) {
if (info->nscount > MAXNS) {
- pthread_mutex_unlock(&res_cache_list_lock);
VLOG << __func__ << ": nscount " << info->nscount << " > MAXNS " << MAXNS;
errno = EFAULT;
return -1;
@@ -1990,19 +1957,16 @@
// - there is only one address per addrinfo thanks to numeric resolution
int addrlen = info->nsaddrinfo[i]->ai_addrlen;
if (addrlen < (int) sizeof(struct sockaddr) || addrlen > (int) sizeof(servers[0])) {
- pthread_mutex_unlock(&res_cache_list_lock);
VLOG << __func__ << ": nsaddrinfo[" << i << "].ai_addrlen == " << addrlen;
errno = EMSGSIZE;
return -1;
}
if (info->nsaddrinfo[i]->ai_addr == NULL) {
- pthread_mutex_unlock(&res_cache_list_lock);
VLOG << __func__ << ": nsaddrinfo[" << i << "].ai_addr == NULL";
errno = ENOENT;
return -1;
}
if (info->nsaddrinfo[i]->ai_next != NULL) {
- pthread_mutex_unlock(&res_cache_list_lock);
VLOG << __func__ << ": nsaddrinfo[" << i << "].ai_next != NULL";
errno = ENOTUNIQ;
return -1;
@@ -2028,38 +1992,32 @@
*dcount = i;
*params = info->params;
revision_id = info->revision_id;
+ *wait_for_pending_req_timeout_count = info->wait_for_pending_req_timeout_count;
}
- pthread_mutex_unlock(&res_cache_list_lock);
return revision_id;
}
int resolv_cache_get_resolver_stats(unsigned netid, __res_params* params, res_stats stats[MAXNS]) {
- int revision_id = -1;
- pthread_mutex_lock(&res_cache_list_lock);
-
+ std::lock_guard guard(cache_mutex);
resolv_cache_info* info = find_cache_info_locked(netid);
if (info) {
memcpy(stats, info->nsstats, sizeof(info->nsstats));
*params = info->params;
- revision_id = info->revision_id;
+ return info->revision_id;
}
- pthread_mutex_unlock(&res_cache_list_lock);
- return revision_id;
+ return -1;
}
void _resolv_cache_add_resolver_stats_sample(unsigned netid, int revision_id, int ns,
const res_sample* sample, int max_samples) {
if (max_samples <= 0) return;
- pthread_mutex_lock(&res_cache_list_lock);
-
+ std::lock_guard guard(cache_mutex);
resolv_cache_info* info = find_cache_info_locked(netid);
if (info && info->revision_id == revision_id) {
_res_cache_add_stats_sample_locked(&info->nsstats[ns], sample, max_samples);
}
-
- pthread_mutex_unlock(&res_cache_list_lock);
}
diff --git a/resolv/resolver_test.cpp b/resolv/resolver_test.cpp
index 0cce121..4be5f48 100644
--- a/resolv/resolver_test.cpp
+++ b/resolv/resolver_test.cpp
@@ -120,12 +120,16 @@
bool GetResolverInfo(std::vector<std::string>* servers, std::vector<std::string>* domains,
std::vector<std::string>* tlsServers, __res_params* params,
- std::vector<ResolverStats>* stats) {
+ std::vector<ResolverStats>* stats,
+ int* wait_for_pending_req_timeout_count) {
using android::net::INetd;
std::vector<int32_t> params32;
std::vector<int32_t> stats32;
+ std::vector<int32_t> wait_for_pending_req_timeout_count32{0};
auto rv = mDnsClient.netdService()->getResolverInfo(TEST_NETID, servers, domains,
- tlsServers, ¶ms32, &stats32);
+ tlsServers, ¶ms32, &stats32,
+ &wait_for_pending_req_timeout_count32);
+
if (!rv.isOk() || params32.size() != static_cast<size_t>(INetd::RESOLVER_PARAMS_COUNT)) {
return false;
}
@@ -140,6 +144,7 @@
params32[INetd::RESOLVER_PARAMS_MAX_SAMPLES]),
.base_timeout_msec = params32[INetd::RESOLVER_PARAMS_BASE_TIMEOUT_MSEC],
};
+ *wait_for_pending_req_timeout_count = wait_for_pending_req_timeout_count32[0];
return ResolverStats::decodeAll(stats32, stats);
}
@@ -272,6 +277,17 @@
auto t1 = std::chrono::steady_clock::now();
ALOGI("%u hosts, %u threads, %u queries, %Es", num_hosts, num_threads, num_queries,
std::chrono::duration<double>(t1 - t0).count());
+
+ std::vector<std::string> res_servers;
+ std::vector<std::string> res_domains;
+ std::vector<std::string> res_tls_servers;
+ __res_params res_params;
+ std::vector<ResolverStats> res_stats;
+ int wait_for_pending_req_timeout_count;
+ ASSERT_TRUE(GetResolverInfo(&res_servers, &res_domains, &res_tls_servers, &res_params,
+ &res_stats, &wait_for_pending_req_timeout_count));
+ EXPECT_EQ(0, wait_for_pending_req_timeout_count);
+
ASSERT_NO_FATAL_FAILURE(mDnsClient.ShutdownDNSServers(&dns));
}
@@ -463,8 +479,9 @@
std::vector<std::string> res_tls_servers;
__res_params res_params;
std::vector<ResolverStats> res_stats;
- ASSERT_TRUE(
- GetResolverInfo(&res_servers, &res_domains, &res_tls_servers, &res_params, &res_stats));
+ int wait_for_pending_req_timeout_count;
+ ASSERT_TRUE(GetResolverInfo(&res_servers, &res_domains, &res_tls_servers, &res_params,
+ &res_stats, &wait_for_pending_req_timeout_count));
EXPECT_EQ(servers.size(), res_servers.size());
EXPECT_EQ(domains.size(), res_domains.size());
EXPECT_EQ(0U, res_tls_servers.size());
@@ -571,6 +588,95 @@
EXPECT_EQ(kIp6LocalHostAddr, ToString(result));
}
+// Verify if the resolver correctly handle multiple queries simultaneously
+// step 1: set dns server#1 into deferred responding mode.
+// step 2: thread#1 query "hello.example.com." --> resolver send query to server#1.
+// step 3: thread#2 query "hello.example.com." --> resolver hold the request and wait for
+// response of previous pending query sent by thread#1.
+// step 4: thread#3 query "konbanha.example.com." --> resolver send query to server#3. Server
+// respond to resolver immediately.
+// step 5: check if server#1 get 1 query by thread#1, server#2 get 0 query, server#3 get 1 query.
+// step 6: resume dns server#1 to respond dns query in step#2.
+// step 7: thread#1 and #2 should get returned from DNS query after step#6. Also, check the
+// number of queries in server#2 is 0 to ensure thread#2 does not wake up unexpectedly
+// before signaled by thread#1.
+TEST_F(ResolverTest, GetAddrInfoV4_deferred_resp) {
+ const char* listen_addr1 = "127.0.0.9";
+ const char* listen_addr2 = "127.0.0.10";
+ const char* listen_addr3 = "127.0.0.11";
+ const char* listen_srv = "53";
+ const char* host_name_deferred = "hello.example.com.";
+ const char* host_name_normal = "konbanha.example.com.";
+ test::DNSResponder dns1(listen_addr1, listen_srv, 250, ns_rcode::ns_r_servfail);
+ test::DNSResponder dns2(listen_addr2, listen_srv, 250, ns_rcode::ns_r_servfail);
+ test::DNSResponder dns3(listen_addr3, listen_srv, 250, ns_rcode::ns_r_servfail);
+ dns1.addMapping(host_name_deferred, ns_type::ns_t_a, "1.2.3.4");
+ dns2.addMapping(host_name_deferred, ns_type::ns_t_a, "1.2.3.4");
+ dns3.addMapping(host_name_normal, ns_type::ns_t_a, "1.2.3.5");
+ ASSERT_TRUE(dns1.startServer());
+ ASSERT_TRUE(dns2.startServer());
+ ASSERT_TRUE(dns3.startServer());
+ const std::vector<std::string> servers_for_t1 = {listen_addr1};
+ const std::vector<std::string> servers_for_t2 = {listen_addr2};
+ const std::vector<std::string> servers_for_t3 = {listen_addr3};
+ addrinfo hints = {.ai_family = AF_INET};
+ const std::vector<int> params = {300, 25, 8, 8, 5000};
+ bool t3_task_done = false;
+
+ dns1.setDeferredResp(true);
+ std::thread t1([&, this]() {
+ ASSERT_TRUE(
+ mDnsClient.SetResolversForNetwork(servers_for_t1, kDefaultSearchDomains, params));
+ ScopedAddrinfo result = safe_getaddrinfo(host_name_deferred, nullptr, &hints);
+ // t3's dns query should got returned first
+ EXPECT_TRUE(t3_task_done);
+ EXPECT_EQ(1U, GetNumQueries(dns1, host_name_deferred));
+ EXPECT_TRUE(result != nullptr);
+ EXPECT_EQ("1.2.3.4", ToString(result));
+ });
+
+ // ensuring t1 and t2 handler functions are processed in order
+ usleep(100 * 1000);
+ std::thread t2([&, this]() {
+ ASSERT_TRUE(
+ mDnsClient.SetResolversForNetwork(servers_for_t2, kDefaultSearchDomains, params));
+ ScopedAddrinfo result = safe_getaddrinfo(host_name_deferred, nullptr, &hints);
+ EXPECT_TRUE(t3_task_done);
+ EXPECT_EQ(0U, GetNumQueries(dns2, host_name_deferred));
+ EXPECT_TRUE(result != nullptr);
+ EXPECT_EQ("1.2.3.4", ToString(result));
+
+ std::vector<std::string> res_servers;
+ std::vector<std::string> res_domains;
+ std::vector<std::string> res_tls_servers;
+ __res_params res_params;
+ std::vector<ResolverStats> res_stats;
+ int wait_for_pending_req_timeout_count;
+ ASSERT_TRUE(GetResolverInfo(&res_servers, &res_domains, &res_tls_servers, &res_params,
+ &res_stats, &wait_for_pending_req_timeout_count));
+ EXPECT_EQ(0, wait_for_pending_req_timeout_count);
+ });
+
+ // ensuring t2 and t3 handler functions are processed in order
+ usleep(100 * 1000);
+ std::thread t3([&, this]() {
+ ASSERT_TRUE(
+ mDnsClient.SetResolversForNetwork(servers_for_t3, kDefaultSearchDomains, params));
+ ScopedAddrinfo result = safe_getaddrinfo(host_name_normal, nullptr, &hints);
+ EXPECT_EQ(1U, GetNumQueries(dns1, host_name_deferred));
+ EXPECT_EQ(0U, GetNumQueries(dns2, host_name_deferred));
+ EXPECT_EQ(1U, GetNumQueries(dns3, host_name_normal));
+ EXPECT_TRUE(result != nullptr);
+ EXPECT_EQ("1.2.3.5", ToString(result));
+
+ t3_task_done = true;
+ dns1.setDeferredResp(false);
+ });
+ t3.join();
+ t1.join();
+ t2.join();
+}
+
TEST_F(ResolverTest, MultidomainResolution) {
constexpr char host_name[] = "nihao.example2.com.";
std::vector<std::string> searchDomains = { "example1.com", "example2.com", "example3.com" };
@@ -732,6 +838,16 @@
for (std::thread& thread : threads) {
thread.join();
}
+
+ std::vector<std::string> res_servers;
+ std::vector<std::string> res_domains;
+ std::vector<std::string> res_tls_servers;
+ __res_params res_params;
+ std::vector<ResolverStats> res_stats;
+ int wait_for_pending_req_timeout_count;
+ ASSERT_TRUE(GetResolverInfo(&res_servers, &res_domains, &res_tls_servers, &res_params,
+ &res_stats, &wait_for_pending_req_timeout_count));
+ EXPECT_EQ(0, wait_for_pending_req_timeout_count);
}
TEST_F(ResolverTest, GetAddrInfoStressTest_Binder_100) {
@@ -758,8 +874,9 @@
std::vector<std::string> res_tls_servers;
__res_params res_params;
std::vector<ResolverStats> res_stats;
- ASSERT_TRUE(
- GetResolverInfo(&res_servers, &res_domains, &res_tls_servers, &res_params, &res_stats));
+ int wait_for_pending_req_timeout_count;
+ ASSERT_TRUE(GetResolverInfo(&res_servers, &res_domains, &res_tls_servers, &res_params,
+ &res_stats, &wait_for_pending_req_timeout_count));
EXPECT_EQ(0U, res_servers.size());
EXPECT_EQ(0U, res_domains.size());
EXPECT_EQ(0U, res_tls_servers.size());
@@ -864,8 +981,9 @@
std::vector<std::string> res_tls_servers;
__res_params res_params;
std::vector<ResolverStats> res_stats;
- ASSERT_TRUE(
- GetResolverInfo(&res_servers, &res_domains, &res_tls_servers, &res_params, &res_stats));
+ int wait_for_pending_req_timeout_count;
+ ASSERT_TRUE(GetResolverInfo(&res_servers, &res_domains, &res_tls_servers, &res_params,
+ &res_stats, &wait_for_pending_req_timeout_count));
// Check the size of the stats and its contents.
EXPECT_EQ(static_cast<size_t>(MAXNS), res_servers.size());
@@ -915,8 +1033,9 @@
std::vector<std::string> res_tls_servers;
__res_params res_params;
std::vector<ResolverStats> res_stats;
- ASSERT_TRUE(
- GetResolverInfo(&res_servers, &res_domains, &res_tls_servers, &res_params, &res_stats));
+ int wait_for_pending_req_timeout_count;
+ ASSERT_TRUE(GetResolverInfo(&res_servers, &res_domains, &res_tls_servers, &res_params,
+ &res_stats, &wait_for_pending_req_timeout_count));
EXPECT_EQ(1, res_stats[0].timeouts);
EXPECT_EQ(1, res_stats[1].errors);