torch/utils/data/typing.ipynb
DataPipe typing system is introduced to make the graph of DataPipes more reliable and provide type inference for users. The typing system provide the flexibility for users to determine which level(s) to have type enforcement and risk false positive errors.
from torch.utils.data import IterDataPipe
from typing import Any, TypeVar, Union
from collections.abc import Iterator
import sys
T_co = TypeVar('T_co', covariant=True)
# Hide traceback of Error
import functools
ipython = get_ipython()
def showtraceback(self, exc_tuple=None, filename=None, tb_offset=None,
exception_only=False, running_compiled_code=False) -> None:
try:
try:
etype, value, tb = self._get_exc_info(exc_tuple)
except ValueError:
print('No traceback available to show.', file=sys.stderr)
return
# Hide traceback
stb = self.InteractiveTB.get_exception_only(etype, value)
self._showtraceback(etype, value, stb)
except KeyboardInterrupt:
print('\n' + self.get_exception_only(), file=sys.stderr)
ipython.showtraceback = functools.partial(showtraceback, ipython)
Compile-time typing is enabled by default for now. And it will generate an attribute of type for each DataPipe. If there is no type hint specified, the DataPipe is set to a default type Any.
__iter__ is not Iteratorclass InvalidDP1(IterDataPipe[int]):
def __iter__(self) -> str:
pass
__iter__ doesn't match or is subtype of the declared type hintclass InvalidDP2(IterDataPipe[int]):
def __iter__(self) -> Iterator[str]:
pass
class DP(IterDataPipe[tuple]):
def __iter__(self) -> Iterator[tuple[int, str]]:
pass
print(DP.type)
class DP(IterDataPipe):
def __iter__(self) -> Iterator[int]:
pass
print(DP.type)
__iter__class DP(IterDataPipe):
def __iter__(self):
pass
print(DP.type)
class DP(IterDataPipe):
def __iter__(self) -> Iterator:
pass
print(DP.type)
class DP(IterDataPipe):
def __iter__(self) -> Iterator[T_co]:
pass
print(DP.type)
class DP(IterDataPipe[tuple[T_co, str]]):
def __iter__(self) -> Iterator[tuple[T_co, str]]:
pass
print(DP.type)
T = TypeVar('T', int, str) # equals to Union[int, str]
class DP(IterDataPipe[tuple[T, str]]):
def __iter__(self) -> Iterator[tuple[int | str, str]]:
pass
print(DP.type)
typeThe attribute type is added into each DataPipe class.
def print_helper(cls, obj) -> None:
print(f"DataPipe[{cls.type}]\nInstance type: {obj.type}")
class DP(IterDataPipe[list[int]]):
def __iter__(self) -> Iterator[list[int]]:
pass
print_helper(DP, DP())
class DP(IterDataPipe[Any]):
def __iter__(self) -> Iterator[Any]:
pass
print_helper(DP, DP())
class DP(IterDataPipe[tuple]):
def __iter__(self) -> Iterator[tuple]:
pass
print_helper(DP, DP())
Construct-time type checking can be enabled by a decorator argument_validation. Users can opt in by attaching the decorator to __init__function, then users can run operations with the type inference of input DataPipe(s).
from torch.utils.data import argument_validation
class DP(IterDataPipe):
@argument_validation
def __init__(self, dp: IterDataPipe[int | tuple]) -> None:
self.dp = dp
def __iter__(self):
yield from self.dp
dp = DP(range(10))
IterDataPipe with detail typing hints, the type of input instance must be a subtype of the hint.class Temp(IterDataPipe[str]):
def __iter__(self):
pass
dp = DP(Temp())
DataPipeclass Temp(IterDataPipe[tuple[int, T_co]]):
def __iter__(self):
pass
dp = DP(Temp())
Runtime type checking is enabled by a decorator runtime_validation. Users can opt in by attaching the decorator to __iter__ to check the output data is an instance of subtype of type attribute of the DataPipe.
Note: This decorator is only allowed to be attached to __iter__ for now. It can be extended into __getitem__ and further nonblocking functions.
runtime_validation_disabled is a context manager to turn off the type validation during runtime. It's useful for DataLoader to disable the runtime validation after the first epoch is finished for better performance. Note: the runtime validation is enabled by default.
from torch.utils.data import runtime_validation, runtime_validation_disabled
class DP(IterDataPipe[tuple[int, T_co]]):
def __init__(self, datasource) -> None:
self.ds = datasource
@runtime_validation
def __iter__(self):
yield from self.ds
Raise RuntimeError when the data is not of subtype
str is not subtype of intdp = DP([(1, 1), (2, 2), ('3', 3)])
for d in dp:
print(d)
with runtime_validation_disabled():
print(list(dp))
List is not subtype of Tupledp = DP([(1, 1), (2, 2), [3, 3]])
for d in dp:
print(d)
with runtime_validation_disabled():
print(list(dp))
dp = DP([(1, 1), (2, '2'), (3, 3.)])
print(list(dp))
T = TypeVar('T', int, str)
ds = list(range(10))
runtime_validation and the DataPipe instance calls reinforce_type, a warning will be raised.class DP(IterDataPipe[T]):
def __init__(self, ds) -> None:
self.ds = ds
def __iter__(self):
yield from self.ds
dp = DP(ds).reinforce_type(int)
class DP(IterDataPipe[T]):
def __init__(self, ds) -> None:
self.ds = ds
@runtime_validation
def __iter__(self):
yield from self.ds
dp = DP(ds).reinforce_type(float)
dp = DP(ds).reinforce_type(str)
list(dp)
with runtime_validation_disabled():
print(list(dp))
dp = DP(ds).reinforce_type(int)
print(list(dp))
class DP(IterDataPipe[Union[int, str]]):
def __init__(self, label) -> None:
if label == 'int':
self.reinforce_type(int)
elif label == 'str':
self.reinforce_type(str)
dp = DP('int')
print(dp.type)
dp = DP('str')
print(dp.type)
dp = DP('')
print(dp.type)