Back to Faiss

Bench Fw Notebook

benchs/bench_fw_notebook.ipynb

1.14.114.4 KB
Original Source
python
import matplotlib.pyplot as plt
import itertools
from faiss.contrib.evaluation import OperatingPoints
from enum import Enum
from faiss.benchs.bench_fw.benchmark_io import BenchmarkIO as BIO
from faiss.benchs.bench_fw.utils import filter_results, ParetoMode, ParetoMetric
from copy import copy
import numpy as np
import datetime
import glob
import io
import json
from zipfile import ZipFile
import tabulate
python
import getpass
username = getpass.getuser()
root = f"/home/{username}/simsearch/data/ivf/results/sift1M"
results = BIO(root).read_json("result.json")
results.keys()
python
results['experiments']
python
def plot_metric(experiments, accuracy_title, cost_title, plot_space=False, plot=None):
    if plot is None:
        plot = plt.subplot()
    x = {}
    y = {}
    for accuracy, space, time, k, v in experiments:
        idx_name = v['index'] + ("snap" if 'search_params' in v and v['search_params']["snap"] == 1 else "")
        if idx_name not in x:
            x[idx_name] = []
            y[idx_name] = []
        x[idx_name].append(accuracy)
        if plot_space:
            y[idx_name].append(space)
        else:
            y[idx_name].append(time)

    #plt.figure(figsize=(10,6))
    #plt.title(accuracy_title)
    plot.set_xlabel(accuracy_title)
    plot.set_ylabel(cost_title)
    plot.set_yscale("log")
    marker = itertools.cycle(("o", "v", "^", "<", ">", "s", "p", "P", "*", "h", "X", "D"))    
    for index in x.keys():
        plot.plot(x[index], y[index], marker=next(marker), label=index, linewidth=0)
    plot.legend(bbox_to_anchor=(1, 1), loc='upper left')
python
# index local optima
accuracy_metric = "knn_intersection"
fr = filter_results(results, evaluation="knn", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1, min_accuracy=0.95)
plot_metric(fr, accuracy_title="knn intersection", cost_title="time (seconds, 32 cores)", plot_space=False)
python
# global optima
accuracy_metric = "knn_intersection"
fr = filter_results(results, evaluation="knn", accuracy_metric=accuracy_metric, min_accuracy=0.25, name_filter=lambda n: not n.startswith("Flat"), pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
#fr = filter_results(results, evaluation="knn", accuracy_metric=accuracy_metric, min_accuracy=0.90, max_space=64, max_time=0, name_filter=lambda n: not n.startswith("Flat"), pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
plot_metric(fr, accuracy_title="knn intersection", cost_title="time (seconds, 32 cores)", plot_space=False)
python
def pretty_params(p):
    p = copy(p)
    if 'snap' in p and p['snap'] == 0:
        del p['snap']
    return p
    
tabulate.tabulate([(accuracy, space, time, v['factory'], pretty_params(v['construction_params'][1]), pretty_params(v['search_params'])) 
                for accuracy, space, time, k, v in fr],
                tablefmt="html",
                headers=["accuracy","space", "time", "factory", "quantizer cfg", "search cfg"])
python
# index local optima @ precision 0.8
precision = 0.8
accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)
fr = filter_results(results, evaluation="weighted", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
plot_metric(fr, accuracy_title=f"range recall @ precision {precision}", cost_title="time (seconds, 16 cores)")
python
# index local optima @ precision 0.2
precision = 0.2
accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)
fr = filter_results(results, evaluation="weighted", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
plot_metric(fr, accuracy_title=f"range recall @ precision {precision}", cost_title="time (seconds, 16 cores)")
python
# global optima @ precision 0.8
precision = 0.8
accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)
fr = filter_results(results, evaluation="weighted", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
plot_metric(fr, accuracy_title=f"range recall @ precision {precision}", cost_title="time (seconds, 16 cores)")
python
def plot_range_search_pr_curves(experiments):
    x = {}
    y = {}
    show = {
        'Flat': None,
    }
    for _, _, _, k, v in fr:
        if ".weighted" in k: # and v['index'] in show:
            x[k] = v['range_search_pr']['recall']
            y[k] = v['range_search_pr']['precision']
    
    plt.title("range search recall")
    plt.xlabel("recall")
    plt.ylabel("precision")
    for index in x.keys():
        plt.plot(x[index], y[index], '.', label=index)
    plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')
python
precision = 0.8
accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)
fr = filter_results(results, evaluation="weighted", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME_SPACE, scaling_factor=1)
plot_range_search_pr_curves(fr)
python
root = "/checkpoint/gsz/bench_fw/ivf/bigann"
scales = [1, 2, 5, 10, 20, 50]
fig, plots = plt.subplots(len(scales), sharex=True, figsize=(5,25))
fig.tight_layout()
for plot, scale in zip(plots, scales, strict=True):
    results = BIO(root).read_json(f"result{scale}.json")
    accuracy_metric = "knn_intersection"
    fr = filter_results(results, evaluation="knn", accuracy_metric=accuracy_metric, min_accuracy=0.9, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
    plot_metric(fr, accuracy_title="knn intersection", cost_title="time (seconds, 64 cores)", plot=plot)
python
x = {}
y = {}
accuracy=0.9
root = "/checkpoint/gsz/bench_fw/ivf/bigann"
scales = [1, 2, 5, 10, 20, 50]
#fig, plots = plt.subplots(len(scales), sharex=True, figsize=(5,25))
#fig.tight_layout()
for scale in scales:
    results = BIO(root).read_json(f"result{scale}.json")
    scale *= 1_000_000
    accuracy_metric = "knn_intersection"
    fr = filter_results(results, evaluation="knn", accuracy_metric=accuracy_metric, min_accuracy=accuracy, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
    seen = set()
    print(scale)
    for _, _, _, _, exp in fr:
        fact = exp["factory"]
        # "HNSW" in fact or 
        if fact in seen or fact in ["Flat", "IVF512,Flat", "IVF1024,Flat", "IVF2048,Flat"]:
            continue
        seen.add(fact)
        if fact not in x:
            x[fact] = []
            y[fact] = []
        x[fact].append(scale)
        y[fact].append(exp["time"] + exp["quantizer"]["time"])
        if (exp["knn_intersection"] > 0.92):
            print(fact)
            print(exp["search_params"])
            print(exp["knn_intersection"])

        #plot_metric(fr, accuracy_title="knn intersection", cost_title="time (seconds, 64 cores)", plot=plot)
    
plt.title(f"recall @ 1 = {accuracy*100}%")
plt.xlabel("database size")
plt.ylabel("time")
plt.xscale("log")
plt.yscale("log")

marker = itertools.cycle(("o", "v", "^", "<", ">", "s", "p", "P", "*", "h", "X", "D"))    
for index in x.keys():
    if "HNSW" in index:
        plt.plot(x[index], y[index], label=index, linewidth=1, marker=next(marker), linestyle="dashed")
    else:
        plt.plot(x[index], y[index], label=index, linewidth=1, marker=next(marker))
plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')
python
# global optima
accuracy_metric = "sym_recall"
fr = filter_results(results, evaluation="rec", accuracy_metric=accuracy_metric, time_metric=lambda e:e['encode_time'], min_accuracy=0.9, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.SPACE, scaling_factor=1)
plot_metric(fr, accuracy_title="knn intersection", cost_title="space", plot_space=True)
python
def pretty_time(s):
    if s is None:
        return "None"
    s = int(s * 1000) / 1000
    m, s = divmod(s, 60)
    h, m = divmod(m, 60)
    d, h = divmod(h, 24)
    r = ""
    if d > 0:
        r += f"{int(d)}d "
    if h > 0:
        r += f"{int(h)}h "
    if m > 0:
        r += f"{int(m)}m "
    if s > 0 or len(r) == 0:
        r += f"{s:.3f}s"
    return r

def pretty_size(s):
    if s > 1024 * 1024:
        return f"{s / 1024 / 1024:.1f}".rstrip('0').rstrip('.') + "MB"
    if s > 1024:
        return f"{s / 1024:.1f}".rstrip('0').rstrip('.') + "KB"
    return f"{s}"

def pretty_mse(m):
    if m is None:
        return "None"
    else:
        return f"{m:.6f}"
python
data = {}
root = "/checkpoint/gsz/bench_fw/bigann"
scales = [1, 2, 5, 10, 20, 50]
for scale in scales:
    results = BIO(root).read_json(f"result{scale}.json")
    accuracy_metric = "knn_intersection"
    fr = filter_results(results, evaluation="knn", accuracy_metric=accuracy_metric, min_accuracy=0, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
    d = {}
    data[f"{scale}M"] = d
    for _, _, _, _, exp in fr:
        fact = exp["factory"]
        # "HNSW" in fact or 
        if fact in ["Flat", "IVF512,Flat", "IVF1024,Flat", "IVF2048,Flat"]:
            continue
        if fact not in d:
            d[fact] = []
        d[fact].append({
            "nprobe": exp["search_params"]["nprobe"],
            "recall": exp["knn_intersection"],
            "time": exp["time"] + exp["quantizer"]["time"],
        })
data
# with open("/checkpoint/gsz/bench_fw/codecs.json", "w") as f:
#    json.dump(data, f)
python
ds = "deep1b"
data = []
jss = []
root = f"/checkpoint/gsz/bench_fw/codecs/{ds}"
results = BIO(root).read_json(f"result.json")
for k, e in results["experiments"].items():
    if "rec" in k and e['factory'] != 'Flat': # and e['sym_recall'] > 0.0: # and "PRQ" in e['factory'] and e['sym_recall'] > 0.0:
        code_size = results['indices'][e['codec']]['sa_code_size']
        codec_size = results['indices'][e['codec']]['codec_size']
        training_time = results['indices'][e['codec']]['training_time']
        # training_size = results['indices'][e['codec']]['training_size']
        cpu = e['cpu'] if 'cpu' in e else ""
        ps = ', '.join([f"{k}={v}" for k,v in e['construction_params'][0].items()]) if e['construction_params'] else " "
        eps = ', '.join([f"{k}={v}" for k,v in e['reconstruct_params'].items() if k != "snap"]) if e['reconstruct_params'] else " "
        data.append((code_size, f"|{e['factory']}|{ps}|{eps}|{code_size}|{pretty_size(codec_size)}|{pretty_time(training_time)}|{training_size}|{pretty_mse(e['mse'])}|{e['sym_recall']}|{e['asym_recall']}|{pretty_time(e['encode_time'])}|{pretty_time(e['decode_time'])}|{cpu}|"))
        jss.append({
            'factory': e['factory'],
            'parameters': e['construction_params'][0] if e['construction_params'] else "",
            'evaluation_params': e['reconstruct_params'],
            'code_size': code_size,
            'codec_size': codec_size,
            'training_time': training_time,
            'training_size': training_size,
            'mse': e['mse'],
            'sym_recall': e['sym_recall'],
            'asym_recall': e['asym_recall'],
            'encode_time': e['encode_time'],
            'decode_time': e['decode_time'],
            'cpu': cpu,
        })

print("|factory key|construction parameters|evaluation parameters|code size|codec size|training time|training size|mean squared error|sym recall @ 1|asym recall @ 1|encode time|decode time|cpu|")
print("|-|-|-|-|-|-|-|-|-|")
data.sort()
for d in data:
    print(d[1])

with open(f"/checkpoint/gsz/bench_fw/codecs_{ds}_test.json", "w") as f:
    json.dump(jss, f)
python
def read_file(filename: str, keys):
    results = []
    with ZipFile(filename, "r") as zip_file:
        for key in keys:
            with zip_file.open(key, "r") as f:
                if key in ["D", "I", "R", "lims"]:
                    results.append(np.load(f))
                elif key in ["P"]:
                    t = io.TextIOWrapper(f)
                    results.append(json.load(t))
                else:
                    raise AssertionError()
    return results
python
ds = "contriever"
data = []
jss = []
root = f"/checkpoint/gsz/bench_fw/codecs/{ds}"
for lf in glob.glob(root + '/*rec*.zip'):
    e, = read_file(lf, ['P'])
    if e['factory'] != 'Flat': # and e['sym_recall'] > 0.0: # and "PRQ" in e['factory'] and e['sym_recall'] > 0.0:
        code_size = e['codec_meta']['sa_code_size']
        codec_size = e['codec_meta']['codec_size']
        training_time = e['codec_meta']['training_time']
        training_size = None # e['codec_meta']['training_size']
        cpu = e['cpu'] if 'cpu' in e else ""
        ps = ', '.join([f"{k}={v}" for k,v in e['construction_params'][0].items()]) if e['construction_params'] else " "
        eps = ', '.join([f"{k}={v}" for k,v in e['reconstruct_params'].items() if k != "snap"]) if e['reconstruct_params'] else " "
        if eps in ps and eps != "encode_ils_iters=16" and eps != "max_beam_size=32":
           eps = " "
        data.append((code_size, f"|{e['factory']}|{ps}|{eps}|{code_size}|{pretty_size(codec_size)}|{pretty_time(training_time)}|{pretty_mse(e['mse'])}|{e['sym_recall']}|{e['asym_recall']}|{pretty_time(e['encode_time'])}|{pretty_time(e['decode_time'])}|{cpu}|"))
        eps = e['reconstruct_params']
        del eps['snap']
        params = copy(e['construction_params'][0]) if e['construction_params'] else {}
        for k, v in e['reconstruct_params'].items():
            params[k] = v
        jss.append({
            'factory': e['factory'],
            'params': params,
            'construction_params': e['construction_params'][0] if e['construction_params'] else {},
            'evaluation_params': e['reconstruct_params'],
            'code_size': code_size,
            'codec_size': codec_size,
            'training_time': training_time,
            # 'training_size': training_size,
            'mse': e['mse'],
            'sym_recall': e['sym_recall'],
            'asym_recall': e['asym_recall'],
            'encode_time': e['encode_time'],
            'decode_time': e['decode_time'],
            'cpu': cpu,
        })

print("|factory key|construction parameters|encode/decode parameters|code size|codec size|training time|mean squared error|sym recall @ 1|asym recall @ 1|encode time|decode time|cpu|")
print("|-|-|-|-|-|-|-|-|-|")
data.sort()
# for d in data:
#   print(d[1])

print(len(data))

with open(f"/checkpoint/gsz/bench_fw/codecs_{ds}_5.json", "w") as f:
    json.dump(jss, f)