rhashtable: Fix sleeping inside RCU critical section in walk_stop

The commit 963ecbd41a1026d99ec7537c050867428c397b89 ("rhashtable:
Fix use-after-free in rhashtable_walk_stop") fixed a real bug
but created another one because we may end up sleeping inside an
RCU critical section.

This patch fixes it properly by replacing the mutex with a spin
lock that specifically protects the walker lists.

Reported-by: Sasha Levin <sasha.levin@oracle.com>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/include/linux/rhashtable.h b/include/linux/rhashtable.h
index f9ecf32..d7be9cb 100644
--- a/include/linux/rhashtable.h
+++ b/include/linux/rhashtable.h
@@ -133,6 +133,7 @@
  * @p: Configuration parameters
  * @run_work: Deferred worker to expand/shrink asynchronously
  * @mutex: Mutex to protect current/future table swapping
+ * @lock: Spin lock to protect walker list
  * @being_destroyed: True if table is set up for destruction
  */
 struct rhashtable {
@@ -144,6 +145,7 @@
 	struct rhashtable_params	p;
 	struct work_struct		run_work;
 	struct mutex                    mutex;
+	spinlock_t			lock;
 };
 
 /**
diff --git a/lib/rhashtable.c b/lib/rhashtable.c
index 7686c1e..e96ad1a 100644
--- a/lib/rhashtable.c
+++ b/lib/rhashtable.c
@@ -256,8 +256,10 @@
 	/* Publish the new table pointer. */
 	rcu_assign_pointer(ht->tbl, new_tbl);
 
+	spin_lock(&ht->lock);
 	list_for_each_entry(walker, &old_tbl->walkers, list)
 		walker->tbl = NULL;
+	spin_unlock(&ht->lock);
 
 	/* Wait for readers. All new readers will see the new
 	 * table, and thus no references to the old table will
@@ -635,12 +637,12 @@
 
 	ht = iter->ht;
 
-	mutex_lock(&ht->mutex);
+	spin_lock(&ht->lock);
 	if (tbl->rehash < tbl->size)
 		list_add(&iter->walker->list, &tbl->walkers);
 	else
 		iter->walker->tbl = NULL;
-	mutex_unlock(&ht->mutex);
+	spin_unlock(&ht->lock);
 
 	iter->p = NULL;
 
@@ -723,6 +725,7 @@
 
 	memset(ht, 0, sizeof(*ht));
 	mutex_init(&ht->mutex);
+	spin_lock_init(&ht->lock);
 	memcpy(&ht->p, params, sizeof(*params));
 
 	if (params->min_size)