Back to Pytorch

Saving and Loading

docs/cpp/source/api/serialize/save_load.md

2.12.01.7 KB
Original Source

Saving and Loading

The primary interface for serialization uses the torch::save and torch::load functions, which can save and load tensors, modules, and optimizers.

Save Functions

{doxygenfunction}
{doxygenfunction}

Load Functions

{doxygenfunction}
{doxygenfunction}

Saving and Loading Tensors

cpp
// Save a tensor
torch::Tensor tensor = torch::randn({2, 3});
torch::save(tensor, "tensor.pt");

// Load a tensor
torch::Tensor loaded;
torch::load(loaded, "tensor.pt");

Saving and Loading Modules

cpp
// Define a model
struct Net : torch::nn::Module {
  Net() {
    fc1 = register_module("fc1", torch::nn::Linear(784, 64));
    fc2 = register_module("fc2", torch::nn::Linear(64, 10));
  }

  torch::Tensor forward(torch::Tensor x) {
    x = torch::relu(fc1->forward(x));
    return fc2->forward(x);
  }

  torch::nn::Linear fc1{nullptr}, fc2{nullptr};
};

// Save model
auto model = std::make_shared<Net>();
torch::save(model, "model.pt");

// Load model
auto loaded_model = std::make_shared<Net>();
torch::load(loaded_model, "model.pt");

Saving Optimizer State

cpp
auto model = std::make_shared<Net>();
auto optimizer = torch::optim::Adam(model->parameters(), 0.001);

// Train...

// Save both model and optimizer
torch::save(model, "model.pt");
torch::save(optimizer, "optimizer.pt");

// Load both
torch::load(model, "model.pt");
torch::load(optimizer, "optimizer.pt");