udp_tunnel: Pass UDP socket down through udp_tunnel{, 6}_xmit_skb().

That was we can make sure the output path of ipv4/ipv6 operate on
the UDP socket rather than whatever random thing happens to be in
skb->sk.

Based upon a patch by Jiri Pirko.

Signed-off-by: David S. Miller <davem@davemloft.net>
Acked-by: Hannes Frederic Sowa <hannes@stressinduktion.org>
diff --git a/net/ipv4/geneve.c b/net/ipv4/geneve.c
index e64f8e9..b77f5e8 100644
--- a/net/ipv4/geneve.c
+++ b/net/ipv4/geneve.c
@@ -136,7 +136,7 @@
 
 	skb_set_inner_protocol(skb, htons(ETH_P_TEB));
 
-	return udp_tunnel_xmit_skb(rt, skb, src, dst,
+	return udp_tunnel_xmit_skb(rt, gs->sock->sk, skb, src, dst,
 				   tos, ttl, df, src_port, dst_port, xnet,
 				   !csum);
 }
diff --git a/net/ipv4/ip_tunnel.c b/net/ipv4/ip_tunnel.c
index 6d364ab..4c2c3ba 100644
--- a/net/ipv4/ip_tunnel.c
+++ b/net/ipv4/ip_tunnel.c
@@ -782,7 +782,7 @@
 		return;
 	}
 
-	err = iptunnel_xmit(skb->sk, rt, skb, fl4.saddr, fl4.daddr, protocol,
+	err = iptunnel_xmit(NULL, rt, skb, fl4.saddr, fl4.daddr, protocol,
 			    tos, ttl, df, !net_eq(tunnel->net, dev_net(dev)));
 	iptunnel_xmit_stats(err, &dev->stats, dev->tstats);
 
diff --git a/net/ipv4/udp_tunnel.c b/net/ipv4/udp_tunnel.c
index c83b354..6bb98cc1 100644
--- a/net/ipv4/udp_tunnel.c
+++ b/net/ipv4/udp_tunnel.c
@@ -75,7 +75,7 @@
 }
 EXPORT_SYMBOL_GPL(setup_udp_tunnel_sock);
 
-int udp_tunnel_xmit_skb(struct rtable *rt, struct sk_buff *skb,
+int udp_tunnel_xmit_skb(struct rtable *rt, struct sock *sk, struct sk_buff *skb,
 			__be32 src, __be32 dst, __u8 tos, __u8 ttl,
 			__be16 df, __be16 src_port, __be16 dst_port,
 			bool xnet, bool nocheck)
@@ -92,7 +92,7 @@
 
 	udp_set_csum(nocheck, skb, src, dst, skb->len);
 
-	return iptunnel_xmit(skb->sk, rt, skb, src, dst, IPPROTO_UDP,
+	return iptunnel_xmit(sk, rt, skb, src, dst, IPPROTO_UDP,
 			     tos, ttl, df, xnet);
 }
 EXPORT_SYMBOL_GPL(udp_tunnel_xmit_skb);
diff --git a/net/ipv6/ip6_gre.c b/net/ipv6/ip6_gre.c
index f724329..b5e6cc1 100644
--- a/net/ipv6/ip6_gre.c
+++ b/net/ipv6/ip6_gre.c
@@ -760,7 +760,7 @@
 
 	skb_set_inner_protocol(skb, protocol);
 
-	ip6tunnel_xmit(skb, dev);
+	ip6tunnel_xmit(NULL, skb, dev);
 	if (ndst)
 		ip6_tnl_dst_store(tunnel, ndst);
 	return 0;
diff --git a/net/ipv6/ip6_tunnel.c b/net/ipv6/ip6_tunnel.c
index b6a211a..5cafd92 100644
--- a/net/ipv6/ip6_tunnel.c
+++ b/net/ipv6/ip6_tunnel.c
@@ -1100,7 +1100,7 @@
 	ipv6h->nexthdr = proto;
 	ipv6h->saddr = fl6->saddr;
 	ipv6h->daddr = fl6->daddr;
-	ip6tunnel_xmit(skb, dev);
+	ip6tunnel_xmit(NULL, skb, dev);
 	if (ndst)
 		ip6_tnl_dst_store(t, ndst);
 	return 0;
diff --git a/net/ipv6/ip6_udp_tunnel.c b/net/ipv6/ip6_udp_tunnel.c
index 32d9b26..bba8903 100644
--- a/net/ipv6/ip6_udp_tunnel.c
+++ b/net/ipv6/ip6_udp_tunnel.c
@@ -62,7 +62,8 @@
 }
 EXPORT_SYMBOL_GPL(udp_sock_create6);
 
