Support RFC 7858 DNS over TLS

This change adds the core capability for DNS over TLS, and creates
private APIs for activating it, but does not provide any way to
activate the functionality in a development environment or on a
real device.

Based on https://android-review.googlesource.com/#/c/373776/

Test: Complete unit+integration tests.  Manual tests look good.
Bug: 34953048
Change-Id: Ib99ac1f631fd2c2c8fbf53bdb05f67f8be7713ac
diff --git a/server/ResolverController.cpp b/server/ResolverController.cpp
index daf4ebb..caf3ee9 100644
--- a/server/ResolverController.cpp
+++ b/server/ResolverController.cpp
@@ -19,13 +19,19 @@
 
 #include <algorithm>
 #include <cstdlib>
+#include <map>
+#include <mutex>
+#include <set>
 #include <string>
+#include <thread>
+#include <utility>
 #include <vector>
 #include <cutils/log.h>
 #include <net/if.h>
 #include <sys/socket.h>
 #include <netdb.h>
 
+#include <arpa/inet.h>
 // NOTE: <resolv_netid.h> is a private C library header that provides
 //       declarations for _resolv_set_nameservers_for_net and
 //       _resolv_flush_cache_for_net
@@ -37,25 +43,199 @@
 #include <android/net/INetd.h>
 
 #include "DumpWriter.h"
+#include "NetdConstants.h"
 #include "ResolverController.h"
 #include "ResolverStats.h"
+#include "dns/DnsTlsTransport.h"
 
 namespace android {
 namespace net {
 
+namespace {
+
+struct PrivateDnsServer {
+    PrivateDnsServer(const sockaddr_storage& ss) : ss(ss) {}
+    const sockaddr_storage ss;
+    // For now, the fingerprints are always SHA-256.  This is the only digest algorithm
+    // that is mandatory to support (https://tools.ietf.org/html/rfc7858#section-4.2).
+    std::set<std::vector<uint8_t>> fingerprints;
+};
+
+// This comparison ignores ports and fingerprints.
+bool operator<(const PrivateDnsServer& x, const PrivateDnsServer& y) {
+    if (x.ss.ss_family != y.ss.ss_family) {
+        return x.ss.ss_family < y.ss.ss_family;
+    }
+    // Same address family.  Compare IP addresses.
+    if (x.ss.ss_family == AF_INET) {
+        const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x.ss);
+        const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y.ss);
+        return x_sin.sin_addr.s_addr < y_sin.sin_addr.s_addr;
+    } else if (x.ss.ss_family == AF_INET6) {
+        const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x.ss);
+        const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y.ss);
+        return std::memcmp(x_sin6.sin6_addr.s6_addr, y_sin6.sin6_addr.s6_addr, 16);
+    }
+    return false;  // Unknown address type.  This is an error.
+}
+
+bool parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
+    sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
+    if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) {
+        // IPv4 parse succeeded, so it's IPv4
+        sin->sin_family = AF_INET;
+        sin->sin_port = htons(port);
+        return true;
+    }
+    sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed);
+    if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){
+        // IPv6 parse succeeded, so it's IPv6.
+        sin6->sin6_family = AF_INET6;
+        sin6->sin6_port = htons(port);
+        return true;
+    }
+    if (DBG) {
+        ALOGW("Failed to parse server address: %s", server);
+    }
+    return false;
+}
+
+// Structure for tracking the entire set of known Private DNS servers.
+std::mutex privateDnsLock;
+typedef std::set<PrivateDnsServer> PrivateDnsSet;
+PrivateDnsSet privateDnsServers;
+
+// Structure for tracking the validation status of servers on a specific netid.
+// Servers that fail validation are removed from the tracker, and can be retried.
+enum class Validation : bool { in_process, success };
+typedef std::map<PrivateDnsServer, Validation> PrivateDnsTracker;
+std::map<unsigned, PrivateDnsTracker> privateDnsTransports;
+
+PrivateDnsSet parseServers(const char** servers, int numservers, in_port_t port) {
+    PrivateDnsSet set;
+    for (int i = 0; i < numservers; ++i) {
+        sockaddr_storage parsed;
+        if (parseServer(servers[i], port, &parsed)) {
+            set.insert(parsed);
+        }
+    }
+    return set;
+}
+
+void checkPrivateDnsProviders(const unsigned netId, const char** servers, int numservers) {
+    if (DBG) {
+        ALOGD("checkPrivateDnsProviders(%u)", netId);
+    }
+
+    std::lock_guard<std::mutex> guard(privateDnsLock);
+    if (privateDnsServers.empty()) {
+        return;
+    }
+
+    // First compute the intersection of the servers to check with the
+    // servers that are permitted to use DNS over TLS.  The intersection
+    // will contain the port number to be used for Private DNS.
+    PrivateDnsSet serversToCheck = parseServers(servers, numservers, 53);
+    PrivateDnsSet intersection;
+    std::set_intersection(privateDnsServers.begin(), privateDnsServers.end(),
+        serversToCheck.begin(), serversToCheck.end(),
+        std::inserter(intersection, intersection.begin()));
+    if (intersection.empty()) {
+        return;
+    }
+
+    auto netPair = privateDnsTransports.find(netId);
+    if (netPair == privateDnsTransports.end()) {
+        // New netId
+        bool added;
+        std::tie(netPair, added) = privateDnsTransports.emplace(netId, PrivateDnsTracker());
+        if (!added) {
+            ALOGE("Memory error while checking private DNS for netId %d", netId);
+            return;
+        }
+    }
+
+    auto& tracker = netPair->second;
+    for (const auto& privateServer : intersection) {
+        if (tracker.count(privateServer) != 0) {
+            continue;
+        }
+        tracker[privateServer] = Validation::in_process;
+        std::thread validate_thread([privateServer, netId] {
+            // validateDnsTlsServer() is a blocking call that performs network operations.
+            // It can take milliseconds to minutes, up to the SYN retry limit.
+            bool success = validateDnsTlsServer(netId,
+                    privateServer.ss, privateServer.fingerprints);
+            std::lock_guard<std::mutex> guard(privateDnsLock);
+            auto netPair = privateDnsTransports.find(netId);
+            if (netPair == privateDnsTransports.end()) {
+                ALOGW("netId %u was erased during private DNS validation", netId);
+                return;
+            }
+            auto& tracker = netPair->second;
+            if (privateDnsServers.count(privateServer) == 0) {
+                ALOGW("Server was removed during private DNS validation");
+                success = false;
+            }
+            if (success) {
+                tracker[privateServer] = Validation::success;
+            } else {
+                // Validation failure is expected if a user is on a captive portal.
+                // TODO: Trigger a second validation attempt after captive portal login
+                // succeeds.
+                tracker.erase(privateServer);
+            }
+        });
+        validate_thread.detach();
+    }
+}
+
+void clearPrivateDnsProviders(unsigned netId) {
+    if (DBG) {
+        ALOGD("clearPrivateDnsProviders(%u)", netId);
+    }
+    std::lock_guard<std::mutex> guard(privateDnsLock);
+    privateDnsTransports.erase(netId);
+}
+
+}  // namespace
+
 int ResolverController::setDnsServers(unsigned netId, const char* searchDomains,
         const char** servers, int numservers, const __res_params* params) {
     if (DBG) {
         ALOGD("setDnsServers netId = %u\n", netId);
     }
+    checkPrivateDnsProviders(netId, servers, numservers);
     return -_resolv_set_nameservers_for_net(netId, servers, numservers, searchDomains, params);
 }
 
