IB/cm: Match connection requests based on private data

Extend matching connection requests to listens in the InfiniBand CM to
include private data checks.

This allows applications to listen on the same service identifier,
with private data directing the request to the appropriate application.

Signed-off-by: Sean Hefty <sean.hefty@intel.com>
Signed-off-by: Roland Dreier <rolandd@cisco.com>
diff --git a/drivers/infiniband/core/cm.c b/drivers/infiniband/core/cm.c
index 86fee43..490fd03 100644
--- a/drivers/infiniband/core/cm.c
+++ b/drivers/infiniband/core/cm.c
@@ -32,7 +32,7 @@
  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  *
- * $Id: cm.c 2821 2005-07-08 17:07:28Z sean.hefty $
+ * $Id: cm.c 4311 2005-12-05 18:42:01Z sean.hefty $
  */
 
 #include <linux/completion.h>
@@ -132,6 +132,7 @@
 	/* todo: use alternate port on send failure */
 	struct cm_av av;
 	struct cm_av alt_av;
+	struct ib_cm_compare_data *compare_data;
 
 	void *private_data;
 	__be64 tid;
@@ -357,6 +358,41 @@
 	return cm_id_priv;
 }
 
+static void cm_mask_copy(u8 *dst, u8 *src, u8 *mask)
+{
+	int i;
+
+	for (i = 0; i < IB_CM_COMPARE_SIZE / sizeof(unsigned long); i++)
+		((unsigned long *) dst)[i] = ((unsigned long *) src)[i] &
+					     ((unsigned long *) mask)[i];
+}
+
+static int cm_compare_data(struct ib_cm_compare_data *src_data,
+			   struct ib_cm_compare_data *dst_data)
+{
+	u8 src[IB_CM_COMPARE_SIZE];
+	u8 dst[IB_CM_COMPARE_SIZE];
+
+	if (!src_data || !dst_data)
+		return 0;
+
+	cm_mask_copy(src, src_data->data, dst_data->mask);
+	cm_mask_copy(dst, dst_data->data, src_data->mask);
+	return memcmp(src, dst, IB_CM_COMPARE_SIZE);
+}
+
+static int cm_compare_private_data(u8 *private_data,
+				   struct ib_cm_compare_data *dst_data)
+{
+	u8 src[IB_CM_COMPARE_SIZE];
+
+	if (!dst_data)
+		return 0;
+
+	cm_mask_copy(src, private_data, dst_data->mask);
+	return memcmp(src, dst_data->data, IB_CM_COMPARE_SIZE);
+}
+
 static struct cm_id_private * cm_insert_listen(struct cm_id_private *cm_id_priv)
 {
 	struct rb_node **link = &cm.listen_service_table.rb_node;
@@ -364,14 +400,18 @@
 	struct cm_id_private *cur_cm_id_priv;
 	__be64 service_id = cm_id_priv->id.service_id;
 	__be64 service_mask = cm_id_priv->id.service_mask;
+	int data_cmp;
 
 	while (*link) {
 		parent = *link;
 		cur_cm_id_priv = rb_entry(parent, struct cm_id_private,
 					  service_node);
+		data_cmp = cm_compare_data(cm_id_priv->compare_data,
+					   cur_cm_id_priv->compare_data);
 		if ((cur_cm_id_priv->id.service_mask & service_id) ==
 		    (service_mask & cur_cm_id_priv->id.service_id) &&
-		    (cm_id_priv->id.device == cur_cm_id_priv->id.device))
+		    (cm_id_priv->id.device == cur_cm_id_priv->id.device) &&
+		    !data_cmp)
 			return cur_cm_id_priv;
 
 		if (cm_id_priv->id.device < cur_cm_id_priv->id.device)
@@ -380,6 +420,10 @@
 			link = &(*link)->rb_right;
 		else if (service_id < cur_cm_id_priv->id.service_id)
 			link = &(*link)->rb_left;
+		else if (service_id > cur_cm_id_priv->id.service_id)
+			link = &(*link)->rb_right;
+		else if (data_cmp < 0)
+			link = &(*link)->rb_left;
 		else
 			link = &(*link)->rb_right;
 	}
@@ -389,16 +433,20 @@
 }
 
 static struct cm_id_private * cm_find_listen(struct ib_device *device,
-					     __be64 service_id)
+					     __be64 service_id,
+					     u8 *private_data)
 {
 	struct rb_node *node = cm.listen_service_table.rb_node;
 	struct cm_id_private *cm_id_priv;
+	int data_cmp;
 
 	while (node) {
 		cm_id_priv = rb_entry(node, struct cm_id_private, service_node);
+		data_cmp = cm_compare_private_data(private_data,
+						   cm_id_priv->compare_data);
 		if ((cm_id_priv->id.service_mask & service_id) ==
 		     cm_id_priv->id.service_id &&
-		    (cm_id_priv->id.device == device))
+		    (cm_id_priv->id.device == device) && !data_cmp)
 			return cm_id_priv;
 
 		if (device < cm_id_priv->id.device)
@@ -407,6 +455,10 @@
 			node = node->rb_right;
 		else if (service_id < cm_id_priv->id.service_id)
 			node = node->rb_left;
+		else if (service_id > cm_id_priv->id.service_id)
+			node = node->rb_right;
+		else if (data_cmp < 0)
+			node = node->rb_left;
 		else
 			node = node->rb_right;
 	}
@@ -730,15 +782,14 @@
 	wait_for_completion(&cm_id_priv->comp);
 	while ((work = cm_dequeue_work(cm_id_priv)) != NULL)
 		cm_free_work(work);
-	if (cm_id_priv->private_data && cm_id_priv->private_data_len)
-		kfree(cm_id_priv->private_data);
+	kfree(cm_id_priv->compare_data);
+	kfree(cm_id_priv->private_data);
 	kfree(cm_id_priv);
 }
 EXPORT_SYMBOL(ib_destroy_cm_id);
 
-int ib_cm_listen(struct ib_cm_id *cm_id,
-		 __be64 service_id,
-		 __be64 service_mask)
+int ib_cm_listen(struct ib_cm_id *cm_id, __be64 service_id, __be64 service_mask,
+		 struct ib_cm_compare_data *compare_data)
 {
 	struct cm_id_private *cm_id_priv, *cur_cm_id_priv;
 	unsigned long flags;
@@ -752,7 +803,19 @@
 		return -EINVAL;
 
 	cm_id_priv = container_of(cm_id, struct cm_id_private, id);
-	BUG_ON(cm_id->state != IB_CM_IDLE);
+	if (cm_id->state != IB_CM_IDLE)
+		return -EINVAL;
+
+	if (compare_data) {
+		cm_id_priv->compare_data = kzalloc(sizeof *compare_data,
+						   GFP_KERNEL);
+		if (!cm_id_priv->compare_data)
+			return -ENOMEM;
+		cm_mask_copy(cm_id_priv->compare_data->data,
+			     compare_data->data, compare_data->mask);
+		memcpy(cm_id_priv->compare_data->mask, compare_data->mask,
+		       IB_CM_COMPARE_SIZE);
+	}
 
 	cm_id->state = IB_CM_LISTEN;
 
@@ -769,6 +832,8 @@
 
 	if (cur_cm_id_priv) {
 		cm_id->state = IB_CM_IDLE;
+		kfree(cm_id_priv->compare_data);
+		cm_id_priv->compare_data = NULL;
 		ret = -EBUSY;
 	}
 	return ret;
@@ -1241,7 +1306,8 @@
 
 	/* Find matching listen request. */
 	listen_cm_id_priv = cm_find_listen(cm_id_priv->id.device,
-					   req_msg->service_id);
+					   req_msg->service_id,
+					   req_msg->private_data);
 	if (!listen_cm_id_priv) {
 		spin_unlock_irqrestore(&cm.lock, flags);
 		cm_issue_rej(work->port, work->mad_recv_wc,
@@ -2654,7 +2720,8 @@
 		goto out; /* Duplicate message. */
 	}
 	cur_cm_id_priv = cm_find_listen(cm_id->device,
-					sidr_req_msg->service_id);
+					sidr_req_msg->service_id,
+					sidr_req_msg->private_data);
 	if (!cur_cm_id_priv) {
 		rb_erase(&cm_id_priv->sidr_id_node, &cm.remote_sidr_table);
 		spin_unlock_irqrestore(&cm.lock, flags);
diff --git a/drivers/infiniband/core/ucm.c b/drivers/infiniband/core/ucm.c
index b396bf7..0136aee 100644
--- a/drivers/infiniband/core/ucm.c
+++ b/drivers/infiniband/core/ucm.c
@@ -648,6 +648,17 @@
 	return result;
 }
 
+static int ucm_validate_listen(__be64 service_id, __be64 service_mask)
+{
+	service_id &= service_mask;
+
+	if (((service_id & IB_CMA_SERVICE_ID_MASK) == IB_CMA_SERVICE_ID) ||
+	    ((service_id & IB_SDP_SERVICE_ID_MASK) == IB_SDP_SERVICE_ID))
+		return -EINVAL;
+
+	return 0;
+}
+
 static ssize_t ib_ucm_listen(struct ib_ucm_file *file,
 			     const char __user *inbuf,
 			     int in_len, int out_len)
@@ -663,7 +674,13 @@
 	if (IS_ERR(ctx))
 		return PTR_ERR(ctx);
 
-	result = ib_cm_listen(ctx->cm_id, cmd.service_id, cmd.service_mask);
+	result = ucm_validate_listen(cmd.service_id, cmd.service_mask);
+	if (result)
+		goto out;
+
+	result = ib_cm_listen(ctx->cm_id, cmd.service_id, cmd.service_mask,
+			      NULL);
+out:
 	ib_ucm_ctx_put(ctx);
 	return result;
 }
diff --git a/include/rdma/ib_cm.h b/include/rdma/ib_cm.h
index 0a9fcd5..8f394f0 100644
--- a/include/rdma/ib_cm.h
+++ b/include/rdma/ib_cm.h
@@ -32,7 +32,7 @@
  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  *
- * $Id: ib_cm.h 2730 2005-06-28 16:43:03Z sean.hefty $
+ * $Id: ib_cm.h 4311 2005-12-05 18:42:01Z sean.hefty $
  */
 #if !defined(IB_CM_H)
 #define IB_CM_H
@@ -102,7 +102,8 @@
 	IB_CM_APR_INFO_LENGTH		 = 72,
 	IB_CM_SIDR_REQ_PRIVATE_DATA_SIZE = 216,
 	IB_CM_SIDR_REP_PRIVATE_DATA_SIZE = 136,
-	IB_CM_SIDR_REP_INFO_LENGTH	 = 72
+	IB_CM_SIDR_REP_INFO_LENGTH	 = 72,
+	IB_CM_COMPARE_SIZE		 = 64
 };
 
 struct ib_cm_id;
@@ -238,7 +239,6 @@
 	u32			qpn;
 	void			*info;
 	u8			info_len;
-
 };
 
 struct ib_cm_event {
@@ -317,6 +317,15 @@
 
 #define IB_SERVICE_ID_AGN_MASK	__constant_cpu_to_be64(0xFF00000000000000ULL)
 #define IB_CM_ASSIGN_SERVICE_ID __constant_cpu_to_be64(0x0200000000000000ULL)
+#define IB_CMA_SERVICE_ID	__constant_cpu_to_be64(0x0000000001000000ULL)
+#define IB_CMA_SERVICE_ID_MASK	__constant_cpu_to_be64(0xFFFFFFFFFF000000ULL)
+#define IB_SDP_SERVICE_ID	__constant_cpu_to_be64(0x0000000000010000ULL)
+#define IB_SDP_SERVICE_ID_MASK	__constant_cpu_to_be64(0xFFFFFFFFFFFF0000ULL)
+
+struct ib_cm_compare_data {
+	u8  data[IB_CM_COMPARE_SIZE];
+	u8  mask[IB_CM_COMPARE_SIZE];
+};
 
 /**
  * ib_cm_listen - Initiates listening on the specified service ID for
@@ -330,10 +339,12 @@
  *   range of service IDs.  If set to 0, the service ID is matched
  *   exactly.  This parameter is ignored if %service_id is set to
  *   IB_CM_ASSIGN_SERVICE_ID.
+ * @compare_data: This parameter is optional.  It specifies data that must
+ *   appear in the private data of a connection request for the specified
+ *   listen request.
  */
-int ib_cm_listen(struct ib_cm_id *cm_id,
-		 __be64 service_id,
-		 __be64 service_mask);
+int ib_cm_listen(struct ib_cm_id *cm_id, __be64 service_id, __be64 service_mask,
+		 struct ib_cm_compare_data *compare_data);
 
 struct ib_cm_req_param {
 	struct ib_sa_path_rec	*primary_path;