net: convert sock.sk_refcnt from atomic_t to refcount_t

refcount_t type and corresponding API should be
used instead of atomic_t when the variable is used as
a reference counter. This allows to avoid accidental
refcounter overflows that might lead to use-after-free
situations.

This patch uses refcount_inc_not_zero() instead of
atomic_inc_not_zero_hint() due to absense of a _hint()
version of refcount API. If the hint() version must
be used, we might need to revisit API.

Signed-off-by: Elena Reshetova <elena.reshetova@intel.com>
Signed-off-by: Hans Liljestrand <ishkamiel@gmail.com>
Signed-off-by: Kees Cook <keescook@chromium.org>
Signed-off-by: David Windsor <dwindsor@gmail.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/include/net/sock.h b/include/net/sock.h
index 5284e50..60200f4 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -66,6 +66,7 @@
 #include <linux/poll.h>
 
 #include <linux/atomic.h>
+#include <linux/refcount.h>
 #include <net/dst.h>
 #include <net/checksum.h>
 #include <net/tcp_states.h>
@@ -219,7 +220,7 @@
 		u32		skc_tw_rcv_nxt; /* struct tcp_timewait_sock  */
 	};
 
-	atomic_t		skc_refcnt;
+	refcount_t		skc_refcnt;
 	/* private: */
 	int                     skc_dontcopy_end[0];
 	union {
@@ -611,7 +612,7 @@
 
 static __always_inline void sock_hold(struct sock *sk)
 {
-	atomic_inc(&sk->sk_refcnt);
+	refcount_inc(&sk->sk_refcnt);
 }
 
 /* Ungrab socket in the context, which assumes that socket refcnt
@@ -619,7 +620,7 @@
  */
 static __always_inline void __sock_put(struct sock *sk)
 {
-	atomic_dec(&sk->sk_refcnt);
+	refcount_dec(&sk->sk_refcnt);
 }
 
 static inline bool sk_del_node_init(struct sock *sk)
@@ -628,7 +629,7 @@
 
 	if (rc) {
 		/* paranoid for a while -acme */
-		WARN_ON(atomic_read(&sk->sk_refcnt) == 1);
+		WARN_ON(refcount_read(&sk->sk_refcnt) == 1);
 		__sock_put(sk);
 	}
 	return rc;
@@ -650,7 +651,7 @@
 
 	if (rc) {
 		/* paranoid for a while -acme */
-		WARN_ON(atomic_read(&sk->sk_refcnt) == 1);
+		WARN_ON(refcount_read(&sk->sk_refcnt) == 1);
 		__sock_put(sk);
 	}
 	return rc;
@@ -1144,9 +1145,9 @@
 
 static inline void sk_refcnt_debug_release(const struct sock *sk)
 {
-	if (atomic_read(&sk->sk_refcnt) != 1)
+	if (refcount_read(&sk->sk_refcnt) != 1)
 		printk(KERN_DEBUG "Destruction of the %s socket %p delayed, refcnt=%d\n",
-		       sk->sk_prot->name, sk, atomic_read(&sk->sk_refcnt));
+		       sk->sk_prot->name, sk, refcount_read(&sk->sk_refcnt));
 }
 #else /* SOCK_REFCNT_DEBUG */
 #define sk_refcnt_debug_inc(sk) do { } while (0)
@@ -1636,7 +1637,7 @@
 /* Ungrab socket and destroy it, if it was the last reference. */
 static inline void sock_put(struct sock *sk)
 {
-	if (atomic_dec_and_test(&sk->sk_refcnt))
+	if (refcount_dec_and_test(&sk->sk_refcnt))
 		sk_free(sk);
 }
 /* Generic version of sock_put(), dealing with all sockets