|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
import math |
|
|
from plotly import graph_objects as go |
|
|
|
|
|
|
|
|
model_ids = { |
|
|
"ERNIE-4.5-PT": "baidu/ERNIE-4.5-0.3B-PT", |
|
|
"ERNIE-4.5-Base-PT": "baidu/ERNIE-4.5-0.3B-Base-PT" |
|
|
} |
|
|
|
|
|
tokenizers = { |
|
|
name: AutoTokenizer.from_pretrained(path) |
|
|
for name, path in model_ids.items() |
|
|
} |
|
|
|
|
|
models = { |
|
|
name: AutoModelForCausalLM.from_pretrained(path).eval() |
|
|
for name, path in model_ids.items() |
|
|
} |
|
|
|
|
|
|
|
|
def format_prob(prob): |
|
|
"""Format probability as percentage with 1 decimal place""" |
|
|
return f"{prob*100:.1f}%" |
|
|
|
|
|
|
|
|
def format_log_prob(log_prob): |
|
|
"""Format log probability with color coding""" |
|
|
return f"{log_prob:.3f}" |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def compare_models(text, top_k=5): |
|
|
if not text.strip(): |
|
|
return None, "⚠️ Please enter some text to analyze" |
|
|
|
|
|
results = {} |
|
|
|
|
|
for model_name in model_ids: |
|
|
tokenizer = tokenizers[model_name] |
|
|
model = models[model_name] |
|
|
|
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt") |
|
|
input_ids = inputs["input_ids"] |
|
|
|
|
|
|
|
|
outputs = model(**inputs) |
|
|
shift_logits = outputs.logits[:, :-1, :] |
|
|
shift_labels = input_ids[:, 1:] |
|
|
|
|
|
|
|
|
log_probs = F.log_softmax(shift_logits, dim=-1) |
|
|
token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1) |
|
|
|
|
|
total_log_prob = token_log_probs.sum().item() |
|
|
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])[1:] |
|
|
|
|
|
|
|
|
topk_list = [] |
|
|
confidence_list = [] |
|
|
for i in range(min(20, shift_logits.shape[1])): |
|
|
topk = torch.topk(log_probs[0, i], k=top_k) |
|
|
topk_ids = topk.indices.tolist() |
|
|
topk_scores = topk.values.tolist() |
|
|
topk_tokens = tokenizer.convert_ids_to_tokens(topk_ids) |
|
|
topk_probs = [math.exp(s) for s in topk_scores] |
|
|
|
|
|
|
|
|
topk_formatted = [f"{tok} ({format_prob(p)})" for tok, p in zip(topk_tokens, topk_probs)] |
|
|
topk_list.append(", ".join(topk_formatted)) |
|
|
|
|
|
|
|
|
actual_token_prob = math.exp(token_log_probs[0, i].item()) |
|
|
confidence_list.append(actual_token_prob) |
|
|
|
|
|
|
|
|
df = pd.DataFrame({ |
|
|
"Token": tokens[:20], |
|
|
"LogProb": [format_log_prob(float(x)) for x in token_log_probs[0][:20]], |
|
|
"Confidence": [format_prob(x) for x in confidence_list[:20]], |
|
|
f"Top-{top_k} Predictions": topk_list |
|
|
}) |
|
|
|
|
|
results[model_name] = { |
|
|
"df": df, |
|
|
"total_log_prob": total_log_prob, |
|
|
"tokens": tokens[:20], |
|
|
"confidences": confidence_list[:20] |
|
|
} |
|
|
|
|
|
|
|
|
comparison_df = pd.DataFrame({ |
|
|
"Token": results["ERNIE-4.5-PT"]["df"]["Token"], |
|
|
"ERNIE-4.5-PT": { |
|
|
"LogProb": results["ERNIE-4.5-PT"]["df"]["LogProb"], |
|
|
"Confidence": results["ERNIE-4.5-PT"]["df"]["Confidence"], |
|
|
"Top-k": results["ERNIE-4.5-PT"]["df"][f"Top-{top_k} Predictions"] |
|
|
}, |
|
|
"ERNIE-4.5-Base-PT": { |
|
|
"LogProb": results["ERNIE-4.5-Base-PT"]["df"]["LogProb"], |
|
|
"Confidence": results["ERNIE-4.5-Base-PT"]["df"]["Confidence"], |
|
|
"Top-k": results["ERNIE-4.5-Base-PT"]["df"][f"Top-{top_k} Predictions"] |
|
|
} |
|
|
}) |
|
|
|
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
|
|
|
fig.add_trace(go.Bar( |
|
|
name='ERNIE-4.5-PT', |
|
|
x=results["ERNIE-4.5-PT"]["tokens"], |
|
|
y=results["ERNIE-4.5-PT"]["confidences"], |
|
|
marker_color='royalblue' |
|
|
)) |
|
|
|
|
|
fig.add_trace(go.Bar( |
|
|
name='ERNIE-4.5-Base-PT', |
|
|
x=results["ERNIE-4.5-Base-PT"]["tokens"], |
|
|
y=results["ERNIE-4.5-Base-PT"]["confidences"], |
|
|
marker_color='lightseagreen' |
|
|
)) |
|
|
|
|
|
fig.update_layout( |
|
|
title='Model Confidence Comparison', |
|
|
xaxis_title='Token', |
|
|
yaxis_title='Confidence (Probability)', |
|
|
barmode='group', |
|
|
yaxis=dict(tickformat='.0%', range=[0, 1]), |
|
|
legend=dict( |
|
|
orientation="h", |
|
|
yanchor="bottom", |
|
|
y=1.02, |
|
|
xanchor="right", |
|
|
x=1 |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
pt_logprob = results['ERNIE-4.5-PT']['total_log_prob'] |
|
|
base_logprob = results['ERNIE-4.5-Base-PT']['total_log_prob'] |
|
|
|
|
|
|
|
|
if pt_logprob > base_logprob: |
|
|
better_model = "ERNIE-4.5-PT" |
|
|
difference = pt_logprob - base_logprob |
|
|
else: |
|
|
better_model = "ERNIE-4.5-Base-PT" |
|
|
difference = base_logprob - pt_logprob |
|
|
|
|
|
summary = ( |
|
|
f"📊 **Model Comparison Summary**\n\n" |
|
|
f"**Total Log Probability**:\n" |
|
|
f"- ERNIE-4.5-PT: {pt_logprob:.3f}\n" |
|
|
f"- ERNIE-4.5-Base-PT: {base_logprob:.3f}\n\n" |
|
|
f"🏆 **Higher Confidence Model**: {better_model}\n" |
|
|
f"Difference: {difference:.3f} ({'+' if better_model == 'ERNIE-4.5-PT' else '-'}{difference:.3f})\n\n" |
|
|
f"**What this means**:\n" |
|
|
f"- Log probability closer to 0 (less negative) indicates higher model confidence\n" |
|
|
f"- The {better_model} model is more confident in predicting your input text\n" |
|
|
f"- Confidence per token is shown in the table and chart below" |
|
|
) |
|
|
|
|
|
return comparison_df, summary, fig |
|
|
|
|
|
|
|
|
css = """ |
|
|
.main-container { |
|
|
max-width: 1200px; |
|
|
margin: 0 auto; |
|
|
} |
|
|
.dataframe-container { |
|
|
margin: 20px 0; |
|
|
} |
|
|
.confidence-chart { |
|
|
margin: 20px 0; |
|
|
height: 400px; |
|
|
} |
|
|
.summary-box { |
|
|
background-color: #f8f9fa; |
|
|
border-left: 4px solid #4285f4; |
|
|
padding: 15px; |
|
|
border-radius: 4px; |
|
|
margin: 20px 0; |
|
|
} |
|
|
.model-header { |
|
|
font-weight: bold; |
|
|
color: #1a73e8; |
|
|
margin-top: 10px; |
|
|
} |
|
|
.token-cell { |
|
|
font-family: monospace; |
|
|
background-color: #f1f3f4; |
|
|
padding: 4px 8px; |
|
|
border-radius: 3px; |
|
|
} |
|
|
.confidence-high { |
|
|
color: #0f9d58; |
|
|
font-weight: bold; |
|
|
} |
|
|
.confidence-medium { |
|
|
color: #f4b400; |
|
|
} |
|
|
.confidence-low { |
|
|
color: #db4437; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(css=css, title="ERNIE Model Comparison Tool") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# 🔍 ERNIE 4.5 Model Comparison Tool |
|
|
|
|
|
Compare how different ERNIE models process your text with detailed token-level analysis. |
|
|
|
|
|
## What this tool shows: |
|
|
- **Token Log Probability**: How confident the model is in predicting each token (closer to 0 is better) |
|
|
- **Confidence**: Probability percentage for each token prediction |
|
|
- **Top-k Predictions**: What other tokens the model considered likely |
|
|
- **Visual Comparison**: Bar chart showing confidence differences between models |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
input_text = gr.Textbox( |
|
|
lines=3, |
|
|
placeholder="Enter text to analyze (e.g., 'Hello, World!')", |
|
|
label="Input Text", |
|
|
value="Hello, World!" |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
top_k = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=10, |
|
|
value=3, |
|
|
step=1, |
|
|
label="Top-k Predictions" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
compare_btn = gr.Button("Compare Models", variant="primary") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
summary_box = gr.Markdown( |
|
|
elem_classes=["summary-box"], |
|
|
label="Model Comparison Summary" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
comparison_table = gr.Dataframe( |
|
|
label="Token-Level Analysis", |
|
|
elem_classes=["dataframe-container"], |
|
|
interactive=False, |
|
|
wrap=True |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
confidence_chart = gr.Plot( |
|
|
label="Model Confidence Comparison", |
|
|
elem_classes=["confidence-chart"] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["Hello, World!", 3], |
|
|
["The quick brown fox jumps over the lazy dog.", 5], |
|
|
["Artificial intelligence will transform our society.", 3], |
|
|
["What is the meaning of life?", 4] |
|
|
], |
|
|
inputs=[input_text, top_k], |
|
|
label="Try these examples:" |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
## How to Interpret Results |
|
|
|
|
|
1. **Log Probability**: Negative values where closer to 0 means higher model confidence |
|
|
2. **Confidence**: Percentage showing how certain the model was about each token |
|
|
3. **Top-k Predictions**: Alternative tokens the model considered likely |
|
|
4. **Visual Chart**: Bar heights represent model confidence for each token |
|
|
|
|
|
**Model Differences**: |
|
|
- **ERNIE-4.5-PT**: Instruction-tuned model, better at following complex instructions |
|
|
- **ERNIE-4.5-Base-PT**: Base model, better at general language patterns |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
compare_btn.click( |
|
|
fn=compare_models, |
|
|
inputs=[input_text, top_k], |
|
|
outputs=[comparison_table, summary_box, confidence_chart] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |