Spaces:
Running
on
L4
Running
on
L4
| import os | |
| import safetensors.torch | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from seva.model import Seva, SevaParams | |
| def seed_everything(seed: int = 0): | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| def print_load_warning(missing: list[str], unexpected: list[str]) -> None: | |
| if len(missing) > 0 and len(unexpected) > 0: | |
| print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) | |
| print("\n" + "-" * 79 + "\n") | |
| print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) | |
| elif len(missing) > 0: | |
| print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) | |
| elif len(unexpected) > 0: | |
| print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) | |
| def load_model( | |
| pretrained_model_name_or_path: str = "stabilityai/stable-virtual-camera", | |
| weight_name: str = "model.safetensors", | |
| device: str | torch.device = "cuda", | |
| verbose: bool = False, | |
| ) -> Seva: | |
| if os.path.isdir(pretrained_model_name_or_path): | |
| weight_path = os.path.join(pretrained_model_name_or_path, weight_name) | |
| else: | |
| weight_path = hf_hub_download( | |
| repo_id=pretrained_model_name_or_path, filename=weight_name | |
| ) | |
| _ = hf_hub_download( | |
| repo_id=pretrained_model_name_or_path, filename="config.yaml" | |
| ) | |
| state_dict = safetensors.torch.load_file( | |
| weight_path, | |
| device=str(device), | |
| ) | |
| with torch.device("meta"): | |
| model = Seva(SevaParams()).to(torch.bfloat16) | |
| missing, unexpected = model.load_state_dict(state_dict, strict=False, assign=True) | |
| if verbose: | |
| print_load_warning(missing, unexpected) | |
| return model | |