Back to Annotated Deep Learning Paper Implementations

Autoregressive Transformer Decoder in JAX from scratch

docs/transformers/jax_transformer/index.html

latest32.4 KB
Original Source

hometransformersjax_transformer

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/jax_transformer/ init.py)

#

Autoregressive Transformer Decoder in JAX from scratch

Contents

28fromfunctoolsimportpartial29fromtypingimportDict,NamedTuple,Tuple,Any,Callable30fromtypingimportList,TypeVar,Generic31fromtypingimportUnion,Optional3233importjax34importjax.numpyasjnp35importnumpyasnp3637fromlabmlimportlab,monit,experiment,tracker38fromlabmlimportlogger39fromlabml.loggerimportText40fromlabml.utils.downloadimportdownload\_file

#

Module

This is a base class for all modules. It handles parameters and transforms methods to pure functions for JAX to compile and differentiate.

You can skip these modules to get into the models directly.

The modules stores parameters and sub-modules separately. When we want to transform any method to a pure function, we pass the parameters of the module and the sub-module as an argument and assign the passed values to class.

This is based on a blog post: From PyTorch to JAX: towards neural net frameworks that purify stateful code.

43classModule:

#

Store all parameters and sub-modules in dictionaries

63\_submodules:Dict[str,'Module']64\_params:Dict[str,jnp.ndarray]

#

Initialize

66def\_\_init\_\_(self):

#

68self.\_params={}69self.\_submodules={}

#

Get attribute

We override the get attribute operation. So when you reference an attribute with model.attribute this function gets called.

Read this guide if you are not familiar with Python magic methods.

71def\_\_getattr\_\_(self,attr\_name:str):

#

If the attribute is a parameter

83ifattr\_nameinself.\_params:84returnself.\_params[attr\_name]

#

If the attribute is a sub-module

86elifattr\_nameinself.\_submodules:87returnself.\_submodules[attr\_name]

#

Otherwise fallback to normal attributes. The attributes are stored in __dict__ by Python.

90else:91returnself.\_\_dict\_\_[attr\_name]

#

Set attribute

We override the set attribute operation. So when you assign an attribute with model.attribute this function gets called.

93def\_\_setattr\_\_(self,key:str,value:Any):

#

If the value is also a module

102ifisinstance(value,Module):103self.\_submodules[key]=value

#

If the value is a JAX array

105elifisinstance(value,jnp.ndarray):106self.\_params[key]=value

#

Otherwise add it to __dict__

108else:109self.\_\_dict\_\_[key]=value

#

Clear parameters

These clears out all the parameters. This is used when a method is called as a pure function. We first clears out all the parameters and assigns the parameters passed to the pure function.

111def\_clear\_params(self):

#

Clear parameters of the module

120self.\_params={}

#

Recursively clear parameters of submodules

122forsminself.\_submodules.values():123sm.\_clear\_params()

#

Collect all the parameters

This recursively collects all the parameters of the module and sub-modules into a dictionary.

125defget\_params(self)-\>Dict[str,jnp.ndarray]:

#

Parameters of the model

133params=self.\_params.copy()

#

Parameters of the submodules

135forsm\_name,sminself.\_submodules.items():136forname,valueinsm.get\_params().items():

#

The dictionary keys are of the form module_name/module_name/param_name

138params[sm\_name+"/"+name]=value

#

140returnparams

#

Set all the parameters

142def\_set\_params(self,params:Dict[str,jnp.ndarray]):

#

Iterate through parameters. Their names have the form module_name/module_name/param_name

149forname,valueinparams.items():

#

Split to get module names and parameter name

151self.\_set\_param(name.split("/"),value)

#

Set a single parameter

This is called by _set_params

153def\_set\_param(self,param\_path:List[str],value:jnp.ndarray):

#

No module names; i.e. a parameter of this module

160iflen(param\_path)==1:161self.\_params[param\_path[0]]=value

#

Parameter of a submodule

163else:164self.\_submodules[param\_path[0]].\_set\_param(param\_path[1:],value)

#

Transform a member method to a pure function

This transforms a member method to a pure function that accepts a dictionary of parameters as an argument.

For example,

