docs/experiments/arithmetic_dataset.html
This is based on code by Georges Harik (@gharik).
11importrandom12importstring13fromtypingimportList1415importtorch16fromlabml.loggerimportText17fromtorch.utils.dataimportDataLoader,Dataset1819fromlabmlimportmonit,logger,tracker20fromlabml.configsimportoption21fromlabml\_nn.experiments.nlp\_autoregressionimportNLPAutoRegressionConfigs,transpose\_batch
This creates arithmetic addition problems and solutions with workings. We've only implemented addition so far.
It's based on a character level tokenization.
24classArithmeticDataset(Dataset):
seq_len is the sequence length of generated math problems. We fill as many problems as possible upto this length :max_digits: is the maximum number of digits in the operand integers :n_sequences: is the number of sequences per epoch34def\_\_init\_\_(self,seq\_len:int,max\_digits:int,n\_sequences:int):
41self.n\_sequences=n\_sequences42self.max\_digits=max\_digits43self.seq\_len=seq\_len
Token id to string
45self.itos=list(string.digits+'xe =\n?+;')
Character to token id
47self.stoi={c:ifori,cinenumerate(self.itos)}
Generates an integer with n_digit number of digits
49@staticmethod50defmake\_int(n\_digits:int):
54res=055foriinrange(n\_digits):56d=random.randrange(1,11)ifi==0elserandom.randrange(0,11)57res=res\*10+d5859returnres
Generates the workings for x + y . For example for 11+29 it generates 1e0+9e0+0e0=10e0 1e0+2e0+1e0=4e0 .
61@staticmethod62defget\_add\_explanation(x:int,y:int):
69carry=070e=071explanation=[]72whilex\>0ory\>0orcarry\>0:73rx,ry=x%10,y%1074total=rx+ry+carry75explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")76x,y,carry=x//10,y//10,total//1077e+=17879return' '.join(explanation)
Make a problem with a pre_explanation or not
Creates an arithmetic addition problem with workings and answer.
82defmake\_add\_problem(self):
86x=self.make\_int(n\_digits=random.randrange(1,self.max\_digits+1))87y=self.make\_int(n\_digits=random.randrange(1,self.max\_digits+1))8889explanation=self.get\_add\_explanation(x,y)90returnf"x={x}+{y}; {explanation} x=={x + y}\n"
Get arithmetic problem and answer. This is used for evaluation.
92defget\_qa(self):
96x=self.make\_int(n\_digits=random.randrange(1,self.max\_digits+1))97y=self.make\_int(n\_digits=random.randrange(1,self.max\_digits+1))9899returnf'x={x}+{y};',f'{x + y}'
Generate multiple problems and pack them into a sequence.
101defget\_packed\_math\_input(self):
105s\_enc=[]106whilelen(s\_enc)\<=self.seq\_len:107s\_part=self.make\_add\_problem()108s\_part\_enc=self.encode('?'+s\_part)109s\_enc=s\_enc+s\_part\_enc110returns\_enc
Encode a given string
112defencode(self,s:str):
116return[self.stoi[c]forcins]
Decode a list of token ids
118defdecode(self,arr:List[int]):
122return''.join([self.itos[c]forcinarr])
Get a input and target pair for auto-regressive modelling
124def\_\_getitem\_\_(self,idx:int):
128s=torch.tensor(self.get\_packed\_math\_input())129returns[:self.seq\_len],s[1:self.seq\_len+1]
Number of sequences per epoch
131def\_\_len\_\_(self):
135returnself.n\_sequences
138classArithmeticAutoregression(NLPAutoRegressionConfigs):
Maximum number of digits per operand integer
143max\_digits:int=4
Number of training sequences per epoch
145train\_sequences\_per\_epoch:int=2\*\*12
Training data loader
147train\_loader:DataLoader='arithmetic\_train\_loader'
Number of problems in evaluation
149n\_tests:int=64
No need of a validation dataset
151validator=None
Number of times to run evaluations per epoch
153inner\_iterations=4
Number of tokens in the vocabulary
155n\_tokens=len(ArithmeticDataset(1,1,1).itos)
We use the sampling function to evaluate the model on a set of problems
[email protected]\_grad()158defsample(self):
Skip in the first epoch
166ifself.training\_loop.idx\<1:167return
Create a dataset to generate problems
170dataset=ArithmeticDataset(self.seq\_len,self.max\_digits,1)
Get a set of problems and answers
172qa=[dataset.get\_qa()for\_inrange(self.n\_tests)]
Collect the problems only
174questions=[p[0]forpinqa]
Create a tensor with only the initial token
177data=torch.tensor([[dataset.stoi[p[0]]forpinquestions]])
Move to device
179data=data.to(self.device)
Number of sequences that have completed
182finished=torch.zeros((len(questions),)).bool().to(self.device)
Token id of the new line character - this marks end of the answer
184new\_line=dataset.stoi['\n']
Sampled results
187results=[p[0]forpinquestions]
Sample upto sequence length
190foriinmonit.iterate('Sample',self.seq\_len-1):
If all the sequences have completed we skip this
192iffinished.sum()==len(finished):193continue
Get the model output
196output,\*\_=self.model(data)
Get the model prediction (greedy)
198output=output[-1].argmax(dim=-1)
Find which sequences have finished
201finished=finished|(output==new\_line)
Skip if all have finished
203iffinished.sum()==len(finished):204continue
Override with the question
207forj,pinenumerate(questions):208iflen(p)\>i+1:209output[j]=dataset.stoi[p[i+1]]
Add the next token to the input
212data=torch.cat([data,output[None,:]],dim=0)
Get the sampled results
215forj,cinenumerate(output):216results[j]+=dataset.itos[c]
Discard everything after the answer in the results
219results=[r.split('\n')[0]forrinresults]
Log a sample
222res\_sample=results[0].split(';')223logger.log([(res\_sample[0],Text.key),(';',Text.subtle),(';'.join(res\_sample[1:]),Text.none)])
Get the answers
226results=[r.split('x==')[-1]forrinresults]
Count the number of correct answers
229correct=0230forr,\_qainzip(results,qa):231ifr==\_qa[1]:232correct+=1
Log the score
235tracker.save('score',correct/len(results))
Training data loader
238@option(ArithmeticAutoregression.train\_loader)239defarithmetic\_train\_loader(c:ArithmeticAutoregression):
243returnDataLoader(ArithmeticDataset(c.seq\_len,c.max\_digits,c.train\_sequences\_per\_epoch),244batch\_size=c.batch\_size,245collate\_fn=transpose\_batch,246num\_workers=4)
Code to test generated problems
249def\_test():
253dataset=ArithmeticDataset(256,8,10)254255print(dataset.decode(dataset.get\_packed\_math\_input()))
259if\_\_name\_\_=='\_\_main\_\_':260\_test()