burn-book/src/saving-and-loading.md
Saving your trained machine learning model is quite easy, no matter the output format you choose. As
mentioned in the Record section, different formats are supported to
serialize/deserialize models. By default, we use the NamedMpkFileRecorder which uses the
MessagePack binary serialization format with the help of
rmp_serde.
// Save model in MessagePack format with full precision
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
model
.save_file(model_path, &recorder)
.expect("Should be able to save the model");
Note that the file extension is automatically handled by the recorder depending on the one you choose. Therefore, only the file path and base name should be provided.
Now that you have a trained model saved to your disk, you can easily load it in a similar fashion.
// Load model in full precision from MessagePack file
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
model = model
.load_file(model_path, &recorder, device)
.expect("Should be able to load the model weights from the provided file");
Note: models can be saved in different output formats, just make sure you are using the correct recorder type when loading the saved model. Type conversion between different precision settings is automatically handled, but formats are not interchangeable. A model can be loaded from one format and saved to another format, just as long as you load it back with the new recorder type afterwards.
The most straightforward way to load weights for a module is simply by using the generated method
load_record. Note that
parameter initialization is lazy, therefore no actual tensor allocation and GPU/CPU kernels are
executed before the module is used. This means that you can use init(device) followed by
load_record(record) without any meaningful performance cost.
// Create a dummy initialized model to save
let device = Default::default();
let model = Model::<MyBackend>::init(&device);
// Save model in MessagePack format with full precision
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
model
.save_file(model_path, &recorder)
.expect("Should be able to save the model");
Afterwards, the model can just as easily be loaded from the record saved on disk.
// Load model record on the backend's default device
let record: ModelRecord<MyBackend> =
NamedMpkFileRecorder::<FullPrecisionSettings>::new()
.load(model_path.into(), &device)
.expect("Could not load model weights");
// Initialize a new model with the loaded record/weights
let model = Model::init(&device).load_record(record);
While the Recorder API works well for basic saving and loading, burn-store was introduced to
address its limitations around memory efficiency and flexibility. It provides zero-copy
memory-mapped loading, cross-framework interoperability (PyTorch and SafeTensors), key remapping,
partial loading, and filtering. The burn-store
crate is intended to eventually replace the Recorder API, but since it was recently released, both
APIs are supported.
| Format | Extension | Description |
|---|---|---|
| Burnpack | .bpk | Burn's native format with fast loading, zero-copy support, and training state persistence |
| SafeTensors | .safetensors | Industry-standard format from Hugging Face for secure tensor serialization |
| PyTorch | .pt, .pth | Direct loading of PyTorch model weights (read-only) |
use burn_store::{ModuleSnapshot, BurnpackStore};
// Save to Burnpack (recommended)
let mut store = BurnpackStore::from_file("model.bpk");
model.save_into(&mut store)?;
// Or save to SafeTensors
use burn_store::SafetensorsStore;
let mut store = SafetensorsStore::from_file("model.safetensors");
model.save_into(&mut store)?;
use burn_store::{ModuleSnapshot, BurnpackStore};
let device = Default::default();
let mut model = MyModel::init(&device);
// Load from Burnpack
let mut store = BurnpackStore::from_file("model.bpk");
model.load_from(&mut store)?;
You can load weights directly from PyTorch .pt files:
use burn_store::{ModuleSnapshot, PytorchStore};
let mut model = MyModel::init(&device);
let mut store = PytorchStore::from_file("pytorch_model.pt");
model.load_from(&mut store)?;
Save only the model weights (state_dict), not the entire model:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(2, 2, (2, 2))
self.conv2 = nn.Conv2d(2, 2, (2, 2), bias=False)
def forward(self, x):
return self.conv2(self.conv1(x))
model = Net()
torch.save(model.state_dict(), "model.pt") # Correct: save state_dict
# torch.save(model, "model.pt") # Wrong: saves entire model
Some PyTorch checkpoints nest the state_dict under a key:
let mut store = PytorchStore::from_file("checkpoint.pt")
.with_top_level_key("state_dict");
model.load_from(&mut store)?;
For SafeTensors files exported from PyTorch, use the adapter for proper weight transformation:
use burn_store::{ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore};
let mut model = MyModel::init(&device);
let mut store = SafetensorsStore::from_file("model.safetensors")
.with_from_adapter(PyTorchToBurnAdapter);
model.load_from(&mut store)?;
For SafeTensors files created by Burn, no adapter is needed:
let mut store = SafetensorsStore::from_file("model.safetensors");
model.load_from(&mut store)?;
from safetensors.torch import save_file
model = Net()
save_file(model.state_dict(), "model.safetensors")
Use the adapter when saving for PyTorch consumption:
use burn_store::{BurnToPyTorchAdapter, SafetensorsStore};
let mut store = SafetensorsStore::from_file("for_pytorch.safetensors")
.with_to_adapter(BurnToPyTorchAdapter)
.skip_enum_variants(true);
model.save_into(&mut store)?;
The load_from method returns detailed information about the loading process:
let result = model.load_from(&mut store)?;
// Print a formatted summary with suggestions
println!("{}", result);
// Or inspect individual fields
println!("Applied: {} tensors", result.applied.len());
println!("Missing: {:?}", result.missing);
println!("Errors: {:?}", result.errors);
if result.is_success() {
println!("All tensors loaded successfully");
}
Burnpack and SafeTensors support custom metadata:
let mut store = BurnpackStore::from_file("model.bpk")
.metadata("version", "1.0")
.metadata("description", "My trained model")
.metadata("epochs", "100");
model.save_into(&mut store)?;
Remap parameter names using regex patterns when model structures don't match:
let mut store = PytorchStore::from_file("model.pt")
// Remove prefix: "model.conv1.weight" -> "conv1.weight"
.with_key_remapping(r"^model\.", "")
// Rename: "layer1" -> "encoder.layer1"
.with_key_remapping(r"^layer", "encoder.layer");
model.load_from(&mut store)?;
For complex remapping:
use burn_store::KeyRemapper;
let remapper = KeyRemapper::new()
.add_pattern(r"^transformer\.h\.(\d+)\.", "transformer.layer$1.")?
.add_pattern(r"\.attn\.", ".attention.")?;
let mut store = SafetensorsStore::from_file("model.safetensors")
.remap(remapper);
Load weights even when some tensors are missing:
let mut store = PytorchStore::from_file("pretrained.pt")
.allow_partial(true);
let result = model.load_from(&mut store)?;
println!("Missing (initialized randomly): {:?}", result.missing);
Load or save only specific layers:
// Load only encoder layers
let mut store = SafetensorsStore::from_file("model.safetensors")
.with_regex(r"^encoder\..*")
.allow_partial(true);
// Save only encoder layers
let mut store = SafetensorsStore::from_file("encoder.safetensors")
.with_regex(r"^encoder\..*");
model.save_into(&mut store)?;
// Multiple patterns (OR logic)
let mut store = SafetensorsStore::from_file("model.safetensors")
.with_regex(r"^encoder\..*") // encoder tensors
.with_regex(r".*\.bias$") // OR any bias tensors
.with_full_path("decoder.scale"); // OR specific tensor
PyTorch nn.Sequential with mixed layers creates non-contiguous indices. PytorchStore
automatically remaps these:
PyTorch: fc.0.weight, fc.2.weight, fc.4.weight (gaps from ReLU layers)
Burn: fc.0.weight, fc.1.weight, fc.2.weight (contiguous)
This is enabled by default. Disable if needed:
let mut store = PytorchStore::from_file("model.pt")
.map_indices_contiguous(false);
For embedded models or large files, use zero-copy loading to avoid memory copies:
// Embedded model (compile-time)
static MODEL_DATA: &[u8] = include_bytes!("model.bpk");
let mut store = BurnpackStore::from_static(MODEL_DATA);
model.load_from(&mut store)?;
// Large file (memory-mapped)
let mut store = BurnpackStore::from_file("large_model.bpk")
.zero_copy(true);
model.load_from(&mut store)?;
Save models at half precision (F16) to reduce file size by ~50%, then load back at full precision:
use burn_store::{ModuleSnapshot, BurnpackStore, HalfPrecisionAdapter};
let adapter = HalfPrecisionAdapter::new();
// Save: F32 -> F16 (same adapter for both directions)
let mut store = BurnpackStore::from_file("model_f16.bpk")
.with_to_adapter(adapter.clone());
model.save_into(&mut store)?;
// Load: F16 -> F32
let mut store = BurnpackStore::from_file("model_f16.bpk")
.with_from_adapter(adapter);
model.load_from(&mut store)?;
By default, weights in Linear, Embedding, Conv*, LayerNorm, GroupNorm, InstanceNorm, RmsNorm, and
PRelu modules are converted. BatchNorm is excluded because its running variance can underflow in
F16. Customize with with_module() and without_module():
// Keep LayerNorm at full precision
let adapter = HalfPrecisionAdapter::new()
.without_module("LayerNorm");
// Add a custom module to the conversion set
let adapter = HalfPrecisionAdapter::new()
.with_module("CustomLayer");
Inspect tensors without loading into a model:
use burn_store::ModuleStore;
let mut store = PytorchStore::from_file("model.pt");
// List all tensor names
let names = store.keys()?;
// Get specific tensor
if let Some(snapshot) = store.get_snapshot("encoder.layer0.weight")? {
println!("Shape: {:?}, DType: {:?}", snapshot.shape, snapshot.dtype);
}
Transfer weights between models:
use burn_store::{ModuleSnapshot, PathFilter};
// Transfer all weights
let snapshots = model1.collect(None, None, false);
model2.apply(snapshots, None, None, false);
// Transfer only encoder weights
let filter = PathFilter::new().with_regex(r"^encoder\..*");
let snapshots = model1.collect(Some(filter.clone()), None, false);
model2.apply(snapshots, Some(filter), None, false);
| Category | Method | Description |
|---|---|---|
| Filtering | with_regex(pattern) | Filter by regex pattern |
with_full_path(path) | Include specific tensor | |
with_predicate(fn) | Custom filter logic | |
| Remapping | with_key_remapping(from, to) | Regex-based renaming |
remap(KeyRemapper) | Complex remapping rules | |
| Adapters | with_from_adapter(adapter) | Loading transformations |
with_to_adapter(adapter) | Saving transformations | |
HalfPrecisionAdapter::new() | F32/F16 mixed-precision | |
| Config | allow_partial(bool) | Continue on missing tensors |
with_top_level_key(key) | Access nested dict (PyTorch) | |
skip_enum_variants(bool) | Skip enum variants in paths | |
map_indices_contiguous(bool) | Remap non-contiguous indices | |
metadata(key, value) | Add custom metadata | |
zero_copy(bool) | Enable zero-copy loading |
| Method | Description |
|---|---|
keys() | Get ordered list of tensor names |
get_all_snapshots() | Get all tensors as BTreeMap |
get_snapshot(name) | Get specific tensor by name |
"Missing source values" error: You saved the entire PyTorch model instead of the state_dict.
Re-export with torch.save(model.state_dict(), "model.pt").
Shape mismatch: Your Burn model doesn't match the source architecture. Verify layer configurations (channels, kernel sizes, bias settings).
Key not found: Parameter names don't match. Use with_key_remapping() or inspect keys:
let store = PytorchStore::from_file("model.pt");
println!("Available keys: {:?}", store.keys()?);
Use Netron to visualize .pt and .safetensors files.
For Burnpack files:
cargo run --example burnpack_inspect model.bpk