Merge "One queue state per direction." into gce-dev
diff --git a/common/commands/socket_forward_proxy/Android.bp b/common/commands/socket_forward_proxy/Android.bp
index 9819e9d..f09f59c 100644
--- a/common/commands/socket_forward_proxy/Android.bp
+++ b/common/commands/socket_forward_proxy/Android.bp
@@ -8,7 +8,6 @@
         "libcuttlefish_fs",
         "vsoc_lib",
         "liblog",
-        "cuttlefish_tcp_socket",
     ],
     static_libs: [
         "libgflags",
diff --git a/common/commands/socket_forward_proxy/main.cpp b/common/commands/socket_forward_proxy/main.cpp
index 2499c56..9ab823b 100644
--- a/common/commands/socket_forward_proxy/main.cpp
+++ b/common/commands/socket_forward_proxy/main.cpp
@@ -42,55 +42,32 @@
 #endif
 
 namespace {
-class Worker {
+// Sends packets, Shutdown(SHUT_WR) on destruction
+class SocketSender {
  public:
-  Worker(SocketForwardRegionView::Connection shm_connection,
-         cvd::SharedFD socket)
-      : shm_connection_(std::move(shm_connection)),
-        socket_(std::move(socket)){}
+  explicit SocketSender(cvd::SharedFD socket) : socket_{std::move(socket)} {}
 
-  static void SocketToShm(std::shared_ptr<Worker> worker) {
-    worker->SocketToShmImpl();
-  }
+  SocketSender(SocketSender&&) = default;
+  SocketSender& operator=(SocketSender&&) = default;
 
-  static void ShmToSocket(std::shared_ptr<Worker> worker) {
-    worker->ShmToSocketImpl();
-  }
+  SocketSender(const SocketSender&&) = delete;
+  SocketSender& operator=(const SocketSender&) = delete;
 
- private:
-
-  // *packet will be empty if Read returns 0 or error
-  void SocketRecvPacket(Packet* packet) {
-    auto size = socket_->Read(packet->payload(), sizeof packet->payload());
-    if (size < 0) {
-      size = 0;
+  ~SocketSender() {
+    if (socket_.operator->()) {  // check that socket_ was not moved-from
+      socket_->Shutdown(SHUT_WR);
     }
-    packet->set_payload_length(size);
   }
 
-  void SocketToShmImpl() {
-    auto shm_sender = shm_connection_.MakeSender();
-
-    auto packet = Packet::MakeData();
-    while (true) {
-      SocketRecvPacket(&packet);
-      if (packet.empty()) {
-        break;
-      }
-      shm_sender.Send(packet);
-    }
-    LOG(INFO) << "Socket to shm exiting";
-  }
-
-  ssize_t SocketSendAll(const Packet& packet) {
+  ssize_t SendAll(const Packet& packet) {
     ssize_t written{};
     while (written < static_cast<ssize_t>(packet.payload_length())) {
       if (!socket_->IsOpen()) {
         return -1;
       }
-      auto just_written = socket_->Send(packet.payload() + written,
-                                         packet.payload_length() - written,
-                                         MSG_NOSIGNAL);
+      auto just_written =
+          socket_->Send(packet.payload() + written,
+                        packet.payload_length() - written, MSG_NOSIGNAL);
       if (just_written <= 0) {
         LOG(INFO) << "Couldn't write to client: "
                   << strerror(socket_->GetErrno());
@@ -101,42 +78,73 @@
     return written;
   }
 
-  struct SocketShutdown {
-    cvd::SharedFD socket;
-    SocketShutdown(const SocketShutdown&) = delete;
-    SocketShutdown& operator=(const SocketShutdown&) = delete;
-    ~SocketShutdown() {
-      socket->Shutdown(SHUT_WR);
-    }
-  };
-
-  void ShmToSocketImpl() {
-    auto shm_receiver = shm_connection_.MakeReceiver();
-    SocketShutdown shutdown_socket{socket_};
-    Packet packet{};
-    while (true) {
-      shm_receiver.Recv(&packet);
-      if (packet.IsEnd()) {
-        break;
-      }
-      if (SocketSendAll(packet) < 0) {
-        break;
-      }
-    }
-    LOG(INFO) << "Shm to socket exiting";
-  }
-
-  SocketForwardRegionView::Connection shm_connection_;
+ private:
   cvd::SharedFD socket_;
 };
 
+class SocketReceiver {
+ public:
+  explicit SocketReceiver(cvd::SharedFD socket) : socket_{std::move(socket)} {}
+
+  SocketReceiver(SocketReceiver&&) = default;
+  SocketReceiver& operator=(SocketReceiver&&) = default;
+
+  SocketReceiver(const SocketReceiver&&) = delete;
+  SocketReceiver& operator=(const SocketReceiver&) = delete;
+
+  // *packet will be empty if Read returns 0 or error
+  void Recv(Packet* packet) {
+    auto size = socket_->Read(packet->payload(), sizeof packet->payload());
+    if (size < 0) {
+      size = 0;
+    }
+    packet->set_payload_length(size);
+  }
+
+ private:
+  cvd::SharedFD socket_;
+};
+
+void SocketToShm(SocketReceiver socket_receiver,
+                 SocketForwardRegionView::Sender shm_sender) {
+  auto packet = Packet::MakeData();
+  while (true) {
+    socket_receiver.Recv(&packet);
+    if (packet.empty()) {
+      break;
+    }
+    if (!shm_sender.Send(packet)) {
+      break;
+    }
+  }
+  LOG(INFO) << "Socket to shm exiting";
+}
+
+void ShmToSocket(SocketSender socket_sender,
+                 SocketForwardRegionView::Receiver shm_receiver) {
+  Packet packet{};
+  while (true) {
+    shm_receiver.Recv(&packet);
+    if (packet.IsEnd()) {
+      break;
+    }
+    if (socket_sender.SendAll(packet) < 0) {
+      break;
+    }
+  }
+  LOG(INFO) << "Shm to socket exiting";
+}
+
 // One thread for reading from shm and writing into a socket.
 // One thread for reading from a socket and writing into shm.
-void LaunchWorkers(SocketForwardRegionView::Connection conn,
+void LaunchWorkers(std::pair<SocketForwardRegionView::Sender,
+                             SocketForwardRegionView::Receiver>
+                       conn,
                    cvd::SharedFD socket) {
-  auto worker = std::make_shared<Worker>(std::move(conn), std::move(socket));
-  std::thread threads[] = {std::thread(Worker::SocketToShm, worker),
-                           std::thread(Worker::ShmToSocket, worker)};
+  // TODO create the SocketSender/Receivers in their respective threads?
+  std::thread threads[] = {
+      std::thread(SocketToShm, SocketReceiver{socket}, std::move(conn.first)),
+      std::thread(ShmToSocket, SocketSender{socket}, std::move(conn.second))};
   for (auto&& t : threads) {
     t.detach();
   }
@@ -162,9 +170,11 @@
   while (true) {
     auto conn = shm->AcceptConnection();
     LOG(INFO) << "shm connection accepted";
-    auto sock = cvd::SharedFD::SocketLocalClient(conn.port(), SOCK_STREAM);
-    CHECK(sock->IsOpen()) << "Could not open socket to port " << conn.port();
-    LOG(INFO) << "socket opened to " << conn.port();
+    auto sock =
+        cvd::SharedFD::SocketLocalClient(conn.first.port(), SOCK_STREAM);
+    CHECK(sock->IsOpen()) << "Could not open socket to port "
+                          << conn.first.port();
+    LOG(INFO) << "socket opened to " << conn.first.port();
     LaunchWorkers(std::move(conn), std::move(sock));
   }
 }
diff --git a/common/vsoc/lib/socket_forward_region_view.cpp b/common/vsoc/lib/socket_forward_region_view.cpp
index 6bd7afd..35ef4db 100644
--- a/common/vsoc/lib/socket_forward_region_view.cpp
+++ b/common/vsoc/lib/socket_forward_region_view.cpp
@@ -23,20 +23,35 @@
 #include "common/vsoc/shm/socket_forward_layout.h"
 
 using vsoc::layout::socket_forward::QueuePair;
+using vsoc::layout::socket_forward::QueueState;
 // store the read and write direction as variables to keep the ifdefs and macros
 // in later code to a minimum
 constexpr auto ReadDirection = &QueuePair::
 #ifdef CUTTLEFISH_HOST
-guest_to_host;
+                                   guest_to_host;
 #else
-host_to_guest;
+                                   host_to_guest;
 #endif
 
 constexpr auto WriteDirection = &QueuePair::
 #ifdef CUTTLEFISH_HOST
-host_to_guest;
+                                    host_to_guest;
 #else
-guest_to_host;
+                                    guest_to_host;
+#endif
+
+constexpr auto kOtherSideClosed = QueueState::
+#ifdef CUTTLEFISH_HOST
+    GUEST_CLOSED;
+#else
+    HOST_CLOSED;
+#endif
+
+constexpr auto kThisSideClosed = QueueState::
+#ifdef CUTTLEFISH_HOST
+    HOST_CLOSED;
+#else
+    GUEST_CLOSED;
 #endif
 
 using vsoc::socket_forward::SocketForwardRegionView;
@@ -45,22 +60,32 @@
   CHECK(packet != nullptr);
   do {
     (data()->queues_[connection_id].*ReadDirection)
-        .Read(this, reinterpret_cast<char*>(packet), sizeof *packet);
+        .queue.Read(this, reinterpret_cast<char*>(packet), sizeof *packet);
   } while (packet->IsBegin());
   // TODO(haining) check packet generation number
   CHECK(!packet->empty()) << "zero-size data message received";
   CHECK_LE(packet->payload_length(), kMaxPayloadSize) << "invalid size";
 }
 
-void SocketForwardRegionView::Send(int connection_id, const Packet& packet) {
-  if (packet.empty()) {
-    LOG(WARNING) << "ignoring empty packet (not sending)";
-    return;
+bool SocketForwardRegionView::Send(int connection_id, const Packet& packet) {
+  CHECK(!packet.empty());
+  CHECK_LE(packet.payload_length(), kMaxPayloadSize);
+
+  // NOTE this is check-then-act but I think that it's okay. Worst case is that
+  // we send one-too-many packets.
+  auto& queue_pair = data()->queues_[connection_id];
+  {
+    auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
+    if ((queue_pair.*WriteDirection).queue_state_ == kOtherSideClosed) {
+      LOG(INFO) << "connection closed, not sending\n";
+      return false;
+    }
+    CHECK((queue_pair.*WriteDirection).queue_state_ != QueueState::INACTIVE);
   }
   // TODO(haining) set packet generation number
-  CHECK_LE(packet.payload_length(), kMaxPayloadSize);
   (data()->queues_[connection_id].*WriteDirection)
-      .Write(this, packet.raw_data(), packet.raw_data_length());
+      .queue.Write(this, packet.raw_data(), packet.raw_data_length());
+  return true;
 }
 
 void SocketForwardRegionView::SendBegin(int connection_id) {
@@ -75,21 +100,55 @@
   Packet packet{};
   do {
     (data()->queues_[connection_id].*ReadDirection)
-        .Read(this, reinterpret_cast<char*>(&packet), sizeof packet);
+        .queue.Read(this, reinterpret_cast<char*>(&packet), sizeof packet);
   } while (!packet.IsBegin());  // TODO(haining) check generation number
 }
 
+bool SocketForwardRegionView::IsOtherSideRecvClosed(int connection_id) {
+  auto& queue_pair = data()->queues_[connection_id];
+  auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
+  auto& queue = queue_pair.*WriteDirection;
+  return queue.queue_state_ == kOtherSideClosed ||
+         queue.queue_state_ == QueueState::INACTIVE;
+}
+
+// TODO merge these two into a helper since the only difference is one
+// Read/Write
+void SocketForwardRegionView::MarkSendQueueDisconnected(int connection_id) {
+  auto& queue_pair = data()->queues_[connection_id];
+  auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
+  auto& queue = queue_pair.*WriteDirection;
+  queue.queue_state_ = queue.queue_state_ == kOtherSideClosed
+                           ? QueueState::INACTIVE
+                           : kThisSideClosed;
+}
+
+void SocketForwardRegionView::MarkRecvQueueDisconnected(int connection_id) {
+  auto& queue_pair = data()->queues_[connection_id];
+  auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
+  auto& queue = queue_pair.*ReadDirection;
+  queue.queue_state_ = queue.queue_state_ == kOtherSideClosed
+                           ? QueueState::INACTIVE
+                           : kThisSideClosed;
+}
+
+int SocketForwardRegionView::port(int connection_id) {
+  return data()->queues_[connection_id].port_;
+}
+
 #ifdef CUTTLEFISH_HOST
 int SocketForwardRegionView::AcquireConnectionID(int port) {
   int id = 0;
   for (auto&& queue_pair : data()->queues_) {
     LOG(DEBUG) << "locking and checking queue at index " << id;
     auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
-    if (queue_pair.queue_state_ == QueuePair::INACTIVE) {
+    if (queue_pair.host_to_guest.queue_state_ == QueueState::INACTIVE &&
+        queue_pair.guest_to_host.queue_state_ == QueueState::INACTIVE) {
       queue_pair.port_ = port;
-      queue_pair.queue_state_ = QueuePair::HOST_CONNECTED;
-      LOG(DEBUG) << "acquired queue " << id << " . current seq_num: "
-                 << data()->seq_num;
+      queue_pair.host_to_guest.queue_state_ = QueueState::HOST_CONNECTED;
+      queue_pair.guest_to_host.queue_state_ = QueueState::HOST_CONNECTED;
+      LOG(DEBUG) << "acquired queue " << id
+                 << ". current seq_num: " << data()->seq_num;
       ++data()->seq_num;
       return id;
     }
@@ -99,49 +158,15 @@
   LOG(FATAL) << "no remaining shm queues for connection";
   return -1;
 }
-#endif
 
-namespace {
-bool OtherSideDisconnected(const QueuePair& queue_pair) {
-  constexpr auto kOtherSideClosed = QueuePair::
-#ifdef CUTTLEFISH_HOST
-      GUEST_CLOSED;
+std::pair<SocketForwardRegionView::Sender, SocketForwardRegionView::Receiver>
+SocketForwardRegionView::OpenConnection(int port) {
+  int connection_id = AcquireConnectionID(port);
+  LOG(INFO) << "Acquired connection with id " << connection_id;
+  return {Sender{this, connection_id}, Receiver{this, connection_id}};
+}
 #else
-      HOST_CLOSED;
-#endif
-  return queue_pair.queue_state_ == kOtherSideClosed;
-}
-
-void MarkThisSideDisconnected(QueuePair* queue_pair) {
-  constexpr auto kThisSideClosed = QueuePair::
-#ifdef CUTTLEFISH_HOST
-      HOST_CLOSED;
-#else
-      GUEST_CLOSED;
-#endif
-  queue_pair->queue_state_ = kThisSideClosed;
-}
-
-}  // namespace
-
-bool SocketForwardRegionView::IsOtherSideClosed(int connection_id) {
-  auto& queue_pair = data()->queues_[connection_id];
-  auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
-  return OtherSideDisconnected(queue_pair);
-}
-
-void SocketForwardRegionView::ReleaseConnectionID(int connection_id) {
-  auto& queue_pair = data()->queues_[connection_id];
-  auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
-  if (OtherSideDisconnected(queue_pair)) {
-    queue_pair.port_ = 0;
-    queue_pair.queue_state_ = QueuePair::INACTIVE;
-  } else {
-    MarkThisSideDisconnected(&queue_pair);
-  }
-}
-
-std::pair<int, int> SocketForwardRegionView::GetWaitingConnectionIDAndPort() {
+int SocketForwardRegionView::GetWaitingConnectionID() {
   while (data()->seq_num == last_seq_number_) {
     WaitForSignal(&data()->seq_num, last_seq_number_);
   }
@@ -150,81 +175,43 @@
   for (auto&& queue_pair : data()->queues_) {
     LOG(DEBUG) << "locking and checking queue at index " << id;
     auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
-    if (queue_pair.queue_state_ == QueuePair::HOST_CONNECTED) {
+    if (queue_pair.host_to_guest.queue_state_ == QueueState::HOST_CONNECTED) {
+      CHECK(queue_pair.guest_to_host.queue_state_ ==
+            QueueState::HOST_CONNECTED);
       LOG(DEBUG) << "found waiting connection at index " << id;
-      queue_pair.queue_state_ = QueuePair::BOTH_CONNECTED;
-      return {id, queue_pair.port_};
+      queue_pair.host_to_guest.queue_state_ = QueueState::BOTH_CONNECTED;
+      queue_pair.guest_to_host.queue_state_ = QueueState::BOTH_CONNECTED;
+      return id;
     }
     ++id;
   }
-  return {-1, -1};
+  return -1;
 }
 
-#ifdef CUTTLEFISH_HOST
-SocketForwardRegionView::Connection SocketForwardRegionView::OpenConnection(
-    int port) {
-  return {this, AcquireConnectionID(port), port};
-}
-#else
-SocketForwardRegionView::Connection
+std::pair<SocketForwardRegionView::Sender, SocketForwardRegionView::Receiver>
 SocketForwardRegionView::AcceptConnection() {
   int connection_id = -1;
-  int port = -1;
   while (connection_id < 0) {
-    // TODO(haining) if ever in C++17, structured binding declaration
-    auto id_and_port = GetWaitingConnectionIDAndPort();
-    connection_id = id_and_port.first;
-    port = id_and_port.second;
+    connection_id = GetWaitingConnectionID();
   }
-  return {this, connection_id, port};
+  LOG(INFO) << "Accepted connection with id " << connection_id;
+  return {Sender{this, connection_id}, Receiver{this, connection_id}};
 }
 #endif
 
 // --- Connection ---- //
-SocketForwardRegionView::Connection::Connection(SocketForwardRegionView* view,
-                                                int connection_id, int port)
-    : view_{view, {connection_id}}, connection_id_{connection_id}, port_{port} {
-  LOG(INFO) << "opened connection with id " << connection_id_;
-}
-
-SocketForwardRegionView::Sender
-SocketForwardRegionView::Connection::MakeSender() {
-  CHECK(!sender_created_);
-  sender_created_ = true;
-  return Sender{this};
-}
-
-SocketForwardRegionView::Receiver
-SocketForwardRegionView::Connection::MakeReceiver() {
-  CHECK(!receiver_created_);
-  receiver_created_ = true;
-  return Receiver{this};
-}
-
-void SocketForwardRegionView::Connection::IgnoreUntilBegin() {
-  view_->IgnoreUntilBegin(connection_id_);
-}
-
-void SocketForwardRegionView::Connection::Recv(Packet* packet) {
+void SocketForwardRegionView::Receiver::Recv(Packet* packet) {
+  if (!got_begin_) {
+    view_->IgnoreUntilBegin(connection_id_);
+    got_begin_ = true;
+  }
   return view_->Recv(connection_id_, packet);
 }
 
-bool SocketForwardRegionView::Connection::closed() const {
-  return view_->IsOtherSideClosed(connection_id_);
+bool SocketForwardRegionView::Sender::closed() const {
+  return view_->IsOtherSideRecvClosed(connection_id_);
 }
 
-void SocketForwardRegionView::Connection::SendEnd() {
-  view_->SendEnd(connection_id_);
-}
-
-void SocketForwardRegionView::Connection::SendBegin() {
-  view_->SendBegin(connection_id_);
-}
-
-void SocketForwardRegionView::Connection::Send(const Packet& packet) {
-  if (closed()) {
-    LOG(INFO) << "connection closed, not sending\n";
-    return;
-  }
-  view_->Send(connection_id_, packet);
+bool SocketForwardRegionView::Sender::Send(const Packet& packet) {
+  return view_->Send(connection_id_, packet);
 }
diff --git a/common/vsoc/lib/socket_forward_region_view.h b/common/vsoc/lib/socket_forward_region_view.h
index 81110db..884db83 100644
--- a/common/vsoc/lib/socket_forward_region_view.h
+++ b/common/vsoc/lib/socket_forward_region_view.h
@@ -37,7 +37,7 @@
 };
 
 constexpr std::size_t kMaxPayloadSize =
-  layout::socket_forward::kMaxPacketSize - sizeof(Header);
+    layout::socket_forward::kMaxPacketSize - sizeof(Header);
 
 struct Packet {
  private:
@@ -52,21 +52,13 @@
   }
 
  public:
-  static Packet MakeBegin() {
-    return MakePacket(Header::BEGIN);
-  }
+  static Packet MakeBegin() { return MakePacket(Header::BEGIN); }
 
-  static Packet MakeEnd() {
-    return MakePacket(Header::END);
-  }
+  static Packet MakeEnd() { return MakePacket(Header::END); }
 
-  static Packet MakeData() {
-    return MakePacket(Header::DATA);
-  }
+  static Packet MakeData() { return MakePacket(Header::DATA); }
 
-  bool empty() const {
-    return IsData() && header_.payload_length == 0;
-  }
+  bool empty() const { return IsData() && header_.payload_length == 0; }
 
   void set_payload_length(std::uint32_t length) {
     CHECK_LE(length, sizeof payload_data_);
@@ -74,49 +66,29 @@
     header_.payload_length = length;
   }
 
-  std::uint32_t generation() const {
-    return header_.generation;
-  }
+  std::uint32_t generation() const { return header_.generation; }
 
   void set_generation(std::uint32_t generation) {
     header_.generation = generation;
   }
 
-  Payload& payload() {
-    return payload_data_;
-  }
+  Payload& payload() { return payload_data_; }
 
-  const Payload& payload() const {
-    return payload_data_;
-  }
+  const Payload& payload() const { return payload_data_; }
 
-  std::uint32_t payload_length() const {
-    return header_.payload_length;
-  }
+  std::uint32_t payload_length() const { return header_.payload_length; }
 
-  bool IsBegin() const {
-    return header_.message_type == Header::BEGIN;
-  }
+  bool IsBegin() const { return header_.message_type == Header::BEGIN; }
 
-  bool IsEnd() const {
-    return header_.message_type == Header::END;
-  }
+  bool IsEnd() const { return header_.message_type == Header::END; }
 
-  bool IsData() const {
-    return header_.message_type == Header::DATA;
-  }
+  bool IsData() const { return header_.message_type == Header::DATA; }
 
-  char* raw_data() {
-    return reinterpret_cast<char*>(this);
-  }
+  char* raw_data() { return reinterpret_cast<char*>(this); }
 
-  const char* raw_data() const {
-    return reinterpret_cast<const char*>(this);
-  }
+  const char* raw_data() const { return reinterpret_cast<const char*>(this); }
 
-  size_t raw_data_length() const {
-    return payload_length() + sizeof header_;
-  }
+  size_t raw_data_length() const { return payload_length() + sizeof header_; }
 };
 
 static_assert(sizeof(Packet) == layout::socket_forward::kMaxPacketSize, "");
@@ -125,20 +97,19 @@
 // Data sent will start with a uint32_t indicating the number of bytes being
 // sent, followed be the data itself
 class SocketForwardRegionView
-    : public TypedRegionView<
-        SocketForwardRegionView,
-        layout::socket_forward::SocketForwardLayout> {
+    : public TypedRegionView<SocketForwardRegionView,
+                             layout::socket_forward::SocketForwardLayout> {
  private:
 #ifdef CUTTLEFISH_HOST
   int AcquireConnectionID(int port);
+#else
+  int GetWaitingConnectionID();
 #endif
-  void ReleaseConnectionID(int connection_id);
-  std::pair<int, int> GetWaitingConnectionIDAndPort();
 
   // Returns an empty data packet if the other side is closed.
   void Recv(int connection_id, Packet* packet);
-  // Does nothing if packet is empty
-  void Send(int connection_id, const Packet& packet);
+  // Returns true on success
+  bool Send(int connection_id, const Packet& packet);
 
   void SendBegin(int connection_id);
   void SendEnd(int connection_id);
@@ -146,104 +117,79 @@
   // skip everything in the connection queue until seeing a BEGIN
   void IgnoreUntilBegin(int connection_id);
 
-  bool IsOtherSideClosed(int connection_id);
+  bool IsOtherSideRecvClosed(int connection_id);
 
  public:
-  class Sender;
-  class Receiver;
-
-  // MakeSender and MakeReceiver may only be called once per connection.
-  // Moving a Connection object invalidates any existing Sender or Receiver.
-  class Connection {
-    friend Receiver;
-    friend Sender;
-    friend SocketForwardRegionView;
-
-   public:
-    Connection(const Connection&) = delete;
-    Connection& operator=(const Connection&) = delete;
-
-    Connection(Connection&&) = default;
-    Connection& operator=(Connection&&) = default;
-    ~Connection() = default;
-
-    Sender MakeSender();
-    Receiver MakeReceiver();
-
-    int port() const {
-      return port_;
-    }
-    bool closed() const;
-
-   private:
-    // Sends should be done using a Sender.
-    void Send(const Packet& packet);
-    void SendBegin();
-    void SendEnd();
-
-    // Receives should be done using a Receiver.
-    void Recv(Packet* packet);
-    void IgnoreUntilBegin();
-
-    struct Releaser {
-      int connection_id_;
-      void operator()(SocketForwardRegionView* view) const {
-        if (view) {
-          view->ReleaseConnectionID(connection_id_);
-        }
-      }
-    };
-
-    Connection(SocketForwardRegionView* view, int connection_id, int port);
-
-    // this is a little weird, but I'm using the unique_ptr to release the id
-    // to the view, it's really representing ownership of the connection_id_
-    std::unique_ptr<SocketForwardRegionView, Releaser> view_{};
-    int connection_id_ = -1;
-    int port_ = -1;
-
-    bool receiver_created_ = false;
-    bool sender_created_ = false;
-  };
-
   // Helper class that will send a ConnectionBegin marker when constructed and a
   // ConnectionEnd marker when destroyed.
   class Sender {
    public:
-    explicit Sender(Connection* connection) : connection_{connection} {
-      connection_->SendBegin();
+    explicit Sender(SocketForwardRegionView* view, int connection_id)
+        : view_{view, {connection_id}}, connection_id_{connection_id} {
+      view_->SendBegin(connection_id);
     }
 
-    void Send(const Packet& packet) {
-      connection_->Send(packet);
-    }
+    Sender(const Sender&) = delete;
+    Sender& operator=(const Sender&) = delete;
+
+    Sender(Sender&&) = default;
+    Sender& operator=(Sender&&) = default;
+    ~Sender() = default;
+
+    // Returns true on success
+    bool Send(const Packet& packet);
+    int port() const { return view_->port(connection_id_); }
 
    private:
+    bool closed() const;
+
     struct EndSender {
-      void operator()(Connection* connection) const {
-        if (connection) {
-          connection->SendEnd();
+      int connection_id = -1;
+      void operator()(SocketForwardRegionView* view) const {
+        if (view) {
+          CHECK(connection_id >= 0);
+          view->SendEnd(connection_id);
+          view->MarkSendQueueDisconnected(connection_id);
         }
       }
     };
-    // Doesn't actually own the Connection, responsible for sending the End
-    // indicator
-    std::unique_ptr<Connection, EndSender> connection_;
+    // Doesn't actually own the View, responsible for sending the End
+    // indicator and marking the sending side as disconnected.
+    std::unique_ptr<SocketForwardRegionView, EndSender> view_;
+    int connection_id_{};
   };
 
   // Helper class that will wait for a ConnectionBegin marker when constructed
   class Receiver {
    public:
-    explicit Receiver(Connection* connection) : connection_{connection} {
-      connection_->IgnoreUntilBegin();
-    }
+    explicit Receiver(SocketForwardRegionView* view, int connection_id)
+        : view_{view, {connection_id}}, connection_id_{connection_id} {}
+    Receiver(const Receiver&) = delete;
+    Receiver& operator=(const Receiver&) = delete;
 
-    void Recv(Packet* packet) {
-      return connection_->Recv(packet);
-    }
+    Receiver(Receiver&&) = default;
+    Receiver& operator=(Receiver&&) = default;
+    ~Receiver() = default;
+
+    void Recv(Packet* packet);
+    int port() const { return view_->port(connection_id_); }
 
    private:
-    Connection* connection_;
+    struct QueueCloser {
+      int connection_id = -1;
+      void operator()(SocketForwardRegionView* view) const {
+        if (view) {
+          CHECK(connection_id >= 0);
+          view->MarkRecvQueueDisconnected(connection_id);
+        }
+      }
+    };
+
+    // Doesn't actually own the View, responsible for marking the receiving
+    // side as disconnected
+    std::unique_ptr<SocketForwardRegionView, QueueCloser> view_{};
+    int connection_id_{};
+    bool got_begin_ = false;
   };
 
   SocketForwardRegionView() = default;
@@ -252,13 +198,19 @@
   SocketForwardRegionView& operator=(const SocketForwardRegionView&) = delete;
 
 #ifdef CUTTLEFISH_HOST
-  Connection OpenConnection(int port);
+  std::pair<Sender, Receiver> OpenConnection(int port);
 #else
-  Connection AcceptConnection();
+  std::pair<Sender, Receiver> AcceptConnection();
 #endif
 
+  int port(int connection_id);
+  void MarkSendQueueDisconnected(int connection_id);
+  void MarkRecvQueueDisconnected(int connection_id);
+
  private:
+#ifndef CUTTLEFISH_HOST
   std::uint32_t last_seq_number_{};
+#endif
 };
 
 }  // namespace socket_forward
diff --git a/common/vsoc/shm/socket_forward_layout.h b/common/vsoc/shm/socket_forward_layout.h
index 9f4de05..45484c4 100644
--- a/common/vsoc/shm/socket_forward_layout.h
+++ b/common/vsoc/shm/socket_forward_layout.h
@@ -27,36 +27,43 @@
 
 constexpr std::size_t kMaxPacketSize = 8192;
 
+enum class QueueState : std::uint32_t {
+  INACTIVE = 0,
+  HOST_CONNECTED = 1,
+  BOTH_CONNECTED = 2,
+  HOST_CLOSED = 3,
+  GUEST_CLOSED = 4,
+  // If both are closed then the queue goes back to INACTIVE
+  // BOTH_CLOSED = 0,
+};
+
+struct Queue {
+  CircularPacketQueue<16, kMaxPacketSize> queue;
+
+  QueueState queue_state_;
+
+  bool Recover() {
+    return queue.Recover();
+  }
+};
+
 struct QueuePair {
   // Traffic originating from host that proceeds towards guest.
-  CircularPacketQueue<16, kMaxPacketSize> host_to_guest;
+  Queue host_to_guest;
   // Traffic originating from guest that proceeds towards host.
-  CircularPacketQueue<16, kMaxPacketSize> guest_to_host;
+  Queue guest_to_host;
 
-  enum QueueState : std::uint32_t {
-    INACTIVE = 0,
-    HOST_CONNECTED = 1,
-    BOTH_CONNECTED = 2,
-    HOST_CLOSED = 3,
-    GUEST_CLOSED = 4,
-    // If both are closed then the queue goes back to INACTIVE
-    // BOTH_CLOSED = 0,
-  };
-  QueueState queue_state_;
   std::uint32_t port_;
 
   SpinLock queue_state_lock_;
 
+
   bool Recover() {
-    bool recovered = false;
-    bool rval = host_to_guest.Recover();
-    recovered = recovered || rval;
-    rval = guest_to_host.Recover();
-    recovered = recovered || rval;
-    rval = queue_state_lock_.Recover();
-    recovered = recovered || rval;
     // TODO: Put queue_state_ and port_ recovery here, probably after grabbing
-    // the queue_state_lock_.
+    bool recovered = false;
+    recovered = recovered ||  host_to_guest.Recover();
+    recovered = recovered || guest_to_host.Recover();
+    recovered = recovered || queue_state_lock_.Recover();
     return recovered;
   }
 };
diff --git a/common/vsoc/shm/version.h b/common/vsoc/shm/version.h
index 5d2de62..e1c715b 100644
--- a/common/vsoc/shm/version.h
+++ b/common/vsoc/shm/version.h
@@ -149,7 +149,10 @@
 constexpr uint32_t version = 0;
 constexpr std::size_t kNumQueues = 16;
 constexpr std::size_t SocketForwardLayout_size =
-    (65548 * 2  + 4 + 4 + 4) * kNumQueues // queues + state + port + lock
+    ((((65548  + 4) // queue + state
+       * 2) // host_to_guest and guest_to_host
+      + 4 + 4) // port and state_lock
+     * kNumQueues)
     + 4; // seq_num
 }  // namespace socket_forward