Add VPN test for DnsResolver

Test {enable, disable}VpnIsolation + secure VPN and bypassable VPN
with some different network selection scenarios.

Test: atest
Bug: 159783741
Bug: 159994981
Bug: 161509097

Change-Id: I4abb7311d0d3330efd1f820cef741f4650beb120
diff --git a/tests/resolv_integration_test.cpp b/tests/resolv_integration_test.cpp
index 716dc71..081eda5 100644
--- a/tests/resolv_integration_test.cpp
+++ b/tests/resolv_integration_test.cpp
@@ -53,6 +53,7 @@
 #include <iterator>
 #include <numeric>
 #include <thread>
+#include <unordered_set>
 
 #include <DnsProxydProtocol.h>  // NETID_USE_LOCAL_NAMESERVERS
 #include <aidl/android/net/IDnsResolver.h>
@@ -5192,13 +5193,13 @@
 //
 // An example of how to use it:
 // TEST_F() {
-//     ScopedNetwork network = CreateScopedNetwork(V4);
+//     ScopedPhysicalNetwork network = CreateScopedPhysicalNetwork(V4);
 //     network.init();
 //
 //     auto dns = network.addIpv4Dns();
 //     StartDns(dns.dnsServer, {});
 //
-//     setResolverConfiguration(...);
+//     network.setDnsConfiguration();
 //     network.startTunForwarder();
 //
 //     // Send queries here
@@ -5207,9 +5208,12 @@
 class ResolverMultinetworkTest : public ResolverTest {
   protected:
     enum class ConnectivityType { V4, V6, V4V6 };
+    static constexpr int TEST_NETID_BASE = 10000;
 
     struct DnsServerPair {
-        test::DNSResponder& dnsServer;
+        DnsServerPair(std::shared_ptr<test::DNSResponder> server, std::string addr)
+            : dnsServer(server), dnsAddr(addr) {}
+        std::shared_ptr<test::DNSResponder> dnsServer;
         std::string dnsAddr;  // The DNS server address used for setResolverConfiguration().
         // TODO: Add test::DnsTlsFrontend* and std::string for DoT.
     };
@@ -5217,40 +5221,117 @@
     class ScopedNetwork {
       public:
         ScopedNetwork(unsigned netId, ConnectivityType type, INetd* netdSrv,
-                      IDnsResolver* dnsResolvSrv)
+                      IDnsResolver* dnsResolvSrv, const char* networkName)
             : mNetId(netId),
               mConnectivityType(type),
               mNetdSrv(netdSrv),
-              mDnsResolvSrv(dnsResolvSrv) {
-            mIfname = StringPrintf("testtun%d", netId);
+              mDnsResolvSrv(dnsResolvSrv),
+              mNetworkName(networkName) {
+            mIfname = fmt::format("testtun{}", netId);
         }
-        ~ScopedNetwork() { destroy(); }
+        virtual ~ScopedNetwork() {
+            if (mNetdSrv != nullptr) mNetdSrv->networkDestroy(mNetId);
+            if (mDnsResolvSrv != nullptr) mDnsResolvSrv->destroyNetworkCache(mNetId);
+        }
 
         Result<void> init();
-        void destroy();
         Result<DnsServerPair> addIpv4Dns() { return addDns(ConnectivityType::V4); }
         Result<DnsServerPair> addIpv6Dns() { return addDns(ConnectivityType::V6); }
         bool startTunForwarder() { return mTunForwarder->startForwarding(); }
+        bool setDnsConfiguration() const;
+        bool clearDnsConfiguration() const;
         unsigned netId() const { return mNetId; }
+        std::string name() const { return mNetworkName; }
 
-      private:
-        Result<DnsServerPair> addDns(ConnectivityType connectivity);
-        std::string makeIpv4AddrString(unsigned n) const {
-            return StringPrintf("192.168.%u.%u", mNetId, n);
-        }
-        std::string makeIpv6AddrString(unsigned n) const {
-            return StringPrintf("2001:db8:%u::%u", mNetId, n);
-        }
+      protected:
+        // Subclasses should implement it to decide which network should be create.
+        virtual Result<void> createNetwork() const = 0;
 
         const unsigned mNetId;
         const ConnectivityType mConnectivityType;
         INetd* mNetdSrv;
         IDnsResolver* mDnsResolvSrv;
-
+        const std::string mNetworkName;
         std::string mIfname;
         std::unique_ptr<TunForwarder> mTunForwarder;
-        std::vector<std::unique_ptr<test::DNSResponder>> mDnsServers;
-        // TODO: Add std::vector<std::unique_ptr<test::DnsTlsFrontend>>
+        std::vector<DnsServerPair> mDnsServerPairs;
+
+      private:
+        Result<DnsServerPair> addDns(ConnectivityType connectivity);
+        // Assuming mNetId is unique during ResolverMultinetworkTest, make the
+        // address based on it to avoid conflicts.
+        std::string makeIpv4AddrString(uint8_t n) const {
+            return StringPrintf("192.168.%u.%u", (mNetId - TEST_NETID_BASE), n);
+        }
+        std::string makeIpv6AddrString(uint8_t n) const {
+            return StringPrintf("2001:db8:%u::%u", (mNetId - TEST_NETID_BASE), n);
+        }
+    };
+
+    class ScopedPhysicalNetwork : public ScopedNetwork {
+      public:
+        ScopedPhysicalNetwork(unsigned netId, const char* networkName)
+            : ScopedNetwork(netId, ConnectivityType::V4V6, nullptr, nullptr, networkName) {}
+        ScopedPhysicalNetwork(unsigned netId, ConnectivityType type, INetd* netdSrv,
+                              IDnsResolver* dnsResolvSrv, const char* name = "Physical")
+            : ScopedNetwork(netId, type, netdSrv, dnsResolvSrv, name) {}
+
+      protected:
+        Result<void> createNetwork() const override {
+            if (auto r = mNetdSrv->networkCreatePhysical(mNetId, INetd::PERMISSION_NONE);
+                !r.isOk()) {
+                return Error() << r.getMessage();
+            }
+            return {};
+        }
+    };
+
+    class ScopedVirtualNetwork : public ScopedNetwork {
+      public:
+        ScopedVirtualNetwork(unsigned netId, ConnectivityType type, INetd* netdSrv,
+                             IDnsResolver* dnsResolvSrv, const char* name, bool isSecure)
+            : ScopedNetwork(netId, type, netdSrv, dnsResolvSrv, name), mIsSecure(isSecure) {}
+        ~ScopedVirtualNetwork() {
+            if (!mVpnIsolationUids.empty()) {
+                const std::vector<int> tmpUids(mVpnIsolationUids.begin(), mVpnIsolationUids.end());
+                mNetdSrv->firewallRemoveUidInterfaceRules(tmpUids);
+            }
+        }
+        // Enable VPN isolation. Ensures that uid can only receive packets on mIfname.
+        Result<void> enableVpnIsolation(int uid) {
+            if (auto r = mNetdSrv->firewallAddUidInterfaceRules(mIfname, {uid}); !r.isOk()) {
+                return Error() << r.getMessage();
+            }
+            mVpnIsolationUids.insert(uid);
+            return {};
+        }
+        Result<void> disableVpnIsolation(int uid) {
+            if (auto r = mNetdSrv->firewallRemoveUidInterfaceRules({static_cast<int>(uid)});
+                !r.isOk()) {
+                return Error() << r.getMessage();
+            }
+            mVpnIsolationUids.erase(uid);
+            return {};
+        }
+        Result<void> addUser(uid_t uid) const { return addUidRange(uid, uid); }
+        Result<void> addUidRange(uid_t from, uid_t to) const {
+            if (auto r = mNetdSrv->networkAddUidRanges(mNetId, {makeUidRangeParcel(from, to)});
+                !r.isOk()) {
+                return Error() << r.getMessage();
+            }
+            return {};
+        }
+
+      protected:
+        Result<void> createNetwork() const override {
+            if (auto r = mNetdSrv->networkCreateVpn(mNetId, mIsSecure); !r.isOk()) {
+                return Error() << r.getMessage();
+            }
+            return {};
+        }
+
+        bool mIsSecure = false;
+        std::unordered_set<int> mVpnIsolationUids;
     };
 
     void SetUp() override {
@@ -5259,34 +5340,59 @@
         ASSERT_NE(mDnsClient.resolvService(), nullptr);
     }
 
-    void TearDown() override { ResolverTest::TearDown(); }
+    void TearDown() override {
+        ResolverTest::TearDown();
+        // Restore default network
+        if (mStoredDefaultNetwork >= 0) {
+            mDnsClient.netdService()->networkSetDefault(mStoredDefaultNetwork);
+        }
+    }
 
-    ScopedNetwork CreateScopedNetwork(ConnectivityType type);
+    ScopedPhysicalNetwork CreateScopedPhysicalNetwork(ConnectivityType type,
+                                                      const char* name = "Physical") {
+        return {getFreeNetId(), type, mDnsClient.netdService(), mDnsClient.resolvService(), name};
+    }
+    ScopedVirtualNetwork CreateScopedVirtualNetwork(ConnectivityType type, bool isSecure,
+                                                    const char* name = "Virtual") {
+        return {getFreeNetId(), type,    mDnsClient.netdService(), mDnsClient.resolvService(),
+                name,           isSecure};
+    }
     void StartDns(test::DNSResponder& dns, const std::vector<DnsRecord>& records);
-
-    unsigned getFreeNetId() { return mNextNetId++; }
+    void setDefaultNetwork(int netId) {
+        // Save current default network at the first call.
+        std::call_once(defaultNetworkFlag, [&]() {
+            ASSERT_TRUE(mDnsClient.netdService()->networkGetDefault(&mStoredDefaultNetwork).isOk());
+        });
+        ASSERT_TRUE(mDnsClient.netdService()->networkSetDefault(netId).isOk());
+    }
+    unsigned getFreeNetId() {
+        if (mNextNetId == TEST_NETID_BASE + 256) mNextNetId = TEST_NETID_BASE;
+        return mNextNetId++;
+    }
 
   private:
     // Use a different netId because this class inherits from the class ResolverTest which
-    // always creates TEST_NETID in setup. It's incremented when CreateScopedNetwork() is called.
-    // Note: Don't create more than 20 networks in the class since 51 is used for the mock network.
-    unsigned mNextNetId = 31;
+    // always creates TEST_NETID in setup. It's incremented when CreateScoped{Physical,
+    // Virtual}Network() is called.
+    // Note: 255 is the maximum number of (mNextNetId - TEST_NETID_BASE) here as mNextNetId
+    // is used to create address.
+    unsigned mNextNetId = TEST_NETID_BASE;
+    // Use -1 to represent that default network was not modified because
+    // real netId must be an unsigned value.
+    int mStoredDefaultNetwork = -1;
+    std::once_flag defaultNetworkFlag;
 };
 
-ResolverMultinetworkTest::ScopedNetwork ResolverMultinetworkTest::CreateScopedNetwork(
-        ConnectivityType type) {
-    return {getFreeNetId(), type, mDnsClient.netdService(), mDnsClient.resolvService()};
-}
-
 Result<void> ResolverMultinetworkTest::ScopedNetwork::init() {
+    if (mNetdSrv == nullptr || mDnsResolvSrv == nullptr) return Error() << "srv not available";
     unique_fd ufd = TunForwarder::createTun(mIfname);
     if (!ufd.ok()) {
         return Errorf("createTun for {} failed", mIfname);
     }
     mTunForwarder = std::make_unique<TunForwarder>(std::move(ufd));
 
-    if (auto r = mNetdSrv->networkCreatePhysical(mNetId, INetd::PERMISSION_SYSTEM); !r.isOk()) {
-        return Error() << r.getMessage();
+    if (auto r = createNetwork(); !r.ok()) {
+        return r;
     }
     if (auto r = mDnsResolvSrv->createNetworkCache(mNetId); !r.isOk()) {
         return Error() << r.getMessage();
@@ -5317,11 +5423,6 @@
     return {};
 }
 
-void ResolverMultinetworkTest::ScopedNetwork::destroy() {
-    mNetdSrv->networkDestroy(mNetId);
-    mDnsResolvSrv->destroyNetworkCache(mNetId);
-}
-
 void ResolverMultinetworkTest::StartDns(test::DNSResponder& dns,
                                         const std::vector<DnsRecord>& records) {
     ResolverTest::StartDns(dns, records);
@@ -5335,7 +5436,7 @@
 
 Result<ResolverMultinetworkTest::DnsServerPair> ResolverMultinetworkTest::ScopedNetwork::addDns(
         ConnectivityType type) {
-    const int index = mDnsServers.size();
+    const int index = mDnsServerPairs.size();
     const int prefixLen = (type == ConnectivityType::V4) ? 32 : 128;
 
     const std::function<std::string(unsigned)> makeIpString =
@@ -5344,9 +5445,12 @@
                       this, std::placeholders::_1);
 
     std::string src1 = makeIpString(1);            // The address from which the resolver will send.
-    std::string dst1 = makeIpString(index + 100);  // The address to which the resolver will send.
+    std::string dst1 = makeIpString(
+            index + 100 +
+            (mNetId - TEST_NETID_BASE));           // The address to which the resolver will send.
     std::string src2 = dst1;                       // The address translated from src1.
-    std::string dst2 = makeIpString(index + 200);  // The address translated from dst2.
+    std::string dst2 = makeIpString(
+            index + 200 + (mNetId - TEST_NETID_BASE));  // The address translated from dst2.
 
     if (!mTunForwarder->addForwardingRule({src1, dst1}, {src2, dst2}) ||
         !mTunForwarder->addForwardingRule({dst2, src2}, {dst1, src1})) {
@@ -5357,15 +5461,56 @@
         return Errorf("interfaceAddAddress({}, {}, {}) failed", mIfname, dst2, prefixLen);
     }
 
-    // Create a DNSResponder instance.
-    auto& dnsPtr = mDnsServers.emplace_back(std::make_unique<test::DNSResponder>(dst2));
-    dnsPtr->setNetwork(mNetId);
-    return DnsServerPair{
-            .dnsServer = *dnsPtr,
-            .dnsAddr = dst1,
-    };
+    return mDnsServerPairs.emplace_back(std::make_shared<test::DNSResponder>(mNetId, dst2), dst1);
 }
 
+bool ResolverMultinetworkTest::ScopedNetwork::setDnsConfiguration() const {
+    if (mDnsResolvSrv == nullptr) return false;
+    ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
+    parcel.tlsServers.clear();
+    parcel.netId = mNetId;
+    parcel.servers.clear();
+    for (const auto& pair : mDnsServerPairs) {
+        parcel.servers.push_back(pair.dnsAddr);
+    }
+    return mDnsResolvSrv->setResolverConfiguration(parcel).isOk();
+}
+
+bool ResolverMultinetworkTest::ScopedNetwork::clearDnsConfiguration() const {
+    if (mDnsResolvSrv == nullptr) return false;
+    return mDnsResolvSrv->destroyNetworkCache(mNetId).isOk() &&
+           mDnsResolvSrv->createNetworkCache(mNetId).isOk();
+}
+
+namespace {
+
+// Convenient wrapper for making getaddrinfo call like framework.
+Result<ScopedAddrinfo> android_getaddrinfofornet_wrapper(const char* name, int netId) {
+    // Use the same parameter as libcore/ojluni/src/main/java/java/net/Inet6AddressImpl.java.
+    static const addrinfo hints = {
+            .ai_flags = AI_ADDRCONFIG,
+            .ai_family = AF_UNSPEC,
+            .ai_socktype = SOCK_STREAM,
+    };
+    addrinfo* result = nullptr;
+    if (int r = android_getaddrinfofornet(name, nullptr, &hints, netId, MARK_UNSET, &result)) {
+        return Error() << r;
+    }
+    return ScopedAddrinfo(result);
+}
+
+void expectDnsWorksForUid(const char* name, unsigned netId, uid_t uid,
+                          const std::vector<std::string>& expectedResult) {
+    ScopedChangeUID scopedChangeUID(uid);
+    auto result = android_getaddrinfofornet_wrapper(name, netId);
+    ASSERT_RESULT_OK(result);
+    ScopedAddrinfo ai_result(std::move(result.value()));
+    std::vector<std::string> result_strs = ToStrings(ai_result);
+    EXPECT_THAT(result_strs, testing::UnorderedElementsAreArray(expectedResult));
+}
+
+}  // namespace
+
 TEST_F(ResolverMultinetworkTest, GetAddrInfo_AI_ADDRCONFIG) {
     constexpr char host_name[] = "ohayou.example.com.";
 
@@ -5378,33 +5523,23 @@
         SCOPED_TRACE(StringPrintf("ConnectivityType: %d", type));
 
         // Create a network.
-        ScopedNetwork network = CreateScopedNetwork(type);
+        ScopedPhysicalNetwork network = CreateScopedPhysicalNetwork(type);
         ASSERT_RESULT_OK(network.init());
 
         // Add a testing DNS server.
         const Result<DnsServerPair> dnsPair =
                 (type == ConnectivityType::V4) ? network.addIpv4Dns() : network.addIpv6Dns();
         ASSERT_RESULT_OK(dnsPair);
-        StartDns(dnsPair->dnsServer, {{host_name, ns_type::ns_t_a, "1.1.1.31"},
-                                      {host_name, ns_type::ns_t_aaaa, "2001:db8:cafe:d00d::31"}});
+        StartDns(*dnsPair->dnsServer, {{host_name, ns_type::ns_t_a, "192.0.2.0"},
+                                       {host_name, ns_type::ns_t_aaaa, "2001:db8:cafe:d00d::31"}});
 
         // Set up resolver and start forwarding.
-        ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
-        parcel.tlsServers.clear();
-        parcel.netId = network.netId();
-        parcel.servers = {dnsPair->dnsAddr};
-        ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+        ASSERT_TRUE(network.setDnsConfiguration());
         ASSERT_TRUE(network.startTunForwarder());
 
-        const addrinfo hints = {
-                .ai_flags = AI_ADDRCONFIG,
-                .ai_family = AF_UNSPEC,
-                .ai_socktype = SOCK_DGRAM,
-        };
-        addrinfo* raw_ai_result = nullptr;
-        EXPECT_EQ(0, android_getaddrinfofornet(host_name, nullptr, &hints, network.netId(),
-                                               MARK_UNSET, &raw_ai_result));
-        ScopedAddrinfo ai_result(raw_ai_result);
+        auto result = android_getaddrinfofornet_wrapper(host_name, network.netId());
+        ASSERT_RESULT_OK(result);
+        ScopedAddrinfo ai_result(std::move(result.value()));
         std::vector<std::string> result_strs = ToStrings(ai_result);
         std::vector<std::string> expectedResult;
         size_t expectedQueries = 0;
@@ -5414,11 +5549,11 @@
             expectedQueries++;
         }
         if (type == ConnectivityType::V4 || type == ConnectivityType::V4V6) {
-            expectedResult.emplace_back("1.1.1.31");
+            expectedResult.emplace_back("192.0.2.0");
             expectedQueries++;
         }
         EXPECT_THAT(result_strs, testing::UnorderedElementsAreArray(expectedResult));
-        EXPECT_EQ(GetNumQueries(dnsPair->dnsServer, host_name), expectedQueries);
+        EXPECT_EQ(GetNumQueries(*dnsPair->dnsServer, host_name), expectedQueries);
     }
 }
 
@@ -5426,24 +5561,20 @@
     constexpr char host_name[] = "ohayou.example.com.";
 
     // Create a network and add an ipv4 DNS server.
-    auto network =
-            std::make_unique<ScopedNetwork>(getFreeNetId(), ConnectivityType::V4V6,
-                                            mDnsClient.netdService(), mDnsClient.resolvService());
+    auto network = std::make_unique<ScopedPhysicalNetwork>(getFreeNetId(), ConnectivityType::V4V6,
+                                                           mDnsClient.netdService(),
+                                                           mDnsClient.resolvService());
     ASSERT_RESULT_OK(network->init());
     const Result<DnsServerPair> dnsPair = network->addIpv4Dns();
     ASSERT_RESULT_OK(dnsPair);
 
     // Set the DNS server unresponsive.
-    dnsPair->dnsServer.setResponseProbability(0.0);
-    dnsPair->dnsServer.setErrorRcode(static_cast<ns_rcode>(-1));
-    StartDns(dnsPair->dnsServer, {});
+    dnsPair->dnsServer->setResponseProbability(0.0);
+    dnsPair->dnsServer->setErrorRcode(static_cast<ns_rcode>(-1));
+    StartDns(*dnsPair->dnsServer, {});
 
     // Set up resolver and start forwarding.
