Add two Netd binder calls to set/get resolver config.

setResolverConfiguration() sets the name servers, search domains,
and resolver parameters.
getResolverInfo() returns the configured information and also the
statistics for each server.
Also includes tests for the new functionality.

BUG: 25731675

Change-Id: Idde486f36bb731f9edd240d62dc1795f8e621fe6
diff --git a/tests/netd_test.cpp b/tests/netd_test.cpp
index 55453f2..ef91266 100644
--- a/tests/netd_test.cpp
+++ b/tests/netd_test.cpp
@@ -27,24 +27,57 @@
 #include <android-base/stringprintf.h>
 #include <private/android_filesystem_config.h>
 
+#include <algorithm>
+#include <chrono>
+#include <iterator>
+#include <numeric>
 #include <thread>
 
+#define LOG_TAG "netd_test"
+// TODO: make this dynamic and stop depending on implementation details.
+#define TEST_OEM_NETWORK "oem29"
+#define TEST_NETID 30
+
 #include "NetdClient.h"
 
 #include <gtest/gtest.h>
-#define LOG_TAG "resolverTest"
+
 #include <utils/Log.h>
+
 #include <testUtil.h>
 
 #include "dns_responder.h"
 #include "resolv_params.h"
+#include "ResolverStats.h"
+
+#include "android/net/INetd.h"
+#include "binder/IServiceManager.h"
 
 using android::base::StringPrintf;
 using android::base::StringAppendF;
+using android::net::ResolverStats;
 
-// TODO: make this dynamic and stop depending on implementation details.
-#define TEST_OEM_NETWORK "oem29"
-#define TEST_NETID 30
+// Emulates the behavior of UnorderedElementsAreArray, which currently cannot be used.
+// TODO: Use UnorderedElementsAreArray, which depends on being able to compile libgmock_host,
+// if that is not possible, improve this hacky algorithm, which is O(n**2)
+template <class A, class B>
+bool UnorderedCompareArray(const A& a, const B& b) {
+    if (a.size() != b.size()) return false;
+    for (const auto& a_elem : a) {
+        size_t a_count = 0;
+        for (const auto& a_elem2 : a) {
+            if (a_elem == a_elem2) {
+                ++a_count;
+            }
+        }
+        size_t b_count = 0;
+        for (const auto& b_elem : b) {
+            if (a_elem == b_elem) ++b_count;
+        }
+        if (a_count != b_count) return false;
+    }
+    return true;
+}
 
 // The only response code used in this test, see
 // frameworks/base/services/java/com/android/server/NetworkManagementService.java
@@ -80,7 +113,6 @@
     return atoi(buffer);
 }
 
-
 bool expectNetdResult(int expected, const char* sockname, const char* format, ...) {
     char command[256];
     va_list args;
@@ -92,15 +124,26 @@
     return (200 <= expected && expected < 300);
 }
 
