legolasyiu commited on
Commit
bd75abd
·
verified ·
1 Parent(s): 8b979ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -74
app.py CHANGED
@@ -1,17 +1,16 @@
1
  # save as app.py
2
  """
3
- Gradio streaming chat that hides user/system text.
4
- Left-side messages are always "thinking..." (literal).
5
- Right-side shows the assistant output streamed as it is generated.
 
6
 
7
  Requirements:
8
  - transformers
9
- - accelerate (recommended)
10
  - gradio
11
  - torch
12
  """
13
  import threading
14
- import time
15
  import gradio as gr
16
  import torch
17
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
@@ -21,8 +20,7 @@ MODEL_ID = "EpistemeAI/gpt-oss-20b-RL"
21
  print("Loading tokenizer and model (this may take a while)...")
22
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
23
 
24
- # ALWAYS use auto for dtype & device_map as requested
25
- # This will let HF/accelerate place weights across available devices
26
  model = AutoModelForCausalLM.from_pretrained(
27
  MODEL_ID,
28
  torch_dtype="auto",
@@ -31,15 +29,14 @@ model = AutoModelForCausalLM.from_pretrained(
31
  model.eval()
32
  print("Model loaded. Example param device:", next(model.parameters()).device)
33
 
34
- # Global history (real content). Stored as list of {"role": "user"|"assistant"|"system", "content": "..."}
35
- GLOBAL_HISTORY = []
36
  HISTORY_LOCK = threading.Lock()
37
 
38
 
39
  def build_prompt(system_message: str, history: list, user_message: str) -> str:
40
  """
41
- Build the model prompt in your preferred format.
42
- Adjust this function if your model expects a different conversation format.
43
  """
44
  pieces = []
45
  if system_message:
@@ -54,15 +51,13 @@ def build_prompt(system_message: str, history: list, user_message: str) -> str:
54
 
55
  def generate_stream(prompt: str, max_tokens: int, temperature: float, top_p: float):
56
  """
57
- Yields partial strings as the model generates tokens using TextIteratorStreamer.
58
  """
59
- # Tokenize (we avoid forcing a single-device .to(...) in case of HF sharded device_map)
60
  inputs = tokenizer(prompt, return_tensors="pt")
61
- # Move input_ids to same device as a model parameter (works with many configs)
62
  try:
63
  input_ids = inputs["input_ids"].to(next(model.parameters()).device)
64
  except Exception:
65
- # Fallback: do not move if that fails (accelerate may handle placement)
66
  input_ids = inputs["input_ids"]
67
 
68
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
@@ -85,101 +80,94 @@ def generate_stream(prompt: str, max_tokens: int, temperature: float, top_p: flo
85
  yield partial
86
 
87
 
88
- def make_display_messages_from_history(real_history: list, current_partial: str | None):
89
  """
90
- Convert the internal 'real_history' (which contains real user & assistant content)
91
- into the list of openai-style message dicts that Gradio Chatbot (type="messages")
92
- expects. Every non-assistant message is replaced with a literal "thinking...".
93
- For each assistant exchange we produce:
94
- {"role":"user", "content":"thinking..."}
95
- {"role":"assistant", "content": "<assistant content or partial stream>"}
96
- The UI will therefore show the left side text "thinking..." and right side the assistant.
97
  """
98
  msgs = []
99
- # Walk the real history and whenever we hit an assistant turn, pair it with a thinking user
100
- i = 0
101
- while i < len(real_history):
102
- item = real_history[i]
103
- if item["role"] == "assistant":
104
- # the user message that caused this assistant reply is typically just before it,
105
- # but we hide the real user content and show "thinking..." instead.
106
- msgs.append({"role": "user", "content": "thinking..."})
107
- msgs.append({"role": "assistant", "content": item.get("content", "") or "thinking..."})
108
- i += 1
109
-
110
- # If a current_partial exists (we're streaming a new assistant response),
111
- # ensure it's reflected as the last assistant message (with a preceding "thinking...")
112
- if current_partial is not None:
113
- # If the last two entries are already the streaming pair, replace them; otherwise append new
114
  if msgs and msgs[-1]["role"] == "assistant":
115
- msgs[-1]["content"] = current_partial
116
  else:
117
- msgs.append({"role": "user", "content": "thinking..."})
118
- msgs.append({"role": "assistant", "content": current_partial})
 
 
119
  return msgs
120
 
121
 
122
  def respond_stream(user_message, system_message, max_tokens, temperature, top_p):
123
  """
124
- Gradio streaming function that yields successive message-lists (OpenAI-style dicts).
125
- It mutates GLOBAL_HISTORY to store the true conversation, but the UI only ever sees
126
- 'thinking...' in non-assistant slots and the assistant's streamed content on the right.
127
  """
128
- # Append the real user turn and an empty assistant placeholder to GLOBAL_HISTORY
129
  with HISTORY_LOCK:
 
 
 
130
  GLOBAL_HISTORY.append({"role": "user", "content": user_message})
131
- GLOBAL_HISTORY.append({"role": "assistant", "content": ""}) # placeholder for streaming
132
- # create a shallow copy for local read
133
- local_history = list(GLOBAL_HISTORY)
134
 
135
- # initial UI placeholder: show existing assistant rows and the new placeholder
136
- displayed = make_display_messages_from_history(local_history, current_partial="thinking...")
137
- yield displayed
138
 
139
- # Build model prompt from the real history (exclude the last assistant placeholder content)
140
- # We pass the actual global history (safe to read under lock copy)
141
  with HISTORY_LOCK:
142
- # Send a snapshot (exclude the last assistant placeholder since it's empty)
143
- prompt_history = [h for h in GLOBAL_HISTORY[:-1] if h.get("role")]
144
  prompt = build_prompt(system_message or "", prompt_history, user_message or "")
145
 
146
- # Stream generation
147
  for partial in generate_stream(prompt, max_tokens, temperature, top_p):
148
- # Update the global assistant placeholder with the partial so future turns keep context
149
  with HISTORY_LOCK:
150
- # Update the last assistant placeholder
151
  if GLOBAL_HISTORY and GLOBAL_HISTORY[-1]["role"] == "assistant":
152
  GLOBAL_HISTORY[-1]["content"] = partial
153
- current_snapshot = list(GLOBAL_HISTORY)
154
-
155
- displayed = make_display_messages_from_history(current_snapshot, current_partial=partial)
156
- yield displayed
157
 
158
- # final sync: ensure the assistant content is finalized in GLOBAL_HISTORY (already done)
159
  with HISTORY_LOCK:
160
  final_snapshot = list(GLOBAL_HISTORY)
161
- displayed = make_display_messages_from_history(final_snapshot, current_partial=final_snapshot[-1].get("content", ""))
162
- yield displayed
163
 
164
 
165
  # --- Gradio UI ---
166
  with gr.Blocks() as demo:
167
- gr.Markdown(f"**Model:** {MODEL_ID} — (UI hides user/system messages; left column shows 'thinking...')")
168
 
169
- # Chatbot expects a list of {"role":.., "content":..} dicts when type="messages"
170
- chatbot = gr.Chatbot(elem_id="chatbot", label="Assistant (user/system hidden)", type="messages", height=560)
171
 
172
  with gr.Row():
173
  with gr.Column(scale=4):
174
- user_input = gr.Textbox(placeholder="Type a user message and press Send", label="Your message")
175
  with gr.Column(scale=2):
176
- system_input = gr.Textbox(value="You are a Vibe Coder assistant.", label="System message")
177
  max_tokens = gr.Slider(minimum=1, maximum=4000, value=800, step=1, label="Max new tokens")
178
  temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.01, label="Temperature")
179
  top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)")
180
  send_btn = gr.Button("Send")
181
 
182
- # Hook the streaming respond function. Gradio will accept a generator that yields message lists.
183
  send_btn.click(
184
  fn=respond_stream,
185
  inputs=[user_input, system_input, max_tokens, temperature, top_p],
@@ -187,7 +175,6 @@ with gr.Blocks() as demo:
187
  queue=True,
188
  )
189
 
190
- # Optional controls
191
  clear_btn = gr.Button("Reset conversation")
192
  def reset_all():
193
  with HISTORY_LOCK:
@@ -195,8 +182,8 @@ with gr.Blocks() as demo:
195
  return []
196
  clear_btn.click(fn=reset_all, inputs=None, outputs=[chatbot])
197
 
198
- gr.Markdown("Note: model loading uses `device_map='auto'` and `torch_dtype='auto'`. "
199
- "If you run into out-of-memory problems on small GPUs, consider running on a machine with more memory or using model parallel tools.")
200
 
201
  if __name__ == "__main__":
202
  demo.launch()
 
1
  # save as app.py
2
  """
3
+ Gradio streaming chat where:
4
+ - user messages are visible in the UI,
5
+ - system messages are hidden (kept for context),
6
+ - assistant output is streamed and updates in-place.
7
 
8
  Requirements:
9
  - transformers
 
10
  - gradio
11
  - torch
12
  """
13
  import threading
 
14
  import gradio as gr
15
  import torch
16
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
20
  print("Loading tokenizer and model (this may take a while)...")
21
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
22
 
23
+ # Use auto dtype & device mapping as requested
 
24
  model = AutoModelForCausalLM.from_pretrained(
25
  MODEL_ID,
26
  torch_dtype="auto",
 
29
  model.eval()
30
  print("Model loaded. Example param device:", next(model.parameters()).device)
31
 
32
+ # Thread-safe global history
33
+ GLOBAL_HISTORY = [] # list of {"role": "system"|"user"|"assistant", "content": "..."}
34
  HISTORY_LOCK = threading.Lock()
35
 
36
 
37
  def build_prompt(system_message: str, history: list, user_message: str) -> str:
38
  """
39
+ Build prompt in the model's expected format. Adjust as needed.
 
40
  """
41
  pieces = []
42
  if system_message:
 
51
 
52
  def generate_stream(prompt: str, max_tokens: int, temperature: float, top_p: float):
53
  """
54
+ Stream partial strings via TextIteratorStreamer.
55
  """
 
56
  inputs = tokenizer(prompt, return_tensors="pt")
57
+ # Move input ids to model param device where possible (works with many accelerate setups)
58
  try:
59
  input_ids = inputs["input_ids"].to(next(model.parameters()).device)
60
  except Exception:
 
61
  input_ids = inputs["input_ids"]
62
 
63
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
80
  yield partial
81
 
82
 
83
+ def visible_messages_from_history(real_history: list, streaming_partial: str | None):
84
  """
85
+ Convert internal history into the list of OpenAI-style messages for Gradio UI.
86
+
87
+ - Show user messages verbatim (visible).
88
+ - Show assistant messages (streamed or final).
89
+ - Omit system messages (kept only for model context).
 
 
90
  """
91
  msgs = []
92
+ for entry in real_history:
93
+ role = entry.get("role")
94
+ content = entry.get("content", "")
95
+ if role == "system":
96
+ # hide system from UI
97
+ continue
98
+ # For assistant messages, we'll use content (may be empty)
99
+ msgs.append({"role": role, "content": content or ("thinking..." if role == "assistant" else "")})
100
+
101
+ # If we're currently streaming an assistant response, ensure it's reflected as the last assistant msg
102
+ if streaming_partial is not None:
103
+ # If last message is assistant, replace its content, otherwise append a new (user, assistant) pair
 
 
 
104
  if msgs and msgs[-1]["role"] == "assistant":
105
+ msgs[-1]["content"] = streaming_partial
106
  else:
107
+ # The user message that started this assistant reply should already be in history and visible.
108
+ # Append assistant partial as the reply
109
+ msgs.append({"role": "assistant", "content": streaming_partial})
110
+
111
  return msgs
112
 
113
 
114
  def respond_stream(user_message, system_message, max_tokens, temperature, top_p):
115
  """
116
+ Gradio streaming handler:
117
+ - Append real user message + assistant placeholder to GLOBAL_HISTORY
118
+ - Yield visible message lists as the assistant generates tokens
119
  """
120
+ # Add the user message and an assistant placeholder into the real history
121
  with HISTORY_LOCK:
122
+ if system_message:
123
+ # include system message in real history for model context (but it won't be shown)
124
+ GLOBAL_HISTORY.append({"role": "system", "content": system_message})
125
  GLOBAL_HISTORY.append({"role": "user", "content": user_message})
126
+ GLOBAL_HISTORY.append({"role": "assistant", "content": ""}) # placeholder
127
+ snapshot = list(GLOBAL_HISTORY)
 
128
 
129
+ # Immediately show user message and assistant placeholder ("thinking...")
130
+ initial_display = visible_messages_from_history(snapshot, streaming_partial="thinking...")
131
+ yield initial_display
132
 
133
+ # Build prompt using the real history but exclude the last assistant placeholder's empty content
 
134
  with HISTORY_LOCK:
135
+ prompt_history = [h for h in GLOBAL_HISTORY[:-1]] # all except the placeholder assistant
 
136
  prompt = build_prompt(system_message or "", prompt_history, user_message or "")
137
 
138
+ # Stream generation and update the last assistant entry
139
  for partial in generate_stream(prompt, max_tokens, temperature, top_p):
 
140
  with HISTORY_LOCK:
141
+ # update global last assistant content
142
  if GLOBAL_HISTORY and GLOBAL_HISTORY[-1]["role"] == "assistant":
143
  GLOBAL_HISTORY[-1]["content"] = partial
144
+ snapshot = list(GLOBAL_HISTORY)
145
+ display = visible_messages_from_history(snapshot, streaming_partial=partial)
146
+ yield display
 
147
 
148
+ # Finalize: ensure assistant final content is shown
149
  with HISTORY_LOCK:
150
  final_snapshot = list(GLOBAL_HISTORY)
151
+ final_display = visible_messages_from_history(final_snapshot, streaming_partial=final_snapshot[-1].get("content", ""))
152
+ yield final_display
153
 
154
 
155
  # --- Gradio UI ---
156
  with gr.Blocks() as demo:
157
+ gr.Markdown(f"**Model:** {MODEL_ID} — (system messages hidden; user visible)")
158
 
159
+ chatbot = gr.Chatbot(elem_id="chatbot", label="Chat", type="messages", height=560)
 
160
 
161
  with gr.Row():
162
  with gr.Column(scale=4):
163
+ user_input = gr.Textbox(placeholder="Type a message and press Send", label="Your message")
164
  with gr.Column(scale=2):
165
+ system_input = gr.Textbox(value="You are a Vibe Coder assistant.", label="System message (hidden from UI)")
166
  max_tokens = gr.Slider(minimum=1, maximum=4000, value=800, step=1, label="Max new tokens")
167
  temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.01, label="Temperature")
168
  top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)")
169
  send_btn = gr.Button("Send")
170
 
 
171
  send_btn.click(
172
  fn=respond_stream,
173
  inputs=[user_input, system_input, max_tokens, temperature, top_p],
 
175
  queue=True,
176
  )
177
 
 
178
  clear_btn = gr.Button("Reset conversation")
179
  def reset_all():
180
  with HISTORY_LOCK:
 
182
  return []
183
  clear_btn.click(fn=reset_all, inputs=None, outputs=[chatbot])
184
 
185
+ gr.Markdown("Notes: model loading uses `device_map='auto'` and `torch_dtype='auto'`. "
186
+ "If running multi-worker (gunicorn) you will need an external history store (Redis/DB).")
187
 
188
  if __name__ == "__main__":
189
  demo.launch()