-    ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
-    parcel.tlsServers.clear();
-    parcel.netId = network->netId();
-    parcel.servers = {dnsPair->dnsAddr};
-    ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+    ASSERT_TRUE(network->setDnsConfiguration());
     ASSERT_TRUE(network->startTunForwarder());
 
     // Expect the things happening in order:
@@ -5457,7 +5588,7 @@
     });
 
     // Tear down the network as soon as the dns server receives the query.
-    const auto condition = [&]() { return GetNumQueries(dnsPair->dnsServer, host_name) == 1U; };
+    const auto condition = [&]() { return GetNumQueries(*dnsPair->dnsServer, host_name) == 1U; };
     EXPECT_TRUE(PollForCondition(condition));
     network.reset();
 
@@ -5468,8 +5599,8 @@
     SKIP_IF_REMOTE_VERSION_LESS_THAN(mDnsClient.resolvService(), 4);
     constexpr char host_name[] = "ohayou.example.com.";
 
-    ScopedNetwork network1 = CreateScopedNetwork(ConnectivityType::V4V6);
-    ScopedNetwork network2 = CreateScopedNetwork(ConnectivityType::V4V6);
+    ScopedPhysicalNetwork network1 = CreateScopedPhysicalNetwork(ConnectivityType::V4V6);
+    ScopedPhysicalNetwork network2 = CreateScopedPhysicalNetwork(ConnectivityType::V4V6);
     ASSERT_RESULT_OK(network1.init());
     ASSERT_RESULT_OK(network2.init());
 
@@ -5477,39 +5608,174 @@
     const Result<DnsServerPair> dnsPair2 = network2.addIpv4Dns();
     ASSERT_RESULT_OK(dnsPair1);
     ASSERT_RESULT_OK(dnsPair2);
-    StartDns(dnsPair1->dnsServer, {{host_name, ns_type::ns_t_a, "1.1.1.31"}});
-    StartDns(dnsPair2->dnsServer, {{host_name, ns_type::ns_t_a, "1.1.1.32"}});
+    StartDns(*dnsPair1->dnsServer, {{host_name, ns_type::ns_t_a, "192.0.2.0"}});
+    StartDns(*dnsPair2->dnsServer, {{host_name, ns_type::ns_t_a, "192.0.2.1"}});
 
     // Set up resolver for network 1 and start forwarding.
-    ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
-    parcel.tlsServers.clear();
-    parcel.netId = network1.netId();
-    parcel.servers = {dnsPair1->dnsAddr};
-    ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+    ASSERT_TRUE(network1.setDnsConfiguration());
     ASSERT_TRUE(network1.startTunForwarder());
 
     // Set up resolver for network 2 and start forwarding.
-    parcel.netId = network2.netId();
-    parcel.servers = {dnsPair2->dnsAddr};
-    ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+    ASSERT_TRUE(network2.setDnsConfiguration());
     ASSERT_TRUE(network2.startTunForwarder());
 
     // Send the same queries to both networks.
     int fd1 = resNetworkQuery(network1.netId(), host_name, ns_c_in, ns_t_a, 0);
     int fd2 = resNetworkQuery(network2.netId(), host_name, ns_c_in, ns_t_a, 0);
 
-    expectAnswersValid(fd1, AF_INET, "1.1.1.31");
-    expectAnswersValid(fd2, AF_INET, "1.1.1.32");
-    EXPECT_EQ(GetNumQueries(dnsPair1->dnsServer, host_name), 1U);
-    EXPECT_EQ(GetNumQueries(dnsPair2->dnsServer, host_name), 1U);
+    expectAnswersValid(fd1, AF_INET, "192.0.2.0");
+    expectAnswersValid(fd2, AF_INET, "192.0.2.1");
+    EXPECT_EQ(GetNumQueries(*dnsPair1->dnsServer, host_name), 1U);
+    EXPECT_EQ(GetNumQueries(*dnsPair2->dnsServer, host_name), 1U);
 
     // Flush the cache of network 1, and send the queries again.
     EXPECT_TRUE(mDnsClient.resolvService()->flushNetworkCache(network1.netId()).isOk());
     fd1 = resNetworkQuery(network1.netId(), host_name, ns_c_in, ns_t_a, 0);
     fd2 = resNetworkQuery(network2.netId(), host_name, ns_c_in, ns_t_a, 0);
 
