dax: update I/O path to do proper PMEM flushing

Update the DAX I/O path so that all operations that store data (I/O
writes, zeroing blocks, punching holes, etc.) properly synchronize the
stores to media using the PMEM API.  This ensures that the data DAX is
writing is durable on media before the operation completes.

Signed-off-by: Ross Zwisler <ross.zwisler@linux.intel.com>
Reviewed-by: Christoph Hellwig <hch@lst.de>
Signed-off-by: Dan Williams <dan.j.williams@intel.com>
diff --git a/fs/dax.c b/fs/dax.c
index c3e21cc..e07fecc 100644
--- a/fs/dax.c
+++ b/fs/dax.c
@@ -23,6 +23,7 @@
 #include <linux/memcontrol.h>
 #include <linux/mm.h>
 #include <linux/mutex.h>
+#include <linux/pmem.h>
 #include <linux/sched.h>
 #include <linux/uio.h>
 #include <linux/vmstat.h>
@@ -46,10 +47,7 @@
 			unsigned pgsz = PAGE_SIZE - offset_in_page(addr);
 			if (pgsz > count)
 				pgsz = count;
-			if (pgsz < PAGE_SIZE)
-				memset(addr, 0, pgsz);
-			else
-				clear_page(addr);
+			clear_pmem((void __pmem *)addr, pgsz);
 			addr += pgsz;
 			size -= pgsz;
 			count -= pgsz;
@@ -59,6 +57,7 @@
 		}
 	} while (size);
 
+	wmb_pmem();
 	return 0;
 }
 EXPORT_SYMBOL_GPL(dax_clear_blocks);
@@ -70,15 +69,16 @@
 	return bdev_direct_access(bh->b_bdev, sector, addr, &pfn, bh->b_size);
 }
 
+/* the clear_pmem() calls are ordered by a wmb_pmem() in the caller */
 static void dax_new_buf(void *addr, unsigned size, unsigned first, loff_t pos,
 			loff_t end)
 {
 	loff_t final = end - pos + first; /* The final byte of the buffer */
 
 	if (first > 0)
-		memset(addr, 0, first);
+		clear_pmem((void __pmem *)addr, first);
 	if (final < size)
-		memset(addr + final, 0, size - final);
+		clear_pmem((void __pmem *)addr + final, size - final);
 }
 
 static bool buffer_written(struct buffer_head *bh)
@@ -108,12 +108,13 @@
 	loff_t bh_max = start;
 	void *addr;
 	bool hole = false;
+	bool need_wmb = false;
 
 	if (iov_iter_rw(iter) != WRITE)
 		end = min(end, i_size_read(inode));
 
 	while (pos < end) {
-		unsigned len;
+		size_t len;
 		if (pos == max) {
 			unsigned blkbits = inode->i_blkbits;
 			sector_t block = pos >> blkbits;
@@ -145,18 +146,22 @@
 				retval = dax_get_addr(bh, &addr, blkbits);
 				if (retval < 0)
 					break;
-				if (buffer_unwritten(bh) || buffer_new(bh))
+				if (buffer_unwritten(bh) || buffer_new(bh)) {
 					dax_new_buf(addr, retval, first, pos,
 									end);
+					need_wmb = true;
+				}
 				addr += first;
 				size = retval - first;
 			}
 			max = min(pos + size, end);
 		}
 
-		if (iov_iter_rw(iter) == WRITE)
-			len = copy_from_iter_nocache(addr, max - pos, iter);
-		else if (!hole)
+		if (iov_iter_rw(iter) == WRITE) {
+			len = copy_from_iter_pmem((void __pmem *)addr,
+					max - pos, iter);
+			need_wmb = true;
+		} else if (!hole)
 			len = copy_to_iter(addr, max - pos, iter);
 		else
 			len = iov_iter_zero(max - pos, iter);
@@ -168,6 +173,9 @@
 		addr += len;
 	}
 
+	if (need_wmb)
+		wmb_pmem();
+
 	return (pos == start) ? retval : pos - start;
 }
 
@@ -303,8 +311,10 @@
 		goto out;
 	}
 
-	if (buffer_unwritten(bh) || buffer_new(bh))
-		clear_page(addr);
+	if (buffer_unwritten(bh) || buffer_new(bh)) {
+		clear_pmem((void __pmem *)addr, PAGE_SIZE);
+		wmb_pmem();
+	}
 
 	error = vm_insert_mixed(vma, vaddr, pfn);
 
@@ -542,7 +552,8 @@
 		err = dax_get_addr(&bh, &addr, inode->i_blkbits);
 		if (err < 0)
 			return err;
-		memset(addr + offset, 0, length);
+		clear_pmem((void __pmem *)addr + offset, length);
+		wmb_pmem();
 	}
 
 	return 0;