ipv6: Pass struct net through ip6_fragment

Signed-off-by: Eric W. Biederman <ebiederm@xmission.com>
diff --git a/net/ipv6/ip6_output.c b/net/ipv6/ip6_output.c
index a598fe2..caf7d14 100644
--- a/net/ipv6/ip6_output.c
+++ b/net/ipv6/ip6_output.c
@@ -56,11 +56,10 @@
 #include <net/checksum.h>
 #include <linux/mroute6.h>
 
-static int ip6_finish_output2(struct sock *sk, struct sk_buff *skb)
+static int ip6_finish_output2(struct net *net, struct sock *sk, struct sk_buff *skb)
 {
 	struct dst_entry *dst = skb_dst(skb);
 	struct net_device *dev = dst->dev;
-	struct net *net = dev_net(dev);
 	struct neighbour *neigh;
 	struct in6_addr *nexthop;
 	int ret;
@@ -126,9 +125,9 @@
 	if ((skb->len > ip6_skb_dst_mtu(skb) && !skb_is_gso(skb)) ||
 	    dst_allfrag(skb_dst(skb)) ||
 	    (IP6CB(skb)->frag_max_size && skb->len > IP6CB(skb)->frag_max_size))
-		return ip6_fragment(sk, skb, ip6_finish_output2);
+		return ip6_fragment(net, sk, skb, ip6_finish_output2);
 	else
-		return ip6_finish_output2(sk, skb);
+		return ip6_finish_output2(net, sk, skb);
 }
 
 int ip6_output(struct sock *sk, struct sk_buff *skb)
@@ -554,8 +553,8 @@
 	skb_copy_secmark(to, from);
 }
 
-int ip6_fragment(struct sock *sk, struct sk_buff *skb,
-		 int (*output)(struct sock *, struct sk_buff *))
+int ip6_fragment(struct net *net, struct sock *sk, struct sk_buff *skb,
+		 int (*output)(struct net *, struct sock *, struct sk_buff *))
 {
 	struct sk_buff *frag;
 	struct rt6_info *rt = (struct rt6_info *)skb_dst(skb);
@@ -568,7 +567,6 @@
 	__be32 frag_id;
 	int ptr, offset = 0, err = 0;
 	u8 *prevhdr, nexthdr = 0;
-	struct net *net = dev_net(skb_dst(skb)->dev);
 
 	hlen = ip6_find_1stfragopt(skb, &prevhdr);
 	nexthdr = *prevhdr;
@@ -688,7 +686,7 @@
 				ip6_copy_metadata(frag, skb);
 			}
 
-			err = output(sk, skb);
+			err = output(net, sk, skb);
 			if (!err)
 				IP6_INC_STATS(net, ip6_dst_idev(&rt->dst),
 					      IPSTATS_MIB_FRAGCREATES);
@@ -816,7 +814,7 @@
 		/*
 		 *	Put this fragment into the sending queue.
 		 */
-		err = output(sk, frag);
+		err = output(net, sk, frag);
 		if (err)
 			goto fail;