Change most of PrivateDnsConfiguration methods to use ServerIdentity
Passing DnsTlsServer might be confusing because it's not straightforward
to know if a DnsTlsServer is a copy or onwed by PrivateDnsConfiguration.
This CL changes most of the methods to use ServerIdentity. The methods
can then get the corresponding DnsTlsServer by the new added method
getPrivateDns().
Bug: 186177613
Test: cd packages/modules/DnsResolver && atest
Change-Id: Ied4a4ee026862cd2c596586499cbfa7646eaaf2a
diff --git a/PrivateDnsConfiguration.cpp b/PrivateDnsConfiguration.cpp
index a2a48b2..7fbad43 100644
--- a/PrivateDnsConfiguration.cpp
+++ b/PrivateDnsConfiguration.cpp
@@ -110,7 +110,7 @@
if (needsValidation(server)) {
updateServerState(identity, Validation::in_process, netId);
- startValidation(server, netId, false);
+ startValidation(identity, netId, false);
}
}
@@ -145,7 +145,7 @@
}
base::Result<void> PrivateDnsConfiguration::requestValidation(unsigned netId,
- const DnsTlsServer& server,
+ const ServerIdentity& identity,
uint32_t mark) {
std::lock_guard guard(mPrivateDnsLock);
@@ -159,40 +159,39 @@
return Errorf("Private DNS setting is not opportunistic mode");
}
- auto netPair = mPrivateDnsTransports.find(netId);
- if (netPair == mPrivateDnsTransports.end()) {
- return Errorf("NetId not found in mPrivateDnsTransports");
+ auto result = getPrivateDnsLocked(identity, netId);
+ if (!result.ok()) {
+ return result.error();
}
- auto& tracker = netPair->second;
- const ServerIdentity identity = ServerIdentity(server);
- auto it = tracker.find(identity);
- if (it == tracker.end()) {
- return Errorf("Server was removed");
- }
+ const DnsTlsServer* target = result.value();
- const DnsTlsServer& target = it->second;
+ if (!target->active()) return Errorf("Server is not active");
- if (!target.active()) return Errorf("Server is not active");
-
- if (target.validationState() != Validation::success) {
+ if (target->validationState() != Validation::success) {
return Errorf("Server validation state mismatched");
}
// Don't run the validation if |mark| (from android_net_context.dns_mark) is different.
// This is to protect validation from running on unexpected marks.
// Validation should be associated with a mark gotten by system permission.
- if (target.mark != mark) return Errorf("Socket mark mismatched");
+ if (target->mark != mark) return Errorf("Socket mark mismatched");
updateServerState(identity, Validation::in_process, netId);
- startValidation(target, netId, true);
+ startValidation(identity, netId, true);
return {};
}
-void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId,
- bool isRevalidation) REQUIRES(mPrivateDnsLock) {
- // Note that capturing |server|, |netId|, and |isRevalidation| in this lambda create copies.
- std::thread validate_thread([this, server, netId, isRevalidation] {
+void PrivateDnsConfiguration::startValidation(const ServerIdentity& identity, unsigned netId,
+ bool isRevalidation) {
+ // This ensures that the thread sends probe at least once in case
+ // the server is removed before the thread starts running.
+ // TODO: consider moving these code to the thread.
+ const auto result = getPrivateDnsLocked(identity, netId);
+ if (!result.ok()) return;
+ DnsTlsServer server = *result.value();
+
+ std::thread validate_thread([this, identity, server, netId, isRevalidation] {
setThreadName(StringPrintf("TlsVerify_%u", netId).c_str());
// cat /proc/sys/net/ipv4/tcp_syn_retries yields "6".
@@ -223,7 +222,7 @@
<< server.toIpString();
const bool needs_reeval =
- this->recordPrivateDnsValidation(server, netId, success, isRevalidation);
+ this->recordPrivateDnsValidation(identity, netId, success, isRevalidation);
if (!needs_reeval) {
break;
@@ -240,11 +239,11 @@
validate_thread.detach();
}
-void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const DnsTlsServer& server,
+void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const ServerIdentity& identity,
unsigned netId, bool success) {
LOG(DEBUG) << "Sending validation " << (success ? "success" : "failure") << " event on netId "
- << netId << " for " << server.toIpString() << " with hostname {" << server.name
- << "}";
+ << netId << " for " << identity.sockaddr.ip().toString() << " with hostname {"
+ << identity.provider << "}";
// Send a validation event to NetdEventListenerService.
const auto& listeners = ResolverEventReporter::getInstance().getListeners();
if (listeners.empty()) {
@@ -252,15 +251,16 @@
<< "Validation event not sent since no INetdEventListener receiver is available.";
}
for (const auto& it : listeners) {
- it->onPrivateDnsValidationEvent(netId, server.toIpString(), server.name, success);
+ it->onPrivateDnsValidationEvent(netId, identity.sockaddr.ip().toString(), identity.provider,
+ success);
}
// Send a validation event to unsolicited event listeners.
const auto& unsolEventListeners = ResolverEventReporter::getInstance().getUnsolEventListeners();
const PrivateDnsValidationEventParcel validationEvent = {
.netId = static_cast<int32_t>(netId),
- .ipAddress = server.toIpString(),
- .hostname = server.name,
+ .ipAddress = identity.sockaddr.ip().toString(),
+ .hostname = identity.provider,
.validation = success ? IDnsResolverUnsolicitedEventListener::VALIDATION_RESULT_SUCCESS
: IDnsResolverUnsolicitedEventListener::VALIDATION_RESULT_FAILURE,
};
@@ -269,11 +269,11 @@
}
}
-bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId,
- bool success, bool isRevalidation) {
+bool PrivateDnsConfiguration::recordPrivateDnsValidation(const ServerIdentity& identity,
+ unsigned netId, bool success,
+ bool isRevalidation) {
constexpr bool NEEDS_REEVALUATION = true;
constexpr bool DONT_REEVALUATE = false;
- const ServerIdentity identity = ServerIdentity(server);
std::lock_guard guard(mPrivateDnsLock);
@@ -303,23 +303,19 @@
auto& tracker = netPair->second;
auto serverPair = tracker.find(identity);
if (serverPair == tracker.end()) {
- LOG(WARNING) << "Server " << server.toIpString()
+ LOG(WARNING) << "Server " << identity.sockaddr.ip().toString()
<< " was removed during private DNS validation";
success = false;
reevaluationStatus = DONT_REEVALUATE;
- } else if (!(serverPair->second == server)) {
- LOG(WARNING) << "Server " << server.toIpString()
- << " was changed during private DNS validation";
- success = false;
- reevaluationStatus = DONT_REEVALUATE;
} else if (!serverPair->second.active()) {
- LOG(WARNING) << "Server " << server.toIpString() << " was removed from the configuration";
+ LOG(WARNING) << "Server " << identity.sockaddr.ip().toString()
+ << " was removed from the configuration";
success = false;
reevaluationStatus = DONT_REEVALUATE;
}
// Send private dns validation result to listeners.
- sendPrivateDnsValidationEvent(server, netId, success);
+ sendPrivateDnsValidationEvent(identity, netId, success);
if (success) {
updateServerState(identity, Validation::success, netId);
@@ -338,19 +334,15 @@
void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state,
uint32_t netId) {
- auto netPair = mPrivateDnsTransports.find(netId);
- if (netPair == mPrivateDnsTransports.end()) {
+ const auto result = getPrivateDnsLocked(identity, netId);
+ if (!result.ok()) {
notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
return;
}
- auto& tracker = netPair->second;
- if (tracker.find(identity) == tracker.end()) {
- notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
- return;
- }
+ auto* server = result.value();
- tracker[identity].setValidationState(state);
+ server->setValidationState(state);
notifyValidationStateUpdate(identity.sockaddr, state, netId);
RecordEntry record(netId, identity, state);
@@ -373,6 +365,28 @@
return false;
}
+base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDns(const ServerIdentity& identity,
+ unsigned netId) {
+ std::lock_guard guard(mPrivateDnsLock);
+ return getPrivateDnsLocked(identity, netId);
+}
+
+base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDnsLocked(
+ const ServerIdentity& identity, unsigned netId) {
+ auto netPair = mPrivateDnsTransports.find(netId);
+ if (netPair == mPrivateDnsTransports.end()) {
+ return Errorf("Failed to get private DNS: netId {} not found", netId);
+ }
+
+ auto iter = netPair->second.find(identity);
+ if (iter == netPair->second.end()) {
+ return Errorf("Failed to get private DNS: server {{{}/{}}} not found", identity.sockaddr,
+ identity.provider);
+ }
+
+ return &iter->second;
+}
+
void PrivateDnsConfiguration::setObserver(PrivateDnsValidationObserver* observer) {
std::lock_guard guard(mPrivateDnsLock);
mObserver = observer;