udp: enable MSG_PEEK at non-zero offset

Enable peeking at UDP datagrams at the offset specified with socket
option SOL_SOCKET/SO_PEEK_OFF. Peek at any datagram in the queue, up
to the end of the given datagram.

Implement the SO_PEEK_OFF semantics introduced in commit ef64a54f6e55
("sock: Introduce the SO_PEEK_OFF sock option"). Increase the offset
on peek, decrease it on regular reads.

When peeking, always checksum the packet immediately, to avoid
recomputation on subsequent peeks and final read.

The socket lock is not held for the duration of udp_recvmsg, so
peek and read operations can run concurrently. Only the last store
to sk_peek_off is preserved.

Signed-off-by: Sam Kumar <samanthakumar@google.com>
Signed-off-by: Willem de Bruijn <willemb@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/core/datagram.c b/net/core/datagram.c
index fa9dc64..b7de71f 100644
--- a/net/core/datagram.c
+++ b/net/core/datagram.c
@@ -301,16 +301,19 @@
 }
 EXPORT_SYMBOL(skb_free_datagram);
 
-void skb_free_datagram_locked(struct sock *sk, struct sk_buff *skb)
+void __skb_free_datagram_locked(struct sock *sk, struct sk_buff *skb, int len)
 {
 	bool slow;
 
 	if (likely(atomic_read(&skb->users) == 1))
 		smp_rmb();
-	else if (likely(!atomic_dec_and_test(&skb->users)))
+	else if (likely(!atomic_dec_and_test(&skb->users))) {
+		sk_peek_offset_bwd(sk, len);
 		return;
+	}
 
 	slow = lock_sock_fast(sk);
+	sk_peek_offset_bwd(sk, len);
 	skb_orphan(skb);
 	sk_mem_reclaim_partial(sk);
 	unlock_sock_fast(sk, slow);
@@ -318,7 +321,7 @@
 	/* skb is now orphaned, can be freed outside of locked section */
 	__kfree_skb(skb);
 }
-EXPORT_SYMBOL(skb_free_datagram_locked);
+EXPORT_SYMBOL(__skb_free_datagram_locked);
 
 /**
  *	skb_kill_datagram - Free a datagram skbuff forcibly
diff --git a/net/core/sock.c b/net/core/sock.c
index e12197b..2ce76e8 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -2187,6 +2187,15 @@
 }
 EXPORT_SYMBOL(__sk_mem_reclaim);
 
+int sk_set_peek_off(struct sock *sk, int val)
+{
+	if (val < 0)
+		return -EINVAL;
+
+	sk->sk_peek_off = val;
+	return 0;
+}
+EXPORT_SYMBOL_GPL(sk_set_peek_off);
 
 /*
  * Set of default routines for initialising struct proto_ops when
diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
index 9e48199..a38b991 100644
--- a/net/ipv4/af_inet.c
+++ b/net/ipv4/af_inet.c
@@ -948,6 +948,7 @@
 	.recvmsg	   = inet_recvmsg,
 	.mmap		   = sock_no_mmap,
 	.sendpage	   = inet_sendpage,
+	.set_peek_off	   = sk_set_peek_off,
 #ifdef CONFIG_COMPAT
 	.compat_setsockopt = compat_sock_common_setsockopt,
 	.compat_getsockopt = compat_sock_common_getsockopt,
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index cf747e8..d80312d 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -1294,7 +1294,7 @@
 	DECLARE_SOCKADDR(struct sockaddr_in *, sin, msg->msg_name);
 	struct sk_buff *skb;
 	unsigned int ulen, copied;
-	int peeked, off = 0;
+	int peeked, peeking, off;
 	int err;
 	int is_udplite = IS_UDPLITE(sk);
 	bool checksum_valid = false;
@@ -1304,15 +1304,16 @@
 		return ip_recv_error(sk, msg, len, addr_len);
 
 try_again:
+	peeking = off = sk_peek_offset(sk, flags);
 	skb = __skb_recv_datagram(sk, flags | (noblock ? MSG_DONTWAIT : 0),
 				  &peeked, &off, &err);
 	if (!skb)
-		goto out;
+		return err;
 
 	ulen = skb->len;
 	copied = len;
-	if (copied > ulen)
-		copied = ulen;
+	if (copied > ulen - off)
+		copied = ulen - off;
 	else if (copied < ulen)
 		msg->msg_flags |= MSG_TRUNC;
 
@@ -1322,16 +1323,16 @@
 	 * coverage checksum (UDP-Lite), do it before the copy.
 	 */
 
