docs/python/on_device_training/training_artifacts.rst
Before the training can start on edge devices, the training artifacts need to be generated in an offline step.
These artifacts include:
It is assumed that the an forward only onnx model is already available. This model can be generated by exporting the PyTorch model using the :func:torch.onnx.export API if using PyTorch.
.. note:: If using PyTorch to export the model, please use the following export arguments so training artifact generation can be successful:
- ``export_params``: ``True``
- ``do_constant_folding``: ``False``
- ``training``: ``torch.onnx.TrainingMode.TRAINING``
Once the forward only onnx model is available, the training artifacts can be generated using the :func:onnxruntime.training.artifacts.generate_artifacts API.
Sample usage:
.. code-block:: python
from onnxruntime.training import artifacts
# Load the forward only onnx model
model = onnx.load(path_to_forward_only_onnx_model)
# Generate the training artifacts
artifacts.generate_artifacts(model,
requires_grad = ["parameters", "needing", "gradients"],
frozen_params = ["parameters", "not", "needing", "gradients"],
loss = artifacts.LossType.CrossEntropyLoss,
optimizer = artifacts.OptimType.AdamW,
artifact_directory = path_to_output_artifact_directory)
.. autoclass:: onnxruntime.training.artifacts.LossType :members: :member-order: bysource :undoc-members:
.. autoclass:: onnxruntime.training.artifacts.OptimType :members: :member-order: bysource :undoc-members:
.. autofunction:: onnxruntime.training.artifacts.generate_artifacts
Custom Loss ++++++++++++
If a custom loss is needed, the user can provide a custom loss function to the :func:onnxruntime.training.artifacts.generate_artifacts API.
This is done by inheriting from the :class:onnxruntime.training.onnxblock.Block class and implementing the build method.
The following example shows how to implement a custom loss function:
Let's assume, we want to use a custom loss function with a model. For this example, we assume that our model generates two outputs. And the custom loss function must apply a loss function on each of the outputs and perform a weighted average on the output. Mathematically,
.. code-block:: python
loss = 0.4 * mse_loss1(output1, target1) + 0.6 * mse_loss2(output2, target2)
Since this is a custom loss function, this loss type is not exposed as an enum by LossType enum.
For this, we make use of onnxblock.
.. code-block:: python
import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training import artifacts
# Define a custom loss block that takes in two inputs
# and performs a weighted average of the losses from these
# two inputs.
class WeightedAverageLoss(onnxblock.Block):
def __init__(self):
self._loss1 = onnxblock.loss.MSELoss()
self._loss2 = onnxblock.loss.MSELoss()
self._w1 = onnxblock.blocks.Constant(0.4)
self._w2 = onnxblock.blocks.Constant(0.6)
self._add = onnxblock.blocks.Add()
self._mul = onnxblock.blocks.Mul()
def build(self, loss_input_name1, loss_input_name2):
# The build method defines how the block should be stacked on top of
# loss_input_name1 and loss_input_name2
# Returns weighted average of the two losses
return self._add(
self._mul(self._w1(), self._loss1(loss_input_name1, target_name="target1")),
self._mul(self._w2(), self._loss2(loss_input_name2, target_name="target2"))
)
my_custom_loss = WeightedAverageLoss()
# Load the onnx model
model_path = "model.onnx"
base_model = onnx.load(model_path)
# Define the parameters that need their gradient computed
requires_grad = ["weight1", "bias1", "weight2", "bias2"]
frozen_params = ["weight3", "bias3"]
# Now, we can invoke generate_artifacts with this custom loss function
artifacts.generate_artifacts(base_model, requires_grad = requires_grad, frozen_params = frozen_params,
loss = my_custom_loss, optimizer = artifacts.OptimType.AdamW)
# Successful completion of the above call will generate 4 files in the current working directory,
# one for each of the artifacts mentioned above (training_model.onnx, eval_model.onnx, checkpoint, optimizer_model.onnx)
.. autoclass:: onnxruntime.training.onnxblock.Block :members: :show-inheritance: :member-order: bysource :inherited-members:
Advanced Usage +++++++++++++++
onnxblock is a library that can be used to build complex onnx models by stacking simple blocks on top of each other. An example of this is the ability to build a custom loss function as shown above.
onnxblock also provides a way to build a custom forward only or training (forward + backward) onnx model through the :class:onnxruntime.training.onnxblock.ForwardBlock and :class:onnxruntime.training.onnxblock.TrainingBlock classes respectively. These blocks inherit from the base :class:onnxruntime.training.onnxblock.Block class and provide additional functionality to build inference and training models.
.. autoclass:: onnxruntime.training.onnxblock.ForwardBlock :members: :show-inheritance: :member-order: bysource :inherited-members:
.. autoclass:: onnxruntime.training.onnxblock.TrainingBlock :members: :show-inheritance: :member-order: bysource :inherited-members: