dev/LOG.md
A running summary documenting some experiments and findings. Started ~Jan 7 2026.
Replaced torch.amp.autocast throughout the codebase with explicit dtype management via a single COMPUTE_DTYPE global. Also added fp16 training support with GradScaler.
autocast is "magic we don't control" — it silently decides which ops run in which precision via internal allowlists. For this codebase, autocast was doing very little: the only thing it actually cast was nn.Linear weights from fp32 to bf16 for matmuls. F.rms_norm, F.cross_entropy, and Flash Attention all handle their own dtypes already. By making precision explicit, we gain fine-grained control (e.g. can experiment with fp32 norms) and eliminate an unnecessary layer of abstraction.
Core mechanism (nanochat/common.py, nanochat/gpt.py):
COMPUTE_DTYPE auto-detected from hardware: SM 80+ → bf16, pre-Ampere → fp32, CPU/MPS → fp32. Override via NANOCHAT_DTYPE env var.Linear(nn.Linear) class that casts weights to match input dtype in forward: F.linear(x, self.weight.to(dtype=x.dtype)). This is the single mechanism that replaces autocast.COMPUTE_DTYPE at init (saves memory). Exception: fp16 keeps embeddings fp32 because GradScaler cannot unscale fp16 gradients.COMPUTE_DTYPE in GPT.forward() (no-op for bf16, active for fp16 path).COMPUTE_DTYPE instead of hardcoded bf16.Autocast removal (11 files):
--dtype CLI flag, ptdtype variables, autocast_ctx definitions, and all with autocast_ctx: blocks from: base_train.py, chat_sft.py, chat_rl.py, chat_cli.py, chat_eval.py, chat_web.py, base_eval.py, engine.py, bench_train_toks.py, test_e2e_pipeline.py.fp16 + GradScaler (base_train.py, chat_sft.py):
scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else Nonescaler.scale(loss).backward() vs plain loss.backward()scaler.unscale_(optimizer) → distributed inf-sync via scaler._found_inf_per_device(optimizer) all-reduced with ReduceOp.MAX → scaler.step(optimizer) → scaler.update()FP8 fix (nanochat/fp8.py, base_train.py):
Float8Linear.forward explicitly casts input to COMPUTE_DTYPE (previously relied on autocast).disable_fp8 context manager now creates our custom Linear (not vanilla nn.Linear) when swapping out Float8Linear during eval.Flash Attention (flash_attention.py):
USE_FA3 (module-level constant, resolved once at import) returns False, falling back to SDPA.Switched the pretraining dataset from FineWeb-EDU 100B to ClimbMix 400B. This is by far the single biggest improvement to nanochat's GPT-2 speedrun time, bringing it down from 2 hours 46 minutes to 2 hours 1 minute — a 27% reduction.
ClimbMix 400B is a curated 400B-token pretraining mixture hosted at karpathy/climbmix-400b-shuffle on HuggingFace. It comes form NVIDIA. It is a blend of high-quality web text, code, math, and other sources, designed to be a better general-purpose pretraining dataset than FineWeb-EDU alone.
karpathy/fineweb-edu-100b-shuffle → karpathy/climbmix-400b-shuffle (up to 6543 shards available vs the previous 1823 data shards, allowing for longer training in the future)base_data/ → base_data_climbmix/ (clean separation from legacy data)list_parquet_files() that detects the old base_data/ directory and falls back gracefully, so existing users see clear upgrade instructions on git pullThis is the sixth attempt at beating FineWeb-EDU on CORE score — the previous five all failed (see entries on 2026-02-17, 2026-02-10, 2026-01-12 below). ClimbMix is the first dataset to convincingly surpass it, and the margin is large enough to also shrink the model from d26 to d24.
Quick experiment to tune logit softcap on d24 scale. Tried 5..30. 5 was terrible, the rest of them were all about equal with the exception of 20, which was the best. Minor but solid improvement: val loss improved by ~1e-3 (0.716 -> 0.715). Setting as default.
Implemented a DeepSeekV3-style Mixture of Experts layer as a drop-in replacement for the dense MLP. The MoE branch works and improves per-step validation loss, but is not a net improvement on wall clock time due to MoE overhead (at least for our scale of interest of approx GPT-2 capability).
Follows DeepSeekV3 and using torchtitan as reference:
expert_hidden_dim = round(4 * dim / (top_k + num_shared) / 128) * 128, so active FLOPs per token match the dense MLPtorch._grouped_mm for dispatching tokens to experts in a single kernel (instead of a Python for-loop)(num_experts, hidden, dim) — Muon's Polar Express operates on the last two dims, so each expert is independently orthogonalizedtop_k + shared experts, not all 8)nn.Linear MLP that runs on all tokens alongside the routed path.second_momentum_buffer shape to preserve leading dims.torch._grouped_mm quirks: requires bf16 (not fp32), column-major right operand, int32 cumulative offsets. The API is undocumented and only discoverable by trial and error.torch._grouped_mm does NOT support FP8. There's a separate torch._scaled_grouped_mm API that requires per-row scaling (not per-tensor like our Float8Linear). The backward pass for weight gradients needs per-group column-wise scaling, which torchao implements with custom Triton kernels. We investigated thoroughly (see dev/moe_fp8.md) but did not implement — would require either depending on torchao.prototype (unstable) or writing ~200 lines of custom autograd + quantization code. Partial FP8 support exists: the shared expert's nn.Linear layers do get converted, but the routed experts (3D nn.Parameter) stay in bf16.torch._scaled_grouped_mm with a custom _Float8GroupedMatmul autograd function, with bf16 fallback for weight gradient (avoiding the per-group column-wise Triton kernels).What's really needed is a fused "FlashMoE" kernel that handles routing + expert dispatch + matmul in one shot (like FlashAttention did for attention), with all the needed features. This doesn't exist yet. Rawdogging MoE with current PyTorch primitives is painful — lots of sorting, gathering, scattering, and layout wrangling around the actual compute.
MoE is not worth the trouble for nanochat right now. The code bloat is substantial (moe.py, router, shared expert, load balancing, optimizer fixes, FP8 gaps, active param counting) and the performance is worse wall-clock at our scale of interest. The fundamental issue is that the grouped_mm dispatch overhead eats the FLOP savings from sparsity, at least at our model scales and sequence lengths.
Tried vanilla fineweb instead of fineweb-edu dataset. Significantly, shockingly worse results:
This is the fifth failed attempt to beat pure FineWeb-EDU on CORE score.
Tried hynky/finepdfs_50BT-dclm_30BT-fineweb_edu_20BT, a mixture of FinePDFs, DCLM, and FineWeb-EDU. Slightly worse on both model sizes tested:
This is the fourth failed attempt to beat pure FineWeb-EDU on CORE score.
Brought chat_sft.py up to parity with base_train.py and tuned settings based on SFT sweeps.
Tuning:
--load-optimizer=1, default on): loads pretrained momentum buffers via new load_optimizer_state() in checkpoint_manager.py. LRs are reset to fresh SFT values after load. Loading the optimizer works slightly better but not by too much.base_train.py (--warmup-ratio, --warmdown-ratio, --init-lr-frac, --final-lr-frac). Similar to pretraining, warmdown ratio of 0.5 worked the best. --init-lr-frac changed from 1.0 slightly lower to 0.8.--mmlu-epochs / --gsm8k-epochs. Might remove these in the future though.Quality of life, footguns, minor fixes:
total_batch_size to base_train.py checkpoint metadata.--chatcore-every=200) across all 6 tasks, logged to wandb.get_peak_flops() for actual GPU instead of hardcoded H100 value.--dry-run and --dtype flags. All ranks now participate in checkpoint save.So far, the --total-batch-size was hardcoded to be 2**19 = 524,288 ~= 0.5M tokens. This was the optimal setting for d12, but when I tried to re-tune it for d26 (GPT-2), I noticed that the optimal was closer to 2**20 = 1,048,576 ~= 1M tokens. This is to be expected - larger models prefer a higher optimal total batch size. However, we have to make sure that all settings of --depth get their own optimal batch size calculated in some principled way. Here, I referenced the "Power Lines" paper from Cerebras (arXiv:2505.13738) for a lot of related experimentation. In particular, they found that Bopt ∝ D^0.383 (where D is the number of training tokens, not the number of parameters!). So the idea is to tune the optimal batch size on d12, and then extrapolate it with this power law to bigger models. The 0.383 exponent means batch size grows slowly: 10× more tokens only justifies ~2.4× bigger batch. For nanochat's compute-optimal training (D ∝ N via --target-param-data-ratio), this means deeper models naturally want larger batches.
Added --total-batch-size=-1 (now the default) to auto-compute optimal batch:
get_scaling_params = lambda m: m.num_scaling_params()['transformer_matrices'] + m.num_scaling_params()['lm_head']
if args.total_batch_size == -1:
D_REF = args.target_param_data_ratio * get_scaling_params(build_model_meta(12))
B_REF = 2**19
args.total_batch_size = 2 ** round(math.log2(B_REF * (target_tokens / D_REF) ** 0.383))
Reference point: d=12 model with B=2^19 (empirically validated). The reference is computed dynamically so that if the architecture changes (e.g., different --aspect-ratio), the math automatically adjusts. However, if the model actually does change too much, one would also want to re-tune the optimal batch size for d=12.
With this formula, we currently get:
| Depth | Scaling Params | Target Tokens | Auto Batch |
|---|---|---|---|
| d=8 | 42M | 0.44B | 2^18 = 262K |
| d=10-16 | 70M-235M | 0.7B-2.5B | 2^19 = 524K |
| d=18-26 | 324M-918M | 3.4B-9.6B | 2^20 = 1.05M |
| d=32-50 | 1.7B-6.2B | 17.6B-65.6B | 2^21 = 2.1M |
In particular, this matches empirical observations that d26 prefers ~2^20 while d12 prefers ~2^19.
Also refactored model initialization to use build_model_meta(depth) helper and dataclasses.asdict() for cleaner config handling.
Tried batch size ramping. The simplest implementation I could think of "tricks" the existing training loop by slicing each micro-batch into smaller pieces and calling optimizer.step() more frequently early in training (1/8 → 1/4 → 1/2 → full batch over the first x% of training, with sqrt LR scaling). Also required a torch.compile warmup phase to pre-compile all slice sizes and avoid recompilation spikes during training. While the idea is sound and small gains were observed, they weren't sufficient to justify the code complexity introduced (conditional slicing logic, warmup with state save/restore, etc.). Not merged for now.
Replaced ReLU² MLP activation with SwiGLU (inspired by twitter). SwiGLU uses three projections instead of two, so to match parameters and FLOPs we scale hidden_dim from 4× to 8/3×:
# Old ReLU²: 2 matrices, 4x expansion
# params: 2 × n × 4n = 8n²
# flops: 2 × 2n × 4n = 16n² per token
self.c_fc = Linear(n_embd, 4 * n_embd)
self.c_proj = Linear(4 * n_embd, n_embd)
x = c_proj(relu(c_fc(x)).square())
# New SwiGLU: 3 matrices, 8/3x expansion
# params: 2 × n × (8n/3) + (8n/3) × n = 8n² ✓ matches
# flops: 3 × 2n × (8n/3) = 16n² per token ✓ matches
hidden_dim = (8 * n_embd) // 3
self.w1 = Linear(n_embd, hidden_dim) # gate
self.w2 = Linear(n_embd, hidden_dim) # up
self.w3 = Linear(hidden_dim, n_embd) # down
x = w3(silu(w1(x)) * w2(x))
Tested at both d12 and d24 (GPT-2 scale). Worse on all measures — step efficiency, wall clock time, and FLOPs. ReLU² remains superior for nanochat. Not adopted.
Tested flipping the shape-based LR heuristic in Muon from boosting tall matrices (input projections like c_fc) to boosting wide matrices (output projections like c_proj). The original code applies max(1, rows/cols)^0.5, giving ~2x LR to c_fc. The flipped version gives ~2x LR to c_proj instead, which aligns with classical fan-in/fan-out scaling conventions. This was proposed in PR #492 and showed improvements in modded-nanogpt.
Result: Quick d12 experiment: slightly worse Not adopted.
Inspired by modded-nanogpt, tried stepping AdamW only on odd iterations while Muon steps every iteration. The idea is that small AdamW params (embeddings, scalars, gates) don't need updates as frequently as the large weight matrices, and skipping saves both compute and communication.
Added skip_adamw parameter to MuonAdamW.step() and DistMuonAdamW.step() plus a matching zero_grad(skip_adamw=...) to let AdamW gradients accumulate over 2 steps. Used lr *= 2**-0.5 (sqrt scaling) to compensate for the 2x effective batch size on AdamW params.
Result: for nanochat d12, we see ~2% faster tok/s, but each step is slightly worse in loss. On net, when plotting against wall clock time, it's slightly worse. Not adopted.
Integrated FP8 training using torchao.float8 to accelerate Linear layer matmuls on H100 GPUs.
FP8 (8-bit floating point) uses H100's FP8 tensor cores for ~2x theoretical matmul throughput. The tradeoff is quantization overhead: computing scales and casting tensors to/from FP8. Still, as an example torchtitan (Meta's distributed training framework) reports 25-28% speedups with FP8 for some of their experiments.
Previous attempt (Jan 2026): FP8 on just lm_head following modded-nanogpt with custom ops → 1% speedup, +2GB memory. Failed due to fragile torch.compile interaction. But this experiment was also done on ~d12 scale back then instead of the bigger model that gets GPT-2 capability of approx d24.
This attempt: Use torchao's convert_to_float8_training() on ALL Linear layers, increase model size to d24. The core snippet is:
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
config = Float8LinearConfig.from_recipe_name("tensorwise")
convert_to_float8_training(model, config=config)
But in practice it's more involved (see base_train.py).
Microbenchmark (d26 MLP, 65536x1664 @ 1664x6656):
| Method | Forward | Fwd+Bwd | Speedup |
|---|---|---|---|
| BF16 + compile | 2.00ms | 4.79ms | 1.00x |
| FP8 rowwise + compile | 1.84ms | 4.55ms | 1.08x |
| FP8 tensorwise + compile | 1.45ms | 4.06ms | 1.38x |
| FP8 rowwise (no compile) | 2.89ms | 21.86ms | 0.23x ❌ |
torch.compile is MANDATORY. Without it, FP8 is 4x slower due to unfused scaling ops.
Full training (d26):
| Config | tok/sec | vs baseline |
|---|---|---|
| BF16 baseline | 630K | 1.00x |
| FP8 rowwise | 564K | 0.90x ❌ |
| FP8 tensorwise | 740K | 1.17x ✓ |
Memory usage also decreases quite a bit, by ~9GB (activations stored as FP8 instead of BF16).
Seeing 17% speedup is encouraging but we're still not done yet because each step is now in lower precision and less powerful individually, so to make up for the precision drop we have to train longer. Empirically, running some sweeps overnight on d24 scale, I saw that the actual speedup (when you match performance) is closer to 5%. It's possible that our LLMs at ~d24 scale are still too small to confidently enjoy the speedups that come from fp8 for bigger models.
For nanochat at approximate scale of interest (~GPT-2 capability, ~d24):
Added --fp8 flag to base_train.py, default recipe is "tensorwise", example of turning on:
torchrun --nproc_per_node=8 -m scripts.base_train --depth=24 --fp8
Uses tensorwise by default. Requires torchao==0.15.0 (compatible with torch 2.9.1), which was added to dependencies.
TLDR: turning on fp8 for GPT-2 capability nanochat model gives approx +5% capability-matched speedup.
Explored Hyperball optimization from this post (saved to knowledge/muonh.md). Constrains weights to sphere of radius R (initial norm): W_{t+1} = R · Normalize(W_t - η·R · Normalize(u_t)). Had to change a number of details in a branch, e.g. not use zero init for our projections (or the initial norm would be zero), keep track of the initial norm, adjust Muon -> MuonH for the update.
Experiments on d12:
| Experiment | Result |
|---|---|
| MuonH for matrix params | Worse than baseline |
| MuonH + LR sweep (2.5e-3 to 1e-2) | Still worse |
| Added learnable RMSNorm scales (paper says γ preserves expressivity) | Still worse |
| Various RMSNorm init tweaks, e.g. 0 at init to residual | Still worse |
| AdamH for lm_head (paper recommends this) | Broken - loss plateaus (see below) |
| AdamH + learnable output scales | Still worse |
Could not outperform the baseline implementation. The article doesn't go into too much detail on how AdamH is applied to lm_head exactly. The classifier layer has to be able to increase in magnitude to make more confident predictions over time. Tried a sensible version with added 0-D learnable scalar, and also with RMSNorms with per-channel learnable scalars both pre and post resnet blocks.
Result: This was not an out-of-the-box win for nanochat even with a mild attempt over a few hours at a bit of tuning and debugging. The idea itself is intuitively appealing. Might come back around later to try harder later.
Removed bigram embeddings (engram-lite) from the codebase. At larger scale (d25), the improvement was tiny and disappeared entirely when measured by wall clock time. It also bloated the VRAM used. The extra parameters and complexity aren't justified.
Explored N-gram memory modules inspired by the DeepSeek Engram paper and modded-nanogpt PR #201.
The Engram paper introduces "conditional memory" as a complement to MoE - using O(1) hash lookups to retrieve static N-gram patterns instead of reconstructing them through computation. Key insight: transformers waste early layers "simulating retrieval through computation" for patterns like named entities and formulaic phrases that could be simple table lookups.
1. Full Engram module with context-aware gating (paper design)
# Hash bigrams to retrieve embeddings, then gate with hidden state
e = embed(hash(prev_token, curr_token))
q = RMSNorm(h) # hidden state as query
k = RMSNorm(W_k @ e) # projected embedding as key
v = W_v @ e
α = sigmoid(q · k / √d) # scalar gate per position
output = α * v
2. Early-layer only injection
3. Trigrams
4. Bigram-only with x0-style injection (modded-nanogpt engram-lite approach)
(36313 * curr) XOR (27191 * prev) mod table_sizex = resid_λ[i]*x + x0_λ[i]*x0 + bigram_λ[i]*x0_bigramTLDR The winning approach follows modded-nanogpt's "engram-lite", simply adding the following module and feeding its output into the residual branch (gated by a per-layer learnable \lambda) before every single block:
class BigramEmbed(nn.Module):
def __init__(self, vocab_size, embed_dim, table_multiplier=5):
self.embed = nn.Embedding(vocab_size * table_multiplier, embed_dim)
def forward(self, idx):
h = (36313 * idx[:, 1:]) ^ (27191 * idx[:, :-1]) % (table_size - 1)
return self.embed(h)
As for optimal hyperparameters:
vocab_size * 5 (~164K entries for 32K vocab). Swept a number of settings and 5 was optimal.bigram_lambdas (init 0.1 was better than 0.0).norm() to the embeddings (mirroring the token embeddings), this was slightly worse.Gating didn't help at our scale. The paper's context-aware gating mechanism (sigmoid dot-product gate) added parameters and complexity without improvement. modded-nanogpt found the same: "simple direct addition to the residual stream outperformed by a decent margin."
Uniform injection beats early-only. Despite the paper's finding that early layers benefit most, restricting injection to early layers hurt. The x0-style "add everywhere with learned lambda" pattern works better for our architecture/scale.
Bigrams are sufficient. Trigrams didn't help - the extra context doesn't pay for the diluted capacity.
Scale matters. The Engram paper's results are at 27B params with MoE. At our ~100M-1B scale, the simpler approach wins. The elaborate gating mechanism may become useful at larger scales where collision handling matters more.
For d12 model with table_multiplier=5:
If you're keeping track, we now have a lot of parameters, a significant amount of them in embeddings (token embeddings, bigram embeddings, value embeddings). For example, for a d12 we now have:
Parameter counts:
wte : 25,165,824
bigram_embed : 125,829,120
value_embeds : 150,994,944
lm_head : 25,165,824
transformer_matrices : 84,935,808
scalars : 36
total : 412,091,556
In other words, only about a quarter of parameters are now weight projections and the vast majority is embedding tables.
Still, on all axes (steps, wall clock time, flops), this somewhat parameter-bloated architecture beats the baseline and will now become the default.
After adding the engram-lite, I re-ran the scaling laws to determine the new optimal tokens:params ratio. I swept FLOPs in the range 1e18..1e19, exponentially strided in 4 settings (1e18, 2e18, 5e18, 1e19). I looked at a number of ways of determining the effective parameter count for the purposes of the scaling laws. The results looked like this:
Kaplan-style (all projections including lm_head and no embeddings)
Optimal configurations (from quadratic fits):
FLOPs Eff Params Tokens Ratio Val BPB
-----------------------------------------------------------------
1e+18 110,678,115 1,241,505,403 11.2 0.8972
2e+18 167,797,457 1,785,336,422 10.7 0.8616
5e+18 250,650,865 2,642,234,152 10.8 0.8293
1e+19 381,758,347 3,806,871,243 10.3 0.7999
N \propto C^0.54, D \propto C^0.49
Chinchilla-style (all parameters, period.)
Optimal configurations (from quadratic fits):
FLOPs Eff Params Tokens Ratio Val BPB
-----------------------------------------------------------------
1e+18 416,320,605 1,232,157,011 3.0 0.8974
2e+18 560,239,841 1,763,669,281 3.2 0.8616
5e+18 741,495,903 2,629,909,368 3.6 0.8291
1e+19 988,644,331 3,884,841,895 4.0 0.7999
N \propto C^0.37, D \propto C^0.50
Transformer-only-style (only the projections inside the transformer)
Optimal configurations (from quadratic fits):
FLOPs Eff Params Tokens Ratio Val BPB
-----------------------------------------------------------------
1e+18 80,259,665 1,315,639,547 17.2 0.8966
2e+18 131,488,566 1,864,134,141 14.5 0.8622
5e+18 220,985,474 2,595,328,843 12.1 0.8302
1e+19 401,213,504 3,328,704,512 8.5 0.7994
N \propto C^0.70, D \propto C^0.41
Clearly, the Kaplan-style ratios are most consistent and produce stable ~0.5 exponents for both params and tokens, meaning we can have a single fixed ratio of tokens:params for compute optimal models. This turns out to be about ~10.5, which now becomes the new default.
Ran ~320 experiments across 6 rounds, scaling from d12→d16→d20 to find optimal optimizer hyperparameters. Added granular per-component control to setup_optimizers() — separate LRs and betas for embedding, unembedding, value_embeds, resid_lambdas, x0_lambdas, and Muon matrix params.
At d12, found two independent improvement routes:
Both gave ~0.002 improvement, but combining them caused conflicts. Fine-tuning found wd=0.13, matrix_lr=0.027, emb_lr=0.38 helped slightly. Best d12 config: Route A + x0_beta1=0.95.
At d16, Route B became competitive with Route A. The routes still conflicted when combined.
At d20 (target scale), everything changed:
x0_beta1=0.96 alone captured nearly all the gains| x0_beta1 | val/bpb | Δ vs baseline |
|---|---|---|
| 0.96 | 0.7971 | -0.0007 |
| 0.94 | 0.7972 | -0.0006 |
| 0.90 | 0.7972 | -0.0006 |
| 0.97 | 0.7977 | -0.0001 |
| 0.98 | 0.8011 | +0.0033 💀 |
Flat plateau from 0.90-0.96, then sharp cliff at 0.97+.
Hyperparameters are scale-dependent. What works at d12 doesn't transfer to d20. The elaborate fine-tuning that won at d12 actively hurts at d20.
Improvement magnitude shrinks with scale. ~0.002 at d12 → ~0.0007 at d20. The baseline is already better-tuned for larger models.
Sharp cliffs exist. x0_beta1=0.98 is catastrophic while 0.96 is optimal.
Don't over-tune on small proxies. Validate at target scale before shipping.
For production d20 runs, add one flag:
--x0-lambdas-beta1=0.96
Skip everything else discovered at smaller scales.
sa_lambdas that gate QKV and O. Slightly confused because of the use of rmsnorm, which erases the effect of any scalar multiplier. Helped a tiny bit (~1e-4 of loss), abandoned to control complexity.Modded-nanogpt uses Value Embeddings (VEs) in a funny U-shaped structure, 3 of them in total and with gates. I tried a large number of tweaks on this today:
Long story short is that the models love Value Embeddings. It is a way to add a huge amount of capacity (parameters) to the model at almost zero cost of FLOPs, because these embeddings are simply added to the Values tensor. Any attempt to reduce the capacity of value embeddings (param sharing, low rank, projections) fail. The model wants many of them, and with all the capacity, and doing so wins across all x axes of steps, flops and wall clock. I re-ran the scaling laws and, because the models are now very parameter bloated, the optimal ratio has halved from 8 to 4! Way down lower than Chinchilla's 20 at this point.
Other experiments, looking at val/bpb as a function of all of steps, flops and wall clock time:
Keeping all of this work on a private branch for now but hope to push shortly.
Continued testing ideas from modded-nanogpt.
| Idea | Result | Notes |
|---|---|---|
| Attention gates | No improvement | Per-head learnable gates on attention output. +1GB memory, decreased efficiency. |
| Batch size schedule | Abandoned | 8→16→24 with LR scaling. Made training script too bloated/complex, not worth cognitive overhead. |
| Value embeddings | Helps a lot | Experiments still ongoing, more on this later. |
Added automatic fallback from Flash Attention 3 to PyTorch's scaled_dot_product_attention (SDPA) for users without Hopper GPUs. This enables nanochat to run on older CUDA GPUs, CPU, and MPS (Apple Silicon).
Created nanochat/flash_attention.py - a unified interface that:
flash_attn object matching FA3's API exactly (flash_attn.flash_attn_func, flash_attn.flash_attn_with_kvcache)Changes to existing code were intentionally kept extremely minimal.
gpt.py: Only the import line changed and a comment
engine.py: Zero changes needed
base_train.py: Added status print and warnings:
--window-pattern is not "L"Tests are split into two classes due to dtype/device constraints:
TestFA3VsSDPA: Comparison tests requiring Hopper GPU + bfloat16. Run both implementations on identical inputs and verify outputs match (max diff typically 0, at most ~0.004 for sliding window).
TestSDPAOnly: SDPA-only tests that run on any device with appropriate dtype. Verify forward pass, backward pass, and KV cache work correctly.
Added _override_impl mechanism for testing - can force 'fa3' or 'sdpa' to directly compare implementations.
--window-pattern L (full context) when using SDPA fallbackTested several architectural ideas from modded-nanogpt to see if they transfer to nanochat. All of these did not help:
| Idea | Result | Notes |
|---|---|---|
| Half-truncated RoPE | No improvement | Only first half of head dims get RoPE (base 1024, linspace). Second half "stationary". |
| Asymmetric softcap | Slightly worse | 23 * sigmoid((x+5)/7.5) vs our symmetric 15 * tanh(x/15). May only help with FP8. |
| Smear gate | Negligible | Blend each token with predecessor via learned gate. Tiny improvement not worth n_embd² params. |
| Backout | No improvement | Save activations at ~60% through network, subtract scaled version at end. |
| Skip connection | Slightly worse | Save at layer ~25%, add at layer ~50%. Also +2GB memory from storing activations. |
Value Embeddings do show promise. I need a more elaborate exploration of a few related ideas, which I leave for tomorrow.
I attempted to train on the Olmo 3 pretraining dataset allenai/dolma3_mix-6T instead of FineWeb-edu. I ran into a number of errors and issues trying to both download and process the dataset and then noticed some quality issues (e.g. some documents seem to be extremely short, like "5".). I managed to work around these with some sensible hacks (e.g. reject documents less than 100 characters in length) and tried to process the dataset exactly as FineWeb, re-trained the tokenizer and trained a d16 model. The CORE score decreased from 15.5 to 13.8, i.e. the result is quite a bit worse.
I am still looking to try the DCLM dataset, which according to the paper should be better that FineWeb-edu. I do have some concerns that the same group both prepared the DCLM dataset and introduced the CORE score so I'm a bit hesitant in case there was some overfitting to CORE score adjacent data distribution.
Classifying as negative result and reverting back to FineWeb-edu for now.
Attempted to prevent attention from "leaking" across document boundaries using Flash Attention's flash_attn_varlen_func, similar to modded-nanogpt's approach.
With the BOS-aligned dataloader, multiple documents are packed into each row. Standard attention allows tokens to attend across document boundaries within a row. The hypothesis was that preventing this "leakage" via varlen attention might improve training.
(inputs.view(-1) == bos_token_id).nonzero()cu_seqlens caused torch.compile recompilation (25s/iter!) - fixed by padding to fixed sizenonzero() inside compiled model hit recompile limit - fixed by moving computation outside compiled region| Metric | Baseline | Varlen |
|---|---|---|
| val_bpb | 0.85427 | 0.85407 |
| MFU | ~same | ~same |
| tok/sec | ~same | ~same |
Essentially identical. The 0.0002 bpb improvement is almost noise.
Not worth the code complexity. The "leakage" across document boundaries within a row is not harmful - the model handles it fine. The BOS-aligned dataloader already provides the key benefit (every row starts with proper context). Not merging to master.
Redesigned the pretraining and midtraining dataloader to ensure every sequence starts with a BOS token, and explored bin-packing algorithms to minimize wasted tokens.
The original dataloader streams tokens into a flat buffer and reshapes into batches. This means some rows start mid-document (no BOS), which could confuse the model during training. We want every row to start with BOS and contain well-formed documents.
Each row is built independently:
Measured token waste empirically on real data (T=2048):
| Algorithm | Util% | Crop% | Pad% | Notes |
|---|---|---|---|---|
| Greedy-Crop (baseline) | 100% | 39.4% | 0% | Simple, no wasted compute |
| Greedy-Pad | 78% | 23.0% | 22% | Pads instead of crops - wastes compute |
| First-Fit Decreasing (FFD) | 99.7% | 23.0% | 0.3% | Near-optimal packing, minimal padding |
| BestFit-Crop | 100% | 34.6% | 0% | Smart cropping, no padding |
A middle ground that maintains 100% utilization while reducing cropping:
This avoids "unlucky" crops by searching the buffer for better-fitting documents.
Results (T=2048):
Keep the original implementation which is very simple, efficient and has 100% token utilization in the batch (no padding with ignore tokens), but creates slightly more confusing token streams for the LLM because documents during training can start abruptly from the middle with no context. Note that this never happens at test time, where BOS is always present.
_bos_bestfit (BestFit-Crop, new default): Slightly more complex but still keeps 100% token utilization in the batch (no padding), but at the cost of discarding documents when they don't fit. In practice, about 34% of tokens are discarded with this approach. This is ok because for most models we care about we have plenty of data without having to go to multiple epochs. One more subtle effect is that it does skew the data distribution a tiny bit because, reliably and necessarily, tokens at the tails of long documents will be discarded. However, this doesn't seem to impact actual downstream performance.
The midtraining dataloader was also updated. Because conversations are on average a lot shorter than pretraining documents, only about 3.3% of tokens get cropped.
Do note that switching to the BOS dataloader changes the validation loss and makes all previous experiments not comparable in absolute value of the loss, because we have a lot fewer "confusing" tokens in the train/val batches. All tokens can look back and find the BOS token and have the full context of that document to make predictions. Therefore, the loss appears lower but this is "fake" to some extent, and the expectation is that the vast majority of relative comparisons done so far would agree with those before and after this change.
Validated the \p{N}{1,2} pattern in SPLIT_PATTERN (tokenizer.py line 30), which I only guessed earlier and had a TODO for to validate. GPT-4 uses \p{N}{1,3} to group number sequences of up to 3 digits into tokens, but we suspected smaller vocab sizes benefit from grouping fewer digits per token.
Results (d12, vocab=32K):
| Pattern | val_bpb |
|---|---|
\p{N}{1,1} | 0.969 |
\p{N}{1,2} | 0.965 |
\p{N}{1,3} | 0.972 |
Conclusion: {1,2} is optimal for vocab size 32K. Grouping 3 digits wastes tokens on rare 3-digit combinations; grouping 1 digit is too fine-grained and bloats token sequences. Keeping {1,2} as default.
Attempted to use FP8 (8-bit floating point) for the lm_head layer to speed up the large vocab projection matmul. H100 GPUs have FP8 tensor cores that can theoretically provide ~2x speedup over BF16.
1. Dynamic Scaling (failed)
x.abs().max() and w.abs().max() each forward to determine scales.item() calls cause graph breaks with torch.compile@torch._dynamo.allow_in_graph pattern (like torchao.float8) - worked but no speeduptorch.library.custom_op with float scales - caused NaN gradients after first optimizer step2. Static Scaling (partial success)
x_scale=10/448, w_scale=0.1/448grad_scale computed dynamically from batch size (safe since it's just 1/(B*T)/57344 due to the gradient expression of cross entropy). modded-nanogpt has a bug here probably because they set grad_scale = 0.75/448, but grads are in E5M2 so this should probably be 1/57344, 1 being the amax of any individual element of cross entropy loss, and no normalization by B,T because they use sum reduction not mean reduction.torch.library.custom_op with @torch.compile on inner kernels| Metric | BF16 Baseline | FP8 lm_head |
|---|---|---|
| GPU Memory | 34 GB | 36 GB |
| tok/sec | baseline | ~1% faster |
FP8 should save memory since we store x_f8 (1 byte) instead of x (2 bytes) for backward. But we see 2GB increase. Suspected causes:
torch.compile on inner kernels creating extra buffers/specializationstorch._scaled_mm internal workspace allocationsTried saving original weight w (just a reference to parameter) instead of w_f8 in backward, then re-quantizing on the spot during backward - didn't help. Still saw bump.
Raw microbenchmark showed promise:
But in full training, the ~1% tok/sec improvement doesn't justify the 2GB memory increase and the added code complexity and the need to tune scale factors for both x and w.
See the branch fp8_attempt_fail for:
nanochat/fp8_static.py - Static scaling implementation (working)nanochat/fp8_dynamic.py - Dynamic scaling implementation (torchao-style, working but slow)gpt.py imports fp8_static.LinearFP8 and simply swaps it for lm_head in gpt.py.Conclusion: Negative result for now. The implementation works correctly but provides marginal speedup with increased memory usage. I'm not understanding the torch.compile interaction here. The complexity of FP8 custom ops isn't justified for lm_head alone. TODO to study in more detail the way this is implemented in other libraries, e.g. torchao.
Ported multi-token prediction from modded-nanogpt. Instead of predicting just the next token, predict the next n tokens at each position with weighted loss.
n_predict times, uses a fancy batched computation using unfold + gather + cross-entropy decomposition (CE = logsumexp - logits[target])[1.0, 0.5, 0.25→0] (3rd token fades)[1.0, 0.5→0] (2nd token fades)[1.0] (standard next-token)| Metric | Baseline | MTP |
|---|---|---|
| GPU Memory | 34 GB | 47 GB |
| MFU | 41% | 40% |
| val/bpb (per step) | baseline | same/slightly worse |
| val/bpb (wall clock) | baseline | noticeably worse |
Conclusion: Negative result for nanochat. The extra memory and compute overhead from predicting multiple tokens doesn't pay off, in fact the results get worse. The auxiliary loss signal may help in other settings (larger models, different architectures?), but for our setup it's pure overhead at the moment.
Added configurable sliding window attention, inspired by GPT-3's alternating short/long pattern.
Pattern string configuration:
--window_pattern CLI arg and GPTConfig.window_pattern fieldSSSL for 20 layers → SSSLSSSLSSSLSSSLSSSL)sequence_len // 2sequence_len (full context)L and checkpoint loading is modified accordingly to fill in this param for old models, see _patch_missing_config_keysQuick experiments showed SSSL (every 4th layer is long) works well - provides a good balance between compute savings and model quality. This is now the default.
Replaced PyTorch's scaled_dot_product_attention (FA2) with Flash Attention 3 for training and inference.
1. FA3 via kernels package
kernels package from HuggingFace Hub: get_kernel('varunneal/flash-attention-3')2. Simplified attention code
(B, T, H, D) layout matching our projection output directly - no transpose neededflash_attn.flash_attn_func(q, k, v, causal=True)flash_attn.flash_attn_with_kvcache() handles all cache cases in one call3. Rewrote KVCache for FA3
(num_layers, 2, B, H, T, D) combined tensork_cache and v_cache of shape (num_layers, B, T, H, D)flash_attn_with_kvcachecache_seqlens tensor (int32, per batch element)get_layer_cache(), advance(), reset(), prefill()window_size=(left, 0), which is huge and expected to give further improvements. This is ready to tune but keeping full context for now.Cherry-picked an idea from modded-nanogpt around learnable per-layer residual connections.
1. x0_lambdas (x0 residual connections)
x0 after norm(wte(idx))x = resid_lambdas[i] * x + x0_lambdas[i] * x02. resid_lambdas (residual stream scaling)
3. DistAdamW small parameter handling
all_reduce instead of reduce_scatter/all_gatherThe two scalar types need very different learning rates:
Implementation: resid_params gets scalar_lr * 0.01, x0_params gets full scalar_lr.
Swept --scalar_lr (controlling x0_lambdas) at multiple depths:
| Depth | Baseline (disabled) | Best scalar_lr | Best val_bpb | Δ bpb |
|---|---|---|---|---|
| d8 | 1.0885 | 0.20 | 1.0782 | -0.0103 |
| d12 | 0.9770 | 0.60 | 0.9693 | -0.0077 |
| d16 | 0.9059 | 0.20 | 0.9002 | -0.0057 |
| d20 | 0.8565 | 0.10 | 0.8526 | -0.0039 |
Observations:
Important lesson: __init__ runs in meta device context, so any tensor values set there are fake. Must initialize actual values in init_weights(). Added docstring warning to __init__.
Added --scalar_lr (default 0.5) controlling learnable per-layer scalars. The formula x = resid_lambdas[i] * x + x0_lambdas[i] * x0 gives the model control over residual scaling and direct shortcuts to the initial embedding. Solid improvement with essentially no compute overhead.
Cherry-picked improvements from NorMuon (modded-nanogpt) into our simpler Muon implementation. Decided against using NorMuon directly due to hard-coded architecture assumptions (expects 32 params split 10 attn + 22 mlp), parameter labeling requirements, and complexity.
1. Polar Express Orthogonalization
zeropower_via_polar_express vs zeropower_via_newtonschulz5)2. NorMuon Variance Reduction
second_momentum_buffer with shape [rows, 1] or [1, cols] (whichever is smaller)3. Cautious Weight Decay
update * weight >= 0 (same sign) from arxiv.org/abs/2411.16085@torch.compile function. Passing changing float values (like weight_decay during scheduling) as function arguments triggers recompilation. Reading from group["weight_decay"] inside the step avoids this.weight_decay param. AdamW still has no weight decay and is hardcoded to 0 weight decay, might try to re-tune this later.4. Weight decay schedule
Swept weight decay values at d8, d12, d16, d20 to find optimal values and scaling law.
Optimal Values Found:
| Depth | Width (channels) | Optimal WD |
|---|---|---|
| d8 | 512 | ~0.40 |
| d12 | 768 | ~0.22 |
| d16 | 1024 | ~0.10 |
| d20 | 1280 | ~0.08 |
Scaling Law:
WD = k / channels^α in log-log spacePractical Formula:
WD_target = WD_reference × (d_reference / d_target)²
Example: If d12 optimal is 0.22, then d20 optimal ≈ 0.22 × (12/20)² ≈ 0.08
Reference: Moonlight paper uses fixed WD=0.1 for their 15B MoE model. Our experiments indicated a scaling law where the optimal WD changed with depth, so we go along with the empirical scaling law.
Muon was changed to use Polar Express, added NorMuon variance reduction, and cautious weight decay with schedule that ramps linearly to zero. All of these changes follow modded-nanogpt repo, but all of them were also validated piece by piece to yield improvements in nanochat with the exception of the Polar Express change which was in the noise. This is default on and configurable with --weight_decay, using simply 0.2 and ∝ 1/width² scaling. The kwarg --weight_decay is therefore changing as of this change. It used to configure AdamW via standard weight decay and now it becomes exclusively used in Muon (AdamW is hardcoded to 0.0), and it is scaled based on depth.
Hypothesis: Gradient clipping may be unnecessary overhead. Tested L2 norm clipping at various thresholds (0.25, 0.5, 1.0, 2.0) and elementwise clipping.
Results:
Bug Found: Original implementation clipped local gradients before sync. Since this codebase doesn't use DDP (gradient sync is in the optimizers), each rank was clipping based on its own local norm. Fixed on the branch with proper distributed all-reduce.
Observation: modded-nanogpt does not appear to clip either right now.
Summary: Deleted all grad-clip code paths. The code naturally produces well-behaved gradients. This improves a bit of MFU because we don't have to calculate and sync grad norms.