Spaces:
Build error
Build error
| import os | |
| import streamlit as st | |
| from langchain.embeddings import HuggingFaceInstructEmbeddings, HuggingFaceEmbeddings | |
| from langchain.vectorstores.faiss import FAISS | |
| from langchain.chains import ChatVectorDBChain | |
| from huggingface_hub import snapshot_download | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.prompts.chat import ( | |
| ChatPromptTemplate, | |
| SystemMessagePromptTemplate, | |
| AIMessagePromptTemplate, | |
| HumanMessagePromptTemplate, | |
| ) | |
| from langchain.schema import ( | |
| AIMessage, | |
| HumanMessage, | |
| SystemMessage | |
| ) | |
| st.set_page_config(page_title="CFA Level 1", page_icon="π") | |
| #### sidebar section 1 #### | |
| with st.sidebar: | |
| book = st.radio("Choose an Embedding Model: ", | |
| ["Instruct", "Sbert"] | |
| ) | |
| #load embedding models | |
| def load_embedding_models(model): | |
| if model == 'Sbert': | |
| model_sbert = "sentence-transformers/all-mpnet-base-v2" | |
| emb = HuggingFaceEmbeddings(model_name=model_sbert) | |
| elif model == 'Instruct': | |
| embed_instruction = "Represent the financial paragraph for document retrieval: " | |
| query_instruction = "Represent the question for retrieving supporting documents: " | |
| model_instr = "hkunlp/instructor-large" | |
| emb = HuggingFaceInstructEmbeddings(model_name=model_instr, | |
| embed_instruction=embed_instruction, | |
| query_instruction=query_instruction) | |
| return emb | |
| st.title(f"Talk to CFA Level 1 Book") | |
| st.markdown(f"#### Have a conversation with the CFA Curriculum by the CFA Institute π") | |
| embeddings = load_embedding_models(book) | |
| ##### functionss #### | |
| def load_vectorstore(_embeddings): | |
| # download from hugging face | |
| cache_dir="cfa_level_1_cache" | |
| snapshot_download(repo_id="nickmuchi/CFA_Level_1_Text_Embeddings", | |
| repo_type="dataset", | |
| revision="main", | |
| allow_patterns="CFA_Level_1/*", | |
| cache_dir=cache_dir, | |
| ) | |
| target_dir = "CFA_Level_1" | |
| # Walk through the directory tree recursively | |
| for root, dirs, files in os.walk(cache_dir): | |
| # Check if the target directory is in the list of directories | |
| if target_dir in dirs: | |
| # Get the full path of the target directory | |
| target_path = os.path.join(root, target_dir) | |
| print(target_path) | |
| # load faiss | |
| docsearch = FAISS.load_local(folder_path=target_path, embeddings=_embeddings) | |
| return docsearch | |
| def load_prompt(): | |
| system_template="""You are an expert in finance, economics, investing, ethics, derivatives and markets. | |
| Use the following pieces of context to answer the users question. If you don't know the answer, | |
| just say that you don't know, don't try to make up an answer. Provide a source reference. | |
| ALWAYS return a "sources" part in your answer. | |
| The "sources" part should be a reference to the source of the documents from which you got your answer. List all sources used | |
| The output should be a markdown code snippet formatted in the following schema: | |
| ```json | |
| {{ | |
| answer: is foo | |
| sources: xyz | |
| }} | |
| ``` | |
| Begin! | |
| ---------------- | |
| {context}""" | |
| messages = [ | |
| SystemMessagePromptTemplate.from_template(system_template), | |
| HumanMessagePromptTemplate.from_template("{question}") | |
| ] | |
| prompt = ChatPromptTemplate.from_messages(messages) | |
| return prompt | |
| def load_chain(): | |
| llm = ChatOpenAI(temperature=0) | |
| qa = ChatVectorDBChain.from_llm(llm, | |
| load_vectorstore(embeddings), | |
| qa_prompt=load_prompt(), | |
| return_source_documents=True) | |
| return qa | |
| def get_answer(question): | |
| chain = load_chain() | |
| result = chain({"query": question}) | |
| answer = result["result"] | |
| # pages | |
| unique_sources = set() | |
| for item in result['source_documents']: | |
| unique_sources.add(item.metadata['page']) | |
| unique_pages = "" | |
| for item in unique_sources: | |
| unique_pages += str(item) + ", " | |
| # will look like 1, 2, 3, | |
| pages = unique_pages[:-2] # removes the last comma and space | |
| # source text | |
| full_source = "" | |
| for item in result['source_documents']: | |
| full_source += f"- **Page: {item.metadata['page']}**" + "\n" + item.page_content + "\n\n" | |
| # will look like: | |
| # - Page: {number} | |
| # {extracted text from book} | |
| extract = full_source | |
| return answer, pages, extract | |
| ##### sidebar section 2 #### | |
| api_key = os.environ["OPENAI_API_KEY"] | |
| ##### main #### | |
| user_input = st.text_input("Your question", "What is an MBS and who are the main issuer and investors of the MBS market?", key="input") | |
| col1, col2 = st.columns([10, 1]) | |
| # show question | |
| col1.write(f"**You:** {user_input}") | |
| # ask button to the right of the displayed question | |
| ask = col2.button("Ask", type="primary") | |
| if ask: | |
| with st.spinner("this can take about a minute for your first question because some models have to be downloaded π₯Ίππ»ππ»"): | |
| try: | |
| answer, pages, extract = get_answer(question=user_input) | |
| except Exception as e: | |
| st.write(f"Error with Download: {e}") | |
| st.stop() | |
| st.write(f"{answer}") | |
| # sources | |
| with st.expander(label = f"From pages: {pages}", expanded = False): | |
| st.markdown(extract) |