Spaces:
Running
Running
| from __future__ import annotations | |
| import logging | |
| from enum import Enum | |
| from typing import List, Union, Dict, Optional, NamedTuple, Any | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from io import BytesIO | |
| from hashlib import md5 | |
| from cachetools import LRUCache | |
| import requests | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoProcessor, AutoModel | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| class TextModelType(str, Enum): | |
| """ | |
| Enumeration of supported text models. | |
| """ | |
| MULTILINGUAL_E5_SMALL = "multilingual-e5-small" | |
| MULTILINGUAL_E5_BASE = "multilingual-e5-base" | |
| MULTILINGUAL_E5_LARGE = "multilingual-e5-large" | |
| SNOWFLAKE_ARCTIC_EMBED_L_V2 = "snowflake-arctic-embed-l-v2.0" | |
| PARAPHRASE_MULTILINGUAL_MINILM_L12_V2 = "paraphrase-multilingual-MiniLM-L12-v2" | |
| PARAPHRASE_MULTILINGUAL_MPNET_BASE_V2 = "paraphrase-multilingual-mpnet-base-v2" | |
| BGE_M3 = "bge-m3" | |
| GTE_MULTILINGUAL_BASE = "gte-multilingual-base" | |
| class ImageModelType(str, Enum): | |
| """ | |
| Enumeration of supported image models. | |
| """ | |
| SIGLIP_BASE_PATCH16_256_MULTILINGUAL = "siglip-base-patch16-256-multilingual" | |
| class ModelInfo(NamedTuple): | |
| """ | |
| This container maps an enum to: | |
| - model_id: Hugging Face model ID (or local path) | |
| - onnx_file: Path to ONNX file (if available) | |
| """ | |
| model_id: str | |
| onnx_file: Optional[str] = None | |
| class ModelConfig: | |
| """ | |
| Configuration for text and image models. | |
| """ | |
| text_model_type: TextModelType = TextModelType.MULTILINGUAL_E5_SMALL | |
| image_model_type: ImageModelType = ( | |
| ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL | |
| ) | |
| logit_scale: float = 4.60517 # Example scale used in cross-modal similarity | |
| def text_model_info(self) -> ModelInfo: | |
| """ | |
| Returns ModelInfo for the configured text_model_type. | |
| """ | |
| text_configs = { | |
| TextModelType.MULTILINGUAL_E5_SMALL: ModelInfo( | |
| model_id="Xenova/multilingual-e5-small", | |
| onnx_file="onnx/model_quantized.onnx", | |
| ), | |
| TextModelType.MULTILINGUAL_E5_BASE: ModelInfo( | |
| model_id="Xenova/multilingual-e5-base", | |
| onnx_file="onnx/model_quantized.onnx", | |
| ), | |
| TextModelType.MULTILINGUAL_E5_LARGE: ModelInfo( | |
| model_id="Xenova/multilingual-e5-large", | |
| onnx_file="onnx/model_quantized.onnx", | |
| ), | |
| TextModelType.SNOWFLAKE_ARCTIC_EMBED_L_V2: ModelInfo( | |
| model_id="Snowflake/snowflake-arctic-embed-l-v2.0", | |
| onnx_file="onnx/model_quantized.onnx", | |
| ), | |
| TextModelType.PARAPHRASE_MULTILINGUAL_MINILM_L12_V2: ModelInfo( | |
| model_id="Xenova/paraphrase-multilingual-MiniLM-L12-v2", | |
| onnx_file="onnx/model_quantized.onnx", | |
| ), | |
| TextModelType.PARAPHRASE_MULTILINGUAL_MPNET_BASE_V2: ModelInfo( | |
| model_id="Xenova/paraphrase-multilingual-mpnet-base-v2", | |
| onnx_file="onnx/model_quantized.onnx", | |
| ), | |
| TextModelType.BGE_M3: ModelInfo( | |
| model_id="Xenova/bge-m3", | |
| onnx_file="onnx/model_quantized.onnx", | |
| ), | |
| TextModelType.GTE_MULTILINGUAL_BASE: ModelInfo( | |
| model_id="onnx-community/gte-multilingual-base", | |
| onnx_file="onnx/model_quantized.onnx", | |
| ), | |
| } | |
| return text_configs[self.text_model_type] | |
| def image_model_info(self) -> ModelInfo: | |
| """ | |
| Returns ModelInfo for the configured image_model_type. | |
| """ | |
| image_configs = { | |
| ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL: ModelInfo( | |
| model_id="google/siglip-base-patch16-256-multilingual" | |
| ), | |
| } | |
| return image_configs[self.image_model_type] | |
| class ModelKind(str, Enum): | |
| TEXT = "text" | |
| IMAGE = "image" | |
| def detect_model_kind(model_id: str) -> ModelKind: | |
| """ | |
| Detect whether model_id belongs to a text or an image model. | |
| Raises ValueError if the model is not recognized. | |
| """ | |
| if model_id in [m.value for m in TextModelType]: | |
| return ModelKind.TEXT | |
| elif model_id in [m.value for m in ImageModelType]: | |
| return ModelKind.IMAGE | |
| else: | |
| raise ValueError( | |
| f"Unrecognized model ID: {model_id}.\n" | |
| f"Valid text: {[m.value for m in TextModelType]}\n" | |
| f"Valid image: {[m.value for m in ImageModelType]}" | |
| ) | |
| class EmbeddingsService: | |
| """ | |
| Service for generating text/image embeddings and performing similarity ranking. | |
| Batch size has been removed. Single or multiple inputs are handled uniformly. | |
| """ | |
| def __init__(self, config: Optional[ModelConfig] = None): | |
| self.lru_cache = LRUCache(maxsize=10_000) | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.config = config or ModelConfig() | |
| # Dictionaries to hold preloaded models | |
| self.text_models: Dict[TextModelType, SentenceTransformer] = {} | |
| self.image_models: Dict[ImageModelType, AutoModel] = {} | |
| self.image_processors: Dict[ImageModelType, AutoProcessor] = {} | |
| # Load all relevant models on init | |
| self._load_all_models() | |
| def _load_all_models(self) -> None: | |
| """ | |
| Pre-load all known text and image models for quick switching. | |
| """ | |
| try: | |
| # Preload text models | |
| for t_model_type in TextModelType: | |
| info = ModelConfig(text_model_type=t_model_type).text_model_info | |
| logger.info("Loading text model: %s", info.model_id) | |
| if info.onnx_file: | |
| logger.info("Using ONNX file: %s", info.onnx_file) | |
| self.text_models[t_model_type] = SentenceTransformer( | |
| info.model_id, | |
| device=self.device, | |
| backend="onnx", | |
| model_kwargs={ | |
| "provider": "CPUExecutionProvider", | |
| "file_name": info.onnx_file, | |
| }, | |
| trust_remote_code=True, | |
| ) | |
| else: | |
| self.text_models[t_model_type] = SentenceTransformer( | |
| info.model_id, | |
| device=self.device, | |
| trust_remote_code=True, | |
| ) | |
| # Preload image models | |
| for i_model_type in ImageModelType: | |
| model_id = ModelConfig( | |
| image_model_type=i_model_type | |
| ).image_model_info.model_id | |
| logger.info("Loading image model: %s", model_id) | |
| model = AutoModel.from_pretrained(model_id).to(self.device) | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| self.image_models[i_model_type] = model | |
| self.image_processors[i_model_type] = processor | |
| logger.info("All models loaded successfully.") | |
| except Exception as e: | |
| msg = f"Error loading models: {str(e)}" | |
| logger.error(msg) | |
| raise RuntimeError(msg) from e | |
| def _validate_text_list(input_text: Union[str, List[str]]) -> List[str]: | |
| """ | |
| Convert text input into a non-empty list of strings. | |
| Raises ValueError if the input is invalid. | |
| """ | |
| if isinstance(input_text, str): | |
| if not input_text.strip(): | |
| raise ValueError("Text input cannot be empty.") | |
| return [input_text] | |
| if not isinstance(input_text, list) or not all( | |
| isinstance(x, str) for x in input_text | |
| ): | |
| raise ValueError("Text input must be a string or a list of strings.") | |
| if len(input_text) == 0: | |
| raise ValueError("Text input list cannot be empty.") | |
| return input_text | |
| def _validate_image_list(input_images: Union[str, List[str]]) -> List[str]: | |
| """ | |
| Convert image input into a non-empty list of image paths/URLs. | |
| Raises ValueError if the input is invalid. | |
| """ | |
| if isinstance(input_images, str): | |
| if not input_images.strip(): | |
| raise ValueError("Image input cannot be empty.") | |
| return [input_images] | |
| if not isinstance(input_images, list) or not all( | |
| isinstance(x, str) for x in input_images | |
| ): | |
| raise ValueError("Image input must be a string or a list of strings.") | |
| if len(input_images) == 0: | |
| raise ValueError("Image input list cannot be empty.") | |
| return input_images | |
| def _process_image(self, path_or_url: str) -> Dict[str, torch.Tensor]: | |
| """ | |
| Loads and processes a single image from local path or URL. | |
| Returns a dictionary of tensors ready for the model. | |
| """ | |
| try: | |
| if path_or_url.startswith("http"): | |
| resp = requests.get(path_or_url, timeout=10) | |
| resp.raise_for_status() | |
| img = Image.open(BytesIO(resp.content)).convert("RGB") | |
| else: | |
| img = Image.open(Path(path_or_url)).convert("RGB") | |
| processor = self.image_processors[self.config.image_model_type] | |
| processed_data = processor(images=img, return_tensors="pt").to(self.device) | |
| return processed_data | |
| except Exception as e: | |
| raise ValueError(f"Error processing image '{path_or_url}': {str(e)}") from e | |
| def _generate_text_embeddings( | |
| self, | |
| model_id: TextModelType, | |
| texts: List[str], | |
| ) -> np.ndarray: | |
| """ | |
| Generates text embeddings using the SentenceTransformer-based model. | |
| Utilizes an LRU cache for single-input scenarios. | |
| """ | |
| try: | |
| if len(texts) == 1: | |
| single_text = texts[0] | |
| key = md5(f"{model_id}:{single_text}".encode("utf-8")).hexdigest()[:8] | |
| if key in self.lru_cache: | |
| return self.lru_cache[key] | |
| model = self.text_models[model_id] | |
| emb = model.encode([single_text]) | |
| self.lru_cache[key] = emb | |
| return emb | |
| # For multiple texts, no LRU cache is used | |
| model = self.text_models[model_id] | |
| return model.encode(texts) | |
| except Exception as e: | |
| raise RuntimeError( | |
| f"Error generating text embeddings with model '{model_id}': {e}" | |
| ) from e | |
| def _generate_image_embeddings( | |
| self, | |
| model_id: ImageModelType, | |
| images: List[str], | |
| ) -> np.ndarray: | |
| """ | |
| Generates image embeddings using the CLIP-like transformer model. | |
| Handles single or multiple images uniformly (no batch size parameter). | |
| """ | |
| try: | |
| model = self.image_models[model_id] | |
| # Collect processed inputs in a single batch | |
| processed_tensors = [] | |
| for img_path in images: | |
| processed_tensors.append(self._process_image(img_path)) | |
| # Keys should be the same for all processed outputs | |
| keys = processed_tensors[0].keys() | |
| # Concatenate along the batch dimension | |
| combined = { | |
| k: torch.cat([pt[k] for pt in processed_tensors], dim=0) for k in keys | |
| } | |
| with torch.no_grad(): | |
| embeddings = model.get_image_features(**combined) | |
| return embeddings.cpu().numpy() | |
| except Exception as e: | |
| raise RuntimeError( | |
| f"Error generating image embeddings with model '{model_id}': {e}" | |
| ) from e | |
| async def generate_embeddings( | |
| self, | |
| model: str, | |
| inputs: Union[str, List[str]], | |
| ) -> np.ndarray: | |
| """ | |
| Asynchronously generates embeddings for either text or image based on the model type. | |
| """ | |
| modality = detect_model_kind(model) | |
| if modality == ModelKind.TEXT: | |
| text_model_id = TextModelType(model) | |
| text_list = self._validate_text_list(inputs) | |
| return self._generate_text_embeddings(text_model_id, text_list) | |
| elif modality == ModelKind.IMAGE: | |
| image_model_id = ImageModelType(model) | |
| image_list = self._validate_image_list(inputs) | |
| return self._generate_image_embeddings(image_model_id, image_list) | |
| async def rank( | |
| self, | |
| model: str, | |
| queries: Union[str, List[str]], | |
| candidates: Union[str, List[str]], | |
| ) -> Dict[str, Any]: | |
| """ | |
| Ranks text `candidates` given `queries`, which can be text or images. | |
| Always returns a dictionary of { probabilities, cosine_similarities, usage }. | |
| Note: This implementation uses the same model for both queries and candidates. | |
| For true cross-modal ranking, you might need separate models or a shared model. | |
| """ | |
| modality = detect_model_kind(model) | |
| # Convert the string model to the appropriate enum | |
| if modality == ModelKind.TEXT: | |
| model_enum = TextModelType(model) | |
| else: | |
| model_enum = ImageModelType(model) | |
| # 1) Generate embeddings for queries | |
| query_embeds = await self.generate_embeddings(model_enum.value, queries) | |
| # 2) Generate embeddings for candidates (assumed text if queries are text; | |
| # or if queries are images, also use the image model for candidates). | |
| candidate_embeds = await self.generate_embeddings(model_enum.value, candidates) | |
| # 3) Compute cosine similarity | |
| sim_matrix = self.cosine_similarity(query_embeds, candidate_embeds) | |
| # 4) Apply logit scale + softmax to obtain probabilities | |
| scaled = np.exp(self.config.logit_scale) * sim_matrix | |
| probs = self.softmax(scaled) | |
| # 5) Estimate token usage if we're dealing with text | |
| if modality == ModelKind.TEXT: | |
| query_tokens = self.estimate_tokens(queries) | |
| candidate_tokens = self.estimate_tokens(candidates) | |
| total_tokens = query_tokens + candidate_tokens | |
| else: | |
| total_tokens = 0 | |
| usage = { | |
| "prompt_tokens": total_tokens, | |
| "total_tokens": total_tokens, | |
| } | |
| return { | |
| "probabilities": probs.tolist(), | |
| "cosine_similarities": sim_matrix.tolist(), | |
| "usage": usage, | |
| } | |
| def estimate_tokens(self, input_data: Union[str, List[str]]) -> int: | |
| """ | |
| Estimates token count using the SentenceTransformer tokenizer. | |
| Only applicable if the current configured model is a text model. | |
| """ | |
| texts = self._validate_text_list(input_data) | |
| model = self.text_models[self.config.text_model_type] | |
| tokenized = model.tokenize(texts) | |
| # Summing over the lengths of input_ids for each example | |
| return sum(len(ids) for ids in tokenized["input_ids"]) | |
| def softmax(scores: np.ndarray) -> np.ndarray: | |
| """ | |
| Applies the standard softmax function along the last dimension. | |
| """ | |
| # Stabilize scores by subtracting max | |
| exps = np.exp(scores - np.max(scores, axis=-1, keepdims=True)) | |
| return exps / np.sum(exps, axis=-1, keepdims=True) | |
| def cosine_similarity(a: np.ndarray, b: np.ndarray) -> np.ndarray: | |
| """ | |
| Computes the pairwise cosine similarity between all rows of a and b. | |
| a: (N, D) | |
| b: (M, D) | |
| Return: (N, M) matrix of cosine similarities | |
| """ | |
| a_norm = a / (np.linalg.norm(a, axis=1, keepdims=True) + 1e-9) | |
| b_norm = b / (np.linalg.norm(b, axis=1, keepdims=True) + 1e-9) | |
| return np.dot(a_norm, b_norm.T) | |