vxlan: keep flags and vni in network byte order

Prevent repeated conversions from and to network order in the fast path.

To achieve this, define all flag constants in big endian order and store VNI
as __be32. To prevent confusion between the actual VNI value and the VNI
field from the header (which contains additional reserved byte), strictly
distinguish between "vni" and "vni_field".

Signed-off-by: Jiri Benc <jbenc@redhat.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/drivers/net/vxlan.c b/drivers/net/vxlan.c
index 524e3b1..4e3d3df 100644
--- a/drivers/net/vxlan.c
+++ b/drivers/net/vxlan.c
@@ -197,9 +197,9 @@
 #endif
 
 /* Virtual Network hash table head */
-static inline struct hlist_head *vni_head(struct vxlan_sock *vs, u32 id)
+static inline struct hlist_head *vni_head(struct vxlan_sock *vs, __be32 vni)
 {
-	return &vs->vni_list[hash_32(id, VNI_HASH_BITS)];
+	return &vs->vni_list[hash_32((__force u32)vni, VNI_HASH_BITS)];
 }
 
 /* Socket hash table head */
@@ -242,12 +242,12 @@
 	return NULL;
 }
 
-static struct vxlan_dev *vxlan_vs_find_vni(struct vxlan_sock *vs, u32 id)
+static struct vxlan_dev *vxlan_vs_find_vni(struct vxlan_sock *vs, __be32 vni)
 {
 	struct vxlan_dev *vxlan;
 
-	hlist_for_each_entry_rcu(vxlan, vni_head(vs, id), hlist) {
-		if (vxlan->default_dst.remote_vni == id)
+	hlist_for_each_entry_rcu(vxlan, vni_head(vs, vni), hlist) {
+		if (vxlan->default_dst.remote_vni == vni)
 			return vxlan;
 	}
 
@@ -255,7 +255,7 @@
 }
 
 /* Look up VNI in a per net namespace table */
-static struct vxlan_dev *vxlan_find_vni(struct net *net, u32 id,
+static struct vxlan_dev *vxlan_find_vni(struct net *net, __be32 vni,
 					sa_family_t family, __be16 port,
 					u32 flags)
 {
@@ -265,7 +265,7 @@
 	if (!vs)
 		return NULL;
 
-	return vxlan_vs_find_vni(vs, id);
+	return vxlan_vs_find_vni(vs, vni);
 }
 
 /* Fill in neighbour message in skbuff. */
@@ -315,7 +315,7 @@
 	    nla_put_be16(skb, NDA_PORT, rdst->remote_port))
 		goto nla_put_failure;
 	if (rdst->remote_vni != vxlan->default_dst.remote_vni &&
-	    nla_put_u32(skb, NDA_VNI, rdst->remote_vni))
+	    nla_put_u32(skb, NDA_VNI, be32_to_cpu(rdst->remote_vni)))
 		goto nla_put_failure;
 	if (rdst->remote_ifindex &&
 	    nla_put_u32(skb, NDA_IFINDEX, rdst->remote_ifindex))
@@ -383,7 +383,7 @@
 	};
 	struct vxlan_rdst remote = {
 		.remote_ip = *ipa, /* goes to NDA_DST */
-		.remote_vni = VXLAN_N_VID,
+		.remote_vni = cpu_to_be32(VXLAN_N_VID),
 	};
 
 	vxlan_fdb_notify(vxlan, &f, &remote, RTM_GETNEIGH);
