AlexTransformer's picture
Update app.py
b9a15d9 verified
raw
history blame
14.2 kB
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import pandas as pd
import math
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
# Load model and tokenizer
model_ids = {
"ERNIE-4.5-PT": "baidu/ERNIE-4.5-0.3B-PT",
"ERNIE-4.5-Base-PT": "baidu/ERNIE-4.5-0.3B-Base-PT"
}
tokenizers = {
name: AutoTokenizer.from_pretrained(path)
for name, path in model_ids.items()
}
models = {
name: AutoModelForCausalLM.from_pretrained(path).eval()
for name, path in model_ids.items()
}
# Helper function to format probability
def format_prob(prob):
"""Format probability as percentage with 1 decimal place"""
return f"{prob*100:.1f}%"
# Helper function to format log probability
def format_log_prob(log_prob):
"""Format log probability"""
return f"{log_prob:.3f}"
# Helper function to get confidence level
def get_confidence_level(prob):
"""Get confidence level description based on probability"""
if prob > 0.8:
return "High", "🟒"
elif prob > 0.5:
return "Medium", "🟑"
else:
return "Low", "πŸ”΄"
# Main function: compute token-wise log probabilities and top-k predictions
@torch.no_grad()
def compare_models(text, top_k=5):
if not text.strip():
return None, "⚠️ Please enter some text to analyze", None
results = {}
for model_name in model_ids:
tokenizer = tokenizers[model_name]
model = models[model_name]
# Tokenize input
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"]
# Get model output logits
outputs = model(**inputs)
shift_logits = outputs.logits[:, :-1, :] # Align prediction with target
shift_labels = input_ids[:, 1:] # Shift labels to match predictions
# Compute log probabilities
log_probs = F.log_softmax(shift_logits, dim=-1)
token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
total_log_prob = token_log_probs.sum().item()
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])[1:] # Skip BOS token
# Generate top-k predictions for each position (up to first 20 tokens)
topk_list = []
confidence_list = []
confidence_indicators = []
for i in range(min(20, shift_logits.shape[1])):
topk = torch.topk(log_probs[0, i], k=top_k)
topk_ids = topk.indices.tolist()
topk_scores = topk.values.tolist()
topk_tokens = tokenizer.convert_ids_to_tokens(topk_ids)
topk_probs = [math.exp(s) for s in topk_scores]
# Format top-k predictions with probabilities
topk_formatted = [f"{tok} ({format_prob(p)})" for tok, p in zip(topk_tokens, topk_probs)]
topk_list.append(", ".join(topk_formatted))
# Calculate confidence (probability of actual token)
actual_token_prob = math.exp(token_log_probs[0, i].item())
confidence_list.append(actual_token_prob)
# Get confidence level and indicator
level, indicator = get_confidence_level(actual_token_prob)
confidence_indicators.append(indicator)
# Store results for this model
results[model_name] = {
"tokens": tokens[:20],
"log_probs": [format_log_prob(float(x)) for x in token_log_probs[0][:20]],
"confidences": [format_prob(x) for x in confidence_list[:20]],
"levels": confidence_indicators[:20],
"topk_predictions": topk_list,
"total_log_prob": total_log_prob,
"confidence_values": confidence_list[:20] # Keep raw values for plotting
}
# Create a properly structured dataframe
df_data = {"Token": results["ERNIE-4.5-PT"]["tokens"]}
# Add columns for each model
for model_name in ["ERNIE-4.5-PT", "ERNIE-4.5-Base-PT"]:
df_data[f"{model_name} LogProb"] = results[model_name]["log_probs"]
df_data[f"{model_name} Confidence"] = results[model_name]["confidences"]
df_data[f"{model_name} Level"] = results[model_name]["levels"]
df_data[f"{model_name} Top-{top_k}"] = results[model_name]["topk_predictions"]
# Create the dataframe
comparison_df = pd.DataFrame(df_data)
# Create visualizations
# 1. Token-level confidence comparison
fig_confidence = go.Figure()
# Add bars for both models
fig_confidence.add_trace(go.Bar(
name='ERNIE-4.5-PT',
x=results["ERNIE-4.5-PT"]["tokens"],
y=results["ERNIE-4.5-PT"]["confidence_values"],
marker_color='royalblue',
text=[format_prob(x) for x in results["ERNIE-4.5-PT"]["confidence_values"]],
textposition='auto',
textfont=dict(size=10)
))
fig_confidence.add_trace(go.Bar(
name='ERNIE-4.5-Base-PT',
x=results["ERNIE-4.5-Base-PT"]["tokens"],
y=results["ERNIE-4.5-Base-PT"]["confidence_values"],
marker_color='lightseagreen',
text=[format_prob(x) for x in results["ERNIE-4.5-Base-PT"]["confidence_values"]],
textposition='auto',
textfont=dict(size=10)
))
fig_confidence.update_layout(
title='Token-Level Confidence Comparison',
xaxis_title='Token',
yaxis_title='Confidence (Probability)',
barmode='group',
yaxis=dict(tickformat='.0%', range=[0, 1.05]),
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1
),
height=500
)
# 2. Log probability trend comparison
fig_logprob = go.Figure()
# Convert log probabilities back to float for plotting
pt_logprobs = [float(x) for x in results["ERNIE-4.5-PT"]["log_probs"]]
base_logprobs = [float(x) for x in results["ERNIE-4.5-Base-PT"]["log_probs"]]
fig_logprob.add_trace(go.Scatter(
name='ERNIE-4.5-PT',
x=results["ERNIE-4.5-PT"]["tokens"],
y=pt_logprobs,
mode='lines+markers',
line=dict(color='royalblue', width=3),
marker=dict(size=8),
text=[f"LogProb: {x}<br>Token: {t}" for x, t in zip(pt_logprobs, results["ERNIE-4.5-PT"]["tokens"])],
hoverinfo='text'
))
fig_logprob.add_trace(go.Scatter(
name='ERNIE-4.5-Base-PT',
x=results["ERNIE-4.5-Base-PT"]["tokens"],
y=base_logprobs,
mode='lines+markers',
line=dict(color='lightseagreen', width=3),
marker=dict(size=8),
text=[f"LogProb: {x}<br>Token: {t}" for x, t in zip(base_logprobs, results["ERNIE-4.5-Base-PT"]["tokens"])],
hoverinfo='text'
))
# Add a horizontal line at y=0 for reference
fig_logprob.add_hline(y=0, line_dash="dash", line_color="red", annotation_text="Zero Reference")
fig_logprob.update_layout(
title='Token-Level Log Probability Trend',
xaxis_title='Token',
yaxis_title='Log Probability',
hovermode='closest',
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1
),
height=400
)
# 3. Model summary comparison
pt_logprob = results['ERNIE-4.5-PT']['total_log_prob']
base_logprob = results['ERNIE-4.5-Base-PT']['total_log_prob']
# Determine which model has higher confidence
if pt_logprob > base_logprob:
better_model = "ERNIE-4.5-PT"
difference = pt_logprob - base_logprob
else:
better_model = "ERNIE-4.5-Base-PT"
difference = base_logprob - pt_logprob
# Calculate average confidence for each model
pt_avg_conf = sum(results['ERNIE-4.5-PT']['confidence_values']) / len(results['ERNIE-4.5-PT']['confidence_values'])
base_avg_conf = sum(results['ERNIE-4.5-Base-PT']['confidence_values']) / len(results['ERNIE-4.5-Base-PT']['confidence_values'])
# Create summary chart
fig_summary = go.Figure()
fig_summary.add_trace(go.Bar(
name='Total Log Probability',
x=['ERNIE-4.5-PT', 'ERNIE-4.5-Base-PT'],
y=[pt_logprob, base_logprob],
marker_color=['royalblue', 'lightseagreen'],
text=[f"{pt_logprob:.3f}", f"{base_logprob:.3f}"],
textposition='auto',
textfont=dict(size=14)
))
fig_summary.update_layout(
title='Model Summary Comparison',
yaxis_title='Total Log Probability',
xaxis_title='Model',
height=300,
showlegend=False
)
# Add annotation for the better model
fig_summary.add_annotation(
x=0 if better_model == "ERNIE-4.5-PT" else 1,
y=max(pt_logprob, base_logprob) + 0.5,
text=f"πŸ† {better_model}",
showarrow=True,
arrowhead=1,
ax=0,
ay=-30,
font=dict(size=16, color="green")
)
# Create summary text
summary = (
f"πŸ“Š **Model Comparison Summary**\n\n"
f"**Total Log Probability**:\n"
f"- ERNIE-4.5-PT: {pt_logprob:.3f}\n"
f"- ERNIE-4.5-Base-PT: {base_logprob:.3f}\n\n"
f"**Average Confidence**:\n"
f"- ERNIE-4.5-PT: {format_prob(pt_avg_conf)}\n"
f"- ERNIE-4.5-Base-PT: {format_prob(base_avg_conf)}\n\n"
f"πŸ† **Higher Confidence Model**: {better_model}\n"
f"Difference: {difference:.3f}\n\n"
f"**What this means**:\n"
f"- Log probability closer to 0 (less negative) indicates higher model confidence\n"
f"- The {better_model} model is more confident in predicting your input text\n"
f"- Confidence indicators: 🟒 High (>80%), 🟑 Medium (50-80%), πŸ”΄ Low (<50%)\n\n"
f"**Interpretation Guide**:\n"
f"- **LogProb**: How confident the model is in predicting each token (closer to 0 is better)\n"
f"- **Confidence**: Probability percentage for each token prediction\n"
f"- **Level**: Visual indicator of confidence (πŸŸ’πŸŸ‘πŸ”΄)\n"
f"- **Top-k**: What other tokens the model considered likely"
)
return comparison_df, summary, fig_confidence, fig_logprob, fig_summary
# Create custom CSS for better styling
css = """
.main-container {
max-width: 1400px;
margin: 0 auto;
}
.dataframe-container {
margin: 20px 0;
}
.summary-box {
background-color: #f8f9fa;
border-left: 4px solid #4285f4;
padding: 15px;
border-radius: 4px;
margin: 20px 0;
}
.chart-container {
margin: 20px 0;
border: 1px solid #e0e0e0;
border-radius: 8px;
padding: 15px;
background-color: #ffffff;
}
"""
# Gradio interface with improved layout
with gr.Blocks(css=css, title="ERNIE Model Comparison Tool") as demo:
gr.Markdown(
"""
# πŸ” ERNIE 4.5 Model Comparison Tool
Compare how different ERNIE models process your text with detailed token-level analysis and visualizations.
"""
)
with gr.Row():
with gr.Column(scale=3):
input_text = gr.Textbox(
lines=3,
placeholder="Enter text to analyze (e.g., 'Hello, World!')",
label="Input Text",
value="What is the meaning of life?"
)
with gr.Column(scale=1):
top_k = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="Top-k Predictions"
)
with gr.Row():
compare_btn = gr.Button("Compare Models", variant="primary", size="lg")
with gr.Row():
with gr.Column():
summary_box = gr.Markdown(
elem_classes=["summary-box"],
label="Model Comparison Summary"
)
with gr.Row():
with gr.Column():
summary_chart = gr.Plot(
label="Model Summary",
elem_classes=["chart-container"]
)
with gr.Row():
with gr.Column():
confidence_chart = gr.Plot(
label="Token-Level Confidence Comparison",
elem_classes=["chart-container"]
)
with gr.Row():
with gr.Column():
logprob_chart = gr.Plot(
label="Token-Level Log Probability Trend",
elem_classes=["chart-container"]
)
with gr.Row():
with gr.Column():
comparison_table = gr.Dataframe(
label="Token-Level Analysis",
elem_classes=["dataframe-container"],
interactive=False,
wrap=True
)
# Examples section
gr.Examples(
examples=[
["Hello, World!", 3],
["The quick brown fox jumps over the lazy dog.", 5],
["Artificial intelligence will transform our society.", 3],
["What is the meaning of life?", 4]
],
inputs=[input_text, top_k],
label="Try these examples:"
)
# Footer with explanation
gr.Markdown(
"""
## How to Interpret Results
1. **Model Summary Chart**: Shows which model has higher overall confidence for your input text
2. **Token-Level Confidence Chart**: Compares how confident each model is for each token in your text
3. **Log Probability Trend Chart**: Shows how log probability changes across tokens (closer to 0 is better)
4. **Token-Level Analysis Table**: Detailed breakdown of predictions for each token
**Model Differences**:
- **ERNIE-4.5-PT**: Instruction-tuned model, better at following complex instructions
- **ERNIE-4.5-Base-PT**: Base model, better at general language patterns
"""
)
# Set up event handler
compare_btn.click(
fn=compare_models,
inputs=[input_text, top_k],
outputs=[comparison_table, summary_box, confidence_chart, logprob_chart, summary_chart]
)
if __name__ == "__main__":
demo.launch()