Support evaluating private DNS by latency

The evaluation is limited to opportunistic mode and is implemented
as a flag-off feature. It is introduced to avoid from using high
latency private DNS servers.

The latency of a server is considered high if it's higher than a
latency threshold which is calculated based on the average latency of
cleartext DNS server:

  latency threshold = std::clamp(3 * mean_do53_latency_ms,
      min_private_dns_latency_threshold_ms,
      max_private_dns_latency_threshold_ms)

, where min_private_dns_latency_threshold_ms is 500 ms by default and
max_private_dns_latency_threshold_ms is 2000 ms by default.

If there's no Do53 average latency for reference, the latency threshold
is min_private_dns_latency_threshold_ms.

The evaluation of a private DNS server works in two phases.

Phase 1: In this phase, Private DNS Validation is being performed,
and the server is not considered validated. The server latency is
evaluated by sending a probe. If the latency is lower than a the
latency threshold, the server state is transitioned to Validation::success.
The evaluation goes to phase 2.

Phase 2: In this phase, the server is considered validated and
DnsResolver can send DNS queries to the server. The server latency
is evaluated by the query response time, and the same latency threshold
is used. If there are several, 10 by default, query response time
failed to meet the time threshold in a row, the server state is
transitioned to Validation::in_process. The evaluation goes to phase 1.

Bug: 188153519
Test: run atest with all the flags off/on
        avoid_bad_private_dns: 0 / 1
        sort_nameservers: 0 / 1
        dot_xport_unusable_threshold: -1 / 20
        dot_query_timeout_ms: -1 / 10000
        min_private_dns_latency_threshold_ms: -1 / 500
        keep_listening_udp: 0 / 1
        parallel_lookup_sleep_time: 2 / 2
        dot_revalidation_threshold: -1 / 10
        max_private_dns_latency_threshold_ms: -1 / 2000
        dot_async_handshake: 0 / 1
        dot_maxtries: 3 / 1
        dot_connect_timeout_ms: 127000 / 10000
        parallel_lookup_release: UNSET / UNSET

Change-Id: Ib681b1ea1417eadac9c013f19549a9fa7c408696
diff --git a/DnsTlsDispatcher.cpp b/DnsTlsDispatcher.cpp
index ca577e7..c9ca530 100644
--- a/DnsTlsDispatcher.cpp
+++ b/DnsTlsDispatcher.cpp
@@ -182,9 +182,18 @@
     // stuck, this function also gets blocked.
     const int connectCounter = xport->transport.getConnectCounter();
 
+    Stopwatch stopwatch;
     const auto& result = queryInternal(*xport, query);
+    const int64_t timeTaken = saturate_cast<int64_t>(stopwatch.timeTakenUs() / 1000);
     *connectTriggered = (xport->transport.getConnectCounter() > connectCounter);
 
+    const int64_t targetTime = server.latencyThreshold().value_or(INT64_MAX);
+    const bool latencyTooHigh = timeTaken > targetTime;
+    if (latencyTooHigh) {
+        LOG(WARNING) << "DoT query took too long: " << timeTaken << " ms (threshold: " << targetTime
+                     << "ms)";
+    }
+
     DnsTlsTransport::Response code = result.code;
     if (code == DnsTlsTransport::Response::success) {
         if (result.response.size() > ans.size()) {
@@ -206,7 +215,7 @@
         xport->lastUsed = now;
 
         // DoT revalidation specific feature.
-        if (xport->checkRevalidationNecessary(code)) {
+        if (xport->checkRevalidationNecessary(code, latencyTooHigh)) {
             // Even if the revalidation passes, it doesn't guarantee that DoT queries
             // to the xport can stop failing because revalidation creates a new connection
             // to probe while the xport still uses an existing connection. So far, there isn't
@@ -316,13 +325,21 @@
     return (it == mStore.end() ? nullptr : it->second.get());
 }
 
-bool DnsTlsDispatcher::Transport::checkRevalidationNecessary(DnsTlsTransport::Response code) {
+bool DnsTlsDispatcher::Transport::checkRevalidationNecessary(DnsTlsTransport::Response code,
+                                                             bool latencyTooHigh) {
     if (!revalidationEnabled) return false;
 
     if (code == DnsTlsTransport::Response::network_error) {
         continuousfailureCount++;
+        LOG(WARNING) << "continuousfailureCount incremented: network_error, count = "
+                     << continuousfailureCount;
+    } else if (latencyTooHigh) {
+        continuousfailureCount++;
+        LOG(WARNING) << "continuousfailureCount incremented: latency too High, count = "
+                     << continuousfailureCount;
     } else {
         continuousfailureCount = 0;
+        LOG(WARNING) << "continuousfailureCount reset";
     }
 
     // triggerThreshold must be greater than 0 because the value of revalidationEnabled is true.
diff --git a/DnsTlsDispatcher.h b/DnsTlsDispatcher.h
index f7b27df..d24eb49 100644
--- a/DnsTlsDispatcher.h
+++ b/DnsTlsDispatcher.h
@@ -98,7 +98,8 @@
         // whether or not this Transport is usable.
         bool usable() const REQUIRES(sLock);
 
-        bool checkRevalidationNecessary(DnsTlsTransport::Response code) REQUIRES(sLock);
+        bool checkRevalidationNecessary(DnsTlsTransport::Response code, bool latencyTooHigh)
+                REQUIRES(sLock);
 
         std::chrono::milliseconds timeout() const { return mTimeout; }
 
diff --git a/DnsTlsServer.h b/DnsTlsServer.h
index 7d9cbef..d4c4fc9 100644
--- a/DnsTlsServer.h
+++ b/DnsTlsServer.h
@@ -86,10 +86,19 @@
     bool active() const override { return mActive; }
     void setActive(bool val) override { mActive = val; }
 
+    // Setter and getter for the latency constraint.
+    void setLatencyThreshold(std::optional<int64_t> val) { mLatencyThresholdMs = val; }
+    std::optional<int64_t> latencyThreshold() const { return mLatencyThresholdMs; }
+
   private:
     // State, unrelated to the comparison of DnsTlsServer objects.
     Validation mValidation = Validation::unknown_server;
     bool mActive = false;
+
+    // The DNS response time threshold. If it is set, the latency of this server will be
+    // considered when evaluating the server. It is used for the DoT engine to evaluate whether
+    // this server, compared to cleartext DNS servers, has relatively high latency or not.
+    std::optional<int64_t> mLatencyThresholdMs = std::nullopt;
 };
 
 // This comparison only checks the IP address.  It ignores ports, names, and fingerprints.
diff --git a/Experiments.h b/Experiments.h
index f87d517..f03dfe3 100644
--- a/Experiments.h
+++ b/Experiments.h
@@ -49,11 +49,19 @@
     // TODO: Migrate other experiment flags to here.
     // (retry_count, retransmission_time_interval)
     static constexpr const char* const kExperimentFlagKeyList[] = {
-            "keep_listening_udp",   "parallel_lookup_release",    "parallel_lookup_sleep_time",
-            "sort_nameservers",     "dot_async_handshake",        "dot_connect_timeout_ms",
-            "dot_maxtries",         "dot_revalidation_threshold", "dot_xport_unusable_threshold",
+            "keep_listening_udp",
+            "parallel_lookup_release",
+            "parallel_lookup_sleep_time",
+            "sort_nameservers",
+            "dot_async_handshake",
+            "dot_connect_timeout_ms",
+            "dot_maxtries",
+            "dot_revalidation_threshold",
+            "dot_xport_unusable_threshold",
             "dot_query_timeout_ms",
-    };
+            "avoid_bad_private_dns",
+            "min_private_dns_latency_threshold_ms",
+            "max_private_dns_latency_threshold_ms"};
     // This value is used in updateInternal as the default value if any flags can't be found.
     static constexpr int kFlagIntDefault = INT_MIN;
     // For testing.
diff --git a/PrivateDnsConfiguration.cpp b/PrivateDnsConfiguration.cpp
index b09c7ce..6939227 100644
--- a/PrivateDnsConfiguration.cpp
+++ b/PrivateDnsConfiguration.cpp
@@ -21,18 +21,23 @@
 #include <android-base/format.h>
 #include <android-base/logging.h>
 #include <android-base/stringprintf.h>
+#include <netdutils/Stopwatch.h>
 #include <netdutils/ThreadUtil.h>
 #include <sys/socket.h>
 
 #include "DnsTlsTransport.h"
+#include "Experiments.h"
 #include "ResolverEventReporter.h"
 #include "netd_resolv/resolv.h"
+#include "resolv_cache.h"
+#include "resolv_private.h"
 #include "util.h"
 
 using aidl::android::net::resolv::aidl::IDnsResolverUnsolicitedEventListener;
 using aidl::android::net::resolv::aidl::PrivateDnsValidationEventParcel;
 using android::base::StringPrintf;
 using android::netdutils::setThreadName;
+using android::netdutils::Stopwatch;
 using std::chrono::milliseconds;
 
 namespace android {
@@ -195,6 +200,23 @@
     std::thread validate_thread([this, identity, server, netId, isRevalidation] {
         setThreadName(StringPrintf("TlsVerify_%u", netId).c_str());
 
+        const bool avoidBadPrivateDns =
+                Experiments::getInstance()->getFlag("avoid_bad_private_dns", 0);
+        std::optional<int64_t> latencyThreshold;
+        if (avoidBadPrivateDns) {
+            const int maxLatency = Experiments::getInstance()->getFlag(
+                    "max_private_dns_latency_threshold_ms", kMaxPrivateDnsLatencyThresholdMs);
+            const int minLatency = Experiments::getInstance()->getFlag(
+                    "min_private_dns_latency_threshold_ms", kMinPrivateDnsLatencyThresholdMs);
+            const auto do53Latency = resolv_stats_get_average_response_time(netId, PROTO_UDP);
+            const int target =
+                    do53Latency.has_value() ? (3 * do53Latency.value().count() / 1000) : 0;
+
+            // The time is limited to the range [minLatency, maxLatency].
+            latencyThreshold = std::clamp(target, minLatency, maxLatency);
+        }
+        const bool isOpportunisticMode = server.name.empty();
+
         // cat /proc/sys/net/ipv4/tcp_syn_retries yields "6".
         //
         // Start with a 1 minute delay and backoff to once per hour.
@@ -215,12 +237,24 @@
             // It can take milliseconds to minutes, up to the SYN retry limit.
             LOG(WARNING) << "Validating DnsTlsServer " << server.toIpString() << " with mark 0x"
                          << std::hex << server.validationMark();
-            const bool success = DnsTlsTransport::validate(server, server.validationMark());
-            LOG(WARNING) << "validateDnsTlsServer returned " << success << " for "
-                         << server.toIpString();
 
-            const bool needs_reeval =
-                    this->recordPrivateDnsValidation(identity, netId, success, isRevalidation);
+            Stopwatch stopwatch;
+            const bool gotAnswer = DnsTlsTransport::validate(server, server.validationMark());
+            const int32_t timeTaken = saturate_cast<int32_t>(stopwatch.timeTakenUs() / 1000);
+            LOG(WARNING) << fmt::format("validateDnsTlsServer returned {} for {}, took {}ms",
+                                        gotAnswer, server.toIpString(), timeTaken);
+
+            const int64_t targetTime = latencyThreshold.value_or(INT64_MAX);
+            bool latencyTooHigh = false;
+            if (isOpportunisticMode && timeTaken > targetTime) {
+                latencyTooHigh = true;
+                LOG(WARNING) << "validateDnsTlsServer took too long: threshold is " << targetTime
+                             << "ms";
+            }
+
+            // TODO: combine these boolean variables into a bitwise variable.
+            const bool needs_reeval = this->recordPrivateDnsValidation(
+                    identity, netId, gotAnswer, isRevalidation, latencyTooHigh);
 
             if (!needs_reeval) {
                 break;
@@ -233,6 +267,8 @@
                 break;
             }
         }
+
+        this->updateServerLatencyThreshold(identity, latencyThreshold, netId);
     });
     validate_thread.detach();
 }
@@ -268,8 +304,8 @@
 }
 
 bool PrivateDnsConfiguration::recordPrivateDnsValidation(const ServerIdentity& identity,
-                                                         unsigned netId, bool success,
-                                                         bool isRevalidation) {
+                                                         unsigned netId, bool gotAnswer,
+                                                         bool isRevalidation, bool latencyTooHigh) {
     constexpr bool NEEDS_REEVALUATION = true;
     constexpr bool DONT_REEVALUATE = false;
 
@@ -290,14 +326,17 @@
     }
 
     bool reevaluationStatus = NEEDS_REEVALUATION;
-    if (success) {
-        reevaluationStatus = DONT_REEVALUATE;
+    if (gotAnswer) {
+        if (!latencyTooHigh) {
+            reevaluationStatus = DONT_REEVALUATE;
+        }
     } else if (mode->second == PrivateDnsMode::OFF) {
         reevaluationStatus = DONT_REEVALUATE;
     } else if (mode->second == PrivateDnsMode::OPPORTUNISTIC && !isRevalidation) {
         reevaluationStatus = DONT_REEVALUATE;
     }
 
+    bool success = gotAnswer;
     auto& tracker = netPair->second;
     auto serverPair = tracker.find(identity);
     if (serverPair == tracker.end()) {
@@ -312,10 +351,12 @@
         reevaluationStatus = DONT_REEVALUATE;
     }
 
-    // Send private dns validation result to listeners.
-    sendPrivateDnsValidationEvent(identity, netId, success);
+    const bool succeededQuickly = success && !latencyTooHigh;
 
-    if (success) {
+    // Send private dns validation result to listeners.
+    sendPrivateDnsValidationEvent(identity, netId, succeededQuickly);
+
+    if (succeededQuickly) {
         updateServerState(identity, Validation::success, netId);
     } else {
         // Validation failure is expected if a user is on a captive portal.
@@ -325,7 +366,7 @@
                                                                        : Validation::fail;
         updateServerState(identity, result, netId);
     }
-    LOG(WARNING) << "Validation " << (success ? "success" : "failed");
+    LOG(WARNING) << "Validation " << (succeededQuickly ? "success" : "failed");
 
     return reevaluationStatus;
 }
@@ -385,6 +426,24 @@
     return iter->second.get();
 }
 
+void PrivateDnsConfiguration::updateServerLatencyThreshold(const ServerIdentity& identity,
+                                                           std::optional<int64_t> latencyThreshold,
+                                                           uint32_t netId) {
+    std::lock_guard guard(mPrivateDnsLock);
+
+    const auto result = getPrivateDnsLocked(identity, netId);
+    if (!result.ok()) return;
+
+    if (result.value()->isDot()) {
+        DnsTlsServer& server = *static_cast<DnsTlsServer*>(result.value());
+        server.setLatencyThreshold(latencyThreshold);
+        LOG(INFO) << "Set latencyThreshold "
+                  << (latencyThreshold ? std::to_string(latencyThreshold.value()) + "ms"
+                                       : "nullopt")
+                  << " to " << server.toIpString();
+    }
+}
+
 void PrivateDnsConfiguration::setObserver(PrivateDnsValidationObserver* observer) {
     std::lock_guard guard(mPrivateDnsLock);
     mObserver = observer;
diff --git a/PrivateDnsConfiguration.h b/PrivateDnsConfiguration.h
index f7e6592..3d40040 100644
--- a/PrivateDnsConfiguration.h
+++ b/PrivateDnsConfiguration.h
@@ -95,6 +95,9 @@
   private:
     typedef std::map<ServerIdentity, std::unique_ptr<IPrivateDnsServer>> PrivateDnsTracker;
 
+    static constexpr int kMaxPrivateDnsLatencyThresholdMs = 2000;
+    static constexpr int kMinPrivateDnsLatencyThresholdMs = 500;
+
     PrivateDnsConfiguration() = default;
 
     // Launchs a thread to run the validation for |server| on the network |netId|.
@@ -103,7 +106,8 @@
             REQUIRES(mPrivateDnsLock);
 
     bool recordPrivateDnsValidation(const ServerIdentity& identity, unsigned netId, bool success,
-                                    bool isRevalidation) EXCLUDES(mPrivateDnsLock);
+                                    bool isRevalidation, bool latencyTooHigh)
+            EXCLUDES(mPrivateDnsLock);
 
     void sendPrivateDnsValidationEvent(const ServerIdentity& identity, unsigned netId, bool success)
             REQUIRES(mPrivateDnsLock);
@@ -123,6 +127,10 @@
     base::Result<IPrivateDnsServer*> getPrivateDnsLocked(const ServerIdentity& identity,
                                                          unsigned netId) REQUIRES(mPrivateDnsLock);
 
+    void updateServerLatencyThreshold(const ServerIdentity& identity,
+                                      std::optional<int64_t> latencyThreshold, uint32_t netId)
+            EXCLUDES(mPrivateDnsLock);
+
     mutable std::mutex mPrivateDnsLock;
     std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);
 
