libbinder: RPC state termination shutdown session

Previously, when state was terminated due to an error condition, it may
be the case that a client would still be trying to read data from the
session, but the connection would not be hung up (it would just be
ignored). So, now hanging up the session connections in this case.

Bug: 183140903
Test: binderRpcTest
Change-Id: I8c281ad2af3889cc3570a8d3b7bf3def8c51ec79
diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp
index 93e04f7..62118ff 100644
--- a/libs/binder/RpcSession.cpp
+++ b/libs/binder/RpcSession.cpp
@@ -113,17 +113,21 @@
     return state()->getMaxThreads(connection.fd(), sp<RpcSession>::fromExisting(this), maxThreads);
 }
 
-bool RpcSession::shutdown() {
+bool RpcSession::shutdownAndWait(bool wait) {
     std::unique_lock<std::mutex> _l(mMutex);
-    LOG_ALWAYS_FATAL_IF(mForServer.promote() != nullptr, "Can only shut down client session");
     LOG_ALWAYS_FATAL_IF(mShutdownTrigger == nullptr, "Shutdown trigger not installed");
-    LOG_ALWAYS_FATAL_IF(mShutdownListener == nullptr, "Shutdown listener not installed");
 
     mShutdownTrigger->trigger();
-    mShutdownListener->waitForShutdown(_l);
-    mState->terminate();
 
-    LOG_ALWAYS_FATAL_IF(!mThreads.empty(), "Shutdown failed");
+    if (wait) {
+        LOG_ALWAYS_FATAL_IF(mShutdownListener == nullptr, "Shutdown listener not installed");
+        mShutdownListener->waitForShutdown(_l);
+        LOG_ALWAYS_FATAL_IF(!mThreads.empty(), "Shutdown failed");
+    }
+
+    _l.unlock();
+    mState->clear();
+
     return true;
 }
 
@@ -139,7 +143,7 @@
 status_t RpcSession::sendDecStrong(const RpcAddress& address) {
     ExclusiveConnection connection(sp<RpcSession>::fromExisting(this),
                                    ConnectionUse::CLIENT_REFCOUNT);
-    return state()->sendDecStrong(connection.fd(), address);
+    return state()->sendDecStrong(connection.fd(), sp<RpcSession>::fromExisting(this), address);
 }
 
 std::unique_ptr<RpcSession::FdTrigger> RpcSession::FdTrigger::make() {
diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp
index 93f1529..7e731f3 100644
--- a/libs/binder/RpcState.cpp
+++ b/libs/binder/RpcState.cpp
@@ -137,9 +137,39 @@
     dumpLocked();
 }
 
-void RpcState::terminate() {
+void RpcState::clear() {
     std::unique_lock<std::mutex> _l(mNodeMutex);
-    terminate(_l);
+
+    if (mTerminated) {
+        LOG_ALWAYS_FATAL_IF(!mNodeForAddress.empty(),
+                            "New state should be impossible after terminating!");
+        return;
+    }
+
+    if (SHOULD_LOG_RPC_DETAIL) {
+        ALOGE("RpcState::clear()");
+        dumpLocked();
+    }
+
+    // if the destructor of a binder object makes another RPC call, then calling
+    // decStrong could deadlock. So, we must hold onto these binders until
+    // mNodeMutex is no longer taken.
+    std::vector<sp<IBinder>> tempHoldBinder;
+
+    mTerminated = true;
+    for (auto& [address, node] : mNodeForAddress) {
+        sp<IBinder> binder = node.binder.promote();
+        LOG_ALWAYS_FATAL_IF(binder == nullptr, "Binder %p expected to be owned.", binder.get());
+
+        if (node.sentRef != nullptr) {
+            tempHoldBinder.push_back(node.sentRef);
+        }
+    }
+
+    mNodeForAddress.clear();
+
+    _l.unlock();
+    tempHoldBinder.clear(); // explicit
 }
 
 void RpcState::dumpLocked() {
@@ -170,32 +200,6 @@
     ALOGE("END DUMP OF RpcState");
 }
 
-void RpcState::terminate(std::unique_lock<std::mutex>& lock) {
-    if (SHOULD_LOG_RPC_DETAIL) {
-        ALOGE("RpcState::terminate()");
-        dumpLocked();
-    }
-
-    // if the destructor of a binder object makes another RPC call, then calling
-    // decStrong could deadlock. So, we must hold onto these binders until
-    // mNodeMutex is no longer taken.
-    std::vector<sp<IBinder>> tempHoldBinder;
-
-    mTerminated = true;
-    for (auto& [address, node] : mNodeForAddress) {
-        sp<IBinder> binder = node.binder.promote();
-        LOG_ALWAYS_FATAL_IF(binder == nullptr, "Binder %p expected to be owned.", binder.get());
-
-        if (node.sentRef != nullptr) {
-            tempHoldBinder.push_back(node.sentRef);
-        }
-    }
-
-    mNodeForAddress.clear();
-
-    lock.unlock();
-    tempHoldBinder.clear(); // explicit
-}
 
 RpcState::CommandData::CommandData(size_t size) : mSize(size) {
     // The maximum size for regular binder is 1MB for all concurrent
@@ -218,13 +222,13 @@
     mData.reset(new (std::nothrow) uint8_t[size]);
 }
 
-status_t RpcState::rpcSend(const base::unique_fd& fd, const char* what, const void* data,
-                           size_t size) {
+status_t RpcState::rpcSend(const base::unique_fd& fd, const sp<RpcSession>& session,
+                           const char* what, const void* data, size_t size) {
     LOG_RPC_DETAIL("Sending %s on fd %d: %s", what, fd.get(), hexString(data, size).c_str());
 
     if (size > std::numeric_limits<ssize_t>::max()) {
         ALOGE("Cannot send %s at size %zu (too big)", what, size);
-        terminate();
+        (void)session->shutdownAndWait(false);
         return BAD_VALUE;
     }
 
@@ -235,7 +239,7 @@
         LOG_RPC_DETAIL("Failed to send %s (sent %zd of %zu bytes) on fd %d, error: %s", what, sent,
                        size, fd.get(), strerror(savedErrno));
 
-        terminate();
+        (void)session->shutdownAndWait(false);
         return -savedErrno;
     }
 
@@ -246,7 +250,7 @@
                           const char* what, void* data, size_t size) {
     if (size > std::numeric_limits<ssize_t>::max()) {
         ALOGE("Cannot rec %s at size %zu (too big)", what, size);
-        terminate();
+        (void)session->shutdownAndWait(false);
         return BAD_VALUE;
     }
 
@@ -358,7 +362,11 @@
 
         if (flags & IBinder::FLAG_ONEWAY) {
             asyncNumber = it->second.asyncNumber;
-            if (!nodeProgressAsyncNumber(&it->second, _l)) return DEAD_OBJECT;
+            if (!nodeProgressAsyncNumber(&it->second)) {
+                _l.unlock();
+                (void)session->shutdownAndWait(false);
+                return DEAD_OBJECT;
+            }
         }
     }
 
@@ -390,7 +398,7 @@
            data.dataSize());
 
     if (status_t status =
-                rpcSend(fd, "transaction", transactionData.data(), transactionData.size());
+                rpcSend(fd, session, "transaction", transactionData.data(), transactionData.size());
         status != OK)
         return status;
 
@@ -442,7 +450,7 @@
     if (command.bodySize < sizeof(RpcWireReply)) {
         ALOGE("Expecting %zu but got %" PRId32 " bytes for RpcWireReply. Terminating!",
               sizeof(RpcWireReply), command.bodySize);
-        terminate();
+        (void)session->shutdownAndWait(false);
         return BAD_VALUE;
     }
     RpcWireReply* rpcReply = reinterpret_cast<RpcWireReply*>(data.data());
@@ -457,7 +465,8 @@
     return OK;
 }
 
