mac80211: Handle the CSA counters correctly

Make the beacon CSA counters part of ieee80211_mutable_offsets and don't
decrement CSA counters when generating a beacon template. This permits the
driver to offload the CSA counters handling. Since mac80211 updates the probe
responses with the correct counter, the driver should sync the counter's value
with mac80211 using ieee80211_csa_update_counter function.

Signed-off-by: Andrei Otcheretianski <andrei.otcheretianski@intel.com>
Signed-off-by: Luciano Coelho <luciano.coelho@intel.com>
Signed-off-by: Johannes Berg <johannes.berg@intel.com>
diff --git a/include/net/mac80211.h b/include/net/mac80211.h
index e652126..982d2cd 100644
--- a/include/net/mac80211.h
+++ b/include/net/mac80211.h
@@ -3411,14 +3411,20 @@
  */
 void ieee80211_report_low_ack(struct ieee80211_sta *sta, u32 num_packets);
 
+#define IEEE80211_MAX_CSA_COUNTERS_NUM 2
+
 /**
  * struct ieee80211_mutable_offsets - mutable beacon offsets
  * @tim_offset: position of TIM element
  * @tim_length: size of TIM element
+ * @csa_offs: array of IEEE80211_MAX_CSA_COUNTERS_NUM offsets to CSA counters.
+ *	This array can contain zero values which should be ignored.
  */
 struct ieee80211_mutable_offsets {
 	u16 tim_offset;
 	u16 tim_length;
+
+	u16 csa_counter_offs[IEEE80211_MAX_CSA_COUNTERS_NUM];
 };
 
 /**
@@ -3433,7 +3439,8 @@
  *
  * This function should be used if the beacon frames are generated by the
  * device, and then the driver must use the returned beacon as the template
- * The driver is responsible to update the DTIM count.
+ * The driver or the device are responsible to update the DTIM and, when
+ * applicable, the CSA count.
  *
  * The driver is responsible for freeing the returned skb.
  *
@@ -3486,6 +3493,20 @@
 }
 
 /**
+ * ieee80211_csa_update_counter - request mac80211 to decrement the csa counter
+ * @vif: &struct ieee80211_vif pointer from the add_interface callback.
+ *
+ * The csa counter should be updated after each beacon transmission.
+ * This function is called implicitly when
+ * ieee80211_beacon_get/ieee80211_beacon_get_tim are called, however if the
+ * beacon frames are generated by the device, the driver should call this
+ * function after each beacon transmission to sync mac80211's csa counters.
+ *
+ * Return: new csa counter value
+ */
+u8 ieee80211_csa_update_counter(struct ieee80211_vif *vif);
+
+/**
  * ieee80211_csa_finish - notify mac80211 about channel switch
  * @vif: &struct ieee80211_vif pointer from the add_interface callback.
  *
diff --git a/net/mac80211/cfg.c b/net/mac80211/cfg.c
index d44dca5..bfd2534 100644
--- a/net/mac80211/cfg.c
+++ b/net/mac80211/cfg.c
@@ -3502,10 +3502,10 @@
 	     sdata->vif.type == NL80211_IFTYPE_ADHOC) &&
 	    params->n_csa_offsets) {
 		int i;
+		u8 c = sdata->csa_current_counter;
 
 		for (i = 0; i < params->n_csa_offsets; i++)
-			data[params->csa_offsets[i]] =
-					sdata->csa_current_counter;
+			data[params->csa_offsets[i]] = c;
 	}
 
 	IEEE80211_SKB_CB(skb)->flags = flags;
diff --git a/net/mac80211/ieee80211_i.h b/net/mac80211/ieee80211_i.h
index 05ed592..57e0b26 100644
--- a/net/mac80211/ieee80211_i.h
+++ b/net/mac80211/ieee80211_i.h
@@ -70,8 +70,6 @@
 
 #define IEEE80211_DEAUTH_FRAME_LEN	(24 /* hdr */ + 2 /* reason */)
 
-#define IEEE80211_MAX_CSA_COUNTERS_NUM 2
-
 struct ieee80211_fragment_entry {
 	unsigned long first_frag_time;
 	unsigned int seq;
diff --git a/net/mac80211/tx.c b/net/mac80211/tx.c
index 509456e..5214686 100644
--- a/net/mac80211/tx.c
+++ b/net/mac80211/tx.c
@@ -2416,13 +2416,14 @@
 	return 0;
 }
 
