ksm: fix oom deadlock

There's a now-obvious deadlock in KSM's out-of-memory handling:
imagine ksmd or KSM_RUN_UNMERGE handling, holding ksm_thread_mutex,
trying to allocate a page to break KSM in an mm which becomes the
OOM victim (quite likely in the unmerge case): it's killed and goes
to exit, and hangs there waiting to acquire ksm_thread_mutex.

Clearly we must not require ksm_thread_mutex in __ksm_exit, simple
though that made everything else: perhaps use mmap_sem somehow?
And part of the answer lies in the comments on unmerge_ksm_pages:
__ksm_exit should also leave all the rmap_item removal to ksmd.

But there's a fundamental problem, that KSM relies upon mmap_sem to
guarantee the consistency of the mm it's dealing with, yet exit_mmap
tears down an mm without taking mmap_sem.  And bumping mm_users won't
help at all, that just ensures that the pages the OOM killer assumes
are on their way to being freed will not be freed.

The best answer seems to be, to move the ksm_exit callout from just
before exit_mmap, to the middle of exit_mmap: after the mm's pages
have been freed (if the mmu_gather is flushed), but before its page
tables and vma structures have been freed; and down_write,up_write
mmap_sem there to serialize with KSM's own reliance on mmap_sem.

But KSM then needs to be careful, whenever it downs mmap_sem, to
check that the mm is not already exiting: there's a danger of using
find_vma on a layout that's being torn apart, or writing into page
tables which have been freed for reuse; and even do_anonymous_page
and __do_fault need to check they're not being called by break_ksm
to reinstate a pte after zap_pte_range has zapped that page table.

Though it might be clearer to add an exiting flag, set while holding
mmap_sem in __ksm_exit, that wouldn't cover the issue of reinstating
a zapped pte.  All we need is to check whether mm_users is 0 - but
must remember that ksmd may detect that before __ksm_exit is reached.
So, ksm_test_exit(mm) added to comment such checks on mm->mm_users.

__ksm_exit now has to leave clearing up the rmap_items to ksmd,
that needs ksm_thread_mutex; but shift the exiting mm just after the
ksm_scan cursor so that it will soon be dealt with.  __ksm_enter raise
mm_count to hold the mm_struct, ksmd's exit processing (exactly like
its processing when it finds all VM_MERGEABLEs unmapped) mmdrop it,
similar procedure for KSM_RUN_UNMERGE (which has stopped ksmd).

