research notebook · pytorch · triton · mit

built a language model, one layer at a time.

roughly 10m parameters, 312m training tokens, on a single 4090 over two weekends. tokenizer, attention blocks, adamw, sampling loop: all handwritten. nothing imported from transformers. probably buggy in ways i haven't caught yet.

[run 03 · weekend two · checkpoint every 500]

step 04200/20000  loss 2.874  lr 3.0e-4  tok/s 38204
step 04250/20000  loss 2.861  lr 3.0e-4  tok/s 38190
step 04300/20000  loss 2.847  lr 3.0e-4  tok/s 38217
  ├─ sample @ T=0.8 ──────────────
    the quick brown fox jumped over
    the lazy dog and began to think
    about what it means to be alive
  └────────────────────────────────

step 04350/20000  loss 2.839  lr 2.9e-4  tok/s 38211
step 04400/20000  loss 2.821  lr 2.9e-4  tok/s 38196
loss is real. tok/s is real. the fox is overfitting.
10.4mparams
312mtrain tokens
2.47final val loss
~37hwall-clock
~420loc core

the training curve.

train / val loss   (every 200 steps, step 0 → 20000)

  loss
  4.20 │
  3.80 │ 
  3.40 │  
  3.00 │    ▓▓░░
  2.80 │       ▓▓▓░░░
  2.60 │           ▓▓▓▓░░░░
  2.47 │                ▓▓▓▓▓▓▓░░░░░░░
       └─────────────────────────────────────────
        0      5k      10k      15k      20k   steps

   train    val       gap stable at ~0.11 nats
# model architecture, at a glance

nanolm/
  tokenizer.py       # bpe, 32,000 merges, from scratch
  model.py           # transformer, rope, gqa=4, ~260 loc
  optim.py           # adamw, decoupled decay, cosine+warmup
  flash.py           # triton kernel, tiled q·k·v softmax
  train.py           # main loop, grad accum, amp, ckpts
  sample.py          # top-k, top-p, temperature, repetition
  dpo.py             # 500 preference pairs, 1 epoch

shape of a forward pass:

  x        (B=64, T=512)
    │ embed + rope
    ▼
  h        (64, 512, 256)
    │ × 6 blocks: gqa attn → swiglu mlp
    ▼
  h        (64, 512, 256)
    │ rms norm → lm_head (tied)
    ▼
  logits   (64, 512, 32000)

every layer, on purpose.

bpe tokenizer from scratch

32k vocab, trained on a 400mb slice of the pile. merge table built pair-by-pair in pure python, then cached for speed. no sentencepiece, no tiktoken.

rotary attention + gqa

rope applied to q and k before the dot product. grouped-query attention with gqa=4, so inference is cheap and the kv-cache shrinks 4x.

adamw, cosine, warmup

decoupled weight decay, 200-step linear warmup, cosine decay to 10% of peak. no deepspeed, no accelerate, no fsdp. one gpu, one process.

flash-attention v2, triton

tiled softmax kernel in triton. about 3.1x over the naive pytorch path at seq 512 on my box, and it actually fits the bf16 budget on a 4090.

tiny rlhf pass

500 handwritten preference pairs, dpo for one epoch, beta 0.1. not enough to change the world. enough to see the log-ratio term do something.

~420 loc, no magic

the whole training loop fits on two printouts. you can read it sunday morning with coffee and understand what's being optimized.

the honest part.

most of those 37 hours were not training. they were me staring at a nan at step 1,842 trying to figure out which layer ate it. the first two runs diverged silently around step 6k because i forgot to scale the attention logits by 1/sqrt(d_k). the third run diverged because i did, but i'd also normalized twice. the flash kernel segfaulted for a day and a half until i noticed i was indexing past the end of the tile on sequence lengths that weren't multiples of 64. every paper says "we use adamw with warmup and cosine decay" in one sentence. getting all three right is the project.

not public yet.

the repo isn't up yet, i'm still cleaning it up. pytorch 2.3+, triton 2.2+, one cuda gpu with at least 12gb when it ships. email bennett@frkhd.com if you want to poke at the checkpoints early.