netfilter: nf_tables: 64bit stats need some extra synchronization

Use generic u64_stats_sync infrastructure to get proper 64bit stats,
even on 32bit arches, at no extra cost for 64bit arches.

Without this fix, 32bit arches can have some wrong counters at the time
the carry is propagated into upper word.

Signed-off-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: Pablo Neira Ayuso <pablo@netfilter.org>
diff --git a/include/net/netfilter/nf_tables.h b/include/net/netfilter/nf_tables.h
index 713b0b8..c4d8619 100644
--- a/include/net/netfilter/nf_tables.h
+++ b/include/net/netfilter/nf_tables.h
@@ -6,6 +6,7 @@
 #include <linux/netfilter/nfnetlink.h>
 #include <linux/netfilter/x_tables.h>
 #include <linux/netfilter/nf_tables.h>
+#include <linux/u64_stats_sync.h>
 #include <net/netlink.h>
 
 #define NFT_JUMP_STACK_SIZE	16
@@ -528,8 +529,9 @@
 };
 
 struct nft_stats {
-	u64 bytes;
-	u64 pkts;
+	u64			bytes;
+	u64			pkts;
+	struct u64_stats_sync	syncp;
 };
 
 #define NFT_HOOK_OPS_MAX		2
diff --git a/net/netfilter/nf_tables_api.c b/net/netfilter/nf_tables_api.c
index ac03d74..8746ff9 100644
--- a/net/netfilter/nf_tables_api.c
+++ b/net/netfilter/nf_tables_api.c
@@ -644,13 +644,20 @@
 {
 	struct nft_stats *cpu_stats, total;
 	struct nlattr *nest;
+	unsigned int seq;
+	u64 pkts, bytes;
 	int cpu;
 
 	memset(&total, 0, sizeof(total));
 	for_each_possible_cpu(cpu) {
 		cpu_stats = per_cpu_ptr(stats, cpu);
-		total.pkts += cpu_stats->pkts;
-		total.bytes += cpu_stats->bytes;
+		do {
+			seq = u64_stats_fetch_begin_irq(&cpu_stats->syncp);
+			pkts = cpu_stats->pkts;
+			bytes = cpu_stats->bytes;
+		} while (u64_stats_fetch_retry_irq(&cpu_stats->syncp, seq));
+		total.pkts += pkts;
+		total.bytes += bytes;
 	}
 	nest = nla_nest_start(skb, NFTA_CHAIN_COUNTERS);
 	if (nest == NULL)
@@ -875,7 +882,7 @@
 	if (!tb[NFTA_COUNTER_BYTES] || !tb[NFTA_COUNTER_PACKETS])
 		return ERR_PTR(-EINVAL);
 
-	newstats = alloc_percpu(struct nft_stats);
+	newstats = netdev_alloc_pcpu_stats(struct nft_stats);
 	if (newstats == NULL)
 		return ERR_PTR(-ENOMEM);
 
@@ -1091,7 +1098,7 @@
 			}
 			basechain->stats = stats;
 		} else {
-			stats = alloc_percpu(struct nft_stats);
+			stats = netdev_alloc_pcpu_stats(struct nft_stats);
 			if (IS_ERR(stats)) {
 				module_put(type->owner);
 				kfree(basechain);
diff --git a/net/netfilter/nf_tables_core.c b/net/netfilter/nf_tables_core.c
index 345acfb..3b90eb2 100644
--- a/net/netfilter/nf_tables_core.c
+++ b/net/netfilter/nf_tables_core.c
@@ -109,7 +109,7 @@
 	struct nft_data data[NFT_REG_MAX + 1];
 	unsigned int stackptr = 0;
 	struct nft_jumpstack jumpstack[NFT_JUMP_STACK_SIZE];
-	struct nft_stats __percpu *stats;
+	struct nft_stats *stats;
 	int rulenum;
 	/*
 	 * Cache cursor to avoid problems in case that the cursor is updated
@@ -205,9 +205,11 @@
 		nft_trace_packet(pkt, basechain, -1, NFT_TRACE_POLICY);
 
 	rcu_read_lock_bh();
-	stats = rcu_dereference(nft_base_chain(basechain)->stats);
-	__this_cpu_inc(stats->pkts);
-	__this_cpu_add(stats->bytes, pkt->skb->len);
+	stats = this_cpu_ptr(rcu_dereference(nft_base_chain(basechain)->stats));
+	u64_stats_update_begin(&stats->syncp);
+	stats->pkts++;
+	stats->bytes += pkt->skb->len;
+	u64_stats_update_end(&stats->syncp);
 	rcu_read_unlock_bh();
 
 	return nft_base_chain(basechain)->policy;