NVMe: Handle failures from memory allocations in nvme_setup_prps

If any of the memory allocations in nvme_setup_prps fail, handle it by
modifying the passed-in data length to reflect the number of bytes we are
actually able to send.  Also allow the caller to specify the GFP flags
they need; for user-initiated commands, we can use GFP_KERNEL allocations.

The various callers are updated to handle this possibility; the main
I/O path is already prepared for this possibility (as it may happen
due to nvme_map_bio being unable to map all the segments of the I/O).
The other callers return -ENOMEM instead of doing partial I/Os.

Reported-by: Andi Kleen <andi@firstfloor.org>
Signed-off-by: Matthew Wilcox <matthew.r.wilcox@intel.com>
diff --git a/drivers/block/nvme.c b/drivers/block/nvme.c
index 79012c5..ddc21ba 100644
--- a/drivers/block/nvme.c
+++ b/drivers/block/nvme.c
@@ -329,9 +329,11 @@
 /* length is in bytes */
 static struct nvme_prps *nvme_setup_prps(struct nvme_dev *dev,
 					struct nvme_common_command *cmd,
-					struct scatterlist *sg, int length)
+					struct scatterlist *sg, int *len,
+					gfp_t gfp)
 {
 	struct dma_pool *pool;
+	int length = *len;
 	int dma_len = sg_dma_len(sg);
 	u64 dma_addr = sg_dma_address(sg);
 	int offset = offset_in_page(dma_addr);
@@ -361,7 +363,12 @@
 
 	nprps = DIV_ROUND_UP(length, PAGE_SIZE);
 	npages = DIV_ROUND_UP(8 * nprps, PAGE_SIZE);
-	prps = kmalloc(sizeof(*prps) + sizeof(__le64 *) * npages, GFP_ATOMIC);
+	prps = kmalloc(sizeof(*prps) + sizeof(__le64 *) * npages, gfp);
+	if (!prps) {
+		cmd->prp2 = cpu_to_le64(dma_addr);
+		*len = (*len - length) + PAGE_SIZE;
+		return prps;
+	}
 	prp_page = 0;
 	if (nprps <= (256 / 8)) {
 		pool = dev->prp_small_pool;
@@ -371,7 +378,13 @@
 		prps->npages = npages;
 	}
 
-	prp_list = dma_pool_alloc(pool, GFP_ATOMIC, &prp_dma);
+	prp_list = dma_pool_alloc(pool, gfp, &prp_dma);
+	if (!prp_list) {
+		cmd->prp2 = cpu_to_le64(dma_addr);
+		*len = (*len - length) + PAGE_SIZE;
+		kfree(prps);
+		return NULL;
+	}
 	prps->list[prp_page++] = prp_list;
 	prps->first_dma = prp_dma;
 	cmd->prp2 = cpu_to_le64(prp_dma);
@@ -379,7 +392,11 @@
 	for (;;) {
 		if (i == PAGE_SIZE / 8) {
 			__le64 *old_prp_list = prp_list;
-			prp_list = dma_pool_alloc(pool, GFP_ATOMIC, &prp_dma);
+			prp_list = dma_pool_alloc(pool, gfp, &prp_dma);
+			if (!prp_list) {
+				*len = (*len - length);
+				return prps;
+			}
 			prps->list[prp_page++] = prp_list;
 			prp_list[0] = old_prp_list[i - 1];
 			old_prp_list[i - 1] = cpu_to_le64(prp_dma);
@@ -525,7 +542,7 @@
 	cmnd->rw.command_id = cmdid;
 	cmnd->rw.nsid = cpu_to_le32(ns->ns_id);
 	nbio->prps = nvme_setup_prps(nvmeq->dev, &cmnd->common, nbio->sg,
-								length);
+							&length, GFP_ATOMIC);
 	cmnd->rw.slba = cpu_to_le64(bio->bi_sector >> (ns->lba_shift - 9));
 	cmnd->rw.length = cpu_to_le16((length >> ns->lba_shift) - 1);
 	cmnd->rw.control = cpu_to_le16(control);
@@ -1009,15 +1026,18 @@
 					unsigned long addr, unsigned length,
 					struct nvme_command *cmd)
 {
-	int err, nents;
+	int err, nents, tmplen = length;
 	struct scatterlist *sg;
 	struct nvme_prps *prps;
 
 	nents = nvme_map_user_pages(dev, 0, addr, length, &sg);
 	if (nents < 0)
 		return nents;
-	prps = nvme_setup_prps(dev, &cmd->common, sg, length);
-	err = nvme_submit_admin_cmd(dev, cmd, NULL);
+	prps = nvme_setup_prps(dev, &cmd->common, sg, &tmplen, GFP_KERNEL);
+	if (tmplen != length)
+		err = -ENOMEM;
+	else
+		err = nvme_submit_admin_cmd(dev, cmd, NULL);
 	nvme_unmap_user_pages(dev, 0, addr, length, sg, nents);
 	nvme_free_prps(dev, prps);
 	return err ? -EIO : 0;
@@ -1086,7 +1106,7 @@
 	c.rw.apptag = io.apptag;
 	c.rw.appmask = io.appmask;
 	/* XXX: metadata */
-	prps = nvme_setup_prps(dev, &c.common, sg, length);
+	prps = nvme_setup_prps(dev, &c.common, sg, &length, GFP_KERNEL);
 
 	nvmeq = get_nvmeq(ns);
 	/*
@@ -1096,7 +1116,10 @@
 	 * additional races since q_lock already protects against other CPUs.
 	 */
 	put_nvmeq(nvmeq);
-	status = nvme_submit_sync_cmd(nvmeq, &c, NULL, IO_TIMEOUT);
+	if (length != (io.nblocks + 1) << ns->lba_shift)
+		status = -ENOMEM;
+	else
+		status = nvme_submit_sync_cmd(nvmeq, &c, NULL, IO_TIMEOUT);
 
 	nvme_unmap_user_pages(dev, io.opcode & 1, io.addr, length, sg, nents);
 	nvme_free_prps(dev, prps);
@@ -1109,7 +1132,7 @@
 	struct nvme_dev *dev = ns->dev;
 	struct nvme_dlfw dlfw;
 	struct nvme_command c;
-	int nents, status;
+	int nents, status, length;
 	struct scatterlist *sg;
 	struct nvme_prps *prps;
 
@@ -1117,8 +1140,9 @@
 		return -EFAULT;
 	if (dlfw.length >= (1 << 30))
 		return -EINVAL;
+	length = dlfw.length * 4;
 
-	nents = nvme_map_user_pages(dev, 1, dlfw.addr, dlfw.length * 4, &sg);
+	nents = nvme_map_user_pages(dev, 1, dlfw.addr, length, &sg);
 	if (nents < 0)
 		return nents;
 
@@ -1126,9 +1150,11 @@
 	c.dlfw.opcode = nvme_admin_download_fw;
 	c.dlfw.numd = cpu_to_le32(dlfw.length);
 	c.dlfw.offset = cpu_to_le32(dlfw.offset);
-	prps = nvme_setup_prps(dev, &c.common, sg, dlfw.length * 4);
-
-	status = nvme_submit_admin_cmd(dev, &c, NULL);
+	prps = nvme_setup_prps(dev, &c.common, sg, &length, GFP_KERNEL);
+	if (length != dlfw.length * 4)
+		status = -ENOMEM;
+	else
+		status = nvme_submit_admin_cmd(dev, &c, NULL);
 	nvme_unmap_user_pages(dev, 0, dlfw.addr, dlfw.length * 4, sg, nents);
 	nvme_free_prps(dev, prps);
 	return status;