Switch to a new way of activating DNS-over-TLS

This change removes the global database of potential DNS-over-TLS
servers from Netd, and makes pinned or named servers mandatory-TLS,
not opportunistic.

Bug: 64753847
Change-Id: I226ffec3f59593bc40cd9019095c5261aae55fa0
Test: Tests pass.  Normal browsing continues to work normally.
diff --git a/server/DnsProxyListener.cpp b/server/DnsProxyListener.cpp
index c29b483..8114b2d 100644
--- a/server/DnsProxyListener.cpp
+++ b/server/DnsProxyListener.cpp
@@ -101,10 +101,30 @@
         return res_goahead;
     }
     DnsTlsTransport::Server tlsServer;
-    if (net::gCtls->resolverCtrl.shouldUseTls(thread_netcontext.dns_netid,
-            insecureResolver, &tlsServer)) {
+    auto tlsStatus = net::gCtls->resolverCtrl.getTlsStatus(thread_netcontext.dns_netid,
+            insecureResolver, &tlsServer);
+    if (tlsStatus == ResolverController::Validation::unknown_netid) {
         if (DBG) {
-            ALOGD("qhook using TLS");
+            ALOGD("No TLS for netid %u", thread_netcontext.dns_netid);
+        }
+        return res_goahead;
+    } else if (tlsStatus == ResolverController::Validation::unknown_server) {
+        if (DBG) {
+            ALOGW("Skipping unexpected server in TLS mode");
+        }
+        return res_nextns;
+    } else {
+        if (tlsStatus != ResolverController::Validation::success) {
+            if (DBG) {
+                ALOGD("Server is not ready");
+            }
+            // TLS validation has not completed.  In opportunistic mode, fall back to UDP.
+            // In strict mode, try a different server.
+            bool opportunistic = tlsServer.name.empty() && tlsServer.fingerprints.empty();
+            return opportunistic ? res_goahead : res_nextns;
+        }
+        if (DBG) {
+            ALOGD("Performing query over TLS");
         }
         auto response = DnsTlsTransport::query(tlsServer, thread_netcontext.dns_mark,
                 *buf, *buflen, ans, anssiz, resplen);
@@ -117,18 +137,13 @@
         if (DBG) {
             ALOGW("qhook abort: doQuery failed: %d", (int)response);
         }
-        // If there was a network error, try a different name server.
-        // Otherwise, fail hard.
+        // If there was a network error on a validated server, try a different name server.
         if (response == DnsTlsTransport::Response::network_error) {
             return res_nextns;
         }
+        // There was an internal error.  Fail hard.
         return res_error;
     }
-
-    if (DBG) {
-        ALOGD("qhook not using TLS");
-    }
-    return res_goahead;
 }
 
 }  // namespace
diff --git a/server/NetdNativeService.cpp b/server/NetdNativeService.cpp
index d4e559d..70ab0ad 100644
--- a/server/NetdNativeService.cpp
+++ b/server/NetdNativeService.cpp
@@ -16,9 +16,11 @@
 
 #define LOG_TAG "Netd"
 
+#include <set>
 #include <vector>
 
 #include <android-base/stringprintf.h>
+#include <android-base/strings.h>
 #include <cutils/log.h>
 #include <cutils/properties.h>
 #include <utils/Errors.h>
@@ -199,13 +201,48 @@
     return binder::Status::ok();
 }
 
+// Parse a base64 encoded string into a vector of bytes.
+// On failure, return an empty vector.
+static std::vector<uint8_t> parseBase64(const std::string& input) {
+    std::vector<uint8_t> decoded;
+    size_t out_len;
+    if (EVP_DecodedLength(&out_len, input.size()) != 1) {
+        return decoded;
+    }
+    // out_len is now an upper bound on the output length.
+    decoded.resize(out_len);
+    if (EVP_DecodeBase64(decoded.data(), &out_len, decoded.size(),
+            reinterpret_cast<const uint8_t*>(input.data()), input.size()) == 1) {
+        // Possibly shrink the vector if the actual output was smaller than the bound.
+        decoded.resize(out_len);
+    } else {
+        decoded.clear();
+    }
+    if (out_len != SHA256_SIZE) {
+        decoded.clear();
+    }
+    return decoded;
+}
+
 binder::Status NetdNativeService::setResolverConfiguration(int32_t netId,
         const std::vector<std::string>& servers, const std::vector<std::string>& domains,
-        const std::vector<int32_t>& params) {
+        const std::vector<int32_t>& params, bool useTls, const std::string& tlsName,
+        const std::vector<std::string>& tlsFingerprints) {
     // This function intentionally does not lock within Netd, as Bionic is thread-safe.
     ENFORCE_PERMISSION(CONNECTIVITY_INTERNAL);
 
-    int err = gCtls->resolverCtrl.setResolverConfiguration(netId, servers, domains, params);
+    std::set<std::vector<uint8_t>> decoded_fingerprints;
+    for (const std::string& fingerprint : tlsFingerprints) {
+        std::vector<uint8_t> decoded = parseBase64(fingerprint);
+        if (decoded.empty()) {
+            return binder::Status::fromServiceSpecificError(EINVAL,
+                    String8::format("ResolverController error: bad fingerprint"));
+        }
+        decoded_fingerprints.emplace(decoded);
+    }
+
+    int err = gCtls->resolverCtrl.setResolverConfiguration(netId, servers, domains, params,
+            useTls, tlsName, decoded_fingerprints);
     if (err != 0) {
         return binder::Status::fromServiceSpecificError(-err,
                 String8::format("ResolverController error: %s", strerror(-err)));
@@ -227,49 +264,6 @@
     return binder::Status::ok();
 }
 
-binder::Status NetdNativeService::addPrivateDnsServer(const std::string& server, int32_t port,
-        const std::string& name,
-        const std::string& fingerprintAlgorithm,
-        const std::vector<std::string>& fingerprints) {
-    ENFORCE_PERMISSION(CONNECTIVITY_INTERNAL);
-    std::set<std::vector<uint8_t>> decoded_fingerprints;
-    for (const std::string& input : fingerprints) {
-        size_t out_len;
-        if (EVP_DecodedLength(&out_len, input.size()) != 1) {
-            return binder::Status::fromServiceSpecificError(INetd::PRIVATE_DNS_BAD_FINGERPRINT,
-                    "ResolverController error: bad fingerprint length");
-        }
-        // out_len is now an upper bound on the output length.
-        std::vector<uint8_t> decoded(out_len);
-        if (EVP_DecodeBase64(decoded.data(), &out_len, decoded.size(),
-                reinterpret_cast<const uint8_t*>(input.data()), input.size()) == 1) {
-            // Possibly shrink the vector if the actual output was smaller than the bound.
-            decoded.resize(out_len);
-        } else {
-            return binder::Status::fromServiceSpecificError(INetd::PRIVATE_DNS_BAD_FINGERPRINT,
-                    "ResolverController error: Base64 parsing failed");
-        }
-        decoded_fingerprints.insert(decoded);
-    }
-    const int err = gCtls->resolverCtrl.addPrivateDnsServer(server, port, name,
-            fingerprintAlgorithm, decoded_fingerprints);
-    if (err != INetd::PRIVATE_DNS_SUCCESS) {
-        return binder::Status::fromServiceSpecificError(err,
-                String8::format("ResolverController error: %d", err));
-    }
-    return binder::Status::ok();
-}
-
-binder::Status NetdNativeService::removePrivateDnsServer(const std::string& server) {
-    ENFORCE_PERMISSION(CONNECTIVITY_INTERNAL);
-    const int err = gCtls->resolverCtrl.removePrivateDnsServer(server);
-    if (err != INetd::PRIVATE_DNS_SUCCESS) {
-        return binder::Status::fromServiceSpecificError(err,
-                String8::format("ResolverController error: %d", err));
-    }
-    return binder::Status::ok();
-}
-
 binder::Status NetdNativeService::tetherApplyDnsInterfaces(bool *ret) {
     NETD_LOCKING_RPC(NETWORK_STACK, gCtls->tetherCtrl.lock)
 
diff --git a/server/NetdNativeService.h b/server/NetdNativeService.h
index d4b7ed7..2f23d8e 100644
--- a/server/NetdNativeService.h
+++ b/server/NetdNativeService.h
@@ -43,15 +43,12 @@
     binder::Status socketDestroy(const std::vector<UidRange>& uids,
             const std::vector<int32_t>& skipUids) override;
     binder::Status setResolverConfiguration(int32_t netId, const std::vector<std::string>& servers,
-            const std::vector<std::string>& domains, const std::vector<int32_t>& params) override;
+            const std::vector<std::string>& domains, const std::vector<int32_t>& params,
+            bool useTls, const std::string& tlsName,
+            const std::vector<std::string>& tlsFingerprints) override;
     binder::Status getResolverInfo(int32_t netId, std::vector<std::string>* servers,
             std::vector<std::string>* domains, std::vector<int32_t>* params,
             std::vector<int32_t>* stats) override;
-    binder::Status addPrivateDnsServer(const std::string& server, int32_t port,
-            const std::string& name,
-            const std::string& fingerprintAlgorithm,
-            const std::vector<std::string>& fingerprints) override;
-    binder::Status removePrivateDnsServer(const std::string& server) override;
 
     binder::Status setIPv6AddrGenMode(const std::string& ifName, int32_t mode) override;
 
diff --git a/server/ResolverController.cpp b/server/ResolverController.cpp
index ab69383..51928d2 100644
--- a/server/ResolverController.cpp
+++ b/server/ResolverController.cpp
@@ -54,25 +54,13 @@
 
 namespace {
 
-// This comparison ignores ports and fingerprints.
-struct AddressComparator {
-    bool operator() (const DnsTlsTransport::Server& x, const DnsTlsTransport::Server& y) const {
-      if (x.ss.ss_family != y.ss.ss_family) {
-          return x.ss.ss_family < y.ss.ss_family;
-      }
-      // Same address family.  Compare IP addresses.
-      if (x.ss.ss_family == AF_INET) {
-          const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x.ss);
-          const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y.ss);
-          return x_sin.sin_addr.s_addr < y_sin.sin_addr.s_addr;
-      } else if (x.ss.ss_family == AF_INET6) {
-          const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x.ss);
-          const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y.ss);
-          return std::memcmp(x_sin6.sin6_addr.s6_addr, y_sin6.sin6_addr.s6_addr, 16) < 0;
-      }
-      return false;  // Unknown address type.  This is an error.
-    }
-};
+// Only used for debug logging
+std::string addrToString(const sockaddr_storage* addr) {
+    char out[INET6_ADDRSTRLEN] = {0};
+    getnameinfo((const sockaddr*)addr, sizeof(sockaddr_storage), out,
+            INET6_ADDRSTRLEN, NULL, 0, NI_NUMERICHOST);
+    return std::string(out);
+}
 
 bool parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
     sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
