adb: extract atransport's connection interface.

As step one of refactoring atransport to separate out protocol handling
from its underlying connection, extract atransport's existing
hand-rolled connection vtable out to its own abstract interface.

This should not change behavior except in one case: emulators are
now treated as TCP devices for the purposes of `adb disconnect`.

Test: python test_device.py, with walleye over USB + TCP
Test: manually connecting and disconnecting devices/emulators
Change-Id: I877b8027e567cc6a7461749432b49f6cb2c2f0d7
diff --git a/adb/Android.mk b/adb/Android.mk
index 0eeafb6..e52f0cb 100644
--- a/adb/Android.mk
+++ b/adb/Android.mk
@@ -11,6 +11,7 @@
 adb_target_sanitize :=
 
 ADB_COMMON_CFLAGS := \
+    -frtti \
     -Wall -Wextra -Werror \
     -Wno-unused-parameter \
     -Wno-missing-field-initializers \
diff --git a/adb/adb.h b/adb/adb.h
index 21e5d4b..b5d6bcb 100644
--- a/adb/adb.h
+++ b/adb/adb.h
@@ -136,9 +136,6 @@
 int adb_server_main(int is_daemon, const std::string& socket_spec, int ack_reply_fd);
 
 /* initialize a transport object's func pointers and state */
-#if ADB_HOST
-int get_available_local_transport_index();
-#endif
 int init_socket_transport(atransport* t, int s, int port, int local);
 void init_usb_transport(atransport* t, usb_handle* usb);
 
diff --git a/adb/services.cpp b/adb/services.cpp
index aff7012..6dc71cf 100644
--- a/adb/services.cpp
+++ b/adb/services.cpp
@@ -407,14 +407,6 @@
         return;
     }
 
-    // Check if more emulators can be registered. Similar unproblematic
-    // race condition as above.
-    int candidate_slot = get_available_local_transport_index();
-    if (candidate_slot < 0) {
-        *response = "Cannot accept more emulators";
-        return;
-    }
-
     // Preconditions met, try to connect to the emulator.
     std::string error;
     if (!local_connect_arbitrary_ports(console_port, adb_port, &error)) {
diff --git a/adb/transport.cpp b/adb/transport.cpp
index 4a9d91a..5acaaec 100644
--- a/adb/transport.cpp
+++ b/adb/transport.cpp
@@ -41,6 +41,7 @@
 
 #include "adb.h"
 #include "adb_auth.h"
+#include "adb_io.h"
 #include "adb_trace.h"
 #include "adb_utils.h"
 #include "diagnose_usb.h"
@@ -65,6 +66,36 @@
     return next++;
 }
 
