Correctly handle case where sendmsg returns less than requested bytes.
Disallow kNonBlocking in UnixSocket as it does not work as expected.
Change-Id: I64747acb832b999e31e7b5cca0c583e0e7974532
Bug: 117139237
diff --git a/include/perfetto/base/unix_socket.h b/include/perfetto/base/unix_socket.h
index 7daf5a2..4a613a0 100644
--- a/include/perfetto/base/unix_socket.h
+++ b/include/perfetto/base/unix_socket.h
@@ -54,6 +54,16 @@
base::ScopedFile CreateSocket();
+// Update msghdr so subsequent sendmsg will send data that remains after n bytes
+// have already been sent.
+// This should not be used, it's exported for test use only.
+void ShiftMsgHdr(size_t n, struct msghdr* msg);
+
+// Re-enter sendmsg until all the data has been sent or an error occurs.
+//
+// TODO(fmayer): Figure out how to do timeouts here for heapprofd.
+ssize_t SendMsgAll(int sockfd, struct msghdr* msg, int flags);
+
// A non-blocking UNIX domain socket in SOCK_STREAM mode. Allows also to
// transfer file descriptors. None of the methods in this class are blocking.
// The main design goal is API simplicity and strong guarantees on the
@@ -170,6 +180,8 @@
// EventListener::OnDisconnect() will be called.
// If the socket is not connected, Send() will just return false.
// Does not append a null string terminator to msg in any case.
+ //
+ // DO NOT PASS kNonBlocking, it is broken.
bool Send(const void* msg,
size_t len,
int send_fd = -1,
@@ -179,7 +191,8 @@
const int* send_fds,
size_t num_fds,
BlockingMode blocking = BlockingMode::kNonBlocking);
- bool Send(const std::string& msg);
+ bool Send(const std::string& msg,
+ BlockingMode blockimg = BlockingMode::kNonBlocking);
// Returns the number of bytes (<= |len|) written in |msg| or 0 if there
// is no data in the buffer to read or an error occurs (in which case a
diff --git a/src/base/unix_socket.cc b/src/base/unix_socket.cc
index e3ed5f2..fb4f5ff 100644
--- a/src/base/unix_socket.cc
+++ b/src/base/unix_socket.cc
@@ -64,6 +64,55 @@
#pragma GCC diagnostic ignored "-Wzero-as-null-pointer-constant"
#endif
+void ShiftMsgHdr(size_t n, struct msghdr* msg) {
+ for (size_t i = 0; i < msg->msg_iovlen; ++i) {
+ struct iovec* vec = &msg->msg_iov[i];
+ if (n < vec->iov_len) {
+ // We sent a part of this iovec.
+ vec->iov_base = reinterpret_cast<char*>(vec->iov_base) + n;
+ vec->iov_len -= n;
+ msg->msg_iov = vec;
+ msg->msg_iovlen -= i;
+ return;
+ }
+ // We sent the whole iovec.
+ n -= vec->iov_len;
+ }
+ // We sent all the iovecs.
+ PERFETTO_CHECK(n == 0);
+ msg->msg_iovlen = 0;
+ msg->msg_iov = nullptr;
+}
+
+// For the interested reader, Linux kernel dive to verify this is not only a
+// theoretical possibility: sock_stream_sendmsg, if sock_alloc_send_pskb returns
+// NULL [1] (which it does when it gets interrupted [2]), returns early with the
+// amount of bytes already sent.
+//
+// [1]:
+// https://elixir.bootlin.com/linux/v4.18.10/source/net/unix/af_unix.c#L1872
+// [2]: https://elixir.bootlin.com/linux/v4.18.10/source/net/core/sock.c#L2101
+ssize_t SendMsgAll(int sockfd, struct msghdr* msg, int flags) {
+ // This does not make sense on non-blocking sockets.
+ PERFETTO_DCHECK((fcntl(sockfd, F_GETFL, 0) & O_NONBLOCK) == 0);
+
+ ssize_t total_sent = 0;
+ while (msg->msg_iov) {
+ ssize_t sent = PERFETTO_EINTR(sendmsg(sockfd, msg, flags));
+ if (sent <= 0) {
+ if (sent == -1 && (errno == EAGAIN || errno == EWOULDBLOCK))
+ return total_sent;
+ return sent;
+ }
+ total_sent += sent;
+ ShiftMsgHdr(static_cast<size_t>(sent), msg);
+ // Only send the ancillary data with the first sendmsg call.
+ msg->msg_control = nullptr;
+ msg->msg_controllen = 0;
+ };
+ return total_sent;
+};
+
ssize_t SockSend(int fd,
const void* msg,
size_t len,
@@ -90,7 +139,7 @@
msg_hdr.msg_controllen = cmsg->cmsg_len;
}
- return PERFETTO_EINTR(sendmsg(fd, &msg_hdr, kNoSigPipe));
+ return SendMsgAll(fd, &msg_hdr, kNoSigPipe);
}
ssize_t SockReceive(int fd,
@@ -396,8 +445,8 @@
}
}
-bool UnixSocket::Send(const std::string& msg) {
- return Send(msg.c_str(), msg.size() + 1);
+bool UnixSocket::Send(const std::string& msg, BlockingMode blocking) {
+ return Send(msg.c_str(), msg.size() + 1, -1, blocking);
}
bool UnixSocket::Send(const void* msg,
@@ -414,6 +463,10 @@
const int* send_fds,
size_t num_fds,
BlockingMode blocking_mode) {
+ // TODO(b/117139237): Non-blocking sends are broken because we do not
+ // properly handle partial sends.
+ PERFETTO_DCHECK(blocking_mode == BlockingMode::kBlocking);
+
if (state_ != State::kConnected) {
errno = last_error_ = ENOTCONN;
return false;
diff --git a/src/base/unix_socket_unittest.cc b/src/base/unix_socket_unittest.cc
index 5cdd6ab..382f9bc 100644
--- a/src/base/unix_socket_unittest.cc
+++ b/src/base/unix_socket_unittest.cc
@@ -41,6 +41,7 @@
using ::testing::Mock;
constexpr char kSocketName[] = TEST_SOCK_NAME("unix_socket_unittest");
+constexpr auto kBlocking = UnixSocket::BlockingMode::kBlocking;
class MockEventListener : public UnixSocket::EventListener {
public:
@@ -115,7 +116,7 @@
auto cli_disconnected = task_runner_.CreateCheckpoint("cli_disconnected");
EXPECT_CALL(event_listener_, OnDisconnect(cli.get()))
.WillOnce(InvokeWithoutArgs(cli_disconnected));
- EXPECT_FALSE(cli->Send("whatever"));
+ EXPECT_FALSE(cli->Send("whatever", kBlocking));
task_runner_.RunUntilCheckpoint("cli_disconnected");
}
@@ -153,8 +154,8 @@
ASSERT_EQ("cli>srv", s->ReceiveString());
srv_did_recv();
}));
- ASSERT_TRUE(cli->Send("cli>srv"));
- ASSERT_TRUE(srv_conn->Send("srv>cli"));
+ ASSERT_TRUE(cli->Send("cli>srv", kBlocking));
+ ASSERT_TRUE(srv_conn->Send("srv>cli", kBlocking));
task_runner_.RunUntilCheckpoint("cli_did_recv");
task_runner_.RunUntilCheckpoint("srv_did_recv");
@@ -168,8 +169,8 @@
ASSERT_EQ("", cli->ReceiveString());
ASSERT_EQ(0u, srv_conn->Receive(&msg, sizeof(msg)));
ASSERT_EQ("", srv_conn->ReceiveString());
- ASSERT_FALSE(cli->Send("foo"));
- ASSERT_FALSE(srv_conn->Send("bar"));
+ ASSERT_FALSE(cli->Send("foo", kBlocking));
+ ASSERT_FALSE(srv_conn->Send("bar", kBlocking));
srv->Shutdown(true);
task_runner_.RunUntilCheckpoint("cli_disconnected");
task_runner_.RunUntilCheckpoint("srv_disconnected");
@@ -244,10 +245,10 @@
int buf_fd[2] = {null_fd.get(), zero_fd.get()};
- ASSERT_TRUE(
- cli->Send(cli_str, sizeof(cli_str), buf_fd, base::ArraySize(buf_fd)));
+ ASSERT_TRUE(cli->Send(cli_str, sizeof(cli_str), buf_fd,
+ base::ArraySize(buf_fd), kBlocking));
ASSERT_TRUE(srv_conn->Send(srv_str, sizeof(srv_str), buf_fd,
- base::ArraySize(buf_fd)));
+ base::ArraySize(buf_fd), kBlocking));
task_runner_.RunUntilCheckpoint("srv_did_recv");
task_runner_.RunUntilCheckpoint("cli_did_recv");
@@ -300,7 +301,7 @@
EXPECT_CALL(event_listener_, OnDataAvailable(s))
.WillOnce(Invoke([](UnixSocket* t) {
ASSERT_EQ("PING", t->ReceiveString());
- ASSERT_TRUE(t->Send("PONG"));
+ ASSERT_TRUE(t->Send("PONG", kBlocking));
}));
}));
@@ -309,7 +310,7 @@
EXPECT_CALL(event_listener_, OnConnect(cli[i].get(), true))
.WillOnce(Invoke([](UnixSocket* s, bool success) {
ASSERT_TRUE(success);
- ASSERT_TRUE(s->Send("PING"));
+ ASSERT_TRUE(s->Send("PING", kBlocking));
}));
auto checkpoint = task_runner_.CreateCheckpoint(std::to_string(i));
@@ -356,7 +357,7 @@
.WillOnce(Invoke(
[this, tmp_fd, checkpoint, mem](UnixSocket*, UnixSocket* new_conn) {
ASSERT_EQ(geteuid(), static_cast<uint32_t>(new_conn->peer_uid()));
- ASSERT_TRUE(new_conn->Send("txfd", 5, tmp_fd));
+ ASSERT_TRUE(new_conn->Send("txfd", 5, tmp_fd, kBlocking));
// Wait for the client to change this again.
EXPECT_CALL(event_listener_, OnDataAvailable(new_conn))
.WillOnce(Invoke([checkpoint, mem](UnixSocket* s) {
@@ -391,7 +392,7 @@
// Now change the shared memory and ping the other process.
memcpy(mem, "rock more", 10);
- ASSERT_TRUE(s->Send("change notify"));
+ ASSERT_TRUE(s->Send("change notify", kBlocking));
checkpoint();
}));
task_runner_.RunUntilCheckpoint("change_seen_by_client");
@@ -409,7 +410,7 @@
int num_frame) {
char buf[kAtomicWrites_FrameSize];
memset(buf, static_cast<char>(num_frame), sizeof(buf));
- if (s->Send(buf, sizeof(buf)))
+ if (s->Send(buf, sizeof(buf), -1, kBlocking))
return true;
task_runner->PostTask(
std::bind(&AtomicWrites_SendAttempt, s, task_runner, num_frame));
@@ -560,8 +561,7 @@
char buf[1024 * 32] = {};
tx_task_runner.PostTask([&cli, &buf, all_sent] {
for (size_t i = 0; i < kTotalBytes / sizeof(buf); i++)
- cli->Send(buf, sizeof(buf), -1 /*fd*/,
- UnixSocket::BlockingMode::kBlocking);
+ cli->Send(buf, sizeof(buf), -1 /*fd*/, kBlocking);
all_sent();
});
tx_task_runner.RunUntilCheckpoint("all_sent", kTimeoutMs);
@@ -608,8 +608,7 @@
static constexpr size_t kBufSize = 32 * 1024 * 1024;
std::unique_ptr<char[]> buf(new char[kBufSize]());
tx_task_runner.PostTask([&cli, &buf, send_done] {
- bool send_res = cli->Send(buf.get(), kBufSize, -1 /*fd*/,
- UnixSocket::BlockingMode::kBlocking);
+ bool send_res = cli->Send(buf.get(), kBufSize, -1 /*fd*/, kBlocking);
ASSERT_FALSE(send_res);
send_done();
});
@@ -620,6 +619,166 @@
tx_thread.join();
}
+TEST_F(UnixSocketTest, ShiftMsgHdrSendPartialFirst) {
+ // Send a part of the first iov, then send the rest.
+ struct iovec iov[2] = {};
+ char hello[] = "hello";
+ char world[] = "world";
+ iov[0].iov_base = &hello[0];
+ iov[0].iov_len = base::ArraySize(hello);
+
+ iov[1].iov_base = &world[0];
+ iov[1].iov_len = base::ArraySize(world);
+
+ struct msghdr hdr = {};
+ hdr.msg_iov = iov;
+ hdr.msg_iovlen = base::ArraySize(iov);
+
+ ShiftMsgHdr(1, &hdr);
+ EXPECT_NE(hdr.msg_iov, nullptr);
+ EXPECT_EQ(hdr.msg_iov[0].iov_base, &hello[1]);
+ EXPECT_EQ(hdr.msg_iov[1].iov_base, &world[0]);
+ EXPECT_EQ(hdr.msg_iovlen, 2);
+ EXPECT_STREQ(reinterpret_cast<char*>(hdr.msg_iov[0].iov_base), "ello");
+ EXPECT_EQ(iov[0].iov_len, base::ArraySize(hello) - 1);
+
+ ShiftMsgHdr(base::ArraySize(hello) - 1, &hdr);
+ EXPECT_EQ(hdr.msg_iov, &iov[1]);
+ EXPECT_EQ(hdr.msg_iovlen, 1);
+ EXPECT_STREQ(reinterpret_cast<char*>(hdr.msg_iov[0].iov_base), world);
+ EXPECT_EQ(hdr.msg_iov[0].iov_len, base::ArraySize(world));
+
+ ShiftMsgHdr(base::ArraySize(world), &hdr);
+ EXPECT_EQ(hdr.msg_iov, nullptr);
+ EXPECT_EQ(hdr.msg_iovlen, 0);
+}
+
+TEST_F(UnixSocketTest, ShiftMsgHdrSendFirstAndPartial) {
+ // Send first iov and part of the second iov, then send the rest.
+ struct iovec iov[2] = {};
+ char hello[] = "hello";
+ char world[] = "world";
+ iov[0].iov_base = &hello[0];
+ iov[0].iov_len = base::ArraySize(hello);
+
+ iov[1].iov_base = &world[0];
+ iov[1].iov_len = base::ArraySize(world);
+
+ struct msghdr hdr = {};
+ hdr.msg_iov = iov;
+ hdr.msg_iovlen = base::ArraySize(iov);
+
+ ShiftMsgHdr(base::ArraySize(hello) + 1, &hdr);
+ EXPECT_NE(hdr.msg_iov, nullptr);
+ EXPECT_EQ(hdr.msg_iovlen, 1);
+ EXPECT_STREQ(reinterpret_cast<char*>(hdr.msg_iov[0].iov_base), "orld");
+ EXPECT_EQ(hdr.msg_iov[0].iov_len, base::ArraySize(world) - 1);
+
+ ShiftMsgHdr(base::ArraySize(world) - 1, &hdr);
+ EXPECT_EQ(hdr.msg_iov, nullptr);
+ EXPECT_EQ(hdr.msg_iovlen, 0);
+}
+
+TEST_F(UnixSocketTest, ShiftMsgHdrSendEverything) {
+ // Send everything at once.
+ struct iovec iov[2] = {};
+ char hello[] = "hello";
+ char world[] = "world";
+ iov[0].iov_base = &hello[0];
+ iov[0].iov_len = base::ArraySize(hello);
+
+ iov[1].iov_base = &world[0];
+ iov[1].iov_len = base::ArraySize(world);
+
+ struct msghdr hdr = {};
+ hdr.msg_iov = iov;
+ hdr.msg_iovlen = base::ArraySize(iov);
+
+ ShiftMsgHdr(base::ArraySize(world) + base::ArraySize(hello), &hdr);
+ EXPECT_EQ(hdr.msg_iov, nullptr);
+ EXPECT_EQ(hdr.msg_iovlen, 0);
+}
+
+void Handler(int) {}
+
+int RollbackSigaction(const struct sigaction* act) {
+ return sigaction(SIGWINCH, act, nullptr);
+}
+
+TEST_F(UnixSocketTest, PartialSendMsgAll) {
+ int sv[2];
+ ASSERT_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, sv), 0);
+ base::ScopedFile send_socket(sv[0]);
+ base::ScopedFile recv_socket(sv[1]);
+
+ // Set bufsize to minimum.
+ int bufsize = 1024;
+ ASSERT_EQ(setsockopt(*send_socket, SOL_SOCKET, SO_SNDBUF, &bufsize,
+ sizeof(bufsize)),
+ 0);
+ ASSERT_EQ(setsockopt(*recv_socket, SOL_SOCKET, SO_RCVBUF, &bufsize,
+ sizeof(bufsize)),
+ 0);
+
+ // Send something larger than send + recv kernel buffers combined to make
+ // sendmsg block.
+ char send_buf[8192];
+ // Make MSAN happy.
+ for (size_t i = 0; i < sizeof(send_buf); ++i)
+ send_buf[i] = static_cast<char>(i % 256);
+ char recv_buf[sizeof(send_buf)];
+
+ // Need to install signal handler to cause the interrupt to happen.
+ // man 3 pthread_kill:
+ // Signal dispositions are process-wide: if a signal handler is
+ // installed, the handler will be invoked in the thread thread, but if
+ // the disposition of the signal is "stop", "continue", or "terminate",
+ // this action will affect the whole process.
+ struct sigaction oldact;
+ struct sigaction newact = {};
+ newact.sa_handler = Handler;
+ ASSERT_EQ(sigaction(SIGWINCH, &newact, &oldact), 0);
+ base::ScopedResource<const struct sigaction*, RollbackSigaction, nullptr>
+ rollback(&oldact);
+
+ auto blocked_thread = pthread_self();
+ std::thread th([blocked_thread, &recv_socket, &recv_buf] {
+ ssize_t rd = PERFETTO_EINTR(read(*recv_socket, recv_buf, 1));
+ ASSERT_EQ(rd, 1);
+ // We are now sure the other thread is in sendmsg, interrupt send.
+ ASSERT_EQ(pthread_kill(blocked_thread, SIGWINCH), 0);
+ // Drain the socket to allow SendMsgAll to succeed.
+ size_t offset = 1;
+ while (offset < sizeof(recv_buf)) {
+ rd = PERFETTO_EINTR(
+ read(*recv_socket, recv_buf + offset, sizeof(recv_buf) - offset));
+ ASSERT_GE(rd, 0);
+ offset += static_cast<size_t>(rd);
+ }
+ });
+
+ // Test sending the send_buf in several chunks as an iov to exercise the
+ // more complicated code-paths of SendMsgAll.
+ struct msghdr hdr = {};
+ struct iovec iov[4];
+ static_assert(sizeof(send_buf) % base::ArraySize(iov) == 0,
+ "Cannot split buffer into even pieces.");
+ constexpr size_t kChunkSize = sizeof(send_buf) / base::ArraySize(iov);
+ for (size_t i = 0; i < base::ArraySize(iov); ++i) {
+ iov[i].iov_base = send_buf + i * kChunkSize;
+ iov[i].iov_len = kChunkSize;
+ }
+ hdr.msg_iov = iov;
+ hdr.msg_iovlen = base::ArraySize(iov);
+
+ ASSERT_EQ(SendMsgAll(*send_socket, &hdr, 0), sizeof(send_buf));
+ send_socket.reset();
+ th.join();
+ // Make sure the re-entry logic was actually triggered.
+ ASSERT_EQ(hdr.msg_iov, nullptr);
+ ASSERT_EQ(memcmp(send_buf, recv_buf, sizeof(send_buf)), 0);
+}
+
// TODO(primiano): add a test to check that in the case of a peer sending a fd
// and the other end just doing a recv (without taking it), the fd is closed and
// not left around.
diff --git a/src/ipc/client_impl_unittest.cc b/src/ipc/client_impl_unittest.cc
index 445e3d0..050fcca 100644
--- a/src/ipc/client_impl_unittest.cc
+++ b/src/ipc/client_impl_unittest.cc
@@ -182,7 +182,8 @@
void Reply(const Frame& frame) {
auto buf = BufferedFrameDeserializer::Serialize(frame);
ASSERT_TRUE(client_sock->is_connected());
- EXPECT_TRUE(client_sock->Send(buf.data(), buf.size(), next_reply_fd));
+ EXPECT_TRUE(client_sock->Send(buf.data(), buf.size(), next_reply_fd,
+ base::UnixSocket::BlockingMode::kBlocking));
next_reply_fd = -1;
}
diff --git a/src/ipc/host_impl_unittest.cc b/src/ipc/host_impl_unittest.cc
index 3ff1f99..26b56c8 100644
--- a/src/ipc/host_impl_unittest.cc
+++ b/src/ipc/host_impl_unittest.cc
@@ -152,7 +152,8 @@
void SendFrame(const Frame& frame, int fd = -1) {
std::string buf = BufferedFrameDeserializer::Serialize(frame);
- ASSERT_TRUE(sock_->Send(buf.data(), buf.size(), fd));
+ ASSERT_TRUE(sock_->Send(buf.data(), buf.size(), fd,
+ base::UnixSocket::BlockingMode::kBlocking));
}
BufferedFrameDeserializer frame_deserializer_;
diff --git a/src/profiling/memory/socket_listener_unittest.cc b/src/profiling/memory/socket_listener_unittest.cc
index 58d5e8f..dd3c1a7 100644
--- a/src/profiling/memory/socket_listener_unittest.cc
+++ b/src/profiling/memory/socket_listener_unittest.cc
@@ -71,8 +71,10 @@
base::ScopedFile(open("/dev/null", O_RDONLY))};
int raw_fds[2] = {*fds[0], *fds[1]};
ASSERT_TRUE(client_socket->Send(&size, sizeof(size), raw_fds,
- base::ArraySize(raw_fds)));
- ASSERT_TRUE(client_socket->Send("1", 1));
+ base::ArraySize(raw_fds),
+ base::UnixSocket::BlockingMode::kBlocking));
+ ASSERT_TRUE(client_socket->Send("1", 1, -1,
+ base::UnixSocket::BlockingMode::kBlocking));
task_runner.RunUntilCheckpoint("callback.called");
}
diff --git a/src/profiling/memory/wire_protocol.cc b/src/profiling/memory/wire_protocol.cc
index 39373f9..57cea48 100644
--- a/src/profiling/memory/wire_protocol.cc
+++ b/src/profiling/memory/wire_protocol.cc
@@ -17,6 +17,7 @@
#include "src/profiling/memory/wire_protocol.h"
#include "perfetto/base/logging.h"
+#include "perfetto/base/unix_socket.h"
#include "perfetto/base/utils.h"
#include <sys/socket.h>
@@ -68,7 +69,7 @@
total_size = iovecs[1].iov_len + iovecs[2].iov_len;
}
- ssize_t sent = sendmsg(sock, &hdr, MSG_NOSIGNAL);
+ ssize_t sent = base::SendMsgAll(sock, &hdr, MSG_NOSIGNAL);
return sent == static_cast<ssize_t>(total_size + sizeof(total_size));
}