Merge "Avoid keeping sending queries to an invalid nameserver"
diff --git a/res_stats.cpp b/res_stats.cpp
index 36dde6b..f3b79bb 100644
--- a/res_stats.cpp
+++ b/res_stats.cpp
@@ -118,11 +118,11 @@
     android_net_res_stats_aggregate(stats, &successes, &errors, &timeouts, &internal_errors,
                                     &rtt_avg, &last_sample_time);
     if (successes >= 0 && errors >= 0 && timeouts >= 0) {
-        int total = successes + errors + timeouts;
+        int total = successes + errors + timeouts + internal_errors;
         LOG(INFO) << __func__ << ": NS stats: S " << successes << " + E " << errors << " + T "
                   << timeouts << " + I " << internal_errors << " = " << total
                   << ", rtt = " << rtt_avg << ", min_samples = " << unsigned(params->min_samples);
-        if (total >= params->min_samples && (errors > 0 || timeouts > 0)) {
+        if (total >= params->min_samples) {
             int success_rate = successes * 100 / total;
             LOG(INFO) << __func__ << ": success rate " << success_rate;
             if (success_rate < params->success_threshold) {
diff --git a/tests/dns_responder/dns_responder_client_ndk.h b/tests/dns_responder/dns_responder_client_ndk.h
index 0b18f42..ca749ed 100644
--- a/tests/dns_responder/dns_responder_client_ndk.h
+++ b/tests/dns_responder/dns_responder_client_ndk.h
@@ -63,10 +63,12 @@
     static void SetupMappings(unsigned num_hosts, const std::vector<std::string>& domains,
                               std::vector<Mapping>* mappings);
 
+    // This function is deprecated. Please use SetResolversFromParcel() instead.
     bool SetResolversForNetwork(const std::vector<std::string>& servers = kDefaultServers,
                                 const std::vector<std::string>& domains = kDefaultSearchDomains,
                                 const std::vector<int>& params = kDefaultParams);
 
+    // This function is deprecated. Please use SetResolversFromParcel() instead.
     bool SetResolversWithTls(const std::vector<std::string>& servers,
                              const std::vector<std::string>& searchDomains,
                              const std::vector<int>& params, const std::string& name) {
@@ -75,6 +77,7 @@
         return SetResolversWithTls(servers, searchDomains, params, servers, name);
     }
 
+    // This function is deprecated. Please use SetResolversFromParcel() instead.
     bool SetResolversWithTls(const std::vector<std::string>& servers,
                              const std::vector<std::string>& searchDomains,
                              const std::vector<int>& params,
diff --git a/tests/resolv_integration_test.cpp b/tests/resolv_integration_test.cpp
index 3c706a1..d49db5f 100644
--- a/tests/resolv_integration_test.cpp
+++ b/tests/resolv_integration_test.cpp
@@ -84,6 +84,7 @@
 
 using aidl::android::net::IDnsResolver;
 using aidl::android::net::INetd;
+using aidl::android::net::ResolverParamsParcel;
 using android::base::ParseInt;
 using android::base::StringPrintf;
 using android::base::unique_fd;
@@ -912,6 +913,94 @@
     EXPECT_EQ(0, wait_for_pending_req_timeout_count);
 }
 
+TEST_F(ResolverTest, SkipBadServersDueToInternalError) {
+    constexpr char listen_addr1[] = "fe80::1";
+    constexpr char listen_addr2[] = "255.255.255.255";
+    constexpr char listen_addr3[] = "127.0.0.3";
+
+    test::DNSResponder dns(listen_addr3);
+    ASSERT_TRUE(dns.startServer());
+
+    ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
+    parcel.servers = {listen_addr1, listen_addr2, listen_addr3};
+
+    // Bad servers can be distinguished after two attempts.
+    parcel.minSamples = 2;
+    ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+
+    // Start querying five times.
+    for (int i = 0; i < 5; i++) {
+        std::string hostName = StringPrintf("hello%d.com.", i);
+        dns.addMapping(hostName, ns_type::ns_t_a, "1.2.3.4");
+        const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
+        EXPECT_TRUE(safe_getaddrinfo(hostName.c_str(), nullptr, &hints) != nullptr);
+    }
+
+    std::vector<std::string> res_servers;
+    std::vector<std::string> res_domains;
+    std::vector<std::string> res_tls_servers;
+    res_params res_params;
+    std::vector<ResolverStats> res_stats;
+    int wait_for_pending_req_timeout_count;
+    ASSERT_TRUE(DnsResponderClient::GetResolverInfo(
+            mDnsClient.resolvService(), TEST_NETID, &res_servers, &res_domains, &res_tls_servers,
+            &res_params, &res_stats, &wait_for_pending_req_timeout_count));
+
+    // Verify the result by means of the statistics.
+    EXPECT_EQ(res_stats[0].successes, 0);
+    EXPECT_EQ(res_stats[1].successes, 0);
+    EXPECT_EQ(res_stats[2].successes, 5);
+    EXPECT_EQ(res_stats[0].internal_errors, 2);
+    EXPECT_EQ(res_stats[1].internal_errors, 2);
+    EXPECT_EQ(res_stats[2].internal_errors, 0);
+}
+
+TEST_F(ResolverTest, SkipBadServersDueToTimeout) {
+    constexpr char listen_addr1[] = "127.0.0.3";
+    constexpr char listen_addr2[] = "127.0.0.4";
+
+    // Set dns1 non-responsive and dns2 workable.
+    test::DNSResponder dns1(listen_addr1, test::kDefaultListenService, static_cast<ns_rcode>(-1));
+    test::DNSResponder dns2(listen_addr2);
+    dns1.setResponseProbability(0.0);
+    ASSERT_TRUE(dns1.startServer());
+    ASSERT_TRUE(dns2.startServer());
+
+    ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
+    parcel.servers = {listen_addr1, listen_addr2};
+
+    // Bad servers can be distinguished after two attempts.
+    parcel.minSamples = 2;
+    ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+
+    // Start querying five times.
+    for (int i = 0; i < 5; i++) {
+        std::string hostName = StringPrintf("hello%d.com.", i);
+        dns1.addMapping(hostName, ns_type::ns_t_a, "1.2.3.4");
+        dns2.addMapping(hostName, ns_type::ns_t_a, "1.2.3.5");
+        const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
+        EXPECT_TRUE(safe_getaddrinfo(hostName.c_str(), nullptr, &hints) != nullptr);
+    }
+
+    std::vector<std::string> res_servers;
+    std::vector<std::string> res_domains;
+    std::vector<std::string> res_tls_servers;
+    res_params res_params;
+    std::vector<ResolverStats> res_stats;
+    int wait_for_pending_req_timeout_count;
+    ASSERT_TRUE(DnsResponderClient::GetResolverInfo(
+            mDnsClient.resolvService(), TEST_NETID, &res_servers, &res_domains, &res_tls_servers,
+            &res_params, &res_stats, &wait_for_pending_req_timeout_count));
+
+    // Verify the result by means of the statistics as well as the query counts.
+    EXPECT_EQ(res_stats[0].successes, 0);
+    EXPECT_EQ(res_stats[1].successes, 5);
+    EXPECT_EQ(res_stats[0].timeouts, 2);
+    EXPECT_EQ(res_stats[1].timeouts, 0);
+    EXPECT_EQ(dns1.queries().size(), 2U);
+    EXPECT_EQ(dns2.queries().size(), 5U);
+}
+
 TEST_F(ResolverTest, EmptySetup) {
     std::vector<std::string> servers;
     std::vector<std::string> domains;
@@ -3494,8 +3583,7 @@
     ScopedSystemProperties scopedSystemProperties(kDotConnectTimeoutMsFlag, "100");
 
     // Set up resolver to opportunistic mode with the default configuration.
-    const aidl::android::net::ResolverParamsParcel parcel =
-            DnsResponderClient::GetDefaultResolverParamsParcel();
+    const ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
     ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
     EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
     dns.clearQueries();