mm: memcontrol: account socket memory in unified hierarchy memory controller

Socket memory can be a significant share of overall memory consumed by
common workloads.  In order to provide reasonable resource isolation in
the unified hierarchy, this type of memory needs to be included in the
tracking/accounting of a cgroup under active memory resource control.

Overhead is only incurred when a non-root control group is created AND
the memory controller is instructed to track and account the memory
footprint of that group.  cgroup.memory=nosocket can be specified on the
boot commandline to override any runtime configuration and forcibly
exclude socket memory from active memory resource control.

Signed-off-by: Johannes Weiner <hannes@cmpxchg.org>
Acked-by: David S. Miller <davem@davemloft.net>
Reviewed-by: Vladimir Davydov <vdavydov@virtuozzo.com>
Acked-by: Michal Hocko <mhocko@suse.com>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
Signed-off-by: Linus Torvalds <torvalds@linux-foundation.org>
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 6aac8d2..60ebc48 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -80,6 +80,9 @@
 
 #define MEM_CGROUP_RECLAIM_RETRIES	5
 
+/* Socket memory accounting disabled? */
+static bool cgroup_memory_nosocket;
+
 /* Whether the swap controller is active */
 #ifdef CONFIG_MEMCG_SWAP
 int do_swap_account __read_mostly;
@@ -1945,6 +1948,26 @@
 	return NOTIFY_OK;
 }
 
+static void reclaim_high(struct mem_cgroup *memcg,
+			 unsigned int nr_pages,
+			 gfp_t gfp_mask)
+{
+	do {
+		if (page_counter_read(&memcg->memory) <= memcg->high)
+			continue;
+		mem_cgroup_events(memcg, MEMCG_HIGH, 1);
+		try_to_free_mem_cgroup_pages(memcg, nr_pages, gfp_mask, true);
+	} while ((memcg = parent_mem_cgroup(memcg)));
+}
+
+static void high_work_func(struct work_struct *work)
+{
+	struct mem_cgroup *memcg;
+
+	memcg = container_of(work, struct mem_cgroup, high_work);
+	reclaim_high(memcg, CHARGE_BATCH, GFP_KERNEL);
+}
+
 /*
  * Scheduled by try_charge() to be executed from the userland return path
  * and reclaims memory over the high limit.
@@ -1952,20 +1975,13 @@
 void mem_cgroup_handle_over_high(void)
 {
 	unsigned int nr_pages = current->memcg_nr_pages_over_high;
-	struct mem_cgroup *memcg, *pos;
+	struct mem_cgroup *memcg;
 
 	if (likely(!nr_pages))
 		return;
 
-	pos = memcg = get_mem_cgroup_from_mm(current->mm);
-
-	do {
-		if (page_counter_read(&pos->memory) <= pos->high)
-			continue;
-		mem_cgroup_events(pos, MEMCG_HIGH, 1);
-		try_to_free_mem_cgroup_pages(pos, nr_pages, GFP_KERNEL, true);
-	} while ((pos = parent_mem_cgroup(pos)));
-
+	memcg = get_mem_cgroup_from_mm(current->mm);
+	reclaim_high(memcg, nr_pages, GFP_KERNEL);
 	css_put(&memcg->css);
 	current->memcg_nr_pages_over_high = 0;
 }
@@ -2100,6 +2116,11 @@
 	 */
 	do {
 		if (page_counter_read(&memcg->memory) > memcg->high) {
+			/* Don't bother a random interrupted task */
+			if (in_interrupt()) {
+				schedule_work(&memcg->high_work);
+				break;
+			}
 			current->memcg_nr_pages_over_high += batch;
 			set_notify_resume(current);
 			break;
@@ -4150,6 +4171,8 @@
 {
 	int node;
 
+	cancel_work_sync(&memcg->high_work);
+
 	mem_cgroup_remove_from_trees(memcg);
 
 	for_each_node(node)
@@ -4196,6 +4219,7 @@
 		page_counter_init(&memcg->kmem, NULL);
 	}
 
+	INIT_WORK(&memcg->high_work, high_work_func);
 	memcg->last_scanned_node = MAX_NUMNODES;
 	INIT_LIST_HEAD(&memcg->oom_notify);
 	memcg->move_charge_at_immigrate = 0;
@@ -4267,6 +4291,11 @@
 	if (ret)
 		return ret;
 
+#ifdef CONFIG_INET
+	if (cgroup_subsys_on_dfl(memory_cgrp_subsys) && !cgroup_memory_nosocket)
+		static_key_slow_inc(&memcg_sockets_enabled_key);
+#endif
+
 	/*
 	 * Make sure the memcg is initialized: mem_cgroup_iter()
 	 * orders reading memcg->initialized against its callers
@@ -4313,6 +4342,10 @@
 	struct mem_cgroup *memcg = mem_cgroup_from_css(css);
 
 	memcg_destroy_kmem(memcg);
+#ifdef CONFIG_INET
+	if (cgroup_subsys_on_dfl(memory_cgrp_subsys) && !cgroup_memory_nosocket)
+		static_key_slow_dec(&memcg_sockets_enabled_key);
+#endif
 	__mem_cgroup_free(memcg);
 }
 
@@ -5533,8 +5566,7 @@
 	commit_charge(newpage, memcg, true);
 }
 
-/* Writing them here to avoid exposing memcg's inner layout */
-#if defined(CONFIG_INET) && defined(CONFIG_MEMCG_KMEM)
+#ifdef CONFIG_INET
 
 struct static_key memcg_sockets_enabled_key;
 EXPORT_SYMBOL(memcg_sockets_enabled_key);
@@ -5559,10 +5591,15 @@
 
 	rcu_read_lock();
 	memcg = mem_cgroup_from_task(current);
-	if (memcg != root_mem_cgroup &&
-	    memcg->tcp_mem.active &&
-	    css_tryget_online(&memcg->css))
+	if (memcg == root_mem_cgroup)
+		goto out;
+#ifdef CONFIG_MEMCG_KMEM
+	if (!cgroup_subsys_on_dfl(memory_cgrp_subsys) && !memcg->tcp_mem.active)
+		goto out;
+#endif
+	if (css_tryget_online(&memcg->css))
 		sk->sk_memcg = memcg;
+out:
 	rcu_read_unlock();
 }
 EXPORT_SYMBOL(sock_update_memcg);
@@ -5583,15 +5620,30 @@
  */
 bool mem_cgroup_charge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages)
 {
-	struct page_counter *counter;
+	gfp_t gfp_mask = GFP_KERNEL;
 
-	if (page_counter_try_charge(&memcg->tcp_mem.memory_allocated,
-				    nr_pages, &counter)) {
-		memcg->tcp_mem.memory_pressure = 0;
-		return true;
+#ifdef CONFIG_MEMCG_KMEM
+	if (!cgroup_subsys_on_dfl(memory_cgrp_subsys)) {
+		struct page_counter *counter;
+
+		if (page_counter_try_charge(&memcg->tcp_mem.memory_allocated,
+					    nr_pages, &counter)) {
+			memcg->tcp_mem.memory_pressure = 0;
+			return true;
+		}
+		page_counter_charge(&memcg->tcp_mem.memory_allocated, nr_pages);
+		memcg->tcp_mem.memory_pressure = 1;
+		return false;
 	}
-	page_counter_charge(&memcg->tcp_mem.memory_allocated, nr_pages);
-	memcg->tcp_mem.memory_pressure = 1;
+#endif
+	/* Don't block in the packet receive path */
+	if (in_softirq())
+		gfp_mask = GFP_NOWAIT;
+
+	if (try_charge(memcg, gfp_mask, nr_pages) == 0)
+		return true;
+
+	try_charge(memcg, gfp_mask|__GFP_NOFAIL, nr_pages);
 	return false;
 }
 
@@ -5602,10 +5654,32 @@
  */
 void mem_cgroup_uncharge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages)
 {
-	page_counter_uncharge(&memcg->tcp_mem.memory_allocated, nr_pages);
+#ifdef CONFIG_MEMCG_KMEM
+	if (!cgroup_subsys_on_dfl(memory_cgrp_subsys)) {
+		page_counter_uncharge(&memcg->tcp_mem.memory_allocated,
+				      nr_pages);
+		return;
+	}
+#endif
+	page_counter_uncharge(&memcg->memory, nr_pages);
+	css_put_many(&memcg->css, nr_pages);
 }
 
-#endif
+#endif /* CONFIG_INET */
+
+static int __init cgroup_memory(char *s)
+{
+	char *token;
+
+	while ((token = strsep(&s, ",")) != NULL) {
+		if (!*token)
+			continue;
+		if (!strcmp(token, "nosocket"))
+			cgroup_memory_nosocket = true;
+	}
+	return 0;
+}
+__setup("cgroup.memory=", cgroup_memory);
 
 /*
  * subsys_initcall() for memory controller.