Spaces:
Sleeping
Sleeping
| import matplotlib.pyplot as plt | |
| import matplotlib.cm as cm | |
| import matplotlib.colors as clrs | |
| import requests | |
| import json | |
| import pandas as pd | |
| import torch | |
| import spaces | |
| # Function to get tokens given text | |
| def get_tokens(tokenizer, text): | |
| token_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=False).to("cuda" if torch.cuda.is_available() else "cpu") | |
| tokens = tokenizer.convert_ids_to_tokens(token_ids[0]) | |
| return tokens, token_ids | |
| # Function to apply chat template to prompt | |
| def decorate_prompt(tokenizer, prompt): | |
| chat = [ | |
| {"role": "user", "content": prompt}, | |
| {"role": "assistant", "content": ""}, | |
| ] | |
| text = tokenizer.apply_chat_template(chat, tokenize=False) | |
| token_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=False).to("cuda" if torch.cuda.is_available() else "cpu") | |
| return token_ids | |
| # Function to get response to prompt | |
| def get_response(model_pipe, prompt): | |
| response = model_pipe(prompt)[0]['generated_text'] | |
| return response | |
| # Function to highlight tokens based on given values | |
| def plot_tokens_with_highlights(tokens, values, concept, cmap_name='Oranges', vmin=None, vmax=None): | |
| if len(tokens) != len(values): | |
| raise ValueError("The number of tokens and values must be the same.") | |
| # Set color map | |
| cmap = cm.get_cmap(cmap_name) | |
| norm = clrs.Normalize(vmin=vmin if vmin is not None else values.detach().min(), | |
| vmax=vmax if vmax is not None else values.detach().max()) | |
| html_output = f"<h3>How much information about the concept '{concept}' is carried in each token:</h3>" | |
| for token, value in zip(tokens, values.detach().numpy()): | |
| rgba_color = cmap(norm(value)) | |
| hex_color = '#%02x%02x%02x' % (int(rgba_color[0]*255), int(rgba_color[1]*255), int(rgba_color[2]*255)) | |
| html_output += f'<span style="background-color: {hex_color};" title="{value:.4f}">{token}</span> ' | |
| return html_output | |
| # Function to get concepts dictionary | |
| def get_concepts_dictionary(dictionary_url): | |
| response = requests.get(dictionary_url, stream=True) | |
| response.raise_for_status() | |
| data_dict = {} | |
| for line in response.iter_lines(decode_unicode=True): | |
| if line: | |
| obj = json.loads(line) | |
| concept_id = obj.get("concept_id") | |
| concept = obj.get("concept") | |
| if concept_id and concept: | |
| data_dict[concept_id] = concept.capitalize() | |
| return data_dict | |
| # Function to get matching concepts | |
| def select_concepts(all_concepts, desired_concept): | |
| concept_ids = [] | |
| for k, v in all_concepts.items(): | |
| if desired_concept.lower() in v.lower(): | |
| concept_ids.append(k) | |
| concept_data = [] | |
| for concept_id in concept_ids: | |
| concept_name = all_concepts.get(concept_id, "Unknown Concept") | |
| concept_data.append({"Concept ID": concept_id, "Concept Name": concept_name}) | |
| concept_df = pd.DataFrame(concept_data) | |
| return torch.tensor(concept_ids), concept_df | |