-static void ieee80211_update_csa(struct ieee80211_sub_if_data *sdata,
-				 struct beacon_data *beacon)
+static void ieee80211_set_csa(struct ieee80211_sub_if_data *sdata,
+			      struct beacon_data *beacon)
 {
 	struct probe_resp *resp;
 	u8 *beacon_data;
 	size_t beacon_data_len;
 	int i;
+	u8 count = sdata->csa_current_counter;
 
 	switch (sdata->vif.type) {
 	case NL80211_IFTYPE_AP:
@@ -2450,16 +2451,7 @@
 			if (WARN_ON(counter_offset_beacon >= beacon_data_len))
 				return;
 
-			/* Warn if the driver did not check for/react to csa
-			 * completeness.  A beacon with CSA counter set to 0
-			 * should never occur, because a counter of 1 means
-			 * switch just before the next beacon.
-			 */
-			if (WARN_ON(beacon_data[counter_offset_beacon] == 1))
-				return;
-
-			beacon_data[counter_offset_beacon] =
-				sdata->csa_current_counter - 1;
+			beacon_data[counter_offset_beacon] = count;
 		}
 
 		if (sdata->vif.type == NL80211_IFTYPE_AP &&
@@ -2474,14 +2466,24 @@
 				rcu_read_unlock();
 				return;
 			}
-			resp->data[counter_offset_presp] =
-				sdata->csa_current_counter - 1;
+			resp->data[counter_offset_presp] = count;
 			rcu_read_unlock();
 		}
 	}
+}
+
+u8 ieee80211_csa_update_counter(struct ieee80211_vif *vif)
+{
+	struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif);
 
 	sdata->csa_current_counter--;
+
+	/* the counter should never reach 0 */
+	WARN_ON(!sdata->csa_current_counter);
+
+	return sdata->csa_current_counter;
 }
+EXPORT_SYMBOL(ieee80211_csa_update_counter);
 
 bool ieee80211_csa_is_complete(struct ieee80211_vif *vif)
 {
@@ -2552,6 +2554,7 @@
 	enum ieee80211_band band;
 	struct ieee80211_tx_rate_control txrc;
 	struct ieee80211_chanctx_conf *chanctx_conf;
+	int csa_off_base = 0;
 
 	rcu_read_lock();
 
@@ -2569,8 +2572,12 @@
 		struct beacon_data *beacon = rcu_dereference(ap->beacon);
 
 		if (beacon) {
-			if (sdata->vif.csa_active)
-				ieee80211_update_csa(sdata, beacon);
+			if (sdata->vif.csa_active) {
+				if (!is_template)
+					ieee80211_csa_update_counter(vif);
+
+				ieee80211_set_csa(sdata, beacon);
+			}
 
 			/*
 			 * headroom, head length,
@@ -2593,6 +2600,9 @@
 			if (offs) {
 				offs->tim_offset = beacon->head_len;
 				offs->tim_length = skb->len - beacon->head_len;
+
+				/* for AP the csa offsets are from tail */
+				csa_off_base = skb->len;
 			}
 
 			if (beacon->tail)
@@ -2608,9 +2618,12 @@
 		if (!presp)
 			goto out;
 
-		if (sdata->vif.csa_active)
-			ieee80211_update_csa(sdata, presp);
+		if (sdata->vif.csa_active) {
+			if (!is_template)
+				ieee80211_csa_update_counter(vif);
 
+			ieee80211_set_csa(sdata, presp);
+		}
 
 		skb = dev_alloc_skb(local->tx_headroom + presp->head_len +
 				    local->hw.extra_beacon_tailroom);
@@ -2630,8 +2643,17 @@
 		if (!bcn)
 			goto out;
 
-		if (sdata->vif.csa_active)
-			ieee80211_update_csa(sdata, bcn);
+		if (sdata->vif.csa_active) {
+			if (!is_template)
+				/* TODO: For mesh csa_counter is in TU, so
+				 * decrementing it by one isn't correct, but
+				 * for now we leave it consistent with overall
+				 * mac80211's behavior.
+				 */
+				ieee80211_csa_update_counter(vif);
+
+			ieee80211_set_csa(sdata, bcn);
+		}
 
 		if (ifmsh->sync_ops)
 			ifmsh->sync_ops->adjust_tbtt(sdata, bcn);
@@ -2658,6 +2680,20 @@
 		goto out;
 	}
 
+	/* CSA offsets */
+	if (offs) {
+		int i;
+
+		for (i = 0; i < IEEE80211_MAX_CSA_COUNTERS_NUM; i++) {
+			u16 csa_off = sdata->csa_counter_offset_beacon[i];
+
+			if (!csa_off)
+				continue;
+
+			offs->csa_counter_offs[i] = csa_off_base + csa_off;
+		}
+	}
+
 	band = chanctx_conf->def.chan->band;
 
 	info = IEEE80211_SKB_CB(skb);