blob: 2f7d2f1deb2572f84afc386555d9fafa7061dfd8 [file] [log] [blame]
Howard McLauchlanef4154b2018-03-16 16:50:26 -07001#!/usr/bin/env python3
2#
3# This script generates a BPF program with structure inspired by trace.py. The
4# generated program operates on PID-indexed stacks. Generally speaking,
5# bookkeeping is done at every intermediate function kprobe/kretprobe to enforce
6# the goal of "fail iff this call chain and these predicates".
7#
8# Top level functions(the ones at the end of the call chain) are responsible for
9# creating the pid_struct and deleting it from the map in kprobe and kretprobe
10# respectively.
11#
12# Intermediate functions(between should_fail_whatever and the top level
13# functions) are responsible for updating the stack to indicate "I have been
14# called and one of my predicate(s) passed" in their entry probes. In their exit
15# probes, they do the opposite, popping their stack to maintain correctness.
16# This implementation aims to ensure correctness in edge cases like recursive
17# calls, so there's some additional information stored in pid_struct for that.
18#
19# At the bottom level function(should_fail_whatever), we do a simple check to
20# ensure all necessary calls/predicates have passed before error injection.
21#
22# Note: presently there are a few hacks to get around various rewriter/verifier
23# issues.
24#
Howard McLauchlanb222f002018-04-10 13:05:47 -070025# Note: this tool requires:
Howard McLauchlanef4154b2018-03-16 16:50:26 -070026# - CONFIG_BPF_KPROBE_OVERRIDE
27#
Howard McLauchlanb222f002018-04-10 13:05:47 -070028# USAGE: inject [-h] [-I header] [-P probability] [-v] mode spec
Howard McLauchlanef4154b2018-03-16 16:50:26 -070029#
30# Copyright (c) 2018 Facebook, Inc.
31# Licensed under the Apache License, Version 2.0 (the "License")
32#
33# 16-Mar-2018 Howard McLauchlan Created this.
34
35import argparse
Howard McLauchlanfb3c0a72018-04-13 14:00:15 -070036import re
Howard McLauchlanef4154b2018-03-16 16:50:26 -070037from bcc import BPF
38
39
40class Probe:
41 errno_mapping = {
42 "kmalloc": "-ENOMEM",
43 "bio": "-EIO",
44 }
45
46 @classmethod
Howard McLauchlanb222f002018-04-10 13:05:47 -070047 def configure(cls, mode, probability):
Howard McLauchlanef4154b2018-03-16 16:50:26 -070048 cls.mode = mode
Howard McLauchlanb222f002018-04-10 13:05:47 -070049 cls.probability = probability
Howard McLauchlanef4154b2018-03-16 16:50:26 -070050
51 def __init__(self, func, preds, length, entry):
52 # length of call chain
53 self.length = length
54 self.func = func
55 self.preds = preds
56 self.is_entry = entry
57
58 def _bail(self, err):
59 raise ValueError("error in probe '%s': %s" %
60 (self.spec, err))
61
62 def _get_err(self):
63 return Probe.errno_mapping[Probe.mode]
64
65 def _get_if_top(self):
66 # ordering guarantees that if this function is top, the last tup is top
67 chk = self.preds[0][1] == 0
68 if not chk:
69 return ""
70
Howard McLauchlanb222f002018-04-10 13:05:47 -070071 if Probe.probability == 1:
72 early_pred = "false"
73 else:
74 early_pred = "bpf_get_prandom_u32() > %s" % str(int((1<<32)*Probe.probability))
Howard McLauchlanef4154b2018-03-16 16:50:26 -070075 # init the map
76 # dont do an early exit here so the singular case works automatically
Howard McLauchlanb222f002018-04-10 13:05:47 -070077 # have an early exit for probability option
Howard McLauchlanef4154b2018-03-16 16:50:26 -070078 enter = """
79 /*
Howard McLauchlanb222f002018-04-10 13:05:47 -070080 * Early exit for probability case
81 */
82 if (%s)
83 return 0;
84 /*
Howard McLauchlanef4154b2018-03-16 16:50:26 -070085 * Top level function init map
86 */
87 struct pid_struct p_struct = {0, 0};
88 m.insert(&pid, &p_struct);
Howard McLauchlanb222f002018-04-10 13:05:47 -070089 """ % early_pred
Howard McLauchlanef4154b2018-03-16 16:50:26 -070090
91 # kill the entry
92 exit = """
93 /*
94 * Top level function clean up map
95 */
96 m.delete(&pid);
97 """
98
99 return enter if self.is_entry else exit
100
101 def _get_heading(self):
102
103 # we need to insert identifier and ctx into self.func
104 # gonna make a lot of formatting assumptions to make this work
105 left = self.func.find("(")
106 right = self.func.rfind(")")
107
108 # self.event and self.func_name need to be accessible
109 self.event = self.func[0:left]
110 self.func_name = self.event + ("_entry" if self.is_entry else "_exit")
111 func_sig = "struct pt_regs *ctx"
112
113 # assume theres something in there, no guarantee its well formed
114 if right > left + 1 and self.is_entry:
115 func_sig += ", " + self.func[left + 1:right]
116
117 return "int %s(%s)" % (self.func_name, func_sig)
118
119 def _get_entry_logic(self):
120 # there is at least one tup(pred, place) for this function
121 text = """
122
123 if (p->conds_met >= %s)
124 return 0;
125 if (p->conds_met == %s && %s) {
126 p->stack[%s] = p->curr_call;
127 p->conds_met++;
128 }"""
129 text = text % (self.length, self.preds[0][1], self.preds[0][0],
130 self.preds[0][1])
131
132 # for each additional pred
133 for tup in self.preds[1:]:
134 text += """
135 else if (p->conds_met == %s && %s) {
136 p->stack[%s] = p->curr_call;
137 p->conds_met++;
138 }
139 """ % (tup[1], tup[0], tup[1])
140 return text
141
142 def _generate_entry(self):
143 prog = self._get_heading() + """
144{
145 u32 pid = bpf_get_current_pid_tgid();
146 %s
147
148 struct pid_struct *p = m.lookup(&pid);
149
150 if (!p)
151 return 0;
152
153 /*
154 * preparation for predicate, if necessary
155 */
156 %s
157 /*
158 * Generate entry logic
159 */
160 %s
161
162 p->curr_call++;
163
164 return 0;
165}"""
166
167 prog = prog % (self._get_if_top(), self.prep, self._get_entry_logic())
168 return prog
169
170 # only need to check top of stack
171 def _get_exit_logic(self):
172 text = """
173 if (p->conds_met < 1 || p->conds_met >= %s)
174 return 0;
175
176 if (p->stack[p->conds_met - 1] == p->curr_call)
177 p->conds_met--;
178 """
179 return text % str(self.length + 1)
180
181 def _generate_exit(self):
182 prog = self._get_heading() + """
183{
184 u32 pid = bpf_get_current_pid_tgid();
185
186 struct pid_struct *p = m.lookup(&pid);
187
188 if (!p)
189 return 0;
190
191 p->curr_call--;
192
193 /*
194 * Generate exit logic
195 */
196 %s
197 %s
198 return 0;
199}"""
200
201 prog = prog % (self._get_exit_logic(), self._get_if_top())
202
203 return prog
204
205 # Special case for should_fail_whatever
206 def _generate_bottom(self):
207 pred = self.preds[0][0]
208 text = self._get_heading() + """
209{
210 /*
211 * preparation for predicate, if necessary
212 */
213 %s
214 /*
215 * If this is the only call in the chain and predicate passes
216 */
217 if (%s == 1 && %s) {
218 bpf_override_return(ctx, %s);
219 return 0;
220 }
221 u32 pid = bpf_get_current_pid_tgid();
222
223 struct pid_struct *p = m.lookup(&pid);
224
225 if (!p)
226 return 0;
227
228 /*
229 * If all conds have been met and predicate passes
230 */
231 if (p->conds_met == %s && %s)
232 bpf_override_return(ctx, %s);
233 return 0;
Howard McLauchlanb222f002018-04-10 13:05:47 -0700234}"""
235 return text % (self.prep, self.length, pred, self._get_err(),
236 self.length - 1, pred, self._get_err())
Howard McLauchlanef4154b2018-03-16 16:50:26 -0700237
238 # presently parses and replaces STRCMP
239 # STRCMP exists because string comparison is inconvenient and somewhat buggy
240 # https://github.com/iovisor/bcc/issues/1617
241 def _prepare_pred(self):
242 self.prep = ""
243 for i in range(len(self.preds)):
244 new_pred = ""
245 pred = self.preds[i][0]
246 place = self.preds[i][1]
247 start, ind = 0, 0
248 while start < len(pred):
249 ind = pred.find("STRCMP(", start)
250 if ind == -1:
251 break
252 new_pred += pred[start:ind]
253 # 7 is len("STRCMP(")
254 start = pred.find(")", start + 7) + 1
255
256 # then ind ... start is STRCMP(...)
257 ptr, literal = pred[ind + 7:start - 1].split(",")
258 literal = literal.strip()
259
260 # x->y->z, some string literal
261 # we make unique id with place_ind
262 uuid = "%s_%s" % (place, ind)
263 unique_bool = "is_true_%s" % uuid
264 self.prep += """
265 char *str_%s = %s;
266 bool %s = true;\n""" % (uuid, ptr.strip(), unique_bool)
267
268 check = "\t%s &= *(str_%s++) == '%%s';\n" % (unique_bool, uuid)
269
270 for ch in literal:
271 self.prep += check % ch
272 self.prep += check % r'\0'
273 new_pred += unique_bool
274
275 new_pred += pred[start:]
276 self.preds[i] = (new_pred, place)
277
278 def generate_program(self):
279 # generate code to work around various rewriter issues
280 self._prepare_pred()
281
282 # special case for bottom
283 if self.preds[-1][1] == self.length - 1:
284 return self._generate_bottom()
285
286 return self._generate_entry() if self.is_entry else self._generate_exit()
287
288 def attach(self, bpf):
289 if self.is_entry:
290 bpf.attach_kprobe(event=self.event,
291 fn_name=self.func_name)
292 else:
293 bpf.attach_kretprobe(event=self.event,
294 fn_name=self.func_name)
295
296
297class Tool:
Howard McLauchlan45bcfb72018-04-13 14:19:15 -0700298
299 examples ="""
300EXAMPLES:
301# ./inject.py kmalloc -v 'SyS_mount()'
302 Fails all calls to syscall mount
303# ./inject.py kmalloc -v '(true) => SyS_mount()(true)'
304 Explicit rewriting of above
305# ./inject.py kmalloc -v 'mount_subtree() => btrfs_mount()'
306 Fails btrfs mounts only
307# ./inject.py kmalloc -v 'd_alloc_parallel(struct dentry *parent, const struct \\
308 qstr *name)(STRCMP(name->name, 'bananas'))'
309 Fails dentry allocations of files named 'bananas'
310# ./inject.py kmalloc -v -P 0.01 'SyS_mount()'
311 Fails calls to syscall mount with 1% probability
312 """
Howard McLauchlanef4154b2018-03-16 16:50:26 -0700313 # add cases as necessary
314 error_injection_mapping = {
315 "kmalloc": "should_failslab(struct kmem_cache *s, gfp_t gfpflags)",
316 "bio": "should_fail_bio(struct bio *bio)",
317 }
318
319 def __init__(self):
320 parser = argparse.ArgumentParser(description="Fail specified kernel" +
321 " functionality when call chain and predicates are met",
Howard McLauchlan45bcfb72018-04-13 14:19:15 -0700322 formatter_class=argparse.RawDescriptionHelpFormatter,
323 epilog=Tool.examples)
324 parser.add_argument(dest="mode", choices=['kmalloc','bio'],
Howard McLauchlanef4154b2018-03-16 16:50:26 -0700325 help="indicate which base kernel function to fail")
326 parser.add_argument(metavar="spec", dest="spec",
327 help="specify call chain")
328 parser.add_argument("-I", "--include", action="append",
329 metavar="header",
330 help="additional header files to include in the BPF program")
Howard McLauchlanb222f002018-04-10 13:05:47 -0700331 parser.add_argument("-P", "--probability", default=1,
332 metavar="probability", type=float,
333 help="probability that this call chain will fail")
Howard McLauchlanef4154b2018-03-16 16:50:26 -0700334 parser.add_argument("-v", "--verbose", action="store_true",
Howard McLauchlan45bcfb72018-04-13 14:19:15 -0700335 help="print BPF program")
Howard McLauchlanef4154b2018-03-16 16:50:26 -0700336 self.args = parser.parse_args()
337
338 self.program = ""
339 self.spec = self.args.spec
340 self.map = {}
341 self.probes = []
342 self.key = Tool.error_injection_mapping[self.args.mode]
343
344 # create_probes and associated stuff
345 def _create_probes(self):
346 self._parse_spec()
Howard McLauchlanb222f002018-04-10 13:05:47 -0700347 Probe.configure(self.args.mode, self.args.probability)
Howard McLauchlanef4154b2018-03-16 16:50:26 -0700348 # self, func, preds, total, entry
349
350 # create all the pair probes
351 for fx, preds in self.map.items():
352
353 # do the enter
354 self.probes.append(Probe(fx, preds, self.length, True))
355
356 if self.key == fx:
357 continue
358
359 # do the exit
360 self.probes.append(Probe(fx, preds, self.length, False))
361
362 def _parse_frames(self):
363 # sentinel
364 data = self.spec + '\0'
365 start, count = 0, 0
366
367 frames = []
368 cur_frame = []
369 i = 0
Howard McLauchlanfb3c0a72018-04-13 14:00:15 -0700370 last_frame_added = 0
Howard McLauchlanef4154b2018-03-16 16:50:26 -0700371
372 while i < len(data):
373 # improper input
374 if count < 0:
375 raise Exception("Check your parentheses")
376 c = data[i]
377 count += c == '('
378 count -= c == ')'
379 if not count:
Howard McLauchlan26882342018-03-21 15:29:39 -0700380 if c == '\0' or (c == '=' and data[i + 1] == '>'):
Howard McLauchlanfb3c0a72018-04-13 14:00:15 -0700381 # This block is closing a chunk. This means cur_frame must
382 # have something in it.
383 if not cur_frame:
384 raise Exception("Cannot parse spec, missing parens")
Howard McLauchlanef4154b2018-03-16 16:50:26 -0700385 if len(cur_frame) == 2:
386 frame = tuple(cur_frame)
387 elif cur_frame[0][0] == '(':
388 frame = self.key, cur_frame[0]
389 else:
390 frame = cur_frame[0], '(true)'
391 frames.append(frame)
392 del cur_frame[:]
393 i += 1
394 start = i + 1
395 elif c == ')':
396 cur_frame.append(data[start:i + 1].strip())
397 start = i + 1
Howard McLauchlanfb3c0a72018-04-13 14:00:15 -0700398 last_frame_added = start
Howard McLauchlanef4154b2018-03-16 16:50:26 -0700399 i += 1
Howard McLauchlanfb3c0a72018-04-13 14:00:15 -0700400
401 # We only permit spaces after the last frame
402 if self.spec[last_frame_added:].strip():
403 raise Exception("Invalid characters found after last frame");
Howard McLauchlanef4154b2018-03-16 16:50:26 -0700404 # improper input
405 if count:
406 raise Exception("Check your parentheses")
407 return frames
408
409 def _parse_spec(self):
410 frames = self._parse_frames()
411 frames.reverse()
412
413 absolute_order = 0
414 for f in frames:
415 # default case
416 func, pred = f[0], f[1]
417
418 if not self._validate_predicate(pred):
Howard McLauchlanfb3c0a72018-04-13 14:00:15 -0700419 raise Exception("Invalid predicate")
420 if not self._validate_identifier(func):
421 raise Exception("Invalid function identifier")
Howard McLauchlanef4154b2018-03-16 16:50:26 -0700422 tup = (pred, absolute_order)
423
424 if func not in self.map:
425 self.map[func] = [tup]
426 else:
427 self.map[func].append(tup)
428
429 absolute_order += 1
430
431 if self.key not in self.map:
432 self.map[self.key] = [('(true)', absolute_order)]
433 absolute_order += 1
434
435 self.length = absolute_order
436
Howard McLauchlanfb3c0a72018-04-13 14:00:15 -0700437 def _validate_identifier(self, func):
438 # We've already established paren balancing. We will only look for
439 # identifier validity here.
440 paren_index = func.find("(")
441 potential_id = func[:paren_index]
442 pattern = '[_a-zA-z][_a-zA-Z0-9]*$'
443 if re.match(pattern, potential_id):
444 return True
445 return False
446
Howard McLauchlanef4154b2018-03-16 16:50:26 -0700447 def _validate_predicate(self, pred):
448
449 if len(pred) > 0 and pred[0] == "(":
450 open = 1
451 for i in range(1, len(pred)):
452 if pred[i] == "(":
453 open += 1
454 elif pred[i] == ")":
455 open -= 1
456 if open != 0:
457 # not well formed, break
458 return False
459
460 return True
461
462 def _def_pid_struct(self):
463 text = """
464struct pid_struct {
465 u64 curr_call; /* book keeping to handle recursion */
466 u64 conds_met; /* stack pointer */
467 u64 stack[%s];
468};
469""" % self.length
470 return text
471
472 def _attach_probes(self):
473 self.bpf = BPF(text=self.program)
474 for p in self.probes:
475 p.attach(self.bpf)
476
477 def _generate_program(self):
478 # leave out auto includes for now
Howard McLauchlanb222f002018-04-10 13:05:47 -0700479 self.program += '#include <linux/mm.h>\n'
Howard McLauchlanef4154b2018-03-16 16:50:26 -0700480 for include in (self.args.include or []):
481 self.program += "#include <%s>\n" % include
482
483 self.program += self._def_pid_struct()
484 self.program += "BPF_HASH(m, u32, struct pid_struct);\n"
485 for p in self.probes:
486 self.program += p.generate_program() + "\n"
487
488 if self.args.verbose:
489 print(self.program)
490
491 def _main_loop(self):
492 while True:
493 self.bpf.perf_buffer_poll()
494
495 def run(self):
496 self._create_probes()
497 self._generate_program()
498 self._attach_probes()
499 self._main_loop()
500
501
502if __name__ == "__main__":
503 Tool().run()