Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import spaces | |
| from huggingface_hub import HfApi, hf_hub_download | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import pyvene as pv | |
| from utils import get_tokens, select_concepts, get_concepts_dictionary, get_response, plot_tokens_with_highlights | |
| import os | |
| hf_token = os.getenv("HF_TOKEN") | |
| def launch_app(): | |
| def process_user_input(prompt, concept): | |
| yield "Processing..." | |
| # Check if prompt or concept are empty | |
| if not prompt or not concept: | |
| return f"<h3>Please provide both a prompt and a concept</h3>" | |
| # Convert prompt to tokens | |
| tokens, token_ids = get_tokens(tokenizer, prompt) | |
| # Get concept IDs and names | |
| concept_ids, concept_df = select_concepts(all_concepts, concept) | |
| if len(concept_ids) == 0: | |
| concepts_html = f"<h3>No relevant concepts found for '{concept}' in LLM thoughts dictionary. Try another concept.</h3>" | |
| else: | |
| concepts_html = f"<h3>using the following in the LLM thoughts dictionary relevant to '{concept}' ({len(concept_ids)} out of {len(all_concepts)} concepts):</h3>" | |
| styled_table = concept_df.style.hide(axis="index").set_properties(**{'background-color': '#f0f0f0', 'color': 'black', 'border-color': 'white'}).to_html() | |
| concepts_html += f'<div style="height: 200px; overflow-y: scroll;">{styled_table}</div>' | |
| # Get activations | |
| if len(concept_ids) > 0: | |
| acts = pv_model.forward({"input_ids": token_ids}, return_dict=True).collected_activations[0] | |
| vals = acts[0, :, concept_ids].sum(-1).cpu() | |
| # Get highlighted tokens | |
| highlighted_tokens_html = plot_tokens_with_highlights(tokens, vals, concept) | |
| else: | |
| highlighted_tokens_html = "" | |
| # Get LLM response | |
| response = get_response(pipe, prompt) | |
| response_html = f"""<h3>LLM response to your prompt:</h3> | |
| {response} | |
| """ | |
| # Write documentation | |
| documentation_html = f"""<h3>How does this work?</h3> | |
| <ul> | |
| <li>The LLM model is an instruction-tuned model, <a href="https://huggingface.co/google/gemma-2-2b-it">Google gemma-2-2b-it</a>. | |
| <li>The LLM interpreter, <a href="https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res">gemma-reft-r1-2b-it-res</a> (not from Google) is trained on the LLM model's layer 20's residual stream. The choices of layer 20 and the residual stream are arbitrary. | |
| <li>The LLM interpreter decomposes the layer 20 residual stream activations into a <a href="https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res/raw/main/l20/metadata.jsonl">dictionary</a> of {len(all_concepts)} human-understandable concepts. This dictionary is not comprehensive; it is possible for a concept you input to not be present in this dictionary. | |
| <li>Each token is highlighted in terms of how much information about a given concept is carried in each token. | |
| <li>Do you find the results surprising? Any feedback? Any ideas on how I can make this app more useful? Please let me know! Contact: Sarah Tan. | |
| </ul> | |
| """ | |
| # Combine HTMLs | |
| output_html = highlighted_tokens_html + concepts_html + "<p> </p>" + response_html + "<p> </p>" + documentation_html | |
| yield output_html | |
| # Set model, interpreter, dictionary choices | |
| model_name = "google/gemma-3-270m-it" #"google/gemma-2-2b-it" | |
| interpreter_name = "pyvene/gemma-reft-r1-2b-it-res" | |
| interpreter_path = "l20/weight.pt" | |
| interpreter_component = "model.layers[20].output" | |
| dictionary_url = "https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res/raw/main/l20/metadata.jsonl" | |
| # Interpreter class | |
| class Encoder(pv.CollectIntervention): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs, keep_last_dim=True) | |
| self.proj = torch.nn.Linear( | |
| self.embed_dim, kwargs["latent_dim"], bias=False) | |
| def forward(self, base, source=None, subspaces=None): | |
| return torch.relu(self.proj(base)) | |
| # Load tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', token=hf_token).to("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load fast model inference pipeline | |
| pipe = pipeline( | |
| task="text-generation", | |
| model=model_name, | |
| use_fast=True, | |
| token=hf_token | |
| ) | |
| path_to_params = hf_hub_download( | |
| repo_id=interpreter_name, | |
| filename=interpreter_path, | |
| force_download=False, | |
| ) | |
| params = torch.load(path_to_params, map_location="cuda" if torch.cuda.is_available() else "cpu") | |
| encoder = Encoder(embed_dim=params.shape[0], latent_dim=params.shape[1]).to("cuda" if torch.cuda.is_available() else "cpu") | |
| encoder.proj.weight.data = params.float() | |
| pv_model = pv.IntervenableModel({ | |
| "component": interpreter_component, | |
| "intervention": encoder}, model=model).to("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load dictionary | |
| all_concepts = get_concepts_dictionary(dictionary_url) | |
| description_text = """ | |
| ## Does an LLM Think Like You? | |
| Input a prompt and a concept that you think is most relevant for your prompt. See how much (if at all) the LLM uses that concept when processing your prompt. | |
| Examples: | |
| - **Prompt**: What is 2+2? **Concept**: math | |
| - **Prompt**: I really like anchovies on pizza but I know a lot of people don't. **Concept**: food | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown(description_text) | |
| with gr.Row(): | |
| prompt_input = gr.Textbox(label="Enter a prompt", value="I really like anchovies on pizza but I know a lot of people don't.") | |
| concept_input = gr.Textbox(label="Enter a concept that you think is most relevant for your prompt", value="food") | |
| process_button = gr.Button("See if an LLM thinks like you!") | |
| output_html = gr.HTML() | |
| process_button.click( | |
| process_user_input, | |
| inputs=[prompt_input, concept_input], | |
| outputs=output_html | |
| ) | |
| demo.launch(debug=True) | |
| if __name__ == "__main__": | |
| launch_app() |