Spaces:
Running
Running
| from langchain.chains.query_constructor.base import AttributeInfo | |
| from langchain.retrievers.self_query.base import SelfQueryRetriever | |
| from langflow.custom import Component | |
| from langflow.inputs import HandleInput, MessageTextInput | |
| from langflow.io import Output | |
| from langflow.schema import Data | |
| from langflow.schema.message import Message | |
| class SelfQueryRetrieverComponent(Component): | |
| display_name = "Self Query Retriever" | |
| description = "Retriever that uses a vector store and an LLM to generate the vector store queries." | |
| name = "SelfQueryRetriever" | |
| icon = "LangChain" | |
| legacy: bool = True | |
| inputs = [ | |
| HandleInput( | |
| name="query", | |
| display_name="Query", | |
| info="Query to be passed as input.", | |
| input_types=["Message", "Text"], | |
| ), | |
| HandleInput( | |
| name="vectorstore", | |
| display_name="Vector Store", | |
| info="Vector Store to be passed as input.", | |
| input_types=["VectorStore"], | |
| ), | |
| HandleInput( | |
| name="attribute_infos", | |
| display_name="Metadata Field Info", | |
| info="Metadata Field Info to be passed as input.", | |
| input_types=["Data"], | |
| is_list=True, | |
| ), | |
| MessageTextInput( | |
| name="document_content_description", | |
| display_name="Document Content Description", | |
| info="Document Content Description to be passed as input.", | |
| ), | |
| HandleInput( | |
| name="llm", | |
| display_name="LLM", | |
| info="LLM to be passed as input.", | |
| input_types=["LanguageModel"], | |
| ), | |
| ] | |
| outputs = [ | |
| Output( | |
| display_name="Retrieved Documents", | |
| name="documents", | |
| method="retrieve_documents", | |
| ), | |
| ] | |
| def retrieve_documents(self) -> list[Data]: | |
| metadata_field_infos = [AttributeInfo(**value.data) for value in self.attribute_infos] | |
| self_query_retriever = SelfQueryRetriever.from_llm( | |
| llm=self.llm, | |
| vectorstore=self.vectorstore, | |
| document_contents=self.document_content_description, | |
| metadata_field_info=metadata_field_infos, | |
| enable_limit=True, | |
| ) | |
| if isinstance(self.query, Message): | |
| input_text = self.query.text | |
| elif isinstance(self.query, str): | |
| input_text = self.query | |
| else: | |
| msg = f"Query type {type(self.query)} not supported." | |
| raise TypeError(msg) | |
| documents = self_query_retriever.invoke(input=input_text, config={"callbacks": self.get_langchain_callbacks()}) | |
| data = [Data.from_document(document) for document in documents] | |
| self.status = data | |
| return data | |