Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from tokenizers import Tokenizer as HFTokenizer | |
| from gpt_infer import GPT, GPTConfig | |
| DEFAULT_MODEL = os.environ.get('NANOCHAT_DEFAULT_MODEL', 'loocorez/nanochat-sft-d20-step650') | |
| ALL_MODELS = [ | |
| 'loocorez/nanochat-sft-d20-step650', | |
| 'loocorez/nanochat-mid-d20-step765', | |
| 'loocorez/nanochat-base-d20-step21400', | |
| ] | |
| def load_model(repo_id: str): | |
| cfg_path = hf_hub_download(repo_id, 'config.json') | |
| with open(cfg_path, 'r') as f: | |
| cfg = json.load(f) | |
| model_config = GPTConfig( | |
| sequence_len=cfg.get('n_ctx', 2048), | |
| vocab_size=cfg['vocab_size'], | |
| n_layer=cfg['n_layer'], | |
| n_head=cfg['n_head'], | |
| n_kv_head=cfg.get('n_kv_head', cfg['n_head']), | |
| n_embd=cfg['n_embd'], | |
| ) | |
| model = GPT(model_config) | |
| model.eval() | |
| try: | |
| from safetensors.torch import load_file | |
| weights_path = hf_hub_download(repo_id, 'model.safetensors') | |
| sd = load_file(weights_path) | |
| except Exception: | |
| weights_path = hf_hub_download(repo_id, 'pytorch_model.bin') | |
| sd = torch.load(weights_path, map_location='cpu') | |
| model.load_state_dict(sd, strict=True, assign=True) | |
| tok_path = hf_hub_download(repo_id, 'tokenizer.json') | |
| tok = HFTokenizer.from_file(tok_path) | |
| return model, tok | |
| model_cache = {} | |
| def get_model(repo_id: str): | |
| if repo_id not in model_cache: | |
| model_cache[repo_id] = load_model(repo_id) | |
| return model_cache[repo_id] | |
| def generate(repo_id: str, system_prompt: str, prompt: str, max_tokens: int, temperature: float, top_k: int|None): | |
| model, tok = get_model(repo_id) | |
| bos_id = tok.token_to_id('<|bos|>') | |
| # Combine system + user prompt | |
| text = prompt if not system_prompt else f"{system_prompt.strip()} | |
| {prompt}" | |
| ids = tok.encode(text).ids | |
| if bos_id is not None: | |
| ids = [bos_id] + ids | |
| out_tokens = [] | |
| for token in model.generate(ids, max_tokens=max_tokens, temperature=temperature, top_k=top_k or None): | |
| out_tokens.append(token) | |
| text = tok.decode(out_tokens, skip_special_tokens=False) | |
| return text | |
| def compare_three(system_prompt: str, prompt: str, max_tokens: int, temperature: float, top_k: int|None): | |
| outputs = [] | |
| for repo_id in ALL_MODELS: | |
| outputs.append(generate(repo_id, system_prompt, prompt, max_tokens, temperature, top_k)) | |
| return tuple(outputs) | |
| with gr.Blocks() as demo: | |
| gr.Markdown('# nanochat (ZeroGPU)') | |
| gr.Markdown('Run a single model or compare SFT/MID/BASE side by side.') | |
| with gr.Tabs(): | |
| with gr.Tab('Single'): | |
| repo = gr.Dropdown(choices=ALL_MODELS, value=DEFAULT_MODEL, label='Model Repo') | |
| system = gr.Textbox(label='System prompt (optional)', value='You are a helpful assistant.', lines=2) | |
| prompt = gr.Textbox(label='User prompt', lines=6) | |
| with gr.Row(): | |
| max_tokens = gr.Slider(1, 256, value=128, step=1, label='Max tokens') | |
| temperature = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label='Temperature') | |
| top_k = gr.Slider(0, 100, value=40, step=1, label='Top-k (0=disabled)') | |
| btn = gr.Button('Generate') | |
| output = gr.Textbox(label='Output', lines=10) | |
| btn.click( | |
| fn=lambda r,s,p,m,t,k: generate(r,s,p,int(m),float(t),int(k) if int(k)>0 else None), | |
| inputs=[repo, system, prompt, max_tokens, temperature, top_k], | |
| outputs=output | |
| ) | |
| with gr.Tab('Compare 3'): | |
| system_c = gr.Textbox(label='System prompt (optional)', value='You are a helpful assistant.', lines=2) | |
| prompt_c = gr.Textbox(label='User prompt', lines=6) | |
| with gr.Row(): | |
| max_tokens_c = gr.Slider(1, 256, value=128, step=1, label='Max tokens') | |
| temperature_c = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label='Temperature') | |
| top_k_c = gr.Slider(0, 100, value=40, step=1, label='Top-k (0=disabled)') | |
| btn_c = gr.Button('Run on all three') | |
| with gr.Row(): | |
| out_sft = gr.Textbox(label='SFT', lines=10) | |
| out_mid = gr.Textbox(label='MID', lines=10) | |
| out_base = gr.Textbox(label='BASE', lines=10) | |
| btn_c.click( | |
| fn=lambda s,p,m,t,k: compare_three(s,p,int(m),float(t),int(k) if int(k)>0 else None), | |
| inputs=[system_c, prompt_c, max_tokens_c, temperature_c, top_k_c], | |
| outputs=[out_sft, out_mid, out_base] | |
| ) | |
| if __name__ == '__main__': | |
| demo.launch() | |