-	if (copied < ulen || UDP_SKB_CB(skb)->partial_cov) {
+	if (copied < ulen || UDP_SKB_CB(skb)->partial_cov || peeking) {
 		checksum_valid = !udp_lib_checksum_complete(skb);
 		if (!checksum_valid)
 			goto csum_copy_err;
 	}
 
 	if (checksum_valid || skb_csum_unnecessary(skb))
-		err = skb_copy_datagram_msg(skb, 0, msg, copied);
+		err = skb_copy_datagram_msg(skb, off, msg, copied);
 	else {
-		err = skb_copy_and_csum_datagram_msg(skb, 0, msg);
+		err = skb_copy_and_csum_datagram_msg(skb, off, msg);
 
 		if (err == -EINVAL)
 			goto csum_copy_err;
@@ -1344,7 +1345,8 @@
 			UDP_INC_STATS_USER(sock_net(sk),
 					   UDP_MIB_INERRORS, is_udplite);
 		}
-		goto out_free;
+		skb_free_datagram_locked(sk, skb);
+		return err;
 	}
 
 	if (!peeked)
@@ -1368,9 +1370,7 @@
 	if (flags & MSG_TRUNC)
 		err = ulen;
 
-out_free:
-	skb_free_datagram_locked(sk, skb);
-out:
+	__skb_free_datagram_locked(sk, skb, peeking ? -err : err);
 	return err;
 
 csum_copy_err:
diff --git a/net/ipv6/af_inet6.c b/net/ipv6/af_inet6.c
index b11c37c..2b78aad 100644
--- a/net/ipv6/af_inet6.c
+++ b/net/ipv6/af_inet6.c
@@ -561,6 +561,7 @@
 	.recvmsg	   = inet_recvmsg,		/* ok		*/
 	.mmap		   = sock_no_mmap,
 	.sendpage	   = sock_no_sendpage,
+	.set_peek_off	   = sk_set_peek_off,
 #ifdef CONFIG_COMPAT
 	.compat_setsockopt = compat_sock_common_setsockopt,
 	.compat_getsockopt = compat_sock_common_getsockopt,
diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
index 84c8d7b..87bd7af 100644
--- a/net/ipv6/udp.c
+++ b/net/ipv6/udp.c
@@ -357,7 +357,7 @@
 	struct inet_sock *inet = inet_sk(sk);
 	struct sk_buff *skb;
 	unsigned int ulen, copied;
-	int peeked, off = 0;
+	int peeked, peeking, off;
 	int err;
 	int is_udplite = IS_UDPLITE(sk);
 	bool checksum_valid = false;
@@ -371,15 +371,16 @@
 		return ipv6_recv_rxpmtu(sk, msg, len, addr_len);
 
 try_again:
+	peeking = off = sk_peek_offset(sk, flags);
 	skb = __skb_recv_datagram(sk, flags | (noblock ? MSG_DONTWAIT : 0),
 				  &peeked, &off, &err);
 	if (!skb)
-		goto out;
+		return err;
 
 	ulen = skb->len;
 	copied = len;
-	if (copied > ulen)
-		copied = ulen;
+	if (copied > ulen - off)
+		copied = ulen - off;
 	else if (copied < ulen)
 		msg->msg_flags |= MSG_TRUNC;
 
@@ -391,16 +392,16 @@
 	 * coverage checksum (UDP-Lite), do it before the copy.
 	 */
 
-	if (copied < ulen || UDP_SKB_CB(skb)->partial_cov) {
+	if (copied < ulen || UDP_SKB_CB(skb)->partial_cov || peeking) {
 		checksum_valid = !udp_lib_checksum_complete(skb);
 		if (!checksum_valid)
 			goto csum_copy_err;
 	}
 
 	if (checksum_valid || skb_csum_unnecessary(skb))
-		err = skb_copy_datagram_msg(skb, 0, msg, copied);
+		err = skb_copy_datagram_msg(skb, off, msg, copied);
 	else {
-		err = skb_copy_and_csum_datagram_msg(skb, 0, msg);
+		err = skb_copy_and_csum_datagram_msg(skb, off, msg);
 		if (err == -EINVAL)
 			goto csum_copy_err;
 	}
@@ -417,7 +418,8 @@
 						    UDP_MIB_INERRORS,
 						    is_udplite);
 		}
-		goto out_free;
+		skb_free_datagram_locked(sk, skb);
+		return err;
 	}
 	if (!peeked) {
 		if (is_udp4)
@@ -465,9 +467,7 @@
 	if (flags & MSG_TRUNC)
 		err = ulen;
 
-out_free:
-	skb_free_datagram_locked(sk, skb);
-out:
+	__skb_free_datagram_locked(sk, skb, peeking ? -err : err);
 	return err;
 
 csum_copy_err: