Back to Burn

Tensor Operations in Burn

examples/notebook/basic-tensor-op.ipynb

0.20.18.5 KB
Original Source

Tensor Operations in Burn

This notebook demonstrates basic tensor operations in Burn, a deep learning framework written in Rust.

rust

// Dependency declarations for the notebook.
// The syntax is similar to Cargo.toml. Just prefix with :dep

:dep burn = {path = "../../crates/burn"}
:dep burn-ndarray = {path = "../../crates/burn-ndarray"}
rust
// Import packages
use burn::prelude::*;
use burn_ndarray::NdArray;

// Type alias for the backend (using CPU/NdArray)
type B = NdArray<f32>;

1. Tensor Creation

rust
let device = <B as Backend>::Device::default();

// Create an empty tensor (uninitialized values)
let empty: Tensor<B, 3> = Tensor::empty([2, 3, 4], &device);
println!("Empty tensor shape: {:?}", empty.shape());

// Create a tensor filled with zeros
let zeros: Tensor<B, 2> = Tensor::zeros([3, 3], &device);
println!("Zeros tensor: {}", zeros);

// Create a tensor filled with ones
let ones: Tensor<B, 2> = Tensor::ones([2, 4], &device);
println!("Ones tensor: {}", ones);

// Create a tensor filled with a specific value
let full: Tensor<B, 2> = Tensor::full([2, 3], 7.0, &device);
println!("Full tensor (7.0): {}", full);
rust
// Create a tensor from a slice of values
let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let from_slice = Tensor::<B, 1>::from_floats(data, &device).reshape([2, 3]);
println!("From slice:\n{}", from_slice);

// Create a random tensor
use burn::tensor::Distribution;
let random: Tensor<B, 1> = Tensor::random([5], Distribution::Default, &device);
println!("Random tensor: {}", random);

// Create a tensor with normal distribution
let normal: Tensor<B, 1> = Tensor::random([5], Distribution::Normal(0.0, 1.0), &device);
println!("Normal distribution: {}", normal);

// Create a tensor with uniform distribution in range [0, 10)
let uniform: Tensor<B, 1> = Tensor::random([5], Distribution::Uniform(0.0, 10.0), &device);
println!("Uniform [0, 10): {}", uniform);

2. Shape Operations

rust
// Reshape tensor - change the dimensions without changing the data
let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &device).reshape([2, 3]);
println!("Original (2x3):\n{}", tensor);

let reshaped: Tensor<B, 3> = tensor.clone().reshape([1, 2, 3]);
println!("Reshaped (1x2x3): {}", reshaped);

// Flatten - reshape to 1D
let flat: Tensor<B, 1> = tensor.flatten(0, 1);
println!("Flattened: {}", flat);
rust
// Transpose - swap dimensions
let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);
println!("Original:\n{}", tensor);

let transposed = tensor.clone().transpose();
println!("Transposed:\n{}", transposed);

// Also .t() works for 2D tensors
let t = tensor.t();
println!("Using .t():\n{}", t);
rust
// Squeeze - remove dimensions of size 1
let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device).reshape([1, 1, 2]);
println!("Before squeeze [1,1,2]: shape = {:?}", tensor.shape());

let squeezed = tensor.squeeze::<1>();
println!("After squeeze: shape = {:?}", squeezed.shape());

// Unsqueeze - add a dimension of size 1 at specified position
let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);
println!("Before unsqueeze [2,2]: shape = {:?}", tensor.shape());

let unsqueezed = tensor.unsqueeze::<3>();
println!("After unsqueeze: shape = {:?}", unsqueezed.shape());

3. Indexing and Slicing

rust
// Create a tensor for indexing examples
let tensor = Tensor::<B, 1>::from_floats(
    [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
&device
).reshape([3, 4]);
println!("Original tensor:\n{}", tensor);
rust
// Slice tensor - select a portion using ranges
// Get rows 1-2 (index 1 to end), columns 1-3 (index 1 to 3)
let sliced = tensor.clone().slice([1..3, 1..4]);
println!("Sliced [1..3, 1..4]:\n{}", sliced);

// Get single row
let row = tensor.clone().slice([1..2, 0..4]);
println!("Row 1: {}", row);

// Get single column
let col = tensor.slice([0..3, 2..3]);
println!("Column 2: {}", col);

4. Basic Math Operations

rust
let a = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);
let b = Tensor::<B, 1>::from_floats([5.0, 6.0, 7.0, 8.0], &device).reshape([2, 2]);

println!("a = {}", a);
println!("b = {}", b);

// Addition
let c = a.clone() + b.clone();
println!("a + b = {}", c);

// Subtraction
let c = a.clone() - b.clone();
println!("a - b = {}", c);

// Multiplication (element-wise)
let c = a.clone() * b.clone();
println!("a * b = {}", c);

// Division (element-wise)
let c = a.clone() / b.clone();
println!("a / b = {}", c);
rust
// Scalar operations
let a = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);

println!("a = {}", a);

// Add scalar
let c = a.clone() + 10.0;
println!("a + 10 = {}", c);

// Multiply scalar
let c = a.clone() * 2.0;
println!("a * 2 = {}", c);
rust
// Matrix multiplication
let a = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);
let b = Tensor::<B, 1>::from_floats([5.0, 6.0, 7.0, 8.0], &device).reshape([2, 2]);

println!("a = {}", a);
println!("b = {}", b);

let result = a.matmul(b);
println!("a @ b (matmul) = {}", result);

// Verify (rows of a · columns of b): row1 [1,2] · col1 [5,7] = 1*5+2*7 = 19, row1 [1,2] · col2 [6,8] = 1*6+2*8 = 22
//                                      row2 [3,4] · col1 [5,7] = 3*5+4*7 = 43, row2 [3,4] · col2 [6,8] = 3*6+4*8 = 50

5. Element-wise Math Functions

rust
let a: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0], &device);

println!("a = {}", a);

// Exponential
println!("exp(a) = {}", a.clone().exp());

// Natural logarithm
println!("log(a + 1) = {}", (a.clone() + 1.0).log());

// Power
println!("a.powf(2) = {}", a.clone().powf_scalar(2.0));
println!("a.powf(0.5) = {}", a.clone().powf_scalar(0.5));
rust
// Trigonometric functions
let angles: Tensor<B, 1> = Tensor::from_floats([0.0, std::f32::consts::PI / 4.0, std::f32::consts::PI / 2.0], &device);

println!("angles = {}", angles);
println!("sin(angles) = {}", angles.clone().sin());
println!("cos(angles) = {}", angles.clone().cos());
println!("tan(angles) = {}", angles.clone().tan());

6. Reduction Operations

rust
let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &device).reshape([2, 3]);
println!("Tensor:\n{}", tensor);

// Sum all elements
println!("Sum: {}", tensor.clone().sum());

// Mean of all elements
println!("Mean: {}", tensor.clone().mean());

// Product of all elements
println!("Product: {}", tensor.clone().prod());

// Maximum and minimum
println!("Max: {}", tensor.clone().max());
println!("Min: {}", tensor.clone().min());
rust
// Reduce along specific dimensions
let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &device).reshape([2, 3]);
println!("Tensor:\n{}", tensor);

// Sum along dimension 0 (columns)
println!("Sum dim 0: {}", tensor.clone().sum_dim(0));

// Sum along dimension 1 (rows)
println!("Sum dim 1: {}", tensor.clone().sum_dim(1));

// Mean along dimension 0
println!("Mean dim 0: {}", tensor.clone().mean_dim(0));

7. Comparison and Selection

rust
let a: Tensor<B, 1> = Tensor::from_floats([1.0, 5.0, 3.0, 8.0], &device);
let b: Tensor<B, 1> = Tensor::from_floats([4.0, 2.0, 6.0, 7.0], &device);

println!("a = {}", a);
println!("b = {}", b);

// Element-wise comparison returns a boolean tensor
let greater = a.clone().greater(b.clone());
println!("a > b: {}", greater);

let less = a.clone().lower(b.clone());
println!("a < b: {}", less);

let equal = a.clone().equal(b.clone());
println!("a == b: {}", equal);
rust
// Conditional selection
let a: Tensor<B, 1> = Tensor::from_floats([1.0, 5.0, 3.0, 8.0], &device);

// mask_where: where condition is true, use replacement value, else keep original value
let condition = a.clone().greater_elem(4.0);
let result = a.clone().mask_where(condition, Tensor::zeros([4], &device));
println!("Original: {}", a);
println!("Where > 4, replace with 0: {}", result);

// mask_fill: simpler - just replace values matching condition
let result = a.clone().mask_fill(a.clone().greater_elem(4.0), -1.0);
println!("Where > 4, replace with -1: {}", result);

Summary

In this notebook, we covered:

  • Tensor Creation: empty, zeros, ones, full, from_floats, random
  • Shape Operations: reshape, transpose, flatten, squeeze, unsqueeze
  • Indexing and Slicing: slice operation with ranges
  • Math Operations: add, sub, mul, div, matmul
  • Element-wise Functions: exp, log, powf_scalar, sin, cos, tan
  • Reduction Operations: sum, mean, prod, max, min
  • Comparison: greater, lower, equal, mask_where, mask_fill