Spaces:
Runtime error
Runtime error
| import torch | |
| import torchvision | |
| import torchaudio | |
| import torchvision.transforms as transforms | |
| from diffusers import UNet2DConditionModel, ControlNetModel | |
| from foleycrafter.pipelines.pipeline_controlnet import StableDiffusionControlNetPipeline | |
| from foleycrafter.pipelines.auffusion_pipeline import AuffusionNoAdapterPipeline, Generator | |
| from foleycrafter.models.auffusion_unet import UNet2DConditionModel as af_UNet2DConditionModel | |
| from diffusers.models import AutoencoderKLTemporalDecoder, AutoencoderKL | |
| from diffusers.schedulers import EulerDiscreteScheduler, DDIMScheduler, PNDMScheduler, KarrasDiffusionSchedulers | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection,\ | |
| SpeechT5HifiGan, ClapTextModelWithProjection, RobertaTokenizer, RobertaTokenizerFast,\ | |
| CLIPTextModel, CLIPTokenizer | |
| import glob | |
| from moviepy.editor import ImageSequenceClip, AudioFileClip, VideoFileClip, VideoClip | |
| from moviepy.audio.AudioClip import AudioArrayClip | |
| import numpy as np | |
| from safetensors import safe_open | |
| import random | |
| from typing import Union, Optional | |
| import decord | |
| import os | |
| import os.path as osp | |
| import imageio | |
| import soundfile as sf | |
| from PIL import Image, ImageOps | |
| import torch.distributed as dist | |
| import io | |
| from omegaconf import OmegaConf | |
| import json | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| import typing as T | |
| import warnings | |
| import pydub | |
| from scipy.io import wavfile | |
| from einops import rearrange | |
| def zero_rank_print(s): | |
| if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): print("### " + s, flush=True) | |
| def build_foleycrafter( | |
| pretrained_model_name_or_path: str="auffusion/auffusion-full-no-adapter", | |
| ) -> StableDiffusionControlNetPipeline: | |
| vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | |
| unet = af_UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | |
| scheduler = PNDMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | |
| tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | |
| text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | |
| controlnet = ControlNetModel.from_unet(unet, conditioning_channels=1) | |
| pipe = StableDiffusionControlNetPipeline( | |
| vae=vae, | |
| controlnet=controlnet, | |
| unet=unet, | |
| scheduler=scheduler, | |
| tokenizer=tokenizer, | |
| text_encoder=text_encoder, | |
| feature_extractor=None, | |
| safety_checker=None, | |
| requires_safety_checker=False, | |
| ) | |
| return pipe | |
| def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): | |
| if len(videos.shape) == 4: | |
| videos = videos.unsqueeze(0) | |
| videos = rearrange(videos, "b c t h w -> t b c h w") | |
| outputs = [] | |
| for x in videos: | |
| x = torchvision.utils.make_grid(x, nrow=n_rows) | |
| x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
| if rescale: | |
| x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
| x = torch.clamp((x * 255), 0, 255).numpy().astype(np.uint8) | |
| outputs.append(x) | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| imageio.mimsave(path, outputs, fps=fps) | |
| def save_videos_from_pil_list(videos: list, path: str, fps=7): | |
| for i in range(len(videos)): | |
| videos[i] = ImageOps.scale(videos[i], 255) | |
| imageio.mimwrite(path, videos, fps=fps) | |
| def seed_everything(seed: int) -> None: | |
| r"""Sets the seed for generating random numbers in :pytorch:`PyTorch`, | |
| :obj:`numpy` and :python:`Python`. | |
| Args: | |
| seed (int): The desired seed. | |
| """ | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| def get_video_frames(video: np.ndarray, num_frames: int=200): | |
| video_length = video.shape[0] | |
| video_idx = np.linspace(0, video_length-1, num_frames, dtype=int) | |
| video = video[video_idx, ...] | |
| return video | |
| def random_audio_video_clip(audio: np.ndarray, video: np.ndarray, fps:float, \ | |
| sample_rate:int=16000, duration:int=5, num_frames: int=20): | |
| """ | |
| Random sample video clips with duration | |
| """ | |
| video_length = video.shape[0] | |
| audio_length = audio.shape[-1] | |
| av_duration = int(video_length / fps) | |
| assert av_duration >= duration,\ | |
| f"video duration {av_duration} is less than {duration}" | |
| # random sample start time | |
| start_time = random.uniform(0, av_duration - duration) | |
| end_time = start_time + duration | |
| start_idx, end_idx = start_time / av_duration, end_time / av_duration | |
| video_start_frame, video_end_frame\ | |
| = video_length * start_idx, video_length * end_idx | |
| audio_start_frame, audio_end_frame\ | |
| = audio_length * start_idx, audio_length * end_idx | |
| # print(f"time_idx : {start_time}:{end_time}") | |
| # print(f"video_idx: {video_start_frame}:{video_end_frame}") | |
| # print(f"audio_idx: {audio_start_frame}:{audio_end_frame}") | |
| audio_idx = np.linspace(audio_start_frame, audio_end_frame, sample_rate * duration, dtype=int) | |
| video_idx = np.linspace(video_start_frame, video_end_frame, num_frames, dtype=int) | |
| audio = audio[..., audio_idx] | |
| video = video[video_idx, ...] | |
| return audio, video | |
| def get_full_indices(reader: Union[decord.VideoReader, decord.AudioReader])\ | |
| -> np.ndarray: | |
| if isinstance(reader, decord.VideoReader): | |
| return np.linspace(0, len(reader) - 1, len(reader), dtype=int) | |
| elif isinstance(reader, decord.AudioReader): | |
| return np.linspace(0, reader.shape[-1] - 1, reader.shape[-1], dtype=int) | |
| def get_frames(video_path:str, onset_list, frame_nums=1024): | |
| video = decord.VideoReader(video_path) | |
| video_frame = len(video) | |
| frames_list = [] | |
| for start, end in onset_list: | |
| video_start = int(start / frame_nums * video_frame) | |
| video_end = int(end / frame_nums * video_frame) | |
| frames_list.extend(range(video_start, video_end)) | |
| frames = video.get_batch(frames_list).asnumpy() | |
| return frames | |
| def get_frames_in_video(video_path:str, onset_list, frame_nums=1024, audio_length_in_s=10): | |
| # this function consider the video length | |
| video = decord.VideoReader(video_path) | |
| video_frame = len(video) | |
| duration = video_frame / video.get_avg_fps() | |
| frames_list = [] | |
| video_onset_list = [] | |
| for start, end in onset_list: | |
| if int(start / frame_nums * duration) >= audio_length_in_s: | |
| continue | |
| video_start = int(start / audio_length_in_s * duration / frame_nums * video_frame) | |
| if video_start >= video_frame: | |
| continue | |
| video_end = int(end / audio_length_in_s * duration / frame_nums * video_frame) | |
| video_onset_list.append([int(start / audio_length_in_s * duration), int(end / audio_length_in_s * duration)]) | |
| frames_list.extend(range(video_start, video_end)) | |
| frames = video.get_batch(frames_list).asnumpy() | |
| return frames, video_onset_list | |
| def save_multimodal(video, audio, output_path, audio_fps:int=16000, video_fps:int=8, remove_audio:bool=True): | |
| imgs = [img for img in video] | |
| # if audio.shape[0] == 1 or audio.shape[0] == 2: | |
| # audio = audio.T #[len, channel] | |
| # audio = np.repeat(audio, 2, axis=1) | |
| output_dir = osp.dirname(output_path) | |
| try: | |
| wavfile.write(osp.join(output_dir, "audio.wav"), audio_fps, audio) | |
| except: | |
| sf.write(osp.join(output_dir, "audio.wav"), audio, audio_fps) | |
| audio_clip = AudioFileClip(osp.join(output_dir, "audio.wav")) | |
| # audio_clip = AudioArrayClip(audio, fps=audio_fps) | |
| video_clip = ImageSequenceClip(imgs, fps=video_fps) | |
| video_clip = video_clip.set_audio(audio_clip) | |
| video_clip.write_videofile(output_path, video_fps, audio=True, audio_fps=audio_fps) | |
| if remove_audio: | |
| os.remove(osp.join(output_dir, "audio.wav")) | |
| return | |
| def save_multimodal_by_frame(video, audio, output_path, audio_fps:int=16000): | |
| imgs = [img for img in video] | |
| # if audio.shape[0] == 1 or audio.shape[0] == 2: | |
| # audio = audio.T #[len, channel] | |
| # audio = np.repeat(audio, 2, axis=1) | |
| # output_dir = osp.dirname(output_path) | |
| output_dir = output_path | |
| wavfile.write(osp.join(output_dir, "audio.wav"), audio_fps, audio) | |
| audio_clip = AudioFileClip(osp.join(output_dir, "audio.wav")) | |
| # audio_clip = AudioArrayClip(audio, fps=audio_fps) | |
| os.makedirs(osp.join(output_dir, 'frames'), exist_ok=True) | |
| for num, img in enumerate(imgs): | |
| if isinstance(img, np.ndarray): | |
| img = Image.fromarray(img.astype(np.uint8)) | |
| img.save(osp.join(output_dir, 'frames', f"{num}.jpg")) | |
| return | |
| def sanity_check(data: dict, save_path: str="sanity_check", batch_size: int=4, sample_rate: int=16000): | |
| video_path = osp.join(save_path, 'video') | |
| audio_path = osp.join(save_path, 'audio') | |
| av_path = osp.join(save_path, 'av') | |
| video, audio, text = data['pixel_values'], data['audio'], data['text'] | |
| video = (video / 2 + 0.5).clamp(0, 1) | |
| zero_rank_print(f"Saving {text} audio: {audio[0].shape} video: {video[0].shape}") | |
| for bsz in range(batch_size): | |
| os.makedirs(video_path, exist_ok=True) | |
| os.makedirs(audio_path, exist_ok=True) | |
| os.makedirs(av_path, exist_ok=True) | |
| # save_videos_grid(video[bsz:bsz+1,...], f"{osp.join(video_path, str(bsz) + '.mp4')}") | |
| bsz_audio = audio[bsz,...].permute(1, 0).cpu().numpy() | |
| bsz_video = video_tensor_to_np(video[bsz, ...]) | |
| sf.write(f"{osp.join(audio_path, str(bsz) + '.wav')}", bsz_audio, sample_rate) | |
| save_multimodal(bsz_video, bsz_audio, osp.join(av_path, str(bsz) + '.mp4')) | |
| def video_tensor_to_np(video: torch.Tensor, rescale: bool=True, scale: bool=False): | |
| if scale: | |
| video = (video / 2 + 0.5).clamp(0, 1) | |
| # c f h w -> f h w c | |
| if video.shape[0] == 3: | |
| video = video.permute(1, 2, 3, 0).detach().cpu().numpy() | |
| elif video.shape[1] == 3: | |
| video = video.permute(0, 2, 3, 1).detach().cpu().numpy() | |
| if rescale: | |
| video = video * 255 | |
| return video | |
| def composite_audio_video(video: str, audio: str, path:str, video_fps:int=7, audio_sample_rate:int=16000): | |
| video = decord.VideoReader(video) | |
| audio = decord.AudioReader(audio, sample_rate=audio_sample_rate) | |
| audio = audio.get_batch(get_full_indices(audio)).asnumpy() | |
| video = video.get_batch(get_full_indices(video)).asnumpy() | |
| save_multimodal(video, audio, path, audio_fps=audio_sample_rate, video_fps=video_fps) | |
| return | |
| # for video pipeline | |
| def append_dims(x, target_dims): | |
| """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | |
| dims_to_append = target_dims - x.ndim | |
| if dims_to_append < 0: | |
| raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") | |
| return x[(...,) + (None,) * dims_to_append] | |
| def resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): | |
| h, w = input.shape[-2:] | |
| factors = (h / size[0], w / size[1]) | |
| # First, we have to determine sigma | |
| # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 | |
| sigmas = ( | |
| max((factors[0] - 1.0) / 2.0, 0.001), | |
| max((factors[1] - 1.0) / 2.0, 0.001), | |
| ) | |
| # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma | |
| # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 | |
| # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now | |
| ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) | |
| # Make sure it is odd | |
| if (ks[0] % 2) == 0: | |
| ks = ks[0] + 1, ks[1] | |
| if (ks[1] % 2) == 0: | |
| ks = ks[0], ks[1] + 1 | |
| input = _gaussian_blur2d(input, ks, sigmas) | |
| output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) | |
| return output | |
| def _gaussian_blur2d(input, kernel_size, sigma): | |
| if isinstance(sigma, tuple): | |
| sigma = torch.tensor([sigma], dtype=input.dtype) | |
| else: | |
| sigma = sigma.to(dtype=input.dtype) | |
| ky, kx = int(kernel_size[0]), int(kernel_size[1]) | |
| bs = sigma.shape[0] | |
| kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) | |
| kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) | |
| out_x = _filter2d(input, kernel_x[..., None, :]) | |
| out = _filter2d(out_x, kernel_y[..., None]) | |
| return out | |
| def _filter2d(input, kernel): | |
| # prepare kernel | |
| b, c, h, w = input.shape | |
| tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) | |
| tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) | |
| height, width = tmp_kernel.shape[-2:] | |
| padding_shape: list[int] = _compute_padding([height, width]) | |
| input = torch.nn.functional.pad(input, padding_shape, mode="reflect") | |
| # kernel and input tensor reshape to align element-wise or batch-wise params | |
| tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) | |
| input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) | |
| # convolve the tensor with the kernel. | |
| output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) | |
| out = output.view(b, c, h, w) | |
| return out | |
| def _gaussian(window_size: int, sigma): | |
| if isinstance(sigma, float): | |
| sigma = torch.tensor([[sigma]]) | |
| batch_size = sigma.shape[0] | |
| x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) | |
| if window_size % 2 == 0: | |
| x = x + 0.5 | |
| gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) | |
| return gauss / gauss.sum(-1, keepdim=True) | |
| def _compute_padding(kernel_size): | |
| """Compute padding tuple.""" | |
| # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) | |
| # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad | |
| if len(kernel_size) < 2: | |
| raise AssertionError(kernel_size) | |
| computed = [k - 1 for k in kernel_size] | |
| # for even kernels we need to do asymmetric padding :( | |
| out_padding = 2 * len(kernel_size) * [0] | |
| for i in range(len(kernel_size)): | |
| computed_tmp = computed[-(i + 1)] | |
| pad_front = computed_tmp // 2 | |
| pad_rear = computed_tmp - pad_front | |
| out_padding[2 * i + 0] = pad_front | |
| out_padding[2 * i + 1] = pad_rear | |
| return out_padding | |
| def print_gpu_memory_usage(info: str, cuda_id:int=0): | |
| print(f">>> {info} <<<") | |
| reserved = torch.cuda.memory_reserved(cuda_id) / 1024 ** 3 | |
| used = torch.cuda.memory_allocated(cuda_id) / 1024 ** 3 | |
| print("total: ", reserved, "G") | |
| print("used: ", used, "G") | |
| print("available: ", reserved - used, "G") | |
| # use for dsp mel2spec | |
| class SpectrogramParams: | |
| """ | |
| Parameters for the conversion from audio to spectrograms to images and back. | |
| Includes helpers to convert to and from EXIF tags, allowing these parameters to be stored | |
| within spectrogram images. | |
| To understand what these parameters do and to customize them, read `spectrogram_converter.py` | |
| and the linked torchaudio documentation. | |
| """ | |
| # Whether the audio is stereo or mono | |
| stereo: bool = False | |
| # FFT parameters | |
| sample_rate: int = 44100 | |
| step_size_ms: int = 10 | |
| window_duration_ms: int = 100 | |
| padded_duration_ms: int = 400 | |
| # Mel scale parameters | |
| num_frequencies: int = 200 | |
| # TODO(hayk): Set these to [20, 20000] for newer models | |
| min_frequency: int = 0 | |
| max_frequency: int = 10000 | |
| mel_scale_norm: T.Optional[str] = None | |
| mel_scale_type: str = "htk" | |
| max_mel_iters: int = 200 | |
| # Griffin Lim parameters | |
| num_griffin_lim_iters: int = 32 | |
| # Image parameterization | |
| power_for_image: float = 0.25 | |
| class ExifTags(Enum): | |
| """ | |
| Custom EXIF tags for the spectrogram image. | |
| """ | |
| SAMPLE_RATE = 11000 | |
| STEREO = 11005 | |
| STEP_SIZE_MS = 11010 | |
| WINDOW_DURATION_MS = 11020 | |
| PADDED_DURATION_MS = 11030 | |
| NUM_FREQUENCIES = 11040 | |
| MIN_FREQUENCY = 11050 | |
| MAX_FREQUENCY = 11060 | |
| POWER_FOR_IMAGE = 11070 | |
| MAX_VALUE = 11080 | |
| def n_fft(self) -> int: | |
| """ | |
| The number of samples in each STFT window, with padding. | |
| """ | |
| return int(self.padded_duration_ms / 1000.0 * self.sample_rate) | |
| def win_length(self) -> int: | |
| """ | |
| The number of samples in each STFT window. | |
| """ | |
| return int(self.window_duration_ms / 1000.0 * self.sample_rate) | |
| def hop_length(self) -> int: | |
| """ | |
| The number of samples between each STFT window. | |
| """ | |
| return int(self.step_size_ms / 1000.0 * self.sample_rate) | |
| def to_exif(self) -> T.Dict[int, T.Any]: | |
| """ | |
| Return a dictionary of EXIF tags for the current values. | |
| """ | |
| return { | |
| self.ExifTags.SAMPLE_RATE.value: self.sample_rate, | |
| self.ExifTags.STEREO.value: self.stereo, | |
| self.ExifTags.STEP_SIZE_MS.value: self.step_size_ms, | |
| self.ExifTags.WINDOW_DURATION_MS.value: self.window_duration_ms, | |
| self.ExifTags.PADDED_DURATION_MS.value: self.padded_duration_ms, | |
| self.ExifTags.NUM_FREQUENCIES.value: self.num_frequencies, | |
| self.ExifTags.MIN_FREQUENCY.value: self.min_frequency, | |
| self.ExifTags.MAX_FREQUENCY.value: self.max_frequency, | |
| self.ExifTags.POWER_FOR_IMAGE.value: float(self.power_for_image), | |
| } | |
| class SpectrogramImageConverter: | |
| """ | |
| Convert between spectrogram images and audio segments. | |
| This is a wrapper around SpectrogramConverter that additionally converts from spectrograms | |
| to images and back. The real audio processing lives in SpectrogramConverter. | |
| """ | |
| def __init__(self, params: SpectrogramParams, device: str = "cuda"): | |
| self.p = params | |
| self.device = device | |
| self.converter = SpectrogramConverter(params=params, device=device) | |
| def spectrogram_image_from_audio( | |
| self, | |
| segment: pydub.AudioSegment, | |
| ) -> Image.Image: | |
| """ | |
| Compute a spectrogram image from an audio segment. | |
| Args: | |
| segment: Audio segment to convert | |
| Returns: | |
| Spectrogram image (in pillow format) | |
| """ | |
| assert int(segment.frame_rate) == self.p.sample_rate, "Sample rate mismatch" | |
| if self.p.stereo: | |
| if segment.channels == 1: | |
| print("WARNING: Mono audio but stereo=True, cloning channel") | |
| segment = segment.set_channels(2) | |
| elif segment.channels > 2: | |
| print("WARNING: Multi channel audio, reducing to stereo") | |
| segment = segment.set_channels(2) | |
| else: | |
| if segment.channels > 1: | |
| print("WARNING: Stereo audio but stereo=False, setting to mono") | |
| segment = segment.set_channels(1) | |
| spectrogram = self.converter.spectrogram_from_audio(segment) | |
| image = image_from_spectrogram( | |
| spectrogram, | |
| power=self.p.power_for_image, | |
| ) | |
| # Store conversion params in exif metadata of the image | |
| exif_data = self.p.to_exif() | |
| exif_data[SpectrogramParams.ExifTags.MAX_VALUE.value] = float(np.max(spectrogram)) | |
| exif = image.getexif() | |
| exif.update(exif_data.items()) | |
| return image | |
| def audio_from_spectrogram_image( | |
| self, | |
| image: Image.Image, | |
| apply_filters: bool = True, | |
| max_value: float = 30e6, | |
| ) -> pydub.AudioSegment: | |
| """ | |
| Reconstruct an audio segment from a spectrogram image. | |
| Args: | |
| image: Spectrogram image (in pillow format) | |
| apply_filters: Apply post-processing to improve the reconstructed audio | |
| max_value: Scaled max amplitude of the spectrogram. Shouldn't matter. | |
| """ | |
| spectrogram = spectrogram_from_image( | |
| image, | |
| max_value=max_value, | |
| power=self.p.power_for_image, | |
| stereo=self.p.stereo, | |
| ) | |
| segment = self.converter.audio_from_spectrogram( | |
| spectrogram, | |
| apply_filters=apply_filters, | |
| ) | |
| return segment | |
| def image_from_spectrogram(spectrogram: np.ndarray, power: float = 0.25) -> Image.Image: | |
| """ | |
| Compute a spectrogram image from a spectrogram magnitude array. | |
| This is the inverse of spectrogram_from_image, except for discretization error from | |
| quantizing to uint8. | |
| Args: | |
| spectrogram: (channels, frequency, time) | |
| power: A power curve to apply to the spectrogram to preserve contrast | |
| Returns: | |
| image: (frequency, time, channels) | |
| """ | |
| # Rescale to 0-1 | |
| max_value = np.max(spectrogram) | |
| data = spectrogram / max_value | |
| # Apply the power curve | |
| data = np.power(data, power) | |
| # Rescale to 0-255 | |
| data = data * 255 | |
| # Invert | |
| data = 255 - data | |
| # Convert to uint8 | |
| data = data.astype(np.uint8) | |
| # Munge channels into a PIL image | |
| if data.shape[0] == 1: | |
| # TODO(hayk): Do we want to write single channel to disk instead? | |
| image = Image.fromarray(data[0], mode="L").convert("RGB") | |
| elif data.shape[0] == 2: | |
| data = np.array([np.zeros_like(data[0]), data[0], data[1]]).transpose(1, 2, 0) | |
| image = Image.fromarray(data, mode="RGB") | |
| else: | |
| raise NotImplementedError(f"Unsupported number of channels: {data.shape[0]}") | |
| # Flip Y | |
| image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM) | |
| return image | |
| def spectrogram_from_image( | |
| image: Image.Image, | |
| power: float = 0.25, | |
| stereo: bool = False, | |
| max_value: float = 30e6, | |
| ) -> np.ndarray: | |
| """ | |
| Compute a spectrogram magnitude array from a spectrogram image. | |
| This is the inverse of image_from_spectrogram, except for discretization error from | |
| quantizing to uint8. | |
| Args: | |
| image: (frequency, time, channels) | |
| power: The power curve applied to the spectrogram | |
| stereo: Whether the spectrogram encodes stereo data | |
| max_value: The max value of the original spectrogram. In practice doesn't matter. | |
| Returns: | |
| spectrogram: (channels, frequency, time) | |
| """ | |
| # Convert to RGB if single channel | |
| if image.mode in ("P", "L"): | |
| image = image.convert("RGB") | |
| # Flip Y | |
| image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM) | |
| # Munge channels into a numpy array of (channels, frequency, time) | |
| data = np.array(image).transpose(2, 0, 1) | |
| if stereo: | |
| # Take the G and B channels as done in image_from_spectrogram | |
| data = data[[1, 2], :, :] | |
| else: | |
| data = data[0:1, :, :] | |
| # Convert to floats | |
| data = data.astype(np.float32) | |
| # Invert | |
| data = 255 - data | |
| # Rescale to 0-1 | |
| data = data / 255 | |
| # Reverse the power curve | |
| data = np.power(data, 1 / power) | |
| # Rescale to max value | |
| data = data * max_value | |
| return data | |
| class SpectrogramConverter: | |
| """ | |
| Convert between audio segments and spectrogram tensors using torchaudio. | |
| In this class a "spectrogram" is defined as a (batch, time, frequency) tensor with float values | |
| that represent the amplitude of the frequency at that time bucket (in the frequency domain). | |
| Frequencies are given in the perceptul Mel scale defined by the params. A more specific term | |
| used in some functions is "mel amplitudes". | |
| The spectrogram computed from `spectrogram_from_audio` is complex valued, but it only | |
| returns the amplitude, because the phase is chaotic and hard to learn. The function | |
| `audio_from_spectrogram` is an approximate inverse of `spectrogram_from_audio`, which | |
| approximates the phase information using the Griffin-Lim algorithm. | |
| Each channel in the audio is treated independently, and the spectrogram has a batch dimension | |
| equal to the number of channels in the input audio segment. | |
| Both the Griffin Lim algorithm and the Mel scaling process are lossy. | |
| For more information, see https://pytorch.org/audio/stable/transforms.html | |
| """ | |
| def __init__(self, params: SpectrogramParams, device: str = "cuda"): | |
| self.p = params | |
| self.device = check_device(device) | |
| if device.lower().startswith("mps"): | |
| warnings.warn( | |
| "WARNING: MPS does not support audio operations, falling back to CPU for them", | |
| stacklevel=2, | |
| ) | |
| self.device = "cpu" | |
| # https://pytorch.org/audio/stable/generated/torchaudio.transforms.Spectrogram.html | |
| self.spectrogram_func = torchaudio.transforms.Spectrogram( | |
| n_fft=params.n_fft, | |
| hop_length=params.hop_length, | |
| win_length=params.win_length, | |
| pad=0, | |
| window_fn=torch.hann_window, | |
| power=None, | |
| normalized=False, | |
| wkwargs=None, | |
| center=True, | |
| pad_mode="reflect", | |
| onesided=True, | |
| ).to(self.device) | |
| # https://pytorch.org/audio/stable/generated/torchaudio.transforms.GriffinLim.html | |
| self.inverse_spectrogram_func = torchaudio.transforms.GriffinLim( | |
| n_fft=params.n_fft, | |
| n_iter=params.num_griffin_lim_iters, | |
| win_length=params.win_length, | |
| hop_length=params.hop_length, | |
| window_fn=torch.hann_window, | |
| power=1.0, | |
| wkwargs=None, | |
| momentum=0.99, | |
| length=None, | |
| rand_init=True, | |
| ).to(self.device) | |
| # https://pytorch.org/audio/stable/generated/torchaudio.transforms.MelScale.html | |
| self.mel_scaler = torchaudio.transforms.MelScale( | |
| n_mels=params.num_frequencies, | |
| sample_rate=params.sample_rate, | |
| f_min=params.min_frequency, | |
| f_max=params.max_frequency, | |
| n_stft=params.n_fft // 2 + 1, | |
| norm=params.mel_scale_norm, | |
| mel_scale=params.mel_scale_type, | |
| ).to(self.device) | |
| # https://pytorch.org/audio/stable/generated/torchaudio.transforms.InverseMelScale.html | |
| self.inverse_mel_scaler = torchaudio.transforms.InverseMelScale( | |
| n_stft=params.n_fft // 2 + 1, | |
| n_mels=params.num_frequencies, | |
| sample_rate=params.sample_rate, | |
| f_min=params.min_frequency, | |
| f_max=params.max_frequency, | |
| # max_iter=params.max_mel_iters, # for higher verson of torchaudio | |
| # tolerance_loss=1e-5, # for higher verson of torchaudio | |
| # tolerance_change=1e-8, # for higher verson of torchaudio | |
| # sgdargs=None, # for higher verson of torchaudio | |
| norm=params.mel_scale_norm, | |
| mel_scale=params.mel_scale_type, | |
| ).to(self.device) | |
| def spectrogram_from_audio( | |
| self, | |
| audio: pydub.AudioSegment, | |
| ) -> np.ndarray: | |
| """ | |
| Compute a spectrogram from an audio segment. | |
| Args: | |
| audio: Audio segment which must match the sample rate of the params | |
| Returns: | |
| spectrogram: (channel, frequency, time) | |
| """ | |
| assert int(audio.frame_rate) == self.p.sample_rate, "Audio sample rate must match params" | |
| # Get the samples as a numpy array in (batch, samples) shape | |
| waveform = np.array([c.get_array_of_samples() for c in audio.split_to_mono()]) | |
| # Convert to floats if necessary | |
| if waveform.dtype != np.float32: | |
| waveform = waveform.astype(np.float32) | |
| waveform_tensor = torch.from_numpy(waveform).to(self.device) | |
| amplitudes_mel = self.mel_amplitudes_from_waveform(waveform_tensor) | |
| return amplitudes_mel.cpu().numpy() | |
| def audio_from_spectrogram( | |
| self, | |
| spectrogram: np.ndarray, | |
| apply_filters: bool = True, | |
| ) -> pydub.AudioSegment: | |
| """ | |
| Reconstruct an audio segment from a spectrogram. | |
| Args: | |
| spectrogram: (batch, frequency, time) | |
| apply_filters: Post-process with normalization and compression | |
| Returns: | |
| audio: Audio segment with channels equal to the batch dimension | |
| """ | |
| # Move to device | |
| amplitudes_mel = torch.from_numpy(spectrogram).to(self.device) | |
| # Reconstruct the waveform | |
| waveform = self.waveform_from_mel_amplitudes(amplitudes_mel) | |
| # Convert to audio segment | |
| segment = audio_from_waveform( | |
| samples=waveform.cpu().numpy(), | |
| sample_rate=self.p.sample_rate, | |
| # Normalize the waveform to the range [-1, 1] | |
| normalize=True, | |
| ) | |
| # Optionally apply post-processing filters | |
| if apply_filters: | |
| segment = apply_filters_func( | |
| segment, | |
| compression=False, | |
| ) | |
| return segment | |
| def mel_amplitudes_from_waveform( | |
| self, | |
| waveform: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Torch-only function to compute Mel-scale amplitudes from a waveform. | |
| Args: | |
| waveform: (batch, samples) | |
| Returns: | |
| amplitudes_mel: (batch, frequency, time) | |
| """ | |
| # Compute the complex-valued spectrogram | |
| spectrogram_complex = self.spectrogram_func(waveform) | |
| # Take the magnitude | |
| amplitudes = torch.abs(spectrogram_complex) | |
| # Convert to mel scale | |
| return self.mel_scaler(amplitudes) | |
| def waveform_from_mel_amplitudes( | |
| self, | |
| amplitudes_mel: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Torch-only function to approximately reconstruct a waveform from Mel-scale amplitudes. | |
| Args: | |
| amplitudes_mel: (batch, frequency, time) | |
| Returns: | |
| waveform: (batch, samples) | |
| """ | |
| # Convert from mel scale to linear | |
| amplitudes_linear = self.inverse_mel_scaler(amplitudes_mel) | |
| # Run the approximate algorithm to compute the phase and recover the waveform | |
| return self.inverse_spectrogram_func(amplitudes_linear) | |
| def check_device(device: str, backup: str = "cpu") -> str: | |
| """ | |
| Check that the device is valid and available. If not, | |
| """ | |
| cuda_not_found = device.lower().startswith("cuda") and not torch.cuda.is_available() | |
| mps_not_found = device.lower().startswith("mps") and not torch.backends.mps.is_available() | |
| if cuda_not_found or mps_not_found: | |
| warnings.warn(f"WARNING: {device} is not available, using {backup} instead.", stacklevel=3) | |
| return backup | |
| return device | |
| def audio_from_waveform( | |
| samples: np.ndarray, sample_rate: int, normalize: bool = False | |
| ) -> pydub.AudioSegment: | |
| """ | |
| Convert a numpy array of samples of a waveform to an audio segment. | |
| Args: | |
| samples: (channels, samples) array | |
| """ | |
| # Normalize volume to fit in int16 | |
| if normalize: | |
| samples *= np.iinfo(np.int16).max / np.max(np.abs(samples)) | |
| # Transpose and convert to int16 | |
| samples = samples.transpose(1, 0) | |
| samples = samples.astype(np.int16) | |
| # Write to the bytes of a WAV file | |
| wav_bytes = io.BytesIO() | |
| wavfile.write(wav_bytes, sample_rate, samples) | |
| wav_bytes.seek(0) | |
| # Read into pydub | |
| return pydub.AudioSegment.from_wav(wav_bytes) | |
| def apply_filters_func(segment: pydub.AudioSegment, compression: bool = False) -> pydub.AudioSegment: | |
| """ | |
| Apply post-processing filters to the audio segment to compress it and | |
| keep at a -10 dBFS level. | |
| """ | |
| # TODO(hayk): Come up with a principled strategy for these filters and experiment end-to-end. | |
| # TODO(hayk): Is this going to make audio unbalanced between sequential clips? | |
| if compression: | |
| segment = pydub.effects.normalize( | |
| segment, | |
| headroom=0.1, | |
| ) | |
| segment = segment.apply_gain(-10 - segment.dBFS) | |
| # TODO(hayk): This is quite slow, ~1.7 seconds on a beefy CPU | |
| segment = pydub.effects.compress_dynamic_range( | |
| segment, | |
| threshold=-20.0, | |
| ratio=4.0, | |
| attack=5.0, | |
| release=50.0, | |
| ) | |
| desired_db = -12 | |
| segment = segment.apply_gain(desired_db - segment.dBFS) | |
| segment = pydub.effects.normalize( | |
| segment, | |
| headroom=0.1, | |
| ) | |
| return segment | |
| def shave_segments(path, n_shave_prefix_segments=1): | |
| """ | |
| Removes segments. Positive values shave the first segments, negative shave the last segments. | |
| """ | |
| if n_shave_prefix_segments >= 0: | |
| return ".".join(path.split(".")[n_shave_prefix_segments:]) | |
| else: | |
| return ".".join(path.split(".")[:n_shave_prefix_segments]) | |
| def renew_resnet_paths(old_list, n_shave_prefix_segments=0): | |
| """ | |
| Updates paths inside resnets to the new naming scheme (local renaming) | |
| """ | |
| mapping = [] | |
| for old_item in old_list: | |
| new_item = old_item.replace("in_layers.0", "norm1") | |
| new_item = new_item.replace("in_layers.2", "conv1") | |
| new_item = new_item.replace("out_layers.0", "norm2") | |
| new_item = new_item.replace("out_layers.3", "conv2") | |
| new_item = new_item.replace("emb_layers.1", "time_emb_proj") | |
| new_item = new_item.replace("skip_connection", "conv_shortcut") | |
| new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | |
| mapping.append({"old": old_item, "new": new_item}) | |
| return mapping | |
| def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): | |
| """ | |
| Updates paths inside resnets to the new naming scheme (local renaming) | |
| """ | |
| mapping = [] | |
| for old_item in old_list: | |
| new_item = old_item | |
| new_item = new_item.replace("nin_shortcut", "conv_shortcut") | |
| new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | |
| mapping.append({"old": old_item, "new": new_item}) | |
| return mapping | |
| def renew_attention_paths(old_list, n_shave_prefix_segments=0): | |
| """ | |
| Updates paths inside attentions to the new naming scheme (local renaming) | |
| """ | |
| mapping = [] | |
| for old_item in old_list: | |
| new_item = old_item | |
| # new_item = new_item.replace('norm.weight', 'group_norm.weight') | |
| # new_item = new_item.replace('norm.bias', 'group_norm.bias') | |
| # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') | |
| # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') | |
| # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | |
| mapping.append({"old": old_item, "new": new_item}) | |
| return mapping | |
| def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): | |
| """ | |
| Updates paths inside attentions to the new naming scheme (local renaming) | |
| """ | |
| mapping = [] | |
| for old_item in old_list: | |
| new_item = old_item | |
| new_item = new_item.replace("norm.weight", "group_norm.weight") | |
| new_item = new_item.replace("norm.bias", "group_norm.bias") | |
| new_item = new_item.replace("q.weight", "to_q.weight") | |
| new_item = new_item.replace("q.bias", "to_q.bias") | |
| new_item = new_item.replace("k.weight", "to_k.weight") | |
| new_item = new_item.replace("k.bias", "to_k.bias") | |
| new_item = new_item.replace("v.weight", "to_v.weight") | |
| new_item = new_item.replace("v.bias", "to_v.bias") | |
| new_item = new_item.replace("proj_out.weight", "to_out.0.weight") | |
| new_item = new_item.replace("proj_out.bias", "to_out.0.bias") | |
| new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | |
| mapping.append({"old": old_item, "new": new_item}) | |
| return mapping | |
| def assign_to_checkpoint( | |
| paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None | |
| ): | |
| """ | |
| This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits | |
| attention layers, and takes into account additional replacements that may arise. | |
| Assigns the weights to the new checkpoint. | |
| """ | |
| assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." | |
| # Splits the attention layers into three variables. | |
| if attention_paths_to_split is not None: | |
| for path, path_map in attention_paths_to_split.items(): | |
| old_tensor = old_checkpoint[path] | |
| channels = old_tensor.shape[0] // 3 | |
| target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) | |
| num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 | |
| old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) | |
| query, key, value = old_tensor.split(channels // num_heads, dim=1) | |
| checkpoint[path_map["query"]] = query.reshape(target_shape) | |
| checkpoint[path_map["key"]] = key.reshape(target_shape) | |
| checkpoint[path_map["value"]] = value.reshape(target_shape) | |
| for path in paths: | |
| new_path = path["new"] | |
| # These have already been assigned | |
| if attention_paths_to_split is not None and new_path in attention_paths_to_split: | |
| continue | |
| # Global renaming happens here | |
| new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") | |
| new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") | |
| new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") | |
| if additional_replacements is not None: | |
| for replacement in additional_replacements: | |
| new_path = new_path.replace(replacement["old"], replacement["new"]) | |
| # proj_attn.weight has to be converted from conv 1D to linear | |
| if "proj_attn.weight" in new_path: | |
| checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] | |
| elif 'to_out.0.weight' in new_path: | |
| checkpoint[new_path] = old_checkpoint[path['old']].squeeze() | |
| elif any([qkv in new_path for qkv in ['to_q', 'to_k', 'to_v']]): | |
| checkpoint[new_path] = old_checkpoint[path['old']].squeeze() | |
| else: | |
| checkpoint[new_path] = old_checkpoint[path["old"]] | |
| def conv_attn_to_linear(checkpoint): | |
| keys = list(checkpoint.keys()) | |
| attn_keys = ["query.weight", "key.weight", "value.weight"] | |
| for key in keys: | |
| if ".".join(key.split(".")[-2:]) in attn_keys: | |
| if checkpoint[key].ndim > 2: | |
| checkpoint[key] = checkpoint[key][:, :, 0, 0] | |
| elif "proj_attn.weight" in key: | |
| if checkpoint[key].ndim > 2: | |
| checkpoint[key] = checkpoint[key][:, :, 0] | |
| def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): | |
| """ | |
| Creates a config for the diffusers based on the config of the LDM model. | |
| """ | |
| if controlnet: | |
| unet_params = original_config.model.params.control_stage_config.params | |
| else: | |
| unet_params = original_config.model.params.unet_config.params | |
| vae_params = original_config.model.params.first_stage_config.params.ddconfig | |
| block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] | |
| down_block_types = [] | |
| resolution = 1 | |
| for i in range(len(block_out_channels)): | |
| block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" | |
| down_block_types.append(block_type) | |
| if i != len(block_out_channels) - 1: | |
| resolution *= 2 | |
| up_block_types = [] | |
| for i in range(len(block_out_channels)): | |
| block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" | |
| up_block_types.append(block_type) | |
| resolution //= 2 | |
| vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) | |
| head_dim = unet_params.num_heads if "num_heads" in unet_params else None | |
| use_linear_projection = ( | |
| unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False | |
| ) | |
| if use_linear_projection: | |
| # stable diffusion 2-base-512 and 2-768 | |
| if head_dim is None: | |
| head_dim = [5, 10, 20, 20] | |
| class_embed_type = None | |
| projection_class_embeddings_input_dim = None | |
| if "num_classes" in unet_params: | |
| if unet_params.num_classes == "sequential": | |
| class_embed_type = "projection" | |
| assert "adm_in_channels" in unet_params | |
| projection_class_embeddings_input_dim = unet_params.adm_in_channels | |
| else: | |
| raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") | |
| config = { | |
| "sample_size": image_size // vae_scale_factor, | |
| "in_channels": unet_params.in_channels, | |
| "down_block_types": tuple(down_block_types), | |
| "block_out_channels": tuple(block_out_channels), | |
| "layers_per_block": unet_params.num_res_blocks, | |
| "cross_attention_dim": unet_params.context_dim, | |
| "attention_head_dim": head_dim, | |
| "use_linear_projection": use_linear_projection, | |
| "class_embed_type": class_embed_type, | |
| "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, | |
| } | |
| if not controlnet: | |
| config["out_channels"] = unet_params.out_channels | |
| config["up_block_types"] = tuple(up_block_types) | |
| return config | |
| def create_vae_diffusers_config(original_config, image_size: int): | |
| """ | |
| Creates a config for the diffusers based on the config of the LDM model. | |
| """ | |
| vae_params = original_config.model.params.first_stage_config.params.ddconfig | |
| _ = original_config.model.params.first_stage_config.params.embed_dim | |
| block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] | |
| down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) | |
| up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) | |
| config = { | |
| "sample_size": image_size, | |
| "in_channels": vae_params.in_channels, | |
| "out_channels": vae_params.out_ch, | |
| "down_block_types": tuple(down_block_types), | |
| "up_block_types": tuple(up_block_types), | |
| "block_out_channels": tuple(block_out_channels), | |
| "latent_channels": vae_params.z_channels, | |
| "layers_per_block": vae_params.num_res_blocks, | |
| } | |
| return config | |
| def create_diffusers_schedular(original_config): | |
| schedular = DDIMScheduler( | |
| num_train_timesteps=original_config.model.params.timesteps, | |
| beta_start=original_config.model.params.linear_start, | |
| beta_end=original_config.model.params.linear_end, | |
| beta_schedule="scaled_linear", | |
| ) | |
| return schedular | |
| def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False): | |
| """ | |
| Takes a state dict and a config, and returns a converted checkpoint. | |
| """ | |
| # extract state_dict for UNet | |
| unet_state_dict = {} | |
| keys = list(checkpoint.keys()) | |
| if controlnet: | |
| unet_key = "control_model." | |
| else: | |
| unet_key = "model.diffusion_model." | |
| # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA | |
| if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: | |
| print(f"Checkpoint {path} has both EMA and non-EMA weights.") | |
| print( | |
| "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" | |
| " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." | |
| ) | |
| for key in keys: | |
| if key.startswith("model.diffusion_model"): | |
| flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) | |
| unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) | |
| else: | |
| if sum(k.startswith("model_ema") for k in keys) > 100: | |
| print( | |
| "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" | |
| " weights (usually better for inference), please make sure to add the `--extract_ema` flag." | |
| ) | |
| for key in keys: | |
| if key.startswith(unet_key): | |
| unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) | |
| new_checkpoint = {} | |
| new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] | |
| new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] | |
| new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] | |
| new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] | |
| if config["class_embed_type"] is None: | |
| # No parameters to port | |
| ... | |
| elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": | |
| new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] | |
| new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] | |
| new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] | |
| new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] | |
| else: | |
| raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") | |
| new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] | |
| new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] | |
| if not controlnet: | |
| new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] | |
| new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] | |
| new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] | |
| new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] | |
| # Retrieves the keys for the input blocks only | |
| num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) | |
| input_blocks = { | |
| layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] | |
| for layer_id in range(num_input_blocks) | |
| } | |
| # Retrieves the keys for the middle blocks only | |
| num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) | |
| middle_blocks = { | |
| layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] | |
| for layer_id in range(num_middle_blocks) | |
| } | |
| # Retrieves the keys for the output blocks only | |
| num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) | |
| output_blocks = { | |
| layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] | |
| for layer_id in range(num_output_blocks) | |
| } | |
| for i in range(1, num_input_blocks): | |
| block_id = (i - 1) // (config["layers_per_block"] + 1) | |
| layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) | |
| resnets = [ | |
| key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key | |
| ] | |
| attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] | |
| if f"input_blocks.{i}.0.op.weight" in unet_state_dict: | |
| new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( | |
| f"input_blocks.{i}.0.op.weight" | |
| ) | |
| new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( | |
| f"input_blocks.{i}.0.op.bias" | |
| ) | |
| paths = renew_resnet_paths(resnets) | |
| meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} | |
| assign_to_checkpoint( | |
| paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
| ) | |
| if len(attentions): | |
| paths = renew_attention_paths(attentions) | |
| meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} | |
| assign_to_checkpoint( | |
| paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
| ) | |
| resnet_0 = middle_blocks[0] | |
| attentions = middle_blocks[1] | |
| resnet_1 = middle_blocks[2] | |
| resnet_0_paths = renew_resnet_paths(resnet_0) | |
| assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) | |
| resnet_1_paths = renew_resnet_paths(resnet_1) | |
| assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) | |
| attentions_paths = renew_attention_paths(attentions) | |
| meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} | |
| assign_to_checkpoint( | |
| attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
| ) | |
| for i in range(num_output_blocks): | |
| block_id = i // (config["layers_per_block"] + 1) | |
| layer_in_block_id = i % (config["layers_per_block"] + 1) | |
| output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] | |
| output_block_list = {} | |
| for layer in output_block_layers: | |
| layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) | |
| if layer_id in output_block_list: | |
| output_block_list[layer_id].append(layer_name) | |
| else: | |
| output_block_list[layer_id] = [layer_name] | |
| if len(output_block_list) > 1: | |
| resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] | |
| attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] | |
| resnet_0_paths = renew_resnet_paths(resnets) | |
| paths = renew_resnet_paths(resnets) | |
| meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} | |
| assign_to_checkpoint( | |
| paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
| ) | |
| output_block_list = {k: sorted(v) for k, v in output_block_list.items()} | |
| if ["conv.bias", "conv.weight"] in output_block_list.values(): | |
| index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) | |
| new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ | |
| f"output_blocks.{i}.{index}.conv.weight" | |
| ] | |
| new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ | |
| f"output_blocks.{i}.{index}.conv.bias" | |
| ] | |
| # Clear attentions as they have been attributed above. | |
| if len(attentions) == 2: | |
| attentions = [] | |
| if len(attentions): | |
| paths = renew_attention_paths(attentions) | |
| meta_path = { | |
| "old": f"output_blocks.{i}.1", | |
| "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", | |
| } | |
| assign_to_checkpoint( | |
| paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
| ) | |
| else: | |
| resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) | |
| for path in resnet_0_paths: | |
| old_path = ".".join(["output_blocks", str(i), path["old"]]) | |
| new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) | |
| new_checkpoint[new_path] = unet_state_dict[old_path] | |
| if controlnet: | |
| # conditioning embedding | |
| orig_index = 0 | |
| new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( | |
| f"input_hint_block.{orig_index}.weight" | |
| ) | |
| new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( | |
| f"input_hint_block.{orig_index}.bias" | |
| ) | |
| orig_index += 2 | |
| diffusers_index = 0 | |
| while diffusers_index < 6: | |
| new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( | |
| f"input_hint_block.{orig_index}.weight" | |
| ) | |
| new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( | |
| f"input_hint_block.{orig_index}.bias" | |
| ) | |
| diffusers_index += 1 | |
| orig_index += 2 | |
| new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( | |
| f"input_hint_block.{orig_index}.weight" | |
| ) | |
| new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( | |
| f"input_hint_block.{orig_index}.bias" | |
| ) | |
| # down blocks | |
| for i in range(num_input_blocks): | |
| new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") | |
| new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") | |
| # mid block | |
| new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") | |
| new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") | |
| return new_checkpoint | |
| def convert_ldm_vae_checkpoint(checkpoint, config, only_decoder=False, only_encoder=False): | |
| # extract state dict for VAE | |
| vae_state_dict = {} | |
| vae_key = "first_stage_model." | |
| keys = list(checkpoint.keys()) | |
| for key in keys: | |
| if key.startswith(vae_key): | |
| vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) | |
| new_checkpoint = {} | |
| new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] | |
| new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] | |
| new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] | |
| new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] | |
| new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] | |
| new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] | |
| new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] | |
| new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] | |
| new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] | |
| new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] | |
| new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] | |
| new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] | |
| new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] | |
| new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] | |
| new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] | |
| new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] | |
| # Retrieves the keys for the encoder down blocks only | |
| num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) | |
| down_blocks = { | |
| layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) | |
| } | |
| # Retrieves the keys for the decoder up blocks only | |
| num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) | |
| up_blocks = { | |
| layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) | |
| } | |
| for i in range(num_down_blocks): | |
| resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] | |
| if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: | |
| new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( | |
| f"encoder.down.{i}.downsample.conv.weight" | |
| ) | |
| new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( | |
| f"encoder.down.{i}.downsample.conv.bias" | |
| ) | |
| paths = renew_vae_resnet_paths(resnets) | |
| meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} | |
| assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | |
| mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] | |
| num_mid_res_blocks = 2 | |
| for i in range(1, num_mid_res_blocks + 1): | |
| resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] | |
| paths = renew_vae_resnet_paths(resnets) | |
| meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} | |
| assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | |
| mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] | |
| paths = renew_vae_attention_paths(mid_attentions) | |
| meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} | |
| assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | |
| conv_attn_to_linear(new_checkpoint) | |
| for i in range(num_up_blocks): | |
| block_id = num_up_blocks - 1 - i | |
| resnets = [ | |
| key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key | |
| ] | |
| if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: | |
| new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ | |
| f"decoder.up.{block_id}.upsample.conv.weight" | |
| ] | |
| new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ | |
| f"decoder.up.{block_id}.upsample.conv.bias" | |
| ] | |
| paths = renew_vae_resnet_paths(resnets) | |
| meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} | |
| assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | |
| mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] | |
| num_mid_res_blocks = 2 | |
| for i in range(1, num_mid_res_blocks + 1): | |
| resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] | |
| paths = renew_vae_resnet_paths(resnets) | |
| meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} | |
| assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | |
| mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] | |
| paths = renew_vae_attention_paths(mid_attentions) | |
| meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} | |
| assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | |
| conv_attn_to_linear(new_checkpoint) | |
| if only_decoder: | |
| new_checkpoint = {k: v for k, v in new_checkpoint.items() if k.startswith('decoder') or k.startswith('post_quant')} | |
| elif only_encoder: | |
| new_checkpoint = {k: v for k, v in new_checkpoint.items() if k.startswith('encoder') or k.startswith('quant')} | |
| return new_checkpoint | |
| def convert_ldm_clip_checkpoint(checkpoint): | |
| keys = list(checkpoint.keys()) | |
| text_model_dict = {} | |
| for key in keys: | |
| if key.startswith("cond_stage_model.transformer"): | |
| text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] | |
| return text_model_dict | |
| def convert_lora_model_level(state_dict, unet, text_encoder=None, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): | |
| """convert lora in model level instead of pipeline leval | |
| """ | |
| visited = [] | |
| # directly update weight in diffusers model | |
| for key in state_dict: | |
| # it is suggested to print out the key, it usually will be something like below | |
| # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" | |
| # as we have set the alpha beforehand, so just skip | |
| if ".alpha" in key or key in visited: | |
| continue | |
| if "text" in key: | |
| layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") | |
| assert text_encoder is not None, ( | |
| 'text_encoder must be passed since lora contains text encoder layers') | |
| curr_layer = text_encoder | |
| else: | |
| layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") | |
| curr_layer = unet | |
| # find the target layer | |
| temp_name = layer_infos.pop(0) | |
| while len(layer_infos) > -1: | |
| try: | |
| curr_layer = curr_layer.__getattr__(temp_name) | |
| if len(layer_infos) > 0: | |
| temp_name = layer_infos.pop(0) | |
| elif len(layer_infos) == 0: | |
| break | |
| except Exception: | |
| if len(temp_name) > 0: | |
| temp_name += "_" + layer_infos.pop(0) | |
| else: | |
| temp_name = layer_infos.pop(0) | |
| pair_keys = [] | |
| if "lora_down" in key: | |
| pair_keys.append(key.replace("lora_down", "lora_up")) | |
| pair_keys.append(key) | |
| else: | |
| pair_keys.append(key) | |
| pair_keys.append(key.replace("lora_up", "lora_down")) | |
| # update weight | |
| # NOTE: load lycon, meybe have bugs :( | |
| if 'conv_in' in pair_keys[0]: | |
| weight_up = state_dict[pair_keys[0]].to(torch.float32) | |
| weight_down = state_dict[pair_keys[1]].to(torch.float32) | |
| weight_up = weight_up.view(weight_up.size(0), -1) | |
| weight_down = weight_down.view(weight_down.size(0), -1) | |
| shape = [e for e in curr_layer.weight.data.shape] | |
| shape[1] = 4 | |
| curr_layer.weight.data[:, :4, ...] += alpha * (weight_up @ weight_down).view(*shape) | |
| elif 'conv' in pair_keys[0]: | |
| weight_up = state_dict[pair_keys[0]].to(torch.float32) | |
| weight_down = state_dict[pair_keys[1]].to(torch.float32) | |
| weight_up = weight_up.view(weight_up.size(0), -1) | |
| weight_down = weight_down.view(weight_down.size(0), -1) | |
| shape = [e for e in curr_layer.weight.data.shape] | |
| curr_layer.weight.data += alpha * (weight_up @ weight_down).view(*shape) | |
| elif len(state_dict[pair_keys[0]].shape) == 4: | |
| weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) | |
| weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) | |
| curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) | |
| else: | |
| weight_up = state_dict[pair_keys[0]].to(torch.float32) | |
| weight_down = state_dict[pair_keys[1]].to(torch.float32) | |
| curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) | |
| # update visited list | |
| for item in pair_keys: | |
| visited.append(item) | |
| return unet, text_encoder | |
| def denormalize_spectrogram( | |
| data: torch.Tensor, | |
| max_value: float = 200, | |
| min_value: float = 1e-5, | |
| power: float = 1, | |
| inverse: bool = False, | |
| ) -> torch.Tensor: | |
| max_value = np.log(max_value) | |
| min_value = np.log(min_value) | |
| # Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner | |
| data = torch.flip(data, [1]) | |
| assert len(data.shape) == 3, "Expected 3 dimensions, got {}".format(len(data.shape)) | |
| if data.shape[0] == 1: | |
| data = data.repeat(3, 1, 1) | |
| assert data.shape[0] == 3, "Expected 3 channels, got {}".format(data.shape[0]) | |
| data = data[0] | |
| # Reverse the power curve | |
| data = torch.pow(data, 1 / power) | |
| # Invert | |
| if inverse: | |
| data = 1 - data | |
| # Rescale to max value | |
| spectrogram = data * (max_value - min_value) + min_value | |
| return spectrogram | |
| class ToTensor1D(torchvision.transforms.ToTensor): | |
| def __call__(self, tensor: np.ndarray): | |
| tensor_2d = super(ToTensor1D, self).__call__(tensor[..., np.newaxis]) | |
| return tensor_2d.squeeze_(0) | |
| def scale(old_value, old_min, old_max, new_min, new_max): | |
| old_range = (old_max - old_min) | |
| new_range = (new_max - new_min) | |
| new_value = (((old_value - old_min) * new_range) / old_range) + new_min | |
| return new_value | |
| def read_frames_with_moviepy(video_path, max_frame_nums=None): | |
| clip = VideoFileClip(video_path) | |
| duration = clip.duration | |
| frames = [] | |
| for frame in clip.iter_frames(): | |
| frames.append(frame) | |
| if max_frame_nums is not None: | |
| frames_idx = np.linspace(0, len(frames) - 1, max_frame_nums, dtype=int) | |
| return np.array(frames)[frames_idx,...], duration | |
| def read_frames_with_moviepy_resample(video_path, save_path): | |
| vision_transform_list = [ | |
| transforms.Resize((128, 128)), | |
| transforms.CenterCrop((112, 112)), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ] | |
| video_transform = transforms.Compose(vision_transform_list) | |
| os.makedirs(save_path, exist_ok=True) | |
| command = f'ffmpeg -v quiet -y -i \"{video_path}\" -f image2 -vf \"scale=-1:360,fps=15\" -qscale:v 3 \"{save_path}\"/frame%06d.jpg' | |
| os.system(command) | |
| frame_list = glob.glob(f'{save_path}/*.jpg') | |
| frame_list.sort() | |
| convert_tensor = transforms.ToTensor() | |
| frame_list = [convert_tensor(np.array(Image.open(frame))) for frame in frame_list] | |
| imgs = torch.stack(frame_list, dim=0) | |
| imgs = video_transform(imgs) | |
| imgs = imgs.permute(1, 0, 2, 3) | |
| return imgs |