functorch/docs/source/tutorials/minifier.ipynb
We have a pretty convenient test case minifier with this interface
def minifier(fail_f: fx.GraphModule, inps, module_fails):
"""
Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails.
Does 2 main strategies:
1. Truncates suffix: Removes some suffix from the graph and sets a new output.
2. Delta Debugging: Tries replacing half of the graph with inputs. If fails,
tries replacing quarter of the graph, etc.
>>> failing_function = fx.symbolic_trace(f)
>>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))
note: module_fails returns True if it fails.
...
Specifically, it takes your FX graph, and tries to minify it with the following 4 strategies (while checking that the resulting graph still returns True for module_fails), until it can't minify it anymore.
def f(a):
b = x * 2
c = b + 3
d = c / 4
return d
It might try truncating the suffix, and get
def f(a):
b = x * 2
c = b + 3
return c
It tries this in a binary search manner, trying to remove the last 1/2, then 3/4, 1/4 then 7/8, 5/8, 3/8...
def f(a):
b = x * 2
c = b + 3
d = c / 4
return d
We might remove a middle node (say, c, in this case).
def f(a, c):
b = x * 2
d = c / 4
return d
Finally, there are 2 auxiliary strategies - eliminating dead code and removing unused inputs. These are somewhat self-explanatory.
So, let's take a look at a toy example. Let's pretend that our graph fails if it has a "multiply" in it. Let's create a failing graph.
import torch
import torch.fx as fx
from functorch.compile import minifier
def failing_f(x, y):
y = torch.ops.aten.div(x, y)
x = torch.ops.aten.add(x, 3)
x = torch.ops.aten.mul(x, y)
return torch.ops.aten.sub(x, y)
inps = [torch.randn(3), torch.randn(3)]
def pass_checker(fx_g, inps):
return (torch.ops.aten.mul in {i.target for i in fx_g.graph.nodes})
min_f, inps = minifier(fx.symbolic_trace(failing_f), inps, pass_checker)
Tada! Our graph is now a minimal example that still fails.
Since the primary use case of this minifier (for now) is for NVFuser repros, we print out a string for convenience that creates a self-contained repro to run the minified graph with NVFuser.
Note that in practice, we provide 2 main "graph checkers" - check_nvfuser_subprocess and check_nvfuser_correctness_subprocess. These are used to check for errors and correctness (i.e. do the results match eager) respectively. These can be used like
from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess
minifier(failing_graph, inps, check_nvfuser_subprocess)
However, assuming you're using AOTAutograd, there's another problem - how do you obtain the FX graph in the first place to pass to the minifier? One possible way is simply to use print_compile.
from functorch.compile import aot_function
from functorch.compile import print_compile
# Or...
def print_compile(fx_g, _):
print(fx_g.code)
return fx_g
def foo(x):
return x.cos().cos()
inp = torch.randn(3, requires_grad=True)
aot_function(foo, print_compile)(inp)
However, this doesn't provide the inputs, nor does it handle any tensor constants that might be saved in the graph. To resolve this, we have another "compiler" called debug_compile. It simply prints out a string that can be copy pasted and run from another file. It leverages FX's to_folder feature to serialize the graph to disk, along with any constants.
You can apply it to either the fw_compiler to dump the forwards graph or bw_compiler to dump the backwards graph.
from functorch.compile import memory_efficient_fusion, debug_compile
memory_efficient_fusion(foo, bw_compiler=debug_compile)(inp)
So, let's copy paste it and see how it works - note that I made a couple minor modifications to run on CPU and use the previous "graph fails if there's a multiply in it" checker.
import torch
import torch.fx as fx
from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess
inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]
inps = [torch.ones(shape, dtype=dtype) for (shape, dtype) in inps]
from foo import FxModule
mod = FxModule()
minifier(fx.symbolic_trace(mod), inps, pass_checker)
Hopefully that was useful :)