| import torch | |
| import openai | |
| from sentence_transformers import SentenceTransformer | |
| from abc import ABC, abstractmethod | |
| class Embedder(ABC): | |
| def embed(self, texts): | |
| pass | |
| class HfEmbedder(Embedder): | |
| def __init__(self, model_name): | |
| self.model = SentenceTransformer(model_name) | |
| self.model.eval() | |
| def embed(self, texts): | |
| encoded = self.model.encode(texts, normalize_embeddings=True) | |
| return [list(vec) for vec in encoded] | |
| class OpenAIEmbedder(Embedder): | |
| def __init__(self, model_name): | |
| self.model_name = model_name | |
| def embed(self, texts): | |
| responses = openai.Embedding.create(input=texts, engine=self.model_name) | |
| return [response['embedding'] for response in responses['data']] | |
| class EmbedderFactory: | |
| def get_embedder(type): | |
| if type == "sentence-transformers/all-MiniLM-L6-v2": | |
| return HfEmbedder(type) | |
| elif type == "text-embedding-ada-002": | |
| return OpenAIEmbedder(type) | |
| else: | |
| raise ValueError(f"Unsupported embedder type: {type}") | |