mac80211: RCU-ify STA info structure access

This makes access to the STA hash table/list use RCU to protect
against freeing of items. However, it's not a true RCU, the
copy step is missing: whenever somebody changes a STA item it
is simply updated. This is an existing race condition that is
now somewhat understandable.

This patch also fixes the race key freeing vs. STA destruction
by making sure that sta_info_destroy() is always called under
RTNL and frees the key.

Signed-off-by: Johannes Berg <johannes@sipsolutions.net>
Signed-off-by: John W. Linville <linville@tuxdriver.com>
diff --git a/net/mac80211/mesh_plink.c b/net/mac80211/mesh_plink.c
index b5fbe97..c2b8050 100644
--- a/net/mac80211/mesh_plink.c
+++ b/net/mac80211/mesh_plink.c
@@ -65,14 +65,14 @@
 void mesh_plink_inc_estab_count(struct ieee80211_sub_if_data *sdata)
 {
 	atomic_inc(&sdata->u.sta.mshstats.estab_plinks);
-	mesh_accept_plinks_update(sdata->dev);
+	mesh_accept_plinks_update(sdata);
 }
 
 static inline
 void mesh_plink_dec_estab_count(struct ieee80211_sub_if_data *sdata)
 {
 	atomic_dec(&sdata->u.sta.mshstats.estab_plinks);
-	mesh_accept_plinks_update(sdata->dev);
+	mesh_accept_plinks_update(sdata);
 }
 
 /**
@@ -99,12 +99,13 @@
  *
  * Returns: non-NULL on success, ERR_PTR() on error.
  */
-struct sta_info *mesh_plink_add(u8 *hw_addr, u64 rates, struct net_device *dev)
+struct sta_info *mesh_plink_add(u8 *hw_addr, u64 rates,
+				struct ieee80211_sub_if_data *sdata)
 {
-	struct ieee80211_local *local = wdev_priv(dev->ieee80211_ptr);
+	struct ieee80211_local *local = sdata->local;
 	struct sta_info *sta;
 
-	if (memcmp(hw_addr, dev->dev_addr, ETH_ALEN) == 0)
+	if (compare_ether_addr(hw_addr, sdata->dev->dev_addr) == 0)
 		/* never add ourselves as neighbours */
 		return ERR_PTR(-EINVAL);
 
@@ -114,7 +115,7 @@
 	if (local->num_sta >= MESH_MAX_PLINKS)
 		return ERR_PTR(-ENOSPC);
 
-	sta = sta_info_add(local, dev, hw_addr, GFP_KERNEL);
+	sta = sta_info_add(sdata, hw_addr);
 	if (IS_ERR(sta))
 		return sta;
 
@@ -125,7 +126,7 @@
 	sta->supp_rates[local->hw.conf.channel->band] = rates;
 	rate_control_rate_init(sta, local);
 
-	mesh_accept_plinks_update(dev);
+	mesh_accept_plinks_update(sdata);
 
 	return sta;
 }
@@ -141,7 +142,8 @@
  */
 static void __mesh_plink_deactivate(struct sta_info *sta)
 {
-	struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(sta->dev);
+	struct ieee80211_sub_if_data *sdata = sta->sdata;
+
 	if (sta->plink_state == ESTAB)
 		mesh_plink_dec_estab_count(sdata);
 	sta->plink_state = BLOCKED;
@@ -246,11 +248,15 @@
 	struct ieee80211_local *local = wdev_priv(dev->ieee80211_ptr);
 	struct sta_info *sta;
 
+	rcu_read_lock();
+
 	sta = sta_info_get(local, hw_addr);
 	if (!sta) {
-		sta = mesh_plink_add(hw_addr, rates, dev);
-		if (IS_ERR(sta))
+		sta = mesh_plink_add(hw_addr, rates, sdata);
+		if (IS_ERR(sta)) {
+			rcu_read_unlock();
 			return;
+		}
 	}
 
 	sta->last_rx = jiffies;
@@ -260,7 +266,7 @@
 			sdata->u.sta.mshcfg.auto_open_plinks)
 		mesh_plink_open(sta);
 
-	sta_info_put(sta);
+	rcu_read_unlock();
 }
 
 static void mesh_plink_timer(unsigned long data)
@@ -273,6 +279,11 @@
 	DECLARE_MAC_BUF(mac);
 #endif
 
+	/*
+	 * This STA is valid because sta_info_destroy() will
+	 * del_timer_sync() this timer after having made sure
+	 * it cannot be readded (by deleting the plink.)
+	 */
 	sta = (struct sta_info *) data;
 
 	spin_lock_bh(&sta->plink_lock);
@@ -286,8 +297,8 @@
 	reason = 0;
 	llid = sta->llid;
 	plid = sta->plid;
