Support RFC 7858 DNS over TLS

This change adds the core capability for DNS over TLS, and creates
private APIs for activating it, but does not provide any way to
activate the functionality in a development environment or on a
real device.

Based on https://android-review.googlesource.com/#/c/373776/

Test: Complete unit+integration tests.  Manual tests look good.
Bug: 34953048
Change-Id: Ib99ac1f631fd2c2c8fbf53bdb05f67f8be7713ac
diff --git a/server/Android.mk b/server/Android.mk
index c8ba20f..8c3d648 100644
--- a/server/Android.mk
+++ b/server/Android.mk
@@ -64,6 +64,7 @@
 
 LOCAL_SHARED_LIBRARIES := \
         libbinder \
+        libcrypto \
         libcutils \
         libdl \
         liblog \
@@ -73,6 +74,7 @@
         libnetutils \
         libnetdutils \
         libnl \
+        libssl \
         libsysutils \
         libbase \
         libutils \
@@ -120,6 +122,7 @@
         oem_iptables_hook.cpp \
         binder/android/net/UidRange.cpp \
         binder/android/net/metrics/INetdEventListener.aidl \
+        dns/DnsTlsTransport.cpp \
 
 LOCAL_AIDL_INCLUDES := $(LOCAL_PATH)/binder
 
@@ -186,6 +189,7 @@
         libnetdaidl \
         libbase \
         libbinder \
+        libcrypto \
         libcutils \
         liblog \
         liblogwrap \
@@ -194,6 +198,7 @@
         libnl \
         libsysutils \
         libutils \
+        libssl \
 
 include $(BUILD_NATIVE_TEST)
 
diff --git a/server/DnsProxyListener.cpp b/server/DnsProxyListener.cpp
index a925a22..8db1da8 100644
--- a/server/DnsProxyListener.cpp
+++ b/server/DnsProxyListener.cpp
@@ -40,8 +40,10 @@
 #include <utils/String16.h>
 #include <sysutils/SocketClient.h>
 
+#include "Controllers.h"
 #include "Fwmark.h"
 #include "DnsProxyListener.h"
+#include "dns/DnsTlsTransport.h"
 #include "NetdConstants.h"
 #include "NetworkController.h"
 #include "ResponseCode.h"
@@ -76,6 +78,61 @@
     cli->decRef();
 }
 
+thread_local android_net_context thread_netcontext = {};
+
+res_sendhookact qhook(sockaddr* const * nsap, const u_char** buf, int* buflen,
+                      u_char* ans, int anssiz, int* resplen) {
+    if (!thread_netcontext.qhook) {
+        ALOGE("qhook abort: thread qhook is null");
+        return res_goahead;
+    }
+    if (!net::gCtls) {
+        ALOGE("qhook abort: gCtls is null");
+        return res_goahead;
+    }
+    // Safely read the data from nsap without violating strict-aliasing.
+    sockaddr_storage insecureResolver;
+    if ((*nsap)->sa_family == AF_INET) {
+        std::memcpy(&insecureResolver, *nsap, sizeof(sockaddr_in));
+    } else if ((*nsap)->sa_family == AF_INET6) {
+        std::memcpy(&insecureResolver, *nsap, sizeof(sockaddr_in6));
+    } else {
+        ALOGE("qhook abort: unknown address family");
+        return res_goahead;
+    }
+    sockaddr_storage secureResolver;
+    std::set<std::vector<uint8_t>> fingerprints;
+    if (net::gCtls->resolverCtrl.shouldUseTls(thread_netcontext.dns_netid,
+            insecureResolver, &secureResolver, &fingerprints)) {
+        if (DBG) {
+            ALOGD("qhook using TLS");
+        }
+        DnsTlsTransport xport(thread_netcontext.dns_mark, IPPROTO_TCP,
+                              secureResolver, fingerprints);
+        auto response = xport.doQuery(*buf, *buflen, ans, anssiz, resplen);
+        if (response == DnsTlsTransport::Response::success) {
+            if (DBG) {
+                ALOGD("qhook success");
+            }
+            return res_done;
+        }
+        if (DBG) {
+            ALOGW("qhook abort: doQuery failed: %d", (int)response);
+        }
+        // If there was a network error, try a different name server.
+        // Otherwise, fail hard.
+        if (response == DnsTlsTransport::Response::network_error) {
+            return res_nextns;
+        }
+        return res_error;
+    }
+
+    if (DBG) {
+        ALOGD("qhook not using TLS");
+    }
+    return res_goahead;
+}
+
 }  // namespace
 
 DnsProxyListener::DnsProxyListener(const NetworkController* netCtrl, EventReporter* eventReporter) :
@@ -189,6 +246,7 @@
 
     struct addrinfo* result = NULL;
     Stopwatch s;
+    thread_netcontext = mNetContext;
     uint32_t rv = android_getaddrinfofornetcontext(mHost, mService, mHints, &mNetContext, &result);
     const int latencyMs = lround(s.timeTaken());
 
@@ -305,6 +363,7 @@
 
     android_net_context netcontext;
     mDnsProxyListener->mNetCtrl->getNetworkContext(netId, uid, &netcontext);
+    netcontext.qhook = &qhook;
 
     if (ai_flags != -1 || ai_family != -1 ||
         ai_socktype != -1 || ai_protocol != -1) {
@@ -370,6 +429,7 @@
 
     android_net_context netcontext;
     mDnsProxyListener->mNetCtrl->getNetworkContext(netId, uid, &netcontext);
+    netcontext.qhook = &qhook;
 
     const int metricsLevel = mDnsProxyListener->mEventReporter->getMetricsReportingLevel();
 
@@ -401,6 +461,7 @@
     }
 
     Stopwatch s;
