genetlink: make multicast groups const, prevent abuse

Register generic netlink multicast groups as an array with
the family and give them contiguous group IDs. Then instead
of passing the global group ID to the various functions that
send messages, pass the ID relative to the family - for most
families that's just 0 because the only have one group.

This avoids the list_head and ID in each group, adding a new
field for the mcast group ID offset to the family.

At the same time, this allows us to prevent abusing groups
again like the quota and dropmon code did, since we can now
check that a family only uses a group it owns.

Signed-off-by: Johannes Berg <johannes.berg@intel.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/netlink/genetlink.c b/net/netlink/genetlink.c
index 36e3a86..7dbc4f7 100644
--- a/net/netlink/genetlink.c
+++ b/net/netlink/genetlink.c
@@ -69,16 +69,20 @@
  * abuses the API and thinks it can statically use group 1.
  * That group will typically conflict with other groups that
  * any proper users use.
+ * Bit 16 is marked as used since it's used for generic netlink
+ * and the code no longer marks pre-reserved IDs as used.
  * Bit 17 is marked as already used since the VFS quota code
  * also abused this API and relied on family == group ID, we
  * cater to that by giving it a static family and group ID.
  */
-static unsigned long mc_group_start = 0x3 | BIT(GENL_ID_VFS_DQUOT);
+static unsigned long mc_group_start = 0x3 | BIT(GENL_ID_CTRL) |
+				      BIT(GENL_ID_VFS_DQUOT);
 static unsigned long *mc_groups = &mc_group_start;
 static unsigned long mc_groups_longs = 1;
 
 static int genl_ctrl_event(int event, struct genl_family *family,
-			   struct genl_multicast_group *grp);
+			   const struct genl_multicast_group *grp,
+			   int grp_id);
 
 static inline unsigned int genl_family_hash(unsigned int id)
 {
@@ -144,66 +148,110 @@
 	return 0;
 }
 
-static struct genl_multicast_group notify_grp;
-
-/**
- * genl_register_mc_group - register a multicast group
- *
- * Registers the specified multicast group and notifies userspace
- * about the new group.
- *
- * Returns 0 on success or a negative error code.
- *
- * @family: The generic netlink family the group shall be registered for.
- * @grp: The group to register, must have a name.
- */
-int genl_register_mc_group(struct genl_family *family,
-			   struct genl_multicast_group *grp)
+static int genl_allocate_reserve_groups(int n_groups, int *first_id)
 {
-	int id;
 	unsigned long *new_groups;
-	int err = 0;
+	int start = 0;
+	int i;
+	int id;
+	bool fits;
 
-	BUG_ON(grp->name[0] == '\0');
-	BUG_ON(memchr(grp->name, '\0', GENL_NAMSIZ) == NULL);
+	do {
+		if (start == 0)
+			id = find_first_zero_bit(mc_groups,
+						 mc_groups_longs *
+						 BITS_PER_LONG);
+		else
+			id = find_next_zero_bit(mc_groups,
+						mc_groups_longs * BITS_PER_LONG,
+						start);
 
-	genl_lock_all();
+		fits = true;
+		for (i = id;
+		     i < min_t(int, id + n_groups,
+			       mc_groups_longs * BITS_PER_LONG);
+		     i++) {
+			if (test_bit(i, mc_groups)) {
+				start = i;
+				fits = false;
+				break;
+			}
+		}
+
+		if (id >= mc_groups_longs * BITS_PER_LONG) {
+			unsigned long new_longs = mc_groups_longs +
+						  BITS_TO_LONGS(n_groups);
+			size_t nlen = new_longs * sizeof(unsigned long);
+
+			if (mc_groups == &mc_group_start) {
+				new_groups = kzalloc(nlen, GFP_KERNEL);
+				if (!new_groups)
+					return -ENOMEM;
+				mc_groups = new_groups;
+				*mc_groups = mc_group_start;
+			} else {
+				new_groups = krealloc(mc_groups, nlen,
+						      GFP_KERNEL);
+				if (!new_groups)
+					return -ENOMEM;
+				mc_groups = new_groups;
+				for (i = 0; i < BITS_TO_LONGS(n_groups); i++)
+					mc_groups[mc_groups_longs + i] = 0;
+			}
+			mc_groups_longs = new_longs;
+		}
+	} while (!fits);
+
+	for (i = id; i < id + n_groups; i++)
+		set_bit(i, mc_groups);
+	*first_id = id;
+	return 0;
+}
+
+static struct genl_family genl_ctrl;
+
+static int genl_validate_assign_mc_groups(struct genl_family *family)
+{
+	int first_id;
+	int n_groups = family->n_mcgrps;
+	int err, i;
+	bool groups_allocated = false;
+
+	if (!n_groups)
+		return 0;
+
+	for (i = 0; i < n_groups; i++) {
+		const struct genl_multicast_group *grp = &family->mcgrps[i];
+
+		if (WARN_ON(grp->name[0] == '\0'))
+			return -EINVAL;
+		if (WARN_ON(memchr(grp->name, '\0', GENL_NAMSIZ) == NULL))
+			return -EINVAL;
+	}
 
 	/* special-case our own group and hacks */
-	if (grp == &notify_grp)
-		id = GENL_ID_CTRL;
-	else if (strcmp(family->name, "NET_DM") == 0)
-		id = 1;
-	else if (strcmp(family->name, "VFS_DQUOT") == 0)
-		id = GENL_ID_VFS_DQUOT;
-	else
-		id = find_first_zero_bit(mc_groups,
-					 mc_groups_longs * BITS_PER_LONG);
-
-
-	if (id >= mc_groups_longs * BITS_PER_LONG) {
-		size_t nlen = (mc_groups_longs + 1) * sizeof(unsigned long);
-
-		if (mc_groups == &mc_group_start) {
-			new_groups = kzalloc(nlen, GFP_KERNEL);
-			if (!new_groups) {
-				err = -ENOMEM;
-				goto out;
-			}
-			mc_groups = new_groups;
-			*mc_groups = mc_group_start;
-		} else {
-			new_groups = krealloc(mc_groups, nlen, GFP_KERNEL);
-			if (!new_groups) {
-				err = -ENOMEM;
-				goto out;
-			}
-			mc_groups = new_groups;
-			mc_groups[mc_groups_longs] = 0;
-		}
-		mc_groups_longs++;
+	if (family == &genl_ctrl) {
+		first_id = GENL_ID_CTRL;
+		BUG_ON(n_groups != 1);
+	} else if (strcmp(family->name, "NET_DM") == 0) {
+		first_id = 1;
+		BUG_ON(n_groups != 1);
+	} else if (strcmp(family->name, "VFS_DQUOT") == 0) {
+		first_id = GENL_ID_VFS_DQUOT;
+		BUG_ON(n_groups != 1);
+	} else {
+		groups_allocated = true;
+		err = genl_allocate_reserve_groups(n_groups, &first_id);
+		if (err)
+			return err;
 	}
 
+	family->mcgrp_offset = first_id;
+
+	/* if still initializing, can't and don't need to to realloc bitmaps */
+	if (!init_net.genl_sock)
+		return 0;
+
 	if (family->netnsok) {
 		struct net *net;
 
@@ -219,9 +267,7 @@
 				 * number of _possible_ groups has been
 				 * increased on some sockets which is ok.
 				 */
-				rcu_read_unlock();
-				netlink_table_ungrab();
-				goto out;
+				break;
 			}
 		}
 		rcu_read_unlock();
@@ -229,46 +275,39 @@
 	} else {
 		err = netlink_change_ngroups(init_net.genl_sock,
 					     mc_groups_longs * BITS_PER_LONG);
-		if (err)
-			goto out;
 	}
 
-	grp->id = id;
-	set_bit(id, mc_groups);
-	list_add_tail(&grp->list, &family->mcast_groups);
+	if (groups_allocated && err) {
+		for (i = 0; i < family->n_mcgrps; i++)
+			clear_bit(family->mcgrp_offset + i, mc_groups);
+	}
 
-	genl_ctrl_event(CTRL_CMD_NEWMCAST_GRP, family, grp);
- out:
-	genl_unlock_all();
 	return err;
 }
-EXPORT_SYMBOL(genl_register_mc_group);
-
-static void __genl_unregister_mc_group(struct genl_family *family,
-				       struct genl_multicast_group *grp)
-{
-	struct net *net;
-
-	netlink_table_grab();
-	rcu_read_lock();
-	for_each_net_rcu(net)
-		__netlink_clear_multicast_users(net->genl_sock, grp->id);
-	rcu_read_unlock();
-	netlink_table_ungrab();
-
-	if (grp->id != 1)
-		clear_bit(grp->id, mc_groups);
-	list_del(&grp->list);
-	genl_ctrl_event(CTRL_CMD_DELMCAST_GRP, family, grp);
-	grp->id = 0;
-}
 
 static void genl_unregister_mc_groups(struct genl_family *family)
 {
-	struct genl_multicast_group *grp, *tmp;
+	struct net *net;
+	int i;
 
-	list_for_each_entry_safe(grp, tmp, &family->mcast_groups, list)
-		__genl_unregister_mc_group(family, grp);
+	netlink_table_grab();
+	rcu_read_lock();
+	for_each_net_rcu(net) {
+		for (i = 0; i < family->n_mcgrps; i++)
+			__netlink_clear_multicast_users(
+				net->genl_sock, family->mcgrp_offset + i);
+	}
+	rcu_read_unlock();
+	netlink_table_ungrab();
+
+	for (i = 0; i < family->n_mcgrps; i++) {
+		int grp_id = family->mcgrp_offset + i;
+
+		if (grp_id != 1)
+			clear_bit(grp_id, mc_groups);
+		genl_ctrl_event(CTRL_CMD_DELMCAST_GRP, family,
+				&family->mcgrps[i], grp_id);
+	}
 }
 
 static int genl_validate_ops(struct genl_family *family)
@@ -314,7 +353,7 @@
  */
 int __genl_register_family(struct genl_family *family)
 {
-	int err = -EINVAL;
+	int err = -EINVAL, i;
 
 	if (family->id && family->id < GENL_MIN_ID)
 		goto errout;
@@ -326,8 +365,6 @@
 	if (err)
 		return err;
 
-	INIT_LIST_HEAD(&family->mcast_groups);
-
 	genl_lock_all();
 
 	if (genl_family_find_byname(family->name)) {
@@ -359,10 +396,18 @@
 	} else
 		family->attrbuf = NULL;
 
+	err = genl_validate_assign_mc_groups(family);
+	if (err)
+		goto errout_locked;
+
 	list_add_tail(&family->family_list, genl_family_chain(family->id));
 	genl_unlock_all();
 
-	genl_ctrl_event(CTRL_CMD_NEWFAMILY, family, NULL);
+	/* send all events */
+	genl_ctrl_event(CTRL_CMD_NEWFAMILY, family, NULL, 0);
+	for (i = 0; i < family->n_mcgrps; i++)
+		genl_ctrl_event(CTRL_CMD_NEWMCAST_GRP, family,
+				&family->mcgrps[i], family->mcgrp_offset + i);
 
 	return 0;
 
@@ -398,7 +443,7 @@
 		genl_unlock_all();
 
 		kfree(family->attrbuf);
-		genl_ctrl_event(CTRL_CMD_DELFAMILY, family, NULL);
+		genl_ctrl_event(CTRL_CMD_DELFAMILY, family, NULL, 0);
 		return 0;
 	}
 
@@ -658,23 +703,26 @@
 		nla_nest_end(skb, nla_ops);
 	}
 
