Spaces:
Runtime error
Runtime error
Upload app.py with huggingface_hub
Browse files
app.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ.setdefault("HF_HOME", "/tmp/hf")
|
| 3 |
+
os.environ.setdefault("HF_HUB_CACHE", "/tmp/hf/hub")
|
| 4 |
+
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf/transformers")
|
| 5 |
+
os.environ.setdefault("NANOCHAT_BASE_DIR", "/tmp/nanochat")
|
| 6 |
+
|
| 7 |
+
from huggingface_hub import hf_hub_download
|
| 8 |
+
import torch
|
| 9 |
+
import gradio as gr
|
| 10 |
+
|
| 11 |
+
from nanochat.checkpoint_manager import load_model_from_dir
|
| 12 |
+
from nanochat.engine import Engine
|
| 13 |
+
|
| 14 |
+
MODEL_REPO = os.getenv("NANOCHAT_MODEL_REPO", "loocorez/nanochat-base-d20-step21400")
|
| 15 |
+
STEP = os.getenv("NANOCHAT_STEP", "021400")
|
| 16 |
+
DEPTH = os.getenv("NANOCHAT_DEPTH", "20")
|
| 17 |
+
|
| 18 |
+
ckpt_dir = f"/tmp/ckpt/d{DEPTH}"
|
| 19 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
| 20 |
+
|
| 21 |
+
# tokenizer (where nanochat expects it)
|
| 22 |
+
tokenizer_dir = "/tmp/nanochat/tokenizer"
|
| 23 |
+
os.makedirs(tokenizer_dir, exist_ok=True)
|
| 24 |
+
hf_hub_download(MODEL_REPO, "tokenizer/tokenizer.pkl", local_dir=tokenizer_dir, local_dir_use_symlinks=False)
|
| 25 |
+
|
| 26 |
+
# base checkpoint
|
| 27 |
+
hf_hub_download(MODEL_REPO, f"base_checkpoints/d{DEPTH}/model_{STEP}.pt", local_dir=ckpt_dir, local_dir_use_symlinks=False)
|
| 28 |
+
hf_hub_download(MODEL_REPO, f"base_checkpoints/d{DEPTH}/meta_{STEP}.json", local_dir=ckpt_dir, local_dir_use_symlinks=False)
|
| 29 |
+
|
| 30 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 31 |
+
model, tokenizer, _ = load_model_from_dir(ckpt_dir, device, phase="eval")
|
| 32 |
+
engine = Engine(model, tokenizer)
|
| 33 |
+
|
| 34 |
+
def chat_fn(history, temperature=0.8, top_k=50, max_new_tokens=256):
|
| 35 |
+
bos = tokenizer.get_bos_token_id()
|
| 36 |
+
user_start = tokenizer.encode_special("<|user_start|>")
|
| 37 |
+
user_end = tokenizer.encode_special("<|user_end|>")
|
| 38 |
+
assistant_start = tokenizer.encode_special("<|assistant_start|>")
|
| 39 |
+
assistant_end = tokenizer.encode_special("<|assistant_end|>")
|
| 40 |
+
|
| 41 |
+
tokens = [bos]
|
| 42 |
+
for role, content in history:
|
| 43 |
+
if role == "user":
|
| 44 |
+
tokens += [user_start] + tokenizer.encode(content) + [user_end]
|
| 45 |
+
else:
|
| 46 |
+
tokens += [assistant_start] + tokenizer.encode(content) + [assistant_end]
|
| 47 |
+
tokens += [assistant_start]
|
| 48 |
+
|
| 49 |
+
with torch.amp.autocast(device_type="cuda" if device.type == "cuda" else "cpu", dtype=torch.bfloat16 if device.type == "cuda" else torch.float32):
|
| 50 |
+
token_column, _ = next(engine.generate(tokens, num_samples=1, max_tokens=max_new_tokens, temperature=temperature, top_k=top_k))
|
| 51 |
+
new_tokens = token_column[len(tokens):]
|
| 52 |
+
return tokenizer.decode(new_tokens)
|
| 53 |
+
|
| 54 |
+
with gr.Blocks() as demo:
|
| 55 |
+
gr.Markdown("# NanoChat BASE")
|
| 56 |
+
chat = gr.Chatbot(type="tuple")
|
| 57 |
+
msg = gr.Textbox()
|
| 58 |
+
temp = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label="Temperature")
|
| 59 |
+
topk = gr.Slider(1, 200, value=50, step=1, label="Top-k")
|
| 60 |
+
max_toks = gr.Slider(16, 1024, value=256, step=16, label="Max new tokens")
|
| 61 |
+
|
| 62 |
+
def respond(user_message, chat_history, temperature, top_k, max_new_tokens):
|
| 63 |
+
chat_history = chat_history + [("user", user_message)]
|
| 64 |
+
reply = chat_fn(chat_history, temperature, top_k, max_new_tokens)
|
| 65 |
+
chat_history = chat_history + [("assistant", reply)]
|
| 66 |
+
return "", chat_history
|
| 67 |
+
|
| 68 |
+
msg.submit(respond, [msg, chat, temp, topk, max_toks], [msg, chat])
|
| 69 |
+
|
| 70 |
+
demo.launch()
|
| 71 |
+
|
| 72 |
+
|