Spaces:
Sleeping
Sleeping
| import os | |
| os.environ.setdefault("HF_HOME", "/tmp/hf") | |
| os.environ.setdefault("HF_HUB_CACHE", "/tmp/hf/hub") | |
| os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf/transformers") | |
| from transformers import AutoModel | |
| from huggingface_hub import hf_hub_download | |
| import torch | |
| import gradio as gr | |
| import pickle | |
| MODEL_ID = "loocorez/nanochat-base-d20-test" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load model via Auto* with trust_remote_code | |
| model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| model = model.to(device) | |
| model.eval() | |
| # Load tokenizer.pkl directly (avoid AutoTokenizer mapping issues) | |
| tok_path = hf_hub_download(MODEL_ID, filename="tokenizer.pkl") | |
| class PklTokenizer: | |
| def __init__(self, pkl_file): | |
| with open(pkl_file, "rb") as f: | |
| self.enc = pickle.load(f) | |
| self._bos = self.enc.encode_single_token("<|bos|>") | |
| def get_bos_token_id(self): | |
| return self._bos | |
| def encode(self, text, prepend=None): | |
| ids = self.enc.encode_ordinary(text) | |
| if prepend is not None: | |
| ids = [prepend] + ids | |
| return ids | |
| def decode(self, ids): | |
| return self.enc.decode(ids) | |
| tokenizer = PklTokenizer(tok_path) | |
| def complete(prompt, max_new_tokens=64): | |
| input_ids = tokenizer.encode(prompt, prepend=tokenizer.get_bos_token_id()) | |
| ids = torch.tensor([input_ids], dtype=torch.long, device=device) | |
| with torch.inference_mode(): | |
| for _ in range(max_new_tokens): | |
| outputs = model(input_ids=ids) | |
| logits = outputs["logits"] if isinstance(outputs, dict) else outputs.logits | |
| next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True) | |
| ids = torch.cat([ids, next_token], dim=1) | |
| return tokenizer.decode(ids[0].tolist()) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# NanoChat Transformers Demo (BASE d20)") | |
| inp = gr.Textbox(value="The capital of Belgium is ") | |
| max_toks = gr.Slider(1, 256, value=64, step=1, label="Max new tokens") | |
| out = gr.Textbox() | |
| btn = gr.Button("Generate") | |
| btn.click(complete, [inp, max_toks], [out]) | |
| demo.launch() | |