IB/umad: Simplify and fix locking

In addition to being overly complex, the locking in user_mad.c is
broken: there were multiple reports of deadlocks and lockdep warnings.
In particular it seems that a single thread may end up trying to take
the same rwsem for reading more than once, which is explicitly
forbidden in the comments in <linux/rwsem.h>.

To solve this, we change the locking to use plain mutexes instead of
rwsems.  There is one mutex per open file, which protects the contents
of the struct ib_umad_file, including the array of agents and list of
queued packets; and there is one mutex per struct ib_umad_port, which
protects the contents, including the list of open files.  We never
hold the file mutex across calls to functions like ib_unregister_mad_agent(),
which can call back into other ib_umad code to queue a packet, and we
always hold the port mutex as long as we need to make sure that a
device is not hot-unplugged from under us.

This even makes things nicer for users of the -rt patch, since we
remove calls to downgrade_write() (which is not implemented in -rt).

Signed-off-by: Roland Dreier <rolandd@cisco.com>
diff --git a/drivers/infiniband/core/user_mad.c b/drivers/infiniband/core/user_mad.c
index b53eac4..4e91510 100644
--- a/drivers/infiniband/core/user_mad.c
+++ b/drivers/infiniband/core/user_mad.c
@@ -2,6 +2,7 @@
  * Copyright (c) 2004 Topspin Communications.  All rights reserved.
  * Copyright (c) 2005 Voltaire, Inc. All rights reserved.
  * Copyright (c) 2005 Sun Microsystems, Inc. All rights reserved.
+ * Copyright (c) 2008 Cisco. All rights reserved.
  *
  * This software is available to you under a choice of one of two
  * licenses.  You may choose to be licensed under the terms of the GNU
@@ -42,7 +43,7 @@
 #include <linux/cdev.h>
 #include <linux/dma-mapping.h>
 #include <linux/poll.h>
-#include <linux/rwsem.h>
+#include <linux/mutex.h>
 #include <linux/kref.h>
 #include <linux/compat.h>
 
@@ -94,7 +95,7 @@
 	struct class_device   *sm_class_dev;
 	struct semaphore       sm_sem;
 
-	struct rw_semaphore    mutex;
+	struct mutex	       file_mutex;
 	struct list_head       file_list;
 
 	struct ib_device      *ib_dev;
@@ -110,11 +111,11 @@
 };
 
 struct ib_umad_file {
+	struct mutex		mutex;
 	struct ib_umad_port    *port;
 	struct list_head	recv_list;
 	struct list_head	send_list;
 	struct list_head	port_list;
-	spinlock_t		recv_lock;
 	spinlock_t		send_lock;
 	wait_queue_head_t	recv_wait;
 	struct ib_mad_agent    *agent[IB_UMAD_MAX_AGENTS];
@@ -156,7 +157,7 @@
 		sizeof (struct ib_user_mad_hdr_old);
 }
 
-/* caller must hold port->mutex at least for reading */
+/* caller must hold file->mutex */
 static struct ib_mad_agent *__get_agent(struct ib_umad_file *file, int id)
 {
 	return file->agents_dead ? NULL : file->agent[id];
@@ -168,32 +169,30 @@
 {
 	int ret = 1;
 
-	down_read(&file->port->mutex);
+	mutex_lock(&file->mutex);
 
 	for (packet->mad.hdr.id = 0;
 	     packet->mad.hdr.id < IB_UMAD_MAX_AGENTS;
 	     packet->mad.hdr.id++)
 		if (agent == __get_agent(file, packet->mad.hdr.id)) {
-			spin_lock_irq(&file->recv_lock);
 			list_add_tail(&packet->list, &file->recv_list);
-			spin_unlock_irq(&file->recv_lock);
 			wake_up_interruptible(&file->recv_wait);
 			ret = 0;
 			break;
 		}
 
-	up_read(&file->port->mutex);
+	mutex_unlock(&file->mutex);
 
 	return ret;
 }
 
 static void dequeue_send(struct ib_umad_file *file,
 			 struct ib_umad_packet *packet)
- {
+{
 	spin_lock_irq(&file->send_lock);
 	list_del(&packet->list);
 	spin_unlock_irq(&file->send_lock);
- }
+}
 
 static void send_handler(struct ib_mad_agent *agent,
 			 struct ib_mad_send_wc *send_wc)
@@ -341,10 +340,10 @@
 	if (count < hdr_size(file))
 		return -EINVAL;
 
-	spin_lock_irq(&file->recv_lock);
+	mutex_lock(&file->mutex);
 
 	while (list_empty(&file->recv_list)) {
-		spin_unlock_irq(&file->recv_lock);
+		mutex_unlock(&file->mutex);
 
 		if (filp->f_flags & O_NONBLOCK)
 			return -EAGAIN;
@@ -353,13 +352,13 @@
 					     !list_empty(&file->recv_list)))
 			return -ERESTARTSYS;
 
