Correctly handle case where sendmsg returns less than requested bytes.

Disallow kNonBlocking in UnixSocket as it does not work as expected.

Change-Id: I64747acb832b999e31e7b5cca0c583e0e7974532
Bug: 117139237
diff --git a/include/perfetto/base/unix_socket.h b/include/perfetto/base/unix_socket.h
index 7daf5a2..4a613a0 100644
--- a/include/perfetto/base/unix_socket.h
+++ b/include/perfetto/base/unix_socket.h
@@ -54,6 +54,16 @@
 
 base::ScopedFile CreateSocket();
 
+// Update msghdr so subsequent sendmsg will send data that remains after n bytes
+// have already been sent.
+// This should not be used, it's exported for test use only.
+void ShiftMsgHdr(size_t n, struct msghdr* msg);
+
+// Re-enter sendmsg until all the data has been sent or an error occurs.
+//
+// TODO(fmayer): Figure out how to do timeouts here for heapprofd.
+ssize_t SendMsgAll(int sockfd, struct msghdr* msg, int flags);
+
 // A non-blocking UNIX domain socket in SOCK_STREAM mode. Allows also to
 // transfer file descriptors. None of the methods in this class are blocking.
 // The main design goal is API simplicity and strong guarantees on the
@@ -170,6 +180,8 @@
   // EventListener::OnDisconnect() will be called.
   // If the socket is not connected, Send() will just return false.
   // Does not append a null string terminator to msg in any case.
+  //
+  // DO NOT PASS kNonBlocking, it is broken.
   bool Send(const void* msg,
             size_t len,
             int send_fd = -1,
@@ -179,7 +191,8 @@
             const int* send_fds,
             size_t num_fds,
             BlockingMode blocking = BlockingMode::kNonBlocking);
-  bool Send(const std::string& msg);
+  bool Send(const std::string& msg,
+            BlockingMode blockimg = BlockingMode::kNonBlocking);
 
   // Returns the number of bytes (<= |len|) written in |msg| or 0 if there
   // is no data in the buffer to read or an error occurs (in which case a
diff --git a/src/base/unix_socket.cc b/src/base/unix_socket.cc
index e3ed5f2..fb4f5ff 100644
--- a/src/base/unix_socket.cc
+++ b/src/base/unix_socket.cc
@@ -64,6 +64,55 @@
 #pragma GCC diagnostic ignored "-Wzero-as-null-pointer-constant"
 #endif
 
+void ShiftMsgHdr(size_t n, struct msghdr* msg) {
+  for (size_t i = 0; i < msg->msg_iovlen; ++i) {
+    struct iovec* vec = &msg->msg_iov[i];
+    if (n < vec->iov_len) {
+      // We sent a part of this iovec.
+      vec->iov_base = reinterpret_cast<char*>(vec->iov_base) + n;
+      vec->iov_len -= n;
+      msg->msg_iov = vec;
+      msg->msg_iovlen -= i;
+      return;
+    }
+    // We sent the whole iovec.
+    n -= vec->iov_len;
+  }
+  // We sent all the iovecs.
+  PERFETTO_CHECK(n == 0);
+  msg->msg_iovlen = 0;
+  msg->msg_iov = nullptr;
+}
+
+// For the interested reader, Linux kernel dive to verify this is not only a
+// theoretical possibility: sock_stream_sendmsg, if sock_alloc_send_pskb returns
+// NULL [1] (which it does when it gets interrupted [2]), returns early with the
+// amount of bytes already sent.
+//
+// [1]:
+// https://elixir.bootlin.com/linux/v4.18.10/source/net/unix/af_unix.c#L1872
+// [2]: https://elixir.bootlin.com/linux/v4.18.10/source/net/core/sock.c#L2101
+ssize_t SendMsgAll(int sockfd, struct msghdr* msg, int flags) {
+  // This does not make sense on non-blocking sockets.
+  PERFETTO_DCHECK((fcntl(sockfd, F_GETFL, 0) & O_NONBLOCK) == 0);
+
+  ssize_t total_sent = 0;
+  while (msg->msg_iov) {
+    ssize_t sent = PERFETTO_EINTR(sendmsg(sockfd, msg, flags));
+    if (sent <= 0) {
+      if (sent == -1 && (errno == EAGAIN || errno == EWOULDBLOCK))
+        return total_sent;
+      return sent;
+    }
+    total_sent += sent;
+    ShiftMsgHdr(static_cast<size_t>(sent), msg);
+    // Only send the ancillary data with the first sendmsg call.
+    msg->msg_control = nullptr;
+    msg->msg_controllen = 0;
+  };
+  return total_sent;
+};
+
 ssize_t SockSend(int fd,
                  const void* msg,
                  size_t len,
@@ -90,7 +139,7 @@
     msg_hdr.msg_controllen = cmsg->cmsg_len;
   }
 
-  return PERFETTO_EINTR(sendmsg(fd, &msg_hdr, kNoSigPipe));
+  return SendMsgAll(fd, &msg_hdr, kNoSigPipe);
 }
 
 ssize_t SockReceive(int fd,
@@ -396,8 +445,8 @@
   }
 }
 
-bool UnixSocket::Send(const std::string& msg) {
-  return Send(msg.c_str(), msg.size() + 1);
+bool UnixSocket::Send(const std::string& msg, BlockingMode blocking) {
+  return Send(msg.c_str(), msg.size() + 1, -1, blocking);
 }
 
 bool UnixSocket::Send(const void* msg,
@@ -414,6 +463,10 @@
                       const int* send_fds,
                       size_t num_fds,
                       BlockingMode blocking_mode) {
+  // TODO(b/117139237): Non-blocking sends are broken because we do not
+  // properly handle partial sends.
+  PERFETTO_DCHECK(blocking_mode == BlockingMode::kBlocking);
+
   if (state_ != State::kConnected) {
     errno = last_error_ = ENOTCONN;
     return false;
diff --git a/src/base/unix_socket_unittest.cc b/src/base/unix_socket_unittest.cc
index 5cdd6ab..382f9bc 100644
--- a/src/base/unix_socket_unittest.cc
+++ b/src/base/unix_socket_unittest.cc
@@ -41,6 +41,7 @@
 using ::testing::Mock;
 
 constexpr char kSocketName[] = TEST_SOCK_NAME("unix_socket_unittest");
+constexpr auto kBlocking = UnixSocket::BlockingMode::kBlocking;
 
 class MockEventListener : public UnixSocket::EventListener {
  public:
@@ -115,7 +116,7 @@
   auto cli_disconnected = task_runner_.CreateCheckpoint("cli_disconnected");
   EXPECT_CALL(event_listener_, OnDisconnect(cli.get()))
       .WillOnce(InvokeWithoutArgs(cli_disconnected));
-  EXPECT_FALSE(cli->Send("whatever"));
+  EXPECT_FALSE(cli->Send("whatever", kBlocking));
   task_runner_.RunUntilCheckpoint("cli_disconnected");
 }
 
@@ -153,8 +154,8 @@
         ASSERT_EQ("cli>srv", s->ReceiveString());
         srv_did_recv();
       }));
-  ASSERT_TRUE(cli->Send("cli>srv"));
-  ASSERT_TRUE(srv_conn->Send("srv>cli"));
+  ASSERT_TRUE(cli->Send("cli>srv", kBlocking));
+  ASSERT_TRUE(srv_conn->Send("srv>cli", kBlocking));
   task_runner_.RunUntilCheckpoint("cli_did_recv");
   task_runner_.RunUntilCheckpoint("srv_did_recv");
 
@@ -168,8 +169,8 @@
   ASSERT_EQ("", cli->ReceiveString());
   ASSERT_EQ(0u, srv_conn->Receive(&msg, sizeof(msg)));
   ASSERT_EQ("", srv_conn->ReceiveString());
-  ASSERT_FALSE(cli->Send("foo"));
-  ASSERT_FALSE(srv_conn->Send("bar"));
+  ASSERT_FALSE(cli->Send("foo", kBlocking));
+  ASSERT_FALSE(srv_conn->Send("bar", kBlocking));
   srv->Shutdown(true);
   task_runner_.RunUntilCheckpoint("cli_disconnected");
   task_runner_.RunUntilCheckpoint("srv_disconnected");
@@ -244,10 +245,10 @@
 
   int buf_fd[2] = {null_fd.get(), zero_fd.get()};
 
-  ASSERT_TRUE(
-      cli->Send(cli_str, sizeof(cli_str), buf_fd, base::ArraySize(buf_fd)));
+  ASSERT_TRUE(cli->Send(cli_str, sizeof(cli_str), buf_fd,
+                        base::ArraySize(buf_fd), kBlocking));
   ASSERT_TRUE(srv_conn->Send(srv_str, sizeof(srv_str), buf_fd,
-                             base::ArraySize(buf_fd)));
+                             base::ArraySize(buf_fd), kBlocking));
   task_runner_.RunUntilCheckpoint("srv_did_recv");
   task_runner_.RunUntilCheckpoint("cli_did_recv");
 
@@ -300,7 +301,7 @@
         EXPECT_CALL(event_listener_, OnDataAvailable(s))
             .WillOnce(Invoke([](UnixSocket* t) {
               ASSERT_EQ("PING", t->ReceiveString());
-              ASSERT_TRUE(t->Send("PONG"));
+              ASSERT_TRUE(t->Send("PONG", kBlocking));
             }));
       }));
 
@@ -309,7 +310,7 @@
     EXPECT_CALL(event_listener_, OnConnect(cli[i].get(), true))
         .WillOnce(Invoke([](UnixSocket* s, bool success) {
           ASSERT_TRUE(success);
-          ASSERT_TRUE(s->Send("PING"));
+          ASSERT_TRUE(s->Send("PING", kBlocking));
         }));
 
     auto checkpoint = task_runner_.CreateCheckpoint(std::to_string(i));
