mirror of
https://github.com/codecrafters-io/build-your-own-x
synced 2026-07-02 16:59:25 +00:00
Implements a character-level GPT-style Transformer: - model.py: CausalSelfAttention, FeedForward, TransformerBlock, LLM - tokenizer.py: CharTokenizer (char -> int mapping) - train.py: training loop with AdamW, gradient clipping, checkpointing, sampling - generate.py: load checkpoint and generate text from a prompt Verified working on a built-in Shakespeare excerpt (805k param model). https://claude.ai/code/session_01SWXLQb3nFTiygbp74dpjVa
50 lines
1.6 KiB
Python
50 lines
1.6 KiB
Python
"""
|
|
Generate text from a saved checkpoint.
|
|
|
|
Usage:
|
|
python generate.py --checkpoint ckpt.pt --prompt "First Citizen:" --tokens 300
|
|
"""
|
|
|
|
import argparse
|
|
import torch
|
|
from model import LLM
|
|
|
|
|
|
def generate(args):
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
ckpt = torch.load(args.checkpoint, map_location=device, weights_only=False)
|
|
config = ckpt["config"]
|
|
tokenizer = ckpt["tokenizer"]
|
|
|
|
model = LLM(config).to(device)
|
|
model.load_state_dict(ckpt["model"])
|
|
model.eval()
|
|
|
|
prompt = args.prompt or ""
|
|
if prompt:
|
|
idx = torch.tensor(tokenizer.encode(prompt), dtype=torch.long, device=device).unsqueeze(0)
|
|
else:
|
|
# Start from a single newline token
|
|
idx = torch.zeros((1, 1), dtype=torch.long, device=device)
|
|
|
|
out = model.generate(idx, max_new_tokens=args.tokens,
|
|
temperature=args.temperature, top_k=args.top_k)
|
|
print(tokenizer.decode(out[0].tolist()))
|
|
|
|
|
|
def parse_args():
|
|
p = argparse.ArgumentParser(description="Generate text from a trained LLM checkpoint")
|
|
p.add_argument("--checkpoint", type=str, default="ckpt.pt", help="Path to checkpoint")
|
|
p.add_argument("--prompt", type=str, default="", help="Seed text")
|
|
p.add_argument("--tokens", type=int, default=300, help="Tokens to generate")
|
|
p.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature")
|
|
p.add_argument("--top_k", type=int, default=20, help="Top-k sampling (0 = off)")
|
|
return p.parse_args()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
if args.top_k == 0:
|
|
args.top_k = None
|
|
generate(args)
|