Upload folder using huggingface_hub
Browse files- advanced_rag.py +283 -90
- psyllm.py +0 -0
- requirements.txt +3 -0
advanced_rag.py
CHANGED
|
@@ -21,6 +21,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
| 21 |
from langchain_community.vectorstores import FAISS
|
| 22 |
from langchain.embeddings import HuggingFaceEmbeddings
|
| 23 |
from langchain_community.retrievers import BM25Retriever
|
|
|
|
| 24 |
from langchain.retrievers import EnsembleRetriever
|
| 25 |
from langchain.prompts import ChatPromptTemplate
|
| 26 |
from langchain.schema import StrOutputParser, Document
|
|
@@ -269,6 +270,51 @@ def count_tokens(text: str) -> int:
|
|
| 269 |
return len(text.split())
|
| 270 |
return len(text.split())
|
| 271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
# Add these imports at the top of your file
|
| 274 |
import uuid
|
|
@@ -299,13 +345,11 @@ def process_in_background(job_id, function, args):
|
|
| 299 |
debug_print(error_msg)
|
| 300 |
results_queue.put((job_id, (error_msg, None, "", "Input tokens: 0", "Output tokens: 0")))
|
| 301 |
|
| 302 |
-
def load_pdfs_async(file_links,
|
| 303 |
"""Asynchronous version of load_pdfs_updated to prevent timeouts"""
|
| 304 |
global last_job_id
|
| 305 |
if not file_links:
|
| 306 |
-
return "Please enter non-empty URLs", "", "Model used: N/A", "", "", get_job_list(), ""
|
| 307 |
-
global slider_max_tokens
|
| 308 |
-
slider_max_tokens = max_tokens_slider
|
| 309 |
|
| 310 |
|
| 311 |
job_id = str(uuid.uuid4())
|
|
@@ -314,7 +358,7 @@ def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temp
|
|
| 314 |
# Start background thread
|
| 315 |
threading.Thread(
|
| 316 |
target=process_in_background,
|
| 317 |
-
args=(job_id, load_pdfs_updated, [file_links,
|
| 318 |
).start()
|
| 319 |
|
| 320 |
job_query = f"Loading files: {file_links.split()[0]}..." if file_links else "No files"
|
|
@@ -333,7 +377,7 @@ def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temp
|
|
| 333 |
f"Files submitted and processing in the background (Job ID: {job_id}).\n\n"
|
| 334 |
f"Use 'Check Job Status' tab with this ID to get results.",
|
| 335 |
f"Job ID: {job_id}",
|
| 336 |
-
f"
|
| 337 |
job_id, # Return job_id to update the job_id_input component
|
| 338 |
job_query, # Return job_query to update the job_query_display component
|
| 339 |
get_job_list(), # Return updated job list
|
|
@@ -343,7 +387,20 @@ def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temp
|
|
| 343 |
def submit_query_async(query, model_choice, max_tokens_slider, temperature, top_p, top_k, bm25_weight, use_history):
|
| 344 |
"""Submit a query asynchronously"""
|
| 345 |
try:
|
| 346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
if not use_history:
|
| 348 |
rag_chain.conversation_history = []
|
| 349 |
debug_print("Conversation history cleared")
|
|
@@ -726,12 +783,11 @@ def load_file_from_google_drive(link: str) -> list:
|
|
| 726 |
|
| 727 |
class ElevatedRagChain:
|
| 728 |
def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
|
| 729 |
-
bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95, top_k: int = 50
|
|
|
|
| 730 |
debug_print(f"Initializing ElevatedRagChain with model: {llm_choice}")
|
| 731 |
-
self.
|
| 732 |
-
|
| 733 |
-
model_kwargs={"device": "cpu"}
|
| 734 |
-
)
|
| 735 |
self.bm25_weight = bm25_weight
|
| 736 |
self.faiss_weight = 1.0 - bm25_weight
|
| 737 |
self.top_k = top_k
|
|
@@ -745,6 +801,57 @@ class ElevatedRagChain:
|
|
| 745 |
self.split_data = None
|
| 746 |
self.elevated_rag_chain = None
|
| 747 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 748 |
# Instance method to capture context and conversation history
|
| 749 |
def capture_context(self, result):
|
| 750 |
self.context = "\n".join([str(doc) for doc in result["context"]])
|
|
@@ -761,11 +868,10 @@ class ElevatedRagChain:
|
|
| 761 |
return input_data["question"]
|
| 762 |
|
| 763 |
# Improve error handling in the ElevatedRagChain class
|
| 764 |
-
def create_llm_pipeline(self):
|
| 765 |
from langchain.llms.base import LLM # Import LLM here so it's always defined
|
| 766 |
from typing import Optional, List, Any
|
| 767 |
from pydantic import PrivateAttr
|
| 768 |
-
global slider_max_tokens
|
| 769 |
|
| 770 |
# Extract the model name without the flag emoji prefix
|
| 771 |
clean_llm_choice = self.llm_choice.split(" ", 1)[-1] if " " in self.llm_choice else self.llm_choice
|
|
@@ -818,7 +924,8 @@ class ElevatedRagChain:
|
|
| 818 |
raise ValueError(f"Unsupported model: {normalized}")
|
| 819 |
model = model_map[model_key]
|
| 820 |
max_tokens = model_token_limits.get(model, 4096)
|
| 821 |
-
|
|
|
|
| 822 |
pricing_info = model_pricing.get(model_key, {"USD": {"input": 0.00, "output": 0.00}, "RON": {"input": 0.00, "output": 0.00}})
|
| 823 |
|
| 824 |
try:
|
|
@@ -1145,7 +1252,7 @@ class ElevatedRagChain:
|
|
| 1145 |
global rag_chain
|
| 1146 |
rag_chain = ElevatedRagChain()
|
| 1147 |
|
| 1148 |
-
def load_pdfs_updated(file_links,
|
| 1149 |
debug_print("Inside load_pdfs function.")
|
| 1150 |
if not file_links:
|
| 1151 |
debug_print("Please enter non-empty URLs")
|
|
@@ -1154,31 +1261,35 @@ def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, te
|
|
| 1154 |
links = [link.strip() for link in file_links.split("\n") if link.strip()]
|
| 1155 |
global rag_chain
|
| 1156 |
if rag_chain.raw_data:
|
| 1157 |
-
|
|
|
|
|
|
|
|
|
|
| 1158 |
context_display = rag_chain.get_current_context()
|
| 1159 |
-
response_msg = f"Files already loaded.
|
| 1160 |
return (
|
| 1161 |
response_msg,
|
| 1162 |
f"Word count: {word_count(rag_chain.context)}",
|
| 1163 |
-
f"
|
| 1164 |
f"Context:\n{context_display}"
|
| 1165 |
)
|
| 1166 |
else:
|
| 1167 |
rag_chain = ElevatedRagChain(
|
| 1168 |
-
llm_choice=
|
| 1169 |
prompt_template=prompt_template,
|
| 1170 |
bm25_weight=bm25_weight,
|
| 1171 |
-
temperature=
|
| 1172 |
-
top_p=
|
| 1173 |
-
top_k=
|
|
|
|
| 1174 |
)
|
| 1175 |
rag_chain.add_pdfs_to_vectore_store(links)
|
| 1176 |
context_display = rag_chain.get_current_context()
|
| 1177 |
-
response_msg = f"Files loaded successfully. Using model: {
|
| 1178 |
return (
|
| 1179 |
response_msg,
|
| 1180 |
f"Word count: {word_count(rag_chain.context)}",
|
| 1181 |
-
f"
|
| 1182 |
f"Context:\n{context_display}"
|
| 1183 |
)
|
| 1184 |
except Exception as e:
|
|
@@ -1209,6 +1320,16 @@ def submit_query_updated(query, temperature, top_p, top_k, bm25_weight, use_hist
|
|
| 1209 |
if not query:
|
| 1210 |
return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
|
| 1211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1212 |
# Clear conversation history if checkbox is unchecked
|
| 1213 |
if not use_history:
|
| 1214 |
rag_chain.conversation_history = []
|
|
@@ -1388,7 +1509,13 @@ document.addEventListener('DOMContentLoaded', function() {
|
|
| 1388 |
gr.Markdown('''# PhiRAG - Async Version
|
| 1389 |
**PhiRAG** Query Your Data with Advanced RAG Techniques
|
| 1390 |
|
| 1391 |
-
**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1392 |
- πΊπΈ Remote Meta-Llama-3 - has context windows of 8000 tokens
|
| 1393 |
- πͺπΊ Mistral-API - has context windows of 32000 tokens
|
| 1394 |
|
|
@@ -1412,50 +1539,48 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
|
|
| 1412 |
**β οΈ IMPORTANT: This app now uses asynchronous processing to avoid timeout issues**
|
| 1413 |
- When you load files or submit a query, you'll receive a Job ID
|
| 1414 |
- Use the "Check Job Status" tab to monitor and retrieve your results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1415 |
''')
|
| 1416 |
|
| 1417 |
with gr.Tabs() as tabs:
|
| 1418 |
with gr.TabItem("Setup & Load Files"):
|
| 1419 |
with gr.Row():
|
| 1420 |
-
with gr.Column():
|
| 1421 |
-
model_dropdown = gr.Dropdown(
|
| 1422 |
-
choices=[
|
| 1423 |
-
"πΊπΈ GPT-3.5",
|
| 1424 |
-
"πΊπΈ GPT-4o",
|
| 1425 |
-
"πΊπΈ GPT-4o mini",
|
| 1426 |
-
"πΊπΈ o1-mini",
|
| 1427 |
-
"πΊπΈ o3-mini",
|
| 1428 |
-
"πΊπΈ Remote Meta-Llama-3",
|
| 1429 |
-
"πͺπΊ Mistral-API",
|
| 1430 |
-
],
|
| 1431 |
-
value="πͺπΊ Mistral-API",
|
| 1432 |
-
label="Select Model"
|
| 1433 |
-
)
|
| 1434 |
-
temperature_slider = gr.Slider(
|
| 1435 |
-
minimum=0.1, maximum=1.0, value=0.5, step=0.1,
|
| 1436 |
-
label="Randomness (Temperature)"
|
| 1437 |
-
)
|
| 1438 |
-
top_p_slider = gr.Slider(
|
| 1439 |
-
minimum=0.1, maximum=0.99, value=0.95, step=0.05,
|
| 1440 |
-
label="Word Variety (Top-p)"
|
| 1441 |
-
)
|
| 1442 |
-
top_k_slider = gr.Slider(
|
| 1443 |
-
minimum=1, maximum=100, value=50, step=1,
|
| 1444 |
-
label="Token Selection (Top-k)"
|
| 1445 |
-
)
|
| 1446 |
-
with gr.Column():
|
| 1447 |
pdf_input = gr.Textbox(
|
| 1448 |
label="Enter your file URLs (one per line)",
|
| 1449 |
placeholder="Enter one URL per line (.pdf or .txt)",
|
| 1450 |
lines=4
|
| 1451 |
)
|
| 1452 |
-
|
| 1453 |
-
|
| 1454 |
-
|
| 1455 |
-
|
| 1456 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1457 |
)
|
| 1458 |
-
with gr.Column():
|
| 1459 |
bm25_weight_slider = gr.Slider(
|
| 1460 |
minimum=0.0, maximum=1.0, value=0.6, step=0.1,
|
| 1461 |
label="Lexical vs Semantics (BM25 Weight)"
|
|
@@ -1477,6 +1602,56 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
|
|
| 1477 |
|
| 1478 |
with gr.Row():
|
| 1479 |
model_output = gr.Markdown("**Current Model**: Not selected")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1480 |
|
| 1481 |
with gr.TabItem("Submit Query", elem_classes=["query-tab"]):
|
| 1482 |
with gr.Row():
|
|
@@ -1754,11 +1929,18 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
|
|
| 1754 |
# Add initialization info display
|
| 1755 |
init_info = gr.Markdown("")
|
| 1756 |
|
| 1757 |
-
# Update load_button click to include
|
| 1758 |
load_button.click(
|
| 1759 |
-
load_pdfs_async,
|
| 1760 |
-
inputs=[pdf_input,
|
| 1761 |
-
outputs=[load_response, load_context, model_output,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1762 |
)
|
| 1763 |
|
| 1764 |
# Add function to sync job IDs between tabs
|
|
@@ -1785,30 +1967,14 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
|
|
| 1785 |
outputs=[query_response, query_context, input_tokens, output_tokens, job_id_input, job_query_display, job_list]
|
| 1786 |
)
|
| 1787 |
|
| 1788 |
-
#
|
| 1789 |
-
def
|
| 1790 |
-
return
|
| 1791 |
|
| 1792 |
-
# Sync parameters between tabs
|
| 1793 |
-
temperature_slider.change(
|
| 1794 |
-
fn=sync_parameters,
|
| 1795 |
-
inputs=[temperature_slider, top_p_slider, top_k_slider, bm25_weight_slider],
|
| 1796 |
-
outputs=[query_temperature_slider, query_top_p_slider, query_top_k_slider, query_bm25_weight_slider]
|
| 1797 |
-
)
|
| 1798 |
-
top_p_slider.change(
|
| 1799 |
-
fn=sync_parameters,
|
| 1800 |
-
inputs=[temperature_slider, top_p_slider, top_k_slider, bm25_weight_slider],
|
| 1801 |
-
outputs=[query_temperature_slider, query_top_p_slider, query_top_k_slider, query_bm25_weight_slider]
|
| 1802 |
-
)
|
| 1803 |
-
top_k_slider.change(
|
| 1804 |
-
fn=sync_parameters,
|
| 1805 |
-
inputs=[temperature_slider, top_p_slider, top_k_slider, bm25_weight_slider],
|
| 1806 |
-
outputs=[query_temperature_slider, query_top_p_slider, query_top_k_slider, query_bm25_weight_slider]
|
| 1807 |
-
)
|
| 1808 |
bm25_weight_slider.change(
|
| 1809 |
-
fn=
|
| 1810 |
-
inputs=[
|
| 1811 |
-
outputs=[
|
| 1812 |
)
|
| 1813 |
|
| 1814 |
# Connect the buttons to their respective functions
|
|
@@ -1844,11 +2010,6 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
|
|
| 1844 |
outputs=[reset_response, reset_context, reset_model]
|
| 1845 |
)
|
| 1846 |
|
| 1847 |
-
model_dropdown.change(
|
| 1848 |
-
fn=sync_model_dropdown,
|
| 1849 |
-
inputs=model_dropdown,
|
| 1850 |
-
outputs=query_model_dropdown
|
| 1851 |
-
)
|
| 1852 |
|
| 1853 |
# Add an event to refresh the job list on page load
|
| 1854 |
app.load(
|
|
@@ -1857,6 +2018,38 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
|
|
| 1857 |
outputs=job_list
|
| 1858 |
)
|
| 1859 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1860 |
# Use the Checkbox to control the periodic updates
|
| 1861 |
auto_refresh_checkbox.change(
|
| 1862 |
fn=periodic_update,
|
|
|
|
| 21 |
from langchain_community.vectorstores import FAISS
|
| 22 |
from langchain.embeddings import HuggingFaceEmbeddings
|
| 23 |
from langchain_community.retrievers import BM25Retriever
|
| 24 |
+
from langchain.embeddings.base import Embeddings
|
| 25 |
from langchain.retrievers import EnsembleRetriever
|
| 26 |
from langchain.prompts import ChatPromptTemplate
|
| 27 |
from langchain.schema import StrOutputParser, Document
|
|
|
|
| 270 |
return len(text.split())
|
| 271 |
return len(text.split())
|
| 272 |
|
| 273 |
+
# Add NebiusEmbedding class for Nebius platform embedding models
|
| 274 |
+
class NebiusEmbedding(Embeddings):
|
| 275 |
+
"""Custom embedding class for Nebius platform models"""
|
| 276 |
+
|
| 277 |
+
def __init__(self, model_name: str, api_key: str = None):
|
| 278 |
+
super().__init__()
|
| 279 |
+
self.model_name = model_name
|
| 280 |
+
self.api_key = api_key or os.environ.get("NEBIUS_API_KEY")
|
| 281 |
+
|
| 282 |
+
if not self.api_key:
|
| 283 |
+
raise ValueError("Please set the NEBIUS_API_KEY environment variable to use Nebius embedding models.")
|
| 284 |
+
|
| 285 |
+
try:
|
| 286 |
+
from openai import OpenAI
|
| 287 |
+
self.client = OpenAI(
|
| 288 |
+
base_url="https://api.studio.nebius.com/v1/",
|
| 289 |
+
api_key=self.api_key
|
| 290 |
+
)
|
| 291 |
+
except ImportError:
|
| 292 |
+
raise ImportError("openai package is required for Nebius embedding models.")
|
| 293 |
+
|
| 294 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
| 295 |
+
"""Embed a list of documents"""
|
| 296 |
+
try:
|
| 297 |
+
response = self.client.embeddings.create(
|
| 298 |
+
model=self.model_name,
|
| 299 |
+
input=texts
|
| 300 |
+
)
|
| 301 |
+
return [data.embedding for data in response.data]
|
| 302 |
+
except Exception as e:
|
| 303 |
+
debug_print(f"Error embedding documents with Nebius: {str(e)}")
|
| 304 |
+
raise e
|
| 305 |
+
|
| 306 |
+
def embed_query(self, text: str) -> List[float]:
|
| 307 |
+
"""Embed a single query"""
|
| 308 |
+
try:
|
| 309 |
+
response = self.client.embeddings.create(
|
| 310 |
+
model=self.model_name,
|
| 311 |
+
input=[text]
|
| 312 |
+
)
|
| 313 |
+
return response.data[0].embedding
|
| 314 |
+
except Exception as e:
|
| 315 |
+
debug_print(f"Error embedding query with Nebius: {str(e)}")
|
| 316 |
+
raise e
|
| 317 |
+
|
| 318 |
|
| 319 |
# Add these imports at the top of your file
|
| 320 |
import uuid
|
|
|
|
| 345 |
debug_print(error_msg)
|
| 346 |
results_queue.put((job_id, (error_msg, None, "", "Input tokens: 0", "Output tokens: 0")))
|
| 347 |
|
| 348 |
+
def load_pdfs_async(file_links, prompt_template, bm25_weight, embedding_model):
|
| 349 |
"""Asynchronous version of load_pdfs_updated to prevent timeouts"""
|
| 350 |
global last_job_id
|
| 351 |
if not file_links:
|
| 352 |
+
return "Please enter non-empty URLs", "", "Model used: N/A", "", "", get_job_list(), ""
|
|
|
|
|
|
|
| 353 |
|
| 354 |
|
| 355 |
job_id = str(uuid.uuid4())
|
|
|
|
| 358 |
# Start background thread
|
| 359 |
threading.Thread(
|
| 360 |
target=process_in_background,
|
| 361 |
+
args=(job_id, load_pdfs_updated, [file_links, prompt_template, bm25_weight, embedding_model])
|
| 362 |
).start()
|
| 363 |
|
| 364 |
job_query = f"Loading files: {file_links.split()[0]}..." if file_links else "No files"
|
|
|
|
| 377 |
f"Files submitted and processing in the background (Job ID: {job_id}).\n\n"
|
| 378 |
f"Use 'Check Job Status' tab with this ID to get results.",
|
| 379 |
f"Job ID: {job_id}",
|
| 380 |
+
f"Embedding model: {embedding_model}",
|
| 381 |
job_id, # Return job_id to update the job_id_input component
|
| 382 |
job_query, # Return job_query to update the job_query_display component
|
| 383 |
get_job_list(), # Return updated job list
|
|
|
|
| 387 |
def submit_query_async(query, model_choice, max_tokens_slider, temperature, top_p, top_k, bm25_weight, use_history):
|
| 388 |
"""Submit a query asynchronously"""
|
| 389 |
try:
|
| 390 |
+
if not query:
|
| 391 |
+
return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
|
| 392 |
+
|
| 393 |
+
# Update BM25 weight and recreate ensemble retriever if needed
|
| 394 |
+
if hasattr(rag_chain, 'bm25_weight') and rag_chain.bm25_weight != bm25_weight:
|
| 395 |
+
rag_chain.bm25_weight = bm25_weight
|
| 396 |
+
rag_chain.faiss_weight = 1.0 - bm25_weight
|
| 397 |
+
rag_chain.ensemble_retriever = EnsembleRetriever(
|
| 398 |
+
retrievers=[rag_chain.bm25_retriever, rag_chain.faiss_retriever],
|
| 399 |
+
weights=[rag_chain.bm25_weight, rag_chain.faiss_weight]
|
| 400 |
+
)
|
| 401 |
+
debug_print(f"Updated ensemble retriever with BM25 weight: {bm25_weight}")
|
| 402 |
+
|
| 403 |
+
# Clear conversation history if checkbox is unchecked
|
| 404 |
if not use_history:
|
| 405 |
rag_chain.conversation_history = []
|
| 406 |
debug_print("Conversation history cleared")
|
|
|
|
| 783 |
|
| 784 |
class ElevatedRagChain:
|
| 785 |
def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
|
| 786 |
+
bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95, top_k: int = 50,
|
| 787 |
+
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2") -> None:
|
| 788 |
debug_print(f"Initializing ElevatedRagChain with model: {llm_choice}")
|
| 789 |
+
self.embedding_model = embedding_model
|
| 790 |
+
self.embed_func = self._create_embedding_function(embedding_model)
|
|
|
|
|
|
|
| 791 |
self.bm25_weight = bm25_weight
|
| 792 |
self.faiss_weight = 1.0 - bm25_weight
|
| 793 |
self.top_k = top_k
|
|
|
|
| 801 |
self.split_data = None
|
| 802 |
self.elevated_rag_chain = None
|
| 803 |
|
| 804 |
+
def _create_embedding_function(self, embedding_model: str):
|
| 805 |
+
"""Create the appropriate embedding function based on the model choice"""
|
| 806 |
+
debug_print(f"Creating embedding function for: {embedding_model}")
|
| 807 |
+
|
| 808 |
+
# Map display names to actual model names
|
| 809 |
+
model_mapping = {
|
| 810 |
+
# sentence-transformers Models (Free)
|
| 811 |
+
"π€ sentence-transformers/all-MiniLM-L6-v2 (384 dim, fast)": "sentence-transformers/all-MiniLM-L6-v2",
|
| 812 |
+
"π€ sentence-transformers/all-mpnet-base-v2 (768 dim, high-quality)": "sentence-transformers/all-mpnet-base-v2",
|
| 813 |
+
"π€ sentence-transformers/all-distilroberta-v1 (768 dim, balanced)": "sentence-transformers/all-distilroberta-v1",
|
| 814 |
+
"π€ sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 (384 dim, multilingual)": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
| 815 |
+
"π€ sentence-transformers/paraphrase-multilingual-mpnet-base-v2 (768 dim, multilingual)": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
| 816 |
+
|
| 817 |
+
# HuggingFace Models (Free)
|
| 818 |
+
"π€ BAAI/bge-small-en-v1.5 (384 dim, efficient)": "BAAI/bge-small-en-v1.5",
|
| 819 |
+
"π€ BAAI/bge-base-en-v1.5 (768 dim, excellent)": "BAAI/bge-base-en-v1.5",
|
| 820 |
+
"π€ BAAI/bge-large-en-v1.5 (1024 dim, powerful)": "BAAI/bge-large-en-v1.5",
|
| 821 |
+
"π€ intfloat/e5-base-v2 (768 dim, general-purpose)": "intfloat/e5-base-v2",
|
| 822 |
+
"π€ intfloat/e5-large-v2 (1024 dim, advanced)": "intfloat/e5-large-v2",
|
| 823 |
+
|
| 824 |
+
# Nebius Models (Cost)
|
| 825 |
+
"π¦ Qwen/Qwen3-Embedding-8B (1024 dim, advanced)": "Qwen/Qwen3-Embedding-8B",
|
| 826 |
+
"π¦ BAAI/bge-en-icl (1024 dim, instruction-tuned)": "BAAI/bge-en-icl",
|
| 827 |
+
"π¦ BAAI/bge-multilingual-gemma2 (1024 dim, multilingual)": "BAAI/bge-multilingual-gemma2"
|
| 828 |
+
}
|
| 829 |
+
|
| 830 |
+
# Get the actual model name
|
| 831 |
+
actual_model = model_mapping.get(embedding_model, embedding_model)
|
| 832 |
+
|
| 833 |
+
# Check if it's a Nebius model
|
| 834 |
+
if any(nebius_model in actual_model for nebius_model in [
|
| 835 |
+
"Qwen/Qwen3-Embedding-8B",
|
| 836 |
+
"BAAI/bge-en-icl",
|
| 837 |
+
"BAAI/bge-multilingual-gemma2"
|
| 838 |
+
]):
|
| 839 |
+
try:
|
| 840 |
+
return NebiusEmbedding(model_name=actual_model)
|
| 841 |
+
except Exception as e:
|
| 842 |
+
debug_print(f"Failed to create Nebius embedding: {e}")
|
| 843 |
+
debug_print("Falling back to default HuggingFace embedding")
|
| 844 |
+
return HuggingFaceEmbeddings(
|
| 845 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 846 |
+
model_kwargs={"device": "cpu"}
|
| 847 |
+
)
|
| 848 |
+
else:
|
| 849 |
+
# Default to HuggingFace embeddings for all other models
|
| 850 |
+
return HuggingFaceEmbeddings(
|
| 851 |
+
model_name=actual_model,
|
| 852 |
+
model_kwargs={"device": "cpu"}
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
# Instance method to capture context and conversation history
|
| 856 |
def capture_context(self, result):
|
| 857 |
self.context = "\n".join([str(doc) for doc in result["context"]])
|
|
|
|
| 868 |
return input_data["question"]
|
| 869 |
|
| 870 |
# Improve error handling in the ElevatedRagChain class
|
| 871 |
+
def create_llm_pipeline(self, max_tokens_override=None):
|
| 872 |
from langchain.llms.base import LLM # Import LLM here so it's always defined
|
| 873 |
from typing import Optional, List, Any
|
| 874 |
from pydantic import PrivateAttr
|
|
|
|
| 875 |
|
| 876 |
# Extract the model name without the flag emoji prefix
|
| 877 |
clean_llm_choice = self.llm_choice.split(" ", 1)[-1] if " " in self.llm_choice else self.llm_choice
|
|
|
|
| 924 |
raise ValueError(f"Unsupported model: {normalized}")
|
| 925 |
model = model_map[model_key]
|
| 926 |
max_tokens = model_token_limits.get(model, 4096)
|
| 927 |
+
if max_tokens_override is not None:
|
| 928 |
+
max_tokens = min(max_tokens_override, max_tokens)
|
| 929 |
pricing_info = model_pricing.get(model_key, {"USD": {"input": 0.00, "output": 0.00}, "RON": {"input": 0.00, "output": 0.00}})
|
| 930 |
|
| 931 |
try:
|
|
|
|
| 1252 |
global rag_chain
|
| 1253 |
rag_chain = ElevatedRagChain()
|
| 1254 |
|
| 1255 |
+
def load_pdfs_updated(file_links, prompt_template, bm25_weight, embedding_model):
|
| 1256 |
debug_print("Inside load_pdfs function.")
|
| 1257 |
if not file_links:
|
| 1258 |
debug_print("Please enter non-empty URLs")
|
|
|
|
| 1261 |
links = [link.strip() for link in file_links.split("\n") if link.strip()]
|
| 1262 |
global rag_chain
|
| 1263 |
if rag_chain.raw_data:
|
| 1264 |
+
# Files already loaded, just update parameters
|
| 1265 |
+
rag_chain.prompt_template = prompt_template
|
| 1266 |
+
rag_chain.bm25_weight = bm25_weight
|
| 1267 |
+
rag_chain.faiss_weight = 1.0 - bm25_weight
|
| 1268 |
context_display = rag_chain.get_current_context()
|
| 1269 |
+
response_msg = f"Files already loaded. Parameters updated."
|
| 1270 |
return (
|
| 1271 |
response_msg,
|
| 1272 |
f"Word count: {word_count(rag_chain.context)}",
|
| 1273 |
+
f"Embedding model: {rag_chain.embedding_model}",
|
| 1274 |
f"Context:\n{context_display}"
|
| 1275 |
)
|
| 1276 |
else:
|
| 1277 |
rag_chain = ElevatedRagChain(
|
| 1278 |
+
llm_choice="Mistral-API", # Default LLM choice
|
| 1279 |
prompt_template=prompt_template,
|
| 1280 |
bm25_weight=bm25_weight,
|
| 1281 |
+
temperature=0.5, # Default values
|
| 1282 |
+
top_p=0.95,
|
| 1283 |
+
top_k=50,
|
| 1284 |
+
embedding_model=embedding_model
|
| 1285 |
)
|
| 1286 |
rag_chain.add_pdfs_to_vectore_store(links)
|
| 1287 |
context_display = rag_chain.get_current_context()
|
| 1288 |
+
response_msg = f"Files loaded successfully. Using embedding model: {embedding_model}"
|
| 1289 |
return (
|
| 1290 |
response_msg,
|
| 1291 |
f"Word count: {word_count(rag_chain.context)}",
|
| 1292 |
+
f"Embedding model: {rag_chain.embedding_model}",
|
| 1293 |
f"Context:\n{context_display}"
|
| 1294 |
)
|
| 1295 |
except Exception as e:
|
|
|
|
| 1320 |
if not query:
|
| 1321 |
return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
|
| 1322 |
|
| 1323 |
+
# Update BM25 weight and recreate ensemble retriever if needed
|
| 1324 |
+
if hasattr(rag_chain, 'bm25_weight') and rag_chain.bm25_weight != bm25_weight:
|
| 1325 |
+
rag_chain.bm25_weight = bm25_weight
|
| 1326 |
+
rag_chain.faiss_weight = 1.0 - bm25_weight
|
| 1327 |
+
rag_chain.ensemble_retriever = EnsembleRetriever(
|
| 1328 |
+
retrievers=[rag_chain.bm25_retriever, rag_chain.faiss_retriever],
|
| 1329 |
+
weights=[rag_chain.bm25_weight, rag_chain.faiss_weight]
|
| 1330 |
+
)
|
| 1331 |
+
debug_print(f"Updated ensemble retriever with BM25 weight: {bm25_weight}")
|
| 1332 |
+
|
| 1333 |
# Clear conversation history if checkbox is unchecked
|
| 1334 |
if not use_history:
|
| 1335 |
rag_chain.conversation_history = []
|
|
|
|
| 1509 |
gr.Markdown('''# PhiRAG - Async Version
|
| 1510 |
**PhiRAG** Query Your Data with Advanced RAG Techniques
|
| 1511 |
|
| 1512 |
+
**Embedding Models:** Choose from the following options:
|
| 1513 |
+
- π€ **HuggingFace Models (Free)**: sentence-transformers, BAAI, intfloat models
|
| 1514 |
+
- π¦ **Nebius Models (Cost)**: Qwen, BAAI models via Nebius platform
|
| 1515 |
+
- **Dimensions**: 384 (fast), 768 (balanced), 1024 (powerful)
|
| 1516 |
+
- **Languages**: English-focused and multilingual options available
|
| 1517 |
+
|
| 1518 |
+
**LLM Models:** Choose from the following options in the Query tabs:
|
| 1519 |
- πΊπΈ Remote Meta-Llama-3 - has context windows of 8000 tokens
|
| 1520 |
- πͺπΊ Mistral-API - has context windows of 32000 tokens
|
| 1521 |
|
|
|
|
| 1539 |
**β οΈ IMPORTANT: This app now uses asynchronous processing to avoid timeout issues**
|
| 1540 |
- When you load files or submit a query, you'll receive a Job ID
|
| 1541 |
- Use the "Check Job Status" tab to monitor and retrieve your results
|
| 1542 |
+
|
| 1543 |
+
**π API Keys Required:**
|
| 1544 |
+
- For Nebius embedding models: Set the NEBIUS_API_KEY environment variable
|
| 1545 |
+
- For OpenAI models: Set the OPENAI_API_KEY environment variable
|
| 1546 |
+
- For Mistral models: Set the MISTRAL_API_KEY environment variable
|
| 1547 |
+
- For HuggingFace models: Set the HF_API_TOKEN environment variable
|
| 1548 |
''')
|
| 1549 |
|
| 1550 |
with gr.Tabs() as tabs:
|
| 1551 |
with gr.TabItem("Setup & Load Files"):
|
| 1552 |
with gr.Row():
|
| 1553 |
+
with gr.Column(scale=2): # Expanded to take more space
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1554 |
pdf_input = gr.Textbox(
|
| 1555 |
label="Enter your file URLs (one per line)",
|
| 1556 |
placeholder="Enter one URL per line (.pdf or .txt)",
|
| 1557 |
lines=4
|
| 1558 |
)
|
| 1559 |
+
with gr.Column(scale=1): # Smaller column for controls
|
| 1560 |
+
embedding_dropdown = gr.Dropdown(
|
| 1561 |
+
choices=[
|
| 1562 |
+
# sentence-transformers Models (Free)
|
| 1563 |
+
"π€ sentence-transformers/all-MiniLM-L6-v2 (384 dim, fast)",
|
| 1564 |
+
"π€ sentence-transformers/all-mpnet-base-v2 (768 dim, high-quality)",
|
| 1565 |
+
"π€ sentence-transformers/all-distilroberta-v1 (768 dim, balanced)",
|
| 1566 |
+
"π€ sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 (384 dim, multilingual)",
|
| 1567 |
+
"π€ sentence-transformers/paraphrase-multilingual-mpnet-base-v2 (768 dim, multilingual)",
|
| 1568 |
+
|
| 1569 |
+
# HuggingFace Models (Free)
|
| 1570 |
+
"π€ BAAI/bge-small-en-v1.5 (384 dim, efficient)",
|
| 1571 |
+
"π€ BAAI/bge-base-en-v1.5 (768 dim, excellent)",
|
| 1572 |
+
"π€ BAAI/bge-large-en-v1.5 (1024 dim, powerful)",
|
| 1573 |
+
"π€ intfloat/e5-base-v2 (768 dim, general-purpose)",
|
| 1574 |
+
"π€ intfloat/e5-large-v2 (1024 dim, advanced)",
|
| 1575 |
+
|
| 1576 |
+
# Nebius Models (Cost)
|
| 1577 |
+
"π¦ Qwen/Qwen3-Embedding-8B (1024 dim, advanced)",
|
| 1578 |
+
"π¦ BAAI/bge-en-icl (1024 dim, instruction-tuned)",
|
| 1579 |
+
"π¦ BAAI/bge-multilingual-gemma2 (1024 dim, multilingual)",
|
| 1580 |
+
],
|
| 1581 |
+
value="π€ sentence-transformers/all-MiniLM-L6-v2 (384 dim, fast)",
|
| 1582 |
+
label="Select Embedding Model (π€ = HuggingFace free, π¦ = Nebius cost)"
|
| 1583 |
)
|
|
|
|
| 1584 |
bm25_weight_slider = gr.Slider(
|
| 1585 |
minimum=0.0, maximum=1.0, value=0.6, step=0.1,
|
| 1586 |
label="Lexical vs Semantics (BM25 Weight)"
|
|
|
|
| 1602 |
|
| 1603 |
with gr.Row():
|
| 1604 |
model_output = gr.Markdown("**Current Model**: Not selected")
|
| 1605 |
+
|
| 1606 |
+
# Job Status Section for Setup & Load
|
| 1607 |
+
with gr.Row():
|
| 1608 |
+
with gr.Column(scale=1):
|
| 1609 |
+
setup_job_list = gr.Markdown(
|
| 1610 |
+
value="No jobs yet",
|
| 1611 |
+
label="Job List (Click to select)"
|
| 1612 |
+
)
|
| 1613 |
+
setup_refresh_button = gr.Button("Refresh Job List")
|
| 1614 |
+
setup_auto_refresh_checkbox = gr.Checkbox(
|
| 1615 |
+
label="Enable Auto Refresh",
|
| 1616 |
+
value=False
|
| 1617 |
+
)
|
| 1618 |
+
setup_df = gr.DataFrame(
|
| 1619 |
+
value=[], # Empty initial value
|
| 1620 |
+
headers=["Number", "Square"],
|
| 1621 |
+
label="Query Results",
|
| 1622 |
+
visible=False
|
| 1623 |
+
)
|
| 1624 |
+
|
| 1625 |
+
with gr.Column(scale=2):
|
| 1626 |
+
setup_job_id_input = gr.Textbox(
|
| 1627 |
+
label="Job ID",
|
| 1628 |
+
placeholder="Job ID will appear here when selected from the list",
|
| 1629 |
+
lines=1
|
| 1630 |
+
)
|
| 1631 |
+
setup_job_query_display = gr.Textbox(
|
| 1632 |
+
label="Job Query",
|
| 1633 |
+
placeholder="The query associated with this job will appear here",
|
| 1634 |
+
lines=2,
|
| 1635 |
+
interactive=False
|
| 1636 |
+
)
|
| 1637 |
+
setup_check_button = gr.Button("Check Status")
|
| 1638 |
+
setup_cleanup_button = gr.Button("Cleanup Old Jobs")
|
| 1639 |
+
|
| 1640 |
+
with gr.Row():
|
| 1641 |
+
setup_status_response = gr.Textbox(
|
| 1642 |
+
label="Job Result",
|
| 1643 |
+
placeholder="Job result will appear here",
|
| 1644 |
+
lines=6
|
| 1645 |
+
)
|
| 1646 |
+
setup_status_context = gr.Textbox(
|
| 1647 |
+
label="Context Information",
|
| 1648 |
+
placeholder="Context information will appear here",
|
| 1649 |
+
lines=6
|
| 1650 |
+
)
|
| 1651 |
+
|
| 1652 |
+
with gr.Row():
|
| 1653 |
+
setup_status_tokens1 = gr.Markdown("")
|
| 1654 |
+
setup_status_tokens2 = gr.Markdown("")
|
| 1655 |
|
| 1656 |
with gr.TabItem("Submit Query", elem_classes=["query-tab"]):
|
| 1657 |
with gr.Row():
|
|
|
|
| 1929 |
# Add initialization info display
|
| 1930 |
init_info = gr.Markdown("")
|
| 1931 |
|
| 1932 |
+
# Update load_button click to include embedding model
|
| 1933 |
load_button.click(
|
| 1934 |
+
lambda file_links, bm25_weight, embedding_model: load_pdfs_async(file_links, default_prompt, bm25_weight, embedding_model),
|
| 1935 |
+
inputs=[pdf_input, bm25_weight_slider, embedding_dropdown],
|
| 1936 |
+
outputs=[load_response, load_context, model_output, setup_job_id_input, setup_job_query_display, setup_job_list, init_info]
|
| 1937 |
+
)
|
| 1938 |
+
|
| 1939 |
+
# Also update Setup & Load job list when files are loaded
|
| 1940 |
+
load_button.click(
|
| 1941 |
+
fn=lambda *args: get_job_list(),
|
| 1942 |
+
inputs=[],
|
| 1943 |
+
outputs=[setup_job_list]
|
| 1944 |
)
|
| 1945 |
|
| 1946 |
# Add function to sync job IDs between tabs
|
|
|
|
| 1967 |
outputs=[query_response, query_context, input_tokens, output_tokens, job_id_input, job_query_display, job_list]
|
| 1968 |
)
|
| 1969 |
|
| 1970 |
+
# Sync BM25 weight between Setup & Load and Query tabs
|
| 1971 |
+
def sync_bm25_weight(bm25_weight):
|
| 1972 |
+
return bm25_weight
|
| 1973 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1974 |
bm25_weight_slider.change(
|
| 1975 |
+
fn=sync_bm25_weight,
|
| 1976 |
+
inputs=[bm25_weight_slider],
|
| 1977 |
+
outputs=[query_bm25_weight_slider]
|
| 1978 |
)
|
| 1979 |
|
| 1980 |
# Connect the buttons to their respective functions
|
|
|
|
| 2010 |
outputs=[reset_response, reset_context, reset_model]
|
| 2011 |
)
|
| 2012 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2013 |
|
| 2014 |
# Add an event to refresh the job list on page load
|
| 2015 |
app.load(
|
|
|
|
| 2018 |
outputs=job_list
|
| 2019 |
)
|
| 2020 |
|
| 2021 |
+
# Setup & Load Job Status Event Handlers
|
| 2022 |
+
setup_check_button.click(
|
| 2023 |
+
check_job_status,
|
| 2024 |
+
inputs=[setup_job_id_input],
|
| 2025 |
+
outputs=[setup_status_response, setup_status_context, setup_status_tokens1, setup_status_tokens2, setup_job_query_display]
|
| 2026 |
+
)
|
| 2027 |
+
|
| 2028 |
+
setup_refresh_button.click(
|
| 2029 |
+
refresh_job_list,
|
| 2030 |
+
inputs=[],
|
| 2031 |
+
outputs=[setup_job_list]
|
| 2032 |
+
)
|
| 2033 |
+
|
| 2034 |
+
setup_job_id_input.change(
|
| 2035 |
+
job_selected,
|
| 2036 |
+
inputs=[setup_job_id_input],
|
| 2037 |
+
outputs=[setup_job_id_input, setup_job_query_display]
|
| 2038 |
+
)
|
| 2039 |
+
|
| 2040 |
+
setup_cleanup_button.click(
|
| 2041 |
+
cleanup_old_jobs,
|
| 2042 |
+
inputs=[],
|
| 2043 |
+
outputs=[setup_status_response, setup_status_context, setup_status_tokens1]
|
| 2044 |
+
)
|
| 2045 |
+
|
| 2046 |
+
setup_auto_refresh_checkbox.change(
|
| 2047 |
+
fn=periodic_update,
|
| 2048 |
+
inputs=[setup_auto_refresh_checkbox],
|
| 2049 |
+
outputs=[setup_job_list, setup_status_response, setup_df, setup_status_context],
|
| 2050 |
+
every=2
|
| 2051 |
+
)
|
| 2052 |
+
|
| 2053 |
# Use the Checkbox to control the periodic updates
|
| 2054 |
auto_refresh_checkbox.change(
|
| 2055 |
fn=periodic_update,
|
psyllm.py
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
CHANGED
|
@@ -47,3 +47,6 @@ pydantic==2.9.0
|
|
| 47 |
sentence-transformers>=2.4.0
|
| 48 |
|
| 49 |
mistralai==1.5.0
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
sentence-transformers>=2.4.0
|
| 48 |
|
| 49 |
mistralai==1.5.0
|
| 50 |
+
|
| 51 |
+
matplotlib>=3.0.0
|
| 52 |
+
networkx>=2.0
|