examples/notebook/basic-tensor-op.ipynb
This notebook demonstrates basic tensor operations in Burn, a deep learning framework written in Rust.
// Dependency declarations for the notebook.
// The syntax is similar to Cargo.toml. Just prefix with :dep
:dep burn = {path = "../../crates/burn"}
:dep burn-ndarray = {path = "../../crates/burn-ndarray"}
// Import packages
use burn::prelude::*;
use burn_ndarray::NdArray;
// Type alias for the backend (using CPU/NdArray)
type B = NdArray<f32>;
let device = <B as Backend>::Device::default();
// Create an empty tensor (uninitialized values)
let empty: Tensor<B, 3> = Tensor::empty([2, 3, 4], &device);
println!("Empty tensor shape: {:?}", empty.shape());
// Create a tensor filled with zeros
let zeros: Tensor<B, 2> = Tensor::zeros([3, 3], &device);
println!("Zeros tensor: {}", zeros);
// Create a tensor filled with ones
let ones: Tensor<B, 2> = Tensor::ones([2, 4], &device);
println!("Ones tensor: {}", ones);
// Create a tensor filled with a specific value
let full: Tensor<B, 2> = Tensor::full([2, 3], 7.0, &device);
println!("Full tensor (7.0): {}", full);
// Create a tensor from a slice of values
let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let from_slice = Tensor::<B, 1>::from_floats(data, &device).reshape([2, 3]);
println!("From slice:\n{}", from_slice);
// Create a random tensor
use burn::tensor::Distribution;
let random: Tensor<B, 1> = Tensor::random([5], Distribution::Default, &device);
println!("Random tensor: {}", random);
// Create a tensor with normal distribution
let normal: Tensor<B, 1> = Tensor::random([5], Distribution::Normal(0.0, 1.0), &device);
println!("Normal distribution: {}", normal);
// Create a tensor with uniform distribution in range [0, 10)
let uniform: Tensor<B, 1> = Tensor::random([5], Distribution::Uniform(0.0, 10.0), &device);
println!("Uniform [0, 10): {}", uniform);
// Reshape tensor - change the dimensions without changing the data
let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &device).reshape([2, 3]);
println!("Original (2x3):\n{}", tensor);
let reshaped: Tensor<B, 3> = tensor.clone().reshape([1, 2, 3]);
println!("Reshaped (1x2x3): {}", reshaped);
// Flatten - reshape to 1D
let flat: Tensor<B, 1> = tensor.flatten(0, 1);
println!("Flattened: {}", flat);
// Transpose - swap dimensions
let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);
println!("Original:\n{}", tensor);
let transposed = tensor.clone().transpose();
println!("Transposed:\n{}", transposed);
// Also .t() works for 2D tensors
let t = tensor.t();
println!("Using .t():\n{}", t);
// Squeeze - remove dimensions of size 1
let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device).reshape([1, 1, 2]);
println!("Before squeeze [1,1,2]: shape = {:?}", tensor.shape());
let squeezed = tensor.squeeze::<1>();
println!("After squeeze: shape = {:?}", squeezed.shape());
// Unsqueeze - add a dimension of size 1 at specified position
let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);
println!("Before unsqueeze [2,2]: shape = {:?}", tensor.shape());
let unsqueezed = tensor.unsqueeze::<3>();
println!("After unsqueeze: shape = {:?}", unsqueezed.shape());
// Create a tensor for indexing examples
let tensor = Tensor::<B, 1>::from_floats(
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
&device
).reshape([3, 4]);
println!("Original tensor:\n{}", tensor);
// Slice tensor - select a portion using ranges
// Get rows 1-2 (index 1 to end), columns 1-3 (index 1 to 3)
let sliced = tensor.clone().slice([1..3, 1..4]);
println!("Sliced [1..3, 1..4]:\n{}", sliced);
// Get single row
let row = tensor.clone().slice([1..2, 0..4]);
println!("Row 1: {}", row);
// Get single column
let col = tensor.slice([0..3, 2..3]);
println!("Column 2: {}", col);
let a = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);
let b = Tensor::<B, 1>::from_floats([5.0, 6.0, 7.0, 8.0], &device).reshape([2, 2]);
println!("a = {}", a);
println!("b = {}", b);
// Addition
let c = a.clone() + b.clone();
println!("a + b = {}", c);
// Subtraction
let c = a.clone() - b.clone();
println!("a - b = {}", c);
// Multiplication (element-wise)
let c = a.clone() * b.clone();
println!("a * b = {}", c);
// Division (element-wise)
let c = a.clone() / b.clone();
println!("a / b = {}", c);
// Scalar operations
let a = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);
println!("a = {}", a);
// Add scalar
let c = a.clone() + 10.0;
println!("a + 10 = {}", c);
// Multiply scalar
let c = a.clone() * 2.0;
println!("a * 2 = {}", c);
// Matrix multiplication
let a = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);
let b = Tensor::<B, 1>::from_floats([5.0, 6.0, 7.0, 8.0], &device).reshape([2, 2]);
println!("a = {}", a);
println!("b = {}", b);
let result = a.matmul(b);
println!("a @ b (matmul) = {}", result);
// Verify (rows of a · columns of b): row1 [1,2] · col1 [5,7] = 1*5+2*7 = 19, row1 [1,2] · col2 [6,8] = 1*6+2*8 = 22
// row2 [3,4] · col1 [5,7] = 3*5+4*7 = 43, row2 [3,4] · col2 [6,8] = 3*6+4*8 = 50
let a: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0], &device);
println!("a = {}", a);
// Exponential
println!("exp(a) = {}", a.clone().exp());
// Natural logarithm
println!("log(a + 1) = {}", (a.clone() + 1.0).log());
// Power
println!("a.powf(2) = {}", a.clone().powf_scalar(2.0));
println!("a.powf(0.5) = {}", a.clone().powf_scalar(0.5));
// Trigonometric functions
let angles: Tensor<B, 1> = Tensor::from_floats([0.0, std::f32::consts::PI / 4.0, std::f32::consts::PI / 2.0], &device);
println!("angles = {}", angles);
println!("sin(angles) = {}", angles.clone().sin());
println!("cos(angles) = {}", angles.clone().cos());
println!("tan(angles) = {}", angles.clone().tan());
let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &device).reshape([2, 3]);
println!("Tensor:\n{}", tensor);
// Sum all elements
println!("Sum: {}", tensor.clone().sum());
// Mean of all elements
println!("Mean: {}", tensor.clone().mean());
// Product of all elements
println!("Product: {}", tensor.clone().prod());
// Maximum and minimum
println!("Max: {}", tensor.clone().max());
println!("Min: {}", tensor.clone().min());
// Reduce along specific dimensions
let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &device).reshape([2, 3]);
println!("Tensor:\n{}", tensor);
// Sum along dimension 0 (columns)
println!("Sum dim 0: {}", tensor.clone().sum_dim(0));
// Sum along dimension 1 (rows)
println!("Sum dim 1: {}", tensor.clone().sum_dim(1));
// Mean along dimension 0
println!("Mean dim 0: {}", tensor.clone().mean_dim(0));
let a: Tensor<B, 1> = Tensor::from_floats([1.0, 5.0, 3.0, 8.0], &device);
let b: Tensor<B, 1> = Tensor::from_floats([4.0, 2.0, 6.0, 7.0], &device);
println!("a = {}", a);
println!("b = {}", b);
// Element-wise comparison returns a boolean tensor
let greater = a.clone().greater(b.clone());
println!("a > b: {}", greater);
let less = a.clone().lower(b.clone());
println!("a < b: {}", less);
let equal = a.clone().equal(b.clone());
println!("a == b: {}", equal);
// Conditional selection
let a: Tensor<B, 1> = Tensor::from_floats([1.0, 5.0, 3.0, 8.0], &device);
// mask_where: where condition is true, use replacement value, else keep original value
let condition = a.clone().greater_elem(4.0);
let result = a.clone().mask_where(condition, Tensor::zeros([4], &device));
println!("Original: {}", a);
println!("Where > 4, replace with 0: {}", result);
// mask_fill: simpler - just replace values matching condition
let result = a.clone().mask_fill(a.clone().greater_elem(4.0), -1.0);
println!("Where > 4, replace with -1: {}", result);
In this notebook, we covered: