Back to Annotated Deep Learning Paper Implementations

LLM.int() on GPT-NeoX

docs/neox/utils/llm_int8.html

latest2.8 KB
Original Source

homeneoxutils

View code on Github

#

LLM.int() on GPT-NeoX

This implements a utility function to transform a nn.Linear layer to LLM.int8() linear layer.

LLM.int8() paper shows you can use int8 quantization while handling outliers to reduce memory footprint without performance degradation in large language models. They convert weights and inputs to scaled 8-bit integers and does matrix multiplication producing int32 results which is then converted back to float16 and rescaled. They show that in large langauge models, some features can give extreme values (outliers) that dominate the model's output. These features get clamped in 8-bit integer space which causes the model performance to degrade. As a solution they pick these outliers (greater than a specified threshold) and compute their multiplications separately in float16 space. Since the percentage of outliers is around 0.01% this doesn't increase memory usage, and prevents the model from degrading performance.

The code to transform GPT-NoeX layers is defined in model.py.

Here are example uses of GPT-NeoX with int8 quantization.

33

#

Import bitsandbytes package

34try:35frombitsandbytes.nnimportLinear8bitLt,Int8Params36exceptImportError:37raiseImportError('''Please install `bitsandbytes` with `pip install bitsandbytes -U`''')3839importtorch40fromtorchimportnn

#

Transform a nn.Linear layer to LLM.int8() linear layer

  • linear_module is the nn.Linear layer to transform
  • device is the device of the model
  • threshold is the threshold α to use for outlier detection
43defmake\_llm\_int8\_linear(linear\_module:nn.Linear,device:torch.device,threshold:float=6.0):

#

53assertisinstance(linear\_module,nn.Linear)

#

Create an empty Linear8bitLt module

56int8\_lin=Linear8bitLt(57linear\_module.in\_features,58linear\_module.out\_features,59linear\_module.biasisnotNone,60has\_fp16\_weights=False,61threshold=threshold,62)

#

Quantize the weights

65int8\_lin.\_parameters['weight']=Int8Params(linear\_module.weight.data.cpu(),66requires\_grad=False,67has\_fp16\_weights=False).to(device)

#

Set the bias in float16 space

70iflinear\_module.biasisnotNone:71int8\_lin.\_parameters['bias']=nn.Parameter(linear\_module.bias.data,72requires\_grad=False)

#

75returnint8\_lin

labml.ai