Back to Pytorch

Example IterDataPipe

torch/utils/data/standard_pipes.ipynb

2.11.08.6 KB
Original Source

Standard flow control and data processing DataPipes

python
from torch.utils.data import IterDataPipe
python
# Example IterDataPipe
class ExampleIterPipe(IterDataPipe):
    def __init__(self, range = 20) -> None:
        self.range = range
    def __iter__(self):
        yield from self.range

Batch

Function: batch

Description:

Alternatives:

Arguments:

  • batch_size: int desired batch size
  • unbatch_level:int = 0 if specified calls unbatch(unbatch_level=unbatch_level) on source datapipe before batching (see unbatch)
  • drop_last: bool = False

Example:

Classic batching produce partial batches by default

python
dp = ExampleIterPipe(10).batch(3)
for i in dp:
    print(i)

To drop incomplete batches add drop_last argument

python
dp = ExampleIterPipe(10).batch(3, drop_last = True)
for i in dp:
    print(i)

Sequential calling of batch produce nested batches

python
dp = ExampleIterPipe(30).batch(3).batch(2)
for i in dp:
    print(i)

It is possible to unbatch source data before applying the new batching rule using unbatch_level argument

python
dp = ExampleIterPipe(30).batch(3).batch(2).batch(10, unbatch_level=-1)
for i in dp:
    print(i)

Unbatch

Function: unbatch

Description:

Alternatives:

Arguments: unbatch_level:int = 1

Example:

python
dp = ExampleIterPipe(10).batch(3).shuffle().unbatch()
for i in dp:
    print(i)

By default unbatching is applied only on the first layer, to unbatch deeper use unbatch_level argument

python
dp = ExampleIterPipe(40).batch(2).batch(4).batch(3).unbatch(unbatch_level = 2)
for i in dp:
    print(i)

Setting unbatch_level to -1 will unbatch to the lowest level

python
dp = ExampleIterPipe(40).batch(2).batch(4).batch(3).unbatch(unbatch_level = -1)
for i in dp:
    print(i)

Map

Function: map

Description:

Alternatives:

Arguments:

  • nesting_level: int = 0

Example:

python
dp = ExampleIterPipe(10).map(lambda x: x * 2)
for i in dp:
    print(i)

map by default applies function to every mini-batch as a whole

python
dp = ExampleIterPipe(10).batch(3).map(lambda x: x * 2)
for i in dp:
    print(i)

To apply function on individual items of the mini-batch use nesting_level argument

python
dp = ExampleIterPipe(10).batch(3).batch(2).map(lambda x: x * 2, nesting_level = 2)
for i in dp:
    print(i)

Setting nesting_level to -1 will apply map function to the lowest level possible

python
dp = ExampleIterPipe(10).batch(3).batch(2).batch(2).map(lambda x: x * 2, nesting_level = -1)
for i in dp:
    print(i)

Filter

Function: filter

Description:

Alternatives:

Arguments:

  • nesting_level: int = 0
  • drop_empty_batches = True whether empty many batches dropped or not.

Example:

python
dp = ExampleIterPipe(10).filter(lambda x: x % 2 == 0)
for i in dp:
    print(i)

Classic filter by default applies filter function to every mini-batches as a whole

python
dp = ExampleIterPipe(10)
dp = dp.batch(3).filter(lambda x: len(x) > 2)
for i in dp:
    print(i)

You can apply filter function on individual elements by setting nesting_level argument

python
dp = ExampleIterPipe(10)
dp = dp.batch(3).filter(lambda x: x > 4, nesting_level = 1)
for i in dp:
    print(i)

If mini-batch ends with zero elements after filtering default behaviour would be to drop them from the response. You can override this behaviour using drop_empty_batches argument.

python
dp = ExampleIterPipe(10)
dp = dp.batch(3).filter(lambda x: x > 4, nesting_level = -1, drop_empty_batches = False)
for i in dp:
    print(i)
python
dp = ExampleIterPipe(20)
dp = dp.batch(3).batch(2).batch(2).filter(lambda x: x < 4 or x > 9 , nesting_level = -1, drop_empty_batches = False)
for i in dp:
    print(i)

Shuffle

Function: shuffle

Description:

Alternatives:

Arguments:

  • unbatch_level:int = 0 if specified calls unbatch(unbatch_level=unbatch_level) on source datapipe before batching (see unbatch)
  • buffer_size: int = 10000

Example:

python
dp = ExampleIterPipe(10).shuffle()
for i in dp:
    print(i)

