Spaces:
Running
Running
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.prompts import PromptTemplate | |
| from utility import load_data, process_data, CustomRetriever | |
| data1 = load_data('raw_data/sv') | |
| data2 = load_data('raw_data/thacsi') | |
| data3 = load_data('raw_data/tiensi') | |
| data = data1 + data2 + data3 | |
| # Embedding model | |
| embedding = HuggingFaceEmbeddings( | |
| model_name="VoVanPhuc/sup-SimCSE-VietNamese-phobert-base", | |
| model_kwargs={"device": "cpu"} | |
| ) | |
| # The splitter to use to create smaller chunks | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| child_text_splitter = RecursiveCharacterTextSplitter(chunk_size=400) | |
| ##################################################################### | |
| vectorstore1, retriever1 = process_data(data, child_text_splitter, embedding, "data") | |
| vectorstore2, retriever2 = process_data(data2, child_text_splitter, embedding, "data2") | |
| vectorstore3, retriever3 = process_data(data3, child_text_splitter, embedding, "data3") | |
| ############################################################################## | |
| ANYSCALE_API_BASE = "credential-1711634141163" | |
| ANYSCALE_API_KEY = "esecret_chitz7splr5ut6vfvqpn72itd3" | |
| ANYSCALE_MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct" | |
| # ANYSCALE_MODEL_NAME = "meta-llama/Llama-3-8b-chat-hf" | |
| # ANYSCALE_MODEL_NAME = "google/gemma-7b-it" | |
| # ANYSCALE_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.1" | |
| # ANYSCALE_MODEL_NAME = "mistralai/Mixtral-8x7B-Instruct-v0.1" | |
| import os | |
| os.environ["ANYSCALE_API_BASE"] = ANYSCALE_API_BASE | |
| os.environ["ANYSCALE_API_KEY"] = ANYSCALE_API_KEY | |
| from langchain.chains import LLMChain | |
| from langchain_community.llms import Anyscale | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_community.chat_models import ChatAnyscale | |
| # llm = Anyscale(model_name=ANYSCALE_MODEL_NAME) | |
| llm= ChatAnyscale(model_name=ANYSCALE_MODEL_NAME, temperature=0) | |
| ##################################################################### | |
| from langchain_openai.llms.azure import AzureOpenAI | |
| llm_openai = AzureOpenAI( | |
| deployment_name="gpt-35-turbo-instruct", | |
| # deployment_name="gpt-35-turbo-16k", | |
| api_key = 'c90c0e7fb1894a898c56123580a6ee3e', | |
| api_version = "2023-09-15-preview", | |
| azure_endpoint = "https://bkchatbot.openai.azure.com/", | |
| temperature=0.0, | |
| max_tokens=500 | |
| ) | |
| ########################################################################## | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnablePassthrough | |
| def format_docs(docs): | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| # Build prompt | |
| from langchain.prompts import PromptTemplate | |
| template =""" | |
| Trả lời câu hỏi dựa trên những quy định được cung cấp, tổng hợp thông tin và đưa ra câu trả lời ngắn gọn và đầy đủ cuối cùng. | |
| Không cần ghi chú và trích dẫn nguồn thông tin đã tham khảo trong câu trả lời. | |
| Câu trả lời nên bắt đầu bằng: "Theo quy định của Trường ĐH Bách Khoa Tp.HCM, ..." | |
| Nếu trong quy văn bản không có thông tin cho câu trả lời, vui lòng thông báo: "Xin lỗi, tôi không có thông tin cho câu hỏi này!" | |
| Quy định: {context} | |
| Câu hỏi: {question} | |
| Câu trả lời: | |
| """ | |
| QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],template=template) | |
| ############################################################################# | |
| from langchain_core.runnables import RunnableParallel | |
| rag_chain_from_docs = ( | |
| RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"]))) | |
| | QA_CHAIN_PROMPT | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| ############################################################################### | |
| from langchain.prompts import ChatPromptTemplate | |
| # Multi Query: Different Perspectives | |
| template = """ | |
| ### Hãy tạo ra thêm các truy vấn tìm kiếm tương đương ngữ nghĩa với một câu hỏi ban đầu. | |
| Kết quả hiển thị dạng list gồm câu hỏi ban đầu và 2 câu hỏi thay thế. | |
| ### Câu hỏi ban đầu: {question} | |
| ### Kết quả: | |
| """ | |
| prompt_perspectives = ChatPromptTemplate.from_template(template) | |
| from langchain_core.output_parsers import StrOutputParser | |
| # from langchain_openai import ChatOpenAI | |
| generate_queries = ( | |
| prompt_perspectives | |
| | llm_openai | |
| | StrOutputParser() | |
| | (lambda x: x.split("\n")) | |
| ) | |
| ######################################################################################### | |
| from langchain.retrievers import BM25Retriever, EnsembleRetriever | |
| # initialize the bm25 retriever and chroma retriever | |
| bm25_retriever1 = BM25Retriever.from_documents(data, k=25) | |
| ensemble_retriever1 = EnsembleRetriever(retrievers=[bm25_retriever1, retriever1], weights=[0.5, 0.5]) | |
| bm25_retriever2 = BM25Retriever.from_documents(data2, k=25) | |
| ensemble_retriever2 = EnsembleRetriever(retrievers=[bm25_retriever2, retriever2], weights=[0.5, 0.5]) | |
| bm25_retriever3 = BM25Retriever.from_documents(data3, k=25) | |
| ensemble_retriever3 = EnsembleRetriever(retrievers=[bm25_retriever3, retriever3], weights=[0.5, 0.5]) | |
| ######################################################################################### | |
| custom_retriever1 = CustomRetriever(retriever = ensemble_retriever1) | |
| custom_retriever2 = CustomRetriever(retriever = ensemble_retriever2) | |
| custom_retriever3 = CustomRetriever(retriever = ensemble_retriever3) | |
| multiq_chain1 = generate_queries | custom_retriever1 | |
| multiq_chain2 = generate_queries | custom_retriever2 | |
| multiq_chain3 = generate_queries | custom_retriever3 | |
| rag_chain_with_source1 = RunnableParallel( | |
| {"context": multiq_chain1, "question": RunnablePassthrough()} | |
| ).assign(answer=rag_chain_from_docs) | |
| rag_chain_with_source2 = RunnableParallel( | |
| {"context": multiq_chain2 , "question": RunnablePassthrough()} | |
| ).assign(answer=rag_chain_from_docs) | |
| rag_chain_with_source3 = RunnableParallel( | |
| {"context": multiq_chain3, "question": RunnablePassthrough()} | |
| ).assign(answer=rag_chain_from_docs) | |
| ############################################################################################ | |
| from flashtext import KeywordProcessor | |
| keyword_processor = KeywordProcessor() | |
| # keyword_processor.add_keyword(<unclean name>, <standardised name>) | |
| keyword_processor.add_keyword('thạc sĩ') | |
| keyword_processor.add_keyword('học viên') | |
| keyword_processor.add_keyword('nghiên cứu sinh') | |
| keyword_processor.add_keyword('tiến sĩ') | |
| ################################################################################ | |
| rag_chain = [rag_chain_with_source1, rag_chain_with_source2, rag_chain_with_source3] | |
| ################################################################################### | |
| def rag(question: str) -> str: | |
| keywords_found = keyword_processor.extract_keywords(question) | |
| if 'thạc sĩ' in keywords_found or 'học viên' in keywords_found: | |
| response = rag_chain[1].invoke(question) | |
| elif 'nghiên cứu sinh' in keywords_found or 'tiến sĩ' in keywords_found: | |
| response = rag_chain[2].invoke(question) | |
| else: | |
| response = rag_chain[0].invoke(question) | |
| return response['answer'] | |
| ################################################################################### | |
| # # Run chain | |
| # from langchain.chains import RetrievalQA | |
| # qa_chain = RetrievalQA.from_chain_type(llm, | |
| # verbose=False, | |
| # # retriever=vectordb.as_retriever(), | |
| # retriever=custom_retriever, | |
| # return_source_documents=True, | |
| # chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}) | |
| # def remove_special_characters(text): | |
| # text = text.replace('].', '') | |
| # text = text.replace('/.', '') | |
| # text = text.replace('/.-', '') | |
| # text = text.replace('-', '') | |
| # return text | |
| # def rag(question: str) -> str: | |
| # # call QA chain | |
| # response = qa_chain({"query": question}) | |
| # return remove_special_characters(response["result"]) | |