alx-d commited on
Commit
e8b305d
·
verified ·
1 Parent(s): 01330c2

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. advanced_rag.py +731 -125
advanced_rag.py CHANGED
@@ -36,6 +36,189 @@ from langchain_community.document_loaders import PyMuPDFLoader # Updated loader
36
  import tempfile
37
  import mimetypes
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def get_mime_type(file_path):
40
  return mimetypes.guess_type(file_path)[0] or 'application/octet-stream'
41
 
@@ -43,6 +226,8 @@ print("Pydantic Version: ")
43
  print(pydantic.__version__)
44
  # Add Mistral imports with fallback handling
45
 
 
 
46
  try:
47
  from mistralai import Mistral
48
  MISTRAL_AVAILABLE = True
@@ -107,11 +292,14 @@ def process_in_background(job_id, function, args):
107
  error_result = (f"Error processing job: {str(e)}", "", "", "")
108
  results_queue.put((job_id, error_result))
109
 
110
- def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
111
  """Asynchronous version of load_pdfs_updated to prevent timeouts"""
112
  global last_job_id
113
  if not file_links:
114
- return "Please enter non-empty URLs", "", "Model used: N/A", "", "", get_job_list()
 
 
 
115
 
116
  job_id = str(uuid.uuid4())
117
  debug_print(f"Starting async job {job_id} for file loading")
@@ -119,7 +307,7 @@ def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temp
119
  # Start background thread
120
  threading.Thread(
121
  target=process_in_background,
122
- args=(job_id, load_pdfs_updated, [file_links, model_choice, prompt_template, bm25_weight, temperature, top_p])
123
  ).start()
124
 
125
  job_query = f"Loading files: {file_links.split()[0]}..." if file_links else "No files"
@@ -132,6 +320,8 @@ def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temp
132
 
133
  last_job_id = job_id
134
 
 
 
135
  return (
136
  f"Files submitted and processing in the background (Job ID: {job_id}).\n\n"
137
  f"Use 'Check Job Status' tab with this ID to get results.",
@@ -139,14 +329,17 @@ def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temp
139
  f"Model requested: {model_choice}",
140
  job_id, # Return job_id to update the job_id_input component
141
  job_query, # Return job_query to update the job_query_display component
142
- get_job_list() # Return updated job list
 
143
  )
144
 
145
- def submit_query_async(query, model_choice=None):
146
  """Asynchronous version of submit_query_updated to prevent timeouts"""
147
  global last_job_id
148
  if not query:
149
  return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0", "", "", get_job_list()
 
 
150
 
151
  job_id = str(uuid.uuid4())
152
  debug_print(f"Starting async job {job_id} for query: {query}")
@@ -154,13 +347,13 @@ def submit_query_async(query, model_choice=None):
154
  # Update model if specified
155
  if model_choice and rag_chain and rag_chain.llm_choice != model_choice:
156
  debug_print(f"Updating model to {model_choice} for this query")
157
- rag_chain.update_llm_pipeline(model_choice, rag_chain.temperature, rag_chain.top_p,
158
- rag_chain.prompt_template, rag_chain.bm25_weight)
159
 
160
  # Start background thread
161
  threading.Thread(
162
  target=process_in_background,
163
- args=(job_id, submit_query_updated, [query])
164
  ).start()
165
 
166
  jobs[job_id] = {
@@ -550,7 +743,7 @@ def load_file_from_google_drive(link: str) -> list:
550
 
551
  class ElevatedRagChain:
552
  def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
553
- bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95) -> None:
554
  debug_print(f"Initializing ElevatedRagChain with model: {llm_choice}")
555
  self.embed_func = HuggingFaceEmbeddings(
556
  model_name="sentence-transformers/all-MiniLM-L6-v2",
@@ -558,7 +751,7 @@ class ElevatedRagChain:
558
  )
559
  self.bm25_weight = bm25_weight
560
  self.faiss_weight = 1.0 - bm25_weight
561
- self.top_k = 5
562
  self.llm_choice = llm_choice
563
  self.temperature = temperature
564
  self.top_p = top_p
@@ -587,9 +780,119 @@ class ElevatedRagChain:
587
  # Improve error handling in the ElevatedRagChain class
588
  def create_llm_pipeline(self):
589
  from langchain.llms.base import LLM # Import LLM here so it's always defined
590
- normalized = self.llm_choice.lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
591
  try:
592
- if "remote" in normalized:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593
  debug_print("Creating remote Meta-Llama-3 pipeline via Hugging Face Inference API...")
594
  from huggingface_hub import InferenceClient
595
  repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
@@ -598,20 +901,19 @@ class ElevatedRagChain:
598
  raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
599
 
600
  client = InferenceClient(token=hf_api_token, timeout=120)
601
-
602
- # We no longer use wait_for_model because it's unsupported
603
  def remote_generate(prompt: str) -> str:
604
  max_retries = 3
605
  backoff = 2 # start with 2 seconds
606
  for attempt in range(max_retries):
607
  try:
608
- debug_print(f"Remote generation attempt {attempt+1}")
609
  response = client.text_generation(
610
  prompt,
611
  model=repo_id,
612
  temperature=self.temperature,
613
  top_p=self.top_p,
614
- max_new_tokens=512 # Reduced token count for speed
615
  )
616
  return response
617
  except Exception as e:
@@ -623,6 +925,11 @@ class ElevatedRagChain:
623
  return "Failed to generate response after multiple attempts."
624
 
625
  class RemoteLLM(LLM):
 
 
 
 
 
626
  @property
627
  def _llm_type(self) -> str:
628
  return "remote_llm"
@@ -632,97 +939,74 @@ class ElevatedRagChain:
632
 
633
  @property
634
  def _identifying_params(self) -> dict:
635
- return {"model": repo_id}
636
 
637
  debug_print("Remote Meta-Llama-3 pipeline created successfully.")
638
  return RemoteLLM()
639
-
640
- elif "mistral-api" in normalized:
 
641
  debug_print("Creating Mistral API pipeline...")
642
  mistral_api_key = os.environ.get("MISTRAL_API_KEY")
643
  if not mistral_api_key:
644
  raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
 
645
  try:
646
- from mistralai import Mistral
647
  debug_print("Mistral library imported successfully")
648
  except ImportError:
649
- debug_print("Mistral client library not installed. Falling back to Llama pipeline.")
650
- normalized = "llama"
651
- if normalized != "llama":
652
- # from pydantic import PrivateAttr
653
- # from langchain.llms.base import LLM
654
- # from typing import Any, Optional, List
655
- # import typing
656
-
657
- class MistralLLM(LLM):
658
- temperature: float = 0.7
659
- top_p: float = 0.95
660
- _client: Any = PrivateAttr(default=None)
661
-
662
- def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
663
- try:
664
- super().__init__(**kwargs)
665
- # Bypass Pydantic's __setattr__ to assign to _client
666
- object.__setattr__(self, '_client', Mistral(api_key=api_key))
667
- self.temperature = temperature
668
- self.top_p = top_p
669
- except Exception as e:
670
- debug_print(f"Init Mistral failed with error: {e}")
671
-
672
- @property
673
- def _llm_type(self) -> str:
674
- return "mistral_llm"
675
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
676
- try:
677
- debug_print("Calling Mistral API...")
678
- response = self._client.chat.complete(
679
- model="mistral-small-latest",
680
- messages=[{"role": "user", "content": prompt}],
681
- temperature=self.temperature,
682
- top_p=self.top_p
683
- )
684
- return response.choices[0].message.content
685
- except Exception as e:
686
- debug_print(f"Mistral API error: {str(e)}")
687
- return f"Error generating response: {str(e)}"
688
- @property
689
- def _identifying_params(self) -> dict:
690
- return {"model": "mistral-small-latest"}
691
- debug_print("Creating Mistral LLM instance")
692
- mistral_llm = MistralLLM(api_key=mistral_api_key, temperature=self.temperature, top_p=self.top_p)
693
- debug_print("Mistral API pipeline created successfully.")
694
- return mistral_llm
695
-
696
- else:
697
- # Default case - using a fallback model (or Llama)
698
- debug_print("Using local/fallback model pipeline")
699
- model_id = "facebook/opt-350m" # Use a smaller model as fallback
700
- pipe = pipeline(
701
- "text-generation",
702
- model=model_id,
703
- device=-1, # CPU
704
- max_length=1024
705
- )
706
 
707
- class LocalLLM(LLM):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
708
  @property
709
  def _llm_type(self) -> str:
710
- return "local_llm"
 
711
  def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
712
- # For this fallback, truncate prompt if it exceeds limits
713
- reserved_gen = 128
714
- max_total = 1024
715
- max_prompt_tokens = max_total - reserved_gen
716
- truncated_prompt = truncate_prompt(prompt, max_tokens=max_prompt_tokens)
717
- generated = pipe(truncated_prompt, max_new_tokens=reserved_gen)[0]["generated_text"]
718
- return generated
 
 
 
 
 
 
 
719
  @property
720
  def _identifying_params(self) -> dict:
721
- return {"model": model_id, "max_length": 1024}
 
 
 
 
 
 
 
 
722
 
723
- debug_print("Local fallback pipeline created.")
724
- return LocalLLM()
725
-
726
  except Exception as e:
727
  debug_print(f"Error creating LLM pipeline: {str(e)}")
728
  # Return a dummy LLM that explains the error
@@ -741,11 +1025,12 @@ class ElevatedRagChain:
741
  return ErrorLLM()
742
 
743
 
744
- def update_llm_pipeline(self, new_model_choice: str, temperature: float, top_p: float, prompt_template: str, bm25_weight: float):
745
  debug_print(f"Updating chain with new model: {new_model_choice}")
746
  self.llm_choice = new_model_choice
747
  self.temperature = temperature
748
  self.top_p = top_p
 
749
  self.prompt_template = prompt_template
750
  self.bm25_weight = bm25_weight
751
  self.faiss_weight = 1.0 - bm25_weight
@@ -753,7 +1038,14 @@ class ElevatedRagChain:
753
  def format_response(response: str) -> str:
754
  input_tokens = count_tokens(self.context + self.prompt_template)
755
  output_tokens = count_tokens(response)
756
- formatted = f"### Response\n\n{response}\n\n---\n"
 
 
 
 
 
 
 
757
  formatted += f"- **Input tokens:** {input_tokens}\n"
758
  formatted += f"- **Output tokens:** {output_tokens}\n"
759
  formatted += f"- **Generated using:** {self.llm_choice}\n"
@@ -836,7 +1128,14 @@ class ElevatedRagChain:
836
  def format_response(response: str) -> str:
837
  input_tokens = count_tokens(self.context + self.prompt_template)
838
  output_tokens = count_tokens(response)
839
- formatted = f"### Response\n\n{response}\n\n---\n"
 
 
 
 
 
 
 
840
  formatted += f"- **Input tokens:** {input_tokens}\n"
841
  formatted += f"- **Output tokens:** {output_tokens}\n"
842
  formatted += f"- **Generated using:** {self.llm_choice}\n"
@@ -863,7 +1162,7 @@ class ElevatedRagChain:
863
  global rag_chain
864
  rag_chain = ElevatedRagChain()
865
 
866
- def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
867
  debug_print("Inside load_pdfs function.")
868
  if not file_links:
869
  debug_print("Please enter non-empty URLs")
@@ -872,7 +1171,7 @@ def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, te
872
  links = [link.strip() for link in file_links.split("\n") if link.strip()]
873
  global rag_chain
874
  if rag_chain.raw_data:
875
- rag_chain.update_llm_pipeline(model_choice, temperature, top_p, prompt_template, bm25_weight)
876
  context_display = rag_chain.get_current_context()
877
  response_msg = f"Files already loaded. Chain updated with model: {model_choice}"
878
  return (
@@ -887,7 +1186,8 @@ def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, te
887
  prompt_template=prompt_template,
888
  bm25_weight=bm25_weight,
889
  temperature=temperature,
890
- top_p=top_p
 
891
  )
892
  rag_chain.add_pdfs_to_vectore_store(links)
893
  context_display = rag_chain.get_current_context()
@@ -911,7 +1211,7 @@ def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, te
911
  def update_model(new_model: str):
912
  global rag_chain
913
  if rag_chain and rag_chain.raw_data:
914
- rag_chain.update_llm_pipeline(new_model, rag_chain.temperature, rag_chain.top_p,
915
  rag_chain.prompt_template, rag_chain.bm25_weight)
916
  debug_print(f"Model updated to {rag_chain.llm_choice}")
917
  return f"Model updated to: {rag_chain.llm_choice}"
@@ -920,7 +1220,7 @@ def update_model(new_model: str):
920
 
921
 
922
  # Update submit_query_updated to better handle context limitation
923
- def submit_query_updated(query):
924
  debug_print(f"Processing query: {query}")
925
  if not query:
926
  debug_print("Empty query received")
@@ -931,6 +1231,19 @@ def submit_query_updated(query):
931
  return "Please load files first.", "", "Input tokens: 0", "Output tokens: 0"
932
 
933
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
934
  # Determine max context size based on model
935
  model_name = rag_chain.llm_choice.lower()
936
  max_context_tokens = 32000 if "mistral" in model_name else 4096
@@ -1077,6 +1390,43 @@ document.addEventListener('DOMContentLoaded', function() {
1077
  clearInterval(jobListInterval);
1078
  }
1079
  }, 500);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1080
  });
1081
  """) as app:
1082
  gr.Markdown('''# PhiRAG - Async Version
@@ -1113,8 +1463,16 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
1113
  with gr.Row():
1114
  with gr.Column():
1115
  model_dropdown = gr.Dropdown(
1116
- choices=["🇺🇸 Remote Meta-Llama-3", "🇪🇺 Mistral-API"],
1117
- value="🇺🇸 Remote Meta-Llama-3",
 
 
 
 
 
 
 
 
1118
  label="Select Model"
1119
  )
1120
  temperature_slider = gr.Slider(
@@ -1125,6 +1483,10 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
1125
  minimum=0.1, maximum=0.99, value=0.95, step=0.05,
1126
  label="Word Variety (Top-p)"
1127
  )
 
 
 
 
1128
  with gr.Column():
1129
  pdf_input = gr.Textbox(
1130
  label="Enter your file URLs (one per line)",
@@ -1160,21 +1522,46 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
1160
  with gr.Row():
1161
  model_output = gr.Markdown("**Current Model**: Not selected")
1162
 
1163
- with gr.TabItem("Submit Query"):
1164
  with gr.Row():
1165
- # Add this line to define the query_model_dropdown
1166
- query_model_dropdown = gr.Dropdown(
1167
- choices=["🇺🇸 Remote Meta-Llama-3", "🇪🇺 Mistral-API"],
1168
- value="🇺🇸 Remote Meta-Llama-3",
1169
- label="Query Model"
1170
- )
1171
-
1172
- query_input = gr.Textbox(
1173
- label="Enter your query here",
1174
- placeholder="Type your query",
1175
- lines=4
1176
- )
1177
- submit_button = gr.Button("Submit Query (Async)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1178
 
1179
  with gr.Row():
1180
  query_response = gr.Textbox(
@@ -1247,6 +1634,138 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
1247
  status_tokens1 = gr.Markdown("")
1248
  status_tokens2 = gr.Markdown("")
1249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1250
  with gr.TabItem("App Management"):
1251
  with gr.Row():
1252
  reset_button = gr.Button("Reset App")
@@ -1267,26 +1786,50 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
1267
  with gr.Row():
1268
  reset_model = gr.Markdown("")
1269
 
1270
- # Connect the buttons to their respective functions
 
 
 
1271
  load_button.click(
1272
  load_pdfs_async,
1273
- inputs=[pdf_input, model_dropdown, prompt_input, bm25_weight_slider, temperature_slider, top_p_slider],
1274
- outputs=[load_response, load_context, model_output, job_id_input, job_query_display, job_list]
1275
- )
1276
-
1277
- # Also sync in the other direction
1278
- query_model_dropdown.change(
1279
- fn=sync_model_dropdown,
1280
- inputs=query_model_dropdown,
1281
- outputs=model_dropdown
1282
  )
1283
 
 
1284
  submit_button.click(
1285
  submit_query_async,
1286
- inputs=[query_input, query_model_dropdown],
1287
  outputs=[query_response, query_context, input_tokens, output_tokens, job_id_input, job_query_display, job_list]
1288
  )
1289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1290
  check_button.click(
1291
  check_job_status,
1292
  inputs=[job_id_input],
@@ -1340,6 +1883,69 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
1340
  every=2 #if auto_refresh_checkbox.value else None # Directly set `every` based on the checkbox state
1341
  )
1342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1343
  if __name__ == "__main__":
1344
  debug_print("Launching Gradio interface.")
1345
  app.queue().launch(share=False)
 
36
  import tempfile
37
  import mimetypes
38
 
39
+ # Add batch processing helper functions
40
+ def generate_parameter_values(min_val, max_val, num_values):
41
+ """Generate evenly spaced values between min and max"""
42
+ if num_values == 1:
43
+ return [min_val]
44
+ step = (max_val - min_val) / (num_values - 1)
45
+ return [min_val + (step * i) for i in range(num_values)]
46
+
47
+ def process_batch_query(query, model_choice, max_tokens, param_configs, slider_values, job_id):
48
+ """Process a batch of queries with different parameter combinations"""
49
+ results = []
50
+
51
+ # Generate all parameter combinations
52
+ temp_values = [slider_values['temperature']] if param_configs['temperature'] == "Constant" else generate_parameter_values(0.1, 1.0, int(param_configs['temperature'].split()[2]))
53
+ top_p_values = [slider_values['top_p']] if param_configs['top_p'] == "Constant" else generate_parameter_values(0.1, 0.99, int(param_configs['top_p'].split()[2]))
54
+ top_k_values = [slider_values['top_k']] if param_configs['top_k'] == "Constant" else generate_parameter_values(1, 100, int(param_configs['top_k'].split()[2]))
55
+ bm25_values = [slider_values['bm25']] if param_configs['bm25'] == "Constant" else generate_parameter_values(0.0, 1.0, int(param_configs['bm25'].split()[2]))
56
+
57
+ total_combinations = len(temp_values) * len(top_p_values) * len(top_k_values) * len(bm25_values)
58
+ current = 0
59
+
60
+ for temp in temp_values:
61
+ for top_p in top_p_values:
62
+ for top_k in top_k_values:
63
+ for bm25 in bm25_values:
64
+ current += 1
65
+ try:
66
+ # Update parameters
67
+ rag_chain.temperature = temp
68
+ rag_chain.top_p = top_p
69
+ rag_chain.top_k = top_k
70
+ rag_chain.bm25_weight = bm25
71
+ rag_chain.faiss_weight = 1.0 - bm25
72
+
73
+ # Update ensemble retriever
74
+ rag_chain.ensemble_retriever = EnsembleRetriever(
75
+ retrievers=[rag_chain.bm25_retriever, rag_chain.faiss_retriever],
76
+ weights=[rag_chain.bm25_weight, rag_chain.faiss_weight]
77
+ )
78
+
79
+ # Process query
80
+ response = rag_chain.elevated_rag_chain.invoke({"question": query})
81
+
82
+ # Format result
83
+ result = {
84
+ "Parameters": f"Temp: {temp:.2f}, Top-p: {top_p:.2f}, Top-k: {top_k}, BM25: {bm25:.2f}",
85
+ "Response": response,
86
+ "Progress": f"Query {current}/{total_combinations}"
87
+ }
88
+ results.append(result)
89
+
90
+ except Exception as e:
91
+ results.append({
92
+ "Parameters": f"Temp: {temp:.2f}, Top-p: {top_p:.2f}, Top-k: {top_k}, BM25: {bm25:.2f}",
93
+ "Response": f"Error: {str(e)}",
94
+ "Progress": f"Query {current}/{total_combinations}"
95
+ })
96
+
97
+ # Format final results
98
+ formatted_results = "### Batch Query Results\n\n"
99
+ for result in results:
100
+ formatted_results += f"#### {result['Parameters']}\n"
101
+ formatted_results += f"**Progress:** {result['Progress']}\n\n"
102
+ formatted_results += f"{result['Response']}\n\n"
103
+ formatted_results += "---\n\n"
104
+
105
+ return (
106
+ formatted_results,
107
+ f"Job ID: {job_id}",
108
+ f"Input tokens: {count_tokens(query)}",
109
+ f"Output tokens: {sum(count_tokens(r['Response']) for r in results)}"
110
+ )
111
+
112
+ def process_batch_query_async(query, model_choice, max_tokens, param_configs, slider_values):
113
+ """Asynchronous version of batch query processing"""
114
+ global last_job_id
115
+ if not query:
116
+ return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0", "", "", get_job_list()
117
+
118
+ if not hasattr(rag_chain, 'elevated_rag_chain') or not rag_chain.raw_data:
119
+ return "Please load files first.", "", "Input tokens: 0", "Output tokens: 0", "", "", get_job_list()
120
+
121
+ job_id = str(uuid.uuid4())
122
+ debug_print(f"Starting async batch job {job_id} for query: {query}")
123
+
124
+ # Get slider values
125
+ slider_values = {
126
+ 'temperature': slider_values['temperature'],
127
+ 'top_p': slider_values['top_p'],
128
+ 'top_k': slider_values['top_k'],
129
+ 'bm25': slider_values['bm25']
130
+ }
131
+
132
+ # Start background thread
133
+ threading.Thread(
134
+ target=process_in_background,
135
+ args=(job_id, process_batch_query, [query, model_choice, max_tokens, param_configs, slider_values, job_id])
136
+ ).start()
137
+
138
+ jobs[job_id] = {
139
+ "status": "processing",
140
+ "type": "batch_query",
141
+ "start_time": time.time(),
142
+ "query": query,
143
+ "model": model_choice,
144
+ "param_configs": param_configs
145
+ }
146
+
147
+ last_job_id = job_id
148
+
149
+ return (
150
+ f"Batch query submitted and processing in the background (Job ID: {job_id}).\n\n"
151
+ f"Use 'Check Job Status' tab with this ID to get results.",
152
+ f"Job ID: {job_id}",
153
+ f"Input tokens: {count_tokens(query)}",
154
+ "Output tokens: pending",
155
+ job_id, # Return job_id to update the job_id_input component
156
+ query, # Return query to update the job_query_display component
157
+ get_job_list() # Return updated job list
158
+ )
159
+
160
+ def submit_batch_query_async(query, model_choice, max_tokens, temp_config, top_p_config, top_k_config, bm25_config,
161
+ temp_slider, top_p_slider, top_k_slider, bm25_slider):
162
+ """Handle batch query submission with async processing"""
163
+ if not query:
164
+ return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0", "", "", get_job_list()
165
+
166
+ if not hasattr(rag_chain, 'elevated_rag_chain') or not rag_chain.raw_data:
167
+ return "Please load files first.", "", "Input tokens: 0", "Output tokens: 0", "", "", get_job_list()
168
+
169
+ # Get slider values
170
+ slider_values = {
171
+ 'temperature': temp_slider,
172
+ 'top_p': top_p_slider,
173
+ 'top_k': top_k_slider,
174
+ 'bm25': bm25_slider
175
+ }
176
+
177
+ param_configs = {
178
+ 'temperature': temp_config,
179
+ 'top_p': top_p_config,
180
+ 'top_k': top_k_config,
181
+ 'bm25': bm25_config
182
+ }
183
+
184
+ return process_batch_query_async(query, model_choice, max_tokens, param_configs, slider_values)
185
+
186
+ def submit_batch_query(query, model_choice, max_tokens, temp_config, top_p_config, top_k_config, bm25_config,
187
+ temp_slider, top_p_slider, top_k_slider, bm25_slider):
188
+ """Handle batch query submission"""
189
+ if not query:
190
+ return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
191
+
192
+ if not hasattr(rag_chain, 'elevated_rag_chain') or not rag_chain.raw_data:
193
+ return "Please load files first.", "", "Input tokens: 0", "Output tokens: 0"
194
+
195
+ # Get slider values
196
+ slider_values = {
197
+ 'temperature': temp_slider,
198
+ 'top_p': top_p_slider,
199
+ 'top_k': top_k_slider,
200
+ 'bm25': bm25_slider
201
+ }
202
+
203
+ try:
204
+ results = process_batch_query(query, model_choice, max_tokens,
205
+ {'temperature': temp_config, 'top_p': top_p_config,
206
+ 'top_k': top_k_config, 'bm25': bm25_config},
207
+ slider_values)
208
+
209
+ # Format results for display
210
+ formatted_results = "### Batch Query Results\n\n"
211
+ for result in results:
212
+ formatted_results += f"#### {result['Parameters']}\n"
213
+ formatted_results += f"**Progress:** {result['Progress']}\n\n"
214
+ formatted_results += f"{result['Response']}\n\n"
215
+ formatted_results += "---\n\n"
216
+
217
+ return formatted_results, "", f"Input tokens: {count_tokens(query)}", f"Output tokens: {sum(count_tokens(r['Response']) for r in results)}"
218
+
219
+ except Exception as e:
220
+ return f"Error processing batch query: {str(e)}", "", "Input tokens: 0", "Output tokens: 0"
221
+
222
  def get_mime_type(file_path):
223
  return mimetypes.guess_type(file_path)[0] or 'application/octet-stream'
224
 
 
226
  print(pydantic.__version__)
227
  # Add Mistral imports with fallback handling
228
 
229
+ slider_max_tokens = None
230
+
231
  try:
232
  from mistralai import Mistral
233
  MISTRAL_AVAILABLE = True
 
292
  error_result = (f"Error processing job: {str(e)}", "", "", "")
293
  results_queue.put((job_id, error_result))
294
 
295
+ def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p, top_k, max_tokens_slider):
296
  """Asynchronous version of load_pdfs_updated to prevent timeouts"""
297
  global last_job_id
298
  if not file_links:
299
+ return "Please enter non-empty URLs", "", "Model used: N/A", "", "", get_job_list(), ""
300
+ global slider_max_tokens
301
+ slider_max_tokens = max_tokens_slider
302
+
303
 
304
  job_id = str(uuid.uuid4())
305
  debug_print(f"Starting async job {job_id} for file loading")
 
307
  # Start background thread
308
  threading.Thread(
309
  target=process_in_background,
310
+ args=(job_id, load_pdfs_updated, [file_links, model_choice, prompt_template, bm25_weight, temperature, top_p, top_k])
311
  ).start()
312
 
313
  job_query = f"Loading files: {file_links.split()[0]}..." if file_links else "No files"
 
320
 
321
  last_job_id = job_id
322
 
323
+ init_message = "Vector database initialized using the files.\nThe above parameters were used in the initialization of the RAG chain."
324
+
325
  return (
326
  f"Files submitted and processing in the background (Job ID: {job_id}).\n\n"
327
  f"Use 'Check Job Status' tab with this ID to get results.",
 
329
  f"Model requested: {model_choice}",
330
  job_id, # Return job_id to update the job_id_input component
331
  job_query, # Return job_query to update the job_query_display component
332
+ get_job_list(), # Return updated job list
333
+ init_message # Return initialization message
334
  )
335
 
336
+ def submit_query_async(query, model_choice, max_tokens_slider, temperature, top_p, top_k, bm25_weight):
337
  """Asynchronous version of submit_query_updated to prevent timeouts"""
338
  global last_job_id
339
  if not query:
340
  return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0", "", "", get_job_list()
341
+ global slider_max_tokens
342
+ slider_max_tokens = max_tokens_slider
343
 
344
  job_id = str(uuid.uuid4())
345
  debug_print(f"Starting async job {job_id} for query: {query}")
 
347
  # Update model if specified
348
  if model_choice and rag_chain and rag_chain.llm_choice != model_choice:
349
  debug_print(f"Updating model to {model_choice} for this query")
350
+ rag_chain.update_llm_pipeline(model_choice, temperature, top_p, top_k,
351
+ rag_chain.prompt_template, bm25_weight)
352
 
353
  # Start background thread
354
  threading.Thread(
355
  target=process_in_background,
356
+ args=(job_id, submit_query_updated, [query, temperature, top_p, top_k, bm25_weight])
357
  ).start()
358
 
359
  jobs[job_id] = {
 
743
 
744
  class ElevatedRagChain:
745
  def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
746
+ bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95, top_k: int = 50) -> None:
747
  debug_print(f"Initializing ElevatedRagChain with model: {llm_choice}")
748
  self.embed_func = HuggingFaceEmbeddings(
749
  model_name="sentence-transformers/all-MiniLM-L6-v2",
 
751
  )
752
  self.bm25_weight = bm25_weight
753
  self.faiss_weight = 1.0 - bm25_weight
754
+ self.top_k = top_k
755
  self.llm_choice = llm_choice
756
  self.temperature = temperature
757
  self.top_p = top_p
 
780
  # Improve error handling in the ElevatedRagChain class
781
  def create_llm_pipeline(self):
782
  from langchain.llms.base import LLM # Import LLM here so it's always defined
783
+ from typing import Optional, List, Any
784
+ from pydantic import PrivateAttr
785
+ global slider_max_tokens
786
+
787
+ # Extract the model name without the flag emoji prefix
788
+ clean_llm_choice = self.llm_choice.split(" ", 1)[-1] if " " in self.llm_choice else self.llm_choice
789
+ normalized = clean_llm_choice.lower()
790
+ print(f"Normalized model name: {normalized}")
791
+
792
+ # Model configurations from the second file
793
+ model_token_limits = {
794
+ "gpt-3.5": 16385,
795
+ "gpt-4o": 128000,
796
+ "gpt-4o-mini": 128000,
797
+ "meta-llama-3": 4096,
798
+ "mistral-api": 128000,
799
+ "o1-mini": 128000,
800
+ "o3-mini": 128000
801
+ }
802
+
803
+ model_map = {
804
+ "gpt-3.5": "gpt-3.5-turbo",
805
+ "gpt-4o": "gpt-4o",
806
+ "gpt-4o mini": "gpt-4o-mini",
807
+ "o1-mini": "gpt-4o-mini",
808
+ "o3-mini": "gpt-4o-mini",
809
+ "mistral": "mistral-small-latest",
810
+ "mistral-api": "mistral-small-latest",
811
+ "meta-llama-3": "meta-llama/Meta-Llama-3-8B-Instruct",
812
+ "remote meta-llama-3": "meta-llama/Meta-Llama-3-8B-Instruct"
813
+ }
814
+
815
+ model_pricing = {
816
+ "gpt-3.5": {"USD": {"input": 0.0000005, "output": 0.0000015}, "RON": {"input": 0.0000023, "output": 0.0000069}},
817
+ "gpt-4o": {"USD": {"input": 0.0000025, "output": 0.00001}, "RON": {"input": 0.0000115, "output": 0.000046}},
818
+ "gpt-4o-mini": {"USD": {"input": 0.00000015, "output": 0.0000006}, "RON": {"input": 0.0000007, "output": 0.0000028}},
819
+ "o1-mini": {"USD": {"input": 0.0000011, "output": 0.0000044}, "RON": {"input": 0.0000051, "output": 0.0000204}},
820
+ "o3-mini": {"USD": {"input": 0.0000011, "output": 0.0000044}, "RON": {"input": 0.0000051, "output": 0.0000204}},
821
+ "meta-llama-3": {"USD": {"input": 0.00, "output": 0.00}, "RON": {"input": 0.00, "output": 0.00}},
822
+ "mistral": {"USD": {"input": 0.00, "output": 0.00}, "RON": {"input": 0.00, "output": 0.00}},
823
+ "mistral-api": {"USD": {"input": 0.00, "output": 0.00}, "RON": {"input": 0.00, "output": 0.00}}
824
+ }
825
+ pricing_info = ""
826
+
827
+ # Find the matching model
828
+ model_key = None
829
+ for key in model_map:
830
+ if key.lower() in normalized:
831
+ model_key = key
832
+ break
833
+
834
+ if not model_key:
835
+ raise ValueError(f"Unsupported model: {normalized}")
836
+ model = model_map[model_key]
837
+ max_tokens = model_token_limits.get(model, 4096)
838
+ max_tokens = min(slider_max_tokens, max_tokens)
839
+ pricing_info = model_pricing.get(model_key, {"USD": {"input": 0.00, "output": 0.00}, "RON": {"input": 0.00, "output": 0.00}})
840
+
841
  try:
842
+ # OpenAI models (GPT-3.5, GPT-4o, GPT-4o mini, o1-mini, o3-mini)
843
+ if any(model in normalized for model in ["gpt-3.5", "gpt-4o", "o1-mini", "o3-mini"]):
844
+ debug_print(f"Creating OpenAI API pipeline for {normalized}...")
845
+ openai_api_key = os.environ.get("OPENAI_API_KEY")
846
+ if not openai_api_key:
847
+ raise ValueError("Please set the OPENAI_API_KEY environment variable to use OpenAI API.")
848
+
849
+ import openai
850
+
851
+ class OpenAILLM(LLM):
852
+ model_name: str = model
853
+ llm_choice: str = model
854
+ max_context_tokens: int = max_tokens
855
+ pricing: dict = pricing_info
856
+ temperature: float = 0.7
857
+ top_p: float = 0.95
858
+ top_k: int = 50
859
+
860
+
861
+ @property
862
+ def _llm_type(self) -> str:
863
+ return "openai_llm"
864
+
865
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
866
+ try:
867
+ openai.api_key = openai_api_key
868
+ print(f" tokens: {max_tokens}")
869
+ response = openai.ChatCompletion.create(
870
+ model=self.model_name,
871
+ messages=[{"role": "user", "content": prompt}],
872
+ temperature=self.temperature,
873
+ top_p=self.top_p,
874
+ max_tokens=max_tokens
875
+ )
876
+ return response["choices"][0]["message"]["content"]
877
+ except Exception as e:
878
+ debug_print(f"OpenAI API error: {str(e)}")
879
+ return f"Error generating response: {str(e)}"
880
+
881
+ @property
882
+ def _identifying_params(self) -> dict:
883
+ return {
884
+ "model": self.model_name,
885
+ "max_tokens": self.max_context_tokens,
886
+ "temperature": self.temperature,
887
+ "top_p": self.top_p,
888
+ "top_k": self.top_k
889
+ }
890
+
891
+ debug_print(f"OpenAI {model} pipeline created successfully.")
892
+ return OpenAILLM()
893
+
894
+ # Meta-Llama-3 model
895
+ elif "meta-llama" in normalized or "llama" in normalized:
896
  debug_print("Creating remote Meta-Llama-3 pipeline via Hugging Face Inference API...")
897
  from huggingface_hub import InferenceClient
898
  repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
 
901
  raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
902
 
903
  client = InferenceClient(token=hf_api_token, timeout=120)
904
+
 
905
  def remote_generate(prompt: str) -> str:
906
  max_retries = 3
907
  backoff = 2 # start with 2 seconds
908
  for attempt in range(max_retries):
909
  try:
910
+ debug_print(f"Remote generation attempt {attempt+1} tokens: {self.max_tokens}")
911
  response = client.text_generation(
912
  prompt,
913
  model=repo_id,
914
  temperature=self.temperature,
915
  top_p=self.top_p,
916
+ max_tokens= max_tokens # Reduced token count for speed
917
  )
918
  return response
919
  except Exception as e:
 
925
  return "Failed to generate response after multiple attempts."
926
 
927
  class RemoteLLM(LLM):
928
+ model_name: str = repo_id
929
+ llm_choice: str = repo_id
930
+ max_context_tokens: int = max_tokens
931
+ pricing: dict = pricing_info
932
+
933
  @property
934
  def _llm_type(self) -> str:
935
  return "remote_llm"
 
939
 
940
  @property
941
  def _identifying_params(self) -> dict:
942
+ return {"model": self.model_name, "max_tokens": self.max_context_tokens}
943
 
944
  debug_print("Remote Meta-Llama-3 pipeline created successfully.")
945
  return RemoteLLM()
946
+
947
+ # Mistral API model
948
+ elif "mistral" in normalized:
949
  debug_print("Creating Mistral API pipeline...")
950
  mistral_api_key = os.environ.get("MISTRAL_API_KEY")
951
  if not mistral_api_key:
952
  raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
953
+
954
  try:
955
+ from mistralai import Mistral
956
  debug_print("Mistral library imported successfully")
957
  except ImportError:
958
+ raise ImportError("Mistral client library not installed. Please install with 'pip install mistralai'.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
959
 
960
+ class MistralLLM(LLM):
961
+ temperature: float = 0.7
962
+ top_p: float = 0.95
963
+ model_name: str = model
964
+ llm_choice: str = model
965
+
966
+ pricing: dict = pricing_info
967
+ _client: Any = PrivateAttr(default=None)
968
+
969
+ def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
970
+ try:
971
+ super().__init__(**kwargs)
972
+ # Bypass Pydantic's __setattr__ to assign to _client
973
+ object.__setattr__(self, '_client', Mistral(api_key=api_key))
974
+ self.temperature = temperature
975
+ self.top_p = top_p
976
+ except Exception as e:
977
+ debug_print(f"Init Mistral failed with error: {e}")
978
+
979
  @property
980
  def _llm_type(self) -> str:
981
+ return "mistral_llm"
982
+
983
  def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
984
+ try:
985
+ debug_print(f"Calling Mistral API... tokens: {max_tokens}")
986
+ response = self._client.chat.complete(
987
+ model=self.model_name,
988
+ messages=[{"role": "user", "content": prompt}],
989
+ temperature=self.temperature,
990
+ top_p=self.top_p,
991
+ max_tokens= max_tokens
992
+ )
993
+ return response.choices[0].message.content
994
+ except Exception as e:
995
+ debug_print(f"Mistral API error: {str(e)}")
996
+ return f"Error generating response: {str(e)}"
997
+
998
  @property
999
  def _identifying_params(self) -> dict:
1000
+ return {"model": self.model_name, "max_tokens": max_tokens}
1001
+
1002
+ debug_print("Creating Mistral LLM instance")
1003
+ mistral_llm = MistralLLM(api_key=mistral_api_key, temperature=self.temperature, top_p=self.top_p)
1004
+ debug_print("Mistral API pipeline created successfully.")
1005
+ return mistral_llm
1006
+
1007
+ else:
1008
+ raise ValueError(f"Unsupported model choice: {self.llm_choice}")
1009
 
 
 
 
1010
  except Exception as e:
1011
  debug_print(f"Error creating LLM pipeline: {str(e)}")
1012
  # Return a dummy LLM that explains the error
 
1025
  return ErrorLLM()
1026
 
1027
 
1028
+ def update_llm_pipeline(self, new_model_choice: str, temperature: float, top_p: float, top_k: int, prompt_template: str, bm25_weight: float):
1029
  debug_print(f"Updating chain with new model: {new_model_choice}")
1030
  self.llm_choice = new_model_choice
1031
  self.temperature = temperature
1032
  self.top_p = top_p
1033
+ self.top_k = top_k
1034
  self.prompt_template = prompt_template
1035
  self.bm25_weight = bm25_weight
1036
  self.faiss_weight = 1.0 - bm25_weight
 
1038
  def format_response(response: str) -> str:
1039
  input_tokens = count_tokens(self.context + self.prompt_template)
1040
  output_tokens = count_tokens(response)
1041
+ formatted = f" Response:\n\n"
1042
+ formatted += f"Model: {self.llm_choice}\n"
1043
+ formatted += f"Model Parameters:\n"
1044
+ formatted += f"- Temperature: {self.temperature}\n"
1045
+ formatted += f"- Top-p: {self.top_p}\n"
1046
+ formatted += f"- Top-k: {self.top_k}\n"
1047
+ formatted += f"- BM25 Weight: {self.bm25_weight}\n\n"
1048
+ formatted += f"{response}\n\n---\n"
1049
  formatted += f"- **Input tokens:** {input_tokens}\n"
1050
  formatted += f"- **Output tokens:** {output_tokens}\n"
1051
  formatted += f"- **Generated using:** {self.llm_choice}\n"
 
1128
  def format_response(response: str) -> str:
1129
  input_tokens = count_tokens(self.context + self.prompt_template)
1130
  output_tokens = count_tokens(response)
1131
+ formatted = f" Response:\n\n"
1132
+ formatted += f"Model: {self.llm_choice}\n"
1133
+ formatted += f"Model Parameters:\n"
1134
+ formatted += f"- Temperature: {self.temperature}\n"
1135
+ formatted += f"- Top-p: {self.top_p}\n"
1136
+ formatted += f"- Top-k: {self.top_k}\n"
1137
+ formatted += f"- BM25 Weight: {self.bm25_weight}\n\n"
1138
+ formatted += f"{response}\n\n---\n"
1139
  formatted += f"- **Input tokens:** {input_tokens}\n"
1140
  formatted += f"- **Output tokens:** {output_tokens}\n"
1141
  formatted += f"- **Generated using:** {self.llm_choice}\n"
 
1162
  global rag_chain
1163
  rag_chain = ElevatedRagChain()
1164
 
1165
+ def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p, top_k):
1166
  debug_print("Inside load_pdfs function.")
1167
  if not file_links:
1168
  debug_print("Please enter non-empty URLs")
 
1171
  links = [link.strip() for link in file_links.split("\n") if link.strip()]
1172
  global rag_chain
1173
  if rag_chain.raw_data:
1174
+ rag_chain.update_llm_pipeline(model_choice, temperature, top_p, top_k, prompt_template, bm25_weight)
1175
  context_display = rag_chain.get_current_context()
1176
  response_msg = f"Files already loaded. Chain updated with model: {model_choice}"
1177
  return (
 
1186
  prompt_template=prompt_template,
1187
  bm25_weight=bm25_weight,
1188
  temperature=temperature,
1189
+ top_p=top_p,
1190
+ top_k=top_k
1191
  )
1192
  rag_chain.add_pdfs_to_vectore_store(links)
1193
  context_display = rag_chain.get_current_context()
 
1211
  def update_model(new_model: str):
1212
  global rag_chain
1213
  if rag_chain and rag_chain.raw_data:
1214
+ rag_chain.update_llm_pipeline(new_model, rag_chain.temperature, rag_chain.top_p, rag_chain.top_k,
1215
  rag_chain.prompt_template, rag_chain.bm25_weight)
1216
  debug_print(f"Model updated to {rag_chain.llm_choice}")
1217
  return f"Model updated to: {rag_chain.llm_choice}"
 
1220
 
1221
 
1222
  # Update submit_query_updated to better handle context limitation
1223
+ def submit_query_updated(query, temperature, top_p, top_k, bm25_weight):
1224
  debug_print(f"Processing query: {query}")
1225
  if not query:
1226
  debug_print("Empty query received")
 
1231
  return "Please load files first.", "", "Input tokens: 0", "Output tokens: 0"
1232
 
1233
  try:
1234
+ # Update all parameters for this query
1235
+ rag_chain.temperature = temperature
1236
+ rag_chain.top_p = top_p
1237
+ rag_chain.top_k = top_k
1238
+ rag_chain.bm25_weight = bm25_weight
1239
+ rag_chain.faiss_weight = 1.0 - bm25_weight
1240
+
1241
+ # Update the ensemble retriever weights
1242
+ rag_chain.ensemble_retriever = EnsembleRetriever(
1243
+ retrievers=[rag_chain.bm25_retriever, rag_chain.faiss_retriever],
1244
+ weights=[rag_chain.bm25_weight, rag_chain.faiss_weight]
1245
+ )
1246
+
1247
  # Determine max context size based on model
1248
  model_name = rag_chain.llm_choice.lower()
1249
  max_context_tokens = 32000 if "mistral" in model_name else 4096
 
1390
  clearInterval(jobListInterval);
1391
  }
1392
  }, 500);
1393
+
1394
+ // Function to disable sliders
1395
+ function disableSliders() {
1396
+ const sliders = document.querySelectorAll('input[type="range"]');
1397
+ sliders.forEach(slider => {
1398
+ if (!slider.closest('.query-tab')) { // Don't disable sliders in query tab
1399
+ slider.disabled = true;
1400
+ slider.style.opacity = '0.5';
1401
+ }
1402
+ });
1403
+ }
1404
+
1405
+ // Function to enable sliders
1406
+ function enableSliders() {
1407
+ const sliders = document.querySelectorAll('input[type="range"]');
1408
+ sliders.forEach(slider => {
1409
+ slider.disabled = false;
1410
+ slider.style.opacity = '1';
1411
+ });
1412
+ }
1413
+
1414
+ // Add event listener for load button
1415
+ const loadButton = document.querySelector('button:contains("Load Files (Async)")');
1416
+ if (loadButton) {
1417
+ loadButton.addEventListener('click', function() {
1418
+ // Wait for the response to come back
1419
+ setTimeout(disableSliders, 1000);
1420
+ });
1421
+ }
1422
+
1423
+ // Add event listener for reset button
1424
+ const resetButton = document.querySelector('button:contains("Reset App")');
1425
+ if (resetButton) {
1426
+ resetButton.addEventListener('click', function() {
1427
+ enableSliders();
1428
+ });
1429
+ }
1430
  });
1431
  """) as app:
1432
  gr.Markdown('''# PhiRAG - Async Version
 
1463
  with gr.Row():
1464
  with gr.Column():
1465
  model_dropdown = gr.Dropdown(
1466
+ choices=[
1467
+ "🇺🇸 GPT-3.5",
1468
+ "🇺🇸 GPT-4o",
1469
+ "🇺🇸 GPT-4o mini",
1470
+ "🇺🇸 o1-mini",
1471
+ "🇺🇸 o3-mini",
1472
+ "🇺🇸 Remote Meta-Llama-3",
1473
+ "🇪🇺 Mistral-API",
1474
+ ],
1475
+ value="🇪🇺 Mistral-API",
1476
  label="Select Model"
1477
  )
1478
  temperature_slider = gr.Slider(
 
1483
  minimum=0.1, maximum=0.99, value=0.95, step=0.05,
1484
  label="Word Variety (Top-p)"
1485
  )
1486
+ top_k_slider = gr.Slider(
1487
+ minimum=1, maximum=100, value=50, step=1,
1488
+ label="Token Selection (Top-k)"
1489
+ )
1490
  with gr.Column():
1491
  pdf_input = gr.Textbox(
1492
  label="Enter your file URLs (one per line)",
 
1522
  with gr.Row():
1523
  model_output = gr.Markdown("**Current Model**: Not selected")
1524
 
1525
+ with gr.TabItem("Submit Query", elem_classes=["query-tab"]):
1526
  with gr.Row():
1527
+ with gr.Column():
1528
+ query_model_dropdown = gr.Dropdown(
1529
+ choices=[
1530
+ "🇺🇸 GPT-3.5",
1531
+ "🇺🇸 GPT-4o",
1532
+ "🇺🇸 GPT-4o mini",
1533
+ "🇺🇸 o1-mini",
1534
+ "🇺🇸 o3-mini",
1535
+ "🇺🇸 Remote Meta-Llama-3",
1536
+ "🇪🇺 Mistral-API",
1537
+ ],
1538
+ value="🇪🇺 Mistral-API",
1539
+ label="Query Model"
1540
+ )
1541
+ query_temperature_slider = gr.Slider(
1542
+ minimum=0.1, maximum=1.0, value=0.5, step=0.1,
1543
+ label="Randomness (Temperature)"
1544
+ )
1545
+ query_top_p_slider = gr.Slider(
1546
+ minimum=0.1, maximum=0.99, value=0.95, step=0.05,
1547
+ label="Word Variety (Top-p)"
1548
+ )
1549
+ query_top_k_slider = gr.Slider(
1550
+ minimum=1, maximum=100, value=50, step=1,
1551
+ label="Token Selection (Top-k)"
1552
+ )
1553
+ query_bm25_weight_slider = gr.Slider(
1554
+ minimum=0.0, maximum=1.0, value=0.6, step=0.1,
1555
+ label="Lexical vs Semantics (BM25 Weight)"
1556
+ )
1557
+ with gr.Column():
1558
+ max_tokens_slider = gr.Slider(minimum=1000, maximum=128000, value=3000, label="🔢 Max Tokens", step=1000)
1559
+ query_input = gr.Textbox(
1560
+ label="Enter your query here",
1561
+ placeholder="Type your query",
1562
+ lines=4
1563
+ )
1564
+ submit_button = gr.Button("Submit Query (Async)")
1565
 
1566
  with gr.Row():
1567
  query_response = gr.Textbox(
 
1634
  status_tokens1 = gr.Markdown("")
1635
  status_tokens2 = gr.Markdown("")
1636
 
1637
+ with gr.TabItem("Batch Query"):
1638
+ with gr.Row():
1639
+ with gr.Column():
1640
+ batch_model_dropdown = gr.Dropdown(
1641
+ choices=[
1642
+ "🇺🇸 GPT-3.5",
1643
+ "🇺🇸 GPT-4o",
1644
+ "🇺🇸 GPT-4o mini",
1645
+ "🇺🇸 o1-mini",
1646
+ "🇺🇸 o3-mini",
1647
+ "🇺🇸 Remote Meta-Llama-3",
1648
+ "🇪🇺 Mistral-API",
1649
+ ],
1650
+ value="🇪🇺 Mistral-API",
1651
+ label="Query Model"
1652
+ )
1653
+ with gr.Row():
1654
+ temp_variation = gr.Dropdown(
1655
+ choices=["Constant", "Whole range 3 values", "Whole range 5 values", "Whole range 7 values", "Whole range 10 values"],
1656
+ value="Constant",
1657
+ label="Temperature Variation"
1658
+ )
1659
+ batch_temperature_slider = gr.Slider(
1660
+ minimum=0.1, maximum=1.0, value=0.5, step=0.1,
1661
+ label="Randomness (Temperature)"
1662
+ )
1663
+ with gr.Row():
1664
+ top_p_variation = gr.Dropdown(
1665
+ choices=["Constant", "Whole range 3 values", "Whole range 5 values", "Whole range 7 values", "Whole range 10 values"],
1666
+ value="Constant",
1667
+ label="Top-p Variation"
1668
+ )
1669
+ batch_top_p_slider = gr.Slider(
1670
+ minimum=0.1, maximum=0.99, value=0.95, step=0.05,
1671
+ label="Word Variety (Top-p)"
1672
+ )
1673
+ with gr.Row():
1674
+ top_k_variation = gr.Dropdown(
1675
+ choices=["Constant", "Whole range 3 values", "Whole range 5 values", "Whole range 7 values", "Whole range 10 values"],
1676
+ value="Constant",
1677
+ label="Top-k Variation"
1678
+ )
1679
+ batch_top_k_slider = gr.Slider(
1680
+ minimum=1, maximum=100, value=50, step=1,
1681
+ label="Token Selection (Top-k)"
1682
+ )
1683
+ with gr.Row():
1684
+ bm25_variation = gr.Dropdown(
1685
+ choices=["Constant", "Whole range 3 values", "Whole range 5 values", "Whole range 7 values", "Whole range 10 values"],
1686
+ value="Constant",
1687
+ label="BM25 Weight Variation"
1688
+ )
1689
+ batch_bm25_weight_slider = gr.Slider(
1690
+ minimum=0.0, maximum=1.0, value=0.6, step=0.1,
1691
+ label="Lexical vs Semantics (BM25 Weight)"
1692
+ )
1693
+ with gr.Column():
1694
+ batch_max_tokens_slider = gr.Slider(
1695
+ minimum=1000, maximum=128000, value=3000, label="🔢 Max Tokens", step=1000
1696
+ )
1697
+ batch_query_input = gr.Textbox(
1698
+ label="Enter your query here",
1699
+ placeholder="Type your query",
1700
+ lines=4
1701
+ )
1702
+ batch_submit_button = gr.Button("Submit Batch Query (Async)")
1703
+
1704
+ with gr.Row():
1705
+ batch_query_response = gr.Textbox(
1706
+ label="Batch Query Results",
1707
+ placeholder="Results will appear here (formatted as Markdown)",
1708
+ lines=10
1709
+ )
1710
+ batch_query_context = gr.Textbox(
1711
+ label="Context Information",
1712
+ placeholder="Retrieved context will appear here",
1713
+ lines=6
1714
+ )
1715
+
1716
+ with gr.Row():
1717
+ batch_input_tokens = gr.Markdown("Input tokens: 0")
1718
+ batch_output_tokens = gr.Markdown("Output tokens: 0")
1719
+
1720
+ with gr.Row():
1721
+ with gr.Column(scale=1):
1722
+ batch_job_list = gr.Markdown(
1723
+ value="No jobs yet",
1724
+ label="Job List (Click to select)"
1725
+ )
1726
+ batch_refresh_button = gr.Button("Refresh Job List")
1727
+ batch_auto_refresh_checkbox = gr.Checkbox(
1728
+ label="Enable Auto Refresh",
1729
+ value=False
1730
+ )
1731
+ batch_df = gr.DataFrame(
1732
+ value=run_query(10),
1733
+ headers=["Number", "Square"],
1734
+ label="Query Results",
1735
+ visible=False
1736
+ )
1737
+
1738
+ with gr.Column(scale=2):
1739
+ batch_job_id_input = gr.Textbox(
1740
+ label="Job ID",
1741
+ placeholder="Job ID will appear here when selected from the list",
1742
+ lines=1
1743
+ )
1744
+ batch_job_query_display = gr.Textbox(
1745
+ label="Job Query",
1746
+ placeholder="The query associated with this job will appear here",
1747
+ lines=2,
1748
+ interactive=False
1749
+ )
1750
+ batch_check_button = gr.Button("Check Status")
1751
+ batch_cleanup_button = gr.Button("Cleanup Old Jobs")
1752
+
1753
+ with gr.Row():
1754
+ batch_status_response = gr.Textbox(
1755
+ label="Job Result",
1756
+ placeholder="Job result will appear here",
1757
+ lines=6
1758
+ )
1759
+ batch_status_context = gr.Textbox(
1760
+ label="Context Information",
1761
+ placeholder="Context information will appear here",
1762
+ lines=6
1763
+ )
1764
+
1765
+ with gr.Row():
1766
+ batch_status_tokens1 = gr.Markdown("")
1767
+ batch_status_tokens2 = gr.Markdown("")
1768
+
1769
  with gr.TabItem("App Management"):
1770
  with gr.Row():
1771
  reset_button = gr.Button("Reset App")
 
1786
  with gr.Row():
1787
  reset_model = gr.Markdown("")
1788
 
1789
+ # Add initialization info display
1790
+ init_info = gr.Markdown("")
1791
+
1792
+ # Update load_button click to include top_k
1793
  load_button.click(
1794
  load_pdfs_async,
1795
+ inputs=[pdf_input, model_dropdown, prompt_input, bm25_weight_slider, temperature_slider, top_p_slider, top_k_slider, max_tokens_slider],
1796
+ outputs=[load_response, load_context, model_output, job_id_input, job_query_display, job_list, init_info]
 
 
 
 
 
 
 
1797
  )
1798
 
1799
+ # Update submit_button click to include top_k
1800
  submit_button.click(
1801
  submit_query_async,
1802
+ inputs=[query_input, query_model_dropdown, max_tokens_slider, query_temperature_slider, query_top_p_slider, query_top_k_slider, query_bm25_weight_slider],
1803
  outputs=[query_response, query_context, input_tokens, output_tokens, job_id_input, job_query_display, job_list]
1804
  )
1805
 
1806
+ # Add function to sync all parameters
1807
+ def sync_parameters(temperature, top_p, top_k, bm25_weight):
1808
+ return temperature, top_p, top_k, bm25_weight
1809
+
1810
+ # Sync parameters between tabs
1811
+ temperature_slider.change(
1812
+ fn=sync_parameters,
1813
+ inputs=[temperature_slider, top_p_slider, top_k_slider, bm25_weight_slider],
1814
+ outputs=[query_temperature_slider, query_top_p_slider, query_top_k_slider, query_bm25_weight_slider]
1815
+ )
1816
+ top_p_slider.change(
1817
+ fn=sync_parameters,
1818
+ inputs=[temperature_slider, top_p_slider, top_k_slider, bm25_weight_slider],
1819
+ outputs=[query_temperature_slider, query_top_p_slider, query_top_k_slider, query_bm25_weight_slider]
1820
+ )
1821
+ top_k_slider.change(
1822
+ fn=sync_parameters,
1823
+ inputs=[temperature_slider, top_p_slider, top_k_slider, bm25_weight_slider],
1824
+ outputs=[query_temperature_slider, query_top_p_slider, query_top_k_slider, query_bm25_weight_slider]
1825
+ )
1826
+ bm25_weight_slider.change(
1827
+ fn=sync_parameters,
1828
+ inputs=[temperature_slider, top_p_slider, top_k_slider, bm25_weight_slider],
1829
+ outputs=[query_temperature_slider, query_top_p_slider, query_top_k_slider, query_bm25_weight_slider]
1830
+ )
1831
+
1832
+ # Connect the buttons to their respective functions
1833
  check_button.click(
1834
  check_job_status,
1835
  inputs=[job_id_input],
 
1883
  every=2 #if auto_refresh_checkbox.value else None # Directly set `every` based on the checkbox state
1884
  )
1885
 
1886
+ # Add batch query button click handler
1887
+ batch_submit_button.click(
1888
+ submit_batch_query_async,
1889
+ inputs=[
1890
+ batch_query_input,
1891
+ batch_model_dropdown,
1892
+ batch_max_tokens_slider,
1893
+ temp_variation,
1894
+ top_p_variation,
1895
+ top_k_variation,
1896
+ bm25_variation,
1897
+ batch_temperature_slider,
1898
+ batch_top_p_slider,
1899
+ batch_top_k_slider,
1900
+ batch_bm25_weight_slider
1901
+ ],
1902
+ outputs=[
1903
+ batch_query_response,
1904
+ batch_query_context,
1905
+ batch_input_tokens,
1906
+ batch_output_tokens,
1907
+ batch_job_id_input,
1908
+ batch_job_query_display,
1909
+ batch_job_list
1910
+ ]
1911
+ )
1912
+
1913
+ # Add batch job status checking
1914
+ batch_check_button.click(
1915
+ check_job_status,
1916
+ inputs=[batch_job_id_input],
1917
+ outputs=[batch_status_response, batch_status_context, batch_status_tokens1, batch_status_tokens2, batch_job_query_display]
1918
+ )
1919
+
1920
+ # Add batch job list refresh
1921
+ batch_refresh_button.click(
1922
+ refresh_job_list,
1923
+ inputs=[],
1924
+ outputs=[batch_job_list]
1925
+ )
1926
+
1927
+ # Add batch job list selection
1928
+ batch_job_id_input.change(
1929
+ job_selected,
1930
+ inputs=[batch_job_id_input],
1931
+ outputs=[batch_job_id_input, batch_job_query_display]
1932
+ )
1933
+
1934
+ # Add batch cleanup
1935
+ batch_cleanup_button.click(
1936
+ cleanup_old_jobs,
1937
+ inputs=[],
1938
+ outputs=[batch_status_response, batch_status_context, batch_status_tokens1]
1939
+ )
1940
+
1941
+ # Add batch auto-refresh
1942
+ batch_auto_refresh_checkbox.change(
1943
+ fn=periodic_update,
1944
+ inputs=[batch_auto_refresh_checkbox],
1945
+ outputs=[batch_job_list, batch_status_response, batch_df, batch_status_context],
1946
+ every=2
1947
+ )
1948
+
1949
  if __name__ == "__main__":
1950
  debug_print("Launching Gradio interface.")
1951
  app.queue().launch(share=False)