Merge branch 'switchdev-locking'

Jiri Pirko says:

====================
switchdev: change locking

This is something which I'm currently struggling with.
Callers of attr_set and obj_add/del often hold not only RTNL, but also
spinlock (bridge). So in that case, the driver implementing the op cannot sleep.

The way rocker is dealing with this now is just to invoke driver operation
and go out, without any checking or reporting of the operation status.

Since it would be nice to at least put a warning in case the operation fails,
it makes sense to do this in delayed work directly in switchdev core
instead of implementing this in separate drivers. And that is what this patchset
is introducing.

So from now on, the locking of switchdev mod ops is consistent. Caller either
holds rtnl mutex or in case it does not, caller sets defer flag, telling
switchdev core to process the op later, in deferred queue.

Function to force to process switchdev deferred ops can be called by op
caller in appropriate location, for example after it releases
spin lock, to force switchdev core to process pending ops.

v1->v2:
- rebased on current net-next head (including Scott's ageing patchset)
v2->v3:
- fixed comment s/of/or/ typo suggested by Nik
v3->v4:
- the actual patchset is sent instead of different branch I send in v3 :/
v4->v5:
- added patch to "const" attr param
- reworked deferred ops infrastructure (mainly patch number 1 and
  internal users (patch 3 and 5)) - resolves the issue pointed out
  by John
====================

Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/drivers/net/ethernet/rocker/rocker.c b/drivers/net/ethernet/rocker/rocker.c
index eafa907..32a80d2 100644
--- a/drivers/net/ethernet/rocker/rocker.c
+++ b/drivers/net/ethernet/rocker/rocker.c
@@ -3672,7 +3672,7 @@
 	    rocker_port->stp_state == BR_STATE_FORWARDING)
 		return 0;
 
-	flags |= ROCKER_OP_FLAG_REMOVE;
+	flags |= ROCKER_OP_FLAG_NOWAIT | ROCKER_OP_FLAG_REMOVE;
 
 	spin_lock_irqsave(&rocker->fdb_tbl_lock, lock_flags);
 
@@ -4374,7 +4374,7 @@
 }
 
 static int rocker_port_attr_set(struct net_device *dev,
-				struct switchdev_attr *attr,
+				const struct switchdev_attr *attr,
 				struct switchdev_trans *trans)
 {
 	struct rocker_port *rocker_port = netdev_priv(dev);
@@ -4382,8 +4382,7 @@
 
 	switch (attr->id) {
 	case SWITCHDEV_ATTR_ID_PORT_STP_STATE:
-		err = rocker_port_stp_update(rocker_port, trans,
-					     ROCKER_OP_FLAG_NOWAIT,
+		err = rocker_port_stp_update(rocker_port, trans, 0,
 					     attr->u.stp_state);
 		break;
 	case SWITCHDEV_ATTR_ID_PORT_BRIDGE_FLAGS:
@@ -4469,7 +4468,7 @@
 		fib4 = SWITCHDEV_OBJ_IPV4_FIB(obj);
 		err = rocker_port_fib_ipv4(rocker_port, trans,
 					   htonl(fib4->dst), fib4->dst_len,
-					   fib4->fi, fib4->tb_id, 0);
+					   &fib4->fi, fib4->tb_id, 0);
 		break;
 	case SWITCHDEV_OBJ_ID_PORT_FDB:
 		err = rocker_port_fdb_add(rocker_port, trans,
@@ -4517,7 +4516,7 @@
 			       const struct switchdev_obj_port_fdb *fdb)
 {
 	__be16 vlan_id = rocker_port_vid_to_vlan(rocker_port, fdb->vid, NULL);
-	int flags = ROCKER_OP_FLAG_NOWAIT | ROCKER_OP_FLAG_REMOVE;
+	int flags = ROCKER_OP_FLAG_REMOVE;
 
 	if (!rocker_port_is_bridged(rocker_port))
 		return -EINVAL;
@@ -4541,7 +4540,7 @@
 		fib4 = SWITCHDEV_OBJ_IPV4_FIB(obj);
 		err = rocker_port_fib_ipv4(rocker_port, NULL,
 					   htonl(fib4->dst), fib4->dst_len,
-					   fib4->fi, fib4->tb_id,
+					   &fib4->fi, fib4->tb_id,
 					   ROCKER_OP_FLAG_REMOVE);
 		break;
 	case SWITCHDEV_OBJ_ID_PORT_FDB:
@@ -4571,7 +4570,7 @@
 	hash_for_each_safe(rocker->fdb_tbl, bkt, tmp, found, entry) {
 		if (found->key.rocker_port != rocker_port)
 			continue;
-		fdb->addr = found->key.addr;
+		ether_addr_copy(fdb->addr, found->key.addr);
 		fdb->ndm_state = NUD_REACHABLE;
 		fdb->vid = rocker_port_vlan_to_vid(rocker_port,
 						   found->key.vlan_id);
diff --git a/include/net/switchdev.h b/include/net/switchdev.h
index 1ce7083..bc865e2 100644
--- a/include/net/switchdev.h
+++ b/include/net/switchdev.h
@@ -14,9 +14,11 @@
 #include <linux/netdevice.h>
 #include <linux/notifier.h>
 #include <linux/list.h>
+#include <net/ip_fib.h>
 
 #define SWITCHDEV_F_NO_RECURSE		BIT(0)
 #define SWITCHDEV_F_SKIP_EOPNOTSUPP	BIT(1)
+#define SWITCHDEV_F_DEFER		BIT(2)
 
 struct switchdev_trans_item {
 	struct list_head list;
@@ -58,8 +60,6 @@
 	} u;
 };
 
-struct fib_info;
-
 enum switchdev_obj_id {
 	SWITCHDEV_OBJ_ID_UNDEFINED,
 	SWITCHDEV_OBJ_ID_PORT_VLAN,
@@ -69,6 +69,7 @@
 
 struct switchdev_obj {
 	enum switchdev_obj_id id;
+	u32 flags;
 };
 
 /* SWITCHDEV_OBJ_ID_PORT_VLAN */
@@ -87,7 +88,7 @@
 	struct switchdev_obj obj;
 	u32 dst;
 	int dst_len;
-	struct fib_info *fi;
+	struct fib_info fi;
 	u8 tos;
 	u8 type;
 	u32 nlflags;
@@ -100,7 +101,7 @@
 /* SWITCHDEV_OBJ_ID_PORT_FDB */
 struct switchdev_obj_port_fdb {
 	struct switchdev_obj obj;
-	const unsigned char *addr;
+	unsigned char addr[ETH_ALEN];
 	u16 vid;
 	u16 ndm_state;
 };
@@ -132,7 +133,7 @@
 	int	(*switchdev_port_attr_get)(struct net_device *dev,
 					   struct switchdev_attr *attr);
 	int	(*switchdev_port_attr_set)(struct net_device *dev,
-					   struct switchdev_attr *attr,
+					   const struct switchdev_attr *attr,
 					   struct switchdev_trans *trans);
 	int	(*switchdev_port_obj_add)(struct net_device *dev,
 					  const struct switchdev_obj *obj,
@@ -167,10 +168,11 @@
 
 #ifdef CONFIG_NET_SWITCHDEV
 
+void switchdev_deferred_process(void);
 int switchdev_port_attr_get(struct net_device *dev,
 			    struct switchdev_attr *attr);
 int switchdev_port_attr_set(struct net_device *dev,
-			    struct switchdev_attr *attr);
+			    const struct switchdev_attr *attr);
 int switchdev_port_obj_add(struct net_device *dev,
 			   const struct switchdev_obj *obj);
 int switchdev_port_obj_del(struct net_device *dev,
@@ -208,6 +210,10 @@
 
 #else
 
+static inline void switchdev_deferred_process(void)
+{
+}
+
 static inline int switchdev_port_attr_get(struct net_device *dev,
 					  struct switchdev_attr *attr)
 {
@@ -215,7 +221,7 @@
 }
 
 static inline int switchdev_port_attr_set(struct net_device *dev,
-					  struct switchdev_attr *attr)
+					  const struct switchdev_attr *attr)
 {
 	return -EOPNOTSUPP;
 }
diff --git a/net/bridge/br_fdb.c b/net/bridge/br_fdb.c
index f43ce05..c88bd8e 100644
--- a/net/bridge/br_fdb.c
+++ b/net/bridge/br_fdb.c
@@ -134,11 +134,14 @@
 static void fdb_del_external_learn(struct net_bridge_fdb_entry *f)
 {
 	struct switchdev_obj_port_fdb fdb = {
-		.obj.id = SWITCHDEV_OBJ_ID_PORT_FDB,
-		.addr = f->addr.addr,
+		.obj = {
+			.id = SWITCHDEV_OBJ_ID_PORT_FDB,
+			.flags = SWITCHDEV_F_DEFER,
+		},
 		.vid = f->vlan_id,
 	};
 
+	ether_addr_copy(fdb.addr, f->addr.addr);
 	switchdev_port_obj_del(f->dst->dev, &fdb.obj);
 }
 
diff --git a/net/bridge/br_if.c b/net/bridge/br_if.c
index 45e4757..ec02f58 100644
--- a/net/bridge/br_if.c
+++ b/net/bridge/br_if.c
@@ -24,6 +24,7 @@
 #include <linux/slab.h>
 #include <net/sock.h>
 #include <linux/if_vlan.h>
+#include <net/switchdev.h>
 
 #include "br_private.h"
 
@@ -250,6 +251,8 @@
 
 	nbp_vlan_flush(p);
 	br_fdb_delete_by_port(br, p, 0, 1);
+	switchdev_deferred_process();
+
 	nbp_update_port_count(br);
 
 	netdev_upper_dev_unlink(dev, br->dev);
diff --git a/net/bridge/br_stp.c b/net/bridge/br_stp.c
index db6d243de..80c34d7 100644
--- a/net/bridge/br_stp.c
+++ b/net/bridge/br_stp.c
@@ -41,13 +41,14 @@
 {
 	struct switchdev_attr attr = {
 		.id = SWITCHDEV_ATTR_ID_PORT_STP_STATE,
+		.flags = SWITCHDEV_F_DEFER,
 		.u.stp_state = state,
 	};
 	int err;
 
 	p->state = state;
 	err = switchdev_port_attr_set(p->dev, &attr);
-	if (err && err != -EOPNOTSUPP)
+	if (err)
 		br_warn(p->br, "error setting offload STP state on port %u(%s)\n",
 				(unsigned int) p->port_no, p->dev->name);
 }
diff --git a/net/dsa/slave.c b/net/dsa/slave.c
index 43d7342..b0b8da0 100644
--- a/net/dsa/slave.c
+++ b/net/dsa/slave.c
@@ -393,7 +393,7 @@
 		if (ret < 0)
 			break;
 
-		fdb->addr = addr;
+		ether_addr_copy(fdb->addr, addr);
 		fdb->vid = vid;
 		fdb->ndm_state = is_static ? NUD_NOARP : NUD_REACHABLE;
 
@@ -453,7 +453,7 @@
 }
 
 static int dsa_slave_port_attr_set(struct net_device *dev,
-				   struct switchdev_attr *attr,
+				   const struct switchdev_attr *attr,
 				   struct switchdev_trans *trans)
 {
 	struct dsa_slave_priv *p = netdev_priv(dev);
diff --git a/net/switchdev/switchdev.c b/net/switchdev/switchdev.c
index b8aaf820..73e3895 100644
--- a/net/switchdev/switchdev.c
+++ b/net/switchdev/switchdev.c
@@ -15,8 +15,10 @@
 #include <linux/mutex.h>
 #include <linux/notifier.h>
 #include <linux/netdevice.h>
+#include <linux/etherdevice.h>
 #include <linux/if_bridge.h>
 #include <linux/list.h>
+#include <linux/workqueue.h>
 #include <net/ip_fib.h>
 #include <net/switchdev.h>
 
@@ -92,6 +94,85 @@
 	switchdev_trans_items_destroy(trans);
 }
 
+static LIST_HEAD(deferred);
+static DEFINE_SPINLOCK(deferred_lock);
+
+typedef void switchdev_deferred_func_t(struct net_device *dev,
+				       const void *data);
+
+struct switchdev_deferred_item {
+	struct list_head list;
+	struct net_device *dev;
+	switchdev_deferred_func_t *func;
+	unsigned long data[0];
+};
+
+static struct switchdev_deferred_item *switchdev_deferred_dequeue(void)
+{
+	struct switchdev_deferred_item *dfitem;
+
+	spin_lock_bh(&deferred_lock);
+	if (list_empty(&deferred)) {
+		dfitem = NULL;
+		goto unlock;
+	}
+	dfitem = list_first_entry(&deferred,
+				  struct switchdev_deferred_item, list);
+	list_del(&dfitem->list);
+unlock:
+	spin_unlock_bh(&deferred_lock);
+	return dfitem;
+}
+
+/**
+ *	switchdev_deferred_process - Process ops in deferred queue
+ *
+ *	Called to flush the ops currently queued in deferred ops queue.
+ *	rtnl_lock must be held.
+ */
+void switchdev_deferred_process(void)
+{
+	struct switchdev_deferred_item *dfitem;
+
+	ASSERT_RTNL();
+
+	while ((dfitem = switchdev_deferred_dequeue())) {
+		dfitem->func(dfitem->dev, dfitem->data);
+		dev_put(dfitem->dev);
+		kfree(dfitem);
+	}
+}
+EXPORT_SYMBOL_GPL(switchdev_deferred_process);
+
+static void switchdev_deferred_process_work(struct work_struct *work)
+{
+	rtnl_lock();
+	switchdev_deferred_process();
+	rtnl_unlock();
+}
+
+static DECLARE_WORK(deferred_process_work, switchdev_deferred_process_work);
+
+static int switchdev_deferred_enqueue(struct net_device *dev,
+				      const void *data, size_t data_len,
+				      switchdev_deferred_func_t *func)
+{
+	struct switchdev_deferred_item *dfitem;
+
+	dfitem = kmalloc(sizeof(*dfitem) + data_len, GFP_ATOMIC);
+	if (!dfitem)
+		return -ENOMEM;
+	dfitem->dev = dev;
+	dfitem->func = func;
+	memcpy(dfitem->data, data, data_len);
+	dev_hold(dev);
+	spin_lock_bh(&deferred_lock);
+	list_add_tail(&dfitem->list, &deferred);
+	spin_unlock_bh(&deferred_lock);
+	schedule_work(&deferred_process_work);
+	return 0;
+}
+
 /**
  *	switchdev_port_attr_get - Get port attribute
  *
@@ -135,7 +216,7 @@
 EXPORT_SYMBOL_GPL(switchdev_port_attr_get);
 
 static int __switchdev_port_attr_set(struct net_device *dev,
-				     struct switchdev_attr *attr,
+				     const struct switchdev_attr *attr,
 				     struct switchdev_trans *trans)
 {
 	const struct switchdev_ops *ops = dev->switchdev_ops;
@@ -170,74 +251,12 @@
 	return err;
 }
 
-struct switchdev_attr_set_work {
-	struct work_struct work;
-	struct net_device *dev;
-	struct switchdev_attr attr;
-};
-
-static void switchdev_port_attr_set_work(struct work_struct *work)
-{
-	struct switchdev_attr_set_work *asw =
-		container_of(work, struct switchdev_attr_set_work, work);
-	int err;
-
-	rtnl_lock();
-	err = switchdev_port_attr_set(asw->dev, &asw->attr);
-	if (err && err != -EOPNOTSUPP)
-		netdev_err(asw->dev, "failed (err=%d) to set attribute (id=%d)\n",
-			   err, asw->attr.id);
-	rtnl_unlock();
-
-	dev_put(asw->dev);
-	kfree(work);
-}
-
-static int switchdev_port_attr_set_defer(struct net_device *dev,
-					 struct switchdev_attr *attr)
-{
-	struct switchdev_attr_set_work *asw;
-
-	asw = kmalloc(sizeof(*asw), GFP_ATOMIC);
-	if (!asw)
-		return -ENOMEM;
-
-	INIT_WORK(&asw->work, switchdev_port_attr_set_work);
-
-	dev_hold(dev);
-	asw->dev = dev;
-	memcpy(&asw->attr, attr, sizeof(asw->attr));
-
-	schedule_work(&asw->work);
-
-	return 0;
-}
-
-/**
- *	switchdev_port_attr_set - Set port attribute
- *
- *	@dev: port device
- *	@attr: attribute to set
- *
- *	Use a 2-phase prepare-commit transaction model to ensure
- *	system is not left in a partially updated state due to
- *	failure from driver/device.
- */
-int switchdev_port_attr_set(struct net_device *dev, struct switchdev_attr *attr)
+static int switchdev_port_attr_set_now(struct net_device *dev,
+				       const struct switchdev_attr *attr)
 {
 	struct switchdev_trans trans;
 	int err;
 
-	if (!rtnl_is_locked()) {
-		/* Running prepare-commit transaction across stacked
-		 * devices requires nothing moves, so if rtnl_lock is
-		 * not held, schedule a worker thread to hold rtnl_lock
-		 * while setting attr.
-		 */
-
-		return switchdev_port_attr_set_defer(dev, attr);
-	}
-
 	switchdev_trans_init(&trans);
 
 	/* Phase I: prepare for attr set. Driver/device should fail
@@ -274,6 +293,47 @@
 
 	return err;
 }
+
+static void switchdev_port_attr_set_deferred(struct net_device *dev,
+					     const void *data)
+{
+	const struct switchdev_attr *attr = data;
+	int err;
+
+	err = switchdev_port_attr_set_now(dev, attr);
+	if (err && err != -EOPNOTSUPP)
+		netdev_err(dev, "failed (err=%d) to set attribute (id=%d)\n",
+			   err, attr->id);
+}
+
+static int switchdev_port_attr_set_defer(struct net_device *dev,
+					 const struct switchdev_attr *attr)
+{
+	return switchdev_deferred_enqueue(dev, attr, sizeof(*attr),
+					  switchdev_port_attr_set_deferred);
+}
+
+/**
+ *	switchdev_port_attr_set - Set port attribute
+ *
+ *	@dev: port device
+ *	@attr: attribute to set
+ *
+ *	Use a 2-phase prepare-commit transaction model to ensure
+ *	system is not left in a partially updated state due to
+ *	failure from driver/device.
+ *
+ *	rtnl_lock must be held and must not be in atomic section,
+ *	in case SWITCHDEV_F_DEFER flag is not set.
+ */
+int switchdev_port_attr_set(struct net_device *dev,
+			    const struct switchdev_attr *attr)
+{
+	if (attr->flags & SWITCHDEV_F_DEFER)
+		return switchdev_port_attr_set_defer(dev, attr);
+	ASSERT_RTNL();
+	return switchdev_port_attr_set_now(dev, attr);
+}
 EXPORT_SYMBOL_GPL(switchdev_port_attr_set);
 
 static int __switchdev_port_obj_add(struct net_device *dev,
@@ -302,21 +362,8 @@
 	return err;
 }
 
-/**
- *	switchdev_port_obj_add - Add port object
- *
- *	@dev: port device
- *	@id: object ID
- *	@obj: object to add
- *
- *	Use a 2-phase prepare-commit transaction model to ensure
- *	system is not left in a partially updated state due to
- *	failure from driver/device.
- *
- *	rtnl_lock must be held.
- */
-int switchdev_port_obj_add(struct net_device *dev,
-			   const struct switchdev_obj *obj)
+static int switchdev_port_obj_add_now(struct net_device *dev,
+				      const struct switchdev_obj *obj)
 {
 	struct switchdev_trans trans;
 	int err;
@@ -358,18 +405,53 @@
 
 	return err;
 }
-EXPORT_SYMBOL_GPL(switchdev_port_obj_add);
+
+static void switchdev_port_obj_add_deferred(struct net_device *dev,
+					    const void *data)
+{
+	const struct switchdev_obj *obj = data;
+	int err;
+
+	err = switchdev_port_obj_add_now(dev, obj);
+	if (err && err != -EOPNOTSUPP)
+		netdev_err(dev, "failed (err=%d) to add object (id=%d)\n",
+			   err, obj->id);
+}
+
+static int switchdev_port_obj_add_defer(struct net_device *dev,
+					const struct switchdev_obj *obj)
+{
+	return switchdev_deferred_enqueue(dev, obj, sizeof(*obj),
+					  switchdev_port_obj_add_deferred);
+}
 
 /**
- *	switchdev_port_obj_del - Delete port object
+ *	switchdev_port_obj_add - Add port object
  *
  *	@dev: port device
  *	@id: object ID
- *	@obj: object to delete
+ *	@obj: object to add
+ *
+ *	Use a 2-phase prepare-commit transaction model to ensure
+ *	system is not left in a partially updated state due to
+ *	failure from driver/device.
+ *
+ *	rtnl_lock must be held and must not be in atomic section,
+ *	in case SWITCHDEV_F_DEFER flag is not set.
  */
-int switchdev_port_obj_del(struct net_device *dev,
+int switchdev_port_obj_add(struct net_device *dev,
 			   const struct switchdev_obj *obj)
 {
+	if (obj->flags & SWITCHDEV_F_DEFER)
+		return switchdev_port_obj_add_defer(dev, obj);
+	ASSERT_RTNL();
+	return switchdev_port_obj_add_now(dev, obj);
+}
+EXPORT_SYMBOL_GPL(switchdev_port_obj_add);
+
+static int switchdev_port_obj_del_now(struct net_device *dev,
+				      const struct switchdev_obj *obj)
+{
 	const struct switchdev_ops *ops = dev->switchdev_ops;
 	struct net_device *lower_dev;
 	struct list_head *iter;
@@ -384,13 +466,51 @@
 	 */
 
 	netdev_for_each_lower_dev(dev, lower_dev, iter) {
-		err = switchdev_port_obj_del(lower_dev, obj);
+		err = switchdev_port_obj_del_now(lower_dev, obj);
 		if (err)
 			break;
 	}
 
 	return err;
 }
+
+static void switchdev_port_obj_del_deferred(struct net_device *dev,
+					    const void *data)
+{
+	const struct switchdev_obj *obj = data;
+	int err;
+
+	err = switchdev_port_obj_del_now(dev, obj);
+	if (err && err != -EOPNOTSUPP)
+		netdev_err(dev, "failed (err=%d) to del object (id=%d)\n",
+			   err, obj->id);
+}
+
+static int switchdev_port_obj_del_defer(struct net_device *dev,
+					const struct switchdev_obj *obj)
+{
+	return switchdev_deferred_enqueue(dev, obj, sizeof(*obj),
+					  switchdev_port_obj_del_deferred);
+}
+
+/**
+ *	switchdev_port_obj_del - Delete port object
+ *
+ *	@dev: port device
+ *	@id: object ID
+ *	@obj: object to delete
+ *
+ *	rtnl_lock must be held and must not be in atomic section,
+ *	in case SWITCHDEV_F_DEFER flag is not set.
+ */
+int switchdev_port_obj_del(struct net_device *dev,
+			   const struct switchdev_obj *obj)
+{
+	if (obj->flags & SWITCHDEV_F_DEFER)
+		return switchdev_port_obj_del_defer(dev, obj);
+	ASSERT_RTNL();
+	return switchdev_port_obj_del_now(dev, obj);
+}
 EXPORT_SYMBOL_GPL(switchdev_port_obj_del);
 
 /**
@@ -400,6 +520,8 @@
  *	@id: object ID
  *	@obj: object to dump
  *	@cb: function to call with a filled object
+ *
+ *	rtnl_lock must be held.
  */
 int switchdev_port_obj_dump(struct net_device *dev, struct switchdev_obj *obj,
 			    switchdev_obj_dump_cb_t *cb)
@@ -409,6 +531,8 @@
 	struct list_head *iter;
 	int err = -EOPNOTSUPP;
 
+	ASSERT_RTNL();
+
 	if (ops && ops->switchdev_port_obj_dump)
 		return ops->switchdev_port_obj_dump(dev, obj, cb);
 
@@ -832,10 +956,10 @@
 {
 	struct switchdev_obj_port_fdb fdb = {
 		.obj.id = SWITCHDEV_OBJ_ID_PORT_FDB,
-		.addr = addr,
 		.vid = vid,
 	};
 
+	ether_addr_copy(fdb.addr, addr);
 	return switchdev_port_obj_add(dev, &fdb.obj);
 }
 EXPORT_SYMBOL_GPL(switchdev_port_fdb_add);
@@ -857,10 +981,10 @@
 {
 	struct switchdev_obj_port_fdb fdb = {
 		.obj.id = SWITCHDEV_OBJ_ID_PORT_FDB,
-		.addr = addr,
 		.vid = vid,
 	};
 
+	ether_addr_copy(fdb.addr, addr);
 	return switchdev_port_obj_del(dev, &fdb.obj);
 }
 EXPORT_SYMBOL_GPL(switchdev_port_fdb_del);
@@ -977,6 +1101,8 @@
 	struct net_device *dev = NULL;
 	int nhsel;
 
+	ASSERT_RTNL();
+
 	/* For this route, all nexthop devs must be on the same switch. */
 
 	for (nhsel = 0; nhsel < fi->fib_nhs; nhsel++) {
@@ -1022,7 +1148,6 @@
 		.obj.id = SWITCHDEV_OBJ_ID_IPV4_FIB,
 		.dst = dst,
 		.dst_len = dst_len,
-		.fi = fi,
 		.tos = tos,
 		.type = type,
 		.nlflags = nlflags,
@@ -1031,6 +1156,8 @@
 	struct net_device *dev;
 	int err = 0;
 
+	memcpy(&ipv4_fib.fi, fi, sizeof(ipv4_fib.fi));
+
 	/* Don't offload route if using custom ip rules or if
 	 * IPv4 FIB offloading has been disabled completely.
 	 */
@@ -1074,7 +1201,6 @@
 		.obj.id = SWITCHDEV_OBJ_ID_IPV4_FIB,
 		.dst = dst,
 		.dst_len = dst_len,
-		.fi = fi,
 		.tos = tos,
 		.type = type,
 		.nlflags = 0,
@@ -1083,6 +1209,8 @@
 	struct net_device *dev;
 	int err = 0;
 
+	memcpy(&ipv4_fib.fi, fi, sizeof(ipv4_fib.fi));
+
 	if (!(fi->fib_flags & RTNH_F_OFFLOAD))
 		return 0;
 
@@ -1205,10 +1333,11 @@
 	u32 mark = dev->ifindex;
 	u32 reset_mark = 0;
 
-	if (group_dev && joining) {
-		mark = switchdev_port_fwd_mark_get(dev, group_dev);
-	} else if (group_dev && !joining) {
-		if (dev->offload_fwd_mark == mark)
+	if (group_dev) {
+		ASSERT_RTNL();
+		if (joining)
+			mark = switchdev_port_fwd_mark_get(dev, group_dev);
+		else if (dev->offload_fwd_mark == mark)
 			/* Ohoh, this port was the mark reference port,
 			 * but it's leaving the group, so reset the
 			 * mark for the remaining ports in the group.