rhashtable: Fix walker behaviour during rehash

Previously whenever the walker encountered a resize it simply
snaps back to the beginning and starts again.  However, this only
works if the rehash started and completed while the walker was
idle.

If the walker attempts to restart while the rehash is still ongoing,
we may miss objects that we shouldn't have.

This patch fixes this by making the walker walk the old table
followed by the new table just like all other readers.  If a
rehash is detected we will still signal our caller of the fact
so they can prepare for duplicates but we will simply continue
the walk onto the new table after the old one is finished either
by us or by the rehasher.

Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/lib/rhashtable.c b/lib/rhashtable.c
index fc0d451..f7c7607 100644
--- a/lib/rhashtable.c
+++ b/lib/rhashtable.c
@@ -170,6 +170,8 @@
 		return NULL;
 	}
 
+	INIT_LIST_HEAD(&tbl->walkers);
+
 	for (i = 0; i < nbuckets; i++)
 		INIT_RHT_NULLS_HEAD(tbl->buckets[i], ht, i);
 
@@ -264,6 +266,7 @@
 			      struct bucket_table *new_tbl)
 {
 	struct bucket_table *old_tbl = rht_dereference(ht->tbl, ht);
+	struct rhashtable_walker *walker;
 	unsigned old_hash;
 
 	get_random_bytes(&new_tbl->hash_rnd, sizeof(new_tbl->hash_rnd));
@@ -284,6 +287,9 @@
 	/* Publish the new table pointer. */
 	rcu_assign_pointer(ht->tbl, new_tbl);
 
+	list_for_each_entry(walker, &old_tbl->walkers, list)
+		walker->tbl = NULL;
+
 	/* Wait for readers. All new readers will see the new
 	 * table, and thus no references to the old table will
 	 * remain.
@@ -358,7 +364,6 @@
 {
 	struct rhashtable *ht;
 	struct bucket_table *tbl;
-	struct rhashtable_walker *walker;
 
 	ht = container_of(work, struct rhashtable, run_work);
 	mutex_lock(&ht->mutex);
@@ -367,9 +372,6 @@
 
 	tbl = rht_dereference(ht->tbl, ht);
 
-	list_for_each_entry(walker, &ht->walkers, list)
-		walker->resize = true;
-
 	if (rht_grow_above_75(ht, tbl))
 		rhashtable_expand(ht);
 	else if (rht_shrink_below_30(ht, tbl))
@@ -725,11 +727,9 @@
 	if (!iter->walker)
 		return -ENOMEM;
 
-	INIT_LIST_HEAD(&iter->walker->list);
-	iter->walker->resize = false;
-
 	mutex_lock(&ht->mutex);
-	list_add(&iter->walker->list, &ht->walkers);
+	iter->walker->tbl = rht_dereference(ht->tbl, ht);
+	list_add(&iter->walker->list, &iter->walker->tbl->walkers);
 	mutex_unlock(&ht->mutex);
 
 	return 0;
@@ -745,7 +745,8 @@
 void rhashtable_walk_exit(struct rhashtable_iter *iter)
 {
 	mutex_lock(&iter->ht->mutex);
-	list_del(&iter->walker->list);
+	if (iter->walker->tbl)
+		list_del(&iter->walker->list);
 	mutex_unlock(&iter->ht->mutex);
 	kfree(iter->walker);
 }
@@ -767,12 +768,19 @@
  */
 int rhashtable_walk_start(struct rhashtable_iter *iter)
 {
+	struct rhashtable *ht = iter->ht;
+
+	mutex_lock(&ht->mutex);
+
+	if (iter->walker->tbl)
+		list_del(&iter->walker->list);
+
 	rcu_read_lock();
 
-	if (iter->walker->resize) {
-		iter->slot = 0;
-		iter->skip = 0;
-		iter->walker->resize = false;
+	mutex_unlock(&ht->mutex);
+
+	if (!iter->walker->tbl) {
+		iter->walker->tbl = rht_dereference_rcu(ht->tbl, ht);
 		return -EAGAIN;
 	}
 
@@ -794,13 +802,11 @@
  */
 void *rhashtable_walk_next(struct rhashtable_iter *iter)
 {
-	const struct bucket_table *tbl;
+	struct bucket_table *tbl = iter->walker->tbl;
 	struct rhashtable *ht = iter->ht;
 	struct rhash_head *p = iter->p;
 	void *obj = NULL;
 
-	tbl = rht_dereference_rcu(ht->tbl, ht);
-
 	if (p) {
 		p = rht_dereference_bucket_rcu(p->next, tbl, iter->slot);
 		goto next;
@@ -826,16 +832,17 @@
 		iter->skip = 0;
 	}
 
+	iter->walker->tbl = rht_dereference_rcu(ht->future_tbl, ht);
+	if (iter->walker->tbl != tbl) {
+		iter->slot = 0;
+		iter->skip = 0;
+		return ERR_PTR(-EAGAIN);
+	}
+
+	iter->walker->tbl = NULL;
 	iter->p = NULL;
 
 out:
-	if (iter->walker->resize) {
-		iter->p = NULL;
-		iter->slot = 0;
-		iter->skip = 0;
-		iter->walker->resize = false;
-		return ERR_PTR(-EAGAIN);
-	}
 
 	return obj;
 }
@@ -849,7 +856,24 @@
  */
 void rhashtable_walk_stop(struct rhashtable_iter *iter)
 {
+	struct rhashtable *ht;
+	struct bucket_table *tbl = iter->walker->tbl;
+
 	rcu_read_unlock();
+
+	if (!tbl)
+		return;
+
+	ht = iter->ht;
+
+	mutex_lock(&ht->mutex);
+	if (rht_dereference(ht->tbl, ht) == tbl ||
+	    rht_dereference(ht->future_tbl, ht) == tbl)
+		list_add(&iter->walker->list, &tbl->walkers);
+	else
+		iter->walker->tbl = NULL;
+	mutex_unlock(&ht->mutex);
+
 	iter->p = NULL;
 }
 EXPORT_SYMBOL_GPL(rhashtable_walk_stop);
@@ -927,7 +951,6 @@
 	memset(ht, 0, sizeof(*ht));
 	mutex_init(&ht->mutex);
 	memcpy(&ht->p, params, sizeof(*params));
-	INIT_LIST_HEAD(&ht->walkers);
 
 	if (params->locks_mul)
 		ht->p.locks_mul = roundup_pow_of_two(params->locks_mul);