tools: improve sslsniff (send buffer & filtering)

This makes few improvements:
    * This can send much larger data payload and also adds
      --max-buffer-size CLI option which allow changing this param.
    * Fixes dealing with non ASCII protocols, previously struct was
      defined as array of chars which made python ctypes treat it as
      NULL terminated string and it prevents from displaying any data
      past the null byte (which is very common for http2).
    * Adds more filtering and displaying options (--print-uid,
      --print-tid, --uid <uid>)

This also deals correctly with rare cases when bpf_probe_read_user fails
(so buffer should be empty and should not be displayed).
diff --git a/tools/sslsniff.py b/tools/sslsniff.py
index 02b7360..8bc61ce 100755
--- a/tools/sslsniff.py
+++ b/tools/sslsniff.py
@@ -4,7 +4,8 @@
 #           GnuTLS and NSS
 #           For Linux, uses BCC, eBPF.
 #
-# USAGE: sslsniff.py [-h] [-p PID] [-c COMM] [-o] [-g] [-d]
+# USAGE: sslsniff.py [-h] [-p PID] [-u UID] [-x] [-c COMM] [-o] [-g] [-n] [-d]
+#                    [--hexdump] [--max-buffer-size SIZE]
 #
 # Licensed under the Apache License, Version 2.0 (the "License")
 #
@@ -23,17 +24,23 @@
 examples = """examples:
     ./sslsniff              # sniff OpenSSL and GnuTLS functions
     ./sslsniff -p 181       # sniff PID 181 only
+    ./sslsniff -u 1000      # sniff only UID 1000
     ./sslsniff -c curl      # sniff curl command only
     ./sslsniff --no-openssl # don't show OpenSSL calls
     ./sslsniff --no-gnutls  # don't show GnuTLS calls
     ./sslsniff --no-nss     # don't show NSS calls
     ./sslsniff --hexdump    # show data as hex instead of trying to decode it as UTF-8
+    ./sslsniff -x           # show process UID and TID
 """
 parser = argparse.ArgumentParser(
     description="Sniff SSL data",
     formatter_class=argparse.RawDescriptionHelpFormatter,
     epilog=examples)
 parser.add_argument("-p", "--pid", type=int, help="sniff this PID only.")
+parser.add_argument("-u", "--uid", type=int, default=None,
+                    help="sniff this UID only.")
+parser.add_argument("-x", "--extra", action="store_true",
+                    help="show extra fields (UID, TID)")
 parser.add_argument("-c", "--comm",
                     help="sniff only commands matching string.")
 parser.add_argument("-o", "--no-openssl", action="store_false", dest="openssl",
@@ -48,6 +55,8 @@
                     help=argparse.SUPPRESS)
 parser.add_argument("--hexdump", action="store_true", dest="hexdump",
                     help="show data as hexdump instead of trying to decode it as UTF-8")
+parser.add_argument('--max-buffer-size', type=int, default=8192,
+                    help='Size of captured buffer')
 args = parser.parse_args()
 
 
@@ -55,34 +64,58 @@
 #include <linux/ptrace.h>
 #include <linux/sched.h>        /* For TASK_COMM_LEN */
 
+#define MAX_BUF_SIZE __MAX_BUF_SIZE__
+
 struct probe_SSL_data_t {
         u64 timestamp_ns;
         u32 pid;
-        char comm[TASK_COMM_LEN];
-        char v0[464];
+        u32 tid;
+        u32 uid;
         u32 len;
+        int buf_filled;
+        char comm[TASK_COMM_LEN];
+        u8 buf[MAX_BUF_SIZE];
 };
 
