cxl: Fix driver use count

cxl keeps a driver use count, which is used with the hash memory model
on p8 to know when to upgrade local TLBIs to global and to trigger
callbacks to manage the MMU for PSL8.

If a process opens a context and closes without attaching or fails the
attachment, the driver use count is never decremented. As a
consequence, TLB invalidations remain global, even if there are no
active cxl contexts.

We should increment the driver use count when the process is attaching
to the cxl adapter, and not on open. It's not needed before the
adapter starts using the context and the use count is decremented on
the detach path, so it makes more sense.

It affects only the user api. The kernel api is already doing The
Right Thing.

Signed-off-by: Frederic Barrat <fbarrat@linux.vnet.ibm.com>
Cc: stable@vger.kernel.org # v4.2+
Fixes: 7bb5d91a4dda ("cxl: Rework context lifetimes")
Acked-by: Andrew Donnellan <andrew.donnellan@au1.ibm.com>
Signed-off-by: Michael Ellerman <mpe@ellerman.id.au>
diff --git a/drivers/misc/cxl/api.c b/drivers/misc/cxl/api.c
index 1a138c8..a0c44d1 100644
--- a/drivers/misc/cxl/api.c
+++ b/drivers/misc/cxl/api.c
@@ -336,6 +336,10 @@
 			mmput(ctx->mm);
 	}
 
+	/*
+	 * Increment driver use count. Enables global TLBIs for hash
+	 * and callbacks to handle the segment table
+	 */
 	cxl_ctx_get();
 
 	if ((rc = cxl_ops->attach_process(ctx, kernel, wed, 0))) {
diff --git a/drivers/misc/cxl/file.c b/drivers/misc/cxl/file.c
index 0761271..4bfad9f 100644
--- a/drivers/misc/cxl/file.c
+++ b/drivers/misc/cxl/file.c
@@ -95,7 +95,6 @@
 
 	pr_devel("afu_open pe: %i\n", ctx->pe);
 	file->private_data = ctx;
-	cxl_ctx_get();
 
 	/* indicate success */
 	rc = 0;
@@ -225,6 +224,12 @@
 	if (ctx->mm)
 		mmput(ctx->mm);
 
+	/*
+	 * Increment driver use count. Enables global TLBIs for hash
+	 * and callbacks to handle the segment table
+	 */
+	cxl_ctx_get();
+
 	trace_cxl_attach(ctx, work.work_element_descriptor, work.num_interrupts, amr);
 
 	if ((rc = cxl_ops->attach_process(ctx, false, work.work_element_descriptor,
@@ -233,6 +238,7 @@
 		cxl_adapter_context_put(ctx->afu->adapter);
 		put_pid(ctx->pid);
 		ctx->pid = NULL;
+		cxl_ctx_put();
 		cxl_context_mm_count_put(ctx);
 		goto out;
 	}