thp: keep huge zero page pinned until tlb flush

Andrea has found[1] a race condition on MMU-gather based TLB flush vs
split_huge_page() or shrinker which frees huge zero under us (patch 1/2
and 2/2 respectively).

With new THP refcounting, we don't need patch 1/2: mmu_gather keeps the
page pinned until flush is complete and the pin prevents the page from
being split under us.

We still need patch 2/2.  This is simplified version of Andrea's patch.
We don't need fancy encoding.

[1] http://lkml.kernel.org/r/1447938052-22165-1-git-send-email-aarcange@redhat.com

Signed-off-by: Kirill A. Shutemov <kirill.shutemov@linux.intel.com>
Reported-by: Andrea Arcangeli <aarcange@redhat.com>
Reviewed-by: Andrea Arcangeli <aarcange@redhat.com>
Cc: "Aneesh Kumar K.V" <aneesh.kumar@linux.vnet.ibm.com>
Cc: Mel Gorman <mgorman@techsingularity.net>
Cc: Hugh Dickins <hughd@google.com>
Cc: Johannes Weiner <hannes@cmpxchg.org>
Cc: Dave Hansen <dave.hansen@intel.com>
Cc: Vlastimil Babka <vbabka@suse.cz>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
Signed-off-by: Linus Torvalds <torvalds@linux-foundation.org>
diff --git a/include/linux/huge_mm.h b/include/linux/huge_mm.h
index 7008623..d7b9e53 100644
--- a/include/linux/huge_mm.h
+++ b/include/linux/huge_mm.h
@@ -152,6 +152,7 @@
 }
 
 struct page *get_huge_zero_page(void);
+void put_huge_zero_page(void);
 
 #else /* CONFIG_TRANSPARENT_HUGEPAGE */
 #define HPAGE_PMD_SHIFT ({ BUILD_BUG(); 0; })
@@ -208,6 +209,10 @@
 	return false;
 }
 
+static inline void put_huge_zero_page(void)
+{
+	BUILD_BUG();
+}
 
 static inline struct page *follow_devmap_pmd(struct vm_area_struct *vma,
 		unsigned long addr, pmd_t *pmd, int flags)
diff --git a/mm/huge_memory.c b/mm/huge_memory.c
index 86f9f8b..5346de0 100644
--- a/mm/huge_memory.c
+++ b/mm/huge_memory.c
@@ -232,7 +232,7 @@
 	return READ_ONCE(huge_zero_page);
 }
 
-static void put_huge_zero_page(void)
+void put_huge_zero_page(void)
 {
 	/*
 	 * Counter should never go to zero here. Only shrinker can put
@@ -1684,12 +1684,12 @@
 	if (vma_is_dax(vma)) {
 		spin_unlock(ptl);
 		if (is_huge_zero_pmd(orig_pmd))
-			put_huge_zero_page();
+			tlb_remove_page(tlb, pmd_page(orig_pmd));
 	} else if (is_huge_zero_pmd(orig_pmd)) {
 		pte_free(tlb->mm, pgtable_trans_huge_withdraw(tlb->mm, pmd));
 		atomic_long_dec(&tlb->mm->nr_ptes);
 		spin_unlock(ptl);
-		put_huge_zero_page();
+		tlb_remove_page(tlb, pmd_page(orig_pmd));
 	} else {
 		struct page *page = pmd_page(orig_pmd);
 		page_remove_rmap(page, true);
diff --git a/mm/swap.c b/mm/swap.c
index a0bc206..03aacbc 100644
--- a/mm/swap.c
+++ b/mm/swap.c
@@ -728,6 +728,11 @@
 			zone = NULL;
 		}
 
+		if (is_huge_zero_page(page)) {
+			put_huge_zero_page();
+			continue;
+		}
+
 		page = compound_head(page);
 		if (!put_page_testzero(page))
 			continue;