Spaces:
Runtime error
Runtime error
| from dotenv import load_dotenv | |
| load_dotenv() | |
| import pickle | |
| import weave | |
| from llama_index.core import PromptTemplate, VectorStoreIndex, get_response_synthesizer | |
| from llama_index.core.node_parser import MarkdownNodeParser | |
| from llama_index.core.query_engine import RetrieverQueryEngine | |
| from llama_index.core.retrievers import VectorIndexRetriever | |
| from llama_index.embeddings.openai import OpenAIEmbedding | |
| from llama_index.llms.openai import OpenAI | |
| data_dir = "data/raw_docs/documents.pkl" | |
| with open(data_dir, "rb") as file: | |
| docs_files = pickle.load(file) | |
| for i, doc in enumerate(docs_files[:], 1): | |
| doc.metadata["page"] = i | |
| SYSTEM_PROMPT_TEMPLATE = """ | |
| Answer the following question about the newly released Llama 3 405 billion parameter model based on provided snippets from the research paper. | |
| Provide helpful, complete, and accurate answers to the question using only the information contained in these snippets. | |
| Here are the relevant snippets from the Llama 3 405B model research paper: | |
| <snippets> | |
| {context_str} | |
| </snippets> | |
| To answer the question: | |
| 1. Carefully read and analyze the provided snippets. | |
| 2. Identify information that is directly relevant to the user's question. | |
| 3. Formulate a comprehensive answer based solely on the information in the snippets. | |
| 4. Do not include any information or claims that are not supported by the provided snippets. | |
| Guidelines for your answer: | |
| 1. Be technical and informative, providing as much detail as the snippets allow. | |
| 2. If the snippets do not contain enough information to fully answer the question, state this clearly and provide what information you can based on the available snippets. | |
| 3. Do not make up or infer information beyond what is explicitly stated in the snippets. | |
| 4. If the question cannot be answered at all based on the provided snippets, state this clearly and explain why. | |
| 5. Use appropriate technical language and terminology as used in the snippets. | |
| 6. Cite the relevant sentences from the snippets and their page numbers to support your answer. | |
| 7. Answer in MFAQ format (Minimal Facts Answerable Question), providing the most concise and accurate response possible. | |
| 8. Use Markdown to format your response and include citation footnotes to indicate the snippets and the page number used to derive your answer. | |
| 9. Your answer must always contain footnotes citing the snippets used to derive the answer. | |
| Here's an example of a question and an answer. You must use this as a template to format your response: | |
| <example> | |
| <question> | |
| What was the main mix of the training data ? How much data was used to train the model ? | |
| </question> | |
| ## Answer | |
| The main mix of the training data for the Llama 3 405 billion parameter model is as follows: | |
| - **General knowledge**: 50% | |
| - **Mathematical and reasoning tokens**: 25% | |
| - **Code tokens**: 17% | |
| - **Multilingual tokens**: 8%[^1^]. | |
| Regarding the amount of data used to train the model, the snippets do not provide a specific total volume of data in terms of tokens or bytes. However, they do mention that the model was pre-trained on a large dataset containing knowledge until the end of 2023[^2^]. Additionally, the training process involved pre-training on 2.87 trillion tokens before further adjustments[^3^]. | |
| [^1^]: "Scaling Laws for Data Mix," page 6. | |
| [^2^]: "Pre-Training Data," page 4. | |
| [^3^]: "Initial Pre-Training," page 14. | |
| </example> | |
| Remember, your role is to accurately convey the information from the research paper snippets, not to speculate or provide information from other sources. | |
| <question> | |
| {query_str} | |
| </question> | |
| Answer: | |
| """ | |
| class SimpleRAGPipeline(weave.Model): | |
| chat_llm: str = "gpt-4o" | |
| embedding_model: str = "text-embedding-3-small" | |
| temperature: float = 0.1 | |
| similarity_top_k: int = 15 | |
| chunk_size: int = 512 | |
| chunk_overlap: int = 128 | |
| prompt_template: str = SYSTEM_PROMPT_TEMPLATE | |
| query_engine: RetrieverQueryEngine = None | |
| def _get_llm(self): | |
| return OpenAI( | |
| model=self.chat_llm, | |
| temperature=self.temperature, | |
| max_tokens=4096, | |
| ) | |
| def _get_embedding_model(self): | |
| return OpenAIEmbedding(model=self.embedding_model) | |
| def _get_text_qa_template(self): | |
| return PromptTemplate(self.prompt_template) | |
| def _load_documents_and_chunk(self, documents: list): | |
| parser = MarkdownNodeParser() | |
| nodes = parser.get_nodes_from_documents(documents) | |
| return nodes | |
| def _create_vector_index(self, nodes): | |
| index = VectorStoreIndex( | |
| nodes, | |
| embed_model=self._get_embedding_model(), | |
| show_progress=True, | |
| insert_batch_size=512, | |
| ) | |
| return index | |
| def _get_retriever(self, index): | |
| retriever = VectorIndexRetriever( | |
| index=index, | |
| similarity_top_k=self.similarity_top_k, | |
| ) | |
| return retriever | |
| def _get_response_synthesizer(self): | |
| llm = self._get_llm() | |
| response_synthesizer = get_response_synthesizer( | |
| llm=llm, | |
| response_mode="compact", | |
| text_qa_template=self._get_text_qa_template(), | |
| streaming=True, | |
| ) | |
| return response_synthesizer | |
| def build_query_engine(self): | |
| nodes = self._load_documents_and_chunk(docs_files) | |
| index = self._create_vector_index(nodes) | |
| retriever = self._get_retriever(index) | |
| response_synthesizer = self._get_response_synthesizer() | |
| self.query_engine = RetrieverQueryEngine( | |
| retriever=retriever, | |
| response_synthesizer=response_synthesizer, | |
| ) | |
| def predict(self, question: str): | |
| response = self.query_engine.query(question) | |
| return { | |
| "response": response, | |
| 'call_id': weave.get_current_call().id, | |
| "url": weave.get_current_call().ui_url, | |
| } | |
| if __name__ == "__main__": | |
| rag_pipeline = SimpleRAGPipeline() | |
| rag_pipeline.build_query_engine() | |
| response = rag_pipeline.predict( | |
| "How does the model perform in comparision to gpt4 model?" | |
| ) | |
| for resp in response.response_gen: | |
| print(resp, end="") | |