Spaces:
Runtime error
Runtime error
| import torch | |
| import psutil | |
| import argparse | |
| import os | |
| from diffusers import FlowMatchEulerDiscreteScheduler | |
| from diffusers.utils import load_image | |
| from transformers import AutoTokenizer, Wav2Vec2Model, Wav2Vec2Processor | |
| from omegaconf import OmegaConf | |
| from wan.models.cache_utils import get_teacache_coefficients | |
| from wan.models.wan_fantasy_transformer3d_1B import WanTransformer3DFantasyModel | |
| from wan.models.wan_text_encoder import WanT5EncoderModel | |
| from wan.models.wan_vae import AutoencoderKLWan | |
| from wan.models.wan_image_encoder import CLIPModel | |
| from wan.pipeline.wan_inference_long_pipeline import WanI2VTalkingInferenceLongPipeline | |
| from wan.utils.fp8_optimization import replace_parameters_by_name, convert_weight_dtype_wrapper, convert_model_weight_to_float8 | |
| from wan.utils.utils import get_image_to_video_latent, save_videos_grid | |
| import numpy as np | |
| import librosa | |
| import datetime | |
| import random | |
| import math | |
| import subprocess | |
| from moviepy.editor import VideoFileClip | |
| from huggingface_hub import snapshot_download | |
| import shutil | |
| import requests | |
| import uuid | |
| # Device and dtype setup | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| if torch.cuda.get_device_capability()[0] >= 8: | |
| dtype = torch.bfloat16 | |
| else: | |
| dtype = torch.float16 | |
| else: | |
| device = "cpu" | |
| dtype = torch.float32 | |
| def filter_kwargs(cls, kwargs): | |
| import inspect | |
| sig = inspect.signature(cls.__init__) | |
| valid_params = set(sig.parameters.keys()) - {'self', 'cls'} | |
| filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} | |
| return filtered_kwargs | |
| def load_transformer_model(model_version, repo_root): | |
| transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", f"transformer3d-{model_version}.pt") | |
| print(f"Loading model: {transformer_path}") | |
| if os.path.exists(transformer_path): | |
| state_dict = torch.load(transformer_path, map_location="cpu") | |
| state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict | |
| m, u = transformer3d.load_state_dict(state_dict, strict=False) | |
| print(f"Model loaded successfully: {transformer_path}") | |
| print(f"Missing keys: {len(m)}; Unexpected keys: {len(u)}") | |
| return transformer3d | |
| else: | |
| print(f"Error: Model file does not exist: {transformer_path}") | |
| return None | |
| def download_file(url, local_path): | |
| """Download file from URL to local path""" | |
| try: | |
| response = requests.get(url, stream=True) | |
| response.raise_for_status() | |
| with open(local_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| return local_path | |
| except Exception as e: | |
| print(f"Error downloading file from {url}: {e}") | |
| return None | |
| def prepare_input_file(input_path, file_type="image"): | |
| """Handle local or remote file inputs""" | |
| if input_path.startswith("http://") or input_path.startswith("https://"): | |
| ext = ".png" if file_type == "image" else ".wav" | |
| local_path = os.path.join("temp", f"{uuid.uuid4()}{ext}") | |
| os.makedirs("temp", exist_ok=True) | |
| return download_file(input_path, local_path) | |
| elif os.path.exists(input_path): | |
| return input_path | |
| else: | |
| print(f"Error: {file_type.capitalize()} file {input_path} does not exist") | |
| return None | |
| # Initialize model paths | |
| REPO_ID = "FrancisRing/StableAvatar" | |
| repo_root = snapshot_download( | |
| repo_id=REPO_ID, | |
| allow_patterns=[ | |
| "StableAvatar-1.3B/*", | |
| "Wan2.1-Fun-V1.1-1.3B-InP/*", | |
| "wav2vec2-base-960h/*", | |
| "assets/**", | |
| "Kim_Vocal_2.onnx", | |
| ], | |
| ) | |
| pretrained_model_name_or_path = os.path.join(repo_root, "Wan2.1-Fun-V1.1-1.3B-InP") | |
| pretrained_wav2vec_path = os.path.join(repo_root, "wav2vec2-base-960h") | |
| audio_separator_model_file = os.path.join(repo_root, "Kim_Vocal_2.onnx") | |
| # Load configuration and models | |
| config = OmegaConf.load("deepspeed_config/wan2.1/wan_civitai.yaml") | |
| sampler_name = "Flow" | |
| clip_sample_n_frames = 81 | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| os.path.join(pretrained_model_name_or_path, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')) | |
| ) | |
| text_encoder = WanT5EncoderModel.from_pretrained( | |
| os.path.join(pretrained_model_name_or_path, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), | |
| additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']), | |
| low_cpu_mem_usage=True, | |
| torch_dtype=dtype, | |
| ).eval() | |
| vae = AutoencoderKLWan.from_pretrained( | |
| os.path.join(pretrained_model_name_or_path, config['vae_kwargs'].get('vae_subpath', 'vae')), | |
| additional_kwargs=OmegaConf.to_container(config['vae_kwargs']), | |
| ) | |
| wav2vec_processor = Wav2Vec2Processor.from_pretrained(pretrained_wav2vec_path) | |
| wav2vec = Wav2Vec2Model.from_pretrained(pretrained_wav2vec_path).to("cpu") | |
| clip_image_encoder = CLIPModel.from_pretrained( | |
| os.path.join(pretrained_model_name_or_path, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')) | |
| ).eval() | |
| transformer3d = WanTransformer3DFantasyModel.from_pretrained( | |
| os.path.join(pretrained_model_name_or_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')), | |
| transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), | |
| low_cpu_mem_usage=False, | |
| torch_dtype=dtype, | |
| ) | |
| # Load default transformer model | |
| load_transformer_model("square", repo_root) | |
| # Initialize scheduler and pipeline | |
| scheduler_dict = {"Flow": FlowMatchEulerDiscreteScheduler} | |
| Choosen_Scheduler = scheduler_dict[sampler_name] | |
| scheduler = Choosen_Scheduler( | |
| **filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(config['scheduler_kwargs'])) | |
| ) | |
| pipeline = WanI2VTalkingInferenceLongPipeline( | |
| tokenizer=tokenizer, | |
| text_encoder=text_encoder, | |
| vae=vae, | |
| transformer=transformer3d, | |
| clip_image_encoder=clip_image_encoder, | |
| scheduler=scheduler, | |
| wav2vec_processor=wav2vec_processor, | |
| wav2vec=wav2vec, | |
| ) | |
| def generate( | |
| GPU_memory_mode="model_cpu_offload", | |
| teacache_threshold=0, | |
| num_skip_start_steps=5, | |
| image_path=None, | |
| audio_path=None, | |
| prompt="", | |
| negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", | |
| width=512, | |
| height=512, | |
| guidance_scale=6.0, | |
| num_inference_steps=50, | |
| text_guide_scale=3.0, | |
| audio_guide_scale=5.0, | |
| motion_frame=25, | |
| fps=25, | |
| overlap_window_length=10, | |
| seed_param=42, | |
| overlapping_weight_scheme="uniform" | |
| ): | |
| global pipeline, transformer3d | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| if seed_param < 0: | |
| seed = random.randint(0, np.iinfo(np.int32).max) | |
| else: | |
| seed = seed_param | |
| # Handle input files | |
| image_path = prepare_input_file(image_path, "image") | |
| audio_path = prepare_input_file(audio_path, "audio") | |
| if not image_path or not audio_path: | |
| return None, None, "Error: Invalid input file paths" | |
| # Configure pipeline based on GPU memory mode | |
| if GPU_memory_mode == "sequential_cpu_offload": | |
| replace_parameters_by_name(transformer3d, ["modulation"], device=device) | |
| transformer3d.freqs = transformer3d.freqs.to(device=device) | |
| pipeline.enable_sequential_cpu_offload(device=device) | |
| elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": | |
| convert_model_weight_to_float8(transformer3d, exclude_module_name=["modulation"]) | |
| convert_weight_dtype_wrapper(transformer3d, dtype) | |
| pipeline.enable_model_cpu_offload(device=device) | |
| elif GPU_memory_mode == "model_cpu_offload": | |
| pipeline.enable_model_cpu_offload(device=device) | |
| else: | |
| pipeline.to(device=device) | |
| # Enable TeaCache if specified | |
| if teacache_threshold > 0: | |
| coefficients = get_teacache_coefficients(pretrained_model_name_or_path) | |
| pipeline.transformer.enable_teacache( | |
| coefficients, | |
| num_inference_steps, | |
| teacache_threshold, | |
| num_skip_start_steps=num_skip_start_steps, | |
| ) | |
| # Perform inference | |
| with torch.no_grad(): | |
| video_length = int((clip_sample_n_frames - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if clip_sample_n_frames != 1 else 1 | |
| input_video, input_video_mask, clip_image = get_image_to_video_latent(image_path, None, video_length=video_length, sample_size=[height, width]) | |
| sr = 16000 | |
| vocal_input, sample_rate = librosa.load(audio_path, sr=sr) | |
| sample = pipeline( | |
| prompt, | |
| num_frames=video_length, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| guidance_scale=guidance_scale, | |
| generator=torch.Generator().manual_seed(seed), | |
| num_inference_steps=num_inference_steps, | |
| video=input_video, | |
| mask_video=input_video_mask, | |
| clip_image=clip_image, | |
| text_guide_scale=text_guide_scale, | |
| audio_guide_scale=audio_guide_scale, | |
| vocal_input_values=vocal_input, | |
| motion_frame=motion_frame, | |
| fps=fps, | |
| sr=sr, | |
| cond_file_path=image_path, | |
| overlap_window_length=overlap_window_length, | |
| seed=seed, | |
| overlapping_weight_scheme=overlapping_weight_scheme, | |
| ).videos | |
| os.makedirs("outputs", exist_ok=True) | |
| video_path = os.path.join("outputs", f"{timestamp}.mp4") | |
| save_videos_grid(sample, video_path, fps=fps) | |
| output_video_with_audio = os.path.join("outputs", f"{timestamp}_audio.mp4") | |
| subprocess.run([ | |
| "ffmpeg", "-y", "-loglevel", "quiet", "-i", video_path, "-i", audio_path, | |
| "-c:v", "copy", "-c:a", "aac", "-strict", "experimental", | |
| output_video_with_audio | |
| ], check=True) | |
| return output_video_with_audio, seed, f"Generated outputs/{timestamp}.mp4" | |
| def main(): | |
| parser = argparse.ArgumentParser(description="StableAvatar Inference Script") | |
| parser.add_argument("--prompt", type=str, default="", help="Text prompt for generation") | |
| parser.add_argument("--seed", type=int, default=42, help="Random seed, -1 for random") | |
| parser.add_argument("--input_image", type=str, required=True, help="Path or URL to input image (e.g., ./image.png or https://example.com/image.png)") | |
| parser.add_argument("--input_audio", type=str, required=True, help="Path or URL to input audio (e.g., ./audio.wav or https://example.com/audio.wav)") | |
| parser.add_argument("--GPU_memory_mode", type=str, default="model_cpu_offload", choices=["Normal", "model_cpu_offload", "model_cpu_offload_and_qfloat8", "sequential_cpu_offload"], help="GPU memory mode") | |
| parser.add_argument("--teacache_threshold", type=float, default=0, help="TeaCache threshold, 0 to disable") | |
| parser.add_argument("--num_skip_start_steps", type=int, default=5, help="Number of start steps to skip") | |
| parser.add_argument("--negative_prompt", type=str, default="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", help="Negative prompt") | |
| parser.add_argument("--width", type=int, default=512, help="Output video width") | |
| parser.add_argument("--height", type=int, default=512, help="Output video height") | |
| parser.add_argument("--guidance_scale", type=float, default=6.0, help="Guidance scale") | |
| parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of inference steps") | |
| parser.add_argument("--text_guide_scale", type=float, default=3.0, help="Text guidance scale") | |
| parser.add_argument("--audio_guide_scale", type=float, default=5.0, help="Audio guidance scale") | |
| parser.add_argument("--motion_frame", type=int, default=25, help="Motion frame") | |
| parser.add_argument("--fps", type=int, default=25, help="Frames per second") | |
| parser.add_argument("--overlap_window_length", type=int, default=10, help="Overlap window length") | |
| parser.add_argument("--overlapping_weight_scheme", type=str, default="uniform", choices=["uniform", "log"], help="Overlapping weight scheme") | |
| args = parser.parse_args() | |
| video_path, seed, message = generate( | |
| GPU_memory_mode=args.GPU_memory_mode, | |
| teacache_threshold=args.teacache_threshold, | |
| num_skip_start_steps=args.num_skip_start_steps, | |
| image_path=args.input_image, | |
| audio_path=args.input_audio, | |
| prompt=args.prompt, | |
| negative_prompt=args.negative_prompt, | |
| width=args.width, | |
| height=args.height, | |
| guidance_scale=args.guidance_scale, | |
| num_inference_steps=args.num_inference_steps, | |
| text_guide_scale=args.text_guide_scale, | |
| audio_guide_scale=args.audio_guide_scale, | |
| motion_frame=args.motion_frame, | |
| fps=args.fps, | |
| overlap_window_length=args.overlap_window_length, | |
| seed_param=args.seed, | |
| overlapping_weight_scheme=args.overlapping_weight_scheme | |
| ) | |
| if video_path: | |
| print(f"{message}\nSeed: {seed}") | |
| else: | |
| print("Generation failed.") | |
| if __name__ == "__main__": | |
| main() |