Add VPN test for DnsResolver
Test {enable, disable}VpnIsolation + secure VPN and bypassable VPN
with some different network selection scenarios.
Test: atest
Bug: 159783741
Bug: 159994981
Bug: 161509097
Change-Id: I4abb7311d0d3330efd1f820cef741f4650beb120
diff --git a/tests/resolv_integration_test.cpp b/tests/resolv_integration_test.cpp
index 716dc71..081eda5 100644
--- a/tests/resolv_integration_test.cpp
+++ b/tests/resolv_integration_test.cpp
@@ -53,6 +53,7 @@
#include <iterator>
#include <numeric>
#include <thread>
+#include <unordered_set>
#include <DnsProxydProtocol.h> // NETID_USE_LOCAL_NAMESERVERS
#include <aidl/android/net/IDnsResolver.h>
@@ -5192,13 +5193,13 @@
//
// An example of how to use it:
// TEST_F() {
-// ScopedNetwork network = CreateScopedNetwork(V4);
+// ScopedPhysicalNetwork network = CreateScopedPhysicalNetwork(V4);
// network.init();
//
// auto dns = network.addIpv4Dns();
// StartDns(dns.dnsServer, {});
//
-// setResolverConfiguration(...);
+// network.setDnsConfiguration();
// network.startTunForwarder();
//
// // Send queries here
@@ -5207,9 +5208,12 @@
class ResolverMultinetworkTest : public ResolverTest {
protected:
enum class ConnectivityType { V4, V6, V4V6 };
+ static constexpr int TEST_NETID_BASE = 10000;
struct DnsServerPair {
- test::DNSResponder& dnsServer;
+ DnsServerPair(std::shared_ptr<test::DNSResponder> server, std::string addr)
+ : dnsServer(server), dnsAddr(addr) {}
+ std::shared_ptr<test::DNSResponder> dnsServer;
std::string dnsAddr; // The DNS server address used for setResolverConfiguration().
// TODO: Add test::DnsTlsFrontend* and std::string for DoT.
};
@@ -5217,40 +5221,117 @@
class ScopedNetwork {
public:
ScopedNetwork(unsigned netId, ConnectivityType type, INetd* netdSrv,
- IDnsResolver* dnsResolvSrv)
+ IDnsResolver* dnsResolvSrv, const char* networkName)
: mNetId(netId),
mConnectivityType(type),
mNetdSrv(netdSrv),
- mDnsResolvSrv(dnsResolvSrv) {
- mIfname = StringPrintf("testtun%d", netId);
+ mDnsResolvSrv(dnsResolvSrv),
+ mNetworkName(networkName) {
+ mIfname = fmt::format("testtun{}", netId);
}
- ~ScopedNetwork() { destroy(); }
+ virtual ~ScopedNetwork() {
+ if (mNetdSrv != nullptr) mNetdSrv->networkDestroy(mNetId);
+ if (mDnsResolvSrv != nullptr) mDnsResolvSrv->destroyNetworkCache(mNetId);
+ }
Result<void> init();
- void destroy();
Result<DnsServerPair> addIpv4Dns() { return addDns(ConnectivityType::V4); }
Result<DnsServerPair> addIpv6Dns() { return addDns(ConnectivityType::V6); }
bool startTunForwarder() { return mTunForwarder->startForwarding(); }
+ bool setDnsConfiguration() const;
+ bool clearDnsConfiguration() const;
unsigned netId() const { return mNetId; }
+ std::string name() const { return mNetworkName; }
- private:
- Result<DnsServerPair> addDns(ConnectivityType connectivity);
- std::string makeIpv4AddrString(unsigned n) const {
- return StringPrintf("192.168.%u.%u", mNetId, n);
- }
- std::string makeIpv6AddrString(unsigned n) const {
- return StringPrintf("2001:db8:%u::%u", mNetId, n);
- }
+ protected:
+ // Subclasses should implement it to decide which network should be create.
+ virtual Result<void> createNetwork() const = 0;
const unsigned mNetId;
const ConnectivityType mConnectivityType;
INetd* mNetdSrv;
IDnsResolver* mDnsResolvSrv;
-
+ const std::string mNetworkName;
std::string mIfname;
std::unique_ptr<TunForwarder> mTunForwarder;
- std::vector<std::unique_ptr<test::DNSResponder>> mDnsServers;
- // TODO: Add std::vector<std::unique_ptr<test::DnsTlsFrontend>>
+ std::vector<DnsServerPair> mDnsServerPairs;
+
+ private:
+ Result<DnsServerPair> addDns(ConnectivityType connectivity);
+ // Assuming mNetId is unique during ResolverMultinetworkTest, make the
+ // address based on it to avoid conflicts.
+ std::string makeIpv4AddrString(uint8_t n) const {
+ return StringPrintf("192.168.%u.%u", (mNetId - TEST_NETID_BASE), n);
+ }
+ std::string makeIpv6AddrString(uint8_t n) const {
+ return StringPrintf("2001:db8:%u::%u", (mNetId - TEST_NETID_BASE), n);
+ }
+ };
+
+ class ScopedPhysicalNetwork : public ScopedNetwork {
+ public:
+ ScopedPhysicalNetwork(unsigned netId, const char* networkName)
+ : ScopedNetwork(netId, ConnectivityType::V4V6, nullptr, nullptr, networkName) {}
+ ScopedPhysicalNetwork(unsigned netId, ConnectivityType type, INetd* netdSrv,
+ IDnsResolver* dnsResolvSrv, const char* name = "Physical")
+ : ScopedNetwork(netId, type, netdSrv, dnsResolvSrv, name) {}
+
+ protected:
+ Result<void> createNetwork() const override {
+ if (auto r = mNetdSrv->networkCreatePhysical(mNetId, INetd::PERMISSION_NONE);
+ !r.isOk()) {
+ return Error() << r.getMessage();
+ }
+ return {};
+ }
+ };
+
+ class ScopedVirtualNetwork : public ScopedNetwork {
+ public:
+ ScopedVirtualNetwork(unsigned netId, ConnectivityType type, INetd* netdSrv,
+ IDnsResolver* dnsResolvSrv, const char* name, bool isSecure)
+ : ScopedNetwork(netId, type, netdSrv, dnsResolvSrv, name), mIsSecure(isSecure) {}
+ ~ScopedVirtualNetwork() {
+ if (!mVpnIsolationUids.empty()) {
+ const std::vector<int> tmpUids(mVpnIsolationUids.begin(), mVpnIsolationUids.end());
+ mNetdSrv->firewallRemoveUidInterfaceRules(tmpUids);
+ }
+ }
+ // Enable VPN isolation. Ensures that uid can only receive packets on mIfname.
+ Result<void> enableVpnIsolation(int uid) {
+ if (auto r = mNetdSrv->firewallAddUidInterfaceRules(mIfname, {uid}); !r.isOk()) {
+ return Error() << r.getMessage();
+ }
+ mVpnIsolationUids.insert(uid);
+ return {};
+ }
+ Result<void> disableVpnIsolation(int uid) {
+ if (auto r = mNetdSrv->firewallRemoveUidInterfaceRules({static_cast<int>(uid)});
+ !r.isOk()) {
+ return Error() << r.getMessage();
+ }
+ mVpnIsolationUids.erase(uid);
+ return {};
+ }
+ Result<void> addUser(uid_t uid) const { return addUidRange(uid, uid); }
+ Result<void> addUidRange(uid_t from, uid_t to) const {
+ if (auto r = mNetdSrv->networkAddUidRanges(mNetId, {makeUidRangeParcel(from, to)});
+ !r.isOk()) {
+ return Error() << r.getMessage();
+ }
+ return {};
+ }
+
+ protected:
+ Result<void> createNetwork() const override {
+ if (auto r = mNetdSrv->networkCreateVpn(mNetId, mIsSecure); !r.isOk()) {
+ return Error() << r.getMessage();
+ }
+ return {};
+ }
+
+ bool mIsSecure = false;
+ std::unordered_set<int> mVpnIsolationUids;
};
void SetUp() override {
@@ -5259,34 +5340,59 @@
ASSERT_NE(mDnsClient.resolvService(), nullptr);
}
- void TearDown() override { ResolverTest::TearDown(); }
+ void TearDown() override {
+ ResolverTest::TearDown();
+ // Restore default network
+ if (mStoredDefaultNetwork >= 0) {
+ mDnsClient.netdService()->networkSetDefault(mStoredDefaultNetwork);
+ }
+ }
- ScopedNetwork CreateScopedNetwork(ConnectivityType type);
+ ScopedPhysicalNetwork CreateScopedPhysicalNetwork(ConnectivityType type,
+ const char* name = "Physical") {
+ return {getFreeNetId(), type, mDnsClient.netdService(), mDnsClient.resolvService(), name};
+ }
+ ScopedVirtualNetwork CreateScopedVirtualNetwork(ConnectivityType type, bool isSecure,
+ const char* name = "Virtual") {
+ return {getFreeNetId(), type, mDnsClient.netdService(), mDnsClient.resolvService(),
+ name, isSecure};
+ }
void StartDns(test::DNSResponder& dns, const std::vector<DnsRecord>& records);
-
- unsigned getFreeNetId() { return mNextNetId++; }
+ void setDefaultNetwork(int netId) {
+ // Save current default network at the first call.
+ std::call_once(defaultNetworkFlag, [&]() {
+ ASSERT_TRUE(mDnsClient.netdService()->networkGetDefault(&mStoredDefaultNetwork).isOk());
+ });
+ ASSERT_TRUE(mDnsClient.netdService()->networkSetDefault(netId).isOk());
+ }
+ unsigned getFreeNetId() {
+ if (mNextNetId == TEST_NETID_BASE + 256) mNextNetId = TEST_NETID_BASE;
+ return mNextNetId++;
+ }
private:
// Use a different netId because this class inherits from the class ResolverTest which
- // always creates TEST_NETID in setup. It's incremented when CreateScopedNetwork() is called.
- // Note: Don't create more than 20 networks in the class since 51 is used for the mock network.
- unsigned mNextNetId = 31;
+ // always creates TEST_NETID in setup. It's incremented when CreateScoped{Physical,
+ // Virtual}Network() is called.
+ // Note: 255 is the maximum number of (mNextNetId - TEST_NETID_BASE) here as mNextNetId
+ // is used to create address.
+ unsigned mNextNetId = TEST_NETID_BASE;
+ // Use -1 to represent that default network was not modified because
+ // real netId must be an unsigned value.
+ int mStoredDefaultNetwork = -1;
+ std::once_flag defaultNetworkFlag;
};
-ResolverMultinetworkTest::ScopedNetwork ResolverMultinetworkTest::CreateScopedNetwork(
- ConnectivityType type) {
- return {getFreeNetId(), type, mDnsClient.netdService(), mDnsClient.resolvService()};
-}
-
Result<void> ResolverMultinetworkTest::ScopedNetwork::init() {
+ if (mNetdSrv == nullptr || mDnsResolvSrv == nullptr) return Error() << "srv not available";
unique_fd ufd = TunForwarder::createTun(mIfname);
if (!ufd.ok()) {
return Errorf("createTun for {} failed", mIfname);
}
mTunForwarder = std::make_unique<TunForwarder>(std::move(ufd));
- if (auto r = mNetdSrv->networkCreatePhysical(mNetId, INetd::PERMISSION_SYSTEM); !r.isOk()) {
- return Error() << r.getMessage();
+ if (auto r = createNetwork(); !r.ok()) {
+ return r;
}
if (auto r = mDnsResolvSrv->createNetworkCache(mNetId); !r.isOk()) {
return Error() << r.getMessage();
@@ -5317,11 +5423,6 @@
return {};
}
-void ResolverMultinetworkTest::ScopedNetwork::destroy() {
- mNetdSrv->networkDestroy(mNetId);
- mDnsResolvSrv->destroyNetworkCache(mNetId);
-}
-
void ResolverMultinetworkTest::StartDns(test::DNSResponder& dns,
const std::vector<DnsRecord>& records) {
ResolverTest::StartDns(dns, records);
@@ -5335,7 +5436,7 @@
Result<ResolverMultinetworkTest::DnsServerPair> ResolverMultinetworkTest::ScopedNetwork::addDns(
ConnectivityType type) {
- const int index = mDnsServers.size();
+ const int index = mDnsServerPairs.size();
const int prefixLen = (type == ConnectivityType::V4) ? 32 : 128;
const std::function<std::string(unsigned)> makeIpString =
@@ -5344,9 +5445,12 @@
this, std::placeholders::_1);
std::string src1 = makeIpString(1); // The address from which the resolver will send.
- std::string dst1 = makeIpString(index + 100); // The address to which the resolver will send.
+ std::string dst1 = makeIpString(
+ index + 100 +
+ (mNetId - TEST_NETID_BASE)); // The address to which the resolver will send.
std::string src2 = dst1; // The address translated from src1.
- std::string dst2 = makeIpString(index + 200); // The address translated from dst2.
+ std::string dst2 = makeIpString(
+ index + 200 + (mNetId - TEST_NETID_BASE)); // The address translated from dst2.
if (!mTunForwarder->addForwardingRule({src1, dst1}, {src2, dst2}) ||
!mTunForwarder->addForwardingRule({dst2, src2}, {dst1, src1})) {
@@ -5357,15 +5461,56 @@
return Errorf("interfaceAddAddress({}, {}, {}) failed", mIfname, dst2, prefixLen);
}
- // Create a DNSResponder instance.
- auto& dnsPtr = mDnsServers.emplace_back(std::make_unique<test::DNSResponder>(dst2));
- dnsPtr->setNetwork(mNetId);
- return DnsServerPair{
- .dnsServer = *dnsPtr,
- .dnsAddr = dst1,
- };
+ return mDnsServerPairs.emplace_back(std::make_shared<test::DNSResponder>(mNetId, dst2), dst1);
}
+bool ResolverMultinetworkTest::ScopedNetwork::setDnsConfiguration() const {
+ if (mDnsResolvSrv == nullptr) return false;
+ ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
+ parcel.tlsServers.clear();
+ parcel.netId = mNetId;
+ parcel.servers.clear();
+ for (const auto& pair : mDnsServerPairs) {
+ parcel.servers.push_back(pair.dnsAddr);
+ }
+ return mDnsResolvSrv->setResolverConfiguration(parcel).isOk();
+}
+
+bool ResolverMultinetworkTest::ScopedNetwork::clearDnsConfiguration() const {
+ if (mDnsResolvSrv == nullptr) return false;
+ return mDnsResolvSrv->destroyNetworkCache(mNetId).isOk() &&
+ mDnsResolvSrv->createNetworkCache(mNetId).isOk();
+}
+
+namespace {
+
+// Convenient wrapper for making getaddrinfo call like framework.
+Result<ScopedAddrinfo> android_getaddrinfofornet_wrapper(const char* name, int netId) {
+ // Use the same parameter as libcore/ojluni/src/main/java/java/net/Inet6AddressImpl.java.
+ static const addrinfo hints = {
+ .ai_flags = AI_ADDRCONFIG,
+ .ai_family = AF_UNSPEC,
+ .ai_socktype = SOCK_STREAM,
+ };
+ addrinfo* result = nullptr;
+ if (int r = android_getaddrinfofornet(name, nullptr, &hints, netId, MARK_UNSET, &result)) {
+ return Error() << r;
+ }
+ return ScopedAddrinfo(result);
+}
+
+void expectDnsWorksForUid(const char* name, unsigned netId, uid_t uid,
+ const std::vector<std::string>& expectedResult) {
+ ScopedChangeUID scopedChangeUID(uid);
+ auto result = android_getaddrinfofornet_wrapper(name, netId);
+ ASSERT_RESULT_OK(result);
+ ScopedAddrinfo ai_result(std::move(result.value()));
+ std::vector<std::string> result_strs = ToStrings(ai_result);
+ EXPECT_THAT(result_strs, testing::UnorderedElementsAreArray(expectedResult));
+}
+
+} // namespace
+
TEST_F(ResolverMultinetworkTest, GetAddrInfo_AI_ADDRCONFIG) {
constexpr char host_name[] = "ohayou.example.com.";
@@ -5378,33 +5523,23 @@
SCOPED_TRACE(StringPrintf("ConnectivityType: %d", type));
// Create a network.
- ScopedNetwork network = CreateScopedNetwork(type);
+ ScopedPhysicalNetwork network = CreateScopedPhysicalNetwork(type);
ASSERT_RESULT_OK(network.init());
// Add a testing DNS server.
const Result<DnsServerPair> dnsPair =
(type == ConnectivityType::V4) ? network.addIpv4Dns() : network.addIpv6Dns();
ASSERT_RESULT_OK(dnsPair);
- StartDns(dnsPair->dnsServer, {{host_name, ns_type::ns_t_a, "1.1.1.31"},
- {host_name, ns_type::ns_t_aaaa, "2001:db8:cafe:d00d::31"}});
+ StartDns(*dnsPair->dnsServer, {{host_name, ns_type::ns_t_a, "192.0.2.0"},
+ {host_name, ns_type::ns_t_aaaa, "2001:db8:cafe:d00d::31"}});
// Set up resolver and start forwarding.
- ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
- parcel.tlsServers.clear();
- parcel.netId = network.netId();
- parcel.servers = {dnsPair->dnsAddr};
- ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+ ASSERT_TRUE(network.setDnsConfiguration());
ASSERT_TRUE(network.startTunForwarder());
- const addrinfo hints = {
- .ai_flags = AI_ADDRCONFIG,
- .ai_family = AF_UNSPEC,
- .ai_socktype = SOCK_DGRAM,
- };
- addrinfo* raw_ai_result = nullptr;
- EXPECT_EQ(0, android_getaddrinfofornet(host_name, nullptr, &hints, network.netId(),
- MARK_UNSET, &raw_ai_result));
- ScopedAddrinfo ai_result(raw_ai_result);
+ auto result = android_getaddrinfofornet_wrapper(host_name, network.netId());
+ ASSERT_RESULT_OK(result);
+ ScopedAddrinfo ai_result(std::move(result.value()));
std::vector<std::string> result_strs = ToStrings(ai_result);
std::vector<std::string> expectedResult;
size_t expectedQueries = 0;
@@ -5414,11 +5549,11 @@
expectedQueries++;
}
if (type == ConnectivityType::V4 || type == ConnectivityType::V4V6) {
- expectedResult.emplace_back("1.1.1.31");
+ expectedResult.emplace_back("192.0.2.0");
expectedQueries++;
}
EXPECT_THAT(result_strs, testing::UnorderedElementsAreArray(expectedResult));
- EXPECT_EQ(GetNumQueries(dnsPair->dnsServer, host_name), expectedQueries);
+ EXPECT_EQ(GetNumQueries(*dnsPair->dnsServer, host_name), expectedQueries);
}
}
@@ -5426,24 +5561,20 @@
constexpr char host_name[] = "ohayou.example.com.";
// Create a network and add an ipv4 DNS server.
- auto network =
- std::make_unique<ScopedNetwork>(getFreeNetId(), ConnectivityType::V4V6,
- mDnsClient.netdService(), mDnsClient.resolvService());
+ auto network = std::make_unique<ScopedPhysicalNetwork>(getFreeNetId(), ConnectivityType::V4V6,
+ mDnsClient.netdService(),
+ mDnsClient.resolvService());
ASSERT_RESULT_OK(network->init());
const Result<DnsServerPair> dnsPair = network->addIpv4Dns();
ASSERT_RESULT_OK(dnsPair);
// Set the DNS server unresponsive.
- dnsPair->dnsServer.setResponseProbability(0.0);
- dnsPair->dnsServer.setErrorRcode(static_cast<ns_rcode>(-1));
- StartDns(dnsPair->dnsServer, {});
+ dnsPair->dnsServer->setResponseProbability(0.0);
+ dnsPair->dnsServer->setErrorRcode(static_cast<ns_rcode>(-1));
+ StartDns(*dnsPair->dnsServer, {});
// Set up resolver and start forwarding.
- ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
- parcel.tlsServers.clear();
- parcel.netId = network->netId();
- parcel.servers = {dnsPair->dnsAddr};
- ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+ ASSERT_TRUE(network->setDnsConfiguration());
ASSERT_TRUE(network->startTunForwarder());
// Expect the things happening in order:
@@ -5457,7 +5588,7 @@
});
// Tear down the network as soon as the dns server receives the query.
- const auto condition = [&]() { return GetNumQueries(dnsPair->dnsServer, host_name) == 1U; };
+ const auto condition = [&]() { return GetNumQueries(*dnsPair->dnsServer, host_name) == 1U; };
EXPECT_TRUE(PollForCondition(condition));
network.reset();
@@ -5468,8 +5599,8 @@
SKIP_IF_REMOTE_VERSION_LESS_THAN(mDnsClient.resolvService(), 4);
constexpr char host_name[] = "ohayou.example.com.";
- ScopedNetwork network1 = CreateScopedNetwork(ConnectivityType::V4V6);
- ScopedNetwork network2 = CreateScopedNetwork(ConnectivityType::V4V6);
+ ScopedPhysicalNetwork network1 = CreateScopedPhysicalNetwork(ConnectivityType::V4V6);
+ ScopedPhysicalNetwork network2 = CreateScopedPhysicalNetwork(ConnectivityType::V4V6);
ASSERT_RESULT_OK(network1.init());
ASSERT_RESULT_OK(network2.init());
@@ -5477,39 +5608,174 @@
const Result<DnsServerPair> dnsPair2 = network2.addIpv4Dns();
ASSERT_RESULT_OK(dnsPair1);
ASSERT_RESULT_OK(dnsPair2);
- StartDns(dnsPair1->dnsServer, {{host_name, ns_type::ns_t_a, "1.1.1.31"}});
- StartDns(dnsPair2->dnsServer, {{host_name, ns_type::ns_t_a, "1.1.1.32"}});
+ StartDns(*dnsPair1->dnsServer, {{host_name, ns_type::ns_t_a, "192.0.2.0"}});
+ StartDns(*dnsPair2->dnsServer, {{host_name, ns_type::ns_t_a, "192.0.2.1"}});
// Set up resolver for network 1 and start forwarding.
- ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
- parcel.tlsServers.clear();
- parcel.netId = network1.netId();
- parcel.servers = {dnsPair1->dnsAddr};
- ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+ ASSERT_TRUE(network1.setDnsConfiguration());
ASSERT_TRUE(network1.startTunForwarder());
// Set up resolver for network 2 and start forwarding.
- parcel.netId = network2.netId();
- parcel.servers = {dnsPair2->dnsAddr};
- ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+ ASSERT_TRUE(network2.setDnsConfiguration());
ASSERT_TRUE(network2.startTunForwarder());
// Send the same queries to both networks.
int fd1 = resNetworkQuery(network1.netId(), host_name, ns_c_in, ns_t_a, 0);
int fd2 = resNetworkQuery(network2.netId(), host_name, ns_c_in, ns_t_a, 0);
- expectAnswersValid(fd1, AF_INET, "1.1.1.31");
- expectAnswersValid(fd2, AF_INET, "1.1.1.32");
- EXPECT_EQ(GetNumQueries(dnsPair1->dnsServer, host_name), 1U);
- EXPECT_EQ(GetNumQueries(dnsPair2->dnsServer, host_name), 1U);
+ expectAnswersValid(fd1, AF_INET, "192.0.2.0");
+ expectAnswersValid(fd2, AF_INET, "192.0.2.1");
+ EXPECT_EQ(GetNumQueries(*dnsPair1->dnsServer, host_name), 1U);
+ EXPECT_EQ(GetNumQueries(*dnsPair2->dnsServer, host_name), 1U);
// Flush the cache of network 1, and send the queries again.
EXPECT_TRUE(mDnsClient.resolvService()->flushNetworkCache(network1.netId()).isOk());
fd1 = resNetworkQuery(network1.netId(), host_name, ns_c_in, ns_t_a, 0);
fd2 = resNetworkQuery(network2.netId(), host_name, ns_c_in, ns_t_a, 0);
- expectAnswersValid(fd1, AF_INET, "1.1.1.31");
- expectAnswersValid(fd2, AF_INET, "1.1.1.32");
- EXPECT_EQ(GetNumQueries(dnsPair1->dnsServer, host_name), 2U);
- EXPECT_EQ(GetNumQueries(dnsPair2->dnsServer, host_name), 1U);
+ expectAnswersValid(fd1, AF_INET, "192.0.2.0");
+ expectAnswersValid(fd2, AF_INET, "192.0.2.1");
+ EXPECT_EQ(GetNumQueries(*dnsPair1->dnsServer, host_name), 2U);
+ EXPECT_EQ(GetNumQueries(*dnsPair2->dnsServer, host_name), 1U);
+}
+
+TEST_F(ResolverMultinetworkTest, DnsWithVpn) {
+ SKIP_IF_REMOTE_VERSION_LESS_THAN(mDnsClient.resolvService(), 4);
+ SKIP_IF_BPF_NOT_SUPPORTED;
+ constexpr char host_name[] = "ohayou.example.com.";
+ constexpr char ipv4_addr[] = "192.0.2.0";
+ constexpr char ipv6_addr[] = "2001:db8:cafe:d00d::31";
+
+ const std::pair<ConnectivityType, std::vector<std::string>> testPairs[] = {
+ {ConnectivityType::V4, {ipv4_addr}},
+ {ConnectivityType::V6, {ipv6_addr}},
+ {ConnectivityType::V4V6, {ipv6_addr, ipv4_addr}},
+ };
+ for (const auto& [type, result] : testPairs) {
+ SCOPED_TRACE(StringPrintf("ConnectivityType: %d", type));
+
+ // Create a network.
+ ScopedPhysicalNetwork underlyingNetwork = CreateScopedPhysicalNetwork(type, "Underlying");
+ ScopedVirtualNetwork bypassableVpnNetwork =
+ CreateScopedVirtualNetwork(type, false, "BypassableVpn");
+ ScopedVirtualNetwork secureVpnNetwork = CreateScopedVirtualNetwork(type, true, "SecureVpn");
+
+ ASSERT_RESULT_OK(underlyingNetwork.init());
+ ASSERT_RESULT_OK(bypassableVpnNetwork.init());
+ ASSERT_RESULT_OK(secureVpnNetwork.init());
+ ASSERT_RESULT_OK(bypassableVpnNetwork.addUser(TEST_UID));
+ ASSERT_RESULT_OK(secureVpnNetwork.addUser(TEST_UID2));
+
+ auto setupDnsFn = [&](std::shared_ptr<test::DNSResponder> dnsServer,
+ ScopedNetwork* nw) -> void {
+ StartDns(*dnsServer, {{host_name, ns_type::ns_t_a, ipv4_addr},
+ {host_name, ns_type::ns_t_aaaa, ipv6_addr}});
+ ASSERT_TRUE(nw->setDnsConfiguration());
+ ASSERT_TRUE(nw->startTunForwarder());
+ };
+ // Add a testing DNS server to networks.
+ const Result<DnsServerPair> underlyingPair = (type == ConnectivityType::V4)
+ ? underlyingNetwork.addIpv4Dns()
+ : underlyingNetwork.addIpv6Dns();
+ ASSERT_RESULT_OK(underlyingPair);
+ const Result<DnsServerPair> bypassableVpnPair = (type == ConnectivityType::V4)
+ ? bypassableVpnNetwork.addIpv4Dns()
+ : bypassableVpnNetwork.addIpv6Dns();
+ ASSERT_RESULT_OK(bypassableVpnPair);
+ const Result<DnsServerPair> secureVpnPair = (type == ConnectivityType::V4)
+ ? secureVpnNetwork.addIpv4Dns()
+ : secureVpnNetwork.addIpv6Dns();
+ ASSERT_RESULT_OK(secureVpnPair);
+ // Set up resolver and start forwarding for networks.
+ setupDnsFn(underlyingPair->dnsServer, &underlyingNetwork);
+ setupDnsFn(bypassableVpnPair->dnsServer, &bypassableVpnNetwork);
+ setupDnsFn(secureVpnPair->dnsServer, &secureVpnNetwork);
+
+ setDefaultNetwork(underlyingNetwork.netId());
+ const unsigned underlyingNetId = underlyingNetwork.netId();
+ const unsigned bypassableVpnNetId = bypassableVpnNetwork.netId();
+ const unsigned secureVpnNetId = secureVpnNetwork.netId();
+ // We've called setNetworkForProcess in SetupOemNetwork, so reset to default first.
+ ScopedSetNetworkForProcess scopedSetNetworkForProcess(NETID_UNSET);
+ auto expectDnsQueryCountsFn = [&](size_t count,
+ std::shared_ptr<test::DNSResponder> dnsServer,
+ unsigned expectedDnsNetId) -> void {
+ EXPECT_EQ(GetNumQueries(*dnsServer, host_name), count);
+ EXPECT_TRUE(mDnsClient.resolvService()->flushNetworkCache(expectedDnsNetId).isOk());
+ dnsServer->clearQueries();
+ // Give DnsResolver some time to clear cache to avoid race.
+ usleep(5 * 1000);
+ };
+
+ // Create a object to represent default network, do not init it.
+ ScopedPhysicalNetwork defaultNetwork{NETID_UNSET, "Default"};
+
+ // Test VPN with DNS server under 4 different network selection scenarios.
+ // See the test config for the expectation.
+ const struct TestConfig {
+ ScopedNetwork* selectedNetwork;
+ unsigned expectedDnsNetId;
+ std::shared_ptr<test::DNSResponder> expectedDnsServer;
+ } vpnWithDnsServerConfigs[]{
+ // clang-format off
+ // Queries use the bypassable VPN by default.
+ {&defaultNetwork, bypassableVpnNetId, bypassableVpnPair->dnsServer},
+ // Choosing the underlying network works because the VPN is bypassable.
+ {&underlyingNetwork, underlyingNetId, underlyingPair->dnsServer},
+ // Selecting the VPN sends the query on the VPN.
+ {&bypassableVpnNetwork, bypassableVpnNetId, bypassableVpnPair->dnsServer},
+ // TEST_UID does not have access to the secure VPN.
+ {&secureVpnNetwork, bypassableVpnNetId, bypassableVpnPair->dnsServer},
+ // clang-format on
+ };
+ for (const auto& config : vpnWithDnsServerConfigs) {
+ SCOPED_TRACE(fmt::format("Bypassble VPN with DnsServer, selectedNetwork = {}",
+ config.selectedNetwork->name()));
+ expectDnsWorksForUid(host_name, config.selectedNetwork->netId(), TEST_UID, result);
+ expectDnsQueryCountsFn(result.size(), config.expectedDnsServer,
+ config.expectedDnsNetId);
+ }
+
+ std::vector<ScopedNetwork*> nwVec{&defaultNetwork, &underlyingNetwork,
+ &bypassableVpnNetwork, &secureVpnNetwork};
+ // Test the VPN without DNS server with the same combination as before.
+ ASSERT_TRUE(bypassableVpnNetwork.clearDnsConfiguration());
+ // Test bypassable VPN, TEST_UID
+ for (const auto* selectedNetwork : nwVec) {
+ SCOPED_TRACE(fmt::format("Bypassble VPN without DnsServer, selectedNetwork = {}",
+ selectedNetwork->name()));
+ expectDnsWorksForUid(host_name, selectedNetwork->netId(), TEST_UID, result);
+ expectDnsQueryCountsFn(result.size(), underlyingPair->dnsServer, underlyingNetId);
+ }
+
+ // The same test scenario as before plus enableVpnIsolation for secure VPN, TEST_UID2.
+ for (bool enableVpnIsolation : {false, true}) {
+ SCOPED_TRACE(fmt::format("enableVpnIsolation = {}", enableVpnIsolation));
+ if (enableVpnIsolation) {
+ EXPECT_RESULT_OK(secureVpnNetwork.enableVpnIsolation(TEST_UID2));
+ }
+
+ // Test secure VPN without DNS server.
+ ASSERT_TRUE(secureVpnNetwork.clearDnsConfiguration());
+ for (const auto* selectedNetwork : nwVec) {
+ SCOPED_TRACE(fmt::format("Secure VPN without DnsServer, selectedNetwork = {}",
+ selectedNetwork->name()));
+ expectDnsWorksForUid(host_name, selectedNetwork->netId(), TEST_UID2, result);
+ expectDnsQueryCountsFn(result.size(), underlyingPair->dnsServer, underlyingNetId);
+ }
+
+ // Test secure VPN with DNS server.
+ ASSERT_TRUE(secureVpnNetwork.setDnsConfiguration());
+ for (const auto* selectedNetwork : nwVec) {
+ SCOPED_TRACE(fmt::format("Secure VPN with DnsServer, selectedNetwork = {}",
+ selectedNetwork->name()));
+ expectDnsWorksForUid(host_name, selectedNetwork->netId(), TEST_UID2, result);
+ expectDnsQueryCountsFn(result.size(), secureVpnPair->dnsServer, secureVpnNetId);
+ }
+
+ if (enableVpnIsolation) {
+ EXPECT_RESULT_OK(secureVpnNetwork.disableVpnIsolation(TEST_UID2));
+ }
+ }
+ }
}