Spaces:
Sleeping
Sleeping
| from abc import abstractmethod | |
| from functools import cache | |
| import os | |
| from qdrant_client import QdrantClient | |
| from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings | |
| from langchain.vectorstores import Qdrant, ElasticVectorSearch, VectorStore | |
| from qdrant_client.models import VectorParams, Distance | |
| from db.embedding import Embedding, EMBEDDINGS | |
| class Store: | |
| def get_embedding(): | |
| embedding = os.getenv("EMBEDDING") | |
| if not embedding: | |
| return EMBEDDINGS["OPEN_AI"] | |
| return EMBEDDINGS[embedding] | |
| def get_instance(): | |
| vector_store = os.getenv("STORE") | |
| if vector_store == "ELASTIC": | |
| return ElasticVectorStore(Store.get_embedding()) | |
| elif vector_store == "QDRANT": | |
| return QdrantVectorStore(Store.get_embedding()) | |
| else: | |
| raise ValueError(f"Invalid vector store {vector_store}") | |
| def __init__(self, embedding: Embedding): | |
| self.embedding = embedding | |
| def get_collection(self, collection: str="test") -> VectorStore: | |
| """ | |
| get an instance of vector store | |
| of collection | |
| """ | |
| pass | |
| def create_collection(self, collection: str) -> None: | |
| """ | |
| create an instance of vector store | |
| with collection name | |
| """ | |
| pass | |
| def list_collections(self) -> list[dict]: | |
| """ | |
| Return a list of collections in the vecot store. | |
| """ | |
| pass | |
| class ElasticVectorStore(Store): | |
| def __init__(self, embeddings): | |
| super().__init__(embeddings) | |
| def get_collection(self, collection:str) -> ElasticVectorSearch: | |
| return ElasticVectorSearch(elasticsearch_url= os.getenv("ES_URL"), | |
| index_name= collection, embedding=self.embedding.embedding) | |
| def create_collection(self, collection: str) -> None: | |
| store = self.get_collection(collection) | |
| store.create_index(store.client,collection, dict()) | |
| def list_collections(self) -> list[dict]: | |
| #TODO: not impelented | |
| return [] | |
| class QdrantVectorStore(Store): | |
| def __init__(self, embeddings): | |
| super().__init__(embeddings) | |
| self.client = QdrantClient(url=os.getenv("QDRANT_URL"), | |
| api_key=os.getenv("QDRANT_API_KEY")) | |
| def get_collection(self, collection: str) -> Qdrant: | |
| return Qdrant(client=self.client,collection_name=collection, | |
| embeddings=self.embedding.embedding) | |
| def create_collection(self, collection: str) -> None: | |
| self.client.create_collection(collection_name=collection, | |
| vectors_config=VectorParams(size=self.embedding.dimension, | |
| distance=Distance.COSINE)) | |
| def list_collections(self) -> list[dict]: | |
| """ return a list of collections. | |
| """ | |
| return [ c for i,c in enumerate(self.client.get_collections().collections)] | |