admaker / pipeline /pipeline.py
karthikeya1212's picture
Update pipeline/pipeline.py
813764f verified
import os
from pathlib import Path
# CACHE PATCH BLOCK: place FIRST in pipeline.py!
HF_CACHE_DIR = Path("/tmp/hf_cache")
HF_CACHE_DIR.mkdir(parents=True, exist_ok=True)
os.environ.update({
"HF_HOME": str(HF_CACHE_DIR),
"HF_HUB_CACHE": str(HF_CACHE_DIR),
"DIFFUSERS_CACHE": str(HF_CACHE_DIR),
"TRANSFORMERS_CACHE": str(HF_CACHE_DIR),
"XDG_CACHE_HOME": str(HF_CACHE_DIR),
"HF_DATASETS_CACHE": str(HF_CACHE_DIR),
"HF_MODULES_CACHE": str(HF_CACHE_DIR),
"TMPDIR": str(HF_CACHE_DIR),
"CACHE_DIR": str(HF_CACHE_DIR),
"TORCH_HOME": str(HF_CACHE_DIR),
"HOME": str(HF_CACHE_DIR)
})
import os.path
if not hasattr(os.path, "expanduser_original"):
os.path.expanduser_original = os.path.expanduser
def safe_expanduser(path):
if (
path.startswith("~") or
path.startswith("/.cache") or
path.startswith("/root/.cache")
):
return str(HF_CACHE_DIR)
return os.path.expanduser_original(path)
os.path.expanduser = safe_expanduser
import asyncio
import logging
import core.script_gen as script_gen
import core.story_script as story_script
import core.image_generator as image_gen
# import core.video_gen as video_gen
# import core.music_gen as music_gen
# import core.assemble as assemble
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s"
)
async def run_pipeline(task: dict, confirmation_event: asyncio.Event):
"""
Executes the full workflow:
1. Script generation
2. Wait for user confirmation
3. Story generation
4. Image generation
Video/music/assemble placeholders for now
"""
task_id = task["id"]
idea = task["idea"]
logging.info(f"[Pipeline] Starting script generation for task {task_id}")
script = await script_gen.generate_script(idea)
task["result"]["script"] = script
task["status"] = "waiting_for_confirmation"
task["confirmation_required"] = True
logging.info(f"[Pipeline] Script ready for task {task_id}, waiting confirmation...")
# Wait for user confirmation
await confirmation_event.wait()
task["status"] = "confirmed"
task["confirmation_required"] = False
logging.info(f"[Pipeline] Task {task_id} confirmed. Continuing pipeline...")
# Story generation
logging.info(f"[Pipeline] Generating story for task {task_id}")
story = await story_script.generate_story(script)
print(story)
task["result"]["story_script"] = story
# Image generation
logging.info(f"[Pipeline] Generating images for task {task_id}")
images = await image_gen.generate_images(story)
task["result"]["images"] = images
# Placeholder for future stages
# logging.info(f"[Pipeline] Generating video/music/assembling for task {task_id}")
# task["result"]["video"] = None
# task["result"]["music"] = None
# task["result"]["final_output"] = None
task["status"] = "completed"
logging.info(f"[Pipeline] Task {task_id} completed. Output: {images}")
return task["result"]