Add a BackoffSequence utility; use it for Private DNS validation
Add a simple, if verbose, BackoffSequence class to encapsulate some
RFC 3315 section 14 style mechanics.
Test: as follows
- built
- flashed
- booted
- system/netd/tests/runtests.sh pass
- make netdutils_test && \
adb push .../data/nativetest64/netdutils_test/netdutils_test /data/nativetest64/netdutils_test && \
adb shell /data/nativetest64/netdutils_test passes
Bug: 64133961
Bug: 72344805
Change-Id: Ib15a9454e17529a735bca4d9a0e96de8baae84c3
diff --git a/server/ResolverController.cpp b/server/ResolverController.cpp
index 2731c8d..ff6b975 100644
--- a/server/ResolverController.cpp
+++ b/server/ResolverController.cpp
@@ -51,10 +51,13 @@
#include "ResolverStats.h"
#include "dns/DnsTlsTransport.h"
#include "dns/DnsTlsServer.h"
+#include "netdutils/BackoffSequence.h"
namespace android {
-namespace net {
+using netdutils::BackoffSequence;
+
+namespace net {
namespace {
std::string addrToString(const sockaddr_storage* addr) {
@@ -90,166 +93,6 @@
}
}
-std::mutex privateDnsLock;
-std::map<unsigned, PrivateDnsMode> privateDnsModes GUARDED_BY(privateDnsLock);
-// Structure for tracking the validation status of servers on a specific netId.
-// Using the AddressComparator ensures at most one entry per IP address.
-typedef std::map<DnsTlsServer, ResolverController::Validation,
- AddressComparator> PrivateDnsTracker;
-std::map<unsigned, PrivateDnsTracker> privateDnsTransports GUARDED_BY(privateDnsLock);
-EventReporter eventReporter;
-android::sp<android::net::metrics::INetdEventListener> netdEventListener;
-
-void validatePrivateDnsProvider(const DnsTlsServer& server,
- PrivateDnsTracker& tracker, unsigned netId) REQUIRES(privateDnsLock) {
- if (DBG) {
- ALOGD("validatePrivateDnsProvider(%s, %u)", addrToString(&(server.ss)).c_str(), netId);
- }
-
- tracker[server] = ResolverController::Validation::in_process;
- if (DBG) {
- ALOGD("Server %s marked as in_process. Tracker now has size %zu",
- addrToString(&(server.ss)).c_str(), tracker.size());
- }
- std::thread validate_thread([server, netId] {
- // ::validate() is a blocking call that performs network operations.
- // It can take milliseconds to minutes, up to the SYN retry limit.
- bool success = DnsTlsTransport::validate(server, netId);
- if (DBG) {
- ALOGD("validateDnsTlsServer returned %d for %s", success,
- addrToString(&(server.ss)).c_str());
- }
- std::lock_guard<std::mutex> guard(privateDnsLock);
- auto netPair = privateDnsTransports.find(netId);
- if (netPair == privateDnsTransports.end()) {
- ALOGW("netId %u was erased during private DNS validation", netId);
- return;
- }
- auto& tracker = netPair->second;
- auto serverPair = tracker.find(server);
- if (serverPair == tracker.end()) {
- ALOGW("Server %s was removed during private DNS validation",
- addrToString(&(server.ss)).c_str());
- success = false;
- }
- if (!(serverPair->first == server)) {
- ALOGW("Server %s was changed during private DNS validation",
- addrToString(&(server.ss)).c_str());
- success = false;
- }
-
- // Send a validation event to NetdEventListenerService.
- if (netdEventListener == nullptr) {
- netdEventListener = eventReporter.getNetdEventListener();
- }
- if (netdEventListener != nullptr) {
- const String16 ipLiteral(addrToString(&(server.ss)).c_str());
- const String16 hostname(server.name.empty() ? "" : server.name.c_str());
- netdEventListener->onPrivateDnsValidationEvent(netId, ipLiteral, hostname, success);
- if (DBG) {
- ALOGD("Sending validation %s event on netId %u for %s with hostname %s",
- success ? "success" : "failure", netId,
- addrToString(&(server.ss)).c_str(), server.name.c_str());
- }
- } else {
- ALOGE("Validation event not sent since NetdEventListenerService is unavailable.");
- }
-
- if (success) {
- tracker[server] = ResolverController::Validation::success;
- if (DBG) {
- ALOGD("Validation succeeded for %s! Tracker now has %zu entries.",
- addrToString(&(server.ss)).c_str(), tracker.size());
- }
- } else {
- // Validation failure is expected if a user is on a captive portal.
- // TODO: Trigger a second validation attempt after captive portal login
- // succeeds.
- tracker[server] = ResolverController::Validation::fail;
- if (DBG) {
- ALOGD("Validation failed for %s!", addrToString(&(server.ss)).c_str());
- }
- }
- });
- validate_thread.detach();
-}
-
-int setPrivateDnsConfiguration(int32_t netId,
- const std::vector<std::string>& servers, const std::string& name,
- const std::set<std::vector<uint8_t>>& fingerprints) {
- if (DBG) {
- ALOGD("setPrivateDnsConfiguration(%u, %zu, %s, %zu)",
- netId, servers.size(), name.c_str(), fingerprints.size());
- }
-
- const bool explicitlyConfigured = !name.empty() || !fingerprints.empty();
-
- // Parse the list of servers that has been passed in
- std::set<DnsTlsServer> tlsServers;
- for (size_t i = 0; i < servers.size(); ++i) {
- sockaddr_storage parsed;
- if (!parseServer(servers[i].c_str(), &parsed)) {
- return -EINVAL;
- }
- DnsTlsServer server(parsed);
- server.name = name;
- server.fingerprints = fingerprints;
- tlsServers.insert(server);
- }
-
- std::lock_guard<std::mutex> guard(privateDnsLock);
- if (explicitlyConfigured) {
- privateDnsModes[netId] = PrivateDnsMode::STRICT;
- } else if (!tlsServers.empty()) {
- privateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC;
- } else {
- privateDnsModes[netId] = PrivateDnsMode::OFF;
- privateDnsTransports.erase(netId);
- return 0;
- }
-
- // Create the tracker if it was not present
- auto netPair = privateDnsTransports.find(netId);
- if (netPair == privateDnsTransports.end()) {
- // No TLS tracker yet for this netId.
- bool added;
- std::tie(netPair, added) = privateDnsTransports.emplace(netId, PrivateDnsTracker());
- if (!added) {
- ALOGE("Memory error while recording private DNS for netId %d", netId);
- return -ENOMEM;
- }
- }
- auto& tracker = netPair->second;
-
- // Remove any servers from the tracker that are not in |servers| exactly.
- for (auto it = tracker.begin(); it != tracker.end();) {
- if (tlsServers.count(it->first) == 0) {
- it = tracker.erase(it);
- } else {
- ++it;
- }
- }
-
- // Add any new or changed servers to the tracker, and initiate async checks for them.
- for (const auto& server : tlsServers) {
- // Don't probe a server more than once. This means that the only way to
- // re-check a failed server is to remove it and re-add it from the netId.
- if (tracker.count(server) == 0) {
- validatePrivateDnsProvider(server, tracker, netId);
- }
- }
- return 0;
-}
-
-void clearPrivateDnsProviders(unsigned netId) {
- if (DBG) {
- ALOGD("clearPrivateDnsProviders(%u)", netId);
- }
- std::lock_guard<std::mutex> guard(privateDnsLock);
- privateDnsModes.erase(netId);
- privateDnsTransports.erase(netId);
-}
-
constexpr const char* validationStatusToString(ResolverController::Validation value) {
switch (value) {
case ResolverController::Validation::in_process: return "in_process";
@@ -261,6 +104,272 @@
}
}
+
+class PrivateDnsConfiguration {
+ public:
+ typedef ResolverController::PrivateDnsStatus PrivateDnsStatus;
+ typedef ResolverController::Validation Validation;
+ typedef std::map<DnsTlsServer, Validation, AddressComparator> PrivateDnsTracker;
+
+ int set(int32_t netId, const std::vector<std::string>& servers, const std::string& name,
+ const std::set<std::vector<uint8_t>>& fingerprints) {
+ if (DBG) {
+ ALOGD("PrivateDnsConfiguration::set(%u, %zu, %s, %zu)",
+ netId, servers.size(), name.c_str(), fingerprints.size());
+ }
+
+ const bool explicitlyConfigured = !name.empty() || !fingerprints.empty();
+
+ // Parse the list of servers that has been passed in
+ std::set<DnsTlsServer> tlsServers;
+ for (size_t i = 0; i < servers.size(); ++i) {
+ sockaddr_storage parsed;
+ if (!parseServer(servers[i].c_str(), &parsed)) {
+ return -EINVAL;
+ }
+ DnsTlsServer server(parsed);
+ server.name = name;
+ server.fingerprints = fingerprints;
+ tlsServers.insert(server);
+ }
+
+ std::lock_guard<std::mutex> guard(mPrivateDnsLock);
+ if (explicitlyConfigured) {
+ mPrivateDnsModes[netId] = PrivateDnsMode::STRICT;
+ } else if (!tlsServers.empty()) {
+ mPrivateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC;
+ } else {
+ mPrivateDnsModes[netId] = PrivateDnsMode::OFF;
+ mPrivateDnsTransports.erase(netId);
+ return 0;
+ }
+
+ // Create the tracker if it was not present
+ auto netPair = mPrivateDnsTransports.find(netId);
+ if (netPair == mPrivateDnsTransports.end()) {
+ // No TLS tracker yet for this netId.
+ bool added;
+ std::tie(netPair, added) = mPrivateDnsTransports.emplace(netId, PrivateDnsTracker());
+ if (!added) {
+ ALOGE("Memory error while recording private DNS for netId %d", netId);
+ return -ENOMEM;
+ }
+ }
+ auto& tracker = netPair->second;
+
+ // Remove any servers from the tracker that are not in |servers| exactly.
+ for (auto it = tracker.begin(); it != tracker.end();) {
+ if (tlsServers.count(it->first) == 0) {
+ it = tracker.erase(it);
+ } else {
+ ++it;
+ }
+ }
+
+ // Add any new or changed servers to the tracker, and initiate async checks for them.
+ for (const auto& server : tlsServers) {
+ // Don't probe a server more than once. This means that the only way to
+ // re-check a failed server is to remove it and re-add it from the netId.
+ if (tracker.count(server) == 0) {
+ validatePrivateDnsProvider(server, tracker, netId);
+ }
+ }
+ return 0;
+ }
+
+ PrivateDnsStatus getStatus(unsigned netId) {
+ PrivateDnsStatus status{PrivateDnsMode::OFF, {}};
+
+ // This mutex is on the critical path of every DNS lookup.
+ //
+ // If the overhead of mutex acquisition proves too high, we could reduce
+ // it by maintaining an atomic_int32_t counter of TLS-enabled netids, or
+ // by using an RWLock.
+ std::lock_guard<std::mutex> guard(mPrivateDnsLock);
+
+ const auto mode = mPrivateDnsModes.find(netId);
+ if (mode == mPrivateDnsModes.end()) return status;
+ status.mode = mode->second;
+
+ const auto netPair = mPrivateDnsTransports.find(netId);
+ if (netPair != mPrivateDnsTransports.end()) {
+ for (const auto& serverPair : netPair->second) {
+ if (serverPair.second == Validation::success) {
+ status.validatedServers.push_back(serverPair.first);
+ }
+ }
+ }
+
+ return status;
+ }
+
+ void clear(unsigned netId) {
+ if (DBG) {
+ ALOGD("PrivateDnsConfiguration::clear(%u)", netId);
+ }
+ std::lock_guard<std::mutex> guard(mPrivateDnsLock);
+ mPrivateDnsModes.erase(netId);
+ mPrivateDnsTransports.erase(netId);
+ }
+
+ void dump(DumpWriter& dw, unsigned netId) {
+ std::lock_guard<std::mutex> guard(mPrivateDnsLock);
+
+ const auto& mode = mPrivateDnsModes.find(netId);
+ dw.println("Private DNS mode: %s", getPrivateDnsModeString(
+ (mode != mPrivateDnsModes.end()) ? mode->second : PrivateDnsMode::OFF));
+ const auto& netPair = mPrivateDnsTransports.find(netId);
+ if (netPair == mPrivateDnsTransports.end()) {
+ dw.println("No Private DNS servers configured");
+ } else {
+ const auto& tracker = netPair->second;
+ dw.println("Private DNS configuration (%zu entries)", tracker.size());
+ dw.incIndent();
+ for (const auto& kv : tracker) {
+ const auto& server = kv.first;
+ const auto status = kv.second;
+ dw.println("%s name{%s} status{%s}",
+ addrToString(&(server.ss)).c_str(),
+ server.name.c_str(),
+ validationStatusToString(status));
+ }
+ dw.decIndent();
+ }
+ }
+
+ private:
+ void validatePrivateDnsProvider(const DnsTlsServer& server, PrivateDnsTracker& tracker,
+ unsigned netId) REQUIRES(mPrivateDnsLock) {
+ if (DBG) {
+ ALOGD("validatePrivateDnsProvider(%s, %u)", addrToString(&(server.ss)).c_str(), netId);
+ }
+
+ tracker[server] = Validation::in_process;
+ if (DBG) {
+ ALOGD("Server %s marked as in_process. Tracker now has size %zu",
+ addrToString(&(server.ss)).c_str(), tracker.size());
+ }
+ // Note that capturing |server| and |netId| in this lambda create copies.
+ std::thread validate_thread([this, server, netId] {
+ // cat /proc/sys/net/ipv4/tcp_syn_retries yields "6".
+ //
+ // Start with a 1 minute delay and backoff to once per hour.
+ //
+ // Assumptions:
+ // [1] Each TLS validation is ~10KB of certs+handshake+payload.
+ // [2] Network typically provision clients with <=4 nameservers.
+ // [3] Average month has 30 days.
+ //
+ // Each validation pass in a given hour is ~1.2MB of data. And 24
+ // such validation passes per day is about ~30MB per month, in the
+ // worst case. Otherwise, this will cost ~600 SYNs per month
+ // (6 SYNs per ip, 4 ips per validation pass, 24 passes per day).
+ auto backoff = BackoffSequence<>::Builder()
+ .withInitialRetransmissionTime(std::chrono::seconds(60))
+ .withMaximumRetransmissionTime(std::chrono::seconds(3600))
+ .build();
+
+ while (true) {
+ // ::validate() is a blocking call that performs network operations.
+ // It can take milliseconds to minutes, up to the SYN retry limit.
+ const bool success = DnsTlsTransport::validate(server, netId);
+ if (DBG) {
+ ALOGD("validateDnsTlsServer returned %d for %s", success,
+ addrToString(&(server.ss)).c_str());
+ }
+
+ const bool needs_reeval = this->recordPrivateDnsValidation(server, netId, success);
+ if (!needs_reeval) {
+ break;
+ }
+
+ if (backoff.hasNextTimeout()) {
+ std::this_thread::sleep_for(backoff.getNextTimeout());
+ } else {
+ break;
+ }
+ }
+ });
+ validate_thread.detach();
+ }
+
+ bool recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId, bool success) {
+ constexpr bool NEEDS_REEVALUATION = true;
+ constexpr bool DONT_REEVALUATE = false;
+
+ std::lock_guard<std::mutex> guard(mPrivateDnsLock);
+
+ auto netPair = mPrivateDnsTransports.find(netId);
+ if (netPair == mPrivateDnsTransports.end()) {
+ ALOGW("netId %u was erased during private DNS validation", netId);
+ return DONT_REEVALUATE;
+ }
+
+ bool reevaluationStatus = success ? DONT_REEVALUATE : NEEDS_REEVALUATION;
+
+ auto& tracker = netPair->second;
+ auto serverPair = tracker.find(server);
+ if (serverPair == tracker.end()) {
+ ALOGW("Server %s was removed during private DNS validation",
+ addrToString(&(server.ss)).c_str());
+ success = false;
+ reevaluationStatus = DONT_REEVALUATE;
+ } else if (!(serverPair->first == server)) {
+ // TODO: It doesn't seem correct to overwrite the tracker entry for
+ // |server| down below in this circumstance... Fix this.
+ ALOGW("Server %s was changed during private DNS validation",
+ addrToString(&(server.ss)).c_str());
+ success = false;
+ reevaluationStatus = DONT_REEVALUATE;
+ }
+
+ // Send a validation event to NetdEventListenerService.
+ if (mNetdEventListener == nullptr) {
+ mNetdEventListener = mEventReporter.getNetdEventListener();
+ }
+ if (mNetdEventListener != nullptr) {
+ const String16 ipLiteral(addrToString(&(server.ss)).c_str());
+ const String16 hostname(server.name.empty() ? "" : server.name.c_str());
+ mNetdEventListener->onPrivateDnsValidationEvent(netId, ipLiteral, hostname, success);
+ if (DBG) {
+ ALOGD("Sending validation %s event on netId %u for %s with hostname %s",
+ success ? "success" : "failure", netId,
+ addrToString(&(server.ss)).c_str(), server.name.c_str());
+ }
+ } else {
+ ALOGE("Validation event not sent since NetdEventListenerService is unavailable.");
+ }
+
+ if (success) {
+ tracker[server] = Validation::success;
+ if (DBG) {
+ ALOGD("Validation succeeded for %s! Tracker now has %zu entries.",
+ addrToString(&(server.ss)).c_str(), tracker.size());
+ }
+ } else {
+ // Validation failure is expected if a user is on a captive portal.
+ // TODO: Trigger a second validation attempt after captive portal login
+ // succeeds.
+ tracker[server] = Validation::fail;
+ if (DBG) {
+ ALOGD("Validation failed for %s!", addrToString(&(server.ss)).c_str());
+ }
+ }
+
+ return reevaluationStatus;
+ }
+
+ EventReporter mEventReporter;
+
+ std::mutex mPrivateDnsLock;
+ std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);
+ // Structure for tracking the validation status of servers on a specific netId.
+ // Using the AddressComparator ensures at most one entry per IP address.
+ std::map<unsigned, PrivateDnsTracker> mPrivateDnsTransports GUARDED_BY(mPrivateDnsLock);
+ android::sp<android::net::metrics::INetdEventListener>
+ mNetdEventListener GUARDED_BY(mPrivateDnsLock);
+} privateDnsConfiguration;
+
} // namespace
int ResolverController::setDnsServers(unsigned netId, const char* searchDomains,
@@ -272,30 +381,8 @@
}
ResolverController::PrivateDnsStatus
-ResolverController::getPrivateDnsStatus(unsigned netid) const {
- PrivateDnsStatus status{PrivateDnsMode::OFF, {}};
-
- // This mutex is on the critical path of every DNS lookup.
- //
- // If the overhead of mutex acquisition proves too high, we could reduce it
- // by maintaining an atomic_int32_t counter of TLS-enabled netids, or by
- // using an RWLock.
- std::lock_guard<std::mutex> guard(privateDnsLock);
-
- const auto mode = privateDnsModes.find(netid);
- if (mode == privateDnsModes.end()) return status;
- status.mode = mode->second;
-
- const auto netPair = privateDnsTransports.find(netid);
- if (netPair != privateDnsTransports.end()) {
- for (const auto& serverPair : netPair->second) {
- if (serverPair.second == Validation::success) {
- status.validatedServers.push_back(serverPair.first);
- }
- }
- }
-
- return status;
+ResolverController::getPrivateDnsStatus(unsigned netId) const {
+ return privateDnsConfiguration.getStatus(netId);
}
int ResolverController::clearDnsServers(unsigned netId) {
@@ -303,7 +390,7 @@
if (DBG) {
ALOGD("clearDnsServers netId = %u\n", netId);
}
- clearPrivateDnsProviders(netId);
+ privateDnsConfiguration.clear(netId);
return 0;
}
@@ -399,7 +486,7 @@
return -EINVAL;
}
- const int err = setPrivateDnsConfiguration(netId, tlsServers, tlsName, tlsFingerprints);
+ const int err = privateDnsConfiguration.set(netId, tlsServers, tlsName, tlsFingerprints);
if (err != 0) {
return err;
}
@@ -502,29 +589,8 @@
static_cast<unsigned>(params.min_samples),
static_cast<unsigned>(params.max_samples));
}
- {
- std::lock_guard<std::mutex> guard(privateDnsLock);
- const auto& mode = privateDnsModes.find(netId);
- dw.println("Private DNS mode: %s", getPrivateDnsModeString(
- mode != privateDnsModes.end() ? mode->second : PrivateDnsMode::OFF));
- const auto& netPair = privateDnsTransports.find(netId);
- if (netPair == privateDnsTransports.end()) {
- dw.println("No Private DNS servers configured");
- } else {
- const auto& tracker = netPair->second;
- dw.println("Private DNS configuration (%zu entries)", tracker.size());
- dw.incIndent();
- for (const auto& kv : tracker) {
- const auto& server = kv.first;
- const auto status = kv.second;
- dw.println("%s name{%s} status{%s}",
- addrToString(&(server.ss)).c_str(),
- server.name.c_str(),
- validationStatusToString(status));
- }
- dw.decIndent();
- }
- }
+
+ privateDnsConfiguration.dump(dw, netId);
}
dw.decIndent();
}