Update convert.py (#14)
Browse files- Update convert.py (53a69cdc8651efe9d80bbc2bfcc4e30ea5aa3c3d)
- convert.py +5 -6
convert.py
CHANGED
|
@@ -13,7 +13,6 @@ from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, h
|
|
| 13 |
from huggingface_hub.file_download import repo_folder_name
|
| 14 |
from safetensors.torch import load_file, save_file
|
| 15 |
from transformers import AutoConfig
|
| 16 |
-
from transformers.pipelines.base import infer_framework_load_model
|
| 17 |
|
| 18 |
|
| 19 |
COMMIT_DESCRIPTION = """
|
|
@@ -72,14 +71,14 @@ def rename(pt_filename: str) -> str:
|
|
| 72 |
|
| 73 |
|
| 74 |
def convert_multi(model_id: str, folder: str, token: Optional[str]) -> ConversionResult:
|
| 75 |
-
filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json", token=token)
|
| 76 |
with open(filename, "r") as f:
|
| 77 |
data = json.load(f)
|
| 78 |
|
| 79 |
filenames = set(data["weight_map"].values())
|
| 80 |
local_filenames = []
|
| 81 |
for filename in filenames:
|
| 82 |
-
pt_filename = hf_hub_download(repo_id=model_id, filename=filename, token=token)
|
| 83 |
|
| 84 |
sf_filename = rename(pt_filename)
|
| 85 |
sf_filename = os.path.join(folder, sf_filename)
|
|
@@ -103,7 +102,7 @@ def convert_multi(model_id: str, folder: str, token: Optional[str]) -> Conversio
|
|
| 103 |
|
| 104 |
|
| 105 |
def convert_single(model_id: str, folder: str, token: Optional[str]) -> ConversionResult:
|
| 106 |
-
pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token)
|
| 107 |
|
| 108 |
sf_name = "model.safetensors"
|
| 109 |
sf_filename = os.path.join(folder, sf_name)
|
|
@@ -157,7 +156,7 @@ def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]])
|
|
| 157 |
|
| 158 |
|
| 159 |
def check_final_model(model_id: str, folder: str, token: Optional[str]):
|
| 160 |
-
config = hf_hub_download(repo_id=model_id, filename="config.json", token=token)
|
| 161 |
shutil.copy(config, os.path.join(folder, "config.json"))
|
| 162 |
config = AutoConfig.from_pretrained(folder)
|
| 163 |
|
|
@@ -244,7 +243,7 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str], token: Opti
|
|
| 244 |
for filename in filenames:
|
| 245 |
prefix, ext = os.path.splitext(filename)
|
| 246 |
if ext in extensions:
|
| 247 |
-
pt_filename = hf_hub_download(model_id, filename=filename, token=token)
|
| 248 |
dirname, raw_filename = os.path.split(filename)
|
| 249 |
if raw_filename == "pytorch_model.bin":
|
| 250 |
# XXX: This is a special case to handle `transformers` and the
|
|
|
|
| 13 |
from huggingface_hub.file_download import repo_folder_name
|
| 14 |
from safetensors.torch import load_file, save_file
|
| 15 |
from transformers import AutoConfig
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
COMMIT_DESCRIPTION = """
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
def convert_multi(model_id: str, folder: str, token: Optional[str]) -> ConversionResult:
|
| 74 |
+
filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder)
|
| 75 |
with open(filename, "r") as f:
|
| 76 |
data = json.load(f)
|
| 77 |
|
| 78 |
filenames = set(data["weight_map"].values())
|
| 79 |
local_filenames = []
|
| 80 |
for filename in filenames:
|
| 81 |
+
pt_filename = hf_hub_download(repo_id=model_id, filename=filename, token=token, cache_dir=folder)
|
| 82 |
|
| 83 |
sf_filename = rename(pt_filename)
|
| 84 |
sf_filename = os.path.join(folder, sf_filename)
|
|
|
|
| 102 |
|
| 103 |
|
| 104 |
def convert_single(model_id: str, folder: str, token: Optional[str]) -> ConversionResult:
|
| 105 |
+
pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token, cache_dir=folder)
|
| 106 |
|
| 107 |
sf_name = "model.safetensors"
|
| 108 |
sf_filename = os.path.join(folder, sf_name)
|
|
|
|
| 156 |
|
| 157 |
|
| 158 |
def check_final_model(model_id: str, folder: str, token: Optional[str]):
|
| 159 |
+
config = hf_hub_download(repo_id=model_id, filename="config.json", token=token, cache_dir=folder)
|
| 160 |
shutil.copy(config, os.path.join(folder, "config.json"))
|
| 161 |
config = AutoConfig.from_pretrained(folder)
|
| 162 |
|
|
|
|
| 243 |
for filename in filenames:
|
| 244 |
prefix, ext = os.path.splitext(filename)
|
| 245 |
if ext in extensions:
|
| 246 |
+
pt_filename = hf_hub_download(model_id, filename=filename, token=token, cache_dir=folder)
|
| 247 |
dirname, raw_filename = os.path.split(filename)
|
| 248 |
if raw_filename == "pytorch_model.bin":
|
| 249 |
# XXX: This is a special case to handle `transformers` and the
|