params = model.get_params()
pure_function = model.purify(model.calculate_loss)
output = pure_function(params, data)
166defpurify(self,method:Callable)-\>Callable:

#

182defpure\_method(params:Dict[str,jnp.array],\*args):

#

Clear parameters in the object

184self.\_clear\_params()

#

Assign the passed parameters

186self.\_set\_params(params)

#

Invoke the method

188result=method(\*args)

#

Return the result

190returnresult

#

193returnpure\_method

#

Type for generics in the module list class

197M=TypeVar('M',bound=Module)

#

Module list

This stores a list of modules. We needed this for transformer decoder to hold the list of transformer layers.

200classModuleList(Module,Generic[M]):

#

For list of modules

209\_submodules:List[M]

#

Initialize with a list of modules.

211def\_\_init\_\_(self,modules:List[M]):

#

215super().\_\_init\_\_()216self.\_submodules=modules

#

Get the idx -th module

218def\_\_getitem\_\_(self,idx:int)-\>M:

#

222returnself.\_submodules[idx]

#

This is not supported

224def\_\_setitem\_\_(self,key,value):

#

228raiseNotImplementedError

#

Number of modules

230def\_\_len\_\_(self):

#

234returnlen(self.\_submodules)

#

Override __getattr__ of Module

236def\_\_getattr\_\_(self,item):

#

240returnself.\_\_dict\_\_[item]

#

Override __setattr__ of Module

242def\_\_setattr\_\_(self,key,value):

#

246self.\_\_dict\_\_[key]=value

#

Clear all parameters

248def\_clear\_params(self):

#

252self.\_params={}253forsminself.\_submodules:254sm.\_clear\_params()

#

Get all parameters

256defget\_params(self):

#

260params=self.\_params261fori,sminenumerate(self.\_submodules):262forname,valueinsm.get\_params().items():263params[f'{i}/{name}']=value264returnparams

#

Set a parameter

266def\_set\_param(self,param\_path:List[str],value:jnp.ndarray):

#

270self.\_submodules[int(param\_path[0])].\_set\_param(param\_path[1:],value)

#

Embedding layer

This maintains embeddings by id.

273classEmbedding(Module):

#

  • rnd_key is the PRNG state
  • n_embeddings is the number of embeddings
  • n_dim is the size of an embedding
282def\_\_init\_\_(self,rnd\_key:jax.random.PRNGKey,n\_embeddings:int,n\_dim:int):

#

288super().\_\_init\_\_()

#

Embeddings are initialized from N(0,1)

290self.embeddings=jax.random.normal(rnd\_key,(n\_embeddings,n\_dim))

#

Return the embeddings for the given ids

292def\_\_call\_\_(self,ids:jnp.ndarray):

#

296returnself.embeddings[ids,:]

#

Embed tokens and add parameterized positional encodings

This is based on our PyTorch implementation.

299classEmbeddingsWithLearnedPositionalEncoding(Module):

#

  • rnd_key is the PRNG state
  • n_vocab is the vocabulary size
  • d_model is the embedding size
  • max_len is the maximum sequence length (to initialize positional encodings)
309def\_\_init\_\_(self,rnd\_key:jax.random.PRNGKey,n\_vocab:int,d\_model:int,max\_len:int=4096):

#

316super().\_\_init\_\_()

#

Embeddings

318self.embeddings=Embedding(rnd\_key,n\_vocab,d\_model)

#

Positional encodings coefficient d​1​

320self.pe\_coef=1/d\_model\*\*0.5

#

Positional encodings initialized to zeros

322self.positional\_encodings=jnp.zeros((max\_len,d\_model))

#

324def\_\_call\_\_(self,x:jnp.ndarray):

#

Get positional encodings

326pe=self.positional\_encodings[:x.shape[0]]

#

Get embeddings and add positional encodings

328returnself.embeddings(x)\*self.pe\_coef+pe

#

Linear Layer

This is a simple linear layer with a weight matrix and a bias vector

331classLinear(Module):

#

  • rnd_key is the PRNG state
  • in_features is the number of features in the input
  • out_features is the number of features in the output
340def\_\_init\_\_(self,rnd\_key:jax.random.PRNGKey,in\_features:int,out\_features:int):

#

346super().\_\_init\_\_()

#

