Spaces:
Sleeping
Sleeping
File size: 7,971 Bytes
4318cac 42ecf43 4318cac 42ecf43 4318cac 42ecf43 4318cac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import gradio as gr
import json
from typing import Any, List, Dict, Union
import torch
from sentence_transformers import SentenceTransformer
from huggingface_hub import login
import os
# Get currently avilable device
device = "cuda" if torch.cuda.is_available() else "cpu"
# SimilarityModel Config's
class Config:
"""Configuration settings for the application."""
EMBEDDING_MODEL_ID = "google/embeddinggemma-300M"
QUERY_PROMPT_NAME = "query"
TOOL_PROMPT_NAME = "document"
TOP_K = 3
HF_TOKEN = os.getenv('HF_TOKEN')
DEVICE = device
# Encapsulated Similarity Model
class SimilarityModel:
"""
A class for finding similar tools for given query using Sentence Transformer embeddings.
"""
def __init__(self, config: Config):
self.config = config
self._login_to_hf()
self.model = self._load_model()
self.tool_embeddings_cache = {}
def _login_to_hf(self):
"""Logs into Hugging Face Hub if a token is provided."""
if self.config.HF_TOKEN:
print("Logging into Hugging Face Hub...")
login(token=self.config.HF_TOKEN)
else:
print("HF_TOKEN not found. Proceeding without login.")
print("Note: This may fail if the model is gated.")
def _load_model(self) -> SentenceTransformer:
"""Loads the Sentence Transformer model."""
print(f"Initializing embedding model: {self.config.EMBEDDING_MODEL_ID}...")
try:
return SentenceTransformer(self.config.EMBEDDING_MODEL_ID).to(self.config.DEVICE)
except Exception as e:
print(f"Error loading model: {e}")
raise
def _validate_query_tools(self, query: Union[str, Any], tools_list: Union[List[Dict], Any]) -> Union[str, List[Dict]]:
"""
Validates the query and tools data to ensure formats.
Args:
query: The user query string.
tools_list: JSON instance, list of dict where each dict represents a tool declaration.
Returns:
True If the query and tools data are valid, then returns tools_data as converted from JSON to list of dict.
False string saying invalid query or tools data.
"""
is_valid_query = isinstance(query, str) and len(query.strip()) > 0
if not is_valid_query:
return "Invalid query. It should be a non-empty string."
# If tools_list are already in format of list of dict.
is_already_valid_tools = isinstance(tools_list, list) and all(isinstance(d, dict) for d in tools_list)
if is_already_valid_tools:
return tools_list
# If tools_list is string but it's list of dict, then json loads will parse
try:
tools_data = json.loads(tools_list)
except json.JSONDecodeError:
return "Invalid JSON format for tools data."
is_valid_tools = isinstance(tools_data, list) and all(isinstance(d, dict) for d in tools_data)
if not is_valid_tools:
return "Invalid tools data. It should be a list of dictionaries."
return tools_data
def cache_tool_embeddings(self, tools_data: List[Dict], tools_cache_key: str, cache_tool: float = True)-> torch.Tensor:
"""
If already tools embeddings are cached returns. If not cached computes tools embeddings and caches.
Args:
tools_data: List of JSON like format, where each dict represents a tool declaration.
tools_cache_key: Unique key for caching based on the tools data.
cache_tool: Whether to cache the tools embeddings or not.
"""
if tools_cache_key in self.tool_embeddings_cache:
tool_description_embeddings = self.tool_embeddings_cache[tools_cache_key]
else:
tool_descriptions = [tool["description"] for tool in tools_data]
tool_description_embeddings = self.model.encode(tool_descriptions, normalize_embeddings=True, prompt_name= self.config.TOOL_PROMPT_NAME)
if cache_tool:
self.tool_embeddings_cache[tools_cache_key] = tool_description_embeddings
return tool_description_embeddings
def find_similar_tools(self, query: str, tools_list: List[Dict], top_k: int, cache_tool_embs: bool= True):
"""
Finds the top_k most similar tools to a given query using Sentence Transformer embeddings.
Args:
query: The user query string.
tools_list: JSON instance, list of dict where each dict represents a tool declaration.
top_k: The number of top similar tools to return.
Returns:
A string containing the names and descriptions of the top_k similar tools, formatted for clarity.
"""
# Validate: query and tools_list
tools_data = self._validate_query_tools(query, tools_list)
try:
assert isinstance(tools_data, list) and all(isinstance(d, dict) for d in tools_data)
except AssertionError:
return tools_data
# Create a unique key for caching based on the tools data
tools_cache_key = json.dumps(tools_data, sort_keys=True)
# Compute tools embedding or get cached embeddings
tool_description_embeddings = self.cache_tool_embeddings(tools_data, tools_cache_key, cache_tool = cache_tool_embs)
# Everytime computing query embeddings, query is from user is always user's stochastic
query_embedding = self.model.encode(query, normalize_embeddings=True, prompt_name= self.config.QUERY_PROMPT_NAME)
# Similarity scores B/W user query and tools embeddings
similarity_scores = self.model.similarity(query_embedding, tool_description_embeddings).cpu()
# Ensure top_k does not exceed the number of available tools
actual_top_k = min(top_k or self.config.TOP_K, len(tools_data))
top_tool_indices = similarity_scores.argsort().flatten()[-actual_top_k:]
# Reverse the indices to get the most similar first
top_tool_indices = top_tool_indices.tolist()[::-1]
top_tools = [tools_data[int(i)] for i in top_tool_indices]
# Format the output for the Gradio Textbox
output_text = f"Top {actual_top_k} most similar tools:\n\n"
for i, tool in enumerate(top_tools):
output_text += f"{i+1}. Name: {tool['name']}\n"
output_text += f" Description: {tool['description']}\n"
if i < len(top_tools) - 1:
output_text += "---\n" # Add a separator between tools
if not top_tools:
output_text = "No tools found."
return output_text, top_tools
def create_ui(model: SimilarityModel):
"""Pretty UI with Gradio for user to interact with"""
with gr.Blocks() as demo:
gr.Interface(
fn = model.find_similar_tools,
inputs=[
gr.Textbox(label="Query"),
gr.Textbox(
lines=10,
label="Define tool declaration here",
info="Please enter a valid JSON string. For e.g, a list of dict's (name & desc π).",
placeholder='''[
{
"name": "get_current_weather",
"description": "Get the current weather in a given location"
}
]'''),
gr.Number(label="Top K", value=3, precision=0),
gr.Checkbox(label="Cache Tool Embeddings", value=True)
],
outputs=[
gr.TextArea(label="Similar Tools (Name and Description)", lines = 5),
gr.JSON(label= "Similar Tools JSON-format")
],
title="Tool Similarity Finder using Embedding Gemma 300M",
description="Enter a query and a list of tools to find the most similar tools based on embeddings."
)
return demo
if __name__ == "__main__":
similarity_model = SimilarityModel(config = Config())
demo = create_ui(similarity_model)
demo.launch(
mcp_server= True
) |