Back to Annotated Deep Learning Paper Implementations

Parity Task

docs/adaptive_computation/parity.html

latest1.8 KB
Original Source

homeadaptive_computation

View code on Github

#

Parity Task

This creates data for Parity Task from the paper Adaptive Computation Time for Recurrent Neural Networks.

The input of the parity task is a vector with 0's 1's and −1's. The output is the parity of 1's - one if there is an odd number of 1's and zero otherwise. The input is generated by making a random number of elements in the vector either 1 or −1's.

19fromtypingimportTuple2021importtorch22fromtorch.utils.dataimportDataset

#

Parity dataset

25classParityDataset(Dataset):

#

  • n_samples is the number of samples
  • n_elems is the number of elements in the input vector
30def\_\_init\_\_(self,n\_samples:int,n\_elems:int=64):

#

35self.n\_samples=n\_samples36self.n\_elems=n\_elems

#

Size of the dataset

38def\_\_len\_\_(self):

#

42returnself.n\_samples

#

Generate a sample

44def\_\_getitem\_\_(self,idx:int)-\>Tuple[torch.Tensor,torch.Tensor]:

#

Empty vector

50x=torch.zeros((self.n\_elems,))

#

Number of non-zero elements - a random number between 1 and total number of elements

52n\_non\_zero=torch.randint(1,self.n\_elems+1,(1,)).item()

#

Fill non-zero elements with 1's and −1's

54x[:n\_non\_zero]=torch.randint(0,2,(n\_non\_zero,))\*2-1

#

Randomly permute the elements

56x=x[torch.randperm(self.n\_elems)]

#

The parity

59y=(x==1.).sum()%2

#

62returnx,y

labml.ai