Merge branch 'l2tp-rework-pppol2tp-ioctl-handling'

Guillaume Nault says:

====================
l2tp: rework pppol2tp ioctl handling

The current ioctl() handling code can be simplified. It tests for
non-relevant conditions and uselessly holds sockets. Once useless
code is removed, it becomes even simpler to let pppol2tp_ioctl() handle
commands directly, rather than dispatch them to pppol2tp_tunnel_ioctl()
or pppol2tp_session_ioctl(). That is the approach taken by this series.

Patch #1 and #2 define helper functions aimed at simplifying the rest
of the patch set.

Patch #3 drops useless tests in pppol2p_ioctl() and avoid holding a
refcount on the socket.

Patches #4, #5 and #6 are the core of the series. They let
pppol2tp_ioctl() handle all ioctls and drop the tunnel and session
specific functions.

Then patch #6 brings a little bit of consolidation.

Finally, patch #7 takes advantage of the simplified code to make
pppol2tp sockets compatible with dev_ioctl(). Certainly not a killer
feature, but it is trivial and it is always nice to see l2tp getting
better integration with the rest of the stack.
====================

Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/include/uapi/linux/ppp-ioctl.h b/include/uapi/linux/ppp-ioctl.h
index 784c2e3..88b5f99 100644
--- a/include/uapi/linux/ppp-ioctl.h
+++ b/include/uapi/linux/ppp-ioctl.h
@@ -68,7 +68,7 @@ struct ppp_option_data {
 struct pppol2tp_ioc_stats {
 	__u16		tunnel_id;	/* redundant */
 	__u16		session_id;	/* if zero, get tunnel stats */
-	__u32		using_ipsec:1;	/* valid only for session_id == 0 */
+	__u32		using_ipsec:1;
 	__aligned_u64	tx_packets;
 	__aligned_u64	tx_bytes;
 	__aligned_u64	tx_errors;
diff --git a/net/l2tp/l2tp_core.c b/net/l2tp/l2tp_core.c
index ac6a00bc..2bd701a 100644
--- a/net/l2tp/l2tp_core.c
+++ b/net/l2tp/l2tp_core.c
@@ -203,47 +203,47 @@ struct l2tp_tunnel *l2tp_tunnel_get_nth(const struct net *net, int nth)
 }
 EXPORT_SYMBOL_GPL(l2tp_tunnel_get_nth);
 
-/* Lookup a session. A new reference is held on the returned session. */
-struct l2tp_session *l2tp_session_get(const struct net *net,
-				      struct l2tp_tunnel *tunnel,
-				      u32 session_id)
+struct l2tp_session *l2tp_tunnel_get_session(struct l2tp_tunnel *tunnel,
+					     u32 session_id)
 {
 	struct hlist_head *session_list;
 	struct l2tp_session *session;
 
-	if (!tunnel) {
-		struct l2tp_net *pn = l2tp_pernet(net);
-
-		session_list = l2tp_session_id_hash_2(pn, session_id);
-
-		rcu_read_lock_bh();
-		hlist_for_each_entry_rcu(session, session_list, global_hlist) {
-			if (session->session_id == session_id) {
-				l2tp_session_inc_refcount(session);
-				rcu_read_unlock_bh();
-
-				return session;
-			}
-		}
-		rcu_read_unlock_bh();
-
-		return NULL;
-	}
-
 	session_list = l2tp_session_id_hash(tunnel, session_id);
+
 	read_lock_bh(&tunnel->hlist_lock);
-	hlist_for_each_entry(session, session_list, hlist) {
+	hlist_for_each_entry(session, session_list, hlist)
 		if (session->session_id == session_id) {
 			l2tp_session_inc_refcount(session);
 			read_unlock_bh(&tunnel->hlist_lock);
 
 			return session;
 		}
-	}
 	read_unlock_bh(&tunnel->hlist_lock);
 
 	return NULL;
 }
+EXPORT_SYMBOL_GPL(l2tp_tunnel_get_session);
+
+struct l2tp_session *l2tp_session_get(const struct net *net, u32 session_id)
+{
+	struct hlist_head *session_list;
+	struct l2tp_session *session;
+
+	session_list = l2tp_session_id_hash_2(l2tp_pernet(net), session_id);
+
+	rcu_read_lock_bh();
+	hlist_for_each_entry_rcu(session, session_list, global_hlist)
+		if (session->session_id == session_id) {
+			l2tp_session_inc_refcount(session);
+			rcu_read_unlock_bh();
+
+			return session;
+		}
+	rcu_read_unlock_bh();
+
+	return NULL;
+}
 EXPORT_SYMBOL_GPL(l2tp_session_get);
 
 struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth)
