| #!/usr/bin/env python |
| # |
| # This script generates a BPF program with structure inspired by trace.py. The |
| # generated program operates on PID-indexed stacks. Generally speaking, |
| # bookkeeping is done at every intermediate function kprobe/kretprobe to enforce |
| # the goal of "fail iff this call chain and these predicates". |
| # |
| # Top level functions(the ones at the end of the call chain) are responsible for |
| # creating the pid_struct and deleting it from the map in kprobe and kretprobe |
| # respectively. |
| # |
| # Intermediate functions(between should_fail_whatever and the top level |
| # functions) are responsible for updating the stack to indicate "I have been |
| # called and one of my predicate(s) passed" in their entry probes. In their exit |
| # probes, they do the opposite, popping their stack to maintain correctness. |
| # This implementation aims to ensure correctness in edge cases like recursive |
| # calls, so there's some additional information stored in pid_struct for that. |
| # |
| # At the bottom level function(should_fail_whatever), we do a simple check to |
| # ensure all necessary calls/predicates have passed before error injection. |
| # |
| # Note: presently there are a few hacks to get around various rewriter/verifier |
| # issues. |
| # |
| # Note: this tool requires: |
| # - CONFIG_BPF_KPROBE_OVERRIDE |
| # |
| # USAGE: inject [-h] [-I header] [-P probability] [-v] mode spec |
| # |
| # Copyright (c) 2018 Facebook, Inc. |
| # Licensed under the Apache License, Version 2.0 (the "License") |
| # |
| # 16-Mar-2018 Howard McLauchlan Created this. |
| |
| import argparse |
| import re |
| from bcc import BPF |
| |
| |
| class Probe: |
| errno_mapping = { |
| "kmalloc": "-ENOMEM", |
| "bio": "-EIO", |
| } |
| |
| @classmethod |
| def configure(cls, mode, probability): |
| cls.mode = mode |
| cls.probability = probability |
| |
| def __init__(self, func, preds, length, entry): |
| # length of call chain |
| self.length = length |
| self.func = func |
| self.preds = preds |
| self.is_entry = entry |
| |
| def _bail(self, err): |
| raise ValueError("error in probe '%s': %s" % |
| (self.spec, err)) |
| |
| def _get_err(self): |
| return Probe.errno_mapping[Probe.mode] |
| |
| def _get_if_top(self): |
| # ordering guarantees that if this function is top, the last tup is top |
| chk = self.preds[0][1] == 0 |
| if not chk: |
| return "" |
| |
| if Probe.probability == 1: |
| early_pred = "false" |
| else: |
| early_pred = "bpf_get_prandom_u32() > %s" % str(int((1<<32)*Probe.probability)) |
| # init the map |
| # dont do an early exit here so the singular case works automatically |
| # have an early exit for probability option |
| enter = """ |
| /* |
| * Early exit for probability case |
| */ |
| if (%s) |
| return 0; |
| /* |
| * Top level function init map |
| */ |
| struct pid_struct p_struct = {0, 0}; |
| m.insert(&pid, &p_struct); |
| """ % early_pred |
| |
| # kill the entry |
| exit = """ |
| /* |
| * Top level function clean up map |
| */ |
| m.delete(&pid); |
| """ |
| |
| return enter if self.is_entry else exit |
| |
| def _get_heading(self): |
| |
| # we need to insert identifier and ctx into self.func |
| # gonna make a lot of formatting assumptions to make this work |
| left = self.func.find("(") |
| right = self.func.rfind(")") |
| |
| # self.event and self.func_name need to be accessible |
| self.event = self.func[0:left] |
| self.func_name = self.event + ("_entry" if self.is_entry else "_exit") |
| func_sig = "struct pt_regs *ctx" |
| |
| # assume theres something in there, no guarantee its well formed |
| if right > left + 1 and self.is_entry: |
| func_sig += ", " + self.func[left + 1:right] |
| |
| return "int %s(%s)" % (self.func_name, func_sig) |
| |
| def _get_entry_logic(self): |
| # there is at least one tup(pred, place) for this function |
| text = """ |
| |
| if (p->conds_met >= %s) |
| return 0; |
| if (p->conds_met == %s && %s) { |
| p->stack[%s] = p->curr_call; |
| p->conds_met++; |
| }""" |
| text = text % (self.length, self.preds[0][1], self.preds[0][0], |
| self.preds[0][1]) |
| |
| # for each additional pred |
| for tup in self.preds[1:]: |
| text += """ |
| else if (p->conds_met == %s && %s) { |
| p->stack[%s] = p->curr_call; |
| p->conds_met++; |
| } |
| """ % (tup[1], tup[0], tup[1]) |
| return text |
| |
| def _generate_entry(self): |
| prog = self._get_heading() + """ |
| { |
| u32 pid = bpf_get_current_pid_tgid(); |
| %s |
| |
| struct pid_struct *p = m.lookup(&pid); |
| |
| if (!p) |
| return 0; |
| |
| /* |
| * preparation for predicate, if necessary |
| */ |
| %s |
| /* |
| * Generate entry logic |
| */ |
| %s |
| |
| p->curr_call++; |
| |
| return 0; |
| }""" |
| |
| prog = prog % (self._get_if_top(), self.prep, self._get_entry_logic()) |
| return prog |
| |
| # only need to check top of stack |
| def _get_exit_logic(self): |
| text = """ |
| if (p->conds_met < 1 || p->conds_met >= %s) |
| return 0; |
| |
| if (p->stack[p->conds_met - 1] == p->curr_call) |
| p->conds_met--; |
| """ |
| return text % str(self.length + 1) |
| |
| def _generate_exit(self): |
| prog = self._get_heading() + """ |
| { |
| u32 pid = bpf_get_current_pid_tgid(); |
| |
| struct pid_struct *p = m.lookup(&pid); |
| |
| if (!p) |
| return 0; |
| |
| p->curr_call--; |
| |
| /* |
| * Generate exit logic |
| */ |
| %s |
| %s |
| return 0; |
| }""" |
| |
| prog = prog % (self._get_exit_logic(), self._get_if_top()) |
| |
| return prog |
| |
| # Special case for should_fail_whatever |
| def _generate_bottom(self): |
| pred = self.preds[0][0] |
| text = self._get_heading() + """ |
| { |
| /* |
| * preparation for predicate, if necessary |
| */ |
| %s |
| /* |
| * If this is the only call in the chain and predicate passes |
| */ |
| if (%s == 1 && %s) { |
| bpf_override_return(ctx, %s); |
| return 0; |
| } |
| u32 pid = bpf_get_current_pid_tgid(); |
| |
| struct pid_struct *p = m.lookup(&pid); |
| |
| if (!p) |
| return 0; |
| |
| /* |
| * If all conds have been met and predicate passes |
| */ |
| if (p->conds_met == %s && %s) |
| bpf_override_return(ctx, %s); |
| return 0; |
| }""" |
| return text % (self.prep, self.length, pred, self._get_err(), |
| self.length - 1, pred, self._get_err()) |
| |
| # presently parses and replaces STRCMP |
| # STRCMP exists because string comparison is inconvenient and somewhat buggy |
| # https://github.com/iovisor/bcc/issues/1617 |
| def _prepare_pred(self): |
| self.prep = "" |
| for i in range(len(self.preds)): |
| new_pred = "" |
| pred = self.preds[i][0] |
| place = self.preds[i][1] |
| start, ind = 0, 0 |
| while start < len(pred): |
| ind = pred.find("STRCMP(", start) |
| if ind == -1: |
| break |
| new_pred += pred[start:ind] |
| # 7 is len("STRCMP(") |
| start = pred.find(")", start + 7) + 1 |
| |
| # then ind ... start is STRCMP(...) |
| ptr, literal = pred[ind + 7:start - 1].split(",") |
| literal = literal.strip() |
| |
| # x->y->z, some string literal |
| # we make unique id with place_ind |
| uuid = "%s_%s" % (place, ind) |
| unique_bool = "is_true_%s" % uuid |
| self.prep += """ |
| char *str_%s = %s; |
| bool %s = true;\n""" % (uuid, ptr.strip(), unique_bool) |
| |
| check = "\t%s &= *(str_%s++) == '%%s';\n" % (unique_bool, uuid) |
| |
| for ch in literal: |
| self.prep += check % ch |
| self.prep += check % r'\0' |
| new_pred += unique_bool |
| |
| new_pred += pred[start:] |
| self.preds[i] = (new_pred, place) |
| |
| def generate_program(self): |
| # generate code to work around various rewriter issues |
| self._prepare_pred() |
| |
| # special case for bottom |
| if self.preds[-1][1] == self.length - 1: |
| return self._generate_bottom() |
| |
| return self._generate_entry() if self.is_entry else self._generate_exit() |
| |
| def attach(self, bpf): |
| if self.is_entry: |
| bpf.attach_kprobe(event=self.event, |
| fn_name=self.func_name) |
| else: |
| bpf.attach_kretprobe(event=self.event, |
| fn_name=self.func_name) |
| |
| |
| class Tool: |
| |
| examples =""" |
| EXAMPLES: |
| # ./inject.py kmalloc -v 'SyS_mount()' |
| Fails all calls to syscall mount |
| # ./inject.py kmalloc -v '(true) => SyS_mount()(true)' |
| Explicit rewriting of above |
| # ./inject.py kmalloc -v 'mount_subtree() => btrfs_mount()' |
| Fails btrfs mounts only |
| # ./inject.py kmalloc -v 'd_alloc_parallel(struct dentry *parent, const struct \\ |
| qstr *name)(STRCMP(name->name, 'bananas'))' |
| Fails dentry allocations of files named 'bananas' |
| # ./inject.py kmalloc -v -P 0.01 'SyS_mount()' |
| Fails calls to syscall mount with 1% probability |
| """ |
| # add cases as necessary |
| error_injection_mapping = { |
| "kmalloc": "should_failslab(struct kmem_cache *s, gfp_t gfpflags)", |
| "bio": "should_fail_bio(struct bio *bio)", |
| } |
| |
| def __init__(self): |
| parser = argparse.ArgumentParser(description="Fail specified kernel" + |
| " functionality when call chain and predicates are met", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=Tool.examples) |
| parser.add_argument(dest="mode", choices=['kmalloc','bio'], |
| help="indicate which base kernel function to fail") |
| parser.add_argument(metavar="spec", dest="spec", |
| help="specify call chain") |
| parser.add_argument("-I", "--include", action="append", |
| metavar="header", |
| help="additional header files to include in the BPF program") |
| parser.add_argument("-P", "--probability", default=1, |
| metavar="probability", type=float, |
| help="probability that this call chain will fail") |
| parser.add_argument("-v", "--verbose", action="store_true", |
| help="print BPF program") |
| self.args = parser.parse_args() |
| |
| self.program = "" |
| self.spec = self.args.spec |
| self.map = {} |
| self.probes = [] |
| self.key = Tool.error_injection_mapping[self.args.mode] |
| |
| # create_probes and associated stuff |
| def _create_probes(self): |
| self._parse_spec() |
| Probe.configure(self.args.mode, self.args.probability) |
| # self, func, preds, total, entry |
| |
| # create all the pair probes |
| for fx, preds in self.map.items(): |
| |
| # do the enter |
| self.probes.append(Probe(fx, preds, self.length, True)) |
| |
| if self.key == fx: |
| continue |
| |
| # do the exit |
| self.probes.append(Probe(fx, preds, self.length, False)) |
| |
| def _parse_frames(self): |
| # sentinel |
| data = self.spec + '\0' |
| start, count = 0, 0 |
| |
| frames = [] |
| cur_frame = [] |
| i = 0 |
| last_frame_added = 0 |
| |
| while i < len(data): |
| # improper input |
| if count < 0: |
| raise Exception("Check your parentheses") |
| c = data[i] |
| count += c == '(' |
| count -= c == ')' |
| if not count: |
| if c == '\0' or (c == '=' and data[i + 1] == '>'): |
| # This block is closing a chunk. This means cur_frame must |
| # have something in it. |
| if not cur_frame: |
| raise Exception("Cannot parse spec, missing parens") |
| if len(cur_frame) == 2: |
| frame = tuple(cur_frame) |
| elif cur_frame[0][0] == '(': |
| frame = self.key, cur_frame[0] |
| else: |
| frame = cur_frame[0], '(true)' |
| frames.append(frame) |
| del cur_frame[:] |
| i += 1 |
| start = i + 1 |
| elif c == ')': |
| cur_frame.append(data[start:i + 1].strip()) |
| start = i + 1 |
| last_frame_added = start |
| i += 1 |
| |
| # We only permit spaces after the last frame |
| if self.spec[last_frame_added:].strip(): |
| raise Exception("Invalid characters found after last frame"); |
| # improper input |
| if count: |
| raise Exception("Check your parentheses") |
| return frames |
| |
| def _parse_spec(self): |
| frames = self._parse_frames() |
| frames.reverse() |
| |
| absolute_order = 0 |
| for f in frames: |
| # default case |
| func, pred = f[0], f[1] |
| |
| if not self._validate_predicate(pred): |
| raise Exception("Invalid predicate") |
| if not self._validate_identifier(func): |
| raise Exception("Invalid function identifier") |
| tup = (pred, absolute_order) |
| |
| if func not in self.map: |
| self.map[func] = [tup] |
| else: |
| self.map[func].append(tup) |
| |
| absolute_order += 1 |
| |
| if self.key not in self.map: |
| self.map[self.key] = [('(true)', absolute_order)] |
| absolute_order += 1 |
| |
| self.length = absolute_order |
| |
| def _validate_identifier(self, func): |
| # We've already established paren balancing. We will only look for |
| # identifier validity here. |
| paren_index = func.find("(") |
| potential_id = func[:paren_index] |
| pattern = '[_a-zA-z][_a-zA-Z0-9]*$' |
| if re.match(pattern, potential_id): |
| return True |
| return False |
| |
| def _validate_predicate(self, pred): |
| |
| if len(pred) > 0 and pred[0] == "(": |
| open = 1 |
| for i in range(1, len(pred)): |
| if pred[i] == "(": |
| open += 1 |
| elif pred[i] == ")": |
| open -= 1 |
| if open != 0: |
| # not well formed, break |
| return False |
| |
| return True |
| |
| def _def_pid_struct(self): |
| text = """ |
| struct pid_struct { |
| u64 curr_call; /* book keeping to handle recursion */ |
| u64 conds_met; /* stack pointer */ |
| u64 stack[%s]; |
| }; |
| """ % self.length |
| return text |
| |
| def _attach_probes(self): |
| self.bpf = BPF(text=self.program) |
| for p in self.probes: |
| p.attach(self.bpf) |
| |
| def _generate_program(self): |
| # leave out auto includes for now |
| self.program += '#include <linux/mm.h>\n' |
| for include in (self.args.include or []): |
| self.program += "#include <%s>\n" % include |
| |
| self.program += self._def_pid_struct() |
| self.program += "BPF_HASH(m, u32, struct pid_struct);\n" |
| for p in self.probes: |
| self.program += p.generate_program() + "\n" |
| |
| if self.args.verbose: |
| print(self.program) |
| |
| def _main_loop(self): |
| while True: |
| self.bpf.perf_buffer_poll() |
| |
| def run(self): |
| self._create_probes() |
| self._generate_program() |
| self._attach_probes() |
| self._main_loop() |
| |
| |
| if __name__ == "__main__": |
| Tool().run() |