Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -60,13 +60,21 @@ def prepare_mlm_sample(text, mask_ratio=0.15):
|
|
| 60 |
global masked_indices, masked_tokens, original_text
|
| 61 |
|
| 62 |
tokens = tokenizer.tokenize(text)
|
|
|
|
|
|
|
| 63 |
# Only mask whole words, not special tokens or punctuation
|
| 64 |
maskable_indices = [i for i, token in enumerate(tokens)
|
| 65 |
if not token.startswith("##") and not token.startswith("[") and not token.endswith("]")
|
| 66 |
and token not in [".", ",", "!", "?", ";", ":", "'", "\"", "-"]]
|
| 67 |
|
|
|
|
|
|
|
|
|
|
| 68 |
# Calculate how many tokens to mask, but ensure at least 1 and at most 8
|
|
|
|
| 69 |
num_to_mask = max(1, min(8, int(len(maskable_indices) * mask_ratio)))
|
|
|
|
|
|
|
| 70 |
# Randomly select indices to mask
|
| 71 |
indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
|
| 72 |
# Sort indices to ensure they're in order
|
|
@@ -101,15 +109,20 @@ def prepare_ntp_sample(text, cut_ratio=0.3):
|
|
| 101 |
# Tokenize text to ensure reasonable cutting
|
| 102 |
tokens = tokenizer.tokenize(text)
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
# Ensure we have enough tokens
|
| 105 |
if len(tokens) < 5:
|
| 106 |
return text, "" # Return original if too short
|
| 107 |
|
| 108 |
-
# Calculate cutoff point
|
| 109 |
-
# But make sure we have at least 3 tokens visible and 1 token hidden
|
| 110 |
cutoff = max(3, int(len(tokens) * (1 - cut_ratio)))
|
| 111 |
cutoff = min(cutoff, len(tokens) - 1) # Ensure there's at least 1 token to predict
|
| 112 |
|
|
|
|
|
|
|
| 113 |
# Get the visible part
|
| 114 |
visible_tokens = tokens[:cutoff]
|
| 115 |
|
|
@@ -120,15 +133,24 @@ def prepare_ntp_sample(text, cut_ratio=0.3):
|
|
| 120 |
visible_text = tokenizer.convert_tokens_to_string(visible_tokens)
|
| 121 |
hidden_text = tokenizer.convert_tokens_to_string(hidden_tokens)
|
| 122 |
|
|
|
|
|
|
|
|
|
|
| 123 |
return visible_text, hidden_text
|
| 124 |
|
| 125 |
def get_new_sample(task, mask_ratio=0.15):
|
| 126 |
"""Get a new text sample based on the task."""
|
| 127 |
-
global current_sample, masked_text, masked_indices, masked_tokens, original_text, ntp_state
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
# Select a random sample
|
| 130 |
current_sample = random.choice(data_samples)
|
| 131 |
|
|
|
|
|
|
|
|
|
|
| 132 |
if task == "mlm":
|
| 133 |
# Prepare MLM sample
|
| 134 |
masked_text, masked_indices, masked_tokens = prepare_mlm_sample(current_sample, mask_ratio)
|
|
@@ -373,7 +395,7 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
|
|
| 373 |
)
|
| 374 |
|
| 375 |
with gr.Row():
|
| 376 |
-
new_button = gr.Button("New Sample")
|
| 377 |
reset_button = gr.Button("Reset Stats")
|
| 378 |
|
| 379 |
# Consolidated input area - only one visible at a time
|
|
@@ -433,8 +455,21 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
|
|
| 433 |
outputs=[mlm_instructions, ntp_instructions, answer_input, mask_count]
|
| 434 |
)
|
| 435 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
# Update the sample text and also update the mask count
|
| 437 |
def new_sample_with_count(mask_ratio_pct, task):
|
|
|
|
| 438 |
ratio = float(mask_ratio_pct) / 100.0
|
| 439 |
sample = get_new_sample(task, ratio)
|
| 440 |
mask_count_text = ""
|
|
@@ -442,8 +477,10 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
|
|
| 442 |
if task == "mlm":
|
| 443 |
count = len(masked_tokens)
|
| 444 |
mask_count_text = f"**Number of [MASK] tokens to guess: {count}**"
|
|
|
|
| 445 |
else:
|
| 446 |
mask_count_text = "**Next Token Prediction mode - guess one token at a time**"
|
|
|
|
| 447 |
|
| 448 |
return sample, mask_count_text, ""
|
| 449 |
|
|
|
|
| 60 |
global masked_indices, masked_tokens, original_text
|
| 61 |
|
| 62 |
tokens = tokenizer.tokenize(text)
|
| 63 |
+
print(f"Text length: {len(text)} characters, {len(tokens)} tokens")
|
| 64 |
+
|
| 65 |
# Only mask whole words, not special tokens or punctuation
|
| 66 |
maskable_indices = [i for i, token in enumerate(tokens)
|
| 67 |
if not token.startswith("##") and not token.startswith("[") and not token.endswith("]")
|
| 68 |
and token not in [".", ",", "!", "?", ";", ":", "'", "\"", "-"]]
|
| 69 |
|
| 70 |
+
print(f"Maskable indices count: {len(maskable_indices)}")
|
| 71 |
+
print(f"Mask ratio: {mask_ratio}")
|
| 72 |
+
|
| 73 |
# Calculate how many tokens to mask, but ensure at least 1 and at most 8
|
| 74 |
+
# Use the maskable_indices length with the ratio
|
| 75 |
num_to_mask = max(1, min(8, int(len(maskable_indices) * mask_ratio)))
|
| 76 |
+
print(f"Number of tokens to mask: {num_to_mask}")
|
| 77 |
+
|
| 78 |
# Randomly select indices to mask
|
| 79 |
indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
|
| 80 |
# Sort indices to ensure they're in order
|
|
|
|
| 109 |
# Tokenize text to ensure reasonable cutting
|
| 110 |
tokens = tokenizer.tokenize(text)
|
| 111 |
|
| 112 |
+
# Print debug info
|
| 113 |
+
print(f"NTP preparation - Text length: {len(text)} characters, {len(tokens)} tokens")
|
| 114 |
+
print(f"Cut ratio: {cut_ratio}")
|
| 115 |
+
|
| 116 |
# Ensure we have enough tokens
|
| 117 |
if len(tokens) < 5:
|
| 118 |
return text, "" # Return original if too short
|
| 119 |
|
| 120 |
+
# Calculate cutoff point based on the cut ratio
|
|
|
|
| 121 |
cutoff = max(3, int(len(tokens) * (1 - cut_ratio)))
|
| 122 |
cutoff = min(cutoff, len(tokens) - 1) # Ensure there's at least 1 token to predict
|
| 123 |
|
| 124 |
+
print(f"Cutoff point: {cutoff} (keeping {cutoff} tokens, cutting {len(tokens) - cutoff} tokens)")
|
| 125 |
+
|
| 126 |
# Get the visible part
|
| 127 |
visible_tokens = tokens[:cutoff]
|
| 128 |
|
|
|
|
| 133 |
visible_text = tokenizer.convert_tokens_to_string(visible_tokens)
|
| 134 |
hidden_text = tokenizer.convert_tokens_to_string(hidden_tokens)
|
| 135 |
|
| 136 |
+
print(f"Visible text length: {len(visible_text)} chars")
|
| 137 |
+
print(f"Hidden text length: {len(hidden_text)} chars")
|
| 138 |
+
|
| 139 |
return visible_text, hidden_text
|
| 140 |
|
| 141 |
def get_new_sample(task, mask_ratio=0.15):
|
| 142 |
"""Get a new text sample based on the task."""
|
| 143 |
+
global current_sample, masked_text, masked_indices, masked_tokens, original_text, ntp_state, current_task
|
| 144 |
+
|
| 145 |
+
# Update current task
|
| 146 |
+
current_task = task
|
| 147 |
|
| 148 |
# Select a random sample
|
| 149 |
current_sample = random.choice(data_samples)
|
| 150 |
|
| 151 |
+
# Print debugging info
|
| 152 |
+
print(f"Getting new sample for task: {task} with mask ratio: {mask_ratio}")
|
| 153 |
+
|
| 154 |
if task == "mlm":
|
| 155 |
# Prepare MLM sample
|
| 156 |
masked_text, masked_indices, masked_tokens = prepare_mlm_sample(current_sample, mask_ratio)
|
|
|
|
| 395 |
)
|
| 396 |
|
| 397 |
with gr.Row():
|
| 398 |
+
new_button = gr.Button("New Sample", variant="primary")
|
| 399 |
reset_button = gr.Button("Reset Stats")
|
| 400 |
|
| 401 |
# Consolidated input area - only one visible at a time
|
|
|
|
| 455 |
outputs=[mlm_instructions, ntp_instructions, answer_input, mask_count]
|
| 456 |
)
|
| 457 |
|
| 458 |
+
# Update the sample text when mask ratio changes (without clicking new sample)
|
| 459 |
+
def update_on_ratio_change(mask_ratio_pct, task):
|
| 460 |
+
print(f"Ratio changed to {mask_ratio_pct}%")
|
| 461 |
+
# Don't generate a new sample here, just update the UI to show the effect of ratio change
|
| 462 |
+
return f"Current mask/cut ratio: {mask_ratio_pct}%. Click 'New Sample' to apply."
|
| 463 |
+
|
| 464 |
+
mask_ratio.change(
|
| 465 |
+
update_on_ratio_change,
|
| 466 |
+
inputs=[mask_ratio, task_radio],
|
| 467 |
+
outputs=[result]
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
# Update the sample text and also update the mask count
|
| 471 |
def new_sample_with_count(mask_ratio_pct, task):
|
| 472 |
+
print(f"Generating new sample with mask ratio: {mask_ratio_pct}% for task: {task}")
|
| 473 |
ratio = float(mask_ratio_pct) / 100.0
|
| 474 |
sample = get_new_sample(task, ratio)
|
| 475 |
mask_count_text = ""
|
|
|
|
| 477 |
if task == "mlm":
|
| 478 |
count = len(masked_tokens)
|
| 479 |
mask_count_text = f"**Number of [MASK] tokens to guess: {count}**"
|
| 480 |
+
print(f"Generated MLM sample with {count} masks at ratio {ratio}")
|
| 481 |
else:
|
| 482 |
mask_count_text = "**Next Token Prediction mode - guess one token at a time**"
|
| 483 |
+
print(f"Generated NTP sample with cut ratio {ratio}")
|
| 484 |
|
| 485 |
return sample, mask_count_text, ""
|
| 486 |
|