IB/core: Support for CMA multicast join flags

Added UCMA and CMA support for multicast join flags. Flags are
passed using UCMA CM join command previously reserved fields.
Currently supporting two join flags indicating two different
multicast JoinStates:

1. Full Member:
   The initiator creates the Multicast group(MCG) if it wasn't
   previously created, can send Multicast messages to the group
   and receive messages from the MCG.

2. Send Only Full Member:
   The initiator creates the Multicast group(MCG) if it wasn't
   previously created, can send Multicast messages to the group
   but doesn't receive any messages from the MCG.

   IB: Send Only Full Member requires a query of ClassPortInfo
       to determine if SM/SA supports this option. If SM/SA
       doesn't support Send-Only there will be no join request
       sent and an error will be returned.

   ETH: When Send Only Full Member is requested no IGMP join
	will be sent.

Signed-off-by: Alex Vesker <valex@mellanox.com>
Reviewed by: Hal Rosenstock <hal@mellanox.com>
Signed-off-by: Leon Romanovsky <leon@kernel.org>
Signed-off-by: Doug Ledford <dledford@redhat.com>
diff --git a/drivers/infiniband/core/cma.c b/drivers/infiniband/core/cma.c
index f0c91ba..0451307 100644
--- a/drivers/infiniband/core/cma.c
+++ b/drivers/infiniband/core/cma.c
@@ -68,6 +68,7 @@
 MODULE_LICENSE("Dual BSD/GPL");
 
 #define CMA_CM_RESPONSE_TIMEOUT 20
+#define CMA_QUERY_CLASSPORT_INFO_TIMEOUT 3000
 #define CMA_MAX_CM_RETRIES 15
 #define CMA_CM_MRA_SETTING (IB_CM_MRA_FLAG_DELAY | 24)
 #define CMA_IBOE_PACKET_LIFETIME 18
@@ -162,6 +163,14 @@
 	unsigned short		port;
 };
 
+struct class_port_info_context {
+	struct ib_class_port_info	*class_port_info;
+	struct ib_device		*device;
+	struct completion		done;
+	struct ib_sa_query		*sa_query;
+	u8				port_num;
+};
+
 static int cma_ps_alloc(struct net *net, enum rdma_port_space ps,
 			struct rdma_bind_list *bind_list, int snum)
 {
@@ -306,6 +315,7 @@
 	struct sockaddr_storage	addr;
 	struct kref		mcref;
 	bool			igmp_joined;
+	u8			join_state;
 };
 
 struct cma_work {
@@ -3754,10 +3764,63 @@
 	}
 }
 
