Merge branch 'udpmem'

Paolo Abeni says:

====================
udp: refactor memory accounting

This patch series refactor the udp memory accounting, replacing the
generic implementation with a custom one, in order to remove the needs for
locking the socket on the enqueue and dequeue operations. The socket backlog
usage is dropped, as well.

The first patch factor out pieces of some queue and memory management
socket helpers, so that they can later be used by the udp memory accounting
functions.
The second patch adds the memory account helpers, without using them.
The third patch replacse the old rx memory accounting path for udp over ipv4 and
udp over ipv6. In kernel UDP users are updated, as well.

The memory accounting schema is described in detail in the individual patch
commit message.

The performance gain depends on the specific scenario; with few flows (and
little contention in the original code) the differences are in the noise range,
while with several flows contending the same socket, the measured speed-up
is relevant (e.g. even over 100% in case of extreme contention)

Many thanks to Eric Dumazet for the reiterated reviews and suggestions.

v5 -> v6:
 - do not orphan the skb on enqueue, skb_steal_sock() already did
   the work for us

v4 -> v5:
 - use the receive queue spin lock to protect the memory accounting
 - several minor clean-up

v3 -> v4:
 - simplified the locking schema, always use a plain spinlock

v2 -> v3:
 - do not set the now unsed backlog_rcv callback

v1 -> v2:
 - changed slighly the memory accounting schema, we now perform lazy reclaim
 - fixed forward_alloc updating issue
 - fixed memory counter integer overflows
====================

Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/include/net/sock.h b/include/net/sock.h
index ebf75db..2764895 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -1274,7 +1274,9 @@
 /*
  * Functions for memory accounting
  */
+int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind);
 int __sk_mem_schedule(struct sock *sk, int size, int kind);
+void __sk_mem_reduce_allocated(struct sock *sk, int amount);
 void __sk_mem_reclaim(struct sock *sk, int amount);
 
 #define SK_MEM_QUANTUM ((int)PAGE_SIZE)
@@ -1950,6 +1952,8 @@
 
 void sk_stop_timer(struct sock *sk, struct timer_list *timer);
 
+int __sk_queue_drop_skb(struct sock *sk, struct sk_buff *skb,
+			unsigned int flags);
 int __sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb);
 int sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb);
 
diff --git a/include/net/udp.h b/include/net/udp.h
index ea53a87..18f1e6b 100644
--- a/include/net/udp.h
+++ b/include/net/udp.h
@@ -246,6 +246,9 @@
 }
 
 /* net/ipv4/udp.c */
+void skb_consume_udp(struct sock *sk, struct sk_buff *skb, int len);
+int __udp_enqueue_schedule_skb(struct sock *sk, struct sk_buff *skb);
+
 void udp_v4_early_demux(struct sk_buff *skb);
 int udp_get_port(struct sock *sk, unsigned short snum,
 		 int (*saddr_cmp)(const struct sock *,
@@ -258,6 +261,7 @@
 void udp4_hwcsum(struct sk_buff *skb, __be32 src, __be32 dst);
 int udp_rcv(struct sk_buff *skb);
 int udp_ioctl(struct sock *sk, int cmd, unsigned long arg);
+int udp_init_sock(struct sock *sk);
 int udp_disconnect(struct sock *sk, int flags);
 unsigned int udp_poll(struct file *file, struct socket *sock, poll_table *wait);
 struct sk_buff *skb_udp_tunnel_segment(struct sk_buff *skb,
diff --git a/net/core/datagram.c b/net/core/datagram.c
index b7de71f..bfb973a 100644
--- a/net/core/datagram.c
+++ b/net/core/datagram.c
@@ -323,6 +323,27 @@
 }
 EXPORT_SYMBOL(__skb_free_datagram_locked);
 
+int __sk_queue_drop_skb(struct sock *sk, struct sk_buff *skb,
+			unsigned int flags)
+{
+	int err = 0;
+
+	if (flags & MSG_PEEK) {
+		err = -ENOENT;
+		spin_lock_bh(&sk->sk_receive_queue.lock);
+		if (skb == skb_peek(&sk->sk_receive_queue)) {
+			__skb_unlink(skb, &sk->sk_receive_queue);
+			atomic_dec(&skb->users);
+			err = 0;
+		}
+		spin_unlock_bh(&sk->sk_receive_queue.lock);
+	}
+
+	atomic_inc(&sk->sk_drops);
+	return err;
+}
+EXPORT_SYMBOL(__sk_queue_drop_skb);
+
 /**
  *	skb_kill_datagram - Free a datagram skbuff forcibly
  *	@sk: socket
@@ -346,23 +367,10 @@
 
 int skb_kill_datagram(struct sock *sk, struct sk_buff *skb, unsigned int flags)
 {
-	int err = 0;
-
-	if (flags & MSG_PEEK) {
-		err = -ENOENT;
-		spin_lock_bh(&sk->sk_receive_queue.lock);
-		if (skb == skb_peek(&sk->sk_receive_queue)) {
-			__skb_unlink(skb, &sk->sk_receive_queue);
-			atomic_dec(&skb->users);
-			err = 0;
-		}
-		spin_unlock_bh(&sk->sk_receive_queue.lock);
-	}
+	int err = __sk_queue_drop_skb(sk, skb, flags);
 
 	kfree_skb(skb);
-	atomic_inc(&sk->sk_drops);
 	sk_mem_reclaim_partial(sk);
-
 	return err;
 }
 EXPORT_SYMBOL(skb_kill_datagram);
diff --git a/net/core/sock.c b/net/core/sock.c
index c73e28f..d8e4532e 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -2091,24 +2091,18 @@
 EXPORT_SYMBOL(sk_wait_data);
 
 /**
- *	__sk_mem_schedule - increase sk_forward_alloc and memory_allocated
+ *	__sk_mem_raise_allocated - increase memory_allocated
  *	@sk: socket
  *	@size: memory size to allocate
+ *	@amt: pages to allocate
  *	@kind: allocation type
  *
- *	If kind is SK_MEM_SEND, it means wmem allocation. Otherwise it means
- *	rmem allocation. This function assumes that protocols which have
- *	memory_pressure use sk_wmem_queued as write buffer accounting.
+ *	Similar to __sk_mem_schedule(), but does not update sk_forward_alloc
  */
-int __sk_mem_schedule(struct sock *sk, int size, int kind)
+int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind)
 {
 	struct proto *prot = sk->sk_prot;
-	int amt = sk_mem_pages(size);
-	long allocated;
-
-	sk->sk_forward_alloc += amt * SK_MEM_QUANTUM;
-
-	allocated = sk_memory_allocated_add(sk, amt);
+	long allocated = sk_memory_allocated_add(sk, amt);
 
 	if (mem_cgroup_sockets_enabled && sk->sk_memcg &&
 	    !mem_cgroup_charge_skmem(sk->sk_memcg, amt))
@@ -2169,9 +2163,6 @@
 
 	trace_sock_exceed_buf_limit(sk, prot, allocated);
 
-	/* Alas. Undo changes. */
-	sk->sk_forward_alloc -= amt * SK_MEM_QUANTUM;
-
 	sk_memory_allocated_sub(sk, amt);
 
 	if (mem_cgroup_sockets_enabled && sk->sk_memcg)
@@ -2179,18 +2170,40 @@
 
 	return 0;
 }
+EXPORT_SYMBOL(__sk_mem_raise_allocated);
+
+/**
+ *	__sk_mem_schedule - increase sk_forward_alloc and memory_allocated
+ *	@sk: socket
+ *	@size: memory size to allocate
+ *	@kind: allocation type
+ *
+ *	If kind is SK_MEM_SEND, it means wmem allocation. Otherwise it means
+ *	rmem allocation. This function assumes that protocols which have
+ *	memory_pressure use sk_wmem_queued as write buffer accounting.
+ */
+int __sk_mem_schedule(struct sock *sk, int size, int kind)
+{
+	int ret, amt = sk_mem_pages(size);
+
+	sk->sk_forward_alloc += amt << SK_MEM_QUANTUM_SHIFT;
+	ret = __sk_mem_raise_allocated(sk, size, amt, kind);
+	if (!ret)
+		sk->sk_forward_alloc -= amt << SK_MEM_QUANTUM_SHIFT;
+	return ret;
+}
 EXPORT_SYMBOL(__sk_mem_schedule);
 
 /**
- *	__sk_mem_reclaim - reclaim memory_allocated
+ *	__sk_mem_reduce_allocated - reclaim memory_allocated
  *	@sk: socket
- *	@amount: number of bytes (rounded down to a SK_MEM_QUANTUM multiple)
+ *	@amount: number of quanta
+ *
+ *	Similar to __sk_mem_reclaim(), but does not update sk_forward_alloc
  */
-void __sk_mem_reclaim(struct sock *sk, int amount)
+void __sk_mem_reduce_allocated(struct sock *sk, int amount)
 {
-	amount >>= SK_MEM_QUANTUM_SHIFT;
 	sk_memory_allocated_sub(sk, amount);
-	sk->sk_forward_alloc -= amount << SK_MEM_QUANTUM_SHIFT;
 
 	if (mem_cgroup_sockets_enabled && sk->sk_memcg)
 		mem_cgroup_uncharge_skmem(sk->sk_memcg, amount);
@@ -2199,6 +2212,19 @@
 	    (sk_memory_allocated(sk) < sk_prot_mem_limits(sk, 0)))
 		sk_leave_memory_pressure(sk);
 }
