udp: dynamically size hash tables at boot time

UDP_HTABLE_SIZE was initialy defined to 128, which is a bit small for
several setups.

4000 active UDP sockets -> 32 sockets per chain in average. An
incoming frame has to lookup all sockets to find best match, so long
chains hurt latency.

Instead of a fixed size hash table that cant be perfect for every
needs, let UDP stack choose its table size at boot time like tcp/ip
route, using alloc_large_system_hash() helper

Add an optional boot parameter, uhash_entries=x so that an admin can
force a size between 256 and 65536 if needed, like thash_entries and
rhash_entries.

dmesg logs two new lines :
[    0.647039] UDP hash table entries: 512 (order: 0, 4096 bytes)
[    0.647099] UDP Lite hash table entries: 512 (order: 0, 4096 bytes)

Maximal size on 64bit arches would be 65536 slots, ie 1 MBytes for non
debugging spinlocks.

Signed-off-by: Eric Dumazet <eric.dumazet@gmail.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/Documentation/kernel-parameters.txt b/Documentation/kernel-parameters.txt
index 6fa7292..02df20b 100644
--- a/Documentation/kernel-parameters.txt
+++ b/Documentation/kernel-parameters.txt
@@ -2589,6 +2589,9 @@
 	uart6850=	[HW,OSS]
 			Format: <io>,<irq>
 
+	uhash_entries=	[KNL,NET]
+			Set number of hash buckets for UDP/UDP-Lite connections
+
 	uhci-hcd.ignore_oc=
 			[USB] Ignore overcurrent events (default N).
 			Some badly-designed motherboards generate lots of
diff --git a/include/linux/udp.h b/include/linux/udp.h
index 0cf5c4c..832361e 100644
--- a/include/linux/udp.h
+++ b/include/linux/udp.h
@@ -45,11 +45,11 @@
 	return (struct udphdr *)skb_transport_header(skb);
 }
 
-#define UDP_HTABLE_SIZE		128
+#define UDP_HTABLE_SIZE_MIN		(CONFIG_BASE_SMALL ? 128 : 256)
 
-static inline int udp_hashfn(struct net *net, const unsigned num)
+static inline int udp_hashfn(struct net *net, unsigned num, unsigned mask)
 {
-	return (num + net_hash_mix(net)) & (UDP_HTABLE_SIZE - 1);
+	return (num + net_hash_mix(net)) & mask;
 }
 
 struct udp_sock {
diff --git a/include/net/udp.h b/include/net/udp.h
index f98abd2..22aa2e7 100644
--- a/include/net/udp.h
+++ b/include/net/udp.h
@@ -54,12 +54,19 @@
 	struct hlist_nulls_head	head;
 	spinlock_t		lock;
 } __attribute__((aligned(2 * sizeof(long))));
+
 struct udp_table {
-	struct udp_hslot	hash[UDP_HTABLE_SIZE];
+	struct udp_hslot	*hash;
+	unsigned int mask;
+	unsigned int log;
 };
 extern struct udp_table udp_table;
-extern void udp_table_init(struct udp_table *);
-
+extern void udp_table_init(struct udp_table *, const char *);
+static inline struct udp_hslot *udp_hashslot(struct udp_table *table,
+					     struct net *net, unsigned num)
+{
+	return &table->hash[udp_hashfn(net, num, table->mask)];
+}
 
 /* Note: this must match 'valbool' in sock_setsockopt */
 #define UDP_CSUM_NOXMIT		1
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 6ec6a8a..194bcdc 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -106,7 +106,7 @@
 #include <net/xfrm.h>
 #include "udp_impl.h"
 
-struct udp_table udp_table;
+struct udp_table udp_table __read_mostly;
 EXPORT_SYMBOL(udp_table);
 
 int sysctl_udp_mem[3] __read_mostly;
@@ -121,14 +121,16 @@
 atomic_t udp_memory_allocated;
 EXPORT_SYMBOL(udp_memory_allocated);
 
-#define PORTS_PER_CHAIN (65536 / UDP_HTABLE_SIZE)
+#define MAX_UDP_PORTS 65536
+#define PORTS_PER_CHAIN (MAX_UDP_PORTS / UDP_HTABLE_SIZE_MIN)
 
 static int udp_lib_lport_inuse(struct net *net, __u16 num,
 			       const struct udp_hslot *hslot,
 			       unsigned long *bitmap,
 			       struct sock *sk,
 			       int (*saddr_comp)(const struct sock *sk1,
-						 const struct sock *sk2))
+						 const struct sock *sk2),
+			       unsigned int log)
 {
 	struct sock *sk2;
 	struct hlist_nulls_node *node;
@@ -142,8 +144,7 @@
 			|| sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
 		    (*saddr_comp)(sk, sk2)) {
 			if (bitmap)
-				__set_bit(sk2->sk_hash / UDP_HTABLE_SIZE,
-					  bitmap);
+				__set_bit(sk2->sk_hash >> log, bitmap);
 			else
 				return 1;
 		}
@@ -180,13 +181,15 @@
 		/*
 		 * force rand to be an odd multiple of UDP_HTABLE_SIZE
 		 */
