Back to Pytorch

Control Flow - Scan

docs/source/higher_order_ops/scan.md

2.11.02.5 KB
Original Source

(scan)=

Control Flow - Scan

torch.scan is a structured control flow operator that performs an inclusive scan with a combine function. It is commonly used for cumulative operations like cumsum, cumprod, or more general recurrences. It can logically be seen as implemented as follows:

python
def scan(
    combine_fn: Callable[[PyTree, PyTree], tuple[PyTree, PyTree]],
    init: PyTree,
    xs: PyTree,
    *,
    dim: int = 0,
    reverse: bool = False,
) -> tuple[PyTree, PyTree]:
    carry = init
    ys = []
    for i in range(xs.size(dim)):
        x_slice = xs.select(dim, i)
        carry, y = combine_fn(carry, x_slice)
        ys.append(y)
    return carry, torch.stack(ys)
{warning}
`torch.scan` is a prototype feature in PyTorch. You may run into miscompiles.
Read more about feature classification at:
https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype

Examples

Below is an example that uses scan to compute a cumulative sum:

{code-cell}
import torch
from torch._higher_order_ops import scan

def add(carry: torch.Tensor, x: torch.Tensor):
    next_carry = carry + x
    y = next_carry.clone()  # clone to avoid output-output aliasing
    return next_carry, y

init = torch.zeros(1)
xs = torch.arange(5, dtype=torch.float32)

final_carry, cumsum = scan(add, init=init, xs=xs)
print(final_carry)
print(cumsum)

We can export the model with scan for further transformations and deployment. This example uses dynamic shapes to allow variable sequence length:

{code-cell}
class ScanModule(torch.nn.Module):
    def forward(self, xs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        def combine_fn(carry, x):
            next_carry = carry + x
            return next_carry, next_carry.clone()

        init = torch.zeros_like(xs[0])
        return scan(combine_fn, init=init, xs=xs)

mod = ScanModule()
inp = torch.randn(5, 3)
ep = torch.export.export(mod, (inp,), dynamic_shapes={"xs": {0: torch.export.Dim.DYNAMIC}})
print(ep)

Notice that the combine function becomes a sub-graph attribute of the top-level graph module.

Restrictions

  • combine_fn must return tensors with the same metadata (shape, dtype) for next_carry as init.

  • combine_fn must not in-place mutate its inputs. A clone before mutation is required.

  • combine_fn must not mutate Python variables (e.g., list/dict) created outside the function.

  • combine_fn's output cannot alias any of the inputs. A clone is required.

API Reference

{eval-rst}
.. autofunction:: torch._higher_order_ops.scan.scan