Back to Burn

Autodifferentiation and Gradient Descent in Burn

examples/notebook/autodiff.ipynb

0.20.17.7 KB
Original Source

Autodifferentiation and Gradient Descent in Burn

This notebook demonstrates how to use automatic differentiation in Burn to compute gradients and implement gradient descent.

rust
// Dependency declarations
:dep burn = {path = "../../crates/burn"}
:dep burn-ndarray = {path = "../../crates/burn-ndarray"}
:dep burn-autodiff = {path = "../../crates/burn-autodiff"}

rust
// Import packages
use burn::prelude::*;
use burn_autodiff::Autodiff;
use burn_ndarray::NdArray;

// Type alias: Autodiff<NdArray> enables automatic differentiation
type B = Autodiff<NdArray<f32>>;

1. Understanding require_grad()

In Burn, tensors can be marked for gradient tracking using .require_grad(). This tells the framework to track operations on this tensor so gradients can be computed later.

rust
let device = <B as Backend>::Device::default();

// Create a regular tensor - no gradient tracking
let x: Tensor<B, 1> = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device);
println!("Regular tensor x: {}", x);

// Create a tensor that requires gradient computation
let y: Tensor<B, 1> = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device).require_grad();
println!("Tensor y with require_grad: {}", y);

// Now let's do some operations on y
let z = y.clone() * 2.0;
let result = z.sum();
println!("result = sum(y * 2) = {}", result);

2. Computing Gradients with backward()

The .backward() method computes the gradients of all tensors that have require_grad() set. It returns a gradients object that holds the computed gradients.

rust
// Example: y = [1, 2, 3, 4]
// z = y * 2 = [2, 4, 6, 8]
// result = sum(z) = 20
//
// d(result)/d(y) = d(result)/dz * dz/dy = 1 * 2 = [2, 2, 2, 2]

let device = <B as Backend>::Device::default();
let y: Tensor<B, 1> = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device).require_grad();
let z = y.clone() * 2.0;
let result = z.sum();

// Compute gradients
let grads = result.backward();

// Get gradient for y
let y_grad = y.grad(&grads).unwrap();
println!("y = {}", y);
println!("d(result)/dy = {}", y_grad);

3. More Complex Example: Quadratic Function

Let's compute the gradient of a more complex function: f(x) = x²

The derivative is: f'(x) = 2x

rust
// f(x) = x^2
// f'(x) = 2x

let device = <B as Backend>::Device::default();
let x: Tensor<B, 1> = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device).require_grad();
let y = x.clone().powf_scalar(2.0);
let result = y.clone().sum();

let grads = result.backward();
let x_grad = x.grad(&grads).unwrap();

println!("x = {}", x);
println!("x^2 = {}", y);
println!("d(x^2)/dx = {}", x_grad);

// Verify: d(x^2)/dx should be [2, 4, 6, 8]
println!("Expected: [2, 4, 6, 8]");

4. Chain Rule Example

Let's verify the chain rule: f(g(x))' = f'(g(x)) * g'(x)

Example: y = sin(x²), we want dy/dx

Let u = x², y = sin(u) dy/du = cos(u), du/dx = 2x dy/dx = cos(x²) * 2x

rust
// y = sin(x^2)
// dy/dx = cos(x^2) * 2x

let device = <B as Backend>::Device::default();
let x: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device).require_grad();

// Forward pass
let x_squared = x.clone().powf_scalar(2.0);
let y = x_squared.sin();
let result = y.clone().sum();

// Backward pass
let grads = result.backward();
let x_grad = x.grad(&grads).unwrap();

println!("x = {}", x);
println!("y = sin(x^2) = {}", y);
println!("dy/dx = {}", x_grad);

// Verify manually: cos(x^2) * 2x
let expected_grad = x.clone().powf_scalar(2.0).cos() * (x.clone() * 2.0);
println!("Expected (cos(x^2) * 2x): {}", expected_grad);

5. Gradient Descent from Scratch

Now let's implement the classic gradient descent algorithm to find the minimum of a function.

We'll minimize: f(x) = (x - 3)²

The minimum is at x = 3, where f(x) = 0

rust
// Target: minimize f(x) = (x - 3)^2
// This has minimum at x = 3

fn loss<B: Backend>(x: &Tensor<B, 1>) -> Tensor<B, 1> {
    // f(x) = (x - 3)^2
    (x.clone() - 3.0).powf_scalar(2.0)
}

