AlexTransformer commited on
Commit
674515d
Β·
verified Β·
1 Parent(s): 964e0e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -32
app.py CHANGED
@@ -4,6 +4,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import gradio as gr
5
  import pandas as pd
6
  import math
 
7
 
8
  # Load model and tokenizer
9
  model_ids = {
@@ -21,9 +22,22 @@ models = {
21
  for name, path in model_ids.items()
22
  }
23
 
 
 
 
 
 
 
 
 
 
 
24
  # Main function: compute token-wise log probabilities and top-k predictions
25
  @torch.no_grad()
26
  def compare_models(text, top_k=5):
 
 
 
27
  results = {}
28
 
29
  for model_name in model_ids:
@@ -48,59 +62,248 @@ def compare_models(text, top_k=5):
48
 
49
  # Generate top-k predictions for each position (up to first 20 tokens)
50
  topk_list = []
 
51
  for i in range(min(20, shift_logits.shape[1])):
52
  topk = torch.topk(log_probs[0, i], k=top_k)
53
  topk_ids = topk.indices.tolist()
54
  topk_scores = topk.values.tolist()
55
  topk_tokens = tokenizer.convert_ids_to_tokens(topk_ids)
56
- topk_probs = [round(math.exp(s), 4) for s in topk_scores]
57
- pair_list = [f"{tok} ({prob})" for tok, prob in zip(topk_tokens, topk_probs)]
58
- topk_list.append(", ".join(pair_list))
 
 
 
 
 
 
59
 
60
  # Prepare dataframe for display
61
  df = pd.DataFrame({
62
  "Token": tokens[:20],
63
- "LogProb": [round(float(x), 4) for x in token_log_probs[0][:20]],
 
64
  f"Top-{top_k} Predictions": topk_list
65
  })
66
 
67
  results[model_name] = {
68
  "df": df,
69
- "total_log_prob": total_log_prob
 
 
70
  }
71
 
72
- # Merge two model results into one table
73
- merged = pd.DataFrame({
74
  "Token": results["ERNIE-4.5-PT"]["df"]["Token"],
75
- "ERNIE-4.5-PT LogProb": results["ERNIE-4.5-PT"]["df"]["LogProb"],
76
- "ERNIE-4.5-PT Top-k": results["ERNIE-4.5-PT"]["df"][f"Top-{top_k} Predictions"],
77
- "ERNIE-4.5-Base-PT LogProb": results["ERNIE-4.5-Base-PT"]["df"]["LogProb"],
78
- "ERNIE-4.5-Base-PT Top-k": results["ERNIE-4.5-Base-PT"]["df"][f"Top-{top_k} Predictions"],
 
 
 
 
 
 
79
  })
80
 
81
- # Summarize total log probability for each model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  summary = (
83
- f"🧠 Total Log Prob:\n"
84
- f"- ERNIE-4.5-PT: {results['ERNIE-4.5-PT']['total_log_prob']:.2f}\n"
85
- f"- ERNIE-4.5-Base-PT: {results['ERNIE-4.5-Base-PT']['total_log_prob']:.2f}"
 
 
 
 
 
 
 
86
  )
87
 
88
- return merged, summary
89
-
90
- # Gradio interface
91
- demo = gr.Interface(
92
- fn=compare_models,
93
- inputs=[
94
- gr.Textbox(lines=2, placeholder="Type a sentence here...", label="Input Sentence"),
95
- gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Top-k Predictions")
96
- ],
97
- outputs=[
98
- gr.Dataframe(label="Token LogProbs and Top-k Predictions"),
99
- gr.Textbox(label="Sentence Total Log Probability", lines=3)
100
- ],
101
- title="πŸ§ͺ ERNIE 4.5 Model Comparison with Top-k Predictions",
102
- description="Compare ERNIE-4.5-0.3B Instruct and Base model by computing token logprobs and Top-k predictions"
103
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  if __name__ == "__main__":
106
- demo.launch()
 
4
  import gradio as gr
5
  import pandas as pd
6
  import math
7
+ from plotly import graph_objects as go
8
 
9
  # Load model and tokenizer
10
  model_ids = {
 
22
  for name, path in model_ids.items()
23
  }
24
 
25
+ # Helper function to format probability
26
+ def format_prob(prob):
27
+ """Format probability as percentage with 1 decimal place"""
28
+ return f"{prob*100:.1f}%"
29
+
30
+ # Helper function to format log probability
31
+ def format_log_prob(log_prob):
32
+ """Format log probability with color coding"""
33
+ return f"{log_prob:.3f}"
34
+
35
  # Main function: compute token-wise log probabilities and top-k predictions
36
  @torch.no_grad()
37
  def compare_models(text, top_k=5):
38
+ if not text.strip():
39
+ return None, "⚠️ Please enter some text to analyze"
40
+
41
  results = {}
42
 
43
  for model_name in model_ids:
 
62
 
63
  # Generate top-k predictions for each position (up to first 20 tokens)
64
  topk_list = []
65
+ confidence_list = []
66
  for i in range(min(20, shift_logits.shape[1])):
67
  topk = torch.topk(log_probs[0, i], k=top_k)
68
  topk_ids = topk.indices.tolist()
69
  topk_scores = topk.values.tolist()
70
  topk_tokens = tokenizer.convert_ids_to_tokens(topk_ids)
71
+ topk_probs = [math.exp(s) for s in topk_scores]
72
+
73
+ # Format top-k predictions with probabilities
74
+ topk_formatted = [f"{tok} ({format_prob(p)})" for tok, p in zip(topk_tokens, topk_probs)]
75
+ topk_list.append(", ".join(topk_formatted))
76
+
77
+ # Calculate confidence (probability of actual token)
78
+ actual_token_prob = math.exp(token_log_probs[0, i].item())
79
+ confidence_list.append(actual_token_prob)
80
 
81
  # Prepare dataframe for display
82
  df = pd.DataFrame({
83
  "Token": tokens[:20],
84
+ "LogProb": [format_log_prob(float(x)) for x in token_log_probs[0][:20]],
85
+ "Confidence": [format_prob(x) for x in confidence_list[:20]],
86
  f"Top-{top_k} Predictions": topk_list
87
  })
88
 
89
  results[model_name] = {
90
  "df": df,
91
+ "total_log_prob": total_log_prob,
92
+ "tokens": tokens[:20],
93
+ "confidences": confidence_list[:20]
94
  }
95
 
96
+ # Create comparison dataframe
97
+ comparison_df = pd.DataFrame({
98
  "Token": results["ERNIE-4.5-PT"]["df"]["Token"],
99
+ "ERNIE-4.5-PT": {
100
+ "LogProb": results["ERNIE-4.5-PT"]["df"]["LogProb"],
101
+ "Confidence": results["ERNIE-4.5-PT"]["df"]["Confidence"],
102
+ "Top-k": results["ERNIE-4.5-PT"]["df"][f"Top-{top_k} Predictions"]
103
+ },
104
+ "ERNIE-4.5-Base-PT": {
105
+ "LogProb": results["ERNIE-4.5-Base-PT"]["df"]["LogProb"],
106
+ "Confidence": results["ERNIE-4.5-Base-PT"]["df"]["Confidence"],
107
+ "Top-k": results["ERNIE-4.5-Base-PT"]["df"][f"Top-{top_k} Predictions"]
108
+ }
109
  })
110
 
111
+ # Create visualization
112
+ fig = go.Figure()
113
+
114
+ # Add confidence bars for both models
115
+ fig.add_trace(go.Bar(
116
+ name='ERNIE-4.5-PT',
117
+ x=results["ERNIE-4.5-PT"]["tokens"],
118
+ y=results["ERNIE-4.5-PT"]["confidences"],
119
+ marker_color='royalblue'
120
+ ))
121
+
122
+ fig.add_trace(go.Bar(
123
+ name='ERNIE-4.5-Base-PT',
124
+ x=results["ERNIE-4.5-Base-PT"]["tokens"],
125
+ y=results["ERNIE-4.5-Base-PT"]["confidences"],
126
+ marker_color='lightseagreen'
127
+ ))
128
+
129
+ fig.update_layout(
130
+ title='Model Confidence Comparison',
131
+ xaxis_title='Token',
132
+ yaxis_title='Confidence (Probability)',
133
+ barmode='group',
134
+ yaxis=dict(tickformat='.0%', range=[0, 1]),
135
+ legend=dict(
136
+ orientation="h",
137
+ yanchor="bottom",
138
+ y=1.02,
139
+ xanchor="right",
140
+ x=1
141
+ )
142
+ )
143
+
144
+ # Create summary
145
+ pt_logprob = results['ERNIE-4.5-PT']['total_log_prob']
146
+ base_logprob = results['ERNIE-4.5-Base-PT']['total_log_prob']
147
+
148
+ # Determine which model has higher confidence
149
+ if pt_logprob > base_logprob:
150
+ better_model = "ERNIE-4.5-PT"
151
+ difference = pt_logprob - base_logprob
152
+ else:
153
+ better_model = "ERNIE-4.5-Base-PT"
154
+ difference = base_logprob - pt_logprob
155
+
156
  summary = (
157
+ f"πŸ“Š **Model Comparison Summary**\n\n"
158
+ f"**Total Log Probability**:\n"
159
+ f"- ERNIE-4.5-PT: {pt_logprob:.3f}\n"
160
+ f"- ERNIE-4.5-Base-PT: {base_logprob:.3f}\n\n"
161
+ f"πŸ† **Higher Confidence Model**: {better_model}\n"
162
+ f"Difference: {difference:.3f} ({'+' if better_model == 'ERNIE-4.5-PT' else '-'}{difference:.3f})\n\n"
163
+ f"**What this means**:\n"
164
+ f"- Log probability closer to 0 (less negative) indicates higher model confidence\n"
165
+ f"- The {better_model} model is more confident in predicting your input text\n"
166
+ f"- Confidence per token is shown in the table and chart below"
167
  )
168
 
169
+ return comparison_df, summary, fig
170
+
171
+ # Create custom CSS for better styling
172
+ css = """
173
+ .main-container {
174
+ max-width: 1200px;
175
+ margin: 0 auto;
176
+ }
177
+ .dataframe-container {
178
+ margin: 20px 0;
179
+ }
180
+ .confidence-chart {
181
+ margin: 20px 0;
182
+ height: 400px;
183
+ }
184
+ .summary-box {
185
+ background-color: #f8f9fa;
186
+ border-left: 4px solid #4285f4;
187
+ padding: 15px;
188
+ border-radius: 4px;
189
+ margin: 20px 0;
190
+ }
191
+ .model-header {
192
+ font-weight: bold;
193
+ color: #1a73e8;
194
+ margin-top: 10px;
195
+ }
196
+ .token-cell {
197
+ font-family: monospace;
198
+ background-color: #f1f3f4;
199
+ padding: 4px 8px;
200
+ border-radius: 3px;
201
+ }
202
+ .confidence-high {
203
+ color: #0f9d58;
204
+ font-weight: bold;
205
+ }
206
+ .confidence-medium {
207
+ color: #f4b400;
208
+ }
209
+ .confidence-low {
210
+ color: #db4437;
211
+ }
212
+ """
213
+
214
+ # Gradio interface with improved layout
215
+ with gr.Blocks(css=css, title="ERNIE Model Comparison Tool") as demo:
216
+ gr.Markdown(
217
+ """
218
+ # πŸ” ERNIE 4.5 Model Comparison Tool
219
+
220
+ Compare how different ERNIE models process your text with detailed token-level analysis.
221
+
222
+ ## What this tool shows:
223
+ - **Token Log Probability**: How confident the model is in predicting each token (closer to 0 is better)
224
+ - **Confidence**: Probability percentage for each token prediction
225
+ - **Top-k Predictions**: What other tokens the model considered likely
226
+ - **Visual Comparison**: Bar chart showing confidence differences between models
227
+ """
228
+ )
229
+
230
+ with gr.Row():
231
+ with gr.Column(scale=3):
232
+ input_text = gr.Textbox(
233
+ lines=3,
234
+ placeholder="Enter text to analyze (e.g., 'Hello, World!')",
235
+ label="Input Text",
236
+ value="Hello, World!"
237
+ )
238
+ with gr.Column(scale=1):
239
+ top_k = gr.Slider(
240
+ minimum=1,
241
+ maximum=10,
242
+ value=3,
243
+ step=1,
244
+ label="Top-k Predictions"
245
+ )
246
+
247
+ with gr.Row():
248
+ compare_btn = gr.Button("Compare Models", variant="primary")
249
+
250
+ with gr.Row():
251
+ with gr.Column():
252
+ summary_box = gr.Markdown(
253
+ elem_classes=["summary-box"],
254
+ label="Model Comparison Summary"
255
+ )
256
+
257
+ with gr.Row():
258
+ with gr.Column():
259
+ comparison_table = gr.Dataframe(
260
+ label="Token-Level Analysis",
261
+ elem_classes=["dataframe-container"],
262
+ interactive=False,
263
+ wrap=True
264
+ )
265
+
266
+ with gr.Row():
267
+ with gr.Column():
268
+ confidence_chart = gr.Plot(
269
+ label="Model Confidence Comparison",
270
+ elem_classes=["confidence-chart"]
271
+ )
272
+
273
+ # Examples section
274
+ gr.Examples(
275
+ examples=[
276
+ ["Hello, World!", 3],
277
+ ["The quick brown fox jumps over the lazy dog.", 5],
278
+ ["Artificial intelligence will transform our society.", 3],
279
+ ["What is the meaning of life?", 4]
280
+ ],
281
+ inputs=[input_text, top_k],
282
+ label="Try these examples:"
283
+ )
284
+
285
+ # Footer with explanation
286
+ gr.Markdown(
287
+ """
288
+ ## How to Interpret Results
289
+
290
+ 1. **Log Probability**: Negative values where closer to 0 means higher model confidence
291
+ 2. **Confidence**: Percentage showing how certain the model was about each token
292
+ 3. **Top-k Predictions**: Alternative tokens the model considered likely
293
+ 4. **Visual Chart**: Bar heights represent model confidence for each token
294
+
295
+ **Model Differences**:
296
+ - **ERNIE-4.5-PT**: Instruction-tuned model, better at following complex instructions
297
+ - **ERNIE-4.5-Base-PT**: Base model, better at general language patterns
298
+ """
299
+ )
300
+
301
+ # Set up event handler
302
+ compare_btn.click(
303
+ fn=compare_models,
304
+ inputs=[input_text, top_k],
305
+ outputs=[comparison_table, summary_box, confidence_chart]
306
+ )
307
 
308
  if __name__ == "__main__":
309
+ demo.launch()