ZaynZhu
Clean version without large assets
7c08dc3
from dotenv import load_dotenv
from utils.src.utils import ppt_to_images, get_json_from_response
import json
import shutil
from camel.models import ModelFactory
from camel.agents import ChatAgent
from utils.wei_utils import *
from camel.messages import BaseMessage
from PIL import Image
import pickle as pkl
from utils.pptx_utils import *
from utils.critic_utils import *
import yaml
from jinja2 import Environment, StrictUndefined
from pdf2image import convert_from_path
import argparse
load_dotenv()
def poster_apply_theme(args, actor_config, critic_config):
total_input_token, total_output_token = 0, 0
extract_input_token, extract_output_token = 0, 0
gen_input_token, gen_output_token = 0, 0
non_overlap_ckpt = pkl.load(open(f'checkpoints/{args.model_name}_{args.poster_name}_non_overlap_ckpt_{args.index}.pkl', 'rb'))
non_overlap_code = non_overlap_ckpt['final_code_by_section']
sections = list(non_overlap_code.keys())
sections = [s for s in sections if s != 'meta']
template_img = convert_from_path(args.template_path)[0]
image_bytes = io.BytesIO()
template_img.save(image_bytes, format="PNG")
image_bytes.seek(0)
# Reload the image from memory as a standard PIL.Image.Image
template_img = Image.open(image_bytes)
title_actor_agent_name = 'theme_agent_title'
with open(f"prompt_templates/{title_actor_agent_name}.yaml", "r") as f:
title_theme_actor_config = yaml.safe_load(f)
section_actor_agent_name = 'theme_agent_section'
with open(f"prompt_templates/{section_actor_agent_name}.yaml", "r") as f:
section_theme_actor_config = yaml.safe_load(f)
title_actor_model = ModelFactory.create(
model_platform=actor_config['model_platform'],
model_type=actor_config['model_type'],
model_config_dict=actor_config['model_config'], # [Optional] the config for model
)
title_actor_sys_msg = title_theme_actor_config['system_prompt']
title_actor_agent = ChatAgent(
system_message=title_actor_sys_msg,
model=title_actor_model,
message_window_size=10, # [Optional] the length for chat memory
)
section_actor_model = ModelFactory.create(
model_platform=actor_config['model_platform'],
model_type=actor_config['model_type'],
model_config_dict=actor_config['model_config'], # [Optional] the config for model
)
section_actor_sys_msg = section_theme_actor_config['system_prompt']
section_actor_agent = ChatAgent(
system_message=section_actor_sys_msg,
model=section_actor_model,
message_window_size=10, # [Optional] the length for chat memory
)
critic_model = ModelFactory.create(
model_platform=critic_config['model_platform'],
model_type=critic_config['model_type'],
model_config_dict=critic_config['model_config'],
)
critic_sys_msg = 'You are a helpful assistant.'
critic_agent = ChatAgent(
system_message=critic_sys_msg,
model=critic_model,
message_window_size=None,
)
theme_aspects = {
'background': ['background'],
'title': ['title_author', 'title_author_border'],
'section': ['section_body', 'section_title', 'section_border']
}
theme_styles = {}
for aspect in theme_aspects.keys():
theme_styles[aspect] = {}
for aspect, prompt_types in theme_aspects.items():
for prompt_type in prompt_types:
print(f'Getting style for {prompt_type}')
with open(f"prompt_templates/theme_templates/theme_{prompt_type}.txt", "r") as f:
prompt = f.read()
msg = BaseMessage.make_user_message(
role_name="User",
content=prompt,
image_list=[template_img],
)
critic_agent.reset()
response = critic_agent.step(msg)
input_token, output_token = account_token(response)
total_input_token += input_token
total_output_token += output_token
extract_input_token += input_token
extract_output_token += output_token
theme_style = get_json_from_response(response.msgs[0].content)
theme_styles[aspect][prompt_type] = theme_style
if 'fontStyle' in theme_styles['section']['section_body']:
del theme_styles['section']['section_body']['fontStyle']
outline_path = f'outlines/{args.model_name}_{args.poster_name}_outline_{args.index}.json'
outline = json.load(open(outline_path, 'r'))
outline_skeleton = {}
for key, val in outline.items():
if key == 'meta':
continue
if not 'subsections' in val:
outline_skeleton[key] = {
'section': key
}
else:
for subsection_name, subsection_dict in val['subsections'].items():
outline_skeleton[subsection_dict['name']] = {
'section': key
}
for key in outline_skeleton.keys():
if 'title' in key.lower() or 'author' in key.lower():
outline_skeleton[key]['style'] = theme_styles['section']['section_title']
else:
outline_skeleton[key]['style'] = theme_styles['section']['section_body']
outline_skeleton_list = []
for section in sections[1:]:
# append all subsections whose section key is the current section
for key, val in outline_skeleton.items():
if val['section'] == section:
outline_skeleton_list.append({key: val})
theme_logs = {}
theme_code = {}
concatenated_code = {}
# Title
jinja_env = Environment(undefined=StrictUndefined)
title_actor_template = jinja_env.from_string(title_theme_actor_config["template"])
# Title section
print(f'Processing section {sections[0]}')
curr_title_code = non_overlap_code[sections[0]]
for style in ['background', 'title']:
for sub_style in theme_styles[style].keys():
print(f' Applying theme for {sub_style}')
jinja_args = {
'style_json': {sub_style: theme_styles[style][sub_style]},
'function_docs': documentation,
'existing_code': curr_title_code
}
actor_prompt = title_actor_template.render(**jinja_args)
log = apply_theme(title_actor_agent, actor_prompt, args.max_retry, existing_code='')
if log[-1]['error'] is not None:
raise Exception(log[-1]['error'])
input_token, output_token = log[-1]['cumulative_tokens']
total_input_token += input_token
total_output_token += output_token
gen_input_token += input_token
gen_output_token += output_token
shutil.copy('poster.pptx', f'tmp/theme_poster_<{sections[0]}>_<{style}>_<{sub_style}>.pptx')
if not style in theme_logs:
theme_logs[style] = {}
theme_logs[style][sub_style] = log
curr_title_code = log[-1]['code']
theme_code[sections[0]] = curr_title_code
concatenated_code[sections[0]] = log[-1]['concatenated_code']
# Remaining sections
jinja_env = Environment(undefined=StrictUndefined)
section_actor_template = jinja_env.from_string(section_theme_actor_config["template"])
prev_section = None
for style_dict in outline_skeleton_list:
curr_subsection = list(style_dict.keys())[0]
curr_section = style_dict[curr_subsection]['section']
section_index = sections.index(curr_section)
print(f'Processing section {curr_section}')
if prev_section != curr_section:
prev_section = curr_section
curr_section_code = non_overlap_code[curr_section]
print(f' Applying theme for {curr_subsection}')
jinja_args = {
'style_json': json.dumps({curr_subsection: style_dict[curr_subsection]['style']}, indent=4),
'function_docs': documentation,
'existing_code': curr_section_code
}
actor_prompt = section_actor_template.render(**jinja_args)
existing_code = concatenated_code[sections[section_index - 1]]
log = apply_theme(section_actor_agent, actor_prompt, args.max_retry, existing_code=existing_code)
if log[-1]['error'] is not None:
raise Exception(log[-1]['error'])
input_token, output_token = log[-1]['cumulative_tokens']
total_input_token += input_token
total_output_token += output_token
gen_input_token += input_token
gen_output_token += output_token
shutil.copy('poster.pptx', f'tmp/theme_poster_<{curr_section}>_<{curr_subsection}>.pptx')
if not style in theme_logs:
theme_logs[style] = {}
theme_logs[style][sub_style] = log
curr_section_code = log[-1]['code']
theme_code[curr_section] = curr_section_code
concatenated_code[curr_section] = log[-1]['concatenated_code']
ppt_to_images(f'poster.pptx', 'tmp/theme_preview')
result_dir = f'results/{args.poster_name}/{args.model_name}/{args.index}'
shutil.copy('poster.pptx', f'{result_dir}/theme_poster.pptx')
ppt_to_images(f'poster.pptx', f'{result_dir}/theme_poster_preview')
ckpt = {
'theme_styles': theme_styles,
'theme_logs': theme_logs,
'theme_code': theme_code,
'concatenated_code': concatenated_code,
'total_input_token': total_input_token,
'total_output_token': total_output_token,
'extract_input_token': extract_input_token,
'extract_output_token': extract_output_token,
'gen_input_token': gen_input_token,
'gen_output_token': gen_output_token
}
pkl.dump(ckpt, open(f'checkpoints/{args.model_name}_{args.poster_name}_theme_ckpt.pkl', 'wb'))
return total_input_token, total_output_token
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--poster_name', type=str, default=None)
parser.add_argument('--model_name', type=str, default='4o')
parser.add_argument('--poster_path', type=str, required=True)
parser.add_argument('--index', type=int, default=0)
parser.add_argument('--template_path', type=str)
parser.add_argument('--max_retry', type=int, default=3)
args = parser.parse_args()
actor_config = get_agent_config(args.model_name)
critic_config = get_agent_config(args.model_name)
if args.poster_name is None:
args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
input_token, output_token = poster_apply_theme(args, actor_config, critic_config)
print(f'Token consumption: {input_token} -> {output_token}')