alx-d commited on
Commit
4fc1e9c
·
verified ·
1 Parent(s): 10a7b38

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. advanced_rag.py +1168 -1113
  2. requirements.txt +1 -1
advanced_rag.py CHANGED
@@ -1,1113 +1,1168 @@
1
- import os
2
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
- import datetime
4
- import functools
5
- import traceback
6
- from typing import List, Optional, Any, Dict
7
-
8
- import torch
9
- import transformers
10
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
11
- from langchain_community.llms import HuggingFacePipeline
12
-
13
- # Other LangChain and community imports
14
- from langchain_community.document_loaders import OnlinePDFLoader
15
- from langchain.text_splitter import RecursiveCharacterTextSplitter
16
- from langchain_community.vectorstores import FAISS
17
- from langchain.embeddings import HuggingFaceEmbeddings
18
- from langchain_community.retrievers import BM25Retriever
19
- from langchain.retrievers import EnsembleRetriever
20
- from langchain.prompts import ChatPromptTemplate
21
- from langchain.schema import StrOutputParser, Document
22
- from langchain_core.runnables import RunnableParallel, RunnableLambda
23
- from transformers.quantizers.auto import AutoQuantizationConfig
24
- import gradio as gr
25
- import requests
26
- from pydantic import PrivateAttr
27
- import pydantic
28
-
29
- from langchain.llms.base import LLM
30
- from typing import Any, Optional, List
31
- import typing
32
- import time
33
-
34
- print("Pydantic Version: ")
35
- print(pydantic.__version__)
36
- # Add Mistral imports with fallback handling
37
-
38
- try:
39
- from mistralai import Mistral
40
- MISTRAL_AVAILABLE = True
41
- debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}")
42
- debug_print("Loaded latest Mistral client library")
43
- except ImportError:
44
- MISTRAL_AVAILABLE = False
45
- debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}")
46
- debug_print("Mistral client library not found. Install with: pip install mistralai")
47
-
48
- def debug_print(message: str):
49
- print(f"[{datetime.datetime.now().isoformat()}] {message}", flush=True)
50
-
51
- def word_count(text: str) -> int:
52
- return len(text.split())
53
-
54
- # Initialize a tokenizer for token counting (using gpt2 as a generic fallback)
55
- def initialize_tokenizer():
56
- try:
57
- return AutoTokenizer.from_pretrained("gpt2")
58
- except Exception as e:
59
- debug_print("Failed to initialize tokenizer: " + str(e))
60
- return None
61
-
62
- global_tokenizer = initialize_tokenizer()
63
-
64
- def count_tokens(text: str) -> int:
65
- if global_tokenizer:
66
- try:
67
- return len(global_tokenizer.encode(text))
68
- except Exception as e:
69
- return len(text.split())
70
- return len(text.split())
71
-
72
-
73
- # Add these imports at the top of your file
74
- import uuid
75
- import threading
76
- import queue
77
- from typing import Dict, Any, Tuple, Optional
78
- import time
79
-
80
- # Global storage for jobs and results
81
- jobs = {} # Stores job status and results
82
- results_queue = queue.Queue() # Thread-safe queue for completed jobs
83
- processing_lock = threading.Lock() # Prevent simultaneous processing of the same job
84
-
85
- # Add a global variable to store the last job ID
86
- last_job_id = None
87
-
88
- # Add these missing async processing functions
89
-
90
- def process_in_background(job_id, function, args):
91
- """Process a function in the background and store results"""
92
- try:
93
- debug_print(f"Processing job {job_id} in background")
94
- result = function(*args)
95
- results_queue.put((job_id, result))
96
- debug_print(f"Job {job_id} completed and added to results queue")
97
- except Exception as e:
98
- debug_print(f"Error in background job {job_id}: {str(e)}")
99
- error_result = (f"Error processing job: {str(e)}", "", "", "")
100
- results_queue.put((job_id, error_result))
101
-
102
- def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
103
- """Asynchronous version of load_pdfs_updated to prevent timeouts"""
104
- global last_job_id
105
- if not file_links:
106
- return "Please enter non-empty URLs", "", "Model used: N/A", "", "", get_job_list()
107
-
108
- job_id = str(uuid.uuid4())
109
- debug_print(f"Starting async job {job_id} for file loading")
110
-
111
- # Start background thread
112
- threading.Thread(
113
- target=process_in_background,
114
- args=(job_id, load_pdfs_updated, [file_links, model_choice, prompt_template, bm25_weight, temperature, top_p])
115
- ).start()
116
-
117
- job_query = f"Loading files: {file_links.split()[0]}..." if file_links else "No files"
118
- jobs[job_id] = {
119
- "status": "processing",
120
- "type": "load_files",
121
- "start_time": time.time(),
122
- "query": job_query
123
- }
124
-
125
- last_job_id = job_id
126
-
127
- return (
128
- f"Files submitted and processing in the background (Job ID: {job_id}).\n\n"
129
- f"Use 'Check Job Status' tab with this ID to get results.",
130
- f"Job ID: {job_id}",
131
- f"Model requested: {model_choice}",
132
- job_id, # Return job_id to update the job_id_input component
133
- job_query, # Return job_query to update the job_query_display component
134
- get_job_list() # Return updated job list
135
- )
136
-
137
- def submit_query_async(query, model_choice=None):
138
- """Asynchronous version of submit_query_updated to prevent timeouts"""
139
- global last_job_id
140
- if not query:
141
- return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0", "", "", get_job_list()
142
-
143
- job_id = str(uuid.uuid4())
144
- debug_print(f"Starting async job {job_id} for query: {query}")
145
-
146
- # Update model if specified
147
- if model_choice and rag_chain and rag_chain.llm_choice != model_choice:
148
- debug_print(f"Updating model to {model_choice} for this query")
149
- rag_chain.update_llm_pipeline(model_choice, rag_chain.temperature, rag_chain.top_p,
150
- rag_chain.prompt_template, rag_chain.bm25_weight)
151
-
152
- # Start background thread
153
- threading.Thread(
154
- target=process_in_background,
155
- args=(job_id, submit_query_updated, [query])
156
- ).start()
157
-
158
- jobs[job_id] = {
159
- "status": "processing",
160
- "type": "query",
161
- "start_time": time.time(),
162
- "query": query,
163
- "model": rag_chain.llm_choice if hasattr(rag_chain, 'llm_choice') else "Unknown"
164
- }
165
-
166
- last_job_id = job_id
167
-
168
- return (
169
- f"Query submitted and processing in the background (Job ID: {job_id}).\n\n"
170
- f"Use 'Check Job Status' tab with this ID to get results.",
171
- f"Job ID: {job_id}",
172
- f"Input tokens: {count_tokens(query)}",
173
- "Output tokens: pending",
174
- job_id, # Return job_id to update the job_id_input component
175
- query, # Return query to update the job_query_display component
176
- get_job_list() # Return updated job list
177
- )
178
-
179
- def update_ui_with_last_job_id():
180
- # This function doesn't need to do anything anymore
181
- # We'll update the UI directly in the functions that call this
182
- pass
183
-
184
- # Function to display all jobs as a clickable list
185
- def get_job_list():
186
- job_list_md = "### Submitted Jobs\n\n"
187
-
188
- if not jobs:
189
- return "No jobs found. Submit a query or load files to create jobs."
190
-
191
- # Sort jobs by start time (newest first)
192
- sorted_jobs = sorted(
193
- [(job_id, job_info) for job_id, job_info in jobs.items()],
194
- key=lambda x: x[1].get("start_time", 0),
195
- reverse=True
196
- )
197
-
198
- for job_id, job_info in sorted_jobs:
199
- status = job_info.get("status", "unknown")
200
- job_type = job_info.get("type", "unknown")
201
- query = job_info.get("query", "")
202
- start_time = job_info.get("start_time", 0)
203
- time_str = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
204
-
205
- # Create a shortened query preview
206
- query_preview = query[:30] + "..." if query and len(query) > 30 else query or "N/A"
207
-
208
- # Add color and icons based on status
209
- if status == "processing":
210
- # Red color with processing icon for processing jobs
211
- status_formatted = f"<span style='color: red'>⏳ {status}</span>"
212
- elif status == "completed":
213
- # Green color with checkmark for completed jobs
214
- status_formatted = f"<span style='color: green'>✅ {status}</span>"
215
- else:
216
- # Default formatting for unknown status
217
- status_formatted = f"<span style='color: orange'>❓ {status}</span>"
218
-
219
- # Create clickable links using Markdown
220
- if job_type == "query":
221
- job_list_md += f"- [{job_id}](javascript:void) - {time_str} - {status_formatted} - Query: {query_preview}\n"
222
- else:
223
- job_list_md += f"- [{job_id}](javascript:void) - {time_str} - {status_formatted} - File Load Job\n"
224
-
225
- return job_list_md
226
-
227
- # Function to handle job list clicks
228
- def job_selected(job_id):
229
- if job_id in jobs:
230
- return job_id, jobs[job_id].get("query", "No query for this job")
231
- return job_id, "Job not found"
232
-
233
- # Function to refresh the job list
234
- def refresh_job_list():
235
- return get_job_list()
236
-
237
- # Function to sync model dropdown boxes
238
- def sync_model_dropdown(value):
239
- return value
240
-
241
- # Function to check job status
242
- def check_job_status(job_id):
243
- if not job_id:
244
- return "Please enter a job ID", "", "", "", ""
245
-
246
- # Process any completed jobs in the queue
247
- try:
248
- while not results_queue.empty():
249
- completed_id, result = results_queue.get_nowait()
250
- if completed_id in jobs:
251
- jobs[completed_id]["status"] = "completed"
252
- jobs[completed_id]["result"] = result
253
- jobs[completed_id]["end_time"] = time.time()
254
- debug_print(f"Job {completed_id} completed and stored in jobs dictionary")
255
- except queue.Empty:
256
- pass
257
-
258
- # Check if the requested job exists
259
- if job_id not in jobs:
260
- return "Job not found. Please check the ID and try again.", "", "", "", ""
261
-
262
- job = jobs[job_id]
263
- job_query = job.get("query", "No query available for this job")
264
-
265
- # If job is still processing
266
- if job["status"] == "processing":
267
- elapsed_time = time.time() - job["start_time"]
268
- job_type = job.get("type", "unknown")
269
-
270
- if job_type == "load_files":
271
- return (
272
- f"Files are still being processed (elapsed: {elapsed_time:.1f}s).\n\n"
273
- f"Try checking again in a few seconds.",
274
- f"Job ID: {job_id}",
275
- f"Status: Processing",
276
- "",
277
- job_query
278
- )
279
- else: # query job
280
- return (
281
- f"Query is still being processed (elapsed: {elapsed_time:.1f}s).\n\n"
282
- f"Try checking again in a few seconds.",
283
- f"Job ID: {job_id}",
284
- f"Input tokens: {count_tokens(job.get('query', ''))}",
285
- "Output tokens: pending",
286
- job_query
287
- )
288
-
289
- # If job is completed
290
- if job["status"] == "completed":
291
- result = job["result"]
292
- processing_time = job["end_time"] - job["start_time"]
293
-
294
- if job.get("type") == "load_files":
295
- return (
296
- f"{result[0]}\n\nProcessing time: {processing_time:.1f}s",
297
- result[1],
298
- result[2],
299
- "",
300
- job_query
301
- )
302
- else: # query job
303
- return (
304
- f"{result[0]}\n\nProcessing time: {processing_time:.1f}s",
305
- result[1],
306
- result[2],
307
- result[3],
308
- job_query
309
- )
310
-
311
- # Fallback for unknown status
312
- return f"Job status: {job['status']}", "", "", "", job_query
313
-
314
- # Function to clean up old jobs
315
- def cleanup_old_jobs():
316
- current_time = time.time()
317
- to_delete = []
318
-
319
- for job_id, job in jobs.items():
320
- # Keep completed jobs for 1 hour, processing jobs for 2 hours
321
- if job["status"] == "completed" and (current_time - job.get("end_time", 0)) > 3600:
322
- to_delete.append(job_id)
323
- elif job["status"] == "processing" and (current_time - job.get("start_time", 0)) > 7200:
324
- to_delete.append(job_id)
325
-
326
- for job_id in to_delete:
327
- del jobs[job_id]
328
-
329
- debug_print(f"Cleaned up {len(to_delete)} old jobs. {len(jobs)} jobs remaining.")
330
- return f"Cleaned up {len(to_delete)} old jobs", "", ""
331
-
332
- # Improve the truncate_prompt function to be more aggressive with limiting context
333
- def truncate_prompt(prompt: str, max_tokens: int = 4096) -> str:
334
- """Truncate prompt to fit within token limit, preserving the most recent/relevant parts."""
335
- if not prompt:
336
- return ""
337
-
338
- if global_tokenizer:
339
- try:
340
- tokens = global_tokenizer.encode(prompt)
341
- if len(tokens) > max_tokens:
342
- # For prompts, we often want to keep the beginning instructions and the end context
343
- # So we'll keep the first 20% and the last 80% of the max tokens
344
- beginning_tokens = int(max_tokens * 0.2)
345
- ending_tokens = max_tokens - beginning_tokens
346
-
347
- new_tokens = tokens[:beginning_tokens] + tokens[-(ending_tokens):]
348
- return global_tokenizer.decode(new_tokens)
349
- except Exception as e:
350
- debug_print(f"Truncation error: {str(e)}")
351
-
352
- # Fallback to word-based truncation
353
- words = prompt.split()
354
- if len(words) > max_tokens:
355
- beginning_words = int(max_tokens * 0.2)
356
- ending_words = max_tokens - beginning_words
357
-
358
- return " ".join(words[:beginning_words] + words[-(ending_words):])
359
-
360
- return prompt
361
-
362
-
363
-
364
-
365
- default_prompt = """\
366
- {conversation_history}
367
- Use the following context to provide a detailed technical answer to the user's question.
368
- Do not include an introduction like "Based on the provided documents, ...". Just answer the question.
369
- If you don't know the answer, please respond with "I don't know".
370
-
371
- Context:
372
- {context}
373
-
374
- User's question:
375
- {question}
376
- """
377
-
378
- def load_txt_from_url(url: str) -> Document:
379
- response = requests.get(url)
380
- if response.status_code == 200:
381
- text = response.text.strip()
382
- if not text:
383
- raise ValueError(f"TXT file at {url} is empty.")
384
- return Document(page_content=text, metadata={"source": url})
385
- else:
386
- raise Exception(f"Failed to load {url} with status {response.status_code}")
387
-
388
- class ElevatedRagChain:
389
- def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
390
- bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95) -> None:
391
- debug_print(f"Initializing ElevatedRagChain with model: {llm_choice}")
392
- self.embed_func = HuggingFaceEmbeddings(
393
- model_name="sentence-transformers/all-MiniLM-L6-v2",
394
- model_kwargs={"device": "cpu"}
395
- )
396
- self.bm25_weight = bm25_weight
397
- self.faiss_weight = 1.0 - bm25_weight
398
- self.top_k = 5
399
- self.llm_choice = llm_choice
400
- self.temperature = temperature
401
- self.top_p = top_p
402
- self.prompt_template = prompt_template
403
- self.context = ""
404
- self.conversation_history: List[Dict[str, str]] = []
405
- self.raw_data = None
406
- self.split_data = None
407
- self.elevated_rag_chain = None
408
-
409
- # Instance method to capture context and conversation history
410
- def capture_context(self, result):
411
- self.context = "\n".join([str(doc) for doc in result["context"]])
412
- result["context"] = self.context
413
- history_text = (
414
- "\n".join([f"Q: {conv['query']}\nA: {conv['response']}" for conv in self.conversation_history])
415
- if self.conversation_history else ""
416
- )
417
- result["conversation_history"] = history_text
418
- return result
419
-
420
- # Instance method to extract question from input data
421
- def extract_question(self, input_data):
422
- return input_data["question"]
423
-
424
- # Improve error handling in the ElevatedRagChain class
425
- def create_llm_pipeline(self):
426
- from langchain.llms.base import LLM # Import LLM here so it's always defined
427
- normalized = self.llm_choice.lower()
428
- try:
429
- if "remote" in normalized:
430
- debug_print("Creating remote Meta-Llama-3 pipeline via Hugging Face Inference API...")
431
- from huggingface_hub import InferenceClient
432
- repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
433
- hf_api_token = os.environ.get("HF_API_TOKEN")
434
- if not hf_api_token:
435
- raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
436
-
437
- client = InferenceClient(token=hf_api_token, timeout=120)
438
-
439
- # We no longer use wait_for_model because it's unsupported
440
- def remote_generate(prompt: str) -> str:
441
- max_retries = 3
442
- backoff = 2 # start with 2 seconds
443
- for attempt in range(max_retries):
444
- try:
445
- debug_print(f"Remote generation attempt {attempt+1}")
446
- response = client.text_generation(
447
- prompt,
448
- model=repo_id,
449
- temperature=self.temperature,
450
- top_p=self.top_p,
451
- max_new_tokens=512 # Reduced token count for speed
452
- )
453
- return response
454
- except Exception as e:
455
- debug_print(f"Attempt {attempt+1} failed with error: {e}")
456
- if attempt == max_retries - 1:
457
- raise
458
- time.sleep(backoff)
459
- backoff *= 2 # exponential backoff
460
- return "Failed to generate response after multiple attempts."
461
-
462
- class RemoteLLM(LLM):
463
- @property
464
- def _llm_type(self) -> str:
465
- return "remote_llm"
466
-
467
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
468
- return remote_generate(prompt)
469
-
470
- @property
471
- def _identifying_params(self) -> dict:
472
- return {"model": repo_id}
473
-
474
- debug_print("Remote Meta-Llama-3 pipeline created successfully.")
475
- return RemoteLLM()
476
-
477
- elif "mistral-api" in normalized:
478
- debug_print("Creating Mistral API pipeline...")
479
- mistral_api_key = os.environ.get("MISTRAL_API_KEY")
480
- if not mistral_api_key:
481
- raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
482
- try:
483
- from mistralai import Mistral
484
- debug_print("Mistral library imported successfully")
485
- except ImportError:
486
- debug_print("Mistral client library not installed. Falling back to Llama pipeline.")
487
- normalized = "llama"
488
- if normalized != "llama":
489
- # from pydantic import PrivateAttr
490
- # from langchain.llms.base import LLM
491
- # from typing import Any, Optional, List
492
- # import typing
493
-
494
- class MistralLLM(LLM):
495
- temperature: float = 0.7
496
- top_p: float = 0.95
497
- _client: Any = PrivateAttr(default=None)
498
-
499
- def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
500
- try:
501
- super().__init__(**kwargs)
502
- # Bypass Pydantic's __setattr__ to assign to _client
503
- object.__setattr__(self, '_client', Mistral(api_key=api_key))
504
- self.temperature = temperature
505
- self.top_p = top_p
506
- except Exception as e:
507
- debug_print(f"Init Mistral failed with error: {e}")
508
-
509
- @property
510
- def _llm_type(self) -> str:
511
- return "mistral_llm"
512
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
513
- try:
514
- debug_print("Calling Mistral API...")
515
- response = self._client.chat.complete(
516
- model="mistral-small-latest",
517
- messages=[{"role": "user", "content": prompt}],
518
- temperature=self.temperature,
519
- top_p=self.top_p
520
- )
521
- return response.choices[0].message.content
522
- except Exception as e:
523
- debug_print(f"Mistral API error: {str(e)}")
524
- return f"Error generating response: {str(e)}"
525
- @property
526
- def _identifying_params(self) -> dict:
527
- return {"model": "mistral-small-latest"}
528
- debug_print("Creating Mistral LLM instance")
529
- mistral_llm = MistralLLM(api_key=mistral_api_key, temperature=self.temperature, top_p=self.top_p)
530
- debug_print("Mistral API pipeline created successfully.")
531
- return mistral_llm
532
-
533
- else:
534
- # Default case - using a fallback model (or Llama)
535
- debug_print("Using local/fallback model pipeline")
536
- model_id = "facebook/opt-350m" # Use a smaller model as fallback
537
- pipe = pipeline(
538
- "text-generation",
539
- model=model_id,
540
- device=-1, # CPU
541
- max_length=1024
542
- )
543
-
544
- class LocalLLM(LLM):
545
- @property
546
- def _llm_type(self) -> str:
547
- return "local_llm"
548
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
549
- # For this fallback, truncate prompt if it exceeds limits
550
- reserved_gen = 128
551
- max_total = 1024
552
- max_prompt_tokens = max_total - reserved_gen
553
- truncated_prompt = truncate_prompt(prompt, max_tokens=max_prompt_tokens)
554
- generated = pipe(truncated_prompt, max_new_tokens=reserved_gen)[0]["generated_text"]
555
- return generated
556
- @property
557
- def _identifying_params(self) -> dict:
558
- return {"model": model_id, "max_length": 1024}
559
-
560
- debug_print("Local fallback pipeline created.")
561
- return LocalLLM()
562
-
563
- except Exception as e:
564
- debug_print(f"Error creating LLM pipeline: {str(e)}")
565
- # Return a dummy LLM that explains the error
566
- class ErrorLLM(LLM):
567
- @property
568
- def _llm_type(self) -> str:
569
- return "error_llm"
570
-
571
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
572
- return f"Error initializing LLM: \n\nPlease check your environment variables and try again."
573
-
574
- @property
575
- def _identifying_params(self) -> dict:
576
- return {"model": "error"}
577
-
578
- return ErrorLLM()
579
-
580
-
581
- def update_llm_pipeline(self, new_model_choice: str, temperature: float, top_p: float, prompt_template: str, bm25_weight: float):
582
- debug_print(f"Updating chain with new model: {new_model_choice}")
583
- self.llm_choice = new_model_choice
584
- self.temperature = temperature
585
- self.top_p = top_p
586
- self.prompt_template = prompt_template
587
- self.bm25_weight = bm25_weight
588
- self.faiss_weight = 1.0 - bm25_weight
589
- self.llm = self.create_llm_pipeline()
590
- def format_response(response: str) -> str:
591
- input_tokens = count_tokens(self.context + self.prompt_template)
592
- output_tokens = count_tokens(response)
593
- formatted = f"### Response\n\n{response}\n\n---\n"
594
- formatted += f"- **Input tokens:** {input_tokens}\n"
595
- formatted += f"- **Output tokens:** {output_tokens}\n"
596
- formatted += f"- **Generated using:** {self.llm_choice}\n"
597
- formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
598
- return formatted
599
- base_runnable = RunnableParallel({
600
- "context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
601
- "question": RunnableLambda(self.extract_question)
602
- }) | self.capture_context
603
- self.elevated_rag_chain = base_runnable | self.rag_prompt | self.llm | format_response
604
- debug_print("Chain updated successfully with new LLM pipeline.")
605
-
606
- def add_pdfs_to_vectore_store(self, file_links: List[str]) -> None:
607
- debug_print(f"Processing files using {self.llm_choice}")
608
- self.raw_data = []
609
- for link in file_links:
610
- if link.lower().endswith(".pdf"):
611
- debug_print(f"Loading PDF: {link}")
612
- loaded_docs = OnlinePDFLoader(link).load()
613
- if loaded_docs:
614
- self.raw_data.append(loaded_docs[0])
615
- else:
616
- debug_print(f"No content found in PDF: {link}")
617
- elif link.lower().endswith(".txt") or link.lower().endswith(".utf-8"):
618
- debug_print(f"Loading TXT: {link}")
619
- try:
620
- self.raw_data.append(load_txt_from_url(link))
621
- except Exception as e:
622
- debug_print(f"Error loading TXT file {link}: {e}")
623
- else:
624
- debug_print(f"File type not supported for URL: {link}")
625
- if not self.raw_data:
626
- raise ValueError("No files were successfully loaded. Please check the URLs and file formats.")
627
- debug_print("Files loaded successfully.")
628
- debug_print("Starting text splitting...")
629
- self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=100)
630
- self.split_data = self.text_splitter.split_documents(self.raw_data)
631
- if not self.split_data:
632
- raise ValueError("Text splitting resulted in no chunks. Check the file contents.")
633
- debug_print(f"Text splitting completed. Number of chunks: {len(self.split_data)}")
634
- debug_print("Creating BM25 retriever...")
635
- self.bm25_retriever = BM25Retriever.from_documents(self.split_data)
636
- self.bm25_retriever.k = self.top_k
637
- debug_print("BM25 retriever created.")
638
- debug_print("Embedding chunks and creating FAISS vector store...")
639
- self.vector_store = FAISS.from_documents(self.split_data, self.embed_func)
640
- self.faiss_retriever = self.vector_store.as_retriever(search_kwargs={"k": self.top_k})
641
- debug_print("FAISS vector store created successfully.")
642
- self.ensemble_retriever = EnsembleRetriever(
643
- retrievers=[self.bm25_retriever, self.faiss_retriever],
644
- weights=[self.bm25_weight, self.faiss_weight]
645
- )
646
-
647
- base_runnable = RunnableParallel({
648
- "context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
649
- "question": RunnableLambda(self.extract_question)
650
- }) | self.capture_context
651
-
652
- # Ensure the prompt template is set
653
- self.rag_prompt = ChatPromptTemplate.from_template(self.prompt_template)
654
- if self.rag_prompt is None:
655
- raise ValueError("Prompt template could not be created from the given template.")
656
- prompt_runnable = RunnableLambda(lambda vars: self.rag_prompt.format(**vars))
657
-
658
- self.str_output_parser = StrOutputParser()
659
- debug_print("Selecting LLM pipeline based on choice: " + self.llm_choice)
660
- self.llm = self.create_llm_pipeline()
661
- if self.llm is None:
662
- raise ValueError("LLM pipeline creation failed.")
663
-
664
- def format_response(response: str) -> str:
665
- input_tokens = count_tokens(self.context + self.prompt_template)
666
- output_tokens = count_tokens(response)
667
- formatted = f"### Response\n\n{response}\n\n---\n"
668
- formatted += f"- **Input tokens:** {input_tokens}\n"
669
- formatted += f"- **Output tokens:** {output_tokens}\n"
670
- formatted += f"- **Generated using:** {self.llm_choice}\n"
671
- formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
672
- return formatted
673
-
674
- self.elevated_rag_chain = base_runnable | prompt_runnable | self.llm | format_response
675
- debug_print("Elevated RAG chain successfully built and ready to use.")
676
-
677
-
678
-
679
- def get_current_context(self) -> str:
680
- base_context = "\n".join([str(doc) for doc in self.split_data[:3]]) if self.split_data else "No context available."
681
- history_summary = "\n\n---\n**Recent Conversations (last 3):**\n"
682
- recent = self.conversation_history[-3:]
683
- if recent:
684
- for i, conv in enumerate(recent, 1):
685
- history_summary += f"**Conversation {i}:**\n- Query: {conv['query']}\n- Response: {conv['response']}\n"
686
- else:
687
- history_summary += "No conversation history."
688
- return base_context + history_summary
689
-
690
- # ----------------------------
691
- # Gradio Interface Functions
692
- # ----------------------------
693
- global rag_chain
694
- rag_chain = ElevatedRagChain()
695
-
696
- def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
697
- debug_print("Inside load_pdfs function.")
698
- if not file_links:
699
- debug_print("Please enter non-empty URLs")
700
- return "Please enter non-empty URLs", "Word count: N/A", "Model used: N/A", "Context: N/A"
701
- try:
702
- links = [link.strip() for link in file_links.split("\n") if link.strip()]
703
- global rag_chain
704
- if rag_chain.raw_data:
705
- rag_chain.update_llm_pipeline(model_choice, temperature, top_p, prompt_template, bm25_weight)
706
- context_display = rag_chain.get_current_context()
707
- response_msg = f"Files already loaded. Chain updated with model: {model_choice}"
708
- return (
709
- response_msg,
710
- f"Word count: {word_count(rag_chain.context)}",
711
- f"Model used: {rag_chain.llm_choice}",
712
- f"Context:\n{context_display}"
713
- )
714
- else:
715
- rag_chain = ElevatedRagChain(
716
- llm_choice=model_choice,
717
- prompt_template=prompt_template,
718
- bm25_weight=bm25_weight,
719
- temperature=temperature,
720
- top_p=top_p
721
- )
722
- rag_chain.add_pdfs_to_vectore_store(links)
723
- context_display = rag_chain.get_current_context()
724
- response_msg = f"Files loaded successfully. Using model: {model_choice}"
725
- return (
726
- response_msg,
727
- f"Word count: {word_count(rag_chain.context)}",
728
- f"Model used: {rag_chain.llm_choice}",
729
- f"Context:\n{context_display}"
730
- )
731
- except Exception as e:
732
- error_msg = traceback.format_exc()
733
- debug_print("Could not load files. Error: " + error_msg)
734
- return (
735
- "Error loading files: " + str(e),
736
- f"Word count: {word_count('')}",
737
- f"Model used: {rag_chain.llm_choice}",
738
- "Context: N/A"
739
- )
740
-
741
- def update_model(new_model: str):
742
- global rag_chain
743
- if rag_chain and rag_chain.raw_data:
744
- rag_chain.update_llm_pipeline(new_model, rag_chain.temperature, rag_chain.top_p,
745
- rag_chain.prompt_template, rag_chain.bm25_weight)
746
- debug_print(f"Model updated to {rag_chain.llm_choice}")
747
- return f"Model updated to: {rag_chain.llm_choice}"
748
- else:
749
- return "No files loaded; please load files first."
750
-
751
-
752
- # Update submit_query_updated to better handle context limitation
753
- def submit_query_updated(query):
754
- debug_print(f"Processing query: {query}")
755
- if not query:
756
- debug_print("Empty query received")
757
- return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
758
-
759
- if not hasattr(rag_chain, 'elevated_rag_chain') or not rag_chain.raw_data:
760
- debug_print("RAG chain not initialized")
761
- return "Please load files first.", "", "Input tokens: 0", "Output tokens: 0"
762
-
763
- try:
764
- # Determine max context size based on model
765
- model_name = rag_chain.llm_choice.lower()
766
- max_context_tokens = 32000 if "mistral" in model_name else 4096
767
-
768
- # Reserve 20% of tokens for the question and response generation
769
- reserved_tokens = int(max_context_tokens * 0.2)
770
- max_context_tokens -= reserved_tokens
771
-
772
- # Collect conversation history (last 2 only to save tokens)
773
- if rag_chain.conversation_history:
774
- recent_history = rag_chain.conversation_history[-2:]
775
- history_text = "\n".join([f"Q: {conv['query']}\nA: {conv['response'][:300]}..."
776
- for conv in recent_history])
777
- else:
778
- history_text = ""
779
-
780
- # Get history token count
781
- history_tokens = count_tokens(history_text)
782
-
783
- # Adjust context tokens based on history size
784
- context_tokens = max_context_tokens - history_tokens
785
-
786
- # Ensure we have some minimum context
787
- context_tokens = max(context_tokens, 1000)
788
-
789
- # Truncate context if needed
790
- context = truncate_prompt(rag_chain.context, max_tokens=context_tokens)
791
-
792
- debug_print(f"Using model: {model_name}, context tokens: {count_tokens(context)}, history tokens: {history_tokens}")
793
-
794
- prompt_variables = {
795
- "conversation_history": history_text,
796
- "context": context,
797
- "question": query
798
- }
799
-
800
- debug_print("Invoking RAG chain")
801
- response = rag_chain.elevated_rag_chain.invoke({"question": query})
802
-
803
- # Store only a reasonable amount of the response in history
804
- trimmed_response = response[:1000] + ("..." if len(response) > 1000 else "")
805
- rag_chain.conversation_history.append({"query": query, "response": trimmed_response})
806
-
807
- input_token_count = count_tokens(query)
808
- output_token_count = count_tokens(response)
809
-
810
- debug_print(f"Query processed successfully. Output tokens: {output_token_count}")
811
-
812
- return (
813
- response,
814
- rag_chain.get_current_context(),
815
- f"Input tokens: {input_token_count}",
816
- f"Output tokens: {output_token_count}"
817
- )
818
- except Exception as e:
819
- error_msg = traceback.format_exc()
820
- debug_print(f"LLM error: {error_msg}")
821
- return (
822
- f"Query error: {str(e)}\n\nTry using a smaller document or simplifying your query.",
823
- "",
824
- "Input tokens: 0",
825
- "Output tokens: 0"
826
- )
827
-
828
- def reset_app_updated():
829
- global rag_chain
830
- rag_chain = ElevatedRagChain()
831
- debug_print("App reset successfully.")
832
- return (
833
- "App reset successfully. You can now load new files",
834
- "",
835
- "Model used: Not selected"
836
- )
837
-
838
- # ----------------------------
839
- # Gradio Interface Setup
840
- # ----------------------------
841
- custom_css = """
842
- textarea {
843
- overflow-y: scroll !important;
844
- max-height: 200px;
845
- }
846
- """
847
-
848
- # Update the Gradio interface to include job status checking
849
- with gr.Blocks(css=custom_css, js="""
850
- document.addEventListener('DOMContentLoaded', function() {
851
- // Add event listener for job list clicks
852
- const jobListInterval = setInterval(() => {
853
- const jobLinks = document.querySelectorAll('.job-list-container a');
854
- if (jobLinks.length > 0) {
855
- jobLinks.forEach(link => {
856
- link.addEventListener('click', function(e) {
857
- e.preventDefault();
858
- const jobId = this.textContent.split(' ')[0];
859
- // Find the job ID input textbox and set its value
860
- const jobIdInput = document.querySelector('.job-id-input input');
861
- if (jobIdInput) {
862
- jobIdInput.value = jobId;
863
- // Trigger the input event to update Gradio's state
864
- jobIdInput.dispatchEvent(new Event('input', { bubbles: true }));
865
- }
866
- });
867
- });
868
- clearInterval(jobListInterval);
869
- }
870
- }, 500);
871
- });
872
- """) as app:
873
- gr.Markdown('''# PhiRAG - Async Version
874
- **PhiRAG** Query Your Data with Advanced RAG Techniques
875
-
876
- **Model Selection & Parameters:** Choose from the following options:
877
- - 🇺🇸 Remote Meta-Llama-3 - has context windows of 8000 tokens
878
- - 🇪🇺 Mistral-API - has context windows of 32000 tokens
879
-
880
- **🔥 Randomness (Temperature):** Adjusts output predictability.
881
- - Example: 0.2 makes the output very deterministic (less creative), while 0.8 introduces more variety and spontaneity.
882
-
883
- **🎯 Word Variety (Top‑p):** Limits word choices to a set probability percentage.
884
- - Example: 0.5 restricts output to the most likely 50% of token choices for a focused answer; 0.95 allows almost all possibilities for more diverse responses.
885
-
886
- **⚖️ BM25 Weight:** Adjust Lexical vs Semantics.
887
- - Example: A value of 0.8 puts more emphasis on exact keyword (lexical) matching, while 0.3 shifts emphasis toward semantic similarity.
888
-
889
- **✏️ Prompt Template:** Edit as desired.
890
-
891
- **🔗 File URLs:** Enter one URL per line (.pdf or .txt).\
892
- - Example: Provide one URL per line, such as
893
- https://www.gutenberg.org/ebooks/8438.txt.utf-8
894
-
895
- **🔍 Query:** Enter your query below.
896
-
897
- **⚠️ IMPORTANT: This app now uses asynchronous processing to avoid timeout issues**
898
- - When you load files or submit a query, you'll receive a Job ID
899
- - Use the "Check Job Status" tab to monitor and retrieve your results
900
- ''')
901
-
902
- with gr.Tabs() as tabs:
903
- with gr.TabItem("Setup & Load Files"):
904
- with gr.Row():
905
- with gr.Column():
906
- model_dropdown = gr.Dropdown(
907
- choices=["🇺🇸 Remote Meta-Llama-3", "🇪🇺 Mistral-API"],
908
- value="🇺🇸 Remote Meta-Llama-3",
909
- label="Select Model"
910
- )
911
- temperature_slider = gr.Slider(
912
- minimum=0.1, maximum=1.0, value=0.5, step=0.1,
913
- label="Randomness (Temperature)"
914
- )
915
- top_p_slider = gr.Slider(
916
- minimum=0.1, maximum=0.99, value=0.95, step=0.05,
917
- label="Word Variety (Top-p)"
918
- )
919
- with gr.Column():
920
- pdf_input = gr.Textbox(
921
- label="Enter your file URLs (one per line)",
922
- placeholder="Enter one URL per line (.pdf or .txt)",
923
- lines=4
924
- )
925
- prompt_input = gr.Textbox(
926
- label="Custom Prompt Template",
927
- placeholder="Enter your custom prompt template here",
928
- lines=8,
929
- value=default_prompt
930
- )
931
- with gr.Column():
932
- bm25_weight_slider = gr.Slider(
933
- minimum=0.0, maximum=1.0, value=0.6, step=0.1,
934
- label="Lexical vs Semantics (BM25 Weight)"
935
- )
936
- load_button = gr.Button("Load Files (Async)")
937
- load_status = gr.Markdown("Status: Waiting for files")
938
-
939
- with gr.Row():
940
- load_response = gr.Textbox(
941
- label="Load Response",
942
- placeholder="Response will appear here",
943
- lines=4
944
- )
945
- load_context = gr.Textbox(
946
- label="Context Info",
947
- placeholder="Context info will appear here",
948
- lines=4
949
- )
950
-
951
- with gr.Row():
952
- model_output = gr.Markdown("**Current Model**: Not selected")
953
-
954
- with gr.TabItem("Submit Query"):
955
- with gr.Row():
956
- # Add this line to define the query_model_dropdown
957
- query_model_dropdown = gr.Dropdown(
958
- choices=["🇺🇸 Remote Meta-Llama-3", "🇪🇺 Mistral-API"],
959
- value="🇺🇸 Remote Meta-Llama-3",
960
- label="Query Model"
961
- )
962
-
963
- query_input = gr.Textbox(
964
- label="Enter your query here",
965
- placeholder="Type your query",
966
- lines=4
967
- )
968
- submit_button = gr.Button("Submit Query (Async)")
969
-
970
- with gr.Row():
971
- query_response = gr.Textbox(
972
- label="Query Response",
973
- placeholder="Response will appear here (formatted as Markdown)",
974
- lines=6
975
- )
976
- query_context = gr.Textbox(
977
- label="Context Information",
978
- placeholder="Retrieved context and conversation history will appear here",
979
- lines=6
980
- )
981
-
982
- with gr.Row():
983
- input_tokens = gr.Markdown("Input tokens: 0")
984
- output_tokens = gr.Markdown("Output tokens: 0")
985
-
986
- with gr.TabItem("Check Job Status"):
987
- with gr.Row():
988
- with gr.Column(scale=1):
989
- job_list = gr.Markdown(
990
- value="No jobs yet",
991
- label="Job List (Click to select)"
992
- )
993
- refresh_button = gr.Button("Refresh Job List")
994
-
995
- with gr.Column(scale=2):
996
- job_id_input = gr.Textbox(
997
- label="Job ID",
998
- placeholder="Job ID will appear here when selected from the list",
999
- lines=1
1000
- )
1001
- job_query_display = gr.Textbox(
1002
- label="Job Query",
1003
- placeholder="The query associated with this job will appear here",
1004
- lines=2,
1005
- interactive=False
1006
- )
1007
- check_button = gr.Button("Check Status")
1008
- cleanup_button = gr.Button("Cleanup Old Jobs")
1009
-
1010
- with gr.Row():
1011
- status_response = gr.Textbox(
1012
- label="Job Result",
1013
- placeholder="Job result will appear here",
1014
- lines=6
1015
- )
1016
- status_context = gr.Textbox(
1017
- label="Context Information",
1018
- placeholder="Context information will appear here",
1019
- lines=6
1020
- )
1021
-
1022
- with gr.Row():
1023
- status_tokens1 = gr.Markdown("")
1024
- status_tokens2 = gr.Markdown("")
1025
-
1026
- with gr.TabItem("App Management"):
1027
- with gr.Row():
1028
- reset_button = gr.Button("Reset App")
1029
-
1030
- with gr.Row():
1031
- reset_response = gr.Textbox(
1032
- label="Reset Response",
1033
- placeholder="Reset confirmation will appear here",
1034
- lines=2
1035
- )
1036
- reset_context = gr.Textbox(
1037
- label="",
1038
- placeholder="",
1039
- lines=2,
1040
- visible=False
1041
- )
1042
-
1043
- with gr.Row():
1044
- reset_model = gr.Markdown("")
1045
-
1046
- # Connect the buttons to their respective functions
1047
- load_button.click(
1048
- load_pdfs_async,
1049
- inputs=[pdf_input, model_dropdown, prompt_input, bm25_weight_slider, temperature_slider, top_p_slider],
1050
- outputs=[load_response, load_context, model_output, job_id_input, job_query_display, job_list]
1051
- )
1052
-
1053
- # Also sync in the other direction
1054
- query_model_dropdown.change(
1055
- fn=sync_model_dropdown,
1056
- inputs=query_model_dropdown,
1057
- outputs=model_dropdown
1058
- )
1059
-
1060
- submit_button.click(
1061
- submit_query_async,
1062
- inputs=[query_input, query_model_dropdown],
1063
- outputs=[query_response, query_context, input_tokens, output_tokens, job_id_input, job_query_display, job_list]
1064
- )
1065
-
1066
- check_button.click(
1067
- check_job_status,
1068
- inputs=[job_id_input],
1069
- outputs=[status_response, status_context, status_tokens1, status_tokens2, job_query_display]
1070
- )
1071
-
1072
- refresh_button.click(
1073
- refresh_job_list,
1074
- inputs=[],
1075
- outputs=[job_list]
1076
- )
1077
-
1078
- # Connect the job list selection event (this is handled by JavaScript)
1079
- job_id_input.change(
1080
- job_selected,
1081
- inputs=[job_id_input],
1082
- outputs=[job_id_input, job_query_display]
1083
- )
1084
-
1085
- cleanup_button.click(
1086
- cleanup_old_jobs,
1087
- inputs=[],
1088
- outputs=[status_response, status_context, status_tokens1]
1089
- )
1090
-
1091
- reset_button.click(
1092
- reset_app_updated,
1093
- inputs=[],
1094
- outputs=[reset_response, reset_context, reset_model]
1095
- )
1096
-
1097
-
1098
- model_dropdown.change(
1099
- fn=sync_model_dropdown,
1100
- inputs=model_dropdown,
1101
- outputs=query_model_dropdown
1102
- )
1103
-
1104
- # Add an event to refresh the job list on page load
1105
- app.load(
1106
- fn=refresh_job_list,
1107
- inputs=None,
1108
- outputs=job_list
1109
- )
1110
-
1111
- if __name__ == "__main__":
1112
- debug_print("Launching Gradio interface.")
1113
- app.launch(share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
+ import datetime
4
+ import functools
5
+ import traceback
6
+ from typing import List, Optional, Any, Dict
7
+
8
+ import torch
9
+ import transformers
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
11
+ from langchain_community.llms import HuggingFacePipeline
12
+
13
+ # Other LangChain and community imports
14
+ from langchain_community.document_loaders import OnlinePDFLoader
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+ from langchain_community.vectorstores import FAISS
17
+ from langchain.embeddings import HuggingFaceEmbeddings
18
+ from langchain_community.retrievers import BM25Retriever
19
+ from langchain.retrievers import EnsembleRetriever
20
+ from langchain.prompts import ChatPromptTemplate
21
+ from langchain.schema import StrOutputParser, Document
22
+ from langchain_core.runnables import RunnableParallel, RunnableLambda
23
+ from transformers.quantizers.auto import AutoQuantizationConfig
24
+ import gradio as gr
25
+ import requests
26
+ from pydantic import PrivateAttr
27
+ import pydantic
28
+
29
+ from langchain.llms.base import LLM
30
+ from typing import Any, Optional, List
31
+ import typing
32
+ import time
33
+
34
+ print("Pydantic Version: ")
35
+ print(pydantic.__version__)
36
+ # Add Mistral imports with fallback handling
37
+
38
+ try:
39
+ from mistralai import Mistral
40
+ MISTRAL_AVAILABLE = True
41
+ debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}")
42
+ debug_print("Loaded latest Mistral client library")
43
+ except ImportError:
44
+ MISTRAL_AVAILABLE = False
45
+ debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}")
46
+ debug_print("Mistral client library not found. Install with: pip install mistralai")
47
+
48
+ def debug_print(message: str):
49
+ print(f"[{datetime.datetime.now().isoformat()}] {message}", flush=True)
50
+
51
+ def word_count(text: str) -> int:
52
+ return len(text.split())
53
+
54
+ # Initialize a tokenizer for token counting (using gpt2 as a generic fallback)
55
+ def initialize_tokenizer():
56
+ try:
57
+ return AutoTokenizer.from_pretrained("gpt2")
58
+ except Exception as e:
59
+ debug_print("Failed to initialize tokenizer: " + str(e))
60
+ return None
61
+
62
+ global_tokenizer = initialize_tokenizer()
63
+
64
+ def count_tokens(text: str) -> int:
65
+ if global_tokenizer:
66
+ try:
67
+ return len(global_tokenizer.encode(text))
68
+ except Exception as e:
69
+ return len(text.split())
70
+ return len(text.split())
71
+
72
+
73
+ # Add these imports at the top of your file
74
+ import uuid
75
+ import threading
76
+ import queue
77
+ from typing import Dict, Any, Tuple, Optional
78
+ import time
79
+
80
+ # Global storage for jobs and results
81
+ jobs = {} # Stores job status and results
82
+ results_queue = queue.Queue() # Thread-safe queue for completed jobs
83
+ processing_lock = threading.Lock() # Prevent simultaneous processing of the same job
84
+
85
+ # Add a global variable to store the last job ID
86
+ last_job_id = None
87
+
88
+ # Add these missing async processing functions
89
+
90
+ def process_in_background(job_id, function, args):
91
+ """Process a function in the background and store results"""
92
+ try:
93
+ debug_print(f"Processing job {job_id} in background")
94
+ result = function(*args)
95
+ results_queue.put((job_id, result))
96
+ debug_print(f"Job {job_id} completed and added to results queue")
97
+ except Exception as e:
98
+ debug_print(f"Error in background job {job_id}: {str(e)}")
99
+ error_result = (f"Error processing job: {str(e)}", "", "", "")
100
+ results_queue.put((job_id, error_result))
101
+
102
+ def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
103
+ """Asynchronous version of load_pdfs_updated to prevent timeouts"""
104
+ global last_job_id
105
+ if not file_links:
106
+ return "Please enter non-empty URLs", "", "Model used: N/A", "", "", get_job_list()
107
+
108
+ job_id = str(uuid.uuid4())
109
+ debug_print(f"Starting async job {job_id} for file loading")
110
+
111
+ # Start background thread
112
+ threading.Thread(
113
+ target=process_in_background,
114
+ args=(job_id, load_pdfs_updated, [file_links, model_choice, prompt_template, bm25_weight, temperature, top_p])
115
+ ).start()
116
+
117
+ job_query = f"Loading files: {file_links.split()[0]}..." if file_links else "No files"
118
+ jobs[job_id] = {
119
+ "status": "processing",
120
+ "type": "load_files",
121
+ "start_time": time.time(),
122
+ "query": job_query
123
+ }
124
+
125
+ last_job_id = job_id
126
+
127
+ return (
128
+ f"Files submitted and processing in the background (Job ID: {job_id}).\n\n"
129
+ f"Use 'Check Job Status' tab with this ID to get results.",
130
+ f"Job ID: {job_id}",
131
+ f"Model requested: {model_choice}",
132
+ job_id, # Return job_id to update the job_id_input component
133
+ job_query, # Return job_query to update the job_query_display component
134
+ get_job_list() # Return updated job list
135
+ )
136
+
137
+ def submit_query_async(query, model_choice=None):
138
+ """Asynchronous version of submit_query_updated to prevent timeouts"""
139
+ global last_job_id
140
+ if not query:
141
+ return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0", "", "", get_job_list()
142
+
143
+ job_id = str(uuid.uuid4())
144
+ debug_print(f"Starting async job {job_id} for query: {query}")
145
+
146
+ # Update model if specified
147
+ if model_choice and rag_chain and rag_chain.llm_choice != model_choice:
148
+ debug_print(f"Updating model to {model_choice} for this query")
149
+ rag_chain.update_llm_pipeline(model_choice, rag_chain.temperature, rag_chain.top_p,
150
+ rag_chain.prompt_template, rag_chain.bm25_weight)
151
+
152
+ # Start background thread
153
+ threading.Thread(
154
+ target=process_in_background,
155
+ args=(job_id, submit_query_updated, [query])
156
+ ).start()
157
+
158
+ jobs[job_id] = {
159
+ "status": "processing",
160
+ "type": "query",
161
+ "start_time": time.time(),
162
+ "query": query,
163
+ "model": rag_chain.llm_choice if hasattr(rag_chain, 'llm_choice') else "Unknown"
164
+ }
165
+
166
+ last_job_id = job_id
167
+
168
+ return (
169
+ f"Query submitted and processing in the background (Job ID: {job_id}).\n\n"
170
+ f"Use 'Check Job Status' tab with this ID to get results.",
171
+ f"Job ID: {job_id}",
172
+ f"Input tokens: {count_tokens(query)}",
173
+ "Output tokens: pending",
174
+ job_id, # Return job_id to update the job_id_input component
175
+ query, # Return query to update the job_query_display component
176
+ get_job_list() # Return updated job list
177
+ )
178
+
179
+ def update_ui_with_last_job_id():
180
+ # This function doesn't need to do anything anymore
181
+ # We'll update the UI directly in the functions that call this
182
+ pass
183
+
184
+ # Function to display all jobs as a clickable list
185
+ def get_job_list():
186
+ job_list_md = "### Submitted Jobs\n\n"
187
+
188
+ if not jobs:
189
+ return "No jobs found. Submit a query or load files to create jobs."
190
+
191
+ # Sort jobs by start time (newest first)
192
+ sorted_jobs = sorted(
193
+ [(job_id, job_info) for job_id, job_info in jobs.items()],
194
+ key=lambda x: x[1].get("start_time", 0),
195
+ reverse=True
196
+ )
197
+
198
+ for job_id, job_info in sorted_jobs:
199
+ status = job_info.get("status", "unknown")
200
+ job_type = job_info.get("type", "unknown")
201
+ query = job_info.get("query", "")
202
+ start_time = job_info.get("start_time", 0)
203
+ time_str = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
204
+
205
+ # Create a shortened query preview
206
+ query_preview = query[:30] + "..." if query and len(query) > 30 else query or "N/A"
207
+
208
+ # Add color and icons based on status
209
+ if status == "processing":
210
+ # Red color with processing icon for processing jobs
211
+ status_formatted = f"<span style='color: red'>⏳ {status}</span>"
212
+ elif status == "completed":
213
+ # Green color with checkmark for completed jobs
214
+ status_formatted = f"<span style='color: green'>✅ {status}</span>"
215
+ else:
216
+ # Default formatting for unknown status
217
+ status_formatted = f"<span style='color: orange'>❓ {status}</span>"
218
+
219
+ # Create clickable links using Markdown
220
+ if job_type == "query":
221
+ job_list_md += f"- [{job_id}](javascript:void) - {time_str} - {status_formatted} - Query: {query_preview}\n"
222
+ else:
223
+ job_list_md += f"- [{job_id}](javascript:void) - {time_str} - {status_formatted} - File Load Job\n"
224
+
225
+ return job_list_md
226
+
227
+ # Function to handle job list clicks
228
+ def job_selected(job_id):
229
+ if job_id in jobs:
230
+ return job_id, jobs[job_id].get("query", "No query for this job")
231
+ return job_id, "Job not found"
232
+
233
+ # Function to refresh the job list
234
+ def refresh_job_list():
235
+ return get_job_list()
236
+
237
+ # Function to sync model dropdown boxes
238
+ def sync_model_dropdown(value):
239
+ return value
240
+
241
+ # Function to check job status
242
+ def check_job_status(job_id):
243
+ if not job_id:
244
+ return "Please enter a job ID", "", "", "", ""
245
+
246
+ # Process any completed jobs in the queue
247
+ try:
248
+ while not results_queue.empty():
249
+ completed_id, result = results_queue.get_nowait()
250
+ if completed_id in jobs:
251
+ jobs[completed_id]["status"] = "completed"
252
+ jobs[completed_id]["result"] = result
253
+ jobs[completed_id]["end_time"] = time.time()
254
+ debug_print(f"Job {completed_id} completed and stored in jobs dictionary")
255
+ except queue.Empty:
256
+ pass
257
+
258
+ # Check if the requested job exists
259
+ if job_id not in jobs:
260
+ return "Job not found. Please check the ID and try again.", "", "", "", ""
261
+
262
+ job = jobs[job_id]
263
+ job_query = job.get("query", "No query available for this job")
264
+
265
+ # If job is still processing
266
+ if job["status"] == "processing":
267
+ elapsed_time = time.time() - job["start_time"]
268
+ job_type = job.get("type", "unknown")
269
+
270
+ if job_type == "load_files":
271
+ return (
272
+ f"Files are still being processed (elapsed: {elapsed_time:.1f}s).\n\n"
273
+ f"Try checking again in a few seconds.",
274
+ f"Job ID: {job_id}",
275
+ f"Status: Processing",
276
+ "",
277
+ job_query
278
+ )
279
+ else: # query job
280
+ return (
281
+ f"Query is still being processed (elapsed: {elapsed_time:.1f}s).\n\n"
282
+ f"Try checking again in a few seconds.",
283
+ f"Job ID: {job_id}",
284
+ f"Input tokens: {count_tokens(job.get('query', ''))}",
285
+ "Output tokens: pending",
286
+ job_query
287
+ )
288
+
289
+ # If job is completed
290
+ if job["status"] == "completed":
291
+ result = job["result"]
292
+ processing_time = job["end_time"] - job["start_time"]
293
+
294
+ if job.get("type") == "load_files":
295
+ return (
296
+ f"{result[0]}\n\nProcessing time: {processing_time:.1f}s",
297
+ result[1],
298
+ result[2],
299
+ "",
300
+ job_query
301
+ )
302
+ else: # query job
303
+ return (
304
+ f"{result[0]}\n\nProcessing time: {processing_time:.1f}s",
305
+ result[1],
306
+ result[2],
307
+ result[3],
308
+ job_query
309
+ )
310
+
311
+ # Fallback for unknown status
312
+ return f"Job status: {job['status']}", "", "", "", job_query
313
+
314
+ # Function to clean up old jobs
315
+ def cleanup_old_jobs():
316
+ current_time = time.time()
317
+ to_delete = []
318
+
319
+ for job_id, job in jobs.items():
320
+ # Keep completed jobs for 1 hour, processing jobs for 2 hours
321
+ if job["status"] == "completed" and (current_time - job.get("end_time", 0)) > 3600:
322
+ to_delete.append(job_id)
323
+ elif job["status"] == "processing" and (current_time - job.get("start_time", 0)) > 7200:
324
+ to_delete.append(job_id)
325
+
326
+ for job_id in to_delete:
327
+ del jobs[job_id]
328
+
329
+ debug_print(f"Cleaned up {len(to_delete)} old jobs. {len(jobs)} jobs remaining.")
330
+ return f"Cleaned up {len(to_delete)} old jobs", "", ""
331
+
332
+ # Improve the truncate_prompt function to be more aggressive with limiting context
333
+ def truncate_prompt(prompt: str, max_tokens: int = 4096) -> str:
334
+ """Truncate prompt to fit within token limit, preserving the most recent/relevant parts."""
335
+ if not prompt:
336
+ return ""
337
+
338
+ if global_tokenizer:
339
+ try:
340
+ tokens = global_tokenizer.encode(prompt)
341
+ if len(tokens) > max_tokens:
342
+ # For prompts, we often want to keep the beginning instructions and the end context
343
+ # So we'll keep the first 20% and the last 80% of the max tokens
344
+ beginning_tokens = int(max_tokens * 0.2)
345
+ ending_tokens = max_tokens - beginning_tokens
346
+
347
+ new_tokens = tokens[:beginning_tokens] + tokens[-(ending_tokens):]
348
+ return global_tokenizer.decode(new_tokens)
349
+ except Exception as e:
350
+ debug_print(f"Truncation error: {str(e)}")
351
+
352
+ # Fallback to word-based truncation
353
+ words = prompt.split()
354
+ if len(words) > max_tokens:
355
+ beginning_words = int(max_tokens * 0.2)
356
+ ending_words = max_tokens - beginning_words
357
+
358
+ return " ".join(words[:beginning_words] + words[-(ending_words):])
359
+
360
+ return prompt
361
+
362
+
363
+
364
+
365
+ default_prompt = """\
366
+ {conversation_history}
367
+ Use the following context to provide a detailed technical answer to the user's question.
368
+ Do not include an introduction like "Based on the provided documents, ...". Just answer the question.
369
+ If you don't know the answer, please respond with "I don't know".
370
+
371
+ Context:
372
+ {context}
373
+
374
+ User's question:
375
+ {question}
376
+ """
377
+
378
+ def load_txt_from_url(url: str) -> Document:
379
+ response = requests.get(url)
380
+ if response.status_code == 200:
381
+ text = response.text.strip()
382
+ if not text:
383
+ raise ValueError(f"TXT file at {url} is empty.")
384
+ return Document(page_content=text, metadata={"source": url})
385
+ else:
386
+ raise Exception(f"Failed to load {url} with status {response.status_code}")
387
+
388
+ class ElevatedRagChain:
389
+ def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
390
+ bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95) -> None:
391
+ debug_print(f"Initializing ElevatedRagChain with model: {llm_choice}")
392
+ self.embed_func = HuggingFaceEmbeddings(
393
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
394
+ model_kwargs={"device": "cpu"}
395
+ )
396
+ self.bm25_weight = bm25_weight
397
+ self.faiss_weight = 1.0 - bm25_weight
398
+ self.top_k = 5
399
+ self.llm_choice = llm_choice
400
+ self.temperature = temperature
401
+ self.top_p = top_p
402
+ self.prompt_template = prompt_template
403
+ self.context = ""
404
+ self.conversation_history: List[Dict[str, str]] = []
405
+ self.raw_data = None
406
+ self.split_data = None
407
+ self.elevated_rag_chain = None
408
+
409
+ # Instance method to capture context and conversation history
410
+ def capture_context(self, result):
411
+ self.context = "\n".join([str(doc) for doc in result["context"]])
412
+ result["context"] = self.context
413
+ history_text = (
414
+ "\n".join([f"Q: {conv['query']}\nA: {conv['response']}" for conv in self.conversation_history])
415
+ if self.conversation_history else ""
416
+ )
417
+ result["conversation_history"] = history_text
418
+ return result
419
+
420
+ # Instance method to extract question from input data
421
+ def extract_question(self, input_data):
422
+ return input_data["question"]
423
+
424
+ # Improve error handling in the ElevatedRagChain class
425
+ def create_llm_pipeline(self):
426
+ from langchain.llms.base import LLM # Import LLM here so it's always defined
427
+ normalized = self.llm_choice.lower()
428
+ try:
429
+ if "remote" in normalized:
430
+ debug_print("Creating remote Meta-Llama-3 pipeline via Hugging Face Inference API...")
431
+ from huggingface_hub import InferenceClient
432
+ repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
433
+ hf_api_token = os.environ.get("HF_API_TOKEN")
434
+ if not hf_api_token:
435
+ raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
436
+
437
+ client = InferenceClient(token=hf_api_token, timeout=120)
438
+
439
+ # We no longer use wait_for_model because it's unsupported
440
+ def remote_generate(prompt: str) -> str:
441
+ max_retries = 3
442
+ backoff = 2 # start with 2 seconds
443
+ for attempt in range(max_retries):
444
+ try:
445
+ debug_print(f"Remote generation attempt {attempt+1}")
446
+ response = client.text_generation(
447
+ prompt,
448
+ model=repo_id,
449
+ temperature=self.temperature,
450
+ top_p=self.top_p,
451
+ max_new_tokens=512 # Reduced token count for speed
452
+ )
453
+ return response
454
+ except Exception as e:
455
+ debug_print(f"Attempt {attempt+1} failed with error: {e}")
456
+ if attempt == max_retries - 1:
457
+ raise
458
+ time.sleep(backoff)
459
+ backoff *= 2 # exponential backoff
460
+ return "Failed to generate response after multiple attempts."
461
+
462
+ class RemoteLLM(LLM):
463
+ @property
464
+ def _llm_type(self) -> str:
465
+ return "remote_llm"
466
+
467
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
468
+ return remote_generate(prompt)
469
+
470
+ @property
471
+ def _identifying_params(self) -> dict:
472
+ return {"model": repo_id}
473
+
474
+ debug_print("Remote Meta-Llama-3 pipeline created successfully.")
475
+ return RemoteLLM()
476
+
477
+ elif "mistral-api" in normalized:
478
+ debug_print("Creating Mistral API pipeline...")
479
+ mistral_api_key = os.environ.get("MISTRAL_API_KEY")
480
+ if not mistral_api_key:
481
+ raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
482
+ try:
483
+ from mistralai import Mistral
484
+ debug_print("Mistral library imported successfully")
485
+ except ImportError:
486
+ debug_print("Mistral client library not installed. Falling back to Llama pipeline.")
487
+ normalized = "llama"
488
+ if normalized != "llama":
489
+ # from pydantic import PrivateAttr
490
+ # from langchain.llms.base import LLM
491
+ # from typing import Any, Optional, List
492
+ # import typing
493
+
494
+ class MistralLLM(LLM):
495
+ temperature: float = 0.7
496
+ top_p: float = 0.95
497
+ _client: Any = PrivateAttr(default=None)
498
+
499
+ def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
500
+ try:
501
+ super().__init__(**kwargs)
502
+ # Bypass Pydantic's __setattr__ to assign to _client
503
+ object.__setattr__(self, '_client', Mistral(api_key=api_key))
504
+ self.temperature = temperature
505
+ self.top_p = top_p
506
+ except Exception as e:
507
+ debug_print(f"Init Mistral failed with error: {e}")
508
+
509
+ @property
510
+ def _llm_type(self) -> str:
511
+ return "mistral_llm"
512
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
513
+ try:
514
+ debug_print("Calling Mistral API...")
515
+ response = self._client.chat.complete(
516
+ model="mistral-small-latest",
517
+ messages=[{"role": "user", "content": prompt}],
518
+ temperature=self.temperature,
519
+ top_p=self.top_p
520
+ )
521
+ return response.choices[0].message.content
522
+ except Exception as e:
523
+ debug_print(f"Mistral API error: {str(e)}")
524
+ return f"Error generating response: {str(e)}"
525
+ @property
526
+ def _identifying_params(self) -> dict:
527
+ return {"model": "mistral-small-latest"}
528
+ debug_print("Creating Mistral LLM instance")
529
+ mistral_llm = MistralLLM(api_key=mistral_api_key, temperature=self.temperature, top_p=self.top_p)
530
+ debug_print("Mistral API pipeline created successfully.")
531
+ return mistral_llm
532
+
533
+ else:
534
+ # Default case - using a fallback model (or Llama)
535
+ debug_print("Using local/fallback model pipeline")
536
+ model_id = "facebook/opt-350m" # Use a smaller model as fallback
537
+ pipe = pipeline(
538
+ "text-generation",
539
+ model=model_id,
540
+ device=-1, # CPU
541
+ max_length=1024
542
+ )
543
+
544
+ class LocalLLM(LLM):
545
+ @property
546
+ def _llm_type(self) -> str:
547
+ return "local_llm"
548
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
549
+ # For this fallback, truncate prompt if it exceeds limits
550
+ reserved_gen = 128
551
+ max_total = 1024
552
+ max_prompt_tokens = max_total - reserved_gen
553
+ truncated_prompt = truncate_prompt(prompt, max_tokens=max_prompt_tokens)
554
+ generated = pipe(truncated_prompt, max_new_tokens=reserved_gen)[0]["generated_text"]
555
+ return generated
556
+ @property
557
+ def _identifying_params(self) -> dict:
558
+ return {"model": model_id, "max_length": 1024}
559
+
560
+ debug_print("Local fallback pipeline created.")
561
+ return LocalLLM()
562
+
563
+ except Exception as e:
564
+ debug_print(f"Error creating LLM pipeline: {str(e)}")
565
+ # Return a dummy LLM that explains the error
566
+ class ErrorLLM(LLM):
567
+ @property
568
+ def _llm_type(self) -> str:
569
+ return "error_llm"
570
+
571
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
572
+ return f"Error initializing LLM: \n\nPlease check your environment variables and try again."
573
+
574
+ @property
575
+ def _identifying_params(self) -> dict:
576
+ return {"model": "error"}
577
+
578
+ return ErrorLLM()
579
+
580
+
581
+ def update_llm_pipeline(self, new_model_choice: str, temperature: float, top_p: float, prompt_template: str, bm25_weight: float):
582
+ debug_print(f"Updating chain with new model: {new_model_choice}")
583
+ self.llm_choice = new_model_choice
584
+ self.temperature = temperature
585
+ self.top_p = top_p
586
+ self.prompt_template = prompt_template
587
+ self.bm25_weight = bm25_weight
588
+ self.faiss_weight = 1.0 - bm25_weight
589
+ self.llm = self.create_llm_pipeline()
590
+ def format_response(response: str) -> str:
591
+ input_tokens = count_tokens(self.context + self.prompt_template)
592
+ output_tokens = count_tokens(response)
593
+ formatted = f"### Response\n\n{response}\n\n---\n"
594
+ formatted += f"- **Input tokens:** {input_tokens}\n"
595
+ formatted += f"- **Output tokens:** {output_tokens}\n"
596
+ formatted += f"- **Generated using:** {self.llm_choice}\n"
597
+ formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
598
+ return formatted
599
+ base_runnable = RunnableParallel({
600
+ "context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
601
+ "question": RunnableLambda(self.extract_question)
602
+ }) | self.capture_context
603
+ self.elevated_rag_chain = base_runnable | self.rag_prompt | self.llm | format_response
604
+ debug_print("Chain updated successfully with new LLM pipeline.")
605
+
606
+ def add_pdfs_to_vectore_store(self, file_links: List[str]) -> None:
607
+ debug_print(f"Processing files using {self.llm_choice}")
608
+ self.raw_data = []
609
+ for link in file_links:
610
+ if link.lower().endswith(".pdf"):
611
+ debug_print(f"Loading PDF: {link}")
612
+ loaded_docs = OnlinePDFLoader(link).load()
613
+ if loaded_docs:
614
+ self.raw_data.append(loaded_docs[0])
615
+ else:
616
+ debug_print(f"No content found in PDF: {link}")
617
+ elif link.lower().endswith(".txt") or link.lower().endswith(".utf-8"):
618
+ debug_print(f"Loading TXT: {link}")
619
+ try:
620
+ self.raw_data.append(load_txt_from_url(link))
621
+ except Exception as e:
622
+ debug_print(f"Error loading TXT file {link}: {e}")
623
+ else:
624
+ debug_print(f"File type not supported for URL: {link}")
625
+ if not self.raw_data:
626
+ raise ValueError("No files were successfully loaded. Please check the URLs and file formats.")
627
+ debug_print("Files loaded successfully.")
628
+ debug_print("Starting text splitting...")
629
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=100)
630
+ self.split_data = self.text_splitter.split_documents(self.raw_data)
631
+ if not self.split_data:
632
+ raise ValueError("Text splitting resulted in no chunks. Check the file contents.")
633
+ debug_print(f"Text splitting completed. Number of chunks: {len(self.split_data)}")
634
+ debug_print("Creating BM25 retriever...")
635
+ self.bm25_retriever = BM25Retriever.from_documents(self.split_data)
636
+ self.bm25_retriever.k = self.top_k
637
+ debug_print("BM25 retriever created.")
638
+ debug_print("Embedding chunks and creating FAISS vector store...")
639
+ self.vector_store = FAISS.from_documents(self.split_data, self.embed_func)
640
+ self.faiss_retriever = self.vector_store.as_retriever(search_kwargs={"k": self.top_k})
641
+ debug_print("FAISS vector store created successfully.")
642
+ self.ensemble_retriever = EnsembleRetriever(
643
+ retrievers=[self.bm25_retriever, self.faiss_retriever],
644
+ weights=[self.bm25_weight, self.faiss_weight]
645
+ )
646
+
647
+ base_runnable = RunnableParallel({
648
+ "context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
649
+ "question": RunnableLambda(self.extract_question)
650
+ }) | self.capture_context
651
+
652
+ # Ensure the prompt template is set
653
+ self.rag_prompt = ChatPromptTemplate.from_template(self.prompt_template)
654
+ if self.rag_prompt is None:
655
+ raise ValueError("Prompt template could not be created from the given template.")
656
+ prompt_runnable = RunnableLambda(lambda vars: self.rag_prompt.format(**vars))
657
+
658
+ self.str_output_parser = StrOutputParser()
659
+ debug_print("Selecting LLM pipeline based on choice: " + self.llm_choice)
660
+ self.llm = self.create_llm_pipeline()
661
+ if self.llm is None:
662
+ raise ValueError("LLM pipeline creation failed.")
663
+
664
+ def format_response(response: str) -> str:
665
+ input_tokens = count_tokens(self.context + self.prompt_template)
666
+ output_tokens = count_tokens(response)
667
+ formatted = f"### Response\n\n{response}\n\n---\n"
668
+ formatted += f"- **Input tokens:** {input_tokens}\n"
669
+ formatted += f"- **Output tokens:** {output_tokens}\n"
670
+ formatted += f"- **Generated using:** {self.llm_choice}\n"
671
+ formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
672
+ return formatted
673
+
674
+ self.elevated_rag_chain = base_runnable | prompt_runnable | self.llm | format_response
675
+ debug_print("Elevated RAG chain successfully built and ready to use.")
676
+
677
+
678
+
679
+ def get_current_context(self) -> str:
680
+ base_context = "\n".join([str(doc) for doc in self.split_data[:3]]) if self.split_data else "No context available."
681
+ history_summary = "\n\n---\n**Recent Conversations (last 3):**\n"
682
+ recent = self.conversation_history[-3:]
683
+ if recent:
684
+ for i, conv in enumerate(recent, 1):
685
+ history_summary += f"**Conversation {i}:**\n- Query: {conv['query']}\n- Response: {conv['response']}\n"
686
+ else:
687
+ history_summary += "No conversation history."
688
+ return base_context + history_summary
689
+
690
+ # ----------------------------
691
+ # Gradio Interface Functions
692
+ # ----------------------------
693
+ global rag_chain
694
+ rag_chain = ElevatedRagChain()
695
+
696
+ def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
697
+ debug_print("Inside load_pdfs function.")
698
+ if not file_links:
699
+ debug_print("Please enter non-empty URLs")
700
+ return "Please enter non-empty URLs", "Word count: N/A", "Model used: N/A", "Context: N/A"
701
+ try:
702
+ links = [link.strip() for link in file_links.split("\n") if link.strip()]
703
+ global rag_chain
704
+ if rag_chain.raw_data:
705
+ rag_chain.update_llm_pipeline(model_choice, temperature, top_p, prompt_template, bm25_weight)
706
+ context_display = rag_chain.get_current_context()
707
+ response_msg = f"Files already loaded. Chain updated with model: {model_choice}"
708
+ return (
709
+ response_msg,
710
+ f"Word count: {word_count(rag_chain.context)}",
711
+ f"Model used: {rag_chain.llm_choice}",
712
+ f"Context:\n{context_display}"
713
+ )
714
+ else:
715
+ rag_chain = ElevatedRagChain(
716
+ llm_choice=model_choice,
717
+ prompt_template=prompt_template,
718
+ bm25_weight=bm25_weight,
719
+ temperature=temperature,
720
+ top_p=top_p
721
+ )
722
+ rag_chain.add_pdfs_to_vectore_store(links)
723
+ context_display = rag_chain.get_current_context()
724
+ response_msg = f"Files loaded successfully. Using model: {model_choice}"
725
+ return (
726
+ response_msg,
727
+ f"Word count: {word_count(rag_chain.context)}",
728
+ f"Model used: {rag_chain.llm_choice}",
729
+ f"Context:\n{context_display}"
730
+ )
731
+ except Exception as e:
732
+ error_msg = traceback.format_exc()
733
+ debug_print("Could not load files. Error: " + error_msg)
734
+ return (
735
+ "Error loading files: " + str(e),
736
+ f"Word count: {word_count('')}",
737
+ f"Model used: {rag_chain.llm_choice}",
738
+ "Context: N/A"
739
+ )
740
+
741
+ def update_model(new_model: str):
742
+ global rag_chain
743
+ if rag_chain and rag_chain.raw_data:
744
+ rag_chain.update_llm_pipeline(new_model, rag_chain.temperature, rag_chain.top_p,
745
+ rag_chain.prompt_template, rag_chain.bm25_weight)
746
+ debug_print(f"Model updated to {rag_chain.llm_choice}")
747
+ return f"Model updated to: {rag_chain.llm_choice}"
748
+ else:
749
+ return "No files loaded; please load files first."
750
+
751
+
752
+ # Update submit_query_updated to better handle context limitation
753
+ def submit_query_updated(query):
754
+ debug_print(f"Processing query: {query}")
755
+ if not query:
756
+ debug_print("Empty query received")
757
+ return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
758
+
759
+ if not hasattr(rag_chain, 'elevated_rag_chain') or not rag_chain.raw_data:
760
+ debug_print("RAG chain not initialized")
761
+ return "Please load files first.", "", "Input tokens: 0", "Output tokens: 0"
762
+
763
+ try:
764
+ # Determine max context size based on model
765
+ model_name = rag_chain.llm_choice.lower()
766
+ max_context_tokens = 32000 if "mistral" in model_name else 4096
767
+
768
+ # Reserve 20% of tokens for the question and response generation
769
+ reserved_tokens = int(max_context_tokens * 0.2)
770
+ max_context_tokens -= reserved_tokens
771
+
772
+ # Collect conversation history (last 2 only to save tokens)
773
+ if rag_chain.conversation_history:
774
+ recent_history = rag_chain.conversation_history[-2:]
775
+ history_text = "\n".join([f"Q: {conv['query']}\nA: {conv['response'][:300]}..."
776
+ for conv in recent_history])
777
+ else:
778
+ history_text = ""
779
+
780
+ # Get history token count
781
+ history_tokens = count_tokens(history_text)
782
+
783
+ # Adjust context tokens based on history size
784
+ context_tokens = max_context_tokens - history_tokens
785
+
786
+ # Ensure we have some minimum context
787
+ context_tokens = max(context_tokens, 1000)
788
+
789
+ # Truncate context if needed
790
+ context = truncate_prompt(rag_chain.context, max_tokens=context_tokens)
791
+
792
+ debug_print(f"Using model: {model_name}, context tokens: {count_tokens(context)}, history tokens: {history_tokens}")
793
+
794
+ prompt_variables = {
795
+ "conversation_history": history_text,
796
+ "context": context,
797
+ "question": query
798
+ }
799
+
800
+ debug_print("Invoking RAG chain")
801
+ response = rag_chain.elevated_rag_chain.invoke({"question": query})
802
+
803
+ # Store only a reasonable amount of the response in history
804
+ trimmed_response = response[:1000] + ("..." if len(response) > 1000 else "")
805
+ rag_chain.conversation_history.append({"query": query, "response": trimmed_response})
806
+
807
+ input_token_count = count_tokens(query)
808
+ output_token_count = count_tokens(response)
809
+
810
+ debug_print(f"Query processed successfully. Output tokens: {output_token_count}")
811
+
812
+ return (
813
+ response,
814
+ rag_chain.get_current_context(),
815
+ f"Input tokens: {input_token_count}",
816
+ f"Output tokens: {output_token_count}"
817
+ )
818
+ except Exception as e:
819
+ error_msg = traceback.format_exc()
820
+ debug_print(f"LLM error: {error_msg}")
821
+ return (
822
+ f"Query error: {str(e)}\n\nTry using a smaller document or simplifying your query.",
823
+ "",
824
+ "Input tokens: 0",
825
+ "Output tokens: 0"
826
+ )
827
+
828
+ def reset_app_updated():
829
+ global rag_chain
830
+ rag_chain = ElevatedRagChain()
831
+ debug_print("App reset successfully.")
832
+ return (
833
+ "App reset successfully. You can now load new files",
834
+ "",
835
+ "Model used: Not selected"
836
+ )
837
+
838
+ # ----------------------------
839
+ # Gradio Interface Setup
840
+ # ----------------------------
841
+ custom_css = """
842
+ textarea {
843
+ overflow-y: scroll !important;
844
+ max-height: 200px;
845
+ }
846
+ """
847
+
848
+ # Function to add dots and reset
849
+ def add_dots_and_reset():
850
+ if not hasattr(add_dots_and_reset, "dots"):
851
+ add_dots_and_reset.dots = "" # Initialize the attribute
852
+
853
+ # Add a dot
854
+ add_dots_and_reset.dots += "."
855
+
856
+ # Reset after 5 dots
857
+ if len(add_dots_and_reset.dots) > 5:
858
+ add_dots_and_reset.dots = ""
859
+
860
+ print(f"Current dots: {add_dots_and_reset.dots}") # Debugging print
861
+ return add_dots_and_reset.dots
862
+
863
+ # Define a dummy function to simulate data retrieval
864
+ def run_query(max_value):
865
+ # Simulate a data retrieval or processing function
866
+ return [[i, i**2] for i in range(1, max_value + 1)]
867
+
868
+ # Function to call both refresh_job_list and check_job_status using the last job ID
869
+ def periodic_update(is_checked):
870
+ if is_checked:
871
+ global last_job_id
872
+ job_list_md = refresh_job_list()
873
+ job_status = check_job_status(last_job_id) if last_job_id else ("No job ID available", "", "", "", "")
874
+ query_results = run_query(10) # Use a fixed value or another logic if needed
875
+ context_info = rag_chain.get_current_context() if rag_chain else "No context available."
876
+ return job_list_md, job_status[0], query_results, context_info
877
+ else:
878
+ return "", "", [], ""
879
+
880
+ # Update the Gradio interface to include job status checking
881
+ with gr.Blocks(css=custom_css, js="""
882
+ document.addEventListener('DOMContentLoaded', function() {
883
+ // Add event listener for job list clicks
884
+ const jobListInterval = setInterval(() => {
885
+ const jobLinks = document.querySelectorAll('.job-list-container a');
886
+ if (jobLinks.length > 0) {
887
+ jobLinks.forEach(link => {
888
+ link.addEventListener('click', function(e) {
889
+ e.preventDefault();
890
+ const jobId = this.textContent.split(' ')[0];
891
+ // Find the job ID input textbox and set its value
892
+ const jobIdInput = document.querySelector('.job-id-input input');
893
+ if (jobIdInput) {
894
+ jobIdInput.value = jobId;
895
+ // Trigger the input event to update Gradio's state
896
+ jobIdInput.dispatchEvent(new Event('input', { bubbles: true }));
897
+ }
898
+ });
899
+ });
900
+ clearInterval(jobListInterval);
901
+ }
902
+ }, 500);
903
+ });
904
+ """) as app:
905
+ gr.Markdown('''# PhiRAG - Async Version
906
+ **PhiRAG** Query Your Data with Advanced RAG Techniques
907
+
908
+ **Model Selection & Parameters:** Choose from the following options:
909
+ - 🇺🇸 Remote Meta-Llama-3 - has context windows of 8000 tokens
910
+ - 🇪🇺 Mistral-API - has context windows of 32000 tokens
911
+
912
+ **🔥 Randomness (Temperature):** Adjusts output predictability.
913
+ - Example: 0.2 makes the output very deterministic (less creative), while 0.8 introduces more variety and spontaneity.
914
+
915
+ **🎯 Word Variety (Top‑p):** Limits word choices to a set probability percentage.
916
+ - Example: 0.5 restricts output to the most likely 50% of token choices for a focused answer; 0.95 allows almost all possibilities for more diverse responses.
917
+
918
+ **⚖️ BM25 Weight:** Adjust Lexical vs Semantics.
919
+ - Example: A value of 0.8 puts more emphasis on exact keyword (lexical) matching, while 0.3 shifts emphasis toward semantic similarity.
920
+
921
+ **✏️ Prompt Template:** Edit as desired.
922
+
923
+ **🔗 File URLs:** Enter one URL per line (.pdf or .txt).\
924
+ - Example: Provide one URL per line, such as
925
+ https://www.gutenberg.org/ebooks/8438.txt.utf-8
926
+
927
+ **🔍 Query:** Enter your query below.
928
+
929
+ **⚠️ IMPORTANT: This app now uses asynchronous processing to avoid timeout issues**
930
+ - When you load files or submit a query, you'll receive a Job ID
931
+ - Use the "Check Job Status" tab to monitor and retrieve your results
932
+ ''')
933
+
934
+ with gr.Tabs() as tabs:
935
+ with gr.TabItem("Setup & Load Files"):
936
+ with gr.Row():
937
+ with gr.Column():
938
+ model_dropdown = gr.Dropdown(
939
+ choices=["🇺🇸 Remote Meta-Llama-3", "🇪🇺 Mistral-API"],
940
+ value="🇺🇸 Remote Meta-Llama-3",
941
+ label="Select Model"
942
+ )
943
+ temperature_slider = gr.Slider(
944
+ minimum=0.1, maximum=1.0, value=0.5, step=0.1,
945
+ label="Randomness (Temperature)"
946
+ )
947
+ top_p_slider = gr.Slider(
948
+ minimum=0.1, maximum=0.99, value=0.95, step=0.05,
949
+ label="Word Variety (Top-p)"
950
+ )
951
+ with gr.Column():
952
+ pdf_input = gr.Textbox(
953
+ label="Enter your file URLs (one per line)",
954
+ placeholder="Enter one URL per line (.pdf or .txt)",
955
+ lines=4
956
+ )
957
+ prompt_input = gr.Textbox(
958
+ label="Custom Prompt Template",
959
+ placeholder="Enter your custom prompt template here",
960
+ lines=8,
961
+ value=default_prompt
962
+ )
963
+ with gr.Column():
964
+ bm25_weight_slider = gr.Slider(
965
+ minimum=0.0, maximum=1.0, value=0.6, step=0.1,
966
+ label="Lexical vs Semantics (BM25 Weight)"
967
+ )
968
+ load_button = gr.Button("Load Files (Async)")
969
+ load_status = gr.Markdown("Status: Waiting for files")
970
+
971
+ with gr.Row():
972
+ load_response = gr.Textbox(
973
+ label="Load Response",
974
+ placeholder="Response will appear here",
975
+ lines=4
976
+ )
977
+ load_context = gr.Textbox(
978
+ label="Context Info",
979
+ placeholder="Context info will appear here",
980
+ lines=4
981
+ )
982
+
983
+ with gr.Row():
984
+ model_output = gr.Markdown("**Current Model**: Not selected")
985
+
986
+ with gr.TabItem("Submit Query"):
987
+ with gr.Row():
988
+ # Add this line to define the query_model_dropdown
989
+ query_model_dropdown = gr.Dropdown(
990
+ choices=["���🇸 Remote Meta-Llama-3", "🇪🇺 Mistral-API"],
991
+ value="🇺🇸 Remote Meta-Llama-3",
992
+ label="Query Model"
993
+ )
994
+
995
+ query_input = gr.Textbox(
996
+ label="Enter your query here",
997
+ placeholder="Type your query",
998
+ lines=4
999
+ )
1000
+ submit_button = gr.Button("Submit Query (Async)")
1001
+
1002
+ with gr.Row():
1003
+ query_response = gr.Textbox(
1004
+ label="Query Response",
1005
+ placeholder="Response will appear here (formatted as Markdown)",
1006
+ lines=6
1007
+ )
1008
+ query_context = gr.Textbox(
1009
+ label="Context Information",
1010
+ placeholder="Retrieved context and conversation history will appear here",
1011
+ lines=6
1012
+ )
1013
+
1014
+ with gr.Row():
1015
+ input_tokens = gr.Markdown("Input tokens: 0")
1016
+ output_tokens = gr.Markdown("Output tokens: 0")
1017
+
1018
+ with gr.TabItem("Check Job Status"):
1019
+ with gr.Row():
1020
+ with gr.Column(scale=1):
1021
+ job_list = gr.Markdown(
1022
+ value="No jobs yet",
1023
+ label="Job List (Click to select)"
1024
+ )
1025
+ # Add the Refresh Job List button
1026
+ refresh_button = gr.Button("Refresh Job List")
1027
+
1028
+ # Use a Checkbox to control the periodic updates
1029
+ auto_refresh_checkbox = gr.Checkbox(
1030
+ label="Enable Auto Refresh",
1031
+ value=False # Default to unchecked
1032
+ )
1033
+
1034
+ # Use a DataFrame to display results
1035
+ df = gr.DataFrame(
1036
+ value=run_query(10), # Initial value
1037
+ headers=["Number", "Square"],
1038
+ label="Query Results",
1039
+ visible=False # Set the DataFrame to be invisible
1040
+ )
1041
+
1042
+ with gr.Column(scale=2):
1043
+ job_id_input = gr.Textbox(
1044
+ label="Job ID",
1045
+ placeholder="Job ID will appear here when selected from the list",
1046
+ lines=1
1047
+ )
1048
+ job_query_display = gr.Textbox(
1049
+ label="Job Query",
1050
+ placeholder="The query associated with this job will appear here",
1051
+ lines=2,
1052
+ interactive=False
1053
+ )
1054
+ check_button = gr.Button("Check Status")
1055
+ cleanup_button = gr.Button("Cleanup Old Jobs")
1056
+
1057
+ with gr.Row():
1058
+ status_response = gr.Textbox(
1059
+ label="Job Result",
1060
+ placeholder="Job result will appear here",
1061
+ lines=6
1062
+ )
1063
+ status_context = gr.Textbox(
1064
+ label="Context Information",
1065
+ placeholder="Context information will appear here",
1066
+ lines=6
1067
+ )
1068
+
1069
+ with gr.Row():
1070
+ status_tokens1 = gr.Markdown("")
1071
+ status_tokens2 = gr.Markdown("")
1072
+
1073
+ with gr.TabItem("App Management"):
1074
+ with gr.Row():
1075
+ reset_button = gr.Button("Reset App")
1076
+
1077
+ with gr.Row():
1078
+ reset_response = gr.Textbox(
1079
+ label="Reset Response",
1080
+ placeholder="Reset confirmation will appear here",
1081
+ lines=2
1082
+ )
1083
+ reset_context = gr.Textbox(
1084
+ label="",
1085
+ placeholder="",
1086
+ lines=2,
1087
+ visible=False
1088
+ )
1089
+
1090
+ with gr.Row():
1091
+ reset_model = gr.Markdown("")
1092
+
1093
+ # Connect the buttons to their respective functions
1094
+ load_button.click(
1095
+ load_pdfs_async,
1096
+ inputs=[pdf_input, model_dropdown, prompt_input, bm25_weight_slider, temperature_slider, top_p_slider],
1097
+ outputs=[load_response, load_context, model_output, job_id_input, job_query_display, job_list]
1098
+ )
1099
+
1100
+ # Also sync in the other direction
1101
+ query_model_dropdown.change(
1102
+ fn=sync_model_dropdown,
1103
+ inputs=query_model_dropdown,
1104
+ outputs=model_dropdown
1105
+ )
1106
+
1107
+ submit_button.click(
1108
+ submit_query_async,
1109
+ inputs=[query_input, query_model_dropdown],
1110
+ outputs=[query_response, query_context, input_tokens, output_tokens, job_id_input, job_query_display, job_list]
1111
+ )
1112
+
1113
+ check_button.click(
1114
+ check_job_status,
1115
+ inputs=[job_id_input],
1116
+ outputs=[status_response, status_context, status_tokens1, status_tokens2, job_query_display]
1117
+ )
1118
+
1119
+ # Connect the refresh button to the refresh_job_list function
1120
+ refresh_button.click(
1121
+ refresh_job_list,
1122
+ inputs=[],
1123
+ outputs=[job_list]
1124
+ )
1125
+
1126
+ # Connect the job list selection event (this is handled by JavaScript)
1127
+ job_id_input.change(
1128
+ job_selected,
1129
+ inputs=[job_id_input],
1130
+ outputs=[job_id_input, job_query_display]
1131
+ )
1132
+
1133
+ cleanup_button.click(
1134
+ cleanup_old_jobs,
1135
+ inputs=[],
1136
+ outputs=[status_response, status_context, status_tokens1]
1137
+ )
1138
+
1139
+ reset_button.click(
1140
+ reset_app_updated,
1141
+ inputs=[],
1142
+ outputs=[reset_response, reset_context, reset_model]
1143
+ )
1144
+
1145
+ model_dropdown.change(
1146
+ fn=sync_model_dropdown,
1147
+ inputs=model_dropdown,
1148
+ outputs=query_model_dropdown
1149
+ )
1150
+
1151
+ # Add an event to refresh the job list on page load
1152
+ app.load(
1153
+ fn=refresh_job_list,
1154
+ inputs=None,
1155
+ outputs=job_list
1156
+ )
1157
+
1158
+ # Use the Checkbox to control the periodic updates
1159
+ auto_refresh_checkbox.change(
1160
+ fn=periodic_update,
1161
+ inputs=[auto_refresh_checkbox],
1162
+ outputs=[job_list, status_response, df, status_context],
1163
+ every=2
1164
+ )
1165
+
1166
+ if __name__ == "__main__":
1167
+ debug_print("Launching Gradio interface.")
1168
+ app.queue().launch(share=False)
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- gradio==4.44.1
2
  langchain-community==0.0.19
3
  langchain_core==0.1.22
4
  langchain-openai==0.0.5
 
1
+ gradio==3.40.0
2
  langchain-community==0.0.19
3
  langchain_core==0.1.22
4
  langchain-openai==0.0.5