blob: b6021e877523b2b7204eaa1c7dc02e962b9b2926 [file] [log] [blame]
#include "uds/service_endpoint.h"
#include <poll.h>
#include <sys/epoll.h>
#include <sys/eventfd.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <algorithm> // std::min
#include <pdx/service.h>
#include <uds/channel_manager.h>
#include <uds/client_channel_factory.h>
#include <uds/ipc_helper.h>
namespace {
constexpr int kMaxBackLogForSocketListen = 1;
using android::pdx::BorrowedChannelHandle;
using android::pdx::BorrowedHandle;
using android::pdx::ChannelReference;
using android::pdx::FileReference;
using android::pdx::LocalChannelHandle;
using android::pdx::LocalHandle;
using android::pdx::Status;
using android::pdx::uds::ChannelInfo;
using android::pdx::uds::ChannelManager;
struct MessageState {
bool GetLocalFileHandle(int index, LocalHandle* handle) {
if (index < 0) {
handle->Reset(index);
} else if (static_cast<size_t>(index) < request.file_descriptors.size()) {
*handle = std::move(request.file_descriptors[index]);
} else {
return false;
}
return true;
}
bool GetLocalChannelHandle(int index, LocalChannelHandle* handle) {
if (index < 0) {
*handle = LocalChannelHandle{nullptr, index};
} else if (static_cast<size_t>(index) < request.channels.size()) {
auto& channel_info = request.channels[index];
*handle = ChannelManager::Get().CreateHandle(
std::move(channel_info.data_fd), std::move(channel_info.event_fd));
} else {
return false;
}
return true;
}
FileReference PushFileHandle(BorrowedHandle handle) {
if (!handle)
return handle.Get();
response.file_descriptors.push_back(std::move(handle));
return response.file_descriptors.size() - 1;
}
ChannelReference PushChannelHandle(BorrowedChannelHandle handle) {
if (!handle)
return handle.value();
ChannelInfo<BorrowedHandle> channel_info;
channel_info.data_fd.Reset(handle.value());
channel_info.event_fd.Reset(
ChannelManager::Get().GetEventFd(handle.value()));
response.channels.push_back(std::move(channel_info));
return response.channels.size() - 1;
}
ChannelReference PushChannelHandle(BorrowedHandle data_fd,
BorrowedHandle event_fd) {
if (!data_fd || !event_fd)
return -1;
ChannelInfo<BorrowedHandle> channel_info;
channel_info.data_fd = std::move(data_fd);
channel_info.event_fd = std::move(event_fd);
response.channels.push_back(std::move(channel_info));
return response.channels.size() - 1;
}
ssize_t WriteData(const iovec* vector, size_t vector_length) {
ssize_t size = 0;
for (size_t i = 0; i < vector_length; i++) {
const auto* data = reinterpret_cast<const uint8_t*>(vector[i].iov_base);
response_data.insert(response_data.end(), data, data + vector[i].iov_len);
size += vector[i].iov_len;
}
return size;
}
ssize_t ReadData(const iovec* vector, size_t vector_length) {
size_t size_remaining = request_data.size() - request_data_read_pos;
ssize_t size = 0;
for (size_t i = 0; i < vector_length && size_remaining > 0; i++) {
size_t size_to_copy = std::min(size_remaining, vector[i].iov_len);
memcpy(vector[i].iov_base, request_data.data() + request_data_read_pos,
size_to_copy);
size += size_to_copy;
request_data_read_pos += size_to_copy;
size_remaining -= size_to_copy;
}
return size;
}
android::pdx::uds::RequestHeader<LocalHandle> request;
android::pdx::uds::ResponseHeader<BorrowedHandle> response;
std::vector<LocalHandle> sockets_to_close;
std::vector<uint8_t> request_data;
size_t request_data_read_pos{0};
std::vector<uint8_t> response_data;
};
} // anonymous namespace
namespace android {
namespace pdx {
namespace uds {
Endpoint::Endpoint(const std::string& endpoint_path, bool blocking)
: endpoint_path_{ClientChannelFactory::GetEndpointPath(endpoint_path)},
is_blocking_{blocking} {
LocalHandle fd{socket(AF_UNIX, SOCK_STREAM, 0)};
if (!fd) {
ALOGE("Endpoint::Endpoint: Failed to create socket: %s", strerror(errno));
return;
}
sockaddr_un local;
local.sun_family = AF_UNIX;
strncpy(local.sun_path, endpoint_path_.c_str(), sizeof(local.sun_path));
local.sun_path[sizeof(local.sun_path) - 1] = '\0';
unlink(local.sun_path);
if (bind(fd.Get(), (struct sockaddr*)&local, sizeof(local)) == -1) {
ALOGE("Endpoint::Endpoint: bind error: %s", strerror(errno));
return;
}
if (listen(fd.Get(), kMaxBackLogForSocketListen) == -1) {
ALOGE("Endpoint::Endpoint: listen error: %s", strerror(errno));
return;
}
cancel_event_fd_.Reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK));
if (!cancel_event_fd_) {
ALOGE("Endpoint::Endpoint: Failed to create event fd: %s\n",
strerror(errno));
return;
}
epoll_fd_.Reset(epoll_create(1)); // Size arg is ignored, but must be > 0.
if (!epoll_fd_) {
ALOGE("Endpoint::Endpoint: Failed to create epoll fd: %s\n",
strerror(errno));
return;
}
// Use "this" as a unique pointer to distinguish the event fd from all
// the other entries that point to instances of Service.
epoll_event socket_event;
socket_event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
socket_event.data.fd = fd.Get();
epoll_event cancel_event;
cancel_event.events = EPOLLIN;
cancel_event.data.fd = cancel_event_fd_.Get();
if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, fd.Get(), &socket_event) < 0 ||
epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, cancel_event_fd_.Get(),
&cancel_event) < 0) {
ALOGE("Endpoint::Endpoint: Failed to add event fd to epoll fd: %s\n",
strerror(errno));
cancel_event_fd_.Close();
epoll_fd_.Close();
} else {
socket_fd_ = std::move(fd);
}
}
void* Endpoint::AllocateMessageState() { return new MessageState; }
void Endpoint::FreeMessageState(void* state) {
delete static_cast<MessageState*>(state);
}
Status<void> Endpoint::AcceptConnection(Message* message) {
sockaddr_un remote;
socklen_t addrlen = sizeof(remote);
LocalHandle channel_fd{
accept(socket_fd_.Get(), reinterpret_cast<sockaddr*>(&remote), &addrlen)};
if (!channel_fd) {
ALOGE("Endpoint::AcceptConnection: failed to accept connection: %s",
strerror(errno));
return ErrorStatus(errno);
}
int optval = 1;
if (setsockopt(channel_fd.Get(), SOL_SOCKET, SO_PASSCRED, &optval,
sizeof(optval)) == -1) {
ALOGE(
"Endpoint::AcceptConnection: Failed to enable the receiving of the "
"credentials for channel %d: %s",
channel_fd.Get(), strerror(errno));
return ErrorStatus(errno);
}
auto status = ReceiveMessageForChannel(channel_fd.Get(), message);
if (status)
status = OnNewChannel(std::move(channel_fd));
return status;
}
int Endpoint::SetService(Service* service) {
service_ = service;
return 0;
}
int Endpoint::SetChannel(int channel_id, Channel* channel) {
std::lock_guard<std::mutex> autolock(channel_mutex_);
auto channel_data = channels_.find(channel_id);
if (channel_data == channels_.end())
return -EINVAL;
channel_data->second.channel_state = channel;
return 0;
}
Status<void> Endpoint::OnNewChannel(LocalHandle channel_fd) {
std::lock_guard<std::mutex> autolock(channel_mutex_);
Status<void> status;
status.PropagateError(OnNewChannelLocked(std::move(channel_fd), nullptr));
return status;
}
Status<Endpoint::ChannelData*> Endpoint::OnNewChannelLocked(
LocalHandle channel_fd, Channel* channel_state) {
epoll_event event;
event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
event.data.fd = channel_fd.Get();
if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, channel_fd.Get(), &event) < 0) {
ALOGE(
"Endpoint::OnNewChannelLocked: Failed to add channel to endpoint: %s\n",
strerror(errno));
return ErrorStatus(errno);
}
ChannelData channel_data;
int channel_id = channel_fd.Get();
channel_data.data_fd = std::move(channel_fd);
channel_data.event_fd.Reset(eventfd(0, 0));
channel_data.channel_state = channel_state;
auto pair = channels_.emplace(channel_id, std::move(channel_data));
return &pair.first->second;
}
Status<void> Endpoint::ReenableEpollEvent(int fd) {
epoll_event event;
event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
event.data.fd = fd;
if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_MOD, fd, &event) < 0) {
ALOGE(
"Endpoint::ReenableEpollEvent: Failed to re-enable channel to "
"endpoint: %s\n",
strerror(errno));
return ErrorStatus(errno);
}
return {};
}
int Endpoint::CloseChannel(int channel_id) {
std::lock_guard<std::mutex> autolock(channel_mutex_);
return CloseChannelLocked(channel_id);
}
int Endpoint::CloseChannelLocked(int channel_id) {
auto channel_data = channels_.find(channel_id);
if (channel_data == channels_.end())
return -EINVAL;
int ret = 0;
epoll_event dummy; // See BUGS in man 2 epoll_ctl.
if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_DEL, channel_id, &dummy) < 0) {
ret = -errno;
ALOGE(
"Endpoint::CloseChannelLocked: Failed to remove channel from endpoint: "
"%s\n",
strerror(errno));
}
channels_.erase(channel_data);
return ret;
}
int Endpoint::ModifyChannelEvents(int channel_id, int clear_mask,
int set_mask) {
std::lock_guard<std::mutex> autolock(channel_mutex_);
auto channel_data = channels_.find(channel_id);
if (channel_data == channels_.end())
return -EINVAL;
int old_mask = channel_data->second.event_mask;
int new_mask = (old_mask & ~clear_mask) | set_mask;
// EPOLLHUP shares the same bitmask with POLLHUP.
if ((new_mask & POLLHUP) && !(old_mask & POLLHUP)) {
CloseChannelLocked(channel_id);
return 0;
}
// EPOLLIN shares the same bitmask with POLLIN and EPOLLPRI shares the same
// bitmask with POLLPRI
eventfd_t value = 1;
if (((new_mask & POLLIN) && !(old_mask & POLLIN)) ||
((new_mask & POLLPRI) && !(old_mask & POLLPRI))) {
eventfd_write(channel_data->second.event_fd.Get(), value);
} else if ((!(new_mask & POLLIN) && (old_mask & POLLIN)) ||
(!(new_mask & POLLPRI) && (old_mask & POLLPRI))) {
eventfd_read(channel_data->second.event_fd.Get(), &value);
}
channel_data->second.event_mask = new_mask;
return 0;
}
Status<RemoteChannelHandle> Endpoint::PushChannel(Message* message,
int /*flags*/,
Channel* channel,
int* channel_id) {
int channel_pair[2] = {};
if (socketpair(AF_UNIX, SOCK_STREAM, 0, channel_pair) == -1) {
ALOGE("Endpoint::PushChannel: Failed to create a socket pair: %s",
strerror(errno));
return ErrorStatus(errno);
}
LocalHandle local_socket{channel_pair[0]};
LocalHandle remote_socket{channel_pair[1]};
int optval = 1;
if (setsockopt(local_socket.Get(), SOL_SOCKET, SO_PASSCRED, &optval,
sizeof(optval)) == -1) {
ALOGE(
"Endpoint::PushChannel: Failed to enable the receiving of the "
"credentials for channel %d: %s",
local_socket.Get(), strerror(errno));
return ErrorStatus(errno);
}
std::lock_guard<std::mutex> autolock(channel_mutex_);
*channel_id = local_socket.Get();
auto channel_data = OnNewChannelLocked(std::move(local_socket), channel);
if (!channel_data)
return ErrorStatus(channel_data.error());
// Flags are ignored for now.
// TODO(xiaohuit): Implement those.
auto* state = static_cast<MessageState*>(message->GetState());
ChannelReference ref = state->PushChannelHandle(
remote_socket.Borrow(), channel_data.get()->event_fd.Borrow());
state->sockets_to_close.push_back(std::move(remote_socket));
return RemoteChannelHandle{ref};
}
Status<int> Endpoint::CheckChannel(const Message* /*message*/,
ChannelReference /*ref*/,
Channel** /*channel*/) {
// TODO(xiaohuit): Implement this.
return ErrorStatus(EFAULT);
}
int Endpoint::DefaultHandleMessage(const MessageInfo& /* info */) {
ALOGE(
"Endpoint::CheckChannel: Not implemented! Endpoint DefaultHandleMessage "
"does nothing!");
return 0;
}
Channel* Endpoint::GetChannelState(int channel_id) {
std::lock_guard<std::mutex> autolock(channel_mutex_);
auto channel_data = channels_.find(channel_id);
return (channel_data != channels_.end()) ? channel_data->second.channel_state
: nullptr;
}
int Endpoint::GetChannelSocketFd(int channel_id) {
std::lock_guard<std::mutex> autolock(channel_mutex_);
auto channel_data = channels_.find(channel_id);
return (channel_data != channels_.end()) ? channel_data->second.data_fd.Get()
: -1;
}
int Endpoint::GetChannelEventFd(int channel_id) {
std::lock_guard<std::mutex> autolock(channel_mutex_);
auto channel_data = channels_.find(channel_id);
return (channel_data != channels_.end()) ? channel_data->second.event_fd.Get()
: -1;
}
Status<void> Endpoint::ReceiveMessageForChannel(int channel_id,
Message* message) {
RequestHeader<LocalHandle> request;
auto status = ReceiveData(channel_id, &request);
if (!status) {
CloseChannel(channel_id);
return status;
}
MessageInfo info;
info.pid = request.cred.pid;
info.tid = -1;
info.cid = channel_id;
info.mid = request.is_impulse ? Message::IMPULSE_MESSAGE_ID
: GetNextAvailableMessageId();
info.euid = request.cred.uid;
info.egid = request.cred.gid;
info.op = request.op;
info.flags = 0;
info.service = service_;
info.channel = GetChannelState(channel_id);
info.send_len = request.send_len;
info.recv_len = request.max_recv_len;
info.fd_count = request.file_descriptors.size();
static_assert(sizeof(info.impulse) == request.impulse_payload.size(),
"Impulse payload sizes must be the same in RequestHeader and "
"MessageInfo");
memcpy(info.impulse, request.impulse_payload.data(),
request.impulse_payload.size());
*message = Message{info};
auto* state = static_cast<MessageState*>(message->GetState());
state->request = std::move(request);
if (request.send_len > 0 && !request.is_impulse) {
state->request_data.resize(request.send_len);
status = ReceiveData(channel_id, state->request_data.data(),
state->request_data.size());
}
if (status && request.is_impulse)
status = ReenableEpollEvent(channel_id);
if (!status)
CloseChannel(channel_id);
return status;
}
int Endpoint::MessageReceive(Message* message) {
{
std::unique_lock<std::mutex> lock(service_mutex_);
condition_.wait(lock, [this] { return service_ != nullptr; });
}
// One event at a time.
epoll_event event;
int count = RETRY_EINTR(
epoll_wait(epoll_fd_.Get(), &event, 1, is_blocking_ ? -1 : 0));
if (count < 0) {
ALOGE("Endpoint::MessageReceive: Failed to wait for epoll events: %s\n",
strerror(errno));
return -errno;
} else if (count == 0) {
return -ETIMEDOUT;
}
if (event.data.fd == cancel_event_fd_.Get()) {
return -ESHUTDOWN;
}
if (event.data.fd == socket_fd_.Get()) {
auto status = AcceptConnection(message);
if (!status)
return -status.error();
status = ReenableEpollEvent(socket_fd_.Get());
return status ? 0 : -status.error();
}
int channel_id = event.data.fd;
if (event.events & EPOLLRDHUP) {
MessageInfo info;
info.pid = -1;
info.tid = -1;
info.cid = channel_id;
info.mid = GetNextAvailableMessageId();
info.euid = -1;
info.egid = -1;
info.op = opcodes::CHANNEL_CLOSE;
info.flags = 0;
info.service = service_;
info.channel = GetChannelState(channel_id);
info.send_len = 0;
info.recv_len = 0;
info.fd_count = 0;
*message = Message{info};
return 0;
}
auto status = ReceiveMessageForChannel(channel_id, message);
if (!status)
return -status.error();
return 0;
}
int Endpoint::MessageReply(Message* message, int return_code) {
int channel_socket = GetChannelSocketFd(message->GetChannelId());
if (channel_socket < 0)
return -EBADF;
auto* state = static_cast<MessageState*>(message->GetState());
switch (message->GetOp()) {
case opcodes::CHANNEL_CLOSE:
return CloseChannel(channel_socket);
case opcodes::CHANNEL_OPEN:
if (return_code < 0)
return CloseChannel(channel_socket);
// Reply with the event fd.
return_code = state->PushFileHandle(
BorrowedHandle{GetChannelEventFd(channel_socket)});
state->response_data.clear(); // Just in case...
break;
}
state->response.ret_code = return_code;
state->response.recv_len = state->response_data.size();
auto status = SendData(channel_socket, state->response);
if (status && !state->response_data.empty()) {
status = SendData(channel_socket, state->response_data.data(),
state->response_data.size());
}
if (status)
status = ReenableEpollEvent(channel_socket);
return status ? 0 : -status.error();
}
int Endpoint::MessageReplyFd(Message* message, unsigned int push_fd) {
auto* state = static_cast<MessageState*>(message->GetState());
auto ref = state->PushFileHandle(BorrowedHandle{static_cast<int>(push_fd)});
return MessageReply(message, ref);
}
int Endpoint::MessageReplyChannelHandle(Message* message,
const LocalChannelHandle& handle) {
auto* state = static_cast<MessageState*>(message->GetState());
auto ref = state->PushChannelHandle(handle.Borrow());
return MessageReply(message, ref);
}
int Endpoint::MessageReplyChannelHandle(Message* message,
const BorrowedChannelHandle& handle) {
auto* state = static_cast<MessageState*>(message->GetState());
auto ref = state->PushChannelHandle(handle.Duplicate());
return MessageReply(message, ref);
}
int Endpoint::MessageReplyChannelHandle(Message* message,
const RemoteChannelHandle& handle) {
return MessageReply(message, handle.value());
}
ssize_t Endpoint::ReadMessageData(Message* message, const iovec* vector,
size_t vector_length) {
auto* state = static_cast<MessageState*>(message->GetState());
return state->ReadData(vector, vector_length);
}
ssize_t Endpoint::WriteMessageData(Message* message, const iovec* vector,
size_t vector_length) {
auto* state = static_cast<MessageState*>(message->GetState());
return state->WriteData(vector, vector_length);
}
FileReference Endpoint::PushFileHandle(Message* message,
const LocalHandle& handle) {
auto* state = static_cast<MessageState*>(message->GetState());
return state->PushFileHandle(handle.Borrow());
}
FileReference Endpoint::PushFileHandle(Message* message,
const BorrowedHandle& handle) {
auto* state = static_cast<MessageState*>(message->GetState());
return state->PushFileHandle(handle.Duplicate());
}
FileReference Endpoint::PushFileHandle(Message* /*message*/,
const RemoteHandle& handle) {
return handle.Get();
}
ChannelReference Endpoint::PushChannelHandle(Message* message,
const LocalChannelHandle& handle) {
auto* state = static_cast<MessageState*>(message->GetState());
return state->PushChannelHandle(handle.Borrow());
}
ChannelReference Endpoint::PushChannelHandle(
Message* message, const BorrowedChannelHandle& handle) {
auto* state = static_cast<MessageState*>(message->GetState());
return state->PushChannelHandle(handle.Duplicate());
}
ChannelReference Endpoint::PushChannelHandle(
Message* /*message*/, const RemoteChannelHandle& handle) {
return handle.value();
}
LocalHandle Endpoint::GetFileHandle(Message* message, FileReference ref) const {
LocalHandle handle;
auto* state = static_cast<MessageState*>(message->GetState());
state->GetLocalFileHandle(ref, &handle);
return handle;
}
LocalChannelHandle Endpoint::GetChannelHandle(Message* message,
ChannelReference ref) const {
LocalChannelHandle handle;
auto* state = static_cast<MessageState*>(message->GetState());
state->GetLocalChannelHandle(ref, &handle);
return handle;
}
int Endpoint::Cancel() {
return (eventfd_write(cancel_event_fd_.Get(), 1) < 0) ? -errno : 0;
}
std::unique_ptr<Endpoint> Endpoint::Create(const std::string& endpoint_path,
mode_t /*unused_mode*/,
bool blocking) {
return std::unique_ptr<Endpoint>(new Endpoint(endpoint_path, blocking));
}
} // namespace uds
} // namespace pdx
} // namespace android