File size: 3,002 Bytes
af4a860
 
 
 
 
 
 
0f93598
af4a860
 
0f93598
af4a860
0f93598
af4a860
 
 
 
 
0f93598
af4a860
 
 
 
 
 
0f93598
af4a860
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as clrs
import requests
import json
import pandas as pd
import torch
import spaces

# Function to get tokens given text
@spaces.GPU
def get_tokens(tokenizer, text):
  token_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=False).to("cuda" if torch.cuda.is_available() else "cpu")
  tokens = tokenizer.convert_ids_to_tokens(token_ids[0])

  return tokens, token_ids

# Function to apply chat template to prompt
@spaces.GPU
def decorate_prompt(tokenizer, prompt):
  chat = [
    {"role": "user", "content": prompt},
    {"role": "assistant", "content": ""},
  ]
  text = tokenizer.apply_chat_template(chat, tokenize=False)
  token_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=False).to("cuda" if torch.cuda.is_available() else "cpu")
  
  return token_ids

# Function to get response to prompt
def get_response(model_pipe, prompt):
   response = model_pipe(prompt)[0]['generated_text']
   return response

# Function to highlight tokens based on given values
def plot_tokens_with_highlights(tokens, values, concept, cmap_name='Oranges', vmin=None, vmax=None):
    if len(tokens) != len(values):
        raise ValueError("The number of tokens and values must be the same.")

    # Set color map
    cmap = cm.get_cmap(cmap_name)

    norm = clrs.Normalize(vmin=vmin if vmin is not None else values.detach().min(),
                          vmax=vmax if vmax is not None else values.detach().max())

    html_output = f"<h3>How much information about the concept '{concept}' is carried in each token:</h3>"

    for token, value in zip(tokens, values.detach().numpy()):
        rgba_color = cmap(norm(value))
        hex_color = '#%02x%02x%02x' % (int(rgba_color[0]*255), int(rgba_color[1]*255), int(rgba_color[2]*255))
        html_output += f'<span style="background-color: {hex_color};" title="{value:.4f}">{token}</span> '

    return html_output

# Function to get concepts dictionary
def get_concepts_dictionary(dictionary_url):
    response = requests.get(dictionary_url, stream=True)
    response.raise_for_status()
    data_dict = {}
    for line in response.iter_lines(decode_unicode=True):
        if line:
            obj = json.loads(line)
            concept_id = obj.get("concept_id")
            concept = obj.get("concept")
            if concept_id and concept:
                data_dict[concept_id] = concept.capitalize()
    return data_dict

# Function to get matching concepts
def select_concepts(all_concepts, desired_concept):
  concept_ids = []
  for k, v in all_concepts.items():
    if desired_concept.lower() in v.lower():
      concept_ids.append(k)
  
  concept_data = []
  for concept_id in concept_ids:
      concept_name = all_concepts.get(concept_id, "Unknown Concept")
      concept_data.append({"Concept ID": concept_id, "Concept Name": concept_name})
  concept_df = pd.DataFrame(concept_data)

  return torch.tensor(concept_ids), concept_df