TENSOR_SHAPES_CONTRIBUTING.md
Pyrefly's tensor shape tracking is designed so that coverage of the PyTorch library can be extended without understanding Pyrefly's internals. This page explains the three mechanisms for specifying shape transforms and how to add new ones.
Shape tracking uses three complementary mechanisms:
.pyi files with shape-generic type signatures.
Covers modules like nn.Linear, nn.Conv2d, and functions like
torch.mm.tensor_ops_registry.rs. Covers operations
with complex shape logic like reshape, cat, transpose, and
F.interpolate.nn.Sequential chaining, .shape
attribute access, and .size().Most contributions involve fixture stubs or DSL functions. Special handlers require changes to Pyrefly's Rust source.
test/tensor_shapes/fixtures/torch/
├── __init__.pyi
├── nn/
│ ├── __init__.pyi # nn.Linear, nn.Conv2d, nn.LSTM, etc.
│ └── functional.pyi # F.relu, F.softmax, F.conv2d, etc.
├── distributions/
│ └── ... # torch.distributions
└── ...
The search_path config option tells Pyrefly to look here for type
information, overriding the real torch stubs.
A fixture stub provides a shape-generic type signature. For example,
nn.Linear:
class Linear[N, M](Module):
def __init__(self, in_features: Dim[N], out_features: Dim[M],
bias: bool = True) -> None: ...
def forward[*Xs](self, input: Tensor[*Xs, N]) -> Tensor[*Xs, M]: ...
The constructor captures the input and output dimensions as type parameters.
The forward method uses those parameters plus a variadic *Xs for batch
dimensions.
Dim[X] for parameters that determine
tensor dimensions. Non-shape parameters (bias, dropout) stay as
their original types.forward signature expressing the shape transform. Use
*Xs or *Bs for batch dimensions that pass through unchanged..pyi file in the fixtures
directory.Suppose you want to add nn.GroupNorm, which preserves spatial dimensions:
class GroupNorm[NumGroups, NumChannels](Module):
def __init__(
self,
num_groups: Dim[NumGroups],
num_channels: Dim[NumChannels],
eps: float = 1e-5,
affine: bool = True,
) -> None: ...
def forward[*S](self, input: Tensor[*S]) -> Tensor[*S]: ...
Since GroupNorm doesn't change the shape, the forward signature is simply
Tensor[*S] -> Tensor[*S].
DSL functions are registered in:
tensor_ops_registry.rs
Each entry maps a qualified PyTorch function name to a shape transform specification written in a tiny Python subset.
The DSL supports:
+, -, *, //)zip, len, indexingTensor(shape=[...]) to construct result shapesself.shape to access input shapesif/else)torch.repeatdef repeat_ir(self: Tensor, sizes: list[int | symint]) -> Tensor:
return Tensor(shape=[d * r for d, r in zip(self.shape, sizes)])
This says: the output shape is the element-wise product of the input shape
and the sizes argument.
torch.catdef cat_ir(tensors: list[Tensor], dim: int = 0) -> Tensor:
shapes = [t.shape for t in tensors]
result = list(shapes[0])
for s in shapes[1:]:
result[dim] = result[dim] + s[dim]
return Tensor(shape=result)
This sums the shapes along the concatenation dimension and preserves all others.
tensor_ops_registry.rs with the qualified PyTorch
name (e.g., "torch.nn.functional.adaptive_avg_pool2d").reveal_type
produces the expected shape.test/tensor_shapes/models/
Each file is a fully annotated port of a real-world PyTorch model with
assert_type checkpoints and smoke tests.
assert_type after every shape-changing operation.verify_port.sh to check for issues.verify_port.shThis script checks a ported model for common issues:
.claude/skills/port-model/verify_port.sh test/tensor_shapes/models/<model>.py
It reports:
| Metric | Description |
|---|---|
ig | type: ignore count |
bs | Bare Tensor in signatures |
bv | Bare Tensor in variable annotations |
sh | Shaped assert_type count |
ba | Bare assert_type count |
sm | Smoke test count |
After adding stubs, DSL functions, or ported models, run the test suite:
# Run a specific test
buck test pyrefly:pyrefly_library -- tensor_shape
# Run all tests
buck test pyrefly:pyrefly_library
For external contributors using cargo:
cargo test tensor_shape