[IPSEC]: Fix inter address family IPsec tunnel handling.

Signed-off-by: Kazunori MIYAZAWA <kazunori@miyazawa.org>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/ipv4/xfrm4_mode_tunnel.c b/net/ipv4/xfrm4_mode_tunnel.c
index 8dee617..584e6d7 100644
--- a/net/ipv4/xfrm4_mode_tunnel.c
+++ b/net/ipv4/xfrm4_mode_tunnel.c
@@ -41,7 +41,7 @@
 	top_iph->ihl = 5;
 	top_iph->version = 4;
 
-	top_iph->protocol = x->inner_mode->afinfo->proto;
+	top_iph->protocol = xfrm_af2proto(skb->dst->ops->family);
 
 	/* DS disclosed */
 	top_iph->tos = INET_ECN_encapsulate(XFRM_MODE_SKB_CB(skb)->tos,
diff --git a/net/ipv4/xfrm4_output.c b/net/ipv4/xfrm4_output.c
index d5a58a8..8c3180a 100644
--- a/net/ipv4/xfrm4_output.c
+++ b/net/ipv4/xfrm4_output.c
@@ -56,7 +56,7 @@
 {
 	int err;
 
-	err = x->inner_mode->afinfo->extract_output(x, skb);
+	err = xfrm_inner_extract_output(x, skb);
 	if (err)
 		return err;
 
diff --git a/net/ipv6/xfrm6_mode_tunnel.c b/net/ipv6/xfrm6_mode_tunnel.c
index 0c742fa..e20529b 100644
--- a/net/ipv6/xfrm6_mode_tunnel.c
+++ b/net/ipv6/xfrm6_mode_tunnel.c
@@ -45,7 +45,7 @@
 
 	memcpy(top_iph->flow_lbl, XFRM_MODE_SKB_CB(skb)->flow_lbl,
 	       sizeof(top_iph->flow_lbl));
-	top_iph->nexthdr = x->inner_mode->afinfo->proto;
+	top_iph->nexthdr = xfrm_af2proto(skb->dst->ops->family);
 
 	dsfield = XFRM_MODE_SKB_CB(skb)->tos;
 	dsfield = INET_ECN_encapsulate(dsfield, dsfield);
diff --git a/net/ipv6/xfrm6_output.c b/net/ipv6/xfrm6_output.c
index 79ccfb0..0af823cf 100644
--- a/net/ipv6/xfrm6_output.c
+++ b/net/ipv6/xfrm6_output.c
@@ -62,7 +62,7 @@
 {
 	int err;
 
-	err = x->inner_mode->afinfo->extract_output(x, skb);
+	err = xfrm_inner_extract_output(x, skb);
 	if (err)
 		return err;
 
diff --git a/net/key/af_key.c b/net/key/af_key.c
index 8b5f486..e9ef9af 100644
--- a/net/key/af_key.c
+++ b/net/key/af_key.c
@@ -1219,7 +1219,7 @@
 		x->sel.prefixlen_s = addr->sadb_address_prefixlen;
 	}
 
-	if (!x->sel.family)
+	if (x->props.mode == XFRM_MODE_TRANSPORT)
 		x->sel.family = x->props.family;
 
 	if (ext_hdrs[SADB_X_EXT_NAT_T_TYPE-1]) {
diff --git a/net/xfrm/xfrm_input.c b/net/xfrm/xfrm_input.c
index 62188c6..7527940 100644
--- a/net/xfrm/xfrm_input.c
+++ b/net/xfrm/xfrm_input.c
@@ -84,14 +84,21 @@
 
 int xfrm_prepare_input(struct xfrm_state *x, struct sk_buff *skb)
 {
+	struct xfrm_mode *inner_mode = x->inner_mode;
 	int err;
 
 	err = x->outer_mode->afinfo->extract_input(x, skb);
 	if (err)
 		return err;
 
-	skb->protocol = x->inner_mode->afinfo->eth_proto;
-	return x->inner_mode->input2(x, skb);
+	if (x->sel.family == AF_UNSPEC) {
+		inner_mode = xfrm_ip2inner_mode(x, XFRM_MODE_SKB_CB(skb)->protocol);
+		if (inner_mode == NULL)
+			return -EAFNOSUPPORT;
+	}
+
+	skb->protocol = inner_mode->afinfo->eth_proto;
+	return inner_mode->input2(x, skb);
 }
 EXPORT_SYMBOL(xfrm_prepare_input);
 
@@ -101,6 +108,7 @@
 	__be32 seq;
 	struct xfrm_state *x;
 	xfrm_address_t *daddr;
+	struct xfrm_mode *inner_mode;
 	unsigned int family;
 	int decaps = 0;
 	int async = 0;
@@ -207,7 +215,15 @@
 
 		XFRM_MODE_SKB_CB(skb)->protocol = nexthdr;
 
-		if (x->inner_mode->input(x, skb)) {
+		inner_mode = x->inner_mode;
+
+		if (x->sel.family == AF_UNSPEC) {
+			inner_mode = xfrm_ip2inner_mode(x, XFRM_MODE_SKB_CB(skb)->protocol);
+			if (inner_mode == NULL)
+				goto drop;
+		}
+
+		if (inner_mode->input(x, skb)) {
 			XFRM_INC_STATS(LINUX_MIB_XFRMINSTATEMODEERROR);
 			goto drop;
 		}
diff --git a/net/xfrm/xfrm_output.c b/net/xfrm/xfrm_output.c
index 569d377..2519129 100644
--- a/net/xfrm/xfrm_output.c
+++ b/net/xfrm/xfrm_output.c
@@ -124,7 +124,7 @@
 		if (!x)
 			return dst_output(skb);
 
-		err = nf_hook(x->inner_mode->afinfo->family,
+		err = nf_hook(skb->dst->ops->family,
 			      NF_INET_POST_ROUTING, skb,
 			      NULL, skb->dst->dev, xfrm_output2);
 		if (unlikely(err != 1))
@@ -193,4 +193,20 @@
 
 	return xfrm_output2(skb);
 }
+
+int xfrm_inner_extract_output(struct xfrm_state *x, struct sk_buff *skb)
+{
+	struct xfrm_mode *inner_mode;
+	if (x->sel.family == AF_UNSPEC)
+		inner_mode = xfrm_ip2inner_mode(x,
+				xfrm_af2proto(skb->dst->ops->family));
+	else
+		inner_mode = x->inner_mode;
+
+	if (inner_mode == NULL)
+		return -EAFNOSUPPORT;
+	return inner_mode->afinfo->extract_output(x, skb);
+}
+
 EXPORT_SYMBOL_GPL(xfrm_output);
+EXPORT_SYMBOL_GPL(xfrm_inner_extract_output);
diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c
index 7ba65e8..58f1f93 100644
--- a/net/xfrm/xfrm_state.c
+++ b/net/xfrm/xfrm_state.c
@@ -388,6 +388,8 @@
 	kfree(x->coaddr);
 	if (x->inner_mode)
 		xfrm_put_mode(x->inner_mode);
+	if (x->inner_mode_iaf)
+		xfrm_put_mode(x->inner_mode_iaf);
 	if (x->outer_mode)
 		xfrm_put_mode(x->outer_mode);
 	if (x->type) {
@@ -523,6 +525,8 @@
 		x->lft.hard_packet_limit = XFRM_INF;
 		x->replay_maxage = 0;
 		x->replay_maxdiff = 0;
+		x->inner_mode = NULL;
+		x->inner_mode_iaf = NULL;
 		spin_lock_init(&x->lock);
 	}
 	return x;
@@ -796,7 +800,7 @@
 			      selector.
 			 */
 			if (x->km.state == XFRM_STATE_VALID) {
-				if (!xfrm_selector_match(&x->sel, fl, x->sel.family) ||
+				if ((x->sel.family && !xfrm_selector_match(&x->sel, fl, x->sel.family)) ||
 				    !security_xfrm_state_pol_flow_match(x, pol, fl))
 					continue;
 				if (!best ||
@@ -1944,6 +1948,7 @@
 int xfrm_init_state(struct xfrm_state *x)
 {
 	struct xfrm_state_afinfo *afinfo;
+	struct xfrm_mode *inner_mode;
 	int family = x->props.family;
 	int err;
 
@@ -1962,13 +1967,48 @@
 		goto error;
 
 	err = -EPROTONOSUPPORT;
-	x->inner_mode = xfrm_get_mode(x->props.mode, x->sel.family);
-	if (x->inner_mode == NULL)
-		goto error;
 
-	if (!(x->inner_mode->flags & XFRM_MODE_FLAG_TUNNEL) &&
-	    family != x->sel.family)
-		goto error;
+	if (x->sel.family != AF_UNSPEC) {
+		inner_mode = xfrm_get_mode(x->props.mode, x->sel.family);
+		if (inner_mode == NULL)
+			goto error;
+
+		if (!(inner_mode->flags & XFRM_MODE_FLAG_TUNNEL) &&
+		    family != x->sel.family) {
+			xfrm_put_mode(inner_mode);
+			goto error;
+		}
+
+		x->inner_mode = inner_mode;
+	} else {
+		struct xfrm_mode *inner_mode_iaf;
+
+		inner_mode = xfrm_get_mode(x->props.mode, AF_INET);
+		if (inner_mode == NULL)
+			goto error;
+
+		if (!(inner_mode->flags & XFRM_MODE_FLAG_TUNNEL)) {
+			xfrm_put_mode(inner_mode);
+			goto error;
+		}
+
+		inner_mode_iaf = xfrm_get_mode(x->props.mode, AF_INET6);
+		if (inner_mode_iaf == NULL)
+			goto error;
+
+		if (!(inner_mode_iaf->flags & XFRM_MODE_FLAG_TUNNEL)) {
+			xfrm_put_mode(inner_mode_iaf);
+			goto error;
+		}
+
+		if (x->props.family == AF_INET) {
+			x->inner_mode = inner_mode;
+			x->inner_mode_iaf = inner_mode_iaf;
+		} else {
+			x->inner_mode = inner_mode_iaf;
+			x->inner_mode_iaf = inner_mode;
+		}
+	}
 
 	x->type = xfrm_get_type(x->id.proto, family);
 	if (x->type == NULL)
diff --git a/net/xfrm/xfrm_user.c b/net/xfrm/xfrm_user.c
index f971ca5..5d96f27 100644
--- a/net/xfrm/xfrm_user.c
+++ b/net/xfrm/xfrm_user.c
@@ -288,12 +288,9 @@
 	memcpy(&x->props.saddr, &p->saddr, sizeof(x->props.saddr));
 	x->props.flags = p->flags;
 
-	/*
-	 * Set inner address family if the KM left it as zero.
-	 * See comment in validate_tmpl.
-	 */
-	if (!x->sel.family)
+	if (x->props.mode == XFRM_MODE_TRANSPORT)
 		x->sel.family = p->family;
+
 }
 
 /*