Spaces:
Runtime error
Runtime error
| import glob | |
| import os | |
| import subprocess | |
| git = os.environ.get('GIT', "git") | |
| def run(command, desc=None, errdesc=None, custom_env=None, live=False): | |
| if desc is not None: | |
| print(desc) | |
| if live: | |
| result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env) | |
| if result.returncode != 0: | |
| raise RuntimeError(f"""{errdesc or 'Error running command'}. | |
| Command: {command} | |
| Error code: {result.returncode}""") | |
| return "" | |
| result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env) | |
| if result.returncode != 0: | |
| message = f"""{errdesc or 'Error running command'}. | |
| Command: {command} | |
| Error code: {result.returncode} | |
| stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'} | |
| stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'} | |
| """ | |
| raise RuntimeError(message) | |
| return result.stdout.decode(encoding="utf8", errors="ignore") | |
| def git_clone(url, dir, name, commithash=None): | |
| # TODO clone into temporary dir and move if successful | |
| if os.path.exists(dir): | |
| if commithash is None: | |
| return | |
| current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip() | |
| if current_hash == commithash: | |
| return | |
| run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") | |
| run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") | |
| return | |
| run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}") | |
| if commithash is not None: | |
| run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") | |
| def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list: | |
| """ | |
| A one-and done loader to try finding the desired models in specified directories. | |
| @param download_name: Specify to download from model_url immediately. | |
| @param model_url: If no other models are found, this will be downloaded on upscale. | |
| @param model_path: The location to store/find models in. | |
| @param command_path: A command-line argument to search for models in first. | |
| @param ext_filter: An optional list of filename extensions to filter by | |
| @return: A list of paths containing the desired model(s) | |
| """ | |
| output = [] | |
| if ext_filter is None: | |
| ext_filter = [] | |
| try: | |
| places = [] | |
| if command_path is not None and command_path != model_path: | |
| pretrained_path = os.path.join(command_path, 'experiments/pretrained_models') | |
| if os.path.exists(pretrained_path): | |
| print(f"Appending path: {pretrained_path}") | |
| places.append(pretrained_path) | |
| elif os.path.exists(command_path): | |
| places.append(command_path) | |
| places.append(model_path) | |
| for place in places: | |
| if os.path.exists(place): | |
| for file in glob.iglob(place + '**/**', recursive=True): | |
| full_path = file | |
| if os.path.isdir(full_path): | |
| continue | |
| if os.path.islink(full_path) and not os.path.exists(full_path): | |
| print(f"Skipping broken symlink: {full_path}") | |
| continue | |
| if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]): | |
| continue | |
| if len(ext_filter) != 0: | |
| model_name, extension = os.path.splitext(file) | |
| if extension not in ext_filter: | |
| continue | |
| if file not in output: | |
| output.append(full_path) | |
| if model_url is not None and len(output) == 0: | |
| if download_name is not None: | |
| from basicsr.utils.download_util import load_file_from_url | |
| dl = load_file_from_url(model_url, model_path, True, download_name) | |
| output.append(dl) | |
| else: | |
| output.append(model_url) | |
| except Exception: | |
| pass | |
| return output |