Merge "Revert "Test for thread creation failed""
diff --git a/Android.bp b/Android.bp
index f1735ae..6f06eb6 100644
--- a/Android.bp
+++ b/Android.bp
@@ -259,6 +259,7 @@
"DnsQueryLogTest.cpp",
"DnsStatsTest.cpp",
"ExperimentsTest.cpp",
+ "PrivateDnsConfigurationTest.cpp",
],
shared_libs: [
"libcrypto",
diff --git a/DnsProxyListener.cpp b/DnsProxyListener.cpp
index df5d873..00385eb 100644
--- a/DnsProxyListener.cpp
+++ b/DnsProxyListener.cpp
@@ -84,25 +84,6 @@
}
}
-template<typename T>
-void tryThreadOrError(SocketClient* cli, T* handler) {
- cli->incRef();
-
- const int rval = netdutils::threadLaunch(handler);
- if (rval == 0) {
- // SocketClient decRef() happens in the handler's run() method.
- return;
- }
-
- char* msg = nullptr;
- asprintf(&msg, "%s (%d)", strerror(-rval), -rval);
- cli->sendMsg(ResponseCode::OperationFailed, msg, false);
- free(msg);
-
- delete handler;
- cli->decRef();
-}
-
bool checkAndClearUseLocalNameserversFlag(unsigned* netid) {
if (netid == nullptr || ((*netid) & NETID_USE_LOCAL_NAMESERVERS) == 0) {
return false;
@@ -563,10 +544,23 @@
registerCmd(new GetDnsNetIdCommand());
}
+void DnsProxyListener::Handler::spawn() {
+ const int rval = netdutils::threadLaunch(this);
+ if (rval == 0) {
+ return;
+ }
+
+ char* msg = nullptr;
+ asprintf(&msg, "%s (%d)", strerror(-rval), -rval);
+ mClient->sendMsg(ResponseCode::OperationFailed, msg, false);
+ free(msg);
+ delete this;
+}
+
DnsProxyListener::GetAddrInfoHandler::GetAddrInfoHandler(SocketClient* c, char* host, char* service,
addrinfo* hints,
const android_net_context& netcontext)
- : mClient(c), mHost(host), mService(service), mHints(hints), mNetContext(netcontext) {}
+ : Handler(c), mHost(host), mService(service), mHints(hints), mNetContext(netcontext) {}
DnsProxyListener::GetAddrInfoHandler::~GetAddrInfoHandler() {
free(mHost);
@@ -763,7 +757,6 @@
reportDnsEvent(INetdEventListener::EVENT_GETADDRINFO, mNetContext, latencyUs, rv, event, mHost,
ip_addrs, total_ip_addr_count);
freeaddrinfo(result);
- mClient->decRef();
}
std::string DnsProxyListener::GetAddrInfoHandler::threadName() {
@@ -841,9 +834,7 @@
hints->ai_protocol = ai_protocol;
}
- DnsProxyListener::GetAddrInfoHandler* handler =
- new DnsProxyListener::GetAddrInfoHandler(cli, name, service, hints, netcontext);
- tryThreadOrError(cli, handler);
+ (new GetAddrInfoHandler(cli, name, service, hints, netcontext))->spawn();
return 0;
}
@@ -888,19 +879,13 @@
netcontext.flags |= NET_CONTEXT_FLAG_USE_LOCAL_NAMESERVERS;
}
- DnsProxyListener::ResNSendHandler* handler =
- new DnsProxyListener::ResNSendHandler(cli, argv[3], flags, netcontext);
- tryThreadOrError(cli, handler);
+ (new ResNSendHandler(cli, argv[3], flags, netcontext))->spawn();
return 0;
}
DnsProxyListener::ResNSendHandler::ResNSendHandler(SocketClient* c, std::string msg, uint32_t flags,
const android_net_context& netcontext)
- : mClient(c), mMsg(std::move(msg)), mFlags(flags), mNetContext(netcontext) {}
-
-DnsProxyListener::ResNSendHandler::~ResNSendHandler() {
- mClient->decRef();
-}
+ : Handler(c), mMsg(std::move(msg)), mFlags(flags), mNetContext(netcontext) {}
void DnsProxyListener::ResNSendHandler::run() {
LOG(DEBUG) << "ResNSendHandler::run: " << mFlags << " / {" << mNetContext.app_netid << " "
@@ -1090,15 +1075,13 @@
netcontext.flags |= NET_CONTEXT_FLAG_USE_LOCAL_NAMESERVERS;
}
- DnsProxyListener::GetHostByNameHandler* handler =
- new DnsProxyListener::GetHostByNameHandler(cli, name, af, netcontext);
- tryThreadOrError(cli, handler);
+ (new GetHostByNameHandler(cli, name, af, netcontext))->spawn();
return 0;
}
DnsProxyListener::GetHostByNameHandler::GetHostByNameHandler(SocketClient* c, char* name, int af,
const android_net_context& netcontext)
- : mClient(c), mName(name), mAf(af), mNetContext(netcontext) {}
+ : Handler(c), mName(name), mAf(af), mNetContext(netcontext) {}
DnsProxyListener::GetHostByNameHandler::~GetHostByNameHandler() {
free(mName);
@@ -1190,7 +1173,6 @@
const int total_ip_addr_count = extractGetHostByNameAnswers(hp, &ip_addrs);
reportDnsEvent(INetdEventListener::EVENT_GETHOSTBYNAME, mNetContext, latencyUs, rv, event,
mName, ip_addrs, total_ip_addr_count);
- mClient->decRef();
}
std::string DnsProxyListener::GetHostByNameHandler::threadName() {
@@ -1242,16 +1224,14 @@
netcontext.flags |= NET_CONTEXT_FLAG_USE_LOCAL_NAMESERVERS;
}
- DnsProxyListener::GetHostByAddrHandler* handler = new DnsProxyListener::GetHostByAddrHandler(
- cli, addr, addrLen, addrFamily, netcontext);
- tryThreadOrError(cli, handler);
+ (new GetHostByAddrHandler(cli, addr, addrLen, addrFamily, netcontext))->spawn();
return 0;
}
DnsProxyListener::GetHostByAddrHandler::GetHostByAddrHandler(SocketClient* c, void* address,
int addressLen, int addressFamily,
const android_net_context& netcontext)
- : mClient(c),
+ : Handler(c),
mAddress(address),
mAddressLen(addressLen),
mAddressFamily(addressFamily),
@@ -1351,7 +1331,6 @@
reportDnsEvent(INetdEventListener::EVENT_GETHOSTBYADDR, mNetContext, latencyUs, rv, event,
(hp && hp->h_name) ? hp->h_name : "null", {}, 0);
- mClient->decRef();
}
std::string DnsProxyListener::GetHostByAddrHandler::threadName() {
diff --git a/DnsProxyListener.h b/DnsProxyListener.h
index 34340a0..3f95091 100644
--- a/DnsProxyListener.h
+++ b/DnsProxyListener.h
@@ -38,6 +38,23 @@
static constexpr const char* SOCKET_NAME = "dnsproxyd";
private:
+ class Handler {
+ public:
+ Handler(SocketClient* c) : mClient(c) { mClient->incRef(); }
+ virtual ~Handler() { mClient->decRef(); }
+ void operator=(const Handler&) = delete;
+
+ // Attept to spawn the worker thread, or return an error to the client.
+ // The Handler instance will self-delete in either case.
+ void spawn();
+
+ virtual void run() = 0;
+ virtual std::string threadName() = 0;
+
+ SocketClient* mClient; // ref-counted
+ };
+
+ /* ------ getaddrinfo ------*/
class GetAddrInfoCmd : public FrameworkCommand {
public:
GetAddrInfoCmd();
@@ -45,21 +62,19 @@
int runCommand(SocketClient* c, int argc, char** argv) override;
};
- /* ------ getaddrinfo ------*/
- class GetAddrInfoHandler {
+ class GetAddrInfoHandler : public Handler {
public:
// Note: All of host, service, and hints may be NULL
GetAddrInfoHandler(SocketClient* c, char* host, char* service, addrinfo* hints,
const android_net_context& netcontext);
- ~GetAddrInfoHandler();
+ ~GetAddrInfoHandler() override;
- void run();
- std::string threadName();
+ void run() override;
+ std::string threadName() override;
private:
void doDns64Synthesis(int32_t* rv, addrinfo** res, NetworkDnsEventReported* event);
- SocketClient* mClient; // ref counted
char* mHost; // owned. TODO: convert to std::string.
char* mService; // owned. TODO: convert to std::string.
addrinfo* mHints; // owned
@@ -74,20 +89,19 @@
int runCommand(SocketClient* c, int argc, char** argv) override;
};
- class GetHostByNameHandler {
+ class GetHostByNameHandler : public Handler {
public:
GetHostByNameHandler(SocketClient* c, char* name, int af,
const android_net_context& netcontext);
- ~GetHostByNameHandler();
+ ~GetHostByNameHandler() override;
- void run();
- std::string threadName();
+ void run() override;
+ std::string threadName() override;
private:
void doDns64Synthesis(int32_t* rv, hostent* hbuf, char* buf, size_t buflen, hostent** hpp,
NetworkDnsEventReported* event);
- SocketClient* mClient; // ref counted
char* mName; // owned. TODO: convert to std::string.
int mAf;
android_net_context mNetContext;
@@ -101,20 +115,19 @@
int runCommand(SocketClient* c, int argc, char** argv) override;
};
- class GetHostByAddrHandler {
+ class GetHostByAddrHandler : public Handler {
public:
GetHostByAddrHandler(SocketClient* c, void* address, int addressLen, int addressFamily,
const android_net_context& netcontext);
- ~GetHostByAddrHandler();
+ ~GetHostByAddrHandler() override;
- void run();
- std::string threadName();
+ void run() override;
+ std::string threadName() override;
private:
void doDns64ReverseLookup(hostent* hbuf, char* buf, size_t buflen, hostent** hpp,
NetworkDnsEventReported* event);
- SocketClient* mClient; // ref counted
void* mAddress; // address to lookup; owned
int mAddressLen; // length of address to look up
int mAddressFamily; // address family
@@ -129,17 +142,16 @@
int runCommand(SocketClient* c, int argc, char** argv) override;
};
- class ResNSendHandler {
+ class ResNSendHandler : public Handler {
public:
ResNSendHandler(SocketClient* c, std::string msg, uint32_t flags,
const android_net_context& netcontext);
- ~ResNSendHandler();
+ ~ResNSendHandler() override = default;
- void run();
- std::string threadName();
+ void run() override;
+ std::string threadName() override;
private:
- SocketClient* mClient; // ref counted
std::string mMsg;
uint32_t mFlags;
android_net_context mNetContext;
diff --git a/DnsResolverService.cpp b/DnsResolverService.cpp
index a2eb598..af449ab 100644
--- a/DnsResolverService.cpp
+++ b/DnsResolverService.cpp
@@ -72,9 +72,9 @@
DnsResolverService::DnsResolverService() {
// register log callback to BnDnsResolver::logFunc
- BnDnsResolver::logFunc =
- std::bind(binderCallLogFn, std::placeholders::_1,
- [](const std::string& msg) { gResNetdCallbacks.log(msg.c_str()); });
+ BnDnsResolver::logFunc = [](const auto& log) {
+ binderCallLogFn(log, [](const std::string& msg) { gResNetdCallbacks.log(msg.c_str()); });
+ };
}
binder_status_t DnsResolverService::start() {
diff --git a/DnsTlsDispatcher.cpp b/DnsTlsDispatcher.cpp
index df8fce8..41eac9d 100644
--- a/DnsTlsDispatcher.cpp
+++ b/DnsTlsDispatcher.cpp
@@ -41,6 +41,11 @@
mFactory.reset(new DnsTlsSocketFactory());
}
+DnsTlsDispatcher& DnsTlsDispatcher::getInstance() {
+ static DnsTlsDispatcher instance;
+ return instance;
+}
+
std::list<DnsTlsServer> DnsTlsDispatcher::getOrderedServerList(
const std::list<DnsTlsServer> &tlsServers, unsigned mark) const {
// Our preferred DnsTlsServer order is:
diff --git a/DnsTlsDispatcher.h b/DnsTlsDispatcher.h
index 9eb6dfe..c3dad06 100644
--- a/DnsTlsDispatcher.h
+++ b/DnsTlsDispatcher.h
@@ -37,13 +37,12 @@
// Queries made here are dispatched to an existing or newly constructed DnsTlsTransport.
class DnsTlsDispatcher {
public:
- // Default constructor.
- DnsTlsDispatcher();
-
// Constructor with dependency injection for testing.
explicit DnsTlsDispatcher(std::unique_ptr<IDnsTlsSocketFactory> factory)
: mFactory(std::move(factory)) {}
+ static DnsTlsDispatcher& getInstance();
+
// Enqueues |query| for resolution via the given |tlsServers| on the
// network indicated by |mark|; writes the response into |ans|, and stores
// the count of bytes written in |resplen|. Returns a success or error code.
@@ -62,6 +61,8 @@
int* _Nonnull resplen, bool* _Nonnull connectTriggered);
private:
+ DnsTlsDispatcher();
+
// This lock is static so that it can be used to annotate the Transport struct.
// DnsTlsDispatcher is a singleton in practice, so making this static does not change
// the locking behavior.
diff --git a/DnsTlsSocket.cpp b/DnsTlsSocket.cpp
index 32b8073..ad12316 100644
--- a/DnsTlsSocket.cpp
+++ b/DnsTlsSocket.cpp
@@ -450,8 +450,20 @@
break;
}
if (fds[SSLFD].revents & (POLLIN | POLLERR | POLLHUP)) {
- if (!readResponse()) {
- LOG(DEBUG) << "SSL remote close or read error.";
+ bool readFailed = false;
+
+ // readResponse() only reads one DNS (and consumes exact bytes) from ssl.
+ // Keep doing so until ssl has no pending data.
+ // TODO: readResponse() can block until it reads a complete DNS response. Consider
+ // refactoring it to not get blocked in any case.
+ do {
+ if (!readResponse()) {
+ LOG(DEBUG) << "SSL remote close or read error.";
+ readFailed = true;
+ }
+ } while (SSL_pending(mSsl.get()) > 0 && !readFailed);
+
+ if (readFailed) {
break;
}
}
diff --git a/DnsTlsSocket.h b/DnsTlsSocket.h
index 5a4140f..f6736f8 100644
--- a/DnsTlsSocket.h
+++ b/DnsTlsSocket.h
@@ -137,6 +137,9 @@
int sslRead(const netdutils::Slice buffer, bool wait) REQUIRES(mLock);
bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock);
+
+ // Read one DNS response. It can potentially block until reading the exact bytes of
+ // the response.
bool readResponse() REQUIRES(mLock);
// It is only used for DNS-OVER-TLS internal test.
diff --git a/PrivateDnsConfiguration.cpp b/PrivateDnsConfiguration.cpp
index d806fa0..c6ae6a2 100644
--- a/PrivateDnsConfiguration.cpp
+++ b/PrivateDnsConfiguration.cpp
@@ -28,7 +28,6 @@
#include "ResolverEventReporter.h"
#include "netd_resolv/resolv.h"
#include "netdutils/BackoffSequence.h"
-#include "resolv_cache.h"
#include "util.h"
using android::base::StringPrintf;
@@ -90,7 +89,6 @@
} else {
mPrivateDnsModes[netId] = PrivateDnsMode::OFF;
mPrivateDnsTransports.erase(netId);
- resolv_stats_set_servers_for_dot(netId, {});
mPrivateDnsValidateThreads.erase(netId);
// TODO: As mPrivateDnsValidateThreads is reset, validation threads which haven't yet
// finished are considered outdated. Consider signaling the outdated validation threads to
@@ -128,7 +126,7 @@
}
}
- return resolv_stats_set_servers_for_dot(netId, servers);
+ return 0;
}
PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) {
@@ -168,6 +166,8 @@
return;
}
+ maybeNotifyObserver(server, Validation::in_process, netId);
+
// Note that capturing |server| and |netId| in this lambda create copies.
std::thread validate_thread([this, server, netId, mark] {
setThreadName(StringPrintf("TlsVerify_%u", netId).c_str());
@@ -224,12 +224,14 @@
auto netPair = mPrivateDnsTransports.find(netId);
if (netPair == mPrivateDnsTransports.end()) {
LOG(WARNING) << "netId " << netId << " was erased during private DNS validation";
+ maybeNotifyObserver(server, Validation::fail, netId);
return DONT_REEVALUATE;
}
const auto mode = mPrivateDnsModes.find(netId);
if (mode == mPrivateDnsModes.end()) {
LOG(WARNING) << "netId " << netId << " has no private DNS validation mode";
+ maybeNotifyObserver(server, Validation::fail, netId);
return DONT_REEVALUATE;
}
const bool modeDoesReevaluation = (mode->second == PrivateDnsMode::STRICT);
@@ -272,12 +274,17 @@
if (success) {
tracker[server] = Validation::success;
+ maybeNotifyObserver(server, Validation::success, netId);
} else {
// Validation failure is expected if a user is on a captive portal.
// TODO: Trigger a second validation attempt after captive portal login
// succeeds.
tracker[server] = (reevaluationStatus == NEEDS_REEVALUATION) ? Validation::in_process
: Validation::fail;
+ maybeNotifyObserver(server,
+ (reevaluationStatus == NEEDS_REEVALUATION) ? Validation::in_process
+ : Validation::fail,
+ netId);
}
LOG(WARNING) << "Validation " << (success ? "success" : "failed");
@@ -338,5 +345,17 @@
return (iter == tracker.end()) || (iter->second == Validation::fail);
}
+void PrivateDnsConfiguration::setObserver(Observer* observer) {
+ std::lock_guard guard(mPrivateDnsLock);
+ mObserver = observer;
+}
+
+void PrivateDnsConfiguration::maybeNotifyObserver(const DnsTlsServer& server, Validation validation,
+ uint32_t netId) const {
+ if (mObserver) {
+ mObserver->onValidationStateUpdate(addrToString(&server.ss), validation, netId);
+ }
+}
+
} // namespace net
} // namespace android
diff --git a/PrivateDnsConfiguration.h b/PrivateDnsConfiguration.h
index 31fb546..5d42935 100644
--- a/PrivateDnsConfiguration.h
+++ b/PrivateDnsConfiguration.h
@@ -91,6 +91,23 @@
// Using the AddressComparator ensures at most one entry per IP address.
std::map<unsigned, PrivateDnsTracker> mPrivateDnsTransports GUARDED_BY(mPrivateDnsLock);
std::map<unsigned, ThreadTracker> mPrivateDnsValidateThreads GUARDED_BY(mPrivateDnsLock);
+
+ // For testing. The observer is notified of onValidationStateUpdate 1) when a validation is
+ // about to begin or 2) when a validation finishes.
+ class Observer {
+ public:
+ virtual ~Observer(){};
+ virtual void onValidationStateUpdate(const std::string& server, Validation validation,
+ uint32_t netId) = 0;
+ };
+
+ void setObserver(Observer* observer);
+ void maybeNotifyObserver(const DnsTlsServer& server, Validation validation,
+ uint32_t netId) const REQUIRES(mPrivateDnsLock);
+
+ Observer* mObserver GUARDED_BY(mPrivateDnsLock);
+
+ friend class PrivateDnsConfigurationTest;
};
} // namespace net
diff --git a/PrivateDnsConfigurationTest.cpp b/PrivateDnsConfigurationTest.cpp
new file mode 100644
index 0000000..80fd4bc
--- /dev/null
+++ b/PrivateDnsConfigurationTest.cpp
@@ -0,0 +1,223 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "PrivateDnsConfiguration.h"
+#include "tests/dns_responder/dns_responder.h"
+#include "tests/dns_responder/dns_tls_frontend.h"
+#include "tests/resolv_test_utils.h"
+
+namespace android::net {
+
+using namespace std::chrono_literals;
+
+class PrivateDnsConfigurationTest : public ::testing::Test {
+ public:
+ static void SetUpTestSuite() {
+ // stopServer() will be called in their destructor.
+ ASSERT_TRUE(tls1.startServer());
+ ASSERT_TRUE(tls2.startServer());
+ ASSERT_TRUE(backend.startServer());
+ }
+
+ void SetUp() {
+ mPdc.setObserver(&mObserver);
+
+ // The default and sole action when the observer is notified of onValidationStateUpdate.
+ // Don't override the action. In other words, don't use WillOnce() or WillRepeatedly()
+ // when mObserver.onValidationStateUpdate is expected to be called, like:
+ //
+ // EXPECT_CALL(mObserver, onValidationStateUpdate).WillOnce(Return());
+ //
+ // This is to ensure that tests can monitor how many validation threads are running. Tests
+ // must wait until every validation thread finishes.
+ ON_CALL(mObserver, onValidationStateUpdate)
+ .WillByDefault([&](const std::string& server, Validation validation, uint32_t) {
+ if (validation == Validation::in_process) {
+ mObserver.runningThreads++;
+ } else if (validation == Validation::success ||
+ validation == Validation::fail) {
+ mObserver.runningThreads--;
+ }
+ std::lock_guard guard(mObserver.lock);
+ mObserver.serverStateMap[server] = validation;
+ });
+ }
+
+ protected:
+ class MockObserver : public PrivateDnsConfiguration::Observer {
+ public:
+ MOCK_METHOD(void, onValidationStateUpdate,
+ (const std::string& server, Validation validation, uint32_t netId), (override));
+
+ std::map<std::string, Validation> getServerStateMap() const {
+ std::lock_guard guard(lock);
+ return serverStateMap;
+ }
+
+ void removeFromServerStateMap(const std::string& server) {
+ std::lock_guard guard(lock);
+ if (const auto it = serverStateMap.find(server); it != serverStateMap.end())
+ serverStateMap.erase(it);
+ }
+
+ // The current number of validation threads running.
+ std::atomic<int> runningThreads = 0;
+
+ mutable std::mutex lock;
+ std::map<std::string, Validation> serverStateMap GUARDED_BY(lock);
+ };
+
+ void expectPrivateDnsStatus(PrivateDnsMode mode) {
+ const PrivateDnsStatus status = mPdc.getStatus(kNetId);
+ EXPECT_EQ(status.mode, mode);
+
+ std::map<std::string, Validation> serverStateMap;
+ for (const auto& [server, validation] : status.serversMap) {
+ serverStateMap[ToString(&server.ss)] = validation;
+ }
+ EXPECT_EQ(serverStateMap, mObserver.getServerStateMap());
+ }
+
+ static constexpr uint32_t kNetId = 30;
+ static constexpr uint32_t kMark = 30;
+ static constexpr char kBackend[] = "127.0.2.1";
+ static constexpr char kServer1[] = "127.0.2.2";
+ static constexpr char kServer2[] = "127.0.2.3";
+
+ MockObserver mObserver;
+ PrivateDnsConfiguration mPdc;
+
+ // TODO: Because incorrect CAs result in validation failed in strict mode, have
+ // PrivateDnsConfiguration run mocked code rather than DnsTlsTransport::validate().
+ inline static test::DnsTlsFrontend tls1{kServer1, "853", kBackend, "53"};
+ inline static test::DnsTlsFrontend tls2{kServer2, "853", kBackend, "53"};
+ inline static test::DNSResponder backend{kBackend, "53"};
+};
+
+TEST_F(PrivateDnsConfigurationTest, ValidationSuccess) {
+ testing::InSequence seq;
+ EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
+ EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
+
+ EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
+ expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
+
+ ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
+}
+
+TEST_F(PrivateDnsConfigurationTest, ValidationFail_Opportunistic) {
+ ASSERT_TRUE(backend.stopServer());
+
+ testing::InSequence seq;
+ EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
+ EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
+
+ EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
+ expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
+
+ // Strictly wait for all of the validation finish; otherwise, the test can crash somehow.
+ ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
+ ASSERT_TRUE(backend.startServer());
+}
+
+TEST_F(PrivateDnsConfigurationTest, ValidationBlock) {
+ backend.setDeferredResp(true);
+
+ // onValidationStateUpdate() is called in sequence.
+ {
+ testing::InSequence seq;
+ EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
+ EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
+ ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 1; }));
+ expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
+
+ EXPECT_CALL(mObserver, onValidationStateUpdate(kServer2, Validation::in_process, kNetId));
+ EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer2}, {}, {}), 0);
+ ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 2; }));
+ mObserver.removeFromServerStateMap(kServer1);
+ expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
+
+ // No duplicate validation as long as not in OFF mode; otherwise, an unexpected
+ // onValidationStateUpdate() will be caught.
+ EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
+ EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1, kServer2}, {}, {}), 0);
+ EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer2}, {}, {}), 0);
+ expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
+
+ // The status keeps unchanged if pass invalid arguments.
+ EXPECT_EQ(mPdc.set(kNetId, kMark, {"invalid_addr"}, {}, {}), -EINVAL);
+ expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
+ }
+
+ // The update for |kServer1| will be Validation::fail because |kServer1| is not an expected
+ // server for the network.
+ EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
+ EXPECT_CALL(mObserver, onValidationStateUpdate(kServer2, Validation::success, kNetId));
+ backend.setDeferredResp(false);
+
+ ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
+ expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
+}
+
+TEST_F(PrivateDnsConfigurationTest, Validation_NetworkDestroyedOrOffMode) {
+ for (const std::string_view config : {"OFF", "NETWORK_DESTROYED"}) {
+ SCOPED_TRACE(config);
+ backend.setDeferredResp(true);
+
+ testing::InSequence seq;
+ EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
+ EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
+ ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 1; }));
+ expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
+
+ if (config == "OFF") {
+ EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {}, {}), 0);
+ } else if (config == "NETWORK_DESTROYED") {
+ mPdc.clear(kNetId);
+ }
+
+ EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
+ backend.setDeferredResp(false);
+
+ ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
+ mObserver.removeFromServerStateMap(kServer1);
+ expectPrivateDnsStatus(PrivateDnsMode::OFF);
+ }
+}
+
+TEST_F(PrivateDnsConfigurationTest, NoValidation) {
+ // If onValidationStateUpdate() is called, the test will fail with uninteresting mock
+ // function calls in the end of the test.
+
+ const auto expectStatus = [&]() {
+ const PrivateDnsStatus status = mPdc.getStatus(kNetId);
+ EXPECT_EQ(status.mode, PrivateDnsMode::OFF);
+ EXPECT_THAT(status.serversMap, testing::IsEmpty());
+ };
+
+ EXPECT_EQ(mPdc.set(kNetId, kMark, {"invalid_addr"}, {}, {}), -EINVAL);
+ expectStatus();
+
+ EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {}, {}), 0);
+ expectStatus();
+}
+
+// TODO: add ValidationFail_Strict test.
+
+} // namespace android::net
diff --git a/ResolverController.cpp b/ResolverController.cpp
index 6a10326..051f095 100644
--- a/ResolverController.cpp
+++ b/ResolverController.cpp
@@ -201,6 +201,10 @@
int ResolverController::setResolverConfiguration(const ResolverParamsParcel& resolverParams) {
using aidl::android::net::IDnsResolver;
+ if (!has_named_cache(resolverParams.netId)) {
+ return -ENOENT;
+ }
+
// Expect to get the mark with system permission.
android_net_context netcontext;
gResNetdCallbacks.get_network_context(resolverParams.netId, 0 /* uid */, &netcontext);
@@ -223,6 +227,10 @@
return err;
}
+ if (int err = resolv_stats_set_servers_for_dot(resolverParams.netId, tlsServers); err != 0) {
+ return err;
+ }
+
res_params res_params = {};
res_params.sample_validity = resolverParams.sampleValiditySeconds;
res_params.success_threshold = resolverParams.successThreshold;
diff --git a/getaddrinfo.cpp b/getaddrinfo.cpp
index 4804762..b9da5f4 100644
--- a/getaddrinfo.cpp
+++ b/getaddrinfo.cpp
@@ -1710,7 +1710,7 @@
static int res_queryN_wrapper(const char* name, res_target* target, res_state res, int* herrno) {
const bool parallel_lookup =
- android::net::Experiments::getInstance()->getFlag("parallel_lookup", 1);
+ android::net::Experiments::getInstance()->getFlag("parallel_lookup", 0);
if (parallel_lookup) return res_queryN_parallel(name, target, res, herrno);
return res_queryN(name, target, res, herrno);
diff --git a/res_send.cpp b/res_send.cpp
index 5d93112..720e45b 100644
--- a/res_send.cpp
+++ b/res_send.cpp
@@ -144,8 +144,6 @@
using android::netdutils::Slice;
using android::netdutils::Stopwatch;
-static DnsTlsDispatcher sDnsTlsDispatcher;
-
static int send_vc(res_state statp, res_params* params, const uint8_t* buf, int buflen,
uint8_t* ans, int anssiz, int* terrno, size_t ns, time_t* at, int* rcode,
int* delay);
@@ -838,6 +836,9 @@
else
break;
}
+ LOG(WARNING) << __func__ << ": resplen " << resplen << " exceeds buf size " << anssiz;
+ // return size should never exceed container size
+ resplen = anssiz;
}
/*
* If the calling application has bailed out of
@@ -848,7 +849,7 @@
*/
if (hp->id != anhp->id) {
LOG(DEBUG) << __func__ << ": ld answer (unexpected):";
- res_pquery(ans, (resplen > anssiz) ? anssiz : resplen);
+ res_pquery(ans, resplen);
goto read_len;
}
@@ -1252,8 +1253,8 @@
LOG(INFO) << __func__ << ": performing query over TLS";
- const auto response = sDnsTlsDispatcher.query(privateDnsStatus.validatedServers(), statp, query,
- answer, &resplen);
+ const auto response = DnsTlsDispatcher::getInstance().query(privateDnsStatus.validatedServers(),
+ statp, query, answer, &resplen);
LOG(INFO) << __func__ << ": TLS query result: " << static_cast<int>(response);
diff --git a/resolv_cache.h b/resolv_cache.h
index d8a3afc..896ace7 100644
--- a/resolv_cache.h
+++ b/resolv_cache.h
@@ -100,7 +100,6 @@
// Get transport types to a given network.
android::net::NetworkType resolv_get_network_types_for_net(unsigned netid);
-// For test only.
// Return true if the cache is existent in the given network, false otherwise.
bool has_named_cache(unsigned netid);
diff --git a/resolv_unit_test.cpp b/resolv_unit_test.cpp
index 3274ed8..befca9e 100644
--- a/resolv_unit_test.cpp
+++ b/resolv_unit_test.cpp
@@ -123,9 +123,7 @@
dns.clearQueries();
}
- int SetResolvers() {
- return resolv_set_nameservers(TEST_NETID, servers, domains, params);
- }
+ int SetResolvers() { return resolv_set_nameservers(TEST_NETID, servers, domains, params); }
const android_net_context mNetcontext = {
.app_netid = TEST_NETID,
@@ -1139,6 +1137,44 @@
}
}
+// Audit if Resolver read out of bounds, which needs HWAddressSanitizer build to trigger SIGABRT.
+TEST_F(ResolvGetAddrInfoTest, OverlengthResp) {
+ std::vector<std::string> nameList;
+ // Construct a long enough record that exceeds 8192 bytes (the maximum buffer size):
+ // Header: (Transaction ID, Flags, ...) = 12 bytes
+ // Query: 19(Name)+2(Type)+2(Class) = 23 bytes
+ // The 1st answer RR: 19(Name)+2(Type)+2(Class)+4(TTL)+2(Len)+77(CNAME) = 106 bytes
+ // 2nd-50th answer RRs: 49*(77(Name)+2(Type)+2(Class)+4(TTL)+2(Len)+77(CNAME)) = 8036 bytes
+ // The last answer RR: 77(Name)+2(Type)+2(Class)+4(TTL)+2(Len)+4(Address) = 91 bytes
+ // ----------------------------------------------------------------------------------------
+ // Sum: 8268 bytes
+ for (int i = 0; i < 10; i++) {
+ std::string domain(kMaxmiumLabelSize / 2, 'a' + i);
+ for (int j = 0; j < 5; j++) {
+ nameList.push_back(domain + std::string(kMaxmiumLabelSize / 2 + 1, '0' + j) +
+ kExampleComDomain + ".");
+ }
+ }
+ test::DNSResponder dns;
+ dns.addMapping(kHelloExampleCom, ns_type::ns_t_cname, nameList[0]);
+ for (size_t i = 0; i < nameList.size() - 1; i++) {
+ dns.addMapping(nameList[i], ns_type::ns_t_cname, nameList[i + 1]);
+ }
+ dns.addMapping(nameList[nameList.size() - 1], ns_type::ns_t_a, kHelloExampleComAddrV4);
+
+ ASSERT_TRUE(dns.startServer());
+ ASSERT_EQ(0, SetResolvers());
+ addrinfo* result = nullptr;
+ const addrinfo hints = {.ai_family = AF_INET};
+ NetworkDnsEventReported event;
+ int rv = resolv_getaddrinfo("hello", nullptr, &hints, &mNetcontext, &result, &event);
+ ScopedAddrinfo result_cleanup(result);
+ EXPECT_EQ(rv, EAI_FAIL);
+ EXPECT_TRUE(result == nullptr);
+ EXPECT_EQ(GetNumQueriesForProtocol(dns, IPPROTO_UDP, kHelloExampleCom), 2U);
+ EXPECT_EQ(GetNumQueriesForProtocol(dns, IPPROTO_TCP, kHelloExampleCom), 2U);
+}
+
TEST_F(GetHostByNameForNetContextTest, AlphabeticalHostname) {
constexpr char host_name[] = "jiababuei.example.com.";
constexpr char v4addr[] = "1.2.3.4";
diff --git a/tests/dns_responder/dns_responder.cpp b/tests/dns_responder/dns_responder.cpp
index 29d0e89..9c7c62e 100644
--- a/tests/dns_responder/dns_responder.cpp
+++ b/tests/dns_responder/dns_responder.cpp
@@ -252,7 +252,7 @@
}
std::string DNSQuestion::toString() const {
- char buffer[4096];
+ char buffer[16384];
int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.name.c_str(),
dnstype2str(qtype), dnsclass2str(qclass));
return std::string(buffer, len);
@@ -291,7 +291,7 @@
}
std::string DNSRecord::toString() const {
- char buffer[4096];
+ char buffer[16384];
int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.name.c_str(), dnstype2str(rtype),
dnsclass2str(rclass));
return std::string(buffer, len);
@@ -418,7 +418,7 @@
// TODO: convert all callers to this interface, then delete the old one.
bool DNSHeader::write(std::vector<uint8_t>* out) const {
- char buffer[4096];
+ char buffer[16384];
char* end = this->write(buffer, buffer + sizeof buffer);
if (end == nullptr) return false;
out->insert(out->end(), buffer, end);
@@ -896,7 +896,7 @@
bool DNSResponder::makeResponse(DNSHeader* header, int protocol, char* response,
size_t* response_len) const {
- char buffer[4096];
+ char buffer[16384];
size_t buffer_len = sizeof(buffer);
bool ret;
@@ -1059,7 +1059,7 @@
}
void DNSResponder::handleQuery(int protocol) {
- char buffer[4096];
+ char buffer[16384];
sockaddr_storage sa;
socklen_t sa_len = sizeof(sa);
ssize_t len = 0;
@@ -1103,7 +1103,7 @@
}
LOG(DEBUG) << "read " << len << " bytes on " << dnsproto2str(protocol);
std::lock_guard lock(cv_mutex_);
- char response[4096];
+ char response[16384];
size_t response_len = sizeof(response);
// TODO: check whether sending malformed packets to DnsResponder
if (handleDNSRequest(buffer, len, protocol, response, &response_len) && response_len > 0) {
diff --git a/tests/dns_responder/dns_tls_frontend.cpp b/tests/dns_responder/dns_tls_frontend.cpp
index d960254..7e7cb12 100644
--- a/tests/dns_responder/dns_tls_frontend.cpp
+++ b/tests/dns_responder/dns_tls_frontend.cpp
@@ -223,6 +223,11 @@
// client, including cleanup actions.
queries_ += handleRequests(ssl.get(), client.get());
}
+
+ if (passiveClose_) {
+ LOG(DEBUG) << "hold the current connection until next connection request";
+ clientFd = std::move(client);
+ }
}
}
LOG(DEBUG) << "Ending loop";
@@ -230,6 +235,7 @@
int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) {
int queryCounts = 0;
+ std::vector<uint8_t> reply;
pollfd fds = {.fd = clientFd, .events = POLLIN};
do {
uint8_t queryHeader[2];
@@ -263,16 +269,25 @@
uint8_t responseHeader[2];
responseHeader[0] = rlen >> 8;
responseHeader[1] = rlen;
- if (SSL_write(ssl, responseHeader, 2) != 2) {
- LOG(INFO) << "Failed to write response header";
- return queryCounts;
- }
- if (SSL_write(ssl, recv_buffer, rlen) != rlen) {
- LOG(INFO) << "Failed to write response body";
- return queryCounts;
- }
+ reply.insert(reply.end(), responseHeader, responseHeader + 2);
+ reply.insert(reply.end(), recv_buffer, recv_buffer + rlen);
+
++queryCounts;
- } while (poll(&fds, 1, 1) > 0);
+ if (queryCounts >= delayQueries_) {
+ break;
+ }
+ } while (poll(&fds, 1, delayQueriesTimeout_) > 0);
+
+ if (queryCounts < delayQueries_) {
+ LOG(WARNING) << "Expect " << delayQueries_ << " queries, but actually received "
+ << queryCounts << " queries";
+ }
+
+ const int replyLen = reply.size();
+ LOG(DEBUG) << "Sending " << queryCounts << "queries at once, byte = " << replyLen;
+ if (SSL_write(ssl, reply.data(), replyLen) != replyLen) {
+ LOG(WARNING) << "Failed to write response body";
+ }
LOG(DEBUG) << __func__ << " return: " << queryCounts;
return queryCounts;
diff --git a/tests/dns_responder/dns_tls_frontend.h b/tests/dns_responder/dns_tls_frontend.h
index 6ba5681..02cdc21 100644
--- a/tests/dns_responder/dns_tls_frontend.h
+++ b/tests/dns_responder/dns_tls_frontend.h
@@ -62,6 +62,12 @@
void set_chain_length(int length) { chain_length_ = length; }
void setHangOnHandshakeForTesting(bool hangOnHandshake) { hangOnHandshake_ = hangOnHandshake; }
+ // Set DnsTlsFrontend to not reply any response until there are |delay| responses or timeout.
+ void setDelayQueries(int delay) { delayQueries_ = delay; }
+ void setDelayQueriesTimeout(int timeout) { delayQueriesTimeout_ = timeout; }
+
+ void setPassiveClose(bool passiveClose) { passiveClose_ = passiveClose; }
+
static constexpr char kDefaultListenAddr[] = "127.0.0.3";
static constexpr char kDefaultListenService[] = "853";
static constexpr char kDefaultBackendAddr[] = "127.0.0.3";
@@ -94,6 +100,9 @@
std::mutex update_mutex_;
int chain_length_ = 1;
std::atomic<bool> hangOnHandshake_ = false;
+ std::atomic<int> delayQueries_ = 1;
+ std::atomic<int> delayQueriesTimeout_ = 1;
+ std::atomic<bool> passiveClose_ = false;
};
} // namespace test
diff --git a/tests/dnsresolver_binder_test.cpp b/tests/dnsresolver_binder_test.cpp
index 6fb73f3..50f446e 100644
--- a/tests/dnsresolver_binder_test.cpp
+++ b/tests/dnsresolver_binder_test.cpp
@@ -21,12 +21,17 @@
#include <netdb.h>
#include <iostream>
+#include <regex>
+#include <string>
+#include <thread>
#include <vector>
#include <aidl/android/net/IDnsResolver.h>
#include <android-base/file.h>
+#include <android-base/format.h>
#include <android-base/stringprintf.h>
#include <android-base/strings.h>
+#include <android-base/unique_fd.h>
#include <android/binder_manager.h>
#include <android/binder_process.h>
#include <gmock/gmock-matchers.h>
@@ -41,9 +46,13 @@
#include "dns_responder_client_ndk.h"
using aidl::android::net::IDnsResolver;
+using aidl::android::net::ResolverHostsParcel;
+using aidl::android::net::ResolverOptionsParcel;
using aidl::android::net::ResolverParamsParcel;
using aidl::android::net::metrics::INetdEventListener;
+using android::base::ReadFdToString;
using android::base::StringPrintf;
+using android::base::unique_fd;
using android::net::ResolverStats;
using android::net::metrics::TestOnDnsEvent;
using android::netdutils::Stopwatch;
@@ -52,6 +61,37 @@
// Sync from TEST_NETID in dns_responder_client.cpp as resolv_integration_test.cpp does.
constexpr int TEST_NETID = 30;
+namespace {
+
+std::vector<std::string> dumpService(ndk::SpAIBinder binder) {
+ unique_fd localFd, remoteFd;
+ bool success = Pipe(&localFd, &remoteFd);
+ EXPECT_TRUE(success) << "Failed to open pipe for dumping: " << strerror(errno);
+ if (!success) return {};
+
+ // dump() blocks until another thread has consumed all its output.
+ std::thread dumpThread = std::thread([binder, remoteFd{std::move(remoteFd)}]() {
+ EXPECT_EQ(STATUS_OK, AIBinder_dump(binder.get(), remoteFd, nullptr, 0));
+ });
+
+ std::string dumpContent;
+
+ EXPECT_TRUE(ReadFdToString(localFd.get(), &dumpContent))
+ << "Error during dump: " << strerror(errno);
+ dumpThread.join();
+
+ std::stringstream dumpStream(std::move(dumpContent));
+ std::vector<std::string> lines;
+ std::string line;
+ while (std::getline(dumpStream, line)) {
+ lines.push_back(std::move(line));
+ }
+
+ return lines;
+}
+
+} // namespace
+
class DnsResolverBinderTest : public ::testing::Test {
public:
DnsResolverBinderTest() {
@@ -65,12 +105,143 @@
}
~DnsResolverBinderTest() {
+ expectLog();
// Destroy cache for test
mDnsResolver->destroyNetworkCache(TEST_NETID);
}
protected:
+ void expectLog() {
+ ndk::SpAIBinder netdBinder = ndk::SpAIBinder(AServiceManager_getService("netd"));
+ // This could happen when the test isn't running as root, or if netd isn't running.
+ assert(nullptr != netdBinder.get());
+ // Send the service dump request to netd.
+ std::vector<std::string> lines = dumpService(netdBinder);
+
+ // Basic regexp to match dump output lines. Matches the beginning and end of the line, and
+ // puts the output of the command itself into the first match group.
+ // Example: " 11-05 00:23:39.481 myCommand(args) <2.02ms>".
+ const std::basic_regex lineRegex(
+ "^ [0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}[.][0-9]{3} "
+ "(.*)"
+ " <[0-9]+[.][0-9]{2}ms>$");
+
+ // For each element of testdata, check that the expected output appears in the dump output.
+ // If not, fail the test and use hintRegex to print similar lines to assist in debugging.
+ for (const auto& td : mExpectedLogData) {
+ const bool found =
+ std::any_of(lines.begin(), lines.end(), [&](const std::string& line) {
+ std::smatch match;
+ if (!std::regex_match(line, match, lineRegex)) return false;
+ return (match.size() == 2) && (match[1].str() == td.output);
+ });
+ EXPECT_TRUE(found) << "Didn't find line '" << td.output << "' in dumpsys output.";
+ if (found) continue;
+ std::cerr << "Similar lines" << std::endl;
+ for (const auto& line : lines) {
+ if (std::regex_search(line, std::basic_regex(td.hintRegex))) {
+ std::cerr << line << std::endl;
+ }
+ }
+ }
+
+ // The log output is different between R and S, either one is fine for the
+ // test to avoid test compatible issue.
+ // TODO: Remove after S.
+ for (const auto& td : mExpectedLogDataWithPacel) {
+ const bool found =
+ std::any_of(lines.begin(), lines.end(), [&](const std::string& line) {
+ std::smatch match;
+ if (!std::regex_match(line, match, lineRegex)) return false;
+ return (match.size() == 2) && ((match[1].str() == td.withPacel.output) ||
+ (match[1].str() == td.withoutPacel.output));
+ });
+ EXPECT_TRUE(found) << fmt::format("Didn't find line '{}' or '{}' in dumpsys output.",
+ td.withPacel.output, td.withoutPacel.output);
+ if (found) continue;
+ std::cerr << "Similar lines" << std::endl;
+ for (const auto& line : lines) {
+ if (std::regex_search(line, std::basic_regex(td.withPacel.hintRegex))) {
+ std::cerr << line << std::endl;
+ }
+ if (std::regex_search(line, std::basic_regex(td.withoutPacel.hintRegex))) {
+ std::cerr << line << std::endl;
+ }
+ }
+ }
+ }
+
+ struct LogData {
+ // Expected contents of the dump command.
+ const std::string output;
+ // A regex that might be helpful in matching relevant lines in the output.
+ // Used to make it easier to add test cases for this code.
+ const std::string hintRegex;
+ };
+
+ // TODO: Remove this struct and below toString methods after S.
+ struct PossibleLogData {
+ LogData withPacel;
+ LogData withoutPacel;
+ };
+
+ std::string toString(const std::vector<ResolverHostsParcel>& parms) {
+ std::string o;
+ const size_t size = parms.size();
+ for (size_t i = 0; i < size; ++i) {
+ o.append(fmt::format("ResolverHostsParcel{{ipAddr: {}, hostName: {}}}", parms[i].ipAddr,
+ parms[i].hostName));
+ if (i + 1 < size) o.append(", ");
+ }
+ return o;
+ }
+
+ std::string toString(const ResolverOptionsParcel& parms) {
+ return fmt::format("ResolverOptionsParcel{{hosts: [{}], tcMode: {}, enforceDnsUid: {}}}",
+ toString(parms.hosts), parms.tcMode, parms.enforceDnsUid);
+ }
+
+ std::string toString(const ResolverParamsParcel& parms) {
+ return fmt::format(
+ "ResolverParamsParcel{{netId: {}, sampleValiditySeconds: {}, successThreshold: {}, "
+ "minSamples: {}, "
+ "maxSamples: {}, baseTimeoutMsec: {}, retryCount: {}, "
+ "servers: [{}], domains: [{}], "
+ "tlsName: {}, tlsServers: [{}], "
+ "tlsFingerprints: [{}], "
+ "caCertificate: {}, tlsConnectTimeoutMs: {}, "
+ "resolverOptions: {}, transportTypes: [{}]}}",
+ parms.netId, parms.sampleValiditySeconds, parms.successThreshold, parms.minSamples,
+ parms.maxSamples, parms.baseTimeoutMsec, parms.retryCount,
+ fmt::join(parms.servers, ", "), fmt::join(parms.domains, ", "), parms.tlsName,
+ fmt::join(parms.tlsServers, ", "), fmt::join(parms.tlsFingerprints, ", "),
+ android::base::StringReplace(parms.caCertificate, "\n", "\\n", true),
+ parms.tlsConnectTimeoutMs, toString(parms.resolverOptions),
+ fmt::join(parms.transportTypes, ", "));
+ }
+
+ PossibleLogData toSetResolverConfigurationLogData(const ResolverParamsParcel& parms,
+ int returnCode = 0) {
+ std::string outputWithParcel = "setResolverConfiguration(" + toString(parms) + ")";
+ std::string hintRegexWithParcel = fmt::format("setResolverConfiguration.*{}", parms.netId);
+
+ std::string outputWithoutParcel = "setResolverConfiguration()";
+ std::string hintRegexWithoutParcel = "setResolverConfiguration";
+ if (returnCode != 0) {
+ outputWithParcel.append(fmt::format(" -> ServiceSpecificException({}, \"{}\")",
+ returnCode, strerror(returnCode)));
+ hintRegexWithParcel.append(fmt::format(".*{}", returnCode));
+ outputWithoutParcel.append(fmt::format(" -> ServiceSpecificException({}, \"{}\")",
+ returnCode, strerror(returnCode)));
+ hintRegexWithoutParcel.append(fmt::format(".*{}", returnCode));
+ }
+ return {{std::move(outputWithParcel), std::move(hintRegexWithParcel)},
+ {std::move(outputWithoutParcel), std::move(hintRegexWithoutParcel)}};
+ }
+
std::shared_ptr<aidl::android::net::IDnsResolver> mDnsResolver;
+ std::vector<LogData> mExpectedLogData;
+ std::vector<PossibleLogData> mExpectedLogDataWithPacel;
};
class TimedOperation : public Stopwatch {
@@ -95,6 +266,9 @@
::ndk::ScopedAStatus status = mDnsResolver->registerEventListener(nullptr);
ASSERT_FALSE(status.isOk());
ASSERT_EQ(EINVAL, status.getServiceSpecificError());
+ mExpectedLogData.push_back(
+ {"registerEventListener() -> ServiceSpecificException(22, \"Invalid argument\")",
+ "registerEventListener.*22"});
}
TEST_F(DnsResolverBinderTest, RegisterEventListener_DuplicateSubscription) {
@@ -104,11 +278,15 @@
std::shared_ptr<DummyListener> dummyListener = ndk::SharedRefBase::make<DummyListener>();
::ndk::ScopedAStatus status = mDnsResolver->registerEventListener(dummyListener);
ASSERT_TRUE(status.isOk()) << status.getMessage();
+ mExpectedLogData.push_back({"registerEventListener()", "registerEventListener.*"});
// Expect to subscribe failed with registered listener instance.
status = mDnsResolver->registerEventListener(dummyListener);
ASSERT_FALSE(status.isOk());
ASSERT_EQ(EEXIST, status.getServiceSpecificError());
+ mExpectedLogData.push_back(
+ {"registerEventListener() -> ServiceSpecificException(17, \"File exists\")",
+ "registerEventListener.*17"});
}
// TODO: Move this test to resolv_integration_test.cpp
@@ -161,6 +339,7 @@
ndk::SharedRefBase::make<TestOnDnsEvent>(expectedResults);
::ndk::ScopedAStatus status = mDnsResolver->registerEventListener(testOnDnsEvent);
ASSERT_TRUE(status.isOk()) << status.getMessage();
+ mExpectedLogData.push_back({"registerEventListener()", "registerEventListener.*"});
// DNS queries.
// Once all expected events of expectedResults are received by the listener, the unit test will
@@ -238,10 +417,13 @@
SCOPED_TRACE(StringPrintf("test case %zu should have passed", i));
SCOPED_TRACE(status.getMessage());
EXPECT_EQ(0, status.getServiceSpecificError());
+ mExpectedLogDataWithPacel.push_back(toSetResolverConfigurationLogData(resolverParams));
} else {
SCOPED_TRACE(StringPrintf("test case %zu should have failed", i));
EXPECT_EQ(EX_SERVICE_SPECIFIC, status.getExceptionCode());
EXPECT_EQ(td.expectedReturnCode, status.getServiceSpecificError());
+ mExpectedLogDataWithPacel.push_back(
+ toSetResolverConfigurationLogData(resolverParams, td.expectedReturnCode));
}
}
}
@@ -252,6 +434,7 @@
resolverParams.transportTypes = {IDnsResolver::TRANSPORT_WIFI, IDnsResolver::TRANSPORT_VPN};
::ndk::ScopedAStatus status = mDnsResolver->setResolverConfiguration(resolverParams);
EXPECT_TRUE(status.isOk()) << status.getMessage();
+ mExpectedLogDataWithPacel.push_back(toSetResolverConfigurationLogData(resolverParams));
// TODO: Find a way to fix a potential deadlock here if it's larger than pipe buffer
// size(65535).
android::base::unique_fd writeFd, readFd;
@@ -268,6 +451,7 @@
auto resolverParams = DnsResponderClient::GetDefaultResolverParamsParcel();
::ndk::ScopedAStatus status = mDnsResolver->setResolverConfiguration(resolverParams);
EXPECT_TRUE(status.isOk()) << status.getMessage();
+ mExpectedLogDataWithPacel.push_back(toSetResolverConfigurationLogData(resolverParams));
android::base::unique_fd writeFd, readFd;
EXPECT_TRUE(Pipe(&readFd, &writeFd));
EXPECT_EQ(mDnsResolver->dump(writeFd.get(), nullptr, 0), 0);
@@ -291,6 +475,7 @@
TEST_NETID, testParams, servers, domains, "", {});
::ndk::ScopedAStatus status = mDnsResolver->setResolverConfiguration(resolverParams);
EXPECT_TRUE(status.isOk()) << status.getMessage();
+ mExpectedLogDataWithPacel.push_back(toSetResolverConfigurationLogData(resolverParams));
std::vector<std::string> res_servers;
std::vector<std::string> res_domains;
@@ -335,46 +520,68 @@
// Create a new network cache.
EXPECT_TRUE(mDnsResolver->createNetworkCache(ANOTHER_TEST_NETID).isOk());
+ mExpectedLogData.push_back({"createNetworkCache(31)", "createNetworkCache.*31"});
// create it again, expect a EEXIST.
EXPECT_EQ(EEXIST,
mDnsResolver->createNetworkCache(ANOTHER_TEST_NETID).getServiceSpecificError());
+ mExpectedLogData.push_back(
+ {"createNetworkCache(31) -> ServiceSpecificException(17, \"File exists\")",
+ "createNetworkCache.*31.*17"});
// destroy it.
EXPECT_TRUE(mDnsResolver->destroyNetworkCache(ANOTHER_TEST_NETID).isOk());
+ mExpectedLogData.push_back({"destroyNetworkCache(31)", "destroyNetworkCache.*31"});
// re-create it
EXPECT_TRUE(mDnsResolver->createNetworkCache(ANOTHER_TEST_NETID).isOk());
+ mExpectedLogData.push_back({"createNetworkCache(31)", "createNetworkCache.*31"});
// destroy it.
EXPECT_TRUE(mDnsResolver->destroyNetworkCache(ANOTHER_TEST_NETID).isOk());
+ mExpectedLogData.push_back({"destroyNetworkCache(31)", "destroyNetworkCache.*31"});
// re-destroy it
EXPECT_TRUE(mDnsResolver->destroyNetworkCache(ANOTHER_TEST_NETID).isOk());
+ mExpectedLogData.push_back({"destroyNetworkCache(31)", "destroyNetworkCache.*31"});
}
TEST_F(DnsResolverBinderTest, FlushNetworkCache) {
SKIP_IF_REMOTE_VERSION_LESS_THAN(mDnsResolver.get(), 4);
// cache has beed created in DnsResolverBinderTest constructor
EXPECT_TRUE(mDnsResolver->flushNetworkCache(TEST_NETID).isOk());
+ mExpectedLogData.push_back({"flushNetworkCache(30)", "destroyNetworkCache.*30"});
EXPECT_EQ(ENONET, mDnsResolver->flushNetworkCache(-1).getServiceSpecificError());
+ mExpectedLogData.push_back(
+ {"flushNetworkCache(-1) -> ServiceSpecificException(64, \"Machine is not on the "
+ "network\")",
+ "flushNetworkCache.*-1.*64"});
}
TEST_F(DnsResolverBinderTest, setLogSeverity) {
// Expect fail
EXPECT_EQ(EINVAL, mDnsResolver->setLogSeverity(-1).getServiceSpecificError());
+ mExpectedLogData.push_back(
+ {"setLogSeverity(-1) -> ServiceSpecificException(22, \"Invalid argument\")",
+ "flushNetworkCache.*-1.*22"});
// Test set different log level
EXPECT_TRUE(mDnsResolver->setLogSeverity(IDnsResolver::DNS_RESOLVER_LOG_VERBOSE).isOk());
+ mExpectedLogData.push_back({"setLogSeverity(0)", "setLogSeverity.*0"});
EXPECT_TRUE(mDnsResolver->setLogSeverity(IDnsResolver::DNS_RESOLVER_LOG_DEBUG).isOk());
+ mExpectedLogData.push_back({"setLogSeverity(1)", "setLogSeverity.*1"});
EXPECT_TRUE(mDnsResolver->setLogSeverity(IDnsResolver::DNS_RESOLVER_LOG_INFO).isOk());
+ mExpectedLogData.push_back({"setLogSeverity(2)", "setLogSeverity.*2"});
EXPECT_TRUE(mDnsResolver->setLogSeverity(IDnsResolver::DNS_RESOLVER_LOG_WARNING).isOk());
+ mExpectedLogData.push_back({"setLogSeverity(3)", "setLogSeverity.*3"});
EXPECT_TRUE(mDnsResolver->setLogSeverity(IDnsResolver::DNS_RESOLVER_LOG_ERROR).isOk());
+ mExpectedLogData.push_back({"setLogSeverity(4)", "setLogSeverity.*4"});
// Set back to default
EXPECT_TRUE(mDnsResolver->setLogSeverity(IDnsResolver::DNS_RESOLVER_LOG_WARNING).isOk());
+ mExpectedLogData.push_back({"setLogSeverity(3)", "setLogSeverity.*3"});
}
diff --git a/tests/resolv_integration_test.cpp b/tests/resolv_integration_test.cpp
index 5d41a8a..b53d2ba 100644
--- a/tests/resolv_integration_test.cpp
+++ b/tests/resolv_integration_test.cpp
@@ -184,8 +184,8 @@
// service.
AIBinder* binder = AServiceManager_getService("dnsresolver");
- ndk::SpAIBinder resolvBinder = ndk::SpAIBinder(binder);
- auto resolvService = aidl::android::net::IDnsResolver::fromBinder(resolvBinder);
+ sResolvBinder = ndk::SpAIBinder(binder);
+ auto resolvService = aidl::android::net::IDnsResolver::fromBinder(sResolvBinder);
ASSERT_NE(nullptr, resolvService.get());
// Subscribe the death recipient to the service IDnsResolver for detecting Netd death.
@@ -371,11 +371,16 @@
// Use a shared static death recipient to monitor the service death. The static death
// recipient could monitor the death not only during the test but also between tests.
static AIBinder_DeathRecipient* sResolvDeathRecipient; // Initialized in SetUpTestSuite.
+
+ // The linked AIBinder_DeathRecipient will be automatically unlinked if the binder is deleted.
+ // The binder needs to be retained throughout tests.
+ static ndk::SpAIBinder sResolvBinder;
};
// Initialize static member of class.
std::shared_ptr<DnsMetricsListener> ResolverTest::sDnsMetricsListener;
AIBinder_DeathRecipient* ResolverTest::sResolvDeathRecipient;
+ndk::SpAIBinder ResolverTest::sResolvBinder;
TEST_F(ResolverTest, GetHostByName) {
constexpr char nonexistent_host_name[] = "nonexistent.example.com.";
@@ -5448,6 +5453,67 @@
} while (std::next_permutation(serverList.begin(), serverList.end()));
}
+TEST_F(ResolverTest, MultipleDotQueriesInOnePacket) {
+ constexpr char hostname1[] = "query1.example.com.";
+ constexpr char hostname2[] = "query2.example.com.";
+ const std::vector<DnsRecord> records = {
+ {hostname1, ns_type::ns_t_a, "1.2.3.4"},
+ {hostname2, ns_type::ns_t_a, "1.2.3.5"},
+ };
+
+ const std::string addr = getUniqueIPv4Address();
+ test::DNSResponder dns(addr);
+ StartDns(dns, records);
+ test::DnsTlsFrontend tls(addr, "853", addr, "53");
+ ASSERT_TRUE(tls.startServer());
+
+ // Set up resolver to strict mode.
+ auto parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
+ parcel.servers = {addr};
+ parcel.tlsServers = {addr};
+ parcel.tlsName = kDefaultPrivateDnsHostName;
+ parcel.caCertificate = kCaCert;
+ ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+ EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
+ EXPECT_TRUE(tls.waitForQueries(1));
+ tls.clearQueries();
+ dns.clearQueries();
+
+ const auto queryAndCheck = [&](const std::string& hostname,
+ const std::vector<DnsRecord>& records) {
+ SCOPED_TRACE(hostname);
+
+ const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
+ auto [result, timeTakenMs] = safe_getaddrinfo_time_taken(hostname.c_str(), nullptr, hints);
+
+ std::vector<std::string> expectedAnswers;
+ for (const auto& r : records) {
+ if (r.host_name == hostname) expectedAnswers.push_back(r.addr);
+ }
+
+ EXPECT_LE(timeTakenMs, 200);
+ ASSERT_NE(result, nullptr);
+ EXPECT_THAT(ToStrings(result), testing::UnorderedElementsAreArray(expectedAnswers));
+ };
+
+ // Set tls to reply DNS responses in one TCP packet and not to close the connection from its
+ // side.
+ tls.setDelayQueries(2);
+ tls.setDelayQueriesTimeout(500);
+ tls.setPassiveClose(true);
+
+ // Start sending DNS requests at the same time.
+ std::array<std::thread, 2> threads;
+ threads[0] = std::thread(queryAndCheck, hostname1, records);
+ threads[1] = std::thread(queryAndCheck, hostname2, records);
+
+ threads[0].join();
+ threads[1].join();
+
+ // Also check no additional queries due to DoT reconnection.
+ EXPECT_EQ(tls.queries(), 2);
+}
+
// ResolverMultinetworkTest is used to verify multinetwork functionality. Here's how it works:
// The resolver sends queries to address A, and then there will be a TunForwarder helping forward
// the packets to address B, which is the address on which the testing server is listening. The