Spaces:
Runtime error
Runtime error
Commit
·
c6c5536
1
Parent(s):
38707b6
Update convert.py
Browse files- convert.py +1 -53
convert.py
CHANGED
|
@@ -133,57 +133,6 @@ def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]])
|
|
| 133 |
errors.append(f"{key} : SF warnings contain {sf_only} which are not present in PT warnings")
|
| 134 |
return "\n".join(errors)
|
| 135 |
|
| 136 |
-
|
| 137 |
-
def check_final_model(model_id: str, folder: str):
|
| 138 |
-
config = hf_hub_download(repo_id=model_id, filename="config.json")
|
| 139 |
-
shutil.copy(config, os.path.join(folder, "config.json"))
|
| 140 |
-
config = AutoConfig.from_pretrained(folder)
|
| 141 |
-
|
| 142 |
-
_, (pt_model, pt_infos) = infer_framework_load_model(model_id, config, output_loading_info=True)
|
| 143 |
-
_, (sf_model, sf_infos) = infer_framework_load_model(folder, config, output_loading_info=True)
|
| 144 |
-
|
| 145 |
-
if pt_infos != sf_infos:
|
| 146 |
-
error_string = create_diff(pt_infos, sf_infos)
|
| 147 |
-
raise ValueError(f"Different infos when reloading the model: {error_string}")
|
| 148 |
-
|
| 149 |
-
pt_params = pt_model.state_dict()
|
| 150 |
-
sf_params = sf_model.state_dict()
|
| 151 |
-
|
| 152 |
-
pt_shared = shared_pointers(pt_params)
|
| 153 |
-
sf_shared = shared_pointers(sf_params)
|
| 154 |
-
if pt_shared != sf_shared:
|
| 155 |
-
raise RuntimeError("The reconstructed model is wrong, shared tensors are different {shared_pt} != {shared_tf}")
|
| 156 |
-
|
| 157 |
-
sig = signature(pt_model.forward)
|
| 158 |
-
input_ids = torch.arange(10).unsqueeze(0)
|
| 159 |
-
pixel_values = torch.randn(1, 3, 224, 224)
|
| 160 |
-
input_values = torch.arange(1000).float().unsqueeze(0)
|
| 161 |
-
kwargs = {}
|
| 162 |
-
if "input_ids" in sig.parameters:
|
| 163 |
-
kwargs["input_ids"] = input_ids
|
| 164 |
-
if "decoder_input_ids" in sig.parameters:
|
| 165 |
-
kwargs["decoder_input_ids"] = input_ids
|
| 166 |
-
if "pixel_values" in sig.parameters:
|
| 167 |
-
kwargs["pixel_values"] = pixel_values
|
| 168 |
-
if "input_values" in sig.parameters:
|
| 169 |
-
kwargs["input_values"] = input_values
|
| 170 |
-
if "bbox" in sig.parameters:
|
| 171 |
-
kwargs["bbox"] = torch.zeros((1, 10, 4)).long()
|
| 172 |
-
if "image" in sig.parameters:
|
| 173 |
-
kwargs["image"] = pixel_values
|
| 174 |
-
|
| 175 |
-
if torch.cuda.is_available():
|
| 176 |
-
pt_model = pt_model.cuda()
|
| 177 |
-
sf_model = sf_model.cuda()
|
| 178 |
-
kwargs = {k: v.cuda() for k, v in kwargs.items()}
|
| 179 |
-
|
| 180 |
-
pt_logits = pt_model(**kwargs)[0]
|
| 181 |
-
sf_logits = sf_model(**kwargs)[0]
|
| 182 |
-
|
| 183 |
-
torch.testing.assert_close(sf_logits, pt_logits)
|
| 184 |
-
print(f"Model {model_id} is ok !")
|
| 185 |
-
|
| 186 |
-
|
| 187 |
def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
|
| 188 |
try:
|
| 189 |
discussions = api.get_repo_discussions(repo_id=model_id)
|
|
@@ -218,7 +167,7 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> List["Co
|
|
| 218 |
def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]:
|
| 219 |
pr_title = "Adding `safetensors` variant of this model"
|
| 220 |
info = api.model_info(model_id)
|
| 221 |
-
filenames = set(s.rfilename for s in info.siblings)
|
| 222 |
|
| 223 |
with TemporaryDirectory() as d:
|
| 224 |
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
|
|
@@ -242,7 +191,6 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["Commi
|
|
| 242 |
operations = convert_multi(model_id, folder)
|
| 243 |
else:
|
| 244 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
| 245 |
-
check_final_model(model_id, folder)
|
| 246 |
else:
|
| 247 |
operations = convert_generic(model_id, folder, filenames)
|
| 248 |
|
|
|
|
| 133 |
errors.append(f"{key} : SF warnings contain {sf_only} which are not present in PT warnings")
|
| 134 |
return "\n".join(errors)
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
|
| 137 |
try:
|
| 138 |
discussions = api.get_repo_discussions(repo_id=model_id)
|
|
|
|
| 167 |
def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]:
|
| 168 |
pr_title = "Adding `safetensors` variant of this model"
|
| 169 |
info = api.model_info(model_id)
|
| 170 |
+
filenames = set(s.rfilename for s in info.siblings if len(s.rfilename.split("/")) > 1)
|
| 171 |
|
| 172 |
with TemporaryDirectory() as d:
|
| 173 |
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
|
|
|
|
| 191 |
operations = convert_multi(model_id, folder)
|
| 192 |
else:
|
| 193 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
|
|
|
| 194 |
else:
|
| 195 |
operations = convert_generic(model_id, folder, filenames)
|
| 196 |
|