import gradio as gr import spaces from huggingface_hub import HfApi, hf_hub_download import torch from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import pyvene as pv from utils import get_tokens, select_concepts, get_concepts_dictionary, get_response, plot_tokens_with_highlights import os hf_token = os.getenv("HF_TOKEN") def launch_app(): @spaces.GPU def process_user_input(prompt, concept): yield "Processing..." # Check if prompt or concept are empty if not prompt or not concept: return f"

Please provide both a prompt and a concept

" # Convert prompt to tokens tokens, token_ids = get_tokens(tokenizer, prompt) # Get concept IDs and names concept_ids, concept_df = select_concepts(all_concepts, concept) if len(concept_ids) == 0: concepts_html = f"

No relevant concepts found for '{concept}' in LLM thoughts dictionary. Try another concept.

" else: concepts_html = f"

using the following in the LLM thoughts dictionary relevant to '{concept}' ({len(concept_ids)} out of {len(all_concepts)} concepts):

" styled_table = concept_df.style.hide(axis="index").set_properties(**{'background-color': '#f0f0f0', 'color': 'black', 'border-color': 'white'}).to_html() concepts_html += f'
{styled_table}
' # Get activations if len(concept_ids) > 0: acts = pv_model.forward({"input_ids": token_ids}, return_dict=True).collected_activations[0] vals = acts[0, :, concept_ids].sum(-1).cpu() # Get highlighted tokens highlighted_tokens_html = plot_tokens_with_highlights(tokens, vals, concept) else: highlighted_tokens_html = "" # Get LLM response response = get_response(pipe, prompt) response_html = f"""

LLM response to your prompt:

{response} """ # Write documentation documentation_html = f"""

How does this work?

""" # Combine HTMLs output_html = highlighted_tokens_html + concepts_html + "

 

" + response_html + "

 

" + documentation_html yield output_html # Set model, interpreter, dictionary choices model_name = "google/gemma-2-2b-it" interpreter_name = "pyvene/gemma-reft-r1-2b-it-res" interpreter_path = "l20/weight.pt" interpreter_component = "model.layers[20].output" dictionary_url = "https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res/raw/main/l20/metadata.jsonl" # Interpreter class class Encoder(pv.CollectIntervention): def __init__(self, **kwargs): super().__init__(**kwargs, keep_last_dim=True) self.proj = torch.nn.Linear( self.embed_dim, kwargs["latent_dim"], bias=False) def forward(self, base, source=None, subspaces=None): return torch.relu(self.proj(base)) with gr.Blocks() as demo: # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', token=hf_token).to("cuda" if torch.cuda.is_available() else "cpu") # Load fast model inference pipeline pipe = pipeline( task="text-generation", model=model_name, use_fast=True, token=hf_token ) path_to_params = hf_hub_download( repo_id=interpreter_name, filename=interpreter_path, force_download=False, ) params = torch.load(path_to_params, map_location="cuda" if torch.cuda.is_available() else "cpu") encoder = Encoder(embed_dim=params.shape[0], latent_dim=params.shape[1]).to("cuda" if torch.cuda.is_available() else "cpu") encoder.proj.weight.data = params.float() pv_model = pv.IntervenableModel({ "component": interpreter_component, "intervention": encoder}, model=model).to("cuda" if torch.cuda.is_available() else "cpu") # Load dictionary all_concepts = get_concepts_dictionary(dictionary_url) description_text = """ ## Does an LLM Think Like You? Input a prompt and a concept that you think is most relevant for your prompt. See how much (if at all) the LLM uses that concept when processing your prompt. Examples: - **Prompt**: What is 2+2? **Concept**: math - **Prompt**: I really like anchovies on pizza but I know a lot of people don't. **Concept**: food """ gr.Markdown(description_text) with gr.Row(): prompt_input = gr.Textbox(label="Enter a prompt", value="I really like anchovies on pizza but I know a lot of people don't.") concept_input = gr.Textbox(label="Enter a concept that you think is most relevant for your prompt", value="food") process_button = gr.Button("See if an LLM thinks like you!") output_html = gr.HTML() process_button.click( process_user_input, inputs=[prompt_input, concept_input], outputs=output_html ) demo.launch(debug=True) if __name__ == "__main__": launch_app()