-		rand = (rand | 1) * UDP_HTABLE_SIZE;
-		for (last = first + UDP_HTABLE_SIZE; first != last; first++) {
-			hslot = &udptable->hash[udp_hashfn(net, first)];
+		rand = (rand | 1) * (udptable->mask + 1);
+		for (last = first + udptable->mask + 1;
+		     first != last;
+		     first++) {
+			hslot = udp_hashslot(udptable, net, first);
 			bitmap_zero(bitmap, PORTS_PER_CHAIN);
 			spin_lock_bh(&hslot->lock);
 			udp_lib_lport_inuse(net, snum, hslot, bitmap, sk,
-					    saddr_comp);
+					    saddr_comp, udptable->log);
 
 			snum = first;
 			/*
@@ -196,7 +199,7 @@
 			 */
 			do {
 				if (low <= snum && snum <= high &&
-				    !test_bit(snum / UDP_HTABLE_SIZE, bitmap))
+				    !test_bit(snum >> udptable->log, bitmap))
 					goto found;
 				snum += rand;
 			} while (snum != first);
@@ -204,9 +207,10 @@
 		}
 		goto fail;
 	} else {
-		hslot = &udptable->hash[udp_hashfn(net, snum)];
+		hslot = udp_hashslot(udptable, net, snum);
 		spin_lock_bh(&hslot->lock);
-		if (udp_lib_lport_inuse(net, snum, hslot, NULL, sk, saddr_comp))
+		if (udp_lib_lport_inuse(net, snum, hslot, NULL, sk,
+					saddr_comp, 0))
 			goto fail_unlock;
 	}
 found:
@@ -283,7 +287,7 @@
 	struct sock *sk, *result;
 	struct hlist_nulls_node *node;
 	unsigned short hnum = ntohs(dport);
-	unsigned int hash = udp_hashfn(net, hnum);
+	unsigned int hash = udp_hashfn(net, hnum, udptable->mask);
 	struct udp_hslot *hslot = &udptable->hash[hash];
 	int score, badness;
 
@@ -1013,8 +1017,8 @@
 {
 	if (sk_hashed(sk)) {
 		struct udp_table *udptable = sk->sk_prot->h.udp_table;
-		unsigned int hash = udp_hashfn(sock_net(sk), sk->sk_hash);
-		struct udp_hslot *hslot = &udptable->hash[hash];
+		struct udp_hslot *hslot = udp_hashslot(udptable, sock_net(sk),
+						     sk->sk_hash);
 
 		spin_lock_bh(&hslot->lock);
 		if (sk_nulls_del_node_init_rcu(sk)) {
@@ -1169,7 +1173,7 @@
 				    struct udp_table *udptable)
 {
 	struct sock *sk;
-	struct udp_hslot *hslot = &udptable->hash[udp_hashfn(net, ntohs(uh->dest))];
+	struct udp_hslot *hslot = udp_hashslot(udptable, net, ntohs(uh->dest));
 	int dif;
 
 	spin_lock(&hslot->lock);
@@ -1609,9 +1613,14 @@
 	struct udp_iter_state *state = seq->private;
 	struct net *net = seq_file_net(seq);
 
-	for (state->bucket = start; state->bucket < UDP_HTABLE_SIZE; ++state->bucket) {
+	for (state->bucket = start; state->bucket <= state->udp_table->mask;
+	     ++state->bucket) {
 		struct hlist_nulls_node *node;
 		struct udp_hslot *hslot = &state->udp_table->hash[state->bucket];
+
+		if (hlist_nulls_empty(&hslot->head))
+			continue;
+
 		spin_lock_bh(&hslot->lock);
 		sk_nulls_for_each(sk, node, &hslot->head) {
 			if (!net_eq(sock_net(sk), net))
@@ -1636,7 +1645,7 @@
 	} while (sk && (!net_eq(sock_net(sk), net) || sk->sk_family != state->family));
 
 	if (!sk) {
-		if (state->bucket < UDP_HTABLE_SIZE)
+		if (state->bucket <= state->udp_table->mask)
 			spin_unlock_bh(&state->udp_table->hash[state->bucket].lock);
 		return udp_get_first(seq, state->bucket + 1);
 	}
@@ -1656,7 +1665,7 @@
 static void *udp_seq_start(struct seq_file *seq, loff_t *pos)
 {
 	struct udp_iter_state *state = seq->private;
-	state->bucket = UDP_HTABLE_SIZE;
+	state->bucket = MAX_UDP_PORTS;
 
 	return *pos ? udp_get_idx(seq, *pos-1) : SEQ_START_TOKEN;
 }
@@ -1678,7 +1687,7 @@
 {
 	struct udp_iter_state *state = seq->private;
 
-	if (state->bucket < UDP_HTABLE_SIZE)
+	if (state->bucket <= state->udp_table->mask)
 		spin_unlock_bh(&state->udp_table->hash[state->bucket].lock);
 }
 
@@ -1738,7 +1747,7 @@
 	__u16 destp	  = ntohs(inet->dport);
 	__u16 srcp	  = ntohs(inet->sport);
 
-	seq_printf(f, "%4d: %08X:%04X %08X:%04X"
+	seq_printf(f, "%5d: %08X:%04X %08X:%04X"
 		" %02X %08X:%08X %02X:%08lX %08X %5d %8d %lu %d %p %d%n",
 		bucket, src, srcp, dest, destp, sp->sk_state,
 		sk_wmem_alloc_get(sp),
@@ -1804,11 +1813,43 @@
 }
 #endif /* CONFIG_PROC_FS */
 
-void __init udp_table_init(struct udp_table *table)
+static __initdata unsigned long uhash_entries;
+static int __init set_uhash_entries(char *str)
 {
-	int i;
+	if (!str)
+		return 0;
+	uhash_entries = simple_strtoul(str, &str, 0);
+	if (uhash_entries && uhash_entries < UDP_HTABLE_SIZE_MIN)
+		uhash_entries = UDP_HTABLE_SIZE_MIN;
+	return 1;
+}
+__setup("uhash_entries=", set_uhash_entries);
 
-	for (i = 0; i < UDP_HTABLE_SIZE; i++) {
+void __init udp_table_init(struct udp_table *table, const char *name)
+{
+	unsigned int i;
+
+	if (!CONFIG_BASE_SMALL)
+		table->hash = alloc_large_system_hash(name,
+			sizeof(struct udp_hslot),
+			uhash_entries,
+			21, /* one slot per 2 MB */
+			0,
+			&table->log,
+			&table->mask,
+			64 * 1024);
+	/*
+	 * Make sure hash table has the minimum size
+	 */
+	if (CONFIG_BASE_SMALL || table->mask < UDP_HTABLE_SIZE_MIN - 1) {
+		table->hash = kmalloc(UDP_HTABLE_SIZE_MIN *
+				      sizeof(struct udp_hslot), GFP_KERNEL);
+		if (!table->hash)
+			panic(name);
+		table->log = ilog2(UDP_HTABLE_SIZE_MIN);
+		table->mask = UDP_HTABLE_SIZE_MIN - 1;
+	}
+	for (i = 0; i <= table->mask; i++) {
 		INIT_HLIST_NULLS_HEAD(&table->hash[i].head, i);
 		spin_lock_init(&table->hash[i].lock);
 	}
@@ -1818,7 +1859,7 @@
 {
 	unsigned long nr_pages, limit;
 
-	udp_table_init(&udp_table);
+	udp_table_init(&udp_table, "UDP");
 	/* Set the pressure threshold up by the same strategy of TCP. It is a
 	 * fraction of global memory that is up to 1/2 at 256 MB, decreasing
 	 * toward zero with the amount of memory, with a floor of 128 pages.
diff --git a/net/ipv4/udplite.c b/net/ipv4/udplite.c
index 95248d7..470c504 100644
--- a/net/ipv4/udplite.c
+++ b/net/ipv4/udplite.c
@@ -12,7 +12,7 @@
  */
 #include "udp_impl.h"
 
-struct udp_table 	udplite_table;
+struct udp_table 	udplite_table __read_mostly;
 EXPORT_SYMBOL(udplite_table);
 
 static int udplite_rcv(struct sk_buff *skb)
@@ -110,7 +110,7 @@
 
 void __init udplite4_register(void)
 {
-	udp_table_init(&udplite_table);
+	udp_table_init(&udplite_table, "UDP-Lite");
 	if (proto_register(&udplite_prot, 1))
 		goto out_register_err;
 
diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
index c6a303e..ff778c1 100644
--- a/net/ipv6/udp.c
+++ b/net/ipv6/udp.c
@@ -132,7 +132,7 @@
 	struct sock *sk, *result;
 	struct hlist_nulls_node *node;
 	unsigned short hnum = ntohs(dport);
-	unsigned int hash = udp_hashfn(net, hnum);
+	unsigned int hash = udp_hashfn(net, hnum, udptable->mask);
 	struct udp_hslot *hslot = &udptable->hash[hash];
 	int score, badness;
 
@@ -452,7 +452,7 @@
 {
 	struct sock *sk, *sk2;
 	const struct udphdr *uh = udp_hdr(skb);
-	struct udp_hslot *hslot = &udptable->hash[udp_hashfn(net, ntohs(uh->dest))];
+	struct udp_hslot *hslot = udp_hashslot(udptable, net, ntohs(uh->dest));
 	int dif;
 
 	spin_lock(&hslot->lock);
@@ -1197,7 +1197,7 @@
 	destp = ntohs(inet->dport);
 	srcp  = ntohs(inet->sport);
 	seq_printf(seq,
-		   "%4d: %08X%08X%08X%08X:%04X %08X%08X%08X%08X:%04X "
+		   "%5d: %08X%08X%08X%08X:%04X %08X%08X%08X%08X:%04X "
 		   "%02X %08X:%08X %02X:%08lX %08X %5d %8d %lu %d %p %d\n",
 		   bucket,
 		   src->s6_addr32[0], src->s6_addr32[1],