xfrm: fix RCU bugs

This patch reverts commit 56892261ed1a (xfrm: Use rcu_dereference_bh to
deference pointer protected by rcu_read_lock_bh), and fixes bugs
introduced in commit 418a99ac6ad ( Replace rwlock on xfrm_policy_afinfo
with rcu )

1) We properly use RCU variant in this file, not a mix of RCU/RCU_BH

2) We must defer some writes after the synchronize_rcu() call or a reader
 can crash dereferencing NULL pointer.

3) Now we use the xfrm_policy_afinfo_lock spinlock only from process
context, we no longer need to block BH in xfrm_policy_register_afinfo()
and xfrm_policy_unregister_afinfo()

4) Can use RCU_INIT_POINTER() instead of rcu_assign_pointer() in
xfrm_policy_unregister_afinfo()

5) Remove a forward inline declaration (xfrm_policy_put_afinfo()),
  and also move xfrm_policy_get_afinfo() declaration.

Signed-off-by: Eric Dumazet <edumazet@google.com>
Cc: Fan Du <fan.du@windriver.com>
Cc: Priyanka Jain <Priyanka.Jain@freescale.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/xfrm/xfrm_policy.c b/net/xfrm/xfrm_policy.c
index 2ed698c..741a32a 100644
--- a/net/xfrm/xfrm_policy.c
+++ b/net/xfrm/xfrm_policy.c
@@ -48,8 +48,6 @@
 
 static struct kmem_cache *xfrm_dst_cache __read_mostly;
 
-static struct xfrm_policy_afinfo *xfrm_policy_get_afinfo(unsigned short family);
-static inline void xfrm_policy_put_afinfo(struct xfrm_policy_afinfo *afinfo);
 static void xfrm_init_pmtu(struct dst_entry *dst);
 static int stale_bundle(struct dst_entry *dst);
 static int xfrm_bundle_ok(struct xfrm_dst *xdst);
@@ -96,6 +94,24 @@
 	return false;
 }
 
+static struct xfrm_policy_afinfo *xfrm_policy_get_afinfo(unsigned short family)
+{
+	struct xfrm_policy_afinfo *afinfo;
+
+	if (unlikely(family >= NPROTO))
+		return NULL;
+	rcu_read_lock();
+	afinfo = rcu_dereference(xfrm_policy_afinfo[family]);
+	if (unlikely(!afinfo))
+		rcu_read_unlock();
+	return afinfo;
+}
+
+static void xfrm_policy_put_afinfo(struct xfrm_policy_afinfo *afinfo)
+{
+	rcu_read_unlock();
+}
+
 static inline struct dst_entry *__xfrm_dst_lookup(struct net *net, int tos,
 						  const xfrm_address_t *saddr,
 						  const xfrm_address_t *daddr,
@@ -2421,7 +2437,7 @@
 		return -EINVAL;
 	if (unlikely(afinfo->family >= NPROTO))
 		return -EAFNOSUPPORT;
-	spin_lock_bh(&xfrm_policy_afinfo_lock);
+	spin_lock(&xfrm_policy_afinfo_lock);
 	if (unlikely(xfrm_policy_afinfo[afinfo->family] != NULL))
 		err = -ENOBUFS;
 	else {
@@ -2444,7 +2460,7 @@
 			afinfo->garbage_collect = xfrm_garbage_collect_deferred;
 		rcu_assign_pointer(xfrm_policy_afinfo[afinfo->family], afinfo);
 	}
-	spin_unlock_bh(&xfrm_policy_afinfo_lock);
+	spin_unlock(&xfrm_policy_afinfo_lock);
 
 	rtnl_lock();
 	for_each_net(net) {
@@ -2477,23 +2493,26 @@
 		return -EINVAL;
 	if (unlikely(afinfo->family >= NPROTO))
 		return -EAFNOSUPPORT;
-	spin_lock_bh(&xfrm_policy_afinfo_lock);
+	spin_lock(&xfrm_policy_afinfo_lock);
 	if (likely(xfrm_policy_afinfo[afinfo->family] != NULL)) {
 		if (unlikely(xfrm_policy_afinfo[afinfo->family] != afinfo))
 			err = -EINVAL;
-		else {
-			struct dst_ops *dst_ops = afinfo->dst_ops;
-			rcu_assign_pointer(xfrm_policy_afinfo[afinfo->family],
-									NULL);
-			dst_ops->kmem_cachep = NULL;
-			dst_ops->check = NULL;
-			dst_ops->negative_advice = NULL;
-			dst_ops->link_failure = NULL;
-			afinfo->garbage_collect = NULL;
-		}
+		else
+			RCU_INIT_POINTER(xfrm_policy_afinfo[afinfo->family],
+					 NULL);
 	}
-	spin_unlock_bh(&xfrm_policy_afinfo_lock);
-	synchronize_rcu();
+	spin_unlock(&xfrm_policy_afinfo_lock);
+	if (!err) {
+		struct dst_ops *dst_ops = afinfo->dst_ops;
+
+		synchronize_rcu();
+
+		dst_ops->kmem_cachep = NULL;
+		dst_ops->check = NULL;
+		dst_ops->negative_advice = NULL;
+		dst_ops->link_failure = NULL;
+		afinfo->garbage_collect = NULL;
+	}
 	return err;
 }
 EXPORT_SYMBOL(xfrm_policy_unregister_afinfo);
@@ -2502,32 +2521,15 @@
 {
 	struct xfrm_policy_afinfo *afinfo;
 
-	rcu_read_lock_bh();
-	afinfo = rcu_dereference_bh(xfrm_policy_afinfo[AF_INET]);
+	rcu_read_lock();
+	afinfo = rcu_dereference(xfrm_policy_afinfo[AF_INET]);
 	if (afinfo)
 		net->xfrm.xfrm4_dst_ops = *afinfo->dst_ops;
 #if IS_ENABLED(CONFIG_IPV6)
-	afinfo = rcu_dereference_bh(xfrm_policy_afinfo[AF_INET6]);
+	afinfo = rcu_dereference(xfrm_policy_afinfo[AF_INET6]);
 	if (afinfo)
 		net->xfrm.xfrm6_dst_ops = *afinfo->dst_ops;
 #endif
-	rcu_read_unlock_bh();
-}
-
-static struct xfrm_policy_afinfo *xfrm_policy_get_afinfo(unsigned short family)
-{
-	struct xfrm_policy_afinfo *afinfo;
-	if (unlikely(family >= NPROTO))
-		return NULL;
-	rcu_read_lock();
-	afinfo = rcu_dereference(xfrm_policy_afinfo[family]);
-	if (unlikely(!afinfo))
-		rcu_read_unlock();
-	return afinfo;
-}
-
-static inline void xfrm_policy_put_afinfo(struct xfrm_policy_afinfo *afinfo)
-{
 	rcu_read_unlock();
 }