drbd: Also need to check for DRBD_GENLA_F_MANDATORY flags before nla_find_nested()

This is done by introducing drbd_nla_find_nested() which handles the flag
before calling nla_find_nested().

Signed-off-by: Philipp Reisner <philipp.reisner@linbit.com>
Signed-off-by: Lars Ellenberg <lars.ellenberg@linbit.com>
diff --git a/drivers/block/drbd/drbd_nl.c b/drivers/block/drbd/drbd_nl.c
index 5b4090f..24187f1 100644
--- a/drivers/block/drbd/drbd_nl.c
+++ b/drivers/block/drbd/drbd_nl.c
@@ -92,7 +92,7 @@
 #define VOLUME_UNSPECIFIED		(-1U)
 	/* pointer into the request skb,
 	 * limited lifetime! */
-	char *conn_name;
+	char *resource_name;
 
 	/* reply buffer */
 	struct sk_buff *reply_skb;
@@ -191,15 +191,15 @@
 		/* and assign stuff to the global adm_ctx */
 		nla = nested_attr_tb[__nla_type(T_ctx_volume)];
 		adm_ctx.volume = nla ? nla_get_u32(nla) : VOLUME_UNSPECIFIED;
-		nla = nested_attr_tb[__nla_type(T_ctx_conn_name)];
+		nla = nested_attr_tb[__nla_type(T_ctx_resource_name)];
 		if (nla)
-			adm_ctx.conn_name = nla_data(nla);
+			adm_ctx.resource_name = nla_data(nla);
 	} else
 		adm_ctx.volume = VOLUME_UNSPECIFIED;
 
 	adm_ctx.minor = d_in->minor;
 	adm_ctx.mdev = minor_to_mdev(d_in->minor);
-	adm_ctx.tconn = conn_get_by_name(adm_ctx.conn_name);
+	adm_ctx.tconn = conn_get_by_name(adm_ctx.resource_name);
 
 	if (!adm_ctx.mdev && (flags & DRBD_ADM_NEED_MINOR)) {
 		drbd_msg_put_info("unknown minor");
@@ -214,7 +214,8 @@
 	if (adm_ctx.mdev && adm_ctx.tconn &&
 	    adm_ctx.mdev->tconn != adm_ctx.tconn) {
 		pr_warning("request: minor=%u, conn=%s; but that minor belongs to connection %s\n",
-				adm_ctx.minor, adm_ctx.conn_name, adm_ctx.mdev->tconn->name);
+				adm_ctx.minor, adm_ctx.resource_name,
+				adm_ctx.mdev->tconn->name);
 		drbd_msg_put_info("minor exists in different connection");
 		return ERR_INVALID_REQUEST;
 	}
@@ -239,7 +240,7 @@
 static int drbd_adm_finish(struct genl_info *info, int retcode)
 {
 	struct nlattr *nla;
-	const char *conn_name = NULL;
+	const char *resource_name = NULL;
 
 	if (adm_ctx.tconn) {
 		kref_put(&adm_ctx.tconn->kref, &conn_destroy);
@@ -253,9 +254,10 @@
 
 	nla = info->attrs[DRBD_NLA_CFG_CONTEXT];
 	if (nla) {
-		nla = nla_find_nested(nla, __nla_type(T_ctx_conn_name));
-		if (nla)
-			conn_name = nla_data(nla);
+		int maxtype = ARRAY_SIZE(drbd_cfg_context_nl_policy) - 1;
+		nla = drbd_nla_find_nested(maxtype, nla, __nla_type(T_ctx_resource_name));
+		if (nla && !IS_ERR(nla))
+			resource_name = nla_data(nla);
 	}
 
 	drbd_adm_send_reply(adm_ctx.reply_skb, info);
@@ -2526,7 +2528,7 @@
 	return drbd_adm_simple_request_state(skb, info, NS(disk, D_OUTDATED));
 }
 
-int nla_put_drbd_cfg_context(struct sk_buff *skb, const char *conn_name, unsigned vnr)
+int nla_put_drbd_cfg_context(struct sk_buff *skb, const char *resource_name, unsigned vnr)
 {
 	struct nlattr *nla;
 	nla = nla_nest_start(skb, DRBD_NLA_CFG_CONTEXT);
@@ -2534,7 +2536,7 @@
 		goto nla_put_failure;
 	if (vnr != VOLUME_UNSPECIFIED)
 		NLA_PUT_U32(skb, T_ctx_volume, vnr);
-	NLA_PUT_STRING(skb, T_ctx_conn_name, conn_name);
+	NLA_PUT_STRING(skb, T_ctx_resource_name, resource_name);
 	nla_nest_end(skb, nla);
 	return 0;
 
@@ -2778,8 +2780,9 @@
 {
 	const unsigned hdrlen = GENL_HDRLEN + GENL_MAGIC_FAMILY_HDRSZ;
 	struct nlattr *nla;
-	const char *conn_name;
+	const char *resource_name;
 	struct drbd_tconn *tconn;
+	int maxtype;
 
 	/* Is this a followup call? */
 	if (cb->args[0]) {
@@ -2799,12 +2802,15 @@
 	/* No explicit context given.  Dump all. */
 	if (!nla)
 		goto dump;
-	nla = nla_find_nested(nla, __nla_type(T_ctx_conn_name));
+	maxtype = ARRAY_SIZE(drbd_cfg_context_nl_policy) - 1;
+	nla = drbd_nla_find_nested(maxtype, nla, __nla_type(T_ctx_resource_name));
+	if (IS_ERR(nla))
+		return PTR_ERR(nla);
 	/* context given, but no name present? */
 	if (!nla)
 		return -EINVAL;
-	conn_name = nla_data(nla);
-	tconn = conn_get_by_name(conn_name);
+	resource_name = nla_data(nla);
+	tconn = conn_get_by_name(resource_name);
 
 	if (!tconn)
 		return -ENODEV;
@@ -2957,16 +2963,16 @@
 }
 
 static enum drbd_ret_code
-drbd_check_conn_name(const char *name)
+drbd_check_resource_name(const char *name)
 {
 	if (!name || !name[0]) {
-		drbd_msg_put_info("connection name missing");
+		drbd_msg_put_info("resource name missing");
 		return ERR_MANDATORY_TAG;
 	}
 	/* if we want to use these in sysfs/configfs/debugfs some day,
 	 * we must not allow slashes */
 	if (strchr(name, '/')) {
-		drbd_msg_put_info("invalid connection name");
+		drbd_msg_put_info("invalid resource name");
 		return ERR_INVALID_REQUEST;
 	}
 	return NO_ERROR;
@@ -2982,7 +2988,7 @@
 	if (retcode != NO_ERROR)
 		goto out;
 
-	retcode = drbd_check_conn_name(adm_ctx.conn_name);
+	retcode = drbd_check_resource_name(adm_ctx.resource_name);
 	if (retcode != NO_ERROR)
 		goto out;
 
@@ -2995,7 +3001,7 @@
 		goto out;
 	}
 
-	if (!conn_create(adm_ctx.conn_name))
+	if (!conn_create(adm_ctx.resource_name))
 		retcode = ERR_NOMEM;
 out:
 	drbd_adm_finish(info, retcode);
@@ -3213,3 +3219,53 @@
 			"Event seq:%u sib_reason:%u\n",
 			err, seq, sib->sib_reason);
 }
+
+int drbd_nla_check_mandatory(int maxtype, struct nlattr *nla)
+{
+	struct nlattr *head = nla_data(nla);
+	int len = nla_len(nla);
+	int rem;
+
+	/*
+	 * validate_nla (called from nla_parse_nested) ignores attributes
+	 * beyond maxtype, and does not understand the DRBD_GENLA_F_MANDATORY flag.
+	 * In order to have it validate attributes with the DRBD_GENLA_F_MANDATORY
+	 * flag set also, check and remove that flag before calling
+	 * nla_parse_nested.
+	 */
+
+	nla_for_each_attr(nla, head, len, rem) {
+		if (nla->nla_type & DRBD_GENLA_F_MANDATORY) {
+			nla->nla_type &= ~DRBD_GENLA_F_MANDATORY;
+			if (nla_type(nla) > maxtype)
+				return -EOPNOTSUPP;
+		}
+	}
+	return 0;
+}
+
+int drbd_nla_parse_nested(struct nlattr *tb[], int maxtype, struct nlattr *nla,
+			  const struct nla_policy *policy)
+{
+	int err;
+
+	err = drbd_nla_check_mandatory(maxtype, nla);
+	if (!err)
+		err = nla_parse_nested(tb, maxtype, nla, policy);
+
+	return err;
+}
+
+struct nlattr *drbd_nla_find_nested(int maxtype, struct nlattr *nla, int attrtype)
+{
+	int err;
+	/*
+	 * If any nested attribute has the DRBD_GENLA_F_MANDATORY flag set and
+	 * we don't know about that attribute, reject all the nested
+	 * attributes.
+	 */
+	err = drbd_nla_check_mandatory(maxtype, nla);
+	if (err)
+		return ERR_PTR(err);
+	return nla_find_nested(nla, attrtype);
+}