File size: 1,366 Bytes
9a570a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""
Simple test script for the trained model
"""
import os
import torch
import tiktoken
from model import GPTConfig, GPT

def test_model():
    # Load model
    ckpt_path = "out-srs/ckpt_000600.pt"
    print(f"Loading {ckpt_path}...")
    
    checkpoint = torch.load(ckpt_path, map_location="mps")
    gptconf = GPTConfig(**checkpoint['model_args'])
    model = GPT(gptconf)
    
    # Load weights
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    
    model.load_state_dict(state_dict)
    model.eval()
    model.to("mps")
    
    print(f"Model loaded! (iteration {checkpoint['iter_num']})")
    
    # Test generation
    enc = tiktoken.get_encoding("gpt2")
    encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
    decode = lambda l: enc.decode(l)
    
    prompt = "Hello, how are you?"
    print(f"\nPrompt: {prompt}")
    
    start_ids = encode(prompt)
    x = torch.tensor(start_ids, dtype=torch.long, device="mps")[None, ...]
    
    with torch.no_grad():
        y = model.generate(x, 50, temperature=0.8, top_k=200)
        result = decode(y[0].tolist())
    
    print(f"Generated: {result}")
    return True

if __name__ == "__main__":
    test_model()