Back to Annotated Deep Learning Paper Implementations

Sketch RNN

docs/sketch_rnn/index.html

latest23.6 KB
Original Source

homesketch_rnn

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/sketch_rnn/ init.py)

#

Sketch RNN

This is an annotated PyTorch implementation of the paper A Neural Representation of Sketch Drawings.

Sketch RNN is a sequence-to-sequence variational auto-encoder. Both encoder and decoder are recurrent neural network models. It learns to reconstruct stroke based simple drawings, by predicting a series of strokes. Decoder predicts each stroke as a mixture of Gaussian's.

Getting data

Download data from Quick, Draw! Dataset. There is a link to download npz files in Sketch-RNN QuickDraw Dataset section of the readme. Place the downloaded npz file(s) in data/sketch folder. This code is configured to use bicycle dataset. You can change this in configurations.

Acknowledgements

Took help from PyTorch Sketch RNN project by Alexis David Jacq

32importmath33fromtypingimportOptional,Tuple,Any3435importeinops36importnumpyasnp37frommatplotlibimportpyplotasplt3839importtorch40importtorch.nnasnn41fromlabmlimportlab,experiment,tracker,monit42fromlabml\_nn.helpers.deviceimportDeviceConfigs43fromlabml\_nn.helpers.optimizerimportOptimizerConfigs44fromlabml\_nn.helpers.trainerimportTrainValidConfigs,BatchIndex45fromtorchimportoptim46fromtorch.utils.dataimportDataset,DataLoader

#

Dataset

This class loads and pre-processes the data.

49classStrokesDataset(Dataset):

#

dataset is a list of numpy arrays of shape seq_len, 3. It is a sequence of strokes, and each stroke is represented by 3 integers. First two are the displacements along x and y (Δx, Δy) and the last integer represents the state of the pen, 1 if it's touching the paper and 0 otherwise.

56def\_\_init\_\_(self,dataset:np.array,max\_seq\_length:int,scale:Optional[float]=None):

#

66data=[]

#

We iterate through each of the sequences and filter

68forseqindataset:

#

Filter if the length of the sequence of strokes is within our range

70if10\<len(seq)\<=max\_seq\_length:

#

Clamp Δx, Δy to [−1000,1000]

72seq=np.minimum(seq,1000)73seq=np.maximum(seq,-1000)

#

Convert to a floating point array and add to data

75seq=np.array(seq,dtype=np.float32)76data.append(seq)

#

We then calculate the scaling factor which is the standard deviation of (Δx, Δy) combined. Paper notes that the mean is not adjusted for simplicity, since the mean is anyway close to 0.

82ifscaleisNone:83scale=np.std(np.concatenate([np.ravel(s[:,0:2])forsindata]))84self.scale=scale

#

Get the longest sequence length among all sequences

87longest\_seq\_len=max([len(seq)forseqindata])

#

We initialize PyTorch data array with two extra steps for start-of-sequence (sos) and end-of-sequence (eos). Each step is a vector (Δx,Δy,p1​,p2​,p3​). Only one of p1​,p2​,p3​ is 1 and the others are 0. They represent pen down, pen up and end-of-sequence in that order. p1​ is 1 if the pen touches the paper in the next step. p2​ is 1 if the pen doesn't touch the paper in the next step. p3​ is 1 if it is the end of the drawing.

97self.data=torch.zeros(len(data),longest\_seq\_len+2,5,dtype=torch.float)

#

The mask array needs only one extra-step since it is for the outputs of the decoder, which takes in data[:-1] and predicts next step.

100self.mask=torch.zeros(len(data),longest\_seq\_len+1)101102fori,seqinenumerate(data):103seq=torch.from\_numpy(seq)104len\_seq=len(seq)

#

Scale and set Δx,Δy

106self.data[i,1:len\_seq+1,:2]=seq[:,:2]/scale

#

p1​

108self.data[i,1:len\_seq+1,2]=1-seq[:,2]

#

p2​

110self.data[i,1:len\_seq+1,3]=seq[:,2]

#

p3​

112self.data[i,len\_seq+1:,4]=1

#

Mask is on until end of sequence

114self.mask[i,:len\_seq+1]=1

#

Start-of-sequence is (0,0,1,0,0)

117self.data[:,0,2]=1

#

Size of the dataset

119def\_\_len\_\_(self):

#

121returnlen(self.data)

#

Get a sample

123def\_\_getitem\_\_(self,idx:int):

#

125returnself.data[idx],self.mask[idx]

#

Bi-variate Gaussian mixture

The mixture is represented by Π and N(μx​,μy​,σx​,σy​,ρxy​). This class adjusts temperatures and creates the categorical and Gaussian distributions from the parameters.

128classBivariateGaussianMixture:

#

138def\_\_init\_\_(self,pi\_logits:torch.Tensor,mu\_x:torch.Tensor,mu\_y:torch.Tensor,139sigma\_x:torch.Tensor,sigma\_y:torch.Tensor,rho\_xy:torch.Tensor):140self.pi\_logits=pi\_logits141self.mu\_x=mu\_x142self.mu\_y=mu\_y143self.sigma\_x=sigma\_x144self.sigma\_y=sigma\_y145self.rho\_xy=rho\_xy

#

Number of distributions in the mixture, M

147@property148defn\_distributions(self):

#

150returnself.pi\_logits.shape[-1]

#

Adjust by temperature τ

152defset\_temperature(self,temperature:float):

#

Πk​^​←τΠk​^​​

157self.pi\_logits/=temperature

#

σx2​←σx2​τ

159self.sigma\_x\*=math.sqrt(temperature)

#

σy2​←σy2​τ

161self.sigma\_y\*=math.sqrt(temperature)

#

163defget\_distribution(self):

#

Clamp σx​, σy​ and ρxy​ to avoid getting NaN s

165sigma\_x=torch.clamp\_min(self.sigma\_x,1e-5)166sigma\_y=torch.clamp\_min(self.sigma\_y,1e-5)167rho\_xy=torch.clamp(self.rho\_xy,-1+1e-5,1-1e-5)

#

Get means

170mean=torch.stack([self.mu\_x,self.mu\_y],-1)

#

Get covariance matrix

172cov=torch.stack([173sigma\_x\*sigma\_x,rho\_xy\*sigma\_x\*sigma\_y,174rho\_xy\*sigma\_x\*sigma\_y,sigma\_y\*sigma\_y175],-1)176cov=cov.view(\*sigma\_y.shape,2,2)

#

Create bi-variate normal distribution.

📝 It would be efficient to scale_tril matrix as [[a, 0], [b, c]] where a=σx​,b=ρxy​σy​,c=σy​1−ρxy2​​. But for simplicity we use co-variance matrix. This is a good resource if you want to read up more about bi-variate distributions, their co-variance matrix, and probability density function.

187multi\_dist=torch.distributions.MultivariateNormal(mean,covariance\_matrix=cov)

#

Create categorical distribution Π from logits

190cat\_dist=torch.distributions.Categorical(logits=self.pi\_logits)

#

193returncat\_dist,multi\_dist

#

Encoder module

This consists of a bidirectional LSTM

196classEncoderRNN(nn.Module):

#

203def\_\_init\_\_(self,d\_z:int,enc\_hidden\_size:int):204super().\_\_init\_\_()

#

Create a bidirectional LSTM taking a sequence of (Δx,Δy,p1​,p2​,p3​) as input.

207self.lstm=nn.LSTM(5,enc\_hidden\_size,bidirectional=True)

#

Head to get μ

209self.mu\_head=nn.Linear(2\*enc\_hidden\_size,d\_z)

#

Head to get σ^

211self.sigma\_head=nn.Linear(2\*enc\_hidden\_size,d\_z)

#

213defforward(self,inputs:torch.Tensor,state=None):

#

The hidden state of the bidirectional LSTM is the concatenation of the output of the last token in the forward direction and first token in the reverse direction, which is what we want. h→​=encode→​(S),h←​=encode←←​(Sreverse​),h=[h→​;h←​]

