wl12xx: implement sta_state callback

Implement sta_state callback instead of the
sta_add/remove callbacks.

Update the fw regarding peer state and ht caps
only after the station was authorized. Otherwise,
the fw might try establishing BA session before
the sta is authorized.

Signed-off-by: Eliad Peller <eliad@wizery.com>
Signed-off-by: Luciano Coelho <coelho@ti.com>
diff --git a/drivers/net/wireless/wl12xx/main.c b/drivers/net/wireless/wl12xx/main.c
index b771106..8569ac3 100644
--- a/drivers/net/wireless/wl12xx/main.c
+++ b/drivers/net/wireless/wl12xx/main.c
@@ -4233,100 +4233,128 @@
 	wl->active_sta_count--;
 }
 
-static int wl1271_op_sta_add(struct ieee80211_hw *hw,
-			     struct ieee80211_vif *vif,
-			     struct ieee80211_sta *sta)
+static int wl12xx_sta_add(struct wl1271 *wl,
+			  struct wl12xx_vif *wlvif,
+			  struct ieee80211_sta *sta)
 {
-	struct wl1271 *wl = hw->priv;
-	struct wl12xx_vif *wlvif = wl12xx_vif_to_data(vif);
 	struct wl1271_station *wl_sta;
 	int ret = 0;
 	u8 hlid;
 
-	mutex_lock(&wl->mutex);
-
-	if (unlikely(wl->state == WL1271_STATE_OFF))
-		goto out;
-
-	if (wlvif->bss_type != BSS_TYPE_AP_BSS)
-		goto out;
-
 	wl1271_debug(DEBUG_MAC80211, "mac80211 add sta %d", (int)sta->aid);
 
 	ret = wl1271_allocate_sta(wl, wlvif, sta);
 	if (ret < 0)
-		goto out;
+		return ret;
 
 	wl_sta = (struct wl1271_station *)sta->drv_priv;
 	hlid = wl_sta->hlid;
 
-	ret = wl1271_ps_elp_wakeup(wl);
-	if (ret < 0)
-		goto out_free_sta;
-
 	ret = wl12xx_cmd_add_peer(wl, wlvif, sta, hlid);
 	if (ret < 0)
-		goto out_sleep;
-
-	ret = wl12xx_cmd_set_peer_state(wl, hlid);
-	if (ret < 0)
-		goto out_sleep;
-
-	ret = wl1271_acx_set_ht_capabilities(wl, &sta->ht_cap, true, hlid);
-	if (ret < 0)
-		goto out_sleep;
-
-out_sleep:
-	wl1271_ps_elp_sleep(wl);
-
-out_free_sta:
-	if (ret < 0)
 		wl1271_free_sta(wl, wlvif, hlid);
 
-out:
-	mutex_unlock(&wl->mutex);
 	return ret;
 }
 
