tracepoint support for argdist and trace, and new tplist tool for printing tracepoints
diff --git a/tools/trace.py b/tools/trace.py
index 186fdfd..8f1ce89 100755
--- a/tools/trace.py
+++ b/tools/trace.py
@@ -5,6 +5,7 @@
 #
 # USAGE: trace [-h] [-p PID] [-v] [-Z STRING_SIZE] [-S] [-M MAX_EVENTS] [-o]
 #              probe [probe ...]
+#
 # Licensed under the Apache License, Version 2.0 (the "License")
 # Copyright (C) 2016 Sasha Goldshtein.
 
@@ -13,6 +14,7 @@
 import argparse
 import re
 import ctypes as ct
+import multiprocessing
 import os
 import traceback
 import sys
@@ -42,6 +44,122 @@
                         raise OSError(errno_, os.strerror(errno_))
                 return t.tv_sec * 1e9 + t.tv_nsec
 
+class Perf(object):
+        class perf_event_attr(ct.Structure):
+                _fields_ = [
+                        ('type', ct.c_uint),
+                        ('size', ct.c_uint),
+                        ('config', ct.c_ulong),
+                        ('sample_period', ct.c_ulong),
+                        ('sample_type', ct.c_ulong),
+                        ('IGNORE1', ct.c_ulong),
+                        ('IGNORE2', ct.c_ulong),
+                        ('wakeup_events', ct.c_uint),
+                        ('IGNORE3', ct.c_uint),
+                        ('IGNORE4', ct.c_ulong),
+                        ('IGNORE5', ct.c_ulong),
+                        ('IGNORE6', ct.c_ulong),
+                        ('IGNORE7', ct.c_uint),
+                        ('IGNORE8', ct.c_int),
+                        ('IGNORE9', ct.c_ulong),
+                        ('IGNORE10', ct.c_uint),
+                        ('IGNORE11', ct.c_uint)
+                ]
+
+        NR_PERF_EVENT_OPEN = 298
+        PERF_TYPE_TRACEPOINT = 2
+        PERF_SAMPLE_RAW = 1024
+        PERF_FLAG_FD_CLOEXEC = 8
+        PERF_EVENT_IOC_SET_FILTER = 1074275334
+        PERF_EVENT_IOC_ENABLE = 9216
+
+        libc = ct.CDLL('libc.so.6', use_errno=True)
+        syscall = libc.syscall          # not declaring vararg types
+        ioctl = libc.ioctl              # not declaring vararg types
+
+        @staticmethod
+        def _open_for_cpu(cpu, attr):
+                pfd = Perf.syscall(Perf.NR_PERF_EVENT_OPEN, ct.byref(attr),
+                                   -1, cpu, -1, Perf.PERF_FLAG_FD_CLOEXEC)
+                if pfd < 0:
+                        errno_ = ct.get_errno()
+                        raise OSError(errno_, os.strerror(errno_))
+                if Perf.ioctl(pfd, Perf.PERF_EVENT_IOC_SET_FILTER,
+                              "common_pid == -17") < 0:
+                        errno_ = ct.get_errno()
+                        raise OSError(errno_, os.strerror(errno_))
+                if Perf.ioctl(pfd, Perf.PERF_EVENT_IOC_ENABLE, 0) < 0:
+                        errno_ = ct.get_errno()
+                        raise OSError(errno_, os.strerror(errno_))
+
+        @staticmethod
+        def perf_event_open(tpoint_id):
+                attr = Perf.perf_event_attr()
+                attr.config = tpoint_id
+                attr.type = Perf.PERF_TYPE_TRACEPOINT
+                attr.sample_type = Perf.PERF_SAMPLE_RAW
+                attr.sample_period = 1
+                attr.wakeup_events = 1
+                for cpu in range(0, multiprocessing.cpu_count()):
+                        Perf._open_for_cpu(cpu, attr)
+
+class Tracepoint(object):
+        tracepoints_enabled = 0
+        trace_root = "/sys/kernel/debug/tracing"
+        event_root = os.path.join(trace_root, "events")
+
+        @staticmethod
+        def generate_decl():
+                if Tracepoint.tracepoints_enabled == 0:
+                        return ""
+                return "\nBPF_HASH(__trace_di, u64, u64);\n"
+
+        @staticmethod
+        def generate_entry_probe():
+                if Tracepoint.tracepoints_enabled == 0:
+                        return ""
+                return """
+int __trace_entry_update(struct pt_regs *ctx)
+{
+        u64 tid = bpf_get_current_pid_tgid();
+        u64 val = ctx->di;
+        __trace_di.update(&tid, &val);
+        return 0;
+}
+"""
+
+        @staticmethod
+        def enable_tracepoint(category, event):
+                tp_id = Tracepoint.get_tpoint_id(category, event)
+                if tp_id == -1:
+                        raise ValueError("no such tracepoint found: %s:%s" %
+                                         (category, event))
+                Perf.perf_event_open(tp_id)
+                Tracepoint.tracepoints_enabled += 1
+
+        @staticmethod
+        def get_tpoint_id(category, event):
+                evt_dir = os.path.join(Tracepoint.event_root, category, event)
+                try:
+                        return int(
+                          open(os.path.join(evt_dir, "id")).read().strip())
+                except:
+                        return -1
+
+        @staticmethod
+        def get_tpoint_format(category, event):
+                evt_dir = os.path.join(Tracepoint.event_root, category, event)
+                try:
+                        return open(os.path.join(evt_dir, "format")).readlines()
+                except:
+                        return ""
+
+        @staticmethod
+        def attach(bpf):
+                if Tracepoint.tracepoints_enabled > 0:
+                        bpf.attach_kprobe(event="tracing_generic_entry_update",
+                                          fn_name="__trace_entry_update")
+
 class Probe(object):
         probe_count = 0
         max_events = None
@@ -72,6 +190,47 @@
         def is_default_action(self):
                 return self.python_format == ""
 
+        def _generate_tpoint_entry_struct_fields(self):
+                format_lines = Tracepoint.get_tpoint_format(self.tp_category,
+                                                            self.tp_event)
+                text = ""
+                for line in format_lines:
+                        match = re.search(r'field:([^;]*);.*size:(\d+);', line)
+                        if match is None:
+                                continue
+                        parts = match.group(1).split()
+                        field_name = parts[-1:][0]
+                        field_type = " ".join(parts[:-1])
+                        field_size = int(match.group(2))
+                        if "__data_loc" in field_type:
+                                continue
+                        if field_name.startswith("common_"):
+                                continue
+                        text += "        %s %s;\n" % (field_type, field_name)
+                return text
+
+        def _generate_tpoint_entry_struct(self):
+                text = """
+struct %s {
+        u64 __do_not_use__;
+%s
+};
+                """
+                self.tp_entry_struct_name = self.probe_name + \
+                                            "_trace_entry"
+                fields = self._generate_tpoint_entry_struct_fields()
+                return text % (self.tp_entry_struct_name, fields)
+
+        def _generate_tpoint_entry_prefix(self):
+                text = """
+        u64 tid = bpf_get_current_pid_tgid();
+        u64 *di = __trace_di.lookup(&tid);
+        if (di == 0) { return 0; }
+        struct %s tp = {};
+        bpf_probe_read(&tp, sizeof(tp), (void *)*di);
+                """ % self.tp_entry_struct_name
+                return text
+
         def _bail(self, error):
                 raise ValueError("error parsing probe '%s': %s" %
                                  (self.raw_probe, error))
@@ -123,13 +282,21 @@
                         parts = ["p", parts[0], parts[1]]
                 if len(parts[0]) == 0:
                         self.probe_type = "p"
-                elif parts[0] in ["p", "r"]:
+                elif parts[0] in ["p", "r", "t"]:
                         self.probe_type = parts[0]
                 else:
-                        self._bail("expected '', 'p', or 'r', got '%s'" %
+                        self._bail("expected '', 'p', 't', or 'r', got '%s'" %
                                    parts[0])
-                self.library = parts[1]
-                self.function = parts[2]
+                if self.probe_type == "t":
+                        self.tp_category = parts[1]
+                        self.tp_event = parts[2]
+                        Tracepoint.enable_tracepoint(self.tp_category,
+                                                     self.tp_event)
+                        self.library = ""       # kernel
+                        self.function = "perf_trace_%s" % self.tp_event
+                else:
+                        self.library = parts[1]
+                        self.function = parts[2]
 
         def _parse_filter(self, filt):
                 self.filter = self._replace_args(filt)
@@ -149,12 +316,17 @@
                 if len(action) == 0:
                         return
 
-                parts = action.split(',')
-                self.raw_format = parts[0]
+                action = action.strip()
+                match = re.search(r'(\".*\"),?(.*)', action)
+                if match is None:
+                        self._bail("expected format string in \"s")
+
+                self.raw_format = match.group(1)
                 self._parse_types(self.raw_format)
-                for part in parts[1:]:
+                for part in match.group(2).split(','):
                         part = self._replace_args(part)
-                        self.values.append(part)
+                        if len(part) > 0:
+                                self.values.append(part)
 
         aliases = {
                 "retval": "ctx->ax",
@@ -283,10 +455,16 @@
                 for i, expr in enumerate(self.values):
                         data_fields += self._generate_field_assign(i)
 
+                prefix = ""
+                if self.probe_type == "t":
+                        data_decl += self._generate_tpoint_entry_struct()
+                        prefix = self._generate_tpoint_entry_prefix()
+
                 text = """
 int %s(struct pt_regs *ctx)
 {
         %s
+        %s
         if (!(%s)) return 0;
 
         struct %s __data = {0};
@@ -298,7 +476,7 @@
         return 0;
 }
 """
-                text = text % (self.probe_name, pid_filter,
+                text = text % (self.probe_name, pid_filter, prefix,
                                self.filter, self.struct_name,
                                data_fields, self.events_name)
 
@@ -308,6 +486,31 @@
         def _time_off_str(cls, timestamp_ns):
                 return "%.6f" % (1e-9 * (timestamp_ns - cls.first_ts))
 
+        auto_includes = {
+                "linux/time.h"      : ["time"],
+                "linux/fs.h"        : ["fs", "file"],
+                "linux/blkdev.h"    : ["bio", "request"],
+                "linux/slab.h"      : ["alloc"],
+                "linux/netdevice.h" : ["sk_buff"]
+        }
+
+        @classmethod
+        def generate_auto_includes(cls, probes):
+                headers = ""
+                for header, keywords in cls.auto_includes.items():
+                        for keyword in keywords:
+                                for probe in probes:
+                                        if keyword in probe:
+                                                headers += "#include <%s>\n" \
+                                                           % header
+                return headers
+
+        def _display_function(self):
+                if self.probe_type != 't':
+                        return self.function
+                else:
+                        return self.function.replace("perf_trace_", "")
+
         def print_event(self, cpu, data, size):
                 # Cast as the generated structure type and display
                 # according to the format string in the probe.
@@ -318,7 +521,8 @@
                 time = strftime("%H:%M:%S") if Probe.use_localtime else \
                        Probe._time_off_str(event.timestamp_ns)
                 print("%-8s %-6d %-12s %-16s %s" % \
-                    (time[:8], event.pid, event.comm[:12], self.function, msg))
+                    (time[:8], event.pid, event.comm[:12],
+                     self._display_function(), msg))
 
                 Probe.event_count += 1
                 if Probe.max_events is not None and \
@@ -337,7 +541,7 @@
                 if self.probe_type == "r":
                         bpf.attach_kretprobe(event=self.function,
                                              fn_name=self.probe_name)
-                elif self.probe_type == "p":
+                elif self.probe_type == "p" or self.probe_type == "t":
                         bpf.attach_kprobe(event=self.function,
                                           fn_name=self.probe_name)
 
@@ -384,6 +588,8 @@
         Trace returns from __kmalloc which returned a null pointer
 trace 'r:c:malloc (retval) "allocated = %p", retval
         Trace returns from malloc and print non-NULL allocated buffers
+trace 't:block:block_rq_complete "sectors=%d", tp.nr_sector'
+        Trace the block_rq_complete kernel tracepoint and print # of tx sectors
 """
 
         def __init__(self):
@@ -420,6 +626,10 @@
 #include <linux/sched.h>        /* For TASK_COMM_LEN */
 
 """
+                self.program += Probe.generate_auto_includes(
+                        map(lambda p: p.raw_probe, self.probes))
+                self.program += Tracepoint.generate_decl()
+                self.program += Tracepoint.generate_entry_probe()
                 for probe in self.probes:
                         self.program += probe.generate_program(
                                 self.args.pid or -1, self.args.include_self)
@@ -429,6 +639,7 @@
 
         def _attach_probes(self):
                 self.bpf = BPF(text=self.program)
+                Tracepoint.attach(self.bpf)
                 for probe in self.probes:
                         if self.args.verbose:
                                 print(probe)
@@ -455,7 +666,7 @@
                 except:
                         if self.args.verbose:
                                 traceback.print_exc()
-                        else:
+                        elif sys.exc_type is not SystemExit:
                                 print(sys.exc_value)
 
 if __name__ == "__main__":