import spaces import subprocess subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) import os import argparse import time from typing import Optional import torch from torchvision.io import write_video from omegaconf import OmegaConf from einops import rearrange import gradio as gr from pipeline import CausalInferencePipeline from huggingface_hub import snapshot_download, hf_hub_download # ----------------------------- # Globals (loaded once per process) # ----------------------------- _PIPELINE: Optional[torch.nn.Module] = None _DEVICE: Optional[torch.device] = None def _ensure_gpu(): if not torch.cuda.is_available(): raise gr.Error("CUDA GPU is required to run this demo. Please run on a machine with an NVIDIA GPU.") # Bind to GPU:0 by default torch.cuda.set_device(0) def _load_pipeline(config_path: str, checkpoint_path: Optional[str], use_ema: bool) -> torch.nn.Module: global _PIPELINE, _DEVICE if _PIPELINE is not None: return _PIPELINE _ensure_gpu() _DEVICE = torch.device("cuda:0") # Load and merge configs config = OmegaConf.load(config_path) default_config = OmegaConf.load("configs/default_config.yaml") config = OmegaConf.merge(default_config, config) # Choose pipeline type based on config pipeline = CausalInferencePipeline(config, device=_DEVICE) # Load checkpoint if provided if checkpoint_path and os.path.exists(checkpoint_path): state_dict = torch.load(checkpoint_path, map_location="cpu") if use_ema and 'generator_ema' in state_dict: state_dict_to_load = state_dict['generator_ema'] # Remove possible FSDP prefix from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict_to_load.items(): new_state_dict[k.replace("_fsdp_wrapped_module.", "")] = v state_dict_to_load = new_state_dict else: state_dict_to_load = state_dict.get('generator', state_dict) pipeline.generator.load_state_dict(state_dict_to_load, strict=False) # The codebase assumes bfloat16 on GPU pipeline = pipeline.to(device=_DEVICE, dtype=torch.bfloat16) pipeline.eval() # Quick sanity path check for Wan models to give friendly errors wan_dir = os.path.join('wan_models', 'Wan2.1-T2V-1.3B') if not os.path.isdir(wan_dir): raise gr.Error( "Wan2.1-T2V-1.3B not found at 'wan_models/Wan2.1-T2V-1.3B'.\n" "Please download it first, e.g.:\n" "huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir-use-symlinks False --local-dir wan_models/Wan2.1-T2V-1.3B" ) _PIPELINE = pipeline return _PIPELINE def build_predict(config_path: str, checkpoint_path: Optional[str], output_dir: str, use_ema: bool): os.makedirs(output_dir, exist_ok=True) @spaces.GPU def predict(prompt: str, num_frames: int) -> str: if not prompt or not prompt.strip(): raise gr.Error("Please enter a non-empty text prompt.") num_frames = int(num_frames) if num_frames % 3 != 0 or not (21 <= num_frames <= 252): raise gr.Error("Number of frames must be a multiple of 3 between 21 and 252.") pipeline = _load_pipeline(config_path, checkpoint_path, use_ema) # Prepare inputs prompts = [prompt.strip()] noise = torch.randn([1, num_frames, 16, 60, 104], device=_DEVICE, dtype=torch.bfloat16) torch.set_grad_enabled(False) with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16): video = pipeline.inference_rolling_forcing( noise=noise, text_prompts=prompts, return_latents=False, initial_latent=None, ) # video: [B=1, T, C, H, W] in [0,1] video = rearrange(video, 'b t c h w -> b t h w c')[0] video_uint8 = (video * 255.0).clamp(0, 255).to(torch.uint8).cpu() # Save to a unique filepath safe_stub = prompt[:60].replace(' ', '_').replace('/', '_') ts = int(time.time()) filepath = os.path.join(output_dir, f"{safe_stub or 'video'}_{ts}.mp4") write_video(filepath, video_uint8, fps=16) print(f"Saved generated video to {filepath}") return filepath return predict def main(): parser = argparse.ArgumentParser() parser.add_argument('--config_path', type=str, default='configs/rolling_forcing_dmd.yaml', help='Path to the model config') parser.add_argument('--checkpoint_path', type=str, default='checkpoints/rolling_forcing_dmd.pt', help='Path to rolling forcing checkpoint (.pt). If missing, will run with base weights only if available.') parser.add_argument('--output_dir', type=str, default='videos/gradio', help='Where to save generated videos') parser.add_argument('--no_ema', action='store_true', help='Disable EMA weights when loading checkpoint') args = parser.parse_args() # Download checkpoint from HuggingFace if not present # 1️⃣ Equivalent to: # huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir wan_models/Wan2.1-T2V-1.3B wan_model_dir = snapshot_download( repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="wan_models/Wan2.1-T2V-1.3B", local_dir_use_symlinks=False, # same as --local-dir-use-symlinks False ) print("Wan model downloaded to:", wan_model_dir) # 2️⃣ Equivalent to: # huggingface-cli download TencentARC/RollingForcing checkpoints/rolling_forcing_dmd.pt --local-dir . rolling_ckpt_path = hf_hub_download( repo_id="TencentARC/RollingForcing", filename="checkpoints/rolling_forcing_dmd.pt", local_dir=".", # where to store it local_dir_use_symlinks=False, ) print("RollingForcing checkpoint downloaded to:", rolling_ckpt_path) predict = build_predict( config_path=args.config_path, checkpoint_path=args.checkpoint_path, output_dir=args.output_dir, use_ema=not args.no_ema, ) demo = gr.Interface( fn=predict, inputs=[ gr.Textbox(label="Text Prompt", lines=2, placeholder="A cinematic shot of a girl dancing in the sunset."), gr.Slider(label="Number of Latent Frames", minimum=21, maximum=252, step=3, value=21), ], outputs=gr.Video(label="Generated Video", format="mp4"), title="Rolling Forcing: Autoregressive Long Video Diffusion in Real Time", description=( "Enter a prompt and generate a video using the Rolling Forcing pipeline.\n" "**Note:** although Rolling Forcing generates videos autoregressivelty, current Gradio demo does not support streaming outputs, so the entire video will be generated before it is displayed.\n" "\n" "If you find this demo useful, please consider giving it a ⭐ star on [GitHub](https://github.com/TencentARC/RollingForcing)--your support is crucial for sustaining this open-source project. " "You can also dive deeper by reading the [paper](https://arxiv.org/abs/2509.25161) or exploring the [project page](https://kunhao-liu.github.io/Rolling_Forcing_Webpage) for more details." ), allow_flagging='never', ) try: # Gradio <= 3.x demo.queue(concurrency_count=1, max_size=2) except TypeError: # Gradio >= 4.x demo.queue(max_size=2) demo.launch(show_error=True) if __name__ == "__main__": main()