Back to Mvision

TWN 三值系数网络

CNN/Deep_Compression/quantization/TWN/readme.md

latest5.6 KB
Original Source

TWN 三值系数网络

权值三值化的核心:
    首先,认为多权值相对比于二值化具有更好的网络泛化能力。
    其次,认为权值的分布接近于一个正态分布和一个均匀分布的组合。
    最后,使用一个 scale 参数去最小化三值化前的权值和三值化之后的权值的 L2 距离。   
    

caffe-代码

论文 Ternary weight networks

论文翻译参考

参考2

算法

这个算法的核心是只在前向和后向过程中使用使用权值简化,但是在update是仍然是使用连续的权值。

简单的说就是先利用公式计算出三值网络中的阈值:

也就是说,将每一层的权值绝对值求平均值乘以0.7算出一个deta作为三值网络离散权值的阈值,
具体的离散过程如下:

其实就是简单的选取一个阈值(Δ),
大于这个阈值的权值变成 1,小于-阈值的权值变成 -1,其他变成 0。
当然这个阈值其实是根据权值的分布的先验知识算出来的。
本文最核心的部分其实就是阈值和 scale 参数 alpha 的推导过程。

在参数三值化之后,作者使用了一个 scale 参数去让三值化之后的参数更接近于三值化之前的参数。
根据一个误差函数 推导出 alpha 再推导出 阈值(Δ)

这样,我们就可以把连续的权值变成离散的(1,0,-1),

那么,接下来我们还需要一个alpha参数,具体干什么用后面会说(增强表达能力)
这个参数的计算方式如下:

|I(deta)|这个参数指的是权值的绝对值大于deta的权值个数,计算出这个参数我们就可以简化前向计算了,
具体简化过程如下:

可以看到,在把alpha乘到前面以后,我们把复杂的乘法运算变成了简单的加法运算,从而加快了整个的训练速度。

主要思想就是三值化参数(激活量与梯度精度),参照BWN使用了缩放因子。
由于相同大小的filter,
三值化比二值化能蕴含更多的信息,
因此相比于BWN准确率有所提高。

this blog is useful for implementation

Experimental Results

Three data sets are used in this paper, including MNIST, CIFAR-10, ImageNet. To different data sets, the authors conducted experiments using LeNet-5 (32-C5 + MP2 + 64-C5 + MP2 + 512FC + SVM), VGG-inspired network (2$$\times$$(128-C3) + MP2 + 2$$\times$$(256-C3) + MP2 + 2$$\times$$(512-C3) + MP2 + 1024-FC + Softmax), ResNet-18, respectively. Network architecture and parameters setting for different data sets are shown as follows:

MNISTCIFAR-10ImageNet
network architectureLeNet-5VGG-7ResNet-18 (B)
weight decay1e-41e-41e-4
mini-batch size of BN5010064($$\times$$4 GPUs)
initial learning rate0.010.10.1
learning rate decay (divided by 10) epochs15, 2580, 12030, 40, 50
momentum0.90.90.9

Comparison of the proposed method and the previous methods are shown as follows:

MethodMINISTCIFAR-10ImageNet Top1 (ResNet-18 / ResNet-18B)ImageNet Top5 (ResNet-18 / ResNet-18B)
TWN99.3592.5661.8 / 65.384.2 / 86.2
BPWN99.0590.1857.5 / 61.681.2 / 83.9
FPWN (full precision)99.4192.8865.4 / 67.686.76 / 88.0
Binary Connect98.8291.73--
Binarized Neural Networks88.689.85--
Binary Weight Networks--60.883.0
XNOR-Net--51.273.2