docs/transformers/jax_transformer/index.html
hometransformersjax_transformer
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/jax_transformer/ init.py)
28fromfunctoolsimportpartial29fromtypingimportDict,NamedTuple,Tuple,Any,Callable30fromtypingimportList,TypeVar,Generic31fromtypingimportUnion,Optional3233importjax34importjax.numpyasjnp35importnumpyasnp3637fromlabmlimportlab,monit,experiment,tracker38fromlabmlimportlogger39fromlabml.loggerimportText40fromlabml.utils.downloadimportdownload\_file
This is a base class for all modules. It handles parameters and transforms methods to pure functions for JAX to compile and differentiate.
You can skip these modules to get into the models directly.
The modules stores parameters and sub-modules separately. When we want to transform any method to a pure function, we pass the parameters of the module and the sub-module as an argument and assign the passed values to class.
This is based on a blog post: From PyTorch to JAX: towards neural net frameworks that purify stateful code.
43classModule:
Store all parameters and sub-modules in dictionaries
63\_submodules:Dict[str,'Module']64\_params:Dict[str,jnp.ndarray]
Initialize
66def\_\_init\_\_(self):
68self.\_params={}69self.\_submodules={}
We override the get attribute operation. So when you reference an attribute with model.attribute this function gets called.
Read this guide if you are not familiar with Python magic methods.
71def\_\_getattr\_\_(self,attr\_name:str):
If the attribute is a parameter
83ifattr\_nameinself.\_params:84returnself.\_params[attr\_name]
If the attribute is a sub-module
86elifattr\_nameinself.\_submodules:87returnself.\_submodules[attr\_name]
Otherwise fallback to normal attributes. The attributes are stored in __dict__ by Python.
90else:91returnself.\_\_dict\_\_[attr\_name]
We override the set attribute operation. So when you assign an attribute with model.attribute this function gets called.
93def\_\_setattr\_\_(self,key:str,value:Any):
If the value is also a module
102ifisinstance(value,Module):103self.\_submodules[key]=value
If the value is a JAX array
105elifisinstance(value,jnp.ndarray):106self.\_params[key]=value
Otherwise add it to __dict__
108else:109self.\_\_dict\_\_[key]=value
These clears out all the parameters. This is used when a method is called as a pure function. We first clears out all the parameters and assigns the parameters passed to the pure function.
111def\_clear\_params(self):
Clear parameters of the module
120self.\_params={}
Recursively clear parameters of submodules
122forsminself.\_submodules.values():123sm.\_clear\_params()
This recursively collects all the parameters of the module and sub-modules into a dictionary.
125defget\_params(self)-\>Dict[str,jnp.ndarray]:
Parameters of the model
133params=self.\_params.copy()
Parameters of the submodules
135forsm\_name,sminself.\_submodules.items():136forname,valueinsm.get\_params().items():
The dictionary keys are of the form module_name/module_name/param_name
138params[sm\_name+"/"+name]=value
140returnparams
142def\_set\_params(self,params:Dict[str,jnp.ndarray]):
Iterate through parameters. Their names have the form module_name/module_name/param_name
149forname,valueinparams.items():
Split to get module names and parameter name
151self.\_set\_param(name.split("/"),value)
This is called by _set_params
153def\_set\_param(self,param\_path:List[str],value:jnp.ndarray):
No module names; i.e. a parameter of this module
160iflen(param\_path)==1:161self.\_params[param\_path[0]]=value
Parameter of a submodule
163else:164self.\_submodules[param\_path[0]].\_set\_param(param\_path[1:],value)
This transforms a member method to a pure function that accepts a dictionary of parameters as an argument.
For example,
params = model.get_params()
pure_function = model.purify(model.calculate_loss)
output = pure_function(params, data)
166defpurify(self,method:Callable)-\>Callable:
182defpure\_method(params:Dict[str,jnp.array],\*args):
Clear parameters in the object
184self.\_clear\_params()
Assign the passed parameters
186self.\_set\_params(params)
Invoke the method
188result=method(\*args)
Return the result
190returnresult
193returnpure\_method
Type for generics in the module list class
197M=TypeVar('M',bound=Module)
This stores a list of modules. We needed this for transformer decoder to hold the list of transformer layers.
200classModuleList(Module,Generic[M]):
For list of modules
209\_submodules:List[M]
Initialize with a list of modules.
211def\_\_init\_\_(self,modules:List[M]):
215super().\_\_init\_\_()216self.\_submodules=modules
idx -th module218def\_\_getitem\_\_(self,idx:int)-\>M:
222returnself.\_submodules[idx]
This is not supported
224def\_\_setitem\_\_(self,key,value):
228raiseNotImplementedError
230def\_\_len\_\_(self):
234returnlen(self.\_submodules)
Override __getattr__ of Module
236def\_\_getattr\_\_(self,item):
240returnself.\_\_dict\_\_[item]
Override __setattr__ of Module
242def\_\_setattr\_\_(self,key,value):
246self.\_\_dict\_\_[key]=value
248def\_clear\_params(self):
252self.\_params={}253forsminself.\_submodules:254sm.\_clear\_params()
256defget\_params(self):
260params=self.\_params261fori,sminenumerate(self.\_submodules):262forname,valueinsm.get\_params().items():263params[f'{i}/{name}']=value264returnparams
266def\_set\_param(self,param\_path:List[str],value:jnp.ndarray):
270self.\_submodules[int(param\_path[0])].\_set\_param(param\_path[1:],value)
This maintains embeddings by id.
273classEmbedding(Module):
rnd_key is the PRNG staten_embeddings is the number of embeddingsn_dim is the size of an embedding282def\_\_init\_\_(self,rnd\_key:jax.random.PRNGKey,n\_embeddings:int,n\_dim:int):
288super().\_\_init\_\_()
Embeddings are initialized from N(0,1)
290self.embeddings=jax.random.normal(rnd\_key,(n\_embeddings,n\_dim))
Return the embeddings for the given ids
292def\_\_call\_\_(self,ids:jnp.ndarray):
296returnself.embeddings[ids,:]
This is based on our PyTorch implementation.
299classEmbeddingsWithLearnedPositionalEncoding(Module):
rnd_key is the PRNG staten_vocab is the vocabulary sized_model is the embedding sizemax_len is the maximum sequence length (to initialize positional encodings)309def\_\_init\_\_(self,rnd\_key:jax.random.PRNGKey,n\_vocab:int,d\_model:int,max\_len:int=4096):
316super().\_\_init\_\_()
Embeddings
318self.embeddings=Embedding(rnd\_key,n\_vocab,d\_model)
Positional encodings coefficient d1
320self.pe\_coef=1/d\_model\*\*0.5
Positional encodings initialized to zeros
322self.positional\_encodings=jnp.zeros((max\_len,d\_model))
324def\_\_call\_\_(self,x:jnp.ndarray):
Get positional encodings
326pe=self.positional\_encodings[:x.shape[0]]
Get embeddings and add positional encodings
328returnself.embeddings(x)\*self.pe\_coef+pe
This is a simple linear layer with a weight matrix and a bias vector
331classLinear(Module):
rnd_key is the PRNG statein_features is the number of features in the inputout_features is the number of features in the output340def\_\_init\_\_(self,rnd\_key:jax.random.PRNGKey,in\_features:int,out\_features:int):
346super().\_\_init\_\_()
Initialize weights to U(−din1,din1)
349rnd\_range=1/in\_features\*\*0.5350self.weight=jax.random.uniform(rnd\_key,(in\_features,out\_features),351minval=-rnd\_range,maxval=rnd\_range)
Initialize the biases to 0
353self.bias=jnp.zeros((out\_features,))
355def\_\_call\_\_(self,x:jnp.ndarray):
Multiply by weights and add the bias
357returnjnp.matmul(x,self.weight)+self.bias
This implements the the layer normalization from the paper Layer Normalization.
When input X∈RL×C is a sequence of embeddings, where C is the number of channels, L is the length of the sequence. γ∈RC and β∈RC. LN(X)=γCVar[X]+ϵX−CE[X]+β
This is based on our PyTorch implementation.
360classLayerNorm(Module):
normalized_shapeS is the shape of the elements (except the batch). The input should then be X∈R∗×S[0]×S[1]×...×S[n]eps is ϵ, used in Var[X]+ϵ for numerical stabilityelementwise_affine is whether to scale and shift the normalized value380def\_\_init\_\_(self,normalized\_shape:Union[Tuple[int],List[int]],\*,381eps:float=1e-5,elementwise\_affine:bool=True):
389super().\_\_init\_\_()390391self.eps=eps392self.elementwise\_affine=elementwise\_affine393self.normalized\_shape=tuple(normalized\_shape)
Create parameters for γ and β for gain and bias
396ifelementwise\_affine:397self.gain=jnp.ones(normalized\_shape)398self.bias=jnp.zeros(normalized\_shape)
400def\_\_call\_\_(self,x:jnp.ndarray):
Sanity check to make sure the shapes match
402assertself.normalized\_shape==x.shape[-len(self.normalized\_shape):]
The exes to calculate the mean and variance on
405axes=[-(i+1)foriinrange(len(self.normalized\_shape))]
Calculate the mean of all elements; i.e. the means for each element E[X]
408mean=x.mean(axis=axes,keepdims=True)
Calculate the squared mean of all elements; i.e. the means for each element E[X2]
411mean\_2=(x\*\*2).mean(axis=axes,keepdims=True)
Variance of all element Var[X]=E[X2]−E[X]2
413var=mean\_2-mean\*\*2
Normalize X^=Var[X]+ϵX−E[X]
415x\_norm=(x-mean)/(var+self.eps)\*\*0.5
Scale and shift LN(x)=γX^+β
418ifself.elementwise\_affine:419x\_norm=self.gain\*x\_norm+self.bias
422returnx\_norm
This computes scaled multi-headed attention from the paper Attention Is All You Need for given query , key and value vectors.
Attention(Q,K,V)=seqsoftmax(dkQK⊤)V
In simple terms, it finds keys that matches the query, and gets the values of those keys.
It uses dot-product of query and key as the indicator of how matching they are. Before taking the softmax the dot-products are scaled by dk1. This is done to avoid large dot-product values causing softmax to give very small gradients when dk is large.
Softmax is calculated along the axis of of the sequence (or time) for keys.
This is based on our PyTorch implementation.
425classMultiHeadAttention(Module):
rnd_key is the PRNG stateheads is the number of heads.d_model is the number of features in the query , key and value vectors.451def\_\_init\_\_(self,rnd\_key:jax.random.PRNGKey,heads:int,d\_model:int):
458super().\_\_init\_\_()
Split the PRNG state
461\_,\*rnd\_keys=jax.random.split(rnd\_key,5)
Number of features per head
464self.d\_k=d\_model//heads
Number of heads
466self.heads=heads
These transform the query , key and value vectors for multi-headed attention.
469self.query=Linear(rnd\_keys[0],d\_model,d\_model)470self.key=Linear(rnd\_keys[1],d\_model,d\_model)471self.value=Linear(rnd\_keys[2],d\_model,d\_model)
Output layer
474self.output=Linear(rnd\_keys[3],d\_model,d\_model)
Scaling factor before the softmax
476self.scale=1/self.d\_k\*\*0.5
query , key and value are the tensors that store collection of query, key and value vectors. They have shape [seq_len, d_model] .
mask has shape [seq_len, seq_len] and mask[i, j] indicates whether query at position i can see key-value at position j .
478def\_\_call\_\_(self,\*,479query:jnp.ndarray,480key:jnp.ndarray,481value:jnp.ndarray,482mask:Optional[jnp.ndarray]=None):
Get sequence length
493seq\_len=len(query)494495ifmaskisnotNone:
Check mask shape
497assertmask.shape[0]==query.shape[0]498assertmask.shape[1]==key.shape[0]
Same mask applied to all heads.
501mask=mask[:,:,None]
Apply linear transformations
504query=self.query(query)505key=self.key(key)506value=self.value(value)
Reshape to split into heads Input has shape [seq_len, batch_size, d_model] . We split the last dimension into heads and d_k .
511query=query.reshape(\*query.shape[:-1],self.heads,self.d\_k)512key=key.reshape(\*key.shape[:-1],self.heads,self.d\_k)513value=value.reshape(\*value.shape[:-1],self.heads,self.d\_k)
Compute attention scores QK⊤. This gives a tensor of shape [seq_len, seq_len, heads] . Sijh=d∑QihdKjhd
518scores=jnp.einsum('ihd,jhd-\>ijh',query,key)
Scale scores dkQK⊤
521scores\*=self.scale
Apply mask
524ifmaskisnotNone:525scores=scores+(mask==0)\*float('-inf')
softmax attention along the key sequence dimension seqsoftmax(dkQK⊤)
529attn=jax.nn.softmax(scores,axis=1)
Multiply by values seqsoftmax(dkQK⊤)V
533x=jnp.einsum("ijh,jhd-\>ihd",attn,value)
Concatenate multiple heads
536x=x.reshape(seq\_len,-1)
Output layer
539returnself.output(x)
This is based on our PyTorch implementation.
542classFeedForward(Module):
rnd_key is the PRNG stated_model is the number of features in a token embeddingd_ff is the number of features in the hidden layer of the FFNactivation is the activation function f552def\_\_init\_\_(self,rnd\_key:jax.random.PRNGKey,d\_model:int,d\_ff:int,553activation=jax.nn.relu):
560super().\_\_init\_\_()
Split the PRNG state
562\_,\*rnd\_keys=jax.random.split(rnd\_key,5)
Layer one parameterized by weight W1 and bias b1
565self.layer1=Linear(rnd\_keys[0],d\_model,d\_ff)
Layer one parameterized by weight W1 and bias b1
567self.layer2=Linear(rnd\_keys[1],d\_ff,d\_model)
Activation function f
569self.activation=activation
571def\_\_call\_\_(self,x:jnp.ndarray):
f(xW1+b1)
573x=self.activation(self.layer1(x))
f(xW1+b1)W2+b2
575returnself.layer2(x)
This is a transformer layer with multi-head attention and a position-wise feed-forward layer. We use pre-layer layer normalization.
578classTransformerLayer(Module):
d_model is the token embedding sizeself_attn is the self attention modulefeed_forward is the feed forward module588def\_\_init\_\_(self,589d\_model:int,590self\_attn:MultiHeadAttention,591feed\_forward:FeedForward):
597super().\_\_init\_\_()598self.size=d\_model599self.self\_attn=self\_attn600self.feed\_forward=feed\_forward601self.norm\_self\_attn=LayerNorm([d\_model])602self.norm\_ff=LayerNorm([d\_model])
604def\_\_call\_\_(self,x:jnp.ndarray,mask:jnp.ndarray):
Normalize the vectors before doing self attention
606z=self.norm\_self\_attn(x)
Run through self attention, i.e. keys and values are from self
608self\_attn=self.self\_attn(query=z,key=z,value=z,mask=mask)609x=x+self\_attn
Normalize for feed-forward
612z=self.norm\_ff(x)
Pass through the feed-forward network
614ff=self.feed\_forward(z)
Add the feed-forward results
616x=x+ff
618returnx
621classCrossEntropyLoss(Module):
628def\_\_init\_\_(self):629super().\_\_init\_\_()
Use jax.vmap to vectorize the loss function
632self.\_loss\_vmap=jax.vmap(self.\_loss,in\_axes=(0,0,))
634def\_loss(self,output:jnp.ndarray,target:jnp.ndarray):
−k∑yklogy^k
636return-jax.nn.log\_softmax(output)[target]
output is the model outputs of shape [seq_len, n_vocab]target is the target of shape [seq_len]638def\_\_call\_\_(self,output:jnp.ndarray,target:jnp.ndarray):
Use the vectorized loss function and calculate the mean.
We could have used a for loop to calculate the losses but using vmap is about 10X faster
647returnself.\_loss\_vmap(output,target).mean()
This is the transformer decode with embedding and output layers.
650classAutoregressiveTransformer(Module):
658layers:ModuleList[TransformerLayer]
rnd_key is the PRNG staten_vocab is the vocabulary sized_model is the number of features in a token embeddingn_layers is the number of transformer layersheads is the number of attention headsd_ff is the number of features in the hidden layer of the FFN660def\_\_init\_\_(self,rnd\_key:jax.random.PRNGKey,n\_vocab:int,d\_model:int,n\_layers:int,heads:int,d\_ff:int):
669super().\_\_init\_\_()670self.n\_vocab=n\_vocab671self.d\_model=d\_model672self.loss\_func=CrossEntropyLoss()
For transformer layers
675layers=[]676foriinrange(n\_layers):
Split PRNG state
678rnd\_key,mha\_key,ffn\_key=jax.random.split(rnd\_key,3)
Create a transformer layer
680attn=MultiHeadAttention(mha\_key,heads,d\_model)681ffn=FeedForward(ffn\_key,d\_model,d\_ff)682layers.append(TransformerLayer(d\_model,attn,ffn))
Make a module list
684self.layers=ModuleList(layers)
Split PRNG state
687rnd\_key,emb\_key,out\_key=jax.random.split(rnd\_key,3)
Create embedding layer
689self.embeddings=EmbeddingsWithLearnedPositionalEncoding(emb\_key,n\_vocab,d\_model)
Final normalization and output layer
691self.norm=LayerNorm([d\_model])692self.output=Linear(out\_key,d\_model,n\_vocab)
694def\_\_call\_\_(self,x:jnp.ndarray):
Get sequence length
696seq\_len=len(x)
A mask for attention so that a token can only see tokens before that
698mask=jnp.tril(jnp.ones((seq\_len,seq\_len),bool))
Get embeddings with positional encodings
700x=self.embeddings(x)
Apply the transformer layers
702foriinrange(len(self.layers)):703x=self.layers[i](x,mask)
Final normalization and linear transformation to get the logits
706returnself.output(self.norm(x))
708defget\_loss(self,x:jnp.ndarray):
Get model outputs
713output=self(x)
Cross entropy loss
715returnself.loss\_func(output[:-1],x[1:])
The starting sequence is given by seq and we greedily sample `length1 tokens
717defsample(self,seq:jnp.ndarray,length:int=20):
723foriinrange(length):
Sample the highest probability token
725idx=jnp.argmax(self(seq)[-1])
Add it to the sequence
727seq=jnp.concatenate((seq,idx[None]))
Return the sampled sequence
730returnseq
This is a named tuple for storing Adam optimizer state for a parameter
733classAdamState(NamedTuple):
737m:jnp.ndarray738v:jnp.ndarray
This is from paper Adam: A Method for Stochastic Optimization.
For parameter θt and gradient gt at step t, the Adam update is,
mtvtm^tv^tθt←β1mt−1+(1−β1)⋅gt←β2vt−1+(1−β2)⋅gt2←1−β1tmt←1−β2tvt←θt−1−α⋅v^t+ϵm^t
where α, β1, β2 and ϵ are scalar hyper parameters. mt and vt are first and second order moments. m^t and v^t are biased corrected moments. ϵ is used as a fix for division by zero error, but also acts as a form of a hyper-parameter that acts against variance in gradients.
741classAdam:
params is the tree-map of parameterslr is the learning rate αbetas is a tuple of (β1, β2)eps is ϵ^`767def\_\_init\_\_(self,params:Dict,768lr:float=0.001,betas:Tuple[float,float]=(0.9,0.999),769eps:float=1e-16,):
777super().\_\_init\_\_()778self.lr=lr779self.betas=betas780self.eps=eps
States for each parameter
783self.states=jax.tree.map(self.\_init\_state,params)
Optimized step function
785self.\_step\_jit=jax.jit(self.\_step)
Number of steps taken t
787self.\_n\_steps=0
Optimized update state function
789self.\_update\_state\_jit=jax.jit(self.\_update\_state)
Initialize the state for a given parameter
791def\_init\_state(self,param:jnp.ndarray):
795returnAdamState(jnp.zeros\_like(param),jnp.zeros\_like(param))
params is a tree-map of parametersgrads is a tree-map of gradients797defstep(self,params:Dict,grads:Dict):
Increment step t
805self.\_n\_steps+=1
Update states for each parameter
807self.states=jax.tree.map(self.\_update\_state\_jit,grads,self.states)
Return updated parameters θt
809returnjax.tree.map(partial(self.\_step\_jit,self.\_n\_steps),params,self.states)
This performs a Adam update on the given parameter
811def\_step(self,n\_steps:int,param:jnp.ndarray,state:AdamState):
Bias corrections for m^t: 1−β1t and for v^t: 1−β2t
819bias\_correction=[1-beta\*\*n\_stepsforbetainself.betas]
Uncorrected first and second moments mt and vt
821m,v=state
α1−β1t1−β2t
824step\_size=self.lr\*(bias\_correction[1]\*\*0.5)/bias\_correction[0]
vt+ϵ^
826den=(v\*\*0.5)+self.eps
θt←θt−1−α1−β1t1−β2t⋅vt+ϵ^mt
830returnparam-step\_size\*m/den
This updates uncorrected first and second moments mt and vt
832def\_update\_state(self,grad,state:AdamState):
Uncorrected first and second moments mt−1 and vt−1
839m,v=state
Clip gradients
841grad=jnp.clip(grad,-1,1)
mt←β1mt−1+(1−β1)⋅gt
843m=self.betas[0]\*m+grad\*(1-self.betas[0])
vt←β2vt−1+(1−β2)⋅gt2
845v=self.betas[1]\*v+(grad\*\*2)\*(1-self.betas[1])
Return the new state
848returnAdamState(m,v)
851classTinyShakespeare:
rnd_key is the PRNG stateseq_len is the sequence length of a samplebatch_size is the batch size858def\_\_init\_\_(self,rnd\_key:jax.random.PRNGKey,seq\_len:int,batch\_size:int):
865self.batch\_size=batch\_size
PRNG key for shuffling the samples
867\_,self.rnd\_key=jax.random.split(rnd\_key)
Local path of the text file
870path=lab.get\_data\_path()/'tiny\_shakespeare.txt'
Download if it doesn't exist
872url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'873ifnotpath.exists():874download\_file(url,path)
Read the file
877withopen(str(path),'r')asf:878self.text=f.read()
Get the characters/tokens
881tokens=sorted(list(set(self.text)))
Number of tokens
884self.n\_tokens=len(tokens)
Map tokens to ids
886self.stoi={t:ifori,tinenumerate(tokens)}
Id to token/character
888self.itos=tokens
As a list of ids
891data=jnp.array([self.stoi[s]forsinlist(self.text)])
Number of batches
893self.n\_batches=len(data)//(seq\_len\*batch\_size)
Truncate
895data=data[:self.n\_batches\*seq\_len\*batch\_size]
Reshape into a samples (better to use random offsets, but lets ignore that here)
897self.data=data.reshape((-1,seq\_len))
List of sample indexes
899self.idx=jnp.arange(len(self.data))
Setup for iteration
901def\_\_iter\_\_(self):
Iteration step
906self.\_iter\_idx=0
Split PRNG key
908self.rnd\_key,rnd\_key=jax.random.split(self.rnd\_key)
Shuffle sample indexes
910self.idx=jax.random.permutation(rnd\_key,self.idx)
913returnself
Number of batches
915def\_\_len\_\_(self):
919returnself.n\_batches
Get next batch
921def\_\_next\_\_(self):
Stop iteration after iterating through all batches
927ifself.\_iter\_idx\>=self.n\_batches:928raiseStopIteration()
Sample indexes for the batch
931idx=self.idx[self.\_iter\_idx\*self.batch\_size:(self.\_iter\_idx+1)\*self.batch\_size]
Increment iteration step
933self.\_iter\_idx+=1
Return samples
936returnself.data[idx]
939defmain():
Create experiment
947experiment.create(name='jax')
Create PRNG key
949rnd\_key=jax.random.PRNGKey(0)
Create dataset
951dataset=TinyShakespeare(rnd\_key,seq\_len=32,batch\_size=128)
Create the model
954model=AutoregressiveTransformer(rnd\_key,dataset.n\_tokens,955d\_model=128,n\_layers=3,heads=8,d\_ff=512)
Get model parameters
957params=model.get\_params()
JAX compiled pure sampling function
960pure\_sample\_fn=jax.jit(model.purify(model.sample))
JAX compiled pure function to get logits for a batch. First we transform model. __call__ to a pure function which accepts two arguments: parameters, and input sequence. Next we vectorize the function to process a batch of samples. in_axes specifies which arguments to parallelize and along which axis. (None, 0) means we have the same parameters but parallelize the inputs across the first axis. out_axes specifies along which axis to merge the results.
968pure\_forward\_fn=jax.jit(jax.vmap(model.purify(model.\_\_call\_\_),969in\_axes=(None,0),out\_axes=0))
Similarly we vectorize loss computation
971pure\_loss\_fn=jax.jit(jax.vmap(model.purify(model.get\_loss),972in\_axes=(None,0),out\_axes=0))
A function to get mean loss
975defget\_loss(params,seq):976returnpure\_loss\_fn(params,seq).mean()
A function to compute gradients for the first argument (parameters)
979grad\_loss\_fn=jax.jit(jax.grad(get\_loss,argnums=0))
Create optimizer
982optimizer=Adam(params)
Start the experiment
985withexperiment.start():
Iterate for 32 epochs
987forepochinmonit.loop(32):
Iterate through batches
989fordatainmonit.iterate('Train',dataset):
Compute and log the loss
991loss=get\_loss(params,data)992tracker.save('loss',np.asarray(loss))
Get the gradients
994grads=grad\_loss\_fn(params,data)
Update parameters
996params=optimizer.step(params,grads)
999tracker.new\_line()
Log a sample after each epoch
1001prompt=[dataset.stoi[c]forcin'It ']1002sampled=pure\_sample\_fn(params,jnp.array(prompt))[len(prompt):]1003sampled=''.join([dataset.itos[i]foriinsampled])1004sampled=sampled.replace('\n','\\n')1005logger.log(('It ',Text.meta),(sampled,Text.value))
1009if\_\_name\_\_=='\_\_main\_\_':1010main()