+    thread_netcontext = mNetContext;
     struct hostent* hp = android_gethostbynamefornetcontext(mName, mAf, &mNetContext);
     const int latencyMs = lround(s.timeTaken());
 
@@ -512,6 +573,7 @@
 
     android_net_context netcontext;
     mDnsProxyListener->mNetCtrl->getNetworkContext(netId, uid, &netcontext);
+    netcontext.qhook = &qhook;
 
     DnsProxyListener::GetHostByAddrHandler* handler =
             new DnsProxyListener::GetHostByAddrHandler(cli, addr, addrLen, addrFamily, netcontext);
@@ -543,6 +605,7 @@
     struct hostent* hp;
 
     // NOTE gethostbyaddr should take a void* but bionic thinks it should be char*
+    thread_netcontext = mNetContext;
     hp = android_gethostbyaddrfornetcontext(
             (char*)mAddress, mAddressLen, mAddressFamily, &mNetContext);
 
diff --git a/server/NetdConstants.cpp b/server/NetdConstants.cpp
index 0a0ca5d..58b2f64 100644
--- a/server/NetdConstants.cpp
+++ b/server/NetdConstants.cpp
@@ -20,6 +20,7 @@
 #include <netdb.h>
 #include <net/if.h>
 #include <netinet/in.h>
+#include <openssl/ssl.h>
 #include <stdlib.h>
 #include <string.h>
 #include <sys/wait.h>
@@ -34,6 +35,8 @@
 #include "NetdConstants.h"
 #include "IptablesRestoreController.h"
 
+const size_t SHA256_SIZE = EVP_MD_size(EVP_sha256());
+
 const char * const OEM_SCRIPT_PATH = "/system/bin/oem-iptables-init.sh";
 const char * const IPTABLES_PATH = "/system/bin/iptables";
 const char * const IP6TABLES_PATH = "/system/bin/ip6tables";
diff --git a/server/NetdConstants.h b/server/NetdConstants.h
index 54ed812..446a898 100644
--- a/server/NetdConstants.h
+++ b/server/NetdConstants.h
@@ -32,6 +32,8 @@
 const int PROTECT_MARK = 0x1;
 const int MAX_SYSTEM_UID = AID_APP - 1;
 
+extern const size_t SHA256_SIZE;
+
 extern const char * const IPTABLES_PATH;
 extern const char * const IP6TABLES_PATH;
 extern const char * const IP_PATH;
diff --git a/server/NetdNativeService.cpp b/server/NetdNativeService.cpp
index 67bf9c0..4cdcdf0 100644
--- a/server/NetdNativeService.cpp
+++ b/server/NetdNativeService.cpp
@@ -28,6 +28,8 @@
 #include <binder/IServiceManager.h>
 #include "android/net/BnNetd.h"
 
+#include <openssl/base64.h>
+
 #include "Controllers.h"
 #include "DumpWriter.h"
 #include "EventReporter.h"
@@ -234,6 +236,47 @@
     return binder::Status::ok();
 }
 
