Spaces:
Running
on
Zero
Running
on
Zero
| from functools import wraps | |
| import torch | |
| import os | |
| import logging | |
| import spaces | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class DeviceManager: | |
| _instance = None | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super(DeviceManager, cls).__new__(cls) | |
| cls._instance._initialized = False | |
| return cls._instance | |
| def __init__(self): | |
| if self._initialized: | |
| return | |
| self._initialized = True | |
| self._current_device = None | |
| try: | |
| if os.environ.get('SPACE_ID'): | |
| # 使用 spaces 的 GPU wrapper 進行初始化 | |
| def init_gpu(): | |
| return torch.device('cuda') | |
| self._current_device = init_gpu() | |
| logger.info("ZeroGPU initialized successfully") | |
| else: | |
| self._current_device = torch.device('cpu') | |
| except Exception as e: | |
| logger.warning(f"Failed to initialize ZeroGPU: {e}") | |
| self._current_device = torch.device('cpu') | |
| def get_optimal_device(self): | |
| return self._current_device | |
| def device_handler(func): | |
| """Decorator for handling device placement with ZeroGPU support""" | |
| async def wrapper(*args, **kwargs): | |
| try: | |
| return await func(*args, **kwargs) | |
| except RuntimeError as e: | |
| if "out of memory" in str(e) or "CUDA" in str(e): | |
| logger.warning("ZeroGPU unavailable, falling back to CPU") | |
| device_mgr = DeviceManager() | |
| device_mgr._current_device = torch.device('cpu') | |
| return await func(*args, **kwargs) | |
| raise e | |
| return wrapper |