benchs/bench_fw_notebook.ipynb
import matplotlib.pyplot as plt
import itertools
from faiss.contrib.evaluation import OperatingPoints
from enum import Enum
from faiss.benchs.bench_fw.benchmark_io import BenchmarkIO as BIO
from faiss.benchs.bench_fw.utils import filter_results, ParetoMode, ParetoMetric
from copy import copy
import numpy as np
import datetime
import glob
import io
import json
from zipfile import ZipFile
import tabulate
import getpass
username = getpass.getuser()
root = f"/home/{username}/simsearch/data/ivf/results/sift1M"
results = BIO(root).read_json("result.json")
results.keys()
results['experiments']
def plot_metric(experiments, accuracy_title, cost_title, plot_space=False, plot=None):
if plot is None:
plot = plt.subplot()
x = {}
y = {}
for accuracy, space, time, k, v in experiments:
idx_name = v['index'] + ("snap" if 'search_params' in v and v['search_params']["snap"] == 1 else "")
if idx_name not in x:
x[idx_name] = []
y[idx_name] = []
x[idx_name].append(accuracy)
if plot_space:
y[idx_name].append(space)
else:
y[idx_name].append(time)
#plt.figure(figsize=(10,6))
#plt.title(accuracy_title)
plot.set_xlabel(accuracy_title)
plot.set_ylabel(cost_title)
plot.set_yscale("log")
marker = itertools.cycle(("o", "v", "^", "<", ">", "s", "p", "P", "*", "h", "X", "D"))
for index in x.keys():
plot.plot(x[index], y[index], marker=next(marker), label=index, linewidth=0)
plot.legend(bbox_to_anchor=(1, 1), loc='upper left')
# index local optima
accuracy_metric = "knn_intersection"
fr = filter_results(results, evaluation="knn", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1, min_accuracy=0.95)
plot_metric(fr, accuracy_title="knn intersection", cost_title="time (seconds, 32 cores)", plot_space=False)
# global optima
accuracy_metric = "knn_intersection"
fr = filter_results(results, evaluation="knn", accuracy_metric=accuracy_metric, min_accuracy=0.25, name_filter=lambda n: not n.startswith("Flat"), pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
#fr = filter_results(results, evaluation="knn", accuracy_metric=accuracy_metric, min_accuracy=0.90, max_space=64, max_time=0, name_filter=lambda n: not n.startswith("Flat"), pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
plot_metric(fr, accuracy_title="knn intersection", cost_title="time (seconds, 32 cores)", plot_space=False)
def pretty_params(p):
p = copy(p)
if 'snap' in p and p['snap'] == 0:
del p['snap']
return p
tabulate.tabulate([(accuracy, space, time, v['factory'], pretty_params(v['construction_params'][1]), pretty_params(v['search_params']))
for accuracy, space, time, k, v in fr],
tablefmt="html",
headers=["accuracy","space", "time", "factory", "quantizer cfg", "search cfg"])
# index local optima @ precision 0.8
precision = 0.8
accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)
fr = filter_results(results, evaluation="weighted", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
plot_metric(fr, accuracy_title=f"range recall @ precision {precision}", cost_title="time (seconds, 16 cores)")
# index local optima @ precision 0.2
precision = 0.2
accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)
fr = filter_results(results, evaluation="weighted", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
plot_metric(fr, accuracy_title=f"range recall @ precision {precision}", cost_title="time (seconds, 16 cores)")
# global optima @ precision 0.8
precision = 0.8
accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)
fr = filter_results(results, evaluation="weighted", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
plot_metric(fr, accuracy_title=f"range recall @ precision {precision}", cost_title="time (seconds, 16 cores)")
def plot_range_search_pr_curves(experiments):
x = {}
y = {}
show = {
'Flat': None,
}
for _, _, _, k, v in fr:
if ".weighted" in k: # and v['index'] in show:
x[k] = v['range_search_pr']['recall']
y[k] = v['range_search_pr']['precision']
plt.title("range search recall")
plt.xlabel("recall")
plt.ylabel("precision")
for index in x.keys():
plt.plot(x[index], y[index], '.', label=index)
plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')
precision = 0.8
accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)
fr = filter_results(results, evaluation="weighted", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME_SPACE, scaling_factor=1)
plot_range_search_pr_curves(fr)
root = "/checkpoint/gsz/bench_fw/ivf/bigann"
scales = [1, 2, 5, 10, 20, 50]
fig, plots = plt.subplots(len(scales), sharex=True, figsize=(5,25))
fig.tight_layout()
for plot, scale in zip(plots, scales, strict=True):
results = BIO(root).read_json(f"result{scale}.json")
accuracy_metric = "knn_intersection"
fr = filter_results(results, evaluation="knn", accuracy_metric=accuracy_metric, min_accuracy=0.9, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
plot_metric(fr, accuracy_title="knn intersection", cost_title="time (seconds, 64 cores)", plot=plot)
x = {}
y = {}
accuracy=0.9
root = "/checkpoint/gsz/bench_fw/ivf/bigann"
scales = [1, 2, 5, 10, 20, 50]
#fig, plots = plt.subplots(len(scales), sharex=True, figsize=(5,25))
#fig.tight_layout()
for scale in scales:
results = BIO(root).read_json(f"result{scale}.json")
scale *= 1_000_000
accuracy_metric = "knn_intersection"
fr = filter_results(results, evaluation="knn", accuracy_metric=accuracy_metric, min_accuracy=accuracy, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
seen = set()
print(scale)
for _, _, _, _, exp in fr:
fact = exp["factory"]
# "HNSW" in fact or
if fact in seen or fact in ["Flat", "IVF512,Flat", "IVF1024,Flat", "IVF2048,Flat"]:
continue
seen.add(fact)
if fact not in x:
x[fact] = []
y[fact] = []
x[fact].append(scale)
y[fact].append(exp["time"] + exp["quantizer"]["time"])
if (exp["knn_intersection"] > 0.92):
print(fact)
print(exp["search_params"])
print(exp["knn_intersection"])
#plot_metric(fr, accuracy_title="knn intersection", cost_title="time (seconds, 64 cores)", plot=plot)
plt.title(f"recall @ 1 = {accuracy*100}%")
plt.xlabel("database size")
plt.ylabel("time")
plt.xscale("log")
plt.yscale("log")
marker = itertools.cycle(("o", "v", "^", "<", ">", "s", "p", "P", "*", "h", "X", "D"))
for index in x.keys():
if "HNSW" in index:
plt.plot(x[index], y[index], label=index, linewidth=1, marker=next(marker), linestyle="dashed")
else:
plt.plot(x[index], y[index], label=index, linewidth=1, marker=next(marker))
plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')
# global optima
accuracy_metric = "sym_recall"
fr = filter_results(results, evaluation="rec", accuracy_metric=accuracy_metric, time_metric=lambda e:e['encode_time'], min_accuracy=0.9, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.SPACE, scaling_factor=1)
plot_metric(fr, accuracy_title="knn intersection", cost_title="space", plot_space=True)
def pretty_time(s):
if s is None:
return "None"
s = int(s * 1000) / 1000
m, s = divmod(s, 60)
h, m = divmod(m, 60)
d, h = divmod(h, 24)
r = ""
if d > 0:
r += f"{int(d)}d "
if h > 0:
r += f"{int(h)}h "
if m > 0:
r += f"{int(m)}m "
if s > 0 or len(r) == 0:
r += f"{s:.3f}s"
return r
def pretty_size(s):
if s > 1024 * 1024:
return f"{s / 1024 / 1024:.1f}".rstrip('0').rstrip('.') + "MB"
if s > 1024:
return f"{s / 1024:.1f}".rstrip('0').rstrip('.') + "KB"
return f"{s}"
def pretty_mse(m):
if m is None:
return "None"
else:
return f"{m:.6f}"
data = {}
root = "/checkpoint/gsz/bench_fw/bigann"
scales = [1, 2, 5, 10, 20, 50]
for scale in scales:
results = BIO(root).read_json(f"result{scale}.json")
accuracy_metric = "knn_intersection"
fr = filter_results(results, evaluation="knn", accuracy_metric=accuracy_metric, min_accuracy=0, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
d = {}
data[f"{scale}M"] = d
for _, _, _, _, exp in fr:
fact = exp["factory"]
# "HNSW" in fact or
if fact in ["Flat", "IVF512,Flat", "IVF1024,Flat", "IVF2048,Flat"]:
continue
if fact not in d:
d[fact] = []
d[fact].append({
"nprobe": exp["search_params"]["nprobe"],
"recall": exp["knn_intersection"],
"time": exp["time"] + exp["quantizer"]["time"],
})
data
# with open("/checkpoint/gsz/bench_fw/codecs.json", "w") as f:
# json.dump(data, f)
ds = "deep1b"
data = []
jss = []
root = f"/checkpoint/gsz/bench_fw/codecs/{ds}"
results = BIO(root).read_json(f"result.json")
for k, e in results["experiments"].items():
if "rec" in k and e['factory'] != 'Flat': # and e['sym_recall'] > 0.0: # and "PRQ" in e['factory'] and e['sym_recall'] > 0.0:
code_size = results['indices'][e['codec']]['sa_code_size']
codec_size = results['indices'][e['codec']]['codec_size']
training_time = results['indices'][e['codec']]['training_time']
# training_size = results['indices'][e['codec']]['training_size']
cpu = e['cpu'] if 'cpu' in e else ""
ps = ', '.join([f"{k}={v}" for k,v in e['construction_params'][0].items()]) if e['construction_params'] else " "
eps = ', '.join([f"{k}={v}" for k,v in e['reconstruct_params'].items() if k != "snap"]) if e['reconstruct_params'] else " "
data.append((code_size, f"|{e['factory']}|{ps}|{eps}|{code_size}|{pretty_size(codec_size)}|{pretty_time(training_time)}|{training_size}|{pretty_mse(e['mse'])}|{e['sym_recall']}|{e['asym_recall']}|{pretty_time(e['encode_time'])}|{pretty_time(e['decode_time'])}|{cpu}|"))
jss.append({
'factory': e['factory'],
'parameters': e['construction_params'][0] if e['construction_params'] else "",
'evaluation_params': e['reconstruct_params'],
'code_size': code_size,
'codec_size': codec_size,
'training_time': training_time,
'training_size': training_size,
'mse': e['mse'],
'sym_recall': e['sym_recall'],
'asym_recall': e['asym_recall'],
'encode_time': e['encode_time'],
'decode_time': e['decode_time'],
'cpu': cpu,
})
print("|factory key|construction parameters|evaluation parameters|code size|codec size|training time|training size|mean squared error|sym recall @ 1|asym recall @ 1|encode time|decode time|cpu|")
print("|-|-|-|-|-|-|-|-|-|")
data.sort()
for d in data:
print(d[1])
with open(f"/checkpoint/gsz/bench_fw/codecs_{ds}_test.json", "w") as f:
json.dump(jss, f)
def read_file(filename: str, keys):
results = []
with ZipFile(filename, "r") as zip_file:
for key in keys:
with zip_file.open(key, "r") as f:
if key in ["D", "I", "R", "lims"]:
results.append(np.load(f))
elif key in ["P"]:
t = io.TextIOWrapper(f)
results.append(json.load(t))
else:
raise AssertionError()
return results
ds = "contriever"
data = []
jss = []
root = f"/checkpoint/gsz/bench_fw/codecs/{ds}"
for lf in glob.glob(root + '/*rec*.zip'):
e, = read_file(lf, ['P'])
if e['factory'] != 'Flat': # and e['sym_recall'] > 0.0: # and "PRQ" in e['factory'] and e['sym_recall'] > 0.0:
code_size = e['codec_meta']['sa_code_size']
codec_size = e['codec_meta']['codec_size']
training_time = e['codec_meta']['training_time']
training_size = None # e['codec_meta']['training_size']
cpu = e['cpu'] if 'cpu' in e else ""
ps = ', '.join([f"{k}={v}" for k,v in e['construction_params'][0].items()]) if e['construction_params'] else " "
eps = ', '.join([f"{k}={v}" for k,v in e['reconstruct_params'].items() if k != "snap"]) if e['reconstruct_params'] else " "
if eps in ps and eps != "encode_ils_iters=16" and eps != "max_beam_size=32":
eps = " "
data.append((code_size, f"|{e['factory']}|{ps}|{eps}|{code_size}|{pretty_size(codec_size)}|{pretty_time(training_time)}|{pretty_mse(e['mse'])}|{e['sym_recall']}|{e['asym_recall']}|{pretty_time(e['encode_time'])}|{pretty_time(e['decode_time'])}|{cpu}|"))
eps = e['reconstruct_params']
del eps['snap']
params = copy(e['construction_params'][0]) if e['construction_params'] else {}
for k, v in e['reconstruct_params'].items():
params[k] = v
jss.append({
'factory': e['factory'],
'params': params,
'construction_params': e['construction_params'][0] if e['construction_params'] else {},
'evaluation_params': e['reconstruct_params'],
'code_size': code_size,
'codec_size': codec_size,
'training_time': training_time,
# 'training_size': training_size,
'mse': e['mse'],
'sym_recall': e['sym_recall'],
'asym_recall': e['asym_recall'],
'encode_time': e['encode_time'],
'decode_time': e['decode_time'],
'cpu': cpu,
})
print("|factory key|construction parameters|encode/decode parameters|code size|codec size|training time|mean squared error|sym recall @ 1|asym recall @ 1|encode time|decode time|cpu|")
print("|-|-|-|-|-|-|-|-|-|")
data.sort()
# for d in data:
# print(d[1])
print(len(data))
with open(f"/checkpoint/gsz/bench_fw/codecs_{ds}_5.json", "w") as f:
json.dump(jss, f)