Switch to a new way of activating DNS-over-TLS

This change removes the global database of potential DNS-over-TLS
servers from Netd, and makes pinned or named servers mandatory-TLS,
not opportunistic.

Bug: 64753847
Change-Id: I226ffec3f59593bc40cd9019095c5261aae55fa0
Test: Tests pass.  Normal browsing continues to work normally.
diff --git a/server/ResolverController.cpp b/server/ResolverController.cpp
index ab69383..51928d2 100644
--- a/server/ResolverController.cpp
+++ b/server/ResolverController.cpp
@@ -54,25 +54,13 @@
 
 namespace {
 
-// This comparison ignores ports and fingerprints.
-struct AddressComparator {
-    bool operator() (const DnsTlsTransport::Server& x, const DnsTlsTransport::Server& y) const {
-      if (x.ss.ss_family != y.ss.ss_family) {
-          return x.ss.ss_family < y.ss.ss_family;
-      }
-      // Same address family.  Compare IP addresses.
-      if (x.ss.ss_family == AF_INET) {
-          const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x.ss);
-          const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y.ss);
-          return x_sin.sin_addr.s_addr < y_sin.sin_addr.s_addr;
-      } else if (x.ss.ss_family == AF_INET6) {
-          const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x.ss);
-          const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y.ss);
-          return std::memcmp(x_sin6.sin6_addr.s6_addr, y_sin6.sin6_addr.s6_addr, 16) < 0;
-      }
-      return false;  // Unknown address type.  This is an error.
-    }
-};
+// Only used for debug logging
+std::string addrToString(const sockaddr_storage* addr) {
+    char out[INET6_ADDRSTRLEN] = {0};
+    getnameinfo((const sockaddr*)addr, sizeof(sockaddr_storage), out,
+            INET6_ADDRSTRLEN, NULL, 0, NI_NUMERICHOST);
+    return std::string(out);
+}
 
 bool parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
     sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
@@ -95,94 +83,121 @@
     return false;
 }
 
-// Structure for tracking the entire set of known Private DNS servers.
+// Structure for tracking the validation status of servers on a specific netId.
+// Using the AddressComparator ensures at most one entry per IP address.
+typedef std::map<DnsTlsTransport::Server, ResolverController::Validation,
+        AddressComparator> PrivateDnsTracker;
 std::mutex privateDnsLock;
-typedef std::set<DnsTlsTransport::Server, AddressComparator> PrivateDnsSet;
-PrivateDnsSet privateDnsServers GUARDED_BY(privateDnsLock);
-
-// Structure for tracking the validation status of servers on a specific netid.
-// Servers that fail validation are removed from the tracker, and can be retried.
-enum class Validation : bool { in_process, success };
-typedef std::map<DnsTlsTransport::Server, Validation, AddressComparator> PrivateDnsTracker;
 std::map<unsigned, PrivateDnsTracker> privateDnsTransports GUARDED_BY(privateDnsLock);
 