@@ -356,7 +357,7 @@
         .WillOnce(Invoke(
             [this, tmp_fd, checkpoint, mem](UnixSocket*, UnixSocket* new_conn) {
               ASSERT_EQ(geteuid(), static_cast<uint32_t>(new_conn->peer_uid()));
-              ASSERT_TRUE(new_conn->Send("txfd", 5, tmp_fd));
+              ASSERT_TRUE(new_conn->Send("txfd", 5, tmp_fd, kBlocking));
               // Wait for the client to change this again.
               EXPECT_CALL(event_listener_, OnDataAvailable(new_conn))
                   .WillOnce(Invoke([checkpoint, mem](UnixSocket* s) {
@@ -391,7 +392,7 @@
 
           // Now change the shared memory and ping the other process.
           memcpy(mem, "rock more", 10);
-          ASSERT_TRUE(s->Send("change notify"));
+          ASSERT_TRUE(s->Send("change notify", kBlocking));
           checkpoint();
         }));
     task_runner_.RunUntilCheckpoint("change_seen_by_client");
@@ -409,7 +410,7 @@
                               int num_frame) {
   char buf[kAtomicWrites_FrameSize];
   memset(buf, static_cast<char>(num_frame), sizeof(buf));
-  if (s->Send(buf, sizeof(buf)))
+  if (s->Send(buf, sizeof(buf), -1, kBlocking))
     return true;
   task_runner->PostTask(
       std::bind(&AtomicWrites_SendAttempt, s, task_runner, num_frame));
@@ -560,8 +561,7 @@
     char buf[1024 * 32] = {};
     tx_task_runner.PostTask([&cli, &buf, all_sent] {
       for (size_t i = 0; i < kTotalBytes / sizeof(buf); i++)
-        cli->Send(buf, sizeof(buf), -1 /*fd*/,
-                  UnixSocket::BlockingMode::kBlocking);
+        cli->Send(buf, sizeof(buf), -1 /*fd*/, kBlocking);
       all_sent();
     });
     tx_task_runner.RunUntilCheckpoint("all_sent", kTimeoutMs);
@@ -608,8 +608,7 @@
     static constexpr size_t kBufSize = 32 * 1024 * 1024;
     std::unique_ptr<char[]> buf(new char[kBufSize]());
     tx_task_runner.PostTask([&cli, &buf, send_done] {
-      bool send_res = cli->Send(buf.get(), kBufSize, -1 /*fd*/,
-                                UnixSocket::BlockingMode::kBlocking);
+      bool send_res = cli->Send(buf.get(), kBufSize, -1 /*fd*/, kBlocking);
       ASSERT_FALSE(send_res);
       send_done();
     });
@@ -620,6 +619,166 @@
   tx_thread.join();
 }
 
+TEST_F(UnixSocketTest, ShiftMsgHdrSendPartialFirst) {
+  // Send a part of the first iov, then send the rest.
+  struct iovec iov[2] = {};
+  char hello[] = "hello";
+  char world[] = "world";
+  iov[0].iov_base = &hello[0];
+  iov[0].iov_len = base::ArraySize(hello);
+
+  iov[1].iov_base = &world[0];
+  iov[1].iov_len = base::ArraySize(world);
+
+  struct msghdr hdr = {};
+  hdr.msg_iov = iov;
+  hdr.msg_iovlen = base::ArraySize(iov);
+
+  ShiftMsgHdr(1, &hdr);
+  EXPECT_NE(hdr.msg_iov, nullptr);
+  EXPECT_EQ(hdr.msg_iov[0].iov_base, &hello[1]);
+  EXPECT_EQ(hdr.msg_iov[1].iov_base, &world[0]);
+  EXPECT_EQ(hdr.msg_iovlen, 2);
+  EXPECT_STREQ(reinterpret_cast<char*>(hdr.msg_iov[0].iov_base), "ello");
+  EXPECT_EQ(iov[0].iov_len, base::ArraySize(hello) - 1);
+
+  ShiftMsgHdr(base::ArraySize(hello) - 1, &hdr);
+  EXPECT_EQ(hdr.msg_iov, &iov[1]);
+  EXPECT_EQ(hdr.msg_iovlen, 1);
+  EXPECT_STREQ(reinterpret_cast<char*>(hdr.msg_iov[0].iov_base), world);
+  EXPECT_EQ(hdr.msg_iov[0].iov_len, base::ArraySize(world));
+
+  ShiftMsgHdr(base::ArraySize(world), &hdr);
+  EXPECT_EQ(hdr.msg_iov, nullptr);
+  EXPECT_EQ(hdr.msg_iovlen, 0);
+}
+
+TEST_F(UnixSocketTest, ShiftMsgHdrSendFirstAndPartial) {
+  // Send first iov and part of the second iov, then send the rest.
+  struct iovec iov[2] = {};
+  char hello[] = "hello";
+  char world[] = "world";
+  iov[0].iov_base = &hello[0];
+  iov[0].iov_len = base::ArraySize(hello);
+
+  iov[1].iov_base = &world[0];
+  iov[1].iov_len = base::ArraySize(world);
+
+  struct msghdr hdr = {};
+  hdr.msg_iov = iov;
+  hdr.msg_iovlen = base::ArraySize(iov);
+
+  ShiftMsgHdr(base::ArraySize(hello) + 1, &hdr);
+  EXPECT_NE(hdr.msg_iov, nullptr);
+  EXPECT_EQ(hdr.msg_iovlen, 1);
+  EXPECT_STREQ(reinterpret_cast<char*>(hdr.msg_iov[0].iov_base), "orld");
+  EXPECT_EQ(hdr.msg_iov[0].iov_len, base::ArraySize(world) - 1);
+
+  ShiftMsgHdr(base::ArraySize(world) - 1, &hdr);
+  EXPECT_EQ(hdr.msg_iov, nullptr);
+  EXPECT_EQ(hdr.msg_iovlen, 0);
+}
+
+TEST_F(UnixSocketTest, ShiftMsgHdrSendEverything) {
+  // Send everything at once.
+  struct iovec iov[2] = {};
+  char hello[] = "hello";
+  char world[] = "world";
+  iov[0].iov_base = &hello[0];
+  iov[0].iov_len = base::ArraySize(hello);
+
+  iov[1].iov_base = &world[0];
+  iov[1].iov_len = base::ArraySize(world);
+
+  struct msghdr hdr = {};
+  hdr.msg_iov = iov;
+  hdr.msg_iovlen = base::ArraySize(iov);
+
+  ShiftMsgHdr(base::ArraySize(world) + base::ArraySize(hello), &hdr);
+  EXPECT_EQ(hdr.msg_iov, nullptr);
+  EXPECT_EQ(hdr.msg_iovlen, 0);
+}
+
+void Handler(int) {}
+
+int RollbackSigaction(const struct sigaction* act) {
+  return sigaction(SIGWINCH, act, nullptr);
+}
+
+TEST_F(UnixSocketTest, PartialSendMsgAll) {
+  int sv[2];
+  ASSERT_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, sv), 0);
+  base::ScopedFile send_socket(sv[0]);
+  base::ScopedFile recv_socket(sv[1]);
+
+  // Set bufsize to minimum.
+  int bufsize = 1024;
+  ASSERT_EQ(setsockopt(*send_socket, SOL_SOCKET, SO_SNDBUF, &bufsize,
+                       sizeof(bufsize)),
+            0);
+  ASSERT_EQ(setsockopt(*recv_socket, SOL_SOCKET, SO_RCVBUF, &bufsize,
+                       sizeof(bufsize)),
+            0);
+
+  // Send something larger than send + recv kernel buffers combined to make
+  // sendmsg block.
+  char send_buf[8192];
+  // Make MSAN happy.
+  for (size_t i = 0; i < sizeof(send_buf); ++i)
+    send_buf[i] = static_cast<char>(i % 256);
+  char recv_buf[sizeof(send_buf)];
+
+  // Need to install signal handler to cause the interrupt to happen.
+  // man 3 pthread_kill:
+  //   Signal dispositions are process-wide: if a signal handler is
+  //   installed, the handler will be invoked in the thread thread, but if
+  //   the disposition of the signal is "stop", "continue", or "terminate",
+  //   this action will affect the whole process.
+  struct sigaction oldact;
+  struct sigaction newact = {};
+  newact.sa_handler = Handler;
+  ASSERT_EQ(sigaction(SIGWINCH, &newact, &oldact), 0);
+  base::ScopedResource<const struct sigaction*, RollbackSigaction, nullptr>
+      rollback(&oldact);
+
+  auto blocked_thread = pthread_self();
+  std::thread th([blocked_thread, &recv_socket, &recv_buf] {
+    ssize_t rd = PERFETTO_EINTR(read(*recv_socket, recv_buf, 1));
+    ASSERT_EQ(rd, 1);
+    // We are now sure the other thread is in sendmsg, interrupt send.
+    ASSERT_EQ(pthread_kill(blocked_thread, SIGWINCH), 0);
+    // Drain the socket to allow SendMsgAll to succeed.
+    size_t offset = 1;
+    while (offset < sizeof(recv_buf)) {
+      rd = PERFETTO_EINTR(
+          read(*recv_socket, recv_buf + offset, sizeof(recv_buf) - offset));
+      ASSERT_GE(rd, 0);
+      offset += static_cast<size_t>(rd);
+    }
+  });
+
+  // Test sending the send_buf in several chunks as an iov to exercise the
+  // more complicated code-paths of SendMsgAll.
+  struct msghdr hdr = {};
+  struct iovec iov[4];
+  static_assert(sizeof(send_buf) % base::ArraySize(iov) == 0,
+                "Cannot split buffer into even pieces.");
+  constexpr size_t kChunkSize = sizeof(send_buf) / base::ArraySize(iov);
+  for (size_t i = 0; i < base::ArraySize(iov); ++i) {
+    iov[i].iov_base = send_buf + i * kChunkSize;
+    iov[i].iov_len = kChunkSize;
+  }
+  hdr.msg_iov = iov;
+  hdr.msg_iovlen = base::ArraySize(iov);
+
+  ASSERT_EQ(SendMsgAll(*send_socket, &hdr, 0), sizeof(send_buf));
+  send_socket.reset();
+  th.join();
+  // Make sure the re-entry logic was actually triggered.
+  ASSERT_EQ(hdr.msg_iov, nullptr);
+  ASSERT_EQ(memcmp(send_buf, recv_buf, sizeof(send_buf)), 0);
+}
+
 // TODO(primiano): add a test to check that in the case of a peer sending a fd
 // and the other end just doing a recv (without taking it), the fd is closed and
 // not left around.