-		spin_lock_irq(&file->recv_lock);
+		mutex_lock(&file->mutex);
 	}
 
 	packet = list_entry(file->recv_list.next, struct ib_umad_packet, list);
 	list_del(&packet->list);
 
-	spin_unlock_irq(&file->recv_lock);
+	mutex_unlock(&file->mutex);
 
 	if (packet->recv_wc)
 		ret = copy_recv_mad(file, buf, packet, count);
@@ -368,9 +367,9 @@
 
 	if (ret < 0) {
 		/* Requeue packet */
-		spin_lock_irq(&file->recv_lock);
+		mutex_lock(&file->mutex);
 		list_add(&packet->list, &file->recv_list);
-		spin_unlock_irq(&file->recv_lock);
+		mutex_unlock(&file->mutex);
 	} else {
 		if (packet->recv_wc)
 			ib_free_recv_mad(packet->recv_wc);
@@ -481,7 +480,7 @@
 		goto err;
 	}
 
-	down_read(&file->port->mutex);
+	mutex_lock(&file->mutex);
 
 	agent = __get_agent(file, packet->mad.hdr.id);
 	if (!agent) {
@@ -577,7 +576,7 @@
 	if (ret)
 		goto err_send;
 
-	up_read(&file->port->mutex);
+	mutex_unlock(&file->mutex);
 	return count;
 
 err_send:
@@ -587,7 +586,7 @@
 err_ah:
 	ib_destroy_ah(ah);
 err_up:
-	up_read(&file->port->mutex);
+	mutex_unlock(&file->mutex);
 err:
 	kfree(packet);
 	return ret;
@@ -613,11 +612,12 @@
 {
 	struct ib_user_mad_reg_req ureq;
 	struct ib_mad_reg_req req;
-	struct ib_mad_agent *agent;
+	struct ib_mad_agent *agent = NULL;
 	int agent_id;
 	int ret;
 
-	down_write(&file->port->mutex);
+	mutex_lock(&file->port->file_mutex);
+	mutex_lock(&file->mutex);
 
 	if (!file->port->ib_dev) {
 		ret = -EPIPE;
@@ -666,13 +666,13 @@
 				      send_handler, recv_handler, file);
 	if (IS_ERR(agent)) {
 		ret = PTR_ERR(agent);
+		agent = NULL;
 		goto out;
 	}
 
 	if (put_user(agent_id,
 		     (u32 __user *) (arg + offsetof(struct ib_user_mad_reg_req, id)))) {
 		ret = -EFAULT;
-		ib_unregister_mad_agent(agent);
 		goto out;
 	}
 
@@ -690,7 +690,13 @@
 	ret = 0;
 
 out:
-	up_write(&file->port->mutex);
+	mutex_unlock(&file->mutex);
+
+	if (ret && agent)
+		ib_unregister_mad_agent(agent);
+
+	mutex_unlock(&file->port->file_mutex);
+
 	return ret;
 }
 
@@ -703,7 +709,8 @@
 	if (get_user(id, arg))
 		return -EFAULT;
 
-	down_write(&file->port->mutex);
+	mutex_lock(&file->port->file_mutex);
+	mutex_lock(&file->mutex);
 
 	if (id < 0 || id >= IB_UMAD_MAX_AGENTS || !__get_agent(file, id)) {
 		ret = -EINVAL;
@@ -714,11 +721,13 @@
 	file->agent[id] = NULL;
 
 out:
-	up_write(&file->port->mutex);
+	mutex_unlock(&file->mutex);
 
 	if (agent)
 		ib_unregister_mad_agent(agent);
 
+	mutex_unlock(&file->port->file_mutex);
+
 	return ret;
 }
 
@@ -726,12 +735,12 @@
 {
 	int ret = 0;
 
-	down_write(&file->port->mutex);
+	mutex_lock(&file->mutex);
 	if (file->already_used)
 		ret = -EINVAL;
 	else
 		file->use_pkey_index = 1;
-	up_write(&file->port->mutex);
+	mutex_unlock(&file->mutex);
 
 	return ret;
 }
@@ -783,7 +792,7 @@
 	if (!port)
 		return -ENXIO;
 
-	down_write(&port->mutex);
+	mutex_lock(&port->file_mutex);
 
 	if (!port->ib_dev) {
 		ret = -ENXIO;
@@ -797,7 +806,7 @@
 		goto out;
 	}
 
-	spin_lock_init(&file->recv_lock);
+	mutex_init(&file->mutex);
 	spin_lock_init(&file->send_lock);
 	INIT_LIST_HEAD(&file->recv_list);
 	INIT_LIST_HEAD(&file->send_list);
@@ -809,7 +818,7 @@
 	list_add_tail(&file->port_list, &port->file_list);
 
 out:
-	up_write(&port->mutex);
+	mutex_unlock(&port->file_mutex);
 	return ret;
 }
 