-static int wl1271_op_sta_remove(struct ieee80211_hw *hw,
-				struct ieee80211_vif *vif,
-				struct ieee80211_sta *sta)
+static int wl12xx_sta_remove(struct wl1271 *wl,
+			     struct wl12xx_vif *wlvif,
+			     struct ieee80211_sta *sta)
 {
-	struct wl1271 *wl = hw->priv;
-	struct wl12xx_vif *wlvif = wl12xx_vif_to_data(vif);
 	struct wl1271_station *wl_sta;
 	int ret = 0, id;
 
-	mutex_lock(&wl->mutex);
-
-	if (unlikely(wl->state == WL1271_STATE_OFF))
-		goto out;
-
-	if (wlvif->bss_type != BSS_TYPE_AP_BSS)
-		goto out;
-
 	wl1271_debug(DEBUG_MAC80211, "mac80211 remove sta %d", (int)sta->aid);
 
 	wl_sta = (struct wl1271_station *)sta->drv_priv;
 	id = wl_sta->hlid;
 	if (WARN_ON(!test_bit(id, wlvif->ap.sta_hlid_map)))
+		return -EINVAL;
+
+	ret = wl12xx_cmd_remove_peer(wl, wl_sta->hlid);
+	if (ret < 0)
+		return ret;
+
+	wl1271_free_sta(wl, wlvif, wl_sta->hlid);
+	return ret;
+}
+
+static int wl12xx_update_sta_state(struct wl1271 *wl,
+				   struct wl12xx_vif *wlvif,
+				   struct ieee80211_sta *sta,
+				   enum ieee80211_sta_state old_state,
+				   enum ieee80211_sta_state new_state)
+{
+	struct wl1271_station *wl_sta;
+	u8 hlid;
+	bool is_ap = wlvif->bss_type == BSS_TYPE_AP_BSS;
+	bool is_sta = wlvif->bss_type == BSS_TYPE_STA_BSS;
+	int ret;
+
+	wl_sta = (struct wl1271_station *)sta->drv_priv;
+	hlid = wl_sta->hlid;
+
+	/* Add station (AP mode) */
+	if (is_ap &&
+	    old_state == IEEE80211_STA_NOTEXIST &&
+	    new_state == IEEE80211_STA_NONE)
+		return wl12xx_sta_add(wl, wlvif, sta);
+
+	/* Remove station (AP mode) */
+	if (is_ap &&
+	    old_state == IEEE80211_STA_NONE &&
+	    new_state == IEEE80211_STA_NOTEXIST) {
+		/* must not fail */
+		wl12xx_sta_remove(wl, wlvif, sta);
+		return 0;
+	}
+
+	/* Authorize station (AP mode) */
+	if (is_ap &&
+	    new_state == IEEE80211_STA_AUTHORIZED) {
+		ret = wl12xx_cmd_set_peer_state(wl, hlid);
+		if (ret < 0)
+			return ret;
+
+		ret = wl1271_acx_set_ht_capabilities(wl, &sta->ht_cap, true,
+						     hlid);
+		return ret;
+	}
+
+	return 0;
+}
+
+static int wl12xx_op_sta_state(struct ieee80211_hw *hw,
+			       struct ieee80211_vif *vif,
+			       struct ieee80211_sta *sta,
+			       enum ieee80211_sta_state old_state,
+			       enum ieee80211_sta_state new_state)
+{
+	struct wl1271 *wl = hw->priv;
+	struct wl12xx_vif *wlvif = wl12xx_vif_to_data(vif);
+	int ret;
+
+	wl1271_debug(DEBUG_MAC80211, "mac80211 sta %d state=%d->%d",
+		     sta->aid, old_state, new_state);
+
+	mutex_lock(&wl->mutex);
+
+	if (unlikely(wl->state == WL1271_STATE_OFF)) {
+		ret = -EBUSY;
 		goto out;
+	}
 
 	ret = wl1271_ps_elp_wakeup(wl);
 	if (ret < 0)
 		goto out;
 
-	ret = wl12xx_cmd_remove_peer(wl, wl_sta->hlid);
-	if (ret < 0)
-		goto out_sleep;
+	ret = wl12xx_update_sta_state(wl, wlvif, sta, old_state, new_state);
 
-	wl1271_free_sta(wl, wlvif, wl_sta->hlid);
-
-out_sleep:
 	wl1271_ps_elp_sleep(wl);
-
 out:
 	mutex_unlock(&wl->mutex);
+	if (new_state < old_state)
+		return 0;
 	return ret;
 }
 
@@ -4795,8 +4823,7 @@
 	.conf_tx = wl1271_op_conf_tx,
 	.get_tsf = wl1271_op_get_tsf,
 	.get_survey = wl1271_op_get_survey,
-	.sta_add = wl1271_op_sta_add,
-	.sta_remove = wl1271_op_sta_remove,
+	.sta_state = wl12xx_op_sta_state,
 	.ampdu_action = wl1271_op_ampdu_action,
 	.tx_frames_pending = wl1271_tx_frames_pending,
 	.set_bitrate_mask = wl12xx_set_bitrate_mask,