-    expectAnswersValid(fd1, AF_INET, "1.1.1.31");
-    expectAnswersValid(fd2, AF_INET, "1.1.1.32");
-    EXPECT_EQ(GetNumQueries(dnsPair1->dnsServer, host_name), 2U);
-    EXPECT_EQ(GetNumQueries(dnsPair2->dnsServer, host_name), 1U);
+    expectAnswersValid(fd1, AF_INET, "192.0.2.0");
+    expectAnswersValid(fd2, AF_INET, "192.0.2.1");
+    EXPECT_EQ(GetNumQueries(*dnsPair1->dnsServer, host_name), 2U);
+    EXPECT_EQ(GetNumQueries(*dnsPair2->dnsServer, host_name), 1U);
+}
+
+TEST_F(ResolverMultinetworkTest, DnsWithVpn) {
+    SKIP_IF_REMOTE_VERSION_LESS_THAN(mDnsClient.resolvService(), 4);
+    SKIP_IF_BPF_NOT_SUPPORTED;
+    constexpr char host_name[] = "ohayou.example.com.";
+    constexpr char ipv4_addr[] = "192.0.2.0";
+    constexpr char ipv6_addr[] = "2001:db8:cafe:d00d::31";
+
+    const std::pair<ConnectivityType, std::vector<std::string>> testPairs[] = {
+            {ConnectivityType::V4, {ipv4_addr}},
+            {ConnectivityType::V6, {ipv6_addr}},
+            {ConnectivityType::V4V6, {ipv6_addr, ipv4_addr}},
+    };
+    for (const auto& [type, result] : testPairs) {
+        SCOPED_TRACE(StringPrintf("ConnectivityType: %d", type));
+
+        // Create a network.
+        ScopedPhysicalNetwork underlyingNetwork = CreateScopedPhysicalNetwork(type, "Underlying");
+        ScopedVirtualNetwork bypassableVpnNetwork =
+                CreateScopedVirtualNetwork(type, false, "BypassableVpn");
+        ScopedVirtualNetwork secureVpnNetwork = CreateScopedVirtualNetwork(type, true, "SecureVpn");
+
+        ASSERT_RESULT_OK(underlyingNetwork.init());
+        ASSERT_RESULT_OK(bypassableVpnNetwork.init());
+        ASSERT_RESULT_OK(secureVpnNetwork.init());
+        ASSERT_RESULT_OK(bypassableVpnNetwork.addUser(TEST_UID));
+        ASSERT_RESULT_OK(secureVpnNetwork.addUser(TEST_UID2));
+
+        auto setupDnsFn = [&](std::shared_ptr<test::DNSResponder> dnsServer,
+                              ScopedNetwork* nw) -> void {
+            StartDns(*dnsServer, {{host_name, ns_type::ns_t_a, ipv4_addr},
+                                  {host_name, ns_type::ns_t_aaaa, ipv6_addr}});
+            ASSERT_TRUE(nw->setDnsConfiguration());
+            ASSERT_TRUE(nw->startTunForwarder());
+        };
+        // Add a testing DNS server to networks.
+        const Result<DnsServerPair> underlyingPair = (type == ConnectivityType::V4)
+                                                             ? underlyingNetwork.addIpv4Dns()
+                                                             : underlyingNetwork.addIpv6Dns();
+        ASSERT_RESULT_OK(underlyingPair);
+        const Result<DnsServerPair> bypassableVpnPair = (type == ConnectivityType::V4)
+                                                                ? bypassableVpnNetwork.addIpv4Dns()
+                                                                : bypassableVpnNetwork.addIpv6Dns();
+        ASSERT_RESULT_OK(bypassableVpnPair);
+        const Result<DnsServerPair> secureVpnPair = (type == ConnectivityType::V4)
+                                                            ? secureVpnNetwork.addIpv4Dns()
+                                                            : secureVpnNetwork.addIpv6Dns();
+        ASSERT_RESULT_OK(secureVpnPair);
+        // Set up resolver and start forwarding for networks.
+        setupDnsFn(underlyingPair->dnsServer, &underlyingNetwork);
+        setupDnsFn(bypassableVpnPair->dnsServer, &bypassableVpnNetwork);
+        setupDnsFn(secureVpnPair->dnsServer, &secureVpnNetwork);
+
+        setDefaultNetwork(underlyingNetwork.netId());
+        const unsigned underlyingNetId = underlyingNetwork.netId();
+        const unsigned bypassableVpnNetId = bypassableVpnNetwork.netId();
+        const unsigned secureVpnNetId = secureVpnNetwork.netId();
+        // We've called setNetworkForProcess in SetupOemNetwork, so reset to default first.
+        ScopedSetNetworkForProcess scopedSetNetworkForProcess(NETID_UNSET);
+        auto expectDnsQueryCountsFn = [&](size_t count,
+                                          std::shared_ptr<test::DNSResponder> dnsServer,
+                                          unsigned expectedDnsNetId) -> void {
+            EXPECT_EQ(GetNumQueries(*dnsServer, host_name), count);
+            EXPECT_TRUE(mDnsClient.resolvService()->flushNetworkCache(expectedDnsNetId).isOk());
+            dnsServer->clearQueries();
+            // Give DnsResolver some time to clear cache to avoid race.
+            usleep(5 * 1000);
+        };
+
+        // Create a object to represent default network, do not init it.
+        ScopedPhysicalNetwork defaultNetwork{NETID_UNSET, "Default"};
+
+        // Test VPN with DNS server under 4 different network selection scenarios.
+        // See the test config for the expectation.
+        const struct TestConfig {
+            ScopedNetwork* selectedNetwork;
+            unsigned expectedDnsNetId;
+            std::shared_ptr<test::DNSResponder> expectedDnsServer;
+        } vpnWithDnsServerConfigs[]{
+                // clang-format off
+                // Queries use the bypassable VPN by default.
+                {&defaultNetwork,       bypassableVpnNetId, bypassableVpnPair->dnsServer},
+                // Choosing the underlying network works because the VPN is bypassable.
+                {&underlyingNetwork,    underlyingNetId,    underlyingPair->dnsServer},
+                // Selecting the VPN sends the query on the VPN.
+                {&bypassableVpnNetwork, bypassableVpnNetId, bypassableVpnPair->dnsServer},
+                // TEST_UID does not have access to the secure VPN.
+                {&secureVpnNetwork,     bypassableVpnNetId, bypassableVpnPair->dnsServer},
+                // clang-format on
+        };
+        for (const auto& config : vpnWithDnsServerConfigs) {
+            SCOPED_TRACE(fmt::format("Bypassble VPN with DnsServer, selectedNetwork = {}",
+                                     config.selectedNetwork->name()));
+            expectDnsWorksForUid(host_name, config.selectedNetwork->netId(), TEST_UID, result);
+            expectDnsQueryCountsFn(result.size(), config.expectedDnsServer,
+                                   config.expectedDnsNetId);
+        }
+
+        std::vector<ScopedNetwork*> nwVec{&defaultNetwork, &underlyingNetwork,
+                                          &bypassableVpnNetwork, &secureVpnNetwork};
+        // Test the VPN without DNS server with the same combination as before.
+        ASSERT_TRUE(bypassableVpnNetwork.clearDnsConfiguration());
+        // Test bypassable VPN, TEST_UID
+        for (const auto* selectedNetwork : nwVec) {
+            SCOPED_TRACE(fmt::format("Bypassble VPN without DnsServer, selectedNetwork = {}",
+                                     selectedNetwork->name()));
+            expectDnsWorksForUid(host_name, selectedNetwork->netId(), TEST_UID, result);
+            expectDnsQueryCountsFn(result.size(), underlyingPair->dnsServer, underlyingNetId);
+        }
+
+        // The same test scenario as before plus enableVpnIsolation for secure VPN, TEST_UID2.
+        for (bool enableVpnIsolation : {false, true}) {
+            SCOPED_TRACE(fmt::format("enableVpnIsolation = {}", enableVpnIsolation));
+            if (enableVpnIsolation) {
+                EXPECT_RESULT_OK(secureVpnNetwork.enableVpnIsolation(TEST_UID2));
+            }
+
+            // Test secure VPN without DNS server.
+            ASSERT_TRUE(secureVpnNetwork.clearDnsConfiguration());
+            for (const auto* selectedNetwork : nwVec) {
+                SCOPED_TRACE(fmt::format("Secure VPN without DnsServer, selectedNetwork = {}",
+                                         selectedNetwork->name()));
+                expectDnsWorksForUid(host_name, selectedNetwork->netId(), TEST_UID2, result);
+                expectDnsQueryCountsFn(result.size(), underlyingPair->dnsServer, underlyingNetId);
+            }
+
+            // Test secure VPN with DNS server.
+            ASSERT_TRUE(secureVpnNetwork.setDnsConfiguration());
+            for (const auto* selectedNetwork : nwVec) {
+                SCOPED_TRACE(fmt::format("Secure VPN with DnsServer, selectedNetwork = {}",
+                                         selectedNetwork->name()));
+                expectDnsWorksForUid(host_name, selectedNetwork->netId(), TEST_UID2, result);
+                expectDnsQueryCountsFn(result.size(), secureVpnPair->dnsServer, secureVpnNetId);
+            }
+
+            if (enableVpnIsolation) {
+                EXPECT_RESULT_OK(secureVpnNetwork.disableVpnIsolation(TEST_UID2));
+            }
+        }
+    }
 }