Store connections by token instead of by fd

The connections are currently stored by fd. If a connection is removed
via 'removeInputChannel', it is possible to re-create the same
connection and have it keyed by the same fd. When this happens, a race
condition may occur where a socket hangup on this fd would cause the
removal of a newly registered connection.

In this refactor, the connections are no longer stored by fd. The looper
interface for adding fds has two versions:
1) the old one that we are currently using, which is marked as 'do not
use'
2) the new one where a callback object is provided instead.

In this CL, we switch to the new version of the callback.

There is now also no need to store the inputchannels in a separate
structure, because we can use the connections collection that's now
keyed by token to find them.

In a future refactor, we should switch to using 'unique_ptr' for the
inputchannels. Most of the time when we are looking for an input
channel, we are actually interested in finding the corresponding
connection.

If we switch Connection to shared_ptr, we can also look into switching
LooperEventCallback to store a weak pointer to a connection instead of
storing the connection token. This should speed up the handling of
events, by avoiding a map lookup.

Test: ./reinitinput.sh. Observe that it doesnt finish after this patch
Test: atest inputflinger_tests
Bug: 182478748

Change-Id: I601f765eebfadcaeff3661a10a10c4a4f0477389
diff --git a/services/inputflinger/dispatcher/InputDispatcher.cpp b/services/inputflinger/dispatcher/InputDispatcher.cpp
index 16cb7d7..9a43ed9 100644
--- a/services/inputflinger/dispatcher/InputDispatcher.cpp
+++ b/services/inputflinger/dispatcher/InputDispatcher.cpp
@@ -283,27 +283,6 @@
     return it != map.end() ? it->second : V{};
 }
 
-/**
- * Find the entry in std::unordered_map by value, and remove it.
- * If more than one entry has the same value, then all matching
- * key-value pairs will be removed.
- *
- * Return true if at least one value has been removed.
- */
-template <typename K, typename V>
-static bool removeByValue(std::unordered_map<K, V>& map, const V& value) {
-    bool removed = false;
-    for (auto it = map.begin(); it != map.end();) {
-        if (it->second == value) {
-            it = map.erase(it);
-            removed = true;
-        } else {
-            it++;
-        }
-    }
-    return removed;
-}
-
 static bool haveSameToken(const sp<InputWindowHandle>& first, const sp<InputWindowHandle>& second) {
     if (first == second) {
         return true;
@@ -507,8 +486,8 @@
         drainInboundQueueLocked();
     }
 
-    while (!mConnectionsByFd.empty()) {
-        sp<Connection> connection = mConnectionsByFd.begin()->second;
+    while (!mConnectionsByToken.empty()) {
+        sp<Connection> connection = mConnectionsByToken.begin()->second;
         removeInputChannel(connection->inputChannel->getConnectionToken());
     }
 }
@@ -3297,86 +3276,78 @@
     delete dispatchEntry;
 }
 
