babbleGPT / gradio_demo.py
shiffman's picture
Upload folder using huggingface_hub
9a570a0 verified
raw
history blame
4.86 kB
"""
Gradio interface for testing the trained nanoGPT model
"""
import os
import gradio as gr
import torch
import tiktoken
from model import GPTConfig, GPT
# Configuration
MODEL_DIR = "out-srs" # Change this to your model directory
DEVICE = "cpu" # Hugging Face Spaces uses CPU
MAX_TOKENS = 100
TEMPERATURE = 0.8
TOP_K = 200
def load_model():
"""Load the latest checkpoint from the model directory"""
print(f"Loading model from {MODEL_DIR}...")
# Use a specific checkpoint that we know is complete
ckpt_path = os.path.join(MODEL_DIR, 'ckpt_001000.pt')
print(f"Loading checkpoint: {ckpt_path}")
# Load checkpoint
checkpoint = torch.load(ckpt_path, map_location="cpu")
# Create model
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
# Load weights
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval()
model.to("cpu")
print(f"Model loaded successfully! (iteration {checkpoint['iter_num']})")
return model
def load_tokenizer():
"""Load the tokenizer"""
# Check if there's a meta.pkl file for custom tokenizer
meta_path = os.path.join('data', 'srs', 'meta.pkl')
if os.path.exists(meta_path):
import pickle
print(f"Loading tokenizer from {meta_path}")
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
else:
print("Using GPT-2 tokenizer")
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)
return encode, decode
# Load model and tokenizer once at startup
print("Initializing model...")
model = load_model()
encode, decode = load_tokenizer()
print("Ready!")
def generate_text(prompt, max_tokens, temperature, top_k):
"""Generate text from the model"""
try:
# Encode the prompt
start_ids = encode(prompt)
x = torch.tensor(start_ids, dtype=torch.long, device="cpu")[None, ...]
# Generate
with torch.no_grad():
y = model.generate(x, max_tokens, temperature=temperature, top_k=top_k)
generated = decode(y[0].tolist())
return generated
except Exception as e:
return f"Error generating text: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="SRS Conversational Model") as demo:
gr.Markdown("# SRS Conversational Model")
gr.Markdown("This model was trained on conversational data. Enter a prompt to see how it continues the conversation!")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here (e.g., 'Hello, how are you?')",
lines=3
)
with gr.Row():
max_tokens_slider = gr.Slider(
minimum=10, maximum=200, value=MAX_TOKENS, step=10,
label="Max tokens to generate"
)
temperature_slider = gr.Slider(
minimum=0.1, maximum=2.0, value=TEMPERATURE, step=0.1,
label="Temperature (creativity)"
)
top_k_slider = gr.Slider(
minimum=1, maximum=500, value=TOP_K, step=10,
label="Top-k (diversity)"
)
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column():
output_text = gr.Textbox(
label="Generated Text",
lines=10,
max_lines=15
)
# Examples
gr.Examples(
examples=[
["Hello, how are you?", 100, 0.8, 200],
["I think the wedding", 80, 0.7, 150],
["So anyway, let's talk about", 120, 0.9, 200],
["You know what's interesting", 100, 0.8, 200]
],
inputs=[prompt_input, max_tokens_slider, temperature_slider, top_k_slider]
)
# Connect the generate button
generate_btn.click(
fn=generate_text,
inputs=[prompt_input, max_tokens_slider, temperature_slider, top_k_slider],
outputs=output_text
)
if __name__ == "__main__":
print("Starting Gradio interface...")
print("Will be available at http://localhost:7860")
print("Use share=True for public link")
# Launch for Hugging Face Spaces
demo.launch()