netns PF_KEY: part 2

* interaction with userspace -- take netns from userspace socket.
* in ->notify hook take netns either from SA or explicitly passed --
	we don't know if SA/SPD flush is coming.
* stub policy migration with init_net for now.

Signed-off-by: Alexey Dobriyan <adobriyan@gmail.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/key/af_key.c b/net/key/af_key.c
index e80b264..bb78ef9 100644
--- a/net/key/af_key.c
+++ b/net/key/af_key.c
@@ -261,9 +261,9 @@
 #define BROADCAST_REGISTERED	2
 #define BROADCAST_PROMISC_ONLY	4
 static int pfkey_broadcast(struct sk_buff *skb, gfp_t allocation,
-			   int broadcast_flags, struct sock *one_sk)
+			   int broadcast_flags, struct sock *one_sk,
+			   struct net *net)
 {
-	struct net *net = &init_net;
 	struct netns_pfkey *net_pfkey = net_generic(net, pfkey_net_id);
 	struct sock *sk;
 	struct hlist_node *node;
@@ -336,7 +336,7 @@
 		hdr->sadb_msg_seq = 0;
 		hdr->sadb_msg_errno = rc;
 		pfkey_broadcast(pfk->dump.skb, GFP_ATOMIC, BROADCAST_ONE,
-				&pfk->sk);
+				&pfk->sk, sock_net(&pfk->sk));
 		pfk->dump.skb = NULL;
 	}
 
@@ -375,7 +375,7 @@
 	hdr->sadb_msg_len = (sizeof(struct sadb_msg) /
 			     sizeof(uint64_t));
 
-	pfkey_broadcast(skb, GFP_KERNEL, BROADCAST_ONE, sk);
+	pfkey_broadcast(skb, GFP_KERNEL, BROADCAST_ONE, sk, sock_net(sk));
 
 	return 0;
 }
@@ -653,7 +653,7 @@
 				      xaddr);
 }
 
-static struct  xfrm_state *pfkey_xfrm_state_lookup(struct sadb_msg *hdr, void **ext_hdrs)
+static struct  xfrm_state *pfkey_xfrm_state_lookup(struct net *net, struct sadb_msg *hdr, void **ext_hdrs)
 {
 	struct sadb_sa *sa;
 	struct sadb_address *addr;
@@ -691,7 +691,7 @@
 	if (!xaddr)
 		return NULL;
 
-	return xfrm_state_lookup(&init_net, xaddr, sa->sadb_sa_spi, proto, family);
+	return xfrm_state_lookup(net, xaddr, sa->sadb_sa_spi, proto, family);
 }
 
 #define PFKEY_ALIGN8(a) (1 + (((a) - 1) | (8 - 1)))
@@ -1066,7 +1066,8 @@
 	return __pfkey_xfrm_state2msg(x, 0, hsc);
 }
 
-static struct xfrm_state * pfkey_msg2xfrm_state(struct sadb_msg *hdr,
+static struct xfrm_state * pfkey_msg2xfrm_state(struct net *net,
+						struct sadb_msg *hdr,
 						void **ext_hdrs)
 {
 	struct xfrm_state *x;
@@ -1130,7 +1131,7 @@
 	     (key->sadb_key_bits+7) / 8 > key->sadb_key_len * sizeof(uint64_t)))
 		return ERR_PTR(-EINVAL);
 
-	x = xfrm_state_alloc(&init_net);
+	x = xfrm_state_alloc(net);
 	if (x == NULL)
 		return ERR_PTR(-ENOBUFS);
 
