net: tcp_memcontrol: simplify linkage between socket and page counter

There won't be any separate counters for socket memory consumed by
protocols other than TCP in the future.  Remove the indirection and link
sockets directly to their owning memory cgroup.

Signed-off-by: Johannes Weiner <hannes@cmpxchg.org>
Reviewed-by: Vladimir Davydov <vdavydov@virtuozzo.com>
Acked-by: David S. Miller <davem@davemloft.net>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
Signed-off-by: Linus Torvalds <torvalds@linux-foundation.org>
diff --git a/net/core/sock.c b/net/core/sock.c
index 89ae859..3535bff 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -195,44 +195,6 @@
 }
 EXPORT_SYMBOL(sk_net_capable);
 
-
-#ifdef CONFIG_MEMCG_KMEM
-int mem_cgroup_sockets_init(struct mem_cgroup *memcg, struct cgroup_subsys *ss)
-{
-	struct proto *proto;
-	int ret = 0;
-
-	mutex_lock(&proto_list_mutex);
-	list_for_each_entry(proto, &proto_list, node) {
-		if (proto->init_cgroup) {
-			ret = proto->init_cgroup(memcg, ss);
-			if (ret)
-				goto out;
-		}
-	}
-
-	mutex_unlock(&proto_list_mutex);
-	return ret;
-out:
-	list_for_each_entry_continue_reverse(proto, &proto_list, node)
-		if (proto->destroy_cgroup)
-			proto->destroy_cgroup(memcg);
-	mutex_unlock(&proto_list_mutex);
-	return ret;
-}
-
-void mem_cgroup_sockets_destroy(struct mem_cgroup *memcg)
-{
-	struct proto *proto;
-
-	mutex_lock(&proto_list_mutex);
-	list_for_each_entry_reverse(proto, &proto_list, node)
-		if (proto->destroy_cgroup)
-			proto->destroy_cgroup(memcg);
-	mutex_unlock(&proto_list_mutex);
-}
-#endif
-
 /*
  * Each address family might have different locking rules, so we have
  * one slock key per address family:
@@ -1601,7 +1563,7 @@
 		sk_set_socket(newsk, NULL);
 		newsk->sk_wq = NULL;
 
-		if (mem_cgroup_sockets_enabled && sk->sk_cgrp)
+		if (mem_cgroup_sockets_enabled && sk->sk_memcg)
 			sock_update_memcg(newsk);
 
 		if (newsk->sk_prot->sockets_allocated)
@@ -2089,8 +2051,8 @@
 
 	allocated = sk_memory_allocated_add(sk, amt);
 
-	if (mem_cgroup_sockets_enabled && sk->sk_cgrp &&
-	    !mem_cgroup_charge_skmem(sk->sk_cgrp, amt))
+	if (mem_cgroup_sockets_enabled && sk->sk_memcg &&
+	    !mem_cgroup_charge_skmem(sk->sk_memcg, amt))
 		goto suppress_allocation;
 
 	/* Under limit. */
@@ -2153,8 +2115,8 @@
 
 	sk_memory_allocated_sub(sk, amt);
 
-	if (mem_cgroup_sockets_enabled && sk->sk_cgrp)
-		mem_cgroup_uncharge_skmem(sk->sk_cgrp, amt);
+	if (mem_cgroup_sockets_enabled && sk->sk_memcg)
+		mem_cgroup_uncharge_skmem(sk->sk_memcg, amt);
 
 	return 0;
 }
@@ -2171,8 +2133,8 @@
 	sk_memory_allocated_sub(sk, amount);
 	sk->sk_forward_alloc -= amount << SK_MEM_QUANTUM_SHIFT;
 
-	if (mem_cgroup_sockets_enabled && sk->sk_cgrp)
-		mem_cgroup_uncharge_skmem(sk->sk_cgrp, amount);
+	if (mem_cgroup_sockets_enabled && sk->sk_memcg)
+		mem_cgroup_uncharge_skmem(sk->sk_memcg, amount);
 
 	if (sk_under_memory_pressure(sk) &&
 	    (sk_memory_allocated(sk) < sk_prot_mem_limits(sk, 0)))
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index eb39e02..c7d1fb5 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -1819,7 +1819,7 @@
 
 	sk_sockets_allocated_dec(sk);
 
-	if (mem_cgroup_sockets_enabled && sk->sk_cgrp)
+	if (mem_cgroup_sockets_enabled && sk->sk_memcg)
 		sock_release_memcg(sk);
 }
 EXPORT_SYMBOL(tcp_v4_destroy_sock);
@@ -2344,11 +2344,6 @@
 	.compat_setsockopt	= compat_tcp_setsockopt,
 	.compat_getsockopt	= compat_tcp_getsockopt,
 #endif
-#ifdef CONFIG_MEMCG_KMEM
-	.init_cgroup		= tcp_init_cgroup,
-	.destroy_cgroup		= tcp_destroy_cgroup,
-	.proto_cgroup		= tcp_proto_cgroup,
-#endif
 	.diag_destroy		= tcp_abort,
 };
 EXPORT_SYMBOL(tcp_prot);
diff --git a/net/ipv4/tcp_memcontrol.c b/net/ipv4/tcp_memcontrol.c
index ef4268d..e507825 100644
--- a/net/ipv4/tcp_memcontrol.c
+++ b/net/ipv4/tcp_memcontrol.c
@@ -8,60 +8,47 @@
 
 int tcp_init_cgroup(struct mem_cgroup *memcg, struct cgroup_subsys *ss)
 {
+	struct mem_cgroup *parent = parent_mem_cgroup(memcg);
+	struct page_counter *counter_parent = NULL;
 	/*
 	 * The root cgroup does not use page_counters, but rather,
 	 * rely on the data already collected by the network
 	 * subsystem
 	 */
-	struct mem_cgroup *parent = parent_mem_cgroup(memcg);
-	struct page_counter *counter_parent = NULL;
-	struct cg_proto *cg_proto, *parent_cg;
-
-	cg_proto = tcp_prot.proto_cgroup(memcg);
-	if (!cg_proto)
+	if (memcg == root_mem_cgroup)
 		return 0;
 
-	cg_proto->memory_pressure = 0;
-	cg_proto->memcg = memcg;
+	memcg->tcp_mem.memory_pressure = 0;
 
-	parent_cg = tcp_prot.proto_cgroup(parent);
-	if (parent_cg)
-		counter_parent = &parent_cg->memory_allocated;
+	if (parent)
+		counter_parent = &parent->tcp_mem.memory_allocated;
 
-	page_counter_init(&cg_proto->memory_allocated, counter_parent);
+	page_counter_init(&memcg->tcp_mem.memory_allocated, counter_parent);
 
 	return 0;
 }
-EXPORT_SYMBOL(tcp_init_cgroup);
 
 void tcp_destroy_cgroup(struct mem_cgroup *memcg)
 {
-	struct cg_proto *cg_proto;
-
-	cg_proto = tcp_prot.proto_cgroup(memcg);
-	if (!cg_proto)
+	if (memcg == root_mem_cgroup)
 		return;
 
-	if (cg_proto->active)
+	if (memcg->tcp_mem.active)
 		static_key_slow_dec(&memcg_socket_limit_enabled);
-
 }
