iwlwifi: mvm: check PN for CCMP/GCMP in the driver

As we're working on multi-queue RX, we want to parallelise checking
the PN in order to avoid having to serialise the RX processing.

It may seem that doing parallel PN checking is insecure, but it turns
out to be OK because queue assignment is done based on the data in the
frame (IP/TCP) and thus cannot be manipulated by an attacker, since
the data is encrypted and must first have been decrypted successfully.

There are some corner cases, in particular when the peer starts using
fragmentation which redirects the packet to the default queue. However
this redirection is remembered (for the STA, per TID) and thus cannot
be exploited by an attacker either.

Leave checking on the default queue (queue 0) to mac80211, since we
get fragmented packets there and those are subject to stricter checks
during reassembly.

Signed-off-by: Johannes Berg <johannes.berg@intel.com>
Signed-off-by: Sara Sharon <sara.sharon@intel.com>
Signed-off-by: Emmanuel Grumbach <emmanuel.grumbach@intel.com>
diff --git a/drivers/net/wireless/intel/iwlwifi/mvm/d3.c b/drivers/net/wireless/intel/iwlwifi/mvm/d3.c
index 6ac4072..2cd9052 100644
--- a/drivers/net/wireless/intel/iwlwifi/mvm/d3.c
+++ b/drivers/net/wireless/intel/iwlwifi/mvm/d3.c
@@ -137,6 +137,28 @@
 		out[i] = cpu_to_le16(p1k[i]);
 }
 
