netns: add rtnl cmd to add and get peer netns ids

With this patch, a user can define an id for a peer netns by providing a FD or a
PID. These ids are local to the netns where it is added (ie valid only into this
netns).

The main function (ie the one exported to other module), peernet2id(), allows to
get the id of a peer netns. If no id has been assigned by the user, this
function allocates one.

These ids will be used in netlink messages to point to a peer netns, for example
in case of a x-netns interface.

Signed-off-by: Nicolas Dichtel <nicolas.dichtel@6wind.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/core/net_namespace.c b/net/core/net_namespace.c
index ce780c7..9d1a4ca 100644
--- a/net/core/net_namespace.c
+++ b/net/core/net_namespace.c
@@ -15,6 +15,10 @@
 #include <linux/file.h>
 #include <linux/export.h>
 #include <linux/user_namespace.h>
+#include <linux/net_namespace.h>
+#include <linux/rtnetlink.h>
+#include <net/sock.h>
+#include <net/netlink.h>
 #include <net/net_namespace.h>
 #include <net/netns/generic.h>
 
@@ -144,6 +148,77 @@
 	}
 }
 
+static int alloc_netid(struct net *net, struct net *peer, int reqid)
+{
+	int min = 0, max = 0;
+
+	ASSERT_RTNL();
+
+	if (reqid >= 0) {
+		min = reqid;
+		max = reqid + 1;
+	}
+
+	return idr_alloc(&net->netns_ids, peer, min, max, GFP_KERNEL);
+}
+
+/* This function is used by idr_for_each(). If net is equal to peer, the
+ * function returns the id so that idr_for_each() stops. Because we cannot
+ * returns the id 0 (idr_for_each() will not stop), we return the magic value
+ * NET_ID_ZERO (-1) for it.
+ */
+#define NET_ID_ZERO -1
+static int net_eq_idr(int id, void *net, void *peer)
+{
+	if (net_eq(net, peer))
+		return id ? : NET_ID_ZERO;
+	return 0;
+}
+
+static int __peernet2id(struct net *net, struct net *peer, bool alloc)
+{
+	int id = idr_for_each(&net->netns_ids, net_eq_idr, peer);
+
+	ASSERT_RTNL();
+
+	/* Magic value for id 0. */
+	if (id == NET_ID_ZERO)
+		return 0;
+	if (id > 0)
+		return id;
+
+	if (alloc)
+		return alloc_netid(net, peer, -1);
+
+	return -ENOENT;
+}
+
+/* This function returns the id of a peer netns. If no id is assigned, one will
+ * be allocated and returned.
+ */
+int peernet2id(struct net *net, struct net *peer)
+{
+	int id = __peernet2id(net, peer, true);
+
+	return id >= 0 ? id : NETNSA_NSID_NOT_ASSIGNED;
+}
+
+struct net *get_net_ns_by_id(struct net *net, int id)
+{
+	struct net *peer;
+
+	if (id < 0)
+		return NULL;
+
+	rcu_read_lock();
+	peer = idr_find(&net->netns_ids, id);
+	if (peer)
+		get_net(peer);
+	rcu_read_unlock();
+
+	return peer;
+}
+
 /*
  * setup_net runs the initializers for the network namespace object.
  */
@@ -158,6 +233,7 @@
 	atomic_set(&net->passive, 1);
 	net->dev_base_seq = 1;
 	net->user_ns = user_ns;
+	idr_init(&net->netns_ids);
 
 #ifdef NETNS_REFCNT_DEBUG
 	atomic_set(&net->use_count, 0);
@@ -288,6 +364,14 @@
 	list_for_each_entry(net, &net_kill_list, cleanup_list) {
 		list_del_rcu(&net->list);
 		list_add_tail(&net->exit_list, &net_exit_list);
+		for_each_net(tmp) {
+			int id = __peernet2id(tmp, net, false);
+
+			if (id >= 0)
+				idr_remove(&tmp->netns_ids, id);
+		}
+		idr_destroy(&net->netns_ids);
+
 	}
 	rtnl_unlock();
 
@@ -402,6 +486,130 @@
 	.exit = net_ns_net_exit,
 };
 
