File size: 7,971 Bytes
4318cac
42ecf43
4318cac
 
 
 
 
42ecf43
4318cac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42ecf43
 
4318cac
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import gradio as gr
import json
from typing import Any, List, Dict, Union
import torch
from sentence_transformers import SentenceTransformer
from huggingface_hub import login
import os

# Get currently avilable device
device = "cuda" if torch.cuda.is_available() else "cpu"

# SimilarityModel Config's
class Config:
    """Configuration settings for the application."""
    EMBEDDING_MODEL_ID = "google/embeddinggemma-300M"
    QUERY_PROMPT_NAME = "query"
    TOOL_PROMPT_NAME = "document"
    TOP_K = 3
    HF_TOKEN = os.getenv('HF_TOKEN')
    DEVICE = device

# Encapsulated Similarity Model
class SimilarityModel:
    """
    A class for finding similar tools for given query using Sentence Transformer embeddings.
    """
    def __init__(self, config: Config):
        self.config = config
        self._login_to_hf()
        self.model = self._load_model()
        self.tool_embeddings_cache = {}

    def _login_to_hf(self):
        """Logs into Hugging Face Hub if a token is provided."""
        if self.config.HF_TOKEN:
            print("Logging into Hugging Face Hub...")
            login(token=self.config.HF_TOKEN)
        else:
            print("HF_TOKEN not found. Proceeding without login.")
            print("Note: This may fail if the model is gated.")

    def _load_model(self) -> SentenceTransformer:
        """Loads the Sentence Transformer model."""
        print(f"Initializing embedding model: {self.config.EMBEDDING_MODEL_ID}...")
        try:
            return SentenceTransformer(self.config.EMBEDDING_MODEL_ID).to(self.config.DEVICE)
        except Exception as e:
            print(f"Error loading model: {e}")
            raise

    def _validate_query_tools(self, query: Union[str, Any], tools_list: Union[List[Dict], Any]) -> Union[str, List[Dict]]:
        """
        Validates the query and tools data to ensure formats.

        Args:
            query: The user query string.
            tools_list: JSON instance, list of dict where each dict represents a tool declaration.
        Returns:
            True If the query and tools data are valid, then returns tools_data as converted from JSON to list of dict.
            False string saying invalid query or tools data.
        """
        is_valid_query = isinstance(query, str) and len(query.strip()) > 0
        if not is_valid_query:
            return "Invalid query. It should be a non-empty string."

        # If tools_list are already in format of list of dict.
        is_already_valid_tools = isinstance(tools_list, list) and all(isinstance(d, dict) for d in tools_list)
        if is_already_valid_tools:
            return tools_list

        # If tools_list is string but it's list of dict, then json loads will parse
        try:
            tools_data = json.loads(tools_list)
        except json.JSONDecodeError:
            return "Invalid JSON format for tools data."

        is_valid_tools = isinstance(tools_data, list) and all(isinstance(d, dict) for d in tools_data)
        if not is_valid_tools:
            return "Invalid tools data. It should be a list of dictionaries."

        return tools_data

    def cache_tool_embeddings(self, tools_data: List[Dict], tools_cache_key: str, cache_tool: float = True)-> torch.Tensor:
        """
        If already tools embeddings are cached returns. If not cached computes tools embeddings and caches.

        Args:
            tools_data: List of JSON like format, where each dict represents a tool declaration.
            tools_cache_key: Unique key for caching based on the tools data.
            cache_tool: Whether to cache the tools embeddings or not.
        """
        if tools_cache_key in self.tool_embeddings_cache:
            tool_description_embeddings = self.tool_embeddings_cache[tools_cache_key]
        else:
            tool_descriptions = [tool["description"] for tool in tools_data]
            tool_description_embeddings = self.model.encode(tool_descriptions, normalize_embeddings=True, prompt_name= self.config.TOOL_PROMPT_NAME)
            if cache_tool:
                self.tool_embeddings_cache[tools_cache_key] = tool_description_embeddings

        return tool_description_embeddings

    def find_similar_tools(self, query: str, tools_list: List[Dict], top_k: int, cache_tool_embs: bool= True):
        """
        Finds the top_k most similar tools to a given query using Sentence Transformer embeddings.

        Args:
            query: The user query string.
            tools_list: JSON instance, list of dict where each dict represents a tool declaration.
            top_k: The number of top similar tools to return.

        Returns:
            A string containing the names and descriptions of the top_k similar tools, formatted for clarity.
        """
        # Validate: query and tools_list
        tools_data = self._validate_query_tools(query, tools_list)
        try:
            assert isinstance(tools_data, list) and all(isinstance(d, dict) for d in tools_data)
        except AssertionError:
            return tools_data

        # Create a unique key for caching based on the tools data
        tools_cache_key = json.dumps(tools_data, sort_keys=True)

        # Compute tools embedding or get cached embeddings
        tool_description_embeddings = self.cache_tool_embeddings(tools_data, tools_cache_key, cache_tool = cache_tool_embs)

        # Everytime computing query embeddings, query is from user is always user's stochastic
        query_embedding = self.model.encode(query, normalize_embeddings=True, prompt_name= self.config.QUERY_PROMPT_NAME)

        # Similarity scores B/W user query and tools embeddings
        similarity_scores = self.model.similarity(query_embedding, tool_description_embeddings).cpu()

        # Ensure top_k does not exceed the number of available tools
        actual_top_k = min(top_k or self.config.TOP_K, len(tools_data))
        top_tool_indices = similarity_scores.argsort().flatten()[-actual_top_k:]

        # Reverse the indices to get the most similar first
        top_tool_indices = top_tool_indices.tolist()[::-1]
        top_tools = [tools_data[int(i)] for i in top_tool_indices]

        # Format the output for the Gradio Textbox
        output_text = f"Top {actual_top_k} most similar tools:\n\n"
        for i, tool in enumerate(top_tools):
            output_text += f"{i+1}. Name: {tool['name']}\n"
            output_text += f"   Description: {tool['description']}\n"
            if i < len(top_tools) - 1:
                output_text += "---\n" # Add a separator between tools

        if not top_tools:
            output_text = "No tools found."

        return output_text, top_tools

def create_ui(model: SimilarityModel):
    """Pretty UI with Gradio for user to interact with"""

    with gr.Blocks() as demo:
        gr.Interface(
        fn = model.find_similar_tools,
        inputs=[
            gr.Textbox(label="Query"),
            gr.Textbox(
            lines=10,
            label="Define tool declaration here",
            info="Please enter a valid JSON string. For e.g, a list of dict's (name & desc πŸ‘).",
            placeholder='''[
        {
            "name": "get_current_weather",
            "description": "Get the current weather in a given location"
        }
    ]'''),
            gr.Number(label="Top K", value=3, precision=0),
            gr.Checkbox(label="Cache Tool Embeddings", value=True)
        ],
        outputs=[
            gr.TextArea(label="Similar Tools (Name and Description)", lines = 5),
            gr.JSON(label= "Similar Tools JSON-format")
            ],
        title="Tool Similarity Finder using Embedding Gemma 300M",
        description="Enter a query and a list of tools to find the most similar tools based on embeddings."
        )
    return demo

if __name__ == "__main__":
    similarity_model = SimilarityModel(config = Config())
    demo = create_ui(similarity_model)
    demo.launch(
        mcp_server= True
    )