mm, mempolicy: task->mempolicy must be NULL before dropping final reference

KASAN allocates memory from the page allocator as part of
kmem_cache_free(), and that can reference current->mempolicy through any
number of allocation functions.  It needs to be NULL'd out before the
final reference is dropped to prevent a use-after-free bug:

	BUG: KASAN: use-after-free in alloc_pages_current+0x363/0x370 at addr ffff88010b48102c
	CPU: 0 PID: 15425 Comm: trinity-c2 Not tainted 4.8.0-rc2+ #140
	...
	Call Trace:
		dump_stack
		kasan_object_err
		kasan_report_error
		__asan_report_load2_noabort
		alloc_pages_current	<-- use after free
		depot_save_stack
		save_stack
		kasan_slab_free
		kmem_cache_free
		__mpol_put		<-- free
		do_exit

This patch sets current->mempolicy to NULL before dropping the final
reference.

Link: http://lkml.kernel.org/r/alpine.DEB.2.10.1608301442180.63329@chino.kir.corp.google.com
Fixes: cd11016e5f52 ("mm, kasan: stackdepot implementation. Enable stackdepot for SLAB")
Signed-off-by: David Rientjes <rientjes@google.com>
Reported-by: Vegard Nossum <vegard.nossum@oracle.com>
Acked-by: Andrey Ryabinin <aryabinin@virtuozzo.com>
Cc: Alexander Potapenko <glider@google.com>
Cc: Dmitry Vyukov <dvyukov@google.com>
Cc: <stable@vger.kernel.org>	[4.6+]
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
Signed-off-by: Linus Torvalds <torvalds@linux-foundation.org>
diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index d8c4e38..2da72a5 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -2336,6 +2336,23 @@
 	return ret;
 }
 
+/*
+ * Drop the (possibly final) reference to task->mempolicy.  It needs to be
+ * dropped after task->mempolicy is set to NULL so that any allocation done as
+ * part of its kmem_cache_free(), such as by KASAN, doesn't reference a freed
+ * policy.
+ */
+void mpol_put_task_policy(struct task_struct *task)
+{
+	struct mempolicy *pol;
+
+	task_lock(task);
+	pol = task->mempolicy;
+	task->mempolicy = NULL;
+	task_unlock(task);
+	mpol_put(pol);
+}
+
 static void sp_delete(struct shared_policy *sp, struct sp_node *n)
 {
 	pr_debug("deleting %lx-l%lx\n", n->start, n->end);