Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import PIL.Image | |
| import numpy as np | |
| from torch import nn | |
| import torch.distributed as dist | |
| import timm.models.hub as timm_hub | |
| """Modified from https://github.com/CompVis/taming-transformers.git""" | |
| import hashlib | |
| import requests | |
| from tqdm import tqdm | |
| try: | |
| import piq | |
| except: | |
| pass | |
| _CONTEXT_PARALLEL_GROUP = None | |
| _CONTEXT_PARALLEL_SIZE = None | |
| def is_dist_avail_and_initialized(): | |
| if not dist.is_available(): | |
| return False | |
| if not dist.is_initialized(): | |
| return False | |
| return True | |
| def get_world_size(): | |
| if not is_dist_avail_and_initialized(): | |
| return 1 | |
| return dist.get_world_size() | |
| def get_rank(): | |
| if not is_dist_avail_and_initialized(): | |
| return 0 | |
| return dist.get_rank() | |
| def is_main_process(): | |
| return get_rank() == 0 | |
| def is_context_parallel_initialized(): | |
| if _CONTEXT_PARALLEL_GROUP is None: | |
| return False | |
| else: | |
| return True | |
| def set_context_parallel_group(size, group): | |
| global _CONTEXT_PARALLEL_GROUP | |
| global _CONTEXT_PARALLEL_SIZE | |
| _CONTEXT_PARALLEL_GROUP = group | |
| _CONTEXT_PARALLEL_SIZE = size | |
| def initialize_context_parallel(context_parallel_size): | |
| global _CONTEXT_PARALLEL_GROUP | |
| global _CONTEXT_PARALLEL_SIZE | |
| assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized" | |
| _CONTEXT_PARALLEL_SIZE = context_parallel_size | |
| rank = torch.distributed.get_rank() | |
| world_size = torch.distributed.get_world_size() | |
| for i in range(0, world_size, context_parallel_size): | |
| ranks = range(i, i + context_parallel_size) | |
| group = torch.distributed.new_group(ranks) | |
| if rank in ranks: | |
| _CONTEXT_PARALLEL_GROUP = group | |
| break | |
| def get_context_parallel_group(): | |
| assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized" | |
| return _CONTEXT_PARALLEL_GROUP | |
| def get_context_parallel_world_size(): | |
| assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" | |
| return _CONTEXT_PARALLEL_SIZE | |
| def get_context_parallel_rank(): | |
| assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" | |
| rank = get_rank() | |
| cp_rank = rank % _CONTEXT_PARALLEL_SIZE | |
| return cp_rank | |
| def get_context_parallel_group_rank(): | |
| assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" | |
| rank = get_rank() | |
| cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE | |
| return cp_group_rank | |
| def download_cached_file(url, check_hash=True, progress=False): | |
| """ | |
| Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. | |
| If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. | |
| """ | |
| def get_cached_file_path(): | |
| # a hack to sync the file path across processes | |
| parts = torch.hub.urlparse(url) | |
| filename = os.path.basename(parts.path) | |
| cached_file = os.path.join(timm_hub.get_cache_dir(), filename) | |
| return cached_file | |
| if is_main_process(): | |
| timm_hub.download_cached_file(url, check_hash, progress) | |
| if is_dist_avail_and_initialized(): | |
| dist.barrier() | |
| return get_cached_file_path() | |
| def convert_weights_to_fp16(model: nn.Module): | |
| """Convert applicable model parameters to fp16""" | |
| def _convert_weights_to_fp16(l): | |
| if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): | |
| l.weight.data = l.weight.data.to(torch.float16) | |
| if l.bias is not None: | |
| l.bias.data = l.bias.data.to(torch.float16) | |
| model.apply(_convert_weights_to_fp16) | |
| def convert_weights_to_bf16(model: nn.Module): | |
| """Convert applicable model parameters to fp16""" | |
| def _convert_weights_to_bf16(l): | |
| if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): | |
| l.weight.data = l.weight.data.to(torch.bfloat16) | |
| if l.bias is not None: | |
| l.bias.data = l.bias.data.to(torch.bfloat16) | |
| model.apply(_convert_weights_to_bf16) | |
| def save_result(result, result_dir, filename, remove_duplicate="", save_format='json'): | |
| import json | |
| import jsonlines | |
| print("Dump result") | |
| # Make the temp dir for saving results | |
| if not os.path.exists(result_dir): | |
| if is_main_process(): | |
| os.makedirs(result_dir) | |
| if is_dist_avail_and_initialized(): | |
| torch.distributed.barrier() | |
| result_file = os.path.join( | |
| result_dir, "%s_rank%d.json" % (filename, get_rank()) | |
| ) | |
| final_result_file = os.path.join(result_dir, f"{filename}.{save_format}") | |
| json.dump(result, open(result_file, "w")) | |
| if is_dist_avail_and_initialized(): | |
| torch.distributed.barrier() | |
| if is_main_process(): | |
| # print("rank %d starts merging results." % get_rank()) | |
| # combine results from all processes | |
| result = [] | |
| for rank in range(get_world_size()): | |
| result_file = os.path.join(result_dir, "%s_rank%d.json" % (filename, rank)) | |
| res = json.load(open(result_file, "r")) | |
| result += res | |
| # print("Remove duplicate") | |
| if remove_duplicate: | |
| result_new = [] | |
| id_set = set() | |
| for res in result: | |
| if res[remove_duplicate] not in id_set: | |
| id_set.add(res[remove_duplicate]) | |
| result_new.append(res) | |
| result = result_new | |
| if save_format == 'json': | |
| json.dump(result, open(final_result_file, "w")) | |
| else: | |
| assert save_format == 'jsonl', "Only support json adn jsonl format" | |
| with jsonlines.open(final_result_file, "w") as writer: | |
| writer.write_all(result) | |
| # print("result file saved to %s" % final_result_file) | |
| return final_result_file | |
| # resizing utils | |
| # TODO: clean up later | |
| def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): | |
| h, w = input.shape[-2:] | |
| factors = (h / size[0], w / size[1]) | |
| # First, we have to determine sigma | |
| # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 | |
| sigmas = ( | |
| max((factors[0] - 1.0) / 2.0, 0.001), | |
| max((factors[1] - 1.0) / 2.0, 0.001), | |
| ) | |
| # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma | |
| # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 | |
| # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now | |
| ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) | |
| # Make sure it is odd | |
| if (ks[0] % 2) == 0: | |
| ks = ks[0] + 1, ks[1] | |
| if (ks[1] % 2) == 0: | |
| ks = ks[0], ks[1] + 1 | |
| input = _gaussian_blur2d(input, ks, sigmas) | |
| output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) | |
| return output | |
| def _compute_padding(kernel_size): | |
| """Compute padding tuple.""" | |
| # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) | |
| # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad | |
| if len(kernel_size) < 2: | |
| raise AssertionError(kernel_size) | |
| computed = [k - 1 for k in kernel_size] | |
| # for even kernels we need to do asymmetric padding :( | |
| out_padding = 2 * len(kernel_size) * [0] | |
| for i in range(len(kernel_size)): | |
| computed_tmp = computed[-(i + 1)] | |
| pad_front = computed_tmp // 2 | |
| pad_rear = computed_tmp - pad_front | |
| out_padding[2 * i + 0] = pad_front | |
| out_padding[2 * i + 1] = pad_rear | |
| return out_padding | |
| def _filter2d(input, kernel): | |
| # prepare kernel | |
| b, c, h, w = input.shape | |
| tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) | |
| tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) | |
| height, width = tmp_kernel.shape[-2:] | |
| padding_shape: list[int] = _compute_padding([height, width]) | |
| input = torch.nn.functional.pad(input, padding_shape, mode="reflect") | |
| # kernel and input tensor reshape to align element-wise or batch-wise params | |
| tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) | |
| input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) | |
| # convolve the tensor with the kernel. | |
| output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) | |
| out = output.view(b, c, h, w) | |
| return out | |
| def _gaussian(window_size: int, sigma): | |
| if isinstance(sigma, float): | |
| sigma = torch.tensor([[sigma]]) | |
| batch_size = sigma.shape[0] | |
| x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) | |
| if window_size % 2 == 0: | |
| x = x + 0.5 | |
| gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) | |
| return gauss / gauss.sum(-1, keepdim=True) | |
| def _gaussian_blur2d(input, kernel_size, sigma): | |
| if isinstance(sigma, tuple): | |
| sigma = torch.tensor([sigma], dtype=input.dtype) | |
| else: | |
| sigma = sigma.to(dtype=input.dtype) | |
| ky, kx = int(kernel_size[0]), int(kernel_size[1]) | |
| bs = sigma.shape[0] | |
| kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) | |
| kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) | |
| out_x = _filter2d(input, kernel_x[..., None, :]) | |
| out = _filter2d(out_x, kernel_y[..., None]) | |
| return out | |
| URL_MAP = { | |
| "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" | |
| } | |
| CKPT_MAP = { | |
| "vgg_lpips": "vgg.pth" | |
| } | |
| MD5_MAP = { | |
| "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" | |
| } | |
| def download(url, local_path, chunk_size=1024): | |
| os.makedirs(os.path.split(local_path)[0], exist_ok=True) | |
| with requests.get(url, stream=True) as r: | |
| total_size = int(r.headers.get("content-length", 0)) | |
| with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: | |
| with open(local_path, "wb") as f: | |
| for data in r.iter_content(chunk_size=chunk_size): | |
| if data: | |
| f.write(data) | |
| pbar.update(chunk_size) | |
| def md5_hash(path): | |
| with open(path, "rb") as f: | |
| content = f.read() | |
| return hashlib.md5(content).hexdigest() | |
| def get_ckpt_path(name, root, check=False): | |
| assert name in URL_MAP | |
| path = os.path.join(root, CKPT_MAP[name]) | |
| print(md5_hash(path)) | |
| if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): | |
| print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) | |
| download(URL_MAP[name], path) | |
| md5 = md5_hash(path) | |
| assert md5 == MD5_MAP[name], md5 | |
| return path | |
| class KeyNotFoundError(Exception): | |
| def __init__(self, cause, keys=None, visited=None): | |
| self.cause = cause | |
| self.keys = keys | |
| self.visited = visited | |
| messages = list() | |
| if keys is not None: | |
| messages.append("Key not found: {}".format(keys)) | |
| if visited is not None: | |
| messages.append("Visited: {}".format(visited)) | |
| messages.append("Cause:\n{}".format(cause)) | |
| message = "\n".join(messages) | |
| super().__init__(message) | |
| def retrieve( | |
| list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False | |
| ): | |
| """Given a nested list or dict return the desired value at key expanding | |
| callable nodes if necessary and :attr:`expand` is ``True``. The expansion | |
| is done in-place. | |
| Parameters | |
| ---------- | |
| list_or_dict : list or dict | |
| Possibly nested list or dictionary. | |
| key : str | |
| key/to/value, path like string describing all keys necessary to | |
| consider to get to the desired value. List indices can also be | |
| passed here. | |
| splitval : str | |
| String that defines the delimiter between keys of the | |
| different depth levels in `key`. | |
| default : obj | |
| Value returned if :attr:`key` is not found. | |
| expand : bool | |
| Whether to expand callable nodes on the path or not. | |
| Returns | |
| ------- | |
| The desired value or if :attr:`default` is not ``None`` and the | |
| :attr:`key` is not found returns ``default``. | |
| Raises | |
| ------ | |
| Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is | |
| ``None``. | |
| """ | |
| keys = key.split(splitval) | |
| success = True | |
| try: | |
| visited = [] | |
| parent = None | |
| last_key = None | |
| for key in keys: | |
| if callable(list_or_dict): | |
| if not expand: | |
| raise KeyNotFoundError( | |
| ValueError( | |
| "Trying to get past callable node with expand=False." | |
| ), | |
| keys=keys, | |
| visited=visited, | |
| ) | |
| list_or_dict = list_or_dict() | |
| parent[last_key] = list_or_dict | |
| last_key = key | |
| parent = list_or_dict | |
| try: | |
| if isinstance(list_or_dict, dict): | |
| list_or_dict = list_or_dict[key] | |
| else: | |
| list_or_dict = list_or_dict[int(key)] | |
| except (KeyError, IndexError, ValueError) as e: | |
| raise KeyNotFoundError(e, keys=keys, visited=visited) | |
| visited += [key] | |
| # final expansion of retrieved value | |
| if expand and callable(list_or_dict): | |
| list_or_dict = list_or_dict() | |
| parent[last_key] = list_or_dict | |
| except KeyNotFoundError as e: | |
| if default is None: | |
| raise e | |
| else: | |
| list_or_dict = default | |
| success = False | |
| if not pass_success: | |
| return list_or_dict | |
| else: | |
| return list_or_dict, success |