trace: add pid/tid filtering, fix symbolizing, misc nits (#798)

* support filtering by process ID (-p) or thread ID (-t); previously -p
  actually filtered on thread ID (aka "pid" in kernel-speak)
* include process and thread ID in output
* flip order of user and kernel stacks to flow more naturally
* resolve symbols using process ID instead of thread ID so only one symbol
  cache is instantiated per process
* misc aesthetic fixes here and there
diff --git a/tools/trace.py b/tools/trace.py
index eae625c..0466b47 100755
--- a/tools/trace.py
+++ b/tools/trace.py
@@ -3,8 +3,9 @@
 # trace         Trace a function and print a trace message based on its
 #               parameters, with an optional filter.
 #
-# USAGE: trace [-h] [-p PID] [-v] [-Z STRING_SIZE] [-S] [-M MAX_EVENTS] [-o]
-#              [-K] [-U] [-I header] probe [probe ...]
+# usage: trace [-h] [-p PID] [-t TID] [-v] [-Z STRING_SIZE] [-S]
+#              [-M MAX_EVENTS] [-o] [-K] [-U] [-I header]
+#              probe [probe ...]
 #
 # Licensed under the Apache License, Version 2.0 (the "License")
 # Copyright (C) 2016 Sasha Goldshtein.
@@ -51,6 +52,7 @@
         event_count = 0
         first_ts = 0
         use_localtime = True
+        tgid = -1
         pid = -1
 
         @classmethod
@@ -58,6 +60,7 @@
                 cls.max_events = args.max_events
                 cls.use_localtime = not args.offset
                 cls.first_ts = Time.monotonic_time()
+                cls.tgid = args.tgid or -1
                 cls.pid = args.pid or -1
 
         def __init__(self, probe, string_size, kernel_stack, user_stack):
@@ -154,7 +157,8 @@
                         self.function = parts[2]
 
         def _find_usdt_probe(self):
-                self.usdt = USDT(path=self.library, pid=Probe.pid)
+                target = Probe.pid if Probe.pid else Probe.tgid
+                self.usdt = USDT(path=self.library, pid=target)
                 for probe in self.usdt.enumerate_probes():
                         if probe.name == self.usdt_name:
                                 return  # Found it, will enable later
@@ -258,6 +262,7 @@
                                 (self._display_function(), self.probe_num)
                 fields = [
                         ("timestamp_ns", ct.c_ulonglong),
+                        ("tgid", ct.c_uint),
                         ("pid", ct.c_uint),
                         ("comm", ct.c_char * 16)       # TASK_COMM_LEN
                 ]
@@ -309,6 +314,7 @@
 struct %s
 {
         u64 timestamp_ns;
+        u32 tgid;
         u32 pid;
         char comm[TASK_COMM_LEN];
 %s
@@ -370,13 +376,15 @@
                 # it to the function body:
                 if len(self.library) == 0 and Probe.pid != -1:
                         pid_filter = """
-        u32 __pid = bpf_get_current_pid_tgid();
         if (__pid != %d) { return 0; }
                 """ % Probe.pid
+                elif len(self.library) == 0 and Probe.tgid != -1:
+                        pid_filter = """
+        if (__tgid != %d) { return 0; }
+                """ % Probe.tgid
                 elif not include_self:
                         pid_filter = """
-        u32 __pid = bpf_get_current_pid_tgid();
-        if (__pid == %d) { return 0; }
+        if (__tgid == %d) { return 0; }
                 """ % os.getpid()
                 else:
                         pid_filter = ""
@@ -410,6 +418,9 @@
 
                 text = heading + """
 {
+        u64 __pid_tgid = bpf_get_current_pid_tgid();
+        u32 __tgid = __pid_tgid >> 32;
+        u32 __pid = __pid_tgid; // implicit cast to u32 for bottom half
         %s
         %s
         %s
@@ -417,7 +428,8 @@
 
         struct %s __data = {0};
         __data.timestamp_ns = bpf_ktime_get_ns();
-        __data.pid = bpf_get_current_pid_tgid();
+        __data.tgid = __tgid;
+        __data.pid = __pid;
         bpf_get_current_comm(&__data.comm, sizeof(__data.comm));
 %s
 %s
@@ -444,17 +456,17 @@
                 else:   # self.probe_type == 't'
                         return self.tp_event
 
-        def print_stack(self, bpf, stack_id, pid):
+        def print_stack(self, bpf, stack_id, tgid):
             if stack_id < 0:
                     print("        %d" % stack_id)
                     return
 
             stack = list(bpf.get_table(self.stacks_name).walk(stack_id))
             for addr in stack:
-                    print("        %016x %s" % (addr, bpf.sym(addr, pid)))
+                    print("        %016x %s" % (addr, bpf.sym(addr, tgid)))
 
-        def _format_message(self, bpf, pid, values):
-                # Replace each %K with kernel sym and %U with user sym in pid
+        def _format_message(self, bpf, tgid, values):
+                # Replace each %K with kernel sym and %U with user sym in tgid
                 kernel_placeholders = [i for i in xrange(0, len(self.types))
                                        if self.types[i] == 'K']
                 user_placeholders = [i for i in xrange(0, len(self.types))
@@ -462,7 +474,7 @@
                 for kp in kernel_placeholders:
                         values[kp] = bpf.ksymaddr(values[kp])
                 for up in user_placeholders:
-                        values[up] = bpf.symaddr(values[up], pid)
+                        values[up] = bpf.symaddr(values[up], tgid)
                 return self.python_format % tuple(values)
 
         def print_event(self, bpf, cpu, data, size):
@@ -471,19 +483,17 @@
                 event = ct.cast(data, ct.POINTER(self.python_struct)).contents
                 values = map(lambda i: getattr(event, "v%d" % i),
                              range(0, len(self.values)))
-                msg = self._format_message(bpf, event.pid, values)
+                msg = self._format_message(bpf, event.tgid, values)
                 time = strftime("%H:%M:%S") if Probe.use_localtime else \
                        Probe._time_off_str(event.timestamp_ns)
-                print("%-8s %-6d %-12s %-16s %s" %
-                    (time[:8], event.pid, event.comm[:12],
+                print("%-8s %-6d %-6d %-12s %-16s %s" %
+                    (time[:8], event.tgid, event.pid, event.comm,
                      self._display_function(), msg))
 
-                if self.user_stack:
-                        print("    User Stack Trace:")
-                        self.print_stack(bpf, event.user_stack_id, event.pid)
                 if self.kernel_stack:
-                        print("    Kernel Stack Trace:")
                         self.print_stack(bpf, event.kernel_stack_id, -1)
+                if self.user_stack:
+                        self.print_stack(bpf, event.user_stack_id, event.tgid)
                 if self.user_stack or self.kernel_stack:
                         print("")
 
@@ -549,9 +559,9 @@
         Trace malloc calls and print the size being allocated
 trace 'p:c:write (arg1 == 1) "writing %d bytes to STDOUT", arg3'
         Trace the write() call from libc to monitor writes to STDOUT
-trace 'r::__kmalloc (retval == 0) "kmalloc failed!"
+trace 'r::__kmalloc (retval == 0) "kmalloc failed!"'
         Trace returns from __kmalloc which returned a null pointer
-trace 'r:c:malloc (retval) "allocated = %x", retval
+trace 'r:c:malloc (retval) "allocated = %x", retval'
         Trace returns from malloc and print non-NULL allocated buffers
 trace 't:block:block_rq_complete "sectors=%d", args->nr_sector'
         Trace the block_rq_complete kernel tracepoint and print # of tx sectors
@@ -564,8 +574,12 @@
                   "functions and print trace messages.",
                   formatter_class=argparse.RawDescriptionHelpFormatter,
                   epilog=Tool.examples)
-                parser.add_argument("-p", "--pid", type=int,
-                  help="id of the process to trace (optional)")
+                # we'll refer to the userspace concepts of "pid" and "tid" by
+                # their kernel names -- tgid and pid -- inside the script
+                parser.add_argument("-p", "--pid", type=int, metavar="PID",
+                  dest="tgid", help="id of the process to trace (optional)")
+                parser.add_argument("-t", "--tid", type=int, metavar="TID",
+                  dest="pid", help="id of the thread to trace (optional)")
                 parser.add_argument("-v", "--verbose", action="store_true",
                   help="print resulting BPF program code before executing")
                 parser.add_argument("-Z", "--string-size", type=int,
@@ -587,6 +601,8 @@
                   metavar="header",
                   help="additional header files to include in the BPF program")
                 self.args = parser.parse_args()
+                if self.args.tgid and self.args.pid:
+                        parser.error("only one of -p and -t may be specified")
 
         def _create_probes(self):
                 Probe.configure(self.args)
@@ -636,8 +652,8 @@
                                              self.probes))
 
                 # Print header
-                print("%-8s %-6s %-12s %-16s %s" %
-                      ("TIME", "PID", "COMM", "FUNC",
+                print("%-8s %-6s %-6s %-12s %-16s %s" %
+                      ("TIME", "PID", "TID", "COMM", "FUNC",
                       "-" if not all_probes_trivial else ""))
 
                 while True: