[NETFILTER]: ctnetlink: use netlink policy

Signed-off-by: Patrick McHardy <kaber@trash.net>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/netfilter/nf_conntrack_netlink.c b/net/netfilter/nf_conntrack_netlink.c
index 9f9bef2..ce35812 100644
--- a/net/netfilter/nf_conntrack_netlink.c
+++ b/net/netfilter/nf_conntrack_netlink.c
@@ -512,16 +512,20 @@
 
 	l3proto = nf_ct_l3proto_find_get(tuple->src.l3num);
 
-	if (likely(l3proto->nlattr_to_tuple))
-		ret = l3proto->nlattr_to_tuple(tb, tuple);
+	if (likely(l3proto->nlattr_to_tuple)) {
+		ret = nla_validate_nested(attr, CTA_IP_MAX,
+					  l3proto->nla_policy);
+		if (ret == 0)
+			ret = l3proto->nlattr_to_tuple(tb, tuple);
+	}
 
 	nf_ct_l3proto_put(l3proto);
 
 	return ret;
 }
 
-static const size_t cta_min_proto[CTA_PROTO_MAX+1] = {
-	[CTA_PROTO_NUM]	= sizeof(u_int8_t),
+static const struct nla_policy proto_nla_policy[CTA_PROTO_MAX+1] = {
+	[CTA_PROTO_NUM]	= { .type = NLA_U8 },
 };
 
 static inline int
@@ -532,10 +536,9 @@
 	struct nf_conntrack_l4proto *l4proto;
 	int ret = 0;
 
-	nla_parse_nested(tb, CTA_PROTO_MAX, attr, NULL);
-
-	if (nlattr_bad_size(tb, CTA_PROTO_MAX, cta_min_proto))
-		return -EINVAL;
+	ret = nla_parse_nested(tb, CTA_PROTO_MAX, attr, proto_nla_policy);
+	if (ret < 0)
+		return ret;
 
 	if (!tb[CTA_PROTO_NUM])
 		return -EINVAL;
@@ -543,8 +546,12 @@
 
 	l4proto = nf_ct_l4proto_find_get(tuple->src.l3num, tuple->dst.protonum);
 
-	if (likely(l4proto->nlattr_to_tuple))
-		ret = l4proto->nlattr_to_tuple(tb, tuple);
+	if (likely(l4proto->nlattr_to_tuple)) {
+		ret = nla_validate_nested(attr, CTA_PROTO_MAX,
+					  l4proto->nla_policy);
+		if (ret == 0)
+			ret = l4proto->nlattr_to_tuple(tb, tuple);
+	}
 
 	nf_ct_l4proto_put(l4proto);
 
@@ -588,9 +595,9 @@
 }
 
 #ifdef CONFIG_NF_NAT_NEEDED
