llm-thinking / utils.py
shftan's picture
Fix cuda
0f93598
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as clrs
import requests
import json
import pandas as pd
import torch
import spaces
# Function to get tokens given text
@spaces.GPU
def get_tokens(tokenizer, text):
token_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=False).to("cuda" if torch.cuda.is_available() else "cpu")
tokens = tokenizer.convert_ids_to_tokens(token_ids[0])
return tokens, token_ids
# Function to apply chat template to prompt
@spaces.GPU
def decorate_prompt(tokenizer, prompt):
chat = [
{"role": "user", "content": prompt},
{"role": "assistant", "content": ""},
]
text = tokenizer.apply_chat_template(chat, tokenize=False)
token_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=False).to("cuda" if torch.cuda.is_available() else "cpu")
return token_ids
# Function to get response to prompt
def get_response(model_pipe, prompt):
response = model_pipe(prompt)[0]['generated_text']
return response
# Function to highlight tokens based on given values
def plot_tokens_with_highlights(tokens, values, concept, cmap_name='Oranges', vmin=None, vmax=None):
if len(tokens) != len(values):
raise ValueError("The number of tokens and values must be the same.")
# Set color map
cmap = cm.get_cmap(cmap_name)
norm = clrs.Normalize(vmin=vmin if vmin is not None else values.detach().min(),
vmax=vmax if vmax is not None else values.detach().max())
html_output = f"<h3>How much information about the concept '{concept}' is carried in each token:</h3>"
for token, value in zip(tokens, values.detach().numpy()):
rgba_color = cmap(norm(value))
hex_color = '#%02x%02x%02x' % (int(rgba_color[0]*255), int(rgba_color[1]*255), int(rgba_color[2]*255))
html_output += f'<span style="background-color: {hex_color};" title="{value:.4f}">{token}</span> '
return html_output
# Function to get concepts dictionary
def get_concepts_dictionary(dictionary_url):
response = requests.get(dictionary_url, stream=True)
response.raise_for_status()
data_dict = {}
for line in response.iter_lines(decode_unicode=True):
if line:
obj = json.loads(line)
concept_id = obj.get("concept_id")
concept = obj.get("concept")
if concept_id and concept:
data_dict[concept_id] = concept.capitalize()
return data_dict
# Function to get matching concepts
def select_concepts(all_concepts, desired_concept):
concept_ids = []
for k, v in all_concepts.items():
if desired_concept.lower() in v.lower():
concept_ids.append(k)
concept_data = []
for concept_id in concept_ids:
concept_name = all_concepts.get(concept_id, "Unknown Concept")
concept_data.append({"Concept ID": concept_id, "Concept Name": concept_name})
concept_df = pd.DataFrame(concept_data)
return torch.tensor(concept_ids), concept_df