adb: switch apacket over to a std::string payload.
Test: python test_device.py with walleye/x86_64 emulator
Change-Id: I0a18941af1cb2279e5019a24ace25741def1202f
diff --git a/adb/adb.cpp b/adb/adb.cpp
index ee3503b..ae8020e 100644
--- a/adb/adb.cpp
+++ b/adb/adb.cpp
@@ -105,31 +105,27 @@
}
uint32_t calculate_apacket_checksum(const apacket* p) {
- const unsigned char* x = reinterpret_cast<const unsigned char*>(p->data);
uint32_t sum = 0;
- size_t count = p->msg.data_length;
-
- while (count-- > 0) {
- sum += *x++;
+ for (size_t i = 0; i < p->msg.data_length; ++i) {
+ sum += static_cast<uint8_t>(p->payload[i]);
}
-
return sum;
}
apacket* get_apacket(void)
{
- apacket* p = reinterpret_cast<apacket*>(malloc(sizeof(apacket)));
+ apacket* p = new apacket();
if (p == nullptr) {
fatal("failed to allocate an apacket");
}
- memset(p, 0, sizeof(apacket) - MAX_PAYLOAD);
+ memset(&p->msg, 0, sizeof(p->msg));
return p;
}
void put_apacket(apacket *p)
{
- free(p);
+ delete p;
}
void handle_online(atransport *t)
@@ -155,8 +151,7 @@
#define DUMPMAX 32
void print_packet(const char *label, apacket *p)
{
- char *tag;
- char *x;
+ const char* tag;
unsigned count;
switch(p->msg.command){
@@ -173,15 +168,15 @@
fprintf(stderr, "%s: %s %08x %08x %04x \"",
label, tag, p->msg.arg0, p->msg.arg1, p->msg.data_length);
count = p->msg.data_length;
- x = (char*) p->data;
- if(count > DUMPMAX) {
+ const char* x = p->payload.data();
+ if (count > DUMPMAX) {
count = DUMPMAX;
tag = "\n";
} else {
tag = "\"\n";
}
- while(count-- > 0){
- if((*x >= ' ') && (*x < 127)) {
+ while (count-- > 0) {
+ if ((*x >= ' ') && (*x < 127)) {
fputc(*x, stderr);
} else {
fputc('.', stderr);
@@ -254,8 +249,8 @@
<< connection_str.length() << ")";
}
- memcpy(cp->data, connection_str.c_str(), connection_str.length());
- cp->msg.data_length = connection_str.length();
+ cp->payload = std::move(connection_str);
+ cp->msg.data_length = cp->payload.size();
send_packet(cp, t);
}
@@ -329,9 +324,7 @@
}
t->update_version(p->msg.arg0, p->msg.arg1);
- std::string banner(reinterpret_cast<const char*>(p->data),
- p->msg.data_length);
- parse_banner(banner, t);
+ parse_banner(p->payload, t);
#if ADB_HOST
handle_online(t);
@@ -354,6 +347,7 @@
((char*) (&(p->msg.command)))[2],
((char*) (&(p->msg.command)))[3]);
print_packet("recv", p);
+ CHECK_EQ(p->payload.size(), p->msg.data_length);
switch(p->msg.command){
case A_SYNC:
@@ -380,11 +374,11 @@
if (t->GetConnectionState() == kCsOffline) {
t->SetConnectionState(kCsUnauthorized);
}
- send_auth_response(p->data, p->msg.data_length, t);
+ send_auth_response(p->payload.data(), p->msg.data_length, t);
break;
#else
case ADB_AUTH_SIGNATURE:
- if (adbd_auth_verify(t->token, sizeof(t->token), p->data, p->msg.data_length)) {
+ if (adbd_auth_verify(t->token, sizeof(t->token), p->payload)) {
adbd_auth_verified(t);
t->failed_auth_attempts = 0;
} else {
@@ -394,7 +388,7 @@
break;
case ADB_AUTH_RSAPUBLICKEY:
- adbd_auth_confirm_key(p->data, p->msg.data_length, t);
+ adbd_auth_confirm_key(p->payload.data(), p->msg.data_length, t);
break;
#endif
default:
@@ -406,9 +400,7 @@
case A_OPEN: /* OPEN(local-id, 0, "destination") */
if (t->online && p->msg.arg0 != 0 && p->msg.arg1 == 0) {
- char *name = (char*) p->data;
- name[p->msg.data_length > 0 ? p->msg.data_length - 1 : 0] = 0;
- asocket* s = create_local_service_socket(name, t);
+ asocket* s = create_local_service_socket(p->payload.c_str(), t);
if (s == nullptr) {
send_close(0, p->msg.arg0, t);
} else {
@@ -474,11 +466,7 @@
asocket* s = find_local_socket(p->msg.arg1, p->msg.arg0);
if (s) {
unsigned rid = p->msg.arg0;
-
- // TODO: Convert apacket::data to a type that we can move out of.
- std::string copy(p->data, p->data + p->msg.data_length);
-
- if (s->enqueue(s, std::move(copy)) == 0) {
+ if (s->enqueue(s, std::move(p->payload)) == 0) {
D("Enqueue the socket");
send_ready(s->id, rid, t);
}
diff --git a/adb/adb.h b/adb/adb.h
index c9c635a..a6d0463 100644
--- a/adb/adb.h
+++ b/adb/adb.h
@@ -74,7 +74,7 @@
struct apacket {
amessage msg;
- char data[MAX_PAYLOAD];
+ std::string payload;
};
uint32_t calculate_apacket_checksum(const apacket* packet);
diff --git a/adb/adb_auth.h b/adb/adb_auth.h
index a6f224f..715e04f 100644
--- a/adb/adb_auth.h
+++ b/adb/adb_auth.h
@@ -49,7 +49,7 @@
void adbd_auth_verified(atransport *t);
void adbd_cloexec_auth_socket();
-bool adbd_auth_verify(const char* token, size_t token_size, const char* sig, int sig_len);
+bool adbd_auth_verify(const char* token, size_t token_size, const std::string& sig);
void adbd_auth_confirm_key(const char* data, size_t len, atransport* t);
void send_auth_request(atransport *t);
diff --git a/adb/adb_auth_host.cpp b/adb/adb_auth_host.cpp
index 365bf77..c3aef16 100644
--- a/adb/adb_auth_host.cpp
+++ b/adb/adb_auth_host.cpp
@@ -299,20 +299,25 @@
return result;
}
-static int adb_auth_sign(RSA* key, const char* token, size_t token_size, char* sig) {
+static std::string adb_auth_sign(RSA* key, const char* token, size_t token_size) {
if (token_size != TOKEN_SIZE) {
D("Unexpected token size %zd", token_size);
return 0;
}
+ std::string result;
+ result.resize(MAX_PAYLOAD);
+
unsigned int len;
if (!RSA_sign(NID_sha1, reinterpret_cast<const uint8_t*>(token), token_size,
- reinterpret_cast<uint8_t*>(sig), &len, key)) {
- return 0;
+ reinterpret_cast<uint8_t*>(&result[0]), &len, key)) {
+ return std::string();
}
+ result.resize(len);
+
D("adb_auth_sign len=%d", len);
- return (int)len;
+ return result;
}
std::string adb_auth_get_userkey() {
@@ -446,13 +451,14 @@
}
apacket* p = get_apacket();
- memcpy(p->data, key.c_str(), key.size() + 1);
-
p->msg.command = A_AUTH;
p->msg.arg0 = ADB_AUTH_RSAPUBLICKEY;
+ p->payload = std::move(key);
+
// adbd expects a null-terminated string.
- p->msg.data_length = key.size() + 1;
+ p->payload.push_back('\0');
+ p->msg.data_length = p->payload.size();
send_packet(p, t);
}
@@ -467,8 +473,8 @@
LOG(INFO) << "Calling send_auth_response";
apacket* p = get_apacket();
- int ret = adb_auth_sign(key.get(), token, token_size, p->data);
- if (!ret) {
+ std::string result = adb_auth_sign(key.get(), token, token_size);
+ if (result.empty()) {
D("Error signing the token");
put_apacket(p);
return;
@@ -476,6 +482,7 @@
p->msg.command = A_AUTH;
p->msg.arg0 = ADB_AUTH_SIGNATURE;
- p->msg.data_length = ret;
+ p->payload = std::move(result);
+ p->msg.data_length = p->payload.size();
send_packet(p, t);
}
diff --git a/adb/adbd_auth.cpp b/adb/adbd_auth.cpp
index 3488ad1..3fd2b31 100644
--- a/adb/adbd_auth.cpp
+++ b/adb/adbd_auth.cpp
@@ -46,7 +46,7 @@
bool auth_required = true;
-bool adbd_auth_verify(const char* token, size_t token_size, const char* sig, int sig_len) {
+bool adbd_auth_verify(const char* token, size_t token_size, const std::string& sig) {
static constexpr const char* key_paths[] = { "/adb_keys", "/data/misc/adb/adb_keys", nullptr };
for (const auto& path : key_paths) {
@@ -80,7 +80,8 @@
bool verified =
(RSA_verify(NID_sha1, reinterpret_cast<const uint8_t*>(token), token_size,
- reinterpret_cast<const uint8_t*>(sig), sig_len, key) == 1);
+ reinterpret_cast<const uint8_t*>(sig.c_str()), sig.size(),
+ key) == 1);
RSA_free(key);
if (verified) return true;
}
@@ -210,10 +211,10 @@
}
apacket* p = get_apacket();
- memcpy(p->data, t->token, sizeof(t->token));
p->msg.command = A_AUTH;
p->msg.arg0 = ADB_AUTH_TOKEN;
p->msg.data_length = sizeof(t->token);
+ p->payload.assign(t->token, t->token + sizeof(t->token));
send_packet(p, t);
}
diff --git a/adb/sockets.cpp b/adb/sockets.cpp
index 307cbfe..0007fed 100644
--- a/adb/sockets.cpp
+++ b/adb/sockets.cpp
@@ -413,15 +413,15 @@
p->msg.command = A_WRTE;
p->msg.arg0 = s->peer->id;
p->msg.arg1 = s->id;
- p->msg.data_length = data.size();
- if (data.size() > sizeof(p->data)) {
+ if (data.size() > MAX_PAYLOAD) {
put_apacket(p);
return -1;
}
- // TODO: Convert apacket::data to a type that we can move into.
- memcpy(p->data, data.data(), data.size());
+ p->payload = std::move(data);
+ p->msg.data_length = p->payload.size();
+
send_packet(p, s->transport);
return 1;
}
@@ -482,17 +482,20 @@
void connect_to_remote(asocket* s, const char* destination) {
D("Connect_to_remote call RS(%d) fd=%d", s->id, s->fd);
apacket* p = get_apacket();
- size_t len = strlen(destination) + 1;
-
- if (len > (s->get_max_payload() - 1)) {
- fatal("destination oversized");
- }
D("LS(%d): connect('%s')", s->id, destination);
p->msg.command = A_OPEN;
p->msg.arg0 = s->id;
- p->msg.data_length = len;
- strcpy((char*)p->data, destination);
+
+ // adbd expects a null-terminated string.
+ p->payload = destination;
+ p->payload.push_back('\0');
+ p->msg.data_length = p->payload.size();
+
+ if (p->msg.data_length > s->get_max_payload()) {
+ fatal("destination oversized");
+ }
+
send_packet(p, s->transport);
}
diff --git a/adb/transport.cpp b/adb/transport.cpp
index 9ae1297..14888ab 100644
--- a/adb/transport.cpp
+++ b/adb/transport.cpp
@@ -72,12 +72,14 @@
return false;
}
- if (packet->msg.data_length > sizeof(packet->data)) {
+ if (packet->msg.data_length > MAX_PAYLOAD) {
D("remote local: read overflow (data length = %" PRIu32 ")", packet->msg.data_length);
return false;
}
- if (!ReadFdExactly(fd_.get(), &packet->data, packet->msg.data_length)) {
+ packet->payload.resize(packet->msg.data_length);
+
+ if (!ReadFdExactly(fd_.get(), &packet->payload[0], packet->payload.size())) {
D("remote local: terminated (data)");
return false;
}
@@ -86,13 +88,18 @@
}
bool FdConnection::Write(apacket* packet) {
- uint32_t length = packet->msg.data_length;
-
- if (!WriteFdExactly(fd_.get(), &packet->msg, sizeof(amessage) + length)) {
+ if (!WriteFdExactly(fd_.get(), &packet->msg, sizeof(packet->msg))) {
D("remote local: write terminated");
return false;
}
+ if (packet->msg.data_length) {
+ if (!WriteFdExactly(fd_.get(), &packet->payload[0], packet->msg.data_length)) {
+ D("remote local: write terminated");
+ return false;
+ }
+ }
+
return true;
}
@@ -133,7 +140,7 @@
std::string result = android::base::StringPrintf("%s: %s: [%s] arg0=%s arg1=%s (len=%d) ", name,
func, cmd, arg0, arg1, len);
- result += dump_hex(p->data, len);
+ result += dump_hex(p->payload.data(), p->payload.size());
return result;
}
@@ -191,9 +198,10 @@
apacket* p = 0;
if (read_packet(fd, t->serial, &p)) {
D("%s: failed to read packet from transport socket on fd %d", t->serial, fd);
- } else {
- handle_packet(p, (atransport*)_t);
+ return;
}
+
+ handle_packet(p, (atransport*)_t);
}
}
@@ -243,6 +251,7 @@
p->msg.arg0 = 1;
p->msg.arg1 = ++(t->sync_token);
p->msg.magic = A_SYNC ^ 0xffffffff;
+ D("sending SYNC packet (len = %u, payload.size() = %zu)", p->msg.data_length, p->payload.size());
if (write_packet(t->fd, t->serial, &p)) {
put_apacket(p);
D("%s: failed to write SYNC packet", t->serial);
@@ -336,6 +345,13 @@
if (active) {
D("%s: transport got packet, sending to remote", t->serial);
ATRACE_NAME("write_transport write_remote");
+
+ // Allow sending the payload's implicit null terminator.
+ if (p->msg.data_length != p->payload.size()) {
+ LOG(FATAL) << "packet data length doesn't match payload: msg.data_length = "
+ << p->msg.data_length << ", payload.size() = " << p->payload.size();
+ }
+
if (t->Write(p) != 0) {
D("%s: remote write failed for transport", t->serial);
put_apacket(p);
diff --git a/adb/transport_usb.cpp b/adb/transport_usb.cpp
index a108699..d7565f6 100644
--- a/adb/transport_usb.cpp
+++ b/adb/transport_usb.cpp
@@ -61,13 +61,12 @@
static int UsbReadPayload(usb_handle* h, apacket* p) {
D("UsbReadPayload(%d)", p->msg.data_length);
- if (p->msg.data_length > sizeof(p->data)) {
+ if (p->msg.data_length > MAX_PAYLOAD) {
return -1;
}
#if CHECK_PACKET_OVERFLOW
size_t usb_packet_size = usb_get_max_packet_size(h);
- CHECK_EQ(0ULL, sizeof(p->data) % usb_packet_size);
// Round the data length up to the nearest packet size boundary.
// The device won't send a zero packet for packet size aligned payloads,
@@ -77,10 +76,18 @@
if (rem_size) {
len += usb_packet_size - rem_size;
}
- CHECK_LE(len, sizeof(p->data));
- return usb_read(h, &p->data, len);
+
+ p->payload.resize(len);
+ int rc = usb_read(h, &p->payload[0], p->payload.size());
+ if (rc != static_cast<int>(p->msg.data_length)) {
+ return -1;
+ }
+
+ p->payload.resize(rc);
+ return rc;
#else
- return usb_read(h, &p->data, p->msg.data_length);
+ p->payload.resize(p->msg.data_length);
+ return usb_read(h, &p->payload[0], p->payload.size());
#endif
}
@@ -120,12 +127,13 @@
}
if (p->msg.data_length) {
- if (p->msg.data_length > sizeof(p->data)) {
+ if (p->msg.data_length > MAX_PAYLOAD) {
PLOG(ERROR) << "remote usb: read overflow (data length = " << p->msg.data_length << ")";
return -1;
}
- if (usb_read(usb, p->data, p->msg.data_length)) {
+ p->payload.resize(p->msg.data_length);
+ if (usb_read(usb, &p->payload[0], p->payload.size())) {
PLOG(ERROR) << "remote usb: terminated (data)";
return -1;
}
@@ -152,7 +160,7 @@
return false;
}
- if (packet->msg.data_length != 0 && usb_write(handle_, &packet->data, size) != 0) {
+ if (packet->msg.data_length != 0 && usb_write(handle_, packet->payload.data(), size) != 0) {
PLOG(ERROR) << "remote usb: 2 - write terminated";
return false;
}