MTP: Fix race conditions in MtpServer JNI code

Make sure previous MtpThread has exited before starting another to avoid
EBUSY opening MTP kernel driver.

BUG: 3317803

Change-Id: I81dcbac42bcf5f680ed1b1469839bc0b0e49d53d
Signed-off-by: Mike Lockwood <lockwood@android.com>
diff --git a/media/jni/android_mtp_MtpServer.cpp b/media/jni/android_mtp_MtpServer.cpp
index 1452d21..3883fb2 100644
--- a/media/jni/android_mtp_MtpServer.cpp
+++ b/media/jni/android_mtp_MtpServer.cpp
@@ -40,9 +40,6 @@
 
 // ----------------------------------------------------------------------------
 
-static jfieldID field_context;
-static Mutex    sMutex;
-
 // in android_mtp_MtpDatabase.cpp
 extern MtpDatabase* getMtpDatabase(JNIEnv *env, jobject database);
 
@@ -61,96 +58,74 @@
     MtpServer*      mServer;
     String8         mStoragePath;
     uint64_t        mReserveSpace;
-    jobject         mJavaServer;
-    bool            mDone;
+    Mutex           mMutex;
+    bool            mUsePtp;
     int             mFd;
 
 public:
-    MtpThread(MtpDatabase* database, const char* storagePath, uint64_t reserveSpace,
-                jobject javaServer)
+    MtpThread(MtpDatabase* database, const char* storagePath, uint64_t reserveSpace)
         :   mDatabase(database),
             mServer(NULL),
             mStoragePath(storagePath),
             mReserveSpace(reserveSpace),
-            mJavaServer(javaServer),
-            mDone(false),
             mFd(-1)
     {
     }
 
     void setPtpMode(bool usePtp) {
-        sMutex.lock();
-        if (mFd >= 0) {
-            ioctl(mFd, MTP_SET_INTERFACE_MODE,
-                    (usePtp ? MTP_INTERFACE_MODE_PTP : MTP_INTERFACE_MODE_MTP));
-        } else {
-            int fd = open("/dev/mtp_usb", O_RDWR);
-            if (fd >= 0) {
-                ioctl(fd, MTP_SET_INTERFACE_MODE,
-                        (usePtp ? MTP_INTERFACE_MODE_PTP : MTP_INTERFACE_MODE_MTP));
-                close(fd);
-            }
-        }
-        sMutex.unlock();
+        mMutex.lock();
+        mUsePtp = usePtp;
+        mMutex.unlock();
     }
 
     virtual bool threadLoop() {
-        sMutex.lock();
-
-        while (!mDone) {
-            mFd = open("/dev/mtp_usb", O_RDWR);
-            printf("open returned %d\n", mFd);
-            if (mFd < 0) {
-                LOGE("could not open MTP driver\n");
-                sMutex.unlock();
-                return false;
-            }
+        mMutex.lock();
+        mFd = open("/dev/mtp_usb", O_RDWR);
+        if (mFd >= 0) {
+            ioctl(mFd, MTP_SET_INTERFACE_MODE,
+                    (mUsePtp ? MTP_INTERFACE_MODE_PTP : MTP_INTERFACE_MODE_MTP));
 
             mServer = new MtpServer(mFd, mDatabase, AID_MEDIA_RW, 0664, 0775);
             mServer->addStorage(mStoragePath, mReserveSpace);
 
-            sMutex.unlock();
-
+            mMutex.unlock();
             mServer->run();
-            sleep(1);
-
-            sMutex.lock();
+            mMutex.lock();
 
             close(mFd);
             mFd = -1;
             delete mServer;
             mServer = NULL;
+        } else {
+            LOGE("could not open MTP driver, errno: %d", errno);
+        }
+        mMutex.unlock();
+        // delay a bit before retrying to avoid excessive spin
+        if (!exitPending()) {
+            sleep(1);
         }
 
-        JNIEnv* env = AndroidRuntime::getJNIEnv();
-        env->SetIntField(mJavaServer, field_context, 0);
-        env->DeleteGlobalRef(mJavaServer);
-        sMutex.unlock();
-
-        return false;
-    }
-
-    void stop() {
-        sMutex.lock();
-        mDone = true;
-        sMutex.unlock();
+        return true;
     }
 
     void sendObjectAdded(MtpObjectHandle handle) {
-        sMutex.lock();
+        mMutex.lock();
         if (mServer)
             mServer->sendObjectAdded(handle);
-        sMutex.unlock();
+        mMutex.unlock();
     }
 
     void sendObjectRemoved(MtpObjectHandle handle) {
-        sMutex.lock();
+        mMutex.lock();
         if (mServer)
             mServer->sendObjectRemoved(handle);
-        sMutex.unlock();
+        mMutex.unlock();
     }
 };
 