But also give __ksm_exit a fast path: when there's no complication
(no rmap_items attached to mm and it's not at the ksm_scan cursor),
it can safely do all the exiting work itself.  This is not just an
optimization: when ksmd is not running, the raised mm_count would
otherwise leak mm_structs.

Signed-off-by: Hugh Dickins <hugh.dickins@tiscali.co.uk>
Acked-by: Izik Eidus <ieidus@redhat.com>
Cc: Andrea Arcangeli <aarcange@redhat.com>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
Signed-off-by: Linus Torvalds <torvalds@linux-foundation.org>
diff --git a/mm/ksm.c b/mm/ksm.c
index 7e4d255d..722e3f2 100644
--- a/mm/ksm.c
+++ b/mm/ksm.c
@@ -32,6 +32,7 @@
 #include <linux/mmu_notifier.h>
 #include <linux/ksm.h>
 
+#include <asm/tlb.h>
 #include <asm/tlbflush.h>
 
 /*
@@ -347,6 +348,8 @@
 	struct vm_area_struct *vma;
 
 	down_read(&mm->mmap_sem);
+	if (ksm_test_exit(mm))
+		goto out;
 	vma = find_vma(mm, addr);
 	if (!vma || vma->vm_start > addr)
 		goto out;
@@ -365,6 +368,8 @@
 	struct page *page;
 
 	down_read(&mm->mmap_sem);
+	if (ksm_test_exit(mm))
+		goto out;
 	vma = find_vma(mm, addr);
 	if (!vma || vma->vm_start > addr)
 		goto out;
@@ -439,11 +444,11 @@
 	} else if (rmap_item->address & NODE_FLAG) {
 		unsigned char age;
 		/*
-		 * ksm_thread can and must skip the rb_erase, because
+		 * Usually ksmd can and must skip the rb_erase, because
 		 * root_unstable_tree was already reset to RB_ROOT.
-		 * But __ksm_exit has to be careful: do the rb_erase
-		 * if it's interrupting a scan, and this rmap_item was
-		 * inserted by this scan rather than left from before.
+		 * But be careful when an mm is exiting: do the rb_erase
+		 * if this rmap_item was inserted by this scan, rather
+		 * than left over from before.
 		 */
 		age = (unsigned char)(ksm_scan.seqnr - rmap_item->address);
 		BUG_ON(age > 1);
@@ -491,6 +496,8 @@
 	int err = 0;
 
 	for (addr = start; addr < end && !err; addr += PAGE_SIZE) {
+		if (ksm_test_exit(vma->vm_mm))
+			break;
 		if (signal_pending(current))
 			err = -ERESTARTSYS;
 		else
@@ -507,34 +514,50 @@
 	int err = 0;
 
 	spin_lock(&ksm_mmlist_lock);
-	mm_slot = list_entry(ksm_mm_head.mm_list.next,
+	ksm_scan.mm_slot = list_entry(ksm_mm_head.mm_list.next,
 						struct mm_slot, mm_list);
 	spin_unlock(&ksm_mmlist_lock);
 
-	while (mm_slot != &ksm_mm_head) {
+	for (mm_slot = ksm_scan.mm_slot;
+			mm_slot != &ksm_mm_head; mm_slot = ksm_scan.mm_slot) {
 		mm = mm_slot->mm;
 		down_read(&mm->mmap_sem);
 		for (vma = mm->mmap; vma; vma = vma->vm_next) {
+			if (ksm_test_exit(mm))
+				break;
 			if (!(vma->vm_flags & VM_MERGEABLE) || !vma->anon_vma)
 				continue;
 			err = unmerge_ksm_pages(vma,
 						vma->vm_start, vma->vm_end);
-			if (err) {
-				up_read(&mm->mmap_sem);
-				goto out;
-			}
+			if (err)
+				goto error;
 		}
+
 		remove_trailing_rmap_items(mm_slot, mm_slot->rmap_list.next);
-		up_read(&mm->mmap_sem);
 
 		spin_lock(&ksm_mmlist_lock);
-		mm_slot = list_entry(mm_slot->mm_list.next,
+		ksm_scan.mm_slot = list_entry(mm_slot->mm_list.next,
 						struct mm_slot, mm_list);
-		spin_unlock(&ksm_mmlist_lock);
+		if (ksm_test_exit(mm)) {
+			hlist_del(&mm_slot->link);
+			list_del(&mm_slot->mm_list);
+			spin_unlock(&ksm_mmlist_lock);
+
+			free_mm_slot(mm_slot);
+			clear_bit(MMF_VM_MERGEABLE, &mm->flags);
+			up_read(&mm->mmap_sem);
+			mmdrop(mm);
+		} else {
+			spin_unlock(&ksm_mmlist_lock);
+			up_read(&mm->mmap_sem);
+		}
 	}
 
 	ksm_scan.seqnr = 0;
-out:
+	return 0;
+
+error:
+	up_read(&mm->mmap_sem);
 	spin_lock(&ksm_mmlist_lock);
 	ksm_scan.mm_slot = &ksm_mm_head;
 	spin_unlock(&ksm_mmlist_lock);
@@ -755,6 +778,9 @@
 	int err = -EFAULT;
 
 	down_read(&mm1->mmap_sem);
+	if (ksm_test_exit(mm1))
+		goto out;
+
 	vma = find_vma(mm1, addr1);
 	if (!vma || vma->vm_start > addr1)
 		goto out;
@@ -796,6 +822,10 @@
 		return err;
 
 	down_read(&mm1->mmap_sem);
+	if (ksm_test_exit(mm1)) {
+		up_read(&mm1->mmap_sem);
+		goto out;
+	}
 	vma = find_vma(mm1, addr1);
 	if (!vma || vma->vm_start > addr1) {
 		up_read(&mm1->mmap_sem);
@@ -1174,7 +1204,12 @@
 
 	mm = slot->mm;
 	down_read(&mm->mmap_sem);
-	for (vma = find_vma(mm, ksm_scan.address); vma; vma = vma->vm_next) {
+	if (ksm_test_exit(mm))
+		vma = NULL;
+	else
+		vma = find_vma(mm, ksm_scan.address);
+
+	for (; vma; vma = vma->vm_next) {
 		if (!(vma->vm_flags & VM_MERGEABLE))
 			continue;
 		if (ksm_scan.address < vma->vm_start)
@@ -1183,6 +1218,8 @@
 			ksm_scan.address = vma->vm_end;
 
 		while (ksm_scan.address < vma->vm_end) {
+			if (ksm_test_exit(mm))
+				break;
 			*page = follow_page(vma, ksm_scan.address, FOLL_GET);
 			if (*page && PageAnon(*page)) {
 				flush_anon_page(vma, *page, ksm_scan.address);
@@ -1205,6 +1242,11 @@
 		}
 	}
 
+	if (ksm_test_exit(mm)) {
+		ksm_scan.address = 0;
+		ksm_scan.rmap_item = list_entry(&slot->rmap_list,
+						struct rmap_item, link);
+	}
 	/*
 	 * Nuke all the rmap_items that are above this current rmap:
 	 * because there were no VM_MERGEABLE vmas with such addresses.
@@ -1219,24 +1261,29 @@
 		 * We've completed a full scan of all vmas, holding mmap_sem
 		 * throughout, and found no VM_MERGEABLE: so do the same as
 		 * __ksm_exit does to remove this mm from all our lists now.
+		 * This applies either when cleaning up after __ksm_exit
+		 * (but beware: we can reach here even before __ksm_exit),
+		 * or when all VM_MERGEABLE areas have been unmapped (and
+		 * mmap_sem then protects against race with MADV_MERGEABLE).
 		 */
 		hlist_del(&slot->link);
 		list_del(&slot->mm_list);
+		spin_unlock(&ksm_mmlist_lock);
+
 		free_mm_slot(slot);
 		clear_bit(MMF_VM_MERGEABLE, &mm->flags);
+		up_read(&mm->mmap_sem);
+		mmdrop(mm);
+	} else {
+		spin_unlock(&ksm_mmlist_lock);
+		up_read(&mm->mmap_sem);
 	}
-	spin_unlock(&ksm_mmlist_lock);
-	up_read(&mm->mmap_sem);
 
 	/* Repeat until we've completed scanning the whole list */
 	slot = ksm_scan.mm_slot;
 	if (slot != &ksm_mm_head)
 		goto next_mm;
 
-	/*
-	 * Bump seqnr here rather than at top, so that __ksm_exit
-	 * can skip rb_erase on unstable tree until we run again.
-	 */
 	ksm_scan.seqnr++;
 	return NULL;
 }
@@ -1361,6 +1408,7 @@
 	spin_unlock(&ksm_mmlist_lock);
 
 	set_bit(MMF_VM_MERGEABLE, &mm->flags);
+	atomic_inc(&mm->mm_count);
 
 	if (needs_wakeup)
 		wake_up_interruptible(&ksm_thread_wait);
@@ -1368,41 +1416,45 @@
 	return 0;
 }
 
-void __ksm_exit(struct mm_struct *mm)
+void __ksm_exit(struct mm_struct *mm,
+		struct mmu_gather **tlbp, unsigned long end)
 {
 	struct mm_slot *mm_slot;
+	int easy_to_free = 0;
 
 	/*
-	 * This process is exiting: doesn't hold and doesn't need mmap_sem;
-	 * but we do need to exclude ksmd and other exiters while we modify
-	 * the various lists and trees.
+	 * This process is exiting: if it's straightforward (as is the
+	 * case when ksmd was never running), free mm_slot immediately.
+	 * But if it's at the cursor or has rmap_items linked to it, use
+	 * mmap_sem to synchronize with any break_cows before pagetables
+	 * are freed, and leave the mm_slot on the list for ksmd to free.
+	 * Beware: ksm may already have noticed it exiting and freed the slot.
 	 */
-	mutex_lock(&ksm_thread_mutex);
+
 	spin_lock(&ksm_mmlist_lock);
 	mm_slot = get_mm_slot(mm);
-	if (!list_empty(&mm_slot->rmap_list)) {
-		spin_unlock(&ksm_mmlist_lock);
-		remove_trailing_rmap_items(mm_slot, mm_slot->rmap_list.next);
-		spin_lock(&ksm_mmlist_lock);
+	if (mm_slot && ksm_scan.mm_slot != mm_slot) {
+		if (list_empty(&mm_slot->rmap_list)) {
+			hlist_del(&mm_slot->link);
+			list_del(&mm_slot->mm_list);
+			easy_to_free = 1;
+		} else {
+			list_move(&mm_slot->mm_list,
+				  &ksm_scan.mm_slot->mm_list);
+		}
 	}
-
-	if (ksm_scan.mm_slot == mm_slot) {
-		ksm_scan.mm_slot = list_entry(
-			mm_slot->mm_list.next, struct mm_slot, mm_list);
-		ksm_scan.address = 0;
-		ksm_scan.rmap_item = list_entry(
-			&ksm_scan.mm_slot->rmap_list, struct rmap_item, link);
-		if (ksm_scan.mm_slot == &ksm_mm_head)
-			ksm_scan.seqnr++;
-	}
-
-	hlist_del(&mm_slot->link);
-	list_del(&mm_slot->mm_list);
 	spin_unlock(&ksm_mmlist_lock);
 
-	free_mm_slot(mm_slot);
-	clear_bit(MMF_VM_MERGEABLE, &mm->flags);
-	mutex_unlock(&ksm_thread_mutex);
+	if (easy_to_free) {
+		free_mm_slot(mm_slot);
+		clear_bit(MMF_VM_MERGEABLE, &mm->flags);
+		mmdrop(mm);
+	} else if (mm_slot) {
+		tlb_finish_mmu(*tlbp, 0, end);
+		down_write(&mm->mmap_sem);
+		up_write(&mm->mmap_sem);
+		*tlbp = tlb_gather_mmu(mm, 1);
+	}
 }
 
 #define KSM_ATTR_RO(_name) \