Change most of PrivateDnsConfiguration methods to use ServerIdentity

Passing DnsTlsServer might be confusing because it's not straightforward
to know if a DnsTlsServer is a copy or onwed by PrivateDnsConfiguration.

This CL changes most of the methods to use ServerIdentity. The methods
can then get the corresponding DnsTlsServer by the new added method
getPrivateDns().

Bug: 186177613
Test: cd packages/modules/DnsResolver && atest
Change-Id: Ied4a4ee026862cd2c596586499cbfa7646eaaf2a
diff --git a/PrivateDnsConfiguration.cpp b/PrivateDnsConfiguration.cpp
index a2a48b2..7fbad43 100644
--- a/PrivateDnsConfiguration.cpp
+++ b/PrivateDnsConfiguration.cpp
@@ -110,7 +110,7 @@
 
         if (needsValidation(server)) {
             updateServerState(identity, Validation::in_process, netId);
-            startValidation(server, netId, false);
+            startValidation(identity, netId, false);
         }
     }
 
@@ -145,7 +145,7 @@
 }
 
 base::Result<void> PrivateDnsConfiguration::requestValidation(unsigned netId,
-                                                              const DnsTlsServer& server,
+                                                              const ServerIdentity& identity,
                                                               uint32_t mark) {
     std::lock_guard guard(mPrivateDnsLock);
 
@@ -159,40 +159,39 @@
         return Errorf("Private DNS setting is not opportunistic mode");
     }
 
