xhci: Remove BUG_ON in xhci_get_input_control_ctx.

Fail gracefully, instead of causing the kernel to panic, if the input
control context doesn't have the right type (XHCI_CTX_TYPE_INPUT).  Push
finding the pointer to the input control context up into functions that
can fail.

Signed-off-by: Sarah Sharp <sarah.a.sharp@linux.intel.com>
Cc: John Youn <johnyoun@synopsys.com>
diff --git a/drivers/usb/host/xhci-dbg.c b/drivers/usb/host/xhci-dbg.c
index f2e7689..5d5e58f 100644
--- a/drivers/usb/host/xhci-dbg.c
+++ b/drivers/usb/host/xhci-dbg.c
@@ -553,6 +553,11 @@
 	if (ctx->type == XHCI_CTX_TYPE_INPUT) {
 		struct xhci_input_control_ctx *ctrl_ctx =
 			xhci_get_input_control_ctx(xhci, ctx);
+		if (!ctrl_ctx) {
+			xhci_warn(xhci, "Could not get input context, bad type.\n");
+			return;
+		}
+
 		xhci_dbg(xhci, "@%p (virt) @%08llx (dma) %#08x - drop flags\n",
 			 &ctrl_ctx->drop_flags, (unsigned long long)dma,
 			 ctrl_ctx->drop_flags);
diff --git a/drivers/usb/host/xhci-mem.c b/drivers/usb/host/xhci-mem.c
index cc1b3ef..b7487a1 100644
--- a/drivers/usb/host/xhci-mem.c
+++ b/drivers/usb/host/xhci-mem.c
@@ -389,7 +389,9 @@
 struct xhci_input_control_ctx *xhci_get_input_control_ctx(struct xhci_hcd *xhci,
 					      struct xhci_container_ctx *ctx)
 {
-	BUG_ON(ctx->type != XHCI_CTX_TYPE_INPUT);
+	if (ctx->type != XHCI_CTX_TYPE_INPUT)
+		return NULL;
+
 	return (struct xhci_input_control_ctx *)ctx->bytes;
 }
 
diff --git a/drivers/usb/host/xhci-ring.c b/drivers/usb/host/xhci-ring.c
index e02b907..1e57eaf 100644
--- a/drivers/usb/host/xhci-ring.c
+++ b/drivers/usb/host/xhci-ring.c
@@ -1424,6 +1424,10 @@
 		 */
 		ctrl_ctx = xhci_get_input_control_ctx(xhci,
 				virt_dev->in_ctx);
+		if (!ctrl_ctx) {
+			xhci_warn(xhci, "Could not get input context, bad type.\n");
+			break;
+		}
 		/* Input ctx add_flags are the endpoint index plus one */
 		ep_index = xhci_last_valid_endpoint(le32_to_cpu(ctrl_ctx->add_flags)) - 1;
 		/* A usb_set_interface() call directly after clearing a halted
diff --git a/drivers/usb/host/xhci.c b/drivers/usb/host/xhci.c
index 77113c1..6779c92 100644
--- a/drivers/usb/host/xhci.c
+++ b/drivers/usb/host/xhci.c
@@ -1235,19 +1235,25 @@
 				hw_max_packet_size);
 		xhci_dbg(xhci, "Issuing evaluate context command.\n");
 
-		/* Set up the modified control endpoint 0 */
-		xhci_endpoint_copy(xhci, xhci->devs[slot_id]->in_ctx,
-				xhci->devs[slot_id]->out_ctx, ep_index);
-		in_ctx = xhci->devs[slot_id]->in_ctx;
-		ep_ctx = xhci_get_ep_ctx(xhci, in_ctx, ep_index);
-		ep_ctx->ep_info2 &= cpu_to_le32(~MAX_PACKET_MASK);
-		ep_ctx->ep_info2 |= cpu_to_le32(MAX_PACKET(max_packet_size));
-
 		/* Set up the input context flags for the command */
 		/* FIXME: This won't work if a non-default control endpoint
 		 * changes max packet sizes.
 		 */
+		in_ctx = xhci->devs[slot_id]->in_ctx;
 		ctrl_ctx = xhci_get_input_control_ctx(xhci, in_ctx);
+		if (!ctrl_ctx) {
+			xhci_warn(xhci, "%s: Could not get input context, bad type.\n",
+					__func__);
+			return -ENOMEM;
+		}
+		/* Set up the modified control endpoint 0 */
+		xhci_endpoint_copy(xhci, xhci->devs[slot_id]->in_ctx,
+				xhci->devs[slot_id]->out_ctx, ep_index);
+
+		ep_ctx = xhci_get_ep_ctx(xhci, in_ctx, ep_index);
+		ep_ctx->ep_info2 &= cpu_to_le32(~MAX_PACKET_MASK);
+		ep_ctx->ep_info2 |= cpu_to_le32(MAX_PACKET(max_packet_size));
+
 		ctrl_ctx->add_flags = cpu_to_le32(EP0_FLAG);
 		ctrl_ctx->drop_flags = 0;
 
@@ -1607,6 +1613,12 @@
 	in_ctx = xhci->devs[udev->slot_id]->in_ctx;
 	out_ctx = xhci->devs[udev->slot_id]->out_ctx;
 	ctrl_ctx = xhci_get_input_control_ctx(xhci, in_ctx);
+	if (!ctrl_ctx) {
+		xhci_warn(xhci, "%s: Could not get input context, bad type.\n",
+				__func__);
+		return 0;
+	}
+
 	ep_index = xhci_get_endpoint_index(&ep->desc);
 	ep_ctx = xhci_get_ep_ctx(xhci, out_ctx, ep_index);
 	/* If the HC already knows the endpoint is disabled,
@@ -1701,8 +1713,13 @@
 	in_ctx = virt_dev->in_ctx;
 	out_ctx = virt_dev->out_ctx;
 	ctrl_ctx = xhci_get_input_control_ctx(xhci, in_ctx);
-	ep_index = xhci_get_endpoint_index(&ep->desc);
+	if (!ctrl_ctx) {
+		xhci_warn(xhci, "%s: Could not get input context, bad type.\n",
+				__func__);
+		return 0;
+	}
 
+	ep_index = xhci_get_endpoint_index(&ep->desc);
 	/* If this endpoint is already in use, and the upper layers are trying
 	 * to add it again without dropping it, reject the addition.
 	 */
@@ -1775,12 +1792,18 @@
 	struct xhci_slot_ctx *slot_ctx;
 	int i;
 
+	ctrl_ctx = xhci_get_input_control_ctx(xhci, virt_dev->in_ctx);
+	if (!ctrl_ctx) {
+		xhci_warn(xhci, "%s: Could not get input context, bad type.\n",
+				__func__);
+		return;
+	}
+
 	/* When a device's add flag and drop flag are zero, any subsequent
 	 * configure endpoint command will leave that endpoint's state
 	 * untouched.  Make sure we don't leave any old state in the input
 	 * endpoint contexts.
 	 */
-	ctrl_ctx = xhci_get_input_control_ctx(xhci, virt_dev->in_ctx);
 	ctrl_ctx->drop_flags = 0;
 	ctrl_ctx->add_flags = 0;
 	slot_ctx = xhci_get_slot_ctx(xhci, virt_dev->in_ctx);
@@ -1887,13 +1910,11 @@
 }
 
 static u32 xhci_count_num_new_endpoints(struct xhci_hcd *xhci,
-		struct xhci_container_ctx *in_ctx)
+		struct xhci_input_control_ctx *ctrl_ctx)
 {
-	struct xhci_input_control_ctx *ctrl_ctx;
 	u32 valid_add_flags;
 	u32 valid_drop_flags;
 
-	ctrl_ctx = xhci_get_input_control_ctx(xhci, in_ctx);
 	/* Ignore the slot flag (bit 0), and the default control endpoint flag
 	 * (bit 1).  The default control endpoint is added during the Address
 	 * Device command and is never removed until the slot is disabled.
@@ -1910,13 +1931,11 @@
 }
 
 static unsigned int xhci_count_num_dropped_endpoints(struct xhci_hcd *xhci,
-		struct xhci_container_ctx *in_ctx)
+		struct xhci_input_control_ctx *ctrl_ctx)
 {
-	struct xhci_input_control_ctx *ctrl_ctx;
 	u32 valid_add_flags;
 	u32 valid_drop_flags;
 
-	ctrl_ctx = xhci_get_input_control_ctx(xhci, in_ctx);
 	valid_add_flags = ctrl_ctx->add_flags >> 2;
 	valid_drop_flags = ctrl_ctx->drop_flags >> 2;
 
@@ -1938,11 +1957,11 @@
  * Must be called with xhci->lock held.
  */
 static int xhci_reserve_host_resources(struct xhci_hcd *xhci,
-		struct xhci_container_ctx *in_ctx)
+		struct xhci_input_control_ctx *ctrl_ctx)
 {
 	u32 added_eps;
 
-	added_eps = xhci_count_num_new_endpoints(xhci, in_ctx);
+	added_eps = xhci_count_num_new_endpoints(xhci, ctrl_ctx);
 	if (xhci->num_active_eps + added_eps > xhci->limit_active_eps) {
 		xhci_dbg(xhci, "Not enough ep ctxs: "
 				"%u active, need to add %u, limit is %u.\n",
@@ -1963,11 +1982,11 @@
  * Must be called with xhci->lock held.
  */
 static void xhci_free_host_resources(struct xhci_hcd *xhci,
-		struct xhci_container_ctx *in_ctx)
+		struct xhci_input_control_ctx *ctrl_ctx)
 {
 	u32 num_failed_eps;
 
-	num_failed_eps = xhci_count_num_new_endpoints(xhci, in_ctx);
+	num_failed_eps = xhci_count_num_new_endpoints(xhci, ctrl_ctx);
 	xhci->num_active_eps -= num_failed_eps;
 	xhci_dbg(xhci, "Removing %u failed ep ctxs, %u now active.\n",
 			num_failed_eps,
@@ -1981,11 +2000,11 @@
  * Must be called with xhci->lock held.
  */
 static void xhci_finish_resource_reservation(struct xhci_hcd *xhci,
-		struct xhci_container_ctx *in_ctx)
+		struct xhci_input_control_ctx *ctrl_ctx)
 {
 	u32 num_dropped_eps;
 
-	num_dropped_eps = xhci_count_num_dropped_endpoints(xhci, in_ctx);
+	num_dropped_eps = xhci_count_num_dropped_endpoints(xhci, ctrl_ctx);
 	xhci->num_active_eps -= num_dropped_eps;
 	if (num_dropped_eps)
 		xhci_dbg(xhci, "Removing %u dropped ep ctxs, %u now active.\n",
@@ -2480,6 +2499,11 @@
 		old_active_eps = virt_dev->tt_info->active_eps;
 
 	ctrl_ctx = xhci_get_input_control_ctx(xhci, in_ctx);
+	if (!ctrl_ctx) {
+		xhci_warn(xhci, "%s: Could not get input context, bad type.\n",
+				__func__);
+		return -ENOMEM;
+	}
 
 	for (i = 0; i < 31; i++) {
 		if (!EP_IS_ADDED(ctrl_ctx, i) && !EP_IS_DROPPED(ctrl_ctx, i))
@@ -2564,6 +2588,7 @@
 	int timeleft;
 	unsigned long flags;
 	struct xhci_container_ctx *in_ctx;
+	struct xhci_input_control_ctx *ctrl_ctx;
 	struct completion *cmd_completion;
 	u32 *cmd_status;
 	struct xhci_virt_device *virt_dev;
@@ -2576,9 +2601,15 @@
 		in_ctx = command->in_ctx;
 	else
 		in_ctx = virt_dev->in_ctx;
+	ctrl_ctx = xhci_get_input_control_ctx(xhci, in_ctx);
+	if (!ctrl_ctx) {
+		xhci_warn(xhci, "%s: Could not get input context, bad type.\n",
+				__func__);
+		return -ENOMEM;
+	}
 
 	if ((xhci->quirks & XHCI_EP_LIMIT_QUIRK) &&
-			xhci_reserve_host_resources(xhci, in_ctx)) {
+			xhci_reserve_host_resources(xhci, ctrl_ctx)) {
 		spin_unlock_irqrestore(&xhci->lock, flags);
 		xhci_warn(xhci, "Not enough host resources, "
 				"active endpoint contexts = %u\n",
@@ -2588,7 +2619,7 @@
 	if ((xhci->quirks & XHCI_SW_BW_CHECKING) &&
 			xhci_reserve_bandwidth(xhci, virt_dev, in_ctx)) {
 		if ((xhci->quirks & XHCI_EP_LIMIT_QUIRK))
-			xhci_free_host_resources(xhci, in_ctx);
+			xhci_free_host_resources(xhci, ctrl_ctx);
 		spin_unlock_irqrestore(&xhci->lock, flags);
 		xhci_warn(xhci, "Not enough bandwidth\n");
 		return -ENOMEM;
@@ -2624,7 +2655,7 @@
 		if (command)
 			list_del(&command->cmd_list);
 		if ((xhci->quirks & XHCI_EP_LIMIT_QUIRK))
-			xhci_free_host_resources(xhci, in_ctx);
+			xhci_free_host_resources(xhci, ctrl_ctx);
 		spin_unlock_irqrestore(&xhci->lock, flags);
 		xhci_dbg(xhci, "FIXME allocate a new ring segment\n");
 		return -ENOMEM;
@@ -2660,9 +2691,9 @@
 		 * Otherwise, clean up the estimate to include dropped eps.
 		 */
 		if (ret)
-			xhci_free_host_resources(xhci, in_ctx);
+			xhci_free_host_resources(xhci, ctrl_ctx);
 		else
-			xhci_finish_resource_reservation(xhci, in_ctx);
+			xhci_finish_resource_reservation(xhci, ctrl_ctx);
 		spin_unlock_irqrestore(&xhci->lock, flags);
 	}
 	return ret;
@@ -2699,6 +2730,11 @@
 
 	/* See section 4.6.6 - A0 = 1; A1 = D0 = D1 = 0 */
 	ctrl_ctx = xhci_get_input_control_ctx(xhci, virt_dev->in_ctx);
+	if (!ctrl_ctx) {
+		xhci_warn(xhci, "%s: Could not get input context, bad type.\n",
+				__func__);
+		return -ENOMEM;
+	}
 	ctrl_ctx->add_flags |= cpu_to_le32(SLOT_FLAG);
 	ctrl_ctx->add_flags &= cpu_to_le32(~EP0_FLAG);
 	ctrl_ctx->drop_flags &= cpu_to_le32(~(SLOT_FLAG | EP0_FLAG));
@@ -2777,10 +2813,9 @@
 static void xhci_setup_input_ctx_for_config_ep(struct xhci_hcd *xhci,
 		struct xhci_container_ctx *in_ctx,
 		struct xhci_container_ctx *out_ctx,
+		struct xhci_input_control_ctx *ctrl_ctx,
 		u32 add_flags, u32 drop_flags)
 {
-	struct xhci_input_control_ctx *ctrl_ctx;
-	ctrl_ctx = xhci_get_input_control_ctx(xhci, in_ctx);
 	ctrl_ctx->add_flags = cpu_to_le32(add_flags);
 	ctrl_ctx->drop_flags = cpu_to_le32(drop_flags);
 	xhci_slot_copy(xhci, in_ctx, out_ctx);
@@ -2794,14 +2829,22 @@
 		unsigned int slot_id, unsigned int ep_index,
 		struct xhci_dequeue_state *deq_state)
 {
+	struct xhci_input_control_ctx *ctrl_ctx;
 	struct xhci_container_ctx *in_ctx;
 	struct xhci_ep_ctx *ep_ctx;
 	u32 added_ctxs;
 	dma_addr_t addr;
 
+	in_ctx = xhci->devs[slot_id]->in_ctx;
+	ctrl_ctx = xhci_get_input_control_ctx(xhci, in_ctx);
+	if (!ctrl_ctx) {
+		xhci_warn(xhci, "%s: Could not get input context, bad type.\n",
+				__func__);
+		return;
+	}
+
 	xhci_endpoint_copy(xhci, xhci->devs[slot_id]->in_ctx,
 			xhci->devs[slot_id]->out_ctx, ep_index);
-	in_ctx = xhci->devs[slot_id]->in_ctx;
 	ep_ctx = xhci_get_ep_ctx(xhci, in_ctx, ep_index);
 	addr = xhci_trb_virt_to_dma(deq_state->new_deq_seg,
 			deq_state->new_deq_ptr);
@@ -2817,7 +2860,8 @@
 
 	added_ctxs = xhci_get_endpoint_flag_from_index(ep_index);
 	xhci_setup_input_ctx_for_config_ep(xhci, xhci->devs[slot_id]->in_ctx,
-			xhci->devs[slot_id]->out_ctx, added_ctxs, added_ctxs);
+			xhci->devs[slot_id]->out_ctx, ctrl_ctx,
+			added_ctxs, added_ctxs);
 }
 
 void xhci_cleanup_stalled_ring(struct xhci_hcd *xhci,
@@ -3075,6 +3119,7 @@
 	struct xhci_hcd *xhci;
 	struct xhci_virt_device *vdev;
 	struct xhci_command *config_cmd;
+	struct xhci_input_control_ctx *ctrl_ctx;
 	unsigned int ep_index;
 	unsigned int num_stream_ctxs;
 	unsigned long flags;
@@ -3096,6 +3141,13 @@
 		xhci_dbg(xhci, "Could not allocate xHCI command structure.\n");
 		return -ENOMEM;
 	}
+	ctrl_ctx = xhci_get_input_control_ctx(xhci, config_cmd->in_ctx);
+	if (!ctrl_ctx) {
+		xhci_warn(xhci, "%s: Could not get input context, bad type.\n",
+				__func__);
+		xhci_free_command(xhci, config_cmd);
+		return -ENOMEM;
+	}
 
 	/* Check to make sure all endpoints are not already configured for
 	 * streams.  While we're at it, find the maximum number of streams that
@@ -3162,7 +3214,8 @@
 	 * and add the updated copy from the input context.
 	 */
 	xhci_setup_input_ctx_for_config_ep(xhci, config_cmd->in_ctx,
-			vdev->out_ctx, changed_ep_bitmask, changed_ep_bitmask);
+			vdev->out_ctx, ctrl_ctx,
+			changed_ep_bitmask, changed_ep_bitmask);
 
 	/* Issue and wait for the configure endpoint command */
 	ret = xhci_configure_endpoint(xhci, udev, config_cmd,
@@ -3220,6 +3273,7 @@
 	struct xhci_hcd *xhci;
 	struct xhci_virt_device *vdev;
 	struct xhci_command *command;
+	struct xhci_input_control_ctx *ctrl_ctx;
 	unsigned int ep_index;
 	unsigned long flags;
 	u32 changed_ep_bitmask;
@@ -3242,6 +3296,13 @@
 	 */
 	ep_index = xhci_get_endpoint_index(&eps[0]->desc);
 	command = vdev->eps[ep_index].stream_info->free_streams_command;
+	ctrl_ctx = xhci_get_input_control_ctx(xhci, command->in_ctx);
+	if (!ctrl_ctx) {
+		xhci_warn(xhci, "%s: Could not get input context, bad type.\n",
+				__func__);
+		return -EINVAL;
+	}
+
 	for (i = 0; i < num_eps; i++) {
 		struct xhci_ep_ctx *ep_ctx;
 
@@ -3256,7 +3317,8 @@
 				&vdev->eps[ep_index]);
 	}
 	xhci_setup_input_ctx_for_config_ep(xhci, command->in_ctx,
-			vdev->out_ctx, changed_ep_bitmask, changed_ep_bitmask);
+			vdev->out_ctx, ctrl_ctx,
+			changed_ep_bitmask, changed_ep_bitmask);
 	spin_unlock_irqrestore(&xhci->lock, flags);
 
 	/* Issue and wait for the configure endpoint command,
@@ -3696,6 +3758,12 @@
 	}
 
 	slot_ctx = xhci_get_slot_ctx(xhci, virt_dev->in_ctx);
+	ctrl_ctx = xhci_get_input_control_ctx(xhci, virt_dev->in_ctx);
+	if (!ctrl_ctx) {
+		xhci_warn(xhci, "%s: Could not get input context, bad type.\n",
+				__func__);
+		return -EINVAL;
+	}
 	/*
 	 * If this is the first Set Address since device plug-in or
 	 * virt_device realloaction after a resume with an xHCI power loss,
@@ -3706,7 +3774,6 @@
 	/* Otherwise, update the control endpoint ring enqueue pointer. */
 	else
 		xhci_copy_ep0_dequeue_into_input_ctx(xhci, udev);
-	ctrl_ctx = xhci_get_input_control_ctx(xhci, virt_dev->in_ctx);
 	ctrl_ctx->add_flags = cpu_to_le32(SLOT_FLAG | EP0_FLAG);
 	ctrl_ctx->drop_flags = 0;
 
@@ -3848,10 +3915,17 @@
 	/* Attempt to issue an Evaluate Context command to change the MEL. */
 	virt_dev = xhci->devs[udev->slot_id];
 	command = xhci->lpm_command;
+	ctrl_ctx = xhci_get_input_control_ctx(xhci, command->in_ctx);
+	if (!ctrl_ctx) {
+		spin_unlock_irqrestore(&xhci->lock, flags);
+		xhci_warn(xhci, "%s: Could not get input context, bad type.\n",
+				__func__);
+		return -ENOMEM;
+	}
+
 	xhci_slot_copy(xhci, command->in_ctx, virt_dev->out_ctx);
 	spin_unlock_irqrestore(&xhci->lock, flags);
 
-	ctrl_ctx = xhci_get_input_control_ctx(xhci, command->in_ctx);
 	ctrl_ctx->add_flags |= cpu_to_le32(SLOT_FLAG);
 	slot_ctx = xhci_get_slot_ctx(xhci, command->in_ctx);
 	slot_ctx->dev_info2 &= cpu_to_le32(~((u32) MAX_EXIT));
@@ -4677,6 +4751,13 @@
 		xhci_dbg(xhci, "Could not allocate xHCI command structure.\n");
 		return -ENOMEM;
 	}
+	ctrl_ctx = xhci_get_input_control_ctx(xhci, config_cmd->in_ctx);
+	if (!ctrl_ctx) {
+		xhci_warn(xhci, "%s: Could not get input context, bad type.\n",
+				__func__);
+		xhci_free_command(xhci, config_cmd);
+		return -ENOMEM;
+	}
 
 	spin_lock_irqsave(&xhci->lock, flags);
 	if (hdev->speed == USB_SPEED_HIGH &&
@@ -4688,7 +4769,6 @@
 	}
 
 	xhci_slot_copy(xhci, config_cmd->in_ctx, vdev->out_ctx);
-	ctrl_ctx = xhci_get_input_control_ctx(xhci, config_cmd->in_ctx);
 	ctrl_ctx->add_flags |= cpu_to_le32(SLOT_FLAG);
 	slot_ctx = xhci_get_slot_ctx(xhci, config_cmd->in_ctx);
 	slot_ctx->dev_info |= cpu_to_le32(DEV_HUB);