220\_,(hidden,cell)=self.lstm(inputs.float(),state)

#

The state has shape [2, batch_size, hidden_size] , where the first dimension is the direction. We rearrange it to get h=[h→​;h←​]

224hidden=einops.rearrange(hidden,'fb b h -\> b (fb h)')

#

μ

227mu=self.mu\_head(hidden)

#

σ^

229sigma\_hat=self.sigma\_head(hidden)

#

σ=exp(2σ^​)

231sigma=torch.exp(sigma\_hat/2.)

#

Sample z=μ+σ⋅N(0,I)

234z=mu+sigma\*torch.normal(mu.new\_zeros(mu.shape),mu.new\_ones(mu.shape))

#

237returnz,mu,sigma\_hat

#

Decoder module

This consists of a LSTM

240classDecoderRNN(nn.Module):

#

247def\_\_init\_\_(self,d\_z:int,dec\_hidden\_size:int,n\_distributions:int):248super().\_\_init\_\_()

#

LSTM takes [(Δx,Δy,p1​,p2​,p3​);z] as input

250self.lstm=nn.LSTM(d\_z+5,dec\_hidden\_size)

#

Initial state of the LSTM is [h0​;c0​]=tanh(Wz​z+bz​). init_state is the linear transformation for this

254self.init\_state=nn.Linear(d\_z,2\*dec\_hidden\_size)

#

This layer produces outputs for each of the n_distributions . Each distribution needs six parameters (Πi​^​,μxi​​,μyi​​,σxi​​^​,σyi​​^​ρxyi​​^​)

259self.mixtures=nn.Linear(dec\_hidden\_size,6\*n\_distributions)

#

This head is for the logits (q1​^​,q2​^​,q3​^​)

262self.q\_head=nn.Linear(dec\_hidden\_size,3)

#

This is to calculate log(qk​) where qk​=softmax(q^​)k​=∑j=13​exp(qj​^​)exp(qk​^​)​

265self.q\_log\_softmax=nn.LogSoftmax(-1)

#

These parameters are stored for future reference

268self.n\_distributions=n\_distributions269self.dec\_hidden\_size=dec\_hidden\_size

#

271defforward(self,x:torch.Tensor,z:torch.Tensor,state:Optional[Tuple[torch.Tensor,torch.Tensor]]):

#

Calculate the initial state

273ifstateisNone:

#

[h0​;c0​]=tanh(Wz​z+bz​)

275h,c=torch.split(torch.tanh(self.init\_state(z)),self.dec\_hidden\_size,1)

#

h and c have shapes [batch_size, lstm_size] . We want to shape them to [1, batch_size, lstm_size] because that's the shape used in LSTM.

278state=(h.unsqueeze(0).contiguous(),c.unsqueeze(0).contiguous())

#

Run the LSTM

281outputs,state=self.lstm(x,state)

#

Get log(q)

284q\_logits=self.q\_log\_softmax(self.q\_head(outputs))

#

Get (Πi​^​,μx,i​,μy,i​,σx,i​^​,σy,i​^​ρxy,i​^​). torch.split splits the output into 6 tensors of size self.n_distribution across dimension 2 .

290pi\_logits,mu\_x,mu\_y,sigma\_x,sigma\_y,rho\_xy=\291torch.split(self.mixtures(outputs),self.n\_distributions,2)

#

Create a bi-variate Gaussian mixture Π and N(μx​,μy​,σx​,σy​,ρxy​) where σx,i​=exp(σx,i​^​),σy,i​=exp(σy,i​^​),ρxy,i​=tanh(ρxy,i​^​) and Πi​=softmax(Π^)i​=∑j=13​exp(Πj​^​)exp(Πi​^​)​

Π is the categorical probabilities of choosing the distribution out of the mixture N(μx​,μy​,σx​,σy​,ρxy​).

304dist=BivariateGaussianMixture(pi\_logits,mu\_x,mu\_y,305torch.exp(sigma\_x),torch.exp(sigma\_y),torch.tanh(rho\_xy))