-int udp_tunnel6_xmit_skb(struct dst_entry *dst, struct sk_buff *skb,
+int udp_tunnel6_xmit_skb(struct dst_entry *dst, struct sock *sk,
+			 struct sk_buff *skb,
 			 struct net_device *dev, struct in6_addr *saddr,
 			 struct in6_addr *daddr,
 			 __u8 prio, __u8 ttl, __be16 src_port,
@@ -97,7 +98,7 @@
 	ip6h->daddr	  = *daddr;
 	ip6h->saddr	  = *saddr;
 
-	ip6tunnel_xmit(skb, dev);
+	ip6tunnel_xmit(sk, skb, dev);
 	return 0;
 }
 EXPORT_SYMBOL_GPL(udp_tunnel6_xmit_skb);
diff --git a/net/ipv6/output_core.c b/net/ipv6/output_core.c
index 7d1131d..85892af 100644
--- a/net/ipv6/output_core.c
+++ b/net/ipv6/output_core.c
@@ -136,7 +136,7 @@
 EXPORT_SYMBOL(ip6_dst_hoplimit);
 #endif
 
-int __ip6_local_out(struct sk_buff *skb)
+static int __ip6_local_out_sk(struct sock *sk, struct sk_buff *skb)
 {
 	int len;
 
@@ -146,19 +146,30 @@
 	ipv6_hdr(skb)->payload_len = htons(len);
 	IP6CB(skb)->nhoff = offsetof(struct ipv6hdr, nexthdr);
 
-	return nf_hook(NFPROTO_IPV6, NF_INET_LOCAL_OUT, skb->sk, skb,
+	return nf_hook(NFPROTO_IPV6, NF_INET_LOCAL_OUT, sk, skb,
 		       NULL, skb_dst(skb)->dev, dst_output_sk);
 }
+
+int __ip6_local_out(struct sk_buff *skb)
+{
+	return __ip6_local_out_sk(skb->sk, skb);
+}
 EXPORT_SYMBOL_GPL(__ip6_local_out);
 
-int ip6_local_out(struct sk_buff *skb)
+int ip6_local_out_sk(struct sock *sk, struct sk_buff *skb)
 {
 	int err;
 
-	err = __ip6_local_out(skb);
+	err = __ip6_local_out_sk(sk, skb);
 	if (likely(err == 1))
-		err = dst_output(skb);
+		err = dst_output_sk(sk, skb);
 
 	return err;
 }
+EXPORT_SYMBOL_GPL(ip6_local_out_sk);
+
+int ip6_local_out(struct sk_buff *skb)
+{
+	return ip6_local_out_sk(skb->sk, skb);
+}
 EXPORT_SYMBOL_GPL(ip6_local_out);
diff --git a/net/openvswitch/vport-vxlan.c b/net/openvswitch/vport-vxlan.c
index 3277a75..6d39766 100644
--- a/net/openvswitch/vport-vxlan.c
+++ b/net/openvswitch/vport-vxlan.c
@@ -222,7 +222,8 @@
 {
 	struct net *net = ovs_dp_get_net(vport->dp);
 	struct vxlan_port *vxlan_port = vxlan_vport(vport);
-	__be16 dst_port = inet_sk(vxlan_port->vs->sock->sk)->inet_sport;
+	struct sock *sk = vxlan_port->vs->sock->sk;
+	__be16 dst_port = inet_sk(sk)->inet_sport;
 	const struct ovs_key_ipv4_tunnel *tun_key;
 	struct vxlan_metadata md = {0};
 	struct rtable *rt;
@@ -255,7 +256,7 @@
 	vxflags = vxlan_port->exts |
 		      (tun_key->tun_flags & TUNNEL_CSUM ? VXLAN_F_UDP_CSUM : 0);
 
-	err = vxlan_xmit_skb(rt, skb, fl.saddr, tun_key->ipv4_dst,
+	err = vxlan_xmit_skb(rt, sk, skb, fl.saddr, tun_key->ipv4_dst,
 			     tun_key->ipv4_tos, tun_key->ipv4_ttl, df,
 			     src_port, dst_port,
 			     &md, false, vxflags);
diff --git a/net/tipc/udp_media.c b/net/tipc/udp_media.c
index ef3d7aa..66deebc 100644
--- a/net/tipc/udp_media.c
+++ b/net/tipc/udp_media.c
@@ -176,7 +176,8 @@
 			goto tx_error;
 		}
 		ttl = ip4_dst_hoplimit(&rt->dst);
-		err = udp_tunnel_xmit_skb(rt, clone, src->ipv4.s_addr,
+		err = udp_tunnel_xmit_skb(rt, ub->ubsock->sk, clone,
+					  src->ipv4.s_addr,
 					  dst->ipv4.s_addr, 0, ttl, 0,
 					  src->udp_port, dst->udp_port,
 					  false, true);
@@ -197,7 +198,8 @@
 		if (err)
 			goto tx_error;
 		ttl = ip6_dst_hoplimit(ndst);
-		err = udp_tunnel6_xmit_skb(ndst, clone, ndst->dev, &src->ipv6,
+		err = udp_tunnel6_xmit_skb(ndst, ub->ubsock->sk, clone,
+					   ndst->dev, &src->ipv6,
 					   &dst->ipv6, 0, ttl, src->udp_port,
 					   dst->udp_port, false);
 #endif