File size: 13,137 Bytes
707db97
 
 
 
c6c2112
e0a575e
b9a15d9
 
707db97
 
 
e0a575e
 
707db97
 
 
 
 
 
 
 
 
 
 
 
e0a575e
 
 
 
674515d
e0a575e
 
 
 
 
 
 
 
 
 
 
 
284767f
e0a575e
 
 
b9a15d9
e0a575e
 
b9a15d9
e0a575e
707db97
e0a575e
 
 
 
 
 
 
 
 
 
 
 
674515d
e0a575e
674515d
e0a575e
 
 
 
 
 
b9a15d9
e0a575e
 
674515d
 
e0a575e
 
 
 
b9a15d9
e0a575e
674515d
b9a15d9
e0a575e
 
 
 
 
 
 
cea53ee
 
b9a15d9
e0a575e
b9a15d9
e0a575e
 
 
 
 
cea53ee
e0a575e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9a15d9
e0a575e
 
b9a15d9
e0a575e
 
 
674515d
e0a575e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674515d
e0a575e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674515d
e0a575e
 
 
b9a15d9
e0a575e
 
b9a15d9
e0a575e
 
674515d
e0a575e
 
 
 
 
 
674515d
e0a575e
 
 
 
 
 
 
 
 
 
 
 
707db97
 
e0a575e
 
674515d
 
e0a575e
674515d
e0a575e
674515d
 
 
 
e0a575e
 
 
 
 
674515d
 
e0a575e
674515d
e0a575e
674515d
 
e0a575e
674515d
 
e0a575e
b9a15d9
e0a575e
b9a15d9
e0a575e
b9a15d9
e0a575e
 
 
 
 
 
b9a15d9
e0a575e
 
 
b9a15d9
e0a575e
 
 
 
b9a15d9
 
674515d
e0a575e
674515d
 
 
 
e0a575e
 
 
 
674515d
e0a575e
674515d
 
 
 
 
 
 
 
e0a575e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674515d
e0a575e
 
 
 
284767f
e0a575e
 
 
 
674515d
 
 
 
e0a575e
 
 
 
 
 
 
 
674515d
707db97
 
674515d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import math
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Load model and tokenizer
model_ids = {
    "ERNIE-4.5-Base-PT": "baidu/ERNIE-4.5-0.3B-Base-PT",
    "ERNIE-4.5-PT": "baidu/ERNIE-4.5-0.3B-PT"
}

tokenizers = {
    name: AutoTokenizer.from_pretrained(path)
    for name, path in model_ids.items()
}

models = {
    name: AutoModelForCausalLM.from_pretrained(path).eval()
    for name, path in model_ids.items()
}

def calculate_token_log_probabilities(text, model_name):
    """Calculate log probability for each token and total log probability."""
    tokenizer = tokenizers[model_name]
    model = models[model_name]
    
    # Tokenize input
    inputs = tokenizer(text, return_tensors="pt")
    input_ids = inputs["input_ids"]
    
    # Get model output logits
    outputs = model(**inputs)
    shift_logits = outputs.logits[:, :-1, :]          # Align prediction with target
    shift_labels = input_ids[:, 1:]                   # Shift labels to match predictions
    
    # Compute log probabilities
    log_probs = F.log_softmax(shift_logits, dim=-1)
    token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
    
    # Convert to list and get corresponding tokens
    token_log_probs = token_log_probs[0].tolist()
    tokens = tokenizer.convert_ids_to_tokens(shift_labels[0])
    
    # Calculate total log probability
    total_log_prob = sum(token_log_probs)
    
    return token_log_probs, tokens, total_log_prob

def create_analysis_visualization(tokens, token_log_probs, total_log_prob, model_name):
    """Create visualization components for token analysis."""
    # Create DataFrame for token analysis
    df_data = []
    for token, log_prob in zip(tokens, token_log_probs):
        prob = math.exp(log_prob)
        df_data.append({
            "Token": token,
            "Log Probability": f"{log_prob:.4f}",
            "Probability": f"{prob:.4f}",
            "Probability (%)": f"{prob*100:.2f}%"
        })
    
    df = pd.DataFrame(df_data)
    
    # Create bar chart for token probabilities
    fig = go.Figure()
    fig.add_trace(go.Bar(
        x=tokens,
        y=[math.exp(lp) for lp in token_log_probs],
        text=[f"{math.exp(lp):.3f}" for lp in token_log_probs],
        textposition='auto',
        marker_color='royalblue',
        name=model_name
    ))
    
    fig.update_layout(
        title=f"Token Probability Distribution - {model_name}",
        xaxis_title="Token",
        yaxis_title="Probability",
        yaxis=dict(tickformat='.0%', range=[0, 1.05]),
        height=400
    )
    
    return df, fig

