netfilter: Pass socket pointer down through okfn().
On the output paths in particular, we have to sometimes deal with two
socket contexts. First, and usually skb->sk, is the local socket that
generated the frame.
And second, is potentially the socket used to control a tunneling
socket, such as one the encapsulates using UDP.
We do not want to disassociate skb->sk when encapsulating in order
to fix this, because that would break socket memory accounting.
The most extreme case where this can cause huge problems is an
AF_PACKET socket transmitting over a vxlan device. We hit code
paths doing checks that assume they are dealing with an ipv4
socket, but are actually operating upon the AF_PACKET one.
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/ipv6/ip6_output.c b/net/ipv6/ip6_output.c
index 654f245..7fde1f2 100644
--- a/net/ipv6/ip6_output.c
+++ b/net/ipv6/ip6_output.c
@@ -56,7 +56,7 @@
#include <net/checksum.h>
#include <linux/mroute6.h>
-static int ip6_finish_output2(struct sk_buff *skb)
+static int ip6_finish_output2(struct sock *sk, struct sk_buff *skb)
{
struct dst_entry *dst = skb_dst(skb);
struct net_device *dev = dst->dev;
@@ -70,7 +70,7 @@
if (ipv6_addr_is_multicast(&ipv6_hdr(skb)->daddr)) {
struct inet6_dev *idev = ip6_dst_idev(skb_dst(skb));
- if (!(dev->flags & IFF_LOOPBACK) && sk_mc_loop(skb->sk) &&
+ if (!(dev->flags & IFF_LOOPBACK) && sk_mc_loop(sk) &&
((mroute6_socket(dev_net(dev), skb) &&
!(IP6CB(skb)->flags & IP6SKB_FORWARDED)) ||
ipv6_chk_mcast_addr(dev, &ipv6_hdr(skb)->daddr,
@@ -82,7 +82,7 @@
*/
if (newskb)
NF_HOOK(NFPROTO_IPV6, NF_INET_POST_ROUTING,
- newskb, NULL, newskb->dev,
+ sk, newskb, NULL, newskb->dev,
dev_loopback_xmit);
if (ipv6_hdr(skb)->hop_limit == 0) {
@@ -122,14 +122,14 @@
return -EINVAL;
}
-static int ip6_finish_output(struct sk_buff *skb)
+static int ip6_finish_output(struct sock *sk, struct sk_buff *skb)
{
if ((skb->len > ip6_skb_dst_mtu(skb) && !skb_is_gso(skb)) ||
dst_allfrag(skb_dst(skb)) ||
(IP6CB(skb)->frag_max_size && skb->len > IP6CB(skb)->frag_max_size))
- return ip6_fragment(skb, ip6_finish_output2);
+ return ip6_fragment(sk, skb, ip6_finish_output2);
else
- return ip6_finish_output2(skb);
+ return ip6_finish_output2(sk, skb);
}
int ip6_output(struct sock *sk, struct sk_buff *skb)
@@ -143,7 +143,8 @@
return 0;
}
- return NF_HOOK_COND(NFPROTO_IPV6, NF_INET_POST_ROUTING, skb, NULL, dev,
+ return NF_HOOK_COND(NFPROTO_IPV6, NF_INET_POST_ROUTING, sk, skb,
+ NULL, dev,
ip6_finish_output,
!(IP6CB(skb)->flags & IP6SKB_REROUTED));
}
@@ -223,8 +224,8 @@
if ((skb->len <= mtu) || skb->ignore_df || skb_is_gso(skb)) {
IP6_UPD_PO_STATS(net, ip6_dst_idev(skb_dst(skb)),
IPSTATS_MIB_OUT, skb->len);
- return NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT, skb, NULL,
- dst->dev, dst_output);
+ return NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT, sk, skb,
+ NULL, dst->dev, dst_output_sk);
}
skb->dev = dst->dev;
@@ -316,10 +317,10 @@
return 0;
}
-static inline int ip6_forward_finish(struct sk_buff *skb)
+static inline int ip6_forward_finish(struct sock *sk, struct sk_buff *skb)
{
skb_sender_cpu_clear(skb);
- return dst_output(skb);
+ return dst_output_sk(sk, skb);
}
static unsigned int ip6_dst_mtu_forward(const struct dst_entry *dst)
@@ -511,7 +512,8 @@
IP6_INC_STATS_BH(net, ip6_dst_idev(dst), IPSTATS_MIB_OUTFORWDATAGRAMS);
IP6_ADD_STATS_BH(net, ip6_dst_idev(dst), IPSTATS_MIB_OUTOCTETS, skb->len);
- return NF_HOOK(NFPROTO_IPV6, NF_INET_FORWARD, skb, skb->dev, dst->dev,
+ return NF_HOOK(NFPROTO_IPV6, NF_INET_FORWARD, NULL, skb,
+ skb->dev, dst->dev,
ip6_forward_finish);
error:
@@ -538,7 +540,8 @@
skb_copy_secmark(to, from);
}
-int ip6_fragment(struct sk_buff *skb, int (*output)(struct sk_buff *))
+int ip6_fragment(struct sock *sk, struct sk_buff *skb,
+ int (*output)(struct sock *, struct sk_buff *))
{
struct sk_buff *frag;
struct rt6_info *rt = (struct rt6_info *)skb_dst(skb);
@@ -667,7 +670,7 @@
ip6_copy_metadata(frag, skb);
}
- err = output(skb);
+ err = output(sk, skb);
if (!err)
IP6_INC_STATS(net, ip6_dst_idev(&rt->dst),
IPSTATS_MIB_FRAGCREATES);
@@ -800,7 +803,7 @@
/*
* Put this fragment into the sending queue.
*/
- err = output(frag);
+ err = output(sk, frag);
if (err)
goto fail;