+EXPORT_SYMBOL(__sk_mem_reduce_allocated);
+
+/**
+ *	__sk_mem_reclaim - reclaim sk_forward_alloc and memory_allocated
+ *	@sk: socket
+ *	@amount: number of bytes (rounded down to a SK_MEM_QUANTUM multiple)
+ */
+void __sk_mem_reclaim(struct sock *sk, int amount)
+{
+	amount >>= SK_MEM_QUANTUM_SHIFT;
+	sk->sk_forward_alloc -= amount << SK_MEM_QUANTUM_SHIFT;
+	__sk_mem_reduce_allocated(sk, amount);
+}
 EXPORT_SYMBOL(__sk_mem_reclaim);
 
 int sk_set_peek_off(struct sock *sk, int val)
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 7d96dc2..c833271 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -1172,6 +1172,112 @@
 	return ret;
 }
 
+static void udp_rmem_release(struct sock *sk, int size, int partial)
+{
+	int amt;
+
+	atomic_sub(size, &sk->sk_rmem_alloc);
+
+	spin_lock_bh(&sk->sk_receive_queue.lock);
+	sk->sk_forward_alloc += size;
+	amt = (sk->sk_forward_alloc - partial) & ~(SK_MEM_QUANTUM - 1);
+	sk->sk_forward_alloc -= amt;
+	spin_unlock_bh(&sk->sk_receive_queue.lock);
+
+	if (amt)
+		__sk_mem_reduce_allocated(sk, amt >> SK_MEM_QUANTUM_SHIFT);
+}
+
+static void udp_rmem_free(struct sk_buff *skb)
+{
+	udp_rmem_release(skb->sk, skb->truesize, 1);
+}
+
+int __udp_enqueue_schedule_skb(struct sock *sk, struct sk_buff *skb)
+{
+	struct sk_buff_head *list = &sk->sk_receive_queue;
+	int rmem, delta, amt, err = -ENOMEM;
+	int size = skb->truesize;
+
+	/* try to avoid the costly atomic add/sub pair when the receive
+	 * queue is full; always allow at least a packet
+	 */
+	rmem = atomic_read(&sk->sk_rmem_alloc);
+	if (rmem && (rmem + size > sk->sk_rcvbuf))
+		goto drop;
+
+	/* we drop only if the receive buf is full and the receive
+	 * queue contains some other skb
+	 */
+	rmem = atomic_add_return(size, &sk->sk_rmem_alloc);
+	if ((rmem > sk->sk_rcvbuf) && (rmem > size))
+		goto uncharge_drop;
+
+	spin_lock(&list->lock);
+	if (size >= sk->sk_forward_alloc) {
+		amt = sk_mem_pages(size);
+		delta = amt << SK_MEM_QUANTUM_SHIFT;
+		if (!__sk_mem_raise_allocated(sk, delta, amt, SK_MEM_RECV)) {
+			err = -ENOBUFS;
+			spin_unlock(&list->lock);
+			goto uncharge_drop;
+		}
+
+		sk->sk_forward_alloc += delta;
+	}
+
+	sk->sk_forward_alloc -= size;
+
+	/* the skb owner in now the udp socket */
+	skb->sk = sk;
+	skb->destructor = udp_rmem_free;
+	skb->dev = NULL;
+	sock_skb_set_dropcount(sk, skb);
+
+	__skb_queue_tail(list, skb);
+	spin_unlock(&list->lock);
+
+	if (!sock_flag(sk, SOCK_DEAD))
+		sk->sk_data_ready(sk);
+
+	return 0;
+
+uncharge_drop:
+	atomic_sub(skb->truesize, &sk->sk_rmem_alloc);
+
+drop:
+	atomic_inc(&sk->sk_drops);
+	return err;
+}
+EXPORT_SYMBOL_GPL(__udp_enqueue_schedule_skb);
+
+static void udp_destruct_sock(struct sock *sk)
+{
+	/* reclaim completely the forward allocated memory */
+	__skb_queue_purge(&sk->sk_receive_queue);
+	udp_rmem_release(sk, 0, 0);
+	inet_sock_destruct(sk);
+}
+
+int udp_init_sock(struct sock *sk)
+{
+	sk->sk_destruct = udp_destruct_sock;
+	return 0;
+}
+EXPORT_SYMBOL_GPL(udp_init_sock);
+
+void skb_consume_udp(struct sock *sk, struct sk_buff *skb, int len)
+{
+	if (unlikely(READ_ONCE(sk->sk_peek_off) >= 0)) {
+		bool slow = lock_sock_fast(sk);
+
+		sk_peek_offset_bwd(sk, len);
+		unlock_sock_fast(sk, slow);
+	}
+	consume_skb(skb);
+}
+EXPORT_SYMBOL_GPL(skb_consume_udp);
+
 /**
  *	first_packet_length	- return length of first packet in receive queue
  *	@sk: socket
@@ -1201,13 +1307,7 @@
 	res = skb ? skb->len : -1;
 	spin_unlock_bh(&rcvq->lock);
 
-	if (!skb_queue_empty(&list_kill)) {
-		bool slow = lock_sock_fast(sk);
-
-		__skb_queue_purge(&list_kill);
-		sk_mem_reclaim_partial(sk);
-		unlock_sock_fast(sk, slow);
-	}
+	__skb_queue_purge(&list_kill);
 	return res;
 }
 
@@ -1256,7 +1356,6 @@
 	int err;
 	int is_udplite = IS_UDPLITE(sk);
 	bool checksum_valid = false;
-	bool slow;
 
 	if (flags & MSG_ERRQUEUE)
 		return ip_recv_error(sk, msg, len, addr_len);
@@ -1297,13 +1396,12 @@
 	}
 
 	if (unlikely(err)) {
-		trace_kfree_skb(skb, udp_recvmsg);
 		if (!peeked) {
 			atomic_inc(&sk->sk_drops);
 			UDP_INC_STATS(sock_net(sk),
 				      UDP_MIB_INERRORS, is_udplite);
 		}
-		skb_free_datagram_locked(sk, skb);
+		kfree_skb(skb);
 		return err;
 	}
 
@@ -1328,16 +1426,15 @@
 	if (flags & MSG_TRUNC)
 		err = ulen;
 
-	__skb_free_datagram_locked(sk, skb, peeking ? -err : err);
+	skb_consume_udp(sk, skb, peeking ? -err : err);
 	return err;
 
 csum_copy_err:
-	slow = lock_sock_fast(sk);
-	if (!skb_kill_datagram(sk, skb, flags)) {
+	if (!__sk_queue_drop_skb(sk, skb, flags)) {
 		UDP_INC_STATS(sock_net(sk), UDP_MIB_CSUMERRORS, is_udplite);
 		UDP_INC_STATS(sock_net(sk), UDP_MIB_INERRORS, is_udplite);
 	}
-	unlock_sock_fast(sk, slow);
+	kfree_skb(skb);
 
 	/* starting over for a new packet, but check if we need to yield */
 	cond_resched();
