inet_diag: allow sk_diag_fill() to handle request socks

inet_diag_fill_req() is renamed to inet_req_diag_fill()
and moved up, so that it can be called fom sk_diag_fill()

inet_diag_bc_sk() is ready to handle request socks.

inet_twsk_diag_dump() is no longer needed.

Signed-off-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/ipv4/inet_diag.c b/net/ipv4/inet_diag.c
index ac7b5c9..e7ba590 100644
--- a/net/ipv4/inet_diag.c
+++ b/net/ipv4/inet_diag.c
@@ -113,14 +113,13 @@
 		return -EMSGSIZE;
 
 	r = nlmsg_data(nlh);
-	BUG_ON((1 << sk->sk_state) & (TCPF_TIME_WAIT | TCPF_NEW_SYN_RECV));
+	BUG_ON(!sk_fullsock(sk));
 
 	inet_diag_msg_common_fill(r, sk);
 	r->idiag_state = sk->sk_state;
 	r->idiag_timer = 0;
 	r->idiag_retrans = 0;
 
-
 	if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown))
 		goto errout;
 
@@ -229,7 +228,6 @@
 
 static int inet_twsk_diag_fill(struct sock *sk,
 			       struct sk_buff *skb,
-			       const struct inet_diag_req_v2 *req,
 			       u32 portid, u32 seq, u16 nlmsg_flags,
 			       const struct nlmsghdr *unlh)
 {
@@ -265,6 +263,39 @@
 	return 0;
 }
 
+static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
+			      u32 portid, u32 seq, u16 nlmsg_flags,
+			      const struct nlmsghdr *unlh)
+{
+	struct inet_diag_msg *r;
+	struct nlmsghdr *nlh;
+	long tmo;
+
+	nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
+			nlmsg_flags);
+	if (!nlh)
+		return -EMSGSIZE;
+
+	r = nlmsg_data(nlh);
+	inet_diag_msg_common_fill(r, sk);
+	r->idiag_state = TCP_SYN_RECV;
+	r->idiag_timer = 1;
+	r->idiag_retrans = inet_reqsk(sk)->num_retrans;
+
+	BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) !=
+		     offsetof(struct sock, sk_cookie));
+
+	tmo = inet_reqsk(sk)->expires - jiffies;
+	r->idiag_expires = (tmo >= 0) ? jiffies_to_msecs(tmo) : 0;
+	r->idiag_rqueue	= 0;
+	r->idiag_wqueue	= 0;
+	r->idiag_uid	= 0;
+	r->idiag_inode	= 0;
+
+	nlmsg_end(skb, nlh);
+	return 0;
+}
+
 static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
 			const struct inet_diag_req_v2 *r,
 			struct user_namespace *user_ns,
@@ -272,9 +303,13 @@
 			const struct nlmsghdr *unlh)
 {
 	if (sk->sk_state == TCP_TIME_WAIT)
-		return inet_twsk_diag_fill(sk, skb, r, portid, seq,
+		return inet_twsk_diag_fill(sk, skb, portid, seq,
 					   nlmsg_flags, unlh);
 
+	if (sk->sk_state == TCP_NEW_SYN_RECV)
+		return inet_req_diag_fill(sk, skb, portid, seq,
+					  nlmsg_flags, unlh);
+
 	return inet_csk_diag_fill(sk, skb, r, user_ns, portid, seq,
 				  nlmsg_flags, unlh);
 }
@@ -502,7 +537,7 @@
 	entry_fill_addrs(&entry, sk);
 	entry.sport = inet->inet_num;
 	entry.dport = ntohs(inet->inet_dport);
-	entry.userlocks = (sk->sk_state != TCP_TIME_WAIT) ? sk->sk_userlocks : 0;
+	entry.userlocks = sk_fullsock(sk) ? sk->sk_userlocks : 0;
 
 	return inet_diag_bc_run(bc, &entry);
 }
@@ -661,61 +696,6 @@
 #endif
 }
 
-static int inet_twsk_diag_dump(struct sock *sk,
-			       struct sk_buff *skb,
-			       struct netlink_callback *cb,
-			       const struct inet_diag_req_v2 *r,
-			       const struct nlattr *bc)
-{
-	twsk_build_assert();
-
-	if (!inet_diag_bc_sk(bc, sk))
-		return 0;
-
-	return inet_twsk_diag_fill(sk, skb, r,
-				   NETLINK_CB(cb->skb).portid,
-				   cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh);
-}
-
-static int inet_diag_fill_req(struct sk_buff *skb, struct sock *sk,
-			      struct request_sock *req,
-			      struct user_namespace *user_ns,
-			      u32 portid, u32 seq,
-			      const struct nlmsghdr *unlh)
-{
-	const struct inet_request_sock *ireq = inet_rsk(req);
-	struct inet_diag_msg *r;
-	struct nlmsghdr *nlh;
-	long tmo;
-
-	nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
-			NLM_F_MULTI);
-	if (!nlh)
-		return -EMSGSIZE;
-
-	r = nlmsg_data(nlh);
-	inet_diag_msg_common_fill(r, (struct sock *)ireq);
-	r->idiag_state = TCP_SYN_RECV;
-	r->idiag_timer = 1;
-	r->idiag_retrans = req->num_retrans;
-
-	BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) !=
-		     offsetof(struct sock, sk_cookie));
-
-	tmo = req->expires - jiffies;
-	if (tmo < 0)
-		tmo = 0;
-
-	r->idiag_expires = jiffies_to_msecs(tmo);
-	r->idiag_rqueue = 0;
-	r->idiag_wqueue = 0;
-	r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk));
-	r->idiag_inode = 0;
-
-	nlmsg_end(skb, nlh);
-	return 0;
-}
-
 static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk,
 			       struct netlink_callback *cb,
 			       const struct inet_diag_req_v2 *r,
@@ -769,10 +749,10 @@
 					continue;
 			}
 
-			err = inet_diag_fill_req(skb, sk, req,
-						 sk_user_ns(NETLINK_CB(cb->skb).sk),
+			err = inet_req_diag_fill((struct sock *)req, skb,
 						 NETLINK_CB(cb->skb).portid,
-						 cb->nlh->nlmsg_seq, cb->nlh);
+						 cb->nlh->nlmsg_seq,
+						 NLM_F_MULTI, cb->nlh);
 			if (err < 0) {
 				cb->args[3] = j + 1;
 				cb->args[4] = reqnum;
@@ -903,10 +883,16 @@
 			if (r->id.idiag_dport != sk->sk_dport &&
 			    r->id.idiag_dport)
 				goto next_normal;
-			if (sk->sk_state == TCP_TIME_WAIT)
-				res = inet_twsk_diag_dump(sk, skb, cb, r, bc);
-			else
-				res = inet_csk_diag_dump(sk, skb, cb, r, bc);
+			twsk_build_assert();
+
+			if (!inet_diag_bc_sk(bc, sk))
+				goto next_normal;
+
+			res = sk_diag_fill(sk, skb, r,
+					   sk_user_ns(NETLINK_CB(cb->skb).sk),
+					   NETLINK_CB(cb->skb).portid,
+					   cb->nlh->nlmsg_seq, NLM_F_MULTI,
+					   cb->nlh);
 			if (res < 0) {
 				spin_unlock_bh(lock);
 				goto done;