Support sending validation request to PrivateDnsConfiguration

Extend PrivateDnsConfiguration to support validation request.

The request is deniable. If the request is denied, no validation
starts. Callers can know if requests are accepted by the return
value of the call.

This change also extends DnsTlsServer to store the mark used by
validation, which helps on preventing running validation with
an unexpected socket mark and resulting in updating wrong validation
state.

Bug: 79727473
Test: cd packages/modules/DnsResolver && atest
Change-Id: Ib92f6b4dd94ed426bf28cb9756d1514e34f16140
diff --git a/DnsTlsServer.h b/DnsTlsServer.h
index 6a72d88..9750dff 100644
--- a/DnsTlsServer.h
+++ b/DnsTlsServer.h
@@ -74,6 +74,12 @@
     Validation validationState() const { return mValidation; }
     void setValidationState(Validation val) { mValidation = val; }
 
+    // The socket mark used for validation.
+    // Note that the mark of a connection to which the DnsResolver sends app's DNS requests can
+    // be different.
+    // TODO: make it const.
+    uint32_t mark = 0;
+
     // Return whether or not the server can be used for a network. It depends on
     // the resolver configuration.
     bool active() const { return mActive; }
diff --git a/PrivateDnsConfiguration.cpp b/PrivateDnsConfiguration.cpp
index 47969b3..8fbd573 100644
--- a/PrivateDnsConfiguration.cpp
+++ b/PrivateDnsConfiguration.cpp
@@ -70,6 +70,7 @@
         DnsTlsServer server(parsed);
         server.name = name;
         server.certificate = caCert;
+        server.mark = mark;
         tmp[ServerIdentity(server)] = server;
     }
 
@@ -140,6 +141,37 @@
     mPrivateDnsTransports.erase(netId);
 }
 
+bool PrivateDnsConfiguration::requestValidation(unsigned netId, const DnsTlsServer& server,
+                                                uint32_t mark) {
+    std::lock_guard guard(mPrivateDnsLock);
+    auto netPair = mPrivateDnsTransports.find(netId);
+    if (netPair == mPrivateDnsTransports.end()) {
+        return false;
+    }
+
+    auto& tracker = netPair->second;
+    const ServerIdentity identity = ServerIdentity(server);
+    auto it = tracker.find(identity);
+    if (it == tracker.end()) {
+        return false;
+    }
+
+    const DnsTlsServer& target = it->second;
+
+    if (!target.active()) return false;
+
+    if (target.validationState() != Validation::success) return false;
+
+    // Don't run the validation if |mark| (from android_net_context.dns_mark) is different.
+    // This is to protect validation from running on unexpected marks.
+    // Validation should be associated with a mark gotten by system permission.
+    if (target.mark != mark) return false;
+
+    updateServerState(identity, Validation::in_process, netId);
+    startValidation(target, netId, mark);
+    return true;
+}
+
 void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId,
                                               uint32_t mark) REQUIRES(mPrivateDnsLock) {
     // Note that capturing |server| and |netId| in this lambda create copies.
diff --git a/PrivateDnsConfiguration.h b/PrivateDnsConfiguration.h
index 722ed71..20b4fa2 100644
--- a/PrivateDnsConfiguration.h
+++ b/PrivateDnsConfiguration.h
@@ -65,6 +65,11 @@
 
     void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);
 
+    // Request |server| to be revalidated on a connection tagged with |mark|.
+    // Return true if the request is accepted; otherwise, return false.
+    bool requestValidation(unsigned netId, const DnsTlsServer& server, uint32_t mark)
+            EXCLUDES(mPrivateDnsLock);
+
     struct ServerIdentity {
         const netdutils::IPAddress ip;
         const std::string name;
diff --git a/PrivateDnsConfigurationTest.cpp b/PrivateDnsConfigurationTest.cpp
index f290277..5d7b4f9 100644
--- a/PrivateDnsConfigurationTest.cpp
+++ b/PrivateDnsConfigurationTest.cpp
@@ -254,6 +254,56 @@
     EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
 }
 
+TEST_F(PrivateDnsConfigurationTest, RequestValidation) {
+    const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
+
+    testing::InSequence seq;
+
+    for (const std::string_view config : {"SUCCESS", "IN_PROGRESS", "FAIL"}) {
+        SCOPED_TRACE(config);
+
+        EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
+        if (config == "SUCCESS") {
+            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
+        } else if (config == "IN_PROGRESS") {
+            backend.setDeferredResp(true);
+        } else {
+            // config = "FAIL"
+            ASSERT_TRUE(backend.stopServer());
+            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
+        }
+        EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
+        expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
+
+        // Wait until the validation state is transitioned.
+        const int runningThreads = (config == "IN_PROGRESS") ? 1 : 0;
+        ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == runningThreads; }));
+
+        bool requestAccepted = false;
+        if (config == "SUCCESS") {
+            EXPECT_CALL(mObserver,
+                        onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
+            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
+            requestAccepted = true;
+        } else if (config == "IN_PROGRESS") {
+            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
+        }
+
+        EXPECT_EQ(mPdc.requestValidation(kNetId, server, kMark), requestAccepted);
+
+        // Resending the same request or requesting nonexistent servers are denied.
+        EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark));
+        EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark + 1));
+        EXPECT_FALSE(mPdc.requestValidation(kNetId + 1, server, kMark));
+
+        // Reset the test state.
+        backend.setDeferredResp(false);
+        backend.startServer();
+        ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
+        mPdc.clear(kNetId);
+    }
+}
+
 // TODO: add ValidationFail_Strict test.
 
 }  // namespace android::net