Give Connection an atransport*.

Bug: http://b/186595076
Test: treehugger
Change-Id: I2675db1fa2b1823e5d73e42cb47f6fcf18cfb4f6
diff --git a/client/usb_libusb.cpp b/client/usb_libusb.cpp
index a877610..8bdf076 100644
--- a/client/usb_libusb.cpp
+++ b/client/usb_libusb.cpp
@@ -123,7 +123,7 @@
         if (payload) {
             packet->payload = std::move(*payload);
         }
-        read_callback_(this, std::move(packet));
+        transport_->HandleRead(std::move(packet));
     }
 
     void Cleanup(ReadBlock* read_block) REQUIRES(read_mutex_) {
@@ -553,8 +553,8 @@
 
     void OnError(const std::string& error) {
         std::call_once(error_flag_, [this, &error]() {
-            if (error_callback_) {
-                error_callback_(this, error);
+            if (transport_) {
+                transport_->HandleError(error);
             }
         });
     }
diff --git a/daemon/usb.cpp b/daemon/usb.cpp
index ad6f818..775edd1 100644
--- a/daemon/usb.cpp
+++ b/daemon/usb.cpp
@@ -582,7 +582,7 @@
 
                 // TODO: Make apacket contain an IOVector so we don't have to coalesce.
                 packet->payload = std::move(incoming_payload_).coalesce();
-                read_callback_(this, std::move(packet));
+                transport_->HandleRead(std::move(packet));
 
                 incoming_header_.reset();
                 // reuse the capacity of the incoming payload while we can.
@@ -678,7 +678,10 @@
 
     void HandleError(const std::string& error) {
         std::call_once(error_flag_, [&]() {
-            error_callback_(this, error);
+            if (transport_) {
+                transport_->HandleError(error);
+            }
+
             if (!stopped_) {
                 Stop();
             }
diff --git a/transport.cpp b/transport.cpp
index fbcf79b..f98eee2 100644
--- a/transport.cpp
+++ b/transport.cpp
@@ -278,25 +278,28 @@
     Stop();
 }
 
+std::string Connection::Serial() const {
+    return transport_ ? transport_->serial_name() : "<unknown>";
+}
+
 BlockingConnectionAdapter::BlockingConnectionAdapter(std::unique_ptr<BlockingConnection> connection)
     : underlying_(std::move(connection)) {}
 
 BlockingConnectionAdapter::~BlockingConnectionAdapter() {
-    LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): destructing";
+    LOG(INFO) << "BlockingConnectionAdapter(" << Serial() << "): destructing";
     Stop();
 }
 
 void BlockingConnectionAdapter::Start() {
     std::lock_guard<std::mutex> lock(mutex_);
     if (started_) {
-        LOG(FATAL) << "BlockingConnectionAdapter(" << this->transport_name_
-                   << "): started multiple times";
+        LOG(FATAL) << "BlockingConnectionAdapter(" << Serial() << "): started multiple times";
     }
 
     StartReadThread();
 
     write_thread_ = std::thread([this]() {
-        LOG(INFO) << this->transport_name_ << ": write thread spawning";
+        LOG(INFO) << Serial() << ": write thread spawning";
         while (true) {
             std::unique_lock<std::mutex> lock(mutex_);
             ScopedLockAssertion assume_locked(mutex_);
@@ -316,7 +319,7 @@
                 break;
             }
         }
-        std::call_once(this->error_flag_, [this]() { this->error_callback_(this, "write failed"); });
+        std::call_once(this->error_flag_, [this]() { transport_->HandleError("write failed"); });
     });
 
     started_ = true;
@@ -324,11 +327,11 @@
 
 void BlockingConnectionAdapter::StartReadThread() {
     read_thread_ = std::thread([this]() {
-        LOG(INFO) << this->transport_name_ << ": read thread spawning";
+        LOG(INFO) << Serial() << ": read thread spawning";
         while (true) {
             auto packet = std::make_unique<apacket>();
             if (!underlying_->Read(packet.get())) {
-                PLOG(INFO) << this->transport_name_ << ": read failed";
+                PLOG(INFO) << Serial() << ": read failed";
                 break;
             }
 
@@ -337,18 +340,17 @@
                 got_stls_cmd = true;
             }
 
-            read_callback_(this, std::move(packet));
+            transport_->HandleRead(std::move(packet));
 
             // If we received the STLS packet, we are about to perform the TLS
             // handshake. So this read thread must stop and resume after the
             // handshake completes otherwise this will interfere in the process.
             if (got_stls_cmd) {
-                LOG(INFO) << this->transport_name_
-                          << ": Received STLS packet. Stopping read thread.";
+                LOG(INFO) << Serial() << ": Received STLS packet. Stopping read thread.";
                 return;
             }
         }
-        std::call_once(this->error_flag_, [this]() { this->error_callback_(this, "read failed"); });
+        std::call_once(this->error_flag_, [this]() { transport_->HandleError("read failed"); });
     });
 }
 
@@ -366,18 +368,17 @@
     {
         std::lock_guard<std::mutex> lock(mutex_);
         if (!started_) {
-            LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): not started";
+            LOG(INFO) << "BlockingConnectionAdapter(" << Serial() << "): not started";
             return;
         }
 
         if (stopped_) {
-            LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_
-                      << "): already stopped";
+            LOG(INFO) << "BlockingConnectionAdapter(" << Serial() << "): already stopped";
             return;
         }
     }
 
-    LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): resetting";
+    LOG(INFO) << "BlockingConnectionAdapter(" << Serial() << "): resetting";
     this->underlying_->Reset();
     Stop();
 }
