Slimmer rewrite of socket_forward_proxy
Key differences:
1. No QueueState
2. No sequence number
3. No generation number
Instead, the guest-side monitors all queues for new connections and the
host-side keeps track of which queues are allocated
design: https://docs.google.com/document/d/1z43c9LGeEEU6G-ojNtEeQP9-ezK890MU3GP6df3byYs
Change-Id: If0396de2ef8080ed78e7afc36ac0d661f99b6d3c
Bug: 80104636
Bug: 110707067
Test: run local while restarting the host-side process and guest-side
process every ~10 seconds. Connect to a guest-side server that only
sends (doesn't receive) information to the host. Connect several
host-side clients to a guest-side echo server. Restart without having
sent any data in either direction or in only one direction.
diff --git a/common/frontend/socket_forward_proxy/main.cpp b/common/frontend/socket_forward_proxy/main.cpp
index e361059..b2c16f6 100644
--- a/common/frontend/socket_forward_proxy/main.cpp
+++ b/common/frontend/socket_forward_proxy/main.cpp
@@ -14,9 +14,11 @@
* limitations under the License.
*/
+#include <array>
#include <cstdint>
#include <cstdlib>
#include <iostream>
+#include <limits>
#include <memory>
#include <mutex>
#include <sstream>
@@ -114,14 +116,11 @@
};
void SocketToShm(SocketReceiver socket_receiver,
- SocketForwardRegionView::Sender shm_sender) {
- auto packet = Packet::MakeData();
+ SocketForwardRegionView::ShmSender shm_sender) {
while (true) {
+ auto packet = Packet::MakeData();
socket_receiver.Recv(&packet);
- if (packet.empty()) {
- break;
- }
- if (!shm_sender.Send(packet)) {
+ if (packet.empty() || !shm_sender.Send(packet)) {
break;
}
}
@@ -129,11 +128,12 @@
}
void ShmToSocket(SocketSender socket_sender,
- SocketForwardRegionView::Receiver shm_receiver) {
- Packet packet{};
+ SocketForwardRegionView::ShmReceiver shm_receiver) {
+ auto packet = Packet{};
while (true) {
shm_receiver.Recv(&packet);
- if (packet.IsEnd()) {
+ CHECK(packet.IsData());
+ if (packet.empty()) {
break;
}
if (socket_sender.SendAll(packet) < 0) {
@@ -145,15 +145,12 @@
// One thread for reading from shm and writing into a socket.
// One thread for reading from a socket and writing into shm.
-void LaunchWorkers(std::pair<SocketForwardRegionView::Sender,
- SocketForwardRegionView::Receiver>
- conn,
- cvd::SharedFD socket) {
- // TODO create the SocketSender/Receiver in their respective threads?
- std::thread(
- SocketToShm, SocketReceiver{socket}, std::move(conn.first)).detach();
- std::thread(
- ShmToSocket, SocketSender{socket}, std::move(conn.second)).detach();
+void HandleConnection(SocketForwardRegionView::ShmSenderReceiverPair shm_sender_and_receiver,
+ cvd::SharedFD socket) {
+ auto socket_to_shm =
+ std::thread(SocketToShm, SocketReceiver{socket}, std::move(shm_sender_and_receiver.first));
+ ShmToSocket(SocketSender{socket}, std::move(shm_sender_and_receiver.second));
+ socket_to_shm.join();
}
#ifdef CUTTLEFISH_HOST
@@ -162,44 +159,142 @@
int host_port;
};
+enum class QueueState {
+ kFree,
+ kUsed,
+};
+
+struct SocketConnectionInfo {
+ std::mutex lock{};
+ std::condition_variable cv{};
+ cvd::SharedFD socket{};
+ int guest_port{};
+ QueueState state = QueueState::kFree;
+};
+
+static constexpr auto kNumHostThreads =
+ vsoc::layout::socket_forward::kNumQueues;
+
+using SocketConnectionInfoCollection =
+ std::array<SocketConnectionInfo, kNumHostThreads>;
+
void LaunchConnectionMaintainer(int port) {
std::thread(cvd::EstablishAndMaintainConnection, port).detach();
}
+void MarkAsFree(SocketConnectionInfo* conn) {
+ std::lock_guard<std::mutex> guard{conn->lock};
+ conn->socket = cvd::SharedFD{};
+ conn->guest_port = 0;
+ conn->state = QueueState::kFree;
+}
-[[noreturn]] void host_impl(SocketForwardRegionView* shm,
- std::vector<PortPair> ports, std::size_t index) {
+std::pair<int, cvd::SharedFD> WaitForConnection(SocketConnectionInfo* conn) {
+ std::unique_lock<std::mutex> guard{conn->lock};
+ while (conn->state != QueueState::kUsed) {
+ conn->cv.wait(guard);
+ }
+ return {conn->guest_port, conn->socket};
+}
+
+[[noreturn]] void host_thread(SocketForwardRegionView::ShmConnectionView view,
+ SocketConnectionInfo* conn) {
+ while (true) {
+ int guest_port{};
+ cvd::SharedFD socket{};
+ // TODO structured binding in C++17
+ std::tie(guest_port, socket) = WaitForConnection(conn);
+
+ LOG(INFO) << "Establishing connection to guest port " << guest_port
+ << " with connection_id: " << view.connection_id();
+ HandleConnection(view.EstablishConnection(guest_port), std::move(socket));
+ LOG(INFO) << "Connection to guest port " << guest_port
+ << " closed. Marking queue " << view.connection_id()
+ << " as free.";
+ MarkAsFree(conn);
+ }
+}
+
+bool TryAllocateConnection(SocketConnectionInfo* conn, int guest_port,
+ cvd::SharedFD socket) {
+ bool success = false;
+ {
+ std::lock_guard<std::mutex> guard{conn->lock};
+ if (conn->state == QueueState::kFree) {
+ conn->socket = std::move(socket);
+ conn->guest_port = guest_port;
+ conn->state = QueueState::kUsed;
+ success = true;
+ }
+ }
+ if (success) {
+ conn->cv.notify_one();
+ }
+ return success;
+}
+
+void AllocateWorkers(cvd::SharedFD socket,
+ SocketConnectionInfoCollection* socket_connection_info,
+ int guest_port) {
+ while (true) {
+ for (auto& conn : *socket_connection_info) {
+ if (TryAllocateConnection(&conn, guest_port, socket)) {
+ return;
+ }
+ }
+ LOG(INFO) << "no queues available. sleeping and retrying";
+ sleep(5);
+ }
+}
+
+[[noreturn]] void host_impl(
+ SocketForwardRegionView* shm,
+ SocketConnectionInfoCollection* socket_connection_info,
+ std::vector<PortPair> ports, std::size_t index) {
// launch a worker for the following port before handling the current port.
// recursion (instead of a loop) removes the need fore any join() or having
// the main thread do no work.
if (index + 1 < ports.size()) {
- std::thread(host_impl, shm, ports, index + 1).detach();
+ std::thread(host_impl, shm, socket_connection_info, ports, index + 1)
+ .detach();
}
auto guest_port = ports[index].guest_port;
auto host_port = ports[index].host_port;
- LOG(INFO) << "starting server on " << host_port
- << " for guest port " << guest_port;
+ LOG(INFO) << "starting server on " << host_port << " for guest port "
+ << guest_port;
auto server = cvd::SharedFD::SocketLocalServer(host_port, SOCK_STREAM);
CHECK(server->IsOpen()) << "Could not start server on port " << host_port;
+ // Note: If generically forwarding ports, the adb connection maintainer should
+ // be disabled
LaunchConnectionMaintainer(host_port);
while (true) {
auto client_socket = cvd::SharedFD::Accept(*server);
CHECK(client_socket->IsOpen()) << "error creating client socket";
LOG(INFO) << "client socket accepted";
- auto conn = shm->OpenConnection(guest_port);
- LOG(INFO) << "shm connection opened";
- LaunchWorkers(std::move(conn), std::move(client_socket));
+ AllocateWorkers(std::move(client_socket), socket_connection_info,
+ guest_port);
}
}
[[noreturn]] void host(SocketForwardRegionView* shm,
std::vector<PortPair> ports) {
CHECK(!ports.empty());
- host_impl(shm, ports, 0);
+
+ SocketConnectionInfoCollection socket_connection_info{};
+
+ auto conn_info_iter = std::begin(socket_connection_info);
+ for (auto& shm_connection_view : shm->AllConnections()) {
+ CHECK_NE(conn_info_iter, std::end(socket_connection_info));
+ std::thread(host_thread, std::move(shm_connection_view), &*conn_info_iter)
+ .detach();
+ ++conn_info_iter;
+ }
+ CHECK_EQ(conn_info_iter, std::end(socket_connection_info));
+ host_impl(shm, &socket_connection_info, ports, 0);
}
std::vector<PortPair> ParsePortsList(const std::string& guest_ports_str,
- const std::string& host_ports_str) {
+ const std::string& host_ports_str) {
std::vector<PortPair> ports{};
auto guest_ports = cvd::StrSplit(guest_ports_str, ',');
auto host_ports = cvd::StrSplit(host_ports_str, ',');
@@ -208,7 +303,6 @@
ports.push_back({std::stoi(guest_ports[i]), std::stoi(host_ports[i])});
}
return ports;
-
}
#else
@@ -224,17 +318,28 @@
}
}
-[[noreturn]] void guest(SocketForwardRegionView* shm) {
- LOG(INFO) << "Starting guest mainloop";
+[[noreturn]] void guest_thread(
+ SocketForwardRegionView::ShmConnectionView view) {
while (true) {
- auto conn = shm->AcceptConnection();
- LOG(INFO) << "shm connection accepted";
- auto sock = OpenSocketConnection(conn.first.port());
- CHECK(sock->IsOpen());
- LOG(INFO) << "socket opened to " << conn.first.port();
- LaunchWorkers(std::move(conn), std::move(sock));
+ LOG(INFO) << "waiting for new connection";
+ auto shm_sender_and_receiver = view.WaitForNewConnection();
+ LOG(INFO) << "new connection for port " << view.port();
+ HandleConnection(std::move(shm_sender_and_receiver), OpenSocketConnection(view.port()));
+ LOG(INFO) << "connection closed on port " << view.port();
}
}
+
+[[noreturn]] void guest(SocketForwardRegionView* shm) {
+ LOG(INFO) << "Starting guest mainloop";
+ auto connection_views = shm->AllConnections();
+ for (auto&& shm_connection_view : connection_views) {
+ std::thread(guest_thread, std::move(shm_connection_view)).detach();
+ }
+ while (true) {
+ sleep(std::numeric_limits<unsigned int>::max());
+ }
+}
+
#endif
SocketForwardRegionView* GetShm() {
diff --git a/common/vsoc/lib/socket_forward_region_view.cpp b/common/vsoc/lib/socket_forward_region_view.cpp
index c8a153b..2dbd2e8 100644
--- a/common/vsoc/lib/socket_forward_region_view.cpp
+++ b/common/vsoc/lib/socket_forward_region_view.cpp
@@ -24,7 +24,6 @@
using vsoc::layout::socket_forward::Queue;
using vsoc::layout::socket_forward::QueuePair;
-namespace QueueState = vsoc::layout::socket_forward::QueueState;
// store the read and write direction as variables to keep the ifdefs and macros
// in later code to a minimum
constexpr auto ReadDirection = &QueuePair::
@@ -41,29 +40,22 @@
guest_to_host;
#endif
-constexpr auto kOtherSideClosed = QueueState::
-#ifdef CUTTLEFISH_HOST
- GUEST_CLOSED;
-#else
- HOST_CLOSED;
-#endif
-
-constexpr auto kThisSideClosed = QueueState::
-#ifdef CUTTLEFISH_HOST
- HOST_CLOSED;
-#else
- GUEST_CLOSED;
-#endif
-
using vsoc::socket_forward::SocketForwardRegionView;
+vsoc::socket_forward::Packet vsoc::socket_forward::Packet::MakeBegin(
+ std::uint16_t port) {
+ auto packet = MakePacket(Header::BEGIN);
+ std::memcpy(packet.payload(), &port, sizeof port);
+ packet.set_payload_length(sizeof port);
+ return packet;
+}
+
void SocketForwardRegionView::Recv(int connection_id, Packet* packet) {
CHECK(packet != nullptr);
do {
(data()->queues_[connection_id].*ReadDirection)
.queue.Read(this, reinterpret_cast<char*>(packet), sizeof *packet);
} while (packet->IsBegin());
- // TODO(haining) check packet generation number
CHECK(!packet->empty()) << "zero-size data message received";
CHECK_LE(packet->payload_length(), kMaxPayloadSize) << "invalid size";
}
@@ -72,228 +64,187 @@
CHECK(!packet.empty());
CHECK_LE(packet.payload_length(), kMaxPayloadSize);
- // NOTE this is check-then-act but I think that it's okay. Worst case is that
- // we send one-too-many packets.
- auto& queue_pair = data()->queues_[connection_id];
- {
- auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
- if ((queue_pair.*WriteDirection).queue_state_ == kOtherSideClosed) {
- LOG(INFO) << "connection closed, not sending\n";
- return false;
- }
- CHECK((queue_pair.*WriteDirection).queue_state_ != QueueState::INACTIVE);
- }
- // TODO(haining) set packet generation number
(data()->queues_[connection_id].*WriteDirection)
.queue.Write(this, packet.raw_data(), packet.raw_data_length());
return true;
}
-void SocketForwardRegionView::IgnoreUntilBegin(int connection_id,
- std::uint32_t generation) {
+int SocketForwardRegionView::IgnoreUntilBegin(int connection_id) {
Packet packet{};
do {
(data()->queues_[connection_id].*ReadDirection)
.queue.Read(this, reinterpret_cast<char*>(&packet), sizeof packet);
- } while (!packet.IsBegin() || packet.generation() < generation);
+ } while (!packet.IsBegin());
+ return packet.port();
}
-bool SocketForwardRegionView::IsOtherSideRecvClosed(int connection_id) {
- auto& queue_pair = data()->queues_[connection_id];
- auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
- auto& queue = queue_pair.*WriteDirection;
- return queue.queue_state_ == kOtherSideClosed ||
- queue.queue_state_ == QueueState::INACTIVE;
-}
-
-void SocketForwardRegionView::ResetQueueStates(QueuePair* queue_pair) {
- using vsoc::layout::socket_forward::Queue;
- auto guard = make_lock_guard(&queue_pair->queue_state_lock_);
- Queue* queues[] = {&queue_pair->host_to_guest, &queue_pair->guest_to_host};
- for (auto* queue : queues) {
- auto& state = queue->queue_state_;
- switch (state) {
- case QueueState::HOST_CONNECTED:
- case kOtherSideClosed:
- LOG(DEBUG)
- << "host_connected or other side is closed, marking inactive";
- state = QueueState::INACTIVE;
- break;
-
- case QueueState::BOTH_CONNECTED:
- LOG(DEBUG) << "both_connected, marking this side closed";
- state = kThisSideClosed;
- break;
-
- case kThisSideClosed:
- [[fallthrough]];
- case QueueState::INACTIVE:
- LOG(DEBUG) << "inactive or this side closed, not changing state";
- break;
- }
- }
-}
+constexpr int kNumQueues =
+ static_cast<int>(vsoc::layout::socket_forward::kNumQueues);
void SocketForwardRegionView::CleanUpPreviousConnections() {
data()->Recover();
- int connection_id = 0;
- auto current_generation = generation();
- auto begin_packet = Packet::MakeBegin();
- begin_packet.set_generation(current_generation);
- auto end_packet = Packet::MakeEnd();
- end_packet.set_generation(current_generation);
- for (auto&& queue_pair : data()->queues_) {
- std::uint32_t state{};
- {
- auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
- state = (queue_pair.*WriteDirection).queue_state_;
-#ifndef CUTTLEFISH_HOST
- if (state == QueueState::HOST_CONNECTED) {
- state = (queue_pair.*WriteDirection).queue_state_ =
- (queue_pair.*ReadDirection).queue_state_ =
- QueueState::BOTH_CONNECTED;
- }
-#endif
- }
- if (state == QueueState::BOTH_CONNECTED
-#ifdef CUTTLEFISH_HOST
- || state == QueueState::HOST_CONNECTED
-#endif
- ) {
- LOG(INFO) << "found connected write queue state, sending begin and end";
- Send(connection_id, begin_packet);
- Send(connection_id, end_packet);
- }
- ResetQueueStates(&queue_pair);
- ++connection_id;
- }
- ++data()->generation_num;
-}
-
-void SocketForwardRegionView::MarkQueueDisconnected(
- int connection_id, Queue QueuePair::*direction) {
- auto& queue_pair = data()->queues_[connection_id];
- auto& queue = queue_pair.*direction;
-
-#ifdef CUTTLEFISH_HOST
- // if the host has connected but the guest hasn't seen it yet, wait for the
- // guest to connect so the protocol can follow the normal state transition.
- while (queue.queue_state_ == QueueState::HOST_CONNECTED) {
- LOG(WARNING) << "closing queue[" << connection_id
- << "] in HOST_CONNECTED state. waiting";
- WaitForSignal(&queue.queue_state_, QueueState::HOST_CONNECTED);
- }
-#endif
-
- auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
-
- queue.queue_state_ = queue.queue_state_ == kOtherSideClosed
- ? QueueState::INACTIVE
- : kThisSideClosed;
-}
-
-void SocketForwardRegionView::MarkSendQueueDisconnected(int connection_id) {
- MarkQueueDisconnected(connection_id, WriteDirection);
-}
-
-void SocketForwardRegionView::MarkRecvQueueDisconnected(int connection_id) {
- MarkQueueDisconnected(connection_id, ReadDirection);
-}
-
-int SocketForwardRegionView::port(int connection_id) {
- return data()->queues_[connection_id].port_;
-}
-
-std::uint32_t SocketForwardRegionView::generation() {
- return data()->generation_num;
-}
-
-#ifdef CUTTLEFISH_HOST
-int SocketForwardRegionView::AcquireConnectionID(int port) {
- while (true) {
- int id = 0;
- for (auto&& queue_pair : data()->queues_) {
- LOG(DEBUG) << "locking and checking queue at index " << id;
- auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
- if (queue_pair.host_to_guest.queue_state_ == QueueState::INACTIVE &&
- queue_pair.guest_to_host.queue_state_ == QueueState::INACTIVE) {
- queue_pair.port_ = port;
- queue_pair.host_to_guest.queue_state_ = QueueState::HOST_CONNECTED;
- queue_pair.guest_to_host.queue_state_ = QueueState::HOST_CONNECTED;
- LOG(DEBUG) << "acquired queue " << id
- << ". current seq_num: " << data()->seq_num;
- ++data()->seq_num;
- SendSignal(layout::Sides::Peer, &data()->seq_num);
- return id;
- }
- ++id;
- }
- LOG(ERROR) << "no remaining shm queues for connection, sleeping.";
- sleep(10);
+ static constexpr auto kRestartPacket = Packet::MakeRestart();
+ for (int connection_id = 0; connection_id < kNumQueues; ++connection_id) {
+ Send(connection_id, kRestartPacket);
}
}
-std::pair<SocketForwardRegionView::Sender, SocketForwardRegionView::Receiver>
-SocketForwardRegionView::OpenConnection(int port) {
- int connection_id = AcquireConnectionID(port);
- LOG(INFO) << "Acquired connection with id " << connection_id;
- auto current_generation = generation();
- return {Sender{this, connection_id, current_generation},
- Receiver{this, connection_id, current_generation}};
-}
-#else
-int SocketForwardRegionView::GetWaitingConnectionID() {
- while (data()->seq_num == last_seq_number_) {
- WaitForSignal(&data()->seq_num, last_seq_number_);
+SocketForwardRegionView::ConnectionViewCollection
+SocketForwardRegionView::AllConnections() {
+ SocketForwardRegionView::ConnectionViewCollection all_queues;
+ for (int connection_id = 0; connection_id < kNumQueues; ++connection_id) {
+ all_queues.emplace_back(this, connection_id);
}
- ++last_seq_number_;
- int id = 0;
- for (auto&& queue_pair : data()->queues_) {
- LOG(DEBUG) << "locking and checking queue at index " << id;
- auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
- if (queue_pair.host_to_guest.queue_state_ == QueueState::HOST_CONNECTED) {
- CHECK(queue_pair.guest_to_host.queue_state_ ==
- QueueState::HOST_CONNECTED);
- LOG(DEBUG) << "found waiting connection at index " << id;
- queue_pair.host_to_guest.queue_state_ = QueueState::BOTH_CONNECTED;
- queue_pair.guest_to_host.queue_state_ = QueueState::BOTH_CONNECTED;
- SendSignal(layout::Sides::Peer, &queue_pair.host_to_guest.queue_state_);
- SendSignal(layout::Sides::Peer, &queue_pair.guest_to_host.queue_state_);
- return id;
- }
- ++id;
- }
- return -1;
+ return all_queues;
}
-std::pair<SocketForwardRegionView::Sender, SocketForwardRegionView::Receiver>
-SocketForwardRegionView::AcceptConnection() {
- int connection_id = -1;
- while (connection_id < 0) {
- connection_id = GetWaitingConnectionID();
- }
- LOG(INFO) << "Accepted connection with id " << connection_id;
-
- auto current_generation = generation();
- return {Sender{this, connection_id, current_generation},
- Receiver{this, connection_id, current_generation}};
-}
-#endif
-
// --- Connection ---- //
-void SocketForwardRegionView::Receiver::Recv(Packet* packet) {
- if (!got_begin_) {
- view_->IgnoreUntilBegin(connection_id_, generation_);
- got_begin_ = true;
+
+void SocketForwardRegionView::ShmConnectionView::Receiver::Recv(Packet* packet) {
+ std::unique_lock<std::mutex> guard(receive_thread_data_lock_);
+ while (received_packet_free_) {
+ receive_thread_data_cv_.wait(guard);
}
- return view_->Recv(connection_id_, packet);
+ CHECK(received_packet_.IsData());
+ *packet = received_packet_;
+ received_packet_free_ = true;
+ receive_thread_data_cv_.notify_one();
}
-bool SocketForwardRegionView::Sender::closed() const {
- return view_->IsOtherSideRecvClosed(connection_id_);
+bool SocketForwardRegionView::ShmConnectionView::Receiver::GotRecvClosed() const {
+ return received_packet_.IsRecvClosed() || (received_packet_.IsRestart()
+#ifdef CUTTLEFISH_HOST
+ && saw_data_
+#endif
+ );
}
-bool SocketForwardRegionView::Sender::Send(const Packet& packet) {
- return view_->Send(connection_id_, packet);
+bool SocketForwardRegionView::ShmConnectionView::Receiver::ShouldReceiveAnotherPacket() const {
+ return (received_packet_.IsRecvClosed() && !saw_end_) ||
+ (saw_end_ && received_packet_.IsEnd())
+#ifdef CUTTLEFISH_HOST
+ || (received_packet_.IsRestart() && !saw_data_) ||
+ (received_packet_.IsBegin())
+#endif
+ ;
+}
+
+void SocketForwardRegionView::ShmConnectionView::Receiver::ReceivePacket() {
+ view_->region_view()->Recv(view_->connection_id(), &received_packet_);
+}
+
+void SocketForwardRegionView::ShmConnectionView::Receiver::CheckPacketForRecvClosed() {
+ if (GotRecvClosed()) {
+ saw_recv_closed_ = true;
+ view_->MarkOtherSideRecvClosed();
+ }
+#ifdef CUTTLEFISH_HOST
+ if (received_packet_.IsData()) {
+ saw_data_ = true;
+ }
+#endif
+}
+
+void SocketForwardRegionView::ShmConnectionView::Receiver::CheckPacketForEnd() {
+ if (received_packet_.IsEnd() || received_packet_.IsRestart()) {
+ CHECK(!saw_end_ || received_packet_.IsRestart());
+ saw_end_ = true;
+ }
+}
+
+
+bool SocketForwardRegionView::ShmConnectionView::Receiver::ExpectMorePackets() const {
+ return !saw_recv_closed_ || !saw_end_;
+}
+
+void SocketForwardRegionView::ShmConnectionView::Receiver::UpdatePacketAndSignalAvailable() {
+ if (!received_packet_.IsData()) {
+ static constexpr auto kEmptyPacket = Packet::MakeData();
+ received_packet_ = kEmptyPacket;
+ }
+ received_packet_free_ = false;
+ receive_thread_data_cv_.notify_one();
+}
+
+void SocketForwardRegionView::ShmConnectionView::Receiver::Start() {
+ while (ExpectMorePackets()) {
+ std::unique_lock<std::mutex> guard(receive_thread_data_lock_);
+ while (!received_packet_free_) {
+ receive_thread_data_cv_.wait(guard);
+ }
+
+ do {
+ ReceivePacket();
+ CheckPacketForRecvClosed();
+ } while (ShouldReceiveAnotherPacket());
+
+ if (received_packet_.empty()) {
+ LOG(ERROR) << "Received empty packet.";
+ }
+
+ CheckPacketForEnd();
+
+ UpdatePacketAndSignalAvailable();
+ }
+}
+
+auto SocketForwardRegionView::ShmConnectionView::ResetAndConnect()
+ -> ShmSenderReceiverPair {
+ if (receiver_) {
+ receiver_->Join();
+ }
+
+ {
+ std::lock_guard<std::mutex> guard(*other_side_receive_closed_lock_);
+ other_side_receive_closed_ = false;
+ }
+
+#ifdef CUTTLEFISH_HOST
+ region_view()->IgnoreUntilBegin(connection_id());
+ region_view()->Send(connection_id(), Packet::MakeBegin(port_));
+#else
+ region_view()->Send(connection_id(), Packet::MakeBegin(port_));
+ port_ =
+ region_view()->IgnoreUntilBegin(connection_id());
+#endif
+
+ receiver_.reset(new Receiver{this});
+ return {ShmSender{this}, ShmReceiver{this}};
+}
+
+#ifdef CUTTLEFISH_HOST
+auto SocketForwardRegionView::ShmConnectionView::EstablishConnection(int port)
+ -> ShmSenderReceiverPair {
+ port_ = port;
+ return ResetAndConnect();
+}
+#else
+auto SocketForwardRegionView::ShmConnectionView::WaitForNewConnection()
+ -> ShmSenderReceiverPair {
+ port_ = 0;
+ return ResetAndConnect();
+}
+#endif
+
+bool SocketForwardRegionView::ShmConnectionView::Send(const Packet& packet) {
+ if (packet.empty()) {
+ LOG(ERROR) << "Sending empty packet";
+ }
+ if (packet.IsData() && IsOtherSideRecvClosed()) {
+ return false;
+ }
+ return region_view()->Send(connection_id(), packet);
+}
+
+void SocketForwardRegionView::ShmConnectionView::Recv(Packet* packet) {
+ receiver_->Recv(packet);
+}
+
+void SocketForwardRegionView::ShmReceiver::Recv(Packet* packet) {
+ view_->Recv(packet);
+}
+
+bool SocketForwardRegionView::ShmSender::Send(const Packet& packet) {
+ return view_->Send(packet);
}
diff --git a/common/vsoc/lib/socket_forward_region_view.h b/common/vsoc/lib/socket_forward_region_view.h
index ce6958a..c41517b 100644
--- a/common/vsoc/lib/socket_forward_region_view.h
+++ b/common/vsoc/lib/socket_forward_region_view.h
@@ -15,6 +15,7 @@
*/
#pragma once
+#include <cstdlib>
#include <utility>
#include <vector>
#include <memory>
@@ -27,11 +28,12 @@
struct Header {
std::uint32_t payload_length;
- std::uint32_t generation;
enum MessageType : std::uint32_t {
DATA = 0,
BEGIN,
END,
+ RECV_CLOSED, // indicate that this side's receive end is closed
+ RESTART,
};
MessageType message_type;
};
@@ -45,51 +47,73 @@
using Payload = char[kMaxPayloadSize];
Payload payload_data_;
- static Packet MakePacket(Header::MessageType type) {
+ static constexpr Packet MakePacket(Header::MessageType type) {
Packet packet{};
packet.header_.message_type = type;
return packet;
}
public:
- static Packet MakeBegin() { return MakePacket(Header::BEGIN); }
+ // port is only revelant on the host-side.
+ static Packet MakeBegin(std::uint16_t port);
- static Packet MakeEnd() { return MakePacket(Header::END); }
+ static constexpr Packet MakeEnd() { return MakePacket(Header::END); }
+
+ static constexpr Packet MakeRecvClosed() {
+ return MakePacket(Header::RECV_CLOSED);
+ }
+
+ static constexpr Packet MakeRestart() { return MakePacket(Header::RESTART); }
// NOTE payload and payload_length must still be set.
- static Packet MakeData() { return MakePacket(Header::DATA); }
+ static constexpr Packet MakeData() { return MakePacket(Header::DATA); }
bool empty() const { return IsData() && header_.payload_length == 0; }
void set_payload_length(std::uint32_t length) {
CHECK_LE(length, sizeof payload_data_);
- header_.message_type = Header::DATA;
header_.payload_length = length;
}
- std::uint32_t generation() const { return header_.generation; }
-
- void set_generation(std::uint32_t generation) {
- header_.generation = generation;
- }
-
Payload& payload() { return payload_data_; }
const Payload& payload() const { return payload_data_; }
- std::uint32_t payload_length() const { return header_.payload_length; }
+ constexpr std::uint32_t payload_length() const {
+ return header_.payload_length;
+ }
- bool IsBegin() const { return header_.message_type == Header::BEGIN; }
+ constexpr bool IsBegin() const {
+ return header_.message_type == Header::BEGIN;
+ }
- bool IsEnd() const { return header_.message_type == Header::END; }
+ constexpr bool IsEnd() const { return header_.message_type == Header::END; }
- bool IsData() const { return header_.message_type == Header::DATA; }
+ constexpr bool IsData() const { return header_.message_type == Header::DATA; }
+
+ constexpr bool IsRecvClosed() const {
+ return header_.message_type == Header::RECV_CLOSED;
+ }
+
+ constexpr bool IsRestart() const {
+ return header_.message_type == Header::RESTART;
+ }
+
+ constexpr std::uint16_t port() const {
+ CHECK(IsBegin());
+ std::uint16_t port_number{};
+ CHECK_EQ(payload_length(), sizeof port_number);
+ std::memcpy(&port_number, payload(), sizeof port_number);
+ return port_number;
+ }
char* raw_data() { return reinterpret_cast<char*>(this); }
const char* raw_data() const { return reinterpret_cast<const char*>(this); }
- size_t raw_data_length() const { return payload_length() + sizeof header_; }
+ constexpr size_t raw_data_length() const {
+ return payload_length() + sizeof header_;
+ }
};
static_assert(sizeof(Packet) == layout::socket_forward::kMaxPacketSize, "");
@@ -101,128 +125,186 @@
: public TypedRegionView<SocketForwardRegionView,
layout::socket_forward::SocketForwardLayout> {
private:
-#ifdef CUTTLEFISH_HOST
- int AcquireConnectionID(int port);
-#else
- int GetWaitingConnectionID();
-#endif
-
// Returns an empty data packet if the other side is closed.
void Recv(int connection_id, Packet* packet);
// Returns true on success
bool Send(int connection_id, const Packet& packet);
- // skip everything in the connection queue until seeing a BEGIN for the
- // current generation
- void IgnoreUntilBegin(int connection_id, std::uint32_t generation);
-
- bool IsOtherSideRecvClosed(int connection_id);
-
- void ResetQueueStates(layout::socket_forward::QueuePair* queue_pair);
-
- void MarkQueueDisconnected(int connection_id,
- layout::socket_forward::Queue
- layout::socket_forward::QueuePair::*direction);
+ // skip everything in the connection queue until seeing a BEGIN packet.
+ // returns port from begin packet.
+ int IgnoreUntilBegin(int connection_id);
public:
- // Helper class that will send a ConnectionBegin marker when constructed and a
- // ConnectionEnd marker when destroyed.
- class Sender {
+ class ShmSender;
+ class ShmReceiver;
+
+ using ShmSenderReceiverPair = std::pair<ShmSender, ShmReceiver>;
+
+ class ShmConnectionView {
public:
- explicit Sender(SocketForwardRegionView* view, int connection_id,
- std::uint32_t generation)
- : view_{view, {connection_id, generation}},
- connection_id_{connection_id} {
- auto packet = Packet::MakeBegin();
- packet.set_generation(generation);
- view_->Send(connection_id, packet);
+ ShmConnectionView(SocketForwardRegionView* region_view, int connection_id)
+ : region_view_{region_view}, connection_id_{connection_id} {}
+
+#ifdef CUTTLEFISH_HOST
+ ShmSenderReceiverPair EstablishConnection(int port);
+#else
+ // Should not be called while there is an active ShmSender or ShmReceiver
+ // for this connection.
+ ShmSenderReceiverPair WaitForNewConnection();
+#endif
+
+ int port() const { return port_; }
+
+ bool Send(const Packet& packet);
+ void Recv(Packet* packet);
+
+ ShmConnectionView(const ShmConnectionView&) = delete;
+ ShmConnectionView& operator=(const ShmConnectionView&) = delete;
+
+ // Moving invalidates all existing ShmSenders and ShmReceiver
+ ShmConnectionView(ShmConnectionView&&) = default;
+ ShmConnectionView& operator=(ShmConnectionView&&) = default;
+ ~ShmConnectionView() = default;
+
+ // NOTE should only be used for debugging/logging purposes.
+ // connection_ids are an implementation detail that are currently useful for
+ // debugging, but may go away in the future.
+ int connection_id() const { return connection_id_; }
+
+ private:
+ SocketForwardRegionView* region_view() const { return region_view_; }
+
+ bool IsOtherSideRecvClosed() {
+ std::lock_guard<std::mutex> guard(*other_side_receive_closed_lock_);
+ return other_side_receive_closed_;
}
- Sender(const Sender&) = delete;
- Sender& operator=(const Sender&) = delete;
+ void MarkOtherSideRecvClosed() {
+ std::lock_guard<std::mutex> guard(*other_side_receive_closed_lock_);
+ other_side_receive_closed_ = true;
+ }
- Sender(Sender&&) = default;
- Sender& operator=(Sender&&) = default;
- ~Sender() = default;
+ void ReceiverThread();
+ ShmSenderReceiverPair ResetAndConnect();
+
+ class Receiver {
+ public:
+ Receiver(ShmConnectionView* view)
+ : view_{view}
+ {
+ receiver_thread_ = std::thread([this] { Start(); });
+ }
+
+ void Recv(Packet* packet);
+
+ void Join() { receiver_thread_.join(); }
+
+ Receiver(const Receiver&) = delete;
+ Receiver& operator=(const Receiver&) = delete;
+
+ ~Receiver() = default;
+ private:
+ void Start();
+ bool GotRecvClosed() const;
+ void ReceivePacket();
+ void CheckPacketForRecvClosed();
+ void CheckPacketForEnd();
+ void UpdatePacketAndSignalAvailable();
+ bool ShouldReceiveAnotherPacket() const;
+ bool ExpectMorePackets() const;
+
+ std::mutex receive_thread_data_lock_;
+ std::condition_variable receive_thread_data_cv_;
+ bool received_packet_free_ = true;
+ Packet received_packet_{};
+
+ ShmConnectionView* view_{};
+ bool saw_recv_closed_ = false;
+ bool saw_end_ = false;
+#ifdef CUTTLEFISH_HOST
+ bool saw_data_ = false;
+#endif
+
+ std::thread receiver_thread_;
+ };
+
+ SocketForwardRegionView* region_view_{};
+ int connection_id_ = -1;
+ int port_ = -1;
+
+ std::unique_ptr<std::mutex> other_side_receive_closed_lock_ =
+ std::unique_ptr<std::mutex>{new std::mutex{}};
+ bool other_side_receive_closed_ = false;
+
+ std::unique_ptr<Receiver> receiver_;
+ };
+
+ class ShmSender {
+ public:
+ explicit ShmSender(ShmConnectionView* view) : view_{view} {}
+
+ ShmSender(const ShmSender&) = delete;
+ ShmSender& operator=(const ShmSender&) = delete;
+
+ ShmSender(ShmSender&&) = default;
+ ShmSender& operator=(ShmSender&&) = default;
+ ~ShmSender() = default;
// Returns true on success
bool Send(const Packet& packet);
- int port() const { return view_->port(connection_id_); }
private:
- bool closed() const;
-
struct EndSender {
- int connection_id = -1;
- std::uint32_t generation{};
- void operator()(SocketForwardRegionView* view) const {
+ void operator()(ShmConnectionView* view) const {
if (view) {
- CHECK(connection_id >= 0);
- auto packet = Packet::MakeEnd();
- packet.set_generation(generation);
- view->Send(connection_id, packet);
- view->MarkSendQueueDisconnected(connection_id);
+ view->Send(Packet::MakeEnd());
}
}
};
+
// Doesn't actually own the View, responsible for sending the End
// indicator and marking the sending side as disconnected.
- std::unique_ptr<SocketForwardRegionView, EndSender> view_;
- int connection_id_{};
+ std::unique_ptr<ShmConnectionView, EndSender> view_;
};
- class Receiver {
+ class ShmReceiver {
public:
- explicit Receiver(SocketForwardRegionView* view, int connection_id,
- std::uint32_t generation)
- : view_{view, {connection_id}},
- connection_id_{connection_id},
- generation_{generation} {}
- Receiver(const Receiver&) = delete;
- Receiver& operator=(const Receiver&) = delete;
+ explicit ShmReceiver(ShmConnectionView* view) : view_{view} {}
+ ShmReceiver(const ShmReceiver&) = delete;
+ ShmReceiver& operator=(const ShmReceiver&) = delete;
- Receiver(Receiver&&) = default;
- Receiver& operator=(Receiver&&) = default;
- ~Receiver() = default;
+ ShmReceiver(ShmReceiver&&) = default;
+ ShmReceiver& operator=(ShmReceiver&&) = default;
+ ~ShmReceiver() = default;
void Recv(Packet* packet);
- int port() const { return view_->port(connection_id_); }
private:
- struct QueueCloser {
- int connection_id = -1;
- void operator()(SocketForwardRegionView* view) const {
+ struct RecvClosedSender {
+ void operator()(ShmConnectionView* view) const {
if (view) {
- CHECK(connection_id >= 0);
- view->MarkRecvQueueDisconnected(connection_id);
+ view->Send(Packet::MakeRecvClosed());
}
}
};
- // Doesn't actually own the View, responsible for marking the receiving
- // side as disconnected
- std::unique_ptr<SocketForwardRegionView, QueueCloser> view_;
- int connection_id_{};
- std::uint32_t generation_{};
- bool got_begin_ = false;
+ // Doesn't actually own the view, responsible for sending the RecvClosed
+ // indicator
+ std::unique_ptr<ShmConnectionView, RecvClosedSender> view_{};
};
+ friend ShmConnectionView;
+
SocketForwardRegionView() = default;
~SocketForwardRegionView() = default;
SocketForwardRegionView(const SocketForwardRegionView&) = delete;
SocketForwardRegionView& operator=(const SocketForwardRegionView&) = delete;
-#ifdef CUTTLEFISH_HOST
- std::pair<Sender, Receiver> OpenConnection(int port);
-#else
- std::pair<Sender, Receiver> AcceptConnection();
-#endif
+ using ConnectionViewCollection = std::vector<ShmConnectionView>;
+ ConnectionViewCollection AllConnections();
int port(int connection_id);
- std::uint32_t generation();
void CleanUpPreviousConnections();
- void MarkSendQueueDisconnected(int connection_id);
- void MarkRecvQueueDisconnected(int connection_id);
private:
#ifndef CUTTLEFISH_HOST
diff --git a/common/vsoc/shm/socket_forward_layout.h b/common/vsoc/shm/socket_forward_layout.h
index 02523af..4a9beda 100644
--- a/common/vsoc/shm/socket_forward_layout.h
+++ b/common/vsoc/shm/socket_forward_layout.h
@@ -27,53 +27,35 @@
constexpr std::size_t kMaxPacketSize = 8192;
constexpr std::size_t kNumQueues = 16;
-namespace QueueState {
-constexpr std::uint32_t INACTIVE = 0;
-constexpr std::uint32_t HOST_CONNECTED = 1;
-constexpr std::uint32_t BOTH_CONNECTED = 2;
-constexpr std::uint32_t HOST_CLOSED = 3;
-constexpr std::uint32_t GUEST_CLOSED = 4;
-// If both are closed then the queue goes back to INACTIVE
-// BOTH_CLOSED = 0,
-} // namespace QueueState
-
struct Queue {
static constexpr size_t layout_size =
- CircularPacketQueue<16, kMaxPacketSize>::layout_size + 4;
+ CircularPacketQueue<16, kMaxPacketSize>::layout_size;
CircularPacketQueue<16, kMaxPacketSize> queue;
- std::atomic_uint32_t queue_state_;
-
bool Recover() { return queue.Recover(); }
};
ASSERT_SHM_COMPATIBLE(Queue);
struct QueuePair {
- static constexpr size_t layout_size = 2 * Queue::layout_size + 8;
+ static constexpr size_t layout_size = 2 * Queue::layout_size;
// Traffic originating from host that proceeds towards guest.
Queue host_to_guest;
// Traffic originating from guest that proceeds towards host.
Queue guest_to_host;
- std::uint32_t port_;
-
- SpinLock queue_state_lock_;
-
bool Recover() {
- // TODO: Put queue_state_ and port_ recovery here, probably after grabbing
bool recovered = false;
recovered = recovered || host_to_guest.Recover();
recovered = recovered || guest_to_host.Recover();
- recovered = recovered || queue_state_lock_.Recover();
return recovered;
}
};
ASSERT_SHM_COMPATIBLE(QueuePair);
struct SocketForwardLayout : public RegionLayout {
- static constexpr size_t layout_size = QueuePair::layout_size * kNumQueues + 8;
+ static constexpr size_t layout_size = QueuePair::layout_size * kNumQueues;
bool Recover() {
bool recovered = false;
@@ -81,14 +63,10 @@
bool rval = i.Recover();
recovered = recovered || rval;
}
- // TODO: consider handling the sequence number here
return recovered;
}
QueuePair queues_[kNumQueues];
- std::atomic_uint32_t seq_num; // incremented for every new connection
- std::atomic_uint32_t
- generation_num; // incremented for every new socket forward process
static const char* region_name;
};