Make DnsTlsTransport's query method static
This refactor removes the need for DnsProxyListener's query hook
to call the DnsTlsTransport constructor. This will allow us to
maintain state for longer (e.g. reusing sockets) without
increasing the complexity of DnsProxyListener.
Bug: 63448521
Test: Integration tests pass.
Change-Id: I3ec3713e188ea11b160e61d1d873469c5ad57ae7
diff --git a/server/ResolverController.cpp b/server/ResolverController.cpp
index 5266166..238eef7 100644
--- a/server/ResolverController.cpp
+++ b/server/ResolverController.cpp
@@ -53,31 +53,25 @@
namespace {
-struct PrivateDnsServer {
- PrivateDnsServer(const sockaddr_storage& ss) : ss(ss) {}
- const sockaddr_storage ss;
- // For now, the fingerprints are always SHA-256. This is the only digest algorithm
- // that is mandatory to support (https://tools.ietf.org/html/rfc7858#section-4.2).
- std::set<std::vector<uint8_t>> fingerprints;
-};
-
// This comparison ignores ports and fingerprints.
-bool operator<(const PrivateDnsServer& x, const PrivateDnsServer& y) {
- if (x.ss.ss_family != y.ss.ss_family) {
- return x.ss.ss_family < y.ss.ss_family;
+struct AddressComparator {
+ bool operator() (const DnsTlsTransport::Server& x, const DnsTlsTransport::Server& y) const {
+ if (x.ss.ss_family != y.ss.ss_family) {
+ return x.ss.ss_family < y.ss.ss_family;
+ }
+ // Same address family. Compare IP addresses.
+ if (x.ss.ss_family == AF_INET) {
+ const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x.ss);
+ const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y.ss);
+ return x_sin.sin_addr.s_addr < y_sin.sin_addr.s_addr;
+ } else if (x.ss.ss_family == AF_INET6) {
+ const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x.ss);
+ const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y.ss);
+ return std::memcmp(x_sin6.sin6_addr.s6_addr, y_sin6.sin6_addr.s6_addr, 16) < 0;
+ }
+ return false; // Unknown address type. This is an error.
}
- // Same address family. Compare IP addresses.
- if (x.ss.ss_family == AF_INET) {
- const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x.ss);
- const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y.ss);
- return x_sin.sin_addr.s_addr < y_sin.sin_addr.s_addr;
- } else if (x.ss.ss_family == AF_INET6) {
- const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x.ss);
- const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y.ss);
- return std::memcmp(x_sin6.sin6_addr.s6_addr, y_sin6.sin6_addr.s6_addr, 16) < 0;
- }
- return false; // Unknown address type. This is an error.
-}
+};
bool parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
@@ -102,13 +96,13 @@
// Structure for tracking the entire set of known Private DNS servers.
std::mutex privateDnsLock;
-typedef std::set<PrivateDnsServer> PrivateDnsSet;
+typedef std::set<DnsTlsTransport::Server, AddressComparator> PrivateDnsSet;
PrivateDnsSet privateDnsServers;
// Structure for tracking the validation status of servers on a specific netid.
// Servers that fail validation are removed from the tracker, and can be retried.
enum class Validation : bool { in_process, success };
-typedef std::map<PrivateDnsServer, Validation> PrivateDnsTracker;
+typedef std::map<DnsTlsTransport::Server, Validation, AddressComparator> PrivateDnsTracker;
std::map<unsigned, PrivateDnsTracker> privateDnsTransports;
PrivateDnsSet parseServers(const char** servers, int numservers, in_port_t port) {
@@ -139,7 +133,8 @@
PrivateDnsSet intersection;
std::set_intersection(privateDnsServers.begin(), privateDnsServers.end(),
serversToCheck.begin(), serversToCheck.end(),
- std::inserter(intersection, intersection.begin()));
+ std::inserter(intersection, intersection.begin()),
+ AddressComparator());
if (intersection.empty()) {
return;
}
@@ -162,10 +157,9 @@
}
tracker[privateServer] = Validation::in_process;
std::thread validate_thread([privateServer, netId] {
- // validateDnsTlsServer() is a blocking call that performs network operations.
+ // ::validate() is a blocking call that performs network operations.
// It can take milliseconds to minutes, up to the SYN retry limit.
- bool success = validateDnsTlsServer(netId,
- privateServer.ss, privateServer.fingerprints);
+ bool success = DnsTlsTransport::validate(privateServer, netId);
std::lock_guard<std::mutex> guard(privateDnsLock);
auto netPair = privateDnsTransports.find(netId);
if (netPair == privateDnsTransports.end()) {
@@ -210,7 +204,7 @@
}
bool ResolverController::shouldUseTls(unsigned netId, const sockaddr_storage& insecureServer,
- sockaddr_storage* secureServer, std::set<std::vector<uint8_t>>* fingerprints) {
+ DnsTlsTransport::Server* secureServer) {
// This mutex is on the critical path of every DNS lookup that doesn't hit a local cache.
// If the overhead of mutex acquisition proves too high, we could reduce it by maintaining
// an atomic_int32_t counter of validated connections, and returning early if it's zero.
@@ -225,8 +219,7 @@
return false;
}
const auto& validatedServer = serverPair->first;
- *secureServer = validatedServer.ss;
- *fingerprints = validatedServer.fingerprints;
+ *secureServer = validatedServer;
return true;
}
@@ -457,13 +450,16 @@
if (!parseServer(server.c_str(), port, &parsed)) {
return INetd::PRIVATE_DNS_BAD_ADDRESS;
}
- PrivateDnsServer privateServer(parsed);
+ DnsTlsTransport::Server privateServer(parsed);
privateServer.fingerprints = fingerprints;
std::lock_guard<std::mutex> guard(privateDnsLock);
// Ensure we overwrite any previous matching server. This is necessary because equality is
// based only on the IP address, not the port or fingerprints.
privateDnsServers.erase(privateServer);
privateDnsServers.insert(privateServer);
+ if (DBG) {
+ ALOGD("Recorded private DNS server: %s", server.c_str());
+ }
return INetd::PRIVATE_DNS_SUCCESS;
}