Support multinetwork tests
The integration test used to set up testing DNS servers on loopback
interface. To support testing functionality for multinetwork, make
the test able to send queries to a TUN interface, and the queries
will be forwarded to the testing DNS servers.
To forward packets, implement a forwarder which can translate packets
(v4-to-v4 or v6-to-v6) between the resolver and testing DNS servers
and can forward packets to each other.
Also add three tests:
GetAddrInfo_AI_ADDRCONFIG
NetworkDestroyedDuringQueryInFlight
OneCachePerNetwork
And remove unused libraries from the test:
libnetd_test_tun_interface
libnetd_test_utils
Test: cd packages/modules/DnsResolver && atest
Change-Id: I52a52ce59373bc8b9462064c0409b657696c379f
diff --git a/tests/resolv_integration_test.cpp b/tests/resolv_integration_test.cpp
index ff5b57b..1ea7d01 100644
--- a/tests/resolv_integration_test.cpp
+++ b/tests/resolv_integration_test.cpp
@@ -20,6 +20,7 @@
#include <android-base/logging.h>
#include <android-base/parseint.h>
#include <android-base/properties.h>
+#include <android-base/result.h>
#include <android-base/stringprintf.h>
#include <android-base/unique_fd.h>
#include <android/multinetwork.h> // ResNsendFlags
@@ -62,13 +63,13 @@
#include "netid_client.h" // NETID_UNSET
#include "params.h" // MAXNS
#include "stats.h" // RCODE_TIMEOUT
-#include "test_utils.h"
#include "tests/dns_metrics_listener/dns_metrics_listener.h"
#include "tests/dns_responder/dns_responder.h"
#include "tests/dns_responder/dns_responder_client_ndk.h"
#include "tests/dns_responder/dns_tls_certificate.h"
#include "tests/dns_responder/dns_tls_frontend.h"
#include "tests/resolv_test_utils.h"
+#include "tests/tun_forwarder.h"
// Valid VPN netId range is 100 ~ 65535
constexpr int TEST_VPN_NETID = 65502;
@@ -86,10 +87,13 @@
using aidl::android::net::INetd;
using aidl::android::net::ResolverParamsParcel;
using aidl::android::net::metrics::INetdEventListener;
+using android::base::Error;
using android::base::ParseInt;
+using android::base::Result;
using android::base::StringPrintf;
using android::base::unique_fd;
using android::net::ResolverStats;
+using android::net::TunForwarder;
using android::net::metrics::DnsMetricsListener;
using android::netdutils::enableSockopt;
using android::netdutils::makeSlice;
@@ -5116,3 +5120,337 @@
EXPECT_EQ(dns1.queries().size(), 0U);
EXPECT_EQ(dns2.queries().size(), 0U);
}
+
+// ResolverMultinetworkTest is used to verify multinetwork functionality. Here's how it works:
+// The resolver sends queries to address A, and then there will be a TunForwarder helping forward
+// the packets to address B, which is the address on which the testing server is listening. The
+// answer packets responded from the testing server go through the reverse path back to the
+// resolver.
+//
+// To achieve the that, it needs to set up a interface with routing rules. Tests are not
+// supposed to initiate DNS servers on their own; instead, some utilities are added to the class to
+// help the setup.
+//
+// An example of how to use it:
+// TEST_F() {
+// ScopedNetwork network = CreateScopedNetwork(V4);
+// network.init();
+//
+// auto dns = network.addIpv4Dns();
+// StartDns(dns.dnsServer, {});
+//
+// setResolverConfiguration(...);
+// network.startTunForwarder();
+//
+// // Send queries here
+// }
+
+class ResolverMultinetworkTest : public ResolverTest {
+ protected:
+ enum class ConnectivityType { V4, V6, V4V6 };
+
+ struct DnsServerPair {
+ test::DNSResponder& dnsServer;
+ std::string dnsAddr; // The DNS server address used for setResolverConfiguration().
+ // TODO: Add test::DnsTlsFrontend* and std::string for DoT.
+ };
+
+ class ScopedNetwork {
+ public:
+ ScopedNetwork(unsigned netId, ConnectivityType type, INetd* netdSrv,
+ IDnsResolver* dnsResolvSrv)
+ : mNetId(netId),
+ mConnectivityType(type),
+ mNetdSrv(netdSrv),
+ mDnsResolvSrv(dnsResolvSrv) {
+ mIfname = StringPrintf("testtun%d", netId);
+ }
+ ~ScopedNetwork() { destroy(); }
+
+ Result<void> init();
+ void destroy();
+ Result<DnsServerPair> addIpv4Dns() { return addDns(ConnectivityType::V4); }
+ Result<DnsServerPair> addIpv6Dns() { return addDns(ConnectivityType::V6); }
+ bool startTunForwarder() { return mTunForwarder->startForwarding(); }
+ unsigned netId() const { return mNetId; }
+
+ private:
+ Result<DnsServerPair> addDns(ConnectivityType connectivity);
+ std::string makeIpv4AddrString(unsigned n) const {
+ return StringPrintf("192.168.%u.%u", mNetId, n);
+ }
+ std::string makeIpv6AddrString(unsigned n) const {
+ return StringPrintf("2001:db8:%u::%u", mNetId, n);
+ }
+
+ const unsigned mNetId;
+ const ConnectivityType mConnectivityType;
+ INetd* mNetdSrv;
+ IDnsResolver* mDnsResolvSrv;
+
+ std::string mIfname;
+ std::unique_ptr<TunForwarder> mTunForwarder;
+ std::vector<std::unique_ptr<test::DNSResponder>> mDnsServers;
+ // TODO: Add std::vector<std::unique_ptr<test::DnsTlsFrontend>>
+ };
+
+ void SetUp() override {
+ ResolverTest::SetUp();
+ ASSERT_NE(mDnsClient.netdService(), nullptr);
+ ASSERT_NE(mDnsClient.resolvService(), nullptr);
+ }
+
+ void TearDown() override { ResolverTest::TearDown(); }
+
+ ScopedNetwork CreateScopedNetwork(ConnectivityType type);
+ void StartDns(test::DNSResponder& dns, const std::vector<DnsRecord>& records);
+
+ unsigned getFreeNetId() { return mNextNetId++; }
+
+ private:
+ // Use a different netId because this class inherits from the class ResolverTest which
+ // always creates TEST_NETID in setup. It's incremented when CreateScopedNetwork() is called.
+ // Note: Don't create more than 20 networks in the class since 51 is used for the dummy network.
+ unsigned mNextNetId = 31;
+};
+
+ResolverMultinetworkTest::ScopedNetwork ResolverMultinetworkTest::CreateScopedNetwork(
+ ConnectivityType type) {
+ return {getFreeNetId(), type, mDnsClient.netdService(), mDnsClient.resolvService()};
+}
+
+Result<void> ResolverMultinetworkTest::ScopedNetwork::init() {
+ unique_fd ufd = TunForwarder::createTun(mIfname);
+ if (!ufd.ok()) {
+ return Errorf("createTun for {} failed", mIfname);
+ }
+ mTunForwarder = std::make_unique<TunForwarder>(std::move(ufd));
+
+ if (auto r = mNetdSrv->networkCreatePhysical(mNetId, INetd::PERMISSION_SYSTEM); !r.isOk()) {
+ return Error() << r.getMessage();
+ }
+ if (auto r = mDnsResolvSrv->createNetworkCache(mNetId); !r.isOk()) {
+ return Error() << r.getMessage();
+ }
+ if (auto r = mNetdSrv->networkAddInterface(mNetId, mIfname); !r.isOk()) {
+ return Error() << r.getMessage();
+ }
+
+ if (mConnectivityType == ConnectivityType::V4 || mConnectivityType == ConnectivityType::V4V6) {
+ const std::string v4Addr = makeIpv4AddrString(1);
+ if (auto r = mNetdSrv->interfaceAddAddress(mIfname, v4Addr, 32); !r.isOk()) {
+ return Error() << r.getMessage();
+ }
+ if (auto r = mNetdSrv->networkAddRoute(mNetId, mIfname, "0.0.0.0/0", ""); !r.isOk()) {
+ return Error() << r.getMessage();
+ }
+ }
+ if (mConnectivityType == ConnectivityType::V6 || mConnectivityType == ConnectivityType::V4V6) {
+ const std::string v6Addr = makeIpv6AddrString(1);
+ if (auto r = mNetdSrv->interfaceAddAddress(mIfname, v6Addr, 128); !r.isOk()) {
+ return Error() << r.getMessage();
+ }
+ if (auto r = mNetdSrv->networkAddRoute(mNetId, mIfname, "::/0", ""); !r.isOk()) {
+ return Error() << r.getMessage();
+ }
+ }
+
+ return {};
+}
+
+void ResolverMultinetworkTest::ScopedNetwork::destroy() {
+ mNetdSrv->networkDestroy(mNetId);
+ mDnsResolvSrv->destroyNetworkCache(mNetId);
+}
+
+void ResolverMultinetworkTest::StartDns(test::DNSResponder& dns,
+ const std::vector<DnsRecord>& records) {
+ ResolverTest::StartDns(dns, records);
+
+ // Bind the DNSResponder's sockets to the network if specified.
+ if (std::optional<unsigned> netId = dns.getNetwork(); netId.has_value()) {
+ setNetworkForSocket(netId.value(), dns.getUdpSocket());
+ setNetworkForSocket(netId.value(), dns.getTcpSocket());
+ }
+}
+
+Result<ResolverMultinetworkTest::DnsServerPair> ResolverMultinetworkTest::ScopedNetwork::addDns(
+ ConnectivityType type) {
+ const int index = mDnsServers.size();
+ const int prefixLen = (type == ConnectivityType::V4) ? 32 : 128;
+
+ const std::function<std::string(unsigned)> makeIpString =
+ std::bind((type == ConnectivityType::V4) ? &ScopedNetwork::makeIpv4AddrString
+ : &ScopedNetwork::makeIpv6AddrString,
+ this, std::placeholders::_1);
+
+ std::string src1 = makeIpString(1); // The address from which the resolver will send.
+ std::string dst1 = makeIpString(index + 100); // The address to which the resolver will send.
+ std::string src2 = dst1; // The address translated from src1.
+ std::string dst2 = makeIpString(index + 200); // The address translated from dst2.
+
+ if (!mTunForwarder->addForwardingRule({src1, dst1}, {src2, dst2}) ||
+ !mTunForwarder->addForwardingRule({dst2, src2}, {dst1, src1})) {
+ return Errorf("Failed to add the rules ({}, {}, {}, {})", src1, dst1, src2, dst2);
+ }
+
+ if (!mNetdSrv->interfaceAddAddress(mIfname, dst2, prefixLen).isOk()) {
+ return Errorf("interfaceAddAddress({}, {}, {}) failed", mIfname, dst2, prefixLen);
+ }
+
+ // Create a DNSResponder instance.
+ auto& dnsPtr = mDnsServers.emplace_back(std::make_unique<test::DNSResponder>(dst2));
+ dnsPtr->setNetwork(mNetId);
+ return DnsServerPair{
+ .dnsServer = *dnsPtr,
+ .dnsAddr = dst1,
+ };
+}
+
+TEST_F(ResolverMultinetworkTest, GetAddrInfo_AI_ADDRCONFIG) {
+ constexpr char host_name[] = "ohayou.example.com.";
+
+ const std::array<ConnectivityType, 3> allTypes = {
+ ConnectivityType::V4,
+ ConnectivityType::V6,
+ ConnectivityType::V4V6,
+ };
+ for (const auto& type : allTypes) {
+ SCOPED_TRACE(StringPrintf("ConnectivityType: %d", type));
+
+ // Create a network.
+ ScopedNetwork network = CreateScopedNetwork(type);
+ ASSERT_RESULT_OK(network.init());
+
+ // Add a testing DNS server.
+ const Result<DnsServerPair> dnsPair =
+ (type == ConnectivityType::V4) ? network.addIpv4Dns() : network.addIpv6Dns();
+ ASSERT_RESULT_OK(dnsPair);
+ StartDns(dnsPair->dnsServer, {{host_name, ns_type::ns_t_a, "1.1.1.31"},
+ {host_name, ns_type::ns_t_aaaa, "2001:db8:cafe:d00d::31"}});
+
+ // Set up resolver and start forwarding.
+ ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
+ parcel.tlsServers.clear();
+ parcel.netId = network.netId();
+ parcel.servers = {dnsPair->dnsAddr};
+ ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+ ASSERT_TRUE(network.startTunForwarder());
+
+ const addrinfo hints = {
+ .ai_flags = AI_ADDRCONFIG,
+ .ai_family = AF_UNSPEC,
+ .ai_socktype = SOCK_DGRAM,
+ };
+ addrinfo* raw_ai_result = nullptr;
+ EXPECT_EQ(0, android_getaddrinfofornet(host_name, nullptr, &hints, network.netId(),
+ MARK_UNSET, &raw_ai_result));
+ ScopedAddrinfo ai_result(raw_ai_result);
+ std::vector<std::string> result_strs = ToStrings(ai_result);
+ std::vector<std::string> expectedResult;
+ size_t expectedQueries = 0;
+
+ if (type == ConnectivityType::V6 || type == ConnectivityType::V4V6) {
+ expectedResult.emplace_back("2001:db8:cafe:d00d::31");
+ expectedQueries++;
+ }
+ if (type == ConnectivityType::V4 || type == ConnectivityType::V4V6) {
+ expectedResult.emplace_back("1.1.1.31");
+ expectedQueries++;
+ }
+ EXPECT_THAT(result_strs, testing::UnorderedElementsAreArray(expectedResult));
+ EXPECT_EQ(GetNumQueries(dnsPair->dnsServer, host_name), expectedQueries);
+ }
+}
+
+TEST_F(ResolverMultinetworkTest, NetworkDestroyedDuringQueryInFlight) {
+ constexpr char host_name[] = "ohayou.example.com.";
+
+ // Create a network and add an ipv4 DNS server.
+ auto network =
+ std::make_unique<ScopedNetwork>(getFreeNetId(), ConnectivityType::V4V6,
+ mDnsClient.netdService(), mDnsClient.resolvService());
+ ASSERT_RESULT_OK(network->init());
+ const Result<DnsServerPair> dnsPair = network->addIpv4Dns();
+ ASSERT_RESULT_OK(dnsPair);
+
+ // Set the DNS server unresponsive.
+ dnsPair->dnsServer.setResponseProbability(0.0);
+ dnsPair->dnsServer.setErrorRcode(static_cast<ns_rcode>(-1));
+ StartDns(dnsPair->dnsServer, {});
+
+ // Set up resolver and start forwarding.
+ ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
+ parcel.tlsServers.clear();
+ parcel.netId = network->netId();
+ parcel.servers = {dnsPair->dnsAddr};
+ ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+ ASSERT_TRUE(network->startTunForwarder());
+
+ // Expect the things happening in order:
+ // 1. The thread sends the query to the dns server which is unresponsive.
+ // 2. The network is destroyed while the thread is waiting for the response from the dns server.
+ // 3. After the dns server timeout, the thread retries but fails to connect.
+ std::thread lookup([&]() {
+ int fd = resNetworkQuery(network->netId(), host_name, ns_c_in, ns_t_a, 0);
+ EXPECT_TRUE(fd != -1);
+ expectAnswersNotValid(fd, -ETIMEDOUT);
+ });
+
+ // Tear down the network as soon as the dns server receives the query.
+ const auto condition = [&]() { return GetNumQueries(dnsPair->dnsServer, host_name) == 1U; };
+ EXPECT_TRUE(PollForCondition(condition));
+ network.reset();
+
+ lookup.join();
+}
+
+TEST_F(ResolverMultinetworkTest, OneCachePerNetwork) {
+ SKIP_IF_REMOTE_VERSION_LESS_THAN(mDnsClient.resolvService(), 4);
+ constexpr char host_name[] = "ohayou.example.com.";
+
+ ScopedNetwork network1 = CreateScopedNetwork(ConnectivityType::V4V6);
+ ScopedNetwork network2 = CreateScopedNetwork(ConnectivityType::V4V6);
+ ASSERT_RESULT_OK(network1.init());
+ ASSERT_RESULT_OK(network2.init());
+
+ const Result<DnsServerPair> dnsPair1 = network1.addIpv4Dns();
+ const Result<DnsServerPair> dnsPair2 = network2.addIpv4Dns();
+ ASSERT_RESULT_OK(dnsPair1);
+ ASSERT_RESULT_OK(dnsPair2);
+ StartDns(dnsPair1->dnsServer, {{host_name, ns_type::ns_t_a, "1.1.1.31"}});
+ StartDns(dnsPair2->dnsServer, {{host_name, ns_type::ns_t_a, "1.1.1.32"}});
+
+ // Set up resolver for network 1 and start forwarding.
+ ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
+ parcel.tlsServers.clear();
+ parcel.netId = network1.netId();
+ parcel.servers = {dnsPair1->dnsAddr};
+ ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+ ASSERT_TRUE(network1.startTunForwarder());
+
+ // Set up resolver for network 2 and start forwarding.
+ parcel.netId = network2.netId();
+ parcel.servers = {dnsPair2->dnsAddr};
+ ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+ ASSERT_TRUE(network2.startTunForwarder());
+
+ // Send the same queries to both networks.
+ int fd1 = resNetworkQuery(network1.netId(), host_name, ns_c_in, ns_t_a, 0);
+ int fd2 = resNetworkQuery(network2.netId(), host_name, ns_c_in, ns_t_a, 0);
+
+ expectAnswersValid(fd1, AF_INET, "1.1.1.31");
+ expectAnswersValid(fd2, AF_INET, "1.1.1.32");
+ EXPECT_EQ(GetNumQueries(dnsPair1->dnsServer, host_name), 1U);
+ EXPECT_EQ(GetNumQueries(dnsPair2->dnsServer, host_name), 1U);
+
+ // Flush the cache of network 1, and send the queries again.
+ EXPECT_TRUE(mDnsClient.resolvService()->flushNetworkCache(network1.netId()).isOk());
+ fd1 = resNetworkQuery(network1.netId(), host_name, ns_c_in, ns_t_a, 0);
+ fd2 = resNetworkQuery(network2.netId(), host_name, ns_c_in, ns_t_a, 0);
+
+ expectAnswersValid(fd1, AF_INET, "1.1.1.31");
+ expectAnswersValid(fd2, AF_INET, "1.1.1.32");
+ EXPECT_EQ(GetNumQueries(dnsPair1->dnsServer, host_name), 2U);
+ EXPECT_EQ(GetNumQueries(dnsPair2->dnsServer, host_name), 1U);
+}