tcp: Fix MD5 signatures for non-linear skbs

Currently, the MD5 code assumes that the SKBs are linear and, in the case
that they aren't, happily goes off and hashes off the end of the SKB and
into random memory.

Reported by Stephen Hemminger in [1]. Advice thanks to Stephen and Evgeniy
Polyakov. Also includes a couple of missed route_caps from Stephen's patch
in [2].

[1] http://marc.info/?l=linux-netdev&m=121445989106145&w=2
[2] http://marc.info/?l=linux-netdev&m=121459157816964&w=2

Signed-off-by: Adam Langley <agl@imperialviolet.org>
Acked-by: Stephen Hemminger <shemminger@vyatta.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index 29adc66..5400d75 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -87,9 +87,8 @@
 #ifdef CONFIG_TCP_MD5SIG
 static struct tcp_md5sig_key *tcp_v4_md5_do_lookup(struct sock *sk,
 						   __be32 addr);
-static int tcp_v4_do_calc_md5_hash(char *md5_hash, struct tcp_md5sig_key *key,
-				   __be32 saddr, __be32 daddr,
-				   struct tcphdr *th, unsigned int tcplen);
+static int tcp_v4_md5_hash_hdr(char *md5_hash, struct tcp_md5sig_key *key,
+			       __be32 daddr, __be32 saddr, struct tcphdr *th);
 #else
 static inline
 struct tcp_md5sig_key *tcp_v4_md5_do_lookup(struct sock *sk, __be32 addr)
@@ -583,11 +582,9 @@
 		arg.iov[0].iov_len += TCPOLEN_MD5SIG_ALIGNED;
 		rep.th.doff = arg.iov[0].iov_len / 4;
 
-		tcp_v4_do_calc_md5_hash((__u8 *)&rep.opt[1],
-					key,
-					ip_hdr(skb)->daddr,
-					ip_hdr(skb)->saddr,
-					&rep.th, arg.iov[0].iov_len);
+		tcp_v4_md5_hash_hdr((__u8 *) &rep.opt[1],
+				     key, ip_hdr(skb)->daddr,
+				     ip_hdr(skb)->saddr, &rep.th);
 	}
 #endif
 	arg.csum = csum_tcpudp_nofold(ip_hdr(skb)->daddr,
@@ -657,11 +654,9 @@
 		arg.iov[0].iov_len += TCPOLEN_MD5SIG_ALIGNED;
 		rep.th.doff = arg.iov[0].iov_len/4;
 
-		tcp_v4_do_calc_md5_hash((__u8 *)&rep.opt[offset],
-					key,
-					ip_hdr(skb)->daddr,
-					ip_hdr(skb)->saddr,
-					&rep.th, arg.iov[0].iov_len);
+		tcp_v4_md5_hash_hdr((__u8 *) &rep.opt[offset],
+				    key, ip_hdr(skb)->daddr,
+				    ip_hdr(skb)->saddr, &rep.th);
 	}
 #endif
 	arg.csum = csum_tcpudp_nofold(ip_hdr(skb)->daddr,
@@ -989,28 +984,16 @@
 				 newkey, cmd.tcpm_keylen);
 }
 
-static int tcp_v4_do_calc_md5_hash(char *md5_hash, struct tcp_md5sig_key *key,
-				   __be32 saddr, __be32 daddr,
-				   struct tcphdr *th,
-				   unsigned int tcplen)
+static int tcp_v4_md5_hash_pseudoheader(struct tcp_md5sig_pool *hp,
+					__be32 daddr, __be32 saddr, int nbytes)
 {
-	struct tcp_md5sig_pool *hp;
 	struct tcp4_pseudohdr *bp;
-	int err;
-
-	/*
-	 * Okay, so RFC2385 is turned on for this connection,
-	 * so we need to generate the MD5 hash for the packet now.
-	 */
-
-	hp = tcp_get_md5sig_pool();
-	if (!hp)
-		goto clear_hash_noput;
+	struct scatterlist sg;
 
 	bp = &hp->md5_blk.ip4;
 
 	/*
-	 * The TCP pseudo-header (in the order: source IP address,
+	 * 1. the TCP pseudo-header (in the order: source IP address,
 	 * destination IP address, zero-padded protocol number, and
 	 * segment length)
 	 */
@@ -1018,48 +1001,95 @@
 	bp->daddr = daddr;
 	bp->pad = 0;
 	bp->protocol = IPPROTO_TCP;
-	bp->len = htons(tcplen);
+	bp->len = cpu_to_be16(nbytes);
 
-	err = tcp_calc_md5_hash(md5_hash, key, sizeof(*bp),
-				th, tcplen, hp);
-	if (err)
+	sg_init_one(&sg, bp, sizeof(*bp));
+	return crypto_hash_update(&hp->md5_desc, &sg, sizeof(*bp));
+}
+
+static int tcp_v4_md5_hash_hdr(char *md5_hash, struct tcp_md5sig_key *key,
+			       __be32 daddr, __be32 saddr, struct tcphdr *th)
+{
+	struct tcp_md5sig_pool *hp;
+	struct hash_desc *desc;
+
+	hp = tcp_get_md5sig_pool();
+	if (!hp)
+		goto clear_hash_noput;
+	desc = &hp->md5_desc;
+
+	if (crypto_hash_init(desc))
+		goto clear_hash;
+	if (tcp_v4_md5_hash_pseudoheader(hp, daddr, saddr, th->doff << 2))
+		goto clear_hash;
+	if (tcp_md5_hash_header(hp, th))
+		goto clear_hash;
+	if (tcp_md5_hash_key(hp, key))
+		goto clear_hash;
+	if (crypto_hash_final(desc, md5_hash))
 		goto clear_hash;
 
-	/* Free up the crypto pool */
 	tcp_put_md5sig_pool();
-out:
 	return 0;
+
 clear_hash:
 	tcp_put_md5sig_pool();
 clear_hash_noput:
 	memset(md5_hash, 0, 16);
-	goto out;
+	return 1;
 }
 
-int tcp_v4_calc_md5_hash(char *md5_hash, struct tcp_md5sig_key *key,
-			 struct sock *sk,
-			 struct dst_entry *dst,
-			 struct request_sock *req,
-			 struct tcphdr *th,
-			 unsigned int tcplen)
+int tcp_v4_md5_hash_skb(char *md5_hash, struct tcp_md5sig_key *key,
+			struct sock *sk, struct request_sock *req,
+			struct sk_buff *skb)
 {
+	struct tcp_md5sig_pool *hp;
+	struct hash_desc *desc;
+	struct tcphdr *th = tcp_hdr(skb);
 	__be32 saddr, daddr;
 
 	if (sk) {
 		saddr = inet_sk(sk)->saddr;
 		daddr = inet_sk(sk)->daddr;
+	} else if (req) {
+		saddr = inet_rsk(req)->loc_addr;
+		daddr = inet_rsk(req)->rmt_addr;
 	} else {
-		struct rtable *rt = (struct rtable *)dst;
-		BUG_ON(!rt);
-		saddr = rt->rt_src;
-		daddr = rt->rt_dst;
+		const struct iphdr *iph = ip_hdr(skb);
+		saddr = iph->saddr;
+		daddr = iph->daddr;
 	}
-	return tcp_v4_do_calc_md5_hash(md5_hash, key,
-				       saddr, daddr,
-				       th, tcplen);
+
+	hp = tcp_get_md5sig_pool();
+	if (!hp)
+		goto clear_hash_noput;
+	desc = &hp->md5_desc;
+
+	if (crypto_hash_init(desc))
+		goto clear_hash;
+
+	if (tcp_v4_md5_hash_pseudoheader(hp, daddr, saddr, skb->len))
+		goto clear_hash;
+	if (tcp_md5_hash_header(hp, th))
+		goto clear_hash;
+	if (tcp_md5_hash_skb_data(hp, skb, th->doff << 2))
+		goto clear_hash;
+	if (tcp_md5_hash_key(hp, key))
+		goto clear_hash;
+	if (crypto_hash_final(desc, md5_hash))
+		goto clear_hash;
+
+	tcp_put_md5sig_pool();
+	return 0;
+
+clear_hash:
+	tcp_put_md5sig_pool();
+clear_hash_noput:
+	memset(md5_hash, 0, 16);
+	return 1;
 }
 
-EXPORT_SYMBOL(tcp_v4_calc_md5_hash);
+EXPORT_SYMBOL(tcp_v4_md5_hash_skb);
 
 static int tcp_v4_inbound_md5_hash(struct sock *sk, struct sk_buff *skb)
 {
@@ -1104,10 +1134,9 @@
 	/* Okay, so this is hash_expected and hash_location -
 	 * so we need to calculate the checksum.
 	 */
-	genhash = tcp_v4_do_calc_md5_hash(newhash,
-					  hash_expected,
-					  iph->saddr, iph->daddr,
-					  th, skb->len);
+	genhash = tcp_v4_md5_hash_skb(newhash,
+				      hash_expected,
+				      NULL, NULL, skb);
 
 	if (genhash || memcmp(hash_location, newhash, 16) != 0) {
 		if (net_ratelimit()) {
@@ -1356,6 +1385,7 @@
 		if (newkey != NULL)
 			tcp_v4_md5_do_add(newsk, inet_sk(sk)->daddr,
 					  newkey, key->keylen);
+		newsk->sk_route_caps &= ~NETIF_F_GSO_MASK;
 	}
 #endif
 
@@ -1719,7 +1749,7 @@
 #ifdef CONFIG_TCP_MD5SIG
 static struct tcp_sock_af_ops tcp_sock_ipv4_specific = {
 	.md5_lookup		= tcp_v4_md5_lookup,
-	.calc_md5_hash		= tcp_v4_calc_md5_hash,
+	.calc_md5_hash		= tcp_v4_md5_hash_skb,
 	.md5_add		= tcp_v4_md5_add_func,
 	.md5_parse		= tcp_v4_parse_md5_keys,
 };