docs/normalization/weight_standardization/conv2d.html
homenormalizationweight_standardization
This is an implementation of a 2 dimensional convolution layer with Weight Standardization
13importtorch14importtorch.nnasnn15fromtorch.nnimportfunctionalasF1617fromlabml\_nn.normalization.weight\_standardizationimportweight\_standardization
This extends the standard 2D Convolution layer and standardize the weights before the convolution step.
20classConv2d(nn.Conv2d):
26def\_\_init\_\_(self,in\_channels,out\_channels,kernel\_size,27stride=1,28padding=0,29dilation=1,30groups:int=1,31bias:bool=True,32padding\_mode:str='zeros',33eps:float=1e-5):34super(Conv2d,self).\_\_init\_\_(in\_channels,out\_channels,kernel\_size,35stride=stride,36padding=padding,37dilation=dilation,38groups=groups,39bias=bias,40padding\_mode=padding\_mode)41self.eps=eps
43defforward(self,x:torch.Tensor):44returnF.conv2d(x,weight\_standardization(self.weight,self.eps),self.bias,self.stride,45self.padding,self.dilation,self.groups)
A simple test to verify the tensor sizes
48def\_test():
52conv2d=Conv2d(10,20,5)53fromlabml.loggerimportinspect54inspect(conv2d.weight)55importtorch56inspect(conv2d(torch.zeros(10,10,100,100)))575859if\_\_name\_\_=='\_\_main\_\_':60\_test()