tracepoint support for argdist and trace, and new tplist tool for printing tracepoints
diff --git a/tools/argdist.py b/tools/argdist.py
index 99198ba..a92c975 100755
--- a/tools/argdist.py
+++ b/tools/argdist.py
@@ -1,13 +1,13 @@
 #!/usr/bin/env python
 #
 # argdist   Trace a function and display a distribution of its
-#              parameter values as a histogram or frequency count.
+#           parameter values as a histogram or frequency count.
 #
 # USAGE: argdist [-h] [-p PID] [-z STRING_SIZE] [-i INTERVAL]
-#                   [-n COUNT] [-v] [-T TOP]
-#                   [-C specifier [specifier ...]]
-#                   [-H specifier [specifier ...]]
-#                   [-I header [header ...]]
+#                [-n COUNT] [-v] [-T TOP]
+#                [-C specifier [specifier ...]]
+#                [-H specifier [specifier ...]]
+#                [-I header [header ...]]
 #
 # Licensed under the Apache License, Version 2.0 (the "License")
 # Copyright (C) 2016 Sasha Goldshtein.
@@ -15,10 +15,129 @@
 from bcc import BPF
 from time import sleep, strftime
 import argparse
+import ctypes as ct
 import re
 import traceback
+import os
+import multiprocessing
 import sys
 
+class Perf(object):
+        class perf_event_attr(ct.Structure):
+                _fields_ = [
+                        ('type', ct.c_uint),
+                        ('size', ct.c_uint),
+                        ('config', ct.c_ulong),
+                        ('sample_period', ct.c_ulong),
+                        ('sample_type', ct.c_ulong),
+                        ('IGNORE1', ct.c_ulong),
+                        ('IGNORE2', ct.c_ulong),
+                        ('wakeup_events', ct.c_uint),
+                        ('IGNORE3', ct.c_uint),
+                        ('IGNORE4', ct.c_ulong),
+                        ('IGNORE5', ct.c_ulong),
+                        ('IGNORE6', ct.c_ulong),
+                        ('IGNORE7', ct.c_uint),
+                        ('IGNORE8', ct.c_int),
+                        ('IGNORE9', ct.c_ulong),
+                        ('IGNORE10', ct.c_uint),
+                        ('IGNORE11', ct.c_uint)
+                ]
+
+        NR_PERF_EVENT_OPEN = 298
+        PERF_TYPE_TRACEPOINT = 2
+        PERF_SAMPLE_RAW = 1024
+        PERF_FLAG_FD_CLOEXEC = 8
+        PERF_EVENT_IOC_SET_FILTER = 1074275334
+        PERF_EVENT_IOC_ENABLE = 9216
+
+        libc = ct.CDLL('libc.so.6', use_errno=True)
+        syscall = libc.syscall          # not declaring vararg types
+        ioctl = libc.ioctl              # not declaring vararg types
+
+        @staticmethod
+        def _open_for_cpu(cpu, attr):
+                pfd = Perf.syscall(Perf.NR_PERF_EVENT_OPEN, ct.byref(attr),
+                                   -1, cpu, -1, Perf.PERF_FLAG_FD_CLOEXEC)
+                if pfd < 0:
+                        errno_ = ct.get_errno()
+                        raise OSError(errno_, os.strerror(errno_))
+                if Perf.ioctl(pfd, Perf.PERF_EVENT_IOC_SET_FILTER,
+                              "common_pid == -17") < 0:
+                        errno_ = ct.get_errno()
+                        raise OSError(errno_, os.strerror(errno_))
+                if Perf.ioctl(pfd, Perf.PERF_EVENT_IOC_ENABLE, 0) < 0:
+                        errno_ = ct.get_errno()
+                        raise OSError(errno_, os.strerror(errno_))
+
+        @staticmethod
+        def perf_event_open(tpoint_id):
+                attr = Perf.perf_event_attr()
+                attr.config = tpoint_id
+                attr.type = Perf.PERF_TYPE_TRACEPOINT
+                attr.sample_type = Perf.PERF_SAMPLE_RAW
+                attr.sample_period = 1
+                attr.wakeup_events = 1
+                for cpu in range(0, multiprocessing.cpu_count()):
+                        Perf._open_for_cpu(cpu, attr)
+
+class Tracepoint(object):
+        tracepoints_enabled = 0
+        trace_root = "/sys/kernel/debug/tracing"
+        event_root = os.path.join(trace_root, "events")
+
+        @staticmethod
+        def generate_decl():
+                if Tracepoint.tracepoints_enabled == 0:
+                        return ""
+                return "\nBPF_HASH(__trace_di, u64, u64);\n"
+
+        @staticmethod
+        def generate_entry_probe():
+                if Tracepoint.tracepoints_enabled == 0:
+                        return ""
+                return """
+int __trace_entry_update(struct pt_regs *ctx)
+{
+        u64 tid = bpf_get_current_pid_tgid();
+        u64 val = ctx->di;
+        __trace_di.update(&tid, &val);
+        return 0;
+}
+"""
+
+        @staticmethod
+        def enable_tracepoint(category, event):
+                tp_id = Tracepoint.get_tpoint_id(category, event)
+                if tp_id == -1:
+                        raise ValueError("no such tracepoint found: %s:%s" %
+                                         (category, event))
+                Perf.perf_event_open(tp_id)
+                Tracepoint.tracepoints_enabled += 1
+
+        @staticmethod
+        def get_tpoint_id(category, event):
+                evt_dir = os.path.join(Tracepoint.event_root, category, event)
+                try:
+                        return int(
+                          open(os.path.join(evt_dir, "id")).read().strip())
+                except:
+                        return -1
+
+        @staticmethod
+        def get_tpoint_format(category, event):
+                evt_dir = os.path.join(Tracepoint.event_root, category, event)
+                try:
+                        return open(os.path.join(evt_dir, "format")).readlines()
+                except:
+                        return ""
+
+        @staticmethod
+        def attach(bpf):
+                if Tracepoint.tracepoints_enabled > 0:
+                        bpf.attach_kprobe(event="tracing_generic_entry_update",
+                                          fn_name="__trace_entry_update")
+
 class Specifier(object):
         probe_text = """
 DATA_DECL
@@ -36,10 +155,11 @@
         next_probe_index = 0
         aliases = { "$PID": "bpf_get_current_pid_tgid()" }
         auto_includes = {
-                "linux/time.h"    : ["time"],
-                "linux/fs.h"      : ["fs", "file"],
-                "linux/blkdev.h"  : ["bio", "request"],
-                "linux/slab.h"    : ["alloc"]
+                "linux/time.h"      : ["time"],
+                "linux/fs.h"        : ["fs", "file"],
+                "linux/blkdev.h"    : ["bio", "request"],
+                "linux/slab.h"      : ["alloc"],
+                "linux/netdevice.h" : ["sk_buff", "net_device"]
         }
 
         @staticmethod
@@ -189,8 +309,8 @@
                                    "function signature must be specified")
                 if len(parts) > 6:
                         self._bail("extraneous ':'-separated parts detected")
-                if parts[0] not in ["r", "p"]:
-                        self._bail("probe type must be either 'p' or 'r', " +
+                if parts[0] not in ["r", "p", "t"]:
+                        self._bail("probe type must be 'p', 'r', or 't', " +
                                    "but got '%s'" % parts[0])
                 if re.match(r"\w+\(.*\)", parts[2]) is None:
                         self._bail(("function signature '%s' has an invalid " +
@@ -216,11 +336,19 @@
 
                 parts = spec_and_label[0].strip().split(':')
                 self.type = type    # hist or freq
-                self.is_ret_probe = parts[0] == "r"
-                self.library = parts[1]
-                self.is_user = len(self.library) > 0
+                self.probe_type = parts[0]
                 fparts = parts[2].split('(')
                 self.function = fparts[0].strip()
+                if self.probe_type == "t":
+                        self.library = ""       # kernel
+                        self.tp_category = parts[1]
+                        self.tp_event = self.function
+                        Tracepoint.enable_tracepoint(
+                                        self.tp_category, self.tp_event)
+                        self.function = "perf_trace_" + self.function
+                else:
+                        self.library = parts[1]
+                self.is_user = len(self.library) > 0
                 self.signature = fparts[1].strip()[:-1]
                 self._parse_signature()
 
@@ -235,12 +363,12 @@
                         if self.type == "hist" and len(self.expr_types) > 1:
                                 self._bail("histograms can only have 1 expr")
                 else:
-                        if not self.is_ret_probe and self.type == "hist":
+                        if not self.probe_type == "r" and self.type == "hist":
                                 self._bail("histograms must have expr")
                         self.expr_types = \
-                                ["u64" if not self.is_ret_probe else "int"]
+                          ["u64" if not self.probe_type == "r" else "int"]
                         self.exprs = \
-                                ["1" if not self.is_ret_probe else "$retval"]
+                          ["1" if not self.probe_type == "r" else "$retval"]
                 self.filter = "" if len(parts) != 6 else parts[5]
                 self._substitute_exprs()
 
@@ -249,7 +377,7 @@
                 def check(expr):
                         keywords = ["$entry", "$latency"]
                         return any(map(lambda kw: kw in expr, keywords))
-                self.entry_probe_required = self.is_ret_probe and \
+                self.entry_probe_required = self.probe_type == "r" and \
                         (any(map(check, self.exprs)) or check(self.filter))
 
                 self.pid = pid
@@ -278,11 +406,11 @@
 
         def _generate_field_assignment(self, i):
                 if self._is_string(self.expr_types[i]):
-                        return "bpf_probe_read(" + \
+                        return "        bpf_probe_read(" + \
                                "&__key.v%d.s, sizeof(__key.v%d.s), %s);\n" % \
                                 (i, i, self.exprs[i])
                 else:
-                        return "__key.v%d = %s;\n" % (i, self.exprs[i])
+                        return "        __key.v%d = %s;\n" % (i, self.exprs[i])
 
         def _generate_hash_decl(self):
                 if self.type == "hist":
@@ -324,28 +452,71 @@
                 else:
                         return ""
 
-        def generate_text(self):
-                # We don't like tools writing tools (Brendan Gregg), but this
-                # is an exception because we're letting the user fully
-                # customize the values we probe. As a rule of thumb though,
-                # try to build a custom tool for a specific purpose.
+        def _generate_tpoint_entry_struct_fields(self):
+                format_lines = Tracepoint.get_tpoint_format(self.tp_category,
+                                                            self.tp_event)
+                text = ""
+                for line in format_lines:
+                        match = re.search(r'field:([^;]*);.*size:(\d+);', line)
+                        if match is None:
+                                continue
+                        parts = match.group(1).split()
+                        field_name = parts[-1:][0]
+                        field_type = " ".join(parts[:-1])
+                        field_size = int(match.group(2))
+                        if "__data_loc" in field_type:
+                                continue
+                        if field_name.startswith("common_"):
+                                continue
+                        text += "        %s %s;\n" % (field_type, field_name)
+                return text
 
+        def _generate_tpoint_entry_struct(self):
+                text = """
+struct %s {
+        u64 __do_not_use__;
+%s
+};
+                """
+                self.tp_entry_struct_name = self.probe_func_name + \
+                                            "_trace_entry"
+                fields = self._generate_tpoint_entry_struct_fields()
+                return text % (self.tp_entry_struct_name, fields)
+
+        def _generate_tpoint_entry_prefix(self):
+                text = """
+        u64 tid = bpf_get_current_pid_tgid();
+        u64 *di = __trace_di.lookup(&tid);
+        if (di == 0) { return 0; }
+        struct %s tp = {};
+        bpf_probe_read(&tp, sizeof(tp), (void *)*di);
+                """ % self.tp_entry_struct_name
+                return text
+
+        def generate_text(self):
                 program = ""
 
                 # If any entry arguments are probed in a ret probe, we need
                 # to generate an entry probe to collect them
                 prefix = ""
                 if self.entry_probe_required:
-                        program = self._generate_entry_probe()
-                        prefix = self._generate_retprobe_prefix()
+                        program += self._generate_entry_probe()
+                        prefix += self._generate_retprobe_prefix()
                         # Replace $entry(paramname) with a reference to the
                         # value we collected when entering the function:
                         self._replace_entry_exprs()
 
+                # If this is a tracepoint probe, generate a local variable
+                # that enables access to the tracepoint structure and also
+                # the structure definition itself
+                if self.probe_type == "t":
+                        program += self._generate_tpoint_entry_struct()
+                        prefix += self._generate_tpoint_entry_prefix()
+
                 program += self.probe_text.replace("PROBENAME",
                                                    self.probe_func_name)
                 signature = "" if len(self.signature) == 0 \
-                                  or self.is_ret_probe \
+                                  or self.probe_type == "r" \
                                else ", " + self.signature
                 program = program.replace("SIGNATURE", signature)
                 program = program.replace("PID_FILTER",
@@ -364,8 +535,10 @@
 
         def attach(self, bpf):
                 self.bpf = bpf
+                uprobes_start = len(BPF.open_uprobes())
+                kprobes_start = len(BPF.open_kprobes())
                 if self.is_user:
-                        if self.is_ret_probe:
+                        if self.probe_type == "r":
                                 bpf.attach_uretprobe(name=self.library,
                                                   sym=self.function,
                                                   fn_name=self.probe_func_name,
@@ -375,13 +548,17 @@
                                                   sym=self.function,
                                                   fn_name=self.probe_func_name,
                                                   pid=self.pid or -1)
+                        if len(BPF.open_uprobes()) != uprobes_start + 1:
+                                self._bail("error attaching probe")
                 else:
-                        if self.is_ret_probe:
+                        if self.probe_type == "r" or self.probe_type == "t":
                                 bpf.attach_kretprobe(event=self.function,
                                                   fn_name=self.probe_func_name)
                         else:
                                 bpf.attach_kprobe(event=self.function,
                                                   fn_name=self.probe_func_name)
+                        if len(BPF.open_kprobes()) != kprobes_start + 1:
+                                self._bail("error attaching probe")
                 if self.entry_probe_required:
                         self._attach_entry_probe()
 
@@ -406,7 +583,7 @@
 
         def _display_key(self, key):
                 if self.is_default_expr:
-                        if not self.is_ret_probe:
+                        if not self.probe_type == "r":
                                 return "total calls"
                         else:
                                 return "retval = %s" % str(key.v0)
@@ -431,7 +608,7 @@
                                 # Print some nice values if the user didn't
                                 # specify an expression to probe
                                 if self.is_default_expr:
-                                        if not self.is_ret_probe:
+                                        if not self.probe_type == "r":
                                                 key_str = "total calls"
                                         else:
                                                 key_str = "retval = %s" % \
@@ -448,13 +625,14 @@
 class Tool(object):
         examples = """
 Probe specifier syntax:
-        {p,r}:[library]:function(signature)[:type[,type...]:expr[,expr...][:filter]][#label]
+        {p,r,t}:{[library],category}:function(signature)[:type[,type...]:expr[,expr...][:filter]][#label]
 Where:
-        p,r        -- probe at function entry or at function exit
+        p,r,t      -- probe at function entry, function exit, or kernel tracepoint
                       in exit probes: can use $retval, $entry(param), $latency
         library    -- the library that contains the function
                       (leave empty for kernel functions)
-        function   -- the function name to trace
+        category   -- the category of the kernel tracepoint (e.g. net, sched)
+        function   -- the function name to trace (or tracepoint name)
         signature  -- the function's parameters, as in the C header
         type       -- the type of the expression to collect (supports multiple)
         expr       -- the expression to collect (supports multiple)
@@ -502,6 +680,12 @@
         Count fork() calls in libc across all processes
         Can also use funccount.py, which is easier and more flexible
 
+argdist -H 't:block:block_rq_complete():u32:tp.nr_sector'
+        Print histogram of number of sectors in completing block I/O requests
+
+argdist -C 't:irq:irq_handler_entry():int:tp.irq'
+        Aggregate interrupts by interrupt request (IRQ)
+
 argdist  -H \\
         'p:c:sleep(u32 seconds):u32:seconds' \\
         'p:c:nanosleep(struct timespec *req):long:req->tv_nsec'
@@ -555,7 +739,7 @@
                                 Specifier("hist", histspecifier, self.args.pid))
                 if len(self.specifiers) == 0:
                         print("at least one specifier is required")
-                        exit(1)
+                        exit()
 
         def _generate_program(self):
                 bpf_source = """
@@ -567,6 +751,8 @@
                         bpf_source += "#include <%s>\n" % include
                 bpf_source += Specifier.generate_auto_includes(
                                 map(lambda s: s.raw_spec, self.specifiers))
+                bpf_source += Tracepoint.generate_decl()
+                bpf_source += Tracepoint.generate_entry_probe()
                 for specifier in self.specifiers:
                         bpf_source += specifier.generate_text()
                 if self.args.verbose:
@@ -574,8 +760,12 @@
                 self.bpf = BPF(text=bpf_source)
 
         def _attach(self):
+                Tracepoint.attach(self.bpf)
                 for specifier in self.specifiers:
                         specifier.attach(self.bpf)
+                if self.args.verbose:
+                        print("open uprobes: %s" % BPF.open_uprobes())
+                        print("open kprobes: %s" % BPF.open_kprobes())
 
         def _main_loop(self):
                 count_so_far = 0
@@ -601,7 +791,7 @@
                 except:
                         if self.args.verbose:
                                 traceback.print_exc()
-                        else:
+                        elif sys.exc_type is not SystemExit:
                                 print(sys.exc_value)
 
 if __name__ == "__main__":