alx-d commited on
Commit
8d12b8e
·
verified ·
1 Parent(s): 6f869ac

Upload folder using huggingface_hub

Browse files
.github/workflows/update_space.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Run Python script
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Checkout
14
+ uses: actions/checkout@v2
15
+
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v2
18
+ with:
19
+ python-version: '3.9'
20
+
21
+ - name: Install Gradio
22
+ run: python -m pip install gradio
23
+
24
+ - name: Log in to Hugging Face
25
+ run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
26
+
27
+ - name: Deploy to Spaces
28
+ run: gradio deploy
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ **/.DS_Store
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Andrew Nedilko
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,48 @@
1
- ---
2
- title: PhiRAG
3
- emoji: 😻
4
- colorFrom: indigo
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.17.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: PhiRAG
3
+ app_file: advanced_rag.py
4
+ sdk: gradio
5
+ sdk_version: 3.40.0
6
+ ---
7
+
8
+ # Advanced RAG System
9
+
10
+ This repository contains the code for a Gradio web app that demoes a Retrieval-Augmented Generation (RAG) system. This app is designed to allow users to load multiple documents of their choice into a vector database, submit queries, and receive answers generated by a sophisticated RAG system that leverages the latest advancements in natural language processing and information retrieval technologies.
11
+
12
+ ## Features
13
+
14
+ #### 1. Dynamic Processing
15
+ - Users can load multiple source documents of their choice into a vector store in real-time.
16
+ - Users can submit queries which are processed in real-time for enhanced retrieval and generation.
17
+
18
+ #### 2. PDF Integration
19
+ - The system allows for the loading of multiple PDF documents into a vector store, enabling the RAG system to retrieve information from a vast corpus.
20
+
21
+ #### 3. Advanced RAG System
22
+ Integrates various components, including:
23
+ - **UI**: Allows users to input URLs for documents and then input user queries; displays the LLM response.
24
+ - **Document Loader**: Loads documents from URLs.
25
+ - **Text Splitter**: Chunks loaded documents.
26
+ - **Vector Store**: Embeds text chunks and adds them to a FAISS vector store; embeds user queries.
27
+ - **Retrievers**: Uses an ensemble of BM25 and FAISS retrievers, along with a Cohere reranker, to retrieve relevant document chunks based on user queries.
28
+ - **Language Model**: Utilizes a Llama 2 large language model for generating responses based on the user query and retrieved context.
29
+
30
+ #### 4. PDF and Query Error Handling
31
+ - Validates PDF URLs and queries to ensure that they are not empty and that they are valid.
32
+ - Displays error messages for empty queries or issues with the RAG system.
33
+
34
+ #### 5. Refresh Mechanism
35
+ - Instructs users to refresh the page to clear / reset the RAG system.
36
+
37
+ ## Installation
38
+
39
+ To run this application, you need to have Python and Gradio installed. Follow these steps:
40
+
41
+ 1. Clone this repository to your local machine.
42
+ 2. Create and activate a virtual environment of your choice (venv, conda, etc.).
43
+ 3. Install dependencies from the requirements.txt file by running `pip install -r requirements.txt`.
44
+ 4. Set up environment variables REPLICATE_API_TOKEN (for a Llama 2 model hosted on replicate.com) and COHERE_API_KEY (for embeddings and reranking service on cohere.com)
45
+ 4. Start the Gradio app by running `python app.py`.
46
+
47
+ ## Licence
48
+ MIT license
__pycache__/advanced_rag.cpython-311.pyc ADDED
Binary file (19.4 kB). View file
 
advanced_rag.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
+ import datetime
4
+ import functools
5
+ import traceback
6
+ from typing import List, Optional, Any, Dict
7
+
8
+ import torch
9
+ import transformers
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
11
+ from langchain_community.llms import HuggingFacePipeline
12
+
13
+ # Other LangChain and community imports
14
+ from langchain_community.document_loaders import OnlinePDFLoader
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+ from langchain_community.vectorstores import FAISS
17
+ from langchain.embeddings import HuggingFaceEmbeddings
18
+ from langchain_community.retrievers import BM25Retriever
19
+ from langchain.retrievers import EnsembleRetriever
20
+ from langchain.prompts import ChatPromptTemplate
21
+ from langchain.schema import StrOutputParser, Document
22
+ from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
23
+ from transformers.quantizers.auto import AutoQuantizationConfig
24
+ import gradio as gr
25
+ import requests
26
+
27
+ # Add Mistral imports with fallback handling
28
+ try:
29
+ # Try importing from the latest package structure
30
+ from mistralai import Mistral
31
+ MISTRAL_AVAILABLE = True
32
+ debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}")
33
+ debug_print("Loaded latest Mistral client library")
34
+ except ImportError:
35
+ MISTRAL_AVAILABLE = False
36
+ debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}")
37
+ debug_print("Mistral client library not found. Install with: pip install mistralai")
38
+
39
+ # Debug print function (already defined above in the try block)
40
+ def debug_print(message: str):
41
+ print(f"[{datetime.datetime.now().isoformat()}] {message}")
42
+
43
+ def word_count(text: str) -> int:
44
+ return len(text.split())
45
+
46
+ # Initialize tokenizer for counting
47
+ def initialize_tokenizer():
48
+ try:
49
+ return AutoTokenizer.from_pretrained("gpt2")
50
+ except Exception as e:
51
+ debug_print("Failed to initialize tokenizer: " + str(e))
52
+ return None
53
+
54
+ global_tokenizer = initialize_tokenizer()
55
+
56
+ def count_tokens(text: str) -> int:
57
+ if global_tokenizer:
58
+ try:
59
+ return len(global_tokenizer.encode(text))
60
+ except Exception as e:
61
+ return len(text.split())
62
+ return len(text.split())
63
+
64
+ # Updated prompt template to include conversation history
65
+ default_prompt = """\
66
+ {conversation_history}
67
+ Use the following context to provide a detailed technical answer to the user's question.
68
+ Do not include an introduction like "Based on the provided documents, ...". Just answer the question.
69
+ If you don't know the answer, please respond with "I don't know".
70
+
71
+ Context:
72
+ {context}
73
+
74
+ User's question:
75
+ {question}
76
+ """
77
+
78
+ # Helper function to load TXT files from URL with error checking
79
+ def load_txt_from_url(url: str) -> Document:
80
+ response = requests.get(url)
81
+ if response.status_code == 200:
82
+ text = response.text.strip()
83
+ if not text:
84
+ raise ValueError(f"TXT file at {url} is empty.")
85
+ return Document(page_content=text, metadata={"source": url})
86
+ else:
87
+ raise Exception(f"Failed to load {url} with status {response.status_code}")
88
+
89
+
90
+ class ElevatedRagChain:
91
+ def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
92
+ bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95) -> None:
93
+ debug_print(f"Initializing ElevatedRagChain with model: {llm_choice}")
94
+
95
+ # Check for required API keys based on model choice
96
+ if "mistral-api" in llm_choice.lower() and not os.environ.get("MISTRAL_API_KEY"):
97
+ debug_print("WARNING: Mistral API selected but MISTRAL_API_KEY environment variable not set")
98
+ if not MISTRAL_AVAILABLE:
99
+ debug_print("WARNING: Mistral API package not installed. Install with: pip install mistralai")
100
+
101
+ self.embed_func = HuggingFaceEmbeddings(
102
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
103
+ model_kwargs={"device": "cpu"}
104
+ )
105
+ self.bm25_weight = bm25_weight
106
+ self.faiss_weight = 1.0 - bm25_weight
107
+ self.top_k = 5
108
+ self.llm_choice = llm_choice
109
+ self.temperature = temperature
110
+ self.top_p = top_p
111
+ self.prompt_template = prompt_template
112
+ self.context = ""
113
+ self.conversation_history: List[Dict[str, str]] = [] # List of dicts with keys "query" and "response"
114
+
115
+ def create_llm_pipeline(self):
116
+ if "remote" in self.llm_choice.lower():
117
+ debug_print("Creating remote Meta-Llama-3 pipeline via Hugging Face Inference API...")
118
+ from huggingface_hub import InferenceClient
119
+ repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
120
+ hf_api_token = os.environ.get("HF_API_TOKEN")
121
+ if not hf_api_token:
122
+ raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
123
+ client = InferenceClient(token=hf_api_token)
124
+
125
+ def remote_generate(prompt: str) -> str:
126
+ response = client.text_generation(
127
+ prompt,
128
+ model=repo_id,
129
+ # max_new_tokens=512,
130
+ temperature=self.temperature,
131
+ top_p=self.top_p,
132
+ repetition_penalty=1.1
133
+ )
134
+ return response
135
+
136
+ from langchain.llms.base import LLM
137
+ class RemoteLLM(LLM):
138
+ @property
139
+ def _llm_type(self) -> str:
140
+ return "remote_llm"
141
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
142
+ return remote_generate(prompt)
143
+ @property
144
+ def _identifying_params(self) -> dict:
145
+ return {"model": repo_id}
146
+ debug_print("Remote Meta-Llama-3 pipeline created successfully.")
147
+ return RemoteLLM()
148
+ elif "mistral-api" in self.llm_choice.lower():
149
+ debug_print("Creating Mistral API pipeline...")
150
+
151
+ mistral_api_key = os.environ.get("MISTRAL_API_KEY")
152
+ if not mistral_api_key:
153
+ raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
154
+
155
+ if not MISTRAL_AVAILABLE:
156
+ raise ImportError("Mistral client library not installed. Install with: pip install mistralai")
157
+
158
+ # Initialize the Mistral client with latest API
159
+ mistral_client = Mistral(api_key=mistral_api_key)
160
+
161
+ # Define the model to use - updated to match current model names
162
+ mistral_model = "mistral-small-latest"
163
+
164
+ from langchain.llms.base import LLM
165
+ class MistralLLM(LLM):
166
+ temperature: float = 0.7
167
+ top_p: float = 0.95
168
+
169
+ def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95):
170
+ super().__init__() # Important to call the parent constructor
171
+ self.client = Mistral(api_key=api_key)
172
+ self.temperature = temperature
173
+ self.top_p = top_p
174
+
175
+ @property
176
+ def _llm_type(self) -> str:
177
+ return "mistral_llm"
178
+
179
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
180
+ response = self.client.chat.complete(
181
+ model="mistral-small-latest", # Replace with the actual model name if different
182
+ messages=[{"role": "user", "content": prompt}],
183
+ temperature=self.temperature,
184
+ top_p=self.top_p,
185
+ max_tokens=512
186
+ )
187
+ return response.choices[0].message.content
188
+
189
+ @property
190
+ def _identifying_params(self) -> dict:
191
+ return {"model": "mistral-small-latest"}
192
+
193
+ # Initialize and return the MistralLLM instance
194
+ mistral_llm = MistralLLM(api_key=mistral_api_key, temperature=self.temperature, top_p=self.top_p)
195
+ debug_print("Mistral API pipeline created successfully.")
196
+ return mistral_llm
197
+
198
+ else:
199
+ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
200
+ if "deepseek" in self.llm_choice.lower():
201
+ model_id = "deepseek-ai/DeepSeek-R1"
202
+ elif "gemini" in self.llm_choice.lower():
203
+ model_id = "gemini/flash-1.5"
204
+ elif "mistralai" in self.llm_choice.lower():
205
+ model_id = "mistralai/Mistral-Small-24B-Instruct-2501"
206
+
207
+ pipe = pipeline(
208
+ "text-generation",
209
+ model=model_id,
210
+ model_kwargs={"torch_dtype": torch.bfloat16},
211
+ max_length=4096,
212
+ do_sample=True,
213
+ temperature=self.temperature,
214
+ top_p=self.top_p,
215
+ device=-1
216
+ )
217
+ return HuggingFacePipeline(pipeline=pipe)
218
+
219
+ def add_pdfs_to_vectore_store(self, file_links: List[str]) -> None:
220
+ debug_print(f"Processing files using {self.llm_choice}")
221
+ self.raw_data = []
222
+ for link in file_links:
223
+ if link.lower().endswith(".pdf"):
224
+ debug_print(f"Loading PDF: {link}")
225
+ # Ensure that the PDF loader returns a non-empty list.
226
+ loaded_docs = OnlinePDFLoader(link).load()
227
+ if loaded_docs:
228
+ self.raw_data.append(loaded_docs[0])
229
+ else:
230
+ debug_print(f"No content found in PDF: {link}")
231
+ elif link.lower().endswith(".txt") or link.lower().endswith(".utf-8"):
232
+ debug_print(f"Loading TXT: {link}")
233
+ try:
234
+ self.raw_data.append(load_txt_from_url(link))
235
+ except Exception as e:
236
+ debug_print(f"Error loading TXT file {link}: {e}")
237
+ else:
238
+ debug_print(f"File type not supported for URL: {link}")
239
+
240
+ if not self.raw_data:
241
+ raise ValueError("No files were successfully loaded. Please check the URLs and file formats.")
242
+
243
+ debug_print("Files loaded successfully.")
244
+
245
+ debug_print("Starting text splitting...")
246
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=100)
247
+ self.split_data = self.text_splitter.split_documents(self.raw_data)
248
+ if not self.split_data:
249
+ raise ValueError("Text splitting resulted in no chunks. Check the file contents.")
250
+ debug_print(f"Text splitting completed. Number of chunks: {len(self.split_data)}")
251
+
252
+ debug_print("Creating BM25 retriever...")
253
+ self.bm25_retriever = BM25Retriever.from_documents(self.split_data)
254
+ self.bm25_retriever.k = self.top_k
255
+ debug_print("BM25 retriever created.")
256
+
257
+ debug_print("Embedding chunks and creating FAISS vector store...")
258
+ self.vector_store = FAISS.from_documents(self.split_data, self.embed_func)
259
+ self.faiss_retriever = self.vector_store.as_retriever(search_kwargs={"k": self.top_k})
260
+ debug_print("FAISS vector store created successfully.")
261
+
262
+ ensemble = EnsembleRetriever(
263
+ retrievers=[self.bm25_retriever, self.faiss_retriever],
264
+ weights=[self.bm25_weight, self.faiss_weight]
265
+ )
266
+
267
+ def capture_context(result):
268
+ # Convert each Document to a string and update the context.
269
+ self.context = "\n".join([str(doc) for doc in result["context"]])
270
+ result["context"] = self.context
271
+ # Add conversation_history from self.conversation_history (if any) as a string.
272
+ history_text = (
273
+ "\n".join([f"Q: {conv['query']}\nA: {conv['response']}" for conv in self.conversation_history])
274
+ if self.conversation_history else ""
275
+ )
276
+ result["conversation_history"] = history_text
277
+ return result
278
+
279
+ def extract_question(input_data):
280
+ # Expecting input_data to be a dict with a key "question"
281
+ return input_data["question"]
282
+
283
+ # Build the chain so that the ensemble (BM25 + FAISS) gets only the question string.
284
+ base_runnable = RunnableParallel({
285
+ "context": RunnableLambda(extract_question) | ensemble,
286
+ "question": RunnableLambda(extract_question)
287
+ }) | capture_context
288
+
289
+ self.rag_prompt = ChatPromptTemplate.from_template(self.prompt_template)
290
+ self.str_output_parser = StrOutputParser()
291
+ debug_print("Selecting LLM pipeline based on choice: " + self.llm_choice)
292
+ self.llm = self.create_llm_pipeline()
293
+
294
+ def format_response(response: str) -> str:
295
+ input_tokens = count_tokens(self.context + self.prompt_template)
296
+ output_tokens = count_tokens(response)
297
+ # Format the response as Markdown for better visual rendering
298
+ formatted = f"### Response\n\n{response}\n\n---\n"
299
+ formatted += f"- **Input tokens:** {input_tokens}\n"
300
+ formatted += f"- **Output tokens:** {output_tokens}\n"
301
+ formatted += f"- **Generated using:** {self.llm_choice}\n"
302
+ # Append conversation history summary
303
+ formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
304
+ return formatted
305
+
306
+ self.elevated_rag_chain = base_runnable | self.rag_prompt | self.llm | format_response
307
+ debug_print("Elevated RAG chain successfully built and ready to use.")
308
+
309
+ def get_current_context(self) -> str:
310
+ # Show a sample of the document context along with a summary of conversation history.
311
+ base_context = "\n".join([str(doc) for doc in self.split_data[:3]]) if hasattr(self, "split_data") and self.split_data else "No context available."
312
+ history_summary = "\n\n---\n**Recent Conversations (last 3):**\n"
313
+ recent = self.conversation_history[-3:]
314
+ if recent:
315
+ for i, conv in enumerate(recent, 1):
316
+ history_summary += f"**Conversation {i}:**\n- Query: {conv['query']}\n- Response: {conv['response']}\n"
317
+ else:
318
+ history_summary += "No conversation history."
319
+ return base_context + history_summary
320
+
321
+ # ----------------------------
322
+ # Gradio Interface Functions
323
+ # ----------------------------
324
+ global rag_chain
325
+ rag_chain = ElevatedRagChain()
326
+
327
+ def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
328
+ debug_print("Inside load_pdfs function.")
329
+ if not file_links:
330
+ debug_print("Please enter non-empty URLs")
331
+ return "Please enter non-empty URLs", "Word count: N/A", "Model used: N/A", "Context: N/A"
332
+ try:
333
+ links = [link.strip() for link in file_links.split("\n") if link.strip()]
334
+ global rag_chain
335
+ rag_chain = ElevatedRagChain(
336
+ llm_choice=model_choice,
337
+ prompt_template=prompt_template,
338
+ bm25_weight=bm25_weight,
339
+ temperature=temperature,
340
+ top_p=top_p
341
+ )
342
+ rag_chain.add_pdfs_to_vectore_store(links)
343
+ context_display = rag_chain.get_current_context()
344
+ response_msg = f"Files loaded successfully. Using model: {model_choice}"
345
+ debug_print(response_msg)
346
+ return (
347
+ response_msg,
348
+ f"Word count: {word_count(rag_chain.context)}",
349
+ f"Model used: {rag_chain.llm_choice}",
350
+ f"Context:\n{context_display}"
351
+ )
352
+ except Exception as e:
353
+ error_msg = traceback.format_exc()
354
+ debug_print("Could not load files. Error: " + error_msg)
355
+ return (
356
+ "Error loading files: " + str(e),
357
+ f"Word count: {word_count('')}",
358
+ f"Model used: {rag_chain.llm_choice}",
359
+ "Context: N/A"
360
+ )
361
+
362
+ def submit_query_updated(query):
363
+ debug_print("Inside submit_query function.")
364
+ if not query:
365
+ debug_print("Please enter a non-empty query")
366
+ return "Please enter a non-empty query", "Word count: 0", f"Model used: {rag_chain.llm_choice}", ""
367
+ if hasattr(rag_chain, 'elevated_rag_chain'):
368
+ try:
369
+ # Incorporate conversation history by joining previous Q&A pairs.
370
+ history_text = ""
371
+ if rag_chain.conversation_history:
372
+ history_text = "\n".join([f"Q: {conv['query']}\nA: {conv['response']}" for conv in rag_chain.conversation_history])
373
+
374
+ # Build the prompt variables dictionary for the chain.
375
+ prompt_variables = {
376
+ "conversation_history": history_text,
377
+ "context": rag_chain.context,
378
+ "question": query
379
+ }
380
+
381
+ response = rag_chain.elevated_rag_chain.invoke(prompt_variables)
382
+ # Save the current conversation to history
383
+ rag_chain.conversation_history.append({"query": query, "response": response})
384
+ input_token_count = count_tokens(query)
385
+ output_token_count = count_tokens(response)
386
+ return (
387
+ response,
388
+ rag_chain.get_current_context(),
389
+ f"Input tokens: {input_token_count}",
390
+ f"Output tokens: {output_token_count}"
391
+ )
392
+ except Exception as e:
393
+ error_msg = traceback.format_exc()
394
+ debug_print("LLM error. Error: " + error_msg)
395
+ return (
396
+ "Query error: " + str(e),
397
+ "",
398
+ "Input tokens: 0",
399
+ "Output tokens: 0"
400
+ )
401
+ return (
402
+ "Please load files first.",
403
+ "",
404
+ "Input tokens: 0",
405
+ "Output tokens: 0"
406
+ )
407
+
408
+ def reset_app_updated():
409
+ global rag_chain
410
+ rag_chain = ElevatedRagChain()
411
+ debug_print("App reset successfully.")
412
+ return (
413
+ "App reset successfully. You can now load new files",
414
+ "",
415
+ "Model used: Not selected"
416
+ )
417
+
418
+ # ----------------------------
419
+ # Gradio Interface Setup
420
+ # ----------------------------
421
+ custom_css = """
422
+ button {
423
+ background-color: grey !important;
424
+ font-family: Arial !important;
425
+ font-weight: bold !important;
426
+ color: blue !important;
427
+ }
428
+ """
429
+
430
+ with gr.Blocks(css=custom_css) as app:
431
+ gr.Markdown('''# PhiRAG
432
+ **PhiRAG** Query Your Data with Advanced RAG Techniques
433
+
434
+ **Model Selection & Parameters:** Choose from the following options:
435
+ - 🇺🇸 Remote Meta-Llama-3
436
+ - 🇪🇺 Mistral-API
437
+
438
+ **🔥 Randomness (Temperature):** Temperature adjusts how predictable or varied the output is. A low temperature makes the model choose very predictable words (which can be repetitive), while a high temperature introduces more randomness for diverse, creative text.
439
+
440
+ **🎯 Word Variety (Top‑p):** Top‑p limits the model’s word choices to those that make up a set percentage (p) of the total probability. Lower values yield focused outputs; higher values increase variety and creativity.
441
+
442
+ **✏️ Prompt Template:** Edit the prompt template if desired.
443
+
444
+ **🔗 File URLs:** Enter one or more file URLs (PDF or TXT, one per line).
445
+
446
+ **⚖️ Weight Controls:** Adjust Lexical vs Semantics (BM25 Weight).
447
+
448
+ **🔍 Query:** Enter your query below.
449
+
450
+ The response displays the model used, word count, and the current context (including conversation history).
451
+ """
452
+ ''')
453
+ with gr.Row():
454
+ with gr.Column():
455
+ model_dropdown = gr.Dropdown(
456
+ choices=[
457
+ "🇺🇸 Remote Meta-Llama-3",
458
+ "🇪🇺 Mistral-API"
459
+ # "DeepSeek-R1", # Option commented out
460
+ # "Gemini Flash 1.5", # Option commented out
461
+ # "Mistralai/Mistral-Small-24B-Instruct-2501" # Option commented out
462
+ ],
463
+ value="🇺🇸 Remote Meta-Llama-3",
464
+ label="Select Model"
465
+ )
466
+ temperature_slider = gr.Slider(
467
+ minimum=0.1, maximum=1.0, value=0.5, step=0.1,
468
+ label="Randomness (Temperature)"
469
+ )
470
+ top_p_slider = gr.Slider(
471
+ minimum=0.1, maximum=0.99, value=0.95, step=0.05,
472
+ label="Word Variety (Top-p)"
473
+ )
474
+ with gr.Column():
475
+ pdf_input = gr.Textbox(
476
+ label="Enter your file URLs (one per line)",
477
+ placeholder="Enter one URL per line (.pdf or .txt)",
478
+ lines=4
479
+ )
480
+ prompt_input = gr.Textbox(
481
+ label="Custom Prompt Template",
482
+ placeholder="Enter your custom prompt template here",
483
+ lines=8,
484
+ value=default_prompt
485
+ )
486
+ with gr.Column():
487
+ bm25_weight_slider = gr.Slider(
488
+ minimum=0.0, maximum=1.0, value=0.6, step=0.1,
489
+ label="Lexical vs Semantics (BM25 Weight)"
490
+ )
491
+ load_button = gr.Button("Load Files")
492
+
493
+ with gr.Row():
494
+ with gr.Column():
495
+ query_input = gr.Textbox(
496
+ label="Enter your query here",
497
+ placeholder="Type your query",
498
+ lines=4
499
+ )
500
+ submit_button = gr.Button("Submit")
501
+ with gr.Column():
502
+ reset_button = gr.Button("Reset App")
503
+
504
+ with gr.Row():
505
+ response_output = gr.Textbox(
506
+ label="Response",
507
+ placeholder="Response will appear here (formatted as Markdown)",
508
+ lines=6
509
+ )
510
+ context_output = gr.Textbox(
511
+ label="Current Context",
512
+ placeholder="Retrieved context and conversation history will appear here",
513
+ lines=6
514
+ )
515
+
516
+ with gr.Row():
517
+ input_tokens = gr.Markdown("Input tokens: 0")
518
+ output_tokens = gr.Markdown("Output tokens: 0")
519
+ model_output = gr.Markdown("**Current Model**: Not selected")
520
+
521
+ load_button.click(
522
+ load_pdfs_updated,
523
+ inputs=[pdf_input, model_dropdown, prompt_input, bm25_weight_slider, temperature_slider, top_p_slider],
524
+ outputs=[response_output, context_output, model_output]
525
+ )
526
+
527
+ submit_button.click(
528
+ submit_query_updated,
529
+ inputs=[query_input],
530
+ outputs=[response_output, context_output, input_tokens, output_tokens]
531
+ )
532
+
533
+ reset_button.click(
534
+ reset_app_updated,
535
+ inputs=[],
536
+ outputs=[response_output, context_output, model_output]
537
+ )
538
+
539
+ if __name__ == "__main__":
540
+ debug_print("Launching Gradio interface.")
541
+ app.launch(share=True)
dropdown.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def test_fn(x):
4
+ return x
5
+
6
+ with gr.Blocks() as demo:
7
+ dropdown = gr.Dropdown(choices=["Option 1", "Option 2"], value="Option 1", label="Select Option")
8
+ demo_button = gr.Button("Submit")
9
+ output = gr.Textbox(label="Output")
10
+ demo_button.click(test_fn, inputs=dropdown, outputs=output)
11
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio<=3.x
2
+ langchain==0.1.6
3
+ langchain-community==0.0.19
4
+ langchain_core==0.1.22
5
+ langchain-openai==0.0.5
6
+ faiss-cpu==1.7.3
7
+ huggingface-hub==0.20.1
8
+ google-generativeai==0.3.2
9
+ openai==1.11.1
10
+ opencv-python==4.9.0.80
11
+ pdf2image==1.17.0
12
+ pdfminer-six==20221105
13
+ pikepdf==8.12.0
14
+ pypdf==4.0.1
15
+ rank-bm25==0.2.2
16
+ replicate==0.23.1
17
+ tiktoken==0.5.2
18
+ unstructured==0.12.3
19
+ unstructured-pytesseract==0.3.12
20
+ unstructured-inference==0.7.23
21
+
22
+ # generated
23
+
24
+ # Transformers for the DeepSeek model and cross-encoder reranker
25
+ transformers>=4.34.0
26
+
27
+ # PyTorch required by DeepSeek and many Hugging Face models
28
+ torch>=2.0.0
29
+
30
+ # LangChain (the main package) – adjust the version if needed
31
+ langchain>=0.0.200
32
+
33
+ # LangChain Community components (for document loaders, vector stores, retrievers, etc.)
34
+ langchain-community
35
+
36
+ # LangChain Core components (for runnables, etc.)
37
+ langchain-core
38
+
39
+ # SentenceTransformers for embedding via HuggingFaceEmbeddings
40
+ sentence-transformers
41
+
42
+ # FAISS for vector storage and similarity search (CPU version)
43
+ faiss-cpu
44
+
45
+ # PDF parsing (e.g., used by OnlinePDFLoader)
46
+ pdfminer.six
47
+
48
+ # Pin Pydantic to a version < 2 (to avoid compatibility issues with LangChain)
49
+ pydantic<2