examples/notebook/autodiff.ipynb
This notebook demonstrates how to use automatic differentiation in Burn to compute gradients and implement gradient descent.
// Dependency declarations
:dep burn = {path = "../../crates/burn"}
:dep burn-ndarray = {path = "../../crates/burn-ndarray"}
:dep burn-autodiff = {path = "../../crates/burn-autodiff"}
// Import packages
use burn::prelude::*;
use burn_autodiff::Autodiff;
use burn_ndarray::NdArray;
// Type alias: Autodiff<NdArray> enables automatic differentiation
type B = Autodiff<NdArray<f32>>;
In Burn, tensors can be marked for gradient tracking using .require_grad(). This tells the framework to track operations on this tensor so gradients can be computed later.
let device = <B as Backend>::Device::default();
// Create a regular tensor - no gradient tracking
let x: Tensor<B, 1> = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device);
println!("Regular tensor x: {}", x);
// Create a tensor that requires gradient computation
let y: Tensor<B, 1> = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device).require_grad();
println!("Tensor y with require_grad: {}", y);
// Now let's do some operations on y
let z = y.clone() * 2.0;
let result = z.sum();
println!("result = sum(y * 2) = {}", result);
The .backward() method computes the gradients of all tensors that have require_grad() set. It returns a gradients object that holds the computed gradients.
// Example: y = [1, 2, 3, 4]
// z = y * 2 = [2, 4, 6, 8]
// result = sum(z) = 20
//
// d(result)/d(y) = d(result)/dz * dz/dy = 1 * 2 = [2, 2, 2, 2]
let device = <B as Backend>::Device::default();
let y: Tensor<B, 1> = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device).require_grad();
let z = y.clone() * 2.0;
let result = z.sum();
// Compute gradients
let grads = result.backward();
// Get gradient for y
let y_grad = y.grad(&grads).unwrap();
println!("y = {}", y);
println!("d(result)/dy = {}", y_grad);
Let's compute the gradient of a more complex function: f(x) = x²
The derivative is: f'(x) = 2x
// f(x) = x^2
// f'(x) = 2x
let device = <B as Backend>::Device::default();
let x: Tensor<B, 1> = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device).require_grad();
let y = x.clone().powf_scalar(2.0);
let result = y.clone().sum();
let grads = result.backward();
let x_grad = x.grad(&grads).unwrap();
println!("x = {}", x);
println!("x^2 = {}", y);
println!("d(x^2)/dx = {}", x_grad);
// Verify: d(x^2)/dx should be [2, 4, 6, 8]
println!("Expected: [2, 4, 6, 8]");
Let's verify the chain rule: f(g(x))' = f'(g(x)) * g'(x)
Example: y = sin(x²), we want dy/dx
Let u = x², y = sin(u) dy/du = cos(u), du/dx = 2x dy/dx = cos(x²) * 2x
// y = sin(x^2)
// dy/dx = cos(x^2) * 2x
let device = <B as Backend>::Device::default();
let x: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device).require_grad();
// Forward pass
let x_squared = x.clone().powf_scalar(2.0);
let y = x_squared.sin();
let result = y.clone().sum();
// Backward pass
let grads = result.backward();
let x_grad = x.grad(&grads).unwrap();
println!("x = {}", x);
println!("y = sin(x^2) = {}", y);
println!("dy/dx = {}", x_grad);
// Verify manually: cos(x^2) * 2x
let expected_grad = x.clone().powf_scalar(2.0).cos() * (x.clone() * 2.0);
println!("Expected (cos(x^2) * 2x): {}", expected_grad);
Now let's implement the classic gradient descent algorithm to find the minimum of a function.
We'll minimize: f(x) = (x - 3)²
The minimum is at x = 3, where f(x) = 0
// Target: minimize f(x) = (x - 3)^2
// This has minimum at x = 3
fn loss<B: Backend>(x: &Tensor<B, 1>) -> Tensor<B, 1> {
// f(x) = (x - 3)^2
(x.clone() - 3.0).powf_scalar(2.0)
}
let device = <B as Backend>::Device::default();
// Start from x = 0
let mut x_val: f32 = 0.0;
let learning_rate: f32 = 0.1;
println!("Starting gradient descent to minimize (x - 3)^2");
println!("Expected minimum: x = 3");
println!("---");
for i in 0..20 {
// Create a new tensor with current x value and require gradients
let x = Tensor::<B, 1>::from_floats([x_val], &device).require_grad();
// Forward pass
let loss_value = loss(&x);
// Get loss as f32 for printing
let loss_scalar: f32 = loss_value.clone().into_scalar().elem::<f32>();
println!("Iteration {}: x = {:.4}, loss = {:.4}", i, x_val, loss_scalar);
// Backward pass
let grads = loss_value.backward();
let grad = x.grad(&grads).unwrap();
// Update: x = x - learning_rate * gradient
let grad_val: f32 = grad.into_scalar().elem::<f32>();
x_val = x_val - grad_val * learning_rate;
}
println!("---");
println!("Final x = {:.4}", x_val);
Let's use gradient descent to fit a simple linear regression model: y = wx + b
We'll generate synthetic data where the true relationship is y = 2x + 1
use burn::tensor::{Distribution, TensorData};
let device = <B as Backend>::Device::default();
// Generate synthetic data: y = 2x + 1 + noise
let num_samples = 100;
let x_data = TensorData::new((0..num_samples).map(|i| i as f32 / 10.0).collect(), [num_samples, 1]);
// Generate noise using Burn's random tensor
let noise = Tensor::<B, 2>::random([num_samples, 1], Distribution::Uniform(-0.25, 0.25), &device);
let x = Tensor::<B, 2>::from(x_data);
let y: Tensor<B, 2> = 2 * x.clone() + 1 + noise;
println!("Generated {} data points", num_samples);
println!("True relationship: y = 2x + 1");
println!("First 5 x values: {}", x.clone().slice([0..5, 0..1]));
println!("First 5 y values: {}", y.clone().slice([0..5, 0..1]));
// Initialize weights randomly
let device = <B as Backend>::Device::default();
let mut w_val: f32 = 0.5; // Start with reasonable initial values
let mut b_val: f32 = 0.5;
let learning_rate: f32 = 0.01;
let num_epochs = 100;
println!("Training linear regression with gradient descent...");
println!("Initial w = {:.4}, b = {:.4}", w_val, b_val);
for epoch in 0..num_epochs {
// Create tensors with current parameter values
let w = Tensor::<B, 2>::from_floats([[w_val]], &device).require_grad();
let b = Tensor::<B, 2>::from_floats([[b_val]], &device).require_grad();
// Forward pass: y_pred = w * x + b
let y_pred = x.clone().matmul(w.clone()) + b.clone();
// Compute loss: MSE = (1/n) * sum((y_pred - y)^2)
let loss = (y_pred.clone() - y.clone()).powf_scalar(2.0).mean();
// Backward pass
let grads = loss.backward();
let w_grad = w.grad(&grads).unwrap();
let b_grad = b.grad(&grads).unwrap();
// Update weights
let w_grad_val: f32 = w_grad.into_scalar().elem::<f32>();
let b_grad_val: f32 = b_grad.into_scalar().elem::<f32>();
w_val = w_val - w_grad_val * learning_rate;
b_val = b_val - b_grad_val * learning_rate;
if epoch % 20 == 0 {
let loss_val: f32 = loss.clone().into_scalar().elem::<f32>();
println!("Epoch {:3}: loss = {:.4}, w = {:.4}, b = {:.4}", epoch, loss_val, w_val, b_val);
}
}
println!("---");
println!("Final: w = {:.4}, b = {:.4}", w_val, b_val);
println!("True: w = 2.0, b = 1.0");
In this notebook, we covered:
These concepts are the foundation of neural network training in Burn!