Spaces:
Paused
Paused
| import os | |
| import json | |
| import random | |
| import glob | |
| import torch | |
| import einops | |
| import torchvision | |
| import safetensors.torch as sf | |
| def write_to_json(data, file_path): | |
| temp_file_path = file_path + ".tmp" | |
| with open(temp_file_path, 'wt', encoding='utf-8') as temp_file: | |
| json.dump(data, temp_file, indent=4) | |
| os.replace(temp_file_path, file_path) | |
| return | |
| def read_from_json(file_path): | |
| with open(file_path, 'rt', encoding='utf-8') as file: | |
| data = json.load(file) | |
| return data | |
| def get_active_parameters(m): | |
| return {k:v for k, v in m.named_parameters() if v.requires_grad} | |
| def cast_training_params(m, dtype=torch.float32): | |
| for param in m.parameters(): | |
| if param.requires_grad: | |
| param.data = param.to(dtype) | |
| return | |
| def set_attr_recursive(obj, attr, value): | |
| attrs = attr.split(".") | |
| for name in attrs[:-1]: | |
| obj = getattr(obj, name) | |
| setattr(obj, attrs[-1], value) | |
| return | |
| def batch_mixture(a, b, probability_a=0.5, mask_a=None): | |
| assert a.shape == b.shape, "Tensors must have the same shape" | |
| batch_size = a.size(0) | |
| if mask_a is None: | |
| mask_a = torch.rand(batch_size) < probability_a | |
| mask_a = mask_a.to(a.device) | |
| mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1)) | |
| result = torch.where(mask_a, a, b) | |
| return result | |
| def zero_module(module): | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |
| def load_last_state(model, folder='accelerator_output'): | |
| file_pattern = os.path.join(folder, '**', 'model.safetensors') | |
| files = glob.glob(file_pattern, recursive=True) | |
| if not files: | |
| print("No model.safetensors files found in the specified folder.") | |
| return | |
| newest_file = max(files, key=os.path.getmtime) | |
| state_dict = sf.load_file(newest_file) | |
| missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | |
| if missing_keys: | |
| print("Missing keys:", missing_keys) | |
| if unexpected_keys: | |
| print("Unexpected keys:", unexpected_keys) | |
| print("Loaded model state from:", newest_file) | |
| return | |
| def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32): | |
| tags = tags_str.split(', ') | |
| tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags))) | |
| prompt = ', '.join(tags) | |
| return prompt | |
| def save_bcthw_as_mp4(x, output_filename, fps=10): | |
| b, c, t, h, w = x.shape | |
| per_row = b | |
| for p in [6, 5, 4, 3, 2]: | |
| if b % p == 0: | |
| per_row = p | |
| break | |
| os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) | |
| x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 | |
| x = x.detach().cpu().to(torch.uint8) | |
| x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row) | |
| torchvision.io.write_video(output_filename, x, fps=fps, video_codec='h264', options={'crf': '0'}) | |
| return x | |
| def save_bcthw_as_png(x, output_filename): | |
| os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) | |
| x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 | |
| x = x.detach().cpu().to(torch.uint8) | |
| x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)') | |
| torchvision.io.write_png(x, output_filename) | |
| return output_filename | |
| def add_tensors_with_padding(tensor1, tensor2): | |
| if tensor1.shape == tensor2.shape: | |
| return tensor1 + tensor2 | |
| shape1 = tensor1.shape | |
| shape2 = tensor2.shape | |
| new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2)) | |
| padded_tensor1 = torch.zeros(new_shape) | |
| padded_tensor2 = torch.zeros(new_shape) | |
| padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1 | |
| padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2 | |
| result = padded_tensor1 + padded_tensor2 | |
| return result | |