cfg80211: Allow user space to specify non-IEs to SAE Authentication

SAE extends Authentication frames with fields that are not information
elements. NL80211_ATTR_IE is not suitable for these, so introduce a new
attribute that can be used to specify the fields needed for SAE in
station mode.

Signed-off-by: Jouni Malinen <j@w1.fi>
[change to verify that SAE is only used with authenticate command]
Signed-off-by: Johannes Berg <johannes.berg@intel.com>
diff --git a/net/wireless/core.h b/net/wireless/core.h
index a343be4..b8eb743 100644
--- a/net/wireless/core.h
+++ b/net/wireless/core.h
@@ -320,13 +320,15 @@
 			 const u8 *bssid,
 			 const u8 *ssid, int ssid_len,
 			 const u8 *ie, int ie_len,
-			 const u8 *key, int key_len, int key_idx);
+			 const u8 *key, int key_len, int key_idx,
+			 const u8 *sae_data, int sae_data_len);
 int cfg80211_mlme_auth(struct cfg80211_registered_device *rdev,
 		       struct net_device *dev, struct ieee80211_channel *chan,
 		       enum nl80211_auth_type auth_type, const u8 *bssid,
 		       const u8 *ssid, int ssid_len,
 		       const u8 *ie, int ie_len,
-		       const u8 *key, int key_len, int key_idx);
+		       const u8 *key, int key_len, int key_idx,
+		       const u8 *sae_data, int sae_data_len);
 int __cfg80211_mlme_assoc(struct cfg80211_registered_device *rdev,
 			  struct net_device *dev,
 			  struct ieee80211_channel *chan,
diff --git a/net/wireless/mlme.c b/net/wireless/mlme.c
index 8016fee..460d493 100644
--- a/net/wireless/mlme.c
+++ b/net/wireless/mlme.c
@@ -273,7 +273,8 @@
 			 const u8 *bssid,
 			 const u8 *ssid, int ssid_len,
 			 const u8 *ie, int ie_len,
-			 const u8 *key, int key_len, int key_idx)
+			 const u8 *key, int key_len, int key_idx,
+			 const u8 *sae_data, int sae_data_len)
 {
 	struct wireless_dev *wdev = dev->ieee80211_ptr;
 	struct cfg80211_auth_request req;
@@ -293,6 +294,8 @@
 
 	req.ie = ie;
 	req.ie_len = ie_len;
+	req.sae_data = sae_data;
+	req.sae_data_len = sae_data_len;
 	req.auth_type = auth_type;
 	req.bss = cfg80211_get_bss(&rdev->wiphy, chan, bssid, ssid, ssid_len,
 				   WLAN_CAPABILITY_ESS, WLAN_CAPABILITY_ESS);
@@ -319,7 +322,8 @@
 		       enum nl80211_auth_type auth_type, const u8 *bssid,
 		       const u8 *ssid, int ssid_len,
 		       const u8 *ie, int ie_len,
-		       const u8 *key, int key_len, int key_idx)
+		       const u8 *key, int key_len, int key_idx,
+		       const u8 *sae_data, int sae_data_len)
 {
 	int err;
 
@@ -327,7 +331,8 @@
 	wdev_lock(dev->ieee80211_ptr);
 	err = __cfg80211_mlme_auth(rdev, dev, chan, auth_type, bssid,
 				   ssid, ssid_len, ie, ie_len,
-				   key, key_len, key_idx);
+				   key, key_len, key_idx,
+				   sae_data, sae_data_len);
 	wdev_unlock(dev->ieee80211_ptr);
 	mutex_unlock(&rdev->devlist_mtx);
 
diff --git a/net/wireless/nl80211.c b/net/wireless/nl80211.c
index 0418a6d..74d8123 100644
--- a/net/wireless/nl80211.c
+++ b/net/wireless/nl80211.c
@@ -23,7 +23,6 @@
 #include "nl80211.h"
 #include "reg.h"
 
-static bool nl80211_valid_auth_type(enum nl80211_auth_type auth_type);
 static int nl80211_crypto_settings(struct cfg80211_registered_device *rdev,
 				   struct genl_info *info,
 				   struct cfg80211_crypto_settings *settings,
@@ -355,6 +354,7 @@
 	[NL80211_ATTR_BG_SCAN_PERIOD] = { .type = NLA_U16 },
 	[NL80211_ATTR_WDEV] = { .type = NLA_U64 },
 	[NL80211_ATTR_USER_REG_HINT_TYPE] = { .type = NLA_U32 },
+	[NL80211_ATTR_SAE_DATA] = { .type = NLA_BINARY, },
 };
 
 /* policy for the key attributes */
@@ -2490,6 +2490,30 @@
 	return ret;
 }
 
