Merge "Fix possible race conditions during channel unregistration." into gingerbread
diff --git a/core/jni/android_view_InputQueue.cpp b/core/jni/android_view_InputQueue.cpp
index 556d367..42f35d1 100644
--- a/core/jni/android_view_InputQueue.cpp
+++ b/core/jni/android_view_InputQueue.cpp
@@ -76,10 +76,14 @@
             STATUS_ZOMBIE
         };
 
-        Connection(const sp<InputChannel>& inputChannel, const sp<PollLoop>& pollLoop);
+        Connection(uint16_t id,
+                const sp<InputChannel>& inputChannel, const sp<PollLoop>& pollLoop);
 
         inline const char* getInputChannelName() const { return inputChannel->getName().string(); }
 
+        // A unique id for this connection.
+        uint16_t id;
+
         Status status;
 
         sp<InputChannel> inputChannel;
@@ -91,29 +95,34 @@
         // The sequence number of the current event being dispatched.
         // This is used as part of the finished token as a way to determine whether the finished
         // token is still valid before sending a finished signal back to the publisher.
-        uint32_t messageSeqNum;
+        uint16_t messageSeqNum;
 
         // True if a message has been received from the publisher but not yet finished.
         bool messageInProgress;
     };
 
     Mutex mLock;
+    uint16_t mNextConnectionId;
     KeyedVector<int32_t, sp<Connection> > mConnectionsByReceiveFd;
 
+    ssize_t getConnectionIndex(const sp<InputChannel>& inputChannel);
+
     static void handleInputChannelDisposed(JNIEnv* env,
             jobject inputChannelObj, const sp<InputChannel>& inputChannel, void* data);
 
     static bool handleReceiveCallback(int receiveFd, int events, void* data);
 
-    static jlong generateFinishedToken(int32_t receiveFd, int32_t messageSeqNum);
+    static jlong generateFinishedToken(int32_t receiveFd,
+            uint16_t connectionId, uint16_t messageSeqNum);
 
     static void parseFinishedToken(jlong finishedToken,
-            int32_t* outReceiveFd, uint32_t* outMessageIndex);
+            int32_t* outReceiveFd, uint16_t* outConnectionId, uint16_t* outMessageIndex);
 };
 
 // ----------------------------------------------------------------------------
 
