AutoPage / utils /src /experiment /experiments.py
Mqleet's picture
upd code
fcaa164
import json
import os
import sys
import shutil
from functools import partial
from glob import glob
from time import sleep
from typing import Type
os.environ['OPENAI_API_KEY'] = 'Your key here'
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
sys.path.insert(0, root_dir)
import func_argparse
import torch
import src.llms as llms
from src.experiment.ablation import (
PPTCrew_wo_Decoupling,
PPTCrew_wo_HTML,
PPTCrew_wo_LayoutInduction,
PPTCrew_wo_SchemaInduction,
PPTCrew_wo_Structure,
)
from src.experiment.preprocess import process_filetype
from src.model_utils import get_text_model
from src.multimodal import ImageLabler
from src.pptgen import PPTCrew
from src.presentation import Presentation
from src.utils import Config, older_than, pbasename, pexists, pjoin, ppt_to_images
# language_model vision_model
EVAL_MODELS = [
(llms.qwen2_5, llms.qwen_vl),
(llms.gpt4o, llms.gpt4o),
(llms.qwen_vl, llms.qwen_vl),
]
# ablation
# 0: w/o layout induction
# 1: w/o schema induction
# 2: w/o decoupling
# 3: w/o html
# 4: with gpt4o template
# 5: w/o structure information
# 6: retry 5 times
AGENT_CLASS = {
-1: PPTCrew,
0: PPTCrew_wo_LayoutInduction,
1: PPTCrew_wo_SchemaInduction,
2: PPTCrew_wo_Decoupling,
3: PPTCrew_wo_HTML,
4: PPTCrew,
5: PPTCrew_wo_Structure,
6: PPTCrew,
}
def get_setting(setting_id: int, ablation_id: int):
assert ablation_id in AGENT_CLASS, f"ablation_id {ablation_id} not in {AGENT_CLASS}"
assert (
ablation_id == -1 or setting_id == 0
), "ablation_id == -1 only when setting_id != 0"
language_model, vision_model = EVAL_MODELS[setting_id]
agent_class = AGENT_CLASS.get(ablation_id)
llms.language_model = language_model
llms.vision_model = vision_model
if ablation_id == -1:
setting_name = "PPTCrew-" + llms.get_simple_modelname(
[language_model, vision_model]
)
elif ablation_id == 6:
setting_name = "PPTCrew_retry_5"
agent_class = partial(agent_class, retry_times=5)
else:
setting_name = agent_class.__name__
model_identifier = llms.get_simple_modelname(
[llms.language_model, llms.vision_model]
)
if ablation_id == 4:
setting_name = "PPTCrew_with_gpt4o"
model_identifier = "gpt-4o+gpt-4o"
return agent_class, setting_name, model_identifier
def do_generate(
genclass: Type[PPTCrew],
setting: str,
model_identifier: str,
debug: bool,
ppt_folder: str,
thread_id: int,
num_slides: int = 12,
):
app_config = Config(rundir=ppt_folder, debug=debug)
text_model = get_text_model(f"cuda:{thread_id % torch.cuda.device_count()}")
presentation = Presentation.from_file(
pjoin(ppt_folder, "source.pptx"),
app_config,
)
ImageLabler(presentation, app_config).caption_images()
induct_cache = pjoin(
app_config.RUN_DIR, "template_induct", model_identifier, "induct_cache.json"
)
if not older_than(induct_cache, wait=True):
print(f"induct_cache not found: {induct_cache}")
return
slide_induction = json.load(open(induct_cache))
try:
pptgen: PPTCrew = genclass(text_model).set_reference(presentation, slide_induction)
except:
print("set_reference failed")
pptgen: PPTCrew = genclass(text_model).set_reference(presentation, slide_induction)
topic = ppt_folder.split("/")[1]
for pdf_folder in glob(f"data/{topic}/pdf/*"):
app_config.set_rundir(pjoin(ppt_folder, setting, pbasename(pdf_folder)))
if pexists(pjoin(app_config.RUN_DIR, "history")):
continue
images = json.load(
open(pjoin(pdf_folder, "image_caption.json"), "r"),
)
doc_json = json.load(
open(pjoin(pdf_folder, "refined_doc.json"), "r"),
)
pptgen.generate_pres(app_config, images, num_slides, doc_json)
def generate_pres(
setting_id: int = 0,
setting_name: str = None,
ablation_id: int = -1,
thread_num: int = 8,
debug: bool = False,
topic: str = "*",
num_slides: int = 12,
):
agent_class, setting, model_identifier = get_setting(setting_id, ablation_id)
setting = setting_name or setting
print("generating slides using:", setting)
generate = partial(
do_generate,
agent_class,
setting,
model_identifier,
debug,
num_slides=num_slides,
)
process_filetype("pptx", generate, thread_num, topic)
def pptx2images(settings: str = "*"):
while True:
for folder in glob(f"data/*/pptx/*/{settings}/*/history"):
folder = os.path.dirname(folder)
pptx = pjoin(folder, "final.pptx")
ppt_folder, setting, pdf = folder.rsplit("/", 2)
dst = pjoin(ppt_folder, "final_images", setting, pdf)
if not pexists(pptx):
if pexists(dst):
print(f"remove {dst}")
shutil.rmtree(dst)
continue
older_than(pptx)
if pexists(dst):
continue
try:
ppt_to_images(pptx, dst)
except:
print("pptx to images failed")
sleep(60)
print("keep scanning for new pptx")
if __name__ == "__main__":
func_argparse.main(
generate_pres,
pptx2images,
)