aio: Fix a trinity splat

aio kiocb refcounting was broken - it was relying on keeping track of
the number of available ring buffer entries, which it needs to do
anyways; then at shutdown time it'd wait for completions to be delivered
until the # of available ring buffer entries equalled what it was
initialized to.

Problem with  that is that the ring buffer is mapped writable into
userspace, so userspace could futz with the head and tail pointers to
cause the kernel to see extra completions, and cause free_ioctx() to
return while there were still outstanding kiocbs. Which would be bad.

Fix is just to directly refcount the kiocbs - which is more
straightforward, and with the new percpu refcounting code doesn't cost
us any cacheline bouncing which was the whole point of the original
scheme.

Also clean up ioctx_alloc()'s error path and fix a bug where it wasn't
subtracting from aio_nr if ioctx_add_table() failed.

Signed-off-by: Kent Overstreet <kmo@daterainc.com>
diff --git a/fs/aio.c b/fs/aio.c
index 067e3d3..ee77dc1 100644
--- a/fs/aio.c
+++ b/fs/aio.c
@@ -80,6 +80,8 @@
 	struct percpu_ref	users;
 	atomic_t		dead;
 
+	struct percpu_ref	reqs;
+
 	unsigned long		user_id;
 
 	struct __percpu kioctx_cpu *cpu;
@@ -107,7 +109,6 @@
 	struct page		**ring_pages;
 	long			nr_pages;
 
-	struct rcu_head		rcu_head;
 	struct work_struct	free_work;
 
 	struct {
@@ -412,26 +413,34 @@
 	return cancel(kiocb);
 }
 
-static void free_ioctx_rcu(struct rcu_head *head)
+static void free_ioctx(struct work_struct *work)
 {
-	struct kioctx *ctx = container_of(head, struct kioctx, rcu_head);
+	struct kioctx *ctx = container_of(work, struct kioctx, free_work);
 
+	pr_debug("freeing %p\n", ctx);
+
+	aio_free_ring(ctx);
 	free_percpu(ctx->cpu);
 	kmem_cache_free(kioctx_cachep, ctx);
 }
 
+static void free_ioctx_reqs(struct percpu_ref *ref)
+{
+	struct kioctx *ctx = container_of(ref, struct kioctx, reqs);
+
+	INIT_WORK(&ctx->free_work, free_ioctx);
+	schedule_work(&ctx->free_work);
+}
+
 /*
  * When this function runs, the kioctx has been removed from the "hash table"
  * and ctx->users has dropped to 0, so we know no more kiocbs can be submitted -
  * now it's safe to cancel any that need to be.
  */
-static void free_ioctx(struct work_struct *work)
+static void free_ioctx_users(struct percpu_ref *ref)
 {
-	struct kioctx *ctx = container_of(work, struct kioctx, free_work);
-	struct aio_ring *ring;
+	struct kioctx *ctx = container_of(ref, struct kioctx, users);
 	struct kiocb *req;
-	unsigned cpu, avail;
-	DEFINE_WAIT(wait);
 
 	spin_lock_irq(&ctx->ctx_lock);
 
@@ -445,54 +454,8 @@
 
 	spin_unlock_irq(&ctx->ctx_lock);
 
-	for_each_possible_cpu(cpu) {
-		struct kioctx_cpu *kcpu = per_cpu_ptr(ctx->cpu, cpu);
-
-		atomic_add(kcpu->reqs_available, &ctx->reqs_available);
-		kcpu->reqs_available = 0;
-	}
-
-	while (1) {
-		prepare_to_wait(&ctx->wait, &wait, TASK_UNINTERRUPTIBLE);
-
-		ring = kmap_atomic(ctx->ring_pages[0]);
-		avail = (ring->head <= ring->tail)
-			 ? ring->tail - ring->head
-			 : ctx->nr_events - ring->head + ring->tail;
-
-		atomic_add(avail, &ctx->reqs_available);
-		ring->head = ring->tail;
-		kunmap_atomic(ring);
-
-		if (atomic_read(&ctx->reqs_available) >= ctx->nr_events - 1)
-			break;
-
-		schedule();
-	}
-	finish_wait(&ctx->wait, &wait);
-
-	WARN_ON(atomic_read(&ctx->reqs_available) > ctx->nr_events - 1);
-
-	aio_free_ring(ctx);
-
-	pr_debug("freeing %p\n", ctx);
-
-	/*
-	 * Here the call_rcu() is between the wait_event() for reqs_active to
-	 * hit 0, and freeing the ioctx.
-	 *
-	 * aio_complete() decrements reqs_active, but it has to touch the ioctx
-	 * after to issue a wakeup so we use rcu.
-	 */
-	call_rcu(&ctx->rcu_head, free_ioctx_rcu);
-}
-
-static void free_ioctx_ref(struct percpu_ref *ref)
-{
-	struct kioctx *ctx = container_of(ref, struct kioctx, users);
-
-	INIT_WORK(&ctx->free_work, free_ioctx);
-	schedule_work(&ctx->free_work);
+	percpu_ref_kill(&ctx->reqs);
+	percpu_ref_put(&ctx->reqs);
 }
 
 static int ioctx_add_table(struct kioctx *ctx, struct mm_struct *mm)
@@ -551,6 +514,16 @@
 	}
 }
 
+static void aio_nr_sub(unsigned nr)
+{
+	spin_lock(&aio_nr_lock);
+	if (WARN_ON(aio_nr - nr > aio_nr))
+		aio_nr = 0;
+	else
+		aio_nr -= nr;
+	spin_unlock(&aio_nr_lock);
+}
+
 /* ioctx_alloc
  *	Allocates and initializes an ioctx.  Returns an ERR_PTR if it failed.
  */
@@ -588,8 +561,11 @@
 
 	ctx->max_reqs = nr_events;
 
-	if (percpu_ref_init(&ctx->users, free_ioctx_ref))
-		goto out_freectx;
+	if (percpu_ref_init(&ctx->users, free_ioctx_users))
+		goto err;
+
+	if (percpu_ref_init(&ctx->reqs, free_ioctx_reqs))
+		goto err;
 
 	spin_lock_init(&ctx->ctx_lock);
 	spin_lock_init(&ctx->completion_lock);
@@ -600,10 +576,10 @@
 
 	ctx->cpu = alloc_percpu(struct kioctx_cpu);
 	if (!ctx->cpu)
-		goto out_freeref;
+		goto err;
 
 	if (aio_setup_ring(ctx) < 0)
-		goto out_freepcpu;
+		goto err;
 
 	atomic_set(&ctx->reqs_available, ctx->nr_events - 1);
 	ctx->req_batch = (ctx->nr_events - 1) / (num_possible_cpus() * 4);
@@ -615,7 +591,8 @@
 	if (aio_nr + nr_events > (aio_max_nr * 2UL) ||
 	    aio_nr + nr_events < aio_nr) {
 		spin_unlock(&aio_nr_lock);
-		goto out_cleanup;
+		err = -EAGAIN;
+		goto err;
 	}
 	aio_nr += ctx->max_reqs;
 	spin_unlock(&aio_nr_lock);
@@ -624,23 +601,19 @@
 
 	err = ioctx_add_table(ctx, mm);
 	if (err)
-		goto out_cleanup_put;
+		goto err_cleanup;
 
 	pr_debug("allocated ioctx %p[%ld]: mm=%p mask=0x%x\n",
 		 ctx, ctx->user_id, mm, ctx->nr_events);
 	return ctx;
 
-out_cleanup_put:
-	percpu_ref_put(&ctx->users);
-out_cleanup:
-	err = -EAGAIN;
+err_cleanup:
+	aio_nr_sub(ctx->max_reqs);
+err:
 	aio_free_ring(ctx);
-out_freepcpu:
 	free_percpu(ctx->cpu);
-out_freeref:
+	free_percpu(ctx->reqs.pcpu_count);
 	free_percpu(ctx->users.pcpu_count);
-out_freectx:
-	put_aio_ring_file(ctx);
 	kmem_cache_free(kioctx_cachep, ctx);
 	pr_debug("error allocating ioctx %d\n", err);
 	return ERR_PTR(err);
@@ -675,10 +648,7 @@
 		 * -EAGAIN with no ioctxs actually in use (as far as userspace
 		 *  could tell).
 		 */
-		spin_lock(&aio_nr_lock);
-		BUG_ON(aio_nr - ctx->max_reqs > aio_nr);
-		aio_nr -= ctx->max_reqs;
-		spin_unlock(&aio_nr_lock);
+		aio_nr_sub(ctx->max_reqs);
 
 		if (ctx->mmap_size)
 			vm_munmap(ctx->mmap_base, ctx->mmap_size);
@@ -810,6 +780,8 @@
 	if (unlikely(!req))
 		goto out_put;
 
+	percpu_ref_get(&ctx->reqs);
+
 	req->ki_ctx = ctx;
 	return req;
 out_put:
@@ -879,12 +851,6 @@
 		return;
 	}
 
-	/*
-	 * Take rcu_read_lock() in case the kioctx is being destroyed, as we
-	 * need to issue a wakeup after incrementing reqs_available.
-	 */
-	rcu_read_lock();
-
 	if (iocb->ki_list.next) {
 		unsigned long flags;
 
@@ -959,7 +925,7 @@
 	if (waitqueue_active(&ctx->wait))
 		wake_up(&ctx->wait);
 
-	rcu_read_unlock();
+	percpu_ref_put(&ctx->reqs);
 }
 EXPORT_SYMBOL(aio_complete);
 
@@ -1370,6 +1336,7 @@
 	return 0;
 out_put_req:
 	put_reqs_available(ctx, 1);
+	percpu_ref_put(&ctx->reqs);
 	kiocb_free(req);
 	return ret;
 }