PeterPinetree commited on
Commit
a097df3
·
1 Parent(s): bf08f52

Refactor to use local model inference for next token prediction and enhance token display functionality

Browse files
Files changed (2) hide show
  1. app.py +182 -149
  2. requirements.txt +5 -1
app.py CHANGED
@@ -1,19 +1,24 @@
1
  import gradio as gr
2
- import requests
3
  import json
4
  import os
5
  import time
 
6
  from typing import List, Dict, Tuple
7
  from dotenv import load_dotenv
 
8
 
9
  # Load environment variables from .env file
10
  load_dotenv()
11
 
12
  # Configuration
13
- API_BASE = "https://router.huggingface.co/v1/"
14
  MODEL_ID = "Qwen/Qwen3-0.6B"
15
  HF_TOKEN = os.getenv('HF_NEXT_TOKEN_PREDICTOR_TOKEN', '')
16
 
 
 
 
 
 
17
  def show_token(token: str) -> str:
18
  """Format token for display"""
19
  if token == "\n":
@@ -22,152 +27,96 @@ def show_token(token: str) -> str:
22
  return f"␣{'' if len(token) == 1 else '×' + str(len(token))}"
23
  return token
24
 
25
- def predict_next_token(text: str, top_k: int = 10, hide_punctuation: bool = False) -> Tuple[str, str]:
26
- """Predict next tokens using HF Serverless API"""
27
-
28
- if not HF_TOKEN:
29
- return "❌ No HF_NEXT_TOKEN_PREDICTOR_TOKEN found in environment variables", ""
30
 
31
  if not text.strip():
32
- return "Please enter some text to predict from", ""
33
 
34
  start_time = time.time()
35
 
36
  try:
37
- # Call Hugging Face Inference Providers API (Text Generation format)
38
- url = f"{API_BASE}text-generation"
39
- headers = {
40
- 'Authorization': f'Bearer {HF_TOKEN}',
41
- 'Content-Type': 'application/json',
42
- }
43
- payload = {
44
- 'model': MODEL_ID,
45
- 'inputs': text,
46
- 'parameters': {
47
- 'max_new_tokens': 1,
48
- 'temperature': 0.0,
49
- 'do_sample': False,
50
- 'return_full_text': False
51
- }
52
- }
53
 
54
- response = requests.post(url, headers=headers, json=payload, timeout=30)
 
 
 
55
 
56
- # Debug logging
57
- print(f"API URL: {url}")
58
- print(f"Response status: {response.status_code}")
59
- if not response.ok:
60
- print(f"Response text: {response.text}")
61
 
62
- if not response.ok:
63
- # Try a different Qwen model as fallback if the main model fails
64
- if MODEL_ID == "Qwen/Qwen3-0.6B":
65
- print(f"Main model failed, trying Qwen2.5-0.5B fallback...")
66
- fallback_url = f"{API_BASE}Qwen/Qwen2.5-0.5B-Instruct"
67
- fallback_response = requests.post(fallback_url, headers=headers, json=payload, timeout=30)
68
- print(f"Fallback response status: {fallback_response.status_code}")
69
- if fallback_response.ok:
70
- response = fallback_response
71
- print("✅ Fallback successful!")
72
- else:
73
- print(f"Fallback also failed: {fallback_response.text[:100]}")
74
-
75
- # If still not ok after fallback attempt
76
- if not response.ok:
77
- error_msg = f"API Error: {response.status_code} for model {MODEL_ID}"
78
- try:
79
- error_detail = response.json()
80
- if 'error' in error_detail:
81
- error_msg += f" - {error_detail['error']}"
82
- except:
83
- error_msg += f" - {response.text[:200]}"
84
- return error_msg, ""
85
 
86
- result = response.json()
87
- prediction_time = int((time.time() - start_time) * 1000)
 
 
88
 
89
- # Parse chat completion response - it returns a single message, not probabilities
90
- try:
91
- predicted_text = result['choices'][0]['message']['content'].strip()
92
- # Extract just the next word (in case model returns more)
93
- next_word = predicted_text.split()[0] if predicted_text else "?"
94
-
95
- # Create simple display since we don't have probabilities
96
- tokens_html = create_simple_token_display(next_word)
97
-
98
- except (KeyError, IndexError) as e:
99
- return f"❌ Error parsing response: {str(e)}", ""
 
 
 
100
 
101
- return tokens_html, f"Prediction time: {prediction_time}ms"
102
 
103
- except requests.exceptions.Timeout:
104
- return "❌ API request timed out. The model might be loading - try again in a moment.", ""
105
- except requests.exceptions.RequestException as e:
106
- return f"❌ Network error: {str(e)}", ""
107
  except Exception as e:
108
- return f"❌ Error: {str(e)}", ""
109
 
110
- def create_simple_token_display(predicted_word: str) -> str:
111
- """Create HTML display for a single predicted token (chat completions format)"""
112
-
113
- # Create HTML for single token
114
- html = """
115
- <div style="font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace; background: #0e162b; border: 1px solid #1c2945; border-radius: 14px; padding: 12px;">
116
- """
117
-
118
- token_display = show_token(predicted_word)
119
-
120
- html += f"""
121
- <div style="display: grid; grid-template-columns: 1fr auto; gap: 8px; align-items: center; padding: 8px 10px; margin: 4px 0; border-radius: 10px; background: #0f1930; border: 1px solid #22365e; cursor: pointer;">
122
- <div style="color: #e6f1ff; font-size: 14px;">{token_display}</div>
123
- <div style="color: #9ab0d0; font-size: 12px;">Predicted</div>
124
- </div>
125
- """
126
-
127
- html += "</div>"
128
- return html
129
-
130
- def create_token_display(api_result: dict, top_k: int, hide_punctuation: bool) -> str:
131
- """Create HTML display for predicted tokens"""
132
-
133
- # For demo purposes, create some example predictions
134
- # In a real implementation, you'd parse the API response properly
135
- demo_tokens = [
136
- {"token": "star", "prob": 0.35},
137
- {"token": "light", "prob": 0.25},
138
- {"token": "night", "prob": 0.15},
139
- {"token": "sky", "prob": 0.10},
140
- {"token": "bright", "prob": 0.08},
141
- {"token": "moon", "prob": 0.04},
142
- {"token": "sun", "prob": 0.03}
143
- ]
144
-
145
- # Filter punctuation if requested
146
- if hide_punctuation:
147
- import re
148
- PUNC_ONLY = re.compile(r'^[\s.,;:!?—-]+$')
149
- demo_tokens = [t for t in demo_tokens if not PUNC_ONLY.match(t['token'])]
150
-
151
- # Take only top_k
152
- tokens = demo_tokens[:top_k]
153
 
154
- # Create HTML
155
  html = """
156
- <div style="font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace; background: #0e162b; border: 1px solid #1c2945; border-radius: 14px; padding: 12px;">
157
  """
158
 
159
- for token in tokens:
160
- token_display = show_token(token['token'])
161
- percentage = f"{token['prob'] * 100:.2f}%"
162
 
163
  html += f"""
164
- <div style="display: grid; grid-template-columns: 1fr auto; gap: 8px; align-items: center; padding: 8px 10px; margin: 4px 0; border-radius: 10px; background: #0f1930; border: 1px solid #22365e; cursor: pointer;">
 
 
 
165
  <div style="color: #e6f1ff; font-size: 14px;">{token_display}</div>
166
  <div style="color: #9ab0d0; font-size: 12px;">{percentage}</div>
167
  </div>
168
  """
169
 
170
- html += "</div>"
 
 
 
171
  return html
172
 
173
  # Custom CSS to match the original design
@@ -187,6 +136,37 @@ custom_css = """
187
  background: #0e1629 !important;
188
  border: 1px solid #1c2945 !important;
189
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  """
191
 
192
  # Create Gradio interface
@@ -194,17 +174,10 @@ with gr.Blocks(css=custom_css, title="Next-Token Predictor") as app:
194
  gr.HTML("""
195
  <div style="text-align: center; padding: 20px; background: #0e1629; border-bottom: 1px solid #1c2945;">
196
  <h1 style="color: #e6f1ff; margin: 0; font-size: 24px;">🤗 Next-Token Predictor</h1>
197
- <p style="color: #9ab0d0; margin: 10px 0 0 0;">Explore how AI predicts the next word! Predictions update automatically as you type.</p>
198
  </div>
199
  """)
200
 
