sched, cls: check if we could overwrite actions when changing a filter

When actions are attached to a filter, they are a part of the filter
itself, so when changing a filter we should allow to overwrite the actions
inside as well.

In my specific case, when I tried to _append_ a new action to an existing
filter which already has an action, I got EEXIST since kernel refused
to overwrite the existing one in kernel.

This patch checks if we are changing the filter checking NLM_F_CREATE flag
(Sigh, filters don't use NLM_F_REPLACE...) and then passes the boolean down
to actions. This fixes the problem above.

Cc: Jamal Hadi Salim <jhs@mojatatu.com>
Cc: David S. Miller <davem@davemloft.net>
Signed-off-by: Cong Wang <xiyou.wangcong@gmail.com>
Signed-off-by: Cong Wang <cwang@twopensource.com>
Signed-off-by: Jamal Hadi Salim <jhs@mojatatu.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/sched/cls_api.c b/net/sched/cls_api.c
index 29a30a1..6786130 100644
--- a/net/sched/cls_api.c
+++ b/net/sched/cls_api.c
@@ -317,7 +317,8 @@
 		}
 	}
 
-	err = tp->ops->change(net, skb, tp, cl, t->tcm_handle, tca, &fh);
+	err = tp->ops->change(net, skb, tp, cl, t->tcm_handle, tca, &fh,
+			      n->nlmsg_flags & NLM_F_CREATE ? TCA_ACT_NOREPLACE : TCA_ACT_REPLACE);
 	if (err == 0) {
 		if (tp_created) {
 			spin_lock_bh(root_lock);
@@ -504,7 +505,7 @@
 EXPORT_SYMBOL(tcf_exts_destroy);
 
 int tcf_exts_validate(struct net *net, struct tcf_proto *tp, struct nlattr **tb,
-		  struct nlattr *rate_tlv, struct tcf_exts *exts)
+		  struct nlattr *rate_tlv, struct tcf_exts *exts, bool ovr)
 {
 #ifdef CONFIG_NET_CLS_ACT
 	{
@@ -513,7 +514,7 @@
 		INIT_LIST_HEAD(&exts->actions);
 		if (exts->police && tb[exts->police]) {
 			act = tcf_action_init_1(net, tb[exts->police], rate_tlv,
-						"police", TCA_ACT_NOREPLACE,
+						"police", ovr,
 						TCA_ACT_BIND);
 			if (IS_ERR(act))
 				return PTR_ERR(act);
@@ -523,7 +524,7 @@
 		} else if (exts->action && tb[exts->action]) {
 			int err;
 			err = tcf_action_init(net, tb[exts->action], rate_tlv,
-					      NULL, TCA_ACT_NOREPLACE,
+					      NULL, ovr,
 					      TCA_ACT_BIND, &exts->actions);
 			if (err)
 				return err;
diff --git a/net/sched/cls_basic.c b/net/sched/cls_basic.c
index e98ca99..0ae1813 100644
--- a/net/sched/cls_basic.c
+++ b/net/sched/cls_basic.c
@@ -130,14 +130,14 @@
 static int basic_set_parms(struct net *net, struct tcf_proto *tp,
 			   struct basic_filter *f, unsigned long base,
 			   struct nlattr **tb,
-			   struct nlattr *est)
+			   struct nlattr *est, bool ovr)
 {
 	int err;
 	struct tcf_exts e;
 	struct tcf_ematch_tree t;
 
 	tcf_exts_init(&e, TCA_BASIC_ACT, TCA_BASIC_POLICE);
-	err = tcf_exts_validate(net, tp, tb, est, &e);
+	err = tcf_exts_validate(net, tp, tb, est, &e, ovr);
 	if (err < 0)
 		return err;
 
@@ -161,7 +161,7 @@
 
 static int basic_change(struct net *net, struct sk_buff *in_skb,
 			struct tcf_proto *tp, unsigned long base, u32 handle,
-			struct nlattr **tca, unsigned long *arg)
+			struct nlattr **tca, unsigned long *arg, bool ovr)
 {
 	int err;
 	struct basic_head *head = tp->root;
@@ -179,7 +179,7 @@
 	if (f != NULL) {
 		if (handle && f->handle != handle)
 			return -EINVAL;
-		return basic_set_parms(net, tp, f, base, tb, tca[TCA_RATE]);
+		return basic_set_parms(net, tp, f, base, tb, tca[TCA_RATE], ovr);
 	}
 
 	err = -ENOBUFS;
@@ -206,7 +206,7 @@
 		f->handle = head->hgenerator;
 	}
 
-	err = basic_set_parms(net, tp, f, base, tb, tca[TCA_RATE]);
+	err = basic_set_parms(net, tp, f, base, tb, tca[TCA_RATE], ovr);
 	if (err < 0)
 		goto errout;
 
diff --git a/net/sched/cls_bpf.c b/net/sched/cls_bpf.c
index 8e3cf49..1618696 100644
--- a/net/sched/cls_bpf.c
+++ b/net/sched/cls_bpf.c
@@ -156,7 +156,7 @@
 static int cls_bpf_modify_existing(struct net *net, struct tcf_proto *tp,
 				   struct cls_bpf_prog *prog,
 				   unsigned long base, struct nlattr **tb,
-				   struct nlattr *est)
+				   struct nlattr *est, bool ovr)
 {
 	struct sock_filter *bpf_ops, *bpf_old;
 	struct tcf_exts exts;
@@ -170,7 +170,7 @@
 		return -EINVAL;
 
 	tcf_exts_init(&exts, TCA_BPF_ACT, TCA_BPF_POLICE);
-	ret = tcf_exts_validate(net, tp, tb, est, &exts);
+	ret = tcf_exts_validate(net, tp, tb, est, &exts, ovr);
 	if (ret < 0)
 		return ret;
 
@@ -242,7 +242,7 @@
 static int cls_bpf_change(struct net *net, struct sk_buff *in_skb,
 			  struct tcf_proto *tp, unsigned long base,
 			  u32 handle, struct nlattr **tca,
-			  unsigned long *arg)
+			  unsigned long *arg, bool ovr)
 {
 	struct cls_bpf_head *head = tp->root;
 	struct cls_bpf_prog *prog = (struct cls_bpf_prog *) *arg;
@@ -260,7 +260,7 @@
 		if (handle && prog->handle != handle)
 			return -EINVAL;
 		return cls_bpf_modify_existing(net, tp, prog, base, tb,
-					       tca[TCA_RATE]);
+					       tca[TCA_RATE], ovr);
 	}
 
 	prog = kzalloc(sizeof(*prog), GFP_KERNEL);
@@ -277,7 +277,7 @@
 		goto errout;
 	}
 
