block: fix disk->part[] dereferencing race

disk->part[] is protected by its matching bdev's lock.  However,
non-critical accesses like collecting stats and printing out sysfs and
proc information used to be performed without any locking.  As
partitions can come and go dynamically, partitions can go away
underneath those non-critical accesses.  As some of those accesses are
writes, this theoretically can lead to silent corruption.

This patch fixes the race by using RCU for the partition array and dev
reference counter to hold partitions.

* Rename disk->part[] to disk->__part[] to make sure no one outside
  genhd layer proper accesses it directly.

* Use RCU for disk->__part[] dereferencing.

* Implement disk_{get|put}_part() which can be used to get and put
  partitions from gendisk respectively.

* Iterators are implemented to help iterate through all partitions
  safely.

* Functions which require RCU readlock are marked with _rcu suffix.

* Use disk_put_part() in __blkdev_put() instead of directly putting
  the contained kobject.

Signed-off-by: Tejun Heo <tj@kernel.org>
Signed-off-by: Jens Axboe <jens.axboe@oracle.com>
diff --git a/block/blk-core.c b/block/blk-core.c
index a0dc2e7..d6128d9a 100644
--- a/block/blk-core.c
+++ b/block/blk-core.c
@@ -60,7 +60,9 @@
 	if (!blk_fs_request(rq) || !rq->rq_disk)
 		return;
 
-	part = disk_map_sector(rq->rq_disk, rq->sector);
+	rcu_read_lock();
+
+	part = disk_map_sector_rcu(rq->rq_disk, rq->sector);
 	if (!new_io)
 		__all_stat_inc(rq->rq_disk, part, merges[rw], rq->sector);
 	else {
@@ -71,6 +73,8 @@
 			part->in_flight++;
 		}
 	}
+
+	rcu_read_unlock();
 }
 
 void blk_queue_congestion_threshold(struct request_queue *q)
@@ -1557,12 +1561,14 @@
 	}
 
 	if (blk_fs_request(req) && req->rq_disk) {
-		struct hd_struct *part =
-			disk_map_sector(req->rq_disk, req->sector);
 		const int rw = rq_data_dir(req);
+		struct hd_struct *part;
 
+		rcu_read_lock();
+		part = disk_map_sector_rcu(req->rq_disk, req->sector);
 		all_stat_add(req->rq_disk, part, sectors[rw],
 				nr_bytes >> 9, req->sector);
+		rcu_read_unlock();
 	}
 
 	total_bytes = bio_nbytes = 0;
@@ -1746,7 +1752,11 @@
 	if (disk && blk_fs_request(req) && req != &req->q->bar_rq) {
 		unsigned long duration = jiffies - req->start_time;
 		const int rw = rq_data_dir(req);
-		struct hd_struct *part = disk_map_sector(disk, req->sector);
+		struct hd_struct *part;
+
+		rcu_read_lock();
+
+		part = disk_map_sector_rcu(disk, req->sector);
 
 		__all_stat_inc(disk, part, ios[rw], req->sector);
 		__all_stat_add(disk, part, ticks[rw], duration, req->sector);
@@ -1756,6 +1766,8 @@
 			part_round_stats(part);
 			part->in_flight--;
 		}
+
+		rcu_read_unlock();
 	}
 
 	if (req->end_io)
diff --git a/block/blk-merge.c b/block/blk-merge.c
index 9b17da6..eb2a3ca 100644
--- a/block/blk-merge.c
+++ b/block/blk-merge.c
@@ -387,14 +387,19 @@
 	elv_merge_requests(q, req, next);
 
 	if (req->rq_disk) {
-		struct hd_struct *part =
-			disk_map_sector(req->rq_disk, req->sector);
+		struct hd_struct *part;
+
+		rcu_read_lock();
+
+		part = disk_map_sector_rcu(req->rq_disk, req->sector);
 		disk_round_stats(req->rq_disk);
 		req->rq_disk->in_flight--;
 		if (part) {
 			part_round_stats(part);
 			part->in_flight--;
 		}
+
+		rcu_read_unlock();
 	}
 
 	req->ioprio = ioprio_best(req->ioprio, next->ioprio);
diff --git a/block/genhd.c b/block/genhd.c
index fa32d09..b431d65 100644
--- a/block/genhd.c
+++ b/block/genhd.c
@@ -26,6 +26,158 @@
 
 static struct device_type disk_type;
 
