xfrm: configure policy hash table thresholds by netlink

Enable to specify local and remote prefix length thresholds for the
policy hash table via a netlink XFRM_MSG_NEWSPDINFO message.

prefix length thresholds are specified by XFRMA_SPD_IPV4_HTHRESH and
XFRMA_SPD_IPV6_HTHRESH optional attributes (struct xfrmu_spdhthresh).

example:

    struct xfrmu_spdhthresh thresh4 = {
        .lbits = 0;
        .rbits = 24;
    };
    struct xfrmu_spdhthresh thresh6 = {
        .lbits = 0;
        .rbits = 56;
    };
    struct nlmsghdr *hdr;
    struct nl_msg *msg;

    msg = nlmsg_alloc();
    hdr = nlmsg_put(msg, NL_AUTO_PORT, NL_AUTO_SEQ, XFRMA_SPD_IPV4_HTHRESH, sizeof(__u32), NLM_F_REQUEST);
    nla_put(msg, XFRMA_SPD_IPV4_HTHRESH, sizeof(thresh4), &thresh4);
    nla_put(msg, XFRMA_SPD_IPV6_HTHRESH, sizeof(thresh6), &thresh6);
    nla_send_auto(sk, msg);

The numbers are the policy selector minimum prefix lengths to put a
policy in the hash table.

- lbits is the local threshold (source address for out policies,
  destination address for in and fwd policies).

- rbits is the remote threshold (destination address for out
  policies, source address for in and fwd policies).

The default values are:

XFRMA_SPD_IPV4_HTHRESH: 32 32
XFRMA_SPD_IPV6_HTHRESH: 128 128

Dynamic re-building of the SPD is performed when the thresholds values
are changed.

The current thresholds can be read via a XFRM_MSG_GETSPDINFO request:
the kernel replies to XFRM_MSG_GETSPDINFO requests by an
XFRM_MSG_NEWSPDINFO message, with both attributes
XFRMA_SPD_IPV4_HTHRESH and XFRMA_SPD_IPV6_HTHRESH.

Signed-off-by: Christophe Gouault <christophe.gouault@6wind.com>
Signed-off-by: Steffen Klassert <steffen.klassert@secunet.com>
diff --git a/net/xfrm/xfrm_user.c b/net/xfrm/xfrm_user.c
index d4db6eb..eaf8a8f 100644
--- a/net/xfrm/xfrm_user.c
+++ b/net/xfrm/xfrm_user.c
@@ -964,7 +964,9 @@
 {
 	return NLMSG_ALIGN(4)
 	       + nla_total_size(sizeof(struct xfrmu_spdinfo))
-	       + nla_total_size(sizeof(struct xfrmu_spdhinfo));
+	       + nla_total_size(sizeof(struct xfrmu_spdhinfo))
+	       + nla_total_size(sizeof(struct xfrmu_spdhthresh))
+	       + nla_total_size(sizeof(struct xfrmu_spdhthresh));
 }
 
 static int build_spdinfo(struct sk_buff *skb, struct net *net,
@@ -973,9 +975,11 @@
 	struct xfrmk_spdinfo si;
 	struct xfrmu_spdinfo spc;
 	struct xfrmu_spdhinfo sph;
+	struct xfrmu_spdhthresh spt4, spt6;
 	struct nlmsghdr *nlh;
 	int err;
 	u32 *f;
+	unsigned lseq;
 
 	nlh = nlmsg_put(skb, portid, seq, XFRM_MSG_NEWSPDINFO, sizeof(u32), 0);
 	if (nlh == NULL) /* shouldn't really happen ... */
@@ -993,9 +997,22 @@
 	sph.spdhcnt = si.spdhcnt;
 	sph.spdhmcnt = si.spdhmcnt;
 
+	do {
+		lseq = read_seqbegin(&net->xfrm.policy_hthresh.lock);
+
+		spt4.lbits = net->xfrm.policy_hthresh.lbits4;
+		spt4.rbits = net->xfrm.policy_hthresh.rbits4;
+		spt6.lbits = net->xfrm.policy_hthresh.lbits6;
+		spt6.rbits = net->xfrm.policy_hthresh.rbits6;
+	} while (read_seqretry(&net->xfrm.policy_hthresh.lock, lseq));
+
 	err = nla_put(skb, XFRMA_SPD_INFO, sizeof(spc), &spc);
 	if (!err)
 		err = nla_put(skb, XFRMA_SPD_HINFO, sizeof(sph), &sph);
+	if (!err)
+		err = nla_put(skb, XFRMA_SPD_IPV4_HTHRESH, sizeof(spt4), &spt4);
+	if (!err)
+		err = nla_put(skb, XFRMA_SPD_IPV6_HTHRESH, sizeof(spt6), &spt6);
 	if (err) {
 		nlmsg_cancel(skb, nlh);
 		return err;
@@ -1004,6 +1021,51 @@
 	return nlmsg_end(skb, nlh);
 }
 
+static int xfrm_set_spdinfo(struct sk_buff *skb, struct nlmsghdr *nlh,
+			    struct nlattr **attrs)
+{
+	struct net *net = sock_net(skb->sk);
+	struct xfrmu_spdhthresh *thresh4 = NULL;
+	struct xfrmu_spdhthresh *thresh6 = NULL;
+
+	/* selector prefixlen thresholds to hash policies */
+	if (attrs[XFRMA_SPD_IPV4_HTHRESH]) {
+		struct nlattr *rta = attrs[XFRMA_SPD_IPV4_HTHRESH];
+
+		if (nla_len(rta) < sizeof(*thresh4))
+			return -EINVAL;
+		thresh4 = nla_data(rta);
+		if (thresh4->lbits > 32 || thresh4->rbits > 32)
+			return -EINVAL;
+	}
+	if (attrs[XFRMA_SPD_IPV6_HTHRESH]) {
+		struct nlattr *rta = attrs[XFRMA_SPD_IPV6_HTHRESH];
+
+		if (nla_len(rta) < sizeof(*thresh6))
+			return -EINVAL;
+		thresh6 = nla_data(rta);
+		if (thresh6->lbits > 128 || thresh6->rbits > 128)
+			return -EINVAL;
+	}
+
+	if (thresh4 || thresh6) {
+		write_seqlock(&net->xfrm.policy_hthresh.lock);
+		if (thresh4) {
+			net->xfrm.policy_hthresh.lbits4 = thresh4->lbits;
+			net->xfrm.policy_hthresh.rbits4 = thresh4->rbits;
+		}
+		if (thresh6) {
+			net->xfrm.policy_hthresh.lbits6 = thresh6->lbits;
+			net->xfrm.policy_hthresh.rbits6 = thresh6->rbits;
+		}
+		write_sequnlock(&net->xfrm.policy_hthresh.lock);
+
+		xfrm_policy_hash_rebuild(net);
+	}
+
+	return 0;
+}
+
 static int xfrm_get_spdinfo(struct sk_buff *skb, struct nlmsghdr *nlh,
 		struct nlattr **attrs)
 {
@@ -2274,6 +2336,7 @@
 	[XFRM_MSG_REPORT      - XFRM_MSG_BASE] = XMSGSIZE(xfrm_user_report),
 	[XFRM_MSG_MIGRATE     - XFRM_MSG_BASE] = XMSGSIZE(xfrm_userpolicy_id),
 	[XFRM_MSG_GETSADINFO  - XFRM_MSG_BASE] = sizeof(u32),
+	[XFRM_MSG_NEWSPDINFO  - XFRM_MSG_BASE] = sizeof(u32),
 	[XFRM_MSG_GETSPDINFO  - XFRM_MSG_BASE] = sizeof(u32),
 };
 
@@ -2308,10 +2371,17 @@
 	[XFRMA_ADDRESS_FILTER]	= { .len = sizeof(struct xfrm_address_filter) },
 };
 