-static const size_t cta_min_protonat[CTA_PROTONAT_MAX+1] = {
-	[CTA_PROTONAT_PORT_MIN]	= sizeof(u_int16_t),
-	[CTA_PROTONAT_PORT_MAX]	= sizeof(u_int16_t),
+static const struct nla_policy protonat_nla_policy[CTA_PROTONAT_MAX+1] = {
+	[CTA_PROTONAT_PORT_MIN]	= { .type = NLA_U16 },
+	[CTA_PROTONAT_PORT_MAX]	= { .type = NLA_U16 },
 };
 
 static int nfnetlink_parse_nat_proto(struct nlattr *attr,
@@ -599,11 +606,11 @@
 {
 	struct nlattr *tb[CTA_PROTONAT_MAX+1];
 	struct nf_nat_protocol *npt;
+	int err;
 
-	nla_parse_nested(tb, CTA_PROTONAT_MAX, attr, NULL);
-
-	if (nlattr_bad_size(tb, CTA_PROTONAT_MAX, cta_min_protonat))
-		return -EINVAL;
+	err = nla_parse_nested(tb, CTA_PROTONAT_MAX, attr, protonat_nla_policy);
+	if (err < 0)
+		return err;
 
 	npt = nf_nat_proto_find_get(ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.dst.protonum);
 
@@ -621,9 +628,9 @@
 	return 0;
 }
 
-static const size_t cta_min_nat[CTA_NAT_MAX+1] = {
-	[CTA_NAT_MINIP]		= sizeof(u_int32_t),
-	[CTA_NAT_MAXIP]		= sizeof(u_int32_t),
+static const struct nla_policy nat_nla_policy[CTA_NAT_MAX+1] = {
+	[CTA_NAT_MINIP]		= { .type = NLA_U32 },
+	[CTA_NAT_MAXIP]		= { .type = NLA_U32 },
 };
 
 static inline int
@@ -635,10 +642,9 @@
 
 	memset(range, 0, sizeof(*range));
 
-	nla_parse_nested(tb, CTA_NAT_MAX, nat, NULL);
-
-	if (nlattr_bad_size(tb, CTA_NAT_MAX, cta_min_nat))
-		return -EINVAL;
+	err = nla_parse_nested(tb, CTA_NAT_MAX, nat, nat_nla_policy);
+	if (err < 0)
+		return err;
 
 	if (tb[CTA_NAT_MINIP])
 		range->min_ip = *(__be32 *)nla_data(tb[CTA_NAT_MINIP]);
@@ -677,12 +683,12 @@
 	return 0;
 }
 
-static const size_t cta_min[CTA_MAX+1] = {
-	[CTA_STATUS] 		= sizeof(u_int32_t),
-	[CTA_TIMEOUT] 		= sizeof(u_int32_t),
-	[CTA_MARK]		= sizeof(u_int32_t),
-	[CTA_USE]		= sizeof(u_int32_t),
-	[CTA_ID]		= sizeof(u_int32_t)
+static const struct nla_policy ct_nla_policy[CTA_MAX+1] = {
+	[CTA_STATUS] 		= { .type = NLA_U32 },
+	[CTA_TIMEOUT] 		= { .type = NLA_U32 },
+	[CTA_MARK]		= { .type = NLA_U32 },
+	[CTA_USE]		= { .type = NLA_U32 },
+	[CTA_ID]		= { .type = NLA_U32 },
 };
 
 static int
@@ -696,9 +702,6 @@
 	u_int8_t u3 = nfmsg->nfgen_family;
 	int err = 0;
 
-	if (nlattr_bad_size(cda, CTA_MAX, cta_min))
-		return -EINVAL;
-
 	if (cda[CTA_TUPLE_ORIG])
 		err = ctnetlink_parse_tuple(cda, &tuple, CTA_TUPLE_ORIG, u3);
 	else if (cda[CTA_TUPLE_REPLY])
@@ -754,9 +757,6 @@
 					  ctnetlink_done);
 	}
 
-	if (nlattr_bad_size(cda, CTA_MAX, cta_min))
-		return -EINVAL;
-
 	if (cda[CTA_TUPLE_ORIG])
 		err = ctnetlink_parse_tuple(cda, &tuple, CTA_TUPLE_ORIG, u3);
 	else if (cda[CTA_TUPLE_REPLY])
@@ -1045,9 +1045,6 @@
 	u_int8_t u3 = nfmsg->nfgen_family;
 	int err = 0;
 
-	if (nlattr_bad_size(cda, CTA_MAX, cta_min))
-		return -EINVAL;
-
 	if (cda[CTA_TUPLE_ORIG]) {
 		err = ctnetlink_parse_tuple(cda, &otuple, CTA_TUPLE_ORIG, u3);
 		if (err < 0)
@@ -1313,9 +1310,9 @@
 	return skb->len;
 }
 
-static const size_t cta_min_exp[CTA_EXPECT_MAX+1] = {
-	[CTA_EXPECT_TIMEOUT]	= sizeof(u_int32_t),
-	[CTA_EXPECT_ID]		= sizeof(u_int32_t)
+static const struct nla_policy exp_nla_policy[CTA_EXPECT_MAX+1] = {
+	[CTA_EXPECT_TIMEOUT]	= { .type = NLA_U32 },
+	[CTA_EXPECT_ID]		= { .type = NLA_U32 },
 };
 
 static int
@@ -1329,9 +1326,6 @@
 	u_int8_t u3 = nfmsg->nfgen_family;
 	int err = 0;
 
-	if (nlattr_bad_size(cda, CTA_EXPECT_MAX, cta_min_exp))
-		return -EINVAL;
-
 	if (nlh->nlmsg_flags & NLM_F_DUMP) {
 		return netlink_dump_start(ctnl, skb, nlh,
 					  ctnetlink_exp_dump_table,
@@ -1393,9 +1387,6 @@
 	unsigned int i;
 	int err;
 
-	if (nlattr_bad_size(cda, CTA_EXPECT_MAX, cta_min_exp))
-		return -EINVAL;
-
 	if (cda[CTA_EXPECT_TUPLE]) {
 		/* delete a single expect by tuple */
 		err = ctnetlink_parse_tuple(cda, &tuple, CTA_EXPECT_TUPLE, u3);
@@ -1534,9 +1525,6 @@
 	u_int8_t u3 = nfmsg->nfgen_family;
 	int err = 0;
 
-	if (nlattr_bad_size(cda, CTA_EXPECT_MAX, cta_min_exp))
-		return -EINVAL;
-
 	if (!cda[CTA_EXPECT_TUPLE]
 	    || !cda[CTA_EXPECT_MASK]
 	    || !cda[CTA_EXPECT_MASTER])
@@ -1577,22 +1565,29 @@
 
 static const struct nfnl_callback ctnl_cb[IPCTNL_MSG_MAX] = {
 	[IPCTNL_MSG_CT_NEW]		= { .call = ctnetlink_new_conntrack,
-					    .attr_count = CTA_MAX, },
+					    .attr_count = CTA_MAX,
+					    .policy = ct_nla_policy },
 	[IPCTNL_MSG_CT_GET] 		= { .call = ctnetlink_get_conntrack,
-					    .attr_count = CTA_MAX, },
+					    .attr_count = CTA_MAX,
+					    .policy = ct_nla_policy },
 	[IPCTNL_MSG_CT_DELETE]  	= { .call = ctnetlink_del_conntrack,
-					    .attr_count = CTA_MAX, },
+					    .attr_count = CTA_MAX,
+					    .policy = ct_nla_policy },
 	[IPCTNL_MSG_CT_GET_CTRZERO] 	= { .call = ctnetlink_get_conntrack,
-					    .attr_count = CTA_MAX, },
+					    .attr_count = CTA_MAX,
+					    .policy = ct_nla_policy },
 };
 
 static const struct nfnl_callback ctnl_exp_cb[IPCTNL_MSG_EXP_MAX] = {
 	[IPCTNL_MSG_EXP_GET]		= { .call = ctnetlink_get_expect,
-					    .attr_count = CTA_EXPECT_MAX, },
+					    .attr_count = CTA_EXPECT_MAX,
+					    .policy = exp_nla_policy },
 	[IPCTNL_MSG_EXP_NEW]		= { .call = ctnetlink_new_expect,
-					    .attr_count = CTA_EXPECT_MAX, },
+					    .attr_count = CTA_EXPECT_MAX,
+					    .policy = exp_nla_policy },
 	[IPCTNL_MSG_EXP_DELETE]		= { .call = ctnetlink_del_expect,
-					    .attr_count = CTA_EXPECT_MAX, },
+					    .attr_count = CTA_EXPECT_MAX,
+					    .policy = exp_nla_policy },
 };
 
 static const struct nfnetlink_subsystem ctnl_subsys = {