build-your-own-x/llm/generate.py
Claude 1d3ce8cff7
Add basic LLM implementation from scratch
Implements a character-level GPT-style Transformer:
- model.py: CausalSelfAttention, FeedForward, TransformerBlock, LLM
- tokenizer.py: CharTokenizer (char -> int mapping)
- train.py: training loop with AdamW, gradient clipping, checkpointing, sampling
- generate.py: load checkpoint and generate text from a prompt

Verified working on a built-in Shakespeare excerpt (805k param model).

https://claude.ai/code/session_01SWXLQb3nFTiygbp74dpjVa
2026-03-22 22:51:49 +00:00

50 lines
1.6 KiB
Python

"""
Generate text from a saved checkpoint.
Usage:
python generate.py --checkpoint ckpt.pt --prompt "First Citizen:" --tokens 300
"""
import argparse
import torch
from model import LLM
def generate(args):
device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt = torch.load(args.checkpoint, map_location=device, weights_only=False)
config = ckpt["config"]
tokenizer = ckpt["tokenizer"]
model = LLM(config).to(device)
model.load_state_dict(ckpt["model"])
model.eval()
prompt = args.prompt or ""
if prompt:
idx = torch.tensor(tokenizer.encode(prompt), dtype=torch.long, device=device).unsqueeze(0)
else:
# Start from a single newline token
idx = torch.zeros((1, 1), dtype=torch.long, device=device)
out = model.generate(idx, max_new_tokens=args.tokens,
temperature=args.temperature, top_k=args.top_k)
print(tokenizer.decode(out[0].tolist()))
def parse_args():
p = argparse.ArgumentParser(description="Generate text from a trained LLM checkpoint")
p.add_argument("--checkpoint", type=str, default="ckpt.pt", help="Path to checkpoint")
p.add_argument("--prompt", type=str, default="", help="Seed text")
p.add_argument("--tokens", type=int, default=300, help="Tokens to generate")
p.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature")
p.add_argument("--top_k", type=int, default=20, help="Top-k sampling (0 = off)")
return p.parse_args()
if __name__ == "__main__":
args = parse_args()
if args.top_k == 0:
args.top_k = None
generate(args)