diff --git a/src/ipc/client_impl_unittest.cc b/src/ipc/client_impl_unittest.cc
index 445e3d0..050fcca 100644
--- a/src/ipc/client_impl_unittest.cc
+++ b/src/ipc/client_impl_unittest.cc
@@ -182,7 +182,8 @@
   void Reply(const Frame& frame) {
     auto buf = BufferedFrameDeserializer::Serialize(frame);
     ASSERT_TRUE(client_sock->is_connected());
-    EXPECT_TRUE(client_sock->Send(buf.data(), buf.size(), next_reply_fd));
+    EXPECT_TRUE(client_sock->Send(buf.data(), buf.size(), next_reply_fd,
+                                  base::UnixSocket::BlockingMode::kBlocking));
     next_reply_fd = -1;
   }
 
diff --git a/src/ipc/host_impl_unittest.cc b/src/ipc/host_impl_unittest.cc
index 3ff1f99..26b56c8 100644
--- a/src/ipc/host_impl_unittest.cc
+++ b/src/ipc/host_impl_unittest.cc
@@ -152,7 +152,8 @@
 
   void SendFrame(const Frame& frame, int fd = -1) {
     std::string buf = BufferedFrameDeserializer::Serialize(frame);
-    ASSERT_TRUE(sock_->Send(buf.data(), buf.size(), fd));
+    ASSERT_TRUE(sock_->Send(buf.data(), buf.size(), fd,
+                            base::UnixSocket::BlockingMode::kBlocking));
   }
 
   BufferedFrameDeserializer frame_deserializer_;
