MTP: Fix some thread safety issues in MTP server start/stop sequence.

Change-Id: Ied1cddc6220fa7394e8de99df9bc37a1208b04ff
Signed-off-by: Mike Lockwood <lockwood@android.com>
diff --git a/media/jni/android_media_MtpServer.cpp b/media/jni/android_media_MtpServer.cpp
index eddad57..1ef2c58 100644
--- a/media/jni/android_media_MtpServer.cpp
+++ b/media/jni/android_media_MtpServer.cpp
@@ -36,6 +36,7 @@
 // ----------------------------------------------------------------------------
 
 static jfieldID field_context;
+static Mutex    sMutex;
 
 // in android_media_MtpDatabase.cpp
 extern MtpDatabase* getMtpDatabase(JNIEnv *env, jobject database);
@@ -53,63 +54,80 @@
 private:
     MtpDatabase*    mDatabase;
     MtpServer*      mServer;
-    String8 mStoragePath;
-    bool mDone;
-    Mutex           mMutex;
+    String8         mStoragePath;
+    bool            mDone;
+    jobject         mJavaServer;
 
 public:
-    MtpThread(MtpDatabase* database, const char* storagePath)
-        : mDatabase(database), mServer(NULL), mStoragePath(storagePath), mDone(false)
+    MtpThread(MtpDatabase* database, const char* storagePath, jobject javaServer)
+        : mDatabase(database),
+            mServer(NULL),
+            mStoragePath(storagePath),
+            mDone(false),
+            mJavaServer(javaServer)
     {
     }
 
     virtual bool threadLoop() {
-        int fd = open("/dev/mtp_usb", O_RDWR);
-        printf("open returned %d\n", fd);
-        if (fd < 0) {
-            LOGE("could not open MTP driver\n");
-            return false;
+        while (1) {
+            int fd = open("/dev/mtp_usb", O_RDWR);
+            printf("open returned %d\n", fd);
+            if (fd < 0) {
+                LOGE("could not open MTP driver\n");
+                break;
+            }
+
+            sMutex.lock();
+            mServer = new MtpServer(fd, mDatabase, AID_SDCARD_RW, 0664, 0775);
+            mServer->addStorage(mStoragePath);
+            sMutex.unlock();
+
+            LOGD("MtpThread mServer->run");
+            mServer->run();
+            close(fd);
+
+            sMutex.lock();
+            delete mServer;
+            mServer = NULL;
+            if (mDone)
+                goto done;
+            sMutex.unlock();
+            // wait a bit before retrying
+            sleep(1);
         }
 
-        mMutex.lock();
-        mServer = new MtpServer(fd, mDatabase, AID_SDCARD_RW, 0664, 0775);
-        mServer->addStorage(mStoragePath);
-        mMutex.unlock();
+        sMutex.lock();
+done:
+        JNIEnv* env = AndroidRuntime::getJNIEnv();
+        env->SetIntField(mJavaServer, field_context, 0);
+        env->DeleteGlobalRef(mJavaServer);
+        sMutex.unlock();
 
-        LOGD("MtpThread mServer->run");
-        mServer->run();
-        close(fd);
-
-        mMutex.lock();
-        delete mServer;
-        mServer = NULL;
-        mMutex.unlock();
-
-        bool done = mDone;
-        if (done)
-            delete this;
-        LOGD("threadLoop returning %s", (done ? "false" : "true"));
-        return !done;
+        LOGD("threadLoop returning");
+        return false;
     }
 
-    void setDone() { mDone = true; }
+    void setDone() {
+        LOGD("setDone");
+        mDone = true; 
+    }
 
     void sendObjectAdded(MtpObjectHandle handle) {
-        mMutex.lock();
+        sMutex.lock();
         if (mServer)
             mServer->sendObjectAdded(handle);
         else
             LOGE("sendObjectAdded called while disconnected\n");
-        mMutex.unlock();
+        sMutex.unlock();
     }
 
     void sendObjectRemoved(MtpObjectHandle handle) {
-        mMutex.lock();
+        sMutex.lock();
         if (mServer)
             mServer->sendObjectRemoved(handle);
         else
             LOGE("sendObjectRemoved called while disconnected\n");
-        mMutex.unlock();
+        sMutex.unlock();
     }
 };
 
@@ -124,7 +142,7 @@
     MtpDatabase* database = getMtpDatabase(env, javaDatabase);
     const char *storagePathStr = env->GetStringUTFChars(storagePath, NULL);
 
-    MtpThread* thread = new MtpThread(database, storagePathStr);
+    MtpThread* thread = new MtpThread(database, storagePathStr, env->NewGlobalRef(thiz));
     env->SetIntField(thiz, field_context, (int)thread);
 
     env->ReleaseStringUTFChars(storagePath, storagePathStr);
@@ -153,11 +171,11 @@
 {
 #ifdef HAVE_ANDROID_OS
     LOGD("stop\n");
+    sMutex.lock();
     MtpThread *thread = (MtpThread *)env->GetIntField(thiz, field_context);
-    if (thread) {
+    if (thread)
         thread->setDone();
-        env->SetIntField(thiz, field_context, 0);
-    }
+    sMutex.unlock();
 #endif
 }