USB: gadget: f_mtp: MTP driver cleanup:

Use a work queue instead of a separate thread for file transfer ioctls
(note: the file transfer must be done on a kernel thread rather than in
process context so vfs_read and vfs_write will use the correct address space
for the buffers)

Enforce requirement that only one ioctl call may be active at a time,
and remove mutex in mtp_send_event that is now no longer necessary.

Synchronize around use of shared variables to avoid SMP issues

Fix mismatched calls to fget and fput

Signed-off-by: Mike Lockwood <lockwood@android.com>
diff --git a/drivers/usb/gadget/f_mtp.c b/drivers/usb/gadget/f_mtp.c
index 2c61a26..81075f2 100644
--- a/drivers/usb/gadget/f_mtp.c
+++ b/drivers/usb/gadget/f_mtp.c
@@ -25,8 +25,6 @@
 #include <linux/wait.h>
 #include <linux/err.h>
 #include <linux/interrupt.h>
-#include <linux/kthread.h>
-#include <linux/freezer.h>
 
 #include <linux/types.h>
 #include <linux/file.h>
@@ -56,11 +54,6 @@
 #define TX_REQ_MAX 4
 #define RX_REQ_MAX 2
 
-/* IO Thread commands */
-#define ANDROID_THREAD_QUIT				1
-#define ANDROID_THREAD_SEND_FILE		2
-#define ANDROID_THREAD_RECEIVE_FILE		3
-
 /* ID for Microsoft MTP OS String */
 #define MTP_OS_STRING_ID   0xEE
 
@@ -92,6 +85,8 @@
 
 	/* synchronize access to our device file */
 	atomic_t open_excl;
+	/* to enforce only one ioctl at a time */
+	atomic_t ioctl_excl;
 
 	struct list_head tx_idle;
 
@@ -101,23 +96,19 @@
 	struct usb_request *rx_req[RX_REQ_MAX];
 	struct usb_request *intr_req;
 	int rx_done;
-
-	/* synchronize access to interrupt endpoint */
-	struct mutex intr_mutex;
 	/* true if interrupt endpoint is busy */
 	int intr_busy;
 
-	/* for our file IO thread */
-	struct task_struct			*thread;
-	/* current command for IO thread (or zero for none) */
-	int							thread_command;
-	struct file 				*thread_file;
-	loff_t						thread_file_offset;
-	size_t						thread_file_length;
-	/* used to wait for thread to complete current command */
-	struct completion			thread_wait;
-	/* result from current command */
-	int							thread_result;
+	/* for processing MTP_SEND_FILE and MTP_RECEIVE_FILE
+	 * ioctls on a work queue
+	 */
+	struct workqueue_struct *wq;
+	struct work_struct send_file_work;
+	struct work_struct receive_file_work;
+	struct file *xfer_file;
+	loff_t xfer_file_offset;
+	size_t xfer_file_length;
+	int xfer_result;
 };
 
 static struct usb_interface_descriptor mtp_interface_desc = {
@@ -622,14 +613,23 @@
 	return r;
 }
 
-static int mtp_send_file(struct mtp_dev *dev, struct file *filp,
-	loff_t offset, size_t count)
-{
+/* read from a local file and write to USB */
+static void send_file_work(struct work_struct *data) {
+	struct mtp_dev	*dev = container_of(data, struct mtp_dev, send_file_work);
 	struct usb_composite_dev *cdev = dev->cdev;
 	struct usb_request *req = 0;
-	int r = count, xfer, ret;
+	struct file *filp;
+	loff_t offset;
+	size_t count;
+	int r, xfer, ret;
 
-	DBG(cdev, "mtp_send_file(%lld %d)\n", offset, count);
+	/* read our parameters */
+	smp_rmb();
+	filp = dev->xfer_file;
+	offset = dev->xfer_file_offset;
+	r = count = dev->xfer_file_length;
+
+	DBG(cdev, "send_file_work(%lld %d)\n", offset, count);
 
 	while (count > 0) {
 		/* get an idle tx request to use */
@@ -656,7 +656,7 @@
 		req->length = xfer;
 		ret = usb_ep_queue(dev->ep_in, req, GFP_KERNEL);
 		if (ret < 0) {
-			DBG(cdev, "mtp_write: xfer error %d\n", ret);
+			DBG(cdev, "send_file_work: xfer error %d\n", ret);
 			dev->state = STATE_ERROR;
 			r = -EIO;
 			break;
@@ -671,20 +671,30 @@
 	if (req)
 		req_put(dev, &dev->tx_idle, req);
 
-	DBG(cdev, "mtp_write returning %d\n", r);
-	return r;
+	DBG(cdev, "send_file_work returning %d\n", r);
+	/* write the result */
+	dev->xfer_result = r;
+	smp_wmb();
 }
 
-static int mtp_receive_file(struct mtp_dev *dev, struct file *filp,
-	loff_t offset, size_t count)
+/* read from USB and write to a local file */
+static void receive_file_work(struct work_struct *data)
 {
+	struct mtp_dev	*dev = container_of(data, struct mtp_dev, receive_file_work);
 	struct usb_composite_dev *cdev = dev->cdev;
 	struct usb_request *read_req = NULL, *write_req = NULL;
-	int r = count;
-	int ret;
-	int cur_buf = 0;
+	struct file *filp;
+	loff_t offset;
+	size_t count;
+	int r, ret, cur_buf = 0;
 
-	DBG(cdev, "mtp_receive_file(%d)\n", count);
+	/* read our parameters */
+	smp_rmb();
+	filp = dev->xfer_file;
+	offset = dev->xfer_file_offset;
+	r = count = dev->xfer_file_length;
+
+	DBG(cdev, "receive_file_work(%d)\n", count);
 
 	while (count > 0 || write_req) {
 		if (count > 0) {
@@ -731,64 +741,10 @@
 		}
 	}
 
-	DBG(cdev, "mtp_read returning %d\n", r);
-	return r;
-}
-
-/* Kernel thread for handling file IO operations */
-static int mtp_thread(void *data)
-{
-	struct mtp_dev *dev = (struct mtp_dev *)data;
-	struct usb_composite_dev *cdev = dev->cdev;
-	int flags;
-
-	DBG(cdev, "mtp_thread started\n");
-
-	while (1) {
-		/* wait for a command */
-		while (1) {
-			try_to_freeze();
-			set_current_state(TASK_INTERRUPTIBLE);
-			if (dev->thread_command != 0)
-				break;
-			schedule();
-		}
-		__set_current_state(TASK_RUNNING);
-
-		if (dev->thread_command == ANDROID_THREAD_QUIT) {
-			DBG(cdev, "ANDROID_THREAD_QUIT\n");
-			dev->thread_result = 0;
-			goto done;
-		}
-
-		if (dev->thread_command == ANDROID_THREAD_SEND_FILE)
-			flags = O_RDONLY | O_LARGEFILE;
-		else
-			flags = O_WRONLY | O_LARGEFILE | O_CREAT;
-
-		if (dev->thread_command == ANDROID_THREAD_SEND_FILE) {
-			dev->thread_result = mtp_send_file(dev,
-				dev->thread_file,
-				dev->thread_file_offset,
-				dev->thread_file_length);
-		} else {
-			dev->thread_result = mtp_receive_file(dev,
-				dev->thread_file,
-				dev->thread_file_offset,
-				dev->thread_file_length);
-		}
-
-		if (dev->thread_file) {
-			fput(dev->thread_file);
-			dev->thread_file = NULL;
-		}
-		dev->thread_command = 0;
-		complete(&dev->thread_wait);
-	}
-
-done:
-	DBG(cdev, "android_thread done\n");
-	complete_and_exit(&dev->thread_wait, 0);
+	DBG(cdev, "receive_file_work returning %d\n", r);
+	/* write the result */
+	dev->xfer_result = r;
+	smp_wmb();
 }
 
 static int mtp_send_event(struct mtp_dev *dev, struct mtp_event *event)
@@ -802,29 +758,21 @@
 	if (length < 0 || length > INTR_BUFFER_SIZE)
 		return -EINVAL;
 
-	mutex_lock(&dev->intr_mutex);
-
 	/* wait for a request to complete */
 	ret = wait_event_interruptible(dev->intr_wq, !dev->intr_busy || dev->state == STATE_OFFLINE);
 	if (ret < 0)
-		goto done;
-	if (dev->state == STATE_OFFLINE) {
-		ret = -ENODEV;
-		goto done;
-	}
+		return ret;
+	if (dev->state == STATE_OFFLINE)
+		return -ENODEV;
 	req = dev->intr_req;
-	if (copy_from_user(req->buf, (void __user *)event->data, length)) {
-		ret = -EFAULT;
-		goto done;
-	}
+	if (copy_from_user(req->buf, (void __user *)event->data, length))
+		return -EFAULT;
 	req->length = length;
 	dev->intr_busy = 1;
 	ret = usb_ep_queue(dev->ep_intr, req, GFP_KERNEL);
 	if (ret)
 		dev->intr_busy = 0;
 
-done:
-	mutex_unlock(&dev->intr_mutex);
 	return ret;
 }
 
@@ -834,22 +782,28 @@
 	struct file *filp = NULL;
 	int ret = -EINVAL;
 
+	if (_lock(&dev->ioctl_excl))
+		return -EBUSY;
+
 	switch (code) {
 	case MTP_SEND_FILE:
 	case MTP_RECEIVE_FILE:
 	{
 		struct mtp_file_range	mfr;
+		struct work_struct *work;
 
 		spin_lock_irq(&dev->lock);
 		if (dev->state == STATE_CANCELED) {
 			/* report cancelation to userspace */
 			dev->state = STATE_READY;
 			spin_unlock_irq(&dev->lock);
-			return -ECANCELED;
+			ret = -ECANCELED;
+			goto out;
 		}
 		if (dev->state == STATE_OFFLINE) {
 			spin_unlock_irq(&dev->lock);
-			return -ENODEV;
+			ret = -ENODEV;
+			goto out;
 		}
 		dev->state = STATE_BUSY;
 		spin_unlock_irq(&dev->lock);
@@ -858,29 +812,36 @@
 			ret = -EFAULT;
 			goto fail;
 		}
+		/* hold a reference to the file while we are working with it */
 		filp = fget(mfr.fd);
 		if (!filp) {
 			ret = -EBADF;
 			goto fail;
 		}
 
-		dev->thread_file = filp;
-		dev->thread_file_offset = mfr.offset;
-		dev->thread_file_length = mfr.length;
+		/* write the parameters */
+		dev->xfer_file = filp;
+		dev->xfer_file_offset = mfr.offset;
+		dev->xfer_file_length = mfr.length;
+		smp_wmb();
 
 		if (code == MTP_SEND_FILE)
-			dev->thread_command = ANDROID_THREAD_SEND_FILE;
+			work = &dev->send_file_work;
 		else
-			dev->thread_command = ANDROID_THREAD_RECEIVE_FILE;
+			work = &dev->receive_file_work;
 
-		/* wake up the thread */
-		init_completion(&dev->thread_wait);
-		wake_up_process(dev->thread);
+		/* We do the file transfer on a work queue so it will run
+		 * in kernel context, which is necessary for vfs_read and
+		 * vfs_write to use our buffers in the kernel address space.
+		 */
+		queue_work(dev->wq, work);
+		/* wait for operation to complete */
+		flush_workqueue(dev->wq);
+		fput(filp);
 
-		/* wait for the thread to complete the command */
-		wait_for_completion(&dev->thread_wait);
-		ret = dev->thread_result;
-		DBG(dev->cdev, "thread returned %d\n", ret);
+		/* read the result */
+		smp_rmb();
+		ret = dev->xfer_result;
 		break;
 	}
 	case MTP_SET_INTERFACE_MODE:
@@ -904,21 +865,22 @@
 		 * which would interfere with bulk transfer state.
 		 */
 		if (copy_from_user(&event, (void __user *)value, sizeof(event)))
-			return -EFAULT;
+			ret = -EFAULT;
 		else
-			return mtp_send_event(dev, &event);
+			ret = mtp_send_event(dev, &event);
+		goto out;
 	}
 	}
 
 fail:
