Merge "libbinder: fix RPC setup races" am: 4b02e9cf07 am: deaa84a08c am: 5f58c863c9

Original change: https://android-review.googlesource.com/c/platform/frameworks/native/+/1777153

Change-Id: I2000ff0245b324dbc28dcd01c6fff8a0bc788135
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp
index c6cf2c5..200d923 100644
--- a/libs/binder/RpcServer.cpp
+++ b/libs/binder/RpcServer.cpp
@@ -369,7 +369,7 @@
     return true;
 }
 
-void RpcServer::onSessionLockedAllIncomingThreadsEnded(const sp<RpcSession>& session) {
+void RpcServer::onSessionAllIncomingThreadsEnded(const sp<RpcSession>& session) {
     auto id = session->mId;
     LOG_ALWAYS_FATAL_IF(id == std::nullopt, "Server sessions must be initialized with ID");
     LOG_RPC_DETAIL("Dropping session with address %s", id->toString().c_str());
diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp
index 5fe0b00..1c37651 100644
--- a/libs/binder/RpcSession.cpp
+++ b/libs/binder/RpcSession.cpp
@@ -132,6 +132,7 @@
     if (wait) {
         LOG_ALWAYS_FATAL_IF(mShutdownListener == nullptr, "Shutdown listener not installed");
         mShutdownListener->waitForShutdown(_l);
+
         LOG_ALWAYS_FATAL_IF(!mThreads.empty(), "Shutdown failed");
     }
 
@@ -259,7 +260,7 @@
     return OK;
 }
 
-void RpcSession::WaitForShutdownListener::onSessionLockedAllIncomingThreadsEnded(
+void RpcSession::WaitForShutdownListener::onSessionAllIncomingThreadsEnded(
         const sp<RpcSession>& session) {
     (void)session;
     mShutdown = true;
@@ -291,7 +292,13 @@
     // be able to do nested calls (we can't only read from it)
     sp<RpcConnection> connection = assignIncomingConnectionToThisThread(std::move(fd));
 
-    status_t status = mState->readConnectionInit(connection, sp<RpcSession>::fromExisting(this));
+    status_t status;
+
+    if (connection == nullptr) {
+        status = DEAD_OBJECT;
+    } else {
+        status = mState->readConnectionInit(connection, sp<RpcSession>::fromExisting(this));
+    }
 
     return PreJoinSetupResult{
             .connection = std::move(connection),
@@ -358,6 +365,7 @@
     sp<RpcConnection>& connection = setupResult.connection;
 
     if (setupResult.status == OK) {
+        LOG_ALWAYS_FATAL_IF(!connection, "must have connection if setup succeeded");
         JavaThreadAttacher javaThreadAttacher;
         while (true) {
             status_t status = session->state()->getAndExecuteCommand(connection, session,
@@ -373,9 +381,6 @@
               statusToString(setupResult.status).c_str());
     }
 
-    LOG_ALWAYS_FATAL_IF(!session->removeIncomingConnection(connection),
-                        "bad state: connection object guaranteed to be in list");
-
     sp<RpcSession::EventListener> listener;
     {
         std::lock_guard<std::mutex> _l(session->mMutex);
@@ -387,6 +392,12 @@
         listener = session->mEventListener.promote();
     }
 
+    // done after all cleanup, since session shutdown progresses via callbacks here
+    if (connection != nullptr) {
+        LOG_ALWAYS_FATAL_IF(!session->removeIncomingConnection(connection),
+                            "bad state: connection object guaranteed to be in list");
+    }
+
     session = nullptr;
 
     if (listener != nullptr) {
@@ -577,24 +588,35 @@
 
 sp<RpcSession::RpcConnection> RpcSession::assignIncomingConnectionToThisThread(unique_fd fd) {
     std::lock_guard<std::mutex> _l(mMutex);
+
+    // Don't accept any more connections, some have shutdown. Usually this
+    // happens when new connections are still being established as part of a
+    // very short-lived session which shuts down after it already started
+    // accepting new connections.
+    if (mIncomingConnections.size() < mMaxIncomingConnections) {
+        return nullptr;
+    }
+
     sp<RpcConnection> session = sp<RpcConnection>::make();
     session->fd = std::move(fd);
     session->exclusiveTid = gettid();
+
     mIncomingConnections.push_back(session);
+    mMaxIncomingConnections = mIncomingConnections.size();
 
     return session;
 }
 
 bool RpcSession::removeIncomingConnection(const sp<RpcConnection>& connection) {
-    std::lock_guard<std::mutex> _l(mMutex);
+    std::unique_lock<std::mutex> _l(mMutex);
     if (auto it = std::find(mIncomingConnections.begin(), mIncomingConnections.end(), connection);
         it != mIncomingConnections.end()) {
         mIncomingConnections.erase(it);
         if (mIncomingConnections.size() == 0) {
             sp<EventListener> listener = mEventListener.promote();
             if (listener) {
-                listener->onSessionLockedAllIncomingThreadsEnded(
-                        sp<RpcSession>::fromExisting(this));
+                _l.unlock();
+                listener->onSessionAllIncomingThreadsEnded(sp<RpcSession>::fromExisting(this));
             }
         }
         return true;
diff --git a/libs/binder/include/binder/RpcServer.h b/libs/binder/include/binder/RpcServer.h
index c8d2857..a8094dd 100644
--- a/libs/binder/include/binder/RpcServer.h
+++ b/libs/binder/include/binder/RpcServer.h
@@ -156,7 +156,7 @@
     friend sp<RpcServer>;
     RpcServer();
 
-    void onSessionLockedAllIncomingThreadsEnded(const sp<RpcSession>& session) override;
+    void onSessionAllIncomingThreadsEnded(const sp<RpcSession>& session) override;
     void onSessionIncomingThreadEnded() override;
 
     static void establishConnection(sp<RpcServer>&& server, base::unique_fd clientFd);
diff --git a/libs/binder/include/binder/RpcSession.h b/libs/binder/include/binder/RpcSession.h
index fdca2a9..2101df8 100644
--- a/libs/binder/include/binder/RpcSession.h
+++ b/libs/binder/include/binder/RpcSession.h
@@ -177,19 +177,19 @@
 
     class EventListener : public virtual RefBase {
     public:
-        virtual void onSessionLockedAllIncomingThreadsEnded(const sp<RpcSession>& session) = 0;
+        virtual void onSessionAllIncomingThreadsEnded(const sp<RpcSession>& session) = 0;
         virtual void onSessionIncomingThreadEnded() = 0;
     };
 
     class WaitForShutdownListener : public EventListener {
     public:
-        void onSessionLockedAllIncomingThreadsEnded(const sp<RpcSession>& session) override;
+        void onSessionAllIncomingThreadsEnded(const sp<RpcSession>& session) override;
         void onSessionIncomingThreadEnded() override;
         void waitForShutdown(std::unique_lock<std::mutex>& lock);
 
     private:
         std::condition_variable mCv;
-        bool mShutdown = false;
+        volatile bool mShutdown = false;
     };
 
     struct RpcConnection : public RefBase {
@@ -297,6 +297,7 @@
     // hint index into clients, ++ when sending an async transaction
     size_t mOutgoingConnectionsOffset = 0;
     std::vector<sp<RpcConnection>> mOutgoingConnections;
+    size_t mMaxIncomingConnections = 0;
     std::vector<sp<RpcConnection>> mIncomingConnections;
     std::map<std::thread::id, std::thread> mThreads;
 };