Fixed bug with labels, added support for tuples in hash
diff --git a/tools/argdist.py b/tools/argdist.py
index 76e68fd..fd53cf3 100755
--- a/tools/argdist.py
+++ b/tools/argdist.py
@@ -25,8 +25,8 @@
 {
         PREFIX
         PID_FILTER
-        KEY_EXPR
         if (!(FILTER)) return 0;
+        KEY_EXPR
         COLLECT
         return 0;
 }
@@ -92,11 +92,13 @@
                 # when entering the function.
                 self.args_to_probe = set()
                 regex = r"\$entry\((\w+)\)"
-                for arg in re.finditer(regex, self.expr):
-                        self.args_to_probe.add(arg.group(1))
+                for expr in self.exprs:
+                        for arg in re.finditer(regex, expr):
+                                self.args_to_probe.add(arg.group(1))
                 for arg in re.finditer(regex, self.filter):
                         self.args_to_probe.add(arg.group(1))
-                if "$latency" in self.expr or "$latency" in self.filter:
+                if any(map(lambda expr: "$latency" in expr, self.exprs)) or \
+                   "$latency" in self.filter:
                         self.args_to_probe.add("__latency")
                         self.param_types["__latency"] = "u64"    # nanoseconds
                 for pname in self.args_to_probe:
@@ -139,7 +141,9 @@
                         else:
                                 entry_expr = "$entry(%s)" % pname
                                 val_expr = "(*%s)" % vname
-                        self.expr = self.expr.replace(entry_expr, val_expr)
+                        for i in range(0, len(self.exprs)):
+                                self.exprs[i] = self.exprs[i].replace(
+                                                entry_expr, val_expr)
                         self.filter = self.filter.replace(entry_expr,
                                                           val_expr)
 
@@ -171,7 +175,17 @@
                                    "but got '%s'" % parts[0])
                 if re.match(r"\w+\(.*\)", parts[2]) is None:
                         self._bail(("function signature '%s' has an invalid " +
-                                   "format") % parts[2])
+                                    "format") % parts[2])
+
+        def _parse_expr_types(self, expr_types):
+                if len(expr_types) == 0:
+                        self._bail("no expr types specified")
+                self.expr_types = expr_types.split(',')
+
+        def _parse_exprs(self, exprs):
+                if len(exprs) == 0:
+                        self._bail("no exprs specified")
+                self.exprs = exprs.split(',')
 
         def __init__(self, type, specifier, pid):
                 self.raw_spec = specifier
@@ -195,26 +209,31 @@
                 # the retval in a ret probe, or simply the value "1" otherwise.
                 self.is_default_expr = len(parts) < 5
                 if not self.is_default_expr:
-                        self.expr_type = parts[3]
-                        self.expr = parts[4]
+                        self._parse_expr_types(parts[3])
+                        self._parse_exprs(parts[4])
+                        if len(self.exprs) != len(self.expr_types):
+                                self._bail("mismatched # of exprs and types")
+                        if self.type == "hist" and len(self.expr_types) > 1:
+                                self._bail("histograms can only have 1 expr")
                 else:
                         if not self.is_ret_probe and self.type == "hist":
-                                raise ValueError("dist probes must have expr")
-                        self.expr_type = \
-                                "u64" if not self.is_ret_probe else "int"
-                        self.expr = "1" if not self.is_ret_probe else "$retval"
+                                self._bail("histograms must have expr")
+                        self.expr_types = \
+                                ["u64" if not self.is_ret_probe else "int"]
+                        self.exprs = \
+                                ["1" if not self.is_ret_probe else "$retval"]
                 self.filter = "" if len(parts) != 6 else parts[5]
                 self._substitute_exprs()
 
                 # Do we need to attach an entry probe so that we can collect an 
                 # argument that is required for an exit (return) probe?
+                def check(expr):
+                        keywords = ["$entry", "$latency"]
+                        return any(map(lambda kw: kw in expr, keywords))
                 self.entry_probe_required = self.is_ret_probe and \
-                       ("$entry" in self.expr or "$entry" in self.filter or
-                        "$latency" in self.expr or "$latency" in self.filter)
+                        (any(map(check, self.exprs)) or check(self.filter))
 
                 self.pid = pid
-                # Generating unique names for probes means we can attach
-                # many times to the same function.
                 self.probe_func_name = "%s_probe%d" % \
                         (self.function, Specifier.next_probe_index)
                 self.probe_hash_name = "%s_hash%d" % \
@@ -222,17 +241,71 @@
                 Specifier.next_probe_index += 1
 
         def _substitute_exprs(self):
-                self.expr = self.expr.replace("$retval",
-                                              "(%s)ctx->ax" % self.expr_type)
-                self.filter = self.filter.replace("$retval",
-                                              "(%s)ctx->ax" % self.expr_type)
-                self.expr = self._substitute_aliases(self.expr)
-                self.filter = self._substitute_aliases(self.filter)
+                def repl(expr):
+                        expr = self._substitute_aliases(expr)
+                        return expr.replace("$retval", "ctx->ax")
+                for i in range(0, len(self.exprs)):
+                        self.exprs[i] = repl(self.exprs[i])
+                self.filter = repl(self.filter)
 
