nl80211: add VHT support for set_bitrate_mask

Add VHT MCS/NSS set support for nl80211_set_tx_bitrate_mask().
This should be used mainly for test purpose, to check
different MCS/NSS VHT combinations.

Signed-off-by: Janusz Dziedzic <janusz.dziedzic@tieto.com>
Signed-off-by: Johannes Berg <johannes.berg@intel.com>
diff --git a/net/wireless/nl80211.c b/net/wireless/nl80211.c
index 2d0c19c..04681a4 100644
--- a/net/wireless/nl80211.c
+++ b/net/wireless/nl80211.c
@@ -7328,11 +7328,72 @@
 	return true;
 }
 
+static u16 vht_mcs_map_to_mcs_mask(u8 vht_mcs_map)
+{
+	u16 mcs_mask = 0;
+
+	switch (vht_mcs_map) {
+	case IEEE80211_VHT_MCS_NOT_SUPPORTED:
+		break;
+	case IEEE80211_VHT_MCS_SUPPORT_0_7:
+		mcs_mask = 0x00FF;
+		break;
+	case IEEE80211_VHT_MCS_SUPPORT_0_8:
+		mcs_mask = 0x01FF;
+		break;
+	case IEEE80211_VHT_MCS_SUPPORT_0_9:
+		mcs_mask = 0x03FF;
+		break;
+	default:
+		break;
+	}
+
+	return mcs_mask;
+}
+
+static void vht_build_mcs_mask(u16 vht_mcs_map,
+			       u16 vht_mcs_mask[NL80211_VHT_NSS_MAX])
+{
+	u8 nss;
+
+	for (nss = 0; nss < NL80211_VHT_NSS_MAX; nss++) {
+		vht_mcs_mask[nss] = vht_mcs_map_to_mcs_mask(vht_mcs_map & 0x03);
+		vht_mcs_map >>= 2;
+	}
+}
+
+static bool vht_set_mcs_mask(struct ieee80211_supported_band *sband,
+			     struct nl80211_txrate_vht *txrate,
+			     u16 mcs[NL80211_VHT_NSS_MAX])
+{
+	u16 tx_mcs_map = le16_to_cpu(sband->vht_cap.vht_mcs.tx_mcs_map);
+	u16 tx_mcs_mask[NL80211_VHT_NSS_MAX] = {};
+	u8 i;
+
+	if (!sband->vht_cap.vht_supported)
+		return false;
+
+	memset(mcs, 0, sizeof(u16) * NL80211_VHT_NSS_MAX);
+
+	/* Build vht_mcs_mask from VHT capabilities */
+	vht_build_mcs_mask(tx_mcs_map, tx_mcs_mask);
+
+	for (i = 0; i < NL80211_VHT_NSS_MAX; i++) {
+		if ((tx_mcs_mask[i] & txrate->mcs[i]) == txrate->mcs[i])
+			mcs[i] = txrate->mcs[i];
+		else
+			return false;
+	}
+
+	return true;
+}
+
 static const struct nla_policy nl80211_txattr_policy[NL80211_TXRATE_MAX + 1] = {
 	[NL80211_TXRATE_LEGACY] = { .type = NLA_BINARY,
 				    .len = NL80211_MAX_SUPP_RATES },
 	[NL80211_TXRATE_HT] = { .type = NLA_BINARY,
 				.len = NL80211_MAX_SUPP_HT_RATES },
+	[NL80211_TXRATE_VHT] = { .len = sizeof(struct nl80211_txrate_vht)},
 };
 
 static int nl80211_set_tx_bitrate_mask(struct sk_buff *skb,
@@ -7345,6 +7406,7 @@
 	struct net_device *dev = info->user_ptr[1];
 	struct nlattr *tx_rates;
 	struct ieee80211_supported_band *sband;
+	u16 vht_tx_mcs_map;
 
 	if (!rdev->ops->set_bitrate_mask)
 		return -EOPNOTSUPP;
@@ -7361,6 +7423,12 @@
 		memcpy(mask.control[i].ht_mcs,
 		       sband->ht_cap.mcs.rx_mask,
 		       sizeof(mask.control[i].ht_mcs));
+
+		if (!sband->vht_cap.vht_supported)
+			continue;
+
+		vht_tx_mcs_map = le16_to_cpu(sband->vht_cap.vht_mcs.tx_mcs_map);
+		vht_build_mcs_mask(vht_tx_mcs_map, mask.control[i].vht_mcs);
 	}
 
 	/* if no rates are given set it back to the defaults */
@@ -7399,20 +7467,32 @@
 					mask.control[band].ht_mcs))
 				return -EINVAL;
 		}
+		if (tb[NL80211_TXRATE_VHT]) {
+			if (!vht_set_mcs_mask(
+					sband,
+					nla_data(tb[NL80211_TXRATE_VHT]),
+					mask.control[band].vht_mcs))
+				return -EINVAL;
+		}
 
 		if (mask.control[band].legacy == 0) {
-			/* don't allow empty legacy rates if HT
-			 * is not even supported. */
-			if (!rdev->wiphy.bands[band]->ht_cap.ht_supported)
+			/* don't allow empty legacy rates if HT or VHT
+			 * are not even supported.
+			 */
+			if (!(rdev->wiphy.bands[band]->ht_cap.ht_supported ||
+			      rdev->wiphy.bands[band]->vht_cap.vht_supported))
 				return -EINVAL;
 
 			for (i = 0; i < IEEE80211_HT_MCS_MASK_LEN; i++)
 				if (mask.control[band].ht_mcs[i])
-					break;
+					goto out;
+
+			for (i = 0; i < NL80211_VHT_NSS_MAX; i++)
+				if (mask.control[band].vht_mcs[i])
+					goto out;
 
 			/* legacy and mcs rates may not be both empty */
-			if (i == IEEE80211_HT_MCS_MASK_LEN)
-				return -EINVAL;
+			return -EINVAL;
 		}
 	}