diff --git a/PrivateDnsConfigurationTest.cpp b/PrivateDnsConfigurationTest.cpp
index 492e7ec..d7f8159 100644
--- a/PrivateDnsConfigurationTest.cpp
+++ b/PrivateDnsConfigurationTest.cpp
@@ -17,15 +17,44 @@
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 
+#include <android-base/properties.h>
+
+#include "Experiments.h"
 #include "PrivateDnsConfiguration.h"
 #include "tests/dns_responder/dns_responder.h"
 #include "tests/dns_responder/dns_tls_frontend.h"
 #include "tests/resolv_test_utils.h"
 
+#include <android-base/logging.h>
+
 namespace android::net {
 
 using namespace std::chrono_literals;
 
+const std::string kAvoidBadPrivateDnsFlag(
+        "persist.device_config.netd_native.avoid_bad_private_dns");
+const std::string kMinPrivateDnsLatencyThresholdMsFlag(
+        "persist.device_config.netd_native.min_private_dns_latency_threshold_ms");
+const std::string kMaxPrivateDnsLatencyThresholdMsFlag(
+        "persist.device_config.netd_native.max_private_dns_latency_threshold_ms");
+
+namespace {
+
+class ScopedSystemProperty {
+  public:
+    ScopedSystemProperty(const std::string& key, const std::string& value) : mStoredKey(key) {
+        mStoredValue = android::base::GetProperty(key, "");
+        android::base::SetProperty(key, value);
+    }
+    ~ScopedSystemProperty() { android::base::SetProperty(mStoredKey, mStoredValue); }
+
+  private:
+    std::string mStoredKey;
+    std::string mStoredValue;
+};
+
+}  // namespace
+
 class PrivateDnsConfigurationTest : public ::testing::Test {
   public:
     using ServerIdentity = PrivateDnsConfiguration::ServerIdentity;
@@ -71,6 +100,8 @@
                     std::lock_guard guard(mObserver.lock);
                     mObserver.serverStateMap[server] = validation;
                 });
+
+        forceExperimentsInstanceUpdate();
     }
 
   protected:
@@ -118,6 +149,8 @@
         return mPdc.getPrivateDns(identity, netId).ok();
     }
 
+    void forceExperimentsInstanceUpdate() { Experiments::getInstance()->update(); }
+
     static constexpr uint32_t kNetId = 30;
     static constexpr uint32_t kMark = 30;
     static constexpr char kBackend[] = "127.0.2.1";
@@ -194,6 +227,63 @@
     ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
 }
 
+TEST_F(PrivateDnsConfigurationTest, ValidationProbingTime) {
+    // The probing time threshold is 500 milliseconds.
+    ScopedSystemProperty sp1(kMinPrivateDnsLatencyThresholdMsFlag, "500");
+    ScopedSystemProperty sp2(kMaxPrivateDnsLatencyThresholdMsFlag, "1000");
+
+    // TODO: Complete STRICT test after the dependency of DnsTlsFrontend is removed.
+    static const struct TestConfig {
+        std::string dnsMode;
+        bool avoidBadPrivateDns;
+        int probingTime;
+        int expectProbeCount;
+    } testConfigs[] = {
+            // clang-format off
+            {"OPPORTUNISTIC", false,   50, 1},
+            {"OPPORTUNISTIC", false,  750, 1},
+            {"OPPORTUNISTIC", false, 1500, 1},
+            {"OPPORTUNISTIC",  true,   50, 1},
+            {"OPPORTUNISTIC",  true,  750, 2},
+            {"OPPORTUNISTIC",  true, 1500, 2},
+            // clang-format on
+    };
+
+    for (const auto& config : testConfigs) {
+        SCOPED_TRACE(fmt::format("testConfig: [{}, {}, {}, {}]", config.dnsMode,
+                                 config.avoidBadPrivateDns, config.probingTime,
+                                 config.expectProbeCount));
+
+        ScopedSystemProperty sp3(kAvoidBadPrivateDnsFlag, (config.avoidBadPrivateDns ? "1" : "0"));
+        forceExperimentsInstanceUpdate();
+
+        // Simulate that validation takes the certain time to complete the first probe.
+        std::thread t([&] {
+            backend.setResponseDelayMs(config.probingTime);
+            std::this_thread::sleep_for(std::chrono::milliseconds(config.probingTime + 500));
+            backend.setResponseDelayMs(0);
+        });
+
+        testing::InSequence seq;
+        EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId))
+                .Times(config.expectProbeCount);
+        EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
+
+        EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
+        expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
+
+        // The thread is expected to be joined before the second probe begins.
+        t.join();
+        ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
+
+        // Reset the state for the next round.
+        mPdc.clear(kNetId);
+        testing::Mock::VerifyAndClearExpectations(&mObserver);
+    }
+
+    backend.setResponseDelayMs(0);
+}
+
 TEST_F(PrivateDnsConfigurationTest, ValidationBlock) {
     backend.setDeferredResp(true);
 
diff --git a/ResolverController.cpp b/ResolverController.cpp
index 0d01d2b..28de275 100644
--- a/ResolverController.cpp
+++ b/ResolverController.cpp
@@ -345,8 +345,13 @@
                        static_cast<uint32_t>(privateDnsStatus.serversMap.size()));
             dw.incIndent();
             for (const auto& [server, validation] : privateDnsStatus.serversMap) {
-                dw.println("%s name{%s} status{%s}", server.toIpString().c_str(),
-                           server.name.c_str(), validationStatusToString(validation));
+                const std::string latencyThreshold =
+                        server.latencyThreshold()
+                                ? std::to_string(server.latencyThreshold().value()) + "ms"
+                                : "nullopt";
+                dw.println("%s name{%s} status{%s} latencyThreshold{%s}",
+                           server.toIpString().c_str(), server.name.c_str(),
+                           validationStatusToString(validation), latencyThreshold.c_str());
             }
             dw.decIndent();
         }
diff --git a/res_cache.cpp b/res_cache.cpp
index 5b03147..9eb143a 100644
--- a/res_cache.cpp
+++ b/res_cache.cpp
@@ -1929,6 +1929,18 @@
     return false;
 }
 
+std::optional<std::chrono::microseconds> resolv_stats_get_average_response_time(
+        unsigned netid, android::net::Protocol protocol) {
+    if (protocol == android::net::PROTO_UNKNOWN) return std::nullopt;
+
+    std::lock_guard guard(cache_mutex);
+    if (const auto info = find_netconfig_locked(netid); info != nullptr) {
+        return info->dnsStats.getAverageLatencyUs(protocol);
+    }
+
+    return std::nullopt;
+}
+
 static const char* tc_mode_to_str(const int mode) {
     switch (mode) {
         case aidl::android::net::IDnsResolver::TC_MODE_DEFAULT:
diff --git a/resolv_cache.h b/resolv_cache.h
index 15baa14..0ff3dbe 100644
--- a/resolv_cache.h
+++ b/resolv_cache.h
@@ -112,6 +112,12 @@
 bool resolv_stats_add(unsigned netid, const android::netdutils::IPSockAddr& server,
                       const android::net::DnsQueryEvent* record);
 
+// Get the average DNS response time per |protocol| on a network.
+// Return -1 if no such average DNS response time available at the time. However, this is unlikely
+// to happen because of DNS probe in NetworkMonitor, except when a network is just created.
+std::optional<std::chrono::microseconds> resolv_stats_get_average_response_time(
+        unsigned netid, android::net::Protocol protocol);
+
 /* Retrieve a local copy of the stats for the given netid. The buffer must have space for
  * MAXNS __resolver_stats. Returns the revision id of the resolvers used.
  */
diff --git a/tests/resolv_integration_test.cpp b/tests/resolv_integration_test.cpp
index 2134227..9102cdb 100644
--- a/tests/resolv_integration_test.cpp
+++ b/tests/resolv_integration_test.cpp
@@ -88,6 +88,12 @@
 const std::string kDotXportUnusableThresholdFlag(
         "persist.device_config.netd_native.dot_xport_unusable_threshold");
 const std::string kDotQueryTimeoutMsFlag("persist.device_config.netd_native.dot_query_timeout_ms");
+const std::string kAvoidBadPrivateDnsFlag(
+        "persist.device_config.netd_native.avoid_bad_private_dns");
+const std::string kMinPrivateDnsLatencyThresholdMsFlag(
+        "persist.device_config.netd_native.min_private_dns_latency_threshold_ms");
+const std::string kMaxPrivateDnsLatencyThresholdMsFlag(
+        "persist.device_config.netd_native.max_private_dns_latency_threshold_ms");
 
 // Semi-public Bionic hook used by the NDK (frameworks/base/native/android/net.c)
 // Tested here for convenience.
@@ -4676,7 +4682,12 @@
         SCOPED_TRACE(fmt::format("testConfig: [{}, {}, {}]", config.dnsMode,
                                  config.validationThreshold, config.queries));
         const int queries = config.queries;
-        const int delayQueriesTimeout = dotQueryTimeoutMs + 1000;
+
+        // In order make DoT queries timedout and revalidation not timed out, add a bit longer
+        // timeout for the DoT server. Note that min_private_dns_latency_threshold_ms must be
+        // lower than the 200 ms, because this test might be run when avoid_bad_private_dns is
+        // enabled.
+        const int delayQueriesTimeout = dotQueryTimeoutMs + 200;
 
         ScopedSystemProperties sp1(kDotRevalidationThresholdFlag,
                                    std::to_string(config.validationThreshold));
@@ -4704,13 +4715,12 @@
         // Expect the things happening in order:
         // 1. Configure the DoT server to postpone |queries + 1| DNS queries.
         // 2. Send |queries| DNS queries, they will time out in 1 second.
-        // 3. 1 second later, the DoT server still waits for one more DNS query until
-        //    |delayQueriesTimeout| times out.
+        // 3. 1 second later, the DoT server still waits for one more DNS query.
         // 4. (opportunistic mode only) Meanwhile, DoT revalidation happens. The DnsResolver
         //    creates a new connection and sends a query to the DoT server.
-        // 5. 1 second later, |delayQueriesTimeout| times out. The DoT server flushes all of the
-        //    postponed DNS queries, and handles the query which comes from the revalidation.
-        // 6. (opportunistic mode only) The revalidation succeeds.
+        // 5. 200 milliseconds later, |delayQueriesTimeout| times out. The DoT server flushes
+        //    all of the postponed DNS queries.
+        // 6. (opportunistic mode only) The revalidation starts and then succeeds.
         // 7. Send another DNS query, and expect it will succeed.
         // 8. (opportunistic mode only) If the DoT server has been deemed as unusable, the
         //    DnsResolver skips trying the DoT server.
@@ -4762,6 +4772,118 @@
     }
 }
 
+// Verifies that the DnsResolver re-validates the DoT server when Avoid Bad Private DNS feature
+// is enabled.
+TEST_F(ResolverTest, TlsServerRevalidation_AvoidBadPrivateDns) {
+    constexpr uint32_t cacheFlag = ANDROID_RESOLV_NO_CACHE_LOOKUP;
+    constexpr int kMinPrivateDnsLatencyThresholdMs = 500;
+    constexpr int kMaxPrivateDnsLatencyThresholdMs = 2000;
+    constexpr int kDotRevalidationThreshold = 5;
+    constexpr int TIMING_TOLERANCE_MS = 200;
+    constexpr char hostname[] = "hello.example.com.";
+    const std::vector<DnsRecord> records = {
+            {hostname, ns_type::ns_t_a, "1.2.3.4"},
+    };
+
+    for (const bool featureEnabled : {true, false}) {
+        SCOPED_TRACE(fmt::format("testConfig: [{}]", featureEnabled));
+        const std::string addr = getUniqueIPv4Address();
+        test::DNSResponder dns(addr);
+        test::DnsTlsFrontend tls(addr, "853", addr, "53");
+        StartDns(dns, records);
+        ASSERT_TRUE(tls.startServer());
+
+        ScopedSystemProperties sp1(kAvoidBadPrivateDnsFlag, (featureEnabled ? "1" : "0"));
+        ScopedSystemProperties sp2(kMinPrivateDnsLatencyThresholdMsFlag,
+                                   std::to_string(kMinPrivateDnsLatencyThresholdMs));
+        ScopedSystemProperties sp3(kMaxPrivateDnsLatencyThresholdMsFlag,
+                                   std::to_string(kMaxPrivateDnsLatencyThresholdMs));
+        ScopedSystemProperties sp4(kDotRevalidationThresholdFlag,
+                                   std::to_string(kDotRevalidationThreshold));
+        ScopedSystemProperties sp5(kDotXportUnusableThresholdFlag, "20");
+        ScopedSystemProperties sp6(kDotQueryTimeoutMsFlag, "10000");
+        resetNetwork();
+
+        const auto setLatencyAndSendQueriesAndWait = [&](int queries,
+                                                         std::chrono::milliseconds latency,
+                                                         bool bypassPrivateDns = false) {
+            std::thread worker([&]() {
+                dns.setDeferredResp(true);
+                std::this_thread::sleep_for(latency);
+                dns.setDeferredResp(false);
+            });
+
+            const unsigned netId =
+                    TEST_NETID | (bypassPrivateDns ? NETID_USE_LOCAL_NAMESERVERS : 0);
+            std::vector<std::thread> threads(queries);
+            Stopwatch stopwatch;
+            for (std::thread& thread : threads) {
+                thread = std::thread([&]() {
+                    int fd = resNetworkQuery(netId, hostname, ns_c_in, ns_t_a, cacheFlag);
+                    expectAnswersValid(fd, AF_INET, "1.2.3.4");
+                });
+            }
+            for (std::thread& thread : threads) {
+                thread.join();
+            }
+            EXPECT_NEAR(latency.count(), stopwatch.getTimeAndResetUs() / 1000, TIMING_TOLERANCE_MS)
+                    << "took time should approximate equal timeout";
+
+            worker.join();
+        };
+
+        const auto expectQueryCount = [&](int frontendQueries, size_t backendQueries) {
+            EXPECT_TRUE(tls.waitForQueries(frontendQueries));
+            EXPECT_EQ(backendQueries, dns.queries().size());
+        };
+
+        // Set up opportunistic mode, and wait for the validation complete.
+        auto parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
+        parcel.servers = {addr};
+        parcel.tlsServers = {addr};
+        ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+        EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
+        tls.clearQueries();
+        dns.clearQueries();
+
+        // Set the server can handle |kDotRevalidationThreshold| queries in one connection.
+        // The 250 ms timeout is to ensure Private DNS Validation can always pass because the
+        // probing time is shorter than kMinPrivateDnsLatencyThresholdMs.
+        tls.setDelayQueries(kDotRevalidationThreshold);
+        tls.setDelayQueriesTimeout(kMinPrivateDnsLatencyThresholdMs / 2);
+        int expectedDotQueries = 0;
+
+        // Simulate that DoT latency is 200 ms. As there is no Do53 latency stats,
+        // kMinPrivateDnsLatencyThresholdMs will be chosen as the threshold. No revalidation will be
+        // triggered because 200 ms < kMinPrivateDnsLatencyThresholdMs.
+        setLatencyAndSendQueriesAndWait(5, 200ms);
+        expectedDotQueries += 5;
+        EXPECT_TRUE(tls.waitForQueries(expectedDotQueries));
+        expectQueryCount(expectedDotQueries, expectedDotQueries);
+
+        // Simulate that DoT latency is 800 ms. A revalidation will be triggered (if the feature is
+        // enabled) because 800 ms > kMinPrivateDnsLatencyThresholdMs.
+        setLatencyAndSendQueriesAndWait(5, 800ms);
+        expectedDotQueries += 5;
+        if (featureEnabled) {
+            EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
+            expectedDotQueries += 1;
+        }
+        expectQueryCount(expectedDotQueries, expectedDotQueries);
+
+        // Add some Do53 stats with latency 400 ms.
+        setLatencyAndSendQueriesAndWait(1, 400ms, true);
+        expectQueryCount(expectedDotQueries, expectedDotQueries + 1);
+
+        // Simulate that DoT latency is 800 ms again. No revalidation will be triggered because
+        // 800 ms < 3 * 400 ms.
+        setLatencyAndSendQueriesAndWait(5, 800ms);
+        expectedDotQueries += 5;
+        expectQueryCount(expectedDotQueries, expectedDotQueries + 1);
+        EXPECT_FALSE(hasUncaughtPrivateDnsValidation(addr));
+    }
+}
+
 TEST_F(ResolverTest, FlushNetworkCache) {
     SKIP_IF_REMOTE_VERSION_LESS_THAN(mDnsClient.resolvService(), 4);
     test::DNSResponder dns;