Merge "Remove support for DNS-over-TLS certificate fingerprints." am: 02eda0692e
am: 843cb17ce3
Change-Id: I792abfa8d2ed7794651b99e78dee2b8dcd4e1296
diff --git a/Android.bp b/Android.bp
index 243a443..059fa86 100644
--- a/Android.bp
+++ b/Android.bp
@@ -70,7 +70,7 @@
// on system ABIs
stl: "libc++_static",
static_libs: [
- "dnsresolver_aidl_interface-V2-ndk_platform",
+ "dnsresolver_aidl_interface-ndk_platform",
"libbase",
"libcrypto",
"libcutils",
@@ -100,6 +100,7 @@
debuggable: {
cppflags: [
"-DRESOLV_ALLOW_VERBOSE_LOGGING=1",
+ "-DRESOLV_INJECT_CA_CERTIFICATE=1",
],
},
},
@@ -171,6 +172,7 @@
"libutils",
],
static_libs: [
+ "dnsresolver_aidl_interface-cpp",
"libgmock",
"libnetd_test_dnsresponder",
"libnetd_test_metrics_listener",
@@ -180,7 +182,6 @@
"libnetdutils",
"netd_aidl_interface-V2-cpp",
"netd_event_listener_interface-V1-cpp",
- "dnsresolver_aidl_interface-V2-cpp",
],
compile_multilib: "both",
sanitize: {
diff --git a/DnsResolverService.cpp b/DnsResolverService.cpp
index f9a66e8..3b83381 100644
--- a/DnsResolverService.cpp
+++ b/DnsResolverService.cpp
@@ -27,8 +27,6 @@
#include <android/binder_manager.h>
#include <android/binder_process.h>
#include <netdutils/DumpWriter.h>
-#include <netdutils/NetworkConstants.h> // SHA256_SIZE
-#include <openssl/base64.h>
#include <private/android_filesystem_config.h> // AID_SYSTEM
#include "DnsResolver.h"
@@ -164,33 +162,6 @@
return ::ndk::ScopedAStatus(AStatus_fromExceptionCodeWithMessage(EX_SECURITY, err.c_str()));
}
-namespace {
-
-// Parse a base64 encoded string into a vector of bytes.
-// On failure, return an empty vector.
-static std::vector<uint8_t> parseBase64(const std::string& input) {
- std::vector<uint8_t> decoded;
- size_t out_len;
- if (EVP_DecodedLength(&out_len, input.size()) != 1) {
- return decoded;
- }
- // out_len is now an upper bound on the output length.
- decoded.resize(out_len);
- if (EVP_DecodeBase64(decoded.data(), &out_len, decoded.size(),
- reinterpret_cast<const uint8_t*>(input.data()), input.size()) == 1) {
- // Possibly shrink the vector if the actual output was smaller than the bound.
- decoded.resize(out_len);
- } else {
- decoded.clear();
- }
- if (out_len != android::netdutils::SHA256_SIZE) {
- decoded.clear();
- }
- return decoded;
-}
-
-} // namespace
-
::ndk::ScopedAStatus DnsResolverService::setResolverConfiguration(
const ResolverParamsParcel& resolverParams) {
// Locking happens in PrivateDnsConfiguration and res_* functions.
@@ -203,21 +174,9 @@
resolverParams.sampleValiditySeconds, resolverParams.successThreshold,
resolverParams.minSamples, resolverParams.maxSamples,
resolverParams.baseTimeoutMsec, resolverParams.retryCount,
- resolverParams.tlsServers, resolverParams.tlsFingerprints);
+ resolverParams.tlsName, resolverParams.tlsServers);
- std::set<std::vector<uint8_t>> decoded_fingerprints;
- for (const std::string& fingerprint : resolverParams.tlsFingerprints) {
- std::vector<uint8_t> decoded = parseBase64(fingerprint);
- if (decoded.empty()) {
- return ::ndk::ScopedAStatus(AStatus_fromServiceSpecificErrorWithMessage(
- EINVAL, "ResolverController error: bad fingerprint"));
- }
- decoded_fingerprints.emplace(decoded);
- }
-
- int res =
- gDnsResolv->resolverCtrl.setResolverConfiguration(resolverParams, decoded_fingerprints);
-
+ int res = gDnsResolv->resolverCtrl.setResolverConfiguration(resolverParams);
gResNetdCallbacks.log(entry.returns(res).withAutomaticDuration().toString().c_str());
return statusFromErrcode(res);
diff --git a/DnsTlsServer.cpp b/DnsTlsServer.cpp
index a97c672..2b8e417 100644
--- a/DnsTlsServer.cpp
+++ b/DnsTlsServer.cpp
@@ -88,7 +88,7 @@
namespace android {
namespace net {
-// This comparison ignores ports and fingerprints.
+// This comparison ignores ports and certificates.
bool AddressComparator::operator() (const DnsTlsServer& x, const DnsTlsServer& y) const {
if (x.ss.ss_family != y.ss.ss_family) {
return x.ss.ss_family < y.ss.ss_family;
@@ -112,7 +112,6 @@
return std::tie(
s.ss,
s.name,
- s.fingerprints,
s.protocol
);
}
@@ -126,7 +125,7 @@
}
bool DnsTlsServer::wasExplicitlyConfigured() const {
- return !name.empty() || !fingerprints.empty();
+ return !name.empty();
}
} // namespace net
diff --git a/DnsTlsServer.h b/DnsTlsServer.h
index 7fc4a35..c3454cb 100644
--- a/DnsTlsServer.h
+++ b/DnsTlsServer.h
@@ -47,16 +47,14 @@
// The server location, including IP and port.
sockaddr_storage ss = {};
- // A set of SHA256 public key fingerprints. If this set is nonempty, the server
- // must present a self-consistent certificate chain that contains a certificate
- // whose public key matches one of these fingerprints. Otherwise, the client will
- // terminate the connection.
- std::set<std::vector<uint8_t>> fingerprints;
-
// The server's hostname. If this string is nonempty, the server must present a
// certificate that indicates this name and has a valid chain to a trusted root CA.
std::string name;
+ // The certificate of the CA that signed the server's certificate.
+ // It is used to store temporary test CA certificate for internal tests.
+ std::string certificate;
+
// Placeholder. More protocols might be defined in the future.
int protocol = IPPROTO_TCP;
diff --git a/DnsTlsSocket.cpp b/DnsTlsSocket.cpp
index eafbf65..a1068f6 100644
--- a/DnsTlsSocket.cpp
+++ b/DnsTlsSocket.cpp
@@ -39,6 +39,11 @@
#include "netdutils/SocketOption.h"
#include "private/android_filesystem_config.h" // AID_DNS
+// NOTE: Inject CA certificate for internal testing -- do NOT enable in production builds
+#ifndef RESOLV_INJECT_CA_CERTIFICATE
+#define RESOLV_INJECT_CA_CERTIFICATE 0
+#endif
+
namespace android {
using netdutils::enableSockopt;
@@ -51,7 +56,6 @@
namespace {
constexpr const char kCaCertDir[] = "/system/etc/security/cacerts";
-constexpr size_t SHA256_SIZE = SHA256_DIGEST_LENGTH;
int waitForReading(int fd) {
struct pollfd fds = { .fd = fd, .events = POLLIN };
@@ -121,31 +125,27 @@
return netdutils::status::ok;
}
-bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
- int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), nullptr);
- unsigned char spki[spki_len];
- unsigned char* temp = spki;
- if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
- LOG(WARNING) << "SPKI length mismatch";
+bool DnsTlsSocket::setTestCaCertificate() {
+ bssl::UniquePtr<BIO> bio(
+ BIO_new_mem_buf(mServer.certificate.data(), mServer.certificate.size()));
+ bssl::UniquePtr<X509> cert(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
+ if (!cert) {
+ LOG(ERROR) << "Failed to read cert";
return false;
}
- out->resize(SHA256_SIZE);
- unsigned int digest_len = 0;
- int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), nullptr);
- if (ret != 1) {
- LOG(WARNING) << "Server cert digest extraction failed";
- return false;
- }
- if (digest_len != out->size()) {
- LOG(WARNING) << "Wrong digest length: " << digest_len;
+
+ X509_STORE* cert_store = SSL_CTX_get_cert_store(mSslCtx.get());
+ if (!X509_STORE_add_cert(cert_store, cert.get())) {
+ LOG(ERROR) << "Failed to add cert";
return false;
}
return true;
}
+// TODO: Try to use static sSslCtx instead of mSslCtx
bool DnsTlsSocket::initialize() {
- // This method should only be called once, at the beginning, so locking should be
- // unnecessary. This lock only serves to help catch bugs in code that calls this method.
+ // This method is called every time when a new SSL connection is created.
+ // This lock only serves to help catch bugs in code that calls this method.
std::lock_guard guard(mLock);
if (mSslCtx) {
// This is a bug in the caller.
@@ -156,12 +156,21 @@
return false;
}
- // Load system CA certs for hostname verification.
+ // Load system CA certs from CAPath for hostname verification.
//
// For discussion of alternative, sustainable approaches see b/71909242.
- if (SSL_CTX_load_verify_locations(mSslCtx.get(), nullptr, kCaCertDir) != 1) {
- LOG(ERROR) << "Failed to load CA cert dir: " << kCaCertDir;
- return false;
+ if (RESOLV_INJECT_CA_CERTIFICATE && !mServer.certificate.empty()) {
+ // Inject test CA certs from ResolverParamsParcel.caCertificate for internal testing.
+ LOG(WARNING) << "test CA certificate is valid";
+ if (!setTestCaCertificate()) {
+ LOG(ERROR) << "Failed to set test CA certificate";
+ return false;
+ }
+ } else {
+ if (SSL_CTX_load_verify_locations(mSslCtx.get(), nullptr, kCaCertDir) != 1) {
+ LOG(ERROR) << "Failed to load CA cert dir: " << kCaCertDir;
+ return false;
+ }
}
// Enable TLS false start
@@ -210,8 +219,9 @@
}
if (!mServer.name.empty()) {
+ LOG(VERBOSE) << "Checking DNS over TLS hostname = " << mServer.name.c_str();
if (SSL_set_tlsext_host_name(ssl.get(), mServer.name.c_str()) != 1) {
- LOG(ERROR) << "ailed to set SNI to " << mServer.name;
+ LOG(ERROR) << "Failed to set SNI to " << mServer.name;
return nullptr;
}
X509_VERIFY_PARAM* param = SSL_get0_param(ssl.get());
@@ -258,56 +268,6 @@
}
}
- // TODO: Call SSL_shutdown before discarding the session if validation fails.
- if (!mServer.fingerprints.empty()) {
- LOG(DEBUG) << "Checking DNS over TLS fingerprint";
-
- // We only care that the chain is internally self-consistent, not that
- // it chains to a trusted root, so we can ignore some kinds of errors.
- // TODO: Add a CA root verification mode that respects these errors.
- int verify_result = SSL_get_verify_result(ssl.get());
- switch (verify_result) {
- case X509_V_OK:
- case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT:
- case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN:
- case X509_V_ERR_CERT_UNTRUSTED:
- break;
- default:
- LOG(WARNING) << "Invalid certificate chain, error " << verify_result;
- return nullptr;
- }
-
- STACK_OF(X509) *chain = SSL_get_peer_cert_chain(ssl.get());
- if (!chain) {
- LOG(WARNING) << "Server has null certificate";
- return nullptr;
- }
- // Chain and its contents are owned by ssl, so we don't need to free explicitly.
- bool matched = false;
- for (size_t i = 0; i < sk_X509_num(chain); ++i) {
- // This appears to be O(N^2), but there doesn't seem to be a straightforward
- // way to walk a STACK_OF nondestructively in linear time.
- X509* cert = sk_X509_value(chain, i);
- std::vector<uint8_t> digest;
- if (!getSPKIDigest(cert, &digest)) {
- LOG(ERROR) << "Digest computation failed";
- return nullptr;
- }
-
- if (mServer.fingerprints.count(digest) > 0) {
- matched = true;
- break;
- }
- }
-
- if (!matched) {
- LOG(WARNING) << "No matching fingerprint";
- return nullptr;
- }
-
- LOG(DEBUG) << "DNS over TLS fingerprint is correct";
- }
-
LOG(DEBUG) << mMark << " handshake complete";
return ssl;
diff --git a/DnsTlsSocket.h b/DnsTlsSocket.h
index 2940500..1439ea9 100644
--- a/DnsTlsSocket.h
+++ b/DnsTlsSocket.h
@@ -96,6 +96,9 @@
bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock);
bool readResponse() REQUIRES(mLock);
+ // It is only used for DNS-OVER-TLS internal test.
+ bool setTestCaCertificate() REQUIRES(mLock);
+
// Similar to query(), this function uses incrementEventFd to send a message to the
// loop thread. However, instead of incrementing the counter by one (indicating a
// new query), it wraps the counter to negative, which we use to indicate a shutdown
diff --git a/DnsTlsTransport.h b/DnsTlsTransport.h
index 6c98fa6..3e43c7e 100644
--- a/DnsTlsTransport.h
+++ b/DnsTlsTransport.h
@@ -52,9 +52,9 @@
// Given a |query|, this method sends it to the server and returns the result asynchronously.
std::future<Result> query(const netdutils::Slice query) EXCLUDES(mLock);
- // Check that a given TLS server is fully working on the specified netid, and has the
- // provided SHA-256 fingerprint (if nonempty). This function is used in ResolverController
- // to ensure that we don't enable DNS over TLS on networks where it doesn't actually work.
+ // Check that a given TLS server is fully working on the specified netid.
+ // This function is used in ResolverController to ensure that we don't enable DNS over TLS
+ // on networks where it doesn't actually work.
static bool validate(const DnsTlsServer& server, unsigned netid, uint32_t mark);
// Implement IDnsTlsSocketObserver
diff --git a/PrivateDnsConfiguration.cpp b/PrivateDnsConfiguration.cpp
index 8b5df4b..e974d50 100644
--- a/PrivateDnsConfiguration.cpp
+++ b/PrivateDnsConfiguration.cpp
@@ -54,11 +54,9 @@
int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
const std::vector<std::string>& servers, const std::string& name,
- const std::set<std::vector<uint8_t>>& fingerprints) {
+ const std::string& caCert) {
LOG(VERBOSE) << "PrivateDnsConfiguration::set(" << netId << ", " << mark << ", "
- << servers.size() << ", " << name << ", " << fingerprints.size() << ")";
-
- const bool explicitlyConfigured = !name.empty() || !fingerprints.empty();
+ << servers.size() << ", " << name << ")";
// Parse the list of servers that has been passed in
std::set<DnsTlsServer> tlsServers;
@@ -69,12 +67,12 @@
}
DnsTlsServer server(parsed);
server.name = name;
- server.fingerprints = fingerprints;
+ server.certificate = caCert;
tlsServers.insert(server);
}
std::lock_guard guard(mPrivateDnsLock);
- if (explicitlyConfigured) {
+ if (!name.empty()) {
mPrivateDnsModes[netId] = PrivateDnsMode::STRICT;
} else if (!tlsServers.empty()) {
mPrivateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC;
diff --git a/PrivateDnsConfiguration.h b/PrivateDnsConfiguration.h
index 50fb54d..d52db77 100644
--- a/PrivateDnsConfiguration.h
+++ b/PrivateDnsConfiguration.h
@@ -54,7 +54,7 @@
class PrivateDnsConfiguration {
public:
int set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
- const std::string& name, const std::set<std::vector<uint8_t>>& fingerprints);
+ const std::string& name, const std::string& caCert);
PrivateDnsStatus getStatus(unsigned netId);
diff --git a/ResolverController.cpp b/ResolverController.cpp
index 8296826..0128a12 100644
--- a/ResolverController.cpp
+++ b/ResolverController.cpp
@@ -195,9 +195,7 @@
return resolv_create_cache_for_net(netId);
}
-int ResolverController::setResolverConfiguration(
- const ResolverParamsParcel& resolverParams,
- const std::set<std::vector<uint8_t>>& tlsFingerprints) {
+int ResolverController::setResolverConfiguration(const ResolverParamsParcel& resolverParams) {
using aidl::android::net::IDnsResolver;
// At private DNS validation time, we only know the netId, so we have to guess/compute the
@@ -213,8 +211,10 @@
if (tlsServers.size() > MAXNS) {
tlsServers.resize(MAXNS);
}
- const int err = gPrivateDnsConfiguration.set(resolverParams.netId, fwmark.intValue, tlsServers,
- resolverParams.tlsName, tlsFingerprints);
+ const int err =
+ gPrivateDnsConfiguration.set(resolverParams.netId, fwmark.intValue, tlsServers,
+ resolverParams.tlsName, resolverParams.caCertificate);
+
if (err != 0) {
return err;
}
diff --git a/ResolverController.h b/ResolverController.h
index 4649225..a4ea477 100644
--- a/ResolverController.h
+++ b/ResolverController.h
@@ -43,8 +43,7 @@
// Binder specific functions, which convert between the ResolverParamsParcel and the
// actual data structures, and call setDnsServer() / getDnsInfo() for the actual processing.
- int setResolverConfiguration(const aidl::android::net::ResolverParamsParcel& resolverParams,
- const std::set<std::vector<uint8_t>>& tlsFingerprints);
+ int setResolverConfiguration(const aidl::android::net::ResolverParamsParcel& resolverParams);
int getResolverInfo(int32_t netId, std::vector<std::string>* servers,
std::vector<std::string>* domains, std::vector<std::string>* tlsServers,
diff --git a/binder/android/net/ResolverParamsParcel.aidl b/binder/android/net/ResolverParamsParcel.aidl
index d25880f..4f266d3 100644
--- a/binder/android/net/ResolverParamsParcel.aidl
+++ b/binder/android/net/ResolverParamsParcel.aidl
@@ -81,6 +81,12 @@
* An array containing TLS public key fingerprints (pins) of which each server must match
* at least one, or empty if there are no pinned keys.
*/
- // DEPRECATED: remove tlsFingerprints in new code
+ // DEPRECATED: no longer to use it
@utf8InCpp String[] tlsFingerprints;
+
+ /**
+ * Certificate authority that signed the certificate; only used by DNS-over-TLS tests.
+ *
+ */
+ @utf8InCpp String caCertificate;
}
diff --git a/dns_tls_test.cpp b/dns_tls_test.cpp
index bec2732..72e6897 100644
--- a/dns_tls_test.cpp
+++ b/dns_tls_test.cpp
@@ -62,9 +62,6 @@
LOG(ERROR) << "Failed to parse server address: " << server;
}
-bytevec FINGERPRINT1 = { 1 };
-bytevec FINGERPRINT2 = { 2 };
-
std::string SERVERNAME1 = "dns.example.com";
std::string SERVERNAME2 = "dns.example.org";
@@ -78,7 +75,6 @@
parseServer("2001:db8::2", 853, &V6ADDR2);
SERVER1 = DnsTlsServer(V4ADDR1);
- SERVER1.fingerprints.insert(FINGERPRINT1);
SERVER1.name = SERVERNAME1;
}
@@ -795,29 +791,6 @@
EXPECT_TRUE(s2.wasExplicitlyConfigured());
}
-TEST_F(ServerTest, Fingerprint) {
- DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
-
- s1.fingerprints.insert(FINGERPRINT1);
- checkUnequal(s1, s2);
- EXPECT_TRUE(isAddressEqual(s1, s2));
-
- s2.fingerprints.insert(FINGERPRINT2);
- checkUnequal(s1, s2);
- EXPECT_TRUE(isAddressEqual(s1, s2));
-
- s2.fingerprints.insert(FINGERPRINT1);
- checkUnequal(s1, s2);
- EXPECT_TRUE(isAddressEqual(s1, s2));
-
- s1.fingerprints.insert(FINGERPRINT2);
- EXPECT_EQ(s1, s2);
- EXPECT_TRUE(isAddressEqual(s1, s2));
-
- EXPECT_TRUE(s1.wasExplicitlyConfigured());
- EXPECT_TRUE(s2.wasExplicitlyConfigured());
-}
-
TEST(QueryMapTest, Basic) {
DnsTlsQueryMap map;
diff --git a/dnsresolver_binder_test.cpp b/dnsresolver_binder_test.cpp
index a77e6ec..fcd4918 100644
--- a/dnsresolver_binder_test.cpp
+++ b/dnsresolver_binder_test.cpp
@@ -29,9 +29,7 @@
#include <binder/IServiceManager.h>
#include <gmock/gmock-matchers.h>
#include <gtest/gtest.h>
-#include <netdutils/NetworkConstants.h> // SHA256_SIZE
#include <netdutils/Stopwatch.h>
-#include <openssl/base64.h>
#include "tests/dns_metrics_listener/base_metrics_listener.h"
#include "tests/dns_metrics_listener/test_metrics.h"
@@ -95,22 +93,12 @@
namespace {
-std::string base64Encode(const std::vector<uint8_t>& input) {
- size_t out_len;
- EXPECT_EQ(1, EVP_EncodedLength(&out_len, input.size()));
- // out_len includes the trailing NULL.
- uint8_t output_bytes[out_len];
- EXPECT_EQ(out_len - 1, EVP_EncodeBlock(output_bytes, input.data(), input.size()));
- return std::string(reinterpret_cast<char*>(output_bytes));
-}
-
// TODO: Convert tests to ResolverParamsParcel and delete this stub.
ResolverParamsParcel makeResolverParamsParcel(int netId, const std::vector<int>& params,
const std::vector<std::string>& servers,
const std::vector<std::string>& domains,
const std::string& tlsHostname,
- const std::vector<std::string>& tlsServers,
- const std::vector<std::string>& tlsFingerprints) {
+ const std::vector<std::string>& tlsServers) {
using android::net::IDnsResolver;
ResolverParamsParcel paramsParcel;
@@ -133,7 +121,7 @@
paramsParcel.domains = domains;
paramsParcel.tlsName = tlsHostname;
paramsParcel.tlsServers = tlsServers;
- paramsParcel.tlsFingerprints = tlsFingerprints;
+ paramsParcel.tlsFingerprints = {};
return paramsParcel;
}
@@ -250,43 +238,47 @@
dnsClient.TearDown();
}
+// TODO: Need to test more than one server cases.
TEST_F(DnsResolverBinderTest, SetResolverConfiguration_Tls) {
const std::vector<std::string> LOCALLY_ASSIGNED_DNS{"8.8.8.8", "2001:4860:4860::8888"};
- std::vector<uint8_t> fp(android::netdutils::SHA256_SIZE);
- std::vector<uint8_t> short_fp(1);
- std::vector<uint8_t> long_fp(android::netdutils::SHA256_SIZE + 1);
- std::vector<std::string> test_domains;
+ static const std::vector<std::string> valid_v4_addr = {"192.0.2.1"};
+ static const std::vector<std::string> valid_v6_addr = {"2001:db8::2"};
+ static const std::vector<std::string> invalid_v4_addr = {"192.0.*.5"};
+ static const std::vector<std::string> invalid_v6_addr = {"2001:dg8::6"};
+ constexpr char valid_tls_name[] = "example.com";
std::vector<int> test_params = {300, 25, 8, 8};
+ // We enumerate valid and invalid v4/v6 address, and several different TLS names
+ // to be the input data and verify the binder status.
static const struct TestData {
const std::vector<std::string> servers;
const std::string tlsName;
- const std::vector<std::vector<uint8_t>> tlsFingerprints;
const int expectedReturnCode;
} kTlsTestData[] = {
- {{"192.0.2.1"}, "", {}, 0},
- {{"2001:db8::2"}, "host.name", {}, 0},
- {{"192.0.2.3"}, "@@@@", {fp}, 0},
- {{"2001:db8::4"}, "", {fp}, 0},
- {{}, "", {}, 0},
- {{""}, "", {}, EINVAL},
- {{"192.0.*.5"}, "", {}, EINVAL},
- {{"2001:dg8::6"}, "", {}, EINVAL},
- {{"2001:db8::c"}, "", {short_fp}, EINVAL},
- {{"192.0.2.12"}, "", {long_fp}, EINVAL},
- {{"2001:db8::e"}, "", {fp, fp, fp}, 0},
- {{"192.0.2.14"}, "", {fp, short_fp}, EINVAL},
+ {valid_v4_addr, valid_tls_name, 0},
+ {valid_v4_addr, "host.com", 0},
+ {valid_v4_addr, "@@@@", 0},
+ {valid_v4_addr, "", 0},
+ {valid_v6_addr, valid_tls_name, 0},
+ {valid_v6_addr, "host.com", 0},
+ {valid_v6_addr, "@@@@", 0},
+ {valid_v6_addr, "", 0},
+ {invalid_v4_addr, valid_tls_name, EINVAL},
+ {invalid_v4_addr, "host.com", EINVAL},
+ {invalid_v4_addr, "@@@@", EINVAL},
+ {invalid_v4_addr, "", EINVAL},
+ {invalid_v6_addr, valid_tls_name, EINVAL},
+ {invalid_v6_addr, "host.com", EINVAL},
+ {invalid_v6_addr, "@@@@", EINVAL},
+ {invalid_v6_addr, "", EINVAL},
+ {{}, "", 0},
+ {{""}, "", EINVAL},
};
for (size_t i = 0; i < std::size(kTlsTestData); i++) {
const auto& td = kTlsTestData[i];
- std::vector<std::string> fingerprints;
- for (const auto& fingerprint : td.tlsFingerprints) {
- fingerprints.push_back(base64Encode(fingerprint));
- }
- const auto resolverParams =
- makeResolverParamsParcel(TEST_NETID, test_params, LOCALLY_ASSIGNED_DNS,
- test_domains, td.tlsName, td.servers, fingerprints);
+ const auto resolverParams = makeResolverParamsParcel(
+ TEST_NETID, test_params, LOCALLY_ASSIGNED_DNS, {}, td.tlsName, td.servers);
binder::Status status = mDnsResolver->setResolverConfiguration(resolverParams);
if (td.expectedReturnCode == 0) {
@@ -312,7 +304,7 @@
3, // retry count
};
const auto resolverParams =
- makeResolverParamsParcel(TEST_NETID, testParams, servers, domains, "", {}, {});
+ makeResolverParamsParcel(TEST_NETID, testParams, servers, domains, "", {});
binder::Status status = mDnsResolver->setResolverConfiguration(resolverParams);
EXPECT_TRUE(status.isOk()) << status.exceptionMessage();
diff --git a/resolver_test.cpp b/resolver_test.cpp
index 1007045..d7cd263 100644
--- a/resolver_test.cpp
+++ b/resolver_test.cpp
@@ -71,6 +71,11 @@
// Use maximum reserved appId for applications to avoid conflict with existing uids.
static const int TEST_UID = 99999;
+// Currently the hostname of TLS server must match the CN filed on the server's certificate.
+// Inject a test CA whose hostname is "example.com" for DNS-OVER-TLS tests.
+static const std::string kDefaultPrivateDnsHostName = "example.com";
+static const std::string kDefaultIncorrectPrivateDnsHostName = "www.example.com";
+
// Semi-public Bionic hook used by the NDK (frameworks/base/native/android/net.c)
// Tested here for convenience.
extern "C" int android_getaddrinfofornet(const char* hostname, const char* servname,
@@ -1048,20 +1053,10 @@
testing::ElementsAreArray(res_domains2));
}
-static std::string base64Encode(const std::vector<uint8_t>& input) {
- size_t out_len;
- EXPECT_EQ(1, EVP_EncodedLength(&out_len, input.size()));
- // out_len includes the trailing NULL.
- uint8_t output_bytes[out_len];
- EXPECT_EQ(out_len - 1, EVP_EncodeBlock(output_bytes, input.data(), input.size()));
- return std::string(reinterpret_cast<char*>(output_bytes));
-}
-
// If we move this function to dns_responder_client, it will complicate the dependency need of
// dns_tls_frontend.h.
static void setupTlsServers(const std::vector<std::string>& servers,
- std::vector<std::unique_ptr<test::DnsTlsFrontend>>* tls,
- std::vector<std::string>* fingerprints) {
+ std::vector<std::unique_ptr<test::DnsTlsFrontend>>* tls) {
constexpr char listen_udp[] = "53";
constexpr char listen_tls[] = "853";
@@ -1069,7 +1064,6 @@
auto t = std::make_unique<test::DnsTlsFrontend>(server, listen_tls, server, listen_udp);
t = std::make_unique<test::DnsTlsFrontend>(server, listen_tls, server, listen_udp);
t->startServer();
- fingerprints->push_back(base64Encode(t->fingerprint()));
tls->push_back(std::move(t));
}
}
@@ -1079,7 +1073,6 @@
std::vector<std::unique_ptr<test::DNSResponder>> dns;
std::vector<std::unique_ptr<test::DnsTlsFrontend>> tls;
std::vector<std::string> servers;
- std::vector<std::string> fingerprints;
std::vector<DnsResponderClient::Mapping> mappings;
for (unsigned i = 0; i < MAXDNSRCH + 1; i++) {
@@ -1087,9 +1080,10 @@
}
ASSERT_NO_FATAL_FAILURE(mDnsClient.SetupMappings(1, domains, &mappings));
ASSERT_NO_FATAL_FAILURE(mDnsClient.SetupDNSServers(MAXNS + 1, mappings, &dns, &servers));
- ASSERT_NO_FATAL_FAILURE(setupTlsServers(servers, &tls, &fingerprints));
+ ASSERT_NO_FATAL_FAILURE(setupTlsServers(servers, &tls));
- ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, domains, kDefaultParams, "", fingerprints));
+ ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, domains, kDefaultParams,
+ kDefaultPrivateDnsHostName));
// If the private DNS validation hasn't completed yet before backend DNS servers stop,
// TLS servers will get stuck in handleOneRequest(), which causes this test stuck in
@@ -1177,8 +1171,7 @@
// There's nothing listening on this address, so validation will either fail or
/// hang. Either way, queries will continue to flow to the DNSResponder.
- ASSERT_TRUE(
- mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, "", {}));
+ ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
const hostent* result;
@@ -1218,8 +1211,7 @@
ASSERT_FALSE(listen(s, 1));
// Trigger TLS validation.
- ASSERT_TRUE(
- mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, "", {}));
+ ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
struct sockaddr_storage cliaddr;
socklen_t sin_size = sizeof(cliaddr);
@@ -1269,8 +1261,7 @@
test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
ASSERT_TRUE(tls.startServer());
- ASSERT_TRUE(
- mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, "", {}));
+ ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
const hostent* result;
@@ -1306,152 +1297,6 @@
EXPECT_EQ("1.2.3.3", ToString(result));
}
-TEST_F(ResolverTest, GetHostByName_TlsFingerprint) {
- constexpr char listen_addr[] = "127.0.0.3";
- constexpr char listen_udp[] = "53";
- constexpr char listen_tls[] = "853";
- test::DNSResponder dns;
- ASSERT_TRUE(dns.startServer());
- for (int chain_length = 1; chain_length <= 3; ++chain_length) {
- std::string host_name = StringPrintf("tlsfingerprint%d.example.com.", chain_length);
- dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.1");
- std::vector<std::string> servers = { listen_addr };
-
- test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
- tls.set_chain_length(chain_length);
- ASSERT_TRUE(tls.startServer());
- ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams,
- "", {base64Encode(tls.fingerprint())}));
-
- const hostent* result;
-
- // Wait for validation to complete.
- EXPECT_TRUE(tls.waitForQueries(1, 5000));
-
- result = gethostbyname(StringPrintf("tlsfingerprint%d", chain_length).c_str());
- EXPECT_FALSE(result == nullptr);
- if (result) {
- EXPECT_EQ("1.2.3.1", ToString(result));
-
- // Wait for query to get counted.
- EXPECT_TRUE(tls.waitForQueries(2, 5000));
- }
-
- // Clear TLS bit to ensure revalidation.
- ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
- tls.stopServer();
- }
-}
-
-TEST_F(ResolverTest, GetHostByName_BadTlsFingerprint) {
- constexpr char listen_addr[] = "127.0.0.3";
- constexpr char listen_udp[] = "53";
- constexpr char listen_tls[] = "853";
- constexpr char host_name[] = "badtlsfingerprint.example.com.";
-
- test::DNSResponder dns;
- StartDns(dns, {{host_name, ns_type::ns_t_a, "1.2.3.1"}});
- std::vector<std::string> servers = { listen_addr };
-
- test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
- ASSERT_TRUE(tls.startServer());
- std::vector<uint8_t> bad_fingerprint = tls.fingerprint();
- bad_fingerprint[5] += 1; // Corrupt the fingerprint.
- ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, "",
- {base64Encode(bad_fingerprint)}));
-
- // The initial validation should fail at the fingerprint check before
- // issuing a query.
- EXPECT_FALSE(tls.waitForQueries(1, 500));
-
- // A fingerprint was provided and failed to match, so the query should fail.
- EXPECT_EQ(nullptr, gethostbyname("badtlsfingerprint"));
-
- // Clear TLS bit.
- ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
-}
-
-// Test that we can pass two different fingerprints, and connection succeeds as long as
-// at least one of them matches the server.
-TEST_F(ResolverTest, GetHostByName_TwoTlsFingerprints) {
- constexpr char listen_addr[] = "127.0.0.3";
- constexpr char listen_udp[] = "53";
- constexpr char listen_tls[] = "853";
- constexpr char host_name[] = "twotlsfingerprints.example.com.";
-
- test::DNSResponder dns;
- StartDns(dns, {{host_name, ns_type::ns_t_a, "1.2.3.1"}});
- std::vector<std::string> servers = { listen_addr };
-
- test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
- ASSERT_TRUE(tls.startServer());
- std::vector<uint8_t> bad_fingerprint = tls.fingerprint();
- bad_fingerprint[5] += 1; // Corrupt the fingerprint.
- ASSERT_TRUE(mDnsClient.SetResolversWithTls(
- servers, kDefaultSearchDomains, kDefaultParams, "",
- {base64Encode(bad_fingerprint), base64Encode(tls.fingerprint())}));
-
- const hostent* result;
-
- // Wait for validation to complete.
- EXPECT_TRUE(tls.waitForQueries(1, 5000));
-
- result = gethostbyname("twotlsfingerprints");
- ASSERT_FALSE(result == nullptr);
- EXPECT_EQ("1.2.3.1", ToString(result));
-
- // Wait for query to get counted.
- EXPECT_TRUE(tls.waitForQueries(2, 5000));
-
- // Clear TLS bit.
- ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
-}
-
-TEST_F(ResolverTest, GetHostByName_TlsFingerprintGoesBad) {
- constexpr char listen_addr[] = "127.0.0.3";
- constexpr char listen_udp[] = "53";
- constexpr char listen_tls[] = "853";
- constexpr char host_name1[] = "tlsfingerprintgoesbad1.example.com.";
- constexpr char host_name2[] = "tlsfingerprintgoesbad2.example.com.";
- const std::vector<DnsRecord> records = {
- {host_name1, ns_type::ns_t_a, "1.2.3.1"},
- {host_name2, ns_type::ns_t_a, "1.2.3.2"},
- };
-
- test::DNSResponder dns;
- StartDns(dns, records);
- std::vector<std::string> servers = { listen_addr };
-
- test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
- ASSERT_TRUE(tls.startServer());
- ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, "",
- {base64Encode(tls.fingerprint())}));
-
- const hostent* result;
-
- // Wait for validation to complete.
- EXPECT_TRUE(tls.waitForQueries(1, 5000));
-
- result = gethostbyname("tlsfingerprintgoesbad1");
- ASSERT_FALSE(result == nullptr);
- EXPECT_EQ("1.2.3.1", ToString(result));
-
- // Wait for query to get counted.
- EXPECT_TRUE(tls.waitForQueries(2, 5000));
-
- // Restart the TLS server. This will generate a new certificate whose fingerprint
- // no longer matches the stored fingerprint.
- tls.stopServer();
- tls.startServer();
-
- result = gethostbyname("tlsfingerprintgoesbad2");
- ASSERT_TRUE(result == nullptr);
- EXPECT_EQ(HOST_NOT_FOUND, h_errno);
-
- // Clear TLS bit.
- ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
-}
-
TEST_F(ResolverTest, GetHostByName_TlsFailover) {
constexpr char listen_addr1[] = "127.0.0.3";
constexpr char listen_addr2[] = "127.0.0.4";
@@ -1479,9 +1324,8 @@
test::DnsTlsFrontend tls2(listen_addr2, listen_tls, listen_addr2, listen_udp);
ASSERT_TRUE(tls1.startServer());
ASSERT_TRUE(tls2.startServer());
- ASSERT_TRUE(mDnsClient.SetResolversWithTls(
- servers, kDefaultSearchDomains, kDefaultParams, "",
- {base64Encode(tls1.fingerprint()), base64Encode(tls2.fingerprint())}));
+ ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams,
+ kDefaultPrivateDnsHostName));
const hostent* result;
@@ -1528,10 +1372,10 @@
test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
ASSERT_TRUE(tls.startServer());
ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams,
- "www.example.com", {}));
+ kDefaultIncorrectPrivateDnsHostName));
- // The TLS server's certificate doesn't chain to a known CA, and a nonempty name was specified,
- // so the client should fail the TLS handshake before ever issuing a query.
+ // The TLS handshake would fail because the name of TLS server doesn't
+ // match with TLS server's certificate.
EXPECT_FALSE(tls.waitForQueries(1, 500));
// The query should fail hard, because a name was specified.
@@ -1557,8 +1401,8 @@
test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
ASSERT_TRUE(tls.startServer());
- ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, "",
- {base64Encode(tls.fingerprint())}));
+ ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams,
+ kDefaultPrivateDnsHostName));
// Wait for validation to complete.
EXPECT_TRUE(tls.waitForQueries(1, 5000));
@@ -1590,8 +1434,6 @@
const unsigned BYPASS_NETID = NETID_USE_LOCAL_NAMESERVERS | TEST_NETID;
- const std::vector<uint8_t> NOOP_FINGERPRINT(android::netdutils::SHA256_SIZE, 0U);
-
const char ADDR4[] = "192.0.2.1";
const char ADDR6[] = "2001:db8::1";
@@ -1662,17 +1504,12 @@
kDefaultParams));
} else if (config.mode == OPPORTUNISTIC) {
ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
- kDefaultParams, "", {}));
+ kDefaultParams, ""));
// Wait for validation to complete.
if (config.withWorkingTLS) EXPECT_TRUE(tls.waitForQueries(1, 5000));
} else if (config.mode == STRICT) {
- // We use the existence of fingerprints to trigger strict mode,
- // rather than hostname validation.
- const auto& fingerprint =
- (config.withWorkingTLS) ? tls.fingerprint() : NOOP_FINGERPRINT;
ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
- kDefaultParams, "",
- {base64Encode(fingerprint)}));
+ kDefaultParams, kDefaultPrivateDnsHostName));
// Wait for validation to complete.
if (config.withWorkingTLS) EXPECT_TRUE(tls.waitForQueries(1, 5000));
}
@@ -1724,7 +1561,6 @@
}
TEST_F(ResolverTest, StrictMode_NoTlsServers) {
- const std::vector<uint8_t> NOOP_FINGERPRINT(android::netdutils::SHA256_SIZE, 0U);
constexpr char cleartext_addr[] = "127.0.0.53";
const std::vector<std::string> servers = { cleartext_addr };
constexpr char host_name[] = "strictmode.notlsips.example.com.";
@@ -1736,8 +1572,8 @@
test::DNSResponder dns(cleartext_addr);
StartDns(dns, records);
- ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, {},
- "", {base64Encode(NOOP_FINGERPRINT)}));
+ ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams,
+ kDefaultIncorrectPrivateDnsHostName));
addrinfo* ai_result = nullptr;
EXPECT_NE(0, getaddrinfo(host_name, nullptr, nullptr, &ai_result));
@@ -2309,7 +2145,6 @@
const char STRICT[] = "strict";
const char GETHOSTBYNAME[] = "gethostbyname";
const char GETADDRINFO[] = "getaddrinfo";
- const std::vector<uint8_t> NOOP_FINGERPRINT(android::netdutils::SHA256_SIZE, 0U);
const char ADDR4[] = "192.0.2.1";
const char CLEARTEXT_ADDR[] = "127.0.0.53";
const char CLEARTEXT_PORT[] = "53";
@@ -2395,13 +2230,13 @@
ASSERT_TRUE(tls.stopServer());
}
ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
- kDefaultParams, "", {}));
+ kDefaultParams, ""));
} else if (config.mode == OPPORTUNISTIC_TLS) {
if (!tls.running()) {
ASSERT_TRUE(tls.startServer());
}
ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
- kDefaultParams, "", {}));
+ kDefaultParams, ""));
// Wait for validation to complete.
EXPECT_TRUE(tls.waitForQueries(1, 5000));
} else if (config.mode == STRICT) {
@@ -2409,8 +2244,7 @@
ASSERT_TRUE(tls.startServer());
}
ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains,
- kDefaultParams, "",
- {base64Encode(tls.fingerprint())}));
+ kDefaultParams, kDefaultPrivateDnsHostName));
// Wait for validation to complete.
EXPECT_TRUE(tls.waitForQueries(1, 5000));
}
@@ -2468,8 +2302,7 @@
dns.setEdns(test::DNSResponder::Edns::FORMERR_ON_EDNS);
test::DnsTlsFrontend tls(CLEARTEXT_ADDR, TLS_PORT, CLEARTEXT_ADDR, CLEARTEXT_PORT);
ASSERT_TRUE(tls.startServer());
- ASSERT_TRUE(
- mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, "", {}));
+ ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
// Wait for validation complete.
EXPECT_TRUE(tls.waitForQueries(1, 5000));
// Shutdown TLS server to get an error. It's similar to no response case but without waiting.
@@ -2500,8 +2333,7 @@
ASSERT_TRUE(dns.startServer());
test::DnsTlsFrontend tls(CLEARTEXT_ADDR, TLS_PORT, CLEARTEXT_ADDR, CLEARTEXT_PORT);
ASSERT_TRUE(tls.startServer());
- ASSERT_TRUE(
- mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, "", {}));
+ ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
// Wait for validation complete.
EXPECT_TRUE(tls.waitForQueries(1, 5000));
// Shutdown TLS server to get an error. It's similar to no response case but without waiting.
@@ -3367,8 +3199,7 @@
ASSERT_TRUE(tls.startServer());
// Setup OPPORTUNISTIC mode and wait for the validation complete.
- ASSERT_TRUE(
- mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, "", {}));
+ ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, ""));
EXPECT_TRUE(tls.waitForQueries(1, 5000));
tls.clearQueries();
@@ -3386,8 +3217,8 @@
dns.clearQueries();
// Setup STRICT mode and wait for the validation complete.
- ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams, "",
- {base64Encode(tls.fingerprint())}));
+ ASSERT_TRUE(mDnsClient.SetResolversWithTls(servers, kDefaultSearchDomains, kDefaultParams,
+ kDefaultPrivateDnsHostName));
EXPECT_TRUE(tls.waitForQueries(1, 5000));
tls.clearQueries();
diff --git a/tests/dns_responder/Android.bp b/tests/dns_responder/Android.bp
index 3c90eae..0f07088 100644
--- a/tests/dns_responder/Android.bp
+++ b/tests/dns_responder/Android.bp
@@ -2,13 +2,12 @@
name: "libnetd_test_dnsresponder",
defaults: ["netd_defaults"],
shared_libs: [
+ "dnsresolver_aidl_interface-cpp",
"libbase",
"libbinder",
"libnetd_client",
"libnetdutils",
"libssl",
- "libutils",
- "dnsresolver_aidl_interface-V2-cpp",
"netd_aidl_interface-V2-cpp",
],
srcs: [
diff --git a/tests/dns_responder/dns_responder_client.cpp b/tests/dns_responder/dns_responder_client.cpp
index cfa7d7e..75d0e1d 100644
--- a/tests/dns_responder/dns_responder_client.cpp
+++ b/tests/dns_responder/dns_responder_client.cpp
@@ -35,6 +35,27 @@
using android::net::INetd;
using android::net::ResolverParamsParcel;
+static const char kCaCert[] = R"(
+-----BEGIN CERTIFICATE-----
+MIIC4TCCAcmgAwIBAgIUQUHZnWhL6M4qcS+I0lLkMyqf3VMwDQYJKoZIhvcNAQEL
+BQAwADAeFw0xOTA2MTAwODM3MzlaFw0yOTA2MDcwODM3MzlaMAAwggEiMA0GCSqG
+SIb3DQEBAQUAA4IBDwAwggEKAoIBAQCapRbBg6dRT4id4DxmlyktomE8gpm4W+VA
+ZOivhKat4CvGfVjVIAUYxV7LOGREkkT8Qhn5/gU0lShsnURzEDWY+IjMDDw+kRAm
+iFAlMRnCobTp/tenseNRB2tDuUhkRbzaT6qaidPbKy099p909gxf4YqsgY2NfsY2
+JkideqIkVq2hvLehsu3BgiK06TGUgxOHfj74vx7iGyujq1v38J1hlox5vj/svJF6
+jVdDw8p2UkJbO2N9u3al0nNSMG+MCgd3cvKUySTnjedYXsYB0WyH/JZn//KDq6o+
+as6eQVHuH1fu+8XPzBNATlkHzy+YAs7T+UWbkoa1F8wIElVQg66lAgMBAAGjUzBR
+MB0GA1UdDgQWBBShu/e54D3VdqdLOWo9Ou5hbjaIojAfBgNVHSMEGDAWgBShu/e5
+4D3VdqdLOWo9Ou5hbjaIojAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUA
+A4IBAQBFkEGqqzzdQlhP5m1kzh+SiUCwekzSump0CSk0JAXAdeLNbWs3H+pE1/hM
+Fx7oFonoX5O6pi68JfcIP0u7wNuZkKqubUl4gG6aHDkAad2oeTov0Be7EKt8Ekwf
+tmFWVQQrF9otlG3Stn4vmE5zVNGQXDgRsNKPekSo0XJz757q5WgblauB71Rekvio
+TCUXXt3jf3SuovcUFjgBkaohikBRbLiPWZrW4y0XUsgBKI6sLtiSZOPiNevY2xAR
+y7mCSmi4wP7vtUQ5G8znkAMKoo0FzyfjSogGQeREUM8Oe9Mmh/D39sq/D4TsiAxt
+Pwl59DlzlHHJhmOL+SCGciBX4X7p
+-----END CERTIFICATE-----
+)";
+
void DnsResponderClient::SetupMappings(unsigned num_hosts, const std::vector<std::string>& domains,
std::vector<Mapping>* mappings) {
mappings->resize(num_hosts * domains.size());
@@ -52,11 +73,12 @@
// TODO: Use SetResolverConfiguration() with ResolverParamsParcel struct directly.
// DEPRECATED: Use SetResolverConfiguration() in new code
-static ResolverParamsParcel makeResolverParamsParcel(
- int netId, const std::vector<int>& params, const std::vector<std::string>& servers,
- const std::vector<std::string>& domains, const std::string& tlsHostname,
- const std::vector<std::string>& tlsServers,
- const std::vector<std::string>& tlsFingerprints) {
+static ResolverParamsParcel makeResolverParamsParcel(int netId, const std::vector<int>& params,
+ const std::vector<std::string>& servers,
+ const std::vector<std::string>& domains,
+ const std::string& tlsHostname,
+ const std::vector<std::string>& tlsServers,
+ const std::string& caCert) {
using android::net::IDnsResolver;
ResolverParamsParcel paramsParcel;
@@ -79,7 +101,8 @@
paramsParcel.domains = domains;
paramsParcel.tlsName = tlsHostname;
paramsParcel.tlsServers = tlsServers;
- paramsParcel.tlsFingerprints = tlsFingerprints;
+ paramsParcel.tlsFingerprints = {};
+ paramsParcel.caCertificate = caCert;
return paramsParcel;
}
@@ -88,7 +111,7 @@
const std::vector<std::string>& domains,
const std::vector<int>& params) {
const auto& resolverParams =
- makeResolverParamsParcel(TEST_NETID, params, servers, domains, "", {}, {});
+ makeResolverParamsParcel(TEST_NETID, params, servers, domains, "", {}, "");
const auto rv = mDnsResolvSrv->setResolverConfiguration(resolverParams);
return rv.isOk();
}
@@ -97,10 +120,9 @@
const std::vector<std::string>& domains,
const std::vector<int>& params,
const std::vector<std::string>& tlsServers,
- const std::string& name,
- const std::vector<std::string>& fingerprints) {
+ const std::string& name) {
const auto& resolverParams = makeResolverParamsParcel(TEST_NETID, params, servers, domains,
- name, tlsServers, fingerprints);
+ name, tlsServers, kCaCert);
const auto rv = mDnsResolvSrv->setResolverConfiguration(resolverParams);
if (!rv.isOk()) LOG(ERROR) << "SetResolversWithTls() -> " << rv.toString8();
return rv.isOk();
diff --git a/tests/dns_responder/dns_responder_client.h b/tests/dns_responder/dns_responder_client.h
index 6e7f285..0b2966b 100644
--- a/tests/dns_responder/dns_responder_client.h
+++ b/tests/dns_responder/dns_responder_client.h
@@ -60,18 +60,16 @@
bool SetResolversWithTls(const std::vector<std::string>& servers,
const std::vector<std::string>& searchDomains,
- const std::vector<int>& params, const std::string& name,
- const std::vector<std::string>& fingerprints) {
+ const std::vector<int>& params, const std::string& name) {
// Pass servers as both network-assigned and TLS servers. Tests can
// determine on which server and by which protocol queries arrived.
- return SetResolversWithTls(servers, searchDomains, params, servers, name, fingerprints);
+ return SetResolversWithTls(servers, searchDomains, params, servers, name);
}
bool SetResolversWithTls(const std::vector<std::string>& servers,
const std::vector<std::string>& searchDomains,
const std::vector<int>& params,
- const std::vector<std::string>& tlsServers, const std::string& name,
- const std::vector<std::string>& fingerprints);
+ const std::vector<std::string>& tlsServers, const std::string& name);
static void SetupDNSServers(unsigned num_servers, const std::vector<Mapping>& mappings,
std::vector<std::unique_ptr<test::DNSResponder>>* dns,
diff --git a/tests/dns_responder/dns_tls_frontend.cpp b/tests/dns_responder/dns_tls_frontend.cpp
index 3a86aca..3b25324 100644
--- a/tests/dns_responder/dns_tls_frontend.cpp
+++ b/tests/dns_responder/dns_tls_frontend.cpp
@@ -21,6 +21,7 @@
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/ssl.h>
+#include <openssl/x509.h>
#include <sys/eventfd.h>
#include <sys/poll.h>
#include <sys/socket.h>
@@ -30,35 +31,92 @@
#define LOG_TAG "DnsTlsFrontend"
#include <android-base/logging.h>
#include <netdutils/InternetAddresses.h>
-#include <netdutils/NetworkConstants.h> // SHA256_SIZE
#include <netdutils/SocketOption.h>
using android::netdutils::enableSockopt;
using android::netdutils::ScopedAddrinfo;
namespace {
+/*
+ * test cert, key, and rootca files can be generated using openssl with
+ * the following commands:
+ *
+ * Create CA certificate:
+ * $ openssl genrsa 2048 > ca-key.pem
+ * $ openssl req -new -sha256 -x509 -nodes -days 3650 -key ca_key.pem -out ca_certificate.pem -subj
+ * '/C=/ST=/L=/CN=/emailAddress='
+ *
+ * Create server certificate:
+ * $ openssl req -sha256 -newkey rsa:2048 -days 3650 -nodes -keyout serve_key.pem -out
+ * server_req.pem -subj '/C=/ST=/L=/CN=example.com/emailAddress='
+ * $ openssl rsa -in server_key.pem
+ * -out server_key.pem $ openssl x509 -sha256 -req -in server_req.pem -days 3650 -CA
+ * ca_certificate.pem -CAkey ca_key.pem -set_serial 01 -out server_certificate.pem
+ *
+ * Verify the certificate:
+ * $ openssl verify -CAfile ca_certificate.pem server_certificate.pem
+ */
+// server_certificate.pem
+static const char kCertificate[] = R"(
+const std::string kCertificate =
+-----BEGIN CERTIFICATE-----
+MIICijCCAXICAQEwDQYJKoZIhvcNAQELBQAwADAeFw0xOTA2MTAwODM3MzlaFw0y
+OTA2MDcwODM3MzlaMBYxFDASBgNVBAMMC2V4YW1wbGUuY29tMIIBIjANBgkqhkiG
+9w0BAQEFAAOCAQ8AMIIBCgKCAQEAuo/v4VuY0Ees5HRx+NwTGm/bgToUFjq9R4z4
+FX+j8yyohxS8OxQZzpKu8JJytyPPi+SnXqZB25usGBPJHapD1Q5YYCIZF9EBztIq
+nEDbxvcWBrv7NDDhPMQ6v5YFhAIUN3a1yBESBWQOWsNkwJw04Wc4agZrhhnG/vS7
+gu1gn+CnaDYupAmHrGS9cSV/B9ZCpLhis2JxmULgdz6ZBee/x8dHHFd1Qeb/+G8j
+hBqhYbQK7ZFLmIO3DXrlP/ONXJ8IE2+PPDloiotkY5ar/1ZbRQS9fSKM9J6pipOE
+bAI1QF+tEn1bnaLfJfoMHIcb0p5xr04OALUZOGw4iVfxulMRIQIDAQABMA0GCSqG
+SIb3DQEBCwUAA4IBAQAuI2NjdWiD2lwmRraW6C7VBF+Sf+9QlzTVzSjuDbPkkYIo
+YWpeYsEeFO5NlxxXl77iu4MqznSAOK8GCiNDCCulDNWRhd5lcO1dVHLcIFYKZ+xv
+6IuH3vh60qJ2hoZbalwflnMQklqh3745ZyOH79dzKTFvlWyNJ2hQgP9bZ2g8F4od
+dS7aOwvx3DCv46b7vBJMKd53ZCdHubfFebDcGxc60DUR0fLSI/o1MJgriODZ1SX7
+sxwzrxvbJW0T+gJOL0C0lE6D84F9oL2u3ef17fC5u1bRd/77pzjTM+dQe7sZspCz
+iboTujdUqo+NSdWgwPUTGTYQg/1i9Qe0vjc0YplY
+-----END CERTIFICATE-----
+)";
-// Copied from DnsTlsTransport.
-bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
- int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), nullptr);
- unsigned char spki[spki_len];
- unsigned char* temp = spki;
- if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
- LOG(ERROR) << "SPKI length mismatch";
- return false;
- }
- out->resize(android::netdutils::SHA256_SIZE);
- unsigned int digest_len = 0;
- int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), nullptr);
- if (ret != 1) {
- LOG(ERROR) << "Server cert digest extraction failed";
- return false;
- }
- if (digest_len != out->size()) {
- LOG(ERROR) << "Wrong digest length: " << digest_len;
- return false;
- }
- return true;
+// server_key.pem
+static const char kPrivatekey[] = R"(
+-----BEGIN RSA PRIVATE KEY-----
+MIIEowIBAAKCAQEAuo/v4VuY0Ees5HRx+NwTGm/bgToUFjq9R4z4FX+j8yyohxS8
+OxQZzpKu8JJytyPPi+SnXqZB25usGBPJHapD1Q5YYCIZF9EBztIqnEDbxvcWBrv7
+NDDhPMQ6v5YFhAIUN3a1yBESBWQOWsNkwJw04Wc4agZrhhnG/vS7gu1gn+CnaDYu
+pAmHrGS9cSV/B9ZCpLhis2JxmULgdz6ZBee/x8dHHFd1Qeb/+G8jhBqhYbQK7ZFL
+mIO3DXrlP/ONXJ8IE2+PPDloiotkY5ar/1ZbRQS9fSKM9J6pipOEbAI1QF+tEn1b
+naLfJfoMHIcb0p5xr04OALUZOGw4iVfxulMRIQIDAQABAoIBACDLLF9wumviLYH6
+9g3IoZMEFpGgo+dEbAEnxnQA+9DDCNy1yGCaJ+8n2ZhwJboLkXAFwWXh07HGq3mQ
+AMo2I7ZPzzkWxVJqaubwCo1s2TUgOb71TDLgZLdJxwnmVRHfS650L3/7gC9yZxON
+RSiWTLVSb5gziLMJ1PD8E/nvwAxaJDlT6vzqwRbnHBkQoumTmds2ecLJd2/6pfl4
+bMhtIKA3ULqnJlqlRt6ds/pWU9ttmXEX52uaGhzaF7PRomOW5pKR6CyBzNCn/RNF
+ZPIINW1TVWss9NMZsJLdIzs7Oon5gQYil9rU2uiA5ZUanYDIL9DOMrfAM3hfUuFq
+ZOhfBAECgYEA36CT81EkdDE7pum/kIuCG3wDEls+xNbWmF76IJAxnolJzKvJsdJA
+0za/l1Qe3/bRYHZWKc7by45LFYefOsS29jqBnBBMLurI7OLtcXqkTSSm11AfsDDI
+gw4bKs81TYdHhnbIDGeApfSWOGXgDM/j4N3stuvY1lOIocXqKMomZVMCgYEA1ZHD
+jtxeAmCqzJHIJ4DOY0Y2RR3Bq3ue/mc8gmV9wDyJMMvBpvOoRkUjwbKZToi4Px30
+5fn6SCRtOKfEL0b9LV7JFsMr84Zoj8kjtnE0BdNfQqdE/uWltpATl6YUPlzqZTGs
+HBGVpsNCzYkjFu9m/zIiryCHY8Uut3VEZmNJjTsCgYEAgADBTzAuBpg7xeHcdhd0
+xNiqRXKXLkKvGQ6ca9E9pbp91LqsO63Wz09yQWO0PIxh8q4pycqPQyfS0KMNwKzi
+8XQxxiwJ/30Cv51xPlht/X4yReKmEMsLqwCDCnEK2LLLfSs2fOst11B2QBgINC03
+CfrdySKcvqmX9sl7rBdx/OMCgYB9t2o4RDwKhkDEXuRFbKsRARmdIeEJQqHa+4ZA
+8+FMMdZIJQj/b9qUUsqzkKBx/EUI0meAoN/Va6vnd8oiUlViSbNxdL4AghQ2353o
+HUcUTtJ6d+BDc4dSqgj+ccLk2ukXXGAFvcwr+DDwsFM5gv9MJYUJNcq8ziurzpnO
+848uVQKBgEmyAa2jt1qNpAvxU0MakJIuKhQl2b6/54EKi9WKqIMs1+rKk6O/Ck3n
++tEWqHhZ4uCRmvTgpOM821l4fTHsoJ8IGWV0mwfk95pEL+g/eBLExR4etMqaW9uz
+x8vnVTKNzZsAVgRcemcLqzuyuMg+/ZnH+YNMzMl0Nbkt+kE3FhfM
+-----END RSA PRIVATE KEY-----
+)";
+
+static bssl::UniquePtr<X509> stringToX509Certs(const char* certs) {
+ bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(certs, strlen(certs)));
+ return bssl::UniquePtr<X509>(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
+}
+
+// Convert a string buffer containing an RSA Private Key into an OpenSSL RSA struct.
+static bssl::UniquePtr<RSA> stringToRSAPrivateKey(const char* key) {
+ bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(key, strlen(key)));
+ return bssl::UniquePtr<RSA>(PEM_read_bio_RSAPrivateKey(bio.get(), nullptr, nullptr, nullptr));
}
std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
@@ -68,71 +126,11 @@
return std::string();
}
-bssl::UniquePtr<EVP_PKEY> make_private_key() {
- bssl::UniquePtr<BIGNUM> e(BN_new());
- if (!e) {
- LOG(ERROR) << "BN_new failed";
- return nullptr;
- }
- if (!BN_set_word(e.get(), RSA_F4)) {
- LOG(ERROR) << "BN_set_word failed";
- return nullptr;
- }
-
- bssl::UniquePtr<RSA> rsa(RSA_new());
- if (!rsa) {
- LOG(ERROR) << "RSA_new failed";
- return nullptr;
- }
- if (!RSA_generate_key_ex(rsa.get(), 2048, e.get(), nullptr)) {
- LOG(ERROR) << "RSA_generate_key_ex failed";
- return nullptr;
- }
-
- bssl::UniquePtr<EVP_PKEY> privkey(EVP_PKEY_new());
- if (!privkey) {
- LOG(ERROR) << "EVP_PKEY_new failed";
- return nullptr;
- }
- if (!EVP_PKEY_assign_RSA(privkey.get(), rsa.get())) {
- LOG(ERROR) << "EVP_PKEY_assign_RSA failed";
- return nullptr;
- }
-
- // |rsa| is now owned by |privkey|, so no need to free it.
- rsa.release();
- return privkey;
-}
-
-bssl::UniquePtr<X509> make_cert(EVP_PKEY* privkey, EVP_PKEY* parent_key) {
- bssl::UniquePtr<X509> cert(X509_new());
- if (!cert) {
- LOG(ERROR) << "X509_new failed";
- return nullptr;
- }
-
- ASN1_INTEGER_set(X509_get_serialNumber(cert.get()), 1);
-
- // Set one hour expiration.
- X509_gmtime_adj(X509_get_notBefore(cert.get()), 0);
- X509_gmtime_adj(X509_get_notAfter(cert.get()), 60 * 60);
-
- X509_set_pubkey(cert.get(), privkey);
-
- if (!X509_sign(cert.get(), parent_key, EVP_sha256())) {
- LOG(ERROR) << "X509_sign failed";
- return nullptr;
- }
-
- return cert;
-}
-
} // namespace
namespace test {
bool DnsTlsFrontend::startServer() {
- SSL_load_error_strings();
OpenSSL_add_ssl_algorithms();
// reset queries_ to 0 every time startServer called
@@ -147,37 +145,20 @@
SSL_CTX_set_ecdh_auto(ctx_.get(), 1);
- // Make certificate chain
- std::vector<bssl::UniquePtr<EVP_PKEY>> keys(chain_length_);
- for (int i = 0; i < chain_length_; ++i) {
- keys[i] = make_private_key();
- }
- std::vector<bssl::UniquePtr<X509>> certs(chain_length_);
- for (int i = 0; i < chain_length_; ++i) {
- int next = std::min(i + 1, chain_length_ - 1);
- certs[i] = make_cert(keys[i].get(), keys[next].get());
+ bssl::UniquePtr<X509> ca_certs(stringToX509Certs(kCertificate));
+ if (!ca_certs) {
+ LOG(ERROR) << "StringToX509Certs failed";
+ return false;
}
- // Install certificate chain.
- if (SSL_CTX_use_certificate(ctx_.get(), certs[0].get()) <= 0) {
+ if (SSL_CTX_use_certificate(ctx_.get(), ca_certs.get()) <= 0) {
LOG(ERROR) << "SSL_CTX_use_certificate failed";
return false;
}
- if (SSL_CTX_use_PrivateKey(ctx_.get(), keys[0].get()) <= 0) {
- LOG(ERROR) << "SSL_CTX_use_PrivateKey failed";
- return false;
- }
- for (int i = 1; i < chain_length_; ++i) {
- if (SSL_CTX_add1_chain_cert(ctx_.get(), certs[i].get()) != 1) {
- LOG(ERROR) << "SSL_CTX_add1_chain_cert failed";
- return false;
- }
- }
- // Report the fingerprint of the "middle" cert. For N = 2, this is the root.
- int fp_index = chain_length_ / 2;
- if (!getSPKIDigest(certs[fp_index].get(), &fingerprint_)) {
- LOG(ERROR) << "getSPKIDigest failed";
+ bssl::UniquePtr<RSA> private_key(stringToRSAPrivateKey(kPrivatekey));
+ if (SSL_CTX_use_RSAPrivateKey(ctx_.get(), private_key.get()) <= 0) {
+ LOG(ERROR) << "Error loading client RSA Private Key data.";
return false;
}
@@ -367,7 +348,6 @@
backend_socket_.reset();
event_fd_.reset();
ctx_.reset();
- fingerprint_.clear();
LOG(INFO) << "frontend stopped successfully";
return true;
}
diff --git a/tests/dns_responder/dns_tls_frontend.h b/tests/dns_responder/dns_tls_frontend.h
index 9e0b165..5af34b4 100644
--- a/tests/dns_responder/dns_tls_frontend.h
+++ b/tests/dns_responder/dns_tls_frontend.h
@@ -55,8 +55,6 @@
void clearQueries() { queries_ = 0; }
bool waitForQueries(int number, int timeoutMs) const;
void set_chain_length(int length) { chain_length_ = length; }
- // Represents a fingerprint from the middle of the certificate chain.
- const std::vector<uint8_t>& fingerprint() const { return fingerprint_; }
private:
void requestHandler();
@@ -83,7 +81,6 @@
std::thread handler_thread_ GUARDED_BY(update_mutex_);
std::mutex update_mutex_;
int chain_length_ = 1;
- std::vector<uint8_t> fingerprint_;
};
} // namespace test