@@ -1306,6 +1307,7 @@
 
 static int pfkey_getspi(struct sock *sk, struct sk_buff *skb, struct sadb_msg *hdr, void **ext_hdrs)
 {
+	struct net *net = sock_net(sk);
 	struct sk_buff *resp_skb;
 	struct sadb_x_sa2 *sa2;
 	struct sadb_address *saddr, *daddr;
@@ -1356,7 +1358,7 @@
 	}
 
 	if (hdr->sadb_msg_seq) {
-		x = xfrm_find_acq_byseq(&init_net, hdr->sadb_msg_seq);
+		x = xfrm_find_acq_byseq(net, hdr->sadb_msg_seq);
 		if (x && xfrm_addr_cmp(&x->id.daddr, xdaddr, family)) {
 			xfrm_state_put(x);
 			x = NULL;
@@ -1364,7 +1366,7 @@
 	}
 
 	if (!x)
-		x = xfrm_find_acq(&init_net, mode, reqid, proto, xdaddr, xsaddr, 1, family);
+		x = xfrm_find_acq(net, mode, reqid, proto, xdaddr, xsaddr, 1, family);
 
 	if (x == NULL)
 		return -ENOENT;
@@ -1397,13 +1399,14 @@
 
 	xfrm_state_put(x);
 
-	pfkey_broadcast(resp_skb, GFP_KERNEL, BROADCAST_ONE, sk);
+	pfkey_broadcast(resp_skb, GFP_KERNEL, BROADCAST_ONE, sk, net);
 
 	return 0;
 }
 
 static int pfkey_acquire(struct sock *sk, struct sk_buff *skb, struct sadb_msg *hdr, void **ext_hdrs)
 {
+	struct net *net = sock_net(sk);
 	struct xfrm_state *x;
 
 	if (hdr->sadb_msg_len != sizeof(struct sadb_msg)/8)
@@ -1412,14 +1415,14 @@
 	if (hdr->sadb_msg_seq == 0 || hdr->sadb_msg_errno == 0)
 		return 0;
 
-	x = xfrm_find_acq_byseq(&init_net, hdr->sadb_msg_seq);
+	x = xfrm_find_acq_byseq(net, hdr->sadb_msg_seq);
 	if (x == NULL)
 		return 0;
 
 	spin_lock_bh(&x->lock);
 	if (x->km.state == XFRM_STATE_ACQ) {
 		x->km.state = XFRM_STATE_ERROR;
-		wake_up(&init_net.xfrm.km_waitq);
+		wake_up(&net->xfrm.km_waitq);
 	}
 	spin_unlock_bh(&x->lock);
 	xfrm_state_put(x);
@@ -1484,18 +1487,19 @@
 	hdr->sadb_msg_seq = c->seq;
 	hdr->sadb_msg_pid = c->pid;
 
-	pfkey_broadcast(skb, GFP_ATOMIC, BROADCAST_ALL, NULL);
+	pfkey_broadcast(skb, GFP_ATOMIC, BROADCAST_ALL, NULL, xs_net(x));
 
 	return 0;
 }
 
 static int pfkey_add(struct sock *sk, struct sk_buff *skb, struct sadb_msg *hdr, void **ext_hdrs)
 {
+	struct net *net = sock_net(sk);
 	struct xfrm_state *x;
 	int err;
 	struct km_event c;
 
-	x = pfkey_msg2xfrm_state(hdr, ext_hdrs);
+	x = pfkey_msg2xfrm_state(net, hdr, ext_hdrs);
 	if (IS_ERR(x))
 		return PTR_ERR(x);
 
@@ -1529,6 +1533,7 @@
 
 static int pfkey_delete(struct sock *sk, struct sk_buff *skb, struct sadb_msg *hdr, void **ext_hdrs)
 {
+	struct net *net = sock_net(sk);
 	struct xfrm_state *x;
 	struct km_event c;
 	int err;
@@ -1538,7 +1543,7 @@
 				     ext_hdrs[SADB_EXT_ADDRESS_DST-1]))
 		return -EINVAL;
 
-	x = pfkey_xfrm_state_lookup(hdr, ext_hdrs);
+	x = pfkey_xfrm_state_lookup(net, hdr, ext_hdrs);
 	if (x == NULL)
 		return -ESRCH;
 
@@ -1570,6 +1575,7 @@
 
 static int pfkey_get(struct sock *sk, struct sk_buff *skb, struct sadb_msg *hdr, void **ext_hdrs)
 {
+	struct net *net = sock_net(sk);
 	__u8 proto;
 	struct sk_buff *out_skb;
 	struct sadb_msg *out_hdr;
@@ -1580,7 +1586,7 @@
 				     ext_hdrs[SADB_EXT_ADDRESS_DST-1]))
 		return -EINVAL;
 
-	x = pfkey_xfrm_state_lookup(hdr, ext_hdrs);
+	x = pfkey_xfrm_state_lookup(net, hdr, ext_hdrs);
 	if (x == NULL)
 		return -ESRCH;
 
@@ -1598,7 +1604,7 @@
 	out_hdr->sadb_msg_reserved = 0;
 	out_hdr->sadb_msg_seq = hdr->sadb_msg_seq;
 	out_hdr->sadb_msg_pid = hdr->sadb_msg_pid;
-	pfkey_broadcast(out_skb, GFP_ATOMIC, BROADCAST_ONE, sk);
+	pfkey_broadcast(out_skb, GFP_ATOMIC, BROADCAST_ONE, sk, sock_net(sk));
 
 	return 0;
 }
@@ -1699,7 +1705,7 @@
 		return -ENOBUFS;
 	}
 