@@ -95,94 +83,121 @@
     return false;
 }
 
-// Structure for tracking the entire set of known Private DNS servers.
+// Structure for tracking the validation status of servers on a specific netId.
+// Using the AddressComparator ensures at most one entry per IP address.
+typedef std::map<DnsTlsTransport::Server, ResolverController::Validation,
+        AddressComparator> PrivateDnsTracker;
 std::mutex privateDnsLock;
-typedef std::set<DnsTlsTransport::Server, AddressComparator> PrivateDnsSet;
-PrivateDnsSet privateDnsServers GUARDED_BY(privateDnsLock);
-
-// Structure for tracking the validation status of servers on a specific netid.
-// Servers that fail validation are removed from the tracker, and can be retried.
-enum class Validation : bool { in_process, success };
-typedef std::map<DnsTlsTransport::Server, Validation, AddressComparator> PrivateDnsTracker;
 std::map<unsigned, PrivateDnsTracker> privateDnsTransports GUARDED_BY(privateDnsLock);
 
-PrivateDnsSet parseServers(const char** servers, int numservers, in_port_t port) {
-    PrivateDnsSet set;
-    for (int i = 0; i < numservers; ++i) {
-        sockaddr_storage parsed;
-        if (parseServer(servers[i], port, &parsed)) {
-            set.insert(parsed);
-        }
+void checkPrivateDnsProvider(const DnsTlsTransport::Server& server,
+        PrivateDnsTracker& tracker, unsigned netId) REQUIRES(privateDnsLock) {
+    if (DBG) {
+        ALOGD("checkPrivateDnsProvider(%s, %u)", addrToString(&(server.ss)).c_str(), netId);
     }
-    return set;
+
+    tracker[server] = ResolverController::Validation::in_process;
+    if (DBG) {
+        ALOGD("Server %s marked as in_process.  Tracker now has size %zu",
+                addrToString(&(server.ss)).c_str(), tracker.size());
+    }
+    std::thread validate_thread([server, netId] {
+        // ::validate() is a blocking call that performs network operations.
+        // It can take milliseconds to minutes, up to the SYN retry limit.
+        bool success = DnsTlsTransport::validate(server, netId);
+        if (DBG) {
+            ALOGD("validateDnsTlsServer returned %d for %s", success,
+                    addrToString(&(server.ss)).c_str());
+        }
+        std::lock_guard<std::mutex> guard(privateDnsLock);
+        auto netPair = privateDnsTransports.find(netId);
+        if (netPair == privateDnsTransports.end()) {
+            ALOGW("netId %u was erased during private DNS validation", netId);
+            return;
+        }
+        auto& tracker = netPair->second;
+        auto serverPair = tracker.find(server);
+        if (serverPair == tracker.end()) {
+            ALOGW("Server %s was removed during private DNS validation",
+                    addrToString(&(server.ss)).c_str());
+            success = false;
+        }
+        if (!(serverPair->first == server)) {
+            ALOGW("Server %s was changed during private DNS validation",
+                    addrToString(&(server.ss)).c_str());
+            success = false;
+        }
+        if (success) {
+            tracker[server] = ResolverController::Validation::success;
+            if (DBG) {
+                ALOGD("Validation succeeded for %s! Tracker now has %zu entries.",
+                        addrToString(&(server.ss)).c_str(), tracker.size());
+            }
+        } else {
+            // Validation failure is expected if a user is on a captive portal.
+            // TODO: Trigger a second validation attempt after captive portal login
+            // succeeds.
+            if (DBG) {
+                ALOGD("Validation failed for %s!", addrToString(&(server.ss)).c_str());
+            }
+            tracker[server] = ResolverController::Validation::fail;
+        }
+    });
+    validate_thread.detach();
 }
 
-void checkPrivateDnsProviders(const unsigned netId, const char** servers, int numservers) {
+int setPrivateDnsProviders(int32_t netId,
+        const std::vector<std::string>& servers, const std::string& name,
+        const std::set<std::vector<uint8_t>>& fingerprints) {
     if (DBG) {
-        ALOGD("checkPrivateDnsProviders(%u)", netId);
+        ALOGD("setPrivateDnsProviders(%u, %zu, %s, %zu)",
+                netId, servers.size(), name.c_str(), fingerprints.size());
+    }
+    // Parse the list of servers that has been passed in
+    std::set<DnsTlsTransport::Server> set;
+    for (size_t i = 0; i < servers.size(); ++i) {
+        sockaddr_storage parsed;
+        if (!parseServer(servers[i].c_str(), 853, &parsed)) {
+            return -EINVAL;
+        }
+        DnsTlsTransport::Server server(parsed);
+        server.name = name;
+        server.fingerprints = fingerprints;
+        set.insert(server);
     }
 
     std::lock_guard<std::mutex> guard(privateDnsLock);
-    if (privateDnsServers.empty()) {
-        return;
-    }
-
-    // First compute the intersection of the servers to check with the
-    // servers that are permitted to use DNS over TLS.  The intersection
-    // will contain the port number to be used for Private DNS.
-    PrivateDnsSet serversToCheck = parseServers(servers, numservers, 53);
-    PrivateDnsSet intersection;
-    std::set_intersection(privateDnsServers.begin(), privateDnsServers.end(),
-        serversToCheck.begin(), serversToCheck.end(),
-        std::inserter(intersection, intersection.begin()),
-        AddressComparator());
-    if (intersection.empty()) {
-        return;
-    }
-
+    // Create the tracker if it was not present
     auto netPair = privateDnsTransports.find(netId);
     if (netPair == privateDnsTransports.end()) {
-        // New netId
+        // No TLS tracker yet for this netId.
         bool added;
         std::tie(netPair, added) = privateDnsTransports.emplace(netId, PrivateDnsTracker());
         if (!added) {
-            ALOGE("Memory error while checking private DNS for netId %d", netId);
-            return;
+            ALOGE("Memory error while recording private DNS for netId %d", netId);
+            return -ENOMEM;
+        }
+    }
+    auto& tracker = netPair->second;
+
+    // Remove any servers from the tracker that are not in |servers| exactly.
+    for (auto it = tracker.begin(); it != tracker.end();) {
+        if (set.count(it->first) == 0) {
+            it = tracker.erase(it);
+        } else {
+            ++it;
         }
     }
 
-    auto& tracker = netPair->second;
-    for (const auto& privateServer : intersection) {
-        if (tracker.count(privateServer) != 0) {
-            continue;
+    // Add any new or changed servers to the tracker, and initiate async checks for them.
+    for (const auto& server : set) {
+        // Don't probe a server more than once.  This means that the only way to
+        // re-check a failed server is to remove it and re-add it from the netId.
+        if (tracker.count(server) == 0) {
+            checkPrivateDnsProvider(server, tracker, netId);
         }
-        tracker[privateServer] = Validation::in_process;
-        std::thread validate_thread([privateServer, netId] {
-            // ::validate() is a blocking call that performs network operations.
-            // It can take milliseconds to minutes, up to the SYN retry limit.
-            bool success = DnsTlsTransport::validate(privateServer, netId);
-            std::lock_guard<std::mutex> guard(privateDnsLock);
-            auto netPair = privateDnsTransports.find(netId);
-            if (netPair == privateDnsTransports.end()) {
-                ALOGW("netId %u was erased during private DNS validation", netId);
-                return;
-            }
-            auto& tracker = netPair->second;
-            if (privateDnsServers.count(privateServer) == 0) {
-                ALOGW("Server was removed during private DNS validation");
-                success = false;
-            }
-            if (success) {
-                tracker[privateServer] = Validation::success;
-            } else {
-                // Validation failure is expected if a user is on a captive portal.
-                // TODO: Trigger a second validation attempt after captive portal login
-                // succeeds.
-                tracker.erase(privateServer);
-            }
-        });
-        validate_thread.detach();
     }
+    return 0;
 }
 
 void clearPrivateDnsProviders(unsigned netId) {
@@ -198,30 +213,43 @@
 int ResolverController::setDnsServers(unsigned netId, const char* searchDomains,
         const char** servers, int numservers, const __res_params* params) {
     if (DBG) {
-        ALOGD("setDnsServers netId = %u\n", netId);
+        ALOGD("setDnsServers netId = %u, numservers = %d", netId, numservers);
     }
-    checkPrivateDnsProviders(netId, servers, numservers);
     return -_resolv_set_nameservers_for_net(netId, servers, numservers, searchDomains, params);
 }
 
-bool ResolverController::shouldUseTls(unsigned netId, const sockaddr_storage& insecureServer,
+ResolverController::Validation ResolverController::getTlsStatus(unsigned netId,
+        const sockaddr_storage& insecureServer,
         DnsTlsTransport::Server* secureServer) {
     // This mutex is on the critical path of every DNS lookup that doesn't hit a local cache.
     // If the overhead of mutex acquisition proves too high, we could reduce it by maintaining
     // an atomic_int32_t counter of validated connections, and returning early if it's zero.
+    if (DBG) {
+        ALOGD("getTlsStatus(%u, %s)?", netId, addrToString(&insecureServer).c_str());
+    }
     std::lock_guard<std::mutex> guard(privateDnsLock);
     const auto netPair = privateDnsTransports.find(netId);
     if (netPair == privateDnsTransports.end()) {
-        return false;
+        if (DBG) {
+            ALOGD("Not using TLS: no tracked servers for netId %u", netId);
+        }
+        return Validation::unknown_netid;
     }
     const auto& tracker = netPair->second;
     const auto serverPair = tracker.find(insecureServer);
-    if (serverPair == tracker.end() || serverPair->second != Validation::success) {
-        return false;
+    if (serverPair == tracker.end()) {
+        if (DBG) {
+            ALOGD("Server is not in the tracker (size %zu) for netid %u", tracker.size(), netId);
+        }
+        return Validation::unknown_server;
     }
     const auto& validatedServer = serverPair->first;
+    Validation status = serverPair->second;
+    if (DBG) {
+        ALOGD("Server %s has status %d", addrToString(&(validatedServer.ss)).c_str(), (int)status);
+    }
     *secureServer = validatedServer;
-    return true;
+    return status;
 }
 
 int ResolverController::clearDnsServers(unsigned netId) {
@@ -316,13 +344,24 @@
 
 int ResolverController::setResolverConfiguration(int32_t netId,
         const std::vector<std::string>& servers, const std::vector<std::string>& domains,
-        const std::vector<int32_t>& params) {
+        const std::vector<int32_t>& params, bool useTls, const std::string& tlsName,
+        const std::set<std::vector<uint8_t>>& tlsFingerprints) {
     using android::net::INetd;
     if (params.size() != INetd::RESOLVER_PARAMS_COUNT) {
         ALOGE("%s: params.size()=%zu", __FUNCTION__, params.size());
         return -EINVAL;
     }
 
+    if (useTls) {
+        int err = setPrivateDnsProviders(netId, servers, tlsName, tlsFingerprints);
+        if (err != 0) {
+            return err;
+        }
+    } else {
+        clearPrivateDnsProviders(netId);
+    }
+
+    // Convert server list to bionic's format.
     auto server_count = std::min<size_t>(MAXNS, servers.size());
     std::vector<const char*> server_ptrs;
     for (size_t i = 0 ; i < server_count ; ++i) {
@@ -424,61 +463,5 @@
     dw.decIndent();
 }
 
-int ResolverController::addPrivateDnsServer(const std::string& server, int32_t port,
-        const std::string& name,
-        const std::string& fingerprintAlgorithm,
-        const std::set<std::vector<uint8_t>>& fingerprints) {
-    using android::net::INetd;
-    if (fingerprintAlgorithm.empty()) {
-        if (!fingerprints.empty()) {
-            return INetd::PRIVATE_DNS_BAD_FINGERPRINT;
-        }
-    } else if (fingerprintAlgorithm.compare("SHA-256") == 0) {
-        if (fingerprints.empty()) {
-            return INetd::PRIVATE_DNS_BAD_FINGERPRINT;
-        }
-        for (const auto& fingerprint : fingerprints) {
-            if (fingerprint.size() != SHA256_SIZE) {
-                return INetd::PRIVATE_DNS_BAD_FINGERPRINT;
-            }
-        }
-    } else {
-        return INetd::PRIVATE_DNS_UNKNOWN_ALGORITHM;
-    }
-    if (port <= 0 || port > 0xFFFF) {
-        return INetd::PRIVATE_DNS_BAD_PORT;
-    }
-    sockaddr_storage parsed;
-    if (!parseServer(server.c_str(), port, &parsed)) {
-        return INetd::PRIVATE_DNS_BAD_ADDRESS;
-    }
-    DnsTlsTransport::Server privateServer(parsed);
-    privateServer.fingerprints = fingerprints;
-    privateServer.name = name;
-    std::lock_guard<std::mutex> guard(privateDnsLock);
-    // Ensure we overwrite any previous matching server.  This is necessary because equality is
-    // based only on the IP address, not the port or fingerprints.
-    privateDnsServers.erase(privateServer);
-    privateDnsServers.insert(privateServer);
-    if (DBG) {
-        ALOGD("Recorded private DNS server: %s", server.c_str());
-    }
-    return INetd::PRIVATE_DNS_SUCCESS;
-}
-
-int ResolverController::removePrivateDnsServer(const std::string& server) {
-    using android::net::INetd;
-    sockaddr_storage parsed;
-    if (!parseServer(server.c_str(), 0, &parsed)) {
-        return INetd::PRIVATE_DNS_BAD_ADDRESS;
-    }
-    std::lock_guard<std::mutex> guard(privateDnsLock);
-    privateDnsServers.erase(parsed);
-    for (auto& pair : privateDnsTransports) {
-        pair.second.erase(parsed);
-    }
-    return INetd::PRIVATE_DNS_SUCCESS;
-}
-
 }  // namespace net
 }  // namespace android
diff --git a/server/ResolverController.h b/server/ResolverController.h
index b283e8b..383b1ab 100644
--- a/server/ResolverController.h
+++ b/server/ResolverController.h
@@ -39,13 +39,14 @@
     int setDnsServers(unsigned netId, const char* searchDomains, const char** servers,
             int numservers, const __res_params* params);
 
+    // Validation status of a DNS over TLS server (on a specific netId).
+    enum class Validation : uint8_t { in_process, success, fail, unknown_server, unknown_netid };
+
     // Given a netId and the address of an insecure (i.e. normal) DNS server, this method checks
     // if there is a known secure DNS server with the same IP address that has been validated as
-    // accessible on this netId.  If so, it returns true, providing the server's address
-    // (including port) and pin fingerprints (possibly empty) in the output parameter.
-    // TODO: Add support for optional stronger security, by returning true even if the secure
-    // server is not accessible.
-    bool shouldUseTls(unsigned netId, const sockaddr_storage& insecureServer,
+    // accessible on this netId.  It returns the validation status, and provides the secure server
+    // (including port, name, and fingerprints) in the output parameter.
+    Validation getTlsStatus(unsigned netId, const sockaddr_storage& insecureServer,
             DnsTlsTransport::Server* secureServer);
 
     int clearDnsServers(unsigned netid);
@@ -59,18 +60,15 @@
     // Binder specific functions, which convert between the binder int/string arrays and the
     // actual data structures, and call setDnsServer() / getDnsInfo() for the actual processing.
     int setResolverConfiguration(int32_t netId, const std::vector<std::string>& servers,
-            const std::vector<std::string>& domains, const std::vector<int32_t>& params);
+            const std::vector<std::string>& domains, const std::vector<int32_t>& params,
+            bool useTls, const std::string& tlsName,
+            const std::set<std::vector<uint8_t>>& tlsFingerprints);
 
     int getResolverInfo(int32_t netId, std::vector<std::string>* servers,
             std::vector<std::string>* domains, std::vector<int32_t>* params,
             std::vector<int32_t>* stats);
     void dump(DumpWriter& dw, unsigned netId);
 
-    int addPrivateDnsServer(const std::string& server, int32_t port,
-            const std::string& name,
-            const std::string& fingerprintAlgorithm,
-            const std::set<std::vector<uint8_t>>& fingerprints);
-    int removePrivateDnsServer(const std::string& server);
 };
 
 }  // namespace net
diff --git a/server/binder/android/net/INetd.aidl b/server/binder/android/net/INetd.aidl
index 3280e0c..1b32dd3 100644
--- a/server/binder/android/net/INetd.aidl
+++ b/server/binder/android/net/INetd.aidl
@@ -104,11 +104,16 @@
      * @param params the params to set. This array contains RESOLVER_PARAMS_COUNT integers that
      *   encode the contents of Bionic's __res_params struct, i.e. sample_validity is stored at
      *   position RESOLVER_PARAMS_SAMPLE_VALIDITY, etc.
+     * @param useTls If true, try to contact servers over TLS on port 853.
+     * @param tlsName The TLS subject name to require for all servers, or empty if there is none.
+     * @param tlsFingerprints An array containing TLS public key fingerprints (pins) of which each
+     *   server must match at least one, or empty if there are no pinned keys.
      * @throws ServiceSpecificException in case of failure, with an error code corresponding to the
      *         unix errno.
      */
     void setResolverConfiguration(int netId, in @utf8InCpp String[] servers,
-            in @utf8InCpp String[] domains, in int[] params);
+            in @utf8InCpp String[] domains, in int[] params, boolean useTls,
+            in @utf8InCpp String tlsName, in @utf8InCpp String[] tlsFingerprints);
 
     // Array indices for resolver stats.
     const int RESOLVER_STATS_SUCCESSES = 0;
@@ -147,44 +152,6 @@
     void getResolverInfo(int netId, out @utf8InCpp String[] servers,
             out @utf8InCpp String[] domains, out int[] params, out int[] stats);
 
-    // Private DNS function error codes.
-    const int PRIVATE_DNS_SUCCESS = 0;
-    const int PRIVATE_DNS_BAD_ADDRESS = 1;
-    const int PRIVATE_DNS_BAD_PORT = 2;
-    const int PRIVATE_DNS_UNKNOWN_ALGORITHM = 3;
-    const int PRIVATE_DNS_BAD_FINGERPRINT = 4;
-
-    /**
-     * Adds a server to the list of DNS resolvers that support DNS over TLS.  After this action
-     * succeeds, any subsequent call to setResolverConfiguration will opportunistically use DNS
-     * over TLS if the specified server is on this list and is reachable on that network.
-     *
-     * @param server the DNS server's IP address.  If a private DNS server is already configured
-     *        with this IP address, it will be overwritten.
-     * @param port the port on which the server is listening, typically 853.
-     * @param name the DNS server's name if name validation is desired, otherwise "".
-     * @param fingerprintAlgorithm the hash algorithm used to compute the fingerprints.  This should
-     *        be a name in MessageDigest's format.  Currently "SHA-256" is the only supported
-     *        algorithm. Set this to the empty string to disable fingerprint validation.
-     * @param fingerprints the server's public key fingerprints as Base64 strings.
-     *        These can be generated using MessageDigest and android.util.Base64.encodeToString.
-     *        Currently "SHA-256" is the only supported algorithm. Set this to empty to disable
-     *        fingerprint validation.
-     * @throws ServiceSpecificException in case of failure, with an error code indicating the
-     *         cause of the the failure.
-     */
-    void addPrivateDnsServer(in @utf8InCpp String server, int port, in @utf8InCpp String name,
-             in @utf8InCpp String fingerprintAlgorithm, in @utf8InCpp String[] fingerprints);
-
-    /**
-     * Remove a server from the list of DNS resolvers that support DNS over TLS.
-     *
-     * @param server the DNS server's IP address.
-     * @throws ServiceSpecificException in case of failure, with an error code indicating the
-     *         cause of the the failure.
-     */
-    void removePrivateDnsServer(in @utf8InCpp String server);
-
     /**
      * Instruct the tethering DNS server to reevaluated serving interfaces.
      * This is needed to for the DNS server to observe changes in the set
diff --git a/server/dns/DnsTlsTransport.cpp b/server/dns/DnsTlsTransport.cpp
index 4cedac9..b369022 100644
--- a/server/dns/DnsTlsTransport.cpp
+++ b/server/dns/DnsTlsTransport.cpp
@@ -16,6 +16,8 @@
 
 #include "dns/DnsTlsTransport.h"
 
+#include <algorithm>
+#include <iterator>
 #include <arpa/inet.h>
 #include <arpa/nameser.h>
 #include <errno.h>
@@ -32,6 +34,72 @@
 #include "NetdConstants.h"
 #include "Permission.h"
 
+namespace {
+
+// Returns a tuple of references to the elements of a.
+auto make_tie(const sockaddr_in& a) {
+    return std::tie(a.sin_port, a.sin_addr.s_addr);
+}
+
+// Returns a tuple of references to the elements of a.
+auto make_tie(const sockaddr_in6& a) {
+    // Skip flowinfo, which is not relevant.
+    return std::tie(
+        a.sin6_port,
+        a.sin6_addr,
+        a.sin6_scope_id
+    );
+}
+
+} // namespace
+
+// These binary operators make sockaddr_storage comparable.  They need to be
+// in the global namespace so that the std::tuple < and == operators can see them.
+static bool operator <(const in6_addr& x, const in6_addr& y) {
+    return std::lexicographical_compare(
+            std::begin(x.s6_addr), std::end(x.s6_addr),
+            std::begin(y.s6_addr), std::end(y.s6_addr));
+}
+
+static bool operator ==(const in6_addr& x, const in6_addr& y) {
+    return std::equal(
+            std::begin(x.s6_addr), std::end(x.s6_addr),
+            std::begin(y.s6_addr), std::end(y.s6_addr));
+}
+
+static bool operator <(const sockaddr_storage& x, const sockaddr_storage& y) {
+    if (x.ss_family != y.ss_family) {
+        return x.ss_family < y.ss_family;
+    }
+    // Same address family.
+    if (x.ss_family == AF_INET) {
+        const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x);
+        const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y);
+        return make_tie(x_sin) < make_tie(y_sin);
+    } else if (x.ss_family == AF_INET6) {
+        const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x);
+        const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y);
+        return make_tie(x_sin6) < make_tie(y_sin6);
+    }
+    return false;  // Unknown address type.  This is an error.
+}
+
+static bool operator ==(const sockaddr_storage& x, const sockaddr_storage& y) {
+    if (x.ss_family != y.ss_family) {
+        return false;
+    }
+    // Same address family.
+    if (x.ss_family == AF_INET) {
+        const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x);
+        const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y);
+        return make_tie(x_sin) == make_tie(y_sin);
+    } else if (x.ss_family == AF_INET6) {
+        const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x);
+        const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y);
+        return make_tie(x_sin6) == make_tie(y_sin6);
+    }
+    return false;  // Unknown address type.  This is an error.
+}
 
 namespace android {
 namespace net {
@@ -125,6 +193,44 @@
     return true;
 }
 
+// This comparison ignores ports and fingerprints.
+// TODO: respect IPv6 scope id (e.g. link-local addresses).
+bool AddressComparator::operator() (const DnsTlsTransport::Server& x,
+        const DnsTlsTransport::Server& y) const {
+    if (x.ss.ss_family != y.ss.ss_family) {
+        return x.ss.ss_family < y.ss.ss_family;
+    }
+    // Same address family.
+    if (x.ss.ss_family == AF_INET) {
+        const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x.ss);
+        const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y.ss);
+        return x_sin.sin_addr.s_addr < y_sin.sin_addr.s_addr;
+    } else if (x.ss.ss_family == AF_INET6) {
+        const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x.ss);
+        const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y.ss);
+        return x_sin6.sin6_addr < y_sin6.sin6_addr;
+    }
+    return false;  // Unknown address type.  This is an error.
+}
+
+// Returns a tuple of references to the elements of s.
+auto make_tie(const DnsTlsTransport::Server& s) {
+    return std::tie(
+        s.ss,
+        s.name,
+        s.fingerprints,
+        s.protocol
+    );
+}
+
+bool DnsTlsTransport::Server::operator <(const DnsTlsTransport::Server& other) const {
+    return make_tie(*this) < make_tie(other);
+}
+
+bool DnsTlsTransport::Server::operator ==(const DnsTlsTransport::Server& other) const {
+    return make_tie(*this) == make_tie(other);
+}
+
 SSL* DnsTlsTransport::sslConnect(int fd) {
     if (fd < 0) {
         ALOGD("%u makeConnectedSocket() failed with: %s", mMark, strerror(errno));
@@ -140,7 +246,8 @@
     }
 
     bssl::UniquePtr<SSL> ssl(SSL_new(ssl_ctx.get()));
-    bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_CLOSE));
+    // This file descriptor is owned by a unique_fd, so don't let libssl close it.
+    bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_NOCLOSE));
     SSL_set_bio(ssl.get(), bio.get(), bio.get());
     bio.release();
 
diff --git a/server/dns/DnsTlsTransport.h b/server/dns/DnsTlsTransport.h
index adf49a0..5a066a3 100644
--- a/server/dns/DnsTlsTransport.h
+++ b/server/dns/DnsTlsTransport.h
@@ -43,6 +43,9 @@
         std::set<std::vector<uint8_t>> fingerprints;
         std::string name;
         int protocol = IPPROTO_TCP;
+        // Exact comparison of Server objects
+        bool operator <(const Server& other) const;
+        bool operator ==(const Server& other) const;
     };
 
     enum class Response : uint8_t { success, network_error, limit_error, internal_error };
@@ -85,6 +88,12 @@
     const Server mServer;
 };
 
+// This comparison ignores ports, names, and fingerprints.
+struct AddressComparator {
+    bool operator() (const DnsTlsTransport::Server& x, const DnsTlsTransport::Server& y) const;
+};
+
+
 }  // namespace net
 }  // namespace android
 
diff --git a/tests/binder_test.cpp b/tests/binder_test.cpp
index 209ca9e..b3a160d 100644
--- a/tests/binder_test.cpp
+++ b/tests/binder_test.cpp
@@ -637,83 +637,57 @@
     return std::string(reinterpret_cast<char*>(output_bytes));
 }
 
-TEST_F(BinderTest, TestAddPrivateDnsServer) {
+TEST_F(BinderTest, TestSetResolverConfiguration_Tls) {
     std::vector<uint8_t> fp(SHA256_SIZE);
+    std::vector<uint8_t> short_fp(1);
+    std::vector<uint8_t> long_fp(SHA256_SIZE + 1);
+    std::vector<std::string> test_domains;
+    std::vector<int> test_params = { 300, 25, 8, 8 };
+    unsigned test_netid = 0;
     static const struct TestData {
-        const std::string address;
-        const int port;
-        const std::string name;
-        const std::string fingerprintAlgorithm;
-        const std::set<std::vector<uint8_t>> fingerprints;
+        const std::vector<std::string> servers;
+        const std::string tlsName;
+        const std::vector<std::vector<uint8_t>> tlsFingerprints;
         const int expectedReturnCode;
     } kTestData[] = {
-        { "192.0.2.1", 853, "", "", {}, INetd::PRIVATE_DNS_SUCCESS },
-        { "2001:db8::2", 65535, "host.name", "", {}, INetd::PRIVATE_DNS_SUCCESS },
-        { "192.0.2.3", 443, "@@@@", "SHA-256", { fp }, INetd::PRIVATE_DNS_SUCCESS },
-        { "2001:db8::4", 1, "", "SHA-256", { fp }, INetd::PRIVATE_DNS_SUCCESS },
-        { "192.0.*.5", 853, "", "", {}, INetd::PRIVATE_DNS_BAD_ADDRESS },
-        { "", 853, "", "", {}, INetd::PRIVATE_DNS_BAD_ADDRESS },
-        { "2001:dg8::6", 65535, "", "", {}, INetd::PRIVATE_DNS_BAD_ADDRESS },
-        { "192.0.2.7", 0, "", "SHA-256", { fp }, INetd::PRIVATE_DNS_BAD_PORT },
-        { "2001:db8::8", 65536, "", "", {}, INetd::PRIVATE_DNS_BAD_PORT },
-        { "192.0.2.9", 50053, "", "SHA-512", { fp }, INetd::PRIVATE_DNS_UNKNOWN_ALGORITHM },
-        { "2001:db8::a", 853, "", "", { fp }, INetd::PRIVATE_DNS_BAD_FINGERPRINT },
-        { "192.0.2.11", 853, "", "SHA-256", {}, INetd::PRIVATE_DNS_BAD_FINGERPRINT },
-        { "2001:db8::c", 853, "", "SHA-256", { { 1 } }, INetd::PRIVATE_DNS_BAD_FINGERPRINT },
-        { "192.0.2.12", 853, "", "SHA-256", { std::vector<uint8_t>(SHA256_SIZE + 1) },
-                INetd::PRIVATE_DNS_BAD_FINGERPRINT },
-        { "2001:db8::e", 1, "", "SHA-256", { fp, fp, fp }, INetd::PRIVATE_DNS_SUCCESS },
-        { "192.0.2.14", 853, "", "SHA-256", { fp, { 1 } }, INetd::PRIVATE_DNS_BAD_FINGERPRINT },
+        { {"192.0.2.1"}, "", {}, 0 },
+        { {"2001:db8::2"}, "host.name", {}, 0 },
+        { {"192.0.2.3"}, "@@@@", { fp }, 0 },
+        { {"2001:db8::4"}, "", { fp }, 0 },
+        { {"192.0.*.5"}, "", {}, EINVAL },
+        { {""}, "", {}, EINVAL },
+        { {"2001:dg8::6"}, "", {}, EINVAL },
+        { {"2001:db8::c"}, "", { short_fp }, EINVAL },
+        { {"192.0.2.12"}, "", { long_fp }, EINVAL },
+        { {"2001:db8::e"}, "", { fp, fp, fp }, 0 },
+        { {"192.0.2.14"}, "", { fp, short_fp }, EINVAL },
     };
 
     for (unsigned int i = 0; i < arraysize(kTestData); i++) {
         const auto &td = kTestData[i];
 
         std::vector<std::string> fingerprints;
-        for (const std::vector<uint8_t>& fingerprint : td.fingerprints) {
+        for (const auto& fingerprint : td.tlsFingerprints) {
             fingerprints.push_back(base64Encode(fingerprint));
         }
-        const binder::Status status = mNetd->addPrivateDnsServer(
-                td.address, td.port, td.name, td.fingerprintAlgorithm, fingerprints);
+        binder::Status status = mNetd->setResolverConfiguration(
+                test_netid, td.servers, test_domains, test_params,
+                true, td.tlsName, fingerprints);
 
-        if (td.expectedReturnCode == INetd::PRIVATE_DNS_SUCCESS) {
+        if (td.expectedReturnCode == 0) {
             SCOPED_TRACE(String8::format("test case %d should have passed", i));
             SCOPED_TRACE(status.toString8());
             EXPECT_EQ(0, status.exceptionCode());
         } else {
             SCOPED_TRACE(String8::format("test case %d should have failed", i));
             EXPECT_EQ(binder::Status::EX_SERVICE_SPECIFIC, status.exceptionCode());
+            EXPECT_EQ(td.expectedReturnCode, status.serviceSpecificErrorCode());
         }
-        EXPECT_EQ(td.expectedReturnCode, status.serviceSpecificErrorCode());
     }
-}
-
-TEST_F(BinderTest, TestRemovePrivateDnsServer) {
-    static const struct TestData {
-        const std::string address;
-        const int expectedReturnCode;
-    } kTestData[] = {
-        { "192.0.2.1", INetd::PRIVATE_DNS_SUCCESS },
-        { "2001:db8::2", INetd::PRIVATE_DNS_SUCCESS },
-        { "192.0.*.3", INetd::PRIVATE_DNS_BAD_ADDRESS },
-        { "2001:dg8::4", INetd::PRIVATE_DNS_BAD_ADDRESS },
-        { "", INetd::PRIVATE_DNS_BAD_ADDRESS },
-    };
-
-    for (unsigned int i = 0; i < arraysize(kTestData); i++) {
-        const auto &td = kTestData[i];
-
-        const binder::Status status = mNetd->removePrivateDnsServer(td.address);
-
-        if (td.expectedReturnCode == INetd::PRIVATE_DNS_SUCCESS) {
-            SCOPED_TRACE(String8::format("test case %d should have passed", i));
-            EXPECT_EQ(0, status.exceptionCode());
-        } else {
-            SCOPED_TRACE(String8::format("test case %d should have failed", i));
-            EXPECT_EQ(binder::Status::EX_SERVICE_SPECIFIC, status.exceptionCode());
-        }
-        EXPECT_EQ(td.expectedReturnCode, status.serviceSpecificErrorCode());
-    }
+    // Ensure TLS is disabled before the start of the next test.
+    mNetd->setResolverConfiguration(
+        test_netid, kTestData[0].servers, test_domains, test_params,
+        false, "", {});
 }
 
 void expectNoTestCounterRules() {
diff --git a/tests/dns_responder/dns_responder_client.cpp b/tests/dns_responder/dns_responder_client.cpp
index ff5b556..879e19c 100644
--- a/tests/dns_responder/dns_responder_client.cpp
+++ b/tests/dns_responder/dns_responder_client.cpp
@@ -91,12 +91,22 @@
 
 bool DnsResponderClient::SetResolversForNetwork(const std::vector<std::string>& servers,
         const std::vector<std::string>& domains, const std::vector<int>& params) {
-    auto rv = mNetdSrv->setResolverConfiguration(TEST_NETID, servers, domains, params);
+    auto rv = mNetdSrv->setResolverConfiguration(TEST_NETID, servers, domains, params,
+            false, "", {});
     return rv.isOk();
 }
 
-bool DnsResponderClient::SetResolversForNetwork(const std::vector<std::string>& searchDomains,
-            const std::vector<std::string>& servers, const std::string& params) {
+bool DnsResponderClient::SetResolversWithTls(const std::vector<std::string>& servers,
+        const std::vector<std::string>& domains, const std::vector<int>& params,
+        const std::string& name,
+        const std::vector<std::string>& fingerprints) {
+    auto rv = mNetdSrv->setResolverConfiguration(TEST_NETID, servers, domains, params,
+            true, name, fingerprints);
+    return rv.isOk();
+}
+
+bool DnsResponderClient::SetResolversForNetwork(const std::vector<std::string>& servers,
+        const std::vector<std::string>& searchDomains, const std::string& params) {
     std::string cmd = StringPrintf("resolver setnetdns %d \"", mOemNetId);
     if (!searchDomains.empty()) {
         cmd += searchDomains[0].c_str();
diff --git a/tests/dns_responder/dns_responder_client.h b/tests/dns_responder/dns_responder_client.h
index ed7d38d..2df536c 100644
--- a/tests/dns_responder/dns_responder_client.h
+++ b/tests/dns_responder/dns_responder_client.h
@@ -29,8 +29,15 @@
     bool SetResolversForNetwork(const std::vector<std::string>& servers,
             const std::vector<std::string>& domains, const std::vector<int>& params);
 
-    bool SetResolversForNetwork(const std::vector<std::string>& searchDomains,
-            const std::vector<std::string>& servers, const std::string& params);
+    bool SetResolversForNetwork(const std::vector<std::string>& servers,
+            const std::vector<std::string>& searchDomains,
+            const std::string& params);
+
+    bool SetResolversWithTls(const std::vector<std::string>& servers,
+            const std::vector<std::string>& searchDomains,
+            const std::vector<int>& params,
+            const std::string& name,
+            const std::vector<std::string>& fingerprints);
 
     static void SetupDNSServers(unsigned num_servers, const std::vector<Mapping>& mappings,
             std::vector<std::unique_ptr<test::DNSResponder>>* dns,
diff --git a/tests/netd_test.cpp b/tests/netd_test.cpp
index bd47a05..8e1eb16 100644
--- a/tests/netd_test.cpp
+++ b/tests/netd_test.cpp
@@ -281,7 +281,7 @@
     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.3");
     ASSERT_TRUE(dns.startServer());
     std::vector<std::string> servers = { listen_addr };
-    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+    ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains, mDefaultParams));
 
     const hostent* result;
 
@@ -389,7 +389,7 @@
 
 
     std::vector<std::string> servers = { listen_addr };
-    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+    ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains, mDefaultParams));
     dns.clearQueries();
     dns2.clearQueries();
 
@@ -422,7 +422,7 @@
 
     // Change the DNS resolver, ensure that queries are still cached.
     servers = { listen_addr2 };
-    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+    ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains, mDefaultParams));
     dns.clearQueries();
     dns2.clearQueries();
 
@@ -456,7 +456,7 @@
     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.5");
     ASSERT_TRUE(dns.startServer());
     std::vector<std::string> servers = { listen_addr };
-    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+    ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains, mDefaultParams));
 
     addrinfo hints;
     memset(&hints, 0, sizeof(hints));
@@ -480,7 +480,7 @@
     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.3");
     ASSERT_TRUE(dns.startServer());
     std::vector<std::string> servers = { listen_addr };
-    ASSERT_TRUE(SetResolversForNetwork(searchDomains, servers, mDefaultParams));
+    ASSERT_TRUE(SetResolversForNetwork(servers, searchDomains, mDefaultParams));
 
     dns.clearQueries();
     const hostent* result = gethostbyname("nihao");
@@ -515,7 +515,7 @@
     int sample_count = 8;
     std::string params = StringPrintf("%u %d %d %d", sample_validity, success_threshold,
             sample_count, sample_count);
-    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, params));
+    ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains, params));
 
     // Repeatedly perform resolutions for non-existing domains until MAXNSSAMPLES resolutions have
     // reached the dns0, which is set to fail. No more requests should then arrive at that server
@@ -576,7 +576,7 @@
                 }
             }
             if (serverSubset.empty()) serverSubset = servers;
-            ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, serverSubset,
+            ASSERT_TRUE(SetResolversForNetwork(serverSubset, mDefaultSearchDomains,
                     mDefaultParams));
             addrinfo hints;
             memset(&hints, 0, sizeof(hints));
@@ -644,7 +644,7 @@
     ASSERT_TRUE(dns.startServer());
     std::vector<std::string> servers = { listen_addr };
     std::vector<std::string> domains = { "domain1.org" };
-    ASSERT_TRUE(SetResolversForNetwork(domains, servers, mDefaultParams));
+    ASSERT_TRUE(SetResolversForNetwork(servers, domains, mDefaultParams));
 
     addrinfo hints;
     memset(&hints, 0, sizeof(hints));
@@ -657,7 +657,7 @@
 
     // Test that changing the domain search path on its own works.
     domains = { "domain2.org" };
-    ASSERT_TRUE(SetResolversForNetwork(domains, servers, mDefaultParams));
+    ASSERT_TRUE(SetResolversForNetwork(servers, domains, mDefaultParams));
     dns.clearQueries();
 
     EXPECT_EQ(0, getaddrinfo("test13", nullptr, &hints, &result));
@@ -710,8 +710,7 @@
 
     // There's nothing listening on this address, so validation will either fail or
     /// hang.  Either way, queries will continue to flow to the DNSResponder.
-    auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "", "", {});
-    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+    ASSERT_TRUE(SetResolversWithTls(servers, mDefaultSearchDomains, mDefaultParams_Binder, "", {}));
 
     const hostent* result;
 
@@ -719,7 +718,8 @@
     ASSERT_FALSE(result == nullptr);
     EXPECT_EQ("1.2.3.3", ToString(result));
 
-    rv = mNetdSrv->removePrivateDnsServer(listen_addr);
+    // Clear TLS bit.
+    ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains,  mDefaultParams_Binder));
     dns.stopServer();
 }
 
@@ -749,10 +749,9 @@
     ASSERT_FALSE(bind(s, reinterpret_cast<struct sockaddr*>(&tlsServer), sizeof(tlsServer)));
     ASSERT_FALSE(listen(s, 1));
 
-    auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "", "", {});
-    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+    // Trigger TLS validation.
+    ASSERT_TRUE(SetResolversWithTls(servers, mDefaultSearchDomains, mDefaultParams_Binder, "", {}));
 
-    // SetResolversForNetwork should have triggered a validation connection to this address.
     struct sockaddr_storage cliaddr;
     socklen_t sin_size = sizeof(cliaddr);
     int new_fd = accept(s, reinterpret_cast<struct sockaddr *>(&cliaddr), &sin_size);
@@ -777,7 +776,8 @@
     ASSERT_FALSE(result == nullptr);
     EXPECT_EQ("1.2.3.2", ToString(result));
 
-    rv = mNetdSrv->removePrivateDnsServer(listen_addr);
+    // Clear TLS bit.
+    ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains,  mDefaultParams_Binder));
     dns.stopServer();
     close(s);
 }
@@ -798,8 +798,7 @@
 
     test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
     ASSERT_TRUE(tls.startServer());
-    auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "", "", {});
-    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+    ASSERT_TRUE(SetResolversWithTls(servers, mDefaultSearchDomains, mDefaultParams_Binder, "", {}));
 
     const hostent* result;
 
@@ -821,9 +820,9 @@
     EXPECT_TRUE(result == nullptr);
     EXPECT_EQ(HOST_NOT_FOUND, h_errno);
 
-    // Remove the TLS server setting.  Queries should now be routed to the
+    // Reset the resolvers without enabling TLS.  Queries should now be routed to the
     // UDP endpoint.
-    rv = mNetdSrv->removePrivateDnsServer(listen_addr);
+    ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains, mDefaultParams_Binder));
 
     result = gethostbyname("tls3");
     ASSERT_FALSE(result == nullptr);
@@ -835,6 +834,7 @@
 TEST_F(ResolverTest, GetHostByName_TlsFingerprint) {
     const char* listen_addr = "127.0.0.3";
     const char* listen_udp = "53";
+    const char* listen_tls = "853";
     test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
     ASSERT_TRUE(dns.startServer());
     for (int chain_length = 1; chain_length <= 3; ++chain_length) {
@@ -842,17 +842,11 @@
         dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.1");
         std::vector<std::string> servers = { listen_addr };
 
-        // Run each TLS server on a new port to avoid any possible races related to reopening
-        // sockets that were just closed.
-        int tls_port = 853 + chain_length;
-        const char* listen_tls = std::to_string(tls_port).c_str();
         test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
         tls.set_chain_length(chain_length);
         ASSERT_TRUE(tls.startServer());
-        auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, tls_port, "", "SHA-256",
-                { base64Encode(tls.fingerprint()) });
-        EXPECT_EQ(0, rv.exceptionCode());
-        ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+        ASSERT_TRUE(SetResolversWithTls(servers, mDefaultSearchDomains, mDefaultParams_Binder, "",
+                { base64Encode(tls.fingerprint()) }));
 
         const hostent* result;
 
@@ -868,8 +862,8 @@
             EXPECT_TRUE(tls.waitForQueries(2, 5000));
         }
 
-        rv = mNetdSrv->removePrivateDnsServer(listen_addr);
-        EXPECT_EQ(0, rv.exceptionCode());
+        // Clear TLS bit to ensure revalidation.
+        ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains,  mDefaultParams_Binder));
         tls.stopServer();
     }
     dns.stopServer();
@@ -889,25 +883,18 @@
     ASSERT_TRUE(tls.startServer());
     std::vector<uint8_t> bad_fingerprint = tls.fingerprint();
     bad_fingerprint[5] += 1;  // Corrupt the fingerprint.
-    auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "", "SHA-256",
-            { base64Encode(bad_fingerprint) });
-    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
-
-    const hostent* result;
+    ASSERT_TRUE(SetResolversWithTls(servers, mDefaultSearchDomains, mDefaultParams_Binder, "",
+            { base64Encode(bad_fingerprint) }));
 
     // The initial validation should fail at the fingerprint check before
     // issuing a query.
     EXPECT_FALSE(tls.waitForQueries(1, 500));
 
-    result = gethostbyname("badtlsfingerprint");
-    ASSERT_FALSE(result == nullptr);
-    EXPECT_EQ("1.2.3.1", ToString(result));
+    // A fingerprint was provided and failed to match, so the query should fail.
+    EXPECT_EQ(nullptr, gethostbyname("badtlsfingerprint"));
 
-    // The query should have bypassed the TLS frontend, because validation
-    // failed.
-    EXPECT_FALSE(tls.waitForQueries(1, 500));
-
-    rv = mNetdSrv->removePrivateDnsServer(listen_addr);
+    // Clear TLS bit.
+    ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains,  mDefaultParams_Binder));
     tls.stopServer();
     dns.stopServer();
 }
@@ -928,9 +915,8 @@
     ASSERT_TRUE(tls.startServer());
     std::vector<uint8_t> bad_fingerprint = tls.fingerprint();
     bad_fingerprint[5] += 1;  // Corrupt the fingerprint.
-    auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "", "SHA-256",
-            { base64Encode(bad_fingerprint), base64Encode(tls.fingerprint()) });
-    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+    ASSERT_TRUE(SetResolversWithTls(servers, mDefaultSearchDomains, mDefaultParams_Binder, "",
+            { base64Encode(bad_fingerprint), base64Encode(tls.fingerprint()) }));
 
     const hostent* result;
 
@@ -944,7 +930,8 @@
     // Wait for query to get counted.
     EXPECT_TRUE(tls.waitForQueries(2, 5000));
 
-    rv = mNetdSrv->removePrivateDnsServer(listen_addr);
+    // Clear TLS bit.
+    ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains,  mDefaultParams_Binder));
     tls.stopServer();
     dns.stopServer();
 }
@@ -963,9 +950,8 @@
 
     test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
     ASSERT_TRUE(tls.startServer());
-    auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "", "SHA-256",
-            { base64Encode(tls.fingerprint()) });
-    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+    ASSERT_TRUE(SetResolversWithTls(servers, mDefaultSearchDomains, mDefaultParams_Binder, "",
+            { base64Encode(tls.fingerprint()) }));
 
     const hostent* result;
 
@@ -988,7 +974,8 @@
     ASSERT_TRUE(result == nullptr);
     EXPECT_EQ(HOST_NOT_FOUND, h_errno);
 
-    rv = mNetdSrv->removePrivateDnsServer(listen_addr);
+    // Clear TLS bit.
+    ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains,  mDefaultParams_Binder));
     tls.stopServer();
     dns.stopServer();
 }
@@ -1014,11 +1001,8 @@
     test::DnsTlsFrontend tls2(listen_addr2, listen_tls, listen_addr2, listen_udp);
     ASSERT_TRUE(tls1.startServer());
     ASSERT_TRUE(tls2.startServer());
-    auto rv = mNetdSrv->addPrivateDnsServer(listen_addr1, 853, "", "SHA-256",
-            { base64Encode(tls1.fingerprint()) });
-    rv = mNetdSrv->addPrivateDnsServer(listen_addr2, 853, "", "SHA-256",
-            { base64Encode(tls2.fingerprint()) });
-    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+    ASSERT_TRUE(SetResolversWithTls(servers, mDefaultSearchDomains, mDefaultParams_Binder, "",
+            { base64Encode(tls1.fingerprint()), base64Encode(tls2.fingerprint()) }));
 
     const hostent* result;
 
@@ -1048,8 +1032,8 @@
     EXPECT_EQ(2U, dns1.queries().size());
     EXPECT_EQ(2U, dns2.queries().size());
 
-    rv = mNetdSrv->removePrivateDnsServer(listen_addr1);
-    rv = mNetdSrv->removePrivateDnsServer(listen_addr2);
+    // Clear TLS bit.
+    ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains,  mDefaultParams_Binder));
     tls2.stopServer();
     dns1.stopServer();
     dns2.stopServer();
@@ -1067,24 +1051,18 @@
 
     test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
     ASSERT_TRUE(tls.startServer());
-    auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "www.example.com", "", {});
-    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
-
-    const hostent* result;
+    ASSERT_TRUE(SetResolversWithTls(servers, mDefaultSearchDomains, mDefaultParams_Binder,
+            "www.example.com", {}));
 
     // The TLS server's certificate doesn't chain to a known CA, and a nonempty name was specified,
     // so the client should fail the TLS handshake before ever issuing a query.
     EXPECT_FALSE(tls.waitForQueries(1, 500));
 
-    result = gethostbyname("badtlsname");
-    ASSERT_FALSE(result == nullptr);
-    EXPECT_EQ("1.2.3.1", ToString(result));
+    // The query should fail hard, because a name was specified.
+    EXPECT_EQ(nullptr, gethostbyname("badtlsname"));
 
-    // The query should have bypassed the TLS frontend, because validation
-    // failed.
-    EXPECT_FALSE(tls.waitForQueries(1, 500));
-
-    rv = mNetdSrv->removePrivateDnsServer(listen_addr);
+    // Clear TLS bit.
+    ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains,  mDefaultParams_Binder));
     tls.stopServer();
     dns.stopServer();
 }
@@ -1102,9 +1080,8 @@
 
     test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
     ASSERT_TRUE(tls.startServer());
-    auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "", "SHA-256",
-            { base64Encode(tls.fingerprint()) });
-    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+    ASSERT_TRUE(SetResolversWithTls(servers, mDefaultSearchDomains, mDefaultParams_Binder, "",
+            { base64Encode(tls.fingerprint()) }));
 
     // Wait for validation to complete.
     EXPECT_TRUE(tls.waitForQueries(1, 5000));
@@ -1126,7 +1103,8 @@
     // Wait for both A and AAAA queries to get counted.
     EXPECT_TRUE(tls.waitForQueries(3, 5000));
 
-    rv = mNetdSrv->removePrivateDnsServer(listen_addr);
+    // Clear TLS bit.
+    ASSERT_TRUE(SetResolversForNetwork(servers, mDefaultSearchDomains,  mDefaultParams_Binder));
     tls.stopServer();
     dns.stopServer();
 }