blob: dca23efd26a481f39de203b5a839c6590e7dac88 [file] [log] [blame]
Alex Vakulenkoe4eec202017-01-27 14:41:04 -08001#include "uds/ipc_helper.h"
2
3#include <alloca.h>
4#include <errno.h>
5#include <log/log.h>
6#include <poll.h>
7#include <string.h>
8#include <sys/inotify.h>
9#include <sys/param.h>
10#include <sys/socket.h>
11
12#include <algorithm>
13
14#include <pdx/service.h>
15#include <pdx/utility.h>
16
17namespace android {
18namespace pdx {
19namespace uds {
20
21uint32_t kMagicPreamble = 0x7564736d; // 'udsm'.
22
23struct MessagePreamble {
24 uint32_t magic{0};
25 uint32_t data_size{0};
26 uint32_t fd_count{0};
27};
28
29Status<void> SendPayload::Send(int socket_fd) {
30 return Send(socket_fd, nullptr);
31}
32
33Status<void> SendPayload::Send(int socket_fd, const ucred* cred) {
34 MessagePreamble preamble;
35 preamble.magic = kMagicPreamble;
36 preamble.data_size = buffer_.size();
37 preamble.fd_count = file_handles_.size();
38
39 ssize_t ret =
40 RETRY_EINTR(send(socket_fd, &preamble, sizeof(preamble), MSG_NOSIGNAL));
41 if (ret < 0)
42 return ErrorStatus(errno);
43 if (ret != sizeof(preamble))
44 return ErrorStatus(EIO);
45
46 msghdr msg = {};
47 iovec recv_vect = {buffer_.data(), buffer_.size()};
48 msg.msg_iov = &recv_vect;
49 msg.msg_iovlen = 1;
50
51 if (cred || !file_handles_.empty()) {
52 const size_t fd_bytes = file_handles_.size() * sizeof(int);
53 msg.msg_controllen = (cred ? CMSG_SPACE(sizeof(ucred)) : 0) +
54 (fd_bytes == 0 ? 0 : CMSG_SPACE(fd_bytes));
55 msg.msg_control = alloca(msg.msg_controllen);
56
57 cmsghdr* control = CMSG_FIRSTHDR(&msg);
58 if (cred) {
59 control->cmsg_level = SOL_SOCKET;
60 control->cmsg_type = SCM_CREDENTIALS;
61 control->cmsg_len = CMSG_LEN(sizeof(ucred));
62 memcpy(CMSG_DATA(control), cred, sizeof(ucred));
63 control = CMSG_NXTHDR(&msg, control);
64 }
65
66 if (fd_bytes) {
67 control->cmsg_level = SOL_SOCKET;
68 control->cmsg_type = SCM_RIGHTS;
69 control->cmsg_len = CMSG_LEN(fd_bytes);
70 memcpy(CMSG_DATA(control), file_handles_.data(), fd_bytes);
71 }
72 }
73
74 ret = RETRY_EINTR(sendmsg(socket_fd, &msg, MSG_NOSIGNAL));
75 if (ret < 0)
76 return ErrorStatus(errno);
77 if (static_cast<size_t>(ret) != buffer_.size())
78 return ErrorStatus(EIO);
79 return {};
80}
81
82// MessageWriter
83void* SendPayload::GetNextWriteBufferSection(size_t size) {
84 return buffer_.grow_by(size);
85}
86
87OutputResourceMapper* SendPayload::GetOutputResourceMapper() { return this; }
88
89// OutputResourceMapper
90FileReference SendPayload::PushFileHandle(const LocalHandle& handle) {
91 if (handle) {
92 const int ref = file_handles_.size();
93 file_handles_.push_back(handle.Get());
94 return ref;
95 } else {
96 return handle.Get();
97 }
98}
99
100FileReference SendPayload::PushFileHandle(const BorrowedHandle& handle) {
101 if (handle) {
102 const int ref = file_handles_.size();
103 file_handles_.push_back(handle.Get());
104 return ref;
105 } else {
106 return handle.Get();
107 }
108}
109
110FileReference SendPayload::PushFileHandle(const RemoteHandle& handle) {
111 return handle.Get();
112}
113
114ChannelReference SendPayload::PushChannelHandle(
115 const LocalChannelHandle& /*handle*/) {
116 return -1;
117}
118ChannelReference SendPayload::PushChannelHandle(
119 const BorrowedChannelHandle& /*handle*/) {
120 return -1;
121}
122ChannelReference SendPayload::PushChannelHandle(
123 const RemoteChannelHandle& /*handle*/) {
124 return -1;
125}
126
127Status<void> ReceivePayload::Receive(int socket_fd) {
128 return Receive(socket_fd, nullptr);
129}
130
131Status<void> ReceivePayload::Receive(int socket_fd, ucred* cred) {
132 MessagePreamble preamble;
133 ssize_t ret =
134 RETRY_EINTR(recv(socket_fd, &preamble, sizeof(preamble), MSG_WAITALL));
135 if (ret < 0)
136 return ErrorStatus(errno);
137 if (ret != sizeof(preamble) || preamble.magic != kMagicPreamble)
138 return ErrorStatus(EIO);
139
140 buffer_.resize(preamble.data_size);
141 file_handles_.clear();
142 read_pos_ = 0;
143
144 msghdr msg = {};
145 iovec recv_vect = {buffer_.data(), buffer_.size()};
146 msg.msg_iov = &recv_vect;
147 msg.msg_iovlen = 1;
148
149 if (cred || preamble.fd_count) {
150 const size_t receive_fd_bytes = preamble.fd_count * sizeof(int);
151 msg.msg_controllen =
152 (cred ? CMSG_SPACE(sizeof(ucred)) : 0) +
153 (receive_fd_bytes == 0 ? 0 : CMSG_SPACE(receive_fd_bytes));
154 msg.msg_control = alloca(msg.msg_controllen);
155 }
156
157 ret = RETRY_EINTR(recvmsg(socket_fd, &msg, MSG_WAITALL));
158 if (ret < 0)
159 return ErrorStatus(errno);
160 if (static_cast<uint32_t>(ret) != preamble.data_size)
161 return ErrorStatus(EIO);
162
163 bool cred_available = false;
164 file_handles_.reserve(preamble.fd_count);
165 cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
166 while (cmsg) {
167 if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_CREDENTIALS &&
168 cred && cmsg->cmsg_len == CMSG_LEN(sizeof(ucred))) {
169 cred_available = true;
170 memcpy(cred, CMSG_DATA(cmsg), sizeof(ucred));
171 } else if (cmsg->cmsg_level == SOL_SOCKET &&
172 cmsg->cmsg_type == SCM_RIGHTS) {
173 socklen_t payload_len = cmsg->cmsg_len - CMSG_LEN(0);
174 const int* fds = reinterpret_cast<const int*>(CMSG_DATA(cmsg));
175 size_t fd_count = payload_len / sizeof(int);
176 std::transform(fds, fds + fd_count, std::back_inserter(file_handles_),
177 [](int fd) { return LocalHandle{fd}; });
178 }
179 cmsg = CMSG_NXTHDR(&msg, cmsg);
180 }
181
182 if (cred && !cred_available) {
183 return ErrorStatus(EIO);
184 }
185
186 return {};
187}
188
189// MessageReader
190MessageReader::BufferSection ReceivePayload::GetNextReadBufferSection() {
191 return {buffer_.data() + read_pos_, &*buffer_.end()};
192}
193
194void ReceivePayload::ConsumeReadBufferSectionData(const void* new_start) {
195 read_pos_ = PointerDistance(new_start, buffer_.data());
196}
197
198InputResourceMapper* ReceivePayload::GetInputResourceMapper() { return this; }
199
200// InputResourceMapper
201bool ReceivePayload::GetFileHandle(FileReference ref, LocalHandle* handle) {
202 if (ref < 0) {
203 *handle = LocalHandle{ref};
204 return true;
205 }
206 if (static_cast<size_t>(ref) > file_handles_.size())
207 return false;
208 *handle = std::move(file_handles_[ref]);
209 return true;
210}
211
212bool ReceivePayload::GetChannelHandle(ChannelReference /*ref*/,
213 LocalChannelHandle* /*handle*/) {
214 return false;
215}
216
217Status<void> SendData(int socket_fd, const void* data, size_t size) {
218 ssize_t size_written = RETRY_EINTR(send(socket_fd, data, size, MSG_NOSIGNAL));
219 if (size_written < 0)
220 return ErrorStatus(errno);
221 if (static_cast<size_t>(size_written) != size)
222 return ErrorStatus(EIO);
223 return {};
224}
225
226Status<void> SendDataVector(int socket_fd, const iovec* data, size_t count) {
227 msghdr msg = {};
228 msg.msg_iov = const_cast<iovec*>(data);
229 msg.msg_iovlen = count;
230 ssize_t size_written = RETRY_EINTR(sendmsg(socket_fd, &msg, MSG_NOSIGNAL));
231 if (size_written < 0)
232 return ErrorStatus(errno);
233 if (static_cast<size_t>(size_written) != CountVectorSize(data, count))
234 return ErrorStatus(EIO);
235 return {};
236}
237
238Status<void> ReceiveData(int socket_fd, void* data, size_t size) {
239 ssize_t size_read = RETRY_EINTR(recv(socket_fd, data, size, MSG_WAITALL));
240 if (size_read < 0)
241 return ErrorStatus(errno);
242 if (static_cast<size_t>(size_read) != size)
243 return ErrorStatus(EIO);
244 return {};
245}
246
247Status<void> ReceiveDataVector(int socket_fd, const iovec* data, size_t count) {
248 msghdr msg = {};
249 msg.msg_iov = const_cast<iovec*>(data);
250 msg.msg_iovlen = count;
251 ssize_t size_read = RETRY_EINTR(recvmsg(socket_fd, &msg, MSG_WAITALL));
252 if (size_read < 0)
253 return ErrorStatus(errno);
254 if (static_cast<size_t>(size_read) != CountVectorSize(data, count))
255 return ErrorStatus(EIO);
256 return {};
257}
258
259size_t CountVectorSize(const iovec* vector, size_t count) {
260 return std::accumulate(
261 vector, vector + count, size_t{0},
262 [](size_t size, const iovec& vec) { return size + vec.iov_len; });
263}
264
265void InitRequest(android::pdx::uds::RequestHeader<BorrowedHandle>* request,
266 int opcode, uint32_t send_len, uint32_t max_recv_len,
267 bool is_impulse) {
268 request->op = opcode;
269 request->cred.pid = getpid();
270 request->cred.uid = geteuid();
271 request->cred.gid = getegid();
272 request->send_len = send_len;
273 request->max_recv_len = max_recv_len;
274 request->is_impulse = is_impulse;
275}
276
277Status<void> WaitForEndpoint(const std::string& endpoint_path,
278 int64_t timeout_ms) {
279 // Endpoint path must be absolute.
280 if (endpoint_path.empty() || endpoint_path.front() != '/')
281 return ErrorStatus(EINVAL);
282
283 // Create inotify fd.
284 LocalHandle fd{inotify_init()};
285 if (!fd)
286 return ErrorStatus(errno);
287
288 // Set the inotify fd to non-blocking.
289 int ret = fcntl(fd.Get(), F_GETFL);
290 fcntl(fd.Get(), F_SETFL, ret | O_NONBLOCK);
291
292 // Setup the pollfd.
293 pollfd pfd = {fd.Get(), POLLIN, 0};
294
295 // Find locations of each path separator.
296 std::vector<size_t> separators{0}; // The path is absolute, so '/' is at #0.
297 size_t pos = endpoint_path.find('/', 1);
298 while (pos != std::string::npos) {
299 separators.push_back(pos);
300 pos = endpoint_path.find('/', pos + 1);
301 }
302 separators.push_back(endpoint_path.size());
303
304 // Walk down the path, checking for existence and waiting if needed.
305 pos = 1;
306 size_t links = 0;
307 std::string current;
308 while (pos < separators.size() && links <= MAXSYMLINKS) {
309 std::string previous = current;
310 current = endpoint_path.substr(0, separators[pos]);
311
312 // Check for existence; proceed to setup a watch if not.
313 if (access(current.c_str(), F_OK) < 0) {
314 if (errno != ENOENT)
315 return ErrorStatus(errno);
316
317 // Extract the name of the path component to wait for.
318 std::string next = current.substr(
319 separators[pos - 1] + 1, separators[pos] - separators[pos - 1] - 1);
320
321 // Add a watch on the last existing directory we reach.
322 int wd = inotify_add_watch(
323 fd.Get(), previous.c_str(),
324 IN_CREATE | IN_DELETE_SELF | IN_MOVE_SELF | IN_MOVED_TO);
325 if (wd < 0) {
326 if (errno != ENOENT)
327 return ErrorStatus(errno);
328 // Restart at the beginning if previous was deleted.
329 links = 0;
330 current.clear();
331 pos = 1;
332 continue;
333 }
334
335 // Make sure current didn't get created before the watch was added.
336 ret = access(current.c_str(), F_OK);
337 if (ret < 0) {
338 if (errno != ENOENT)
339 return ErrorStatus(errno);
340
341 bool exit_poll = false;
342 while (!exit_poll) {
343 // Wait for an event or timeout.
344 ret = poll(&pfd, 1, timeout_ms);
345 if (ret <= 0)
346 return ErrorStatus(ret == 0 ? ETIMEDOUT : errno);
347
348 // Read events.
349 char buffer[sizeof(inotify_event) + NAME_MAX + 1];
350
351 ret = read(fd.Get(), buffer, sizeof(buffer));
352 if (ret < 0) {
353 if (errno == EAGAIN || errno == EWOULDBLOCK)
354 continue;
355 else
356 return ErrorStatus(errno);
357 } else if (static_cast<size_t>(ret) < sizeof(struct inotify_event)) {
358 return ErrorStatus(EIO);
359 }
360
361 auto* event = reinterpret_cast<const inotify_event*>(buffer);
362 auto* end = reinterpret_cast<const inotify_event*>(buffer + ret);
363 while (event < end) {
364 std::string event_for;
365 if (event->len > 0)
366 event_for = event->name;
367
368 if (event->mask & (IN_CREATE | IN_MOVED_TO)) {
369 // See if this is the droid we're looking for.
370 if (next == event_for) {
371 exit_poll = true;
372 break;
373 }
374 } else if (event->mask & (IN_DELETE_SELF | IN_MOVE_SELF)) {
375 // Restart at the beginning if our watch dir is deleted.
376 links = 0;
377 current.clear();
378 pos = 0;
379 exit_poll = true;
380 break;
381 }
382
383 event = reinterpret_cast<const inotify_event*>(AdvancePointer(
384 event, sizeof(struct inotify_event) + event->len));
385 } // while (event < end)
386 } // while (!exit_poll)
387 } // Current dir doesn't exist.
388 ret = inotify_rm_watch(fd.Get(), wd);
389 if (ret < 0 && errno != EINVAL)
390 return ErrorStatus(errno);
391 } // if (access(current.c_str(), F_OK) < 0)
392
393 // Check for symbolic link and update link count.
394 struct stat stat_buf;
395 ret = lstat(current.c_str(), &stat_buf);
396 if (ret < 0 && errno != ENOENT)
397 return ErrorStatus(errno);
398 else if (ret == 0 && S_ISLNK(stat_buf.st_mode))
399 links++;
400 pos++;
401 } // while (pos < separators.size() && links <= MAXSYMLINKS)
402
403 return {};
404}
405
406} // namespace uds
407} // namespace pdx
408} // namespace android