badaoui HF Staff commited on
Commit
e3faddb
·
verified ·
1 Parent(s): 677386b

Create synchronizer.py

Browse files
Files changed (1) hide show
  1. synchronizer.py +132 -0
synchronizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ External module for creating PR-based cache synchronization.
3
+ """
4
+
5
+ import os
6
+ import threading
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+
10
+ import torch_xla.core.xla_model as xm
11
+ import torch_xla.runtime as xr
12
+
13
+ from optimum.neuron.cache.hub_cache import create_hub_compile_cache_proxy
14
+ from optimum.neuron.utils.cache_utils import get_hf_hub_cache_repo
15
+ from optimum.neuron.utils.require_utils import requires_torch_neuronx
16
+ from optimum.neuron.utils.version_utils import get_neuronxcc_version
17
+ from optimum.neuron.utils.import_utils import is_neuronx_available
18
+ from libneuronxla.neuron_cc_cache import CacheUrl, CompileCacheS3
19
+
20
+ @requires_torch_neuronx
21
+ def synchronize_hub_cache_with_pr(
22
+ cache_path: str | Path | None = None,
23
+ cache_repo_id: str | None = None,
24
+ commit_message: str | None = None,
25
+ commit_description: str | None = None,
26
+ token: str | None = None,
27
+ non_blocking: bool = False,
28
+ ):
29
+ """Synchronize the neuronx compiler cache with the optimum-neuron hub cache via a Pull Request.
30
+
31
+ Args:
32
+ cache_path (`str | Path | None`, defaults to `None`):
33
+ The path of the folder to use for synchronization.
34
+ cache_repo_id (`str | None`, defaults to `None`):
35
+ The id of the HuggingFace cache repository, in the form 'org|user/name'.
36
+ non_blocking (`bool`, defaults to `False`):
37
+ If `True`, the synchronization will be done in a non-blocking way.
38
+
39
+ Yields:
40
+ Status messages about the synchronization process.
41
+
42
+ Returns:
43
+ The URL of the created pull request or None if non_blocking=True.
44
+ """
45
+ # Validate cache path if provided
46
+ if cache_path is not None:
47
+ cache_path = Path(cache_path)
48
+ cache_path_str = cache_path.as_posix()
49
+ if not cache_path.is_dir():
50
+ raise ValueError(f"The {cache_path_str} directory does not exist, cannot synchronize.")
51
+ cache_url = CacheUrl(cache_path_str, url_type="fs")
52
+ else:
53
+ cache_url = None
54
+
55
+ # Get default cache repo if not provided
56
+ if cache_repo_id is None:
57
+ cache_repo_id = get_hf_hub_cache_repo()
58
+
59
+ # Create the hub cache proxy using the existing function
60
+ hub_cache_proxy = create_hub_compile_cache_proxy(cache_url=cache_url, cache_repo_id=cache_repo_id)
61
+
62
+ # Check if S3 cache (not supported for PR workflow)
63
+ if isinstance(hub_cache_proxy.default_cache, CompileCacheS3):
64
+ raise ValueError("Hugging Face hub compiler cache synchronization via PR is not supported for S3.")
65
+
66
+
67
+ def _create_pr():
68
+ """Internal function to create the PR"""
69
+ try:
70
+ pr_url = hub_cache_proxy.api.upload_folder(
71
+ repo_id=cache_repo_id,
72
+ folder_path=hub_cache_proxy.default_cache.cache_path,
73
+ commit_message=commit_message,
74
+ commit_description=commit_description,
75
+ ignore_patterns="lock",
76
+ create_pr=True,
77
+ token=token
78
+ )
79
+ yield f"Pull request created successfully: {pr_url}"
80
+ return pr_url
81
+ except Exception as e:
82
+ yield f"Error: Failed to create PR for cache synchronization: {e}"
83
+ raise
84
+
85
+ # Handle distributed training scenario
86
+ if os.environ.get("TORCHELASTIC_RUN_ID", None) is not None:
87
+ # Multi-process execution
88
+ pr_url = None
89
+ if xr.local_ordinal() == 0:
90
+ # Only the first process creates the PR
91
+ if non_blocking:
92
+ def sync_thread():
93
+ try:
94
+ for status in _create_pr():
95
+ yield status
96
+ except Exception as e:
97
+ yield f"Error: Background sync failed: {e}"
98
+
99
+ thread = threading.Thread(target=sync_thread)
100
+ thread.start()
101
+ yield "Cache synchronization started in background thread"
102
+ else:
103
+ for status in _create_pr():
104
+ yield status
105
+ if "Pull request created successfully:" in status:
106
+ pr_url = status.split(": ", 1)[1]
107
+
108
+ # Synchronize all processes
109
+ xm.rendezvous("synchronize_hub_cache_pr")
110
+
111
+ return pr_url if not non_blocking else None
112
+
113
+ # Single process execution
114
+ if non_blocking:
115
+ def sync_thread():
116
+ try:
117
+ for status in _create_pr():
118
+ yield status
119
+ except Exception as e:
120
+ yield f"Error: Background sync failed: {e}"
121
+
122
+ thread = threading.Thread(target=sync_thread)
123
+ thread.start()
124
+ yield "Cache synchronization started in background thread"
125
+ return None
126
+ else:
127
+ pr_url = None
128
+ for status in _create_pr():
129
+ yield status
130
+ if "Pull request created successfully:" in status:
131
+ pr_url = status.split(": ", 1)[1]
132
+ return pr_url