Copy queries synchronously in DnsTlsSocket
Prior to this change, each outgoing query was copied only once,
on the DnsTlsSocket's loop thread. This could create a problem
if a misbehaving server sent an erroneous response with a
colliding ID number after the query was given to DnsTlsSocket
but before the copy was made. The erroneous response would
complete the query, causing the caller to deallocate the backing
buffer, resulting in a segfault on copy.
This change moves the copy earlier, onto the calling thread, thus
ensuring that the backing buffer cannot have been deallocated.
Instead of sending the network thread pointers to query buffers,
copies of queries are stored in a shared queue, and the network
thread is notified of new queries on an eventfd socket.
Bug: 122133500
Test: Integrations tests pass, manual tests good. No regression test.
Change-Id: Ia4e72da561aeef69a17e87bfdc7aa04340c12fd0
diff --git a/resolv/DnsTlsSocket.cpp b/resolv/DnsTlsSocket.cpp
index b4d0081..d1455c1 100644
--- a/resolv/DnsTlsSocket.cpp
+++ b/resolv/DnsTlsSocket.cpp
@@ -19,14 +19,15 @@
#include "netd_resolv/DnsTlsSocket.h"
-#include <algorithm>
#include <arpa/inet.h>
#include <arpa/nameser.h>
#include <errno.h>
#include <linux/tcp.h>
#include <openssl/err.h>
#include <openssl/sha.h>
+#include <sys/eventfd.h>
#include <sys/poll.h>
+#include <algorithm>
#include "netd_resolv/DnsTlsSessionCache.h"
#include "netd_resolv/IDnsTlsSocketObserver.h"
@@ -163,14 +164,8 @@
if (!mSsl) {
return false;
}
- int sv[2];
- if (socketpair(AF_LOCAL, SOCK_SEQPACKET, 0, sv)) {
- return false;
- }
- // The two sockets are perfectly symmetrical, so the choice of which one is
- // "in" and which one is "out" is arbitrary.
- mIpcInFd.reset(sv[0]);
- mIpcOutFd.reset(sv[1]);
+
+ mEventFd.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
// Start the I/O loop.
mLoopThread.reset(new std::thread(&DnsTlsSocket::loop, this));
@@ -338,26 +333,25 @@
void DnsTlsSocket::loop() {
std::lock_guard guard(mLock);
- // Buffer at most one query.
- Query q;
+ std::deque<std::vector<uint8_t>> q;
const int timeout_msecs = DnsTlsSocket::kIdleTimeout.count() * 1000;
while (true) {
// poll() ignores negative fds
struct pollfd fds[2] = { { .fd = -1 }, { .fd = -1 } };
- enum { SSLFD = 0, IPCFD = 1 };
+ enum { SSLFD = 0, EVENTFD = 1 };
// Always listen for a response from server.
fds[SSLFD].fd = mSslFd.get();
fds[SSLFD].events = POLLIN;
- // If we have a pending query, also wait for space
- // to write it, otherwise listen for a new query.
- if (!q.query.empty()) {
+ // If we have pending queries, wait for space to write one.
+ // Otherwise, listen for new queries.
+ if (!q.empty()) {
fds[SSLFD].events |= POLLOUT;
} else {
- fds[IPCFD].fd = mIpcOutFd.get();
- fds[IPCFD].events = POLLIN;
+ fds[EVENTFD].fd = mEventFd.get();
+ fds[EVENTFD].events = POLLIN;
}
const int s = TEMP_FAILURE_RETRY(poll(fds, std::size(fds), timeout_msecs));
@@ -375,28 +369,44 @@
break;
}
}
- if (fds[IPCFD].revents & (POLLIN | POLLERR)) {
- int res = read(mIpcOutFd.get(), &q, sizeof(q));
+ if (fds[EVENTFD].revents & (POLLIN | POLLERR)) {
+ int64_t num_queries;
+ ssize_t res = read(mEventFd.get(), &num_queries, sizeof(num_queries));
if (res < 0) {
- ALOGW("Error during IPC read");
+ ALOGW("Error during eventfd read");
break;
} else if (res == 0) {
- ALOGV("IPC channel closed; disconnecting");
+ ALOGV("eventfd closed; disconnecting");
break;
- } else if (res != sizeof(q)) {
- ALOGE("Struct size mismatch: %d != %zu", res, sizeof(q));
+ } else if (res != sizeof(num_queries)) {
+ ALOGE("Int size mismatch: %zd != %zu", res, sizeof(num_queries));
+ break;
+ } else if (num_queries <= 0) {
+ ALOGE("eventfd reads should always be positive");
+ break;
+ }
+ // Take ownership of all pending queries. (q is always empty here.)
+ mQueue.swap(q);
+ // The writing thread writes to mQueue and then increments mEventFd, so
+ // there should be at least num_queries entries in mQueue.
+ if (q.size() < (uint64_t) num_queries) {
+ ALOGE("Synchronization error");
break;
}
} else if (fds[SSLFD].revents & POLLOUT) {
- // query cannot be null here.
- if (!sendQuery(q)) {
+ // q cannot be empty here.
+ // Sending the entire queue here would risk a TCP flow control deadlock, so
+ // we only send a single query on each cycle of this loop.
+ // TODO: Coalesce multiple pending queries if there is enough space in the
+ // write buffer.
+ if (!sendQuery(q.front())) {
break;
}
- q = Query(); // Reset q to empty
+ q.pop_front();
}
}
- ALOGV("Closing IPC read FD");
- mIpcOutFd.reset();
+ ALOGV("Closing event FD");
+ mEventFd.reset();
ALOGV("Disconnecting");
sslDisconnect();
ALOGV("Calling onClosed");
@@ -407,7 +417,12 @@
DnsTlsSocket::~DnsTlsSocket() {
ALOGV("Destructor");
// This will trigger an orderly shutdown in loop().
- mIpcInFd.reset();
+ // In principle there is a data race here: If there is an I/O error in the network thread
+ // simultaneous with a call to the destructor in a different thread, both threads could
+ // attempt to call mEventFd.reset() at the same time. However, the implementation of
+ // UniqueFd::reset appears to be thread-safe, and neither thread reads or writes mEventFd
+ // after this point, so we don't expect an issue in practice.
+ mEventFd.reset();
{
// Wait for the orderly shutdown to complete.
std::lock_guard guard(mLock);
@@ -425,12 +440,28 @@
}
bool DnsTlsSocket::query(uint16_t id, const Slice query) {
- const Query q = { .id = id, .query = query };
- if (!mIpcInFd) {
+ if (!mEventFd) {
return false;
}
- int written = write(mIpcInFd.get(), &q, sizeof(q));
- return written == sizeof(q);
+
+ // Compose the entire message in a single buffer, so that it can be
+ // sent as a single TLS record.
+ std::vector<uint8_t> buf(query.size() + 4);
+ // Write 2-byte length
+ uint16_t len = query.size() + 2; // + 2 for the ID.
+ buf[0] = len >> 8;
+ buf[1] = len;
+ // Write 2-byte ID
+ buf[2] = id >> 8;
+ buf[3] = id;
+ // Copy body
+ std::memcpy(buf.data() + 4, query.base(), query.size());
+
+ mQueue.push(std::move(buf));
+ // Increment the mEventFd counter by 1.
+ constexpr int64_t num_queries = 1;
+ int written = write(mEventFd.get(), &num_queries, sizeof(num_queries));
+ return written == sizeof(num_queries);
}
// Read exactly len bytes into buffer or fail with an SSL error code
@@ -464,20 +495,7 @@
return SSL_ERROR_NONE;
}
-bool DnsTlsSocket::sendQuery(const Query& q) {
- ALOGV("sending query");
- // Compose the entire message in a single buffer, so that it can be
- // sent as a single TLS record.
- std::vector<uint8_t> buf(q.query.size() + 4);
- // Write 2-byte length
- uint16_t len = q.query.size() + 2; // + 2 for the ID.
- buf[0] = len >> 8;
- buf[1] = len;
- // Write 2-byte ID
- buf[2] = q.id >> 8;
- buf[3] = q.id;
- // Copy body
- std::memcpy(buf.data() + 4, q.query.base(), q.query.size());
+bool DnsTlsSocket::sendQuery(const std::vector<uint8_t>& buf) {
if (!sslWrite(netdutils::makeSlice(buf))) {
return false;
}