+binder::Status NetdNativeService::addPrivateDnsServer(const std::string& server, int32_t port,
+        const std::string& fingerprintAlgorithm, const std::vector<std::string>& fingerprints) {
+    ENFORCE_PERMISSION(CONNECTIVITY_INTERNAL);
+    std::set<std::vector<uint8_t>> decoded_fingerprints;
+    for (const std::string& input : fingerprints) {
+        size_t out_len;
+        if (EVP_DecodedLength(&out_len, input.size()) != 1) {
+            return binder::Status::fromServiceSpecificError(INetd::PRIVATE_DNS_BAD_FINGERPRINT,
+                    "ResolverController error: bad fingerprint length");
+        }
+        // out_len is now an upper bound on the output length.
+        std::vector<uint8_t> decoded(out_len);
+        if (EVP_DecodeBase64(decoded.data(), &out_len, decoded.size(),
+                reinterpret_cast<const uint8_t*>(input.data()), input.size()) == 1) {
+            // Possibly shrink the vector if the actual output was smaller than the bound.
+            decoded.resize(out_len);
+        } else {
+            return binder::Status::fromServiceSpecificError(INetd::PRIVATE_DNS_BAD_FINGERPRINT,
+                    "ResolverController error: Base64 parsing failed");
+        }
+        decoded_fingerprints.insert(decoded);
+    }
+    const int err = gCtls->resolverCtrl.addPrivateDnsServer(server, port,
+            fingerprintAlgorithm, decoded_fingerprints);
+    if (err != INetd::PRIVATE_DNS_SUCCESS) {
+        return binder::Status::fromServiceSpecificError(err,
+                String8::format("ResolverController error: %d", err));
+    }
+    return binder::Status::ok();
+}
+
+binder::Status NetdNativeService::removePrivateDnsServer(const std::string& server) {
+    ENFORCE_PERMISSION(CONNECTIVITY_INTERNAL);
+    const int err = gCtls->resolverCtrl.removePrivateDnsServer(server);
+    if (err != INetd::PRIVATE_DNS_SUCCESS) {
+        return binder::Status::fromServiceSpecificError(err,
+                String8::format("ResolverController error: %d", err));
+    }
+    return binder::Status::ok();
+}
+
 binder::Status NetdNativeService::tetherApplyDnsInterfaces(bool *ret) {
     NETD_BIG_LOCK_RPC(CONNECTIVITY_INTERNAL);
 
diff --git a/server/NetdNativeService.h b/server/NetdNativeService.h
index 407c563..138043a 100644
--- a/server/NetdNativeService.h
+++ b/server/NetdNativeService.h
@@ -47,6 +47,10 @@
     binder::Status getResolverInfo(int32_t netId, std::vector<std::string>* servers,
             std::vector<std::string>* domains, std::vector<int32_t>* params,
             std::vector<int32_t>* stats) override;
+    binder::Status addPrivateDnsServer(const std::string& server, int32_t port,
+            const std::string& fingerprintAlgorithm,
+            const std::vector<std::string>& fingerprints) override;
+    binder::Status removePrivateDnsServer(const std::string& server) override;
 
     binder::Status setIPv6AddrGenMode(const std::string& ifName, int32_t mode) override;
 
diff --git a/server/ResolverController.cpp b/server/ResolverController.cpp
index daf4ebb..caf3ee9 100644
--- a/server/ResolverController.cpp
+++ b/server/ResolverController.cpp
@@ -19,13 +19,19 @@
 
 #include <algorithm>
 #include <cstdlib>
+#include <map>
+#include <mutex>
+#include <set>
 #include <string>
+#include <thread>
+#include <utility>
 #include <vector>
 #include <cutils/log.h>
 #include <net/if.h>
 #include <sys/socket.h>
 #include <netdb.h>
 
+#include <arpa/inet.h>
 // NOTE: <resolv_netid.h> is a private C library header that provides
 //       declarations for _resolv_set_nameservers_for_net and
 //       _resolv_flush_cache_for_net
@@ -37,25 +43,199 @@
 #include <android/net/INetd.h>
 
 #include "DumpWriter.h"
+#include "NetdConstants.h"
 #include "ResolverController.h"
 #include "ResolverStats.h"
+#include "dns/DnsTlsTransport.h"
 
 namespace android {
 namespace net {
 
+namespace {
+
+struct PrivateDnsServer {
+    PrivateDnsServer(const sockaddr_storage& ss) : ss(ss) {}
+    const sockaddr_storage ss;
+    // For now, the fingerprints are always SHA-256.  This is the only digest algorithm
+    // that is mandatory to support (https://tools.ietf.org/html/rfc7858#section-4.2).
+    std::set<std::vector<uint8_t>> fingerprints;
+};
+
+// This comparison ignores ports and fingerprints.
+bool operator<(const PrivateDnsServer& x, const PrivateDnsServer& y) {
+    if (x.ss.ss_family != y.ss.ss_family) {
+        return x.ss.ss_family < y.ss.ss_family;
+    }
+    // Same address family.  Compare IP addresses.
+    if (x.ss.ss_family == AF_INET) {
+        const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x.ss);
+        const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y.ss);
+        return x_sin.sin_addr.s_addr < y_sin.sin_addr.s_addr;
+    } else if (x.ss.ss_family == AF_INET6) {
+        const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x.ss);
+        const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y.ss);
+        return std::memcmp(x_sin6.sin6_addr.s6_addr, y_sin6.sin6_addr.s6_addr, 16);
+    }
+    return false;  // Unknown address type.  This is an error.
+}
+
+bool parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
+    sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
+    if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) {
+        // IPv4 parse succeeded, so it's IPv4
+        sin->sin_family = AF_INET;
+        sin->sin_port = htons(port);
+        return true;
+    }
+    sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed);
+    if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){
+        // IPv6 parse succeeded, so it's IPv6.
+        sin6->sin6_family = AF_INET6;
+        sin6->sin6_port = htons(port);
+        return true;
+    }
+    if (DBG) {
+        ALOGW("Failed to parse server address: %s", server);
+    }
+    return false;
+}
+
+// Structure for tracking the entire set of known Private DNS servers.
+std::mutex privateDnsLock;
+typedef std::set<PrivateDnsServer> PrivateDnsSet;
+PrivateDnsSet privateDnsServers;
+
+// Structure for tracking the validation status of servers on a specific netid.
+// Servers that fail validation are removed from the tracker, and can be retried.
+enum class Validation : bool { in_process, success };
+typedef std::map<PrivateDnsServer, Validation> PrivateDnsTracker;
+std::map<unsigned, PrivateDnsTracker> privateDnsTransports;
+
+PrivateDnsSet parseServers(const char** servers, int numservers, in_port_t port) {
+    PrivateDnsSet set;
+    for (int i = 0; i < numservers; ++i) {
+        sockaddr_storage parsed;
+        if (parseServer(servers[i], port, &parsed)) {
+            set.insert(parsed);
+        }
+    }
+    return set;
+}
+
+void checkPrivateDnsProviders(const unsigned netId, const char** servers, int numservers) {
+    if (DBG) {
+        ALOGD("checkPrivateDnsProviders(%u)", netId);
+    }
+
+    std::lock_guard<std::mutex> guard(privateDnsLock);
+    if (privateDnsServers.empty()) {
+        return;
+    }
+
+    // First compute the intersection of the servers to check with the
+    // servers that are permitted to use DNS over TLS.  The intersection
+    // will contain the port number to be used for Private DNS.
+    PrivateDnsSet serversToCheck = parseServers(servers, numservers, 53);
+    PrivateDnsSet intersection;
+    std::set_intersection(privateDnsServers.begin(), privateDnsServers.end(),
+        serversToCheck.begin(), serversToCheck.end(),
+        std::inserter(intersection, intersection.begin()));
+    if (intersection.empty()) {
+        return;
+    }
+
+    auto netPair = privateDnsTransports.find(netId);
+    if (netPair == privateDnsTransports.end()) {
+        // New netId
+        bool added;
+        std::tie(netPair, added) = privateDnsTransports.emplace(netId, PrivateDnsTracker());
+        if (!added) {
+            ALOGE("Memory error while checking private DNS for netId %d", netId);
+            return;
+        }
+    }
+
+    auto& tracker = netPair->second;
+    for (const auto& privateServer : intersection) {
+        if (tracker.count(privateServer) != 0) {
+            continue;
+        }
+        tracker[privateServer] = Validation::in_process;
+        std::thread validate_thread([privateServer, netId] {
+            // validateDnsTlsServer() is a blocking call that performs network operations.
+            // It can take milliseconds to minutes, up to the SYN retry limit.
+            bool success = validateDnsTlsServer(netId,
+                    privateServer.ss, privateServer.fingerprints);
+            std::lock_guard<std::mutex> guard(privateDnsLock);
+            auto netPair = privateDnsTransports.find(netId);
+            if (netPair == privateDnsTransports.end()) {
+                ALOGW("netId %u was erased during private DNS validation", netId);
+                return;
+            }
+            auto& tracker = netPair->second;
+            if (privateDnsServers.count(privateServer) == 0) {
+                ALOGW("Server was removed during private DNS validation");
+                success = false;
+            }
+            if (success) {
+                tracker[privateServer] = Validation::success;
+            } else {
+                // Validation failure is expected if a user is on a captive portal.
+                // TODO: Trigger a second validation attempt after captive portal login
+                // succeeds.
+                tracker.erase(privateServer);
+            }
+        });
+        validate_thread.detach();
+    }
+}
+
+void clearPrivateDnsProviders(unsigned netId) {
+    if (DBG) {
+        ALOGD("clearPrivateDnsProviders(%u)", netId);
+    }
+    std::lock_guard<std::mutex> guard(privateDnsLock);
+    privateDnsTransports.erase(netId);
+}
+
+}  // namespace
+
 int ResolverController::setDnsServers(unsigned netId, const char* searchDomains,
         const char** servers, int numservers, const __res_params* params) {
     if (DBG) {
         ALOGD("setDnsServers netId = %u\n", netId);
     }
+    checkPrivateDnsProviders(netId, servers, numservers);
     return -_resolv_set_nameservers_for_net(netId, servers, numservers, searchDomains, params);
 }
 
