[NETFILTER]: xt_conntrack: add port and direction matching

Extend the xt_conntrack match revision 1 by port matching (all four
{orig,repl}{src,dst}) and by packet direction matching.

Signed-off-by: Jan Engelhardt <jengelh@computergmbh.de>
Signed-off-by: Patrick McHardy <kaber@trash.net>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/include/linux/netfilter/xt_conntrack.h b/include/linux/netfilter/xt_conntrack.h
index d2492a3..f3fd83e 100644
--- a/include/linux/netfilter/xt_conntrack.h
+++ b/include/linux/netfilter/xt_conntrack.h
@@ -6,9 +6,6 @@
 #define _XT_CONNTRACK_H
 
 #include <linux/netfilter/nf_conntrack_tuple_common.h>
-#ifdef __KERNEL__
-#	include <linux/in.h>
-#endif
 
 #define XT_CONNTRACK_STATE_BIT(ctinfo) (1 << ((ctinfo)%IP_CT_IS_REPLY+1))
 #define XT_CONNTRACK_STATE_INVALID (1 << 0)
@@ -18,14 +15,21 @@
 #define XT_CONNTRACK_STATE_UNTRACKED (1 << (IP_CT_NUMBER + 3))
 
 /* flags, invflags: */
-#define XT_CONNTRACK_STATE	0x01
-#define XT_CONNTRACK_PROTO	0x02
-#define XT_CONNTRACK_ORIGSRC	0x04
-#define XT_CONNTRACK_ORIGDST	0x08
-#define XT_CONNTRACK_REPLSRC	0x10
-#define XT_CONNTRACK_REPLDST	0x20
-#define XT_CONNTRACK_STATUS	0x40
-#define XT_CONNTRACK_EXPIRES	0x80
+enum {
+	XT_CONNTRACK_STATE        = 1 << 0,
+	XT_CONNTRACK_PROTO        = 1 << 1,
+	XT_CONNTRACK_ORIGSRC      = 1 << 2,
+	XT_CONNTRACK_ORIGDST      = 1 << 3,
+	XT_CONNTRACK_REPLSRC      = 1 << 4,
+	XT_CONNTRACK_REPLDST      = 1 << 5,
+	XT_CONNTRACK_STATUS       = 1 << 6,
+	XT_CONNTRACK_EXPIRES      = 1 << 7,
+	XT_CONNTRACK_ORIGSRC_PORT = 1 << 8,
+	XT_CONNTRACK_ORIGDST_PORT = 1 << 9,
+	XT_CONNTRACK_REPLSRC_PORT = 1 << 10,
+	XT_CONNTRACK_REPLDST_PORT = 1 << 11,
+	XT_CONNTRACK_DIRECTION    = 1 << 12,
+};
 
 /* This is exposed to userspace, so remains frozen in time. */
 struct ip_conntrack_old_tuple
@@ -70,8 +74,10 @@
 	union nf_inet_addr repldst_addr, repldst_mask;
 	u_int32_t expires_min, expires_max;
 	u_int16_t l4proto;
+	__be16 origsrc_port, origdst_port;
+	__be16 replsrc_port, repldst_port;
+	u_int16_t match_flags, invert_flags;
 	u_int8_t state_mask, status_mask;
-	u_int8_t match_flags, invert_flags;
 };
 
 #endif /*_XT_CONNTRACK_H*/
diff --git a/net/netfilter/xt_conntrack.c b/net/netfilter/xt_conntrack.c
index e92190e..8533085 100644
--- a/net/netfilter/xt_conntrack.c
+++ b/net/netfilter/xt_conntrack.c
@@ -4,7 +4,6 @@
  *
  *	(C) 2001  Marc Boucher (marc@mbsi.ca).
  *	Copyright © CC Computer Consultants GmbH, 2007 - 2008
- *	Jan Engelhardt <jengelh@computergmbh.de>
  *
  *	This program is free software; you can redistribute it and/or modify
  *	it under the terms of the GNU General Public License version 2 as
@@ -20,6 +19,7 @@
 
 MODULE_LICENSE("GPL");
 MODULE_AUTHOR("Marc Boucher <marc@mbsi.ca>");
+MODULE_AUTHOR("Jan Engelhardt <jengelh@computergmbh.de>");
 MODULE_DESCRIPTION("Xtables: connection tracking state match");
 MODULE_ALIAS("ipt_conntrack");
 MODULE_ALIAS("ip6t_conntrack");
@@ -166,6 +166,44 @@
 	       &info->repldst_addr, &info->repldst_mask, family);
 }
 
+static inline bool
+ct_proto_port_check(const struct xt_conntrack_mtinfo1 *info,
+                    const struct nf_conn *ct)
+{
+	const struct nf_conntrack_tuple *tuple;
+
+	tuple = &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple;
+	if ((info->match_flags & XT_CONNTRACK_PROTO) &&
+	    (tuple->dst.protonum == info->l4proto) ^
+	    !(info->invert_flags & XT_CONNTRACK_PROTO))
+		return false;
+
+	/* Shortcut to match all recognized protocols by using ->src.all. */
+	if ((info->match_flags & XT_CONNTRACK_ORIGSRC_PORT) &&
+	    (tuple->src.u.all == info->origsrc_port) ^
+	    !(info->invert_flags & XT_CONNTRACK_ORIGSRC_PORT))
+		return false;
+
+	if ((info->match_flags & XT_CONNTRACK_ORIGDST_PORT) &&
+	    (tuple->dst.u.all == info->origdst_port) ^
+	    !(info->invert_flags & XT_CONNTRACK_ORIGDST_PORT))
+		return false;
+
+	tuple = &ct->tuplehash[IP_CT_DIR_REPLY].tuple;
+
+	if ((info->match_flags & XT_CONNTRACK_REPLSRC_PORT) &&
+	    (tuple->src.u.all == info->replsrc_port) ^
+	    !(info->invert_flags & XT_CONNTRACK_REPLSRC_PORT))
+		return false;
+
+	if ((info->match_flags & XT_CONNTRACK_REPLDST_PORT) &&
+	    (tuple->dst.u.all == info->repldst_port) ^
+	    !(info->invert_flags & XT_CONNTRACK_REPLDST_PORT))
+		return false;
+
+	return true;
+}
+
 static bool
 conntrack_mt(const struct sk_buff *skb, const struct net_device *in,
              const struct net_device *out, const struct xt_match *match,
@@ -200,10 +238,9 @@
 
 	if (ct == NULL)
 		return info->match_flags & XT_CONNTRACK_STATE;
-
-	if ((info->match_flags & XT_CONNTRACK_PROTO) &&
-	    ((ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.dst.protonum ==
-	    info->l4proto) ^ !(info->invert_flags & XT_CONNTRACK_PROTO)))
+	if ((info->match_flags & XT_CONNTRACK_DIRECTION) &&
+	    (CTINFO2DIR(ctinfo) == IP_CT_DIR_ORIGINAL) ^
+	    !!(info->invert_flags & XT_CONNTRACK_DIRECTION))
 		return false;
 
 	if (info->match_flags & XT_CONNTRACK_ORIGSRC)
@@ -226,6 +263,9 @@
 		    !(info->invert_flags & XT_CONNTRACK_REPLDST))
 			return false;
 
+	if (!ct_proto_port_check(info, ct))
+		return false;
+
 	if ((info->match_flags & XT_CONNTRACK_STATUS) &&
 	    (!!(info->status_mask & ct->status) ^
 	    !(info->invert_flags & XT_CONNTRACK_STATUS)))