Initialize weights to U(−din​​1​,din​​1​)

349rnd\_range=1/in\_features\*\*0.5350self.weight=jax.random.uniform(rnd\_key,(in\_features,out\_features),351minval=-rnd\_range,maxval=rnd\_range)

#

Initialize the biases to 0

353self.bias=jnp.zeros((out\_features,))

#

355def\_\_call\_\_(self,x:jnp.ndarray):

#

Multiply by weights and add the bias

357returnjnp.matmul(x,self.weight)+self.bias

#

Layer Normalization

This implements the the layer normalization from the paper Layer Normalization.

When input X∈RL×C is a sequence of embeddings, where C is the number of channels, L is the length of the sequence. γ∈RC and β∈RC. LN(X)=γCVar​[X]+ϵ​X−CE​[X]​+β

This is based on our PyTorch implementation.

360classLayerNorm(Module):

#

  • normalized_shapeS is the shape of the elements (except the batch). The input should then be X∈R∗×S[0]×S[1]×...×S[n]
  • eps is ϵ, used in Var[X]+ϵ​ for numerical stability
  • elementwise_affine is whether to scale and shift the normalized value
380def\_\_init\_\_(self,normalized\_shape:Union[Tuple[int],List[int]],\*,381eps:float=1e-5,elementwise\_affine:bool=True):

#

389super().\_\_init\_\_()390391self.eps=eps392self.elementwise\_affine=elementwise\_affine393self.normalized\_shape=tuple(normalized\_shape)

#

Create parameters for γ and β for gain and bias

396ifelementwise\_affine:397self.gain=jnp.ones(normalized\_shape)398self.bias=jnp.zeros(normalized\_shape)

#

400def\_\_call\_\_(self,x:jnp.ndarray):

#

Sanity check to make sure the shapes match

402assertself.normalized\_shape==x.shape[-len(self.normalized\_shape):]

#

The exes to calculate the mean and variance on

405axes=[-(i+1)foriinrange(len(self.normalized\_shape))]

#

Calculate the mean of all elements; i.e. the means for each element E[X]

408mean=x.mean(axis=axes,keepdims=True)

#

Calculate the squared mean of all elements; i.e. the means for each element E[X2]

411mean\_2=(x\*\*2).mean(axis=axes,keepdims=True)

#

Variance of all element Var[X]=E[X2]−E[X]2

413var=mean\_2-mean\*\*2

#

Normalize X^=Var[X]+ϵ​X−E[X]​

415x\_norm=(x-mean)/(var+self.eps)\*\*0.5

#

Scale and shift LN(x)=γX^+β

418ifself.elementwise\_affine:419x\_norm=self.gain\*x\_norm+self.bias

#

422returnx\_norm

#

Multi-Head Attention Module

This computes scaled multi-headed attention from the paper Attention Is All You Need for given query , key and value vectors.

Attention(Q,K,V)=seqsoftmax​(dk​​QK⊤​)V

In simple terms, it finds keys that matches the query, and gets the values of those keys.

It uses dot-product of query and key as the indicator of how matching they are. Before taking the softmax the dot-products are scaled by dk​​1​. This is done to avoid large dot-product values causing softmax to give very small gradients when dk​ is large.

Softmax is calculated along the axis of of the sequence (or time) for keys.

This is based on our PyTorch implementation.

425classMultiHeadAttention(Module):

#

  • rnd_key is the PRNG state
  • heads is the number of heads.
  • d_model is the number of features in the query , key and value vectors.
451def\_\_init\_\_(self,rnd\_key:jax.random.PRNGKey,heads:int,d\_model:int):

#

458super().\_\_init\_\_()

#

Split the PRNG state

461\_,\*rnd\_keys=jax.random.split(rnd\_key,5)

#

Number of features per head

464self.d\_k=d\_model//heads

#

Number of heads

466self.heads=heads

#

These transform the query , key and value vectors for multi-headed attention.

469self.query=Linear(rnd\_keys[0],d\_model,d\_model)470self.key=Linear(rnd\_keys[1],d\_model,d\_model)471self.value=Linear(rnd\_keys[2],d\_model,d\_model)

#

Output layer

474self.output=Linear(rnd\_keys[3],d\_model,d\_model)

#

Scaling factor before the softmax

476self.scale=1/self.d\_k\*\*0.5

#

query , key and value are the tensors that store collection of query, key and value vectors. They have shape [seq_len, d_model] .

mask has shape [seq_len, seq_len] and mask[i, j] indicates whether query at position i can see key-value at position j .

478def\_\_call\_\_(self,\*,479query:jnp.ndarray,480key:jnp.ndarray,481value:jnp.ndarray,482mask:Optional[jnp.ndarray]=None):

#

Get sequence length

493seq\_len=len(query)494495ifmaskisnotNone:

#

Check mask shape

497assertmask.shape[0]==query.shape[0]498assertmask.shape[1]==key.shape[0]

#

Same mask applied to all heads.

501mask=mask[:,:,None]

#

Apply linear transformations

504query=self.query(query)505key=self.key(key)506value=self.value(value)

#

Reshape to split into heads Input has shape [seq_len, batch_size, d_model] . We split the last dimension into heads and d_k .

511query=query.reshape(\*query.shape[:-1],self.heads,self.d\_k)512key=key.reshape(\*key.shape[:-1],self.heads,self.d\_k)513value=value.reshape(\*value.shape[:-1],self.heads,self.d\_k)

#

Compute attention scores QK⊤. This gives a tensor of shape [seq_len, seq_len, heads] . Sijh​=d∑​Qihd​Kjhd​

518scores=jnp.einsum('ihd,jhd-\>ijh',query,key)

#

Scale scores dk​​QK⊤​

521scores\*=self.scale

#

Apply mask

524ifmaskisnotNone:525scores=scores+(mask==0)\*float('-inf')

#

softmax attention along the key sequence dimension seqsoftmax​(dk​​QK⊤​)

529attn=jax.nn.softmax(scores,axis=1)

#

Multiply by values seqsoftmax​(dk​​QK⊤​)V

533x=jnp.einsum("ijh,jhd-\>ihd",attn,value)

#

Concatenate multiple heads

536x=x.reshape(seq\_len,-1)

#

Output layer

539returnself.output(x)

#

Position-wise Feed-Forward layer

This is based on our PyTorch implementation.

542classFeedForward(Module):

#

  • rnd_key is the PRNG state
  • d_model is the number of features in a token embedding
  • d_ff is the number of features in the hidden layer of the FFN
  • activation is the activation function f
552def\_\_init\_\_(self,rnd\_key:jax.random.PRNGKey,d\_model:int,d\_ff:int,553activation=jax.nn.relu):

#

560super().\_\_init\_\_()

#

Split the PRNG state

562\_,\*rnd\_keys=jax.random.split(rnd\_key,5)

#

Layer one parameterized by weight W1​ and bias b1​

565self.layer1=Linear(rnd\_keys[0],d\_model,d\_ff)

#

Layer one parameterized by weight W1​ and bias b1​

567self.layer2=Linear(rnd\_keys[1],d\_ff,d\_model)

#

Activation function f

569self.activation=activation

#

571def\_\_call\_\_(self,x:jnp.ndarray):

#

f(xW1​+b1​)

573x=self.activation(self.layer1(x))

#

f(xW1​+b1​)W2​+b2​

575returnself.layer2(x)

#

Transformer Layer

This is a transformer layer with multi-head attention and a position-wise feed-forward layer. We use pre-layer layer normalization.

578classTransformerLayer(Module):

#

  • d_model is the token embedding size
  • self_attn is the self attention module
  • feed_forward is the feed forward module
588def\_\_init\_\_(self,589d\_model:int,590self\_attn:MultiHeadAttention,591feed\_forward:FeedForward):

#

597super().\_\_init\_\_()598self.size=d\_model599self.self\_attn=self\_attn600self.feed\_forward=feed\_forward601self.norm\_self\_attn=LayerNorm([d\_model])602self.norm\_ff=LayerNorm([d\_model])

#

604def\_\_call\_\_(self,x:jnp.ndarray,mask:jnp.ndarray):

#

Normalize the vectors before doing self attention

606z=self.norm\_self\_attn(x)

#

Run through self attention, i.e. keys and values are from self

608self\_attn=self.self\_attn(query=z,key=z,value=z,mask=mask)609x=x+self\_attn

