blob: 4c21459e05fb20ece68d6c4abc93dbb8ac6bb7dc [file] [log] [blame]
Cody Schuffelen134ff032019-11-22 00:25:32 -08001/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <array>
18#include <cstdint>
19#include <cstdlib>
20#include <iostream>
21#include <limits>
22#include <memory>
23#include <mutex>
24#include <sstream>
25#include <string>
26#include <thread>
27#include <vector>
28#include <glog/logging.h>
29#include <gflags/gflags.h>
30
31#include <unistd.h>
32
33#include "common/libs/fs/shared_fd.h"
34#include "common/libs/strings/str_split.h"
35#include "common/vsoc/lib/socket_forward_region_view.h"
36
37#ifdef CUTTLEFISH_HOST
38#include "host/libs/config/cuttlefish_config.h"
39#endif
40
41using vsoc::socket_forward::Packet;
42using vsoc::socket_forward::SocketForwardRegionView;
43
44#ifdef CUTTLEFISH_HOST
45DEFINE_string(guest_ports, "",
46 "Comma-separated list of ports on which to forward TCP "
47 "connections to the guest.");
48DEFINE_string(host_ports, "",
49 "Comma-separated list of ports on which to run TCP servers on "
50 "the host.");
51#endif
52
53namespace {
54// Sends packets, Shutdown(SHUT_WR) on destruction
55class SocketSender {
56 public:
57 explicit SocketSender(cvd::SharedFD socket) : socket_{std::move(socket)} {}
58
59 SocketSender(SocketSender&&) = default;
60 SocketSender& operator=(SocketSender&&) = default;
61
62 SocketSender(const SocketSender&&) = delete;
63 SocketSender& operator=(const SocketSender&) = delete;
64
65 ~SocketSender() {
66 if (socket_.operator->()) { // check that socket_ was not moved-from
67 socket_->Shutdown(SHUT_WR);
68 }
69 }
70
71 ssize_t SendAll(const Packet& packet) {
72 ssize_t written{};
73 while (written < static_cast<ssize_t>(packet.payload_length())) {
74 if (!socket_->IsOpen()) {
75 return -1;
76 }
77 auto just_written =
78 socket_->Send(packet.payload() + written,
79 packet.payload_length() - written, MSG_NOSIGNAL);
80 if (just_written <= 0) {
81 LOG(INFO) << "Couldn't write to client: "
82 << strerror(socket_->GetErrno());
83 return just_written;
84 }
85 written += just_written;
86 }
87 return written;
88 }
89
90 private:
91 cvd::SharedFD socket_;
92};
93
94class SocketReceiver {
95 public:
96 explicit SocketReceiver(cvd::SharedFD socket) : socket_{std::move(socket)} {}
97
98 SocketReceiver(SocketReceiver&&) = default;
99 SocketReceiver& operator=(SocketReceiver&&) = default;
100
101 SocketReceiver(const SocketReceiver&&) = delete;
102 SocketReceiver& operator=(const SocketReceiver&) = delete;
103
104 // *packet will be empty if Read returns 0 or error
105 void Recv(Packet* packet) {
106 auto size = socket_->Read(packet->payload(), sizeof packet->payload());
107 if (size < 0) {
108 size = 0;
109 }
110 packet->set_payload_length(size);
111 }
112
113 private:
114 cvd::SharedFD socket_;
115};
116
117void SocketToShm(SocketReceiver socket_receiver,
118 SocketForwardRegionView::ShmSender shm_sender) {
119 while (true) {
120 auto packet = Packet::MakeData();
121 socket_receiver.Recv(&packet);
122 if (packet.empty() || !shm_sender.Send(packet)) {
123 break;
124 }
125 }
126 LOG(INFO) << "Socket to shm exiting";
127}
128
129void ShmToSocket(SocketSender socket_sender,
130 SocketForwardRegionView::ShmReceiver shm_receiver) {
131 auto packet = Packet{};
132 while (true) {
133 shm_receiver.Recv(&packet);
134 CHECK(packet.IsData());
135 if (packet.empty()) {
136 break;
137 }
138 if (socket_sender.SendAll(packet) < 0) {
139 break;
140 }
141 }
142 LOG(INFO) << "Shm to socket exiting";
143}
144
145// One thread for reading from shm and writing into a socket.
146// One thread for reading from a socket and writing into shm.
147void HandleConnection(SocketForwardRegionView::ShmSenderReceiverPair shm_sender_and_receiver,
148 cvd::SharedFD socket) {
149 auto socket_to_shm =
150 std::thread(SocketToShm, SocketReceiver{socket}, std::move(shm_sender_and_receiver.first));
151 ShmToSocket(SocketSender{socket}, std::move(shm_sender_and_receiver.second));
152 socket_to_shm.join();
153}
154
155#ifdef CUTTLEFISH_HOST
156struct PortPair {
157 int guest_port;
158 int host_port;
159};
160
161enum class QueueState {
162 kFree,
163 kUsed,
164};
165
166struct SocketConnectionInfo {
167 std::mutex lock{};
168 std::condition_variable cv{};
169 cvd::SharedFD socket{};
170 int guest_port{};
171 QueueState state = QueueState::kFree;
172};
173
174static constexpr auto kNumHostThreads =
175 vsoc::layout::socket_forward::kNumQueues;
176
177using SocketConnectionInfoCollection =
178 std::array<SocketConnectionInfo, kNumHostThreads>;
179
180void MarkAsFree(SocketConnectionInfo* conn) {
181 std::lock_guard<std::mutex> guard{conn->lock};
182 conn->socket = cvd::SharedFD{};
183 conn->guest_port = 0;
184 conn->state = QueueState::kFree;
185}
186
187std::pair<int, cvd::SharedFD> WaitForConnection(SocketConnectionInfo* conn) {
188 std::unique_lock<std::mutex> guard{conn->lock};
189 while (conn->state != QueueState::kUsed) {
190 conn->cv.wait(guard);
191 }
192 return {conn->guest_port, conn->socket};
193}
194
195[[noreturn]] void host_thread(SocketForwardRegionView::ShmConnectionView view,
196 SocketConnectionInfo* conn) {
197 while (true) {
198 int guest_port{};
199 cvd::SharedFD socket{};
200 // TODO structured binding in C++17
201 std::tie(guest_port, socket) = WaitForConnection(conn);
202
203 LOG(INFO) << "Establishing connection to guest port " << guest_port
204 << " with connection_id: " << view.connection_id();
205 HandleConnection(view.EstablishConnection(guest_port), std::move(socket));
206 LOG(INFO) << "Connection to guest port " << guest_port
207 << " closed. Marking queue " << view.connection_id()
208 << " as free.";
209 MarkAsFree(conn);
210 }
211}
212
213bool TryAllocateConnection(SocketConnectionInfo* conn, int guest_port,
214 cvd::SharedFD socket) {
215 bool success = false;
216 {
217 std::lock_guard<std::mutex> guard{conn->lock};
218 if (conn->state == QueueState::kFree) {
219 conn->socket = std::move(socket);
220 conn->guest_port = guest_port;
221 conn->state = QueueState::kUsed;
222 success = true;
223 }
224 }
225 if (success) {
226 conn->cv.notify_one();
227 }
228 return success;
229}
230
231void AllocateWorkers(cvd::SharedFD socket,
232 SocketConnectionInfoCollection* socket_connection_info,
233 int guest_port) {
234 while (true) {
235 for (auto& conn : *socket_connection_info) {
236 if (TryAllocateConnection(&conn, guest_port, socket)) {
237 return;
238 }
239 }
240 LOG(INFO) << "no queues available. sleeping and retrying";
241 sleep(5);
242 }
243}
244
245[[noreturn]] void host_impl(
246 SocketForwardRegionView* shm,
247 SocketConnectionInfoCollection* socket_connection_info,
248 std::vector<PortPair> ports, std::size_t index) {
249 // launch a worker for the following port before handling the current port.
250 // recursion (instead of a loop) removes the need fore any join() or having
251 // the main thread do no work.
252 if (index + 1 < ports.size()) {
253 std::thread(host_impl, shm, socket_connection_info, ports, index + 1)
254 .detach();
255 }
256 auto guest_port = ports[index].guest_port;
257 auto host_port = ports[index].host_port;
258 LOG(INFO) << "starting server on " << host_port << " for guest port "
259 << guest_port;
260 auto server = cvd::SharedFD::SocketLocalServer(host_port, SOCK_STREAM);
261 CHECK(server->IsOpen()) << "Could not start server on port " << host_port;
262 while (true) {
263 LOG(INFO) << "waiting for client connection";
264 auto client_socket = cvd::SharedFD::Accept(*server);
265 CHECK(client_socket->IsOpen()) << "error creating client socket";
266 LOG(INFO) << "client socket accepted";
267 AllocateWorkers(std::move(client_socket), socket_connection_info,
268 guest_port);
269 }
270}
271
272[[noreturn]] void host(SocketForwardRegionView* shm,
273 std::vector<PortPair> ports) {
274 CHECK(!ports.empty());
275
276 SocketConnectionInfoCollection socket_connection_info{};
277
278 auto conn_info_iter = std::begin(socket_connection_info);
279 for (auto& shm_connection_view : shm->AllConnections()) {
280 CHECK_NE(conn_info_iter, std::end(socket_connection_info));
281 std::thread(host_thread, std::move(shm_connection_view), &*conn_info_iter)
282 .detach();
283 ++conn_info_iter;
284 }
285 CHECK_EQ(conn_info_iter, std::end(socket_connection_info));
286 host_impl(shm, &socket_connection_info, ports, 0);
287}
288
289std::vector<PortPair> ParsePortsList(const std::string& guest_ports_str,
290 const std::string& host_ports_str) {
291 std::vector<PortPair> ports{};
292 auto guest_ports = cvd::StrSplit(guest_ports_str, ',');
293 auto host_ports = cvd::StrSplit(host_ports_str, ',');
294 CHECK(guest_ports.size() == host_ports.size());
295 for (std::size_t i = 0; i < guest_ports.size(); ++i) {
296 ports.push_back({std::stoi(guest_ports[i]), std::stoi(host_ports[i])});
297 }
298 return ports;
299}
300
301#else
302cvd::SharedFD OpenSocketConnection(int port) {
303 while (true) {
304 auto sock = cvd::SharedFD::SocketLocalClient(port, SOCK_STREAM);
305 if (sock->IsOpen()) {
306 return sock;
307 }
308 LOG(WARNING) << "could not connect on port " << port
309 << ". sleeping for 1 second";
310 sleep(1);
311 }
312}
313
314[[noreturn]] void guest_thread(
315 SocketForwardRegionView::ShmConnectionView view) {
316 while (true) {
317 LOG(INFO) << "waiting for new connection";
318 auto shm_sender_and_receiver = view.WaitForNewConnection();
319 LOG(INFO) << "new connection for port " << view.port();
320 HandleConnection(std::move(shm_sender_and_receiver), OpenSocketConnection(view.port()));
321 LOG(INFO) << "connection closed on port " << view.port();
322 }
323}
324
325[[noreturn]] void guest(SocketForwardRegionView* shm) {
326 LOG(INFO) << "Starting guest mainloop";
327 auto connection_views = shm->AllConnections();
328 for (auto&& shm_connection_view : connection_views) {
329 std::thread(guest_thread, std::move(shm_connection_view)).detach();
330 }
331 while (true) {
332 sleep(std::numeric_limits<unsigned int>::max());
333 }
334}
335
336#endif
337
338SocketForwardRegionView* GetShm() {
339 auto shm = SocketForwardRegionView::GetInstance(
340#ifdef CUTTLEFISH_HOST
341 vsoc::GetDomain().c_str()
342#endif
343 );
344 if (!shm) {
345 LOG(FATAL) << "Could not open SHM. Aborting.";
346 }
347 shm->CleanUpPreviousConnections();
348 return shm;
349}
350
351// makes sure we're running as root on the guest, no-op on the host
352void assert_correct_user() {
353#ifndef CUTTLEFISH_HOST
354 CHECK_EQ(getuid(), 0u) << "must run as root!";
355#endif
356}
357
358} // namespace
359
360int main(int argc, char* argv[]) {
361 gflags::ParseCommandLineFlags(&argc, &argv, true);
362 assert_correct_user();
363
364 auto shm = GetShm();
365 auto worker = shm->StartWorker();
366
367#ifdef CUTTLEFISH_HOST
368 CHECK(!FLAGS_guest_ports.empty()) << "Must specify --guest_ports flag";
369 CHECK(!FLAGS_host_ports.empty()) << "Must specify --host_ports flag";
370 host(shm, ParsePortsList(FLAGS_guest_ports, FLAGS_host_ports));
371#else
372 guest(shm);
373#endif
374}