docs/source/func.md
.. currentmodule:: torch.func
torch.func, previously known as "functorch", is JAX-like composable function transforms for PyTorch.
This library is currently in [beta](https://pytorch.org/blog/pytorch-feature-classification-changes/#beta).
What this means is that the features generally work (unless otherwise documented)
and we (the PyTorch team) are committed to bringing this library forward. However, the APIs
may change under user feedback and we don't have full coverage over PyTorch operations.
If you have suggestions on the API or use-cases you'd like to be covered, please
open a GitHub issue or reach out. We'd love to hear about how you're using the library.
A "function transform" is a higher-order function that accepts a numerical function and returns a new function that computes a different quantity.
{mod}torch.func has auto-differentiation transforms (grad(f) returns a function that
computes the gradient of f), a vectorization/batching transform (vmap(f)
returns a function that computes f over batches of inputs), and others.
These function transforms can compose with each other arbitrarily. For example,
composing vmap(grad(f)) computes a quantity called per-sample-gradients that
stock PyTorch cannot efficiently compute today.
There are a number of use cases that are tricky to do in PyTorch today:
Composing {func}vmap, {func}grad, and {func}vjp transforms allows us to express the above without designing a separate subsystem for each.
This idea of composable function transforms comes from the JAX framework.
.. toctree::
:maxdepth: 2
func.whirlwind_tour
func.api
func.ux_limitations
func.migrating