import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForCausalLM import gradio as gr import math import pandas as pd import plotly.graph_objects as go from plotly.subplots import make_subplots # Load model and tokenizer model_ids = { "ERNIE-4.5-Base-PT": "baidu/ERNIE-4.5-0.3B-Base-PT", "ERNIE-4.5-PT": "baidu/ERNIE-4.5-0.3B-PT" } tokenizers = { name: AutoTokenizer.from_pretrained(path) for name, path in model_ids.items() } models = { name: AutoModelForCausalLM.from_pretrained(path).eval() for name, path in model_ids.items() } def calculate_token_log_probabilities(text, model_name): """Calculate log probability for each token and total log probability.""" tokenizer = tokenizers[model_name] model = models[model_name] # Tokenize input inputs = tokenizer(text, return_tensors="pt") input_ids = inputs["input_ids"] # Get model output logits outputs = model(**inputs) shift_logits = outputs.logits[:, :-1, :] # Align prediction with target shift_labels = input_ids[:, 1:] # Shift labels to match predictions # Compute log probabilities log_probs = F.log_softmax(shift_logits, dim=-1) token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1) # Convert to list and get corresponding tokens token_log_probs = token_log_probs[0].tolist() tokens = tokenizer.convert_ids_to_tokens(shift_labels[0]) # Calculate total log probability total_log_prob = sum(token_log_probs) return token_log_probs, tokens, total_log_prob def create_analysis_visualization(tokens, token_log_probs, total_log_prob, model_name): """Create visualization components for token analysis.""" # Create DataFrame for token analysis df_data = [] for token, log_prob in zip(tokens, token_log_probs): prob = math.exp(log_prob) df_data.append({ "Token": token, "Log Probability": f"{log_prob:.4f}", "Probability": f"{prob:.4f}", "Probability (%)": f"{prob*100:.2f}%" }) df = pd.DataFrame(df_data) # Create bar chart for token probabilities fig = go.Figure() fig.add_trace(go.Bar( x=tokens, y=[math.exp(lp) for lp in token_log_probs], text=[f"{math.exp(lp):.3f}" for lp in token_log_probs], textposition='auto', marker_color='royalblue', name=model_name )) fig.update_layout( title=f"Token Probability Distribution - {model_name}", xaxis_title="Token", yaxis_title="Probability", yaxis=dict(tickformat='.0%', range=[0, 1.05]), height=400 ) return df, fig def analyze_text_both_models(text): """Analyze text with both models and return visualization components.""" if not text.strip(): # Return empty components for both models return (None, None, "⚠️ Please enter some text to analyze", None, None, "⚠️ Please enter some text to analyze", None, "⚠️ Please enter some text to analyze") results = {} # Analyze with both models for model_name in model_ids: token_log_probs, tokens, total_log_prob = calculate_token_log_probabilities(text, model_name) # Create visualization components df, fig = create_analysis_visualization(tokens, token_log_probs, total_log_prob, model_name) # Create summary text avg_prob = math.exp(total_log_prob / len(token_log_probs)) if token_log_probs else 0 geo_mean_prob = math.exp(total_log_prob / len(token_log_probs)) if token_log_probs else 0 summary = ( f"## Analysis Summary - {model_name}\n\n" f"**Total Log Probability**: {total_log_prob:.4f}\n" f"**Sum of Individual Log Probs**: {sum(token_log_probs):.4f}\n" f"**Verification**: {'✓ Match' if abs(total_log_prob - sum(token_log_probs)) < 1e-10 else '✗ Mismatch'}\n\n" f"**Average Token Probability**: {avg_prob:.4f} ({avg_prob*100:.2f}%)\n" f"**Geometric Mean of Probabilities**: {geo_mean_prob:.4f} ({geo_mean_prob*100:.2f}%)\n\n" f"### Interpretation\n" f"- Total Log Probability is the sum of individual token log probabilities\n" f"- Higher values (closer to 0) indicate higher model confidence\n" f"- The first token has no prediction (no preceding context)\n" f"- Each token's probability shows how confident the model was in predicting it" ) results[model_name] = { "df": df, "fig": fig, "summary": summary, "total_log_prob": total_log_prob, "token_log_probs": token_log_probs, "tokens": tokens } # Create comparison chart comparison_fig = go.Figure() # Add bars for both models base_model = "ERNIE-4.5-Base-PT" pt_model = "ERNIE-4.5-PT" # Ensure both models have the same tokens for comparison if results[base_model]["tokens"] == results[pt_model]["tokens"]: tokens = results[base_model]["tokens"] comparison_fig.add_trace(go.Bar( name=base_model, x=tokens, y=[math.exp(lp) for lp in results[base_model]["token_log_probs"]], text=[f"{math.exp(lp):.3f}" for lp in results[base_model]["token_log_probs"]], textposition='auto', marker_color='royalblue' )) comparison_fig.add_trace(go.Bar( name=pt_model, x=tokens, y=[math.exp(lp) for lp in results[pt_model]["token_log_probs"]], text=[f"{math.exp(lp):.3f}" for lp in results[pt_model]["token_log_probs"]], textposition='auto', marker_color='lightseagreen' )) comparison_fig.update_layout( title="Model Comparison: Token Probability Distribution", xaxis_title="Token", yaxis_title="Probability", yaxis=dict(tickformat='.0%', range=[0, 1.05]), barmode='group', height=400 ) else: # If tokens are different, create separate subplots comparison_fig = make_subplots( rows=2, cols=1, subplot_titles=(base_model, pt_model), vertical_spacing=0.1 ) # Add Base-PT model comparison_fig.add_trace( go.Bar( x=results[base_model]["tokens"], y=[math.exp(lp) for lp in results[base_model]["token_log_probs"]], text=[f"{math.exp(lp):.3f}" for lp in results[base_model]["token_log_probs"]], textposition='auto', marker_color='royalblue', name=base_model ), row=1, col=1 ) # Add PT model comparison_fig.add_trace( go.Bar( x=results[pt_model]["tokens"], y=[math.exp(lp) for lp in results[pt_model]["token_log_probs"]], text=[f"{math.exp(lp):.3f}" for lp in results[pt_model]["token_log_probs"]], textposition='auto', marker_color='lightseagreen', name=pt_model ), row=2, col=1 ) comparison_fig.update_layout( title="Model Comparison: Token Probability Distribution", height=600, showlegend=False ) # Create comparison summary base_total = results[base_model]["total_log_prob"] pt_total = results[pt_model]["total_log_prob"] better_model = base_model if base_total > pt_total else pt_model difference = abs(base_total - pt_total) comparison_summary = ( f"## Model Comparison Summary\n\n" f"**Total Log Probability**:\n" f"- {base_model}: {base_total:.4f}\n" f"- {pt_model}: {pt_total:.4f}\n\n" f"**Higher Confidence Model**: {better_model}\n" f"Difference: {difference:.4f}\n\n" f"### Interpretation\n" f"- The model with the higher Total Log Probability is more confident in predicting the input text\n" f"- Log probability closer to 0 (less negative) indicates higher model confidence\n" f"- {base_model} is the base model while {pt_model} is instruction-tuned" ) return ( results[base_model]["df"], results[base_model]["fig"], results[base_model]["summary"], results[pt_model]["df"], results[pt_model]["fig"], results[pt_model]["summary"], comparison_fig, comparison_summary ) # Create Gradio interface with side-by-side comparison with gr.Blocks(title="Token Log Probability Analyzer - Model Comparison") as demo: gr.Markdown( """ # 🔍 Token Log Probability Analyzer - Model Comparison Compare how two ERNIE models predict each token in your text with detailed log probability breakdown. """ ) with gr.Row(): text_input = gr.Textbox( label="Input Text", placeholder="Enter text to analyze (e.g., 'Hello, World!')", value="Hello, World!" ) with gr.Row(): analyze_btn = gr.Button("Analyze Both Models", variant="primary", size="lg") # Model comparison section with gr.Row(): with gr.Column(): comparison_summary_output = gr.Markdown(label="Model Comparison Summary") with gr.Row(): comparison_chart_output = gr.Plot(label="Model Comparison Chart") # Side-by-side model results with gr.Row(): # Left column: ERNIE-4.5-Base-PT with gr.Column(): gr.Markdown("### ERNIE-4.5-Base-PT") base_summary_output = gr.Markdown(label="Base Model Summary") base_table_output = gr.DataFrame( label="Token Analysis", interactive=False, wrap=True ) base_chart_output = gr.Plot(label="Token Probability Chart") # Right column: ERNIE-4.5-PT with gr.Column(): gr.Markdown("### ERNIE-4.5-PT") pt_summary_output = gr.Markdown(label="PT Model Summary") pt_table_output = gr.DataFrame( label="Token Analysis", interactive=False, wrap=True ) pt_chart_output = gr.Plot(label="Token Probability Chart") # Examples section gr.Examples( examples=[ ["Hello, World!"], ["The quick brown fox jumps over the lazy dog."], ["Artificial intelligence will transform our society."], ["What is the meaning of life?"] ], inputs=[text_input], label="Try these examples:" ) # Footer with explanation gr.Markdown( """ ## How to Interpret Results This interface compares two ERNIE models side by side: 1. **ERNIE-4.5-Base-PT** (left): Base model, better at general language patterns 2. **ERNIE-4.5-PT** (right): Instruction-tuned model, better at following complex instructions ### Analysis Components For each model, you'll see: - **Summary**: Key metrics including Total Log Probability and average token probability - **Token Analysis Table**: Detailed breakdown of each token's log probability and probability - **Token Probability Chart**: Visual representation of each token's prediction probability ### Model Comparison - **Model Comparison Summary**: Shows which model has higher overall confidence - **Model Comparison Chart**: Side-by-side visualization of token probabilities ### Key Concepts - **Log Probability**: - Ranges from -∞ to 0 - Closer to 0 = higher model confidence - Used instead of raw probability to avoid numerical underflow - **Total Log Probability**: - Sum of individual token log probabilities - Measures overall model confidence in the entire sequence - Allows comparison between different models - **Why Compare Models?**: - Base models may be better at general language - Instruction-tuned models may be better at specific tasks - Different models have different strengths for different types of text """ ) # Set up event handler analyze_btn.click( fn=analyze_text_both_models, inputs=[text_input], outputs=[ base_table_output, base_chart_output, base_summary_output, pt_table_output, pt_chart_output, pt_summary_output, comparison_chart_output, comparison_summary_output ] ) if __name__ == "__main__": demo.launch()