Fix DnsTlsSocket fast shutdown path
Previously, DnsTlsSocket's destructor told the loop thread to
perform a clean shutdown by closing an IPC file descriptor.
However, the IPC file descriptor is now an eventfd, which does
not alert the listening thread when it is closed.
This change uses the eventfd counter's sign bit as an indication
that the destructor is requesting an immediate close.
Test: Includes regression test.
Bug: 123212403
Bug: 124058672
Change-Id: I6edc26bf504cbfbba7d055b1f8e52ac70e02c6e0
diff --git a/resolv/DnsTlsSocket.cpp b/resolv/DnsTlsSocket.cpp
index b203731..6c31269 100644
--- a/resolv/DnsTlsSocket.cpp
+++ b/resolv/DnsTlsSocket.cpp
@@ -348,6 +348,8 @@
// If we have pending queries, wait for space to write one.
// Otherwise, listen for new queries.
+ // Note: This blocks the destructor until q is empty, i.e. until all pending
+ // queries are sent or have failed to send.
if (!q.empty()) {
fds[SSLFD].events |= POLLOUT;
} else {
@@ -364,7 +366,7 @@
ALOGV("Poll failed: %d", errno);
break;
}
- if (fds[SSLFD].revents & (POLLIN | POLLERR)) {
+ if (fds[SSLFD].revents & (POLLIN | POLLERR | POLLHUP)) {
if (!readResponse()) {
ALOGV("SSL remote close or read error.");
break;
@@ -377,23 +379,17 @@
ALOGW("Error during eventfd read");
break;
} else if (res == 0) {
- ALOGV("eventfd closed; disconnecting");
+ ALOGW("eventfd closed; disconnecting");
break;
} else if (res != sizeof(num_queries)) {
ALOGE("Int size mismatch: %zd != %zu", res, sizeof(num_queries));
break;
- } else if (num_queries <= 0) {
- ALOGE("eventfd reads should always be positive");
+ } else if (num_queries < 0) {
+ ALOGV("Negative eventfd read indicates destructor-initiated shutdown");
break;
}
// Take ownership of all pending queries. (q is always empty here.)
mQueue.swap(q);
- // The writing thread writes to mQueue and then increments mEventFd, so
- // there should be at least num_queries entries in mQueue.
- if (q.size() < (uint64_t) num_queries) {
- ALOGE("Synchronization error");
- break;
- }
} else if (fds[SSLFD].revents & POLLOUT) {
// q cannot be empty here.
// Sending the entire queue here would risk a TCP flow control deadlock, so
@@ -406,8 +402,6 @@
q.pop_front();
}
}
- ALOGV("Closing event FD");
- mEventFd.reset();
ALOGV("Disconnecting");
sslDisconnect();
ALOGV("Calling onClosed");
@@ -418,12 +412,7 @@
DnsTlsSocket::~DnsTlsSocket() {
ALOGV("Destructor");
// This will trigger an orderly shutdown in loop().
- // In principle there is a data race here: If there is an I/O error in the network thread
- // simultaneous with a call to the destructor in a different thread, both threads could
- // attempt to call mEventFd.reset() at the same time. However, the implementation of
- // UniqueFd::reset appears to be thread-safe, and neither thread reads or writes mEventFd
- // after this point, so we don't expect an issue in practice.
- mEventFd.reset();
+ requestLoopShutdown();
{
// Wait for the orderly shutdown to complete.
std::lock_guard guard(mLock);
@@ -441,10 +430,6 @@
}
bool DnsTlsSocket::query(uint16_t id, const Slice query) {
- if (!mEventFd) {
- return false;
- }
-
// Compose the entire message in a single buffer, so that it can be
// sent as a single TLS record.
std::vector<uint8_t> buf(query.size() + 4);
@@ -460,9 +445,25 @@
mQueue.push(std::move(buf));
// Increment the mEventFd counter by 1.
- constexpr int64_t num_queries = 1;
- int written = write(mEventFd.get(), &num_queries, sizeof(num_queries));
- return written == sizeof(num_queries);
+ return incrementEventFd(1);
+}
+
+void DnsTlsSocket::requestLoopShutdown() {
+ // Write a negative number to the eventfd. This triggers an immediate shutdown.
+ incrementEventFd(INT64_MIN);
+}
+
+bool DnsTlsSocket::incrementEventFd(const int64_t count) {
+ if (!mEventFd) {
+ ALOGV("eventfd is not initialized");
+ return false;
+ }
+ int written = write(mEventFd.get(), &count, sizeof(count));
+ if (written != sizeof(count)) {
+ ALOGE("Failed to increment eventfd by %" PRId64, count);
+ return false;
+ }
+ return true;
}
// Read exactly len bytes into buffer or fail with an SSL error code
diff --git a/resolv/DnsTlsSocket.h b/resolv/DnsTlsSocket.h
index 22285f1..2940500 100644
--- a/resolv/DnsTlsSocket.h
+++ b/resolv/DnsTlsSocket.h
@@ -62,7 +62,7 @@
// notified that the socket is closed.
// Note that success here indicates successful sending, not receipt of a response.
// Thread-safe.
- bool query(uint16_t id, const netdutils::Slice query) override;
+ bool query(uint16_t id, const netdutils::Slice query) override EXCLUDES(mLock);
private:
// Lock to be held by the SSL event loop thread. This is not normally in contention.
@@ -96,6 +96,15 @@
bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock);
bool readResponse() REQUIRES(mLock);
+ // Similar to query(), this function uses incrementEventFd to send a message to the
+ // loop thread. However, instead of incrementing the counter by one (indicating a
+ // new query), it wraps the counter to negative, which we use to indicate a shutdown
+ // request.
+ void requestLoopShutdown() EXCLUDES(mLock);
+
+ // This function sends a message to the loop thread by incrementing mEventFd.
+ bool incrementEventFd(int64_t count) EXCLUDES(mLock);
+
// Queue of pending queries. query() pushes items onto the queue and notifies
// the loop thread by incrementing mEventFd. loop() reads items off the queue.
LockedQueue<std::vector<uint8_t>> mQueue;
@@ -103,8 +112,10 @@
// eventfd socket used for notifying the SSL thread when queries are ready to send.
// This socket acts similarly to an atomic counter, incremented by query() and cleared
// by loop(). We have to use a socket because the SSL thread needs to wait in poll()
- // for input from either a remote server or a query thread.
- // EOF indicates a close request.
+ // for input from either a remote server or a query thread. Since eventfd does not have
+ // EOF, we indicate a close request by setting the counter to a negative number.
+ // This file descriptor is opened by initialize(), and closed implicitly after
+ // destruction.
base::unique_fd mEventFd;
// SSL Socket fields.
diff --git a/resolv/dns_tls_test.cpp b/resolv/dns_tls_test.cpp
index d8626f1..0bcad30 100644
--- a/resolv/dns_tls_test.cpp
+++ b/resolv/dns_tls_test.cpp
@@ -29,6 +29,8 @@
#include "IDnsTlsSocketFactory.h"
#include "IDnsTlsSocketObserver.h"
+#include "dns_responder/dns_tls_frontend.h"
+
#include <chrono>
#include <arpa/inet.h>
#include <android-base/macros.h>
@@ -918,5 +920,44 @@
EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
}
+class StubObserver : public IDnsTlsSocketObserver {
+ public:
+ bool closed = false;
+ void onResponse(std::vector<uint8_t>) override {}
+
+ void onClosed() override { closed = true; }
+};
+
+TEST(DnsTlsSocketTest, SlowDestructor) {
+ constexpr char tls_addr[] = "127.0.0.3";
+ constexpr char tls_port[] = "8530"; // High-numbered port so root isn't required.
+ // This test doesn't perform any queries, so the backend address can be invalid.
+ constexpr char backend_addr[] = "192.0.2.1";
+ constexpr char backend_port[] = "1";
+
+ test::DnsTlsFrontend tls(tls_addr, tls_port, backend_addr, backend_port);
+ ASSERT_TRUE(tls.startServer());
+
+ DnsTlsServer server;
+ parseServer(tls_addr, 8530, &server.ss);
+
+ StubObserver observer;
+ ASSERT_FALSE(observer.closed);
+ DnsTlsSessionCache cache;
+ auto socket = std::make_unique<DnsTlsSocket>(server, MARK, &observer, &cache);
+ ASSERT_TRUE(socket->initialize());
+
+ // Test: Time the socket destructor. This should be fast.
+ auto before = std::chrono::steady_clock::now();
+ socket.reset();
+ auto after = std::chrono::steady_clock::now();
+ auto delay = after - before;
+ ALOGV("Shutdown took %lld ns", delay / std::chrono::nanoseconds{1});
+ EXPECT_TRUE(observer.closed);
+ // Shutdown should complete in milliseconds, but if the shutdown signal is lost
+ // it will wait for the timeout, which is expected to take 20seconds.
+ EXPECT_LT(delay, std::chrono::seconds{5});
+}
+
} // end of namespace net
} // end of namespace android