#!/usr/bin/env python

from __future__ import print_function
from bcc import BPF
from time import sleep, strftime, time
import argparse
import signal

description = """The fdbcstat utility displays FDB C API statistics on terminal
that include calls-per-second, average latency and maximum latency
within the given time interval.

Each field in the output represents the following elements
in a slash-separated format:
- Operation type
- Number of calls per second
- Average latency in microseconds (us)
- Maximum latency in microseconds (us)
"""

# supported APIs
# note: the array index is important here.
#       it's used in BPF as the funciton identifier.
# 0: get
# 1: get_range
# 2: get_read_version
# 3: set
# 4: clear
# 5: clear_range
# 6: commit
fdbfuncs = [
    { "name":"get",              "waitfuture":True,  "enabled":True },
    { "name":"get_range",        "waitfuture":True,  "enabled":True },
    { "name":"get_read_version", "waitfuture":True,  "enabled":True },
    { "name":"set",              "waitfuture":False, "enabled":True },
    { "name":"clear",            "waitfuture":False, "enabled":True },
    { "name":"clear_range",      "waitfuture":False, "enabled":True },
    { "name":"commit",           "waitfuture":True,  "enabled":True }
]

# arguments
parser = argparse.ArgumentParser(
    description="FoundationDB client statistics collector",
    formatter_class=argparse.RawTextHelpFormatter,
    epilog=description)
parser.add_argument("-p", "--pid", type=int,
    help="Capture for this PID only")
parser.add_argument("-i", "--interval", type=int,
    help="Print interval in seconds (Default: 1 second)")
parser.add_argument("-d", "--duration", type=int,
    help="Duration in seconds (Default: unset)")
parser.add_argument("-f", "--functions", type=str,
    help='''Capture for specific functions (comma-separated) (Default: unset)
Supported functions: get, get_range, get_read_version,
                     set, clear, clear_range, commit''')
parser.add_argument("libpath",
    help="Full path to libfdb_c.so")
args = parser.parse_args()

if not args.interval:
    args.interval = 1

if args.functions:
    # reset all
    idx=0
    while idx < len(fdbfuncs):
        fdbfuncs[idx]['enabled'] = False
        idx += 1

    # enable specified functions
    for f in args.functions.split(','):
        idx=0
        while idx < len(fdbfuncs):
            if fdbfuncs[idx]['name'] == f:
                fdbfuncs[idx]['enabled'] = True
            idx += 1

# check for libfdb_c.so
libpath = BPF.find_library(args.libpath) or BPF.find_exe(args.libpath)
if libpath is None:
    print("Error: Can't find %s" % args.libpath)
    exit(1)

# main BPF program
# we do not rely on PT_REGS_IP() and BPF.sym() to retrive the symbol name
# because some "backword-compatible" symbols do not get resovled through BPF.sym().
bpf_text = """
#include <uapi/linux/ptrace.h>

typedef struct _stats_key_t {
    u32 pid;
    u32 func;
} stats_key_t;

typedef struct _stats_val_t {
    u64 cnt;
    u64 total;
    u64 max;
} stats_val_t;

BPF_HASH(starttime, u32, u64);
BPF_HASH(startfunc, u32, u32);
BPF_HASH(stats, stats_key_t, stats_val_t);

static int trace_common_entry(struct pt_regs *ctx, u32 func)
{
    u64 pid_tgid = bpf_get_current_pid_tgid();
    u32 pid = pid_tgid;        /* lower 32-bit = Process ID (Thread ID) */
    u32 tgid = pid_tgid >> 32; /* upper 32-bit = Thread Group ID (Process ID) */

    /* if PID is specified, we'll filter by tgid here */
    FILTERPID

    /* start time in ns */
    u64 ts = bpf_ktime_get_ns();

    /* function type */
    u32 f = func;
    startfunc.update(&pid, &f);

    /* update start time */
    starttime.update(&pid, &ts);

    return 0;
}

int trace_get_entry(struct pt_regs *ctx)
{
    return trace_common_entry(ctx, 0);
}

int trace_get_range_entry(struct pt_regs *ctx)
{
    return trace_common_entry(ctx, 1);
}

int trace_get_read_version_entry(struct pt_regs *ctx)
{
    return trace_common_entry(ctx, 2);
}

int trace_set_entry(struct pt_regs *ctx)
{
    return trace_common_entry(ctx, 3);
}

int trace_clear_entry(struct pt_regs *ctx)
{
    return trace_common_entry(ctx, 4);
}

int trace_clear_range_entry(struct pt_regs *ctx)
{
    return trace_common_entry(ctx, 5);
}

int trace_commit_entry(struct pt_regs *ctx)
{
    return trace_common_entry(ctx, 6);
}

int trace_func_return(struct pt_regs *ctx)
{
    u64 *st; /* start time */
    u64 duration;
    u64 pid_tgid = bpf_get_current_pid_tgid();
    u32 pid = pid_tgid;
    u32 tgid = pid_tgid >> 32;

    /* if PID is specified, we'll filter by tgid here */
    FILTERPID

    /* calculate duration in ns */
    st = starttime.lookup(&pid);
    if (!st || st == 0) {
        return 0; /* missed start */
    }
    /* duration in ns */
    duration = bpf_ktime_get_ns() - *st;
    starttime.delete(&pid);

    /* update stats */
    u32 func, *funcp = startfunc.lookup(&pid);
    if (funcp) {
        func = *funcp;
        stats_key_t key;
        stats_val_t *prev;
        stats_val_t cur;
        key.pid = pid; /* pid here is the thread ID in user space */
        key.func = func;
        prev = stats.lookup(&key);
        if (prev) {
            cur.cnt = prev->cnt + 1;
            cur.total = prev->total + duration;
            cur.max = (duration > prev->max) ? duration : prev->max;
            stats.update(&key, &cur);
        } else {
            cur.cnt = 1;
            cur.total = duration;
            cur.max = duration;
            stats.insert(&key, &cur);
        }
        startfunc.delete(&pid);
    }
    return 0;
}
"""

# If PID is specified, insert the PID filter
if args.pid:
    bpf_text = bpf_text.replace('FILTERPID',
        'if (tgid != %d) { return 0; }' % args.pid)
else:
    bpf_text = bpf_text.replace('FILTERPID', '')

# signal handler
def signal_ignore(signal, frame):
    pass

# load BPF program
b = BPF(text=bpf_text)

# attach probes
waitfuture = False;
for f in fdbfuncs:

    # skip disabled functions
    if not f['enabled']:
        continue

    # attach the entry point
    b.attach_uprobe(name=libpath, sym='fdb_transaction_'+f['name'],
                    fn_name='trace_' + f['name'] + '_entry', pid=args.pid or -1)
    if f['waitfuture']:
        waitfuture = True
    else:
        b.attach_uretprobe(name=libpath, sym='fdb_transaction_'+f['name'],
                           fn_name="trace_func_return", pid=args.pid or -1)
if waitfuture:
    b.attach_uretprobe(name=libpath, sym='fdb_future_block_until_ready',
                       fn_name="trace_func_return", pid=args.pid or -1)
    
# open uprobes
matched = b.num_open_uprobes()

if matched == 0:
    print("0 functions matched... Exiting.")
    exit()

stats = b.get_table("stats")

# aggregated stats dictionary
agg = {}

exiting = 0
seconds = 0
prev = 0.0
now = 0.0

# main loop
while (1):
    try:
        sleep(args.interval)
        seconds += args.interval
        prev = now
        now = time()
        if prev == 0:
            stats.clear()
            continue
    except KeyboardInterrupt:
        exiting = 1
        signal.signal(signal.SIGINT, signal_ignore)

    if args.duration and seconds >= args.duration:
        exiting = 1

    # walk through the stats and aggregate by the functions
    for k,v in stats.items():
        f = fdbfuncs[k.func]['name']
        if f in agg:
            # update an exiting entry
            agg[f]['cnt'] = agg[f]['cnt'] + v.cnt
            agg[f]['total'] = agg[f]['total'] + v.total;
            if v.cnt > agg[f]['max']:
                agg[f]['max'] = v.cnt
        else:
            # insert a new entry
            agg[f] = {'cnt':v.cnt, 'total':v.total, 'max':v.max}

    # print out aggregated stats
    print("%-8s " % (strftime("%H:%M:%S")), end="", flush=True)
    for f in sorted(agg):
        print("%s/%d/%d/%d " % (f,
                                agg[f]['cnt'] / (now - prev),
                                agg[f]['total']/agg[f]['cnt'] / 1000, # us
                                agg[f]['max'] / 1000), # us
              end="")
    print()
        
    stats.clear()
    agg.clear()

    if exiting:
        exit()
