Spaces:
Running
Running
| """ | |
| External module for creating PR-based cache synchronization. | |
| """ | |
| import os | |
| import threading | |
| from datetime import datetime | |
| from pathlib import Path | |
| import torch_xla.core.xla_model as xm | |
| import torch_xla.runtime as xr | |
| from optimum.neuron.cache.hub_cache import create_hub_compile_cache_proxy | |
| from optimum.neuron.utils.cache_utils import get_hf_hub_cache_repo | |
| from optimum.neuron.utils.require_utils import requires_torch_neuronx | |
| from optimum.neuron.utils.version_utils import get_neuronxcc_version | |
| from optimum.neuron.utils.import_utils import is_neuronx_available | |
| from libneuronxla.neuron_cc_cache import CacheUrl, CompileCacheS3 | |
| def synchronize_hub_cache_with_pr( | |
| cache_path: str | Path | None = None, | |
| cache_repo_id: str | None = None, | |
| commit_message: str | None = None, | |
| commit_description: str | None = None, | |
| token: str | None = None, | |
| non_blocking: bool = False, | |
| ): | |
| """Synchronize the neuronx compiler cache with the optimum-neuron hub cache via a Pull Request. | |
| Args: | |
| cache_path (`str | Path | None`, defaults to `None`): | |
| The path of the folder to use for synchronization. | |
| cache_repo_id (`str | None`, defaults to `None`): | |
| The id of the HuggingFace cache repository, in the form 'org|user/name'. | |
| non_blocking (`bool`, defaults to `False`): | |
| If `True`, the synchronization will be done in a non-blocking way. | |
| Yields: | |
| Status messages about the synchronization process. | |
| Returns: | |
| The URL of the created pull request or None if non_blocking=True. | |
| """ | |
| # Validate cache path if provided | |
| if cache_path is not None: | |
| cache_path = Path(cache_path) | |
| cache_path_str = cache_path.as_posix() | |
| if not cache_path.is_dir(): | |
| raise ValueError(f"The {cache_path_str} directory does not exist, cannot synchronize.") | |
| cache_url = CacheUrl(cache_path_str, url_type="fs") | |
| else: | |
| cache_url = None | |
| # Get default cache repo if not provided | |
| if cache_repo_id is None: | |
| cache_repo_id = get_hf_hub_cache_repo() | |
| # Create the hub cache proxy using the existing function | |
| hub_cache_proxy = create_hub_compile_cache_proxy(cache_url=cache_url, cache_repo_id=cache_repo_id) | |
| # Check if S3 cache (not supported for PR workflow) | |
| if isinstance(hub_cache_proxy.default_cache, CompileCacheS3): | |
| raise ValueError("Hugging Face hub compiler cache synchronization via PR is not supported for S3.") | |
| def _create_pr(): | |
| """Internal function to create the PR""" | |
| try: | |
| pr_url = hub_cache_proxy.api.upload_folder( | |
| repo_id=cache_repo_id, | |
| folder_path=hub_cache_proxy.default_cache.cache_path, | |
| commit_message=commit_message, | |
| commit_description=commit_description, | |
| ignore_patterns="lock", | |
| create_pr=True, | |
| token=token | |
| ) | |
| yield f"Pull request created successfully: {pr_url}" | |
| return pr_url | |
| except Exception as e: | |
| yield f"Error: Failed to create PR for cache synchronization: {e}" | |
| raise | |
| # Handle distributed training scenario | |
| if os.environ.get("TORCHELASTIC_RUN_ID", None) is not None: | |
| # Multi-process execution | |
| pr_url = None | |
| if xr.local_ordinal() == 0: | |
| # Only the first process creates the PR | |
| if non_blocking: | |
| def sync_thread(): | |
| try: | |
| for status in _create_pr(): | |
| yield status | |
| except Exception as e: | |
| yield f"Error: Background sync failed: {e}" | |
| thread = threading.Thread(target=sync_thread) | |
| thread.start() | |
| yield "Cache synchronization started in background thread" | |
| else: | |
| for status in _create_pr(): | |
| yield status | |
| if "Pull request created successfully:" in status: | |
| pr_url = status.split(": ", 1)[1] | |
| # Synchronize all processes | |
| xm.rendezvous("synchronize_hub_cache_pr") | |
| return pr_url if not non_blocking else None | |
| # Single process execution | |
| if non_blocking: | |
| def sync_thread(): | |
| try: | |
| for status in _create_pr(): | |
| yield status | |
| except Exception as e: | |
| yield f"Error: Background sync failed: {e}" | |
| thread = threading.Thread(target=sync_thread) | |
| thread.start() | |
| yield "Cache synchronization started in background thread" | |
| return None | |
| else: | |
| pr_url = None | |
| for status in _create_pr(): | |
| yield status | |
| if "Pull request created successfully:" in status: | |
| pr_url = status.split(": ", 1)[1] | |
| return pr_url |