rbd: use reference counting for the snap context

This prevents a race between requests with a given snap context and
header updates that free it. The osd client was already expecting the
snap context to be reference counted, since it get()s it in
ceph_osdc_build_request and put()s it when the request completes.

Also remove the second down_read()/up_read() on header_rwsem in
rbd_do_request, which wasn't actually preventing this race or
protecting any other data.

Signed-off-by: Josh Durgin <josh.durgin@dreamhost.com>
Reviewed-by: Alex Elder <elder@inktank.com>
diff --git a/drivers/block/rbd.c b/drivers/block/rbd.c
index a6bbda2..988f944 100644
--- a/drivers/block/rbd.c
+++ b/drivers/block/rbd.c
@@ -626,7 +626,7 @@
 	kfree(header->object_prefix);
 	kfree(header->snap_sizes);
 	kfree(header->snap_names);
-	kfree(header->snapc);
+	ceph_put_snap_context(header->snapc);
 }
 
 /*
@@ -902,13 +902,10 @@
 	dout("rbd_do_request object_name=%s ofs=%lld len=%lld\n",
 		object_name, len, ofs);
 
-	down_read(&rbd_dev->header_rwsem);
-
 	osdc = &rbd_dev->rbd_client->client->osdc;
 	req = ceph_osdc_alloc_request(osdc, flags, snapc, ops,
 					false, GFP_NOIO, pages, bio);
 	if (!req) {
-		up_read(&rbd_dev->header_rwsem);
 		ret = -ENOMEM;
 		goto done_pages;
 	}
@@ -942,7 +939,6 @@
 				snapc,
 				&mtime,
 				req->r_oid, req->r_oid_len);
-	up_read(&rbd_dev->header_rwsem);
 
 	if (linger_req) {
 		ceph_osdc_set_request_linger(osdc, req);
@@ -1448,6 +1444,7 @@
 		u64 ofs;
 		int num_segs, cur_seg = 0;
 		struct rbd_req_coll *coll;
+		struct ceph_snap_context *snapc;
 
 		/* peek at request from block layer */
 		if (!rq)
@@ -1474,21 +1471,20 @@
 
 		spin_unlock_irq(q->queue_lock);
 
-		if (rbd_dev->snap_id != CEPH_NOSNAP) {
-			bool snap_exists;
+		down_read(&rbd_dev->header_rwsem);
 
-			down_read(&rbd_dev->header_rwsem);
-			snap_exists = rbd_dev->snap_exists;
+		if (rbd_dev->snap_id != CEPH_NOSNAP && !rbd_dev->snap_exists) {
 			up_read(&rbd_dev->header_rwsem);
-
-			if (!snap_exists) {
-				dout("request for non-existent snapshot");
-				spin_lock_irq(q->queue_lock);
-				__blk_end_request_all(rq, -ENXIO);
-				continue;
-			}
+			dout("request for non-existent snapshot");
+			spin_lock_irq(q->queue_lock);
+			__blk_end_request_all(rq, -ENXIO);
+			continue;
 		}
 
+		snapc = ceph_get_snap_context(rbd_dev->header.snapc);
+
+		up_read(&rbd_dev->header_rwsem);
+
 		dout("%s 0x%x bytes at 0x%llx\n",
 		     do_write ? "write" : "read",
 		     size, blk_rq_pos(rq) * SECTOR_SIZE);
@@ -1498,6 +1494,7 @@
 		if (!coll) {
 			spin_lock_irq(q->queue_lock);
 			__blk_end_request_all(rq, -ENOMEM);
+			ceph_put_snap_context(snapc);
 			continue;
 		}
 
@@ -1521,7 +1518,7 @@
 			/* init OSD command: write or read */
 			if (do_write)
 				rbd_req_write(rq, rbd_dev,
-					      rbd_dev->header.snapc,
+					      snapc,
 					      ofs,
 					      op_size, bio,
 					      coll, cur_seg);
@@ -1544,6 +1541,8 @@
 		if (bp)
 			bio_pair_release(bp);
 		spin_lock_irq(q->queue_lock);
+
+		ceph_put_snap_context(snapc);
 	}
 }
 
@@ -1744,7 +1743,8 @@
 	/* rbd_dev->header.object_prefix shouldn't change */
 	kfree(rbd_dev->header.snap_sizes);
 	kfree(rbd_dev->header.snap_names);
-	kfree(rbd_dev->header.snapc);
+	/* osd requests may still refer to snapc */
+	ceph_put_snap_context(rbd_dev->header.snapc);
 
 	rbd_dev->header.image_size = h.image_size;
 	rbd_dev->header.total_snaps = h.total_snaps;