-        def _is_string_probe(self):
-                return self.expr_type == "char*" or self.expr_type == "char *"
+        def _is_string(self, expr_type):
+                return expr_type == "char*" or expr_type == "char *"
 
-        def generate_text(self, string_size):
+        def _generate_hash_field(self, i):
+                if self._is_string(self.expr_types[i]):
+                        return "struct __string_t v%d;\n" % i
+                else:
+                        return "%s v%d;\n" % (self.expr_types[i], i)
+
+        def _generate_field_assignment(self, i):
+                if self._is_string(self.expr_types[i]):
+                        return "bpf_probe_read(" + \
+                               "&__key.v%d.s, sizeof(__key.v%d.s), %s);\n" % \
+                                (i, i, self.exprs[i])
+                else:
+                        return "__key.v%d = %s;\n" % (i, self.exprs[i])
+
+        def _generate_hash_decl(self):
+                if self.type == "hist":
+                        return "BPF_HISTOGRAM(%s, %s);" % \
+                               (self.probe_hash_name, self.expr_types[0])
+                else:
+                        text = "struct %s_key_t {\n" % self.probe_hash_name
+                        for i in range(0, len(self.expr_types)):
+                                text += self._generate_hash_field(i)
+                        text += "};\n"
+                        text += "BPF_HASH(%s, struct %s_key_t, u64);\n" % \
+                                (self.probe_hash_name, self.probe_hash_name)
+                        return text
+
+        def _generate_key_assignment(self):
+                if self.type == "hist":
+                        return "%s __key = %s;\n" % \
+                                (self.expr_types[0], self.exprs[0])
+                else:
+                        text = "struct %s_key_t __key = {};\n" % \
+                                self.probe_hash_name
+                        for i in range(0, len(self.exprs)):
+                                text += self._generate_field_assignment(i) 
+                        return text
+
+        def _generate_hash_update(self):
+                if self.type == "hist":
+                        return "%s.increment(bpf_log2l(__key));" % \
+                                self.probe_hash_name 
+                else:
+                        return "%s.increment(__key);" % self.probe_hash_name
+
+        def _generate_pid_filter(self):
+                # Kernel probes need to explicitly filter pid, because the
+                # attach interface doesn't support pid filtering
+                if self.pid is not None and not self.is_user:
+                        return "u32 pid = bpf_get_current_pid_tgid();\n" + \
+                               "if (pid != %d) { return 0; }" % self.pid
+                else:
+                        return ""
+
+        def generate_text(self):
                 # We don't like tools writing tools (Brendan Gregg), but this
                 # is an exception because we're letting the user fully
                 # customize the values we probe. As a rule of thumb though,
@@ -246,6 +319,8 @@
                 if self.entry_probe_required:
                         program = self._generate_entry_probe()
                         prefix = self._generate_retprobe_prefix()                                         
