AlexTransformer commited on
Commit
b9a15d9
Β·
verified Β·
1 Parent(s): 9387d94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -91
app.py CHANGED
@@ -4,7 +4,9 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import gradio as gr
5
  import pandas as pd
6
  import math
7
- from plotly import graph_objects as go
 
 
8
 
9
  # Load model and tokenizer
10
  model_ids = {
@@ -29,14 +31,24 @@ def format_prob(prob):
29
 
30
  # Helper function to format log probability
31
  def format_log_prob(log_prob):
32
- """Format log probability with color coding"""
33
  return f"{log_prob:.3f}"
34
 
 
 
 
 
 
 
 
 
 
 
35
  # Main function: compute token-wise log probabilities and top-k predictions
36
  @torch.no_grad()
37
  def compare_models(text, top_k=5):
38
  if not text.strip():
39
- return None, "⚠️ Please enter some text to analyze"
40
 
41
  results = {}
42
 
@@ -63,6 +75,8 @@ def compare_models(text, top_k=5):
63
  # Generate top-k predictions for each position (up to first 20 tokens)
64
  topk_list = []
65
  confidence_list = []
 
 
66
  for i in range(min(20, shift_logits.shape[1])):
67
  topk = torch.topk(log_probs[0, i], k=top_k)
68
  topk_ids = topk.indices.tolist()
@@ -77,71 +91,124 @@ def compare_models(text, top_k=5):
77
  # Calculate confidence (probability of actual token)
78
  actual_token_prob = math.exp(token_log_probs[0, i].item())
79
  confidence_list.append(actual_token_prob)
 
 
 
 
80
 
81
- # Prepare dataframe for display
82
- df = pd.DataFrame({
83
- "Token": tokens[:20],
84
- "LogProb": [format_log_prob(float(x)) for x in token_log_probs[0][:20]],
85
- "Confidence": [format_prob(x) for x in confidence_list[:20]],
86
- f"Top-{top_k} Predictions": topk_list
87
- })
88
-
89
  results[model_name] = {
90
- "df": df,
91
- "total_log_prob": total_log_prob,
92
  "tokens": tokens[:20],
93
- "confidences": confidence_list[:20]
 
 
 
 
 
94
  }
95
 
96
- # Create comparison dataframe
97
- comparison_df = pd.DataFrame({
98
- "Token": results["ERNIE-4.5-PT"]["df"]["Token"],
99
- "ERNIE-4.5-PT": {
100
- "LogProb": results["ERNIE-4.5-PT"]["df"]["LogProb"],
101
- "Confidence": results["ERNIE-4.5-PT"]["df"]["Confidence"],
102
- "Top-k": results["ERNIE-4.5-PT"]["df"][f"Top-{top_k} Predictions"]
103
- },
104
- "ERNIE-4.5-Base-PT": {
105
- "LogProb": results["ERNIE-4.5-Base-PT"]["df"]["LogProb"],
106
- "Confidence": results["ERNIE-4.5-Base-PT"]["df"]["Confidence"],
107
- "Top-k": results["ERNIE-4.5-Base-PT"]["df"][f"Top-{top_k} Predictions"]
108
- }
109
- })
110
 
111
- # Create visualization
112
- fig = go.Figure()
 
113
 
114
- # Add confidence bars for both models
115
- fig.add_trace(go.Bar(
116
  name='ERNIE-4.5-PT',
117
  x=results["ERNIE-4.5-PT"]["tokens"],
118
- y=results["ERNIE-4.5-PT"]["confidences"],
119
- marker_color='royalblue'
 
 
 
120
  ))
121
 
122
- fig.add_trace(go.Bar(
123
  name='ERNIE-4.5-Base-PT',
124
  x=results["ERNIE-4.5-Base-PT"]["tokens"],
125
- y=results["ERNIE-4.5-Base-PT"]["confidences"],
126
- marker_color='lightseagreen'
 
 
 
127
  ))
128
 
129
- fig.update_layout(
130
- title='Model Confidence Comparison',
131
  xaxis_title='Token',
132
  yaxis_title='Confidence (Probability)',
133
  barmode='group',
134
- yaxis=dict(tickformat='.0%', range=[0, 1]),
135
  legend=dict(
136
  orientation="h",
137
  yanchor="bottom",
138
  y=1.02,
139
  xanchor="right",
140
  x=1
141
- )
 
142
  )
143
 
144
- # Create summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  pt_logprob = results['ERNIE-4.5-PT']['total_log_prob']
146
  base_logprob = results['ERNIE-4.5-Base-PT']['total_log_prob']
147
 
@@ -153,34 +220,76 @@ def compare_models(text, top_k=5):
153
  better_model = "ERNIE-4.5-Base-PT"
154
  difference = base_logprob - pt_logprob
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  summary = (
157
  f"πŸ“Š **Model Comparison Summary**\n\n"
158
  f"**Total Log Probability**:\n"
159
  f"- ERNIE-4.5-PT: {pt_logprob:.3f}\n"
160
  f"- ERNIE-4.5-Base-PT: {base_logprob:.3f}\n\n"
 
 
 
161
  f"πŸ† **Higher Confidence Model**: {better_model}\n"
162
- f"Difference: {difference:.3f} ({'+' if better_model == 'ERNIE-4.5-PT' else '-'}{difference:.3f})\n\n"
163
  f"**What this means**:\n"
164
  f"- Log probability closer to 0 (less negative) indicates higher model confidence\n"
165
  f"- The {better_model} model is more confident in predicting your input text\n"
166
- f"- Confidence per token is shown in the table and chart below"
 
 
 
 
 
167
  )
168
 
169
- return comparison_df, summary, fig
170
 
171
  # Create custom CSS for better styling
172
  css = """
173
  .main-container {
174
- max-width: 1200px;
175
  margin: 0 auto;
176
  }
177
  .dataframe-container {
178
  margin: 20px 0;
179
  }
180
- .confidence-chart {
181
- margin: 20px 0;
182
- height: 400px;
183
- }
184
  .summary-box {
185
  background-color: #f8f9fa;
186
  border-left: 4px solid #4285f4;
@@ -188,26 +297,12 @@ css = """
188
  border-radius: 4px;
189
  margin: 20px 0;
190
  }
191
- .model-header {
192
- font-weight: bold;
193
- color: #1a73e8;
194
- margin-top: 10px;
195
- }
196
- .token-cell {
197
- font-family: monospace;
198
- background-color: #f1f3f4;
199
- padding: 4px 8px;
200
- border-radius: 3px;
201
- }
202
- .confidence-high {
203
- color: #0f9d58;
204
- font-weight: bold;
205
- }
206
- .confidence-medium {
207
- color: #f4b400;
208
- }
209
- .confidence-low {
210
- color: #db4437;
211
  }
212
  """
213
 
@@ -217,13 +312,7 @@ with gr.Blocks(css=css, title="ERNIE Model Comparison Tool") as demo:
217
  """
218
  # πŸ” ERNIE 4.5 Model Comparison Tool
219
 
220
- Compare how different ERNIE models process your text with detailed token-level analysis.
221
-
222
- ## What this tool shows:
223
- - **Token Log Probability**: How confident the model is in predicting each token (closer to 0 is better)
224
- - **Confidence**: Probability percentage for each token prediction
225
- - **Top-k Predictions**: What other tokens the model considered likely
226
- - **Visual Comparison**: Bar chart showing confidence differences between models
227
  """
228
  )
229
 
@@ -233,7 +322,7 @@ with gr.Blocks(css=css, title="ERNIE Model Comparison Tool") as demo:
233
  lines=3,
234
  placeholder="Enter text to analyze (e.g., 'Hello, World!')",
235
  label="Input Text",
236
- value="Hello, World!"
237
  )
238
  with gr.Column(scale=1):
239
  top_k = gr.Slider(
@@ -245,7 +334,7 @@ with gr.Blocks(css=css, title="ERNIE Model Comparison Tool") as demo:
245
  )
246
 
247
  with gr.Row():
248
- compare_btn = gr.Button("Compare Models", variant="primary")
249
 
250
  with gr.Row():
251
  with gr.Column():
@@ -256,18 +345,32 @@ with gr.Blocks(css=css, title="ERNIE Model Comparison Tool") as demo:
256
 
257
  with gr.Row():
258
  with gr.Column():
259
- comparison_table = gr.Dataframe(
260
- label="Token-Level Analysis",
261
- elem_classes=["dataframe-container"],
262
- interactive=False,
263
- wrap=True
264
  )
265
 
266
  with gr.Row():
267
  with gr.Column():
268
  confidence_chart = gr.Plot(
269
- label="Model Confidence Comparison",
270
- elem_classes=["confidence-chart"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  )
272
 
273
  # Examples section
@@ -287,10 +390,10 @@ with gr.Blocks(css=css, title="ERNIE Model Comparison Tool") as demo:
287
  """
288
  ## How to Interpret Results
289
 
290
- 1. **Log Probability**: Negative values where closer to 0 means higher model confidence
291
- 2. **Confidence**: Percentage showing how certain the model was about each token
292
- 3. **Top-k Predictions**: Alternative tokens the model considered likely
293
- 4. **Visual Chart**: Bar heights represent model confidence for each token
294
 
295
  **Model Differences**:
296
  - **ERNIE-4.5-PT**: Instruction-tuned model, better at following complex instructions
@@ -302,7 +405,7 @@ with gr.Blocks(css=css, title="ERNIE Model Comparison Tool") as demo:
302
  compare_btn.click(
303
  fn=compare_models,
304
  inputs=[input_text, top_k],
305
- outputs=[comparison_table, summary_box, confidence_chart]
306
  )
307
 
308
  if __name__ == "__main__":
 
4
  import gradio as gr
5
  import pandas as pd
6
  import math
7
+ import plotly.graph_objects as go
8
+ import plotly.express as px
9
+ from plotly.subplots import make_subplots
10
 
11
  # Load model and tokenizer
12
  model_ids = {
 
31
 
32
  # Helper function to format log probability
33
  def format_log_prob(log_prob):
34
+ """Format log probability"""
35
  return f"{log_prob:.3f}"
36
 
37
+ # Helper function to get confidence level
38
+ def get_confidence_level(prob):
39
+ """Get confidence level description based on probability"""
40
+ if prob > 0.8:
41
+ return "High", "🟒"
42
+ elif prob > 0.5:
43
+ return "Medium", "🟑"
44
+ else:
45
+ return "Low", "πŸ”΄"
46
+
47
  # Main function: compute token-wise log probabilities and top-k predictions
48
  @torch.no_grad()
49
  def compare_models(text, top_k=5):
50
  if not text.strip():
51
+ return None, "⚠️ Please enter some text to analyze", None
52
 
53
  results = {}
54
 
 
75
  # Generate top-k predictions for each position (up to first 20 tokens)
76
  topk_list = []
77
  confidence_list = []
78
+ confidence_indicators = []
79
+
80
  for i in range(min(20, shift_logits.shape[1])):
81
  topk = torch.topk(log_probs[0, i], k=top_k)
82
  topk_ids = topk.indices.tolist()
 
91
  # Calculate confidence (probability of actual token)
92
  actual_token_prob = math.exp(token_log_probs[0, i].item())
93
  confidence_list.append(actual_token_prob)
94
+
95
+ # Get confidence level and indicator
96
+ level, indicator = get_confidence_level(actual_token_prob)
97
+ confidence_indicators.append(indicator)
98
 
99
+ # Store results for this model
 
 
 
 
 
 
 
100
  results[model_name] = {
 
 
101
  "tokens": tokens[:20],
102
+ "log_probs": [format_log_prob(float(x)) for x in token_log_probs[0][:20]],
103
+ "confidences": [format_prob(x) for x in confidence_list[:20]],
104
+ "levels": confidence_indicators[:20],
105
+ "topk_predictions": topk_list,
106
+ "total_log_prob": total_log_prob,
107
+ "confidence_values": confidence_list[:20] # Keep raw values for plotting
108
  }
109
 
110
+ # Create a properly structured dataframe
111
+ df_data = {"Token": results["ERNIE-4.5-PT"]["tokens"]}
112
+
113
+ # Add columns for each model
114
+ for model_name in ["ERNIE-4.5-PT", "ERNIE-4.5-Base-PT"]:
115
+ df_data[f"{model_name} LogProb"] = results[model_name]["log_probs"]
116
+ df_data[f"{model_name} Confidence"] = results[model_name]["confidences"]
117
+ df_data[f"{model_name} Level"] = results[model_name]["levels"]
118
+ df_data[f"{model_name} Top-{top_k}"] = results[model_name]["topk_predictions"]
119
+
120
+ # Create the dataframe
121
+ comparison_df = pd.DataFrame(df_data)
 
 
122
 
123
+ # Create visualizations
124
+ # 1. Token-level confidence comparison
125
+ fig_confidence = go.Figure()
126
 
127
+ # Add bars for both models
128
+ fig_confidence.add_trace(go.Bar(
129
  name='ERNIE-4.5-PT',
130
  x=results["ERNIE-4.5-PT"]["tokens"],
131
+ y=results["ERNIE-4.5-PT"]["confidence_values"],
132
+ marker_color='royalblue',
133
+ text=[format_prob(x) for x in results["ERNIE-4.5-PT"]["confidence_values"]],
134
+ textposition='auto',
135
+ textfont=dict(size=10)
136
  ))
137
 
138
+ fig_confidence.add_trace(go.Bar(
139
  name='ERNIE-4.5-Base-PT',
140
  x=results["ERNIE-4.5-Base-PT"]["tokens"],
141
+ y=results["ERNIE-4.5-Base-PT"]["confidence_values"],
142
+ marker_color='lightseagreen',
143
+ text=[format_prob(x) for x in results["ERNIE-4.5-Base-PT"]["confidence_values"]],
144
+ textposition='auto',
145
+ textfont=dict(size=10)
146
  ))
147
 
148
+ fig_confidence.update_layout(
149
+ title='Token-Level Confidence Comparison',
150
  xaxis_title='Token',
151
  yaxis_title='Confidence (Probability)',
152
  barmode='group',
153
+ yaxis=dict(tickformat='.0%', range=[0, 1.05]),
154
  legend=dict(
155
  orientation="h",
156
  yanchor="bottom",
157
  y=1.02,
158
  xanchor="right",
159
  x=1
160
+ ),
161
+ height=500
162
  )
163
 
164
+ # 2. Log probability trend comparison
165
+ fig_logprob = go.Figure()
166
+
167
+ # Convert log probabilities back to float for plotting
168
+ pt_logprobs = [float(x) for x in results["ERNIE-4.5-PT"]["log_probs"]]
169
+ base_logprobs = [float(x) for x in results["ERNIE-4.5-Base-PT"]["log_probs"]]
170
+
171
+ fig_logprob.add_trace(go.Scatter(
172
+ name='ERNIE-4.5-PT',
173
+ x=results["ERNIE-4.5-PT"]["tokens"],
174
+ y=pt_logprobs,
175
+ mode='lines+markers',
176
+ line=dict(color='royalblue', width=3),
177
+ marker=dict(size=8),
178
+ text=[f"LogProb: {x}<br>Token: {t}" for x, t in zip(pt_logprobs, results["ERNIE-4.5-PT"]["tokens"])],
179
+ hoverinfo='text'
180
+ ))
181
+
182
+ fig_logprob.add_trace(go.Scatter(
183
+ name='ERNIE-4.5-Base-PT',
184
+ x=results["ERNIE-4.5-Base-PT"]["tokens"],
185
+ y=base_logprobs,
186
+ mode='lines+markers',
187
+ line=dict(color='lightseagreen', width=3),
188
+ marker=dict(size=8),
189
+ text=[f"LogProb: {x}<br>Token: {t}" for x, t in zip(base_logprobs, results["ERNIE-4.5-Base-PT"]["tokens"])],
190
+ hoverinfo='text'
191
+ ))
192
+
193
+ # Add a horizontal line at y=0 for reference
194
+ fig_logprob.add_hline(y=0, line_dash="dash", line_color="red", annotation_text="Zero Reference")
195
+
196
+ fig_logprob.update_layout(
197
+ title='Token-Level Log Probability Trend',
198
+ xaxis_title='Token',
199
+ yaxis_title='Log Probability',
200
+ hovermode='closest',
201
+ legend=dict(
202
+ orientation="h",
203
+ yanchor="bottom",
204
+ y=1.02,
205
+ xanchor="right",
206
+ x=1
207
+ ),
208
+ height=400
209
+ )
210
+
211
+ # 3. Model summary comparison
212
  pt_logprob = results['ERNIE-4.5-PT']['total_log_prob']
213
  base_logprob = results['ERNIE-4.5-Base-PT']['total_log_prob']
214
 
 
220
  better_model = "ERNIE-4.5-Base-PT"
221
  difference = base_logprob - pt_logprob
222
 
223
+ # Calculate average confidence for each model
224
+ pt_avg_conf = sum(results['ERNIE-4.5-PT']['confidence_values']) / len(results['ERNIE-4.5-PT']['confidence_values'])
225
+ base_avg_conf = sum(results['ERNIE-4.5-Base-PT']['confidence_values']) / len(results['ERNIE-4.5-Base-PT']['confidence_values'])
226
+
227
+ # Create summary chart
228
+ fig_summary = go.Figure()
229
+
230
+ fig_summary.add_trace(go.Bar(
231
+ name='Total Log Probability',
232
+ x=['ERNIE-4.5-PT', 'ERNIE-4.5-Base-PT'],
233
+ y=[pt_logprob, base_logprob],
234
+ marker_color=['royalblue', 'lightseagreen'],
235
+ text=[f"{pt_logprob:.3f}", f"{base_logprob:.3f}"],
236
+ textposition='auto',
237
+ textfont=dict(size=14)
238
+ ))
239
+
240
+ fig_summary.update_layout(
241
+ title='Model Summary Comparison',
242
+ yaxis_title='Total Log Probability',
243
+ xaxis_title='Model',
244
+ height=300,
245
+ showlegend=False
246
+ )
247
+
248
+ # Add annotation for the better model
249
+ fig_summary.add_annotation(
250
+ x=0 if better_model == "ERNIE-4.5-PT" else 1,
251
+ y=max(pt_logprob, base_logprob) + 0.5,
252
+ text=f"πŸ† {better_model}",
253
+ showarrow=True,
254
+ arrowhead=1,
255
+ ax=0,
256
+ ay=-30,
257
+ font=dict(size=16, color="green")
258
+ )
259
+
260
+ # Create summary text
261
  summary = (
262
  f"πŸ“Š **Model Comparison Summary**\n\n"
263
  f"**Total Log Probability**:\n"
264
  f"- ERNIE-4.5-PT: {pt_logprob:.3f}\n"
265
  f"- ERNIE-4.5-Base-PT: {base_logprob:.3f}\n\n"
266
+ f"**Average Confidence**:\n"
267
+ f"- ERNIE-4.5-PT: {format_prob(pt_avg_conf)}\n"
268
+ f"- ERNIE-4.5-Base-PT: {format_prob(base_avg_conf)}\n\n"
269
  f"πŸ† **Higher Confidence Model**: {better_model}\n"
270
+ f"Difference: {difference:.3f}\n\n"
271
  f"**What this means**:\n"
272
  f"- Log probability closer to 0 (less negative) indicates higher model confidence\n"
273
  f"- The {better_model} model is more confident in predicting your input text\n"
274
+ f"- Confidence indicators: 🟒 High (>80%), 🟑 Medium (50-80%), πŸ”΄ Low (<50%)\n\n"
275
+ f"**Interpretation Guide**:\n"
276
+ f"- **LogProb**: How confident the model is in predicting each token (closer to 0 is better)\n"
277
+ f"- **Confidence**: Probability percentage for each token prediction\n"
278
+ f"- **Level**: Visual indicator of confidence (πŸŸ’πŸŸ‘πŸ”΄)\n"
279
+ f"- **Top-k**: What other tokens the model considered likely"
280
  )
281
 
282
+ return comparison_df, summary, fig_confidence, fig_logprob, fig_summary
283
 
284
  # Create custom CSS for better styling
285
  css = """
286
  .main-container {
287
+ max-width: 1400px;
288
  margin: 0 auto;
289
  }
290
  .dataframe-container {
291
  margin: 20px 0;
292
  }
 
 
 
 
293
  .summary-box {
294
  background-color: #f8f9fa;
295
  border-left: 4px solid #4285f4;
 
297
  border-radius: 4px;
298
  margin: 20px 0;
299
  }
300
+ .chart-container {
301
+ margin: 20px 0;
302
+ border: 1px solid #e0e0e0;
303
+ border-radius: 8px;
304
+ padding: 15px;
305
+ background-color: #ffffff;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  }
307
  """
308
 
 
312
  """
313
  # πŸ” ERNIE 4.5 Model Comparison Tool
314
 
315
+ Compare how different ERNIE models process your text with detailed token-level analysis and visualizations.
 
 
 
 
 
 
316
  """
317
  )
