legolasyiu commited on
Commit
8b979ce
·
verified ·
1 Parent(s): 30a476d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -48
app.py CHANGED
@@ -1,29 +1,46 @@
1
  # save as app.py
 
 
 
 
 
 
 
 
 
 
 
2
  import threading
 
3
  import gradio as gr
4
  import torch
5
- from transformers import (
6
- AutoTokenizer,
7
- AutoModelForCausalLM,
8
- TextIteratorStreamer,
9
- )
10
 
11
  MODEL_ID = "EpistemeAI/gpt-oss-20b-RL"
12
 
13
  print("Loading tokenizer and model (this may take a while)...")
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
15
 
16
- # Always use auto mapping / dtype
 
17
  model = AutoModelForCausalLM.from_pretrained(
18
  MODEL_ID,
19
  torch_dtype="auto",
20
  device_map="auto",
21
  )
22
  model.eval()
23
- print("Model loaded. Device:", next(model.parameters()).device)
 
 
 
 
 
24
 
25
- # --------- Helper: build prompt ----------
26
- def build_prompt(system_message: str, history: list[dict], user_message: str) -> str:
 
 
 
27
  pieces = []
28
  if system_message:
29
  pieces.append(f"<|system|>\n{system_message}\n")
@@ -34,10 +51,19 @@ def build_prompt(system_message: str, history: list[dict], user_message: str) ->
34
  pieces.append(f"<|user|>\n{user_message}\n<|assistant|>\n")
35
  return "\n".join(pieces)
36
 
37
- # --------- Streaming generator ----------
38
- def generate_stream(prompt, max_tokens, temperature, top_p):
 
 
 
 
39
  inputs = tokenizer(prompt, return_tensors="pt")
40
- input_ids = inputs["input_ids"].to(model.device)
 
 
 
 
 
41
 
42
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
43
 
@@ -58,56 +84,119 @@ def generate_stream(prompt, max_tokens, temperature, top_p):
58
  partial += token_str
59
  yield partial
60
 
61
- # --------- Gradio app logic ----------
62
- def respond_stream(user_message, chat_history, system_message, max_tokens, temperature, top_p):
63
- history = chat_history or []
64
- prompt = build_prompt(system_message or "", history, user_message or "")
65
-
66
- history.append({"role": "user", "content": user_message})
67
- history.append({"role": "assistant", "content": ""})
68
 
69
- def history_to_chatbot_rows(hist):
70
- rows = []
71
- for item in hist:
72
- if item["role"] == "assistant":
73
- rows.append(("thinking...", item["content"] or "thinking..."))
74
- return rows or []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- chatbot_rows = history_to_chatbot_rows(history[:-1])
77
- chatbot_rows.append(("thinking...", "thinking..."))
78
- yield chatbot_rows # placeholder row
79
 
80
- for partial in generate_stream(prompt, max_tokens, temperature, top_p):
81
- chatbot_rows[-1] = ("thinking...", partial)
82
- history[-1]["content"] = partial
83
- yield chatbot_rows
 
84
 
85
- chatbot_rows[-1] = ("thinking...", history[-1]["content"])
86
- yield chatbot_rows
87
 
88
- # --------- Build Gradio UI ----------
89
  with gr.Blocks() as demo:
90
- gr.Markdown(f"**Model:** {MODEL_ID}")
91
- with gr.Row():
92
- chatbot = gr.Chatbot(elem_id="chatbot", label="Assistant Output (user/system hidden)").style(height=500)
93
 
94
- history_state = gr.State(value=[])
95
- system_input = gr.Textbox(value="You are a Vibe Coder assistant.", label="System message")
96
- user_input = gr.Textbox(placeholder="Type a user message and press Send", label="Your message")
97
- max_tokens = gr.Slider(minimum=1, maximum=4000, value=800, step=1, label="Max new tokens")
98
- temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.01, label="Temperature")
99
- top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)")
100
- send_btn = gr.Button("Send")
101
 
 
 
 
 
 
 
 
 
 
 
 
102
  send_btn.click(
103
  fn=respond_stream,
104
- inputs=[user_input, history_state, system_input, max_tokens, temperature, top_p],
105
  outputs=[chatbot],
106
  queue=True,
107
  )
108
 
109
- send_btn.click(lambda u, s: s, inputs=[user_input, history_state], outputs=[history_state])
110
- send_btn.click(lambda: "", None, user_input)
 
 
 
 
 
 
 
 
111
 
112
  if __name__ == "__main__":
113
  demo.launch()
 
1
  # save as app.py
2
+ """
3
+ Gradio streaming chat that hides user/system text.
4
+ Left-side messages are always "thinking..." (literal).
5
+ Right-side shows the assistant output streamed as it is generated.
6
+
7
+ Requirements:
8
+ - transformers
9
+ - accelerate (recommended)
10
+ - gradio
11
+ - torch
12
+ """
13
  import threading
14
+ import time
15
  import gradio as gr
16
  import torch
17
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
 
 
 
18
 
19
  MODEL_ID = "EpistemeAI/gpt-oss-20b-RL"
20
 
21
  print("Loading tokenizer and model (this may take a while)...")
22
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
23
 
24
+ # ALWAYS use auto for dtype & device_map as requested
25
+ # This will let HF/accelerate place weights across available devices
26
  model = AutoModelForCausalLM.from_pretrained(
27
  MODEL_ID,
28
  torch_dtype="auto",
29
  device_map="auto",
30
  )
31
  model.eval()
32
+ print("Model loaded. Example param device:", next(model.parameters()).device)
33
+
34
+ # Global history (real content). Stored as list of {"role": "user"|"assistant"|"system", "content": "..."}
35
+ GLOBAL_HISTORY = []
36
+ HISTORY_LOCK = threading.Lock()
37
+
38
 
39
+ def build_prompt(system_message: str, history: list, user_message: str) -> str:
40
+ """
41
+ Build the model prompt in your preferred format.
42
+ Adjust this function if your model expects a different conversation format.
43
+ """
44
  pieces = []
45
  if system_message:
46
  pieces.append(f"<|system|>\n{system_message}\n")
 
51
  pieces.append(f"<|user|>\n{user_message}\n<|assistant|>\n")
52
  return "\n".join(pieces)
53
 
54
+
55
+ def generate_stream(prompt: str, max_tokens: int, temperature: float, top_p: float):
56
+ """
57
+ Yields partial strings as the model generates tokens using TextIteratorStreamer.
58
+ """
59
+ # Tokenize (we avoid forcing a single-device .to(...) in case of HF sharded device_map)
60
  inputs = tokenizer(prompt, return_tensors="pt")
61
+ # Move input_ids to same device as a model parameter (works with many configs)
62
+ try:
63
+ input_ids = inputs["input_ids"].to(next(model.parameters()).device)
64
+ except Exception:
65
+ # Fallback: do not move if that fails (accelerate may handle placement)
66
+ input_ids = inputs["input_ids"]
67
 
68
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
69
 
 
84
  partial += token_str
85
  yield partial
86
 
 
 
 
 
 
 
 
87
 
88
+ def make_display_messages_from_history(real_history: list, current_partial: str | None):
89
+ """
90
+ Convert the internal 'real_history' (which contains real user & assistant content)
91
+ into the list of openai-style message dicts that Gradio Chatbot (type="messages")
92
+ expects. Every non-assistant message is replaced with a literal "thinking...".
93
+ For each assistant exchange we produce:
94
+ {"role":"user", "content":"thinking..."}
95
+ {"role":"assistant", "content": "<assistant content or partial stream>"}
96
+ The UI will therefore show the left side text "thinking..." and right side the assistant.
97
+ """
98
+ msgs = []
99
+ # Walk the real history and whenever we hit an assistant turn, pair it with a thinking user
100
+ i = 0
101
+ while i < len(real_history):
102
+ item = real_history[i]
103
+ if item["role"] == "assistant":
104
+ # the user message that caused this assistant reply is typically just before it,
105
+ # but we hide the real user content and show "thinking..." instead.
106
+ msgs.append({"role": "user", "content": "thinking..."})
107
+ msgs.append({"role": "assistant", "content": item.get("content", "") or "thinking..."})
108
+ i += 1
109
+
110
+ # If a current_partial exists (we're streaming a new assistant response),
111
+ # ensure it's reflected as the last assistant message (with a preceding "thinking...")
112
+ if current_partial is not None:
113
+ # If the last two entries are already the streaming pair, replace them; otherwise append new
114
+ if msgs and msgs[-1]["role"] == "assistant":
115
+ msgs[-1]["content"] = current_partial
116
+ else:
117
+ msgs.append({"role": "user", "content": "thinking..."})
118
+ msgs.append({"role": "assistant", "content": current_partial})
119
+ return msgs
120
+
121
+
122
+ def respond_stream(user_message, system_message, max_tokens, temperature, top_p):
123
+ """
124
+ Gradio streaming function that yields successive message-lists (OpenAI-style dicts).
125
+ It mutates GLOBAL_HISTORY to store the true conversation, but the UI only ever sees
126
+ 'thinking...' in non-assistant slots and the assistant's streamed content on the right.
127
+ """
128
+ # Append the real user turn and an empty assistant placeholder to GLOBAL_HISTORY
129
+ with HISTORY_LOCK:
130
+ GLOBAL_HISTORY.append({"role": "user", "content": user_message})
131
+ GLOBAL_HISTORY.append({"role": "assistant", "content": ""}) # placeholder for streaming
132
+ # create a shallow copy for local read
133
+ local_history = list(GLOBAL_HISTORY)
134
+
135
+ # initial UI placeholder: show existing assistant rows and the new placeholder
136
+ displayed = make_display_messages_from_history(local_history, current_partial="thinking...")
137
+ yield displayed
138
+
139
+ # Build model prompt from the real history (exclude the last assistant placeholder content)
140
+ # We pass the actual global history (safe to read under lock copy)
141
+ with HISTORY_LOCK:
142
+ # Send a snapshot (exclude the last assistant placeholder since it's empty)
143
+ prompt_history = [h for h in GLOBAL_HISTORY[:-1] if h.get("role")]
144
+ prompt = build_prompt(system_message or "", prompt_history, user_message or "")
145
+
146
+ # Stream generation
147
+ for partial in generate_stream(prompt, max_tokens, temperature, top_p):
148
+ # Update the global assistant placeholder with the partial so future turns keep context
149
+ with HISTORY_LOCK:
150
+ # Update the last assistant placeholder
151
+ if GLOBAL_HISTORY and GLOBAL_HISTORY[-1]["role"] == "assistant":
152
+ GLOBAL_HISTORY[-1]["content"] = partial
153
+ current_snapshot = list(GLOBAL_HISTORY)
154
 
155
+ displayed = make_display_messages_from_history(current_snapshot, current_partial=partial)
156
+ yield displayed
 
157
 
158
+ # final sync: ensure the assistant content is finalized in GLOBAL_HISTORY (already done)
159
+ with HISTORY_LOCK:
160
+ final_snapshot = list(GLOBAL_HISTORY)
161
+ displayed = make_display_messages_from_history(final_snapshot, current_partial=final_snapshot[-1].get("content", ""))
162
+ yield displayed
163
 
 
 
164
 
165
+ # --- Gradio UI ---
166
  with gr.Blocks() as demo:
167
+ gr.Markdown(f"**Model:** {MODEL_ID} — (UI hides user/system messages; left column shows 'thinking...')")
 
 
168
 
169
+ # Chatbot expects a list of {"role":.., "content":..} dicts when type="messages"
170
+ chatbot = gr.Chatbot(elem_id="chatbot", label="Assistant (user/system hidden)", type="messages", height=560)
 
 
 
 
 
171
 
172
+ with gr.Row():
173
+ with gr.Column(scale=4):
174
+ user_input = gr.Textbox(placeholder="Type a user message and press Send", label="Your message")
175
+ with gr.Column(scale=2):
176
+ system_input = gr.Textbox(value="You are a Vibe Coder assistant.", label="System message")
177
+ max_tokens = gr.Slider(minimum=1, maximum=4000, value=800, step=1, label="Max new tokens")
178
+ temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.01, label="Temperature")
179
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)")
180
+ send_btn = gr.Button("Send")
181
+
182
+ # Hook the streaming respond function. Gradio will accept a generator that yields message lists.
183
  send_btn.click(
184
  fn=respond_stream,
185
+ inputs=[user_input, system_input, max_tokens, temperature, top_p],
186
  outputs=[chatbot],
187
  queue=True,
188
  )
189
 
190
+ # Optional controls
191
+ clear_btn = gr.Button("Reset conversation")
192
+ def reset_all():
193
+ with HISTORY_LOCK:
194
+ GLOBAL_HISTORY.clear()
195
+ return []
196
+ clear_btn.click(fn=reset_all, inputs=None, outputs=[chatbot])
197
+
198
+ gr.Markdown("Note: model loading uses `device_map='auto'` and `torch_dtype='auto'`. "
199
+ "If you run into out-of-memory problems on small GPUs, consider running on a machine with more memory or using model parallel tools.")
200
 
201
  if __name__ == "__main__":
202
  demo.launch()