#

308returndist,q\_logits,state

#

Reconstruction Loss

311classReconstructionLoss(nn.Module):

#

316defforward(self,mask:torch.Tensor,target:torch.Tensor,317dist:'BivariateGaussianMixture',q\_logits:torch.Tensor):

#

Get Π and N(μx​,μy​,σx​,σy​,ρxy​)

319pi,mix=dist.get\_distribution()

#

target has shape [seq_len, batch_size, 5] where the last dimension is the features (Δx,Δy,p1​,p2​,p3​). We want to get Δx,Δ y and get the probabilities from each of the distributions in the mixture N(μx​,μy​,σx​,σy​,ρxy​).

xy will have shape [seq_len, batch_size, n_distributions, 2]

326xy=target[:,:,0:2].unsqueeze(-2).expand(-1,-1,dist.n\_distributions,-1)

#

Calculate the probabilities p(Δx,Δy)=j=1∑M​Πj​N(Δx,Δy∣μx,j​,μy,j​,σx,j​,σy,j​,ρxy,j​)

332probs=torch.sum(pi.probs\*torch.exp(mix.log\_prob(xy)),2)

#

Ls​=−Nmax​1​i=1∑Ns​​log(p(Δx,Δy)) Although probs has Nmax​ (longest_seq_len ) elements, the sum is only taken upto Ns​ because the rest is masked out.

It might feel like we should be taking the sum and dividing by Ns​ and not Nmax​, but this will give higher weight for individual predictions in shorter sequences. We give equal weight to each prediction p(Δx,Δy) when we divide by Nmax​

341loss\_stroke=-torch.mean(mask\*torch.log(1e-5+probs))

#

Lp​=−Nmax​1​i=1∑Nmax​​k=1∑3​pk,i​log(qk,i​)

344loss\_pen=-torch.mean(target[:,:,2:]\*q\_logits)

#

LR​=Ls​+Lp​

347returnloss\_stroke+loss\_pen

#

KL-Divergence loss

This calculates the KL divergence between a given normal distribution and N(0,1)

350classKLDivLoss(nn.Module):

#

357defforward(self,sigma\_hat:torch.Tensor,mu:torch.Tensor):

#

LKL​=−2Nz​1​(1+σ^−μ2−exp(σ^))

359return-0.5\*torch.mean(1+sigma\_hat-mu\*\*2-torch.exp(sigma\_hat))

#

Sampler

This samples a sketch from the decoder and plots it

362classSampler:

#

369def\_\_init\_\_(self,encoder:EncoderRNN,decoder:DecoderRNN):370self.decoder=decoder371self.encoder=encoder

#

373defsample(self,data:torch.Tensor,temperature:float):

#

Nmax​

375longest\_seq\_len=len(data)

#

Get z from the encoder

378z,\_,\_=self.encoder(data)

#

Start-of-sequence stroke is (0,0,1,0,0)

381s=data.new\_tensor([0,0,1,0,0])382seq=[s]

#

Initial decoder is None . The decoder will initialize it to [h0​;c0​]=tanh(Wz​z+bz​)

385state=None

#

We don't need gradients

388withtorch.no\_grad():

#

Sample Nmax​ strokes

390foriinrange(longest\_seq\_len):

#

[(Δx,Δy,p1​,p2​,p3​);z] is the input to the decoder

392data=torch.cat([s.view(1,1,-1),z.unsqueeze(0)],2)

#

Get Π, N(μx​,μy​,σx​,σy​,ρxy​), q and the next state from the decoder

395dist,q\_logits,state=self.decoder(data,z,state)

#

Sample a stroke

397s=self.\_sample\_step(dist,q\_logits,temperature)

#

Add the new stroke to the sequence of strokes

399seq.append(s)

#

Stop sampling if p3​=1. This indicates that sketching has stopped

401ifs[4]==1:402break

#

Create a PyTorch tensor of the sequence of strokes

405seq=torch.stack(seq)

#

Plot the sequence of strokes

