net: Convert TCP/DCCP listening hash tables to use RCU

This is the last step to be able to perform full RCU lookups
in __inet_lookup() : After established/timewait tables, we
add RCU lookups to listening hash table.

The only trick here is that a socket of a given type (TCP ipv4,
TCP ipv6, ...) can now flight between two different tables
(established and listening) during a RCU grace period, so we
must use different 'nulls' end-of-chain values for two tables.

We define a large value :

#define LISTENING_NULLS_BASE (1U << 29)

So that slots in listening table are guaranteed to have different
end-of-chain values than slots in established table. A reader can
still detect it finished its lookup in the right chain.

Signed-off-by: Eric Dumazet <dada1@cosmosbay.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/ipv6/inet6_hashtables.c b/net/ipv6/inet6_hashtables.c
index e0fd681..8fe267f 100644
--- a/net/ipv6/inet6_hashtables.c
+++ b/net/ipv6/inet6_hashtables.c
@@ -33,7 +33,7 @@
 
 		ilb = &hashinfo->listening_hash[inet_sk_listen_hashfn(sk)];
 		spin_lock(&ilb->lock);
-		__sk_add_node(sk, &ilb->head);
+		__sk_nulls_add_node_rcu(sk, &ilb->head);
 		spin_unlock(&ilb->lock);
 	} else {
 		unsigned int hash;
@@ -118,47 +118,71 @@
 }
 EXPORT_SYMBOL(__inet6_lookup_established);
 
+static int inline compute_score(struct sock *sk, struct net *net,
+				const unsigned short hnum,
+				const struct in6_addr *daddr,
+				const int dif)
+{
+	int score = -1;
+
+	if (net_eq(sock_net(sk), net) && inet_sk(sk)->num == hnum &&
+	    sk->sk_family == PF_INET6) {
+		const struct ipv6_pinfo *np = inet6_sk(sk);
+
+		score = 1;
+		if (!ipv6_addr_any(&np->rcv_saddr)) {
+			if (!ipv6_addr_equal(&np->rcv_saddr, daddr))
+				return -1;
+			score++;
+		}
+		if (sk->sk_bound_dev_if) {
+			if (sk->sk_bound_dev_if != dif)
+				return -1;
+			score++;
+		}
+	}
+	return score;
+}
+
 struct sock *inet6_lookup_listener(struct net *net,
 		struct inet_hashinfo *hashinfo, const struct in6_addr *daddr,
 		const unsigned short hnum, const int dif)
 {
 	struct sock *sk;
-	const struct hlist_node *node;
-	struct sock *result = NULL;
-	int score, hiscore = 0;
-	struct inet_listen_hashbucket *ilb;
+	const struct hlist_nulls_node *node;
+	struct sock *result;
+	int score, hiscore;
+	unsigned int hash = inet_lhashfn(net, hnum);
+	struct inet_listen_hashbucket *ilb = &hashinfo->listening_hash[hash];
 
-	ilb = &hashinfo->listening_hash[inet_lhashfn(net, hnum)];
-	spin_lock(&ilb->lock);
-	sk_for_each(sk, node, &ilb->head) {
-		if (net_eq(sock_net(sk), net) && inet_sk(sk)->num == hnum &&
-				sk->sk_family == PF_INET6) {
-			const struct ipv6_pinfo *np = inet6_sk(sk);
-
-			score = 1;
-			if (!ipv6_addr_any(&np->rcv_saddr)) {
-				if (!ipv6_addr_equal(&np->rcv_saddr, daddr))
-					continue;
-				score++;
-			}
-			if (sk->sk_bound_dev_if) {
-				if (sk->sk_bound_dev_if != dif)
-					continue;
-				score++;
-			}
-			if (score == 3) {
-				result = sk;
-				break;
-			}
-			if (score > hiscore) {
-				hiscore = score;
-				result = sk;
-			}
+	rcu_read_lock();
+begin:
+	result = NULL;
+	hiscore = -1;
+	sk_nulls_for_each(sk, node, &ilb->head) {
+		score = compute_score(sk, net, hnum, daddr, dif);
+		if (score > hiscore) {
+			hiscore = score;
+			result = sk;
 		}
 	}
-	if (result)
-		sock_hold(result);
-	spin_unlock(&ilb->lock);
+	/*
+	 * if the nulls value we got at the end of this lookup is
+	 * not the expected one, we must restart lookup.
+	 * We probably met an item that was moved to another chain.
+	 */
+	if (get_nulls_value(node) != hash + LISTENING_NULLS_BASE)
+		goto begin;
+	if (result) {
+		if (unlikely(!atomic_inc_not_zero(&result->sk_refcnt)))
+			result = NULL;
+		else if (unlikely(compute_score(result, net, hnum, daddr,
+				  dif) < hiscore)) {
+			sock_put(result);
+			goto begin;
+		}
+	}
+	rcu_read_unlock();
 	return result;
 }