docs/source/higher_order_ops/scan.md
(scan)=
torch.scan is a structured control flow operator that performs an inclusive scan with a combine function.
It is commonly used for cumulative operations like cumsum, cumprod, or more general recurrences.
It can logically be seen as implemented as follows:
def scan(
combine_fn: Callable[[PyTree, PyTree], tuple[PyTree, PyTree]],
init: PyTree,
xs: PyTree,
*,
dim: int = 0,
reverse: bool = False,
) -> tuple[PyTree, PyTree]:
carry = init
ys = []
for i in range(xs.size(dim)):
x_slice = xs.select(dim, i)
carry, y = combine_fn(carry, x_slice)
ys.append(y)
return carry, torch.stack(ys)
`torch.scan` is a prototype feature in PyTorch. You may run into miscompiles.
Read more about feature classification at:
https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
Below is an example that uses scan to compute a cumulative sum:
import torch
from torch._higher_order_ops import scan
def add(carry: torch.Tensor, x: torch.Tensor):
next_carry = carry + x
y = next_carry.clone() # clone to avoid output-output aliasing
return next_carry, y
init = torch.zeros(1)
xs = torch.arange(5, dtype=torch.float32)
final_carry, cumsum = scan(add, init=init, xs=xs)
print(final_carry)
print(cumsum)
We can export the model with scan for further transformations and deployment. This example uses dynamic shapes to allow variable sequence length:
class ScanModule(torch.nn.Module):
def forward(self, xs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
def combine_fn(carry, x):
next_carry = carry + x
return next_carry, next_carry.clone()
init = torch.zeros_like(xs[0])
return scan(combine_fn, init=init, xs=xs)
mod = ScanModule()
inp = torch.randn(5, 3)
ep = torch.export.export(mod, (inp,), dynamic_shapes={"xs": {0: torch.export.Dim.DYNAMIC}})
print(ep)
Notice that the combine function becomes a sub-graph attribute of the top-level graph module.
combine_fn must return tensors with the same metadata (shape, dtype) for next_carry as init.
combine_fn must not in-place mutate its inputs. A clone before mutation is required.
combine_fn must not mutate Python variables (e.g., list/dict) created outside the function.
combine_fn's output cannot alias any of the inputs. A clone is required.
.. autofunction:: torch._higher_order_ops.scan.scan