Spaces:
Sleeping
Sleeping
| import asyncio, os | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks, Request | |
| from fastapi.responses import JSONResponse | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.middleware.gzip import GZipMiddleware | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.exceptions import RequestValidationError | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from starlette.responses import FileResponse | |
| from fastapi.responses import RedirectResponse | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from fastapi import Depends, Security | |
| from pydantic import BaseModel, HttpUrl, Field | |
| from typing import Optional, List, Dict, Any, Union | |
| import psutil | |
| import time | |
| import uuid | |
| from collections import defaultdict | |
| from urllib.parse import urlparse | |
| import math | |
| import logging | |
| from enum import Enum | |
| from dataclasses import dataclass | |
| import json | |
| from crawl4ai import AsyncWebCrawler, CrawlResult, CacheMode | |
| from crawl4ai.config import MIN_WORD_THRESHOLD | |
| from crawl4ai.extraction_strategy import ( | |
| LLMExtractionStrategy, | |
| CosineStrategy, | |
| JsonCssExtractionStrategy, | |
| ) | |
| __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class TaskStatus(str, Enum): | |
| PENDING = "pending" | |
| PROCESSING = "processing" | |
| COMPLETED = "completed" | |
| FAILED = "failed" | |
| class CrawlerType(str, Enum): | |
| BASIC = "basic" | |
| LLM = "llm" | |
| COSINE = "cosine" | |
| JSON_CSS = "json_css" | |
| class ExtractionConfig(BaseModel): | |
| type: CrawlerType | |
| params: Dict[str, Any] = {} | |
| class ChunkingStrategy(BaseModel): | |
| type: str | |
| params: Dict[str, Any] = {} | |
| class ContentFilter(BaseModel): | |
| type: str = "bm25" | |
| params: Dict[str, Any] = {} | |
| class CrawlRequest(BaseModel): | |
| urls: Union[HttpUrl, List[HttpUrl]] | |
| word_count_threshold: int = MIN_WORD_THRESHOLD | |
| extraction_config: Optional[ExtractionConfig] = None | |
| chunking_strategy: Optional[ChunkingStrategy] = None | |
| content_filter: Optional[ContentFilter] = None | |
| js_code: Optional[List[str]] = None | |
| wait_for: Optional[str] = None | |
| css_selector: Optional[str] = None | |
| screenshot: bool = False | |
| magic: bool = False | |
| extra: Optional[Dict[str, Any]] = {} | |
| session_id: Optional[str] = None | |
| cache_mode: Optional[CacheMode] = CacheMode.ENABLED | |
| priority: int = Field(default=5, ge=1, le=10) | |
| ttl: Optional[int] = 3600 | |
| crawler_params: Dict[str, Any] = {} | |
| class TaskInfo: | |
| id: str | |
| status: TaskStatus | |
| result: Optional[Union[CrawlResult, List[CrawlResult]]] = None | |
| error: Optional[str] = None | |
| created_at: float = time.time() | |
| ttl: int = 3600 | |
| class ResourceMonitor: | |
| def __init__(self, max_concurrent_tasks: int = 10): | |
| self.max_concurrent_tasks = max_concurrent_tasks | |
| self.memory_threshold = 0.85 | |
| self.cpu_threshold = 0.90 | |
| self._last_check = 0 | |
| self._check_interval = 1 # seconds | |
| self._last_available_slots = max_concurrent_tasks | |
| async def get_available_slots(self) -> int: | |
| current_time = time.time() | |
| if current_time - self._last_check < self._check_interval: | |
| return self._last_available_slots | |
| mem_usage = psutil.virtual_memory().percent / 100 | |
| cpu_usage = psutil.cpu_percent() / 100 | |
| memory_factor = max( | |
| 0, (self.memory_threshold - mem_usage) / self.memory_threshold | |
| ) | |
| cpu_factor = max(0, (self.cpu_threshold - cpu_usage) / self.cpu_threshold) | |
| self._last_available_slots = math.floor( | |
| self.max_concurrent_tasks * min(memory_factor, cpu_factor) | |
| ) | |
| self._last_check = current_time | |
| return self._last_available_slots | |
| class TaskManager: | |
| def __init__(self, cleanup_interval: int = 300): | |
| self.tasks: Dict[str, TaskInfo] = {} | |
| self.high_priority = asyncio.PriorityQueue() | |
| self.low_priority = asyncio.PriorityQueue() | |
| self.cleanup_interval = cleanup_interval | |
| self.cleanup_task = None | |
| async def start(self): | |
| self.cleanup_task = asyncio.create_task(self._cleanup_loop()) | |
| async def stop(self): | |
| if self.cleanup_task: | |
| self.cleanup_task.cancel() | |
| try: | |
| await self.cleanup_task | |
| except asyncio.CancelledError: | |
| pass | |
| async def add_task(self, task_id: str, priority: int, ttl: int) -> None: | |
| task_info = TaskInfo(id=task_id, status=TaskStatus.PENDING, ttl=ttl) | |
| self.tasks[task_id] = task_info | |
| queue = self.high_priority if priority > 5 else self.low_priority | |
| await queue.put((-priority, task_id)) # Negative for proper priority ordering | |
| async def get_next_task(self) -> Optional[str]: | |
| try: | |
| # Try high priority first | |
| _, task_id = await asyncio.wait_for(self.high_priority.get(), timeout=0.1) | |
| return task_id | |
| except asyncio.TimeoutError: | |
| try: | |
| # Then try low priority | |
| _, task_id = await asyncio.wait_for( | |
| self.low_priority.get(), timeout=0.1 | |
| ) | |
| return task_id | |
| except asyncio.TimeoutError: | |
| return None | |
| def update_task( | |
| self, task_id: str, status: TaskStatus, result: Any = None, error: str = None | |
| ): | |
| if task_id in self.tasks: | |
| task_info = self.tasks[task_id] | |
| task_info.status = status | |
| task_info.result = result | |
| task_info.error = error | |
| def get_task(self, task_id: str) -> Optional[TaskInfo]: | |
| return self.tasks.get(task_id) | |
| async def _cleanup_loop(self): | |
| while True: | |
| try: | |
| await asyncio.sleep(self.cleanup_interval) | |
| current_time = time.time() | |
| expired_tasks = [ | |
| task_id | |
| for task_id, task in self.tasks.items() | |
| if current_time - task.created_at > task.ttl | |
| and task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED] | |
| ] | |
| for task_id in expired_tasks: | |
| del self.tasks[task_id] | |
| except Exception as e: | |
| logger.error(f"Error in cleanup loop: {e}") | |
| class CrawlerPool: | |
| def __init__(self, max_size: int = 10): | |
| self.max_size = max_size | |
| self.active_crawlers: Dict[AsyncWebCrawler, float] = {} | |
| self._lock = asyncio.Lock() | |
| async def acquire(self, **kwargs) -> AsyncWebCrawler: | |
| async with self._lock: | |
| # Clean up inactive crawlers | |
| current_time = time.time() | |
| inactive = [ | |
| crawler | |
| for crawler, last_used in self.active_crawlers.items() | |
| if current_time - last_used > 600 # 10 minutes timeout | |
| ] | |
| for crawler in inactive: | |
| await crawler.__aexit__(None, None, None) | |
| del self.active_crawlers[crawler] | |
| # Create new crawler if needed | |
| if len(self.active_crawlers) < self.max_size: | |
| crawler = AsyncWebCrawler(**kwargs) | |
| await crawler.__aenter__() | |
| self.active_crawlers[crawler] = current_time | |
| return crawler | |
| # Reuse least recently used crawler | |
| crawler = min(self.active_crawlers.items(), key=lambda x: x[1])[0] | |
| self.active_crawlers[crawler] = current_time | |
| return crawler | |
| async def release(self, crawler: AsyncWebCrawler): | |
| async with self._lock: | |
| if crawler in self.active_crawlers: | |
| self.active_crawlers[crawler] = time.time() | |
| async def cleanup(self): | |
| async with self._lock: | |
| for crawler in list(self.active_crawlers.keys()): | |
| await crawler.__aexit__(None, None, None) | |
| self.active_crawlers.clear() | |
| class CrawlerService: | |
| def __init__(self, max_concurrent_tasks: int = 10): | |
| self.resource_monitor = ResourceMonitor(max_concurrent_tasks) | |
| self.task_manager = TaskManager() | |
| self.crawler_pool = CrawlerPool(max_concurrent_tasks) | |
| self._processing_task = None | |
| async def start(self): | |
| await self.task_manager.start() | |
| self._processing_task = asyncio.create_task(self._process_queue()) | |
| async def stop(self): | |
| if self._processing_task: | |
| self._processing_task.cancel() | |
| try: | |
| await self._processing_task | |
| except asyncio.CancelledError: | |
| pass | |
| await self.task_manager.stop() | |
| await self.crawler_pool.cleanup() | |
| def _create_extraction_strategy(self, config: ExtractionConfig): | |
| if not config: | |
| return None | |
| if config.type == CrawlerType.LLM: | |
| return LLMExtractionStrategy(**config.params) | |
| elif config.type == CrawlerType.COSINE: | |
| return CosineStrategy(**config.params) | |
| elif config.type == CrawlerType.JSON_CSS: | |
| return JsonCssExtractionStrategy(**config.params) | |
| return None | |
| async def submit_task(self, request: CrawlRequest) -> str: | |
| task_id = str(uuid.uuid4()) | |
| await self.task_manager.add_task(task_id, request.priority, request.ttl or 3600) | |
| # Store request data with task | |
| self.task_manager.tasks[task_id].request = request | |
| return task_id | |
| async def _process_queue(self): | |
| while True: | |
| try: | |
| available_slots = await self.resource_monitor.get_available_slots() | |
| if False and available_slots <= 0: | |
| await asyncio.sleep(1) | |
| continue | |
| task_id = await self.task_manager.get_next_task() | |
| if not task_id: | |
| await asyncio.sleep(1) | |
| continue | |
| task_info = self.task_manager.get_task(task_id) | |
| if not task_info: | |
| continue | |
| request = task_info.request | |
| self.task_manager.update_task(task_id, TaskStatus.PROCESSING) | |
| try: | |
| crawler = await self.crawler_pool.acquire(**request.crawler_params) | |
| extraction_strategy = self._create_extraction_strategy( | |
| request.extraction_config | |
| ) | |
| if isinstance(request.urls, list): | |
| results = await crawler.arun_many( | |
| urls=[str(url) for url in request.urls], | |
| word_count_threshold=MIN_WORD_THRESHOLD, | |
| extraction_strategy=extraction_strategy, | |
| js_code=request.js_code, | |
| wait_for=request.wait_for, | |
| css_selector=request.css_selector, | |
| screenshot=request.screenshot, | |
| magic=request.magic, | |
| session_id=request.session_id, | |
| cache_mode=request.cache_mode, | |
| **request.extra, | |
| ) | |
| else: | |
| results = await crawler.arun( | |
| url=str(request.urls), | |
| extraction_strategy=extraction_strategy, | |
| js_code=request.js_code, | |
| wait_for=request.wait_for, | |
| css_selector=request.css_selector, | |
| screenshot=request.screenshot, | |
| magic=request.magic, | |
| session_id=request.session_id, | |
| cache_mode=request.cache_mode, | |
| **request.extra, | |
| ) | |
| await self.crawler_pool.release(crawler) | |
| self.task_manager.update_task( | |
| task_id, TaskStatus.COMPLETED, results | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing task {task_id}: {str(e)}") | |
| self.task_manager.update_task( | |
| task_id, TaskStatus.FAILED, error=str(e) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in queue processing: {str(e)}") | |
| await asyncio.sleep(1) | |
| app = FastAPI(title="Crawl4AI API") | |
| # CORS configuration | |
| origins = ["*"] # Allow all origins | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, # List of origins that are allowed to make requests | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allows all methods | |
| allow_headers=["*"], # Allows all headers | |
| ) | |
| app.add_middleware(GZipMiddleware, minimum_size=1000) | |
| # API token security | |
| security = HTTPBearer() | |
| CRAWL4AI_API_TOKEN = os.getenv("CRAWL4AI_API_TOKEN") | |
| async def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)): | |
| if not CRAWL4AI_API_TOKEN: | |
| return credentials # No token verification if CRAWL4AI_API_TOKEN is not set | |
| if credentials.credentials != CRAWL4AI_API_TOKEN: | |
| raise HTTPException(status_code=401, detail="Invalid token") | |
| return credentials | |
| def secure_endpoint(): | |
| """Returns security dependency only if CRAWL4AI_API_TOKEN is set""" | |
| return Depends(verify_token) if CRAWL4AI_API_TOKEN else None | |
| # Check if site directory exists | |
| if os.path.exists(__location__ + "/site"): | |
| # Mount the site directory as a static directory | |
| app.mount("/mkdocs", StaticFiles(directory="site", html=True), name="mkdocs") | |
| site_templates = Jinja2Templates(directory=__location__ + "/site") | |
| crawler_service = CrawlerService() | |
| async def startup_event(): | |
| await crawler_service.start() | |
| async def shutdown_event(): | |
| await crawler_service.stop() | |
| # @app.get("/") | |
| # def read_root(): | |
| # if os.path.exists(__location__ + "/site"): | |
| # return RedirectResponse(url="/mkdocs") | |
| # # Return a json response | |
| # return {"message": "Crawl4AI API service is running"} | |
| async def root(): | |
| return RedirectResponse(url="/docs") | |
| async def crawl(request: CrawlRequest) -> Dict[str, str]: | |
| task_id = await crawler_service.submit_task(request) | |
| return {"task_id": task_id} | |
| async def get_task_status(task_id: str): | |
| task_info = crawler_service.task_manager.get_task(task_id) | |
| if not task_info: | |
| raise HTTPException(status_code=404, detail="Task not found") | |
| response = { | |
| "status": task_info.status, | |
| "created_at": task_info.created_at, | |
| } | |
| if task_info.status == TaskStatus.COMPLETED: | |
| # Convert CrawlResult to dict for JSON response | |
| if isinstance(task_info.result, list): | |
| response["results"] = [result.dict() for result in task_info.result] | |
| else: | |
| response["result"] = task_info.result.dict() | |
| elif task_info.status == TaskStatus.FAILED: | |
| response["error"] = task_info.error | |
| return response | |
| async def crawl_sync(request: CrawlRequest) -> Dict[str, Any]: | |
| task_id = await crawler_service.submit_task(request) | |
| # Wait up to 60 seconds for task completion | |
| for _ in range(60): | |
| task_info = crawler_service.task_manager.get_task(task_id) | |
| if not task_info: | |
| raise HTTPException(status_code=404, detail="Task not found") | |
| if task_info.status == TaskStatus.COMPLETED: | |
| # Return same format as /task/{task_id} endpoint | |
| if isinstance(task_info.result, list): | |
| return { | |
| "status": task_info.status, | |
| "results": [result.dict() for result in task_info.result], | |
| } | |
| return {"status": task_info.status, "result": task_info.result.dict()} | |
| if task_info.status == TaskStatus.FAILED: | |
| raise HTTPException(status_code=500, detail=task_info.error) | |
| await asyncio.sleep(1) | |
| # If we get here, task didn't complete within timeout | |
| raise HTTPException(status_code=408, detail="Task timed out") | |
| # @app.post( | |
| # "/crawl_direct", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [] | |
| # ) | |
| # async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]: | |
| # try: | |
| # crawler = await crawler_service.crawler_pool.acquire(**request.crawler_params) | |
| # extraction_strategy = crawler_service._create_extraction_strategy( | |
| # request.extraction_config | |
| # ) | |
| # try: | |
| # if isinstance(request.urls, list): | |
| # results = await crawler.arun_many( | |
| # urls=[str(url) for url in request.urls], | |
| # extraction_strategy=extraction_strategy, | |
| # js_code=request.js_code, | |
| # wait_for=request.wait_for, | |
| # css_selector=request.css_selector, | |
| # screenshot=request.screenshot, | |
| # magic=request.magic, | |
| # cache_mode=request.cache_mode, | |
| # session_id=request.session_id, | |
| # **request.extra, | |
| # ) | |
| # return {"results": [result.dict() for result in results]} | |
| # else: | |
| # result = await crawler.arun( | |
| # url=str(request.urls), | |
| # extraction_strategy=extraction_strategy, | |
| # js_code=request.js_code, | |
| # wait_for=request.wait_for, | |
| # css_selector=request.css_selector, | |
| # screenshot=request.screenshot, | |
| # magic=request.magic, | |
| # cache_mode=request.cache_mode, | |
| # session_id=request.session_id, | |
| # **request.extra, | |
| # ) | |
| # return {"result": result.dict()} | |
| # finally: | |
| # await crawler_service.crawler_pool.release(crawler) | |
| # except Exception as e: | |
| # logger.error(f"Error in direct crawl: {str(e)}") | |
| # raise HTTPException(status_code=500, detail=str(e)) | |
| async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]: | |
| logger.info("Received request to crawl directly.") | |
| try: | |
| logger.debug("Acquiring crawler from the crawler pool.") | |
| crawler = await crawler_service.crawler_pool.acquire(**request.crawler_params) | |
| logger.debug("Crawler acquired successfully.") | |
| logger.debug("Creating extraction strategy based on the request configuration.") | |
| extraction_strategy = crawler_service._create_extraction_strategy( | |
| request.extraction_config | |
| ) | |
| logger.debug("Extraction strategy created successfully.") | |
| try: | |
| if isinstance(request.urls, list): | |
| logger.info("Processing multiple URLs.") | |
| results = await crawler.arun_many( | |
| urls=[str(url) for url in request.urls], | |
| extraction_strategy=extraction_strategy, | |
| js_code=request.js_code, | |
| wait_for=request.wait_for, | |
| css_selector=request.css_selector, | |
| screenshot=request.screenshot, | |
| magic=request.magic, | |
| cache_mode=request.cache_mode, | |
| session_id=request.session_id, | |
| **request.extra, | |
| ) | |
| logger.info("Crawling completed for multiple URLs.") | |
| return {"results": [result.dict() for result in results]} | |
| else: | |
| logger.info("Processing a single URL.") | |
| result = await crawler.arun( | |
| url=str(request.urls), | |
| extraction_strategy=extraction_strategy, | |
| js_code=request.js_code, | |
| wait_for=request.wait_for, | |
| css_selector=request.css_selector, | |
| screenshot=request.screenshot, | |
| magic=request.magic, | |
| cache_mode=request.cache_mode, | |
| session_id=request.session_id, | |
| **request.extra, | |
| ) | |
| logger.info("Crawling completed for a single URL.") | |
| return {"result": result.dict()} | |
| finally: | |
| logger.debug("Releasing crawler back to the pool.") | |
| await crawler_service.crawler_pool.release(crawler) | |
| logger.debug("Crawler released successfully.") | |
| except Exception as e: | |
| logger.error(f"Error in direct crawl: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def health_check(): | |
| available_slots = await crawler_service.resource_monitor.get_available_slots() | |
| memory = psutil.virtual_memory() | |
| return { | |
| "status": "healthy", | |
| "available_slots": available_slots, | |
| "memory_usage": memory.percent, | |
| "cpu_usage": psutil.cpu_percent(), | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=11235) | |