admaker / core /video_generator.py
karthikeya1212's picture
Upload 24 files
eb8c5e1 verified
raw
history blame
3.2 kB
import os
import torch
from pathlib import Path
from huggingface_hub import hf_hub_download
from diffusers import AnimateDiffPipeline, MotionAdapter
from typing import List
from PIL import Image
# ---------------- MODEL CONFIG ----------------
MODEL_REPO = "ByteDance/AnimateDiff-Lightning"
MODEL_FILENAME = "animatediff_lightning_8step_comfyui.safetensors"
MODEL_DIR = Path("/tmp/models/animatediff_lightning")
os.makedirs(MODEL_DIR, exist_ok=True)
# ---------------- MODEL DOWNLOAD ----------------
def download_model() -> Path:
model_path = MODEL_DIR / MODEL_FILENAME
if not model_path.exists():
print("[VideoGen] Downloading AnimateDiff Lightning 8-step...")
model_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_FILENAME,
local_dir=str(MODEL_DIR),
force_download=False,
)
print(f"[VideoGen] Model downloaded to: {model_path}")
else:
print("[VideoGen] AnimateDiff model already exists.")
return model_path
# ---------------- PIPELINE LOAD ----------------
def load_pipeline() -> AnimateDiffPipeline:
model_path = download_model()
print("[VideoGen] Loading AnimateDiff pipeline...")
adapter = MotionAdapter.from_single_file(str(model_path))
pipe = AnimateDiffPipeline.from_pretrained(
"emilianJR/epiCRealism",
motion_adapter=adapter,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
if torch.cuda.is_available():
pipe.to("cuda")
print("[VideoGen] AnimateDiff ready.")
return pipe
# ---------------- GLOBAL PIPELINE CACHE ----------------
pipe: AnimateDiffPipeline | None = None
# ---------------- VIDEO GENERATION ----------------
def generate_video(
keyframe_images: List[Image.Image],
seed: int = None,
num_frames: int = 16
) -> List[Image.Image]:
"""
Generates a short video by interpolating between input keyframe images.
Args:
keyframe_images (List[PIL.Image]): List of PIL images representing keyframes.
seed (int, optional): Seed for deterministic generation.
num_frames (int): Total number of frames in the generated video.
Returns:
List[PIL.Image]: Interpolated video frames.
"""
global pipe
if pipe is None:
pipe = load_pipeline()
if len(keyframe_images) < 2:
raise ValueError("At least 2 keyframe images are required to generate a video.")
print(f"[VideoGen] Generating video from {len(keyframe_images)} keyframes, {num_frames} frames, seed={seed}")
generator = None
if seed is not None:
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = torch.Generator(device).manual_seed(seed)
# AnimateDiff expects init_images for interpolation between keyframes
video_frames = pipe(
init_images=keyframe_images,
num_frames=num_frames,
guidance_scale=1.0,
num_inference_steps=8,
generator=generator
).frames
print("[VideoGen] Video generated successfully.")
return video_frames