packages/native/plugins/qjl-cpu/README.md
Standalone C library for QJL (1-bit JL transform) K-cache compression on
CPU. Mirrors the upstream CUDA reference at
packages/training/scripts/quantization/qjl/csrc/. Designed to be the
algorithmic source of truth for the future llama.cpp block_qjl1_256
GGML quant type and GGML_OP_ATTN_SCORE_QJL op (see the porting plan
at docs/porting/on-device-quantization-porting-plan.md section
"QJL NEON kernel + GGML K-cache hook").
This is not a llama.cpp integration. It is a self-contained library
For each cached key vector k of length 128 (per attention head, per
token):
||k|| and store as bf16 (2 bytes),Π ∈ R^{128 × 256} drawn from N(0, 1), giving sketch
s = k @ Π,sign(s) and pack 8 signs/byte LSB-first into 32 bytes.Block layout (block_qjl1_256): 32 bytes packed signs + 2 bytes bf16
norm = 34 bytes for what was 256 bytes of bf16 K. Compression
ratio: 7.53× over bf16 K-cache at head_dim=128.
The companion attention-score path takes a pre-projected query sketch
(length 256 per head) and reconstructs the inner product as
||k|| * sqrt(pi/2)/proj_dim * sum_j sign[t,j] * q_sketch[h_q,j],
which is the cosine-similarity estimator the QJL paper proves to be
unbiased and minimally distorted at 1 bit.
GQA sharing: a query head h_q reads from h_kv = h_q / (n_heads / n_kv_heads) of the K cache, matching the upstream
cuda_qjl_gqa_score kernel.
include/qjl/qjl.h Public API.
src/qjl_block.h Block-format static asserts.
src/qjl_projection.c Deterministic Π builder (MT + Box-Muller).
NOT bit-compatible with torch.randn —
for fixture parity, ship Π in the sidecar.
src/qjl_quantize_ref.c Scalar reference quantize/dequantize.
src/qjl_quantize_avx2.c AVX2 quantize/dequantize for x86_64.
src/qjl_quantize_neon.c NEON quantize/dequantize for AArch64.
src/qjl_score_ref.c Scalar reference GQA score.
src/qjl_score_avx2.c AVX2 GQA score.
src/qjl_score_neon.c NEON GQA score.
src/qjl_dispatch.c Compile-time best-impl dispatch + helpers.
test/qjl_bench.c --parity + --throughput modes.
test/fixtures/ Generated by scripts/gen_fixtures.py.
scripts/gen_fixtures.py Emits a binary fixture from torch's
qjl_pure_pytorch_quantize reference.
Native:
cmake -B build -S packages/native/plugins/qjl-cpu
cmake --build build -j
Cross-compile for arm64 (linux-musl, via zig):
printf '#!/bin/sh\nexec zig cc --target=aarch64-linux-musl "$@"\n' > /tmp/zigcc-aarch64
chmod +x /tmp/zigcc-aarch64
cmake -B build-arm64 -S packages/native/plugins/qjl-cpu \
-DCMAKE_C_COMPILER=/tmp/zigcc-aarch64 \
-DCMAKE_C_COMPILER_LAUNCHER= \
-DCMAKE_SYSTEM_NAME=Linux \
-DCMAKE_SYSTEM_PROCESSOR=aarch64 \
-DCMAKE_BUILD_TYPE=Release
cmake --build build-arm64 -j
Requires zig 0.13+. The same --target=aarch64-linux-android shape
also works for an Android NDK-style binary; the musl form is the one
the rest of this repo uses (see scripts/distro-android/compile-libllama.mjs).
Generate a fixture from the Python reference (requires torch):
python3 packages/native/plugins/qjl-cpu/scripts/gen_fixtures.py \
--out packages/native/plugins/qjl-cpu/test/fixtures/qjl_fixtures.bin \
--n 100
Run parity (compares scalar + AVX2 against the recorded Python output):
build/qjl_bench --parity packages/native/plugins/qjl-cpu/test/fixtures/qjl_fixtures.bin
Run throughput (uses the bundled MT-projection — does not need a fixture):
build/qjl_bench --throughput
Recorded on the development host (Intel x86_64, AVX2 + FMA, gcc 13.3,
-O3):
| Op | Scalar ref | AVX2 | Speedup |
|---|---|---|---|
quantize_row_* | ~13–15 µs / vec | ~0.85–1.05 µs / vec | ~14–16× |
score_qk per (q-head, token) | ~675 ns | ~20 ns | ~33–35× |
Parity status: 100/100 vectors match for both ref and avx2
against the qjl_pure_pytorch_quantize Python reference (signs
bit-exact, bf16 norms bit-exact). Score reconstruction matches the
Python reference to relative error < 3e-6 (the residual is
floating-point reduction-order noise; the score path is documented to
allow that).
NEON throughput on real arm64 hardware is TBD — see "Cross-arm64 gap" below.
The arm64 binary cross-compiles cleanly via zig cc --target=aarch64-linux-musl, producing a standalone aarch64 ELF that
links against musl. Running it on the available cuttlefish image was
not possible because:
ro.product.cpu.abi=x86_64 and
has no ARM binary translation (libndk_translation.so /
libhoudini.so) installed;qemu-aarch64 is present in the dev environment.Options when an arm64 surface becomes available:
launch_cvd --cpu_arch=arm64, requires KVM with arm64 support or a
true arm64 host),qemu-user-static on the host and run the arm64 binary via
qemu-aarch64 user-mode emulation,Options 1 and 2 are documentation-quality; option 3 is the production
target and pairs naturally with the next porting step (the llama.cpp
block_qjl1_256 GGML quant type).
The Python fixture (scripts/gen_fixtures.py) records:
The C parity test reads Π from the fixture rather than constructing it
in C, because the bundled MT-Box-Muller generator
(qjl_make_projection_mt) is not bit-compatible with torch.randn.
On a real production deployment the projection matrix must travel with
the QJL sidecar (see qjl_apply.py's rand_prj field), and the C
side reads it from there too — the design assumption is that Π is
small (128 * 256 * 4 = 128 KB per layer) and pinned at quantization
time.
Before this can be wired into the llama.cpp fork as a block_qjl1_256
GGML quant type:
block_qjl1_256 definition in ggml-common.h plus
ggml_quantize_chunk + ggml_dequantize_row_t table entries — the
glue is mechanical once this library is upstream.GGML_OP_ATTN_SCORE_QJL custom op in ggml-cpu.c that calls
qjl_score_qk() from this library on the existing K-cache stride
pattern.LLAMA_QJL_KCACHE build flag on the llama.cpp side that flips the
K-cache type to qjl1_256 and rewires the attention-score path to
the new op.ELIZA_LLAMA_CACHE_TYPE_K=qjl1_256 env-var plumbing in the eliza
shim — already structurally supported by
eliza_llama_context_params_set_type_k, only needs the new enum
value.The QJL-on-K + TurboQuant-on-V combination is the documented endgame (see porting plan section 3): same model, different per-side cache type, ~10× KV reduction at long context. This library is the K-side half of that.
Apache 2.0 — same as the upstream QJL implementation (see
packages/training/scripts/quantization/qjl/LICENSE).