Give Connection an atransport*.
Bug: http://b/186595076
Test: treehugger
Change-Id: I2675db1fa2b1823e5d73e42cb47f6fcf18cfb4f6
diff --git a/client/usb_libusb.cpp b/client/usb_libusb.cpp
index a877610..8bdf076 100644
--- a/client/usb_libusb.cpp
+++ b/client/usb_libusb.cpp
@@ -123,7 +123,7 @@
if (payload) {
packet->payload = std::move(*payload);
}
- read_callback_(this, std::move(packet));
+ transport_->HandleRead(std::move(packet));
}
void Cleanup(ReadBlock* read_block) REQUIRES(read_mutex_) {
@@ -553,8 +553,8 @@
void OnError(const std::string& error) {
std::call_once(error_flag_, [this, &error]() {
- if (error_callback_) {
- error_callback_(this, error);
+ if (transport_) {
+ transport_->HandleError(error);
}
});
}
diff --git a/daemon/usb.cpp b/daemon/usb.cpp
index ad6f818..775edd1 100644
--- a/daemon/usb.cpp
+++ b/daemon/usb.cpp
@@ -582,7 +582,7 @@
// TODO: Make apacket contain an IOVector so we don't have to coalesce.
packet->payload = std::move(incoming_payload_).coalesce();
- read_callback_(this, std::move(packet));
+ transport_->HandleRead(std::move(packet));
incoming_header_.reset();
// reuse the capacity of the incoming payload while we can.
@@ -678,7 +678,10 @@
void HandleError(const std::string& error) {
std::call_once(error_flag_, [&]() {
- error_callback_(this, error);
+ if (transport_) {
+ transport_->HandleError(error);
+ }
+
if (!stopped_) {
Stop();
}
diff --git a/transport.cpp b/transport.cpp
index fbcf79b..f98eee2 100644
--- a/transport.cpp
+++ b/transport.cpp
@@ -278,25 +278,28 @@
Stop();
}
+std::string Connection::Serial() const {
+ return transport_ ? transport_->serial_name() : "<unknown>";
+}
+
BlockingConnectionAdapter::BlockingConnectionAdapter(std::unique_ptr<BlockingConnection> connection)
: underlying_(std::move(connection)) {}
BlockingConnectionAdapter::~BlockingConnectionAdapter() {
- LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): destructing";
+ LOG(INFO) << "BlockingConnectionAdapter(" << Serial() << "): destructing";
Stop();
}
void BlockingConnectionAdapter::Start() {
std::lock_guard<std::mutex> lock(mutex_);
if (started_) {
- LOG(FATAL) << "BlockingConnectionAdapter(" << this->transport_name_
- << "): started multiple times";
+ LOG(FATAL) << "BlockingConnectionAdapter(" << Serial() << "): started multiple times";
}
StartReadThread();
write_thread_ = std::thread([this]() {
- LOG(INFO) << this->transport_name_ << ": write thread spawning";
+ LOG(INFO) << Serial() << ": write thread spawning";
while (true) {
std::unique_lock<std::mutex> lock(mutex_);
ScopedLockAssertion assume_locked(mutex_);
@@ -316,7 +319,7 @@
break;
}
}
- std::call_once(this->error_flag_, [this]() { this->error_callback_(this, "write failed"); });
+ std::call_once(this->error_flag_, [this]() { transport_->HandleError("write failed"); });
});
started_ = true;
@@ -324,11 +327,11 @@
void BlockingConnectionAdapter::StartReadThread() {
read_thread_ = std::thread([this]() {
- LOG(INFO) << this->transport_name_ << ": read thread spawning";
+ LOG(INFO) << Serial() << ": read thread spawning";
while (true) {
auto packet = std::make_unique<apacket>();
if (!underlying_->Read(packet.get())) {
- PLOG(INFO) << this->transport_name_ << ": read failed";
+ PLOG(INFO) << Serial() << ": read failed";
break;
}
@@ -337,18 +340,17 @@
got_stls_cmd = true;
}
- read_callback_(this, std::move(packet));
+ transport_->HandleRead(std::move(packet));
// If we received the STLS packet, we are about to perform the TLS
// handshake. So this read thread must stop and resume after the
// handshake completes otherwise this will interfere in the process.
if (got_stls_cmd) {
- LOG(INFO) << this->transport_name_
- << ": Received STLS packet. Stopping read thread.";
+ LOG(INFO) << Serial() << ": Received STLS packet. Stopping read thread.";
return;
}
}
- std::call_once(this->error_flag_, [this]() { this->error_callback_(this, "read failed"); });
+ std::call_once(this->error_flag_, [this]() { transport_->HandleError("read failed"); });
});
}
@@ -366,18 +368,17 @@
{
std::lock_guard<std::mutex> lock(mutex_);
if (!started_) {
- LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): not started";
+ LOG(INFO) << "BlockingConnectionAdapter(" << Serial() << "): not started";
return;
}
if (stopped_) {
- LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_
- << "): already stopped";
+ LOG(INFO) << "BlockingConnectionAdapter(" << Serial() << "): already stopped";
return;
}
}
- LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): resetting";
+ LOG(INFO) << "BlockingConnectionAdapter(" << Serial() << "): resetting";
this->underlying_->Reset();
Stop();
}
@@ -386,20 +387,19 @@
{
std::lock_guard<std::mutex> lock(mutex_);
if (!started_) {
- LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): not started";
+ LOG(INFO) << "BlockingConnectionAdapter(" << Serial() << "): not started";
return;
}
if (stopped_) {
- LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_
- << "): already stopped";
+ LOG(INFO) << "BlockingConnectionAdapter(" << Serial() << "): already stopped";
return;
}
stopped_ = true;
}
- LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): stopping";
+ LOG(INFO) << "BlockingConnectionAdapter(" << Serial() << "): stopping";
this->underlying_->Close();
this->cv_.notify_one();
@@ -417,8 +417,8 @@
read_thread.join();
write_thread.join();
- LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): stopped";
- std::call_once(this->error_flag_, [this]() { this->error_callback_(this, "requested stop"); });
+ LOG(INFO) << "BlockingConnectionAdapter(" << Serial() << "): stopped";
+ std::call_once(this->error_flag_, [this]() { transport_->HandleError("requested stop"); });
}
bool BlockingConnectionAdapter::Write(std::unique_ptr<apacket> packet) {
@@ -792,30 +792,7 @@
/* don't create transport threads for inaccessible devices */
if (t->GetConnectionState() != kCsNoPerm) {
- // The connection gets a reference to the atransport. It will release it
- // upon a read/write error.
- t->connection()->SetTransportName(t->serial_name());
- t->connection()->SetReadCallback([t](Connection*, std::unique_ptr<apacket> p) {
- if (!check_header(p.get(), t)) {
- D("%s: remote read: bad header", t->serial.c_str());
- return false;
- }
-
- VLOG(TRANSPORT) << dump_packet(t->serial.c_str(), "from remote", p.get());
- apacket* packet = p.release();
-
- // TODO: Does this need to run on the main thread?
- fdevent_run_on_main_thread([packet, t]() { handle_packet(packet, t); });
- return true;
- });
- t->connection()->SetErrorCallback([t](Connection*, const std::string& error) {
- LOG(INFO) << t->serial_name() << ": connection terminated: " << error;
- fdevent_run_on_main_thread([t]() {
- handle_offline(t);
- transport_destroy(t);
- });
- });
-
+ t->connection()->SetTransport(t);
t->connection()->Start();
#if ADB_HOST
send_connect(t);
@@ -852,7 +829,7 @@
transport_registration_recv = s[1];
transport_registration_fde =
- fdevent_create(transport_registration_recv, transport_registration_func, nullptr);
+ fdevent_create(transport_registration_recv, transport_registration_func, nullptr);
fdevent_set(transport_registration_fde, FDE_READ);
}
@@ -961,8 +938,8 @@
atransport* result = nullptr;
if (transport_id != 0) {
- *error_out =
- android::base::StringPrintf("no device with transport id '%" PRIu64 "'", transport_id);
+ *error_out = android::base::StringPrintf("no device with transport id '%" PRIu64 "'",
+ transport_id);
} else if (serial) {
*error_out = android::base::StringPrintf("device '%s' not found", serial);
} else if (type == kTransportLocal) {
@@ -1129,6 +1106,28 @@
connection_ = std::shared_ptr<Connection>(std::move(connection));
}
+bool atransport::HandleRead(std::unique_ptr<apacket> p) {
+ if (!check_header(p.get(), this)) {
+ D("%s: remote read: bad header", serial.c_str());
+ return false;
+ }
+
+ VLOG(TRANSPORT) << dump_packet(serial.c_str(), "from remote", p.get());
+ apacket* packet = p.release();
+
+ // TODO: Does this need to run on the main thread?
+ fdevent_run_on_main_thread([packet, this]() { handle_packet(packet, this); });
+ return true;
+}
+
+void atransport::HandleError(const std::string& error) {
+ LOG(INFO) << serial_name() << ": connection terminated: " << error;
+ fdevent_run_on_main_thread([this]() {
+ handle_offline(this);
+ transport_destroy(this);
+ });
+}
+
std::string atransport::connection_state_name() const {
ConnectionState state = GetConnectionState();
switch (state) {
@@ -1268,7 +1267,8 @@
// Parse our |serial| and the given |target| to check if the hostnames and ports match.
std::string serial_host, error;
int serial_port = -1;
- if (android::base::ParseNetAddress(serial, &serial_host, &serial_port, nullptr, &error)) {
+ if (android::base::ParseNetAddress(serial, &serial_host, &serial_port, nullptr,
+ &error)) {
// |target| may omit the port to default to ours.
std::string target_host;
int target_port = serial_port;
diff --git a/transport.h b/transport.h
index d098b7c..32acc1c 100644
--- a/transport.h
+++ b/transport.h
@@ -106,22 +106,7 @@
Connection() = default;
virtual ~Connection() = default;
- void SetTransportName(std::string transport_name) {
- transport_name_ = std::move(transport_name);
- }
-
- using ReadCallback = std::function<bool(Connection*, std::unique_ptr<apacket>)>;
- void SetReadCallback(ReadCallback callback) {
- CHECK(!read_callback_);
- read_callback_ = callback;
- }
-
- // Called after the Connection has terminated, either by an error or because Stop was called.
- using ErrorCallback = std::function<void(Connection*, const std::string&)>;
- void SetErrorCallback(ErrorCallback callback) {
- CHECK(!error_callback_);
- error_callback_ = callback;
- }
+ void SetTransport(atransport* transport) { transport_ = transport; }
virtual bool Write(std::unique_ptr<apacket> packet) = 0;
@@ -133,9 +118,9 @@
// Stop, and reset the device if it's a USB connection.
virtual void Reset();
- std::string transport_name_;
- ReadCallback read_callback_;
- ErrorCallback error_callback_;
+ std::string Serial() const;
+
+ atransport* transport_ = nullptr;
static std::unique_ptr<Connection> FromFd(unique_fd fd);
};
@@ -295,6 +280,9 @@
return connection_;
}
+ bool HandleRead(std::unique_ptr<apacket> p);
+ void HandleError(const std::string& error);
+
#if ADB_HOST
void SetUsbHandle(usb_handle* h) { usb_handle_ = h; }
usb_handle* GetUsbHandle() { return usb_handle_; }
diff --git a/transport_fd.cpp b/transport_fd.cpp
index b9b4f42..d88d57d 100644
--- a/transport_fd.cpp
+++ b/transport_fd.cpp
@@ -121,7 +121,7 @@
packet->msg = *read_header_;
packet->payload = std::move(payload);
read_header_ = nullptr;
- read_callback_(this, std::move(packet));
+ transport_->HandleRead(std::move(packet));
}
}
}
@@ -145,7 +145,7 @@
thread_ = std::thread([this]() {
std::string error = "connection closed";
Run(&error);
- this->error_callback_(this, error);
+ transport_->HandleError(error);
});
}