website/docs/tensor-shapes-tutorial-basics.mdx
{/*
In this tutorial, you'll add tensor shape annotations to a simple multi-layer perceptron (MLP) model.
By the end, you'll understand Dim, Tensor[...], class-level type
parameters, and method-level type parameters.
Here's a simple actor network from a reinforcement learning setup — three Linear layers in sequence:
class BaselineActor(nn.Module):
def __init__(self, state_size: int, action_size: int):
super().__init__()
self.fc1 = nn.Linear(state_size, 400)
self.fc2 = nn.Linear(400, 400)
self.out = nn.Linear(400, action_size)
def forward(self, state):
h1 = F.relu(self.fc1(state))
h2 = F.relu(self.fc2(h1))
return torch.tanh(self.out(h2))
Without shape annotations, every intermediate value is just Tensor. You
can't tell from reading the code what shape h1 has, or whether the layer
dimensions are consistent.
The constructor takes two parameters that determine tensor dimensions:
state_size — the input dimension (flows into nn.Linear)action_size — the output dimension (flows into nn.Linear)Both flow to sub-module constructors, so both must be Dim, not int.
(Dim[X] is a type that bridges a runtime integer value to a type-level
symbol X — see Getting Started for details.)
There are also two fixed constants: 400 (hidden dimension). These are
literal values, not parameters, so they don't need type params.
Make the dimension parameters into Dim[...] and add class-level type
parameters:
class BaselineActor[S, A](nn.Module):
def __init__(self, state_size: Dim[S], action_size: Dim[A]) -> None:
super().__init__()
self.fc1 = nn.Linear(state_size, 400)
self.fc2 = nn.Linear(400, 400)
self.out = nn.Linear(400, action_size)
Now when someone writes BaselineActor(24, 4), the type checker binds
S = 24 and A = 4, inferring the type BaselineActor[24, 4]. The
sub-modules are automatically typed: self.fc1 is Linear[24, 400],
self.out is Linear[400, 4].
The forward method has one dynamic dimension — batch size — that varies across calls. Make it a method-level type parameter:
def forward[B](self, state: Tensor[B, S]) -> Tensor[B, A]:
h1 = F.relu(self.fc1(state))
h2 = F.relu(self.fc2(h1))
return torch.tanh(self.out(h2))
S and A are class-level params (fixed at construction). B is a
method-level param (bound per call).
Add assert_type after each intermediate to verify what pyrefly infers:
def forward[B](self, state: Tensor[B, S]) -> Tensor[B, A]:
h1 = F.relu(self.fc1(state))
assert_type(h1, Tensor[B, 400])
h2 = F.relu(self.fc2(h1))
assert_type(h2, Tensor[B, 400])
act = torch.tanh(self.out(h2))
assert_type(act, Tensor[B, A])
return act
Run pyrefly check. If any assert_type fails, the shape you expected
doesn't match what pyrefly inferred — investigate the mismatch.
Once all shapes check out, remove the assert_type calls. Each one
corresponds to an inlay type hint that your IDE shows permanently. Pyrefly
catches shape errors through your function signatures and return types
regardless — you don't need assert_type in the final code.
Smoke tests exercise the model at concrete dimensions. They verify that the shape annotations are consistent end-to-end:
def test_baseline_actor():
actor = BaselineActor(24, 4)
state = torch.randn(8, 24)
# pyrefly infers: Tensor[8, 24]
act = actor(state)
# pyrefly infers: Tensor[8, 4]
Use concrete dimensions in tests (Tensor[8, 24], not generic Tensor[B, S])
so the type checker verifies the full shape calculation.
import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem';
<Tabs> <TabItem value="after" label="With tensor shapes" default> ```python from __future__ import annotationsimport torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch_shapes import Dim
class BaselineActor[S, A](nn.Module):
def __init__(self, state_size: Dim[S], action_size: Dim[A]) -> None:
super().__init__()
self.fc1 = nn.Linear(state_size, 400)
self.fc2 = nn.Linear(400, 400)
self.out = nn.Linear(400, action_size)
def forward[B](self, state: Tensor[B, S]) -> Tensor[B, A]:
h1 = F.relu(self.fc1(state))
# pyrefly infers: Tensor[B, 400]
h2 = F.relu(self.fc2(h1))
# pyrefly infers: Tensor[B, 400]
act = torch.tanh(self.out(h2))
# pyrefly infers: Tensor[B, A]
return act
```
class BaselineActor(nn.Module):
def __init__(self, state_size: int, action_size: int):
super().__init__()
self.fc1 = nn.Linear(state_size, 400)
self.fc2 = nn.Linear(400, 400)
self.out = nn.Linear(400, action_size)
def forward(self, state):
h1 = F.relu(self.fc1(state))
# what shape is h1? Tensor — that's all you know
h2 = F.relu(self.fc2(h1))
return torch.tanh(self.out(h2))
```
:::note
This example uses from __future__ import annotations with new-style
generics — the simplest setup. Pyrefly shows inferred shapes as inlay type
hints in your IDE. If you want assert_type for runtime regression guards,
see Getting Started
for the torch_shapes.TypeVar import style, which requires old-style
generics.
:::
Dim[X] bridges runtime integer values to type-level symbols. Constructor
parameters that determine tensor dimensions should be Dim[X], not int.class Foo[S, A]) represent dimensions fixed at
construction time.def forward[B]) represent dimensions that vary
per call (batch size, sequence length).reveal_type
during development to inspect shapes in checker output.This model had a simple linear pipeline — each layer feeds into the next with known shapes. In Tutorial 2, you'll see what happens when layers are stacked in loops, as in Transformer architectures.