Change most of PrivateDnsConfiguration methods to use ServerIdentity
Passing DnsTlsServer might be confusing because it's not straightforward
to know if a DnsTlsServer is a copy or onwed by PrivateDnsConfiguration.
This CL changes most of the methods to use ServerIdentity. The methods
can then get the corresponding DnsTlsServer by the new added method
getPrivateDns().
Bug: 186177613
Test: cd packages/modules/DnsResolver && atest
Change-Id: Ied4a4ee026862cd2c596586499cbfa7646eaaf2a
diff --git a/DnsTlsDispatcher.cpp b/DnsTlsDispatcher.cpp
index d14fdeb..ca577e7 100644
--- a/DnsTlsDispatcher.cpp
+++ b/DnsTlsDispatcher.cpp
@@ -214,8 +214,8 @@
// happens, the xport will be marked as unusable and DoT queries won't be sent to
// it anymore. Eventually, after IDLE_TIMEOUT, the xport will be destroyed, and
// a new xport will be created.
- const auto result =
- PrivateDnsConfiguration::getInstance().requestValidation(netId, server, mark);
+ const auto result = PrivateDnsConfiguration::getInstance().requestValidation(
+ netId, PrivateDnsConfiguration::ServerIdentity{server}, mark);
LOG(WARNING) << "Requested validation for " << server.toIpString() << " with mark 0x"
<< std::hex << mark << ", "
<< (result.ok() ? "succeeded" : "failed: " + result.error().message());
diff --git a/PrivateDnsConfiguration.cpp b/PrivateDnsConfiguration.cpp
index a2a48b2..7fbad43 100644
--- a/PrivateDnsConfiguration.cpp
+++ b/PrivateDnsConfiguration.cpp
@@ -110,7 +110,7 @@
if (needsValidation(server)) {
updateServerState(identity, Validation::in_process, netId);
- startValidation(server, netId, false);
+ startValidation(identity, netId, false);
}
}
@@ -145,7 +145,7 @@
}
base::Result<void> PrivateDnsConfiguration::requestValidation(unsigned netId,
- const DnsTlsServer& server,
+ const ServerIdentity& identity,
uint32_t mark) {
std::lock_guard guard(mPrivateDnsLock);
@@ -159,40 +159,39 @@
return Errorf("Private DNS setting is not opportunistic mode");
}
- auto netPair = mPrivateDnsTransports.find(netId);
- if (netPair == mPrivateDnsTransports.end()) {
- return Errorf("NetId not found in mPrivateDnsTransports");
+ auto result = getPrivateDnsLocked(identity, netId);
+ if (!result.ok()) {
+ return result.error();
}
- auto& tracker = netPair->second;
- const ServerIdentity identity = ServerIdentity(server);
- auto it = tracker.find(identity);
- if (it == tracker.end()) {
- return Errorf("Server was removed");
- }
+ const DnsTlsServer* target = result.value();
- const DnsTlsServer& target = it->second;
+ if (!target->active()) return Errorf("Server is not active");
- if (!target.active()) return Errorf("Server is not active");
-
- if (target.validationState() != Validation::success) {
+ if (target->validationState() != Validation::success) {
return Errorf("Server validation state mismatched");
}
// Don't run the validation if |mark| (from android_net_context.dns_mark) is different.
// This is to protect validation from running on unexpected marks.
// Validation should be associated with a mark gotten by system permission.
- if (target.mark != mark) return Errorf("Socket mark mismatched");
+ if (target->mark != mark) return Errorf("Socket mark mismatched");
updateServerState(identity, Validation::in_process, netId);
- startValidation(target, netId, true);
+ startValidation(identity, netId, true);
return {};
}
-void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId,
- bool isRevalidation) REQUIRES(mPrivateDnsLock) {
- // Note that capturing |server|, |netId|, and |isRevalidation| in this lambda create copies.
- std::thread validate_thread([this, server, netId, isRevalidation] {
+void PrivateDnsConfiguration::startValidation(const ServerIdentity& identity, unsigned netId,
+ bool isRevalidation) {
+ // This ensures that the thread sends probe at least once in case
+ // the server is removed before the thread starts running.
+ // TODO: consider moving these code to the thread.
+ const auto result = getPrivateDnsLocked(identity, netId);
+ if (!result.ok()) return;
+ DnsTlsServer server = *result.value();
+
+ std::thread validate_thread([this, identity, server, netId, isRevalidation] {
setThreadName(StringPrintf("TlsVerify_%u", netId).c_str());
// cat /proc/sys/net/ipv4/tcp_syn_retries yields "6".
@@ -223,7 +222,7 @@
<< server.toIpString();
const bool needs_reeval =
- this->recordPrivateDnsValidation(server, netId, success, isRevalidation);
+ this->recordPrivateDnsValidation(identity, netId, success, isRevalidation);
if (!needs_reeval) {
break;
@@ -240,11 +239,11 @@
validate_thread.detach();
}
-void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const DnsTlsServer& server,
+void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const ServerIdentity& identity,
unsigned netId, bool success) {
LOG(DEBUG) << "Sending validation " << (success ? "success" : "failure") << " event on netId "
- << netId << " for " << server.toIpString() << " with hostname {" << server.name
- << "}";
+ << netId << " for " << identity.sockaddr.ip().toString() << " with hostname {"
+ << identity.provider << "}";
// Send a validation event to NetdEventListenerService.
const auto& listeners = ResolverEventReporter::getInstance().getListeners();
if (listeners.empty()) {
@@ -252,15 +251,16 @@
<< "Validation event not sent since no INetdEventListener receiver is available.";
}
for (const auto& it : listeners) {
- it->onPrivateDnsValidationEvent(netId, server.toIpString(), server.name, success);
+ it->onPrivateDnsValidationEvent(netId, identity.sockaddr.ip().toString(), identity.provider,
+ success);
}
// Send a validation event to unsolicited event listeners.
const auto& unsolEventListeners = ResolverEventReporter::getInstance().getUnsolEventListeners();
const PrivateDnsValidationEventParcel validationEvent = {
.netId = static_cast<int32_t>(netId),
- .ipAddress = server.toIpString(),
- .hostname = server.name,
+ .ipAddress = identity.sockaddr.ip().toString(),
+ .hostname = identity.provider,
.validation = success ? IDnsResolverUnsolicitedEventListener::VALIDATION_RESULT_SUCCESS
: IDnsResolverUnsolicitedEventListener::VALIDATION_RESULT_FAILURE,
};
@@ -269,11 +269,11 @@
}
}
-bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId,
- bool success, bool isRevalidation) {
+bool PrivateDnsConfiguration::recordPrivateDnsValidation(const ServerIdentity& identity,
+ unsigned netId, bool success,
+ bool isRevalidation) {
constexpr bool NEEDS_REEVALUATION = true;
constexpr bool DONT_REEVALUATE = false;
- const ServerIdentity identity = ServerIdentity(server);
std::lock_guard guard(mPrivateDnsLock);
@@ -303,23 +303,19 @@
auto& tracker = netPair->second;
auto serverPair = tracker.find(identity);
if (serverPair == tracker.end()) {
- LOG(WARNING) << "Server " << server.toIpString()
+ LOG(WARNING) << "Server " << identity.sockaddr.ip().toString()
<< " was removed during private DNS validation";
success = false;
reevaluationStatus = DONT_REEVALUATE;
- } else if (!(serverPair->second == server)) {
- LOG(WARNING) << "Server " << server.toIpString()
- << " was changed during private DNS validation";
- success = false;
- reevaluationStatus = DONT_REEVALUATE;
} else if (!serverPair->second.active()) {
- LOG(WARNING) << "Server " << server.toIpString() << " was removed from the configuration";
+ LOG(WARNING) << "Server " << identity.sockaddr.ip().toString()
+ << " was removed from the configuration";
success = false;
reevaluationStatus = DONT_REEVALUATE;
}
// Send private dns validation result to listeners.
- sendPrivateDnsValidationEvent(server, netId, success);
+ sendPrivateDnsValidationEvent(identity, netId, success);
if (success) {
updateServerState(identity, Validation::success, netId);
@@ -338,19 +334,15 @@
void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state,
uint32_t netId) {
- auto netPair = mPrivateDnsTransports.find(netId);
- if (netPair == mPrivateDnsTransports.end()) {
+ const auto result = getPrivateDnsLocked(identity, netId);
+ if (!result.ok()) {
notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
return;
}
- auto& tracker = netPair->second;
- if (tracker.find(identity) == tracker.end()) {
- notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
- return;
- }
+ auto* server = result.value();
- tracker[identity].setValidationState(state);
+ server->setValidationState(state);
notifyValidationStateUpdate(identity.sockaddr, state, netId);
RecordEntry record(netId, identity, state);
@@ -373,6 +365,28 @@
return false;
}
+base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDns(const ServerIdentity& identity,
+ unsigned netId) {
+ std::lock_guard guard(mPrivateDnsLock);
+ return getPrivateDnsLocked(identity, netId);
+}
+
+base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDnsLocked(
+ const ServerIdentity& identity, unsigned netId) {
+ auto netPair = mPrivateDnsTransports.find(netId);
+ if (netPair == mPrivateDnsTransports.end()) {
+ return Errorf("Failed to get private DNS: netId {} not found", netId);
+ }
+
+ auto iter = netPair->second.find(identity);
+ if (iter == netPair->second.end()) {
+ return Errorf("Failed to get private DNS: server {{{}/{}}} not found", identity.sockaddr,
+ identity.provider);
+ }
+
+ return &iter->second;
+}
+
void PrivateDnsConfiguration::setObserver(PrivateDnsValidationObserver* observer) {
std::lock_guard guard(mPrivateDnsLock);
mObserver = observer;
diff --git a/PrivateDnsConfiguration.h b/PrivateDnsConfiguration.h
index d2a165c..1fce07c 100644
--- a/PrivateDnsConfiguration.h
+++ b/PrivateDnsConfiguration.h
@@ -33,6 +33,7 @@
namespace android {
namespace net {
+// TODO: decouple the dependency of DnsTlsServer.
struct PrivateDnsStatus {
PrivateDnsMode mode;
@@ -53,24 +54,6 @@
class PrivateDnsConfiguration {
public:
- // The only instance of PrivateDnsConfiguration.
- static PrivateDnsConfiguration& getInstance() {
- static PrivateDnsConfiguration instance;
- return instance;
- }
-
- int set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
- const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock);
-
- PrivateDnsStatus getStatus(unsigned netId) const EXCLUDES(mPrivateDnsLock);
-
- void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);
-
- // Request |server| to be revalidated on a connection tagged with |mark|.
- // Returns a Result to indicate if the request is accepted.
- base::Result<void> requestValidation(unsigned netId, const DnsTlsServer& server, uint32_t mark)
- EXCLUDES(mPrivateDnsLock);
-
struct ServerIdentity {
const netdutils::IPSockAddr sockaddr;
const std::string provider;
@@ -86,6 +69,24 @@
}
};
+ // The only instance of PrivateDnsConfiguration.
+ static PrivateDnsConfiguration& getInstance() {
+ static PrivateDnsConfiguration instance;
+ return instance;
+ }
+
+ int set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
+ const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock);
+
+ PrivateDnsStatus getStatus(unsigned netId) const EXCLUDES(mPrivateDnsLock);
+
+ void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);
+
+ // Request the server to be revalidated on a connection tagged with |mark|.
+ // Returns a Result to indicate if the request is accepted.
+ base::Result<void> requestValidation(unsigned netId, const ServerIdentity& identity,
+ uint32_t mark) EXCLUDES(mPrivateDnsLock);
+
void setObserver(PrivateDnsValidationObserver* observer);
void dump(netdutils::DumpWriter& dw) const;
@@ -97,23 +98,31 @@
// Launchs a thread to run the validation for |server| on the network |netId|.
// |isRevalidation| is true if this call is due to a revalidation request.
- void startValidation(const DnsTlsServer& server, unsigned netId, bool isRevalidation)
+ void startValidation(const ServerIdentity& identity, unsigned netId, bool isRevalidation)
REQUIRES(mPrivateDnsLock);
- bool recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId, bool success,
+ bool recordPrivateDnsValidation(const ServerIdentity& identity, unsigned netId, bool success,
bool isRevalidation) EXCLUDES(mPrivateDnsLock);
- void sendPrivateDnsValidationEvent(const DnsTlsServer& server, unsigned netId, bool success)
+ void sendPrivateDnsValidationEvent(const ServerIdentity& identity, unsigned netId, bool success)
REQUIRES(mPrivateDnsLock);
// Decide if a validation for |server| is needed. Note that servers that have failed
// multiple validation attempts but for which there is still a validating
// thread running are marked as being Validation::in_process.
+ // TODO: decouple the dependency of DnsTlsServer.
bool needsValidation(const DnsTlsServer& server) REQUIRES(mPrivateDnsLock);
void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId)
REQUIRES(mPrivateDnsLock);
+ // For testing.
+ base::Result<DnsTlsServer*> getPrivateDns(const ServerIdentity& identity, unsigned netId)
+ EXCLUDES(mPrivateDnsLock);
+
+ base::Result<DnsTlsServer*> getPrivateDnsLocked(const ServerIdentity& identity, unsigned netId)
+ REQUIRES(mPrivateDnsLock);
+
mutable std::mutex mPrivateDnsLock;
std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);
diff --git a/PrivateDnsConfigurationTest.cpp b/PrivateDnsConfigurationTest.cpp
index b6196a7..fa8371c 100644
--- a/PrivateDnsConfigurationTest.cpp
+++ b/PrivateDnsConfigurationTest.cpp
@@ -28,6 +28,8 @@
class PrivateDnsConfigurationTest : public ::testing::Test {
public:
+ using ServerIdentity = PrivateDnsConfiguration::ServerIdentity;
+
static void SetUpTestSuite() {
// stopServer() will be called in their destructor.
ASSERT_TRUE(tls1.startServer());
@@ -100,6 +102,10 @@
return (serverStateMap == mObserver.getServerStateMap());
}
+ bool hasPrivateDnsServer(const ServerIdentity& identity, unsigned netId) {
+ return mPdc.getPrivateDns(identity, netId).ok();
+ }
+
static constexpr uint32_t kNetId = 30;
static constexpr uint32_t kMark = 30;
static constexpr char kBackend[] = "127.0.2.1";
@@ -230,8 +236,6 @@
}
TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) {
- using ServerIdentity = PrivateDnsConfiguration::ServerIdentity;
-
DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 853));
server.name = "dns.example.com";
@@ -254,6 +258,7 @@
TEST_F(PrivateDnsConfigurationTest, RequestValidation) {
const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
+ const ServerIdentity identity(server);
testing::InSequence seq;
@@ -281,18 +286,18 @@
EXPECT_CALL(mObserver,
onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
- EXPECT_TRUE(mPdc.requestValidation(kNetId, server, kMark).ok());
+ EXPECT_TRUE(mPdc.requestValidation(kNetId, identity, kMark).ok());
} else if (config == "IN_PROGRESS") {
EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
- EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark).ok());
+ EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok());
} else if (config == "FAIL") {
- EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark).ok());
+ EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok());
}
// Resending the same request or requesting nonexistent servers are denied.
- EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark).ok());
- EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark + 1).ok());
- EXPECT_FALSE(mPdc.requestValidation(kNetId + 1, server, kMark).ok());
+ EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok());
+ EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark + 1).ok());
+ EXPECT_FALSE(mPdc.requestValidation(kNetId + 1, identity, kMark).ok());
// Reset the test state.
backend.setDeferredResp(false);
@@ -306,6 +311,26 @@
}
}
+TEST_F(PrivateDnsConfigurationTest, GetPrivateDns) {
+ const DnsTlsServer server1(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
+ const DnsTlsServer server2(netdutils::IPSockAddr::toIPSockAddr(kServer2, 853));
+
+ EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId));
+ EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId));
+
+ // Suppress the warning.
+ EXPECT_CALL(mObserver, onValidationStateUpdate).Times(2);
+
+ EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
+ expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
+
+ EXPECT_TRUE(hasPrivateDnsServer(ServerIdentity(server1), kNetId));
+ EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId));
+ EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId + 1));
+
+ ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
+}
+
// TODO: add ValidationFail_Strict test.
} // namespace android::net