netlink/genetlink: pass network namespace to bind/unbind

Netlink families can exist in multiple namespaces, and for the most
part multicast subscriptions are per network namespace. Thus it only
makes sense to have bind/unbind notifications per network namespace.

To achieve this, pass the network namespace of a given client socket
to the bind/unbind functions.

Also do this in generic netlink, and there also make sure that any
bind for multicast groups that only exist in init_net is rejected.
This isn't really a problem if it is accepted since a client in a
different namespace will never receive any notifications from such
a group, but it can confuse the family if not rejected (it's also
possible to silently (without telling the family) accept it, but it
would also have to be ignored on unbind so families that take any
kind of action on bind/unbind won't do unnecessary work for invalid
clients like that.

Signed-off-by: Johannes Berg <johannes.berg@intel.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/netfilter/nfnetlink.c b/net/netfilter/nfnetlink.c
index 13c2e17..cde4a67 100644
--- a/net/netfilter/nfnetlink.c
+++ b/net/netfilter/nfnetlink.c
@@ -463,7 +463,7 @@
 }
 
 #ifdef CONFIG_MODULES
-static int nfnetlink_bind(int group)
+static int nfnetlink_bind(struct net *net, int group)
 {
 	const struct nfnetlink_subsystem *ss;
 	int type;
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
index f29b63f..84ea76c 100644
--- a/net/netlink/af_netlink.c
+++ b/net/netlink/af_netlink.c
@@ -1141,8 +1141,8 @@
 	struct module *module = NULL;
 	struct mutex *cb_mutex;
 	struct netlink_sock *nlk;
-	int (*bind)(int group);
-	void (*unbind)(int group);
+	int (*bind)(struct net *net, int group);
+	void (*unbind)(struct net *net, int group);
 	int err = 0;
 
 	sock->state = SS_UNCONNECTED;
@@ -1251,7 +1251,7 @@
 
 		for (i = 0; i < nlk->ngroups; i++)
 			if (test_bit(i, nlk->groups))
-				nlk->netlink_unbind(i + 1);
+				nlk->netlink_unbind(sock_net(sk), i + 1);
 	}
 	kfree(nlk->groups);
 	nlk->groups = NULL;
@@ -1418,8 +1418,9 @@
 }
 
 static void netlink_undo_bind(int group, long unsigned int groups,
-			      struct netlink_sock *nlk)
+			      struct sock *sk)
 {
+	struct netlink_sock *nlk = nlk_sk(sk);
 	int undo;
 
 	if (!nlk->netlink_unbind)
@@ -1427,7 +1428,7 @@
 
 	for (undo = 0; undo < group; undo++)
 		if (test_bit(undo, &groups))
-			nlk->netlink_unbind(undo);
+			nlk->netlink_unbind(sock_net(sk), undo);
 }
 
 static int netlink_bind(struct socket *sock, struct sockaddr *addr,
@@ -1465,10 +1466,10 @@
 		for (group = 0; group < nlk->ngroups; group++) {
 			if (!test_bit(group, &groups))
 				continue;
-			err = nlk->netlink_bind(group);
+			err = nlk->netlink_bind(net, group);
 			if (!err)
 				continue;
-			netlink_undo_bind(group, groups, nlk);
+			netlink_undo_bind(group, groups, sk);
 			return err;
 		}
 	}
@@ -1478,7 +1479,7 @@
 			netlink_insert(sk, net, nladdr->nl_pid) :
 			netlink_autobind(sock);
 		if (err) {
-			netlink_undo_bind(nlk->ngroups, groups, nlk);
+			netlink_undo_bind(nlk->ngroups, groups, sk);
 			return err;
 		}
 	}
@@ -2129,7 +2130,7 @@
 		if (!val || val - 1 >= nlk->ngroups)
 			return -EINVAL;
 		if (optname == NETLINK_ADD_MEMBERSHIP && nlk->netlink_bind) {
-			err = nlk->netlink_bind(val);
+			err = nlk->netlink_bind(sock_net(sk), val);
 			if (err)
 				return err;
 		}
@@ -2138,7 +2139,7 @@
 					 optname == NETLINK_ADD_MEMBERSHIP);
 		netlink_table_ungrab();
 		if (optname == NETLINK_DROP_MEMBERSHIP && nlk->netlink_unbind)
-			nlk->netlink_unbind(val);
+			nlk->netlink_unbind(sock_net(sk), val);
 
 		err = 0;
 		break;
diff --git a/net/netlink/af_netlink.h b/net/netlink/af_netlink.h
index b20a173..f123a88 100644
--- a/net/netlink/af_netlink.h
+++ b/net/netlink/af_netlink.h
@@ -39,8 +39,8 @@
 	struct mutex		*cb_mutex;
 	struct mutex		cb_def_mutex;
 	void			(*netlink_rcv)(struct sk_buff *skb);
-	int			(*netlink_bind)(int group);
-	void			(*netlink_unbind)(int group);
+	int			(*netlink_bind)(struct net *net, int group);
+	void			(*netlink_unbind)(struct net *net, int group);
 	struct module		*module;
 #ifdef CONFIG_NETLINK_MMAP
 	struct mutex		pg_vec_lock;
@@ -65,8 +65,8 @@
 	unsigned int		groups;
 	struct mutex		*cb_mutex;
 	struct module		*module;
-	int			(*bind)(int group);
-	void			(*unbind)(int group);
+	int			(*bind)(struct net *net, int group);
+	void			(*unbind)(struct net *net, int group);
 	bool			(*compare)(struct net *net, struct sock *sock);
 	int			registered;
 };
diff --git a/net/netlink/genetlink.c b/net/netlink/genetlink.c
index 05bf40b..91566ed 100644
--- a/net/netlink/genetlink.c
+++ b/net/netlink/genetlink.c
@@ -983,7 +983,7 @@
 	{ .name = "notify", },
 };
 
-static int genl_bind(int group)
+static int genl_bind(struct net *net, int group)
 {
 	int i, err;
 	bool found = false;
@@ -997,8 +997,10 @@
 			    group < f->mcgrp_offset + f->n_mcgrps) {
 				int fam_grp = group - f->mcgrp_offset;
 
-				if (f->mcast_bind)
-					err = f->mcast_bind(fam_grp);
+				if (!f->netnsok && net != &init_net)
+					err = -ENOENT;
+				else if (f->mcast_bind)
+					err = f->mcast_bind(net, fam_grp);
 				else
 					err = 0;
 				found = true;
@@ -1014,7 +1016,7 @@
 	return err;
 }
 
-static void genl_unbind(int group)
+static void genl_unbind(struct net *net, int group)
 {
 	int i;
 	bool found = false;
@@ -1029,7 +1031,7 @@
 				int fam_grp = group - f->mcgrp_offset;
 
 				if (f->mcast_unbind)
-					f->mcast_unbind(fam_grp);
+					f->mcast_unbind(net, fam_grp);
 				found = true;
 				break;
 			}