Back to Pytorch

Automatic Mixed Precision

docs/source/accelerator/amp.md

2.11.04.2 KB
Original Source

Automatic Mixed Precision

Background

Automatic Mixed Precision (AMP) enables the use of both single precision (32-bit) and half precision (16-bit) floating point types during training or inference.

Key components include:

  • Autocast: Automatically casts operations to lower-precision (e.g., float16 or bfloat16) to improve performance while maintaining accuracy.
  • Gradient Scaling: Dynamically scales gradients during backpropagation to prevent underflow when training with mixed precision.

Design

Casting Strategy

The CastPolicy is used to define type conversion rules. Each enum value represents a set of type conversion requirements for a group of operators, ensuring consistent handling of operations that prioritize either precision or performance.

PolicyExplanation
lower_precision_fpCast all inputs to lower_precision_fp before execute the op.
fp32Cast all inputs to at::kFloat before running the op.
fp32_set_opt_dtypeExecution in at::kFloat, while respecting user-specified output dtype if provided.
fp32_append_dtypeAppend at::kFloat to the args and redispatch to the type-aware overload
promotePromote all inputs to the “widest” dtype before execution.

Operators Lists

PyTorch defines a general list of operators for each of casting strategies mentioned above, as a reference for developers of new accelerators.

PolicyOperators List
lower_precision_fpList Link
fp32List Link
fp32_set_opt_dtypeList Link
fp32_append_dtypeList Link
promoteList Link

Implementation

Python Integration

Implement the get_amp_supported_dtype method to return the data types supported by the new accelerator in the AMP context.

{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/amp/__init__.py
    :language: python
    :start-after: LITERALINCLUDE START: AMP GET_SUPPORTED_DTYPE
    :end-before: LITERALINCLUDE END: AMP GET_SUPPORTED_DTYPE
    :linenos:

C++ Integration

This section shows how AMP registers autocast kernels for the AutocastPrivateUse1 dispatch key.

  • Register a fallback that makes unhandled ops fall through to their normal implementations.
  • Register specific aten kernels under AutocastPrivateUse1 using the KERNEL_PRIVATEUSEONE helper macro, which maps an op to the desired precision implementation (with enum at::autocast::CastPolicy)
{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/amp/autocast_mode.cpp
    :language: c++
    :start-after: LITERALINCLUDE START: AMP FALLTHROUTH
    :end-before: LITERALINCLUDE END: AMP FALLTHROUTH
    :linenos:

.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/amp/autocast_mode.cpp
    :language: c++
    :start-after: LITERALINCLUDE START: AMP IMPL
    :end-before: LITERALINCLUDE END: AMP IMPL
    :emphasize-lines: 3,6,8-10
    :linenos: