AlexTransformer's picture
Update app.py
674515d verified
raw
history blame
10 kB
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import pandas as pd
import math
from plotly import graph_objects as go
# Load model and tokenizer
model_ids = {
"ERNIE-4.5-PT": "baidu/ERNIE-4.5-0.3B-PT",
"ERNIE-4.5-Base-PT": "baidu/ERNIE-4.5-0.3B-Base-PT"
}
tokenizers = {
name: AutoTokenizer.from_pretrained(path)
for name, path in model_ids.items()
}
models = {
name: AutoModelForCausalLM.from_pretrained(path).eval()
for name, path in model_ids.items()
}
# Helper function to format probability
def format_prob(prob):
"""Format probability as percentage with 1 decimal place"""
return f"{prob*100:.1f}%"
# Helper function to format log probability
def format_log_prob(log_prob):
"""Format log probability with color coding"""
return f"{log_prob:.3f}"
# Main function: compute token-wise log probabilities and top-k predictions
@torch.no_grad()
def compare_models(text, top_k=5):
if not text.strip():
return None, "⚠️ Please enter some text to analyze"
results = {}
for model_name in model_ids:
tokenizer = tokenizers[model_name]
model = models[model_name]
# Tokenize input
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"]
# Get model output logits
outputs = model(**inputs)
shift_logits = outputs.logits[:, :-1, :] # Align prediction with target
shift_labels = input_ids[:, 1:] # Shift labels to match predictions
# Compute log probabilities
log_probs = F.log_softmax(shift_logits, dim=-1)
token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
total_log_prob = token_log_probs.sum().item()
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])[1:] # Skip BOS token
# Generate top-k predictions for each position (up to first 20 tokens)
topk_list = []
confidence_list = []
for i in range(min(20, shift_logits.shape[1])):
topk = torch.topk(log_probs[0, i], k=top_k)
topk_ids = topk.indices.tolist()
topk_scores = topk.values.tolist()
topk_tokens = tokenizer.convert_ids_to_tokens(topk_ids)
topk_probs = [math.exp(s) for s in topk_scores]
# Format top-k predictions with probabilities
topk_formatted = [f"{tok} ({format_prob(p)})" for tok, p in zip(topk_tokens, topk_probs)]
topk_list.append(", ".join(topk_formatted))
# Calculate confidence (probability of actual token)
actual_token_prob = math.exp(token_log_probs[0, i].item())
confidence_list.append(actual_token_prob)
# Prepare dataframe for display
df = pd.DataFrame({
"Token": tokens[:20],
"LogProb": [format_log_prob(float(x)) for x in token_log_probs[0][:20]],
"Confidence": [format_prob(x) for x in confidence_list[:20]],
f"Top-{top_k} Predictions": topk_list
})
results[model_name] = {
"df": df,
"total_log_prob": total_log_prob,
"tokens": tokens[:20],
"confidences": confidence_list[:20]
}
# Create comparison dataframe
comparison_df = pd.DataFrame({
"Token": results["ERNIE-4.5-PT"]["df"]["Token"],
"ERNIE-4.5-PT": {
"LogProb": results["ERNIE-4.5-PT"]["df"]["LogProb"],
"Confidence": results["ERNIE-4.5-PT"]["df"]["Confidence"],
"Top-k": results["ERNIE-4.5-PT"]["df"][f"Top-{top_k} Predictions"]
},
"ERNIE-4.5-Base-PT": {
"LogProb": results["ERNIE-4.5-Base-PT"]["df"]["LogProb"],
"Confidence": results["ERNIE-4.5-Base-PT"]["df"]["Confidence"],
"Top-k": results["ERNIE-4.5-Base-PT"]["df"][f"Top-{top_k} Predictions"]
}
})
# Create visualization
fig = go.Figure()
# Add confidence bars for both models
fig.add_trace(go.Bar(
name='ERNIE-4.5-PT',
x=results["ERNIE-4.5-PT"]["tokens"],
y=results["ERNIE-4.5-PT"]["confidences"],
marker_color='royalblue'
))
fig.add_trace(go.Bar(
name='ERNIE-4.5-Base-PT',
x=results["ERNIE-4.5-Base-PT"]["tokens"],
y=results["ERNIE-4.5-Base-PT"]["confidences"],
marker_color='lightseagreen'
))
fig.update_layout(
title='Model Confidence Comparison',
xaxis_title='Token',
yaxis_title='Confidence (Probability)',
barmode='group',
yaxis=dict(tickformat='.0%', range=[0, 1]),
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1
)
)
# Create summary
pt_logprob = results['ERNIE-4.5-PT']['total_log_prob']
base_logprob = results['ERNIE-4.5-Base-PT']['total_log_prob']
# Determine which model has higher confidence
if pt_logprob > base_logprob:
better_model = "ERNIE-4.5-PT"
difference = pt_logprob - base_logprob
else:
better_model = "ERNIE-4.5-Base-PT"
difference = base_logprob - pt_logprob
summary = (
f"📊 **Model Comparison Summary**\n\n"
f"**Total Log Probability**:\n"
f"- ERNIE-4.5-PT: {pt_logprob:.3f}\n"
f"- ERNIE-4.5-Base-PT: {base_logprob:.3f}\n\n"
f"🏆 **Higher Confidence Model**: {better_model}\n"
f"Difference: {difference:.3f} ({'+' if better_model == 'ERNIE-4.5-PT' else '-'}{difference:.3f})\n\n"
f"**What this means**:\n"
f"- Log probability closer to 0 (less negative) indicates higher model confidence\n"
f"- The {better_model} model is more confident in predicting your input text\n"
f"- Confidence per token is shown in the table and chart below"
)
return comparison_df, summary, fig
# Create custom CSS for better styling
css = """
.main-container {
max-width: 1200px;
margin: 0 auto;
}
.dataframe-container {
margin: 20px 0;
}
.confidence-chart {
margin: 20px 0;
height: 400px;
}
.summary-box {
background-color: #f8f9fa;
border-left: 4px solid #4285f4;
padding: 15px;
border-radius: 4px;
margin: 20px 0;
}
.model-header {
font-weight: bold;
color: #1a73e8;
margin-top: 10px;
}
.token-cell {
font-family: monospace;
background-color: #f1f3f4;
padding: 4px 8px;
border-radius: 3px;
}
.confidence-high {
color: #0f9d58;
font-weight: bold;
}
.confidence-medium {
color: #f4b400;
}
.confidence-low {
color: #db4437;
}
"""
# Gradio interface with improved layout
with gr.Blocks(css=css, title="ERNIE Model Comparison Tool") as demo:
gr.Markdown(
"""
# 🔍 ERNIE 4.5 Model Comparison Tool
Compare how different ERNIE models process your text with detailed token-level analysis.
## What this tool shows:
- **Token Log Probability**: How confident the model is in predicting each token (closer to 0 is better)
- **Confidence**: Probability percentage for each token prediction
- **Top-k Predictions**: What other tokens the model considered likely
- **Visual Comparison**: Bar chart showing confidence differences between models
"""
)
with gr.Row():
with gr.Column(scale=3):
input_text = gr.Textbox(
lines=3,
placeholder="Enter text to analyze (e.g., 'Hello, World!')",
label="Input Text",
value="Hello, World!"
)
with gr.Column(scale=1):
top_k = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="Top-k Predictions"
)
with gr.Row():
compare_btn = gr.Button("Compare Models", variant="primary")
with gr.Row():
with gr.Column():
summary_box = gr.Markdown(
elem_classes=["summary-box"],
label="Model Comparison Summary"
)
with gr.Row():
with gr.Column():
comparison_table = gr.Dataframe(
label="Token-Level Analysis",
elem_classes=["dataframe-container"],
interactive=False,
wrap=True
)
with gr.Row():
with gr.Column():
confidence_chart = gr.Plot(
label="Model Confidence Comparison",
elem_classes=["confidence-chart"]
)
# Examples section
gr.Examples(
examples=[
["Hello, World!", 3],
["The quick brown fox jumps over the lazy dog.", 5],
["Artificial intelligence will transform our society.", 3],
["What is the meaning of life?", 4]
],
inputs=[input_text, top_k],
label="Try these examples:"
)
# Footer with explanation
gr.Markdown(
"""
## How to Interpret Results
1. **Log Probability**: Negative values where closer to 0 means higher model confidence
2. **Confidence**: Percentage showing how certain the model was about each token
3. **Top-k Predictions**: Alternative tokens the model considered likely
4. **Visual Chart**: Bar heights represent model confidence for each token
**Model Differences**:
- **ERNIE-4.5-PT**: Instruction-tuned model, better at following complex instructions
- **ERNIE-4.5-Base-PT**: Base model, better at general language patterns
"""
)
# Set up event handler
compare_btn.click(
fn=compare_models,
inputs=[input_text, top_k],
outputs=[comparison_table, summary_box, confidence_chart]
)
if __name__ == "__main__":
demo.launch()