|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
import math |
|
|
import plotly.graph_objects as go |
|
|
import plotly.express as px |
|
|
from plotly.subplots import make_subplots |
|
|
|
|
|
|
|
|
model_ids = { |
|
|
"ERNIE-4.5-PT": "baidu/ERNIE-4.5-0.3B-PT", |
|
|
"ERNIE-4.5-Base-PT": "baidu/ERNIE-4.5-0.3B-Base-PT" |
|
|
} |
|
|
|
|
|
tokenizers = { |
|
|
name: AutoTokenizer.from_pretrained(path) |
|
|
for name, path in model_ids.items() |
|
|
} |
|
|
|
|
|
models = { |
|
|
name: AutoModelForCausalLM.from_pretrained(path).eval() |
|
|
for name, path in model_ids.items() |
|
|
} |
|
|
|
|
|
|
|
|
def format_prob(prob): |
|
|
"""Format probability as percentage with 1 decimal place""" |
|
|
return f"{prob*100:.1f}%" |
|
|
|
|
|
|
|
|
def format_log_prob(log_prob): |
|
|
"""Format log probability""" |
|
|
return f"{log_prob:.3f}" |
|
|
|
|
|
|
|
|
def get_confidence_level(prob): |
|
|
"""Get confidence level description based on probability""" |
|
|
if prob > 0.8: |
|
|
return "High", "π’" |
|
|
elif prob > 0.5: |
|
|
return "Medium", "π‘" |
|
|
else: |
|
|
return "Low", "π΄" |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def compare_models(text, top_k=5): |
|
|
if not text.strip(): |
|
|
return None, "β οΈ Please enter some text to analyze", None |
|
|
|
|
|
results = {} |
|
|
|
|
|
for model_name in model_ids: |
|
|
tokenizer = tokenizers[model_name] |
|
|
model = models[model_name] |
|
|
|
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt") |
|
|
input_ids = inputs["input_ids"] |
|
|
|
|
|
|
|
|
outputs = model(**inputs) |
|
|
shift_logits = outputs.logits[:, :-1, :] |
|
|
shift_labels = input_ids[:, 1:] |
|
|
|
|
|
|
|
|
log_probs = F.log_softmax(shift_logits, dim=-1) |
|
|
token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1) |
|
|
|
|
|
total_log_prob = token_log_probs.sum().item() |
|
|
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])[1:] |
|
|
|
|
|
|
|
|
topk_list = [] |
|
|
confidence_list = [] |
|
|
confidence_indicators = [] |
|
|
|
|
|
for i in range(min(20, shift_logits.shape[1])): |
|
|
topk = torch.topk(log_probs[0, i], k=top_k) |
|
|
topk_ids = topk.indices.tolist() |
|
|
topk_scores = topk.values.tolist() |
|
|
topk_tokens = tokenizer.convert_ids_to_tokens(topk_ids) |
|
|
topk_probs = [math.exp(s) for s in topk_scores] |
|
|
|
|
|
|
|
|
topk_formatted = [f"{tok} ({format_prob(p)})" for tok, p in zip(topk_tokens, topk_probs)] |
|
|
topk_list.append(", ".join(topk_formatted)) |
|
|
|
|
|
|
|
|
actual_token_prob = math.exp(token_log_probs[0, i].item()) |
|
|
confidence_list.append(actual_token_prob) |
|
|
|
|
|
|
|
|
level, indicator = get_confidence_level(actual_token_prob) |
|
|
confidence_indicators.append(indicator) |
|
|
|
|
|
|
|
|
results[model_name] = { |
|
|
"tokens": tokens[:20], |
|
|
"log_probs": [format_log_prob(float(x)) for x in token_log_probs[0][:20]], |
|
|
"confidences": [format_prob(x) for x in confidence_list[:20]], |
|
|
"levels": confidence_indicators[:20], |
|
|
"topk_predictions": topk_list, |
|
|
"total_log_prob": total_log_prob, |
|
|
"confidence_values": confidence_list[:20] |
|
|
} |
|
|
|
|
|
|
|
|
df_data = {"Token": results["ERNIE-4.5-PT"]["tokens"]} |
|
|
|
|
|
|
|
|
for model_name in ["ERNIE-4.5-PT", "ERNIE-4.5-Base-PT"]: |
|
|
df_data[f"{model_name} LogProb"] = results[model_name]["log_probs"] |
|
|
df_data[f"{model_name} Confidence"] = results[model_name]["confidences"] |
|
|
df_data[f"{model_name} Level"] = results[model_name]["levels"] |
|
|
df_data[f"{model_name} Top-{top_k}"] = results[model_name]["topk_predictions"] |
|
|
|
|
|
|
|
|
comparison_df = pd.DataFrame(df_data) |
|
|
|
|
|
|
|
|
|
|
|
fig_confidence = go.Figure() |
|
|
|
|
|
|
|
|
fig_confidence.add_trace(go.Bar( |
|
|
name='ERNIE-4.5-PT', |
|
|
x=results["ERNIE-4.5-PT"]["tokens"], |
|
|
y=results["ERNIE-4.5-PT"]["confidence_values"], |
|
|
marker_color='royalblue', |
|
|
text=[format_prob(x) for x in results["ERNIE-4.5-PT"]["confidence_values"]], |
|
|
textposition='auto', |
|
|
textfont=dict(size=10) |
|
|
)) |
|
|
|
|
|
fig_confidence.add_trace(go.Bar( |
|
|
name='ERNIE-4.5-Base-PT', |
|
|
x=results["ERNIE-4.5-Base-PT"]["tokens"], |
|
|
y=results["ERNIE-4.5-Base-PT"]["confidence_values"], |
|
|
marker_color='lightseagreen', |
|
|
text=[format_prob(x) for x in results["ERNIE-4.5-Base-PT"]["confidence_values"]], |
|
|
textposition='auto', |
|
|
textfont=dict(size=10) |
|
|
)) |
|
|
|
|
|
fig_confidence.update_layout( |
|
|
title='Token-Level Confidence Comparison', |
|
|
xaxis_title='Token', |
|
|
yaxis_title='Confidence (Probability)', |
|
|
barmode='group', |
|
|
yaxis=dict(tickformat='.0%', range=[0, 1.05]), |
|
|
legend=dict( |
|
|
orientation="h", |
|
|
yanchor="bottom", |
|
|
y=1.02, |
|
|
xanchor="right", |
|
|
x=1 |
|
|
), |
|
|
height=500 |
|
|
) |
|
|
|
|
|
|
|
|
fig_logprob = go.Figure() |
|
|
|
|
|
|
|
|
pt_logprobs = [float(x) for x in results["ERNIE-4.5-PT"]["log_probs"]] |
|
|
base_logprobs = [float(x) for x in results["ERNIE-4.5-Base-PT"]["log_probs"]] |
|
|
|
|
|
fig_logprob.add_trace(go.Scatter( |
|
|
name='ERNIE-4.5-PT', |
|
|
x=results["ERNIE-4.5-PT"]["tokens"], |
|
|
y=pt_logprobs, |
|
|
mode='lines+markers', |
|
|
line=dict(color='royalblue', width=3), |
|
|
marker=dict(size=8), |
|
|
text=[f"LogProb: {x}<br>Token: {t}" for x, t in zip(pt_logprobs, results["ERNIE-4.5-PT"]["tokens"])], |
|
|
hoverinfo='text' |
|
|
)) |
|
|
|
|
|
fig_logprob.add_trace(go.Scatter( |
|
|
name='ERNIE-4.5-Base-PT', |
|
|
x=results["ERNIE-4.5-Base-PT"]["tokens"], |
|
|
y=base_logprobs, |
|
|
mode='lines+markers', |
|
|
line=dict(color='lightseagreen', width=3), |
|
|
marker=dict(size=8), |
|
|
text=[f"LogProb: {x}<br>Token: {t}" for x, t in zip(base_logprobs, results["ERNIE-4.5-Base-PT"]["tokens"])], |
|
|
hoverinfo='text' |
|
|
)) |
|
|
|
|
|
|
|
|
fig_logprob.add_hline(y=0, line_dash="dash", line_color="red", annotation_text="Zero Reference") |
|
|
|
|
|
fig_logprob.update_layout( |
|
|
title='Token-Level Log Probability Trend', |
|
|
xaxis_title='Token', |
|
|
yaxis_title='Log Probability', |
|
|
hovermode='closest', |
|
|
legend=dict( |
|
|
orientation="h", |
|
|
yanchor="bottom", |
|
|
y=1.02, |
|
|
xanchor="right", |
|
|
x=1 |
|
|
), |
|
|
height=400 |
|
|
) |
|
|
|
|
|
|
|
|
pt_logprob = results['ERNIE-4.5-PT']['total_log_prob'] |
|
|
base_logprob = results['ERNIE-4.5-Base-PT']['total_log_prob'] |
|
|
|
|
|
|
|
|
if pt_logprob > base_logprob: |
|
|
better_model = "ERNIE-4.5-PT" |
|
|
difference = pt_logprob - base_logprob |
|
|
else: |
|
|
better_model = "ERNIE-4.5-Base-PT" |
|
|
difference = base_logprob - pt_logprob |
|
|
|
|
|
|
|
|
pt_avg_conf = sum(results['ERNIE-4.5-PT']['confidence_values']) / len(results['ERNIE-4.5-PT']['confidence_values']) |
|
|
base_avg_conf = sum(results['ERNIE-4.5-Base-PT']['confidence_values']) / len(results['ERNIE-4.5-Base-PT']['confidence_values']) |
|
|
|
|
|
|
|
|
fig_summary = go.Figure() |
|
|
|
|
|
fig_summary.add_trace(go.Bar( |
|
|
name='Total Log Probability', |
|
|
x=['ERNIE-4.5-PT', 'ERNIE-4.5-Base-PT'], |
|
|
y=[pt_logprob, base_logprob], |
|
|
marker_color=['royalblue', 'lightseagreen'], |
|
|
text=[f"{pt_logprob:.3f}", f"{base_logprob:.3f}"], |
|
|
textposition='auto', |
|
|
textfont=dict(size=14) |
|
|
)) |
|
|
|
|
|
fig_summary.update_layout( |
|
|
title='Model Summary Comparison', |
|
|
yaxis_title='Total Log Probability', |
|
|
xaxis_title='Model', |
|
|
height=300, |
|
|
showlegend=False |
|
|
) |
|
|
|
|
|
|
|
|
fig_summary.add_annotation( |
|
|
x=0 if better_model == "ERNIE-4.5-PT" else 1, |
|
|
y=max(pt_logprob, base_logprob) + 0.5, |
|
|
text=f"π {better_model}", |
|
|
showarrow=True, |
|
|
arrowhead=1, |
|
|
ax=0, |
|
|
ay=-30, |
|
|
font=dict(size=16, color="green") |
|
|
) |
|
|
|
|
|
|
|
|
summary = ( |
|
|
f"π **Model Comparison Summary**\n\n" |
|
|
f"**Total Log Probability**:\n" |
|
|
f"- ERNIE-4.5-PT: {pt_logprob:.3f}\n" |
|
|
f"- ERNIE-4.5-Base-PT: {base_logprob:.3f}\n\n" |
|
|
f"**Average Confidence**:\n" |
|
|
f"- ERNIE-4.5-PT: {format_prob(pt_avg_conf)}\n" |
|
|
f"- ERNIE-4.5-Base-PT: {format_prob(base_avg_conf)}\n\n" |
|
|
f"π **Higher Confidence Model**: {better_model}\n" |
|
|
f"Difference: {difference:.3f}\n\n" |
|
|
f"**What this means**:\n" |
|
|
f"- Log probability closer to 0 (less negative) indicates higher model confidence\n" |
|
|
f"- The {better_model} model is more confident in predicting your input text\n" |
|
|
f"- Confidence indicators: π’ High (>80%), π‘ Medium (50-80%), π΄ Low (<50%)\n\n" |
|
|
f"**Interpretation Guide**:\n" |
|
|
f"- **LogProb**: How confident the model is in predicting each token (closer to 0 is better)\n" |
|
|
f"- **Confidence**: Probability percentage for each token prediction\n" |
|
|
f"- **Level**: Visual indicator of confidence (π’π‘π΄)\n" |
|
|
f"- **Top-k**: What other tokens the model considered likely" |
|
|
) |
|
|
|
|
|
return comparison_df, summary, fig_confidence, fig_logprob, fig_summary |
|
|
|
|
|
|
|
|
css = """ |
|
|
.main-container { |
|
|
max-width: 1400px; |
|
|
margin: 0 auto; |
|
|
} |
|
|
.dataframe-container { |
|
|
margin: 20px 0; |
|
|
} |
|
|
.summary-box { |
|
|
background-color: #f8f9fa; |
|
|
border-left: 4px solid #4285f4; |
|
|
padding: 15px; |
|
|
border-radius: 4px; |
|
|
margin: 20px 0; |
|
|
} |
|
|
.chart-container { |
|
|
margin: 20px 0; |
|
|
border: 1px solid #e0e0e0; |
|
|
border-radius: 8px; |
|
|
padding: 15px; |
|
|
background-color: #ffffff; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(css=css, title="ERNIE Model Comparison Tool") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# π ERNIE 4.5 Model Comparison Tool |
|
|
|
|
|
Compare how different ERNIE models process your text with detailed token-level analysis and visualizations. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
input_text = gr.Textbox( |
|
|
lines=3, |
|
|
placeholder="Enter text to analyze (e.g., 'Hello, World!')", |
|
|
label="Input Text", |
|
|
value="What is the meaning of life?" |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
top_k = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=10, |
|
|
value=3, |
|
|
step=1, |
|
|
label="Top-k Predictions" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
compare_btn = gr.Button("Compare Models", variant="primary", size="lg") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
summary_box = gr.Markdown( |
|
|
elem_classes=["summary-box"], |
|
|
label="Model Comparison Summary" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
summary_chart = gr.Plot( |
|
|
label="Model Summary", |
|
|
elem_classes=["chart-container"] |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
confidence_chart = gr.Plot( |
|
|
label="Token-Level Confidence Comparison", |
|
|
elem_classes=["chart-container"] |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
logprob_chart = gr.Plot( |
|
|
label="Token-Level Log Probability Trend", |
|
|
elem_classes=["chart-container"] |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
comparison_table = gr.Dataframe( |
|
|
label="Token-Level Analysis", |
|
|
elem_classes=["dataframe-container"], |
|
|
interactive=False, |
|
|
wrap=True |
|
|
) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["Hello, World!", 3], |
|
|
["The quick brown fox jumps over the lazy dog.", 5], |
|
|
["Artificial intelligence will transform our society.", 3], |
|
|
["What is the meaning of life?", 4] |
|
|
], |
|
|
inputs=[input_text, top_k], |
|
|
label="Try these examples:" |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
## How to Interpret Results |
|
|
|
|
|
1. **Model Summary Chart**: Shows which model has higher overall confidence for your input text |
|
|
2. **Token-Level Confidence Chart**: Compares how confident each model is for each token in your text |
|
|
3. **Log Probability Trend Chart**: Shows how log probability changes across tokens (closer to 0 is better) |
|
|
4. **Token-Level Analysis Table**: Detailed breakdown of predictions for each token |
|
|
|
|
|
**Model Differences**: |
|
|
- **ERNIE-4.5-PT**: Instruction-tuned model, better at following complex instructions |
|
|
- **ERNIE-4.5-Base-PT**: Base model, better at general language patterns |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
compare_btn.click( |
|
|
fn=compare_models, |
|
|
inputs=[input_text, top_k], |
|
|
outputs=[comparison_table, summary_box, confidence_chart, logprob_chart, summary_chart] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |