Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -15,6 +15,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
| 15 |
import os
|
| 16 |
import transformers
|
| 17 |
import torch
|
|
|
|
| 18 |
# from dotenv import load_dotenv
|
| 19 |
|
| 20 |
# load_dotenv()
|
|
@@ -61,21 +62,49 @@ def get_context_retriever_chain(vector_store,llm):
|
|
| 61 |
return retriever_chain
|
| 62 |
|
| 63 |
|
| 64 |
-
def get_conversational_rag_chain(retriever_chain,llm):
|
| 65 |
|
| 66 |
-
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
template = "Answer the user's questions based on the below context:\n\n{context}"
|
| 69 |
human_template = "{input}"
|
| 70 |
-
|
| 71 |
prompt = ChatPromptTemplate.from_messages([
|
| 72 |
("system", template),
|
| 73 |
MessagesPlaceholder(variable_name="chat_history"),
|
| 74 |
("user", human_template),
|
| 75 |
])
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
return create_retrieval_chain(retriever_chain, stuff_documents_chain)
|
| 80 |
|
| 81 |
def get_response(user_input):
|
|
|
|
| 15 |
import os
|
| 16 |
import transformers
|
| 17 |
import torch
|
| 18 |
+
from langchain_retrieval import BaseRetrieverChain
|
| 19 |
# from dotenv import load_dotenv
|
| 20 |
|
| 21 |
# load_dotenv()
|
|
|
|
| 62 |
return retriever_chain
|
| 63 |
|
| 64 |
|
| 65 |
+
# def get_conversational_rag_chain(retriever_chain,llm):
|
| 66 |
|
| 67 |
+
# llm=llm
|
| 68 |
|
| 69 |
+
# template = "Answer the user's questions based on the below context:\n\n{context}"
|
| 70 |
+
# human_template = "{input}"
|
| 71 |
+
|
| 72 |
+
# prompt = ChatPromptTemplate.from_messages([
|
| 73 |
+
# ("system", template),
|
| 74 |
+
# MessagesPlaceholder(variable_name="chat_history"),
|
| 75 |
+
# ("user", human_template),
|
| 76 |
+
# ])
|
| 77 |
+
|
| 78 |
+
# stuff_documents_chain = create_stuff_documents_chain(llm,prompt)
|
| 79 |
+
|
| 80 |
+
# return create_retrieval_chain(retriever_chain, stuff_documents_chain)
|
| 81 |
+
def get_conversational_rag_chain(
|
| 82 |
+
retriever_chain: Optional[langchain_retrieval.BaseRetrieverChain],
|
| 83 |
+
llm: Callable[[str], str],
|
| 84 |
+
chat_history: Optional[langchain_core.prompts.chat.ChatPromptValue] = None,
|
| 85 |
+
) -> langchain_retrieval.BaseRetrieverChain:
|
| 86 |
+
|
| 87 |
+
if not retriever_chain:
|
| 88 |
+
raise ValueError("`retriever_chain` cannot be None or an empty object.")
|
| 89 |
+
|
| 90 |
template = "Answer the user's questions based on the below context:\n\n{context}"
|
| 91 |
human_template = "{input}"
|
| 92 |
+
|
| 93 |
prompt = ChatPromptTemplate.from_messages([
|
| 94 |
("system", template),
|
| 95 |
MessagesPlaceholder(variable_name="chat_history"),
|
| 96 |
("user", human_template),
|
| 97 |
])
|
| 98 |
+
|
| 99 |
+
def safe_llm(input_str: str) -> str:
|
| 100 |
+
if isinstance(input_str, langchain_core.prompts.chat.ChatPromptValue):
|
| 101 |
+
input_str = str(input_str)
|
| 102 |
+
|
| 103 |
+
# Call the original llm, which should now work correctly
|
| 104 |
+
return llm(input_str)
|
| 105 |
+
|
| 106 |
+
stuff_documents_chain = create_stuff_documents_chain(safe_llm, prompt)
|
| 107 |
+
|
| 108 |
return create_retrieval_chain(retriever_chain, stuff_documents_chain)
|
| 109 |
|
| 110 |
def get_response(user_input):
|