soreuseport: fast reuseport UDP socket selection

Include a struct sock_reuseport instance when a UDP socket binds to
a specific address for the first time with the reuseport flag set.
When selecting a socket for an incoming UDP packet, use the information
available in sock_reuseport if present.

This required adding an additional field to the UDP source address
equality function to differentiate between exact and wildcard matches.
The original use case allowed wildcard matches when checking for
existing port uses during bind.  The new use case of adding a socket
to a reuseport group requires exact address matching.

Performance test (using a machine with 2 CPU sockets and a total of
48 cores):  Create reuseport groups of varying size.  Use one socket
from this group per user thread (pinning each thread to a different
core) calling recvmmsg in a tight loop.  Record number of messages
received per second while saturating a 10G link.
  10 sockets: 18% increase (~2.8M -> 3.3M pkts/s)
  20 sockets: 14% increase (~2.9M -> 3.3M pkts/s)
  40 sockets: 13% increase (~3.0M -> 3.4M pkts/s)

This work is based off a similar implementation written by
Ying Cai <ycai@google.com> for implementing policy-based reuseport
selection.

Signed-off-by: Craig Gallek <kraig@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index ac14ae4..762b01f 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -113,6 +113,7 @@
 #include <trace/events/skb.h>
 #include <net/busy_poll.h>
 #include "udp_impl.h"
+#include <net/sock_reuseport.h>
 
 struct udp_table udp_table __read_mostly;
 EXPORT_SYMBOL(udp_table);
@@ -137,7 +138,8 @@
 			       unsigned long *bitmap,
 			       struct sock *sk,
 			       int (*saddr_comp)(const struct sock *sk1,
-						 const struct sock *sk2),
+						 const struct sock *sk2,
+						 bool match_wildcard),
 			       unsigned int log)
 {
 	struct sock *sk2;
@@ -152,8 +154,9 @@
 		    (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if ||
 		     sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
 		    (!sk2->sk_reuseport || !sk->sk_reuseport ||
+		     rcu_access_pointer(sk->sk_reuseport_cb) ||
 		     !uid_eq(uid, sock_i_uid(sk2))) &&
-		    saddr_comp(sk, sk2)) {
+		    saddr_comp(sk, sk2, true)) {
 			if (!bitmap)
 				return 1;
 			__set_bit(udp_sk(sk2)->udp_port_hash >> log, bitmap);
@@ -170,7 +173,8 @@
 				struct udp_hslot *hslot2,
 				struct sock *sk,
 				int (*saddr_comp)(const struct sock *sk1,
-						  const struct sock *sk2))
+						  const struct sock *sk2,
+						  bool match_wildcard))
 {
 	struct sock *sk2;
 	struct hlist_nulls_node *node;
@@ -186,8 +190,9 @@
 		    (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if ||
 		     sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
 		    (!sk2->sk_reuseport || !sk->sk_reuseport ||
+		     rcu_access_pointer(sk->sk_reuseport_cb) ||
 		     !uid_eq(uid, sock_i_uid(sk2))) &&
-		    saddr_comp(sk, sk2)) {
+		    saddr_comp(sk, sk2, true)) {
 			res = 1;
 			break;
 		}
@@ -196,6 +201,35 @@
 	return res;
 }
 
