contributor-book/src/guides/adding-a-new-operation-to-burn.md
Let's discuss how one might go about adding new operators to Burn, using the example of the pow operator added in this PR.
burn-tensor is the crate that defines all tensor operations that need to be implemented by the
various backends. The core of this lies in
crates/burn-backend/src/tensor/ops/numeric.rs,
which is home to the numeric trait. The numeric trait is the home of all tensor operations that are
numeric in nature and that are shared by Int and Float Tensor types. The numeric trait is
implemented in
crates/burn-backend/src/tensor/ops/int.rs
for the int type and in
crates/burn-backend/src/tensor/ops/float.rs
for the float type. More information on the relationship between Tensor modules can be found under
the section for Tensor Architecture.
Here is where pow was added to crates/burn-tensor/src/tensor/api/numeric.rs:
Tensor<Backend, Dimension, Kind> structTensor is a struct that has a single member: primitive (defined
here),
that is defined by its
Kind:
one of Bool, Float, or Int (those linked in 3). These call the ops for that data type defined
in the
Backend
supertrait1. This is the trait that is then implemented by the different burn-
backends (such as burn-ndarray and burn-wgpu) which must implement the functions if no default
is provided.
In this case, we don't need to worry about Bool Tensors. Float ops are implemented under
crates/burn-backend/src/backend/ops/tensor.rs,
and Int ops under
crates/burn-backend/src/backend/ops/int_tensor.rs.
The current convention is ops of each type, if not unique to that type, are prefixed with the type.
So powf and sundry would be defined as int_powf for IntTensorOps and float_powf for
FloatTensorOps. If an op is unique to a type, then it should be implemented under
burn-tensor/src/api/{type}.rs. For example, here is an implementation for
sin under crates/burn-tensor/src/api/float.rs
which obviously doesn't make sense for Int or Bool tensors.
The Int Tensor function uses the ones defined for Float with 2 extra casts (LHS to a Float
tensor, Output to an Int). Given that the rest of the code will only look at the float
implementations.
With the addition of quantized float tensors, the Float tensor primitive is represented by the
TensorPrimitive
enum. This allows us to handle both float and quantized float operations in the Tensor
implementation, correctly dispatching to the corresponding op (float or quantized) based on the
variant. Following the same convention, the equivalent
quantized tensor ops
are prefixed with q_* (e.g., q_reshape instead of float_reshape). Most ops have a default
implementation that simply dequantizes the input into its floating-point representation, performs
the operation on the float tensor, and quantizes the output. Backends can overwrite specific
implementations when required/desired.
Additional tests should be added to burn-backend-tests under
crates/burn-backend-tests/tests/tensor/{float_or_int}/ops/{op_name}.rs,
and the module name should be inserted into
crates/burn-backend-tests/tests/tensor/{float_or_int}/ops/mod.rs.
If it makes sense for a floating point operation to support quantization, the
QTensorOps
counterpart is usually added at the same time with a default implementation (as mentioned in the
previous section). Tests for q_* ops follow a similar procedure: the test is added under
crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/{op_name}.rs,
the module name is inserted into
crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/mod.rs.
If you take a look at any of the existing tests for an operation on a quantized tensor, you will see
that the inputs and expected outputs are always defined with floating point values. While it assumes
that the quantization and dequantization are correct, it makes the tests much more readable and
easier to understand w.r.t. what is being tested. Effectively, the tests are there to ensure that a
tensor operation is invariant to quantization (up to some quantization error, of course).
Note: the tests try to use tensors with floating point values which can be de/quantized without introducing too much quantization error, but the result always depends on the operation (e.g., tensor product of values can grow larger and significantly increase the output tensor range, leading to more de/quantization error on the results).
Since this is probably the hardest and the least straightforward, we'll cover this backend
separately. burn-autodiff enables other backends to use autodifferentiation2. Ops for
float types are implemented in
crates/burn-autodiff/src/ops/tensor.rs
and need to:
backward.rs under the same directory), the last 2 arguments are two closures
that define the left and right partial derivatives..computeBound()) or memory-bound
(.memoryBound()) for gradient checkpointing. Compute-bound operation are heavy to compute
(for instance matmul or convolution), which means that even with checkpointing they will save
their output for the backward pass and not recompute it. Memory-bound operations are more
trivial (like powf which only performs one small operation per tensor entry), so it can be
beneficial to recompute them during the backward pass instead of saving their whole forward
output to memory. Operations registered as memory-bound need to know their parents
(.parents() method) and how to recompute their forward pass during the backward pass (with a
struct that implements RetroForward), using their parents' outputs.The above steps are mostly boilerplate, so you can often just copy the contents of another similar op, change the name of the structs, and ensure that either both sides have the data they need (if they need to have a copy of the opposite sided tensor, clone its contents).
For those that need it, here is a quick refresher on the necessary calculus. If you are familiar with how to calculate partial derivatives, you can skip this section.
Since pow is a binary operation, the left and right functions are the partial derivatives with
respect to the left and right sided tensors.
Let's define the operator as a function \(f(x,y)=x^{y}\) , where \(x\) is the left hand tensor and \(y\) is the right handed tensor. The two closures are defining the partial derivatives of \(f\) with respect to \(x\),\(y\). Treat the other variables as a constant
$$\frac{\delta }{\delta x} (x^{y})= y \cdot x^{y-1}$$ is the left handed closure, and
$$\frac{\delta }{\delta y} (x^{y}) = x^{y} \cdot ln(x)$$
is the right. If you aren't sure how to calculate these by hand, it is recommended to use symbolab, plug in your operator in terms of \(x\) and \(y\), and just swap out the variable \(x\)|\(y\) in the partial derivative to get the other side.
For testing the autodiff operations, please refer to
this section.
Most of these are fairly straightforward implementations. For reference here's pow's float implementation for torch and ndarray backends:
This is where any calculation happens currently. Playing a guessing game with method names and seeing what completions are suggested will take you far. If you are having trouble figuring out how to do it from the docs for that backend, try searching github for relevant function calls.
Adding an operator to these backends can be fairly straightforward, though due to what these backends are for, involves a bit more indirection. Fusion and jit, like autodiff, are not target backends as much as backends that enable certain functionality for other backends, in this case kernel fusion or just-in-time compilation. Adding the operator won't involve doing any calculation, you'll just be describing how the generated code should look. Most of this can be copy/pasted/adjusted from other functions.
Here's how powf was added to burn-fusion:
NumericOperationIr enum under
crates/burn-ir/src/operation.rsNumericOperationIr enum under
crates/burn-ir/src/operation.rsNumericOperationIr enum under
burn/crates/burn-fusion/src/stream/context.rsThe way cubecl handles tensor-scalar operations is by transforming both into a sequence of
vectorized scalar operations. Since powf already existed in cubecl, it was pretty easy to reuse
the existing implementation for the situation where both sides of the operation were tensors. The
cubecl crate is primarily concerned with how the operation is compiled and executed by the gpu.
The actual implementation is defined in burn-cubecl.
Here is where code was added for powf in burn-cubecl and cubecl:
FloatTensorOps under burn/crates/burn-cubecl/src/ops/tensor.rsburn/crates/burn-cubecl/src/ops/numeric.rscubecl/crates/cubecl-ir/src/arithmetic.rsburn/crates/burn-cubecl-fusion/src/engine/codegen/ir.rscubecl/crates/cubecl-cpp/src/shared/base.rs,
cubecl/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs
and
cubecl/crates/cubecl-spirv/src/arithmetic.rscubecl/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs,
and the actual
instruction in wgsl here,
for CPP in the enum here
cubecl/crates/cubecl-cpp/src/shared/instruction.rs
and the actual instruction here
cubecl/crates/cubecl-cpp/src/shared/binary.rsWe needed to generate some custom WGSL code for powf in WGSL, primarily due to issues with proper
case handling of the wgsl pow function, like 0 to the 0 power being 1, and any negative number to an
even power being positive. We reused as much as the existing logic as possible, and then branched at
the last point based off the var type of the rhs.
See here.
For most operations, you shouldn't need to add to cubecl-wgpu/src/compiler/wgsl/extension.rs
unless the operation isn't native to WGSL.
For functions that need a complex kernel without a direct mapping to a base instruction, simply use
the cube macro (see
the cubecl book).
And you're done! Congrats, you just fully added a new operation to burn, and we are all one step closer to the answer to Are we learning yet? being "Yes, and it's freaking fast!". Buy yourself a coffee.
for more on supertraits see the advanced trait section of the rust book ↩
wiki link for automatic differentiation ↩
for more information on unit structs see the defining and instantiating structs section of the rust book ↩