bridge: mdb: Marking port-group as offloaded

There is a race-condition when updating the mdb offload flag without using
the mulicast_lock. This reverts commit 9e8430f8d60d98 ("bridge: mdb:
Passing the port-group pointer to br_mdb module").

This patch marks offloaded MDB entry as "offload" by changing the port-
group flags and marks it as MDB_PG_FLAGS_OFFLOAD.

When switchdev PORT_MDB succeeded and adds a multicast group, a completion
callback is been invoked "br_mdb_complete". The completion function
locks the multicast_lock and finds the right net_bridge_port_group and
marks it as offloaded.

Fixes: 9e8430f8d60d98 ("bridge: mdb: Passing the port-group pointer to br_mdb module")
Reported-by: Nikolay Aleksandrov <nikolay@cumulusnetworks.com>
Signed-off-by: Elad Raz <eladr@mellanox.com>
Signed-off-by: Jiri Pirko <jiri@mellanox.com>
Reviewed-by: Ido Schimmel <idosch@mellanox.com>
Acked-by: Nikolay Aleksandrov <nikolay@cumulusnetworks.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/bridge/br_mdb.c b/net/bridge/br_mdb.c
index b258120..7dbc80d 100644
--- a/net/bridge/br_mdb.c
+++ b/net/bridge/br_mdb.c
@@ -256,9 +256,45 @@
 		+ nla_total_size(sizeof(struct br_mdb_entry));
 }
 
-static void __br_mdb_notify(struct net_device *dev, struct br_mdb_entry *entry,
-			    int type, struct net_bridge_port_group *pg)
+struct br_mdb_complete_info {
+	struct net_bridge_port *port;
+	struct br_ip ip;
+};
+
+static void br_mdb_complete(struct net_device *dev, int err, void *priv)
 {
+	struct br_mdb_complete_info *data = priv;
+	struct net_bridge_port_group __rcu **pp;
+	struct net_bridge_port_group *p;
+	struct net_bridge_mdb_htable *mdb;
+	struct net_bridge_mdb_entry *mp;
+	struct net_bridge_port *port = data->port;
+	struct net_bridge *br = port->br;
+
+	if (err)
+		goto err;
+
+	spin_lock_bh(&br->multicast_lock);
+	mdb = mlock_dereference(br->mdb, br);
+	mp = br_mdb_ip_get(mdb, &data->ip);
+	if (!mp)
+		goto out;
+	for (pp = &mp->ports; (p = mlock_dereference(*pp, br)) != NULL;
+	     pp = &p->next) {
+		if (p->port != port)
+			continue;
+		p->flags |= MDB_PG_FLAGS_OFFLOAD;
+	}
+out:
+	spin_unlock_bh(&br->multicast_lock);
+err:
+	kfree(priv);
+}
+
+static void __br_mdb_notify(struct net_device *dev, struct net_bridge_port *p,
+			    struct br_mdb_entry *entry, int type)
+{
+	struct br_mdb_complete_info *complete_info;
 	struct switchdev_obj_port_mdb mdb = {
 		.obj = {
 			.id = SWITCHDEV_OBJ_ID_PORT_MDB,
@@ -281,9 +317,14 @@
 
 	mdb.obj.orig_dev = port_dev;
 	if (port_dev && type == RTM_NEWMDB) {
-		err = switchdev_port_obj_add(port_dev, &mdb.obj);
-		if (!err && pg)
-			pg->flags |= MDB_PG_FLAGS_OFFLOAD;
+		complete_info = kmalloc(sizeof(*complete_info), GFP_ATOMIC);
+		if (complete_info) {
+			complete_info->port = p;
+			__mdb_entry_to_br_ip(entry, &complete_info->ip);
+			mdb.obj.complete_priv = complete_info;
+			mdb.obj.complete = br_mdb_complete;
+			switchdev_port_obj_add(port_dev, &mdb.obj);
+		}
 	} else if (port_dev && type == RTM_DELMDB) {
 		switchdev_port_obj_del(port_dev, &mdb.obj);
 	}
@@ -304,21 +345,21 @@
 	rtnl_set_sk_err(net, RTNLGRP_MDB, err);
 }
 
-void br_mdb_notify(struct net_device *dev, struct net_bridge_port_group *pg,
-		   int type)
+void br_mdb_notify(struct net_device *dev, struct net_bridge_port *port,
+		   struct br_ip *group, int type, u8 flags)
 {
 	struct br_mdb_entry entry;
 
 	memset(&entry, 0, sizeof(entry));
-	entry.ifindex = pg->port->dev->ifindex;
-	entry.addr.proto = pg->addr.proto;
-	entry.addr.u.ip4 = pg->addr.u.ip4;
+	entry.ifindex = port->dev->ifindex;
+	entry.addr.proto = group->proto;
+	entry.addr.u.ip4 = group->u.ip4;
 #if IS_ENABLED(CONFIG_IPV6)
-	entry.addr.u.ip6 = pg->addr.u.ip6;
+	entry.addr.u.ip6 = group->u.ip6;
 #endif
-	entry.vid = pg->addr.vid;
-	__mdb_entry_fill_flags(&entry, pg->flags);
-	__br_mdb_notify(dev, &entry, type, pg);
+	entry.vid = group->vid;
+	__mdb_entry_fill_flags(&entry, flags);
+	__br_mdb_notify(dev, port, &entry, type);
 }
 
 static int nlmsg_populate_rtr_fill(struct sk_buff *skb,
@@ -463,8 +504,7 @@
 }
 
 static int br_mdb_add_group(struct net_bridge *br, struct net_bridge_port *port,
-			    struct br_ip *group, unsigned char state,
-			    struct net_bridge_port_group **pg)
+			    struct br_ip *group, unsigned char state)
 {
 	struct net_bridge_mdb_entry *mp;
 	struct net_bridge_port_group *p;
@@ -495,7 +535,6 @@
 	if (unlikely(!p))
 		return -ENOMEM;
 	rcu_assign_pointer(*pp, p);
-	*pg = p;
 	if (state == MDB_TEMPORARY)
 		mod_timer(&p->timer, now + br->multicast_membership_interval);
 
@@ -503,8 +542,7 @@
 }
 
 static int __br_mdb_add(struct net *net, struct net_bridge *br,
-			struct br_mdb_entry *entry,
-			struct net_bridge_port_group **pg)
+			struct br_mdb_entry *entry)
 {
 	struct br_ip ip;
 	struct net_device *dev;
@@ -525,7 +563,7 @@
 	__mdb_entry_to_br_ip(entry, &ip);
 
 	spin_lock_bh(&br->multicast_lock);
-	ret = br_mdb_add_group(br, p, &ip, entry->state, pg);
+	ret = br_mdb_add_group(br, p, &ip, entry->state);
 	spin_unlock_bh(&br->multicast_lock);
 	return ret;
 }
@@ -533,7 +571,6 @@
 static int br_mdb_add(struct sk_buff *skb, struct nlmsghdr *nlh)
 {
 	struct net *net = sock_net(skb->sk);
-	struct net_bridge_port_group *pg;
 	struct net_bridge_vlan_group *vg;
 	struct net_device *dev, *pdev;
 	struct br_mdb_entry *entry;
@@ -563,15 +600,15 @@
 	if (br_vlan_enabled(br) && vg && entry->vid == 0) {
 		list_for_each_entry(v, &vg->vlan_list, vlist) {
 			entry->vid = v->vid;
-			err = __br_mdb_add(net, br, entry, &pg);
+			err = __br_mdb_add(net, br, entry);
 			if (err)
 				break;
-			__br_mdb_notify(dev, entry, RTM_NEWMDB, pg);
+			__br_mdb_notify(dev, p, entry, RTM_NEWMDB);
 		}
 	} else {
-		err = __br_mdb_add(net, br, entry, &pg);
+		err = __br_mdb_add(net, br, entry);
 		if (!err)
-			__br_mdb_notify(dev, entry, RTM_NEWMDB, pg);
+			__br_mdb_notify(dev, p, entry, RTM_NEWMDB);
 	}
 
 	return err;
@@ -659,12 +696,12 @@
 			entry->vid = v->vid;
 			err = __br_mdb_del(br, entry);
 			if (!err)
-				__br_mdb_notify(dev, entry, RTM_DELMDB, NULL);
+				__br_mdb_notify(dev, p, entry, RTM_DELMDB);
 		}
 	} else {
 		err = __br_mdb_del(br, entry);
 		if (!err)
-			__br_mdb_notify(dev, entry, RTM_DELMDB, NULL);
+			__br_mdb_notify(dev, p, entry, RTM_DELMDB);
 	}
 
 	return err;