rds: tcp: use rds_destroy_pending() to synchronize netns/module teardown and rds connection/workq management

An rds_connection can get added during netns deletion between lines 528
and 529 of

  506 static void rds_tcp_kill_sock(struct net *net)
  :
  /* code to pull out all the rds_connections that should be destroyed */
  :
  528         spin_unlock_irq(&rds_tcp_conn_lock);
  529         list_for_each_entry_safe(tc, _tc, &tmp_list, t_tcp_node)
  530                 rds_conn_destroy(tc->t_cpath->cp_conn);

Such an rds_connection would miss out the rds_conn_destroy()
loop (that cancels all pending work) and (if it was scheduled
after netns deletion) could trigger the use-after-free.

A similar race-window exists for the module unload path
in rds_tcp_exit -> rds_tcp_destroy_conns

Concurrency with netns deletion (rds_tcp_kill_sock()) must be handled
by checking check_net() before enqueuing new work or adding new
connections.

Concurrency with module-unload is handled by maintaining a module
specific flag that is set at the start of the module exit function,
and must be checked before enqueuing new work or adding new connections.

This commit refactors existing RDS_DESTROY_PENDING checks added by
commit 3db6e0d172c9 ("rds: use RCU to synchronize work-enqueue with
connection teardown") and consolidates all the concurrency checks
listed above into the function rds_destroy_pending().

Signed-off-by: Sowmini Varadhan <sowmini.varadhan@oracle.com>
Acked-by: Santosh Shilimkar <santosh.shilimkar@oracle.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/rds/cong.c b/net/rds/cong.c
index 8d19fd2..63da9d2 100644
--- a/net/rds/cong.c
+++ b/net/rds/cong.c
@@ -223,7 +223,7 @@
 
 		rcu_read_lock();
 		if (!test_and_set_bit(0, &conn->c_map_queued) &&
-		    !test_bit(RDS_DESTROY_PENDING, &cp->cp_flags)) {
+		    !rds_destroy_pending(cp->cp_conn)) {
 			rds_stats_inc(s_cong_update_queued);
 			/* We cannot inline the call to rds_send_xmit() here
 			 * for two reasons (both pertaining to a TCP transport):
diff --git a/net/rds/connection.c b/net/rds/connection.c
index b10c0ef..94e190fe 100644
--- a/net/rds/connection.c
+++ b/net/rds/connection.c
@@ -220,8 +220,13 @@
 				     is_outgoing);
 		conn->c_path[i].cp_index = i;
 	}
-	ret = trans->conn_alloc(conn, gfp);
+	rcu_read_lock();
+	if (rds_destroy_pending(conn))
+		ret = -ENETDOWN;
+	else
+		ret = trans->conn_alloc(conn, gfp);
 	if (ret) {
+		rcu_read_unlock();
 		kfree(conn->c_path);
 		kmem_cache_free(rds_conn_slab, conn);
 		conn = ERR_PTR(ret);
@@ -283,6 +288,7 @@
 		}
 	}
 	spin_unlock_irqrestore(&rds_conn_lock, flags);
+	rcu_read_unlock();
 
 out:
 	return conn;
@@ -382,13 +388,10 @@
 {
 	struct rds_message *rm, *rtmp;
 
-	set_bit(RDS_DESTROY_PENDING, &cp->cp_flags);
-
 	if (!cp->cp_transport_data)
 		return;
 
 	/* make sure lingering queued work won't try to ref the conn */
-	synchronize_rcu();
 	cancel_delayed_work_sync(&cp->cp_send_w);
 	cancel_delayed_work_sync(&cp->cp_recv_w);
 
@@ -691,7 +694,7 @@
 	atomic_set(&cp->cp_state, RDS_CONN_ERROR);
 
 	rcu_read_lock();
-	if (!destroy && test_bit(RDS_DESTROY_PENDING, &cp->cp_flags)) {
+	if (!destroy && rds_destroy_pending(cp->cp_conn)) {
 		rcu_read_unlock();
 		return;
 	}
@@ -714,7 +717,7 @@
 void rds_conn_path_connect_if_down(struct rds_conn_path *cp)
 {
 	rcu_read_lock();
-	if (test_bit(RDS_DESTROY_PENDING, &cp->cp_flags)) {
+	if (rds_destroy_pending(cp->cp_conn)) {
 		rcu_read_unlock();
 		return;
 	}
diff --git a/net/rds/ib.c b/net/rds/ib.c
index ff0c980..50a88f3 100644
--- a/net/rds/ib.c
+++ b/net/rds/ib.c
@@ -48,6 +48,7 @@
 static unsigned int rds_ib_mr_1m_pool_size = RDS_MR_1M_POOL_SIZE;
 static unsigned int rds_ib_mr_8k_pool_size = RDS_MR_8K_POOL_SIZE;
 unsigned int rds_ib_retry_count = RDS_IB_DEFAULT_RETRY_COUNT;
+static atomic_t rds_ib_unloading;
 
 module_param(rds_ib_mr_1m_pool_size, int, 0444);
 MODULE_PARM_DESC(rds_ib_mr_1m_pool_size, " Max number of 1M mr per HCA");
@@ -378,8 +379,23 @@
 	flush_workqueue(rds_wq);
 }
 
+static void rds_ib_set_unloading(void)
+{
+	atomic_set(&rds_ib_unloading, 1);
+}
+
+static bool rds_ib_is_unloading(struct rds_connection *conn)
+{
+	struct rds_conn_path *cp = &conn->c_path[0];
+
+	return (test_bit(RDS_DESTROY_PENDING, &cp->cp_flags) ||
+		atomic_read(&rds_ib_unloading) != 0);
+}
+
 void rds_ib_exit(void)
 {
+	rds_ib_set_unloading();
+	synchronize_rcu();
 	rds_info_deregister_func(RDS_INFO_IB_CONNECTIONS, rds_ib_ic_info);
 	rds_ib_unregister_client();
 	rds_ib_destroy_nodev_conns();
@@ -413,6 +429,7 @@
 	.flush_mrs		= rds_ib_flush_mrs,
 	.t_owner		= THIS_MODULE,
 	.t_name			= "infiniband",
+	.t_unloading		= rds_ib_is_unloading,
 	.t_type			= RDS_TRANS_IB
 };
 
diff --git a/net/rds/ib_cm.c b/net/rds/ib_cm.c
index 80fb6f6..eea1d86 100644
--- a/net/rds/ib_cm.c
+++ b/net/rds/ib_cm.c
@@ -117,6 +117,7 @@
 			  &conn->c_laddr, &conn->c_faddr,
 			  RDS_PROTOCOL_MAJOR(conn->c_version),
 			  RDS_PROTOCOL_MINOR(conn->c_version));
+		set_bit(RDS_DESTROY_PENDING, &conn->c_path[0].cp_flags);
 		rds_conn_destroy(conn);
 		return;
 	} else {
diff --git a/net/rds/rds.h b/net/rds/rds.h
index 374ae83..7301b9b 100644
--- a/net/rds/rds.h
+++ b/net/rds/rds.h
@@ -518,6 +518,7 @@
 	void (*sync_mr)(void *trans_private, int direction);
 	void (*free_mr)(void *trans_private, int invalidate);
 	void (*flush_mrs)(void);
+	bool (*t_unloading)(struct rds_connection *conn);
 };
 
 struct rds_sock {
@@ -862,6 +863,12 @@
 		__rds_put_mr_final(mr);
 }
 
+static inline bool rds_destroy_pending(struct rds_connection *conn)
+{
+	return !check_net(rds_conn_net(conn)) ||
+	       (conn->c_trans->t_unloading && conn->c_trans->t_unloading(conn));
+}
+
 /* stats.c */
 DECLARE_PER_CPU_SHARED_ALIGNED(struct rds_statistics, rds_stats);
 #define rds_stats_inc_which(which, member) do {		\
diff --git a/net/rds/send.c b/net/rds/send.c
index d3e32d1..b1b0022 100644
--- a/net/rds/send.c
+++ b/net/rds/send.c
@@ -162,7 +162,7 @@
 		goto out;
 	}
 
-	if (test_bit(RDS_DESTROY_PENDING, &cp->cp_flags)) {
+	if (rds_destroy_pending(cp->cp_conn)) {
 		release_in_xmit(cp);
 		ret = -ENETUNREACH; /* dont requeue send work */
 		goto out;
@@ -444,7 +444,7 @@
 			if (batch_count < send_batch_count)
 				goto restart;
 			rcu_read_lock();
-			if (test_bit(RDS_DESTROY_PENDING, &cp->cp_flags))
+			if (rds_destroy_pending(cp->cp_conn))
 				ret = -ENETUNREACH;
 			else
 				queue_delayed_work(rds_wq, &cp->cp_send_w, 1);
@@ -1162,7 +1162,7 @@
 	else
 		cpath = &conn->c_path[0];
 
-	if (test_bit(RDS_DESTROY_PENDING, &cpath->cp_flags)) {
+	if (rds_destroy_pending(conn)) {
 		ret = -EAGAIN;
 		goto out;
 	}
@@ -1209,7 +1209,7 @@
 	if (ret == -ENOMEM || ret == -EAGAIN) {
 		ret = 0;
 		rcu_read_lock();
-		if (test_bit(RDS_DESTROY_PENDING, &cpath->cp_flags))
+		if (rds_destroy_pending(cpath->cp_conn))
 			ret = -ENETUNREACH;
 		else
 			queue_delayed_work(rds_wq, &cpath->cp_send_w, 1);
@@ -1295,7 +1295,7 @@
 
 	/* schedule the send work on rds_wq */
 	rcu_read_lock();
-	if (!test_bit(RDS_DESTROY_PENDING, &cp->cp_flags))
+	if (!rds_destroy_pending(cp->cp_conn))
 		queue_delayed_work(rds_wq, &cp->cp_send_w, 1);
 	rcu_read_unlock();
 
diff --git a/net/rds/tcp.c b/net/rds/tcp.c
index 9920d2f..44c4652 100644
--- a/net/rds/tcp.c
+++ b/net/rds/tcp.c
@@ -49,6 +49,7 @@
 /* Track rds_tcp_connection structs so they can be cleaned up */
 static DEFINE_SPINLOCK(rds_tcp_conn_lock);
 static LIST_HEAD(rds_tcp_conn_list);
+static atomic_t rds_tcp_unloading = ATOMIC_INIT(0);
 
 static struct kmem_cache *rds_tcp_conn_slab;
 
@@ -274,14 +275,13 @@
 static void rds_tcp_conn_free(void *arg)
 {
 	struct rds_tcp_connection *tc = arg;
-	unsigned long flags;
 
 	rdsdebug("freeing tc %p\n", tc);
 
-	spin_lock_irqsave(&rds_tcp_conn_lock, flags);
+	spin_lock_bh(&rds_tcp_conn_lock);
 	if (!tc->t_tcp_node_detached)
 		list_del(&tc->t_tcp_node);
-	spin_unlock_irqrestore(&rds_tcp_conn_lock, flags);
+	spin_unlock_bh(&rds_tcp_conn_lock);
 
 	kmem_cache_free(rds_tcp_conn_slab, tc);
 }
@@ -296,7 +296,7 @@
 		tc = kmem_cache_alloc(rds_tcp_conn_slab, gfp);
 		if (!tc) {
 			ret = -ENOMEM;
-			break;
+			goto fail;
 		}
 		mutex_init(&tc->t_conn_path_lock);
 		tc->t_sock = NULL;
@@ -306,14 +306,19 @@
 
 		conn->c_path[i].cp_transport_data = tc;
 		tc->t_cpath = &conn->c_path[i];
+		tc->t_tcp_node_detached = true;
 
-		spin_lock_irq(&rds_tcp_conn_lock);
-		tc->t_tcp_node_detached = false;
-		list_add_tail(&tc->t_tcp_node, &rds_tcp_conn_list);
-		spin_unlock_irq(&rds_tcp_conn_lock);
 		rdsdebug("rds_conn_path [%d] tc %p\n", i,
 			 conn->c_path[i].cp_transport_data);
 	}
+	spin_lock_bh(&rds_tcp_conn_lock);
+	for (i = 0; i < RDS_MPATH_WORKERS; i++) {
+		tc = conn->c_path[i].cp_transport_data;
+		tc->t_tcp_node_detached = false;
+		list_add_tail(&tc->t_tcp_node, &rds_tcp_conn_list);
+	}
+	spin_unlock_bh(&rds_tcp_conn_lock);
+fail:
 	if (ret) {
 		for (j = 0; j < i; j++)
 			rds_tcp_conn_free(conn->c_path[j].cp_transport_data);
@@ -332,6 +337,16 @@
 	return false;
 }
 
+static void rds_tcp_set_unloading(void)
+{
+	atomic_set(&rds_tcp_unloading, 1);
+}
+
+static bool rds_tcp_is_unloading(struct rds_connection *conn)
+{
+	return atomic_read(&rds_tcp_unloading) != 0;
+}
+
 static void rds_tcp_destroy_conns(void)
 {
 	struct rds_tcp_connection *tc, *_tc;
@@ -370,6 +385,7 @@
 	.t_type			= RDS_TRANS_TCP,
 	.t_prefer_loopback	= 1,
 	.t_mp_capable		= 1,
+	.t_unloading		= rds_tcp_is_unloading,
 };
 
 static unsigned int rds_tcp_netid;
@@ -513,7 +529,7 @@
 
 	rtn->rds_tcp_listen_sock = NULL;
 	rds_tcp_listen_stop(lsock, &rtn->rds_tcp_accept_w);
-	spin_lock_irq(&rds_tcp_conn_lock);
+	spin_lock_bh(&rds_tcp_conn_lock);
 	list_for_each_entry_safe(tc, _tc, &rds_tcp_conn_list, t_tcp_node) {
 		struct net *c_net = read_pnet(&tc->t_cpath->cp_conn->c_net);
 
@@ -526,7 +542,7 @@
 			tc->t_tcp_node_detached = true;
 		}
 	}
-	spin_unlock_irq(&rds_tcp_conn_lock);
+	spin_unlock_bh(&rds_tcp_conn_lock);
 	list_for_each_entry_safe(tc, _tc, &tmp_list, t_tcp_node)
 		rds_conn_destroy(tc->t_cpath->cp_conn);
 }
@@ -574,7 +590,7 @@
 {
 	struct rds_tcp_connection *tc, *_tc;
 
-	spin_lock_irq(&rds_tcp_conn_lock);
+	spin_lock_bh(&rds_tcp_conn_lock);
 	list_for_each_entry_safe(tc, _tc, &rds_tcp_conn_list, t_tcp_node) {
 		struct net *c_net = read_pnet(&tc->t_cpath->cp_conn->c_net);
 
@@ -584,7 +600,7 @@
 		/* reconnect with new parameters */
 		rds_conn_path_drop(tc->t_cpath, false);
 	}
-	spin_unlock_irq(&rds_tcp_conn_lock);
+	spin_unlock_bh(&rds_tcp_conn_lock);
 }
 
 static int rds_tcp_skbuf_handler(struct ctl_table *ctl, int write,
@@ -607,6 +623,8 @@
 
 static void rds_tcp_exit(void)
 {
+	rds_tcp_set_unloading();
+	synchronize_rcu();
 	rds_info_deregister_func(RDS_INFO_TCP_SOCKETS, rds_tcp_tc_info);
 	unregister_pernet_subsys(&rds_tcp_net_ops);
 	if (unregister_netdevice_notifier(&rds_tcp_dev_notifier))
diff --git a/net/rds/tcp_connect.c b/net/rds/tcp_connect.c
index 534c67a..d999e70 100644
--- a/net/rds/tcp_connect.c
+++ b/net/rds/tcp_connect.c
@@ -170,7 +170,7 @@
 		 cp->cp_conn, tc, sock);
 
 	if (sock) {
-		if (test_bit(RDS_DESTROY_PENDING, &cp->cp_flags))
+		if (rds_destroy_pending(cp->cp_conn))
 			rds_tcp_set_linger(sock);
 		sock->ops->shutdown(sock, RCV_SHUTDOWN | SEND_SHUTDOWN);
 		lock_sock(sock->sk);
diff --git a/net/rds/tcp_recv.c b/net/rds/tcp_recv.c
index dd707b9..b9fbd2e 100644
--- a/net/rds/tcp_recv.c
+++ b/net/rds/tcp_recv.c
@@ -323,7 +323,7 @@
 
 	if (rds_tcp_read_sock(cp, GFP_ATOMIC) == -ENOMEM) {
 		rcu_read_lock();
-		if (!test_bit(RDS_DESTROY_PENDING, &cp->cp_flags))
+		if (!rds_destroy_pending(cp->cp_conn))
 			queue_delayed_work(rds_wq, &cp->cp_recv_w, 0);
 		rcu_read_unlock();
 	}
diff --git a/net/rds/tcp_send.c b/net/rds/tcp_send.c
index 16f6574..7df869d 100644
--- a/net/rds/tcp_send.c
+++ b/net/rds/tcp_send.c
@@ -204,7 +204,7 @@
 
 	rcu_read_lock();
 	if ((refcount_read(&sk->sk_wmem_alloc) << 1) <= sk->sk_sndbuf &&
-	    !test_bit(RDS_DESTROY_PENDING, &cp->cp_flags))
+	    !rds_destroy_pending(cp->cp_conn))
 		queue_delayed_work(rds_wq, &cp->cp_send_w, 0);
 	rcu_read_unlock();
 
diff --git a/net/rds/threads.c b/net/rds/threads.c
index eb76db1..c52861d 100644
--- a/net/rds/threads.c
+++ b/net/rds/threads.c
@@ -88,7 +88,7 @@
 	cp->cp_reconnect_jiffies = 0;
 	set_bit(0, &cp->cp_conn->c_map_queued);
 	rcu_read_lock();
-	if (!test_bit(RDS_DESTROY_PENDING, &cp->cp_flags)) {
+	if (!rds_destroy_pending(cp->cp_conn)) {
 		queue_delayed_work(rds_wq, &cp->cp_send_w, 0);
 		queue_delayed_work(rds_wq, &cp->cp_recv_w, 0);
 	}
@@ -138,7 +138,7 @@
 	if (cp->cp_reconnect_jiffies == 0) {
 		cp->cp_reconnect_jiffies = rds_sysctl_reconnect_min_jiffies;
 		rcu_read_lock();
-		if (!test_bit(RDS_DESTROY_PENDING, &cp->cp_flags))
+		if (!rds_destroy_pending(cp->cp_conn))
 			queue_delayed_work(rds_wq, &cp->cp_conn_w, 0);
 		rcu_read_unlock();
 		return;
@@ -149,7 +149,7 @@
 		 rand % cp->cp_reconnect_jiffies, cp->cp_reconnect_jiffies,
 		 conn, &conn->c_laddr, &conn->c_faddr);
 	rcu_read_lock();
-	if (!test_bit(RDS_DESTROY_PENDING, &cp->cp_flags))
+	if (!rds_destroy_pending(cp->cp_conn))
 		queue_delayed_work(rds_wq, &cp->cp_conn_w,
 				   rand % cp->cp_reconnect_jiffies);
 	rcu_read_unlock();