Add a BackoffSequence utility; use it for Private DNS validation

Add a simple, if verbose, BackoffSequence class to encapsulate some
RFC 3315 section 14 style mechanics.

Test: as follows
    - built
    - flashed
    - booted
    - system/netd/tests/runtests.sh pass
    - make netdutils_test && \
      adb push .../data/nativetest64/netdutils_test/netdutils_test /data/nativetest64/netdutils_test && \
      adb shell /data/nativetest64/netdutils_test passes
Bug: 64133961
Bug: 72344805
Change-Id: Ib15a9454e17529a735bca4d9a0e96de8baae84c3
diff --git a/server/ResolverController.cpp b/server/ResolverController.cpp
index 2731c8d..ff6b975 100644
--- a/server/ResolverController.cpp
+++ b/server/ResolverController.cpp
@@ -51,10 +51,13 @@
 #include "ResolverStats.h"
 #include "dns/DnsTlsTransport.h"
 #include "dns/DnsTlsServer.h"
+#include "netdutils/BackoffSequence.h"
 
 namespace android {
-namespace net {
 
+using netdutils::BackoffSequence;
+
+namespace net {
 namespace {
 
 std::string addrToString(const sockaddr_storage* addr) {
@@ -90,166 +93,6 @@
     }
 }
 
-std::mutex privateDnsLock;
-std::map<unsigned, PrivateDnsMode> privateDnsModes GUARDED_BY(privateDnsLock);
-// Structure for tracking the validation status of servers on a specific netId.
-// Using the AddressComparator ensures at most one entry per IP address.
-typedef std::map<DnsTlsServer, ResolverController::Validation,
-        AddressComparator> PrivateDnsTracker;
-std::map<unsigned, PrivateDnsTracker> privateDnsTransports GUARDED_BY(privateDnsLock);
-EventReporter eventReporter;
-android::sp<android::net::metrics::INetdEventListener> netdEventListener;
-
-void validatePrivateDnsProvider(const DnsTlsServer& server,
-        PrivateDnsTracker& tracker, unsigned netId) REQUIRES(privateDnsLock) {
-    if (DBG) {
-        ALOGD("validatePrivateDnsProvider(%s, %u)", addrToString(&(server.ss)).c_str(), netId);
-    }
-
-    tracker[server] = ResolverController::Validation::in_process;
-    if (DBG) {
-        ALOGD("Server %s marked as in_process.  Tracker now has size %zu",
-                addrToString(&(server.ss)).c_str(), tracker.size());
-    }
-    std::thread validate_thread([server, netId] {
-        // ::validate() is a blocking call that performs network operations.
-        // It can take milliseconds to minutes, up to the SYN retry limit.
-        bool success = DnsTlsTransport::validate(server, netId);
-        if (DBG) {
-            ALOGD("validateDnsTlsServer returned %d for %s", success,
-                    addrToString(&(server.ss)).c_str());
-        }
-        std::lock_guard<std::mutex> guard(privateDnsLock);
-        auto netPair = privateDnsTransports.find(netId);
-        if (netPair == privateDnsTransports.end()) {
-            ALOGW("netId %u was erased during private DNS validation", netId);
-            return;
-        }
-        auto& tracker = netPair->second;
-        auto serverPair = tracker.find(server);
-        if (serverPair == tracker.end()) {
-            ALOGW("Server %s was removed during private DNS validation",
-                    addrToString(&(server.ss)).c_str());
-            success = false;
-        }
-        if (!(serverPair->first == server)) {
-            ALOGW("Server %s was changed during private DNS validation",
-                    addrToString(&(server.ss)).c_str());
-            success = false;
-        }
-
-        // Send a validation event to NetdEventListenerService.
-        if (netdEventListener == nullptr) {
-            netdEventListener = eventReporter.getNetdEventListener();
-        }
-        if (netdEventListener != nullptr) {
-            const String16 ipLiteral(addrToString(&(server.ss)).c_str());
-            const String16 hostname(server.name.empty() ? "" : server.name.c_str());
-            netdEventListener->onPrivateDnsValidationEvent(netId, ipLiteral, hostname, success);
-            if (DBG) {
-                ALOGD("Sending validation %s event on netId %u for %s with hostname %s",
-                        success ? "success" : "failure", netId,
-                        addrToString(&(server.ss)).c_str(), server.name.c_str());
-            }
-        } else {
-            ALOGE("Validation event not sent since NetdEventListenerService is unavailable.");
-        }
-
-        if (success) {
-            tracker[server] = ResolverController::Validation::success;
-            if (DBG) {
-                ALOGD("Validation succeeded for %s! Tracker now has %zu entries.",
-                        addrToString(&(server.ss)).c_str(), tracker.size());
-            }
-        } else {
-            // Validation failure is expected if a user is on a captive portal.
-            // TODO: Trigger a second validation attempt after captive portal login
-            // succeeds.
-            tracker[server] = ResolverController::Validation::fail;
-            if (DBG) {
-                ALOGD("Validation failed for %s!", addrToString(&(server.ss)).c_str());
-            }
-        }
-    });
-    validate_thread.detach();
-}
-
-int setPrivateDnsConfiguration(int32_t netId,
-        const std::vector<std::string>& servers, const std::string& name,
-        const std::set<std::vector<uint8_t>>& fingerprints) {
-    if (DBG) {
-        ALOGD("setPrivateDnsConfiguration(%u, %zu, %s, %zu)",
-                netId, servers.size(), name.c_str(), fingerprints.size());
-    }
-
-    const bool explicitlyConfigured = !name.empty() || !fingerprints.empty();
-
-    // Parse the list of servers that has been passed in
-    std::set<DnsTlsServer> tlsServers;
-    for (size_t i = 0; i < servers.size(); ++i) {
-        sockaddr_storage parsed;
-        if (!parseServer(servers[i].c_str(), &parsed)) {
-            return -EINVAL;
-        }
-        DnsTlsServer server(parsed);
-        server.name = name;
-        server.fingerprints = fingerprints;
-        tlsServers.insert(server);
-    }
-
-    std::lock_guard<std::mutex> guard(privateDnsLock);
-    if (explicitlyConfigured) {
-        privateDnsModes[netId] = PrivateDnsMode::STRICT;
-    } else if (!tlsServers.empty()) {
-        privateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC;
-    } else {
-        privateDnsModes[netId] = PrivateDnsMode::OFF;
-        privateDnsTransports.erase(netId);
-        return 0;
-    }
-
-    // Create the tracker if it was not present
-    auto netPair = privateDnsTransports.find(netId);
-    if (netPair == privateDnsTransports.end()) {
-        // No TLS tracker yet for this netId.
-        bool added;
-        std::tie(netPair, added) = privateDnsTransports.emplace(netId, PrivateDnsTracker());
-        if (!added) {
-            ALOGE("Memory error while recording private DNS for netId %d", netId);
-            return -ENOMEM;
-        }
-    }
-    auto& tracker = netPair->second;
-
-    // Remove any servers from the tracker that are not in |servers| exactly.
-    for (auto it = tracker.begin(); it != tracker.end();) {
-        if (tlsServers.count(it->first) == 0) {
-            it = tracker.erase(it);
-        } else {
-            ++it;
-        }
-    }
-
-    // Add any new or changed servers to the tracker, and initiate async checks for them.
-    for (const auto& server : tlsServers) {
-        // Don't probe a server more than once.  This means that the only way to
-        // re-check a failed server is to remove it and re-add it from the netId.
-        if (tracker.count(server) == 0) {
-            validatePrivateDnsProvider(server, tracker, netId);
-        }
-    }
-    return 0;
-}
-
-void clearPrivateDnsProviders(unsigned netId) {
-    if (DBG) {
-        ALOGD("clearPrivateDnsProviders(%u)", netId);
-    }
-    std::lock_guard<std::mutex> guard(privateDnsLock);
-    privateDnsModes.erase(netId);
-    privateDnsTransports.erase(netId);
-}
-
 constexpr const char* validationStatusToString(ResolverController::Validation value) {
     switch (value) {
         case ResolverController::Validation::in_process:     return "in_process";
@@ -261,6 +104,272 @@
     }
 }
 
+
+class PrivateDnsConfiguration {
+  public:
+    typedef ResolverController::PrivateDnsStatus PrivateDnsStatus;
+    typedef ResolverController::Validation Validation;
+    typedef std::map<DnsTlsServer, Validation, AddressComparator> PrivateDnsTracker;
+
+    int set(int32_t netId, const std::vector<std::string>& servers, const std::string& name,
+            const std::set<std::vector<uint8_t>>& fingerprints) {
+        if (DBG) {
+            ALOGD("PrivateDnsConfiguration::set(%u, %zu, %s, %zu)",
+                    netId, servers.size(), name.c_str(), fingerprints.size());
+        }
+
+        const bool explicitlyConfigured = !name.empty() || !fingerprints.empty();
+
+        // Parse the list of servers that has been passed in
+        std::set<DnsTlsServer> tlsServers;
+        for (size_t i = 0; i < servers.size(); ++i) {
+            sockaddr_storage parsed;
+            if (!parseServer(servers[i].c_str(), &parsed)) {
+                return -EINVAL;
+            }
+            DnsTlsServer server(parsed);
+            server.name = name;
+            server.fingerprints = fingerprints;
+            tlsServers.insert(server);
+        }
+
+        std::lock_guard<std::mutex> guard(mPrivateDnsLock);
+        if (explicitlyConfigured) {
+            mPrivateDnsModes[netId] = PrivateDnsMode::STRICT;
+        } else if (!tlsServers.empty()) {
+            mPrivateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC;
+        } else {
+            mPrivateDnsModes[netId] = PrivateDnsMode::OFF;
+            mPrivateDnsTransports.erase(netId);
+            return 0;
+        }
+
+        // Create the tracker if it was not present
+        auto netPair = mPrivateDnsTransports.find(netId);
+        if (netPair == mPrivateDnsTransports.end()) {
+            // No TLS tracker yet for this netId.
+            bool added;
+            std::tie(netPair, added) = mPrivateDnsTransports.emplace(netId, PrivateDnsTracker());
+            if (!added) {
+                ALOGE("Memory error while recording private DNS for netId %d", netId);
+                return -ENOMEM;
+            }
+        }
+        auto& tracker = netPair->second;
+
+        // Remove any servers from the tracker that are not in |servers| exactly.
+        for (auto it = tracker.begin(); it != tracker.end();) {
+            if (tlsServers.count(it->first) == 0) {
+                it = tracker.erase(it);
+            } else {
+                ++it;
+            }
+        }
+
+        // Add any new or changed servers to the tracker, and initiate async checks for them.
+        for (const auto& server : tlsServers) {
+            // Don't probe a server more than once.  This means that the only way to
+            // re-check a failed server is to remove it and re-add it from the netId.
+            if (tracker.count(server) == 0) {
+                validatePrivateDnsProvider(server, tracker, netId);
+            }
+        }
+        return 0;
+    }
+
+    PrivateDnsStatus getStatus(unsigned netId) {
+        PrivateDnsStatus status{PrivateDnsMode::OFF, {}};
+
+        // This mutex is on the critical path of every DNS lookup.
+        //
+        // If the overhead of mutex acquisition proves too high, we could reduce
+        // it by maintaining an atomic_int32_t counter of TLS-enabled netids, or
+        // by using an RWLock.
+        std::lock_guard<std::mutex> guard(mPrivateDnsLock);
+
+        const auto mode = mPrivateDnsModes.find(netId);
+        if (mode == mPrivateDnsModes.end()) return status;
+        status.mode = mode->second;
+
+        const auto netPair = mPrivateDnsTransports.find(netId);
+        if (netPair != mPrivateDnsTransports.end()) {
+            for (const auto& serverPair : netPair->second) {
+                if (serverPair.second == Validation::success) {
+                    status.validatedServers.push_back(serverPair.first);
+                }
+            }
+        }
+
+        return status;
+    }
+
+    void clear(unsigned netId) {
+        if (DBG) {
+            ALOGD("PrivateDnsConfiguration::clear(%u)", netId);
+        }
+        std::lock_guard<std::mutex> guard(mPrivateDnsLock);
+        mPrivateDnsModes.erase(netId);
+        mPrivateDnsTransports.erase(netId);
+    }
+
+    void dump(DumpWriter& dw, unsigned netId) {
+        std::lock_guard<std::mutex> guard(mPrivateDnsLock);
+
+        const auto& mode = mPrivateDnsModes.find(netId);
+        dw.println("Private DNS mode: %s", getPrivateDnsModeString(
+                (mode != mPrivateDnsModes.end()) ? mode->second : PrivateDnsMode::OFF));
+        const auto& netPair = mPrivateDnsTransports.find(netId);
+        if (netPair == mPrivateDnsTransports.end()) {
+            dw.println("No Private DNS servers configured");
+        } else {
+            const auto& tracker = netPair->second;
+            dw.println("Private DNS configuration (%zu entries)", tracker.size());
+            dw.incIndent();
+            for (const auto& kv : tracker) {
+                const auto& server = kv.first;
+                const auto status = kv.second;
+                dw.println("%s name{%s} status{%s}",
+                        addrToString(&(server.ss)).c_str(),
+                        server.name.c_str(),
+                        validationStatusToString(status));
+            }
+            dw.decIndent();
+        }
+    }
+
+  private:
+    void validatePrivateDnsProvider(const DnsTlsServer& server, PrivateDnsTracker& tracker,
+            unsigned netId) REQUIRES(mPrivateDnsLock) {
+        if (DBG) {
+            ALOGD("validatePrivateDnsProvider(%s, %u)", addrToString(&(server.ss)).c_str(), netId);
+        }
+
+        tracker[server] = Validation::in_process;
+        if (DBG) {
+            ALOGD("Server %s marked as in_process.  Tracker now has size %zu",
+                    addrToString(&(server.ss)).c_str(), tracker.size());
+        }
+        // Note that capturing |server| and |netId| in this lambda create copies.
+        std::thread validate_thread([this, server, netId] {
+            // cat /proc/sys/net/ipv4/tcp_syn_retries yields "6".
+            //
+            // Start with a 1 minute delay and backoff to once per hour.
+            //
+            // Assumptions:
+            //     [1] Each TLS validation is ~10KB of certs+handshake+payload.
+            //     [2] Network typically provision clients with <=4 nameservers.
+            //     [3] Average month has 30 days.
+            //
+            // Each validation pass in a given hour is ~1.2MB of data. And 24
+            // such validation passes per day is about ~30MB per month, in the
+            // worst case. Otherwise, this will cost ~600 SYNs per month
+            // (6 SYNs per ip, 4 ips per validation pass, 24 passes per day).
+            auto backoff = BackoffSequence<>::Builder()
+                    .withInitialRetransmissionTime(std::chrono::seconds(60))
+                    .withMaximumRetransmissionTime(std::chrono::seconds(3600))
+                    .build();
+
+            while (true) {
+                // ::validate() is a blocking call that performs network operations.
+                // It can take milliseconds to minutes, up to the SYN retry limit.
+                const bool success = DnsTlsTransport::validate(server, netId);
+                if (DBG) {
+                    ALOGD("validateDnsTlsServer returned %d for %s", success,
+                            addrToString(&(server.ss)).c_str());
+                }
+
+                const bool needs_reeval = this->recordPrivateDnsValidation(server, netId, success);
+                if (!needs_reeval) {
+                    break;
+                }
+
+                if (backoff.hasNextTimeout()) {
+                    std::this_thread::sleep_for(backoff.getNextTimeout());
+                } else {
+                    break;
+                }
+            }
+        });
+        validate_thread.detach();
+    }
+
+    bool recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId, bool success) {
+        constexpr bool NEEDS_REEVALUATION = true;
+        constexpr bool DONT_REEVALUATE = false;
+
+        std::lock_guard<std::mutex> guard(mPrivateDnsLock);
+
+        auto netPair = mPrivateDnsTransports.find(netId);
+        if (netPair == mPrivateDnsTransports.end()) {
+            ALOGW("netId %u was erased during private DNS validation", netId);
+            return DONT_REEVALUATE;
+        }
+
+        bool reevaluationStatus = success ? DONT_REEVALUATE : NEEDS_REEVALUATION;
+
+        auto& tracker = netPair->second;
+        auto serverPair = tracker.find(server);
+        if (serverPair == tracker.end()) {
+            ALOGW("Server %s was removed during private DNS validation",
+                    addrToString(&(server.ss)).c_str());
+            success = false;
+            reevaluationStatus = DONT_REEVALUATE;
+        } else if (!(serverPair->first == server)) {
+            // TODO: It doesn't seem correct to overwrite the tracker entry for
+            // |server| down below in this circumstance... Fix this.
+            ALOGW("Server %s was changed during private DNS validation",
+                    addrToString(&(server.ss)).c_str());
+            success = false;
+            reevaluationStatus = DONT_REEVALUATE;
+        }
+
+        // Send a validation event to NetdEventListenerService.
+        if (mNetdEventListener == nullptr) {
+            mNetdEventListener = mEventReporter.getNetdEventListener();
+        }
+        if (mNetdEventListener != nullptr) {
+            const String16 ipLiteral(addrToString(&(server.ss)).c_str());
+            const String16 hostname(server.name.empty() ? "" : server.name.c_str());
+            mNetdEventListener->onPrivateDnsValidationEvent(netId, ipLiteral, hostname, success);
+            if (DBG) {
+                ALOGD("Sending validation %s event on netId %u for %s with hostname %s",
+                        success ? "success" : "failure", netId,
+                        addrToString(&(server.ss)).c_str(), server.name.c_str());
+            }
+        } else {
+            ALOGE("Validation event not sent since NetdEventListenerService is unavailable.");
+        }
+
+        if (success) {
+            tracker[server] = Validation::success;
+            if (DBG) {
+                ALOGD("Validation succeeded for %s! Tracker now has %zu entries.",
+                        addrToString(&(server.ss)).c_str(), tracker.size());
+            }
+        } else {
+            // Validation failure is expected if a user is on a captive portal.
+            // TODO: Trigger a second validation attempt after captive portal login
+            // succeeds.
+            tracker[server] = Validation::fail;
+            if (DBG) {
+                ALOGD("Validation failed for %s!", addrToString(&(server.ss)).c_str());
+            }
+        }
+
+        return reevaluationStatus;
+    }
+
+    EventReporter mEventReporter;
+
+    std::mutex mPrivateDnsLock;
+    std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);
+    // Structure for tracking the validation status of servers on a specific netId.
+    // Using the AddressComparator ensures at most one entry per IP address.
+    std::map<unsigned, PrivateDnsTracker> mPrivateDnsTransports GUARDED_BY(mPrivateDnsLock);
+    android::sp<android::net::metrics::INetdEventListener>
+            mNetdEventListener GUARDED_BY(mPrivateDnsLock);
+} privateDnsConfiguration;
+
 }  // namespace
 
 int ResolverController::setDnsServers(unsigned netId, const char* searchDomains,
@@ -272,30 +381,8 @@
 }
 
 ResolverController::PrivateDnsStatus
-ResolverController::getPrivateDnsStatus(unsigned netid) const {
-    PrivateDnsStatus status{PrivateDnsMode::OFF, {}};
-
-    // This mutex is on the critical path of every DNS lookup.
-    //
-    // If the overhead of mutex acquisition proves too high, we could reduce it
-    // by maintaining an atomic_int32_t counter of TLS-enabled netids, or by
-    // using an RWLock.
-    std::lock_guard<std::mutex> guard(privateDnsLock);
-
-    const auto mode = privateDnsModes.find(netid);
-    if (mode == privateDnsModes.end()) return status;
-    status.mode = mode->second;
-
-    const auto netPair = privateDnsTransports.find(netid);
-    if (netPair != privateDnsTransports.end()) {
-        for (const auto& serverPair : netPair->second) {
-            if (serverPair.second == Validation::success) {
-                status.validatedServers.push_back(serverPair.first);
-            }
-        }
-    }
-
-    return status;
+ResolverController::getPrivateDnsStatus(unsigned netId) const {
+    return privateDnsConfiguration.getStatus(netId);
 }
 
 int ResolverController::clearDnsServers(unsigned netId) {
@@ -303,7 +390,7 @@
     if (DBG) {
         ALOGD("clearDnsServers netId = %u\n", netId);
     }
-    clearPrivateDnsProviders(netId);
+    privateDnsConfiguration.clear(netId);
     return 0;
 }
 
@@ -399,7 +486,7 @@
         return -EINVAL;
     }
 
