Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| LLM-based Question Classifier for Multi-Agent GAIA Solver | |
| Routes questions to appropriate specialist agents based on content analysis | |
| """ | |
| import os | |
| import json | |
| import re | |
| from typing import Dict, List, Optional, Tuple | |
| from enum import Enum | |
| from dotenv import load_dotenv | |
| # Load environment variables | |
| load_dotenv() | |
| # Import LLM (using same setup as main solver) | |
| try: | |
| from smolagents import InferenceClientModel | |
| except ImportError: | |
| # Fallback for newer smolagents versions | |
| try: | |
| from smolagents.models import InferenceClientModel | |
| except ImportError: | |
| # If all imports fail, we'll handle this in the class | |
| InferenceClientModel = None | |
| class AgentType(Enum): | |
| """Available specialist agent types""" | |
| MULTIMEDIA = "multimedia" # Video, audio, image analysis | |
| RESEARCH = "research" # Web search, Wikipedia, academic papers | |
| LOGIC_MATH = "logic_math" # Puzzles, calculations, pattern recognition | |
| FILE_PROCESSING = "file_processing" # Excel, Python code, document analysis | |
| GENERAL = "general" # Fallback for unclear cases | |
| # Regular expression patterns for better content type detection | |
| YOUTUBE_URL_PATTERN = r'(https?://)?(www\.)?(youtube\.com|youtu\.?be)/.+?(?=\s|$)' | |
| # Enhanced YouTube URL pattern with more variations (shortened links, IDs, watch URLs, etc) | |
| ENHANCED_YOUTUBE_URL_PATTERN = r'(https?://)?(www\.)?(youtube\.com|youtu\.?be)/(?:watch\?v=|embed/|v/|shorts/|playlist\?list=|channel/|user/|[^/\s]+/?)?([^\s&?/]+)' | |
| VIDEO_PATTERNS = [r'youtube\.(com|be)', r'video', r'watch\?v='] | |
| AUDIO_PATTERNS = [r'\.mp3\b', r'\.wav\b', r'audio', r'sound', r'listen', r'music', r'podcast'] | |
| IMAGE_PATTERNS = [r'\.jpg\b', r'\.jpeg\b', r'\.png\b', r'\.gif\b', r'image', r'picture', r'photo'] | |
| class QuestionClassifier: | |
| """LLM-powered question classifier for agent routing""" | |
| def __init__(self): | |
| self.hf_token = os.getenv("HUGGINGFACE_TOKEN") | |
| if not self.hf_token: | |
| raise ValueError("HUGGINGFACE_TOKEN environment variable is required") | |
| # Initialize lightweight model for classification | |
| if InferenceClientModel is not None: | |
| self.classifier_model = InferenceClientModel( | |
| model_id="Qwen/Qwen2.5-7B-Instruct", # Smaller, faster model for classification | |
| token=self.hf_token | |
| ) | |
| else: | |
| # Fallback: Use a simple rule-based classifier | |
| self.classifier_model = None | |
| print("β οΈ Using fallback rule-based classification (InferenceClientModel not available)") | |
| def classify_question(self, question: str, file_name: str = "") -> Dict: | |
| """ | |
| Classify a GAIA question and determine the best agent routing | |
| Args: | |
| question: The question text | |
| file_name: Associated file name (if any) | |
| Returns: | |
| Dict with classification results and routing information | |
| """ | |
| # First, check for direct YouTube URL pattern as a fast path (enhanced detection) | |
| if re.search(ENHANCED_YOUTUBE_URL_PATTERN, question): | |
| return self._create_youtube_video_classification(question, file_name) | |
| # Secondary check for YouTube keywords plus URL-like text | |
| question_lower = question.lower() | |
| if "youtube" in question_lower and any(term in question_lower for term in ["video", "watch", "channel"]): | |
| # Possible YouTube question, check more carefully | |
| if re.search(r'(youtube\.com|youtu\.be)', question): | |
| return self._create_youtube_video_classification(question, file_name) | |
| # Continue with regular classification | |
| # Create classification prompt | |
| classification_prompt = f""" | |
| Analyze this GAIA benchmark question and classify it for routing to specialist agents. | |
| Question: {question} | |
| Associated file: {file_name if file_name else "None"} | |
| Classify this question into ONE primary category and optionally secondary categories: | |
| AGENT CATEGORIES: | |
| 1. MULTIMEDIA - Questions involving video analysis, audio transcription, image analysis | |
| Examples: YouTube videos, MP3 files, PNG images, visual content analysis | |
| 2. RESEARCH - Questions requiring web search, Wikipedia lookup, or factual data retrieval | |
| Examples: Factual lookups, biographical info, historical data, citations, sports statistics, company information, academic papers | |
| Note: If a question requires looking up data first (even for later calculations), classify as RESEARCH | |
| 3. LOGIC_MATH - Questions involving pure mathematical calculations or logical reasoning with given data | |
| Examples: Mathematical puzzles with provided numbers, algebraic equations, geometric calculations, logical deduction puzzles | |
| Note: Use this ONLY when all data is provided and no external lookup is needed | |
| 4. FILE_PROCESSING - Questions requiring file analysis (Excel, Python code, documents) | |
| Examples: Spreadsheet analysis, code execution, document parsing | |
| 5. GENERAL - Simple questions or unclear classification | |
| ANALYSIS REQUIRED: | |
| 1. Primary agent type (required) | |
| 2. Secondary agent types (if question needs multiple specialists) | |
| 3. Complexity level (1-5, where 5 is most complex) | |
| 4. Tools needed (list specific tools that would be useful) | |
| 5. Reasoning (explain your classification choice) | |
| Respond in JSON format: | |
| {{ | |
| "primary_agent": "AGENT_TYPE", | |
| "secondary_agents": ["AGENT_TYPE2", "AGENT_TYPE3"], | |
| "complexity": 3, | |
| "confidence": 0.95, | |
| "tools_needed": ["tool1", "tool2"], | |
| "reasoning": "explanation of classification", | |
| "requires_multimodal": false, | |
| "estimated_steps": 5 | |
| }} | |
| """ | |
| try: | |
| # Get classification from LLM or fallback | |
| if self.classifier_model is not None: | |
| messages = [{"role": "user", "content": classification_prompt}] | |
| response = self.classifier_model(messages) | |
| else: | |
| # Fallback to rule-based classification | |
| return self._fallback_classification(question, file_name) | |
| # Parse JSON response | |
| classification_text = response.content.strip() | |
| # Extract JSON if wrapped in code blocks | |
| if "```json" in classification_text: | |
| json_start = classification_text.find("```json") + 7 | |
| json_end = classification_text.find("```", json_start) | |
| classification_text = classification_text[json_start:json_end].strip() | |
| elif "```" in classification_text: | |
| json_start = classification_text.find("```") + 3 | |
| json_end = classification_text.find("```", json_start) | |
| classification_text = classification_text[json_start:json_end].strip() | |
| classification = json.loads(classification_text) | |
| # Validate and normalize the response | |
| return self._validate_classification(classification, question, file_name) | |
| except Exception as e: | |
| print(f"Classification error: {e}") | |
| # Fallback classification | |
| return self._fallback_classification(question, file_name) | |
| def _create_youtube_video_classification(self, question: str, file_name: str = "") -> Dict: | |
| """Create a specialized classification for YouTube video questions""" | |
| # Use enhanced pattern for more robust URL detection | |
| youtube_url_match = re.search(ENHANCED_YOUTUBE_URL_PATTERN, question) | |
| if not youtube_url_match: | |
| # Fall back to original pattern | |
| youtube_url_match = re.search(YOUTUBE_URL_PATTERN, question) | |
| # Extract the URL | |
| if youtube_url_match: | |
| youtube_url = youtube_url_match.group(0) | |
| else: | |
| # If we can't extract a URL but it looks like a YouTube question | |
| question_lower = question.lower() | |
| if "youtube" in question_lower: | |
| # Try to find any URL-like pattern | |
| url_match = re.search(r'https?://\S+', question) | |
| youtube_url = url_match.group(0) if url_match else "unknown_youtube_url" | |
| else: | |
| youtube_url = "unknown_youtube_url" | |
| # Determine complexity based on question | |
| question_lower = question.lower() | |
| complexity = 3 # Default | |
| confidence = 0.98 # High default confidence for YouTube questions | |
| # Analyze the task more specifically | |
| if any(term in question_lower for term in ['count', 'how many', 'highest number']): | |
| complexity = 2 # Counting tasks | |
| task_type = "counting" | |
| elif any(term in question_lower for term in ['relationship', 'compare', 'difference']): | |
| complexity = 4 # Comparative analysis | |
| task_type = "comparison" | |
| elif any(term in question_lower for term in ['say', 'speech', 'dialogue', 'talk', 'speak']): | |
| complexity = 3 # Speech analysis | |
| task_type = "speech_analysis" | |
| elif any(term in question_lower for term in ['scene', 'visual', 'appear', 'shown']): | |
| complexity = 3 # Visual analysis | |
| task_type = "visual_analysis" | |
| else: | |
| task_type = "general_video_analysis" | |
| # Always use analyze_youtube_video as the primary tool | |
| tools_needed = ["analyze_youtube_video"] | |
| # Set highest priority for analyze_youtube_video in case other tools are suggested | |
| # This ensures it always appears first in the tools list | |
| primary_tool = "analyze_youtube_video" | |
| # Add secondary tools if the task might need them | |
| if "audio" in question_lower or any(term in question_lower for term in ['say', 'speech', 'dialogue']): | |
| tools_needed.append("analyze_audio_file") # Add as fallback | |
| return { | |
| "primary_agent": "multimedia", | |
| "secondary_agents": [], | |
| "complexity": complexity, | |
| "confidence": confidence, | |
| "tools_needed": tools_needed, | |
| "reasoning": f"Question contains a YouTube URL and requires {task_type}", | |
| "requires_multimodal": True, | |
| "estimated_steps": 3, | |
| "question_summary": question[:100] + "..." if len(question) > 100 else question, | |
| "has_file": bool(file_name), | |
| "media_type": "youtube_video", | |
| "media_url": youtube_url, | |
| "task_type": task_type # Add task type for more specific handling | |
| } | |
| def _validate_classification(self, classification: Dict, question: str, file_name: str) -> Dict: | |
| """Validate and normalize classification response""" | |
| # Ensure primary agent is valid | |
| primary_agent = classification.get("primary_agent", "GENERAL") | |
| if primary_agent not in [agent.value.upper() for agent in AgentType]: | |
| primary_agent = "GENERAL" | |
| # Validate secondary agents | |
| secondary_agents = classification.get("secondary_agents", []) | |
| valid_secondary = [ | |
| agent for agent in secondary_agents | |
| if agent.upper() in [a.value.upper() for a in AgentType] | |
| ] | |
| # Ensure confidence is between 0 and 1 | |
| confidence = max(0.0, min(1.0, classification.get("confidence", 0.5))) | |
| # Ensure complexity is between 1 and 5 | |
| complexity = max(1, min(5, classification.get("complexity", 3))) | |
| return { | |
| "primary_agent": primary_agent.lower(), | |
| "secondary_agents": [agent.lower() for agent in valid_secondary], | |
| "complexity": complexity, | |
| "confidence": confidence, | |
| "tools_needed": classification.get("tools_needed", []), | |
| "reasoning": classification.get("reasoning", "Automated classification"), | |
| "requires_multimodal": classification.get("requires_multimodal", False), | |
| "estimated_steps": classification.get("estimated_steps", 5), | |
| "question_summary": question[:100] + "..." if len(question) > 100 else question, | |
| "has_file": bool(file_name) | |
| } | |
| def _fallback_classification(self, question: str, file_name: str = "") -> Dict: | |
| """Fallback classification when LLM fails""" | |
| # Simple heuristic-based fallback | |
| question_lower = question.lower() | |
| # Check for YouTube URL first (most specific case) - use enhanced pattern | |
| youtube_match = re.search(ENHANCED_YOUTUBE_URL_PATTERN, question) | |
| if youtube_match: | |
| # Use the dedicated method for YouTube classification to ensure consistency | |
| return self._create_youtube_video_classification(question, file_name) | |
| # Secondary check for YouTube references (may not have a valid URL format) | |
| if "youtube" in question_lower and any(keyword in question_lower for keyword in | |
| ["video", "watch", "link", "url", "channel"]): | |
| # Likely a YouTube question even without a perfect URL match | |
| # Create a custom classification with high confidence | |
| return { | |
| "primary_agent": "multimedia", | |
| "secondary_agents": [], | |
| "complexity": 3, | |
| "confidence": 0.85, | |
| "tools_needed": ["analyze_youtube_video"], | |
| "reasoning": "Fallback detected YouTube reference without complete URL", | |
| "requires_multimodal": True, | |
| "estimated_steps": 3, | |
| "question_summary": question[:100] + "..." if len(question) > 100 else question, | |
| "has_file": bool(file_name), | |
| "media_type": "youtube_video", | |
| "media_url": "youtube_reference_detected" # Placeholder | |
| } | |
| # Check other multimedia patterns | |
| # Video patterns (beyond YouTube) | |
| elif any(re.search(pattern, question_lower) for pattern in VIDEO_PATTERNS): | |
| return { | |
| "primary_agent": "multimedia", | |
| "secondary_agents": [], | |
| "complexity": 3, | |
| "confidence": 0.8, | |
| "tools_needed": ["analyze_video_frames"], | |
| "reasoning": "Fallback detected video-related content", | |
| "requires_multimodal": True, | |
| "estimated_steps": 4, | |
| "question_summary": question[:100] + "..." if len(question) > 100 else question, | |
| "has_file": bool(file_name), | |
| "media_type": "video" | |
| } | |
| # Audio patterns | |
| elif any(re.search(pattern, question_lower) for pattern in AUDIO_PATTERNS): | |
| return { | |
| "primary_agent": "multimedia", | |
| "secondary_agents": [], | |
| "complexity": 3, | |
| "confidence": 0.8, | |
| "tools_needed": ["analyze_audio_file"], | |
| "reasoning": "Fallback detected audio-related content", | |
| "requires_multimodal": True, | |
| "estimated_steps": 3, | |
| "question_summary": question[:100] + "..." if len(question) > 100 else question, | |
| "has_file": bool(file_name), | |
| "media_type": "audio" | |
| } | |
| # Image patterns | |
| elif any(re.search(pattern, question_lower) for pattern in IMAGE_PATTERNS): | |
| return { | |
| "primary_agent": "multimedia", | |
| "secondary_agents": [], | |
| "complexity": 2, | |
| "confidence": 0.8, | |
| "tools_needed": ["analyze_image_with_gemini"], | |
| "reasoning": "Fallback detected image-related content", | |
| "requires_multimodal": True, | |
| "estimated_steps": 2, | |
| "question_summary": question[:100] + "..." if len(question) > 100 else question, | |
| "has_file": bool(file_name), | |
| "media_type": "image" | |
| } | |
| # General multimedia keywords | |
| elif any(keyword in question_lower for keyword in ["multimedia", "visual", "picture", "screenshot"]): | |
| primary_agent = "multimedia" | |
| tools_needed = ["analyze_image_with_gemini"] | |
| # Research patterns | |
| elif any(keyword in question_lower for keyword in ["wikipedia", "search", "find", "who", "what", "when", "where"]): | |
| primary_agent = "research" | |
| tools_needed = ["research_with_comprehensive_fallback"] | |
| # Math/Logic patterns | |
| elif any(keyword in question_lower for keyword in ["calculate", "number", "count", "math", "opposite", "pattern"]): | |
| primary_agent = "logic_math" | |
| tools_needed = ["advanced_calculator"] | |
| # File processing | |
| elif file_name and any(ext in file_name.lower() for ext in [".xlsx", ".py", ".csv", ".pdf"]): | |
| primary_agent = "file_processing" | |
| if ".xlsx" in file_name.lower(): | |
| tools_needed = ["analyze_excel_file"] | |
| elif ".py" in file_name.lower(): | |
| tools_needed = ["analyze_python_code"] | |
| else: | |
| tools_needed = ["analyze_text_file"] | |
| # Default | |
| else: | |
| primary_agent = "general" | |
| tools_needed = [] | |
| return { | |
| "primary_agent": primary_agent, | |
| "secondary_agents": [], | |
| "complexity": 3, | |
| "confidence": 0.6, | |
| "tools_needed": tools_needed, | |
| "reasoning": "Fallback heuristic classification", | |
| "requires_multimodal": bool(file_name), | |
| "estimated_steps": 5, | |
| "question_summary": question[:100] + "..." if len(question) > 100 else question, | |
| "has_file": bool(file_name) | |
| } | |
| def batch_classify(self, questions: List[Dict]) -> List[Dict]: | |
| """Classify multiple questions in batch""" | |
| results = [] | |
| for q in questions: | |
| question_text = q.get("question", "") | |
| file_name = q.get("file_name", "") | |
| task_id = q.get("task_id", "") | |
| classification = self.classify_question(question_text, file_name) | |
| classification["task_id"] = task_id | |
| results.append(classification) | |
| return results | |
| def get_routing_recommendation(self, classification: Dict) -> Dict: | |
| """Get specific routing recommendations based on classification""" | |
| primary_agent = classification["primary_agent"] | |
| complexity = classification["complexity"] | |
| routing = { | |
| "primary_route": primary_agent, | |
| "requires_coordination": len(classification["secondary_agents"]) > 0, | |
| "parallel_execution": False, | |
| "estimated_duration": "medium", | |
| "special_requirements": [] | |
| } | |
| # Add special requirements based on agent type | |
| if primary_agent == "multimedia": | |
| routing["special_requirements"].extend([ | |
| "Requires yt-dlp and ffmpeg for video processing", | |
| "Needs Gemini Vision API for image analysis", | |
| "May need large temp storage for video files" | |
| ]) | |
| elif primary_agent == "research": | |
| routing["special_requirements"].extend([ | |
| "Requires web search and Wikipedia API access", | |
| "May need academic database access", | |
| "Benefits from citation tracking tools" | |
| ]) | |
| elif primary_agent == "file_processing": | |
| routing["special_requirements"].extend([ | |
| "Requires file processing libraries (pandas, openpyxl)", | |
| "May need sandboxed code execution environment", | |
| "Needs secure file handling" | |
| ]) | |
| # Adjust duration estimate based on complexity | |
| if complexity >= 4: | |
| routing["estimated_duration"] = "long" | |
| elif complexity <= 2: | |
| routing["estimated_duration"] = "short" | |
| # Suggest parallel execution for multi-agent scenarios | |
| if len(classification["secondary_agents"]) >= 2: | |
| routing["parallel_execution"] = True | |
| return routing | |
| def test_classifier(): | |
| """Test the classifier with sample GAIA questions""" | |
| # Sample questions from our GAIA set | |
| test_questions = [ | |
| { | |
| "task_id": "video_test", | |
| "question": "In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species to be on camera simultaneously?", | |
| "file_name": "" | |
| }, | |
| { | |
| "task_id": "youtube_short_test", | |
| "question": "Check this YouTube video https://youtu.be/L1vXCYZAYYM and count the birds", | |
| "file_name": "" | |
| }, | |
| { | |
| "task_id": "video_url_variation", | |
| "question": "How many people appear in the YouTube video at youtube.com/watch?v=dQw4w9WgXcQ", | |
| "file_name": "" | |
| }, | |
| { | |
| "task_id": "research_test", | |
| "question": "How many studio albums were published by Mercedes Sosa between 2000 and 2009?", | |
| "file_name": "" | |
| }, | |
| { | |
| "task_id": "logic_test", | |
| "question": ".rewsna eht sa \"tfel\" drow eht fo etisoppo eht etirw ,ecnetnes siht dnatsrednu uoy fI", | |
| "file_name": "" | |
| }, | |
| { | |
| "task_id": "file_test", | |
| "question": "What is the final numeric output from the attached Python code?", | |
| "file_name": "script.py" | |
| } | |
| ] | |
| classifier = QuestionClassifier() | |
| print("π§ Testing Question Classifier") | |
| print("=" * 50) | |
| for question in test_questions: | |
| print(f"\nπ Question: {question['question'][:80]}...") | |
| classification = classifier.classify_question( | |
| question["question"], | |
| question["file_name"] | |
| ) | |
| print(f"π― Primary Agent: {classification['primary_agent']}") | |
| print(f"π§ Tools Needed: {classification['tools_needed']}") | |
| print(f"π Complexity: {classification['complexity']}/5") | |
| print(f"π² Confidence: {classification['confidence']:.2f}") | |
| print(f"π Reasoning: {classification['reasoning']}") | |
| routing = classifier.get_routing_recommendation(classification) | |
| print(f"π Routing: {routing['primary_route']} ({'coordination needed' if routing['requires_coordination'] else 'single agent'})") | |
| if __name__ == "__main__": | |
| test_classifier() |