Read End from queue, Shutdown(SHUT_WR) socket

When an END marker comes through the queue, shutdown the sockets write
end in the ShmToSocket worker thread.

Also switches the Write to a Send with MSG_NOSIGNAL since ShmToSocket
can now potentially write into a closed socket causing a SIGPIPE.

BUG: 72654144
Change-Id: I74a3d9ffbf81ebea84279ba8e55ebd1ad266df6b
diff --git a/common/commands/socket_forward_proxy/main.cpp b/common/commands/socket_forward_proxy/main.cpp
index cb1fadb..2499c56 100644
--- a/common/commands/socket_forward_proxy/main.cpp
+++ b/common/commands/socket_forward_proxy/main.cpp
@@ -36,6 +36,7 @@
 using vsoc::socket_forward::Packet;
 using vsoc::socket_forward::SocketForwardRegionView;
 
+// TODO(haining) accept multiple ports
 #ifdef CUTTLEFISH_HOST
 DEFINE_uint32(port, 0, "Port from which to forward TCP connections.");
 #endif
@@ -48,25 +49,6 @@
       : shm_connection_(std::move(shm_connection)),
         socket_(std::move(socket)){}
 
-  [[nodiscard]] bool closed() {
-    {
-      std::lock_guard<std::mutex> guard(closed_lock_);
-      if (closed_) {
-        return true;
-      }
-    }
-    if (shm_connection_.closed() || !socket_->IsOpen()) {
-      std::lock_guard<std::mutex> guard(closed_lock_);
-      closed_ = true;
-    }
-    return closed_;
-  }
-
-  void close() {
-    std::lock_guard<std::mutex> guard(closed_lock_);
-    closed_ = true;
-  }
-
   static void SocketToShm(std::shared_ptr<Worker> worker) {
     worker->SocketToShmImpl();
   }
@@ -76,23 +58,28 @@
   }
 
  private:
+
+  // *packet will be empty if Read returns 0 or error
+  void SocketRecvPacket(Packet* packet) {
+    auto size = socket_->Read(packet->payload(), sizeof packet->payload());
+    if (size < 0) {
+      size = 0;
+    }
+    packet->set_payload_length(size);
+  }
+
   void SocketToShmImpl() {
     auto shm_sender = shm_connection_.MakeSender();
 
     auto packet = Packet::MakeData();
     while (true) {
-      if (closed()) {
+      SocketRecvPacket(&packet);
+      if (packet.empty()) {
         break;
       }
-      auto size = socket_->Recv(packet.payload(), sizeof packet.payload(), 0);
-      if (size <= 0) {
-        break;
-      }
-      packet.set_payload_length(size);
       shm_sender.Send(packet);
     }
     LOG(INFO) << "Socket to shm exiting";
-    close();
   }
 
   ssize_t SocketSendAll(const Packet& packet) {
@@ -101,8 +88,9 @@
       if (!socket_->IsOpen()) {
         return -1;
       }
-      auto just_written = socket_->Write(packet.payload() + written,
-                                         packet.payload_length() - written);
+      auto just_written = socket_->Send(packet.payload() + written,
+                                         packet.payload_length() - written,
+                                         MSG_NOSIGNAL);
       if (just_written <= 0) {
         LOG(INFO) << "Couldn't write to client: "
                   << strerror(socket_->GetErrno());
@@ -113,13 +101,20 @@
     return written;
   }
 
+  struct SocketShutdown {
+    cvd::SharedFD socket;
+    SocketShutdown(const SocketShutdown&) = delete;
+    SocketShutdown& operator=(const SocketShutdown&) = delete;
+    ~SocketShutdown() {
+      socket->Shutdown(SHUT_WR);
+    }
+  };
+
   void ShmToSocketImpl() {
     auto shm_receiver = shm_connection_.MakeReceiver();
+    SocketShutdown shutdown_socket{socket_};
     Packet packet{};
     while (true) {
-      if (closed()) {
-        break;
-      }
       shm_receiver.Recv(&packet);
       if (packet.IsEnd()) {
         break;
@@ -129,13 +124,10 @@
       }
     }
     LOG(INFO) << "Shm to socket exiting";
-    close();
   }
 
   SocketForwardRegionView::Connection shm_connection_;
   cvd::SharedFD socket_;
-  bool closed_{};
-  std::mutex closed_lock_;
 };
 
 // One thread for reading from shm and writing into a socket.
diff --git a/common/vsoc/lib/socket_forward_region_view.cpp b/common/vsoc/lib/socket_forward_region_view.cpp
index 102a827..6bd7afd 100644
--- a/common/vsoc/lib/socket_forward_region_view.cpp
+++ b/common/vsoc/lib/socket_forward_region_view.cpp
@@ -137,7 +137,6 @@
     queue_pair.port_ = 0;
     queue_pair.queue_state_ = QueuePair::INACTIVE;
   } else {
-    Send(connection_id, {});
     MarkThisSideDisconnected(&queue_pair);
   }
 }
diff --git a/common/vsoc/lib/socket_forward_region_view.h b/common/vsoc/lib/socket_forward_region_view.h
index 89da339..81110db 100644
--- a/common/vsoc/lib/socket_forward_region_view.h
+++ b/common/vsoc/lib/socket_forward_region_view.h
@@ -65,7 +65,7 @@
   }
 
   bool empty() const {
-    return header_.message_type == Header::DATA && header_.payload_length == 0;
+    return IsData() && header_.payload_length == 0;
   }
 
   void set_payload_length(std::uint32_t length) {