Spaces:
Running
on
Zero
Running
on
Zero
Changed initialization to MASK tokens instead of EOS tokens
Browse files
app.py
CHANGED
|
@@ -16,8 +16,8 @@ hf_token = os.getenv("HF_TOKEN")
|
|
| 16 |
# --- Load tokenizer ---
|
| 17 |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token)
|
| 18 |
vocab_size = len(tokenizer)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False)
|
| 22 |
|
| 23 |
# def load_model():
|
|
@@ -114,7 +114,6 @@ def confidence_guided_noising(input_ids, answer_start, confidences, noise_clippi
|
|
| 114 |
noised = input_ids.copy()
|
| 115 |
answer_len = len(input_ids) - answer_start
|
| 116 |
num_to_noise = int(threshold * answer_len * noise_start)
|
| 117 |
-
mask_token_id = tokenizer.encode('MASK', add_special_tokens = False)[0]
|
| 118 |
|
| 119 |
if num_to_noise == 0:
|
| 120 |
return noised
|
|
@@ -176,7 +175,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness, clustering, noise_
|
|
| 176 |
return
|
| 177 |
|
| 178 |
if len(input_ids) < 256:
|
| 179 |
-
input_ids += [
|
| 180 |
else:
|
| 181 |
input_ids = input_ids[:256]
|
| 182 |
|
|
@@ -203,7 +202,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness, clustering, noise_
|
|
| 203 |
highlighted = []
|
| 204 |
for j, tok in enumerate(decoded_tokens):
|
| 205 |
tok_id = tokenizer.convert_tokens_to_ids(tok)
|
| 206 |
-
if tok_id ==
|
| 207 |
continue
|
| 208 |
token_str = tokenizer.convert_tokens_to_string([tok])
|
| 209 |
if prev_decoded_tokens and j < len(prev_decoded_tokens) and tok != prev_decoded_tokens[j]:
|
|
@@ -245,7 +244,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness, clustering, noise_
|
|
| 245 |
highlighted = []
|
| 246 |
for j, tok in enumerate(decoded_tokens):
|
| 247 |
tok_id = tokenizer.convert_tokens_to_ids(tok)
|
| 248 |
-
if tok_id ==
|
| 249 |
continue
|
| 250 |
token_str = tokenizer.convert_tokens_to_string([tok])
|
| 251 |
abs_idx = answer_start + j
|
|
@@ -259,7 +258,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness, clustering, noise_
|
|
| 259 |
|
| 260 |
|
| 261 |
final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
|
| 262 |
-
final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) !=
|
| 263 |
final_output = tokenizer.convert_tokens_to_string(final_tokens)
|
| 264 |
print(final_output)
|
| 265 |
yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output.replace('\n', '<br>')
|
|
|
|
| 16 |
# --- Load tokenizer ---
|
| 17 |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token)
|
| 18 |
vocab_size = len(tokenizer)
|
| 19 |
+
eos_token_id = tokenizer.eos_token_id
|
| 20 |
+
mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
|
| 21 |
assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False)
|
| 22 |
|
| 23 |
# def load_model():
|
|
|
|
| 114 |
noised = input_ids.copy()
|
| 115 |
answer_len = len(input_ids) - answer_start
|
| 116 |
num_to_noise = int(threshold * answer_len * noise_start)
|
|
|
|
| 117 |
|
| 118 |
if num_to_noise == 0:
|
| 119 |
return noised
|
|
|
|
| 175 |
return
|
| 176 |
|
| 177 |
if len(input_ids) < 256:
|
| 178 |
+
input_ids += [mask_token_id] * (256 - len(input_ids))
|
| 179 |
else:
|
| 180 |
input_ids = input_ids[:256]
|
| 181 |
|
|
|
|
| 202 |
highlighted = []
|
| 203 |
for j, tok in enumerate(decoded_tokens):
|
| 204 |
tok_id = tokenizer.convert_tokens_to_ids(tok)
|
| 205 |
+
if tok_id == eos_token_id:
|
| 206 |
continue
|
| 207 |
token_str = tokenizer.convert_tokens_to_string([tok])
|
| 208 |
if prev_decoded_tokens and j < len(prev_decoded_tokens) and tok != prev_decoded_tokens[j]:
|
|
|
|
| 244 |
highlighted = []
|
| 245 |
for j, tok in enumerate(decoded_tokens):
|
| 246 |
tok_id = tokenizer.convert_tokens_to_ids(tok)
|
| 247 |
+
if tok_id == eos_token_id:
|
| 248 |
continue
|
| 249 |
token_str = tokenizer.convert_tokens_to_string([tok])
|
| 250 |
abs_idx = answer_start + j
|
|
|
|
| 258 |
|
| 259 |
|
| 260 |
final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
|
| 261 |
+
final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eos_token_id]
|
| 262 |
final_output = tokenizer.convert_tokens_to_string(final_tokens)
|
| 263 |
print(final_output)
|
| 264 |
yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output.replace('\n', '<br>')
|