legolasyiu commited on
Commit
c971648
·
verified ·
1 Parent(s): f728324

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -12
app.py CHANGED
@@ -1,21 +1,97 @@
1
-
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import gradio as gr
 
 
4
 
 
5
  checkpoint = "EpistemeAI/metatune-gpt20b-R0"
6
- device = "cuda" # "cuda" or "cpu"
 
7
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
8
- model = AutoModelForCausalLM.from_pretrained(checkpoint,torch_dtype="auto").to(device)
 
9
 
10
- def predict(message, history):
11
- history.append({"role": "user", "content": message})
12
- input_text = tokenizer.apply_chat_template(history, tokenize=False)
13
- inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
14
- outputs = model.generate(inputs, max_new_tokens=3200, temperature=0.2, top_p=0.9, do_sample=True)
15
- decoded = tokenizer.decode(outputs[0])
16
- response = decoded.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0]
17
- return response
 
 
 
 
 
18
 
19
- demo = gr.ChatInterface(predict, type="messages")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
 
 
 
 
 
 
 
 
 
 
 
21
  demo.launch()
 
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
  import gradio as gr
3
+ import torch
4
+ import time
5
 
6
+ # --- Model / tokenizer load (your checkpoint) ---
7
  checkpoint = "EpistemeAI/metatune-gpt20b-R0"
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
11
+ model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto").to(device)
12
+ model.eval()
13
 
14
+ # --- Helper: convert gradio display history (tuples) -> model/chat history (dicts) ---
15
+ def display_to_model_history(display_history):
16
+ """
17
+ Convert gradio chatbot history (list of (role, text)) into a list of dicts
18
+ used by your tokenizer.apply_chat_template. Adjust roles to 'user'/'assistant'.
19
+ """
20
+ model_history = []
21
+ if not display_history:
22
+ return model_history
23
+ for role, text in display_history:
24
+ role_key = "user" if role.lower().startswith("user") else "assistant"
25
+ model_history.append({"role": role_key, "content": text})
26
+ return model_history
27
 
28
+ # --- Prediction (generator) that shows thinking and then final output ---
29
+ def predict(user_message, chat_history):
30
+ """
31
+ Args:
32
+ user_message: string typed by user
33
+ chat_history: list of tuples [(role, text), ...] from the gradio Chatbot
34
+
35
+ Yields:
36
+ chat_history list (so gradio updates UI). First yield shows "Thinking...",
37
+ second yields the final assistant response.
38
+ """
39
+ # Ensure history is a list
40
+ chat_history = chat_history or []
41
+
42
+ # 1) Append user message to both display and model history
43
+ chat_history.append(("User", user_message))
44
+ # Convert to model history for tokenizer
45
+ model_history = display_to_model_history(chat_history)
46
+
47
+ # 2) Append "Thinking..." placeholder in UI and yield (so user sees it)
48
+ chat_history.append(("Assistant", "Thinking..."))
49
+ yield chat_history
50
+
51
+ # 3) Build the prompt for the model using your tokenizer helper
52
+ input_text = tokenizer.apply_chat_template(model_history, tokenize=False)
53
+
54
+ # 4) Tokenize and run generation
55
+ inputs = tokenizer.encode(input_text, return_tensors="pt", truncation=True).to(device)
56
+ # Generate (tune args as you prefer)
57
+ with torch.no_grad():
58
+ outputs = model.generate(
59
+ inputs,
60
+ max_new_tokens=512,
61
+ temperature=0.9,
62
+ top_p=0.9,
63
+ do_sample=True,
64
+ pad_token_id=tokenizer.eos_token_id,
65
+ )
66
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)
67
+
68
+ # 5) Extract assistant response (match your original splitting logic)
69
+ # Keep the same delimiters you used previously (adjust if different)
70
+ try:
71
+ response = decoded.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0].strip()
72
+ except Exception:
73
+ # Fallback: use last part of decoded text
74
+ response = decoded.strip()
75
+
76
+ # 6) Replace the "Thinking..." placeholder with final response
77
+ # The placeholder was last element, so update it
78
+ if chat_history and chat_history[-1][0].lower().startswith("assistant"):
79
+ chat_history[-1] = ("Assistant", response)
80
+ else:
81
+ chat_history.append(("Assistant", response))
82
+
83
+ # 7) Final yield with assistant output
84
+ yield chat_history
85
 
86
+ # --- Gradio UI ---
87
+ with gr.Blocks() as demo:
88
+ gr.Markdown("## Episteme Chat — shows 'Thinking...' then final assistant output")
89
+ chatbot = gr.Chatbot(height=600)
90
+ txt = gr.Textbox(show_label=False, placeholder="Type your message and hit Enter")
91
+ clear = gr.Button("Clear")
92
+
93
+ # Bind the generator to textbox submit
94
+ txt.submit(predict, inputs=[txt, chatbot], outputs=chatbot)
95
+ clear.click(lambda: None, None, chatbot, queue=False) # clears chat (returns None)
96
+
97
  demo.launch()