Back to Pytorch

Control Flow Operators

docs/source/higher_order_ops/index.md

2.11.03.8 KB
Original Source

(higher_order_ops)=

Control Flow Operators

Control flow operators are structured operators in PyTorch that enable expressing complex control flow patterns in a way that is compatible with torch.compile and torch.export. Unlike regular Python control flow, these operators preserve their semantics through torch.compile and torch.export, enabling data-dependent control flow in traced programs.

{warning}
Control flow operators are prototype features in PyTorch. They may have limited support for certain
input/output types and some may not fully support training. Read more about feature classification at:
https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype

Why Use Control Flow Operators?

PyTorch lets you write models using native Python, including control flow like if statements, for loops, and while loops. This is great for flexibility, but it creates challenges for compilation.

Consider this simple example:

python
if mod.static_config == 0:
    return f(x)
return g(x)

The two branches might have completely different operations. If we tried to compile both branches every time we hit an if statement, the number of code paths would explode exponentially and quickly become intractable.

To deal with this, torch.compile uses specialization and guards. When tracing your model, the compiler picks one code path based on the current value of the predicate (specialization), then adds a guard to check that assumption at runtime. If the guard fails, it recompiles.

The same applies to loops: the compiler unrolls them and guards on the number of iterations. This produces a straight-line computational graph that's easy to optimize.

This approach works well for static control flow, but breaks down in several cases:

  • Data-dependent control flow: When the predicate depends on the value of a tensor, the compiler can't pick a branch at compile time because the value isn't known yet. Similarly, it can't unroll a while loop if the iteration count depends on tensor values. The compiler handles this by graph breaking and falling back to Python, which also makes it impossible to run the model without a Python runtime (e.g., on edge devices).

  • Dynamic shape-dependent control flow: When the number of loop iterations or the branch predicate depends on a dynamic tensor size, specialization means the compiled code only works for that specific size. The compiler has to recompile whenever the size changes.

  • Large computational graphs: Even with a static iteration count, unrolling a large loop creates a graph that grows linearly with the number of iterations, even though each iteration does the same thing. This leads to long compile times and high memory usage.

Control flow operators solve these problems by representing control flow as explicit operators that the compiler understands. Instead of specializing away the control flow, these operators preserve it in the compiled graph.

Available Operators

{toctree}
:maxdepth: 1

cond
while_loop
scan
associative_scan
map

Quick Comparison

OperatorUse CaseExample
condIf pred is True, returns true_fn(*operands), otherwise returns false_fn(*operands).cond(pred, true_fn, false_fn, operands)
while_loopWhile cond_fn(*operands) is True, executes body_fn(*operands), which returns the operands for the next iteration.while_loop(cond_fn, body_fn, operands)
scanApplies cumulative operations to xs with carried statescan(combine_fn, init, xs)
associative_scanSimilar to scan, but requiring an associative combine_fn to allow for more optimizations.associative_scan(add, xs, dim=0)
mapComputes fn on each slice of xs and returns the stacked output.map(fn, xs)