mirror of
https://github.com/codecrafters-io/build-your-own-x
synced 2026-07-02 16:59:25 +00:00
Implements a character-level GPT-style Transformer: - model.py: CausalSelfAttention, FeedForward, TransformerBlock, LLM - tokenizer.py: CharTokenizer (char -> int mapping) - train.py: training loop with AdamW, gradient clipping, checkpointing, sampling - generate.py: load checkpoint and generate text from a prompt Verified working on a built-in Shakespeare excerpt (805k param model). https://claude.ai/code/session_01SWXLQb3nFTiygbp74dpjVa
171 lines
5.8 KiB
Python
171 lines
5.8 KiB
Python
"""
|
||
Basic LLM: a character-level GPT-style Transformer built from scratch.
|
||
|
||
Architecture:
|
||
- Token + positional embeddings
|
||
- N Transformer blocks (masked multi-head self-attention + feed-forward)
|
||
- Layer norm + linear head
|
||
"""
|
||
|
||
import math
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
|
||
class CausalSelfAttention(nn.Module):
|
||
"""Multi-head self-attention with a causal (look-ahead) mask."""
|
||
|
||
def __init__(self, config):
|
||
super().__init__()
|
||
assert config.n_embd % config.n_heads == 0
|
||
self.n_heads = config.n_heads
|
||
self.head_dim = config.n_embd // config.n_heads
|
||
self.n_embd = config.n_embd
|
||
|
||
# Fused Q/K/V projection
|
||
self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
|
||
self.proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
||
self.dropout = nn.Dropout(config.dropout)
|
||
|
||
# Causal mask: upper-triangular, registered as a non-parameter buffer
|
||
mask = torch.triu(torch.ones(config.block_size, config.block_size), diagonal=1).bool()
|
||
self.register_buffer("mask", mask)
|
||
|
||
def forward(self, x):
|
||
B, T, C = x.shape
|
||
|
||
q, k, v = self.qkv(x).split(self.n_embd, dim=2)
|
||
|
||
# Reshape to (B, n_heads, T, head_dim)
|
||
def reshape(t):
|
||
return t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
||
|
||
q, k, v = reshape(q), reshape(k), reshape(v)
|
||
|
||
# Scaled dot-product attention
|
||
scale = math.sqrt(self.head_dim)
|
||
scores = (q @ k.transpose(-2, -1)) / scale # (B, nh, T, T)
|
||
scores = scores.masked_fill(self.mask[:T, :T], float("-inf"))
|
||
weights = F.softmax(scores, dim=-1)
|
||
weights = self.dropout(weights)
|
||
|
||
out = weights @ v # (B, nh, T, hd)
|
||
out = out.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)
|
||
return self.proj(out)
|
||
|
||
|
||
class FeedForward(nn.Module):
|
||
"""Position-wise feed-forward network (expand → GELU → contract)."""
|
||
|
||
def __init__(self, config):
|
||
super().__init__()
|
||
self.net = nn.Sequential(
|
||
nn.Linear(config.n_embd, 4 * config.n_embd),
|
||
nn.GELU(),
|
||
nn.Linear(4 * config.n_embd, config.n_embd),
|
||
nn.Dropout(config.dropout),
|
||
)
|
||
|
||
def forward(self, x):
|
||
return self.net(x)
|
||
|
||
|
||
class TransformerBlock(nn.Module):
|
||
"""Self-attention + feed-forward with pre-layer-norm residual connections."""
|
||
|
||
def __init__(self, config):
|
||
super().__init__()
|
||
self.ln1 = nn.LayerNorm(config.n_embd)
|
||
self.attn = CausalSelfAttention(config)
|
||
self.ln2 = nn.LayerNorm(config.n_embd)
|
||
self.ff = FeedForward(config)
|
||
|
||
def forward(self, x):
|
||
x = x + self.attn(self.ln1(x))
|
||
x = x + self.ff(self.ln2(x))
|
||
return x
|
||
|
||
|
||
class LLM(nn.Module):
|
||
"""
|
||
Tiny GPT-style language model.
|
||
|
||
config fields:
|
||
vocab_size – number of tokens
|
||
block_size – maximum context length
|
||
n_embd – embedding dimension
|
||
n_heads – number of attention heads
|
||
n_layers – number of Transformer blocks
|
||
dropout – dropout probability
|
||
"""
|
||
|
||
def __init__(self, config):
|
||
super().__init__()
|
||
self.config = config
|
||
|
||
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
|
||
self.pos_emb = nn.Embedding(config.block_size, config.n_embd)
|
||
self.drop = nn.Dropout(config.dropout)
|
||
self.blocks = nn.Sequential(*[TransformerBlock(config) for _ in range(config.n_layers)])
|
||
self.ln_f = nn.LayerNorm(config.n_embd)
|
||
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||
|
||
# Weight tying: share token embedding and output projection weights
|
||
self.head.weight = self.tok_emb.weight
|
||
|
||
self.apply(self._init_weights)
|
||
|
||
def _init_weights(self, module):
|
||
if isinstance(module, nn.Linear):
|
||
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||
if module.bias is not None:
|
||
nn.init.zeros_(module.bias)
|
||
elif isinstance(module, nn.Embedding):
|
||
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||
|
||
def forward(self, idx, targets=None):
|
||
B, T = idx.shape
|
||
assert T <= self.config.block_size, "Sequence longer than block_size"
|
||
|
||
positions = torch.arange(T, device=idx.device)
|
||
x = self.drop(self.tok_emb(idx) + self.pos_emb(positions))
|
||
x = self.blocks(x)
|
||
x = self.ln_f(x)
|
||
logits = self.head(x) # (B, T, vocab_size)
|
||
|
||
loss = None
|
||
if targets is not None:
|
||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
||
|
||
return logits, loss
|
||
|
||
@torch.no_grad()
|
||
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
||
"""
|
||
Autoregressively generate tokens given a conditioning sequence.
|
||
|
||
Args:
|
||
idx – (1, T) tensor of starting token ids
|
||
max_new_tokens – number of tokens to generate
|
||
temperature – >1 = more random, <1 = more focused
|
||
top_k – if set, restrict sampling to the top-k logits
|
||
"""
|
||
for _ in range(max_new_tokens):
|
||
# Crop to block_size if needed
|
||
idx_cond = idx[:, -self.config.block_size:]
|
||
logits, _ = self(idx_cond)
|
||
logits = logits[:, -1, :] / temperature # last time-step
|
||
|
||
if top_k is not None:
|
||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||
logits[logits < v[:, [-1]]] = float("-inf")
|
||
|
||
probs = F.softmax(logits, dim=-1)
|
||
next_token = torch.multinomial(probs, num_samples=1)
|
||
idx = torch.cat([idx, next_token], dim=1)
|
||
|
||
return idx
|
||
|
||
def num_params(self):
|
||
return sum(p.numel() for p in self.parameters())
|