tcp: md5: protects md5sig_info with RCU

This patch makes sure we use appropriate memory barriers before
publishing tp->md5sig_info, allowing tcp_md5_do_lookup() being used from
tcp_v4_send_reset() without holding socket lock (upcoming patch from
Shawn Lu)

Note we also need to respect rcu grace period before its freeing, since
we can free socket without this grace period thanks to
SLAB_DESTROY_BY_RCU

Signed-off-by: Eric Dumazet <eric.dumazet@gmail.com>
Cc: Shawn Lu <shawn.lu@ericsson.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index da5d322..567cca9 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -879,14 +879,18 @@
 	struct tcp_md5sig_key *key;
 	struct hlist_node *pos;
 	unsigned int size = sizeof(struct in_addr);
+	struct tcp_md5sig_info *md5sig;
 
-	if (!tp->md5sig_info)
+	/* caller either holds rcu_read_lock() or socket lock */
+	md5sig = rcu_dereference_check(tp->md5sig_info,
+				       sock_owned_by_user(sk));
+	if (!md5sig)
 		return NULL;
 #if IS_ENABLED(CONFIG_IPV6)
 	if (family == AF_INET6)
 		size = sizeof(struct in6_addr);
 #endif
-	hlist_for_each_entry_rcu(key, pos, &tp->md5sig_info->head, node) {
+	hlist_for_each_entry_rcu(key, pos, &md5sig->head, node) {
 		if (key->family != family)
 			continue;
 		if (!memcmp(&key->addr, addr, size))
@@ -932,7 +936,8 @@
 		return 0;
 	}
 
-	md5sig = tp->md5sig_info;
+	md5sig = rcu_dereference_protected(tp->md5sig_info,
+					   sock_owned_by_user(sk));
 	if (!md5sig) {
 		md5sig = kmalloc(sizeof(*md5sig), gfp);
 		if (!md5sig)
@@ -940,7 +945,7 @@
 
 		sk_nocaps_add(sk, NETIF_F_GSO_MASK);
 		INIT_HLIST_HEAD(&md5sig->head);
-		tp->md5sig_info = md5sig;
+		rcu_assign_pointer(tp->md5sig_info, md5sig);
 	}
 
 	key = sock_kmalloc(sk, sizeof(*key), gfp);
@@ -966,6 +971,7 @@
 {
 	struct tcp_sock *tp = tcp_sk(sk);
 	struct tcp_md5sig_key *key;
+	struct tcp_md5sig_info *md5sig;
 
 	key = tcp_md5_do_lookup(sk, (union tcp_md5_addr *)&addr, AF_INET);
 	if (!key)
@@ -973,7 +979,9 @@
 	hlist_del_rcu(&key->node);
 	atomic_sub(sizeof(*key), &sk->sk_omem_alloc);
 	kfree_rcu(key, rcu);
-	if (hlist_empty(&tp->md5sig_info->head))
+	md5sig = rcu_dereference_protected(tp->md5sig_info,
+					   sock_owned_by_user(sk));
+	if (hlist_empty(&md5sig->head))
 		tcp_free_md5sig_pool();
 	return 0;
 }
@@ -984,10 +992,13 @@
 	struct tcp_sock *tp = tcp_sk(sk);
 	struct tcp_md5sig_key *key;
 	struct hlist_node *pos, *n;
+	struct tcp_md5sig_info *md5sig;
 
-	if (!hlist_empty(&tp->md5sig_info->head))
+	md5sig = rcu_dereference_protected(tp->md5sig_info, 1);
+
+	if (!hlist_empty(&md5sig->head))
 		tcp_free_md5sig_pool();
-	hlist_for_each_entry_safe(key, pos, n, &tp->md5sig_info->head, node) {
+	hlist_for_each_entry_safe(key, pos, n, &md5sig->head, node) {
 		hlist_del_rcu(&key->node);
 		atomic_sub(sizeof(*key), &sk->sk_omem_alloc);
 		kfree_rcu(key, rcu);
@@ -1009,12 +1020,9 @@
 	if (sin->sin_family != AF_INET)
 		return -EINVAL;
 
-	if (!cmd.tcpm_key || !cmd.tcpm_keylen) {
-		if (!tcp_sk(sk)->md5sig_info)
-			return -ENOENT;
+	if (!cmd.tcpm_key || !cmd.tcpm_keylen)
 		return tcp_md5_do_del(sk, (union tcp_md5_addr *)&sin->sin_addr.s_addr,
 				      AF_INET);
-	}
 
 	if (cmd.tcpm_keylen > TCP_MD5SIG_MAXKEYLEN)
 		return -EINVAL;
@@ -1896,7 +1904,7 @@
 	/* Clean up the MD5 key list, if any */
 	if (tp->md5sig_info) {
 		tcp_clear_md5_list(sk);
-		kfree(tp->md5sig_info);
+		kfree_rcu(tp->md5sig_info, rcu);
 		tp->md5sig_info = NULL;
 	}
 #endif