build-your-own-x/llm/model.py
Claude 1d3ce8cff7
Add basic LLM implementation from scratch
Implements a character-level GPT-style Transformer:
- model.py: CausalSelfAttention, FeedForward, TransformerBlock, LLM
- tokenizer.py: CharTokenizer (char -> int mapping)
- train.py: training loop with AdamW, gradient clipping, checkpointing, sampling
- generate.py: load checkpoint and generate text from a prompt

Verified working on a built-in Shakespeare excerpt (805k param model).

https://claude.ai/code/session_01SWXLQb3nFTiygbp74dpjVa
2026-03-22 22:51:49 +00:00

171 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Basic LLM: a character-level GPT-style Transformer built from scratch.
Architecture:
- Token + positional embeddings
- N Transformer blocks (masked multi-head self-attention + feed-forward)
- Layer norm + linear head
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class CausalSelfAttention(nn.Module):
"""Multi-head self-attention with a causal (look-ahead) mask."""
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_heads == 0
self.n_heads = config.n_heads
self.head_dim = config.n_embd // config.n_heads
self.n_embd = config.n_embd
# Fused Q/K/V projection
self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
self.proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.dropout = nn.Dropout(config.dropout)
# Causal mask: upper-triangular, registered as a non-parameter buffer
mask = torch.triu(torch.ones(config.block_size, config.block_size), diagonal=1).bool()
self.register_buffer("mask", mask)
def forward(self, x):
B, T, C = x.shape
q, k, v = self.qkv(x).split(self.n_embd, dim=2)
# Reshape to (B, n_heads, T, head_dim)
def reshape(t):
return t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
q, k, v = reshape(q), reshape(k), reshape(v)
# Scaled dot-product attention
scale = math.sqrt(self.head_dim)
scores = (q @ k.transpose(-2, -1)) / scale # (B, nh, T, T)
scores = scores.masked_fill(self.mask[:T, :T], float("-inf"))
weights = F.softmax(scores, dim=-1)
weights = self.dropout(weights)
out = weights @ v # (B, nh, T, hd)
out = out.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)
return self.proj(out)
class FeedForward(nn.Module):
"""Position-wise feed-forward network (expand → GELU → contract)."""
def __init__(self, config):
super().__init__()
self.net = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(),
nn.Linear(4 * config.n_embd, config.n_embd),
nn.Dropout(config.dropout),
)
def forward(self, x):
return self.net(x)
class TransformerBlock(nn.Module):
"""Self-attention + feed-forward with pre-layer-norm residual connections."""
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln2 = nn.LayerNorm(config.n_embd)
self.ff = FeedForward(config)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.ff(self.ln2(x))
return x
class LLM(nn.Module):
"""
Tiny GPT-style language model.
config fields:
vocab_size number of tokens
block_size maximum context length
n_embd embedding dimension
n_heads number of attention heads
n_layers number of Transformer blocks
dropout dropout probability
"""
def __init__(self, config):
super().__init__()
self.config = config
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
self.pos_emb = nn.Embedding(config.block_size, config.n_embd)
self.drop = nn.Dropout(config.dropout)
self.blocks = nn.Sequential(*[TransformerBlock(config) for _ in range(config.n_layers)])
self.ln_f = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Weight tying: share token embedding and output projection weights
self.head.weight = self.tok_emb.weight
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
B, T = idx.shape
assert T <= self.config.block_size, "Sequence longer than block_size"
positions = torch.arange(T, device=idx.device)
x = self.drop(self.tok_emb(idx) + self.pos_emb(positions))
x = self.blocks(x)
x = self.ln_f(x)
logits = self.head(x) # (B, T, vocab_size)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
"""
Autoregressively generate tokens given a conditioning sequence.
Args:
idx (1, T) tensor of starting token ids
max_new_tokens number of tokens to generate
temperature >1 = more random, <1 = more focused
top_k if set, restrict sampling to the top-k logits
"""
for _ in range(max_new_tokens):
# Crop to block_size if needed
idx_cond = idx[:, -self.config.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature # last time-step
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float("-inf")
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, next_token], dim=1)
return idx
def num_params(self):
return sum(p.numel() for p in self.parameters())