net: diag: allow socket bytecode filters to match socket marks

This allows a privileged process to filter by socket mark when
dumping sockets via INET_DIAG_BY_FAMILY. This is useful on
systems that use mark-based routing such as Android.

The ability to filter socket marks requires CAP_NET_ADMIN, which
is consistent with other privileged operations allowed by the
SOCK_DIAG interface such as the ability to destroy sockets and
the ability to inspect BPF filters attached to packet sockets.

Tested: https://android-review.googlesource.com/261350
Signed-off-by: Lorenzo Colitti <lorenzo@google.com>
Acked-by: David Ahern <dsa@cumulusnetworks.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/ipv4/inet_diag.c b/net/ipv4/inet_diag.c
index a89b68c..abfbe49 100644
--- a/net/ipv4/inet_diag.c
+++ b/net/ipv4/inet_diag.c
@@ -45,6 +45,7 @@
 	u16 family;
 	u16 userlocks;
 	u32 ifindex;
+	u32 mark;
 };
 
 static DEFINE_MUTEX(inet_diag_table_mutex);
@@ -580,6 +581,14 @@
 				yes = 0;
 			break;
 		}
+		case INET_DIAG_BC_MARK_COND: {
+			struct inet_diag_markcond *cond;
+
+			cond = (struct inet_diag_markcond *)(op + 1);
+			if ((entry->mark & cond->mask) != cond->mark)
+				yes = 0;
+			break;
+		}
 		}
 
 		if (yes) {
@@ -624,6 +633,12 @@
 	entry.dport = ntohs(inet->inet_dport);
 	entry.ifindex = sk->sk_bound_dev_if;
 	entry.userlocks = sk_fullsock(sk) ? sk->sk_userlocks : 0;
+	if (sk_fullsock(sk))
+		entry.mark = sk->sk_mark;
+	else if (sk->sk_state == TCP_NEW_SYN_RECV)
+		entry.mark = inet_rsk(inet_reqsk(sk))->ir_mark;
+	else
+		entry.mark = 0;
 
 	return inet_diag_bc_run(bc, &entry);
 }
@@ -706,8 +721,17 @@
 	return true;
 }
 
-static int inet_diag_bc_audit(const struct nlattr *attr)
+static bool valid_markcond(const struct inet_diag_bc_op *op, int len,
+			   int *min_len)
 {
+	*min_len += sizeof(struct inet_diag_markcond);
+	return len >= *min_len;
+}
+
+static int inet_diag_bc_audit(const struct nlattr *attr,
+			      const struct sk_buff *skb)
+{
+	bool net_admin = netlink_net_capable(skb, CAP_NET_ADMIN);
 	const void *bytecode, *bc;
 	int bytecode_len, len;
 
@@ -738,6 +762,12 @@
 			if (!valid_port_comparison(bc, len, &min_len))
 				return -EINVAL;
 			break;
+		case INET_DIAG_BC_MARK_COND:
+			if (!net_admin)
+				return -EPERM;
+			if (!valid_markcond(bc, len, &min_len))
+				return -EINVAL;
+			break;
 		case INET_DIAG_BC_AUTO:
 		case INET_DIAG_BC_JMP:
 		case INET_DIAG_BC_NOP:
@@ -1030,7 +1060,7 @@
 
 			attr = nlmsg_find_attr(nlh, hdrlen,
 					       INET_DIAG_REQ_BYTECODE);
-			err = inet_diag_bc_audit(attr);
+			err = inet_diag_bc_audit(attr, skb);
 			if (err)
 				return err;
 		}
@@ -1061,7 +1091,7 @@
 
 			attr = nlmsg_find_attr(h, hdrlen,
 					       INET_DIAG_REQ_BYTECODE);
-			err = inet_diag_bc_audit(attr);
+			err = inet_diag_bc_audit(attr, skb);
 			if (err)
 				return err;
 		}