Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import random | |
| import re | |
| import numpy as np | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
| # Initialize variables to track stats | |
| user_stats = { | |
| "mlm": {"correct": 0, "total": 0}, | |
| "ntp": {"correct": 0, "total": 0} | |
| } | |
| # Function to load and sample from the requested dataset | |
| def load_sample_data(sample_size=100): | |
| try: | |
| # Try to load the requested dataset | |
| dataset = load_dataset("mlfoundations/dclm-baseline-1.0-parquet", streaming=True) | |
| dataset_field = "text" # Assuming the field name is "text" | |
| except Exception as e: | |
| print(f"Error loading requested dataset: {e}") | |
| # Fallback to cc_news if there's an issue | |
| dataset = load_dataset("vblagoje/cc_news", streaming=True) | |
| dataset_field = "text" | |
| # Sample from the dataset | |
| samples = [] | |
| for i, example in enumerate(dataset["train"]): | |
| if i >= sample_size: | |
| break | |
| # Get text from the appropriate field | |
| if dataset_field in example and example[dataset_field]: | |
| # Clean text by removing extra whitespaces | |
| text = re.sub(r'\s+', ' ', example[dataset_field]).strip() | |
| # Only include longer texts to make the task meaningful | |
| if len(text.split()) > 20: | |
| # Truncate to two sentences | |
| sentences = re.split(r'(?<=[.!?])\s+', text) | |
| if len(sentences) >= 2: | |
| # Take only the first two sentences | |
| two_sentence_text = ' '.join(sentences[:2]) | |
| samples.append(two_sentence_text) | |
| return samples | |
| # Load data at startup | |
| data_samples = load_sample_data(100) | |
| current_sample = None | |
| masked_text = "" | |
| original_text = "" | |
| masked_indices = [] | |
| masked_tokens = [] | |
| current_task = "mlm" | |
| def prepare_mlm_sample(text, mask_ratio=0.15): | |
| """Prepare a text sample for MLM by masking random tokens.""" | |
| global masked_indices, masked_tokens, original_text | |
| tokens = tokenizer.tokenize(text) | |
| print(f"Text length: {len(text)} characters, {len(tokens)} tokens") | |
| # Only mask whole words, not special tokens or punctuation | |
| maskable_indices = [i for i, token in enumerate(tokens) | |
| if not token.startswith("##") and not token.startswith("[") and not token.endswith("]") | |
| and token not in [".", ",", "!", "?", ";", ":", "'", "\"", "-"]] | |
| print(f"Maskable indices count: {len(maskable_indices)}") | |
| print(f"Mask ratio: {mask_ratio}") | |
| # Calculate how many tokens to mask based on the mask ratio | |
| # No arbitrary cap - use the actual percentage | |
| num_to_mask = max(1, int(len(maskable_indices) * mask_ratio)) | |
| print(f"Number of tokens to mask: {num_to_mask}") | |
| # Randomly select indices to mask | |
| indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices))) | |
| # Sort indices to ensure they're in order | |
| indices_to_mask.sort() | |
| # Create a copy of tokens to mask | |
| masked_tokens_list = tokens.copy() | |
| original_tokens = [] | |
| # Replace selected tokens with [MASK] | |
| for idx in indices_to_mask: | |
| original_tokens.append(masked_tokens_list[idx]) | |
| masked_tokens_list[idx] = "[MASK]" | |
| # Save info for evaluation | |
| masked_indices = indices_to_mask | |
| masked_tokens = original_tokens | |
| original_text = text | |
| # Convert back to text with masks | |
| masked_text = tokenizer.convert_tokens_to_string(masked_tokens_list) | |
| # Print debugging info | |
| print(f"Original tokens: {original_tokens}") | |
| print(f"Masked indices: {indices_to_mask}") | |
| print(f"Number of masks: {len(original_tokens)}") | |
| return masked_text, indices_to_mask, original_tokens | |
| def prepare_ntp_sample(text, cut_ratio=0.3): | |
| """Prepare a text sample for NTP by cutting off the end.""" | |
| # Tokenize text to ensure reasonable cutting | |
| tokens = tokenizer.tokenize(text) | |
| # Print debug info | |
| print(f"NTP preparation - Text length: {len(text)} characters, {len(tokens)} tokens") | |
| print(f"Cut ratio: {cut_ratio}") | |
| # Ensure we have enough tokens | |
| if len(tokens) < 5: | |
| return text, "" # Return original if too short | |
| # Calculate cutoff point based on the cut ratio | |
| cutoff = max(3, int(len(tokens) * (1 - cut_ratio))) | |
| cutoff = min(cutoff, len(tokens) - 1) # Ensure there's at least 1 token to predict | |
| print(f"Cutoff point: {cutoff} (keeping {cutoff} tokens, cutting {len(tokens) - cutoff} tokens)") | |
| # Get the visible part | |
| visible_tokens = tokens[:cutoff] | |
| # Get the hidden part (to be predicted) | |
| hidden_tokens = tokens[cutoff:] | |
| # Convert back to text | |
| visible_text = tokenizer.convert_tokens_to_string(visible_tokens) | |
| hidden_text = tokenizer.convert_tokens_to_string(hidden_tokens) | |
| print(f"Visible text length: {len(visible_text)} chars") | |
| print(f"Hidden text length: {len(hidden_text)} chars") | |
| return visible_text, hidden_text | |
| def get_new_sample(task, mask_ratio=0.15): | |
| """Get a new text sample based on the task.""" | |
| global current_sample, masked_text, masked_indices, masked_tokens, original_text, ntp_state, current_task | |
| # Update current task | |
| current_task = task | |
| # Select a random sample | |
| current_sample = random.choice(data_samples) | |
| # Print debugging info | |
| print(f"Getting new sample for task: {task} with mask ratio: {mask_ratio}") | |
| if task == "mlm": | |
| # Prepare MLM sample | |
| masked_text, masked_indices, masked_tokens = prepare_mlm_sample(current_sample, mask_ratio) | |
| return masked_text | |
| else: # NTP | |
| # Prepare NTP sample | |
| visible_text, hidden_text = prepare_ntp_sample(current_sample, mask_ratio) | |
| # Store original and visible for comparison | |
| original_text = current_sample | |
| masked_text = visible_text | |
| # Reset NTP state for new iteration | |
| ntp_state = { | |
| "full_text": "", | |
| "revealed_text": "", | |
| "next_token_idx": 0, | |
| "tokens": [] | |
| } | |
| # Prepare for token-by-token prediction | |
| prepare_next_token_prediction() | |
| return visible_text | |
| def check_mlm_answer(user_answers): | |
| """Check user MLM answers against the masked tokens.""" | |
| global user_stats | |
| # Print for debugging | |
| print(f"Original user input: '{user_answers}'") | |
| # Handle the case where input is empty | |
| if not user_answers or user_answers.isspace(): | |
| return "Please provide your answers. No input was detected." | |
| # Basic cleanup - trim and lowercase | |
| user_answers = user_answers.strip().lower() | |
| print(f"After basic cleanup: '{user_answers}'") | |
| # Explicit comma-based splitting with protection for empty entries | |
| if ',' in user_answers: | |
| # Split by commas and strip each item | |
| user_tokens = [token.strip() for token in user_answers.split(',')] | |
| # Filter out empty tokens | |
| user_tokens = [token for token in user_tokens if token] | |
| else: | |
| # If no commas, split by whitespace | |
| user_tokens = [token for token in user_answers.split() if token] | |
| print(f"Parsed tokens: {user_tokens}, count: {len(user_tokens)}") | |
| print(f"Expected tokens: {masked_tokens}, count: {len(masked_tokens)}") | |
| # Ensure we have the same number of answers as masks | |
| if len(user_tokens) != len(masked_tokens): | |
| return f"Please provide exactly {len(masked_tokens)} answers (one for each [MASK]). You provided {len(user_tokens)}.\n\nFormat example: word1, word2, word3" | |
| # Compare each answer | |
| correct = 0 | |
| feedback = [] | |
| for i, (user_token, orig_token) in enumerate(zip(user_tokens, masked_tokens)): | |
| orig_token = orig_token.lower() | |
| # Remove ## from subword tokens for comparison | |
| if orig_token.startswith("##"): | |
| orig_token = orig_token[2:] | |
| if user_token == orig_token: | |
| correct += 1 | |
| feedback.append(f"✓ Token {i+1}: '{user_token}' is correct!") | |
| else: | |
| feedback.append(f"✗ Token {i+1}: '{user_token}' should be '{orig_token}'") | |
| # Update stats | |
| user_stats["mlm"]["correct"] += correct | |
| user_stats["mlm"]["total"] += len(masked_tokens) | |
| # Calculate accuracy | |
| accuracy = correct / len(masked_tokens) if masked_tokens else 0 | |
| accuracy_percentage = accuracy * 100 | |
| # Add overall accuracy to feedback | |
| feedback.insert(0, f"Your accuracy: {correct}/{len(masked_tokens)} ({accuracy_percentage:.1f}%)") | |
| # Calculate overall stats | |
| overall_accuracy = user_stats["mlm"]["correct"] / user_stats["mlm"]["total"] if user_stats["mlm"]["total"] > 0 else 0 | |
| feedback.append(f"\nOverall MLM Accuracy: {user_stats['mlm']['correct']}/{user_stats['mlm']['total']} ({overall_accuracy*100:.1f}%)") | |
| return "\n".join(feedback) | |
| # Variable to store NTP state | |
| ntp_state = { | |
| "full_text": "", | |
| "revealed_text": "", | |
| "next_token_idx": 0, | |
| "tokens": [] | |
| } | |
| def prepare_next_token_prediction(): | |
| """Prepare for the next token prediction.""" | |
| global ntp_state, masked_text, original_text | |
| # Get the hidden part | |
| full_hidden = original_text[len(masked_text):].strip() | |
| # Tokenize the hidden part | |
| hidden_tokens = tokenizer.tokenize(full_hidden) | |
| # Print debug info | |
| print(f"NTP State setup:") | |
| print(f" Full text: '{original_text}'") | |
| print(f" Visible text: '{masked_text}'") | |
| print(f" Hidden text: '{full_hidden}'") | |
| print(f" Hidden tokens: {hidden_tokens}") | |
| # Set up the NTP state | |
| ntp_state["tokens"] = hidden_tokens | |
| ntp_state["full_text"] = full_hidden | |
| ntp_state["revealed_text"] = "" | |
| ntp_state["next_token_idx"] = 0 | |
| # Make sure we have tokens to predict | |
| if not ntp_state["tokens"]: | |
| print("Warning: No tokens to predict, will try another sample") | |
| # If we don't have tokens, get a new sample with a higher cut ratio | |
| new_text = get_new_sample("ntp", 0.4) # Use higher cut ratio | |
| prepare_next_token_prediction() | |
| def check_ntp_answer(user_continuation): | |
| """Check user NTP answer for the next token only.""" | |
| global user_stats, ntp_state, masked_text | |
| # If we haven't set up NTP state yet, do it now | |
| if not ntp_state["tokens"]: | |
| prepare_next_token_prediction() | |
| # Print debug info | |
| print(f"Current NTP state:") | |
| print(f" Next token index: {ntp_state['next_token_idx']}") | |
| print(f" Total tokens: {len(ntp_state['tokens'])}") | |
| print(f" User input: '{user_continuation}'") | |
| # No more tokens to predict | |
| if ntp_state["next_token_idx"] >= len(ntp_state["tokens"]): | |
| # Reset for next round | |
| return "You've completed this prediction! Click 'New Sample' for another." | |
| # Get the next token to predict | |
| next_token = ntp_state["tokens"][ntp_state["next_token_idx"]] | |
| print(f" Expected next token: '{next_token}'") | |
| # Get user's prediction | |
| user_text = user_continuation.strip() | |
| # Tokenize user's prediction to get their first token | |
| user_tokens = tokenizer.tokenize(user_text) | |
| user_token = user_tokens[0].lower() if user_tokens else "" | |
| print(f" User's tokenized input: {user_tokens}") | |
| # Clean up tokens for comparison | |
| next_token_clean = next_token.lower() | |
| if next_token_clean.startswith("##"): | |
| next_token_clean = next_token_clean[2:] | |
| if user_token.startswith("##"): | |
| user_token = user_token[2:] | |
| # Check if correct | |
| is_correct = (user_token == next_token_clean) | |
| print(f" Comparison: '{user_token}' vs '{next_token_clean}' -> {'Correct' if is_correct else 'Incorrect'}") | |
| # Update stats | |
| if is_correct: | |
| user_stats["ntp"]["correct"] += 1 | |
| user_stats["ntp"]["total"] += 1 | |
| # Reveal this token and prepare for next | |
| ntp_state["revealed_text"] += tokenizer.convert_tokens_to_string([next_token]) | |
| ntp_state["next_token_idx"] += 1 | |
| # Calculate overall accuracy | |
| overall_accuracy = user_stats["ntp"]["correct"] / user_stats["ntp"]["total"] if user_stats["ntp"]["total"] > 0 else 0 | |
| feedback = [] | |
| if is_correct: | |
| feedback.append(f"✓ Correct! The next token was indeed '{next_token_clean}'") | |
| else: | |
| feedback.append(f"✗ Not quite. The actual next token was '{next_token_clean}'") | |
| # Show progress | |
| feedback.append(f"\nText so far: {masked_text}{ntp_state['revealed_text']}") | |
| # If there are more tokens, prompt for next | |
| if ntp_state["next_token_idx"] < len(ntp_state["tokens"]): | |
| feedback.append(f"\nPredict the next token...") | |
| else: | |
| feedback.append(f"\nPrediction complete! Full text was:\n{original_text}") | |
| # Show overall stats | |
| feedback.append(f"\nOverall NTP Accuracy: {user_stats['ntp']['correct']}/{user_stats['ntp']['total']} ({overall_accuracy*100:.1f}%)") | |
| return "\n".join(feedback) | |
| def switch_task(task): | |
| """Switch between MLM and NTP tasks.""" | |
| global current_task | |
| current_task = task | |
| return gr.update(visible=(task == "mlm")), gr.update(visible=(task == "ntp")) | |
| def generate_new_sample(mask_ratio): | |
| """Generate a new sample based on current task.""" | |
| ratio = float(mask_ratio) / 100.0 # Convert percentage to ratio | |
| sample = get_new_sample(current_task, ratio) | |
| return sample, "" | |
| def check_answer(user_input, task): | |
| """Check user answer based on current task.""" | |
| # Make the current task visible in UI and more prominent | |
| if task == "mlm": | |
| return check_mlm_answer(user_input) | |
| else: # NTP | |
| return check_ntp_answer(user_input) | |
| def reset_stats(): | |
| """Reset user statistics.""" | |
| global user_stats | |
| user_stats = { | |
| "mlm": {"correct": 0, "total": 0}, | |
| "ntp": {"correct": 0, "total": 0} | |
| } | |
| return "Statistics have been reset." | |
| # Set up Gradio interface | |
| with gr.Blocks(title="MLM and NTP Testing") as demo: | |
| gr.Markdown("# Language Model Testing: MLM vs NTP") | |
| gr.Markdown("Test your skills at Masked Language Modeling (MLM) and Next Token Prediction (NTP)") | |
| with gr.Row(): | |
| task_radio = gr.Radio( | |
| ["mlm", "ntp"], | |
| label="Task Type", | |
| value="mlm", | |
| info="MLM: Guess the masked words | NTP: Predict what comes next" | |
| ) | |
| mask_ratio = gr.Slider( | |
| minimum=5, | |
| maximum=50, | |
| value=15, | |
| step=5, | |
| label="Mask/Cut Ratio (%)", | |
| info="Percentage of tokens to mask (MLM) or text to hide (NTP)" | |
| ) | |
| # Count the visible [MASK] tokens for user reference | |
| mask_count = gr.Markdown("**Number of [MASK] tokens to guess: 0**") | |
| sample_text = gr.Textbox( | |
| label="Text Sample", | |
| placeholder="Click 'New Sample' to get started", | |
| value=get_new_sample("mlm", 0.15), | |
| lines=10, | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| new_button = gr.Button("New Sample", variant="primary") | |
| reset_button = gr.Button("Reset Stats") | |
| # Consolidated input area - only one visible at a time | |
| input_area = gr.Group() | |
| with input_area: | |
| # Task-specific input instructions | |
| mlm_instructions = gr.Markdown(""" | |
| ### MLM Instructions | |
| 1. For each [MASK] token, provide your guess for the original word. | |
| 2. Separate your answers with commas. | |
| 3. Make sure you provide exactly the same number of answers as [MASK] tokens. | |
| **Example format:** `word1, word2, word3` or `word1,word2,word3` | |
| """, visible=True) | |
| ntp_instructions = gr.Markdown(""" | |
| ### NTP Instructions | |
| Predict the next word or token that would follow the text. | |
| Type a single word or token for each prediction. | |
| """, visible=False) | |
| # Unified input box | |
| answer_input = gr.Textbox( | |
| label="Your answer", | |
| placeholder="For MLM: word1, word2, word3 | For NTP: single word", | |
| lines=1 | |
| ) | |
| with gr.Row(): | |
| check_button = gr.Button("Check Answer", variant="primary") | |
| result = gr.Textbox(label="Result", lines=6) | |
| # Function to switch task type | |
| def switch_task_unified(task): | |
| if task == "mlm": | |
| mask_text = f"**Number of [MASK] tokens to guess: {len(masked_tokens)}**" | |
| return ( | |
| gr.update(visible=True), # mlm_instructions | |
| gr.update(visible=False), # ntp_instructions | |
| gr.update(placeholder="comma-separated answers (e.g., word1, word2, word3)"), | |
| mask_text | |
| ) | |
| else: # ntp | |
| return ( | |
| gr.update(visible=False), # mlm_instructions | |
| gr.update(visible=True), # ntp_instructions | |
| gr.update(placeholder="Type the next word/token you predict"), | |
| "**Next Token Prediction mode - guess one token at a time**" | |
| ) | |
| # Set up event handlers | |
| task_radio.change( | |
| switch_task_unified, | |
| inputs=[task_radio], | |
| outputs=[mlm_instructions, ntp_instructions, answer_input, mask_count] | |
| ) | |
| # Update the sample text when mask ratio changes (without clicking new sample) | |
| def update_on_ratio_change(mask_ratio_pct, task): | |
| print(f"Ratio changed to {mask_ratio_pct}%") | |
| # Don't generate a new sample here, just update the UI to show the effect of ratio change | |
| return f"Current mask/cut ratio: {mask_ratio_pct}%. Click 'New Sample' to apply." | |
| mask_ratio.change( | |
| update_on_ratio_change, | |
| inputs=[mask_ratio, task_radio], | |
| outputs=[result] | |
| ) | |
| # Update the sample text and also update the mask count | |
| def new_sample_with_count(mask_ratio_pct, task): | |
| print(f"Generating new sample with mask ratio: {mask_ratio_pct}% for task: {task}") | |
| ratio = float(mask_ratio_pct) / 100.0 | |
| sample = get_new_sample(task, ratio) | |
| mask_count_text = "" | |
| if task == "mlm": | |
| count = len(masked_tokens) | |
| mask_count_text = f"**Number of [MASK] tokens to guess: {count}**" | |
| print(f"Generated MLM sample with {count} masks at ratio {ratio}") | |
| else: | |
| mask_count_text = "**Next Token Prediction mode - guess one token at a time**" | |
| print(f"Generated NTP sample with cut ratio {ratio}") | |
| return sample, mask_count_text, "" | |
| new_button.click( | |
| new_sample_with_count, | |
| inputs=[mask_ratio, task_radio], | |
| outputs=[sample_text, mask_count, result] | |
| ) | |
| reset_button.click(reset_stats, inputs=None, outputs=[result]) | |
| # Unified check answer function | |
| def unified_check_answer(user_input, task): | |
| if task == "mlm": | |
| return check_mlm_answer(user_input) | |
| else: # ntp | |
| return check_ntp_answer(user_input) | |
| check_button.click( | |
| unified_check_answer, | |
| inputs=[answer_input, task_radio], | |
| outputs=[result] | |
| ) | |
| answer_input.submit( | |
| unified_check_answer, | |
| inputs=[answer_input, task_radio], | |
| outputs=[result] | |
| ) | |
| demo.launch() |