Spaces:
Running
Running
adding application
Browse files
app.py
CHANGED
|
@@ -226,58 +226,55 @@ def find_last_token_index(full_ids, word_ids):
|
|
| 226 |
|
| 227 |
|
| 228 |
def analyse_word(model_name: str, extraction_template: str, word: str, patchscopes_template: str):
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
"<p style='color:red'>❌ Patchscopes library not found. Run:<br/>"
|
| 232 |
-
"<code>pip install git+https://github.com/schwartz-lab-NLP/Tokens2Words</code></p>"
|
| 233 |
-
)
|
| 234 |
-
|
| 235 |
-
model, tokenizer = get_model_and_tokenizer(model_name)
|
| 236 |
|
| 237 |
-
|
| 238 |
-
|
| 239 |
|
| 240 |
-
|
| 241 |
-
|
| 242 |
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
|
|
|
|
|
|
| 281 |
|
| 282 |
|
| 283 |
# ----------------------------- GRADIO UI -------------------------------
|
|
@@ -311,8 +308,4 @@ with gr.Blocks(theme="soft") as demo:
|
|
| 311 |
)
|
| 312 |
|
| 313 |
if __name__ == "__main__":
|
| 314 |
-
|
| 315 |
-
demo.launch()
|
| 316 |
-
except Exception as e:
|
| 317 |
-
print(f"Error launching Gradio app: {e}")
|
| 318 |
-
raise
|
|
|
|
| 226 |
|
| 227 |
|
| 228 |
def analyse_word(model_name: str, extraction_template: str, word: str, patchscopes_template: str):
|
| 229 |
+
try:
|
| 230 |
+
model, tokenizer = get_model_and_tokenizer(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
+
# Build extraction prompt (where hidden states will be collected)
|
| 233 |
+
extraction_prompt ="X"
|
| 234 |
|
| 235 |
+
# Identify last token position of the *word* inside the prompt IDs
|
| 236 |
+
word_token_ids = tokenizer.encode(word, add_special_tokens=False)
|
| 237 |
|
| 238 |
+
# Instantiate Patchscopes retriever
|
| 239 |
+
patch_retriever = PatchscopesRetriever(
|
| 240 |
+
model,
|
| 241 |
+
tokenizer,
|
| 242 |
+
extraction_prompt,
|
| 243 |
+
patchscopes_template,
|
| 244 |
+
prompt_target_placeholder="X",
|
| 245 |
+
)
|
| 246 |
|
| 247 |
+
# Run retrieval for the word across all layers (one pass)
|
| 248 |
+
retrieved_words = patch_retriever.get_hidden_states_and_retrieve_word(
|
| 249 |
+
word,
|
| 250 |
+
num_tokens_to_generate=len(tokenizer.tokenize(word)),
|
| 251 |
+
)[0]
|
| 252 |
+
|
| 253 |
+
# Build a table summarising which layers match
|
| 254 |
+
records = []
|
| 255 |
+
matches = 0
|
| 256 |
+
for layer_idx, ret_word in enumerate(retrieved_words):
|
| 257 |
+
match = ret_word.strip(" ") == word.strip(" ")
|
| 258 |
+
if match:
|
| 259 |
+
matches += 1
|
| 260 |
+
records.append({"Layer": layer_idx, "Retrieved": ret_word, "Match?": "✓" if match else ""})
|
| 261 |
+
|
| 262 |
+
df = pd.DataFrame(records)
|
| 263 |
+
|
| 264 |
+
def _style(row):
|
| 265 |
+
color = "background-color: lightgreen" if row["Match?"] else ""
|
| 266 |
+
return [color] * len(row)
|
| 267 |
+
|
| 268 |
+
html_table = df.style.apply(_style, axis=1).hide(axis="index").to_html(escape=False)
|
| 269 |
+
|
| 270 |
+
sub_tokens = tokenizer.convert_ids_to_tokens(word_token_ids)
|
| 271 |
+
top = (
|
| 272 |
+
f"<p><b>Sub‑word tokens:</b> {' , '.join(sub_tokens)}</p>"
|
| 273 |
+
f"<p><b>Total matched layers:</b> {matches} / {len(retrieved_words)}</p>"
|
| 274 |
+
)
|
| 275 |
+
return top + html_table
|
| 276 |
+
except Exception as e:
|
| 277 |
+
return f"<p style='color:red'>❌ Error: {str(e)}</p>"
|
| 278 |
|
| 279 |
|
| 280 |
# ----------------------------- GRADIO UI -------------------------------
|
|
|
|
| 308 |
)
|
| 309 |
|
| 310 |
if __name__ == "__main__":
|
| 311 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|