Read End from queue, Shutdown(SHUT_WR) socket
When an END marker comes through the queue, shutdown the sockets write
end in the ShmToSocket worker thread.
Also switches the Write to a Send with MSG_NOSIGNAL since ShmToSocket
can now potentially write into a closed socket causing a SIGPIPE.
BUG: 72654144
Change-Id: I74a3d9ffbf81ebea84279ba8e55ebd1ad266df6b
diff --git a/common/commands/socket_forward_proxy/main.cpp b/common/commands/socket_forward_proxy/main.cpp
index cb1fadb..2499c56 100644
--- a/common/commands/socket_forward_proxy/main.cpp
+++ b/common/commands/socket_forward_proxy/main.cpp
@@ -36,6 +36,7 @@
using vsoc::socket_forward::Packet;
using vsoc::socket_forward::SocketForwardRegionView;
+// TODO(haining) accept multiple ports
#ifdef CUTTLEFISH_HOST
DEFINE_uint32(port, 0, "Port from which to forward TCP connections.");
#endif
@@ -48,25 +49,6 @@
: shm_connection_(std::move(shm_connection)),
socket_(std::move(socket)){}
- [[nodiscard]] bool closed() {
- {
- std::lock_guard<std::mutex> guard(closed_lock_);
- if (closed_) {
- return true;
- }
- }
- if (shm_connection_.closed() || !socket_->IsOpen()) {
- std::lock_guard<std::mutex> guard(closed_lock_);
- closed_ = true;
- }
- return closed_;
- }
-
- void close() {
- std::lock_guard<std::mutex> guard(closed_lock_);
- closed_ = true;
- }
-
static void SocketToShm(std::shared_ptr<Worker> worker) {
worker->SocketToShmImpl();
}
@@ -76,23 +58,28 @@
}
private:
+
+ // *packet will be empty if Read returns 0 or error
+ void SocketRecvPacket(Packet* packet) {
+ auto size = socket_->Read(packet->payload(), sizeof packet->payload());
+ if (size < 0) {
+ size = 0;
+ }
+ packet->set_payload_length(size);
+ }
+
void SocketToShmImpl() {
auto shm_sender = shm_connection_.MakeSender();
auto packet = Packet::MakeData();
while (true) {
- if (closed()) {
+ SocketRecvPacket(&packet);
+ if (packet.empty()) {
break;
}
- auto size = socket_->Recv(packet.payload(), sizeof packet.payload(), 0);
- if (size <= 0) {
- break;
- }
- packet.set_payload_length(size);
shm_sender.Send(packet);
}
LOG(INFO) << "Socket to shm exiting";
- close();
}
ssize_t SocketSendAll(const Packet& packet) {
@@ -101,8 +88,9 @@
if (!socket_->IsOpen()) {
return -1;
}
- auto just_written = socket_->Write(packet.payload() + written,
- packet.payload_length() - written);
+ auto just_written = socket_->Send(packet.payload() + written,
+ packet.payload_length() - written,
+ MSG_NOSIGNAL);
if (just_written <= 0) {
LOG(INFO) << "Couldn't write to client: "
<< strerror(socket_->GetErrno());
@@ -113,13 +101,20 @@
return written;
}
+ struct SocketShutdown {
+ cvd::SharedFD socket;
+ SocketShutdown(const SocketShutdown&) = delete;
+ SocketShutdown& operator=(const SocketShutdown&) = delete;
+ ~SocketShutdown() {
+ socket->Shutdown(SHUT_WR);
+ }
+ };
+
void ShmToSocketImpl() {
auto shm_receiver = shm_connection_.MakeReceiver();
+ SocketShutdown shutdown_socket{socket_};
Packet packet{};
while (true) {
- if (closed()) {
- break;
- }
shm_receiver.Recv(&packet);
if (packet.IsEnd()) {
break;
@@ -129,13 +124,10 @@
}
}
LOG(INFO) << "Shm to socket exiting";
- close();
}
SocketForwardRegionView::Connection shm_connection_;
cvd::SharedFD socket_;
- bool closed_{};
- std::mutex closed_lock_;
};
// One thread for reading from shm and writing into a socket.
diff --git a/common/vsoc/lib/socket_forward_region_view.cpp b/common/vsoc/lib/socket_forward_region_view.cpp
index 102a827..6bd7afd 100644
--- a/common/vsoc/lib/socket_forward_region_view.cpp
+++ b/common/vsoc/lib/socket_forward_region_view.cpp
@@ -137,7 +137,6 @@
queue_pair.port_ = 0;
queue_pair.queue_state_ = QueuePair::INACTIVE;
} else {
- Send(connection_id, {});
MarkThisSideDisconnected(&queue_pair);
}
}
diff --git a/common/vsoc/lib/socket_forward_region_view.h b/common/vsoc/lib/socket_forward_region_view.h
index 89da339..81110db 100644
--- a/common/vsoc/lib/socket_forward_region_view.h
+++ b/common/vsoc/lib/socket_forward_region_view.h
@@ -65,7 +65,7 @@
}
bool empty() const {
- return header_.message_type == Header::DATA && header_.payload_length == 0;
+ return IsData() && header_.payload_length == 0;
}
void set_payload_length(std::uint32_t length) {