blob: 541b5104287ebacaeee5d95be041bfcc0c7795fc [file] [log] [blame]
Howard McLauchlanef4154b2018-03-16 16:50:26 -07001#!/usr/bin/env python3
2#
3# This script generates a BPF program with structure inspired by trace.py. The
4# generated program operates on PID-indexed stacks. Generally speaking,
5# bookkeeping is done at every intermediate function kprobe/kretprobe to enforce
6# the goal of "fail iff this call chain and these predicates".
7#
8# Top level functions(the ones at the end of the call chain) are responsible for
9# creating the pid_struct and deleting it from the map in kprobe and kretprobe
10# respectively.
11#
12# Intermediate functions(between should_fail_whatever and the top level
13# functions) are responsible for updating the stack to indicate "I have been
14# called and one of my predicate(s) passed" in their entry probes. In their exit
15# probes, they do the opposite, popping their stack to maintain correctness.
16# This implementation aims to ensure correctness in edge cases like recursive
17# calls, so there's some additional information stored in pid_struct for that.
18#
19# At the bottom level function(should_fail_whatever), we do a simple check to
20# ensure all necessary calls/predicates have passed before error injection.
21#
22# Note: presently there are a few hacks to get around various rewriter/verifier
23# issues.
24#
25# Note: this tool requires(as of v4.16-rc5):
26# - commit f7174d08a5fc ("mm: make should_failslab always available for fault
27# injection")
28# - CONFIG_BPF_KPROBE_OVERRIDE
29#
30# USAGE: inject [-h] [-I header] [-v]
31#
32# Copyright (c) 2018 Facebook, Inc.
33# Licensed under the Apache License, Version 2.0 (the "License")
34#
35# 16-Mar-2018 Howard McLauchlan Created this.
36
37import argparse
38from bcc import BPF
39
40
41class Probe:
42 errno_mapping = {
43 "kmalloc": "-ENOMEM",
44 "bio": "-EIO",
45 }
46
47 @classmethod
48 def configure(cls, mode):
49 cls.mode = mode
50
51 def __init__(self, func, preds, length, entry):
52 # length of call chain
53 self.length = length
54 self.func = func
55 self.preds = preds
56 self.is_entry = entry
57
58 def _bail(self, err):
59 raise ValueError("error in probe '%s': %s" %
60 (self.spec, err))
61
62 def _get_err(self):
63 return Probe.errno_mapping[Probe.mode]
64
65 def _get_if_top(self):
66 # ordering guarantees that if this function is top, the last tup is top
67 chk = self.preds[0][1] == 0
68 if not chk:
69 return ""
70
71 # init the map
72 # dont do an early exit here so the singular case works automatically
73 enter = """
74 /*
75 * Top level function init map
76 */
77 struct pid_struct p_struct = {0, 0};
78 m.insert(&pid, &p_struct);
79 """
80
81 # kill the entry
82 exit = """
83 /*
84 * Top level function clean up map
85 */
86 m.delete(&pid);
87 """
88
89 return enter if self.is_entry else exit
90
91 def _get_heading(self):
92
93 # we need to insert identifier and ctx into self.func
94 # gonna make a lot of formatting assumptions to make this work
95 left = self.func.find("(")
96 right = self.func.rfind(")")
97
98 # self.event and self.func_name need to be accessible
99 self.event = self.func[0:left]
100 self.func_name = self.event + ("_entry" if self.is_entry else "_exit")
101 func_sig = "struct pt_regs *ctx"
102
103 # assume theres something in there, no guarantee its well formed
104 if right > left + 1 and self.is_entry:
105 func_sig += ", " + self.func[left + 1:right]
106
107 return "int %s(%s)" % (self.func_name, func_sig)
108
109 def _get_entry_logic(self):
110 # there is at least one tup(pred, place) for this function
111 text = """
112
113 if (p->conds_met >= %s)
114 return 0;
115 if (p->conds_met == %s && %s) {
116 p->stack[%s] = p->curr_call;
117 p->conds_met++;
118 }"""
119 text = text % (self.length, self.preds[0][1], self.preds[0][0],
120 self.preds[0][1])
121
122 # for each additional pred
123 for tup in self.preds[1:]:
124 text += """
125 else if (p->conds_met == %s && %s) {
126 p->stack[%s] = p->curr_call;
127 p->conds_met++;
128 }
129 """ % (tup[1], tup[0], tup[1])
130 return text
131
132 def _generate_entry(self):
133 prog = self._get_heading() + """
134{
135 u32 pid = bpf_get_current_pid_tgid();
136 %s
137
138 struct pid_struct *p = m.lookup(&pid);
139
140 if (!p)
141 return 0;
142
143 /*
144 * preparation for predicate, if necessary
145 */
146 %s
147 /*
148 * Generate entry logic
149 */
150 %s
151
152 p->curr_call++;
153
154 return 0;
155}"""
156
157 prog = prog % (self._get_if_top(), self.prep, self._get_entry_logic())
158 return prog
159
160 # only need to check top of stack
161 def _get_exit_logic(self):
162 text = """
163 if (p->conds_met < 1 || p->conds_met >= %s)
164 return 0;
165
166 if (p->stack[p->conds_met - 1] == p->curr_call)
167 p->conds_met--;
168 """
169 return text % str(self.length + 1)
170
171 def _generate_exit(self):
172 prog = self._get_heading() + """
173{
174 u32 pid = bpf_get_current_pid_tgid();
175
176 struct pid_struct *p = m.lookup(&pid);
177
178 if (!p)
179 return 0;
180
181 p->curr_call--;
182
183 /*
184 * Generate exit logic
185 */
186 %s
187 %s
188 return 0;
189}"""
190
191 prog = prog % (self._get_exit_logic(), self._get_if_top())
192
193 return prog
194
195 # Special case for should_fail_whatever
196 def _generate_bottom(self):
197 pred = self.preds[0][0]
198 text = self._get_heading() + """
199{
200 /*
201 * preparation for predicate, if necessary
202 */
203 %s
204 /*
205 * If this is the only call in the chain and predicate passes
206 */
207 if (%s == 1 && %s) {
208 bpf_override_return(ctx, %s);
209 return 0;
210 }
211 u32 pid = bpf_get_current_pid_tgid();
212
213 struct pid_struct *p = m.lookup(&pid);
214
215 if (!p)
216 return 0;
217
218 /*
219 * If all conds have been met and predicate passes
220 */
221 if (p->conds_met == %s && %s)
222 bpf_override_return(ctx, %s);
223 return 0;
224}""" % (self.prep, self.length, pred, self._get_err(), self.length - 1, pred,
225 self._get_err())
226 return text
227
228 # presently parses and replaces STRCMP
229 # STRCMP exists because string comparison is inconvenient and somewhat buggy
230 # https://github.com/iovisor/bcc/issues/1617
231 def _prepare_pred(self):
232 self.prep = ""
233 for i in range(len(self.preds)):
234 new_pred = ""
235 pred = self.preds[i][0]
236 place = self.preds[i][1]
237 start, ind = 0, 0
238 while start < len(pred):
239 ind = pred.find("STRCMP(", start)
240 if ind == -1:
241 break
242 new_pred += pred[start:ind]
243 # 7 is len("STRCMP(")
244 start = pred.find(")", start + 7) + 1
245
246 # then ind ... start is STRCMP(...)
247 ptr, literal = pred[ind + 7:start - 1].split(",")
248 literal = literal.strip()
249
250 # x->y->z, some string literal
251 # we make unique id with place_ind
252 uuid = "%s_%s" % (place, ind)
253 unique_bool = "is_true_%s" % uuid
254 self.prep += """
255 char *str_%s = %s;
256 bool %s = true;\n""" % (uuid, ptr.strip(), unique_bool)
257
258 check = "\t%s &= *(str_%s++) == '%%s';\n" % (unique_bool, uuid)
259
260 for ch in literal:
261 self.prep += check % ch
262 self.prep += check % r'\0'
263 new_pred += unique_bool
264
265 new_pred += pred[start:]
266 self.preds[i] = (new_pred, place)
267
268 def generate_program(self):
269 # generate code to work around various rewriter issues
270 self._prepare_pred()
271
272 # special case for bottom
273 if self.preds[-1][1] == self.length - 1:
274 return self._generate_bottom()
275
276 return self._generate_entry() if self.is_entry else self._generate_exit()
277
278 def attach(self, bpf):
279 if self.is_entry:
280 bpf.attach_kprobe(event=self.event,
281 fn_name=self.func_name)
282 else:
283 bpf.attach_kretprobe(event=self.event,
284 fn_name=self.func_name)
285
286
287class Tool:
288 # add cases as necessary
289 error_injection_mapping = {
290 "kmalloc": "should_failslab(struct kmem_cache *s, gfp_t gfpflags)",
291 "bio": "should_fail_bio(struct bio *bio)",
292 }
293
294 def __init__(self):
295 parser = argparse.ArgumentParser(description="Fail specified kernel" +
296 " functionality when call chain and predicates are met",
297 formatter_class=argparse.RawDescriptionHelpFormatter)
298 parser.add_argument(metavar="mode", dest="mode",
299 help="indicate which base kernel function to fail")
300 parser.add_argument(metavar="spec", dest="spec",
301 help="specify call chain")
302 parser.add_argument("-I", "--include", action="append",
303 metavar="header",
304 help="additional header files to include in the BPF program")
305 parser.add_argument("-v", "--verbose", action="store_true",
306 help="print BPF program")
307 self.args = parser.parse_args()
308
309 self.program = ""
310 self.spec = self.args.spec
311 self.map = {}
312 self.probes = []
313 self.key = Tool.error_injection_mapping[self.args.mode]
314
315 # create_probes and associated stuff
316 def _create_probes(self):
317 self._parse_spec()
318 Probe.configure(self.args.mode)
319 # self, func, preds, total, entry
320
321 # create all the pair probes
322 for fx, preds in self.map.items():
323
324 # do the enter
325 self.probes.append(Probe(fx, preds, self.length, True))
326
327 if self.key == fx:
328 continue
329
330 # do the exit
331 self.probes.append(Probe(fx, preds, self.length, False))
332
333 def _parse_frames(self):
334 # sentinel
335 data = self.spec + '\0'
336 start, count = 0, 0
337
338 frames = []
339 cur_frame = []
340 i = 0
341
342 while i < len(data):
343 # improper input
344 if count < 0:
345 raise Exception("Check your parentheses")
346 c = data[i]
347 count += c == '('
348 count -= c == ')'
349 if not count:
Howard McLauchlan26882342018-03-21 15:29:39 -0700350 if c == '\0' or (c == '=' and data[i + 1] == '>'):
Howard McLauchlanef4154b2018-03-16 16:50:26 -0700351 if len(cur_frame) == 2:
352 frame = tuple(cur_frame)
353 elif cur_frame[0][0] == '(':
354 frame = self.key, cur_frame[0]
355 else:
356 frame = cur_frame[0], '(true)'
357 frames.append(frame)
358 del cur_frame[:]
359 i += 1
360 start = i + 1
361 elif c == ')':
362 cur_frame.append(data[start:i + 1].strip())
363 start = i + 1
364 i += 1
365 # improper input
366 if count:
367 raise Exception("Check your parentheses")
368 return frames
369
370 def _parse_spec(self):
371 frames = self._parse_frames()
372 frames.reverse()
373
374 absolute_order = 0
375 for f in frames:
376 # default case
377 func, pred = f[0], f[1]
378
379 if not self._validate_predicate(pred):
380 raise Exception
381 tup = (pred, absolute_order)
382
383 if func not in self.map:
384 self.map[func] = [tup]
385 else:
386 self.map[func].append(tup)
387
388 absolute_order += 1
389
390 if self.key not in self.map:
391 self.map[self.key] = [('(true)', absolute_order)]
392 absolute_order += 1
393
394 self.length = absolute_order
395
396 def _validate_predicate(self, pred):
397
398 if len(pred) > 0 and pred[0] == "(":
399 open = 1
400 for i in range(1, len(pred)):
401 if pred[i] == "(":
402 open += 1
403 elif pred[i] == ")":
404 open -= 1
405 if open != 0:
406 # not well formed, break
407 return False
408
409 return True
410
411 def _def_pid_struct(self):
412 text = """
413struct pid_struct {
414 u64 curr_call; /* book keeping to handle recursion */
415 u64 conds_met; /* stack pointer */
416 u64 stack[%s];
417};
418""" % self.length
419 return text
420
421 def _attach_probes(self):
422 self.bpf = BPF(text=self.program)
423 for p in self.probes:
424 p.attach(self.bpf)
425
426 def _generate_program(self):
427 # leave out auto includes for now
428
429 for include in (self.args.include or []):
430 self.program += "#include <%s>\n" % include
431
432 self.program += self._def_pid_struct()
433 self.program += "BPF_HASH(m, u32, struct pid_struct);\n"
434 for p in self.probes:
435 self.program += p.generate_program() + "\n"
436
437 if self.args.verbose:
438 print(self.program)
439
440 def _main_loop(self):
441 while True:
442 self.bpf.perf_buffer_poll()
443
444 def run(self):
445 self._create_probes()
446 self._generate_program()
447 self._attach_probes()
448 self._main_loop()
449
450
451if __name__ == "__main__":
452 Tool().run()