-PrivateDnsSet parseServers(const char** servers, int numservers, in_port_t port) {
-    PrivateDnsSet set;
-    for (int i = 0; i < numservers; ++i) {
-        sockaddr_storage parsed;
-        if (parseServer(servers[i], port, &parsed)) {
-            set.insert(parsed);
-        }
+void checkPrivateDnsProvider(const DnsTlsTransport::Server& server,
+        PrivateDnsTracker& tracker, unsigned netId) REQUIRES(privateDnsLock) {
+    if (DBG) {
+        ALOGD("checkPrivateDnsProvider(%s, %u)", addrToString(&(server.ss)).c_str(), netId);
     }
-    return set;
+
+    tracker[server] = ResolverController::Validation::in_process;
+    if (DBG) {
+        ALOGD("Server %s marked as in_process.  Tracker now has size %zu",
+                addrToString(&(server.ss)).c_str(), tracker.size());
+    }
+    std::thread validate_thread([server, netId] {
+        // ::validate() is a blocking call that performs network operations.
+        // It can take milliseconds to minutes, up to the SYN retry limit.
+        bool success = DnsTlsTransport::validate(server, netId);
+        if (DBG) {
+            ALOGD("validateDnsTlsServer returned %d for %s", success,
+                    addrToString(&(server.ss)).c_str());
+        }
+        std::lock_guard<std::mutex> guard(privateDnsLock);
+        auto netPair = privateDnsTransports.find(netId);
+        if (netPair == privateDnsTransports.end()) {
+            ALOGW("netId %u was erased during private DNS validation", netId);
+            return;
+        }
+        auto& tracker = netPair->second;
+        auto serverPair = tracker.find(server);
+        if (serverPair == tracker.end()) {
+            ALOGW("Server %s was removed during private DNS validation",
+                    addrToString(&(server.ss)).c_str());
+            success = false;
+        }
+        if (!(serverPair->first == server)) {
+            ALOGW("Server %s was changed during private DNS validation",
+                    addrToString(&(server.ss)).c_str());
+            success = false;
+        }
+        if (success) {
+            tracker[server] = ResolverController::Validation::success;
+            if (DBG) {
+                ALOGD("Validation succeeded for %s! Tracker now has %zu entries.",
+                        addrToString(&(server.ss)).c_str(), tracker.size());
+            }
+        } else {
+            // Validation failure is expected if a user is on a captive portal.
+            // TODO: Trigger a second validation attempt after captive portal login
+            // succeeds.
+            if (DBG) {
+                ALOGD("Validation failed for %s!", addrToString(&(server.ss)).c_str());
+            }
+            tracker[server] = ResolverController::Validation::fail;
+        }
+    });
+    validate_thread.detach();
 }
 
-void checkPrivateDnsProviders(const unsigned netId, const char** servers, int numservers) {
+int setPrivateDnsProviders(int32_t netId,
+        const std::vector<std::string>& servers, const std::string& name,
+        const std::set<std::vector<uint8_t>>& fingerprints) {
     if (DBG) {
-        ALOGD("checkPrivateDnsProviders(%u)", netId);
+        ALOGD("setPrivateDnsProviders(%u, %zu, %s, %zu)",
+                netId, servers.size(), name.c_str(), fingerprints.size());
+    }
+    // Parse the list of servers that has been passed in
+    std::set<DnsTlsTransport::Server> set;
+    for (size_t i = 0; i < servers.size(); ++i) {
+        sockaddr_storage parsed;
+        if (!parseServer(servers[i].c_str(), 853, &parsed)) {
+            return -EINVAL;
+        }
+        DnsTlsTransport::Server server(parsed);
+        server.name = name;
+        server.fingerprints = fingerprints;
+        set.insert(server);
     }
 
     std::lock_guard<std::mutex> guard(privateDnsLock);
-    if (privateDnsServers.empty()) {
-        return;
-    }
-
-    // First compute the intersection of the servers to check with the
-    // servers that are permitted to use DNS over TLS.  The intersection
-    // will contain the port number to be used for Private DNS.
-    PrivateDnsSet serversToCheck = parseServers(servers, numservers, 53);
-    PrivateDnsSet intersection;
-    std::set_intersection(privateDnsServers.begin(), privateDnsServers.end(),
-        serversToCheck.begin(), serversToCheck.end(),
-        std::inserter(intersection, intersection.begin()),
-        AddressComparator());
-    if (intersection.empty()) {
-        return;
-    }
-
+    // Create the tracker if it was not present
     auto netPair = privateDnsTransports.find(netId);
     if (netPair == privateDnsTransports.end()) {
-        // New netId
+        // No TLS tracker yet for this netId.
         bool added;
         std::tie(netPair, added) = privateDnsTransports.emplace(netId, PrivateDnsTracker());
         if (!added) {
-            ALOGE("Memory error while checking private DNS for netId %d", netId);
-            return;
+            ALOGE("Memory error while recording private DNS for netId %d", netId);
+            return -ENOMEM;
+        }
+    }
+    auto& tracker = netPair->second;
+
+    // Remove any servers from the tracker that are not in |servers| exactly.
+    for (auto it = tracker.begin(); it != tracker.end();) {
+        if (set.count(it->first) == 0) {
+            it = tracker.erase(it);
+        } else {
+            ++it;
         }
     }
 
-    auto& tracker = netPair->second;
-    for (const auto& privateServer : intersection) {
-        if (tracker.count(privateServer) != 0) {
-            continue;
+    // Add any new or changed servers to the tracker, and initiate async checks for them.
+    for (const auto& server : set) {
+        // Don't probe a server more than once.  This means that the only way to
+        // re-check a failed server is to remove it and re-add it from the netId.
+        if (tracker.count(server) == 0) {
+            checkPrivateDnsProvider(server, tracker, netId);
         }
-        tracker[privateServer] = Validation::in_process;
-        std::thread validate_thread([privateServer, netId] {
-            // ::validate() is a blocking call that performs network operations.
-            // It can take milliseconds to minutes, up to the SYN retry limit.
-            bool success = DnsTlsTransport::validate(privateServer, netId);
-            std::lock_guard<std::mutex> guard(privateDnsLock);
-            auto netPair = privateDnsTransports.find(netId);
-            if (netPair == privateDnsTransports.end()) {
-                ALOGW("netId %u was erased during private DNS validation", netId);
-                return;
-            }
-            auto& tracker = netPair->second;
-            if (privateDnsServers.count(privateServer) == 0) {
-                ALOGW("Server was removed during private DNS validation");
-                success = false;
-            }
-            if (success) {
-                tracker[privateServer] = Validation::success;
-            } else {
-                // Validation failure is expected if a user is on a captive portal.
-                // TODO: Trigger a second validation attempt after captive portal login
-                // succeeds.
-                tracker.erase(privateServer);
-            }
-        });
-        validate_thread.detach();
     }
+    return 0;
 }
 
 void clearPrivateDnsProviders(unsigned netId) {
@@ -198,30 +213,43 @@
 int ResolverController::setDnsServers(unsigned netId, const char* searchDomains,
         const char** servers, int numservers, const __res_params* params) {
     if (DBG) {
-        ALOGD("setDnsServers netId = %u\n", netId);
+        ALOGD("setDnsServers netId = %u, numservers = %d", netId, numservers);
     }
-    checkPrivateDnsProviders(netId, servers, numservers);
     return -_resolv_set_nameservers_for_net(netId, servers, numservers, searchDomains, params);
 }
 
-bool ResolverController::shouldUseTls(unsigned netId, const sockaddr_storage& insecureServer,
+ResolverController::Validation ResolverController::getTlsStatus(unsigned netId,
+        const sockaddr_storage& insecureServer,
         DnsTlsTransport::Server* secureServer) {
     // This mutex is on the critical path of every DNS lookup that doesn't hit a local cache.
     // If the overhead of mutex acquisition proves too high, we could reduce it by maintaining
     // an atomic_int32_t counter of validated connections, and returning early if it's zero.
+    if (DBG) {
+        ALOGD("getTlsStatus(%u, %s)?", netId, addrToString(&insecureServer).c_str());
+    }
     std::lock_guard<std::mutex> guard(privateDnsLock);
     const auto netPair = privateDnsTransports.find(netId);
     if (netPair == privateDnsTransports.end()) {
-        return false;
+        if (DBG) {
+            ALOGD("Not using TLS: no tracked servers for netId %u", netId);
+        }
+        return Validation::unknown_netid;
     }
     const auto& tracker = netPair->second;
     const auto serverPair = tracker.find(insecureServer);
-    if (serverPair == tracker.end() || serverPair->second != Validation::success) {
-        return false;
+    if (serverPair == tracker.end()) {
+        if (DBG) {
+            ALOGD("Server is not in the tracker (size %zu) for netid %u", tracker.size(), netId);
+        }
+        return Validation::unknown_server;
     }
     const auto& validatedServer = serverPair->first;
+    Validation status = serverPair->second;
+    if (DBG) {
+        ALOGD("Server %s has status %d", addrToString(&(validatedServer.ss)).c_str(), (int)status);
+    }
     *secureServer = validatedServer;
-    return true;
+    return status;
 }
 
 int ResolverController::clearDnsServers(unsigned netId) {
@@ -316,13 +344,24 @@
 
 int ResolverController::setResolverConfiguration(int32_t netId,
         const std::vector<std::string>& servers, const std::vector<std::string>& domains,
-        const std::vector<int32_t>& params) {
+        const std::vector<int32_t>& params, bool useTls, const std::string& tlsName,
+        const std::set<std::vector<uint8_t>>& tlsFingerprints) {
     using android::net::INetd;
     if (params.size() != INetd::RESOLVER_PARAMS_COUNT) {
         ALOGE("%s: params.size()=%zu", __FUNCTION__, params.size());
         return -EINVAL;
     }
 
+    if (useTls) {
+        int err = setPrivateDnsProviders(netId, servers, tlsName, tlsFingerprints);
+        if (err != 0) {
+            return err;
+        }
+    } else {
+        clearPrivateDnsProviders(netId);
+    }
+
+    // Convert server list to bionic's format.
     auto server_count = std::min<size_t>(MAXNS, servers.size());
     std::vector<const char*> server_ptrs;
     for (size_t i = 0 ; i < server_count ; ++i) {
@@ -424,61 +463,5 @@
     dw.decIndent();
 }
 
-int ResolverController::addPrivateDnsServer(const std::string& server, int32_t port,
-        const std::string& name,
-        const std::string& fingerprintAlgorithm,
-        const std::set<std::vector<uint8_t>>& fingerprints) {
-    using android::net::INetd;
-    if (fingerprintAlgorithm.empty()) {
-        if (!fingerprints.empty()) {
-            return INetd::PRIVATE_DNS_BAD_FINGERPRINT;
-        }
-    } else if (fingerprintAlgorithm.compare("SHA-256") == 0) {
-        if (fingerprints.empty()) {
-            return INetd::PRIVATE_DNS_BAD_FINGERPRINT;
-        }
-        for (const auto& fingerprint : fingerprints) {
-            if (fingerprint.size() != SHA256_SIZE) {
-                return INetd::PRIVATE_DNS_BAD_FINGERPRINT;
-            }
-        }
-    } else {
-        return INetd::PRIVATE_DNS_UNKNOWN_ALGORITHM;
-    }
-    if (port <= 0 || port > 0xFFFF) {
-        return INetd::PRIVATE_DNS_BAD_PORT;
-    }
-    sockaddr_storage parsed;
-    if (!parseServer(server.c_str(), port, &parsed)) {
-        return INetd::PRIVATE_DNS_BAD_ADDRESS;
-    }
-    DnsTlsTransport::Server privateServer(parsed);
-    privateServer.fingerprints = fingerprints;
-    privateServer.name = name;
-    std::lock_guard<std::mutex> guard(privateDnsLock);
-    // Ensure we overwrite any previous matching server.  This is necessary because equality is
-    // based only on the IP address, not the port or fingerprints.
-    privateDnsServers.erase(privateServer);
-    privateDnsServers.insert(privateServer);
-    if (DBG) {
-        ALOGD("Recorded private DNS server: %s", server.c_str());
-    }
-    return INetd::PRIVATE_DNS_SUCCESS;
-}
-
-int ResolverController::removePrivateDnsServer(const std::string& server) {
-    using android::net::INetd;
-    sockaddr_storage parsed;
-    if (!parseServer(server.c_str(), 0, &parsed)) {
-        return INetd::PRIVATE_DNS_BAD_ADDRESS;
-    }
-    std::lock_guard<std::mutex> guard(privateDnsLock);
-    privateDnsServers.erase(parsed);
-    for (auto& pair : privateDnsTransports) {
-        pair.second.erase(parsed);
-    }
-    return INetd::PRIVATE_DNS_SUCCESS;
-}
-
 }  // namespace net
 }  // namespace android