import gradio as gr import json from typing import Any, List, Dict, Union import torch from sentence_transformers import SentenceTransformer from huggingface_hub import login import os # Get currently avilable device device = "cuda" if torch.cuda.is_available() else "cpu" # SimilarityModel Config's class Config: """Configuration settings for the application.""" EMBEDDING_MODEL_ID = "google/embeddinggemma-300M" QUERY_PROMPT_NAME = "query" TOOL_PROMPT_NAME = "document" TOP_K = 3 HF_TOKEN = os.getenv('HF_TOKEN') DEVICE = device # Encapsulated Similarity Model class SimilarityModel: """ A class for finding similar tools for given query using Sentence Transformer embeddings. """ def __init__(self, config: Config): self.config = config self._login_to_hf() self.model = self._load_model() self.tool_embeddings_cache = {} def _login_to_hf(self): """Logs into Hugging Face Hub if a token is provided.""" if self.config.HF_TOKEN: print("Logging into Hugging Face Hub...") login(token=self.config.HF_TOKEN) else: print("HF_TOKEN not found. Proceeding without login.") print("Note: This may fail if the model is gated.") def _load_model(self) -> SentenceTransformer: """Loads the Sentence Transformer model.""" print(f"Initializing embedding model: {self.config.EMBEDDING_MODEL_ID}...") try: return SentenceTransformer(self.config.EMBEDDING_MODEL_ID).to(self.config.DEVICE) except Exception as e: print(f"Error loading model: {e}") raise def _validate_query_tools(self, query: Union[str, Any], tools_list: Union[List[Dict], Any]) -> Union[str, List[Dict]]: """ Validates the query and tools data to ensure formats. Args: query: The user query string. tools_list: JSON instance, list of dict where each dict represents a tool declaration. Returns: True If the query and tools data are valid, then returns tools_data as converted from JSON to list of dict. False string saying invalid query or tools data. """ is_valid_query = isinstance(query, str) and len(query.strip()) > 0 if not is_valid_query: return "Invalid query. It should be a non-empty string." # If tools_list are already in format of list of dict. is_already_valid_tools = isinstance(tools_list, list) and all(isinstance(d, dict) for d in tools_list) if is_already_valid_tools: return tools_list # If tools_list is string but it's list of dict, then json loads will parse try: tools_data = json.loads(tools_list) except json.JSONDecodeError: return "Invalid JSON format for tools data." is_valid_tools = isinstance(tools_data, list) and all(isinstance(d, dict) for d in tools_data) if not is_valid_tools: return "Invalid tools data. It should be a list of dictionaries." return tools_data def cache_tool_embeddings(self, tools_data: List[Dict], tools_cache_key: str, cache_tool: float = True)-> torch.Tensor: """ If already tools embeddings are cached returns. If not cached computes tools embeddings and caches. Args: tools_data: List of JSON like format, where each dict represents a tool declaration. tools_cache_key: Unique key for caching based on the tools data. cache_tool: Whether to cache the tools embeddings or not. """ if tools_cache_key in self.tool_embeddings_cache: tool_description_embeddings = self.tool_embeddings_cache[tools_cache_key] else: tool_descriptions = [tool["description"] for tool in tools_data] tool_description_embeddings = self.model.encode(tool_descriptions, normalize_embeddings=True, prompt_name= self.config.TOOL_PROMPT_NAME) if cache_tool: self.tool_embeddings_cache[tools_cache_key] = tool_description_embeddings return tool_description_embeddings def find_similar_tools(self, query: str, tools_list: list[dict], top_k: int, cache_tool_embs: bool= True)-> list[dict]: """ Finds the top_k most similar tools to a given query using Sentence Transformer embeddings. Args: query: The user query string. tools_list: JSON instance, list of dict where each dict represents a tool declaration. top_k: The number of top similar tools to return. cache_tool_embs: What to cache tools embs? Default is True. Returns: A string containing the names and descriptions of the top_k similar tools, formatted for clarity. """ # Validate: query and tools_list tools_data = self._validate_query_tools(query, tools_list) try: assert isinstance(tools_data, list) and all(isinstance(d, dict) for d in tools_data) except AssertionError: return tools_data, json.dumps([{"Error": tools_data}]) # Create a unique key for caching based on the tools data tools_cache_key = json.dumps(tools_data, sort_keys=True) # Compute tools embedding or get cached embeddings tool_description_embeddings = self.cache_tool_embeddings(tools_data, tools_cache_key, cache_tool = cache_tool_embs) # Everytime computing query embeddings, query is from user is always user's stochastic query_embedding = self.model.encode(query, normalize_embeddings=True, prompt_name= self.config.QUERY_PROMPT_NAME) # Similarity scores B/W user query and tools embeddings similarity_scores = self.model.similarity(query_embedding, tool_description_embeddings).cpu() # Ensure top_k does not exceed the number of available tools actual_top_k = min(top_k or self.config.TOP_K, len(tools_data)) top_tool_indices = similarity_scores.argsort().flatten()[-actual_top_k:] # Reverse the indices to get the most similar first top_tool_indices = top_tool_indices.tolist()[::-1] top_tools = [tools_data[int(i)] for i in top_tool_indices] # Format the output for the Gradio Textbox output_text = f"Top {actual_top_k} most similar tools:\n\n" for i, tool in enumerate(top_tools): output_text += f"{i+1}. Name: {tool['name']}\n" output_text += f" Description: {tool['description']}\n" if i < len(top_tools) - 1: output_text += "---\n" # Add a separator between tools if not top_tools: output_text = "No tools found." return output_text, json.dumps(top_tools) def create_ui(model: SimilarityModel): """Pretty UI with Gradio for user to interact with""" with gr.Blocks() as demo: gr.Interface( fn = model.find_similar_tools, inputs=[ gr.Textbox(label="Query"), gr.Textbox( lines=6, label="Define tool declaration here", info="Please enter a valid JSON string. For e.g, a list of dict's (name & desc 👍).", placeholder='''[ { "name": "get_current_weather", "description": "Get the current weather in a given location" } ]'''), gr.Number(label="Top K", value=3, precision=0), gr.Checkbox(label="Cache Tool Embeddings", value=True) ], outputs=[ gr.TextArea(label="Similar Tools (Name and Description)", lines = 5), gr.JSON(label= "Similar Tools JSON-format") ], title="Tool Similarity Finder using Embedding Gemma 300M", description="Enter a query and a list of tools to find the most similar tools based on embeddings." ) return demo if __name__ == "__main__": similarity_model = SimilarityModel(config = Config()) demo = create_ui(similarity_model) demo.launch( mcp_server= True )