-	ret = cls_bpf_modify_existing(net, tp, prog, base, tb, tca[TCA_RATE]);
+	ret = cls_bpf_modify_existing(net, tp, prog, base, tb, tca[TCA_RATE], ovr);
 	if (ret < 0)
 		goto errout;
 
diff --git a/net/sched/cls_cgroup.c b/net/sched/cls_cgroup.c
index 8e2158a..cacf01b 100644
--- a/net/sched/cls_cgroup.c
+++ b/net/sched/cls_cgroup.c
@@ -83,7 +83,7 @@
 static int cls_cgroup_change(struct net *net, struct sk_buff *in_skb,
 			     struct tcf_proto *tp, unsigned long base,
 			     u32 handle, struct nlattr **tca,
-			     unsigned long *arg)
+			     unsigned long *arg, bool ovr)
 {
 	struct nlattr *tb[TCA_CGROUP_MAX + 1];
 	struct cls_cgroup_head *head = tp->root;
@@ -119,7 +119,7 @@
 		return err;
 
 	tcf_exts_init(&e, TCA_CGROUP_ACT, TCA_CGROUP_POLICE);
-	err = tcf_exts_validate(net, tp, tb, tca[TCA_RATE], &e);
+	err = tcf_exts_validate(net, tp, tb, tca[TCA_RATE], &e, ovr);
 	if (err < 0)
 		return err;
 
diff --git a/net/sched/cls_flow.c b/net/sched/cls_flow.c
index 257029c..35be16f 100644
--- a/net/sched/cls_flow.c
+++ b/net/sched/cls_flow.c
@@ -349,7 +349,7 @@
 static int flow_change(struct net *net, struct sk_buff *in_skb,
 		       struct tcf_proto *tp, unsigned long base,
 		       u32 handle, struct nlattr **tca,
-		       unsigned long *arg)
+		       unsigned long *arg, bool ovr)
 {
 	struct flow_head *head = tp->root;
 	struct flow_filter *f;
@@ -393,7 +393,7 @@
 	}
 
 	tcf_exts_init(&e, TCA_FLOW_ACT, TCA_FLOW_POLICE);
-	err = tcf_exts_validate(net, tp, tb, tca[TCA_RATE], &e);
+	err = tcf_exts_validate(net, tp, tb, tca[TCA_RATE], &e, ovr);
 	if (err < 0)
 		return err;
 