+static int udp_reuseport_add_sock(struct sock *sk, struct udp_hslot *hslot,
+				  int (*saddr_same)(const struct sock *sk1,
+						    const struct sock *sk2,
+						    bool match_wildcard))
+{
+	struct net *net = sock_net(sk);
+	struct hlist_nulls_node *node;
+	kuid_t uid = sock_i_uid(sk);
+	struct sock *sk2;
+
+	sk_nulls_for_each(sk2, node, &hslot->head) {
+		if (net_eq(sock_net(sk2), net) &&
+		    sk2 != sk &&
+		    sk2->sk_family == sk->sk_family &&
+		    ipv6_only_sock(sk2) == ipv6_only_sock(sk) &&
+		    (udp_sk(sk2)->udp_port_hash == udp_sk(sk)->udp_port_hash) &&
+		    (sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
+		    sk2->sk_reuseport && uid_eq(uid, sock_i_uid(sk2)) &&
+		    (*saddr_same)(sk, sk2, false)) {
+			return reuseport_add_sock(sk, sk2);
+		}
+	}
+
+	/* Initial allocation may have already happened via setsockopt */
+	if (!rcu_access_pointer(sk->sk_reuseport_cb))
+		return reuseport_alloc(sk);
+	return 0;
+}
+
 /**
  *  udp_lib_get_port  -  UDP/-Lite port lookup for IPv4 and IPv6
  *
@@ -207,7 +241,8 @@
  */
 int udp_lib_get_port(struct sock *sk, unsigned short snum,
 		     int (*saddr_comp)(const struct sock *sk1,
-				       const struct sock *sk2),
+				       const struct sock *sk2,
+				       bool match_wildcard),
 		     unsigned int hash2_nulladdr)
 {
 	struct udp_hslot *hslot, *hslot2;
@@ -290,6 +325,14 @@
 	udp_sk(sk)->udp_port_hash = snum;
 	udp_sk(sk)->udp_portaddr_hash ^= snum;
 	if (sk_unhashed(sk)) {
+		if (sk->sk_reuseport &&
+		    udp_reuseport_add_sock(sk, hslot, saddr_comp)) {
+			inet_sk(sk)->inet_num = 0;
+			udp_sk(sk)->udp_port_hash = 0;
+			udp_sk(sk)->udp_portaddr_hash ^= snum;
+			goto fail_unlock;
+		}
+
 		sk_nulls_add_node_rcu(sk, &hslot->head);
 		hslot->count++;
 		sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
@@ -309,13 +352,22 @@
 }
 EXPORT_SYMBOL(udp_lib_get_port);
 
-static int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2)
+/* match_wildcard == true:  0.0.0.0 equals to any IPv4 addresses
+ * match_wildcard == false: addresses must be exactly the same, i.e.
+ *                          0.0.0.0 only equals to 0.0.0.0
+ */
+static int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2,
+				bool match_wildcard)
 {
 	struct inet_sock *inet1 = inet_sk(sk1), *inet2 = inet_sk(sk2);
 
-	return 	(!ipv6_only_sock(sk2)  &&
-		 (!inet1->inet_rcv_saddr || !inet2->inet_rcv_saddr ||
-		   inet1->inet_rcv_saddr == inet2->inet_rcv_saddr));
+	if (!ipv6_only_sock(sk2)) {
+		if (inet1->inet_rcv_saddr == inet2->inet_rcv_saddr)
+			return 1;
+		if (!inet1->inet_rcv_saddr || !inet2->inet_rcv_saddr)
+			return match_wildcard;
+	}
+	return 0;
 }
 
 static u32 udp4_portaddr_hash(const struct net *net, __be32 saddr,
@@ -459,8 +511,14 @@
 			badness = score;
 			reuseport = sk->sk_reuseport;
 			if (reuseport) {
+				struct sock *sk2;
 				hash = udp_ehashfn(net, daddr, hnum,
 						   saddr, sport);
+				sk2 = reuseport_select_sock(sk, hash);
+				if (sk2) {
+					result = sk2;
+					goto found;
+				}
 				matches = 1;
 			}
 		} else if (score == badness && reuseport) {
@@ -478,6 +536,7 @@
 	if (get_nulls_value(node) != slot2)
 		goto begin;
 	if (result) {
+found:
 		if (unlikely(!atomic_inc_not_zero_hint(&result->sk_refcnt, 2)))
 			result = NULL;
 		else if (unlikely(compute_score2(result, net, saddr, sport,
@@ -540,8 +599,14 @@
 			badness = score;
 			reuseport = sk->sk_reuseport;
 			if (reuseport) {
+				struct sock *sk2;
 				hash = udp_ehashfn(net, daddr, hnum,
 						   saddr, sport);
+				sk2 = reuseport_select_sock(sk, hash);
+				if (sk2) {
+					result = sk2;
+					goto found;
+				}
 				matches = 1;
 			}
 		} else if (score == badness && reuseport) {
@@ -560,6 +625,7 @@
 		goto begin;
 
 	if (result) {
+found:
 		if (unlikely(!atomic_inc_not_zero_hint(&result->sk_refcnt, 2)))
 			result = NULL;
 		else if (unlikely(compute_score(result, net, saddr, hnum, sport,
@@ -587,7 +653,8 @@
 struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
 			     __be32 daddr, __be16 dport, int dif)
 {
-	return __udp4_lib_lookup(net, saddr, sport, daddr, dport, dif, &udp_table);
+	return __udp4_lib_lookup(net, saddr, sport, daddr, dport, dif,
+				 &udp_table);
 }
 EXPORT_SYMBOL_GPL(udp4_lib_lookup);
 
@@ -1398,6 +1465,8 @@
 		hslot2 = udp_hashslot2(udptable, udp_sk(sk)->udp_portaddr_hash);
 
 		spin_lock_bh(&hslot->lock);
+		if (rcu_access_pointer(sk->sk_reuseport_cb))
+			reuseport_detach_sock(sk);
 		if (sk_nulls_del_node_init_rcu(sk)) {
 			hslot->count--;
 			inet_sk(sk)->inet_num = 0;
@@ -1425,22 +1494,28 @@
 		hslot2 = udp_hashslot2(udptable, udp_sk(sk)->udp_portaddr_hash);
 		nhslot2 = udp_hashslot2(udptable, newhash);
 		udp_sk(sk)->udp_portaddr_hash = newhash;
-		if (hslot2 != nhslot2) {
+
+		if (hslot2 != nhslot2 ||
+		    rcu_access_pointer(sk->sk_reuseport_cb)) {
 			hslot = udp_hashslot(udptable, sock_net(sk),
 					     udp_sk(sk)->udp_port_hash);
 			/* we must lock primary chain too */
 			spin_lock_bh(&hslot->lock);
+			if (rcu_access_pointer(sk->sk_reuseport_cb))
+				reuseport_detach_sock(sk);
 
-			spin_lock(&hslot2->lock);
-			hlist_nulls_del_init_rcu(&udp_sk(sk)->udp_portaddr_node);
-			hslot2->count--;
-			spin_unlock(&hslot2->lock);
+			if (hslot2 != nhslot2) {
+				spin_lock(&hslot2->lock);
+				hlist_nulls_del_init_rcu(&udp_sk(sk)->udp_portaddr_node);
+				hslot2->count--;
+				spin_unlock(&hslot2->lock);
 
-			spin_lock(&nhslot2->lock);
-			hlist_nulls_add_head_rcu(&udp_sk(sk)->udp_portaddr_node,
-						 &nhslot2->head);
-			nhslot2->count++;
-			spin_unlock(&nhslot2->lock);
+				spin_lock(&nhslot2->lock);
+				hlist_nulls_add_head_rcu(&udp_sk(sk)->udp_portaddr_node,
+							 &nhslot2->head);
+				nhslot2->count++;
+				spin_unlock(&nhslot2->lock);
+			}
 
 			spin_unlock_bh(&hslot->lock);
 		}