+static bool nl80211_valid_auth_type(struct cfg80211_registered_device *rdev,
+				    enum nl80211_auth_type auth_type,
+				    enum nl80211_commands cmd)
+{
+	if (auth_type > NL80211_AUTHTYPE_MAX)
+		return false;
+
+	switch (cmd) {
+	case NL80211_CMD_AUTHENTICATE:
+		if (!(rdev->wiphy.features & NL80211_FEATURE_SAE) &&
+		    auth_type == NL80211_AUTHTYPE_SAE)
+			return false;
+		return true;
+	case NL80211_CMD_CONNECT:
+	case NL80211_CMD_START_AP:
+		/* SAE not supported yet */
+		if (auth_type == NL80211_AUTHTYPE_SAE)
+			return false;
+		return true;
+	default:
+		return false;
+	}
+}
+
 static int nl80211_start_ap(struct sk_buff *skb, struct genl_info *info)
 {
 	struct cfg80211_registered_device *rdev = info->user_ptr[0];
@@ -2559,7 +2583,8 @@
 	if (info->attrs[NL80211_ATTR_AUTH_TYPE]) {
 		params.auth_type = nla_get_u32(
 			info->attrs[NL80211_ATTR_AUTH_TYPE]);
-		if (!nl80211_valid_auth_type(params.auth_type))
+		if (!nl80211_valid_auth_type(rdev, params.auth_type,
+					     NL80211_CMD_START_AP))
 			return -EINVAL;
 	} else
 		params.auth_type = NL80211_AUTHTYPE_AUTOMATIC;
@@ -4852,11 +4877,6 @@
 	return res;
 }
 
-static bool nl80211_valid_auth_type(enum nl80211_auth_type auth_type)
-{
-	return auth_type <= NL80211_AUTHTYPE_MAX;
-}
-
 static bool nl80211_valid_wpa_versions(u32 wpa_versions)
 {
 	return !(wpa_versions & ~(NL80211_WPA_VERSION_1 |
@@ -4868,8 +4888,8 @@
 	struct cfg80211_registered_device *rdev = info->user_ptr[0];
 	struct net_device *dev = info->user_ptr[1];
 	struct ieee80211_channel *chan;
-	const u8 *bssid, *ssid, *ie = NULL;
-	int err, ssid_len, ie_len = 0;
+	const u8 *bssid, *ssid, *ie = NULL, *sae_data = NULL;
+	int err, ssid_len, ie_len = 0, sae_data_len = 0;
 	enum nl80211_auth_type auth_type;
 	struct key_parse key;
 	bool local_state_change;
@@ -4945,9 +4965,23 @@
 	}
 
 	auth_type = nla_get_u32(info->attrs[NL80211_ATTR_AUTH_TYPE]);
-	if (!nl80211_valid_auth_type(auth_type))
+	if (!nl80211_valid_auth_type(rdev, auth_type, NL80211_CMD_AUTHENTICATE))
 		return -EINVAL;
 
+	if (auth_type == NL80211_AUTHTYPE_SAE &&
+	    !info->attrs[NL80211_ATTR_SAE_DATA])
+		return -EINVAL;
+
+	if (info->attrs[NL80211_ATTR_SAE_DATA]) {
+		if (auth_type != NL80211_AUTHTYPE_SAE)
+			return -EINVAL;
+		sae_data = nla_data(info->attrs[NL80211_ATTR_SAE_DATA]);
+		sae_data_len = nla_len(info->attrs[NL80211_ATTR_SAE_DATA]);
+		/* need to include at least Auth Transaction and Status Code */
+		if (sae_data_len < 4)
+			return -EINVAL;
+	}
+
 	local_state_change = !!info->attrs[NL80211_ATTR_LOCAL_STATE_CHANGE];
 
 	/*
@@ -4959,7 +4993,8 @@
 
 	return cfg80211_mlme_auth(rdev, dev, chan, auth_type, bssid,
 				  ssid, ssid_len, ie, ie_len,
-				  key.p.key, key.p.key_len, key.idx);
+				  key.p.key, key.p.key_len, key.idx,
+				  sae_data, sae_data_len);
 }
 
 static int nl80211_crypto_settings(struct cfg80211_registered_device *rdev,
@@ -5596,7 +5631,8 @@
 	if (info->attrs[NL80211_ATTR_AUTH_TYPE]) {
 		connect.auth_type =
 			nla_get_u32(info->attrs[NL80211_ATTR_AUTH_TYPE]);
-		if (!nl80211_valid_auth_type(connect.auth_type))
+		if (!nl80211_valid_auth_type(rdev, connect.auth_type,
+					     NL80211_CMD_CONNECT))
 			return -EINVAL;
 	} else
 		connect.auth_type = NL80211_AUTHTYPE_AUTOMATIC;
diff --git a/net/wireless/sme.c b/net/wireless/sme.c
index 6f39cb8..055d596 100644
--- a/net/wireless/sme.c
+++ b/net/wireless/sme.c
@@ -179,7 +179,7 @@
 					    params->ssid, params->ssid_len,
 					    NULL, 0,
 					    params->key, params->key_len,
-					    params->key_idx);
+					    params->key_idx, NULL, 0);
 	case CFG80211_CONN_ASSOCIATE_NEXT:
 		BUG_ON(!rdev->ops->assoc);
 		wdev->conn->state = CFG80211_CONN_ASSOCIATING;