l2tp: fix race in l2tp_recv_common()

commit 61b9a047729bb230978178bca6729689d0c50ca2 upstream.

Taking a reference on sessions in l2tp_recv_common() is racy; this
has to be done by the callers.

To this end, a new function is required (l2tp_session_get()) to
atomically lookup a session and take a reference on it. Callers then
have to manually drop this reference.

Fixes: fd558d186df2 ("l2tp: Split pppol2tp patch into separate l2tp and ppp parts")
Signed-off-by: Guillaume Nault <g.nault@alphalink.fr>
Signed-off-by: David S. Miller <davem@davemloft.net>
Signed-off-by: Amit Pundir <amit.pundir@linaro.org>
Signed-off-by: Greg Kroah-Hartman <gregkh@linuxfoundation.org>
diff --git a/net/l2tp/l2tp_core.c b/net/l2tp/l2tp_core.c
index e702cb95..046a5ba 100644
--- a/net/l2tp/l2tp_core.c
+++ b/net/l2tp/l2tp_core.c
@@ -278,6 +278,55 @@
 }
 EXPORT_SYMBOL_GPL(l2tp_session_find);
 
+/* Like l2tp_session_find() but takes a reference on the returned session.
+ * Optionally calls session->ref() too if do_ref is true.
+ */
+struct l2tp_session *l2tp_session_get(struct net *net,
+				      struct l2tp_tunnel *tunnel,
+				      u32 session_id, bool do_ref)
+{
+	struct hlist_head *session_list;
+	struct l2tp_session *session;
+
+	if (!tunnel) {
+		struct l2tp_net *pn = l2tp_pernet(net);
+
+		session_list = l2tp_session_id_hash_2(pn, session_id);
+
+		rcu_read_lock_bh();
+		hlist_for_each_entry_rcu(session, session_list, global_hlist) {
+			if (session->session_id == session_id) {
+				l2tp_session_inc_refcount(session);
+				if (do_ref && session->ref)
+					session->ref(session);
+				rcu_read_unlock_bh();
+
+				return session;
+			}
+		}
+		rcu_read_unlock_bh();
+
+		return NULL;
+	}
+
+	session_list = l2tp_session_id_hash(tunnel, session_id);
+	read_lock_bh(&tunnel->hlist_lock);
+	hlist_for_each_entry(session, session_list, hlist) {
+		if (session->session_id == session_id) {
+			l2tp_session_inc_refcount(session);
+			if (do_ref && session->ref)
+				session->ref(session);
+			read_unlock_bh(&tunnel->hlist_lock);
+
+			return session;
+		}
+	}
+	read_unlock_bh(&tunnel->hlist_lock);
+
+	return NULL;
+}
+EXPORT_SYMBOL_GPL(l2tp_session_get);
+
 struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth,
 					  bool do_ref)
 {
@@ -637,6 +686,9 @@
  * a data (not control) frame before coming here. Fields up to the
  * session-id have already been parsed and ptr points to the data
  * after the session-id.
+ *
+ * session->ref() must have been called prior to l2tp_recv_common().
+ * session->deref() will be called automatically after skb is processed.
  */
 void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
 		      unsigned char *ptr, unsigned char *optr, u16 hdrflags,
@@ -646,14 +698,6 @@
 	int offset;
 	u32 ns, nr;
 
-	/* The ref count is increased since we now hold a pointer to
-	 * the session. Take care to decrement the refcnt when exiting
-	 * this function from now on...
-	 */
-	l2tp_session_inc_refcount(session);
-	if (session->ref)
-		(*session->ref)(session);
-
 	/* Parse and check optional cookie */
 	if (session->peer_cookie_len > 0) {
 		if (memcmp(ptr, &session->peer_cookie[0], session->peer_cookie_len)) {
@@ -806,8 +850,6 @@
 	/* Try to dequeue as many skbs from reorder_q as we can. */
 	l2tp_recv_dequeue(session);
 
-	l2tp_session_dec_refcount(session);
-
 	return;
 
 discard:
@@ -816,8 +858,6 @@
 
 	if (session->deref)
 		(*session->deref)(session);
-
-	l2tp_session_dec_refcount(session);
 }
 EXPORT_SYMBOL(l2tp_recv_common);
 
@@ -924,8 +964,14 @@
 	}
 
 	/* Find the session context */
-	session = l2tp_session_find(tunnel->l2tp_net, tunnel, session_id);
+	session = l2tp_session_get(tunnel->l2tp_net, tunnel, session_id, true);
 	if (!session || !session->recv_skb) {
+		if (session) {
+			if (session->deref)
+				session->deref(session);
+			l2tp_session_dec_refcount(session);
+		}
+
 		/* Not found? Pass to userspace to deal with */
 		l2tp_info(tunnel, L2TP_MSG_DATA,
 			  "%s: no session found (%u/%u). Passing up.\n",
@@ -934,6 +980,7 @@
 	}
 
 	l2tp_recv_common(session, skb, ptr, optr, hdrflags, length, payload_hook);
+	l2tp_session_dec_refcount(session);
 
 	return 0;