@@ -1456,7 +1553,7 @@
 		sk_incoming_cpu_update(sk);
 	}
 
-	rc = __sock_queue_rcv_skb(sk, skb);
+	rc = __udp_enqueue_schedule_skb(sk, skb);
 	if (rc < 0) {
 		int is_udplite = IS_UDPLITE(sk);
 
@@ -1471,7 +1568,6 @@
 	}
 
 	return 0;
-
 }
 
 static struct static_key udp_encap_needed __read_mostly;
@@ -1493,7 +1589,6 @@
 int udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
 {
 	struct udp_sock *up = udp_sk(sk);
-	int rc;
 	int is_udplite = IS_UDPLITE(sk);
 
 	/*
@@ -1580,25 +1675,9 @@
 		goto drop;
 
 	udp_csum_pull_header(skb);
-	if (sk_rcvqueues_full(sk, sk->sk_rcvbuf)) {
-		__UDP_INC_STATS(sock_net(sk), UDP_MIB_RCVBUFERRORS,
-				is_udplite);
-		goto drop;
-	}
-
-	rc = 0;
 
 	ipv4_pktinfo_prepare(sk, skb);
-	bh_lock_sock(sk);
-	if (!sock_owned_by_user(sk))
-		rc = __udp_queue_rcv_skb(sk, skb);
-	else if (sk_add_backlog(sk, skb, sk->sk_rcvbuf)) {
-		bh_unlock_sock(sk);
-		goto drop;
-	}
-	bh_unlock_sock(sk);
-
-	return rc;
+	return __udp_queue_rcv_skb(sk, skb);
 
 csum_error:
 	__UDP_INC_STATS(sock_net(sk), UDP_MIB_CSUMERRORS, is_udplite);
@@ -2208,13 +2287,13 @@
 	.connect	   = ip4_datagram_connect,
 	.disconnect	   = udp_disconnect,
 	.ioctl		   = udp_ioctl,
+	.init		   = udp_init_sock,
 	.destroy	   = udp_destroy_sock,
 	.setsockopt	   = udp_setsockopt,
 	.getsockopt	   = udp_getsockopt,
 	.sendmsg	   = udp_sendmsg,
 	.recvmsg	   = udp_recvmsg,
 	.sendpage	   = udp_sendpage,
-	.backlog_rcv	   = __udp_queue_rcv_skb,
 	.release_cb	   = ip4_datagram_release_cb,
 	.hash		   = udp_lib_hash,
 	.unhash		   = udp_lib_unhash,
diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
index 9aa7c1c..71963b2 100644
--- a/net/ipv6/udp.c
+++ b/net/ipv6/udp.c
@@ -334,7 +334,6 @@
 	int is_udplite = IS_UDPLITE(sk);
 	bool checksum_valid = false;
 	int is_udp4;
-	bool slow;
 
 	if (flags & MSG_ERRQUEUE)
 		return ipv6_recv_error(sk, msg, len, addr_len);
@@ -378,7 +377,6 @@
 			goto csum_copy_err;
 	}
 	if (unlikely(err)) {
-		trace_kfree_skb(skb, udpv6_recvmsg);
 		if (!peeked) {
 			atomic_inc(&sk->sk_drops);
 			if (is_udp4)
@@ -388,7 +386,7 @@
 				UDP6_INC_STATS(sock_net(sk), UDP_MIB_INERRORS,
 					       is_udplite);
 		}
-		skb_free_datagram_locked(sk, skb);
+		kfree_skb(skb);
 		return err;
 	}
 	if (!peeked) {
@@ -437,12 +435,11 @@
 	if (flags & MSG_TRUNC)
 		err = ulen;
 
-	__skb_free_datagram_locked(sk, skb, peeking ? -err : err);
+	skb_consume_udp(sk, skb, peeking ? -err : err);
 	return err;
 
 csum_copy_err:
-	slow = lock_sock_fast(sk);
-	if (!skb_kill_datagram(sk, skb, flags)) {
+	if (!__sk_queue_drop_skb(sk, skb, flags)) {
 		if (is_udp4) {
 			UDP_INC_STATS(sock_net(sk),
 				      UDP_MIB_CSUMERRORS, is_udplite);
@@ -455,7 +452,7 @@
 				       UDP_MIB_INERRORS, is_udplite);
 		}
 	}
-	unlock_sock_fast(sk, slow);
+	kfree_skb(skb);
 
 	/* starting over for a new packet, but check if we need to yield */
 	cond_resched();
@@ -523,7 +520,7 @@
 		sk_incoming_cpu_update(sk);
 	}
 
