RDS: Implement masked atomic operations

Add two CMSGs for masked versions of cswp and fadd. args
struct modified to use a union for different atomic op type's
arguments. Change IB to do masked atomic ops. Atomic op type
in rds_message similarly unionized.

Signed-off-by: Andy Grover <andy.grover@oracle.com>
diff --git a/net/rds/ib_send.c b/net/rds/ib_send.c
index 808544a..71f373c 100644
--- a/net/rds/ib_send.c
+++ b/net/rds/ib_send.c
@@ -807,13 +807,17 @@
 	send->s_queued = jiffies;
 
 	if (op->op_type == RDS_ATOMIC_TYPE_CSWP) {
-		send->s_wr.opcode = IB_WR_ATOMIC_CMP_AND_SWP;
-		send->s_wr.wr.atomic.compare_add = op->op_compare;
-		send->s_wr.wr.atomic.swap = op->op_swap_add;
+		send->s_wr.opcode = IB_WR_MASKED_ATOMIC_CMP_AND_SWP;
+		send->s_wr.wr.atomic.compare_add = op->op_m_cswp.compare;
+		send->s_wr.wr.atomic.swap = op->op_m_cswp.swap;
+		send->s_wr.wr.atomic.compare_add_mask = op->op_m_cswp.compare_mask;
+		send->s_wr.wr.atomic.swap_mask = op->op_m_cswp.swap_mask;
 	} else { /* FADD */
-		send->s_wr.opcode = IB_WR_ATOMIC_FETCH_AND_ADD;
-		send->s_wr.wr.atomic.compare_add = op->op_swap_add;
+		send->s_wr.opcode = IB_WR_MASKED_ATOMIC_FETCH_AND_ADD;
+		send->s_wr.wr.atomic.compare_add = op->op_m_fadd.add;
 		send->s_wr.wr.atomic.swap = 0;
+		send->s_wr.wr.atomic.compare_add_mask = op->op_m_fadd.nocarry_mask;
+		send->s_wr.wr.atomic.swap_mask = 0;
 	}
 	nr_sig = rds_ib_set_wr_signal_state(ic, send, op->op_notify);
 	send->s_wr.num_sge = 1;
diff --git a/net/rds/rdma.c b/net/rds/rdma.c
index 48781fe..4806467 100644
--- a/net/rds/rdma.c
+++ b/net/rds/rdma.c
@@ -738,13 +738,34 @@
 
 	args = CMSG_DATA(cmsg);
 
-	if (cmsg->cmsg_type == RDS_CMSG_ATOMIC_CSWP) {
-		rm->atomic.op_type = RDS_ATOMIC_TYPE_CSWP;
-		rm->atomic.op_swap_add = args->cswp.swap;
-		rm->atomic.op_compare = args->cswp.compare;
-	} else {
+	/* Nonmasked & masked cmsg ops converted to masked hw ops */
+	switch (cmsg->cmsg_type) {
+	case RDS_CMSG_ATOMIC_FADD:
 		rm->atomic.op_type = RDS_ATOMIC_TYPE_FADD;
-		rm->atomic.op_swap_add = args->fadd.add;
+		rm->atomic.op_m_fadd.add = args->fadd.add;
+		rm->atomic.op_m_fadd.nocarry_mask = 0;
+		break;
+	case RDS_CMSG_MASKED_ATOMIC_FADD:
+		rm->atomic.op_type = RDS_ATOMIC_TYPE_FADD;
+		rm->atomic.op_m_fadd.add = args->m_fadd.add;
+		rm->atomic.op_m_fadd.nocarry_mask = args->m_fadd.nocarry_mask;
+		break;
+	case RDS_CMSG_ATOMIC_CSWP:
+		rm->atomic.op_type = RDS_ATOMIC_TYPE_CSWP;
+		rm->atomic.op_m_cswp.compare = args->cswp.compare;
+		rm->atomic.op_m_cswp.swap = args->cswp.swap;
+		rm->atomic.op_m_cswp.compare_mask = ~0;
+		rm->atomic.op_m_cswp.swap_mask = ~0;
+		break;
+	case RDS_CMSG_MASKED_ATOMIC_CSWP:
+		rm->atomic.op_type = RDS_ATOMIC_TYPE_CSWP;
+		rm->atomic.op_m_cswp.compare = args->m_cswp.compare;
+		rm->atomic.op_m_cswp.swap = args->m_cswp.swap;
+		rm->atomic.op_m_cswp.compare_mask = args->m_cswp.compare_mask;
+		rm->atomic.op_m_cswp.swap_mask = args->m_cswp.swap_mask;
+		break;
+	default:
+		BUG(); /* should never happen */
 	}
 
 	rm->atomic.op_notify = !!(args->flags & RDS_RDMA_NOTIFY_ME);
diff --git a/net/rds/rds.h b/net/rds/rds.h
index aadaddb..8103dcf 100644
--- a/net/rds/rds.h
+++ b/net/rds/rds.h
@@ -316,8 +316,18 @@
 	struct {
 		struct rm_atomic_op {
 			int			op_type;
-			uint64_t		op_swap_add;
-			uint64_t		op_compare;
+			union {
+				struct {
+					uint64_t	compare;
+					uint64_t	swap;
+					uint64_t	compare_mask;
+					uint64_t	swap_mask;
+				} op_m_cswp;
+				struct {
+					uint64_t	add;
+					uint64_t	nocarry_mask;
+				} op_m_fadd;
+			};
 
 			u32			op_rkey;
 			u64			op_remote_addr;
diff --git a/net/rds/send.c b/net/rds/send.c
index 81471b2..9b951a0 100644
--- a/net/rds/send.c
+++ b/net/rds/send.c
@@ -843,6 +843,8 @@
 
 		case RDS_CMSG_ATOMIC_CSWP:
 		case RDS_CMSG_ATOMIC_FADD:
+		case RDS_CMSG_MASKED_ATOMIC_CSWP:
+		case RDS_CMSG_MASKED_ATOMIC_FADD:
 			cmsg_groups |= 1;
 			size += sizeof(struct scatterlist);
 			break;
@@ -894,6 +896,8 @@
 			break;
 		case RDS_CMSG_ATOMIC_CSWP:
 		case RDS_CMSG_ATOMIC_FADD:
+		case RDS_CMSG_MASKED_ATOMIC_CSWP:
+		case RDS_CMSG_MASKED_ATOMIC_FADD:
 			ret = rds_cmsg_atomic(rs, rm, cmsg);
 			break;