Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| from typing import Dict, List, Optional | |
| from bs4 import BeautifulSoup | |
| import yt_dlp | |
| import pandas as pd | |
| import requests | |
| import torch | |
| from langchain_community.document_loaders import YoutubeLoader | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain_community.tools import BearlyInterpreterTool | |
| from langchain.docstore.document import Document | |
| from smolagents import ( | |
| DuckDuckGoSearchTool, | |
| SpeechToTextTool, | |
| Tool, | |
| VisitWebpageTool, | |
| WikipediaSearchTool, | |
| ) | |
| from transformers import AutoProcessor, AutoModelForImageTextToText | |
| class RelevantInfoRetrieverTool(Tool): | |
| name = "relevant_info_retriever" | |
| description = "Retrieves relevant to the query information." | |
| inputs = { | |
| "query": { | |
| "type": "string", | |
| "description": "The query for which to retrieve information.", | |
| }, | |
| "docs": { | |
| "type": "string", | |
| "description": "The source documents from which to choose in order to retrieve relevant information", | |
| }, | |
| } | |
| output_type = "string" | |
| def forward(self, query: str, docs: List[Document]): | |
| self.retriever = BM25Retriever.from_documents(docs) | |
| results = self.retriever.get_relevant_documents(query) | |
| if results: | |
| return "\n\n".join([doc.page_content for doc in results]) | |
| else: | |
| return "No relevant information found." | |
| class YoutubeTranscriptTool(Tool): | |
| name = "youtube_transcript" | |
| description = "Fetches youtube video's transcript." | |
| inputs = { | |
| "youtube_url": { | |
| "type": "string", | |
| "description": "The youtube video url", | |
| }, | |
| "source_langs": { | |
| "type": "array", | |
| "description": "A list of language codes in a descending priority for the video trascript.", | |
| "items": {"type": "string"}, | |
| "default": ["en"], | |
| "required": False, | |
| "nullable": True, | |
| }, | |
| "target_lang": { | |
| "type": "string", | |
| "description": "The language to which the transcript will be translated.", | |
| "default": "en", | |
| "required": False, | |
| "nullable": True, | |
| }, | |
| } | |
| output_type = "string" | |
| def forward( | |
| self, | |
| youtube_url: str, | |
| source_langs: Optional[List[str]] = ["en"], | |
| target_lang: Optional[str] = "en", | |
| ): | |
| try: | |
| loader = YoutubeLoader.from_youtube_url( | |
| youtube_url, | |
| add_video_info=True, | |
| language=source_langs, | |
| translation=target_lang, | |
| # transcript_format=TranscriptFormat.CHUNKS, | |
| # chunk_size_seconds=30, | |
| ) | |
| transcript_docs = loader.load() | |
| return transcript_docs | |
| except Exception as e: | |
| return f"Error fetching video's transcript: {e}" | |
| class ReverseStringTool(Tool): | |
| name = "reverse_string" | |
| description = "Reverses the input string." | |
| inputs = { | |
| "string": { | |
| "type": "string", | |
| "description": "The string that needs to be reversed.", | |
| } | |
| } | |
| output_type = "string" | |
| def forward(self, string: str): | |
| try: | |
| return string[-1::-1] | |
| except Exception as e: | |
| return f"Error reversing string: {e}" | |
| class SmolVLM2: | |
| """The parent class for visual analyzer tools (using SmolVLM2-500M-Video model)""" | |
| def __init__(self): | |
| """Initializations for the analyzer tool""" | |
| model_path = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" | |
| device = "cpu" # "cuda" if torch.cuda.is_available() else "cpu" | |
| self.processor = AutoProcessor.from_pretrained(model_path) | |
| self.model = AutoModelForImageTextToText.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.bfloat16, | |
| # _attn_implementation="flash_attention_2", | |
| ).to(device) | |
| class ImagesAnalyzerTool(Tool, SmolVLM2): | |
| name = "image_analyzer" | |
| description = "Analyzes each input image according to the query" | |
| inputs = { | |
| "query": { | |
| "type": "string", | |
| "description": "The query according to which the image will be analyzed.", | |
| }, | |
| "images_urls": { | |
| "type": "array", | |
| "description": "A list of strings containing the images' urls", | |
| "items": {"type": "string"}, | |
| }, | |
| } | |
| output_type = "string" | |
| def __init__(self): | |
| Tool.__init__(self) | |
| SmolVLM2.__init__(self) | |
| def forward(self, query: str, images_urls: List[str]): | |
| try: | |
| # Image message entities for the different images' urls | |
| image_message_ents = [{"type": "image", "url": iu} for iu in images_urls] | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": query, | |
| }, | |
| ] | |
| + image_message_ents, | |
| }, | |
| ] | |
| inputs = self.processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ).to(self.model.device, dtype=torch.bfloat16) | |
| generated_ids = self.model.generate( | |
| **inputs, do_sample=False, max_new_tokens=64 | |
| ) | |
| generated_texts = self.processor.batch_decode( | |
| generated_ids, | |
| skip_special_tokens=True, | |
| ) | |
| return generated_texts[0] | |
| except Exception as e: | |
| return f"Error analyzing image(s): {e}" | |
| class VideoAnalyzerTool(Tool, SmolVLM2): | |
| name = "video_analyzer" | |
| description = "Analyzes video at a specified path according to the query" | |
| inputs = { | |
| "query": { | |
| "type": "string", | |
| "description": "The query according to which the video will be analyzed.", | |
| }, | |
| "video_path": { | |
| "type": "string", | |
| "description": "A string containing the video path", | |
| }, | |
| } | |
| output_type = "string" | |
| def __init__(self): | |
| Tool.__init__(self) | |
| SmolVLM2.__init__(self) | |
| def forward(self, query: str, video_path: str) -> str: | |
| try: | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "video", "path": video_path}, | |
| {"type": "text", "text": query}, | |
| ], | |
| }, | |
| ] | |
| inputs = self.processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ).to(self.model.device, dtype=torch.bfloat16) | |
| generated_ids = self.model.generate( | |
| **inputs, do_sample=False, max_new_tokens=64 | |
| ) | |
| generated_texts = self.processor.batch_decode( | |
| generated_ids, | |
| skip_special_tokens=True, | |
| ) | |
| return generated_texts[0] | |
| except Exception as e: | |
| return f"Error analyzing video: {e}" | |
| finally: | |
| # Cleanup if needed | |
| if video_path and os.path.exists(video_path): | |
| os.remove(video_path) | |
| class FileDownloaderTool(Tool): | |
| name = "file_downloader" | |
| description = "Downloads a file returning the name of the temporarily saved file" | |
| inputs = { | |
| "file_url": { | |
| "type": "string", | |
| "description": "The url from which the file shall be downloaded.", | |
| }, | |
| } | |
| output_type = "string" | |
| def forward(self, file_url: str) -> str: | |
| response = requests.get(file_url, stream=True) | |
| response.raise_for_status() | |
| original_filename = ( | |
| response.headers.get("content-disposition", "") | |
| .split("=", -1)[-1] | |
| .strip('"') | |
| ) | |
| # Even if original_filename is empty or there is no extension, ext will be "" | |
| ext = os.path.splitext(original_filename)[-1] | |
| with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp_file: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| tmp_file.write(chunk) | |
| return tmp_file.name | |
| class YoutubeVideoDownloaderTool(Tool): | |
| name = "youtube_video_downloader" | |
| description = "Downloads the video from the specified url and returns the path where the video was saved" | |
| inputs = { | |
| "video_url": { | |
| "type": "string", | |
| "description": "A string containing the video url", | |
| }, | |
| } | |
| output_type = "string" | |
| def forward(self, video_url: str) -> str: | |
| try: | |
| saved_video_path = "" | |
| temp_dir = tempfile.gettempdir() | |
| ydl_opts = { | |
| "outtmpl": f"{temp_dir}/%(title)s.%(ext)s", # Absolute or relative path | |
| "quiet": True, | |
| } | |
| # Download youtube video as a file in tmp directory | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| info = ydl.extract_info(video_url, download=True) | |
| saved_video_path = ydl.prepare_filename(info) | |
| return saved_video_path | |
| except Exception as e: | |
| return f"Error downloading video: {e}" | |
| class LoadXlsxFileTool(Tool): | |
| name = "load_xlsx_file" | |
| description = "This tool loads xlsx file into pandas and returns it" | |
| inputs = {"file_path": {"type": "string", "description": "File path"}} | |
| output_type = "object" | |
| def forward(self, file_path: str) -> object: | |
| return pd.read_excel(file_path) | |
| class LoadTextFileTool(Tool): | |
| name = "load_text_file" | |
| description = "This tool loads any text file" | |
| inputs = {"file_path": {"type": "string", "description": "File path"}} | |
| output_type = "string" | |
| def forward(self, file_path: str) -> str: | |
| with open(file_path, "r", encoding="utf-8") as file: | |
| return file.read() | |
| class WebpageTablesContextRetrieverTool(Tool): | |
| name = "webpage_tables_context_retriever" | |
| description = """Retrieves structural context for all tables on a webpage. | |
| Returns table indexes with captions, headers, and surrounding text to help identify relevant tables. | |
| Use this first to determine which table index to extract.""" | |
| inputs = { | |
| "url": {"type": "string", "description": "The URL of the webpage to analyze"} | |
| } | |
| output_type = "object" | |
| def forward(self, url: str) -> Dict: | |
| """Retrieve context information for all tables on the page""" | |
| try: | |
| response = requests.get(url, timeout=15) | |
| response.raise_for_status() | |
| soup = BeautifulSoup(response.text, "html.parser") | |
| tables = soup.find_all("table") | |
| if not tables: | |
| return { | |
| "status": "success", | |
| "tables": [], | |
| "message": "No tables found on page", | |
| "url": url, | |
| } | |
| results = [] | |
| for i, table in enumerate(tables): | |
| context = { | |
| "index": i, | |
| "id": table.get("id", ""), | |
| "class": " ".join(table.get("class", [])), | |
| "summary": table.get("summary", ""), | |
| "caption": self._get_table_caption(table), | |
| "preceding_header": self._get_preceding_header(table), | |
| "surrounding_text": self._get_surrounding_text(table), | |
| } | |
| results.append(context) | |
| return { | |
| "status": "success", | |
| "tables": results, | |
| "url": url, | |
| "message": f"Found {len(results)} tables with context information", | |
| "suggestion": "Use html_table_extractor with the most relevant index", | |
| } | |
| except Exception as e: | |
| return { | |
| "status": "error", | |
| "url": url, | |
| "message": f"Failed to retrieve table contexts: {str(e)}", | |
| } | |
| def _get_table_caption(self, table) -> str: | |
| """Extract table caption text if available""" | |
| caption = table.find("caption") | |
| return caption.get_text(strip=True) if caption else "" | |
| def _get_preceding_header(self, table) -> str: | |
| """Find the nearest preceding heading""" | |
| for tag in table.find_all_previous(["h1", "h2", "h3", "h4", "h5", "h6"]): | |
| return tag.get_text(strip=True) | |
| return "" | |
| def _get_surrounding_text(self, table, chars=150) -> str: | |
| """Get relevant text around the table""" | |
| prev_text = " ".join( | |
| t.strip() | |
| for t in table.find_all_previous(string=True, limit=3) | |
| if t.strip() | |
| ) | |
| next_text = " ".join( | |
| t.strip() for t in table.find_all_next(string=True, limit=3) if t.strip() | |
| ) | |
| return f"...{prev_text[-chars:]} [TABLE] {next_text[:chars]}..." | |
| class HtmlTableExtractorTool(Tool): | |
| name = "html_table_extractor" | |
| description = """Extracts a specific HTML table as structured data. | |
| Use after webpage_tables_context_retriever to get the correct table index.""" | |
| inputs = { | |
| "page_url": { | |
| "type": "string", | |
| "description": "The webpage URL containing the table", | |
| }, | |
| "table_index": { | |
| "type": "integer", | |
| "description": "0-based index of the table to extract (from webpage_tables_context_retriever)", | |
| }, | |
| } | |
| output_type = "object" | |
| def forward(self, page_url: str, table_index: int) -> Dict: | |
| """Extract a specific table by index""" | |
| try: | |
| # First verify the URL is accessible | |
| test_request = requests.head(page_url, timeout=5) | |
| test_request.raise_for_status() | |
| # Read all tables | |
| tables = pd.read_html(page_url) | |
| if not tables: | |
| return { | |
| "status": "error", | |
| "message": "No tables found at URL", | |
| "url": page_url, | |
| } | |
| # Validate index | |
| if table_index < 0 or table_index >= len(tables): | |
| return { | |
| "status": "error", | |
| "message": f"Invalid table index {table_index}. Page has {len(tables)} tables.", | |
| "url": page_url, | |
| "available_indexes": list(range(len(tables))), | |
| } | |
| # Convert DataFrame to JSON-serializable format | |
| df = tables[table_index] | |
| return { | |
| "status": "success", | |
| "table_index": table_index, | |
| "table_data": df, | |
| "url": page_url, | |
| } | |
| except Exception as e: | |
| return { | |
| "status": "error", | |
| "message": f"Table extraction failed: {str(e)}", | |
| "url": page_url, | |
| "table_index": table_index, | |
| } | |