website/docs/tensor-shapes-tutorial-architectures.mdx
{/*
Tutorial 2 covered shape-preserving loops. This tutorial tackles architectures where shapes change systematically — encoder-decoder networks with skip connections, and recursive chains where dimensions grow or shrink exponentially.
Encoder-decoder architectures (UNet, Demucs, Super SloMo) encode the input to a bottleneck and then decode back, with skip connections between corresponding encoder and decoder layers.
Each encode step transforms (B, C, H, W) to (B, 2C, H', W') — doubling
channels and shrinking spatial dimensions. Decoding reverses this, using the
skip connection to restore the original shape. The key insight is that each
encode-recurse-decode cycle preserves the input shape:
encode: (B, C, H, W) → (B, 2C, H', W')
recurse: preserves (B, 2C, H', W')
decode + skip: (B, 2C, H', W') + (B, C, H, W) → (B, C, H, W)
This gives a recursive signature where recurse takes and returns the same
shape:
class UNet[NChannels, NClasses](nn.Module):
def _encode[B, C, H, W](
self, x: Tensor[B, C, H, W], depth: int
) -> Tensor[B, 2 * C, (H - 2) // 2 + 1, (W - 2) // 2 + 1]:
idx = len(self.downs) - depth
down: Down[C, 2 * C] = self.downs[idx]
return down(x)
def _decode[B, C, H, W](
self,
skip: Tensor[B, C, H, W],
deep: Tensor[B, 2 * C, (H - 2) // 2 + 1, (W - 2) // 2 + 1],
depth: int,
) -> Tensor[B, C, H, W]:
idx = len(self.ups) - depth
up: Up[2 * C, C] = self.ups[idx]
return up(deep, skip)
def recurse[I, B, C, H, W](
self, x: Tensor[B, C, H, W], depth: Dim[I]
) -> Tensor[B, C, H, W]:
if depth == 0:
return x
skip = x
encoded = self._encode(x, depth)
middle = self.recurse(encoded, depth - 1)
decoded = self._decode(skip, middle, depth)
return decoded
Python has no way to express "element i of this list has type
Stage[C * 2**i]". The workaround:
Any: list[Down[Any, Any]]down: Down[C, 2 * C] = self.downs[idx]
The Any erases element-level type info, and the annotation re-introduces
it for each use.
Some algebraic equivalences can't be automatically proven. For example,
((H - 2) // 2 + 1) * 2 does not simplify back to H. When you hit this,
use type: ignore with a comment explaining the gap:
return up(deep, skip) # type: ignore[bad-argument-type] # ((H-2)//2+1)*2 = H
Keep these to an absolute minimum and document each one.
When each stage doubles or halves a dimension, the result after I stages
involves 2**I. This appears in DCGAN (generator and discriminator), ResNet,
and DenseNet.
@overload patternUse @overload to separate the base case from the recursive case:
class Generator(nn.Module):
def _apply_stage[B, C, H, W](
self, x: Tensor[B, C, H, W], depth: int
) -> Tensor[B, C // 2, (H - 1) * 2 + 2, (W - 1) * 2 + 2]:
idx = len(self.up_stages) - depth
stage: GenUpStage[C] = self.up_stages[idx]
return stage(x)
@overload
def _chain[B, C, H, W](
self, x: Tensor[B, C, H, W], depth: Dim[1]
) -> Tensor[B, C // 2, H * 2, W * 2]: ...
@overload
def _chain[I, B, C, H, W](
self, x: Tensor[B, C, H, W], depth: Dim[I]
) -> Tensor[B, C // 2 ** I, H * 2 ** I, W * 2 ** I]: ...
def _chain[I, B, C, H, W](
self, x: Tensor[B, C, H, W], depth: Dim[I]
) -> (Tensor[B, C // 2, H * 2, W * 2]
| Tensor[B, C // 2 ** I, H * 2 ** I, W * 2 ** I]):
y = self._apply_stage(x, depth)
if depth == 1:
return y
return self._chain(y, depth - 1)
The base-case overload (depth: Dim[1]) handles the single-stage case
where the formula simplifies concretely. The recursive overload uses 2**I
to express the exponential relationship.
This _apply_stage + _chain pattern separates concerns:
_apply_stage: applies a single stage from the ModuleList, using a
narrowing annotation to type the list element._chain: recursively applies _apply_stage with overloaded return
types.The caller invokes _chain with a concrete depth:
def forward[B](self, input: Tensor[B, 100, 1, 1]) -> Tensor[B, 3, 64, 64]:
h0 = F.relu(self.project_bn(self.project(input)))
assert_type(h0, Tensor[B, 512, 4, 4])
h1 = self._chain(h0, 3) # 512->64, 4->32
assert_type(h1, Tensor[B, 64, 32, 32])
return torch.tanh(self.output(h1))
ModuleLists.@overload separates base and recursive cases for exponential shape
chains.type: ignore is a last resort for algebraic gaps the checker can't
prove. Always document the specific equivalence.In Tutorial 4, you'll see how to handle config classes with type parameters, dynamic construction patterns, and other advanced techniques.