Spaces:
Paused
Paused
Correct context windows
Browse files
app.py
CHANGED
|
@@ -13,9 +13,10 @@ attnlrp.register(model)
|
|
| 13 |
|
| 14 |
|
| 15 |
def really_clean_tokens(tokens):
|
|
|
|
| 16 |
cleaned_tokens = []
|
| 17 |
for token in tokens:
|
| 18 |
-
token = token.replace("_", " ").replace("β", " ").replace("<s>", "")
|
| 19 |
if token.startswith("<0x") and token.endswith(">"):
|
| 20 |
# Convert hex to character
|
| 21 |
char_code = int(token[3:-1], 16)
|
|
@@ -44,12 +45,11 @@ def generate_and_visualize(prompt, num_tokens=10):
|
|
| 44 |
|
| 45 |
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
|
| 46 |
input_embeds = model.get_input_embeddings()(input_ids)
|
| 47 |
-
|
| 48 |
-
input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
| 49 |
-
input_tokens = really_clean_tokens(input_tokens)
|
| 50 |
generated_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(generated_tokens_ids))
|
| 51 |
|
| 52 |
return input_tokens, all_relevances, generated_tokens
|
|
|
|
| 53 |
def process_relevances(input_tokens, all_relevances, generated_tokens):
|
| 54 |
attention_matrix = np.array([el[:len(all_relevances[0])] for el in all_relevances])
|
| 55 |
|
|
@@ -103,11 +103,11 @@ def process_relevances(input_tokens, all_relevances, generated_tokens):
|
|
| 103 |
for i, (token, coords) in enumerate(output_with_notes):
|
| 104 |
if coords is not None:
|
| 105 |
best_width, best_patch_end = coords
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
output_with_notes[i] = (token, (context, significant_start, significant_end))
|
| 112 |
|
| 113 |
return output_with_notes
|
|
@@ -123,7 +123,7 @@ def create_html_with_hover(output_with_notes):
|
|
| 123 |
formatted_context.append(f'<strong>{token}</strong>')
|
| 124 |
else:
|
| 125 |
formatted_context.append(token)
|
| 126 |
-
formatted_note = "
|
| 127 |
html += f'<span class="hoverable" data-note-id="note-{i}">{text}<sup>[{i+1}]</sup>'
|
| 128 |
html += f'<span class="hover-note">{formatted_note}</span></span>'
|
| 129 |
else:
|
|
@@ -144,7 +144,6 @@ css = """
|
|
| 144 |
.hover-note {
|
| 145 |
display: none;
|
| 146 |
position: absolute;
|
| 147 |
-
background-color: #f0f0f0;
|
| 148 |
padding: 5px;
|
| 149 |
border-radius: 5px;
|
| 150 |
bottom: 100%;
|
|
@@ -153,8 +152,9 @@ css = """
|
|
| 153 |
white-space: normal;
|
| 154 |
background-color: rgba(240, 240, 240, 1);
|
| 155 |
max-width: 600px;
|
|
|
|
| 156 |
word-wrap: break-word;
|
| 157 |
-
z-index:
|
| 158 |
}
|
| 159 |
.hoverable:hover .hover-note { display: block; }
|
| 160 |
"""
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
def really_clean_tokens(tokens):
|
| 16 |
+
tokens = clean_tokens(tokens)
|
| 17 |
cleaned_tokens = []
|
| 18 |
for token in tokens:
|
| 19 |
+
token = token.replace("_", " ").replace("β", " ").replace("<s>", " ")
|
| 20 |
if token.startswith("<0x") and token.endswith(">"):
|
| 21 |
# Convert hex to character
|
| 22 |
char_code = int(token[3:-1], 16)
|
|
|
|
| 45 |
|
| 46 |
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
|
| 47 |
input_embeds = model.get_input_embeddings()(input_ids)
|
| 48 |
+
input_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(input_ids[0]))
|
|
|
|
|
|
|
| 49 |
generated_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(generated_tokens_ids))
|
| 50 |
|
| 51 |
return input_tokens, all_relevances, generated_tokens
|
| 52 |
+
|
| 53 |
def process_relevances(input_tokens, all_relevances, generated_tokens):
|
| 54 |
attention_matrix = np.array([el[:len(all_relevances[0])] for el in all_relevances])
|
| 55 |
|
|
|
|
| 103 |
for i, (token, coords) in enumerate(output_with_notes):
|
| 104 |
if coords is not None:
|
| 105 |
best_width, best_patch_end = coords
|
| 106 |
+
significant_start = max(0, best_patch_end - best_width)
|
| 107 |
+
significant_end = best_patch_end + kernel_width
|
| 108 |
+
context_start = max(0, significant_start - context_width)
|
| 109 |
+
context_end = min(len(input_tokens), significant_end + context_width)
|
| 110 |
+
context = input_tokens[context_start:context_end]
|
| 111 |
output_with_notes[i] = (token, (context, significant_start, significant_end))
|
| 112 |
|
| 113 |
return output_with_notes
|
|
|
|
| 123 |
formatted_context.append(f'<strong>{token}</strong>')
|
| 124 |
else:
|
| 125 |
formatted_context.append(token)
|
| 126 |
+
formatted_note = "".join(formatted_context)
|
| 127 |
html += f'<span class="hoverable" data-note-id="note-{i}">{text}<sup>[{i+1}]</sup>'
|
| 128 |
html += f'<span class="hover-note">{formatted_note}</span></span>'
|
| 129 |
else:
|
|
|
|
| 144 |
.hover-note {
|
| 145 |
display: none;
|
| 146 |
position: absolute;
|
|
|
|
| 147 |
padding: 5px;
|
| 148 |
border-radius: 5px;
|
| 149 |
bottom: 100%;
|
|
|
|
| 152 |
white-space: normal;
|
| 153 |
background-color: rgba(240, 240, 240, 1);
|
| 154 |
max-width: 600px;
|
| 155 |
+
width:500px;
|
| 156 |
word-wrap: break-word;
|
| 157 |
+
z-index: 10;
|
| 158 |
}
|
| 159 |
.hoverable:hover .hover-note { display: block; }
|
| 160 |
"""
|