Add support for checking flags in syscall arguments in Minijail.

Also, extract some code into functions as well, to make the code more readable.

BUG=chromium-os:36848
TEST=syscall_filter_unittest, security_Minijail_seccomp

Change-Id: Iedf8ecbf1814340fd8b3e4ec687b303c9c024d0a
Reviewed-on: https://gerrit.chromium.org/gerrit/39128
Tested-by: Jorge Lucangeli Obes <jorgelo@chromium.org>
Reviewed-by: Kees Cook <keescook@chromium.org>
Commit-Ready: Jorge Lucangeli Obes <jorgelo@chromium.org>
diff --git a/Makefile b/Makefile
index cd6b51c..c86a647 100644
--- a/Makefile
+++ b/Makefile
@@ -49,7 +49,7 @@
 libsyscalls.gen.o : libsyscalls.gen.c libsyscalls.h
 
 syscall_filter_unittest : syscall_filter_unittest.o syscall_filter.o \
-		bpf.o util.o libsyscalls.gen.o test_harness.h
+		bpf.o util.o libsyscalls.gen.o
 	$(CC) $(CFLAGS) -o $@ $^
 
 syscall_filter_unittest.o : syscall_filter_unittest.c test_harness.h
diff --git a/syscall_filter.c b/syscall_filter.c
index a1f65d6..5597da2 100644
--- a/syscall_filter.c
+++ b/syscall_filter.c
@@ -23,6 +23,8 @@
 		return EQ;
 	} else if (!strcmp(op_str, "!=")) {
 		return NE;
+	} else if (!strcmp(op_str, "&")) {
+		return SET;
 	} else {
 		return 0;
 	}
@@ -139,6 +141,85 @@
 	return get_label_id(labels, lbl_str);
 }
 
