Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import tempfile | |
| from threading import Thread | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain.chains import RetrievalQA | |
| from langchain_community.llms import HuggingFacePipeline | |
| from transformers import TextIteratorStreamer, AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from book_title_extractor import BookTitleExtractor | |
| from duplicate_detector import DuplicateDetector | |
| class StreamingHanlder(): | |
| def __init__(self): | |
| self.buffer =[] | |
| self.token_callback = None | |
| def on_llm_new_token(self, token:str, **kwargs): | |
| self.buffer.append(token) | |
| if self.token_callback: | |
| self.token_callback(token) | |
| class RagEngine: | |
| def _load_vectorstore(self): | |
| if os.path.exists(os.path.join(self.persist_dir, "chroma.sqlite3")): | |
| self.vectorstore = Chroma( | |
| persist_directory=self.persist_dir, | |
| embedding_function=self.embedding | |
| ) | |
| self.retriever = self.vectorstore.as_retriever() | |
| def __init__(self, persist_dir="chroma_store",embed_model= "nomic-embed-text",llm_model="qwen:1.8b", temp_dir ="chroma_temp"): | |
| self.temp_dir = temp_dir | |
| os.makedirs(self.temp_dir, exist_ok=True) | |
| self.duplicate_detector = DuplicateDetector() | |
| self.title_extractor = BookTitleExtractor() | |
| self.embedding = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| self.vectorstore =None | |
| self.retriever = None | |
| self.persist_dir = "chroma_temp" | |
| self._load_vectorstore() | |
| self.model_id = "Qwen/Qwen-1_8B-Chat" | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code = True) | |
| self.model = AutoModelForCausalLM.from_pretrained(self.model_id, | |
| trust_remote_code = True, | |
| device_map ="auto", | |
| torch_dtype = "auto") | |
| self.model.eval() | |
| def clear_temp(self): | |
| shutil.rmtree(self.temp_dir,ignore_errors=True) | |
| os.makedirs(self.temp_dir, exist_ok=True) | |
| def index_pdf(self, pdf_path): | |
| if self.duplicate_detector.is_duplicate(pdf_path): | |
| raise ValueError(f"duplicate book detected, skipping index of: {pdf_path}") | |
| return | |
| self.duplicate_detector.store_fingerprints(pdf_path) | |
| self.clear_temp() | |
| filename = os.path.basename(pdf_path) | |
| loader = PyPDFLoader(pdf_path) | |
| documents = loader.load() | |
| title = self.title_extractor.extract_book_title_from_documents(documents,max_docs=10) | |
| for doc in documents: | |
| doc.metadata["source"] = title | |
| documents = [doc for doc in documents if doc.page_content.strip()] | |
| if not documents: | |
| raise ValueError("No Reasonable text in uploaded pdf") | |
| splitter = RecursiveCharacterTextSplitter(chunk_size = 1000,chunk_overlap = 500 ) | |
| chunks = splitter.split_documents(documents) | |
| if self.vectorstore is None: | |
| self.vectorstore = Chroma.from_documents( | |
| documents=chunks, | |
| embedding=self.embedding, | |
| persist_directory=self.temp_dir | |
| ) | |
| self.vectorstore.persist() | |
| else: | |
| self.vectorstore.add_documents(chunks) | |
| self.vectorstore.persist() | |
| self.retriever = self.vectorstore.as_retriever() | |
| def stream_answer(self, question): | |
| if not self.retriever: | |
| yield "data: β Please upload and index a PDF first.\n\n" | |
| return | |
| docs = self.retriever.get_relevant_documents(question) | |
| if not docs: | |
| yield "data: β No relevant documents found.\n\n" | |
| return | |
| sources = [] | |
| for doc in docs: | |
| title = doc.metadata.get("source", "Unknown Title") | |
| page = doc.metadata.get("page", "Unknown Page") | |
| sources.append(f"{title} - Page {page}") | |
| context = "\n\n".join([doc.page_content for doc in docs[:3]]) | |
| system_prompt = "You are a helpful assistant that only replies in English." | |
| user_prompt = f"Context:\n{context}\n\nQuestion: {question}" | |
| prompt = ( | |
| "<|im_start|>system\nYou are a helpful assistant that only replies in English.<|im_end|>\n" | |
| f"<|im_start|>user\nContext:\n{context}\n\nQuestion: {question}<|im_end|>\n" | |
| "<|im_start|>assistant\n" | |
| ) | |
| print (prompt) | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
| print("π’ Prompt token length:", inputs['input_ids'].shape[-1]) | |
| streamer = TextIteratorStreamer( | |
| tokenizer=self.tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True | |
| ) | |
| generation_args = { | |
| "input_ids": inputs["input_ids"], | |
| "attention_mask": inputs["attention_mask"], | |
| "max_new_tokens": 512, | |
| "streamer": streamer, | |
| "do_sample": False, | |
| "temperature": 0.0, | |
| "top_p": 0.95, | |
| } | |
| thread = Thread(target=self.model.generate, kwargs=generation_args) | |
| thread.start() | |
| collected_tokens = [] | |
| for token in streamer: | |
| if token.strip(): # Filter out whitespace | |
| collected_tokens.append(token) | |
| yield f"{token} " | |
| if sources: | |
| sources_text = "\n\nπ **Sources:**\n" + "\n".join(set(sources)) | |
| for line in sources_text.splitlines(): | |
| if line.strip(): | |
| yield f"{line} \n" | |
| yield "\n\n" | |
| def ask_question(self, question): | |
| print (question) | |
| if not self.qa_chain : | |
| return "please upload and index pdf document first" | |
| result = self.qa_chain({"query":question}) | |
| answer = result["result"] | |
| sources =[] | |
| for doc in result["source_documents"]: | |
| source = doc.metadata.get("source", "Unknown") | |
| sources.append(source) | |
| print (answer) | |
| return { | |
| "answer": answer, | |
| "sources": list(set(sources)) # Remove duplicates | |
| } | |
| def ask_question_stream(self, question: str): | |
| if not self.qa_chain: | |
| yield "β Please upload and index a PDF document first." | |
| return | |
| from queue import Queue, Empty | |
| import threading | |
| q = Queue() | |
| def token_callback(token): | |
| q.put(token) | |
| self.handler.buffer = [] | |
| self.handler.token_callback = token_callback | |
| def run(): | |
| result = self.qa_chain.invoke({"query": question}) | |
| print (result) | |
| self._latest_result = result | |
| q.put(None) | |
| threading.Thread(target=run).start() | |
| print("Threading started", flush=True) | |
| while True: | |
| try: | |
| token = q.get(timeout=30) | |
| if token is None: | |
| print("Stream finished", flush=True) | |
| break | |
| yield token | |
| except Empty: | |
| print("Timed out waiting for token", flush=True) | |
| break | |
| sources = [] | |
| for doc in self._latest_result.get("source_documents",[] ): | |
| source = doc.metadata.get("source", "Unknown") | |
| sources.append(source) | |
| if sources: | |
| yield "\n\nπ **Sources:**\n" | |
| for i, src in enumerate(set(sources)): | |
| yield f"[{i+1}] {src}\n" |