StableAvatar / app.py
dangthr's picture
Update app.py
160e694 verified
raw
history blame
13.8 kB
import torch
import psutil
import argparse
import os
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import load_image
from transformers import AutoTokenizer, Wav2Vec2Model, Wav2Vec2Processor
from omegaconf import OmegaConf
from wan.models.cache_utils import get_teacache_coefficients
from wan.models.wan_fantasy_transformer3d_1B import WanTransformer3DFantasyModel
from wan.models.wan_text_encoder import WanT5EncoderModel
from wan.models.wan_vae import AutoencoderKLWan
from wan.models.wan_image_encoder import CLIPModel
from wan.pipeline.wan_inference_long_pipeline import WanI2VTalkingInferenceLongPipeline
from wan.utils.fp8_optimization import replace_parameters_by_name, convert_weight_dtype_wrapper, convert_model_weight_to_float8
from wan.utils.utils import get_image_to_video_latent, save_videos_grid
import numpy as np
import librosa
import datetime
import random
import math
import subprocess
from moviepy.editor import VideoFileClip
from huggingface_hub import snapshot_download
import shutil
import requests
import uuid
# Device and dtype setup
if torch.cuda.is_available():
device = "cuda"
if torch.cuda.get_device_capability()[0] >= 8:
dtype = torch.bfloat16
else:
dtype = torch.float16
else:
device = "cpu"
dtype = torch.float32
def filter_kwargs(cls, kwargs):
import inspect
sig = inspect.signature(cls.__init__)
valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
return filtered_kwargs
def load_transformer_model(model_version, repo_root):
transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", f"transformer3d-{model_version}.pt")
print(f"Loading model: {transformer_path}")
if os.path.exists(transformer_path):
state_dict = torch.load(transformer_path, map_location="cpu")
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
m, u = transformer3d.load_state_dict(state_dict, strict=False)
print(f"Model loaded successfully: {transformer_path}")
print(f"Missing keys: {len(m)}; Unexpected keys: {len(u)}")
return transformer3d
else:
print(f"Error: Model file does not exist: {transformer_path}")
return None
def download_file(url, local_path):
"""Download file from URL to local path"""
try:
response = requests.get(url, stream=True)
response.raise_for_status()
with open(local_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return local_path
except Exception as e:
print(f"Error downloading file from {url}: {e}")
return None
def prepare_input_file(input_path, file_type="image"):
"""Handle local or remote file inputs"""
if input_path.startswith("http://") or input_path.startswith("https://"):
ext = ".png" if file_type == "image" else ".wav"
local_path = os.path.join("temp", f"{uuid.uuid4()}{ext}")
os.makedirs("temp", exist_ok=True)
return download_file(input_path, local_path)
elif os.path.exists(input_path):
return input_path
else:
print(f"Error: {file_type.capitalize()} file {input_path} does not exist")
return None
# Initialize model paths
REPO_ID = "FrancisRing/StableAvatar"
repo_root = snapshot_download(
repo_id=REPO_ID,
allow_patterns=[
"StableAvatar-1.3B/*",
"Wan2.1-Fun-V1.1-1.3B-InP/*",
"wav2vec2-base-960h/*",
"assets/**",
"Kim_Vocal_2.onnx",
],
)
pretrained_model_name_or_path = os.path.join(repo_root, "Wan2.1-Fun-V1.1-1.3B-InP")
pretrained_wav2vec_path = os.path.join(repo_root, "wav2vec2-base-960h")
audio_separator_model_file = os.path.join(repo_root, "Kim_Vocal_2.onnx")
# Load configuration and models
config = OmegaConf.load("deepspeed_config/wan2.1/wan_civitai.yaml")
sampler_name = "Flow"
clip_sample_n_frames = 81
tokenizer = AutoTokenizer.from_pretrained(
os.path.join(pretrained_model_name_or_path, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer'))
)
text_encoder = WanT5EncoderModel.from_pretrained(
os.path.join(pretrained_model_name_or_path, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
low_cpu_mem_usage=True,
torch_dtype=dtype,
).eval()
vae = AutoencoderKLWan.from_pretrained(
os.path.join(pretrained_model_name_or_path, config['vae_kwargs'].get('vae_subpath', 'vae')),
additional_kwargs=OmegaConf.to_container(config['vae_kwargs']),
)
wav2vec_processor = Wav2Vec2Processor.from_pretrained(pretrained_wav2vec_path)
wav2vec = Wav2Vec2Model.from_pretrained(pretrained_wav2vec_path).to("cpu")
clip_image_encoder = CLIPModel.from_pretrained(
os.path.join(pretrained_model_name_or_path, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder'))
).eval()
transformer3d = WanTransformer3DFantasyModel.from_pretrained(
os.path.join(pretrained_model_name_or_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
low_cpu_mem_usage=False,
torch_dtype=dtype,
)
# Load default transformer model
load_transformer_model("square", repo_root)
# Initialize scheduler and pipeline
scheduler_dict = {"Flow": FlowMatchEulerDiscreteScheduler}
Choosen_Scheduler = scheduler_dict[sampler_name]
scheduler = Choosen_Scheduler(
**filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(config['scheduler_kwargs']))
)
pipeline = WanI2VTalkingInferenceLongPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer3d,
clip_image_encoder=clip_image_encoder,
scheduler=scheduler,
wav2vec_processor=wav2vec_processor,
wav2vec=wav2vec,
)
def generate(
GPU_memory_mode="model_cpu_offload",
teacache_threshold=0,
num_skip_start_steps=5,
image_path=None,
audio_path=None,
prompt="",
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
width=512,
height=512,
guidance_scale=6.0,
num_inference_steps=50,
text_guide_scale=3.0,
audio_guide_scale=5.0,
motion_frame=25,
fps=25,
overlap_window_length=10,
seed_param=42,
overlapping_weight_scheme="uniform"
):
global pipeline, transformer3d
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
if seed_param < 0:
seed = random.randint(0, np.iinfo(np.int32).max)
else:
seed = seed_param
# Handle input files
image_path = prepare_input_file(image_path, "image")
audio_path = prepare_input_file(audio_path, "audio")
if not image_path or not audio_path:
return None, None, "Error: Invalid input file paths"
# Configure pipeline based on GPU memory mode
if GPU_memory_mode == "sequential_cpu_offload":
replace_parameters_by_name(transformer3d, ["modulation"], device=device)
transformer3d.freqs = transformer3d.freqs.to(device=device)
pipeline.enable_sequential_cpu_offload(device=device)
elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
convert_model_weight_to_float8(transformer3d, exclude_module_name=["modulation"])
convert_weight_dtype_wrapper(transformer3d, dtype)
pipeline.enable_model_cpu_offload(device=device)
elif GPU_memory_mode == "model_cpu_offload":
pipeline.enable_model_cpu_offload(device=device)
else:
pipeline.to(device=device)
# Enable TeaCache if specified
if teacache_threshold > 0:
coefficients = get_teacache_coefficients(pretrained_model_name_or_path)
pipeline.transformer.enable_teacache(
coefficients,
num_inference_steps,
teacache_threshold,
num_skip_start_steps=num_skip_start_steps,
)
# Perform inference
with torch.no_grad():
video_length = int((clip_sample_n_frames - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if clip_sample_n_frames != 1 else 1
input_video, input_video_mask, clip_image = get_image_to_video_latent(image_path, None, video_length=video_length, sample_size=[height, width])
sr = 16000
vocal_input, sample_rate = librosa.load(audio_path, sr=sr)
sample = pipeline(
prompt,
num_frames=video_length,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
generator=torch.Generator().manual_seed(seed),
num_inference_steps=num_inference_steps,
video=input_video,
mask_video=input_video_mask,
clip_image=clip_image,
text_guide_scale=text_guide_scale,
audio_guide_scale=audio_guide_scale,
vocal_input_values=vocal_input,
motion_frame=motion_frame,
fps=fps,
sr=sr,
cond_file_path=image_path,
overlap_window_length=overlap_window_length,
seed=seed,
overlapping_weight_scheme=overlapping_weight_scheme,
).videos
os.makedirs("outputs", exist_ok=True)
video_path = os.path.join("outputs", f"{timestamp}.mp4")
save_videos_grid(sample, video_path, fps=fps)
output_video_with_audio = os.path.join("outputs", f"{timestamp}_audio.mp4")
subprocess.run([
"ffmpeg", "-y", "-loglevel", "quiet", "-i", video_path, "-i", audio_path,
"-c:v", "copy", "-c:a", "aac", "-strict", "experimental",
output_video_with_audio
], check=True)
return output_video_with_audio, seed, f"Generated outputs/{timestamp}.mp4"
def main():
parser = argparse.ArgumentParser(description="StableAvatar Inference Script")
parser.add_argument("--prompt", type=str, default="", help="Text prompt for generation")
parser.add_argument("--seed", type=int, default=42, help="Random seed, -1 for random")
parser.add_argument("--input_image", type=str, required=True, help="Path or URL to input image (e.g., ./image.png or https://example.com/image.png)")
parser.add_argument("--input_audio", type=str, required=True, help="Path or URL to input audio (e.g., ./audio.wav or https://example.com/audio.wav)")
parser.add_argument("--GPU_memory_mode", type=str, default="model_cpu_offload", choices=["Normal", "model_cpu_offload", "model_cpu_offload_and_qfloat8", "sequential_cpu_offload"], help="GPU memory mode")
parser.add_argument("--teacache_threshold", type=float, default=0, help="TeaCache threshold, 0 to disable")
parser.add_argument("--num_skip_start_steps", type=int, default=5, help="Number of start steps to skip")
parser.add_argument("--negative_prompt", type=str, default="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", help="Negative prompt")
parser.add_argument("--width", type=int, default=512, help="Output video width")
parser.add_argument("--height", type=int, default=512, help="Output video height")
parser.add_argument("--guidance_scale", type=float, default=6.0, help="Guidance scale")
parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of inference steps")
parser.add_argument("--text_guide_scale", type=float, default=3.0, help="Text guidance scale")
parser.add_argument("--audio_guide_scale", type=float, default=5.0, help="Audio guidance scale")
parser.add_argument("--motion_frame", type=int, default=25, help="Motion frame")
parser.add_argument("--fps", type=int, default=25, help="Frames per second")
parser.add_argument("--overlap_window_length", type=int, default=10, help="Overlap window length")
parser.add_argument("--overlapping_weight_scheme", type=str, default="uniform", choices=["uniform", "log"], help="Overlapping weight scheme")
args = parser.parse_args()
video_path, seed, message = generate(
GPU_memory_mode=args.GPU_memory_mode,
teacache_threshold=args.teacache_threshold,
num_skip_start_steps=args.num_skip_start_steps,
image_path=args.input_image,
audio_path=args.input_audio,
prompt=args.prompt,
negative_prompt=args.negative_prompt,
width=args.width,
height=args.height,
guidance_scale=args.guidance_scale,
num_inference_steps=args.num_inference_steps,
text_guide_scale=args.text_guide_scale,
audio_guide_scale=args.audio_guide_scale,
motion_frame=args.motion_frame,
fps=args.fps,
overlap_window_length=args.overlap_window_length,
seed_param=args.seed,
overlapping_weight_scheme=args.overlapping_weight_scheme
)
if video_path:
print(f"{message}\nSeed: {seed}")
else:
print("Generation failed.")
if __name__ == "__main__":
main()