Kill sockets when a VPN comes up.
1. Change the SockDiag callback function to be a filter that
returns a bool instead of a function that optionally kills a
socket. All existing callbacks basically only existed to kill
sockets under certain conditions, and making them return a
boolean allows reusing the same callback function signature
to filter sockets as well.
2. Add a new SockDiag method to kill sockets based on a UidRanges
object (which contains a number of UID ranges) and a list of
users to skip.
3. Add a new UIDRANGE mode to SockDiagTest to test the above.
4. When UID ranges are added or removed from the VPN, kill
sockets in those UID ranges unless the socket UIDs are in
mProtectableUsers and thus their creator might have set the
protect bit on their mark. Short of actually being
able to see the socket mark on each socket and basing our
decision on that, this is the best we can do.
Bug: 26976388
Change-Id: I53a30df3feb63254a6451a29fa6041c9b679f9bb
diff --git a/server/SockDiagTest.cpp b/server/SockDiagTest.cpp
index 6425c67..2061a3b 100644
--- a/server/SockDiagTest.cpp
+++ b/server/SockDiagTest.cpp
@@ -24,12 +24,7 @@
#include "NetdConstants.h"
#include "SockDiag.h"
-
-
-#define NUM_SOCKETS 500
-#define START_UID 8000 // START_UID + NUM_SOCKETS must be <= 9999.
-#define CLOSE_UID (START_UID + NUM_SOCKETS - 42) // Close to the end
-
+#include "UidRanges.h"
class SockDiagTest : public ::testing::Test {
};
@@ -104,7 +99,7 @@
if (msg == nullptr) {
EXPECT_FALSE(seenNull);
seenNull = true;
- return 0;
+ return false;
}
EXPECT_EQ(htonl(INADDR_LOOPBACK), msg->id.idiag_src[0]);
v4SocketsSeen++;
@@ -115,7 +110,7 @@
src, htons(msg->id.idiag_sport),
dst, htons(msg->id.idiag_dport),
tcpStateName(msg->idiag_state));
- return 0;
+ return false;
};
int v6SocketsSeen = 0;
@@ -125,7 +120,7 @@
if (msg == nullptr) {
EXPECT_FALSE(seenNull);
seenNull = true;
- return 0;
+ return false;
}
struct in6_addr *saddr = (struct in6_addr *) msg->id.idiag_src;
EXPECT_TRUE(
@@ -141,7 +136,7 @@
src, htons(msg->id.idiag_sport),
dst, htons(msg->id.idiag_dport),
tcpStateName(msg->idiag_state));
- return 0;
+ return false;
};
SockDiag sd;
@@ -183,6 +178,7 @@
enum MicroBenchmarkTestType {
ADDRESS,
UID,
+ UIDRANGE,
};
const char *testTypeName(MicroBenchmarkTestType mode) {
@@ -190,6 +186,7 @@
switch((mode)) {
TO_STRING_TYPE(ADDRESS);
TO_STRING_TYPE(UID);
+ TO_STRING_TYPE(UIDRANGE);
}
#undef TO_STRING_TYPE
}
@@ -204,16 +201,44 @@
protected:
SockDiag mSd;
+ constexpr static int MAX_SOCKETS = 500;
+ constexpr static int ADDRESS_SOCKETS = 500;
+ constexpr static int UID_SOCKETS = 100;
+ constexpr static uid_t START_UID = 8000; // START_UID + number of sockets must be <= 9999.
+ constexpr static int CLOSE_UID = START_UID + UID_SOCKETS - 42; // Close to the end
+ static_assert(START_UID + MAX_SOCKETS < 9999, "Too many sockets");
+
+ int howManySockets() {
+ MicroBenchmarkTestType mode = GetParam();
+ switch (mode) {
+ case ADDRESS:
+ return 500;
+ case UID:
+ case UIDRANGE:
+ return 50;
+ }
+ }
+
int destroySockets() {
MicroBenchmarkTestType mode = GetParam();
int ret;
- if (mode == ADDRESS) {
- ret = mSd.destroySockets("::1");
- EXPECT_LE(0, ret) << ": Failed to destroy sockets on ::1: " << strerror(-ret);
- } else {
- ret = mSd.destroySockets(IPPROTO_TCP, CLOSE_UID);
- EXPECT_LE(0, ret) << ": Failed to destroy sockets for UID " << CLOSE_UID << ": " <<
- strerror(-ret);
+ switch (mode) {
+ case ADDRESS:
+ ret = mSd.destroySockets("::1");
+ EXPECT_LE(0, ret) << ": Failed to destroy sockets on ::1: " << strerror(-ret);
+ break;
+ case UID:
+ ret = mSd.destroySockets(IPPROTO_TCP, CLOSE_UID);
+ EXPECT_LE(0, ret) << ": Failed to destroy sockets for UID " << CLOSE_UID << ": " <<
+ strerror(-ret);
+ break;
+ case UIDRANGE: {
+ const char *uidRangeStrings[] = { "8005-8012", "8042", "8043", "8090-8099" };
+ std::set<uid_t> skipUids { 8007, 8043, 8098, 8099 };
+ UidRanges uidRanges;
+ uidRanges.parseFrom(ARRAY_SIZE(uidRangeStrings), (char **) uidRangeStrings);
+ ret = mSd.destroySockets(uidRanges, skipUids);
+ }
}
return ret;
}
@@ -221,10 +246,22 @@
bool shouldHaveClosedSocket(int i) {
MicroBenchmarkTestType mode = GetParam();
switch (mode) {
- case ADDRESS:
- return true;
- case UID:
- return i == CLOSE_UID - START_UID;
+ case ADDRESS:
+ return true;
+ case UID:
+ return i == CLOSE_UID - START_UID;
+ case UIDRANGE: {
+ uid_t uid = i + START_UID;
+ // Skip UIDs in skipUids.
+ if (uid == 8007 || uid == 8043 || uid == 8098 || uid == 8099) {
+ return false;
+ }
+ // Include UIDs in uidRanges.
+ if ((8005 <= uid && uid <= 8012) || uid == 8042 || (8090 <= uid && uid <= 8099)) {
+ return true;
+ }
+ return false;
+ }
}
}
@@ -251,8 +288,10 @@
TEST_P(SockDiagMicroBenchmarkTest, TestMicroBenchmark) {
MicroBenchmarkTestType mode = GetParam();
+ int numSockets = howManySockets();
+
fprintf(stderr, "Benchmarking closing %d sockets based on %s\n",
- NUM_SOCKETS, testTypeName(mode));
+ numSockets, testTypeName(mode));
int listensocket = socket(AF_INET6, SOCK_STREAM, 0);
ASSERT_NE(-1, listensocket) << "Failed to open listen socket";
@@ -263,13 +302,13 @@
using ms = std::chrono::duration<float, std::ratio<1, 1000>>;
- int clientsockets[NUM_SOCKETS], serversockets[NUM_SOCKETS];
- uint16_t clientports[NUM_SOCKETS];
+ int clientsockets[MAX_SOCKETS], serversockets[MAX_SOCKETS];
+ uint16_t clientports[MAX_SOCKETS];
sockaddr_in6 client;
socklen_t clientlen;
auto start = std::chrono::steady_clock::now();
- for (int i = 0; i < NUM_SOCKETS; i++) {
+ for (int i = 0; i < numSockets; i++) {
int s = socket(AF_INET6, SOCK_STREAM, 0);
uid_t uid = START_UID + i;
ASSERT_EQ(0, fchown(s, uid, -1));
@@ -291,7 +330,7 @@
std::chrono::duration_cast<ms>(std::chrono::steady_clock::now() - start).count());
start = std::chrono::steady_clock::now();
- for (int i = 0; i < NUM_SOCKETS; i++) {
+ for (int i = 0; i < numSockets; i++) {
checkSocketState(i, clientsockets[i], "Client socket");
checkSocketState(i, serversockets[i], "Server socket");
}
@@ -299,7 +338,7 @@
std::chrono::duration_cast<ms>(std::chrono::steady_clock::now() - start).count());
start = std::chrono::steady_clock::now();
- for (int i = 0; i < NUM_SOCKETS; i++) {
+ for (int i = 0; i < numSockets; i++) {
close(clientsockets[i]);
close(serversockets[i]);
}
@@ -309,4 +348,8 @@
close(listensocket);
}
-INSTANTIATE_TEST_CASE_P(Address, SockDiagMicroBenchmarkTest, testing::Values(ADDRESS, UID));
+// "SockDiagTest.cpp:232: error: undefined reference to 'SockDiagMicroBenchmarkTest::CLOSE_UID'".
+constexpr int SockDiagMicroBenchmarkTest::CLOSE_UID;
+
+INSTANTIATE_TEST_CASE_P(Address, SockDiagMicroBenchmarkTest,
+ testing::Values(ADDRESS, UID, UIDRANGE));