+int compile_atom(struct filter_block *head, char *atom,
+		struct bpf_labels *labels, int nr, int group_idx)
+{
+	/* Splits the atom. */
+	char *atom_ptr;
+	char *argidx_str = strtok_r(atom, " ", &atom_ptr);
+	char *operator_str = strtok_r(NULL, " ", &atom_ptr);
+	char *constant_str = strtok_r(NULL, " ", &atom_ptr);
+
+	if (argidx_str == NULL || operator_str == NULL || constant_str == NULL)
+		return -1;
+
+	int op = str_to_op(operator_str);
+	if (op < MIN_OPERATOR)
+		return -1;
+
+	if (strncmp(argidx_str, "arg", 3)) {
+		return -1;
+	}
+
+	char *argidx_ptr;
+	long int argidx = strtol(argidx_str + 3, &argidx_ptr, 10);
+	/*
+	 * Checks to see if an actual argument index
+	 * was parsed.
+	 */
+	if (argidx_ptr == argidx_str + 3)
+		return -1;
+
+	long int c = strtol(constant_str, NULL, 0);
+	/*
+	 * Looks up the label for the end of the AND statement
+	 * this atom belongs to.
+	 */
+	unsigned int id = group_end_lbl(labels, nr, group_idx);
+
+	/*
+	 * Builds a BPF comparison between a syscall argument
+	 * and a constant.
+	 * The comparison lives inside an AND statement.
+	 * If the comparison succeeds, we continue
+	 * to the next comparison.
+	 * If this comparison fails, the whole AND statement
+	 * will fail, so we jump to the end of this AND statement.
+	 */
+	struct sock_filter *comp_block;
+	size_t len = bpf_arg_comp(&comp_block, op, argidx, c, id);
+	if (len == 0)
+		return -1;
+
+	append_filter_block(head, comp_block, len);
+	return 0;
+}
+
+int compile_errno(struct filter_block *head, char *ret_errno) {
+	char *errno_ptr;
+
+	/* Splits the 'return' keyword and the actual errno value. */
+	char *ret_str = strtok_r(ret_errno, " ", &errno_ptr);
+	if (strncmp(ret_str, "return", strlen("return")))
+		return -1;
+
+	char *errno_val_str = strtok_r(NULL, " ", &errno_ptr);
+
+	if (errno_val_str) {
+		char *errno_val_ptr;
+		int errno_val = strtol(
+				errno_val_str, &errno_val_ptr, 0);
+		/* Checks to see if we parsed an actual errno. */
+		if (errno_val_ptr == errno_val_str)
+			return -1;
+
+		append_ret_errno(head, errno_val);
+	} else {
+		append_ret_kill(head);
+	}
+	return 0;
+}
+
 struct filter_block *compile_section(int nr, const char *policy_line,
 		unsigned int entry_lbl_id, struct bpf_labels *labels)
 {
@@ -153,8 +234,8 @@
 	 * Atoms are of the form "arg{DNUM} OP NUM"
 	 * where:
 	 *   - DNUM is a decimal number.
-	 *   - OP is a comparison operator (== or != for now).
-	 *   - NUM is a decimal or hexadecimal number.
+	 *   - OP is an operator: ==, !=, or & (flags set).
+	 *   - NUM is an octal, decimal, or hexadecimal number.
 	 *
 	 * When the syscall arguments make the expression true,
 	 * the syscall is allowed. If not, the process is killed.
@@ -181,8 +262,9 @@
 		return NULL;
 
 	/* Splits the optional "return <errno>" part. */
-	char *arg_filter = strtok(line, ";");
-	char *ret_errno = strtok(NULL, ";");
+	char *line_ptr;
+	char *arg_filter = strtok_r(line, ";", &line_ptr);
+	char *ret_errno = strtok_r(NULL, ";", &line_ptr);
 
 	/*
 	 * We build the argument filter as a collection of smaller
@@ -207,73 +289,15 @@
 	 * Splits the policy line by '||' into conjunctions and each conjunction
 	 * by '&&' into atoms.
 	 */
-	char *arg_filter_str;
-	char *arg_filter_ptr;
-	for (arg_filter_str = arg_filter; ; arg_filter_str = NULL) {
-		char *group = strtok_r(arg_filter_str, "||", &arg_filter_ptr);
-
-		if (group == NULL)
-			break;
-
-		char *group_str;
-		char *group_ptr;
-		for (group_str = group; ; group_str = NULL) {
-			char *comp = strtok_r(group_str, "&&", &group_ptr);
-
-			if (comp == NULL)
-				break;
-
-			/* Splits each atom. */
-			char *comp_ptr;
-			char *argidx_str = strtok_r(comp, " ", &comp_ptr);
-			char *operator_str = strtok_r(NULL, " ", &comp_ptr);
-			char *constant_str = strtok_r(NULL, " ", &comp_ptr);
-
-			if (argidx_str == NULL ||
-			    operator_str == NULL ||
-			    constant_str == NULL)
+	char *arg_filter_str = arg_filter;
+	char *group;
+	while ((group = tokenize(&arg_filter_str, "||")) != NULL) {
+		char *group_str = group;
+		char *comp;
+		while ((comp = tokenize(&group_str, "&&")) != NULL) {
+			/* Compiles each atom into a BPF block. */
+			if (compile_atom(head, comp, labels, nr, group_idx) < 0)
 				return NULL;
-
-			int op = str_to_op(operator_str);
-
-			if (op < MIN_OPERATOR)
-				return NULL;
-
-			if (strncmp(argidx_str, "arg", 3)) {
-				return NULL;
-			}
-
-			char *argidx_ptr;
-			long int argidx = strtol(
-					argidx_str + 3, &argidx_ptr, 10);
-			/*
-			 * Checks to see if an actual argument index
-			 * was parsed.
-			 */
-			if (argidx_ptr == argidx_str + 3) {
-				return NULL;
-			}
-
-			long int c = strtol(constant_str, NULL, 0);
-			unsigned int id = group_end_lbl(
-					labels, nr, group_idx);
-
-			/*
-			 * Builds a BPF comparison between a syscall argument
-			 * and a constant.
-			 * The comparison lives inside an AND statement.
-			 * If the comparison succeeds, we continue
-			 * to the next comparison.
-			 * If this comparison fails, the whole AND statement
-			 * will fail, so we jump to the end of this AND statement.
-			 */
-			struct sock_filter *comp_block;
-			len = bpf_arg_comp(&comp_block,
-					op, argidx, c, id);
-			if (len == 0)
-				return NULL;
-
-			append_filter_block(head, comp_block, len);
 		}
 		/*
 		 * If the AND statement succeeds, we're done,
@@ -298,26 +322,8 @@
 	 * otherwise just kill the task.
 	 */
 	if (ret_errno) {
-		char *errno_ptr;
-
-		char *ret_str = strtok_r(ret_errno, " ", &errno_ptr);
-		if (strncmp(ret_str, "return", strlen("return")))
+		if (compile_errno(head, ret_errno) < 0)
 			return NULL;
-
-		char *errno_val_str = strtok_r(NULL, " ", &errno_ptr);
-
-		if (errno_val_str) {
-			char *errno_val_ptr;
-			int errno_val = strtol(
-					errno_val_str, &errno_val_ptr, 0);
-			/* Checks to see if we parsed an actual errno. */
-			if (errno_val_ptr == errno_val_str)
-				return NULL;
-
-			append_ret_errno(head, errno_val);
-		} else {
-			append_ret_kill(head);
-		}
 	} else {
 		append_ret_kill(head);
 	}
diff --git a/syscall_filter_unittest.c b/syscall_filter_unittest.c
index 0068fca..a4911c8 100644
--- a/syscall_filter_unittest.c
+++ b/syscall_filter_unittest.c
@@ -289,6 +289,48 @@
 	free_label_strings(&self->labels);
 }
 
+TEST_F(arg_filter, arg0_mask) {
+	const char *fragment = "arg1 & 02";	/* O_RDWR */
+	int nr = 1;
+	unsigned int id = 0;
+	struct filter_block *block =
+		compile_section(nr, fragment, id, &self->labels);
+
+	ASSERT_NE(block, NULL);
+	size_t exp_total_len = 1 + (BPF_ARG_COMP_LEN + 1) + 2 + 1 + 2;
+	EXPECT_EQ(block->total_len, exp_total_len);
+
+	/* First block is a label. */
+	struct filter_block *curr_block = block;
+	ASSERT_NE(curr_block, NULL);
+	EXPECT_EQ(block->len, 1U);
+	EXPECT_LBL(curr_block->instrs);
+
+	/* Second block is a comparison. */
+	curr_block = block->next;
+	EXPECT_COMP(curr_block);
+
+	/* Third block is a jump and a label (end of AND group). */
+	curr_block = curr_block->next;
+	EXPECT_NE(curr_block, NULL);
+	EXPECT_GROUP_END(curr_block);
+
+	/* Fourth block is SECCOMP_RET_KILL */
+	curr_block = curr_block->next;
+	EXPECT_NE(curr_block, NULL);
+	EXPECT_KILL(curr_block);
+
+	/* Fifth block is "SUCCESS" label and SECCOMP_RET_ALLOW */
+	curr_block = curr_block->next;
+	EXPECT_NE(curr_block, NULL);
+	EXPECT_ALLOW(curr_block);
+
+	EXPECT_EQ(curr_block->next, NULL);
+
+	free_block_list(block);
+	free_label_strings(&self->labels);
+}
+
 TEST_F(arg_filter, and_or) {
 	const char *fragment = "arg0 == 0 && arg1 == 0 || arg0 == 1";
 	int nr = 1;
@@ -535,7 +577,8 @@
 	EXPECT_ALLOW_SYSCALL(actual.filter + index + 2, __NR_write);
 	EXPECT_ALLOW_SYSCALL(actual.filter + index + 4, __NR_rt_sigreturn);
 	EXPECT_ALLOW_SYSCALL(actual.filter + index + 6, __NR_exit);
-	EXPECT_EQ_STMT(actual.filter + index + 8, BPF_RET+BPF_K, SECCOMP_RET_TRAP);
+	EXPECT_EQ_STMT(actual.filter + index + 8, BPF_RET+BPF_K,
+			SECCOMP_RET_TRAP);
 
 	free(actual.filter);
 	fclose(policy);
diff --git a/util.c b/util.c
index d16f3bb..0923ef9 100644
--- a/util.c
+++ b/util.c
@@ -51,3 +51,52 @@
 	*(end + 1) = '\0';
 	return s;
 }
+
+char *tokenize(char **stringp, const char *delim) {
+	char *ret = NULL;
+
+	/* If the string is NULL or empty, there are no tokens to be found. */
+	if (stringp == NULL || *stringp == NULL || **stringp == '\0')
+		return NULL;
+
+	/*
+	 * If the delimiter is NULL or empty,
+	 * the full string makes up the only token.
+	 */
+	if (delim == NULL || *delim == '\0') {
+		ret = *stringp;
+		*stringp = NULL;
+		return ret;
+	}
+
+	char *found;
+	while (**stringp != '\0') {
+		found = strstr(*stringp, delim);
+
+		if (!found) {
+			/*
+			 * The delimiter was not found, so the full string
+			 * makes up the only token, and we're done.
+			 */
+			ret = *stringp;
+			*stringp = NULL;
+			break;
+		}
+
+		if (found != *stringp) {
+			/* There's a non-empty token before the delimiter. */
+			*found = '\0';
+			ret = *stringp;
+			*stringp = found + strlen(delim);
+			break;
+		}
+
+		/*
+		 * The delimiter was found at the start of the string,
+		 * skip it and keep looking for a non-empty token.
+		 */
+		*stringp += strlen(delim);
+	}
+
+	return ret;
+}
diff --git a/util.h b/util.h
index 09a67dc..2740ab3 100644
--- a/util.h
+++ b/util.h
@@ -32,5 +32,6 @@
 int lookup_syscall(const char *name);
 const char *lookup_syscall_name(int nr);
 char *strip(char *s);
+char *tokenize(char **stringp, const char *delim);
 
 #endif /* _UTIL_H_ */