318
 
 
322
  lines=3,
323
  placeholder="Enter text to analyze (e.g., 'Hello, World!')",
324
  label="Input Text",
325
+ value="What is the meaning of life?"
326
  )
327
  with gr.Column(scale=1):
328
  top_k = gr.Slider(
 
334
  )
335
 
336
  with gr.Row():
337
+ compare_btn = gr.Button("Compare Models", variant="primary", size="lg")
338
 
339
  with gr.Row():
340
  with gr.Column():
 
345
 
346
  with gr.Row():
347
  with gr.Column():
348
+ summary_chart = gr.Plot(
349
+ label="Model Summary",
350
+ elem_classes=["chart-container"]
 
 
351
  )
352
 
353
  with gr.Row():
354
  with gr.Column():
355
  confidence_chart = gr.Plot(
356
+ label="Token-Level Confidence Comparison",
357
+ elem_classes=["chart-container"]
358
+ )
359
+
360
+ with gr.Row():
361
+ with gr.Column():
362
+ logprob_chart = gr.Plot(
363
+ label="Token-Level Log Probability Trend",
364
+ elem_classes=["chart-container"]
365
+ )
366
+
367
+ with gr.Row():
368
+ with gr.Column():
369
+ comparison_table = gr.Dataframe(
370
+ label="Token-Level Analysis",
371
+ elem_classes=["dataframe-container"],
372
+ interactive=False,
373
+ wrap=True
374
  )
375
 
376
  # Examples section
 
390
  """
391
  ## How to Interpret Results
392
 
393
+ 1. **Model Summary Chart**: Shows which model has higher overall confidence for your input text
394
+ 2. **Token-Level Confidence Chart**: Compares how confident each model is for each token in your text
395
+ 3. **Log Probability Trend Chart**: Shows how log probability changes across tokens (closer to 0 is better)
396
+ 4. **Token-Level Analysis Table**: Detailed breakdown of predictions for each token
397
 
398
  **Model Differences**:
399
  - **ERNIE-4.5-PT**: Instruction-tuned model, better at following complex instructions
 
405
  compare_btn.click(
406
  fn=compare_models,
407
  inputs=[input_text, top_k],
408
+ outputs=[comparison_table, summary_box, confidence_chart, logprob_chart, summary_chart]
409
  )
410
 
411
  if __name__ == "__main__":