legolasyiu commited on
Commit
f35ca45
·
verified ·
1 Parent(s): a4fb6dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -42
app.py CHANGED
@@ -4,12 +4,12 @@ Gradio streaming chat where:
4
  - user messages are visible in the UI,
5
  - system messages are hidden (kept for context),
6
  - assistant output is streamed and updates in-place.
 
7
 
8
  Requirements:
9
- - transformers
10
- - gradio
11
- - torch
12
  """
 
13
  import threading
14
  import gradio as gr
15
  import torch
@@ -20,7 +20,7 @@ MODEL_ID = "EpistemeAI/metatune-gpt20b-R0"
20
  print("Loading tokenizer and model (this may take a while)...")
21
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
22
 
23
- # Use auto dtype & device mapping as requested
24
  model = AutoModelForCausalLM.from_pretrained(
25
  MODEL_ID,
26
  torch_dtype="auto",
@@ -54,7 +54,6 @@ def generate_stream(prompt: str, max_tokens: int, temperature: float, top_p: flo
54
  Stream partial strings via TextIteratorStreamer.
55
  """
56
  inputs = tokenizer(prompt, return_tensors="pt")
57
- # Move input ids to model param device where possible (works with many accelerate setups)
58
  try:
59
  input_ids = inputs["input_ids"].to(next(model.parameters()).device)
60
  except Exception:
@@ -82,74 +81,76 @@ def generate_stream(prompt: str, max_tokens: int, temperature: float, top_p: flo
82
 
83
  def visible_messages_from_history(real_history: list, streaming_partial: str | None):
84
  """
85
- Convert internal history into the list of OpenAI-style messages for Gradio UI.
86
-
87
- - Show user messages verbatim (visible).
88
- - Show assistant messages (streamed or final).
89
- - Omit system messages (kept only for model context).
90
  """
91
  msgs = []
92
  for entry in real_history:
93
  role = entry.get("role")
94
  content = entry.get("content", "")
95
  if role == "system":
96
- # hide system from UI
97
  continue
98
- # For assistant messages, we'll use content (may be empty)
99
  msgs.append({"role": role, "content": content or ("thinking..." if role == "assistant" else "")})
100
 
101
- # If we're currently streaming an assistant response, ensure it's reflected as the last assistant msg
102
  if streaming_partial is not None:
103
- # If last message is assistant, replace its content, otherwise append a new (user, assistant) pair
104
  if msgs and msgs[-1]["role"] == "assistant":
105
  msgs[-1]["content"] = streaming_partial
106
  else:
107
- # The user message that started this assistant reply should already be in history and visible.
108
- # Append assistant partial as the reply
109
  msgs.append({"role": "assistant", "content": streaming_partial})
110
 
111
  return msgs
112
 
113
 
114
- def respond_stream(user_message, system_message, max_tokens, temperature, top_p):
115
  """
116
- Gradio streaming handler:
117
- - Append real user message + assistant placeholder to GLOBAL_HISTORY
118
- - Yield visible message lists as the assistant generates tokens
119
  """
120
- # Add the user message and an assistant placeholder into the real history
 
 
 
 
 
 
 
121
  with HISTORY_LOCK:
122
  if system_message:
123
- # include system message in real history for model context (but it won't be shown)
124
  GLOBAL_HISTORY.append({"role": "system", "content": system_message})
125
  GLOBAL_HISTORY.append({"role": "user", "content": user_message})
126
- GLOBAL_HISTORY.append({"role": "assistant", "content": ""}) # placeholder
127
  snapshot = list(GLOBAL_HISTORY)
128
 
129
- # Immediately show user message and assistant placeholder ("thinking...")
130
  initial_display = visible_messages_from_history(snapshot, streaming_partial="thinking...")
131
- yield initial_display
132
 
133
- # Build prompt using the real history but exclude the last assistant placeholder's empty content
134
  with HISTORY_LOCK:
135
- prompt_history = [h for h in GLOBAL_HISTORY[:-1]] # all except the placeholder assistant
136
  prompt = build_prompt(system_message or "", prompt_history, user_message or "")
137
 
138
- # Stream generation and update the last assistant entry
139
  for partial in generate_stream(prompt, max_tokens, temperature, top_p):
140
  with HISTORY_LOCK:
141
- # update global last assistant content
142
  if GLOBAL_HISTORY and GLOBAL_HISTORY[-1]["role"] == "assistant":
143
  GLOBAL_HISTORY[-1]["content"] = partial
144
  snapshot = list(GLOBAL_HISTORY)
145
  display = visible_messages_from_history(snapshot, streaming_partial=partial)
146
- yield display
147
 
148
- # Finalize: ensure assistant final content is shown
149
  with HISTORY_LOCK:
150
  final_snapshot = list(GLOBAL_HISTORY)
151
  final_display = visible_messages_from_history(final_snapshot, streaming_partial=final_snapshot[-1].get("content", ""))
152
- yield final_display
 
 
 
 
 
 
153
 
154
 
155
  # --- Gradio UI ---
@@ -157,12 +158,13 @@ with gr.Blocks() as demo:
157
  gr.Markdown(f"**Model:** {MODEL_ID} — (system messages hidden; user visible)")
158
 
159
  chatbot = gr.Chatbot(elem_id="chatbot", label="Chat", type="messages", height=560)
 
160
 
161
  with gr.Row():
162
  with gr.Column(scale=4):
163
  user_input = gr.Textbox(placeholder="Type a message and press Send", label="Your message")
164
  with gr.Column(scale=2):
165
- system_input = gr.Textbox(value="You are a Vibe Coder assistant.", label="System message (hidden from UI)")
166
  max_tokens = gr.Slider(minimum=1, maximum=4000, value=800, step=1, label="Max new tokens")
167
  temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.01, label="Temperature")
168
  top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)")
@@ -170,20 +172,18 @@ with gr.Blocks() as demo:
170
 
171
  send_btn.click(
172
  fn=respond_stream,
173
- inputs=[user_input, system_input, max_tokens, temperature, top_p],
174
- outputs=[chatbot],
175
  queue=True,
176
  )
177
 
178
  clear_btn = gr.Button("Reset conversation")
179
- def reset_all():
180
- with HISTORY_LOCK:
181
- GLOBAL_HISTORY.clear()
182
- return []
183
- clear_btn.click(fn=reset_all, inputs=None, outputs=[chatbot])
184
 
185
- gr.Markdown("Notes: model loading uses `device_map='auto'` and `torch_dtype='auto'`. "
186
- "If running multi-worker (gunicorn) you will need an external history store (Redis/DB).")
 
 
187
 
188
  if __name__ == "__main__":
189
  demo.launch()
 
4
  - user messages are visible in the UI,
5
  - system messages are hidden (kept for context),
6
  - assistant output is streamed and updates in-place.
7
+ - full back-and-forth memory between turns.
8
 
9
  Requirements:
10
+ pip install torch transformers gradio
 
 
11
  """
12
+
13
  import threading
14
  import gradio as gr
15
  import torch
 
20
  print("Loading tokenizer and model (this may take a while)...")
21
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
22
 
23
+ # Use auto dtype & device mapping
24
  model = AutoModelForCausalLM.from_pretrained(
25
  MODEL_ID,
26
  torch_dtype="auto",
 
54
  Stream partial strings via TextIteratorStreamer.
55
  """
56
  inputs = tokenizer(prompt, return_tensors="pt")
 
57
  try:
58
  input_ids = inputs["input_ids"].to(next(model.parameters()).device)
59
  except Exception:
 
81
 
82
  def visible_messages_from_history(real_history: list, streaming_partial: str | None):
83
  """
84
+ Convert internal history into Gradio-visible messages.
85
+ - Show user messages.
86
+ - Show assistant messages (partial or final).
87
+ - Hide system messages.
 
88
  """
89
  msgs = []
90
  for entry in real_history:
91
  role = entry.get("role")
92
  content = entry.get("content", "")
93
  if role == "system":
 
94
  continue
 
95
  msgs.append({"role": role, "content": content or ("thinking..." if role == "assistant" else "")})
96
 
 
97
  if streaming_partial is not None:
 
98
  if msgs and msgs[-1]["role"] == "assistant":
99
  msgs[-1]["content"] = streaming_partial
100
  else:
 
 
101
  msgs.append({"role": "assistant", "content": streaming_partial})
102
 
103
  return msgs
104
 
105
 
106
+ def respond_stream(user_message, system_message, max_tokens, temperature, top_p, history_state):
107
  """
108
+ Gradio streaming handler with persistent memory.
 
 
109
  """
110
+ if history_state is None:
111
+ history_state = []
112
+
113
+ # Sync local and global histories (optional global memory)
114
+ with HISTORY_LOCK:
115
+ GLOBAL_HISTORY[:] = history_state
116
+
117
+ # Add the new user message and placeholder assistant
118
  with HISTORY_LOCK:
119
  if system_message:
 
120
  GLOBAL_HISTORY.append({"role": "system", "content": system_message})
121
  GLOBAL_HISTORY.append({"role": "user", "content": user_message})
122
+ GLOBAL_HISTORY.append({"role": "assistant", "content": ""})
123
  snapshot = list(GLOBAL_HISTORY)
124
 
125
+ # Show initial "thinking..." state
126
  initial_display = visible_messages_from_history(snapshot, streaming_partial="thinking...")
127
+ yield initial_display, snapshot
128
 
129
+ # Build prompt excluding assistant placeholder
130
  with HISTORY_LOCK:
131
+ prompt_history = [h for h in GLOBAL_HISTORY[:-1]]
132
  prompt = build_prompt(system_message or "", prompt_history, user_message or "")
133
 
134
+ # Stream generation and update assistant output
135
  for partial in generate_stream(prompt, max_tokens, temperature, top_p):
136
  with HISTORY_LOCK:
 
137
  if GLOBAL_HISTORY and GLOBAL_HISTORY[-1]["role"] == "assistant":
138
  GLOBAL_HISTORY[-1]["content"] = partial
139
  snapshot = list(GLOBAL_HISTORY)
140
  display = visible_messages_from_history(snapshot, streaming_partial=partial)
141
+ yield display, snapshot
142
 
143
+ # Final display
144
  with HISTORY_LOCK:
145
  final_snapshot = list(GLOBAL_HISTORY)
146
  final_display = visible_messages_from_history(final_snapshot, streaming_partial=final_snapshot[-1].get("content", ""))
147
+ yield final_display, final_snapshot
148
+
149
+
150
+ def reset_all():
151
+ with HISTORY_LOCK:
152
+ GLOBAL_HISTORY.clear()
153
+ return [], []
154
 
155
 
156
  # --- Gradio UI ---
 
158
  gr.Markdown(f"**Model:** {MODEL_ID} — (system messages hidden; user visible)")
159
 
160
  chatbot = gr.Chatbot(elem_id="chatbot", label="Chat", type="messages", height=560)
161
+ history_state = gr.State([])
162
 
163
  with gr.Row():
164
  with gr.Column(scale=4):
165
  user_input = gr.Textbox(placeholder="Type a message and press Send", label="Your message")
166
  with gr.Column(scale=2):
167
+ system_input = gr.Textbox(value="You are a Vibe Coder assistant.", label="System message (hidden)")
168
  max_tokens = gr.Slider(minimum=1, maximum=4000, value=800, step=1, label="Max new tokens")
169
  temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.01, label="Temperature")
170
  top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)")
 
172
 
173
  send_btn.click(
174
  fn=respond_stream,
175
+ inputs=[user_input, system_input, max_tokens, temperature, top_p, history_state],
176
+ outputs=[chatbot, history_state],
177
  queue=True,
178
  )
179
 
180
  clear_btn = gr.Button("Reset conversation")
181
+ clear_btn.click(fn=reset_all, inputs=None, outputs=[chatbot, history_state])
 
 
 
 
182
 
183
+ gr.Markdown(
184
+ "Notes: model loading uses `device_map='auto'` and `torch_dtype='auto'`. "
185
+ "If running multi-worker (gunicorn) you will need an external history store (Redis/DB)."
186
+ )
187
 
188
  if __name__ == "__main__":
189
  demo.launch()