Extend DnsTlsServer to store validation state
This change also fixes a bug in PrivateDnsConfiguration which DoT
servers not valid for the network might somehow be counted as
validated servers. For instance, if a validation for DoT server
finishes after the server is removed, the server is mistakenly
deemed as a validated server.
Bug: 79727473
Test: cd packages/modules/DnsResolver && atest
Change-Id: Idee1f34f59dce1451b7b3e87fd20e6795af883ba
diff --git a/DnsTlsServer.h b/DnsTlsServer.h
index 8caa962..18c1ddd 100644
--- a/DnsTlsServer.h
+++ b/DnsTlsServer.h
@@ -27,6 +27,9 @@
namespace android {
namespace net {
+// Validation status of a DNS over TLS server (on a specific netId).
+enum class Validation : uint8_t { in_process, success, fail, unknown_server, unknown_netid };
+
// DnsTlsServer represents a recursive resolver that supports, or may support, a
// secure protocol.
struct DnsTlsServer {
@@ -37,17 +40,21 @@
DnsTlsServer(const sockaddr_storage& ss) : ss(ss) {}
// The server location, including IP and port.
+ // TODO: make it const.
sockaddr_storage ss = {};
// The server's hostname. If this string is nonempty, the server must present a
// certificate that indicates this name and has a valid chain to a trusted root CA.
+ // TODO: make it const.
std::string name;
// The certificate of the CA that signed the server's certificate.
// It is used to store temporary test CA certificate for internal tests.
+ // TODO: make it const.
std::string certificate;
// Placeholder. More protocols might be defined in the future.
+ // TODO: make it const.
int protocol = IPPROTO_TCP;
// Exact comparison of DnsTlsServer objects
@@ -55,6 +62,13 @@
bool operator==(const DnsTlsServer& other) const;
bool wasExplicitlyConfigured() const;
+
+ Validation validationState() const { return mValidation; }
+ void setValidationState(Validation val) { mValidation = val; }
+
+ private:
+ // State, unrelated to the comparison of DnsTlsServer objects.
+ Validation mValidation = Validation::unknown_server;
};
// This comparison only checks the IP address. It ignores ports, names, and fingerprints.
diff --git a/PrivateDnsConfiguration.cpp b/PrivateDnsConfiguration.cpp
index 53aa56d..13bef15 100644
--- a/PrivateDnsConfiguration.cpp
+++ b/PrivateDnsConfiguration.cpp
@@ -69,7 +69,7 @@
<< ", " << servers.size() << ", " << name << ")";
// Parse the list of servers that has been passed in
- std::set<DnsTlsServer> tlsServers;
+ PrivateDnsTracker tmp;
for (const auto& s : servers) {
sockaddr_storage parsed;
if (!parseServer(s.c_str(), &parsed)) {
@@ -78,13 +78,13 @@
DnsTlsServer server(parsed);
server.name = name;
server.certificate = caCert;
- tlsServers.insert(server);
+ tmp[ServerIdentity(server)] = server;
}
std::lock_guard guard(mPrivateDnsLock);
if (!name.empty()) {
mPrivateDnsModes[netId] = PrivateDnsMode::STRICT;
- } else if (!tlsServers.empty()) {
+ } else if (!tmp.empty()) {
mPrivateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC;
} else {
mPrivateDnsModes[netId] = PrivateDnsMode::OFF;
@@ -112,7 +112,7 @@
// Remove any servers from the tracker that are not in |servers| exactly.
for (auto it = tracker.begin(); it != tracker.end();) {
- if (tlsServers.count(it->first) == 0) {
+ if (tmp.find(it->first) == tmp.end()) {
it = tracker.erase(it);
} else {
++it;
@@ -120,7 +120,7 @@
}
// Add any new or changed servers to the tracker, and initiate async checks for them.
- for (const auto& server : tlsServers) {
+ for (const auto& [identity, server] : tmp) {
if (needsValidation(tracker, server)) {
// This is temporarily required. Consider the following scenario, for example,
// Step 1) A DoTServer (s1) is set for the network. A validation (v1) for s1 starts.
@@ -133,7 +133,10 @@
//
// If we didn't add servers to tracker before needValidateThread(), tracker would
// become empty. We would report s1 validation failed.
- tracker[server] = Validation::in_process;
+ if (tracker.find(identity) == tracker.end()) {
+ tracker[identity] = server;
+ }
+ tracker[identity].setValidationState(Validation::in_process);
LOG(DEBUG) << "Server " << addrToString(&server.ss) << " marked as in_process on netId "
<< netId << ". Tracker now has size " << tracker.size();
// This judge must be after "tracker[server] = Validation::in_process;"
@@ -141,7 +144,7 @@
continue;
}
- updateServerState(server, Validation::in_process, netId);
+ updateServerState(identity, Validation::in_process, netId);
startValidation(server, netId, mark);
}
}
@@ -159,8 +162,8 @@
const auto netPair = mPrivateDnsTransports.find(netId);
if (netPair != mPrivateDnsTransports.end()) {
- for (const auto& serverPair : netPair->second) {
- status.serversMap.emplace(serverPair.first, serverPair.second);
+ for (const auto& [_, server] : netPair->second) {
+ status.serversMap.emplace(server, server.validationState());
}
}
@@ -227,20 +230,21 @@
bool success) {
constexpr bool NEEDS_REEVALUATION = true;
constexpr bool DONT_REEVALUATE = false;
+ const ServerIdentity identity = ServerIdentity(server);
std::lock_guard guard(mPrivateDnsLock);
auto netPair = mPrivateDnsTransports.find(netId);
if (netPair == mPrivateDnsTransports.end()) {
LOG(WARNING) << "netId " << netId << " was erased during private DNS validation";
- maybeNotifyObserver(server, Validation::fail, netId);
+ maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId);
return DONT_REEVALUATE;
}
const auto mode = mPrivateDnsModes.find(netId);
if (mode == mPrivateDnsModes.end()) {
LOG(WARNING) << "netId " << netId << " has no private DNS validation mode";
- maybeNotifyObserver(server, Validation::fail, netId);
+ maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId);
return DONT_REEVALUATE;
}
const bool modeDoesReevaluation = (mode->second == PrivateDnsMode::STRICT);
@@ -249,16 +253,13 @@
(success || !modeDoesReevaluation) ? DONT_REEVALUATE : NEEDS_REEVALUATION;
auto& tracker = netPair->second;
- auto serverPair = tracker.find(server);
+ auto serverPair = tracker.find(identity);
if (serverPair == tracker.end()) {
- // TODO: Consider not adding this server to the tracker since this server is not expected
- // to be one of the private DNS servers for this network now. This could prevent this
- // server from being included when dumping status.
LOG(WARNING) << "Server " << addrToString(&server.ss)
<< " was removed during private DNS validation";
success = false;
reevaluationStatus = DONT_REEVALUATE;
- } else if (!(serverPair->first == server)) {
+ } else if (!(serverPair->second == server)) {
// TODO: It doesn't seem correct to overwrite the tracker entry for
// |server| down below in this circumstance... Fix this.
LOG(WARNING) << "Server " << addrToString(&server.ss)
@@ -282,14 +283,14 @@
}
if (success) {
- updateServerState(server, Validation::success, netId);
+ updateServerState(identity, Validation::success, netId);
} else {
// Validation failure is expected if a user is on a captive portal.
// TODO: Trigger a second validation attempt after captive portal login
// succeeds.
const auto result = (reevaluationStatus == NEEDS_REEVALUATION) ? Validation::in_process
: Validation::fail;
- updateServerState(server, result, netId);
+ updateServerState(identity, result, netId);
}
LOG(WARNING) << "Validation " << (success ? "success" : "failed");
@@ -324,15 +325,22 @@
}
}
-void PrivateDnsConfiguration::updateServerState(const DnsTlsServer& server, Validation state,
+void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state,
uint32_t netId) {
auto netPair = mPrivateDnsTransports.find(netId);
- if (netPair != mPrivateDnsTransports.end()) {
- auto& tracker = netPair->second;
- tracker[server] = state;
+ if (netPair == mPrivateDnsTransports.end()) {
+ maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId);
+ return;
}
- maybeNotifyObserver(server, state, netId);
+ auto& tracker = netPair->second;
+ if (tracker.find(identity) == tracker.end()) {
+ maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId);
+ return;
+ }
+
+ tracker[identity].setValidationState(state);
+ maybeNotifyObserver(identity.ip.toString(), state, netId);
}
void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& server,
@@ -353,8 +361,9 @@
bool PrivateDnsConfiguration::needsValidation(const PrivateDnsTracker& tracker,
const DnsTlsServer& server) {
- const auto& iter = tracker.find(server);
- return (iter == tracker.end()) || (iter->second == Validation::fail);
+ const ServerIdentity identity = ServerIdentity(server);
+ const auto& iter = tracker.find(identity);
+ return (iter == tracker.end()) || (iter->second.validationState() == Validation::fail);
}
void PrivateDnsConfiguration::setObserver(Observer* observer) {
@@ -362,10 +371,10 @@
mObserver = observer;
}
-void PrivateDnsConfiguration::maybeNotifyObserver(const DnsTlsServer& server, Validation validation,
- uint32_t netId) const {
+void PrivateDnsConfiguration::maybeNotifyObserver(const std::string& serverIp,
+ Validation validation, uint32_t netId) const {
if (mObserver) {
- mObserver->onValidationStateUpdate(addrToString(&server.ss), validation, netId);
+ mObserver->onValidationStateUpdate(serverIp, validation, netId);
}
}
diff --git a/PrivateDnsConfiguration.h b/PrivateDnsConfiguration.h
index ffaae46..341c691 100644
--- a/PrivateDnsConfiguration.h
+++ b/PrivateDnsConfiguration.h
@@ -22,6 +22,7 @@
#include <vector>
#include <android-base/thread_annotations.h>
+#include <netdutils/InternetAddresses.h>
#include "DnsTlsServer.h"
@@ -31,11 +32,10 @@
// The DNS over TLS mode on a specific netId.
enum class PrivateDnsMode : uint8_t { OFF, OPPORTUNISTIC, STRICT };
-// Validation status of a DNS over TLS server (on a specific netId).
-enum class Validation : uint8_t { in_process, success, fail, unknown_server, unknown_netid };
-
struct PrivateDnsStatus {
PrivateDnsMode mode;
+
+ // TODO: change the type to std::vector<DnsTlsServer>.
std::map<DnsTlsServer, Validation, AddressComparator> serversMap;
std::list<DnsTlsServer> validatedServers() const {
@@ -65,8 +65,26 @@
void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);
+ struct ServerIdentity {
+ const netdutils::IPAddress ip;
+ const std::string name;
+ const int protocol;
+
+ explicit ServerIdentity(const DnsTlsServer& server)
+ : ip(netdutils::IPSockAddr::toIPSockAddr(server.ss).ip()),
+ name(server.name),
+ protocol(server.protocol) {}
+
+ bool operator<(const ServerIdentity& other) const {
+ return std::tie(ip, name, protocol) < std::tie(other.ip, other.name, other.protocol);
+ }
+ bool operator==(const ServerIdentity& other) const {
+ return std::tie(ip, name, protocol) == std::tie(other.ip, other.name, other.protocol);
+ }
+ };
+
private:
- typedef std::map<DnsTlsServer, Validation, AddressComparator> PrivateDnsTracker;
+ typedef std::map<ServerIdentity, DnsTlsServer> PrivateDnsTracker;
typedef std::set<DnsTlsServer, AddressComparator> ThreadTracker;
PrivateDnsConfiguration() = default;
@@ -88,7 +106,7 @@
bool needsValidation(const PrivateDnsTracker& tracker, const DnsTlsServer& server)
REQUIRES(mPrivateDnsLock);
- void updateServerState(const DnsTlsServer& server, Validation state, uint32_t netId)
+ void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId)
REQUIRES(mPrivateDnsLock);
std::mutex mPrivateDnsLock;
@@ -99,19 +117,20 @@
std::map<unsigned, ThreadTracker> mPrivateDnsValidateThreads GUARDED_BY(mPrivateDnsLock);
// For testing. The observer is notified of onValidationStateUpdate 1) when a validation is
- // about to begin or 2) when a validation finishes.
- // WARNING: The Observer is notified while the lock is being held. Be careful not to call any
- // method of PrivateDnsConfiguration from the observer.
+ // about to begin or 2) when a validation finishes. If a validation finishes when in OFF mode
+ // or when the network has been destroyed, |validation| will be Validation::fail.
+ // WARNING: The Observer is notified while the lock is being held. Be careful not to call
+ // any method of PrivateDnsConfiguration from the observer.
// TODO: fix the reentrancy problem.
class Observer {
public:
virtual ~Observer(){};
- virtual void onValidationStateUpdate(const std::string& server, Validation validation,
+ virtual void onValidationStateUpdate(const std::string& serverIp, Validation validation,
uint32_t netId) = 0;
};
void setObserver(Observer* observer);
- void maybeNotifyObserver(const DnsTlsServer& server, Validation validation,
+ void maybeNotifyObserver(const std::string& serverIp, Validation validation,
uint32_t netId) const REQUIRES(mPrivateDnsLock);
Observer* mObserver GUARDED_BY(mPrivateDnsLock);
diff --git a/PrivateDnsConfigurationTest.cpp b/PrivateDnsConfigurationTest.cpp
index 80fd4bc..f290277 100644
--- a/PrivateDnsConfigurationTest.cpp
+++ b/PrivateDnsConfigurationTest.cpp
@@ -63,7 +63,8 @@
class MockObserver : public PrivateDnsConfiguration::Observer {
public:
MOCK_METHOD(void, onValidationStateUpdate,
- (const std::string& server, Validation validation, uint32_t netId), (override));
+ (const std::string& serverIp, Validation validation, uint32_t netId),
+ (override));
std::map<std::string, Validation> getServerStateMap() const {
std::lock_guard guard(lock);
@@ -172,6 +173,11 @@
backend.setDeferredResp(false);
ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
+
+ // kServer1 is not a present server and thus should not be available from
+ // PrivateDnsConfiguration::getStatus().
+ mObserver.removeFromServerStateMap(kServer1);
+
expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
}
@@ -218,6 +224,36 @@
expectStatus();
}
+TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) {
+ using ServerIdentity = PrivateDnsConfiguration::ServerIdentity;
+
+ DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 853));
+ server.name = "dns.example.com";
+ server.protocol = 1;
+
+ // Different IP address (port is ignored).
+ DnsTlsServer other = server;
+ EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
+ other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 5353);
+ EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
+ other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.2", 853);
+ EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
+
+ // Different provider hostname.
+ other = server;
+ EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
+ other.name = "other.example.com";
+ EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
+ other.name = "";
+ EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
+
+ // Different protocol.
+ other = server;
+ EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
+ other.protocol++;
+ EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
+}
+
// TODO: add ValidationFail_Strict test.
} // namespace android::net
diff --git a/resolv_tls_unit_test.cpp b/resolv_tls_unit_test.cpp
index 78f98cb..919ff6f 100644
--- a/resolv_tls_unit_test.cpp
+++ b/resolv_tls_unit_test.cpp
@@ -800,6 +800,18 @@
EXPECT_FALSE(s2 == s1);
}
+void checkEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
+ EXPECT_TRUE(s1 == s1);
+ EXPECT_TRUE(s2 == s2);
+ EXPECT_TRUE(isAddressEqual(s1, s1));
+ EXPECT_TRUE(isAddressEqual(s2, s2));
+
+ EXPECT_FALSE(s1 < s2);
+ EXPECT_FALSE(s2 < s1);
+ EXPECT_TRUE(s1 == s2);
+ EXPECT_TRUE(s2 == s1);
+}
+
class ServerTest : public BaseTest {};
TEST_F(ServerTest, IPv4) {
@@ -873,6 +885,18 @@
EXPECT_TRUE(s2.wasExplicitlyConfigured());
}
+TEST_F(ServerTest, State) {
+ DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
+ checkEqual(s1, s2);
+ s1.setValidationState(Validation::success);
+ checkEqual(s1, s2);
+ s2.setValidationState(Validation::fail);
+ checkEqual(s1, s2);
+
+ EXPECT_EQ(s1.validationState(), Validation::success);
+ EXPECT_EQ(s2.validationState(), Validation::fail);
+}
+
TEST(QueryMapTest, Basic) {
DnsTlsQueryMap map;