Back to Burn

burn-flex vs burn-ndarray: Comprehensive Comparison

crates/burn-flex/COMPARISON.md

0.21.041.0 KB
Original Source

burn-flex vs burn-ndarray: Comprehensive Comparison

This document compares burn-flex (proposed replacement) against burn-ndarray (current CPU backend) to demonstrate full coverage and the architectural differences between the two.

Executive Summary

burn-flex is a from-scratch CPU backend built to replace burn-ndarray. The ndarray crate has been slow to evolve: it lacks f16/bf16 support, is limited to 6 dimensions, uses unsigned-only strides (preventing zero-copy flip), and simulates quantization rather than executing natively. burn-flex addresses all of these while passing the full burn-backend-tests suite, all ONNX model checks, and real model inference (ALBERT, MiniLM).

Performance improvements fall into two categories:

  • Compute gains (1.1-9.7x): Better algorithms and libraries (gemm over matrixmultiply, Arc COW for buffer reuse, SIMD reductions).
  • Structural improvements (up to 166,000x): Operations that burn-ndarray eagerly materializes (unfold, expand, slice, dequantize) are represented as zero-copy views or direct lookups in burn-flex, avoiding the work entirely.

burn-flex uses significantly less memory, supports f16/bf16 natively, runs on no_std/WASM/embedded, and has no dimension limit.


1. Architecture

Tensor Representation

Aspectburn-flexburn-ndarray
StorageArc<Bytes> (type-erased bytes)enum NdArrayTensor { F64(NdArrayStorage<f64>), F32(...), ... }
DtypeRuntime DType field on FlexTensorCompile-time via enum variant
Dispatchmatch dtype at op entry, cast onceexecute_with_dtype! macro expands match for every op
Clone costO(1) Arc refcount incrementO(1) ArcArray refcount increment
COWArc::make_mut / is_unique()ArcArray::is_unique() + NdArrayStorage::Borrowed always returns false
MetadataLayout { shape, strides: Vec<isize>, start_offset }ndarray's internal strides (usize only)
Stride signSigned (isize) for zero-copy flipUnsigned (usize), flip requires data copy

FlexTensor (44 bytes without shape vec):

rust
struct FlexTensor {
    data: Arc<Bytes>,    // 8 bytes (pointer)
    layout: Layout,      // shape + strides + offset
    dtype: DType,        // 1 byte enum
}

NdArrayTensor (enum with 11 typed variants):

rust
enum NdArrayTensor {
    F64(NdArrayStorage<f64>),
    F32(NdArrayStorage<f32>),
    // ... 9 more variants
}

Key insight: Flex uses one struct for all dtypes with runtime dispatch. NdArray uses a typed enum with macro-based dispatch. Flex's approach is simpler (no macros, no generics plumbing) and enables operations to handle all dtypes uniformly.

Backend Type

Aspectburn-flexburn-ndarray
Typestruct Flex; (unit struct)struct NdArray<E=f32, I=i64, Q=i8> (3 generic params)
Float elementRuntime (f32/f64/f16/bf16)Compile-time E: FloatNdArrayElement (f32 or f64 only)
Int elementRuntime (i8-i64, u8-u64)Compile-time I: IntNdArrayElement
Quant elementRuntimeCompile-time Q: QuantElement

Flex eliminates generic parameters entirely. Users write Flex instead of NdArray<f32, i64, i8>. Dtype selection happens at runtime via DType.


2. Feature Coverage

Float Dtypes

Dtypeburn-flexburn-ndarray
f32Full support (native)Full support (native)
f64Full support (native)Full support (native)
f16Full support (native)Not supported
bf16Full support (via f32 conversion for compute-heavy ops)Not supported
Flex32Not applicableMaps to f32

burn-flex's f16 support is native for all operations. For matmul and convolution, the gemm crate has native f16 kernels (since v0.15). bf16 converts to f32 for compute-heavy ops (matmul, conv) because gemm lacks native bf16 support.

Integer Dtypes

Dtypeburn-flexburn-ndarray
i64Full supportFull support
i32Full supportFull support
i16Full supportFull support
i8Full supportFull support
u64Full supportFull support
u32Full supportFull support
u16Full supportFull support
u8Full supportFull support

Both backends support the same integer dtypes.