-    auto netPair = mPrivateDnsTransports.find(netId);
-    if (netPair == mPrivateDnsTransports.end()) {
-        return Errorf("NetId not found in mPrivateDnsTransports");
+    auto result = getPrivateDnsLocked(identity, netId);
+    if (!result.ok()) {
+        return result.error();
     }
 
-    auto& tracker = netPair->second;
-    const ServerIdentity identity = ServerIdentity(server);
-    auto it = tracker.find(identity);
-    if (it == tracker.end()) {
-        return Errorf("Server was removed");
-    }
+    const DnsTlsServer* target = result.value();
 
-    const DnsTlsServer& target = it->second;
+    if (!target->active()) return Errorf("Server is not active");
 
-    if (!target.active()) return Errorf("Server is not active");
-
-    if (target.validationState() != Validation::success) {
+    if (target->validationState() != Validation::success) {
         return Errorf("Server validation state mismatched");
     }
 
     // Don't run the validation if |mark| (from android_net_context.dns_mark) is different.
     // This is to protect validation from running on unexpected marks.
     // Validation should be associated with a mark gotten by system permission.
-    if (target.mark != mark) return Errorf("Socket mark mismatched");
+    if (target->mark != mark) return Errorf("Socket mark mismatched");
 
     updateServerState(identity, Validation::in_process, netId);
-    startValidation(target, netId, true);
+    startValidation(identity, netId, true);
     return {};
 }
 
-void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId,
-                                              bool isRevalidation) REQUIRES(mPrivateDnsLock) {
-    // Note that capturing |server|, |netId|, and |isRevalidation| in this lambda create copies.
-    std::thread validate_thread([this, server, netId, isRevalidation] {
+void PrivateDnsConfiguration::startValidation(const ServerIdentity& identity, unsigned netId,
+                                              bool isRevalidation) {
+    // This ensures that the thread sends probe at least once in case
+    // the server is removed before the thread starts running.
+    // TODO: consider moving these code to the thread.
+    const auto result = getPrivateDnsLocked(identity, netId);
+    if (!result.ok()) return;
+    DnsTlsServer server = *result.value();
+
+    std::thread validate_thread([this, identity, server, netId, isRevalidation] {
         setThreadName(StringPrintf("TlsVerify_%u", netId).c_str());
 
         // cat /proc/sys/net/ipv4/tcp_syn_retries yields "6".
@@ -223,7 +222,7 @@
                          << server.toIpString();
 
             const bool needs_reeval =
-                    this->recordPrivateDnsValidation(server, netId, success, isRevalidation);
+                    this->recordPrivateDnsValidation(identity, netId, success, isRevalidation);
 
             if (!needs_reeval) {
                 break;
@@ -240,11 +239,11 @@
     validate_thread.detach();
 }
 
-void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const DnsTlsServer& server,
+void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const ServerIdentity& identity,
                                                             unsigned netId, bool success) {
     LOG(DEBUG) << "Sending validation " << (success ? "success" : "failure") << " event on netId "
-               << netId << " for " << server.toIpString() << " with hostname {" << server.name
-               << "}";
+               << netId << " for " << identity.sockaddr.ip().toString() << " with hostname {"
+               << identity.provider << "}";
     // Send a validation event to NetdEventListenerService.
     const auto& listeners = ResolverEventReporter::getInstance().getListeners();
     if (listeners.empty()) {
@@ -252,15 +251,16 @@
                 << "Validation event not sent since no INetdEventListener receiver is available.";
     }
     for (const auto& it : listeners) {
-        it->onPrivateDnsValidationEvent(netId, server.toIpString(), server.name, success);
+        it->onPrivateDnsValidationEvent(netId, identity.sockaddr.ip().toString(), identity.provider,
+                                        success);
     }
 
     // Send a validation event to unsolicited event listeners.
     const auto& unsolEventListeners = ResolverEventReporter::getInstance().getUnsolEventListeners();
     const PrivateDnsValidationEventParcel validationEvent = {
             .netId = static_cast<int32_t>(netId),
-            .ipAddress = server.toIpString(),
-            .hostname = server.name,
+            .ipAddress = identity.sockaddr.ip().toString(),
+            .hostname = identity.provider,
             .validation = success ? IDnsResolverUnsolicitedEventListener::VALIDATION_RESULT_SUCCESS
                                   : IDnsResolverUnsolicitedEventListener::VALIDATION_RESULT_FAILURE,
     };
@@ -269,11 +269,11 @@
     }
 }
 
-bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId,
-                                                         bool success, bool isRevalidation) {
+bool PrivateDnsConfiguration::recordPrivateDnsValidation(const ServerIdentity& identity,
+                                                         unsigned netId, bool success,
+                                                         bool isRevalidation) {
     constexpr bool NEEDS_REEVALUATION = true;
     constexpr bool DONT_REEVALUATE = false;
-    const ServerIdentity identity = ServerIdentity(server);
 
     std::lock_guard guard(mPrivateDnsLock);
 
@@ -303,23 +303,19 @@
     auto& tracker = netPair->second;
     auto serverPair = tracker.find(identity);
     if (serverPair == tracker.end()) {
-        LOG(WARNING) << "Server " << server.toIpString()
+        LOG(WARNING) << "Server " << identity.sockaddr.ip().toString()
                      << " was removed during private DNS validation";
         success = false;
         reevaluationStatus = DONT_REEVALUATE;
-    } else if (!(serverPair->second == server)) {
-        LOG(WARNING) << "Server " << server.toIpString()
-                     << " was changed during private DNS validation";
-        success = false;
-        reevaluationStatus = DONT_REEVALUATE;
     } else if (!serverPair->second.active()) {
-        LOG(WARNING) << "Server " << server.toIpString() << " was removed from the configuration";
+        LOG(WARNING) << "Server " << identity.sockaddr.ip().toString()
+                     << " was removed from the configuration";
         success = false;
         reevaluationStatus = DONT_REEVALUATE;
     }
 
     // Send private dns validation result to listeners.
-    sendPrivateDnsValidationEvent(server, netId, success);
+    sendPrivateDnsValidationEvent(identity, netId, success);
 
     if (success) {
         updateServerState(identity, Validation::success, netId);
@@ -338,19 +334,15 @@
 
 void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state,
                                                 uint32_t netId) {
-    auto netPair = mPrivateDnsTransports.find(netId);
-    if (netPair == mPrivateDnsTransports.end()) {
+    const auto result = getPrivateDnsLocked(identity, netId);
+    if (!result.ok()) {
         notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
         return;
     }
 
-    auto& tracker = netPair->second;
-    if (tracker.find(identity) == tracker.end()) {
-        notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
-        return;
-    }
+    auto* server = result.value();
 
-    tracker[identity].setValidationState(state);
+    server->setValidationState(state);
     notifyValidationStateUpdate(identity.sockaddr, state, netId);
 
     RecordEntry record(netId, identity, state);
@@ -373,6 +365,28 @@
     return false;
 }
 
+base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDns(const ServerIdentity& identity,
+                                                                   unsigned netId) {
+    std::lock_guard guard(mPrivateDnsLock);
+    return getPrivateDnsLocked(identity, netId);
+}
+
+base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDnsLocked(
+        const ServerIdentity& identity, unsigned netId) {
+    auto netPair = mPrivateDnsTransports.find(netId);
+    if (netPair == mPrivateDnsTransports.end()) {
+        return Errorf("Failed to get private DNS: netId {} not found", netId);
+    }
+
+    auto iter = netPair->second.find(identity);
+    if (iter == netPair->second.end()) {
+        return Errorf("Failed to get private DNS: server {{{}/{}}} not found", identity.sockaddr,
+                      identity.provider);
+    }
+
+    return &iter->second;
+}
+
 void PrivateDnsConfiguration::setObserver(PrivateDnsValidationObserver* observer) {
     std::lock_guard guard(mPrivateDnsLock);
     mObserver = observer;