201
- if not HF_TOKEN:
202
- gr.HTML("""
203
- <div style="background: #ffb4c0; color: #000; padding: 10px; border-radius: 8px; margin: 10px;">
204
- ⚠️ <strong>HF_NEXT_TOKEN_PREDICTOR_TOKEN not found!</strong> Please set your Hugging Face token as an environment variable or Space secret.
205
- </div>
206
- """)
207
-
208
  with gr.Row():
209
  with gr.Column(scale=1):
210
  text_input = gr.Textbox(
@@ -218,41 +191,101 @@ with gr.Blocks(css=custom_css, title="Next-Token Predictor") as app:
218
  with gr.Row():
219
  top_k = gr.Slider(
220
  minimum=5,
221
- maximum=30,
222
  value=10,
223
  step=1,
224
- label="Top-K predictions",
225
- info="How many predictions to show"
 
 
 
 
 
 
 
 
 
 
 
 
226
  )
227
- hide_punct = gr.Checkbox(
228
- label="Hide punctuation-only tokens",
229
- value=False,
230
- info="Focus on meaningful words"
 
 
 
 
 
231
  )
232
 
233
  timing_info = gr.HTML(value="<div style='color: #9ab0d0; font-size: 12px;'>✨ Predictions update as you type!</div>")
234
 
235
  with gr.Column(scale=1):
236
- predictions_html = gr.HTML(label="🔮 Next Token Predictions")
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
- # Event handlers - auto-prediction on any change
239
- def update_predictions(text, k, hide_p):
240
- result_html, timing = predict_next_token(text, int(k), hide_p)
241
- return result_html, timing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
  # Auto-predict on any input change
244
- for component in [text_input, top_k, hide_punct]:
 
245
  component.change(
246
- update_predictions,
247
- inputs=[text_input, top_k, hide_punct],
248
- outputs=[predictions_html, timing_info]
 
 
 
 
 
 
 
 
249
  )
250
 
251
  # Load initial predictions on app start
252
  app.load(
253
- lambda: update_predictions("Twinkle, twinkle, little ", 10, False),
254
- outputs=[predictions_html, timing_info]
255
  )
256
 
257
  if __name__ == "__main__":
258
- app.launch(share=False, server_port=7860)
 
1
  import gradio as gr
 
2
  import json
3
  import os
4
  import time
5
+ import torch
6
  from typing import List, Dict, Tuple
7
  from dotenv import load_dotenv
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
 
10
  # Load environment variables from .env file
11
  load_dotenv()
12
 
13
  # Configuration
 
14
  MODEL_ID = "Qwen/Qwen3-0.6B"
15
  HF_TOKEN = os.getenv('HF_NEXT_TOKEN_PREDICTOR_TOKEN', '')
16
 
17
+ # Initialize model and tokenizer (local inference like the working app)
18
+ print("Loading model and tokenizer...")
19
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
20
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
21
+
22
  def show_token(token: str) -> str:
23
  """Format token for display"""
24
  if token == "\n":
 
27
  return f"␣{'' if len(token) == 1 else '×' + str(len(token))}"
28
  return token
29
 
30
+ def predict_next_token(text: str, top_k: int = 10, temperature: float = 1.0, top_p: float = 0.9) -> Tuple[List[Dict], str]:
31
+ """Predict next tokens using local model with temperature and top-p filtering"""
 
 
 
32
 
33
  if not text.strip():
34
+ return [], "Please enter some text to predict from"
35
 
36
  start_time = time.time()
37
 
38
  try:
39
+ # Use local model inference
40
+ tokens = tokenizer(text, return_tensors="pt", padding=False)
41
+ out = model.generate(
42
+ **tokens,
43
+ max_new_tokens=1,
44
+ output_scores=True,
45
+ return_dict_in_generate=True,
46
+ pad_token_id=tokenizer.eos_token_id,
47
+ do_sample=False,
48
+ )
 
 
 
 
 
 
49
 
50
+ # Get raw logits and apply temperature scaling
51
+ logits = out.scores[0]
52
+ scaled_logits = logits / temperature
53
+ scores = torch.softmax(scaled_logits, dim=-1)
54
 
55
+ # Apply top-p filtering (nucleus sampling)
56
+ sorted_probs, sorted_indices = torch.sort(scores, descending=True)
57
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
 
 
58
 
59
+ # Find the cutoff point for top-p
60
+ cutoff_index = torch.where(cumulative_probs >= top_p)[1]
61
+ if len(cutoff_index) > 0:
62
+ cutoff = cutoff_index[0].item() + 1
63
+ top_p_indices = sorted_indices[0, :cutoff]
64
+ top_p_probs = sorted_probs[0, :cutoff]
65
+ else:
66
+ # Fallback if top_p is very low
67
+ top_p_indices = sorted_indices[0, :min(50, len(sorted_indices[0]))]
68
+ top_p_probs = sorted_probs[0, :min(50, len(sorted_probs[0]))]
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ # Apply top-k to the top-p filtered results
71
+ final_k = min(top_k, len(top_p_indices))
72
+ final_indices = top_p_indices[:final_k]
73
+ final_probs = top_p_probs[:final_k]
74
 
75
+ # Convert to tokens
76
+ token_ids = [int(idx) for idx in final_indices]
77
+ probs = [float(prob) for prob in final_probs]
78
+ tokens_text = [tokenizer.decode([tid]) for tid in token_ids]
79
+
80
+ # Create token data structure
81
+ tokens_data = []
82
+ for i in range(len(token_ids)):
83
+ tokens_data.append({
84
+ "token": tokens_text[i],
85
+ "prob": probs[i]
86
+ })
87
+
88
+ prediction_time = int((time.time() - start_time) * 1000)
89
 
90
+ return tokens_data, f"Prediction time: {prediction_time}ms"
91
 
 
 
 
 
92
  except Exception as e:
93
+ return [], f"❌ Error: {str(e)}"
94
 
95
+ def create_clickable_token_display(tokens_data: List[Dict]) -> str:
96
+ """Create HTML display with clickable tokens - simplified without JavaScript"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
 
98
  html = """
99
+ <div id="token-predictions" style="font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace; background: #0e162b; border: 1px solid #1c2945; border-radius: 14px; padding: 12px;">
100
  """
101
 
102
+ for i, token_data in enumerate(tokens_data):
103
+ token_display = show_token(token_data['token'])
104
+ percentage = f"{token_data['prob'] * 100:.2f}%"
105
 
106
  html += f"""
107
+ <div class="token-prediction" data-token="{token_data['token']}"
108
+ style="display: grid; grid-template-columns: 1fr auto; gap: 8px; align-items: center; padding: 8px 10px; margin: 4px 0; border-radius: 10px; background: #0f1930; border: 1px solid #22365e; cursor: pointer; transition: background 0.2s;"
109
+ onmouseover="this.style.background='#1a2b4a'"
110
+ onmouseout="this.style.background='#0f1930'">
111
  <div style="color: #e6f1ff; font-size: 14px;">{token_display}</div>
112
  <div style="color: #9ab0d0; font-size: 12px;">{percentage}</div>
113
  </div>
114
  """
115
 
116
+ html += """
117
+ </div>
118
+ """
119
+
120
  return html
121
 
122
  # Custom CSS to match the original design
 
136
  background: #0e1629 !important;
137
  border: 1px solid #1c2945 !important;
138
  }
139
+
140
+ .token-button {
141
+ background: #0f1930 !important;
142
+ border: 1px solid #22365e !important;
143
+ color: #e6f1ff !important;
144
+ border-radius: 6px !important;
145
+ margin: 0px !important;
146
+ padding: 2px 6px !important;
147
+ font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace !important;
148
+ transition: background 0.2s !important;
149
+ font-size: 12px !important;
150
+ }
151
+
152
+ .token-button:hover {
153
+ background: #1a2b4a !important;
154
+ }
155
+
156
+ /* Remove Gradio's default spacing between buttons */
157
+ .token-button + .token-button {
158
+ margin-top: 0px !important;
159
+ }
160
+
161
+ /* Remove gaps in the column containing buttons */
162
+ div:has(> .token-button) {
163
+ gap: 0px !important;
164
+ }
165
+
166
+ /* Target Gradio's automatic spacing */
167
+ .block > div > div {
168
+ gap: 0px !important;
169
+ }
170
  """
171
 
172
  # Create Gradio interface
 
174
  gr.HTML("""
175
  <div style="text-align: center; padding: 20px; background: #0e1629; border-bottom: 1px solid #1c2945;">
176
  <h1 style="color: #e6f1ff; margin: 0; font-size: 24px;">🤗 Next-Token Predictor</h1>
177
+ <p style="color: #9ab0d0; margin: 10px 0 0 0;">Explore how AI predicts the next word! Click on predictions to append them.</p>
178
  </div>
179
  """)
180
 
 
 
 
 
 
 
 
181
  with gr.Row():
182
  with gr.Column(scale=1):
183
  text_input = gr.Textbox(
 
191
  with gr.Row():
192
  top_k = gr.Slider(
193
  minimum=5,
194
+ maximum=15,
195
  value=10,
196
  step=1,
197
+ label="Top-K",
198
+ info="How many top predictions to show",
199
+ show_label=True,
200
+ interactive=True
201
+ )
202
+ temperature = gr.Slider(
203
+ minimum=0.1,
204
+ maximum=2.0,
205
+ value=1.0,
206
+ step=0.1,
207
+ label="Temperature",
208
+ info="Creativity: Low=predictable, High=surprising",
209
+ show_label=True,
210
+ interactive=True
211
  )
212
+ top_p = gr.Slider(
213
+ minimum=0.1,
214
+ maximum=1.0,
215
+ value=0.9,
216
+ step=0.05,
217
+ label="Top-P",
218
+ info="Consider words making up this % of probability",
219
+ show_label=True,
220
+ interactive=True
221
  )
222
 
223
  timing_info = gr.HTML(value="<div style='color: #9ab0d0; font-size: 12px;'>✨ Predictions update as you type!</div>")
224
 
225
  with gr.Column(scale=1):
226
+ # Create a column for token buttons
227
+ with gr.Column():
228
+ gr.HTML("<h4 style='color: #e6f1ff; margin: 0;'>🔮 Next Token Predictions</h4>")
229
+
230
+ # Create buttons for each possible token (we'll show/hide as needed)
231
+ token_buttons = []
232
+ for i in range(15): # Support up to 15 tokens
233
+ btn = gr.Button(
234
+ value="",
235
+ visible=False,
236
+ elem_classes=["token-button"],
237
+ size="sm"
238
+ )
239
+ token_buttons.append(btn)
240
 
241
+ # Store current tokens data
242
+ current_tokens = gr.State([])
243
+
244
+ def update_predictions_and_buttons(text, k, temp, p):
245
+ tokens_data, timing = predict_next_token(text, int(k), float(temp), float(p))
246
+
247
+ # Update button states
248
+ button_updates = []
249
+ for i in range(15):
250
+ if i < len(tokens_data):
251
+ token = tokens_data[i]['token']
252
+ prob = tokens_data[i]['prob']
253
+ display_token = show_token(token)
254
+ button_label = f"{display_token} ({prob*100:.1f}%)"
255
+ button_updates.append(gr.Button(value=button_label, visible=True))
256
+ else:
257
+ button_updates.append(gr.Button(visible=False))
258
+
259
+ return [timing, tokens_data] + button_updates
260
+
261
+ def append_token_to_input(current_text, tokens_data, button_index):
262
+ if tokens_data and 0 <= button_index < len(tokens_data):
263
+ token = tokens_data[button_index]['token']
264
+ return current_text + token
265
+ return current_text
266
 
267
  # Auto-predict on any input change
268
+ outputs = [timing_info, current_tokens] + token_buttons
269
+ for component in [text_input, top_k, temperature, top_p]:
270
  component.change(
271
+ update_predictions_and_buttons,
272
+ inputs=[text_input, top_k, temperature, top_p],
273
+ outputs=outputs
274
+ )
275
+
276
+ # Set up click handlers for each token button
277
+ for i, btn in enumerate(token_buttons):
278
+ btn.click(
279
+ lambda text, tokens, idx=i: append_token_to_input(text, tokens, idx),
280
+ inputs=[text_input, current_tokens],
281
+ outputs=[text_input]
282
  )
283
 
284
  # Load initial predictions on app start
285
  app.load(
286
+ lambda: update_predictions_and_buttons("Twinkle, twinkle, little ", 10, 1.0, 0.9),
287
+ outputs=outputs
288
  )
289
 
290
  if __name__ == "__main__":
291
+ app.launch(share=False)
requirements.txt CHANGED
@@ -1,3 +1,7 @@
1
  gradio==4.44.1
2
  requests==2.31.0
3
- python-dotenv==1.0.0
 
 
 
 
 
1
  gradio==4.44.1
2
  requests==2.31.0
3
+ python-dotenv==1.0.0
4
+ transformers
5
+ torch
6
+ anywidget
7
+ traitlets