Back to Onnxruntime

Train the Model on the Device

docs/python/on_device_training/training_api.rst

1.25.12.3 KB
Original Source

Train the Model on the Device

Once the training artifacts are generated, the model can be trained on the device using the onnxruntime training python API.

The expected training artifacts are:

  1. The training onnx model
  2. The checkpoint state
  3. The optimizer onnx model
  4. The eval onnx model (optional)

Sample usage:

.. code-block:: python

from onnxruntime.training.api import CheckpointState, Module, Optimizer

# Load the checkpoint state
state = CheckpointState.load_checkpoint(path_to_the_checkpoint_artifact)

# Create the module
module = Module(path_to_the_training_model,
                state,
                path_to_the_eval_model,
                device="cpu")

optimizer = Optimizer(path_to_the_optimizer_model, module)

# Training loop
for ...:
    module.train()
    training_loss = module(...)
    optimizer.step()
    module.lazy_reset_grad()

# Eval
module.eval()
eval_loss = module(...)

# Save the checkpoint
CheckpointState.save_checkpoint(state, path_to_the_checkpoint_artifact)

.. autoclass:: onnxruntime.training.api.checkpoint_state.Parameter :members: :show-inheritance: :member-order: bysource :inherited-members: :special-members: repr

.. autoclass:: onnxruntime.training.api.checkpoint_state.Parameters :members: :show-inheritance: :member-order: bysource :inherited-members: :special-members: getitem, setitem, contains, iter, repr, len

.. autoclass:: onnxruntime.training.api.checkpoint_state.Properties :members: :show-inheritance: :member-order: bysource :inherited-members: :special-members: getitem, setitem, contains, iter, repr, len

.. autoclass:: onnxruntime.training.api.CheckpointState :members: :show-inheritance: :member-order: bysource :inherited-members:

.. autoclass:: onnxruntime.training.api.Module :members: :show-inheritance: :member-order: bysource :inherited-members: :special-members: call

.. autoclass:: onnxruntime.training.api.Optimizer :members: :show-inheritance: :member-order: bysource :inherited-members:

.. autoclass:: onnxruntime.training.api.LinearLRScheduler :members: :show-inheritance: :member-order: bysource :inherited-members: