Merge qt-r1-dev-plus-aosp-without-vendor (5817612) into stage-aosp-master

No content change.

Bug: 135460123
Change-Id: Ib32a7aaba614052187717674add4b517857f78e0
Merged-In: If96423ed9f30515c5ff7bcba69d789ce68eb6ba3
diff --git a/Android.bp b/Android.bp
index 059fa86..6d566d6 100644
--- a/Android.bp
+++ b/Android.bp
@@ -157,22 +157,23 @@
     srcs: [
         "tests/dns_responder/dns_responder.cpp",
         "dnsresolver_binder_test.cpp",
-        "resolver_test.cpp",
+        "resolv_integration_test.cpp",
     ],
     header_libs: [
         "libnetd_resolv_headers",
     ],
     shared_libs: [
         "libbpf_android",
-        "libbase",
         "libbinder",
         "libcrypto",
+        "liblog",
         "libnetd_client",
         "libssl",
         "libutils",
     ],
     static_libs: [
         "dnsresolver_aidl_interface-cpp",
+        "libbase",
         "libgmock",
         "libnetd_test_dnsresponder",
         "libnetd_test_metrics_listener",
@@ -197,19 +198,23 @@
     //TODO:  drop root privileges and make it be an real unit test.
     defaults: ["netd_defaults"],
     srcs: [
-        "dns_tls_test.cpp",
-        "libnetd_resolv_test.cpp",
-        "res_cache_test.cpp",
+        "resolv_cache_unit_test.cpp",
+        "resolv_tls_unit_test.cpp",
+        "resolv_unit_test.cpp",
     ],
     shared_libs: [
         "libbase",
         "libcrypto",
         "libcutils",
         "libssl",
+        "libbinder_ndk",
     ],
     static_libs: [
         "dnsresolver_aidl_interface-V2-cpp",
+        "dnsresolver_aidl_interface-V2-ndk_platform",
+        "netd_event_listener_interface-V1-ndk_platform",
         "libgmock",
+        "liblog",
         "libnetd_resolv",
         "libnetd_test_dnsresponder",
         "libnetd_test_resolv_utils",
diff --git a/DnsProxyListener.cpp b/DnsProxyListener.cpp
index 9743acd..dc8cd96 100644
--- a/DnsProxyListener.cpp
+++ b/DnsProxyListener.cpp
@@ -118,18 +118,11 @@
     return (flags & NET_CONTEXT_FLAG_USE_LOCAL_NAMESERVERS) != 0;
 }
 
-inline bool queryingViaTls(unsigned dns_netid) {
-    // TODO: The simpler PrivateDnsStatus should suffice here.
-    ExternalPrivateDnsStatus privateDnsStatus = {PrivateDnsMode::OFF, 0, {}};
-    gPrivateDnsConfiguration.getStatus(dns_netid, &privateDnsStatus);
-    switch (static_cast<PrivateDnsMode>(privateDnsStatus.mode)) {
+bool queryingViaTls(unsigned dns_netid) {
+    const auto privateDnsStatus = gPrivateDnsConfiguration.getStatus(dns_netid);
+    switch (privateDnsStatus.mode) {
         case PrivateDnsMode::OPPORTUNISTIC:
-            for (int i = 0; i < privateDnsStatus.numServers; i++) {
-                if (privateDnsStatus.serverStatus[i].validation == Validation::success) {
-                    return true;
-                }
-            }
-            return false;
+            return !privateDnsStatus.validatedServers().empty();
         case PrivateDnsMode::STRICT:
             return true;
         default:
diff --git a/DnsResolver.cpp b/DnsResolver.cpp
index 092ff3f..70ecc7e 100644
--- a/DnsResolver.cpp
+++ b/DnsResolver.cpp
@@ -30,7 +30,15 @@
     LOG(INFO) << __func__ << ": Initializing resolver";
     resolv_set_log_severity(android::base::WARNING);
 
-    android::net::gResNetdCallbacks = *callbacks;
+    using android::net::gApiLevel;
+    gApiLevel = android::base::GetUintProperty<uint64_t>("ro.build.version.sdk", 0);
+    using android::net::gResNetdCallbacks;
+    gResNetdCallbacks.check_calling_permission = callbacks->check_calling_permission;
+    gResNetdCallbacks.get_network_context = callbacks->get_network_context;
+    gResNetdCallbacks.log = callbacks->log;
+    if (gApiLevel >= 30) {
+        gResNetdCallbacks.tagSocket = callbacks->tagSocket;
+    }
     android::net::gDnsResolv = android::net::DnsResolver::getInstance();
     return android::net::gDnsResolv->start();
 }
@@ -41,8 +49,14 @@
 namespace {
 
 bool verifyCallbacks() {
-    return gResNetdCallbacks.check_calling_permission && gResNetdCallbacks.get_network_context &&
-           gResNetdCallbacks.log;
+    if (!(gResNetdCallbacks.check_calling_permission && gResNetdCallbacks.get_network_context &&
+          gResNetdCallbacks.log)) {
+        return false;
+    }
+    if (gApiLevel >= 30) {
+        return gResNetdCallbacks.tagSocket != nullptr;
+    }
+    return true;
 }
 
 }  // namespace
@@ -50,6 +64,7 @@
 DnsResolver* gDnsResolv = nullptr;
 ResolverNetdCallbacks gResNetdCallbacks;
 netdutils::Log gDnsResolverLog("dnsResolver");
+uint64_t gApiLevel = 0;
 
 DnsResolver* DnsResolver::getInstance() {
     // Instantiated on first use.
diff --git a/DnsResolver.h b/DnsResolver.h
index 6d8d2ce..2efa341 100644
--- a/DnsResolver.h
+++ b/DnsResolver.h
@@ -44,6 +44,7 @@
 extern DnsResolver* gDnsResolv;
 extern ResolverNetdCallbacks gResNetdCallbacks;
 extern netdutils::Log gDnsResolverLog;
+extern uint64_t gApiLevel;
 
 }  // namespace net
 }  // namespace android
diff --git a/DnsTlsSocket.cpp b/DnsTlsSocket.cpp
index 838886a..14aadd4 100644
--- a/DnsTlsSocket.cpp
+++ b/DnsTlsSocket.cpp
@@ -39,6 +39,7 @@
 #include <netdutils/ThreadUtil.h>
 
 #include "private/android_filesystem_config.h"  // AID_DNS
+#include "resolv_private.h"
 
 // NOTE: Inject CA certificate for internal testing -- do NOT enable in production builds
 #ifndef RESOLV_INJECT_CA_CERTIFICATE
@@ -96,9 +97,7 @@
         return Status(errno);
     }
 
-    if (fchown(mSslFd.get(), AID_DNS, -1) == -1) {
-        LOG(WARNING) << "Failed to chown socket: %s" << strerror(errno);
-    }
+    resolv_tag_socket(mSslFd.get(), AID_DNS);
 
     const socklen_t len = sizeof(mMark);
     if (setsockopt(mSslFd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
diff --git a/PrivateDnsConfiguration.cpp b/PrivateDnsConfiguration.cpp
index c91367f..14bac16 100644
--- a/PrivateDnsConfiguration.cpp
+++ b/PrivateDnsConfiguration.cpp
@@ -126,37 +126,13 @@
     const auto netPair = mPrivateDnsTransports.find(netId);
     if (netPair != mPrivateDnsTransports.end()) {
         for (const auto& serverPair : netPair->second) {
-            if (serverPair.second == Validation::success) {
-                status.validatedServers.push_back(serverPair.first);
-            }
+            status.serversMap.emplace(serverPair.first, serverPair.second);
         }
     }
 
     return status;
 }
 
-void PrivateDnsConfiguration::getStatus(unsigned netId, ExternalPrivateDnsStatus* status) {
-    std::lock_guard guard(mPrivateDnsLock);
-
-    const auto mode = mPrivateDnsModes.find(netId);
-    if (mode == mPrivateDnsModes.end()) return;
-    status->mode = mode->second;
-
-    const auto netPair = mPrivateDnsTransports.find(netId);
-    if (netPair != mPrivateDnsTransports.end()) {
-        int count = 0;
-        for (const auto& serverPair : netPair->second) {
-            status->serverStatus[count].ss = serverPair.first.ss;
-            status->serverStatus[count].hostname =
-                    serverPair.first.name.empty() ? "" : serverPair.first.name.c_str();
-            status->serverStatus[count].validation = serverPair.second;
-            count++;
-            if (count >= MAXNS) break;  // Lose the rest
-        }
-        status->numServers = count;
-    }
-}
-
 void PrivateDnsConfiguration::clear(unsigned netId) {
     LOG(DEBUG) << "PrivateDnsConfiguration::clear(" << netId << ")";
     std::lock_guard guard(mPrivateDnsLock);
@@ -242,6 +218,9 @@
     auto& tracker = netPair->second;
     auto serverPair = tracker.find(server);
     if (serverPair == tracker.end()) {
+        // TODO: Consider not adding this server to the tracker since this server is not expected
+        // to be one of the private DNS servers for this network now. This could prevent this
+        // server from being included when dumping status.
         LOG(WARNING) << "Server " << addrToString(&server.ss)
                      << " was removed during private DNS validation";
         success = false;
diff --git a/PrivateDnsConfiguration.h b/PrivateDnsConfiguration.h
index d52db77..5800831 100644
--- a/PrivateDnsConfiguration.h
+++ b/PrivateDnsConfiguration.h
@@ -14,8 +14,7 @@
  * limitations under the License.
  */
 
-#ifndef NETD_RESOLV_PRIVATEDNSCONFIGURATION_H
-#define NETD_RESOLV_PRIVATEDNSCONFIGURATION_H
+#pragma once
 
 #include <list>
 #include <map>
@@ -37,31 +36,28 @@
 
 struct PrivateDnsStatus {
     PrivateDnsMode mode;
-    std::list<DnsTlsServer> validatedServers;
-};
+    std::map<DnsTlsServer, Validation, AddressComparator> serversMap;
 
-// TODO: remove this C-style struct and use PrivateDnsStatus everywhere.
-struct ExternalPrivateDnsStatus {
-    PrivateDnsMode mode;
-    int numServers;
-    struct PrivateDnsInfo {
-        sockaddr_storage ss;
-        const char* hostname;
-        Validation validation;
-    } serverStatus[MAXNS];
+    std::list<DnsTlsServer> validatedServers() const {
+        std::list<DnsTlsServer> servers;
+
+        for (const auto& pair : serversMap) {
+            if (pair.second == Validation::success) {
+                servers.push_back(pair.first);
+            }
+        }
+        return servers;
+    }
 };
 
 class PrivateDnsConfiguration {
   public:
     int set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
-            const std::string& name, const std::string& caCert);
+            const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock);
 
-    PrivateDnsStatus getStatus(unsigned netId);
+    PrivateDnsStatus getStatus(unsigned netId) EXCLUDES(mPrivateDnsLock);
 
-    // DEPRECATED, use getStatus() above.
-    void getStatus(unsigned netId, ExternalPrivateDnsStatus* status);
-
-    void clear(unsigned netId);
+    void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);
 
   private:
     typedef std::map<DnsTlsServer, Validation, AddressComparator> PrivateDnsTracker;
@@ -88,5 +84,3 @@
 
 }  // namespace net
 }  // namespace android
-
-#endif /* NETD_RESOLV_PRIVATEDNSCONFIGURATION_H */
diff --git a/README-DoT.md b/README-DoT.md
new file mode 100644
index 0000000..56fe772
--- /dev/null
+++ b/README-DoT.md
@@ -0,0 +1,127 @@
+# DNS-over-TLS query forwarder design
+
+## Overview
+
+The DNS-over-TLS query forwarder consists of five classes:
+ * `DnsTlsDispatcher`
+ * `DnsTlsTransport`
+ * `DnsTlsQueryMap`
+ * `DnsTlsSessionCache`
+ * `DnsTlsSocket`
+
+`DnsTlsDispatcher` is a singleton class whose `query` method is the DnsTls's
+only public interface.  `DnsTlsDispatcher` is just a table holding the
+`DnsTlsTransport` for each server (represented by a `DnsTlsServer` struct) and
+network.  `DnsTlsDispatcher` also blocks each query thread, waiting on a
+`std::future` returned by `DnsTlsTransport` that represents the response.
+
+`DnsTlsTransport` sends each query over a `DnsTlsSocket`, opening a
+new one if necessary.  It also has to listen for responses from the
+`DnsTlsSocket`, which happen on a different thread.
+`IDnsTlsSocketObserver` is an interface defining how `DnsTlsSocket` returns
+responses to `DnsTlsTransport`.
+
+`DnsTlsQueryMap` and `DnsTlsSessionCache` are helper classes owned by `DnsTlsTransport`.
+`DnsTlsQueryMap` handles ID renumbering and query-response pairing.
+`DnsTlsSessionCache` allows TLS session resumption.
+
+`DnsTlsSocket` interleaves all queries onto a single socket, and reports all
+responses to `DnsTlsTransport` (through the `IDnsTlsObserver` interface).  It doesn't
+know anything about which queries correspond to which responses, and does not retain
+state to indicate whether there is an outstanding query.
+
+## Threading
+
+### Overall patterns
+
+For clarity, each of the five classes in this design is thread-safe and holds one lock.
+Classes that spawn a helper thread call `thread::join()` in their destructor to ensure
+that it is cleaned up appropriately.
+
+All the classes here make full use of Clang thread annotations (and also null-pointer
+annotations) to minimize the likelihood of a latent threading bug.  The unit tests are
+also heavily threaded to exercise this functionality.
+
+This code creates O(1) threads per socket, and does not create a new thread for each
+query or response.  However, DnsProxyListener does create a thread for each query.
+
+### Threading in `DnsTlsSocket`
+
+`DnsTlsSocket` can receive queries on any thread, and send them over a
+"reliable datagram pipe" (`socketpair()` in `SOCK_SEQPACKET` mode).
+The query method writes a struct (containing a pointer to the query) to the pipe
+from its thread, and the loop thread (which owns the SSL socket)
+reads off the other end of the pipe.  The pipe doesn't actually have a queue "inside";
+instead, any queueing happens by blocking the query thread until the
+socket thread can read the datagram off the other end.
+
+We need to pass messages between threads using a pipe, and not a condition variable
+or a thread-safe queue, because the socket thread has to be blocked
+in `poll()` waiting for data from the server, but also has to be woken
+up on inputs from the query threads.  Therefore, inputs from the query
+threads have to arrive on a socket, so that `poll()` can listen for them.
+(There can only be a single thread because [you can't use different threads
+to read and write in OpenSSL](https://www.openssl.org/blog/blog/2017/02/21/threads/)).
+
+## ID renumbering
+
+`DnsTlsDispatcher` accepts queries that have colliding ID numbers and still sends them on
+a single socket.  To avoid confusion at the server, `DnsTlsQueryMap` assigns each
+query a new ID for transmission, records the mapping from input IDs to sent IDs, and
+applies the inverse mapping to responses before returning them to the caller.
+
+`DnsTlsQueryMap` assigns each new query the ID number one greater than the largest
+ID number of an outstanding query.  This means that ID numbers are initially sequential
+and usually small.  If the largest possible ID number is already in use,
+`DnsTlsQueryMap` will scan the ID space to find an available ID, or fail the query
+if there are no available IDs.  Queries will not block waiting for an ID number to
+become available.
+
+## Time constants
+
+`DnsTlsSocket` imposes a 20-second inactivity timeout.  A socket that has been idle for
+20 seconds will be closed.  This sets the limit of tolerance for slow replies,
+which could happen as a result of malfunctioning authoritative DNS servers.
+If there are any pending queries, `DnsTlsTransport` will retry them.
+
+`DnsTlsQueryMap` imposes a retry limit of 3.  `DnsTlsTransport` will retry the query up
+to 3 times before reporting failure to `DnsTlsDispatcher`.
+This limit helps to ensure proper functioning in the case of a recursive resolver that
+is malfunctioning or is flooded with requests that are stalled due to malfunctioning
+authoritative servers.
+
+`DnsTlsDispatcher` maintains a 5-minute timeout.  Any `DnsTlsTransport` that has had no
+outstanding queries for 5 minutes will be destroyed at the next query on a different
+transport.
+This sets the limit on how long session tickets will be preserved during idle periods,
+because each `DnsTlsTransport` owns a `DnsTlsSessionCache`.  Imposing this timeout
+increases latency on the first query after an idle period, but also helps to avoid
+unbounded memory usage.
+
+`DnsTlsSessionCache` sets a limit of 5 sessions in each cache, expiring the oldest one
+when the limit is reached.  However, because the client code does not currently
+reuse sessions more than once, it should not be possible to hit this limit.
+
+## Testing
+
+Unit tests for DoT are in `resolv_tls_unit_test.cpp`. They cover all the classes except
+`DnsTlsSocket` (which requires `CAP_NET_ADMIN` because it uses `setsockopt(SO_MARK)`) and
+`DnsTlsSessionCache` (which requires integration with libssl).  These classes are
+exercised by the integration tests in `resolv_integration_test.cpp`.
+
+### Dependency Injection
+
+For unit testing, we would like to be able to mock out `DnsTlsSocket`.  This is
+particularly required for unit testing of `DnsTlsDispatcher` and `DnsTlsTransport`.
+To make these unit tests possible, this code uses a dependency injection pattern:
+`DnsTlsSocket` is produced by a `DnsTlsSocketFactory`, and both of these have a
+defined interface.
+
+`DnsTlsDispatcher`'s constructor takes an `IDnsTlsSocketFactory`,
+which in production is a `DnsTlsSocketFactory`.  However, in unit tests, we can
+substitute a test factory that returns a fake socket, so that the unit tests can
+run without actually connecting over TLS to a test server.  (The integration tests
+do actual TLS.)
+
+## Reference
+ * [BoringSSL API docs](https://commondatastorage.googleapis.com/chromium-boringssl-docs/headers.html)
diff --git a/README.md b/README.md
index b002a9c..35371ad 100644
--- a/README.md
+++ b/README.md
@@ -1,128 +1,3 @@
-# DNS-over-TLS query forwarder design
-
-## Overview
-
-The DNS-over-TLS query forwarder consists of five classes:
- * `DnsTlsDispatcher`
- * `DnsTlsTransport`
- * `DnsTlsQueryMap`
- * `DnsTlsSessionCache`
- * `DnsTlsSocket`
-
-`DnsTlsDispatcher` is a singleton class whose `query` method is the DnsTls's
-only public interface.  `DnsTlsDispatcher` is just a table holding the
-`DnsTlsTransport` for each server (represented by a `DnsTlsServer` struct) and
-network.  `DnsTlsDispatcher` also blocks each query thread, waiting on a
-`std::future` returned by `DnsTlsTransport` that represents the response.
-
-`DnsTlsTransport` sends each query over a `DnsTlsSocket`, opening a
-new one if necessary.  It also has to listen for responses from the
-`DnsTlsSocket`, which happen on a different thread.
-`IDnsTlsSocketObserver` is an interface defining how `DnsTlsSocket` returns
-responses to `DnsTlsTransport`.
-
-`DnsTlsQueryMap` and `DnsTlsSessionCache` are helper classes owned by `DnsTlsTransport`.
-`DnsTlsQueryMap` handles ID renumbering and query-response pairing.
-`DnsTlsSessionCache` allows TLS session resumption.
-
-`DnsTlsSocket` interleaves all queries onto a single socket, and reports all
-responses to `DnsTlsTransport` (through the `IDnsTlsObserver` interface).  It doesn't
-know anything about which queries correspond to which responses, and does not retain
-state to indicate whether there is an outstanding query.
-
-## Threading
-
-### Overall patterns
-
-For clarity, each of the five classes in this design is thread-safe and holds one lock.
-Classes that spawn a helper thread call `thread::join()` in their destructor to ensure
-that it is cleaned up appropriately.
-
-All the classes here make full use of Clang thread annotations (and also null-pointer
-annotations) to minimize the likelihood of a latent threading bug.  The unit tests are
-also heavily threaded to exercise this functionality.
-
-This code creates O(1) threads per socket, and does not create a new thread for each
-query or response.  However, DnsProxyListener does create a thread for each query.
-
-### Threading in `DnsTlsSocket`
-
-`DnsTlsSocket` can receive queries on any thread, and send them over a
-"reliable datagram pipe" (`socketpair()` in `SOCK_SEQPACKET` mode).
-The query method writes a struct (containing a pointer to the query) to the pipe
-from its thread, and the loop thread (which owns the SSL socket)
-reads off the other end of the pipe.  The pipe doesn't actually have a queue "inside";
-instead, any queueing happens by blocking the query thread until the
-socket thread can read the datagram off the other end.
-
-We need to pass messages between threads using a pipe, and not a condition variable
-or a thread-safe queue, because the socket thread has to be blocked
-in `poll()` waiting for data from the server, but also has to be woken
-up on inputs from the query threads.  Therefore, inputs from the query
-threads have to arrive on a socket, so that `poll()` can listen for them.
-(There can only be a single thread because [you can't use different threads
-to read and write in OpenSSL](https://www.openssl.org/blog/blog/2017/02/21/threads/)).
-
-## ID renumbering
-
-`DnsTlsDispatcher` accepts queries that have colliding ID numbers and still sends them on
-a single socket.  To avoid confusion at the server, `DnsTlsQueryMap` assigns each
-query a new ID for transmission, records the mapping from input IDs to sent IDs, and
-applies the inverse mapping to responses before returning them to the caller.
-
-`DnsTlsQueryMap` assigns each new query the ID number one greater than the largest
-ID number of an outstanding query.  This means that ID numbers are initially sequential
-and usually small.  If the largest possible ID number is already in use,
-`DnsTlsQueryMap` will scan the ID space to find an available ID, or fail the query
-if there are no available IDs.  Queries will not block waiting for an ID number to
-become available.
-
-## Time constants
-
-`DnsTlsSocket` imposes a 20-second inactivity timeout.  A socket that has been idle for
-20 seconds will be closed.  This sets the limit of tolerance for slow replies,
-which could happen as a result of malfunctioning authoritative DNS servers.
-If there are any pending queries, `DnsTlsTransport` will retry them.
-
-`DnsTlsQueryMap` imposes a retry limit of 3.  `DnsTlsTransport` will retry the query up
-to 3 times before reporting failure to `DnsTlsDispatcher`.
-This limit helps to ensure proper functioning in the case of a recursive resolver that
-is malfunctioning or is flooded with requests that are stalled due to malfunctioning
-authoritative servers.
-
-`DnsTlsDispatcher` maintains a 5-minute timeout.  Any `DnsTlsTransport` that has had no
-outstanding queries for 5 minutes will be destroyed at the next query on a different
-transport.
-This sets the limit on how long session tickets will be preserved during idle periods,
-because each `DnsTlsTransport` owns a `DnsTlsSessionCache`.  Imposing this timeout
-increases latency on the first query after an idle period, but also helps to avoid
-unbounded memory usage.
-
-`DnsTlsSessionCache` sets a limit of 5 sessions in each cache, expiring the oldest one
-when the limit is reached.  However, because the client code does not currently
-reuse sessions more than once, it should not be possible to hit this limit.
-
-## Testing
-
-Unit tests are in `dns_tls_test.cpp`. They cover all the classes except
-`DnsTlsSocket` (which requires `CAP_NET_ADMIN` because it uses `setsockopt(SO_MARK)`) and
-`DnsTlsSessionCache` (which requires integration with libssl).  These classes are
-exercised by the integration tests in `../tests/resolv_test.cpp`.
-
-### Dependency Injection
-
-For unit testing, we would like to be able to mock out `DnsTlsSocket`.  This is
-particularly required for unit testing of `DnsTlsDispatcher` and `DnsTlsTransport`.
-To make these unit tests possible, this code uses a dependency injection pattern:
-`DnsTlsSocket` is produced by a `DnsTlsSocketFactory`, and both of these have a
-defined interface.
-
-`DnsTlsDispatcher`'s constructor takes an `IDnsTlsSocketFactory`,
-which in production is a `DnsTlsSocketFactory`.  However, in unit tests, we can
-substitute a test factory that returns a fake socket, so that the unit tests can
-run without actually connecting over TLS to a test server.  (The integration tests
-do actual TLS.)
-
 ## Logging
 
 This code uses LOG(X) for logging. Log levels are VERBOSE,DEBUG,INFO,WARNING and ERROR.
@@ -136,5 +11,3 @@
 ERROR     4
 Verbose resolver logs could contain PII -- do NOT enable in production builds.
 
-## Reference
- * [BoringSSL API docs](https://commondatastorage.googleapis.com/chromium-boringssl-docs/headers.html)
diff --git a/ResolverController.cpp b/ResolverController.cpp
index 0128a12..a53d05a 100644
--- a/ResolverController.cpp
+++ b/ResolverController.cpp
@@ -249,11 +249,9 @@
     // Serialize the information for binder.
     ResolverStats::encodeAll(res_stats, stats);
 
-    ExternalPrivateDnsStatus privateDnsStatus = {PrivateDnsMode::OFF, 0, {}};
-    gPrivateDnsConfiguration.getStatus(netId, &privateDnsStatus);
-    for (int i = 0; i < privateDnsStatus.numServers; i++) {
-        std::string tlsServer_str = addrToString(&(privateDnsStatus.serverStatus[i].ss));
-        tlsServers->push_back(std::move(tlsServer_str));
+    const auto privateDnsStatus = gPrivateDnsConfiguration.getStatus(netId);
+    for (const auto& pair : privateDnsStatus.serversMap) {
+        tlsServers->push_back(addrToString(&pair.first.ss));
     }
 
     params->resize(IDnsResolver::RESOLVER_PARAMS_COUNT);
@@ -344,20 +342,17 @@
         }
 
         mDns64Configuration.dump(dw, netId);
-        ExternalPrivateDnsStatus privateDnsStatus = {PrivateDnsMode::OFF, 0, {}};
-        gPrivateDnsConfiguration.getStatus(netId, &privateDnsStatus);
+        const auto privateDnsStatus = gPrivateDnsConfiguration.getStatus(netId);
         dw.println("Private DNS mode: %s", getPrivateDnsModeString(privateDnsStatus.mode));
-        if (!privateDnsStatus.numServers) {
+        if (privateDnsStatus.serversMap.size() == 0) {
             dw.println("No Private DNS servers configured");
         } else {
-            dw.println("Private DNS configuration (%u entries)", privateDnsStatus.numServers);
+            dw.println("Private DNS configuration (%u entries)",
+                       static_cast<uint32_t>(privateDnsStatus.serversMap.size()));
             dw.incIndent();
-            for (int i = 0; i < privateDnsStatus.numServers; i++) {
-                dw.println("%s name{%s} status{%s}",
-                           addrToString(&(privateDnsStatus.serverStatus[i].ss)).c_str(),
-                           privateDnsStatus.serverStatus[i].hostname,
-                           validationStatusToString(static_cast<Validation>(
-                                   privateDnsStatus.serverStatus[i].validation)));
+            for (const auto& pair : privateDnsStatus.serversMap) {
+                dw.println("%s name{%s} status{%s}", addrToString(&pair.first.ss).c_str(),
+                           pair.first.name.c_str(), validationStatusToString(pair.second));
             }
             dw.decIndent();
         }
diff --git a/dnsresolver_binder_test.cpp b/dnsresolver_binder_test.cpp
index fcd4918..97c136f 100644
--- a/dnsresolver_binder_test.cpp
+++ b/dnsresolver_binder_test.cpp
@@ -54,7 +54,7 @@
 using android::netdutils::Stopwatch;
 
 // TODO: make this dynamic and stop depending on implementation details.
-// Sync from TEST_NETID in dns_responder_client.cpp as resolver_test.cpp does.
+// Sync from TEST_NETID in dns_responder_client.cpp as resolv_integration_test.cpp does.
 constexpr int TEST_NETID = 30;
 
 class DnsResolverBinderTest : public ::testing::Test {
@@ -158,7 +158,7 @@
     ASSERT_EQ(EEXIST, status.serviceSpecificErrorCode());
 }
 
-// TODO: Move this test to resolver_test.cpp
+// TODO: Move this test to resolv_integration_test.cpp
 TEST_F(DnsResolverBinderTest, RegisterEventListener_onDnsEvent) {
     // The test configs are used to trigger expected events. The expected results are defined in
     // expectedResults.
diff --git a/include/netd_resolv/resolv.h b/include/netd_resolv/resolv.h
index 400d56d..17cf2aa 100644
--- a/include/netd_resolv/resolv.h
+++ b/include/netd_resolv/resolv.h
@@ -81,6 +81,7 @@
 typedef void (*get_network_context_callback)(unsigned netid, uid_t uid,
                                              android_net_context* netcontext);
 typedef void (*log_callback)(const char* msg);
+typedef int (*tagSocketCallback)(int sockFd, uint32_t tag, uid_t uid);
 
 /*
  * Some functions needed by the resolver (e.g. checkCallingPermission()) live in
@@ -92,8 +93,11 @@
     check_calling_permission_callback check_calling_permission;
     get_network_context_callback get_network_context;
     log_callback log;
+    tagSocketCallback tagSocket;
 };
 
+#define TAG_SYSTEM_DNS 0xFFFFFF82
+
 LIBNETD_RESOLV_PUBLIC bool resolv_has_nameservers(unsigned netid);
 
 // Set callbacks and bring DnsResolver up.
diff --git a/res_cache.cpp b/res_cache.cpp
index 83ec09f..ceaaf04 100644
--- a/res_cache.cpp
+++ b/res_cache.cpp
@@ -58,12 +58,11 @@
 
 #include <server_configurable_flags/get_flags.h>
 
+#include "res_debug.h"
 #include "res_state_ext.h"
 #include "resolv_private.h"
 
-// NOTE: verbose logging MUST NOT be left enabled in production binaries.
-// It floods logs at high rate, and can leak privacy-sensitive information.
-constexpr bool kDumpData = false;
+using android::base::StringAppendF;
 
 /* This code implements a small and *simple* DNS resolver cache.
  *
@@ -147,49 +146,6 @@
  */
 #define CONFIG_MAX_ENTRIES (64 * 2 * 5)
 
-/** BOUNDED BUFFER FORMATTING **/
-
-/* technical note:
- *
- *   the following debugging routines are used to append data to a bounded
- *   buffer they take two parameters that are:
- *
- *   - p : a pointer to the current cursor position in the buffer
- *         this value is initially set to the buffer's address.
- *
- *   - end : the address of the buffer's limit, i.e. of the first byte
- *           after the buffer. this address should never be touched.
- *
- *           IMPORTANT: it is assumed that end > buffer_address, i.e.
- *                      that the buffer is at least one byte.
- *
- *   the bprint_x() functions return the new value of 'p' after the data
- *   has been appended, and also ensure the following:
- *
- *   - the returned value will never be strictly greater than 'end'
- *
- *   - a return value equal to 'end' means that truncation occurred
- *     (in which case, end[-1] will be set to 0)
- *
- *   - after returning from a bprint_x() function, the content of the buffer
- *     is always 0-terminated, even in the event of truncation.
- *
- *  these conventions allow you to call bprint_x() functions multiple times and
- *  only check for truncation at the end of the sequence, as in:
- *
- *     char  buff[1000], *p = buff, *end = p + sizeof(buff);
- *
- *     p = bprint_c(p, end, '"');
- *     p = bprint_s(p, end, my_string);
- *     p = bprint_c(p, end, '"');
- *
- *     if (p >= end) {
- *        // buffer was too small
- *     }
- *
- *     printf( "%s", buff );
- */
-
 /* Defaults used for initializing res_params */
 
 // If successes * 100 / total_samples is less than this value, the server is considered failing
@@ -197,122 +153,6 @@
 // Sample validity in seconds. Set to -1 to disable skipping failing servers.
 #define NSSAMPLE_VALIDITY 1800
 
-/* add a char to a bounded buffer */
-static char* bprint_c(char* p, char* end, int c) {
-    if (p < end) {
-        if (p + 1 == end)
-            *p++ = 0;
-        else {
-            *p++ = (char) c;
-            *p = 0;
-        }
-    }
-    return p;
-}
-
-/* add a sequence of bytes to a bounded buffer */
-static char* bprint_b(char* p, char* end, const char* buf, int len) {
-    int avail = end - p;
-
-    if (avail <= 0 || len <= 0) return p;
-
-    if (avail > len) avail = len;
-
-    memcpy(p, buf, avail);
-    p += avail;
-
-    if (p < end)
-        p[0] = 0;
-    else
-        end[-1] = 0;
-
-    return p;
-}
-
-/* add a string to a bounded buffer */
-static char* bprint_s(char* p, char* end, const char* str) {
-    return bprint_b(p, end, str, strlen(str));
-}
-
-/* add a formatted string to a bounded buffer */
-static char* bprint(char* p, char* end, const char* format, ...) {
-    int avail, n;
-    va_list args;
-
-    avail = end - p;
-
-    if (avail <= 0) return p;
-
-    va_start(args, format);
-    n = vsnprintf(p, avail, format, args);
-    va_end(args);
-
-    /* certain C libraries return -1 in case of truncation */
-    if (n < 0 || n > avail) n = avail;
-
-    p += n;
-    /* certain C libraries do not zero-terminate in case of truncation */
-    if (p == end) p[-1] = 0;
-
-    return p;
-}
-
-/* add a hex value to a bounded buffer, up to 8 digits */
-static char* bprint_hex(char* p, char* end, unsigned value, int numDigits) {
-    char text[sizeof(unsigned) * 2];
-    int nn = 0;
-
-    while (numDigits-- > 0) {
-        text[nn++] = "0123456789abcdef"[(value >> (numDigits * 4)) & 15];
-    }
-    return bprint_b(p, end, text, nn);
-}
-
-/* add the hexadecimal dump of some memory area to a bounded buffer */
-static char* bprint_hexdump(char* p, char* end, const uint8_t* data, int datalen) {
-    int lineSize = 16;
-
-    while (datalen > 0) {
-        int avail = datalen;
-        int nn;
-
-        if (avail > lineSize) avail = lineSize;
-
-        for (nn = 0; nn < avail; nn++) {
-            if (nn > 0) p = bprint_c(p, end, ' ');
-            p = bprint_hex(p, end, data[nn], 2);
-        }
-        for (; nn < lineSize; nn++) {
-            p = bprint_s(p, end, "   ");
-        }
-        p = bprint_s(p, end, "  ");
-
-        for (nn = 0; nn < avail; nn++) {
-            int c = data[nn];
-
-            if (c < 32 || c > 127) c = '.';
-
-            p = bprint_c(p, end, c);
-        }
-        p = bprint_c(p, end, '\n');
-
-        data += avail;
-        datalen -= avail;
-    }
-    return p;
-}
-
-/* dump the content of a query of packet to the log */
-static void dump_bytes(const uint8_t* base, int len) {
-    if (!kDumpData) return;
-
-    char buff[1024];
-    char *p = buff, *end = p + sizeof(buff);
-
-    p = bprint_hexdump(p, end, base, len);
-    LOG(INFO) << __func__ << ": " << buff;
-}
-
 static time_t _time_now(void) {
     struct timeval tv;
 
@@ -581,98 +421,6 @@
     return 1;
 }
 
-/** QUERY DEBUGGING **/
-static char* dnsPacket_bprintQName(DnsPacket* packet, char* bp, char* bend) {
-    const uint8_t* p = packet->cursor;
-    const uint8_t* end = packet->end;
-    int first = 1;
-
-    for (;;) {
-        int c;
-
-        if (p >= end) break;
-
-        c = *p++;
-
-        if (c == 0) {
-            packet->cursor = p;
-            return bp;
-        }
-
-        /* we don't expect label compression in QNAMEs */
-        if (c >= 64) break;
-
-        if (first)
-            first = 0;
-        else
-            bp = bprint_c(bp, bend, '.');
-
-        bp = bprint_b(bp, bend, (const char*) p, c);
-
-        p += c;
-        /* we rely on the bound check at the start
-         * of the loop here */
-    }
-    /* malformed data */
-    bp = bprint_s(bp, bend, "<MALFORMED>");
-    return bp;
-}
-
-static char* dnsPacket_bprintQR(DnsPacket* packet, char* p, char* end) {
-#define QQ(x) \
-    { DNS_TYPE_##x, #x }
-    static const struct {
-        const char* typeBytes;
-        const char* typeString;
-    } qTypes[] = {QQ(A), QQ(PTR), QQ(MX), QQ(AAAA), QQ(ALL), {NULL, NULL}};
-    int nn;
-    const char* typeString = NULL;
-
-    /* dump QNAME */
-    p = dnsPacket_bprintQName(packet, p, end);
-
-    /* dump TYPE */
-    p = bprint_s(p, end, " (");
-
-    for (nn = 0; qTypes[nn].typeBytes != NULL; nn++) {
-        if (_dnsPacket_checkBytes(packet, 2, qTypes[nn].typeBytes)) {
-            typeString = qTypes[nn].typeString;
-            break;
-        }
-    }
-
-    if (typeString != NULL)
-        p = bprint_s(p, end, typeString);
-    else {
-        int typeCode = _dnsPacket_readInt16(packet);
-        p = bprint(p, end, "UNKNOWN-%d", typeCode);
-    }
-
-    p = bprint_c(p, end, ')');
-
-    /* skip CLASS */
-    _dnsPacket_skip(packet, 2);
-    return p;
-}
-
-/* this function assumes the packet has already been checked */
-static char* dnsPacket_bprintQuery(DnsPacket* packet, char* p, char* end) {
-    int qdCount;
-
-    if (packet->base[2] & 0x1) {
-        p = bprint_s(p, end, "RECURSIVE ");
-    }
-
-    _dnsPacket_skip(packet, 4);
-    qdCount = _dnsPacket_readInt16(packet);
-    _dnsPacket_skip(packet, 6);
-
-    for (; qdCount > 0; qdCount--) {
-        p = dnsPacket_bprintQR(packet, p, end);
-    }
-    return p;
-}
-
 /** QUERY HASHING SUPPORT
  **
  ** THE FOLLOWING CODE ASSUMES THAT THE INPUT PACKET HAS ALREADY
@@ -1292,26 +1040,15 @@
     return cache;
 }
 
-static void dump_query(const uint8_t* query, int querylen) {
-    if (!WOULD_LOG(VERBOSE)) return;
+static void cache_dump_mru_locked(Cache* cache) {
+    std::string buf;
 
-    char temp[256], *p = temp, *end = p + sizeof(temp);
-    DnsPacket pack[1];
+    StringAppendF(&buf, "MRU LIST (%2d): ", cache->num_entries);
+    for (Entry* e = cache->mru_list.mru_next; e != &cache->mru_list; e = e->mru_next) {
+        StringAppendF(&buf, " %d", e->id);
+    }
 
-    _dnsPacket_init(pack, query, querylen);
-    p = dnsPacket_bprintQuery(pack, p, end);
-    LOG(VERBOSE) << __func__ << ": " << temp;
-}
-
-static void cache_dump_mru(Cache* cache) {
-    char temp[512], *p = temp, *end = p + sizeof(temp);
-    Entry* e;
-
-    p = bprint(temp, end, "MRU LIST (%2d): ", cache->num_entries);
-    for (e = cache->mru_list.mru_next; e != &cache->mru_list; e = e->mru_next)
-        p = bprint(p, end, " %d", e->id);
-
-    LOG(INFO) << __func__ << ": " << temp;
+    LOG(INFO) << __func__ << ": " << buf;
 }
 
 /* This function tries to find a key within the hash table
@@ -1385,7 +1122,7 @@
         return;
     }
     LOG(INFO) << __func__ << ": Cache full - removing oldest";
-    dump_query(oldest->query, oldest->querylen);
+    res_pquery(oldest->query, oldest->querylen);
     _cache_remove_p(cache, lookup);
 }
 
@@ -1430,7 +1167,6 @@
     Cache* cache;
 
     LOG(INFO) << __func__ << ": lookup";
-    dump_query((u_char*) query, querylen);
 
     /* we don't cache malformed queries */
     if (!entry_init_key(&key, query, querylen)) {
@@ -1494,7 +1230,7 @@
     /* remove stale entries here */
     if (now >= e->expires) {
         LOG(INFO) << __func__ << ": NOT IN CACHE (STALE ENTRY " << *lookup << "DISCARDED)";
-        dump_query(e->query, e->querylen);
+        res_pquery(e->query, e->querylen);
         _cache_remove_p(cache, lookup);
         return RESOLV_CACHE_NOTFOUND;
     }
@@ -1540,14 +1276,6 @@
         return -ENONET;
     }
 
-    LOG(INFO) << __func__ << ": query:";
-    dump_query((u_char*)query, querylen);
-    res_pquery((u_char*)answer, answerlen);
-    if (kDumpData) {
-        LOG(INFO) << __func__ << ": answer:";
-        dump_bytes((u_char*)answer, answerlen);
-    }
-
     lookup = _cache_lookup_p(cache, key);
     e = *lookup;
 
@@ -1582,7 +1310,7 @@
         }
     }
 
-    cache_dump_mru(cache);
+    cache_dump_mru_locked(cache);
     _cache_notify_waiting_tid_locked(cache, key);
 
     return 0;
diff --git a/res_debug.cpp b/res_debug.cpp
index 9d01cc5..f249456 100644
--- a/res_debug.cpp
+++ b/res_debug.cpp
@@ -97,9 +97,10 @@
 
 #define LOG_TAG "resolv"
 
+#include "res_debug.h"
+
 #include <sys/param.h>
 #include <sys/socket.h>
-#include <sys/types.h>
 
 #include <arpa/inet.h>
 #include <arpa/nameser.h>
@@ -112,6 +113,7 @@
 #include <errno.h>
 #include <math.h>
 #include <netdb.h>
+#include <netdutils/Slice.h>
 #include <stdlib.h>
 #include <string.h>
 #include <strings.h>
@@ -128,6 +130,7 @@
 #endif
 
 using android::base::StringAppendF;
+using android::netdutils::Slice;
 
 struct res_sym {
     int number;            /* Identifying number, like T_MX */
@@ -150,6 +153,10 @@
             LOG(VERBOSE) << s;
             return;
         }
+        if (rrnum == 0) {
+            int opcode = ns_msg_getflag(*handle, ns_f_opcode);
+            StringAppendF(&s, ";; %s SECTION:\n", p_section(section, opcode));
+        }
         if (section == ns_s_qd)
             StringAppendF(&s, ";;\t%s, type = %s, class = %s\n", ns_rr_name(rr),
                           p_type(ns_rr_type(rr)), p_class(ns_rr_class(rr)));
@@ -277,6 +284,9 @@
     do_section(&handle, ns_s_an);
     do_section(&handle, ns_s_ns);
     do_section(&handle, ns_s_ar);
+
+    LOG(VERBOSE) << "Hex dump:";
+    LOG(VERBOSE) << android::netdutils::toHex(Slice(const_cast<uint8_t*>(msg), len), 32);
 }
 
 /*
diff --git a/res_debug.h b/res_debug.h
index 2ba31d4..cb47b90 100644
--- a/res_debug.h
+++ b/res_debug.h
@@ -16,6 +16,16 @@
 
 #pragma once
 
-#include <cinttypes>
+#include <sys/types.h>
+
+// TODO: use netdutils::Slice for (msg, len).
+void res_pquery(const u_char* msg, int len);
+
+// Thread-unsafe functions returning pointers to static buffers :-(
+// TODO: switch all res_debug to std::string
+const char* p_type(int type);
+const char* p_section(int section, int opcode);
+const char* p_class(int cl);
+const char* p_rcode(int rcode);
 
 int resolv_set_log_severity(uint32_t logSeverity);
diff --git a/res_query.cpp b/res_query.cpp
index f4289e4..4bfdf04 100644
--- a/res_query.cpp
+++ b/res_query.cpp
@@ -86,6 +86,7 @@
 
 #include <android-base/logging.h>
 
+#include "res_debug.h"
 #include "resolv_cache.h"
 #include "resolv_private.h"
 
diff --git a/res_send.cpp b/res_send.cpp
index 85e8a58..fa3371f 100644
--- a/res_send.cpp
+++ b/res_send.cpp
@@ -107,9 +107,9 @@
 #include "netd_resolv/resolv.h"
 #include "netd_resolv/stats.h"
 #include "private/android_filesystem_config.h"
+#include "res_debug.h"
 #include "res_state_ext.h"
 #include "resolv_cache.h"
-#include "resolv_private.h"
 #include "stats.pb.h"
 
 // TODO: use the namespace something like android::netd_resolv for libnetd_resolv
@@ -539,6 +539,8 @@
                 resplen = res_tls_send(statp, Slice(const_cast<u_char*>(buf), buflen),
                                        Slice(ans, anssiz), rcode, &fallback);
                 if (resplen > 0) {
+                    LOG(DEBUG) << __func__ << ": got answer from DoT";
+                    res_pquery(ans, resplen);
                     if (cache_status == RESOLV_CACHE_NOTFOUND) {
                         resolv_cache_add(statp->netid, buf, buflen, ans, resplen);
                     }
@@ -772,9 +774,7 @@
                     return -1;
             }
         }
-        if (fchown(statp->_vcsock, statp->uid, -1) == -1) {
-            PLOG(WARNING) << __func__ << ": Failed to chown socket";
-        }
+        resolv_tag_socket(statp->_vcsock, statp->uid);
         if (statp->_mark != MARK_UNSET) {
             if (setsockopt(statp->_vcsock, SOL_SOCKET, SO_MARK, &statp->_mark,
                            sizeof(statp->_mark)) < 0) {
@@ -1017,9 +1017,7 @@
             }
         }
 
-        if (fchown(statp->_u._ext.nssocks[ns], statp->uid, -1) == -1) {
-            PLOG(WARNING) << __func__ << ": Failed to chown socket";
-        }
+        resolv_tag_socket(statp->_u._ext.nssocks[ns], statp->uid);
         if (statp->_mark != MARK_UNSET) {
             if (setsockopt(statp->_u._ext.nssocks[ns], SOL_SOCKET, SO_MARK, &(statp->_mark),
                            sizeof(statp->_mark)) < 0) {
@@ -1223,7 +1221,7 @@
         return -1;
     }
 
-    if (privateDnsStatus.validatedServers.empty()) {
+    if (privateDnsStatus.validatedServers().empty()) {
         if (privateDnsStatus.mode == PrivateDnsMode::OPPORTUNISTIC) {
             *fallback = true;
             return -1;
@@ -1242,12 +1240,14 @@
             // network change.
             for (int i = 0; i < 42; i++) {
                 std::this_thread::sleep_for(std::chrono::milliseconds(100));
-                if (!gPrivateDnsConfiguration.getStatus(netId).validatedServers.empty()) {
+                // Calling getStatus() to merely check if there's any validated server seems
+                // wasteful. Consider adding a new method in PrivateDnsConfiguration for speed ups.
+                if (!gPrivateDnsConfiguration.getStatus(netId).validatedServers().empty()) {
                     privateDnsStatus = gPrivateDnsConfiguration.getStatus(netId);
                     break;
                 }
             }
-            if (privateDnsStatus.validatedServers.empty()) {
+            if (privateDnsStatus.validatedServers().empty()) {
                 return -1;
             }
         }
@@ -1255,7 +1255,7 @@
 
     LOG(INFO) << __func__ << ": performing query over TLS";
 
-    const auto response = sDnsTlsDispatcher.query(privateDnsStatus.validatedServers, statp, query,
+    const auto response = sDnsTlsDispatcher.query(privateDnsStatus.validatedServers(), statp, query,
                                                   answer, &resplen);
 
     LOG(INFO) << __func__ << ": TLS query result: " << static_cast<int>(response);
diff --git a/res_cache_test.cpp b/resolv_cache_unit_test.cpp
similarity index 100%
rename from res_cache_test.cpp
rename to resolv_cache_unit_test.cpp
diff --git a/resolver_test.cpp b/resolv_integration_test.cpp
similarity index 98%
rename from resolver_test.cpp
rename to resolv_integration_test.cpp
index d7cd263..5014983 100644
--- a/resolver_test.cpp
+++ b/resolv_integration_test.cpp
@@ -171,6 +171,10 @@
         return sDnsMetricsListener->waitForNat64Prefix(status, timeout);
     }
 
+    bool WaitForPrivateDnsValidation(std::string serverAddr, bool validated) {
+        return sDnsMetricsListener->waitForPrivateDnsValidation(serverAddr, validated);
+    }
+
     DnsResponderClient mDnsClient;
 
     // Use a shared static DNS listener for all tests to avoid registering lots of listeners
@@ -1091,7 +1095,7 @@
     // So, wait for private DNS validation done before stopping backend DNS servers.
     for (int i = 0; i < MAXNS; i++) {
         LOG(INFO) << "Waiting for private DNS validation on " << tls[i]->listen_address() << ".";
-        EXPECT_TRUE(tls[i]->waitForQueries(1, 5000));
+        EXPECT_TRUE(WaitForPrivateDnsValidation(tls[i]->listen_address(), true));
         LOG(INFO) << "private DNS validation on " << tls[i]->listen_address() << " done.";
     }
 
@@ -1262,13 +1266,9 @@
     test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
     ASSERT_TRUE(tls.startServer());
     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
+    EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
 
-    const hostent* result;
-
-    // Wait for validation to complete.
-    EXPECT_TRUE(tls.waitForQueries(1, 5000));
-
-    result = gethostbyname("tls1");
+    const hostent* result = gethostbyname("tls1");
     ASSERT_FALSE(result == nullptr);
     EXPECT_EQ("1.2.3.1", ToString(result));
 
@@ -1326,14 +1326,10 @@
     ASSERT_TRUE(tls2.startServer());
     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams,
                                                kDefaultPrivateDnsHostName));
+    EXPECT_TRUE(WaitForPrivateDnsValidation(tls1.listen_address(), true));
+    EXPECT_TRUE(WaitForPrivateDnsValidation(tls2.listen_address(), true));
 
-    const hostent* result;
-
-    // Wait for validation to complete.
-    EXPECT_TRUE(tls1.waitForQueries(1, 5000));
-    EXPECT_TRUE(tls2.waitForQueries(1, 5000));
-
-    result = gethostbyname("tlsfailover1");
+    const hostent* result = gethostbyname("tlsfailover1");
     ASSERT_FALSE(result == nullptr);
     EXPECT_EQ("1.2.3.1", ToString(result));
 
@@ -1376,7 +1372,7 @@
 
     // The TLS handshake would fail because the name of TLS server doesn't
     // match with TLS server's certificate.
-    EXPECT_FALSE(tls.waitForQueries(1, 500));
+    EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), false));
 
     // The query should fail hard, because a name was specified.
     EXPECT_EQ(nullptr, gethostbyname("badtlsname"));
@@ -1403,9 +1399,7 @@
     ASSERT_TRUE(tls.startServer());
     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams,
                                                kDefaultPrivateDnsHostName));
-
-    // Wait for validation to complete.
-    EXPECT_TRUE(tls.waitForQueries(1, 5000));
+    EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
 
     dns.clearQueries();
     ScopedAddrinfo result = safe_getaddrinfo("addrinfotls", nullptr, nullptr);
@@ -1505,13 +1499,16 @@
         } else if (config.mode == OPPORTUNISTIC) {
             ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
                                                        kDefaultParams, ""));
-            // Wait for validation to complete.
-            if (config.withWorkingTLS) EXPECT_TRUE(tls.waitForQueries(1, 5000));
+
+            // Wait for the validation event. If the server is running, the validation should
+            // be successful; otherwise, the validation should be failed.
+            EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), config.withWorkingTLS));
         } else if (config.mode == STRICT) {
             ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
                                                        kDefaultParams, kDefaultPrivateDnsHostName));
-            // Wait for validation to complete.
-            if (config.withWorkingTLS) EXPECT_TRUE(tls.waitForQueries(1, 5000));
+
+            // Wait for the validation event.
+            EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), config.withWorkingTLS));
         }
         tls.clearQueries();
 
@@ -2231,22 +2228,21 @@
             }
             ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
                                                        kDefaultParams, ""));
+            EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), false));
         } else if (config.mode == OPPORTUNISTIC_TLS) {
             if (!tls.running()) {
                 ASSERT_TRUE(tls.startServer());
             }
             ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
                                                        kDefaultParams, ""));
-            // Wait for validation to complete.
-            EXPECT_TRUE(tls.waitForQueries(1, 5000));
+            EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
         } else if (config.mode == STRICT) {
             if (!tls.running()) {
                 ASSERT_TRUE(tls.startServer());
             }
             ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
                                                        kDefaultParams, kDefaultPrivateDnsHostName));
-            // Wait for validation to complete.
-            EXPECT_TRUE(tls.waitForQueries(1, 5000));
+            EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
         }
 
         if (config.method == GETHOSTBYNAME) {
@@ -2303,8 +2299,8 @@
     test::DnsTlsFrontend tls(CLEARTEXT_ADDR, TLS_PORT, CLEARTEXT_ADDR, CLEARTEXT_PORT);
     ASSERT_TRUE(tls.startServer());
     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
-    // Wait for validation complete.
-    EXPECT_TRUE(tls.waitForQueries(1, 5000));
+    EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
+
     // Shutdown TLS server to get an error. It's similar to no response case but without waiting.
     tls.stopServer();
 
@@ -2334,8 +2330,8 @@
     test::DnsTlsFrontend tls(CLEARTEXT_ADDR, TLS_PORT, CLEARTEXT_ADDR, CLEARTEXT_PORT);
     ASSERT_TRUE(tls.startServer());
     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
-    // Wait for validation complete.
-    EXPECT_TRUE(tls.waitForQueries(1, 5000));
+    EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
+
     // Shutdown TLS server to get an error. It's similar to no response case but without waiting.
     tls.stopServer();
     dns.setEdns(test::DNSResponder::Edns::FORMERR_UNCOND);
@@ -3200,7 +3196,7 @@
 
     // Setup OPPORTUNISTIC mode and wait for the validation complete.
     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
-    EXPECT_TRUE(tls.waitForQueries(1, 5000));
+    EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
     tls.clearQueries();
 
     // Start NAT64 prefix discovery and wait for it complete.
@@ -3219,7 +3215,7 @@
     // Setup STRICT mode and wait for the validation complete.
     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams,
                                                kDefaultPrivateDnsHostName));
-    EXPECT_TRUE(tls.waitForQueries(1, 5000));
+    EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
     tls.clearQueries();
 
     // Start NAT64 prefix discovery and wait for it to complete.
diff --git a/resolv_private.h b/resolv_private.h
index 0e0cbad..f2668b5 100644
--- a/resolv_private.h
+++ b/resolv_private.h
@@ -61,6 +61,7 @@
 #include <string>
 #include <vector>
 
+#include "DnsResolver.h"
 #include "netd_resolv/params.h"
 #include "netd_resolv/resolv.h"
 #include "netd_resolv/stats.h"
@@ -175,18 +176,10 @@
 void putlong(uint32_t, u_char*);
 void putshort(uint16_t, u_char*);
 
-// Thread-unsafe functions returning pointers to static buffers :-(
-// TODO: switch all res_debug to std::string
-const char* p_class(int);
-const char* p_type(int);
-const char* p_rcode(int);
-const char* p_section(int, int);
-
 int res_nameinquery(const char*, int, int, const u_char*, const u_char*);
 int res_queriesmatch(const u_char*, const u_char*, const u_char*, const u_char*);
 /* Things involving a resolver context. */
 int res_ninit(res_state);
-void res_pquery(const u_char*, int);
 
 int res_nquery(res_state, const char*, int, int, u_char*, int, int*);
 int res_nsearch(res_state, const char*, int, int, u_char*, int, int*);
@@ -225,4 +218,16 @@
 
 android::net::IpVersion ipFamilyToIPVersion(int ipFamily);
 
+inline void resolv_tag_socket(int sock, uid_t uid) {
+    if (android::net::gResNetdCallbacks.tagSocket != nullptr) {
+        if (int err = android::net::gResNetdCallbacks.tagSocket(sock, TAG_SYSTEM_DNS, uid)) {
+            LOG(WARNING) << "Failed to tag socket: " << strerror(-err);
+        }
+    }
+
+    if (fchown(sock, uid, -1) == -1) {
+        LOG(WARNING) << "Failed to chown socket: " << strerror(errno);
+    }
+}
+
 #endif  // NETD_RESOLV_PRIVATE_H
diff --git a/dns_tls_test.cpp b/resolv_tls_unit_test.cpp
similarity index 100%
rename from dns_tls_test.cpp
rename to resolv_tls_unit_test.cpp
diff --git a/libnetd_resolv_test.cpp b/resolv_unit_test.cpp
similarity index 86%
rename from libnetd_resolv_test.cpp
rename to resolv_unit_test.cpp
index 2820b07..f9accd0 100644
--- a/libnetd_resolv_test.cpp
+++ b/resolv_unit_test.cpp
@@ -16,10 +16,10 @@
 
 #define LOG_TAG "resolv"
 
-#include <gtest/gtest.h>
-
 #include <android-base/stringprintf.h>
 #include <arpa/inet.h>
+#include <gmock/gmock-matchers.h>
+#include <gtest/gtest.h>
 #include <netdb.h>
 #include <netdutils/InternetAddresses.h>
 
@@ -39,10 +39,14 @@
 using android::net::NetworkDnsEventReported;
 using android::netdutils::ScopedAddrinfo;
 
-// Minimize class ResolverTest to be class TestBase because class TestBase doesn't need all member
-// functions of class ResolverTest and class DnsResponderClient.
 class TestBase : public ::testing::Test {
   protected:
+    struct DnsMessage {
+        std::string host_name;   // host name
+        ns_type type;            // record type
+        test::DNSHeader header;  // dns header
+    };
+
     void SetUp() override {
         // Create cache for test
         resolv_create_cache_for_net(TEST_NETID);
@@ -52,7 +56,64 @@
         resolv_delete_cache_for_net(TEST_NETID);
     }
 
-    int setResolvers() {
+    test::DNSRecord MakeAnswerRecord(const std::string& name, unsigned rclass, unsigned rtype,
+                                     const std::string& rdata, unsigned ttl = kAnswerRecordTtlSec) {
+        test::DNSRecord record{
+                .name = {.name = name},
+                .rtype = rtype,
+                .rclass = rclass,
+                .ttl = ttl,
+        };
+        EXPECT_TRUE(test::DNSResponder::fillAnswerRdata(rdata, record));
+        return record;
+    }
+
+    DnsMessage MakeDnsMessage(const std::string& qname, ns_type qtype,
+                              const std::vector<std::string>& rdata) {
+        const unsigned qclass = ns_c_in;
+        // Build a DNSHeader in the following format.
+        // Question
+        //   <qname>                IN      <qtype>
+        // Answer
+        //   <qname>                IN      <qtype>     <rdata[0]>
+        //   ..
+        //   <qname>                IN      <qtype>     <rdata[n]>
+        //
+        // Example:
+        // Question
+        //   hello.example.com.     IN      A
+        // Answer
+        //   hello.example.com.     IN      A           1.2.3.1
+        //   ..
+        //   hello.example.com.     IN      A           1.2.3.9
+        test::DNSHeader header(kDefaultDnsHeader);
+
+        // Question section
+        test::DNSQuestion question{
+                .qname = {.name = qname},
+                .qtype = qtype,
+                .qclass = qclass,
+        };
+        header.questions.push_back(std::move(question));
+
+        // Answer section
+        for (const auto& r : rdata) {
+            test::DNSRecord record = MakeAnswerRecord(qname, qclass, qtype, r);
+            header.answers.push_back(std::move(record));
+        }
+        // TODO: Perhaps add support for authority RRs and additional RRs.
+        return {qname, qtype, header};
+    }
+
+    void StartDns(test::DNSResponder& dns, const std::vector<DnsMessage>& messages) {
+        for (const auto& m : messages) {
+            dns.addMappingDnsHeader(m.host_name, m.type, m.header);
+        }
+        ASSERT_TRUE(dns.startServer());
+        dns.clearQueries();
+    }
+
+    int SetResolvers() {
         const std::vector<std::string> servers = {test::kDefaultListenAddr};
         const std::vector<std::string> domains = {"example.com"};
         const res_params params = {
@@ -326,7 +387,7 @@
     test::DNSResponder dns;
     dns.addMapping(v4_host_name, ns_type::ns_t_a, "1.2.3.3");
     ASSERT_TRUE(dns.startServer());
-    ASSERT_EQ(0, setResolvers());
+    ASSERT_EQ(0, SetResolvers());
 
     // Want AAAA answer but DNS server has A answer only.
     addrinfo* result = nullptr;
@@ -348,7 +409,7 @@
     dns.addMapping(host_name, ns_type::ns_t_a, v4addr);
     dns.addMapping(host_name, ns_type::ns_t_aaaa, v6addr);
     ASSERT_TRUE(dns.startServer());
-    ASSERT_EQ(0, setResolvers());
+    ASSERT_EQ(0, SetResolvers());
 
     static const struct TestConfig {
         int ai_family;
@@ -377,7 +438,7 @@
 TEST_F(ResolvGetAddrInfoTest, IllegalHostname) {
     test::DNSResponder dns;
     ASSERT_TRUE(dns.startServer());
-    ASSERT_EQ(0, setResolvers());
+    ASSERT_EQ(0, SetResolvers());
 
     // Illegal hostname is verified by res_hnok() in system/netd/resolv/res_comp.cpp.
     static constexpr char const* illegalHostnames[] = {
@@ -440,7 +501,7 @@
         dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
         dns.setResponseProbability(0.0);  // always ignore requests and response preset rcode
         ASSERT_TRUE(dns.startServer());
-        ASSERT_EQ(0, setResolvers());
+        ASSERT_EQ(0, SetResolvers());
 
         addrinfo* result = nullptr;
         const addrinfo hints = {.ai_family = AF_UNSPEC};
@@ -457,7 +518,7 @@
     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
     dns.setResponseProbability(0.0);  // always ignore requests and don't response
     ASSERT_TRUE(dns.startServer());
-    ASSERT_EQ(0, setResolvers());
+    ASSERT_EQ(0, SetResolvers());
 
     addrinfo* result = nullptr;
     const addrinfo hints = {.ai_family = AF_UNSPEC};
@@ -474,7 +535,7 @@
     dns.addMapping("cnames.example.com.", ns_type::ns_t_cname, "acname.example.com.");
     dns.addMapping("acname.example.com.", ns_type::ns_t_cname, "hello.example.com.");
     ASSERT_TRUE(dns.startServer());
-    ASSERT_EQ(0, setResolvers());
+    ASSERT_EQ(0, SetResolvers());
 
     static const struct TestConfig {
         const char* name;
@@ -507,7 +568,7 @@
 TEST_F(ResolvGetAddrInfoTest, CnamesBrokenChainByIllegalCname) {
     test::DNSResponder dns;
     ASSERT_TRUE(dns.startServer());
-    ASSERT_EQ(0, setResolvers());
+    ASSERT_EQ(0, SetResolvers());
 
     static const struct TestConfig {
         const char* name;
@@ -561,7 +622,7 @@
     dns.addMapping("hello.example.com.", ns_type::ns_t_cname, "a.example.com.");
     dns.addMapping("a.example.com.", ns_type::ns_t_cname, "hello.example.com.");
     ASSERT_TRUE(dns.startServer());
-    ASSERT_EQ(0, setResolvers());
+    ASSERT_EQ(0, SetResolvers());
 
     for (const auto& family : {AF_INET, AF_INET6, AF_UNSPEC}) {
         SCOPED_TRACE(StringPrintf("family: %d", family));
@@ -576,6 +637,52 @@
     }
 }
 
+TEST_F(ResolvGetAddrInfoTest, MultiAnswerSections) {
+    test::DNSResponder dns(test::DNSResponder::MappingType::DNS_HEADER);
+    // Answer section for query type {A, AAAA}
+    // Type A:
+    //   hello.example.com.   IN    A       1.2.3.1
+    //   hello.example.com.   IN    A       1.2.3.2
+    // Type AAAA:
+    //   hello.example.com.   IN    AAAA    2001:db8::41
+    //   hello.example.com.   IN    AAAA    2001:db8::42
+    StartDns(dns, {MakeDnsMessage(kHelloExampleCom, ns_type::ns_t_a, {"1.2.3.1", "1.2.3.2"}),
+                   MakeDnsMessage(kHelloExampleCom, ns_type::ns_t_aaaa,
+                                  {"2001:db8::41", "2001:db8::42"})});
+    ASSERT_EQ(0, SetResolvers());
+
+    for (const auto& family : {AF_INET, AF_INET6, AF_UNSPEC}) {
+        SCOPED_TRACE(StringPrintf("family: %d", family));
+
+        addrinfo* res = nullptr;
+        // If the socket type is not specified, every address will appear twice, once for
+        // SOCK_STREAM and one for SOCK_DGRAM. Just pick one because the addresses for
+        // the second query of different socket type are responded by the cache.
+        const addrinfo hints = {.ai_family = family, .ai_socktype = SOCK_STREAM};
+        NetworkDnsEventReported event;
+        int rv = resolv_getaddrinfo("hello", nullptr, &hints, &mNetcontext, &res, &event);
+        ScopedAddrinfo result(res);
+        ASSERT_NE(nullptr, result);
+        ASSERT_EQ(0, rv);
+
+        const std::vector<std::string> result_strs = ToStrings(result);
+        if (family == AF_INET) {
+            EXPECT_EQ(1U, GetNumQueries(dns, kHelloExampleCom));
+            EXPECT_THAT(result_strs, testing::UnorderedElementsAreArray({"1.2.3.1", "1.2.3.2"}));
+        } else if (family == AF_INET6) {
+            EXPECT_EQ(1U, GetNumQueries(dns, kHelloExampleCom));
+            EXPECT_THAT(result_strs,
+                        testing::UnorderedElementsAreArray({"2001:db8::41", "2001:db8::42"}));
+        } else if (family == AF_UNSPEC) {
+            EXPECT_EQ(0U, GetNumQueries(dns, kHelloExampleCom));  // no query because of the cache
+            EXPECT_THAT(result_strs,
+                        testing::UnorderedElementsAreArray(
+                                {"1.2.3.1", "1.2.3.2", "2001:db8::41", "2001:db8::42"}));
+        }
+        dns.clearQueries();
+    }
+}
+
 TEST_F(GetHostByNameForNetContextTest, AlphabeticalHostname) {
     constexpr char host_name[] = "jiababuei.example.com.";
     constexpr char v4addr[] = "1.2.3.4";
@@ -585,7 +692,7 @@
     dns.addMapping(host_name, ns_type::ns_t_a, v4addr);
     dns.addMapping(host_name, ns_type::ns_t_aaaa, v6addr);
     ASSERT_TRUE(dns.startServer());
-    ASSERT_EQ(0, setResolvers());
+    ASSERT_EQ(0, SetResolvers());
 
     static const struct TestConfig {
         int ai_family;
@@ -613,7 +720,7 @@
 TEST_F(GetHostByNameForNetContextTest, IllegalHostname) {
     test::DNSResponder dns;
     ASSERT_TRUE(dns.startServer());
-    ASSERT_EQ(0, setResolvers());
+    ASSERT_EQ(0, SetResolvers());
 
     // Illegal hostname is verified by res_hnok() in system/netd/resolv/res_comp.cpp.
     static constexpr char const* illegalHostnames[] = {
@@ -655,7 +762,7 @@
     test::DNSResponder dns;
     dns.addMapping(v4_host_name, ns_type::ns_t_a, "1.2.3.3");
     ASSERT_TRUE(dns.startServer());
-    ASSERT_EQ(0, setResolvers());
+    ASSERT_EQ(0, SetResolvers());
     dns.clearQueries();
 
     // Want AAAA answer but DNS server has A answer only.
@@ -695,7 +802,7 @@
         dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
         dns.setResponseProbability(0.0);  // always ignore requests and response preset rcode
         ASSERT_TRUE(dns.startServer());
-        ASSERT_EQ(0, setResolvers());
+        ASSERT_EQ(0, SetResolvers());
 
         hostent* hp = nullptr;
         NetworkDnsEventReported event;
@@ -712,7 +819,7 @@
     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
     dns.setResponseProbability(0.0);  // always ignore requests and don't response
     ASSERT_TRUE(dns.startServer());
-    ASSERT_EQ(0, setResolvers());
+    ASSERT_EQ(0, SetResolvers());
 
     hostent* hp = nullptr;
     NetworkDnsEventReported event;
@@ -728,7 +835,7 @@
     dns.addMapping("cnames.example.com.", ns_type::ns_t_cname, "acname.example.com.");
     dns.addMapping("acname.example.com.", ns_type::ns_t_cname, "hello.example.com.");
     ASSERT_TRUE(dns.startServer());
-    ASSERT_EQ(0, setResolvers());
+    ASSERT_EQ(0, SetResolvers());
 
     static const struct TestConfig {
         const char* name;
@@ -756,7 +863,7 @@
 TEST_F(GetHostByNameForNetContextTest, CnamesBrokenChainByIllegalCname) {
     test::DNSResponder dns;
     ASSERT_TRUE(dns.startServer());
-    ASSERT_EQ(0, setResolvers());
+    ASSERT_EQ(0, SetResolvers());
 
     static const struct TestConfig {
         const char* name;
@@ -809,7 +916,7 @@
     dns.addMapping("hello.example.com.", ns_type::ns_t_cname, "a.example.com.");
     dns.addMapping("a.example.com.", ns_type::ns_t_cname, "hello.example.com.");
     ASSERT_TRUE(dns.startServer());
-    ASSERT_EQ(0, setResolvers());
+    ASSERT_EQ(0, SetResolvers());
 
     for (const auto& family : {AF_INET, AF_INET6}) {
         SCOPED_TRACE(StringPrintf("family: %d", family));
@@ -825,8 +932,6 @@
 // Note that local host file function, files_getaddrinfo(), of resolv_getaddrinfo()
 // is not tested because it only returns a boolean (success or failure) without any error number.
 
-// TODO: Simplify the DNS server configuration, DNSResponder and resolv_set_nameservers, as
-//       ResolverTest does.
 // TODO: Add test for resolv_getaddrinfo().
 //       - DNS response message parsing.
 //           - Unexpected type of resource record (RR).
diff --git a/tests/dns_metrics_listener/dns_metrics_listener.cpp b/tests/dns_metrics_listener/dns_metrics_listener.cpp
index 8f3b646..acfb416 100644
--- a/tests/dns_metrics_listener/dns_metrics_listener.cpp
+++ b/tests/dns_metrics_listener/dns_metrics_listener.cpp
@@ -23,9 +23,11 @@
 namespace net {
 namespace metrics {
 
+using android::base::ScopedLockAssertion;
 using std::chrono::milliseconds;
 
 constexpr milliseconds kRetryIntervalMs{20};
+constexpr milliseconds kEventTimeoutMs{5000};
 
 android::binder::Status DnsMetricsListener::onNat64PrefixEvent(int32_t netId, bool added,
                                                                const std::string& prefixString,
@@ -35,6 +37,20 @@
     return android::binder::Status::ok();
 }
 
+android::binder::Status DnsMetricsListener::onPrivateDnsValidationEvent(
+        int32_t netId, const ::android::String16& ipAddress,
+        const ::android::String16& /*hostname*/, bool validated) {
+    {
+        std::lock_guard lock(mMutex);
+        std::string serverAddr(String8(ipAddress.string()));
+
+        // keep updating the server to have latest validation status.
+        mValidationRecords.insert_or_assign({netId, serverAddr}, validated);
+    }
+    mCv.notify_one();
+    return android::binder::Status::ok();
+}
+
 bool DnsMetricsListener::waitForNat64Prefix(ExpectNat64PrefixStatus status,
                                             milliseconds timeout) const {
     android::base::Timer t;
@@ -50,6 +66,31 @@
     return false;
 }
 
+bool DnsMetricsListener::waitForPrivateDnsValidation(const std::string& serverAddr,
+                                                     const bool validated) {
+    const auto now = std::chrono::steady_clock::now();
+
+    std::unique_lock lock(mMutex);
+    ScopedLockAssertion assume_lock(mMutex);
+
+    // onPrivateDnsValidationEvent() might already be invoked. Search for the record first.
+    do {
+        if (findAndRemoveValidationRecord({mNetId, serverAddr}, validated)) return true;
+    } while (mCv.wait_until(lock, now + kEventTimeoutMs) != std::cv_status::timeout);
+
+    // Timeout.
+    return false;
+}
+
+bool DnsMetricsListener::findAndRemoveValidationRecord(const ServerKey& key, const bool value) {
+    auto it = mValidationRecords.find(key);
+    if (it != mValidationRecords.end() && it->second == value) {
+        mValidationRecords.erase(it);
+        return true;
+    }
+    return false;
+}
+
 }  // namespace metrics
 }  // namespace net
-}  // namespace android
\ No newline at end of file
+}  // namespace android
diff --git a/tests/dns_metrics_listener/dns_metrics_listener.h b/tests/dns_metrics_listener/dns_metrics_listener.h
index 1f8170d..d933b13 100644
--- a/tests/dns_metrics_listener/dns_metrics_listener.h
+++ b/tests/dns_metrics_listener/dns_metrics_listener.h
@@ -16,6 +16,10 @@
 
 #pragma once
 
+#include <condition_variable>
+#include <map>
+#include <utility>
+
 #include <android-base/thread_annotations.h>
 
 #include "base_metrics_listener.h"
@@ -42,20 +46,38 @@
                                                const std::string& prefixString,
                                                int32_t /*prefixLength*/) override;
 
+    android::binder::Status onPrivateDnsValidationEvent(int32_t netId,
+                                                        const ::android::String16& ipAddress,
+                                                        const ::android::String16& /*hostname*/,
+                                                        bool validated) override;
+
     // Wait for expected NAT64 prefix status until timeout.
     bool waitForNat64Prefix(ExpectNat64PrefixStatus status,
                             std::chrono::milliseconds timeout) const;
 
+    // Wait for the expected private DNS validation result until timeout.
+    bool waitForPrivateDnsValidation(const std::string& serverAddr, const bool validated);
+
   private:
+    typedef std::pair<int32_t, std::string> ServerKey;
+
+    // Search mValidationRecords. Return true if |key| exists and its value is equal to
+    // |value|, and then remove it; otherwise, return false.
+    bool findAndRemoveValidationRecord(const ServerKey& key, const bool value) REQUIRES(mMutex);
+
     // Monitor the event which was fired on specific network id.
     const int32_t mNetId;
 
     // The NAT64 prefix of the network |mNetId|. It is updated by the event onNat64PrefixEvent().
     std::string mNat64Prefix GUARDED_BY(mMutex);
 
+    // Used to store the data from onPrivateDnsValidationEvent.
+    std::map<ServerKey, bool> mValidationRecords GUARDED_BY(mMutex);
+
     mutable std::mutex mMutex;
+    std::condition_variable mCv;
 };
 
 }  // namespace metrics
 }  // namespace net
-}  // namespace android
\ No newline at end of file
+}  // namespace android
diff --git a/tests/dns_responder/dns_responder.cpp b/tests/dns_responder/dns_responder.cpp
index 51eb7dd..4bac4b0 100644
--- a/tests/dns_responder/dns_responder.cpp
+++ b/tests/dns_responder/dns_responder.cpp
@@ -434,10 +434,11 @@
 /* DNS responder */
 
 DNSResponder::DNSResponder(std::string listen_address, std::string listen_service,
-                           ns_rcode error_rcode)
+                           ns_rcode error_rcode, MappingType mapping_type)
     : listen_address_(std::move(listen_address)),
       listen_service_(std::move(listen_service)),
-      error_rcode_(error_rcode) {}
+      error_rcode_(error_rcode),
+      mapping_type_(mapping_type) {}
 
 DNSResponder::~DNSResponder() {
     stopServer();
@@ -445,6 +446,7 @@
 
 void DNSResponder::addMapping(const std::string& name, ns_type type, const std::string& addr) {
     std::lock_guard lock(mappings_mutex_);
+    // TODO: Consider using std::map::insert_or_assign().
     auto it = mappings_.find(QueryKey(name, type));
     if (it != mappings_.end()) {
         LOG(INFO) << "Overwriting mapping for (" << name << ", " << dnstype2str(type)
@@ -455,6 +457,23 @@
     mappings_.try_emplace({name, type}, addr);
 }
 
+void DNSResponder::addMappingDnsHeader(const std::string& name, ns_type type,
+                                       const DNSHeader& header) {
+    std::lock_guard lock(mappings_mutex_);
+    // TODO: Consider using std::map::insert_or_assign().
+    auto it = dnsheader_mappings_.find(QueryKey(name, type));
+    if (it != dnsheader_mappings_.end()) {
+        // TODO: Perhaps replace header pointer with header content once DNSHeader::toString() has
+        // been implemented.
+        LOG(INFO) << "Overwriting mapping for (" << name << ", " << dnstype2str(type)
+                  << "), previous header " << (void*)&it->second << " new header "
+                  << (void*)&header;
+        it->second = header;
+        return;
+    }
+    dnsheader_mappings_.try_emplace({name, type}, header);
+}
+
 void DNSResponder::removeMapping(const std::string& name, ns_type type) {
     std::lock_guard lock(mappings_mutex_);
     auto it = mappings_.find(QueryKey(name, type));
@@ -463,7 +482,18 @@
         return;
     }
     LOG(ERROR) << "Cannot remove mapping from (" << name << ", " << dnstype2str(type)
-               << "), not present";
+               << "), not present in registered mappings";
+}
+
+void DNSResponder::removeMappingDnsHeader(const std::string& name, ns_type type) {
+    std::lock_guard lock(mappings_mutex_);
+    auto it = dnsheader_mappings_.find(QueryKey(name, type));
+    if (it != dnsheader_mappings_.end()) {
+        dnsheader_mappings_.erase(it);
+        return;
+    }
+    LOG(ERROR) << "Cannot remove mapping from (" << name << ", " << dnstype2str(type)
+               << "), not present in registered DnsHeader mappings";
 }
 
 void DNSResponder::setResponseProbability(double response_probability) {
@@ -669,7 +699,13 @@
 
     // Make the response. The query has been read into |header| which is used to build and return
     // the response as well.
-    return makeResponse(&header, response, response_len);
+    switch (mapping_type_) {
+        case MappingType::DNS_HEADER:
+            return makeResponseFromDnsHeader(&header, response, response_len);
+        case MappingType::ADDRESS_OR_HOSTNAME:
+        default:
+            return makeResponse(&header, response, response_len);
+    }
 }
 
 bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
@@ -700,7 +736,7 @@
                     .name = {.name = it->first.name},
                     .rtype = it->first.type,
                     .rclass = ns_class::ns_c_in,
-                    .ttl = 5,  // seconds
+                    .ttl = kAnswerRecordTtlSec,  // seconds
             };
             if (!fillAnswerRdata(it->second, record)) return false;
             answers->push_back(std::move(record));
@@ -802,6 +838,58 @@
     return true;
 }
 
+bool DNSResponder::makeResponseFromDnsHeader(DNSHeader* header, char* response,
+                                             size_t* response_len) const {
+    std::lock_guard guard(mappings_mutex_);
+
+    // Support single question record only. It should be okay because res_mkquery() sets "qdcount"
+    // as one for the operation QUERY and handleDNSRequest() checks ns_opcode::ns_o_query before
+    // making a response. In other words, only need to handle the query which has single question
+    // section. See also res_mkquery() in system/netd/resolv/res_mkquery.cpp.
+    // TODO: Perhaps add support for multi-question records.
+    const std::vector<DNSQuestion>& questions = header->questions;
+    if (questions.size() != 1) {
+        LOG(INFO) << "unsupported question count " << questions.size();
+        return makeErrorResponse(header, ns_rcode::ns_r_notimpl, response, response_len);
+    }
+
+    if (questions[0].qclass != ns_class::ns_c_in && questions[0].qclass != ns_class::ns_c_any) {
+        LOG(INFO) << "unsupported question class " << questions[0].qclass;
+        return makeErrorResponse(header, ns_rcode::ns_r_notimpl, response, response_len);
+    }
+
+    const std::string name = questions[0].qname.name;
+    const int qtype = questions[0].qtype;
+    const auto it = dnsheader_mappings_.find(QueryKey(name, qtype));
+    if (it != dnsheader_mappings_.end()) {
+        // Store both "id" and "rd" which comes from query.
+        const unsigned id = header->id;
+        const bool rd = header->rd;
+
+        // Build a response from the registered DNSHeader mapping.
+        *header = it->second;
+        // Assign both "ID" and "RD" fields from query to response. See RFC 1035 section 4.1.1.
+        header->id = id;
+        header->rd = rd;
+    } else {
+        // TODO: handle correctly. See also TODO in addAnswerRecords().
+        LOG(INFO) << "no mapping found for " << name << " " << dnstype2str(qtype)
+                  << ", couldn't build a response from DNSHeader mapping";
+
+        // Note that do nothing as makeResponse() if no mapping is found. It just changes the QR
+        // flag from query (0) to response (1) in the query. Then, send the modified query back as
+        // a response.
+        header->qr = true;
+    }
+
+    char* response_cur = header->write(response, response + *response_len);
+    if (response_cur == nullptr) {
+        return false;
+    }
+    *response_len = response_cur - response;
+    return true;
+}
+
 void DNSResponder::setDeferredResp(bool deferred_resp) {
     std::lock_guard<std::mutex> guard(cv_mutex_for_deferred_resp_);
     deferred_resp_ = deferred_resp;
diff --git a/tests/dns_responder/dns_responder.h b/tests/dns_responder/dns_responder.h
index 41e7312..ed4b089 100644
--- a/tests/dns_responder/dns_responder.h
+++ b/tests/dns_responder/dns_responder.h
@@ -31,6 +31,9 @@
 #include <android-base/thread_annotations.h>
 #include "android-base/unique_fd.h"
 
+// Default TTL of the DNS answer record.
+constexpr unsigned kAnswerRecordTtlSec = 5;
+
 namespace test {
 
 struct DNSName {
@@ -73,6 +76,8 @@
     char* writeIntFields(unsigned rdlen, char* buffer, const char* buffer_end) const;
 };
 
+// TODO: Perhaps rename to DNSMessage. Per RFC 1035 section 4.1, struct DNSHeader more likes a
+// message because it has not only header section but also question section and other RRs.
 struct DNSHeader {
     unsigned id;
     bool ra;
@@ -108,6 +113,7 @@
 
 inline const std::string kDefaultListenAddr = "127.0.0.3";
 inline const std::string kDefaultListenService = "53";
+inline const ns_rcode kDefaultErrorCode = ns_rcode::ns_r_servfail;
 
 /*
  * Simple DNS responder, which replies to queries with the registered response
@@ -122,18 +128,36 @@
         FORMERR_UNCOND,   // DNS server reply FORMERR unconditionally
         DROP              // DNS server not supporting EDNS will not do any response.
     };
+    // Indicate which mapping the DNS server used to build the response.
+    // See also addMapping(), addMappingDnsHeader(), removeMapping(), removeMappingDnsHeader(),
+    // makeResponse(), makeResponseFromDnsHeader().
+    // TODO: Perhaps break class DNSResponder for each mapping.
+    // TODO: Add the mapping from (raw dns query) to (raw dns response).
+    enum class MappingType : uint8_t {
+        ADDRESS_OR_HOSTNAME,  // Use the mapping from (name, type) to (address or hostname)
+        DNS_HEADER,           // Use the mapping from (name, type) to (DNSHeader)
+    };
 
     DNSResponder(std::string listen_address = kDefaultListenAddr,
                  std::string listen_service = kDefaultListenService,
-                 ns_rcode error_rcode = ns_rcode::ns_r_servfail);
+                 ns_rcode error_rcode = kDefaultErrorCode,
+                 DNSResponder::MappingType mapping_type = MappingType::ADDRESS_OR_HOSTNAME);
 
     DNSResponder(ns_rcode error_rcode)
         : DNSResponder(kDefaultListenAddr, kDefaultListenService, error_rcode){};
 
+    DNSResponder(MappingType mapping_type)
+        : DNSResponder(kDefaultListenAddr, kDefaultListenService, kDefaultErrorCode,
+                       mapping_type){};
+
     ~DNSResponder();
 
+    // Functions used for accessing mapping {ADDRESS_OR_HOSTNAME, DNS_HEADER}.
     void addMapping(const std::string& name, ns_type type, const std::string& addr);
+    void addMappingDnsHeader(const std::string& name, ns_type type, const DNSHeader& header);
     void removeMapping(const std::string& name, ns_type type);
+    void removeMappingDnsHeader(const std::string& name, ns_type type);
+
     void setResponseProbability(double response_probability);
     void setEdns(Edns edns);
     bool running() const;
@@ -182,9 +206,13 @@
 
     bool generateErrorResponse(DNSHeader* header, ns_rcode rcode, char* response,
                                size_t* response_len) const;
+    // TODO: Change makeErrorResponse and makeResponse{, FromDnsHeader} to use C++ containers
+    // instead of the unsafe pointer + length buffer.
     bool makeErrorResponse(DNSHeader* header, ns_rcode rcode, char* response,
                            size_t* response_len) const;
+    // Build a response from mapping {ADDRESS_OR_HOSTNAME, DNS_HEADER}.
     bool makeResponse(DNSHeader* header, char* response, size_t* response_len) const;
+    bool makeResponseFromDnsHeader(DNSHeader* header, char* response, size_t* response_len) const;
 
     // Add a new file descriptor to be polled by the handler thread.
     bool addFd(int fd, uint32_t events);
@@ -204,6 +232,8 @@
     const std::string listen_service_;
     // Error code to return for requests for an unknown name.
     const ns_rcode error_rcode_;
+    // Mapping type the DNS server used to build the response.
+    const MappingType mapping_type_;
     // Probability that a valid response is being sent instead of being sent
     // instead of returning error_rcode_.
     std::atomic<double> response_probability_ = 1.0;
@@ -217,9 +247,14 @@
     // ignoring the requests.
     std::atomic<Edns> edns_ = Edns::ON;
 
-    // Mappings from (name, type) to registered response and the
-    // mutex protecting them.
+    // Mappings used for building the DNS response by registered mapping items. |mapping_type_|
+    // decides which mapping is used. See also makeResponse{, FromDnsHeader}.
+    // - mappings_: Mapping from (name, type) to (address or hostname).
+    // - dnsheader_mappings_: Mapping from (name, type) to (DNSHeader).
     std::unordered_map<QueryKey, std::string, QueryKeyHash> mappings_ GUARDED_BY(mappings_mutex_);
+    std::unordered_map<QueryKey, DNSHeader, QueryKeyHash> dnsheader_mappings_
+            GUARDED_BY(mappings_mutex_);
+
     mutable std::mutex mappings_mutex_;
     // Query names received so far and the corresponding mutex.
     mutable std::vector<std::pair<std::string, ns_type>> queries_ GUARDED_BY(queries_mutex_);
diff --git a/tests/resolv_test_utils.h b/tests/resolv_test_utils.h
index 7c43ed9..987d289 100644
--- a/tests/resolv_test_utils.h
+++ b/tests/resolv_test_utils.h
@@ -50,6 +50,20 @@
 static constexpr char kBadCharAtTheEndHost[] = "hello.example.com^.";
 static constexpr char kBadCharInTheMiddleOfLabelHost[] = "hello.ex^ample.com.";
 
+static const test::DNSHeader kDefaultDnsHeader = {
+        // Don't need to initialize the flag "id" and "rd" because DNS responder assigns them from
+        // query to response. See RFC 1035 section 4.1.1.
+        .id = 0,                // unused. should be assigned from query to response
+        .ra = false,            // recursive query support is not available
+        .rcode = ns_r_noerror,  // no error
+        .qr = true,             // message is a response
+        .opcode = QUERY,        // a standard query
+        .aa = false,            // answer/authority portion was not authenticated by the server
+        .tr = false,            // message is not truncated
+        .rd = false,            // unused. should be assigned from query to response
+        .ad = false,            // non-authenticated data is unacceptable
+};
+
 size_t GetNumQueries(const test::DNSResponder& dns, const char* name);
 size_t GetNumQueriesForType(const test::DNSResponder& dns, ns_type type, const char* name);
 std::string ToString(const hostent* he);