vhost: Allow device specific fields per vq

This is useful for any device who wants device specific fields per vq.
For example, tcm_vhost wants a per vq field to track requests which are
in flight on the vq. Also, on top of this we can add patches to move
things like ubufs from vhost.h out to net.c.

Signed-off-by: Michael S. Tsirkin <mst@redhat.com>
Signed-off-by: Asias He <asias@redhat.com>
Signed-off-by: Michael S. Tsirkin <mst@redhat.com>
diff --git a/drivers/vhost/tcm_vhost.c b/drivers/vhost/tcm_vhost.c
index 1677238..99d3480 100644
--- a/drivers/vhost/tcm_vhost.c
+++ b/drivers/vhost/tcm_vhost.c
@@ -74,13 +74,17 @@
 #define VHOST_SCSI_MAX_VQ	128
 #define VHOST_SCSI_MAX_EVENT	128
 
+struct vhost_scsi_virtqueue {
+	struct vhost_virtqueue vq;
+};
+
 struct vhost_scsi {
 	/* Protected by vhost_scsi->dev.mutex */
 	struct tcm_vhost_tpg **vs_tpg;
 	char vs_vhost_wwpn[TRANSPORT_IQN_LEN];
 
 	struct vhost_dev dev;
-	struct vhost_virtqueue vqs[VHOST_SCSI_MAX_VQ];
+	struct vhost_scsi_virtqueue vqs[VHOST_SCSI_MAX_VQ];
 
 	struct vhost_work vs_completion_work; /* cmd completion work item */
 	struct llist_head vs_completion_list; /* cmd completion queue */
@@ -366,7 +370,7 @@
 static struct tcm_vhost_evt *tcm_vhost_allocate_evt(struct vhost_scsi *vs,
 	u32 event, u32 reason)
 {
-	struct vhost_virtqueue *vq = &vs->vqs[VHOST_SCSI_VQ_EVT];
+	struct vhost_virtqueue *vq = &vs->vqs[VHOST_SCSI_VQ_EVT].vq;
 	struct tcm_vhost_evt *evt;
 
 	if (vs->vs_events_nr > VHOST_SCSI_MAX_EVENT) {
@@ -409,7 +413,7 @@
 static void tcm_vhost_do_evt_work(struct vhost_scsi *vs,
 	struct tcm_vhost_evt *evt)
 {
-	struct vhost_virtqueue *vq = &vs->vqs[VHOST_SCSI_VQ_EVT];
+	struct vhost_virtqueue *vq = &vs->vqs[VHOST_SCSI_VQ_EVT].vq;
 	struct virtio_scsi_event *event = &evt->event;
 	struct virtio_scsi_event __user *eventp;
 	unsigned out, in;
@@ -460,7 +464,7 @@
 {
 	struct vhost_scsi *vs = container_of(work, struct vhost_scsi,
 					vs_event_work);
-	struct vhost_virtqueue *vq = &vs->vqs[VHOST_SCSI_VQ_EVT];
+	struct vhost_virtqueue *vq = &vs->vqs[VHOST_SCSI_VQ_EVT].vq;
 	struct tcm_vhost_evt *evt;
 	struct llist_node *llnode;
 
@@ -511,8 +515,10 @@
 		       v_rsp.sense_len);
 		ret = copy_to_user(tv_cmd->tvc_resp, &v_rsp, sizeof(v_rsp));
 		if (likely(ret == 0)) {
+			struct vhost_scsi_virtqueue *q;
 			vhost_add_used(tv_cmd->tvc_vq, tv_cmd->tvc_vq_desc, 0);
-			vq = tv_cmd->tvc_vq - vs->vqs;
+			q = container_of(tv_cmd->tvc_vq, struct vhost_scsi_virtqueue, vq);
+			vq = q - vs->vqs;
 			__set_bit(vq, signal);
 		} else
 			pr_err("Faulted on virtio_scsi_cmd_resp\n");
@@ -523,7 +529,7 @@
 	vq = -1;
 	while ((vq = find_next_bit(signal, VHOST_SCSI_MAX_VQ, vq + 1))
 		< VHOST_SCSI_MAX_VQ)
-		vhost_signal(&vs->dev, &vs->vqs[vq]);
+		vhost_signal(&vs->dev, &vs->vqs[vq].vq);
 }
 
 static struct tcm_vhost_cmd *vhost_scsi_allocate_cmd(
@@ -938,7 +944,7 @@
 
 static void vhost_scsi_flush_vq(struct vhost_scsi *vs, int index)
 {
-	vhost_poll_flush(&vs->dev.vqs[index].poll);
+	vhost_poll_flush(&vs->vqs[index].vq.poll);
 }
 
 static void vhost_scsi_flush(struct vhost_scsi *vs)
@@ -975,7 +981,7 @@
 	/* Verify that ring has been setup correctly. */
 	for (index = 0; index < vs->dev.nvqs; ++index) {
 		/* Verify that ring has been setup correctly. */
-		if (!vhost_vq_access_ok(&vs->vqs[index])) {
+		if (!vhost_vq_access_ok(&vs->vqs[index].vq)) {
 			ret = -EFAULT;
 			goto out;
 		}
@@ -1022,7 +1028,7 @@
 		memcpy(vs->vs_vhost_wwpn, t->vhost_wwpn,
 		       sizeof(vs->vs_vhost_wwpn));
 		for (i = 0; i < VHOST_SCSI_MAX_VQ; i++) {
-			vq = &vs->vqs[i];
+			vq = &vs->vqs[i].vq;
 			/* Flushing the vhost_work acts as synchronize_rcu */
 			mutex_lock(&vq->mutex);
 			rcu_assign_pointer(vq->private_data, vs_tpg);
@@ -1063,7 +1069,7 @@
 	mutex_lock(&vs->dev.mutex);
 	/* Verify that ring has been setup correctly. */
 	for (index = 0; index < vs->dev.nvqs; ++index) {
-		if (!vhost_vq_access_ok(&vs->vqs[index])) {
+		if (!vhost_vq_access_ok(&vs->vqs[index].vq)) {
 			ret = -EFAULT;
 			goto err_dev;
 		}
@@ -1103,7 +1109,7 @@
 	}
 	if (match) {
 		for (i = 0; i < VHOST_SCSI_MAX_VQ; i++) {
-			vq = &vs->vqs[i];
+			vq = &vs->vqs[i].vq;
 			/* Flushing the vhost_work acts as synchronize_rcu */
 			mutex_lock(&vq->mutex);
 			rcu_assign_pointer(vq->private_data, NULL);
@@ -1151,24 +1157,36 @@
 static int vhost_scsi_open(struct inode *inode, struct file *f)
 {
 	struct vhost_scsi *s;
+	struct vhost_virtqueue **vqs;
 	int r, i;
 
 	s = kzalloc(sizeof(*s), GFP_KERNEL);
 	if (!s)
 		return -ENOMEM;
 
+	vqs = kmalloc(VHOST_SCSI_MAX_VQ * sizeof(*vqs), GFP_KERNEL);
+	if (!vqs) {
+		kfree(s);
+		return -ENOMEM;
+	}
+
 	vhost_work_init(&s->vs_completion_work, vhost_scsi_complete_cmd_work);
 	vhost_work_init(&s->vs_event_work, tcm_vhost_evt_work);
 
 	s->vs_events_nr = 0;
 	s->vs_events_missed = false;
 
-	s->vqs[VHOST_SCSI_VQ_CTL].handle_kick = vhost_scsi_ctl_handle_kick;
-	s->vqs[VHOST_SCSI_VQ_EVT].handle_kick = vhost_scsi_evt_handle_kick;
-	for (i = VHOST_SCSI_VQ_IO; i < VHOST_SCSI_MAX_VQ; i++)
-		s->vqs[i].handle_kick = vhost_scsi_handle_kick;
-	r = vhost_dev_init(&s->dev, s->vqs, VHOST_SCSI_MAX_VQ);
+	vqs[VHOST_SCSI_VQ_CTL] = &s->vqs[VHOST_SCSI_VQ_CTL].vq;
+	vqs[VHOST_SCSI_VQ_EVT] = &s->vqs[VHOST_SCSI_VQ_EVT].vq;
+	s->vqs[VHOST_SCSI_VQ_CTL].vq.handle_kick = vhost_scsi_ctl_handle_kick;
+	s->vqs[VHOST_SCSI_VQ_EVT].vq.handle_kick = vhost_scsi_evt_handle_kick;
+	for (i = VHOST_SCSI_VQ_IO; i < VHOST_SCSI_MAX_VQ; i++) {
+		vqs[i] = &s->vqs[i].vq;
+		s->vqs[i].vq.handle_kick = vhost_scsi_handle_kick;
+	}
+	r = vhost_dev_init(&s->dev, vqs, VHOST_SCSI_MAX_VQ);
 	if (r < 0) {
+		kfree(vqs);
 		kfree(s);
 		return r;
 	}
@@ -1190,6 +1208,7 @@
 	vhost_dev_cleanup(&s->dev, false);
 	/* Jobs can re-queue themselves in evt kick handler. Do extra flush. */
 	vhost_scsi_flush(s);
+	kfree(s->dev.vqs);
 	kfree(s);
 	return 0;
 }
@@ -1205,7 +1224,7 @@
 	u32 events_missed;
 	u64 features;
 	int r, abi_version = VHOST_SCSI_ABI_VERSION;
-	struct vhost_virtqueue *vq = &vs->vqs[VHOST_SCSI_VQ_EVT];
+	struct vhost_virtqueue *vq = &vs->vqs[VHOST_SCSI_VQ_EVT].vq;
 
 	switch (ioctl) {
 	case VHOST_SCSI_SET_ENDPOINT:
@@ -1333,7 +1352,7 @@
 	else
 		reason = VIRTIO_SCSI_EVT_RESET_REMOVED;
 
-	vq = &vs->vqs[VHOST_SCSI_VQ_EVT];
+	vq = &vs->vqs[VHOST_SCSI_VQ_EVT].vq;
 	mutex_lock(&vq->mutex);
 	tcm_vhost_send_evt(vs, tpg, lun,
 			VIRTIO_SCSI_T_TRANSPORT_RESET, reason);