Support multinetwork tests

The integration test used to set up testing DNS servers on loopback
interface. To support testing functionality for multinetwork, make
the test able to send queries to a TUN interface, and the queries
will be forwarded to the testing DNS servers.

To forward packets, implement a forwarder which can translate packets
(v4-to-v4 or v6-to-v6) between the resolver and testing DNS servers
and can forward packets to each other.

Also add three tests:
  GetAddrInfo_AI_ADDRCONFIG
  NetworkDestroyedDuringQueryInFlight
  OneCachePerNetwork

And remove unused libraries from the test:
  libnetd_test_tun_interface
  libnetd_test_utils

Test: cd packages/modules/DnsResolver && atest
Change-Id: I52a52ce59373bc8b9462064c0409b657696c379f
diff --git a/tests/Android.bp b/tests/Android.bp
index df6f39f..112e162 100644
--- a/tests/Android.bp
+++ b/tests/Android.bp
@@ -151,6 +151,7 @@
         "dns_responder/dns_responder.cpp",
         "dnsresolver_binder_test.cpp",
         "resolv_integration_test.cpp",
+        "tun_forwarder.cpp",
     ],
     header_libs: [
         "dnsproxyd_protocol_headers",
@@ -170,13 +171,12 @@
         "libnetd_test_dnsresponder_ndk",
         "libnetd_test_metrics_listener",
         "libnetd_test_resolv_utils",
-        "libnetd_test_tun_interface",
-        "libnetd_test_utils",
         "libnetdutils",
         "libssl",
         "libutils",
         "netd_aidl_interface-ndk_platform",
         "netd_event_listener_interface-ndk_platform",
+        "libipchecksum",
     ],
     // This test talks to the DnsResolver module over a binary protocol on a socket, so keep it as
     // multilib setting is worth because we might be able to get some coverage for the case where
diff --git a/tests/dns_responder/dns_responder.h b/tests/dns_responder/dns_responder.h
index ccd0bfc..8b54405 100644
--- a/tests/dns_responder/dns_responder.h
+++ b/tests/dns_responder/dns_responder.h
@@ -175,6 +175,7 @@
     void setResponseProbability(double response_probability);
     void setResponseProbability(double response_probability, int protocol);
     void setResponseDelayMs(unsigned);
+    void setErrorRcode(ns_rcode error_rcode) { error_rcode_ = error_rcode; }
     void setEdns(Edns edns);
     void setTtl(unsigned ttl);
     bool running() const;
@@ -190,6 +191,16 @@
     void setDeferredResp(bool deferred_resp);
     static bool fillRdata(const std::string& rdatastr, DNSRecord& record);
 
+    // These functions are helpers for binding the listening sockets to a specific network, which
+    // is necessary only for multinetwork tests. Since binding sockets to a network requires
+    // the dependency of libnetd_client, and DNSResponder is also widely used in other tests like
+    // resolv_unit_test which doesn't need that dependency, so expose the socket fds to let the
+    // callers perform binding operations by themselves. Callers MUST not close the fds.
+    void setNetwork(unsigned netId) { mNetId = netId; }
+    std::optional<unsigned> getNetwork() const { return mNetId; }
+    int getUdpSocket() const { return udp_socket_.get(); }
+    int getTcpSocket() const { return tcp_socket_.get(); }
+
     // TODO: Make DNSResponder record unknown queries in a vector for improving the debugging.
     // Unit test could dump the unexpected query for further debug if any unexpected failure.
 
@@ -284,8 +295,10 @@
     // Address and service to listen on TCP and UDP.
     const std::string listen_address_;
     const std::string listen_service_;
+
+    // TODO: Consider refactoring atomic members of this class to a single big mutex.
     // Error code to return for requests for an unknown name.
-    const ns_rcode error_rcode_;
+    ns_rcode error_rcode_;
     // Mapping type the DNS server used to build the response.
     const MappingType mapping_type_;
     // Probability that a valid response on TCP is being sent instead of
@@ -340,6 +353,9 @@
     std::condition_variable cv_for_deferred_resp_;
     std::mutex cv_mutex_for_deferred_resp_;
     bool deferred_resp_ GUARDED_BY(cv_mutex_for_deferred_resp_) = false;
+
+    // The network to which the listening sockets will be bound.
+    std::optional<unsigned> mNetId;
 };
 
 }  // namespace test
diff --git a/tests/resolv_integration_test.cpp b/tests/resolv_integration_test.cpp
index ff5b57b..1ea7d01 100644
--- a/tests/resolv_integration_test.cpp
+++ b/tests/resolv_integration_test.cpp
@@ -20,6 +20,7 @@
 #include <android-base/logging.h>
 #include <android-base/parseint.h>
 #include <android-base/properties.h>
+#include <android-base/result.h>
 #include <android-base/stringprintf.h>
 #include <android-base/unique_fd.h>
 #include <android/multinetwork.h>  // ResNsendFlags
@@ -62,13 +63,13 @@
 #include "netid_client.h"  // NETID_UNSET
 #include "params.h"        // MAXNS
 #include "stats.h"         // RCODE_TIMEOUT
-#include "test_utils.h"
 #include "tests/dns_metrics_listener/dns_metrics_listener.h"
 #include "tests/dns_responder/dns_responder.h"
 #include "tests/dns_responder/dns_responder_client_ndk.h"
 #include "tests/dns_responder/dns_tls_certificate.h"
 #include "tests/dns_responder/dns_tls_frontend.h"
 #include "tests/resolv_test_utils.h"
+#include "tests/tun_forwarder.h"
 
 // Valid VPN netId range is 100 ~ 65535
 constexpr int TEST_VPN_NETID = 65502;
@@ -86,10 +87,13 @@
 using aidl::android::net::INetd;
 using aidl::android::net::ResolverParamsParcel;
 using aidl::android::net::metrics::INetdEventListener;
+using android::base::Error;
 using android::base::ParseInt;
+using android::base::Result;
 using android::base::StringPrintf;
 using android::base::unique_fd;
 using android::net::ResolverStats;
+using android::net::TunForwarder;
 using android::net::metrics::DnsMetricsListener;
 using android::netdutils::enableSockopt;
 using android::netdutils::makeSlice;
@@ -5116,3 +5120,337 @@
     EXPECT_EQ(dns1.queries().size(), 0U);
     EXPECT_EQ(dns2.queries().size(), 0U);
 }
+
+// ResolverMultinetworkTest is used to verify multinetwork functionality. Here's how it works:
+// The resolver sends queries to address A, and then there will be a TunForwarder helping forward
+// the packets to address B, which is the address on which the testing server is listening. The
+// answer packets responded from the testing server go through the reverse path back to the
+// resolver.
+//
+// To achieve the that, it needs to set up a interface with routing rules. Tests are not
+// supposed to initiate DNS servers on their own; instead, some utilities are added to the class to
+// help the setup.
+//
+// An example of how to use it:
+// TEST_F() {
+//     ScopedNetwork network = CreateScopedNetwork(V4);
+//     network.init();
+//
+//     auto dns = network.addIpv4Dns();
+//     StartDns(dns.dnsServer, {});
+//
+//     setResolverConfiguration(...);
+//     network.startTunForwarder();
+//
+//     // Send queries here
+// }
+
+class ResolverMultinetworkTest : public ResolverTest {
+  protected:
+    enum class ConnectivityType { V4, V6, V4V6 };
+
+    struct DnsServerPair {
+        test::DNSResponder& dnsServer;
+        std::string dnsAddr;  // The DNS server address used for setResolverConfiguration().
+        // TODO: Add test::DnsTlsFrontend* and std::string for DoT.
+    };
+
+    class ScopedNetwork {
+      public:
+        ScopedNetwork(unsigned netId, ConnectivityType type, INetd* netdSrv,
+                      IDnsResolver* dnsResolvSrv)
+            : mNetId(netId),
+              mConnectivityType(type),
+              mNetdSrv(netdSrv),
+              mDnsResolvSrv(dnsResolvSrv) {
+            mIfname = StringPrintf("testtun%d", netId);
+        }
+        ~ScopedNetwork() { destroy(); }
+
+        Result<void> init();
+        void destroy();
+        Result<DnsServerPair> addIpv4Dns() { return addDns(ConnectivityType::V4); }
+        Result<DnsServerPair> addIpv6Dns() { return addDns(ConnectivityType::V6); }
+        bool startTunForwarder() { return mTunForwarder->startForwarding(); }
+        unsigned netId() const { return mNetId; }
+
+      private:
+        Result<DnsServerPair> addDns(ConnectivityType connectivity);
+        std::string makeIpv4AddrString(unsigned n) const {
+            return StringPrintf("192.168.%u.%u", mNetId, n);
+        }
+        std::string makeIpv6AddrString(unsigned n) const {
+            return StringPrintf("2001:db8:%u::%u", mNetId, n);
+        }
+
+        const unsigned mNetId;
+        const ConnectivityType mConnectivityType;
+        INetd* mNetdSrv;
+        IDnsResolver* mDnsResolvSrv;
+
+        std::string mIfname;
+        std::unique_ptr<TunForwarder> mTunForwarder;
+        std::vector<std::unique_ptr<test::DNSResponder>> mDnsServers;
+        // TODO: Add std::vector<std::unique_ptr<test::DnsTlsFrontend>>
+    };
+
+    void SetUp() override {
+        ResolverTest::SetUp();
+        ASSERT_NE(mDnsClient.netdService(), nullptr);
+        ASSERT_NE(mDnsClient.resolvService(), nullptr);
+    }
+
+    void TearDown() override { ResolverTest::TearDown(); }
+
+    ScopedNetwork CreateScopedNetwork(ConnectivityType type);
+    void StartDns(test::DNSResponder& dns, const std::vector<DnsRecord>& records);
+
+    unsigned getFreeNetId() { return mNextNetId++; }
+
+  private:
+    // Use a different netId because this class inherits from the class ResolverTest which
+    // always creates TEST_NETID in setup. It's incremented when CreateScopedNetwork() is called.
+    // Note: Don't create more than 20 networks in the class since 51 is used for the dummy network.
+    unsigned mNextNetId = 31;
+};
+
+ResolverMultinetworkTest::ScopedNetwork ResolverMultinetworkTest::CreateScopedNetwork(
+        ConnectivityType type) {
+    return {getFreeNetId(), type, mDnsClient.netdService(), mDnsClient.resolvService()};
+}
+
+Result<void> ResolverMultinetworkTest::ScopedNetwork::init() {
+    unique_fd ufd = TunForwarder::createTun(mIfname);
+    if (!ufd.ok()) {
+        return Errorf("createTun for {} failed", mIfname);
+    }
+    mTunForwarder = std::make_unique<TunForwarder>(std::move(ufd));
+
+    if (auto r = mNetdSrv->networkCreatePhysical(mNetId, INetd::PERMISSION_SYSTEM); !r.isOk()) {
+        return Error() << r.getMessage();
+    }
+    if (auto r = mDnsResolvSrv->createNetworkCache(mNetId); !r.isOk()) {
+        return Error() << r.getMessage();
+    }
+    if (auto r = mNetdSrv->networkAddInterface(mNetId, mIfname); !r.isOk()) {
+        return Error() << r.getMessage();
+    }
+
+    if (mConnectivityType == ConnectivityType::V4 || mConnectivityType == ConnectivityType::V4V6) {
+        const std::string v4Addr = makeIpv4AddrString(1);
+        if (auto r = mNetdSrv->interfaceAddAddress(mIfname, v4Addr, 32); !r.isOk()) {
+            return Error() << r.getMessage();
+        }
+        if (auto r = mNetdSrv->networkAddRoute(mNetId, mIfname, "0.0.0.0/0", ""); !r.isOk()) {
+            return Error() << r.getMessage();
+        }
+    }
+    if (mConnectivityType == ConnectivityType::V6 || mConnectivityType == ConnectivityType::V4V6) {
+        const std::string v6Addr = makeIpv6AddrString(1);
+        if (auto r = mNetdSrv->interfaceAddAddress(mIfname, v6Addr, 128); !r.isOk()) {
+            return Error() << r.getMessage();
+        }
+        if (auto r = mNetdSrv->networkAddRoute(mNetId, mIfname, "::/0", ""); !r.isOk()) {
+            return Error() << r.getMessage();
+        }
+    }
+
+    return {};
+}
+
+void ResolverMultinetworkTest::ScopedNetwork::destroy() {
+    mNetdSrv->networkDestroy(mNetId);
+    mDnsResolvSrv->destroyNetworkCache(mNetId);
+}
+
+void ResolverMultinetworkTest::StartDns(test::DNSResponder& dns,
+                                        const std::vector<DnsRecord>& records) {
+    ResolverTest::StartDns(dns, records);
+
+    // Bind the DNSResponder's sockets to the network if specified.
+    if (std::optional<unsigned> netId = dns.getNetwork(); netId.has_value()) {
+        setNetworkForSocket(netId.value(), dns.getUdpSocket());
+        setNetworkForSocket(netId.value(), dns.getTcpSocket());
+    }
+}
+
+Result<ResolverMultinetworkTest::DnsServerPair> ResolverMultinetworkTest::ScopedNetwork::addDns(
+        ConnectivityType type) {
+    const int index = mDnsServers.size();
+    const int prefixLen = (type == ConnectivityType::V4) ? 32 : 128;
+
+    const std::function<std::string(unsigned)> makeIpString =
+            std::bind((type == ConnectivityType::V4) ? &ScopedNetwork::makeIpv4AddrString
+                                                     : &ScopedNetwork::makeIpv6AddrString,
+                      this, std::placeholders::_1);
+
+    std::string src1 = makeIpString(1);            // The address from which the resolver will send.
+    std::string dst1 = makeIpString(index + 100);  // The address to which the resolver will send.
+    std::string src2 = dst1;                       // The address translated from src1.
+    std::string dst2 = makeIpString(index + 200);  // The address translated from dst2.
+
+    if (!mTunForwarder->addForwardingRule({src1, dst1}, {src2, dst2}) ||
+        !mTunForwarder->addForwardingRule({dst2, src2}, {dst1, src1})) {
+        return Errorf("Failed to add the rules ({}, {}, {}, {})", src1, dst1, src2, dst2);
+    }
+
+    if (!mNetdSrv->interfaceAddAddress(mIfname, dst2, prefixLen).isOk()) {
+        return Errorf("interfaceAddAddress({}, {}, {}) failed", mIfname, dst2, prefixLen);
+    }
+
+    // Create a DNSResponder instance.
+    auto& dnsPtr = mDnsServers.emplace_back(std::make_unique<test::DNSResponder>(dst2));
+    dnsPtr->setNetwork(mNetId);
+    return DnsServerPair{
+            .dnsServer = *dnsPtr,
+            .dnsAddr = dst1,
+    };
+}
+
+TEST_F(ResolverMultinetworkTest, GetAddrInfo_AI_ADDRCONFIG) {
+    constexpr char host_name[] = "ohayou.example.com.";
+
+    const std::array<ConnectivityType, 3> allTypes = {
+            ConnectivityType::V4,
+            ConnectivityType::V6,
+            ConnectivityType::V4V6,
+    };
+    for (const auto& type : allTypes) {
+        SCOPED_TRACE(StringPrintf("ConnectivityType: %d", type));
+
+        // Create a network.
+        ScopedNetwork network = CreateScopedNetwork(type);
+        ASSERT_RESULT_OK(network.init());
+
+        // Add a testing DNS server.
+        const Result<DnsServerPair> dnsPair =
+                (type == ConnectivityType::V4) ? network.addIpv4Dns() : network.addIpv6Dns();
+        ASSERT_RESULT_OK(dnsPair);
+        StartDns(dnsPair->dnsServer, {{host_name, ns_type::ns_t_a, "1.1.1.31"},
+                                      {host_name, ns_type::ns_t_aaaa, "2001:db8:cafe:d00d::31"}});
+
+        // Set up resolver and start forwarding.
+        ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
+        parcel.tlsServers.clear();
+        parcel.netId = network.netId();
+        parcel.servers = {dnsPair->dnsAddr};
+        ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+        ASSERT_TRUE(network.startTunForwarder());
+
+        const addrinfo hints = {
+                .ai_flags = AI_ADDRCONFIG,
+                .ai_family = AF_UNSPEC,
+                .ai_socktype = SOCK_DGRAM,
+        };
+        addrinfo* raw_ai_result = nullptr;
+        EXPECT_EQ(0, android_getaddrinfofornet(host_name, nullptr, &hints, network.netId(),
+                                               MARK_UNSET, &raw_ai_result));
+        ScopedAddrinfo ai_result(raw_ai_result);
+        std::vector<std::string> result_strs = ToStrings(ai_result);
+        std::vector<std::string> expectedResult;
+        size_t expectedQueries = 0;
+
+        if (type == ConnectivityType::V6 || type == ConnectivityType::V4V6) {
+            expectedResult.emplace_back("2001:db8:cafe:d00d::31");
+            expectedQueries++;
+        }
+        if (type == ConnectivityType::V4 || type == ConnectivityType::V4V6) {
+            expectedResult.emplace_back("1.1.1.31");
+            expectedQueries++;
+        }
+        EXPECT_THAT(result_strs, testing::UnorderedElementsAreArray(expectedResult));
+        EXPECT_EQ(GetNumQueries(dnsPair->dnsServer, host_name), expectedQueries);
+    }
+}
+
+TEST_F(ResolverMultinetworkTest, NetworkDestroyedDuringQueryInFlight) {
+    constexpr char host_name[] = "ohayou.example.com.";
+
+    // Create a network and add an ipv4 DNS server.
+    auto network =
+            std::make_unique<ScopedNetwork>(getFreeNetId(), ConnectivityType::V4V6,
+                                            mDnsClient.netdService(), mDnsClient.resolvService());
+    ASSERT_RESULT_OK(network->init());
+    const Result<DnsServerPair> dnsPair = network->addIpv4Dns();
+    ASSERT_RESULT_OK(dnsPair);
+
+    // Set the DNS server unresponsive.
+    dnsPair->dnsServer.setResponseProbability(0.0);
+    dnsPair->dnsServer.setErrorRcode(static_cast<ns_rcode>(-1));
+    StartDns(dnsPair->dnsServer, {});
+
+    // Set up resolver and start forwarding.
+    ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
+    parcel.tlsServers.clear();
+    parcel.netId = network->netId();
+    parcel.servers = {dnsPair->dnsAddr};
+    ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+    ASSERT_TRUE(network->startTunForwarder());
+
+    // Expect the things happening in order:
+    // 1. The thread sends the query to the dns server which is unresponsive.
+    // 2. The network is destroyed while the thread is waiting for the response from the dns server.
+    // 3. After the dns server timeout, the thread retries but fails to connect.
+    std::thread lookup([&]() {
+        int fd = resNetworkQuery(network->netId(), host_name, ns_c_in, ns_t_a, 0);
+        EXPECT_TRUE(fd != -1);
+        expectAnswersNotValid(fd, -ETIMEDOUT);
+    });
+
+    // Tear down the network as soon as the dns server receives the query.
+    const auto condition = [&]() { return GetNumQueries(dnsPair->dnsServer, host_name) == 1U; };
+    EXPECT_TRUE(PollForCondition(condition));
+    network.reset();
+
+    lookup.join();
+}
+
+TEST_F(ResolverMultinetworkTest, OneCachePerNetwork) {
+    SKIP_IF_REMOTE_VERSION_LESS_THAN(mDnsClient.resolvService(), 4);
+    constexpr char host_name[] = "ohayou.example.com.";
+
+    ScopedNetwork network1 = CreateScopedNetwork(ConnectivityType::V4V6);
+    ScopedNetwork network2 = CreateScopedNetwork(ConnectivityType::V4V6);
+    ASSERT_RESULT_OK(network1.init());
+    ASSERT_RESULT_OK(network2.init());
+
+    const Result<DnsServerPair> dnsPair1 = network1.addIpv4Dns();
+    const Result<DnsServerPair> dnsPair2 = network2.addIpv4Dns();
+    ASSERT_RESULT_OK(dnsPair1);
+    ASSERT_RESULT_OK(dnsPair2);
+    StartDns(dnsPair1->dnsServer, {{host_name, ns_type::ns_t_a, "1.1.1.31"}});
+    StartDns(dnsPair2->dnsServer, {{host_name, ns_type::ns_t_a, "1.1.1.32"}});
+
+    // Set up resolver for network 1 and start forwarding.
+    ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
+    parcel.tlsServers.clear();
+    parcel.netId = network1.netId();
+    parcel.servers = {dnsPair1->dnsAddr};
+    ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+    ASSERT_TRUE(network1.startTunForwarder());
+
+    // Set up resolver for network 2 and start forwarding.
+    parcel.netId = network2.netId();
+    parcel.servers = {dnsPair2->dnsAddr};
+    ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+    ASSERT_TRUE(network2.startTunForwarder());
+
+    // Send the same queries to both networks.
+    int fd1 = resNetworkQuery(network1.netId(), host_name, ns_c_in, ns_t_a, 0);
+    int fd2 = resNetworkQuery(network2.netId(), host_name, ns_c_in, ns_t_a, 0);
+
+    expectAnswersValid(fd1, AF_INET, "1.1.1.31");
+    expectAnswersValid(fd2, AF_INET, "1.1.1.32");
+    EXPECT_EQ(GetNumQueries(dnsPair1->dnsServer, host_name), 1U);
+    EXPECT_EQ(GetNumQueries(dnsPair2->dnsServer, host_name), 1U);
+
+    // Flush the cache of network 1, and send the queries again.
+    EXPECT_TRUE(mDnsClient.resolvService()->flushNetworkCache(network1.netId()).isOk());
+    fd1 = resNetworkQuery(network1.netId(), host_name, ns_c_in, ns_t_a, 0);
+    fd2 = resNetworkQuery(network2.netId(), host_name, ns_c_in, ns_t_a, 0);
+
+    expectAnswersValid(fd1, AF_INET, "1.1.1.31");
+    expectAnswersValid(fd2, AF_INET, "1.1.1.32");
+    EXPECT_EQ(GetNumQueries(dnsPair1->dnsServer, host_name), 2U);
+    EXPECT_EQ(GetNumQueries(dnsPair2->dnsServer, host_name), 1U);
+}
diff --git a/tests/tun_forwarder.cpp b/tests/tun_forwarder.cpp
new file mode 100644
index 0000000..3e40a78
--- /dev/null
+++ b/tests/tun_forwarder.cpp
@@ -0,0 +1,419 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#define LOG_TAG "TunForwarder"
+
+#include "tun_forwarder.h"
+
+#include <arpa/inet.h>
+#include <linux/if.h>
+#include <linux/if_tun.h>
+#include <linux/ioctl.h>
+#include <netinet/ip6.h>
+#include <netinet/tcp.h>
+#include <netinet/udp.h>
+#include <sys/eventfd.h>
+#include <sys/poll.h>
+
+#include <android-base/logging.h>
+
+extern "C" {
+#include <netutils/checksum.h>
+}
+
+using android::base::Error;
+using android::base::Result;
+using android::base::unique_fd;
+using android::netdutils::Slice;
+
+namespace android::net {
+
+static constexpr int MAXMTU = 1500;
+static constexpr ssize_t TUN_HDRLEN = sizeof(struct tun_pi);
+static constexpr ssize_t IP4_HDRLEN = sizeof(struct iphdr);
+static constexpr ssize_t IP6_HDRLEN = sizeof(struct ip6_hdr);
+static constexpr ssize_t TCP_HDRLEN = sizeof(struct tcphdr);
+static constexpr ssize_t UDP_HDRLEN = sizeof(struct udphdr);
+
+namespace {
+
+bool operator==(const in6_addr& x, const in6_addr& y) {
+    return std::memcmp(x.s6_addr, y.s6_addr, 16) == 0;
+}
+
+bool operator!=(const in6_addr& x, const in6_addr& y) {
+    return !(x == y);
+}
+
+bool operator<(const in6_addr& x, const in6_addr& y) {
+    return std::memcmp(x.s6_addr, y.s6_addr, 16) < 0;
+}
+
+}  // namespace
+
+Result<TunForwarder::v4pair> TunForwarder::v4pair::makePair(
+        const std::array<std::string, 2>& addrs) {
+    v4pair pair;
+    if (inet_pton(AF_INET, addrs[0].c_str(), &pair.src) != 1 ||
+        inet_pton(AF_INET, addrs[1].c_str(), &pair.dst) != 1) {
+        return Error() << "Failed to make v4pair";
+    }
+    return pair;
+}
+
+bool TunForwarder::v4pair::operator==(const v4pair& o) const {
+    return std::tie(src.s_addr, dst.s_addr) == std::tie(o.src.s_addr, o.dst.s_addr);
+}
+
+bool TunForwarder::v4pair::operator<(const v4pair& o) const {
+    return std::tie(src.s_addr, dst.s_addr) < std::tie(o.src.s_addr, o.dst.s_addr);
+}
+
+Result<TunForwarder::v6pair> TunForwarder::v6pair::makePair(
+        const std::array<std::string, 2>& addrs) {
+    v6pair pair;
+    if (inet_pton(AF_INET6, addrs[0].c_str(), &pair.src) != 1 ||
+        inet_pton(AF_INET6, addrs[1].c_str(), &pair.dst) != 1) {
+        return Error() << "Failed to make v6pair";
+    }
+    return pair;
+}
+
+bool TunForwarder::v6pair::operator==(const v6pair& o) const {
+    return src == o.src && dst == o.dst;
+}
+
+bool TunForwarder::v6pair::operator<(const v6pair& o) const {
+    if (src != o.src) return src < o.src;
+    return dst < o.dst;
+}
+
+TunForwarder::TunForwarder(unique_fd tunFd) : mTunFd(std::move(tunFd)) {
+    mEventFd.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
+}
+
+TunForwarder::~TunForwarder() {
+    stopForwarding();
+    if (mForwarder.joinable()) {
+        mForwarder.join();
+    }
+}
+
+bool TunForwarder::startForwarding() {
+    if (mForwarder.joinable()) return false;
+    mForwarder = std::thread(&TunForwarder::loop, this);
+    return true;
+}
+
+bool TunForwarder::stopForwarding() {
+    return signalEventFd();
+}
+
+// Assume all of the strings in |from| and |to| are the IP addresses of the same IP version.
+bool TunForwarder::addForwardingRule(const std::array<std::string, 2>& from,
+                                     const std::array<std::string, 2>& to) {
+    const bool isV4 = (from[0].find(':') == from[0].npos);
+    if (isV4) {
+        auto k = v4pair::makePair(from);
+        auto v = v4pair::makePair(to);
+        if (!k.ok() || !v.ok()) return false;
+        mRulesIpv4[k.value()] = v.value();
+    } else {
+        auto k = v6pair::makePair(from);
+        auto v = v6pair::makePair(to);
+        if (!k.ok() || !v.ok()) return false;
+        mRulesIpv6[k.value()] = v.value();
+    }
+    return true;
+}
+
+unique_fd TunForwarder::createTun(const std::string& ifname) {
+    unique_fd fd(open("/dev/tun", O_RDWR | O_NONBLOCK | O_CLOEXEC));
+    if (!fd.ok() == -1) {
+        return {};
+    }
+
+    ifreq ifr = {
+            .ifr_ifru = {.ifru_flags = IFF_TUN},
+    };
+    strlcpy(ifr.ifr_name, ifname.data(), sizeof(ifr.ifr_name));
+
+    if (ioctl(fd.get(), TUNSETIFF, &ifr) == -1) {
+        PLOG(WARNING) << "failed to bring up tun " << ifr.ifr_name;
+        return {};
+    }
+
+    unique_fd inet6CtrlSock(socket(AF_INET6, SOCK_DGRAM | SOCK_CLOEXEC, 0));
+    ifr.ifr_flags = IFF_UP;
+    if (ioctl(inet6CtrlSock.get(), SIOCSIFFLAGS, &ifr) == -1) {
+        PLOG(WARNING) << "failed on SIOCSIFFLAGS " << ifr.ifr_name;
+        return {};
+    }
+
+    return fd;
+}
+
+void TunForwarder::loop() {
+    while (true) {
+        struct pollfd wait_fd[] = {
+                {mEventFd, POLLIN, 0},
+                {mTunFd.get(), POLLIN, 0},
+        };
+
+        if (int ret = poll(wait_fd, std::size(wait_fd), kPollTimeoutMs); ret <= 0) {
+            break;
+        }
+
+        if (wait_fd[0].revents & (POLLIN | POLLERR)) {
+            uint64_t value = 0;
+            eventfd_read(mEventFd, &value);
+            break;
+        }
+        if (wait_fd[1].revents & (POLLIN | POLLERR)) {
+            handlePacket(wait_fd[1].fd);
+        }
+    }
+}
+
+void TunForwarder::handlePacket(int fd) const {
+    uint8_t buf[MAXMTU + TUN_HDRLEN];
+
+    ssize_t readlen = read(fd, buf, std::size(buf));
+    if (readlen < 0) {
+        PLOG(ERROR) << "failed to read packets from tun";
+        return;
+    } else if (readlen == 0) {
+        PLOG(ERROR) << "tun interface removed";
+        return;
+    }
+
+    // Filter the packet. Only TCP and UDP packets are allowed.
+    const Slice tunPacket(buf, readlen);
+    if (auto result = validatePacket(tunPacket); !result.ok()) {
+        LOG(DEBUG) << "validatePacket failed: " << result.error();
+        return;
+    }
+
+    // Change the packet's source/destination address and checksum.
+    if (auto result = translatePacket(tunPacket); !result.ok()) {
+        LOG(ERROR) << "translatePacket failed: " << result.error();
+    }
+
+    // Write the new packet to the fd, causing the kernel to receive it on the tun interface.
+    write(fd, buf, readlen);
+}
+
+Result<void> TunForwarder::validatePacket(Slice tunPacket) const {
+    if (tunPacket.size() < TUN_HDRLEN) {
+        return Error() << "Too short for a tun header";
+    }
+
+    const tun_pi* const tunHeader = reinterpret_cast<tun_pi*>(tunPacket.base());
+    if (tunHeader->flags != 0) {
+        return Error() << "Unexpected tun flags " << static_cast<int>(tunHeader->flags);
+    }
+
+    switch (uint16_t proto = ntohs(tunHeader->proto); proto) {
+        case ETH_P_IP:
+            return validateIpv4Packet(drop(tunPacket, TUN_HDRLEN));
+        case ETH_P_IPV6:
+            return validateIpv6Packet(drop(tunPacket, TUN_HDRLEN));
+        default:
+            return Error() << "Unsupported packet type 0x" << std::hex << static_cast<int>(proto);
+    }
+}
+
+Result<void> TunForwarder::validateIpv4Packet(Slice ipv4Packet) const {
+    if (ipv4Packet.size() < IP4_HDRLEN) {
+        return Error() << "Too short for an ip header";
+    }
+
+    const iphdr* const ipHeader = reinterpret_cast<iphdr*>(ipv4Packet.base());
+    if (ipHeader->ihl < 5) {
+        return Error() << "IP header length set to less than 5";
+    }
+    if (ipHeader->ihl * 4 > ipv4Packet.size()) {
+        return Error() << "IP header length set too large: " << ipHeader->ihl;
+    }
+    if (ipHeader->version != 4) {
+        return Error() << "IP header version not 4: " << ipHeader->version;
+    }
+    if (mRulesIpv4.find({ipHeader->saddr, ipHeader->daddr}) == mRulesIpv4.end()) {
+        return Error() << "Can't find any v4 rule. Packet hex dump: " << toHex(ipv4Packet, 32);
+    }
+
+    switch (ipHeader->protocol) {
+        case IPPROTO_UDP:
+            return validateUdpPacket(drop(ipv4Packet, ipHeader->ihl * 4));
+        case IPPROTO_TCP:
+            return validateTcpPacket(drop(ipv4Packet, ipHeader->ihl * 4));
+        default:
+            return Error() << "Unsupported transport protocol "
+                           << static_cast<int>(ipHeader->protocol);
+    }
+}
+
+Result<void> TunForwarder::validateIpv6Packet(Slice ipv6Packet) const {
+    if (ipv6Packet.size() < IP6_HDRLEN) {
+        return Error() << "Too short for an ipv6 header";
+    }
+
+    const ip6_hdr* const ipv6Header = reinterpret_cast<ip6_hdr*>(ipv6Packet.base());
+    if (mRulesIpv6.find({ipv6Header->ip6_src, ipv6Header->ip6_dst}) == mRulesIpv6.end()) {
+        return Error() << "Can't find any v6 rule. Packet hex dump: " << toHex(ipv6Packet, 32);
+    }
+
+    switch (ipv6Header->ip6_nxt) {
+        case IPPROTO_UDP:
+            return validateUdpPacket(drop(ipv6Packet, IP6_HDRLEN));
+        case IPPROTO_TCP:
+            return validateTcpPacket(drop(ipv6Packet, IP6_HDRLEN));
+        default:
+            return Error() << "Expect TCP/UDP in ipv6 next header: "
+                           << static_cast<int>(ipv6Header->ip6_nxt);
+    }
+}
+
+Result<void> TunForwarder::validateUdpPacket(Slice udpPacket) const {
+    if (udpPacket.size() < UDP_HDRLEN) {
+        return Error() << "Too short for a udp header";
+    }
+    return {};
+}
+
+Result<void> TunForwarder::validateTcpPacket(Slice tcpPacket) const {
+    if (tcpPacket.size() < TCP_HDRLEN) {
+        return Error() << "Too short for a tcp header";
+    }
+
+    const tcphdr* const tcpHeader = reinterpret_cast<tcphdr*>(tcpPacket.base());
+    if (tcpHeader->doff < 5) {
+        return Error() << "TCP header length set to less than 5";
+    }
+    if (tcpHeader->doff * 4 > tcpPacket.size()) {
+        return Error() << "TCP header length set too large: " << tcpHeader->doff;
+    }
+    return {};
+}
+
+Result<void> TunForwarder::translatePacket(Slice tunPacket) const {
+    const tun_pi* const tunHeader = reinterpret_cast<tun_pi*>(tunPacket.base());
+    switch (uint16_t proto = ntohs(tunHeader->proto); proto) {
+        case ETH_P_IP:
+            return translateIpv4Packet(drop(tunPacket, TUN_HDRLEN));
+        case ETH_P_IPV6:
+            return translateIpv6Packet(drop(tunPacket, TUN_HDRLEN));
+        default:
+            return Error() << "translate: Unsupported packet type 0x" << std::hex
+                           << static_cast<int>(proto);
+    }
+}
+
+Result<void> TunForwarder::translateIpv4Packet(Slice ipv4Packet) const {
+    iphdr* ipHeader = reinterpret_cast<iphdr*>(ipv4Packet.base());
+    const size_t ipHeaderLen = ipHeader->ihl * 4;
+    const size_t transport_len = ipv4Packet.size() - ipHeaderLen;
+
+    uint32_t oldPseudoSum = ipv4_pseudo_header_checksum(ipHeader, transport_len);
+    for (const auto& [from, to] : mRulesIpv4) {
+        if (ipHeader->saddr == static_cast<int>(from.src.s_addr) &&
+            ipHeader->daddr == static_cast<int>(from.dst.s_addr)) {
+            ipHeader->saddr = to.src.s_addr;
+            ipHeader->daddr = to.dst.s_addr;
+            break;
+        }
+    }
+    uint32_t newPseudoSum = ipv4_pseudo_header_checksum(ipHeader, transport_len);
+
+    ipHeader->check = 0;
+    ipHeader->check = ip_checksum(ipHeader, sizeof(struct iphdr));
+
+    switch (ipHeader->protocol) {
+        case IPPROTO_UDP:
+            translateUdpPacket(drop(ipv4Packet, ipHeaderLen), oldPseudoSum, newPseudoSum);
+            break;
+        case IPPROTO_TCP:
+            translateTcpPacket(drop(ipv4Packet, ipHeaderLen), oldPseudoSum, newPseudoSum);
+            break;
+        default:
+            return Error() << "translate: Unsupported transport protocol "
+                           << static_cast<int>(ipHeader->protocol);
+    }
+
+    return {};
+}
+
+Result<void> TunForwarder::translateIpv6Packet(Slice ipv6Packet) const {
+    ip6_hdr* ipv6Header = reinterpret_cast<ip6_hdr*>(ipv6Packet.base());
+    const size_t ipHeaderLen = IP6_HDRLEN;
+    const size_t transport_len = ipv6Packet.size() - ipHeaderLen;
+
+    uint32_t oldPseudoSum =
+            ipv6_pseudo_header_checksum(ipv6Header, transport_len, ipv6Header->ip6_nxt);
+    for (const auto& [from, to] : mRulesIpv6) {
+        if (ipv6Header->ip6_src == from.src && ipv6Header->ip6_dst == from.dst) {
+            ipv6Header->ip6_src = to.src;
+            ipv6Header->ip6_dst = to.dst;
+            break;
+        }
+    }
+    uint32_t newPseudoSum =
+            ipv6_pseudo_header_checksum(ipv6Header, transport_len, ipv6Header->ip6_nxt);
+
+    switch (ipv6Header->ip6_nxt) {
+        case IPPROTO_UDP:
+            translateUdpPacket(drop(ipv6Packet, ipHeaderLen), oldPseudoSum, newPseudoSum);
+            break;
+        case IPPROTO_TCP:
+            translateTcpPacket(drop(ipv6Packet, ipHeaderLen), oldPseudoSum, newPseudoSum);
+            break;
+        default:
+            return Error() << "transliate: Expect TCP/UDP in ipv6 next header: "
+                           << static_cast<int>(ipv6Header->ip6_nxt);
+    }
+
+    return {};
+}
+
+void TunForwarder::translateUdpPacket(Slice udpPacket, uint32_t oldPseudoSum,
+                                      uint32_t newPseudoSum) const {
+    udphdr* udpHeader = reinterpret_cast<udphdr*>(udpPacket.base());
+    if (udpHeader->check) {
+        udpHeader->check = ip_checksum_adjust(udpHeader->check, oldPseudoSum, newPseudoSum);
+    } else {
+        uint32_t tmp = ip_checksum_add(newPseudoSum, udpPacket.base(), udpPacket.size());
+        udpHeader->check = ip_checksum_finish(tmp);
+    }
+
+    // RFC 768: "If the computed checksum is zero, it is transmitted as all ones (the equivalent
+    // in one's complement arithmetic)."
+    if (!udpHeader->check) {
+        udpHeader->check = 0xffff;
+    }
+}
+
+void TunForwarder::translateTcpPacket(Slice tcpPacket, uint32_t oldPseudoSum,
+                                      uint32_t newPseudoSum) const {
+    tcphdr* tcpHeader = reinterpret_cast<tcphdr*>(tcpPacket.base());
+    tcpHeader->check = ip_checksum_adjust(tcpHeader->check, oldPseudoSum, newPseudoSum);
+}
+
+bool TunForwarder::signalEventFd() {
+    return eventfd_write(mEventFd.get(), 1) == 0;
+}
+
+}  // namespace android::net
diff --git a/tests/tun_forwarder.h b/tests/tun_forwarder.h
new file mode 100644
index 0000000..2b0a65a
--- /dev/null
+++ b/tests/tun_forwarder.h
@@ -0,0 +1,105 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#pragma once
+
+#include <map>
+#include <thread>
+
+#include <netinet/ip.h>
+
+#include <android-base/result.h>
+#include <android-base/unique_fd.h>
+#include <netdutils/Slice.h>
+
+namespace android::net {
+
+// Given a TUN interface fd, TunForwarder reads packets from the fd, changes their IP header
+// according to a set of forwarding rules (which can be set by addForwardingRule), and sends
+// new packets back to the fd. Only IPv4 and IPv6 packets with recognized source and destination
+// addresses are accepted; other packets are silently ignored.
+class TunForwarder {
+  public:
+    TunForwarder(base::unique_fd tunFd);
+    ~TunForwarder();
+
+    bool addForwardingRule(const std::array<std::string, 2>& from,
+                           const std::array<std::string, 2>& to);
+    bool startForwarding();
+    bool stopForwarding();
+
+    static base::unique_fd createTun(const std::string& ifname);
+
+  private:
+    // TODO: Considering using IPAddress for v4pair and v6pair. This might requires adding
+    // addr4() and addr6() as IPPrefix does.
+    struct v4pair {
+        static base::Result<v4pair> makePair(const std::array<std::string, 2>& addrs);
+        v4pair() = default;
+        v4pair(int32_t srcAddr, int32_t dstAddr) {
+            src.s_addr = static_cast<in_addr_t>(srcAddr);
+            dst.s_addr = static_cast<in_addr_t>(dstAddr);
+        }
+        in_addr src;
+        in_addr dst;
+        bool operator==(const v4pair& o) const;
+        bool operator<(const v4pair& o) const;
+    };
+
+    struct v6pair {
+        static base::Result<v6pair> makePair(const std::array<std::string, 2>& addrs);
+        v6pair() = default;
+        v6pair(const in6_addr& srcAddr, const in6_addr& dstAddr) : src(srcAddr), dst(dstAddr) {}
+        in6_addr src;
+        in6_addr dst;
+        bool operator==(const v6pair& o) const;
+        bool operator<(const v6pair& o) const;
+    };
+
+    void loop();
+    void handlePacket(int fd) const;
+
+    // Send a signal to terminate the loop thread.
+    bool signalEventFd();
+
+    // A series of functions to check the packet. Return error if the packet is neither UDP nor TCP.
+    base::Result<void> validatePacket(netdutils::Slice tunPacket) const;
+    base::Result<void> validateIpv4Packet(netdutils::Slice ipv4Packet) const;
+    base::Result<void> validateIpv6Packet(netdutils::Slice ipv6Packet) const;
+    base::Result<void> validateUdpPacket(netdutils::Slice udpPacket) const;
+    base::Result<void> validateTcpPacket(netdutils::Slice tcpPacket) const;
+
+    // The function assumes |tunPacket| is either UDP or TCP packet, changes the source/destination
+    // addresses, and updates the checksum.
+    base::Result<void> translatePacket(netdutils::Slice tunPacket) const;
+    base::Result<void> translateIpv4Packet(netdutils::Slice ipv4Packet) const;
+    base::Result<void> translateIpv6Packet(netdutils::Slice ipv6Packet) const;
+    void translateUdpPacket(netdutils::Slice udpPacket, uint32_t oldPseudoSum,
+                            uint32_t newPseudoSum) const;
+    void translateTcpPacket(netdutils::Slice tcpPacket, uint32_t oldPseudoSum,
+                            uint32_t newPseudoSum) const;
+
+    std::thread mForwarder;
+    base::unique_fd mTunFd;
+    base::unique_fd mEventFd;
+    std::map<v4pair, v4pair> mRulesIpv4;
+    std::map<v6pair, v6pair> mRulesIpv6;
+
+    static constexpr int kPollTimeoutMs = 5000;
+};
+
+}  // namespace android::net