alx-d commited on
Commit
b256930
Β·
verified Β·
1 Parent(s): b4842b9

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. advanced_rag.py +283 -90
  2. psyllm.py +0 -0
  3. requirements.txt +3 -0
advanced_rag.py CHANGED
@@ -21,6 +21,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
21
  from langchain_community.vectorstores import FAISS
22
  from langchain.embeddings import HuggingFaceEmbeddings
23
  from langchain_community.retrievers import BM25Retriever
 
24
  from langchain.retrievers import EnsembleRetriever
25
  from langchain.prompts import ChatPromptTemplate
26
  from langchain.schema import StrOutputParser, Document
@@ -269,6 +270,51 @@ def count_tokens(text: str) -> int:
269
  return len(text.split())
270
  return len(text.split())
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  # Add these imports at the top of your file
274
  import uuid
@@ -299,13 +345,11 @@ def process_in_background(job_id, function, args):
299
  debug_print(error_msg)
300
  results_queue.put((job_id, (error_msg, None, "", "Input tokens: 0", "Output tokens: 0")))
301
 
302
- def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p, top_k, max_tokens_slider):
303
  """Asynchronous version of load_pdfs_updated to prevent timeouts"""
304
  global last_job_id
305
  if not file_links:
306
- return "Please enter non-empty URLs", "", "Model used: N/A", "", "", get_job_list(), ""
307
- global slider_max_tokens
308
- slider_max_tokens = max_tokens_slider
309
 
310
 
311
  job_id = str(uuid.uuid4())
@@ -314,7 +358,7 @@ def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temp
314
  # Start background thread
315
  threading.Thread(
316
  target=process_in_background,
317
- args=(job_id, load_pdfs_updated, [file_links, model_choice, prompt_template, bm25_weight, temperature, top_p, top_k])
318
  ).start()
319
 
320
  job_query = f"Loading files: {file_links.split()[0]}..." if file_links else "No files"
@@ -333,7 +377,7 @@ def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temp
333
  f"Files submitted and processing in the background (Job ID: {job_id}).\n\n"
334
  f"Use 'Check Job Status' tab with this ID to get results.",
335
  f"Job ID: {job_id}",
336
- f"Model requested: {model_choice}",
337
  job_id, # Return job_id to update the job_id_input component
338
  job_query, # Return job_query to update the job_query_display component
339
  get_job_list(), # Return updated job list
@@ -343,7 +387,20 @@ def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temp
343
  def submit_query_async(query, model_choice, max_tokens_slider, temperature, top_p, top_k, bm25_weight, use_history):
344
  """Submit a query asynchronously"""
345
  try:
346
- # ... existing code ...
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  if not use_history:
348
  rag_chain.conversation_history = []
349
  debug_print("Conversation history cleared")
@@ -726,12 +783,11 @@ def load_file_from_google_drive(link: str) -> list:
726
 
727
  class ElevatedRagChain:
728
  def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
729
- bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95, top_k: int = 50) -> None:
 
730
  debug_print(f"Initializing ElevatedRagChain with model: {llm_choice}")
731
- self.embed_func = HuggingFaceEmbeddings(
732
- model_name="sentence-transformers/all-MiniLM-L6-v2",
733
- model_kwargs={"device": "cpu"}
734
- )
735
  self.bm25_weight = bm25_weight
736
  self.faiss_weight = 1.0 - bm25_weight
737
  self.top_k = top_k
@@ -745,6 +801,57 @@ class ElevatedRagChain:
745
  self.split_data = None
746
  self.elevated_rag_chain = None
747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
748
  # Instance method to capture context and conversation history
749
  def capture_context(self, result):
750
  self.context = "\n".join([str(doc) for doc in result["context"]])
@@ -761,11 +868,10 @@ class ElevatedRagChain:
761
  return input_data["question"]
762
 
763
  # Improve error handling in the ElevatedRagChain class
764
- def create_llm_pipeline(self):
765
  from langchain.llms.base import LLM # Import LLM here so it's always defined
766
  from typing import Optional, List, Any
767
  from pydantic import PrivateAttr
768
- global slider_max_tokens
769
 
770
  # Extract the model name without the flag emoji prefix
771
  clean_llm_choice = self.llm_choice.split(" ", 1)[-1] if " " in self.llm_choice else self.llm_choice
@@ -818,7 +924,8 @@ class ElevatedRagChain:
818
  raise ValueError(f"Unsupported model: {normalized}")
819
  model = model_map[model_key]
820
  max_tokens = model_token_limits.get(model, 4096)
821
- max_tokens = min(slider_max_tokens, max_tokens)
 
822
  pricing_info = model_pricing.get(model_key, {"USD": {"input": 0.00, "output": 0.00}, "RON": {"input": 0.00, "output": 0.00}})
823
 
824
  try:
@@ -1145,7 +1252,7 @@ class ElevatedRagChain:
1145
  global rag_chain
1146
  rag_chain = ElevatedRagChain()
1147
 
1148
- def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p, top_k):
1149
  debug_print("Inside load_pdfs function.")
1150
  if not file_links:
1151
  debug_print("Please enter non-empty URLs")
@@ -1154,31 +1261,35 @@ def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, te
1154
  links = [link.strip() for link in file_links.split("\n") if link.strip()]
1155
  global rag_chain
1156
  if rag_chain.raw_data:
1157
- rag_chain.update_llm_pipeline(model_choice, temperature, top_p, top_k, prompt_template, bm25_weight)
 
 
 
1158
  context_display = rag_chain.get_current_context()
1159
- response_msg = f"Files already loaded. Chain updated with model: {model_choice}"
1160
  return (
1161
  response_msg,
1162
  f"Word count: {word_count(rag_chain.context)}",
1163
- f"Model used: {rag_chain.llm_choice}",
1164
  f"Context:\n{context_display}"
1165
  )
1166
  else:
1167
  rag_chain = ElevatedRagChain(
1168
- llm_choice=model_choice,
1169
  prompt_template=prompt_template,
1170
  bm25_weight=bm25_weight,
1171
- temperature=temperature,
1172
- top_p=top_p,
1173
- top_k=top_k
 
1174
  )
1175
  rag_chain.add_pdfs_to_vectore_store(links)
1176
  context_display = rag_chain.get_current_context()
1177
- response_msg = f"Files loaded successfully. Using model: {model_choice}"
1178
  return (
1179
  response_msg,
1180
  f"Word count: {word_count(rag_chain.context)}",
1181
- f"Model used: {rag_chain.llm_choice}",
1182
  f"Context:\n{context_display}"
1183
  )
1184
  except Exception as e:
@@ -1209,6 +1320,16 @@ def submit_query_updated(query, temperature, top_p, top_k, bm25_weight, use_hist
1209
  if not query:
1210
  return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
1211
 
 
 
 
 
 
 
 
 
 
 
1212
  # Clear conversation history if checkbox is unchecked
1213
  if not use_history:
1214
  rag_chain.conversation_history = []
@@ -1388,7 +1509,13 @@ document.addEventListener('DOMContentLoaded', function() {
1388
  gr.Markdown('''# PhiRAG - Async Version
1389
  **PhiRAG** Query Your Data with Advanced RAG Techniques
1390
 
1391
- **Model Selection & Parameters:** Choose from the following options:
 
 
 
 
 
 
1392
  - πŸ‡ΊπŸ‡Έ Remote Meta-Llama-3 - has context windows of 8000 tokens
1393
  - πŸ‡ͺπŸ‡Ί Mistral-API - has context windows of 32000 tokens
1394
 
@@ -1412,50 +1539,48 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
1412
  **⚠️ IMPORTANT: This app now uses asynchronous processing to avoid timeout issues**
1413
  - When you load files or submit a query, you'll receive a Job ID
1414
  - Use the "Check Job Status" tab to monitor and retrieve your results
 
 
 
 
 
 
1415
  ''')
1416
 
1417
  with gr.Tabs() as tabs:
1418
  with gr.TabItem("Setup & Load Files"):
1419
  with gr.Row():
1420
- with gr.Column():
1421
- model_dropdown = gr.Dropdown(
1422
- choices=[
1423
- "πŸ‡ΊπŸ‡Έ GPT-3.5",
1424
- "πŸ‡ΊπŸ‡Έ GPT-4o",
1425
- "πŸ‡ΊπŸ‡Έ GPT-4o mini",
1426
- "πŸ‡ΊπŸ‡Έ o1-mini",
1427
- "πŸ‡ΊπŸ‡Έ o3-mini",
1428
- "πŸ‡ΊπŸ‡Έ Remote Meta-Llama-3",
1429
- "πŸ‡ͺπŸ‡Ί Mistral-API",
1430
- ],
1431
- value="πŸ‡ͺπŸ‡Ί Mistral-API",
1432
- label="Select Model"
1433
- )
1434
- temperature_slider = gr.Slider(
1435
- minimum=0.1, maximum=1.0, value=0.5, step=0.1,
1436
- label="Randomness (Temperature)"
1437
- )
1438
- top_p_slider = gr.Slider(
1439
- minimum=0.1, maximum=0.99, value=0.95, step=0.05,
1440
- label="Word Variety (Top-p)"
1441
- )
1442
- top_k_slider = gr.Slider(
1443
- minimum=1, maximum=100, value=50, step=1,
1444
- label="Token Selection (Top-k)"
1445
- )
1446
- with gr.Column():
1447
  pdf_input = gr.Textbox(
1448
  label="Enter your file URLs (one per line)",
1449
  placeholder="Enter one URL per line (.pdf or .txt)",
1450
  lines=4
1451
  )
1452
- prompt_input = gr.Textbox(
1453
- label="Custom Prompt Template",
1454
- placeholder="Enter your custom prompt template here",
1455
- lines=8,
1456
- value=default_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1457
  )
1458
- with gr.Column():
1459
  bm25_weight_slider = gr.Slider(
1460
  minimum=0.0, maximum=1.0, value=0.6, step=0.1,
1461
  label="Lexical vs Semantics (BM25 Weight)"
@@ -1477,6 +1602,56 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
1477
 
1478
  with gr.Row():
1479
  model_output = gr.Markdown("**Current Model**: Not selected")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1480
 
1481
  with gr.TabItem("Submit Query", elem_classes=["query-tab"]):
1482
  with gr.Row():
@@ -1754,11 +1929,18 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
1754
  # Add initialization info display
1755
  init_info = gr.Markdown("")
1756
 
1757
- # Update load_button click to include top_k
1758
  load_button.click(
1759
- load_pdfs_async,
1760
- inputs=[pdf_input, model_dropdown, prompt_input, bm25_weight_slider, temperature_slider, top_p_slider, top_k_slider, max_tokens_slider],
1761
- outputs=[load_response, load_context, model_output, job_id_input, job_query_display, job_list, init_info]
 
 
 
 
 
 
 
1762
  )
1763
 
1764
  # Add function to sync job IDs between tabs
@@ -1785,30 +1967,14 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
1785
  outputs=[query_response, query_context, input_tokens, output_tokens, job_id_input, job_query_display, job_list]
1786
  )
1787
 
1788
- # Add function to sync all parameters
1789
- def sync_parameters(temperature, top_p, top_k, bm25_weight):
1790
- return temperature, top_p, top_k, bm25_weight
1791
 
1792
- # Sync parameters between tabs
1793
- temperature_slider.change(
1794
- fn=sync_parameters,
1795
- inputs=[temperature_slider, top_p_slider, top_k_slider, bm25_weight_slider],
1796
- outputs=[query_temperature_slider, query_top_p_slider, query_top_k_slider, query_bm25_weight_slider]
1797
- )
1798
- top_p_slider.change(
1799
- fn=sync_parameters,
1800
- inputs=[temperature_slider, top_p_slider, top_k_slider, bm25_weight_slider],
1801
- outputs=[query_temperature_slider, query_top_p_slider, query_top_k_slider, query_bm25_weight_slider]
1802
- )
1803
- top_k_slider.change(
1804
- fn=sync_parameters,
1805
- inputs=[temperature_slider, top_p_slider, top_k_slider, bm25_weight_slider],
1806
- outputs=[query_temperature_slider, query_top_p_slider, query_top_k_slider, query_bm25_weight_slider]
1807
- )
1808
  bm25_weight_slider.change(
1809
- fn=sync_parameters,
1810
- inputs=[temperature_slider, top_p_slider, top_k_slider, bm25_weight_slider],
1811
- outputs=[query_temperature_slider, query_top_p_slider, query_top_k_slider, query_bm25_weight_slider]
1812
  )
1813
 
1814
  # Connect the buttons to their respective functions
@@ -1844,11 +2010,6 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
1844
  outputs=[reset_response, reset_context, reset_model]
1845
  )
1846
 
1847
- model_dropdown.change(
1848
- fn=sync_model_dropdown,
1849
- inputs=model_dropdown,
1850
- outputs=query_model_dropdown
1851
- )
1852
 
1853
  # Add an event to refresh the job list on page load
1854
  app.load(
@@ -1857,6 +2018,38 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
1857
  outputs=job_list
1858
  )
1859
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1860
  # Use the Checkbox to control the periodic updates
1861
  auto_refresh_checkbox.change(
1862
  fn=periodic_update,
 
21
  from langchain_community.vectorstores import FAISS
22
  from langchain.embeddings import HuggingFaceEmbeddings
23
  from langchain_community.retrievers import BM25Retriever
24
+ from langchain.embeddings.base import Embeddings
25
  from langchain.retrievers import EnsembleRetriever
26
  from langchain.prompts import ChatPromptTemplate
27
  from langchain.schema import StrOutputParser, Document
 
270
  return len(text.split())
271
  return len(text.split())
272
 
273
+ # Add NebiusEmbedding class for Nebius platform embedding models
274
+ class NebiusEmbedding(Embeddings):
275
+ """Custom embedding class for Nebius platform models"""
276
+
277
+ def __init__(self, model_name: str, api_key: str = None):
278
+ super().__init__()
279
+ self.model_name = model_name
280
+ self.api_key = api_key or os.environ.get("NEBIUS_API_KEY")
281
+
282
+ if not self.api_key:
283
+ raise ValueError("Please set the NEBIUS_API_KEY environment variable to use Nebius embedding models.")
284
+
285
+ try:
286
+ from openai import OpenAI
287
+ self.client = OpenAI(
288
+ base_url="https://api.studio.nebius.com/v1/",
289
+ api_key=self.api_key
290
+ )
291
+ except ImportError:
292
+ raise ImportError("openai package is required for Nebius embedding models.")
293
+
294
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
295
+ """Embed a list of documents"""
296
+ try:
297
+ response = self.client.embeddings.create(
298
+ model=self.model_name,
299
+ input=texts
300
+ )
301
+ return [data.embedding for data in response.data]
302
+ except Exception as e:
303
+ debug_print(f"Error embedding documents with Nebius: {str(e)}")
304
+ raise e
305
+
306
+ def embed_query(self, text: str) -> List[float]:
307
+ """Embed a single query"""
308
+ try:
309
+ response = self.client.embeddings.create(
310
+ model=self.model_name,
311
+ input=[text]
312
+ )
313
+ return response.data[0].embedding
314
+ except Exception as e:
315
+ debug_print(f"Error embedding query with Nebius: {str(e)}")
316
+ raise e
317
+
318
 
319
  # Add these imports at the top of your file
320
  import uuid
 
345
  debug_print(error_msg)
346
  results_queue.put((job_id, (error_msg, None, "", "Input tokens: 0", "Output tokens: 0")))
347
 
348
+ def load_pdfs_async(file_links, prompt_template, bm25_weight, embedding_model):
349
  """Asynchronous version of load_pdfs_updated to prevent timeouts"""
350
  global last_job_id
351
  if not file_links:
352
+ return "Please enter non-empty URLs", "", "Model used: N/A", "", "", get_job_list(), ""
 
 
353
 
354
 
355
  job_id = str(uuid.uuid4())
 
358
  # Start background thread
359
  threading.Thread(
360
  target=process_in_background,
361
+ args=(job_id, load_pdfs_updated, [file_links, prompt_template, bm25_weight, embedding_model])
362
  ).start()
363
 
364
  job_query = f"Loading files: {file_links.split()[0]}..." if file_links else "No files"
 
377
  f"Files submitted and processing in the background (Job ID: {job_id}).\n\n"
378
  f"Use 'Check Job Status' tab with this ID to get results.",
379
  f"Job ID: {job_id}",
380
+ f"Embedding model: {embedding_model}",
381
  job_id, # Return job_id to update the job_id_input component
382
  job_query, # Return job_query to update the job_query_display component
383
  get_job_list(), # Return updated job list
 
387
  def submit_query_async(query, model_choice, max_tokens_slider, temperature, top_p, top_k, bm25_weight, use_history):
388
  """Submit a query asynchronously"""
389
  try:
390
+ if not query:
391
+ return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
392
+
393
+ # Update BM25 weight and recreate ensemble retriever if needed
394
+ if hasattr(rag_chain, 'bm25_weight') and rag_chain.bm25_weight != bm25_weight:
395
+ rag_chain.bm25_weight = bm25_weight
396
+ rag_chain.faiss_weight = 1.0 - bm25_weight
397
+ rag_chain.ensemble_retriever = EnsembleRetriever(
398
+ retrievers=[rag_chain.bm25_retriever, rag_chain.faiss_retriever],
399
+ weights=[rag_chain.bm25_weight, rag_chain.faiss_weight]
400
+ )
401
+ debug_print(f"Updated ensemble retriever with BM25 weight: {bm25_weight}")
402
+
403
+ # Clear conversation history if checkbox is unchecked
404
  if not use_history:
405
  rag_chain.conversation_history = []
406
  debug_print("Conversation history cleared")
 
783
 
784
  class ElevatedRagChain:
785
  def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
786
+ bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95, top_k: int = 50,
787
+ embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2") -> None:
788
  debug_print(f"Initializing ElevatedRagChain with model: {llm_choice}")
789
+ self.embedding_model = embedding_model
790
+ self.embed_func = self._create_embedding_function(embedding_model)
 
 
791
  self.bm25_weight = bm25_weight
792
  self.faiss_weight = 1.0 - bm25_weight
793
  self.top_k = top_k
 
801
  self.split_data = None
802
  self.elevated_rag_chain = None
803
 
804
+ def _create_embedding_function(self, embedding_model: str):
805
+ """Create the appropriate embedding function based on the model choice"""
806
+ debug_print(f"Creating embedding function for: {embedding_model}")
807
+
808
+ # Map display names to actual model names
809
+ model_mapping = {
810
+ # sentence-transformers Models (Free)
811
+ "πŸ€— sentence-transformers/all-MiniLM-L6-v2 (384 dim, fast)": "sentence-transformers/all-MiniLM-L6-v2",
812
+ "πŸ€— sentence-transformers/all-mpnet-base-v2 (768 dim, high-quality)": "sentence-transformers/all-mpnet-base-v2",
813
+ "πŸ€— sentence-transformers/all-distilroberta-v1 (768 dim, balanced)": "sentence-transformers/all-distilroberta-v1",
814
+ "πŸ€— sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 (384 dim, multilingual)": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
815
+ "πŸ€— sentence-transformers/paraphrase-multilingual-mpnet-base-v2 (768 dim, multilingual)": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
816
+
817
+ # HuggingFace Models (Free)
818
+ "πŸ€— BAAI/bge-small-en-v1.5 (384 dim, efficient)": "BAAI/bge-small-en-v1.5",
819
+ "πŸ€— BAAI/bge-base-en-v1.5 (768 dim, excellent)": "BAAI/bge-base-en-v1.5",
820
+ "πŸ€— BAAI/bge-large-en-v1.5 (1024 dim, powerful)": "BAAI/bge-large-en-v1.5",
821
+ "πŸ€— intfloat/e5-base-v2 (768 dim, general-purpose)": "intfloat/e5-base-v2",
822
+ "πŸ€— intfloat/e5-large-v2 (1024 dim, advanced)": "intfloat/e5-large-v2",
823
+
824
+ # Nebius Models (Cost)
825
+ "🟦 Qwen/Qwen3-Embedding-8B (1024 dim, advanced)": "Qwen/Qwen3-Embedding-8B",
826
+ "🟦 BAAI/bge-en-icl (1024 dim, instruction-tuned)": "BAAI/bge-en-icl",
827
+ "🟦 BAAI/bge-multilingual-gemma2 (1024 dim, multilingual)": "BAAI/bge-multilingual-gemma2"
828
+ }
829
+
830
+ # Get the actual model name
831
+ actual_model = model_mapping.get(embedding_model, embedding_model)
832
+
833
+ # Check if it's a Nebius model
834
+ if any(nebius_model in actual_model for nebius_model in [
835
+ "Qwen/Qwen3-Embedding-8B",
836
+ "BAAI/bge-en-icl",
837
+ "BAAI/bge-multilingual-gemma2"
838
+ ]):
839
+ try:
840
+ return NebiusEmbedding(model_name=actual_model)
841
+ except Exception as e:
842
+ debug_print(f"Failed to create Nebius embedding: {e}")
843
+ debug_print("Falling back to default HuggingFace embedding")
844
+ return HuggingFaceEmbeddings(
845
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
846
+ model_kwargs={"device": "cpu"}
847
+ )
848
+ else:
849
+ # Default to HuggingFace embeddings for all other models
850
+ return HuggingFaceEmbeddings(
851
+ model_name=actual_model,
852
+ model_kwargs={"device": "cpu"}
853
+ )
854
+
855
  # Instance method to capture context and conversation history
856
  def capture_context(self, result):
857
  self.context = "\n".join([str(doc) for doc in result["context"]])
 
868
  return input_data["question"]
869
 
870
  # Improve error handling in the ElevatedRagChain class
871
+ def create_llm_pipeline(self, max_tokens_override=None):
872
  from langchain.llms.base import LLM # Import LLM here so it's always defined
873
  from typing import Optional, List, Any
874
  from pydantic import PrivateAttr
 
875
 
876
  # Extract the model name without the flag emoji prefix
877
  clean_llm_choice = self.llm_choice.split(" ", 1)[-1] if " " in self.llm_choice else self.llm_choice
 
924
  raise ValueError(f"Unsupported model: {normalized}")
925
  model = model_map[model_key]
926
  max_tokens = model_token_limits.get(model, 4096)
927
+ if max_tokens_override is not None:
928
+ max_tokens = min(max_tokens_override, max_tokens)
929
  pricing_info = model_pricing.get(model_key, {"USD": {"input": 0.00, "output": 0.00}, "RON": {"input": 0.00, "output": 0.00}})
930
 
931
  try:
 
1252
  global rag_chain
1253
  rag_chain = ElevatedRagChain()
1254
 
1255
+ def load_pdfs_updated(file_links, prompt_template, bm25_weight, embedding_model):
1256
  debug_print("Inside load_pdfs function.")
1257
  if not file_links:
1258
  debug_print("Please enter non-empty URLs")
 
1261
  links = [link.strip() for link in file_links.split("\n") if link.strip()]
1262
  global rag_chain
1263
  if rag_chain.raw_data:
1264
+ # Files already loaded, just update parameters
1265
+ rag_chain.prompt_template = prompt_template
1266
+ rag_chain.bm25_weight = bm25_weight
1267
+ rag_chain.faiss_weight = 1.0 - bm25_weight
1268
  context_display = rag_chain.get_current_context()
1269
+ response_msg = f"Files already loaded. Parameters updated."
1270
  return (
1271
  response_msg,
1272
  f"Word count: {word_count(rag_chain.context)}",
1273
+ f"Embedding model: {rag_chain.embedding_model}",
1274
  f"Context:\n{context_display}"
1275
  )
1276
  else:
1277
  rag_chain = ElevatedRagChain(
1278
+ llm_choice="Mistral-API", # Default LLM choice
1279
  prompt_template=prompt_template,
1280
  bm25_weight=bm25_weight,
1281
+ temperature=0.5, # Default values
1282
+ top_p=0.95,
1283
+ top_k=50,
1284
+ embedding_model=embedding_model
1285
  )
1286
  rag_chain.add_pdfs_to_vectore_store(links)
1287
  context_display = rag_chain.get_current_context()
1288
+ response_msg = f"Files loaded successfully. Using embedding model: {embedding_model}"
1289
  return (
1290
  response_msg,
1291
  f"Word count: {word_count(rag_chain.context)}",
1292
+ f"Embedding model: {rag_chain.embedding_model}",
1293
  f"Context:\n{context_display}"
1294
  )
1295
  except Exception as e:
 
1320
  if not query:
1321
  return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
1322
 
1323
+ # Update BM25 weight and recreate ensemble retriever if needed
1324
+ if hasattr(rag_chain, 'bm25_weight') and rag_chain.bm25_weight != bm25_weight:
1325
+ rag_chain.bm25_weight = bm25_weight
1326
+ rag_chain.faiss_weight = 1.0 - bm25_weight
1327
+ rag_chain.ensemble_retriever = EnsembleRetriever(
1328
+ retrievers=[rag_chain.bm25_retriever, rag_chain.faiss_retriever],
1329
+ weights=[rag_chain.bm25_weight, rag_chain.faiss_weight]
1330
+ )
1331
+ debug_print(f"Updated ensemble retriever with BM25 weight: {bm25_weight}")
1332
+
1333
  # Clear conversation history if checkbox is unchecked
1334
  if not use_history:
1335
  rag_chain.conversation_history = []
 
1509
  gr.Markdown('''# PhiRAG - Async Version
1510
  **PhiRAG** Query Your Data with Advanced RAG Techniques
1511
 
1512
+ **Embedding Models:** Choose from the following options:
1513
+ - πŸ€— **HuggingFace Models (Free)**: sentence-transformers, BAAI, intfloat models
1514
+ - 🟦 **Nebius Models (Cost)**: Qwen, BAAI models via Nebius platform
1515
+ - **Dimensions**: 384 (fast), 768 (balanced), 1024 (powerful)
1516
+ - **Languages**: English-focused and multilingual options available
1517
+
1518
+ **LLM Models:** Choose from the following options in the Query tabs:
1519
  - πŸ‡ΊπŸ‡Έ Remote Meta-Llama-3 - has context windows of 8000 tokens
1520
  - πŸ‡ͺπŸ‡Ί Mistral-API - has context windows of 32000 tokens
1521
 
 
1539
  **⚠️ IMPORTANT: This app now uses asynchronous processing to avoid timeout issues**
1540
  - When you load files or submit a query, you'll receive a Job ID
1541
  - Use the "Check Job Status" tab to monitor and retrieve your results
1542
+
1543
+ **πŸ”‘ API Keys Required:**
1544
+ - For Nebius embedding models: Set the NEBIUS_API_KEY environment variable
1545
+ - For OpenAI models: Set the OPENAI_API_KEY environment variable
1546
+ - For Mistral models: Set the MISTRAL_API_KEY environment variable
1547
+ - For HuggingFace models: Set the HF_API_TOKEN environment variable
1548
  ''')
1549
 
1550
  with gr.Tabs() as tabs:
1551
  with gr.TabItem("Setup & Load Files"):
1552
  with gr.Row():
1553
+ with gr.Column(scale=2): # Expanded to take more space
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1554
  pdf_input = gr.Textbox(
1555
  label="Enter your file URLs (one per line)",
1556
  placeholder="Enter one URL per line (.pdf or .txt)",
1557
  lines=4
1558
  )
1559
+ with gr.Column(scale=1): # Smaller column for controls
1560
+ embedding_dropdown = gr.Dropdown(
1561
+ choices=[
1562
+ # sentence-transformers Models (Free)
1563
+ "πŸ€— sentence-transformers/all-MiniLM-L6-v2 (384 dim, fast)",
1564
+ "πŸ€— sentence-transformers/all-mpnet-base-v2 (768 dim, high-quality)",
1565
+ "πŸ€— sentence-transformers/all-distilroberta-v1 (768 dim, balanced)",
1566
+ "πŸ€— sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 (384 dim, multilingual)",
1567
+ "πŸ€— sentence-transformers/paraphrase-multilingual-mpnet-base-v2 (768 dim, multilingual)",
1568
+
1569
+ # HuggingFace Models (Free)
1570
+ "πŸ€— BAAI/bge-small-en-v1.5 (384 dim, efficient)",
1571
+ "πŸ€— BAAI/bge-base-en-v1.5 (768 dim, excellent)",
1572
+ "πŸ€— BAAI/bge-large-en-v1.5 (1024 dim, powerful)",
1573
+ "πŸ€— intfloat/e5-base-v2 (768 dim, general-purpose)",
1574
+ "πŸ€— intfloat/e5-large-v2 (1024 dim, advanced)",
1575
+
1576
+ # Nebius Models (Cost)
1577
+ "🟦 Qwen/Qwen3-Embedding-8B (1024 dim, advanced)",
1578
+ "🟦 BAAI/bge-en-icl (1024 dim, instruction-tuned)",
1579
+ "🟦 BAAI/bge-multilingual-gemma2 (1024 dim, multilingual)",
1580
+ ],
1581
+ value="πŸ€— sentence-transformers/all-MiniLM-L6-v2 (384 dim, fast)",
1582
+ label="Select Embedding Model (πŸ€— = HuggingFace free, 🟦 = Nebius cost)"
1583
  )
 
1584
  bm25_weight_slider = gr.Slider(
1585
  minimum=0.0, maximum=1.0, value=0.6, step=0.1,
1586
  label="Lexical vs Semantics (BM25 Weight)"
 
1602
 
1603
  with gr.Row():
1604
  model_output = gr.Markdown("**Current Model**: Not selected")
1605
+
1606
+ # Job Status Section for Setup & Load
1607
+ with gr.Row():
1608
+ with gr.Column(scale=1):
1609
+ setup_job_list = gr.Markdown(
1610
+ value="No jobs yet",
1611
+ label="Job List (Click to select)"
1612
+ )
1613
+ setup_refresh_button = gr.Button("Refresh Job List")
1614
+ setup_auto_refresh_checkbox = gr.Checkbox(
1615
+ label="Enable Auto Refresh",
1616
+ value=False
1617
+ )
1618
+ setup_df = gr.DataFrame(
1619
+ value=[], # Empty initial value
1620
+ headers=["Number", "Square"],
1621
+ label="Query Results",
1622
+ visible=False
1623
+ )
1624
+
1625
+ with gr.Column(scale=2):
1626
+ setup_job_id_input = gr.Textbox(
1627
+ label="Job ID",
1628
+ placeholder="Job ID will appear here when selected from the list",
1629
+ lines=1
1630
+ )
1631
+ setup_job_query_display = gr.Textbox(
1632
+ label="Job Query",
1633
+ placeholder="The query associated with this job will appear here",
1634
+ lines=2,
1635
+ interactive=False
1636
+ )
1637
+ setup_check_button = gr.Button("Check Status")
1638
+ setup_cleanup_button = gr.Button("Cleanup Old Jobs")
1639
+
1640
+ with gr.Row():
1641
+ setup_status_response = gr.Textbox(
1642
+ label="Job Result",
1643
+ placeholder="Job result will appear here",
1644
+ lines=6
1645
+ )
1646
+ setup_status_context = gr.Textbox(
1647
+ label="Context Information",
1648
+ placeholder="Context information will appear here",
1649
+ lines=6
1650
+ )
1651
+
1652
+ with gr.Row():
1653
+ setup_status_tokens1 = gr.Markdown("")
1654
+ setup_status_tokens2 = gr.Markdown("")
1655
 
1656
  with gr.TabItem("Submit Query", elem_classes=["query-tab"]):
1657
  with gr.Row():
 
1929
  # Add initialization info display
1930
  init_info = gr.Markdown("")
1931
 
1932
+ # Update load_button click to include embedding model
1933
  load_button.click(
1934
+ lambda file_links, bm25_weight, embedding_model: load_pdfs_async(file_links, default_prompt, bm25_weight, embedding_model),
1935
+ inputs=[pdf_input, bm25_weight_slider, embedding_dropdown],
1936
+ outputs=[load_response, load_context, model_output, setup_job_id_input, setup_job_query_display, setup_job_list, init_info]
1937
+ )
1938
+
1939
+ # Also update Setup & Load job list when files are loaded
1940
+ load_button.click(
1941
+ fn=lambda *args: get_job_list(),
1942
+ inputs=[],
1943
+ outputs=[setup_job_list]
1944
  )
1945
 
1946
  # Add function to sync job IDs between tabs
 
1967
  outputs=[query_response, query_context, input_tokens, output_tokens, job_id_input, job_query_display, job_list]
1968
  )
1969
 
1970
+ # Sync BM25 weight between Setup & Load and Query tabs
1971
+ def sync_bm25_weight(bm25_weight):
1972
+ return bm25_weight
1973
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1974
  bm25_weight_slider.change(
1975
+ fn=sync_bm25_weight,
1976
+ inputs=[bm25_weight_slider],
1977
+ outputs=[query_bm25_weight_slider]
1978
  )
1979
 
1980
  # Connect the buttons to their respective functions
 
2010
  outputs=[reset_response, reset_context, reset_model]
2011
  )
2012
 
 
 
 
 
 
2013
 
2014
  # Add an event to refresh the job list on page load
2015
  app.load(
 
2018
  outputs=job_list
2019
  )
2020
 
2021
+ # Setup & Load Job Status Event Handlers
2022
+ setup_check_button.click(
2023
+ check_job_status,
2024
+ inputs=[setup_job_id_input],
2025
+ outputs=[setup_status_response, setup_status_context, setup_status_tokens1, setup_status_tokens2, setup_job_query_display]
2026
+ )
2027
+
2028
+ setup_refresh_button.click(
2029
+ refresh_job_list,
2030
+ inputs=[],
2031
+ outputs=[setup_job_list]
2032
+ )
2033
+
2034
+ setup_job_id_input.change(
2035
+ job_selected,
2036
+ inputs=[setup_job_id_input],
2037
+ outputs=[setup_job_id_input, setup_job_query_display]
2038
+ )
2039
+
2040
+ setup_cleanup_button.click(
2041
+ cleanup_old_jobs,
2042
+ inputs=[],
2043
+ outputs=[setup_status_response, setup_status_context, setup_status_tokens1]
2044
+ )
2045
+
2046
+ setup_auto_refresh_checkbox.change(
2047
+ fn=periodic_update,
2048
+ inputs=[setup_auto_refresh_checkbox],
2049
+ outputs=[setup_job_list, setup_status_response, setup_df, setup_status_context],
2050
+ every=2
2051
+ )
2052
+
2053
  # Use the Checkbox to control the periodic updates
2054
  auto_refresh_checkbox.change(
2055
  fn=periodic_update,
psyllm.py CHANGED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -47,3 +47,6 @@ pydantic==2.9.0
47
  sentence-transformers>=2.4.0
48
 
49
  mistralai==1.5.0
 
 
 
 
47
  sentence-transformers>=2.4.0
48
 
49
  mistralai==1.5.0
50
+
51
+ matplotlib>=3.0.0
52
+ networkx>=2.0