Upload folder using huggingface_hub
Browse files- advanced_rag.py +731 -125
advanced_rag.py
CHANGED
|
@@ -36,6 +36,189 @@ from langchain_community.document_loaders import PyMuPDFLoader # Updated loader
|
|
| 36 |
import tempfile
|
| 37 |
import mimetypes
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
def get_mime_type(file_path):
|
| 40 |
return mimetypes.guess_type(file_path)[0] or 'application/octet-stream'
|
| 41 |
|
|
@@ -43,6 +226,8 @@ print("Pydantic Version: ")
|
|
| 43 |
print(pydantic.__version__)
|
| 44 |
# Add Mistral imports with fallback handling
|
| 45 |
|
|
|
|
|
|
|
| 46 |
try:
|
| 47 |
from mistralai import Mistral
|
| 48 |
MISTRAL_AVAILABLE = True
|
|
@@ -107,11 +292,14 @@ def process_in_background(job_id, function, args):
|
|
| 107 |
error_result = (f"Error processing job: {str(e)}", "", "", "")
|
| 108 |
results_queue.put((job_id, error_result))
|
| 109 |
|
| 110 |
-
def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
|
| 111 |
"""Asynchronous version of load_pdfs_updated to prevent timeouts"""
|
| 112 |
global last_job_id
|
| 113 |
if not file_links:
|
| 114 |
-
return "Please enter non-empty URLs", "", "Model used: N/A", "", "", get_job_list()
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
job_id = str(uuid.uuid4())
|
| 117 |
debug_print(f"Starting async job {job_id} for file loading")
|
|
@@ -119,7 +307,7 @@ def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temp
|
|
| 119 |
# Start background thread
|
| 120 |
threading.Thread(
|
| 121 |
target=process_in_background,
|
| 122 |
-
args=(job_id, load_pdfs_updated, [file_links, model_choice, prompt_template, bm25_weight, temperature, top_p])
|
| 123 |
).start()
|
| 124 |
|
| 125 |
job_query = f"Loading files: {file_links.split()[0]}..." if file_links else "No files"
|
|
@@ -132,6 +320,8 @@ def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temp
|
|
| 132 |
|
| 133 |
last_job_id = job_id
|
| 134 |
|
|
|
|
|
|
|
| 135 |
return (
|
| 136 |
f"Files submitted and processing in the background (Job ID: {job_id}).\n\n"
|
| 137 |
f"Use 'Check Job Status' tab with this ID to get results.",
|
|
@@ -139,14 +329,17 @@ def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temp
|
|
| 139 |
f"Model requested: {model_choice}",
|
| 140 |
job_id, # Return job_id to update the job_id_input component
|
| 141 |
job_query, # Return job_query to update the job_query_display component
|
| 142 |
-
get_job_list() # Return updated job list
|
|
|
|
| 143 |
)
|
| 144 |
|
| 145 |
-
def submit_query_async(query, model_choice
|
| 146 |
"""Asynchronous version of submit_query_updated to prevent timeouts"""
|
| 147 |
global last_job_id
|
| 148 |
if not query:
|
| 149 |
return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0", "", "", get_job_list()
|
|
|
|
|
|
|
| 150 |
|
| 151 |
job_id = str(uuid.uuid4())
|
| 152 |
debug_print(f"Starting async job {job_id} for query: {query}")
|
|
@@ -154,13 +347,13 @@ def submit_query_async(query, model_choice=None):
|
|
| 154 |
# Update model if specified
|
| 155 |
if model_choice and rag_chain and rag_chain.llm_choice != model_choice:
|
| 156 |
debug_print(f"Updating model to {model_choice} for this query")
|
| 157 |
-
rag_chain.update_llm_pipeline(model_choice,
|
| 158 |
-
rag_chain.prompt_template,
|
| 159 |
|
| 160 |
# Start background thread
|
| 161 |
threading.Thread(
|
| 162 |
target=process_in_background,
|
| 163 |
-
args=(job_id, submit_query_updated, [query])
|
| 164 |
).start()
|
| 165 |
|
| 166 |
jobs[job_id] = {
|
|
@@ -550,7 +743,7 @@ def load_file_from_google_drive(link: str) -> list:
|
|
| 550 |
|
| 551 |
class ElevatedRagChain:
|
| 552 |
def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
|
| 553 |
-
bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95) -> None:
|
| 554 |
debug_print(f"Initializing ElevatedRagChain with model: {llm_choice}")
|
| 555 |
self.embed_func = HuggingFaceEmbeddings(
|
| 556 |
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
|
@@ -558,7 +751,7 @@ class ElevatedRagChain:
|
|
| 558 |
)
|
| 559 |
self.bm25_weight = bm25_weight
|
| 560 |
self.faiss_weight = 1.0 - bm25_weight
|
| 561 |
-
self.top_k =
|
| 562 |
self.llm_choice = llm_choice
|
| 563 |
self.temperature = temperature
|
| 564 |
self.top_p = top_p
|
|
@@ -587,9 +780,119 @@ class ElevatedRagChain:
|
|
| 587 |
# Improve error handling in the ElevatedRagChain class
|
| 588 |
def create_llm_pipeline(self):
|
| 589 |
from langchain.llms.base import LLM # Import LLM here so it's always defined
|
| 590 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 591 |
try:
|
| 592 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 593 |
debug_print("Creating remote Meta-Llama-3 pipeline via Hugging Face Inference API...")
|
| 594 |
from huggingface_hub import InferenceClient
|
| 595 |
repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
|
@@ -598,20 +901,19 @@ class ElevatedRagChain:
|
|
| 598 |
raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
|
| 599 |
|
| 600 |
client = InferenceClient(token=hf_api_token, timeout=120)
|
| 601 |
-
|
| 602 |
-
# We no longer use wait_for_model because it's unsupported
|
| 603 |
def remote_generate(prompt: str) -> str:
|
| 604 |
max_retries = 3
|
| 605 |
backoff = 2 # start with 2 seconds
|
| 606 |
for attempt in range(max_retries):
|
| 607 |
try:
|
| 608 |
-
debug_print(f"Remote generation attempt {attempt+1}")
|
| 609 |
response = client.text_generation(
|
| 610 |
prompt,
|
| 611 |
model=repo_id,
|
| 612 |
temperature=self.temperature,
|
| 613 |
top_p=self.top_p,
|
| 614 |
-
|
| 615 |
)
|
| 616 |
return response
|
| 617 |
except Exception as e:
|
|
@@ -623,6 +925,11 @@ class ElevatedRagChain:
|
|
| 623 |
return "Failed to generate response after multiple attempts."
|
| 624 |
|
| 625 |
class RemoteLLM(LLM):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 626 |
@property
|
| 627 |
def _llm_type(self) -> str:
|
| 628 |
return "remote_llm"
|
|
@@ -632,97 +939,74 @@ class ElevatedRagChain:
|
|
| 632 |
|
| 633 |
@property
|
| 634 |
def _identifying_params(self) -> dict:
|
| 635 |
-
return {"model":
|
| 636 |
|
| 637 |
debug_print("Remote Meta-Llama-3 pipeline created successfully.")
|
| 638 |
return RemoteLLM()
|
| 639 |
-
|
| 640 |
-
|
|
|
|
| 641 |
debug_print("Creating Mistral API pipeline...")
|
| 642 |
mistral_api_key = os.environ.get("MISTRAL_API_KEY")
|
| 643 |
if not mistral_api_key:
|
| 644 |
raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
|
|
|
|
| 645 |
try:
|
| 646 |
-
from mistralai import Mistral
|
| 647 |
debug_print("Mistral library imported successfully")
|
| 648 |
except ImportError:
|
| 649 |
-
|
| 650 |
-
normalized = "llama"
|
| 651 |
-
if normalized != "llama":
|
| 652 |
-
# from pydantic import PrivateAttr
|
| 653 |
-
# from langchain.llms.base import LLM
|
| 654 |
-
# from typing import Any, Optional, List
|
| 655 |
-
# import typing
|
| 656 |
-
|
| 657 |
-
class MistralLLM(LLM):
|
| 658 |
-
temperature: float = 0.7
|
| 659 |
-
top_p: float = 0.95
|
| 660 |
-
_client: Any = PrivateAttr(default=None)
|
| 661 |
-
|
| 662 |
-
def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
|
| 663 |
-
try:
|
| 664 |
-
super().__init__(**kwargs)
|
| 665 |
-
# Bypass Pydantic's __setattr__ to assign to _client
|
| 666 |
-
object.__setattr__(self, '_client', Mistral(api_key=api_key))
|
| 667 |
-
self.temperature = temperature
|
| 668 |
-
self.top_p = top_p
|
| 669 |
-
except Exception as e:
|
| 670 |
-
debug_print(f"Init Mistral failed with error: {e}")
|
| 671 |
-
|
| 672 |
-
@property
|
| 673 |
-
def _llm_type(self) -> str:
|
| 674 |
-
return "mistral_llm"
|
| 675 |
-
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
| 676 |
-
try:
|
| 677 |
-
debug_print("Calling Mistral API...")
|
| 678 |
-
response = self._client.chat.complete(
|
| 679 |
-
model="mistral-small-latest",
|
| 680 |
-
messages=[{"role": "user", "content": prompt}],
|
| 681 |
-
temperature=self.temperature,
|
| 682 |
-
top_p=self.top_p
|
| 683 |
-
)
|
| 684 |
-
return response.choices[0].message.content
|
| 685 |
-
except Exception as e:
|
| 686 |
-
debug_print(f"Mistral API error: {str(e)}")
|
| 687 |
-
return f"Error generating response: {str(e)}"
|
| 688 |
-
@property
|
| 689 |
-
def _identifying_params(self) -> dict:
|
| 690 |
-
return {"model": "mistral-small-latest"}
|
| 691 |
-
debug_print("Creating Mistral LLM instance")
|
| 692 |
-
mistral_llm = MistralLLM(api_key=mistral_api_key, temperature=self.temperature, top_p=self.top_p)
|
| 693 |
-
debug_print("Mistral API pipeline created successfully.")
|
| 694 |
-
return mistral_llm
|
| 695 |
-
|
| 696 |
-
else:
|
| 697 |
-
# Default case - using a fallback model (or Llama)
|
| 698 |
-
debug_print("Using local/fallback model pipeline")
|
| 699 |
-
model_id = "facebook/opt-350m" # Use a smaller model as fallback
|
| 700 |
-
pipe = pipeline(
|
| 701 |
-
"text-generation",
|
| 702 |
-
model=model_id,
|
| 703 |
-
device=-1, # CPU
|
| 704 |
-
max_length=1024
|
| 705 |
-
)
|
| 706 |
|
| 707 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 708 |
@property
|
| 709 |
def _llm_type(self) -> str:
|
| 710 |
-
return "
|
|
|
|
| 711 |
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 719 |
@property
|
| 720 |
def _identifying_params(self) -> dict:
|
| 721 |
-
return {"model":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 722 |
|
| 723 |
-
debug_print("Local fallback pipeline created.")
|
| 724 |
-
return LocalLLM()
|
| 725 |
-
|
| 726 |
except Exception as e:
|
| 727 |
debug_print(f"Error creating LLM pipeline: {str(e)}")
|
| 728 |
# Return a dummy LLM that explains the error
|
|
@@ -741,11 +1025,12 @@ class ElevatedRagChain:
|
|
| 741 |
return ErrorLLM()
|
| 742 |
|
| 743 |
|
| 744 |
-
def update_llm_pipeline(self, new_model_choice: str, temperature: float, top_p: float, prompt_template: str, bm25_weight: float):
|
| 745 |
debug_print(f"Updating chain with new model: {new_model_choice}")
|
| 746 |
self.llm_choice = new_model_choice
|
| 747 |
self.temperature = temperature
|
| 748 |
self.top_p = top_p
|
|
|
|
| 749 |
self.prompt_template = prompt_template
|
| 750 |
self.bm25_weight = bm25_weight
|
| 751 |
self.faiss_weight = 1.0 - bm25_weight
|
|
@@ -753,7 +1038,14 @@ class ElevatedRagChain:
|
|
| 753 |
def format_response(response: str) -> str:
|
| 754 |
input_tokens = count_tokens(self.context + self.prompt_template)
|
| 755 |
output_tokens = count_tokens(response)
|
| 756 |
-
formatted = f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 757 |
formatted += f"- **Input tokens:** {input_tokens}\n"
|
| 758 |
formatted += f"- **Output tokens:** {output_tokens}\n"
|
| 759 |
formatted += f"- **Generated using:** {self.llm_choice}\n"
|
|
@@ -836,7 +1128,14 @@ class ElevatedRagChain:
|
|
| 836 |
def format_response(response: str) -> str:
|
| 837 |
input_tokens = count_tokens(self.context + self.prompt_template)
|
| 838 |
output_tokens = count_tokens(response)
|
| 839 |
-
formatted = f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 840 |
formatted += f"- **Input tokens:** {input_tokens}\n"
|
| 841 |
formatted += f"- **Output tokens:** {output_tokens}\n"
|
| 842 |
formatted += f"- **Generated using:** {self.llm_choice}\n"
|
|
@@ -863,7 +1162,7 @@ class ElevatedRagChain:
|
|
| 863 |
global rag_chain
|
| 864 |
rag_chain = ElevatedRagChain()
|
| 865 |
|
| 866 |
-
def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
|
| 867 |
debug_print("Inside load_pdfs function.")
|
| 868 |
if not file_links:
|
| 869 |
debug_print("Please enter non-empty URLs")
|
|
@@ -872,7 +1171,7 @@ def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, te
|
|
| 872 |
links = [link.strip() for link in file_links.split("\n") if link.strip()]
|
| 873 |
global rag_chain
|
| 874 |
if rag_chain.raw_data:
|
| 875 |
-
rag_chain.update_llm_pipeline(model_choice, temperature, top_p, prompt_template, bm25_weight)
|
| 876 |
context_display = rag_chain.get_current_context()
|
| 877 |
response_msg = f"Files already loaded. Chain updated with model: {model_choice}"
|
| 878 |
return (
|
|
@@ -887,7 +1186,8 @@ def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, te
|
|
| 887 |
prompt_template=prompt_template,
|
| 888 |
bm25_weight=bm25_weight,
|
| 889 |
temperature=temperature,
|
| 890 |
-
top_p=top_p
|
|
|
|
| 891 |
)
|
| 892 |
rag_chain.add_pdfs_to_vectore_store(links)
|
| 893 |
context_display = rag_chain.get_current_context()
|
|
@@ -911,7 +1211,7 @@ def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, te
|
|
| 911 |
def update_model(new_model: str):
|
| 912 |
global rag_chain
|
| 913 |
if rag_chain and rag_chain.raw_data:
|
| 914 |
-
rag_chain.update_llm_pipeline(new_model, rag_chain.temperature, rag_chain.top_p,
|
| 915 |
rag_chain.prompt_template, rag_chain.bm25_weight)
|
| 916 |
debug_print(f"Model updated to {rag_chain.llm_choice}")
|
| 917 |
return f"Model updated to: {rag_chain.llm_choice}"
|
|
@@ -920,7 +1220,7 @@ def update_model(new_model: str):
|
|
| 920 |
|
| 921 |
|
| 922 |
# Update submit_query_updated to better handle context limitation
|
| 923 |
-
def submit_query_updated(query):
|
| 924 |
debug_print(f"Processing query: {query}")
|
| 925 |
if not query:
|
| 926 |
debug_print("Empty query received")
|
|
@@ -931,6 +1231,19 @@ def submit_query_updated(query):
|
|
| 931 |
return "Please load files first.", "", "Input tokens: 0", "Output tokens: 0"
|
| 932 |
|
| 933 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 934 |
# Determine max context size based on model
|
| 935 |
model_name = rag_chain.llm_choice.lower()
|
| 936 |
max_context_tokens = 32000 if "mistral" in model_name else 4096
|
|
@@ -1077,6 +1390,43 @@ document.addEventListener('DOMContentLoaded', function() {
|
|
| 1077 |
clearInterval(jobListInterval);
|
| 1078 |
}
|
| 1079 |
}, 500);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1080 |
});
|
| 1081 |
""") as app:
|
| 1082 |
gr.Markdown('''# PhiRAG - Async Version
|
|
@@ -1113,8 +1463,16 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
|
|
| 1113 |
with gr.Row():
|
| 1114 |
with gr.Column():
|
| 1115 |
model_dropdown = gr.Dropdown(
|
| 1116 |
-
choices=[
|
| 1117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1118 |
label="Select Model"
|
| 1119 |
)
|
| 1120 |
temperature_slider = gr.Slider(
|
|
@@ -1125,6 +1483,10 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
|
|
| 1125 |
minimum=0.1, maximum=0.99, value=0.95, step=0.05,
|
| 1126 |
label="Word Variety (Top-p)"
|
| 1127 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1128 |
with gr.Column():
|
| 1129 |
pdf_input = gr.Textbox(
|
| 1130 |
label="Enter your file URLs (one per line)",
|
|
@@ -1160,21 +1522,46 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
|
|
| 1160 |
with gr.Row():
|
| 1161 |
model_output = gr.Markdown("**Current Model**: Not selected")
|
| 1162 |
|
| 1163 |
-
with gr.TabItem("Submit Query"):
|
| 1164 |
with gr.Row():
|
| 1165 |
-
|
| 1166 |
-
|
| 1167 |
-
|
| 1168 |
-
|
| 1169 |
-
|
| 1170 |
-
|
| 1171 |
-
|
| 1172 |
-
|
| 1173 |
-
|
| 1174 |
-
|
| 1175 |
-
|
| 1176 |
-
|
| 1177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1178 |
|
| 1179 |
with gr.Row():
|
| 1180 |
query_response = gr.Textbox(
|
|
@@ -1247,6 +1634,138 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
|
|
| 1247 |
status_tokens1 = gr.Markdown("")
|
| 1248 |
status_tokens2 = gr.Markdown("")
|
| 1249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1250 |
with gr.TabItem("App Management"):
|
| 1251 |
with gr.Row():
|
| 1252 |
reset_button = gr.Button("Reset App")
|
|
@@ -1267,26 +1786,50 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
|
|
| 1267 |
with gr.Row():
|
| 1268 |
reset_model = gr.Markdown("")
|
| 1269 |
|
| 1270 |
-
#
|
|
|
|
|
|
|
|
|
|
| 1271 |
load_button.click(
|
| 1272 |
load_pdfs_async,
|
| 1273 |
-
inputs=[pdf_input, model_dropdown, prompt_input, bm25_weight_slider, temperature_slider, top_p_slider],
|
| 1274 |
-
outputs=[load_response, load_context, model_output, job_id_input, job_query_display, job_list]
|
| 1275 |
-
)
|
| 1276 |
-
|
| 1277 |
-
# Also sync in the other direction
|
| 1278 |
-
query_model_dropdown.change(
|
| 1279 |
-
fn=sync_model_dropdown,
|
| 1280 |
-
inputs=query_model_dropdown,
|
| 1281 |
-
outputs=model_dropdown
|
| 1282 |
)
|
| 1283 |
|
|
|
|
| 1284 |
submit_button.click(
|
| 1285 |
submit_query_async,
|
| 1286 |
-
inputs=[query_input, query_model_dropdown],
|
| 1287 |
outputs=[query_response, query_context, input_tokens, output_tokens, job_id_input, job_query_display, job_list]
|
| 1288 |
)
|
| 1289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1290 |
check_button.click(
|
| 1291 |
check_job_status,
|
| 1292 |
inputs=[job_id_input],
|
|
@@ -1340,6 +1883,69 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
|
|
| 1340 |
every=2 #if auto_refresh_checkbox.value else None # Directly set `every` based on the checkbox state
|
| 1341 |
)
|
| 1342 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1343 |
if __name__ == "__main__":
|
| 1344 |
debug_print("Launching Gradio interface.")
|
| 1345 |
app.queue().launch(share=False)
|
|
|
|
| 36 |
import tempfile
|
| 37 |
import mimetypes
|
| 38 |
|
| 39 |
+
# Add batch processing helper functions
|
| 40 |
+
def generate_parameter_values(min_val, max_val, num_values):
|
| 41 |
+
"""Generate evenly spaced values between min and max"""
|
| 42 |
+
if num_values == 1:
|
| 43 |
+
return [min_val]
|
| 44 |
+
step = (max_val - min_val) / (num_values - 1)
|
| 45 |
+
return [min_val + (step * i) for i in range(num_values)]
|
| 46 |
+
|
| 47 |
+
def process_batch_query(query, model_choice, max_tokens, param_configs, slider_values, job_id):
|
| 48 |
+
"""Process a batch of queries with different parameter combinations"""
|
| 49 |
+
results = []
|
| 50 |
+
|
| 51 |
+
# Generate all parameter combinations
|
| 52 |
+
temp_values = [slider_values['temperature']] if param_configs['temperature'] == "Constant" else generate_parameter_values(0.1, 1.0, int(param_configs['temperature'].split()[2]))
|
| 53 |
+
top_p_values = [slider_values['top_p']] if param_configs['top_p'] == "Constant" else generate_parameter_values(0.1, 0.99, int(param_configs['top_p'].split()[2]))
|
| 54 |
+
top_k_values = [slider_values['top_k']] if param_configs['top_k'] == "Constant" else generate_parameter_values(1, 100, int(param_configs['top_k'].split()[2]))
|
| 55 |
+
bm25_values = [slider_values['bm25']] if param_configs['bm25'] == "Constant" else generate_parameter_values(0.0, 1.0, int(param_configs['bm25'].split()[2]))
|
| 56 |
+
|
| 57 |
+
total_combinations = len(temp_values) * len(top_p_values) * len(top_k_values) * len(bm25_values)
|
| 58 |
+
current = 0
|
| 59 |
+
|
| 60 |
+
for temp in temp_values:
|
| 61 |
+
for top_p in top_p_values:
|
| 62 |
+
for top_k in top_k_values:
|
| 63 |
+
for bm25 in bm25_values:
|
| 64 |
+
current += 1
|
| 65 |
+
try:
|
| 66 |
+
# Update parameters
|
| 67 |
+
rag_chain.temperature = temp
|
| 68 |
+
rag_chain.top_p = top_p
|
| 69 |
+
rag_chain.top_k = top_k
|
| 70 |
+
rag_chain.bm25_weight = bm25
|
| 71 |
+
rag_chain.faiss_weight = 1.0 - bm25
|
| 72 |
+
|
| 73 |
+
# Update ensemble retriever
|
| 74 |
+
rag_chain.ensemble_retriever = EnsembleRetriever(
|
| 75 |
+
retrievers=[rag_chain.bm25_retriever, rag_chain.faiss_retriever],
|
| 76 |
+
weights=[rag_chain.bm25_weight, rag_chain.faiss_weight]
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Process query
|
| 80 |
+
response = rag_chain.elevated_rag_chain.invoke({"question": query})
|
| 81 |
+
|
| 82 |
+
# Format result
|
| 83 |
+
result = {
|
| 84 |
+
"Parameters": f"Temp: {temp:.2f}, Top-p: {top_p:.2f}, Top-k: {top_k}, BM25: {bm25:.2f}",
|
| 85 |
+
"Response": response,
|
| 86 |
+
"Progress": f"Query {current}/{total_combinations}"
|
| 87 |
+
}
|
| 88 |
+
results.append(result)
|
| 89 |
+
|
| 90 |
+
except Exception as e:
|
| 91 |
+
results.append({
|
| 92 |
+
"Parameters": f"Temp: {temp:.2f}, Top-p: {top_p:.2f}, Top-k: {top_k}, BM25: {bm25:.2f}",
|
| 93 |
+
"Response": f"Error: {str(e)}",
|
| 94 |
+
"Progress": f"Query {current}/{total_combinations}"
|
| 95 |
+
})
|
| 96 |
+
|
| 97 |
+
# Format final results
|
| 98 |
+
formatted_results = "### Batch Query Results\n\n"
|
| 99 |
+
for result in results:
|
| 100 |
+
formatted_results += f"#### {result['Parameters']}\n"
|
| 101 |
+
formatted_results += f"**Progress:** {result['Progress']}\n\n"
|
| 102 |
+
formatted_results += f"{result['Response']}\n\n"
|
| 103 |
+
formatted_results += "---\n\n"
|
| 104 |
+
|
| 105 |
+
return (
|
| 106 |
+
formatted_results,
|
| 107 |
+
f"Job ID: {job_id}",
|
| 108 |
+
f"Input tokens: {count_tokens(query)}",
|
| 109 |
+
f"Output tokens: {sum(count_tokens(r['Response']) for r in results)}"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def process_batch_query_async(query, model_choice, max_tokens, param_configs, slider_values):
|
| 113 |
+
"""Asynchronous version of batch query processing"""
|
| 114 |
+
global last_job_id
|
| 115 |
+
if not query:
|
| 116 |
+
return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0", "", "", get_job_list()
|
| 117 |
+
|
| 118 |
+
if not hasattr(rag_chain, 'elevated_rag_chain') or not rag_chain.raw_data:
|
| 119 |
+
return "Please load files first.", "", "Input tokens: 0", "Output tokens: 0", "", "", get_job_list()
|
| 120 |
+
|
| 121 |
+
job_id = str(uuid.uuid4())
|
| 122 |
+
debug_print(f"Starting async batch job {job_id} for query: {query}")
|
| 123 |
+
|
| 124 |
+
# Get slider values
|
| 125 |
+
slider_values = {
|
| 126 |
+
'temperature': slider_values['temperature'],
|
| 127 |
+
'top_p': slider_values['top_p'],
|
| 128 |
+
'top_k': slider_values['top_k'],
|
| 129 |
+
'bm25': slider_values['bm25']
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
# Start background thread
|
| 133 |
+
threading.Thread(
|
| 134 |
+
target=process_in_background,
|
| 135 |
+
args=(job_id, process_batch_query, [query, model_choice, max_tokens, param_configs, slider_values, job_id])
|
| 136 |
+
).start()
|
| 137 |
+
|
| 138 |
+
jobs[job_id] = {
|
| 139 |
+
"status": "processing",
|
| 140 |
+
"type": "batch_query",
|
| 141 |
+
"start_time": time.time(),
|
| 142 |
+
"query": query,
|
| 143 |
+
"model": model_choice,
|
| 144 |
+
"param_configs": param_configs
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
last_job_id = job_id
|
| 148 |
+
|
| 149 |
+
return (
|
| 150 |
+
f"Batch query submitted and processing in the background (Job ID: {job_id}).\n\n"
|
| 151 |
+
f"Use 'Check Job Status' tab with this ID to get results.",
|
| 152 |
+
f"Job ID: {job_id}",
|
| 153 |
+
f"Input tokens: {count_tokens(query)}",
|
| 154 |
+
"Output tokens: pending",
|
| 155 |
+
job_id, # Return job_id to update the job_id_input component
|
| 156 |
+
query, # Return query to update the job_query_display component
|
| 157 |
+
get_job_list() # Return updated job list
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
def submit_batch_query_async(query, model_choice, max_tokens, temp_config, top_p_config, top_k_config, bm25_config,
|
| 161 |
+
temp_slider, top_p_slider, top_k_slider, bm25_slider):
|
| 162 |
+
"""Handle batch query submission with async processing"""
|
| 163 |
+
if not query:
|
| 164 |
+
return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0", "", "", get_job_list()
|
| 165 |
+
|
| 166 |
+
if not hasattr(rag_chain, 'elevated_rag_chain') or not rag_chain.raw_data:
|
| 167 |
+
return "Please load files first.", "", "Input tokens: 0", "Output tokens: 0", "", "", get_job_list()
|
| 168 |
+
|
| 169 |
+
# Get slider values
|
| 170 |
+
slider_values = {
|
| 171 |
+
'temperature': temp_slider,
|
| 172 |
+
'top_p': top_p_slider,
|
| 173 |
+
'top_k': top_k_slider,
|
| 174 |
+
'bm25': bm25_slider
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
param_configs = {
|
| 178 |
+
'temperature': temp_config,
|
| 179 |
+
'top_p': top_p_config,
|
| 180 |
+
'top_k': top_k_config,
|
| 181 |
+
'bm25': bm25_config
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
return process_batch_query_async(query, model_choice, max_tokens, param_configs, slider_values)
|
| 185 |
+
|
| 186 |
+
def submit_batch_query(query, model_choice, max_tokens, temp_config, top_p_config, top_k_config, bm25_config,
|
| 187 |
+
temp_slider, top_p_slider, top_k_slider, bm25_slider):
|
| 188 |
+
"""Handle batch query submission"""
|
| 189 |
+
if not query:
|
| 190 |
+
return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
|
| 191 |
+
|
| 192 |
+
if not hasattr(rag_chain, 'elevated_rag_chain') or not rag_chain.raw_data:
|
| 193 |
+
return "Please load files first.", "", "Input tokens: 0", "Output tokens: 0"
|
| 194 |
+
|
| 195 |
+
# Get slider values
|
| 196 |
+
slider_values = {
|
| 197 |
+
'temperature': temp_slider,
|
| 198 |
+
'top_p': top_p_slider,
|
| 199 |
+
'top_k': top_k_slider,
|
| 200 |
+
'bm25': bm25_slider
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
try:
|
| 204 |
+
results = process_batch_query(query, model_choice, max_tokens,
|
| 205 |
+
{'temperature': temp_config, 'top_p': top_p_config,
|
| 206 |
+
'top_k': top_k_config, 'bm25': bm25_config},
|
| 207 |
+
slider_values)
|
| 208 |
+
|
| 209 |
+
# Format results for display
|
| 210 |
+
formatted_results = "### Batch Query Results\n\n"
|
| 211 |
+
for result in results:
|
| 212 |
+
formatted_results += f"#### {result['Parameters']}\n"
|
| 213 |
+
formatted_results += f"**Progress:** {result['Progress']}\n\n"
|
| 214 |
+
formatted_results += f"{result['Response']}\n\n"
|
| 215 |
+
formatted_results += "---\n\n"
|
| 216 |
+
|
| 217 |
+
return formatted_results, "", f"Input tokens: {count_tokens(query)}", f"Output tokens: {sum(count_tokens(r['Response']) for r in results)}"
|
| 218 |
+
|
| 219 |
+
except Exception as e:
|
| 220 |
+
return f"Error processing batch query: {str(e)}", "", "Input tokens: 0", "Output tokens: 0"
|
| 221 |
+
|
| 222 |
def get_mime_type(file_path):
|
| 223 |
return mimetypes.guess_type(file_path)[0] or 'application/octet-stream'
|
| 224 |
|
|
|
|
| 226 |
print(pydantic.__version__)
|
| 227 |
# Add Mistral imports with fallback handling
|
| 228 |
|
| 229 |
+
slider_max_tokens = None
|
| 230 |
+
|
| 231 |
try:
|
| 232 |
from mistralai import Mistral
|
| 233 |
MISTRAL_AVAILABLE = True
|
|
|
|
| 292 |
error_result = (f"Error processing job: {str(e)}", "", "", "")
|
| 293 |
results_queue.put((job_id, error_result))
|
| 294 |
|
| 295 |
+
def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p, top_k, max_tokens_slider):
|
| 296 |
"""Asynchronous version of load_pdfs_updated to prevent timeouts"""
|
| 297 |
global last_job_id
|
| 298 |
if not file_links:
|
| 299 |
+
return "Please enter non-empty URLs", "", "Model used: N/A", "", "", get_job_list(), ""
|
| 300 |
+
global slider_max_tokens
|
| 301 |
+
slider_max_tokens = max_tokens_slider
|
| 302 |
+
|
| 303 |
|
| 304 |
job_id = str(uuid.uuid4())
|
| 305 |
debug_print(f"Starting async job {job_id} for file loading")
|
|
|
|
| 307 |
# Start background thread
|
| 308 |
threading.Thread(
|
| 309 |
target=process_in_background,
|
| 310 |
+
args=(job_id, load_pdfs_updated, [file_links, model_choice, prompt_template, bm25_weight, temperature, top_p, top_k])
|
| 311 |
).start()
|
| 312 |
|
| 313 |
job_query = f"Loading files: {file_links.split()[0]}..." if file_links else "No files"
|
|
|
|
| 320 |
|
| 321 |
last_job_id = job_id
|
| 322 |
|
| 323 |
+
init_message = "Vector database initialized using the files.\nThe above parameters were used in the initialization of the RAG chain."
|
| 324 |
+
|
| 325 |
return (
|
| 326 |
f"Files submitted and processing in the background (Job ID: {job_id}).\n\n"
|
| 327 |
f"Use 'Check Job Status' tab with this ID to get results.",
|
|
|
|
| 329 |
f"Model requested: {model_choice}",
|
| 330 |
job_id, # Return job_id to update the job_id_input component
|
| 331 |
job_query, # Return job_query to update the job_query_display component
|
| 332 |
+
get_job_list(), # Return updated job list
|
| 333 |
+
init_message # Return initialization message
|
| 334 |
)
|
| 335 |
|
| 336 |
+
def submit_query_async(query, model_choice, max_tokens_slider, temperature, top_p, top_k, bm25_weight):
|
| 337 |
"""Asynchronous version of submit_query_updated to prevent timeouts"""
|
| 338 |
global last_job_id
|
| 339 |
if not query:
|
| 340 |
return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0", "", "", get_job_list()
|
| 341 |
+
global slider_max_tokens
|
| 342 |
+
slider_max_tokens = max_tokens_slider
|
| 343 |
|
| 344 |
job_id = str(uuid.uuid4())
|
| 345 |
debug_print(f"Starting async job {job_id} for query: {query}")
|
|
|
|
| 347 |
# Update model if specified
|
| 348 |
if model_choice and rag_chain and rag_chain.llm_choice != model_choice:
|
| 349 |
debug_print(f"Updating model to {model_choice} for this query")
|
| 350 |
+
rag_chain.update_llm_pipeline(model_choice, temperature, top_p, top_k,
|
| 351 |
+
rag_chain.prompt_template, bm25_weight)
|
| 352 |
|
| 353 |
# Start background thread
|
| 354 |
threading.Thread(
|
| 355 |
target=process_in_background,
|
| 356 |
+
args=(job_id, submit_query_updated, [query, temperature, top_p, top_k, bm25_weight])
|
| 357 |
).start()
|
| 358 |
|
| 359 |
jobs[job_id] = {
|
|
|
|
| 743 |
|
| 744 |
class ElevatedRagChain:
|
| 745 |
def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
|
| 746 |
+
bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95, top_k: int = 50) -> None:
|
| 747 |
debug_print(f"Initializing ElevatedRagChain with model: {llm_choice}")
|
| 748 |
self.embed_func = HuggingFaceEmbeddings(
|
| 749 |
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
|
|
|
| 751 |
)
|
| 752 |
self.bm25_weight = bm25_weight
|
| 753 |
self.faiss_weight = 1.0 - bm25_weight
|
| 754 |
+
self.top_k = top_k
|
| 755 |
self.llm_choice = llm_choice
|
| 756 |
self.temperature = temperature
|
| 757 |
self.top_p = top_p
|
|
|
|
| 780 |
# Improve error handling in the ElevatedRagChain class
|
| 781 |
def create_llm_pipeline(self):
|
| 782 |
from langchain.llms.base import LLM # Import LLM here so it's always defined
|
| 783 |
+
from typing import Optional, List, Any
|
| 784 |
+
from pydantic import PrivateAttr
|
| 785 |
+
global slider_max_tokens
|
| 786 |
+
|
| 787 |
+
# Extract the model name without the flag emoji prefix
|
| 788 |
+
clean_llm_choice = self.llm_choice.split(" ", 1)[-1] if " " in self.llm_choice else self.llm_choice
|
| 789 |
+
normalized = clean_llm_choice.lower()
|
| 790 |
+
print(f"Normalized model name: {normalized}")
|
| 791 |
+
|
| 792 |
+
# Model configurations from the second file
|
| 793 |
+
model_token_limits = {
|
| 794 |
+
"gpt-3.5": 16385,
|
| 795 |
+
"gpt-4o": 128000,
|
| 796 |
+
"gpt-4o-mini": 128000,
|
| 797 |
+
"meta-llama-3": 4096,
|
| 798 |
+
"mistral-api": 128000,
|
| 799 |
+
"o1-mini": 128000,
|
| 800 |
+
"o3-mini": 128000
|
| 801 |
+
}
|
| 802 |
+
|
| 803 |
+
model_map = {
|
| 804 |
+
"gpt-3.5": "gpt-3.5-turbo",
|
| 805 |
+
"gpt-4o": "gpt-4o",
|
| 806 |
+
"gpt-4o mini": "gpt-4o-mini",
|
| 807 |
+
"o1-mini": "gpt-4o-mini",
|
| 808 |
+
"o3-mini": "gpt-4o-mini",
|
| 809 |
+
"mistral": "mistral-small-latest",
|
| 810 |
+
"mistral-api": "mistral-small-latest",
|
| 811 |
+
"meta-llama-3": "meta-llama/Meta-Llama-3-8B-Instruct",
|
| 812 |
+
"remote meta-llama-3": "meta-llama/Meta-Llama-3-8B-Instruct"
|
| 813 |
+
}
|
| 814 |
+
|
| 815 |
+
model_pricing = {
|
| 816 |
+
"gpt-3.5": {"USD": {"input": 0.0000005, "output": 0.0000015}, "RON": {"input": 0.0000023, "output": 0.0000069}},
|
| 817 |
+
"gpt-4o": {"USD": {"input": 0.0000025, "output": 0.00001}, "RON": {"input": 0.0000115, "output": 0.000046}},
|
| 818 |
+
"gpt-4o-mini": {"USD": {"input": 0.00000015, "output": 0.0000006}, "RON": {"input": 0.0000007, "output": 0.0000028}},
|
| 819 |
+
"o1-mini": {"USD": {"input": 0.0000011, "output": 0.0000044}, "RON": {"input": 0.0000051, "output": 0.0000204}},
|
| 820 |
+
"o3-mini": {"USD": {"input": 0.0000011, "output": 0.0000044}, "RON": {"input": 0.0000051, "output": 0.0000204}},
|
| 821 |
+
"meta-llama-3": {"USD": {"input": 0.00, "output": 0.00}, "RON": {"input": 0.00, "output": 0.00}},
|
| 822 |
+
"mistral": {"USD": {"input": 0.00, "output": 0.00}, "RON": {"input": 0.00, "output": 0.00}},
|
| 823 |
+
"mistral-api": {"USD": {"input": 0.00, "output": 0.00}, "RON": {"input": 0.00, "output": 0.00}}
|
| 824 |
+
}
|
| 825 |
+
pricing_info = ""
|
| 826 |
+
|
| 827 |
+
# Find the matching model
|
| 828 |
+
model_key = None
|
| 829 |
+
for key in model_map:
|
| 830 |
+
if key.lower() in normalized:
|
| 831 |
+
model_key = key
|
| 832 |
+
break
|
| 833 |
+
|
| 834 |
+
if not model_key:
|
| 835 |
+
raise ValueError(f"Unsupported model: {normalized}")
|
| 836 |
+
model = model_map[model_key]
|
| 837 |
+
max_tokens = model_token_limits.get(model, 4096)
|
| 838 |
+
max_tokens = min(slider_max_tokens, max_tokens)
|
| 839 |
+
pricing_info = model_pricing.get(model_key, {"USD": {"input": 0.00, "output": 0.00}, "RON": {"input": 0.00, "output": 0.00}})
|
| 840 |
+
|
| 841 |
try:
|
| 842 |
+
# OpenAI models (GPT-3.5, GPT-4o, GPT-4o mini, o1-mini, o3-mini)
|
| 843 |
+
if any(model in normalized for model in ["gpt-3.5", "gpt-4o", "o1-mini", "o3-mini"]):
|
| 844 |
+
debug_print(f"Creating OpenAI API pipeline for {normalized}...")
|
| 845 |
+
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
| 846 |
+
if not openai_api_key:
|
| 847 |
+
raise ValueError("Please set the OPENAI_API_KEY environment variable to use OpenAI API.")
|
| 848 |
+
|
| 849 |
+
import openai
|
| 850 |
+
|
| 851 |
+
class OpenAILLM(LLM):
|
| 852 |
+
model_name: str = model
|
| 853 |
+
llm_choice: str = model
|
| 854 |
+
max_context_tokens: int = max_tokens
|
| 855 |
+
pricing: dict = pricing_info
|
| 856 |
+
temperature: float = 0.7
|
| 857 |
+
top_p: float = 0.95
|
| 858 |
+
top_k: int = 50
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
@property
|
| 862 |
+
def _llm_type(self) -> str:
|
| 863 |
+
return "openai_llm"
|
| 864 |
+
|
| 865 |
+
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
| 866 |
+
try:
|
| 867 |
+
openai.api_key = openai_api_key
|
| 868 |
+
print(f" tokens: {max_tokens}")
|
| 869 |
+
response = openai.ChatCompletion.create(
|
| 870 |
+
model=self.model_name,
|
| 871 |
+
messages=[{"role": "user", "content": prompt}],
|
| 872 |
+
temperature=self.temperature,
|
| 873 |
+
top_p=self.top_p,
|
| 874 |
+
max_tokens=max_tokens
|
| 875 |
+
)
|
| 876 |
+
return response["choices"][0]["message"]["content"]
|
| 877 |
+
except Exception as e:
|
| 878 |
+
debug_print(f"OpenAI API error: {str(e)}")
|
| 879 |
+
return f"Error generating response: {str(e)}"
|
| 880 |
+
|
| 881 |
+
@property
|
| 882 |
+
def _identifying_params(self) -> dict:
|
| 883 |
+
return {
|
| 884 |
+
"model": self.model_name,
|
| 885 |
+
"max_tokens": self.max_context_tokens,
|
| 886 |
+
"temperature": self.temperature,
|
| 887 |
+
"top_p": self.top_p,
|
| 888 |
+
"top_k": self.top_k
|
| 889 |
+
}
|
| 890 |
+
|
| 891 |
+
debug_print(f"OpenAI {model} pipeline created successfully.")
|
| 892 |
+
return OpenAILLM()
|
| 893 |
+
|
| 894 |
+
# Meta-Llama-3 model
|
| 895 |
+
elif "meta-llama" in normalized or "llama" in normalized:
|
| 896 |
debug_print("Creating remote Meta-Llama-3 pipeline via Hugging Face Inference API...")
|
| 897 |
from huggingface_hub import InferenceClient
|
| 898 |
repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
|
|
|
| 901 |
raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
|
| 902 |
|
| 903 |
client = InferenceClient(token=hf_api_token, timeout=120)
|
| 904 |
+
|
|
|
|
| 905 |
def remote_generate(prompt: str) -> str:
|
| 906 |
max_retries = 3
|
| 907 |
backoff = 2 # start with 2 seconds
|
| 908 |
for attempt in range(max_retries):
|
| 909 |
try:
|
| 910 |
+
debug_print(f"Remote generation attempt {attempt+1} tokens: {self.max_tokens}")
|
| 911 |
response = client.text_generation(
|
| 912 |
prompt,
|
| 913 |
model=repo_id,
|
| 914 |
temperature=self.temperature,
|
| 915 |
top_p=self.top_p,
|
| 916 |
+
max_tokens= max_tokens # Reduced token count for speed
|
| 917 |
)
|
| 918 |
return response
|
| 919 |
except Exception as e:
|
|
|
|
| 925 |
return "Failed to generate response after multiple attempts."
|
| 926 |
|
| 927 |
class RemoteLLM(LLM):
|
| 928 |
+
model_name: str = repo_id
|
| 929 |
+
llm_choice: str = repo_id
|
| 930 |
+
max_context_tokens: int = max_tokens
|
| 931 |
+
pricing: dict = pricing_info
|
| 932 |
+
|
| 933 |
@property
|
| 934 |
def _llm_type(self) -> str:
|
| 935 |
return "remote_llm"
|
|
|
|
| 939 |
|
| 940 |
@property
|
| 941 |
def _identifying_params(self) -> dict:
|
| 942 |
+
return {"model": self.model_name, "max_tokens": self.max_context_tokens}
|
| 943 |
|
| 944 |
debug_print("Remote Meta-Llama-3 pipeline created successfully.")
|
| 945 |
return RemoteLLM()
|
| 946 |
+
|
| 947 |
+
# Mistral API model
|
| 948 |
+
elif "mistral" in normalized:
|
| 949 |
debug_print("Creating Mistral API pipeline...")
|
| 950 |
mistral_api_key = os.environ.get("MISTRAL_API_KEY")
|
| 951 |
if not mistral_api_key:
|
| 952 |
raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
|
| 953 |
+
|
| 954 |
try:
|
| 955 |
+
from mistralai import Mistral
|
| 956 |
debug_print("Mistral library imported successfully")
|
| 957 |
except ImportError:
|
| 958 |
+
raise ImportError("Mistral client library not installed. Please install with 'pip install mistralai'.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 959 |
|
| 960 |
+
class MistralLLM(LLM):
|
| 961 |
+
temperature: float = 0.7
|
| 962 |
+
top_p: float = 0.95
|
| 963 |
+
model_name: str = model
|
| 964 |
+
llm_choice: str = model
|
| 965 |
+
|
| 966 |
+
pricing: dict = pricing_info
|
| 967 |
+
_client: Any = PrivateAttr(default=None)
|
| 968 |
+
|
| 969 |
+
def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
|
| 970 |
+
try:
|
| 971 |
+
super().__init__(**kwargs)
|
| 972 |
+
# Bypass Pydantic's __setattr__ to assign to _client
|
| 973 |
+
object.__setattr__(self, '_client', Mistral(api_key=api_key))
|
| 974 |
+
self.temperature = temperature
|
| 975 |
+
self.top_p = top_p
|
| 976 |
+
except Exception as e:
|
| 977 |
+
debug_print(f"Init Mistral failed with error: {e}")
|
| 978 |
+
|
| 979 |
@property
|
| 980 |
def _llm_type(self) -> str:
|
| 981 |
+
return "mistral_llm"
|
| 982 |
+
|
| 983 |
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
| 984 |
+
try:
|
| 985 |
+
debug_print(f"Calling Mistral API... tokens: {max_tokens}")
|
| 986 |
+
response = self._client.chat.complete(
|
| 987 |
+
model=self.model_name,
|
| 988 |
+
messages=[{"role": "user", "content": prompt}],
|
| 989 |
+
temperature=self.temperature,
|
| 990 |
+
top_p=self.top_p,
|
| 991 |
+
max_tokens= max_tokens
|
| 992 |
+
)
|
| 993 |
+
return response.choices[0].message.content
|
| 994 |
+
except Exception as e:
|
| 995 |
+
debug_print(f"Mistral API error: {str(e)}")
|
| 996 |
+
return f"Error generating response: {str(e)}"
|
| 997 |
+
|
| 998 |
@property
|
| 999 |
def _identifying_params(self) -> dict:
|
| 1000 |
+
return {"model": self.model_name, "max_tokens": max_tokens}
|
| 1001 |
+
|
| 1002 |
+
debug_print("Creating Mistral LLM instance")
|
| 1003 |
+
mistral_llm = MistralLLM(api_key=mistral_api_key, temperature=self.temperature, top_p=self.top_p)
|
| 1004 |
+
debug_print("Mistral API pipeline created successfully.")
|
| 1005 |
+
return mistral_llm
|
| 1006 |
+
|
| 1007 |
+
else:
|
| 1008 |
+
raise ValueError(f"Unsupported model choice: {self.llm_choice}")
|
| 1009 |
|
|
|
|
|
|
|
|
|
|
| 1010 |
except Exception as e:
|
| 1011 |
debug_print(f"Error creating LLM pipeline: {str(e)}")
|
| 1012 |
# Return a dummy LLM that explains the error
|
|
|
|
| 1025 |
return ErrorLLM()
|
| 1026 |
|
| 1027 |
|
| 1028 |
+
def update_llm_pipeline(self, new_model_choice: str, temperature: float, top_p: float, top_k: int, prompt_template: str, bm25_weight: float):
|
| 1029 |
debug_print(f"Updating chain with new model: {new_model_choice}")
|
| 1030 |
self.llm_choice = new_model_choice
|
| 1031 |
self.temperature = temperature
|
| 1032 |
self.top_p = top_p
|
| 1033 |
+
self.top_k = top_k
|
| 1034 |
self.prompt_template = prompt_template
|
| 1035 |
self.bm25_weight = bm25_weight
|
| 1036 |
self.faiss_weight = 1.0 - bm25_weight
|
|
|
|
| 1038 |
def format_response(response: str) -> str:
|
| 1039 |
input_tokens = count_tokens(self.context + self.prompt_template)
|
| 1040 |
output_tokens = count_tokens(response)
|
| 1041 |
+
formatted = f"✅ Response:\n\n"
|
| 1042 |
+
formatted += f"Model: {self.llm_choice}\n"
|
| 1043 |
+
formatted += f"Model Parameters:\n"
|
| 1044 |
+
formatted += f"- Temperature: {self.temperature}\n"
|
| 1045 |
+
formatted += f"- Top-p: {self.top_p}\n"
|
| 1046 |
+
formatted += f"- Top-k: {self.top_k}\n"
|
| 1047 |
+
formatted += f"- BM25 Weight: {self.bm25_weight}\n\n"
|
| 1048 |
+
formatted += f"{response}\n\n---\n"
|
| 1049 |
formatted += f"- **Input tokens:** {input_tokens}\n"
|
| 1050 |
formatted += f"- **Output tokens:** {output_tokens}\n"
|
| 1051 |
formatted += f"- **Generated using:** {self.llm_choice}\n"
|
|
|
|
| 1128 |
def format_response(response: str) -> str:
|
| 1129 |
input_tokens = count_tokens(self.context + self.prompt_template)
|
| 1130 |
output_tokens = count_tokens(response)
|
| 1131 |
+
formatted = f"✅ Response:\n\n"
|
| 1132 |
+
formatted += f"Model: {self.llm_choice}\n"
|
| 1133 |
+
formatted += f"Model Parameters:\n"
|
| 1134 |
+
formatted += f"- Temperature: {self.temperature}\n"
|
| 1135 |
+
formatted += f"- Top-p: {self.top_p}\n"
|
| 1136 |
+
formatted += f"- Top-k: {self.top_k}\n"
|
| 1137 |
+
formatted += f"- BM25 Weight: {self.bm25_weight}\n\n"
|
| 1138 |
+
formatted += f"{response}\n\n---\n"
|
| 1139 |
formatted += f"- **Input tokens:** {input_tokens}\n"
|
| 1140 |
formatted += f"- **Output tokens:** {output_tokens}\n"
|
| 1141 |
formatted += f"- **Generated using:** {self.llm_choice}\n"
|
|
|
|
| 1162 |
global rag_chain
|
| 1163 |
rag_chain = ElevatedRagChain()
|
| 1164 |
|
| 1165 |
+
def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p, top_k):
|
| 1166 |
debug_print("Inside load_pdfs function.")
|
| 1167 |
if not file_links:
|
| 1168 |
debug_print("Please enter non-empty URLs")
|
|
|
|
| 1171 |
links = [link.strip() for link in file_links.split("\n") if link.strip()]
|
| 1172 |
global rag_chain
|
| 1173 |
if rag_chain.raw_data:
|
| 1174 |
+
rag_chain.update_llm_pipeline(model_choice, temperature, top_p, top_k, prompt_template, bm25_weight)
|
| 1175 |
context_display = rag_chain.get_current_context()
|
| 1176 |
response_msg = f"Files already loaded. Chain updated with model: {model_choice}"
|
| 1177 |
return (
|
|
|
|
| 1186 |
prompt_template=prompt_template,
|
| 1187 |
bm25_weight=bm25_weight,
|
| 1188 |
temperature=temperature,
|
| 1189 |
+
top_p=top_p,
|
| 1190 |
+
top_k=top_k
|
| 1191 |
)
|
| 1192 |
rag_chain.add_pdfs_to_vectore_store(links)
|
| 1193 |
context_display = rag_chain.get_current_context()
|
|
|
|
| 1211 |
def update_model(new_model: str):
|
| 1212 |
global rag_chain
|
| 1213 |
if rag_chain and rag_chain.raw_data:
|
| 1214 |
+
rag_chain.update_llm_pipeline(new_model, rag_chain.temperature, rag_chain.top_p, rag_chain.top_k,
|
| 1215 |
rag_chain.prompt_template, rag_chain.bm25_weight)
|
| 1216 |
debug_print(f"Model updated to {rag_chain.llm_choice}")
|
| 1217 |
return f"Model updated to: {rag_chain.llm_choice}"
|
|
|
|
| 1220 |
|
| 1221 |
|
| 1222 |
# Update submit_query_updated to better handle context limitation
|
| 1223 |
+
def submit_query_updated(query, temperature, top_p, top_k, bm25_weight):
|
| 1224 |
debug_print(f"Processing query: {query}")
|
| 1225 |
if not query:
|
| 1226 |
debug_print("Empty query received")
|
|
|
|
| 1231 |
return "Please load files first.", "", "Input tokens: 0", "Output tokens: 0"
|
| 1232 |
|
| 1233 |
try:
|
| 1234 |
+
# Update all parameters for this query
|
| 1235 |
+
rag_chain.temperature = temperature
|
| 1236 |
+
rag_chain.top_p = top_p
|
| 1237 |
+
rag_chain.top_k = top_k
|
| 1238 |
+
rag_chain.bm25_weight = bm25_weight
|
| 1239 |
+
rag_chain.faiss_weight = 1.0 - bm25_weight
|
| 1240 |
+
|
| 1241 |
+
# Update the ensemble retriever weights
|
| 1242 |
+
rag_chain.ensemble_retriever = EnsembleRetriever(
|
| 1243 |
+
retrievers=[rag_chain.bm25_retriever, rag_chain.faiss_retriever],
|
| 1244 |
+
weights=[rag_chain.bm25_weight, rag_chain.faiss_weight]
|
| 1245 |
+
)
|
| 1246 |
+
|
| 1247 |
# Determine max context size based on model
|
| 1248 |
model_name = rag_chain.llm_choice.lower()
|
| 1249 |
max_context_tokens = 32000 if "mistral" in model_name else 4096
|
|
|
|
| 1390 |
clearInterval(jobListInterval);
|
| 1391 |
}
|
| 1392 |
}, 500);
|
| 1393 |
+
|
| 1394 |
+
// Function to disable sliders
|
| 1395 |
+
function disableSliders() {
|
| 1396 |
+
const sliders = document.querySelectorAll('input[type="range"]');
|
| 1397 |
+
sliders.forEach(slider => {
|
| 1398 |
+
if (!slider.closest('.query-tab')) { // Don't disable sliders in query tab
|
| 1399 |
+
slider.disabled = true;
|
| 1400 |
+
slider.style.opacity = '0.5';
|
| 1401 |
+
}
|
| 1402 |
+
});
|
| 1403 |
+
}
|
| 1404 |
+
|
| 1405 |
+
// Function to enable sliders
|
| 1406 |
+
function enableSliders() {
|
| 1407 |
+
const sliders = document.querySelectorAll('input[type="range"]');
|
| 1408 |
+
sliders.forEach(slider => {
|
| 1409 |
+
slider.disabled = false;
|
| 1410 |
+
slider.style.opacity = '1';
|
| 1411 |
+
});
|
| 1412 |
+
}
|
| 1413 |
+
|
| 1414 |
+
// Add event listener for load button
|
| 1415 |
+
const loadButton = document.querySelector('button:contains("Load Files (Async)")');
|
| 1416 |
+
if (loadButton) {
|
| 1417 |
+
loadButton.addEventListener('click', function() {
|
| 1418 |
+
// Wait for the response to come back
|
| 1419 |
+
setTimeout(disableSliders, 1000);
|
| 1420 |
+
});
|
| 1421 |
+
}
|
| 1422 |
+
|
| 1423 |
+
// Add event listener for reset button
|
| 1424 |
+
const resetButton = document.querySelector('button:contains("Reset App")');
|
| 1425 |
+
if (resetButton) {
|
| 1426 |
+
resetButton.addEventListener('click', function() {
|
| 1427 |
+
enableSliders();
|
| 1428 |
+
});
|
| 1429 |
+
}
|
| 1430 |
});
|
| 1431 |
""") as app:
|
| 1432 |
gr.Markdown('''# PhiRAG - Async Version
|
|
|
|
| 1463 |
with gr.Row():
|
| 1464 |
with gr.Column():
|
| 1465 |
model_dropdown = gr.Dropdown(
|
| 1466 |
+
choices=[
|
| 1467 |
+
"🇺🇸 GPT-3.5",
|
| 1468 |
+
"🇺🇸 GPT-4o",
|
| 1469 |
+
"🇺🇸 GPT-4o mini",
|
| 1470 |
+
"🇺🇸 o1-mini",
|
| 1471 |
+
"🇺🇸 o3-mini",
|
| 1472 |
+
"🇺🇸 Remote Meta-Llama-3",
|
| 1473 |
+
"🇪🇺 Mistral-API",
|
| 1474 |
+
],
|
| 1475 |
+
value="🇪🇺 Mistral-API",
|
| 1476 |
label="Select Model"
|
| 1477 |
)
|
| 1478 |
temperature_slider = gr.Slider(
|
|
|
|
| 1483 |
minimum=0.1, maximum=0.99, value=0.95, step=0.05,
|
| 1484 |
label="Word Variety (Top-p)"
|
| 1485 |
)
|
| 1486 |
+
top_k_slider = gr.Slider(
|
| 1487 |
+
minimum=1, maximum=100, value=50, step=1,
|
| 1488 |
+
label="Token Selection (Top-k)"
|
| 1489 |
+
)
|
| 1490 |
with gr.Column():
|
| 1491 |
pdf_input = gr.Textbox(
|
| 1492 |
label="Enter your file URLs (one per line)",
|
|
|
|
| 1522 |
with gr.Row():
|
| 1523 |
model_output = gr.Markdown("**Current Model**: Not selected")
|
| 1524 |
|
| 1525 |
+
with gr.TabItem("Submit Query", elem_classes=["query-tab"]):
|
| 1526 |
with gr.Row():
|
| 1527 |
+
with gr.Column():
|
| 1528 |
+
query_model_dropdown = gr.Dropdown(
|
| 1529 |
+
choices=[
|
| 1530 |
+
"🇺🇸 GPT-3.5",
|
| 1531 |
+
"🇺🇸 GPT-4o",
|
| 1532 |
+
"🇺🇸 GPT-4o mini",
|
| 1533 |
+
"🇺🇸 o1-mini",
|
| 1534 |
+
"🇺🇸 o3-mini",
|
| 1535 |
+
"🇺🇸 Remote Meta-Llama-3",
|
| 1536 |
+
"🇪🇺 Mistral-API",
|
| 1537 |
+
],
|
| 1538 |
+
value="🇪🇺 Mistral-API",
|
| 1539 |
+
label="Query Model"
|
| 1540 |
+
)
|
| 1541 |
+
query_temperature_slider = gr.Slider(
|
| 1542 |
+
minimum=0.1, maximum=1.0, value=0.5, step=0.1,
|
| 1543 |
+
label="Randomness (Temperature)"
|
| 1544 |
+
)
|
| 1545 |
+
query_top_p_slider = gr.Slider(
|
| 1546 |
+
minimum=0.1, maximum=0.99, value=0.95, step=0.05,
|
| 1547 |
+
label="Word Variety (Top-p)"
|
| 1548 |
+
)
|
| 1549 |
+
query_top_k_slider = gr.Slider(
|
| 1550 |
+
minimum=1, maximum=100, value=50, step=1,
|
| 1551 |
+
label="Token Selection (Top-k)"
|
| 1552 |
+
)
|
| 1553 |
+
query_bm25_weight_slider = gr.Slider(
|
| 1554 |
+
minimum=0.0, maximum=1.0, value=0.6, step=0.1,
|
| 1555 |
+
label="Lexical vs Semantics (BM25 Weight)"
|
| 1556 |
+
)
|
| 1557 |
+
with gr.Column():
|
| 1558 |
+
max_tokens_slider = gr.Slider(minimum=1000, maximum=128000, value=3000, label="🔢 Max Tokens", step=1000)
|
| 1559 |
+
query_input = gr.Textbox(
|
| 1560 |
+
label="Enter your query here",
|
| 1561 |
+
placeholder="Type your query",
|
| 1562 |
+
lines=4
|
| 1563 |
+
)
|
| 1564 |
+
submit_button = gr.Button("Submit Query (Async)")
|
| 1565 |
|
| 1566 |
with gr.Row():
|
| 1567 |
query_response = gr.Textbox(
|
|
|
|
| 1634 |
status_tokens1 = gr.Markdown("")
|
| 1635 |
status_tokens2 = gr.Markdown("")
|
| 1636 |
|
| 1637 |
+
with gr.TabItem("Batch Query"):
|
| 1638 |
+
with gr.Row():
|
| 1639 |
+
with gr.Column():
|
| 1640 |
+
batch_model_dropdown = gr.Dropdown(
|
| 1641 |
+
choices=[
|
| 1642 |
+
"🇺🇸 GPT-3.5",
|
| 1643 |
+
"🇺🇸 GPT-4o",
|
| 1644 |
+
"🇺🇸 GPT-4o mini",
|
| 1645 |
+
"🇺🇸 o1-mini",
|
| 1646 |
+
"🇺🇸 o3-mini",
|
| 1647 |
+
"🇺🇸 Remote Meta-Llama-3",
|
| 1648 |
+
"🇪🇺 Mistral-API",
|
| 1649 |
+
],
|
| 1650 |
+
value="🇪🇺 Mistral-API",
|
| 1651 |
+
label="Query Model"
|
| 1652 |
+
)
|
| 1653 |
+
with gr.Row():
|
| 1654 |
+
temp_variation = gr.Dropdown(
|
| 1655 |
+
choices=["Constant", "Whole range 3 values", "Whole range 5 values", "Whole range 7 values", "Whole range 10 values"],
|
| 1656 |
+
value="Constant",
|
| 1657 |
+
label="Temperature Variation"
|
| 1658 |
+
)
|
| 1659 |
+
batch_temperature_slider = gr.Slider(
|
| 1660 |
+
minimum=0.1, maximum=1.0, value=0.5, step=0.1,
|
| 1661 |
+
label="Randomness (Temperature)"
|
| 1662 |
+
)
|
| 1663 |
+
with gr.Row():
|
| 1664 |
+
top_p_variation = gr.Dropdown(
|
| 1665 |
+
choices=["Constant", "Whole range 3 values", "Whole range 5 values", "Whole range 7 values", "Whole range 10 values"],
|
| 1666 |
+
value="Constant",
|
| 1667 |
+
label="Top-p Variation"
|
| 1668 |
+
)
|
| 1669 |
+
batch_top_p_slider = gr.Slider(
|
| 1670 |
+
minimum=0.1, maximum=0.99, value=0.95, step=0.05,
|
| 1671 |
+
label="Word Variety (Top-p)"
|
| 1672 |
+
)
|
| 1673 |
+
with gr.Row():
|
| 1674 |
+
top_k_variation = gr.Dropdown(
|
| 1675 |
+
choices=["Constant", "Whole range 3 values", "Whole range 5 values", "Whole range 7 values", "Whole range 10 values"],
|
| 1676 |
+
value="Constant",
|
| 1677 |
+
label="Top-k Variation"
|
| 1678 |
+
)
|
| 1679 |
+
batch_top_k_slider = gr.Slider(
|
| 1680 |
+
minimum=1, maximum=100, value=50, step=1,
|
| 1681 |
+
label="Token Selection (Top-k)"
|
| 1682 |
+
)
|
| 1683 |
+
with gr.Row():
|
| 1684 |
+
bm25_variation = gr.Dropdown(
|
| 1685 |
+
choices=["Constant", "Whole range 3 values", "Whole range 5 values", "Whole range 7 values", "Whole range 10 values"],
|
| 1686 |
+
value="Constant",
|
| 1687 |
+
label="BM25 Weight Variation"
|
| 1688 |
+
)
|
| 1689 |
+
batch_bm25_weight_slider = gr.Slider(
|
| 1690 |
+
minimum=0.0, maximum=1.0, value=0.6, step=0.1,
|
| 1691 |
+
label="Lexical vs Semantics (BM25 Weight)"
|
| 1692 |
+
)
|
| 1693 |
+
with gr.Column():
|
| 1694 |
+
batch_max_tokens_slider = gr.Slider(
|
| 1695 |
+
minimum=1000, maximum=128000, value=3000, label="🔢 Max Tokens", step=1000
|
| 1696 |
+
)
|
| 1697 |
+
batch_query_input = gr.Textbox(
|
| 1698 |
+
label="Enter your query here",
|
| 1699 |
+
placeholder="Type your query",
|
| 1700 |
+
lines=4
|
| 1701 |
+
)
|
| 1702 |
+
batch_submit_button = gr.Button("Submit Batch Query (Async)")
|
| 1703 |
+
|
| 1704 |
+
with gr.Row():
|
| 1705 |
+
batch_query_response = gr.Textbox(
|
| 1706 |
+
label="Batch Query Results",
|
| 1707 |
+
placeholder="Results will appear here (formatted as Markdown)",
|
| 1708 |
+
lines=10
|
| 1709 |
+
)
|
| 1710 |
+
batch_query_context = gr.Textbox(
|
| 1711 |
+
label="Context Information",
|
| 1712 |
+
placeholder="Retrieved context will appear here",
|
| 1713 |
+
lines=6
|
| 1714 |
+
)
|
| 1715 |
+
|
| 1716 |
+
with gr.Row():
|
| 1717 |
+
batch_input_tokens = gr.Markdown("Input tokens: 0")
|
| 1718 |
+
batch_output_tokens = gr.Markdown("Output tokens: 0")
|
| 1719 |
+
|
| 1720 |
+
with gr.Row():
|
| 1721 |
+
with gr.Column(scale=1):
|
| 1722 |
+
batch_job_list = gr.Markdown(
|
| 1723 |
+
value="No jobs yet",
|
| 1724 |
+
label="Job List (Click to select)"
|
| 1725 |
+
)
|
| 1726 |
+
batch_refresh_button = gr.Button("Refresh Job List")
|
| 1727 |
+
batch_auto_refresh_checkbox = gr.Checkbox(
|
| 1728 |
+
label="Enable Auto Refresh",
|
| 1729 |
+
value=False
|
| 1730 |
+
)
|
| 1731 |
+
batch_df = gr.DataFrame(
|
| 1732 |
+
value=run_query(10),
|
| 1733 |
+
headers=["Number", "Square"],
|
| 1734 |
+
label="Query Results",
|
| 1735 |
+
visible=False
|
| 1736 |
+
)
|
| 1737 |
+
|
| 1738 |
+
with gr.Column(scale=2):
|
| 1739 |
+
batch_job_id_input = gr.Textbox(
|
| 1740 |
+
label="Job ID",
|
| 1741 |
+
placeholder="Job ID will appear here when selected from the list",
|
| 1742 |
+
lines=1
|
| 1743 |
+
)
|
| 1744 |
+
batch_job_query_display = gr.Textbox(
|
| 1745 |
+
label="Job Query",
|
| 1746 |
+
placeholder="The query associated with this job will appear here",
|
| 1747 |
+
lines=2,
|
| 1748 |
+
interactive=False
|
| 1749 |
+
)
|
| 1750 |
+
batch_check_button = gr.Button("Check Status")
|
| 1751 |
+
batch_cleanup_button = gr.Button("Cleanup Old Jobs")
|
| 1752 |
+
|
| 1753 |
+
with gr.Row():
|
| 1754 |
+
batch_status_response = gr.Textbox(
|
| 1755 |
+
label="Job Result",
|
| 1756 |
+
placeholder="Job result will appear here",
|
| 1757 |
+
lines=6
|
| 1758 |
+
)
|
| 1759 |
+
batch_status_context = gr.Textbox(
|
| 1760 |
+
label="Context Information",
|
| 1761 |
+
placeholder="Context information will appear here",
|
| 1762 |
+
lines=6
|
| 1763 |
+
)
|
| 1764 |
+
|
| 1765 |
+
with gr.Row():
|
| 1766 |
+
batch_status_tokens1 = gr.Markdown("")
|
| 1767 |
+
batch_status_tokens2 = gr.Markdown("")
|
| 1768 |
+
|
| 1769 |
with gr.TabItem("App Management"):
|
| 1770 |
with gr.Row():
|
| 1771 |
reset_button = gr.Button("Reset App")
|
|
|
|
| 1786 |
with gr.Row():
|
| 1787 |
reset_model = gr.Markdown("")
|
| 1788 |
|
| 1789 |
+
# Add initialization info display
|
| 1790 |
+
init_info = gr.Markdown("")
|
| 1791 |
+
|
| 1792 |
+
# Update load_button click to include top_k
|
| 1793 |
load_button.click(
|
| 1794 |
load_pdfs_async,
|
| 1795 |
+
inputs=[pdf_input, model_dropdown, prompt_input, bm25_weight_slider, temperature_slider, top_p_slider, top_k_slider, max_tokens_slider],
|
| 1796 |
+
outputs=[load_response, load_context, model_output, job_id_input, job_query_display, job_list, init_info]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1797 |
)
|
| 1798 |
|
| 1799 |
+
# Update submit_button click to include top_k
|
| 1800 |
submit_button.click(
|
| 1801 |
submit_query_async,
|
| 1802 |
+
inputs=[query_input, query_model_dropdown, max_tokens_slider, query_temperature_slider, query_top_p_slider, query_top_k_slider, query_bm25_weight_slider],
|
| 1803 |
outputs=[query_response, query_context, input_tokens, output_tokens, job_id_input, job_query_display, job_list]
|
| 1804 |
)
|
| 1805 |
|
| 1806 |
+
# Add function to sync all parameters
|
| 1807 |
+
def sync_parameters(temperature, top_p, top_k, bm25_weight):
|
| 1808 |
+
return temperature, top_p, top_k, bm25_weight
|
| 1809 |
+
|
| 1810 |
+
# Sync parameters between tabs
|
| 1811 |
+
temperature_slider.change(
|
| 1812 |
+
fn=sync_parameters,
|
| 1813 |
+
inputs=[temperature_slider, top_p_slider, top_k_slider, bm25_weight_slider],
|
| 1814 |
+
outputs=[query_temperature_slider, query_top_p_slider, query_top_k_slider, query_bm25_weight_slider]
|
| 1815 |
+
)
|
| 1816 |
+
top_p_slider.change(
|
| 1817 |
+
fn=sync_parameters,
|
| 1818 |
+
inputs=[temperature_slider, top_p_slider, top_k_slider, bm25_weight_slider],
|
| 1819 |
+
outputs=[query_temperature_slider, query_top_p_slider, query_top_k_slider, query_bm25_weight_slider]
|
| 1820 |
+
)
|
| 1821 |
+
top_k_slider.change(
|
| 1822 |
+
fn=sync_parameters,
|
| 1823 |
+
inputs=[temperature_slider, top_p_slider, top_k_slider, bm25_weight_slider],
|
| 1824 |
+
outputs=[query_temperature_slider, query_top_p_slider, query_top_k_slider, query_bm25_weight_slider]
|
| 1825 |
+
)
|
| 1826 |
+
bm25_weight_slider.change(
|
| 1827 |
+
fn=sync_parameters,
|
| 1828 |
+
inputs=[temperature_slider, top_p_slider, top_k_slider, bm25_weight_slider],
|
| 1829 |
+
outputs=[query_temperature_slider, query_top_p_slider, query_top_k_slider, query_bm25_weight_slider]
|
| 1830 |
+
)
|
| 1831 |
+
|
| 1832 |
+
# Connect the buttons to their respective functions
|
| 1833 |
check_button.click(
|
| 1834 |
check_job_status,
|
| 1835 |
inputs=[job_id_input],
|
|
|
|
| 1883 |
every=2 #if auto_refresh_checkbox.value else None # Directly set `every` based on the checkbox state
|
| 1884 |
)
|
| 1885 |
|
| 1886 |
+
# Add batch query button click handler
|
| 1887 |
+
batch_submit_button.click(
|
| 1888 |
+
submit_batch_query_async,
|
| 1889 |
+
inputs=[
|
| 1890 |
+
batch_query_input,
|
| 1891 |
+
batch_model_dropdown,
|
| 1892 |
+
batch_max_tokens_slider,
|
| 1893 |
+
temp_variation,
|
| 1894 |
+
top_p_variation,
|
| 1895 |
+
top_k_variation,
|
| 1896 |
+
bm25_variation,
|
| 1897 |
+
batch_temperature_slider,
|
| 1898 |
+
batch_top_p_slider,
|
| 1899 |
+
batch_top_k_slider,
|
| 1900 |
+
batch_bm25_weight_slider
|
| 1901 |
+
],
|
| 1902 |
+
outputs=[
|
| 1903 |
+
batch_query_response,
|
| 1904 |
+
batch_query_context,
|
| 1905 |
+
batch_input_tokens,
|
| 1906 |
+
batch_output_tokens,
|
| 1907 |
+
batch_job_id_input,
|
| 1908 |
+
batch_job_query_display,
|
| 1909 |
+
batch_job_list
|
| 1910 |
+
]
|
| 1911 |
+
)
|
| 1912 |
+
|
| 1913 |
+
# Add batch job status checking
|
| 1914 |
+
batch_check_button.click(
|
| 1915 |
+
check_job_status,
|
| 1916 |
+
inputs=[batch_job_id_input],
|
| 1917 |
+
outputs=[batch_status_response, batch_status_context, batch_status_tokens1, batch_status_tokens2, batch_job_query_display]
|
| 1918 |
+
)
|
| 1919 |
+
|
| 1920 |
+
# Add batch job list refresh
|
| 1921 |
+
batch_refresh_button.click(
|
| 1922 |
+
refresh_job_list,
|
| 1923 |
+
inputs=[],
|
| 1924 |
+
outputs=[batch_job_list]
|
| 1925 |
+
)
|
| 1926 |
+
|
| 1927 |
+
# Add batch job list selection
|
| 1928 |
+
batch_job_id_input.change(
|
| 1929 |
+
job_selected,
|
| 1930 |
+
inputs=[batch_job_id_input],
|
| 1931 |
+
outputs=[batch_job_id_input, batch_job_query_display]
|
| 1932 |
+
)
|
| 1933 |
+
|
| 1934 |
+
# Add batch cleanup
|
| 1935 |
+
batch_cleanup_button.click(
|
| 1936 |
+
cleanup_old_jobs,
|
| 1937 |
+
inputs=[],
|
| 1938 |
+
outputs=[batch_status_response, batch_status_context, batch_status_tokens1]
|
| 1939 |
+
)
|
| 1940 |
+
|
| 1941 |
+
# Add batch auto-refresh
|
| 1942 |
+
batch_auto_refresh_checkbox.change(
|
| 1943 |
+
fn=periodic_update,
|
| 1944 |
+
inputs=[batch_auto_refresh_checkbox],
|
| 1945 |
+
outputs=[batch_job_list, batch_status_response, batch_df, batch_status_context],
|
| 1946 |
+
every=2
|
| 1947 |
+
)
|
| 1948 |
+
|
| 1949 |
if __name__ == "__main__":
|
| 1950 |
debug_print("Launching Gradio interface.")
|
| 1951 |
app.queue().launch(share=False)
|