-    const int err = setPrivateDnsConfiguration(netId, tlsServers, tlsName, tlsFingerprints);
+    const int err = privateDnsConfiguration.set(netId, tlsServers, tlsName, tlsFingerprints);
     if (err != 0) {
         return err;
     }
@@ -502,29 +589,8 @@
                     static_cast<unsigned>(params.min_samples),
                     static_cast<unsigned>(params.max_samples));
         }
-        {
-            std::lock_guard<std::mutex> guard(privateDnsLock);
-            const auto& mode = privateDnsModes.find(netId);
-            dw.println("Private DNS mode: %s", getPrivateDnsModeString(
-                    mode != privateDnsModes.end() ? mode->second : PrivateDnsMode::OFF));
-            const auto& netPair = privateDnsTransports.find(netId);
-            if (netPair == privateDnsTransports.end()) {
-                dw.println("No Private DNS servers configured");
-            } else {
-                const auto& tracker = netPair->second;
-                dw.println("Private DNS configuration (%zu entries)", tracker.size());
-                dw.incIndent();
-                for (const auto& kv : tracker) {
-                    const auto& server = kv.first;
-                    const auto status = kv.second;
-                    dw.println("%s name{%s} status{%s}",
-                            addrToString(&(server.ss)).c_str(),
-                            server.name.c_str(),
-                            validationStatusToString(status));
-                }
-                dw.decIndent();
-            }
-        }
+
+        privateDnsConfiguration.dump(dw, netId);
     }
     dw.decIndent();
 }