@@ -386,20 +387,19 @@
     {
         std::lock_guard<std::mutex> lock(mutex_);
         if (!started_) {
-            LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): not started";
+            LOG(INFO) << "BlockingConnectionAdapter(" << Serial() << "): not started";
             return;
         }
 
         if (stopped_) {
-            LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_
-                      << "): already stopped";
+            LOG(INFO) << "BlockingConnectionAdapter(" << Serial() << "): already stopped";
             return;
         }
 
         stopped_ = true;
     }
 
-    LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): stopping";
+    LOG(INFO) << "BlockingConnectionAdapter(" << Serial() << "): stopping";
 
     this->underlying_->Close();
     this->cv_.notify_one();
@@ -417,8 +417,8 @@
     read_thread.join();
     write_thread.join();
 
-    LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): stopped";
-    std::call_once(this->error_flag_, [this]() { this->error_callback_(this, "requested stop"); });
+    LOG(INFO) << "BlockingConnectionAdapter(" << Serial() << "): stopped";
+    std::call_once(this->error_flag_, [this]() { transport_->HandleError("requested stop"); });
 }
 
 bool BlockingConnectionAdapter::Write(std::unique_ptr<apacket> packet) {
@@ -792,30 +792,7 @@
 
     /* don't create transport threads for inaccessible devices */
     if (t->GetConnectionState() != kCsNoPerm) {
-        // The connection gets a reference to the atransport. It will release it
-        // upon a read/write error.
-        t->connection()->SetTransportName(t->serial_name());
-        t->connection()->SetReadCallback([t](Connection*, std::unique_ptr<apacket> p) {
-            if (!check_header(p.get(), t)) {
-                D("%s: remote read: bad header", t->serial.c_str());
-                return false;
-            }
-
-            VLOG(TRANSPORT) << dump_packet(t->serial.c_str(), "from remote", p.get());
-            apacket* packet = p.release();
-
-            // TODO: Does this need to run on the main thread?
-            fdevent_run_on_main_thread([packet, t]() { handle_packet(packet, t); });
-            return true;
-        });
-        t->connection()->SetErrorCallback([t](Connection*, const std::string& error) {
-            LOG(INFO) << t->serial_name() << ": connection terminated: " << error;
-            fdevent_run_on_main_thread([t]() {
-                handle_offline(t);
-                transport_destroy(t);
-            });
-        });
-
+        t->connection()->SetTransport(t);
         t->connection()->Start();
 #if ADB_HOST
         send_connect(t);
@@ -852,7 +829,7 @@
     transport_registration_recv = s[1];
 
     transport_registration_fde =
-        fdevent_create(transport_registration_recv, transport_registration_func, nullptr);
+            fdevent_create(transport_registration_recv, transport_registration_func, nullptr);
     fdevent_set(transport_registration_fde, FDE_READ);
 }
 
@@ -961,8 +938,8 @@
     atransport* result = nullptr;
 
     if (transport_id != 0) {
-        *error_out =
-            android::base::StringPrintf("no device with transport id '%" PRIu64 "'", transport_id);
+        *error_out = android::base::StringPrintf("no device with transport id '%" PRIu64 "'",
+                                                 transport_id);
     } else if (serial) {
         *error_out = android::base::StringPrintf("device '%s' not found", serial);
     } else if (type == kTransportLocal) {
@@ -1129,6 +1106,28 @@
     connection_ = std::shared_ptr<Connection>(std::move(connection));
 }
 
+bool atransport::HandleRead(std::unique_ptr<apacket> p) {
+    if (!check_header(p.get(), this)) {
+        D("%s: remote read: bad header", serial.c_str());
+        return false;
+    }
+
+    VLOG(TRANSPORT) << dump_packet(serial.c_str(), "from remote", p.get());
+    apacket* packet = p.release();
+
+    // TODO: Does this need to run on the main thread?
+    fdevent_run_on_main_thread([packet, this]() { handle_packet(packet, this); });
+    return true;
+}
+
+void atransport::HandleError(const std::string& error) {
+    LOG(INFO) << serial_name() << ": connection terminated: " << error;
+    fdevent_run_on_main_thread([this]() {
+        handle_offline(this);
+        transport_destroy(this);
+    });
+}
+
 std::string atransport::connection_state_name() const {
     ConnectionState state = GetConnectionState();
     switch (state) {
@@ -1268,7 +1267,8 @@
             // Parse our |serial| and the given |target| to check if the hostnames and ports match.
             std::string serial_host, error;
             int serial_port = -1;
-            if (android::base::ParseNetAddress(serial, &serial_host, &serial_port, nullptr, &error)) {
+            if (android::base::ParseNetAddress(serial, &serial_host, &serial_port, nullptr,
+                                               &error)) {
                 // |target| may omit the port to default to ours.
                 std::string target_host;
                 int target_port = serial_port;
diff --git a/transport.h b/transport.h
index d098b7c..32acc1c 100644
--- a/transport.h
+++ b/transport.h
@@ -106,22 +106,7 @@
     Connection() = default;
     virtual ~Connection() = default;
 
-    void SetTransportName(std::string transport_name) {
-        transport_name_ = std::move(transport_name);
-    }
-
-    using ReadCallback = std::function<bool(Connection*, std::unique_ptr<apacket>)>;
-    void SetReadCallback(ReadCallback callback) {
-        CHECK(!read_callback_);
-        read_callback_ = callback;
-    }
-
-    // Called after the Connection has terminated, either by an error or because Stop was called.
-    using ErrorCallback = std::function<void(Connection*, const std::string&)>;
-    void SetErrorCallback(ErrorCallback callback) {
-        CHECK(!error_callback_);
-        error_callback_ = callback;
-    }
+    void SetTransport(atransport* transport) { transport_ = transport; }
 
     virtual bool Write(std::unique_ptr<apacket> packet) = 0;
 
@@ -133,9 +118,9 @@
     // Stop, and reset the device if it's a USB connection.
     virtual void Reset();
 
-    std::string transport_name_;
-    ReadCallback read_callback_;
-    ErrorCallback error_callback_;
+    std::string Serial() const;
+
+    atransport* transport_ = nullptr;
 
     static std::unique_ptr<Connection> FromFd(unique_fd fd);
 };
@@ -295,6 +280,9 @@
         return connection_;
     }
 
+    bool HandleRead(std::unique_ptr<apacket> p);
+    void HandleError(const std::string& error);
+
 #if ADB_HOST
     void SetUsbHandle(usb_handle* h) { usb_handle_ = h; }
     usb_handle* GetUsbHandle() { return usb_handle_; }
diff --git a/transport_fd.cpp b/transport_fd.cpp
index b9b4f42..d88d57d 100644
--- a/transport_fd.cpp
+++ b/transport_fd.cpp
@@ -121,7 +121,7 @@
                         packet->msg = *read_header_;
                         packet->payload = std::move(payload);
                         read_header_ = nullptr;
-                        read_callback_(this, std::move(packet));
+                        transport_->HandleRead(std::move(packet));
                     }
                 }
             }
@@ -145,7 +145,7 @@
         thread_ = std::thread([this]() {
             std::string error = "connection closed";
             Run(&error);
-            this->error_callback_(this, error);
+            transport_->HandleError(error);
         });
     }