diff --git a/net/sched/cls_fw.c b/net/sched/cls_fw.c
index 63a3ce7..861b03c 100644
--- a/net/sched/cls_fw.c
+++ b/net/sched/cls_fw.c
@@ -169,7 +169,7 @@
 
 static int
 fw_change_attrs(struct net *net, struct tcf_proto *tp, struct fw_filter *f,
-	struct nlattr **tb, struct nlattr **tca, unsigned long base)
+	struct nlattr **tb, struct nlattr **tca, unsigned long base, bool ovr)
 {
 	struct fw_head *head = tp->root;
 	struct tcf_exts e;
@@ -177,7 +177,7 @@
 	int err;
 
 	tcf_exts_init(&e, TCA_FW_ACT, TCA_FW_POLICE);
-	err = tcf_exts_validate(net, tp, tb, tca[TCA_RATE], &e);
+	err = tcf_exts_validate(net, tp, tb, tca[TCA_RATE], &e, ovr);
 	if (err < 0)
 		return err;
 
@@ -218,7 +218,7 @@
 		     struct tcf_proto *tp, unsigned long base,
 		     u32 handle,
 		     struct nlattr **tca,
-		     unsigned long *arg)
+		     unsigned long *arg, bool ovr)
 {
 	struct fw_head *head = tp->root;
 	struct fw_filter *f = (struct fw_filter *) *arg;
@@ -236,7 +236,7 @@
 	if (f != NULL) {
 		if (f->id != handle && handle)
 			return -EINVAL;
-		return fw_change_attrs(net, tp, f, tb, tca, base);
+		return fw_change_attrs(net, tp, f, tb, tca, base, ovr);
 	}
 
 	if (!handle)
@@ -264,7 +264,7 @@
 	tcf_exts_init(&f->exts, TCA_FW_ACT, TCA_FW_POLICE);
 	f->id = handle;
 
-	err = fw_change_attrs(net, tp, f, tb, tca, base);
+	err = fw_change_attrs(net, tp, f, tb, tca, base, ovr);
 	if (err < 0)
 		goto errout;
 
diff --git a/net/sched/cls_route.c b/net/sched/cls_route.c
index 1ad3068..dd9fc25 100644
--- a/net/sched/cls_route.c
+++ b/net/sched/cls_route.c
@@ -333,7 +333,8 @@
 static int route4_set_parms(struct net *net, struct tcf_proto *tp,
 			    unsigned long base, struct route4_filter *f,
 			    u32 handle, struct route4_head *head,
-			    struct nlattr **tb, struct nlattr *est, int new)
+			    struct nlattr **tb, struct nlattr *est, int new,
+			    bool ovr)
 {
 	int err;
 	u32 id = 0, to = 0, nhandle = 0x8000;
@@ -343,7 +344,7 @@
 	struct tcf_exts e;
 
 	tcf_exts_init(&e, TCA_ROUTE4_ACT, TCA_ROUTE4_POLICE);
-	err = tcf_exts_validate(net, tp, tb, est, &e);
+	err = tcf_exts_validate(net, tp, tb, est, &e, ovr);
 	if (err < 0)
 		return err;
 
@@ -428,7 +429,7 @@
 		       struct tcf_proto *tp, unsigned long base,
 		       u32 handle,
 		       struct nlattr **tca,
-		       unsigned long *arg)
+		       unsigned long *arg, bool ovr)
 {
 	struct route4_head *head = tp->root;
 	struct route4_filter *f, *f1, **fp;
@@ -455,7 +456,7 @@
 			old_handle = f->handle;
 
 		err = route4_set_parms(net, tp, base, f, handle, head, tb,
-			tca[TCA_RATE], 0);
+			tca[TCA_RATE], 0, ovr);
 		if (err < 0)
 			return err;
 
@@ -479,7 +480,7 @@
 
 	tcf_exts_init(&f->exts, TCA_ROUTE4_ACT, TCA_ROUTE4_POLICE);
 	err = route4_set_parms(net, tp, base, f, handle, head, tb,
-		tca[TCA_RATE], 1);
+		tca[TCA_RATE], 1, ovr);
 	if (err < 0)
 		goto errout;
 
diff --git a/net/sched/cls_rsvp.h b/net/sched/cls_rsvp.h
index 19f8e5d..1020e23 100644
--- a/net/sched/cls_rsvp.h
+++ b/net/sched/cls_rsvp.h
@@ -415,7 +415,7 @@
 		       struct tcf_proto *tp, unsigned long base,
 		       u32 handle,
 		       struct nlattr **tca,