@@ -872,7 +872,7 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb)
 	}
 
 	/* Find the session context */
-	session = l2tp_session_get(tunnel->l2tp_net, tunnel, session_id);
+	session = l2tp_tunnel_get_session(tunnel, session_id);
 	if (!session || !session->recv_skb) {
 		if (session)
 			l2tp_session_dec_refcount(session);
diff --git a/net/l2tp/l2tp_core.h b/net/l2tp/l2tp_core.h
index 5804065..8480a0a 100644
--- a/net/l2tp/l2tp_core.h
+++ b/net/l2tp/l2tp_core.h
@@ -15,6 +15,10 @@
 #include <net/dst.h>
 #include <net/sock.h>
 
+#ifdef CONFIG_XFRM
+#include <net/xfrm.h>
+#endif
+
 /* Just some random numbers */
 #define L2TP_TUNNEL_MAGIC	0x42114DDA
 #define L2TP_SESSION_MAGIC	0x0C04EB7D
@@ -192,12 +196,12 @@ static inline void *l2tp_session_priv(struct l2tp_session *session)
 
 struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id);
 struct l2tp_tunnel *l2tp_tunnel_get_nth(const struct net *net, int nth);
+struct l2tp_session *l2tp_tunnel_get_session(struct l2tp_tunnel *tunnel,
+					     u32 session_id);
 
 void l2tp_tunnel_free(struct l2tp_tunnel *tunnel);
 
-struct l2tp_session *l2tp_session_get(const struct net *net,
-				      struct l2tp_tunnel *tunnel,
-				      u32 session_id);
+struct l2tp_session *l2tp_session_get(const struct net *net, u32 session_id);
 struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth);
 struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net,
 						const char *ifname);
@@ -284,6 +288,21 @@ static inline u32 l2tp_tunnel_dst_mtu(const struct l2tp_tunnel *tunnel)
 	return mtu;
 }
 
+#ifdef CONFIG_XFRM
+static inline bool l2tp_tunnel_uses_xfrm(const struct l2tp_tunnel *tunnel)
+{
+	struct sock *sk = tunnel->sock;
+
+	return sk && (rcu_access_pointer(sk->sk_policy[0]) ||
+		      rcu_access_pointer(sk->sk_policy[1]));
+}
+#else
+static inline bool l2tp_tunnel_uses_xfrm(const struct l2tp_tunnel *tunnel)
+{
+	return false;
+}
+#endif
+
 #define l2tp_printk(ptr, type, func, fmt, ...)				\
 do {									\
 	if (((ptr)->debug) & (type))					\
diff --git a/net/l2tp/l2tp_ip.c b/net/l2tp/l2tp_ip.c
index 0bc39cc..35f6f86 100644
--- a/net/l2tp/l2tp_ip.c
+++ b/net/l2tp/l2tp_ip.c
@@ -144,7 +144,7 @@ static int l2tp_ip_recv(struct sk_buff *skb)
 	}
 
 	/* Ok, this is a data packet. Lookup the session. */
-	session = l2tp_session_get(net, NULL, session_id);
+	session = l2tp_session_get(net, session_id);
 	if (!session)
 		goto discard;
 
diff --git a/net/l2tp/l2tp_ip6.c b/net/l2tp/l2tp_ip6.c
index 42f828c..237f1a4 100644
--- a/net/l2tp/l2tp_ip6.c
+++ b/net/l2tp/l2tp_ip6.c
@@ -157,7 +157,7 @@ static int l2tp_ip6_recv(struct sk_buff *skb)
 	}
 
 	/* Ok, this is a data packet. Lookup the session. */
-	session = l2tp_session_get(net, NULL, session_id);
+	session = l2tp_session_get(net, session_id);
 	if (!session)
 		goto discard;
 
diff --git a/net/l2tp/l2tp_netlink.c b/net/l2tp/l2tp_netlink.c
index 2e1e926..edbd5d1 100644
--- a/net/l2tp/l2tp_netlink.c
+++ b/net/l2tp/l2tp_netlink.c
@@ -66,7 +66,7 @@ static struct l2tp_session *l2tp_nl_session_get(struct genl_info *info)
 		session_id = nla_get_u32(info->attrs[L2TP_ATTR_SESSION_ID]);
 		tunnel = l2tp_tunnel_get(net, tunnel_id);
 		if (tunnel) {
-			session = l2tp_session_get(net, tunnel, session_id);
+			session = l2tp_tunnel_get_session(tunnel, session_id);
 			l2tp_tunnel_dec_refcount(tunnel);
 		}
 	}
@@ -627,7 +627,7 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf
 							   &cfg);
 
 	if (ret >= 0) {
-		session = l2tp_session_get(net, tunnel, session_id);
+		session = l2tp_tunnel_get_session(tunnel, session_id);
 		if (session) {
 			ret = l2tp_session_notify(&l2tp_nl_family, info, session,
 						  L2TP_CMD_SESSION_CREATE);
@@ -710,9 +710,6 @@ static int l2tp_nl_session_send(struct sk_buff *skb, u32 portid, u32 seq, int fl
 	void *hdr;
 	struct nlattr *nest;
 	struct l2tp_tunnel *tunnel = session->tunnel;
-	struct sock *sk = NULL;
-
-	sk = tunnel->sock;
 
 	hdr = genlmsg_put(skb, portid, seq, &l2tp_nl_family, flags, cmd);
 	if (!hdr)
@@ -738,10 +735,8 @@ static int l2tp_nl_session_send(struct sk_buff *skb, u32 portid, u32 seq, int fl
 	    nla_put_u8(skb, L2TP_ATTR_RECV_SEQ, session->recv_seq) ||
 	    nla_put_u8(skb, L2TP_ATTR_SEND_SEQ, session->send_seq) ||
 	    nla_put_u8(skb, L2TP_ATTR_LNS_MODE, session->lns_mode) ||
-#ifdef CONFIG_XFRM
-	    (((sk) && (sk->sk_policy[0] || sk->sk_policy[1])) &&
+	    (l2tp_tunnel_uses_xfrm(tunnel) &&
 	     nla_put_u8(skb, L2TP_ATTR_USING_IPSEC, 1)) ||
-#endif
 	    (session->reorder_timeout &&
 	     nla_put_msecs(skb, L2TP_ATTR_RECV_TIMEOUT,
 			   session->reorder_timeout, L2TP_ATTR_PAD)))
diff --git a/net/l2tp/l2tp_ppp.c b/net/l2tp/l2tp_ppp.c
index 6e2c8e7..62f2d3f 100644
--- a/net/l2tp/l2tp_ppp.c
+++ b/net/l2tp/l2tp_ppp.c
@@ -95,7 +95,6 @@
 #include <net/netns/generic.h>
 #include <net/ip.h>
 #include <net/udp.h>
-#include <net/xfrm.h>
 #include <net/inet_common.h>
 
 #include <asm/byteorder.h>
@@ -758,7 +757,7 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
 	if (tunnel->peer_tunnel_id == 0)
 		tunnel->peer_tunnel_id = info.peer_tunnel_id;
 
-	session = l2tp_session_get(sock_net(sk), tunnel, info.session_id);
+	session = l2tp_tunnel_get_session(tunnel, info.session_id);
 	if (session) {
 		drop_refcnt = true;
 
@@ -1027,8 +1026,10 @@ static int pppol2tp_getname(struct socket *sock, struct sockaddr *uaddr,
  ****************************************************************************/
 
 static void pppol2tp_copy_stats(struct pppol2tp_ioc_stats *dest,
-				struct l2tp_stats *stats)
+				const struct l2tp_stats *stats)
 {
+	memset(dest, 0, sizeof(*dest));
+
 	dest->tx_packets = atomic_long_read(&stats->tx_packets);
 	dest->tx_bytes = atomic_long_read(&stats->tx_bytes);
 	dest->tx_errors = atomic_long_read(&stats->tx_errors);
@@ -1039,188 +1040,107 @@ static void pppol2tp_copy_stats(struct pppol2tp_ioc_stats *dest,
 	dest->rx_errors = atomic_long_read(&stats->rx_errors);
 }
 
-/* Session ioctl helper.
- */
-static int pppol2tp_session_ioctl(struct l2tp_session *session,
-				  unsigned int cmd, unsigned long arg)
+static int pppol2tp_tunnel_copy_stats(struct pppol2tp_ioc_stats *stats,
+				      struct l2tp_tunnel *tunnel)
 {
-	int err = 0;
-	struct sock *sk;
-	int val = (int) arg;
-	struct l2tp_tunnel *tunnel = session->tunnel;
-	struct pppol2tp_ioc_stats stats;
+	struct l2tp_session *session;
 
-	l2tp_dbg(session, L2TP_MSG_CONTROL,
-		 "%s: pppol2tp_session_ioctl(cmd=%#x, arg=%#lx)\n",
-		 session->name, cmd, arg);
+	if (!stats->session_id) {
+		pppol2tp_copy_stats(stats, &tunnel->stats);
+		return 0;
+	}
 
-	sk = pppol2tp_session_get_sock(session);
-	if (!sk)
+	/* If session_id is set, search the corresponding session in the
+	 * context of this tunnel and record the session's statistics.
+	 */
+	session = l2tp_tunnel_get_session(tunnel, stats->session_id);
+	if (!session)
 		return -EBADR;
 
+	if (session->pwtype != L2TP_PWTYPE_PPP) {
+		l2tp_session_dec_refcount(session);
+		return -EBADR;
+	}
+
+	pppol2tp_copy_stats(stats, &session->stats);
+	l2tp_session_dec_refcount(session);
+
+	return 0;
+}
+
+static int pppol2tp_ioctl(struct socket *sock, unsigned int cmd,
+			  unsigned long arg)
+{
+	struct pppol2tp_ioc_stats stats;
+	struct l2tp_session *session;
+	int val;
+
 	switch (cmd) {
 	case PPPIOCGMRU:
 	case PPPIOCGFLAGS:
-		err = -EFAULT;
+		session = sock->sk->sk_user_data;
+		if (!session)
+			return -ENOTCONN;
+
+		/* Not defined for tunnels */
+		if (!session->session_id && !session->peer_session_id)
+			return -ENOSYS;
+
 		if (put_user(0, (int __user *)arg))
-			break;
-		err = 0;
+			return -EFAULT;
 		break;
 
 	case PPPIOCSMRU:
 	case PPPIOCSFLAGS:
-		err = -EFAULT;
+		session = sock->sk->sk_user_data;
+		if (!session)
+			return -ENOTCONN;
+
+		/* Not defined for tunnels */
+		if (!session->session_id && !session->peer_session_id)
+			return -ENOSYS;
+
 		if (get_user(val, (int __user *)arg))
-			break;
-		err = 0;
+			return -EFAULT;
 		break;
 
 	case PPPIOCGL2TPSTATS:
-		err = -ENXIO;
-		if (!(sk->sk_state & PPPOX_CONNECTED))
-			break;
+		session = sock->sk->sk_user_data;
+		if (!session)
+			return -ENOTCONN;
 
-		memset(&stats, 0, sizeof(stats));
-		stats.tunnel_id = tunnel->tunnel_id;
-		stats.session_id = session->session_id;
-		pppol2tp_copy_stats(&stats, &session->stats);
-		if (copy_to_user((void __user *) arg, &stats,
-				 sizeof(stats)))
-			break;
-		l2tp_info(session, L2TP_MSG_CONTROL, "%s: get L2TP stats\n",
-			  session->name);
-		err = 0;
+		/* Session 0 represents the parent tunnel */
+		if (!session->session_id && !session->peer_session_id) {
+			u32 session_id;
+			int err;
+
+			if (copy_from_user(&stats, (void __user *)arg,
+					   sizeof(stats)))
+				return -EFAULT;
+
+			session_id = stats.session_id;
+			err = pppol2tp_tunnel_copy_stats(&stats,
+							 session->tunnel);
+			if (err < 0)
+				return err;
+
+			stats.session_id = session_id;
+		} else {
+			pppol2tp_copy_stats(&stats, &session->stats);
+			stats.session_id = session->session_id;
+		}
+		stats.tunnel_id = session->tunnel->tunnel_id;
+		stats.using_ipsec = l2tp_tunnel_uses_xfrm(session->tunnel);
+
+		if (copy_to_user((void __user *)arg, &stats, sizeof(stats)))
+			return -EFAULT;
 		break;
 
 	default:
-		err = -ENOSYS;
-		break;
+		return -ENOIOCTLCMD;
 	}
 
-	sock_put(sk);
-
-	return err;
-}
-
-/* Tunnel ioctl helper.
- *
- * Note the special handling for PPPIOCGL2TPSTATS below. If the ioctl data
- * specifies a session_id, the session ioctl handler is called. This allows an
- * application to retrieve session stats via a tunnel socket.
- */
-static int pppol2tp_tunnel_ioctl(struct l2tp_tunnel *tunnel,
-				 unsigned int cmd, unsigned long arg)
-{
-	int err = 0;
-	struct sock *sk;
-	struct pppol2tp_ioc_stats stats;
-
-	l2tp_dbg(tunnel, L2TP_MSG_CONTROL,
-		 "%s: pppol2tp_tunnel_ioctl(cmd=%#x, arg=%#lx)\n",
-		 tunnel->name, cmd, arg);
-
-	sk = tunnel->sock;
-	sock_hold(sk);
-
-	switch (cmd) {
-	case PPPIOCGL2TPSTATS:
-		err = -ENXIO;
-		if (!(sk->sk_state & PPPOX_CONNECTED))
-			break;
-
-		if (copy_from_user(&stats, (void __user *) arg,
-				   sizeof(stats))) {
-			err = -EFAULT;
-			break;
-		}
-		if (stats.session_id != 0) {
-			/* resend to session ioctl handler */
-			struct l2tp_session *session =
-				l2tp_session_get(sock_net(sk), tunnel,
-						 stats.session_id);
-
-			if (!session) {
-				err = -EBADR;
-				break;
-			}
-			if (session->pwtype != L2TP_PWTYPE_PPP) {
-				l2tp_session_dec_refcount(session);
-				err = -EBADR;
-				break;
-			}
-
-			err = pppol2tp_session_ioctl(session, cmd, arg);
-			l2tp_session_dec_refcount(session);
-			break;
-		}
-#ifdef CONFIG_XFRM
-		stats.using_ipsec = (sk->sk_policy[0] || sk->sk_policy[1]) ? 1 : 0;
-#endif
-		pppol2tp_copy_stats(&stats, &tunnel->stats);
-		if (copy_to_user((void __user *) arg, &stats, sizeof(stats))) {
-			err = -EFAULT;
-			break;
-		}
-		l2tp_info(tunnel, L2TP_MSG_CONTROL, "%s: get L2TP stats\n",
-			  tunnel->name);
-		err = 0;
-		break;
-
-	default:
-		err = -ENOSYS;
-		break;
-	}
-
-	sock_put(sk);
-
-	return err;
-}
-
-/* Main ioctl() handler.
- * Dispatch to tunnel or session helpers depending on the socket.
- */
-static int pppol2tp_ioctl(struct socket *sock, unsigned int cmd,
-			  unsigned long arg)
-{
-	struct sock *sk = sock->sk;
-	struct l2tp_session *session;
-	struct l2tp_tunnel *tunnel;
-	int err;
-
-	if (!sk)
-		return 0;
-
-	err = -EBADF;
-	if (sock_flag(sk, SOCK_DEAD) != 0)
-		goto end;
-
-	err = -ENOTCONN;
-	if ((sk->sk_user_data == NULL) ||
-	    (!(sk->sk_state & (PPPOX_CONNECTED | PPPOX_BOUND))))
-		goto end;
-
-	/* Get session context from the socket */
-	err = -EBADF;
-	session = pppol2tp_sock_to_session(sk);
-	if (session == NULL)
-		goto end;
-
-	/* Special case: if session's session_id is zero, treat ioctl as a
-	 * tunnel ioctl
-	 */
-	if ((session->session_id == 0) &&
-	    (session->peer_session_id == 0)) {
-		tunnel = session->tunnel;
-		err = pppol2tp_tunnel_ioctl(tunnel, cmd, arg);
-		goto end_put_sess;
-	}
-
-	err = pppol2tp_session_ioctl(session, cmd, arg);
-
-end_put_sess:
-	sock_put(sk);
-end:
-	return err;
+	return 0;
 }
 
 /*****************************************************************************