import gradio as gr import spaces from huggingface_hub import HfApi, hf_hub_download import torch from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import pyvene as pv from utils import get_tokens, select_concepts, get_concepts_dictionary, get_response, plot_tokens_with_highlights import os hf_token = os.getenv("HF_TOKEN") def launch_app(): @spaces.GPU def process_user_input(prompt, concept): yield "Processing..." # Check if prompt or concept are empty if not prompt or not concept: return f"
" + response_html + "
" + documentation_html yield output_html # Set model, interpreter, dictionary choices model_name = "google/gemma-2-2b-it" interpreter_name = "pyvene/gemma-reft-r1-2b-it-res" interpreter_path = "l20/weight.pt" interpreter_component = "model.layers[20].output" dictionary_url = "https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res/raw/main/l20/metadata.jsonl" # Interpreter class class Encoder(pv.CollectIntervention): def __init__(self, **kwargs): super().__init__(**kwargs, keep_last_dim=True) self.proj = torch.nn.Linear( self.embed_dim, kwargs["latent_dim"], bias=False) def forward(self, base, source=None, subspaces=None): return torch.relu(self.proj(base)) with gr.Blocks() as demo: # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', token=hf_token).to("cuda" if torch.cuda.is_available() else "cpu") # Load fast model inference pipeline pipe = pipeline( task="text-generation", model=model_name, use_fast=True, token=hf_token ) path_to_params = hf_hub_download( repo_id=interpreter_name, filename=interpreter_path, force_download=False, ) params = torch.load(path_to_params, map_location="cuda" if torch.cuda.is_available() else "cpu") encoder = Encoder(embed_dim=params.shape[0], latent_dim=params.shape[1]).to("cuda" if torch.cuda.is_available() else "cpu") encoder.proj.weight.data = params.float() pv_model = pv.IntervenableModel({ "component": interpreter_component, "intervention": encoder}, model=model).to("cuda" if torch.cuda.is_available() else "cpu") # Load dictionary all_concepts = get_concepts_dictionary(dictionary_url) description_text = """ ## Does an LLM Think Like You? Input a prompt and a concept that you think is most relevant for your prompt. See how much (if at all) the LLM uses that concept when processing your prompt. Examples: - **Prompt**: What is 2+2? **Concept**: math - **Prompt**: I really like anchovies on pizza but I know a lot of people don't. **Concept**: food """ gr.Markdown(description_text) with gr.Row(): prompt_input = gr.Textbox(label="Enter a prompt", value="I really like anchovies on pizza but I know a lot of people don't.") concept_input = gr.Textbox(label="Enter a concept that you think is most relevant for your prompt", value="food") process_button = gr.Button("See if an LLM thinks like you!") output_html = gr.HTML() process_button.click( process_user_input, inputs=[prompt_input, concept_input], outputs=output_html ) demo.launch(debug=True) if __name__ == "__main__": launch_app()