net: mpls: Add support for netconf

Add netconf support to MPLS. Allows userpsace to learn and be notified
of changes to 'input' enable setting per interface.

Acked-by: Nicolas Dichtel <nicolas.dichtel@6wind.com>
Signed-off-by: David Ahern <dsa@cumulusnetworks.com>
Acked-by: Robert Shearman <rshearma@brocade.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/mpls/af_mpls.c b/net/mpls/af_mpls.c
index 64d3bf2..3818686 100644
--- a/net/mpls/af_mpls.c
+++ b/net/mpls/af_mpls.c
@@ -7,6 +7,7 @@
 #include <linux/if_arp.h>
 #include <linux/ipv6.h>
 #include <linux/mpls.h>
+#include <linux/netconf.h>
 #include <linux/vmalloc.h>
 #include <linux/percpu.h>
 #include <net/ip.h>
@@ -960,15 +961,215 @@
 	return nla_total_size_64bit(sizeof(struct mpls_link_stats));
 }
 
+static int mpls_netconf_fill_devconf(struct sk_buff *skb, struct mpls_dev *mdev,
+				     u32 portid, u32 seq, int event,
+				     unsigned int flags, int type)
+{
+	struct nlmsghdr  *nlh;
+	struct netconfmsg *ncm;
+	bool all = false;
+
+	nlh = nlmsg_put(skb, portid, seq, event, sizeof(struct netconfmsg),
+			flags);
+	if (!nlh)
+		return -EMSGSIZE;
+
+	if (type == NETCONFA_ALL)
+		all = true;
+
+	ncm = nlmsg_data(nlh);
+	ncm->ncm_family = AF_MPLS;
+
+	if (nla_put_s32(skb, NETCONFA_IFINDEX, mdev->dev->ifindex) < 0)
+		goto nla_put_failure;
+
+	if ((all || type == NETCONFA_INPUT) &&
+	    nla_put_s32(skb, NETCONFA_INPUT,
+			mdev->input_enabled) < 0)
+		goto nla_put_failure;
+
+	nlmsg_end(skb, nlh);
+	return 0;
+
+nla_put_failure:
+	nlmsg_cancel(skb, nlh);
+	return -EMSGSIZE;
+}
+
+static int mpls_netconf_msgsize_devconf(int type)
+{
+	int size = NLMSG_ALIGN(sizeof(struct netconfmsg))
+			+ nla_total_size(4); /* NETCONFA_IFINDEX */
+	bool all = false;
+
+	if (type == NETCONFA_ALL)
+		all = true;
+
+	if (all || type == NETCONFA_INPUT)
+		size += nla_total_size(4);
+
+	return size;
+}
+
+static void mpls_netconf_notify_devconf(struct net *net, int type,
+					struct mpls_dev *mdev)
+{
+	struct sk_buff *skb;
+	int err = -ENOBUFS;
+
+	skb = nlmsg_new(mpls_netconf_msgsize_devconf(type), GFP_KERNEL);
+	if (!skb)
+		goto errout;
+
+	err = mpls_netconf_fill_devconf(skb, mdev, 0, 0, RTM_NEWNETCONF,
+					0, type);
+	if (err < 0) {
+		/* -EMSGSIZE implies BUG in mpls_netconf_msgsize_devconf() */
+		WARN_ON(err == -EMSGSIZE);
+		kfree_skb(skb);
+		goto errout;
+	}
+
+	rtnl_notify(skb, net, 0, RTNLGRP_MPLS_NETCONF, NULL, GFP_KERNEL);
+	return;
+errout:
+	if (err < 0)
+		rtnl_set_sk_err(net, RTNLGRP_MPLS_NETCONF, err);
+}
+
+static const struct nla_policy devconf_mpls_policy[NETCONFA_MAX + 1] = {
+	[NETCONFA_IFINDEX]	= { .len = sizeof(int) },
+};
+
+static int mpls_netconf_get_devconf(struct sk_buff *in_skb,
+				    struct nlmsghdr *nlh)
+{
+	struct net *net = sock_net(in_skb->sk);
+	struct nlattr *tb[NETCONFA_MAX + 1];
+	struct netconfmsg *ncm;
+	struct net_device *dev;
+	struct mpls_dev *mdev;
+	struct sk_buff *skb;
+	int ifindex;
+	int err;
+
+	err = nlmsg_parse(nlh, sizeof(*ncm), tb, NETCONFA_MAX,
+			  devconf_mpls_policy);
+	if (err < 0)
+		goto errout;
+
+	err = -EINVAL;
+	if (!tb[NETCONFA_IFINDEX])
+		goto errout;
+
+	ifindex = nla_get_s32(tb[NETCONFA_IFINDEX]);
+	dev = __dev_get_by_index(net, ifindex);
+	if (!dev)
+		goto errout;
+
+	mdev = mpls_dev_get(dev);
+	if (!mdev)
+		goto errout;
+
+	err = -ENOBUFS;
+	skb = nlmsg_new(mpls_netconf_msgsize_devconf(NETCONFA_ALL), GFP_KERNEL);
+	if (!skb)
+		goto errout;
+
+	err = mpls_netconf_fill_devconf(skb, mdev,
+					NETLINK_CB(in_skb).portid,
+					nlh->nlmsg_seq, RTM_NEWNETCONF, 0,
+					NETCONFA_ALL);
+	if (err < 0) {
+		/* -EMSGSIZE implies BUG in mpls_netconf_msgsize_devconf() */
+		WARN_ON(err == -EMSGSIZE);
+		kfree_skb(skb);
+		goto errout;
+	}
+	err = rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid);
+errout:
+	return err;
+}
+
+static int mpls_netconf_dump_devconf(struct sk_buff *skb,
+				     struct netlink_callback *cb)
+{
+	struct net *net = sock_net(skb->sk);
+	struct hlist_head *head;
+	struct net_device *dev;
+	struct mpls_dev *mdev;
+	int idx, s_idx;
+	int h, s_h;
+
+	s_h = cb->args[0];
+	s_idx = idx = cb->args[1];
+
+	for (h = s_h; h < NETDEV_HASHENTRIES; h++, s_idx = 0) {
+		idx = 0;
+		head = &net->dev_index_head[h];
+		rcu_read_lock();
+		cb->seq = net->dev_base_seq;
+		hlist_for_each_entry_rcu(dev, head, index_hlist) {
+			if (idx < s_idx)
+				goto cont;
+			mdev = mpls_dev_get(dev);
+			if (!mdev)
+				goto cont;
+			if (mpls_netconf_fill_devconf(skb, mdev,
+						      NETLINK_CB(cb->skb).portid,
+						      cb->nlh->nlmsg_seq,
+						      RTM_NEWNETCONF,
+						      NLM_F_MULTI,
+						      NETCONFA_ALL) < 0) {
+				rcu_read_unlock();
+				goto done;
+			}
+			nl_dump_check_consistent(cb, nlmsg_hdr(skb));
+cont:
+			idx++;
+		}
+		rcu_read_unlock();
+	}
+done:
+	cb->args[0] = h;
+	cb->args[1] = idx;
+
+	return skb->len;
+}
+
 #define MPLS_PERDEV_SYSCTL_OFFSET(field)	\
 	(&((struct mpls_dev *)0)->field)
 
+static int mpls_conf_proc(struct ctl_table *ctl, int write,
+			  void __user *buffer,
+			  size_t *lenp, loff_t *ppos)
+{
+	int oval = *(int *)ctl->data;
+	int ret = proc_dointvec(ctl, write, buffer, lenp, ppos);
+
+	if (write) {
+		struct mpls_dev *mdev = ctl->extra1;
+		int i = (int *)ctl->data - (int *)mdev;
+		struct net *net = ctl->extra2;
+		int val = *(int *)ctl->data;
+
+		if (i == offsetof(struct mpls_dev, input_enabled) &&
+		    val != oval) {
+			mpls_netconf_notify_devconf(net,
+						    NETCONFA_INPUT,
+						    mdev);
+		}
+	}
+
+	return ret;
+}
+
 static const struct ctl_table mpls_dev_table[] = {
 	{
 		.procname	= "input",
 		.maxlen		= sizeof(int),
 		.mode		= 0644,
-		.proc_handler	= proc_dointvec,
+		.proc_handler	= mpls_conf_proc,
 		.data		= MPLS_PERDEV_SYSCTL_OFFSET(input_enabled),
 	},
 	{ }
@@ -978,6 +1179,7 @@
 				    struct mpls_dev *mdev)
 {
 	char path[sizeof("net/mpls/conf/") + IFNAMSIZ];
+	struct net *net = dev_net(dev);
 	struct ctl_table *table;
 	int i;
 
@@ -988,8 +1190,11 @@
 	/* Table data contains only offsets relative to the base of
 	 * the mdev at this point, so make them absolute.
 	 */
-	for (i = 0; i < ARRAY_SIZE(mpls_dev_table); i++)
+	for (i = 0; i < ARRAY_SIZE(mpls_dev_table); i++) {
 		table[i].data = (char *)mdev + (uintptr_t)table[i].data;
+		table[i].extra1 = mdev;
+		table[i].extra2 = net;
+	}
 
 	snprintf(path, sizeof(path), "net/mpls/conf/%s", dev->name);
 
@@ -1041,6 +1246,7 @@
 	if (err)
 		goto free;
 
+	mdev->dev = dev;
 	rcu_assign_pointer(dev->mpls_ptr, mdev);
 
 	return mdev;
@@ -1861,6 +2067,8 @@
 	rtnl_register(PF_MPLS, RTM_NEWROUTE, mpls_rtm_newroute, NULL, NULL);
 	rtnl_register(PF_MPLS, RTM_DELROUTE, mpls_rtm_delroute, NULL, NULL);
 	rtnl_register(PF_MPLS, RTM_GETROUTE, NULL, mpls_dump_routes, NULL);
+	rtnl_register(PF_MPLS, RTM_GETNETCONF, mpls_netconf_get_devconf,
+		      mpls_netconf_dump_devconf, NULL);
 	err = 0;
 out:
 	return err;