Spaces:
Running
on
Zero
Running
on
Zero
| from functools import wraps | |
| import torch | |
| from huggingface_hub import HfApi | |
| import os | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class DeviceManager: | |
| _instance = None | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super(DeviceManager, cls).__new__(cls) | |
| cls._instance._initialized = False | |
| return cls._instance | |
| def __init__(self): | |
| if self._initialized: | |
| return | |
| self._initialized = True | |
| self._current_device = None | |
| self._zero_gpu_available = None | |
| def check_zero_gpu_availability(self): | |
| try: | |
| api = HfApi() | |
| # 檢查環境變數或其他方式確認是否在 Spaces 環境 | |
| if 'SPACE_ID' in os.environ: | |
| # 這裡可以添加更多具體的 ZeroGPU 可用性檢查 | |
| self._zero_gpu_available = True | |
| return True | |
| except Exception as e: | |
| logger.warning(f"Error checking ZeroGPU availability: {e}") | |
| self._zero_gpu_available = False | |
| return False | |
| def get_optimal_device(self): | |
| if self._current_device is None: | |
| if self.check_zero_gpu_availability(): | |
| self._current_device = torch.device('cuda') | |
| logger.info("Using ZeroGPU") | |
| else: | |
| self._current_device = torch.device('cpu') | |
| logger.info("Using CPU") | |
| return self._current_device | |
| def move_to_device(self, tensor_or_model): | |
| device = self.get_optimal_device() | |
| if hasattr(tensor_or_model, 'to'): | |
| return tensor_or_model.to(device) | |
| return tensor_or_model | |
| def device_handler(func): | |
| """Decorator for handling device placement""" | |
| async def wrapper(*args, **kwargs): | |
| device_mgr = DeviceManager() | |
| # 處理輸入參數的設備轉換 | |
| def process_arg(arg): | |
| if torch.is_tensor(arg) or hasattr(arg, 'to'): | |
| return device_mgr.move_to_device(arg) | |
| return arg | |
| processed_args = [process_arg(arg) for arg in args] | |
| processed_kwargs = {k: process_arg(v) for k, v in kwargs.items()} | |
| try: | |
| result = await func(*processed_args, **processed_kwargs) | |
| # 處理輸出結果的設備轉換 | |
| if torch.is_tensor(result): | |
| return device_mgr.move_to_device(result) | |
| elif isinstance(result, tuple): | |
| return tuple(device_mgr.move_to_device(r) if torch.is_tensor(r) else r for r in result) | |
| return result | |
| except RuntimeError as e: | |
| if "out of memory" in str(e): | |
| logger.warning("GPU memory exceeded, falling back to CPU") | |
| device_mgr._current_device = torch.device('cpu') | |
| return await wrapper(*args, **kwargs) | |
| raise e | |
| return wrapper |