Merge "Let resolv_integration_test listen to onPrivateDnsValidationEvent"
diff --git a/resolver_test.cpp b/resolver_test.cpp
index d7cd263..5014983 100644
--- a/resolver_test.cpp
+++ b/resolver_test.cpp
@@ -171,6 +171,10 @@
return sDnsMetricsListener->waitForNat64Prefix(status, timeout);
}
+ bool WaitForPrivateDnsValidation(std::string serverAddr, bool validated) {
+ return sDnsMetricsListener->waitForPrivateDnsValidation(serverAddr, validated);
+ }
+
DnsResponderClient mDnsClient;
// Use a shared static DNS listener for all tests to avoid registering lots of listeners
@@ -1091,7 +1095,7 @@
// So, wait for private DNS validation done before stopping backend DNS servers.
for (int i = 0; i < MAXNS; i++) {
LOG(INFO) << "Waiting for private DNS validation on " << tls[i]->listen_address() << ".";
- EXPECT_TRUE(tls[i]->waitForQueries(1, 5000));
+ EXPECT_TRUE(WaitForPrivateDnsValidation(tls[i]->listen_address(), true));
LOG(INFO) << "private DNS validation on " << tls[i]->listen_address() << " done.";
}
@@ -1262,13 +1266,9 @@
test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
ASSERT_TRUE(tls.startServer());
ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
+ EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
- const hostent* result;
-
- // Wait for validation to complete.
- EXPECT_TRUE(tls.waitForQueries(1, 5000));
-
- result = gethostbyname("tls1");
+ const hostent* result = gethostbyname("tls1");
ASSERT_FALSE(result == nullptr);
EXPECT_EQ("1.2.3.1", ToString(result));
@@ -1326,14 +1326,10 @@
ASSERT_TRUE(tls2.startServer());
ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams,
kDefaultPrivateDnsHostName));
+ EXPECT_TRUE(WaitForPrivateDnsValidation(tls1.listen_address(), true));
+ EXPECT_TRUE(WaitForPrivateDnsValidation(tls2.listen_address(), true));
- const hostent* result;
-
- // Wait for validation to complete.
- EXPECT_TRUE(tls1.waitForQueries(1, 5000));
- EXPECT_TRUE(tls2.waitForQueries(1, 5000));
-
- result = gethostbyname("tlsfailover1");
+ const hostent* result = gethostbyname("tlsfailover1");
ASSERT_FALSE(result == nullptr);
EXPECT_EQ("1.2.3.1", ToString(result));
@@ -1376,7 +1372,7 @@
// The TLS handshake would fail because the name of TLS server doesn't
// match with TLS server's certificate.
- EXPECT_FALSE(tls.waitForQueries(1, 500));
+ EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), false));
// The query should fail hard, because a name was specified.
EXPECT_EQ(nullptr, gethostbyname("badtlsname"));
@@ -1403,9 +1399,7 @@
ASSERT_TRUE(tls.startServer());
ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams,
kDefaultPrivateDnsHostName));
-
- // Wait for validation to complete.
- EXPECT_TRUE(tls.waitForQueries(1, 5000));
+ EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
dns.clearQueries();
ScopedAddrinfo result = safe_getaddrinfo("addrinfotls", nullptr, nullptr);
@@ -1505,13 +1499,16 @@
} else if (config.mode == OPPORTUNISTIC) {
ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
kDefaultParams, ""));
- // Wait for validation to complete.
- if (config.withWorkingTLS) EXPECT_TRUE(tls.waitForQueries(1, 5000));
+
+ // Wait for the validation event. If the server is running, the validation should
+ // be successful; otherwise, the validation should be failed.
+ EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), config.withWorkingTLS));
} else if (config.mode == STRICT) {
ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
kDefaultParams, kDefaultPrivateDnsHostName));
- // Wait for validation to complete.
- if (config.withWorkingTLS) EXPECT_TRUE(tls.waitForQueries(1, 5000));
+
+ // Wait for the validation event.
+ EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), config.withWorkingTLS));
}
tls.clearQueries();
@@ -2231,22 +2228,21 @@
}
ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
kDefaultParams, ""));
+ EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), false));
} else if (config.mode == OPPORTUNISTIC_TLS) {
if (!tls.running()) {
ASSERT_TRUE(tls.startServer());
}
ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
kDefaultParams, ""));
- // Wait for validation to complete.
- EXPECT_TRUE(tls.waitForQueries(1, 5000));
+ EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
} else if (config.mode == STRICT) {
if (!tls.running()) {
ASSERT_TRUE(tls.startServer());
}
ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
kDefaultParams, kDefaultPrivateDnsHostName));
- // Wait for validation to complete.
- EXPECT_TRUE(tls.waitForQueries(1, 5000));
+ EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
}
if (config.method == GETHOSTBYNAME) {
@@ -2303,8 +2299,8 @@
test::DnsTlsFrontend tls(CLEARTEXT_ADDR, TLS_PORT, CLEARTEXT_ADDR, CLEARTEXT_PORT);
ASSERT_TRUE(tls.startServer());
ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
- // Wait for validation complete.
- EXPECT_TRUE(tls.waitForQueries(1, 5000));
+ EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
+
// Shutdown TLS server to get an error. It's similar to no response case but without waiting.
tls.stopServer();
@@ -2334,8 +2330,8 @@
test::DnsTlsFrontend tls(CLEARTEXT_ADDR, TLS_PORT, CLEARTEXT_ADDR, CLEARTEXT_PORT);
ASSERT_TRUE(tls.startServer());
ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
- // Wait for validation complete.
- EXPECT_TRUE(tls.waitForQueries(1, 5000));
+ EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
+
// Shutdown TLS server to get an error. It's similar to no response case but without waiting.
tls.stopServer();
dns.setEdns(test::DNSResponder::Edns::FORMERR_UNCOND);
@@ -3200,7 +3196,7 @@
// Setup OPPORTUNISTIC mode and wait for the validation complete.
ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
- EXPECT_TRUE(tls.waitForQueries(1, 5000));
+ EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
tls.clearQueries();
// Start NAT64 prefix discovery and wait for it complete.
@@ -3219,7 +3215,7 @@
// Setup STRICT mode and wait for the validation complete.
ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams,
kDefaultPrivateDnsHostName));
- EXPECT_TRUE(tls.waitForQueries(1, 5000));
+ EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
tls.clearQueries();
// Start NAT64 prefix discovery and wait for it to complete.
diff --git a/tests/dns_metrics_listener/dns_metrics_listener.cpp b/tests/dns_metrics_listener/dns_metrics_listener.cpp
index 8f3b646..acfb416 100644
--- a/tests/dns_metrics_listener/dns_metrics_listener.cpp
+++ b/tests/dns_metrics_listener/dns_metrics_listener.cpp
@@ -23,9 +23,11 @@
namespace net {
namespace metrics {
+using android::base::ScopedLockAssertion;
using std::chrono::milliseconds;
constexpr milliseconds kRetryIntervalMs{20};
+constexpr milliseconds kEventTimeoutMs{5000};
android::binder::Status DnsMetricsListener::onNat64PrefixEvent(int32_t netId, bool added,
const std::string& prefixString,
@@ -35,6 +37,20 @@
return android::binder::Status::ok();
}
+android::binder::Status DnsMetricsListener::onPrivateDnsValidationEvent(
+ int32_t netId, const ::android::String16& ipAddress,
+ const ::android::String16& /*hostname*/, bool validated) {
+ {
+ std::lock_guard lock(mMutex);
+ std::string serverAddr(String8(ipAddress.string()));
+
+ // keep updating the server to have latest validation status.
+ mValidationRecords.insert_or_assign({netId, serverAddr}, validated);
+ }
+ mCv.notify_one();
+ return android::binder::Status::ok();
+}
+
bool DnsMetricsListener::waitForNat64Prefix(ExpectNat64PrefixStatus status,
milliseconds timeout) const {
android::base::Timer t;
@@ -50,6 +66,31 @@
return false;
}
+bool DnsMetricsListener::waitForPrivateDnsValidation(const std::string& serverAddr,
+ const bool validated) {
+ const auto now = std::chrono::steady_clock::now();
+
+ std::unique_lock lock(mMutex);
+ ScopedLockAssertion assume_lock(mMutex);
+
+ // onPrivateDnsValidationEvent() might already be invoked. Search for the record first.
+ do {
+ if (findAndRemoveValidationRecord({mNetId, serverAddr}, validated)) return true;
+ } while (mCv.wait_until(lock, now + kEventTimeoutMs) != std::cv_status::timeout);
+
+ // Timeout.
+ return false;
+}
+
+bool DnsMetricsListener::findAndRemoveValidationRecord(const ServerKey& key, const bool value) {
+ auto it = mValidationRecords.find(key);
+ if (it != mValidationRecords.end() && it->second == value) {
+ mValidationRecords.erase(it);
+ return true;
+ }
+ return false;
+}
+
} // namespace metrics
} // namespace net
-} // namespace android
\ No newline at end of file
+} // namespace android
diff --git a/tests/dns_metrics_listener/dns_metrics_listener.h b/tests/dns_metrics_listener/dns_metrics_listener.h
index 1f8170d..d933b13 100644
--- a/tests/dns_metrics_listener/dns_metrics_listener.h
+++ b/tests/dns_metrics_listener/dns_metrics_listener.h
@@ -16,6 +16,10 @@
#pragma once
+#include <condition_variable>
+#include <map>
+#include <utility>
+
#include <android-base/thread_annotations.h>
#include "base_metrics_listener.h"
@@ -42,20 +46,38 @@
const std::string& prefixString,
int32_t /*prefixLength*/) override;
+ android::binder::Status onPrivateDnsValidationEvent(int32_t netId,
+ const ::android::String16& ipAddress,
+ const ::android::String16& /*hostname*/,
+ bool validated) override;
+
// Wait for expected NAT64 prefix status until timeout.
bool waitForNat64Prefix(ExpectNat64PrefixStatus status,
std::chrono::milliseconds timeout) const;
+ // Wait for the expected private DNS validation result until timeout.
+ bool waitForPrivateDnsValidation(const std::string& serverAddr, const bool validated);
+
private:
+ typedef std::pair<int32_t, std::string> ServerKey;
+
+ // Search mValidationRecords. Return true if |key| exists and its value is equal to
+ // |value|, and then remove it; otherwise, return false.
+ bool findAndRemoveValidationRecord(const ServerKey& key, const bool value) REQUIRES(mMutex);
+
// Monitor the event which was fired on specific network id.
const int32_t mNetId;
// The NAT64 prefix of the network |mNetId|. It is updated by the event onNat64PrefixEvent().
std::string mNat64Prefix GUARDED_BY(mMutex);
+ // Used to store the data from onPrivateDnsValidationEvent.
+ std::map<ServerKey, bool> mValidationRecords GUARDED_BY(mMutex);
+
mutable std::mutex mMutex;
+ std::condition_variable mCv;
};
} // namespace metrics
} // namespace net
-} // namespace android
\ No newline at end of file
+} // namespace android