+static const struct nla_policy xfrma_spd_policy[XFRMA_SPD_MAX+1] = {
+	[XFRMA_SPD_IPV4_HTHRESH] = { .len = sizeof(struct xfrmu_spdhthresh) },
+	[XFRMA_SPD_IPV6_HTHRESH] = { .len = sizeof(struct xfrmu_spdhthresh) },
+};
+
 static const struct xfrm_link {
 	int (*doit)(struct sk_buff *, struct nlmsghdr *, struct nlattr **);
 	int (*dump)(struct sk_buff *, struct netlink_callback *);
 	int (*done)(struct netlink_callback *);
+	const struct nla_policy *nla_pol;
+	int nla_max;
 } xfrm_dispatch[XFRM_NR_MSGTYPES] = {
 	[XFRM_MSG_NEWSA       - XFRM_MSG_BASE] = { .doit = xfrm_add_sa        },
 	[XFRM_MSG_DELSA       - XFRM_MSG_BASE] = { .doit = xfrm_del_sa        },
@@ -2335,6 +2405,9 @@
 	[XFRM_MSG_GETAE       - XFRM_MSG_BASE] = { .doit = xfrm_get_ae  },
 	[XFRM_MSG_MIGRATE     - XFRM_MSG_BASE] = { .doit = xfrm_do_migrate    },
 	[XFRM_MSG_GETSADINFO  - XFRM_MSG_BASE] = { .doit = xfrm_get_sadinfo   },
+	[XFRM_MSG_NEWSPDINFO  - XFRM_MSG_BASE] = { .doit = xfrm_set_spdinfo,
+						   .nla_pol = xfrma_spd_policy,
+						   .nla_max = XFRMA_SPD_MAX },
 	[XFRM_MSG_GETSPDINFO  - XFRM_MSG_BASE] = { .doit = xfrm_get_spdinfo   },
 };
 
@@ -2371,8 +2444,9 @@
 		}
 	}
 
-	err = nlmsg_parse(nlh, xfrm_msg_min[type], attrs, XFRMA_MAX,
-			  xfrma_policy);
+	err = nlmsg_parse(nlh, xfrm_msg_min[type], attrs,
+			  link->nla_max ? : XFRMA_MAX,
+			  link->nla_pol ? : xfrma_policy);
 	if (err < 0)
 		return err;