Add stack trace to trace.py (#712)

diff --git a/tools/trace.py b/tools/trace.py
index 2e168b0..696ca61 100755
--- a/tools/trace.py
+++ b/tools/trace.py
@@ -10,6 +10,7 @@
 # Copyright (C) 2016 Sasha Goldshtein.
 
 from bcc import BPF, Tracepoint, Perf, USDT
+from functools import partial
 from time import sleep, strftime
 import argparse
 import re
@@ -58,10 +59,12 @@
                 cls.first_ts = Time.monotonic_time()
                 cls.pid = args.pid or -1
 
-        def __init__(self, probe, string_size):
+        def __init__(self, probe, string_size, kernel_stack, user_stack):
                 self.usdt = None
                 self.raw_probe = probe
                 self.string_size = string_size
+                self.kernel_stack = kernel_stack
+                self.user_stack = user_stack
                 Probe.probe_count += 1
                 self._parse_probe()
                 self.probe_num = Probe.probe_count
@@ -236,6 +239,10 @@
                 ]
                 for i in range(0, len(self.types)):
                         self._generate_python_field_decl(i, fields)
+                if self.kernel_stack:
+                        fields.append(("kernel_stack_id", ct.c_int))
+                if self.user_stack:
+                        fields.append(("user_stack_id", ct.c_int))
                 return type(self.python_struct_name, (ct.Structure,),
                             dict(_fields_=fields))
 
@@ -260,12 +267,19 @@
                 # construct the final display string.
                 self.events_name = "%s_events" % self.probe_name
                 self.struct_name = "%s_data_t" % self.probe_name
-
+                self.stacks_name = "%s_stacks" % self.probe_name
+                stack_table = "BPF_STACK_TRACE(%s, 1024);" % self.stacks_name \
+                              if (self.kernel_stack or self.user_stack) else ""
                 data_fields = ""
                 for i, field_type in enumerate(self.types):
                         data_fields += "        " + \
                                        self._generate_field_decl(i)
 
+                kernel_stack_str = "       int kernel_stack_id;" \
+                                   if self.kernel_stack else ""
+                user_stack_str = "       int user_stack_id;" \
+                                 if self.user_stack else ""
+
                 text = """
 struct %s
 {
@@ -273,11 +287,16 @@
         u32 pid;
         char comm[TASK_COMM_LEN];
 %s
+%s
+%s
 };
 
 BPF_PERF_OUTPUT(%s);
+%s
 """
-                return text % (self.struct_name, data_fields, self.events_name)
+                return text % (self.struct_name, data_fields,
+                               kernel_stack_str, user_stack_str,
+                               self.events_name, stack_table)
 
         def _generate_field_assign(self, idx):
                 field_type = self.types[idx]
@@ -346,6 +365,18 @@
                 for i, expr in enumerate(self.values):
                         data_fields += self._generate_field_assign(i)
 
+                stack_trace = ""
+                if self.user_stack:
+                        stack_trace += """
+        __data.user_stack_id = %s.get_stackid(
+          ctx, BPF_F_REUSE_STACKID | BPF_F_USER_STACK
+        );""" % self.stacks_name
+                if self.kernel_stack:
+                        stack_trace += """
+        __data.kernel_stack_id = %s.get_stackid(
+          ctx, BPF_F_REUSE_STACKID
+        );""" % self.stacks_name
+
                 text = """
 int %s(%s)
 {
@@ -359,6 +390,7 @@
         __data.pid = bpf_get_current_pid_tgid();
         bpf_get_current_comm(&__data.comm, sizeof(__data.comm));
 %s
+%s
         %s.perf_submit(ctx, &__data, sizeof(__data));
         return 0;
 }
@@ -366,7 +398,8 @@
                 text = text % (self.probe_name, signature,
                                pid_filter, prefix,
                                self._generate_usdt_filter_read(), self.filter,
-                               self.struct_name, data_fields, self.events_name)
+                               self.struct_name, data_fields,
+                               stack_trace, self.events_name)
 
                 return data_decl + "\n" + text
 
@@ -382,7 +415,16 @@
                 else:   # self.probe_type == 't'
                         return self.tp_event
 
-        def print_event(self, cpu, data, size):
+        def print_stack(self, bpf, stack_id, pid):
+            if stack_id < 0:
+                print("        %d" % stack_id)
+                return
+
+            stack = list(bpf.get_table(self.stacks_name).walk(stack_id))
+            for addr in stack:
+                print("        %016x %s" % (addr, bpf.sym(addr, pid)))
+
+        def print_event(self, bpf, cpu, data, size):
                 # Cast as the generated structure type and display
                 # according to the format string in the probe.
                 event = ct.cast(data, ct.POINTER(self.python_struct)).contents
@@ -395,6 +437,15 @@
                     (time[:8], event.pid, event.comm[:12],
                      self._display_function(), msg))
 
+                if self.user_stack:
+                    print("    User Stack Trace:")
+                    self.print_stack(bpf, event.user_stack_id, event.pid)
+                if self.kernel_stack:
+                    print("    Kernel Stack Trace:")
+                    self.print_stack(bpf, event.kernel_stack_id, -1)
+                if self.user_stack or self.kernel_stack:
+                    print("")
+
                 Probe.event_count += 1
                 if Probe.max_events is not None and \
                    Probe.event_count >= Probe.max_events:
@@ -406,7 +457,8 @@
                 else:
                         self._attach_u(bpf)
                 self.python_struct = self._generate_python_data_decl()
-                bpf[self.events_name].open_perf_buffer(self.print_event)
+                callback = partial(self.print_event, bpf)
+                bpf[self.events_name].open_perf_buffer(callback)
 
         def _attach_k(self, bpf):
                 if self.probe_type == "r":
@@ -482,6 +534,10 @@
                   help="number of events to print before quitting")
                 parser.add_argument("-o", "--offset", action="store_true",
                   help="use relative time from first traced message")
+                parser.add_argument("-K", "--kernel-stack", action="store_true",
+                  help="output kernel stack trace")
+                parser.add_argument("-U", "--user_stack", action="store_true",
+                  help="output user stack trace")
                 parser.add_argument(metavar="probe", dest="probes", nargs="+",
                   help="probe specifier (see examples)")
                 self.args = parser.parse_args()
@@ -491,7 +547,8 @@
                 self.probes = []
                 for probe_spec in self.args.probes:
                         self.probes.append(Probe(
-                                probe_spec, self.args.string_size))
+                                probe_spec, self.args.string_size,
+                                self.args.kernel_stack, self.args.user_stack))
 
         def _generate_program(self):
                 self.program = """