Merge "Let resolv_integration_test listen to onPrivateDnsValidationEvent"
am: 6ddb0a2e41

Change-Id: Ib05afddd79261f8c950e68e2c39576902d7519f3
diff --git a/resolver_test.cpp b/resolver_test.cpp
index d7cd263..5014983 100644
--- a/resolver_test.cpp
+++ b/resolver_test.cpp
@@ -171,6 +171,10 @@
         return sDnsMetricsListener->waitForNat64Prefix(status, timeout);
     }
 
+    bool WaitForPrivateDnsValidation(std::string serverAddr, bool validated) {
+        return sDnsMetricsListener->waitForPrivateDnsValidation(serverAddr, validated);
+    }
+
     DnsResponderClient mDnsClient;
 
     // Use a shared static DNS listener for all tests to avoid registering lots of listeners
@@ -1091,7 +1095,7 @@
     // So, wait for private DNS validation done before stopping backend DNS servers.
     for (int i = 0; i < MAXNS; i++) {
         LOG(INFO) << "Waiting for private DNS validation on " << tls[i]->listen_address() << ".";
-        EXPECT_TRUE(tls[i]->waitForQueries(1, 5000));
+        EXPECT_TRUE(WaitForPrivateDnsValidation(tls[i]->listen_address(), true));
         LOG(INFO) << "private DNS validation on " << tls[i]->listen_address() << " done.";
     }
 
@@ -1262,13 +1266,9 @@
     test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
     ASSERT_TRUE(tls.startServer());
     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
+    EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
 
-    const hostent* result;
-
-    // Wait for validation to complete.
-    EXPECT_TRUE(tls.waitForQueries(1, 5000));
-
-    result = gethostbyname("tls1");
+    const hostent* result = gethostbyname("tls1");
     ASSERT_FALSE(result == nullptr);
     EXPECT_EQ("1.2.3.1", ToString(result));
 
@@ -1326,14 +1326,10 @@
     ASSERT_TRUE(tls2.startServer());
     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams,
                                                kDefaultPrivateDnsHostName));
+    EXPECT_TRUE(WaitForPrivateDnsValidation(tls1.listen_address(), true));
+    EXPECT_TRUE(WaitForPrivateDnsValidation(tls2.listen_address(), true));
 
-    const hostent* result;
-
-    // Wait for validation to complete.
-    EXPECT_TRUE(tls1.waitForQueries(1, 5000));
-    EXPECT_TRUE(tls2.waitForQueries(1, 5000));
-
-    result = gethostbyname("tlsfailover1");
+    const hostent* result = gethostbyname("tlsfailover1");
     ASSERT_FALSE(result == nullptr);
     EXPECT_EQ("1.2.3.1", ToString(result));
 
@@ -1376,7 +1372,7 @@
 
     // The TLS handshake would fail because the name of TLS server doesn't
     // match with TLS server's certificate.
-    EXPECT_FALSE(tls.waitForQueries(1, 500));
+    EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), false));
 
     // The query should fail hard, because a name was specified.
     EXPECT_EQ(nullptr, gethostbyname("badtlsname"));
@@ -1403,9 +1399,7 @@
     ASSERT_TRUE(tls.startServer());
     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams,
                                                kDefaultPrivateDnsHostName));
-
-    // Wait for validation to complete.
-    EXPECT_TRUE(tls.waitForQueries(1, 5000));
+    EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
 
     dns.clearQueries();
     ScopedAddrinfo result = safe_getaddrinfo("addrinfotls", nullptr, nullptr);
@@ -1505,13 +1499,16 @@
         } else if (config.mode == OPPORTUNISTIC) {
             ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
                                                        kDefaultParams, ""));
-            // Wait for validation to complete.
-            if (config.withWorkingTLS) EXPECT_TRUE(tls.waitForQueries(1, 5000));
+
+            // Wait for the validation event. If the server is running, the validation should
+            // be successful; otherwise, the validation should be failed.
+            EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), config.withWorkingTLS));
         } else if (config.mode == STRICT) {
             ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
                                                        kDefaultParams, kDefaultPrivateDnsHostName));