@@ -821,7 +830,8 @@
 	int already_dead;
 	int i;
 
-	down_write(&file->port->mutex);
+	mutex_lock(&file->port->file_mutex);
+	mutex_lock(&file->mutex);
 
 	already_dead = file->agents_dead;
 	file->agents_dead = 1;
@@ -834,14 +844,14 @@
 
 	list_del(&file->port_list);
 
-	downgrade_write(&file->port->mutex);
+	mutex_unlock(&file->mutex);
 
 	if (!already_dead)
 		for (i = 0; i < IB_UMAD_MAX_AGENTS; ++i)
 			if (file->agent[i])
 				ib_unregister_mad_agent(file->agent[i]);
 
-	up_read(&file->port->mutex);
+	mutex_unlock(&file->port->file_mutex);
 
 	kfree(file);
 	kref_put(&dev->ref, ib_umad_release_dev);
@@ -914,10 +924,10 @@
 	};
 	int ret = 0;
 
-	down_write(&port->mutex);
+	mutex_lock(&port->file_mutex);
 	if (port->ib_dev)
 		ret = ib_modify_port(port->ib_dev, port->port_num, 0, &props);
-	up_write(&port->mutex);
+	mutex_unlock(&port->file_mutex);
 
 	up(&port->sm_sem);
 
@@ -981,7 +991,7 @@
 	port->ib_dev   = device;
 	port->port_num = port_num;
 	init_MUTEX(&port->sm_sem);
-	init_rwsem(&port->mutex);
+	mutex_init(&port->file_mutex);
 	INIT_LIST_HEAD(&port->file_list);
 
 	port->dev = cdev_alloc();
@@ -1052,6 +1062,7 @@
 static void ib_umad_kill_port(struct ib_umad_port *port)
 {
 	struct ib_umad_file *file;
+	int already_dead;
 	int id;
 
 	class_set_devdata(port->class_dev,    NULL);
@@ -1067,42 +1078,22 @@
 	umad_port[port->dev_num] = NULL;
 	spin_unlock(&port_lock);
 
-	down_write(&port->mutex);
+	mutex_lock(&port->file_mutex);
 
 	port->ib_dev = NULL;
 
-	/*
-	 * Now go through the list of files attached to this port and
-	 * unregister all of their MAD agents.  We need to hold
-	 * port->mutex while doing this to avoid racing with
-	 * ib_umad_close(), but we can't hold the mutex for writing
-	 * while calling ib_unregister_mad_agent(), since that might
-	 * deadlock by calling back into queue_packet().  So we
-	 * downgrade our lock to a read lock, and then drop and
-	 * reacquire the write lock for the next iteration.
-	 *
-	 * We do list_del_init() on the file's list_head so that the
-	 * list_del in ib_umad_close() is still OK, even after the
-	 * file is removed from the list.
-	 */
-	while (!list_empty(&port->file_list)) {
-		file = list_entry(port->file_list.next, struct ib_umad_file,
-				  port_list);
-
+	list_for_each_entry(file, &port->file_list, port_list) {
+		mutex_lock(&file->mutex);
+		already_dead = file->agents_dead;
 		file->agents_dead = 1;
-		list_del_init(&file->port_list);
-
-		downgrade_write(&port->mutex);
+		mutex_unlock(&file->mutex);
 
 		for (id = 0; id < IB_UMAD_MAX_AGENTS; ++id)
 			if (file->agent[id])
 				ib_unregister_mad_agent(file->agent[id]);
-
-		up_read(&port->mutex);
-		down_write(&port->mutex);
 	}
 
-	up_write(&port->mutex);
+	mutex_unlock(&port->file_mutex);
 
 	clear_bit(port->dev_num, dev_map);
 }