@@ -452,7 +452,7 @@
 /* caller should hold vxlan->hash_lock */
 static struct vxlan_rdst *vxlan_fdb_find_rdst(struct vxlan_fdb *f,
 					      union vxlan_addr *ip, __be16 port,
-					      __u32 vni, __u32 ifindex)
+					      __be32 vni, __u32 ifindex)
 {
 	struct vxlan_rdst *rd;
 
@@ -469,7 +469,8 @@
 
 /* Replace destination of unicast mac */
 static int vxlan_fdb_replace(struct vxlan_fdb *f,
-			     union vxlan_addr *ip, __be16 port, __u32 vni, __u32 ifindex)
+			     union vxlan_addr *ip, __be16 port, __be32 vni,
+			     __u32 ifindex)
 {
 	struct vxlan_rdst *rd;
 
@@ -491,7 +492,7 @@
 
 /* Add/update destinations for multicast */
 static int vxlan_fdb_append(struct vxlan_fdb *f,
-			    union vxlan_addr *ip, __be16 port, __u32 vni,
+			    union vxlan_addr *ip, __be16 port, __be32 vni,
 			    __u32 ifindex, struct vxlan_rdst **rdp)
 {
 	struct vxlan_rdst *rd;
@@ -523,7 +524,8 @@
 static struct vxlanhdr *vxlan_gro_remcsum(struct sk_buff *skb,
 					  unsigned int off,
 					  struct vxlanhdr *vh, size_t hdrlen,
-					  u32 data, struct gro_remcsum *grc,
+					  __be32 vni_field,
+					  struct gro_remcsum *grc,
 					  bool nopartial)
 {
 	size_t start, offset;
@@ -534,10 +536,8 @@
 	if (!NAPI_GRO_CB(skb)->csum_valid)
 		return NULL;
 
-	start = (data & VXLAN_RCO_MASK) << VXLAN_RCO_SHIFT;
-	offset = start + ((data & VXLAN_RCO_UDP) ?
-			  offsetof(struct udphdr, check) :
-			  offsetof(struct tcphdr, check));
+	start = vxlan_rco_start(vni_field);
+	offset = start + vxlan_rco_offset(vni_field);
 
 	vh = skb_gro_remcsum_process(skb, (void *)vh, off, hdrlen,
 				     start, offset, grc, nopartial);
@@ -557,7 +557,7 @@
 	int flush = 1;
 	struct vxlan_sock *vs = container_of(uoff, struct vxlan_sock,
 					     udp_offloads);
-	u32 flags;
+	__be32 flags;
 	struct gro_remcsum grc;
 
 	skb_gro_remcsum_init(&grc);
@@ -573,11 +573,11 @@
 
 	skb_gro_postpull_rcsum(skb, vh, sizeof(struct vxlanhdr));
 
-	flags = ntohl(vh->vx_flags);
+	flags = vh->vx_flags;
 
 	if ((flags & VXLAN_HF_RCO) && (vs->flags & VXLAN_F_REMCSUM_RX)) {
 		vh = vxlan_gro_remcsum(skb, off_vx, vh, sizeof(struct vxlanhdr),
-				       ntohl(vh->vx_vni), &grc,
+				       vh->vx_vni, &grc,
 				       !!(vs->flags &
 					  VXLAN_F_REMCSUM_NOPARTIAL));
 
@@ -668,7 +668,7 @@
 static int vxlan_fdb_create(struct vxlan_dev *vxlan,
 			    const u8 *mac, union vxlan_addr *ip,
 			    __u16 state, __u16 flags,
-			    __be16 port, __u32 vni, __u32 ifindex,
+			    __be16 port, __be32 vni, __u32 ifindex,
 			    __u8 ndm_flags)
 {
 	struct vxlan_rdst *rd = NULL;
@@ -777,7 +777,8 @@
 }
 
 static int vxlan_fdb_parse(struct nlattr *tb[], struct vxlan_dev *vxlan,
-			   union vxlan_addr *ip, __be16 *port, u32 *vni, u32 *ifindex)
+			   union vxlan_addr *ip, __be16 *port, __be32 *vni,
+			   u32 *ifindex)
 {
 	struct net *net = dev_net(vxlan->dev);
 	int err;
@@ -810,7 +811,7 @@
 	if (tb[NDA_VNI]) {
 		if (nla_len(tb[NDA_VNI]) != sizeof(u32))
 			return -EINVAL;
-		*vni = nla_get_u32(tb[NDA_VNI]);
+		*vni = cpu_to_be32(nla_get_u32(tb[NDA_VNI]));
 	} else {
 		*vni = vxlan->default_dst.remote_vni;
 	}
@@ -840,7 +841,8 @@
 	/* struct net *net = dev_net(vxlan->dev); */
 	union vxlan_addr ip;
 	__be16 port;
-	u32 vni, ifindex;
+	__be32 vni;
+	u32 ifindex;
 	int err;
 
 	if (!(ndm->ndm_state & (NUD_PERMANENT|NUD_REACHABLE))) {
@@ -877,7 +879,8 @@
 	struct vxlan_rdst *rd = NULL;
 	union vxlan_addr ip;
 	__be16 port;
-	u32 vni, ifindex;
+	__be32 vni;
+	u32 ifindex;
 	int err;
 
 	err = vxlan_fdb_parse(tb, vxlan, &ip, &port, &vni, &ifindex);
@@ -1133,17 +1136,16 @@
 }
 
 static struct vxlanhdr *vxlan_remcsum(struct sk_buff *skb, struct vxlanhdr *vh,
-				      size_t hdrlen, u32 data, bool nopartial)
+				      size_t hdrlen, __be32 vni_field,
+				      bool nopartial)
 {
 	size_t start, offset, plen;
 
 	if (skb->remcsum_offload)
 		return vh;
 
-	start = (data & VXLAN_RCO_MASK) << VXLAN_RCO_SHIFT;
-	offset = start + ((data & VXLAN_RCO_UDP) ?
-			  offsetof(struct udphdr, check) :
-			  offsetof(struct tcphdr, check));
+	start = vxlan_rco_start(vni_field);
+	offset = start + vxlan_rco_offset(vni_field);
 
 	plen = hdrlen + offset + sizeof(u16);
 
@@ -1159,7 +1161,7 @@
 }
 
 static void vxlan_rcv(struct vxlan_sock *vs, struct sk_buff *skb,
-		      struct vxlan_metadata *md, u32 vni,
+		      struct vxlan_metadata *md, __be32 vni,
 		      struct metadata_dst *tun_dst)
 {
 	struct iphdr *oip = NULL;
@@ -1257,7 +1259,7 @@
 {
 	struct metadata_dst *tun_dst = NULL;
 	struct vxlan_sock *vs;
-	u32 flags, vni;
+	__be32 flags, vni_field;
 	struct vxlan_metadata _md;
 	struct vxlan_metadata *md = &_md;
 
@@ -1265,8 +1267,8 @@
 	if (!pskb_may_pull(skb, VXLAN_HLEN))
 		goto error;
 
-	flags = ntohl(vxlan_hdr(skb)->vx_flags);
-	vni = ntohl(vxlan_hdr(skb)->vx_vni);
+	flags = vxlan_hdr(skb)->vx_flags;
+	vni_field = vxlan_hdr(skb)->vx_vni;
 
 	if (flags & VXLAN_HF_VNI) {
 		flags &= ~VXLAN_HF_VNI;
@@ -1283,17 +1285,18 @@
 		goto drop;
 
 	if ((flags & VXLAN_HF_RCO) && (vs->flags & VXLAN_F_REMCSUM_RX)) {
-		if (!vxlan_remcsum(skb, vxlan_hdr(skb), sizeof(struct vxlanhdr), vni,
+		if (!vxlan_remcsum(skb, vxlan_hdr(skb), sizeof(struct vxlanhdr),
+				   vni_field,
 				   !!(vs->flags & VXLAN_F_REMCSUM_NOPARTIAL)))
 			goto drop;
 
 		flags &= ~VXLAN_HF_RCO;
-		vni &= VXLAN_VNI_MASK;
+		vni_field &= VXLAN_VNI_MASK;
 	}
 
 	if (vxlan_collect_metadata(vs)) {
 		tun_dst = udp_tun_rx_dst(skb, vxlan_get_sk_family(vs), TUNNEL_KEY,
-					 cpu_to_be64(vni >> 8), sizeof(*md));
+					 vxlan_vni(vni_field), sizeof(*md));
 
 		if (!tun_dst)
 			goto drop;
@@ -1324,7 +1327,7 @@
 		flags &= ~VXLAN_GBP_USED_BITS;
 	}
 
-	if (flags || vni & ~VXLAN_VNI_MASK) {
+	if (flags || vni_field & ~VXLAN_VNI_MASK) {
 		/* If there are any unprocessed flags remaining treat
 		 * this as a malformed packet. This behavior diverges from
 		 * VXLAN RFC (RFC7348) which stipulates that bits in reserved
@@ -1337,7 +1340,7 @@
 		goto bad_flags;
 	}
 
-	vxlan_rcv(vs, skb, md, vni >> 8, tun_dst);
+	vxlan_rcv(vs, skb, md, vxlan_vni(vni_field), tun_dst);
 	return 0;
 
 drop:
@@ -1680,7 +1683,7 @@
 		return;
 
 	gbp = (struct vxlanhdr_gbp *)vxh;
-	vxh->vx_flags |= htonl(VXLAN_HF_GBP);
+	vxh->vx_flags |= VXLAN_HF_GBP;
 
 	if (md->gbp & VXLAN_GBP_DONT_LEARN)
 		gbp->dont_learn = 1;
@@ -1700,7 +1703,6 @@
 	int min_headroom;
 	int err;
 	int type = udp_sum ? SKB_GSO_UDP_TUNNEL_CSUM : SKB_GSO_UDP_TUNNEL;
-	u16 hdrlen = sizeof(struct vxlanhdr);
 
 	if ((vxflags & VXLAN_F_REMCSUM_TX) &&
 	    skb->ip_summed == CHECKSUM_PARTIAL) {
@@ -1733,18 +1735,15 @@
 		return PTR_ERR(skb);
 
 	vxh = (struct vxlanhdr *) __skb_push(skb, sizeof(*vxh));
-	vxh->vx_flags = htonl(VXLAN_HF_VNI);
-	vxh->vx_vni = vni;
+	vxh->vx_flags = VXLAN_HF_VNI;
+	vxh->vx_vni = vxlan_vni_field(vni);
 
 	if (type & SKB_GSO_TUNNEL_REMCSUM) {
-		u32 data = (skb_checksum_start_offset(skb) - hdrlen) >>
-			   VXLAN_RCO_SHIFT;
+		unsigned int start;
 
-		if (skb->csum_offset == offsetof(struct udphdr, check))
-			data |= VXLAN_RCO_UDP;
-
-		vxh->vx_vni |= htonl(data);
-		vxh->vx_flags |= htonl(VXLAN_HF_RCO);
+		start = skb_checksum_start_offset(skb) - sizeof(struct vxlanhdr);
+		vxh->vx_vni |= vxlan_compute_rco(start, skb->csum_offset);
+		vxh->vx_flags |= VXLAN_HF_RCO;
 
 		if (!skb_is_gso(skb)) {
 			skb->ip_summed = CHECKSUM_NONE;
@@ -1892,7 +1891,7 @@
 	struct vxlan_metadata _md;
 	struct vxlan_metadata *md = &_md;
 	__be16 src_port = 0, dst_port;
-	u32 vni;
+	__be32 vni;
 	__be16 df = 0;
 	__u8 tos, ttl;
 	int err;
@@ -1914,7 +1913,7 @@
 			goto drop;
 		}
 		dst_port = info->key.tp_dst ? : vxlan->cfg.dst_port;
-		vni = be64_to_cpu(info->key.tun_id);
+		vni = vxlan_tun_id_to_vni(info->key.tun_id);
 		remote_ip.sa.sa_family = ip_tunnel_info_af(info);
 		if (remote_ip.sa.sa_family == AF_INET)
 			remote_ip.sin.sin_addr.s_addr = info->key.u.ipv4.dst;
@@ -2007,7 +2006,7 @@
 		tos = ip_tunnel_ecn_encap(tos, old_iph, skb);
 		ttl = ttl ? : ip4_dst_hoplimit(&rt->dst);
 		err = vxlan_build_skb(skb, &rt->dst, sizeof(struct iphdr),
-				      htonl(vni << 8), md, flags, udp_sum);
+				      vni, md, flags, udp_sum);
 		if (err < 0)
 			goto xmit_tx_error;
 
@@ -2065,7 +2064,7 @@
 		ttl = ttl ? : ip6_dst_hoplimit(ndst);
 		skb_scrub_packet(skb, xnet);
 		err = vxlan_build_skb(skb, ndst, sizeof(struct ipv6hdr),
-				      htonl(vni << 8), md, flags, udp_sum);
+				      vni, md, flags, udp_sum);
 		if (err < 0) {
 			dst_release(ndst);
 			return;
@@ -2222,7 +2221,7 @@
 static void vxlan_vs_add_dev(struct vxlan_sock *vs, struct vxlan_dev *vxlan)
 {
 	struct vxlan_net *vn = net_generic(vxlan->net, vxlan_net_id);
-	__u32 vni = vxlan->default_dst.remote_vni;
+	__be32 vni = vxlan->default_dst.remote_vni;
 
 	spin_lock(&vn->sock_lock);
 	hlist_add_head_rcu(&vxlan->hlist, vni_head(vs, vni));
@@ -2837,7 +2836,7 @@
 	memset(&conf, 0, sizeof(conf));
 
 	if (data[IFLA_VXLAN_ID])
-		conf.vni = nla_get_u32(data[IFLA_VXLAN_ID]);
+		conf.vni = cpu_to_be32(nla_get_u32(data[IFLA_VXLAN_ID]));
 
 	if (data[IFLA_VXLAN_GROUP]) {
 		conf.remote_ip.sin.sin_addr.s_addr = nla_get_in_addr(data[IFLA_VXLAN_GROUP]);
@@ -2941,7 +2940,7 @@
 		break;
 
 	case -EEXIST:
-		pr_info("duplicate VNI %u\n", conf.vni);
+		pr_info("duplicate VNI %u\n", be32_to_cpu(conf.vni));
 		break;
 	}
 
@@ -2999,7 +2998,7 @@
 		.high = htons(vxlan->cfg.port_max),
 	};
 
-	if (nla_put_u32(skb, IFLA_VXLAN_ID, dst->remote_vni))
+	if (nla_put_u32(skb, IFLA_VXLAN_ID, be32_to_cpu(dst->remote_vni)))
 		goto nla_put_failure;
 
 	if (!vxlan_addr_any(&dst->remote_ip)) {