-	if (filp)
-		fput(filp);
 	spin_lock_irq(&dev->lock);
 	if (dev->state == STATE_CANCELED)
 		ret = -ECANCELED;
 	else if (dev->state != STATE_OFFLINE)
 		dev->state = STATE_READY;
 	spin_unlock_irq(&dev->lock);
+out:
+	_unlock(&dev->ioctl_excl);
 	DBG(dev->cdev, "ioctl returning %d\n", ret);
 	return ret;
 }
@@ -929,10 +891,6 @@
 	if (_lock(&_mtp_dev->open_excl))
 		return -EBUSY;
 
-	_mtp_dev->thread = kthread_create(mtp_thread, _mtp_dev, "f_mtp");
-	if (IS_ERR(_mtp_dev->thread))
-		return -ENOMEM;
-
 	/* clear any error condition */
 	if (_mtp_dev->state != STATE_OFFLINE)
 		_mtp_dev->state = STATE_READY;
@@ -945,14 +903,6 @@
 {
 	printk(KERN_INFO "mtp_release\n");
 
-	/* tell the thread to quit */
-	if (_mtp_dev->thread) {
-		_mtp_dev->thread_command = ANDROID_THREAD_QUIT;
-		init_completion(&_mtp_dev->thread_wait);
-		wake_up_process(_mtp_dev->thread);
-		wait_for_completion(&_mtp_dev->thread_wait);
-	}
-
 	_unlock(&_mtp_dev->open_excl);
 	return 0;
 }
@@ -1216,13 +1166,18 @@
 	}
 
 	spin_lock_init(&dev->lock);
-	init_completion(&dev->thread_wait);
 	init_waitqueue_head(&dev->read_wq);
 	init_waitqueue_head(&dev->write_wq);
 	init_waitqueue_head(&dev->intr_wq);
 	atomic_set(&dev->open_excl, 0);
+	atomic_set(&dev->ioctl_excl, 0);
 	INIT_LIST_HEAD(&dev->tx_idle);
-	mutex_init(&dev->intr_mutex);
+
+	dev->wq = create_singlethread_workqueue("f_mtp");
+	if (!dev->wq)
+		goto err1;
+	INIT_WORK(&dev->send_file_work, send_file_work);
+	INIT_WORK(&dev->receive_file_work, receive_file_work);
 
 	dev->cdev = c->cdev;
 	dev->function.name = "mtp";
@@ -1254,6 +1209,8 @@
 err2:
 	misc_deregister(&mtp_device);
 err1:
+	if (dev->wq)
+		destroy_workqueue(dev->wq);
 	kfree(dev);
 	printk(KERN_ERR "mtp gadget driver failed to initialize\n");
 	return ret;