Merge "fastboot: add Socket timeout detection."
diff --git a/fastboot/socket.cpp b/fastboot/socket.cpp
index d49f47f..14ecd93 100644
--- a/fastboot/socket.cpp
+++ b/fastboot/socket.cpp
@@ -48,18 +48,6 @@
     return ret;
 }
 
-bool Socket::SetReceiveTimeout(int timeout_ms) {
-    if (timeout_ms != receive_timeout_ms_) {
-        if (socket_set_receive_timeout(sock_, timeout_ms) == 0) {
-            receive_timeout_ms_ = timeout_ms;
-            return true;
-        }
-        return false;
-    }
-
-    return true;
-}
-
 ssize_t Socket::ReceiveAll(void* data, size_t length, int timeout_ms) {
     size_t total = 0;
 
@@ -82,6 +70,40 @@
     return socket_get_local_port(sock_);
 }
 
+// According to Windows setsockopt() documentation, if a Windows socket times out during send() or
+// recv() the state is indeterminate and should not be used. Our UDP protocol relies on being able
+// to re-send after a timeout, so we must use select() rather than SO_RCVTIMEO.
+// See https://msdn.microsoft.com/en-us/library/windows/desktop/ms740476(v=vs.85).aspx.
+bool Socket::WaitForRecv(int timeout_ms) {
+    receive_timed_out_ = false;
+
+    // In our usage |timeout_ms| <= 0 means block forever, so just return true immediately and let
+    // the subsequent recv() do the blocking.
+    if (timeout_ms <= 0) {
+        return true;
+    }
+
+    // select() doesn't always check this case and will block for |timeout_ms| if we let it.
+    if (sock_ == INVALID_SOCKET) {
+        return false;
+    }
+
+    fd_set read_set;
+    FD_ZERO(&read_set);
+    FD_SET(sock_, &read_set);
+
+    timeval timeout;
+    timeout.tv_sec = timeout_ms / 1000;
+    timeout.tv_usec = (timeout_ms % 1000) * 1000;
+
+    int result = TEMP_FAILURE_RETRY(select(sock_ + 1, &read_set, nullptr, nullptr, &timeout));
+
+    if (result == 0) {
+        receive_timed_out_ = true;
+    }
+    return result == 1;
+}
+
 // Implements the Socket interface for UDP.
 class UdpSocket : public Socket {
   public:
@@ -127,7 +149,7 @@
 }
 
 ssize_t UdpSocket::Receive(void* data, size_t length, int timeout_ms) {
-    if (!SetReceiveTimeout(timeout_ms)) {
+    if (!WaitForRecv(timeout_ms)) {
         return -1;
     }
 
@@ -206,7 +228,7 @@
 }
 
 ssize_t TcpSocket::Receive(void* data, size_t length, int timeout_ms) {
-    if (!SetReceiveTimeout(timeout_ms)) {
+    if (!WaitForRecv(timeout_ms)) {
         return -1;
     }
 
diff --git a/fastboot/socket.h b/fastboot/socket.h
index c0bd7c9..de543db 100644
--- a/fastboot/socket.h
+++ b/fastboot/socket.h
@@ -81,13 +81,17 @@
     virtual bool Send(std::vector<cutils_socket_buffer_t> buffers) = 0;
 
     // Waits up to |timeout_ms| to receive up to |length| bytes of data. |timout_ms| of 0 will
-    // block forever. Returns the number of bytes received or -1 on error/timeout. On timeout
-    // errno will be set to EAGAIN or EWOULDBLOCK.
+    // block forever. Returns the number of bytes received or -1 on error/timeout; see
+    // ReceiveTimedOut() to distinguish between the two.
     virtual ssize_t Receive(void* data, size_t length, int timeout_ms) = 0;
 
     // Calls Receive() until exactly |length| bytes have been received or an error occurs.
     virtual ssize_t ReceiveAll(void* data, size_t length, int timeout_ms);
 
+    // Returns true if the last Receive() call timed out normally and can be retried; fatal errors
+    // or successful reads will return false.
+    bool ReceiveTimedOut() { return receive_timed_out_; }
+
     // Closes the socket. Returns 0 on success, -1 on error.
     virtual int Close();
 
@@ -102,10 +106,13 @@
     // Protected constructor to force factory function use.
     Socket(cutils_socket_t sock);
 
-    // Update the socket receive timeout if necessary.
-    bool SetReceiveTimeout(int timeout_ms);
+    // Blocks up to |timeout_ms| until a read is possible on |sock_|, and sets |receive_timed_out_|
+    // as appropriate to help distinguish between normal timeouts and fatal errors. Returns true if
+    // a subsequent recv() on |sock_| will complete without blocking or if |timeout_ms| <= 0.
+    bool WaitForRecv(int timeout_ms);
 
     cutils_socket_t sock_ = INVALID_SOCKET;
+    bool receive_timed_out_ = false;
 
     // Non-class functions we want to override during tests to verify functionality. Implementation
     // should call this rather than using socket_send_buffers() directly.
@@ -113,8 +120,6 @@
             socket_send_buffers_function_ = &socket_send_buffers;
 
   private:
-    int receive_timeout_ms_ = 0;
-
     FRIEND_TEST(SocketTest, TestTcpSendBuffers);
     FRIEND_TEST(SocketTest, TestUdpSendBuffers);
 
diff --git a/fastboot/socket_mock.cpp b/fastboot/socket_mock.cpp
index c962f30..2531b53 100644
--- a/fastboot/socket_mock.cpp
+++ b/fastboot/socket_mock.cpp
@@ -55,7 +55,7 @@
         return false;
     }
 
-    bool return_value = events_.front().return_value;
+    bool return_value = events_.front().status;
     events_.pop();
     return return_value;
 }
@@ -76,21 +76,28 @@
         return -1;
     }
 
-    if (events_.front().type != EventType::kReceive) {
+    const Event& event = events_.front();
+    if (event.type != EventType::kReceive) {
         ADD_FAILURE() << "Receive() was called out-of-order";
         return -1;
     }
 
-    if (events_.front().return_value > static_cast<ssize_t>(length)) {
-        ADD_FAILURE() << "Receive(): not enough bytes (" << length << ") for "
-                      << events_.front().message;
+    const std::string& message = event.message;
+    if (message.length() > length) {
+        ADD_FAILURE() << "Receive(): not enough bytes (" << length << ") for " << message;
         return -1;
     }
 
-    ssize_t return_value = events_.front().return_value;
-    if (return_value > 0) {
-        memcpy(data, events_.front().message.data(), return_value);
+    receive_timed_out_ = event.status;
+    ssize_t return_value = message.length();
+
+    // Empty message indicates failure.
+    if (message.empty()) {
+        return_value = -1;
+    } else {
+        memcpy(data, message.data(), message.length());
     }
+
     events_.pop();
     return return_value;
 }
@@ -124,18 +131,21 @@
 }
 
 void SocketMock::AddReceive(std::string message) {
-    ssize_t return_value = message.length();
-    events_.push(Event(EventType::kReceive, std::move(message), return_value, nullptr));
+    events_.push(Event(EventType::kReceive, std::move(message), false, nullptr));
+}
+
+void SocketMock::AddReceiveTimeout() {
+    events_.push(Event(EventType::kReceive, "", true, nullptr));
 }
 
 void SocketMock::AddReceiveFailure() {
-    events_.push(Event(EventType::kReceive, "", -1, nullptr));
+    events_.push(Event(EventType::kReceive, "", false, nullptr));
 }
 
 void SocketMock::AddAccept(std::unique_ptr<Socket> sock) {
-    events_.push(Event(EventType::kAccept, "", 0, std::move(sock)));
+    events_.push(Event(EventType::kAccept, "", false, std::move(sock)));
 }
 
-SocketMock::Event::Event(EventType _type, std::string _message, ssize_t _return_value,
+SocketMock::Event::Event(EventType _type, std::string _message, ssize_t _status,
                          std::unique_ptr<Socket> _sock)
-        : type(_type), message(_message), return_value(_return_value), sock(std::move(_sock)) {}
+        : type(_type), message(_message), status(_status), sock(std::move(_sock)) {}
diff --git a/fastboot/socket_mock.h b/fastboot/socket_mock.h
index 41fe06d..eacd6bb 100644
--- a/fastboot/socket_mock.h
+++ b/fastboot/socket_mock.h
@@ -71,7 +71,10 @@
     // Adds data to provide for Receive().
     void AddReceive(std::string message);
 
-    // Adds a Receive() failure.
+    // Adds a Receive() timeout after which ReceiveTimedOut() will return true.
+    void AddReceiveTimeout();
+
+    // Adds a Receive() failure after which ReceiveTimedOut() will return false.
     void AddReceiveFailure();
 
     // Adds a Socket to return from Accept().
@@ -81,12 +84,12 @@
     enum class EventType { kSend, kReceive, kAccept };
 
     struct Event {
-        Event(EventType _type, std::string _message, ssize_t _return_value,
+        Event(EventType _type, std::string _message, ssize_t _status,
               std::unique_ptr<Socket> _sock);
 
         EventType type;
         std::string message;
-        ssize_t return_value;
+        bool status;  // Return value for Send() or timeout status for Receive().
         std::unique_ptr<Socket> sock;
     };
 
diff --git a/fastboot/socket_test.cpp b/fastboot/socket_test.cpp
index cc71075..affbdfd 100644
--- a/fastboot/socket_test.cpp
+++ b/fastboot/socket_test.cpp
@@ -28,7 +28,8 @@
 #include <gtest/gtest-spi.h>
 #include <gtest/gtest.h>
 
-enum { kTestTimeoutMs = 3000 };
+static constexpr int kShortTimeoutMs = 10;
+static constexpr int kTestTimeoutMs = 3000;
 
 // Creates connected sockets |server| and |client|. Returns true on success.
 bool MakeConnectedSockets(Socket::Protocol protocol, std::unique_ptr<Socket>* server,
@@ -87,6 +88,50 @@
     }
 }
 
+TEST(SocketTest, TestReceiveTimeout) {
+    std::unique_ptr<Socket> server, client;
+    char buffer[16];
+
+    for (Socket::Protocol protocol : {Socket::Protocol::kUdp, Socket::Protocol::kTcp}) {
+        ASSERT_TRUE(MakeConnectedSockets(protocol, &server, &client));
+
+        EXPECT_EQ(-1, server->Receive(buffer, sizeof(buffer), kShortTimeoutMs));
+        EXPECT_TRUE(server->ReceiveTimedOut());
+
+        EXPECT_EQ(-1, client->Receive(buffer, sizeof(buffer), kShortTimeoutMs));
+        EXPECT_TRUE(client->ReceiveTimedOut());
+    }
+
+    // UDP will wait for timeout if the other side closes.
+    ASSERT_TRUE(MakeConnectedSockets(Socket::Protocol::kUdp, &server, &client));
+    EXPECT_EQ(0, server->Close());
+    EXPECT_EQ(-1, client->Receive(buffer, sizeof(buffer), kShortTimeoutMs));
+    EXPECT_TRUE(client->ReceiveTimedOut());
+}
+
+TEST(SocketTest, TestReceiveFailure) {
+    std::unique_ptr<Socket> server, client;
+    char buffer[16];
+
+    for (Socket::Protocol protocol : {Socket::Protocol::kUdp, Socket::Protocol::kTcp}) {
+        ASSERT_TRUE(MakeConnectedSockets(protocol, &server, &client));
+
+        EXPECT_EQ(0, server->Close());
+        EXPECT_EQ(-1, server->Receive(buffer, sizeof(buffer), kTestTimeoutMs));
+        EXPECT_FALSE(server->ReceiveTimedOut());
+
+        EXPECT_EQ(0, client->Close());
+        EXPECT_EQ(-1, client->Receive(buffer, sizeof(buffer), kTestTimeoutMs));
+        EXPECT_FALSE(client->ReceiveTimedOut());
+    }
+
+    // TCP knows right away when the other side closes and returns 0 to indicate EOF.
+    ASSERT_TRUE(MakeConnectedSockets(Socket::Protocol::kTcp, &server, &client));
+    EXPECT_EQ(0, server->Close());
+    EXPECT_EQ(0, client->Receive(buffer, sizeof(buffer), kTestTimeoutMs));
+    EXPECT_FALSE(client->ReceiveTimedOut());
+}
+
 // Tests sending and receiving large packets.
 TEST(SocketTest, TestLargePackets) {
     std::string message(1024, '\0');
@@ -290,6 +335,11 @@
 
     mock->AddReceiveFailure();
     EXPECT_FALSE(ReceiveString(mock, "foo"));
+    EXPECT_FALSE(mock->ReceiveTimedOut());
+
+    mock->AddReceiveTimeout();
+    EXPECT_FALSE(ReceiveString(mock, "foo"));
+    EXPECT_TRUE(mock->ReceiveTimedOut());
 
     mock->AddReceive("foo");
     mock->AddReceiveFailure();