esp: Fix ESN generation under UDP encapsulation

Blair Steven noticed that ESN in conjunction with UDP encapsulation
is broken because we set the temporary ESP header to the wrong spot.

This patch fixes this by first of all using the right spot, i.e.,
4 bytes off the real ESP header, and then saving this information
so that after encryption we can restore it properly.

Fixes: 7021b2e1cddd ("esp4: Switch to new AEAD interface")
Reported-by: Blair Steven <Blair.Steven@alliedtelesis.co.nz>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
Acked-by: Steffen Klassert <steffen.klassert@secunet.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/ipv4/esp4.c b/net/ipv4/esp4.c
index 4779374..d95631d 100644
--- a/net/ipv4/esp4.c
+++ b/net/ipv4/esp4.c
@@ -23,6 +23,11 @@
 	void *tmp;
 };
 
+struct esp_output_extra {
+	__be32 seqhi;
+	u32 esphoff;
+};
+
 #define ESP_SKB_CB(__skb) ((struct esp_skb_cb *)&((__skb)->cb[0]))
 
 static u32 esp4_get_mtu(struct xfrm_state *x, int mtu);
@@ -35,11 +40,11 @@
  *
  * TODO: Use spare space in skb for this where possible.
  */
-static void *esp_alloc_tmp(struct crypto_aead *aead, int nfrags, int seqhilen)
+static void *esp_alloc_tmp(struct crypto_aead *aead, int nfrags, int extralen)
 {
 	unsigned int len;
 
-	len = seqhilen;
+	len = extralen;
 
 	len += crypto_aead_ivsize(aead);
 
@@ -57,15 +62,16 @@
 	return kmalloc(len, GFP_ATOMIC);
 }
 
-static inline __be32 *esp_tmp_seqhi(void *tmp)
+static inline void *esp_tmp_extra(void *tmp)
 {
-	return PTR_ALIGN((__be32 *)tmp, __alignof__(__be32));
+	return PTR_ALIGN(tmp, __alignof__(struct esp_output_extra));
 }
-static inline u8 *esp_tmp_iv(struct crypto_aead *aead, void *tmp, int seqhilen)
+
+static inline u8 *esp_tmp_iv(struct crypto_aead *aead, void *tmp, int extralen)
 {
 	return crypto_aead_ivsize(aead) ?
-	       PTR_ALIGN((u8 *)tmp + seqhilen,
-			 crypto_aead_alignmask(aead) + 1) : tmp + seqhilen;
+	       PTR_ALIGN((u8 *)tmp + extralen,
+			 crypto_aead_alignmask(aead) + 1) : tmp + extralen;
 }
 
 static inline struct aead_request *esp_tmp_req(struct crypto_aead *aead, u8 *iv)
@@ -99,7 +105,7 @@
 {
 	struct ip_esp_hdr *esph = (void *)(skb->data + offset);
 	void *tmp = ESP_SKB_CB(skb)->tmp;
-	__be32 *seqhi = esp_tmp_seqhi(tmp);
+	__be32 *seqhi = esp_tmp_extra(tmp);
 
 	esph->seq_no = esph->spi;
 	esph->spi = *seqhi;
@@ -107,7 +113,11 @@
 
 static void esp_output_restore_header(struct sk_buff *skb)
 {
-	esp_restore_header(skb, skb_transport_offset(skb) - sizeof(__be32));
+	void *tmp = ESP_SKB_CB(skb)->tmp;
+	struct esp_output_extra *extra = esp_tmp_extra(tmp);
+
+	esp_restore_header(skb, skb_transport_offset(skb) + extra->esphoff -
+				sizeof(__be32));
 }
 
 static void esp_output_done_esn(struct crypto_async_request *base, int err)
@@ -121,6 +131,7 @@
 static int esp_output(struct xfrm_state *x, struct sk_buff *skb)
 {
 	int err;
+	struct esp_output_extra *extra;
 	struct ip_esp_hdr *esph;
 	struct crypto_aead *aead;
 	struct aead_request *req;
@@ -137,8 +148,7 @@
 	int tfclen;
 	int nfrags;
 	int assoclen;
-	int seqhilen;
-	__be32 *seqhi;
+	int extralen;
 	__be64 seqno;
 
 	/* skb is pure payload to encrypt */
@@ -166,21 +176,21 @@
 	nfrags = err;
 
 	assoclen = sizeof(*esph);
-	seqhilen = 0;
+	extralen = 0;
 
 	if (x->props.flags & XFRM_STATE_ESN) {
-		seqhilen += sizeof(__be32);
-		assoclen += seqhilen;
+		extralen += sizeof(*extra);
+		assoclen += sizeof(__be32);
 	}
 
-	tmp = esp_alloc_tmp(aead, nfrags, seqhilen);
+	tmp = esp_alloc_tmp(aead, nfrags, extralen);
 	if (!tmp) {
 		err = -ENOMEM;
 		goto error;
 	}
 
-	seqhi = esp_tmp_seqhi(tmp);
-	iv = esp_tmp_iv(aead, tmp, seqhilen);
+	extra = esp_tmp_extra(tmp);
+	iv = esp_tmp_iv(aead, tmp, extralen);
 	req = esp_tmp_req(aead, iv);
 	sg = esp_req_sg(aead, req);
 
@@ -247,8 +257,10 @@
 	 * encryption.
 	 */
 	if ((x->props.flags & XFRM_STATE_ESN)) {
-		esph = (void *)(skb_transport_header(skb) - sizeof(__be32));
-		*seqhi = esph->spi;
+		extra->esphoff = (unsigned char *)esph -
+				 skb_transport_header(skb);
+		esph = (struct ip_esp_hdr *)((unsigned char *)esph - 4);
+		extra->seqhi = esph->spi;
 		esph->seq_no = htonl(XFRM_SKB_CB(skb)->seq.output.hi);
 		aead_request_set_callback(req, 0, esp_output_done_esn, skb);
 	}
@@ -445,7 +457,7 @@
 		goto out;
 
 	ESP_SKB_CB(skb)->tmp = tmp;
-	seqhi = esp_tmp_seqhi(tmp);
+	seqhi = esp_tmp_extra(tmp);
 	iv = esp_tmp_iv(aead, tmp, seqhilen);
 	req = esp_tmp_req(aead, iv);
 	sg = esp_req_sg(aead, req);