-	if (!list_empty(&family->mcast_groups)) {
-		struct genl_multicast_group *grp;
+	if (family->n_mcgrps) {
 		struct nlattr *nla_grps;
-		int idx = 1;
+		int i;
 
 		nla_grps = nla_nest_start(skb, CTRL_ATTR_MCAST_GROUPS);
 		if (nla_grps == NULL)
 			goto nla_put_failure;
 
-		list_for_each_entry(grp, &family->mcast_groups, list) {
+		for (i = 0; i < family->n_mcgrps; i++) {
 			struct nlattr *nest;
+			const struct genl_multicast_group *grp;
 
-			nest = nla_nest_start(skb, idx++);
+			grp = &family->mcgrps[i];
+
+			nest = nla_nest_start(skb, i + 1);
 			if (nest == NULL)
 				goto nla_put_failure;
 
-			if (nla_put_u32(skb, CTRL_ATTR_MCAST_GRP_ID, grp->id) ||
+			if (nla_put_u32(skb, CTRL_ATTR_MCAST_GRP_ID,
+					family->mcgrp_offset + i) ||
 			    nla_put_string(skb, CTRL_ATTR_MCAST_GRP_NAME,
 					   grp->name))
 				goto nla_put_failure;
@@ -692,9 +740,9 @@
 }
 
 static int ctrl_fill_mcgrp_info(struct genl_family *family,
-				struct genl_multicast_group *grp, u32 portid,
-				u32 seq, u32 flags, struct sk_buff *skb,
-				u8 cmd)
+				const struct genl_multicast_group *grp,
+				int grp_id, u32 portid, u32 seq, u32 flags,
+				struct sk_buff *skb, u8 cmd)
 {
 	void *hdr;
 	struct nlattr *nla_grps;
@@ -716,7 +764,7 @@
 	if (nest == NULL)
 		goto nla_put_failure;
 
-	if (nla_put_u32(skb, CTRL_ATTR_MCAST_GRP_ID, grp->id) ||
+	if (nla_put_u32(skb, CTRL_ATTR_MCAST_GRP_ID, grp_id) ||
 	    nla_put_string(skb, CTRL_ATTR_MCAST_GRP_NAME,
 			   grp->name))
 		goto nla_put_failure;
@@ -782,9 +830,10 @@
 	return skb;
 }
 
-static struct sk_buff *ctrl_build_mcgrp_msg(struct genl_family *family,
-					    struct genl_multicast_group *grp,
-					    u32 portid, int seq, u8 cmd)
+static struct sk_buff *
+ctrl_build_mcgrp_msg(struct genl_family *family,
+		     const struct genl_multicast_group *grp,
+		     int grp_id, u32 portid, int seq, u8 cmd)
 {
 	struct sk_buff *skb;
 	int err;
@@ -793,7 +842,8 @@
 	if (skb == NULL)
 		return ERR_PTR(-ENOBUFS);
 
-	err = ctrl_fill_mcgrp_info(family, grp, portid, seq, 0, skb, cmd);
+	err = ctrl_fill_mcgrp_info(family, grp, grp_id, portid,
+				   seq, 0, skb, cmd);
 	if (err < 0) {
 		nlmsg_free(skb);
 		return ERR_PTR(err);
@@ -856,7 +906,8 @@
 }
 
 static int genl_ctrl_event(int event, struct genl_family *family,
-			   struct genl_multicast_group *grp)
+			   const struct genl_multicast_group *grp,
+			   int grp_id)
 {
 	struct sk_buff *msg;
 
@@ -873,7 +924,7 @@
 	case CTRL_CMD_NEWMCAST_GRP:
 	case CTRL_CMD_DELMCAST_GRP:
 		BUG_ON(!grp);
-		msg = ctrl_build_mcgrp_msg(family, grp, 0, 0, event);
+		msg = ctrl_build_mcgrp_msg(family, grp, grp_id, 0, 0, event);
 		break;
 	default:
 		return -EINVAL;
@@ -884,11 +935,11 @@
 
 	if (!family->netnsok) {
 		genlmsg_multicast_netns(&genl_ctrl, &init_net, msg, 0,
-					GENL_ID_CTRL, GFP_KERNEL);
+					0, GFP_KERNEL);
 	} else {
 		rcu_read_lock();
 		genlmsg_multicast_allns(&genl_ctrl, msg, 0,
-					GENL_ID_CTRL, GFP_ATOMIC);
+					0, GFP_ATOMIC);
 		rcu_read_unlock();
 	}
 
@@ -904,8 +955,8 @@
 	},
 };
 
-static struct genl_multicast_group notify_grp = {
-	.name		= "notify",
+static struct genl_multicast_group genl_ctrl_groups[] = {
+	{ .name = "notify", },
 };
 
 static int __net_init genl_pernet_init(struct net *net)
@@ -945,7 +996,8 @@
 	for (i = 0; i < GENL_FAM_TAB_SIZE; i++)
 		INIT_LIST_HEAD(&family_ht[i]);
 
-	err = genl_register_family_with_ops(&genl_ctrl, genl_ctrl_ops);
+	err = genl_register_family_with_ops_groups(&genl_ctrl, genl_ctrl_ops,
+						   genl_ctrl_groups);
 	if (err < 0)
 		goto problem;
 
@@ -953,10 +1005,6 @@
 	if (err)
 		goto problem;
 
-	err = genl_register_mc_group(&genl_ctrl, &notify_grp);
-	if (err < 0)
-		goto problem;
-
 	return 0;
 
 problem:
@@ -997,6 +1045,9 @@
 int genlmsg_multicast_allns(struct genl_family *family, struct sk_buff *skb,
 			    u32 portid, unsigned int group, gfp_t flags)
 {
+	if (group >= family->n_mcgrps)
+		return -EINVAL;
+	group = family->mcgrp_offset + group;
 	return genlmsg_mcast(skb, portid, group, flags);
 }
 EXPORT_SYMBOL(genlmsg_multicast_allns);
@@ -1011,6 +1062,9 @@
 	if (nlh)
 		report = nlmsg_report(nlh);
 
+	if (group >= family->n_mcgrps)
+		return;
+	group = family->mcgrp_offset + group;
 	nlmsg_notify(sk, skb, portid, group, report, flags);
 }
 EXPORT_SYMBOL(genl_notify);