crates/burn-store/MIGRATION.md
This guide helps you migrate from the deprecated burn-import recorders (PyTorchFileRecorder,
SafetensorsFileRecorder) to the new burn-store API (PytorchStore, SafetensorsStore).
The new burn-store API provides:
Before (burn-import):
use burn::record::{FullPrecisionSettings, Recorder};
use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
// Load into a record, then create model from record
let record: ModelRecord<B> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("model.pt".into(), &device)
.expect("Failed to load");
let model = Model::init(&device).load_record(record);
After (burn-store):
use burn_store::{ModuleSnapshot, PytorchStore};
// Initialize model, then load weights directly
let mut model = Model::init(&device);
let mut store = PytorchStore::from_file("model.pt");
model.load_from(&mut store).expect("Failed to load");
Before (burn-import):
use burn::record::{FullPrecisionSettings, Recorder};
use burn_import::safetensors::{AdapterType, LoadArgs, SafetensorsFileRecorder};
let record: ModelRecord<B> = SafetensorsFileRecorder::<FullPrecisionSettings>::default()
.load("model.safetensors".into(), &device)
.expect("Failed to load");
let model = Model::init(&device).load_record(record);
After (burn-store):
use burn_store::{ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore};
let mut model = Model::init(&device);
// For SafeTensors exported from PyTorch, use the adapter
let mut store = SafetensorsStore::from_file("model.safetensors")
.with_from_adapter(PyTorchToBurnAdapter);
model.load_from(&mut store).expect("Failed to load");
// For native Burn SafeTensors, no adapter needed
let mut store = SafetensorsStore::from_file("model.safetensors");
model.load_from(&mut store).expect("Failed to load");
| burn-import | burn-store |
|---|---|
LoadArgs::new(path) | PytorchStore::from_file(path) |
.with_key_remap(pattern, replacement) | .with_key_remapping(pattern, replacement) |
.with_top_level_key(key) | .with_top_level_key(key) |
.with_debug_print() | (use tracing/logging instead) |
PyTorchFileRecorder::<FullPrecisionSettings> | (precision handled automatically) |
| burn-import | burn-store |
|---|---|
LoadArgs::new(path) | SafetensorsStore::from_file(path) |
.with_key_remap(pattern, replacement) | .with_key_remapping(pattern, replacement) |
.with_adapter_type(AdapterType::PyTorch) | .with_from_adapter(PyTorchToBurnAdapter) |
.with_adapter_type(AdapterType::NoAdapter) | (default, no adapter) |
.with_debug_print() | (use tracing/logging instead) |
SafetensorsFileRecorder::<FullPrecisionSettings> | (precision handled automatically) |
Before:
let args = LoadArgs::new("model.pt".into())
.with_key_remap("conv\\.(.*)", "$1")
.with_key_remap("^old_prefix\\.", "new_prefix.");
let record: ModelRecord<B> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(args, &device)?;
After:
let mut store = PytorchStore::from_file("model.pt")
.with_key_remapping("conv\\.(.*)", "$1")
.with_key_remapping("^old_prefix\\.", "new_prefix.");
model.load_from(&mut store)?;
Before:
let args = LoadArgs::new("checkpoint.pt".into())
.with_top_level_key("state_dict");
let record: ModelRecord<B> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(args, &device)?;
After:
let mut store = PytorchStore::from_file("checkpoint.pt")
.with_top_level_key("state_dict");
model.load_from(&mut store)?;
Before:
use burn_import::safetensors::{AdapterType, LoadArgs};
let args = LoadArgs::new("pytorch_model.safetensors".into())
.with_adapter_type(AdapterType::PyTorch);
let record: ModelRecord<B> = SafetensorsFileRecorder::<FullPrecisionSettings>::default()
.load(args, &device)?;
After:
use burn_store::{PyTorchToBurnAdapter, SafetensorsStore};
let mut store = SafetensorsStore::from_file("pytorch_model.safetensors")
.with_from_adapter(PyTorchToBurnAdapter);
model.load_from(&mut store)?;
Handle missing tensors gracefully:
let mut store = PytorchStore::from_file("model.pt")
.allow_partial(true);
let result = model.load_from(&mut store)?;
println!("Loaded: {:?}", result.applied);
println!("Missing: {:?}", result.missing);
Load only specific tensors:
let mut store = SafetensorsStore::from_file("model.safetensors")
.with_regex(r"^encoder\..*") // Only encoder layers
.allow_partial(true);
model.load_from(&mut store)?;
Save models (not supported by old recorders):
// Save to SafeTensors
let mut store = SafetensorsStore::from_file("output.safetensors")
.metadata("version", "1.0");
model.save_into(&mut store)?;
// Save to Burnpack (native format)
let mut store = BurnpackStore::from_file("output.bpk");
model.save_into(&mut store)?;
Get detailed information about loading:
let result = model.load_from(&mut store)?;
// Print the full result for debugging - shows applied, skipped, missing, and errors
println!("{}", result);
// Or access individual fields
println!("Applied: {} tensors", result.applied.len());
println!("Skipped: {} tensors", result.skipped.len());
println!("Missing: {:?}", result.missing);
println!("Errors: {:?}", result.errors);
// Check if fully successful
if result.is_success() {
println!("All tensors loaded successfully");
}
The LoadResult implements Display, so printing it shows a formatted summary with suggestions for
common issues (e.g., using allow_partial(true) for missing tensors).
Before:
[dependencies]
burn-import = { version = "0.x", features = ["pytorch", "safetensors"] }
After:
[dependencies]
burn-store = { version = "0.x", features = ["pytorch", "safetensors"] }
The new API loads directly into models, not records. Update your model initialization:
// Before: Create record, then model from record
let record = recorder.load(...)?;
let model = Model::init(&device).load_record(record);
// After: Create model, then load into it
let mut model = Model::init(&device);
model.load_from(&mut store)?;
If you had functions that took ModelRecord, update them to take Model:
// Before
fn infer(record: ModelRecord<B>) {
let model = Model::init(&device).load_record(record);
// ...
}
// After
fn infer(model: Model<B>) {
// Model already has weights loaded
// ...
}
The old API required explicit precision settings. The new API handles this automatically:
// Before: Had to specify FullPrecisionSettings or HalfPrecisionSettings
PyTorchFileRecorder::<FullPrecisionSettings>::default()
// After: Precision handled automatically based on tensor dtype
PytorchStore::from_file("model.pt")
The new API provides richer error information:
// Before: Simple Result
let record = recorder.load(args, &device)?;
// After: LoadResult with detailed info
let result = model.load_from(&mut store)?;
// Print the result to see a helpful summary with suggestions
println!("{}", result);
// Or handle specific issues programmatically
if !result.errors.is_empty() {
for (path, error) in &result.errors {
eprintln!("Error loading {}: {}", path, error);
}
}