Back to Nanochat

Scaling Laws Analysis

dev/scaling_analysis.ipynb

latest11.0 KB
Original Source

Scaling Laws Analysis

Analyze results from scaling_laws.sh to find the optimal param:data ratio for nanochat.

python
%matplotlib inline
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Load results
tag = "jan26"
base_dir = os.environ.get('NANOCHAT_BASE_DIR', os.path.expanduser('~/.cache/nanochat'))
results_path = os.path.join(base_dir, f'scaling_laws_results_{tag}', 'results.csv')

df = pd.read_csv(results_path)
flops_budgets = sorted(df['flops_budget'].unique())
print(f"Loaded {len(df)} runs across {len(flops_budgets)} FLOPs budgets")
print(f"Columns: {list(df.columns)}")
df
python
# =============================================================================
# FILTERING: Remove incomplete or problematic runs
# =============================================================================

print(f"Before filtering: {len(df)} runs")

# Filter out runs with missing/invalid val_bpb (incomplete runs)
df = df[df['val_bpb'].notna() & (df['val_bpb'] > 0)]

# Optional: exclude specific flops budgets that aren't done yet
# exclude_flops = [1e19]  # <-- adjust as runs complete
# df = df[~df['flops_budget'].isin(exclude_flops)]

# Optional: exclude specific depths
# exclude_depths = [18, 20]
# df = df[~df['depth'].isin(exclude_depths)]

print(f"After filtering: {len(df)} runs")
print(f"FLOPs budgets: {sorted(df['flops_budget'].unique())}")
print(f"Depths: {sorted(df['depth'].unique())}")

# Update flops_budgets list after filtering
flops_budgets = sorted(df['flops_budget'].unique())

Effective Parameter Count

Different scaling law papers use different conventions for counting parameters:

  • Kaplan et al. excluded embedding parameters (claimed cleaner laws)
  • Chinchilla included all parameters (and noted Kaplan had a bug)

Our CSV now has granular counts:

  • params_wte - token embedding (lookup table)
  • params_value_embeds - value embeddings (lookup table)
  • params_lm_head - unembedding projection (matmul)
  • params_transformer - attention + MLP matrices (matmuls)
  • params_scalars - resid/x0/bigram lambdas (tiny)

Experiment below with different combinations to see which gives the cleanest scaling laws.

python
# =============================================================================
# EXPERIMENT HERE: Define which parameters to count for scaling laws
# =============================================================================

def compute_effective_params(row):
    """
    Compute the 'effective' parameter count for scaling law analysis.

    Modify this function to experiment with different conventions:
    - Chinchilla-style: include everything
    - Kaplan-style: exclude embeddings
    - Matmul-only: just transformer + lm_head (the actual compute)
    - etc.
    """
    # Option 1: Chinchilla-style (all params)
    # return row['params_total']

    # Option 2: Kaplan-style (exclude embeddings)
    return row['params_transformer'] + row['params_lm_head']

    # Option 3: Transformer-only (exclude all embeddings AND lm_head)
    # return row['params_transformer']


# Compute derived columns
df = df.copy()  # avoid SettingWithCopyWarning from earlier filter
df['effective_params'] = df.apply(compute_effective_params, axis=1)
df['param_data_ratio'] = df['tokens_trained'] / df['effective_params']

# Show parameter breakdown for first few rows
print("Parameter breakdown (first row per flops budget):")
param_cols = ['depth', 'params_wte', 'params_value_embeds',
              'params_lm_head', 'params_transformer', 'params_scalars', 'params_total', 'effective_params']
df.groupby('flops_budget').first()[param_cols]

IsoFLOP Curves (à la Chinchilla)

For each compute budget, plot loss vs model size. Looking for the U-shape valley that reveals the optimal model size for each FLOPs budget.

python
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Plot 1: IsoFLOP curves - Val BPB vs Parameters (the Chinchilla plot!)
ax = axes[0]
colors = plt.cm.viridis(np.linspace(0, 0.9, len(flops_budgets)))
optimal_by_bpb = []

for flops, color in zip(flops_budgets, colors):
    subset = df[df['flops_budget'] == flops].sort_values('effective_params')
    ax.plot(subset['effective_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)

    # Fit quadratic in log-space: val_bpb = a*(log N)^2 + b*(log N) + c
    log_params = np.log10(subset['effective_params'])
    coeffs = np.polyfit(log_params, subset['val_bpb'], 2)
    a, b, c = coeffs

    # Plot fitted curve (dashed)
    log_fit_x = np.linspace(log_params.min() - 0.1, log_params.max() + 0.1, 100)
    fit_y = a * log_fit_x**2 + b * log_fit_x + c
    ax.plot(10**log_fit_x, fit_y, '--', color=color, linewidth=2)

    # Find minimum of quadratic: d/dx(ax^2 + bx + c) = 0 => x = -b/(2a)
    if a > 0:  # parabola opens upward (has a minimum)
        log_opt = -b / (2 * a)
        opt_params = 10**log_opt
        opt_bpb = a * log_opt**2 + b * log_opt + c
        # Mark the fitted optimal
        ax.scatter([opt_params], [opt_bpb], s=150, color=color,
                   zorder=5, edgecolors='black', linewidths=2, marker='*')
        # Interpolate tokens and ratio from actual data (don't use C≈6ND approximation)
        opt_tokens = np.interp(np.log10(opt_params), log_params, subset['tokens_trained'])
        opt_ratio = np.interp(np.log10(opt_params), log_params, subset['param_data_ratio'])
        optimal_by_bpb.append({'flops': flops, 'params': opt_params, 'tokens': opt_tokens, 'ratio': opt_ratio, 'bpb': opt_bpb})
    else:
        # Fallback to raw minimum if quadratic doesn't have minimum
        best_idx = subset['val_bpb'].idxmin()
        best = subset.loc[best_idx]
        ax.scatter([best['effective_params']], [best['val_bpb']], s=150, color=color,
                   zorder=5, edgecolors='black', linewidths=2)
        optimal_by_bpb.append({'flops': flops, 'params': best['effective_params'],
                              'tokens': best['tokens_trained'], 'ratio': best['param_data_ratio'], 'bpb': best['val_bpb']})

ax.set_xscale('log')
ax.set_xlabel('Effective Parameters')
ax.set_ylabel('Validation Loss (bpb)')
ax.set_title('IsoFLOP Curves')
ax.legend(title='FLOPs', loc='upper right')
ax.grid(True, alpha=0.3)

opt_df = pd.DataFrame(optimal_by_bpb)

# Plot 2: Optimal model size vs compute (power law)
ax = axes[1]
ax.loglog(opt_df['flops'], opt_df['params'], 'o', markersize=10, color='#2ecc71')
ax.set_xlabel('FLOPs')
ax.set_ylabel('Optimal Parameters')
ax.set_title('Optimal Model Size')
ax.grid(True, alpha=0.3)

# Fit and show power law
if len(opt_df) >= 2:
    log_f = np.log10(opt_df['flops'])
    log_p = np.log10(opt_df['params'])
    slope, intercept = np.polyfit(log_f, log_p, 1)
    fit_f = np.logspace(log_f.min() - 0.5, log_f.max() + 0.5, 100)
    fit_p = 10**(intercept + slope * np.log10(fit_f))
    ax.plot(fit_f, fit_p, 'r--', alpha=0.7, label=f'N ∝ C^{slope:.2f}')
    ax.legend()

# Plot 3: Optimal tokens vs compute (power law)
ax = axes[2]
ax.loglog(opt_df['flops'], opt_df['tokens'], 'o', markersize=10, color='#e74c3c')
ax.set_xlabel('FLOPs')
ax.set_ylabel('Optimal Tokens')
ax.set_title('Optimal Training Tokens')
ax.grid(True, alpha=0.3)

# Fit and show power law
if len(opt_df) >= 2:
    log_f = np.log10(opt_df['flops'])
    log_t = np.log10(opt_df['tokens'])
    slope, intercept = np.polyfit(log_f, log_t, 1)
    fit_f = np.logspace(log_f.min() - 0.5, log_f.max() + 0.5, 100)
    fit_t = 10**(intercept + slope * np.log10(fit_f))
    ax.plot(fit_f, fit_t, 'r--', alpha=0.7, label=f'D ∝ C^{slope:.2f}')
    ax.legend()

plt.tight_layout()
plt.show()

# Print the optimal points (from quadratic fits)
print("\nOptimal configurations (from quadratic fits):")
print(f"{'FLOPs':<12} {'Eff Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}")
print("-" * 65)
for _, row in opt_df.iterrows():
    print(f"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}")
python
# =============================================================================
# Optimal Ratio Summary (from power law fits)
# =============================================================================

# From the power law fits: N ∝ C^a and D ∝ C^b
# The ratio D/N ∝ C^(b-a). If a ≈ b, ratio is roughly constant.

if len(opt_df) >= 2:
    log_f = np.log10(opt_df['flops'])
    log_p = np.log10(opt_df['params'])
    log_t = np.log10(opt_df['tokens'])

    # Fit power laws
    slope_n, intercept_n = np.polyfit(log_f, log_p, 1)
    slope_d, intercept_d = np.polyfit(log_f, log_t, 1)

    # The ratio D/N at a reference compute (geometric mean of our budgets)
    ref_flops = np.sqrt(opt_df['flops'].min() * opt_df['flops'].max())
    log_ref = np.log10(ref_flops)

    # Predicted optimal N and D at reference compute
    pred_log_n = intercept_n + slope_n * log_ref
    pred_log_d = intercept_d + slope_d * log_ref
    optimal_ratio = 10**(pred_log_d - pred_log_n)

    # Also compute from the fitted optimals directly (mean and std)
    mean_ratio = opt_df['ratio'].mean()
    std_ratio = opt_df['ratio'].std()

    print("=" * 60)
    print("OPTIMAL RATIO SUMMARY")
    print("=" * 60)
    print(f"\nPower law exponents:")
    print(f"  N ∝ C^{slope_n:.3f}")
    print(f"  D ∝ C^{slope_d:.3f}")
    print(f"  Ratio exponent (b-a): {slope_d - slope_n:.3f}  (should be ~0 if ratio is constant)")
    print(f"\nOptimal ratio (tokens per effective param):")
    print(f"  From power law at C={ref_flops:.1e}: {optimal_ratio:.1f}")
    print(f"  Mean across budgets: {mean_ratio:.1f} ± {std_ratio:.1f}")
    print(f"  Chinchilla reference: 20")
    print(f"\nPer-budget ratios: {[f'{r:.1f}' for r in opt_df['ratio'].values]}")
else:
    print("Need at least 2 flops budgets to compute power law fits")

Val BPB vs Depth and Ratio

python
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Val BPB vs Depth
ax = axes[0]
for flops in flops_budgets:
    subset = df[df['flops_budget'] == flops].sort_values('depth')
    ax.plot(subset['depth'], subset['val_bpb'], 'o-', label=f'{flops:.0e}')
    # Mark the best (lowest)
    best_idx = subset['val_bpb'].idxmin()
    best = subset.loc[best_idx]
    ax.scatter([best['depth']], [best['val_bpb']], s=100, zorder=5, edgecolors='black', linewidths=2)

ax.set_xlabel('Depth')
ax.set_ylabel('Val BPB (lower is better)')
ax.set_title('Validation BPB vs Model Depth')
ax.legend(title='FLOPs')
ax.grid(True, alpha=0.3)

# Plot 2: Val BPB vs Param:Data Ratio
ax = axes[1]
for flops in flops_budgets:
    subset = df[df['flops_budget'] == flops].sort_values('param_data_ratio')
    ax.plot(subset['param_data_ratio'], subset['val_bpb'], 'o-', label=f'{flops:.0e}')
    best_idx = subset['val_bpb'].idxmin()
    best = subset.loc[best_idx]
    ax.scatter([best['param_data_ratio']], [best['val_bpb']], s=100, zorder=5, edgecolors='black', linewidths=2)

ax.axvline(x=20, color='red', linestyle='--', alpha=0.5, label='Chinchilla (20)')
ax.set_xlabel('Param:Data Ratio (tokens/param)')
ax.set_ylabel('Val BPB (lower is better)')
ax.set_title('Val BPB vs Param:Data Ratio')
ax.legend(title='FLOPs')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()