llm-thinking / app.py
shftan's picture
Fix cuda
0f93598
raw
history blame
6.37 kB
import gradio as gr
import spaces
from huggingface_hub import HfApi, hf_hub_download
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import pyvene as pv
from utils import get_tokens, select_concepts, get_concepts_dictionary, get_response, plot_tokens_with_highlights
import os
hf_token = os.getenv("HF_TOKEN")
@spaces.GPU
def launch_app():
@spaces.GPU
def process_user_input(prompt, concept):
yield "Processing..."
# Check if prompt or concept are empty
if not prompt or not concept:
return f"<h3>Please provide both a prompt and a concept</h3>"
# Convert prompt to tokens
tokens, token_ids = get_tokens(tokenizer, prompt)
# Get concept IDs and names
concept_ids, concept_df = select_concepts(all_concepts, concept)
if len(concept_ids) == 0:
concepts_html = f"<h3>No relevant concepts found for '{concept}' in LLM thoughts dictionary. Try another concept.</h3>"
else:
concepts_html = f"<h3>using the following in the LLM thoughts dictionary relevant to '{concept}' ({len(concept_ids)} out of {len(all_concepts)} concepts):</h3>"
styled_table = concept_df.style.hide(axis="index").set_properties(**{'background-color': '#f0f0f0', 'color': 'black', 'border-color': 'white'}).to_html()
concepts_html += f'<div style="height: 200px; overflow-y: scroll;">{styled_table}</div>'
# Get activations
if len(concept_ids) > 0:
acts = pv_model.forward({"input_ids": token_ids}, return_dict=True).collected_activations[0]
vals = acts[0, :, concept_ids].sum(-1).cpu()
# Get highlighted tokens
highlighted_tokens_html = plot_tokens_with_highlights(tokens, vals, concept)
else:
highlighted_tokens_html = ""
# Get LLM response
response = get_response(pipe, prompt)
response_html = f"""<h3>LLM response to your prompt:</h3>
{response}
"""
# Write documentation
documentation_html = f"""<h3>How does this work?</h3>
<ul>
<li>The LLM model is an instruction-tuned model, <a href="https://huggingface.co/google/gemma-2-2b-it">Google gemma-2-2b-it</a>.
<li>The LLM interpreter, <a href="https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res">gemma-reft-r1-2b-it-res</a> (not from Google) is trained on the LLM model's layer 20's residual stream. The choices of layer 20 and the residual stream are arbitrary.
<li>The LLM interpreter decomposes the layer 20 residual stream activations into a <a href="https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res/raw/main/l20/metadata.jsonl">dictionary</a> of {len(all_concepts)} human-understandable concepts. This dictionary is not comprehensive; it is possible for a concept you input to not be present in this dictionary.
<li>Each token is highlighted in terms of how much information about a given concept is carried in each token.
<li>Do you find the results surprising? Any feedback? Any ideas on how I can make this app more useful? Please let me know! Contact: Sarah Tan.
</ul>
"""
# Combine HTMLs
output_html = highlighted_tokens_html + concepts_html + "<p>&nbsp;</p>" + response_html + "<p>&nbsp;</p>" + documentation_html
yield output_html
# Set model, interpreter, dictionary choices
model_name = "google/gemma-3-270m-it" #"google/gemma-2-2b-it"
interpreter_name = "pyvene/gemma-reft-r1-2b-it-res"
interpreter_path = "l20/weight.pt"
interpreter_component = "model.layers[20].output"
dictionary_url = "https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res/raw/main/l20/metadata.jsonl"
# Interpreter class
class Encoder(pv.CollectIntervention):
def __init__(self, **kwargs):
super().__init__(**kwargs, keep_last_dim=True)
self.proj = torch.nn.Linear(
self.embed_dim, kwargs["latent_dim"], bias=False)
def forward(self, base, source=None, subspaces=None):
return torch.relu(self.proj(base))
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', token=hf_token).to("cuda" if torch.cuda.is_available() else "cpu")
# Load fast model inference pipeline
pipe = pipeline(
task="text-generation",
model=model_name,
use_fast=True,
token=hf_token
)
path_to_params = hf_hub_download(
repo_id=interpreter_name,
filename=interpreter_path,
force_download=False,
)
params = torch.load(path_to_params, map_location="cuda" if torch.cuda.is_available() else "cpu")
encoder = Encoder(embed_dim=params.shape[0], latent_dim=params.shape[1]).to("cuda" if torch.cuda.is_available() else "cpu")
encoder.proj.weight.data = params.float()
pv_model = pv.IntervenableModel({
"component": interpreter_component,
"intervention": encoder}, model=model).to("cuda" if torch.cuda.is_available() else "cpu")
# Load dictionary
all_concepts = get_concepts_dictionary(dictionary_url)
description_text = """
## Does an LLM Think Like You?
Input a prompt and a concept that you think is most relevant for your prompt. See how much (if at all) the LLM uses that concept when processing your prompt.
Examples:
- **Prompt**: What is 2+2? **Concept**: math
- **Prompt**: I really like anchovies on pizza but I know a lot of people don't. **Concept**: food
"""
with gr.Blocks() as demo:
gr.Markdown(description_text)
with gr.Row():
prompt_input = gr.Textbox(label="Enter a prompt", value="I really like anchovies on pizza but I know a lot of people don't.")
concept_input = gr.Textbox(label="Enter a concept that you think is most relevant for your prompt", value="food")
process_button = gr.Button("See if an LLM thinks like you!")
output_html = gr.HTML()
process_button.click(
process_user_input,
inputs=[prompt_input, concept_input],
outputs=output_html
)
demo.launch(debug=True)
if __name__ == "__main__":
launch_app()