IB/core: Add rdma_network_type to wc

Providers should tell IB core the wc's network type.
This is used in order to search for the proper GID in the
GID table. When using HCAs that can't provide this info,
IB core tries to deep examine the packet and extract
the GID type by itself.

We choose sgid_index and type from all the matching entries in
RDMA-CM based on hint from the IP stack and we set hop_limit for
the IP packet based on above hint from IP stack.

Signed-off-by: Matan Barak <matanb@mellanox.com>
Signed-off-by: Somnath Kotur <Somnath.Kotur@Avagotech.Com>
Signed-off-by: Doug Ledford <dledford@redhat.com>
diff --git a/drivers/infiniband/core/verbs.c b/drivers/infiniband/core/verbs.c
index e397d8b..b25e5fb 100644
--- a/drivers/infiniband/core/verbs.c
+++ b/drivers/infiniband/core/verbs.c
@@ -305,8 +305,61 @@
 }
 EXPORT_SYMBOL(ib_create_ah);
 
+static int ib_get_header_version(const union rdma_network_hdr *hdr)
+{
+	const struct iphdr *ip4h = (struct iphdr *)&hdr->roce4grh;
+	struct iphdr ip4h_checked;
+	const struct ipv6hdr *ip6h = (struct ipv6hdr *)&hdr->ibgrh;
+
+	/* If it's IPv6, the version must be 6, otherwise, the first
+	 * 20 bytes (before the IPv4 header) are garbled.
+	 */
+	if (ip6h->version != 6)
+		return (ip4h->version == 4) ? 4 : 0;
+	/* version may be 6 or 4 because the first 20 bytes could be garbled */
+
+	/* RoCE v2 requires no options, thus header length
+	 * must be 5 words
+	 */
+	if (ip4h->ihl != 5)
+		return 6;
+
+	/* Verify checksum.
+	 * We can't write on scattered buffers so we need to copy to
+	 * temp buffer.
+	 */
+	memcpy(&ip4h_checked, ip4h, sizeof(ip4h_checked));
+	ip4h_checked.check = 0;
+	ip4h_checked.check = ip_fast_csum((u8 *)&ip4h_checked, 5);
+	/* if IPv4 header checksum is OK, believe it */
+	if (ip4h->check == ip4h_checked.check)
+		return 4;
+	return 6;
+}
+
+static enum rdma_network_type ib_get_net_type_by_grh(struct ib_device *device,
+						     u8 port_num,
+						     const struct ib_grh *grh)
+{
+	int grh_version;
+
+	if (rdma_protocol_ib(device, port_num))
+		return RDMA_NETWORK_IB;
+
+	grh_version = ib_get_header_version((union rdma_network_hdr *)grh);
+
+	if (grh_version == 4)
+		return RDMA_NETWORK_IPV4;
+
+	if (grh->next_hdr == IPPROTO_UDP)
+		return RDMA_NETWORK_IPV6;
+
+	return RDMA_NETWORK_ROCE_V1;
+}
+
 struct find_gid_index_context {
 	u16 vlan_id;
+	enum ib_gid_type gid_type;
 };
 
 static bool find_gid_index(const union ib_gid *gid,
@@ -316,6 +369,9 @@
 	struct find_gid_index_context *ctx =
 		(struct find_gid_index_context *)context;
 
+	if (ctx->gid_type != gid_attr->gid_type)
+		return false;
+
 	if ((!!(ctx->vlan_id != 0xffff) == !is_vlan_dev(gid_attr->ndev)) ||
 	    (is_vlan_dev(gid_attr->ndev) &&
 	     vlan_dev_vlan_id(gid_attr->ndev) != ctx->vlan_id))
@@ -326,14 +382,49 @@
 
 static int get_sgid_index_from_eth(struct ib_device *device, u8 port_num,
 				   u16 vlan_id, const union ib_gid *sgid,
+				   enum ib_gid_type gid_type,
 				   u16 *gid_index)
 {
-	struct find_gid_index_context context = {.vlan_id = vlan_id};
+	struct find_gid_index_context context = {.vlan_id = vlan_id,
+						 .gid_type = gid_type};
 
 	return ib_find_gid_by_filter(device, sgid, port_num, find_gid_index,
 				     &context, gid_index);
 }
 
+static int get_gids_from_rdma_hdr(union rdma_network_hdr *hdr,
+				  enum rdma_network_type net_type,
+				  union ib_gid *sgid, union ib_gid *dgid)
+{
+	struct sockaddr_in  src_in;
+	struct sockaddr_in  dst_in;
+	__be32 src_saddr, dst_saddr;
+
+	if (!sgid || !dgid)
+		return -EINVAL;
+
+	if (net_type == RDMA_NETWORK_IPV4) {
+		memcpy(&src_in.sin_addr.s_addr,
+		       &hdr->roce4grh.saddr, 4);
+		memcpy(&dst_in.sin_addr.s_addr,
+		       &hdr->roce4grh.daddr, 4);
+		src_saddr = src_in.sin_addr.s_addr;
+		dst_saddr = dst_in.sin_addr.s_addr;
+		ipv6_addr_set_v4mapped(src_saddr,
+				       (struct in6_addr *)sgid);
+		ipv6_addr_set_v4mapped(dst_saddr,
+				       (struct in6_addr *)dgid);
+		return 0;
+	} else if (net_type == RDMA_NETWORK_IPV6 ||
+		   net_type == RDMA_NETWORK_IB) {
+		*dgid = hdr->ibgrh.dgid;
+		*sgid = hdr->ibgrh.sgid;
+		return 0;
+	} else {
+		return -EINVAL;
+	}
+}
+
 int ib_init_ah_from_wc(struct ib_device *device, u8 port_num,
 		       const struct ib_wc *wc, const struct ib_grh *grh,
 		       struct ib_ah_attr *ah_attr)
@@ -341,9 +432,25 @@
 	u32 flow_class;
 	u16 gid_index;
 	int ret;
+	enum rdma_network_type net_type = RDMA_NETWORK_IB;
+	enum ib_gid_type gid_type = IB_GID_TYPE_IB;
+	union ib_gid dgid;
+	union ib_gid sgid;
 
 	memset(ah_attr, 0, sizeof *ah_attr);
 	if (rdma_cap_eth_ah(device, port_num)) {
+		if (wc->wc_flags & IB_WC_WITH_NETWORK_HDR_TYPE)
+			net_type = wc->network_hdr_type;
+		else
+			net_type = ib_get_net_type_by_grh(device, port_num, grh);
+		gid_type = ib_network_to_gid_type(net_type);
+	}
+	ret = get_gids_from_rdma_hdr((union rdma_network_hdr *)grh, net_type,
+				     &sgid, &dgid);
+	if (ret)
+		return ret;
+
+	if (rdma_protocol_roce(device, port_num)) {
 		u16 vlan_id = wc->wc_flags & IB_WC_WITH_VLAN ?
 				wc->vlan_id : 0xffff;
 
@@ -352,7 +459,7 @@
 
 		if (!(wc->wc_flags & IB_WC_WITH_SMAC) ||
 		    !(wc->wc_flags & IB_WC_WITH_VLAN)) {
-			ret = rdma_addr_find_dmac_by_grh(&grh->dgid, &grh->sgid,
+			ret = rdma_addr_find_dmac_by_grh(&dgid, &sgid,
 							 ah_attr->dmac,
 							 wc->wc_flags & IB_WC_WITH_VLAN ?
 							 NULL : &vlan_id,
@@ -362,7 +469,7 @@
 		}
 
 		ret = get_sgid_index_from_eth(device, port_num, vlan_id,
-					      &grh->dgid, &gid_index);
+					      &dgid, gid_type, &gid_index);
 		if (ret)
 			return ret;
 
@@ -377,10 +484,10 @@
 
 	if (wc->wc_flags & IB_WC_GRH) {
 		ah_attr->ah_flags = IB_AH_GRH;
-		ah_attr->grh.dgid = grh->sgid;
+		ah_attr->grh.dgid = sgid;
 
 		if (!rdma_cap_eth_ah(device, port_num)) {
-			ret = ib_find_cached_gid_by_port(device, &grh->dgid,
+			ret = ib_find_cached_gid_by_port(device, &dgid,
 							 IB_GID_TYPE_IB,
 							 port_num, NULL,
 							 &gid_index);
@@ -1020,6 +1127,12 @@
 					ret = -ENXIO;
 				goto out;
 			}
+			if (sgid_attr.gid_type == IB_GID_TYPE_ROCE_UDP_ENCAP)
+				/* TODO: get the hoplimit from the inet/inet6
+				 * device
+				 */
+				qp_attr->ah_attr.grh.hop_limit =
+							IPV6_DEFAULT_HOPLIMIT;
 
 			ifindex = sgid_attr.ndev->ifindex;