netfilter: Pass socket pointer down through okfn().

On the output paths in particular, we have to sometimes deal with two
socket contexts.  First, and usually skb->sk, is the local socket that
generated the frame.

And second, is potentially the socket used to control a tunneling
socket, such as one the encapsulates using UDP.

We do not want to disassociate skb->sk when encapsulating in order
to fix this, because that would break socket memory accounting.

The most extreme case where this can cause huge problems is an
AF_PACKET socket transmitting over a vxlan device.  We hit code
paths doing checks that assume they are dealing with an ipv4
socket, but are actually operating upon the AF_PACKET one.

Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/ipv6/ip6_input.c b/net/ipv6/ip6_input.c
index fb97f7f..f2e464e 100644
--- a/net/ipv6/ip6_input.c
+++ b/net/ipv6/ip6_input.c
@@ -46,8 +46,7 @@
 #include <net/xfrm.h>
 #include <net/inet_ecn.h>
 
-
-int ip6_rcv_finish(struct sk_buff *skb)
+int ip6_rcv_finish(struct sock *sk, struct sk_buff *skb)
 {
 	if (sysctl_ip_early_demux && !skb_dst(skb) && skb->sk == NULL) {
 		const struct inet6_protocol *ipprot;
@@ -183,7 +182,8 @@
 	/* Must drop socket now because of tproxy. */
 	skb_orphan(skb);
 
-	return NF_HOOK(NFPROTO_IPV6, NF_INET_PRE_ROUTING, skb, dev, NULL,
+	return NF_HOOK(NFPROTO_IPV6, NF_INET_PRE_ROUTING, NULL, skb,
+		       dev, NULL,
 		       ip6_rcv_finish);
 err:
 	IP6_INC_STATS_BH(net, idev, IPSTATS_MIB_INHDRERRORS);
@@ -198,7 +198,7 @@
  */
 
 
-static int ip6_input_finish(struct sk_buff *skb)
+static int ip6_input_finish(struct sock *sk, struct sk_buff *skb)
 {
 	struct net *net = dev_net(skb_dst(skb)->dev);
 	const struct inet6_protocol *ipprot;
@@ -277,7 +277,8 @@
 
 int ip6_input(struct sk_buff *skb)
 {
-	return NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_IN, skb, skb->dev, NULL,
+	return NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_IN, NULL, skb,
+		       skb->dev, NULL,
 		       ip6_input_finish);
 }
 
diff --git a/net/ipv6/ip6_output.c b/net/ipv6/ip6_output.c
index 654f245..7fde1f2 100644
--- a/net/ipv6/ip6_output.c
+++ b/net/ipv6/ip6_output.c
@@ -56,7 +56,7 @@
 #include <net/checksum.h>
 #include <linux/mroute6.h>
 
-static int ip6_finish_output2(struct sk_buff *skb)
+static int ip6_finish_output2(struct sock *sk, struct sk_buff *skb)
 {
 	struct dst_entry *dst = skb_dst(skb);
 	struct net_device *dev = dst->dev;
@@ -70,7 +70,7 @@
 	if (ipv6_addr_is_multicast(&ipv6_hdr(skb)->daddr)) {
 		struct inet6_dev *idev = ip6_dst_idev(skb_dst(skb));
 
-		if (!(dev->flags & IFF_LOOPBACK) && sk_mc_loop(skb->sk) &&
+		if (!(dev->flags & IFF_LOOPBACK) && sk_mc_loop(sk) &&
 		    ((mroute6_socket(dev_net(dev), skb) &&
 		     !(IP6CB(skb)->flags & IP6SKB_FORWARDED)) ||
 		     ipv6_chk_mcast_addr(dev, &ipv6_hdr(skb)->daddr,
@@ -82,7 +82,7 @@
 			 */
 			if (newskb)
 				NF_HOOK(NFPROTO_IPV6, NF_INET_POST_ROUTING,
-					newskb, NULL, newskb->dev,
+					sk, newskb, NULL, newskb->dev,
 					dev_loopback_xmit);
 
 			if (ipv6_hdr(skb)->hop_limit == 0) {
@@ -122,14 +122,14 @@
 	return -EINVAL;
 }
 
-static int ip6_finish_output(struct sk_buff *skb)
+static int ip6_finish_output(struct sock *sk, struct sk_buff *skb)
 {
 	if ((skb->len > ip6_skb_dst_mtu(skb) && !skb_is_gso(skb)) ||
 	    dst_allfrag(skb_dst(skb)) ||
 	    (IP6CB(skb)->frag_max_size && skb->len > IP6CB(skb)->frag_max_size))
-		return ip6_fragment(skb, ip6_finish_output2);
+		return ip6_fragment(sk, skb, ip6_finish_output2);
 	else
-		return ip6_finish_output2(skb);
+		return ip6_finish_output2(sk, skb);
 }
 
 int ip6_output(struct sock *sk, struct sk_buff *skb)
@@ -143,7 +143,8 @@
 		return 0;
 	}
 
-	return NF_HOOK_COND(NFPROTO_IPV6, NF_INET_POST_ROUTING, skb, NULL, dev,
+	return NF_HOOK_COND(NFPROTO_IPV6, NF_INET_POST_ROUTING, sk, skb,
+			    NULL, dev,
 			    ip6_finish_output,
 			    !(IP6CB(skb)->flags & IP6SKB_REROUTED));
 }
@@ -223,8 +224,8 @@
 	if ((skb->len <= mtu) || skb->ignore_df || skb_is_gso(skb)) {
 		IP6_UPD_PO_STATS(net, ip6_dst_idev(skb_dst(skb)),
 			      IPSTATS_MIB_OUT, skb->len);
-		return NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT, skb, NULL,
-			       dst->dev, dst_output);
+		return NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT, sk, skb,
+			       NULL, dst->dev, dst_output_sk);
 	}
 
 	skb->dev = dst->dev;
@@ -316,10 +317,10 @@
 	return 0;
 }
 
-static inline int ip6_forward_finish(struct sk_buff *skb)
+static inline int ip6_forward_finish(struct sock *sk, struct sk_buff *skb)
 {
 	skb_sender_cpu_clear(skb);
-	return dst_output(skb);
+	return dst_output_sk(sk, skb);
 }
 
 static unsigned int ip6_dst_mtu_forward(const struct dst_entry *dst)
@@ -511,7 +512,8 @@
 
 	IP6_INC_STATS_BH(net, ip6_dst_idev(dst), IPSTATS_MIB_OUTFORWDATAGRAMS);
 	IP6_ADD_STATS_BH(net, ip6_dst_idev(dst), IPSTATS_MIB_OUTOCTETS, skb->len);
-	return NF_HOOK(NFPROTO_IPV6, NF_INET_FORWARD, skb, skb->dev, dst->dev,
+	return NF_HOOK(NFPROTO_IPV6, NF_INET_FORWARD, NULL, skb,
+		       skb->dev, dst->dev,
 		       ip6_forward_finish);
 
 error:
@@ -538,7 +540,8 @@
 	skb_copy_secmark(to, from);
 }
 
-int ip6_fragment(struct sk_buff *skb, int (*output)(struct sk_buff *))
+int ip6_fragment(struct sock *sk, struct sk_buff *skb,
+		 int (*output)(struct sock *, struct sk_buff *))
 {
 	struct sk_buff *frag;
 	struct rt6_info *rt = (struct rt6_info *)skb_dst(skb);
@@ -667,7 +670,7 @@
 				ip6_copy_metadata(frag, skb);
 			}
 
-			err = output(skb);
+			err = output(sk, skb);
 			if (!err)
 				IP6_INC_STATS(net, ip6_dst_idev(&rt->dst),
 					      IPSTATS_MIB_FRAGCREATES);
@@ -800,7 +803,7 @@
 		/*
 		 *	Put this fragment into the sending queue.
 		 */
-		err = output(frag);
+		err = output(sk, frag);
 		if (err)
 			goto fail;
 
diff --git a/net/ipv6/ip6mr.c b/net/ipv6/ip6mr.c
index 8493a22..74ceb73 100644
--- a/net/ipv6/ip6mr.c
+++ b/net/ipv6/ip6mr.c
@@ -1986,13 +1986,13 @@
 }
 #endif
 
-static inline int ip6mr_forward2_finish(struct sk_buff *skb)
+static inline int ip6mr_forward2_finish(struct sock *sk, struct sk_buff *skb)
 {
 	IP6_INC_STATS_BH(dev_net(skb_dst(skb)->dev), ip6_dst_idev(skb_dst(skb)),
 			 IPSTATS_MIB_OUTFORWDATAGRAMS);
 	IP6_ADD_STATS_BH(dev_net(skb_dst(skb)->dev), ip6_dst_idev(skb_dst(skb)),
 			 IPSTATS_MIB_OUTOCTETS, skb->len);
-	return dst_output(skb);
+	return dst_output_sk(sk, skb);
 }
 
 /*
@@ -2064,7 +2064,8 @@
 
 	IP6CB(skb)->flags |= IP6SKB_FORWARDED;
 
-	return NF_HOOK(NFPROTO_IPV6, NF_INET_FORWARD, skb, skb->dev, dev,
+	return NF_HOOK(NFPROTO_IPV6, NF_INET_FORWARD, NULL, skb,
+		       skb->dev, dev,
 		       ip6mr_forward2_finish);
 
 out_free:
diff --git a/net/ipv6/mcast.c b/net/ipv6/mcast.c
index fac1f27..083b292 100644
--- a/net/ipv6/mcast.c
+++ b/net/ipv6/mcast.c
@@ -1644,8 +1644,9 @@
 
 	payload_len = skb->len;
 
-	err = NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT, skb, NULL, skb->dev,
-		      dst_output);
+	err = NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT,
+		      net->ipv6.igmp_sk, skb, NULL, skb->dev,
+		      dst_output_sk);
 out:
 	if (!err) {
 		ICMP6MSGOUT_INC_STATS(net, idev, ICMPV6_MLD2_REPORT);
@@ -2007,8 +2008,8 @@
 	}
 
 	skb_dst_set(skb, dst);
-	err = NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT, skb, NULL, skb->dev,
-		      dst_output);
+	err = NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT, sk, skb,
+		      NULL, skb->dev, dst_output_sk);
 out:
 	if (!err) {
 		ICMP6MSGOUT_INC_STATS(net, idev, type);
diff --git a/net/ipv6/ndisc.c b/net/ipv6/ndisc.c
index 71fde6c..96f153c 100644
--- a/net/ipv6/ndisc.c
+++ b/net/ipv6/ndisc.c
@@ -463,8 +463,9 @@
 	idev = __in6_dev_get(dst->dev);
 	IP6_UPD_PO_STATS(net, idev, IPSTATS_MIB_OUT, skb->len);
 
-	err = NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT, skb, NULL, dst->dev,
-		      dst_output);
+	err = NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT, sk, skb,
+		      NULL, dst->dev,
+		      dst_output_sk);
 	if (!err) {
 		ICMP6MSGOUT_INC_STATS(net, idev, type);
 		ICMP6_INC_STATS(net, idev, ICMP6_MIB_OUTMSGS);
diff --git a/net/ipv6/netfilter/nf_defrag_ipv6_hooks.c b/net/ipv6/netfilter/nf_defrag_ipv6_hooks.c
index e2b8820..a45db0b 100644
--- a/net/ipv6/netfilter/nf_defrag_ipv6_hooks.c
+++ b/net/ipv6/netfilter/nf_defrag_ipv6_hooks.c
@@ -75,7 +75,7 @@
 
 	nf_ct_frag6_consume_orig(reasm);
 
-	NF_HOOK_THRESH(NFPROTO_IPV6, ops->hooknum, reasm,
+	NF_HOOK_THRESH(NFPROTO_IPV6, ops->hooknum, state->sk, reasm,
 		       state->in, state->out,
 		       state->okfn, NF_IP6_PRI_CONNTRACK_DEFRAG + 1);
 
diff --git a/net/ipv6/output_core.c b/net/ipv6/output_core.c
index 4016a6e..7d1131d 100644
--- a/net/ipv6/output_core.c
+++ b/net/ipv6/output_core.c
@@ -146,8 +146,8 @@
 	ipv6_hdr(skb)->payload_len = htons(len);
 	IP6CB(skb)->nhoff = offsetof(struct ipv6hdr, nexthdr);
 
-	return nf_hook(NFPROTO_IPV6, NF_INET_LOCAL_OUT, skb, NULL,
-		       skb_dst(skb)->dev, dst_output);
+	return nf_hook(NFPROTO_IPV6, NF_INET_LOCAL_OUT, skb->sk, skb,
+		       NULL, skb_dst(skb)->dev, dst_output_sk);
 }
 EXPORT_SYMBOL_GPL(__ip6_local_out);
 
diff --git a/net/ipv6/raw.c b/net/ipv6/raw.c
index 79ccdb4..8072bd4 100644
--- a/net/ipv6/raw.c
+++ b/net/ipv6/raw.c
@@ -652,8 +652,8 @@
 		goto error_fault;
 
 	IP6_UPD_PO_STATS(sock_net(sk), rt->rt6i_idev, IPSTATS_MIB_OUT, skb->len);
-	err = NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT, skb, NULL,
-		      rt->dst.dev, dst_output);
+	err = NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT, sk, skb,
+		      NULL, rt->dst.dev, dst_output_sk);
 	if (err > 0)
 		err = net_xmit_errno(err);
 	if (err)
diff --git a/net/ipv6/xfrm6_input.c b/net/ipv6/xfrm6_input.c
index f48fbe4..74bd178 100644
--- a/net/ipv6/xfrm6_input.c
+++ b/net/ipv6/xfrm6_input.c
@@ -42,7 +42,8 @@
 	ipv6_hdr(skb)->payload_len = htons(skb->len);
 	__skb_push(skb, skb->data - skb_network_header(skb));
 
-	NF_HOOK(NFPROTO_IPV6, NF_INET_PRE_ROUTING, skb, skb->dev, NULL,
+	NF_HOOK(NFPROTO_IPV6, NF_INET_PRE_ROUTING, NULL, skb,
+		skb->dev, NULL,
 		ip6_rcv_finish);
 	return -1;
 }
diff --git a/net/ipv6/xfrm6_output.c b/net/ipv6/xfrm6_output.c
index 010f8bd..09c76a7 100644
--- a/net/ipv6/xfrm6_output.c
+++ b/net/ipv6/xfrm6_output.c
@@ -120,7 +120,7 @@
 }
 EXPORT_SYMBOL(xfrm6_prepare_output);
 
-int xfrm6_output_finish(struct sk_buff *skb)
+int xfrm6_output_finish(struct sock *sk, struct sk_buff *skb)
 {
 	memset(IP6CB(skb), 0, sizeof(*IP6CB(skb)));
 
@@ -128,10 +128,10 @@
 	IP6CB(skb)->flags |= IP6SKB_XFRM_TRANSFORMED;
 #endif
 
-	return xfrm_output(skb);
+	return xfrm_output(sk, skb);
 }
 
-static int __xfrm6_output(struct sk_buff *skb)
+static int __xfrm6_output(struct sock *sk, struct sk_buff *skb)
 {
 	struct dst_entry *dst = skb_dst(skb);
 	struct xfrm_state *x = dst->xfrm;
@@ -140,7 +140,7 @@
 #ifdef CONFIG_NETFILTER
 	if (!x) {
 		IP6CB(skb)->flags |= IP6SKB_REROUTED;
-		return dst_output(skb);
+		return dst_output_sk(sk, skb);
 	}
 #endif
 
@@ -160,14 +160,15 @@
 	if (x->props.mode == XFRM_MODE_TUNNEL &&
 	    ((skb->len > mtu && !skb_is_gso(skb)) ||
 		dst_allfrag(skb_dst(skb)))) {
-			return ip6_fragment(skb, x->outer_mode->afinfo->output_finish);
+		return ip6_fragment(sk, skb,
+				    x->outer_mode->afinfo->output_finish);
 	}
-	return x->outer_mode->afinfo->output_finish(skb);
+	return x->outer_mode->afinfo->output_finish(sk, skb);
 }
 
 int xfrm6_output(struct sock *sk, struct sk_buff *skb)
 {
-	return NF_HOOK_COND(NFPROTO_IPV6, NF_INET_POST_ROUTING, skb,
+	return NF_HOOK_COND(NFPROTO_IPV6, NF_INET_POST_ROUTING, sk, skb,
 			    NULL, skb_dst(skb)->dev, __xfrm6_output,
 			    !(IP6CB(skb)->flags & IP6SKB_REROUTED));
 }