IB/srpt: Simplify channel state management

The only allowed channel state changes are those that change
the channel state into a state with a higher numerical value.
This allows to merge the functions srpt_set_ch_state() and
srpt_test_and_set_ch_state() into a single function.

Signed-off-by: Bart Van Assche <bart.vanassche@sandisk.com>
Reviewed-by: Christoph Hellwig <hch@lst.de>
Reviewed-by: Sagi Grimberg <sagig@mellanox.com>
Cc: Alex Estrin <alex.estrin@intel.com>
Signed-off-by: Doug Ledford <dledford@redhat.com>
diff --git a/drivers/infiniband/ulp/srpt/ib_srpt.c b/drivers/infiniband/ulp/srpt/ib_srpt.c
index 716f429..863fdd1 100644
--- a/drivers/infiniband/ulp/srpt/ib_srpt.c
+++ b/drivers/infiniband/ulp/srpt/ib_srpt.c
@@ -96,37 +96,25 @@
 static void srpt_recv_done(struct ib_cq *cq, struct ib_wc *wc);
 static void srpt_send_done(struct ib_cq *cq, struct ib_wc *wc);
 
-static enum rdma_ch_state
-srpt_set_ch_state(struct srpt_rdma_ch *ch, enum rdma_ch_state new_state)
-{
-	unsigned long flags;
-	enum rdma_ch_state prev;
-
-	spin_lock_irqsave(&ch->spinlock, flags);
-	prev = ch->state;
-	ch->state = new_state;
-	spin_unlock_irqrestore(&ch->spinlock, flags);
-	return prev;
-}
-
-/**
- * srpt_test_and_set_ch_state() - Test and set the channel state.
- *
- * Returns true if and only if the channel state has been set to the new state.
+/*
+ * The only allowed channel state changes are those that change the channel
+ * state into a state with a higher numerical value. Hence the new > prev test.
  */
-static bool
-srpt_test_and_set_ch_state(struct srpt_rdma_ch *ch, enum rdma_ch_state old,
-			   enum rdma_ch_state new)
+static bool srpt_set_ch_state(struct srpt_rdma_ch *ch, enum rdma_ch_state new)
 {
 	unsigned long flags;
 	enum rdma_ch_state prev;
+	bool changed = false;
 
 	spin_lock_irqsave(&ch->spinlock, flags);
 	prev = ch->state;
-	if (prev == old)
+	if (new > prev) {
 		ch->state = new;
+		changed = true;
+	}
 	spin_unlock_irqrestore(&ch->spinlock, flags);
-	return prev == old;
+
+	return changed;
 }
 
 /**
@@ -199,8 +187,7 @@
 		ib_cm_notify(ch->cm_id, event->event);
 		break;
 	case IB_EVENT_QP_LAST_WQE_REACHED:
-		if (srpt_test_and_set_ch_state(ch, CH_DRAINING,
-					       CH_RELEASING))
+		if (srpt_set_ch_state(ch, CH_RELEASING))
 			srpt_release_channel(ch);
 		else
 			pr_debug("%s: state %d - ignored LAST_WQE.\n",
@@ -1947,12 +1934,7 @@
 	spin_lock_irq(&sdev->spinlock);
 	list_for_each_entry(ch, &sdev->rch_list, list) {
 		if (ch->cm_id == cm_id) {
-			do_reset = srpt_test_and_set_ch_state(ch,
-					CH_CONNECTING, CH_DRAINING) ||
-				   srpt_test_and_set_ch_state(ch,
-					CH_LIVE, CH_DRAINING) ||
-				   srpt_test_and_set_ch_state(ch,
-					CH_DISCONNECTING, CH_DRAINING);
+			do_reset = srpt_set_ch_state(ch, CH_DRAINING);
 			break;
 		}
 	}
@@ -2353,7 +2335,7 @@
 	ch = srpt_find_channel(cm_id->context, cm_id);
 	BUG_ON(!ch);
 
-	if (srpt_test_and_set_ch_state(ch, CH_CONNECTING, CH_LIVE)) {
+	if (srpt_set_ch_state(ch, CH_LIVE)) {
 		struct srpt_recv_ioctx *ioctx, *ioctx_tmp;
 
 		ret = srpt_ch_qp_rts(ch, ch->qp);