USB: xHCI: Add pointer to udev in struct xhci_virt_device

Add a pointer to udev in struct xhci_virt_device. When allocate a new
virt_device, make the pointer point to the corresponding udev.

Modify xhci_check_args(), check if virt_dev->udev matches the target udev,
to make sure command is issued to the right device.

Signed-off-by: Andiry Xu <andiry.xu@amd.com>
Signed-off-by: Sarah Sharp <sarah.a.sharp@linux.intel.com>
Signed-off-by: Greg Kroah-Hartman <gregkh@suse.de>

diff --git a/drivers/usb/host/xhci.c b/drivers/usb/host/xhci.c
index d5c550e..0bec040 100644
--- a/drivers/usb/host/xhci.c
+++ b/drivers/usb/host/xhci.c
@@ -607,7 +607,11 @@
  * returns 0 this is a root hub; returns -EINVAL for NULL pointers.
  */
 int xhci_check_args(struct usb_hcd *hcd, struct usb_device *udev,
-		struct usb_host_endpoint *ep, int check_ep, const char *func) {
+		struct usb_host_endpoint *ep, int check_ep, bool check_virt_dev,
+		const char *func) {
+	struct xhci_hcd	*xhci;
+	struct xhci_virt_device	*virt_dev;
+
 	if (!hcd || (check_ep && !ep) || !udev) {
 		printk(KERN_DEBUG "xHCI %s called with invalid args\n",
 				func);
@@ -618,11 +622,24 @@
 				func);
 		return 0;
 	}
-	if (!udev->slot_id) {
-		printk(KERN_DEBUG "xHCI %s called with unaddressed device\n",
-				func);
-		return -EINVAL;
+
+	if (check_virt_dev) {
+		xhci = hcd_to_xhci(hcd);
+		if (!udev->slot_id || !xhci->devs
+			|| !xhci->devs[udev->slot_id]) {
+			printk(KERN_DEBUG "xHCI %s called with unaddressed "
+						"device\n", func);
+			return -EINVAL;
+		}
+
+		virt_dev = xhci->devs[udev->slot_id];
+		if (virt_dev->udev != udev) {
+			printk(KERN_DEBUG "xHCI %s called with udev and "
+					  "virt_dev does not match\n", func);
+			return -EINVAL;
+		}
 	}
+
 	return 1;
 }
 
@@ -704,18 +721,13 @@
 	struct urb_priv	*urb_priv;
 	int size, i;
 
-	if (!urb || xhci_check_args(hcd, urb->dev, urb->ep, true, __func__) <= 0)
+	if (!urb || xhci_check_args(hcd, urb->dev, urb->ep,
+					true, true, __func__) <= 0)
 		return -EINVAL;
 
 	slot_id = urb->dev->slot_id;
 	ep_index = xhci_get_endpoint_index(&urb->ep->desc);
 
-	if (!xhci->devs || !xhci->devs[slot_id]) {
-		if (!in_interrupt())
-			dev_warn(&urb->dev->dev, "WARN: urb submitted for dev with no Slot ID\n");
-		ret = -EINVAL;
-		goto exit;
-	}
 	if (!HCD_HW_ACCESSIBLE(hcd)) {
 		if (!in_interrupt())
 			xhci_dbg(xhci, "urb submitted during PCI suspend\n");
@@ -991,7 +1003,7 @@
 	u32 new_add_flags, new_drop_flags, new_slot_info;
 	int ret;
 
-	ret = xhci_check_args(hcd, udev, ep, 1, __func__);
+	ret = xhci_check_args(hcd, udev, ep, 1, true, __func__);
 	if (ret <= 0)
 		return ret;
 	xhci = hcd_to_xhci(hcd);
@@ -1004,12 +1016,6 @@
 		return 0;
 	}
 
-	if (!xhci->devs || !xhci->devs[udev->slot_id]) {
-		xhci_warn(xhci, "xHCI %s called with unaddressed device\n",
-				__func__);
-		return -EINVAL;
-	}
-
 	in_ctx = xhci->devs[udev->slot_id]->in_ctx;
 	out_ctx = xhci->devs[udev->slot_id]->out_ctx;
 	ctrl_ctx = xhci_get_input_control_ctx(xhci, in_ctx);
@@ -1078,7 +1084,7 @@
 	u32 new_add_flags, new_drop_flags, new_slot_info;
 	int ret = 0;
 