+#define BASE_EVENT_SIZE ((size_t)(&((struct probe_SSL_data_t*)0)->buf))
+#define EVENT_SIZE(X) (BASE_EVENT_SIZE + ((size_t)(X)))
+
+
+BPF_PERCPU_ARRAY(ssl_data, struct probe_SSL_data_t, 1);
 BPF_PERF_OUTPUT(perf_SSL_write);
 
 int probe_SSL_write(struct pt_regs *ctx, void *ssl, void *buf, int num) {
+        int ret;
+        u32 zero = 0;
         u64 pid_tgid = bpf_get_current_pid_tgid();
         u32 pid = pid_tgid >> 32;
+        u32 tid = pid_tgid;
+        u32 uid = bpf_get_current_uid_gid();
 
-        FILTER
+        PID_FILTER
+        UID_FILTER
+        struct probe_SSL_data_t *data = ssl_data.lookup(&zero);
+        if (!data)
+                return 0;
 
-        struct probe_SSL_data_t __data = {0};
-        __data.timestamp_ns = bpf_ktime_get_ns();
-        __data.pid = pid;
-        __data.len = num;
+        data->timestamp_ns = bpf_ktime_get_ns();
+        data->pid = pid;
+        data->tid = tid;
+        data->uid = uid;
+        data->len = num;
+        data->buf_filled = 0;
+        bpf_get_current_comm(&data->comm, sizeof(data->comm));
+        u32 buf_copy_size = min((size_t)MAX_BUF_SIZE, (size_t)num);
 
-        bpf_get_current_comm(&__data.comm, sizeof(__data.comm));
+        if (buf != 0)
+                ret = bpf_probe_read_user(data->buf, buf_copy_size, buf);
 
-        if ( buf != 0) {
-                bpf_probe_read_user(&__data.v0, sizeof(__data.v0), buf);
-        }
+        if (!ret)
+                data->buf_filled = 1;
+        else
+                buf_copy_size = 0;
 
-        perf_SSL_write.perf_submit(ctx, &__data, sizeof(__data));
+        perf_SSL_write.perf_submit(ctx, data, EVENT_SIZE(buf_copy_size));
         return 0;
 }
 
@@ -94,47 +127,74 @@
         u64 pid_tgid = bpf_get_current_pid_tgid();
         u32 pid = pid_tgid >> 32;
         u32 tid = (u32)pid_tgid;
+        u32 uid = bpf_get_current_uid_gid();
 
