Merge "Fix kick_transport test." am: 63e0819
am: 119ca23

* commit '119ca2345d97db04d4114f5031d172f3694350d4':
  Fix kick_transport test.

Change-Id: I4625174a1d62edfa8a1db52ced70356798d53c95
diff --git a/transport.cpp b/transport.cpp
index 413b362..cbb9c9c 100644
--- a/transport.cpp
+++ b/transport.cpp
@@ -295,20 +295,12 @@
     transport_unref(t);
 }
 
-static void kick_transport_locked(atransport* t) {
-    CHECK(t != nullptr);
-    if (!t->kicked) {
-        t->kicked = true;
-        t->kick(t);
-    }
-}
-
 void kick_transport(atransport* t) {
     adb_mutex_lock(&transport_lock);
     // As kick_transport() can be called from threads without guarantee that t is valid,
     // check if the transport is in transport_list first.
     if (std::find(transport_list.begin(), transport_list.end(), t) != transport_list.end()) {
-        kick_transport_locked(t);
+        t->Kick();
     }
     adb_mutex_unlock(&transport_lock);
 }
@@ -621,7 +613,7 @@
     t->ref_count--;
     if (t->ref_count == 0) {
         D("transport: %s unref (kicking and closing)", t->serial);
-        kick_transport_locked(t);
+        t->Kick();
         t->close(t);
         remove_transport(t);
     } else {
@@ -748,6 +740,14 @@
     return result;
 }
 
+void atransport::Kick() {
+    if (!kicked_) {
+        kicked_ = true;
+        CHECK(kick_func_ != nullptr);
+        kick_func_(this);
+    }
+}
+
 const std::string atransport::connection_state_name() const {
     switch (connection_state) {
         case kCsOffline: return "offline";
@@ -926,10 +926,7 @@
 void close_usb_devices() {
     adb_mutex_lock(&transport_lock);
     for (const auto& t : transport_list) {
-        if (!t->kicked) {
-            t->kicked = 1;
-            t->kick(t);
-        }
+        t->Kick();
     }
     adb_mutex_unlock(&transport_lock);
 }
@@ -1000,7 +997,7 @@
             // the read_transport thread will notify the main thread to make this transport
             // offline. Then the main thread will notify the write_transport thread to exit.
             // Finally, this transport will be closed and freed in the main thread.
-            kick_transport_locked(t);
+            t->Kick();
         }
     }
     adb_mutex_unlock(&transport_lock);
diff --git a/transport.h b/transport.h
index 5857249..35d7b50 100644
--- a/transport.h
+++ b/transport.h
@@ -60,7 +60,13 @@
     int (*read_from_remote)(apacket* p, atransport* t) = nullptr;
     int (*write_to_remote)(apacket* p, atransport* t) = nullptr;
     void (*close)(atransport* t) = nullptr;
-    void (*kick)(atransport* t) = nullptr;
+    void SetKickFunction(void (*kick_func)(atransport*)) {
+        kick_func_ = kick_func;
+    }
+    bool IsKicked() {
+        return kicked_;
+    }
+    void Kick();
 
     int fd = -1;
     int transport_socket = -1;
@@ -82,7 +88,6 @@
     char* device = nullptr;
     char* devpath = nullptr;
     int adb_port = -1;  // Use for emulators (local transport)
-    bool kicked = false;
 
     void* key = nullptr;
     unsigned char token[TOKEN_SIZE] = {};
@@ -123,6 +128,9 @@
     bool MatchesTarget(const std::string& target) const;
 
 private:
+    bool kicked_ = false;
+    void (*kick_func_)(atransport*) = nullptr;
+
     // A set of features transmitted in the banner with the initial connection.
     // This is stored in the banner as 'features=feature0,feature1,etc'.
     FeatureSet features_;