-		       unsigned long *arg)
+		       unsigned long *arg, bool ovr)
 {
 	struct rsvp_head *data = tp->root;
 	struct rsvp_filter *f, **fp;
@@ -436,7 +436,7 @@
 		return err;
 
 	tcf_exts_init(&e, TCA_RSVP_ACT, TCA_RSVP_POLICE);
-	err = tcf_exts_validate(net, tp, tb, tca[TCA_RATE], &e);
+	err = tcf_exts_validate(net, tp, tb, tca[TCA_RATE], &e, ovr);
 	if (err < 0)
 		return err;
 
diff --git a/net/sched/cls_tcindex.c b/net/sched/cls_tcindex.c
index eed8404..d11d0a4 100644
--- a/net/sched/cls_tcindex.c
+++ b/net/sched/cls_tcindex.c
@@ -192,7 +192,7 @@
 tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base,
 		  u32 handle, struct tcindex_data *p,
 		  struct tcindex_filter_result *r, struct nlattr **tb,
-		 struct nlattr *est)
+		  struct nlattr *est, bool ovr)
 {
 	int err, balloc = 0;
 	struct tcindex_filter_result new_filter_result, *old_r = r;
@@ -202,7 +202,7 @@
 	struct tcf_exts e;
 
 	tcf_exts_init(&e, TCA_TCINDEX_ACT, TCA_TCINDEX_POLICE);
-	err = tcf_exts_validate(net, tp, tb, est, &e);
+	err = tcf_exts_validate(net, tp, tb, est, &e, ovr);
 	if (err < 0)
 		return err;
 
@@ -331,7 +331,7 @@
 static int
 tcindex_change(struct net *net, struct sk_buff *in_skb,
 	       struct tcf_proto *tp, unsigned long base, u32 handle,
-	       struct nlattr **tca, unsigned long *arg)
+	       struct nlattr **tca, unsigned long *arg, bool ovr)
 {
 	struct nlattr *opt = tca[TCA_OPTIONS];
 	struct nlattr *tb[TCA_TCINDEX_MAX + 1];
@@ -351,7 +351,7 @@
 		return err;
 
 	return tcindex_set_parms(net, tp, base, handle, p, r, tb,
-				 tca[TCA_RATE]);
+				 tca[TCA_RATE], ovr);
 }
 
 
diff --git a/net/sched/cls_u32.c b/net/sched/cls_u32.c
index 84c28da..c39b583 100644
--- a/net/sched/cls_u32.c
+++ b/net/sched/cls_u32.c
@@ -486,13 +486,13 @@
 static int u32_set_parms(struct net *net, struct tcf_proto *tp,
 			 unsigned long base, struct tc_u_hnode *ht,
 			 struct tc_u_knode *n, struct nlattr **tb,
-			 struct nlattr *est)
+			 struct nlattr *est, bool ovr)
 {
 	int err;
 	struct tcf_exts e;
 
 	tcf_exts_init(&e, TCA_U32_ACT, TCA_U32_POLICE);
-	err = tcf_exts_validate(net, tp, tb, est, &e);
+	err = tcf_exts_validate(net, tp, tb, est, &e, ovr);
 	if (err < 0)
 		return err;
 
@@ -545,7 +545,7 @@
 static int u32_change(struct net *net, struct sk_buff *in_skb,
 		      struct tcf_proto *tp, unsigned long base, u32 handle,
 		      struct nlattr **tca,
-		      unsigned long *arg)
+		      unsigned long *arg, bool ovr)
 {
 	struct tc_u_common *tp_c = tp->data;
 	struct tc_u_hnode *ht;
@@ -569,7 +569,7 @@
 			return -EINVAL;
 
 		return u32_set_parms(net, tp, base, n->ht_up, n, tb,
-				     tca[TCA_RATE]);
+				     tca[TCA_RATE], ovr);
 	}
 
 	if (tb[TCA_U32_DIVISOR]) {
@@ -656,7 +656,7 @@
 	}
 #endif
 
-	err = u32_set_parms(net, tp, base, ht, n, tb, tca[TCA_RATE]);
+	err = u32_set_parms(net, tp, base, ht, n, tb, tca[TCA_RATE], ovr);
 	if (err == 0) {
 		struct tc_u_knode **ins;
 		for (ins = &ht->ht[TC_U32_HASH(handle)]; *ins; ins = &(*ins)->next)