-status_t RpcState::sendDecStrong(const base::unique_fd& fd, const RpcAddress& addr) {
+status_t RpcState::sendDecStrong(const base::unique_fd& fd, const sp<RpcSession>& session,
+                                 const RpcAddress& addr) {
     {
         std::lock_guard<std::mutex> _l(mNodeMutex);
         if (mTerminated) return DEAD_OBJECT; // avoid fatal only, otherwise races
@@ -476,10 +485,10 @@
             .command = RPC_COMMAND_DEC_STRONG,
             .bodySize = sizeof(RpcWireAddress),
     };
-    if (status_t status = rpcSend(fd, "dec ref header", &cmd, sizeof(cmd)); status != OK)
+    if (status_t status = rpcSend(fd, session, "dec ref header", &cmd, sizeof(cmd)); status != OK)
         return status;
-    if (status_t status =
-                rpcSend(fd, "dec ref body", &addr.viewRawEmbedded(), sizeof(RpcWireAddress));
+    if (status_t status = rpcSend(fd, session, "dec ref body", &addr.viewRawEmbedded(),
+                                  sizeof(RpcWireAddress));
         status != OK)
         return status;
     return OK;
@@ -538,7 +547,7 @@
     // also can't consider it a fatal error because this would allow any client
     // to kill us, so ending the session for misbehaving client.
     ALOGE("Unknown RPC command %d - terminating session", command.command);
-    terminate();
+    (void)session->shutdownAndWait(false);
     return DEAD_OBJECT;
 }
 status_t RpcState::processTransact(const base::unique_fd& fd, const sp<RpcSession>& session,
@@ -571,7 +580,7 @@
     if (transactionData.size() < sizeof(RpcWireTransaction)) {
         ALOGE("Expecting %zu but got %zu bytes for RpcWireTransaction. Terminating!",
               sizeof(RpcWireTransaction), transactionData.size());
-        terminate();
+        (void)session->shutdownAndWait(false);
         return BAD_VALUE;
     }
     RpcWireTransaction* transaction = reinterpret_cast<RpcWireTransaction*>(transactionData.data());
@@ -600,12 +609,12 @@
             // session.
             ALOGE("While transacting, binder has been deleted at address %s. Terminating!",
                   addr.toString().c_str());
-            terminate();
+            (void)session->shutdownAndWait(false);
             replyStatus = BAD_VALUE;
         } else if (target->localBinder() == nullptr) {
             ALOGE("Unknown binder address or non-local binder, not address %s. Terminating!",
                   addr.toString().c_str());
-            terminate();
+            (void)session->shutdownAndWait(false);
             replyStatus = BAD_VALUE;
         } else if (transaction->flags & IBinder::FLAG_ONEWAY) {
             std::lock_guard<std::mutex> _l(mNodeMutex);
@@ -707,7 +716,11 @@
             // last refcount dropped after this transaction happened
             if (it == mNodeForAddress.end()) return OK;
 
-            if (!nodeProgressAsyncNumber(&it->second, _l)) return DEAD_OBJECT;
+            if (!nodeProgressAsyncNumber(&it->second)) {
+                _l.unlock();
+                (void)session->shutdownAndWait(false);
+                return DEAD_OBJECT;
+            }
 
             if (it->second.asyncTodo.size() == 0) return OK;
             if (it->second.asyncTodo.top().asyncNumber == it->second.asyncNumber) {
@@ -753,7 +766,7 @@
     memcpy(replyData.data() + sizeof(RpcWireHeader) + sizeof(RpcWireReply), reply.data(),
            reply.dataSize());
 
-    return rpcSend(fd, "reply", replyData.data(), replyData.size());
+    return rpcSend(fd, session, "reply", replyData.data(), replyData.size());
 }
 
 status_t RpcState::processDecStrong(const base::unique_fd& fd, const sp<RpcSession>& session,
@@ -772,7 +785,7 @@
     if (command.bodySize < sizeof(RpcWireAddress)) {
         ALOGE("Expecting %zu but got %" PRId32 " bytes for RpcWireAddress. Terminating!",
               sizeof(RpcWireAddress), command.bodySize);
-        terminate();
+        (void)session->shutdownAndWait(false);
         return BAD_VALUE;
     }
     RpcWireAddress* address = reinterpret_cast<RpcWireAddress*>(commandData.data());
@@ -790,7 +803,8 @@
     if (target == nullptr) {
         ALOGE("While requesting dec strong, binder has been deleted at address %s. Terminating!",
               addr.toString().c_str());
-        terminate();
+        _l.unlock();
+        (void)session->shutdownAndWait(false);
         return BAD_VALUE;
     }
 
@@ -826,12 +840,11 @@
     return ref;
 }
 
-bool RpcState::nodeProgressAsyncNumber(BinderNode* node, std::unique_lock<std::mutex>& lock) {
+bool RpcState::nodeProgressAsyncNumber(BinderNode* node) {
     // 2**64 =~ 10**19 =~ 1000 transactions per second for 585 million years to
     // a single binder
     if (node->asyncNumber >= std::numeric_limits<decltype(node->asyncNumber)>::max()) {
         ALOGE("Out of async transaction IDs. Terminating");
-        terminate(lock);
         return false;
     }
     node->asyncNumber++;
diff --git a/libs/binder/RpcState.h b/libs/binder/RpcState.h
index 81ff458..13c3115 100644
--- a/libs/binder/RpcState.h
+++ b/libs/binder/RpcState.h
@@ -65,7 +65,8 @@
                                            uint32_t code, const Parcel& data,
                                            const sp<RpcSession>& session, Parcel* reply,
                                            uint32_t flags);
-    [[nodiscard]] status_t sendDecStrong(const base::unique_fd& fd, const RpcAddress& address);
+    [[nodiscard]] status_t sendDecStrong(const base::unique_fd& fd, const sp<RpcSession>& session,
+                                         const RpcAddress& address);
 
     enum class CommandType {
         ANY,
@@ -110,11 +111,10 @@
      * WARNING: RpcState is responsible for calling this when the session is
      * no longer recoverable.
      */
-    void terminate();
+    void clear();
 
 private:
     void dumpLocked();
-    void terminate(std::unique_lock<std::mutex>& lock);
 
     // Alternative to std::vector<uint8_t> that doesn't abort on allocation failure and caps
     // large allocations to avoid being requested from allocating too much data.
@@ -130,8 +130,8 @@
         size_t mSize;
     };
 
-    [[nodiscard]] status_t rpcSend(const base::unique_fd& fd, const char* what, const void* data,
-                                   size_t size);
+    [[nodiscard]] status_t rpcSend(const base::unique_fd& fd, const sp<RpcSession>& session,
+                                   const char* what, const void* data, size_t size);
     [[nodiscard]] status_t rpcRec(const base::unique_fd& fd, const sp<RpcSession>& session,
                                   const char* what, void* data, size_t size);
 
@@ -204,9 +204,8 @@
     // dropped after any locks are removed.
     sp<IBinder> tryEraseNode(std::map<RpcAddress, BinderNode>::iterator& it);
     // true - success
-    // false - state terminated, lock gone, halt
-    [[nodiscard]] bool nodeProgressAsyncNumber(BinderNode* node,
-                                               std::unique_lock<std::mutex>& lock);
+    // false - session shutdown, halt
+    [[nodiscard]] bool nodeProgressAsyncNumber(BinderNode* node);
 
     std::mutex mNodeMutex;
     bool mTerminated = false;
diff --git a/libs/binder/include/binder/RpcSession.h b/libs/binder/include/binder/RpcSession.h
index 2f2c77c..7aa6d02 100644
--- a/libs/binder/include/binder/RpcSession.h
+++ b/libs/binder/include/binder/RpcSession.h
@@ -97,14 +97,20 @@
     status_t getRemoteMaxThreads(size_t* maxThreads);
 
     /**
-     * Shuts down the service. Only works for client sessions (server-side
-     * sessions currently only support shutting down the entire server).
+     * Shuts down the service.
+     *
+     * For client sessions, wait can be true or false. For server sessions,
+     * waiting is not currently supported (will abort).
      *
      * Warning: this is currently not active/nice (the server isn't told we're
      * shutting down). Being nicer to the server could potentially make it
      * reclaim resources faster.
+     *
+     * If this is called w/ 'wait' true, then this will wait for shutdown to
+     * complete before returning. This will hang if it is called from the
+     * session threadpool (when processing received calls).
      */
-    [[nodiscard]] bool shutdown();
+    [[nodiscard]] bool shutdownAndWait(bool wait);
 
     [[nodiscard]] status_t transact(const sp<IBinder>& binder, uint32_t code, const Parcel& data,
                                     Parcel* reply, uint32_t flags);
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index 0a970fb..c846660 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -966,7 +966,7 @@
 
             // since this session has a reverse connection w/ a threadpool, we
             // need to manually shut it down
-            EXPECT_TRUE(proc.proc.sessions.at(0).session->shutdown());
+            EXPECT_TRUE(proc.proc.sessions.at(0).session->shutdownAndWait(true));
 
             proc.expectAlreadyShutdown = true;
         }