-	rc = __sock_queue_rcv_skb(sk, skb);
+	rc = __udp_enqueue_schedule_skb(sk, skb);
 	if (rc < 0) {
 		int is_udplite = IS_UDPLITE(sk);
 
@@ -535,6 +532,7 @@
 		kfree_skb(skb);
 		return -1;
 	}
+
 	return 0;
 }
 
@@ -556,7 +554,6 @@
 int udpv6_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
 {
 	struct udp_sock *up = udp_sk(sk);
-	int rc;
 	int is_udplite = IS_UDPLITE(sk);
 
 	if (!xfrm6_policy_check(sk, XFRM_POLICY_IN, skb))
@@ -622,25 +619,10 @@
 		goto drop;
 
 	udp_csum_pull_header(skb);
-	if (sk_rcvqueues_full(sk, sk->sk_rcvbuf)) {
-		__UDP6_INC_STATS(sock_net(sk),
-				 UDP_MIB_RCVBUFERRORS, is_udplite);
-		goto drop;
-	}
 
 	skb_dst_drop(skb);
 
-	bh_lock_sock(sk);
-	rc = 0;
-	if (!sock_owned_by_user(sk))
-		rc = __udpv6_queue_rcv_skb(sk, skb);
-	else if (sk_add_backlog(sk, skb, sk->sk_rcvbuf)) {
-		bh_unlock_sock(sk);
-		goto drop;
-	}
-	bh_unlock_sock(sk);
-
-	return rc;
+	return __udpv6_queue_rcv_skb(sk, skb);
 
 csum_error:
 	__UDP6_INC_STATS(sock_net(sk), UDP_MIB_CSUMERRORS, is_udplite);
@@ -1433,12 +1415,12 @@
 	.connect	   = ip6_datagram_connect,
 	.disconnect	   = udp_disconnect,
 	.ioctl		   = udp_ioctl,
+	.init		   = udp_init_sock,
 	.destroy	   = udpv6_destroy_sock,
 	.setsockopt	   = udpv6_setsockopt,
 	.getsockopt	   = udpv6_getsockopt,
 	.sendmsg	   = udpv6_sendmsg,
 	.recvmsg	   = udpv6_recvmsg,
-	.backlog_rcv	   = __udpv6_queue_rcv_skb,
 	.release_cb	   = ip6_datagram_release_cb,
 	.hash		   = udp_lib_hash,
 	.unhash		   = udp_lib_unhash,
diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c
index 57625f6..e2a55dc 100644
--- a/net/sunrpc/svcsock.c
+++ b/net/sunrpc/svcsock.c
@@ -39,6 +39,7 @@
 #include <net/checksum.h>
 #include <net/ip.h>
 #include <net/ipv6.h>
+#include <net/udp.h>
 #include <net/tcp.h>
 #include <net/tcp_states.h>
 #include <asm/uaccess.h>
@@ -129,6 +130,18 @@
 	}
 }
 
