net: convert sock.sk_wmem_alloc from atomic_t to refcount_t

refcount_t type and corresponding API should be
used instead of atomic_t when the variable is used as
a reference counter. This allows to avoid accidental
refcounter overflows that might lead to use-after-free
situations.

Signed-off-by: Elena Reshetova <elena.reshetova@intel.com>
Signed-off-by: Hans Liljestrand <ishkamiel@gmail.com>
Signed-off-by: Kees Cook <keescook@chromium.org>
Signed-off-by: David Windsor <dwindsor@gmail.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c
index 9a9c395..1d79137 100644
--- a/net/ipv4/tcp_output.c
+++ b/net/ipv4/tcp_output.c
@@ -861,12 +861,11 @@ void tcp_wfree(struct sk_buff *skb)
 	struct sock *sk = skb->sk;
 	struct tcp_sock *tp = tcp_sk(sk);
 	unsigned long flags, nval, oval;
-	int wmem;
 
 	/* Keep one reference on sk_wmem_alloc.
 	 * Will be released by sk_free() from here or tcp_tasklet_func()
 	 */
-	wmem = atomic_sub_return(skb->truesize - 1, &sk->sk_wmem_alloc);
+	WARN_ON(refcount_sub_and_test(skb->truesize - 1, &sk->sk_wmem_alloc));
 
 	/* If this softirq is serviced by ksoftirqd, we are likely under stress.
 	 * Wait until our queues (qdisc + devices) are drained.
@@ -875,7 +874,7 @@ void tcp_wfree(struct sk_buff *skb)
 	 * - chance for incoming ACK (processed by another cpu maybe)
 	 *   to migrate this flow (skb->ooo_okay will be eventually set)
 	 */
-	if (wmem >= SKB_TRUESIZE(1) && this_cpu_ksoftirqd() == current)
+	if (refcount_read(&sk->sk_wmem_alloc) >= SKB_TRUESIZE(1) && this_cpu_ksoftirqd() == current)
 		goto out;
 
 	for (oval = READ_ONCE(sk->sk_tsq_flags);; oval = nval) {
@@ -925,7 +924,7 @@ enum hrtimer_restart tcp_pace_kick(struct hrtimer *timer)
 		if (nval != oval)
 			continue;
 
-		if (!atomic_inc_not_zero(&sk->sk_wmem_alloc))
+		if (!refcount_inc_not_zero(&sk->sk_wmem_alloc))
 			break;
 		/* queue this socket to tasklet queue */
 		tsq = this_cpu_ptr(&tsq_tasklet);
@@ -1045,7 +1044,7 @@ static int tcp_transmit_skb(struct sock *sk, struct sk_buff *skb, int clone_it,
 	skb->sk = sk;
 	skb->destructor = skb_is_tcp_pure_ack(skb) ? __sock_wfree : tcp_wfree;
 	skb_set_hash_from_sk(skb, sk);
-	atomic_add(skb->truesize, &sk->sk_wmem_alloc);
+	refcount_add(skb->truesize, &sk->sk_wmem_alloc);
 
 	skb_set_dst_pending_confirm(skb, sk->sk_dst_pending_confirm);
 
@@ -2176,7 +2175,7 @@ static bool tcp_small_queue_check(struct sock *sk, const struct sk_buff *skb,
 	limit = min_t(u32, limit, sysctl_tcp_limit_output_bytes);
 	limit <<= factor;
 
-	if (atomic_read(&sk->sk_wmem_alloc) > limit) {
+	if (refcount_read(&sk->sk_wmem_alloc) > limit) {
 		/* Always send the 1st or 2nd skb in write queue.
 		 * No need to wait for TX completion to call us back,
 		 * after softirq/tasklet schedule.
@@ -2192,7 +2191,7 @@ static bool tcp_small_queue_check(struct sock *sk, const struct sk_buff *skb,
 		 * test again the condition.
 		 */
 		smp_mb__after_atomic();
-		if (atomic_read(&sk->sk_wmem_alloc) > limit)
+		if (refcount_read(&sk->sk_wmem_alloc) > limit)
 			return true;
 	}
 	return false;
@@ -2812,7 +2811,7 @@ int __tcp_retransmit_skb(struct sock *sk, struct sk_buff *skb, int segs)
 	/* Do not sent more than we queued. 1/4 is reserved for possible
 	 * copying overhead: fragmentation, tunneling, mangling etc.
 	 */
-	if (atomic_read(&sk->sk_wmem_alloc) >
+	if (refcount_read(&sk->sk_wmem_alloc) >
 	    min_t(u32, sk->sk_wmem_queued + (sk->sk_wmem_queued >> 2),
 		  sk->sk_sndbuf))
 		return -EAGAIN;