dax: update to new mmu_notifier semantic

Replace all mmu_notifier_invalidate_page() calls by *_invalidate_range()
and make sure it is bracketed by calls to *_invalidate_range_start()/end().

Note that because we can not presume the pmd value or pte value we have
to assume the worst and unconditionaly report an invalidation as
happening.

Signed-off-by: Jérôme Glisse <jglisse@redhat.com>
Cc: Dan Williams <dan.j.williams@intel.com>
Cc: Ross Zwisler <ross.zwisler@linux.intel.com>
Cc: Bernhard Held <berny156@gmx.de>
Cc: Adam Borowski <kilobyte@angband.pl>
Cc: Andrea Arcangeli <aarcange@redhat.com>
Cc: Radim Krčmář <rkrcmar@redhat.com>
Cc: Wanpeng Li <kernellwp@gmail.com>
Cc: Paolo Bonzini <pbonzini@redhat.com>
Cc: Takashi Iwai <tiwai@suse.de>
Cc: Nadav Amit <nadav.amit@gmail.com>
Cc: Mike Galbraith <efault@gmx.de>
Cc: Kirill A. Shutemov <kirill.shutemov@linux.intel.com>
Cc: axie <axie@amd.com>
Cc: Andrew Morton <akpm@linux-foundation.org>
Signed-off-by: Linus Torvalds <torvalds@linux-foundation.org>
diff --git a/fs/dax.c b/fs/dax.c
index 865d42c..ab925dc 100644
--- a/fs/dax.c
+++ b/fs/dax.c
@@ -646,11 +646,10 @@
 	pte_t pte, *ptep = NULL;
 	pmd_t *pmdp = NULL;
 	spinlock_t *ptl;
-	bool changed;
 
 	i_mmap_lock_read(mapping);
 	vma_interval_tree_foreach(vma, &mapping->i_mmap, index, index) {
-		unsigned long address;
+		unsigned long address, start, end;
 
 		cond_resched();
 
@@ -658,8 +657,13 @@
 			continue;
 
 		address = pgoff_address(index, vma);
-		changed = false;
-		if (follow_pte_pmd(vma->vm_mm, address, &ptep, &pmdp, &ptl))
+
+		/*
+		 * Note because we provide start/end to follow_pte_pmd it will
+		 * call mmu_notifier_invalidate_range_start() on our behalf
+		 * before taking any lock.
+		 */
+		if (follow_pte_pmd(vma->vm_mm, address, &start, &end, &ptep, &pmdp, &ptl))
 			continue;
 
 		if (pmdp) {
@@ -676,7 +680,7 @@
 			pmd = pmd_wrprotect(pmd);
 			pmd = pmd_mkclean(pmd);
 			set_pmd_at(vma->vm_mm, address, pmdp, pmd);
-			changed = true;
+			mmu_notifier_invalidate_range(vma->vm_mm, start, end);
 unlock_pmd:
 			spin_unlock(ptl);
 #endif
@@ -691,13 +695,12 @@
 			pte = pte_wrprotect(pte);
 			pte = pte_mkclean(pte);
 			set_pte_at(vma->vm_mm, address, ptep, pte);
-			changed = true;
+			mmu_notifier_invalidate_range(vma->vm_mm, start, end);
 unlock_pte:
 			pte_unmap_unlock(ptep, ptl);
 		}
 
-		if (changed)
-			mmu_notifier_invalidate_page(vma->vm_mm, address);
+		mmu_notifier_invalidate_range_end(vma->vm_mm, start, end);
 	}
 	i_mmap_unlock_read(mapping);
 }