blob: 44b9b5b7e7ca43f7e886d03d33e2968f39a66c5d [file] [log] [blame]
/*
* Copyright (C) 2019 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <source/HostToGuestComms.h>
#include <https/SafeCallbackable.h>
#include <https/Support.h>
#include <android-base/logging.h>
HostToGuestComms::HostToGuestComms(
std::shared_ptr<RunLoop> runLoop,
bool isServer,
int fd,
ReceiveCb onReceive)
: mRunLoop(runLoop),
mIsServer(isServer),
mOnReceive(onReceive),
mServerSock(-1),
mSock(-1),
mInBufferLen(0),
mSendPending(false),
mConnected(false) {
makeFdNonblocking(fd);
if (mIsServer) {
mServerSock = fd;
} else {
mSock = fd;
}
}
HostToGuestComms::HostToGuestComms(
std::shared_ptr<RunLoop> runLoop,
bool isServer,
uint32_t cid,
uint16_t port,
ReceiveCb onReceive)
: mRunLoop(runLoop),
mIsServer(isServer),
mOnReceive(onReceive),
mServerSock(-1),
mSock(-1),
mInBufferLen(0),
mSendPending(false),
mConnected(false) {
int s = socket(AF_VSOCK, SOCK_STREAM, 0);
CHECK_GE(s, 0);
LOG(INFO) << "HostToGuestComms created socket " << s;
makeFdNonblocking(s);
sockaddr_vm addr;
memset(&addr, 0, sizeof(addr));
addr.svm_family = AF_VSOCK;
addr.svm_port = port;
addr.svm_cid = cid;
int res;
if (mIsServer) {
LOG(INFO)
<< "Binding to cid "
<< (addr.svm_cid == VMADDR_CID_ANY)
? "VMADDR_CID_ANY" : std::to_string(addr.svm_cid);
res = bind(s, reinterpret_cast<const sockaddr *>(&addr), sizeof(addr));
if (res) {
LOG(ERROR)
<< (mIsServer ? "bind" : "connect")
<< " FAILED w/ errno "
<< errno
<< " ("
<< strerror(errno)
<< ")";
}
CHECK(!res);
res = listen(s, 4);
CHECK(!res);
mServerSock = s;
} else {
mSock = s;
mConnectToAddr = addr;
}
}
HostToGuestComms::~HostToGuestComms() {
if (mSock >= 0) {
mRunLoop->cancelSocket(mSock);
close(mSock);
mSock = -1;
}
if (mServerSock >= 0) {
mRunLoop->cancelSocket(mServerSock);
close(mServerSock);
mServerSock = -1;
}
}
void HostToGuestComms::start() {
if (mIsServer) {
mRunLoop->postSocketRecv(
mServerSock,
makeSafeCallback(this, &HostToGuestComms::onServerConnection));
} else {
mRunLoop->postWithDelay(
std::chrono::milliseconds(5000),
makeSafeCallback(
this,
&HostToGuestComms::onAttemptToConnect,
mConnectToAddr));
}
}
void HostToGuestComms::send(const void *data, size_t size, bool addFraming) {
if (!size) {
return;
}
std::lock_guard autoLock(mLock);
size_t offset = mOutBuffer.size();
if (addFraming) {
uint32_t packetLen = size;
size_t totalSize = sizeof(packetLen) + size;
mOutBuffer.resize(offset + totalSize);
memcpy(mOutBuffer.data() + offset, &packetLen, sizeof(packetLen));
memcpy(mOutBuffer.data() + offset + sizeof(packetLen), data, size);
} else {
mOutBuffer.resize(offset + size);
memcpy(mOutBuffer.data() + offset, data, size);
}
if (mSock >= 0 && (mIsServer || mConnected) && !mSendPending) {
mSendPending = true;
mRunLoop->postSocketSend(
mSock,
makeSafeCallback(this, &HostToGuestComms::onSocketSend));
}
}
void HostToGuestComms::onServerConnection() {
int s = accept(mServerSock, nullptr, nullptr);
if (s >= 0) {
if (mSock >= 0) {
LOG(INFO) << "Rejecting client, we already have one.";
// We already have a client.
close(s);
s = -1;
} else {
LOG(INFO) << "Accepted client socket " << s << ".";
makeFdNonblocking(s);
mSock = s;
mRunLoop->postSocketRecv(
mSock,
makeSafeCallback(this, &HostToGuestComms::onSocketReceive));
std::lock_guard autoLock(mLock);
if (!mOutBuffer.empty()) {
CHECK(!mSendPending);
mSendPending = true;
mRunLoop->postSocketSend(
mSock,
makeSafeCallback(
this, &HostToGuestComms::onSocketSend));
}
}
}
mRunLoop->postSocketRecv(
mServerSock,
makeSafeCallback(this, &HostToGuestComms::onServerConnection));
}
void HostToGuestComms::onSocketReceive() {
ssize_t n;
for (;;) {
static constexpr size_t kChunkSize = 65536;
mInBuffer.resize(mInBufferLen + kChunkSize);
do {
n = recv(mSock, mInBuffer.data() + mInBufferLen, kChunkSize, 0);
} while (n < 0 && errno == EINTR);
if (n <= 0) {
break;
}
mInBufferLen += static_cast<size_t>(n);
}
int savedErrno = errno;
drainInBuffer();
if ((n < 0 && savedErrno != EAGAIN && savedErrno != EWOULDBLOCK)
|| n == 0) {
LOG(ERROR) << "Client is gone.";
// Client is gone.
mRunLoop->cancelSocket(mSock);
mSendPending = false;
close(mSock);
mSock = -1;
return;
}
mRunLoop->postSocketRecv(
mSock,
makeSafeCallback(this, &HostToGuestComms::onSocketReceive));
}
void HostToGuestComms::drainInBuffer() {
for (;;) {
uint32_t packetLen;
if (mInBufferLen < sizeof(packetLen)) {
return;
}
memcpy(&packetLen, mInBuffer.data(), sizeof(packetLen));
size_t totalLen = sizeof(packetLen) + packetLen;
if (mInBufferLen < totalLen) {
return;
}
if (mOnReceive) {
// LOG(INFO) << "Dispatching packet of size " << packetLen;
mOnReceive(mInBuffer.data() + sizeof(packetLen), packetLen);
}
mInBuffer.erase(mInBuffer.begin(), mInBuffer.begin() + totalLen);
mInBufferLen -= totalLen;
}
}
void HostToGuestComms::onSocketSend() {
std::lock_guard autoLock(mLock);
CHECK(mSendPending);
mSendPending = false;
if (mSock < 0) {
return;
}
ssize_t n;
while (!mOutBuffer.empty()) {
do {
n = ::send(mSock, mOutBuffer.data(), mOutBuffer.size(), 0);
} while (n < 0 && errno == EINTR);
if (n <= 0) {
break;
}
mOutBuffer.erase(mOutBuffer.begin(), mOutBuffer.begin() + n);
}
if ((n < 0 && errno != EAGAIN && errno != EWOULDBLOCK) || n == 0) {
LOG(ERROR) << "Client is gone.";
// Client is gone.
mRunLoop->cancelSocket(mSock);
close(mSock);
mSock = -1;
return;
}
if (!mOutBuffer.empty()) {
mSendPending = true;
mRunLoop->postSocketSend(
mSock,
makeSafeCallback(this, &HostToGuestComms::onSocketSend));
}
}
void HostToGuestComms::onAttemptToConnect(const sockaddr_vm &addr) {
LOG(VERBOSE) << "Attempting to connect to cid " << addr.svm_cid;
int res;
do {
res = connect(
mSock, reinterpret_cast<const sockaddr *>(&addr), sizeof(addr));
} while (res < 0 && errno == EINTR);
if (res < 0) {
if (errno == EINPROGRESS) {
LOG(VERBOSE) << "EINPROGRESS, waiting to check the connection.";
mRunLoop->postSocketSend(
mSock,
makeSafeCallback(
this, &HostToGuestComms::onCheckConnection, addr));
return;
}
LOG(INFO)
<< "Our attempt to connect to the guest FAILED w/ error "
<< errno
<< " ("
<< strerror(errno)
<< "), will try again shortly.";
mRunLoop->postWithDelay(
std::chrono::milliseconds(5000),
makeSafeCallback(
this, &HostToGuestComms::onAttemptToConnect, addr));
return;
}
onConnected();
}
void HostToGuestComms::onCheckConnection(const sockaddr_vm &addr) {
int err;
int res;
do {
socklen_t errSize = sizeof(err);
res = getsockopt(mSock, SOL_SOCKET, SO_ERROR, &err, &errSize);
} while (res < 0 && errno == EINTR);
CHECK(!res);
if (!err) {
onConnected();
} else {
LOG(VERBOSE)
<< "Connection failed w/ error "
<< err
<< " ("
<< strerror(err)
<< "), will try again shortly.";
// Is there a better way of cancelling the (failed) connection that
// somehow is still in progress on the socket and restarting it?
mRunLoop->cancelSocket(mSock);
close(mSock);
mSock = socket(AF_VSOCK, SOCK_STREAM, 0);
CHECK_GE(mSock, 0);
makeFdNonblocking(mSock);
mRunLoop->postWithDelay(
std::chrono::milliseconds(5000),
makeSafeCallback(
this, &HostToGuestComms::onAttemptToConnect, addr));
}
}
void HostToGuestComms::onConnected() {
LOG(INFO) << "Connected to guest.";
std::lock_guard autoLock(mLock);
mConnected = true;
CHECK(!mSendPending);
if (!mOutBuffer.empty()) {
mSendPending = true;
mRunLoop->postSocketSend(
mSock,
makeSafeCallback(this, &HostToGuestComms::onSocketSend));
}
mRunLoop->postSocketRecv(
mSock,
makeSafeCallback(this, &HostToGuestComms::onSocketReceive));
}