IB: split struct ib_send_wr

This patch split up struct ib_send_wr so that all non-trivial verbs
use their own structure which embedds struct ib_send_wr.  This dramaticly
shrinks the size of a WR for most common operations:

sizeof(struct ib_send_wr) (old):	96

sizeof(struct ib_send_wr):		48
sizeof(struct ib_rdma_wr):		64
sizeof(struct ib_atomic_wr):		96
sizeof(struct ib_ud_wr):		88
sizeof(struct ib_fast_reg_wr):		88
sizeof(struct ib_bind_mw_wr):		96
sizeof(struct ib_sig_handover_wr):	80

And with Sagi's pending MR rework the fast registration WR will also be
down to a reasonable size:

sizeof(struct ib_fastreg_wr):		64

Signed-off-by: Christoph Hellwig <hch@lst.de>
Reviewed-by: Bart Van Assche <bart.vanassche@sandisk.com> [srp, srpt]
Reviewed-by: Chuck Lever <chuck.lever@oracle.com> [sunrpc]
Tested-by: Haggai Eran <haggaie@mellanox.com>
Tested-by: Sagi Grimberg <sagig@mellanox.com>
Tested-by: Steve Wise <swise@opengridcomputing.com>
diff --git a/drivers/infiniband/hw/mlx5/mlx5_ib.h b/drivers/infiniband/hw/mlx5/mlx5_ib.h
index 22123b7..29f3ecd 100644
--- a/drivers/infiniband/hw/mlx5/mlx5_ib.h
+++ b/drivers/infiniband/hw/mlx5/mlx5_ib.h
@@ -245,6 +245,7 @@
 };
 
 struct mlx5_umr_wr {
+	struct ib_send_wr		wr;
 	union {
 		u64			virt_addr;
 		u64			offset;
@@ -257,6 +258,11 @@
 	u32				mkey;
 };
 
+static inline struct mlx5_umr_wr *umr_wr(struct ib_send_wr *wr)
+{
+	return container_of(wr, struct mlx5_umr_wr, wr);
+}
+
 struct mlx5_shared_mr_info {
 	int mr_id;
 	struct ib_umem		*umem;
diff --git a/drivers/infiniband/hw/mlx5/mr.c b/drivers/infiniband/hw/mlx5/mr.c
index 54a15b5..b30d4ae 100644
--- a/drivers/infiniband/hw/mlx5/mr.c
+++ b/drivers/infiniband/hw/mlx5/mr.c
@@ -687,7 +687,7 @@
 			     int access_flags)
 {
 	struct mlx5_ib_dev *dev = to_mdev(pd->device);
-	struct mlx5_umr_wr *umrwr = (struct mlx5_umr_wr *)&wr->wr.fast_reg;
+	struct mlx5_umr_wr *umrwr = umr_wr(wr);
 
 	sg->addr = dma;
 	sg->length = ALIGN(sizeof(u64) * n, 64);
@@ -715,7 +715,7 @@
 static void prep_umr_unreg_wqe(struct mlx5_ib_dev *dev,
 			       struct ib_send_wr *wr, u32 key)
 {
-	struct mlx5_umr_wr *umrwr = (struct mlx5_umr_wr *)&wr->wr.fast_reg;
+	struct mlx5_umr_wr *umrwr = umr_wr(wr);
 
 	wr->send_flags = MLX5_IB_SEND_UMR_UNREG | MLX5_IB_SEND_UMR_FAIL_IF_FREE;
 	wr->opcode = MLX5_IB_WR_UMR;
@@ -752,7 +752,8 @@
 	struct device *ddev = dev->ib_dev.dma_device;
 	struct umr_common *umrc = &dev->umrc;
 	struct mlx5_ib_umr_context umr_context;
-	struct ib_send_wr wr, *bad;
+	struct mlx5_umr_wr umrwr;
+	struct ib_send_wr *bad;
 	struct mlx5_ib_mr *mr;
 	struct ib_sge sg;
 	int size;
@@ -798,14 +799,14 @@
 		goto free_pas;
 	}
 
-	memset(&wr, 0, sizeof(wr));
-	wr.wr_id = (u64)(unsigned long)&umr_context;
-	prep_umr_reg_wqe(pd, &wr, &sg, dma, npages, mr->mmr.key, page_shift,
-			 virt_addr, len, access_flags);
+	memset(&umrwr, 0, sizeof(umrwr));
+	umrwr.wr.wr_id = (u64)(unsigned long)&umr_context;
+	prep_umr_reg_wqe(pd, &umrwr.wr, &sg, dma, npages, mr->mmr.key,
+			 page_shift, virt_addr, len, access_flags);
 
 	mlx5_ib_init_umr_context(&umr_context);
 	down(&umrc->sem);
-	err = ib_post_send(umrc->qp, &wr, &bad);
+	err = ib_post_send(umrc->qp, &umrwr.wr, &bad);
 	if (err) {
 		mlx5_ib_warn(dev, "post send failed, err %d\n", err);
 		goto unmap_dma;
@@ -851,8 +852,8 @@
 	int size;
 	__be64 *pas;
 	dma_addr_t dma;
-	struct ib_send_wr wr, *bad;
-	struct mlx5_umr_wr *umrwr = (struct mlx5_umr_wr *)&wr.wr.fast_reg;
+	struct ib_send_wr *bad;
+	struct mlx5_umr_wr wr;
 	struct ib_sge sg;
 	int err = 0;
 	const int page_index_alignment = MLX5_UMR_MTT_ALIGNMENT / sizeof(u64);
@@ -917,26 +918,26 @@
 		dma_sync_single_for_device(ddev, dma, size, DMA_TO_DEVICE);
 
 		memset(&wr, 0, sizeof(wr));
-		wr.wr_id = (u64)(unsigned long)&umr_context;
+		wr.wr.wr_id = (u64)(unsigned long)&umr_context;
 
 		sg.addr = dma;
 		sg.length = ALIGN(npages * sizeof(u64),
 				MLX5_UMR_MTT_ALIGNMENT);
 		sg.lkey = dev->umrc.pd->local_dma_lkey;
 
-		wr.send_flags = MLX5_IB_SEND_UMR_FAIL_IF_FREE |
+		wr.wr.send_flags = MLX5_IB_SEND_UMR_FAIL_IF_FREE |
 				MLX5_IB_SEND_UMR_UPDATE_MTT;
-		wr.sg_list = &sg;
-		wr.num_sge = 1;
-		wr.opcode = MLX5_IB_WR_UMR;
-		umrwr->npages = sg.length / sizeof(u64);
-		umrwr->page_shift = PAGE_SHIFT;
-		umrwr->mkey = mr->mmr.key;
-		umrwr->target.offset = start_page_index;
+		wr.wr.sg_list = &sg;
+		wr.wr.num_sge = 1;
+		wr.wr.opcode = MLX5_IB_WR_UMR;
+		wr.npages = sg.length / sizeof(u64);
+		wr.page_shift = PAGE_SHIFT;
+		wr.mkey = mr->mmr.key;
+		wr.target.offset = start_page_index;
 
 		mlx5_ib_init_umr_context(&umr_context);
 		down(&umrc->sem);
-		err = ib_post_send(umrc->qp, &wr, &bad);
+		err = ib_post_send(umrc->qp, &wr.wr, &bad);
 		if (err) {
 			mlx5_ib_err(dev, "UMR post send failed, err %d\n", err);
 		} else {
@@ -1122,16 +1123,17 @@
 {
 	struct umr_common *umrc = &dev->umrc;
 	struct mlx5_ib_umr_context umr_context;
-	struct ib_send_wr wr, *bad;
+	struct mlx5_umr_wr umrwr;
+	struct ib_send_wr *bad;
 	int err;
 
-	memset(&wr, 0, sizeof(wr));
-	wr.wr_id = (u64)(unsigned long)&umr_context;
-	prep_umr_unreg_wqe(dev, &wr, mr->mmr.key);
+	memset(&umrwr.wr, 0, sizeof(umrwr));
+	umrwr.wr.wr_id = (u64)(unsigned long)&umr_context;
+	prep_umr_unreg_wqe(dev, &umrwr.wr, mr->mmr.key);
 
 	mlx5_ib_init_umr_context(&umr_context);
 	down(&umrc->sem);
-	err = ib_post_send(umrc->qp, &wr, &bad);
+	err = ib_post_send(umrc->qp, &umrwr.wr, &bad);
 	if (err) {
 		up(&umrc->sem);
 		mlx5_ib_dbg(dev, "err %d\n", err);
diff --git a/drivers/infiniband/hw/mlx5/qp.c b/drivers/infiniband/hw/mlx5/qp.c
index 6f521a3..d4c36af 100644
--- a/drivers/infiniband/hw/mlx5/qp.c
+++ b/drivers/infiniband/hw/mlx5/qp.c
@@ -1838,9 +1838,9 @@
 static void set_datagram_seg(struct mlx5_wqe_datagram_seg *dseg,
 			     struct ib_send_wr *wr)
 {
-	memcpy(&dseg->av, &to_mah(wr->wr.ud.ah)->av, sizeof(struct mlx5_av));
-	dseg->av.dqp_dct = cpu_to_be32(wr->wr.ud.remote_qpn | MLX5_EXTENDED_UD_AV);
-	dseg->av.key.qkey.qkey = cpu_to_be32(wr->wr.ud.remote_qkey);
+	memcpy(&dseg->av, &to_mah(ud_wr(wr)->ah)->av, sizeof(struct mlx5_av));
+	dseg->av.dqp_dct = cpu_to_be32(ud_wr(wr)->remote_qpn | MLX5_EXTENDED_UD_AV);
+	dseg->av.key.qkey.qkey = cpu_to_be32(ud_wr(wr)->remote_qkey);
 }
 
 static void set_data_ptr_seg(struct mlx5_wqe_data_seg *dseg, struct ib_sge *sg)
@@ -1908,7 +1908,7 @@
 	}
 
 	umr->flags = (1 << 5); /* fail if not free */
-	umr->klm_octowords = get_klm_octo(wr->wr.fast_reg.page_list_len);
+	umr->klm_octowords = get_klm_octo(fast_reg_wr(wr)->page_list_len);
 	umr->mkey_mask = frwr_mkey_mask();
 }
 
@@ -1952,7 +1952,7 @@
 static void set_reg_umr_segment(struct mlx5_wqe_umr_ctrl_seg *umr,
 				struct ib_send_wr *wr)
 {
-	struct mlx5_umr_wr *umrwr = (struct mlx5_umr_wr *)&wr->wr.fast_reg;
+	struct mlx5_umr_wr *umrwr = umr_wr(wr);
 
 	memset(umr, 0, sizeof(*umr));
 
@@ -1996,20 +1996,20 @@
 		return;
 	}
 
-	seg->flags = get_umr_flags(wr->wr.fast_reg.access_flags) |
+	seg->flags = get_umr_flags(fast_reg_wr(wr)->access_flags) |
 		     MLX5_ACCESS_MODE_MTT;
 	*writ = seg->flags & (MLX5_PERM_LOCAL_WRITE | IB_ACCESS_REMOTE_WRITE);
-	seg->qpn_mkey7_0 = cpu_to_be32((wr->wr.fast_reg.rkey & 0xff) | 0xffffff00);
+	seg->qpn_mkey7_0 = cpu_to_be32((fast_reg_wr(wr)->rkey & 0xff) | 0xffffff00);
 	seg->flags_pd = cpu_to_be32(MLX5_MKEY_REMOTE_INVAL);
-	seg->start_addr = cpu_to_be64(wr->wr.fast_reg.iova_start);
-	seg->len = cpu_to_be64(wr->wr.fast_reg.length);
-	seg->xlt_oct_size = cpu_to_be32((wr->wr.fast_reg.page_list_len + 1) / 2);
-	seg->log2_page_size = wr->wr.fast_reg.page_shift;
+	seg->start_addr = cpu_to_be64(fast_reg_wr(wr)->iova_start);
+	seg->len = cpu_to_be64(fast_reg_wr(wr)->length);
+	seg->xlt_oct_size = cpu_to_be32((fast_reg_wr(wr)->page_list_len + 1) / 2);
+	seg->log2_page_size = fast_reg_wr(wr)->page_shift;
 }
 
 static void set_reg_mkey_segment(struct mlx5_mkey_seg *seg, struct ib_send_wr *wr)
 {
-	struct mlx5_umr_wr *umrwr = (struct mlx5_umr_wr *)&wr->wr.fast_reg;
+	struct mlx5_umr_wr *umrwr = umr_wr(wr);
 
 	memset(seg, 0, sizeof(*seg));
 	if (wr->send_flags & MLX5_IB_SEND_UMR_UNREG) {
@@ -2034,15 +2034,15 @@
 			   struct mlx5_ib_pd *pd,
 			   int writ)
 {
-	struct mlx5_ib_fast_reg_page_list *mfrpl = to_mfrpl(wr->wr.fast_reg.page_list);
-	u64 *page_list = wr->wr.fast_reg.page_list->page_list;
+	struct mlx5_ib_fast_reg_page_list *mfrpl = to_mfrpl(fast_reg_wr(wr)->page_list);
+	u64 *page_list = fast_reg_wr(wr)->page_list->page_list;
 	u64 perm = MLX5_EN_RD | (writ ? MLX5_EN_WR : 0);
 	int i;
 
-	for (i = 0; i < wr->wr.fast_reg.page_list_len; i++)
+	for (i = 0; i < fast_reg_wr(wr)->page_list_len; i++)
 		mfrpl->mapped_page_list[i] = cpu_to_be64(page_list[i] | perm);
 	dseg->addr = cpu_to_be64(mfrpl->map);
-	dseg->byte_count = cpu_to_be32(ALIGN(sizeof(u64) * wr->wr.fast_reg.page_list_len, 64));
+	dseg->byte_count = cpu_to_be32(ALIGN(sizeof(u64) * fast_reg_wr(wr)->page_list_len, 64));
 	dseg->lkey = cpu_to_be32(pd->ibpd.local_dma_lkey);
 }
 
@@ -2224,22 +2224,22 @@
 	return 0;
 }
 
-static int set_sig_data_segment(struct ib_send_wr *wr, struct mlx5_ib_qp *qp,
-				void **seg, int *size)
+static int set_sig_data_segment(struct ib_sig_handover_wr *wr,
+				struct mlx5_ib_qp *qp, void **seg, int *size)
 {
-	struct ib_sig_attrs *sig_attrs = wr->wr.sig_handover.sig_attrs;
-	struct ib_mr *sig_mr = wr->wr.sig_handover.sig_mr;
+	struct ib_sig_attrs *sig_attrs = wr->sig_attrs;
+	struct ib_mr *sig_mr = wr->sig_mr;
 	struct mlx5_bsf *bsf;
-	u32 data_len = wr->sg_list->length;
-	u32 data_key = wr->sg_list->lkey;
-	u64 data_va = wr->sg_list->addr;
+	u32 data_len = wr->wr.sg_list->length;
+	u32 data_key = wr->wr.sg_list->lkey;
+	u64 data_va = wr->wr.sg_list->addr;
 	int ret;
 	int wqe_size;
 
-	if (!wr->wr.sig_handover.prot ||
-	    (data_key == wr->wr.sig_handover.prot->lkey &&
-	     data_va == wr->wr.sig_handover.prot->addr &&
-	     data_len == wr->wr.sig_handover.prot->length)) {
+	if (!wr->prot ||
+	    (data_key == wr->prot->lkey &&
+	     data_va == wr->prot->addr &&
+	     data_len == wr->prot->length)) {
 		/**
 		 * Source domain doesn't contain signature information
 		 * or data and protection are interleaved in memory.
@@ -2273,8 +2273,8 @@
 		struct mlx5_stride_block_ctrl_seg *sblock_ctrl;
 		struct mlx5_stride_block_entry *data_sentry;
 		struct mlx5_stride_block_entry *prot_sentry;
-		u32 prot_key = wr->wr.sig_handover.prot->lkey;
-		u64 prot_va = wr->wr.sig_handover.prot->addr;
+		u32 prot_key = wr->prot->lkey;
+		u64 prot_va = wr->prot->addr;
 		u16 block_size = sig_attrs->mem.sig.dif.pi_interval;
 		int prot_size;
 
@@ -2326,16 +2326,16 @@
 }
 
 static void set_sig_mkey_segment(struct mlx5_mkey_seg *seg,
-				 struct ib_send_wr *wr, u32 nelements,
+				 struct ib_sig_handover_wr *wr, u32 nelements,
 				 u32 length, u32 pdn)
 {
-	struct ib_mr *sig_mr = wr->wr.sig_handover.sig_mr;
+	struct ib_mr *sig_mr = wr->sig_mr;
 	u32 sig_key = sig_mr->rkey;
 	u8 sigerr = to_mmr(sig_mr)->sig->sigerr_count & 1;
 
 	memset(seg, 0, sizeof(*seg));
 
-	seg->flags = get_umr_flags(wr->wr.sig_handover.access_flags) |
+	seg->flags = get_umr_flags(wr->access_flags) |
 				   MLX5_ACCESS_MODE_KLM;
 	seg->qpn_mkey7_0 = cpu_to_be32((sig_key & 0xff) | 0xffffff00);
 	seg->flags_pd = cpu_to_be32(MLX5_MKEY_REMOTE_INVAL | sigerr << 26 |
@@ -2346,7 +2346,7 @@
 }
 
 static void set_sig_umr_segment(struct mlx5_wqe_umr_ctrl_seg *umr,
-				struct ib_send_wr *wr, u32 nelements)
+				u32 nelements)
 {
 	memset(umr, 0, sizeof(*umr));
 
@@ -2357,37 +2357,37 @@
 }
 
 
-static int set_sig_umr_wr(struct ib_send_wr *wr, struct mlx5_ib_qp *qp,
+static int set_sig_umr_wr(struct ib_send_wr *send_wr, struct mlx5_ib_qp *qp,
 			  void **seg, int *size)
 {
-	struct mlx5_ib_mr *sig_mr = to_mmr(wr->wr.sig_handover.sig_mr);
+	struct ib_sig_handover_wr *wr = sig_handover_wr(send_wr);
+	struct mlx5_ib_mr *sig_mr = to_mmr(wr->sig_mr);
 	u32 pdn = get_pd(qp)->pdn;
 	u32 klm_oct_size;
 	int region_len, ret;
 
-	if (unlikely(wr->num_sge != 1) ||
-	    unlikely(wr->wr.sig_handover.access_flags &
-		     IB_ACCESS_REMOTE_ATOMIC) ||
+	if (unlikely(wr->wr.num_sge != 1) ||
+	    unlikely(wr->access_flags & IB_ACCESS_REMOTE_ATOMIC) ||
 	    unlikely(!sig_mr->sig) || unlikely(!qp->signature_en) ||
 	    unlikely(!sig_mr->sig->sig_status_checked))
 		return -EINVAL;
 
 	/* length of the protected region, data + protection */
-	region_len = wr->sg_list->length;
-	if (wr->wr.sig_handover.prot &&
-	    (wr->wr.sig_handover.prot->lkey != wr->sg_list->lkey  ||
-	     wr->wr.sig_handover.prot->addr != wr->sg_list->addr  ||
-	     wr->wr.sig_handover.prot->length != wr->sg_list->length))
-		region_len += wr->wr.sig_handover.prot->length;
+	region_len = wr->wr.sg_list->length;
+	if (wr->prot &&
+	    (wr->prot->lkey != wr->wr.sg_list->lkey  ||
+	     wr->prot->addr != wr->wr.sg_list->addr  ||
+	     wr->prot->length != wr->wr.sg_list->length))
+		region_len += wr->prot->length;
 
 	/**
 	 * KLM octoword size - if protection was provided
 	 * then we use strided block format (3 octowords),
 	 * else we use single KLM (1 octoword)
 	 **/
-	klm_oct_size = wr->wr.sig_handover.prot ? 3 : 1;
+	klm_oct_size = wr->prot ? 3 : 1;
 
-	set_sig_umr_segment(*seg, wr, klm_oct_size);
+	set_sig_umr_segment(*seg, klm_oct_size);
 	*seg += sizeof(struct mlx5_wqe_umr_ctrl_seg);
 	*size += sizeof(struct mlx5_wqe_umr_ctrl_seg) / 16;
 	if (unlikely((*seg == qp->sq.qend)))
@@ -2454,8 +2454,8 @@
 	if (unlikely((*seg == qp->sq.qend)))
 		*seg = mlx5_get_send_wqe(qp, 0);
 	if (!li) {
-		if (unlikely(wr->wr.fast_reg.page_list_len >
-			     wr->wr.fast_reg.page_list->max_page_list_len))
+		if (unlikely(fast_reg_wr(wr)->page_list_len >
+			     fast_reg_wr(wr)->page_list->max_page_list_len))
 			return	-ENOMEM;
 
 		set_frwr_pages(*seg, wr, mdev, pd, writ);
@@ -2636,8 +2636,8 @@
 			case IB_WR_RDMA_READ:
 			case IB_WR_RDMA_WRITE:
 			case IB_WR_RDMA_WRITE_WITH_IMM:
-				set_raddr_seg(seg, wr->wr.rdma.remote_addr,
-					      wr->wr.rdma.rkey);
+				set_raddr_seg(seg, rdma_wr(wr)->remote_addr,
+					      rdma_wr(wr)->rkey);
 				seg += sizeof(struct mlx5_wqe_raddr_seg);
 				size += sizeof(struct mlx5_wqe_raddr_seg) / 16;
 				break;
@@ -2666,7 +2666,7 @@
 			case IB_WR_FAST_REG_MR:
 				next_fence = MLX5_FENCE_MODE_INITIATOR_SMALL;
 				qp->sq.wr_data[idx] = IB_WR_FAST_REG_MR;
-				ctrl->imm = cpu_to_be32(wr->wr.fast_reg.rkey);
+				ctrl->imm = cpu_to_be32(fast_reg_wr(wr)->rkey);
 				err = set_frwr_li_wr(&seg, wr, &size, mdev, to_mpd(ibqp->pd), qp);
 				if (err) {
 					mlx5_ib_warn(dev, "\n");
@@ -2678,7 +2678,7 @@
 
 			case IB_WR_REG_SIG_MR:
 				qp->sq.wr_data[idx] = IB_WR_REG_SIG_MR;
-				mr = to_mmr(wr->wr.sig_handover.sig_mr);
+				mr = to_mmr(sig_handover_wr(wr)->sig_mr);
 
 				ctrl->imm = cpu_to_be32(mr->ibmr.rkey);
 				err = set_sig_umr_wr(wr, qp, &seg, &size);
@@ -2706,7 +2706,7 @@
 					goto out;
 				}
 
-				err = set_psv_wr(&wr->wr.sig_handover.sig_attrs->mem,
+				err = set_psv_wr(&sig_handover_wr(wr)->sig_attrs->mem,
 						 mr->sig->psv_memory.psv_idx, &seg,
 						 &size);
 				if (err) {
@@ -2728,7 +2728,7 @@
 				}
 
 				next_fence = MLX5_FENCE_MODE_INITIATOR_SMALL;
-				err = set_psv_wr(&wr->wr.sig_handover.sig_attrs->wire,
+				err = set_psv_wr(&sig_handover_wr(wr)->sig_attrs->wire,
 						 mr->sig->psv_wire.psv_idx, &seg,
 						 &size);
 				if (err) {
@@ -2752,8 +2752,8 @@
 			switch (wr->opcode) {
 			case IB_WR_RDMA_WRITE:
 			case IB_WR_RDMA_WRITE_WITH_IMM:
-				set_raddr_seg(seg, wr->wr.rdma.remote_addr,
-					      wr->wr.rdma.rkey);
+				set_raddr_seg(seg, rdma_wr(wr)->remote_addr,
+					      rdma_wr(wr)->rkey);
 				seg  += sizeof(struct mlx5_wqe_raddr_seg);
 				size += sizeof(struct mlx5_wqe_raddr_seg) / 16;
 				break;
@@ -2780,7 +2780,7 @@
 				goto out;
 			}
 			qp->sq.wr_data[idx] = MLX5_IB_WR_UMR;
-			ctrl->imm = cpu_to_be32(wr->wr.fast_reg.rkey);
+			ctrl->imm = cpu_to_be32(umr_wr(wr)->mkey);
 			set_reg_umr_segment(seg, wr);
 			seg += sizeof(struct mlx5_wqe_umr_ctrl_seg);
 			size += sizeof(struct mlx5_wqe_umr_ctrl_seg) / 16;