|
|
import gradio as gr |
|
|
import torch |
|
|
import os |
|
|
import sys |
|
|
import subprocess |
|
|
import tempfile |
|
|
import numpy as np |
|
|
import spaces |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
REPO_PATH = "LongCat-Video" |
|
|
CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video") |
|
|
|
|
|
|
|
|
if not os.path.exists(REPO_PATH): |
|
|
print(f"Cloning LongCat-Video repository to '{REPO_PATH}'...") |
|
|
try: |
|
|
subprocess.run( |
|
|
["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH], |
|
|
check=True, |
|
|
capture_output=True |
|
|
) |
|
|
print("Repository cloned successfully.") |
|
|
except subprocess.CalledProcessError as e: |
|
|
print(f"Error cloning repository: {e.stderr.decode()}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.abspath(REPO_PATH)) |
|
|
|
|
|
|
|
|
from huggingface_hub import snapshot_download |
|
|
from longcat_video.pipeline_longcat_video import LongCatVideoPipeline |
|
|
from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler |
|
|
from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan |
|
|
from longcat_video.modules.longcat_video_dit import LongCatVideoTransformer3DModel |
|
|
from longcat_video.context_parallel import context_parallel_util |
|
|
from transformers import AutoTokenizer, UMT5EncoderModel |
|
|
from diffusers.utils import export_to_video |
|
|
|
|
|
|
|
|
if not os.path.exists(CHECKPOINT_DIR): |
|
|
print(f"Downloading model weights to '{CHECKPOINT_DIR}'...") |
|
|
try: |
|
|
snapshot_download( |
|
|
repo_id="meituan-longcat/LongCat-Video", |
|
|
local_dir=CHECKPOINT_DIR, |
|
|
local_dir_use_symlinks=False, |
|
|
ignore_patterns=["*.md", "*.gitattributes", "assets/*"] |
|
|
) |
|
|
print("Model weights downloaded successfully.") |
|
|
except Exception as e: |
|
|
print(f"Error downloading model weights: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
pipe = None |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32 |
|
|
|
|
|
print("--- Initializing Models (loaded once at startup) ---") |
|
|
try: |
|
|
|
|
|
cp_split_hw = context_parallel_util.get_optimal_split(1) |
|
|
|
|
|
print("Loading tokenizer and text encoder...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, subfolder="tokenizer", torch_dtype=torch_dtype) |
|
|
text_encoder = UMT5EncoderModel.from_pretrained(CHECKPOINT_DIR, subfolder="text_encoder", torch_dtype=torch_dtype) |
|
|
|
|
|
print("Loading VAE and Scheduler...") |
|
|
vae = AutoencoderKLWan.from_pretrained(CHECKPOINT_DIR, subfolder="vae", torch_dtype=torch_dtype) |
|
|
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(CHECKPOINT_DIR, subfolder="scheduler", torch_dtype=torch_dtype) |
|
|
|
|
|
print("Loading DiT model...") |
|
|
dit = LongCatVideoTransformer3DModel.from_pretrained(CHECKPOINT_DIR, |
|
|
enable_flashattn3=False, |
|
|
enable_flashattn2=False, |
|
|
enable_xformers=True, |
|
|
subfolder="dit", |
|
|
cp_split_hw=cp_split_hw, |
|
|
torch_dtype=torch_dtype) |
|
|
|
|
|
print("Creating LongCatVideoPipeline...") |
|
|
pipe = LongCatVideoPipeline( |
|
|
tokenizer=tokenizer, |
|
|
text_encoder=text_encoder, |
|
|
vae=vae, |
|
|
scheduler=scheduler, |
|
|
dit=dit, |
|
|
) |
|
|
pipe.to(device) |
|
|
|
|
|
print("Loading LoRA weights for optional modes...") |
|
|
cfg_step_lora_path = os.path.join(CHECKPOINT_DIR, 'lora/cfg_step_lora.safetensors') |
|
|
pipe.dit.load_lora(cfg_step_lora_path, 'cfg_step_lora') |
|
|
|
|
|
refinement_lora_path = os.path.join(CHECKPOINT_DIR, 'lora/refinement_lora.safetensors') |
|
|
pipe.dit.load_lora(refinement_lora_path, 'refinement_lora') |
|
|
|
|
|
print("--- Models loaded successfully and are ready for inference. ---") |
|
|
|
|
|
except Exception as e: |
|
|
print("--- FATAL ERROR: Failed to load models. ---") |
|
|
print(f"Details: {e}") |
|
|
|
|
|
pipe = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def torch_gc(): |
|
|
"""Helper function to clean up GPU memory.""" |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.ipc_collect() |
|
|
|
|
|
def check_duration( |
|
|
mode, |
|
|
prompt, |
|
|
neg_prompt, |
|
|
image, |
|
|
height, width, resolution, |
|
|
seed, |
|
|
use_distill, |
|
|
use_refine, |
|
|
progress |
|
|
): |
|
|
if use_refine and resolution=="480p": |
|
|
return 240 |
|
|
elif resolution=="720p": |
|
|
return 360 |
|
|
else: |
|
|
return 900 |
|
|
|
|
|
@spaces.GPU(duration=check_duration) |
|
|
def generate_video( |
|
|
mode, |
|
|
prompt, |
|
|
neg_prompt, |
|
|
image, |
|
|
height, width, resolution, |
|
|
seed, |
|
|
use_distill, |
|
|
use_refine, |
|
|
progress=gr.Progress(track_tqdm=True) |
|
|
): |
|
|
""" |
|
|
Universal video generation function. |
|
|
""" |
|
|
if pipe is None: |
|
|
raise gr.Error("Models failed to load. Please check the console output for errors and restart the app.") |
|
|
|
|
|
generator = torch.Generator(device=device).manual_seed(int(seed)) |
|
|
|
|
|
|
|
|
progress(0, desc="Starting Stage 1: Base Generation") |
|
|
|
|
|
num_frames = 93 |
|
|
is_distill = use_distill or use_refine |
|
|
|
|
|
if is_distill: |
|
|
pipe.dit.enable_loras(['cfg_step_lora']) |
|
|
num_inference_steps = 16 |
|
|
guidance_scale = 1.0 |
|
|
current_neg_prompt = "" |
|
|
else: |
|
|
num_inference_steps = 50 |
|
|
guidance_scale = 4.0 |
|
|
current_neg_prompt = neg_prompt |
|
|
|
|
|
if mode == "t2v": |
|
|
output = pipe.generate_t2v( |
|
|
prompt=prompt, |
|
|
negative_prompt=current_neg_prompt, |
|
|
height=height, |
|
|
width=width, |
|
|
num_frames=num_frames, |
|
|
num_inference_steps=num_inference_steps, |
|
|
use_distill=is_distill, |
|
|
guidance_scale=guidance_scale, |
|
|
generator=generator, |
|
|
)[0] |
|
|
elif mode == "i2v": |
|
|
pil_image = Image.fromarray(image) |
|
|
output = pipe.generate_i2v( |
|
|
image=pil_image, |
|
|
prompt=prompt, |
|
|
negative_prompt=current_neg_prompt, |
|
|
resolution=resolution, |
|
|
num_frames=num_frames, |
|
|
num_inference_steps=num_inference_steps, |
|
|
use_distill=is_distill, |
|
|
guidance_scale=guidance_scale, |
|
|
generator=generator, |
|
|
)[0] |
|
|
|
|
|
if is_distill: |
|
|
pipe.dit.disable_all_loras() |
|
|
|
|
|
torch_gc() |
|
|
|
|
|
|
|
|
if use_refine: |
|
|
progress(0.5, desc="Starting Stage 2: Refinement") |
|
|
|
|
|
pipe.dit.enable_loras(['refinement_lora']) |
|
|
pipe.dit.enable_bsa() |
|
|
|
|
|
stage1_video_pil = [(frame * 255).astype(np.uint8) for frame in output] |
|
|
stage1_video_pil = [Image.fromarray(img) for img in stage1_video_pil] |
|
|
|
|
|
refine_image = Image.fromarray(image) if mode == 'i2v' else None |
|
|
|
|
|
output = pipe.generate_refine( |
|
|
image=refine_image, |
|
|
prompt=prompt, |
|
|
stage1_video=stage1_video_pil, |
|
|
num_cond_frames=1 if mode == 'i2v' else 0, |
|
|
num_inference_steps=50, |
|
|
generator=generator, |
|
|
)[0] |
|
|
|
|
|
pipe.dit.disable_all_loras() |
|
|
pipe.dit.disable_bsa() |
|
|
torch_gc() |
|
|
|
|
|
|
|
|
progress(1.0, desc="Exporting video") |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video_file: |
|
|
fps = 30 if use_refine else 15 |
|
|
export_to_video(output, temp_video_file.name, fps=fps) |
|
|
return temp_video_file.name |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(css="style.css") as demo: |
|
|
gr.Markdown("# 🎬 LongCat-Video") |
|
|
gr.Markdown('''[[Model](https://huggingface.co/meituan-longcat/LongCat-Video)]''') |
|
|
|
|
|
with gr.Tabs() as tabs: |
|
|
with gr.TabItem("Image-to-Video", id=1): |
|
|
mode_i2v = gr.State("i2v") |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
image_i2v = gr.Image(type="numpy", label="Input Image") |
|
|
prompt_i2v = gr.Textbox(label="Prompt", lines=4, placeholder="The cat in the image wags its tail and blinks.") |
|
|
|
|
|
with gr.Accordion(label="Advanced Options", open=False): |
|
|
neg_prompt_i2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality, static, subtitles, watermark") |
|
|
resolution_i2v = gr.Dropdown(label="Resolution", choices=["480p", "720p"], value="480p") |
|
|
seed_i2v = gr.Number(label="Seed", value=42, precision=0) |
|
|
distill_i2v = gr.Checkbox(label="Use Distill Mode", value=True, info="Faster, lower quality base generation.") |
|
|
refine_i2v = gr.Checkbox(label="Use Refine Mode", value=False, info="Higher quality & resolution, but slower. Uses Distill mode for its first stage.") |
|
|
|
|
|
i2v_button = gr.Button("Generate Video", variant="primary") |
|
|
with gr.Column(scale=3): |
|
|
video_output_i2v = gr.Video(label="Generated Video", interactive=False) |
|
|
|
|
|
with gr.TabItem("Text-to-Video", id=0): |
|
|
mode_t2v = gr.State("t2v") |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
prompt_t2v = gr.Textbox(label="Prompt", lines=4, placeholder="A cinematic shot of a Corgi walking on the beach.") |
|
|
|
|
|
with gr.Accordion(label="Advanced Options", open=False): |
|
|
neg_prompt_t2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality, static, subtitles") |
|
|
with gr.Row(): |
|
|
height_t2v = gr.Slider(label="Height", minimum=256, maximum=1024, value=480, step=64) |
|
|
width_t2v = gr.Slider(label="Width", minimum=256, maximum=1024, value=832, step=64) |
|
|
with gr.Row(): |
|
|
seed_t2v = gr.Number(label="Seed", value=42, precision=0) |
|
|
distill_t2v = gr.Checkbox(label="Use Distill Mode", value=True, info="Faster, lower quality base generation.") |
|
|
refine_t2v = gr.Checkbox(label="Use Refine Mode", value=False, info="Higher quality & resolution, but slower. Uses Distill mode for its first stage.") |
|
|
|
|
|
t2v_button = gr.Button("Generate Video", variant="primary") |
|
|
with gr.Column(scale=3): |
|
|
video_output_t2v = gr.Video(label="Generated Video", interactive=False) |
|
|
|
|
|
|
|
|
t2v_inputs = [ |
|
|
mode_t2v, prompt_t2v, neg_prompt_t2v, |
|
|
gr.State(None), |
|
|
height_t2v, width_t2v, |
|
|
gr.State(None), |
|
|
seed_t2v, distill_t2v, refine_t2v |
|
|
] |
|
|
t2v_button.click(fn=generate_video, inputs=t2v_inputs, outputs=video_output_t2v) |
|
|
|
|
|
i2v_inputs = [ |
|
|
mode_i2v, prompt_i2v, neg_prompt_i2v, image_i2v, |
|
|
gr.State(None), gr.State(None), |
|
|
resolution_i2v, |
|
|
seed_i2v, distill_i2v, refine_i2v |
|
|
] |
|
|
i2v_button.click(fn=generate_video, inputs=i2v_inputs, outputs=video_output_i2v) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |