Back to Pytorch

Transforms

docs/cpp/source/api/data/transforms.md

2.12.01.7 KB
Original Source

Transforms

Transforms apply preprocessing to data samples, such as normalization or augmentation. They can be chained using the .map() method on datasets.

Transform (Base Class)

The base class for all transforms. Subclass this to create custom transforms.

{doxygenclass}
:members:
:undoc-members:

BatchTransform (Base Class)

Base class for transforms that operate on entire batches.

{doxygenclass}
:members:
:undoc-members:

TensorTransform

Base class for transforms that operate on tensors specifically.

{doxygenclass}
:members:
:undoc-members:

Normalize

Normalizes tensors with a given mean and standard deviation.

{doxygenstruct}
:members:
:undoc-members:

Stack

Stacks a batch of tensors into a single tensor.

{doxygenstruct}
:members:
:undoc-members:

Example:

cpp
auto dataset = torch::data::datasets::MNIST("./data")
    .map(torch::data::transforms::Normalize<>(0.5, 0.5))
    .map(torch::data::transforms::Stack<>());

Lambda

{doxygenclass}
:members:
:undoc-members:

TensorLambda

{doxygenclass}
:members:
:undoc-members:

BatchLambda

{doxygenclass}
:members:
:undoc-members:

Chaining Transforms

Transforms can be chained together using .map():

cpp
auto dataset = torch::data::datasets::MNIST("./data")
    .map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
    .map(torch::data::transforms::Stack<>());