babbleGPT / test_model.py
shiffman's picture
Upload folder using huggingface_hub
9a570a0 verified
"""
Simple test script for the trained model
"""
import os
import torch
import tiktoken
from model import GPTConfig, GPT
def test_model():
# Load model
ckpt_path = "out-srs/ckpt_000600.pt"
print(f"Loading {ckpt_path}...")
checkpoint = torch.load(ckpt_path, map_location="mps")
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
# Load weights
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval()
model.to("mps")
print(f"Model loaded! (iteration {checkpoint['iter_num']})")
# Test generation
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)
prompt = "Hello, how are you?"
print(f"\nPrompt: {prompt}")
start_ids = encode(prompt)
x = torch.tensor(start_ids, dtype=torch.long, device="mps")[None, ...]
with torch.no_grad():
y = model.generate(x, 50, temperature=0.8, top_k=200)
result = decode(y[0].tolist())
print(f"Generated: {result}")
return True
if __name__ == "__main__":
test_model()