Support RFC 7858 DNS over TLS
This change adds the core capability for DNS over TLS, and creates
private APIs for activating it, but does not provide any way to
activate the functionality in a development environment or on a
real device.
Based on https://android-review.googlesource.com/#/c/373776/
Test: Complete unit+integration tests. Manual tests look good.
Bug: 34953048
Change-Id: Ib99ac1f631fd2c2c8fbf53bdb05f67f8be7713ac
diff --git a/tests/netd_test.cpp b/tests/netd_test.cpp
index bed3785..12d85aa 100644
--- a/tests/netd_test.cpp
+++ b/tests/netd_test.cpp
@@ -27,6 +27,8 @@
#include <android-base/stringprintf.h>
#include <private/android_filesystem_config.h>
+#include <openssl/base64.h>
+
#include <algorithm>
#include <chrono>
#include <iterator>
@@ -46,6 +48,7 @@
#include "dns_responder.h"
#include "dns_responder_client.h"
+#include "dns_tls_frontend.h"
#include "resolv_params.h"
#include "ResolverStats.h"
@@ -685,3 +688,398 @@
ASSERT_NO_FATAL_FAILURE(ShutdownDNSServers(&dns));
}
+
+static std::string base64Encode(const std::vector<uint8_t>& input) {
+ size_t out_len;
+ EXPECT_EQ(1, EVP_EncodedLength(&out_len, input.size()));
+ // out_len includes the trailing NULL.
+ uint8_t output_bytes[out_len];
+ EXPECT_EQ(out_len - 1, EVP_EncodeBlock(output_bytes, input.data(), input.size()));
+ return std::string(reinterpret_cast<char*>(output_bytes));
+}
+
+// Test what happens if the specified TLS server is nonexistent.
+TEST_F(ResolverTest, GetHostByName_TlsMissing) {
+ const char* listen_addr = "127.0.0.3";
+ const char* listen_srv = "53";
+ const char* host_name = "tlsmissing.example.com.";
+ test::DNSResponder dns(listen_addr, listen_srv, 250, ns_rcode::ns_r_servfail, 1.0);
+ dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.3");
+ ASSERT_TRUE(dns.startServer());
+ std::vector<std::string> servers = { listen_addr };
+
+ // There's nothing listening on this address, so validation will either fail or
+ /// hang. Either way, queries will continue to flow to the DNSResponder.
+ auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "", {});
+ ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+
+ const hostent* result;
+
+ result = gethostbyname("tlsmissing");
+ ASSERT_FALSE(result == nullptr);
+ EXPECT_EQ("1.2.3.3", ToString(result));
+
+ rv = mNetdSrv->removePrivateDnsServer(listen_addr);
+ dns.stopServer();
+}
+
+// Test what happens if the specified TLS server replies with garbage.
+TEST_F(ResolverTest, GetHostByName_TlsBroken) {
+ const char* listen_addr = "127.0.0.3";
+ const char* listen_srv = "53";
+ const char* host_name1 = "tlsbroken1.example.com.";
+ const char* host_name2 = "tlsbroken2.example.com.";
+ test::DNSResponder dns(listen_addr, listen_srv, 250, ns_rcode::ns_r_servfail, 1.0);
+ dns.addMapping(host_name1, ns_type::ns_t_a, "1.2.3.1");
+ dns.addMapping(host_name2, ns_type::ns_t_a, "1.2.3.2");
+ ASSERT_TRUE(dns.startServer());
+ std::vector<std::string> servers = { listen_addr };
+
+ // Bind the specified private DNS socket but don't respond to any client sockets yet.
+ int s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+ ASSERT_TRUE(s >= 0);
+ struct sockaddr_in tlsServer = {
+ .sin_family = AF_INET,
+ .sin_port = htons(853),
+ };
+ ASSERT_TRUE(inet_pton(AF_INET, listen_addr, &tlsServer.sin_addr));
+ ASSERT_FALSE(bind(s, reinterpret_cast<struct sockaddr*>(&tlsServer), sizeof(tlsServer)));
+ ASSERT_FALSE(listen(s, 1));
+
+ auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "", {});
+ ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+
+ // SetResolversForNetwork should have triggered a validation connection to this address.
+ struct sockaddr_storage cliaddr;
+ socklen_t sin_size = sizeof(cliaddr);
+ int new_fd = accept(s, reinterpret_cast<struct sockaddr *>(&cliaddr), &sin_size);
+ ASSERT_TRUE(new_fd > 0);
+
+ // We've received the new file descriptor but not written to it or closed, so the
+ // validation is still pending. Queries should still flow correctly because the
+ // server is not used until validation succeeds.
+ const hostent* result;
+ result = gethostbyname("tlsbroken1");
+ ASSERT_FALSE(result == nullptr);
+ EXPECT_EQ("1.2.3.1", ToString(result));
+
+ // Now we cause the validation to fail.
+ std::string garbage = "definitely not a valid TLS ServerHello";
+ write(new_fd, garbage.data(), garbage.size());
+ close(new_fd);
+
+ // Validation failure shouldn't interfere with lookups, because lookups won't be sent
+ // to the TLS server unless validation succeeds.
+ result = gethostbyname("tlsbroken2");
+ ASSERT_FALSE(result == nullptr);
+ EXPECT_EQ("1.2.3.2", ToString(result));
+
+ rv = mNetdSrv->removePrivateDnsServer(listen_addr);
+ dns.stopServer();
+ close(s);
+}
+
+TEST_F(ResolverTest, GetHostByName_Tls) {
+ const char* listen_addr = "127.0.0.3";
+ const char* listen_udp = "53";
+ const char* listen_tls = "853";
+ const char* host_name1 = "tls1.example.com.";
+ const char* host_name2 = "tls2.example.com.";
+ const char* host_name3 = "tls3.example.com.";
+ test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
+ dns.addMapping(host_name1, ns_type::ns_t_a, "1.2.3.1");
+ dns.addMapping(host_name2, ns_type::ns_t_a, "1.2.3.2");
+ dns.addMapping(host_name3, ns_type::ns_t_a, "1.2.3.3");
+ ASSERT_TRUE(dns.startServer());
+ std::vector<std::string> servers = { listen_addr };
+
+ test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
+ ASSERT_TRUE(tls.startServer());
+ auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "", {});
+ ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+
+ const hostent* result;
+
+ // Wait for validation to complete.
+ EXPECT_TRUE(tls.waitForQueries(1, 5000));
+
+ result = gethostbyname("tls1");
+ ASSERT_FALSE(result == nullptr);
+ EXPECT_EQ("1.2.3.1", ToString(result));
+
+ // Wait for query to get counted.
+ EXPECT_TRUE(tls.waitForQueries(2, 5000));
+
+ // Stop the TLS server. Since it's already been validated, queries will
+ // continue to be routed to it.
+ tls.stopServer();
+
+ result = gethostbyname("tls2");
+ EXPECT_TRUE(result == nullptr);
+ EXPECT_EQ(HOST_NOT_FOUND, h_errno);
+
+ // Remove the TLS server setting. Queries should now be routed to the
+ // UDP endpoint.
+ rv = mNetdSrv->removePrivateDnsServer(listen_addr);
+
+ result = gethostbyname("tls3");
+ ASSERT_FALSE(result == nullptr);
+ EXPECT_EQ("1.2.3.3", ToString(result));
+
+ dns.stopServer();
+}
+
+TEST_F(ResolverTest, GetHostByName_TlsFingerprint) {
+ const char* listen_addr = "127.0.0.3";
+ const char* listen_udp = "53";
+ const char* listen_tls = "853";
+ const char* host_name = "tlsfingerprint.example.com.";
+ test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
+ dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.1");
+ ASSERT_TRUE(dns.startServer());
+ std::vector<std::string> servers = { listen_addr };
+
+ test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
+ ASSERT_TRUE(tls.startServer());
+ auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "SHA-256",
+ { base64Encode(tls.fingerprint()) });
+ ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+
+ const hostent* result;
+
+ // Wait for validation to complete.
+ EXPECT_TRUE(tls.waitForQueries(1, 5000));
+
+ result = gethostbyname("tlsfingerprint");
+ ASSERT_FALSE(result == nullptr);
+ EXPECT_EQ("1.2.3.1", ToString(result));
+
+ // Wait for query to get counted.
+ EXPECT_TRUE(tls.waitForQueries(2, 5000));
+
+ rv = mNetdSrv->removePrivateDnsServer(listen_addr);
+ tls.stopServer();
+ dns.stopServer();
+}
+
+TEST_F(ResolverTest, GetHostByName_BadTlsFingerprint) {
+ const char* listen_addr = "127.0.0.3";
+ const char* listen_udp = "53";
+ const char* listen_tls = "853";
+ const char* host_name = "badtlsfingerprint.example.com.";
+ test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
+ dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.1");
+ ASSERT_TRUE(dns.startServer());
+ std::vector<std::string> servers = { listen_addr };
+
+ test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
+ ASSERT_TRUE(tls.startServer());
+ std::vector<uint8_t> bad_fingerprint = tls.fingerprint();
+ bad_fingerprint[5] += 1; // Corrupt the fingerprint.
+ auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "SHA-256",
+ { base64Encode(bad_fingerprint) });
+ ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+
+ const hostent* result;
+
+ // The initial validation should fail at the fingerprint check before
+ // issuing a query.
+ EXPECT_FALSE(tls.waitForQueries(1, 500));
+
+ result = gethostbyname("badtlsfingerprint");
+ ASSERT_FALSE(result == nullptr);
+ EXPECT_EQ("1.2.3.1", ToString(result));
+
+ // The query should have bypassed the TLS frontend, because validation
+ // failed.
+ EXPECT_FALSE(tls.waitForQueries(1, 500));
+
+ rv = mNetdSrv->removePrivateDnsServer(listen_addr);
+ tls.stopServer();
+ dns.stopServer();
+}
+
+// Test that we can pass two different fingerprints, and connection succeeds as long as
+// at least one of them matches the server.
+TEST_F(ResolverTest, GetHostByName_TwoTlsFingerprints) {
+ const char* listen_addr = "127.0.0.3";
+ const char* listen_udp = "53";
+ const char* listen_tls = "853";
+ const char* host_name = "twotlsfingerprints.example.com.";
+ test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
+ dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.1");
+ ASSERT_TRUE(dns.startServer());
+ std::vector<std::string> servers = { listen_addr };
+
+ test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
+ ASSERT_TRUE(tls.startServer());
+ std::vector<uint8_t> bad_fingerprint = tls.fingerprint();
+ bad_fingerprint[5] += 1; // Corrupt the fingerprint.
+ auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "SHA-256",
+ { base64Encode(bad_fingerprint), base64Encode(tls.fingerprint()) });
+ ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+
+ const hostent* result;
+
+ // Wait for validation to complete.
+ EXPECT_TRUE(tls.waitForQueries(1, 5000));
+
+ result = gethostbyname("twotlsfingerprints");
+ ASSERT_FALSE(result == nullptr);
+ EXPECT_EQ("1.2.3.1", ToString(result));
+
+ // Wait for query to get counted.
+ EXPECT_TRUE(tls.waitForQueries(2, 5000));
+
+ rv = mNetdSrv->removePrivateDnsServer(listen_addr);
+ tls.stopServer();
+ dns.stopServer();
+}
+
+TEST_F(ResolverTest, GetHostByName_TlsFingerprintGoesBad) {
+ const char* listen_addr = "127.0.0.3";
+ const char* listen_udp = "53";
+ const char* listen_tls = "853";
+ const char* host_name1 = "tlsfingerprintgoesbad1.example.com.";
+ const char* host_name2 = "tlsfingerprintgoesbad2.example.com.";
+ test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
+ dns.addMapping(host_name1, ns_type::ns_t_a, "1.2.3.1");
+ dns.addMapping(host_name2, ns_type::ns_t_a, "1.2.3.2");
+ ASSERT_TRUE(dns.startServer());
+ std::vector<std::string> servers = { listen_addr };
+
+ test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
+ ASSERT_TRUE(tls.startServer());
+ auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "SHA-256",
+ { base64Encode(tls.fingerprint()) });
+ ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+
+ const hostent* result;
+
+ // Wait for validation to complete.
+ EXPECT_TRUE(tls.waitForQueries(1, 5000));
+
+ result = gethostbyname("tlsfingerprintgoesbad1");
+ ASSERT_FALSE(result == nullptr);
+ EXPECT_EQ("1.2.3.1", ToString(result));
+
+ // Wait for query to get counted.
+ EXPECT_TRUE(tls.waitForQueries(2, 5000));
+
+ // Restart the TLS server. This will generate a new certificate whose fingerprint
+ // no longer matches the stored fingerprint.
+ tls.stopServer();
+ tls.startServer();
+
+ result = gethostbyname("tlsfingerprintgoesbad2");
+ ASSERT_TRUE(result == nullptr);
+ EXPECT_EQ(HOST_NOT_FOUND, h_errno);
+
+ rv = mNetdSrv->removePrivateDnsServer(listen_addr);
+ tls.stopServer();
+ dns.stopServer();
+}
+
+TEST_F(ResolverTest, GetHostByName_TlsFailover) {
+ const char* listen_addr1 = "127.0.0.3";
+ const char* listen_addr2 = "127.0.0.4";
+ const char* listen_udp = "53";
+ const char* listen_tls = "853";
+ const char* host_name1 = "tlsfailover1.example.com.";
+ const char* host_name2 = "tlsfailover2.example.com.";
+ test::DNSResponder dns1(listen_addr1, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
+ test::DNSResponder dns2(listen_addr2, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
+ dns1.addMapping(host_name1, ns_type::ns_t_a, "1.2.3.1");
+ dns1.addMapping(host_name2, ns_type::ns_t_a, "1.2.3.2");
+ dns2.addMapping(host_name1, ns_type::ns_t_a, "1.2.3.3");
+ dns2.addMapping(host_name2, ns_type::ns_t_a, "1.2.3.4");
+ ASSERT_TRUE(dns1.startServer());
+ ASSERT_TRUE(dns2.startServer());
+ std::vector<std::string> servers = { listen_addr1, listen_addr2 };
+
+ test::DnsTlsFrontend tls1(listen_addr1, listen_tls, listen_addr1, listen_udp);
+ test::DnsTlsFrontend tls2(listen_addr2, listen_tls, listen_addr2, listen_udp);
+ ASSERT_TRUE(tls1.startServer());
+ ASSERT_TRUE(tls2.startServer());
+ auto rv = mNetdSrv->addPrivateDnsServer(listen_addr1, 853, "SHA-256",
+ { base64Encode(tls1.fingerprint()) });
+ rv = mNetdSrv->addPrivateDnsServer(listen_addr2, 853, "SHA-256",
+ { base64Encode(tls2.fingerprint()) });
+ ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+
+ const hostent* result;
+
+ // Wait for validation to complete.
+ EXPECT_TRUE(tls1.waitForQueries(1, 5000));
+ EXPECT_TRUE(tls2.waitForQueries(1, 5000));
+
+ result = gethostbyname("tlsfailover1");
+ ASSERT_FALSE(result == nullptr);
+ EXPECT_EQ("1.2.3.1", ToString(result));
+
+ // Wait for query to get counted.
+ EXPECT_TRUE(tls1.waitForQueries(2, 5000));
+ // No new queries should have reached tls2.
+ EXPECT_EQ(1, tls2.queries());
+
+ // Stop tls1. Subsequent queries should attempt to reach tls1, fail, and retry to tls2.
+ tls1.stopServer();
+
+ result = gethostbyname("tlsfailover2");
+ EXPECT_EQ("1.2.3.4", ToString(result));
+
+ // Wait for query to get counted.
+ EXPECT_TRUE(tls2.waitForQueries(2, 5000));
+
+ // No additional queries should have reached the insecure servers.
+ EXPECT_EQ(2U, dns1.queries().size());
+ EXPECT_EQ(2U, dns2.queries().size());
+
+ rv = mNetdSrv->removePrivateDnsServer(listen_addr1);
+ rv = mNetdSrv->removePrivateDnsServer(listen_addr2);
+ tls2.stopServer();
+ dns1.stopServer();
+ dns2.stopServer();
+}
+
+TEST_F(ResolverTest, GetAddrInfo_Tls) {
+ const char* listen_addr = "127.0.0.3";
+ const char* listen_udp = "53";
+ const char* listen_tls = "853";
+ const char* host_name = "addrinfotls.example.com.";
+ test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
+ dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
+ dns.addMapping(host_name, ns_type::ns_t_aaaa, "::1.2.3.4");
+ ASSERT_TRUE(dns.startServer());
+ std::vector<std::string> servers = { listen_addr };
+
+ test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
+ ASSERT_TRUE(tls.startServer());
+ auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "SHA-256",
+ { base64Encode(tls.fingerprint()) });
+ ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+
+ // Wait for validation to complete.
+ EXPECT_TRUE(tls.waitForQueries(1, 5000));
+
+ dns.clearQueries();
+ addrinfo* result = nullptr;
+ EXPECT_EQ(0, getaddrinfo("addrinfotls", nullptr, nullptr, &result));
+ size_t found = GetNumQueries(dns, host_name);
+ EXPECT_LE(1U, found);
+ // Could be A or AAAA
+ std::string result_str = ToString(result);
+ EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
+ << ", result_str='" << result_str << "'";
+ // TODO: Use ScopedAddrinfo or similar once it is available in a common header file.
+ if (result) {
+ freeaddrinfo(result);
+ result = nullptr;
+ }
+ // Wait for both A and AAAA queries to get counted.
+ EXPECT_TRUE(tls.waitForQueries(3, 5000));
+
+ rv = mNetdSrv->removePrivateDnsServer(listen_addr);
+ tls.stopServer();
+ dns.stopServer();
+}