Merge branch 'mlxsw-layer2-multicast'

Jiri Pirko says:

====================
mlxsw: Adding layer 2 multicast

Elad says:

This patchset add Linux hardware reflection for L2 multicast offload and add
MC support in mlxsw. For every bridge MDB entry insertion, either by IGMP
snooping or by static insertion/removal, a switchdev ops is been called.
In mlxsw, a new multicast group (MID) is been created and ports are assigned.
When all ports are removed, the multicast group is been deleted.

---
v1->v2:
- GFP_ATOMIC->GFP_KERNEL change in patch 7/8
====================

Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/Documentation/networking/switchdev.txt b/Documentation/networking/switchdev.txt
index 9199413..fad6313 100644
--- a/Documentation/networking/switchdev.txt
+++ b/Documentation/networking/switchdev.txt
@@ -304,8 +304,12 @@
 IGMP Snooping
 ^^^^^^^^^^^^^
 
-XXX: complete this section
-
+In order to support IGMP snooping, the port netdevs should trap to the bridge
+driver all IGMP join and leave messages.
+The bridge multicast module will notify port netdevs on every multicast group
+changed whether it is static configured or dynamically joined/leave.
+The hardware implementation should be forwarding all registered multicast
+traffic groups only to the configured ports.
 
 L3 Routing Offload
 ------------------
diff --git a/drivers/net/ethernet/mellanox/mlxsw/reg.h b/drivers/net/ethernet/mellanox/mlxsw/reg.h
index 66d851d..0c52372 100644
--- a/drivers/net/ethernet/mellanox/mlxsw/reg.h
+++ b/drivers/net/ethernet/mellanox/mlxsw/reg.h
@@ -99,6 +99,55 @@
  */
 MLXSW_ITEM_BUF(reg, spad, base_mac, 0x02, 6);
 
+/* SMID - Switch Multicast ID
+ * --------------------------
+ * The MID record maps from a MID (Multicast ID), which is a unique identifier
+ * of the multicast group within the stacking domain, into a list of local
+ * ports into which the packet is replicated.
+ */
+#define MLXSW_REG_SMID_ID 0x2007
+#define MLXSW_REG_SMID_LEN 0x240
+
+static const struct mlxsw_reg_info mlxsw_reg_smid = {
+	.id = MLXSW_REG_SMID_ID,
+	.len = MLXSW_REG_SMID_LEN,
+};
+
+/* reg_smid_swid
+ * Switch partition ID.
+ * Access: Index
+ */
+MLXSW_ITEM32(reg, smid, swid, 0x00, 24, 8);
+
+/* reg_smid_mid
+ * Multicast identifier - global identifier that represents the multicast group
+ * across all devices.
+ * Access: Index
+ */
+MLXSW_ITEM32(reg, smid, mid, 0x00, 0, 16);
+
+/* reg_smid_port
+ * Local port memebership (1 bit per port).
+ * Access: RW
+ */
+MLXSW_ITEM_BIT_ARRAY(reg, smid, port, 0x20, 0x20, 1);
+
+/* reg_smid_port_mask
+ * Local port mask (1 bit per port).
+ * Access: W
+ */
+MLXSW_ITEM_BIT_ARRAY(reg, smid, port_mask, 0x220, 0x20, 1);
+
+static inline void mlxsw_reg_smid_pack(char *payload, u16 mid,
+				       u8 port, bool set)
+{
+	MLXSW_REG_ZERO(smid, payload);
+	mlxsw_reg_smid_swid_set(payload, 0);
+	mlxsw_reg_smid_mid_set(payload, mid);
+	mlxsw_reg_smid_port_set(payload, port, set);
+	mlxsw_reg_smid_port_mask_set(payload, port, 1);
+}
+
 /* SSPR - Switch System Port Record Register
  * -----------------------------------------
  * Configures the system port to local port mapping.
@@ -287,6 +336,7 @@
 enum mlxsw_reg_sfd_rec_type {
 	MLXSW_REG_SFD_REC_TYPE_UNICAST = 0x0,
 	MLXSW_REG_SFD_REC_TYPE_UNICAST_LAG = 0x1,
+	MLXSW_REG_SFD_REC_TYPE_MULTICAST = 0x2,
 };
 
 /* reg_sfd_rec_type
@@ -379,7 +429,6 @@
 
 static inline void mlxsw_reg_sfd_rec_pack(char *payload, int rec_index,
 					  enum mlxsw_reg_sfd_rec_type rec_type,
-					  enum mlxsw_reg_sfd_rec_policy policy,
 					  const char *mac,
 					  enum mlxsw_reg_sfd_rec_action action)
 {
@@ -389,7 +438,6 @@
 		mlxsw_reg_sfd_num_rec_set(payload, rec_index + 1);
 	mlxsw_reg_sfd_rec_swid_set(payload, rec_index, 0);
 	mlxsw_reg_sfd_rec_type_set(payload, rec_index, rec_type);
-	mlxsw_reg_sfd_rec_policy_set(payload, rec_index, policy);
 	mlxsw_reg_sfd_rec_mac_memcpy_to(payload, rec_index, mac);
 	mlxsw_reg_sfd_rec_action_set(payload, rec_index, action);
 }
@@ -401,8 +449,8 @@
 					 u8 local_port)
 {
 	mlxsw_reg_sfd_rec_pack(payload, rec_index,
-			       MLXSW_REG_SFD_REC_TYPE_UNICAST,
-			       policy, mac, action);
+			       MLXSW_REG_SFD_REC_TYPE_UNICAST, mac, action);
+	mlxsw_reg_sfd_rec_policy_set(payload, rec_index, policy);
 	mlxsw_reg_sfd_uc_sub_port_set(payload, rec_index, 0);
 	mlxsw_reg_sfd_uc_fid_vid_set(payload, rec_index, fid_vid);
 	mlxsw_reg_sfd_uc_system_port_set(payload, rec_index, local_port);
@@ -461,7 +509,8 @@
 {
 	mlxsw_reg_sfd_rec_pack(payload, rec_index,
 			       MLXSW_REG_SFD_REC_TYPE_UNICAST_LAG,
-			       policy, mac, action);
+			       mac, action);
+	mlxsw_reg_sfd_rec_policy_set(payload, rec_index, policy);
 	mlxsw_reg_sfd_uc_lag_sub_port_set(payload, rec_index, 0);
 	mlxsw_reg_sfd_uc_lag_fid_vid_set(payload, rec_index, fid_vid);
 	mlxsw_reg_sfd_uc_lag_lag_vid_set(payload, rec_index, lag_vid);
@@ -477,6 +526,45 @@
 	*p_lag_id = mlxsw_reg_sfd_uc_lag_lag_id_get(payload, rec_index);
 }
 
+/* reg_sfd_mc_pgi
+ *
+ * Multicast port group index - index into the port group table.
+ * Value 0x1FFF indicates the pgi should point to the MID entry.
+ * For Spectrum this value must be set to 0x1FFF
+ * Access: RW
+ */
+MLXSW_ITEM32_INDEXED(reg, sfd, mc_pgi, MLXSW_REG_SFD_BASE_LEN, 16, 13,
+		     MLXSW_REG_SFD_REC_LEN, 0x08, false);
+
+/* reg_sfd_mc_fid_vid
+ *
+ * Filtering ID or VLAN ID
+ * Access: Index
+ */
+MLXSW_ITEM32_INDEXED(reg, sfd, mc_fid_vid, MLXSW_REG_SFD_BASE_LEN, 0, 16,
+		     MLXSW_REG_SFD_REC_LEN, 0x08, false);
+
+/* reg_sfd_mc_mid
+ *
+ * Multicast identifier - global identifier that represents the multicast
+ * group across all devices.
+ * Access: RW
+ */
+MLXSW_ITEM32_INDEXED(reg, sfd, mc_mid, MLXSW_REG_SFD_BASE_LEN, 0, 16,
+		     MLXSW_REG_SFD_REC_LEN, 0x0C, false);
+
+static inline void
+mlxsw_reg_sfd_mc_pack(char *payload, int rec_index,
+		      const char *mac, u16 fid_vid,
+		      enum mlxsw_reg_sfd_rec_action action, u16 mid)
+{
+	mlxsw_reg_sfd_rec_pack(payload, rec_index,
+			       MLXSW_REG_SFD_REC_TYPE_MULTICAST, mac, action);
+	mlxsw_reg_sfd_mc_pgi_set(payload, rec_index, 0x1FFF);
+	mlxsw_reg_sfd_mc_fid_vid_set(payload, rec_index, fid_vid);
+	mlxsw_reg_sfd_mc_mid_set(payload, rec_index, mid);
+}
+
 /* SFN - Switch FDB Notification Register
  * -------------------------------------------
  * The switch provides notifications on newly learned FDB entries and
@@ -3013,6 +3101,8 @@
 		return "SGCR";
 	case MLXSW_REG_SPAD_ID:
 		return "SPAD";
+	case MLXSW_REG_SMID_ID:
+		return "SMID";
 	case MLXSW_REG_SSPR_ID:
 		return "SSPR";
 	case MLXSW_REG_SFDAT_ID:
diff --git a/drivers/net/ethernet/mellanox/mlxsw/spectrum.c b/drivers/net/ethernet/mellanox/mlxsw/spectrum.c
index b6f3650..ce6845d 100644
--- a/drivers/net/ethernet/mellanox/mlxsw/spectrum.c
+++ b/drivers/net/ethernet/mellanox/mlxsw/spectrum.c
@@ -1858,6 +1858,7 @@
 	mlxsw_sp->bus_info = mlxsw_bus_info;
 	INIT_LIST_HEAD(&mlxsw_sp->port_vfids.list);
 	INIT_LIST_HEAD(&mlxsw_sp->br_vfids.list);
+	INIT_LIST_HEAD(&mlxsw_sp->br_mids.list);
 
 	err = mlxsw_sp_base_mac_get(mlxsw_sp);
 	if (err) {
@@ -1939,7 +1940,7 @@
 	.used_max_port_per_lag		= 1,
 	.max_port_per_lag		= MLXSW_SP_PORT_PER_LAG_MAX,
 	.used_max_mid			= 1,
-	.max_mid			= 7000,
+	.max_mid			= MLXSW_SP_MID_MAX,
 	.used_max_pgt			= 1,
 	.max_pgt			= 0,
 	.used_max_system_port		= 1,
diff --git a/drivers/net/ethernet/mellanox/mlxsw/spectrum.h b/drivers/net/ethernet/mellanox/mlxsw/spectrum.h
index 7601789..199f91a 100644
--- a/drivers/net/ethernet/mellanox/mlxsw/spectrum.h
+++ b/drivers/net/ethernet/mellanox/mlxsw/spectrum.h
@@ -44,6 +44,7 @@
 #include <linux/list.h>
 #include <net/switchdev.h>
 
+#include "port.h"
 #include "core.h"
 
 #define MLXSW_SP_VFID_BASE VLAN_N_VID
@@ -54,6 +55,8 @@
 #define MLXSW_SP_LAG_MAX 64
 #define MLXSW_SP_PORT_PER_LAG_MAX 16
 
+#define MLXSW_SP_MID_MAX 7000
+
 struct mlxsw_sp_port;
 
 struct mlxsw_sp_upper {
@@ -69,6 +72,14 @@
 	u16 vid;
 };
 
+struct mlxsw_sp_mid {
+	struct list_head list;
+	unsigned char addr[ETH_ALEN];
+	u16 vid;
+	u16 mid;
+	unsigned int ref_count;
+};
+
 static inline u16 mlxsw_sp_vfid_to_fid(u16 vfid)
 {
 	return MLXSW_SP_VFID_BASE + vfid;
@@ -93,6 +104,10 @@
 		struct list_head list;
 		unsigned long mapped[BITS_TO_LONGS(MLXSW_SP_VFID_BR_MAX)];
 	} br_vfids;
+	struct {
+		struct list_head list;
+		unsigned long mapped[BITS_TO_LONGS(MLXSW_SP_MID_MAX)];
+	} br_mids;
 	unsigned long active_fids[BITS_TO_LONGS(VLAN_N_VID)];
 	struct mlxsw_sp_port **ports;
 	struct mlxsw_core *core;
diff --git a/drivers/net/ethernet/mellanox/mlxsw/spectrum_switchdev.c b/drivers/net/ethernet/mellanox/mlxsw/spectrum_switchdev.c
index d64559e..4cdc18e 100644
--- a/drivers/net/ethernet/mellanox/mlxsw/spectrum_switchdev.c
+++ b/drivers/net/ethernet/mellanox/mlxsw/spectrum_switchdev.c
@@ -51,6 +51,23 @@
 #include "core.h"
 #include "reg.h"
 
+static u16 mlxsw_sp_port_vid_to_fid_get(struct mlxsw_sp_port *mlxsw_sp_port,
+					u16 vid)
+{
+	u16 fid = vid;
+
+	if (mlxsw_sp_port_is_vport(mlxsw_sp_port)) {
+		u16 vfid = mlxsw_sp_vport_vfid_get(mlxsw_sp_port);
+
+		fid = mlxsw_sp_vfid_to_fid(vfid);
+	}
+
+	if (!fid)
+		fid = mlxsw_sp_port->pvid;
+
+	return fid;
+}
+
 static struct mlxsw_sp_port *
 mlxsw_sp_port_orig_get(struct net_device *dev,
 		       struct mlxsw_sp_port *mlxsw_sp_port)
@@ -641,22 +658,16 @@
 			     const struct switchdev_obj_port_fdb *fdb,
 			     struct switchdev_trans *trans)
 {
-	u16 fid = fdb->vid;
+	u16 fid = mlxsw_sp_port_vid_to_fid_get(mlxsw_sp_port, fdb->vid);
 	u16 lag_vid = 0;
 
 	if (switchdev_trans_ph_prepare(trans))
 		return 0;
 
 	if (mlxsw_sp_port_is_vport(mlxsw_sp_port)) {
-		u16 vfid = mlxsw_sp_vport_vfid_get(mlxsw_sp_port);
-
-		fid = mlxsw_sp_vfid_to_fid(vfid);
 		lag_vid = mlxsw_sp_vport_vid_get(mlxsw_sp_port);
 	}
 
-	if (!fid)
-		fid = mlxsw_sp_port->pvid;
-
 	if (!mlxsw_sp_port->lagged)
 		return mlxsw_sp_port_fdb_uc_op(mlxsw_sp_port->mlxsw_sp,
 					       mlxsw_sp_port->local_port,
@@ -668,6 +679,143 @@
 						   true, false);
 }
 
+static int mlxsw_sp_port_mdb_op(struct mlxsw_sp *mlxsw_sp, const char *addr,
+				u16 fid, u16 mid, bool adding)
+{
+	char *sfd_pl;
+	int err;
+
+	sfd_pl = kmalloc(MLXSW_REG_SFD_LEN, GFP_KERNEL);
+	if (!sfd_pl)
+		return -ENOMEM;
+
+	mlxsw_reg_sfd_pack(sfd_pl, mlxsw_sp_sfd_op(adding), 0);
+	mlxsw_reg_sfd_mc_pack(sfd_pl, 0, addr, fid,
+			      MLXSW_REG_SFD_REC_ACTION_NOP, mid);
+	err = mlxsw_reg_write(mlxsw_sp->core, MLXSW_REG(sfd), sfd_pl);
+	kfree(sfd_pl);
+	return err;
+}
+
+static int mlxsw_sp_port_smid_set(struct mlxsw_sp_port *mlxsw_sp_port, u16 mid,
+				  bool add, bool clear_all_ports)
+{
+	struct mlxsw_sp *mlxsw_sp = mlxsw_sp_port->mlxsw_sp;
+	char *smid_pl;
+	int err, i;
+
+	smid_pl = kmalloc(MLXSW_REG_SMID_LEN, GFP_KERNEL);
+	if (!smid_pl)
+		return -ENOMEM;
+
+	mlxsw_reg_smid_pack(smid_pl, mid, mlxsw_sp_port->local_port, add);
+	if (clear_all_ports) {
+		for (i = 1; i < MLXSW_PORT_MAX_PORTS; i++)
+			if (mlxsw_sp->ports[i])
+				mlxsw_reg_smid_port_mask_set(smid_pl, i, 1);
+	}
+	err = mlxsw_reg_write(mlxsw_sp->core, MLXSW_REG(smid), smid_pl);
+	kfree(smid_pl);
+	return err;
+}
+
+static struct mlxsw_sp_mid *__mlxsw_sp_mc_get(struct mlxsw_sp *mlxsw_sp,
+					      const unsigned char *addr,
+					      u16 vid)
+{
+	struct mlxsw_sp_mid *mid;
+
+	list_for_each_entry(mid, &mlxsw_sp->br_mids.list, list) {
+		if (ether_addr_equal(mid->addr, addr) && mid->vid == vid)
+			return mid;
+	}
+	return NULL;
+}
+
+static struct mlxsw_sp_mid *__mlxsw_sp_mc_alloc(struct mlxsw_sp *mlxsw_sp,
+						const unsigned char *addr,
+						u16 vid)
+{
+	struct mlxsw_sp_mid *mid;
+	u16 mid_idx;
+
+	mid_idx = find_first_zero_bit(mlxsw_sp->br_mids.mapped,
+				      MLXSW_SP_MID_MAX);
+	if (mid_idx == MLXSW_SP_MID_MAX)
+		return NULL;
+
+	mid = kzalloc(sizeof(*mid), GFP_KERNEL);
+	if (!mid)
+		return NULL;
+
+	set_bit(mid_idx, mlxsw_sp->br_mids.mapped);
+	ether_addr_copy(mid->addr, addr);
+	mid->vid = vid;
+	mid->mid = mid_idx;
+	mid->ref_count = 0;
+	list_add_tail(&mid->list, &mlxsw_sp->br_mids.list);
+
+	return mid;
+}
+
+static int __mlxsw_sp_mc_dec_ref(struct mlxsw_sp *mlxsw_sp,
+				 struct mlxsw_sp_mid *mid)
+{
+	if (--mid->ref_count == 0) {
+		list_del(&mid->list);
+		clear_bit(mid->mid, mlxsw_sp->br_mids.mapped);
+		kfree(mid);
+		return 1;
+	}
+	return 0;
+}
+
+static int mlxsw_sp_port_mdb_add(struct mlxsw_sp_port *mlxsw_sp_port,
+				 const struct switchdev_obj_port_mdb *mdb,
+				 struct switchdev_trans *trans)
+{
+	struct mlxsw_sp *mlxsw_sp = mlxsw_sp_port->mlxsw_sp;
+	struct net_device *dev = mlxsw_sp_port->dev;
+	struct mlxsw_sp_mid *mid;
+	u16 fid = mlxsw_sp_port_vid_to_fid_get(mlxsw_sp_port, mdb->vid);
+	int err = 0;
+
+	if (switchdev_trans_ph_prepare(trans))
+		return 0;
+
+	mid = __mlxsw_sp_mc_get(mlxsw_sp, mdb->addr, mdb->vid);
+	if (!mid) {
+		mid = __mlxsw_sp_mc_alloc(mlxsw_sp, mdb->addr, mdb->vid);
+		if (!mid) {
+			netdev_err(dev, "Unable to allocate MC group\n");
+			return -ENOMEM;
+		}
+	}
+	mid->ref_count++;
+
+	err = mlxsw_sp_port_smid_set(mlxsw_sp_port, mid->mid, true,
+				     mid->ref_count == 1);
+	if (err) {
+		netdev_err(dev, "Unable to set SMID\n");
+		goto err_out;
+	}
+
+	if (mid->ref_count == 1) {
+		err = mlxsw_sp_port_mdb_op(mlxsw_sp, mdb->addr, fid, mid->mid,
+					   true);
+		if (err) {
+			netdev_err(dev, "Unable to set MC SFD\n");
+			goto err_out;
+		}
+	}
+
+	return 0;
+
+err_out:
+	__mlxsw_sp_mc_dec_ref(mlxsw_sp, mid);
+	return err;
+}
+
 static int mlxsw_sp_port_obj_add(struct net_device *dev,
 				 const struct switchdev_obj *obj,
 				 struct switchdev_trans *trans)
@@ -693,6 +841,11 @@
 						   SWITCHDEV_OBJ_PORT_FDB(obj),
 						   trans);
 		break;
+	case SWITCHDEV_OBJ_ID_PORT_MDB:
+		err = mlxsw_sp_port_mdb_add(mlxsw_sp_port,
+					    SWITCHDEV_OBJ_PORT_MDB(obj),
+					    trans);
+		break;
 	default:
 		err = -EOPNOTSUPP;
 		break;
@@ -787,13 +940,10 @@
 mlxsw_sp_port_fdb_static_del(struct mlxsw_sp_port *mlxsw_sp_port,
 			     const struct switchdev_obj_port_fdb *fdb)
 {
-	u16 fid = fdb->vid;
+	u16 fid = mlxsw_sp_port_vid_to_fid_get(mlxsw_sp_port, fdb->vid);
 	u16 lag_vid = 0;
 
 	if (mlxsw_sp_port_is_vport(mlxsw_sp_port)) {
-		u16 vfid = mlxsw_sp_vport_vfid_get(mlxsw_sp_port);
-
-		fid = mlxsw_sp_vfid_to_fid(vfid);
 		lag_vid = mlxsw_sp_vport_vid_get(mlxsw_sp_port);
 	}
 
@@ -809,6 +959,37 @@
 						   false, false);
 }
 
+static int mlxsw_sp_port_mdb_del(struct mlxsw_sp_port *mlxsw_sp_port,
+				 const struct switchdev_obj_port_mdb *mdb)
+{
+	struct mlxsw_sp *mlxsw_sp = mlxsw_sp_port->mlxsw_sp;
+	struct net_device *dev = mlxsw_sp_port->dev;
+	struct mlxsw_sp_mid *mid;
+	u16 fid = mlxsw_sp_port_vid_to_fid_get(mlxsw_sp_port, mdb->vid);
+	u16 mid_idx;
+	int err = 0;
+
+	mid = __mlxsw_sp_mc_get(mlxsw_sp, mdb->addr, mdb->vid);
+	if (!mid) {
+		netdev_err(dev, "Unable to remove port from MC DB\n");
+		return -EINVAL;
+	}
+
+	err = mlxsw_sp_port_smid_set(mlxsw_sp_port, mid->mid, false, false);
+	if (err)
+		netdev_err(dev, "Unable to remove port from SMID\n");
+
+	mid_idx = mid->mid;
+	if (__mlxsw_sp_mc_dec_ref(mlxsw_sp, mid)) {
+		err = mlxsw_sp_port_mdb_op(mlxsw_sp, mdb->addr, fid, mid_idx,
+					   false);
+		if (err)
+			netdev_err(dev, "Unable to remove MC SFD\n");
+	}
+
+	return err;
+}
+
 static int mlxsw_sp_port_obj_del(struct net_device *dev,
 				 const struct switchdev_obj *obj)
 {
@@ -831,6 +1012,9 @@
 		err = mlxsw_sp_port_fdb_static_del(mlxsw_sp_port,
 						   SWITCHDEV_OBJ_PORT_FDB(obj));
 		break;
+	case SWITCHDEV_OBJ_ID_PORT_MDB:
+		err = mlxsw_sp_port_mdb_del(mlxsw_sp_port,
+					    SWITCHDEV_OBJ_PORT_MDB(obj));
 	default:
 		err = -EOPNOTSUPP;
 		break;
diff --git a/include/net/switchdev.h b/include/net/switchdev.h
index 603ae2f..d451122 100644
--- a/include/net/switchdev.h
+++ b/include/net/switchdev.h
@@ -68,6 +68,7 @@
 	SWITCHDEV_OBJ_ID_PORT_VLAN,
 	SWITCHDEV_OBJ_ID_IPV4_FIB,
 	SWITCHDEV_OBJ_ID_PORT_FDB,
+	SWITCHDEV_OBJ_ID_PORT_MDB,
 };
 
 struct switchdev_obj {
@@ -113,6 +114,16 @@
 #define SWITCHDEV_OBJ_PORT_FDB(obj) \
 	container_of(obj, struct switchdev_obj_port_fdb, obj)
 
+/* SWITCHDEV_OBJ_ID_PORT_MDB */
+struct switchdev_obj_port_mdb {
+	struct switchdev_obj obj;
+	unsigned char addr[ETH_ALEN];
+	u16 vid;
+};
+
+#define SWITCHDEV_OBJ_PORT_MDB(obj) \
+	container_of(obj, struct switchdev_obj_port_mdb, obj)
+
 void switchdev_trans_item_enqueue(struct switchdev_trans *trans,
 				  void *data, void (*destructor)(void const *),
 				  struct switchdev_trans_item *tritem);
diff --git a/net/bridge/br_mdb.c b/net/bridge/br_mdb.c
index cd8deea..30e105f 100644
--- a/net/bridge/br_mdb.c
+++ b/net/bridge/br_mdb.c
@@ -7,6 +7,7 @@
 #include <linux/if_ether.h>
 #include <net/ip.h>
 #include <net/netlink.h>
+#include <net/switchdev.h>
 #if IS_ENABLED(CONFIG_IPV6)
 #include <net/ipv6.h>
 #include <net/addrconf.h>
@@ -210,10 +211,32 @@
 static void __br_mdb_notify(struct net_device *dev, struct br_mdb_entry *entry,
 			    int type)
 {
+	struct switchdev_obj_port_mdb mdb = {
+		.obj = {
+			.id = SWITCHDEV_OBJ_ID_PORT_MDB,
+			.flags = SWITCHDEV_F_DEFER,
+		},
+		.vid = entry->vid,
+	};
+	struct net_device *port_dev;
 	struct net *net = dev_net(dev);
 	struct sk_buff *skb;
 	int err = -ENOBUFS;
 
+	port_dev = __dev_get_by_index(net, entry->ifindex);
+	if (entry->addr.proto == htons(ETH_P_IP))
+		ip_eth_mc_map(entry->addr.u.ip4, mdb.addr);
+#if IS_ENABLED(CONFIG_IPV6)
+	else
+		ipv6_eth_mc_map(&entry->addr.u.ip6, mdb.addr);
+#endif
+
+	mdb.obj.orig_dev = port_dev;
+	if (port_dev && type == RTM_NEWMDB)
+		switchdev_port_obj_add(port_dev, &mdb.obj);
+	else if (port_dev && type == RTM_DELMDB)
+		switchdev_port_obj_del(port_dev, &mdb.obj);
+
 	skb = nlmsg_new(rtnl_mdb_nlmsg_size(), GFP_ATOMIC);
 	if (!skb)
 		goto errout;
diff --git a/net/switchdev/switchdev.c b/net/switchdev/switchdev.c
index df790d3..ebc661d 100644
--- a/net/switchdev/switchdev.c
+++ b/net/switchdev/switchdev.c
@@ -345,6 +345,8 @@
 		return sizeof(struct switchdev_obj_ipv4_fib);
 	case SWITCHDEV_OBJ_ID_PORT_FDB:
 		return sizeof(struct switchdev_obj_port_fdb);
+	case SWITCHDEV_OBJ_ID_PORT_MDB:
+		return sizeof(struct switchdev_obj_port_mdb);
 	default:
 		BUG();
 	}