Commit
·
0d3afce
1
Parent(s):
6f2f71c
Update convert.py
Browse files- convert.py +41 -19
convert.py
CHANGED
|
@@ -10,7 +10,7 @@ import torch
|
|
| 10 |
|
| 11 |
from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
|
| 12 |
from huggingface_hub.file_download import repo_folder_name
|
| 13 |
-
from safetensors.torch import
|
| 14 |
|
| 15 |
|
| 16 |
COMMIT_DESCRIPTION = """
|
|
@@ -32,6 +32,7 @@ Feel free to ignore this PR.
|
|
| 32 |
|
| 33 |
ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]]
|
| 34 |
|
|
|
|
| 35 |
def _remove_duplicate_names(
|
| 36 |
state_dict: Dict[str, torch.Tensor],
|
| 37 |
*,
|
|
@@ -48,9 +49,7 @@ def _remove_duplicate_names(
|
|
| 48 |
shareds = _find_shared_tensors(state_dict)
|
| 49 |
to_remove = defaultdict(list)
|
| 50 |
for shared in shareds:
|
| 51 |
-
complete_names = set(
|
| 52 |
-
[name for name in shared if _is_complete(state_dict[name])]
|
| 53 |
-
)
|
| 54 |
if not complete_names:
|
| 55 |
if len(shared) == 1:
|
| 56 |
# Force contiguous
|
|
@@ -81,11 +80,13 @@ def _remove_duplicate_names(
|
|
| 81 |
to_remove[keep_name].append(name)
|
| 82 |
return to_remove
|
| 83 |
|
|
|
|
| 84 |
def get_discard_names(model_id: str, revision: Optional[str], folder: str, token: Optional[str]) -> List[str]:
|
| 85 |
try:
|
| 86 |
-
import transformers
|
| 87 |
import json
|
| 88 |
|
|
|
|
|
|
|
| 89 |
config_filename = hf_hub_download(
|
| 90 |
model_id, revision=revision, filename="config.json", token=token, cache_dir=folder
|
| 91 |
)
|
|
@@ -98,10 +99,11 @@ def get_discard_names(model_id: str, revision: Optional[str], folder: str, token
|
|
| 98 |
# Name for this varible depends on transformers version.
|
| 99 |
discard_names = getattr(class_, "_tied_weights_keys", [])
|
| 100 |
|
| 101 |
-
except Exception
|
| 102 |
discard_names = []
|
| 103 |
return discard_names
|
| 104 |
|
|
|
|
| 105 |
class AlreadyExists(Exception):
|
| 106 |
pass
|
| 107 |
|
|
@@ -126,8 +128,12 @@ def rename(pt_filename: str) -> str:
|
|
| 126 |
return local
|
| 127 |
|
| 128 |
|
| 129 |
-
def convert_multi(
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
with open(filename, "r") as f:
|
| 132 |
data = json.load(f)
|
| 133 |
|
|
@@ -157,8 +163,12 @@ def convert_multi(model_id: str, *, revision=Optional[str], folder: str, token:
|
|
| 157 |
return operations, errors
|
| 158 |
|
| 159 |
|
| 160 |
-
def convert_single(
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
sf_name = "model.safetensors"
|
| 164 |
sf_filename = os.path.join(folder, sf_name)
|
|
@@ -217,20 +227,22 @@ def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]])
|
|
| 217 |
|
| 218 |
def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[str]) -> Optional["Discussion"]:
|
| 219 |
try:
|
| 220 |
-
|
| 221 |
-
discussions = api.get_repo_discussions(repo_id=model_id
|
| 222 |
except Exception:
|
| 223 |
return None
|
| 224 |
for discussion in discussions:
|
| 225 |
if discussion.status in {"open", "closed"} and discussion.is_pull_request and discussion.title == pr_title:
|
| 226 |
commits = api.list_repo_commits(model_id, revision=discussion.git_reference)
|
| 227 |
|
| 228 |
-
if
|
| 229 |
return discussion
|
| 230 |
return None
|
| 231 |
|
| 232 |
|
| 233 |
-
def convert_generic(
|
|
|
|
|
|
|
| 234 |
operations = []
|
| 235 |
errors = []
|
| 236 |
|
|
@@ -238,7 +250,9 @@ def convert_generic(model_id: str, *, revision=Optional[str], folder: str, filen
|
|
| 238 |
for filename in filenames:
|
| 239 |
prefix, ext = os.path.splitext(filename)
|
| 240 |
if ext in extensions:
|
| 241 |
-
pt_filename = hf_hub_download(
|
|
|
|
|
|
|
| 242 |
dirname, raw_filename = os.path.split(filename)
|
| 243 |
if raw_filename == "pytorch_model.bin":
|
| 244 |
# XXX: This is a special case to handle `transformers` and the
|
|
@@ -255,7 +269,9 @@ def convert_generic(model_id: str, *, revision=Optional[str], folder: str, filen
|
|
| 255 |
return operations, errors
|
| 256 |
|
| 257 |
|
| 258 |
-
def convert(
|
|
|
|
|
|
|
| 259 |
pr_title = "Adding `safetensors` variant of this model"
|
| 260 |
info = api.model_info(model_id, revision=revision)
|
| 261 |
filenames = set(s.rfilename for s in info.siblings)
|
|
@@ -279,13 +295,19 @@ def convert(api: "HfApi", model_id: str, revision: Optional[str] = None, force:
|
|
| 279 |
|
| 280 |
discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=api.token)
|
| 281 |
if "pytorch_model.bin" in filenames:
|
| 282 |
-
operations, errors = convert_single(
|
|
|
|
|
|
|
| 283 |
elif "pytorch_model.bin.index.json" in filenames:
|
| 284 |
-
operations, errors = convert_multi(
|
|
|
|
|
|
|
| 285 |
else:
|
| 286 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
| 287 |
else:
|
| 288 |
-
operations, errors = convert_generic(
|
|
|
|
|
|
|
| 289 |
|
| 290 |
if operations:
|
| 291 |
new_pr = api.create_commit(
|
|
|
|
| 10 |
|
| 11 |
from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
|
| 12 |
from huggingface_hub.file_download import repo_folder_name
|
| 13 |
+
from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file
|
| 14 |
|
| 15 |
|
| 16 |
COMMIT_DESCRIPTION = """
|
|
|
|
| 32 |
|
| 33 |
ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]]
|
| 34 |
|
| 35 |
+
|
| 36 |
def _remove_duplicate_names(
|
| 37 |
state_dict: Dict[str, torch.Tensor],
|
| 38 |
*,
|
|
|
|
| 49 |
shareds = _find_shared_tensors(state_dict)
|
| 50 |
to_remove = defaultdict(list)
|
| 51 |
for shared in shareds:
|
| 52 |
+
complete_names = set([name for name in shared if _is_complete(state_dict[name])])
|
|
|
|
|
|
|
| 53 |
if not complete_names:
|
| 54 |
if len(shared) == 1:
|
| 55 |
# Force contiguous
|
|
|
|
| 80 |
to_remove[keep_name].append(name)
|
| 81 |
return to_remove
|
| 82 |
|
| 83 |
+
|
| 84 |
def get_discard_names(model_id: str, revision: Optional[str], folder: str, token: Optional[str]) -> List[str]:
|
| 85 |
try:
|
|
|
|
| 86 |
import json
|
| 87 |
|
| 88 |
+
import transformers
|
| 89 |
+
|
| 90 |
config_filename = hf_hub_download(
|
| 91 |
model_id, revision=revision, filename="config.json", token=token, cache_dir=folder
|
| 92 |
)
|
|
|
|
| 99 |
# Name for this varible depends on transformers version.
|
| 100 |
discard_names = getattr(class_, "_tied_weights_keys", [])
|
| 101 |
|
| 102 |
+
except Exception:
|
| 103 |
discard_names = []
|
| 104 |
return discard_names
|
| 105 |
|
| 106 |
+
|
| 107 |
class AlreadyExists(Exception):
|
| 108 |
pass
|
| 109 |
|
|
|
|
| 128 |
return local
|
| 129 |
|
| 130 |
|
| 131 |
+
def convert_multi(
|
| 132 |
+
model_id: str, *, revision=Optional[str], folder: str, token: Optional[str], discard_names: List[str]
|
| 133 |
+
) -> ConversionResult:
|
| 134 |
+
filename = hf_hub_download(
|
| 135 |
+
repo_id=model_id, revision=revision, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder
|
| 136 |
+
)
|
| 137 |
with open(filename, "r") as f:
|
| 138 |
data = json.load(f)
|
| 139 |
|
|
|
|
| 163 |
return operations, errors
|
| 164 |
|
| 165 |
|
| 166 |
+
def convert_single(
|
| 167 |
+
model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]
|
| 168 |
+
) -> ConversionResult:
|
| 169 |
+
pt_filename = hf_hub_download(
|
| 170 |
+
repo_id=model_id, revision=revision, filename="pytorch_model.bin", token=token, cache_dir=folder
|
| 171 |
+
)
|
| 172 |
|
| 173 |
sf_name = "model.safetensors"
|
| 174 |
sf_filename = os.path.join(folder, sf_name)
|
|
|
|
| 227 |
|
| 228 |
def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[str]) -> Optional["Discussion"]:
|
| 229 |
try:
|
| 230 |
+
revision_commit = api.model_info(model_id, revision=revision).sha
|
| 231 |
+
discussions = api.get_repo_discussions(repo_id=model_id)
|
| 232 |
except Exception:
|
| 233 |
return None
|
| 234 |
for discussion in discussions:
|
| 235 |
if discussion.status in {"open", "closed"} and discussion.is_pull_request and discussion.title == pr_title:
|
| 236 |
commits = api.list_repo_commits(model_id, revision=discussion.git_reference)
|
| 237 |
|
| 238 |
+
if revision_commit == commits[1].commit_id:
|
| 239 |
return discussion
|
| 240 |
return None
|
| 241 |
|
| 242 |
|
| 243 |
+
def convert_generic(
|
| 244 |
+
model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str]
|
| 245 |
+
) -> ConversionResult:
|
| 246 |
operations = []
|
| 247 |
errors = []
|
| 248 |
|
|
|
|
| 250 |
for filename in filenames:
|
| 251 |
prefix, ext = os.path.splitext(filename)
|
| 252 |
if ext in extensions:
|
| 253 |
+
pt_filename = hf_hub_download(
|
| 254 |
+
model_id, revision=revision, filename=filename, token=token, cache_dir=folder
|
| 255 |
+
)
|
| 256 |
dirname, raw_filename = os.path.split(filename)
|
| 257 |
if raw_filename == "pytorch_model.bin":
|
| 258 |
# XXX: This is a special case to handle `transformers` and the
|
|
|
|
| 269 |
return operations, errors
|
| 270 |
|
| 271 |
|
| 272 |
+
def convert(
|
| 273 |
+
api: "HfApi", model_id: str, revision: Optional[str] = None, force: bool = False
|
| 274 |
+
) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]:
|
| 275 |
pr_title = "Adding `safetensors` variant of this model"
|
| 276 |
info = api.model_info(model_id, revision=revision)
|
| 277 |
filenames = set(s.rfilename for s in info.siblings)
|
|
|
|
| 295 |
|
| 296 |
discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=api.token)
|
| 297 |
if "pytorch_model.bin" in filenames:
|
| 298 |
+
operations, errors = convert_single(
|
| 299 |
+
model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names
|
| 300 |
+
)
|
| 301 |
elif "pytorch_model.bin.index.json" in filenames:
|
| 302 |
+
operations, errors = convert_multi(
|
| 303 |
+
model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names
|
| 304 |
+
)
|
| 305 |
else:
|
| 306 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
| 307 |
else:
|
| 308 |
+
operations, errors = convert_generic(
|
| 309 |
+
model_id, revision=revision, folder=folder, filenames=filenames, token=api.token
|
| 310 |
+
)
|
| 311 |
|
| 312 |
if operations:
|
| 313 |
new_pr = api.create_commit(
|