+bool ResolverController::shouldUseTls(unsigned netId, const sockaddr_storage& insecureServer,
+        sockaddr_storage* secureServer, std::set<std::vector<uint8_t>>* fingerprints) {
+    // This mutex is on the critical path of every DNS lookup that doesn't hit a local cache.
+    // If the overhead of mutex acquisition proves too high, we could reduce it by maintaining
+    // an atomic_int32_t counter of validated connections, and returning early if it's zero.
+    std::lock_guard<std::mutex> guard(privateDnsLock);
+    const auto netPair = privateDnsTransports.find(netId);
+    if (netPair == privateDnsTransports.end()) {
+        return false;
+    }
+    const auto& tracker = netPair->second;
+    const auto serverPair = tracker.find(insecureServer);
+    if (serverPair == tracker.end() || serverPair->second != Validation::success) {
+        return false;
+    }
+    const auto& validatedServer = serverPair->first;
+    *secureServer = validatedServer.ss;
+    *fingerprints = validatedServer.fingerprints;
+    return true;
+}
+
 int ResolverController::clearDnsServers(unsigned netId) {
     _resolv_set_nameservers_for_net(netId, NULL, 0, "", NULL);
     if (DBG) {
         ALOGD("clearDnsServers netId = %u\n", netId);
     }
+    clearPrivateDnsProviders(netId);
     return 0;
 }
 
@@ -250,5 +430,56 @@
     dw.decIndent();
 }
 
+int ResolverController::addPrivateDnsServer(const std::string& server, int32_t port,
+        const std::string& fingerprintAlgorithm,
+        const std::set<std::vector<uint8_t>>& fingerprints) {
+    using android::net::INetd;
+    if (fingerprintAlgorithm.empty()) {
+        if (!fingerprints.empty()) {
+            return INetd::PRIVATE_DNS_BAD_FINGERPRINT;
+        }
+    } else if (fingerprintAlgorithm.compare("SHA-256") == 0) {
+        if (fingerprints.empty()) {
+            return INetd::PRIVATE_DNS_BAD_FINGERPRINT;
+        }
+        for (const auto& fingerprint : fingerprints) {
+            if (fingerprint.size() != SHA256_SIZE) {
+                return INetd::PRIVATE_DNS_BAD_FINGERPRINT;
+            }
+        }
+    } else {
+        return INetd::PRIVATE_DNS_UNKNOWN_ALGORITHM;
+    }
+    if (port <= 0 || port > 0xFFFF) {
+        return INetd::PRIVATE_DNS_BAD_PORT;
+    }
+    sockaddr_storage parsed;
+    if (!parseServer(server.c_str(), port, &parsed)) {
+        return INetd::PRIVATE_DNS_BAD_ADDRESS;
+    }
+    PrivateDnsServer privateServer(parsed);
+    privateServer.fingerprints = fingerprints;
+    std::lock_guard<std::mutex> guard(privateDnsLock);
+    // Ensure we overwrite any previous matching server.  This is necessary because equality is
+    // based only on the IP address, not the port or fingerprints.
+    privateDnsServers.erase(privateServer);
+    privateDnsServers.insert(privateServer);
+    return INetd::PRIVATE_DNS_SUCCESS;
+}
+
+int ResolverController::removePrivateDnsServer(const std::string& server) {
+    using android::net::INetd;
+    sockaddr_storage parsed;
+    if (!parseServer(server.c_str(), 0, &parsed)) {
+        return INetd::PRIVATE_DNS_BAD_ADDRESS;
+    }
+    std::lock_guard<std::mutex> guard(privateDnsLock);
+    privateDnsServers.erase(parsed);
+    for (auto& pair : privateDnsTransports) {
+        pair.second.erase(parsed);
+    }
+    return INetd::PRIVATE_DNS_SUCCESS;
+}
+
 }  // namespace net
 }  // namespace android