blob: cf95fa9b8cdcd39d52e91661ff5fa0bba6a6bc2d [file] [log] [blame]
Ryan Hainingb1899152018-01-29 15:50:37 -08001/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <cassert>
18
19#include "common/vsoc/lib/circqueue_impl.h"
20#include "common/vsoc/lib/lock_guard.h"
21#include "common/vsoc/lib/socket_forward_region_view.h"
22#include "common/vsoc/shm/lock.h"
23#include "common/vsoc/shm/socket_forward_layout.h"
24
25using vsoc::layout::socket_forward::QueuePair;
26// store the read and write direction as variables to keep the ifdefs and macros
27// in later code to a minimum
28constexpr auto ReadDirection = &QueuePair::
29#ifdef CUTTLEFISH_HOST
30guest_to_host;
31#else
32host_to_guest;
33#endif
34
35constexpr auto WriteDirection = &QueuePair::
36#ifdef CUTTLEFISH_HOST
37host_to_guest;
38#else
39guest_to_host;
40#endif
41
42using vsoc::socket_forward::Message;
43using vsoc::socket_forward::SocketForwardRegionView;
44
45constexpr std::int32_t kConnectionBegin = -1;
46constexpr std::int32_t kConnectionEnd = -2;
47
48Message SocketForwardRegionView::Recv(int connection_id) {
49 std::int32_t len{};
50 (data()->queues_[connection_id].*ReadDirection)
51 .Read(this, reinterpret_cast<char*>(&len), sizeof len);
52 if (len == kConnectionEnd) {
53 return {};
54 }
55 CHECK_NE(len, 0) << "zero-size message received";
56 CHECK_GT(len, 0) << "invalid size";
57 Message message(len);
58 (data()->queues_[connection_id].*ReadDirection)
59 .Read(this, reinterpret_cast<char*>(message.data()), message.size());
60 return message;
61}
62
63void SocketForwardRegionView::Send(int connection_id, const Message& message) {
64 if (message.empty()) {
65 return;
66 }
67 std::int32_t len = message.size();
68 (data()->queues_[connection_id].*WriteDirection)
69 .Write(this, reinterpret_cast<const char*>(&len), sizeof len);
70 (data()->queues_[connection_id].*WriteDirection)
71 .Write(this, reinterpret_cast<const char*>(message.data()),
72 message.size());
73}
74
75void SocketForwardRegionView::SendBegin(int connection_id) {
76 (data()->queues_[connection_id].*WriteDirection)
77 .Write(this, reinterpret_cast<const char*>(&kConnectionBegin),
78 sizeof kConnectionBegin);
79}
80
81void SocketForwardRegionView::SendEnd(int connection_id) {
82 (data()->queues_[connection_id].*WriteDirection)
83 .Write(this, reinterpret_cast<const char*>(&kConnectionEnd),
84 sizeof kConnectionEnd);
85}
86
87void SocketForwardRegionView::IgnoreUntilBegin(int connection_id) {
88 Message ignored(128);
89 while (true) {
90 std::int32_t len{};
91 (data()->queues_[connection_id].*ReadDirection)
92 .Read(this, reinterpret_cast<char*>(&len), sizeof len);
93 if (len == kConnectionBegin) {
94 break;
95 } else if (len == kConnectionEnd) {
96 continue;
97 }
98
99 CHECK_NE(len, 0) << "zero-size message received";
100 CHECK_GT(len, 0) << "invalid size";
101 ignored.resize(len);
102 (data()->queues_[connection_id].*ReadDirection)
103 .Read(this, reinterpret_cast<char*>(ignored.data()), ignored.size());
104 }
105}
106
107#ifdef CUTTLEFISH_HOST
108int SocketForwardRegionView::AcquireConnectionID(int port) {
109 int id = 0;
110 for (auto&& queue_pair : data()->queues_) {
111 LOG(DEBUG) << "locking and checking queue at index " << id;
112 auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
113 if (queue_pair.queue_state_ == QueuePair::INACTIVE) {
114 queue_pair.port_ = port;
115 queue_pair.queue_state_ = QueuePair::HOST_CONNECTED;
116 LOG(DEBUG) << "acquired queue " << id << " . current seq_num: "
117 << data()->seq_num;
118 ++data()->seq_num;
119 return id;
120 }
121 ++id;
122 }
123 // TODO(haining) handle all queues being used
124 LOG(FATAL) << "no remaining shm queues for connection";
125 return -1;
126}
127#endif
128
129namespace {
130bool OtherSideDisconnected(const QueuePair& queue_pair) {
131 constexpr auto kOtherSideClosed = QueuePair::
132#ifdef CUTTLEFISH_HOST
133 GUEST_CLOSED;
134#else
135 HOST_CLOSED;
136#endif
137 return queue_pair.queue_state_ == kOtherSideClosed;
138}
139
140void MarkThisSideDisconnected(QueuePair* queue_pair) {
141 constexpr auto kThisSideClosed = QueuePair::
142#ifdef CUTTLEFISH_HOST
143 HOST_CLOSED;
144#else
145 GUEST_CLOSED;
146#endif
147 queue_pair->queue_state_ = kThisSideClosed;
148}
149
150} // namespace
151
152bool SocketForwardRegionView::IsOtherSideClosed(int connection_id) {
153 auto& queue_pair = data()->queues_[connection_id];
154 auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
155 return OtherSideDisconnected(queue_pair);
156}
157
158void SocketForwardRegionView::ReleaseConnectionID(int connection_id) {
159 auto& queue_pair = data()->queues_[connection_id];
160 auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
161 if (OtherSideDisconnected(queue_pair)) {
162 queue_pair.port_ = 0;
163 queue_pair.queue_state_ = QueuePair::INACTIVE;
164 } else {
165 Send(connection_id, {});
166 MarkThisSideDisconnected(&queue_pair);
167 }
168}
169
170std::pair<int, int> SocketForwardRegionView::GetWaitingConnectionIDAndPort() {
171 while (data()->seq_num == last_seq_number_) {
172 WaitForSignal(&data()->seq_num, last_seq_number_);
173 }
174 ++last_seq_number_;
175 int id = 0;
176 for (auto&& queue_pair : data()->queues_) {
177 LOG(DEBUG) << "locking and checking queue at index " << id;
178 auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
179 if (queue_pair.queue_state_ == QueuePair::HOST_CONNECTED) {
180 LOG(DEBUG) << "found waiting connection at index " << id;
181 queue_pair.queue_state_ = QueuePair::BOTH_CONNECTED;
182 return {id, queue_pair.port_};
183 }
184 ++id;
185 }
186 return {-1, -1};
187}
188
Ryan Haining15ee31f2018-02-05 13:18:07 -0800189#if defined(CUTTLEFISH_HOST)
190std::shared_ptr<SocketForwardRegionView> SocketForwardRegionView::GetInstance(
191 const char* domain) {
192 return RegionView::GetInstanceImpl<SocketForwardRegionView>(
193 [](std::shared_ptr<SocketForwardRegionView> region, const char* domain) {
194 return region->Open(domain);
195 },
196 domain);
197}
198#else
199std::shared_ptr<SocketForwardRegionView> SocketForwardRegionView::GetInstance()
200{
201 return RegionView::GetInstanceImpl<SocketForwardRegionView>(
202 std::mem_fn(&SocketForwardRegionView::Open));
203}
204#endif
205
Ryan Hainingb1899152018-01-29 15:50:37 -0800206#ifdef CUTTLEFISH_HOST
207SocketForwardRegionView::Connection SocketForwardRegionView::OpenConnection(
208 int port) {
209 return {this, AcquireConnectionID(port), port};
210}
211#else
212SocketForwardRegionView::Connection
213SocketForwardRegionView::AcceptConnection() {
214 int connection_id = -1;
215 int port = -1;
216 while (connection_id < 0) {
217 // TODO(haining) if ever in C++17, structured binding declaration
218 auto id_and_port = GetWaitingConnectionIDAndPort();
219 connection_id = id_and_port.first;
220 port = id_and_port.second;
221 }
222 return {this, connection_id, port};
223}
224#endif
225
226// --- Connection ---- //
227SocketForwardRegionView::Connection::Connection(SocketForwardRegionView* view,
228 int connection_id, int port)
229 : view_{view, {connection_id}}, connection_id_{connection_id}, port_{port} {
230 LOG(INFO) << "opened connection with id " << connection_id_;
231}
232
233SocketForwardRegionView::Sender
234SocketForwardRegionView::Connection::MakeSender() {
235 CHECK(!sender_created_);
236 sender_created_ = true;
237 return Sender{this};
238}
239
240SocketForwardRegionView::Receiver
241SocketForwardRegionView::Connection::MakeReceiver() {
242 CHECK(!receiver_created_);
243 receiver_created_ = true;
244 return Receiver{this};
245}
246
247void SocketForwardRegionView::Connection::IgnoreUntilBegin() {
248 view_->IgnoreUntilBegin(connection_id_);
249}
250
251Message SocketForwardRegionView::Connection::Recv() {
252 return view_->Recv(connection_id_);
253}
254
255bool SocketForwardRegionView::Connection::closed() const {
256 return view_->IsOtherSideClosed(connection_id_);
257}
258
259void SocketForwardRegionView::Connection::SendEnd() {
260 view_->SendEnd(connection_id_);
261}
262
263void SocketForwardRegionView::Connection::SendBegin() {
264 view_->SendBegin(connection_id_);
265}
266
267void SocketForwardRegionView::Connection::Send(const Message& message) {
268 if (closed()) {
269 LOG(INFO) << "connection closed, not sending\n";
270 return;
271 }
272 view_->Send(connection_id_, message);
273}