IB/rdmavt: Add per verb driver callback checking

For each verb validate that all requirements for driver callbacks are met.
If a function is called without checking for a valid pointer, it is a
required function. Also document what each callback function does.

Reviewed-by: Mike Marciniszyn <mike.marciniszyn@intel.com>
Signed-off-by: Dennis Dalessandro <dennis.dalessandro@intel.com>
Signed-off-by: Doug Ledford <dledford@redhat.com>
diff --git a/drivers/infiniband/sw/rdmavt/vt.c b/drivers/infiniband/sw/rdmavt/vt.c
index f5cb09b..9566a92 100644
--- a/drivers/infiniband/sw/rdmavt/vt.c
+++ b/drivers/infiniband/sw/rdmavt/vt.c
@@ -325,12 +325,375 @@
 	return 0;
 }
 
-/*
- * Check driver override. If driver passes a value use it, otherwise we use our
- * own value.
- */
-#define CHECK_DRIVER_OVERRIDE(rdi, x) \
-	rdi->ibdev.x = rdi->ibdev.x ? : rvt_ ##x
+enum {
+	MISC,
+	QUERY_DEVICE,
+	MODIFY_DEVICE,
+	QUERY_PORT,
+	MODIFY_PORT,
+	QUERY_PKEY,
+	QUERY_GID,
+	ALLOC_UCONTEXT,
+	DEALLOC_UCONTEXT,
+	GET_PORT_IMMUTABLE,
+	CREATE_QP,
+	MODIFY_QP,
+	DESTROY_QP,
+	QUERY_QP,
+	POST_SEND,
+	POST_RECV,
+	POST_SRQ_RECV,
+	CREATE_AH,
+	DESTROY_AH,
+	MODIFY_AH,
+	QUERY_AH,
+	CREATE_SRQ,
+	MODIFY_SRQ,
+	DESTROY_SRQ,
+	QUERY_SRQ,
+	ATTACH_MCAST,
+	DETACH_MCAST,
+	GET_DMA_MR,
+	REG_USER_MR,
+	DEREG_MR,
+	ALLOC_MR,
+	ALLOC_FMR,
+	MAP_PHYS_FMR,
+	UNMAP_FMR,
+	DEALLOC_FMR,
+	MMAP,
+	CREATE_CQ,
+	DESTROY_CQ,
+	POLL_CQ,
+	REQ_NOTFIY_CQ,
+	RESIZE_CQ,
+	ALLOC_PD,
+	DEALLOC_PD,
+	_VERB_IDX_MAX /* Must always be last! */
+};
+
+static inline int check_driver_override(struct rvt_dev_info *rdi,
+					size_t offset, void *func)
+{
+	if (!*(void **)((void *)&rdi->ibdev + offset)) {
+		*(void **)((void *)&rdi->ibdev + offset) = func;
+		return 0;
+	}
+
+	return 1;
+}
+
+static int check_support(struct rvt_dev_info *rdi, int verb)
+{
+	switch (verb) {
+	case MISC:
+		/*
+		 * These functions are not part of verbs specifically but are
+		 * required for rdmavt to function.
+		 */
+		if ((!rdi->driver_f.port_callback) ||
+		    (!rdi->driver_f.get_card_name) ||
+		    (!rdi->driver_f.get_pci_dev))
+			return -EINVAL;
+		break;
+
+	case QUERY_DEVICE:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    query_device),
+						    rvt_query_device);
+		break;
+
+	case MODIFY_DEVICE:
+		/*
+		 * rdmavt does not support modify device currently drivers must
+		 * provide.
+		 */
+		if (!check_driver_override(rdi, offsetof(struct ib_device,
+							 modify_device),
+					   rvt_modify_device))
+			return -EOPNOTSUPP;
+		break;
+
+	case QUERY_PORT:
+		if (!check_driver_override(rdi, offsetof(struct ib_device,
+							 query_port),
+					   rvt_query_port))
+			if (!rdi->driver_f.query_port_state)
+				return -EINVAL;
+		break;
+
+	case MODIFY_PORT:
+		if (!check_driver_override(rdi, offsetof(struct ib_device,
+							 modify_port),
+					   rvt_modify_port))
+			if (!rdi->driver_f.cap_mask_chg ||
+			    !rdi->driver_f.shut_down_port)
+				return -EINVAL;
+		break;
+
+	case QUERY_PKEY:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    query_pkey),
+				      rvt_query_pkey);
+		break;
+
+	case QUERY_GID:
+		if (!check_driver_override(rdi, offsetof(struct ib_device,
+							 query_gid),
+					   rvt_query_gid))
+			if (!rdi->driver_f.get_guid_be)
+				return -EINVAL;
+		break;
+
+	case ALLOC_UCONTEXT:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    alloc_ucontext),
+				      rvt_alloc_ucontext);
+		break;
+
+	case DEALLOC_UCONTEXT:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    dealloc_ucontext),
+				      rvt_dealloc_ucontext);
+		break;
+
+	case GET_PORT_IMMUTABLE:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    get_port_immutable),
+				      rvt_get_port_immutable);
+		break;
+
+	case CREATE_QP:
+		if (!check_driver_override(rdi, offsetof(struct ib_device,
+							 create_qp),
+					   rvt_create_qp))
+			if (!rdi->driver_f.qp_priv_alloc ||
+			    !rdi->driver_f.qp_priv_free ||
+			    !rdi->driver_f.notify_qp_reset ||
+			    !rdi->driver_f.flush_qp_waiters ||
+			    !rdi->driver_f.stop_send_queue ||
+			    !rdi->driver_f.quiesce_qp)
+				return -EINVAL;
+		break;
+
+	case MODIFY_QP:
+		if (!check_driver_override(rdi, offsetof(struct ib_device,
+							 modify_qp),
+					   rvt_modify_qp))
+			if (!rdi->driver_f.notify_qp_reset ||
+			    !rdi->driver_f.schedule_send ||
+			    !rdi->driver_f.get_pmtu_from_attr ||
+			    !rdi->driver_f.flush_qp_waiters ||
+			    !rdi->driver_f.stop_send_queue ||
+			    !rdi->driver_f.quiesce_qp ||
+			    !rdi->driver_f.notify_error_qp ||
+			    !rdi->driver_f.mtu_from_qp ||
+			    !rdi->driver_f.mtu_to_path_mtu ||
+			    !rdi->driver_f.shut_down_port ||
+			    !rdi->driver_f.cap_mask_chg)
+				return -EINVAL;
+		break;
+
+	case DESTROY_QP:
+		if (!check_driver_override(rdi, offsetof(struct ib_device,
+							 destroy_qp),
+					   rvt_destroy_qp))
+			if (!rdi->driver_f.qp_priv_free ||
+			    !rdi->driver_f.notify_qp_reset ||
+			    !rdi->driver_f.flush_qp_waiters ||
+			    !rdi->driver_f.stop_send_queue ||
+			    !rdi->driver_f.quiesce_qp)
+				return -EINVAL;
+		break;
+
+	case QUERY_QP:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    query_qp),
+						    rvt_query_qp);
+		break;
+
+	case POST_SEND:
+		if (!check_driver_override(rdi, offsetof(struct ib_device,
+							 post_send),
+					   rvt_post_send))
+			if (!rdi->driver_f.schedule_send ||
+			    !rdi->driver_f.do_send)
+				return -EINVAL;
+		break;
+
+	case POST_RECV:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    post_recv),
+				      rvt_post_recv);
+		break;
+	case POST_SRQ_RECV:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    post_srq_recv),
+				      rvt_post_srq_recv);
+		break;
+
+	case CREATE_AH:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    create_ah),
+				      rvt_create_ah);
+		break;
+
+	case DESTROY_AH:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    destroy_ah),
+				      rvt_destroy_ah);
+		break;
+
+	case MODIFY_AH:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    modify_ah),
+				      rvt_modify_ah);
+		break;
+
+	case QUERY_AH:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    query_ah),
+				      rvt_query_ah);
+		break;
+
+	case CREATE_SRQ:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    create_srq),
+				      rvt_create_srq);
+		break;
+
+	case MODIFY_SRQ:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    modify_srq),
+				      rvt_modify_srq);
+		break;
+
+	case DESTROY_SRQ:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    destroy_srq),
+				      rvt_destroy_srq);
+		break;
+
+	case QUERY_SRQ:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    query_srq),
+				      rvt_query_srq);
+		break;
+
+	case ATTACH_MCAST:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    attach_mcast),
+				      rvt_attach_mcast);
+		break;
+
+	case DETACH_MCAST:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    detach_mcast),
+				      rvt_detach_mcast);
+		break;
+
+	case GET_DMA_MR:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    get_dma_mr),
+				      rvt_get_dma_mr);
+		break;
+
+	case REG_USER_MR:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    reg_user_mr),
+				      rvt_reg_user_mr);
+		break;
+
+	case DEREG_MR:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    dereg_mr),
+				      rvt_dereg_mr);
+		break;
+
+	case ALLOC_FMR:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    alloc_fmr),
+				      rvt_alloc_fmr);
+		break;
+
+	case ALLOC_MR:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    alloc_mr),
+				      rvt_alloc_mr);
+		break;
+
+	case MAP_PHYS_FMR:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    map_phys_fmr),
+				      rvt_map_phys_fmr);
+		break;
+
+	case UNMAP_FMR:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    unmap_fmr),
+				      rvt_unmap_fmr);
+		break;
+
+	case DEALLOC_FMR:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    dealloc_fmr),
+				      rvt_dealloc_fmr);
+		break;
+
+	case MMAP:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    mmap),
+				      rvt_mmap);
+		break;
+
+	case CREATE_CQ:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    create_cq),
+				      rvt_create_cq);
+		break;
+
+	case DESTROY_CQ:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    destroy_cq),
+				      rvt_destroy_cq);
+		break;
+
+	case POLL_CQ:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    poll_cq),
+				      rvt_poll_cq);
+		break;
+
+	case REQ_NOTFIY_CQ:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    req_notify_cq),
+				      rvt_req_notify_cq);
+		break;
+
+	case RESIZE_CQ:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    resize_cq),
+				      rvt_resize_cq);
+		break;
+
+	case ALLOC_PD:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    alloc_pd),
+				      rvt_alloc_pd);
+		break;
+
+	case DEALLOC_PD:
+		check_driver_override(rdi, offsetof(struct ib_device,
+						    dealloc_pd),
+				      rvt_dealloc_pd);
+		break;
+
+	default:
+		return -EINVAL;
+	}
+
+	return 0;
+}
 
 /**
  * rvt_register_device - register a driver
@@ -343,35 +706,26 @@
  */
 int rvt_register_device(struct rvt_dev_info *rdi)
 {
-	/* Validate that drivers have provided the right information */
-	int ret = 0;
+	int ret = 0, i;
 
 	if (!rdi)
 		return -EINVAL;
 
-	if ((!rdi->driver_f.port_callback) ||
-	    (!rdi->driver_f.get_card_name) ||
-	    (!rdi->driver_f.get_pci_dev) ||
-	    (!rdi->driver_f.check_ah)) {
-		pr_err("Driver not supporting req func\n");
-		return -EINVAL;
-	}
+	/*
+	 * Check to ensure drivers have setup the required helpers for the verbs
+	 * they want rdmavt to handle
+	 */
+	for (i = 0; i < _VERB_IDX_MAX; i++)
+		if (check_support(rdi, i)) {
+			pr_err("Driver support req not met at %d\n", i);
+			return -EINVAL;
+		}
+
 
 	/* Once we get past here we can use rvt_pr macros and tracepoints */
 	trace_rvt_dbg(rdi, "Driver attempting registration");
 	rvt_mmap_init(rdi);
 
-	/* Dev Ops */
-	CHECK_DRIVER_OVERRIDE(rdi, query_device);
-	CHECK_DRIVER_OVERRIDE(rdi, modify_device);
-	CHECK_DRIVER_OVERRIDE(rdi, query_port);
-	CHECK_DRIVER_OVERRIDE(rdi, modify_port);
-	CHECK_DRIVER_OVERRIDE(rdi, query_pkey);
-	CHECK_DRIVER_OVERRIDE(rdi, query_gid);
-	CHECK_DRIVER_OVERRIDE(rdi, alloc_ucontext);
-	CHECK_DRIVER_OVERRIDE(rdi, dealloc_ucontext);
-	CHECK_DRIVER_OVERRIDE(rdi, get_port_immutable);
-
 	/* Queue Pairs */
 	ret = rvt_driver_qp_init(rdi);
 	if (ret) {
@@ -379,33 +733,15 @@
 		return -EINVAL;
 	}
 
-	CHECK_DRIVER_OVERRIDE(rdi, create_qp);
-	CHECK_DRIVER_OVERRIDE(rdi, modify_qp);
-	CHECK_DRIVER_OVERRIDE(rdi, destroy_qp);
-	CHECK_DRIVER_OVERRIDE(rdi, query_qp);
-	CHECK_DRIVER_OVERRIDE(rdi, post_send);
-	CHECK_DRIVER_OVERRIDE(rdi, post_recv);
-	CHECK_DRIVER_OVERRIDE(rdi, post_srq_recv);
-
 	/* Address Handle */
-	CHECK_DRIVER_OVERRIDE(rdi, create_ah);
-	CHECK_DRIVER_OVERRIDE(rdi, destroy_ah);
-	CHECK_DRIVER_OVERRIDE(rdi, modify_ah);
-	CHECK_DRIVER_OVERRIDE(rdi, query_ah);
 	spin_lock_init(&rdi->n_ahs_lock);
 	rdi->n_ahs_allocated = 0;
 
 	/* Shared Receive Queue */
-	CHECK_DRIVER_OVERRIDE(rdi, create_srq);
-	CHECK_DRIVER_OVERRIDE(rdi, modify_srq);
-	CHECK_DRIVER_OVERRIDE(rdi, destroy_srq);
-	CHECK_DRIVER_OVERRIDE(rdi, query_srq);
 	rvt_driver_srq_init(rdi);
 
 	/* Multicast */
 	rvt_driver_mcast_init(rdi);
-	CHECK_DRIVER_OVERRIDE(rdi, attach_mcast);
-	CHECK_DRIVER_OVERRIDE(rdi, detach_mcast);
 
 	/* Mem Region */
 	ret = rvt_driver_mr_init(rdi);
@@ -414,35 +750,18 @@
 		goto bail_no_mr;
 	}
 
-	CHECK_DRIVER_OVERRIDE(rdi, get_dma_mr);
-	CHECK_DRIVER_OVERRIDE(rdi, reg_user_mr);
-	CHECK_DRIVER_OVERRIDE(rdi, dereg_mr);
-	CHECK_DRIVER_OVERRIDE(rdi, alloc_mr);
-	CHECK_DRIVER_OVERRIDE(rdi, alloc_fmr);
-	CHECK_DRIVER_OVERRIDE(rdi, map_phys_fmr);
-	CHECK_DRIVER_OVERRIDE(rdi, unmap_fmr);
-	CHECK_DRIVER_OVERRIDE(rdi, dealloc_fmr);
-	CHECK_DRIVER_OVERRIDE(rdi, mmap);
-
 	/* Completion queues */
 	ret = rvt_driver_cq_init(rdi);
 	if (ret) {
 		pr_err("Error in driver CQ init.\n");
 		goto bail_mr;
 	}
-	CHECK_DRIVER_OVERRIDE(rdi, create_cq);
-	CHECK_DRIVER_OVERRIDE(rdi, destroy_cq);
-	CHECK_DRIVER_OVERRIDE(rdi, poll_cq);
-	CHECK_DRIVER_OVERRIDE(rdi, req_notify_cq);
-	CHECK_DRIVER_OVERRIDE(rdi, resize_cq);
 
 	/* DMA Operations */
 	rdi->ibdev.dma_ops =
 		rdi->ibdev.dma_ops ? : &rvt_default_dma_mapping_ops;
 
 	/* Protection Domain */
-	CHECK_DRIVER_OVERRIDE(rdi, alloc_pd);
-	CHECK_DRIVER_OVERRIDE(rdi, dealloc_pd);
 	spin_lock_init(&rdi->n_pds_lock);
 	rdi->n_pds_allocated = 0;