inet: refactor inet[6]_lookup functions to take skb

This is a preliminary step to allow fast socket lookup of SO_REUSEPORT
groups.  Doing so with a BPF filter will require access to the
skb in question.  This change plumbs the skb (and offset to payload
data) through the call stack to the listening socket lookup
implementations where it will be used in a following patch.

Signed-off-by: Craig Gallek <kraig@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/include/net/addrconf.h b/include/net/addrconf.h
index 47f52d3..730d856 100644
--- a/include/net/addrconf.h
+++ b/include/net/addrconf.h
@@ -87,6 +87,8 @@
 		      u32 banned_flags);
 int ipv6_get_lladdr(struct net_device *dev, struct in6_addr *addr,
 		    u32 banned_flags);
+int ipv4_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
+			 bool match_wildcard);
 int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
 			 bool match_wildcard);
 void addrconf_join_solict(struct net_device *dev, const struct in6_addr *addr);
diff --git a/include/net/inet6_hashtables.h b/include/net/inet6_hashtables.h
index b3c28a9..28332bd 100644
--- a/include/net/inet6_hashtables.h
+++ b/include/net/inet6_hashtables.h
@@ -53,6 +53,7 @@
 
 struct sock *inet6_lookup_listener(struct net *net,
 				   struct inet_hashinfo *hashinfo,
+				   struct sk_buff *skb, int doff,
 				   const struct in6_addr *saddr,
 				   const __be16 sport,
 				   const struct in6_addr *daddr,
@@ -60,6 +61,7 @@
 
 static inline struct sock *__inet6_lookup(struct net *net,
 					  struct inet_hashinfo *hashinfo,
+					  struct sk_buff *skb, int doff,
 					  const struct in6_addr *saddr,
 					  const __be16 sport,
 					  const struct in6_addr *daddr,
@@ -71,12 +73,12 @@
 	if (sk)
 		return sk;
 
-	return inet6_lookup_listener(net, hashinfo, saddr, sport,
+	return inet6_lookup_listener(net, hashinfo, skb, doff, saddr, sport,
 				     daddr, hnum, dif);
 }
 
 static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
-					      struct sk_buff *skb,
+					      struct sk_buff *skb, int doff,
 					      const __be16 sport,
 					      const __be16 dport,
 					      int iif)
@@ -86,13 +88,14 @@
 	if (sk)
 		return sk;
 
-	return __inet6_lookup(dev_net(skb_dst(skb)->dev), hashinfo,
-			      &ipv6_hdr(skb)->saddr, sport,
+	return __inet6_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
+			      doff, &ipv6_hdr(skb)->saddr, sport,
 			      &ipv6_hdr(skb)->daddr, ntohs(dport),
 			      iif);
 }
 
 struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo,
+			  struct sk_buff *skb, int doff,
 			  const struct in6_addr *saddr, const __be16 sport,
 			  const struct in6_addr *daddr, const __be16 dport,
 			  const int dif);
diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h
index 554440e..8240339 100644
--- a/include/net/inet_hashtables.h
+++ b/include/net/inet_hashtables.h
@@ -213,6 +213,7 @@
 
 struct sock *__inet_lookup_listener(struct net *net,
 				    struct inet_hashinfo *hashinfo,
+				    struct sk_buff *skb, int doff,
 				    const __be32 saddr, const __be16 sport,
 				    const __be32 daddr,
 				    const unsigned short hnum,
@@ -220,10 +221,11 @@
 
 static inline struct sock *inet_lookup_listener(struct net *net,
 		struct inet_hashinfo *hashinfo,
+		struct sk_buff *skb, int doff,
 		__be32 saddr, __be16 sport,
 		__be32 daddr, __be16 dport, int dif)
 {
-	return __inet_lookup_listener(net, hashinfo, saddr, sport,
+	return __inet_lookup_listener(net, hashinfo, skb, doff, saddr, sport,
 				      daddr, ntohs(dport), dif);
 }
 
@@ -299,6 +301,7 @@
 
 static inline struct sock *__inet_lookup(struct net *net,
 					 struct inet_hashinfo *hashinfo,
+					 struct sk_buff *skb, int doff,
 					 const __be32 saddr, const __be16 sport,
 					 const __be32 daddr, const __be16 dport,
 					 const int dif)
@@ -307,12 +310,13 @@
 	struct sock *sk = __inet_lookup_established(net, hashinfo,
 				saddr, sport, daddr, hnum, dif);
 
-	return sk ? : __inet_lookup_listener(net, hashinfo, saddr, sport,
-					     daddr, hnum, dif);
+	return sk ? : __inet_lookup_listener(net, hashinfo, skb, doff, saddr,
+					     sport, daddr, hnum, dif);
 }
 
 static inline struct sock *inet_lookup(struct net *net,
 				       struct inet_hashinfo *hashinfo,
+				       struct sk_buff *skb, int doff,
 				       const __be32 saddr, const __be16 sport,
 				       const __be32 daddr, const __be16 dport,
 				       const int dif)
@@ -320,7 +324,8 @@
 	struct sock *sk;
 
 	local_bh_disable();
-	sk = __inet_lookup(net, hashinfo, saddr, sport, daddr, dport, dif);
+	sk = __inet_lookup(net, hashinfo, skb, doff, saddr, sport, daddr,
+			   dport, dif);
 	local_bh_enable();
 
 	return sk;
@@ -328,6 +333,7 @@
 
 static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
 					     struct sk_buff *skb,
+					     int doff,
 					     const __be16 sport,
 					     const __be16 dport)
 {
@@ -337,8 +343,8 @@
 	if (sk)
 		return sk;
 	else
-		return __inet_lookup(dev_net(skb_dst(skb)->dev), hashinfo,
-				     iph->saddr, sport,
+		return __inet_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
+				     doff, iph->saddr, sport,
 				     iph->daddr, dport, inet_iif(skb));
 }
 
diff --git a/net/dccp/ipv4.c b/net/dccp/ipv4.c
index 5684e14..1e0c600 100644
--- a/net/dccp/ipv4.c
+++ b/net/dccp/ipv4.c
@@ -802,7 +802,7 @@
 	}
 
 lookup:
-	sk = __inet_lookup_skb(&dccp_hashinfo, skb,
+	sk = __inet_lookup_skb(&dccp_hashinfo, skb, __dccp_hdr_len(dh),
 			       dh->dccph_sport, dh->dccph_dport);
 	if (!sk) {
 		dccp_pr_debug("failed to look up flow ID in table and "
diff --git a/net/dccp/ipv6.c b/net/dccp/ipv6.c
index 90a8269..45cbe85 100644
--- a/net/dccp/ipv6.c
+++ b/net/dccp/ipv6.c
@@ -668,7 +668,7 @@
 		DCCP_SKB_CB(skb)->dccpd_ack_seq = dccp_hdr_ack_seq(skb);
 
 lookup:
-	sk = __inet6_lookup_skb(&dccp_hashinfo, skb,
+	sk = __inet6_lookup_skb(&dccp_hashinfo, skb, __dccp_hdr_len(dh),
 			        dh->dccph_sport, dh->dccph_dport,
 				inet6_iif(skb));
 	if (!sk) {
diff --git a/net/ipv4/inet_diag.c b/net/ipv4/inet_diag.c
index 6029157..50c0d96 100644
--- a/net/ipv4/inet_diag.c
+++ b/net/ipv4/inet_diag.c
@@ -357,18 +357,18 @@
 	struct sock *sk;
 
 	if (req->sdiag_family == AF_INET)
-		sk = inet_lookup(net, hashinfo, req->id.idiag_dst[0],
+		sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[0],
 				 req->id.idiag_dport, req->id.idiag_src[0],
 				 req->id.idiag_sport, req->id.idiag_if);
 #if IS_ENABLED(CONFIG_IPV6)
 	else if (req->sdiag_family == AF_INET6) {
 		if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) &&
 		    ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_src))
-			sk = inet_lookup(net, hashinfo, req->id.idiag_dst[3],
+			sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[3],
 					 req->id.idiag_dport, req->id.idiag_src[3],
 					 req->id.idiag_sport, req->id.idiag_if);
 		else
-			sk = inet6_lookup(net, hashinfo,
+			sk = inet6_lookup(net, hashinfo, NULL, 0,
 					  (struct in6_addr *)req->id.idiag_dst,
 					  req->id.idiag_dport,
 					  (struct in6_addr *)req->id.idiag_src,
diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
index b6023b7..5e4290b 100644
--- a/net/ipv4/inet_hashtables.c
+++ b/net/ipv4/inet_hashtables.c
@@ -205,6 +205,7 @@
 
 struct sock *__inet_lookup_listener(struct net *net,
 				    struct inet_hashinfo *hashinfo,
+				    struct sk_buff *skb, int doff,
 				    const __be32 saddr, __be16 sport,
 				    const __be32 daddr, const unsigned short hnum,
 				    const int dif)
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index 0d381fa..3f872a6 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -637,8 +637,8 @@
 		 * Incoming packet is checked with md5 hash with finding key,
 		 * no RST generated if md5 hash doesn't match.
 		 */
-		sk1 = __inet_lookup_listener(net,
-					     &tcp_hashinfo, ip_hdr(skb)->saddr,
+		sk1 = __inet_lookup_listener(net, &tcp_hashinfo, NULL, 0,
+					     ip_hdr(skb)->saddr,
 					     th->source, ip_hdr(skb)->daddr,
 					     ntohs(th->source), inet_iif(skb));
 		/* don't send rst if it can't find key */
@@ -1581,7 +1581,8 @@
 	TCP_SKB_CB(skb)->sacked	 = 0;
 
 lookup:
-	sk = __inet_lookup_skb(&tcp_hashinfo, skb, th->source, th->dest);
+	sk = __inet_lookup_skb(&tcp_hashinfo, skb, __tcp_hdrlen(th), th->source,
+			       th->dest);
 	if (!sk)
 		goto no_tcp_socket;
 
@@ -1695,7 +1696,8 @@
 	switch (tcp_timewait_state_process(inet_twsk(sk), skb, th)) {
 	case TCP_TW_SYN: {
 		struct sock *sk2 = inet_lookup_listener(dev_net(skb->dev),
-							&tcp_hashinfo,
+							&tcp_hashinfo, skb,
+							__tcp_hdrlen(th),
 							iph->saddr, th->source,
 							iph->daddr, th->dest,
 							inet_iif(skb));
diff --git a/net/ipv6/inet6_hashtables.c b/net/ipv6/inet6_hashtables.c
index 072653d..004345d 100644
--- a/net/ipv6/inet6_hashtables.c
+++ b/net/ipv6/inet6_hashtables.c
@@ -121,7 +121,9 @@
 }
 
 struct sock *inet6_lookup_listener(struct net *net,
-		struct inet_hashinfo *hashinfo, const struct in6_addr *saddr,
+		struct inet_hashinfo *hashinfo,
+		struct sk_buff *skb, int doff,
+		const struct in6_addr *saddr,
 		const __be16 sport, const struct in6_addr *daddr,
 		const unsigned short hnum, const int dif)
 {
@@ -177,6 +179,7 @@
 EXPORT_SYMBOL_GPL(inet6_lookup_listener);
 
 struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo,
+			  struct sk_buff *skb, int doff,
 			  const struct in6_addr *saddr, const __be16 sport,
 			  const struct in6_addr *daddr, const __be16 dport,
 			  const int dif)
@@ -184,7 +187,8 @@
 	struct sock *sk;
 
 	local_bh_disable();
-	sk = __inet6_lookup(net, hashinfo, saddr, sport, daddr, ntohs(dport), dif);
+	sk = __inet6_lookup(net, hashinfo, skb, doff, saddr, sport, daddr,
+			    ntohs(dport), dif);
 	local_bh_enable();
 
 	return sk;
diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
index d72bcfb..9977b6f 100644
--- a/net/ipv6/tcp_ipv6.c
+++ b/net/ipv6/tcp_ipv6.c
@@ -866,7 +866,8 @@
 		 * no RST generated if md5 hash doesn't match.
 		 */
 		sk1 = inet6_lookup_listener(dev_net(skb_dst(skb)->dev),
-					   &tcp_hashinfo, &ipv6h->saddr,
+					   &tcp_hashinfo, NULL, 0,
+					   &ipv6h->saddr,
 					   th->source, &ipv6h->daddr,
 					   ntohs(th->source), tcp_v6_iif(skb));
 		if (!sk1)
@@ -1375,8 +1376,8 @@
 	hdr = ipv6_hdr(skb);
 
 lookup:
-	sk = __inet6_lookup_skb(&tcp_hashinfo, skb, th->source, th->dest,
-				inet6_iif(skb));
+	sk = __inet6_lookup_skb(&tcp_hashinfo, skb, __tcp_hdrlen(th),
+				th->source, th->dest, inet6_iif(skb));
 	if (!sk)
 		goto no_tcp_socket;
 
@@ -1500,6 +1501,7 @@
 		struct sock *sk2;
 
 		sk2 = inet6_lookup_listener(dev_net(skb->dev), &tcp_hashinfo,
+					    skb, __tcp_hdrlen(th),
 					    &ipv6_hdr(skb)->saddr, th->source,
 					    &ipv6_hdr(skb)->daddr,
 					    ntohs(th->dest), tcp_v6_iif(skb));
diff --git a/net/netfilter/xt_TPROXY.c b/net/netfilter/xt_TPROXY.c
index 3ab591e..7f4414d 100644
--- a/net/netfilter/xt_TPROXY.c
+++ b/net/netfilter/xt_TPROXY.c
@@ -105,19 +105,24 @@
  * belonging to established connections going through that one.
  */
 static inline struct sock *
-nf_tproxy_get_sock_v4(struct net *net, const u8 protocol,
+nf_tproxy_get_sock_v4(struct net *net, struct sk_buff *skb, void *hp,
+		      const u8 protocol,
 		      const __be32 saddr, const __be32 daddr,
 		      const __be16 sport, const __be16 dport,
 		      const struct net_device *in,
 		      const enum nf_tproxy_lookup_t lookup_type)
 {
 	struct sock *sk;
+	struct tcphdr *tcph;
 
 	switch (protocol) {
 	case IPPROTO_TCP:
 		switch (lookup_type) {
 		case NFT_LOOKUP_LISTENER:
-			sk = inet_lookup_listener(net, &tcp_hashinfo,
+			tcph = hp;
+			sk = inet_lookup_listener(net, &tcp_hashinfo, skb,
+						    ip_hdrlen(skb) +
+						      __tcp_hdrlen(tcph),
 						    saddr, sport,
 						    daddr, dport,
 						    in->ifindex);
@@ -169,19 +174,23 @@
 
 #ifdef XT_TPROXY_HAVE_IPV6
 static inline struct sock *
-nf_tproxy_get_sock_v6(struct net *net, const u8 protocol,
+nf_tproxy_get_sock_v6(struct net *net, struct sk_buff *skb, int thoff, void *hp,
+		      const u8 protocol,
 		      const struct in6_addr *saddr, const struct in6_addr *daddr,
 		      const __be16 sport, const __be16 dport,
 		      const struct net_device *in,
 		      const enum nf_tproxy_lookup_t lookup_type)
 {
 	struct sock *sk;
+	struct tcphdr *tcph;
 
 	switch (protocol) {
 	case IPPROTO_TCP:
 		switch (lookup_type) {
 		case NFT_LOOKUP_LISTENER:
-			sk = inet6_lookup_listener(net, &tcp_hashinfo,
+			tcph = hp;
+			sk = inet6_lookup_listener(net, &tcp_hashinfo, skb,
+						   thoff + __tcp_hdrlen(tcph),
 						   saddr, sport,
 						   daddr, ntohs(dport),
 						   in->ifindex);
@@ -267,7 +276,7 @@
 		 * to a listener socket if there's one */
 		struct sock *sk2;
 
-		sk2 = nf_tproxy_get_sock_v4(net, iph->protocol,
+		sk2 = nf_tproxy_get_sock_v4(net, skb, hp, iph->protocol,
 					    iph->saddr, laddr ? laddr : iph->daddr,
 					    hp->source, lport ? lport : hp->dest,
 					    skb->dev, NFT_LOOKUP_LISTENER);
@@ -305,7 +314,7 @@
 	 * addresses, this happens if the redirect already happened
 	 * and the current packet belongs to an already established
 	 * connection */
-	sk = nf_tproxy_get_sock_v4(net, iph->protocol,
+	sk = nf_tproxy_get_sock_v4(net, skb, hp, iph->protocol,
 				   iph->saddr, iph->daddr,
 				   hp->source, hp->dest,
 				   skb->dev, NFT_LOOKUP_ESTABLISHED);
@@ -321,7 +330,7 @@
 	else if (!sk)
 		/* no, there's no established connection, check if
 		 * there's a listener on the redirected addr/port */
-		sk = nf_tproxy_get_sock_v4(net, iph->protocol,
+		sk = nf_tproxy_get_sock_v4(net, skb, hp, iph->protocol,
 					   iph->saddr, laddr,
 					   hp->source, lport,
 					   skb->dev, NFT_LOOKUP_LISTENER);
@@ -429,7 +438,7 @@
 		 * to a listener socket if there's one */
 		struct sock *sk2;
 
-		sk2 = nf_tproxy_get_sock_v6(par->net, tproto,
+		sk2 = nf_tproxy_get_sock_v6(par->net, skb, thoff, hp, tproto,
 					    &iph->saddr,
 					    tproxy_laddr6(skb, &tgi->laddr.in6, &iph->daddr),
 					    hp->source,
@@ -472,7 +481,7 @@
 	 * addresses, this happens if the redirect already happened
 	 * and the current packet belongs to an already established
 	 * connection */
-	sk = nf_tproxy_get_sock_v6(par->net, tproto,
+	sk = nf_tproxy_get_sock_v6(par->net, skb, thoff, hp, tproto,
 				   &iph->saddr, &iph->daddr,
 				   hp->source, hp->dest,
 				   par->in, NFT_LOOKUP_ESTABLISHED);
@@ -487,8 +496,8 @@
 	else if (!sk)
 		/* no there's no established connection, check if
 		 * there's a listener on the redirected addr/port */
-		sk = nf_tproxy_get_sock_v6(par->net, tproto,
-					   &iph->saddr, laddr,
+		sk = nf_tproxy_get_sock_v6(par->net, skb, thoff, hp,
+					   tproto, &iph->saddr, laddr,
 					   hp->source, lport,
 					   par->in, NFT_LOOKUP_LISTENER);
 
diff --git a/net/netfilter/xt_socket.c b/net/netfilter/xt_socket.c
index 2ec08f0..49d14ec 100644
--- a/net/netfilter/xt_socket.c
+++ b/net/netfilter/xt_socket.c
@@ -112,14 +112,15 @@
  *     box.
  */
 static struct sock *
-xt_socket_get_sock_v4(struct net *net, const u8 protocol,
+xt_socket_get_sock_v4(struct net *net, struct sk_buff *skb, const int doff,
+		      const u8 protocol,
 		      const __be32 saddr, const __be32 daddr,
 		      const __be16 sport, const __be16 dport,
 		      const struct net_device *in)
 {
 	switch (protocol) {
 	case IPPROTO_TCP:
-		return __inet_lookup(net, &tcp_hashinfo,
+		return __inet_lookup(net, &tcp_hashinfo, skb, doff,
 				     saddr, sport, daddr, dport,
 				     in->ifindex);
 	case IPPROTO_UDP:
@@ -148,6 +149,8 @@
 					     const struct net_device *indev)
 {
 	const struct iphdr *iph = ip_hdr(skb);
+	struct sk_buff *data_skb = NULL;
+	int doff = 0;
 	__be32 uninitialized_var(daddr), uninitialized_var(saddr);
 	__be16 uninitialized_var(dport), uninitialized_var(sport);
 	u8 uninitialized_var(protocol);
@@ -169,6 +172,10 @@
 		sport = hp->source;
 		daddr = iph->daddr;
 		dport = hp->dest;
+		data_skb = (struct sk_buff *)skb;
+		doff = iph->protocol == IPPROTO_TCP ?
+			ip_hdrlen(skb) + __tcp_hdrlen((struct tcphdr *)hp) :
+			ip_hdrlen(skb) + sizeof(*hp);
 
 	} else if (iph->protocol == IPPROTO_ICMP) {
 		if (extract_icmp4_fields(skb, &protocol, &saddr, &daddr,
@@ -198,8 +205,8 @@
 	}
 #endif
 
-	return xt_socket_get_sock_v4(net, protocol, saddr, daddr,
-				     sport, dport, indev);
+	return xt_socket_get_sock_v4(net, data_skb, doff, protocol, saddr,
+				     daddr, sport, dport, indev);
 }
 
 static bool
@@ -318,14 +325,15 @@
 }
 
 static struct sock *
-xt_socket_get_sock_v6(struct net *net, const u8 protocol,
+xt_socket_get_sock_v6(struct net *net, struct sk_buff *skb, int doff,
+		      const u8 protocol,
 		      const struct in6_addr *saddr, const struct in6_addr *daddr,
 		      const __be16 sport, const __be16 dport,
 		      const struct net_device *in)
 {
 	switch (protocol) {
 	case IPPROTO_TCP:
-		return inet6_lookup(net, &tcp_hashinfo,
+		return inet6_lookup(net, &tcp_hashinfo, skb, doff,
 				    saddr, sport, daddr, dport,
 				    in->ifindex);
 	case IPPROTO_UDP:
@@ -343,6 +351,8 @@
 	__be16 uninitialized_var(dport), uninitialized_var(sport);
 	const struct in6_addr *daddr = NULL, *saddr = NULL;
 	struct ipv6hdr *iph = ipv6_hdr(skb);
+	struct sk_buff *data_skb = NULL;
+	int doff = 0;
 	int thoff = 0, tproto;
 
 	tproto = ipv6_find_hdr(skb, &thoff, -1, NULL, NULL);
@@ -362,6 +372,10 @@
 		sport = hp->source;
 		daddr = &iph->daddr;
 		dport = hp->dest;
+		data_skb = (struct sk_buff *)skb;
+		doff = tproto == IPPROTO_TCP ?
+			thoff + __tcp_hdrlen((struct tcphdr *)hp) :
+			thoff + sizeof(*hp);
 
 	} else if (tproto == IPPROTO_ICMPV6) {
 		struct ipv6hdr ipv6_var;
@@ -373,7 +387,7 @@
 		return NULL;
 	}
 
-	return xt_socket_get_sock_v6(net, tproto, saddr, daddr,
+	return xt_socket_get_sock_v6(net, data_skb, doff, tproto, saddr, daddr,
 				     sport, dport, indev);
 }