Merge "netd_integration_test static dependecies"
diff --git a/libnetdutils/Syscalls.cpp b/libnetdutils/Syscalls.cpp
index b0301d2..5354341 100644
--- a/libnetdutils/Syscalls.cpp
+++ b/libnetdutils/Syscalls.cpp
@@ -178,6 +178,14 @@
return file;
}
+ StatusOr<pid_t> fork() const override {
+ pid_t rv = ::fork();
+ if (rv == -1) {
+ return statusFromErrno(errno, "fork() failed");
+ }
+ return rv;
+ }
+
StatusOr<int> vfprintf(FILE* file, const char* format, va_list ap) const override {
auto rv = ::vfprintf(file, format, ap);
if (rv == -1) {
diff --git a/libnetdutils/include/netdutils/MockSyscalls.h b/libnetdutils/include/netdutils/MockSyscalls.h
index fbd6791..149ba59 100644
--- a/libnetdutils/include/netdutils/MockSyscalls.h
+++ b/libnetdutils/include/netdutils/MockSyscalls.h
@@ -62,6 +62,7 @@
MOCK_CONST_METHOD3(vfprintf, StatusOr<int>(FILE* file, const char* format, va_list ap));
MOCK_CONST_METHOD3(vfscanf, StatusOr<int>(FILE* file, const char* format, va_list ap));
MOCK_CONST_METHOD1(fclose, Status(FILE* file));
+ MOCK_CONST_METHOD0(fork, StatusOr<pid_t>());
};
// For the lifetime of this mock, replace the contents of sSyscalls
diff --git a/libnetdutils/include/netdutils/Syscalls.h b/libnetdutils/include/netdutils/Syscalls.h
index 5190da1..0e336b6 100644
--- a/libnetdutils/include/netdutils/Syscalls.h
+++ b/libnetdutils/include/netdutils/Syscalls.h
@@ -20,6 +20,7 @@
#include <memory>
#include <poll.h>
+#include <unistd.h>
#include <sys/eventfd.h>
#include <sys/socket.h>
#include <sys/types.h>
@@ -81,6 +82,8 @@
virtual Status fclose(FILE* file) const = 0;
+ virtual StatusOr<pid_t> fork() const = 0;
+
// va_args helpers
// va_start doesn't work when the preceding argument is a reference
// type so we're forced to use const char*.
diff --git a/netutils_wrappers/Android.mk b/netutils_wrappers/Android.mk
new file mode 100644
index 0000000..ed1af34
--- /dev/null
+++ b/netutils_wrappers/Android.mk
@@ -0,0 +1,47 @@
+# Copyright (C) 2017 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+LOCAL_PATH := $(call my-dir)
+
+###
+### Wrapper binary.
+###
+include $(CLEAR_VARS)
+
+LOCAL_CFLAGS := -Wall -Werror
+LOCAL_CLANG := true
+LOCAL_MODULE := netutils-wrapper-1.0
+LOCAL_SHARED_LIBRARIES := libbase liblog
+LOCAL_SRC_FILES := NetUtilsWrapper-1.0.cpp main.cpp
+LOCAL_MODULE_SYMLINKS := \
+ iptables-wrapper-1.0 \
+ ip6tables-wrapper-1.0 \
+ ndc-wrapper-1.0 \
+ tc-wrapper-1.0 \
+ ip-wrapper-1.0
+
+include $(BUILD_EXECUTABLE)
+
+###
+### Wrapper unit tests.
+###
+include $(CLEAR_VARS)
+
+LOCAL_CFLAGS := -Wall -Werror
+LOCAL_CLANG := true
+LOCAL_MODULE := netutils_wrapper_test
+LOCAL_SHARED_LIBRARIES := libbase liblog
+LOCAL_SRC_FILES := NetUtilsWrapper-1.0.cpp NetUtilsWrapperTest-1.0.cpp
+
+include $(BUILD_NATIVE_TEST)
diff --git a/netutils_wrappers/NetUtilsWrapper-1.0.cpp b/netutils_wrappers/NetUtilsWrapper-1.0.cpp
new file mode 100644
index 0000000..a9fbf3f
--- /dev/null
+++ b/netutils_wrappers/NetUtilsWrapper-1.0.cpp
@@ -0,0 +1,134 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <regex>
+#include <string>
+
+#include <libgen.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <unistd.h>
+
+#include <android-base/strings.h>
+
+#define LOG_TAG "NetUtilsWrapper"
+#include <cutils/log.h>
+
+#include "NetUtilsWrapper.h"
+
+#define SYSTEM_DIRNAME "/system/bin/"
+
+#define OEM_IFACE "[^ ]*oem[0-9]+"
+#define RMNET_IFACE "(r_)?rmnet_(data)?[0-9]+"
+#define VENDOR_IFACE "(" OEM_IFACE "|" RMNET_IFACE ")"
+#define VENDOR_CHAIN "(oem_.*|nm_.*|qcom_.*)"
+
+// List of net utils wrapped by this program
+// The list MUST be in descending order of string length
+const char *netcmds[] = {
+ "ip6tables",
+ "iptables",
+ "ndc",
+ "tc",
+ "ip",
+ NULL,
+};
+
+// List of regular expressions of expected commands.
+const char *EXPECTED_REGEXPS[] = {
+#define CMD "^" SYSTEM_DIRNAME
+ // Create, delete, and manage OEM networks.
+ CMD "ndc network (create|destroy) oem[0-9]+( |$)",
+ CMD "ndc network interface (add|remove) oem[0-9]+ " VENDOR_IFACE,
+ CMD "ndc network route (add|remove) oem[0-9]+ ",
+ CMD "ndc ipfwd (enable|disable) ",
+ CMD "ndc ipfwd (add|remove) .*" VENDOR_IFACE,
+
+ // Manage vendor iptables rules.
+ CMD "ip(6)?tables -w.* (-A|-D|-F|-I|-N|-X) " VENDOR_CHAIN,
+ CMD "ip(6)?tables -w.* (-i|-o) " VENDOR_IFACE,
+
+ // Manage IPsec state.
+ CMD "ip xfrm .*",
+
+ // Manage vendor interfaces.
+ CMD "tc .* dev " VENDOR_IFACE,
+ CMD "ip( -4| -6)? (addr|address) (add|del|delete|flush).* dev " VENDOR_IFACE,
+
+ // Other activities observed on current devices. In future releases, these should be supported
+ // in a way that is less likely to interfere with general Android networking behaviour.
+ CMD "tc qdisc del dev root",
+ CMD "ip( -4| -6)? rule .* goto 13000 prio 11999",
+ CMD "ip( -4| -6)? rule .* prio 25000",
+ CMD "ip(6)?tables -w .* -j " VENDOR_CHAIN,
+ CMD "iptables -w -t mangle -[AD] PREROUTING -m socket --nowildcard --restore-skmark -j ACCEPT",
+ CMD "ndc network interface (add|remove) oem[0-9]+$", // Invalid command: no interface removed.
+#undef CMD
+};
+
+bool checkExpectedCommand(int argc, char **argv) {
+ static bool loggedError = false;
+ std::vector<const char*> allArgs(argc);
+ for (int i = 0; i < argc; i++) {
+ allArgs[i] = argv[i];
+ }
+ std::string fullCmd = android::base::Join(allArgs, ' ');
+ for (size_t i = 0; i < ARRAY_SIZE(EXPECTED_REGEXPS); i++) {
+ const std::regex expectedRegexp(EXPECTED_REGEXPS[i], std::regex_constants::extended);
+ if (std::regex_search(fullCmd, expectedRegexp)) {
+ return true;
+ }
+ }
+ if (!loggedError) {
+ ALOGI("Unexpected command: %s", fullCmd.c_str());
+ fprintf(stderr, LOG_TAG ": Unexpected command: %s\n", fullCmd.c_str());
+ loggedError = true;
+ }
+ return false;
+}
+
+
+// This is the only gateway for vendor programs to reach net utils.
+int doMain(int argc, char **argv) {
+ char *progname = argv[0];
+ char *basename = NULL;
+
+ basename = strrchr(progname, '/');
+ basename = basename ? basename + 1 : progname;
+
+ for (int i = 0; netcmds[i]; ++i) {
+ size_t len = strlen(netcmds[i]);
+ if (!strncmp(basename, netcmds[i], len)) {
+ // truncate to match netcmds[i]
+ basename[len] = '\0';
+
+ // hardcode the path to /system so it cannot be overwritten
+ char *cmd;
+ if (asprintf(&cmd, "%s%s", SYSTEM_DIRNAME, basename) == -1) {
+ perror("asprintf");
+ exit(EXIT_FAILURE);
+ }
+ argv[0] = cmd;
+ if (checkExpectedCommand(argc, argv)) {
+ execv(cmd, argv);
+ }
+ }
+ }
+
+ // Invalid command. Reject and fail.
+ exit(EXIT_FAILURE);
+}
diff --git a/netutils_wrappers/NetUtilsWrapper.h b/netutils_wrappers/NetUtilsWrapper.h
new file mode 100644
index 0000000..127addc
--- /dev/null
+++ b/netutils_wrappers/NetUtilsWrapper.h
@@ -0,0 +1,20 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define ARRAY_SIZE(x) (sizeof((x)) / (sizeof(((x)[0]))))
+
+int doMain(int argc, char *argv[]);
+bool checkExpectedCommand(int argc, char **argv);
diff --git a/netutils_wrappers/NetUtilsWrapperTest-1.0.cpp b/netutils_wrappers/NetUtilsWrapperTest-1.0.cpp
new file mode 100644
index 0000000..a32cc3b
--- /dev/null
+++ b/netutils_wrappers/NetUtilsWrapperTest-1.0.cpp
@@ -0,0 +1,68 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include <android-base/strings.h>
+
+#include "NetUtilsWrapper.h"
+
+#define MAX_ARGS 128
+#define VALID true
+#define INVALID false
+
+struct Command {
+ bool valid;
+ std::string cmdString;
+};
+
+std::vector<Command> COMMANDS = {
+ {INVALID, "tc qdisc del dev root"},
+ {VALID, "/system/bin/tc qdisc del dev root"},
+ {VALID, "/system/bin/ip -6 addr add dev r_rmnet_data6 2001:db8::/64"},
+ {INVALID, "/system/bin/ip -6 addr add dev wlan2 2001:db8::/64"},
+ {VALID, "/system/bin/ip6tables -w -A INPUT -j qcom_foo"},
+ {INVALID, "/system/bin/ip6tables -w -A INPUT -j routectrl_MANGLE_INPUT"},
+ {VALID, "/system/bin/ip6tables -w -A INPUT -i rmnet_data9 -j routectrl_MANGLE_INPUT"},
+ {VALID, "/system/bin/ip6tables -w -F nm_pre_ip4"},
+ {INVALID, "/system/bin/ip6tables -w -F INPUT"},
+ {VALID, "/system/bin/ndc network interface add oem10"},
+ {VALID, "/system/bin/ndc network interface add oem10 v_oem9"},
+ {VALID, "/system/bin/ndc network interface add oem10 oem9"},
+ {INVALID, "/system/bin/ndc network interface add 100 v_oem9"},
+ {VALID, "/system/bin/ndc network interface add oem10 r_rmnet_data0"},
+ {VALID, "/system/bin/ip xfrm state"},
+};
+
+TEST(NetUtilsWrapperTest10, TestCommands) {
+ // Overwritten by each test case.
+ char *argv[MAX_ARGS];
+
+ for (const Command& cmd : COMMANDS) {
+ std::vector<std::string> pieces = android::base::Split(cmd.cmdString, " ");
+ ASSERT_LE(pieces.size(), ARRAY_SIZE(argv));
+ for (size_t i = 0; i < pieces.size(); i++) {
+ argv[i] = const_cast<char*>(pieces[i].c_str());
+ }
+ EXPECT_EQ(cmd.valid, checkExpectedCommand(pieces.size(), argv)) <<
+ "Expected command to be " <<
+ (cmd.valid ? "valid" : "invalid") << ", but was " <<
+ (cmd.valid ? "invalid" : "valid") << ": '" << cmd.cmdString << "'";
+ }
+}
diff --git a/netutils_wrappers/main.cpp b/netutils_wrappers/main.cpp
new file mode 100644
index 0000000..e2072a3
--- /dev/null
+++ b/netutils_wrappers/main.cpp
@@ -0,0 +1,21 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "NetUtilsWrapper.h"
+
+int main(int argc, char *argv[]) {
+ return doMain(argc, argv);
+}
diff --git a/server/Android.mk b/server/Android.mk
index c98ef9b..9745419 100644
--- a/server/Android.mk
+++ b/server/Android.mk
@@ -180,7 +180,7 @@
IptablesRestoreController.cpp IptablesRestoreControllerTest.cpp \
BandwidthController.cpp BandwidthControllerTest.cpp \
FirewallControllerTest.cpp FirewallController.cpp \
- IdletimerController.cpp \
+ IdletimerController.cpp IdletimerControllerTest.cpp \
NatControllerTest.cpp NatController.cpp \
NetlinkCommands.cpp NetlinkManager.cpp \
RouteController.cpp RouteControllerTest.cpp \
diff --git a/server/BandwidthController.cpp b/server/BandwidthController.cpp
index d875f00..903390b 100644
--- a/server/BandwidthController.cpp
+++ b/server/BandwidthController.cpp
@@ -64,8 +64,6 @@
const char BandwidthController::LOCAL_RAW_PREROUTING[] = "bw_raw_PREROUTING";
const char BandwidthController::LOCAL_MANGLE_POSTROUTING[] = "bw_mangle_POSTROUTING";
-auto BandwidthController::execFunction = android_fork_execvp;
-auto BandwidthController::popenFunction = popen;
auto BandwidthController::iptablesRestoreFunction = execIptablesRestoreWithOutput;
using android::base::Join;
diff --git a/server/BandwidthControllerTest.cpp b/server/BandwidthControllerTest.cpp
index 066f9eb..a0a57da 100644
--- a/server/BandwidthControllerTest.cpp
+++ b/server/BandwidthControllerTest.cpp
@@ -51,8 +51,6 @@
class BandwidthControllerTest : public IptablesBaseTest {
protected:
BandwidthControllerTest() {
- BandwidthController::execFunction = fake_android_fork_exec;
- BandwidthController::popenFunction = fake_popen;
BandwidthController::iptablesRestoreFunction = fakeExecIptablesRestoreWithOutput;
}
BandwidthController mBw;
@@ -378,7 +376,6 @@
std::string expectedError = counters;
EXPECT_EQ(expectedError, err);
- // popen() failing is always an error.
addIptablesRestoreOutput(kIPv4TetherCounters);
ASSERT_EQ(-1, mBw.getTetherStats(&cli, filter, err));
expectNoSocketClientResponse(socketPair[1]);
diff --git a/server/Controllers.cpp b/server/Controllers.cpp
index a25e05a..2b4079c 100644
--- a/server/Controllers.cpp
+++ b/server/Controllers.cpp
@@ -14,6 +14,11 @@
* limitations under the License.
*/
+#include <regex>
+#include <set>
+#include <string>
+
+#include <android-base/strings.h>
#include <android-base/stringprintf.h>
#define LOG_TAG "Netd"
@@ -29,101 +34,153 @@
namespace android {
namespace net {
+using android::base::Join;
+using android::base::StringPrintf;
+using android::base::StringAppendF;
+
auto Controllers::execIptablesRestore = ::execIptablesRestore;
-auto Controllers::execIptablesSilently = ::execIptablesSilently;
+auto Controllers::execIptablesRestoreWithOutput = ::execIptablesRestoreWithOutput;
namespace {
+
/**
* List of module chains to be created, along with explicit ordering. ORDERING
* IS CRITICAL, AND SHOULD BE TRIPLE-CHECKED WITH EACH CHANGE.
*/
-static const char* FILTER_INPUT[] = {
+static const std::vector<const char*> FILTER_INPUT = {
// Bandwidth should always be early in input chain, to make sure we
// correctly count incoming traffic against data plan.
BandwidthController::LOCAL_INPUT,
FirewallController::LOCAL_INPUT,
- NULL,
};
-static const char* FILTER_FORWARD[] = {
+static const std::vector<const char*> FILTER_FORWARD = {
OEM_IPTABLES_FILTER_FORWARD,
FirewallController::LOCAL_FORWARD,
BandwidthController::LOCAL_FORWARD,
NatController::LOCAL_FORWARD,
- NULL,
};
-static const char* FILTER_OUTPUT[] = {
+static const std::vector<const char*> FILTER_OUTPUT = {
OEM_IPTABLES_FILTER_OUTPUT,
FirewallController::LOCAL_OUTPUT,
StrictController::LOCAL_OUTPUT,
BandwidthController::LOCAL_OUTPUT,
- NULL,
};
-static const char* RAW_PREROUTING[] = {
+static const std::vector<const char*> RAW_PREROUTING = {
BandwidthController::LOCAL_RAW_PREROUTING,
IdletimerController::LOCAL_RAW_PREROUTING,
NatController::LOCAL_RAW_PREROUTING,
- NULL,
};
-static const char* MANGLE_POSTROUTING[] = {
+static const std::vector<const char*> MANGLE_POSTROUTING = {
OEM_IPTABLES_MANGLE_POSTROUTING,
BandwidthController::LOCAL_MANGLE_POSTROUTING,
IdletimerController::LOCAL_MANGLE_POSTROUTING,
- NULL,
};
-static const char* MANGLE_INPUT[] = {
+static const std::vector<const char*> MANGLE_INPUT = {
WakeupController::LOCAL_MANGLE_INPUT,
RouteController::LOCAL_MANGLE_INPUT,
- NULL,
};
-static const char* MANGLE_FORWARD[] = {
+static const std::vector<const char*> MANGLE_FORWARD = {
NatController::LOCAL_MANGLE_FORWARD,
- NULL,
};
-static const char* NAT_PREROUTING[] = {
+static const std::vector<const char*> NAT_PREROUTING = {
OEM_IPTABLES_NAT_PREROUTING,
- NULL,
};
-static const char* NAT_POSTROUTING[] = {
+static const std::vector<const char*> NAT_POSTROUTING = {
NatController::LOCAL_NAT_POSTROUTING,
- NULL,
};
+// Commands to create child chains and to match created chains in iptables -S output. Keep in sync.
+static const char* CHILD_CHAIN_TEMPLATE = "-A %s -j %s\n";
+static const std::regex CHILD_CHAIN_REGEX("^-A ([^ ]+) -j ([^ ]+)$",
+ std::regex_constants::extended);
+
} // namespace
/* static */
+std::set<std::string> Controllers::findExistingChildChains(const IptablesTarget target,
+ const char* table,
+ const char* parentChain) {
+ if (target == V4V6) {
+ ALOGE("findExistingChildChains only supports one protocol at a time");
+ abort();
+ }
+
+ std::set<std::string> existing;
+
+ // List the current contents of parentChain.
+ //
+ // TODO: there is no guarantee that nothing else modifies the chain in the few milliseconds
+ // between when we list the existing rules and when we delete them. However:
+ // - Since this code is only run on startup, nothing else in netd will be running.
+ // - While vendor code is known to add its own rules to chains created by netd, it should never
+ // be modifying the rules in childChains or the rules that hook said chains into their parent
+ // chains.
+ std::string command = StringPrintf("*%s\n-S %s\nCOMMIT\n", table, parentChain);
+ std::string output;
+ if (Controllers::execIptablesRestoreWithOutput(target, command, &output) == -1) {
+ ALOGE("Error listing chain %s in table %s\n", parentChain, table);
+ return existing;
+ }
+
+ // The only rules added by createChildChains are of the simple form "-A <parent> -j <child>".
+ // Find those rules and add each one's child chain to existing.
+ std::smatch matches;
+ std::stringstream stream(output);
+ std::string rule;
+ while (std::getline(stream, rule, '\n')) {
+ if (std::regex_search(rule, matches, CHILD_CHAIN_REGEX) && matches[1] == parentChain) {
+ existing.insert(matches[2]);
+ }
+ }
+
+ return existing;
+}
+
+/* static */
void Controllers::createChildChains(IptablesTarget target, const char* table,
const char* parentChain,
- const char** childChains,
+ const std::vector<const char*>& childChains,
bool exclusive) {
- std::string command = android::base::StringPrintf("*%s\n", table);
+ std::string command = StringPrintf("*%s\n", table);
- // If we're the exclusive owner of this chain, clear it entirely. This saves us from having to
- // run one execIptablesSilently command to delete each child chain. We can't use -D in
- // iptables-restore because it's a fatal error if the rule doesn't exist.
+ // We cannot just clear all the chains we create because vendor code modifies filter OUTPUT and
+ // mangle POSTROUTING directly. So:
+ //
+ // - If we're the exclusive owner of this chain, simply clear it entirely.
+ // - If not, then list the chain's current contents to ensure that if we restart after a crash,
+ // we leave the existing rules alone in the positions they currently occupy. This is faster
+ // than blindly deleting our rules and recreating them, because deleting a rule that doesn't
+ // exists causes iptables-restore to quit, which takes ~30ms per delete. It's also more
+ // correct, because if we delete rules and re-add them, they'll be in the wrong position with
+ // regards to the vendor rules.
+ //
// TODO: Make all chains exclusive once vendor code uses the oem_* rules.
+ std::set<std::string> existingChildChains;
if (exclusive) {
// Just running ":chain -" flushes user-defined chains, but not built-in chains like INPUT.
// Since at this point we don't know if parentChain is a built-in chain, do both.
- command += android::base::StringPrintf(":%s -\n", parentChain);
- command += android::base::StringPrintf("-F %s\n", parentChain);
+ StringAppendF(&command, ":%s -\n", parentChain);
+ StringAppendF(&command, "-F %s\n", parentChain);
+ } else {
+ existingChildChains = findExistingChildChains(target, table, parentChain);
}
- const char** childChain = childChains;
- do {
- if (!exclusive) {
- execIptablesSilently(target, "-t", table, "-D", parentChain, "-j", *childChain, NULL);
+ for (const auto& childChain : childChains) {
+ // Always clear the child chain.
+ StringAppendF(&command, ":%s -\n", childChain);
+ // But only add it to the parent chain if it's not already there.
+ if (existingChildChains.find(childChain) == existingChildChains.end()) {
+ StringAppendF(&command, CHILD_CHAIN_TEMPLATE, parentChain, childChain);
}
- command += android::base::StringPrintf(":%s -\n", *childChain);
- command += android::base::StringPrintf("-A %s -j %s\n", parentChain, *childChain);
- } while (*(++childChain) != NULL);
+ }
command += "COMMIT\n";
execIptablesRestore(target, command);
}
@@ -163,10 +220,10 @@
createChildChains(V4, "nat", "PREROUTING", NAT_PREROUTING, true);
createChildChains(V4, "nat", "POSTROUTING", NAT_POSTROUTING, true);
- // We cannot use createChildChainsFast for all chains because vendor code modifies filter OUTPUT
- // and mangle POSTROUTING directly.
- createChildChains(V4V6, "filter", "OUTPUT", FILTER_OUTPUT, false);
- createChildChains(V4V6, "mangle", "POSTROUTING", MANGLE_POSTROUTING, false);
+ createChildChains(V4, "filter", "OUTPUT", FILTER_OUTPUT, false);
+ createChildChains(V6, "filter", "OUTPUT", FILTER_OUTPUT, false);
+ createChildChains(V4, "mangle", "POSTROUTING", MANGLE_POSTROUTING, false);
+ createChildChains(V6, "mangle", "POSTROUTING", MANGLE_POSTROUTING, false);
}
void Controllers::initIptablesRules() {
diff --git a/server/Controllers.h b/server/Controllers.h
index 0754932..53854cf 100644
--- a/server/Controllers.h
+++ b/server/Controllers.h
@@ -63,10 +63,13 @@
friend class ControllersTest;
void initIptablesRules();
static void initChildChains();
+ static std::set<std::string> findExistingChildChains(const IptablesTarget target,
+ const char* table,
+ const char* parentChain);
static void createChildChains(IptablesTarget target, const char* table, const char* parentChain,
- const char** childChains, bool exclusive);
- static int (*execIptablesSilently)(IptablesTarget target, ...);
+ const std::vector<const char*>& childChains, bool exclusive);
static int (*execIptablesRestore)(IptablesTarget, const std::string&);
+ static int (*execIptablesRestoreWithOutput)(IptablesTarget, const std::string&, std::string *);
};
extern Controllers* gCtls;
diff --git a/server/ControllersTest.cpp b/server/ControllersTest.cpp
index 6f41798..3ca5d81 100644
--- a/server/ControllersTest.cpp
+++ b/server/ControllersTest.cpp
@@ -16,30 +16,60 @@
* ControllersTest.cpp - unit tests for Controllers.cpp
*/
+#include <set>
#include <string>
#include <vector>
+#include <gmock/gmock.h>
#include <gtest/gtest.h>
+#include <android-base/strings.h>
+
#include "Controllers.h"
#include "IptablesBaseTest.h"
+using testing::ContainerEq;
+
namespace android {
namespace net {
class ControllersTest : public IptablesBaseTest {
public:
ControllersTest() {
- Controllers::execIptablesSilently = fakeExecIptables;
Controllers::execIptablesRestore = fakeExecIptablesRestore;
+ Controllers::execIptablesRestoreWithOutput = fakeExecIptablesRestoreWithOutput;
}
protected:
void initChildChains() { Controllers::initChildChains(); };
+ std::set<std::string> findExistingChildChains(IptablesTarget a, const char* b, const char*c) {
+ return Controllers::findExistingChildChains(a, b, c);
+ }
};
+TEST_F(ControllersTest, TestFindExistingChildChains) {
+ ExpectedIptablesCommands expectedCmds = {
+ { V6, "*raw\n-S PREROUTING\nCOMMIT\n" },
+ };
+ sIptablesRestoreOutput.push_back(
+ "-P PREROUTING ACCEPT\n"
+ "-A PREROUTING -j bw_raw_PREROUTING\n"
+ "-A PREROUTING -j idletimer_raw_PREROUTING\n"
+ "-A PREROUTING -j natctrl_raw_PREROUTING\n"
+ );
+ std::set<std::string> expectedChains = {
+ "bw_raw_PREROUTING",
+ "idletimer_raw_PREROUTING",
+ "natctrl_raw_PREROUTING",
+ };
+ std::set<std::string> actual = findExistingChildChains(V6, "raw", "PREROUTING");
+ EXPECT_THAT(expectedChains, ContainerEq(actual));
+ expectIptablesRestoreCommands(expectedCmds);
+}
+
TEST_F(ControllersTest, TestInitIptablesRules) {
- ExpectedIptablesCommands expectedRestoreCommands = {
+ // Test what happens when we boot and there are no rules.
+ ExpectedIptablesCommands expected = {
{ V4V6, "*filter\n"
":INPUT -\n"
"-F INPUT\n"
@@ -103,47 +133,134 @@
"-A POSTROUTING -j natctrl_nat_POSTROUTING\n"
"COMMIT\n"
},
- { V4V6, "*filter\n"
- ":oem_out -\n"
- "-A OUTPUT -j oem_out\n"
- ":fw_OUTPUT -\n"
- "-A OUTPUT -j fw_OUTPUT\n"
- ":st_OUTPUT -\n"
- "-A OUTPUT -j st_OUTPUT\n"
- ":bw_OUTPUT -\n"
- "-A OUTPUT -j bw_OUTPUT\n"
- "COMMIT\n"
+ { V4, "*filter\n"
+ "-S OUTPUT\n"
+ "COMMIT\n" },
+ { V4, "*filter\n"
+ ":oem_out -\n"
+ "-A OUTPUT -j oem_out\n"
+ ":fw_OUTPUT -\n"
+ "-A OUTPUT -j fw_OUTPUT\n"
+ ":st_OUTPUT -\n"
+ "-A OUTPUT -j st_OUTPUT\n"
+ ":bw_OUTPUT -\n"
+ "-A OUTPUT -j bw_OUTPUT\n"
+ "COMMIT\n"
},
- { V4V6, "*mangle\n"
- ":oem_mangle_post -\n"
- "-A POSTROUTING -j oem_mangle_post\n"
- ":bw_mangle_POSTROUTING -\n"
- "-A POSTROUTING -j bw_mangle_POSTROUTING\n"
- ":idletimer_mangle_POSTROUTING -\n"
- "-A POSTROUTING -j idletimer_mangle_POSTROUTING\n"
- "COMMIT\n"
+ { V6, "*filter\n"
+ "-S OUTPUT\n"
+ "COMMIT\n" },
+ { V6, "*filter\n"
+ ":oem_out -\n"
+ "-A OUTPUT -j oem_out\n"
+ ":fw_OUTPUT -\n"
+ "-A OUTPUT -j fw_OUTPUT\n"
+ ":st_OUTPUT -\n"
+ "-A OUTPUT -j st_OUTPUT\n"
+ ":bw_OUTPUT -\n"
+ "-A OUTPUT -j bw_OUTPUT\n"
+ "COMMIT\n"
+ },
+ { V4, "*mangle\n"
+ "-S POSTROUTING\n"
+ "COMMIT\n" },
+ { V4, "*mangle\n"
+ ":oem_mangle_post -\n"
+ "-A POSTROUTING -j oem_mangle_post\n"
+ ":bw_mangle_POSTROUTING -\n"
+ "-A POSTROUTING -j bw_mangle_POSTROUTING\n"
+ ":idletimer_mangle_POSTROUTING -\n"
+ "-A POSTROUTING -j idletimer_mangle_POSTROUTING\n"
+ "COMMIT\n"
+ },
+ { V6, "*mangle\n"
+ "-S POSTROUTING\n"
+ "COMMIT\n" },
+ { V6, "*mangle\n"
+ ":oem_mangle_post -\n"
+ "-A POSTROUTING -j oem_mangle_post\n"
+ ":bw_mangle_POSTROUTING -\n"
+ "-A POSTROUTING -j bw_mangle_POSTROUTING\n"
+ ":idletimer_mangle_POSTROUTING -\n"
+ "-A POSTROUTING -j idletimer_mangle_POSTROUTING\n"
+ "COMMIT\n"
},
};
+
+ // Check that we run these commands and these only.
initChildChains();
- expectIptablesRestoreCommands(expectedRestoreCommands);
+ expectIptablesRestoreCommands(expected);
+ expectIptablesRestoreCommands(ExpectedIptablesCommands{});
- std::vector<std::string> expectedIptablesCommands = {
- "-t filter -D OUTPUT -j oem_out",
- "-t filter -D OUTPUT -j fw_OUTPUT",
- "-t filter -D OUTPUT -j st_OUTPUT",
- "-t filter -D OUTPUT -j bw_OUTPUT",
- "-t mangle -D POSTROUTING -j oem_mangle_post",
- "-t mangle -D POSTROUTING -j bw_mangle_POSTROUTING",
- "-t mangle -D POSTROUTING -j idletimer_mangle_POSTROUTING",
- };
- expectIptablesCommands(expectedIptablesCommands);
+ // Now test what happens when some rules exist (e.g., if we crash and restart).
- // ... and nothing more.
- expectedRestoreCommands = {};
- expectIptablesRestoreCommands(expectedRestoreCommands);
+ // First, explicitly tell the iptables test code to return empty output to all the commands we
+ // send. This allows us to tell it to return non-empty output to particular commands in the
+ // following code.
+ for (size_t i = 0; i < expected.size(); i++) {
+ sIptablesRestoreOutput.push_back("");
+ }
- expectedIptablesCommands = {};
- expectIptablesCommands(expectedIptablesCommands);
+ // Define a macro to remove a substring from a string. We use a macro instead of a function so
+ // we can assert in it. In the following code, we use ASSERT_* to check for programming errors
+ // in the test code, and EXPECT_* to check for errors in the actual code.
+#define DELETE_SUBSTRING(substr, str) { \
+ size_t start = (str).find((substr)); \
+ ASSERT_NE(std::string::npos, start); \
+ (str).erase(start, strlen((substr))); \
+ ASSERT_EQ(std::string::npos, (str).find((substr))); \
+ }
+
+ // Now set test expectations.
+
+ // 1. Test that if we find rules that we don't create ourselves, we ignore them.
+ // First check that command #7 is where we list the OUTPUT chain in the (IPv4) filter table:
+ ASSERT_NE(std::string::npos, expected[7].second.find("*filter\n-S OUTPUT\n"));
+ // ... and pretend that when we run that command, we find the following rules. Because we don't
+ // create any of these rules ourselves, our behaviour is unchanged.
+ sIptablesRestoreOutput[7] =
+ "-P OUTPUT ACCEPT\n"
+ "-A OUTPUT -o r_rmnet_data8 -p udp -m udp --dport 1900 -j DROP\n";
+
+ // 2. Test that rules that we create ourselves are not added if they already exist.
+ // Pretend that when we list the OUTPUT chain in the (IPv6) filter table, we find the oem_out
+ // and st_OUTPUT chains:
+ ASSERT_NE(std::string::npos, expected[9].second.find("*filter\n-S OUTPUT\n"));
+ sIptablesRestoreOutput[9] =
+ "-A OUTPUT -j oem_out\n"
+ "-A OUTPUT -j st_OUTPUT\n";
+ // ... and expect that when we populate the OUTPUT chain, we do not re-add them.
+ DELETE_SUBSTRING("-A OUTPUT -j oem_out\n", expected[10].second);
+ DELETE_SUBSTRING("-A OUTPUT -j st_OUTPUT\n", expected[10].second);
+
+ // 3. Now test that when we list the POSTROUTING chain in the mangle table, we find a mixture of
+ // netd-created rules and vendor rules:
+ ASSERT_NE(std::string::npos, expected[13].second.find("*mangle\n-S POSTROUTING\n"));
+ sIptablesRestoreOutput[13] =
+ "-P POSTROUTING ACCEPT\n"
+ "-A POSTROUTING -j oem_mangle_post\n"
+ "-A POSTROUTING -j bw_mangle_POSTROUTING\n"
+ "-A POSTROUTING -j idletimer_mangle_POSTROUTING\n"
+ "-A POSTROUTING -j qcom_qos_reset_POSTROUTING\n"
+ "-A POSTROUTING -j qcom_qos_filter_POSTROUTING\n";
+ // and expect that we don't re-add the netd-created rules that already exist.
+ DELETE_SUBSTRING("-A POSTROUTING -j oem_mangle_post\n", expected[14].second);
+ DELETE_SUBSTRING("-A POSTROUTING -j bw_mangle_POSTROUTING\n", expected[14].second);
+ DELETE_SUBSTRING("-A POSTROUTING -j idletimer_mangle_POSTROUTING\n", expected[14].second);
+
+ // In this last case, also check that our expectations are reasonable.
+ std::string expectedCmd14 =
+ "*mangle\n"
+ ":oem_mangle_post -\n"
+ ":bw_mangle_POSTROUTING -\n"
+ ":idletimer_mangle_POSTROUTING -\n"
+ "COMMIT\n";
+ ASSERT_EQ(expectedCmd14, expected[14].second);
+
+ // Finally, actually test that initChildChains runs the expected commands, and nothing more.
+ initChildChains();
+ expectIptablesRestoreCommands(expected);
+ expectIptablesRestoreCommands(ExpectedIptablesCommands{});
}
} // namespace net
diff --git a/server/IdletimerController.cpp b/server/IdletimerController.cpp
index e6306fd..da19453 100644
--- a/server/IdletimerController.cpp
+++ b/server/IdletimerController.cpp
@@ -95,6 +95,10 @@
#define LOG_NDEBUG 0
+#include <string>
+#include <vector>
+
+#include <stdint.h>
#include <stdlib.h>
#include <errno.h>
#include <sys/socket.h>
@@ -106,6 +110,9 @@
#include <string.h>
#include <cutils/properties.h>
+#include <android-base/strings.h>
+#include <android-base/stringprintf.h>
+
#define LOG_TAG "IdletimerController"
#include <cutils/log.h>
#include <logwrap/logwrap.h>
@@ -113,69 +120,35 @@
#include "IdletimerController.h"
#include "NetdConstants.h"
+using android::base::Join;
+using android::base::StringPrintf;
+
const char* IdletimerController::LOCAL_RAW_PREROUTING = "idletimer_raw_PREROUTING";
const char* IdletimerController::LOCAL_MANGLE_POSTROUTING = "idletimer_mangle_POSTROUTING";
+auto IdletimerController::execIptablesRestore = ::execIptablesRestore;
+
IdletimerController::IdletimerController() {
}
IdletimerController::~IdletimerController() {
}
-/* return 0 or non-zero */
-int IdletimerController::runIpxtablesCmd(int argc, const char **argv) {
- int resIpv4, resIpv6;
-
- // Running for IPv4
- argv[0] = IPTABLES_PATH;
- resIpv4 = android_fork_execvp(argc, (char **)argv, NULL, false, false);
-
- // Running for IPv6
- argv[0] = IP6TABLES_PATH;
- resIpv6 = android_fork_execvp(argc, (char **)argv, NULL, false, false);
-
-#if !LOG_NDEBUG
- std::string full_cmd = argv[0];
- argc--; argv++;
- for (; argc; argc--, argv++) {
- full_cmd += " ";
- full_cmd += argv[0];
- }
- ALOGV("runCmd(%s) res_ipv4=%d, res_ipv6=%d", full_cmd.c_str(), resIpv4, resIpv6);
-#endif
-
- return (resIpv4 == 0 && resIpv6 == 0) ? 0 : -1;
-}
bool IdletimerController::setupIptablesHooks() {
return true;
}
int IdletimerController::setDefaults() {
- int res;
- const char *cmd1[] = {
- NULL, // To be filled inside runIpxtablesCmd
- "-w",
- "-t",
- "raw",
- "-F",
- LOCAL_RAW_PREROUTING
- };
- res = runIpxtablesCmd(ARRAY_SIZE(cmd1), cmd1);
+ std::vector<std::string> cmds = {
+ "*raw",
+ StringPrintf(":%s -", LOCAL_RAW_PREROUTING),
+ "COMMIT",
+ "*mangle",
+ StringPrintf(":%s -", LOCAL_MANGLE_POSTROUTING),
+ "COMMIT\n",
+ };
- if (res)
- return res;
-
- const char *cmd2[] = {
- NULL, // To be filled inside runIpxtablesCmd
- "-w",
- "-t",
- "mangle",
- "-F",
- LOCAL_MANGLE_POSTROUTING
- };
- res = runIpxtablesCmd(ARRAY_SIZE(cmd2), cmd2);
-
- return res;
+ return execIptablesRestore(V4V6, Join(cmds, '\n'));
}
int IdletimerController::enableIdletimerControl() {
@@ -191,70 +164,34 @@
int IdletimerController::modifyInterfaceIdletimer(IptOp op, const char *iface,
uint32_t timeout,
const char *classLabel) {
- int res;
- char timeout_str[11]; //enough to store any 32-bit unsigned decimal
+ if (!isIfaceName(iface)) {
+ errno = ENOENT;
+ return -1;
+ }
- if (!isIfaceName(iface)) {
- errno = ENOENT;
- return -1;
- }
+ const char *addRemove = (op == IptOpAdd) ? "-A" : "-D";
+ std::vector<std::string> cmds = {
+ "*raw",
+ StringPrintf("%s %s -i %s -j IDLETIMER --timeout %u --label %s --send_nl_msg 1",
+ addRemove, LOCAL_RAW_PREROUTING, iface, timeout, classLabel),
+ "COMMIT",
+ "*mangle",
+ StringPrintf("%s %s -o %s -j IDLETIMER --timeout %u --label %s --send_nl_msg 1",
+ addRemove, LOCAL_MANGLE_POSTROUTING, iface, timeout, classLabel),
+ "COMMIT\n",
+ };
- snprintf(timeout_str, sizeof(timeout_str), "%u", timeout);
-
- const char *cmd1[] = {
- NULL, // To be filled inside runIpxtablesCmd
- "-w",
- "-t",
- "raw",
- (op == IptOpAdd) ? "-A" : "-D",
- LOCAL_RAW_PREROUTING,
- "-i",
- iface,
- "-j",
- "IDLETIMER",
- "--timeout",
- timeout_str,
- "--label",
- classLabel,
- "--send_nl_msg",
- "1"
- };
- res = runIpxtablesCmd(ARRAY_SIZE(cmd1), cmd1);
-
- if (res)
- return res;
-
- const char *cmd2[] = {
- NULL, // To be filled inside runIpxtablesCmd
- "-w",
- "-t",
- "mangle",
- (op == IptOpAdd) ? "-A" : "-D",
- LOCAL_MANGLE_POSTROUTING,
- "-o",
- iface,
- "-j",
- "IDLETIMER",
- "--timeout",
- timeout_str,
- "--label",
- classLabel,
- "--send_nl_msg",
- "1"
- };
- res = runIpxtablesCmd(ARRAY_SIZE(cmd2), cmd2);
-
- return res;
+ return execIptablesRestore(V4V6, Join(cmds, '\n'));
}
int IdletimerController::addInterfaceIdletimer(const char *iface,
uint32_t timeout,
const char *classLabel) {
- return modifyInterfaceIdletimer(IptOpAdd, iface, timeout, classLabel);
+ return modifyInterfaceIdletimer(IptOpAdd, iface, timeout, classLabel);
}
int IdletimerController::removeInterfaceIdletimer(const char *iface,
uint32_t timeout,
const char *classLabel) {
- return modifyInterfaceIdletimer(IptOpDelete, iface, timeout, classLabel);
+ return modifyInterfaceIdletimer(IptOpDelete, iface, timeout, classLabel);
}
diff --git a/server/IdletimerController.h b/server/IdletimerController.h
index 98a312e..87e0b4e 100644
--- a/server/IdletimerController.h
+++ b/server/IdletimerController.h
@@ -16,6 +16,10 @@
#ifndef _IDLETIMER_CONTROLLER_H
#define _IDLETIMER_CONTROLLER_H
+#include <stdint.h>
+
+#include "NetdConstants.h"
+
class IdletimerController {
public:
@@ -39,6 +43,9 @@
int runIpxtablesCmd(int argc, const char **cmd);
int modifyInterfaceIdletimer(IptOp op, const char *iface, uint32_t timeout,
const char *classLabel);
+
+ friend class IdletimerControllerTest;
+ static int (*execIptablesRestore)(IptablesTarget, const std::string&);
};
#endif
diff --git a/server/IdletimerControllerTest.cpp b/server/IdletimerControllerTest.cpp
new file mode 100644
index 0000000..ace3fd9
--- /dev/null
+++ b/server/IdletimerControllerTest.cpp
@@ -0,0 +1,95 @@
+/*
+ * Copyright 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * IdletimerControllerTest.cpp - unit tests for IdletimerController.cpp
+ */
+
+#include <gtest/gtest.h>
+
+#include <android-base/strings.h>
+#include <android-base/stringprintf.h>
+
+#include "IdletimerController.h"
+#include "IptablesBaseTest.h"
+
+using android::base::Join;
+using android::base::StringPrintf;
+
+class IdletimerControllerTest : public IptablesBaseTest {
+protected:
+ IdletimerControllerTest() {
+ IdletimerController::execIptablesRestore = fakeExecIptablesRestore;
+ }
+ IdletimerController mIt;
+};
+
+TEST_F(IdletimerControllerTest, TestSetupIptablesHooks) {
+ mIt.setupIptablesHooks();
+ expectIptablesRestoreCommands(ExpectedIptablesCommands{});
+}
+
+TEST_F(IdletimerControllerTest, TestEnableDisable) {
+ std::vector<std::string> expected = {
+ "*raw\n"
+ ":idletimer_raw_PREROUTING -\n"
+ "COMMIT\n"
+ "*mangle\n"
+ ":idletimer_mangle_POSTROUTING -\n"
+ "COMMIT\n",
+ };
+
+ mIt.enableIdletimerControl();
+ expectIptablesRestoreCommands(expected);
+
+ mIt.enableIdletimerControl();
+ expectIptablesRestoreCommands(expected);
+
+ mIt.disableIdletimerControl();
+ expectIptablesRestoreCommands(expected);
+
+ mIt.disableIdletimerControl();
+ expectIptablesRestoreCommands(expected);
+}
+
+const std::vector<std::string> makeAddRemoveCommands(bool add) {
+ const char *op = add ? "-A" : "-D";
+ std::vector<std::string> cmds = {
+ "*raw",
+ StringPrintf("%s idletimer_raw_PREROUTING -i wlan0 -j IDLETIMER"
+ " --timeout 12345 --label hello --send_nl_msg 1", op),
+ "COMMIT",
+ "*mangle",
+ StringPrintf("%s idletimer_mangle_POSTROUTING -o wlan0 -j IDLETIMER"
+ " --timeout 12345 --label hello --send_nl_msg 1", op),
+ "COMMIT\n",
+ };
+ return { Join(cmds, '\n') };
+}
+
+TEST_F(IdletimerControllerTest, TestAddRemove) {
+ auto expected = makeAddRemoveCommands(true);
+ mIt.addInterfaceIdletimer("wlan0", 12345, "hello");
+ expectIptablesRestoreCommands(expected);
+
+ mIt.addInterfaceIdletimer("wlan0", 12345, "hello");
+ expectIptablesRestoreCommands(expected);
+
+ expected = makeAddRemoveCommands(false);
+ mIt.removeInterfaceIdletimer("wlan0", 12345, "hello");
+ expectIptablesRestoreCommands(expected);
+
+ mIt.removeInterfaceIdletimer("wlan0", 12345, "hello");
+ expectIptablesRestoreCommands(expected);
+}
diff --git a/server/IptablesBaseTest.cpp b/server/IptablesBaseTest.cpp
index b5fd9a0..c9bf67b 100644
--- a/server/IptablesBaseTest.cpp
+++ b/server/IptablesBaseTest.cpp
@@ -61,30 +61,6 @@
return ret;
}
-int IptablesBaseTest::fakeExecIptables(IptablesTarget target, ...) {
- std::string cmd = " -w";
- va_list args;
- va_start(args, target);
- const char *arg;
- do {
- arg = va_arg(args, const char *);
- if (arg != nullptr) {
- cmd += " ";
- cmd += arg;
- }
- } while (arg);
- va_end(args);
-
- if (target == V4 || target == V4V6) {
- sCmds.push_back(IPTABLES_PATH + cmd);
- }
- if (target == V6 || target == V4V6) {
- sCmds.push_back(IP6TABLES_PATH + cmd);
- }
-
- return 0;
-}
-
FILE *IptablesBaseTest::fake_popen(const char * /* cmd */, const char *type) {
if (sPopenContents.empty() || strcmp(type, "r") != 0) {
return NULL;
@@ -120,62 +96,9 @@
return fakeExecIptablesRestoreWithOutput(target, fullCmd, output);
}
-int IptablesBaseTest::expectIptablesCommand(IptablesTarget target, int pos,
- const std::string& cmd) {
-
- if ((unsigned) pos >= sCmds.size()) {
- ADD_FAILURE() << "Expected too many iptables commands, want command "
- << pos + 1 << "/" << sCmds.size();
- return -1;
- }
-
- if (target == V4 || target == V4V6) {
- EXPECT_EQ("/system/bin/iptables -w " + cmd, sCmds[pos++]);
- }
- if (target == V6 || target == V4V6) {
- EXPECT_EQ("/system/bin/ip6tables -w " + cmd, sCmds[pos++]);
- }
-
- return target == V4V6 ? 2 : 1;
-}
-
-void IptablesBaseTest::expectIptablesCommands(const std::vector<std::string>& expectedCmds) {
- ExpectedIptablesCommands expected;
- for (auto cmd : expectedCmds) {
- expected.push_back({ V4V6, cmd });
- }
- expectIptablesCommands(expected);
-}
-
-void IptablesBaseTest::expectIptablesCommands(const ExpectedIptablesCommands& expectedCmds) {
- size_t pos = 0;
- for (size_t i = 0; i < expectedCmds.size(); i ++) {
- auto target = expectedCmds[i].first;
- auto cmd = expectedCmds[i].second;
- int numConsumed = expectIptablesCommand(target, pos, cmd);
- if (numConsumed < 0) {
- // Read past the end of the array.
- break;
- }
- pos += numConsumed;
- }
-
- EXPECT_EQ(pos, sCmds.size());
- sCmds.clear();
-}
-
-void IptablesBaseTest::expectIptablesCommands(
- const std::vector<ExpectedIptablesCommands>& snippets) {
- ExpectedIptablesCommands expected;
- for (const auto& snippet: snippets) {
- expected.insert(expected.end(), snippet.begin(), snippet.end());
- }
- expectIptablesCommands(expected);
-}
-
void IptablesBaseTest::expectIptablesRestoreCommands(const std::vector<std::string>& expectedCmds) {
ExpectedIptablesCommands expected;
- for (auto cmd : expectedCmds) {
+ for (const auto& cmd : expectedCmds) {
expected.push_back({ V4V6, cmd });
}
expectIptablesRestoreCommands(expected);
diff --git a/server/IptablesBaseTest.h b/server/IptablesBaseTest.h
index a8a511f..207a5ee 100644
--- a/server/IptablesBaseTest.h
+++ b/server/IptablesBaseTest.h
@@ -28,16 +28,12 @@
static int fake_android_fork_exec(int argc, char* argv[], int *status, bool, bool);
static int fake_android_fork_execvp(int argc, char* argv[], int *status, bool, bool);
- static int fakeExecIptables(IptablesTarget target, ...);
static int fakeExecIptablesRestore(IptablesTarget target, const std::string& commands);
static int fakeExecIptablesRestoreWithOutput(IptablesTarget target, const std::string& commands,
std::string *output);
static int fakeExecIptablesRestoreCommand(IptablesTarget target, const std::string& table,
const std::string& commands, std::string *output);
static FILE *fake_popen(const char *cmd, const char *type);
- void expectIptablesCommands(const std::vector<std::string>& expectedCmds);
- void expectIptablesCommands(const ExpectedIptablesCommands& expectedCmds);
- void expectIptablesCommands(const std::vector<ExpectedIptablesCommands>& snippets);
void expectIptablesRestoreCommands(const std::vector<std::string>& expectedCmds);
void expectIptablesRestoreCommands(const ExpectedIptablesCommands& expectedCmds);
void setReturnValues(const std::deque<int>& returnValues);
@@ -48,5 +44,4 @@
static std::deque<int> sReturnValues;
static std::deque<std::string> sPopenContents;
static std::deque<std::string> sIptablesRestoreOutput;
- int expectIptablesCommand(IptablesTarget target, int pos, const std::string& cmd);
};
diff --git a/server/IptablesRestoreController.cpp b/server/IptablesRestoreController.cpp
index e346b82..a90224a 100644
--- a/server/IptablesRestoreController.cpp
+++ b/server/IptablesRestoreController.cpp
@@ -24,9 +24,13 @@
#define LOG_TAG "IptablesRestoreController"
#include <android-base/logging.h>
#include <android-base/file.h>
+#include <netdutils/Syscalls.h>
#include "Controllers.h"
+using android::netdutils::StatusOr;
+using android::netdutils::sSyscalls;
+
constexpr char IPTABLES_RESTORE_PATH[] = "/system/bin/iptables-restore";
constexpr char IP6TABLES_RESTORE_PATH[] = "/system/bin/ip6tables-restore";
@@ -113,6 +117,13 @@
};
IptablesRestoreController::IptablesRestoreController() {
+ Init();
+}
+
+IptablesRestoreController::~IptablesRestoreController() {
+}
+
+void IptablesRestoreController::Init() {
// Start the IPv4 and IPv6 processes in parallel, since each one takes 20-30ms.
std::thread v4([this] () { mIpRestore.reset(forkAndExec(IPTABLES_PROCESS)); });
std::thread v6([this] () { mIp6Restore.reset(forkAndExec(IP6TABLES_PROCESS)); });
@@ -120,9 +131,6 @@
v6.join();
}
-IptablesRestoreController::~IptablesRestoreController() {
-}
-
/* static */
IptablesProcess* IptablesRestoreController::forkAndExec(const IptablesProcessType type) {
const char* const cmd = (type == IPTABLES_PROCESS) ?
@@ -142,8 +150,14 @@
return nullptr;
}
- pid_t child_pid = fork();
- if (child_pid == 0) {
+ const auto& sys = sSyscalls.get();
+ StatusOr<pid_t> child_pid = sys.fork();
+ if (!isOk(child_pid)) {
+ ALOGE("fork() failed: %s", strerror(child_pid.status().code()));
+ return nullptr;
+ }
+
+ if (child_pid.value() == 0) {
// The child process. Reads from stdin, writes to stderr and stdout.
// stdin_pipe[1] : The write end of the stdin pipe.
@@ -183,11 +197,6 @@
}
// The parent process. Writes to stdout and stderr and reads from stdin.
- if (child_pid == -1) {
- ALOGE("fork() failed: %s", strerror(errno));
- return nullptr;
- }
-
// stdin_pipe[0] : The read end of the stdin pipe.
// stdout_pipe[1] : The write end of the stdout pipe.
// stderr_pipe[1] : The write end of the stderr pipe.
@@ -197,7 +206,7 @@
ALOGW("close() failed: %s", strerror(errno));
}
- return new IptablesProcess(child_pid, stdin_pipe[1], stdout_pipe[0], stderr_pipe[0]);
+ return new IptablesProcess(child_pid.value(), stdin_pipe[1], stdout_pipe[0], stderr_pipe[0]);
}
// TODO: Return -errno on failure instead of -1.
diff --git a/server/IptablesRestoreController.h b/server/IptablesRestoreController.h
index 6850d0d..b1c8dcd 100644
--- a/server/IptablesRestoreController.h
+++ b/server/IptablesRestoreController.h
@@ -68,6 +68,8 @@
// |POLL_TIMEOUT_MS * MAX_RETRIES|. Chosen so that the overall timeout is 1s.
static int POLL_TIMEOUT_MS;
+ void Init();
+
private:
static IptablesProcess* forkAndExec(const IptablesProcessType type);
diff --git a/server/IptablesRestoreControllerTest.cpp b/server/IptablesRestoreControllerTest.cpp
index 43041ec..017870f 100644
--- a/server/IptablesRestoreControllerTest.cpp
+++ b/server/IptablesRestoreControllerTest.cpp
@@ -20,12 +20,14 @@
#include <sys/socket.h>
#include <sys/un.h>
+#include <gmock/gmock.h>
#include <gtest/gtest.h>
#define LOG_TAG "IptablesRestoreControllerTest"
#include <cutils/log.h>
#include <android-base/stringprintf.h>
#include <android-base/strings.h>
+#include <netdutils/MockSyscalls.h>
#include "IptablesRestoreController.h"
#include "NetdConstants.h"
@@ -37,6 +39,9 @@
using android::base::Join;
using android::base::StringPrintf;
+using android::netdutils::ScopedMockSyscalls;
+using testing::Return;
+using testing::StrictMock;
class IptablesRestoreControllerTest : public ::testing::Test {
public:
@@ -60,6 +65,10 @@
deleteTestChain();
}
+ void Init() {
+ con.Init();
+ }
+
pid_t getIpRestorePid(const IptablesRestoreController::IptablesProcessType type) {
return con.getIpRestorePid(type);
};
@@ -258,15 +267,30 @@
float timeTaken = s.getTimeAndReset();
fprintf(stderr, " Add/del %d UID rules via restore: %.1fms (%.2fms per operation)\n",
iterations, timeTaken, timeTaken / 2 / iterations);
-
- for (int i = 0; i < iterations; i++) {
- EXPECT_EQ(0, execIptables(V4V6, "-I", "fw_powersave", "-m", "owner",
- "--uid-owner", "2000000000", "-j", "RETURN", nullptr));
- EXPECT_EQ(0, execIptables(V4V6, "-D", "fw_powersave", "-m", "owner",
- "--uid-owner", "2000000000", "-j", "RETURN", nullptr));
- }
- timeTaken = s.getTimeAndReset();
- fprintf(stderr, " Add/del %d UID rules via iptables: %.1fms (%.2fms per operation)\n",
- iterations, timeTaken, timeTaken / 2 / iterations);
}
}
+
+TEST_F(IptablesRestoreControllerTest, TestStartup) {
+ // Tests that IptablesRestoreController::Init never sets its processes to null pointers if
+ // fork() succeeds.
+ {
+ // Mock fork(), and check that initializing 100 times never results in a null pointer.
+ constexpr int NUM_ITERATIONS = 100; // Takes 100-150ms on angler.
+ constexpr pid_t FAKE_PID = 2000000001;
+ StrictMock<ScopedMockSyscalls> sys;
+
+ EXPECT_CALL(sys, fork()).Times(NUM_ITERATIONS * 2).WillRepeatedly(Return(FAKE_PID));
+ for (int i = 0; i < NUM_ITERATIONS; i++) {
+ Init();
+ EXPECT_NE(0, getIpRestorePid(IptablesRestoreController::IPTABLES_PROCESS));
+ EXPECT_NE(0, getIpRestorePid(IptablesRestoreController::IP6TABLES_PROCESS));
+ }
+ }
+
+ // The controller is now in an invalid state: the pipes are connected to working iptables
+ // processes, but the PIDs are set to FAKE_PID. Send a malformed command to ensure that the
+ // processes terminate and close the pipes, then send a valid command to have the controller
+ // re-initialize properly now that fork() is no longer mocked.
+ EXPECT_EQ(-1, con.execute(V4V6, "malformed command\n", nullptr));
+ EXPECT_EQ(0, con.execute(V4V6, "#Test\n", nullptr));
+}
diff --git a/server/NatController.cpp b/server/NatController.cpp
index 85a7ee1..58c732d 100644
--- a/server/NatController.cpp
+++ b/server/NatController.cpp
@@ -16,27 +16,32 @@
#define LOG_NDEBUG 0
-#include <stdlib.h>
+#include <string>
+#include <vector>
+
#include <errno.h>
+#include <fcntl.h>
+#include <stdlib.h>
+#include <string.h>
+#include <arpa/inet.h>
+#include <linux/in.h>
+#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/wait.h>
-#include <fcntl.h>
-#include <netinet/in.h>
-#include <arpa/inet.h>
-#include <string.h>
-#include <cutils/properties.h>
#define LOG_TAG "NatController"
+#include <android-base/strings.h>
#include <android-base/stringprintf.h>
#include <cutils/log.h>
+#include <cutils/properties.h>
#include <logwrap/logwrap.h>
-#include "NetdConstants.h"
#include "NatController.h"
#include "NetdConstants.h"
#include "RouteController.h"
+using android::base::Join;
using android::base::StringPrintf;
const char* NatController::LOCAL_FORWARD = "natctrl_FORWARD";
@@ -60,29 +65,6 @@
bool checkRes;
};
-int NatController::runCmd(int argc, const char **argv) {
- int res;
-
- res = execFunction(argc, (char **)argv, NULL, false, false);
-
-#if !LOG_NDEBUG
- std::string full_cmd = argv[0];
- argc--; argv++;
- /*
- * HACK: Sometimes runCmd() is called with a ridcously large value (32)
- * and it works because the argv[] contains a NULL after the last
- * true argv. So here we use the NULL argv[] to terminate when the argc
- * is horribly wrong, and argc for the normal cases.
- */
- for (; argc && argv[0]; argc--, argv++) {
- full_cmd += " ";
- full_cmd += argv[0];
- }
- ALOGV("runCmd(%s) res=%d", full_cmd.c_str(), res);
-#endif
- return res;
-}
-
int NatController::setupIptablesHooks() {
int res;
res = setDefaults();
@@ -169,27 +151,24 @@
// add this if we are the first added nat
if (natCount == 0) {
- const char *v4Cmd[] = {
- IPTABLES_PATH,
- "-w",
- "-t",
- "nat",
- "-A",
- LOCAL_NAT_POSTROUTING,
- "-o",
- extIface,
- "-j",
- "MASQUERADE"
+ std::vector<std::string> v4Cmds = {
+ "*nat",
+ StringPrintf("-A %s -o %s -j MASQUERADE", LOCAL_NAT_POSTROUTING, extIface),
+ "COMMIT\n"
};
/*
* IPv6 tethering doesn't need the state-based conntrack rules, so
* it unconditionally jumps to the tether counters chain all the time.
*/
- const char *v6Cmd[] = {IP6TABLES_PATH, "-w", "-A", LOCAL_FORWARD,
- "-g", LOCAL_TETHER_COUNTERS_CHAIN};
+ std::vector<std::string> v6Cmds = {
+ "*filter",
+ StringPrintf("-A %s -g %s", LOCAL_FORWARD, LOCAL_TETHER_COUNTERS_CHAIN),
+ "COMMIT\n"
+ };
- if (runCmd(ARRAY_SIZE(v4Cmd), v4Cmd) || runCmd(ARRAY_SIZE(v6Cmd), v6Cmd)) {
+ if (iptablesRestoreFunction(V4, Join(v4Cmds, '\n')) ||
+ iptablesRestoreFunction(V6, Join(v6Cmds, '\n'))) {
ALOGE("Error setting postroute rule: iface=%s", extIface);
// unwind what's been done, but don't care about success - what more could we do?
setDefaults();
@@ -206,204 +185,85 @@
return -1;
}
- /* Always make sure the drop rule is at the end */
- const char *cmd1[] = {
- IPTABLES_PATH,
- "-w",
- "-D",
- LOCAL_FORWARD,
- "-j",
- "DROP"
- };
- runCmd(ARRAY_SIZE(cmd1), cmd1);
- const char *cmd2[] = {
- IPTABLES_PATH,
- "-w",
- "-A",
- LOCAL_FORWARD,
- "-j",
- "DROP"
- };
- runCmd(ARRAY_SIZE(cmd2), cmd2);
-
natCount++;
return 0;
}
-bool NatController::checkTetherCountingRuleExist(const char *pair_name) {
- std::list<std::string>::iterator it;
-
- for (it = ifacePairList.begin(); it != ifacePairList.end(); it++) {
- if (*it == pair_name) {
- /* We already have this counter */
- return true;
- }
- }
- return false;
+bool NatController::checkTetherCountingRuleExist(const std::string& pair_name) {
+ return std::find(ifacePairList.begin(), ifacePairList.end(), pair_name) != ifacePairList.end();
}
-int NatController::setTetherCountingRules(bool add, const char *intIface, const char *extIface) {
-
- /* We only ever add tethering quota rules so that they stick. */
- if (!add) {
- return 0;
- }
- char *pair_name;
- asprintf(&pair_name, "%s_%s", intIface, extIface);
-
- if (checkTetherCountingRuleExist(pair_name)) {
- free(pair_name);
- return 0;
- }
- const char *cmd2b[] = {
- IPTABLES_PATH,
- "-w", "-A", LOCAL_TETHER_COUNTERS_CHAIN, "-i", intIface, "-o", extIface, "-j", "RETURN"
- };
-
- const char *cmd2c[] = {
- IP6TABLES_PATH,
- "-w", "-A", LOCAL_TETHER_COUNTERS_CHAIN, "-i", intIface, "-o", extIface, "-j", "RETURN"
- };
-
- if (runCmd(ARRAY_SIZE(cmd2b), cmd2b) || runCmd(ARRAY_SIZE(cmd2c), cmd2c)) {
- free(pair_name);
- return -1;
- }
- ifacePairList.push_front(pair_name);
- free(pair_name);
-
- asprintf(&pair_name, "%s_%s", extIface, intIface);
- if (checkTetherCountingRuleExist(pair_name)) {
- free(pair_name);
- return 0;
- }
-
- const char *cmd3b[] = {
- IPTABLES_PATH,
- "-w", "-A", LOCAL_TETHER_COUNTERS_CHAIN, "-i", extIface, "-o", intIface, "-j", "RETURN"
- };
-
- const char *cmd3c[] = {
- IP6TABLES_PATH,
- "-w", "-A", LOCAL_TETHER_COUNTERS_CHAIN, "-i", extIface, "-o", intIface, "-j", "RETURN"
- };
-
- if (runCmd(ARRAY_SIZE(cmd3b), cmd3b) || runCmd(ARRAY_SIZE(cmd3c), cmd3c)) {
- // unwind what's been done, but don't care about success - what more could we do?
- free(pair_name);
- return -1;
- }
- ifacePairList.push_front(pair_name);
- free(pair_name);
- return 0;
+/* static */
+std::string NatController::makeTetherCountingRule(const char *if1, const char *if2) {
+ return StringPrintf("-A %s -i %s -o %s -j RETURN", LOCAL_TETHER_COUNTERS_CHAIN, if1, if2);
}
int NatController::setForwardRules(bool add, const char *intIface, const char *extIface) {
- const char *cmd1[] = {
- IPTABLES_PATH,
- "-w",
- add ? "-A" : "-D",
- LOCAL_FORWARD,
- "-i",
- extIface,
- "-o",
- intIface,
- "-m",
- "state",
- "--state",
- "ESTABLISHED,RELATED",
- "-g",
- LOCAL_TETHER_COUNTERS_CHAIN
- };
- int rc = 0;
+ const char *op = add ? "-A" : "-D";
- if (runCmd(ARRAY_SIZE(cmd1), cmd1) && add) {
+ std::string rpfilterCmd = StringPrintf(
+ "*raw\n"
+ "%s %s -i %s -m rpfilter --invert ! -s fe80::/64 -j DROP\n"
+ "COMMIT\n", op, LOCAL_RAW_PREROUTING, intIface);
+ if (iptablesRestoreFunction(V6, rpfilterCmd) == -1 && add) {
return -1;
}
- const char *cmd2[] = {
- IPTABLES_PATH,
- "-w",
- add ? "-A" : "-D",
- LOCAL_FORWARD,
- "-i",
- intIface,
- "-o",
- extIface,
- "-m",
- "state",
- "--state",
- "INVALID",
- "-j",
- "DROP"
+ std::vector<std::string> v4 = {
+ "*filter",
+ StringPrintf("%s %s -i %s -o %s -m state --state ESTABLISHED,RELATED -g %s",
+ op, LOCAL_FORWARD, extIface, intIface, LOCAL_TETHER_COUNTERS_CHAIN),
+ StringPrintf("%s %s -i %s -o %s -m state --state INVALID -j DROP",
+ op, LOCAL_FORWARD, intIface, extIface),
+ StringPrintf("%s %s -i %s -o %s -g %s",
+ op, LOCAL_FORWARD, intIface, extIface, LOCAL_TETHER_COUNTERS_CHAIN),
};
- const char *cmd3[] = {
- IPTABLES_PATH,
- "-w",
- add ? "-A" : "-D",
- LOCAL_FORWARD,
- "-i",
- intIface,
- "-o",
- extIface,
- "-g",
- LOCAL_TETHER_COUNTERS_CHAIN
+ std::vector<std::string> v6 = {
+ "*filter",
};
- const char *cmd4[] = {
- IP6TABLES_PATH,
- "-w",
- "-t",
- "raw",
- add ? "-A" : "-D",
- LOCAL_RAW_PREROUTING,
- "-i",
- intIface,
- "-m",
- "rpfilter",
- "--invert",
- "!",
- "-s",
- "fe80::/64",
- "-j",
- "DROP"
- };
-
- if (runCmd(ARRAY_SIZE(cmd2), cmd2) && add) {
- // bail on error, but only if adding
- rc = -1;
- goto err_invalid_drop;
+ /* We only ever add tethering quota rules so that they stick. */
+ std::string pair1 = StringPrintf("%s_%s", intIface, extIface);
+ if (add && !checkTetherCountingRuleExist(pair1)) {
+ v4.push_back(makeTetherCountingRule(intIface, extIface));
+ v6.push_back(makeTetherCountingRule(intIface, extIface));
+ }
+ std::string pair2 = StringPrintf("%s_%s", extIface, intIface);
+ if (add && !checkTetherCountingRuleExist(pair2)) {
+ v4.push_back(makeTetherCountingRule(extIface, intIface));
+ v6.push_back(makeTetherCountingRule(extIface, intIface));
}
- if (runCmd(ARRAY_SIZE(cmd3), cmd3) && add) {
+ // Always make sure the drop rule is at the end.
+ // TODO: instead of doing this, consider just rebuilding LOCAL_FORWARD completely from scratch
+ // every time, starting with ":natctrl_FORWARD -\n". This method would likely be a bit simpler.
+ if (add) {
+ v4.push_back(StringPrintf("-D %s -j DROP", LOCAL_FORWARD));
+ v4.push_back(StringPrintf("-A %s -j DROP", LOCAL_FORWARD));
+ }
+
+ v4.push_back("COMMIT\n");
+ v6.push_back("COMMIT\n");
+
+ // We only add IPv6 rules here, never remove them.
+ if (iptablesRestoreFunction(V4, Join(v4, '\n')) == -1 ||
+ (add && iptablesRestoreFunction(V6, Join(v6, '\n')) == -1)) {
// unwind what's been done, but don't care about success - what more could we do?
- rc = -1;
- goto err_return;
+ if (add) {
+ setForwardRules(false, intIface, extIface);
+ }
+ return -1;
}
- if (runCmd(ARRAY_SIZE(cmd4), cmd4) && add) {
- rc = -1;
- goto err_rpfilter;
+ if (add && !checkTetherCountingRuleExist(pair1)) {
+ ifacePairList.push_front(pair1);
}
-
- if (setTetherCountingRules(add, intIface, extIface) && add) {
- rc = -1;
- goto err_return;
+ if (add && !checkTetherCountingRuleExist(pair2)) {
+ ifacePairList.push_front(pair2);
}
return 0;
-
-err_rpfilter:
- cmd3[2] = "-D";
- runCmd(ARRAY_SIZE(cmd3), cmd3);
-err_return:
- cmd2[2] = "-D";
- runCmd(ARRAY_SIZE(cmd2), cmd2);
-err_invalid_drop:
- cmd1[2] = "-D";
- runCmd(ARRAY_SIZE(cmd1), cmd1);
- return rc;
}
int NatController::disableNat(const char* intIface, const char* extIface) {
diff --git a/server/NatController.h b/server/NatController.h
index 5541a26..4c711b0 100644
--- a/server/NatController.h
+++ b/server/NatController.h
@@ -17,7 +17,6 @@
#ifndef _NAT_CONTROLLER_H
#define _NAT_CONTROLLER_H
-#include <linux/in.h>
#include <list>
#include <string>
@@ -44,7 +43,8 @@
private:
int natCount;
- bool checkTetherCountingRuleExist(const char *pair_name);
+ static std::string makeTetherCountingRule(const char *if1, const char *if2);
+ bool checkTetherCountingRuleExist(const std::string& pair_name);
int setDefaults();
int runCmd(int argc, const char **argv);
diff --git a/server/NatControllerTest.cpp b/server/NatControllerTest.cpp
index ada8ad7..c24322a 100644
--- a/server/NatControllerTest.cpp
+++ b/server/NatControllerTest.cpp
@@ -32,6 +32,7 @@
#include "NatController.h"
#include "IptablesBaseTest.h"
+using android::base::Join;
using android::base::StringPrintf;
class NatControllerTest : public IptablesBaseTest {
@@ -87,46 +88,78 @@
"COMMIT\n" },
};
- const ExpectedIptablesCommands TWIDDLE_COMMANDS = {
- { V4, "-D natctrl_FORWARD -j DROP" },
- { V4, "-A natctrl_FORWARD -j DROP" },
- };
-
ExpectedIptablesCommands firstNatCommands(const char *extIf) {
+ std::string v4Cmd = StringPrintf(
+ "*nat\n"
+ "-A natctrl_nat_POSTROUTING -o %s -j MASQUERADE\n"
+ "COMMIT\n", extIf);
+ std::string v6Cmd =
+ "*filter\n"
+ "-A natctrl_FORWARD -g natctrl_tether_counters\n"
+ "COMMIT\n";
return {
- { V4, StringPrintf("-t nat -A natctrl_nat_POSTROUTING -o %s -j MASQUERADE", extIf) },
- { V6, "-A natctrl_FORWARD -g natctrl_tether_counters" },
+ { V4, v4Cmd },
+ { V6, v6Cmd },
};
}
ExpectedIptablesCommands startNatCommands(const char *intIf, const char *extIf) {
+ std::string rpfilterCmd = StringPrintf(
+ "*raw\n"
+ "-A natctrl_raw_PREROUTING -i %s -m rpfilter --invert ! -s fe80::/64 -j DROP\n"
+ "COMMIT\n", intIf);
+
+ std::vector<std::string> v4Cmds = {
+ "*filter",
+ StringPrintf("-A natctrl_FORWARD -i %s -o %s -m state --state"
+ " ESTABLISHED,RELATED -g natctrl_tether_counters", extIf, intIf),
+ StringPrintf("-A natctrl_FORWARD -i %s -o %s -m state --state INVALID -j DROP",
+ intIf, extIf),
+ StringPrintf("-A natctrl_FORWARD -i %s -o %s -g natctrl_tether_counters",
+ intIf, extIf),
+ StringPrintf("-A natctrl_tether_counters -i %s -o %s -j RETURN", intIf, extIf),
+ StringPrintf("-A natctrl_tether_counters -i %s -o %s -j RETURN", extIf, intIf),
+ "-D natctrl_FORWARD -j DROP",
+ "-A natctrl_FORWARD -j DROP",
+ "COMMIT\n",
+ };
+
+ std::vector<std::string> v6Cmds = {
+ "*filter",
+ StringPrintf("-A natctrl_tether_counters -i %s -o %s -j RETURN", intIf, extIf),
+ StringPrintf("-A natctrl_tether_counters -i %s -o %s -j RETURN", extIf, intIf),
+ "COMMIT\n",
+ };
+
return {
- { V4, StringPrintf("-A natctrl_FORWARD -i %s -o %s -m state --state"
- " ESTABLISHED,RELATED -g natctrl_tether_counters", extIf, intIf) },
- { V4, StringPrintf("-A natctrl_FORWARD -i %s -o %s -m state --state INVALID -j DROP",
- intIf, extIf) },
- { V4, StringPrintf("-A natctrl_FORWARD -i %s -o %s -g natctrl_tether_counters",
- intIf, extIf) },
- { V6, StringPrintf("-t raw -A natctrl_raw_PREROUTING -i %s -m rpfilter --invert"
- " ! -s fe80::/64 -j DROP", intIf) },
- { V4V6, StringPrintf("-A natctrl_tether_counters -i %s -o %s -j RETURN",
- intIf, extIf) },
- { V4V6, StringPrintf("-A natctrl_tether_counters -i %s -o %s -j RETURN",
- extIf, intIf) },
+ { V6, rpfilterCmd },
+ { V4, Join(v4Cmds, '\n') },
+ { V6, Join(v6Cmds, '\n') },
};
}
ExpectedIptablesCommands stopNatCommands(const char *intIf, const char *extIf) {
- return {
- { V4, StringPrintf("-D natctrl_FORWARD -i %s -o %s -m state --state"
- " ESTABLISHED,RELATED -g natctrl_tether_counters", extIf, intIf) },
- { V4, StringPrintf("-D natctrl_FORWARD -i %s -o %s -m state --state INVALID -j DROP",
- intIf, extIf) },
- { V4, StringPrintf("-D natctrl_FORWARD -i %s -o %s -g natctrl_tether_counters",
- intIf, extIf) },
- { V6, StringPrintf("-t raw -D natctrl_raw_PREROUTING -i %s -m rpfilter --invert"
- " ! -s fe80::/64 -j DROP", intIf) },
+ std::string rpfilterCmd = StringPrintf(
+ "*raw\n"
+ "-D natctrl_raw_PREROUTING -i %s -m rpfilter --invert ! -s fe80::/64 -j DROP\n"
+ "COMMIT\n", intIf);
+
+ std::vector<std::string> v4Cmds = {
+ "*filter",
+ StringPrintf("-D natctrl_FORWARD -i %s -o %s -m state --state"
+ " ESTABLISHED,RELATED -g natctrl_tether_counters", extIf, intIf),
+ StringPrintf("-D natctrl_FORWARD -i %s -o %s -m state --state INVALID -j DROP",
+ intIf, extIf),
+ StringPrintf("-D natctrl_FORWARD -i %s -o %s -g natctrl_tether_counters",
+ intIf, extIf),
+ "COMMIT\n",
};
+
+ return {
+ { V6, rpfilterCmd },
+ { V4, Join(v4Cmds, '\n') },
+ };
+
}
};
@@ -141,28 +174,24 @@
}
TEST_F(NatControllerTest, TestAddAndRemoveNat) {
-
- std::vector<ExpectedIptablesCommands> startFirstNat = {
- firstNatCommands("rmnet0"),
- startNatCommands("wlan0", "rmnet0"),
- TWIDDLE_COMMANDS,
- };
+ ExpectedIptablesCommands expected;
+ ExpectedIptablesCommands setupFirstNatCommands = firstNatCommands("rmnet0");
+ ExpectedIptablesCommands startFirstNatCommands = startNatCommands("wlan0", "rmnet0");
+ expected.insert(expected.end(), setupFirstNatCommands.begin(), setupFirstNatCommands.end());
+ expected.insert(expected.end(), startFirstNatCommands.begin(), startFirstNatCommands.end());
mNatCtrl.enableNat("wlan0", "rmnet0");
- expectIptablesCommands(startFirstNat);
+ expectIptablesRestoreCommands(expected);
- std::vector<ExpectedIptablesCommands> startOtherNat = {
- startNatCommands("usb0", "rmnet0"),
- TWIDDLE_COMMANDS,
- };
+ ExpectedIptablesCommands startOtherNat = startNatCommands("usb0", "rmnet0");
mNatCtrl.enableNat("usb0", "rmnet0");
- expectIptablesCommands(startOtherNat);
+ expectIptablesRestoreCommands(startOtherNat);
ExpectedIptablesCommands stopOtherNat = stopNatCommands("wlan0", "rmnet0");
mNatCtrl.disableNat("wlan0", "rmnet0");
- expectIptablesCommands(stopOtherNat);
+ expectIptablesRestoreCommands(stopOtherNat);
- ExpectedIptablesCommands stopLastNat = stopNatCommands("usb0", "rmnet0");
+ expected = stopNatCommands("usb0", "rmnet0");
+ expected.insert(expected.end(), FLUSH_COMMANDS.begin(), FLUSH_COMMANDS.end());
mNatCtrl.disableNat("usb0", "rmnet0");
- expectIptablesCommands(stopLastNat);
- expectIptablesRestoreCommands(FLUSH_COMMANDS);
+ expectIptablesRestoreCommands(expected);
}
diff --git a/server/NetdConstants.cpp b/server/NetdConstants.cpp
index 58b2f64..5abdacd 100644
--- a/server/NetdConstants.cpp
+++ b/server/NetdConstants.cpp
@@ -38,91 +38,9 @@
const size_t SHA256_SIZE = EVP_MD_size(EVP_sha256());
const char * const OEM_SCRIPT_PATH = "/system/bin/oem-iptables-init.sh";
-const char * const IPTABLES_PATH = "/system/bin/iptables";
-const char * const IP6TABLES_PATH = "/system/bin/ip6tables";
-const char * const TC_PATH = "/system/bin/tc";
-const char * const IP_PATH = "/system/bin/ip";
const char * const ADD = "add";
const char * const DEL = "del";
-static void logExecError(const char* argv[], int res, int status) {
- const char** argp = argv;
- std::string args = "";
- while (*argp) {
- args += *argp;
- args += ' ';
- argp++;
- }
- ALOGE("exec() res=%d, status=%d for %s", res, status, args.c_str());
-}
-
-static int execIptablesCommand(int argc, const char *argv[], bool silent) {
- int res;
- int status;
-
- res = android_fork_execvp(argc, (char **)argv, &status, false,
- !silent);
- if (res || !WIFEXITED(status) || WEXITSTATUS(status)) {
- if (!silent) {
- logExecError(argv, res, status);
- }
- if (res)
- return res;
- if (!WIFEXITED(status))
- return ECHILD;
- }
- return WEXITSTATUS(status);
-}
-
-static int execIptables(IptablesTarget target, bool silent, va_list args) {
- /* Read arguments from incoming va_list; we expect the list to be NULL terminated. */
- std::list<const char*> argsList;
- argsList.push_back(NULL);
- const char* arg;
-
- // Wait to avoid failure due to another process holding the lock
- argsList.push_back("-w");
-
- do {
- arg = va_arg(args, const char *);
- argsList.push_back(arg);
- } while (arg);
-
- int i = 0;
- const char* argv[argsList.size()];
- std::list<const char*>::iterator it;
- for (it = argsList.begin(); it != argsList.end(); it++, i++) {
- argv[i] = *it;
- }
-
- int res = 0;
- if (target == V4 || target == V4V6) {
- argv[0] = IPTABLES_PATH;
- res |= execIptablesCommand(argsList.size(), argv, silent);
- }
- if (target == V6 || target == V4V6) {
- argv[0] = IP6TABLES_PATH;
- res |= execIptablesCommand(argsList.size(), argv, silent);
- }
- return res;
-}
-
-int execIptables(IptablesTarget target, ...) {
- va_list args;
- va_start(args, target);
- int res = execIptables(target, false, args);
- va_end(args);
- return res;
-}
-
-int execIptablesSilently(IptablesTarget target, ...) {
- va_list args;
- va_start(args, target);
- int res = execIptables(target, true, args);
- va_end(args);
- return res;
-}
-
int execIptablesRestoreWithOutput(IptablesTarget target, const std::string& commands,
std::string *output) {
return android::net::gCtls->iptablesRestoreCtrl.execute(target, commands, output);
diff --git a/server/NetdConstants.h b/server/NetdConstants.h
index 446a898..b1117c4 100644
--- a/server/NetdConstants.h
+++ b/server/NetdConstants.h
@@ -34,18 +34,12 @@
extern const size_t SHA256_SIZE;
-extern const char * const IPTABLES_PATH;
-extern const char * const IP6TABLES_PATH;
-extern const char * const IP_PATH;
-extern const char * const TC_PATH;
extern const char * const OEM_SCRIPT_PATH;
extern const char * const ADD;
extern const char * const DEL;
enum IptablesTarget { V4, V6, V4V6 };
-int execIptables(IptablesTarget target, ...);
-int execIptablesSilently(IptablesTarget target, ...);
int execIptablesRestore(IptablesTarget target, const std::string& commands);
int execIptablesRestoreWithOutput(IptablesTarget target, const std::string& commands,
std::string *output);
diff --git a/server/XfrmController.cpp b/server/XfrmController.cpp
index 8437e4b..ea6f696 100644
--- a/server/XfrmController.cpp
+++ b/server/XfrmController.cpp
@@ -67,13 +67,15 @@
namespace android {
namespace net {
-namespace {
-
+// Exposed for testing
constexpr uint32_t ALGO_MASK_AUTH_ALL = ~0;
+// Exposed for testing
constexpr uint32_t ALGO_MASK_CRYPT_ALL = ~0;
-
+// Exposed for testing
constexpr uint8_t REPLAY_WINDOW_SIZE = 4;
+namespace {
+
constexpr uint32_t RAND_SPI_MIN = 1;
constexpr uint32_t RAND_SPI_MAX = 0xFFFFFFFE;
diff --git a/server/XfrmController.h b/server/XfrmController.h
index 739d8d1..904ae88 100644
--- a/server/XfrmController.h
+++ b/server/XfrmController.h
@@ -33,6 +33,13 @@
namespace android {
namespace net {
+// Exposed for testing
+extern const uint32_t ALGO_MASK_AUTH_ALL;
+// Exposed for testing
+extern const uint32_t ALGO_MASK_CRYPT_ALL;
+// Exposed for testing
+extern const uint8_t REPLAY_WINDOW_SIZE;
+
// Suggest we avoid the smallest and largest ints
class XfrmMessage;
class TransportModeSecurityAssociation;
@@ -129,12 +136,39 @@
int ipSecRemoveTransportModeTransform(const android::base::unique_fd& socket);
+ // Exposed for testing
+ static constexpr size_t MAX_ALGO_LENGTH = 128;
+
+ // Exposed for testing
+ struct nlattr_algo_crypt {
+ nlattr hdr;
+ xfrm_algo crypt;
+ uint8_t key[MAX_ALGO_LENGTH];
+ };
+
+ // Exposed for testing
+ struct nlattr_algo_auth {
+ nlattr hdr;
+ xfrm_algo_auth auth;
+ uint8_t key[MAX_ALGO_LENGTH];
+ };
+
+ // Exposed for testing
+ struct nlattr_user_tmpl {
+ nlattr hdr;
+ xfrm_user_tmpl tmpl;
+ };
+
+ // Exposed for testing
+ struct nlattr_encap_tmpl {
+ nlattr hdr;
+ xfrm_encap_tmpl tmpl;
+ };
+
private:
// prevent concurrent modification of XFRM
android::RWLock mLock;
- static constexpr size_t MAX_ALGO_LENGTH = 128;
-
/*
* Below is a redefinition of the xfrm_usersa_info struct that is part
* of the Linux uapi <linux/xfrm.h> to align the structures to a 64-bit
@@ -175,29 +209,6 @@
"struct xfrm_userspi_info has changed and does not match the kernel struct.");
#endif
- struct nlattr_algo_crypt {
- nlattr hdr;
- xfrm_algo crypt;
- uint8_t key[MAX_ALGO_LENGTH];
- };
-
- struct nlattr_algo_auth {
- nlattr hdr;
- xfrm_algo_auth auth;
- uint8_t key[MAX_ALGO_LENGTH];
- };
-
- struct nlattr_user_tmpl {
- nlattr hdr;
- xfrm_user_tmpl tmpl;
- };
-
- struct nlattr_encap_tmpl {
- nlattr hdr;
- xfrm_encap_tmpl tmpl;
- };
-
-
// helper function for filling in the XfrmSaInfo structure
static int fillXfrmSaId(int32_t direction, const std::string& localAddress,
const std::string& remoteAddress, int32_t spi, XfrmSaId* xfrmId);
diff --git a/server/dns/DnsTlsTransport.cpp b/server/dns/DnsTlsTransport.cpp
index 4988023..74b779c 100644
--- a/server/dns/DnsTlsTransport.cpp
+++ b/server/dns/DnsTlsTransport.cpp
@@ -183,22 +183,50 @@
if (DBG) {
ALOGD("Checking DNS over TLS fingerprint");
}
- // TODO: Follow the cert chain and check all the way up.
- bssl::UniquePtr<X509> cert(SSL_get_peer_certificate(ssl.get()));
- if (!cert) {
+
+ // We only care that the chain is internally self-consistent, not that
+ // it chains to a trusted root, so we can ignore some kinds of errors.
+ // TODO: Add a CA root verification mode that respects these errors.
+ int verify_result = SSL_get_verify_result(ssl.get());
+ switch (verify_result) {
+ case X509_V_OK:
+ case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT:
+ case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN:
+ case X509_V_ERR_CERT_UNTRUSTED:
+ break;
+ default:
+ ALOGW("Invalid certificate chain, error %d", verify_result);
+ return nullptr;
+ }
+
+ STACK_OF(X509) *chain = SSL_get_peer_cert_chain(ssl.get());
+ if (!chain) {
ALOGW("Server has null certificate");
return nullptr;
}
- std::vector<uint8_t> digest;
- if (!getSPKIDigest(cert.get(), &digest)) {
- ALOGE("Digest computation failed");
- return nullptr;
+ // Chain and its contents are owned by ssl, so we don't need to free explicitly.
+ bool matched = false;
+ for (size_t i = 0; i < sk_X509_num(chain); ++i) {
+ // This appears to be O(N^2), but there doesn't seem to be a straightforward
+ // way to walk a STACK_OF nondestructively in linear time.
+ X509* cert = sk_X509_value(chain, i);
+ std::vector<uint8_t> digest;
+ if (!getSPKIDigest(cert, &digest)) {
+ ALOGE("Digest computation failed");
+ return nullptr;
+ }
+
+ if (mServer.fingerprints.count(digest) > 0) {
+ matched = true;
+ break;
+ }
}
- if (mServer.fingerprints.count(digest) == 0) {
+ if (!matched) {
ALOGW("No matching fingerprint");
return nullptr;
}
+
if (DBG) {
ALOGD("DNS over TLS fingerprint is correct");
}
diff --git a/server/oem_iptables_hook.cpp b/server/oem_iptables_hook.cpp
index 7e4b3cb..4c839a2 100644
--- a/server/oem_iptables_hook.cpp
+++ b/server/oem_iptables_hook.cpp
@@ -27,40 +27,17 @@
#include <logwrap/logwrap.h>
#include "NetdConstants.h"
-static int runIptablesCmd(int argc, const char **argv) {
- int res;
-
- res = android_fork_execvp(argc, (char **)argv, NULL, false, false);
- return res;
-}
-
static bool oemCleanupHooks() {
- const char *cmd1[] = {
- IPTABLES_PATH,
- "-w",
- "-F",
- "oem_out"
- };
- runIptablesCmd(ARRAY_SIZE(cmd1), cmd1);
+ std::string cmd =
+ "*filter\n"
+ ":oem_out -\n"
+ ":oem_fwd -\n"
+ "COMMIT\n"
+ "*nat\n"
+ ":oem_nat_pre -\n"
+ "COMMIT\n";
- const char *cmd2[] = {
- IPTABLES_PATH,
- "-w",
- "-F",
- "oem_fwd"
- };
- runIptablesCmd(ARRAY_SIZE(cmd2), cmd2);
-
- const char *cmd3[] = {
- IPTABLES_PATH,
- "-w",
- "-t",
- "nat",
- "-F",
- "oem_nat_pre"
- };
- runIptablesCmd(ARRAY_SIZE(cmd3), cmd3);
- return true;
+ return (execIptablesRestore(V4V6, cmd) == 0);
}
static bool oemInitChains() {
diff --git a/tests/benchmarks/Android.mk b/tests/benchmarks/Android.mk
index bfcf600..c67d40e 100644
--- a/tests/benchmarks/Android.mk
+++ b/tests/benchmarks/Android.mk
@@ -13,40 +13,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-#
-# Note: netd benchmark can't build on nyc-mr2-dev, because google-benchmark project is out of date
-# and won't be backported, and thus the content of this file is commented out to disable it.
-# In order to run netd benchmark locally you can uncomment the content of this file and follow
-# instructions in ag/1673408 (checkout that commit and build external/google-benchmark and
-# system/netd locally and then run the benchmark locally)
-#
-#
-#LOCAL_PATH := $(call my-dir)
-#
-## APCT build target for metrics tests
-#include $(CLEAR_VARS)
-#LOCAL_MODULE := netd_benchmark
-#LOCAL_CFLAGS := -Wall -Werror -Wunused-parameter
-## Bug: http://b/29823425 Disable -Wvarargs for Clang update to r271374
-#LOCAL_CFLAGS += -Wno-varargs
+LOCAL_PATH := $(call my-dir)
-#EXTRA_LDLIBS := -lpthread
-#LOCAL_SHARED_LIBRARIES += libbase libbinder liblog libnetd_client
-#LOCAL_STATIC_LIBRARIES += libnetd_test_dnsresponder libutils
+# APCT build target for metrics tests
+include $(CLEAR_VARS)
+LOCAL_MODULE := netd_benchmark
+LOCAL_CFLAGS := -Wall -Werror -Wunused-parameter
+# Bug: http://b/29823425 Disable -Wvarargs for Clang update to r271374
+LOCAL_CFLAGS += -Wno-varargs
-#LOCAL_AIDL_INCLUDES := system/netd/server/binder
-#LOCAL_C_INCLUDES += system/netd/include \
-# system/netd/client \
-# system/netd/server \
-# system/netd/server/binder \
-# system/netd/tests/dns_responder \
-# bionic/libc/dns/include
+EXTRA_LDLIBS := -lpthread
+LOCAL_SHARED_LIBRARIES += libbase libbinder liblog libnetd_client
+LOCAL_STATIC_LIBRARIES += libnetd_test_dnsresponder libutils
-#LOCAL_SRC_FILES := main.cpp \
-# connect_benchmark.cpp \
-# dns_benchmark.cpp \
-# ../../server/binder/android/net/metrics/INetdEventListener.aidl
+LOCAL_AIDL_INCLUDES := system/netd/server/binder
+LOCAL_C_INCLUDES += system/netd/include \
+ system/netd/client \
+ system/netd/server \
+ system/netd/server/binder \
+ system/netd/tests/dns_responder \
+ bionic/libc/dns/include
-#LOCAL_MODULE_TAGS := eng tests
+LOCAL_SRC_FILES := main.cpp \
+ connect_benchmark.cpp \
+ dns_benchmark.cpp \
+ ../../server/binder/android/net/metrics/INetdEventListener.aidl
-#include $(BUILD_NATIVE_BENCHMARK)
+LOCAL_MODULE_TAGS := eng tests
+
+include $(BUILD_NATIVE_BENCHMARK)
diff --git a/tests/binder_test.cpp b/tests/binder_test.cpp
index 41fc8c3..b2f362e 100644
--- a/tests/binder_test.cpp
+++ b/tests/binder_test.cpp
@@ -49,6 +49,9 @@
#include "android/net/UidRange.h"
#include "binder/IServiceManager.h"
+#define IP_PATH "/system/bin/ip"
+#define IP6TABLES_PATH "/system/bin/ip6tables"
+#define IPTABLES_PATH "/system/bin/iptables"
#define TUN_DEV "/dev/tun"
using namespace android;
diff --git a/tests/dns_responder/dns_tls_frontend.cpp b/tests/dns_responder/dns_tls_frontend.cpp
index 74b8e72..02b01f2 100644
--- a/tests/dns_responder/dns_tls_frontend.cpp
+++ b/tests/dns_responder/dns_tls_frontend.cpp
@@ -112,7 +112,7 @@
return privkey;
}
-bssl::UniquePtr<X509> make_cert(EVP_PKEY* privkey) {
+bssl::UniquePtr<X509> make_cert(EVP_PKEY* privkey, EVP_PKEY* parent_key) {
bssl::UniquePtr<X509> cert(X509_new());
if (!cert) {
ALOGE("X509_new failed");
@@ -127,7 +127,7 @@
X509_set_pubkey(cert.get(), privkey);
- if (!X509_sign(cert.get(), privkey, EVP_sha256())) {
+ if (!X509_sign(cert.get(), parent_key, EVP_sha256())) {
ALOGE("X509_sign failed");
return nullptr;
}
@@ -151,20 +151,37 @@
SSL_CTX_set_ecdh_auto(ctx_.get(), 1);
- bssl::UniquePtr<EVP_PKEY> key(make_private_key());
- bssl::UniquePtr<X509> cert(make_cert(key.get()));
- if (SSL_CTX_use_certificate(ctx_.get(), cert.get()) <= 0) {
+ // Make certificate chain
+ std::vector<bssl::UniquePtr<EVP_PKEY>> keys(chain_length_);
+ for (int i = 0; i < chain_length_; ++i) {
+ keys[i] = make_private_key();
+ }
+ std::vector<bssl::UniquePtr<X509>> certs(chain_length_);
+ for (int i = 0; i < chain_length_; ++i) {
+ int next = std::min(i + 1, chain_length_ - 1);
+ certs[i] = make_cert(keys[i].get(), keys[next].get());
+ }
+
+ // Install certificate chain.
+ if (SSL_CTX_use_certificate(ctx_.get(), certs[0].get()) <= 0) {
ALOGE("SSL_CTX_use_certificate failed");
return false;
}
-
- if (!getSPKIDigest(cert.get(), &fingerprint_)) {
- ALOGE("getSPKIDigest failed");
+ if (SSL_CTX_use_PrivateKey(ctx_.get(), keys[0].get()) <= 0 ) {
+ ALOGE("SSL_CTX_use_PrivateKey failed");
return false;
}
+ for (int i = 1; i < chain_length_; ++i) {
+ if (SSL_CTX_add1_chain_cert(ctx_.get(), certs[i].get()) != 1) {
+ ALOGE("SSL_CTX_add1_chain_cert failed");
+ return false;
+ }
+ }
- if (SSL_CTX_use_PrivateKey(ctx_.get(), key.get()) <= 0 ) {
- ALOGE("SSL_CTX_use_PrivateKey failed");
+ // Report the fingerprint of the "middle" cert. For N = 2, this is the root.
+ int fp_index = chain_length_ / 2;
+ if (!getSPKIDigest(certs[fp_index].get(), &fingerprint_)) {
+ ALOGE("getSPKIDigest failed");
return false;
}
diff --git a/tests/dns_responder/dns_tls_frontend.h b/tests/dns_responder/dns_tls_frontend.h
index 911ea3c..0a2556c 100644
--- a/tests/dns_responder/dns_tls_frontend.h
+++ b/tests/dns_responder/dns_tls_frontend.h
@@ -59,6 +59,8 @@
bool stopServer();
int queries() const { return queries_; }
bool waitForQueries(int number, int timeoutMs) const;
+ void set_chain_length(int length) { chain_length_ = length; }
+ // Represents a fingerprint from the middle of the certificate chain.
const std::vector<uint8_t>& fingerprint() const { return fingerprint_; }
private:
@@ -76,6 +78,7 @@
std::atomic<bool> terminate_ GUARDED_BY(update_mutex_);
std::thread handler_thread_ GUARDED_BY(update_mutex_);
std::mutex update_mutex_;
+ int chain_length_ = 1;
std::vector<uint8_t> fingerprint_;
};
diff --git a/tests/netd_test.cpp b/tests/netd_test.cpp
index 12d85aa..0217f5b 100644
--- a/tests/netd_test.cpp
+++ b/tests/netd_test.cpp
@@ -833,33 +833,38 @@
const char* listen_addr = "127.0.0.3";
const char* listen_udp = "53";
const char* listen_tls = "853";
- const char* host_name = "tlsfingerprint.example.com.";
- test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
- dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.1");
- ASSERT_TRUE(dns.startServer());
- std::vector<std::string> servers = { listen_addr };
+ for (int chain_length = 1; chain_length <= 3; ++chain_length) {
+ const char* host_name = StringPrintf("tlsfingerprint%d.example.com.", chain_length).c_str();
+ test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
+ dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.1");
+ ASSERT_TRUE(dns.startServer());
+ std::vector<std::string> servers = { listen_addr };
- test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
- ASSERT_TRUE(tls.startServer());
- auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "SHA-256",
- { base64Encode(tls.fingerprint()) });
- ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
+ test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
+ tls.set_chain_length(chain_length);
+ ASSERT_TRUE(tls.startServer());
+ auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "SHA-256",
+ { base64Encode(tls.fingerprint()) });
+ ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
- const hostent* result;
+ const hostent* result;
- // Wait for validation to complete.
- EXPECT_TRUE(tls.waitForQueries(1, 5000));
+ // Wait for validation to complete.
+ EXPECT_TRUE(tls.waitForQueries(1, 5000));
- result = gethostbyname("tlsfingerprint");
- ASSERT_FALSE(result == nullptr);
- EXPECT_EQ("1.2.3.1", ToString(result));
+ result = gethostbyname(StringPrintf("tlsfingerprint%d", chain_length).c_str());
+ EXPECT_FALSE(result == nullptr);
+ if (result) {
+ EXPECT_EQ("1.2.3.1", ToString(result));
- // Wait for query to get counted.
- EXPECT_TRUE(tls.waitForQueries(2, 5000));
+ // Wait for query to get counted.
+ EXPECT_TRUE(tls.waitForQueries(2, 5000));
+ }
- rv = mNetdSrv->removePrivateDnsServer(listen_addr);
- tls.stopServer();
- dns.stopServer();
+ rv = mNetdSrv->removePrivateDnsServer(listen_addr);
+ tls.stopServer();
+ dns.stopServer();
+ }
}
TEST_F(ResolverTest, GetHostByName_BadTlsFingerprint) {