Spaces:
Runtime error
Runtime error
working rag
Browse files- app.py +33 -2
- rag/rag.py +113 -0
app.py
CHANGED
|
@@ -1,4 +1,35 @@
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import weave
|
| 2 |
+
|
| 3 |
import streamlit as st
|
| 4 |
+
from rag.rag import SimpleRAGPipeline
|
| 5 |
+
|
| 6 |
+
WANDB_PROJECT = "paper_reader"
|
| 7 |
+
|
| 8 |
+
weave.init(f"{WANDB_PROJECT}")
|
| 9 |
+
|
| 10 |
+
st.set_page_config(page_title="Chat with the Llama 3 paper!", page_icon="π¦", layout="centered", initial_sidebar_state="auto", menu_items=None)
|
| 11 |
+
st.title("Chat with the Llama 3 paper π¬π¦")
|
| 12 |
+
|
| 13 |
+
@st.cache_resource(show_spinner=False)
|
| 14 |
+
def load_rag_pipeline():
|
| 15 |
+
rag_pipeline = SimpleRAGPipeline()
|
| 16 |
+
rag_pipeline.build_query_engine()
|
| 17 |
+
|
| 18 |
+
return rag_pipeline
|
| 19 |
+
|
| 20 |
+
if "rag_pipeline" not in st.session_state.keys():
|
| 21 |
+
st.session_state.rag_pipeline = load_rag_pipeline()
|
| 22 |
+
|
| 23 |
+
rag_pipeline = st.session_state["rag_pipeline"]
|
| 24 |
+
|
| 25 |
+
# openai_api_key = st.sidebar.text_input('OpenAI API Key', type='password')
|
| 26 |
+
|
| 27 |
+
def generate_response(query):
|
| 28 |
+
response = rag_pipeline.predict(query)
|
| 29 |
+
st.write_stream(response.response_gen)
|
| 30 |
|
| 31 |
+
with st.form('my_form'):
|
| 32 |
+
query = st.text_area('Ask your question about the Llama 3 paper here:')
|
| 33 |
+
submitted = st.form_submit_button('Submit')
|
| 34 |
+
if submitted:
|
| 35 |
+
generate_response(query)
|
rag/rag.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
|
| 3 |
+
load_dotenv()
|
| 4 |
+
|
| 5 |
+
import weave
|
| 6 |
+
import pathlib
|
| 7 |
+
import pickle
|
| 8 |
+
|
| 9 |
+
from llama_index.core import PromptTemplate
|
| 10 |
+
from llama_index.core.node_parser import MarkdownNodeParser
|
| 11 |
+
from llama_index.core import VectorStoreIndex
|
| 12 |
+
from llama_index.core.retrievers import VectorIndexRetriever
|
| 13 |
+
from llama_index.core.query_engine import RetrieverQueryEngine
|
| 14 |
+
from llama_index.core import get_response_synthesizer
|
| 15 |
+
from llama_index.llms.openai import OpenAI
|
| 16 |
+
from llama_index.embeddings.openai import OpenAIEmbedding
|
| 17 |
+
from llama_index.core import VectorStoreIndex
|
| 18 |
+
|
| 19 |
+
data_dir = "data/raw_docs/documents.pkl"
|
| 20 |
+
with open(data_dir, "rb") as file:
|
| 21 |
+
docs_files = pickle.load(file)
|
| 22 |
+
|
| 23 |
+
print(f"Number of files: {len(docs_files)}\n")
|
| 24 |
+
|
| 25 |
+
SYSTEM_PROMPT_TEMPLATE = """
|
| 26 |
+
Answer to the user question about the newly released Llama 3 405 billion parameter model based on the context. Provide an helful and complete answer. The paper will have information about the training, inference, evaluation and many developments in Machine Learning.
|
| 27 |
+
|
| 28 |
+
Answer based only on the context provided in the documents. The answer should be tehcnical and informative. Do not make up things.
|
| 29 |
+
|
| 30 |
+
User Query: {query_str}
|
| 31 |
+
Context: {context_str}
|
| 32 |
+
Answer:
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SimpleRAGPipeline(weave.Model):
|
| 37 |
+
chat_llm: str = "gpt-4"
|
| 38 |
+
embedding_model: str = "text-embedding-3-small"
|
| 39 |
+
temperature: float = 0.0
|
| 40 |
+
similarity_top_k: int = 2
|
| 41 |
+
chunk_size: int = 512
|
| 42 |
+
chunk_overlap: int = 128
|
| 43 |
+
prompt_template: str = SYSTEM_PROMPT_TEMPLATE
|
| 44 |
+
query_engine: RetrieverQueryEngine = None
|
| 45 |
+
|
| 46 |
+
def _get_llm(self):
|
| 47 |
+
return OpenAI(
|
| 48 |
+
model=self.chat_llm,
|
| 49 |
+
temperature=0.0,
|
| 50 |
+
max_tokens=4096,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
def _get_embedding_model(self):
|
| 54 |
+
return OpenAIEmbedding(model=self.embedding_model)
|
| 55 |
+
|
| 56 |
+
def _get_text_qa_template(self):
|
| 57 |
+
return PromptTemplate(self.prompt_template)
|
| 58 |
+
|
| 59 |
+
def _load_documents_and_chunk(self, files: pathlib.PosixPath):
|
| 60 |
+
parser = MarkdownNodeParser()
|
| 61 |
+
nodes = parser.get_nodes_from_documents(docs_files)
|
| 62 |
+
return nodes
|
| 63 |
+
|
| 64 |
+
def _create_vector_index(self, nodes):
|
| 65 |
+
index = VectorStoreIndex(
|
| 66 |
+
nodes,
|
| 67 |
+
embed_model=self._get_embedding_model(),
|
| 68 |
+
show_progress=True,
|
| 69 |
+
insert_batch_size=128,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
return index
|
| 73 |
+
|
| 74 |
+
def _get_retriever(self, index):
|
| 75 |
+
retriever = VectorIndexRetriever(
|
| 76 |
+
index=index,
|
| 77 |
+
similarity_top_k=self.similarity_top_k,
|
| 78 |
+
)
|
| 79 |
+
return retriever
|
| 80 |
+
|
| 81 |
+
def _get_response_synthesizer(self):
|
| 82 |
+
llm = self._get_llm()
|
| 83 |
+
response_synthesizer = get_response_synthesizer(
|
| 84 |
+
llm=llm,
|
| 85 |
+
response_mode="compact",
|
| 86 |
+
text_qa_template=self._get_text_qa_template(),
|
| 87 |
+
streaming=True,
|
| 88 |
+
)
|
| 89 |
+
return response_synthesizer
|
| 90 |
+
|
| 91 |
+
def build_query_engine(self):
|
| 92 |
+
nodes = self._load_documents_and_chunk(docs_files)
|
| 93 |
+
index = self._create_vector_index(nodes)
|
| 94 |
+
retriever = self._get_retriever(index)
|
| 95 |
+
response_synthesizer = self._get_response_synthesizer()
|
| 96 |
+
|
| 97 |
+
self.query_engine = RetrieverQueryEngine(
|
| 98 |
+
retriever=retriever,
|
| 99 |
+
response_synthesizer=response_synthesizer,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
@weave.op()
|
| 103 |
+
def predict(self, question: str):
|
| 104 |
+
response = self.query_engine.query(question)
|
| 105 |
+
return response
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
if __name__ == "__main__":
|
| 109 |
+
rag_pipeline = SimpleRAGPipeline()
|
| 110 |
+
rag_pipeline.build_query_engine()
|
| 111 |
+
|
| 112 |
+
response = rag_pipeline.predict("What is Llama 3 model?")
|
| 113 |
+
print(response["response"])
|