diff --git a/src/profiling/memory/socket_listener_unittest.cc b/src/profiling/memory/socket_listener_unittest.cc
index 58d5e8f..dd3c1a7 100644
--- a/src/profiling/memory/socket_listener_unittest.cc
+++ b/src/profiling/memory/socket_listener_unittest.cc
@@ -71,8 +71,10 @@
                              base::ScopedFile(open("/dev/null", O_RDONLY))};
   int raw_fds[2] = {*fds[0], *fds[1]};
   ASSERT_TRUE(client_socket->Send(&size, sizeof(size), raw_fds,
-                                  base::ArraySize(raw_fds)));
-  ASSERT_TRUE(client_socket->Send("1", 1));
+                                  base::ArraySize(raw_fds),
+                                  base::UnixSocket::BlockingMode::kBlocking));
+  ASSERT_TRUE(client_socket->Send("1", 1, -1,
+                                  base::UnixSocket::BlockingMode::kBlocking));
 
   task_runner.RunUntilCheckpoint("callback.called");
 }
diff --git a/src/profiling/memory/wire_protocol.cc b/src/profiling/memory/wire_protocol.cc
index 39373f9..57cea48 100644
--- a/src/profiling/memory/wire_protocol.cc
+++ b/src/profiling/memory/wire_protocol.cc
@@ -17,6 +17,7 @@
 #include "src/profiling/memory/wire_protocol.h"
 
 #include "perfetto/base/logging.h"
+#include "perfetto/base/unix_socket.h"
 #include "perfetto/base/utils.h"
 
 #include <sys/socket.h>
@@ -68,7 +69,7 @@
     total_size = iovecs[1].iov_len + iovecs[2].iov_len;
   }
 
-  ssize_t sent = sendmsg(sock, &hdr, MSG_NOSIGNAL);
+  ssize_t sent = base::SendMsgAll(sock, &hdr, MSG_NOSIGNAL);
   return sent == static_cast<ssize_t>(total_size + sizeof(total_size));
 }