IPoIB: Fix usage of uninitialized multicast objects

The driver should avoid calling ib_sa_free_multicast on the mcast->mc
object until it finishes its initialization state.  Otherwise we can
crash when ipoib_mcast_dev_flush() attempts to use the uninitialized
multicast object.

Instead, only call wait_for_completion() for multicast entries that
started the join process, meaning that ib_sa_join_multicast() finished.

Signed-off-by: Erez Shitrit <erezsh@mellanox.com>
Signed-off-by: Or Gerlitz <ogerlitz@mellanox.com>
Signed-off-by: Roland Dreier <roland@purestorage.com>
diff --git a/drivers/infiniband/ulp/ipoib/ipoib.h b/drivers/infiniband/ulp/ipoib/ipoib.h
index ec9190e..c639f90 100644
--- a/drivers/infiniband/ulp/ipoib/ipoib.h
+++ b/drivers/infiniband/ulp/ipoib/ipoib.h
@@ -101,6 +101,7 @@
 	IPOIB_MCAST_FLAG_SENDONLY = 1,
 	IPOIB_MCAST_FLAG_BUSY	  = 2,	/* joining or already joined */
 	IPOIB_MCAST_FLAG_ATTACHED = 3,
+	IPOIB_MCAST_JOIN_STARTED  = 4,
 
 	MAX_SEND_CQE		  = 16,
 	IPOIB_CM_COPYBREAK	  = 256,
@@ -151,6 +152,7 @@
 	struct sk_buff_head pkt_queue;
 
 	struct net_device *dev;
+	struct completion done;
 };
 
 struct ipoib_rx_buf {
diff --git a/drivers/infiniband/ulp/ipoib/ipoib_multicast.c b/drivers/infiniband/ulp/ipoib/ipoib_multicast.c
index cecb98a..780a2a0 100644
--- a/drivers/infiniband/ulp/ipoib/ipoib_multicast.c
+++ b/drivers/infiniband/ulp/ipoib/ipoib_multicast.c
@@ -386,8 +386,10 @@
 			mcast->mcmember.mgid.raw, status);
 
 	/* We trap for port events ourselves. */
-	if (status == -ENETRESET)
-		return 0;
+	if (status == -ENETRESET) {
+		status = 0;
+		goto out;
+	}
 
 	if (!status)
 		status = ipoib_mcast_join_finish(mcast, &multicast->rec);
@@ -407,7 +409,8 @@
 		if (mcast == priv->broadcast)
 			queue_work(ipoib_workqueue, &priv->carrier_on_task);
 
-		return 0;
+		status = 0;
+		goto out;
 	}
 
 	if (mcast->logcount++ < 20) {
@@ -434,7 +437,8 @@
 				   mcast->backoff * HZ);
 	spin_unlock_irq(&priv->lock);
 	mutex_unlock(&mcast_mutex);
-
+out:
+	complete(&mcast->done);
 	return status;
 }
 
@@ -484,11 +488,15 @@
 	}
 
 	set_bit(IPOIB_MCAST_FLAG_BUSY, &mcast->flags);
+	init_completion(&mcast->done);
+	set_bit(IPOIB_MCAST_JOIN_STARTED, &mcast->flags);
+
 	mcast->mc = ib_sa_join_multicast(&ipoib_sa_client, priv->ca, priv->port,
 					 &rec, comp_mask, GFP_KERNEL,
 					 ipoib_mcast_join_complete, mcast);
 	if (IS_ERR(mcast->mc)) {
 		clear_bit(IPOIB_MCAST_FLAG_BUSY, &mcast->flags);
+		complete(&mcast->done);
 		ret = PTR_ERR(mcast->mc);
 		ipoib_warn(priv, "ib_sa_join_multicast failed, status %d\n", ret);
 
@@ -751,6 +759,11 @@
 
 	spin_unlock_irqrestore(&priv->lock, flags);
 
+	/* seperate between the wait to the leave*/
+	list_for_each_entry_safe(mcast, tmcast, &remove_list, list)
+		if (test_bit(IPOIB_MCAST_JOIN_STARTED, &mcast->flags))
+			wait_for_completion(&mcast->done);
+
 	list_for_each_entry_safe(mcast, tmcast, &remove_list, list) {
 		ipoib_mcast_leave(dev, mcast);
 		ipoib_mcast_free(mcast);