-        FILTER
+        PID_FILTER
+        UID_FILTER
 
         bufs.update(&tid, (u64*)&buf);
         return 0;
 }
 
 int probe_SSL_read_exit(struct pt_regs *ctx, void *ssl, void *buf, int num) {
+        u32 zero = 0;
         u64 pid_tgid = bpf_get_current_pid_tgid();
         u32 pid = pid_tgid >> 32;
         u32 tid = (u32)pid_tgid;
+        u32 uid = bpf_get_current_uid_gid();
+        int ret;
 
-        FILTER
+        PID_FILTER
+        UID_FILTER
 
         u64 *bufp = bufs.lookup(&tid);
-        if (bufp == 0) {
+        if (bufp == 0)
                 return 0;
-        }
 
-        struct probe_SSL_data_t __data = {0};
-        __data.timestamp_ns = bpf_ktime_get_ns();
-        __data.pid = pid;
-        __data.len = PT_REGS_RC(ctx);
+        int len = PT_REGS_RC(ctx);
+        if (len <= 0) // read failed
+                return 0;
 
-        bpf_get_current_comm(&__data.comm, sizeof(__data.comm));
+        struct probe_SSL_data_t *data = ssl_data.lookup(&zero);
+        if (!data)
+                return 0;
 
-        if (bufp != 0) {
-                bpf_probe_read_user(&__data.v0, sizeof(__data.v0), (char *)*bufp);
-        }
+        data->timestamp_ns = bpf_ktime_get_ns();
+        data->pid = pid;
+        data->tid = tid;
+        data->uid = uid;
+        data->len = (u32)len;
+        data->buf_filled = 0;
+        u32 buf_copy_size = min((size_t)MAX_BUF_SIZE, (size_t)len);
+
+        bpf_get_current_comm(&data->comm, sizeof(data->comm));
+
+        if (bufp != 0)
+                ret = bpf_probe_read_user(&data->buf, buf_copy_size, (char *)*bufp);
 
         bufs.delete(&tid);
 
-        perf_SSL_read.perf_submit(ctx, &__data, sizeof(__data));
+        if (!ret)
+                data->buf_filled = 1;
+        else
+                buf_copy_size = 0;
+
+        perf_SSL_read.perf_submit(ctx, data, EVENT_SIZE(buf_copy_size));
         return 0;
 }
 """
 
 if args.pid:
-    prog = prog.replace('FILTER', 'if (pid != %d) { return 0; }' % args.pid)
+    prog = prog.replace('PID_FILTER', 'if (pid != %d) { return 0; }' % args.pid)
 else:
-    prog = prog.replace('FILTER', '')
+    prog = prog.replace('PID_FILTER', '')
+
+if args.uid is not None:
+    prog = prog.replace('UID_FILTER', 'if (uid != %d) { return 0; }' % args.uid)
+else:
+    prog = prog.replace('UID_FILTER', '')
+
+prog = prog.replace('__MAX_BUF_SIZE__', str(args.max_buffer_size))
 
 if args.debug or args.ebpf:
     print(prog)
@@ -179,14 +239,15 @@
                        fn_name="probe_SSL_read_exit", pid=args.pid or -1)
 
 # define output data structure in Python
-TASK_COMM_LEN = 16  # linux/sched.h
-MAX_BUF_SIZE = 464  # Limited by the BPF stack
 
 
 # header
-print("%-12s %-18s %-16s %-7s %-6s" % ("FUNC", "TIME(s)", "COMM", "PID",
-                                       "LEN"))
+header = "%-12s %-18s %-16s %-7s %-6s" % ("FUNC", "TIME(s)", "COMM", "PID", "LEN")
 
+if args.extra:
+    header += " %-7s %-7s" % ("UID", "TID")
+
+print(header)
 # process event
 start = 0
 
@@ -202,6 +263,16 @@
 def print_event(cpu, data, size, rw, evt):
     global start
     event = b[evt].event(data)
+    if event.len <= args.max_buffer_size:
+        buf_size = event.len
+    else:
+        buf_size = args.max_buffer_size
+
+    if event.buf_filled == 1:
+        buf = bytearray(event.buf[:buf_size])
+    else:
+        buf_size = 0
+        buf = b""
 
     # Filter events by command
     if args.comm:
@@ -216,19 +287,38 @@
 
     e_mark = "-" * 5 + " END DATA " + "-" * 5
 
-    truncated_bytes = event.len - MAX_BUF_SIZE
+    truncated_bytes = event.len - buf_size
     if truncated_bytes > 0:
         e_mark = "-" * 5 + " END DATA (TRUNCATED, " + str(truncated_bytes) + \
                 " bytes lost) " + "-" * 5
 
-    fmt = "%-12s %-18.9f %-16s %-7d %-6d\n%s\n%s\n%s\n\n"
+    base_fmt = "%(func)-12s %(time)-18.9f %(comm)-16s %(pid)-7d %(len)-6d"
+
+    if args.extra:
+        base_fmt += " %(uid)-7d %(tid)-7d"
+
+    fmt = ''.join([base_fmt, "\n%(begin)s\n%(data)s\n%(end)s\n\n"])
     if args.hexdump:
-        unwrapped_data = binascii.hexlify(event.v0)
-        data = textwrap.fill(unwrapped_data.decode('utf-8', 'replace'),width=32)
+        unwrapped_data = binascii.hexlify(buf)
+        data = textwrap.fill(unwrapped_data.decode('utf-8', 'replace'), width=32)
     else:
-        data = event.v0.decode('utf-8', 'replace')
-    print(fmt % (rw, time_s, event.comm.decode('utf-8', 'replace'),
-                 event.pid, event.len, s_mark, data, e_mark))
+        data = buf.decode('utf-8', 'replace')
+
+    fmt_data = {
+        'func': rw,
+        'time': time_s,
+        'comm': event.comm.decode('utf-8', 'replace'),
+        'pid': event.pid,
+        'tid': event.tid,
+        'uid': event.uid,
+        'len': event.len,
+        'begin': s_mark,
+        'end': e_mark,
+        'data': data
+    }
+
+    print(fmt % fmt_data)
+
 
 b["perf_SSL_write"].open_perf_buffer(print_event_write)
 b["perf_SSL_read"].open_perf_buffer(print_event_read)