[XFRM]: Optimize MTU calculation

Replace the probing based MTU estimation, which usually takes 2-3 iterations
to find a fitting value and may underestimate the MTU, by an exact calculation.

Also fix underestimation of the XFRM trailer_len, which causes unnecessary
reallocations.

Signed-off-by: Patrick McHardy <kaber@trash.net>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/include/net/xfrm.h b/include/net/xfrm.h
index 5a00aa8..e144a25 100644
--- a/include/net/xfrm.h
+++ b/include/net/xfrm.h
@@ -279,7 +279,7 @@
 	xfrm_address_t		*(*local_addr)(struct xfrm_state *, xfrm_address_t *);
 	xfrm_address_t		*(*remote_addr)(struct xfrm_state *, xfrm_address_t *);
 	/* Estimate maximal size of result of transformation of a dgram */
-	u32			(*get_max_size)(struct xfrm_state *, int size);
+	u32			(*get_mtu)(struct xfrm_state *, int size);
 };
 
 extern int xfrm_register_type(struct xfrm_type *type, unsigned short family);
diff --git a/net/ipv4/esp4.c b/net/ipv4/esp4.c
index bdc65d8..a315d5d 100644
--- a/net/ipv4/esp4.c
+++ b/net/ipv4/esp4.c
@@ -272,32 +272,34 @@
 	return -EINVAL;
 }
 
-static u32 esp4_get_max_size(struct xfrm_state *x, int mtu)
+static u32 esp4_get_mtu(struct xfrm_state *x, int mtu)
 {
 	struct esp_data *esp = x->data;
 	u32 blksize = ALIGN(crypto_blkcipher_blocksize(esp->conf.tfm), 4);
-	int enclen = 0;
+	u32 align = max_t(u32, blksize, esp->conf.padlen);
+	u32 rem;
+
+	mtu -= x->props.header_len + esp->auth.icv_trunc_len;
+	rem = mtu & (align - 1);
+	mtu &= ~(align - 1);
 
 	switch (x->props.mode) {
 	case XFRM_MODE_TUNNEL:
-		mtu = ALIGN(mtu +2, blksize);
 		break;
 	default:
 	case XFRM_MODE_TRANSPORT:
 		/* The worst case */
-		mtu = ALIGN(mtu + 2, 4) + blksize - 4;
+		mtu -= blksize - 4;
+		mtu += min_t(u32, blksize - 4, rem);
 		break;
 	case XFRM_MODE_BEET:
 		/* The worst case. */
-		enclen = IPV4_BEET_PHMAXLEN;
-		mtu = ALIGN(mtu + enclen + 2, blksize);
+		mtu -= IPV4_BEET_PHMAXLEN;
+		mtu += min_t(u32, IPV4_BEET_PHMAXLEN, rem);
 		break;
 	}
 
-	if (esp->conf.padlen)
-		mtu = ALIGN(mtu, esp->conf.padlen);
-
-	return mtu + x->props.header_len + esp->auth.icv_trunc_len - enclen;
+	return mtu - 2;
 }
 
 static void esp4_err(struct sk_buff *skb, u32 info)
@@ -340,6 +342,7 @@
 {
 	struct esp_data *esp = NULL;
 	struct crypto_blkcipher *tfm;
+	u32 align;
 
 	/* null auth and encryption can have zero length keys */
 	if (x->aalg) {
@@ -421,7 +424,10 @@
 		}
 	}
 	x->data = esp;
-	x->props.trailer_len = esp4_get_max_size(x, 0) - x->props.header_len;
+	align = ALIGN(crypto_blkcipher_blocksize(esp->conf.tfm), 4);
+	if (esp->conf.padlen)
+		align = max_t(u32, align, esp->conf.padlen);
+	x->props.trailer_len = align + 1 + esp->auth.icv_trunc_len;
 	return 0;
 
 error:
