Change most of PrivateDnsConfiguration methods to use ServerIdentity

Passing DnsTlsServer might be confusing because it's not straightforward
to know if a DnsTlsServer is a copy or onwed by PrivateDnsConfiguration.

This CL changes most of the methods to use ServerIdentity. The methods
can then get the corresponding DnsTlsServer by the new added method
getPrivateDns().

Bug: 186177613
Test: cd packages/modules/DnsResolver && atest
Change-Id: Ied4a4ee026862cd2c596586499cbfa7646eaaf2a
diff --git a/DnsTlsDispatcher.cpp b/DnsTlsDispatcher.cpp
index d14fdeb..ca577e7 100644
--- a/DnsTlsDispatcher.cpp
+++ b/DnsTlsDispatcher.cpp
@@ -214,8 +214,8 @@
             // happens, the xport will be marked as unusable and DoT queries won't be sent to
             // it anymore. Eventually, after IDLE_TIMEOUT, the xport will be destroyed, and
             // a new xport will be created.
-            const auto result =
-                    PrivateDnsConfiguration::getInstance().requestValidation(netId, server, mark);
+            const auto result = PrivateDnsConfiguration::getInstance().requestValidation(
+                    netId, PrivateDnsConfiguration::ServerIdentity{server}, mark);
             LOG(WARNING) << "Requested validation for " << server.toIpString() << " with mark 0x"
                          << std::hex << mark << ", "
                          << (result.ok() ? "succeeded" : "failed: " + result.error().message());
diff --git a/PrivateDnsConfiguration.cpp b/PrivateDnsConfiguration.cpp
index a2a48b2..7fbad43 100644
--- a/PrivateDnsConfiguration.cpp
+++ b/PrivateDnsConfiguration.cpp
@@ -110,7 +110,7 @@
 
         if (needsValidation(server)) {
             updateServerState(identity, Validation::in_process, netId);
-            startValidation(server, netId, false);
+            startValidation(identity, netId, false);
         }
     }
 
@@ -145,7 +145,7 @@
 }
 
 base::Result<void> PrivateDnsConfiguration::requestValidation(unsigned netId,
-                                                              const DnsTlsServer& server,
+                                                              const ServerIdentity& identity,
                                                               uint32_t mark) {
     std::lock_guard guard(mPrivateDnsLock);
 
@@ -159,40 +159,39 @@
         return Errorf("Private DNS setting is not opportunistic mode");
     }
 
-    auto netPair = mPrivateDnsTransports.find(netId);
-    if (netPair == mPrivateDnsTransports.end()) {
-        return Errorf("NetId not found in mPrivateDnsTransports");
+    auto result = getPrivateDnsLocked(identity, netId);
+    if (!result.ok()) {
+        return result.error();
     }
 
-    auto& tracker = netPair->second;
-    const ServerIdentity identity = ServerIdentity(server);
-    auto it = tracker.find(identity);
-    if (it == tracker.end()) {
-        return Errorf("Server was removed");
-    }
+    const DnsTlsServer* target = result.value();
 
-    const DnsTlsServer& target = it->second;
+    if (!target->active()) return Errorf("Server is not active");
 
-    if (!target.active()) return Errorf("Server is not active");
-
-    if (target.validationState() != Validation::success) {
+    if (target->validationState() != Validation::success) {
         return Errorf("Server validation state mismatched");
     }
 
     // Don't run the validation if |mark| (from android_net_context.dns_mark) is different.
     // This is to protect validation from running on unexpected marks.
     // Validation should be associated with a mark gotten by system permission.
-    if (target.mark != mark) return Errorf("Socket mark mismatched");
+    if (target->mark != mark) return Errorf("Socket mark mismatched");
 
     updateServerState(identity, Validation::in_process, netId);
-    startValidation(target, netId, true);
+    startValidation(identity, netId, true);
     return {};
 }
 
-void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId,
-                                              bool isRevalidation) REQUIRES(mPrivateDnsLock) {
-    // Note that capturing |server|, |netId|, and |isRevalidation| in this lambda create copies.
-    std::thread validate_thread([this, server, netId, isRevalidation] {
+void PrivateDnsConfiguration::startValidation(const ServerIdentity& identity, unsigned netId,
+                                              bool isRevalidation) {
+    // This ensures that the thread sends probe at least once in case
+    // the server is removed before the thread starts running.
+    // TODO: consider moving these code to the thread.
+    const auto result = getPrivateDnsLocked(identity, netId);
+    if (!result.ok()) return;
+    DnsTlsServer server = *result.value();
+
+    std::thread validate_thread([this, identity, server, netId, isRevalidation] {
         setThreadName(StringPrintf("TlsVerify_%u", netId).c_str());
 
         // cat /proc/sys/net/ipv4/tcp_syn_retries yields "6".
@@ -223,7 +222,7 @@
                          << server.toIpString();
 
             const bool needs_reeval =
-                    this->recordPrivateDnsValidation(server, netId, success, isRevalidation);
+                    this->recordPrivateDnsValidation(identity, netId, success, isRevalidation);
 
             if (!needs_reeval) {
                 break;
@@ -240,11 +239,11 @@
     validate_thread.detach();
 }
 
-void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const DnsTlsServer& server,
+void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const ServerIdentity& identity,
                                                             unsigned netId, bool success) {
     LOG(DEBUG) << "Sending validation " << (success ? "success" : "failure") << " event on netId "
-               << netId << " for " << server.toIpString() << " with hostname {" << server.name
-               << "}";
+               << netId << " for " << identity.sockaddr.ip().toString() << " with hostname {"
+               << identity.provider << "}";
     // Send a validation event to NetdEventListenerService.
     const auto& listeners = ResolverEventReporter::getInstance().getListeners();
     if (listeners.empty()) {
@@ -252,15 +251,16 @@
                 << "Validation event not sent since no INetdEventListener receiver is available.";
     }
     for (const auto& it : listeners) {
-        it->onPrivateDnsValidationEvent(netId, server.toIpString(), server.name, success);
+        it->onPrivateDnsValidationEvent(netId, identity.sockaddr.ip().toString(), identity.provider,
+                                        success);
     }
 
     // Send a validation event to unsolicited event listeners.
     const auto& unsolEventListeners = ResolverEventReporter::getInstance().getUnsolEventListeners();
     const PrivateDnsValidationEventParcel validationEvent = {
             .netId = static_cast<int32_t>(netId),
-            .ipAddress = server.toIpString(),
-            .hostname = server.name,
+            .ipAddress = identity.sockaddr.ip().toString(),
+            .hostname = identity.provider,
             .validation = success ? IDnsResolverUnsolicitedEventListener::VALIDATION_RESULT_SUCCESS
                                   : IDnsResolverUnsolicitedEventListener::VALIDATION_RESULT_FAILURE,
     };
@@ -269,11 +269,11 @@
     }
 }
 
-bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId,
-                                                         bool success, bool isRevalidation) {
+bool PrivateDnsConfiguration::recordPrivateDnsValidation(const ServerIdentity& identity,
+                                                         unsigned netId, bool success,
+                                                         bool isRevalidation) {
     constexpr bool NEEDS_REEVALUATION = true;
     constexpr bool DONT_REEVALUATE = false;
-    const ServerIdentity identity = ServerIdentity(server);
 
     std::lock_guard guard(mPrivateDnsLock);
 
@@ -303,23 +303,19 @@
     auto& tracker = netPair->second;
     auto serverPair = tracker.find(identity);
     if (serverPair == tracker.end()) {
-        LOG(WARNING) << "Server " << server.toIpString()
+        LOG(WARNING) << "Server " << identity.sockaddr.ip().toString()
                      << " was removed during private DNS validation";
         success = false;
         reevaluationStatus = DONT_REEVALUATE;
-    } else if (!(serverPair->second == server)) {
-        LOG(WARNING) << "Server " << server.toIpString()
-                     << " was changed during private DNS validation";
-        success = false;
-        reevaluationStatus = DONT_REEVALUATE;
     } else if (!serverPair->second.active()) {
-        LOG(WARNING) << "Server " << server.toIpString() << " was removed from the configuration";
+        LOG(WARNING) << "Server " << identity.sockaddr.ip().toString()
+                     << " was removed from the configuration";
         success = false;
         reevaluationStatus = DONT_REEVALUATE;
     }
 
     // Send private dns validation result to listeners.
-    sendPrivateDnsValidationEvent(server, netId, success);
+    sendPrivateDnsValidationEvent(identity, netId, success);
 
     if (success) {
         updateServerState(identity, Validation::success, netId);
@@ -338,19 +334,15 @@
 
 void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state,
                                                 uint32_t netId) {
-    auto netPair = mPrivateDnsTransports.find(netId);
-    if (netPair == mPrivateDnsTransports.end()) {
+    const auto result = getPrivateDnsLocked(identity, netId);
+    if (!result.ok()) {
         notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
         return;
     }
 
-    auto& tracker = netPair->second;
-    if (tracker.find(identity) == tracker.end()) {
-        notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
-        return;
-    }
+    auto* server = result.value();
 
-    tracker[identity].setValidationState(state);
+    server->setValidationState(state);
     notifyValidationStateUpdate(identity.sockaddr, state, netId);
 
     RecordEntry record(netId, identity, state);
@@ -373,6 +365,28 @@
     return false;
 }
 
+base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDns(const ServerIdentity& identity,
+                                                                   unsigned netId) {
+    std::lock_guard guard(mPrivateDnsLock);
+    return getPrivateDnsLocked(identity, netId);
+}
+
+base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDnsLocked(
+        const ServerIdentity& identity, unsigned netId) {
+    auto netPair = mPrivateDnsTransports.find(netId);
+    if (netPair == mPrivateDnsTransports.end()) {
+        return Errorf("Failed to get private DNS: netId {} not found", netId);
+    }
+
+    auto iter = netPair->second.find(identity);
+    if (iter == netPair->second.end()) {
+        return Errorf("Failed to get private DNS: server {{{}/{}}} not found", identity.sockaddr,
+                      identity.provider);
+    }
+
+    return &iter->second;
+}
+
 void PrivateDnsConfiguration::setObserver(PrivateDnsValidationObserver* observer) {
     std::lock_guard guard(mPrivateDnsLock);
     mObserver = observer;
diff --git a/PrivateDnsConfiguration.h b/PrivateDnsConfiguration.h
index d2a165c..1fce07c 100644
--- a/PrivateDnsConfiguration.h
+++ b/PrivateDnsConfiguration.h
@@ -33,6 +33,7 @@
 namespace android {
 namespace net {
 
+// TODO: decouple the dependency of DnsTlsServer.
 struct PrivateDnsStatus {
     PrivateDnsMode mode;
 
@@ -53,24 +54,6 @@
 
 class PrivateDnsConfiguration {
   public:
-    // The only instance of PrivateDnsConfiguration.
-    static PrivateDnsConfiguration& getInstance() {
-        static PrivateDnsConfiguration instance;
-        return instance;
-    }
-
-    int set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
-            const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock);
-
-    PrivateDnsStatus getStatus(unsigned netId) const EXCLUDES(mPrivateDnsLock);
-
-    void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);
-
-    // Request |server| to be revalidated on a connection tagged with |mark|.
-    // Returns a Result to indicate if the request is accepted.
-    base::Result<void> requestValidation(unsigned netId, const DnsTlsServer& server, uint32_t mark)
-            EXCLUDES(mPrivateDnsLock);
-
     struct ServerIdentity {
         const netdutils::IPSockAddr sockaddr;
         const std::string provider;
@@ -86,6 +69,24 @@
         }
     };
 
+    // The only instance of PrivateDnsConfiguration.
+    static PrivateDnsConfiguration& getInstance() {
+        static PrivateDnsConfiguration instance;
+        return instance;
+    }
+
+    int set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
+            const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock);
+
+    PrivateDnsStatus getStatus(unsigned netId) const EXCLUDES(mPrivateDnsLock);
+
+    void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);
+
+    // Request the server to be revalidated on a connection tagged with |mark|.
+    // Returns a Result to indicate if the request is accepted.
+    base::Result<void> requestValidation(unsigned netId, const ServerIdentity& identity,
+                                         uint32_t mark) EXCLUDES(mPrivateDnsLock);
+
     void setObserver(PrivateDnsValidationObserver* observer);
 
     void dump(netdutils::DumpWriter& dw) const;
@@ -97,23 +98,31 @@
 
     // Launchs a thread to run the validation for |server| on the network |netId|.
     // |isRevalidation| is true if this call is due to a revalidation request.
-    void startValidation(const DnsTlsServer& server, unsigned netId, bool isRevalidation)
+    void startValidation(const ServerIdentity& identity, unsigned netId, bool isRevalidation)
             REQUIRES(mPrivateDnsLock);
 
-    bool recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId, bool success,
+    bool recordPrivateDnsValidation(const ServerIdentity& identity, unsigned netId, bool success,
                                     bool isRevalidation) EXCLUDES(mPrivateDnsLock);
 
-    void sendPrivateDnsValidationEvent(const DnsTlsServer& server, unsigned netId, bool success)
+    void sendPrivateDnsValidationEvent(const ServerIdentity& identity, unsigned netId, bool success)
             REQUIRES(mPrivateDnsLock);
 
     // Decide if a validation for |server| is needed. Note that servers that have failed
     // multiple validation attempts but for which there is still a validating
     // thread running are marked as being Validation::in_process.
+    // TODO: decouple the dependency of DnsTlsServer.
     bool needsValidation(const DnsTlsServer& server) REQUIRES(mPrivateDnsLock);
 
     void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId)
             REQUIRES(mPrivateDnsLock);
 
+    // For testing.
+    base::Result<DnsTlsServer*> getPrivateDns(const ServerIdentity& identity, unsigned netId)
+            EXCLUDES(mPrivateDnsLock);
+
+    base::Result<DnsTlsServer*> getPrivateDnsLocked(const ServerIdentity& identity, unsigned netId)
+            REQUIRES(mPrivateDnsLock);
+
     mutable std::mutex mPrivateDnsLock;
     std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);
 
diff --git a/PrivateDnsConfigurationTest.cpp b/PrivateDnsConfigurationTest.cpp
index b6196a7..fa8371c 100644
--- a/PrivateDnsConfigurationTest.cpp
+++ b/PrivateDnsConfigurationTest.cpp
@@ -28,6 +28,8 @@
 
 class PrivateDnsConfigurationTest : public ::testing::Test {
   public:
+    using ServerIdentity = PrivateDnsConfiguration::ServerIdentity;
+
     static void SetUpTestSuite() {
         // stopServer() will be called in their destructor.
         ASSERT_TRUE(tls1.startServer());
@@ -100,6 +102,10 @@
         return (serverStateMap == mObserver.getServerStateMap());
     }
 
+    bool hasPrivateDnsServer(const ServerIdentity& identity, unsigned netId) {
+        return mPdc.getPrivateDns(identity, netId).ok();
+    }
+
     static constexpr uint32_t kNetId = 30;
     static constexpr uint32_t kMark = 30;
     static constexpr char kBackend[] = "127.0.2.1";
@@ -230,8 +236,6 @@
 }
 
 TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) {
-    using ServerIdentity = PrivateDnsConfiguration::ServerIdentity;
-
     DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 853));
     server.name = "dns.example.com";
 
@@ -254,6 +258,7 @@
 
 TEST_F(PrivateDnsConfigurationTest, RequestValidation) {
     const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
+    const ServerIdentity identity(server);
 
     testing::InSequence seq;
 
@@ -281,18 +286,18 @@
             EXPECT_CALL(mObserver,
                         onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
             EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
-            EXPECT_TRUE(mPdc.requestValidation(kNetId, server, kMark).ok());
+            EXPECT_TRUE(mPdc.requestValidation(kNetId, identity, kMark).ok());
         } else if (config == "IN_PROGRESS") {
             EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
-            EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark).ok());
+            EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok());
         } else if (config == "FAIL") {
-            EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark).ok());
+            EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok());
         }
 
         // Resending the same request or requesting nonexistent servers are denied.
-        EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark).ok());
-        EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark + 1).ok());
-        EXPECT_FALSE(mPdc.requestValidation(kNetId + 1, server, kMark).ok());
+        EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok());
+        EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark + 1).ok());
+        EXPECT_FALSE(mPdc.requestValidation(kNetId + 1, identity, kMark).ok());
 
         // Reset the test state.
         backend.setDeferredResp(false);
@@ -306,6 +311,26 @@
     }
 }
 
+TEST_F(PrivateDnsConfigurationTest, GetPrivateDns) {
+    const DnsTlsServer server1(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
+    const DnsTlsServer server2(netdutils::IPSockAddr::toIPSockAddr(kServer2, 853));
+
+    EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId));
+    EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId));
+
+    // Suppress the warning.
+    EXPECT_CALL(mObserver, onValidationStateUpdate).Times(2);
+
+    EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
+    expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
+
+    EXPECT_TRUE(hasPrivateDnsServer(ServerIdentity(server1), kNetId));
+    EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId));
+    EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId + 1));
+
+    ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
+}
+
 // TODO: add ValidationFail_Strict test.
 
 }  // namespace android::net