ceph: clean up statfs

Avoid unnecessary msgpool.  Preallocate reply.  Fix use-after-free race.

Signed-off-by: Sage Weil <sage@newdream.net>
diff --git a/fs/ceph/mon_client.c b/fs/ceph/mon_client.c
index 8fdc011..43cfab0 100644
--- a/fs/ceph/mon_client.c
+++ b/fs/ceph/mon_client.c
@@ -393,16 +393,64 @@
 	rb_insert_color(&new->node, &monc->statfs_request_tree);
 }
 
+static void release_statfs_request(struct kref *kref)
+{
+	struct ceph_mon_statfs_request *req =
+		container_of(kref, struct ceph_mon_statfs_request, kref);
+
+	if (req->reply)
+		ceph_msg_put(req->reply);
+	if (req->request)
+		ceph_msg_put(req->request);
+}
+
+static void put_statfs_request(struct ceph_mon_statfs_request *req)
+{
+	kref_put(&req->kref, release_statfs_request);
+}
+
+static void get_statfs_request(struct ceph_mon_statfs_request *req)
+{
+	kref_get(&req->kref);
+}
+
+static struct ceph_msg *get_statfs_reply(struct ceph_connection *con,
+					 struct ceph_msg_header *hdr,
+					 int *skip)
+{
+	struct ceph_mon_client *monc = con->private;
+	struct ceph_mon_statfs_request *req;
+	u64 tid = le64_to_cpu(hdr->tid);
+	struct ceph_msg *m;
+
+	mutex_lock(&monc->mutex);
+	req = __lookup_statfs(monc, tid);
+	if (!req) {
+		dout("get_statfs_reply %lld dne\n", tid);
+		*skip = 1;
+		m = NULL;
+	} else {
+		dout("get_statfs_reply %lld got %p\n", tid, req->reply);
+		m = ceph_msg_get(req->reply);
+		/*
+		 * we don't need to track the connection reading into
+		 * this reply because we only have one open connection
+		 * at a time, ever.
+		 */
+	}
+	mutex_unlock(&monc->mutex);
+	return m;
+}
+
 static void handle_statfs_reply(struct ceph_mon_client *monc,
 				struct ceph_msg *msg)
 {
 	struct ceph_mon_statfs_request *req;
 	struct ceph_mon_statfs_reply *reply = msg->front.iov_base;
-	u64 tid;
+	u64 tid = le64_to_cpu(msg->hdr.tid);
 
 	if (msg->front.iov_len != sizeof(*reply))
 		goto bad;
-	tid = le64_to_cpu(msg->hdr.tid);
 	dout("handle_statfs_reply %p tid %llu\n", msg, tid);
 
 	mutex_lock(&monc->mutex);
@@ -410,10 +458,13 @@
 	if (req) {
 		*req->buf = reply->st;
 		req->result = 0;
+		get_statfs_request(req);
 	}
 	mutex_unlock(&monc->mutex);
-	if (req)
+	if (req) {
 		complete(&req->completion);
+		put_statfs_request(req);
+	}
 	return;
 
 bad:
@@ -422,67 +473,63 @@
 }
 
 /*
- * (re)send a statfs request
- */
-static int send_statfs(struct ceph_mon_client *monc,
-		       struct ceph_mon_statfs_request *req)
-{
-	struct ceph_msg *msg;
-	struct ceph_mon_statfs *h;
-
-	dout("send_statfs tid %llu\n", req->tid);
-	msg = ceph_msg_new(CEPH_MSG_STATFS, sizeof(*h), 0, 0, NULL);
-	if (IS_ERR(msg))
-		return PTR_ERR(msg);
-	req->request = msg;
-	msg->hdr.tid = cpu_to_le64(req->tid);
-	h = msg->front.iov_base;
-	h->monhdr.have_version = 0;
-	h->monhdr.session_mon = cpu_to_le16(-1);
-	h->monhdr.session_mon_tid = 0;
-	h->fsid = monc->monmap->fsid;
-	ceph_con_send(monc->con, msg);
-	return 0;
-}
-
-/*
  * Do a synchronous statfs().
  */
 int ceph_monc_do_statfs(struct ceph_mon_client *monc, struct ceph_statfs *buf)
 {
-	struct ceph_mon_statfs_request req;
+	struct ceph_mon_statfs_request *req;
+	struct ceph_mon_statfs *h;
 	int err;
 
-	req.buf = buf;
-	init_completion(&req.completion);
+	req = kmalloc(sizeof(*req), GFP_NOFS);
+	if (!req)
+		return -ENOMEM;
 
-	/* allocate memory for reply */
-	err = ceph_msgpool_resv(&monc->msgpool_statfs_reply, 1);
-	if (err)
-		return err;
+	memset(req, 0, sizeof(*req));
+	kref_init(&req->kref);
+	req->buf = buf;
+	init_completion(&req->completion);
+
+	req->request = ceph_msg_new(CEPH_MSG_STATFS, sizeof(*h), 0, 0, NULL);
+	if (IS_ERR(req->request)) {
+		err = PTR_ERR(req->request);
+		goto out;
+	}
+	req->reply = ceph_msg_new(CEPH_MSG_STATFS_REPLY, 1024, 0, 0, NULL);
+	if (IS_ERR(req->reply)) {
+		err = PTR_ERR(req->reply);
+		goto out;
+	}
+
+	/* fill out request */
+	h = req->request->front.iov_base;
+	h->monhdr.have_version = 0;
+	h->monhdr.session_mon = cpu_to_le16(-1);
+	h->monhdr.session_mon_tid = 0;
+	h->fsid = monc->monmap->fsid;
 
 	/* register request */
 	mutex_lock(&monc->mutex);
-	req.tid = ++monc->last_tid;
-	req.last_attempt = jiffies;
-	req.delay = BASE_DELAY_INTERVAL;
-	__insert_statfs(monc, &req);
+	req->tid = ++monc->last_tid;
+	req->request->hdr.tid = cpu_to_le64(req->tid);
+	__insert_statfs(monc, req);
 	monc->num_statfs_requests++;
 	mutex_unlock(&monc->mutex);
 
 	/* send request and wait */
-	err = send_statfs(monc, &req);
-	if (!err)
-		err = wait_for_completion_interruptible(&req.completion);
+	ceph_con_send(monc->con, ceph_msg_get(req->request));
+	err = wait_for_completion_interruptible(&req->completion);
 
 	mutex_lock(&monc->mutex);
-	rb_erase(&req.node, &monc->statfs_request_tree);
+	rb_erase(&req->node, &monc->statfs_request_tree);
 	monc->num_statfs_requests--;
-	ceph_msgpool_resv(&monc->msgpool_statfs_reply, -1);
 	mutex_unlock(&monc->mutex);
 
 	if (!err)
-		err = req.result;
+		err = req->result;
+
+out:
+	kref_put(&req->kref, release_statfs_request);
 	return err;
 }
 
@@ -496,7 +543,7 @@
 
 	for (p = rb_first(&monc->statfs_request_tree); p; p = rb_next(p)) {
 		req = rb_entry(p, struct ceph_mon_statfs_request, node);
-		send_statfs(monc, req);
+		ceph_con_send(monc->con, ceph_msg_get(req->request));
 	}
 }
 
@@ -591,13 +638,9 @@
 			       sizeof(struct ceph_mon_subscribe_ack), 1, false);
 	if (err < 0)
 		goto out_monmap;
-	err = ceph_msgpool_init(&monc->msgpool_statfs_reply,
-				sizeof(struct ceph_mon_statfs_reply), 0, false);
-	if (err < 0)
-		goto out_pool1;
 	err = ceph_msgpool_init(&monc->msgpool_auth_reply, 4096, 1, false);
 	if (err < 0)
-		goto out_pool2;
+		goto out_pool;
 
 	monc->m_auth = ceph_msg_new(CEPH_MSG_AUTH, 4096, 0, 0, NULL);
 	monc->pending_auth = 0;
@@ -624,10 +667,8 @@
 
 out_pool3:
 	ceph_msgpool_destroy(&monc->msgpool_auth_reply);
-out_pool2:
+out_pool:
 	ceph_msgpool_destroy(&monc->msgpool_subscribe_ack);
-out_pool1:
-	ceph_msgpool_destroy(&monc->msgpool_statfs_reply);
 out_monmap:
 	kfree(monc->monmap);
 out:
@@ -652,7 +693,6 @@
 
 	ceph_msg_put(monc->m_auth);
 	ceph_msgpool_destroy(&monc->msgpool_subscribe_ack);
-	ceph_msgpool_destroy(&monc->msgpool_statfs_reply);
 	ceph_msgpool_destroy(&monc->msgpool_auth_reply);
 
 	kfree(monc->monmap);
@@ -773,8 +813,7 @@
 		m = ceph_msgpool_get(&monc->msgpool_subscribe_ack, front_len);
 		break;
 	case CEPH_MSG_STATFS_REPLY:
-		m = ceph_msgpool_get(&monc->msgpool_statfs_reply, front_len);
-		break;
+		return get_statfs_reply(con, hdr, skip);
 	case CEPH_MSG_AUTH_REPLY:
 		m = ceph_msgpool_get(&monc->msgpool_auth_reply, front_len);
 		break;