mpls: make RTA_OIF optional

If user did not specify an oif, try and get it from the via address.
If failed to get device, return with -ENODEV.

Signed-off-by: Roopa Prabhu <roopa@cumulusnetworks.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/mpls/af_mpls.c b/net/mpls/af_mpls.c
index 6e66911..49f1b0e 100644
--- a/net/mpls/af_mpls.c
+++ b/net/mpls/af_mpls.c
@@ -15,6 +15,7 @@
 #include <net/ip_fib.h>
 #include <net/netevent.h>
 #include <net/netns/generic.h>
+#include <net/ip6_route.h>
 #include "internal.h"
 
 #define LABEL_NOT_SPECIFIED (1<<20)
@@ -330,6 +331,70 @@
 	return LABEL_NOT_SPECIFIED;
 }
 
+static struct net_device *inet_fib_lookup_dev(struct net *net, void *addr)
+{
+	struct net_device *dev = NULL;
+	struct rtable *rt;
+	struct in_addr daddr;
+
+	memcpy(&daddr, addr, sizeof(struct in_addr));
+	rt = ip_route_output(net, daddr.s_addr, 0, 0, 0);
+	if (IS_ERR(rt))
+		goto errout;
+
+	dev = rt->dst.dev;
+	dev_hold(dev);
+
+	ip_rt_put(rt);
+
+errout:
+	return dev;
+}
+
+static struct net_device *inet6_fib_lookup_dev(struct net *net, void *addr)
+{
+	struct net_device *dev = NULL;
+	struct dst_entry *dst;
+	struct flowi6 fl6;
+
+	memset(&fl6, 0, sizeof(fl6));
+	memcpy(&fl6.daddr, addr, sizeof(struct in6_addr));
+	dst = ip6_route_output(net, NULL, &fl6);
+	if (dst->error)
+		goto errout;
+
+	dev = dst->dev;
+	dev_hold(dev);
+
+errout:
+	dst_release(dst);
+
+	return dev;
+}
+
+static struct net_device *find_outdev(struct net *net,
+				      struct mpls_route_config *cfg)
+{
+	struct net_device *dev = NULL;
+
+	if (!cfg->rc_ifindex) {
+		switch (cfg->rc_via_table) {
+		case NEIGH_ARP_TABLE:
+			dev = inet_fib_lookup_dev(net, cfg->rc_via);
+			break;
+		case NEIGH_ND_TABLE:
+			dev = inet6_fib_lookup_dev(net, cfg->rc_via);
+			break;
+		case NEIGH_LINK_TABLE:
+			break;
+		}
+	} else {
+		dev = dev_get_by_index(net, cfg->rc_ifindex);
+	}
+
+	return dev;
+}
+
 static int mpls_route_add(struct mpls_route_config *cfg)
 {
 	struct mpls_route __rcu **platform_label;
@@ -361,7 +426,7 @@
 		goto errout;
 
 	err = -ENODEV;
-	dev = dev_get_by_index(net, cfg->rc_ifindex);
+	dev = find_outdev(net, cfg);
 	if (!dev)
 		goto errout;