Spaces:
Sleeping
Sleeping
File size: 6,372 Bytes
1607001 0f93598 1607001 af4a860 0f93598 af4a860 0f93598 af4a860 1607001 ec01901 0f93598 ec01901 0f93598 ec01901 0f93598 ec01901 0f93598 ec01901 0f93598 ec01901 0f93598 ec01901 0f93598 ec01901 0f93598 ec01901 af4a860 1607001 af4a860 0f93598 ec01901 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import gradio as gr
import spaces
from huggingface_hub import HfApi, hf_hub_download
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import pyvene as pv
from utils import get_tokens, select_concepts, get_concepts_dictionary, get_response, plot_tokens_with_highlights
import os
hf_token = os.getenv("HF_TOKEN")
@spaces.GPU
def launch_app():
@spaces.GPU
def process_user_input(prompt, concept):
yield "Processing..."
# Check if prompt or concept are empty
if not prompt or not concept:
return f"<h3>Please provide both a prompt and a concept</h3>"
# Convert prompt to tokens
tokens, token_ids = get_tokens(tokenizer, prompt)
# Get concept IDs and names
concept_ids, concept_df = select_concepts(all_concepts, concept)
if len(concept_ids) == 0:
concepts_html = f"<h3>No relevant concepts found for '{concept}' in LLM thoughts dictionary. Try another concept.</h3>"
else:
concepts_html = f"<h3>using the following in the LLM thoughts dictionary relevant to '{concept}' ({len(concept_ids)} out of {len(all_concepts)} concepts):</h3>"
styled_table = concept_df.style.hide(axis="index").set_properties(**{'background-color': '#f0f0f0', 'color': 'black', 'border-color': 'white'}).to_html()
concepts_html += f'<div style="height: 200px; overflow-y: scroll;">{styled_table}</div>'
# Get activations
if len(concept_ids) > 0:
acts = pv_model.forward({"input_ids": token_ids}, return_dict=True).collected_activations[0]
vals = acts[0, :, concept_ids].sum(-1).cpu()
# Get highlighted tokens
highlighted_tokens_html = plot_tokens_with_highlights(tokens, vals, concept)
else:
highlighted_tokens_html = ""
# Get LLM response
response = get_response(pipe, prompt)
response_html = f"""<h3>LLM response to your prompt:</h3>
{response}
"""
# Write documentation
documentation_html = f"""<h3>How does this work?</h3>
<ul>
<li>The LLM model is an instruction-tuned model, <a href="https://huggingface.co/google/gemma-2-2b-it">Google gemma-2-2b-it</a>.
<li>The LLM interpreter, <a href="https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res">gemma-reft-r1-2b-it-res</a> (not from Google) is trained on the LLM model's layer 20's residual stream. The choices of layer 20 and the residual stream are arbitrary.
<li>The LLM interpreter decomposes the layer 20 residual stream activations into a <a href="https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res/raw/main/l20/metadata.jsonl">dictionary</a> of {len(all_concepts)} human-understandable concepts. This dictionary is not comprehensive; it is possible for a concept you input to not be present in this dictionary.
<li>Each token is highlighted in terms of how much information about a given concept is carried in each token.
<li>Do you find the results surprising? Any feedback? Any ideas on how I can make this app more useful? Please let me know! Contact: Sarah Tan.
</ul>
"""
# Combine HTMLs
output_html = highlighted_tokens_html + concepts_html + "<p> </p>" + response_html + "<p> </p>" + documentation_html
yield output_html
# Set model, interpreter, dictionary choices
model_name = "google/gemma-3-270m-it" #"google/gemma-2-2b-it"
interpreter_name = "pyvene/gemma-reft-r1-2b-it-res"
interpreter_path = "l20/weight.pt"
interpreter_component = "model.layers[20].output"
dictionary_url = "https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res/raw/main/l20/metadata.jsonl"
# Interpreter class
class Encoder(pv.CollectIntervention):
def __init__(self, **kwargs):
super().__init__(**kwargs, keep_last_dim=True)
self.proj = torch.nn.Linear(
self.embed_dim, kwargs["latent_dim"], bias=False)
def forward(self, base, source=None, subspaces=None):
return torch.relu(self.proj(base))
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', token=hf_token).to("cuda" if torch.cuda.is_available() else "cpu")
# Load fast model inference pipeline
pipe = pipeline(
task="text-generation",
model=model_name,
use_fast=True,
token=hf_token
)
path_to_params = hf_hub_download(
repo_id=interpreter_name,
filename=interpreter_path,
force_download=False,
)
params = torch.load(path_to_params, map_location="cuda" if torch.cuda.is_available() else "cpu")
encoder = Encoder(embed_dim=params.shape[0], latent_dim=params.shape[1]).to("cuda" if torch.cuda.is_available() else "cpu")
encoder.proj.weight.data = params.float()
pv_model = pv.IntervenableModel({
"component": interpreter_component,
"intervention": encoder}, model=model).to("cuda" if torch.cuda.is_available() else "cpu")
# Load dictionary
all_concepts = get_concepts_dictionary(dictionary_url)
description_text = """
## Does an LLM Think Like You?
Input a prompt and a concept that you think is most relevant for your prompt. See how much (if at all) the LLM uses that concept when processing your prompt.
Examples:
- **Prompt**: What is 2+2? **Concept**: math
- **Prompt**: I really like anchovies on pizza but I know a lot of people don't. **Concept**: food
"""
with gr.Blocks() as demo:
gr.Markdown(description_text)
with gr.Row():
prompt_input = gr.Textbox(label="Enter a prompt", value="I really like anchovies on pizza but I know a lot of people don't.")
concept_input = gr.Textbox(label="Enter a concept that you think is most relevant for your prompt", value="food")
process_button = gr.Button("See if an LLM thinks like you!")
output_html = gr.HTML()
process_button.click(
process_user_input,
inputs=[prompt_input, concept_input],
outputs=output_html
)
demo.launch(debug=True)
if __name__ == "__main__":
launch_app() |