def analyze_text_both_models(text):
    """Analyze text with both models and return visualization components."""
    if not text.strip():
        # Return empty components for both models
        return (None, None, "⚠️ Please enter some text to analyze", 
                None, None, "⚠️ Please enter some text to analyze",
                None, "⚠️ Please enter some text to analyze")
    
    results = {}
    
    # Analyze with both models
    for model_name in model_ids:
        token_log_probs, tokens, total_log_prob = calculate_token_log_probabilities(text, model_name)
        
        # Create visualization components
        df, fig = create_analysis_visualization(tokens, token_log_probs, total_log_prob, model_name)
        
        # Create summary text
        avg_prob = math.exp(total_log_prob / len(token_log_probs)) if token_log_probs else 0
        geo_mean_prob = math.exp(total_log_prob / len(token_log_probs)) if token_log_probs else 0
        
        summary = (
            f"## Analysis Summary - {model_name}\n\n"
            f"**Total Log Probability**: {total_log_prob:.4f}\n"
            f"**Sum of Individual Log Probs**: {sum(token_log_probs):.4f}\n"
            f"**Verification**: {'✓ Match' if abs(total_log_prob - sum(token_log_probs)) < 1e-10 else '✗ Mismatch'}\n\n"
            f"**Average Token Probability**: {avg_prob:.4f} ({avg_prob*100:.2f}%)\n"
            f"**Geometric Mean of Probabilities**: {geo_mean_prob:.4f} ({geo_mean_prob*100:.2f}%)\n\n"
            f"### Interpretation\n"
            f"- Total Log Probability is the sum of individual token log probabilities\n"
            f"- Higher values (closer to 0) indicate higher model confidence\n"
            f"- The first token has no prediction (no preceding context)\n"
            f"- Each token's probability shows how confident the model was in predicting it"
        )
        
        results[model_name] = {
            "df": df,
            "fig": fig,
            "summary": summary,
            "total_log_prob": total_log_prob,
            "token_log_probs": token_log_probs,
            "tokens": tokens
        }
    
    # Create comparison chart
    comparison_fig = go.Figure()
    
    # Add bars for both models
    base_model = "ERNIE-4.5-Base-PT"
    pt_model = "ERNIE-4.5-PT"
    
    # Ensure both models have the same tokens for comparison
    if results[base_model]["tokens"] == results[pt_model]["tokens"]:
        tokens = results[base_model]["tokens"]
        
        comparison_fig.add_trace(go.Bar(
            name=base_model,
            x=tokens,
            y=[math.exp(lp) for lp in results[base_model]["token_log_probs"]],
            text=[f"{math.exp(lp):.3f}" for lp in results[base_model]["token_log_probs"]],
            textposition='auto',
            marker_color='royalblue'
        ))
        
        comparison_fig.add_trace(go.Bar(
            name=pt_model,
            x=tokens,
            y=[math.exp(lp) for lp in results[pt_model]["token_log_probs"]],
            text=[f"{math.exp(lp):.3f}" for lp in results[pt_model]["token_log_probs"]],
            textposition='auto',
            marker_color='lightseagreen'
        ))
        
        comparison_fig.update_layout(
            title="Model Comparison: Token Probability Distribution",
            xaxis_title="Token",
            yaxis_title="Probability",
            yaxis=dict(tickformat='.0%', range=[0, 1.05]),
            barmode='group',
            height=400
        )
    else:
        # If tokens are different, create separate subplots
        comparison_fig = make_subplots(
            rows=2, cols=1,
            subplot_titles=(base_model, pt_model),
            vertical_spacing=0.1
        )
        
        # Add Base-PT model
        comparison_fig.add_trace(
            go.Bar(
                x=results[base_model]["tokens"],
                y=[math.exp(lp) for lp in results[base_model]["token_log_probs"]],
                text=[f"{math.exp(lp):.3f}" for lp in results[base_model]["token_log_probs"]],
                textposition='auto',
                marker_color='royalblue',
                name=base_model
            ),
            row=1, col=1
        )
        
        # Add PT model
        comparison_fig.add_trace(
            go.Bar(
                x=results[pt_model]["tokens"],
                y=[math.exp(lp) for lp in results[pt_model]["token_log_probs"]],
                text=[f"{math.exp(lp):.3f}" for lp in results[pt_model]["token_log_probs"]],
                textposition='auto',
                marker_color='lightseagreen',
                name=pt_model
            ),
            row=2, col=1
        )
        
        comparison_fig.update_layout(
            title="Model Comparison: Token Probability Distribution",
            height=600,
            showlegend=False
        )
    
    # Create comparison summary
    base_total = results[base_model]["total_log_prob"]
    pt_total = results[pt_model]["total_log_prob"]
    
    better_model = base_model if base_total > pt_total else pt_model
    difference = abs(base_total - pt_total)
    
    comparison_summary = (
        f"## Model Comparison Summary\n\n"
        f"**Total Log Probability**:\n"
        f"- {base_model}: {base_total:.4f}\n"
        f"- {pt_model}: {pt_total:.4f}\n\n"
        f"**Higher Confidence Model**: {better_model}\n"
        f"Difference: {difference:.4f}\n\n"
        f"### Interpretation\n"
        f"- The model with the higher Total Log Probability is more confident in predicting the input text\n"
        f"- Log probability closer to 0 (less negative) indicates higher model confidence\n"
        f"- {base_model} is the base model while {pt_model} is instruction-tuned"
    )
    
    return (
        results[base_model]["df"], 
        results[base_model]["fig"], 
        results[base_model]["summary"],
        results[pt_model]["df"], 
        results[pt_model]["fig"], 
        results[pt_model]["summary"],
        comparison_fig,
        comparison_summary
    )