-EXPORT_SYMBOL(tcp_destroy_cgroup);
 
 static int tcp_update_limit(struct mem_cgroup *memcg, unsigned long nr_pages)
 {
-	struct cg_proto *cg_proto;
 	int ret;
 
-	cg_proto = tcp_prot.proto_cgroup(memcg);
-	if (!cg_proto)
+	if (memcg == root_mem_cgroup)
 		return -EINVAL;
 
-	ret = page_counter_limit(&cg_proto->memory_allocated, nr_pages);
+	ret = page_counter_limit(&memcg->tcp_mem.memory_allocated, nr_pages);
 	if (ret)
 		return ret;
 
-	if (!cg_proto->active) {
+	if (!memcg->tcp_mem.active) {
 		/*
 		 * The active flag needs to be written after the static_key
 		 * update. This is what guarantees that the socket activation
@@ -79,7 +66,7 @@
 		 * patched in yet.
 		 */
 		static_key_slow_inc(&memcg_socket_limit_enabled);
-		cg_proto->active = true;
+		memcg->tcp_mem.active = true;
 	}
 
 	return 0;
@@ -123,32 +110,32 @@
 static u64 tcp_cgroup_read(struct cgroup_subsys_state *css, struct cftype *cft)
 {
 	struct mem_cgroup *memcg = mem_cgroup_from_css(css);
-	struct cg_proto *cg_proto = tcp_prot.proto_cgroup(memcg);
 	u64 val;
 
 	switch (cft->private) {
 	case RES_LIMIT:
-		if (!cg_proto)
-			return PAGE_COUNTER_MAX;
-		val = cg_proto->memory_allocated.limit;
+		if (memcg == root_mem_cgroup)
+			val = PAGE_COUNTER_MAX;
+		else
+			val = memcg->tcp_mem.memory_allocated.limit;
 		val *= PAGE_SIZE;
 		break;
 	case RES_USAGE:
-		if (!cg_proto)
+		if (memcg == root_mem_cgroup)
 			val = atomic_long_read(&tcp_memory_allocated);
 		else
-			val = page_counter_read(&cg_proto->memory_allocated);
+			val = page_counter_read(&memcg->tcp_mem.memory_allocated);
 		val *= PAGE_SIZE;
 		break;
 	case RES_FAILCNT:
-		if (!cg_proto)
+		if (memcg == root_mem_cgroup)
 			return 0;
-		val = cg_proto->memory_allocated.failcnt;
+		val = memcg->tcp_mem.memory_allocated.failcnt;
 		break;
 	case RES_MAX_USAGE:
-		if (!cg_proto)
+		if (memcg == root_mem_cgroup)
 			return 0;
-		val = cg_proto->memory_allocated.watermark;
+		val = memcg->tcp_mem.memory_allocated.watermark;
 		val *= PAGE_SIZE;
 		break;
 	default:
@@ -161,19 +148,17 @@
 				char *buf, size_t nbytes, loff_t off)
 {
 	struct mem_cgroup *memcg;
-	struct cg_proto *cg_proto;
 
 	memcg = mem_cgroup_from_css(of_css(of));
-	cg_proto = tcp_prot.proto_cgroup(memcg);
-	if (!cg_proto)
+	if (memcg == root_mem_cgroup)
 		return nbytes;
 
 	switch (of_cft(of)->private) {
 	case RES_MAX_USAGE:
-		page_counter_reset_watermark(&cg_proto->memory_allocated);
+		page_counter_reset_watermark(&memcg->tcp_mem.memory_allocated);
 		break;
 	case RES_FAILCNT:
-		cg_proto->memory_allocated.failcnt = 0;
+		memcg->tcp_mem.memory_allocated.failcnt = 0;
 		break;
 	}
 
diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c
index 493b489..fda379c 100644
--- a/net/ipv4/tcp_output.c
+++ b/net/ipv4/tcp_output.c
@@ -2821,8 +2821,8 @@
 	sk->sk_forward_alloc += amt * SK_MEM_QUANTUM;
 	sk_memory_allocated_add(sk, amt);
 
-	if (mem_cgroup_sockets_enabled && sk->sk_cgrp)
-		mem_cgroup_charge_skmem(sk->sk_cgrp, amt);
+	if (mem_cgroup_sockets_enabled && sk->sk_memcg)
+		mem_cgroup_charge_skmem(sk->sk_memcg, amt);
 }
 
 /* Send a FIN. The caller locks the socket for us.
diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
index db9f1c3..4ad8edb 100644
--- a/net/ipv6/tcp_ipv6.c
+++ b/net/ipv6/tcp_ipv6.c
@@ -1889,9 +1889,6 @@
 	.compat_setsockopt	= compat_tcp_setsockopt,
 	.compat_getsockopt	= compat_tcp_getsockopt,
 #endif
-#ifdef CONFIG_MEMCG_KMEM
-	.proto_cgroup		= tcp_proto_cgroup,
-#endif
 	.clear_sk		= tcp_v6_clear_sk,
 	.diag_destroy		= tcp_abort,
 };