ext4: fix races between page faults and hole punching

Currently, page faults and hole punching are completely unsynchronized.
This can result in page fault faulting in a page into a range that we
are punching after truncate_pagecache_range() has been called and thus
we can end up with a page mapped to disk blocks that will be shortly
freed. Filesystem corruption will shortly follow. Note that the same
race is avoided for truncate by checking page fault offset against
i_size but there isn't similar mechanism available for punching holes.

Fix the problem by creating new rw semaphore i_mmap_sem in inode and
grab it for writing over truncate, hole punching, and other functions
removing blocks from extent tree and for read over page faults. We
cannot easily use i_data_sem for this since that ranks below transaction
start and we need something ranking above it so that it can be held over
the whole truncate / hole punching operation. Also remove various
workarounds we had in the code to reduce race window when page fault
could have created pages with stale mapping information.

Signed-off-by: Jan Kara <jack@suse.com>
Signed-off-by: Theodore Ts'o <tytso@mit.edu>
diff --git a/fs/ext4/file.c b/fs/ext4/file.c
index 113837e..0d24ebc 100644
--- a/fs/ext4/file.c
+++ b/fs/ext4/file.c
@@ -209,15 +209,18 @@
 {
 	int result;
 	handle_t *handle = NULL;
-	struct super_block *sb = file_inode(vma->vm_file)->i_sb;
+	struct inode *inode = file_inode(vma->vm_file);
+	struct super_block *sb = inode->i_sb;
 	bool write = vmf->flags & FAULT_FLAG_WRITE;
 
 	if (write) {
 		sb_start_pagefault(sb);
 		file_update_time(vma->vm_file);
+		down_read(&EXT4_I(inode)->i_mmap_sem);
 		handle = ext4_journal_start_sb(sb, EXT4_HT_WRITE_PAGE,
 						EXT4_DATA_TRANS_BLOCKS(sb));
-	}
+	} else
+		down_read(&EXT4_I(inode)->i_mmap_sem);
 
 	if (IS_ERR(handle))
 		result = VM_FAULT_SIGBUS;
@@ -228,8 +231,10 @@
 	if (write) {
 		if (!IS_ERR(handle))
 			ext4_journal_stop(handle);
+		up_read(&EXT4_I(inode)->i_mmap_sem);
 		sb_end_pagefault(sb);
-	}
+	} else
+		up_read(&EXT4_I(inode)->i_mmap_sem);
 
 	return result;
 }
@@ -246,10 +251,12 @@
 	if (write) {
 		sb_start_pagefault(sb);
 		file_update_time(vma->vm_file);
+		down_read(&EXT4_I(inode)->i_mmap_sem);
 		handle = ext4_journal_start_sb(sb, EXT4_HT_WRITE_PAGE,
 				ext4_chunk_trans_blocks(inode,
 							PMD_SIZE / PAGE_SIZE));
-	}
+	} else
+		down_read(&EXT4_I(inode)->i_mmap_sem);
 
 	if (IS_ERR(handle))
 		result = VM_FAULT_SIGBUS;
@@ -260,30 +267,71 @@
 	if (write) {
 		if (!IS_ERR(handle))
 			ext4_journal_stop(handle);
+		up_read(&EXT4_I(inode)->i_mmap_sem);
 		sb_end_pagefault(sb);
-	}
+	} else
+		up_read(&EXT4_I(inode)->i_mmap_sem);
 
 	return result;
 }
 
 static int ext4_dax_mkwrite(struct vm_area_struct *vma, struct vm_fault *vmf)
 {
-	return dax_mkwrite(vma, vmf, ext4_get_block_dax,
-				ext4_end_io_unwritten);
+	int err;
+	struct inode *inode = file_inode(vma->vm_file);
+
+	sb_start_pagefault(inode->i_sb);
+	file_update_time(vma->vm_file);
+	down_read(&EXT4_I(inode)->i_mmap_sem);
+	err = __dax_mkwrite(vma, vmf, ext4_get_block_dax,
+			    ext4_end_io_unwritten);
+	up_read(&EXT4_I(inode)->i_mmap_sem);
+	sb_end_pagefault(inode->i_sb);
+
+	return err;
+}
+
+/*
+ * Handle write fault for VM_MIXEDMAP mappings. Similarly to ext4_dax_mkwrite()
+ * handler we check for races agaist truncate. Note that since we cycle through
+ * i_mmap_sem, we are sure that also any hole punching that began before we
+ * were called is finished by now and so if it included part of the file we
+ * are working on, our pte will get unmapped and the check for pte_same() in
+ * wp_pfn_shared() fails. Thus fault gets retried and things work out as
+ * desired.
+ */
+static int ext4_dax_pfn_mkwrite(struct vm_area_struct *vma,
+				struct vm_fault *vmf)
+{
+	struct inode *inode = file_inode(vma->vm_file);
+	struct super_block *sb = inode->i_sb;
+	int ret = VM_FAULT_NOPAGE;
+	loff_t size;
+
+	sb_start_pagefault(sb);
+	file_update_time(vma->vm_file);
+	down_read(&EXT4_I(inode)->i_mmap_sem);
+	size = (i_size_read(inode) + PAGE_SIZE - 1) >> PAGE_SHIFT;
+	if (vmf->pgoff >= size)
+		ret = VM_FAULT_SIGBUS;
+	up_read(&EXT4_I(inode)->i_mmap_sem);
+	sb_end_pagefault(sb);
+
+	return ret;
 }
 
 static const struct vm_operations_struct ext4_dax_vm_ops = {
 	.fault		= ext4_dax_fault,
 	.pmd_fault	= ext4_dax_pmd_fault,
 	.page_mkwrite	= ext4_dax_mkwrite,
-	.pfn_mkwrite	= dax_pfn_mkwrite,
+	.pfn_mkwrite	= ext4_dax_pfn_mkwrite,
 };
 #else
 #define ext4_dax_vm_ops	ext4_file_vm_ops
 #endif
 
 static const struct vm_operations_struct ext4_file_vm_ops = {
-	.fault		= filemap_fault,
+	.fault		= ext4_filemap_fault,
 	.map_pages	= filemap_map_pages,
 	.page_mkwrite   = ext4_page_mkwrite,
 };