Back to Pytorch

DataPipe Typing System

torch/utils/data/typing.ipynb

2.11.07.2 KB
Original Source

DataPipe Typing System

DataPipe typing system is introduced to make the graph of DataPipes more reliable and provide type inference for users. The typing system provide the flexibility for users to determine which level(s) to have type enforcement and risk false positive errors.

python
from torch.utils.data import IterDataPipe
from typing import Any, TypeVar, Union
from collections.abc import Iterator
import sys

T_co = TypeVar('T_co', covariant=True)
python
# Hide traceback of Error
import functools
ipython = get_ipython()
def showtraceback(self, exc_tuple=None, filename=None, tb_offset=None,
                  exception_only=False, running_compiled_code=False) -> None:
    try:
        try:
            etype, value, tb = self._get_exc_info(exc_tuple)
        except ValueError:
            print('No traceback available to show.', file=sys.stderr)
            return

        # Hide traceback
        stb = self.InteractiveTB.get_exception_only(etype, value)

        self._showtraceback(etype, value, stb)

    except KeyboardInterrupt:
        print('\n' + self.get_exception_only(), file=sys.stderr)
ipython.showtraceback = functools.partial(showtraceback, ipython)

Compile-time

Compile-time typing is enabled by default for now. And it will generate an attribute of type for each DataPipe. If there is no type hint specified, the DataPipe is set to a default type Any.

Invalid Typing

  • Return type hint of __iter__ is not Iterator
python
class InvalidDP1(IterDataPipe[int]):
    def __iter__(self) -> str:
        pass
  • Return type hint of __iter__ doesn't match or is subtype of the declared type hint
python
class InvalidDP2(IterDataPipe[int]):
    def __iter__(self) -> Iterator[str]:
        pass

Valid Typing

  • It's allowed that return type is a subtype of class type annotation
python
class DP(IterDataPipe[tuple]):
    def __iter__(self) -> Iterator[tuple[int, str]]:
        pass
print(DP.type)
python
class DP(IterDataPipe):
    def __iter__(self) -> Iterator[int]:
        pass
print(DP.type)
  • Default Typing (Any) with/without return hint for __iter__
python
class DP(IterDataPipe):
    def __iter__(self):
        pass
print(DP.type)
class DP(IterDataPipe):
    def __iter__(self) -> Iterator:
        pass
print(DP.type)
class DP(IterDataPipe):
    def __iter__(self) -> Iterator[T_co]:
        pass
print(DP.type)
  • Matched type hints (including equal but not same types)
python
class DP(IterDataPipe[tuple[T_co, str]]):
    def __iter__(self) -> Iterator[tuple[T_co, str]]:
        pass
print(DP.type)

T = TypeVar('T', int, str)  # equals to Union[int, str]
class DP(IterDataPipe[tuple[T, str]]):
    def __iter__(self) -> Iterator[tuple[int | str, str]]:
        pass
print(DP.type)

Attribute type

The attribute type is added into each DataPipe class.

python
def print_helper(cls, obj) -> None:
    print(f"DataPipe[{cls.type}]\nInstance type: {obj.type}")
python
class DP(IterDataPipe[list[int]]):
    def __iter__(self) -> Iterator[list[int]]:
        pass
print_helper(DP, DP())
python
class DP(IterDataPipe[Any]):
    def __iter__(self) -> Iterator[Any]:
        pass
print_helper(DP, DP())
python
class DP(IterDataPipe[tuple]):
    def __iter__(self) -> Iterator[tuple]:
        pass
print_helper(DP, DP())

Construct-time

Construct-time type checking can be enabled by a decorator argument_validation. Users can opt in by attaching the decorator to __init__function, then users can run operations with the type inference of input DataPipe(s).

python
from torch.utils.data import argument_validation

class DP(IterDataPipe):
    @argument_validation
    def __init__(self, dp: IterDataPipe[int | tuple]) -> None:
        self.dp = dp

    def __iter__(self):
        yield from self.dp
python
dp = DP(range(10))
  • When any input is annotated by IterDataPipe with detail typing hints, the type of input instance must be a subtype of the hint.
python
class Temp(IterDataPipe[str]):
    def __iter__(self):
        pass
dp = DP(Temp())
  • Example of valid input DataPipe
python
class Temp(IterDataPipe[tuple[int, T_co]]):
    def __iter__(self):
        pass
dp = DP(Temp())

Runtime

Decorator

Runtime type checking is enabled by a decorator runtime_validation. Users can opt in by attaching the decorator to __iter__ to check the output data is an instance of subtype of type attribute of the DataPipe.

Note: This decorator is only allowed to be attached to __iter__ for now. It can be extended into __getitem__ and further nonblocking functions.

runtime_validation_disabled is a context manager to turn off the type validation during runtime. It's useful for DataLoader to disable the runtime validation after the first epoch is finished for better performance. Note: the runtime validation is enabled by default.

python
from torch.utils.data import runtime_validation, runtime_validation_disabled

class DP(IterDataPipe[tuple[int, T_co]]):
    def __init__(self, datasource) -> None:
        self.ds = datasource

    @runtime_validation
    def __iter__(self):
        yield from self.ds

Raise RuntimeError when the data is not of subtype

  • str is not subtype of int
python
dp = DP([(1, 1), (2, 2), ('3', 3)])
for d in dp:
    print(d)
  • Context manager to disable the runtime validation
python
with runtime_validation_disabled():
    print(list(dp))
  • List is not subtype of Tuple
python
dp = DP([(1, 1), (2, 2), [3, 3]])
for d in dp:
    print(d)
  • Context manager to disable the runtime validation
python
with runtime_validation_disabled():
    print(list(dp))
  • No error will be raised when all data pass the validation
python
dp = DP([(1, 1), (2, '2'), (3, 3.)])
print(list(dp))

Reinforce type for DataPipe instance

python
T = TypeVar('T', int, str)
ds = list(range(10))
  • If the DataPipe class is not decorated with runtime_validation and the DataPipe instance calls reinforce_type, a warning will be raised.
python
class DP(IterDataPipe[T]):
    def __init__(self, ds) -> None:
        self.ds = ds

    def __iter__(self):
        yield from self.ds
dp = DP(ds).reinforce_type(int)
python
class DP(IterDataPipe[T]):
    def __init__(self, ds) -> None:
        self.ds = ds

    @runtime_validation
    def __iter__(self):
        yield from self.ds
  • expected type must be a subtype of the original type hint
python
dp = DP(ds).reinforce_type(float)
  • Integer data is not subtype of str
python
dp = DP(ds).reinforce_type(str)
list(dp)
  • Compatible with context manager to disable validation
python
with runtime_validation_disabled():
    print(list(dp))
  • Valid type enforcement
python
dp = DP(ds).reinforce_type(int)
print(list(dp))
  • Different type based on the logic of class initialization
python
class DP(IterDataPipe[Union[int, str]]):
    def __init__(self, label) -> None:
        if label == 'int':
            self.reinforce_type(int)
        elif label == 'str':
            self.reinforce_type(str)
python
dp = DP('int')
print(dp.type)
python
dp = DP('str')
print(dp.type)
python
dp = DP('')
print(dp.type)