net: convert packet_fanout.sk_ref from atomic_t to refcount_t

refcount_t type and corresponding API should be
used instead of atomic_t when the variable is used as
a reference counter. This allows to avoid accidental
refcounter overflows that might lead to use-after-free
situations.

Signed-off-by: Elena Reshetova <elena.reshetova@intel.com>
Signed-off-by: Hans Liljestrand <ishkamiel@gmail.com>
Signed-off-by: Kees Cook <keescook@chromium.org>
Signed-off-by: David Windsor <dwindsor@gmail.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
Change-Id: Ibcded37762b5421cb3b81ca19b1f13a101ac3958
Git-repo: https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
Git-commit: fb5c2c17a556d9b00798d6a6b9e624281ee2eb28
[dcagle@codeaurora.org: Change one more atomic ref to refcount]
Signed-off-by: Dennis Cagle <dcagle@codeaurora.org>
diff --git a/net/packet/af_packet.c b/net/packet/af_packet.c
index 267db0d..f15f08b 100644
--- a/net/packet/af_packet.c
+++ b/net/packet/af_packet.c
@@ -1685,7 +1685,7 @@
 		match->flags = flags;
 		INIT_LIST_HEAD(&match->list);
 		spin_lock_init(&match->lock);
-		atomic_set(&match->sk_ref, 0);
+		refcount_set(&match->sk_ref, 0);
 		fanout_init_data(match);
 		match->prot_hook.type = po->prot_hook.type;
 		match->prot_hook.dev = po->prot_hook.dev;
@@ -1702,19 +1702,19 @@
 	    match->prot_hook.type == po->prot_hook.type &&
 	    match->prot_hook.dev == po->prot_hook.dev) {
 		err = -ENOSPC;
-		if (atomic_read(&match->sk_ref) < PACKET_FANOUT_MAX) {
+		if (refcount_read(&match->sk_ref) < PACKET_FANOUT_MAX) {
 			__dev_remove_pack(&po->prot_hook);
 			po->fanout = match;
 			po->rollover = rollover;
 			rollover = NULL;
-			atomic_inc(&match->sk_ref);
+			refcount_set(&match->sk_ref, refcount_read(&match->sk_ref) + 1);
 			__fanout_link(sk, po);
 			err = 0;
 		}
 	}
 	spin_unlock(&po->bind_lock);
 
-	if (err && !atomic_read(&match->sk_ref)) {
+	if (err && !refcount_read(&match->sk_ref)) {
 		list_del(&match->list);
 		kfree(match);
 	}
@@ -1740,7 +1740,7 @@
 	if (f) {
 		po->fanout = NULL;
 
-		if (atomic_dec_and_test(&f->sk_ref))
+		if (refcount_dec_and_test(&f->sk_ref))
 			list_del(&f->list);
 		else
 			f = NULL;