Spaces:
Paused
Paused
| from pymilvus import MilvusClient as Client | |
| from pymilvus import FieldSchema, DataType | |
| import json | |
| import logging | |
| from typing import Optional | |
| from open_webui.retrieval.vector.main import ( | |
| VectorDBBase, | |
| VectorItem, | |
| SearchResult, | |
| GetResult, | |
| ) | |
| from open_webui.config import ( | |
| MILVUS_URI, | |
| MILVUS_DB, | |
| MILVUS_TOKEN, | |
| MILVUS_INDEX_TYPE, | |
| MILVUS_METRIC_TYPE, | |
| MILVUS_HNSW_M, | |
| MILVUS_HNSW_EFCONSTRUCTION, | |
| MILVUS_IVF_FLAT_NLIST, | |
| ) | |
| from open_webui.env import SRC_LOG_LEVELS | |
| log = logging.getLogger(__name__) | |
| log.setLevel(SRC_LOG_LEVELS["RAG"]) | |
| class MilvusClient(VectorDBBase): | |
| def __init__(self): | |
| self.collection_prefix = "open_webui" | |
| if MILVUS_TOKEN is None: | |
| self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB) | |
| else: | |
| self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB, token=MILVUS_TOKEN) | |
| def _result_to_get_result(self, result) -> GetResult: | |
| ids = [] | |
| documents = [] | |
| metadatas = [] | |
| for match in result: | |
| _ids = [] | |
| _documents = [] | |
| _metadatas = [] | |
| for item in match: | |
| _ids.append(item.get("id")) | |
| _documents.append(item.get("data", {}).get("text")) | |
| _metadatas.append(item.get("metadata")) | |
| ids.append(_ids) | |
| documents.append(_documents) | |
| metadatas.append(_metadatas) | |
| return GetResult( | |
| **{ | |
| "ids": ids, | |
| "documents": documents, | |
| "metadatas": metadatas, | |
| } | |
| ) | |
| def _result_to_search_result(self, result) -> SearchResult: | |
| ids = [] | |
| distances = [] | |
| documents = [] | |
| metadatas = [] | |
| for match in result: | |
| _ids = [] | |
| _distances = [] | |
| _documents = [] | |
| _metadatas = [] | |
| for item in match: | |
| _ids.append(item.get("id")) | |
| # normalize milvus score from [-1, 1] to [0, 1] range | |
| # https://milvus.io/docs/de/metric.md | |
| _dist = (item.get("distance") + 1.0) / 2.0 | |
| _distances.append(_dist) | |
| _documents.append(item.get("entity", {}).get("data", {}).get("text")) | |
| _metadatas.append(item.get("entity", {}).get("metadata")) | |
| ids.append(_ids) | |
| distances.append(_distances) | |
| documents.append(_documents) | |
| metadatas.append(_metadatas) | |
| return SearchResult( | |
| **{ | |
| "ids": ids, | |
| "distances": distances, | |
| "documents": documents, | |
| "metadatas": metadatas, | |
| } | |
| ) | |
| def _create_collection(self, collection_name: str, dimension: int): | |
| schema = self.client.create_schema( | |
| auto_id=False, | |
| enable_dynamic_field=True, | |
| ) | |
| schema.add_field( | |
| field_name="id", | |
| datatype=DataType.VARCHAR, | |
| is_primary=True, | |
| max_length=65535, | |
| ) | |
| schema.add_field( | |
| field_name="vector", | |
| datatype=DataType.FLOAT_VECTOR, | |
| dim=dimension, | |
| description="vector", | |
| ) | |
| schema.add_field(field_name="data", datatype=DataType.JSON, description="data") | |
| schema.add_field( | |
| field_name="metadata", datatype=DataType.JSON, description="metadata" | |
| ) | |
| index_params = self.client.prepare_index_params() | |
| # Use configurations from config.py | |
| index_type = MILVUS_INDEX_TYPE.upper() | |
| metric_type = MILVUS_METRIC_TYPE.upper() | |
| log.info(f"Using Milvus index type: {index_type}, metric type: {metric_type}") | |
| index_creation_params = {} | |
| if index_type == "HNSW": | |
| index_creation_params = { | |
| "M": MILVUS_HNSW_M, | |
| "efConstruction": MILVUS_HNSW_EFCONSTRUCTION, | |
| } | |
| log.info(f"HNSW params: {index_creation_params}") | |
| elif index_type == "IVF_FLAT": | |
| index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST} | |
| log.info(f"IVF_FLAT params: {index_creation_params}") | |
| elif index_type in ["FLAT", "AUTOINDEX"]: | |
| log.info(f"Using {index_type} index with no specific build-time params.") | |
| else: | |
| log.warning( | |
| f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. " | |
| f"Supported types: HNSW, IVF_FLAT, FLAT, AUTOINDEX. " | |
| f"Milvus will use its default for the collection if this type is not directly supported for index creation." | |
| ) | |
| # For unsupported types, pass the type directly to Milvus; it might handle it or use a default. | |
| # If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var. | |
| index_params.add_index( | |
| field_name="vector", | |
| index_type=index_type, | |
| metric_type=metric_type, | |
| params=index_creation_params, | |
| ) | |
| self.client.create_collection( | |
| collection_name=f"{self.collection_prefix}_{collection_name}", | |
| schema=schema, | |
| index_params=index_params, | |
| ) | |
| log.info( | |
| f"Successfully created collection '{self.collection_prefix}_{collection_name}' with index type '{index_type}' and metric '{metric_type}'." | |
| ) | |
| def has_collection(self, collection_name: str) -> bool: | |
| # Check if the collection exists based on the collection name. | |
| collection_name = collection_name.replace("-", "_") | |
| return self.client.has_collection( | |
| collection_name=f"{self.collection_prefix}_{collection_name}" | |
| ) | |
| def delete_collection(self, collection_name: str): | |
| # Delete the collection based on the collection name. | |
| collection_name = collection_name.replace("-", "_") | |
| return self.client.drop_collection( | |
| collection_name=f"{self.collection_prefix}_{collection_name}" | |
| ) | |
| def search( | |
| self, collection_name: str, vectors: list[list[float | int]], limit: int | |
| ) -> Optional[SearchResult]: | |
| # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. | |
| collection_name = collection_name.replace("-", "_") | |
| # For some index types like IVF_FLAT, search params like nprobe can be set. | |
| # Example: search_params = {"nprobe": 10} if using IVF_FLAT | |
| # For simplicity, not adding configurable search_params here, but could be extended. | |
| result = self.client.search( | |
| collection_name=f"{self.collection_prefix}_{collection_name}", | |
| data=vectors, | |
| limit=limit, | |
| output_fields=["data", "metadata"], | |
| # search_params=search_params # Potentially add later if needed | |
| ) | |
| return self._result_to_search_result(result) | |
| def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): | |
| # Construct the filter string for querying | |
| collection_name = collection_name.replace("-", "_") | |
| if not self.has_collection(collection_name): | |
| log.warning( | |
| f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}" | |
| ) | |
| return None | |
| filter_string = " && ".join( | |
| [ | |
| f'metadata["{key}"] == {json.dumps(value)}' | |
| for key, value in filter.items() | |
| ] | |
| ) | |
| max_limit = 16383 # The maximum number of records per request | |
| all_results = [] | |
| if limit is None: | |
| # Milvus default limit for query if not specified is 16384, but docs mention iteration. | |
| # Let's set a practical high number if "all" is intended, or handle true pagination. | |
| # For now, if limit is None, we'll fetch in batches up to a very large number. | |
| # This part could be refined based on expected use cases for "get all". | |
| # For this function signature, None implies "as many as possible" up to Milvus limits. | |
| limit = ( | |
| 16384 * 10 | |
| ) # A large number to signify fetching many, will be capped by actual data or max_limit per call. | |
| log.info( | |
| f"Limit not specified for query, fetching up to {limit} results in batches." | |
| ) | |
| # Initialize offset and remaining to handle pagination | |
| offset = 0 | |
| remaining = limit | |
| try: | |
| log.info( | |
| f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}" | |
| ) | |
| # Loop until there are no more items to fetch or the desired limit is reached | |
| while remaining > 0: | |
| current_fetch = min( | |
| max_limit, remaining if isinstance(remaining, int) else max_limit | |
| ) | |
| log.debug( | |
| f"Querying with offset: {offset}, current_fetch: {current_fetch}" | |
| ) | |
| results = self.client.query( | |
| collection_name=f"{self.collection_prefix}_{collection_name}", | |
| filter=filter_string, | |
| output_fields=[ | |
| "id", | |
| "data", | |
| "metadata", | |
| ], # Explicitly list needed fields. Vector not usually needed in query. | |
| limit=current_fetch, | |
| offset=offset, | |
| ) | |
| if not results: | |
| log.debug("No more results from query.") | |
| break | |
| all_results.extend(results) | |
| results_count = len(results) | |
| log.debug(f"Fetched {results_count} results in this batch.") | |
| if isinstance(remaining, int): | |
| remaining -= results_count | |
| offset += results_count | |
| # Break the loop if the results returned are less than the requested fetch count (means end of data) | |
| if results_count < current_fetch: | |
| log.debug( | |
| "Fetched less than requested, assuming end of results for this query." | |
| ) | |
| break | |
| log.info(f"Total results from query: {len(all_results)}") | |
| return self._result_to_get_result([all_results]) | |
| except Exception as e: | |
| log.exception( | |
| f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}" | |
| ) | |
| return None | |
| def get(self, collection_name: str) -> Optional[GetResult]: | |
| # Get all the items in the collection. This can be very resource-intensive for large collections. | |
| collection_name = collection_name.replace("-", "_") | |
| log.warning( | |
| f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections." | |
| ) | |
| # Using query with a trivial filter to get all items. | |
| # This will use the paginated query logic. | |
| return self.query(collection_name=collection_name, filter={}, limit=None) | |
| def insert(self, collection_name: str, items: list[VectorItem]): | |
| # Insert the items into the collection, if the collection does not exist, it will be created. | |
| collection_name = collection_name.replace("-", "_") | |
| if not self.client.has_collection( | |
| collection_name=f"{self.collection_prefix}_{collection_name}" | |
| ): | |
| log.info( | |
| f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now." | |
| ) | |
| if not items: | |
| log.error( | |
| f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension." | |
| ) | |
| raise ValueError( | |
| "Cannot create Milvus collection without items to determine vector dimension." | |
| ) | |
| self._create_collection( | |
| collection_name=collection_name, dimension=len(items[0]["vector"]) | |
| ) | |
| log.info( | |
| f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}." | |
| ) | |
| return self.client.insert( | |
| collection_name=f"{self.collection_prefix}_{collection_name}", | |
| data=[ | |
| { | |
| "id": item["id"], | |
| "vector": item["vector"], | |
| "data": {"text": item["text"]}, | |
| "metadata": item["metadata"], | |
| } | |
| for item in items | |
| ], | |
| ) | |
| def upsert(self, collection_name: str, items: list[VectorItem]): | |
| # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. | |
| collection_name = collection_name.replace("-", "_") | |
| if not self.client.has_collection( | |
| collection_name=f"{self.collection_prefix}_{collection_name}" | |
| ): | |
| log.info( | |
| f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now." | |
| ) | |
| if not items: | |
| log.error( | |
| f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension." | |
| ) | |
| raise ValueError( | |
| "Cannot create Milvus collection for upsert without items to determine vector dimension." | |
| ) | |
| self._create_collection( | |
| collection_name=collection_name, dimension=len(items[0]["vector"]) | |
| ) | |
| log.info( | |
| f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}." | |
| ) | |
| return self.client.upsert( | |
| collection_name=f"{self.collection_prefix}_{collection_name}", | |
| data=[ | |
| { | |
| "id": item["id"], | |
| "vector": item["vector"], | |
| "data": {"text": item["text"]}, | |
| "metadata": item["metadata"], | |
| } | |
| for item in items | |
| ], | |
| ) | |
| def delete( | |
| self, | |
| collection_name: str, | |
| ids: Optional[list[str]] = None, | |
| filter: Optional[dict] = None, | |
| ): | |
| # Delete the items from the collection based on the ids or filter. | |
| collection_name = collection_name.replace("-", "_") | |
| if not self.has_collection(collection_name): | |
| log.warning( | |
| f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}" | |
| ) | |
| return None | |
| if ids: | |
| log.info( | |
| f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}" | |
| ) | |
| return self.client.delete( | |
| collection_name=f"{self.collection_prefix}_{collection_name}", | |
| ids=ids, | |
| ) | |
| elif filter: | |
| filter_string = " && ".join( | |
| [ | |
| f'metadata["{key}"] == {json.dumps(value)}' | |
| for key, value in filter.items() | |
| ] | |
| ) | |
| log.info( | |
| f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}" | |
| ) | |
| return self.client.delete( | |
| collection_name=f"{self.collection_prefix}_{collection_name}", | |
| filter=filter_string, | |
| ) | |
| else: | |
| log.warning( | |
| f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken." | |
| ) | |
| return None | |
| def reset(self): | |
| # Resets the database. This will delete all collections and item entries that match the prefix. | |
| log.warning( | |
| f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'." | |
| ) | |
| collection_names = self.client.list_collections() | |
| deleted_collections = [] | |
| for collection_name_full in collection_names: | |
| if collection_name_full.startswith(self.collection_prefix): | |
| try: | |
| self.client.drop_collection(collection_name=collection_name_full) | |
| deleted_collections.append(collection_name_full) | |
| log.info(f"Deleted collection: {collection_name_full}") | |
| except Exception as e: | |
| log.error(f"Error deleting collection {collection_name_full}: {e}") | |
| log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}") | |