+/**
+ * disk_get_part - get partition
+ * @disk: disk to look partition from
+ * @partno: partition number
+ *
+ * Look for partition @partno from @disk.  If found, increment
+ * reference count and return it.
+ *
+ * CONTEXT:
+ * Don't care.
+ *
+ * RETURNS:
+ * Pointer to the found partition on success, NULL if not found.
+ */
+struct hd_struct *disk_get_part(struct gendisk *disk, int partno)
+{
+	struct hd_struct *part;
+
+	if (unlikely(partno < 1 || partno > disk_max_parts(disk)))
+		return NULL;
+	rcu_read_lock();
+	part = rcu_dereference(disk->__part[partno - 1]);
+	if (part)
+		get_device(&part->dev);
+	rcu_read_unlock();
+
+	return part;
+}
+EXPORT_SYMBOL_GPL(disk_get_part);
+
+/**
+ * disk_part_iter_init - initialize partition iterator
+ * @piter: iterator to initialize
+ * @disk: disk to iterate over
+ * @flags: DISK_PITER_* flags
+ *
+ * Initialize @piter so that it iterates over partitions of @disk.
+ *
+ * CONTEXT:
+ * Don't care.
+ */
+void disk_part_iter_init(struct disk_part_iter *piter, struct gendisk *disk,
+			  unsigned int flags)
+{
+	piter->disk = disk;
+	piter->part = NULL;
+
+	if (flags & DISK_PITER_REVERSE)
+		piter->idx = disk_max_parts(piter->disk) - 1;
+	else
+		piter->idx = 0;
+
+	piter->flags = flags;
+}
+EXPORT_SYMBOL_GPL(disk_part_iter_init);
+
+/**
+ * disk_part_iter_next - proceed iterator to the next partition and return it
+ * @piter: iterator of interest
+ *
+ * Proceed @piter to the next partition and return it.
+ *
+ * CONTEXT:
+ * Don't care.
+ */
+struct hd_struct *disk_part_iter_next(struct disk_part_iter *piter)
+{
+	int inc, end;
+
+	/* put the last partition */
+	disk_put_part(piter->part);
+	piter->part = NULL;
+
+	rcu_read_lock();
+
+	/* determine iteration parameters */
+	if (piter->flags & DISK_PITER_REVERSE) {
+		inc = -1;
+		end = -1;
+	} else {
+		inc = 1;
+		end = disk_max_parts(piter->disk);
+	}
+
+	/* iterate to the next partition */
+	for (; piter->idx != end; piter->idx += inc) {
+		struct hd_struct *part;
+
+		part = rcu_dereference(piter->disk->__part[piter->idx]);
+		if (!part)
+			continue;
+		if (!(piter->flags & DISK_PITER_INCL_EMPTY) && !part->nr_sects)
+			continue;
+
+		get_device(&part->dev);
+		piter->part = part;
+		piter->idx += inc;
+		break;
+	}
+
+	rcu_read_unlock();
+
+	return piter->part;
+}
+EXPORT_SYMBOL_GPL(disk_part_iter_next);
+
+/**
+ * disk_part_iter_exit - finish up partition iteration
+ * @piter: iter of interest
+ *
+ * Called when iteration is over.  Cleans up @piter.
+ *
+ * CONTEXT:
+ * Don't care.
+ */
+void disk_part_iter_exit(struct disk_part_iter *piter)
+{
+	disk_put_part(piter->part);
+	piter->part = NULL;
+}
+EXPORT_SYMBOL_GPL(disk_part_iter_exit);
+
+/**
+ * disk_map_sector_rcu - map sector to partition
+ * @disk: gendisk of interest
+ * @sector: sector to map
+ *
+ * Find out which partition @sector maps to on @disk.  This is
+ * primarily used for stats accounting.
+ *
+ * CONTEXT:
+ * RCU read locked.  The returned partition pointer is valid only
+ * while preemption is disabled.
+ *
+ * RETURNS:
+ * Found partition on success, NULL if there's no matching partition.
+ */
+struct hd_struct *disk_map_sector_rcu(struct gendisk *disk, sector_t sector)
+{
+	int i;
+
+	for (i = 0; i < disk_max_parts(disk); i++) {
+		struct hd_struct *part = rcu_dereference(disk->__part[i]);
+
+		if (part && part->start_sect <= sector &&
+		    sector < part->start_sect + part->nr_sects)
+			return part;
+	}
+	return NULL;
+}
+EXPORT_SYMBOL_GPL(disk_map_sector_rcu);
+
 /*
  * Can be deleted altogether. Later.
  *
@@ -245,10 +397,12 @@
 	if (partno == 0)
 		devt = disk_devt(disk);
 	else {
-		struct hd_struct *part = disk->part[partno - 1];
+		struct hd_struct *part;
 
+		part = disk_get_part(disk, partno);
 		if (part && part->nr_sects)
 			devt = part_devt(part);
+		disk_put_part(part);
 	}
 
 	if (likely(devt != MKDEV(0, 0)))
@@ -270,8 +424,9 @@
 	class_dev_iter_init(&iter, &block_class, NULL, &disk_type);
 	while ((dev = class_dev_iter_next(&iter))) {
 		struct gendisk *disk = dev_to_disk(dev);
+		struct disk_part_iter piter;
+		struct hd_struct *part;
 		char buf[BDEVNAME_SIZE];
-		int n;
 
 		/*
 		 * Don't show empty devices or things that have been
@@ -298,16 +453,13 @@
 			printk(" (driver?)\n");
 
 		/* now show the partitions */
-		for (n = 0; n < disk_max_parts(disk); ++n) {
-			struct hd_struct *part = disk->part[n];
-
-			if (!part || !part->nr_sects)
-				continue;
+		disk_part_iter_init(&piter, disk, 0);
+		while ((part = disk_part_iter_next(&piter)))
 			printk("  %02x%02x %10llu %s\n",
 			       MAJOR(part_devt(part)), MINOR(part_devt(part)),
 			       (unsigned long long)part->nr_sects >> 1,
 			       disk_name(disk, part->partno, buf));
-		}
+		disk_part_iter_exit(&piter);
 	}
 	class_dev_iter_exit(&iter);
 }
@@ -371,7 +523,8 @@
 static int show_partition(struct seq_file *seqf, void *v)
 {
 	struct gendisk *sgp = v;
-	int n;
+	struct disk_part_iter piter;
+	struct hd_struct *part;
 	char buf[BDEVNAME_SIZE];
 
 	/* Don't show non-partitionable removeable devices or empty devices */
@@ -386,17 +539,14 @@
 		MAJOR(disk_devt(sgp)), MINOR(disk_devt(sgp)),
 		(unsigned long long)get_capacity(sgp) >> 1,
 		disk_name(sgp, 0, buf));
-	for (n = 0; n < disk_max_parts(sgp); n++) {
-		struct hd_struct *part = sgp->part[n];
-		if (!part)
-			continue;
-		if (part->nr_sects == 0)
-			continue;
+
+	disk_part_iter_init(&piter, sgp, 0);
+	while ((part = disk_part_iter_next(&piter)))
 		seq_printf(seqf, "%4d  %4d %10llu %s\n",
 			   MAJOR(part_devt(part)), MINOR(part_devt(part)),
 			   (unsigned long long)part->nr_sects >> 1,
 			   disk_name(sgp, part->partno, buf));
-	}
+	disk_part_iter_exit(&piter);
 
 	return 0;
 }
@@ -571,7 +721,7 @@
 	struct gendisk *disk = dev_to_disk(dev);
 
 	kfree(disk->random);
-	kfree(disk->part);
+	kfree(disk->__part);
 	free_disk_stats(disk);
 	kfree(disk);
 }
@@ -596,8 +746,9 @@
 static int diskstats_show(struct seq_file *seqf, void *v)
 {
 	struct gendisk *gp = v;
+	struct disk_part_iter piter;
+	struct hd_struct *hd;
 	char buf[BDEVNAME_SIZE];
-	int n;
 
 	/*
 	if (&gp->dev.kobj.entry == block_class.devices.next)
@@ -624,12 +775,8 @@
 		jiffies_to_msecs(disk_stat_read(gp, time_in_queue)));
 
 	/* now show all non-0 size partitions of it */
-	for (n = 0; n < disk_max_parts(gp); n++) {
-		struct hd_struct *hd = gp->part[n];
-
-		if (!hd || !hd->nr_sects)
-			continue;
-
+	disk_part_iter_init(&piter, gp, 0);
+	while ((hd = disk_part_iter_next(&piter))) {
 		preempt_disable();
 		part_round_stats(hd);
 		preempt_enable();
@@ -650,6 +797,7 @@
 			   jiffies_to_msecs(part_stat_read(hd, time_in_queue))
 			);
 	}
+	disk_part_iter_exit(&piter);
  
 	return 0;
 }