+static void svc_release_udp_skb(struct svc_rqst *rqstp)
+{
+	struct sk_buff *skb = rqstp->rq_xprt_ctxt;
+
+	if (skb) {
+		rqstp->rq_xprt_ctxt = NULL;
+
+		dprintk("svc: service %p, releasing skb %p\n", rqstp, skb);
+		consume_skb(skb);
+	}
+}
+
 union svc_pktinfo_u {
 	struct in_pktinfo pkti;
 	struct in6_pktinfo pkti6;
@@ -575,7 +588,7 @@
 			goto out_free;
 		}
 		local_bh_enable();
-		skb_free_datagram_locked(svsk->sk_sk, skb);
+		consume_skb(skb);
 	} else {
 		/* we can use it in-place */
 		rqstp->rq_arg.head[0].iov_base = skb->data;
@@ -602,8 +615,7 @@
 
 	return len;
 out_free:
-	trace_kfree_skb(skb, svc_udp_recvfrom);
-	skb_free_datagram_locked(svsk->sk_sk, skb);
+	kfree_skb(skb);
 	return 0;
 }
 
@@ -660,7 +672,7 @@
 	.xpo_create = svc_udp_create,
 	.xpo_recvfrom = svc_udp_recvfrom,
 	.xpo_sendto = svc_udp_sendto,
-	.xpo_release_rqst = svc_release_skb,
+	.xpo_release_rqst = svc_release_udp_skb,
 	.xpo_detach = svc_sock_detach,
 	.xpo_free = svc_sock_free,
 	.xpo_prep_reply_hdr = svc_udp_prep_reply_hdr,
diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c
index 0137af1..1758665 100644
--- a/net/sunrpc/xprtsock.c
+++ b/net/sunrpc/xprtsock.c
@@ -1083,7 +1083,7 @@
 		skb = skb_recv_datagram(sk, 0, 1, &err);
 		if (skb != NULL) {
 			xs_udp_data_read_skb(&transport->xprt, sk, skb);
-			skb_free_datagram_locked(sk, skb);
+			consume_skb(skb);
 			continue;
 		}
 		if (!test_and_clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state))