Add new targeted error injection tool

bpf_override_return is a very powerful mechanism for error injection,
with the caveat that it requires whitelisting of the functions to be
overriden.

inject.py will take a call chain and optional set of predicates as
input, and inject the appropriate error when both the call chain and all
predicates are satisfied.

Signed-off-by: Howard McLauchlan <hmclauchlan@fb.com>
diff --git a/tools/inject.py b/tools/inject.py
new file mode 100755
index 0000000..7ca8896
--- /dev/null
+++ b/tools/inject.py
@@ -0,0 +1,452 @@
+#!/usr/bin/env python3
+#
+# This script generates a BPF program with structure inspired by trace.py. The
+# generated program operates on PID-indexed stacks. Generally speaking,
+# bookkeeping is done at every intermediate function kprobe/kretprobe to enforce
+# the goal of "fail iff this call chain and these predicates".
+#
+# Top level functions(the ones at the end of the call chain) are responsible for
+# creating the pid_struct and deleting it from the map in kprobe and kretprobe
+# respectively.
+#
+# Intermediate functions(between should_fail_whatever and the top level
+# functions) are responsible for updating the stack to indicate "I have been
+# called and one of my predicate(s) passed" in their entry probes. In their exit
+# probes, they do the opposite, popping their stack to maintain correctness.
+# This implementation aims to ensure correctness in edge cases like recursive
+# calls, so there's some additional information stored in pid_struct for that.
+#
+# At the bottom level function(should_fail_whatever), we do a simple check to
+# ensure all necessary calls/predicates have passed before error injection.
+#
+# Note: presently there are a few hacks to get around various rewriter/verifier
+# issues.
+#
+# Note: this tool requires(as of v4.16-rc5):
+# - commit f7174d08a5fc ("mm: make should_failslab always available for fault
+# injection")
+# - CONFIG_BPF_KPROBE_OVERRIDE
+#
+# USAGE: inject [-h] [-I header] [-v]
+#
+# Copyright (c) 2018 Facebook, Inc.
+# Licensed under the Apache License, Version 2.0 (the "License")
+#
+# 16-Mar-2018   Howard McLauchlan   Created this.
+
+import argparse
+from bcc import BPF
+
+
+class Probe:
+    errno_mapping = {
+        "kmalloc": "-ENOMEM",
+        "bio": "-EIO",
+    }
+
+    @classmethod
+    def configure(cls, mode):
+        cls.mode = mode
+
+    def __init__(self, func, preds, length, entry):
+        # length of call chain
+        self.length = length
+        self.func = func
+        self.preds = preds
+        self.is_entry = entry
+
+    def _bail(self, err):
+        raise ValueError("error in probe '%s': %s" %
+                (self.spec, err))
+
+    def _get_err(self):
+        return Probe.errno_mapping[Probe.mode]
+
+    def _get_if_top(self):
+        # ordering guarantees that if this function is top, the last tup is top
+        chk = self.preds[0][1] == 0
+        if not chk:
+            return ""
+
+        # init the map
+        # dont do an early exit here so the singular case works automatically
+        enter = """
+        /*
+         * Top level function init map
+         */
+        struct pid_struct p_struct = {0, 0};
+        m.insert(&pid, &p_struct);
+        """
+
+        # kill the entry
+        exit = """
+        /*
+         * Top level function clean up map
+         */
+        m.delete(&pid);
+        """
+
+        return enter if self.is_entry else exit
+
+    def _get_heading(self):
+
+        # we need to insert identifier and ctx into self.func
+        # gonna make a lot of formatting assumptions to make this work
+        left = self.func.find("(")
+        right = self.func.rfind(")")
+
+        # self.event and self.func_name need to be accessible
+        self.event = self.func[0:left]
+        self.func_name = self.event + ("_entry" if self.is_entry else "_exit")
+        func_sig = "struct pt_regs *ctx"
+
+        # assume theres something in there, no guarantee its well formed
+        if right > left + 1 and self.is_entry:
+            func_sig += ", " + self.func[left + 1:right]
+
+        return "int %s(%s)" % (self.func_name, func_sig)
+
+    def _get_entry_logic(self):
+        # there is at least one tup(pred, place) for this function
+        text = """
+
+        if (p->conds_met >= %s)
+                return 0;
+        if (p->conds_met == %s && %s) {
+                p->stack[%s] = p->curr_call;
+                p->conds_met++;
+        }"""
+        text = text % (self.length, self.preds[0][1], self.preds[0][0],
+                self.preds[0][1])
+
+        # for each additional pred
+        for tup in self.preds[1:]:
+            text += """
+        else if (p->conds_met == %s && %s) {
+                p->stack[%s] = p->curr_call;
+                p->conds_met++;
+        }
+            """ % (tup[1], tup[0], tup[1])
+        return text
+
+    def _generate_entry(self):
+        prog = self._get_heading() + """
+{
+        u32 pid = bpf_get_current_pid_tgid();
+        %s
+
+        struct pid_struct *p = m.lookup(&pid);
+
+        if (!p)
+                return 0;
+
+        /*
+         * preparation for predicate, if necessary
+         */
+         %s
+        /*
+         * Generate entry logic
+         */
+        %s
+
+        p->curr_call++;
+
+        return 0;
+}"""
+
+        prog = prog % (self._get_if_top(), self.prep, self._get_entry_logic())
+        return prog
+
+    # only need to check top of stack
+    def _get_exit_logic(self):
+        text = """
+        if (p->conds_met < 1 || p->conds_met >= %s)
+                return 0;
+
+        if (p->stack[p->conds_met - 1] == p->curr_call)
+                p->conds_met--;
+        """
+        return text % str(self.length + 1)
+
+    def _generate_exit(self):
+        prog = self._get_heading() + """
+{
+        u32 pid = bpf_get_current_pid_tgid();
+
+        struct pid_struct *p = m.lookup(&pid);
+
+        if (!p)
+                return 0;
+
+        p->curr_call--;
+
+        /*
+         * Generate exit logic
+         */
+        %s
+        %s
+        return 0;
+}"""
+
+        prog = prog % (self._get_exit_logic(), self._get_if_top())
+
+        return prog
+
+    # Special case for should_fail_whatever
+    def _generate_bottom(self):
+        pred = self.preds[0][0]
+        text = self._get_heading() + """
+{
+        /*
+         * preparation for predicate, if necessary
+         */
+         %s
+        /*
+         * If this is the only call in the chain and predicate passes
+         */
+        if (%s == 1 && %s) {
+                bpf_override_return(ctx, %s);
+                return 0;
+        }
+        u32 pid = bpf_get_current_pid_tgid();
+
+        struct pid_struct *p = m.lookup(&pid);
+
+        if (!p)
+                return 0;
+
+        /*
+         * If all conds have been met and predicate passes
+         */
+        if (p->conds_met == %s && %s)
+                bpf_override_return(ctx, %s);
+        return 0;
+}""" % (self.prep, self.length, pred, self._get_err(), self.length - 1, pred,
+            self._get_err())
+        return text
+
+    # presently parses and replaces STRCMP
+    # STRCMP exists because string comparison is inconvenient and somewhat buggy
+    # https://github.com/iovisor/bcc/issues/1617
+    def _prepare_pred(self):
+        self.prep = ""
+        for i in range(len(self.preds)):
+            new_pred = ""
+            pred = self.preds[i][0]
+            place = self.preds[i][1]
+            start, ind = 0, 0
+            while start < len(pred):
+                ind = pred.find("STRCMP(", start)
+                if ind == -1:
+                    break
+                new_pred += pred[start:ind]
+                # 7 is len("STRCMP(")
+                start = pred.find(")", start + 7) + 1
+
+                # then ind ... start is STRCMP(...)
+                ptr, literal = pred[ind + 7:start - 1].split(",")
+                literal = literal.strip()
+
+                # x->y->z, some string literal
+                # we make unique id with place_ind
+                uuid = "%s_%s" % (place, ind)
+                unique_bool = "is_true_%s" % uuid
+                self.prep += """
+        char *str_%s = %s;
+        bool %s = true;\n""" % (uuid, ptr.strip(), unique_bool)
+
+                check = "\t%s &= *(str_%s++) == '%%s';\n" % (unique_bool, uuid)
+
+                for ch in literal:
+                    self.prep += check % ch
+                self.prep += check % r'\0'
+                new_pred += unique_bool
+
+            new_pred += pred[start:]
+            self.preds[i] = (new_pred, place)
+
+    def generate_program(self):
+        # generate code to work around various rewriter issues
+        self._prepare_pred()
+
+        # special case for bottom
+        if self.preds[-1][1] == self.length - 1:
+            return self._generate_bottom()
+
+        return self._generate_entry() if self.is_entry else self._generate_exit()
+
+    def attach(self, bpf):
+        if self.is_entry:
+            bpf.attach_kprobe(event=self.event,
+                    fn_name=self.func_name)
+        else:
+            bpf.attach_kretprobe(event=self.event,
+                    fn_name=self.func_name)
+
+
+class Tool:
+    # add cases as necessary
+    error_injection_mapping = {
+        "kmalloc": "should_failslab(struct kmem_cache *s, gfp_t gfpflags)",
+        "bio": "should_fail_bio(struct bio *bio)",
+    }
+
+    def __init__(self):
+        parser = argparse.ArgumentParser(description="Fail specified kernel" +
+                " functionality when call chain and predicates are met",
+                formatter_class=argparse.RawDescriptionHelpFormatter)
+        parser.add_argument(metavar="mode", dest="mode",
+                help="indicate which base kernel function to fail")
+        parser.add_argument(metavar="spec", dest="spec",
+                help="specify call chain")
+        parser.add_argument("-I", "--include", action="append",
+                metavar="header",
+                help="additional header files to include in the BPF program")
+        parser.add_argument("-v", "--verbose", action="store_true",
+            help="print BPF program")
+        self.args = parser.parse_args()
+
+        self.program = ""
+        self.spec = self.args.spec
+        self.map = {}
+        self.probes = []
+        self.key = Tool.error_injection_mapping[self.args.mode]
+
+    # create_probes and associated stuff
+    def _create_probes(self):
+        self._parse_spec()
+        Probe.configure(self.args.mode)
+        # self, func, preds, total, entry
+
+        # create all the pair probes
+        for fx, preds in self.map.items():
+
+            # do the enter
+            self.probes.append(Probe(fx, preds, self.length, True))
+
+            if self.key == fx:
+                continue
+
+            # do the exit
+            self.probes.append(Probe(fx, preds, self.length, False))
+
+    def _parse_frames(self):
+        # sentinel
+        data = self.spec + '\0'
+        start, count = 0, 0
+
+        frames = []
+        cur_frame = []
+        i = 0
+
+        while i < len(data):
+            # improper input
+            if count < 0:
+                raise Exception("Check your parentheses")
+            c = data[i]
+            count += c == '('
+            count -= c == ')'
+            if not count:
+                if c == '\0' or (c == '<' and data[i + 1] == '-'):
+                    if len(cur_frame) == 2:
+                        frame = tuple(cur_frame)
+                    elif cur_frame[0][0] == '(':
+                        frame = self.key, cur_frame[0]
+                    else:
+                        frame = cur_frame[0], '(true)'
+                    frames.append(frame)
+                    del cur_frame[:]
+                    i += 1
+                    start = i + 1
+                elif c == ')':
+                    cur_frame.append(data[start:i + 1].strip())
+                    start = i + 1
+            i += 1
+        # improper input
+        if count:
+            raise Exception("Check your parentheses")
+        return frames
+
+    def _parse_spec(self):
+        frames = self._parse_frames()
+        frames.reverse()
+
+        absolute_order = 0
+        for f in frames:
+            # default case
+            func, pred = f[0], f[1]
+
+            if not self._validate_predicate(pred):
+                raise Exception
+            tup = (pred, absolute_order)
+
+            if func not in self.map:
+                self.map[func] = [tup]
+            else:
+                self.map[func].append(tup)
+
+            absolute_order += 1
+
+        if self.key not in self.map:
+            self.map[self.key] = [('(true)', absolute_order)]
+            absolute_order += 1
+
+        self.length = absolute_order
+
+    def _validate_predicate(self, pred):
+
+        if len(pred) > 0 and pred[0] == "(":
+            open = 1
+            for i in range(1, len(pred)):
+                if pred[i] == "(":
+                    open += 1
+                elif pred[i] == ")":
+                    open -= 1
+            if open != 0:
+                # not well formed, break
+                return False
+
+        return True
+
+    def _def_pid_struct(self):
+        text = """
+struct pid_struct {
+    u64 curr_call; /* book keeping to handle recursion */
+    u64 conds_met; /* stack pointer */
+    u64 stack[%s];
+};
+""" % self.length
+        return text
+
+    def _attach_probes(self):
+        self.bpf = BPF(text=self.program)
+        for p in self.probes:
+            p.attach(self.bpf)
+
+    def _generate_program(self):
+        # leave out auto includes for now
+
+        for include in (self.args.include or []):
+            self.program += "#include <%s>\n" % include
+
+        self.program += self._def_pid_struct()
+        self.program += "BPF_HASH(m, u32, struct pid_struct);\n"
+        for p in self.probes:
+            self.program += p.generate_program() + "\n"
+
+        if self.args.verbose:
+            print(self.program)
+
+    def _main_loop(self):
+        while True:
+            self.bpf.perf_buffer_poll()
+
+    def run(self):
+        self._create_probes()
+        self._generate_program()
+        self._attach_probes()
+        self._main_loop()
+
+
+if __name__ == "__main__":
+    Tool().run()