-	dev = sta->dev;
-	sdata = IEEE80211_DEV_TO_SUB_IF(dev);
+	sdata = sta->sdata;
+	dev = sdata->dev;
 
 	switch (sta->plink_state) {
 	case OPN_RCVD:
@@ -302,8 +313,7 @@
 			sta->plink_timeout = sta->plink_timeout +
 					     rand % sta->plink_timeout;
 			++sta->plink_retries;
-			if (!mod_plink_timer(sta, sta->plink_timeout))
-				__sta_info_get(sta);
+			mod_plink_timer(sta, sta->plink_timeout);
 			spin_unlock_bh(&sta->plink_lock);
 			mesh_plink_frame_tx(dev, PLINK_OPEN, sta->addr, llid,
 					    0, 0);
@@ -316,16 +326,14 @@
 		if (!reason)
 			reason = cpu_to_le16(MESH_CONFIRM_TIMEOUT);
 		sta->plink_state = HOLDING;
-		if (!mod_plink_timer(sta, dot11MeshHoldingTimeout(sdata)))
-			__sta_info_get(sta);
+		mod_plink_timer(sta, dot11MeshHoldingTimeout(sdata));
 		spin_unlock_bh(&sta->plink_lock);
 		mesh_plink_frame_tx(dev, PLINK_CLOSE, sta->addr, llid, plid,
 				    reason);
 		break;
 	case HOLDING:
 		/* holding timer */
-		if (del_timer(&sta->plink_timer))
-			sta_info_put(sta);
+		del_timer(&sta->plink_timer);
 		mesh_plink_fsm_restart(sta);
 		spin_unlock_bh(&sta->plink_lock);
 		break;
@@ -333,8 +341,6 @@
 		spin_unlock_bh(&sta->plink_lock);
 		break;
 	}
-
-	sta_info_put(sta);
 }
 
 static inline void mesh_plink_timer_set(struct sta_info *sta, int timeout)
@@ -343,14 +349,13 @@
 	sta->plink_timer.data = (unsigned long) sta;
 	sta->plink_timer.function = mesh_plink_timer;
 	sta->plink_timeout = timeout;
-	__sta_info_get(sta);
 	add_timer(&sta->plink_timer);
 }
 
 int mesh_plink_open(struct sta_info *sta)
 {
 	__le16 llid;
-	struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(sta->dev);
+	struct ieee80211_sub_if_data *sdata = sta->sdata;
 #ifdef CONFIG_MAC80211_VERBOSE_MPL_DEBUG
 	DECLARE_MAC_BUF(mac);
 #endif
@@ -360,7 +365,6 @@
 	sta->llid = llid;
 	if (sta->plink_state != LISTEN) {
 		spin_unlock_bh(&sta->plink_lock);
-		sta_info_put(sta);
 		return -EBUSY;
 	}
 	sta->plink_state = OPN_SNT;
@@ -369,7 +373,8 @@
 	mpl_dbg("Mesh plink: starting establishment with %s\n",
 		print_mac(mac, sta->addr));
 
-	return mesh_plink_frame_tx(sta->dev, PLINK_OPEN, sta->addr, llid, 0, 0);
+	return mesh_plink_frame_tx(sdata->dev, PLINK_OPEN,
+				   sta->addr, llid, 0, 0);
 }
 
 void mesh_plink_block(struct sta_info *sta)
@@ -386,7 +391,7 @@
 
 int mesh_plink_close(struct sta_info *sta)
 {
-	struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(sta->dev);
+	struct ieee80211_sub_if_data *sdata = sta->sdata;
 	int llid, plid, reason;
 #ifdef CONFIG_MAC80211_VERBOSE_MPL_DEBUG
 	DECLARE_MAC_BUF(mac);
@@ -401,13 +406,11 @@
 	if (sta->plink_state == LISTEN || sta->plink_state == BLOCKED) {
 		mesh_plink_fsm_restart(sta);
 		spin_unlock_bh(&sta->plink_lock);
-		sta_info_put(sta);
 		return 0;
 	} else if (sta->plink_state == ESTAB) {
 		__mesh_plink_deactivate(sta);
 		/* The timer should not be running */
-		if (!mod_plink_timer(sta, dot11MeshHoldingTimeout(sdata)))
-			__sta_info_get(sta);
+		mod_plink_timer(sta, dot11MeshHoldingTimeout(sdata));
 	} else if (!mod_plink_timer(sta, dot11MeshHoldingTimeout(sdata)))
 		sta->ignore_plink_timer = true;
 
@@ -415,15 +418,16 @@
 	llid = sta->llid;
 	plid = sta->plid;
 	spin_unlock_bh(&sta->plink_lock);
-	mesh_plink_frame_tx(sta->dev, PLINK_CLOSE, sta->addr, llid, plid,
-			    reason);
+	mesh_plink_frame_tx(sta->sdata->dev, PLINK_CLOSE, sta->addr, llid,
+			    plid, reason);
 	return 0;
 }
 
 void mesh_rx_plink_frame(struct net_device *dev, struct ieee80211_mgmt *mgmt,
 			 size_t len, struct ieee80211_rx_status *rx_status)
 {
-	struct ieee80211_local *local = wdev_priv(dev->ieee80211_ptr);
+	struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev);
+	struct ieee80211_local *local = sdata->local;
 	struct ieee802_11_elems elems;
 	struct sta_info *sta;
 	enum plink_event event;
