docs/sketch_rnn/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/sketch_rnn/ init.py)
This is an annotated PyTorch implementation of the paper A Neural Representation of Sketch Drawings.
Sketch RNN is a sequence-to-sequence variational auto-encoder. Both encoder and decoder are recurrent neural network models. It learns to reconstruct stroke based simple drawings, by predicting a series of strokes. Decoder predicts each stroke as a mixture of Gaussian's.
Download data from Quick, Draw! Dataset. There is a link to download npz files in Sketch-RNN QuickDraw Dataset section of the readme. Place the downloaded npz file(s) in data/sketch folder. This code is configured to use bicycle dataset. You can change this in configurations.
Took help from PyTorch Sketch RNN project by Alexis David Jacq
32importmath33fromtypingimportOptional,Tuple,Any3435importeinops36importnumpyasnp37frommatplotlibimportpyplotasplt3839importtorch40importtorch.nnasnn41fromlabmlimportlab,experiment,tracker,monit42fromlabml\_nn.helpers.deviceimportDeviceConfigs43fromlabml\_nn.helpers.optimizerimportOptimizerConfigs44fromlabml\_nn.helpers.trainerimportTrainValidConfigs,BatchIndex45fromtorchimportoptim46fromtorch.utils.dataimportDataset,DataLoader
This class loads and pre-processes the data.
49classStrokesDataset(Dataset):
dataset is a list of numpy arrays of shape seq_len, 3. It is a sequence of strokes, and each stroke is represented by 3 integers. First two are the displacements along x and y (Δx, Δy) and the last integer represents the state of the pen, 1 if it's touching the paper and 0 otherwise.
56def\_\_init\_\_(self,dataset:np.array,max\_seq\_length:int,scale:Optional[float]=None):
66data=[]
We iterate through each of the sequences and filter
68forseqindataset:
Filter if the length of the sequence of strokes is within our range
70if10\<len(seq)\<=max\_seq\_length:
Clamp Δx, Δy to [−1000,1000]
72seq=np.minimum(seq,1000)73seq=np.maximum(seq,-1000)
Convert to a floating point array and add to data
75seq=np.array(seq,dtype=np.float32)76data.append(seq)
We then calculate the scaling factor which is the standard deviation of (Δx, Δy) combined. Paper notes that the mean is not adjusted for simplicity, since the mean is anyway close to 0.
82ifscaleisNone:83scale=np.std(np.concatenate([np.ravel(s[:,0:2])forsindata]))84self.scale=scale
Get the longest sequence length among all sequences
87longest\_seq\_len=max([len(seq)forseqindata])
We initialize PyTorch data array with two extra steps for start-of-sequence (sos) and end-of-sequence (eos). Each step is a vector (Δx,Δy,p1,p2,p3). Only one of p1,p2,p3 is 1 and the others are 0. They represent pen down, pen up and end-of-sequence in that order. p1 is 1 if the pen touches the paper in the next step. p2 is 1 if the pen doesn't touch the paper in the next step. p3 is 1 if it is the end of the drawing.
97self.data=torch.zeros(len(data),longest\_seq\_len+2,5,dtype=torch.float)
The mask array needs only one extra-step since it is for the outputs of the decoder, which takes in data[:-1] and predicts next step.
100self.mask=torch.zeros(len(data),longest\_seq\_len+1)101102fori,seqinenumerate(data):103seq=torch.from\_numpy(seq)104len\_seq=len(seq)
Scale and set Δx,Δy
106self.data[i,1:len\_seq+1,:2]=seq[:,:2]/scale
p1
108self.data[i,1:len\_seq+1,2]=1-seq[:,2]
p2
110self.data[i,1:len\_seq+1,3]=seq[:,2]
p3
112self.data[i,len\_seq+1:,4]=1
Mask is on until end of sequence
114self.mask[i,:len\_seq+1]=1
Start-of-sequence is (0,0,1,0,0)
117self.data[:,0,2]=1
Size of the dataset
119def\_\_len\_\_(self):
121returnlen(self.data)
Get a sample
123def\_\_getitem\_\_(self,idx:int):
125returnself.data[idx],self.mask[idx]
The mixture is represented by Π and N(μx,μy,σx,σy,ρxy). This class adjusts temperatures and creates the categorical and Gaussian distributions from the parameters.
128classBivariateGaussianMixture:
138def\_\_init\_\_(self,pi\_logits:torch.Tensor,mu\_x:torch.Tensor,mu\_y:torch.Tensor,139sigma\_x:torch.Tensor,sigma\_y:torch.Tensor,rho\_xy:torch.Tensor):140self.pi\_logits=pi\_logits141self.mu\_x=mu\_x142self.mu\_y=mu\_y143self.sigma\_x=sigma\_x144self.sigma\_y=sigma\_y145self.rho\_xy=rho\_xy
Number of distributions in the mixture, M
147@property148defn\_distributions(self):
150returnself.pi\_logits.shape[-1]
Adjust by temperature τ
152defset\_temperature(self,temperature:float):
Πk^←τΠk^
157self.pi\_logits/=temperature
σx2←σx2τ
159self.sigma\_x\*=math.sqrt(temperature)
σy2←σy2τ
161self.sigma\_y\*=math.sqrt(temperature)
163defget\_distribution(self):
Clamp σx, σy and ρxy to avoid getting NaN s
165sigma\_x=torch.clamp\_min(self.sigma\_x,1e-5)166sigma\_y=torch.clamp\_min(self.sigma\_y,1e-5)167rho\_xy=torch.clamp(self.rho\_xy,-1+1e-5,1-1e-5)
Get means
170mean=torch.stack([self.mu\_x,self.mu\_y],-1)
Get covariance matrix
172cov=torch.stack([173sigma\_x\*sigma\_x,rho\_xy\*sigma\_x\*sigma\_y,174rho\_xy\*sigma\_x\*sigma\_y,sigma\_y\*sigma\_y175],-1)176cov=cov.view(\*sigma\_y.shape,2,2)
Create bi-variate normal distribution.
📝 It would be efficient to scale_tril matrix as [[a, 0], [b, c]] where a=σx,b=ρxyσy,c=σy1−ρxy2. But for simplicity we use co-variance matrix. This is a good resource if you want to read up more about bi-variate distributions, their co-variance matrix, and probability density function.
187multi\_dist=torch.distributions.MultivariateNormal(mean,covariance\_matrix=cov)
Create categorical distribution Π from logits
190cat\_dist=torch.distributions.Categorical(logits=self.pi\_logits)
193returncat\_dist,multi\_dist
This consists of a bidirectional LSTM
196classEncoderRNN(nn.Module):
203def\_\_init\_\_(self,d\_z:int,enc\_hidden\_size:int):204super().\_\_init\_\_()
Create a bidirectional LSTM taking a sequence of (Δx,Δy,p1,p2,p3) as input.
207self.lstm=nn.LSTM(5,enc\_hidden\_size,bidirectional=True)
Head to get μ
209self.mu\_head=nn.Linear(2\*enc\_hidden\_size,d\_z)
Head to get σ^
211self.sigma\_head=nn.Linear(2\*enc\_hidden\_size,d\_z)
213defforward(self,inputs:torch.Tensor,state=None):
The hidden state of the bidirectional LSTM is the concatenation of the output of the last token in the forward direction and first token in the reverse direction, which is what we want. h→=encode→(S),h←=encode←←(Sreverse),h=[h→;h←]
220\_,(hidden,cell)=self.lstm(inputs.float(),state)
The state has shape [2, batch_size, hidden_size] , where the first dimension is the direction. We rearrange it to get h=[h→;h←]
224hidden=einops.rearrange(hidden,'fb b h -\> b (fb h)')
μ
227mu=self.mu\_head(hidden)
σ^
229sigma\_hat=self.sigma\_head(hidden)
σ=exp(2σ^)
231sigma=torch.exp(sigma\_hat/2.)
Sample z=μ+σ⋅N(0,I)
234z=mu+sigma\*torch.normal(mu.new\_zeros(mu.shape),mu.new\_ones(mu.shape))
237returnz,mu,sigma\_hat
This consists of a LSTM
240classDecoderRNN(nn.Module):
247def\_\_init\_\_(self,d\_z:int,dec\_hidden\_size:int,n\_distributions:int):248super().\_\_init\_\_()
LSTM takes [(Δx,Δy,p1,p2,p3);z] as input
250self.lstm=nn.LSTM(d\_z+5,dec\_hidden\_size)
Initial state of the LSTM is [h0;c0]=tanh(Wzz+bz). init_state is the linear transformation for this
254self.init\_state=nn.Linear(d\_z,2\*dec\_hidden\_size)
This layer produces outputs for each of the n_distributions . Each distribution needs six parameters (Πi^,μxi,μyi,σxi^,σyi^ρxyi^)
259self.mixtures=nn.Linear(dec\_hidden\_size,6\*n\_distributions)
This head is for the logits (q1^,q2^,q3^)
262self.q\_head=nn.Linear(dec\_hidden\_size,3)
This is to calculate log(qk) where qk=softmax(q^)k=∑j=13exp(qj^)exp(qk^)
265self.q\_log\_softmax=nn.LogSoftmax(-1)
These parameters are stored for future reference
268self.n\_distributions=n\_distributions269self.dec\_hidden\_size=dec\_hidden\_size
271defforward(self,x:torch.Tensor,z:torch.Tensor,state:Optional[Tuple[torch.Tensor,torch.Tensor]]):
Calculate the initial state
273ifstateisNone:
[h0;c0]=tanh(Wzz+bz)
275h,c=torch.split(torch.tanh(self.init\_state(z)),self.dec\_hidden\_size,1)
h and c have shapes [batch_size, lstm_size] . We want to shape them to [1, batch_size, lstm_size] because that's the shape used in LSTM.
278state=(h.unsqueeze(0).contiguous(),c.unsqueeze(0).contiguous())
Run the LSTM
281outputs,state=self.lstm(x,state)
Get log(q)
284q\_logits=self.q\_log\_softmax(self.q\_head(outputs))
Get (Πi^,μx,i,μy,i,σx,i^,σy,i^ρxy,i^). torch.split splits the output into 6 tensors of size self.n_distribution across dimension 2 .
290pi\_logits,mu\_x,mu\_y,sigma\_x,sigma\_y,rho\_xy=\291torch.split(self.mixtures(outputs),self.n\_distributions,2)
Create a bi-variate Gaussian mixture Π and N(μx,μy,σx,σy,ρxy) where σx,i=exp(σx,i^),σy,i=exp(σy,i^),ρxy,i=tanh(ρxy,i^) and Πi=softmax(Π^)i=∑j=13exp(Πj^)exp(Πi^)
Π is the categorical probabilities of choosing the distribution out of the mixture N(μx,μy,σx,σy,ρxy).
304dist=BivariateGaussianMixture(pi\_logits,mu\_x,mu\_y,305torch.exp(sigma\_x),torch.exp(sigma\_y),torch.tanh(rho\_xy))
308returndist,q\_logits,state
311classReconstructionLoss(nn.Module):
316defforward(self,mask:torch.Tensor,target:torch.Tensor,317dist:'BivariateGaussianMixture',q\_logits:torch.Tensor):
Get Π and N(μx,μy,σx,σy,ρxy)
319pi,mix=dist.get\_distribution()
target has shape [seq_len, batch_size, 5] where the last dimension is the features (Δx,Δy,p1,p2,p3). We want to get Δx,Δ y and get the probabilities from each of the distributions in the mixture N(μx,μy,σx,σy,ρxy).
xy will have shape [seq_len, batch_size, n_distributions, 2]
326xy=target[:,:,0:2].unsqueeze(-2).expand(-1,-1,dist.n\_distributions,-1)
Calculate the probabilities p(Δx,Δy)=j=1∑MΠjN(Δx,Δy∣μx,j,μy,j,σx,j,σy,j,ρxy,j)
332probs=torch.sum(pi.probs\*torch.exp(mix.log\_prob(xy)),2)
Ls=−Nmax1i=1∑Nslog(p(Δx,Δy)) Although probs has Nmax (longest_seq_len ) elements, the sum is only taken upto Ns because the rest is masked out.
It might feel like we should be taking the sum and dividing by Ns and not Nmax, but this will give higher weight for individual predictions in shorter sequences. We give equal weight to each prediction p(Δx,Δy) when we divide by Nmax
341loss\_stroke=-torch.mean(mask\*torch.log(1e-5+probs))
Lp=−Nmax1i=1∑Nmaxk=1∑3pk,ilog(qk,i)
344loss\_pen=-torch.mean(target[:,:,2:]\*q\_logits)
LR=Ls+Lp
347returnloss\_stroke+loss\_pen
This calculates the KL divergence between a given normal distribution and N(0,1)
350classKLDivLoss(nn.Module):
357defforward(self,sigma\_hat:torch.Tensor,mu:torch.Tensor):
LKL=−2Nz1(1+σ^−μ2−exp(σ^))
359return-0.5\*torch.mean(1+sigma\_hat-mu\*\*2-torch.exp(sigma\_hat))
This samples a sketch from the decoder and plots it
362classSampler:
369def\_\_init\_\_(self,encoder:EncoderRNN,decoder:DecoderRNN):370self.decoder=decoder371self.encoder=encoder
373defsample(self,data:torch.Tensor,temperature:float):
Nmax
375longest\_seq\_len=len(data)
Get z from the encoder
378z,\_,\_=self.encoder(data)
Start-of-sequence stroke is (0,0,1,0,0)
381s=data.new\_tensor([0,0,1,0,0])382seq=[s]
Initial decoder is None . The decoder will initialize it to [h0;c0]=tanh(Wzz+bz)
385state=None
We don't need gradients
388withtorch.no\_grad():
Sample Nmax strokes
390foriinrange(longest\_seq\_len):
[(Δx,Δy,p1,p2,p3);z] is the input to the decoder
392data=torch.cat([s.view(1,1,-1),z.unsqueeze(0)],2)
Get Π, N(μx,μy,σx,σy,ρxy), q and the next state from the decoder
395dist,q\_logits,state=self.decoder(data,z,state)
Sample a stroke
397s=self.\_sample\_step(dist,q\_logits,temperature)
Add the new stroke to the sequence of strokes
399seq.append(s)
Stop sampling if p3=1. This indicates that sketching has stopped
401ifs[4]==1:402break
Create a PyTorch tensor of the sequence of strokes
405seq=torch.stack(seq)
Plot the sequence of strokes
408self.plot(seq)
410@staticmethod411def\_sample\_step(dist:'BivariateGaussianMixture',q\_logits:torch.Tensor,temperature:float):
Set temperature τ for sampling. This is implemented in class BivariateGaussianMixture .
413dist.set\_temperature(temperature)
Get temperature adjusted Π and N(μx,μy,σx,σy,ρxy)
415pi,mix=dist.get\_distribution()
Sample from Π the index of the distribution to use from the mixture
417idx=pi.sample()[0,0]
Create categorical distribution q with log-probabilities q_logits or q^
420q=torch.distributions.Categorical(logits=q\_logits/temperature)
Sample from q
422q\_idx=q.sample()[0,0]
Sample from the normal distributions in the mixture and pick the one indexed by idx
425xy=mix.sample()[0,0,idx]
Create an empty stroke (Δx,Δy,q1,q2,q3)
428stroke=q\_logits.new\_zeros(5)
Set Δx,Δy
430stroke[:2]=xy
Set q1,q2,q3
432stroke[q\_idx+2]=1
434returnstroke
436@staticmethod437defplot(seq:torch.Tensor):
Take the cumulative sums of (Δx,Δy) to get (x,y)
439seq[:,0:2]=torch.cumsum(seq[:,0:2],dim=0)
Create a new numpy array of the form (x,y,q2)
441seq[:,2]=seq[:,3]442seq=seq[:,0:3].detach().cpu().numpy()
Split the array at points where q2 is 1. i.e. split the array of strokes at the points where the pen is lifted from the paper. This gives a list of sequence of strokes.
447strokes=np.split(seq,np.where(seq[:,2]\>0)[0]+1)
Plot each sequence of strokes
449forsinstrokes:450plt.plot(s[:,0],-s[:,1])
Don't show axes
452plt.axis('off')
Show the plot
454plt.show()
These are default configurations which can later be adjusted by passing a dict .
457classConfigs(TrainValidConfigs):
Device configurations to pick the device to run the experiment
465device:torch.device=DeviceConfigs()
467encoder:EncoderRNN468decoder:DecoderRNN469optimizer:optim.Adam470sampler:Sampler471472dataset\_name:str473train\_loader:DataLoader474valid\_loader:DataLoader475train\_dataset:StrokesDataset476valid\_dataset:StrokesDataset
Encoder and decoder sizes
479enc\_hidden\_size=256480dec\_hidden\_size=512
Batch size
483batch\_size=100
Number of features in z
486d\_z=128
Number of distributions in the mixture, M
488n\_distributions=20
Weight of KL divergence loss, wKL
491kl\_div\_loss\_weight=0.5
Gradient clipping
493grad\_clip=1.
Temperature τ for sampling
495temperature=0.4
Filter out stroke sequences longer than 200
498max\_seq\_length=200499500epochs=100501502kl\_div\_loss=KLDivLoss()503reconstruction\_loss=ReconstructionLoss()
505definit(self):
Initialize encoder & decoder
507self.encoder=EncoderRNN(self.d\_z,self.enc\_hidden\_size).to(self.device)508self.decoder=DecoderRNN(self.d\_z,self.dec\_hidden\_size,self.n\_distributions).to(self.device)
Set optimizer. Things like type of optimizer and learning rate are configurable
511optimizer=OptimizerConfigs()512optimizer.parameters=list(self.encoder.parameters())+list(self.decoder.parameters())513self.optimizer=optimizer
Create sampler
516self.sampler=Sampler(self.encoder,self.decoder)
npz file path is data/sketch/[DATASET NAME].npz
519path=lab.get\_data\_path()/'sketch'/f'{self.dataset\_name}.npz'
Load the numpy file
521dataset=np.load(str(path),encoding='latin1',allow\_pickle=True)
Create training dataset
524self.train\_dataset=StrokesDataset(dataset['train'],self.max\_seq\_length)
Create validation dataset
526self.valid\_dataset=StrokesDataset(dataset['valid'],self.max\_seq\_length,self.train\_dataset.scale)
Create training data loader
529self.train\_loader=DataLoader(self.train\_dataset,self.batch\_size,shuffle=True)
Create validation data loader
531self.valid\_loader=DataLoader(self.valid\_dataset,self.batch\_size)
Configure the tracker to print the total train/validation loss
534tracker.set\_scalar("loss.total.\*",True)535536self.state\_modules=[]
538defstep(self,batch:Any,batch\_idx:BatchIndex):539self.encoder.train(self.mode.is\_train)540self.decoder.train(self.mode.is\_train)
Move data and mask to device and swap the sequence and batch dimensions. data will have shape [seq_len, batch_size, 5] and mask will have shape [seq_len, batch_size] .
545data=batch[0].to(self.device).transpose(0,1)546mask=batch[1].to(self.device).transpose(0,1)
Increment step in training mode
549ifself.mode.is\_train:550tracker.add\_global\_step(len(data))
Encode the sequence of strokes
553withmonit.section("encoder"):
Get z, μ, and σ^
555z,mu,sigma\_hat=self.encoder(data)
Decode the mixture of distributions and q^
558withmonit.section("decoder"):
Concatenate [(Δx,Δy,p1,p2,p3);z]
560z\_stack=z.unsqueeze(0).expand(data.shape[0]-1,-1,-1)561inputs=torch.cat([data[:-1],z\_stack],2)
Get mixture of distributions and q^
563dist,q\_logits,\_=self.decoder(inputs,z,None)
Compute the loss
566withmonit.section('loss'):
LKL
568kl\_loss=self.kl\_div\_loss(sigma\_hat,mu)
LR
570reconstruction\_loss=self.reconstruction\_loss(mask,data[1:],dist,q\_logits)
Loss=LR+wKLLKL
572loss=reconstruction\_loss+self.kl\_div\_loss\_weight\*kl\_loss
Track losses
575tracker.add("loss.kl.",kl\_loss)576tracker.add("loss.reconstruction.",reconstruction\_loss)577tracker.add("loss.total.",loss)
Only if we are in training state
580ifself.mode.is\_train:
Run optimizer
582withmonit.section('optimize'):
Set grad to zero
584self.optimizer.zero\_grad()
Compute gradients
586loss.backward()
Log model parameters and gradients
588ifbatch\_idx.is\_last:589tracker.add(encoder=self.encoder,decoder=self.decoder)
Clip gradients
591nn.utils.clip\_grad\_norm\_(self.encoder.parameters(),self.grad\_clip)592nn.utils.clip\_grad\_norm\_(self.decoder.parameters(),self.grad\_clip)
Optimize
594self.optimizer.step()595596tracker.save()
598defsample(self):
Randomly pick a sample from validation dataset to encoder
600data,\*\_=self.valid\_dataset[np.random.choice(len(self.valid\_dataset))]
Add batch dimension and move it to device
602data=data.unsqueeze(1).to(self.device)
Sample
604self.sampler.sample(data,self.temperature)
607defmain():608configs=Configs()609experiment.create(name="sketch\_rnn")
Pass a dictionary of configurations
612experiment.configs(configs,{613'optimizer.optimizer':'Adam',
We use a learning rate of 1e-3 because we can see results faster. Paper had suggested 1e-4 .
616'optimizer.learning\_rate':1e-3,
Name of the dataset
618'dataset\_name':'bicycle',
Number of inner iterations within an epoch to switch between training, validation and sampling.
620'inner\_iterations':10621})622623withexperiment.start():
Run the experiment
625configs.run()626627628if\_\_name\_\_=="\_\_main\_\_":629main()