# Create Gradio interface with side-by-side comparison
with gr.Blocks(title="Token Log Probability Analyzer - Model Comparison") as demo:
    gr.Markdown(
        """
        # 🔍 Token Log Probability Analyzer - Model Comparison
        
        Compare how two ERNIE models predict each token in your text with detailed log probability breakdown.
        """
    )
    
    with gr.Row():
        text_input = gr.Textbox(
            label="Input Text",
            placeholder="Enter text to analyze (e.g., 'Hello, World!')",
            value="Hello, World!"
        )
    
    with gr.Row():
        analyze_btn = gr.Button("Analyze Both Models", variant="primary", size="lg")
    
    # Model comparison section
    with gr.Row():
        with gr.Column():
            comparison_summary_output = gr.Markdown(label="Model Comparison Summary")
    
    with gr.Row():
        comparison_chart_output = gr.Plot(label="Model Comparison Chart")
    
    # Side-by-side model results
    with gr.Row():
        # Left column: ERNIE-4.5-Base-PT
        with gr.Column():
            gr.Markdown("### ERNIE-4.5-Base-PT")
            base_summary_output = gr.Markdown(label="Base Model Summary")
            base_table_output = gr.DataFrame(
                label="Token Analysis",
                interactive=False,
                wrap=True
            )
            base_chart_output = gr.Plot(label="Token Probability Chart")
        
        # Right column: ERNIE-4.5-PT
        with gr.Column():
            gr.Markdown("### ERNIE-4.5-PT")
            pt_summary_output = gr.Markdown(label="PT Model Summary")
            pt_table_output = gr.DataFrame(
                label="Token Analysis",
                interactive=False,
                wrap=True
            )
            pt_chart_output = gr.Plot(label="Token Probability Chart")
    
    # Examples section
    gr.Examples(
        examples=[
            ["Hello, World!"],
            ["The quick brown fox jumps over the lazy dog."],
            ["Artificial intelligence will transform our society."],
            ["What is the meaning of life?"]
        ],
        inputs=[text_input],
        label="Try these examples:"
    )
    
    # Footer with explanation
    gr.Markdown(
        """
        ## How to Interpret Results
        
        This interface compares two ERNIE models side by side:
        
        1. **ERNIE-4.5-Base-PT** (left): Base model, better at general language patterns
        2. **ERNIE-4.5-PT** (right): Instruction-tuned model, better at following complex instructions
        
        ### Analysis Components
        
        For each model, you'll see:
        - **Summary**: Key metrics including Total Log Probability and average token probability
        - **Token Analysis Table**: Detailed breakdown of each token's log probability and probability
        - **Token Probability Chart**: Visual representation of each token's prediction probability
        
        ### Model Comparison
        
        - **Model Comparison Summary**: Shows which model has higher overall confidence
        - **Model Comparison Chart**: Side-by-side visualization of token probabilities
        
        ### Key Concepts
        
        - **Log Probability**: 
          - Ranges from -∞ to 0
          - Closer to 0 = higher model confidence
          - Used instead of raw probability to avoid numerical underflow
        
        - **Total Log Probability**:
          - Sum of individual token log probabilities
          - Measures overall model confidence in the entire sequence
          - Allows comparison between different models
        
        - **Why Compare Models?**:
          - Base models may be better at general language
          - Instruction-tuned models may be better at specific tasks
          - Different models have different strengths for different types of text
        """
    )
    
    # Set up event handler
    analyze_btn.click(
        fn=analyze_text_both_models,
        inputs=[text_input],
        outputs=[
            base_table_output, base_chart_output, base_summary_output,
            pt_table_output, pt_chart_output, pt_summary_output,
            comparison_chart_output, comparison_summary_output
        ]
    )

if __name__ == "__main__":
    demo.launch()