@@ -435,7 +439,6 @@
 #ifdef CONFIG_MAC80211_VERBOSE_MPL_DEBUG
 	DECLARE_MAC_BUF(mac);
 #endif
-	struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev);
 
 	if (is_multicast_ether_addr(mgmt->da)) {
 		mpl_dbg("Mesh plink: ignore frame from multicast address");
@@ -474,14 +477,17 @@
 	if (ftype == PLINK_CONFIRM || (ftype == PLINK_CLOSE && ie_len == 7))
 		memcpy(&llid, PLINK_GET_PLID(elems.peer_link), 2);
 
+	rcu_read_lock();
+
 	sta = sta_info_get(local, mgmt->sa);
 	if (!sta && ftype != PLINK_OPEN) {
 		mpl_dbg("Mesh plink: cls or cnf from unknown peer\n");
+		rcu_read_unlock();
 		return;
 	}
 
 	if (sta && sta->plink_state == BLOCKED) {
-		sta_info_put(sta);
+		rcu_read_unlock();
 		return;
 	}
 
@@ -505,13 +511,15 @@
 		u64 rates;
 		if (!mesh_plink_free_count(sdata)) {
 			mpl_dbg("Mesh plink error: no more free plinks\n");
+			rcu_read_unlock();
 			return;
 		}
 
 		rates = ieee80211_sta_get_rates(local, &elems, rx_status->band);
-		sta = mesh_plink_add(mgmt->sa, rates, dev);
+		sta = mesh_plink_add(mgmt->sa, rates, sdata);
 		if (IS_ERR(sta)) {
 			mpl_dbg("Mesh plink error: plink table full\n");
+			rcu_read_unlock();
 			return;
 		}
 		event = OPN_ACPT;
@@ -521,14 +529,14 @@
 		switch (ftype) {
 		case PLINK_OPEN:
 			if (!mesh_plink_free_count(sdata) ||
-					(sta->plid && sta->plid != plid))
+			    (sta->plid && sta->plid != plid))
 				event = OPN_IGNR;
 			else
 				event = OPN_ACPT;
 			break;
 		case PLINK_CONFIRM:
 			if (!mesh_plink_free_count(sdata) ||
-				(sta->llid != llid || sta->plid != plid))
+			    (sta->llid != llid || sta->plid != plid))
 				event = CNF_IGNR;
 			else
 				event = CNF_ACPT;
@@ -555,7 +563,7 @@
 		default:
 			mpl_dbg("Mesh plink: unknown frame subtype\n");
 			spin_unlock_bh(&sta->plink_lock);
-			sta_info_put(sta);
+			rcu_read_unlock();
 			return;
 		}
 	}
@@ -659,8 +667,7 @@
 					    plid, 0);
 			break;
 		case CNF_ACPT:
-			if (del_timer(&sta->plink_timer))
-				sta_info_put(sta);
+			del_timer(&sta->plink_timer);
 			sta->plink_state = ESTAB;
 			mesh_plink_inc_estab_count(sdata);
 			spin_unlock_bh(&sta->plink_lock);
@@ -693,8 +700,7 @@
 					    plid, reason);
 			break;
 		case OPN_ACPT:
-			if (del_timer(&sta->plink_timer))
-				sta_info_put(sta);
+			del_timer(&sta->plink_timer);
 			sta->plink_state = ESTAB;
 			mesh_plink_inc_estab_count(sdata);
 			spin_unlock_bh(&sta->plink_lock);
@@ -717,9 +723,7 @@
 			__mesh_plink_deactivate(sta);
 			sta->plink_state = HOLDING;
 			llid = sta->llid;
-			if (!mod_plink_timer(sta,
-					dot11MeshHoldingTimeout(sdata)))
-				__sta_info_get(sta);
+			mod_plink_timer(sta, dot11MeshHoldingTimeout(sdata));
 			spin_unlock_bh(&sta->plink_lock);
 			mesh_plink_frame_tx(dev, PLINK_CLOSE, sta->addr, llid,
 					    plid, reason);
@@ -738,10 +742,8 @@
 	case HOLDING:
 		switch (event) {
 		case CLS_ACPT:
-			if (del_timer(&sta->plink_timer)) {
+			if (del_timer(&sta->plink_timer))
 				sta->ignore_plink_timer = 1;
-				sta_info_put(sta);
-			}
 			mesh_plink_fsm_restart(sta);
 			spin_unlock_bh(&sta->plink_lock);
 			break;
@@ -766,5 +768,6 @@
 		spin_unlock_bh(&sta->plink_lock);
 		break;
 	}
-	sta_info_put(sta);
+
+	rcu_read_unlock();
 }