|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import gradio as gr |
|
|
import math |
|
|
import pandas as pd |
|
|
import plotly.graph_objects as go |
|
|
from plotly.subplots import make_subplots |
|
|
|
|
|
|
|
|
model_ids = { |
|
|
"ERNIE-4.5-Base-PT": "baidu/ERNIE-4.5-0.3B-Base-PT", |
|
|
"ERNIE-4.5-PT": "baidu/ERNIE-4.5-0.3B-PT" |
|
|
} |
|
|
|
|
|
tokenizers = { |
|
|
name: AutoTokenizer.from_pretrained(path) |
|
|
for name, path in model_ids.items() |
|
|
} |
|
|
|
|
|
models = { |
|
|
name: AutoModelForCausalLM.from_pretrained(path).eval() |
|
|
for name, path in model_ids.items() |
|
|
} |
|
|
|
|
|
def calculate_token_log_probabilities(text, model_name): |
|
|
"""Calculate log probability for each token and total log probability.""" |
|
|
tokenizer = tokenizers[model_name] |
|
|
model = models[model_name] |
|
|
|
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt") |
|
|
input_ids = inputs["input_ids"] |
|
|
|
|
|
|
|
|
outputs = model(**inputs) |
|
|
shift_logits = outputs.logits[:, :-1, :] |
|
|
shift_labels = input_ids[:, 1:] |
|
|
|
|
|
|
|
|
log_probs = F.log_softmax(shift_logits, dim=-1) |
|
|
token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1) |
|
|
|
|
|
|
|
|
token_log_probs = token_log_probs[0].tolist() |
|
|
tokens = tokenizer.convert_ids_to_tokens(shift_labels[0]) |
|
|
|
|
|
|
|
|
total_log_prob = sum(token_log_probs) |
|
|
|
|
|
return token_log_probs, tokens, total_log_prob |
|
|
|
|
|
def create_analysis_visualization(tokens, token_log_probs, total_log_prob, model_name): |
|
|
"""Create visualization components for token analysis.""" |
|
|
|
|
|
df_data = [] |
|
|
for token, log_prob in zip(tokens, token_log_probs): |
|
|
prob = math.exp(log_prob) |
|
|
df_data.append({ |
|
|
"Token": token, |
|
|
"Log Probability": f"{log_prob:.4f}", |
|
|
"Probability": f"{prob:.4f}", |
|
|
"Probability (%)": f"{prob*100:.2f}%" |
|
|
}) |
|
|
|
|
|
df = pd.DataFrame(df_data) |
|
|
|
|
|
|
|
|
fig = go.Figure() |
|
|
fig.add_trace(go.Bar( |
|
|
x=tokens, |
|
|
y=[math.exp(lp) for lp in token_log_probs], |
|
|
text=[f"{math.exp(lp):.3f}" for lp in token_log_probs], |
|
|
textposition='auto', |
|
|
marker_color='royalblue', |
|
|
name=model_name |
|
|
)) |
|
|
|
|
|
fig.update_layout( |
|
|
title=f"Token Probability Distribution - {model_name}", |
|
|
xaxis_title="Token", |
|
|
yaxis_title="Probability", |
|
|
yaxis=dict(tickformat='.0%', range=[0, 1.05]), |
|
|
height=400 |
|
|
) |
|
|
|
|
|
return df, fig |
|
|
|
|
|
def analyze_text_both_models(text): |
|
|
"""Analyze text with both models and return visualization components.""" |
|
|
if not text.strip(): |
|
|
|
|
|
return (None, None, "⚠️ Please enter some text to analyze", |
|
|
None, None, "⚠️ Please enter some text to analyze", |
|
|
None, "⚠️ Please enter some text to analyze") |
|
|
|
|
|
results = {} |
|
|
|
|
|
|
|
|
for model_name in model_ids: |
|
|
token_log_probs, tokens, total_log_prob = calculate_token_log_probabilities(text, model_name) |
|
|
|
|
|
|
|
|
df, fig = create_analysis_visualization(tokens, token_log_probs, total_log_prob, model_name) |
|
|
|
|
|
|
|
|
avg_prob = math.exp(total_log_prob / len(token_log_probs)) if token_log_probs else 0 |
|
|
geo_mean_prob = math.exp(total_log_prob / len(token_log_probs)) if token_log_probs else 0 |
|
|
|
|
|
summary = ( |
|
|
f"## Analysis Summary - {model_name}\n\n" |
|
|
f"**Total Log Probability**: {total_log_prob:.4f}\n" |
|
|
f"**Sum of Individual Log Probs**: {sum(token_log_probs):.4f}\n" |
|
|
f"**Verification**: {'✓ Match' if abs(total_log_prob - sum(token_log_probs)) < 1e-10 else '✗ Mismatch'}\n\n" |
|
|
f"**Average Token Probability**: {avg_prob:.4f} ({avg_prob*100:.2f}%)\n" |
|
|
f"**Geometric Mean of Probabilities**: {geo_mean_prob:.4f} ({geo_mean_prob*100:.2f}%)\n\n" |
|
|
f"### Interpretation\n" |
|
|
f"- Total Log Probability is the sum of individual token log probabilities\n" |
|
|
f"- Higher values (closer to 0) indicate higher model confidence\n" |
|
|
f"- The first token has no prediction (no preceding context)\n" |
|
|
f"- Each token's probability shows how confident the model was in predicting it" |
|
|
) |
|
|
|
|
|
results[model_name] = { |
|
|
"df": df, |
|
|
"fig": fig, |
|
|
"summary": summary, |
|
|
"total_log_prob": total_log_prob, |
|
|
"token_log_probs": token_log_probs, |
|
|
"tokens": tokens |
|
|
} |
|
|
|
|
|
|
|
|
comparison_fig = go.Figure() |
|
|
|
|
|
|
|
|
base_model = "ERNIE-4.5-Base-PT" |
|
|
pt_model = "ERNIE-4.5-PT" |
|
|
|
|
|
|
|
|
if results[base_model]["tokens"] == results[pt_model]["tokens"]: |
|
|
tokens = results[base_model]["tokens"] |
|
|
|
|
|
comparison_fig.add_trace(go.Bar( |
|
|
name=base_model, |
|
|
x=tokens, |
|
|
y=[math.exp(lp) for lp in results[base_model]["token_log_probs"]], |
|
|
text=[f"{math.exp(lp):.3f}" for lp in results[base_model]["token_log_probs"]], |
|
|
textposition='auto', |
|
|
marker_color='royalblue' |
|
|
)) |
|
|
|
|
|
comparison_fig.add_trace(go.Bar( |
|
|
name=pt_model, |
|
|
x=tokens, |
|
|
y=[math.exp(lp) for lp in results[pt_model]["token_log_probs"]], |
|
|
text=[f"{math.exp(lp):.3f}" for lp in results[pt_model]["token_log_probs"]], |
|
|
textposition='auto', |
|
|
marker_color='lightseagreen' |
|
|
)) |
|
|
|
|
|
comparison_fig.update_layout( |
|
|
title="Model Comparison: Token Probability Distribution", |
|
|
xaxis_title="Token", |
|
|
yaxis_title="Probability", |
|
|
yaxis=dict(tickformat='.0%', range=[0, 1.05]), |
|
|
barmode='group', |
|
|
height=400 |
|
|
) |
|
|
else: |
|
|
|
|
|
comparison_fig = make_subplots( |
|
|
rows=2, cols=1, |
|
|
subplot_titles=(base_model, pt_model), |
|
|
vertical_spacing=0.1 |
|
|
) |
|
|
|
|
|
|
|
|
comparison_fig.add_trace( |
|
|
go.Bar( |
|
|
x=results[base_model]["tokens"], |
|
|
y=[math.exp(lp) for lp in results[base_model]["token_log_probs"]], |
|
|
text=[f"{math.exp(lp):.3f}" for lp in results[base_model]["token_log_probs"]], |
|
|
textposition='auto', |
|
|
marker_color='royalblue', |
|
|
name=base_model |
|
|
), |
|
|
row=1, col=1 |
|
|
) |
|
|
|
|
|
|
|
|
comparison_fig.add_trace( |
|
|
go.Bar( |
|
|
x=results[pt_model]["tokens"], |
|
|
y=[math.exp(lp) for lp in results[pt_model]["token_log_probs"]], |
|
|
text=[f"{math.exp(lp):.3f}" for lp in results[pt_model]["token_log_probs"]], |
|
|
textposition='auto', |
|
|
marker_color='lightseagreen', |
|
|
name=pt_model |
|
|
), |
|
|
row=2, col=1 |
|
|
) |
|
|
|
|
|
comparison_fig.update_layout( |
|
|
title="Model Comparison: Token Probability Distribution", |
|
|
height=600, |
|
|
showlegend=False |
|
|
) |
|
|
|
|
|
|
|
|
base_total = results[base_model]["total_log_prob"] |
|
|
pt_total = results[pt_model]["total_log_prob"] |
|
|
|
|
|
better_model = base_model if base_total > pt_total else pt_model |
|
|
difference = abs(base_total - pt_total) |
|
|
|
|
|
comparison_summary = ( |
|
|
f"## Model Comparison Summary\n\n" |
|
|
f"**Total Log Probability**:\n" |
|
|
f"- {base_model}: {base_total:.4f}\n" |
|
|
f"- {pt_model}: {pt_total:.4f}\n\n" |
|
|
f"**Higher Confidence Model**: {better_model}\n" |
|
|
f"Difference: {difference:.4f}\n\n" |
|
|
f"### Interpretation\n" |
|
|
f"- The model with the higher Total Log Probability is more confident in predicting the input text\n" |
|
|
f"- Log probability closer to 0 (less negative) indicates higher model confidence\n" |
|
|
f"- {base_model} is the base model while {pt_model} is instruction-tuned" |
|
|
) |
|
|
|
|
|
return ( |
|
|
results[base_model]["df"], |
|
|
results[base_model]["fig"], |
|
|
results[base_model]["summary"], |
|
|
results[pt_model]["df"], |
|
|
results[pt_model]["fig"], |
|
|
results[pt_model]["summary"], |
|
|
comparison_fig, |
|
|
comparison_summary |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Token Log Probability Analyzer - Model Comparison") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# 🔍 Token Log Probability Analyzer - Model Comparison |
|
|
|
|
|
Compare how two ERNIE models predict each token in your text with detailed log probability breakdown. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
text_input = gr.Textbox( |
|
|
label="Input Text", |
|
|
placeholder="Enter text to analyze (e.g., 'Hello, World!')", |
|
|
value="Hello, World!" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
analyze_btn = gr.Button("Analyze Both Models", variant="primary", size="lg") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
comparison_summary_output = gr.Markdown(label="Model Comparison Summary") |
|
|
|
|
|
with gr.Row(): |
|
|
comparison_chart_output = gr.Plot(label="Model Comparison Chart") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("### ERNIE-4.5-Base-PT") |
|
|
base_summary_output = gr.Markdown(label="Base Model Summary") |
|
|
base_table_output = gr.DataFrame( |
|
|
label="Token Analysis", |
|
|
interactive=False, |
|
|
wrap=True |
|
|
) |
|
|
base_chart_output = gr.Plot(label="Token Probability Chart") |
|
|
|
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("### ERNIE-4.5-PT") |
|
|
pt_summary_output = gr.Markdown(label="PT Model Summary") |
|
|
pt_table_output = gr.DataFrame( |
|
|
label="Token Analysis", |
|
|
interactive=False, |
|
|
wrap=True |
|
|
) |
|
|
pt_chart_output = gr.Plot(label="Token Probability Chart") |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["Hello, World!"], |
|
|
["The quick brown fox jumps over the lazy dog."], |
|
|
["Artificial intelligence will transform our society."], |
|
|
["What is the meaning of life?"] |
|
|
], |
|
|
inputs=[text_input], |
|
|
label="Try these examples:" |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
## How to Interpret Results |
|
|
|
|
|
This interface compares two ERNIE models side by side: |
|
|
|
|
|
1. **ERNIE-4.5-Base-PT** (left): Base model, better at general language patterns |
|
|
2. **ERNIE-4.5-PT** (right): Instruction-tuned model, better at following complex instructions |
|
|
|
|
|
### Analysis Components |
|
|
|
|
|
For each model, you'll see: |
|
|
- **Summary**: Key metrics including Total Log Probability and average token probability |
|
|
- **Token Analysis Table**: Detailed breakdown of each token's log probability and probability |
|
|
- **Token Probability Chart**: Visual representation of each token's prediction probability |
|
|
|
|
|
### Model Comparison |
|
|
|
|
|
- **Model Comparison Summary**: Shows which model has higher overall confidence |
|
|
- **Model Comparison Chart**: Side-by-side visualization of token probabilities |
|
|
|
|
|
### Key Concepts |
|
|
|
|
|
- **Log Probability**: |
|
|
- Ranges from -∞ to 0 |
|
|
- Closer to 0 = higher model confidence |
|
|
- Used instead of raw probability to avoid numerical underflow |
|
|
|
|
|
- **Total Log Probability**: |
|
|
- Sum of individual token log probabilities |
|
|
- Measures overall model confidence in the entire sequence |
|
|
- Allows comparison between different models |
|
|
|
|
|
- **Why Compare Models?**: |
|
|
- Base models may be better at general language |
|
|
- Instruction-tuned models may be better at specific tasks |
|
|
- Different models have different strengths for different types of text |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
analyze_btn.click( |
|
|
fn=analyze_text_both_models, |
|
|
inputs=[text_input], |
|
|
outputs=[ |
|
|
base_table_output, base_chart_output, base_summary_output, |
|
|
pt_table_output, pt_chart_output, pt_summary_output, |
|
|
comparison_chart_output, comparison_summary_output |
|
|
] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |