Spaces:
Build error
Build error
| import torch | |
| from enum import Enum | |
| import gc | |
| import numpy as np | |
| import jax.numpy as jnp | |
| import tomesd | |
| import jax | |
| from flax.training.common_utils import shard | |
| from flax.jax_utils import replicate | |
| from flax import jax_utils | |
| import einops | |
| from transformers import CLIPTokenizer, CLIPFeatureExtractor, FlaxCLIPTextModel | |
| from diffusers import ( | |
| FlaxDDIMScheduler, | |
| FlaxAutoencoderKL, | |
| FlaxStableDiffusionControlNetPipeline, | |
| StableDiffusionPipeline, | |
| ) | |
| from text_to_animation.models.unet_2d_condition_flax import FlaxUNet2DConditionModel | |
| from text_to_animation.models.controlnet_flax import FlaxControlNetModel | |
| from text_to_animation.pipelines.text_to_video_pipeline_flax import ( | |
| FlaxTextToVideoPipeline, | |
| ) | |
| import utils.utils as utils | |
| import utils.gradio_utils as gradio_utils | |
| import os | |
| on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR" | |
| unshard = lambda x: einops.rearrange(x, "d b ... -> (d b) ...") | |
| class ModelType(Enum): | |
| Text2Video = 1 | |
| ControlNetPose = 2 | |
| StableDiffusion = 3 | |
| def replicate_devices(array): | |
| return jnp.expand_dims(array, 0).repeat(jax.device_count(), 0) | |
| class ControlAnimationModel: | |
| def __init__(self, device, dtype, **kwargs): | |
| self.device = device | |
| self.dtype = dtype | |
| self.rng = jax.random.PRNGKey(0) | |
| self.pipe_dict = { | |
| ModelType.Text2Video: FlaxTextToVideoPipeline, # TODO: Replace with our TextToVideo JAX Pipeline | |
| ModelType.ControlNetPose: FlaxStableDiffusionControlNetPipeline, | |
| } | |
| self.pipe = None | |
| self.model_type = None | |
| self.states = {} | |
| self.model_name = "" | |
| self.from_local = True # if the attn model is available in local (after adaptation by adapt_attn.py) | |
| def set_model( | |
| self, | |
| model_type: ModelType, | |
| model_id: str, | |
| controlnet, | |
| controlnet_params, | |
| tokenizer, | |
| scheduler, | |
| scheduler_state, | |
| **kwargs, | |
| ): | |
| if hasattr(self, "pipe") and self.pipe is not None: | |
| del self.pipe | |
| self.pipe = None | |
| gc.collect() | |
| scheduler, scheduler_state = FlaxDDIMScheduler.from_pretrained( | |
| model_id, subfolder="scheduler", from_pt=True | |
| ) | |
| tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") | |
| feature_extractor = CLIPFeatureExtractor.from_pretrained( | |
| model_id, subfolder="feature_extractor" | |
| ) | |
| if self.from_local: | |
| unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( | |
| f'./{model_id.split("/")[-1]}', | |
| subfolder="unet", | |
| from_pt=True, | |
| dtype=self.dtype, | |
| ) | |
| else: | |
| unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( | |
| model_id, subfolder="unet", from_pt=True, dtype=self.dtype | |
| ) | |
| vae, vae_params = FlaxAutoencoderKL.from_pretrained( | |
| model_id, subfolder="vae", from_pt=True, dtype=self.dtype | |
| ) | |
| text_encoder = FlaxCLIPTextModel.from_pretrained( | |
| model_id, subfolder="text_encoder", from_pt=True, dtype=self.dtype | |
| ) | |
| self.pipe = FlaxTextToVideoPipeline( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| unet=unet, | |
| controlnet=controlnet, | |
| scheduler=scheduler, | |
| safety_checker=None, | |
| feature_extractor=feature_extractor, | |
| ) | |
| self.params = { | |
| "unet": unet_params, | |
| "vae": vae_params, | |
| "scheduler": scheduler_state, | |
| "controlnet": controlnet_params, | |
| "text_encoder": text_encoder.params, | |
| } | |
| self.p_params = jax_utils.replicate(self.params) | |
| self.model_type = model_type | |
| self.model_name = model_id | |
| # def inference_chunk(self, image, frame_ids, prompt, negative_prompt, **kwargs): | |
| # prompt_ids = self.pipe.prepare_text_inputs(prompt) | |
| # n_prompt_ids = self.pipe.prepare_text_inputs(negative_prompt) | |
| # latents = kwargs.pop('latents') | |
| # # rng = jax.random.split(self.rng, jax.device_count()) | |
| # prng, self.rng = jax.random.split(self.rng) | |
| # #prng = jax.numpy.stack([prng] * jax.device_count())#same prng seed on every device | |
| # prng_seed = jax.random.split(prng, jax.device_count()) | |
| # image = replicate_devices(image[frame_ids]) | |
| # latents = replicate_devices(latents) | |
| # prompt_ids = replicate_devices(prompt_ids) | |
| # n_prompt_ids = replicate_devices(n_prompt_ids) | |
| # return (self.pipe(image=image, | |
| # latents=latents, | |
| # prompt_ids=prompt_ids, | |
| # neg_prompt_ids=n_prompt_ids, | |
| # params=self.p_params, | |
| # prng_seed=prng_seed, jit = True, | |
| # ).images)[0] | |
| def inference(self, image, split_to_chunks=False, chunk_size=8, **kwargs): | |
| if not hasattr(self, "pipe") or self.pipe is None: | |
| return | |
| if "merging_ratio" in kwargs: | |
| merging_ratio = kwargs.pop("merging_ratio") | |
| # if merging_ratio > 0: | |
| tomesd.apply_patch(self.pipe, ratio=merging_ratio) | |
| # f = image.shape[0] | |
| assert "prompt" in kwargs | |
| prompt = [kwargs.pop("prompt")] | |
| negative_prompt = [kwargs.pop("negative_prompt", "")] | |
| frames_counter = 0 | |
| # Processing chunk-by-chunk | |
| if split_to_chunks: | |
| pass | |
| # # not tested | |
| # f = image.shape[0] | |
| # chunk_ids = np.arange(0, f, chunk_size - 1) | |
| # result = [] | |
| # for i in range(len(chunk_ids)): | |
| # ch_start = chunk_ids[i] | |
| # ch_end = f if i == len(chunk_ids) - 1 else chunk_ids[i + 1] | |
| # frame_ids = [0] + list(range(ch_start, ch_end)) | |
| # print(f'Processing chunk {i + 1} / {len(chunk_ids)}') | |
| # result.append(self.inference_chunk(image=image, | |
| # frame_ids=frame_ids, | |
| # prompt=prompt, | |
| # negative_prompt=negative_prompt, | |
| # **kwargs).images[1:]) | |
| # frames_counter += len(chunk_ids)-1 | |
| # if on_huggingspace and frames_counter >= 80: | |
| # break | |
| # result = np.concatenate(result) | |
| # return result | |
| else: | |
| if "jit" in kwargs and kwargs.pop("jit"): | |
| prompt_ids = self.pipe.prepare_text_inputs(prompt) | |
| n_prompt_ids = self.pipe.prepare_text_inputs(negative_prompt) | |
| latents = kwargs.pop("latents") | |
| prng, self.rng = jax.random.split(self.rng) | |
| prng_seed = jax.random.split(prng, jax.device_count()) | |
| image = replicate_devices(image) | |
| latents = replicate_devices(latents) | |
| prompt_ids = replicate_devices(prompt_ids) | |
| n_prompt_ids = replicate_devices(n_prompt_ids) | |
| return ( | |
| self.pipe( | |
| image=image, | |
| latents=latents, | |
| prompt_ids=prompt_ids, | |
| neg_prompt_ids=n_prompt_ids, | |
| params=self.p_params, | |
| prng_seed=prng_seed, | |
| jit=True, | |
| ).images | |
| )[0] | |
| else: | |
| prompt_ids = self.pipe.prepare_text_inputs(prompt) | |
| n_prompt_ids = self.pipe.prepare_text_inputs(negative_prompt) | |
| latents = kwargs.pop("latents") | |
| prng_seed, self.rng = jax.random.split(self.rng) | |
| return self.pipe( | |
| image=image, | |
| latents=latents, | |
| prompt_ids=prompt_ids, | |
| neg_prompt_ids=n_prompt_ids, | |
| params=self.params, | |
| prng_seed=prng_seed, | |
| jit=False, | |
| ).images | |
| def process_controlnet_pose( | |
| self, | |
| video_path, | |
| prompt, | |
| chunk_size=8, | |
| watermark="Picsart AI Research", | |
| merging_ratio=0.0, | |
| num_inference_steps=20, | |
| controlnet_conditioning_scale=1.0, | |
| guidance_scale=9.0, | |
| seed=42, | |
| eta=0.0, | |
| resolution=512, | |
| use_cf_attn=True, | |
| save_path=None, | |
| ): | |
| print("Module Pose") | |
| video_path = gradio_utils.motion_to_video_path(video_path) | |
| if self.model_type != ModelType.ControlNetPose: | |
| controlnet = FlaxControlNetModel.from_pretrained( | |
| "fusing/stable-diffusion-v1-5-controlnet-openpose" | |
| ) | |
| self.set_model( | |
| ModelType.ControlNetPose, | |
| model_id="runwayml/stable-diffusion-v1-5", | |
| controlnet=controlnet, | |
| ) | |
| self.pipe.scheduler = FlaxDDIMScheduler.from_config( | |
| self.pipe.scheduler.config | |
| ) | |
| if use_cf_attn: | |
| self.pipe.unet.set_attn_processor(processor=self.controlnet_attn_proc) | |
| self.pipe.controlnet.set_attn_processor( | |
| processor=self.controlnet_attn_proc | |
| ) | |
| video_path = ( | |
| gradio_utils.motion_to_video_path(video_path) | |
| if "Motion" in video_path | |
| else video_path | |
| ) | |
| added_prompt = "best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth" | |
| negative_prompts = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic" | |
| video, fps = utils.prepare_video( | |
| video_path, resolution, self.device, self.dtype, False, output_fps=4 | |
| ) | |
| control = ( | |
| utils.pre_process_pose(video, apply_pose_detect=False) | |
| .to(self.device) | |
| .to(self.dtype) | |
| ) | |
| f, _, h, w = video.shape | |
| self.generator.manual_seed(seed) | |
| latents = torch.randn( | |
| (1, 4, h // 8, w // 8), | |
| dtype=self.dtype, | |
| device=self.device, | |
| generator=self.generator, | |
| ) | |
| latents = latents.repeat(f, 1, 1, 1) | |
| result = self.inference( | |
| image=control, | |
| prompt=prompt + ", " + added_prompt, | |
| height=h, | |
| width=w, | |
| negative_prompt=negative_prompts, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| controlnet_conditioning_scale=controlnet_conditioning_scale, | |
| eta=eta, | |
| latents=latents, | |
| seed=seed, | |
| output_type="numpy", | |
| split_to_chunks=True, | |
| chunk_size=chunk_size, | |
| merging_ratio=merging_ratio, | |
| ) | |
| return utils.create_gif( | |
| result, | |
| fps, | |
| path=save_path, | |
| watermark=gradio_utils.logo_name_to_path(watermark), | |
| ) | |
| def process_text2video( | |
| self, | |
| prompt, | |
| model_name="dreamlike-art/dreamlike-photoreal-2.0", | |
| motion_field_strength_x=12, | |
| motion_field_strength_y=12, | |
| t0=44, | |
| t1=47, | |
| n_prompt="", | |
| chunk_size=8, | |
| video_length=8, | |
| watermark="Picsart AI Research", | |
| merging_ratio=0.0, | |
| seed=0, | |
| resolution=512, | |
| fps=2, | |
| use_cf_attn=True, | |
| use_motion_field=True, | |
| smooth_bg=False, | |
| smooth_bg_strength=0.4, | |
| path=None, | |
| ): | |
| print("Module Text2Video") | |
| if self.model_type != ModelType.Text2Video or model_name != self.model_name: | |
| print("Model update") | |
| unet = FlaxUNet2DConditionModel.from_pretrained( | |
| model_name, subfolder="unet" | |
| ) | |
| self.set_model(ModelType.Text2Video, model_id=model_name, unet=unet) | |
| self.pipe.scheduler = FlaxDDIMScheduler.from_config( | |
| self.pipe.scheduler.config | |
| ) | |
| if use_cf_attn: | |
| self.pipe.unet.set_attn_processor(processor=self.text2video_attn_proc) | |
| self.generator.manual_seed(seed) | |
| added_prompt = "high quality, HD, 8K, trending on artstation, high focus, dramatic lighting" | |
| negative_prompts = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic" | |
| prompt = prompt.rstrip() | |
| if len(prompt) > 0 and (prompt[-1] == "," or prompt[-1] == "."): | |
| prompt = prompt.rstrip()[:-1] | |
| prompt = prompt.rstrip() | |
| prompt = prompt + ", " + added_prompt | |
| if len(n_prompt) > 0: | |
| negative_prompt = n_prompt | |
| else: | |
| negative_prompt = None | |
| result = self.inference( | |
| prompt=prompt, | |
| video_length=video_length, | |
| height=resolution, | |
| width=resolution, | |
| num_inference_steps=50, | |
| guidance_scale=7.5, | |
| guidance_stop_step=1.0, | |
| t0=t0, | |
| t1=t1, | |
| motion_field_strength_x=motion_field_strength_x, | |
| motion_field_strength_y=motion_field_strength_y, | |
| use_motion_field=use_motion_field, | |
| smooth_bg=smooth_bg, | |
| smooth_bg_strength=smooth_bg_strength, | |
| seed=seed, | |
| output_type="numpy", | |
| negative_prompt=negative_prompt, | |
| merging_ratio=merging_ratio, | |
| split_to_chunks=True, | |
| chunk_size=chunk_size, | |
| ) | |
| return utils.create_video( | |
| result, fps, path=path, watermark=gradio_utils.logo_name_to_path(watermark) | |
| ) | |
| def generate_animation( | |
| self, | |
| prompt: str, | |
| model_link: str = "dreamlike-art/dreamlike-photoreal-2.0", | |
| is_safetensor: bool = False, | |
| motion_field_strength_x: int = 12, | |
| motion_field_strength_y: int = 12, | |
| t0: int = 44, | |
| t1: int = 47, | |
| n_prompt: str = "", | |
| chunk_size: int = 8, | |
| video_length: int = 8, | |
| merging_ratio: float = 0.0, | |
| seed: int = 0, | |
| resolution: int = 512, | |
| fps: int = 2, | |
| use_cf_attn: bool = True, | |
| use_motion_field: bool = True, | |
| smooth_bg: bool = False, | |
| smooth_bg_strength: float = 0.4, | |
| path: str = None, | |
| ): | |
| if is_safetensor and model_link[-len(".safetensors") :] == ".safetensors": | |
| pipe = utils.load_safetensors_model(model_link) | |
| return | |
| def generate_initial_frames( | |
| self, | |
| prompt: str, | |
| model_link: str = "dreamlike-art/dreamlike-photoreal-2.0", | |
| is_safetensor: bool = False, | |
| n_prompt: str = "", | |
| width: int = 512, | |
| height: int = 512, | |
| # batch_count: int = 4, | |
| # batch_size: int = 1, | |
| cfg_scale: float = 7.0, | |
| seed: int = 0, | |
| ): | |
| print(f">>> prompt: {prompt}, model_link: {model_link}") | |
| pipe = StableDiffusionPipeline.from_pretrained(model_link) | |
| batch_size = 4 | |
| prompt = [prompt] * batch_size | |
| negative_prompt = [n_prompt] * batch_size | |
| images = pipe( | |
| prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| guidance_scale=cfg_scale, | |
| ).images | |
| return images | |