+// This smart pointer is necessary for preventing MtpThread from exiting too early
+static sp<MtpThread> sThread;
+
 #endif // HAVE_ANDROID_OS
 
 static void
@@ -161,9 +136,8 @@
     MtpDatabase* database = getMtpDatabase(env, javaDatabase);
     const char *storagePathStr = env->GetStringUTFChars(storagePath, NULL);
 
-    MtpThread* thread = new MtpThread(database, storagePathStr,
-            reserveSpace, env->NewGlobalRef(thiz));
-    env->SetIntField(thiz, field_context, (int)thread);
+    // create the thread and assign it to the smart pointer
+    sThread = new MtpThread(database, storagePathStr, reserveSpace);
 
     env->ReleaseStringUTFChars(storagePath, storagePathStr);
 #endif
@@ -173,8 +147,9 @@
 android_mtp_MtpServer_start(JNIEnv *env, jobject thiz)
 {
 #ifdef HAVE_ANDROID_OS
-    MtpThread *thread = (MtpThread *)env->GetIntField(thiz, field_context);
-    thread->run("MtpThread");
+    MtpThread *thread = sThread.get();
+    if (thread)
+        thread->run("MtpThread");
 #endif // HAVE_ANDROID_OS
 }
 
@@ -182,9 +157,11 @@
 android_mtp_MtpServer_stop(JNIEnv *env, jobject thiz)
 {
 #ifdef HAVE_ANDROID_OS
-    MtpThread *thread = (MtpThread *)env->GetIntField(thiz, field_context);
-    if (thread)
-        thread->stop();
+    MtpThread *thread = sThread.get();
+    if (thread) {
+        thread->requestExitAndWait();
+        sThread = NULL;
+    }
 #endif
 }
 
@@ -192,7 +169,7 @@
 android_mtp_MtpServer_send_object_added(JNIEnv *env, jobject thiz, jint handle)
 {
 #ifdef HAVE_ANDROID_OS
-    MtpThread *thread = (MtpThread *)env->GetIntField(thiz, field_context);
+    MtpThread *thread = sThread.get();
     if (thread)
         thread->sendObjectAdded(handle);
 #endif
@@ -202,7 +179,7 @@
 android_mtp_MtpServer_send_object_removed(JNIEnv *env, jobject thiz, jint handle)
 {
 #ifdef HAVE_ANDROID_OS
-    MtpThread *thread = (MtpThread *)env->GetIntField(thiz, field_context);
+    MtpThread *thread = sThread.get();
     if (thread)
         thread->sendObjectRemoved(handle);
 #endif
@@ -212,7 +189,7 @@
 android_mtp_MtpServer_set_ptp_mode(JNIEnv *env, jobject thiz, jboolean usePtp)
 {
 #ifdef HAVE_ANDROID_OS
-    MtpThread *thread = (MtpThread *)env->GetIntField(thiz, field_context);
+    MtpThread *thread = sThread.get();
     if (thread)
         thread->setPtpMode(usePtp);
 #endif
@@ -241,11 +218,6 @@
         LOGE("Can't find android/mtp/MtpServer");
         return -1;
     }
-    field_context = env->GetFieldID(clazz, "mNativeContext", "I");
-    if (field_context == NULL) {
-        LOGE("Can't find MtpServer.mNativeContext");
-        return -1;
-    }
 
     return AndroidRuntime::registerNativeMethods(env,
                 "android/mtp/MtpServer", gMethods, NELEM(gMethods));