408self.plot(seq)

#

410@staticmethod411def\_sample\_step(dist:'BivariateGaussianMixture',q\_logits:torch.Tensor,temperature:float):

#

Set temperature τ for sampling. This is implemented in class BivariateGaussianMixture .

413dist.set\_temperature(temperature)

#

Get temperature adjusted Π and N(μx​,μy​,σx​,σy​,ρxy​)

415pi,mix=dist.get\_distribution()

#

Sample from Π the index of the distribution to use from the mixture

417idx=pi.sample()[0,0]

#

Create categorical distribution q with log-probabilities q_logits or q^​

420q=torch.distributions.Categorical(logits=q\_logits/temperature)

#

Sample from q

422q\_idx=q.sample()[0,0]

#

Sample from the normal distributions in the mixture and pick the one indexed by idx

425xy=mix.sample()[0,0,idx]

#

Create an empty stroke (Δx,Δy,q1​,q2​,q3​)

428stroke=q\_logits.new\_zeros(5)

#

Set Δx,Δy

430stroke[:2]=xy

#

Set q1​,q2​,q3​

432stroke[q\_idx+2]=1

#

434returnstroke

#

436@staticmethod437defplot(seq:torch.Tensor):

#

Take the cumulative sums of (Δx,Δy) to get (x,y)

439seq[:,0:2]=torch.cumsum(seq[:,0:2],dim=0)

#

Create a new numpy array of the form (x,y,q2​)

441seq[:,2]=seq[:,3]442seq=seq[:,0:3].detach().cpu().numpy()

#

Split the array at points where q2​ is 1. i.e. split the array of strokes at the points where the pen is lifted from the paper. This gives a list of sequence of strokes.

447strokes=np.split(seq,np.where(seq[:,2]\>0)[0]+1)

#

Plot each sequence of strokes

449forsinstrokes:450plt.plot(s[:,0],-s[:,1])

#

Don't show axes

452plt.axis('off')

#

Show the plot

454plt.show()

#

Configurations

These are default configurations which can later be adjusted by passing a dict .

457classConfigs(TrainValidConfigs):

#

Device configurations to pick the device to run the experiment

465device:torch.device=DeviceConfigs()

#

467encoder:EncoderRNN468decoder:DecoderRNN469optimizer:optim.Adam470sampler:Sampler471472dataset\_name:str473train\_loader:DataLoader474valid\_loader:DataLoader475train\_dataset:StrokesDataset476valid\_dataset:StrokesDataset

#

Encoder and decoder sizes

479enc\_hidden\_size=256480dec\_hidden\_size=512

#

Batch size

483batch\_size=100

#

Number of features in z

486d\_z=128

#

Number of distributions in the mixture, M

488n\_distributions=20

#

Weight of KL divergence loss, wKL​

491kl\_div\_loss\_weight=0.5

#

Gradient clipping

493grad\_clip=1.

#

Temperature τ for sampling

495temperature=0.4

#

Filter out stroke sequences longer than 200

498max\_seq\_length=200499500epochs=100501502kl\_div\_loss=KLDivLoss()503reconstruction\_loss=ReconstructionLoss()

#

505definit(self):

#

Initialize encoder & decoder

507self.encoder=EncoderRNN(self.d\_z,self.enc\_hidden\_size).to(self.device)508self.decoder=DecoderRNN(self.d\_z,self.dec\_hidden\_size,self.n\_distributions).to(self.device)

#

Set optimizer. Things like type of optimizer and learning rate are configurable

511optimizer=OptimizerConfigs()512optimizer.parameters=list(self.encoder.parameters())+list(self.decoder.parameters())513self.optimizer=optimizer

#

Create sampler

516self.sampler=Sampler(self.encoder,self.decoder)

#

npz file path is data/sketch/[DATASET NAME].npz

519path=lab.get\_data\_path()/'sketch'/f'{self.dataset\_name}.npz'

#

Load the numpy file

521dataset=np.load(str(path),encoding='latin1',allow\_pickle=True)

#

Create training dataset

