Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Copyright (c) Alibaba, Inc. and its affiliates. | |
| from inspect import isfunction | |
| import torch | |
| from torch.nn.utils.rnn import pad_sequence | |
| from scepter.modules.utils.distribute import we | |
| def exists(x): | |
| return x is not None | |
| def default(val, d): | |
| if exists(val): | |
| return val | |
| return d() if isfunction(d) else d | |
| def disabled_train(self, mode=True): | |
| """Overwrite model.train with this function to make sure train/eval mode | |
| does not change anymore.""" | |
| return self | |
| def transfer_size(para_num): | |
| if para_num > 1000 * 1000 * 1000 * 1000: | |
| bill = para_num / (1000 * 1000 * 1000 * 1000) | |
| return '{:.2f}T'.format(bill) | |
| elif para_num > 1000 * 1000 * 1000: | |
| gyte = para_num / (1000 * 1000 * 1000) | |
| return '{:.2f}B'.format(gyte) | |
| elif para_num > (1000 * 1000): | |
| meta = para_num / (1000 * 1000) | |
| return '{:.2f}M'.format(meta) | |
| elif para_num > 1000: | |
| kelo = para_num / 1000 | |
| return '{:.2f}K'.format(kelo) | |
| else: | |
| return para_num | |
| def count_params(model): | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| return transfer_size(total_params) | |
| def expand_dims_like(x, y): | |
| while x.dim() != y.dim(): | |
| x = x.unsqueeze(-1) | |
| return x | |
| def unpack_tensor_into_imagelist(image_tensor, shapes): | |
| image_list = [] | |
| for img, shape in zip(image_tensor, shapes): | |
| h, w = shape[0], shape[1] | |
| image_list.append(img[:, :h * w].view(1, -1, h, w)) | |
| return image_list | |
| def find_example(tensor_list, image_list): | |
| for i in tensor_list: | |
| if isinstance(i, torch.Tensor): | |
| return torch.zeros_like(i) | |
| for i in image_list: | |
| if isinstance(i, torch.Tensor): | |
| _, c, h, w = i.size() | |
| return torch.zeros_like(i.view(c, h * w).transpose(1, 0)) | |
| return None | |
| def pack_imagelist_into_tensor_v2(image_list): | |
| # allow None | |
| example = None | |
| image_tensor, shapes = [], [] | |
| for img in image_list: | |
| if img is None: | |
| example = find_example(image_tensor, | |
| image_list) if example is None else example | |
| image_tensor.append(example) | |
| shapes.append(None) | |
| continue | |
| _, c, h, w = img.size() | |
| image_tensor.append(img.view(c, h * w).transpose(1, 0)) # h*w, c | |
| shapes.append((h, w)) | |
| image_tensor = pad_sequence(image_tensor, | |
| batch_first=True).permute(0, 2, 1) # b, c, l | |
| return image_tensor, shapes | |
| def to_device(inputs, strict=True): | |
| if inputs is None: | |
| return None | |
| if strict: | |
| assert all(isinstance(i, torch.Tensor) for i in inputs) | |
| return [i.to(we.device_id) if i is not None else None for i in inputs] | |
| def check_list_of_list(ll): | |
| return isinstance(ll, list) and all(isinstance(i, list) for i in ll) | |