shuffle operates on input mini-batches similar as on individual items

python
dp = ExampleIterPipe(10).batch(3).shuffle()
for i in dp:
    print(i)

To shuffle elements across batches use shuffle(unbatch_level) followed by batch pattern

python
dp = ExampleIterPipe(10).batch(3).shuffle(unbatch_level = -1).batch(3)
for i in dp:
    print(i)

Collate

Function: collate

Description:

Alternatives:

Arguments:

Example:

python
dp = ExampleIterPipe(10).batch(3).collate()
for i in dp:
    print(i)

GroupBy

Function: groupby

Usage: dp.groupby(lambda x: x[0])

Description: Batching items by combining items with same key into same batch

Arguments:

  • group_key_fn
  • group_size - yield resulted group as soon as group_size elements accumulated
  • guaranteed_group_size:int = None
  • unbatch_level:int = 0 if specified calls unbatch(unbatch_level=unbatch_level) on source datapipe before batching (see unbatch)

Attention

As datasteam can be arbitrary large, grouping is done on best effort basis and there is no guarantee that same key will never present in the different groups. You can call it local groupby where locallity is the one DataPipe process/thread.

python
dp = ExampleIterPipe(10).shuffle().groupby(lambda x: x % 3)
for i in dp:
    print(i)

By default group key function is applied to entire input (mini-batch)

python
dp = ExampleIterPipe(10).batch(3).groupby(lambda x: len(x))
for i in dp:
    print(i)

It is possible to unnest items from the mini-batches using unbatch_level argument

python
dp = ExampleIterPipe(10).batch(3).groupby(lambda x: x % 3, unbatch_level = 1)
for i in dp:
    print(i)

When internal buffer (defined by buffer_size) is overfilled, groupby will yield biggest group available

python
dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, buffer_size = 5)
for i in dp:
    print(i)

groupby will produce group_size sized batches on as fast as possible basis

python
dp = ExampleIterPipe(18).shuffle().groupby(lambda x: x % 3, group_size = 3)
for i in dp:
    print(i)

Remaining groups must be at least guaranteed_group_size big.

python
dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, group_size = 3, guaranteed_group_size = 2)
for i in dp:
    print(i)

Without defined group_size function will try to accumulate at least guaranteed_group_size elements before yielding resulted group

python
dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, guaranteed_group_size = 2)
for i in dp:
    print(i)

This behaviour becomes noticeable when data is bigger than buffer and some groups getting evicted before gathering all potential items

python
dp = ExampleIterPipe(15).groupby(lambda x: x % 3, guaranteed_group_size = 2, buffer_size = 6)
for i in dp:
    print(i)

With randomness involved you might end up with incomplete groups (so next example expected to fail in most cases)

python
dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, guaranteed_group_size = 2, buffer_size = 6)
for i in dp:
    print(i)

To avoid this error and drop incomplete groups, use drop_remaining argument

python
dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, guaranteed_group_size = 2, buffer_size = 6, drop_remaining = True)
for i in dp:
    print(i)

Zip

Function: zip

Description:

Alternatives:

Arguments:

Example:

python
_dp = ExampleIterPipe(5).shuffle()
dp = ExampleIterPipe(5).zip(_dp)
for i in dp:
    print(i)

Fork

Function: fork

Description:

Alternatives:

Arguments:

Example:

python
dp = ExampleIterPipe(2)
dp1, dp2, dp3 = dp.fork(3)
for i in dp1 + dp2 + dp3:
    print(i)

Demultiplexer

Function: demux

Description:

Alternatives:

Arguments:

Example:

python
dp = ExampleIterPipe(10)
dp1, dp2, dp3 = dp.demux(3, lambda x: x % 3)
for i in dp2:
    print(i)

Multiplexer

Function: mux

Description:

Alternatives:

Arguments:

Example:

python
dp1 = ExampleIterPipe(3)
dp2 = ExampleIterPipe(3).map(lambda x: x * 10)
dp3 = ExampleIterPipe(3).map(lambda x: x * 100)

dp = dp1.mux(dp2, dp3)
for i in dp:
    print(i)

Concat

Function: concat

Description: Returns DataPipes with elements from the first datapipe following by elements from second datapipes

Alternatives:

`dp = dp.concat(dp2, dp3)`
`dp = dp.concat(*datapipes_list)`

Example:

python
dp = ExampleIterPipe(4)
dp2 = ExampleIterPipe(3)
dp = dp.concat(dp2)
for i in dp:
    print(i)