loocorez's picture
Upload app.py with huggingface_hub
3d31b21 verified
import os
import json
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from tokenizers import Tokenizer as HFTokenizer
from gpt_infer import GPT, GPTConfig
DEFAULT_MODEL = os.environ.get('NANOCHAT_DEFAULT_MODEL', 'loocorez/nanochat-sft-d20-step650')
ALL_MODELS = [
'loocorez/nanochat-sft-d20-step650',
'loocorez/nanochat-mid-d20-step765',
'loocorez/nanochat-base-d20-step21400',
]
@torch.inference_mode()
def load_model(repo_id: str):
cfg_path = hf_hub_download(repo_id, 'config.json')
with open(cfg_path, 'r') as f:
cfg = json.load(f)
model_config = GPTConfig(
sequence_len=cfg.get('n_ctx', 2048),
vocab_size=cfg['vocab_size'],
n_layer=cfg['n_layer'],
n_head=cfg['n_head'],
n_kv_head=cfg.get('n_kv_head', cfg['n_head']),
n_embd=cfg['n_embd'],
)
model = GPT(model_config)
model.eval()
try:
from safetensors.torch import load_file
weights_path = hf_hub_download(repo_id, 'model.safetensors')
sd = load_file(weights_path)
except Exception:
weights_path = hf_hub_download(repo_id, 'pytorch_model.bin')
sd = torch.load(weights_path, map_location='cpu')
model.load_state_dict(sd, strict=True, assign=True)
tok_path = hf_hub_download(repo_id, 'tokenizer.json')
tok = HFTokenizer.from_file(tok_path)
return model, tok
model_cache = {}
def get_model(repo_id: str):
if repo_id not in model_cache:
model_cache[repo_id] = load_model(repo_id)
return model_cache[repo_id]
@torch.inference_mode()
def generate(repo_id: str, system_prompt: str, prompt: str, max_tokens: int, temperature: float, top_k: int|None):
model, tok = get_model(repo_id)
bos_id = tok.token_to_id('<|bos|>')
# Combine system + user prompt
text = prompt if not system_prompt else f"{system_prompt.strip()}
{prompt}"
ids = tok.encode(text).ids
if bos_id is not None:
ids = [bos_id] + ids
out_tokens = []
for token in model.generate(ids, max_tokens=max_tokens, temperature=temperature, top_k=top_k or None):
out_tokens.append(token)
text = tok.decode(out_tokens, skip_special_tokens=False)
return text
@torch.inference_mode()
def compare_three(system_prompt: str, prompt: str, max_tokens: int, temperature: float, top_k: int|None):
outputs = []
for repo_id in ALL_MODELS:
outputs.append(generate(repo_id, system_prompt, prompt, max_tokens, temperature, top_k))
return tuple(outputs)
with gr.Blocks() as demo:
gr.Markdown('# nanochat (ZeroGPU)')
gr.Markdown('Run a single model or compare SFT/MID/BASE side by side.')
with gr.Tabs():
with gr.Tab('Single'):
repo = gr.Dropdown(choices=ALL_MODELS, value=DEFAULT_MODEL, label='Model Repo')
system = gr.Textbox(label='System prompt (optional)', value='You are a helpful assistant.', lines=2)
prompt = gr.Textbox(label='User prompt', lines=6)
with gr.Row():
max_tokens = gr.Slider(1, 256, value=128, step=1, label='Max tokens')
temperature = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label='Temperature')
top_k = gr.Slider(0, 100, value=40, step=1, label='Top-k (0=disabled)')
btn = gr.Button('Generate')
output = gr.Textbox(label='Output', lines=10)
btn.click(
fn=lambda r,s,p,m,t,k: generate(r,s,p,int(m),float(t),int(k) if int(k)>0 else None),
inputs=[repo, system, prompt, max_tokens, temperature, top_k],
outputs=output
)
with gr.Tab('Compare 3'):
system_c = gr.Textbox(label='System prompt (optional)', value='You are a helpful assistant.', lines=2)
prompt_c = gr.Textbox(label='User prompt', lines=6)
with gr.Row():
max_tokens_c = gr.Slider(1, 256, value=128, step=1, label='Max tokens')
temperature_c = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label='Temperature')
top_k_c = gr.Slider(0, 100, value=40, step=1, label='Top-k (0=disabled)')
btn_c = gr.Button('Run on all three')
with gr.Row():
out_sft = gr.Textbox(label='SFT', lines=10)
out_mid = gr.Textbox(label='MID', lines=10)
out_base = gr.Textbox(label='BASE', lines=10)
btn_c.click(
fn=lambda s,p,m,t,k: compare_three(s,p,int(m),float(t),int(k) if int(k)>0 else None),
inputs=[system_c, prompt_c, max_tokens_c, temperature_c, top_k_c],
outputs=[out_sft, out_mid, out_base]
)
if __name__ == '__main__':
demo.launch()