alx-d commited on
Commit
97f878b
·
verified ·
1 Parent(s): 7f0ef09

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. advanced_rag.py +68 -54
  2. requirements.txt +1 -1
advanced_rag.py CHANGED
@@ -22,7 +22,6 @@ from langchain.schema import StrOutputParser, Document
22
  from langchain_core.runnables import RunnableParallel, RunnableLambda
23
  from transformers.quantizers.auto import AutoQuantizationConfig
24
  import gradio as gr
25
- import requests
26
  from pydantic import PrivateAttr
27
  import pydantic
28
 
@@ -33,7 +32,7 @@ import time
33
  import re
34
  import requests
35
  from langchain.schema import Document
36
- from langchain.document_loaders import PyPDFLoader
37
  import tempfile
38
  import mimetypes
39
 
@@ -395,67 +394,82 @@ def load_txt_from_url(url: str) -> Document:
395
  else:
396
  raise Exception(f"Failed to load {url} with status {response.status_code}")
397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  def load_file_from_google_drive(link: str) -> list:
399
  """
400
- Load PDF or text from a Google Drive shared link by detecting the file type
 
401
  """
402
- # Extract the file ID from the Google Drive link
403
- file_id_match = re.search(r'\/d\/(.*?)\/view', link)
404
- if not file_id_match:
405
- raise ValueError(f"Could not extract file ID from Google Drive link: {link}")
406
-
407
- file_id = file_id_match.group(1)
408
-
409
- # Create direct download link
410
- download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
411
-
412
- # Download the file to a temporary location
413
- response = requests.get(download_url, stream=True)
414
- if response.status_code != 200:
415
- raise ValueError(f"Failed to download file from Google Drive. Status code: {response.status_code}")
416
-
417
- # Create a temporary file
418
  with tempfile.NamedTemporaryFile(delete=False) as temp_file:
419
  temp_path = temp_file.name
420
- # Write content to the temp file
421
- for chunk in response.iter_content(chunk_size=1024):
422
- if chunk:
423
- temp_file.write(chunk)
424
- # With:
425
 
426
  try:
427
- # Detect file type using python-magic
428
- mime_type = get_mime_type(temp_path)
429
- debug_print(f"Detected MIME type: {mime_type}")
430
-
431
- if mime_type == 'application/pdf':
432
- # Handle PDF file
433
- loader = PyPDFLoader(temp_path)
434
- documents = loader.load()
435
-
436
- # Update metadata to include source URL
437
- for doc in documents:
438
- doc.metadata["source"] = link
439
-
440
- debug_print(f"Loaded PDF with {len(documents)} pages")
441
- return documents
442
- else:
443
- # Handle as text file
444
- with open(temp_path, 'r', encoding='utf-8', errors='ignore') as file:
445
- content = file.read()
446
-
447
- metadata = {"source": link}
448
- return [Document(page_content=content, metadata=metadata)]
449
- except Exception as e:
450
- # Log the error for debugging
451
- debug_print(f"Error processing file: {str(e)}")
452
- raise e
453
  finally:
454
- # Clean up the temporary file
455
  if os.path.exists(temp_path):
456
- os.unlink(temp_path)
457
-
458
-
459
  class ElevatedRagChain:
460
  def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
461
  bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95) -> None:
 
22
  from langchain_core.runnables import RunnableParallel, RunnableLambda
23
  from transformers.quantizers.auto import AutoQuantizationConfig
24
  import gradio as gr
 
25
  from pydantic import PrivateAttr
26
  import pydantic
27
 
 
32
  import re
33
  import requests
34
  from langchain.schema import Document
35
+ from langchain_community.document_loaders import PyMuPDFLoader # Updated loader
36
  import tempfile
37
  import mimetypes
38
 
 
394
  else:
395
  raise Exception(f"Failed to load {url} with status {response.status_code}")
396
 
397
+ from pdfminer.high_level import extract_text
398
+ from langchain_core.documents import Document
399
+
400
+
401
+ def get_confirm_token(response):
402
+ for key, value in response.cookies.items():
403
+ if key.startswith("download_warning"):
404
+ return value
405
+ return None
406
+
407
+
408
+ def download_file_from_google_drive(file_id, destination):
409
+ """
410
+ Download a file from Google Drive handling large file confirmation.
411
+ """
412
+ URL = "https://docs.google.com/uc?export=download&confirm=1"
413
+ session = requests.Session()
414
+
415
+ response = session.get(URL, params={"id": file_id}, stream=True)
416
+ token = get_confirm_token(response)
417
+
418
+ if token:
419
+ params = {"id": file_id, "confirm": token}
420
+ response = session.get(URL, params=params, stream=True)
421
+
422
+ save_response_content(response, destination)
423
+
424
+
425
+ def save_response_content(response, destination):
426
+ CHUNK_SIZE = 32768
427
+ with open(destination, "wb") as f:
428
+ for chunk in response.iter_content(CHUNK_SIZE):
429
+ if chunk:
430
+ f.write(chunk)
431
+
432
+
433
+ def extract_file_id(drive_link: str) -> str:
434
+ match = re.search(r"/d/([a-zA-Z0-9_-]+)", drive_link)
435
+ if match:
436
+ return match.group(1)
437
+ raise ValueError("Could not extract file ID from the provided Google Drive link.")
438
+
439
+
440
  def load_file_from_google_drive(link: str) -> list:
441
  """
442
+ Load a document from a Google Drive link using pdfminer to extract text.
443
+ Returns a list of LangChain Document objects.
444
  """
445
+ file_id = extract_file_id(link)
446
+ print(f"[DEBUG] Extracted file ID: {file_id}")
447
+
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  with tempfile.NamedTemporaryFile(delete=False) as temp_file:
449
  temp_path = temp_file.name
 
 
 
 
 
450
 
451
  try:
452
+ download_file_from_google_drive(file_id, temp_path)
453
+ print(f"[DEBUG] File downloaded to: {temp_path}")
454
+
455
+ try:
456
+ full_text = extract_text(temp_path)
457
+ if not full_text.strip():
458
+ raise ValueError("Extracted text is empty. The PDF might be image-based.")
459
+ print("[DEBUG] Extracted preview text from PDF:")
460
+ print(full_text[:1000]) # Preview first 500 characters
461
+
462
+ document = Document(page_content=full_text, metadata={"source": link})
463
+ return [document]
464
+
465
+ except Exception as e:
466
+ print(f"[ERROR] Could not extract text from PDF: {e}")
467
+ return []
468
+
 
 
 
 
 
 
 
 
 
469
  finally:
 
470
  if os.path.exists(temp_path):
471
+ os.remove(temp_path)
472
+
 
473
  class ElevatedRagChain:
474
  def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
475
  bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95) -> None:
requirements.txt CHANGED
@@ -46,4 +46,4 @@ pydantic==2.9.0
46
 
47
  sentence-transformers>=2.4.0
48
 
49
- mistralai==1.5.0
 
46
 
47
  sentence-transformers>=2.4.0
48
 
49
+ mistralai==1.5.0