diff --git a/transport_local.cpp b/transport_local.cpp
index f6c9df4..4121f47 100644
--- a/transport_local.cpp
+++ b/transport_local.cpp
@@ -388,7 +388,7 @@
 {
     int  fail = 0;
 
-    t->kick = remote_kick;
+    t->SetKickFunction(remote_kick);
     t->close = remote_close;
     t->read_from_remote = remote_read;
     t->write_to_remote = remote_write;
diff --git a/transport_test.cpp b/transport_test.cpp
index 2028ecc..a6db07a 100644
--- a/transport_test.cpp
+++ b/transport_test.cpp
@@ -20,47 +20,6 @@
 
 #include "adb.h"
 
-class TestTransport : public atransport {
-public:
-    bool operator==(const atransport& rhs) const {
-        EXPECT_EQ(read_from_remote, rhs.read_from_remote);
-        EXPECT_EQ(write_to_remote, rhs.write_to_remote);
-        EXPECT_EQ(close, rhs.close);
-        EXPECT_EQ(kick, rhs.kick);
-
-        EXPECT_EQ(fd, rhs.fd);
-        EXPECT_EQ(transport_socket, rhs.transport_socket);
-
-        EXPECT_EQ(
-            0, memcmp(&transport_fde, &rhs.transport_fde, sizeof(fdevent)));
-
-        EXPECT_EQ(ref_count, rhs.ref_count);
-        EXPECT_EQ(sync_token, rhs.sync_token);
-        EXPECT_EQ(connection_state, rhs.connection_state);
-        EXPECT_EQ(online, rhs.online);
-        EXPECT_EQ(type, rhs.type);
-
-        EXPECT_EQ(usb, rhs.usb);
-        EXPECT_EQ(sfd, rhs.sfd);
-
-        EXPECT_EQ(serial, rhs.serial);
-        EXPECT_EQ(product, rhs.product);
-        EXPECT_EQ(model, rhs.model);
-        EXPECT_EQ(device, rhs.device);
-        EXPECT_EQ(devpath, rhs.devpath);
-        EXPECT_EQ(adb_port, rhs.adb_port);
-        EXPECT_EQ(kicked, rhs.kicked);
-
-        EXPECT_EQ(key, rhs.key);
-        EXPECT_EQ(0, memcmp(token, rhs.token, TOKEN_SIZE));
-        EXPECT_EQ(failed_auth_attempts, rhs.failed_auth_attempts);
-
-        EXPECT_EQ(features(), rhs.features());
-
-        return true;
-    }
-};
-
 class TransportSetup {
 public:
   TransportSetup() {
@@ -83,35 +42,19 @@
 static TransportSetup g_TransportSetup;
 
 TEST(transport, kick_transport) {
-  TestTransport t;
-
+  atransport t;
+  static size_t kick_count;
+  kick_count = 0;
   // Mutate some member so we can test that the function is run.
-  t.kick = [](atransport* trans) { trans->fd = 42; };
-
-  TestTransport expected;
-  expected.kick = t.kick;
-  expected.fd = 42;
-  expected.kicked = 1;
-
-  kick_transport(&t);
-  ASSERT_EQ(42, t.fd);
-  ASSERT_EQ(1, t.kicked);
-  ASSERT_EQ(expected, t);
-}
-
-TEST(transport, kick_transport_already_kicked) {
-  // Ensure that the transport is not modified if the transport has already been
-  // kicked.
-  TestTransport t;
-  t.kicked = 1;
-  t.kick = [](atransport*) { FAIL() << "Kick should not have been called"; };
-
-  TestTransport expected;
-  expected.kicked = 1;
-  expected.kick = t.kick;
-
-  kick_transport(&t);
-  ASSERT_EQ(expected, t);
+  t.SetKickFunction([](atransport* trans) { kick_count++; });
+  ASSERT_FALSE(t.IsKicked());
+  t.Kick();
+  ASSERT_TRUE(t.IsKicked());
+  ASSERT_EQ(1u, kick_count);
+  // A transport can only be kicked once.
+  t.Kick();
+  ASSERT_TRUE(t.IsKicked());
+  ASSERT_EQ(1u, kick_count);
 }
 
 static void DisconnectFunc(void* arg, atransport*) {
diff --git a/transport_usb.cpp b/transport_usb.cpp
index 263f9e7..d05d928 100644
--- a/transport_usb.cpp
+++ b/transport_usb.cpp
@@ -84,7 +84,7 @@
 {
     D("transport: usb");
     t->close = remote_close;
-    t->kick = remote_kick;
+    t->SetKickFunction(remote_kick);
     t->read_from_remote = remote_read;
     t->write_to_remote = remote_write;
     t->sync_token = 1;