docs/parallel/simd.md
Codon's simd module and Vec[T, N] type provide direct, ergonomic access to
SIMD
instructions, including:
simd.Vec[T, N] represents an LLVM vector <N x T>:
T is a scalar numeric type (e.g. float32, int, u32, etc.)N is an integeral literalConceptually:
Vec[float32, 8] ≈ "8-wide float32 SIMD register"Vec[u8, 16] ≈ "16 byte-wide lanes"You would typically use Vec in hot loops where:
The simplest way to create a vector is to broadcast a scalar into all lanes:
from simd import Vec
# 8-lane vector of all ones
v = Vec[float, 8](1.0)
# 4-lane vector of all -3
w = Vec[int, 4](-3)
Vec can also load data from a pointer, such as the underlying buffer of a
NumPy array. Here is an example that implements hand-vectorized addition of
two arrays:
import numpy as np
from simd import Vec
def add_arrays(a: np.ndarray[float32, 1],
b: np.ndarray[float32, 1],
out: np.ndarray[float32, 1]):
W: Literal[int] = 8 # vector width
i = 0
while i + W <= len(a):
va = Vec[float32, W](a.data + i) # load 8 elements from a[i..i+7]
vb = Vec[float32, W](b.data + i) # load 8 elements from b[i..i+7]
vc = va + vb # SIMD add
Ptr[Vec[float32, W]](out.data + i)[0] = vc # store back
i += W
# handle remaining tail elements (scalar)
while i < len(a):
out[i] = a[i] + b[i]
i += 1
Note that:
Vec[T, N](ptr) treats ptr as a Ptr[Vec[T, N]] and loads one SIMD vectorPtr[Vec[T, N]]You can also construct vectors from lists:
data = [1.0, 2.0, 3.0, 4.0]
v = Vec[float, 4](data) # load from data[0..3]
For byte-sized element types (i8, u8, byte), you can read directly from a string buffer:
# Load 16 bytes from a string
s = "ABCDEFGHIJKLMNOP"
v = Vec[u8, 16](s)
# Skip first 4 bytes
v2 = Vec[u8, 16](s, 4)
All basic arithmetic operations are lane-wise:
+, -, * on integer and float vectors/ on float vectors (true division)// and % on integer vectorsExample: fused multiply-add style accumulation for a dot product:
import numpy as np
from simd import Vec
def dot(a: Ptr[float32], b: Ptr[float32], n: int) -> float32:
W: Literal[int] = 8
i = 0
acc = Vec[float32, W](0.0f32)
while i + W <= n:
va = Vec[float32, W](a + i)
vb = Vec[float32, W](b + i)
acc = acc + va * vb # lane-wise multiply + add
i += W
# horizontal reduce SIMD accumulator
total = acc.sum()
# handle tail scalars
while i < n:
total += a[i] * b[i]
i += 1
return total
Note that:
va * vb multiplies lanesacc.sum() adds all lanes, resulting in a scalarComparisons between vectors (or between a vector and a scalar) produce mask vectors:
v: Vec[float, 8] = ...
mask_negative = v < Vec[float, 8](0.0) # Vec[u1, 8]
You can then use mask to select between two vectors, without branches:
def relu_vec(v: Vec[float, 8]) -> Vec[float, 8]:
zero = Vec[float, 8](0.0)
m = v > zero # positive lanes
return v.mask(zero, m) # if m: v else zero
The general pattern is:
<, <=, >, >=, ==, !=)mask(self, other, mask) to do: mask ? self : other lane-wiseThis is useful to turn control-flow into data-flow, which is conducive to SIMD programming:
min/max/conditionals)x > t ? x : 0)Vec includes a reduction over addition, both for integers and floats:
v = Vec[float32, 8](...)
s = v.sum() # returns float32
Internally this uses LLVM's vector reduction intrinsics and is typically much faster than manually scattering and summing.
You can combine this with loops to build aggregate operations. Below is an example that implements L2 norm:
def l2_norm(xs: Ptr[float32], n: int) -> float32:
W: Literal[int] = 8
i = 0
acc = Vec[float32, W](0.0f32)
while i + W <= n:
v = Vec[float32, W](xs + i)
acc = acc + v * v
i += W
total = acc.sum()
while i < n:
total += xs[i] * xs[i]
i += 1
return np.sqrt(total)
Integer types support additional operations, such as:
&, |, ^<<, >>min, maxSuppose you want to add two u8 images but clamp to 255 on overflow.
You can use the overflow-aware addition and a mask:
def saturating_add_u8(a: Ptr[u8], b: Ptr[u8],
out: Ptr[u8], n: int):
W: Literal[int] = 16
i = 0
max_val = Vec[u8, W](255u8)
while i + W <= n:
va = Vec[u8, W](a + i)
vb = Vec[u8, W](b + i)
(sum_vec, overflow) = va.add(vb, overflow=True)
# where overflow[i] == 1, clamp to 255
clamped = max_val.mask(sum_vec, overflow)
Ptr[Vec[u8, W]](out + i)[0] = clamped
i += W
while i < n:
s = int(a[i]) + int(b[i])
out[i] = u8(255 if s > 255 else s)
i += 1
Note:
va.add(vb, overflow=True) returns (result, overflow_mask)mask to blend between a "safe" value and the raw resultVectors can be transformed into lists:
v = Vec[int, 4](...)
lst = v.scatter() # List[int] of length 4
print(lst) # e.g. [1, 1, 1, 1]
Vectors can be printed directly:
print(v) # e.g. "<1,1,1,1>"