thp: prepare for DAX huge pages

Add a vma_is_dax() helper macro to test whether the VMA is DAX, and use it
in zap_huge_pmd() and __split_huge_page_pmd().

[akpm@linux-foundation.org: fix build]
Signed-off-by: Matthew Wilcox <willy@linux.intel.com>
Cc: Hillf Danton <dhillf@gmail.com>
Cc: "Kirill A. Shutemov" <kirill.shutemov@linux.intel.com>
Cc: Theodore Ts'o <tytso@mit.edu>
Cc: Jan Kara <jack@suse.cz>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
Signed-off-by: Linus Torvalds <torvalds@linux-foundation.org>
diff --git a/mm/huge_memory.c b/mm/huge_memory.c
index ca475df..9057241 100644
--- a/mm/huge_memory.c
+++ b/mm/huge_memory.c
@@ -16,6 +16,7 @@
 #include <linux/swap.h>
 #include <linux/shrinker.h>
 #include <linux/mm_inline.h>
+#include <linux/dax.h>
 #include <linux/kthread.h>
 #include <linux/khugepaged.h>
 #include <linux/freezer.h>
@@ -794,7 +795,7 @@
 }
 
 /* Caller must hold page table lock. */
-static bool set_huge_zero_page(pgtable_t pgtable, struct mm_struct *mm,
+bool set_huge_zero_page(pgtable_t pgtable, struct mm_struct *mm,
 		struct vm_area_struct *vma, unsigned long haddr, pmd_t *pmd,
 		struct page *zero_page)
 {
@@ -1421,7 +1422,6 @@
 	int ret = 0;
 
 	if (__pmd_trans_huge_lock(pmd, vma, &ptl) == 1) {
-		struct page *page;
 		pgtable_t pgtable;
 		pmd_t orig_pmd;
 		/*
@@ -1433,13 +1433,22 @@
 		orig_pmd = pmdp_huge_get_and_clear_full(tlb->mm, addr, pmd,
 							tlb->fullmm);
 		tlb_remove_pmd_tlb_entry(tlb, pmd, addr);
-		pgtable = pgtable_trans_huge_withdraw(tlb->mm, pmd);
+		if (vma_is_dax(vma)) {
+			if (is_huge_zero_pmd(orig_pmd)) {
+				pgtable = NULL;
+			} else {
+				spin_unlock(ptl);
+				return 1;
+			}
+		} else {
+			pgtable = pgtable_trans_huge_withdraw(tlb->mm, pmd);
+		}
 		if (is_huge_zero_pmd(orig_pmd)) {
 			atomic_long_dec(&tlb->mm->nr_ptes);
 			spin_unlock(ptl);
 			put_huge_zero_page();
 		} else {
-			page = pmd_page(orig_pmd);
+			struct page *page = pmd_page(orig_pmd);
 			page_remove_rmap(page);
 			VM_BUG_ON_PAGE(page_mapcount(page) < 0, page);
 			add_mm_counter(tlb->mm, MM_ANONPAGES, -HPAGE_PMD_NR);
@@ -1448,7 +1457,8 @@
 			spin_unlock(ptl);
 			tlb_remove_page(tlb, page);
 		}
-		pte_free(tlb->mm, pgtable);
+		if (pgtable)
+			pte_free(tlb->mm, pgtable);
 		ret = 1;
 	}
 	return ret;
@@ -2914,7 +2924,7 @@
 		pmd_t *pmd)
 {
 	spinlock_t *ptl;
-	struct page *page;
+	struct page *page = NULL;
 	struct mm_struct *mm = vma->vm_mm;
 	unsigned long haddr = address & HPAGE_PMD_MASK;
 	unsigned long mmun_start;	/* For mmu_notifiers */
@@ -2927,25 +2937,25 @@
 again:
 	mmu_notifier_invalidate_range_start(mm, mmun_start, mmun_end);
 	ptl = pmd_lock(mm, pmd);
-	if (unlikely(!pmd_trans_huge(*pmd))) {
-		spin_unlock(ptl);
-		mmu_notifier_invalidate_range_end(mm, mmun_start, mmun_end);
-		return;
-	}
-	if (is_huge_zero_pmd(*pmd)) {
+	if (unlikely(!pmd_trans_huge(*pmd)))
+		goto unlock;
+	if (vma_is_dax(vma)) {
+		pmdp_huge_clear_flush(vma, haddr, pmd);
+	} else if (is_huge_zero_pmd(*pmd)) {
 		__split_huge_zero_page_pmd(vma, haddr, pmd);
-		spin_unlock(ptl);
-		mmu_notifier_invalidate_range_end(mm, mmun_start, mmun_end);
-		return;
+	} else {
+		page = pmd_page(*pmd);
+		VM_BUG_ON_PAGE(!page_count(page), page);
+		get_page(page);
 	}
-	page = pmd_page(*pmd);
-	VM_BUG_ON_PAGE(!page_count(page), page);
-	get_page(page);
+ unlock:
 	spin_unlock(ptl);
 	mmu_notifier_invalidate_range_end(mm, mmun_start, mmun_end);
 
-	split_huge_page(page);
+	if (!page)
+		return;
 
+	split_huge_page(page);
 	put_page(page);
 
 	/*