#

Normalize for feed-forward

612z=self.norm\_ff(x)

#

Pass through the feed-forward network

614ff=self.feed\_forward(z)

#

Add the feed-forward results

616x=x+ff

#

618returnx

#

Cross Entropy Loss

621classCrossEntropyLoss(Module):

#

628def\_\_init\_\_(self):629super().\_\_init\_\_()

#

Use jax.vmap to vectorize the loss function

632self.\_loss\_vmap=jax.vmap(self.\_loss,in\_axes=(0,0,))

#

634def\_loss(self,output:jnp.ndarray,target:jnp.ndarray):

#

−k∑​yk​logy^​k​

636return-jax.nn.log\_softmax(output)[target]

#

  • output is the model outputs of shape [seq_len, n_vocab]
  • target is the target of shape [seq_len]
638def\_\_call\_\_(self,output:jnp.ndarray,target:jnp.ndarray):

#

Use the vectorized loss function and calculate the mean.

We could have used a for loop to calculate the losses but using vmap is about 10X faster

647returnself.\_loss\_vmap(output,target).mean()

#

Autoregressive Transformer

This is the transformer decode with embedding and output layers.

650classAutoregressiveTransformer(Module):

#

658layers:ModuleList[TransformerLayer]

#

  • rnd_key is the PRNG state
  • n_vocab is the vocabulary size
  • d_model is the number of features in a token embedding
  • n_layers is the number of transformer layers
  • heads is the number of attention heads
  • d_ff is the number of features in the hidden layer of the FFN
660def\_\_init\_\_(self,rnd\_key:jax.random.PRNGKey,n\_vocab:int,d\_model:int,n\_layers:int,heads:int,d\_ff:int):

#

669super().\_\_init\_\_()670self.n\_vocab=n\_vocab671self.d\_model=d\_model672self.loss\_func=CrossEntropyLoss()

#

For transformer layers

675layers=[]676foriinrange(n\_layers):

#

Split PRNG state

678rnd\_key,mha\_key,ffn\_key=jax.random.split(rnd\_key,3)

#

Create a transformer layer

680attn=MultiHeadAttention(mha\_key,heads,d\_model)681ffn=FeedForward(ffn\_key,d\_model,d\_ff)682layers.append(TransformerLayer(d\_model,attn,ffn))

#

Make a module list

684self.layers=ModuleList(layers)

#

Split PRNG state

687rnd\_key,emb\_key,out\_key=jax.random.split(rnd\_key,3)

#

Create embedding layer

689self.embeddings=EmbeddingsWithLearnedPositionalEncoding(emb\_key,n\_vocab,d\_model)

#

Final normalization and output layer

691self.norm=LayerNorm([d\_model])692self.output=Linear(out\_key,d\_model,n\_vocab)

#

694def\_\_call\_\_(self,x:jnp.ndarray):

#

Get sequence length

696seq\_len=len(x)

#

A mask for attention so that a token can only see tokens before that

698mask=jnp.tril(jnp.ones((seq\_len,seq\_len),bool))

#

Get embeddings with positional encodings

700x=self.embeddings(x)

#

Apply the transformer layers

702foriinrange(len(self.layers)):703x=self.layers[i](x,mask)

#

Final normalization and linear transformation to get the logits

706returnself.output(self.norm(x))

#

Calculate the loss

708defget\_loss(self,x:jnp.ndarray):

#

Get model outputs

713output=self(x)

#

Cross entropy loss

715returnself.loss\_func(output[:-1],x[1:])

#

Sample

The starting sequence is given by seq and we greedily sample `length1 tokens

717defsample(self,seq:jnp.ndarray,length:int=20):

#

723foriinrange(length):

#

Sample the highest probability token

725idx=jnp.argmax(self(seq)[-1])

#

Add it to the sequence

727seq=jnp.concatenate((seq,idx[None]))

#

Return the sampled sequence

730returnseq

#

This is a named tuple for storing Adam optimizer state for a parameter

733classAdamState(NamedTuple):

#

737m:jnp.ndarray738v:jnp.ndarray

#

Adam Optimizer

This is from paper Adam: A Method for Stochastic Optimization.

For parameter θt​ and gradient gt​ at step t, the Adam update is,

mt​vt​m^t​v^t​θt​​←β1​mt−1​+(1−β1​)⋅gt​←β2​vt−1​+(1−β2​)⋅gt​2←1−β1​tmt​​←1−β2​tvt​​←θt−1​−α⋅v^t​​+ϵm^t​​​

where α, β1​, β2​ and ϵ are scalar hyper parameters. mt​ and vt​ are first and second order moments. m^t​ and v^t​ are biased corrected moments. ϵ is used as a fix for division by zero error, but also acts as a form of a hyper-parameter that acts against variance in gradients.

741classAdam:

#

  • params is the tree-map of parameters
  • lr is the learning rate α
  • betas is a tuple of (β1​, β2​)
  • eps is ϵ^`
767def\_\_init\_\_(self,params:Dict,768lr:float=0.001,betas:Tuple[float,float]=(0.9,0.999),769eps:float=1e-16,):

#

777super().\_\_init\_\_()778self.lr=lr779self.betas=betas780self.eps=eps

#

States for each parameter

783self.states=jax.tree.map(self.\_init\_state,params)

#

Optimized step function

785self.\_step\_jit=jax.jit(self.\_step)

#

Number of steps taken t

787self.\_n\_steps=0

#

Optimized update state function

789self.\_update\_state\_jit=jax.jit(self.\_update\_state)

#

Initialize the state for a given parameter

791def\_init\_state(self,param:jnp.ndarray):

#

795returnAdamState(jnp.zeros\_like(param),jnp.zeros\_like(param))

#

Step function

  • params is a tree-map of parameters
  • grads is a tree-map of gradients
797defstep(self,params:Dict,grads:Dict):

#

Increment step t

805self.\_n\_steps+=1

#

Update states for each parameter

807self.states=jax.tree.map(self.\_update\_state\_jit,grads,self.states)

#

Return updated parameters θt​

809returnjax.tree.map(partial(self.\_step\_jit,self.\_n\_steps),params,self.states)

#

Update parameters

This performs a Adam update on the given parameter

811def\_step(self,n\_steps:int,param:jnp.ndarray,state:AdamState):

#

Bias corrections for m^t​: 1−β1​t and for v^t​: 1−β2​t

819bias\_correction=[1-beta\*\*n\_stepsforbetainself.betas]

#

Uncorrected first and second moments mt​ and vt​

821m,v=state

#

α1−β1​t1−β2​t​​

824step\_size=self.lr\*(bias\_correction[1]\*\*0.5)/bias\_correction[0]

#

vt​​+ϵ^

826den=(v\*\*0.5)+self.eps

#

θt​←θt−1​−α1−β1​t1−β2​t​​⋅vt​​+ϵ^mt​​

830returnparam-step\_size\*m/den

#

Update state

This updates uncorrected first and second moments mt​ and vt​

832def\_update\_state(self,grad,state:AdamState):

#

Uncorrected first and second moments mt−1​ and vt−1​

839m,v=state

#

Clip gradients

841grad=jnp.clip(grad,-1,1)

#

mt​←β1​mt−1​+(1−β1​)⋅gt​

843m=self.betas[0]\*m+grad\*(1-self.betas[0])

#

vt​←β2​vt−1​+(1−β2​)⋅gt​2

845v=self.betas[1]\*v+(grad\*\*2)\*(1-self.betas[1])

#

Return the new state

848returnAdamState(m,v)

#

Tiny Shakespeare dataset

851classTinyShakespeare:

#

  • rnd_key is the PRNG state
  • seq_len is the sequence length of a sample
  • batch_size is the batch size
858def\_\_init\_\_(self,rnd\_key:jax.random.PRNGKey,seq\_len:int,batch\_size:int):

#

865self.batch\_size=batch\_size

#

PRNG key for shuffling the samples

867\_,self.rnd\_key=jax.random.split(rnd\_key)

#

Local path of the text file

870path=lab.get\_data\_path()/'tiny\_shakespeare.txt'

#

Download if it doesn't exist

872url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'873ifnotpath.exists():874download\_file(url,path)

#

Read the file

877withopen(str(path),'r')asf:878self.text=f.read()

#

Get the characters/tokens

881tokens=sorted(list(set(self.text)))

#

Number of tokens

884self.n\_tokens=len(tokens)

#

Map tokens to ids

886self.stoi={t:ifori,tinenumerate(tokens)}

#

Id to token/character

888self.itos=tokens

#

As a list of ids

891data=jnp.array([self.stoi[s]forsinlist(self.text)])

#

Number of batches

893self.n\_batches=len(data)//(seq\_len\*batch\_size)

#

Truncate

895data=data[:self.n\_batches\*seq\_len\*batch\_size]

#

Reshape into a samples (better to use random offsets, but lets ignore that here)

897self.data=data.reshape((-1,seq\_len))

#

List of sample indexes

899self.idx=jnp.arange(len(self.data))

#

Setup for iteration

901def\_\_iter\_\_(self):

#

Iteration step

906self.\_iter\_idx=0

#

Split PRNG key

908self.rnd\_key,rnd\_key=jax.random.split(self.rnd\_key)

#

Shuffle sample indexes

910self.idx=jax.random.permutation(rnd\_key,self.idx)

#

913returnself

#

Number of batches

915def\_\_len\_\_(self):

#

919returnself.n\_batches

#

Get next batch

921def\_\_next\_\_(self):

#

Stop iteration after iterating through all batches

927ifself.\_iter\_idx\>=self.n\_batches:928raiseStopIteration()

#

Sample indexes for the batch

931idx=self.idx[self.\_iter\_idx\*self.batch\_size:(self.\_iter\_idx+1)\*self.batch\_size]

#

Increment iteration step

933self.\_iter\_idx+=1

#

Return samples

936returnself.data[idx]

#

Run the experiment

939defmain():

#

Create experiment

947experiment.create(name='jax')

#

Create PRNG key

949rnd\_key=jax.random.PRNGKey(0)

#

Create dataset

951dataset=TinyShakespeare(rnd\_key,seq\_len=32,batch\_size=128)

#

Create the model

954model=AutoregressiveTransformer(rnd\_key,dataset.n\_tokens,955d\_model=128,n\_layers=3,heads=8,d\_ff=512)

#

Get model parameters

957params=model.get\_params()

#

JAX compiled pure sampling function

960pure\_sample\_fn=jax.jit(model.purify(model.sample))

#

JAX compiled pure function to get logits for a batch. First we transform model. __call__ to a pure function which accepts two arguments: parameters, and input sequence. Next we vectorize the function to process a batch of samples. in_axes specifies which arguments to parallelize and along which axis. (None, 0) means we have the same parameters but parallelize the inputs across the first axis. out_axes specifies along which axis to merge the results.

968pure\_forward\_fn=jax.jit(jax.vmap(model.purify(model.\_\_call\_\_),969in\_axes=(None,0),out\_axes=0))

#

Similarly we vectorize loss computation

971pure\_loss\_fn=jax.jit(jax.vmap(model.purify(model.get\_loss),972in\_axes=(None,0),out\_axes=0))

#

A function to get mean loss

975defget\_loss(params,seq):976returnpure\_loss\_fn(params,seq).mean()

#

A function to compute gradients for the first argument (parameters)

979grad\_loss\_fn=jax.jit(jax.grad(get\_loss,argnums=0))

#

Create optimizer

982optimizer=Adam(params)

#

Start the experiment

985withexperiment.start():

#

Iterate for 32 epochs

987forepochinmonit.loop(32):

#

Iterate through batches

989fordatainmonit.iterate('Train',dataset):

#

Compute and log the loss

991loss=get\_loss(params,data)992tracker.save('loss',np.asarray(loss))

#

Get the gradients

994grads=grad\_loss\_fn(params,data)

#

Update parameters

996params=optimizer.step(params,grads)

#

999tracker.new\_line()

#

Log a sample after each epoch

1001prompt=[dataset.stoi[c]forcin'It ']1002sampled=pure\_sample\_fn(params,jnp.array(prompt))[len(prompt):]1003sampled=''.join([dataset.itos[i]foriinsampled])1004sampled=sampled.replace('\n','\\n')1005logger.log(('It ',Text.meta),(sampled,Text.value))

#

1009if\_\_name\_\_=='\_\_main\_\_':1010main()

labml.ai