built a language model, one layer at a time.
roughly 10m parameters, 312m training tokens, on a single 4090 over two weekends. tokenizer, attention blocks, adamw, sampling loop: all handwritten. nothing imported from transformers. probably buggy in ways i haven't caught yet.
[run 03 · weekend two · checkpoint every 500] step 04200/20000 loss 2.874 lr 3.0e-4 tok/s 38204 step 04250/20000 loss 2.861 lr 3.0e-4 tok/s 38190 step 04300/20000 loss 2.847 lr 3.0e-4 tok/s 38217 ├─ sample @ T=0.8 ────────────── │ the quick brown fox jumped over │ the lazy dog and began to think │ about what it means to be alive └──────────────────────────────── step 04350/20000 loss 2.839 lr 2.9e-4 tok/s 38211 step 04400/20000 loss 2.821 lr 2.9e-4 tok/s 38196
the training curve.
train / val loss (every 200 steps, step 0 → 20000) loss 4.20 │█ 3.80 │ ▓▒ 3.40 │ ▓░ 3.00 │ ▓▓░░ 2.80 │ ▓▓▓░░░ 2.60 │ ▓▓▓▓░░░░ 2.47 │ ▓▓▓▓▓▓▓░░░░░░░ └───────────────────────────────────────── 0 5k 10k 15k 20k steps █ train ░ val gap stable at ~0.11 nats
# model architecture, at a glance nanolm/ tokenizer.py # bpe, 32,000 merges, from scratch model.py # transformer, rope, gqa=4, ~260 loc optim.py # adamw, decoupled decay, cosine+warmup flash.py # triton kernel, tiled q·k·v softmax train.py # main loop, grad accum, amp, ckpts sample.py # top-k, top-p, temperature, repetition dpo.py # 500 preference pairs, 1 epoch shape of a forward pass: x (B=64, T=512) │ embed + rope ▼ h (64, 512, 256) │ × 6 blocks: gqa attn → swiglu mlp ▼ h (64, 512, 256) │ rms norm → lm_head (tied) ▼ logits (64, 512, 32000)
every layer, on purpose.
bpe tokenizer from scratch
32k vocab, trained on a 400mb slice of the pile. merge table built pair-by-pair in pure python, then cached for speed. no sentencepiece, no tiktoken.
rotary attention + gqa
rope applied to q and k before the dot product. grouped-query attention with gqa=4, so inference is cheap and the kv-cache shrinks 4x.
adamw, cosine, warmup
decoupled weight decay, 200-step linear warmup, cosine decay to 10% of peak. no deepspeed, no accelerate, no fsdp. one gpu, one process.
flash-attention v2, triton
tiled softmax kernel in triton. about 3.1x over the naive pytorch path at seq 512 on my box, and it actually fits the bf16 budget on a 4090.
tiny rlhf pass
500 handwritten preference pairs, dpo for one epoch, beta 0.1. not enough to change the world. enough to see the log-ratio term do something.
~420 loc, no magic
the whole training loop fits on two printouts. you can read it sunday morning with coffee and understand what's being optimized.
the honest part.
most of those 37 hours were not training. they were me staring at a nan at step 1,842 trying to figure out which layer ate it. the first two runs diverged silently around step 6k because i forgot to scale the attention logits by 1/sqrt(d_k). the third run diverged because i did, but i'd also normalized twice. the flash kernel segfaulted for a day and a half until i noticed i was indexing past the end of the tile on sequence lengths that weren't multiples of 64. every paper says "we use adamw with warmup and cosine decay" in one sentence. getting all three right is the project.
not public yet.
the repo isn't up yet, i'm still cleaning it up. pytorch 2.3+, triton 2.2+, one cuda gpu with at least 12gb when it ships. email bennett@frkhd.com if you want to poke at the checkpoints early.