docs/adaptive_computation/parity.html
This creates data for Parity Task from the paper Adaptive Computation Time for Recurrent Neural Networks.
The input of the parity task is a vector with 0's 1's and −1's. The output is the parity of 1's - one if there is an odd number of 1's and zero otherwise. The input is generated by making a random number of elements in the vector either 1 or −1's.
19fromtypingimportTuple2021importtorch22fromtorch.utils.dataimportDataset
25classParityDataset(Dataset):
n_samples is the number of samplesn_elems is the number of elements in the input vector30def\_\_init\_\_(self,n\_samples:int,n\_elems:int=64):
35self.n\_samples=n\_samples36self.n\_elems=n\_elems
Size of the dataset
38def\_\_len\_\_(self):
42returnself.n\_samples
Generate a sample
44def\_\_getitem\_\_(self,idx:int)-\>Tuple[torch.Tensor,torch.Tensor]:
Empty vector
50x=torch.zeros((self.n\_elems,))
Number of non-zero elements - a random number between 1 and total number of elements
52n\_non\_zero=torch.randint(1,self.n\_elems+1,(1,)).item()
Fill non-zero elements with 1's and −1's
54x[:n\_non\_zero]=torch.randint(0,2,(n\_non\_zero,))\*2-1
Randomly permute the elements
56x=x[torch.randperm(self.n\_elems)]
The parity
59y=(x==1.).sum()%2
62returnx,y