blob: 46b3825808f6dbcb820451d6a01cc6152a09264a [file] [log] [blame]
Howard McLauchlanef4154b2018-03-16 16:50:26 -07001#!/usr/bin/env python3
2#
3# This script generates a BPF program with structure inspired by trace.py. The
4# generated program operates on PID-indexed stacks. Generally speaking,
5# bookkeeping is done at every intermediate function kprobe/kretprobe to enforce
6# the goal of "fail iff this call chain and these predicates".
7#
8# Top level functions(the ones at the end of the call chain) are responsible for
9# creating the pid_struct and deleting it from the map in kprobe and kretprobe
10# respectively.
11#
12# Intermediate functions(between should_fail_whatever and the top level
13# functions) are responsible for updating the stack to indicate "I have been
14# called and one of my predicate(s) passed" in their entry probes. In their exit
15# probes, they do the opposite, popping their stack to maintain correctness.
16# This implementation aims to ensure correctness in edge cases like recursive
17# calls, so there's some additional information stored in pid_struct for that.
18#
19# At the bottom level function(should_fail_whatever), we do a simple check to
20# ensure all necessary calls/predicates have passed before error injection.
21#
22# Note: presently there are a few hacks to get around various rewriter/verifier
23# issues.
24#
Howard McLauchlanb222f002018-04-10 13:05:47 -070025# Note: this tool requires:
Howard McLauchlanef4154b2018-03-16 16:50:26 -070026# - CONFIG_BPF_KPROBE_OVERRIDE
27#
Howard McLauchlanb222f002018-04-10 13:05:47 -070028# USAGE: inject [-h] [-I header] [-P probability] [-v] mode spec
Howard McLauchlanef4154b2018-03-16 16:50:26 -070029#
30# Copyright (c) 2018 Facebook, Inc.
31# Licensed under the Apache License, Version 2.0 (the "License")
32#
33# 16-Mar-2018 Howard McLauchlan Created this.
34
35import argparse
36from bcc import BPF
37
38
39class Probe:
40 errno_mapping = {
41 "kmalloc": "-ENOMEM",
42 "bio": "-EIO",
43 }
44
45 @classmethod
Howard McLauchlanb222f002018-04-10 13:05:47 -070046 def configure(cls, mode, probability):
Howard McLauchlanef4154b2018-03-16 16:50:26 -070047 cls.mode = mode
Howard McLauchlanb222f002018-04-10 13:05:47 -070048 cls.probability = probability
Howard McLauchlanef4154b2018-03-16 16:50:26 -070049
50 def __init__(self, func, preds, length, entry):
51 # length of call chain
52 self.length = length
53 self.func = func
54 self.preds = preds
55 self.is_entry = entry
56
57 def _bail(self, err):
58 raise ValueError("error in probe '%s': %s" %
59 (self.spec, err))
60
61 def _get_err(self):
62 return Probe.errno_mapping[Probe.mode]
63
64 def _get_if_top(self):
65 # ordering guarantees that if this function is top, the last tup is top
66 chk = self.preds[0][1] == 0
67 if not chk:
68 return ""
69
Howard McLauchlanb222f002018-04-10 13:05:47 -070070 if Probe.probability == 1:
71 early_pred = "false"
72 else:
73 early_pred = "bpf_get_prandom_u32() > %s" % str(int((1<<32)*Probe.probability))
Howard McLauchlanef4154b2018-03-16 16:50:26 -070074 # init the map
75 # dont do an early exit here so the singular case works automatically
Howard McLauchlanb222f002018-04-10 13:05:47 -070076 # have an early exit for probability option
Howard McLauchlanef4154b2018-03-16 16:50:26 -070077 enter = """
78 /*
Howard McLauchlanb222f002018-04-10 13:05:47 -070079 * Early exit for probability case
80 */
81 if (%s)
82 return 0;
83 /*
Howard McLauchlanef4154b2018-03-16 16:50:26 -070084 * Top level function init map
85 */
86 struct pid_struct p_struct = {0, 0};
87 m.insert(&pid, &p_struct);
Howard McLauchlanb222f002018-04-10 13:05:47 -070088 """ % early_pred
Howard McLauchlanef4154b2018-03-16 16:50:26 -070089
90 # kill the entry
91 exit = """
92 /*
93 * Top level function clean up map
94 */
95 m.delete(&pid);
96 """
97
98 return enter if self.is_entry else exit
99
100 def _get_heading(self):
101
102 # we need to insert identifier and ctx into self.func
103 # gonna make a lot of formatting assumptions to make this work
104 left = self.func.find("(")
105 right = self.func.rfind(")")
106
107 # self.event and self.func_name need to be accessible
108 self.event = self.func[0:left]
109 self.func_name = self.event + ("_entry" if self.is_entry else "_exit")
110 func_sig = "struct pt_regs *ctx"
111
112 # assume theres something in there, no guarantee its well formed
113 if right > left + 1 and self.is_entry:
114 func_sig += ", " + self.func[left + 1:right]
115
116 return "int %s(%s)" % (self.func_name, func_sig)
117
118 def _get_entry_logic(self):
119 # there is at least one tup(pred, place) for this function
120 text = """
121
122 if (p->conds_met >= %s)
123 return 0;
124 if (p->conds_met == %s && %s) {
125 p->stack[%s] = p->curr_call;
126 p->conds_met++;
127 }"""
128 text = text % (self.length, self.preds[0][1], self.preds[0][0],
129 self.preds[0][1])
130
131 # for each additional pred
132 for tup in self.preds[1:]:
133 text += """
134 else if (p->conds_met == %s && %s) {
135 p->stack[%s] = p->curr_call;
136 p->conds_met++;
137 }
138 """ % (tup[1], tup[0], tup[1])
139 return text
140
141 def _generate_entry(self):
142 prog = self._get_heading() + """
143{
144 u32 pid = bpf_get_current_pid_tgid();
145 %s
146
147 struct pid_struct *p = m.lookup(&pid);
148
149 if (!p)
150 return 0;
151
152 /*
153 * preparation for predicate, if necessary
154 */
155 %s
156 /*
157 * Generate entry logic
158 */
159 %s
160
161 p->curr_call++;
162
163 return 0;
164}"""
165
166 prog = prog % (self._get_if_top(), self.prep, self._get_entry_logic())
167 return prog
168
169 # only need to check top of stack
170 def _get_exit_logic(self):
171 text = """
172 if (p->conds_met < 1 || p->conds_met >= %s)
173 return 0;
174
175 if (p->stack[p->conds_met - 1] == p->curr_call)
176 p->conds_met--;
177 """
178 return text % str(self.length + 1)
179
180 def _generate_exit(self):
181 prog = self._get_heading() + """
182{
183 u32 pid = bpf_get_current_pid_tgid();
184
185 struct pid_struct *p = m.lookup(&pid);
186
187 if (!p)
188 return 0;
189
190 p->curr_call--;
191
192 /*
193 * Generate exit logic
194 */
195 %s
196 %s
197 return 0;
198}"""
199
200 prog = prog % (self._get_exit_logic(), self._get_if_top())
201
202 return prog
203
204 # Special case for should_fail_whatever
205 def _generate_bottom(self):
206 pred = self.preds[0][0]
207 text = self._get_heading() + """
208{
209 /*
210 * preparation for predicate, if necessary
211 */
212 %s
213 /*
214 * If this is the only call in the chain and predicate passes
215 */
216 if (%s == 1 && %s) {
217 bpf_override_return(ctx, %s);
218 return 0;
219 }
220 u32 pid = bpf_get_current_pid_tgid();
221
222 struct pid_struct *p = m.lookup(&pid);
223
224 if (!p)
225 return 0;
226
227 /*
228 * If all conds have been met and predicate passes
229 */
230 if (p->conds_met == %s && %s)
231 bpf_override_return(ctx, %s);
232 return 0;
Howard McLauchlanb222f002018-04-10 13:05:47 -0700233}"""
234 return text % (self.prep, self.length, pred, self._get_err(),
235 self.length - 1, pred, self._get_err())
Howard McLauchlanef4154b2018-03-16 16:50:26 -0700236
237 # presently parses and replaces STRCMP
238 # STRCMP exists because string comparison is inconvenient and somewhat buggy
239 # https://github.com/iovisor/bcc/issues/1617
240 def _prepare_pred(self):
241 self.prep = ""
242 for i in range(len(self.preds)):
243 new_pred = ""
244 pred = self.preds[i][0]
245 place = self.preds[i][1]
246 start, ind = 0, 0
247 while start < len(pred):
248 ind = pred.find("STRCMP(", start)
249 if ind == -1:
250 break
251 new_pred += pred[start:ind]
252 # 7 is len("STRCMP(")
253 start = pred.find(")", start + 7) + 1
254
255 # then ind ... start is STRCMP(...)
256 ptr, literal = pred[ind + 7:start - 1].split(",")
257 literal = literal.strip()
258
259 # x->y->z, some string literal
260 # we make unique id with place_ind
261 uuid = "%s_%s" % (place, ind)
262 unique_bool = "is_true_%s" % uuid
263 self.prep += """
264 char *str_%s = %s;
265 bool %s = true;\n""" % (uuid, ptr.strip(), unique_bool)
266
267 check = "\t%s &= *(str_%s++) == '%%s';\n" % (unique_bool, uuid)
268
269 for ch in literal:
270 self.prep += check % ch
271 self.prep += check % r'\0'
272 new_pred += unique_bool
273
274 new_pred += pred[start:]
275 self.preds[i] = (new_pred, place)
276
277 def generate_program(self):
278 # generate code to work around various rewriter issues
279 self._prepare_pred()
280
281 # special case for bottom
282 if self.preds[-1][1] == self.length - 1:
283 return self._generate_bottom()
284
285 return self._generate_entry() if self.is_entry else self._generate_exit()
286
287 def attach(self, bpf):
288 if self.is_entry:
289 bpf.attach_kprobe(event=self.event,
290 fn_name=self.func_name)
291 else:
292 bpf.attach_kretprobe(event=self.event,
293 fn_name=self.func_name)
294
295
296class Tool:
297 # add cases as necessary
298 error_injection_mapping = {
299 "kmalloc": "should_failslab(struct kmem_cache *s, gfp_t gfpflags)",
300 "bio": "should_fail_bio(struct bio *bio)",
301 }
302
303 def __init__(self):
304 parser = argparse.ArgumentParser(description="Fail specified kernel" +
305 " functionality when call chain and predicates are met",
306 formatter_class=argparse.RawDescriptionHelpFormatter)
307 parser.add_argument(metavar="mode", dest="mode",
308 help="indicate which base kernel function to fail")
309 parser.add_argument(metavar="spec", dest="spec",
310 help="specify call chain")
311 parser.add_argument("-I", "--include", action="append",
312 metavar="header",
313 help="additional header files to include in the BPF program")
Howard McLauchlanb222f002018-04-10 13:05:47 -0700314 parser.add_argument("-P", "--probability", default=1,
315 metavar="probability", type=float,
316 help="probability that this call chain will fail")
Howard McLauchlanef4154b2018-03-16 16:50:26 -0700317 parser.add_argument("-v", "--verbose", action="store_true",
318 help="print BPF program")
319 self.args = parser.parse_args()
320
321 self.program = ""
322 self.spec = self.args.spec
323 self.map = {}
324 self.probes = []
325 self.key = Tool.error_injection_mapping[self.args.mode]
326
327 # create_probes and associated stuff
328 def _create_probes(self):
329 self._parse_spec()
Howard McLauchlanb222f002018-04-10 13:05:47 -0700330 Probe.configure(self.args.mode, self.args.probability)
Howard McLauchlanef4154b2018-03-16 16:50:26 -0700331 # self, func, preds, total, entry
332
333 # create all the pair probes
334 for fx, preds in self.map.items():
335
336 # do the enter
337 self.probes.append(Probe(fx, preds, self.length, True))
338
339 if self.key == fx:
340 continue
341
342 # do the exit
343 self.probes.append(Probe(fx, preds, self.length, False))
344
345 def _parse_frames(self):
346 # sentinel
347 data = self.spec + '\0'
348 start, count = 0, 0
349
350 frames = []
351 cur_frame = []
352 i = 0
353
354 while i < len(data):
355 # improper input
356 if count < 0:
357 raise Exception("Check your parentheses")
358 c = data[i]
359 count += c == '('
360 count -= c == ')'
361 if not count:
Howard McLauchlan26882342018-03-21 15:29:39 -0700362 if c == '\0' or (c == '=' and data[i + 1] == '>'):
Howard McLauchlanef4154b2018-03-16 16:50:26 -0700363 if len(cur_frame) == 2:
364 frame = tuple(cur_frame)
365 elif cur_frame[0][0] == '(':
366 frame = self.key, cur_frame[0]
367 else:
368 frame = cur_frame[0], '(true)'
369 frames.append(frame)
370 del cur_frame[:]
371 i += 1
372 start = i + 1
373 elif c == ')':
374 cur_frame.append(data[start:i + 1].strip())
375 start = i + 1
376 i += 1
377 # improper input
378 if count:
379 raise Exception("Check your parentheses")
380 return frames
381
382 def _parse_spec(self):
383 frames = self._parse_frames()
384 frames.reverse()
385
386 absolute_order = 0
387 for f in frames:
388 # default case
389 func, pred = f[0], f[1]
390
391 if not self._validate_predicate(pred):
392 raise Exception
393 tup = (pred, absolute_order)
394
395 if func not in self.map:
396 self.map[func] = [tup]
397 else:
398 self.map[func].append(tup)
399
400 absolute_order += 1
401
402 if self.key not in self.map:
403 self.map[self.key] = [('(true)', absolute_order)]
404 absolute_order += 1
405
406 self.length = absolute_order
407
408 def _validate_predicate(self, pred):
409
410 if len(pred) > 0 and pred[0] == "(":
411 open = 1
412 for i in range(1, len(pred)):
413 if pred[i] == "(":
414 open += 1
415 elif pred[i] == ")":
416 open -= 1
417 if open != 0:
418 # not well formed, break
419 return False
420
421 return True
422
423 def _def_pid_struct(self):
424 text = """
425struct pid_struct {
426 u64 curr_call; /* book keeping to handle recursion */
427 u64 conds_met; /* stack pointer */
428 u64 stack[%s];
429};
430""" % self.length
431 return text
432
433 def _attach_probes(self):
434 self.bpf = BPF(text=self.program)
435 for p in self.probes:
436 p.attach(self.bpf)
437
438 def _generate_program(self):
439 # leave out auto includes for now
Howard McLauchlanb222f002018-04-10 13:05:47 -0700440 self.program += '#include <linux/mm.h>\n'
Howard McLauchlanef4154b2018-03-16 16:50:26 -0700441 for include in (self.args.include or []):
442 self.program += "#include <%s>\n" % include
443
444 self.program += self._def_pid_struct()
445 self.program += "BPF_HASH(m, u32, struct pid_struct);\n"
446 for p in self.probes:
447 self.program += p.generate_program() + "\n"
448
449 if self.args.verbose:
450 print(self.program)
451
452 def _main_loop(self):
453 while True:
454 self.bpf.perf_buffer_poll()
455
456 def run(self):
457 self._create_probes()
458 self._generate_program()
459 self._attach_probes()
460 self._main_loop()
461
462
463if __name__ == "__main__":
464 Tool().run()