functorch/docs/source/index.rst
:github_url: https://github.com/pytorch/functorch
.. currentmodule:: functorch
functorch is JAX-like <https://github.com/google/jax>_ composable function transforms for PyTorch.
.. warning::
We've integrated functorch into PyTorch. As the final step of the
integration, the functorch APIs are deprecated as of PyTorch 2.0.
Please use the torch.func APIs instead and see the
migration guide <https://pytorch.org/docs/main/func.migrating.html>_
and docs <https://pytorch.org/docs/main/func.html>_
for more details.
A "function transform" is a higher-order function that accepts a numerical function and returns a new function that computes a different quantity.
functorch has auto-differentiation transforms (grad(f) returns a function that
computes the gradient of f), a vectorization/batching transform (vmap(f)
returns a function that computes f over batches of inputs), and others.
These function transforms can compose with each other arbitrarily. For example,
composing vmap(grad(f)) computes a quantity called per-sample-gradients that
stock PyTorch cannot efficiently compute today.
There are a number of use cases that are tricky to do in PyTorch today:
Composing :func:vmap, :func:grad, and :func:vjp transforms allows us to express the above without designing a separate subsystem for each.
This idea of composable function transforms comes from the JAX framework <https://github.com/google/jax>_.
Check out our whirlwind tour <whirlwind_tour>_ or some of our tutorials mentioned below.
.. toctree:: :maxdepth: 2 :caption: functorch: Getting Started
install tutorials/whirlwind_tour.ipynb ux_limitations
.. toctree:: :maxdepth: 2 :caption: functorch API Reference and Notes
functorch experimental aot_autograd
.. toctree:: :maxdepth: 1 :caption: functorch Tutorials
tutorials/jacobians_hessians.ipynb tutorials/ensembling.ipynb tutorials/per_sample_grads.ipynb tutorials/neural_tangent_kernels.ipynb tutorials/aot_autograd_optimizations.ipynb tutorials/minifier.ipynb