AlexTransformer's picture
Update app.py
cea53ee verified
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import math
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# Load model and tokenizer
model_ids = {
"ERNIE-4.5-Base-PT": "baidu/ERNIE-4.5-0.3B-Base-PT",
"ERNIE-4.5-PT": "baidu/ERNIE-4.5-0.3B-PT"
}
tokenizers = {
name: AutoTokenizer.from_pretrained(path)
for name, path in model_ids.items()
}
models = {
name: AutoModelForCausalLM.from_pretrained(path).eval()
for name, path in model_ids.items()
}
def calculate_token_log_probabilities(text, model_name):
"""Calculate log probability for each token and total log probability."""
tokenizer = tokenizers[model_name]
model = models[model_name]
# Tokenize input
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"]
# Get model output logits
outputs = model(**inputs)
shift_logits = outputs.logits[:, :-1, :] # Align prediction with target
shift_labels = input_ids[:, 1:] # Shift labels to match predictions
# Compute log probabilities
log_probs = F.log_softmax(shift_logits, dim=-1)
token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
# Convert to list and get corresponding tokens
token_log_probs = token_log_probs[0].tolist()
tokens = tokenizer.convert_ids_to_tokens(shift_labels[0])
# Calculate total log probability
total_log_prob = sum(token_log_probs)
return token_log_probs, tokens, total_log_prob
def create_analysis_visualization(tokens, token_log_probs, total_log_prob, model_name):
"""Create visualization components for token analysis."""
# Create DataFrame for token analysis
df_data = []
for token, log_prob in zip(tokens, token_log_probs):
prob = math.exp(log_prob)
df_data.append({
"Token": token,
"Log Probability": f"{log_prob:.4f}",
"Probability": f"{prob:.4f}",
"Probability (%)": f"{prob*100:.2f}%"
})
df = pd.DataFrame(df_data)
# Create bar chart for token probabilities
fig = go.Figure()
fig.add_trace(go.Bar(
x=tokens,
y=[math.exp(lp) for lp in token_log_probs],
text=[f"{math.exp(lp):.3f}" for lp in token_log_probs],
textposition='auto',
marker_color='royalblue',
name=model_name
))
fig.update_layout(
title=f"Token Probability Distribution - {model_name}",
xaxis_title="Token",
yaxis_title="Probability",
yaxis=dict(tickformat='.0%', range=[0, 1.05]),
height=400
)
return df, fig
def analyze_text_both_models(text):
"""Analyze text with both models and return visualization components."""
if not text.strip():
# Return empty components for both models
return (None, None, "⚠️ Please enter some text to analyze",
None, None, "⚠️ Please enter some text to analyze",
None, "⚠️ Please enter some text to analyze")
results = {}
# Analyze with both models
for model_name in model_ids:
token_log_probs, tokens, total_log_prob = calculate_token_log_probabilities(text, model_name)
# Create visualization components
df, fig = create_analysis_visualization(tokens, token_log_probs, total_log_prob, model_name)
# Create summary text
avg_prob = math.exp(total_log_prob / len(token_log_probs)) if token_log_probs else 0
geo_mean_prob = math.exp(total_log_prob / len(token_log_probs)) if token_log_probs else 0
summary = (
f"## Analysis Summary - {model_name}\n\n"
f"**Total Log Probability**: {total_log_prob:.4f}\n"
f"**Sum of Individual Log Probs**: {sum(token_log_probs):.4f}\n"
f"**Verification**: {'✓ Match' if abs(total_log_prob - sum(token_log_probs)) < 1e-10 else '✗ Mismatch'}\n\n"
f"**Average Token Probability**: {avg_prob:.4f} ({avg_prob*100:.2f}%)\n"
f"**Geometric Mean of Probabilities**: {geo_mean_prob:.4f} ({geo_mean_prob*100:.2f}%)\n\n"
f"### Interpretation\n"
f"- Total Log Probability is the sum of individual token log probabilities\n"
f"- Higher values (closer to 0) indicate higher model confidence\n"
f"- The first token has no prediction (no preceding context)\n"
f"- Each token's probability shows how confident the model was in predicting it"
)
results[model_name] = {
"df": df,
"fig": fig,
"summary": summary,
"total_log_prob": total_log_prob,
"token_log_probs": token_log_probs,
"tokens": tokens
}
# Create comparison chart
comparison_fig = go.Figure()
# Add bars for both models
base_model = "ERNIE-4.5-Base-PT"
pt_model = "ERNIE-4.5-PT"
# Ensure both models have the same tokens for comparison
if results[base_model]["tokens"] == results[pt_model]["tokens"]:
tokens = results[base_model]["tokens"]
comparison_fig.add_trace(go.Bar(
name=base_model,
x=tokens,
y=[math.exp(lp) for lp in results[base_model]["token_log_probs"]],
text=[f"{math.exp(lp):.3f}" for lp in results[base_model]["token_log_probs"]],
textposition='auto',
marker_color='royalblue'
))
comparison_fig.add_trace(go.Bar(
name=pt_model,
x=tokens,
y=[math.exp(lp) for lp in results[pt_model]["token_log_probs"]],
text=[f"{math.exp(lp):.3f}" for lp in results[pt_model]["token_log_probs"]],
textposition='auto',
marker_color='lightseagreen'
))
comparison_fig.update_layout(
title="Model Comparison: Token Probability Distribution",
xaxis_title="Token",
yaxis_title="Probability",
yaxis=dict(tickformat='.0%', range=[0, 1.05]),
barmode='group',
height=400
)
else:
# If tokens are different, create separate subplots
comparison_fig = make_subplots(
rows=2, cols=1,
subplot_titles=(base_model, pt_model),
vertical_spacing=0.1
)
# Add Base-PT model
comparison_fig.add_trace(
go.Bar(
x=results[base_model]["tokens"],
y=[math.exp(lp) for lp in results[base_model]["token_log_probs"]],
text=[f"{math.exp(lp):.3f}" for lp in results[base_model]["token_log_probs"]],
textposition='auto',
marker_color='royalblue',
name=base_model
),
row=1, col=1
)
# Add PT model
comparison_fig.add_trace(
go.Bar(
x=results[pt_model]["tokens"],
y=[math.exp(lp) for lp in results[pt_model]["token_log_probs"]],
text=[f"{math.exp(lp):.3f}" for lp in results[pt_model]["token_log_probs"]],
textposition='auto',
marker_color='lightseagreen',
name=pt_model
),
row=2, col=1
)
comparison_fig.update_layout(
title="Model Comparison: Token Probability Distribution",
height=600,
showlegend=False
)
# Create comparison summary
base_total = results[base_model]["total_log_prob"]
pt_total = results[pt_model]["total_log_prob"]
better_model = base_model if base_total > pt_total else pt_model
difference = abs(base_total - pt_total)
comparison_summary = (
f"## Model Comparison Summary\n\n"
f"**Total Log Probability**:\n"
f"- {base_model}: {base_total:.4f}\n"
f"- {pt_model}: {pt_total:.4f}\n\n"
f"**Higher Confidence Model**: {better_model}\n"
f"Difference: {difference:.4f}\n\n"
f"### Interpretation\n"
f"- The model with the higher Total Log Probability is more confident in predicting the input text\n"
f"- Log probability closer to 0 (less negative) indicates higher model confidence\n"
f"- {base_model} is the base model while {pt_model} is instruction-tuned"
)
return (
results[base_model]["df"],
results[base_model]["fig"],
results[base_model]["summary"],
results[pt_model]["df"],
results[pt_model]["fig"],
results[pt_model]["summary"],
comparison_fig,
comparison_summary
)
# Create Gradio interface with side-by-side comparison
with gr.Blocks(title="Token Log Probability Analyzer - Model Comparison") as demo:
gr.Markdown(
"""
# 🔍 Token Log Probability Analyzer - Model Comparison
Compare how two ERNIE models predict each token in your text with detailed log probability breakdown.
"""
)
with gr.Row():
text_input = gr.Textbox(
label="Input Text",
placeholder="Enter text to analyze (e.g., 'Hello, World!')",
value="Hello, World!"
)
with gr.Row():
analyze_btn = gr.Button("Analyze Both Models", variant="primary", size="lg")
# Model comparison section
with gr.Row():
with gr.Column():
comparison_summary_output = gr.Markdown(label="Model Comparison Summary")
with gr.Row():
comparison_chart_output = gr.Plot(label="Model Comparison Chart")
# Side-by-side model results
with gr.Row():
# Left column: ERNIE-4.5-Base-PT
with gr.Column():
gr.Markdown("### ERNIE-4.5-Base-PT")
base_summary_output = gr.Markdown(label="Base Model Summary")
base_table_output = gr.DataFrame(
label="Token Analysis",
interactive=False,
wrap=True
)
base_chart_output = gr.Plot(label="Token Probability Chart")
# Right column: ERNIE-4.5-PT
with gr.Column():
gr.Markdown("### ERNIE-4.5-PT")
pt_summary_output = gr.Markdown(label="PT Model Summary")
pt_table_output = gr.DataFrame(
label="Token Analysis",
interactive=False,
wrap=True
)
pt_chart_output = gr.Plot(label="Token Probability Chart")
# Examples section
gr.Examples(
examples=[
["Hello, World!"],
["The quick brown fox jumps over the lazy dog."],
["Artificial intelligence will transform our society."],
["What is the meaning of life?"]
],
inputs=[text_input],
label="Try these examples:"
)
# Footer with explanation
gr.Markdown(
"""
## How to Interpret Results
This interface compares two ERNIE models side by side:
1. **ERNIE-4.5-Base-PT** (left): Base model, better at general language patterns
2. **ERNIE-4.5-PT** (right): Instruction-tuned model, better at following complex instructions
### Analysis Components
For each model, you'll see:
- **Summary**: Key metrics including Total Log Probability and average token probability
- **Token Analysis Table**: Detailed breakdown of each token's log probability and probability
- **Token Probability Chart**: Visual representation of each token's prediction probability
### Model Comparison
- **Model Comparison Summary**: Shows which model has higher overall confidence
- **Model Comparison Chart**: Side-by-side visualization of token probabilities
### Key Concepts
- **Log Probability**:
- Ranges from -∞ to 0
- Closer to 0 = higher model confidence
- Used instead of raw probability to avoid numerical underflow
- **Total Log Probability**:
- Sum of individual token log probabilities
- Measures overall model confidence in the entire sequence
- Allows comparison between different models
- **Why Compare Models?**:
- Base models may be better at general language
- Instruction-tuned models may be better at specific tasks
- Different models have different strengths for different types of text
"""
)
# Set up event handler
analyze_btn.click(
fn=analyze_text_both_models,
inputs=[text_input],
outputs=[
base_table_output, base_chart_output, base_summary_output,
pt_table_output, pt_chart_output, pt_summary_output,
comparison_chart_output, comparison_summary_output
]
)
if __name__ == "__main__":
demo.launch()