+bool FdConnection::Read(apacket* packet) {
+    if (!ReadFdExactly(fd_.get(), &packet->msg, sizeof(amessage))) {
+        D("remote local: read terminated (message)");
+        return false;
+    }
+
+    if (!ReadFdExactly(fd_.get(), &packet->data, packet->msg.data_length)) {
+        D("remote local: terminated (data)");
+        return false;
+    }
+
+    return true;
+}
+
+bool FdConnection::Write(apacket* packet) {
+    uint32_t length = packet->msg.data_length;
+
+    if (!WriteFdExactly(fd_.get(), &packet->msg, sizeof(amessage) + length)) {
+        D("remote local: write terminated");
+        return false;
+    }
+
+    return true;
+}
+
+void FdConnection::Close() {
+    adb_shutdown(fd_.get());
+    fd_.reset();
+}
+
 static std::string dump_packet(const char* name, const char* func, apacket* p) {
     unsigned command = p->msg.command;
     int len = p->msg.data_length;
@@ -220,11 +251,18 @@
 
         {
             ATRACE_NAME("read_transport read_remote");
-            if (t->read_from_remote(p, t) != 0) {
+            if (!t->connection->Read(p)) {
                 D("%s: remote read failed for transport", t->serial);
                 put_apacket(p);
                 break;
             }
+
+            if (!check_header(p, t)) {
+                D("%s: remote read: bad header", t->serial);
+                put_apacket(p);
+                break;
+            }
+
 #if ADB_HOST
             if (p->msg.command == 0) {
                 put_apacket(p);
@@ -626,7 +664,7 @@
     t->ref_count--;
     if (t->ref_count == 0) {
         D("transport: %s unref (kicking and closing)", t->serial);
-        t->close(t);
+        t->connection->Close();
         remove_transport(t);
     } else {
         D("transport: %s unref (count=%zu)", t->serial, t->ref_count);
@@ -754,14 +792,14 @@
 }
 
 int atransport::Write(apacket* p) {
-    return write_func_(p, this);
+    return this->connection->Write(p) ? 0 : -1;
 }
 
 void atransport::Kick() {
     if (!kicked_) {
+        D("kicking transport %s", this->serial);
         kicked_ = true;
-        CHECK(kick_func_ != nullptr);
-        kick_func_(this);
+        this->connection->Close();
     }
 }
 
@@ -1083,8 +1121,12 @@
 // This should only be used for transports with connection_state == kCsNoPerm.
 void unregister_usb_transport(usb_handle* usb) {
     std::lock_guard<std::recursive_mutex> lock(transport_lock);
-    transport_list.remove_if(
-        [usb](atransport* t) { return t->usb == usb && t->GetConnectionState() == kCsNoPerm; });
+    transport_list.remove_if([usb](atransport* t) {
+        if (auto connection = dynamic_cast<UsbConnection*>(t->connection.get())) {
+            return connection->handle_ == usb && t->GetConnectionState() == kCsNoPerm;
+        }
+        return false;
+    });
 }
 
 bool check_header(apacket* p, atransport* t) {
diff --git a/adb/transport.h b/adb/transport.h
index 86cd992..9700f44 100644
--- a/adb/transport.h
+++ b/adb/transport.h
@@ -28,10 +28,11 @@
 #include <string>
 #include <unordered_set>
 
-#include "adb.h"
-
 #include <openssl/rsa.h>
 
+#include "adb.h"
+#include "adb_unique_fd.h"
+
 typedef std::unordered_set<std::string> FeatureSet;
 
 const FeatureSet& supported_features();
@@ -56,6 +57,50 @@
 
 TransportId NextTransportId();
 
+// Abstraction for a blocking packet transport.
+struct Connection {
+    Connection() = default;
+    Connection(const Connection& copy) = delete;
+    Connection(Connection&& move) = delete;
+
+    // Destroy a Connection. Formerly known as 'Close' in atransport.
+    virtual ~Connection() = default;
+
+    // Read/Write a packet. These functions are concurrently called from a transport's reader/writer
+    // threads.
+    virtual bool Read(apacket* packet) = 0;
+    virtual bool Write(apacket* packet) = 0;
+
+    // Terminate a connection.
+    // This method must be thread-safe, and must cause concurrent Reads/Writes to terminate.
+    // Formerly known as 'Kick' in atransport.
+    virtual void Close() = 0;
+};
+
+struct FdConnection : public Connection {
+    explicit FdConnection(unique_fd fd) : fd_(std::move(fd)) {}
+
+    bool Read(apacket* packet) override final;
+    bool Write(apacket* packet) override final;
+
+    void Close() override;
+
+  private:
+    unique_fd fd_;
+};
+
+struct UsbConnection : public Connection {
+    explicit UsbConnection(usb_handle* handle) : handle_(handle) {}
+    ~UsbConnection();
+
+    bool Read(apacket* packet) override final;
+    bool Write(apacket* packet) override final;
+
+    void Close() override final;
+
+    usb_handle* handle_;
+};
+
 class atransport {
   public:
     // TODO(danalbert): We expose waaaaaaay too much stuff because this was
@@ -73,12 +118,6 @@
     }
     virtual ~atransport() {}
 
-    int (*read_from_remote)(apacket* p, atransport* t) = nullptr;
-    void (*close)(atransport* t) = nullptr;
-
-    void SetWriteFunction(int (*write_func)(apacket*, atransport*)) { write_func_ = write_func; }
-    void SetKickFunction(void (*kick_func)(atransport*)) { kick_func_ = kick_func; }
-    bool IsKicked() { return kicked_; }
     int Write(apacket* p);
     void Kick();
 
@@ -95,9 +134,7 @@
     bool online = false;
     TransportType type = kTransportAny;
 
-    // USB handle or socket fd as needed.
-    usb_handle* usb = nullptr;
-    int sfd = -1;
+    std::unique_ptr<Connection> connection;
 
     // Used to identify transports for clients.
     char* serial = nullptr;
@@ -105,22 +142,8 @@
     char* model = nullptr;
     char* device = nullptr;
     char* devpath = nullptr;
-    void SetLocalPortForEmulator(int port) {
-        CHECK_EQ(local_port_for_emulator_, -1);
-        local_port_for_emulator_ = port;
-    }
 
-    bool GetLocalPortForEmulator(int* port) const {
-        if (type == kTransportLocal && local_port_for_emulator_ != -1) {
-            *port = local_port_for_emulator_;
-            return true;
-        }
-        return false;
-    }
-
-    bool IsTcpDevice() const {
-        return type == kTransportLocal && local_port_for_emulator_ == -1;
-    }
+    bool IsTcpDevice() const { return type == kTransportLocal; }
 
 #if ADB_HOST
     std::shared_ptr<RSA> NextKey();
@@ -165,10 +188,7 @@
     bool MatchesTarget(const std::string& target) const;
 
 private:
-    int local_port_for_emulator_ = -1;
     bool kicked_ = false;
-    void (*kick_func_)(atransport*) = nullptr;
-    int (*write_func_)(apacket*, atransport*) = nullptr;
 
     // A set of features transmitted in the banner with the initial connection.
     // This is stored in the banner as 'features=feature0,feature1,etc'.
diff --git a/adb/transport_local.cpp b/adb/transport_local.cpp
index d6c84da..560a031 100644
--- a/adb/transport_local.cpp
+++ b/adb/transport_local.cpp
@@ -28,10 +28,12 @@
 #include <condition_variable>
 #include <mutex>
 #include <thread>
+#include <unordered_map>
 #include <vector>
 
 #include <android-base/parsenetaddress.h>
 #include <android-base/stringprintf.h>
+#include <android-base/thread_annotations.h>
 #include <cutils/sockets.h>
 
 #if !ADB_HOST
@@ -40,6 +42,7 @@
 
 #include "adb.h"
 #include "adb_io.h"
+#include "adb_unique_fd.h"
 #include "adb_utils.h"
 #include "sysdeps/chrono.h"
 
@@ -53,48 +56,15 @@
 
 static std::mutex& local_transports_lock = *new std::mutex();
 
-/* we keep a list of opened transports. The atransport struct knows to which
- * local transport it is connected. The list is used to detect when we're
- * trying to connect twice to a given local transport.
- */
-static atransport*  local_transports[ ADB_LOCAL_TRANSPORT_MAX ];
+// We keep a map from emulator port to transport.
+// TODO: weak_ptr?
+static auto& local_transports GUARDED_BY(local_transports_lock) =
+    *new std::unordered_map<int, atransport*>();
 #endif /* ADB_HOST */
 
-static int remote_read(apacket *p, atransport *t)
-{
-    if (!ReadFdExactly(t->sfd, &p->msg, sizeof(amessage))) {
-        D("remote local: read terminated (message)");
-        return -1;
-    }
-
-    if (!check_header(p, t)) {
-        D("bad header: terminated (data)");
-        return -1;
-    }
-
-    if (!ReadFdExactly(t->sfd, p->data, p->msg.data_length)) {
-        D("remote local: terminated (data)");
-        return -1;
-    }
-
-    return 0;
-}
-
-static int remote_write(apacket *p, atransport *t)
-{
-    int   length = p->msg.data_length;
-
-    if(!WriteFdExactly(t->sfd, &p->msg, sizeof(amessage) + length)) {
-        D("remote local: write terminated");
-        return -1;
-    }
-
-    return 0;
-}
-
 bool local_connect(int port) {
     std::string dummy;
-    return local_connect_arbitrary_ports(port-1, port, &dummy) == 0;
+    return local_connect_arbitrary_ports(port - 1, port, &dummy) == 0;
 }
 
 void connect_device(const std::string& address, std::string* response) {
@@ -423,130 +393,83 @@
     std::thread(func, port).detach();
 }
 
-static void remote_kick(atransport *t)
-{
-    int fd = t->sfd;
-    t->sfd = -1;
-    adb_shutdown(fd);
-    adb_close(fd);
-
 #if ADB_HOST
-    int  nn;
-    std::lock_guard<std::mutex> lock(local_transports_lock);
-    for (nn = 0; nn < ADB_LOCAL_TRANSPORT_MAX; nn++) {
-        if (local_transports[nn] == t) {
-            local_transports[nn] = NULL;
-            break;
-        }
-    }
-#endif
-}
+struct EmulatorConnection : public FdConnection {
+    EmulatorConnection(unique_fd fd, int local_port)
+        : FdConnection(std::move(fd)), local_port_(local_port) {}
 
-static void remote_close(atransport *t)
-{
-    int fd = t->sfd;
-    if (fd != -1) {
-        t->sfd = -1;
-        adb_close(fd);
-    }
-#if ADB_HOST
-    int local_port;
-    if (t->GetLocalPortForEmulator(&local_port)) {
-        VLOG(TRANSPORT) << "remote_close, local_port = " << local_port;
+    ~EmulatorConnection() {
+        VLOG(TRANSPORT) << "remote_close, local_port = " << local_port_;
         std::unique_lock<std::mutex> lock(retry_ports_lock);
         RetryPort port;
-        port.port = local_port;
+        port.port = local_port_;
         port.retry_count = LOCAL_PORT_RETRY_COUNT;
         retry_ports.push_back(port);
         retry_ports_cond.notify_one();
     }
-#endif
-}
 
+    void Close() override {
+        std::lock_guard<std::mutex> lock(local_transports_lock);
+        local_transports.erase(local_port_);
+        FdConnection::Close();
+    }
 
-#if ADB_HOST
+    int local_port_;
+};
+
 /* Only call this function if you already hold local_transports_lock. */
 static atransport* find_emulator_transport_by_adb_port_locked(int adb_port)
-{
-    int i;
-    for (i = 0; i < ADB_LOCAL_TRANSPORT_MAX; i++) {
-        int local_port;
-        if (local_transports[i] && local_transports[i]->GetLocalPortForEmulator(&local_port)) {
-            if (local_port == adb_port) {
-                return local_transports[i];
-            }
-        }
+    REQUIRES(local_transports_lock) {
+    auto it = local_transports.find(adb_port);
+    if (it == local_transports.end()) {
+        return nullptr;
     }
-    return NULL;
+    return it->second;
 }
 
-std::string getEmulatorSerialString(int console_port)
-{
+std::string getEmulatorSerialString(int console_port) {
     return android::base::StringPrintf("emulator-%d", console_port);
 }
 
-atransport* find_emulator_transport_by_adb_port(int adb_port)
-{
+atransport* find_emulator_transport_by_adb_port(int adb_port) {
     std::lock_guard<std::mutex> lock(local_transports_lock);
-    atransport* result = find_emulator_transport_by_adb_port_locked(adb_port);
-    return result;
+    return find_emulator_transport_by_adb_port_locked(adb_port);
 }
 
-atransport* find_emulator_transport_by_console_port(int console_port)
-{
+atransport* find_emulator_transport_by_console_port(int console_port) {
     return find_transport(getEmulatorSerialString(console_port).c_str());
 }
-
-
-/* Only call this function if you already hold local_transports_lock. */
-int get_available_local_transport_index_locked()
-{
-    int i;
-    for (i = 0; i < ADB_LOCAL_TRANSPORT_MAX; i++) {
-        if (local_transports[i] == NULL) {
-            return i;
-        }
-    }
-    return -1;
-}
-
-int get_available_local_transport_index()
-{
-    std::lock_guard<std::mutex> lock(local_transports_lock);
-    int result = get_available_local_transport_index_locked();
-    return result;
-}
 #endif
 
-int init_socket_transport(atransport *t, int s, int adb_port, int local)
-{
-    int  fail = 0;
+int init_socket_transport(atransport* t, int s, int adb_port, int local) {
+    int fail = 0;
 
-    t->SetKickFunction(remote_kick);
-    t->SetWriteFunction(remote_write);
-    t->close = remote_close;
-    t->read_from_remote = remote_read;
-    t->sfd = s;
+    unique_fd fd(s);
     t->sync_token = 1;
     t->type = kTransportLocal;
 
 #if ADB_HOST
+    // Emulator connection.
     if (local) {
+        t->connection.reset(new EmulatorConnection(std::move(fd), adb_port));
         std::lock_guard<std::mutex> lock(local_transports_lock);
-        t->SetLocalPortForEmulator(adb_port);
         atransport* existing_transport = find_emulator_transport_by_adb_port_locked(adb_port);
-        int index = get_available_local_transport_index_locked();
         if (existing_transport != NULL) {
             D("local transport for port %d already registered (%p)?", adb_port, existing_transport);
             fail = -1;
-        } else if (index < 0) {
+        } else if (local_transports.size() >= ADB_LOCAL_TRANSPORT_MAX) {
             // Too many emulators.
             D("cannot register more emulators. Maximum is %d", ADB_LOCAL_TRANSPORT_MAX);
             fail = -1;
         } else {
-            local_transports[index] = t;
+            local_transports[adb_port] = t;
         }
+
+        return fail;
     }
 #endif
+
+    // Regular tcp connection.
+    t->connection.reset(new FdConnection(std::move(fd)));
     return fail;
 }
diff --git a/adb/transport_test.cpp b/adb/transport_test.cpp
index 68689d4..d987d4f 100644
--- a/adb/transport_test.cpp
+++ b/adb/transport_test.cpp
@@ -20,22 +20,6 @@
 
 #include "adb.h"
 
-TEST(transport, kick_transport) {
-  atransport t;
-  static size_t kick_count;
-  kick_count = 0;
-  // Mutate some member so we can test that the function is run.
-  t.SetKickFunction([](atransport* trans) { kick_count++; });
-  ASSERT_FALSE(t.IsKicked());
-  t.Kick();
-  ASSERT_TRUE(t.IsKicked());
-  ASSERT_EQ(1u, kick_count);
-  // A transport can only be kicked once.
-  t.Kick();
-  ASSERT_TRUE(t.IsKicked());
-  ASSERT_EQ(1u, kick_count);
-}
-
 static void DisconnectFunc(void* arg, atransport*) {
     int* count = reinterpret_cast<int*>(arg);
     ++*count;
diff --git a/adb/transport_usb.cpp b/adb/transport_usb.cpp
index 3474820..73e8e15 100644
--- a/adb/transport_usb.cpp
+++ b/adb/transport_usb.cpp
@@ -80,25 +80,18 @@
 #endif
 }
 
-static int remote_read(apacket* p, atransport* t) {
-    int n = UsbReadMessage(t->usb, &p->msg);
+static int remote_read(apacket* p, usb_handle* usb) {
+    int n = UsbReadMessage(usb, &p->msg);
     if (n < 0) {
         D("remote usb: read terminated (message)");
         return -1;
     }
-    if (static_cast<size_t>(n) != sizeof(p->msg) || !check_header(p, t)) {
-        D("remote usb: check_header failed, skip it");
-        goto err_msg;
-    }
-    if (t->GetConnectionState() == kCsOffline) {
-        // If we read a wrong msg header declaring a large message payload, don't read its payload.
-        // Otherwise we may miss true messages from the device.
-        if (p->msg.command != A_CNXN && p->msg.command != A_AUTH) {
-            goto err_msg;
-        }
+    if (static_cast<size_t>(n) != sizeof(p->msg)) {
+        D("remote usb: read received unexpected header length %d", n);
+        return -1;
     }
     if (p->msg.data_length) {
-        n = UsbReadPayload(t->usb, p);
+        n = UsbReadPayload(usb, p);
         if (n < 0) {
             D("remote usb: terminated (data)");
             return -1;
@@ -106,34 +99,24 @@
         if (static_cast<uint32_t>(n) != p->msg.data_length) {
             D("remote usb: read payload failed (need %u bytes, give %d bytes), skip it",
               p->msg.data_length, n);
-            goto err_msg;
+            return -1;
         }
     }
     return 0;
-
-err_msg:
-    p->msg.command = 0;
-    return 0;
 }
 
 #else
 
 // On Android devices, we rely on the kernel to provide buffered read.
 // So we can recover automatically from EOVERFLOW.
-static int remote_read(apacket *p, atransport *t)
-{
-    if (usb_read(t->usb, &p->msg, sizeof(amessage))) {
+static int remote_read(apacket* p, usb_handle* usb) {
+    if (usb_read(usb, &p->msg, sizeof(amessage))) {
         PLOG(ERROR) << "remote usb: read terminated (message)";
         return -1;
     }
 
-    if (!check_header(p, t)) {
-        LOG(ERROR) << "remote usb: check_header failed";
-        return -1;
-    }
-
     if (p->msg.data_length) {
-        if (usb_read(t->usb, p->data, p->msg.data_length)) {
+        if (usb_read(usb, p->data, p->msg.data_length)) {
             PLOG(ERROR) << "remote usb: terminated (data)";
             return -1;
         }
@@ -143,45 +126,43 @@
 }
 #endif
 
-static int remote_write(apacket *p, atransport *t)
-{
-    unsigned size = p->msg.data_length;
+UsbConnection::~UsbConnection() {
+    usb_close(handle_);
+}
 
-    if (usb_write(t->usb, &p->msg, sizeof(amessage))) {
+bool UsbConnection::Read(apacket* packet) {
+    int rc = remote_read(packet, handle_);
+    return rc == 0;
+}
+
+bool UsbConnection::Write(apacket* packet) {
+    unsigned size = packet->msg.data_length;
+
+    if (usb_write(handle_, &packet->msg, sizeof(packet->msg)) != 0) {
         PLOG(ERROR) << "remote usb: 1 - write terminated";
-        return -1;
+        return false;
     }
-    if (p->msg.data_length == 0) return 0;
-    if (usb_write(t->usb, &p->data, size)) {
+
+    if (packet->msg.data_length != 0 && usb_write(handle_, &packet->data, size) != 0) {
         PLOG(ERROR) << "remote usb: 2 - write terminated";
-        return -1;
+        return false;
     }
 
-    return 0;
+    return true;
 }
 
-static void remote_close(atransport* t) {
-    usb_close(t->usb);
-    t->usb = 0;
-}
-
-static void remote_kick(atransport* t) {
-    usb_kick(t->usb);
+void UsbConnection::Close() {
+    usb_kick(handle_);
 }
 
 void init_usb_transport(atransport* t, usb_handle* h) {
     D("transport: usb");
-    t->close = remote_close;
-    t->SetKickFunction(remote_kick);
-    t->SetWriteFunction(remote_write);
-    t->read_from_remote = remote_read;
+    t->connection.reset(new UsbConnection(h));
     t->sync_token = 1;
     t->type = kTransportUsb;
-    t->usb = h;
 }
 
-int is_adb_interface(int usb_class, int usb_subclass, int usb_protocol)
-{
+int is_adb_interface(int usb_class, int usb_subclass, int usb_protocol) {
     return (usb_class == ADB_CLASS && usb_subclass == ADB_SUBCLASS && usb_protocol == ADB_PROTOCOL);
 }