+static void cma_query_sa_classport_info_cb(int status,
+					   struct ib_class_port_info *rec,
+					   void *context)
+{
+	struct class_port_info_context *cb_ctx = context;
+
+	WARN_ON(!context);
+
+	if (status || !rec) {
+		pr_debug("RDMA CM: %s port %u failed query ClassPortInfo status: %d\n",
+			 cb_ctx->device->name, cb_ctx->port_num, status);
+		goto out;
+	}
+
+	memcpy(cb_ctx->class_port_info, rec, sizeof(struct ib_class_port_info));
+
+out:
+	complete(&cb_ctx->done);
+}
+
+static int cma_query_sa_classport_info(struct ib_device *device, u8 port_num,
+				       struct ib_class_port_info *class_port_info)
+{
+	struct class_port_info_context *cb_ctx;
+	int ret;
+
+	cb_ctx = kmalloc(sizeof(*cb_ctx), GFP_KERNEL);
+	if (!cb_ctx)
+		return -ENOMEM;
+
+	cb_ctx->device = device;
+	cb_ctx->class_port_info = class_port_info;
+	cb_ctx->port_num = port_num;
+	init_completion(&cb_ctx->done);
+
+	ret = ib_sa_classport_info_rec_query(&sa_client, device, port_num,
+					     CMA_QUERY_CLASSPORT_INFO_TIMEOUT,
+					     GFP_KERNEL, cma_query_sa_classport_info_cb,
+					     cb_ctx, &cb_ctx->sa_query);
+	if (ret < 0) {
+		pr_err("RDMA CM: %s port %u failed to send ClassPortInfo query, ret: %d\n",
+		       device->name, port_num, ret);
+		goto out;
+	}
+
+	wait_for_completion(&cb_ctx->done);
+
+out:
+	kfree(cb_ctx);
+	return ret;
+}
+
 static int cma_join_ib_multicast(struct rdma_id_private *id_priv,
 				 struct cma_multicast *mc)
 {
 	struct ib_sa_mcmember_rec rec;
+	struct ib_class_port_info class_port_info;
 	struct rdma_dev_addr *dev_addr = &id_priv->id.route.addr.dev_addr;
 	ib_sa_comp_mask comp_mask;
 	int ret;
@@ -3776,7 +3839,24 @@
 	rec.qkey = cpu_to_be32(id_priv->qkey);
 	rdma_addr_get_sgid(dev_addr, &rec.port_gid);
 	rec.pkey = cpu_to_be16(ib_addr_get_pkey(dev_addr));
-	rec.join_state = 1;
+	rec.join_state = mc->join_state;
+
+	if (rec.join_state == BIT(SENDONLY_FULLMEMBER_JOIN)) {
+		ret = cma_query_sa_classport_info(id_priv->id.device,
+						  id_priv->id.port_num,
+						  &class_port_info);
+
+		if (ret)
+			return ret;
+
+		if (!(ib_get_cpi_capmask2(&class_port_info) &
+		      IB_SA_CAP_MASK2_SENDONLY_FULL_MEM_SUPPORT)) {
+			pr_warn("RDMA CM: %s port %u Unable to multicast join\n"
+				"RDMA CM: SM doesn't support Send Only Full Member option\n",
+				id_priv->id.device->name, id_priv->id.port_num);
+			return -EOPNOTSUPP;
+		}
+	}
 
 	comp_mask = IB_SA_MCMEMBER_REC_MGID | IB_SA_MCMEMBER_REC_PORT_GID |
 		    IB_SA_MCMEMBER_REC_PKEY | IB_SA_MCMEMBER_REC_JOIN_STATE |
@@ -3845,6 +3925,9 @@
 	struct sockaddr *addr = (struct sockaddr *)&mc->addr;
 	struct net_device *ndev = NULL;
 	enum ib_gid_type gid_type;
+	bool send_only;
+
+	send_only = mc->join_state == BIT(SENDONLY_FULLMEMBER_JOIN);
 
 	if (cma_zero_addr((struct sockaddr *)&mc->addr))
 		return -EINVAL;
@@ -3878,12 +3961,14 @@
 	gid_type = id_priv->cma_dev->default_gid_type[id_priv->id.port_num -
 		   rdma_start_port(id_priv->cma_dev->device)];
 	if (addr->sa_family == AF_INET) {
-		if (gid_type == IB_GID_TYPE_ROCE_UDP_ENCAP)
-			err = cma_igmp_send(ndev, &mc->multicast.ib->rec.mgid,
-					    true);
-		if (!err) {
-			mc->igmp_joined = true;
+		if (gid_type == IB_GID_TYPE_ROCE_UDP_ENCAP) {
 			mc->multicast.ib->rec.hop_limit = IPV6_DEFAULT_HOPLIMIT;
+			if (!send_only) {
+				err = cma_igmp_send(ndev, &mc->multicast.ib->rec.mgid,
+						    true);
+				if (!err)
+					mc->igmp_joined = true;
+			}
 		}
 	} else {
 		if (gid_type == IB_GID_TYPE_ROCE_UDP_ENCAP)
@@ -3913,7 +3998,7 @@
 }
 
 int rdma_join_multicast(struct rdma_cm_id *id, struct sockaddr *addr,
-			void *context)
+			u8 join_state, void *context)
 {
 	struct rdma_id_private *id_priv;
 	struct cma_multicast *mc;
@@ -3932,6 +4017,7 @@
 	mc->context = context;
 	mc->id_priv = id_priv;
 	mc->igmp_joined = false;
+	mc->join_state = join_state;
 	spin_lock(&id_priv->lock);
 	list_add(&mc->list, &id_priv->mc_list);
 	spin_unlock(&id_priv->lock);
diff --git a/drivers/infiniband/core/multicast.c b/drivers/infiniband/core/multicast.c
index a83ec28..3a3c5d7 100644
--- a/drivers/infiniband/core/multicast.c
+++ b/drivers/infiniband/core/multicast.c
@@ -93,18 +93,6 @@
 
 struct mcast_member;
 
-/*
-* There are 4 types of join states:
-* FullMember, NonMember, SendOnlyNonMember, SendOnlyFullMember.
-*/
-enum {
-	FULLMEMBER_JOIN,
-	NONMEMBER_JOIN,
-	SENDONLY_NONMEBER_JOIN,
-	SENDONLY_FULLMEMBER_JOIN,
-	NUM_JOIN_MEMBERSHIP_TYPES,
-};
-
 struct mcast_group {
 	struct ib_sa_mcmember_rec rec;
 	struct rb_node		node;
diff --git a/drivers/infiniband/core/ucma.c b/drivers/infiniband/core/ucma.c
index c0f3826..2825ece 100644
--- a/drivers/infiniband/core/ucma.c
+++ b/drivers/infiniband/core/ucma.c
@@ -106,6 +106,7 @@
 	int			events_reported;
 
 	u64			uid;
+	u8			join_state;
 	struct list_head	list;
 	struct sockaddr_storage	addr;
 };
@@ -1317,12 +1318,20 @@
 	struct ucma_multicast *mc;
 	struct sockaddr *addr;
 	int ret;
+	u8 join_state;
 
 	if (out_len < sizeof(resp))
 		return -ENOSPC;
 
 	addr = (struct sockaddr *) &cmd->addr;
-	if (cmd->reserved || !cmd->addr_size || (cmd->addr_size != rdma_addr_size(addr)))
+	if (!cmd->addr_size || (cmd->addr_size != rdma_addr_size(addr)))
+		return -EINVAL;
+
+	if (cmd->join_flags == RDMA_MC_JOIN_FLAG_FULLMEMBER)
+		join_state = BIT(FULLMEMBER_JOIN);
+	else if (cmd->join_flags == RDMA_MC_JOIN_FLAG_SENDONLY_FULLMEMBER)
+		join_state = BIT(SENDONLY_FULLMEMBER_JOIN);
+	else
 		return -EINVAL;
 
 	ctx = ucma_get_ctx(file, cmd->id);
@@ -1335,10 +1344,11 @@
 		ret = -ENOMEM;
 		goto err1;
 	}
-
+	mc->join_state = join_state;
 	mc->uid = cmd->uid;
 	memcpy(&mc->addr, addr, cmd->addr_size);
-	ret = rdma_join_multicast(ctx->cm_id, (struct sockaddr *) &mc->addr, mc);
+	ret = rdma_join_multicast(ctx->cm_id, (struct sockaddr *)&mc->addr,
+				  join_state, mc);
 	if (ret)
 		goto err2;
 
@@ -1382,7 +1392,7 @@
 	join_cmd.uid = cmd.uid;
 	join_cmd.id = cmd.id;
 	join_cmd.addr_size = rdma_addr_size((struct sockaddr *) &cmd.addr);
-	join_cmd.reserved = 0;
+	join_cmd.join_flags = RDMA_MC_JOIN_FLAG_FULLMEMBER;
 	memcpy(&join_cmd.addr, &cmd.addr, join_cmd.addr_size);
 
 	return ucma_process_join(file, &join_cmd, out_len);