Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Async Batch Processor for GAIA Questions | |
| Comprehensive concurrent processing with progress tracking and error handling | |
| """ | |
| import asyncio | |
| import time | |
| from datetime import datetime | |
| from typing import List, Dict, Any, Optional, Callable | |
| from pathlib import Path | |
| import sys | |
| # Add parent directory to path for imports | |
| sys.path.append(str(Path(__file__).parent.parent)) | |
| from tests.async_batch_logger import AsyncBatchLogger, QuestionResult | |
| from tests.async_batch_gaia_solver import AsyncGAIASolver | |
| from main import GAIASolver | |
| from question_classifier import QuestionClassifier | |
| class BatchQuestionProcessor: | |
| """ | |
| Comprehensive async batch processor for GAIA questions | |
| Features: Concurrency control, progress tracking, error resilience, real-time logging | |
| """ | |
| def __init__(self, | |
| max_concurrent: int = 3, | |
| question_timeout: int = 300, # 5 minutes per question | |
| progress_interval: int = 10): # Progress update every 10 seconds | |
| self.max_concurrent = max_concurrent | |
| self.question_timeout = question_timeout | |
| self.progress_interval = progress_interval | |
| # Semaphore for concurrency control | |
| self.semaphore = asyncio.Semaphore(max_concurrent) | |
| # Progress tracking | |
| self.completed_count = 0 | |
| self.total_questions = 0 | |
| self.start_time = None | |
| # Logger | |
| self.logger = AsyncBatchLogger() | |
| async def process_questions_batch(self, | |
| questions: List[Dict[str, Any]], | |
| solver_kwargs: Optional[Dict] = None) -> Dict[str, Any]: | |
| """ | |
| Process a batch of questions with full async concurrency | |
| Args: | |
| questions: List of question dictionaries | |
| solver_kwargs: Kwargs to pass to GAIASolver initialization | |
| Returns: | |
| Comprehensive batch results with classification analysis | |
| """ | |
| self.total_questions = len(questions) | |
| self.start_time = time.time() | |
| # Initialize batch logging | |
| await self.logger.log_batch_start(self.total_questions, self.max_concurrent) | |
| # Default solver configuration | |
| if solver_kwargs is None: | |
| solver_kwargs = { | |
| "use_kluster": True, | |
| "kluster_model": "qwen3-235b" | |
| } | |
| # Create async solver | |
| async_solver = AsyncGAIASolver( | |
| solver_class=GAIASolver, | |
| classifier_class=QuestionClassifier, | |
| **solver_kwargs | |
| ) | |
| # Start progress tracking task | |
| progress_task = asyncio.create_task(self._track_progress()) | |
| try: | |
| # Process all questions concurrently | |
| print(f"π Starting concurrent processing of {len(questions)} questions...") | |
| print(f"π Max concurrent: {self.max_concurrent} | Timeout: {self.question_timeout}s") | |
| tasks = [] | |
| for question_data in questions: | |
| task = asyncio.create_task( | |
| self._process_single_question(async_solver, question_data) | |
| ) | |
| tasks.append(task) | |
| # Wait for all questions to complete | |
| results = await asyncio.gather(*tasks, return_exceptions=True) | |
| # Process results | |
| batch_results = await self._compile_batch_results(results, questions) | |
| # Complete batch logging | |
| await self.logger.log_batch_complete() | |
| return batch_results | |
| finally: | |
| # Stop progress tracking | |
| progress_task.cancel() | |
| try: | |
| await progress_task | |
| except asyncio.CancelledError: | |
| pass | |
| async def _process_single_question(self, | |
| async_solver: AsyncGAIASolver, | |
| question_data: Dict[str, Any]) -> QuestionResult: | |
| """Process a single question with full error handling and logging""" | |
| task_id = question_data.get('task_id', 'unknown') | |
| async with self.semaphore: # Acquire semaphore for concurrency control | |
| try: | |
| # Log question start | |
| await self.logger.log_question_start(task_id, question_data) | |
| # Process with timeout | |
| result = await asyncio.wait_for( | |
| async_solver.solve_question_async(question_data, task_id), | |
| timeout=self.question_timeout | |
| ) | |
| # Create QuestionResult object | |
| question_result = QuestionResult( | |
| task_id=task_id, | |
| question_text=question_data.get('question', ''), | |
| classification=result.get('classification', {}).get('primary_agent', 'unknown'), | |
| complexity=result.get('classification', {}).get('complexity', 0), | |
| confidence=result.get('classification', {}).get('confidence', 0.0), | |
| expected_answer=result.get('validation', {}).get('expected', ''), | |
| our_answer=result.get('answer', ''), | |
| status=result.get('validation', {}).get('status', 'UNKNOWN'), | |
| accuracy_score=result.get('validation', {}).get('accuracy_score', 0.0), | |
| total_duration=result.get('timing_info', {}).get('total_duration', 0.0), | |
| classification_time=result.get('timing_info', {}).get('classification_time', 0.0), | |
| solving_time=result.get('timing_info', {}).get('solving_time', 0.0), | |
| validation_time=result.get('timing_info', {}).get('validation_time', 0.0), | |
| error_type=result.get('error_type'), | |
| error_details=str(result.get('error_details', '')), | |
| tools_used=result.get('classification', {}).get('tools_needed', []), | |
| anti_hallucination_applied=False, # TODO: Track this from solver | |
| override_reason=None | |
| ) | |
| # Log classification details | |
| if result.get('classification'): | |
| await self.logger.log_classification(task_id, result['classification']) | |
| # Log answer processing (if available in result) | |
| if result.get('answer'): | |
| await self.logger.log_answer_processing( | |
| task_id, | |
| str(result.get('answer', '')), | |
| str(result.get('answer', '')) | |
| ) | |
| # Log question completion | |
| await self.logger.log_question_complete(task_id, question_result) | |
| # Update progress | |
| self.completed_count += 1 | |
| return question_result | |
| except asyncio.TimeoutError: | |
| print(f"β±οΈ [{task_id[:8]}...] Question timed out after {self.question_timeout}s") | |
| timeout_result = QuestionResult( | |
| task_id=task_id, | |
| question_text=question_data.get('question', ''), | |
| classification='timeout', | |
| complexity=0, | |
| confidence=0.0, | |
| expected_answer='', | |
| our_answer='', | |
| status='TIMEOUT', | |
| accuracy_score=0.0, | |
| total_duration=self.question_timeout, | |
| classification_time=0.0, | |
| solving_time=self.question_timeout, | |
| validation_time=0.0, | |
| error_type='timeout', | |
| error_details=f'Question processing timed out after {self.question_timeout} seconds', | |
| tools_used=[], | |
| anti_hallucination_applied=False, | |
| override_reason=None | |
| ) | |
| await self.logger.log_question_complete(task_id, timeout_result) | |
| self.completed_count += 1 | |
| return timeout_result | |
| except Exception as e: | |
| print(f"β [{task_id[:8]}...] Unexpected error: {str(e)}") | |
| error_result = QuestionResult( | |
| task_id=task_id, | |
| question_text=question_data.get('question', ''), | |
| classification='error', | |
| complexity=0, | |
| confidence=0.0, | |
| expected_answer='', | |
| our_answer='', | |
| status='ERROR', | |
| accuracy_score=0.0, | |
| total_duration=time.time() - self.start_time if self.start_time else 0.0, | |
| classification_time=0.0, | |
| solving_time=0.0, | |
| validation_time=0.0, | |
| error_type='unexpected_error', | |
| error_details=str(e), | |
| tools_used=[], | |
| anti_hallucination_applied=False, | |
| override_reason=None | |
| ) | |
| await self.logger.log_question_complete(task_id, error_result) | |
| self.completed_count += 1 | |
| return error_result | |
| async def _track_progress(self): | |
| """Background task for real-time progress tracking""" | |
| while True: | |
| try: | |
| await asyncio.sleep(self.progress_interval) | |
| await self.logger.log_batch_progress() | |
| except asyncio.CancelledError: | |
| break | |
| except Exception as e: | |
| print(f"β οΈ Progress tracking error: {e}") | |
| async def _compile_batch_results(self, | |
| results: List[QuestionResult], | |
| questions: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """Compile comprehensive batch results with analysis""" | |
| # Count results by status | |
| status_counts = { | |
| "CORRECT": 0, | |
| "PARTIAL": 0, | |
| "INCORRECT": 0, | |
| "TIMEOUT": 0, | |
| "ERROR": 0 | |
| } | |
| # Count by classification | |
| classification_counts = {} | |
| # Timing analysis | |
| total_duration = 0.0 | |
| successful_questions = [] | |
| for result in results: | |
| if isinstance(result, QuestionResult): | |
| # Status counting | |
| status = result.status | |
| if status in status_counts: | |
| status_counts[status] += 1 | |
| # Classification counting | |
| classification = result.classification | |
| if classification not in classification_counts: | |
| classification_counts[classification] = 0 | |
| classification_counts[classification] += 1 | |
| # Timing analysis | |
| total_duration += result.total_duration | |
| if result.status in ["CORRECT", "PARTIAL"]: | |
| successful_questions.append(result) | |
| # Calculate accuracy metrics | |
| total_completed = len([r for r in results if isinstance(r, QuestionResult)]) | |
| accuracy_rate = status_counts["CORRECT"] / total_completed if total_completed > 0 else 0.0 | |
| success_rate = (status_counts["CORRECT"] + status_counts["PARTIAL"]) / total_completed if total_completed > 0 else 0.0 | |
| # Performance metrics | |
| avg_duration = total_duration / total_completed if total_completed > 0 else 0.0 | |
| batch_summary = { | |
| "timestamp": datetime.now().isoformat(), | |
| "total_questions": self.total_questions, | |
| "completed_questions": total_completed, | |
| "accuracy_metrics": { | |
| "accuracy_rate": accuracy_rate, | |
| "success_rate": success_rate, | |
| "correct_answers": status_counts["CORRECT"], | |
| "partial_answers": status_counts["PARTIAL"], | |
| "incorrect_answers": status_counts["INCORRECT"], | |
| "timeouts": status_counts["TIMEOUT"], | |
| "errors": status_counts["ERROR"] | |
| }, | |
| "classification_breakdown": classification_counts, | |
| "performance_metrics": { | |
| "total_duration": total_duration, | |
| "average_duration": avg_duration, | |
| "max_concurrent": self.max_concurrent, | |
| "question_timeout": self.question_timeout | |
| }, | |
| "detailed_results": [result for result in results if isinstance(result, QuestionResult)] | |
| } | |
| return batch_summary | |
| async def main(): | |
| """Test the async batch processor with a small subset of questions""" | |
| try: | |
| # Import required classes | |
| from gaia_web_loader import GAIAQuestionLoaderWeb | |
| print("π§ͺ Testing Async Batch Processor") | |
| print("=" * 60) | |
| # Load a few test questions | |
| print("π Loading test questions...") | |
| loader = GAIAQuestionLoaderWeb() | |
| all_questions = loader.questions | |
| # Use first 3 questions for testing | |
| test_questions = all_questions[:3] | |
| print(f"β Loaded {len(test_questions)} test questions") | |
| for i, q in enumerate(test_questions): | |
| task_id = q.get('task_id', 'unknown') | |
| question = q.get('question', '')[:50] + "..." | |
| print(f" {i+1}. {task_id[:8]}... - {question}") | |
| # Initialize processor | |
| print(f"\nπ Initializing batch processor...") | |
| processor = BatchQuestionProcessor( | |
| max_concurrent=2, # Lower concurrency for testing | |
| question_timeout=180, # 3 minutes timeout for testing | |
| progress_interval=5 # Progress updates every 5 seconds | |
| ) | |
| # Process batch | |
| print(f"\nπ Starting batch processing...") | |
| results = await processor.process_questions_batch(test_questions) | |
| # Display results | |
| print(f"\nπ BATCH RESULTS:") | |
| print("=" * 60) | |
| accuracy = results["accuracy_metrics"]["accuracy_rate"] | |
| success = results["accuracy_metrics"]["success_rate"] | |
| print(f"β Accuracy Rate: {accuracy:.1%}") | |
| print(f"π― Success Rate: {success:.1%}") | |
| print(f"β±οΈ Total Duration: {results['performance_metrics']['total_duration']:.1f}s") | |
| print(f"β‘ Average Duration: {results['performance_metrics']['average_duration']:.1f}s") | |
| print(f"\nπ Classification Breakdown:") | |
| for classification, count in results["classification_breakdown"].items(): | |
| print(f" - {classification}: {count}") | |
| print(f"\nπ Status Breakdown:") | |
| for status, count in results["accuracy_metrics"].items(): | |
| if isinstance(count, int): | |
| print(f" - {status}: {count}") | |
| print(f"\nβ Async batch processing test completed successfully!") | |
| except Exception as e: | |
| print(f"β Test failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| if __name__ == "__main__": | |
| asyncio.run(main()) |