@@ -703,12 +851,16 @@
 		if (partno == 0)
 			devt = disk_devt(disk);
 		else {
-			struct hd_struct *part = disk->part[partno - 1];
+			struct hd_struct *part;
 
-			if (!part || !part->nr_sects)
+			part = disk_get_part(disk, partno);
+			if (!part || !part->nr_sects) {
+				disk_put_part(part);
 				continue;
+			}
 
 			devt = part_devt(part);
+			disk_put_part(part);
 		}
 		break;
 	}
@@ -735,9 +887,9 @@
 		}
 		if (minors > 1) {
 			int size = (minors - 1) * sizeof(struct hd_struct *);
-			disk->part = kmalloc_node(size,
+			disk->__part = kmalloc_node(size,
 				GFP_KERNEL | __GFP_ZERO, node_id);
-			if (!disk->part) {
+			if (!disk->__part) {
 				free_disk_stats(disk);
 				kfree(disk);
 				return NULL;
@@ -798,10 +950,14 @@
 
 void set_disk_ro(struct gendisk *disk, int flag)
 {
-	int i;
+	struct disk_part_iter piter;
+	struct hd_struct *part;
+
 	disk->policy = flag;
-	for (i = 0; i < disk_max_parts(disk); i++)
-		if (disk->part[i]) disk->part[i]->policy = flag;
+	disk_part_iter_init(&piter, disk, DISK_PITER_INCL_EMPTY);
+	while ((part = disk_part_iter_next(&piter)))
+		part->policy = flag;
+	disk_part_iter_exit(&piter);
 }
 
 EXPORT_SYMBOL(set_disk_ro);
diff --git a/block/ioctl.c b/block/ioctl.c
index 403f7d7e..a5f672a 100644
--- a/block/ioctl.c
+++ b/block/ioctl.c
@@ -12,11 +12,12 @@
 {
 	struct block_device *bdevp;
 	struct gendisk *disk;
+	struct hd_struct *part;
 	struct blkpg_ioctl_arg a;
 	struct blkpg_partition p;
+	struct disk_part_iter piter;
 	long long start, length;
 	int partno;
-	int i;
 	int err;
 
 	if (!capable(CAP_SYS_ADMIN))
@@ -47,28 +48,33 @@
 			mutex_lock(&bdev->bd_mutex);
 
 			/* overlap? */
-			for (i = 0; i < disk_max_parts(disk); i++) {
-				struct hd_struct *s = disk->part[i];
-
-				if (!s)
-					continue;
-				if (!(start+length <= s->start_sect ||
-				      start >= s->start_sect + s->nr_sects)) {
+			disk_part_iter_init(&piter, disk,
+					    DISK_PITER_INCL_EMPTY);
+			while ((part = disk_part_iter_next(&piter))) {
+				if (!(start + length <= part->start_sect ||
+				      start >= part->start_sect + part->nr_sects)) {
+					disk_part_iter_exit(&piter);
 					mutex_unlock(&bdev->bd_mutex);
 					return -EBUSY;
 				}
 			}
+			disk_part_iter_exit(&piter);
+
 			/* all seems OK */
 			err = add_partition(disk, partno, start, length,
 					    ADDPART_FLAG_NONE);
 			mutex_unlock(&bdev->bd_mutex);
 			return err;
 		case BLKPG_DEL_PARTITION:
-			if (!disk->part[partno - 1])
+			part = disk_get_part(disk, partno);
+			if (!part)
 				return -ENXIO;
-			bdevp = bdget_disk(disk, partno);
+
+			bdevp = bdget(part_devt(part));
+			disk_put_part(part);
 			if (!bdevp)
 				return -ENOMEM;
+
 			mutex_lock(&bdevp->bd_mutex);
 			if (bdevp->bd_openers) {
 				mutex_unlock(&bdevp->bd_mutex);