Merge "adb: switch apacket over to a std::string payload." am: 581a4ceb00 am: 4d74066425
am: ea9359a957

Change-Id: Ib9b1f7c9c3e59e3cfd3ea492360703d2b7a68119
diff --git a/adb/adb.cpp b/adb/adb.cpp
index ee3503b..ae8020e 100644
--- a/adb/adb.cpp
+++ b/adb/adb.cpp
@@ -105,31 +105,27 @@
 }
 
 uint32_t calculate_apacket_checksum(const apacket* p) {
-    const unsigned char* x = reinterpret_cast<const unsigned char*>(p->data);
     uint32_t sum = 0;
-    size_t count = p->msg.data_length;
-
-    while (count-- > 0) {
-        sum += *x++;
+    for (size_t i = 0; i < p->msg.data_length; ++i) {
+        sum += static_cast<uint8_t>(p->payload[i]);
     }
-
     return sum;
 }
 
 apacket* get_apacket(void)
 {
-    apacket* p = reinterpret_cast<apacket*>(malloc(sizeof(apacket)));
+    apacket* p = new apacket();
     if (p == nullptr) {
       fatal("failed to allocate an apacket");
     }
 
-    memset(p, 0, sizeof(apacket) - MAX_PAYLOAD);
+    memset(&p->msg, 0, sizeof(p->msg));
     return p;
 }
 
 void put_apacket(apacket *p)
 {
-    free(p);
+    delete p;
 }
 
 void handle_online(atransport *t)
@@ -155,8 +151,7 @@
 #define DUMPMAX 32
 void print_packet(const char *label, apacket *p)
 {
-    char *tag;
-    char *x;
+    const char* tag;
     unsigned count;
 
     switch(p->msg.command){
@@ -173,15 +168,15 @@
     fprintf(stderr, "%s: %s %08x %08x %04x \"",
             label, tag, p->msg.arg0, p->msg.arg1, p->msg.data_length);
     count = p->msg.data_length;
-    x = (char*) p->data;
-    if(count > DUMPMAX) {
+    const char* x = p->payload.data();
+    if (count > DUMPMAX) {
         count = DUMPMAX;
         tag = "\n";
     } else {
         tag = "\"\n";
     }
-    while(count-- > 0){
-        if((*x >= ' ') && (*x < 127)) {
+    while (count-- > 0) {
+        if ((*x >= ' ') && (*x < 127)) {
             fputc(*x, stderr);
         } else {
             fputc('.', stderr);
@@ -254,8 +249,8 @@
                    << connection_str.length() << ")";
     }
 
-    memcpy(cp->data, connection_str.c_str(), connection_str.length());
-    cp->msg.data_length = connection_str.length();
+    cp->payload = std::move(connection_str);
+    cp->msg.data_length = cp->payload.size();
 
     send_packet(cp, t);
 }
@@ -329,9 +324,7 @@
     }
 
     t->update_version(p->msg.arg0, p->msg.arg1);
-    std::string banner(reinterpret_cast<const char*>(p->data),
-                       p->msg.data_length);
-    parse_banner(banner, t);
+    parse_banner(p->payload, t);
 
 #if ADB_HOST
     handle_online(t);
@@ -354,6 +347,7 @@
             ((char*) (&(p->msg.command)))[2],
             ((char*) (&(p->msg.command)))[3]);
     print_packet("recv", p);
+    CHECK_EQ(p->payload.size(), p->msg.data_length);
 
     switch(p->msg.command){
     case A_SYNC:
@@ -380,11 +374,11 @@
                 if (t->GetConnectionState() == kCsOffline) {
                     t->SetConnectionState(kCsUnauthorized);
                 }
-                send_auth_response(p->data, p->msg.data_length, t);
+                send_auth_response(p->payload.data(), p->msg.data_length, t);
                 break;
 #else
             case ADB_AUTH_SIGNATURE:
-                if (adbd_auth_verify(t->token, sizeof(t->token), p->data, p->msg.data_length)) {
+                if (adbd_auth_verify(t->token, sizeof(t->token), p->payload)) {
                     adbd_auth_verified(t);
                     t->failed_auth_attempts = 0;
                 } else {
@@ -394,7 +388,7 @@
                 break;
 
             case ADB_AUTH_RSAPUBLICKEY:
-                adbd_auth_confirm_key(p->data, p->msg.data_length, t);
+                adbd_auth_confirm_key(p->payload.data(), p->msg.data_length, t);
                 break;
 #endif
             default:
@@ -406,9 +400,7 @@
 
     case A_OPEN: /* OPEN(local-id, 0, "destination") */
         if (t->online && p->msg.arg0 != 0 && p->msg.arg1 == 0) {
-            char *name = (char*) p->data;
-            name[p->msg.data_length > 0 ? p->msg.data_length - 1 : 0] = 0;
-            asocket* s = create_local_service_socket(name, t);
+            asocket* s = create_local_service_socket(p->payload.c_str(), t);
             if (s == nullptr) {
                 send_close(0, p->msg.arg0, t);
             } else {
@@ -474,11 +466,7 @@
             asocket* s = find_local_socket(p->msg.arg1, p->msg.arg0);
             if (s) {
                 unsigned rid = p->msg.arg0;
-
-                // TODO: Convert apacket::data to a type that we can move out of.
-                std::string copy(p->data, p->data + p->msg.data_length);
-
-                if (s->enqueue(s, std::move(copy)) == 0) {
+                if (s->enqueue(s, std::move(p->payload)) == 0) {
                     D("Enqueue the socket");
                     send_ready(s->id, rid, t);
                 }
diff --git a/adb/adb.h b/adb/adb.h
index c9c635a..a6d0463 100644
--- a/adb/adb.h
+++ b/adb/adb.h
@@ -74,7 +74,7 @@
 
 struct apacket {
     amessage msg;
-    char data[MAX_PAYLOAD];
+    std::string payload;
 };
 
 uint32_t calculate_apacket_checksum(const apacket* packet);
diff --git a/adb/adb_auth.h b/adb/adb_auth.h
index a6f224f..715e04f 100644
--- a/adb/adb_auth.h
+++ b/adb/adb_auth.h
@@ -49,7 +49,7 @@
 void adbd_auth_verified(atransport *t);
 
 void adbd_cloexec_auth_socket();
-bool adbd_auth_verify(const char* token, size_t token_size, const char* sig, int sig_len);
+bool adbd_auth_verify(const char* token, size_t token_size, const std::string& sig);
 void adbd_auth_confirm_key(const char* data, size_t len, atransport* t);
 
 void send_auth_request(atransport *t);
diff --git a/adb/adb_auth_host.cpp b/adb/adb_auth_host.cpp
index 365bf77..c3aef16 100644
--- a/adb/adb_auth_host.cpp
+++ b/adb/adb_auth_host.cpp
@@ -299,20 +299,25 @@
     return result;
 }
 
-static int adb_auth_sign(RSA* key, const char* token, size_t token_size, char* sig) {
+static std::string adb_auth_sign(RSA* key, const char* token, size_t token_size) {
     if (token_size != TOKEN_SIZE) {
         D("Unexpected token size %zd", token_size);
         return 0;
     }
 
+    std::string result;
+    result.resize(MAX_PAYLOAD);
+
     unsigned int len;
     if (!RSA_sign(NID_sha1, reinterpret_cast<const uint8_t*>(token), token_size,
-                  reinterpret_cast<uint8_t*>(sig), &len, key)) {
-        return 0;
+                  reinterpret_cast<uint8_t*>(&result[0]), &len, key)) {
+        return std::string();
     }
 
+    result.resize(len);
+
     D("adb_auth_sign len=%d", len);
-    return (int)len;
+    return result;
 }
 
 std::string adb_auth_get_userkey() {
@@ -446,13 +451,14 @@
     }
 
     apacket* p = get_apacket();
-    memcpy(p->data, key.c_str(), key.size() + 1);
-
     p->msg.command = A_AUTH;
     p->msg.arg0 = ADB_AUTH_RSAPUBLICKEY;
 
+    p->payload = std::move(key);
+
     // adbd expects a null-terminated string.
-    p->msg.data_length = key.size() + 1;
+    p->payload.push_back('\0');
+    p->msg.data_length = p->payload.size();
     send_packet(p, t);
 }
 
@@ -467,8 +473,8 @@
     LOG(INFO) << "Calling send_auth_response";
     apacket* p = get_apacket();
 
-    int ret = adb_auth_sign(key.get(), token, token_size, p->data);
-    if (!ret) {
+    std::string result = adb_auth_sign(key.get(), token, token_size);
+    if (result.empty()) {
         D("Error signing the token");
         put_apacket(p);
         return;
@@ -476,6 +482,7 @@
 
     p->msg.command = A_AUTH;
     p->msg.arg0 = ADB_AUTH_SIGNATURE;
-    p->msg.data_length = ret;
+    p->payload = std::move(result);
+    p->msg.data_length = p->payload.size();
     send_packet(p, t);
 }
diff --git a/adb/adbd_auth.cpp b/adb/adbd_auth.cpp
index 3488ad1..3fd2b31 100644
--- a/adb/adbd_auth.cpp
+++ b/adb/adbd_auth.cpp
@@ -46,7 +46,7 @@
 
 bool auth_required = true;
 
-bool adbd_auth_verify(const char* token, size_t token_size, const char* sig, int sig_len) {
+bool adbd_auth_verify(const char* token, size_t token_size, const std::string& sig) {
     static constexpr const char* key_paths[] = { "/adb_keys", "/data/misc/adb/adb_keys", nullptr };
 
     for (const auto& path : key_paths) {
@@ -80,7 +80,8 @@
 
                 bool verified =
                     (RSA_verify(NID_sha1, reinterpret_cast<const uint8_t*>(token), token_size,
-                                reinterpret_cast<const uint8_t*>(sig), sig_len, key) == 1);
+                                reinterpret_cast<const uint8_t*>(sig.c_str()), sig.size(),
+                                key) == 1);
                 RSA_free(key);
                 if (verified) return true;
             }
@@ -210,10 +211,10 @@
     }
 
     apacket* p = get_apacket();
-    memcpy(p->data, t->token, sizeof(t->token));
     p->msg.command = A_AUTH;
     p->msg.arg0 = ADB_AUTH_TOKEN;
     p->msg.data_length = sizeof(t->token);
+    p->payload.assign(t->token, t->token + sizeof(t->token));
     send_packet(p, t);
 }
 
diff --git a/adb/sockets.cpp b/adb/sockets.cpp
index 307cbfe..0007fed 100644
--- a/adb/sockets.cpp
+++ b/adb/sockets.cpp
@@ -413,15 +413,15 @@
     p->msg.command = A_WRTE;
     p->msg.arg0 = s->peer->id;
     p->msg.arg1 = s->id;
-    p->msg.data_length = data.size();
 
-    if (data.size() > sizeof(p->data)) {
+    if (data.size() > MAX_PAYLOAD) {
         put_apacket(p);
         return -1;
     }
 
-    // TODO: Convert apacket::data to a type that we can move into.
-    memcpy(p->data, data.data(), data.size());
+    p->payload = std::move(data);
+    p->msg.data_length = p->payload.size();
+
     send_packet(p, s->transport);
     return 1;
 }
@@ -482,17 +482,20 @@
 void connect_to_remote(asocket* s, const char* destination) {
     D("Connect_to_remote call RS(%d) fd=%d", s->id, s->fd);
     apacket* p = get_apacket();
-    size_t len = strlen(destination) + 1;
-
-    if (len > (s->get_max_payload() - 1)) {
-        fatal("destination oversized");
-    }
 
     D("LS(%d): connect('%s')", s->id, destination);
     p->msg.command = A_OPEN;
     p->msg.arg0 = s->id;
-    p->msg.data_length = len;
-    strcpy((char*)p->data, destination);
+
+    // adbd expects a null-terminated string.
+    p->payload = destination;
+    p->payload.push_back('\0');
+    p->msg.data_length = p->payload.size();
+
+    if (p->msg.data_length > s->get_max_payload()) {
+        fatal("destination oversized");
+    }
+
     send_packet(p, s->transport);
 }
 
diff --git a/adb/transport.cpp b/adb/transport.cpp
index 9ae1297..14888ab 100644
--- a/adb/transport.cpp
+++ b/adb/transport.cpp
@@ -72,12 +72,14 @@
         return false;
     }
 
-    if (packet->msg.data_length > sizeof(packet->data)) {
+    if (packet->msg.data_length > MAX_PAYLOAD) {
         D("remote local: read overflow (data length = %" PRIu32 ")", packet->msg.data_length);
         return false;
     }
 
-    if (!ReadFdExactly(fd_.get(), &packet->data, packet->msg.data_length)) {
+    packet->payload.resize(packet->msg.data_length);
+
+    if (!ReadFdExactly(fd_.get(), &packet->payload[0], packet->payload.size())) {
         D("remote local: terminated (data)");
         return false;
     }
@@ -86,13 +88,18 @@
 }
 
 bool FdConnection::Write(apacket* packet) {
-    uint32_t length = packet->msg.data_length;
-
-    if (!WriteFdExactly(fd_.get(), &packet->msg, sizeof(amessage) + length)) {
+    if (!WriteFdExactly(fd_.get(), &packet->msg, sizeof(packet->msg))) {
         D("remote local: write terminated");
         return false;
     }
 
+    if (packet->msg.data_length) {
+        if (!WriteFdExactly(fd_.get(), &packet->payload[0], packet->msg.data_length)) {
+            D("remote local: write terminated");
+            return false;
+        }
+    }
+
     return true;
 }
 
@@ -133,7 +140,7 @@
 
     std::string result = android::base::StringPrintf("%s: %s: [%s] arg0=%s arg1=%s (len=%d) ", name,
                                                      func, cmd, arg0, arg1, len);
-    result += dump_hex(p->data, len);
+    result += dump_hex(p->payload.data(), p->payload.size());
     return result;
 }
 
@@ -191,9 +198,10 @@
         apacket* p = 0;
         if (read_packet(fd, t->serial, &p)) {
             D("%s: failed to read packet from transport socket on fd %d", t->serial, fd);
-        } else {
-            handle_packet(p, (atransport*)_t);
+            return;
         }
+
+        handle_packet(p, (atransport*)_t);
     }
 }
 
@@ -243,6 +251,7 @@
     p->msg.arg0 = 1;
     p->msg.arg1 = ++(t->sync_token);
     p->msg.magic = A_SYNC ^ 0xffffffff;
+    D("sending SYNC packet (len = %u, payload.size() = %zu)", p->msg.data_length, p->payload.size());
     if (write_packet(t->fd, t->serial, &p)) {
         put_apacket(p);
         D("%s: failed to write SYNC packet", t->serial);
@@ -336,6 +345,13 @@
             if (active) {
                 D("%s: transport got packet, sending to remote", t->serial);
                 ATRACE_NAME("write_transport write_remote");
+
+                // Allow sending the payload's implicit null terminator.
+                if (p->msg.data_length != p->payload.size()) {
+                    LOG(FATAL) << "packet data length doesn't match payload: msg.data_length = "
+                               << p->msg.data_length << ", payload.size() = " << p->payload.size();
+                }
+
                 if (t->Write(p) != 0) {
                     D("%s: remote write failed for transport", t->serial);
                     put_apacket(p);
diff --git a/adb/transport_usb.cpp b/adb/transport_usb.cpp
index a108699..d7565f6 100644
--- a/adb/transport_usb.cpp
+++ b/adb/transport_usb.cpp
@@ -61,13 +61,12 @@
 static int UsbReadPayload(usb_handle* h, apacket* p) {
     D("UsbReadPayload(%d)", p->msg.data_length);
 
-    if (p->msg.data_length > sizeof(p->data)) {
+    if (p->msg.data_length > MAX_PAYLOAD) {
         return -1;
     }
 
 #if CHECK_PACKET_OVERFLOW
     size_t usb_packet_size = usb_get_max_packet_size(h);
-    CHECK_EQ(0ULL, sizeof(p->data) % usb_packet_size);
 
     // Round the data length up to the nearest packet size boundary.
     // The device won't send a zero packet for packet size aligned payloads,
@@ -77,10 +76,18 @@
     if (rem_size) {
         len += usb_packet_size - rem_size;
     }
-    CHECK_LE(len, sizeof(p->data));
-    return usb_read(h, &p->data, len);
+
+    p->payload.resize(len);
+    int rc = usb_read(h, &p->payload[0], p->payload.size());
+    if (rc != static_cast<int>(p->msg.data_length)) {
+        return -1;
+    }
+
+    p->payload.resize(rc);
+    return rc;
 #else
-    return usb_read(h, &p->data, p->msg.data_length);
+    p->payload.resize(p->msg.data_length);
+    return usb_read(h, &p->payload[0], p->payload.size());
 #endif
 }
 
@@ -120,12 +127,13 @@
     }
 
     if (p->msg.data_length) {
-        if (p->msg.data_length > sizeof(p->data)) {
+        if (p->msg.data_length > MAX_PAYLOAD) {
             PLOG(ERROR) << "remote usb: read overflow (data length = " << p->msg.data_length << ")";
             return -1;
         }
 
-        if (usb_read(usb, p->data, p->msg.data_length)) {
+        p->payload.resize(p->msg.data_length);
+        if (usb_read(usb, &p->payload[0], p->payload.size())) {
             PLOG(ERROR) << "remote usb: terminated (data)";
             return -1;
         }
@@ -152,7 +160,7 @@
         return false;
     }
 
-    if (packet->msg.data_length != 0 && usb_write(handle_, &packet->data, size) != 0) {
+    if (packet->msg.data_length != 0 && usb_write(handle_, packet->payload.data(), size) != 0) {
         PLOG(ERROR) << "remote usb: 2 - write terminated";
         return false;
     }