tcp: use dctcp if enabled on the route to the initiator

Currently, the following case doesn't use DCTCP, even if it should:
A responder has f.e. Cubic as system wide default, but for a specific
route to the initiating host, DCTCP is being set in RTAX_CC_ALGO. The
initiating host then uses DCTCP as congestion control, but since the
initiator sets ECT(0), tcp_ecn_create_request() doesn't set ecn_ok,
and we have to fall back to Reno after 3WHS completes.

We were thinking on how to solve this in a minimal, non-intrusive
way without bloating tcp_ecn_create_request() needlessly: lets cache
the CA ecn option flag in RTAX_FEATURES. In other words, when ECT(0)
is set on the SYN packet, set ecn_ok=1 iff route RTAX_FEATURES
contains the unexposed (internal-only) DST_FEATURE_ECN_CA. This allows
to only do a single metric feature lookup inside tcp_ecn_create_request().

Joint work with Florian Westphal.

Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
Signed-off-by: Florian Westphal <fw@strlen.de>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/core/rtnetlink.c b/net/core/rtnetlink.c
index 788ceed..a466821 100644
--- a/net/core/rtnetlink.c
+++ b/net/core/rtnetlink.c
@@ -678,6 +678,12 @@
 					continue;
 				if (nla_put_string(skb, i + 1, name))
 					goto nla_put_failure;
+			} else if (i == RTAX_FEATURES - 1) {
+				u32 user_features = metrics[i] & RTAX_FEATURE_MASK;
+
+				BUILD_BUG_ON(RTAX_FEATURE_MASK & DST_FEATURE_MASK);
+				if (nla_put_u32(skb, i + 1, user_features))
+					goto nla_put_failure;
 			} else {
 				if (nla_put_u32(skb, i + 1, metrics[i]))
 					goto nla_put_failure;
diff --git a/net/ipv4/fib_semantics.c b/net/ipv4/fib_semantics.c
index 115a08e..992a959 100644
--- a/net/ipv4/fib_semantics.c
+++ b/net/ipv4/fib_semantics.c
@@ -879,6 +879,7 @@
 static int
 fib_convert_metrics(struct fib_info *fi, const struct fib_config *cfg)
 {
+	bool ecn_ca = false;
 	struct nlattr *nla;
 	int remaining;
 
@@ -898,7 +899,7 @@
 			char tmp[TCP_CA_NAME_MAX];
 
 			nla_strlcpy(tmp, nla, sizeof(tmp));
-			val = tcp_ca_get_key_by_name(tmp);
+			val = tcp_ca_get_key_by_name(tmp, &ecn_ca);
 			if (val == TCP_CA_UNSPEC)
 				return -EINVAL;
 		} else {
@@ -913,6 +914,9 @@
 		fi->fib_metrics[type - 1] = val;
 	}
 
+	if (ecn_ca)
+		fi->fib_metrics[RTAX_FEATURES - 1] |= DST_FEATURE_ECN_CA;
+
 	return 0;
 }
 
diff --git a/net/ipv4/tcp_cong.c b/net/ipv4/tcp_cong.c
index a2ed23c..93c4dc3 100644
--- a/net/ipv4/tcp_cong.c
+++ b/net/ipv4/tcp_cong.c
@@ -114,16 +114,19 @@
 }
 EXPORT_SYMBOL_GPL(tcp_unregister_congestion_control);
 
-u32 tcp_ca_get_key_by_name(const char *name)
+u32 tcp_ca_get_key_by_name(const char *name, bool *ecn_ca)
 {
 	const struct tcp_congestion_ops *ca;
-	u32 key;
+	u32 key = TCP_CA_UNSPEC;
 
 	might_sleep();
 
 	rcu_read_lock();
 	ca = __tcp_ca_find_autoload(name);
-	key = ca ? ca->key : TCP_CA_UNSPEC;
+	if (ca) {
+		key = ca->key;
+		*ecn_ca = ca->flags & TCP_CONG_NEEDS_ECN;
+	}
 	rcu_read_unlock();
 
 	return key;
diff --git a/net/ipv4/tcp_input.c b/net/ipv4/tcp_input.c
index dc08e23..a8f515b 100644
--- a/net/ipv4/tcp_input.c
+++ b/net/ipv4/tcp_input.c
@@ -6003,14 +6003,17 @@
 	const struct net *net = sock_net(listen_sk);
 	bool th_ecn = th->ece && th->cwr;
 	bool ect, ecn_ok;
+	u32 ecn_ok_dst;
 
 	if (!th_ecn)
 		return;
 
 	ect = !INET_ECN_is_not_ect(TCP_SKB_CB(skb)->ip_dsfield);
-	ecn_ok = net->ipv4.sysctl_tcp_ecn || dst_feature(dst, RTAX_FEATURE_ECN);
+	ecn_ok_dst = dst_feature(dst, DST_FEATURE_ECN_MASK);
+	ecn_ok = net->ipv4.sysctl_tcp_ecn || ecn_ok_dst;
 
-	if ((!ect && ecn_ok) || tcp_ca_needs_ecn(listen_sk))
+	if ((!ect && ecn_ok) || tcp_ca_needs_ecn(listen_sk) ||
+	    (ecn_ok_dst & DST_FEATURE_ECN_CA))
 		inet_rsk(req)->ecn_ok = 1;
 }
 
diff --git a/net/ipv6/route.c b/net/ipv6/route.c
index 8771530..f45cac6 100644
--- a/net/ipv6/route.c
+++ b/net/ipv6/route.c
@@ -1698,6 +1698,7 @@
 static int ip6_convert_metrics(struct mx6_config *mxc,
 			       const struct fib6_config *cfg)
 {
+	bool ecn_ca = false;
 	struct nlattr *nla;
 	int remaining;
 	u32 *mp;
@@ -1722,7 +1723,7 @@
 			char tmp[TCP_CA_NAME_MAX];
 
 			nla_strlcpy(tmp, nla, sizeof(tmp));
-			val = tcp_ca_get_key_by_name(tmp);
+			val = tcp_ca_get_key_by_name(tmp, &ecn_ca);
 			if (val == TCP_CA_UNSPEC)
 				goto err;
 		} else {
@@ -1735,8 +1736,12 @@
 		__set_bit(type - 1, mxc->mx_valid);
 	}
 
-	mxc->mx = mp;
+	if (ecn_ca) {
+		__set_bit(RTAX_FEATURES - 1, mxc->mx_valid);
+		mp[RTAX_FEATURES - 1] |= DST_FEATURE_ECN_CA;
+	}
 
+	mxc->mx = mp;
 	return 0;
  err:
 	kfree(mp);