let device = <B as Backend>::Device::default();
// Start from x = 0
let mut x_val: f32 = 0.0;

let learning_rate: f32 = 0.1;

println!("Starting gradient descent to minimize (x - 3)^2");
println!("Expected minimum: x = 3");
println!("---");

for i in 0..20 {
    // Create a new tensor with current x value and require gradients
    let x = Tensor::<B, 1>::from_floats([x_val], &device).require_grad();
    
    // Forward pass
    let loss_value = loss(&x);
    
    // Get loss as f32 for printing
    let loss_scalar: f32 = loss_value.clone().into_scalar().elem::<f32>();
    
    println!("Iteration {}: x = {:.4}, loss = {:.4}", i, x_val, loss_scalar);

    // Backward pass
    let grads = loss_value.backward();
    let grad = x.grad(&grads).unwrap();
    
    // Update: x = x - learning_rate * gradient
    let grad_val: f32 = grad.into_scalar().elem::<f32>();
    x_val = x_val - grad_val * learning_rate;
}

println!("---");
println!("Final x = {:.4}", x_val);

6. Linear Regression with Gradient Descent

Let's use gradient descent to fit a simple linear regression model: y = wx + b

We'll generate synthetic data where the true relationship is y = 2x + 1

rust
use burn::tensor::{Distribution, TensorData};

let device = <B as Backend>::Device::default();
// Generate synthetic data: y = 2x + 1 + noise
let num_samples = 100;
let x_data = TensorData::new((0..num_samples).map(|i| i as f32 / 10.0).collect(), [num_samples, 1]);
// Generate noise using Burn's random tensor
let noise = Tensor::<B, 2>::random([num_samples, 1], Distribution::Uniform(-0.25, 0.25), &device);

let x = Tensor::<B, 2>::from(x_data);
let y: Tensor<B, 2> = 2 * x.clone() + 1 + noise;

println!("Generated {} data points", num_samples);
println!("True relationship: y = 2x + 1");
println!("First 5 x values: {}", x.clone().slice([0..5, 0..1]));
println!("First 5 y values: {}", y.clone().slice([0..5, 0..1]));

rust
// Initialize weights randomly
let device = <B as Backend>::Device::default();
let mut w_val: f32 = 0.5; // Start with reasonable initial values
let mut b_val: f32 = 0.5;

let learning_rate: f32 = 0.01;
let num_epochs = 100;

println!("Training linear regression with gradient descent...");
println!("Initial w = {:.4}, b = {:.4}", w_val, b_val);

for epoch in 0..num_epochs {
    // Create tensors with current parameter values
    let w = Tensor::<B, 2>::from_floats([[w_val]], &device).require_grad();
    let b = Tensor::<B, 2>::from_floats([[b_val]], &device).require_grad();
    
    // Forward pass: y_pred = w * x + b
    let y_pred = x.clone().matmul(w.clone()) + b.clone();
    
    // Compute loss: MSE = (1/n) * sum((y_pred - y)^2)
    let loss = (y_pred.clone() - y.clone()).powf_scalar(2.0).mean();
    
    // Backward pass
    let grads = loss.backward();
    let w_grad = w.grad(&grads).unwrap();
    let b_grad = b.grad(&grads).unwrap();
    
    // Update weights
    let w_grad_val: f32 = w_grad.into_scalar().elem::<f32>();
    let b_grad_val: f32 = b_grad.into_scalar().elem::<f32>();
    w_val = w_val - w_grad_val * learning_rate;
    b_val = b_val - b_grad_val * learning_rate;
    
    if epoch % 20 == 0 {
        let loss_val: f32 = loss.clone().into_scalar().elem::<f32>();
        println!("Epoch {:3}: loss = {:.4}, w = {:.4}, b = {:.4}", epoch, loss_val, w_val, b_val);
    }
}

println!("---");
println!("Final: w = {:.4}, b = {:.4}", w_val, b_val);
println!("True: w = 2.0, b = 1.0");

Summary

In this notebook, we covered:

  • require_grad(): Mark tensors for gradient tracking
  • backward(): Compute gradients automatically using reverse-mode autodiff
  • grad(): Retrieve computed gradients
  • Gradient Descent: Implemented from scratch to minimize a quadratic function
  • Linear Regression: Used gradient descent to fit a linear model to data

These concepts are the foundation of neural network training in Burn!