netfilter: nf_tables: unbind set in rule from commit path

Anonymous sets that are bound to rules from the same transaction trigger
a kernel splat from the abort path due to double set list removal and
double free.

This patch updates the logic to search for the transaction that is
responsible for creating the set and disable the set list removal and
release, given the rule is now responsible for this. Lookup is reverse
since the transaction that adds the set is likely to be at the tail of
the list.

Moreover, this patch adds the unbind step to deliver the event from the
commit path.  This should not be done from the worker thread, since we
have no guarantees of in-order delivery to the listener.

This patch removes the assumption that both activate and deactivate
callbacks need to be provided.

Fixes: cd5125d8f518 ("netfilter: nf_tables: split set destruction in deactivate and destroy phase")
Reported-by: Mikhail Morfikov <mmorfikov@gmail.com>
Signed-off-by: Pablo Neira Ayuso <pablo@netfilter.org>
diff --git a/include/net/netfilter/nf_tables.h b/include/net/netfilter/nf_tables.h
index 2c33958..50c101e 100644
--- a/include/net/netfilter/nf_tables.h
+++ b/include/net/netfilter/nf_tables.h
@@ -469,9 +469,7 @@
 int nf_tables_bind_set(const struct nft_ctx *ctx, struct nft_set *set,
 		       struct nft_set_binding *binding);
 void nf_tables_unbind_set(const struct nft_ctx *ctx, struct nft_set *set,
-			  struct nft_set_binding *binding);
-void nf_tables_rebind_set(const struct nft_ctx *ctx, struct nft_set *set,
-			  struct nft_set_binding *binding);
+			  struct nft_set_binding *binding, bool commit);
 void nf_tables_destroy_set(const struct nft_ctx *ctx, struct nft_set *set);
 
 /**
@@ -721,6 +719,13 @@
 #define NFT_EXPR_STATEFUL		0x1
 #define NFT_EXPR_GC			0x2
 
+enum nft_trans_phase {
+	NFT_TRANS_PREPARE,
+	NFT_TRANS_ABORT,
+	NFT_TRANS_COMMIT,
+	NFT_TRANS_RELEASE
+};
+
 /**
  *	struct nft_expr_ops - nf_tables expression operations
  *
@@ -750,7 +755,8 @@
 	void				(*activate)(const struct nft_ctx *ctx,
 						    const struct nft_expr *expr);
 	void				(*deactivate)(const struct nft_ctx *ctx,
-						      const struct nft_expr *expr);
+						      const struct nft_expr *expr,
+						      enum nft_trans_phase phase);
 	void				(*destroy)(const struct nft_ctx *ctx,
 						   const struct nft_expr *expr);
 	void				(*destroy_clone)(const struct nft_ctx *ctx,
@@ -1321,12 +1327,15 @@
 struct nft_trans_set {
 	struct nft_set			*set;
 	u32				set_id;
+	bool				bound;
 };
 
 #define nft_trans_set(trans)	\
 	(((struct nft_trans_set *)trans->data)->set)
 #define nft_trans_set_id(trans)	\
 	(((struct nft_trans_set *)trans->data)->set_id)
+#define nft_trans_set_bound(trans)	\
+	(((struct nft_trans_set *)trans->data)->bound)
 
 struct nft_trans_chain {
 	bool				update;
diff --git a/net/netfilter/nf_tables_api.c b/net/netfilter/nf_tables_api.c
index dd2b28a..9f4d37b 100644
--- a/net/netfilter/nf_tables_api.c
+++ b/net/netfilter/nf_tables_api.c
@@ -112,6 +112,23 @@
 	kfree(trans);
 }
 
+static void nft_set_trans_bind(const struct nft_ctx *ctx, struct nft_set *set)
+{
+	struct net *net = ctx->net;
+	struct nft_trans *trans;
+
+	if (!nft_set_is_anonymous(set))
+		return;
+
+	list_for_each_entry_reverse(trans, &net->nft.commit_list, list) {
+		if (trans->msg_type == NFT_MSG_NEWSET &&
+		    nft_trans_set(trans) == set) {
+			nft_trans_set_bound(trans) = true;
+			break;
+		}
+	}
+}
+
 static int nf_tables_register_hook(struct net *net,
 				   const struct nft_table *table,
 				   struct nft_chain *chain)
@@ -207,18 +224,6 @@
 	return err;
 }
 
-/* either expr ops provide both activate/deactivate, or neither */
-static bool nft_expr_check_ops(const struct nft_expr_ops *ops)
-{
-	if (!ops)
-		return true;
-
-	if (WARN_ON_ONCE((!ops->activate ^ !ops->deactivate)))
-		return false;
-
-	return true;
-}
-
 static void nft_rule_expr_activate(const struct nft_ctx *ctx,
 				   struct nft_rule *rule)
 {
@@ -234,14 +239,15 @@
 }
 
 static void nft_rule_expr_deactivate(const struct nft_ctx *ctx,
-				     struct nft_rule *rule)
+				     struct nft_rule *rule,
+				     enum nft_trans_phase phase)
 {
 	struct nft_expr *expr;
 
 	expr = nft_expr_first(rule);
 	while (expr != nft_expr_last(rule) && expr->ops) {
 		if (expr->ops->deactivate)
-			expr->ops->deactivate(ctx, expr);
+			expr->ops->deactivate(ctx, expr, phase);
 
 		expr = nft_expr_next(expr);
 	}
@@ -292,7 +298,7 @@
 		nft_trans_destroy(trans);
 		return err;
 	}
-	nft_rule_expr_deactivate(ctx, rule);
+	nft_rule_expr_deactivate(ctx, rule, NFT_TRANS_PREPARE);
 
 	return 0;
 }
@@ -1926,9 +1932,6 @@
  */
 int nft_register_expr(struct nft_expr_type *type)
 {
-	if (!nft_expr_check_ops(type->ops))
-		return -EINVAL;
-
 	nfnl_lock(NFNL_SUBSYS_NFTABLES);
 	if (type->family == NFPROTO_UNSPEC)
 		list_add_tail_rcu(&type->list, &nf_tables_expressions);
@@ -2076,10 +2079,6 @@
 			err = PTR_ERR(ops);
 			goto err1;
 		}
-		if (!nft_expr_check_ops(ops)) {
-			err = -EINVAL;
-			goto err1;
-		}
 	} else
 		ops = type->ops;
 
@@ -2477,7 +2476,7 @@
 static void nf_tables_rule_release(const struct nft_ctx *ctx,
 				   struct nft_rule *rule)
 {
-	nft_rule_expr_deactivate(ctx, rule);
+	nft_rule_expr_deactivate(ctx, rule, NFT_TRANS_RELEASE);
 	nf_tables_rule_destroy(ctx, rule);
 }
 
@@ -3677,39 +3676,30 @@
 bind:
 	binding->chain = ctx->chain;
 	list_add_tail_rcu(&binding->list, &set->bindings);
+	nft_set_trans_bind(ctx, set);
+
 	return 0;
 }
 EXPORT_SYMBOL_GPL(nf_tables_bind_set);
 
-void nf_tables_rebind_set(const struct nft_ctx *ctx, struct nft_set *set,
-			  struct nft_set_binding *binding)
-{
-	if (list_empty(&set->bindings) && nft_set_is_anonymous(set) &&
-	    nft_is_active(ctx->net, set))
-		list_add_tail_rcu(&set->list, &ctx->table->sets);
-
-	list_add_tail_rcu(&binding->list, &set->bindings);
-}
-EXPORT_SYMBOL_GPL(nf_tables_rebind_set);
-
 void nf_tables_unbind_set(const struct nft_ctx *ctx, struct nft_set *set,
-		          struct nft_set_binding *binding)
+			  struct nft_set_binding *binding, bool event)
 {
 	list_del_rcu(&binding->list);
 
-	if (list_empty(&set->bindings) && nft_set_is_anonymous(set) &&
-	    nft_is_active(ctx->net, set))
+	if (list_empty(&set->bindings) && nft_set_is_anonymous(set)) {
 		list_del_rcu(&set->list);
+		if (event)
+			nf_tables_set_notify(ctx, set, NFT_MSG_DELSET,
+					     GFP_KERNEL);
+	}
 }
 EXPORT_SYMBOL_GPL(nf_tables_unbind_set);
 
 void nf_tables_destroy_set(const struct nft_ctx *ctx, struct nft_set *set)
 {
-	if (list_empty(&set->bindings) && nft_set_is_anonymous(set) &&
-	    nft_is_active(ctx->net, set)) {
-		nf_tables_set_notify(ctx, set, NFT_MSG_DELSET, GFP_ATOMIC);
+	if (list_empty(&set->bindings) && nft_set_is_anonymous(set))
 		nft_set_destroy(set);
-	}
 }
 EXPORT_SYMBOL_GPL(nf_tables_destroy_set);
 
@@ -6462,6 +6452,9 @@
 			nf_tables_rule_notify(&trans->ctx,
 					      nft_trans_rule(trans),
 					      NFT_MSG_DELRULE);
+			nft_rule_expr_deactivate(&trans->ctx,
+						 nft_trans_rule(trans),
+						 NFT_TRANS_COMMIT);
 			break;
 		case NFT_MSG_NEWSET:
 			nft_clear(net, nft_trans_set(trans));
@@ -6549,7 +6542,8 @@
 		nf_tables_rule_destroy(&trans->ctx, nft_trans_rule(trans));
 		break;
 	case NFT_MSG_NEWSET:
-		nft_set_destroy(nft_trans_set(trans));
+		if (!nft_trans_set_bound(trans))
+			nft_set_destroy(nft_trans_set(trans));
 		break;
 	case NFT_MSG_NEWSETELEM:
 		nft_set_elem_destroy(nft_trans_elem_set(trans),
@@ -6610,7 +6604,9 @@
 		case NFT_MSG_NEWRULE:
 			trans->ctx.chain->use--;
 			list_del_rcu(&nft_trans_rule(trans)->list);
-			nft_rule_expr_deactivate(&trans->ctx, nft_trans_rule(trans));
+			nft_rule_expr_deactivate(&trans->ctx,
+						 nft_trans_rule(trans),
+						 NFT_TRANS_ABORT);
 			break;
 		case NFT_MSG_DELRULE:
 			trans->ctx.chain->use++;
@@ -6620,7 +6616,8 @@
 			break;
 		case NFT_MSG_NEWSET:
 			trans->ctx.table->use--;
-			list_del_rcu(&nft_trans_set(trans)->list);
+			if (!nft_trans_set_bound(trans))
+				list_del_rcu(&nft_trans_set(trans)->list);
 			break;
 		case NFT_MSG_DELSET:
 			trans->ctx.table->use++;
diff --git a/net/netfilter/nft_compat.c b/net/netfilter/nft_compat.c
index 432139b..12a95ee 100644
--- a/net/netfilter/nft_compat.c
+++ b/net/netfilter/nft_compat.c
@@ -569,10 +569,14 @@
 }
 
 static void nft_compat_deactivate(const struct nft_ctx *ctx,
-				  const struct nft_expr *expr)
+				  const struct nft_expr *expr,
+				  enum nft_trans_phase phase)
 {
 	struct nft_xt *xt = container_of(expr->ops, struct nft_xt, ops);
 
+	if (phase == NFT_TRANS_COMMIT)
+		return;
+
 	if (--xt->listcnt == 0)
 		list_del_init(&xt->head);
 }
diff --git a/net/netfilter/nft_dynset.c b/net/netfilter/nft_dynset.c
index 07d4efd..f1172f9 100644
--- a/net/netfilter/nft_dynset.c
+++ b/net/netfilter/nft_dynset.c
@@ -235,20 +235,17 @@
 	return err;
 }
 
-static void nft_dynset_activate(const struct nft_ctx *ctx,
-				const struct nft_expr *expr)
-{
-	struct nft_dynset *priv = nft_expr_priv(expr);
-
-	nf_tables_rebind_set(ctx, priv->set, &priv->binding);
-}
-
 static void nft_dynset_deactivate(const struct nft_ctx *ctx,
-				  const struct nft_expr *expr)
+				  const struct nft_expr *expr,
+				  enum nft_trans_phase phase)
 {
 	struct nft_dynset *priv = nft_expr_priv(expr);
 
-	nf_tables_unbind_set(ctx, priv->set, &priv->binding);
+	if (phase == NFT_TRANS_PREPARE)
+		return;
+
+	nf_tables_unbind_set(ctx, priv->set, &priv->binding,
+			     phase == NFT_TRANS_COMMIT);
 }
 
 static void nft_dynset_destroy(const struct nft_ctx *ctx,
@@ -296,7 +293,6 @@
 	.eval		= nft_dynset_eval,
 	.init		= nft_dynset_init,
 	.destroy	= nft_dynset_destroy,
-	.activate	= nft_dynset_activate,
 	.deactivate	= nft_dynset_deactivate,
 	.dump		= nft_dynset_dump,
 };
diff --git a/net/netfilter/nft_immediate.c b/net/netfilter/nft_immediate.c
index 0777a93..3f6d1d2 100644
--- a/net/netfilter/nft_immediate.c
+++ b/net/netfilter/nft_immediate.c
@@ -72,10 +72,14 @@
 }
 
 static void nft_immediate_deactivate(const struct nft_ctx *ctx,
-				     const struct nft_expr *expr)
+				     const struct nft_expr *expr,
+				     enum nft_trans_phase phase)
 {
 	const struct nft_immediate_expr *priv = nft_expr_priv(expr);
 
+	if (phase == NFT_TRANS_COMMIT)
+		return;
+
 	return nft_data_release(&priv->data, nft_dreg_to_type(priv->dreg));
 }
 
diff --git a/net/netfilter/nft_lookup.c b/net/netfilter/nft_lookup.c
index 227b2b1..14496da 100644
--- a/net/netfilter/nft_lookup.c
+++ b/net/netfilter/nft_lookup.c
@@ -121,20 +121,17 @@
 	return 0;
 }
 
-static void nft_lookup_activate(const struct nft_ctx *ctx,
-				const struct nft_expr *expr)
-{
-	struct nft_lookup *priv = nft_expr_priv(expr);
-
-	nf_tables_rebind_set(ctx, priv->set, &priv->binding);
-}
-
 static void nft_lookup_deactivate(const struct nft_ctx *ctx,
-				  const struct nft_expr *expr)
+				  const struct nft_expr *expr,
+				  enum nft_trans_phase phase)
 {
 	struct nft_lookup *priv = nft_expr_priv(expr);
 
-	nf_tables_unbind_set(ctx, priv->set, &priv->binding);
+	if (phase == NFT_TRANS_PREPARE)
+		return;
+
+	nf_tables_unbind_set(ctx, priv->set, &priv->binding,
+			     phase == NFT_TRANS_COMMIT);
 }
 
 static void nft_lookup_destroy(const struct nft_ctx *ctx,
@@ -225,7 +222,6 @@
 	.size		= NFT_EXPR_SIZE(sizeof(struct nft_lookup)),
 	.eval		= nft_lookup_eval,
 	.init		= nft_lookup_init,
-	.activate	= nft_lookup_activate,
 	.deactivate	= nft_lookup_deactivate,
 	.destroy	= nft_lookup_destroy,
 	.dump		= nft_lookup_dump,
diff --git a/net/netfilter/nft_objref.c b/net/netfilter/nft_objref.c
index a3185ca..ae178e9 100644
--- a/net/netfilter/nft_objref.c
+++ b/net/netfilter/nft_objref.c
@@ -155,20 +155,17 @@
 	return -1;
 }
 
-static void nft_objref_map_activate(const struct nft_ctx *ctx,
-				    const struct nft_expr *expr)
-{
-	struct nft_objref_map *priv = nft_expr_priv(expr);
-
-	nf_tables_rebind_set(ctx, priv->set, &priv->binding);
-}
-
 static void nft_objref_map_deactivate(const struct nft_ctx *ctx,
-				      const struct nft_expr *expr)
+				      const struct nft_expr *expr,
+				      enum nft_trans_phase phase)
 {
 	struct nft_objref_map *priv = nft_expr_priv(expr);
 
-	nf_tables_unbind_set(ctx, priv->set, &priv->binding);
+	if (phase == NFT_TRANS_PREPARE)
+		return;
+
+	nf_tables_unbind_set(ctx, priv->set, &priv->binding,
+			     phase == NFT_TRANS_COMMIT);
 }
 
 static void nft_objref_map_destroy(const struct nft_ctx *ctx,
@@ -185,7 +182,6 @@
 	.size		= NFT_EXPR_SIZE(sizeof(struct nft_objref_map)),
 	.eval		= nft_objref_map_eval,
 	.init		= nft_objref_map_init,
-	.activate	= nft_objref_map_activate,
 	.deactivate	= nft_objref_map_deactivate,
 	.destroy	= nft_objref_map_destroy,
 	.dump		= nft_objref_map_dump,