@@ -438,7 +444,7 @@
 	.proto	     	= IPPROTO_ESP,
 	.init_state	= esp_init_state,
 	.destructor	= esp_destroy,
-	.get_max_size	= esp4_get_max_size,
+	.get_mtu	= esp4_get_mtu,
 	.input		= esp_input,
 	.output		= esp_output
 };
diff --git a/net/ipv6/esp6.c b/net/ipv6/esp6.c
index 6b76c4c..7107bb7 100644
--- a/net/ipv6/esp6.c
+++ b/net/ipv6/esp6.c
@@ -235,22 +235,24 @@
 	return ret;
 }
 
-static u32 esp6_get_max_size(struct xfrm_state *x, int mtu)
+static u32 esp6_get_mtu(struct xfrm_state *x, int mtu)
 {
 	struct esp_data *esp = x->data;
 	u32 blksize = ALIGN(crypto_blkcipher_blocksize(esp->conf.tfm), 4);
+	u32 align = max_t(u32, blksize, esp->conf.padlen);
+	u32 rem;
 
-	if (x->props.mode == XFRM_MODE_TUNNEL) {
-		mtu = ALIGN(mtu + 2, blksize);
-	} else {
-		/* The worst case. */
+	mtu -= x->props.header_len + esp->auth.icv_trunc_len;
+	rem = mtu & (align - 1);
+	mtu &= ~(align - 1);
+
+	if (x->props.mode != XFRM_MODE_TUNNEL) {
 		u32 padsize = ((blksize - 1) & 7) + 1;
-		mtu = ALIGN(mtu + 2, padsize) + blksize - padsize;
+		mtu -= blksize - padsize;
+		mtu += min_t(u32, blksize - padsize, rem);
 	}
-	if (esp->conf.padlen)
-		mtu = ALIGN(mtu, esp->conf.padlen);
 
-	return mtu + x->props.header_len + esp->auth.icv_trunc_len;
+	return mtu - 2;
 }
 
 static void esp6_err(struct sk_buff *skb, struct inet6_skb_parm *opt,
@@ -380,7 +382,7 @@
 	.proto	     	= IPPROTO_ESP,
 	.init_state	= esp6_init_state,
 	.destructor	= esp6_destroy,
-	.get_max_size	= esp6_get_max_size,
+	.get_mtu	= esp6_get_mtu,
 	.input		= esp6_input,
 	.output		= esp6_output,
 	.hdr_offset	= xfrm6_find_1stfragopt,
diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c
index 63a20e8..69a3600 100644
--- a/net/xfrm/xfrm_state.c
+++ b/net/xfrm/xfrm_state.c
@@ -1667,37 +1667,17 @@
 }
 EXPORT_SYMBOL(xfrm_state_delete_tunnel);
 
-/*
- * This function is NOT optimal.  For example, with ESP it will give an
- * MTU that's usually two bytes short of being optimal.  However, it will
- * usually give an answer that's a multiple of 4 provided the input is
- * also a multiple of 4.
- */
 int xfrm_state_mtu(struct xfrm_state *x, int mtu)
 {
-	int res = mtu;
+	int res;
 
-	res -= x->props.header_len;
-
-	for (;;) {
-		int m = res;
-
-		if (m < 68)
-			return 68;
-
-		spin_lock_bh(&x->lock);
-		if (x->km.state == XFRM_STATE_VALID &&
-		    x->type && x->type->get_max_size)
-			m = x->type->get_max_size(x, m);
-		else
-			m += x->props.header_len;
-		spin_unlock_bh(&x->lock);
-
-		if (m <= mtu)
-			break;
-		res -= (m - mtu);
-	}
-
+	spin_lock_bh(&x->lock);
+	if (x->km.state == XFRM_STATE_VALID &&
+	    x->type && x->type->get_mtu)
+		res = x->type->get_mtu(x, mtu);
+	else
+		res = mtu;
+	spin_unlock_bh(&x->lock);
 	return res;
 }