syscall_filter: Implement flag set inclusion.
When filtering syscalls that take flags as an argument, we usually want
to allow a small set of "safe" flags. This is hard to express with the
current language.
Implement this by adding a "flag set inclusion" mode using the 'in'
keyword. This works by allowing the syscall as long as the passed
flags, when viewed as a set, are included in the set of flags described
by the policy.
Also, clang-format all of bpf.c.
Bug: 31997910
Test: syscall_filter_unittest
Change-Id: I121af56b176bd3260904d367fd92d47a16bb3dcb
diff --git a/bpf.c b/bpf.c
index ce754a7..b618866 100644
--- a/bpf.c
+++ b/bpf.c
@@ -15,9 +15,9 @@
size_t bpf_validate_arch(struct sock_filter *filter)
{
struct sock_filter *curr_block = filter;
- set_bpf_stmt(curr_block++, BPF_LD+BPF_W+BPF_ABS, arch_nr);
- set_bpf_jump(curr_block++,
- BPF_JMP+BPF_JEQ+BPF_K, ARCH_NR, SKIP, NEXT);
+ set_bpf_stmt(curr_block++, BPF_LD + BPF_W + BPF_ABS, arch_nr);
+ set_bpf_jump(curr_block++, BPF_JMP + BPF_JEQ + BPF_K, ARCH_NR, SKIP,
+ NEXT);
set_bpf_ret_kill(curr_block++);
return curr_block - filter;
}
@@ -26,16 +26,16 @@
size_t bpf_allow_syscall(struct sock_filter *filter, int nr)
{
struct sock_filter *curr_block = filter;
- set_bpf_jump(curr_block++, BPF_JMP+BPF_JEQ+BPF_K, nr, NEXT, SKIP);
- set_bpf_stmt(curr_block++, BPF_RET+BPF_K, SECCOMP_RET_ALLOW);
+ set_bpf_jump(curr_block++, BPF_JMP + BPF_JEQ + BPF_K, nr, NEXT, SKIP);
+ set_bpf_stmt(curr_block++, BPF_RET + BPF_K, SECCOMP_RET_ALLOW);
return curr_block - filter;
}
-size_t bpf_allow_syscall_args(struct sock_filter *filter,
- int nr, unsigned int id)
+size_t bpf_allow_syscall_args(struct sock_filter *filter, int nr,
+ unsigned int id)
{
struct sock_filter *curr_block = filter;
- set_bpf_jump(curr_block++, BPF_JMP+BPF_JEQ+BPF_K, nr, NEXT, SKIP);
+ set_bpf_jump(curr_block++, BPF_JMP + BPF_JEQ + BPF_K, nr, NEXT, SKIP);
set_bpf_jump_lbl(curr_block++, id);
return curr_block - filter;
}
@@ -44,16 +44,16 @@
#if defined(BITS32)
size_t bpf_load_arg(struct sock_filter *filter, int argidx)
{
- set_bpf_stmt(filter, BPF_LD+BPF_W+BPF_ABS, LO_ARG(argidx));
+ set_bpf_stmt(filter, BPF_LD + BPF_W + BPF_ABS, LO_ARG(argidx));
return 1U;
}
#elif defined(BITS64)
size_t bpf_load_arg(struct sock_filter *filter, int argidx)
{
struct sock_filter *curr_block = filter;
- set_bpf_stmt(curr_block++, BPF_LD+BPF_W+BPF_ABS, LO_ARG(argidx));
+ set_bpf_stmt(curr_block++, BPF_LD + BPF_W + BPF_ABS, LO_ARG(argidx));
set_bpf_stmt(curr_block++, BPF_ST, 0); /* lo -> M[0] */
- set_bpf_stmt(curr_block++, BPF_LD+BPF_W+BPF_ABS, HI_ARG(argidx));
+ set_bpf_stmt(curr_block++, BPF_LD + BPF_W + BPF_ABS, HI_ARG(argidx));
set_bpf_stmt(curr_block++, BPF_ST, 1); /* hi -> M[1] */
return curr_block - filter;
}
@@ -61,10 +61,10 @@
/* Size-aware equality comparison. */
size_t bpf_comp_jeq32(struct sock_filter *filter, unsigned long c,
- unsigned char jt, unsigned char jf)
+ unsigned char jt, unsigned char jf)
{
unsigned int lo = (unsigned int)(c & 0xFFFFFFFF);
- set_bpf_jump(filter, BPF_JMP+BPF_JEQ+BPF_K, lo, jt, jf);
+ set_bpf_jump(filter, BPF_JMP + BPF_JEQ + BPF_K, lo, jt, jf);
return 1U;
}
@@ -73,8 +73,8 @@
* We jump true when *both* comparisons are true.
*/
#if defined(BITS64)
-size_t bpf_comp_jeq64(struct sock_filter *filter, uint64_t c,
- unsigned char jt, unsigned char jf)
+size_t bpf_comp_jeq64(struct sock_filter *filter, uint64_t c, unsigned char jt,
+ unsigned char jf)
{
unsigned int lo = (unsigned int)(c & 0xFFFFFFFF);
unsigned int hi = (unsigned int)(c >> 32);
@@ -83,7 +83,7 @@
/* bpf_load_arg leaves |hi| in A */
curr_block += bpf_comp_jeq32(curr_block, hi, NEXT, SKIPN(2) + jf);
- set_bpf_stmt(curr_block++, BPF_LD+BPF_MEM, 0); /* swap in |lo| */
+ set_bpf_stmt(curr_block++, BPF_LD + BPF_MEM, 0); /* swap in |lo| */
curr_block += bpf_comp_jeq32(curr_block, lo, jt, jf);
return curr_block - filter;
@@ -92,10 +92,10 @@
/* Size-aware bitwise AND. */
size_t bpf_comp_jset32(struct sock_filter *filter, unsigned long mask,
- unsigned char jt, unsigned char jf)
+ unsigned char jt, unsigned char jf)
{
unsigned int mask_lo = (unsigned int)(mask & 0xFFFFFFFF);
- set_bpf_jump(filter, BPF_JMP+BPF_JSET+BPF_K, mask_lo, jt, jf);
+ set_bpf_jump(filter, BPF_JMP + BPF_JSET + BPF_K, mask_lo, jt, jf);
return 1U;
}
@@ -105,7 +105,7 @@
*/
#if defined(BITS64)
size_t bpf_comp_jset64(struct sock_filter *filter, uint64_t mask,
- unsigned char jt, unsigned char jf)
+ unsigned char jt, unsigned char jf)
{
unsigned int mask_lo = (unsigned int)(mask & 0xFFFFFFFF);
unsigned int mask_hi = (unsigned int)(mask >> 32);
@@ -114,20 +114,32 @@
/* bpf_load_arg leaves |hi| in A */
curr_block += bpf_comp_jset32(curr_block, mask_hi, SKIPN(2) + jt, NEXT);
- set_bpf_stmt(curr_block++, BPF_LD+BPF_MEM, 0); /* swap in |lo| */
+ set_bpf_stmt(curr_block++, BPF_LD + BPF_MEM, 0); /* swap in |lo| */
curr_block += bpf_comp_jset32(curr_block, mask_lo, jt, jf);
return curr_block - filter;
}
#endif
-size_t bpf_arg_comp(struct sock_filter **pfilter,
- int op, int argidx, unsigned long c, unsigned int label_id)
+size_t bpf_comp_jin(struct sock_filter *filter, unsigned long mask,
+ unsigned char jt, unsigned char jf)
{
- struct sock_filter *filter = calloc(BPF_ARG_COMP_LEN + 1,
- sizeof(struct sock_filter));
+ unsigned long negative_mask = ~mask;
+ /*
+ * The mask is negated, so the comparison will be true when the argument
+ * includes a flag that wasn't listed in the original (non-negated)
+ * mask. This would be the failure case, so we switch |jt| and |jf|.
+ */
+ return bpf_comp_jset(filter, negative_mask, jf, jt);
+}
+
+size_t bpf_arg_comp(struct sock_filter **pfilter, int op, int argidx,
+ unsigned long c, unsigned int label_id)
+{
+ struct sock_filter *filter =
+ calloc(BPF_ARG_COMP_LEN + 1, sizeof(struct sock_filter));
struct sock_filter *curr_block = filter;
- size_t (*comp_function)(struct sock_filter *filter, unsigned long k,
+ size_t (*comp_function)(struct sock_filter * filter, unsigned long k,
unsigned char jt, unsigned char jf);
int flip = 0;
@@ -148,6 +160,10 @@
comp_function = bpf_comp_jset;
flip = 0;
break;
+ case IN:
+ comp_function = bpf_comp_jin;
+ flip = 0;
+ break;
default:
*pfilter = NULL;
return 0;
@@ -279,7 +295,7 @@
end = begin + labels->count;
for (; begin < end; ++begin) {
if (begin->label)
- free((void*)(begin->label));
+ free((void *)(begin->label));
}
labels->count = 0;