+bool ResolverController::shouldUseTls(unsigned netId, const sockaddr_storage& insecureServer,
+        sockaddr_storage* secureServer, std::set<std::vector<uint8_t>>* fingerprints) {
+    // This mutex is on the critical path of every DNS lookup that doesn't hit a local cache.
+    // If the overhead of mutex acquisition proves too high, we could reduce it by maintaining
+    // an atomic_int32_t counter of validated connections, and returning early if it's zero.
+    std::lock_guard<std::mutex> guard(privateDnsLock);
+    const auto netPair = privateDnsTransports.find(netId);
+    if (netPair == privateDnsTransports.end()) {
+        return false;
+    }
+    const auto& tracker = netPair->second;
+    const auto serverPair = tracker.find(insecureServer);
+    if (serverPair == tracker.end() || serverPair->second != Validation::success) {
+        return false;
+    }
+    const auto& validatedServer = serverPair->first;
+    *secureServer = validatedServer.ss;
+    *fingerprints = validatedServer.fingerprints;
+    return true;
+}
+
 int ResolverController::clearDnsServers(unsigned netId) {
     _resolv_set_nameservers_for_net(netId, NULL, 0, "", NULL);
     if (DBG) {
         ALOGD("clearDnsServers netId = %u\n", netId);
     }
+    clearPrivateDnsProviders(netId);
     return 0;
 }
 
@@ -250,5 +430,56 @@
     dw.decIndent();
 }
 
+int ResolverController::addPrivateDnsServer(const std::string& server, int32_t port,
+        const std::string& fingerprintAlgorithm,
+        const std::set<std::vector<uint8_t>>& fingerprints) {
+    using android::net::INetd;
+    if (fingerprintAlgorithm.empty()) {
+        if (!fingerprints.empty()) {
+            return INetd::PRIVATE_DNS_BAD_FINGERPRINT;
+        }
+    } else if (fingerprintAlgorithm.compare("SHA-256") == 0) {
+        if (fingerprints.empty()) {
+            return INetd::PRIVATE_DNS_BAD_FINGERPRINT;
+        }
+        for (const auto& fingerprint : fingerprints) {
+            if (fingerprint.size() != SHA256_SIZE) {
+                return INetd::PRIVATE_DNS_BAD_FINGERPRINT;
+            }
+        }
+    } else {
+        return INetd::PRIVATE_DNS_UNKNOWN_ALGORITHM;
+    }
+    if (port <= 0 || port > 0xFFFF) {
+        return INetd::PRIVATE_DNS_BAD_PORT;
+    }
+    sockaddr_storage parsed;
+    if (!parseServer(server.c_str(), port, &parsed)) {
+        return INetd::PRIVATE_DNS_BAD_ADDRESS;
+    }
+    PrivateDnsServer privateServer(parsed);
+    privateServer.fingerprints = fingerprints;
+    std::lock_guard<std::mutex> guard(privateDnsLock);
+    // Ensure we overwrite any previous matching server.  This is necessary because equality is
+    // based only on the IP address, not the port or fingerprints.
+    privateDnsServers.erase(privateServer);
+    privateDnsServers.insert(privateServer);
+    return INetd::PRIVATE_DNS_SUCCESS;
+}
+
+int ResolverController::removePrivateDnsServer(const std::string& server) {
+    using android::net::INetd;
+    sockaddr_storage parsed;
+    if (!parseServer(server.c_str(), 0, &parsed)) {
+        return INetd::PRIVATE_DNS_BAD_ADDRESS;
+    }
+    std::lock_guard<std::mutex> guard(privateDnsLock);
+    privateDnsServers.erase(parsed);
+    for (auto& pair : privateDnsTransports) {
+        pair.second.erase(parsed);
+    }
+    return INetd::PRIVATE_DNS_SUCCESS;
+}
+
 }  // namespace net
 }  // namespace android
