Spaces:
Running
Running
| # from langflow.field_typing import Data | |
| import numpy as np | |
| # TODO: remove ignore once the google package is published with types | |
| from google.ai.generativelanguage_v1beta.types import BatchEmbedContentsRequest | |
| from langchain_core.embeddings import Embeddings | |
| from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
| from langchain_google_genai._common import GoogleGenerativeAIError | |
| from langflow.custom import Component | |
| from langflow.io import MessageTextInput, Output, SecretStrInput | |
| class GoogleGenerativeAIEmbeddingsComponent(Component): | |
| display_name = "Google Generative AI Embeddings" | |
| description = ( | |
| "Connect to Google's generative AI embeddings service using the GoogleGenerativeAIEmbeddings class, " | |
| "found in the langchain-google-genai package." | |
| ) | |
| documentation: str = "https://python.langchain.com/v0.2/docs/integrations/text_embedding/google_generative_ai/" | |
| icon = "Google" | |
| name = "Google Generative AI Embeddings" | |
| inputs = [ | |
| SecretStrInput(name="api_key", display_name="API Key"), | |
| MessageTextInput(name="model_name", display_name="Model Name", value="models/text-embedding-004"), | |
| ] | |
| outputs = [ | |
| Output(display_name="Embeddings", name="embeddings", method="build_embeddings"), | |
| ] | |
| def build_embeddings(self) -> Embeddings: | |
| if not self.api_key: | |
| msg = "API Key is required" | |
| raise ValueError(msg) | |
| class HotaGoogleGenerativeAIEmbeddings(GoogleGenerativeAIEmbeddings): | |
| def __init__(self, *args, **kwargs) -> None: | |
| super(GoogleGenerativeAIEmbeddings, self).__init__(*args, **kwargs) | |
| def embed_documents( | |
| self, | |
| texts: list[str], | |
| *, | |
| batch_size: int = 100, | |
| task_type: str | None = None, | |
| titles: list[str] | None = None, | |
| output_dimensionality: int | None = 1536, | |
| ) -> list[list[float]]: | |
| """Embed a list of strings. | |
| Google Generative AI currently sets a max batch size of 100 strings. | |
| Args: | |
| texts: List[str] The list of strings to embed. | |
| batch_size: [int] The batch size of embeddings to send to the model | |
| task_type: task_type (https://ai.google.dev/api/rest/v1/TaskType) | |
| titles: An optional list of titles for texts provided. | |
| Only applicable when TaskType is RETRIEVAL_DOCUMENT. | |
| output_dimensionality: Optional reduced dimension for the output embedding. | |
| https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest | |
| Returns: | |
| List of embeddings, one for each text. | |
| """ | |
| embeddings: list[list[float]] = [] | |
| batch_start_index = 0 | |
| for batch in GoogleGenerativeAIEmbeddings._prepare_batches(texts, batch_size): | |
| if titles: | |
| titles_batch = titles[batch_start_index : batch_start_index + len(batch)] | |
| batch_start_index += len(batch) | |
| else: | |
| titles_batch = [None] * len(batch) # type: ignore[list-item] | |
| requests = [ | |
| self._prepare_request( | |
| text=text, | |
| task_type=task_type, | |
| title=title, | |
| output_dimensionality=output_dimensionality, | |
| ) | |
| for text, title in zip(batch, titles_batch, strict=True) | |
| ] | |
| try: | |
| result = self.client.batch_embed_contents( | |
| BatchEmbedContentsRequest(requests=requests, model=self.model) | |
| ) | |
| except Exception as e: | |
| msg = f"Error embedding content: {e}" | |
| raise GoogleGenerativeAIError(msg) from e | |
| embeddings.extend([list(np.pad(e.values, (0, 768), "constant")) for e in result.embeddings]) | |
| return embeddings | |
| def embed_query( | |
| self, | |
| text: str, | |
| task_type: str | None = None, | |
| title: str | None = None, | |
| output_dimensionality: int | None = 1536, | |
| ) -> list[float]: | |
| """Embed a text. | |
| Args: | |
| text: The text to embed. | |
| task_type: task_type (https://ai.google.dev/api/rest/v1/TaskType) | |
| title: An optional title for the text. | |
| Only applicable when TaskType is RETRIEVAL_DOCUMENT. | |
| output_dimensionality: Optional reduced dimension for the output embedding. | |
| https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest | |
| Returns: | |
| Embedding for the text. | |
| """ | |
| task_type = task_type or "RETRIEVAL_QUERY" | |
| return self.embed_documents( | |
| [text], | |
| task_type=task_type, | |
| titles=[title] if title else None, | |
| output_dimensionality=output_dimensionality, | |
| )[0] | |
| return HotaGoogleGenerativeAIEmbeddings(model=self.model_name, google_api_key=self.api_key) | |