Back to Pytorch

Control Flow - While Loop

docs/source/higher_order_ops/while_loop.md

2.11.02.4 KB
Original Source

(while_loop)=

Control Flow - While Loop

torch.while_loop is a structured control flow operator that runs a body function while a condition is true. It can logically be seen as implemented as follows:

python
def while_loop(
    cond_fn: Callable[..., bool],
    body_fn: Callable[..., tuple],
    carried_inputs: tuple,
):
    val = carried_inputs
    while cond_fn(*val):
        val = body_fn(*val)
    return val
{warning}
`torch.while_loop` is a prototype feature in PyTorch. It has limited support for input and output types.
Please look forward to a more stable implementation in a future version of PyTorch.
Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype

Examples

Below is a basic example that uses while_loop to iterate until a condition is met:

{code-cell}
import torch
from torch._higher_order_ops import while_loop

class M(torch.nn.Module):

    def cond_fn(self, iter_count, x):
        return iter_count.sum() > 0

    def body_fn(self, iter_count, x):
        return iter_count - 1, x * 2

    def forward(self, init_iter, init_x):
        final_iter, final_x = while_loop(self.cond_fn, self.body_fn, (init_iter, init_x))
        return final_iter, final_x

m = M()

We can eagerly run the model and expect the results vary based on input shape:

{code-cell}
_, final_x = m(torch.tensor([3]), torch.ones(3))
assert torch.equal(final_x, torch.ones(3) * 2**3)

_, final_x = m(torch.tensor([10]), torch.ones(3))
assert torch.equal(final_x, torch.ones(3) * 2**10)

We can export the model for further transformations and deployment. This gives us an exported program that preserves the while_loop structure:

{code-cell}
ep = torch.export.export(M(), (torch.tensor([10]), torch.ones(3)))
print(ep)

Notice that both the condition and body functions become sub-graph attributes of the top-level graph module.

Restrictions

  • body_fn must return tensors or integers with the same metadata (shape, dtype) as inputs.

  • body_fn and cond_fn must not in-place mutate the carried_inputs. A clone before mutation is required.

  • body_fn and cond_fn must not mutate Python variables (e.g., list/dict) created outside the function.

  • body_fn and cond_fn's output cannot alias any of the inputs. A clone is required.

API Reference

{eval-rst}
.. autofunction:: torch._higher_order_ops.while_loop.while_loop