-NativeInputQueue::NativeInputQueue() {
+NativeInputQueue::NativeInputQueue() :
+        mNextConnectionId(0) {
 }
 
 NativeInputQueue::~NativeInputQueue() {
@@ -134,18 +143,17 @@
 
     sp<PollLoop> pollLoop = android_os_MessageQueue_getPollLoop(env, messageQueueObj);
 
-    int receiveFd;
     { // acquire lock
         AutoMutex _l(mLock);
 
-        receiveFd = inputChannel->getReceivePipeFd();
-        if (mConnectionsByReceiveFd.indexOfKey(receiveFd) >= 0) {
+        if (getConnectionIndex(inputChannel) >= 0) {
             LOGW("Attempted to register already registered input channel '%s'",
                     inputChannel->getName().string());
             return BAD_VALUE;
         }
 
-        sp<Connection> connection = new Connection(inputChannel, pollLoop);
+        uint16_t connectionId = mNextConnectionId++;
+        sp<Connection> connection = new Connection(connectionId, inputChannel, pollLoop);
         status_t result = connection->inputConsumer.initialize();
         if (result) {
             LOGW("Failed to initialize input consumer for input channel '%s', status=%d",
@@ -155,13 +163,14 @@
 
         connection->inputHandlerObjGlobal = env->NewGlobalRef(inputHandlerObj);
 
+        int32_t receiveFd = inputChannel->getReceivePipeFd();
         mConnectionsByReceiveFd.add(receiveFd, connection);
+
+        pollLoop->setCallback(receiveFd, POLLIN, handleReceiveCallback, this);
     } // release lock
 
     android_view_InputChannel_setDisposeCallback(env, inputChannelObj,
             handleInputChannelDisposed, this);
-
-    pollLoop->setCallback(receiveFd, POLLIN, handleReceiveCallback, this);
     return OK;
 }
 
@@ -177,38 +186,56 @@
     LOGD("channel '%s' - Unregistered", inputChannel->getName().string());
 #endif
 
-    int32_t receiveFd;
-    sp<Connection> connection;
     { // acquire lock
         AutoMutex _l(mLock);
 
-        receiveFd = inputChannel->getReceivePipeFd();
-        ssize_t connectionIndex = mConnectionsByReceiveFd.indexOfKey(receiveFd);
+        ssize_t connectionIndex = getConnectionIndex(inputChannel);
         if (connectionIndex < 0) {
             LOGW("Attempted to unregister already unregistered input channel '%s'",
                     inputChannel->getName().string());
             return BAD_VALUE;
         }
 
-        connection = mConnectionsByReceiveFd.valueAt(connectionIndex);
+        sp<Connection> connection = mConnectionsByReceiveFd.valueAt(connectionIndex);
         mConnectionsByReceiveFd.removeItemsAt(connectionIndex);
 
         connection->status = Connection::STATUS_ZOMBIE;
 
+        connection->pollLoop->removeCallback(inputChannel->getReceivePipeFd());
+
         env->DeleteGlobalRef(connection->inputHandlerObjGlobal);
         connection->inputHandlerObjGlobal = NULL;
+
+        if (connection->messageInProgress) {
+            LOGI("Sending finished signal for input channel '%s' since it is being unregistered "
+                    "while an input message is still in progress.",
+                    connection->getInputChannelName());
+            connection->messageInProgress = false;
+            connection->inputConsumer.sendFinishedSignal(); // ignoring result
+        }
     } // release lock
 
     android_view_InputChannel_setDisposeCallback(env, inputChannelObj, NULL, NULL);
-
-    connection->pollLoop->removeCallback(receiveFd);
     return OK;
 }
 
+ssize_t NativeInputQueue::getConnectionIndex(const sp<InputChannel>& inputChannel) {
+    ssize_t connectionIndex = mConnectionsByReceiveFd.indexOfKey(inputChannel->getReceivePipeFd());
+    if (connectionIndex >= 0) {
+        sp<Connection> connection = mConnectionsByReceiveFd.valueAt(connectionIndex);
+        if (connection->inputChannel.get() == inputChannel.get()) {
+            return connectionIndex;
+        }
+    }
+
+    return -1;
+}
+
 status_t NativeInputQueue::finished(JNIEnv* env, jlong finishedToken, bool ignoreSpuriousFinish) {
     int32_t receiveFd;
-    uint32_t messageSeqNum;
-    parseFinishedToken(finishedToken, &receiveFd, &messageSeqNum);
+    uint16_t connectionId;
+    uint16_t messageSeqNum;
+    parseFinishedToken(finishedToken, &receiveFd, &connectionId, &messageSeqNum);
 
     { // acquire lock
         AutoMutex _l(mLock);
@@ -216,16 +243,25 @@
         ssize_t connectionIndex = mConnectionsByReceiveFd.indexOfKey(receiveFd);
         if (connectionIndex < 0) {
             if (! ignoreSpuriousFinish) {
-                LOGW("Attempted to finish input on channel that is no longer registered.");
+                LOGI("Ignoring finish signal on channel that is no longer registered.");
             }
             return DEAD_OBJECT;
         }
 
         sp<Connection> connection = mConnectionsByReceiveFd.valueAt(connectionIndex);
+        if (connectionId != connection->id) {
+            if (! ignoreSpuriousFinish) {
+                LOGI("Ignoring finish signal on channel that is no longer registered.");
+            }
+            return DEAD_OBJECT;
+        }
+
         if (messageSeqNum != connection->messageSeqNum || ! connection->messageInProgress) {
             if (! ignoreSpuriousFinish) {
-                LOGW("Attempted to finish input twice on channel '%s'.",
-                        connection->getInputChannelName());
+                LOGW("Attempted to finish input twice on channel '%s'.  "
+                        "finished messageSeqNum=%d, current messageSeqNum=%d, messageInProgress=%d",
+                        connection->getInputChannelName(),
+                        messageSeqNum, connection->messageSeqNum, connection->messageInProgress);
             }
             return INVALID_OPERATION;
         }
@@ -312,7 +348,7 @@
         connection->messageInProgress = true;
         connection->messageSeqNum += 1;
 
-        finishedToken = generateFinishedToken(receiveFd, connection->messageSeqNum);
+        finishedToken = generateFinishedToken(receiveFd, connection->id, connection->messageSeqNum);
 
         inputHandlerObjLocal = env->NewLocalRef(connection->inputHandlerObjGlobal);
     } // release lock
@@ -384,20 +420,23 @@
     return true;
 }
 
-jlong NativeInputQueue::generateFinishedToken(int32_t receiveFd, int32_t messageSeqNum) {
-    return (jlong(receiveFd) << 32) | jlong(messageSeqNum);
+jlong NativeInputQueue::generateFinishedToken(int32_t receiveFd, uint16_t connectionId,
+        uint16_t messageSeqNum) {
+    return (jlong(receiveFd) << 32) | (jlong(connectionId) << 16) | jlong(messageSeqNum);
 }
 
 void NativeInputQueue::parseFinishedToken(jlong finishedToken,
-        int32_t* outReceiveFd, uint32_t* outMessageIndex) {
+        int32_t* outReceiveFd, uint16_t* outConnectionId, uint16_t* outMessageIndex) {
     *outReceiveFd = int32_t(finishedToken >> 32);
-    *outMessageIndex = uint32_t(finishedToken & 0xffffffff);
+    *outConnectionId = uint16_t(finishedToken >> 16);
+    *outMessageIndex = uint16_t(finishedToken);
 }
 
 // ----------------------------------------------------------------------------
 
-NativeInputQueue::Connection::Connection(const sp<InputChannel>& inputChannel, const sp<PollLoop>& pollLoop) :
-    status(STATUS_NORMAL), inputChannel(inputChannel), inputConsumer(inputChannel),
+NativeInputQueue::Connection::Connection(uint16_t id,
+        const sp<InputChannel>& inputChannel, const sp<PollLoop>& pollLoop) :
+    id(id), status(STATUS_NORMAL), inputChannel(inputChannel), inputConsumer(inputChannel),
     pollLoop(pollLoop), inputHandlerObjGlobal(NULL),
     messageSeqNum(0), messageInProgress(false) {
 }
diff --git a/include/ui/InputDispatcher.h b/include/ui/InputDispatcher.h
index d3495fe..2505cb0 100644
--- a/include/ui/InputDispatcher.h
+++ b/include/ui/InputDispatcher.h
@@ -554,6 +554,8 @@
     // All registered connections mapped by receive pipe file descriptor.
     KeyedVector<int, sp<Connection> > mConnectionsByReceiveFd;
 
+    ssize_t getConnectionIndex(const sp<InputChannel>& inputChannel);
+
     // Active connections are connections that have a non-empty outbound queue.
     // We don't use a ref-counted pointer here because we explicitly abort connections
     // during unregistration which causes the connection's outbound queue to be cleared
diff --git a/libs/ui/InputDispatcher.cpp b/libs/ui/InputDispatcher.cpp
index b53f140..13030b5 100644
--- a/libs/ui/InputDispatcher.cpp
+++ b/libs/ui/InputDispatcher.cpp
@@ -433,8 +433,7 @@
     for (size_t i = 0; i < mCurrentInputTargets.size(); i++) {
         const InputTarget& inputTarget = mCurrentInputTargets.itemAt(i);
 
-        ssize_t connectionIndex = mConnectionsByReceiveFd.indexOfKey(
-                inputTarget.inputChannel->getReceivePipeFd());
+        ssize_t connectionIndex = getConnectionIndex(inputTarget.inputChannel);
         if (connectionIndex >= 0) {
             sp<Connection> connection = mConnectionsByReceiveFd.valueAt(connectionIndex);
             prepareDispatchCycleLocked(currentTime, connection, eventEntry, & inputTarget,
@@ -1367,12 +1366,10 @@
     LOGD("channel '%s' ~ registerInputChannel", inputChannel->getName().string());
 #endif
 
-    int receiveFd;
     { // acquire lock
         AutoMutex _l(mLock);
 
-        receiveFd = inputChannel->getReceivePipeFd();
-        if (mConnectionsByReceiveFd.indexOfKey(receiveFd) >= 0) {
+        if (getConnectionIndex(inputChannel) >= 0) {
             LOGW("Attempted to register already registered input channel '%s'",
                     inputChannel->getName().string());
             return BAD_VALUE;
@@ -1386,12 +1383,13 @@
             return status;
         }
 
+        int32_t receiveFd = inputChannel->getReceivePipeFd();
         mConnectionsByReceiveFd.add(receiveFd, connection);
 
+        mPollLoop->setCallback(receiveFd, POLLIN, handleReceiveCallback, this);
+
         runCommandsLockedInterruptible();
     } // release lock
-
-    mPollLoop->setCallback(receiveFd, POLLIN, handleReceiveCallback, this);
     return OK;
 }
 
@@ -1400,12 +1398,10 @@
     LOGD("channel '%s' ~ unregisterInputChannel", inputChannel->getName().string());
 #endif
 
-    int32_t receiveFd;
     { // acquire lock
         AutoMutex _l(mLock);
 
-        receiveFd = inputChannel->getReceivePipeFd();
-        ssize_t connectionIndex = mConnectionsByReceiveFd.indexOfKey(receiveFd);
+        ssize_t connectionIndex = getConnectionIndex(inputChannel);
         if (connectionIndex < 0) {
             LOGW("Attempted to unregister already unregistered input channel '%s'",
                     inputChannel->getName().string());
@@ -1417,20 +1413,32 @@
 
         connection->status = Connection::STATUS_ZOMBIE;
 
+        mPollLoop->removeCallback(inputChannel->getReceivePipeFd());
+
         nsecs_t currentTime = now();
         abortDispatchCycleLocked(currentTime, connection, true /*broken*/);
 
         runCommandsLockedInterruptible();
     } // release lock
 
-    mPollLoop->removeCallback(receiveFd);
-
     // Wake the poll loop because removing the connection may have changed the current
     // synchronization state.
     mPollLoop->wake();
     return OK;
 }
 
+ssize_t InputDispatcher::getConnectionIndex(const sp<InputChannel>& inputChannel) {
+    ssize_t connectionIndex = mConnectionsByReceiveFd.indexOfKey(inputChannel->getReceivePipeFd());
+    if (connectionIndex >= 0) {
+        sp<Connection> connection = mConnectionsByReceiveFd.valueAt(connectionIndex);
+        if (connection->inputChannel.get() == inputChannel.get()) {
+            return connectionIndex;
+        }
+    }
+
+    return -1;
+}
+
 void InputDispatcher::activateConnectionLocked(Connection* connection) {
     for (size_t i = 0; i < mActiveConnections.size(); i++) {
         if (mActiveConnections.itemAt(i) == connection) {
diff --git a/services/jni/com_android_server_InputManager.cpp b/services/jni/com_android_server_InputManager.cpp
index ebe71ab..ba58b43 100644
--- a/services/jni/com_android_server_InputManager.cpp
+++ b/services/jni/com_android_server_InputManager.cpp
@@ -319,9 +319,9 @@
     bool isScreenOn();
     bool isScreenBright();
 
-    // Weak references to all currently registered input channels by receive fd.
+    // Weak references to all currently registered input channels by connection pointer.
     Mutex mInputChannelRegistryLock;
-    KeyedVector<int, jweak> mInputChannelObjWeakByReceiveFd;
+    KeyedVector<InputChannel*, jweak> mInputChannelObjWeakTable;
 
     jobject getInputChannelObjLocal(JNIEnv* env, const sp<InputChannel>& inputChannel);
 
@@ -509,8 +509,7 @@
     {
         AutoMutex _l(mInputChannelRegistryLock);
 
-        ssize_t index = mInputChannelObjWeakByReceiveFd.indexOfKey(
-                inputChannel->getReceivePipeFd());
+        ssize_t index = mInputChannelObjWeakTable.indexOfKey(inputChannel.get());
         if (index >= 0) {
             LOGE("Input channel object '%s' has already been registered",
                     inputChannel->getName().string());
@@ -518,8 +517,7 @@
             goto DeleteWeakRef;
         }
 
-        mInputChannelObjWeakByReceiveFd.add(inputChannel->getReceivePipeFd(),
-                inputChannelObjWeak);
+        mInputChannelObjWeakTable.add(inputChannel.get(), inputChannelObjWeak);
     }
 
     status = mInputManager->registerInputChannel(inputChannel);
@@ -534,7 +532,7 @@
     // Failed!
     {
         AutoMutex _l(mInputChannelRegistryLock);
-        mInputChannelObjWeakByReceiveFd.removeItem(inputChannel->getReceivePipeFd());
+        mInputChannelObjWeakTable.removeItem(inputChannel.get());
     }
 
 DeleteWeakRef:
@@ -548,16 +546,15 @@
     {
         AutoMutex _l(mInputChannelRegistryLock);
 
-        ssize_t index = mInputChannelObjWeakByReceiveFd.indexOfKey(
-                inputChannel->getReceivePipeFd());
+        ssize_t index = mInputChannelObjWeakTable.indexOfKey(inputChannel.get());
         if (index < 0) {
             LOGE("Input channel object '%s' is not currently registered",
                     inputChannel->getName().string());
             return INVALID_OPERATION;
         }
 
-        inputChannelObjWeak = mInputChannelObjWeakByReceiveFd.valueAt(index);
-        mInputChannelObjWeakByReceiveFd.removeItemsAt(index);
+        inputChannelObjWeak = mInputChannelObjWeakTable.valueAt(index);
+        mInputChannelObjWeakTable.removeItemsAt(index);
     }
 
     env->DeleteWeakGlobalRef(inputChannelObjWeak);
@@ -572,13 +569,12 @@
     {
         AutoMutex _l(mInputChannelRegistryLock);
 
-        ssize_t index = mInputChannelObjWeakByReceiveFd.indexOfKey(
-                inputChannel->getReceivePipeFd());
+        ssize_t index = mInputChannelObjWeakTable.indexOfKey(inputChannel.get());
         if (index < 0) {
             return NULL;
         }
 
-        jweak inputChannelObjWeak = mInputChannelObjWeakByReceiveFd.valueAt(index);
+        jweak inputChannelObjWeak = mInputChannelObjWeakTable.valueAt(index);
         return env->NewLocalRef(inputChannelObjWeak);
     }
 }