Add unit test for shared/unshared interface quotas

Test: as follows
    - built
    - flashed
    - booted
    - "runtest -x .../netd_unit_test.cpp" passes
Bug: 28362720
Bug: 38143143

Change-Id: I0b962898f9e3d7e86d5c0d0d01b79b3e3543b5ee
diff --git a/server/BandwidthController.cpp b/server/BandwidthController.cpp
index 9aace9d..69addb3 100644
--- a/server/BandwidthController.cpp
+++ b/server/BandwidthController.cpp
@@ -22,15 +22,14 @@
  * If they ever were to allow it, then netd/ would need some tweaking.
  */
 
-#include <string>
-#include <vector>
-
+#include <ctype.h>
 #include <errno.h>
 #include <fcntl.h>
 #include <stdio.h>
 #include <stdlib.h>
 #include <string.h>
-#include <ctype.h>
+#include <string>
+#include <vector>
 
 #define __STDC_FORMAT_MACROS 1
 #include <inttypes.h>
@@ -51,9 +50,10 @@
 #include <cutils/properties.h>
 #include <logwrap/logwrap.h>
 
-#include "NetdConstants.h"
+#include <netdutils/Syscalls.h>
 #include "BandwidthController.h"
-#include "NatController.h"  /* For LOCAL_TETHER_COUNTERS_CHAIN */
+#include "NatController.h" /* For LOCAL_TETHER_COUNTERS_CHAIN */
+#include "NetdConstants.h"
 #include "ResponseCode.h"
 
 /* Alphabetical */
@@ -70,6 +70,8 @@
 
 using android::base::StringAppendF;
 using android::base::StringPrintf;
+using android::netdutils::StatusOr;
+using android::netdutils::UniqueFile;
 
 namespace {
 
@@ -625,21 +627,23 @@
 }
 
 int BandwidthController::getInterfaceQuota(const std::string& iface, int64_t* bytes) {
-    FILE *fp;
+    const auto& sys = android::netdutils::sSyscalls.get();
     const std::string fname = "/proc/net/xt_quota/" + iface;
-    int scanRes;
 
     if (!isIfaceName(iface)) return -1;
 
-    fp = fopen(fname.c_str(), "re");
-    if (!fp) {
-        ALOGE("Reading quota %s failed (%s)", iface.c_str(), strerror(errno));
+    StatusOr<UniqueFile> file = sys.fopen(fname, "re");
+    if (!isOk(file)) {
+        ALOGE("Reading quota %s failed (%s)", iface.c_str(), toString(file).c_str());
         return -1;
     }
-    scanRes = fscanf(fp, "%" SCNd64, bytes);
-    ALOGV("Read quota res=%d bytes=%" PRId64, scanRes, *bytes);
-    fclose(fp);
-    return scanRes == 1 ? 0 : -1;
+    auto rv = sys.fscanf(file.value().get(), "%" SCNd64, bytes);
+    if (!isOk(rv)) {
+        ALOGE("Reading quota %s failed (%s)", iface.c_str(), toString(rv).c_str());
+        return -1;
+    }
+    ALOGV("Read quota res=%d bytes=%" PRId64, rv.value(), *bytes);
+    return rv.value() == 1 ? 0 : -1;
 }
 
 int BandwidthController::removeInterfaceQuota(const std::string& iface) {
@@ -668,23 +672,20 @@
 }
 
 int BandwidthController::updateQuota(const std::string& quotaName, int64_t bytes) {
-    FILE *fp;
-    char *fname;
+    const auto& sys = android::netdutils::sSyscalls.get();
+    const std::string fname = "/proc/net/xt_quota/" + quotaName;
 
     if (!isIfaceName(quotaName)) {
         ALOGE("updateQuota: Invalid quotaName \"%s\"", quotaName.c_str());
         return -1;
     }
 
-    asprintf(&fname, "/proc/net/xt_quota/%s", quotaName.c_str());
-    fp = fopen(fname, "we");
-    free(fname);
-    if (!fp) {
-        ALOGE("Updating quota %s failed (%s)", quotaName.c_str(), strerror(errno));
+    StatusOr<UniqueFile> file = sys.fopen(fname, "we");
+    if (!isOk(file)) {
+        ALOGE("Updating quota %s failed (%s)", quotaName.c_str(), toString(file).c_str());
         return -1;
     }
-    fprintf(fp, "%" PRId64"\n", bytes);
-    fclose(fp);
+    sys.fprintf(file.value().get(), "%" PRId64 "\n", bytes);
     return 0;
 }
 
diff --git a/server/BandwidthControllerTest.cpp b/server/BandwidthControllerTest.cpp
index 954db57..dd46612 100644
--- a/server/BandwidthControllerTest.cpp
+++ b/server/BandwidthControllerTest.cpp
@@ -30,15 +30,25 @@
 #include <android-base/strings.h>
 #include <android-base/stringprintf.h>
 
+#include <netdutils/MockSyscalls.h>
 #include "BandwidthController.h"
 #include "IptablesBaseTest.h"
 #include "tun_interface.h"
 
+using ::testing::ByMove;
+using ::testing::Invoke;
+using ::testing::Return;
+using ::testing::StrictMock;
+using ::testing::Test;
+using ::testing::_;
+
 using android::base::StringPrintf;
 using android::net::TunInterface;
+using android::netdutils::status::ok;
+using android::netdutils::UniqueFile;
 
 class BandwidthControllerTest : public IptablesBaseTest {
-public:
+protected:
     BandwidthControllerTest() {
         BandwidthController::execFunction = fake_android_fork_exec;
         BandwidthController::popenFunction = fake_popen;
@@ -112,6 +122,21 @@
     int runIptablesAlertFwdCmd(IptOp a, const char *b, int64_t c) {
         return mBw.runIptablesAlertFwdCmd(a, b, c);
     }
+
+    void expectUpdateQuota(uint64_t quota) {
+        uintptr_t dummy;
+        FILE* dummyFile = reinterpret_cast<FILE*>(&dummy);
+
+        EXPECT_CALL(mSyscalls, fopen(_, _)).WillOnce(Return(ByMove(UniqueFile(dummyFile))));
+        EXPECT_CALL(mSyscalls, vfprintf(dummyFile, _, _))
+            .WillOnce(Invoke([quota](FILE*, const std::string&, va_list ap) {
+                EXPECT_EQ(quota, va_arg(ap, uint64_t));
+                return 0;
+            }));
+        EXPECT_CALL(mSyscalls, fclose(dummyFile)).WillOnce(Return(ok));
+    }
+
+    StrictMock<android::netdutils::ScopedMockSyscalls> mSyscalls;
 };
 
 TEST_F(BandwidthControllerTest, TestSetupIptablesHooks) {
@@ -355,38 +380,46 @@
     clearIptablesRestoreOutput();
 }
 
-const std::vector<std::string> makeInterfaceQuotaCommands(const char *iface, int ruleIndex,
+const std::vector<std::string> makeInterfaceQuotaCommands(const std::string& iface, int ruleIndex,
                                                           int64_t quota) {
+    const std::string chain = "bw_costly_" + iface;
+    const char* c_chain = chain.c_str();
+    const char* c_iface = iface.c_str();
     std::vector<std::string> cmds = {
-        StringPrintf("-F bw_costly_%s", iface),
-        StringPrintf("-N bw_costly_%s", iface),
-        StringPrintf("-A bw_costly_%s -j bw_penalty_box", iface),
-        StringPrintf("-D bw_INPUT -i %s --jump bw_costly_%s", iface, iface),
-        StringPrintf("-I bw_INPUT %d -i %s --jump bw_costly_%s", ruleIndex, iface, iface),
-        StringPrintf("-D bw_OUTPUT -o %s --jump bw_costly_%s", iface, iface),
-        StringPrintf("-I bw_OUTPUT %d -o %s --jump bw_costly_%s", ruleIndex, iface, iface),
-        StringPrintf("-D bw_FORWARD -o %s --jump bw_costly_%s", iface, iface),
-        StringPrintf("-A bw_FORWARD -o %s --jump bw_costly_%s", iface, iface),
-        StringPrintf("-A bw_costly_%s -m quota2 ! --quota %" PRIu64 " --name %s --jump REJECT",
-                     iface, quota, iface),
+        //      StringPrintf(":%s -", c_chain),
+        StringPrintf("-F %s", c_chain),
+        StringPrintf("-N %s", c_chain),
+        StringPrintf("-A %s -j bw_penalty_box", c_chain),
+        StringPrintf("-D bw_INPUT -i %s --jump %s", c_iface, c_chain),
+        StringPrintf("-I bw_INPUT %d -i %s --jump %s", ruleIndex, c_iface, c_chain),
+        StringPrintf("-D bw_OUTPUT -o %s --jump %s", c_iface, c_chain),
+        StringPrintf("-I bw_OUTPUT %d -o %s --jump %s", ruleIndex, c_iface, c_chain),
+        StringPrintf("-D bw_FORWARD -o %s --jump %s", c_iface, c_chain),
+        StringPrintf("-A bw_FORWARD -o %s --jump %s", c_iface, c_chain),
+        StringPrintf("-A %s -m quota2 ! --quota %" PRIu64 " --name %s --jump REJECT", c_chain,
+                     quota, c_iface),
     };
     return cmds;
 }
 
-const std::vector<std::string> removeInterfaceQuotaCommands(const char *iface) {
+const std::vector<std::string> removeInterfaceQuotaCommands(const std::string& iface) {
+    const std::string chain = "bw_costly_" + iface;
+    const char* c_chain = chain.c_str();
+    const char* c_iface = iface.c_str();
     std::vector<std::string> cmds = {
-        StringPrintf("-D bw_INPUT -i %s --jump bw_costly_%s", iface, iface),
-        StringPrintf("-D bw_OUTPUT -o %s --jump bw_costly_%s", iface, iface),
-        StringPrintf("-D bw_FORWARD -o %s --jump bw_costly_%s", iface, iface),
-        StringPrintf("-F bw_costly_%s", iface),
-        StringPrintf("-X bw_costly_%s", iface),
+        StringPrintf("-D bw_INPUT -i %s --jump %s", c_iface, c_chain),
+        StringPrintf("-D bw_OUTPUT -o %s --jump %s", c_iface, c_chain),
+        StringPrintf("-D bw_FORWARD -o %s --jump %s", c_iface, c_chain),
+        StringPrintf("-F %s", c_chain),
+        StringPrintf("-X %s", c_chain),
     };
     return cmds;
 }
 
 TEST_F(BandwidthControllerTest, TestSetInterfaceQuota) {
-    const char *iface = mTun.name().c_str();
-    std::vector<std::string> expected = makeInterfaceQuotaCommands(iface, 1, 123456);
+    constexpr uint64_t kOldQuota = 123456;
+    const std::string iface = mTun.name();
+    std::vector<std::string> expected = makeInterfaceQuotaCommands(iface, 1, kOldQuota);
 
     // prepCostlyInterface assumes that exactly one of the "-F chain" and "-N chain" commands fails.
     // So pretend that the first two commands (the IPv4 -F and the IPv6 -F) fail.
@@ -395,7 +428,13 @@
     returnValues[1] = 1;
     setReturnValues(returnValues);
 
-    EXPECT_EQ(0, mBw.setInterfaceQuota(iface, 123456));
+    EXPECT_EQ(0, mBw.setInterfaceQuota(iface, kOldQuota));
+    expectIptablesCommands(expected);
+
+    constexpr uint64_t kNewQuota = kOldQuota + 1;
+    expected = {};
+    expectUpdateQuota(kNewQuota);
+    EXPECT_EQ(0, mBw.setInterfaceQuota(iface, kNewQuota));
     expectIptablesCommands(expected);
 
     expected = removeInterfaceQuotaCommands(iface);
@@ -403,6 +442,105 @@
     expectIptablesCommands(expected);
 }
 
+const std::vector<std::string> makeInterfaceSharedQuotaCommands(const std::string& iface,
+                                                                int ruleIndex, int64_t quota) {
+    const std::string chain = "bw_costly_shared";
+    const char* c_chain = chain.c_str();
+    const char* c_iface = iface.c_str();
+    std::vector<std::string> cmds = {
+        StringPrintf("-D bw_INPUT -i %s --jump %s", c_iface, c_chain),
+        StringPrintf("-I bw_INPUT %d -i %s --jump %s", ruleIndex, c_iface, c_chain),
+        StringPrintf("-D bw_OUTPUT -o %s --jump %s", c_iface, c_chain),
+        StringPrintf("-I bw_OUTPUT %d -o %s --jump %s", ruleIndex, c_iface, c_chain),
+        StringPrintf("-D bw_FORWARD -o %s --jump %s", c_iface, c_chain),
+        StringPrintf("-A bw_FORWARD -o %s --jump %s", c_iface, c_chain),
+        StringPrintf("-I %s -m quota2 ! --quota %" PRIu64 " --name shared --jump REJECT", c_chain,
+                     quota),
+    };
+    return cmds;
+}
+
+const std::vector<std::string> removeInterfaceSharedQuotaCommands(const std::string& iface,
+                                                                  int64_t quota) {
+    const std::string chain = "bw_costly_shared";
+    const char* c_chain = chain.c_str();
+    const char* c_iface = iface.c_str();
+    std::vector<std::string> cmds = {
+        StringPrintf("-D bw_INPUT -i %s --jump %s", c_iface, c_chain),
+        StringPrintf("-D bw_OUTPUT -o %s --jump %s", c_iface, c_chain),
+        StringPrintf("-D bw_FORWARD -o %s --jump %s", c_iface, c_chain),
+        StringPrintf("-D %s -m quota2 ! --quota %" PRIu64
+                     " --name shared --jump REJECT", c_chain, quota),
+    };
+    return cmds;
+}
+
+TEST_F(BandwidthControllerTest, TestSetInterfaceSharedQuotaDuplicate) {
+    constexpr uint64_t kQuota = 123456;
+    const std::string iface = mTun.name();
+    std::vector<std::string> expected = makeInterfaceSharedQuotaCommands(iface, 1, 123456);
+    EXPECT_EQ(0, mBw.setInterfaceSharedQuota(iface, kQuota));
+    expectIptablesCommands(expected);
+
+    expected = {};
+    EXPECT_EQ(0, mBw.setInterfaceSharedQuota(iface, kQuota));
+    expectIptablesCommands(expected);
+
+    expected = removeInterfaceSharedQuotaCommands(iface, kQuota);
+    EXPECT_EQ(0, mBw.removeInterfaceSharedQuota(iface));
+    expectIptablesCommands(expected);
+}
+
+TEST_F(BandwidthControllerTest, TestSetInterfaceSharedQuotaUpdate) {
+    constexpr uint64_t kOldQuota = 123456;
+    const std::string iface = mTun.name();
+    std::vector<std::string> expected = makeInterfaceSharedQuotaCommands(iface, 1, kOldQuota);
+    EXPECT_EQ(0, mBw.setInterfaceSharedQuota(iface, kOldQuota));
+    expectIptablesCommands(expected);
+
+    constexpr uint64_t kNewQuota = kOldQuota + 1;
+    expected = {};
+    expectUpdateQuota(kNewQuota);
+    EXPECT_EQ(0, mBw.setInterfaceSharedQuota(iface, kNewQuota));
+    expectIptablesCommands(expected);
+
+    expected = removeInterfaceSharedQuotaCommands(iface, kNewQuota);
+    EXPECT_EQ(0, mBw.removeInterfaceSharedQuota(iface));
+    expectIptablesCommands(expected);
+}
+
+TEST_F(BandwidthControllerTest, TestSetInterfaceSharedQuotaTwoInterfaces) {
+    constexpr uint64_t kQuota = 123456;
+    const std::vector<std::string> ifaces{
+        {"a" + mTun.name()},
+        {"b" + mTun.name()},
+    };
+
+    for (const auto& iface : ifaces) {
+        bool first = (iface == ifaces[0]);
+        auto expected = makeInterfaceSharedQuotaCommands(iface, 1, kQuota);
+        if (!first) {
+            // Quota rule is only added when the total number of
+            // interfaces transitions from 0 -> 1.
+            expected.pop_back();
+        }
+        EXPECT_EQ(0, mBw.setInterfaceSharedQuota(iface, kQuota));
+        expectIptablesCommands(expected);
+    }
+
+    for (const auto& iface : ifaces) {
+        bool last = (iface == ifaces[1]);
+        auto expected = removeInterfaceSharedQuotaCommands(iface, kQuota);
+        if (!last) {
+            // Quota rule is only removed when the total number of
+            // interfaces transitions from 1 -> 0.
+            expected.pop_back();
+        }
+        EXPECT_EQ(0, mBw.removeInterfaceSharedQuota(iface));
+        expectIptablesCommands(expected);
+    }
+}
+
 TEST_F(BandwidthControllerTest, IptablesAlertCmd) {
     std::vector<std::string> expected = {
         "*filter\n"