diff --git a/server/ResolverController.h b/server/ResolverController.h
index 3da9ac9..a6a559d 100644
--- a/server/ResolverController.h
+++ b/server/ResolverController.h
@@ -39,6 +39,15 @@
     int setDnsServers(unsigned netId, const char* searchDomains, const char** servers,
             int numservers, const __res_params* params);
 
+    // Given a netId and the address of an insecure (i.e. normal) DNS server, this method checks
+    // if there is a known secure DNS server with the same IP address that has been validated as
+    // accessible on this netId.  If so, it returns true, providing the server's address
+    // (including port) and pin fingerprints (possibly empty) in the output parameters.
+    // TODO: Add support for optional stronger security, by returning true even if the secure
+    // server is not accessible.
+    bool shouldUseTls(unsigned netId, const sockaddr_storage& insecureServer,
+            sockaddr_storage* secureServer, std::set<std::vector<uint8_t>>* fingerprints);
+
     int clearDnsServers(unsigned netid);
 
     int flushDnsCache(unsigned netid);
@@ -56,6 +65,11 @@
             std::vector<std::string>* domains, std::vector<int32_t>* params,
             std::vector<int32_t>* stats);
     void dump(DumpWriter& dw, unsigned netId);
+
+    int addPrivateDnsServer(const std::string& server, int32_t port,
+            const std::string& fingerprintAlgorithm,
+            const std::set<std::vector<uint8_t>>& fingerprints);
+    int removePrivateDnsServer(const std::string& server);
 };
 
 }  // namespace net
diff --git a/server/binder/android/net/INetd.aidl b/server/binder/android/net/INetd.aidl
index e5bf218..3dd5e1a 100644
--- a/server/binder/android/net/INetd.aidl
+++ b/server/binder/android/net/INetd.aidl
@@ -146,6 +146,44 @@
     void getResolverInfo(int netId, out @utf8InCpp String[] servers,
             out @utf8InCpp String[] domains, out int[] params, out int[] stats);
 
+    // Private DNS function error codes.
+    const int PRIVATE_DNS_SUCCESS = 0;
+    const int PRIVATE_DNS_BAD_ADDRESS = 1;
+    const int PRIVATE_DNS_BAD_PORT = 2;
+    const int PRIVATE_DNS_UNKNOWN_ALGORITHM = 3;
+    const int PRIVATE_DNS_BAD_FINGERPRINT = 4;
+
+    /**
+     * Adds a server to the list of DNS resolvers that support DNS over TLS.  After this action
+     * succeeds, any subsequent call to setResolverConfiguration will opportunistically use DNS
+     * over TLS if the specified server is on this list and is reachable on that network.
+     *
+     * @param server the DNS server's IP address.  If a private DNS server is already configured
+     *        with this IP address, it will be overwritten.
+     * @param port the port on which the server is listening, typically 853.
+     * @param fingerprintAlgorithm the hash algorithm used to compute the fingerprints.  This should
+     *        be a name in MessageDigest's format.  Currently "SHA-256" is the only supported
+     *        algorithm. Set this to the empty string to disable fingerprint validation.
+     * @param fingerprints the server's public key fingerprints as Base64 strings.
+     *        These can be generated using MessageDigest and android.util.Base64.encodeToString.
+     *        Currently "SHA-256" is the only supported algorithm. Set this to empty to disable
+     *        fingerprint validation.
+     * @throws ServiceSpecificException in case of failure, with an error code indicating the
+     *         cause of the the failure.
+     * @return true if the arguments were successfully parsed and recognized.
+     */
+    void addPrivateDnsServer(in @utf8InCpp String server, int port,
+             in @utf8InCpp String fingerprintAlgorithm, in @utf8InCpp String[] fingerprints);
+
+    /**
+     * Remove a server from the list of DNS resolvers that support DNS over TLS.
+     *
+     * @param server the DNS server's IP address.
+     * @throws ServiceSpecificException in case of failure, with an error code indicating the
+     *         cause of the the failure.
+     */
+    void removePrivateDnsServer(in @utf8InCpp String server);
+
     /**
      * Instruct the tethering DNS server to reevaluated serving interfaces.
      * This is needed to for the DNS server to observe changes in the set
diff --git a/server/dns/DnsTlsTransport.cpp b/server/dns/DnsTlsTransport.cpp
new file mode 100644
index 0000000..8d27d20
--- /dev/null
+++ b/server/dns/DnsTlsTransport.cpp
@@ -0,0 +1,435 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "dns/DnsTlsTransport.h"
+
+#include <arpa/inet.h>
+#include <arpa/nameser.h>
+#include <errno.h>
+#include <openssl/err.h>
+#include <openssl/ssl.h>
+#include <stdlib.h>
+
+#define LOG_TAG "DnsTlsTransport"
+#define DBG 0
+
+#include "log/log.h"
+#include "Fwmark.h"
+#undef ADD  // already defined in nameser.h
+#include "NetdConstants.h"
+#include "Permission.h"
+
+
+namespace android {
+namespace net {
+
+namespace {
+
+bool setNonBlocking(int fd, bool enabled) {
+    int flags = fcntl(fd, F_GETFL);
+    if (flags < 0) return false;
+
+    if (enabled) {
+        flags |= O_NONBLOCK;
+    } else {
+        flags &= ~O_NONBLOCK;
+    }
+    return (fcntl(fd, F_SETFL, flags) == 0);
+}
+
+int waitForReading(int fd) {
+    fd_set fds;
+    FD_ZERO(&fds);
+    FD_SET(fd, &fds);
+    const int ret = TEMP_FAILURE_RETRY(select(fd + 1, &fds, nullptr, nullptr, nullptr));
+    if (DBG && ret <= 0) {
+        ALOGD("select");
+    }
+    return ret;
+}
+
+int waitForWriting(int fd) {
+    fd_set fds;
+    FD_ZERO(&fds);
+    FD_SET(fd, &fds);
+    const int ret = TEMP_FAILURE_RETRY(select(fd + 1, nullptr, &fds, nullptr, nullptr));
+    if (DBG && ret <= 0) {
+        ALOGD("select");
+    }
+    return ret;
+}
+
+}  // namespace
+
+android::base::unique_fd DnsTlsTransport::makeConnectedSocket() const {
+    android::base::unique_fd fd;
+    int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
+    switch (mProtocol) {
+        case IPPROTO_TCP:
+            type |= SOCK_STREAM;
+            break;
+        default:
+            errno = EPROTONOSUPPORT;
+            return fd;
+    }
+
+    fd.reset(socket(mAddr.ss_family, type, mProtocol));
+    if (fd.get() == -1) {
+        return fd;
+    }
+
+    const socklen_t len = sizeof(mMark);
+    if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
+        fd.reset();
+    } else if (connect(fd.get(),
+            reinterpret_cast<const struct sockaddr *>(&mAddr), sizeof(mAddr)) != 0
+        && errno != EINPROGRESS) {
+        fd.reset();
+    }
+
+    return fd;
+}
+
+bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
+    int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
+    unsigned char spki[spki_len];
+    unsigned char* temp = spki;
+    if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
+        ALOGW("SPKI length mismatch");
+        return false;
+    }
+    out->resize(SHA256_SIZE);
+    unsigned int digest_len = 0;
+    int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
+    if (ret != 1) {
+        ALOGW("Server cert digest extraction failed");
+        return false;
+    }
+    if (digest_len != out->size()) {
+        ALOGW("Wrong digest length: %d", digest_len);
+        return false;
+    }
+    return true;
+}
+
+SSL* DnsTlsTransport::sslConnect(int fd) {
+    if (fd < 0) {
+        ALOGD("%u makeConnectedSocket() failed with: %s", mMark, strerror(errno));
+        return nullptr;
+    }
+
+    // Set up TLS context.
+    bssl::UniquePtr<SSL_CTX> ssl_ctx(SSL_CTX_new(TLS_method()));
+    if (!SSL_CTX_set_max_proto_version(ssl_ctx.get(), TLS1_3_VERSION) ||
+        !SSL_CTX_set_min_proto_version(ssl_ctx.get(), TLS1_1_VERSION)) {
+        ALOGD("failed to min/max TLS versions");
+        return nullptr;
+    }
+
+    bssl::UniquePtr<SSL> ssl(SSL_new(ssl_ctx.get()));
+    bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_CLOSE));
+    SSL_set_bio(ssl.get(), bio.get(), bio.get());
+    bio.release();
+
+    if (!setNonBlocking(fd, false)) {
+        ALOGE("Failed to disable nonblocking status on DNS-over-TLS fd");
+        return nullptr;
+    }
+
+    for (;;) {
+        if (DBG) {
+            ALOGD("%u Calling SSL_connect", mMark);
+        }
+        int ret = SSL_connect(ssl.get());
+        if (DBG) {
+            ALOGD("%u SSL_connect returned %d", mMark, ret);
+        }
+        if (ret == 1) break;  // SSL handshake complete;
+
+        const int ssl_err = SSL_get_error(ssl.get(), ret);
+        switch (ssl_err) {
+            case SSL_ERROR_WANT_READ:
+                if (waitForReading(fd) != 1) {
+                    ALOGW("SSL_connect read error");
+                    return nullptr;
+                }
+                break;
+            case SSL_ERROR_WANT_WRITE:
+                if (waitForWriting(fd) != 1) {
+                    ALOGW("SSL_connect write error");
+                    return nullptr;
+                }
+                break;
+            default:
+                ALOGW("SSL_connect error %d, errno=%d", ssl_err, errno);
+                return nullptr;
+        }
+    }
+
+    if (!mFingerprints.empty()) {
+        if (DBG) {
+            ALOGD("Checking DNS over TLS fingerprint");
+        }
+        // TODO: Follow the cert chain and check all the way up.
+        bssl::UniquePtr<X509> cert(SSL_get_peer_certificate(ssl.get()));
+        if (!cert) {
+            ALOGW("Server has null certificate");
+            return nullptr;
+        }
+        std::vector<uint8_t> digest;
+        if (!getSPKIDigest(cert.get(), &digest)) {
+            ALOGE("Digest computation failed");
+            return nullptr;
+        }
+
+        if (mFingerprints.count(digest) == 0) {
+            ALOGW("No matching fingerprint");
+            return nullptr;
+        }
+        if (DBG) {
+            ALOGD("DNS over TLS fingerprint is correct");
+        }
+    }
+
+    if (DBG) {
+        ALOGD("%u handshake complete", mMark);
+    }
+    return ssl.release();
+}
+
+bool DnsTlsTransport::sslWrite(int fd, SSL *ssl, const uint8_t *buffer, int len) {
+    if (DBG) {
+        ALOGD("%u Writing %d bytes", mMark, len);
+    }
+    for (;;) {
+        int ret = SSL_write(ssl, buffer, len);
+        if (ret == len) break;  // SSL write complete;
+
+        if (ret < 1) {
+            const int ssl_err = SSL_get_error(ssl, ret);
+            switch (ssl_err) {
+                case SSL_ERROR_WANT_WRITE:
+                    if (waitForWriting(fd) != 1) {
+                        if (DBG) {
+                            ALOGW("SSL_write error");
+                        }
+                        return false;
+                    }
+                    continue;
+                case 0:
+                    break;  // SSL write complete;
+                default:
+                    if (DBG) {
+                        ALOGW("SSL_write error %d", ssl_err);
+                    }
+                    return false;
+            }
+        }
+    }
+    if (DBG) {
+        ALOGD("%u Wrote %d bytes", mMark, len);
+    }
+    return true;
+}
+
+// Read exactly len bytes into buffer or fail
+bool DnsTlsTransport::sslRead(int fd, SSL *ssl, uint8_t *buffer, int len) {
+    int remaining = len;
+    while (remaining > 0) {
+        int ret = SSL_read(ssl, buffer + (len - remaining), remaining);
+        if (ret == 0) {
+            ALOGE("SSL socket closed with %i of %i bytes remaining", remaining, len);
+            return false;
+        }
+
+        if (ret < 0) {
+            const int ssl_err = SSL_get_error(ssl, ret);
+            if (ssl_err == SSL_ERROR_WANT_READ) {
+                if (waitForReading(fd) != 1) {
+                    if (DBG) {
+                        ALOGW("SSL_read error");
+                    }
+                    return false;
+                }
+                continue;
+            } else {
+                if (DBG) {
+                    ALOGW("SSL_read error %d", ssl_err);
+                }
+                return false;
+            }
+        }
+
+        remaining -= ret;
+    }
+    return true;
+}
+
+DnsTlsTransport::Response DnsTlsTransport::doQuery(const uint8_t *query, size_t qlen,
+        uint8_t *response, size_t limit, int *resplen) {
+    *resplen = 0;  // Zero indicates an error.
+
+    if (DBG) {
+        ALOGD("%u connecting TCP socket", mMark);
+    }
+    android::base::unique_fd fd(makeConnectedSocket());
+    if (DBG) {
+        ALOGD("%u connecting SSL", mMark);
+    }
+    bssl::UniquePtr<SSL> ssl(sslConnect(fd));
+    if (ssl == nullptr) {
+        if (DBG) {
+            ALOGW("%u SSL connection failed", mMark);
+        }
+        return Response::network_error;
+    }
+
+    uint8_t queryHeader[2];
+    queryHeader[0] = qlen >> 8;
+    queryHeader[1] = qlen;
+    if (!sslWrite(fd.get(), ssl.get(), queryHeader, 2)) {
+        return Response::network_error;
+    }
+    if (!sslWrite(fd.get(), ssl.get(), query, qlen)) {
+        return Response::network_error;
+    }
+    if (DBG) {
+        ALOGD("%u SSL_write complete", mMark);
+    }
+
+    uint8_t responseHeader[2];
+    if (!sslRead(fd.get(), ssl.get(), responseHeader, 2)) {
+        if (DBG) {
+            ALOGW("%u Failed to read 2-byte length header", mMark);
+        }
+        return Response::network_error;
+    }
+    const uint16_t responseSize = (responseHeader[0] << 8) | responseHeader[1];
+    if (DBG) {
+        ALOGD("%u Expecting response of size %i", mMark, responseSize);
+    }
+    if (responseSize > limit) {
+        ALOGE("%u Response doesn't fit in output buffer: %i", mMark, responseSize);
+        return Response::limit_error;
+    }
+    if (!sslRead(fd.get(), ssl.get(), response, responseSize)) {
+        if (DBG) {
+            ALOGW("%u Failed to read %i bytes", mMark, responseSize);
+        }
+        return Response::network_error;
+    }
+    if (DBG) {
+        ALOGD("%u SSL_read complete", mMark);
+    }
+
+    if (response[0] != query[0] || response[1] != query[1]) {
+        ALOGE("reply query ID != query ID");
+        return Response::internal_error;
+    }
+
+    SSL_shutdown(ssl.get());
+
+    *resplen = responseSize;
+    return Response::success;
+}
+
+bool validateDnsTlsServer(unsigned netid, const struct sockaddr_storage& ss,
+        const std::set<std::vector<uint8_t>>& fingerprints) {
+    if (DBG) {
+        ALOGD("Beginning validation on %u", netid);
+    }
+    // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
+    // order to prove that it is actually a working DNS over TLS server.
+    static const char kDnsSafeChars[] =
+            "abcdefhijklmnopqrstuvwxyz"
+            "ABCDEFHIJKLMNOPQRSTUVWXYZ"
+            "0123456789";
+    const auto c = [](uint8_t rnd) -> uint8_t {
+        return kDnsSafeChars[(rnd % ARRAY_SIZE(kDnsSafeChars))];
+    };
+    uint8_t rnd[8];
+    arc4random_buf(rnd, ARRAY_SIZE(rnd));
+    // We could try to use res_mkquery() here, but it's basically the same.
+    uint8_t query[] = {
+        rnd[6], rnd[7],  // [0-1]   query ID
+        1, 0,  // [2-3]   flags; query[2] = 1 for recursion desired (RD).
+        0, 1,  // [4-5]   QDCOUNT (number of queries)
+        0, 0,  // [6-7]   ANCOUNT (number of answers)
+        0, 0,  // [8-9]   NSCOUNT (number of name server records)
+        0, 0,  // [10-11] ARCOUNT (number of additional records)
+        17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]),
+            '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's',
+        6, 'm', 'e', 't', 'r', 'i', 'c',
+        7, 'g', 's', 't', 'a', 't', 'i', 'c',
+        3, 'c', 'o', 'm',
+        0,  // null terminator of FQDN (root TLD)
+        0, ns_t_aaaa,  // QTYPE
+        0, ns_c_in     // QCLASS
+    };
+    const int qlen = ARRAY_SIZE(query);
+
+    const int kRecvBufSize = 4 * 1024;
+    uint8_t recvbuf[kRecvBufSize];
+
+    // At validation time, we only know the netId, so we have to guess/compute the
+    // corresponding socket mark.
+    Fwmark fwmark;
+    fwmark.permission = PERMISSION_SYSTEM;
+    fwmark.explicitlySelected = true;
+    fwmark.protectedFromVpn = true;
+    fwmark.netId = netid;
+    unsigned mark = fwmark.intValue;
+    DnsTlsTransport xport(mark, IPPROTO_TCP, ss, fingerprints);
+    int replylen = 0;
+    xport.doQuery(query, qlen, recvbuf, kRecvBufSize, &replylen);
+    if (replylen == 0) {
+        if (DBG) {
+            ALOGD("doQuery failed");
+        }
+        return false;
+    }
+
+    if (replylen < NS_HFIXEDSZ) {
+        if (DBG) {
+            ALOGW("short response: %d", replylen);
+        }
+        return false;
+    }
+
+    const int qdcount = (recvbuf[4] << 8) | recvbuf[5];
+    if (qdcount != 1) {
+        ALOGW("reply query count != 1: %d", qdcount);
+        return false;
+    }
+
+    const int ancount = (recvbuf[6] << 8) | recvbuf[7];
+    if (DBG) {
+        ALOGD("%u answer count: %d", netid, ancount);
+    }
+
+    // TODO: Further validate the response contents (check for valid AAAA record, ...).
+    // Note that currently, integration tests rely on this function accepting a
+    // response with zero records.
+#if 0
+    for (int i = 0; i < resplen; i++) {
+        ALOGD("recvbuf[%d] = %d %c", i, recvbuf[i], recvbuf[i]);
+    }
+#endif
+    return true;
+}
+
+}  // namespace net
+}  // namespace android
diff --git a/server/dns/DnsTlsTransport.h b/server/dns/DnsTlsTransport.h
new file mode 100644
index 0000000..b9e9f7f
--- /dev/null
+++ b/server/dns/DnsTlsTransport.h
@@ -0,0 +1,80 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_DNSTLSTRANSPORT_H
+#define _DNS_DNSTLSTRANSPORT_H
+
+#include <netinet/in.h>
+#include <set>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <vector>
+
+#include "android-base/unique_fd.h"
+
+// Forward declaration.
+typedef struct ssl_st SSL;
+
+namespace android {
+namespace net {
+
+class DnsTlsTransport {
+public:
+    DnsTlsTransport(unsigned mark, int protocol, const sockaddr_storage &ss,
+            const std::set<std::vector<uint8_t>>& fingerprints)
+            : mMark(mark), mProtocol(protocol), mAddr(ss), mFingerprints(fingerprints)
+            {}
+    ~DnsTlsTransport() {}
+
+    enum class Response : uint8_t { success, network_error, limit_error, internal_error };
+
+    // Given a |query| of length |qlen|, sends it to the server and writes the
+    // response into |ans|, which can accept up to |anssiz| bytes.  Indicates
+    // the number of bytes written in |resplen|.  If |resplen| is zero, an
+    // error has occurred.
+    Response doQuery(const uint8_t *query, size_t qlen, uint8_t *ans, size_t anssiz, int *resplen);
+
+private:
+    // On success, returns a non-blocking socket connected to mAddr (the
+    // connection will likely be in progress if mProtocol is IPPROTO_TCP).
+    // On error, returns -1 with errno set appropriately.
+    android::base::unique_fd makeConnectedSocket() const;
+
+    SSL* sslConnect(int fd);
+
+    // Writes a buffer to the socket.
+    bool sslWrite(int fd, SSL *ssl, const uint8_t *buffer, int len);
+
+    // Reads exactly the specified number of bytes from the socket.  Blocking.
+    // Returns false if the socket closes before enough bytes can be read.
+    bool sslRead(int fd, SSL *ssl, uint8_t *buffer, int len);
+
+    const unsigned mMark;  // Socket mark
+    const int mProtocol;
+    const sockaddr_storage mAddr;
+    const std::set<std::vector<uint8_t>> mFingerprints;
+};
+
+// Check that a given TLS server (ss) is fully working on the specified netid, and has a
+// provided SHA-256 fingerprint (if nonempty).  This function is used in ResolverController
+// to ensure that we don't enable DNS over TLS on networks where it doesn't actually work.
+bool validateDnsTlsServer(unsigned netid, const sockaddr_storage& ss,
+        const std::set<std::vector<uint8_t>>& fingerprints);
+
+}  // namespace net
+}  // namespace android
+
+#endif  // _DNS_DNSTLSTRANSPORT_H