+static struct nla_policy rtnl_net_policy[NETNSA_MAX + 1] = {
+	[NETNSA_NONE]		= { .type = NLA_UNSPEC },
+	[NETNSA_NSID]		= { .type = NLA_S32 },
+	[NETNSA_PID]		= { .type = NLA_U32 },
+	[NETNSA_FD]		= { .type = NLA_U32 },
+};
+
+static int rtnl_net_newid(struct sk_buff *skb, struct nlmsghdr *nlh)
+{
+	struct net *net = sock_net(skb->sk);
+	struct nlattr *tb[NETNSA_MAX + 1];
+	struct net *peer;
+	int nsid, err;
+
+	err = nlmsg_parse(nlh, sizeof(struct rtgenmsg), tb, NETNSA_MAX,
+			  rtnl_net_policy);
+	if (err < 0)
+		return err;
+	if (!tb[NETNSA_NSID])
+		return -EINVAL;
+	nsid = nla_get_s32(tb[NETNSA_NSID]);
+
+	if (tb[NETNSA_PID])
+		peer = get_net_ns_by_pid(nla_get_u32(tb[NETNSA_PID]));
+	else if (tb[NETNSA_FD])
+		peer = get_net_ns_by_fd(nla_get_u32(tb[NETNSA_FD]));
+	else
+		return -EINVAL;
+	if (IS_ERR(peer))
+		return PTR_ERR(peer);
+
+	if (__peernet2id(net, peer, false) >= 0) {
+		err = -EEXIST;
+		goto out;
+	}
+
+	err = alloc_netid(net, peer, nsid);
+	if (err > 0)
+		err = 0;
+out:
+	put_net(peer);
+	return err;
+}
+
+static int rtnl_net_get_size(void)
+{
+	return NLMSG_ALIGN(sizeof(struct rtgenmsg))
+	       + nla_total_size(sizeof(s32)) /* NETNSA_NSID */
+	       ;
+}
+
+static int rtnl_net_fill(struct sk_buff *skb, u32 portid, u32 seq, int flags,
+			 int cmd, struct net *net, struct net *peer)
+{
+	struct nlmsghdr *nlh;
+	struct rtgenmsg *rth;
+	int id;
+
+	ASSERT_RTNL();
+
+	nlh = nlmsg_put(skb, portid, seq, cmd, sizeof(*rth), flags);
+	if (!nlh)
+		return -EMSGSIZE;
+
+	rth = nlmsg_data(nlh);
+	rth->rtgen_family = AF_UNSPEC;
+
+	id = __peernet2id(net, peer, false);
+	if  (id < 0)
+		id = NETNSA_NSID_NOT_ASSIGNED;
+	if (nla_put_s32(skb, NETNSA_NSID, id))
+		goto nla_put_failure;
+
+	nlmsg_end(skb, nlh);
+	return 0;
+
+nla_put_failure:
+	nlmsg_cancel(skb, nlh);
+	return -EMSGSIZE;
+}
+
+static int rtnl_net_getid(struct sk_buff *skb, struct nlmsghdr *nlh)
+{
+	struct net *net = sock_net(skb->sk);
+	struct nlattr *tb[NETNSA_MAX + 1];
+	struct sk_buff *msg;
+	int err = -ENOBUFS;
+	struct net *peer;
+
+	err = nlmsg_parse(nlh, sizeof(struct rtgenmsg), tb, NETNSA_MAX,
+			  rtnl_net_policy);
+	if (err < 0)
+		return err;
+	if (tb[NETNSA_PID])
+		peer = get_net_ns_by_pid(nla_get_u32(tb[NETNSA_PID]));
+	else if (tb[NETNSA_FD])
+		peer = get_net_ns_by_fd(nla_get_u32(tb[NETNSA_FD]));
+	else
+		return -EINVAL;
+
+	if (IS_ERR(peer))
+		return PTR_ERR(peer);
+
+	msg = nlmsg_new(rtnl_net_get_size(), GFP_KERNEL);
+	if (!msg) {
+		err = -ENOMEM;
+		goto out;
+	}
+
+	err = rtnl_net_fill(msg, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
+			    RTM_GETNSID, net, peer);
+	if (err < 0)
+		goto err_out;
+
+	err = rtnl_unicast(msg, net, NETLINK_CB(skb).portid);
+	goto out;
+
+err_out:
+	nlmsg_free(msg);
+out:
+	put_net(peer);
+	return err;
+}
+
 static int __init net_ns_init(void)
 {
 	struct net_generic *ng;
@@ -435,6 +643,9 @@
 
 	register_pernet_subsys(&net_ns_ops);
 
+	rtnl_register(PF_UNSPEC, RTM_NEWNSID, rtnl_net_newid, NULL, NULL);
+	rtnl_register(PF_UNSPEC, RTM_GETNSID, rtnl_net_getid, NULL, NULL);
+
 	return 0;
 }