-
 class ResolverTest : public ::testing::Test {
 protected:
+    struct Mapping {
+        std::string host;
+        std::string entry;
+        std::string ip4;
+        std::string ip6;
+    };
+
     virtual void SetUp() {
         // Ensure resolutions go via proxy.
         setenv("ANDROID_DNS_MODE", "", 1);
         uid = getuid();
         pid = getpid();
         SetupOemNetwork();
+
+        // binder setup
+        auto binder = android::defaultServiceManager()->getService(android::String16("netd"));
+        ASSERT_TRUE(binder != nullptr);
+        mNetdSrv = android::interface_cast<android::net::INetd>(binder);
     }
 
     virtual void TearDown() {
@@ -118,6 +161,53 @@
         ASSERT_EQ((unsigned) oemNetId, getNetworkForProcess());
     }
 
+    void SetupMappings(unsigned num_hosts, const std::vector<std::string>& domains,
+            std::vector<Mapping>* mappings) const {
+        mappings->resize(num_hosts * domains.size());
+        auto mappings_it = mappings->begin();
+        for (unsigned i = 0 ; i < num_hosts ; ++i) {
+            for (const auto& domain : domains) {
+                ASSERT_TRUE(mappings_it != mappings->end());
+                mappings_it->host = StringPrintf("host%u", i);
+                mappings_it->entry = StringPrintf("%s.%s.", mappings_it->host.c_str(),
+                        domain.c_str());
+                mappings_it->ip4 = StringPrintf("192.0.2.%u", i%253 + 1);
+                mappings_it->ip6 = StringPrintf("2001:db8::%x", i%65534 + 1);
+                ++mappings_it;
+            }
+        }
+    }
+
+    void SetupDNSServers(unsigned num_servers, const std::vector<Mapping>& mappings,
+            std::vector<std::unique_ptr<test::DNSResponder>>* dns,
+            std::vector<std::string>* servers) const {
+        ASSERT_TRUE(num_servers != 0 && num_servers < 100);
+        const char* listen_srv = "53";
+        dns->resize(num_servers);
+        servers->resize(num_servers);
+        for (unsigned i = 0 ; i < num_servers ; ++i) {
+            auto& server = (*servers)[i];
+            auto& d = (*dns)[i];
+            server = StringPrintf("127.0.0.%u", i + 100);
+            d = std::make_unique<test::DNSResponder>(server, listen_srv, 250,
+                    ns_rcode::ns_r_servfail, 1.0);
+            ASSERT_TRUE(d.get() != nullptr);
+            for (const auto& mapping : mappings) {
+                d->addMapping(mapping.entry.c_str(), ns_type::ns_t_a, mapping.ip4.c_str());
+                d->addMapping(mapping.entry.c_str(), ns_type::ns_t_aaaa, mapping.ip6.c_str());
+            }
+            ASSERT_TRUE(d->startServer());
+        }
+    }
+
+    void ShutdownDNSServers(std::vector<std::unique_ptr<test::DNSResponder>>* dns) const {
+        for (const auto& d : *dns) {
+            ASSERT_TRUE(d.get() != nullptr);
+            d->stopServer();
+        }
+        dns->clear();
+    }
+
     void TearDownOemNetwork() {
         if (oemNetId != -1) {
             expectNetdResult(ResponseCodeOK, "netd",
@@ -125,23 +215,27 @@
         }
     }
 
+    bool SetResolversForNetwork(const std::vector<std::string>& servers,
+            const std::vector<std::string>& domains, const std::vector<int>& params) {
+        auto rv = mNetdSrv->setResolverConfiguration(TEST_NETID, servers, domains, params);
+        return rv.isOk();
+    }
+
     bool SetResolversForNetwork(const std::vector<std::string>& searchDomains,
             const std::vector<std::string>& servers, const std::string& params) {
-        // No use case for empty domains / servers (yet).
-        if (searchDomains.empty() || servers.empty()) return false;
-
-        std::string cmd = StringPrintf("resolver setnetdns %d \"%s", oemNetId,
-                searchDomains[0].c_str());
-        for (size_t i = 1 ; i < searchDomains.size() ; ++i) {
-            cmd += " ";
-            cmd += searchDomains[i];
+        std::string cmd = StringPrintf("resolver setnetdns %d \"", oemNetId);
+        if (!searchDomains.empty()) {
+            cmd += searchDomains[0].c_str();
+            for (size_t i = 1 ; i < searchDomains.size() ; ++i) {
+                cmd += " ";
+                cmd += searchDomains[i];
+            }
         }
-        cmd += "\" ";
+        cmd += "\"";
 
-        cmd += servers[0];
-        for (size_t i = 1 ; i < servers.size() ; ++i) {
+        for (const auto& str : servers) {
             cmd += " ";
-            cmd += servers[i];
+            cmd += str;
         }
 
         if (!params.empty()) {
@@ -157,6 +251,28 @@
         return true;
     }
 
+    bool GetResolverInfo(std::vector<std::string>* servers, std::vector<std::string>* domains,
+            __res_params* params, std::vector<ResolverStats>* stats) {
+        using android::net::INetd;
+        std::vector<int32_t> params32;
+        std::vector<int32_t> stats32;
+        auto rv = mNetdSrv->getResolverInfo(TEST_NETID, servers, domains, &params32, &stats32);
+        if (!rv.isOk() || params32.size() != INetd::RESOLVER_PARAMS_COUNT) {
+            return false;
+        }
+        *params = __res_params {
+            .sample_validity = static_cast<uint16_t>(
+                    params32[INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY]),
+            .success_threshold = static_cast<uint8_t>(
+                    params32[INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD]),
+            .min_samples = static_cast<uint8_t>(
+                    params32[INetd::RESOLVER_PARAMS_MIN_SAMPLES]),
+            .max_samples = static_cast<uint8_t>(
+                    params32[INetd::RESOLVER_PARAMS_MAX_SAMPLES])
+        };
+        return ResolverStats::decodeAll(stats32, stats);
+    }
+
     std::string ToString(const hostent* he) const {
         if (he == nullptr) return "<null>";
         char buffer[INET6_ADDRSTRLEN];
@@ -184,7 +300,6 @@
         auto queries = dns.queries();
         size_t found = 0;
         for (const auto& p : queries) {
-            std::cout << "query " << p.first << "\n";
             if (p.first == name) {
                 ++found;
             }
@@ -197,7 +312,6 @@
         auto queries = dns.queries();
         size_t found = 0;
         for (const auto& p : queries) {
-            std::cout << "query " << p.first << "\n";
             if (p.second == type && p.first == name) {
                 ++found;
             }
@@ -205,9 +319,56 @@
         return found;
     }
 
+    void RunGetAddrInfoStressTest_Binder(unsigned num_hosts, unsigned num_threads,
+            unsigned num_queries) {
+        std::vector<std::string> domains = { "example.com" };
+        std::vector<std::unique_ptr<test::DNSResponder>> dns;
+        std::vector<std::string> servers;
+        std::vector<Mapping> mappings;
+        ASSERT_NO_FATAL_FAILURE(SetupMappings(num_hosts, domains, &mappings));
+        ASSERT_NO_FATAL_FAILURE(SetupDNSServers(MAXNS, mappings, &dns, &servers));
+
+        std::vector<int> params = { 300, 25, 8, 8 };
+        ASSERT_TRUE(SetResolversForNetwork(servers, domains, params));
+
+        auto t0 = std::chrono::steady_clock::now();
+        std::vector<std::thread> threads(num_threads);
+        for (std::thread& thread : threads) {
+           thread = std::thread([this, &servers, &dns, &mappings, num_queries]() {
+                for (unsigned i = 0 ; i < num_queries ; ++i) {
+                    uint32_t ofs = arc4random_uniform(mappings.size());
+                    ASSERT_TRUE(ofs < mappings.size());
+                    auto& mapping = mappings[i];
+                    addrinfo* result = nullptr;
+                    int rv = getaddrinfo(mapping.host.c_str(), nullptr, nullptr, &result);
+                    EXPECT_EQ(0, rv) << "error [" << rv << "] " << gai_strerror(rv);
+                    if (rv == 0) {
+                        std::string result_str = ToString(result);
+                        EXPECT_TRUE(result_str == mapping.ip4 || result_str == mapping.ip6)
+                            << "result='" << result_str << "', ip4='" << mapping.ip4
+                            << "', ip6='" << mapping.ip6;
+                    }
+                    if (result) {
+                        freeaddrinfo(result);
+                        result = nullptr;
+                    }
+                }
+            });
+        }
+
+        for (std::thread& thread : threads) {
+            thread.join();
+        }
+        auto t1 = std::chrono::steady_clock::now();
+        ALOGI("%u hosts, %u threads, %u queries, %Es", num_hosts, num_threads, num_queries,
+                std::chrono::duration<double>(t1 - t0).count());
+        ASSERT_NO_FATAL_FAILURE(ShutdownDNSServers(&dns));
+    }
+
     int pid;
     int uid;
     int oemNetId = -1;
+    android::sp<android::net::INetd> mNetdSrv = nullptr;
     const std::vector<std::string> mDefaultSearchDomains = { "example.com" };
     // <sample validity in s> <success threshold in percent> <min samples> <max samples>
     const std::string mDefaultParams = "300 25 8 8";
@@ -234,13 +395,77 @@
     dns.stopServer();
 }
 
+TEST_F(ResolverTest, TestBinderSerialization) {
+    using android::net::INetd;
+    std::vector<int> params_offsets = {
+        INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY,
+        INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD,
+        INetd::RESOLVER_PARAMS_MIN_SAMPLES,
+        INetd::RESOLVER_PARAMS_MAX_SAMPLES
+    };
+    int size = static_cast<int>(params_offsets.size());
+    EXPECT_EQ(size, INetd::RESOLVER_PARAMS_COUNT);
+    std::sort(params_offsets.begin(), params_offsets.end());
+    for (int i = 0 ; i < size ; ++i) {
+        EXPECT_EQ(params_offsets[i], i);
+    }
+}
+
+TEST_F(ResolverTest, GetHostByName_Binder) {
+    using android::net::INetd;
+
+    std::vector<std::string> domains = { "example.com" };
+    std::vector<std::unique_ptr<test::DNSResponder>> dns;
+    std::vector<std::string> servers;
+    std::vector<Mapping> mappings;
+    ASSERT_NO_FATAL_FAILURE(SetupMappings(1, domains, &mappings));
+    ASSERT_NO_FATAL_FAILURE(SetupDNSServers(4, mappings, &dns, &servers));
+    ASSERT_EQ(1U, mappings.size());
+    const Mapping& mapping = mappings[0];
+
+    std::vector<int> params = { 300, 25, 8, 8 };
+    ASSERT_TRUE(SetResolversForNetwork(servers, domains, params));
+
+    const hostent* result = gethostbyname(mapping.host.c_str());
+    size_t total_queries = std::accumulate(dns.begin(), dns.end(), 0,
+            [this, &mapping](size_t total, auto& d) {
+                return total + GetNumQueriesForType(*d, ns_type::ns_t_a, mapping.entry.c_str());
+            });
+
+    EXPECT_LE(1U, total_queries);
+    ASSERT_FALSE(result == nullptr);
+    ASSERT_EQ(4, result->h_length);
+    ASSERT_FALSE(result->h_addr_list[0] == nullptr);
+    EXPECT_EQ(mapping.ip4, ToString(result));
+    EXPECT_TRUE(result->h_addr_list[1] == nullptr);
+
+    std::vector<std::string> res_servers;
+    std::vector<std::string> res_domains;
+    __res_params res_params;
+    std::vector<ResolverStats> res_stats;
+    ASSERT_TRUE(GetResolverInfo(&res_servers, &res_domains, &res_params, &res_stats));
+    EXPECT_EQ(servers.size(), res_servers.size());
+    EXPECT_EQ(domains.size(), res_domains.size());
+    ASSERT_EQ(INetd::RESOLVER_PARAMS_COUNT, params.size());
+    EXPECT_EQ(params[INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY], res_params.sample_validity);
+    EXPECT_EQ(params[INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD], res_params.success_threshold);
+    EXPECT_EQ(params[INetd::RESOLVER_PARAMS_MIN_SAMPLES], res_params.min_samples);
+    EXPECT_EQ(params[INetd::RESOLVER_PARAMS_MAX_SAMPLES], res_params.max_samples);
+    EXPECT_EQ(servers.size(), res_stats.size());
+
+    EXPECT_TRUE(UnorderedCompareArray(res_servers, servers));
+    EXPECT_TRUE(UnorderedCompareArray(res_domains, domains));
+
+    ASSERT_NO_FATAL_FAILURE(ShutdownDNSServers(&dns));
+}
+
 TEST_F(ResolverTest, GetAddrInfo) {
     addrinfo* result = nullptr;
 
     const char* listen_addr = "127.0.0.4";
     const char* listen_addr2 = "127.0.0.5";
     const char* listen_srv = "53";
-    const char* host_name = "howdie.example.com.";
+    const char* host_name = "howdy.example.com.";
     test::DNSResponder dns(listen_addr, listen_srv, 250,
                            ns_rcode::ns_r_servfail, 1.0);
     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
@@ -250,26 +475,30 @@
     ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
 
     dns.clearQueries();
-    EXPECT_EQ(0, getaddrinfo("howdie", nullptr, nullptr, &result));
+    EXPECT_EQ(0, getaddrinfo("howdy", nullptr, nullptr, &result));
     size_t found = GetNumQueries(dns, host_name);
     EXPECT_LE(1U, found);
     // Could be A or AAAA
     std::string result_str = ToString(result);
     EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
         << ", result_str='" << result_str << "'";
-    if (result) freeaddrinfo(result);
-    result = nullptr;
+    // TODO: Use ScopedAddrinfo or similar once it is available in a common header file.
+    if (result) {
+        freeaddrinfo(result);
+        result = nullptr;
+    }
 
     // Verify that the name is cached.
-    size_t old_found = found;
-    EXPECT_EQ(0, getaddrinfo("howdie", nullptr, nullptr, &result));
+    EXPECT_EQ(0, getaddrinfo("howdy", nullptr, nullptr, &result));
     found = GetNumQueries(dns, host_name);
-    EXPECT_EQ(old_found, found);
+    EXPECT_LE(1U, found);
     result_str = ToString(result);
     EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
         << result_str;
-    if (result) freeaddrinfo(result);
-    result = nullptr;
+    if (result) {
+        freeaddrinfo(result);
+        result = nullptr;
+    }
 
     // Change the DNS resolver, ensure that queries are no longer cached.
     dns.clearQueries();
@@ -280,7 +509,7 @@
     ASSERT_TRUE(dns2.startServer());
     servers = { listen_addr2 };
     ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
-    EXPECT_EQ(0, getaddrinfo("howdie", nullptr, nullptr, &result));
+    EXPECT_EQ(0, getaddrinfo("howdy", nullptr, nullptr, &result));
     found = GetNumQueries(dns, host_name);
     size_t found2 = GetNumQueries(dns2, host_name);
     EXPECT_EQ(0U, found);
@@ -290,8 +519,10 @@
     result_str = ToString(result);
     EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
         << ", result_str='" << result_str << "'";
-    if (result) freeaddrinfo(result);
-    result = nullptr;
+    if (result) {
+        freeaddrinfo(result);
+        result = nullptr;
+    }
     dns.stopServer();
     dns2.stopServer();
 }
@@ -315,7 +546,10 @@
     EXPECT_EQ(0, getaddrinfo("hola", nullptr, &hints, &result));
     EXPECT_EQ(1U, GetNumQueries(dns, host_name));
     EXPECT_EQ("1.2.3.5", ToString(result));
-    if (result) freeaddrinfo(result);
+    if (result) {
+        freeaddrinfo(result);
+        result = nullptr;
+    }
 }
 
 TEST_F(ResolverTest, MultidomainResolution) {
@@ -375,6 +609,10 @@
     for (int i = 0 ; i < sample_count ; ++i) {
         std::string domain = StringPrintf("nonexistent%d", i);
         getaddrinfo(domain.c_str(), nullptr, &hints, &result);
+        if (result) {
+            freeaddrinfo(result);
+            result = nullptr;
+        }
     }
     // Due to 100% errors for all possible samples, the server should be ignored from now on and
     // only the second one used for all following queries, until NSSAMPLE_VALIDITY is reached.
@@ -383,7 +621,10 @@
     EXPECT_EQ(0, getaddrinfo("ohayou", nullptr, &hints, &result));
     EXPECT_EQ(0U, GetNumQueries(dns0, host_name));
     EXPECT_EQ(1U, GetNumQueries(dns1, host_name));
-    if (result) freeaddrinfo(result);
+    if (result) {
+        freeaddrinfo(result);
+        result = nullptr;
+    }
 }
 
 TEST_F(ResolverTest, GetAddrInfoV6_concurrent) {
@@ -425,9 +666,47 @@
             addrinfo* result = nullptr;
             int rv = getaddrinfo("konbanha", nullptr, &hints, &result);
             EXPECT_EQ(0, rv) << "error [" << rv << "] " << gai_strerror(rv);
+            if (result) {
+                freeaddrinfo(result);
+                result = nullptr;
+            }
         });
     }
     for (std::thread& thread : threads) {
         thread.join();
     }
 }
+
+TEST_F(ResolverTest, GetAddrInfoStressTest_Binder_100) {
+    const unsigned num_hosts = 100;
+    const unsigned num_threads = 100;
+    const unsigned num_queries = 100;
+    ASSERT_NO_FATAL_FAILURE(RunGetAddrInfoStressTest_Binder(num_hosts, num_threads, num_queries));
+}
+
+TEST_F(ResolverTest, GetAddrInfoStressTest_Binder_100000) {
+    const unsigned num_hosts = 100000;
+    const unsigned num_threads = 100;
+    const unsigned num_queries = 100;
+    ASSERT_NO_FATAL_FAILURE(RunGetAddrInfoStressTest_Binder(num_hosts, num_threads, num_queries));
+}
+
+TEST_F(ResolverTest, EmptySetup) {
+    using android::net::INetd;
+    std::vector<std::string> servers;
+    std::vector<std::string> domains;
+    std::vector<int> params = { 300, 25, 8, 8 };
+    ASSERT_TRUE(SetResolversForNetwork(servers, domains, params));
+    std::vector<std::string> res_servers;
+    std::vector<std::string> res_domains;
+    __res_params res_params;
+    std::vector<ResolverStats> res_stats;
+    ASSERT_TRUE(GetResolverInfo(&res_servers, &res_domains, &res_params, &res_stats));
+    EXPECT_EQ(0U, res_servers.size());
+    EXPECT_EQ(0U, res_domains.size());
+    ASSERT_EQ(INetd::RESOLVER_PARAMS_COUNT, params.size());
+    EXPECT_EQ(params[INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY], res_params.sample_validity);
+    EXPECT_EQ(params[INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD], res_params.success_threshold);
+    EXPECT_EQ(params[INetd::RESOLVER_PARAMS_MIN_SAMPLES], res_params.min_samples);
+    EXPECT_EQ(params[INetd::RESOLVER_PARAMS_MAX_SAMPLES], res_params.max_samples);
+}