Back to Annotated Deep Learning Paper Implementations

Cache for Intermediate Activations

docs/neox/utils/cache.html

latest2.6 KB
Original Source

homeneoxutils

View code on Github

#

Cache for Intermediate Activations

During inference the model outputs token by token. We use this simple cache to store key's and value's attention layers, so that we don't have to recompute them for previous tokens.

15fromtypingimportAny

#

Cache

This maintains a key-value cache and queues push values and pop them in the same order. The queues are useful since we have multiple attention layers.

18classCache:

#

26def\_\_init\_\_(self):27self.\_cache={}

#

Clear cache

29defclear\_all(self):

#

33self.\_cache={}

#

Push a value to a queue

  • name is the name of the queue
  • value is the value to be pushed
35defpush(self,name:str,value:Any):

#

Create an empty queue if it's not present

44ifnamenotinself.\_cache:45self.\_cache[name]=[]

#

Push to the queue

48self.\_cache[name].append(value)

#

Return the size of the queue

  • name is the name of the queue

Returns size of the queue if exists else None

50defq\_size(self,name):

#

58ifnamenotinself.\_cache:59returnNone6061iftype(self.\_cache[name])!=list:62returnNone6364returnlen(self.\_cache[name])

#

Pop from a queue

  • name is the name of the queue

Returns the value

66defpop(self,name:str):

#

73returnself.\_cache[name].pop(0)

#

Cache a value

  • key is the name of the value to be cached
  • value is the value
75defset(self,key:str,value:Any):

#

82self.\_cache[key]=value

#

Retrieve a value from cache

  • key is the name used when caching
  • default is the default value if the cache is empty

Returns the cached value

84defget(self,key:str,default:Any=None):

#

92returnself.\_cache.get(key,default)

#

Clear a cache value

  • key is the name used when caching
94defclear(self,key:str):

#

100delself.\_cache[key]

#

Singleton for cache

104\_INSTANCE=None

#

Get the cache instance

Returns the cache instance

107defget\_cache()-\>Cache:

#

113global\_INSTANCE114115if\_INSTANCEisNone:116\_INSTANCE=Cache()117118return\_INSTANCE

labml.ai