Back to Pytorch

Control Flow - Map

docs/source/higher_order_ops/map.md

2.11.02.0 KB
Original Source

(map)=

Control Flow - Map

torch.map is a structured control flow operator that applies a function over the leading dimension of input tensors. It can logically be seen as implemented as follows:

python
def map(
    f: Callable[[PyTree, ...], PyTree],
    xs: Union[PyTree, torch.Tensor],
    *args,
):
    out = []
    for idx in range(xs.size(0)):
        xs_sliced = xs.select(0, idx)
        out.append(f(xs_sliced, *args))
    return torch.stack(out)
{warning}
`torch._higher_order_ops.map` is a prototype feature in PyTorch. You may run into miscompiles.
Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype

Examples

Below is an example that uses map to apply a function over a batch:

{code-cell}
import torch
from torch._higher_order_ops import map

def f(x):
    return x.sin() + x.cos()

xs = torch.randn(3, 4, 5)  # batch of 3 tensors, each 4x5
# Applies f to each of the 3 slices
result = map(f, xs)  # returns tensor of shape [3, 4, 5]
print(result)

We can export the model with map for further transformations and deployment. This example uses dynamic shapes to allow variable batch size:

{code-cell}
class MapModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, xs: torch.Tensor) -> torch.Tensor:
        def body_fn(x):
            return x.sin() + x.cos()

        return map(body_fn, xs)

mod = MapModule()
inp = torch.randn(3, 4)
ep = torch.export.export(mod, (inp,), dynamic_shapes={"xs": {0: torch.export.Dim.DYNAMIC}})
print(ep)

Notice that torch.map is lowered to torch.ops.higher_order.map_impl, and the body function becomes a sub-graph attribute of the top-level graph module.

Restrictions

  • Mapped xs can only consist of tensors.

  • Leading dimensions of all tensors in xs must be consistent and non-zero.

  • The body function must not mutate inputs.

API Reference

{eval-rst}
.. autofunction:: torch._higher_order_ops.map.map