-	ret = xhci_check_args(hcd, udev, ep, 1, __func__);
+	ret = xhci_check_args(hcd, udev, ep, 1, true, __func__);
 	if (ret <= 0) {
 		/* So we won't queue a reset ep command for a root hub */
 		ep->hcpriv = NULL;
@@ -1098,12 +1104,6 @@
 		return 0;
 	}
 
-	if (!xhci->devs || !xhci->devs[udev->slot_id]) {
-		xhci_warn(xhci, "xHCI %s called with unaddressed device\n",
-				__func__);
-		return -EINVAL;
-	}
-
 	in_ctx = xhci->devs[udev->slot_id]->in_ctx;
 	out_ctx = xhci->devs[udev->slot_id]->out_ctx;
 	ctrl_ctx = xhci_get_input_control_ctx(xhci, in_ctx);
@@ -1346,16 +1346,11 @@
 	struct xhci_input_control_ctx *ctrl_ctx;
 	struct xhci_slot_ctx *slot_ctx;
 
-	ret = xhci_check_args(hcd, udev, NULL, 0, __func__);
+	ret = xhci_check_args(hcd, udev, NULL, 0, true, __func__);
 	if (ret <= 0)
 		return ret;
 	xhci = hcd_to_xhci(hcd);
 
-	if (!udev->slot_id || !xhci->devs || !xhci->devs[udev->slot_id]) {
-		xhci_warn(xhci, "xHCI %s called with unaddressed device\n",
-				__func__);
-		return -EINVAL;
-	}
 	xhci_dbg(xhci, "%s called for udev %p\n", __func__, udev);
 	virt_dev = xhci->devs[udev->slot_id];
 
@@ -1405,16 +1400,11 @@
 	struct xhci_virt_device	*virt_dev;
 	int i, ret;
 
-	ret = xhci_check_args(hcd, udev, NULL, 0, __func__);
+	ret = xhci_check_args(hcd, udev, NULL, 0, true, __func__);
 	if (ret <= 0)
 		return;
 	xhci = hcd_to_xhci(hcd);
 
-	if (!xhci->devs || !xhci->devs[udev->slot_id]) {
-		xhci_warn(xhci, "xHCI %s called with unaddressed device\n",
-				__func__);
-		return;
-	}
 	xhci_dbg(xhci, "%s called for udev %p\n", __func__, udev);
 	virt_dev = xhci->devs[udev->slot_id];
 	/* Free any rings allocated for added endpoints */
@@ -1575,7 +1565,7 @@
 
 	if (!ep)
 		return -EINVAL;
-	ret = xhci_check_args(xhci_to_hcd(xhci), udev, ep, 1, __func__);
+	ret = xhci_check_args(xhci_to_hcd(xhci), udev, ep, 1, true, __func__);
 	if (ret <= 0)
 		return -EINVAL;
 	if (ep->ss_ep_comp.bmAttributes == 0) {
@@ -1965,17 +1955,12 @@
 	int timeleft;
 	int last_freed_endpoint;
 
-	ret = xhci_check_args(hcd, udev, NULL, 0, __func__);
+	ret = xhci_check_args(hcd, udev, NULL, 0, true, __func__);
 	if (ret <= 0)
 		return ret;
 	xhci = hcd_to_xhci(hcd);
 	slot_id = udev->slot_id;
 	virt_dev = xhci->devs[slot_id];
-	if (!virt_dev) {
-		xhci_dbg(xhci, "%s called with invalid slot ID %u\n",
-				__func__, slot_id);
-		return -EINVAL;
-	}
 
 	xhci_dbg(xhci, "Resetting device with slot ID %u\n", slot_id);
 	/* Allocate the command structure that holds the struct completion.
@@ -2077,13 +2062,13 @@
 	struct xhci_virt_device *virt_dev;
 	unsigned long flags;
 	u32 state;
-	int i;
+	int i, ret;
 
-	if (udev->slot_id == 0)
+	ret = xhci_check_args(hcd, udev, NULL, 0, true, __func__);
+	if (ret <= 0)
 		return;
+
 	virt_dev = xhci->devs[udev->slot_id];
-	if (!virt_dev)
-		return;
 
 	/* Stop any wayward timer functions (which may grab the lock) */
 	for (i = 0; i < 31; ++i) {