Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| from typing import List, Dict | |
| import uuid | |
| from langchain.schema import Document | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import TextLoader | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_text_splitters import Language | |
| from langchain_core.embeddings import Embeddings | |
| import statistics | |
| from litellm import embedding | |
| import litellm | |
| import tiktoken | |
| from tqdm import tqdm | |
| from langfuse import Langfuse | |
| from mllm_tools.utils import _prepare_text_inputs | |
| from task_generator import get_prompt_detect_plugins | |
| class RAGVectorStore: | |
| """A class for managing vector stores for RAG (Retrieval Augmented Generation). | |
| This class handles creation, loading and querying of vector stores for both Manim core | |
| and plugin documentation. | |
| Args: | |
| chroma_db_path (str): Path to ChromaDB storage directory | |
| manim_docs_path (str): Path to Manim documentation files | |
| embedding_model (str): Name of the embedding model to use | |
| trace_id (str, optional): Trace identifier for logging. Defaults to None | |
| session_id (str, optional): Session identifier. Defaults to None | |
| use_langfuse (bool, optional): Whether to use Langfuse logging. Defaults to True | |
| helper_model: Helper model for processing. Defaults to None | |
| """ | |
| def __init__(self, | |
| chroma_db_path: str = "chroma_db", | |
| manim_docs_path: str = "rag/manim_docs", | |
| embedding_model: str = "text-embedding-ada-002", | |
| trace_id: str = None, | |
| session_id: str = None, | |
| use_langfuse: bool = True, | |
| helper_model = None): | |
| self.chroma_db_path = chroma_db_path | |
| self.manim_docs_path = manim_docs_path | |
| self.embedding_model = embedding_model | |
| self.trace_id = trace_id | |
| self.session_id = session_id | |
| self.use_langfuse = use_langfuse | |
| self.helper_model = helper_model | |
| self.enc = tiktoken.encoding_for_model("gpt-4") | |
| self.plugin_stores = {} | |
| self.vector_store = self._load_or_create_vector_store() | |
| def _load_or_create_vector_store(self): | |
| """Loads existing or creates new ChromaDB vector stores. | |
| Creates/loads vector stores for both Manim core documentation and any available plugins. | |
| Stores are persisted to disk for future reuse. | |
| Returns: | |
| Chroma: The core Manim vector store instance | |
| """ | |
| print("Entering _load_or_create_vector_store with trace_id:", self.trace_id) | |
| core_path = os.path.join(self.chroma_db_path, "manim_core") | |
| # Load or create core vector store | |
| if os.path.exists(core_path): | |
| print("Loading existing core ChromaDB...") | |
| self.core_vector_store = Chroma( | |
| collection_name="manim_core", | |
| persist_directory=core_path, | |
| embedding_function=self._get_embedding_function() | |
| ) | |
| else: | |
| print("Creating new core ChromaDB...") | |
| self.core_vector_store = self._create_core_store() | |
| # Fix: Use correct path construction for plugin_docs | |
| plugin_docs_path = os.path.join(self.manim_docs_path, "plugin_docs") | |
| print(f"Plugin docs path: {plugin_docs_path}") | |
| if os.path.exists(plugin_docs_path): | |
| for plugin_name in os.listdir(plugin_docs_path): | |
| plugin_store_path = os.path.join(self.chroma_db_path, f"manim_plugin_{plugin_name}") | |
| if os.path.exists(plugin_store_path): | |
| print(f"Loading existing plugin store: {plugin_name}") | |
| self.plugin_stores[plugin_name] = Chroma( | |
| collection_name=f"manim_plugin_{plugin_name}", | |
| persist_directory=plugin_store_path, | |
| embedding_function=self._get_embedding_function() | |
| ) | |
| else: | |
| print(f"Creating new plugin store: {plugin_name}") | |
| plugin_path = os.path.join(plugin_docs_path, plugin_name) | |
| if os.path.isdir(plugin_path): | |
| plugin_store = Chroma( | |
| collection_name=f"manim_plugin_{plugin_name}", | |
| embedding_function=self._get_embedding_function(), | |
| persist_directory=plugin_store_path | |
| ) | |
| plugin_docs = self._process_documentation_folder(plugin_path) | |
| if plugin_docs: | |
| self._add_documents_to_store(plugin_store, plugin_docs, plugin_name) | |
| self.plugin_stores[plugin_name] = plugin_store | |
| return self.core_vector_store # Return core store for backward compatibility | |
| def _get_embedding_function(self) -> Embeddings: | |
| """Creates an embedding function using litellm. | |
| Returns: | |
| Embeddings: A LangChain Embeddings instance that wraps litellm functionality | |
| """ | |
| class LiteLLMEmbeddings(Embeddings): | |
| def __init__(self, embedding_model): | |
| self.embedding_model = embedding_model | |
| def embed_documents(self, texts: list[str]) -> list[list[float]]: | |
| litellm.success_callback = [] | |
| litellm.failure_callback = [] | |
| response = embedding( | |
| model=self.embedding_model, | |
| input=texts, | |
| task_type="CODE_RETRIEVAL_QUERY" if self.embedding_model == "vertex_ai/text-embedding-005" else None | |
| ) | |
| litellm.success_callback = ["langfuse"] | |
| litellm.failure_callback = ["langfuse"] | |
| return [r["embedding"] for r in response["data"]] | |
| def embed_query(self, text: str) -> list[float]: | |
| litellm.success_callback = [] | |
| litellm.failure_callback = [] | |
| response = embedding( | |
| model=self.embedding_model, | |
| input=[text], | |
| task_type="CODE_RETRIEVAL_QUERY" if self.embedding_model == "vertex_ai/text-embedding-005" else None | |
| ) | |
| litellm.success_callback = ["langfuse"] | |
| litellm.failure_callback = ["langfuse"] | |
| return response["data"][0]["embedding"] | |
| return LiteLLMEmbeddings(self.embedding_model) | |
| def _create_core_store(self): | |
| """Creates the main ChromaDB vector store for Manim core documentation. | |
| Returns: | |
| Chroma: The initialized and populated core vector store | |
| """ | |
| core_vector_store = Chroma( | |
| collection_name="manim_core", | |
| embedding_function=self._get_embedding_function(), | |
| persist_directory=os.path.join(self.chroma_db_path, "manim_core") | |
| ) | |
| # Process manim core docs | |
| core_docs = self._process_documentation_folder(os.path.join(self.manim_docs_path, "manim_core")) | |
| if core_docs: | |
| self._add_documents_to_store(core_vector_store, core_docs, "manim_core") | |
| return core_vector_store | |
| def _process_documentation_folder(self, folder_path: str) -> List[Document]: | |
| """Processes documentation files from a folder into LangChain documents. | |
| Args: | |
| folder_path (str): Path to the folder containing documentation files | |
| Returns: | |
| List[Document]: List of processed LangChain documents | |
| """ | |
| all_docs = [] | |
| for root, _, files in os.walk(folder_path): | |
| for file in files: | |
| if file.endswith(('.md', '.py')): | |
| file_path = os.path.join(root, file) | |
| try: | |
| loader = TextLoader(file_path) | |
| documents = loader.load() | |
| for doc in documents: | |
| doc.metadata['source'] = file_path | |
| all_docs.extend(documents) | |
| except Exception as e: | |
| print(f"Error loading file {file_path}: {e}") | |
| if not all_docs: | |
| print(f"No markdown or python files found in {folder_path}") | |
| return [] | |
| # Split documents using appropriate splitters | |
| split_docs = [] | |
| markdown_splitter = RecursiveCharacterTextSplitter.from_language( | |
| language=Language.MARKDOWN | |
| ) | |
| python_splitter = RecursiveCharacterTextSplitter.from_language( | |
| language=Language.PYTHON | |
| ) | |
| for doc in all_docs: | |
| if doc.metadata['source'].endswith('.md'): | |
| temp_docs = markdown_splitter.split_documents([doc]) | |
| for temp_doc in temp_docs: | |
| temp_doc.page_content = f"Source: {doc.metadata['source']}\n\n{temp_doc.page_content}" | |
| split_docs.extend(temp_docs) | |
| elif doc.metadata['source'].endswith('.py'): | |
| temp_docs = python_splitter.split_documents([doc]) | |
| for temp_doc in temp_docs: | |
| temp_doc.page_content = f"Source: {doc.metadata['source']}\n\n{temp_doc.page_content}" | |
| split_docs.extend(temp_docs) | |
| return split_docs | |
| def _add_documents_to_store(self, vector_store: Chroma, documents: List[Document], store_name: str): | |
| """Adds documents to a vector store in batches with rate limiting. | |
| Args: | |
| vector_store (Chroma): The vector store to add documents to | |
| documents (List[Document]): List of documents to add | |
| store_name (str): Name of the store for logging purposes | |
| """ | |
| print(f"Adding documents to {store_name} store") | |
| # Calculate token statistics | |
| token_lengths = [len(self.enc.encode(doc.page_content)) for doc in documents] | |
| print(f"Token length statistics for {store_name}: " | |
| f"Min: {min(token_lengths)}, Max: {max(token_lengths)}, " | |
| f"Mean: {sum(token_lengths) / len(token_lengths):.1f}, " | |
| f"Median: {statistics.median(token_lengths)}, " | |
| f"Std: {statistics.stdev(token_lengths):.1f}") | |
| import time | |
| batch_size = 10 | |
| request_count = 0 | |
| for i in tqdm(range(0, len(documents), batch_size), desc=f"Processing {store_name} batches"): | |
| batch_docs = documents[i:i + batch_size] | |
| batch_ids = [str(uuid.uuid4()) for _ in batch_docs] | |
| vector_store.add_documents(documents=batch_docs, ids=batch_ids) | |
| request_count += 1 | |
| if request_count % 100 == 0: | |
| time.sleep(60) # Sleep for 1 second every 100 requests | |
| vector_store.persist() | |
| def find_relevant_docs(self, queries: List[Dict], k: int = 5, trace_id: str = None, topic: str = None, scene_number: int = None) -> List[str]: | |
| """Finds relevant documentation based on the provided queries. | |
| Args: | |
| queries (List[Dict]): List of query dictionaries with 'type' and 'query' keys | |
| k (int, optional): Number of results to return per query. Defaults to 5 | |
| trace_id (str, optional): Trace identifier for logging. Defaults to None | |
| topic (str, optional): Topic name for logging. Defaults to None | |
| scene_number (int, optional): Scene number for logging. Defaults to None | |
| Returns: | |
| List[str]: Formatted string containing relevant documentation snippets | |
| """ | |
| manim_core_formatted_results = [] | |
| manim_plugin_formatted_results = [] | |
| # Create a Langfuse span if enabled | |
| if self.use_langfuse: | |
| langfuse = Langfuse() | |
| span = langfuse.span( | |
| trace_id=trace_id, # Use the passed trace_id | |
| name=f"RAG search for {topic} - scene {scene_number}", | |
| metadata={ | |
| "topic": topic, | |
| "scene_number": scene_number, | |
| "session_id": self.session_id | |
| } | |
| ) | |
| # Separate queries by type | |
| manim_core_queries = [query for query in queries if query["type"] == "manim-core"] | |
| manim_plugin_queries = [query for query in queries if query["type"] != "manim-core" and query["type"] in self.plugin_stores] | |
| if len([q for q in queries if q["type"] != "manim-core"]) != len(manim_plugin_queries): | |
| print("Warning: Some plugin queries were skipped because their types weren't found in available plugin stores") | |
| # Search in core manim docs | |
| for query in manim_core_queries: | |
| query_text = query["query"] | |
| self.core_vector_store._embedding_function.parent_observation_id = span.id | |
| manim_core_results = self.core_vector_store.similarity_search_with_relevance_scores( | |
| query=query_text, | |
| k=k, | |
| score_threshold=0.5 | |
| ) | |
| for result in manim_core_results: | |
| manim_core_formatted_results.append({ | |
| "query": query_text, | |
| "source": result[0].metadata['source'], | |
| "content": result[0].page_content, | |
| "score": result[1] | |
| }) | |
| # Search in relevant plugin docs | |
| for query in manim_plugin_queries: | |
| plugin_name = query["type"] | |
| query_text = query["query"] | |
| self.plugin_stores[plugin_name]._embedding_function.parent_observation_id = span.id | |
| if plugin_name in self.plugin_stores: | |
| plugin_results = self.plugin_stores[plugin_name].similarity_search_with_relevance_scores( | |
| query=query_text, | |
| k=k, | |
| score_threshold=0.5 | |
| ) | |
| for result in plugin_results: | |
| manim_plugin_formatted_results.append({ | |
| "query": query_text, | |
| "source": result[0].metadata['source'], | |
| "content": result[0].page_content, | |
| "score": result[1] | |
| }) | |
| print(f"Number of results before removing duplicates: {len(manim_core_formatted_results) + len(manim_plugin_formatted_results)}") | |
| # Remove duplicates based on content | |
| manim_core_unique_results = [] | |
| manim_plugin_unique_results = [] | |
| seen = set() | |
| for item in manim_core_formatted_results: | |
| key = item['content'] | |
| if key not in seen: | |
| manim_core_unique_results.append(item) | |
| seen.add(key) | |
| for item in manim_plugin_formatted_results: | |
| key = item['content'] | |
| if key not in seen: | |
| manim_plugin_unique_results.append(item) | |
| seen.add(key) | |
| print(f"Number of results after removing duplicates: {len(manim_core_unique_results) + len(manim_plugin_unique_results)}") | |
| total_tokens = sum(len(self.enc.encode(res['content'])) for res in manim_core_unique_results + manim_plugin_unique_results) | |
| print(f"Total tokens for the RAG search: {total_tokens}") | |
| # Update Langfuse with the deduplicated results | |
| if self.use_langfuse: | |
| filtered_results_markdown = json.dumps(manim_core_unique_results + manim_plugin_unique_results, indent=2) | |
| span.update( # Use span.update, not span.end | |
| output=filtered_results_markdown, | |
| metadata={ | |
| "total_tokens": total_tokens, | |
| "initial_results_count": len(manim_core_formatted_results) + len(manim_plugin_formatted_results), | |
| "filtered_results_count": len(manim_core_unique_results) + len(manim_plugin_unique_results) | |
| } | |
| ) | |
| manim_core_results = "Please refer to the following Manim core documentation that may be helpful for the code generation:\n\n" + "\n\n".join([f"Content:\n````text\n{res['content']}\n````\nScore: {res['score']}" for res in manim_core_unique_results]) | |
| manim_plugin_results = "Please refer to the following Manim plugin documentation that may be helpful for the code generation:\n\n" + "\n\n".join([f"Content:\n````text\n{res['content']}\n````\nScore: {res['score']}" for res in manim_plugin_unique_results]) | |
| return manim_core_results + "\n\n" + manim_plugin_results |