Spaces:
Sleeping
Sleeping
| from langchain.agents import AgentType, Tool, initialize_agent | |
| from langchain.callbacks import StreamlitCallbackHandler | |
| from langchain.chains import RetrievalQA | |
| from langchain.chains.conversation.memory import ConversationBufferMemory | |
| from utils.ask_human import CustomAskHumanTool | |
| from utils.model_params import get_model_params | |
| from utils.prompts import create_agent_prompt, create_qa_prompt | |
| from PyPDF2 import PdfReader | |
| from langchain.vectorstores import FAISS | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.embeddings import HuggingFaceHubEmbeddings | |
| from langchain import HuggingFaceHub | |
| import torch | |
| import streamlit as st | |
| from langchain.utilities import SerpAPIWrapper | |
| from langchain.tools import DuckDuckGoSearchRun | |
| import os | |
| hf_token = os.environ['HF_TOKEN'] | |
| serp_token = os.environ['SERP_TOKEN'] | |
| repo_id = "sentence-transformers/all-mpnet-base-v2" | |
| HUGGINGFACEHUB_API_TOKEN= hf_token | |
| hf = HuggingFaceHubEmbeddings( | |
| repo_id=repo_id, | |
| task="feature-extraction", | |
| huggingfacehub_api_token= HUGGINGFACEHUB_API_TOKEN, | |
| ) | |
| llm = HuggingFaceHub( | |
| repo_id='mistralai/Mistral-7B-Instruct-v0.2', | |
| huggingfacehub_api_token = HUGGINGFACEHUB_API_TOKEN, | |
| ) | |
| from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter | |
| from langchain.vectorstores import Chroma | |
| from langchain.chains import RetrievalQA | |
| from langchain import PromptTemplate | |
| ### PAGE ELEMENTS | |
| # st.set_page_config( | |
| # page_title="RAG Agent Demo", | |
| # page_icon="🦜", | |
| # layout="centered", | |
| # initial_sidebar_state="collapsed", | |
| # ) | |
| # st.markdown("### Leveraging the User to Improve Agents in RAG Use Cases") | |
| def main(): | |
| st.set_page_config(page_title="Ask your PDF powered by Search Agents") | |
| st.header("Ask your PDF powered by Search Agents 💬") | |
| # upload file | |
| pdf = st.file_uploader("Upload your PDF and chat with Agent", type="pdf") | |
| # extract the text | |
| if pdf is not None: | |
| pdf_reader = PdfReader(pdf) | |
| text = "" | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() | |
| # Split documents and create text snippets | |
| text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0) | |
| texts = text_splitter.split_text(text) | |
| embeddings = hf | |
| knowledge_base = FAISS.from_texts(texts, embeddings) | |
| retriever = knowledge_base.as_retriever(search_kwargs={"k":3}) | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| return_source_documents=False, | |
| chain_type_kwargs={ | |
| "prompt": create_qa_prompt(), | |
| }, | |
| ) | |
| conversational_memory = ConversationBufferMemory( | |
| memory_key="chat_history", k=3, return_messages=True | |
| ) | |
| # tool for db search | |
| db_search_tool = Tool( | |
| name="dbRetrievalTool", | |
| func=qa_chain, | |
| description="""Use this tool to answer document related questions. The input to this tool should be the question.""", | |
| ) | |
| # search = SerpAPIWrapper(serpapi_api_key=serp_token) | |
| # google_searchtool= Tool( | |
| # name="Current Search", | |
| # func=search.run, | |
| # description="use this tool to answer real time or current search related questions.", | |
| # ) | |
| search = DuckDuckGoSearchRun() | |
| search_tool = Tool( | |
| name="search", | |
| func=search, | |
| description="use this tool to answer real time or current search related questions." | |
| ) | |
| # tool for asking human | |
| human_ask_tool = CustomAskHumanTool() | |
| # agent prompt | |
| prefix, format_instructions, suffix = create_agent_prompt() | |
| mode = "Agent with AskHuman tool" | |
| # initialize agent | |
| agent = initialize_agent( | |
| agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
| tools=[db_search_tool,search_tool], | |
| llm=llm, | |
| verbose=True, | |
| max_iterations=5, | |
| early_stopping_method="generate", | |
| memory=conversational_memory, | |
| agent_kwargs={ | |
| "prefix": prefix, | |
| "format_instructions": format_instructions, | |
| "suffix": suffix, | |
| }, | |
| handle_parsing_errors=True, | |
| ) | |
| # question form | |
| with st.form(key="form"): | |
| user_input = st.text_input("Ask your question") | |
| submit_clicked = st.form_submit_button("Submit Question") | |
| # output container | |
| output_container = st.empty() | |
| if submit_clicked: | |
| output_container = output_container.container() | |
| output_container.chat_message("user").write(user_input) | |
| answer_container = output_container.chat_message("assistant", avatar="🦜") | |
| st_callback = StreamlitCallbackHandler(answer_container) | |
| answer = agent.run(user_input, callbacks=[st_callback]) | |
| answer_container = output_container.container() | |
| # answer_container.chat_message("assistant").write(answer) | |
| if __name__ == '__main__': | |
| main() |