Back to Annotated Deep Learning Paper Implementations

schedule.py

docs/helpers/schedule.html

latest2.5 KB
Original Source

homehelpers

View code on Github

#

1fromtypingimportTuple,List

#

4classSchedule:

#

5def\_\_call\_\_(self,x):6raiseNotImplementedError()

#

9classFlat(Schedule):

#

10def\_\_init\_\_(self,value):11self.\_\_value=value

#

13def\_\_call\_\_(self,x):14returnself.\_\_value

#

16def\_\_str\_\_(self):17returnf"Schedule({self.\_\_value})"

#

20classDynamic(Schedule):

#

21def\_\_init\_\_(self,value):22self.\_\_value=value

#

24def\_\_call\_\_(self,x):25returnself.\_\_value

#

27defupdate(self,value):28self.\_\_value=value

#

30def\_\_str\_\_(self):31return"Dynamic"

#

Piecewise schedule

34classPiecewise(Schedule):

#

Initialize

endpoints is list of pairs (x, y) . The values between endpoints are linearly interpolated. y values outside the range covered by x are outside_value .

39def\_\_init\_\_(self,endpoints:List[Tuple[float,float]],outside\_value:float=None):

#

(x, y) pairs should be sorted

50indexes=[e[0]foreinendpoints]51assertindexes==sorted(indexes)5253self.\_outside\_value=outside\_value54self.\_endpoints=endpoints

#

Find y for given x

56def\_\_call\_\_(self,x):

#

iterate through each segment

62for(x1,y1),(x2,y2)inzip(self.\_endpoints[:-1],self.\_endpoints[1:]):

#

interpolate if x is within the segment

64ifx1\<=x\<x2:65dx=float(x-x1)/(x2-x1)66returny1+dx\*(y2-y1)

#

return outside value otherwise

69returnself.\_outside\_value

#

71def\_\_str\_\_(self):72endpoints=", ".join([f"({e[0]}, {e[1]})"foreinself.\_endpoints])73returnf"Schedule[{endpoints}, {self.\_outside\_value}]"

#

76classRelativePiecewise(Piecewise):

#

77def\_\_init\_\_(self,relative\_endpoits:List[Tuple[float,float]],total\_steps:int):78endpoints=[]79foreinrelative\_endpoits:80index=int(total\_steps\*e[0])81assertindex\>=082endpoints.append((index,e[1]))8384super().\_\_init\_\_(endpoints,outside\_value=relative\_endpoits[-1][1])

labml.ai