crates/burn-flex/README.md
A fast, memory-efficient CPU backend for Burn with multi-threading, SIMD, and optimized matrix multiplication. Runs on std, no_std, and WebAssembly. Supports f16/bf16, zero-copy data loading, and is thread-safe by design.
Detailed comparison with burn-ndarray: Full architecture, feature coverage, operation-by-operation analysis, and migration path.
transpose, permute, flip, narrow, sliceunfold (sliding windows as strided view instead of materialization)expand (broadcast via zero strides)q.matmul(k.swap_dims(-2,-1))) run at contiguous speed with no copy.no_std.scale * x_q dequantization instead of reparsing packed bytes.Bytes, Shape, TensorData,
Element trait) from burn-backend and burn-stdDefault: std, simd, rayon.
| Flag | Default | Description |
|---|---|---|
std | Yes | Standard library support |
simd | Yes | Portable SIMD via macerator; also enables gemm/wasm-simd128-enable |
rayon | Yes | Parallel execution for large tensors (forwards gemm/rayon) |
x86-v4 | No | AVX-512 kernels in gemm for x86_64 (Sapphire Rapids, Zen 4/5) |
apple-amx | No | Apple Silicon AMX matrix coprocessor in gemm (experimental) |
tracing | No | Propagate tracing instrumentation |
critical-section | No | Support for no_std targets without atomic CAS |
Enable opt-in paths by passing them to Cargo, either directly on burn-flex or through the
top-level burn crate:
# Direct
burn-flex = { version = "0.21", features = ["apple-amx"] }
# Via burn
burn = { version = "0.21", features = ["flex", "apple-amx"] }
See ARCHITECTURE.md#feature-flags for per-case benchmark impact.
burn-ndarray depends on the ndarray crate, which has been slow to accept contributions and evolve.
burn-flex was built as a from-scratch replacement that addresses the gaps while maintaining full compatibility with Burn's backend test suite.
burn-ndarray now uses macerator SIMD for f32 elementwise ops (tracel-ai/burn#2851), so contiguous f32 binary/unary ops are at parity. Flex advantages come from gemm, integer ops (i32 vs i64), structural zero-copy, and fused kernels.
| Category | Speedup | Highlights |
|---|---|---|
| Binary ops (f32) | ~1x | Both use macerator SIMD for f32 |
| Binary ops (i32) | 1.8-5.3x | Flex uses i32, NdArray uses i64 |
| Matmul (square) | 1.4-3.1x | gemm at small/large; tied at mid-sizes |
| Matmul (batched) | 1.3-2.2x | Multi-head attention shapes |
| Matmul (int) | 3.7-6.5x | gemm vs matrixmultiply for integers |
| Conv2d (3x3) | 1.1-3.7x | Larger kernels and batches benefit most |
| Conv1d | 4.3-9.8x | |
| Conv transpose | 9.2-84x | Direct scatter vs im2col |
| Attention | 1.2-3.0x | Fused softmax, 2-8x lower peak memory |
| Pooling | 1.1-3.1x | |
| Interpolation | 1.1-6.3x | Nearest 4-6x, bilinear 1.7-2.8x |
| Reductions | 1.3-5.4x | Near-zero allocation for scalar results |
| Cumulative ops | 2.1-95x | 1D cumsum: 95x faster |
| Gather/scatter | 1.2-6.4x | |
| Unary (tanh, sin) | 1.3-2.0x | tanh 2x, sin/cos 1.3-1.5x |
| Sort | 2.3-29x | 2D sort up to 29x |
| Repeat dim | 8.9-12x | Single alloc + memcpy vs N slice_assign |
| Tensor creation | 16-33x | zeros/ones/full |
| Embedding | 4.8-5.8x | |
| Quantize | 1.3-1.5x | Fused 2-pass implementation |
| FFT (rfft/irfft) | yes | Native implementation, works in no_std |
These reflect better operation representation, not faster computation. burn-ndarray eagerly materializes data for these operations; burn-flex avoids the work entirely through zero-copy views and separated storage layouts.
| Category | Improvement | What changed |
|---|---|---|
| Dequantize | 122-238x | Direct scale * x_q vs reparsing QuantizedBytes each call |
| Quantized ops | 6.1-125x | Dominated by fast dequantize path above |
| Slice/narrow | 2.1-2,400x | Zero-copy strided view vs data copy |
| Unfold | 920-130,000x | O(1) strided view vs O(n) full materialization |
| Expand | 620-2,800x | Zero-copy broadcast (stride=0) vs data copy |
| Int cast | 6.3-26,000x | Zero-copy reinterpret vs element-wise conversion |
Note on quantization: burn-ndarray simulates quantization by dequantizing to f32 for most operations. The quantized speedups reflect the difference between simulated and native execution, not equivalent algorithms running at different speeds.
See BENCHMARKS.md for the full breakdown.
Per-op comparison against candle-core, pure-Rust on both sides. Across 11 bench files covering every flex op that intersects with candle's CPU API, flex is as fast or faster on every operation category.
| Category | Representative ratio | Notes |
|---|---|---|
| Batched matmul | 8-11x | Strided gemm, no copy |
| Conv1d (wav2vec2) | 1.4-7.6x | Direct conv path |
| Conv2d (ResNet) | 1.3-4.0x | 1x1 pointwise 4x |
| Conv transpose | 1.5-1.9x | |
| Max/min reductions | 3.8-5.1x | SIMD + zero-alloc |
| Pooling (k=3 s=2) | 1.8-2.5x | |
| Layer norm (fused) | 1.6-3.4x | Two-pass Welford kernel |
| Softmax (fused) | 1.4-1.7x | Three-pass row kernel |
| Cat, gather, select | 1.3-2.5x | |
| Nearest2d interpolation | 1.3-1.4x | |
| Elementwise, matmul, gelu, view ops | tied | Both at memory bandwidth ceiling |
burn-backend-tests pass across all feature flag combinations:
no-default-features (no_std, no SIMD, no rayon)no-default-features + simd (no_std with SIMD)stdstd + simdstd + rayonstd + simd + rayon (default)burn-no-std-tests integration suite passes (MNIST model inference in #![no_std])thumbv6m-none-eabi (ARM Cortex-M0+, no atomic pointers)thumbv7m-none-eabi (ARM Cortex-M3)wasm32-unknown-unknownburn-onnx pass