docs/source/accelerator/amp.md
Automatic Mixed Precision (AMP) enables the use of both single precision (32-bit) and half precision (16-bit) floating point types during training or inference.
Key components include:
The CastPolicy is used to define type conversion rules. Each enum value represents a set of type conversion requirements for a group of operators, ensuring consistent handling of operations that prioritize either precision or performance.
| Policy | Explanation |
|---|---|
lower_precision_fp | Cast all inputs to lower_precision_fp before execute the op. |
fp32 | Cast all inputs to at::kFloat before running the op. |
fp32_set_opt_dtype | Execution in at::kFloat, while respecting user-specified output dtype if provided. |
fp32_append_dtype | Append at::kFloat to the args and redispatch to the type-aware overload |
promote | Promote all inputs to the “widest” dtype before execution. |
PyTorch defines a general list of operators for each of casting strategies mentioned above, as a reference for developers of new accelerators.
| Policy | Operators List |
|---|---|
lower_precision_fp | List Link |
fp32 | List Link |
fp32_set_opt_dtype | List Link |
fp32_append_dtype | List Link |
promote | List Link |
Implement the get_amp_supported_dtype method to return the data types supported by the new accelerator in the AMP context.
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/amp/__init__.py
:language: python
:start-after: LITERALINCLUDE START: AMP GET_SUPPORTED_DTYPE
:end-before: LITERALINCLUDE END: AMP GET_SUPPORTED_DTYPE
:linenos:
This section shows how AMP registers autocast kernels for the AutocastPrivateUse1 dispatch key.
AutocastPrivateUse1 using the KERNEL_PRIVATEUSEONE helper macro, which maps an op to the desired precision implementation (with enum at::autocast::CastPolicy).. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/amp/autocast_mode.cpp
:language: c++
:start-after: LITERALINCLUDE START: AMP FALLTHROUTH
:end-before: LITERALINCLUDE END: AMP FALLTHROUTH
:linenos:
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/amp/autocast_mode.cpp
:language: c++
:start-after: LITERALINCLUDE START: AMP IMPL
:end-before: LITERALINCLUDE END: AMP IMPL
:emphasize-lines: 3,6,8-10
:linenos: