virtio: add context flag to find vqs

Allows maintaining extra context per vq.  For ease of use, passing in
NULL is legal and disables the feature for all vqs.

Includes fixes by Christian for s390, acked by Cornelia.

Signed-off-by: Christian Borntraeger <borntraeger@de.ibm.com>
Acked-by: Cornelia Huck <cornelia.huck@de.ibm.com>
Signed-off-by: Michael S. Tsirkin <mst@redhat.com>
diff --git a/drivers/misc/mic/vop/vop_main.c b/drivers/misc/mic/vop/vop_main.c
index c2e29d7..a341938 100644
--- a/drivers/misc/mic/vop/vop_main.c
+++ b/drivers/misc/mic/vop/vop_main.c
@@ -278,7 +278,7 @@
 static struct virtqueue *vop_find_vq(struct virtio_device *dev,
 				     unsigned index,
 				     void (*callback)(struct virtqueue *vq),
-				     const char *name)
+				     const char *name, bool ctx)
 {
 	struct _vop_vdev *vdev = to_vopvdev(dev);
 	struct vop_device *vpdev = vdev->vpdev;
@@ -314,6 +314,7 @@
 				le16_to_cpu(config.num), MIC_VIRTIO_RING_ALIGN,
 				dev,
 				false,
+				ctx,
 				(void __force *)va, vop_notify, callback, name);
 	if (!vq) {
 		err = -ENOMEM;
@@ -374,7 +375,8 @@
 static int vop_find_vqs(struct virtio_device *dev, unsigned nvqs,
 			struct virtqueue *vqs[],
 			vq_callback_t *callbacks[],
-			const char * const names[], struct irq_affinity *desc)
+			const char * const names[], const bool *ctx,
+			struct irq_affinity *desc)
 {
 	struct _vop_vdev *vdev = to_vopvdev(dev);
 	struct vop_device *vpdev = vdev->vpdev;
@@ -388,7 +390,8 @@
 	for (i = 0; i < nvqs; ++i) {
 		dev_dbg(_vop_dev(vdev), "%s: %d: %s\n",
 			__func__, i, names[i]);
-		vqs[i] = vop_find_vq(dev, i, callbacks[i], names[i]);
+		vqs[i] = vop_find_vq(dev, i, callbacks[i], names[i],
+				     ctx ? ctx[i] : false);
 		if (IS_ERR(vqs[i])) {
 			err = PTR_ERR(vqs[i]);
 			goto error;
diff --git a/drivers/remoteproc/remoteproc_virtio.c b/drivers/remoteproc/remoteproc_virtio.c
index 0142cc3..2946348 100644
--- a/drivers/remoteproc/remoteproc_virtio.c
+++ b/drivers/remoteproc/remoteproc_virtio.c
@@ -71,7 +71,7 @@
 static struct virtqueue *rp_find_vq(struct virtio_device *vdev,
 				    unsigned int id,
 				    void (*callback)(struct virtqueue *vq),
-				    const char *name)
+				    const char *name, bool ctx)
 {
 	struct rproc_vdev *rvdev = vdev_to_rvdev(vdev);
 	struct rproc *rproc = vdev_to_rproc(vdev);
@@ -103,8 +103,8 @@
 	 * Create the new vq, and tell virtio we're not interested in
 	 * the 'weak' smp barriers, since we're talking with a real device.
 	 */
-	vq = vring_new_virtqueue(id, len, rvring->align, vdev, false, addr,
-				 rproc_virtio_notify, callback, name);
+	vq = vring_new_virtqueue(id, len, rvring->align, vdev, false, ctx,
+				 addr, rproc_virtio_notify, callback, name);
 	if (!vq) {
 		dev_err(dev, "vring_new_virtqueue %s failed\n", name);
 		rproc_free_vring(rvring);
@@ -138,12 +138,14 @@
 				 struct virtqueue *vqs[],
 				 vq_callback_t *callbacks[],
 				 const char * const names[],
+				 const bool * ctx,
 				 struct irq_affinity *desc)
 {
 	int i, ret;
 
 	for (i = 0; i < nvqs; ++i) {
-		vqs[i] = rp_find_vq(vdev, i, callbacks[i], names[i]);
+		vqs[i] = rp_find_vq(vdev, i, callbacks[i], names[i],
+				    ctx ? ctx[i] : false);
 		if (IS_ERR(vqs[i])) {
 			ret = PTR_ERR(vqs[i]);
 			goto error;
diff --git a/drivers/s390/virtio/kvm_virtio.c b/drivers/s390/virtio/kvm_virtio.c
index 2ce0b3e..a99d09a 100644
--- a/drivers/s390/virtio/kvm_virtio.c
+++ b/drivers/s390/virtio/kvm_virtio.c
@@ -189,7 +189,7 @@
 static struct virtqueue *kvm_find_vq(struct virtio_device *vdev,
 				     unsigned index,
 				     void (*callback)(struct virtqueue *vq),
-				     const char *name)
+				     const char *name, bool ctx)
 {
 	struct kvm_device *kdev = to_kvmdev(vdev);
 	struct kvm_vqconfig *config;
@@ -211,7 +211,7 @@
 		goto out;
 
 	vq = vring_new_virtqueue(index, config->num, KVM_S390_VIRTIO_RING_ALIGN,
-				 vdev, true, (void *) config->address,
+				 vdev, true, ctx, (void *) config->address,
 				 kvm_notify, callback, name);
 	if (!vq) {
 		err = -ENOMEM;
@@ -256,6 +256,7 @@
 			struct virtqueue *vqs[],
 			vq_callback_t *callbacks[],
 			const char * const names[],
+			const bool *ctx,
 			struct irq_affinity *desc)
 {
 	struct kvm_device *kdev = to_kvmdev(vdev);
@@ -266,7 +267,8 @@
 		return -ENOENT;
 
 	for (i = 0; i < nvqs; ++i) {
-		vqs[i] = kvm_find_vq(vdev, i, callbacks[i], names[i]);
+		vqs[i] = kvm_find_vq(vdev, i, callbacks[i], names[i],
+				     ctx ? ctx[i] : false);
 		if (IS_ERR(vqs[i]))
 			goto error;
 	}
diff --git a/drivers/s390/virtio/virtio_ccw.c b/drivers/s390/virtio/virtio_ccw.c
index 0ed209f..2a76ea7 100644
--- a/drivers/s390/virtio/virtio_ccw.c
+++ b/drivers/s390/virtio/virtio_ccw.c
@@ -484,7 +484,7 @@
 
 static struct virtqueue *virtio_ccw_setup_vq(struct virtio_device *vdev,
 					     int i, vq_callback_t *callback,
-					     const char *name,
+					     const char *name, bool ctx,
 					     struct ccw1 *ccw)
 {
 	struct virtio_ccw_device *vcdev = to_vc_device(vdev);
@@ -522,7 +522,7 @@
 	}
 
 	vq = vring_new_virtqueue(i, info->num, KVM_VIRTIO_CCW_RING_ALIGN, vdev,
-				 true, info->queue, virtio_ccw_kvm_notify,
+				 true, ctx, info->queue, virtio_ccw_kvm_notify,
 				 callback, name);
 	if (!vq) {
 		/* For now, we fail if we can't get the requested size. */
@@ -629,6 +629,7 @@
 			       struct virtqueue *vqs[],
 			       vq_callback_t *callbacks[],
 			       const char * const names[],
+			       const bool *ctx,
 			       struct irq_affinity *desc)
 {
 	struct virtio_ccw_device *vcdev = to_vc_device(vdev);
@@ -642,7 +643,7 @@
 
 	for (i = 0; i < nvqs; ++i) {
 		vqs[i] = virtio_ccw_setup_vq(vdev, i, callbacks[i], names[i],
-					     ccw);
+					     ctx ? ctx[i] : false, ccw);
 		if (IS_ERR(vqs[i])) {
 			ret = PTR_ERR(vqs[i]);
 			vqs[i] = NULL;
diff --git a/drivers/virtio/virtio_mmio.c b/drivers/virtio/virtio_mmio.c
index 78343b8..74dc717 100644
--- a/drivers/virtio/virtio_mmio.c
+++ b/drivers/virtio/virtio_mmio.c
@@ -351,7 +351,7 @@
 
 static struct virtqueue *vm_setup_vq(struct virtio_device *vdev, unsigned index,
 				  void (*callback)(struct virtqueue *vq),
-				  const char *name)
+				  const char *name, bool ctx)
 {
 	struct virtio_mmio_device *vm_dev = to_virtio_mmio_device(vdev);
 	struct virtio_mmio_vq_info *info;
@@ -388,7 +388,7 @@
 
 	/* Create the vring */
 	vq = vring_create_virtqueue(index, num, VIRTIO_MMIO_VRING_ALIGN, vdev,
-				 true, true, vm_notify, callback, name);
+				 true, true, ctx, vm_notify, callback, name);
 	if (!vq) {
 		err = -ENOMEM;
 		goto error_new_virtqueue;
@@ -447,6 +447,7 @@
 		       struct virtqueue *vqs[],
 		       vq_callback_t *callbacks[],
 		       const char * const names[],
+		       const bool *ctx,
 		       struct irq_affinity *desc)
 {
 	struct virtio_mmio_device *vm_dev = to_virtio_mmio_device(vdev);
@@ -459,7 +460,8 @@
 		return err;
 
 	for (i = 0; i < nvqs; ++i) {
-		vqs[i] = vm_setup_vq(vdev, i, callbacks[i], names[i]);
+		vqs[i] = vm_setup_vq(vdev, i, callbacks[i], names[i],
+				     ctx ? ctx[i] : false);
 		if (IS_ERR(vqs[i])) {
 			vm_del_vqs(vdev);
 			return PTR_ERR(vqs[i]);
diff --git a/drivers/virtio/virtio_pci_common.c b/drivers/virtio/virtio_pci_common.c
index 698d5d0..007a4f3 100644
--- a/drivers/virtio/virtio_pci_common.c
+++ b/drivers/virtio/virtio_pci_common.c
@@ -172,6 +172,7 @@
 static struct virtqueue *vp_setup_vq(struct virtio_device *vdev, unsigned index,
 				     void (*callback)(struct virtqueue *vq),
 				     const char *name,
+				     bool ctx,
 				     u16 msix_vec)
 {
 	struct virtio_pci_device *vp_dev = to_vp_device(vdev);
@@ -183,7 +184,7 @@
 	if (!info)
 		return ERR_PTR(-ENOMEM);
 
-	vq = vp_dev->setup_vq(vp_dev, info, index, callback, name,
+	vq = vp_dev->setup_vq(vp_dev, info, index, callback, name, ctx,
 			      msix_vec);
 	if (IS_ERR(vq))
 		goto out_info;
@@ -274,6 +275,7 @@
 static int vp_find_vqs_msix(struct virtio_device *vdev, unsigned nvqs,
 		struct virtqueue *vqs[], vq_callback_t *callbacks[],
 		const char * const names[], bool per_vq_vectors,
+		const bool *ctx,
 		struct irq_affinity *desc)
 {
 	struct virtio_pci_device *vp_dev = to_vp_device(vdev);
@@ -315,6 +317,7 @@
 		else
 			msix_vec = VP_MSIX_VQ_VECTOR;
 		vqs[i] = vp_setup_vq(vdev, i, callbacks[i], names[i],
+				     ctx ? ctx[i] : false,
 				     msix_vec);
 		if (IS_ERR(vqs[i])) {
 			err = PTR_ERR(vqs[i]);
@@ -345,7 +348,7 @@
 
 static int vp_find_vqs_intx(struct virtio_device *vdev, unsigned nvqs,
 		struct virtqueue *vqs[], vq_callback_t *callbacks[],
-		const char * const names[])
+		const char * const names[], const bool *ctx)
 {
 	struct virtio_pci_device *vp_dev = to_vp_device(vdev);
 	int i, err;
@@ -367,6 +370,7 @@
 			continue;
 		}
 		vqs[i] = vp_setup_vq(vdev, i, callbacks[i], names[i],
+				     ctx ? ctx[i] : false,
 				     VIRTIO_MSI_NO_VECTOR);
 		if (IS_ERR(vqs[i])) {
 			err = PTR_ERR(vqs[i]);
@@ -383,20 +387,21 @@
 /* the config->find_vqs() implementation */
 int vp_find_vqs(struct virtio_device *vdev, unsigned nvqs,
 		struct virtqueue *vqs[], vq_callback_t *callbacks[],
-		const char * const names[], struct irq_affinity *desc)
+		const char * const names[], const bool *ctx,
+		struct irq_affinity *desc)
 {
 	int err;
 
 	/* Try MSI-X with one vector per queue. */
-	err = vp_find_vqs_msix(vdev, nvqs, vqs, callbacks, names, true, desc);
+	err = vp_find_vqs_msix(vdev, nvqs, vqs, callbacks, names, true, ctx, desc);
 	if (!err)
 		return 0;
 	/* Fallback: MSI-X with one vector for config, one shared for queues. */
-	err = vp_find_vqs_msix(vdev, nvqs, vqs, callbacks, names, false, desc);
+	err = vp_find_vqs_msix(vdev, nvqs, vqs, callbacks, names, false, ctx, desc);
 	if (!err)
 		return 0;
 	/* Finally fall back to regular interrupts. */
-	return vp_find_vqs_intx(vdev, nvqs, vqs, callbacks, names);
+	return vp_find_vqs_intx(vdev, nvqs, vqs, callbacks, names, ctx);
 }
 
 const char *vp_bus_name(struct virtio_device *vdev)
diff --git a/drivers/virtio/virtio_pci_common.h b/drivers/virtio/virtio_pci_common.h
index e96334a..135ee3c 100644
--- a/drivers/virtio/virtio_pci_common.h
+++ b/drivers/virtio/virtio_pci_common.h
@@ -102,6 +102,7 @@
 				      unsigned idx,
 				      void (*callback)(struct virtqueue *vq),
 				      const char *name,
+				      bool ctx,
 				      u16 msix_vec);
 	void (*del_vq)(struct virtio_pci_vq_info *info);
 
@@ -131,7 +132,8 @@
 /* the config->find_vqs() implementation */
 int vp_find_vqs(struct virtio_device *vdev, unsigned nvqs,
 		struct virtqueue *vqs[], vq_callback_t *callbacks[],
-		const char * const names[], struct irq_affinity *desc);
+		const char * const names[], const bool *ctx,
+		struct irq_affinity *desc);
 const char *vp_bus_name(struct virtio_device *vdev);
 
 /* Setup the affinity for a virtqueue:
diff --git a/drivers/virtio/virtio_pci_legacy.c b/drivers/virtio/virtio_pci_legacy.c
index 4bfa48f..2780886 100644
--- a/drivers/virtio/virtio_pci_legacy.c
+++ b/drivers/virtio/virtio_pci_legacy.c
@@ -116,6 +116,7 @@
 				  unsigned index,
 				  void (*callback)(struct virtqueue *vq),
 				  const char *name,
+				  bool ctx,
 				  u16 msix_vec)
 {
 	struct virtqueue *vq;
@@ -135,7 +136,8 @@
 	/* create the vring */
 	vq = vring_create_virtqueue(index, num,
 				    VIRTIO_PCI_VRING_ALIGN, &vp_dev->vdev,
-				    true, false, vp_notify, callback, name);
+				    true, false, ctx,
+				    vp_notify, callback, name);
 	if (!vq)
 		return ERR_PTR(-ENOMEM);
 
diff --git a/drivers/virtio/virtio_pci_modern.c b/drivers/virtio/virtio_pci_modern.c
index 8978f10..2555d80 100644
--- a/drivers/virtio/virtio_pci_modern.c
+++ b/drivers/virtio/virtio_pci_modern.c
@@ -297,6 +297,7 @@
 				  unsigned index,
 				  void (*callback)(struct virtqueue *vq),
 				  const char *name,
+				  bool ctx,
 				  u16 msix_vec)
 {
 	struct virtio_pci_common_cfg __iomem *cfg = vp_dev->common;
@@ -328,7 +329,8 @@
 	/* create the vring */
 	vq = vring_create_virtqueue(index, num,
 				    SMP_CACHE_BYTES, &vp_dev->vdev,
-				    true, true, vp_notify, callback, name);
+				    true, true, ctx,
+				    vp_notify, callback, name);
 	if (!vq)
 		return ERR_PTR(-ENOMEM);
 
@@ -387,12 +389,14 @@
 }
 
 static int vp_modern_find_vqs(struct virtio_device *vdev, unsigned nvqs,
-		struct virtqueue *vqs[], vq_callback_t *callbacks[],
-		const char * const names[], struct irq_affinity *desc)
+			      struct virtqueue *vqs[],
+			      vq_callback_t *callbacks[],
+			      const char * const names[], const bool *ctx,
+			      struct irq_affinity *desc)
 {
 	struct virtio_pci_device *vp_dev = to_vp_device(vdev);
 	struct virtqueue *vq;
-	int rc = vp_find_vqs(vdev, nvqs, vqs, callbacks, names, desc);
+	int rc = vp_find_vqs(vdev, nvqs, vqs, callbacks, names, ctx, desc);
 
 	if (rc)
 		return rc;
diff --git a/drivers/virtio/virtio_ring.c b/drivers/virtio/virtio_ring.c
index 409aeaa..b23b5fa 100644
--- a/drivers/virtio/virtio_ring.c
+++ b/drivers/virtio/virtio_ring.c
@@ -916,6 +916,7 @@
 					struct vring vring,
 					struct virtio_device *vdev,
 					bool weak_barriers,
+					bool context,
 					bool (*notify)(struct virtqueue *),
 					void (*callback)(struct virtqueue *),
 					const char *name)
@@ -1019,6 +1020,7 @@
 	struct virtio_device *vdev,
 	bool weak_barriers,
 	bool may_reduce_num,
+	bool context,
 	bool (*notify)(struct virtqueue *),
 	void (*callback)(struct virtqueue *),
 	const char *name)
@@ -1058,7 +1060,7 @@
 	queue_size_in_bytes = vring_size(num, vring_align);
 	vring_init(&vring, num, queue, vring_align);
 
-	vq = __vring_new_virtqueue(index, vring, vdev, weak_barriers,
+	vq = __vring_new_virtqueue(index, vring, vdev, weak_barriers, context,
 				   notify, callback, name);
 	if (!vq) {
 		vring_free_queue(vdev, queue_size_in_bytes, queue,
@@ -1079,6 +1081,7 @@
 				      unsigned int vring_align,
 				      struct virtio_device *vdev,
 				      bool weak_barriers,
+				      bool context,
 				      void *pages,
 				      bool (*notify)(struct virtqueue *vq),
 				      void (*callback)(struct virtqueue *vq),
@@ -1086,7 +1089,7 @@
 {
 	struct vring vring;
 	vring_init(&vring, num, pages, vring_align);
-	return __vring_new_virtqueue(index, vring, vdev, weak_barriers,
+	return __vring_new_virtqueue(index, vring, vdev, weak_barriers, context,
 				     notify, callback, name);
 }
 EXPORT_SYMBOL_GPL(vring_new_virtqueue);