SURIAPRAKASH1 commited on
Commit
4318cac
Β·
1 Parent(s): 42ecf43

model and gradio ui implemented

Browse files
Files changed (1) hide show
  1. app.py +185 -20
app.py CHANGED
@@ -1,24 +1,189 @@
1
- import gradio as gr
2
- from typing import List, Dict
3
  import json
 
 
 
 
 
4
 
5
- def query_tool_similarity(query: str, tools_json: str, top_k: int):
6
- """Calculates similarity between query and each tool.
7
- Returns list of int as index for each tool in based on similarity."""
8
- try:
9
- tools = json.loads(tools_json)
10
- assert isinstance(tools, list) and all(isinstance(d, dict) for d in tools)
11
- except Exception as e:
12
- return f"Invalid JSON for tools: {e}"
13
- return f"{query}, {tools}, {top_k}"
14
-
15
- with gr.Blocks() as demo:
16
- gr.Interface(fn = query_tool_similarity,
17
- inputs= [gr.Textbox(label="Query"),
18
- gr.Textbox(label="Tools (List of Dicts as JSON)", lines=6, placeholder='e.g. [{"name": "foo", "desc": "bar"}]'),
19
- gr.Number(label="Top K", precision=0)],
20
- outputs= gr.Textbox(label="Result")
21
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  if __name__ == "__main__":
24
- demo.launch()
 
 
 
 
 
1
+ import gradio as gr
 
2
  import json
3
+ from typing import Any, List, Dict, Union
4
+ import torch
5
+ from sentence_transformers import SentenceTransformer
6
+ from huggingface_hub import login
7
+ import os
8
 
9
+ # Get currently avilable device
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ # SimilarityModel Config's
13
+ class Config:
14
+ """Configuration settings for the application."""
15
+ EMBEDDING_MODEL_ID = "google/embeddinggemma-300M"
16
+ QUERY_PROMPT_NAME = "query"
17
+ TOOL_PROMPT_NAME = "document"
18
+ TOP_K = 3
19
+ HF_TOKEN = os.getenv('HF_TOKEN')
20
+ DEVICE = device
21
+
22
+ # Encapsulated Similarity Model
23
+ class SimilarityModel:
24
+ """
25
+ A class for finding similar tools for given query using Sentence Transformer embeddings.
26
+ """
27
+ def __init__(self, config: Config):
28
+ self.config = config
29
+ self._login_to_hf()
30
+ self.model = self._load_model()
31
+ self.tool_embeddings_cache = {}
32
+
33
+ def _login_to_hf(self):
34
+ """Logs into Hugging Face Hub if a token is provided."""
35
+ if self.config.HF_TOKEN:
36
+ print("Logging into Hugging Face Hub...")
37
+ login(token=self.config.HF_TOKEN)
38
+ else:
39
+ print("HF_TOKEN not found. Proceeding without login.")
40
+ print("Note: This may fail if the model is gated.")
41
+
42
+ def _load_model(self) -> SentenceTransformer:
43
+ """Loads the Sentence Transformer model."""
44
+ print(f"Initializing embedding model: {self.config.EMBEDDING_MODEL_ID}...")
45
+ try:
46
+ return SentenceTransformer(self.config.EMBEDDING_MODEL_ID).to(self.config.DEVICE)
47
+ except Exception as e:
48
+ print(f"Error loading model: {e}")
49
+ raise
50
+
51
+ def _validate_query_tools(self, query: Union[str, Any], tools_list: Union[List[Dict], Any]) -> Union[str, List[Dict]]:
52
+ """
53
+ Validates the query and tools data to ensure formats.
54
+
55
+ Args:
56
+ query: The user query string.
57
+ tools_list: JSON instance, list of dict where each dict represents a tool declaration.
58
+ Returns:
59
+ True If the query and tools data are valid, then returns tools_data as converted from JSON to list of dict.
60
+ False string saying invalid query or tools data.
61
+ """
62
+ is_valid_query = isinstance(query, str) and len(query.strip()) > 0
63
+ if not is_valid_query:
64
+ return "Invalid query. It should be a non-empty string."
65
+
66
+ # If tools_list are already in format of list of dict.
67
+ is_already_valid_tools = isinstance(tools_list, list) and all(isinstance(d, dict) for d in tools_list)
68
+ if is_already_valid_tools:
69
+ return tools_list
70
+
71
+ # If tools_list is string but it's list of dict, then json loads will parse
72
+ try:
73
+ tools_data = json.loads(tools_list)
74
+ except json.JSONDecodeError:
75
+ return "Invalid JSON format for tools data."
76
+
77
+ is_valid_tools = isinstance(tools_data, list) and all(isinstance(d, dict) for d in tools_data)
78
+ if not is_valid_tools:
79
+ return "Invalid tools data. It should be a list of dictionaries."
80
+
81
+ return tools_data
82
+
83
+ def cache_tool_embeddings(self, tools_data: List[Dict], tools_cache_key: str, cache_tool: float = True)-> torch.Tensor:
84
+ """
85
+ If already tools embeddings are cached returns. If not cached computes tools embeddings and caches.
86
+
87
+ Args:
88
+ tools_data: List of JSON like format, where each dict represents a tool declaration.
89
+ tools_cache_key: Unique key for caching based on the tools data.
90
+ cache_tool: Whether to cache the tools embeddings or not.
91
+ """
92
+ if tools_cache_key in self.tool_embeddings_cache:
93
+ tool_description_embeddings = self.tool_embeddings_cache[tools_cache_key]
94
+ else:
95
+ tool_descriptions = [tool["description"] for tool in tools_data]
96
+ tool_description_embeddings = self.model.encode(tool_descriptions, normalize_embeddings=True, prompt_name= self.config.TOOL_PROMPT_NAME)
97
+ if cache_tool:
98
+ self.tool_embeddings_cache[tools_cache_key] = tool_description_embeddings
99
+
100
+ return tool_description_embeddings
101
+
102
+ def find_similar_tools(self, query: str, tools_list: List[Dict], top_k: int, cache_tool_embs: bool= True):
103
+ """
104
+ Finds the top_k most similar tools to a given query using Sentence Transformer embeddings.
105
+
106
+ Args:
107
+ query: The user query string.
108
+ tools_list: JSON instance, list of dict where each dict represents a tool declaration.
109
+ top_k: The number of top similar tools to return.
110
+
111
+ Returns:
112
+ A string containing the names and descriptions of the top_k similar tools, formatted for clarity.
113
+ """
114
+ # Validate: query and tools_list
115
+ tools_data = self._validate_query_tools(query, tools_list)
116
+ try:
117
+ assert isinstance(tools_data, list) and all(isinstance(d, dict) for d in tools_data)
118
+ except AssertionError:
119
+ return tools_data
120
+
121
+ # Create a unique key for caching based on the tools data
122
+ tools_cache_key = json.dumps(tools_data, sort_keys=True)
123
+
124
+ # Compute tools embedding or get cached embeddings
125
+ tool_description_embeddings = self.cache_tool_embeddings(tools_data, tools_cache_key, cache_tool = cache_tool_embs)
126
+
127
+ # Everytime computing query embeddings, query is from user is always user's stochastic
128
+ query_embedding = self.model.encode(query, normalize_embeddings=True, prompt_name= self.config.QUERY_PROMPT_NAME)
129
+
130
+ # Similarity scores B/W user query and tools embeddings
131
+ similarity_scores = self.model.similarity(query_embedding, tool_description_embeddings).cpu()
132
+
133
+ # Ensure top_k does not exceed the number of available tools
134
+ actual_top_k = min(top_k or self.config.TOP_K, len(tools_data))
135
+ top_tool_indices = similarity_scores.argsort().flatten()[-actual_top_k:]
136
+
137
+ # Reverse the indices to get the most similar first
138
+ top_tool_indices = top_tool_indices.tolist()[::-1]
139
+ top_tools = [tools_data[int(i)] for i in top_tool_indices]
140
+
141
+ # Format the output for the Gradio Textbox
142
+ output_text = f"Top {actual_top_k} most similar tools:\n\n"
143
+ for i, tool in enumerate(top_tools):
144
+ output_text += f"{i+1}. Name: {tool['name']}\n"
145
+ output_text += f" Description: {tool['description']}\n"
146
+ if i < len(top_tools) - 1:
147
+ output_text += "---\n" # Add a separator between tools
148
+
149
+ if not top_tools:
150
+ output_text = "No tools found."
151
+
152
+ return output_text, top_tools
153
+
154
+ def create_ui(model: SimilarityModel):
155
+ """Pretty UI with Gradio for user to interact with"""
156
+
157
+ with gr.Blocks() as demo:
158
+ gr.Interface(
159
+ fn = model.find_similar_tools,
160
+ inputs=[
161
+ gr.Textbox(label="Query"),
162
+ gr.Textbox(
163
+ lines=10,
164
+ label="Define tool declaration here",
165
+ info="Please enter a valid JSON string. For e.g, a list of dict's (name & desc πŸ‘).",
166
+ placeholder='''[
167
+ {
168
+ "name": "get_current_weather",
169
+ "description": "Get the current weather in a given location"
170
+ }
171
+ ]'''),
172
+ gr.Number(label="Top K", value=3, precision=0),
173
+ gr.Checkbox(label="Cache Tool Embeddings", value=True)
174
+ ],
175
+ outputs=[
176
+ gr.TextArea(label="Similar Tools (Name and Description)", lines = 5),
177
+ gr.JSON(label= "Similar Tools JSON-format")
178
+ ],
179
+ title="Tool Similarity Finder using Embedding Gemma 300M",
180
+ description="Enter a query and a list of tools to find the most similar tools based on embeddings."
181
+ )
182
+ return demo
183
 
184
  if __name__ == "__main__":
185
+ similarity_model = SimilarityModel(config = Config())
186
+ demo = create_ui(similarity_model)
187
+ demo.launch(
188
+ mcp_server= True
189
+ )