loocorez commited on
Commit
880bf35
·
verified ·
1 Parent(s): ce291ed

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +72 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ.setdefault("HF_HOME", "/tmp/hf")
3
+ os.environ.setdefault("HF_HUB_CACHE", "/tmp/hf/hub")
4
+ os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf/transformers")
5
+ os.environ.setdefault("NANOCHAT_BASE_DIR", "/tmp/nanochat")
6
+
7
+ from huggingface_hub import hf_hub_download
8
+ import torch
9
+ import gradio as gr
10
+
11
+ from nanochat.checkpoint_manager import load_model_from_dir
12
+ from nanochat.engine import Engine
13
+
14
+ MODEL_REPO = os.getenv("NANOCHAT_MODEL_REPO", "loocorez/nanochat-base-d20-step21400")
15
+ STEP = os.getenv("NANOCHAT_STEP", "021400")
16
+ DEPTH = os.getenv("NANOCHAT_DEPTH", "20")
17
+
18
+ ckpt_dir = f"/tmp/ckpt/d{DEPTH}"
19
+ os.makedirs(ckpt_dir, exist_ok=True)
20
+
21
+ # tokenizer (where nanochat expects it)
22
+ tokenizer_dir = "/tmp/nanochat/tokenizer"
23
+ os.makedirs(tokenizer_dir, exist_ok=True)
24
+ hf_hub_download(MODEL_REPO, "tokenizer/tokenizer.pkl", local_dir=tokenizer_dir, local_dir_use_symlinks=False)
25
+
26
+ # base checkpoint
27
+ hf_hub_download(MODEL_REPO, f"base_checkpoints/d{DEPTH}/model_{STEP}.pt", local_dir=ckpt_dir, local_dir_use_symlinks=False)
28
+ hf_hub_download(MODEL_REPO, f"base_checkpoints/d{DEPTH}/meta_{STEP}.json", local_dir=ckpt_dir, local_dir_use_symlinks=False)
29
+
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ model, tokenizer, _ = load_model_from_dir(ckpt_dir, device, phase="eval")
32
+ engine = Engine(model, tokenizer)
33
+
34
+ def chat_fn(history, temperature=0.8, top_k=50, max_new_tokens=256):
35
+ bos = tokenizer.get_bos_token_id()
36
+ user_start = tokenizer.encode_special("<|user_start|>")
37
+ user_end = tokenizer.encode_special("<|user_end|>")
38
+ assistant_start = tokenizer.encode_special("<|assistant_start|>")
39
+ assistant_end = tokenizer.encode_special("<|assistant_end|>")
40
+
41
+ tokens = [bos]
42
+ for role, content in history:
43
+ if role == "user":
44
+ tokens += [user_start] + tokenizer.encode(content) + [user_end]
45
+ else:
46
+ tokens += [assistant_start] + tokenizer.encode(content) + [assistant_end]
47
+ tokens += [assistant_start]
48
+
49
+ with torch.amp.autocast(device_type="cuda" if device.type == "cuda" else "cpu", dtype=torch.bfloat16 if device.type == "cuda" else torch.float32):
50
+ token_column, _ = next(engine.generate(tokens, num_samples=1, max_tokens=max_new_tokens, temperature=temperature, top_k=top_k))
51
+ new_tokens = token_column[len(tokens):]
52
+ return tokenizer.decode(new_tokens)
53
+
54
+ with gr.Blocks() as demo:
55
+ gr.Markdown("# NanoChat BASE")
56
+ chat = gr.Chatbot(type="tuple")
57
+ msg = gr.Textbox()
58
+ temp = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label="Temperature")
59
+ topk = gr.Slider(1, 200, value=50, step=1, label="Top-k")
60
+ max_toks = gr.Slider(16, 1024, value=256, step=16, label="Max new tokens")
61
+
62
+ def respond(user_message, chat_history, temperature, top_k, max_new_tokens):
63
+ chat_history = chat_history + [("user", user_message)]
64
+ reply = chat_fn(chat_history, temperature, top_k, max_new_tokens)
65
+ chat_history = chat_history + [("assistant", reply)]
66
+ return "", chat_history
67
+
68
+ msg.submit(respond, [msg, chat, temp, topk, max_toks], [msg, chat])
69
+
70
+ demo.launch()
71
+
72
+