-            // Wait for validation to complete.
-            if (config.withWorkingTLS) EXPECT_TRUE(tls.waitForQueries(1, 5000));
+
+            // Wait for the validation event.
+            EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), config.withWorkingTLS));
         }
         tls.clearQueries();
 
@@ -2231,22 +2228,21 @@
             }
             ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
                                                        kDefaultParams, ""));
+            EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), false));
         } else if (config.mode == OPPORTUNISTIC_TLS) {
             if (!tls.running()) {
                 ASSERT_TRUE(tls.startServer());
             }
             ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
                                                        kDefaultParams, ""));
-            // Wait for validation to complete.
-            EXPECT_TRUE(tls.waitForQueries(1, 5000));
+            EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
         } else if (config.mode == STRICT) {
             if (!tls.running()) {
                 ASSERT_TRUE(tls.startServer());
             }
             ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
                                                        kDefaultParams, kDefaultPrivateDnsHostName));
-            // Wait for validation to complete.
-            EXPECT_TRUE(tls.waitForQueries(1, 5000));
+            EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
         }
 
         if (config.method == GETHOSTBYNAME) {
@@ -2303,8 +2299,8 @@
     test::DnsTlsFrontend tls(CLEARTEXT_ADDR, TLS_PORT, CLEARTEXT_ADDR, CLEARTEXT_PORT);
     ASSERT_TRUE(tls.startServer());
     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
-    // Wait for validation complete.
-    EXPECT_TRUE(tls.waitForQueries(1, 5000));
+    EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
+
     // Shutdown TLS server to get an error. It's similar to no response case but without waiting.
     tls.stopServer();
 
@@ -2334,8 +2330,8 @@
     test::DnsTlsFrontend tls(CLEARTEXT_ADDR, TLS_PORT, CLEARTEXT_ADDR, CLEARTEXT_PORT);
     ASSERT_TRUE(tls.startServer());
     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
-    // Wait for validation complete.
-    EXPECT_TRUE(tls.waitForQueries(1, 5000));
+    EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
+
     // Shutdown TLS server to get an error. It's similar to no response case but without waiting.
     tls.stopServer();
     dns.setEdns(test::DNSResponder::Edns::FORMERR_UNCOND);
@@ -3200,7 +3196,7 @@
 
     // Setup OPPORTUNISTIC mode and wait for the validation complete.
     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
-    EXPECT_TRUE(tls.waitForQueries(1, 5000));
+    EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
     tls.clearQueries();
 
     // Start NAT64 prefix discovery and wait for it complete.
@@ -3219,7 +3215,7 @@
     // Setup STRICT mode and wait for the validation complete.
     ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams,
                                                kDefaultPrivateDnsHostName));
-    EXPECT_TRUE(tls.waitForQueries(1, 5000));
+    EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
     tls.clearQueries();
 
     // Start NAT64 prefix discovery and wait for it to complete.
diff --git a/tests/dns_metrics_listener/dns_metrics_listener.cpp b/tests/dns_metrics_listener/dns_metrics_listener.cpp
index 8f3b646..acfb416 100644
--- a/tests/dns_metrics_listener/dns_metrics_listener.cpp
+++ b/tests/dns_metrics_listener/dns_metrics_listener.cpp
@@ -23,9 +23,11 @@
 namespace net {
 namespace metrics {
 
+using android::base::ScopedLockAssertion;
 using std::chrono::milliseconds;
 
 constexpr milliseconds kRetryIntervalMs{20};
+constexpr milliseconds kEventTimeoutMs{5000};
 
 android::binder::Status DnsMetricsListener::onNat64PrefixEvent(int32_t netId, bool added,
                                                                const std::string& prefixString,
@@ -35,6 +37,20 @@
     return android::binder::Status::ok();
 }
 
+android::binder::Status DnsMetricsListener::onPrivateDnsValidationEvent(
+        int32_t netId, const ::android::String16& ipAddress,
+        const ::android::String16& /*hostname*/, bool validated) {
+    {
+        std::lock_guard lock(mMutex);
+        std::string serverAddr(String8(ipAddress.string()));
+
+        // keep updating the server to have latest validation status.
+        mValidationRecords.insert_or_assign({netId, serverAddr}, validated);
+    }
+    mCv.notify_one();
+    return android::binder::Status::ok();
+}
+
 bool DnsMetricsListener::waitForNat64Prefix(ExpectNat64PrefixStatus status,
                                             milliseconds timeout) const {
     android::base::Timer t;
@@ -50,6 +66,31 @@
     return false;
 }
 
+bool DnsMetricsListener::waitForPrivateDnsValidation(const std::string& serverAddr,
+                                                     const bool validated) {
+    const auto now = std::chrono::steady_clock::now();
+
+    std::unique_lock lock(mMutex);
+    ScopedLockAssertion assume_lock(mMutex);
+
+    // onPrivateDnsValidationEvent() might already be invoked. Search for the record first.
+    do {
+        if (findAndRemoveValidationRecord({mNetId, serverAddr}, validated)) return true;
+    } while (mCv.wait_until(lock, now + kEventTimeoutMs) != std::cv_status::timeout);
+
+    // Timeout.
+    return false;
+}
+
+bool DnsMetricsListener::findAndRemoveValidationRecord(const ServerKey& key, const bool value) {
+    auto it = mValidationRecords.find(key);
+    if (it != mValidationRecords.end() && it->second == value) {
+        mValidationRecords.erase(it);
+        return true;
+    }
+    return false;
+}
+
 }  // namespace metrics
 }  // namespace net
-}  // namespace android
\ No newline at end of file
+}  // namespace android
diff --git a/tests/dns_metrics_listener/dns_metrics_listener.h b/tests/dns_metrics_listener/dns_metrics_listener.h
index 1f8170d..d933b13 100644
--- a/tests/dns_metrics_listener/dns_metrics_listener.h
+++ b/tests/dns_metrics_listener/dns_metrics_listener.h
@@ -16,6 +16,10 @@
 
 #pragma once
 
+#include <condition_variable>
+#include <map>
+#include <utility>
+
 #include <android-base/thread_annotations.h>
 
 #include "base_metrics_listener.h"
@@ -42,20 +46,38 @@
                                                const std::string& prefixString,
                                                int32_t /*prefixLength*/) override;
 
+    android::binder::Status onPrivateDnsValidationEvent(int32_t netId,
+                                                        const ::android::String16& ipAddress,
+                                                        const ::android::String16& /*hostname*/,
+                                                        bool validated) override;
+
     // Wait for expected NAT64 prefix status until timeout.
     bool waitForNat64Prefix(ExpectNat64PrefixStatus status,
                             std::chrono::milliseconds timeout) const;
 
+    // Wait for the expected private DNS validation result until timeout.
+    bool waitForPrivateDnsValidation(const std::string& serverAddr, const bool validated);
+
   private:
+    typedef std::pair<int32_t, std::string> ServerKey;
+
+    // Search mValidationRecords. Return true if |key| exists and its value is equal to
+    // |value|, and then remove it; otherwise, return false.
+    bool findAndRemoveValidationRecord(const ServerKey& key, const bool value) REQUIRES(mMutex);
+
     // Monitor the event which was fired on specific network id.
     const int32_t mNetId;
 
     // The NAT64 prefix of the network |mNetId|. It is updated by the event onNat64PrefixEvent().
     std::string mNat64Prefix GUARDED_BY(mMutex);
 
+    // Used to store the data from onPrivateDnsValidationEvent.
+    std::map<ServerKey, bool> mValidationRecords GUARDED_BY(mMutex);
+
     mutable std::mutex mMutex;
+    std::condition_variable mCv;
 };
 
 }  // namespace metrics
 }  // namespace net
-}  // namespace android
\ No newline at end of file
+}  // namespace android