scaling_laws.ipynb
Reproducing some scaling laws results from Chinchilla. Can't get the numbers to match exactly, but can still be used as a rough guide to help determine compute-optimal models. Also contains related utilities for calculating flops and param counts.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
%matplotlib inline
First some parameter calculations:
def gpt_params(seq_len, vocab_size, d_model, num_heads, num_layers):
""" Given GPT config calculate total number of parameters """
ffw_size = 4*d_model # in GPT the number of intermediate features is always 4*d_model
# token and position embeddings
embeddings = d_model * vocab_size + d_model * seq_len
# transformer blocks
attention = 3*d_model**2 + 3*d_model # weights and biases
attproj = d_model**2 + d_model
ffw = d_model*(ffw_size) + ffw_size
ffwproj = ffw_size*d_model + d_model
layernorms = 2*2*d_model
# dense
ln_f = 2*d_model
dense = d_model*vocab_size # note: no bias here
# note: embeddings are not included in the param count!
total_params = num_layers*(attention + attproj + ffw + ffwproj + layernorms) + ln_f + dense
return total_params
gpt2 = dict(seq_len = 1024, vocab_size = 50257, d_model = 768, num_heads = 12, num_layers = 12)
gpt_params(**gpt2)/1e6
OpenAI reports gpt2 (small) as having 124M params, so this is a match. Also, loading the OpenAI weights into nanoGPT and then calling model.parameters() exactly matches the above number and verifies the implementation. Now Chinchilla parameters:
def chinchilla_params(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size):
""" Parameters in the Chinchilla models. Unlike GPT they use relative positional embeddings. """
# token embeddings only
embeddings = d_model * vocab_size
# transformer blocks
attention = 3*d_model**2 + 3*d_model # weights and biases
relative_pos = d_model**2 + 2*d_model # relative keys, content bias, relative bias
attproj = d_model**2 + d_model
ffw = d_model*ffw_size + ffw_size
ffwproj = ffw_size*d_model + d_model
layernorms = 2*2*d_model
# dense
ln_f = 2*d_model
dense = d_model*vocab_size # note: no bias here
# note: embeddings are not included in the param count!
total_params = num_layers*(attention + relative_pos + attproj + ffw + ffwproj + layernorms) + ln_f + dense
return total_params
# Load in all the 50 Chinchilla models on the last page of the paper
import json
chinchilla_models_txt = '[[44000000.0, 512, 2048, 64, 8, 8], [57000000.0, 576, 2304, 64, 9, 9], [74000000.0, 640, 2560, 64, 10, 10], [90000000.0, 640, 2560, 64, 10, 13], [106000000.0, 640, 2560, 64, 10, 16], [117000000.0, 768, 3072, 64, 12, 12], [140000000.0, 768, 3072, 64, 12, 15], [163000000.0, 768, 3072, 64, 12, 18], [175000000.0, 896, 3584, 64, 14, 14], [196000000.0, 896, 3584, 64, 14, 16], [217000000.0, 896, 3584, 64, 14, 18], [251000000.0, 1024, 4096, 64, 16, 16], [278000000.0, 1024, 4096, 64, 16, 18], [306000000.0, 1024, 4096, 64, 16, 20], [425000000.0, 1280, 5120, 128, 10, 18], [489000000.0, 1280, 5120, 128, 10, 21], [509000000.0, 1408, 5632, 128, 11, 18], [552000000.0, 1280, 5120, 128, 10, 24], [587000000.0, 1408, 5632, 128, 11, 21], [632000000.0, 1536, 6144, 128, 12, 19], [664000000.0, 1408, 5632, 128, 11, 24], [724000000.0, 1536, 6144, 128, 12, 22], [816000000.0, 1536, 6144, 128, 12, 25], [893000000.0, 1792, 7168, 128, 14, 20], [1018000000.0, 1792, 7168, 128, 14, 23], [1143000000.0, 1792, 7168, 128, 14, 26], [1266000000.0, 2048, 8192, 128, 16, 22], [1424000000.0, 2176, 8704, 128, 17, 22], [1429000000.0, 2048, 8192, 128, 16, 25], [1593000000.0, 2048, 8192, 128, 16, 28], [1609000000.0, 2176, 8704, 128, 17, 25], [1731000000.0, 2304, 9216, 128, 18, 24], [1794000000.0, 2176, 8704, 128, 17, 28], [2007000000.0, 2304, 9216, 128, 18, 28], [2283000000.0, 2304, 9216, 128, 18, 32], [2298000000.0, 2560, 10240, 128, 20, 26], [2639000000.0, 2560, 10240, 128, 20, 30], [2980000000.0, 2560, 10240, 128, 20, 34], [3530000000.0, 2688, 10752, 128, 22, 36], [3802000000.0, 2816, 11264, 128, 22, 36], [4084000000.0, 2944, 11776, 128, 22, 36], [4516000000.0, 3072, 12288, 128, 24, 36], [6796000000.0, 3584, 14336, 128, 28, 40], [9293000000.0, 4096, 16384, 128, 32, 42], [11452000000.0, 4352, 17408, 128, 32, 47], [12295000000.0, 4608, 18432, 128, 36, 44], [12569000000.0, 4608, 18432, 128, 32, 47], [13735000000.0, 4864, 19456, 128, 32, 47], [14940000000.0, 4992, 19968, 128, 32, 49], [16183000000.0, 5120, 20480, 128, 40, 47]]'
chilchilla_models = json.loads(chinchilla_models_txt) # all 50 models
chilchilla_models[0] # tuples of params, d_model, ffw_size, kv_size, n_heads, n_layers from Table A9
for m in chilchilla_models[-5:]: # only print last 5 models of the table
p, d, f, k, h, l = m
nparams = chinchilla_params(seq_len = 1024, vocab_size = 32000, d_model = d, num_heads = h, num_layers = l, ffw_size=f)
print(f"our estimated params: {nparams/1e6:.4f}M, chinchilla params: {p/1e6:.4f}M, d_model: {d}, n_heads: {h}, n_layers: {l}")
We are almost able to reproduce the parameter counts for the Chinchilla models.
Now turning to FLOPs:
def chinchilla_flops(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size):
"""
Calculate total number of FLOPs, see Chinchilla
paper Appendix F as reference: https://arxiv.org/pdf/2203.15556.pdf
"""
key_size = d_model // num_heads
# embeddings
embeddings = 2 * seq_len * vocab_size * d_model
# attention
# key, query, value projections
attention = 2 * 3 * seq_len * d_model * (key_size * num_heads)
# key @ query logits
attlogits = 2 * seq_len * seq_len * (key_size * num_heads)
# softmax
attsoftmax = 3 * num_heads * seq_len * seq_len # 3* is for subtract (max), exp, divide (?)
# softmax @ value reductions
attvalue = 2 * seq_len * seq_len * (key_size * num_heads)
# final linear
attlinear = 2 * seq_len * (key_size * num_heads) * d_model
att = attention + attlogits + attsoftmax + attvalue + attlinear
# feed forward
dense = 2 * seq_len * (d_model * ffw_size + d_model * ffw_size)
# logits
logits = 2 * seq_len * d_model * vocab_size
# this is what you'd expect:
# forward_flops = embeddings + num_layers * (att + dense) + logits
# but:
# per author correspondence apparently there is typo in the paper,
# they do not count embeddings and logits to repro table 4. So instead:
forward_flops = num_layers * (att + dense)
backward_flops = 2 * forward_flops # as in Kaplan et al. 2020
total_flops = forward_flops + backward_flops
return total_flops
# Now try reproduce Table A4 from Chinchilla paper Appendix,
# comparing accurate flops above to approximate flops F = 6*N*D
# note Chinchilla mentions using vocab_size = 32K
chilchilla_models_table4 = [
[10, 640, 2560, 10, 64],
[20, 1024, 4096, 16, 64],
[24, 1280, 5120, 10, 128 ],
[26, 1792, 7168, 14, 128 ],
[28, 2048, 8192, 16, 128],
[40, 3584, 14336, 28, 128]
]
rows = []
for num_layers, d_model, ffw_size, num_heads, _ in chilchilla_models_table4:
args = dict(seq_len = 2048, vocab_size = 32000, d_model = d_model,
num_heads = num_heads, num_layers = num_layers, ffw_size=ffw_size)
D = args['seq_len'] # dataset size (cancels anyway, for the purposes of the ratio calculation below)
N = chinchilla_params(**args)
F = chinchilla_flops(**args)
approx_flops = 6*D*N # approximate flops
chinch_flops = F * (float(D) / args['seq_len']) # exact flops according to Chinchilla paper calculations
# print('---')
# print(f"params: {N/1e6:.2f}M")
# print(f"approx flops: {approx_flops/1e9:.2f}B")
# print(f"chinchilla flops: {chinch_flops/1e9:.2f}B")
# print(f"ratio (chinchilla / approx): {chinch_flops / approx_flops:.2f}")
# first copy all keyvalues from args into out
out = {k:v for k,v in args.items()}
# then add the calculated values
out['N'] = N
out['F'] = F
out['approx_flops'] = approx_flops
out['chinch_flops'] = chinch_flops
out['ratio'] = chinch_flops / approx_flops
rows.append(out)
# make a pandas dataframe from rows
df = pd.DataFrame(rows)
df
Pretty good match! Except the param counts are still not perfectly accurate.
In their "Aproach 3", Chinchilla paper fits a function L(N,D) to approximate the final loss gives the model size and the data size. Here is the final fit:
def L(N, D):
"""
Approximates loss given N parameters and D dataset size (in tokens),
per Chinchilla paper.
"""
E = 1.69 # entropy of natural language, limit of infinite model on infinite data
A = 406.4
B = 410.7
alpha = 0.34
beta = 0.28
return A / (N ** alpha) + B / (D ** beta) + E
ns = 10 ** np.arange(7, 11, step=2**-4) # model sizes from 10M to 100B
ds = 10 ** np.arange(9, 12, step=2**-4) # dataset sizes from 1B to 1T
plt.figure(figsize=(12, 5))
plt.subplot(121)
# create a 2D countour plot of loss L as a function of model size and dataset size in ns,ds
loss2d = np.log10(np.array([[L(n, d) for d in ds] for n in ns]))
plt.imshow(loss2d, extent=[9, 12, 7, 11], origin='lower', alpha=0.5)
plt.contour(loss2d, levels=30, extent=[9, 12, 7, 11], origin='lower')
plt.xlabel('log10(dataset size)')
plt.ylabel('log10(model size)')
plt.title('loss')
plt.colorbar()
# plot the compute for each point, which is a deterministic function: flops = 6*N*D
plt.subplot(122)
compute2d = np.log10(np.array([[6*n*d for d in ds] for n in ns]))
plt.imshow(compute2d, extent=[9, 12, 7, 11], origin='lower', alpha=0.5)
plt.contour(compute2d, levels=30, extent=[9, 12, 7, 11], origin='lower')
plt.xlabel('log10(dataset size)')
plt.ylabel('log10(model size)')
plt.title('log10 flops')
plt.colorbar()
Ok so given any N,D we can estimate both: 1) the loss, and 2) the total flops. Now we want to solve the following problem: Given a specific budget of flops C, find: N_opt, D_opt = argmin_{FLOPs(N,D) = C} L(N, D). i.e. how big of a model should we train and for how many tokens?
c = 2.21e19 # target compute budget (usually know this because we know how many GPU for how long go brrr)
# (I got this flop number from row 1 of Table A3)
# sweep model sizes from 10M to 100B
ns = 10 ** np.arange(7, 11, step=2**-4)
# using C = 6*N*D, solve for D that maintains the compute budget c
ds = c / (6 * ns)
# evaluate the loss in each case
losses = L(ns, ds)
# find the argmin
best = np.argmin(losses)
print(f"best model size: {ns[best]/1e6:.2f}M")
print(f"best dataset size: {ds[best]/1e9:.2f}B")
# plot the loss
plt.figure(figsize=(3,3))
plt.plot(ns, losses)
plt.xscale('log')
# plot a vertical bar at the best model size
plt.axvline(ns[best], color='red')
plt.xlabel('model size')
plt.ylabel('loss')
In the plot above, basically the models on the left of best are too small and trained for too long. The models on the right of best are way too large and trained for too little. The model at the red line is just right.
Now, the Chinchilla paper says that best model size for this flop budget is 400M params and 9.2B tokens (instead of 316M params and 11.65B tokens) so there is some unresolved disagreement here too...
# Calculate the Chinchilla optimal models for a range of compute budgets
# sweep over compute budgets from 1e17 to 1e26
cs = 10 ** np.arange(17, 26, step=2**-8)
models = []
for c in cs:
# sweep over model sizes
ns = 10 ** np.arange(7, 14, step=2**-8)
# the dataset sizes that would maintain the given compute budget
ds = c / (6 * ns)
# losses at each point
losses = L(ns, ds)
# n,d for the best model
best = np.argmin(losses)
models.append((c, ns[best], ds[best])) # c, n, d tuple log
len(models)
query_model_size = 400e6
ns = np.array([n for c, n, d in models])
ds = np.array([d for c, n, d in models])
# find the index of the closest model size in ns
ix = np.argmin(np.abs(ns - query_model_size))
# retrieve the corresponding params, flops, and data size
print("closest model found:")
print(f"model size: {ns[ix]/1e6:.2f}M")
print(f"dataset size: {ds[ix]/1e9:.2f}B")
print(f"flops: {6*ns[ix]*ds[ix]:e}")
print(f"loss: {L(ns[ix], ds[ix]):.2f}")
This should have come out as 9.2B according to Table A3 in Chinchilla paper, per my understanding of it.
Approach 2 is probably my favorite one because it fixes a flop budget and runs a number of model/dataset sizes, measures the loss, fits a parabolla, and gets the minimum. So it's a fairly direct measurement of what we're after. The best way to then calculate the compute-optimal number of tokens for any given model size, as an example, is via simple interpolation.
# Approach 1 numbers
# # parameters, tokens
# raw = [
# [400e6, 8e9],
# [1e9, 20.2e9],
# [10e9, 205.1e9],
# [67e9, 1.5e12],
# [175e9, 3.7e12],
# [280e9, 5.9e12],
# [520e9, 11e12],
# [1e12, 21.2e12],
# [10e12, 216.2e12],
# ]
# Approach 2 numbers
# parameters, tokens
raw = [
[400e6, 7.7e9],
[1e9, 20.0e9],
[10e9, 219.5e9],
[67e9, 1.7e12],
[175e9, 4.3e12],
[280e9, 7.1e12],
[520e9, 13.4e12],
[1e12, 26.5e12],
[10e12, 292.0e12],
]
# fit a line by linear regression to the raw data
import numpy as np
x = np.array([np.log10(x[0]) for x in raw])
y = np.array([np.log10(x[1]) for x in raw])
A = np.vstack([x, np.ones(len(x))]).T
m, c = np.linalg.lstsq(A, y, rcond=None)[0]
print(f"y = {m}x + {c}")
plt.figure(figsize=(3, 3))
# plot the line
plt.plot([q[0] for q in raw], [10**(m*np.log10(q[0]) + c) for q in raw], label='linear regression', color='r')
# plot the raw data
plt.scatter([q[0] for q in raw], [q[1] for q in raw], label='raw data')
plt.xscale('log')
plt.yscale('log')
plt.xlabel('parameters')
plt.ylabel('tokens')
plt.title('compute optimal models')
plt.grid()
xquery = 124e6 # query model size here (e.g. GPT-2 small is 124M)
yquery = 10**(m*np.log10(xquery) + c)
print(f"predicted parameters for {xquery:e} tokens: {yquery:e}")