524self.train\_dataset=StrokesDataset(dataset['train'],self.max\_seq\_length)

#

Create validation dataset

526self.valid\_dataset=StrokesDataset(dataset['valid'],self.max\_seq\_length,self.train\_dataset.scale)

#

Create training data loader

529self.train\_loader=DataLoader(self.train\_dataset,self.batch\_size,shuffle=True)

#

Create validation data loader

531self.valid\_loader=DataLoader(self.valid\_dataset,self.batch\_size)

#

Configure the tracker to print the total train/validation loss

534tracker.set\_scalar("loss.total.\*",True)535536self.state\_modules=[]

#

538defstep(self,batch:Any,batch\_idx:BatchIndex):539self.encoder.train(self.mode.is\_train)540self.decoder.train(self.mode.is\_train)

#

Move data and mask to device and swap the sequence and batch dimensions. data will have shape [seq_len, batch_size, 5] and mask will have shape [seq_len, batch_size] .

545data=batch[0].to(self.device).transpose(0,1)546mask=batch[1].to(self.device).transpose(0,1)

#

Increment step in training mode

549ifself.mode.is\_train:550tracker.add\_global\_step(len(data))

#

Encode the sequence of strokes

553withmonit.section("encoder"):

#

Get z, μ, and σ^

555z,mu,sigma\_hat=self.encoder(data)

#

Decode the mixture of distributions and q^​

558withmonit.section("decoder"):

#

Concatenate [(Δx,Δy,p1​,p2​,p3​);z]

560z\_stack=z.unsqueeze(0).expand(data.shape[0]-1,-1,-1)561inputs=torch.cat([data[:-1],z\_stack],2)

#

Get mixture of distributions and q^​

563dist,q\_logits,\_=self.decoder(inputs,z,None)

#

Compute the loss

566withmonit.section('loss'):

#

LKL​

568kl\_loss=self.kl\_div\_loss(sigma\_hat,mu)

#

LR​

570reconstruction\_loss=self.reconstruction\_loss(mask,data[1:],dist,q\_logits)

#

Loss=LR​+wKL​LKL​

572loss=reconstruction\_loss+self.kl\_div\_loss\_weight\*kl\_loss

#

Track losses

575tracker.add("loss.kl.",kl\_loss)576tracker.add("loss.reconstruction.",reconstruction\_loss)577tracker.add("loss.total.",loss)

#

Only if we are in training state

580ifself.mode.is\_train:

#

Run optimizer

582withmonit.section('optimize'):

#

Set grad to zero

584self.optimizer.zero\_grad()

#

Compute gradients

586loss.backward()

#

Log model parameters and gradients

588ifbatch\_idx.is\_last:589tracker.add(encoder=self.encoder,decoder=self.decoder)

#

Clip gradients

591nn.utils.clip\_grad\_norm\_(self.encoder.parameters(),self.grad\_clip)592nn.utils.clip\_grad\_norm\_(self.decoder.parameters(),self.grad\_clip)

#

Optimize

594self.optimizer.step()595596tracker.save()

#

598defsample(self):

#

Randomly pick a sample from validation dataset to encoder

600data,\*\_=self.valid\_dataset[np.random.choice(len(self.valid\_dataset))]

#

Add batch dimension and move it to device

602data=data.unsqueeze(1).to(self.device)

#

Sample

604self.sampler.sample(data,self.temperature)

#

607defmain():608configs=Configs()609experiment.create(name="sketch\_rnn")

#

Pass a dictionary of configurations

612experiment.configs(configs,{613'optimizer.optimizer':'Adam',

#

We use a learning rate of 1e-3 because we can see results faster. Paper had suggested 1e-4 .

616'optimizer.learning\_rate':1e-3,

#

Name of the dataset

618'dataset\_name':'bicycle',

#

Number of inner iterations within an epoch to switch between training, validation and sampling.

620'inner\_iterations':10621})622623withexperiment.start():

#

Run the experiment

625configs.run()626627628if\_\_name\_\_=="\_\_main\_\_":629main()

labml.ai