Spaces:
Runtime error
Runtime error
| import torch | |
| import random | |
| import cv2 | |
| import fnmatch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| import torchvision.transforms.functional as TF | |
| from diffusers.optimization import get_scheduler | |
| from einops import rearrange, repeat | |
| from omegaconf import OmegaConf | |
| from dataset import * | |
| from models.unet.motion_embeddings import * | |
| from .lora import * | |
| from .lora_handler import * | |
| def find_videos(directory, extensions=('.mp4', '.mkv', '.avi', '.mov', '.flv', '.wmv', '.gif')): | |
| video_files = [] | |
| for root, dirs, files in os.walk(directory): | |
| for extension in extensions: | |
| for filename in fnmatch.filter(files, '*' + extension): | |
| video_files.append(os.path.join(root, filename)) | |
| return video_files | |
| def param_optim(model, condition, extra_params=None, is_lora=False, negation=None): | |
| extra_params = extra_params if len(extra_params.keys()) > 0 else None | |
| return { | |
| "model": model, | |
| "condition": condition, | |
| 'extra_params': extra_params, | |
| 'is_lora': is_lora, | |
| "negation": negation | |
| } | |
| def create_optim_params(name='param', params=None, lr=5e-6, extra_params=None): | |
| params = { | |
| "name": name, | |
| "params": params, | |
| "lr": lr | |
| } | |
| if extra_params is not None: | |
| for k, v in extra_params.items(): | |
| params[k] = v | |
| return params | |
| def create_optimizer_params(model_list, lr): | |
| import itertools | |
| optimizer_params = [] | |
| for optim in model_list: | |
| model, condition, extra_params, is_lora, negation = optim.values() | |
| # Check if we are doing LoRA training. | |
| if is_lora and condition and isinstance(model, list): | |
| params = create_optim_params( | |
| params=itertools.chain(*model), | |
| extra_params=extra_params | |
| ) | |
| optimizer_params.append(params) | |
| continue | |
| if is_lora and condition and not isinstance(model, list): | |
| for n, p in model.named_parameters(): | |
| if 'lora' in n: | |
| params = create_optim_params(n, p, lr, extra_params) | |
| optimizer_params.append(params) | |
| continue | |
| # If this is true, we can train it. | |
| if condition: | |
| for n, p in model.named_parameters(): | |
| should_negate = 'lora' in n and not is_lora | |
| if should_negate: continue | |
| params = create_optim_params(n, p, lr, extra_params) | |
| optimizer_params.append(params) | |
| return optimizer_params | |
| def get_optimizer(use_8bit_adam): | |
| if use_8bit_adam: | |
| try: | |
| import bitsandbytes as bnb | |
| except ImportError: | |
| raise ImportError( | |
| "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" | |
| ) | |
| return bnb.optim.AdamW8bit | |
| else: | |
| return torch.optim.AdamW | |
| # Initialize the optimizer | |
| def prepare_optimizers(params, config, **extra_params): | |
| optimizer_cls = get_optimizer(config.train.use_8bit_adam) | |
| optimizer_temporal = optimizer_cls( | |
| params, | |
| lr=config.loss.learning_rate | |
| ) | |
| lr_scheduler_temporal = get_scheduler( | |
| config.loss.lr_scheduler, | |
| optimizer=optimizer_temporal, | |
| num_warmup_steps=config.loss.lr_warmup_steps * config.train.gradient_accumulation_steps, | |
| num_training_steps=config.train.max_train_steps * config.train.gradient_accumulation_steps, | |
| ) | |
| # Insert Spatial LoRAs | |
| if config.loss.type == 'DebiasedHybrid': | |
| unet_lora_params_spatial_list = extra_params.get('unet_lora_params_spatial_list', []) | |
| spatial_lora_num = extra_params.get('spatial_lora_num', 1) | |
| optimizer_spatial_list = [] | |
| lr_scheduler_spatial_list = [] | |
| for i in range(spatial_lora_num): | |
| unet_lora_params_spatial = unet_lora_params_spatial_list[i] | |
| optimizer_spatial = optimizer_cls( | |
| create_optimizer_params( | |
| [ | |
| param_optim( | |
| unet_lora_params_spatial, | |
| config.loss.use_unet_lora, | |
| is_lora=True, | |
| extra_params={**{"lr": config.loss.learning_rate_spatial}} | |
| ) | |
| ], | |
| config.loss.learning_rate_spatial | |
| ), | |
| lr=config.loss.learning_rate_spatial | |
| ) | |
| optimizer_spatial_list.append(optimizer_spatial) | |
| # Scheduler | |
| lr_scheduler_spatial = get_scheduler( | |
| config.loss.lr_scheduler, | |
| optimizer=optimizer_spatial, | |
| num_warmup_steps=config.loss.lr_warmup_steps * config.train.gradient_accumulation_steps, | |
| num_training_steps=config.train.max_train_steps * config.train.gradient_accumulation_steps, | |
| ) | |
| lr_scheduler_spatial_list.append(lr_scheduler_spatial) | |
| else: | |
| optimizer_spatial_list = [] | |
| lr_scheduler_spatial_list = [] | |
| return [optimizer_temporal] + optimizer_spatial_list, [lr_scheduler_temporal] + lr_scheduler_spatial_list | |
| def sample_noise(latents, noise_strength, use_offset_noise=False): | |
| b, c, f, *_ = latents.shape | |
| noise_latents = torch.randn_like(latents, device=latents.device) | |
| if use_offset_noise: | |
| offset_noise = torch.randn(b, c, f, 1, 1, device=latents.device) | |
| noise_latents = noise_latents + noise_strength * offset_noise | |
| return noise_latents | |
| def tensor_to_vae_latent(t, vae): | |
| video_length = t.shape[1] | |
| t = rearrange(t, "b f c h w -> (b f) c h w") | |
| latents = vae.encode(t).latent_dist.sample() | |
| latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length) | |
| latents = latents * 0.18215 | |
| return latents | |
| def prepare_data(config, tokenizer): | |
| # Get the training dataset based on types (json, single_video, image) | |
| # Assuming config.dataset is a DictConfig object | |
| dataset_params_dict = OmegaConf.to_container(config.dataset, resolve=True) | |
| # Remove the 'type' key | |
| dataset_params_dict.pop('type', None) # 'None' ensures no error if 'type' key doesn't exist | |
| train_datasets = [] | |
| # Loop through all available datasets, get the name, then add to list of data to process. | |
| for DataSet in [VideoJsonDataset, SingleVideoDataset, ImageDataset, VideoFolderDataset]: | |
| for dataset in config.dataset.type: | |
| if dataset == DataSet.__getname__(): | |
| train_datasets.append(DataSet(**dataset_params_dict, tokenizer=tokenizer)) | |
| if len(train_datasets) < 0: | |
| raise ValueError("Dataset type not found: 'json', 'single_video', 'folder', 'image'") | |
| train_dataset = train_datasets[0] | |
| train_dataloader = torch.utils.data.DataLoader( | |
| train_dataset, | |
| batch_size=config.train.train_batch_size, | |
| shuffle=True | |
| ) | |
| return train_dataloader, train_dataset | |
| # create parameters for optimziation | |
| def prepare_params(unet, config, train_dataset): | |
| extra_params = {} | |
| params,embedding_layers = inject_motion_embeddings( | |
| unet, | |
| combinations=config.model.motion_embeddings.combinations, | |
| config=config | |
| ) | |
| config.model.embedding_layers = embedding_layers | |
| if config.loss.type == "DebiasedHybrid": | |
| if config.loss.spatial_lora_num == -1: | |
| config.loss.spatial_lora_num = train_dataset.__len__() | |
| lora_managers_spatial, unet_lora_params_spatial_list, unet_negation_all = inject_spatial_loras( | |
| unet=unet, | |
| use_unet_lora=True, | |
| lora_unet_dropout=0.1, | |
| lora_path='', | |
| lora_rank=32, | |
| spatial_lora_num=1, | |
| ) | |
| extra_params['lora_managers_spatial'] = lora_managers_spatial | |
| extra_params['unet_lora_params_spatial_list'] = unet_lora_params_spatial_list | |
| extra_params['unet_negation_all'] = unet_negation_all | |
| return params, extra_params |