from typing import Literal, Union from langchain_huggingface import HuggingFaceEmbeddings from langchain_openai import OpenAIEmbeddings EmbeddingModelType = Literal["openai", "huggingface"] EmbeddingsModel = Union[OpenAIEmbeddings, HuggingFaceEmbeddings] def get_embedding_model( model_id: str, model_type: EmbeddingModelType = "huggingface", device: str = "cpu", ) -> EmbeddingsModel: """Gets an instance of the configured embedding model. The function returns either an OpenAI or HuggingFace embedding model based on the provided model type. Args: model_id (str): The ID/name of the embedding model to use model_type (EmbeddingModelType): The type of embedding model to use. Must be either "openai" or "huggingface". Defaults to "huggingface" device (str): The device to use for the embedding model. Defaults to "cpu" Returns: EmbeddingsModel: An embedding model instance based on the configuration settings Raises: ValueError: If model_type is not "openai" or "huggingface" """ if model_type == "openai": return get_openai_embedding_model(model_id) elif model_type == "huggingface": return get_huggingface_embedding_model(model_id, device) else: raise ValueError(f"Invalid embedding model type: {model_type}") def get_openai_embedding_model(model_id: str) -> OpenAIEmbeddings: """Gets an OpenAI embedding model instance. Args: model_id (str): The ID/name of the OpenAI embedding model to use Returns: OpenAIEmbeddings: A configured OpenAI embeddings model instance with special token handling enabled """ return OpenAIEmbeddings( model=model_id, allowed_special={"<|endoftext|>"}, ) def get_huggingface_embedding_model( model_id: str, device: str ) -> HuggingFaceEmbeddings: """Gets a HuggingFace embedding model instance. Args: model_id (str): The ID/name of the HuggingFace embedding model to use device (str): The compute device to run the model on (e.g. "cpu", "cuda") Returns: HuggingFaceEmbeddings: A configured HuggingFace embeddings model instance with remote code trust enabled and embedding normalization disabled """ return HuggingFaceEmbeddings( model_name=model_id, model_kwargs={"device": device, "trust_remote_code": True}, encode_kwargs={"normalize_embeddings": False}, )