-int InputDispatcher::handleReceiveCallback(int fd, int events, void* data) {
-    InputDispatcher* d = static_cast<InputDispatcher*>(data);
+int InputDispatcher::handleReceiveCallback(int events, sp<IBinder> connectionToken) {
+    std::scoped_lock _l(mLock);
+    sp<Connection> connection = getConnectionLocked(connectionToken);
+    if (connection == nullptr) {
+        ALOGW("Received looper callback for unknown input channel token %p.  events=0x%x",
+              connectionToken.get(), events);
+        return 0; // remove the callback
+    }
 
-    { // acquire lock
-        std::scoped_lock _l(d->mLock);
-
-        if (d->mConnectionsByFd.find(fd) == d->mConnectionsByFd.end()) {
-            ALOGE("Received spurious receive callback for unknown input channel.  "
-                  "fd=%d, events=0x%x",
-                  fd, events);
-            return 0; // remove the callback
+    bool notify;
+    if (!(events & (ALOOPER_EVENT_ERROR | ALOOPER_EVENT_HANGUP))) {
+        if (!(events & ALOOPER_EVENT_INPUT)) {
+            ALOGW("channel '%s' ~ Received spurious callback for unhandled poll event.  "
+                  "events=0x%x",
+                  connection->getInputChannelName().c_str(), events);
+            return 1;
         }
 
-        bool notify;
-        sp<Connection> connection = d->mConnectionsByFd[fd];
-        if (!(events & (ALOOPER_EVENT_ERROR | ALOOPER_EVENT_HANGUP))) {
-            if (!(events & ALOOPER_EVENT_INPUT)) {
-                ALOGW("channel '%s' ~ Received spurious callback for unhandled poll event.  "
-                      "events=0x%x",
-                      connection->getInputChannelName().c_str(), events);
+        nsecs_t currentTime = now();
+        bool gotOne = false;
+        status_t status = OK;
+        for (;;) {
+            Result<InputPublisher::ConsumerResponse> result =
+                    connection->inputPublisher.receiveConsumerResponse();
+            if (!result.ok()) {
+                status = result.error().code();
+                break;
+            }
+
+            if (std::holds_alternative<InputPublisher::Finished>(*result)) {
+                const InputPublisher::Finished& finish =
+                        std::get<InputPublisher::Finished>(*result);
+                finishDispatchCycleLocked(currentTime, connection, finish.seq, finish.handled,
+                                          finish.consumeTime);
+            } else if (std::holds_alternative<InputPublisher::Timeline>(*result)) {
+                // TODO(b/167947340): Report this data to LatencyTracker
+            }
+            gotOne = true;
+        }
+        if (gotOne) {
+            runCommandsLockedInterruptible();
+            if (status == WOULD_BLOCK) {
                 return 1;
             }
-
-            nsecs_t currentTime = now();
-            bool gotOne = false;
-            status_t status = OK;
-            for (;;) {
-                Result<InputPublisher::ConsumerResponse> result =
-                        connection->inputPublisher.receiveConsumerResponse();
-                if (!result.ok()) {
-                    status = result.error().code();
-                    break;
-                }
-
-                if (std::holds_alternative<InputPublisher::Finished>(*result)) {
-                    const InputPublisher::Finished& finish =
-                            std::get<InputPublisher::Finished>(*result);
-                    d->finishDispatchCycleLocked(currentTime, connection, finish.seq,
-                                                 finish.handled, finish.consumeTime);
-                } else if (std::holds_alternative<InputPublisher::Timeline>(*result)) {
-                    // TODO(b/167947340): Report this data to LatencyTracker
-                }
-                gotOne = true;
-            }
-            if (gotOne) {
-                d->runCommandsLockedInterruptible();
-                if (status == WOULD_BLOCK) {
-                    return 1;
-                }
-            }
-
-            notify = status != DEAD_OBJECT || !connection->monitor;
-            if (notify) {
-                ALOGE("channel '%s' ~ Failed to receive finished signal.  status=%s(%d)",
-                      connection->getInputChannelName().c_str(), statusToString(status).c_str(),
-                      status);
-            }
-        } else {
-            // Monitor channels are never explicitly unregistered.
-            // We do it automatically when the remote endpoint is closed so don't warn about them.
-            const bool stillHaveWindowHandle =
-                    d->getWindowHandleLocked(connection->inputChannel->getConnectionToken()) !=
-                    nullptr;
-            notify = !connection->monitor && stillHaveWindowHandle;
-            if (notify) {
-                ALOGW("channel '%s' ~ Consumer closed input channel or an error occurred.  "
-                      "events=0x%x",
-                      connection->getInputChannelName().c_str(), events);
-            }
         }
 
-        // Remove the channel.
-        d->removeInputChannelLocked(connection->inputChannel->getConnectionToken(), notify);
-        return 0; // remove the callback
-    }             // release lock
+        notify = status != DEAD_OBJECT || !connection->monitor;
+        if (notify) {
+            ALOGE("channel '%s' ~ Failed to receive finished signal.  status=%s(%d)",
+                  connection->getInputChannelName().c_str(), statusToString(status).c_str(),
+                  status);
+        }
+    } else {
+        // Monitor channels are never explicitly unregistered.
+        // We do it automatically when the remote endpoint is closed so don't warn about them.
+        const bool stillHaveWindowHandle =
+                getWindowHandleLocked(connection->inputChannel->getConnectionToken()) != nullptr;
+        notify = !connection->monitor && stillHaveWindowHandle;
+        if (notify) {
+            ALOGW("channel '%s' ~ Consumer closed input channel or an error occurred.  events=0x%x",
+                  connection->getInputChannelName().c_str(), events);
+        }
+    }
+
+    // Remove the channel.
+    removeInputChannelLocked(connection->inputChannel->getConnectionToken(), notify);
+    return 0; // remove the callback
 }
 
 void InputDispatcher::synthesizeCancelationEventsForAllConnectionsLocked(
         const CancelationOptions& options) {
-    for (const auto& [fd, connection] : mConnectionsByFd) {
+    for (const auto& [token, connection] : mConnectionsByToken) {
         synthesizeCancelationEventsForConnectionLocked(connection, options);
     }
 }
@@ -4342,11 +4313,11 @@
 
 std::shared_ptr<InputChannel> InputDispatcher::getInputChannelLocked(
         const sp<IBinder>& token) const {
-    size_t count = mInputChannelsByToken.count(token);
-    if (count == 0) {
+    auto connectionIt = mConnectionsByToken.find(token);
+    if (connectionIt == mConnectionsByToken.end()) {
         return nullptr;
     }
-    return mInputChannelsByToken.at(token);
+    return connectionIt->second->inputChannel;
 }
 
 void InputDispatcher::updateWindowHandlesForDisplayLocked(
@@ -4996,13 +4967,13 @@
         dump += INDENT "ReplacedKeys: <empty>\n";
     }
 
-    if (!mConnectionsByFd.empty()) {
+    if (!mConnectionsByToken.empty()) {
         dump += INDENT "Connections:\n";
-        for (const auto& pair : mConnectionsByFd) {
-            const sp<Connection>& connection = pair.second;
+        for (const auto& [token, connection] : mConnectionsByToken) {
             dump += StringPrintf(INDENT2 "%i: channelName='%s', windowName='%s', "
                                          "status=%s, monitor=%s, responsive=%s\n",
-                                 pair.first, connection->getInputChannelName().c_str(),
+                                 connection->inputChannel->getFd().get(),
+                                 connection->getInputChannelName().c_str(),
                                  connection->getWindowName().c_str(), connection->getStatusLabel(),
                                  toString(connection->monitor), toString(connection->responsive));
 
@@ -5050,14 +5021,23 @@
     }
 }
 
+class LooperEventCallback : public LooperCallback {
+public:
+    LooperEventCallback(std::function<int(int events)> callback) : mCallback(callback) {}
+    int handleEvent(int /*fd*/, int events, void* /*data*/) override { return mCallback(events); }
+
+private:
+    std::function<int(int events)> mCallback;
+};
+
 Result<std::unique_ptr<InputChannel>> InputDispatcher::createInputChannel(const std::string& name) {
 #if DEBUG_CHANNEL_CREATION
     ALOGD("channel '%s' ~ createInputChannel", name.c_str());
 #endif
 
-    std::shared_ptr<InputChannel> serverChannel;
+    std::unique_ptr<InputChannel> serverChannel;
     std::unique_ptr<InputChannel> clientChannel;
-    status_t result = openInputChannelPair(name, serverChannel, clientChannel);
+    status_t result = InputChannel::openInputChannelPair(name, serverChannel, clientChannel);
 
     if (result) {
         return base::Error(result) << "Failed to open input channel pair with name " << name;
@@ -5065,13 +5045,20 @@
 
     { // acquire lock
         std::scoped_lock _l(mLock);
-        sp<Connection> connection = new Connection(serverChannel, false /*monitor*/, mIdGenerator);
-
+        const sp<IBinder>& token = serverChannel->getConnectionToken();
         int fd = serverChannel->getFd();
-        mConnectionsByFd[fd] = connection;
-        mInputChannelsByToken[serverChannel->getConnectionToken()] = serverChannel;
+        sp<Connection> connection =
+                new Connection(std::move(serverChannel), false /*monitor*/, mIdGenerator);
 
-        mLooper->addFd(fd, 0, ALOOPER_EVENT_INPUT, handleReceiveCallback, this);
+        if (mConnectionsByToken.find(token) != mConnectionsByToken.end()) {
+            ALOGE("Created a new connection, but the token %p is already known", token.get());
+        }
+        mConnectionsByToken.emplace(token, connection);
+
+        std::function<int(int events)> callback = std::bind(&InputDispatcher::handleReceiveCallback,
+                                                            this, std::placeholders::_1, token);
+
+        mLooper->addFd(fd, 0, ALOOPER_EVENT_INPUT, new LooperEventCallback(callback), nullptr);
     } // release lock
 
     // Wake the looper because some connections have changed.
@@ -5099,18 +5086,21 @@
         }
 
         sp<Connection> connection = new Connection(serverChannel, true /*monitor*/, mIdGenerator);
-
+        const sp<IBinder>& token = serverChannel->getConnectionToken();
         const int fd = serverChannel->getFd();
-        mConnectionsByFd[fd] = connection;
-        mInputChannelsByToken[serverChannel->getConnectionToken()] = serverChannel;
+
+        if (mConnectionsByToken.find(token) != mConnectionsByToken.end()) {
+            ALOGE("Created a new connection, but the token %p is already known", token.get());
+        }
+        mConnectionsByToken.emplace(token, connection);
+        std::function<int(int events)> callback = std::bind(&InputDispatcher::handleReceiveCallback,
+                                                            this, std::placeholders::_1, token);
 
         auto& monitorsByDisplay =
                 isGestureMonitor ? mGestureMonitorsByDisplay : mGlobalMonitorsByDisplay;
         monitorsByDisplay[displayId].emplace_back(serverChannel, pid);
 
-        mLooper->addFd(fd, 0, ALOOPER_EVENT_INPUT, handleReceiveCallback, this);
-        ALOGI("Created monitor %s for display %" PRId32 ", gesture=%s, pid=%" PRId32, name.c_str(),
-              displayId, toString(isGestureMonitor), pid);
+        mLooper->addFd(fd, 0, ALOOPER_EVENT_INPUT, new LooperEventCallback(callback), nullptr);
     }
 
     // Wake the looper because some connections have changed.
@@ -5143,7 +5133,6 @@
     }
 
     removeConnectionLocked(connection);
-    mInputChannelsByToken.erase(connectionToken);
 
     if (connection->monitor) {
         removeMonitorChannelLocked(connectionToken);
@@ -5301,9 +5290,8 @@
         return nullptr;
     }
 
-    for (const auto& pair : mConnectionsByFd) {
-        const sp<Connection>& connection = pair.second;
-        if (connection->inputChannel->getConnectionToken() == inputConnectionToken) {
+    for (const auto& [token, connection] : mConnectionsByToken) {
+        if (token == inputConnectionToken) {
             return connection;
         }
     }
@@ -5321,7 +5309,7 @@
 
 void InputDispatcher::removeConnectionLocked(const sp<Connection>& connection) {
     mAnrTracker.eraseToken(connection->inputChannel->getConnectionToken());
-    removeByValue(mConnectionsByFd, connection);
+    mConnectionsByToken.erase(connection->inputChannel->getConnectionToken());
 }
 
 void InputDispatcher::onDispatchCycleFinishedLocked(nsecs_t currentTime,