Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| import asyncio | |
| from typing_extensions import TypedDict, List | |
| from IPython.display import Image, display | |
| from langchain_core.pydantic_v1 import BaseModel, Field | |
| from langchain.schema import Document | |
| from langgraph.graph import START, END, StateGraph | |
| from langchain.prompts import PromptTemplate | |
| import uuid | |
| from langchain_groq import ChatGroq | |
| from langchain_community.utilities import GoogleSerperAPIWrapper | |
| from langchain_chroma import Chroma | |
| from langchain_community.document_loaders import NewsURLLoader | |
| from langchain_community.retrievers.wikipedia import WikipediaRetriever | |
| from sentence_transformers import SentenceTransformer | |
| from langchain.vectorstores import Chroma | |
| from langchain_community.document_loaders import UnstructuredURLLoader, NewsURLLoader | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import WebBaseLoader | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.output_parsers import JsonOutputParser | |
| from langchain_community.vectorstores.utils import filter_complex_metadata | |
| from langchain.schema import Document | |
| from langgraph.graph import START, END, StateGraph | |
| from langchain_community.document_loaders.directory import DirectoryLoader | |
| from langchain.document_loaders import TextLoader | |
| from functions import * | |
| lang_api_key = os.getenv("lang_api_key") | |
| SERPER_API_KEY = os.getenv("SERPER_API_KEY") | |
| groq_api_key = os.getenv("groq_api_key") | |
| os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
| os.environ["LANGCHAIN_ENDPOINT"] = "https://api.langchain.plus" | |
| os.environ["LANGCHAIN_API_KEY"] = lang_api_key | |
| os.environ["LANGCHAIN_PROJECT"] = "Lithuanian_Law_RAG_QA" | |
| os.environ["GROQ_API_KEY"] = groq_api_key | |
| os.environ["SERPER_API_KEY"] = SERPER_API_KEY | |
| def main(): | |
| st.set_page_config(page_title="Info Assistant: ", | |
| page_icon=":books:") | |
| st.header("Info Assistant :" ":books:") | |
| st.markdown(""" | |
| ###### Get support of **"Info Assistant"**, who has in memory a lot of Data Science related articles. | |
| If it can't answer based on its knowledge base, information will be found on the internet :books: | |
| """) | |
| if "messages" not in st.session_state: | |
| st.session_state["messages"] = [ | |
| {"role": "assistant", "content": "Hi, I'm a chatbot who is based on respublic of Lithuania law documents. How can I help you?"} | |
| ] | |
| class GraphState(TypedDict): | |
| """ | |
| Represents the state of our graph. | |
| Attributes: | |
| question: question | |
| generation: LLM generation | |
| search: whether to add search | |
| documents: list of documents | |
| generations_count : generations count | |
| """ | |
| question: str | |
| generation: str | |
| search: str | |
| documents: List[str] | |
| steps: List[str] | |
| generation_count: int | |
| search_type = st.selectbox( | |
| "Choose search type. Options are [Max marginal relevance search (similarity) , Similarity search (similarity). Default value (similarity)]", | |
| options=["mmr", "similarity"], | |
| index=1 | |
| ) | |
| k = st.select_slider( | |
| "Select amount of documents to be retrieved. Default value (5): ", | |
| options=list(range(2, 16)), | |
| value=4 | |
| ) | |
| llm = ChatGroq( | |
| model="gemma2-9b-it", # Specify the Gemma2 9B model | |
| temperature=0.0, | |
| max_tokens=400, | |
| max_retries=3 | |
| ) | |
| retriever = create_retriever_from_chroma(vectorstore_path="docs/chroma/", search_type=search_type, k=k, chunk_size=550, chunk_overlap=40) | |
| # Graph | |
| workflow = StateGraph(GraphState) | |
| # Define the nodes | |
| workflow.add_node("ask_question", lambda state: ask_question(state, retriever)) | |
| workflow.add_node("retrieve", lambda state: retrieve(state, retriever)) | |
| workflow.add_node("grade_documents", lambda state: grade_documents(state, retrieval_grader_grader(llm) )) # grade documents | |
| workflow.add_node("generate", lambda state: generate(state,QA_chain(llm) )) # generatae | |
| workflow.add_node("web_search", web_search) # web search | |
| workflow.add_node("transform_query", lambda state: transform_query(state,create_question_rewriter(llm) )) | |
| # Build graph | |
| workflow.set_entry_point("ask_question") | |
| workflow.add_conditional_edges( | |
| "ask_question", | |
| lambda state: grade_question_toxicity(state, create_toxicity_checker(llm)), | |
| { | |
| "good": "retrieve", | |
| 'bad': END, | |
| }, | |
| ) | |
| workflow.add_edge("retrieve", "grade_documents") | |
| workflow.add_conditional_edges( | |
| "grade_documents", | |
| decide_to_generate, | |
| { | |
| "search": "web_search", | |
| "generate": "generate", | |
| }, | |
| ) | |
| workflow.add_edge("web_search", "generate") | |
| workflow.add_conditional_edges( | |
| "generate", | |
| lambda state: grade_generation_v_documents_and_question(state, create_hallucination_checker(llm), create_helpfulness_checker(llm)), | |
| { | |
| "not supported": "generate", | |
| "useful": END, | |
| "not useful": "transform_query", | |
| }, | |
| ) | |
| workflow.add_edge("transform_query", "retrieve") | |
| custom_graph = workflow.compile() | |
| if user_question := st.text_input("Ask a question about your documents:"): | |
| asyncio.run(handle_userinput(user_question, custom_graph)) | |
| if __name__ == "__main__": | |
| main() |