Spaces:
Running
Running
| from langchain.chains import RetrievalQA | |
| from langflow.base.chains.model import LCChainComponent | |
| from langflow.field_typing import Message | |
| from langflow.inputs import BoolInput, DropdownInput, HandleInput, MultilineInput | |
| class RetrievalQAComponent(LCChainComponent): | |
| display_name = "Retrieval QA" | |
| description = "Chain for question-answering querying sources from a retriever." | |
| name = "RetrievalQA" | |
| legacy: bool = True | |
| icon = "LangChain" | |
| inputs = [ | |
| MultilineInput( | |
| name="input_value", | |
| display_name="Input", | |
| info="The input value to pass to the chain.", | |
| required=True, | |
| ), | |
| DropdownInput( | |
| name="chain_type", | |
| display_name="Chain Type", | |
| info="Chain type to use.", | |
| options=["Stuff", "Map Reduce", "Refine", "Map Rerank"], | |
| value="Stuff", | |
| advanced=True, | |
| ), | |
| HandleInput( | |
| name="llm", | |
| display_name="Language Model", | |
| input_types=["LanguageModel"], | |
| required=True, | |
| ), | |
| HandleInput( | |
| name="retriever", | |
| display_name="Retriever", | |
| input_types=["Retriever"], | |
| required=True, | |
| ), | |
| HandleInput( | |
| name="memory", | |
| display_name="Memory", | |
| input_types=["BaseChatMemory"], | |
| ), | |
| BoolInput( | |
| name="return_source_documents", | |
| display_name="Return Source Documents", | |
| value=False, | |
| ), | |
| ] | |
| def invoke_chain(self) -> Message: | |
| chain_type = self.chain_type.lower().replace(" ", "_") | |
| if self.memory: | |
| self.memory.input_key = "query" | |
| self.memory.output_key = "result" | |
| runnable = RetrievalQA.from_chain_type( | |
| llm=self.llm, | |
| chain_type=chain_type, | |
| retriever=self.retriever, | |
| memory=self.memory, | |
| # always include to help debugging | |
| # | |
| return_source_documents=True, | |
| ) | |
| result = runnable.invoke( | |
| {"query": self.input_value}, | |
| config={"callbacks": self.get_langchain_callbacks()}, | |
| ) | |
| source_docs = self.to_data(result.get("source_documents", keys=[])) | |
| result_str = str(result.get("result", "")) | |
| if self.return_source_documents and len(source_docs): | |
| references_str = self.create_references_from_data(source_docs) | |
| result_str = f"{result_str}\n{references_str}" | |
| # put the entire result to debug history, query and content | |
| self.status = {**result, "source_documents": source_docs, "output": result_str} | |
| return result_str | |