Merge branch 'vhost-net-next' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index f0fd52c..70ac604 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -703,6 +703,10 @@
 		vhost_net_disable_vq(n, vq);
 		rcu_assign_pointer(vq->private_data, sock);
 		vhost_net_enable_vq(n, vq);
+
+		r = vhost_init_used(vq);
+		if (r)
+			goto err_vq;
 	}
 
 	mutex_unlock(&vq->mutex);
diff --git a/drivers/vhost/test.c b/drivers/vhost/test.c
index 734e1d7..fc9a1d7 100644
--- a/drivers/vhost/test.c
+++ b/drivers/vhost/test.c
@@ -195,8 +195,13 @@
 						    lockdep_is_held(&vq->mutex));
 		rcu_assign_pointer(vq->private_data, priv);
 
+		r = vhost_init_used(&n->vqs[index]);
+
 		mutex_unlock(&vq->mutex);
 
+		if (r)
+			goto err;
+
 		if (oldpriv) {
 			vhost_test_flush_vq(n, index);
 		}
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 5ef2f62..c14c42b 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -629,17 +629,6 @@
 	return 0;
 }
 
-static int init_used(struct vhost_virtqueue *vq,
-		     struct vring_used __user *used)
-{
-	int r = put_user(vq->used_flags, &used->flags);
-
-	if (r)
-		return r;
-	vq->signalled_used_valid = false;
-	return get_user(vq->last_used_idx, &used->idx);
-}
-
 static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
 {
 	struct file *eventfp, *filep = NULL,
@@ -752,10 +741,6 @@
 			}
 		}
 
-		r = init_used(vq, (struct vring_used __user *)(unsigned long)
-			      a.used_user_addr);
-		if (r)
-			break;
 		vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG));
 		vq->desc = (void __user *)(unsigned long)a.desc_user_addr;
 		vq->avail = (void __user *)(unsigned long)a.avail_user_addr;
@@ -1010,6 +995,57 @@
 	return 0;
 }
 
+static int vhost_update_used_flags(struct vhost_virtqueue *vq)
+{
+	void __user *used;
+	if (__put_user(vq->used_flags, &vq->used->flags) < 0)
+		return -EFAULT;
+	if (unlikely(vq->log_used)) {
+		/* Make sure the flag is seen before log. */
+		smp_wmb();
+		/* Log used flag write. */
+		used = &vq->used->flags;
+		log_write(vq->log_base, vq->log_addr +
+			  (used - (void __user *)vq->used),
+			  sizeof vq->used->flags);
+		if (vq->log_ctx)
+			eventfd_signal(vq->log_ctx, 1);
+	}
+	return 0;
+}
+
+static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event)
+{
+	if (__put_user(vq->avail_idx, vhost_avail_event(vq)))
+		return -EFAULT;
+	if (unlikely(vq->log_used)) {
+		void __user *used;
+		/* Make sure the event is seen before log. */
+		smp_wmb();
+		/* Log avail event write */
+		used = vhost_avail_event(vq);
+		log_write(vq->log_base, vq->log_addr +
+			  (used - (void __user *)vq->used),
+			  sizeof *vhost_avail_event(vq));
+		if (vq->log_ctx)
+			eventfd_signal(vq->log_ctx, 1);
+	}
+	return 0;
+}
+
+int vhost_init_used(struct vhost_virtqueue *vq)
+{
+	int r;
+	if (!vq->private_data)
+		return 0;
+
+	r = vhost_update_used_flags(vq);
+	if (r)
+		return r;
+	vq->signalled_used_valid = false;
+	return get_user(vq->last_used_idx, &vq->used->idx);
+}
+
 static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
 			  struct iovec iov[], int iov_size)
 {
@@ -1481,34 +1517,20 @@
 		return false;
 	vq->used_flags &= ~VRING_USED_F_NO_NOTIFY;
 	if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
-		r = put_user(vq->used_flags, &vq->used->flags);
+		r = vhost_update_used_flags(vq);
 		if (r) {
 			vq_err(vq, "Failed to enable notification at %p: %d\n",
 			       &vq->used->flags, r);
 			return false;
 		}
 	} else {
-		r = put_user(vq->avail_idx, vhost_avail_event(vq));
+		r = vhost_update_avail_event(vq, vq->avail_idx);
 		if (r) {
 			vq_err(vq, "Failed to update avail event index at %p: %d\n",
 			       vhost_avail_event(vq), r);
 			return false;
 		}
 	}
-	if (unlikely(vq->log_used)) {
-		void __user *used;
-		/* Make sure data is seen before log. */
-		smp_wmb();
-		used = vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX) ?
-			&vq->used->flags : vhost_avail_event(vq);
-		/* Log used flags or event index entry write. Both are 16 bit
-		 * fields. */
-		log_write(vq->log_base, vq->log_addr +
-			   (used - (void __user *)vq->used),
-			  sizeof(u16));
-		if (vq->log_ctx)
-			eventfd_signal(vq->log_ctx, 1);
-	}
 	/* They could have slipped one in as we were doing that: make
 	 * sure it's written, then check again. */
 	smp_mb();
@@ -1531,7 +1553,7 @@
 		return;
 	vq->used_flags |= VRING_USED_F_NO_NOTIFY;
 	if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
-		r = put_user(vq->used_flags, &vq->used->flags);
+		r = vhost_update_used_flags(vq);
 		if (r)
 			vq_err(vq, "Failed to enable notification at %p: %d\n",
 			       &vq->used->flags, r);
@@ -1556,7 +1578,6 @@
 	if (!ubufs)
 		return ERR_PTR(-ENOMEM);
 	kref_init(&ubufs->kref);
-	kref_get(&ubufs->kref);
 	init_waitqueue_head(&ubufs->wait);
 	ubufs->vq = vq;
 	return ubufs;
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index 1544b78..14c9abf 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -174,6 +174,7 @@
 		      struct vhost_log *log, unsigned int *log_num);
 void vhost_discard_vq_desc(struct vhost_virtqueue *, int n);
 
+int vhost_init_used(struct vhost_virtqueue *);
 int vhost_add_used(struct vhost_virtqueue *, unsigned int head, int len);
 int vhost_add_used_n(struct vhost_virtqueue *, struct vring_used_elem *heads,
 		     unsigned count);