blob: c41517bbcae6691169dd7c5a32a056d66490c651 [file] [log] [blame]
Cody Schuffelen134ff032019-11-22 00:25:32 -08001/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#pragma once
17
18#include <cstdlib>
19#include <utility>
20#include <vector>
21#include <memory>
22
23#include "common/vsoc/lib/typed_region_view.h"
24#include "common/vsoc/shm/socket_forward_layout.h"
25
26namespace vsoc {
27namespace socket_forward {
28
29struct Header {
30 std::uint32_t payload_length;
31 enum MessageType : std::uint32_t {
32 DATA = 0,
33 BEGIN,
34 END,
35 RECV_CLOSED, // indicate that this side's receive end is closed
36 RESTART,
37 };
38 MessageType message_type;
39};
40
41constexpr std::size_t kMaxPayloadSize =
42 layout::socket_forward::kMaxPacketSize - sizeof(Header);
43
44struct Packet {
45 private:
46 Header header_;
47 using Payload = char[kMaxPayloadSize];
48 Payload payload_data_;
49
50 static constexpr Packet MakePacket(Header::MessageType type) {
51 Packet packet{};
52 packet.header_.message_type = type;
53 return packet;
54 }
55
56 public:
57 // port is only revelant on the host-side.
58 static Packet MakeBegin(std::uint16_t port);
59
60 static constexpr Packet MakeEnd() { return MakePacket(Header::END); }
61
62 static constexpr Packet MakeRecvClosed() {
63 return MakePacket(Header::RECV_CLOSED);
64 }
65
66 static constexpr Packet MakeRestart() { return MakePacket(Header::RESTART); }
67
68 // NOTE payload and payload_length must still be set.
69 static constexpr Packet MakeData() { return MakePacket(Header::DATA); }
70
71 bool empty() const { return IsData() && header_.payload_length == 0; }
72
73 void set_payload_length(std::uint32_t length) {
74 CHECK_LE(length, sizeof payload_data_);
75 header_.payload_length = length;
76 }
77
78 Payload& payload() { return payload_data_; }
79
80 const Payload& payload() const { return payload_data_; }
81
82 constexpr std::uint32_t payload_length() const {
83 return header_.payload_length;
84 }
85
86 constexpr bool IsBegin() const {
87 return header_.message_type == Header::BEGIN;
88 }
89
90 constexpr bool IsEnd() const { return header_.message_type == Header::END; }
91
92 constexpr bool IsData() const { return header_.message_type == Header::DATA; }
93
94 constexpr bool IsRecvClosed() const {
95 return header_.message_type == Header::RECV_CLOSED;
96 }
97
98 constexpr bool IsRestart() const {
99 return header_.message_type == Header::RESTART;
100 }
101
102 constexpr std::uint16_t port() const {
103 CHECK(IsBegin());
104 std::uint16_t port_number{};
105 CHECK_EQ(payload_length(), sizeof port_number);
106 std::memcpy(&port_number, payload(), sizeof port_number);
107 return port_number;
108 }
109
110 char* raw_data() { return reinterpret_cast<char*>(this); }
111
112 const char* raw_data() const { return reinterpret_cast<const char*>(this); }
113
114 constexpr size_t raw_data_length() const {
115 return payload_length() + sizeof header_;
116 }
117};
118
119static_assert(sizeof(Packet) == layout::socket_forward::kMaxPacketSize, "");
120static_assert(std::is_pod<Packet>{}, "");
121
122// Data sent will start with a uint32_t indicating the number of bytes being
123// sent, followed be the data itself
124class SocketForwardRegionView
125 : public TypedRegionView<SocketForwardRegionView,
126 layout::socket_forward::SocketForwardLayout> {
127 private:
128 // Returns an empty data packet if the other side is closed.
129 void Recv(int connection_id, Packet* packet);
130 // Returns true on success
131 bool Send(int connection_id, const Packet& packet);
132
133 // skip everything in the connection queue until seeing a BEGIN packet.
134 // returns port from begin packet.
135 int IgnoreUntilBegin(int connection_id);
136
137 public:
138 class ShmSender;
139 class ShmReceiver;
140
141 using ShmSenderReceiverPair = std::pair<ShmSender, ShmReceiver>;
142
143 class ShmConnectionView {
144 public:
145 ShmConnectionView(SocketForwardRegionView* region_view, int connection_id)
146 : region_view_{region_view}, connection_id_{connection_id} {}
147
148#ifdef CUTTLEFISH_HOST
149 ShmSenderReceiverPair EstablishConnection(int port);
150#else
151 // Should not be called while there is an active ShmSender or ShmReceiver
152 // for this connection.
153 ShmSenderReceiverPair WaitForNewConnection();
154#endif
155
156 int port() const { return port_; }
157
158 bool Send(const Packet& packet);
159 void Recv(Packet* packet);
160
161 ShmConnectionView(const ShmConnectionView&) = delete;
162 ShmConnectionView& operator=(const ShmConnectionView&) = delete;
163
164 // Moving invalidates all existing ShmSenders and ShmReceiver
165 ShmConnectionView(ShmConnectionView&&) = default;
166 ShmConnectionView& operator=(ShmConnectionView&&) = default;
167 ~ShmConnectionView() = default;
168
169 // NOTE should only be used for debugging/logging purposes.
170 // connection_ids are an implementation detail that are currently useful for
171 // debugging, but may go away in the future.
172 int connection_id() const { return connection_id_; }
173
174 private:
175 SocketForwardRegionView* region_view() const { return region_view_; }
176
177 bool IsOtherSideRecvClosed() {
178 std::lock_guard<std::mutex> guard(*other_side_receive_closed_lock_);
179 return other_side_receive_closed_;
180 }
181
182 void MarkOtherSideRecvClosed() {
183 std::lock_guard<std::mutex> guard(*other_side_receive_closed_lock_);
184 other_side_receive_closed_ = true;
185 }
186
187 void ReceiverThread();
188 ShmSenderReceiverPair ResetAndConnect();
189
190 class Receiver {
191 public:
192 Receiver(ShmConnectionView* view)
193 : view_{view}
194 {
195 receiver_thread_ = std::thread([this] { Start(); });
196 }
197
198 void Recv(Packet* packet);
199
200 void Join() { receiver_thread_.join(); }
201
202 Receiver(const Receiver&) = delete;
203 Receiver& operator=(const Receiver&) = delete;
204
205 ~Receiver() = default;
206 private:
207 void Start();
208 bool GotRecvClosed() const;
209 void ReceivePacket();
210 void CheckPacketForRecvClosed();
211 void CheckPacketForEnd();
212 void UpdatePacketAndSignalAvailable();
213 bool ShouldReceiveAnotherPacket() const;
214 bool ExpectMorePackets() const;
215
216 std::mutex receive_thread_data_lock_;
217 std::condition_variable receive_thread_data_cv_;
218 bool received_packet_free_ = true;
219 Packet received_packet_{};
220
221 ShmConnectionView* view_{};
222 bool saw_recv_closed_ = false;
223 bool saw_end_ = false;
224#ifdef CUTTLEFISH_HOST
225 bool saw_data_ = false;
226#endif
227
228 std::thread receiver_thread_;
229 };
230
231 SocketForwardRegionView* region_view_{};
232 int connection_id_ = -1;
233 int port_ = -1;
234
235 std::unique_ptr<std::mutex> other_side_receive_closed_lock_ =
236 std::unique_ptr<std::mutex>{new std::mutex{}};
237 bool other_side_receive_closed_ = false;
238
239 std::unique_ptr<Receiver> receiver_;
240 };
241
242 class ShmSender {
243 public:
244 explicit ShmSender(ShmConnectionView* view) : view_{view} {}
245
246 ShmSender(const ShmSender&) = delete;
247 ShmSender& operator=(const ShmSender&) = delete;
248
249 ShmSender(ShmSender&&) = default;
250 ShmSender& operator=(ShmSender&&) = default;
251 ~ShmSender() = default;
252
253 // Returns true on success
254 bool Send(const Packet& packet);
255
256 private:
257 struct EndSender {
258 void operator()(ShmConnectionView* view) const {
259 if (view) {
260 view->Send(Packet::MakeEnd());
261 }
262 }
263 };
264
265 // Doesn't actually own the View, responsible for sending the End
266 // indicator and marking the sending side as disconnected.
267 std::unique_ptr<ShmConnectionView, EndSender> view_;
268 };
269
270 class ShmReceiver {
271 public:
272 explicit ShmReceiver(ShmConnectionView* view) : view_{view} {}
273 ShmReceiver(const ShmReceiver&) = delete;
274 ShmReceiver& operator=(const ShmReceiver&) = delete;
275
276 ShmReceiver(ShmReceiver&&) = default;
277 ShmReceiver& operator=(ShmReceiver&&) = default;
278 ~ShmReceiver() = default;
279
280 void Recv(Packet* packet);
281
282 private:
283 struct RecvClosedSender {
284 void operator()(ShmConnectionView* view) const {
285 if (view) {
286 view->Send(Packet::MakeRecvClosed());
287 }
288 }
289 };
290
291 // Doesn't actually own the view, responsible for sending the RecvClosed
292 // indicator
293 std::unique_ptr<ShmConnectionView, RecvClosedSender> view_{};
294 };
295
296 friend ShmConnectionView;
297
298 SocketForwardRegionView() = default;
299 ~SocketForwardRegionView() = default;
300 SocketForwardRegionView(const SocketForwardRegionView&) = delete;
301 SocketForwardRegionView& operator=(const SocketForwardRegionView&) = delete;
302
303 using ConnectionViewCollection = std::vector<ShmConnectionView>;
304 ConnectionViewCollection AllConnections();
305
306 int port(int connection_id);
307 void CleanUpPreviousConnections();
308
309 private:
310#ifndef CUTTLEFISH_HOST
311 std::uint32_t last_seq_number_{};
312#endif
313};
314
315} // namespace socket_forward
316} // namespace vsoc