+                        # Replace $entry(paramname) with a reference to the
+                        # value we collected when entering the function:
                         self._replace_entry_exprs()
 
                 program += self.probe_text.replace("PROBENAME",
@@ -254,39 +329,12 @@
                                   or self.is_ret_probe \
                                else ", " + self.signature
                 program = program.replace("SIGNATURE", signature)
-                if self.pid is not None and not self.is_user:
-                        # Kernel probes need to explicitly filter pid
-                        program = program.replace("PID_FILTER",
-                                "u32 pid = bpf_get_current_pid_tgid();\n" + \
-                                "if (pid != %d) { return 0; }" % self.pid)
-                else:
-                        program = program.replace("PID_FILTER", "")
-                if self._is_string_probe():
-                        decl = """
-struct %s_key_t { char key[%d]; };
-BPF_HASH(%s, struct %s_key_t, u64);
-""" \
-                        % (self.function, string_size,
-                           self.probe_hash_name, self.function)
-                        collect = "%s.increment(__key);" % self.probe_hash_name
-                        key_expr = """
-struct %s_key_t __key = {0};
-bpf_probe_read(&__key.key, sizeof(__key.key), %s);
-""" \
-                        % (self.function, self.expr)
-                elif self.type == "freq":
-                        decl = "BPF_HASH(%s, %s, u64);" % \
-                                (self.probe_hash_name, self.expr_type)
-                        collect = "%s.increment(__key);" % self.probe_hash_name
-                        key_expr = "%s __key = %s;" % \
-                                   (self.expr_type, self.expr)
-                elif self.type == "hist":
-                        decl = "BPF_HISTOGRAM(%s, %s);" % \
-                                (self.probe_hash_name, self.expr_type)
-                        collect = "%s.increment(bpf_log2l(__key));" % \
-                                  self.probe_hash_name 
-                        key_expr = "%s __key = %s;" % \
-                                   (self.expr_type, self.expr)
+                program = program.replace("PID_FILTER",
+                                          self._generate_pid_filter())
+
+                decl = self._generate_hash_decl()
+                key_expr = self._generate_key_assignment()
+                collect = self._generate_hash_update()
                 program = program.replace("DATA_DECL", decl)
                 program = program.replace("KEY_EXPR", key_expr) 
                 program = program.replace("FILTER",
@@ -318,6 +366,40 @@
                 if self.entry_probe_required:
                         self._attach_entry_probe()
 
+        def _v2s(self, v):
+                # Most fields can be converted with plain str(), but strings
+                # are wrapped in a __string_t which has an .s field
+                if "__string_t" in type(v).__name__:
+                        return str(v.s)
+                return str(v)
+
+        def _display_expr(self, i):
+                # Replace ugly latency calculation with $latency
+                expr = self.exprs[i].replace(
+                        "(bpf_ktime_get_ns() - *____latency_val)", "$latency")
+                # Replace alias values back with the alias name
+                for alias, subst in Specifier.aliases.items():
+                        expr = expr.replace(subst, alias) 
+                # Replace retval expression with $retval
+                expr = expr.replace("ctx->ax", "$retval")
+                # Replace ugly (*__param_val) expressions with param name
+                return re.sub(r"\(\*__(\w+)_val\)", r"\1", expr)
+
+        def _display_key(self, key):
+                if self.is_default_expr:
+                        if not self.is_ret_probe:
+                                return "total calls"
+                        else:
+                                return "retval = %s" % str(key.v0)
+                else:
+                        # The key object has v0, ..., vk fields containing
+                        # the values of the expressions from self.exprs
+                        def str_i(i):
+                                key_i = self._v2s(getattr(key, "v%d" % i))
+                                return "%s = %s" % \
+                                        (self._display_expr(i), key_i)
+                        return ", ".join(map(str_i, range(0, len(self.exprs))))
+
         def display(self, top):
                 data = self.bpf.get_table(self.probe_hash_name)
                 if self.type == "freq":
@@ -327,8 +409,6 @@
                         if top is not None:
                                 data = data[-top:]
                         for key, value in data: 
-                                key_val = key.key if self._is_string_probe() \
-                                                  else str(key.value)
                                 # Print some nice values if the user didn't
                                 # specify an expression to probe
                                 if self.is_default_expr:
@@ -336,21 +416,19 @@
                                                 key_str = "total calls"
                                         else:
                                                 key_str = "retval = %s" % \
-                                                          key_val
+                                                          self._v2s(key.v0)
                                 else:
-                                        key_str = "%s = %s" % \
-                                                  (self.expr, key_val)
+                                        key_str = self._display_key(key) 
                                 print("\t%-10s %s" % \
                                       (str(value.value), key_str))
                 elif self.type == "hist":
-                        label = self.label or \
-                                (self.expr if not self.is_default_expr \
-                                           else "retval")
+                        label = self.label or (self._display_expr(0)
+                                if not self.is_default_expr  else "retval")
                         data.print_log2_hist(val_type=label)
 
 examples = """
 Probe specifier syntax:
-        {p,r}:[library]:function(signature)[:type:expr[:filter]][#label]
+        {p,r}:[library]:function(signature)[:type[,type...]:expr[,expr...][:filter]][#label]
 Where:
         p,r        -- probe at function entry or at function exit
                       in exit probes: can use $retval, $entry(param), $latency
@@ -358,8 +436,8 @@
                       (leave empty for kernel functions)
         function   -- the function name to trace
         signature  -- the function's parameters, as in the C header
-        type       -- the type of the expression to collect
-        expr       -- the expression to collect
+        type       -- the type of the expression to collect (supports multiple)
+        expr       -- the expression to collect (supports multiple)
         filter     -- the filter that is applied to collected values
         label      -- the label for this probe in the resulting output
 
@@ -372,7 +450,7 @@
         Print a frequency count of how many times process 1005 called malloc
         with an allocation size of 16 bytes
 
-argdist.py -C 'r:c:gets():char*:$retval#snooped strings'
+argdist.py -C 'r:c:gets():char*:(char*)$retval#snooped strings'
         Snoop on all strings returned by gets()
 
 argdist.py -H 'r::__kmalloc(size_t size):u64:$latency/$entry(size)#ns per byte'
@@ -388,7 +466,7 @@
         the top 5 busiest fds
 
 argdist.py -p 1005 -H 'r:c:read()'
-        Print a histogram of error codes returned by read() in process 1005
+        Print a histogram of results (sizes) returned by read() in process 1005
 
 argdist.py -C 'r::__vfs_read():u32:$PID:$latency > 100000'
         Print frequency of reads by process where the latency was >0.1ms
@@ -451,11 +529,15 @@
         print("at least one specifier is required")
         exit(1)
 
-bpf_source = "#include <uapi/linux/ptrace.h>\n"
+bpf_source = """
+struct __string_t { char s[%d]; };
+
+#include <uapi/linux/ptrace.h>
+""" % args.string_size
 for include in (args.include or []):
         bpf_source += "#include <%s>\n" % include
 for specifier in specifiers:
-        bpf_source += specifier.generate_text(args.string_size)
+        bpf_source += specifier.generate_text()
 
 if args.verbose:
         print(bpf_source)