File size: 6,372 Bytes
1607001
 
0f93598
1607001
af4a860
 
 
0f93598
af4a860
0f93598
af4a860
 
1607001
ec01901
0f93598
ec01901
 
0f93598
 
ec01901
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f93598
ec01901
 
0f93598
ec01901
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f93598
 
ec01901
 
 
 
 
0f93598
 
ec01901
 
 
 
 
 
 
0f93598
 
ec01901
 
 
0f93598
ec01901
 
 
af4a860
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1607001
af4a860
 
 
 
 
 
0f93598
ec01901
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr
import spaces
from huggingface_hub import HfApi, hf_hub_download
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import pyvene as pv
from utils import get_tokens, select_concepts, get_concepts_dictionary, get_response, plot_tokens_with_highlights
import os

hf_token = os.getenv("HF_TOKEN")


@spaces.GPU
def launch_app():

    @spaces.GPU
    def process_user_input(prompt, concept):
        yield "Processing..."

        # Check if prompt or concept are empty
        if not prompt or not concept:
            return f"<h3>Please provide both a prompt and a concept</h3>"

        # Convert prompt to tokens
        tokens, token_ids = get_tokens(tokenizer, prompt)

        # Get concept IDs and names
        concept_ids, concept_df = select_concepts(all_concepts, concept)
        if len(concept_ids) == 0:
            concepts_html = f"<h3>No relevant concepts found for '{concept}' in LLM thoughts dictionary. Try another concept.</h3>"
        else:
            concepts_html = f"<h3>using the following in the LLM thoughts dictionary relevant to '{concept}' ({len(concept_ids)} out of {len(all_concepts)} concepts):</h3>"
            styled_table = concept_df.style.hide(axis="index").set_properties(**{'background-color': '#f0f0f0', 'color': 'black', 'border-color': 'white'}).to_html()
            concepts_html += f'<div style="height: 200px; overflow-y: scroll;">{styled_table}</div>'

        # Get activations
        if len(concept_ids) > 0:
            acts = pv_model.forward({"input_ids": token_ids}, return_dict=True).collected_activations[0]
            vals = acts[0, :, concept_ids].sum(-1).cpu()

            # Get highlighted tokens
            highlighted_tokens_html = plot_tokens_with_highlights(tokens, vals, concept)
        else:
            highlighted_tokens_html = ""

        # Get LLM response
        response = get_response(pipe, prompt)
        response_html = f"""<h3>LLM response to your prompt:</h3>
            {response}
        """

        # Write documentation
        documentation_html = f"""<h3>How does this work?</h3>
            <ul>
            <li>The LLM model is an instruction-tuned model, <a href="https://huggingface.co/google/gemma-2-2b-it">Google gemma-2-2b-it</a>.
            <li>The LLM interpreter, <a href="https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res">gemma-reft-r1-2b-it-res</a> (not from Google) is trained on the LLM model's layer 20's residual stream. The choices of layer 20 and the residual stream are arbitrary.
            <li>The LLM interpreter decomposes the layer 20 residual stream activations into a <a href="https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res/raw/main/l20/metadata.jsonl">dictionary</a> of {len(all_concepts)} human-understandable concepts. This dictionary is not comprehensive; it is possible for a concept you input to not be present in this dictionary.
            <li>Each token is highlighted in terms of how much information about a given concept is carried in each token.
            <li>Do you find the results surprising? Any feedback? Any ideas on how I can make this app more useful? Please let me know! Contact: Sarah Tan.
            </ul>
        """

        # Combine HTMLs
        output_html = highlighted_tokens_html + concepts_html + "<p>&nbsp;</p>" + response_html + "<p>&nbsp;</p>" + documentation_html

        yield output_html 

    # Set model, interpreter, dictionary choices
    model_name = "google/gemma-3-270m-it" #"google/gemma-2-2b-it"
    interpreter_name = "pyvene/gemma-reft-r1-2b-it-res"
    interpreter_path = "l20/weight.pt"
    interpreter_component = "model.layers[20].output"
    dictionary_url = "https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res/raw/main/l20/metadata.jsonl"

    # Interpreter class
    class Encoder(pv.CollectIntervention):
        def __init__(self, **kwargs):
            super().__init__(**kwargs, keep_last_dim=True)
            self.proj = torch.nn.Linear(
                    self.embed_dim, kwargs["latent_dim"], bias=False)
        def forward(self, base, source=None, subspaces=None):
            return torch.relu(self.proj(base))

    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', token=hf_token).to("cuda" if torch.cuda.is_available() else "cpu")

    # Load fast model inference pipeline
    pipe = pipeline(
        task="text-generation",
        model=model_name,
        use_fast=True, 
        token=hf_token
    )

    path_to_params = hf_hub_download(
        repo_id=interpreter_name,
        filename=interpreter_path,
        force_download=False,
    )
    params = torch.load(path_to_params, map_location="cuda" if torch.cuda.is_available() else "cpu")
    encoder = Encoder(embed_dim=params.shape[0], latent_dim=params.shape[1]).to("cuda" if torch.cuda.is_available() else "cpu")
    encoder.proj.weight.data = params.float()
    pv_model = pv.IntervenableModel({
        "component": interpreter_component,
        "intervention": encoder}, model=model).to("cuda" if torch.cuda.is_available() else "cpu")

    # Load dictionary
    all_concepts = get_concepts_dictionary(dictionary_url)

    description_text = """
    ## Does an LLM Think Like You?
    Input a prompt and a concept that you think is most relevant for your prompt. See how much (if at all) the LLM uses that concept when processing your prompt.
    Examples:
    - **Prompt**: What is 2+2? **Concept**: math
    - **Prompt**: I really like anchovies on pizza but I know a lot of people don't. **Concept**: food
    """

    with gr.Blocks() as demo:
        gr.Markdown(description_text)
        with gr.Row():
            prompt_input = gr.Textbox(label="Enter a prompt", value="I really like anchovies on pizza but I know a lot of people don't.")
            concept_input = gr.Textbox(label="Enter a concept that you think is most relevant for your prompt", value="food")
        process_button = gr.Button("See if an LLM thinks like you!")
        output_html = gr.HTML()

        process_button.click(
            process_user_input,
            inputs=[prompt_input, concept_input],
            outputs=output_html
        )
    
    demo.launch(debug=True)

if __name__ == "__main__":
    launch_app()