+static const u8 *iwl_mvm_find_max_pn(struct ieee80211_key_conf *key,
+				     struct iwl_mvm_key_pn *ptk_pn,
+				     struct ieee80211_key_seq *seq,
+				     int tid, int queues)
+{
+	const u8 *ret = seq->ccmp.pn;
+	int i;
+
+	/* get the PN from mac80211, used on the default queue */
+	ieee80211_get_key_rx_seq(key, tid, seq);
+
+	/* and use the internal data for the other queues */
+	for (i = 1; i < queues; i++) {
+		const u8 *tmp = ptk_pn->q[i].pn[tid];
+
+		if (memcmp(ret, tmp, IEEE80211_CCMP_PN_LEN) <= 0)
+			ret = tmp;
+	}
+
+	return ret;
+}
+
 struct wowlan_key_data {
 	struct iwl_wowlan_rsc_tsc_params_cmd *rsc_tsc;
 	struct iwl_wowlan_tkip_params_cmd *tkip;
@@ -294,18 +316,42 @@
 
 		/*
 		 * For non-QoS this relies on the fact that both the uCode and
-		 * mac80211 use TID 0 for checking the IV in the frames.
+		 * mac80211/our RX code use TID 0 for checking the PN.
 		 */
-		for (i = 0; i < IWL_NUM_RSC; i++) {
-			u8 *pn = seq.ccmp.pn;
+		if (sta && iwl_mvm_has_new_rx_api(mvm)) {
+			struct iwl_mvm_sta *mvmsta;
+			struct iwl_mvm_key_pn *ptk_pn;
+			const u8 *pn;
 
-			ieee80211_get_key_rx_seq(key, i, &seq);
-			aes_sc[i].pn = cpu_to_le64((u64)pn[5] |
-						   ((u64)pn[4] << 8) |
-						   ((u64)pn[3] << 16) |
-						   ((u64)pn[2] << 24) |
-						   ((u64)pn[1] << 32) |
-						   ((u64)pn[0] << 40));
+			mvmsta = iwl_mvm_sta_from_mac80211(sta);
+			ptk_pn = rcu_dereference_protected(
+						mvmsta->ptk_pn[key->keyidx],
+						lockdep_is_held(&mvm->mutex));
+			if (WARN_ON(!ptk_pn))
+				break;
+
+			for (i = 0; i < IWL_MAX_TID_COUNT; i++) {
+				pn = iwl_mvm_find_max_pn(key, ptk_pn, &seq, i,
+						mvm->trans->num_rx_queues);
+				aes_sc[i].pn = cpu_to_le64((u64)pn[5] |
+							   ((u64)pn[4] << 8) |
+							   ((u64)pn[3] << 16) |
+							   ((u64)pn[2] << 24) |
+							   ((u64)pn[1] << 32) |
+							   ((u64)pn[0] << 40));
+			}
+		} else {
+			for (i = 0; i < IWL_NUM_RSC; i++) {
+				u8 *pn = seq.ccmp.pn;
+
+				ieee80211_get_key_rx_seq(key, i, &seq);
+				aes_sc[i].pn = cpu_to_le64((u64)pn[5] |
+							   ((u64)pn[4] << 8) |
+							   ((u64)pn[3] << 16) |
+							   ((u64)pn[2] << 24) |
+							   ((u64)pn[1] << 32) |
+							   ((u64)pn[0] << 40));
+			}
 		}
 		data->use_rsc_tsc = true;
 		break;
@@ -1426,18 +1472,42 @@
 	seq->tkip.iv16 = le16_to_cpu(sc->iv16);
 }
 
-static void iwl_mvm_set_aes_rx_seq(struct aes_sc *scs,
+static void iwl_mvm_set_aes_rx_seq(struct iwl_mvm *mvm, struct aes_sc *scs,
+				   struct ieee80211_sta *sta,
 				   struct ieee80211_key_conf *key)
 {
 	int tid;
 
 	BUILD_BUG_ON(IWL_NUM_RSC != IEEE80211_NUM_TIDS);
 
-	for (tid = 0; tid < IWL_NUM_RSC; tid++) {
-		struct ieee80211_key_seq seq = {};
+	if (sta && iwl_mvm_has_new_rx_api(mvm)) {
+		struct iwl_mvm_sta *mvmsta;
+		struct iwl_mvm_key_pn *ptk_pn;
 
-		iwl_mvm_aes_sc_to_seq(&scs[tid], &seq);
-		ieee80211_set_key_rx_seq(key, tid, &seq);
+		mvmsta = iwl_mvm_sta_from_mac80211(sta);
+
+		ptk_pn = rcu_dereference_protected(mvmsta->ptk_pn[key->keyidx],
+						   lockdep_is_held(&mvm->mutex));
+		if (WARN_ON(!ptk_pn))
+			return;
+
+		for (tid = 0; tid < IWL_MAX_TID_COUNT; tid++) {
+			struct ieee80211_key_seq seq = {};
+			int i;
+
+			iwl_mvm_aes_sc_to_seq(&scs[tid], &seq);
+			ieee80211_set_key_rx_seq(key, tid, &seq);
+			for (i = 1; i < mvm->trans->num_rx_queues; i++)
+				memcpy(ptk_pn->q[i].pn[tid],
+				       seq.ccmp.pn, IEEE80211_CCMP_PN_LEN);
+		}
+	} else {
+		for (tid = 0; tid < IWL_NUM_RSC; tid++) {
+			struct ieee80211_key_seq seq = {};
+
+			iwl_mvm_aes_sc_to_seq(&scs[tid], &seq);
+			ieee80211_set_key_rx_seq(key, tid, &seq);
+		}
 	}
 }
 
@@ -1456,14 +1526,15 @@
 	}
 }
 
-static void iwl_mvm_set_key_rx_seq(struct ieee80211_key_conf *key,
+static void iwl_mvm_set_key_rx_seq(struct iwl_mvm *mvm,
+				   struct ieee80211_key_conf *key,
 				   struct iwl_wowlan_status *status)
 {
 	union iwl_all_tsc_rsc *rsc = &status->gtk.rsc.all_tsc_rsc;
 
 	switch (key->cipher) {
 	case WLAN_CIPHER_SUITE_CCMP:
-		iwl_mvm_set_aes_rx_seq(rsc->aes.multicast_rsc, key);
+		iwl_mvm_set_aes_rx_seq(mvm, rsc->aes.multicast_rsc, NULL, key);
 		break;
 	case WLAN_CIPHER_SUITE_TKIP:
 		iwl_mvm_set_tkip_rx_seq(rsc->tkip.multicast_rsc, key);
@@ -1474,6 +1545,7 @@
 }
 
 struct iwl_mvm_d3_gtk_iter_data {
+	struct iwl_mvm *mvm;
 	struct iwl_wowlan_status *status;
 	void *last_gtk;
 	u32 cipher;
@@ -1522,7 +1594,8 @@
 
 		switch (key->cipher) {
 		case WLAN_CIPHER_SUITE_CCMP:
-			iwl_mvm_set_aes_rx_seq(sc->aes.unicast_rsc, key);
+			iwl_mvm_set_aes_rx_seq(data->mvm, sc->aes.unicast_rsc,
+					       sta, key);
 			atomic64_set(&key->tx_pn, le64_to_cpu(sc->aes.tsc.pn));
 			break;
 		case WLAN_CIPHER_SUITE_TKIP:
@@ -1545,7 +1618,7 @@
 	if (data->status->num_of_gtk_rekeys)
 		ieee80211_remove_key(key);
 	else if (data->last_gtk == key)
-		iwl_mvm_set_key_rx_seq(key, data->status);
+		iwl_mvm_set_key_rx_seq(data->mvm, key, data->status);
 }
 
 static bool iwl_mvm_setup_connection_keep(struct iwl_mvm *mvm,
@@ -1554,6 +1627,7 @@
 {
 	struct iwl_mvm_vif *mvmvif = iwl_mvm_vif_from_mac80211(vif);
 	struct iwl_mvm_d3_gtk_iter_data gtkdata = {
+		.mvm = mvm,
 		.status = status,
 	};
 	u32 disconnection_reasons =
@@ -1615,7 +1689,7 @@
 		key = ieee80211_gtk_rekey_add(vif, &conf.conf);
 		if (IS_ERR(key))
 			return false;
-		iwl_mvm_set_key_rx_seq(key, status);
+		iwl_mvm_set_key_rx_seq(mvm, key, status);
 	}
 
 	if (status->num_of_gtk_rekeys) {