Bool

Featureburn-flexburn-ndarray
Storageu8 (1 byte per element)bool (1 byte per element via ndarray)
OperationsAll BoolTensorOpsAll BoolTensorOps

Quantization

Featureburn-flexburn-ndarray
QuantizePer-tensor and per-block symmetricPer-tensor and per-block symmetric
Dequantizescale * x_q (direct multiply, 135-232x faster)Reparses QuantizedBytes on every call
Scale storageVec<f32> stored separatelyQParams<f32> in NdArrayQTensor
Q layout opsZero-copy (permute, flip, expand, slice, select)Copies entire tensor
Q ordering opsSkip dequantization (argmax, argmin, gather on i8 directly)Dequantize to f32, then operate
QuantStoreNativeNative
QuantValueQ8F, Q8SQ8F, Q8S (+ Q4/Q2 for export_tests)

The fundamental difference is scale storage. Flex stores scales separately so dequantization is a simple scale * x_q multiply. NdArray stores everything in QuantizedBytes which must be parsed on every access, making it the bottleneck for all quantized operations.


3. Operation Coverage

Tensor Operations (FloatTensorOps)

All operations listed below are implemented by both backends unless marked otherwise.

Operationburn-flexburn-ndarrayNotes
from_dataYesYes
into_dataYesYes
randomYesYes
empty/zeros/onesYesYes
fullYesYes
add / sub / mul / divYesYes
add_scalar / sub_scalar / mul_scalar / div_scalarYesYes
remainderYesYes
remainder_scalarYesYes
matmulYesYesFlex uses gemm, NdArray uses matrixmultiply
negYesYes
recipYesYes
swap_dims / permuteYesYesBoth zero-copy
reshapeYesYesBoth zero-copy when contiguous
gather / scatter_addYesYes
select / select_addYesYes
slice / slice_assignYesYesFlex: zero-copy view; NdArray: may copy
mask_fill / mask_whereYesYes
equal / not_equal / greater / lower / greater_equal / lower_equalYesYes
equal_elem / not_equal_elem / greater_elem / lower_elemYesYes
sum / sum_dim / mean / mean_dim / prod / prod_dimYesYes
max / max_dim / max_dim_with_indicesYesYes
min / min_dim / min_dim_with_indicesYesYes
argmax / argminYesYes
any / any_dim / all / all_dimYesYes
exp / log / log1pYesYes
powf / powf_scalar / powi / powi_scalarYesYes
sqrt / abs / signYesYes
cos / sin / tanhYesYes
erfYesYes
catYesYes
into_int / into_boolYesYes
clamp / clamp_min / clamp_maxYesYes
expandYesYesFlex: zero-copy; NdArray: copies
flipYesYesFlex: zero-copy (signed strides); NdArray: copies
repeat_dimYesYes
sort / sort_with_indices / argsortYesYes
cumsum / cumprod / cummin / cummaxYesYes
narrowYesYesFlex: zero-copy; NdArray: may copy
chunkYesYes
crossYesYes
unfoldYesYesFlex: zero-copy (strided view); NdArray: materializes
round / floor / ceilYesYes
castYesYes
grid_sample_2dYesYes
bool_selectYesYes
int_powiYesYes

Module Operations (ModuleOps)

Operationburn-flexburn-ndarrayNotes
conv1dYesYesFlex: delegates to conv3d
conv2dYesYesFlex: delegates to conv3d
conv3dYesYesFlex: unified implementation
conv_transpose1dYesYesFlex: delegates to conv_transpose3d
conv_transpose2dYesYesFlex: delegates to conv_transpose3d
conv_transpose3dYesYesFlex: unified implementation
deform_conv2dYesYes
deform_conv2d_backwardYesYes
avg_pool2dYesYesFlex: delegates to pool3d
avg_pool2d_backwardYesYes
max_pool2dYesYesFlex: delegates to pool3d
max_pool2d_with_indicesYesYes
max_pool2d_with_indices_backwardYesYes
adaptive_avg_pool2dYesYes
adaptive_avg_pool2d_backwardYesYes
interpolateYesYesNearest, bilinear, bicubic
attention (SDPA)YesYesFlex: auto-selects naive or flash by score matrix size; NdArray: matmul + softmax
rfftYesNoFlex: Cooley-Tukey with complex packing, radix-4, SIMD, compile-time twiddles. no_std.
irfftYesNoFlex: Inverse packing trick, SIMD via conjugate-forward-conjugate. no_std.