-	pfkey_broadcast(supp_skb, GFP_KERNEL, BROADCAST_REGISTERED, sk);
+	pfkey_broadcast(supp_skb, GFP_KERNEL, BROADCAST_REGISTERED, sk, sock_net(sk));
 
 	return 0;
 }
@@ -1721,13 +1727,14 @@
 	hdr->sadb_msg_errno = (uint8_t) 0;
 	hdr->sadb_msg_len = (sizeof(struct sadb_msg) / sizeof(uint64_t));
 
-	pfkey_broadcast(skb, GFP_ATOMIC, BROADCAST_ALL, NULL);
+	pfkey_broadcast(skb, GFP_ATOMIC, BROADCAST_ALL, NULL, c->net);
 
 	return 0;
 }
 
 static int pfkey_flush(struct sock *sk, struct sk_buff *skb, struct sadb_msg *hdr, void **ext_hdrs)
 {
+	struct net *net = sock_net(sk);
 	unsigned proto;
 	struct km_event c;
 	struct xfrm_audit audit_info;
@@ -1740,14 +1747,14 @@
 	audit_info.loginuid = audit_get_loginuid(current);
 	audit_info.sessionid = audit_get_sessionid(current);
 	audit_info.secid = 0;
-	err = xfrm_state_flush(&init_net, proto, &audit_info);
+	err = xfrm_state_flush(net, proto, &audit_info);
 	if (err)
 		return err;
 	c.data.proto = proto;
 	c.seq = hdr->sadb_msg_seq;
 	c.pid = hdr->sadb_msg_pid;
 	c.event = XFRM_MSG_FLUSHSA;
-	c.net = &init_net;
+	c.net = net;
 	km_state_notify(NULL, &c);
 
 	return 0;
@@ -1777,7 +1784,7 @@
 
 	if (pfk->dump.skb)
 		pfkey_broadcast(pfk->dump.skb, GFP_ATOMIC, BROADCAST_ONE,
-				&pfk->sk);
+				&pfk->sk, sock_net(&pfk->sk));
 	pfk->dump.skb = out_skb;
 
 	return 0;
@@ -1785,7 +1792,8 @@
 
 static int pfkey_dump_sa(struct pfkey_sock *pfk)
 {
-	return xfrm_state_walk(&init_net, &pfk->dump.u.state, dump_sa, (void *) pfk);
+	struct net *net = sock_net(&pfk->sk);
+	return xfrm_state_walk(net, &pfk->dump.u.state, dump_sa, (void *) pfk);
 }
 
 static void pfkey_dump_sa_done(struct pfkey_sock *pfk)
@@ -1826,7 +1834,7 @@
 			return -EINVAL;
 		pfk->promisc = satype;
 	}
-	pfkey_broadcast(skb_clone(skb, GFP_KERNEL), GFP_KERNEL, BROADCAST_ALL, NULL);
+	pfkey_broadcast(skb_clone(skb, GFP_KERNEL), GFP_KERNEL, BROADCAST_ALL, NULL, sock_net(sk));
 	return 0;
 }
 
@@ -1842,7 +1850,7 @@
 	return 0;
 }
 
-static u32 gen_reqid(void)
+static u32 gen_reqid(struct net *net)
 {
 	struct xfrm_policy_walk walk;
 	u32 start;
@@ -1855,7 +1863,7 @@
 		if (reqid == 0)
 			reqid = IPSEC_MANUAL_REQID_MAX+1;
 		xfrm_policy_walk_init(&walk, XFRM_POLICY_TYPE_MAIN);
-		rc = xfrm_policy_walk(&init_net, &walk, check_reqid, (void*)&reqid);
+		rc = xfrm_policy_walk(net, &walk, check_reqid, (void*)&reqid);
 		xfrm_policy_walk_done(&walk);
 		if (rc != -EEXIST)
 			return reqid;
@@ -1866,6 +1874,7 @@
 static int
 parse_ipsecrequest(struct xfrm_policy *xp, struct sadb_x_ipsecrequest *rq)
 {
+	struct net *net = xp_net(xp);
 	struct xfrm_tmpl *t = xp->xfrm_vec + xp->xfrm_nr;
 	int mode;
 
@@ -1885,7 +1894,7 @@
 		t->reqid = rq->sadb_x_ipsecrequest_reqid;
 		if (t->reqid > IPSEC_MANUAL_REQID_MAX)
 			t->reqid = 0;
-		if (!t->reqid && !(t->reqid = gen_reqid()))
+		if (!t->reqid && !(t->reqid = gen_reqid(net)))
 			return -ENOBUFS;
 	}
 
@@ -2156,7 +2165,7 @@
 	out_hdr->sadb_msg_errno = 0;
 	out_hdr->sadb_msg_seq = c->seq;
 	out_hdr->sadb_msg_pid = c->pid;
-	pfkey_broadcast(out_skb, GFP_ATOMIC, BROADCAST_ALL, NULL);
+	pfkey_broadcast(out_skb, GFP_ATOMIC, BROADCAST_ALL, NULL, xp_net(xp));
 out:
 	return 0;
 
@@ -2164,6 +2173,7 @@
 
 static int pfkey_spdadd(struct sock *sk, struct sk_buff *skb, struct sadb_msg *hdr, void **ext_hdrs)
 {
+	struct net *net = sock_net(sk);
 	int err = 0;
 	struct sadb_lifetime *lifetime;
 	struct sadb_address *sa;
@@ -2183,7 +2193,7 @@
 	if (!pol->sadb_x_policy_dir || pol->sadb_x_policy_dir >= IPSEC_DIR_MAX)
 		return -EINVAL;
 
-	xp = xfrm_policy_alloc(&init_net, GFP_KERNEL);
+	xp = xfrm_policy_alloc(net, GFP_KERNEL);
 	if (xp == NULL)
 		return -ENOBUFS;
 
@@ -2284,6 +2294,7 @@
 
 static int pfkey_spddelete(struct sock *sk, struct sk_buff *skb, struct sadb_msg *hdr, void **ext_hdrs)
 {
+	struct net *net = sock_net(sk);
 	int err;
 	struct sadb_address *sa;
 	struct sadb_x_policy *pol;
@@ -2333,7 +2344,7 @@
 			return err;
 	}
 
-	xp = xfrm_policy_bysel_ctx(&init_net, XFRM_POLICY_TYPE_MAIN,
+	xp = xfrm_policy_bysel_ctx(net, XFRM_POLICY_TYPE_MAIN,
 				   pol->sadb_x_policy_dir - 1, &sel, pol_ctx,
 				   1, &err);
 	security_xfrm_policy_free(pol_ctx);
@@ -2381,7 +2392,7 @@
 	out_hdr->sadb_msg_errno = 0;
 	out_hdr->sadb_msg_seq = hdr->sadb_msg_seq;
 	out_hdr->sadb_msg_pid = hdr->sadb_msg_pid;
-	pfkey_broadcast(out_skb, GFP_ATOMIC, BROADCAST_ONE, sk);
+	pfkey_broadcast(out_skb, GFP_ATOMIC, BROADCAST_ONE, sk, xp_net(xp));
 	err = 0;
 
 out:
@@ -2566,6 +2577,7 @@
 
 static int pfkey_spdget(struct sock *sk, struct sk_buff *skb, struct sadb_msg *hdr, void **ext_hdrs)
 {
+	struct net *net = sock_net(sk);
 	unsigned int dir;
 	int err = 0, delete;
 	struct sadb_x_policy *pol;
@@ -2580,7 +2592,7 @@
 		return -EINVAL;
 
 	delete = (hdr->sadb_msg_type == SADB_X_SPDDELETE2);
-	xp = xfrm_policy_byid(&init_net, XFRM_POLICY_TYPE_MAIN, dir,
+	xp = xfrm_policy_byid(net, XFRM_POLICY_TYPE_MAIN, dir,
 			      pol->sadb_x_policy_id, delete, &err);
 	if (xp == NULL)
 		return -ENOENT;
@@ -2634,7 +2646,7 @@
 
 	if (pfk->dump.skb)
 		pfkey_broadcast(pfk->dump.skb, GFP_ATOMIC, BROADCAST_ONE,
-				&pfk->sk);
+				&pfk->sk, sock_net(&pfk->sk));
 	pfk->dump.skb = out_skb;
 
 	return 0;
@@ -2642,7 +2654,8 @@
 
 static int pfkey_dump_sp(struct pfkey_sock *pfk)
 {
-	return xfrm_policy_walk(&init_net, &pfk->dump.u.policy, dump_sp, (void *) pfk);
+	struct net *net = sock_net(&pfk->sk);
+	return xfrm_policy_walk(net, &pfk->dump.u.policy, dump_sp, (void *) pfk);
 }
 
 static void pfkey_dump_sp_done(struct pfkey_sock *pfk)
@@ -2681,13 +2694,14 @@
 	hdr->sadb_msg_version = PF_KEY_V2;
 	hdr->sadb_msg_errno = (uint8_t) 0;
 	hdr->sadb_msg_len = (sizeof(struct sadb_msg) / sizeof(uint64_t));
-	pfkey_broadcast(skb_out, GFP_ATOMIC, BROADCAST_ALL, NULL);
+	pfkey_broadcast(skb_out, GFP_ATOMIC, BROADCAST_ALL, NULL, c->net);
 	return 0;
 
 }
 
 static int pfkey_spdflush(struct sock *sk, struct sk_buff *skb, struct sadb_msg *hdr, void **ext_hdrs)
 {
+	struct net *net = sock_net(sk);
 	struct km_event c;
 	struct xfrm_audit audit_info;
 	int err;
@@ -2695,14 +2709,14 @@
 	audit_info.loginuid = audit_get_loginuid(current);
 	audit_info.sessionid = audit_get_sessionid(current);
 	audit_info.secid = 0;
-	err = xfrm_policy_flush(&init_net, XFRM_POLICY_TYPE_MAIN, &audit_info);
+	err = xfrm_policy_flush(net, XFRM_POLICY_TYPE_MAIN, &audit_info);
 	if (err)
 		return err;
 	c.data.type = XFRM_POLICY_TYPE_MAIN;
 	c.event = XFRM_MSG_FLUSHPOLICY;
 	c.pid = hdr->sadb_msg_pid;
 	c.seq = hdr->sadb_msg_seq;
-	c.net = &init_net;
+	c.net = net;
 	km_policy_notify(NULL, 0, &c);
 
 	return 0;
@@ -2742,7 +2756,7 @@
 	int err;
 
 	pfkey_broadcast(skb_clone(skb, GFP_KERNEL), GFP_KERNEL,
-			BROADCAST_PROMISC_ONLY, NULL);
+			BROADCAST_PROMISC_ONLY, NULL, sock_net(sk));
 
 	memset(ext_hdrs, 0, sizeof(ext_hdrs));
 	err = parse_exthdrs(skb, hdr, ext_hdrs);
@@ -2945,13 +2959,13 @@
 	out_hdr->sadb_msg_seq = 0;
 	out_hdr->sadb_msg_pid = 0;
 
-	pfkey_broadcast(out_skb, GFP_ATOMIC, BROADCAST_REGISTERED, NULL);
+	pfkey_broadcast(out_skb, GFP_ATOMIC, BROADCAST_REGISTERED, NULL, xs_net(x));
 	return 0;
 }
 
 static int pfkey_send_notify(struct xfrm_state *x, struct km_event *c)
 {
-	struct net *net = &init_net;
+	struct net *net = x ? xs_net(x) : c->net;
 	struct netns_pfkey *net_pfkey = net_generic(net, pfkey_net_id);
 
 	if (atomic_read(&net_pfkey->socks_nr) == 0)
@@ -3116,12 +3130,13 @@
 		       xfrm_ctx->ctx_len);
 	}
 
-	return pfkey_broadcast(skb, GFP_ATOMIC, BROADCAST_REGISTERED, NULL);
+	return pfkey_broadcast(skb, GFP_ATOMIC, BROADCAST_REGISTERED, NULL, xs_net(x));
 }
 
 static struct xfrm_policy *pfkey_compile_policy(struct sock *sk, int opt,
 						u8 *data, int len, int *dir)
 {
+	struct net *net = sock_net(sk);
 	struct xfrm_policy *xp;
 	struct sadb_x_policy *pol = (struct sadb_x_policy*)data;
 	struct sadb_x_sec_ctx *sec_ctx;
@@ -3154,7 +3169,7 @@
 	    (!pol->sadb_x_policy_dir || pol->sadb_x_policy_dir > IPSEC_DIR_OUTBOUND))
 		return NULL;
 
-	xp = xfrm_policy_alloc(&init_net, GFP_ATOMIC);
+	xp = xfrm_policy_alloc(net, GFP_ATOMIC);
 	if (xp == NULL) {
 		*dir = -ENOBUFS;
 		return NULL;
@@ -3313,7 +3328,7 @@
 	n_port->sadb_x_nat_t_port_port = sport;
 	n_port->sadb_x_nat_t_port_reserved = 0;
 
-	return pfkey_broadcast(skb, GFP_ATOMIC, BROADCAST_REGISTERED, NULL);
+	return pfkey_broadcast(skb, GFP_ATOMIC, BROADCAST_REGISTERED, NULL, xs_net(x));
 }
 
 #ifdef CONFIG_NET_KEY_MIGRATE
@@ -3504,7 +3519,7 @@
 	}
 
 	/* broadcast migrate message to sockets */
-	pfkey_broadcast(skb, GFP_ATOMIC, BROADCAST_ALL, NULL);
+	pfkey_broadcast(skb, GFP_ATOMIC, BROADCAST_ALL, NULL, &init_net);
 
 	return 0;