xarray: Add XArray iterators

The xa_for_each iterator allows the user to efficiently walk a range
of the array, executing the loop body once for each entry in that
range that matches the filter.  This commit also includes xa_find()
and xa_find_after() which are helper functions for xa_for_each() but
may also be useful in their own right.

In the xas family of functions, we have xas_for_each(), xas_find(),
xas_next_entry(), xas_for_each_tagged(), xas_find_tagged(),
xas_next_tagged() and xas_pause().

Signed-off-by: Matthew Wilcox <willy@infradead.org>
diff --git a/lib/test_xarray.c b/lib/test_xarray.c
index fb47225..e3c2d4d 100644
--- a/lib/test_xarray.c
+++ b/lib/test_xarray.c
@@ -75,6 +75,48 @@
 //	XA_BUG_ON(xa, xa_err(xa_store(xa, 0, xa_mk_internal(0), 0)) != -EINVAL);
 }
 
+static noinline void check_xas_retry(struct xarray *xa)
+{
+	XA_STATE(xas, xa, 0);
+	void *entry;
+
+	xa_store_index(xa, 0, GFP_KERNEL);
+	xa_store_index(xa, 1, GFP_KERNEL);
+
+	rcu_read_lock();
+	XA_BUG_ON(xa, xas_find(&xas, ULONG_MAX) != xa_mk_value(0));
+	xa_erase_index(xa, 1);
+	XA_BUG_ON(xa, !xa_is_retry(xas_reload(&xas)));
+	XA_BUG_ON(xa, xas_retry(&xas, NULL));
+	XA_BUG_ON(xa, xas_retry(&xas, xa_mk_value(0)));
+	xas_reset(&xas);
+	XA_BUG_ON(xa, xas.xa_node != XAS_RESTART);
+	XA_BUG_ON(xa, xas_next_entry(&xas, ULONG_MAX) != xa_mk_value(0));
+	XA_BUG_ON(xa, xas.xa_node != NULL);
+
+	XA_BUG_ON(xa, xa_store_index(xa, 1, GFP_KERNEL) != NULL);
+	XA_BUG_ON(xa, !xa_is_internal(xas_reload(&xas)));
+	xas.xa_node = XAS_RESTART;
+	XA_BUG_ON(xa, xas_next_entry(&xas, ULONG_MAX) != xa_mk_value(0));
+	rcu_read_unlock();
+
+	/* Make sure we can iterate through retry entries */
+	xas_lock(&xas);
+	xas_set(&xas, 0);
+	xas_store(&xas, XA_RETRY_ENTRY);
+	xas_set(&xas, 1);
+	xas_store(&xas, XA_RETRY_ENTRY);
+
+	xas_set(&xas, 0);
+	xas_for_each(&xas, entry, ULONG_MAX) {
+		xas_store(&xas, xa_mk_value(xas.xa_index));
+	}
+	xas_unlock(&xas);
+
+	xa_erase_index(xa, 0);
+	xa_erase_index(xa, 1);
+}
+
 static noinline void check_xa_load(struct xarray *xa)
 {
 	unsigned long i, j;
@@ -217,6 +259,44 @@
 	XA_BUG_ON(xa, !xa_empty(xa));
 }
 
+static noinline void check_xas_erase(struct xarray *xa)
+{
+	XA_STATE(xas, xa, 0);
+	void *entry;
+	unsigned long i, j;
+
+	for (i = 0; i < 200; i++) {
+		for (j = i; j < 2 * i + 17; j++) {
+			xas_set(&xas, j);
+			do {
+				xas_lock(&xas);
+				xas_store(&xas, xa_mk_value(j));
+				xas_unlock(&xas);
+			} while (xas_nomem(&xas, GFP_KERNEL));
+		}
+
+		xas_set(&xas, ULONG_MAX);
+		do {
+			xas_lock(&xas);
+			xas_store(&xas, xa_mk_value(0));
+			xas_unlock(&xas);
+		} while (xas_nomem(&xas, GFP_KERNEL));
+
+		xas_lock(&xas);
+		xas_store(&xas, NULL);
+
+		xas_set(&xas, 0);
+		j = i;
+		xas_for_each(&xas, entry, ULONG_MAX) {
+			XA_BUG_ON(xa, entry != xa_mk_value(j));
+			xas_store(&xas, NULL);
+			j++;
+		}
+		xas_unlock(&xas);
+		XA_BUG_ON(xa, !xa_empty(xa));
+	}
+}
+
 static noinline void check_multi_store(struct xarray *xa)
 {
 #ifdef CONFIG_XARRAY_MULTI
@@ -285,16 +365,119 @@
 #endif
 }
 
+static noinline void check_multi_find(struct xarray *xa)
+{
+#ifdef CONFIG_XARRAY_MULTI
+	unsigned long index;
+
+	xa_store_order(xa, 12, 2, xa_mk_value(12), GFP_KERNEL);
+	XA_BUG_ON(xa, xa_store_index(xa, 16, GFP_KERNEL) != NULL);
+
+	index = 0;
+	XA_BUG_ON(xa, xa_find(xa, &index, ULONG_MAX, XA_PRESENT) !=
+			xa_mk_value(12));
+	XA_BUG_ON(xa, index != 12);
+	index = 13;
+	XA_BUG_ON(xa, xa_find(xa, &index, ULONG_MAX, XA_PRESENT) !=
+			xa_mk_value(12));
+	XA_BUG_ON(xa, (index < 12) || (index >= 16));
+	XA_BUG_ON(xa, xa_find_after(xa, &index, ULONG_MAX, XA_PRESENT) !=
+			xa_mk_value(16));
+	XA_BUG_ON(xa, index != 16);
+
+	xa_erase_index(xa, 12);
+	xa_erase_index(xa, 16);
+	XA_BUG_ON(xa, !xa_empty(xa));
+#endif
+}
+
+static noinline void check_multi_find_2(struct xarray *xa)
+{
+	unsigned int max_order = IS_ENABLED(CONFIG_XARRAY_MULTI) ? 10 : 1;
+	unsigned int i, j;
+	void *entry;
+
+	for (i = 0; i < max_order; i++) {
+		unsigned long index = 1UL << i;
+		for (j = 0; j < index; j++) {
+			XA_STATE(xas, xa, j + index);
+			xa_store_index(xa, index - 1, GFP_KERNEL);
+			xa_store_order(xa, index, i, xa_mk_value(index),
+					GFP_KERNEL);
+			rcu_read_lock();
+			xas_for_each(&xas, entry, ULONG_MAX) {
+				xa_erase_index(xa, index);
+			}
+			rcu_read_unlock();
+			xa_erase_index(xa, index - 1);
+			XA_BUG_ON(xa, !xa_empty(xa));
+		}
+	}
+}
+
+static noinline void check_find(struct xarray *xa)
+{
+	unsigned long i, j, k;
+
+	XA_BUG_ON(xa, !xa_empty(xa));
+
+	/*
+	 * Check xa_find with all pairs between 0 and 99 inclusive,
+	 * starting at every index between 0 and 99
+	 */
+	for (i = 0; i < 100; i++) {
+		XA_BUG_ON(xa, xa_store_index(xa, i, GFP_KERNEL) != NULL);
+		xa_set_mark(xa, i, XA_MARK_0);
+		for (j = 0; j < i; j++) {
+			XA_BUG_ON(xa, xa_store_index(xa, j, GFP_KERNEL) !=
+					NULL);
+			xa_set_mark(xa, j, XA_MARK_0);
+			for (k = 0; k < 100; k++) {
+				unsigned long index = k;
+				void *entry = xa_find(xa, &index, ULONG_MAX,
+								XA_PRESENT);
+				if (k <= j)
+					XA_BUG_ON(xa, index != j);
+				else if (k <= i)
+					XA_BUG_ON(xa, index != i);
+				else
+					XA_BUG_ON(xa, entry != NULL);
+
+				index = k;
+				entry = xa_find(xa, &index, ULONG_MAX,
+								XA_MARK_0);
+				if (k <= j)
+					XA_BUG_ON(xa, index != j);
+				else if (k <= i)
+					XA_BUG_ON(xa, index != i);
+				else
+					XA_BUG_ON(xa, entry != NULL);
+			}
+			xa_erase_index(xa, j);
+			XA_BUG_ON(xa, xa_get_mark(xa, j, XA_MARK_0));
+			XA_BUG_ON(xa, !xa_get_mark(xa, i, XA_MARK_0));
+		}
+		xa_erase_index(xa, i);
+		XA_BUG_ON(xa, xa_get_mark(xa, i, XA_MARK_0));
+	}
+	XA_BUG_ON(xa, !xa_empty(xa));
+	check_multi_find(xa);
+	check_multi_find_2(xa);
+}
+
 static DEFINE_XARRAY(array);
 
 static int xarray_checks(void)
 {
 	check_xa_err(&array);
+	check_xas_retry(&array);
 	check_xa_load(&array);
 	check_xa_mark(&array);
 	check_xa_shrink(&array);
+	check_xas_erase(&array);
 	check_cmpxchg(&array);
 	check_multi_store(&array);
+	check_find(&array);
 
 	printk("XArray: %u of %u tests passed\n", tests_passed, tests_run);
 	return (tests_run == tests_passed) ? 0 : -EINVAL;