Int and Bool Operations

Both backends implement all IntTensorOps and BoolTensorOps. The operations mirror float ops where applicable (arithmetic, comparison, reduction, gather/scatter, slice, etc.) plus type-specific operations (int_random uniform, bool_not, bool_and, bool_or, bool_xor).

Quantized Operations (QTensorOps)

Both backends implement all QTensorOps. The ops follow a dequantize-op-requantize pattern for most operations. Flex optimizes by:

  • Storing scales separately for O(1) dequantization access
  • Zero-copy layout ops on quantized tensors (permute, flip, expand, slice, select)
  • Skipping dequantization for ordering ops (argmax, argmin, gather with tensor-level quant)

Activation Operations (ActivationOps)

Both backends implement all ActivationOps via the default trait implementations (relu, gelu, etc.).

Transaction Operations

Both backends implement TransactionOps for batched tensor operations.


4. Dimension Limits

Aspectburn-flexburn-ndarray
Max dimensionsUnlimited (arbitrary rank)6 (hardcoded in reshape macro)
EnforcementDynamic Vec<isize> for stridesStatic Dim<[usize; N]> requires match on 1-6

burn-ndarray's dimension limit comes from its reshape! macro which matches on dimensions 1-6:

rust
match $D {
    1 => reshape!(ty $ty, n 1, ...),
    // ...
    6 => reshape!(ty $ty, n 6, ...),
    _ => panic!("NdArray supports arrays up to 6 dimensions"),
}

burn-flex uses IxDyn-equivalent dynamic shapes with no upper bound.


5. Zero-Copy Operations

Operationburn-flexburn-ndarray
transposeZero-copy (swap strides)Zero-copy (ndarray view)
permuteZero-copy (reorder strides)Zero-copy (ndarray view)
reshapeZero-copy if contiguousZero-copy if standard layout
slice / narrowZero-copy (offset + strides)May allocate depending on path
flipZero-copy (negate stride)Copies data
unfoldZero-copy (O(1) strided view)O(n) full materialization
expandZero-copy (set stride to 0)Copies data

Flex's signed strides (isize) enable zero-copy flip, which is impossible with ndarray's unsigned strides. The unfold operation is especially dramatic: Flex returns a strided view in ~50ns regardless of size, while NdArray copies all window data (milliseconds for large tensors).


6. Memory Strategy

In-Place Mutation

Strategyburn-flexburn-ndarray
Unique checkArc::strong_count(&data) == 1ArcArray::is_unique()
In-place thresholdContiguous at offset 0 AND uniqueUnique (via SIMD ops, not all ops)
Binary op reuseReuses lhs buffer when contiguousAllocates new for most ops
Allocation savings3x less for binary ops (4.2 MB vs 12.6 MB for 1M f32)Standard ndarray allocation

Zero-Copy Loading

Both backends support zero-copy loading from external sources (burnpack files, mmap'd data):

Featureburn-flexburn-ndarray
MechanismArc<Bytes> wraps borrowed data directlyNdArrayStorage::Borrowed holds Bytes + shape
COW triggerArc::make_mut clones on shared mutationinto_owned() copies borrowed to ArcArray
View accessstorage::<E>() via bytemuck castview() via unsafe ArrayView from raw pointer

7. SIMD

Aspectburn-flexburn-ndarray
Librarymacerator (required with simd feature)macerator (optional with simd feature)
DispatchArch::new().dispatch(kernel)Same macerator dispatch
ISAsNEON, AVX2, AVX512, SSE, SIMD128, scalar fallbackNEON, AVX2, SSE, SIMD128, scalar fallback
CoverageBinary ops, comparisons, boolean ops, reductions, unary opsBinary ops, comparisons, unary ops, conv, pool
Without SIMDScalar fallback module (simd/scalar.rs)Falls back to ndarray operations

Both use macerator for portable SIMD. NdArray additionally has SIMD-optimized conv and pool kernels. Flex relies on the gemm crate's built-in SIMD for matmul/conv performance.


8. Matrix Multiplication

Aspectburn-flexburn-ndarray
Librarygemm crate (v0.18)matrixmultiply crate (via ndarray)
f32Native gemm kernelmatrixmultiply
f64Native gemm kernelmatrixmultiply
f16Native gemm kernel (since v0.15)Not supported
bf16Convert to f32, gemm, convert backNot supported
i32 matmulManual nested loopManual nested loop
ParallelismRayon via gemm (threshold: 192^3)Rayon via iter_range_par macro
BatchedParallel over batches + per-batch gemmParallel over batches + ndarray general_mat_mul
BroadcastHandles batch broadcast nativelyHandles batch broadcast via stride mapping
BLAS optionNo (pure Rust only)Yes (Accelerate, OpenBLAS, Netlib via feature flags)

burn-ndarray offers optional BLAS acceleration (Accelerate on macOS, OpenBLAS, Netlib) through feature flags. burn-flex uses only the gemm crate, which is pure Rust but highly optimized with its own SIMD kernels. The gemm crate consistently outperforms matrixmultiply by 1.3-3.4x on Apple M3 Max.


9. Convolutions

Aspectburn-flexburn-ndarray
Algorithmim2col + gemm (unified 3D)Direct computation (per-dimension implementations)
conv1dDelegates to conv3dSeparate implementation
conv2dDelegates to conv3dSeparate implementation
conv3dSingle unified implementationSeparate implementation
f16 supportNative gemmNot supported
bf16 supportVia f32 conversionNot supported
ParallelismRayon over batches and groupsiter_range_par over batches
SIMD convVia gemm SIMD kernelsmacerator-based SIMD conv kernel

Flex's unified 3D approach means one implementation covers all dimensionalities. The tradeoff is that 1D/2D convolutions expand dimensions (negligible overhead since gemm dominates).

NdArray has dedicated SIMD conv/pool kernels via macerator, which can be faster for specific patterns. Flex relies on the gemm crate's SIMD for all compute-heavy paths.


10. Parallelism

Aspectburn-flexburn-ndarray
Libraryrayon (optional)rayon (optional, called "multi-threads")
Feature flagrayonmulti-threads
Threshold4M elements for memory-bound opsVia run_par! / iter_range_par! macros
ScopeLarge tensors, batch dims, pool, convMatmul batches, ops via macros
gemm parallelismRayon via Parallelism::Rayon(0)matrixmultiply threading
Without featureSingle-threaded (all ops work)Single-threaded (all ops work)

11. Platform Support

Targetburn-flexburn-ndarray
x86_64 (std)YesYes
aarch64 (std)Yes (primary target)Yes
wasm32-unknown-unknownYes (verified)Yes (claimed, categories)
thumbv6m-none-eabi (Cortex-M0+)Yes (verified, no atomic ptrs)Not verified
thumbv7m-none-eabi (Cortex-M3)Yes (verified)Not verified
no_stdYes (tested, MNIST inference)Yes (supported)

burn-flex has been explicitly tested on embedded targets with Burn's burn-no-std-tests integration suite (MNIST model inference).


12. Dependencies

burn-flex

DependencyPurposeRequired
burn-backendBackend traits, typesAlways
burn-irBackendIr traitAlways
burn-stdBytes, Shape, platform abstractionsAlways
halff16/bf16 typesAlways
bytemuckZero-copy type castingAlways
num-traitsNumeric traits (libm for no_std)Always
gemmMatrix multiplicationAlways
maceratorPortable SIMDOptional (simd)
aligned-vecSIMD-aligned allocationOptional (simd)
rayonParallelismOptional (rayon)

Total: 7 required + 3 optional

burn-ndarray

DependencyPurposeRequired
burn-backendBackend traits, typesAlways
burn-stdPlatform abstractionsAlways
burn-autodiffAutodiff supportOptional (std)
burn-irIR typesAlways
ndarrayN-dimensional array libraryAlways
matrixmultiplyMatrix multiplicationAlways
atomic_floatAtomic f32/f64Always
const-randomCompile-time randomAlways
libmMath functions for no_stdAlways
num-traitsNumeric traitsAlways
pasteMacro utilitiesAlways
randRandom number generationAlways
maceratorPortable SIMDOptional (simd)
bytemuckType castingOptional (simd)
itertoolsIterator utilitiesOptional (simd)
seq-macroSequence macrosOptional (simd)
rayonParallelismOptional (multi-threads)
blas-srcBLAS bindingsOptional (blas-*)
openblas-srcOpenBLASOptional (blas-openblas)
portable-atomicAtomic for no-atomic-ptr targetsConditional
portable-atomic-utilAtomic utilitiesConditional

Total: 12 required + 9 optional + 2 conditional

burn-flex has significantly fewer dependencies, with no dependency on ndarray itself, no macro utility crates, and no BLAS bindings.


13. Codebase Size

Metricburn-flexburn-ndarray
Source files3837
Total lines~23,500~11,400
ops/ directory~19,700 lines~8,200 lines
SIMD module~1,200 lines~2,100 lines

burn-flex has roughly 2x the code. This is because:

  1. Flex implements all ops from scratch (ndarray delegates to the ndarray crate's built-in ops)
  2. Flex has dedicated optimized implementations (pool, conv, reduce, cumulative, gather/scatter)
  3. Flex has more comprehensive dtype handling (f16/bf16 paths for every op)
  4. Flex has explicit contiguous/non-contiguous fast paths throughout

14. Testing

Aspectburn-flexburn-ndarray
burn-backend-testsAll pass (6 feature flag combos)All pass
burn-no-std-testsPass (MNIST inference)Not explicitly verified
ONNX model checksAll passAll pass
Real model inferenceALBERT, MiniLMNot documented
Feature combos testedno-default, simd, std, std+simd, std+rayon, std+simd+rayonDefault
Edge-case robustnessInteger overflow, rounding, zero-size, invalid paramsStandard
Embedded buildsthumbv6m, thumbv7m, wasm32wasm32

15. Performance Summary

All benchmarks on Apple M3 Max, default features enabled.

Compute Performance

Genuine algorithmic and library improvements:

CategoryFlex vs NdArrayWhy
Binary ops (f32)2.4-3.9x fasterArc COW avoids allocation; 3x less memory
Binary ops (i64)1.5-6.4x fasterSame COW benefits
Matmul (square)1.1-3.4x fastergemm > matrixmultiply
Matmul (batched)1.8-3.2x fasterBetter batch parallelism
Attention1.2-2.4x fasterFlash attention, 2-8.5x lower peak memory
Conv2d1.2-4.0x fasterim2col+gemm vs direct
Conv1d4.3-9.6x fasterUnified 3D avoids overhead
Pooling1.2-3.1x fasterUnified 3D, better parallelism
Interpolation1.2-3.6x fasterDirect computation vs intermediates
Reductions1.6-5.1x fasterZero-alloc SIMD single-pass
Cumulative3.1-97x fasterBlocked scan, scalar accumulator
Gather/scatter1.6-9.8x fasterDirect indexing
Unary1.1-2.7x fasterIn-place mutation when possible
Comparisons2.1-3.9x fasterSIMD + compact u8 output
Int cast5.0-7.6x fasterDirect byte reinterpretation
Quantize1.6x fasterFused 2-pass implementation
Concatenation3.6-16.3x fasterDirect memcpy vs slice_assign

Structural Improvements

These reflect changes in how operations are represented and executed, not pure compute speedups. burn-ndarray eagerly materializes data where burn-flex uses zero-copy views or separated storage.

CategoryImprovementWhat changed
Dequantize135-232xDirect scale * x_q vs reparsing QuantizedBytes each call
Quantized ops2.9-117xDominated by fast dequantize path
Slice/narrow2.1-2,100xZero-copy strided view vs potential data copy
Unfold1,200-166,000xO(1) strided view vs O(n) full materialization
Expand550-2,600xZero-copy broadcast (stride=0) vs data copy

Note on quantization: burn-ndarray simulates quantization by dequantizing to f32 for most operations. The quantized speedups reflect the difference between simulated and native execution, not equivalent algorithms running at different speeds.

Where NdArray Wins

CategoryNdArray advantageReason
bool_not/bool_and~20% fasterndarray's vectorized mapv is well-optimized
int_powf_scalar~10% fasterndarray's vectorized internals
Transposed i64 add (large)~7% fasterndarray handles non-contiguous well
Deform conv (medium)~30% fasterNdArray has optimized deform conv path
Max pool 5x5~17% fasterSpecific kernel size advantage

These are specific edge cases where NdArray's ndarray-based internals have an advantage.


16. Why Replace burn-ndarray?

The ndarray crate has been slow to accept contributions and evolve. Burn's CPU backend inherits these constraints:

  • No f16/bf16: Models using half-precision weights must convert to f32. An f16 PR has been open for a long time with no clear timeline.
  • 6-dimension limit: Hard-coded in reshape macros, cannot be fixed without upstream changes.
  • Unsigned strides: usize-only strides make zero-copy flip impossible.
  • Simulated quantization: No native quantized storage; dequantize/requantize on every op.
  • COW limitations: NdArrayStorage::Borrowed always returns false for is_unique(), preventing in-place mutation of externally loaded data.

burn-flex was built to address these gaps without waiting on upstream. It is not intended to compete with CubeCL CPU, which targets high-performance computation through operator fusion and just-in-time compilation via LLVM. The goal is to provide a lightweight, portable replacement for burn-ndarray that works today on platforms CubeCL CPU cannot target (no_std, WASM, embedded).

17. What burn-flex Adds

  1. f16/bf16 support: Native arithmetic on half-precision types. Enables running models that use f16 weights without conversion.

  2. No dimension limit: Arbitrary tensor rank (ndarray is limited to 6).

  3. Zero-copy flip/unfold/expand: Signed strides enable O(1) flip. Unfold returns a strided view instead of materializing all windows.

  4. Unified 3D conv/pool: Single implementation covers 1D/2D/3D, reducing code paths and potential for inconsistencies.

  5. Native quantization: Stores scales separately for direct scale * x_q dequantization instead of reparsing packed bytes on every access. Zero-copy layout ops on quantized tensors.

  6. Fewer dependencies: 7 required deps vs 12. No ndarray, no matrixmultiply, no paste, no const-random, no BLAS bindings.

  7. Simpler type system: Flex vs NdArray<E, I, Q>. No generic parameters, no element trait hierarchy (FloatNdArrayElement, IntNdArrayElement, NdArrayElement, ExpElement).

  8. Real FFT: Forward (rfft) and inverse (irfft) real FFT with complex packing, SIMD butterflies, and compile-time twiddle tables. Works in no_std (rustfft/realfft require std). NdArray implements neither.


18. What burn-ndarray Has That burn-flex Does Not

  1. BLAS acceleration: Feature flags for Accelerate (macOS), OpenBLAS, and Netlib BLAS. These can outperform gemm for very large matmuls on specific hardware. burn-flex relies solely on the gemm crate.

  2. SIMD conv/pool kernels: burn-ndarray has dedicated macerator-based SIMD kernels for convolution and pooling. burn-flex delegates to gemm's SIMD.

  3. export_tests feature: burn-ndarray serves as a reference implementation for some burn-cubecl kernels via export_tests.


19. Migration Path

For Burn users switching from burn-ndarray to burn-flex:

ChangeDetails
Type parameterNdArray<f32> becomes Flex
DeviceNdArrayDevice::Cpu becomes FlexDevice
Feature flagsmulti-threads becomes rayon
BLAS featuresNo equivalent (gemm handles matmul)
AutodiffUse burn_autodiff::Autodiff<Flex> (same pattern)
f16/bf16Works out of the box (new capability)
QuantizationSame API, faster execution
TestsSame burn-backend-tests suite passes

20. Conclusion

burn-flex is a from-scratch replacement for burn-ndarray, motivated by ndarray's lack of f16/bf16 support, 6-dimension limit, simulated quantization, and slow pace of upstream development. It implements all required Backend traits (FloatTensorOps, IntTensorOps, BoolTensorOps, QTensorOps, ModuleOps, ActivationOps, TransactionOps) and passes the same test suite.

Performance gains come in two forms: compute improvements (1.1-9.7x) from better libraries and algorithms, and structural improvements (up to 166,000x) from representing operations as zero-copy views instead of eagerly materializing data. Memory usage is significantly reduced through Arc-based COW and in-place mutation.

The only capabilities lost are optional BLAS acceleration (replaced by the gemm crate, which is faster in most benchmarks) and the export_tests reference implementation feature.