|
|
from dotenv import load_dotenv |
|
|
from utils.src.utils import ppt_to_images, get_json_from_response |
|
|
import json |
|
|
|
|
|
from camel.models import ModelFactory |
|
|
from camel.agents import ChatAgent |
|
|
|
|
|
from utils.wei_utils import * |
|
|
|
|
|
from camel.messages import BaseMessage |
|
|
from PIL import Image |
|
|
import pickle as pkl |
|
|
from utils.pptx_utils import * |
|
|
from utils.critic_utils import * |
|
|
import yaml |
|
|
import argparse |
|
|
import shutil |
|
|
from jinja2 import Environment, StrictUndefined |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
import copy |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
MAX_ATTEMPTS = 5 |
|
|
|
|
|
def process_leaf_section( |
|
|
leaf_section, |
|
|
section_name, |
|
|
outline, |
|
|
content, |
|
|
style_logs, |
|
|
critic_logs, |
|
|
actor_logs, |
|
|
img_logs, |
|
|
slide_width, |
|
|
slide_height, |
|
|
name_to_hierarchy, |
|
|
critic_template, |
|
|
actor_template, |
|
|
critic_agent, |
|
|
actor_agent, |
|
|
neg_img, |
|
|
pos_img, |
|
|
MAX_ATTEMPTS, |
|
|
documentation, |
|
|
total_input_token, |
|
|
total_output_token, |
|
|
): |
|
|
""" |
|
|
Handles the logic for a single leaf_section within a section_name. |
|
|
|
|
|
Returns a dictionary of updated logs and tokens. |
|
|
""" |
|
|
section_code = style_logs[section_name][-1]['code'] |
|
|
log = [] |
|
|
leaf_name = None |
|
|
if leaf_section in outline: |
|
|
leaf_name = outline[leaf_section]['name'] |
|
|
else: |
|
|
leaf_name = outline[section_name]['subsections'][leaf_section]['name'] |
|
|
|
|
|
num_rounds = 0 |
|
|
while True: |
|
|
print(f"Section: {section_name}, Leaf Section: {leaf_section}, Round: {num_rounds}") |
|
|
num_rounds += 1 |
|
|
if num_rounds > MAX_ATTEMPTS: |
|
|
break |
|
|
|
|
|
poster = create_poster(slide_width, slide_height) |
|
|
add_blank_slide(poster) |
|
|
empty_poster_path = f'tmp/empty_poster_{section_name}_{leaf_section}.pptx' |
|
|
save_presentation(poster, file_name=empty_poster_path) |
|
|
|
|
|
curr_location, zoomed_in_img, zoomed_in_img_path = get_snapshot_from_section( |
|
|
leaf_section, |
|
|
section_name, |
|
|
name_to_hierarchy, |
|
|
leaf_name, |
|
|
section_code, |
|
|
empty_poster_path |
|
|
) |
|
|
|
|
|
if leaf_section not in img_logs: |
|
|
img_logs[leaf_section] = [] |
|
|
img_logs[leaf_section].append(zoomed_in_img) |
|
|
|
|
|
jinja_args = { |
|
|
'content_json': content[leaf_section] if leaf_section in content |
|
|
else content[section_name]['subsections'][leaf_section], |
|
|
'existing_code': section_code, |
|
|
} |
|
|
|
|
|
critic_prompt = critic_template.render(**jinja_args) |
|
|
|
|
|
critic_msg = BaseMessage.make_user_message( |
|
|
role_name="User", |
|
|
content=critic_prompt, |
|
|
image_list=[neg_img, pos_img, zoomed_in_img], |
|
|
) |
|
|
|
|
|
critic_agent.reset() |
|
|
response = critic_agent.step(critic_msg) |
|
|
resp = response.msgs[0].content |
|
|
|
|
|
|
|
|
input_token, output_token = account_token(response) |
|
|
total_input_token += input_token |
|
|
total_output_token += output_token |
|
|
|
|
|
if leaf_section not in critic_logs: |
|
|
critic_logs[leaf_section] = [] |
|
|
critic_logs[leaf_section].append(response) |
|
|
|
|
|
|
|
|
if isinstance(resp, str): |
|
|
if resp in ['NO', 'NO.', '"NO"', "'NO'"]: |
|
|
break |
|
|
|
|
|
feedback = get_json_from_response(resp) |
|
|
print(feedback) |
|
|
|
|
|
jinja_args = { |
|
|
'content_json': content[leaf_section] if leaf_section in content |
|
|
else content[section_name]['subsections'][leaf_section], |
|
|
'function_docs': documentation, |
|
|
'existing_code': section_code, |
|
|
'suggestion_json': feedback, |
|
|
} |
|
|
|
|
|
actor_prompt = actor_template.render(**jinja_args) |
|
|
|
|
|
leaf_log = edit_code(actor_agent, actor_prompt, 3, existing_code='') |
|
|
if leaf_log[-1]['error'] is not None: |
|
|
raise Exception(leaf_log[-1]['error']) |
|
|
|
|
|
|
|
|
in_tok = leaf_log[-1]['cumulative_tokens'][0] |
|
|
out_tok = leaf_log[-1]['cumulative_tokens'][1] |
|
|
total_input_token += in_tok |
|
|
total_output_token += out_tok |
|
|
|
|
|
section_code = leaf_log[-1]['code'] |
|
|
|
|
|
if leaf_section not in actor_logs: |
|
|
actor_logs[leaf_section] = [] |
|
|
actor_logs[leaf_section].append(leaf_log) |
|
|
|
|
|
log.extend(leaf_log) |
|
|
|
|
|
return { |
|
|
"section_code": section_code, |
|
|
"log": log, |
|
|
"img_logs": img_logs, |
|
|
"critic_logs": critic_logs, |
|
|
"actor_logs": actor_logs, |
|
|
"total_input_token": total_input_token, |
|
|
"total_output_token": total_output_token, |
|
|
} |
|
|
|
|
|
|
|
|
def process_section( |
|
|
section_name, |
|
|
content, |
|
|
outline, |
|
|
sections, |
|
|
style_logs, |
|
|
critic_logs, |
|
|
actor_logs, |
|
|
img_logs, |
|
|
slide_width, |
|
|
slide_height, |
|
|
name_to_hierarchy, |
|
|
critic_template, |
|
|
actor_template, |
|
|
critic_agent, |
|
|
actor_agent, |
|
|
neg_img, |
|
|
pos_img, |
|
|
MAX_ATTEMPTS, |
|
|
documentation, |
|
|
total_input_token, |
|
|
total_output_token, |
|
|
): |
|
|
""" |
|
|
Handles processing of a single section and its subsections (leaf sections). |
|
|
Returns updated logs and token counters for this section. |
|
|
""" |
|
|
results_per_leaf = [] |
|
|
|
|
|
|
|
|
section_code = style_logs[section_name][-1]['code'] |
|
|
|
|
|
|
|
|
if 'subsections' in content[section_name]: |
|
|
subsections = list(content[section_name]['subsections'].keys()) |
|
|
else: |
|
|
subsections = [section_name] |
|
|
|
|
|
all_logs_for_section = [] |
|
|
|
|
|
for leaf_section in subsections: |
|
|
|
|
|
leaf_result = process_leaf_section( |
|
|
leaf_section, |
|
|
section_name, |
|
|
outline, |
|
|
content, |
|
|
style_logs, |
|
|
critic_logs, |
|
|
actor_logs, |
|
|
img_logs, |
|
|
slide_width, |
|
|
slide_height, |
|
|
name_to_hierarchy, |
|
|
critic_template, |
|
|
actor_template, |
|
|
critic_agent, |
|
|
actor_agent, |
|
|
neg_img, |
|
|
pos_img, |
|
|
MAX_ATTEMPTS, |
|
|
documentation, |
|
|
total_input_token, |
|
|
total_output_token, |
|
|
) |
|
|
|
|
|
|
|
|
section_code = leaf_result["section_code"] |
|
|
all_logs_for_section.extend(leaf_result["log"]) |
|
|
img_logs = leaf_result["img_logs"] |
|
|
critic_logs = leaf_result["critic_logs"] |
|
|
actor_logs = leaf_result["actor_logs"] |
|
|
total_input_token = leaf_result["total_input_token"] |
|
|
total_output_token = leaf_result["total_output_token"] |
|
|
|
|
|
|
|
|
if all_logs_for_section: |
|
|
style_logs[section_name].append(all_logs_for_section[-1]) |
|
|
|
|
|
|
|
|
return { |
|
|
"section_name": section_name, |
|
|
"style_logs": style_logs, |
|
|
"critic_logs": critic_logs, |
|
|
"actor_logs": actor_logs, |
|
|
"img_logs": img_logs, |
|
|
"total_input_token": total_input_token, |
|
|
"total_output_token": total_output_token |
|
|
} |
|
|
|
|
|
def parallel_by_sections( |
|
|
sections, |
|
|
content, |
|
|
outline, |
|
|
style_logs, |
|
|
critic_logs, |
|
|
actor_logs, |
|
|
img_logs, |
|
|
slide_width, |
|
|
slide_height, |
|
|
name_to_hierarchy, |
|
|
critic_template, |
|
|
actor_template, |
|
|
critic_agent, |
|
|
actor_agent, |
|
|
neg_img, |
|
|
pos_img, |
|
|
MAX_ATTEMPTS, |
|
|
documentation, |
|
|
total_input_token, |
|
|
total_output_token, |
|
|
max_workers=4 |
|
|
): |
|
|
""" |
|
|
Main entry point to parallelize processing across sections. |
|
|
|
|
|
Returns the merged logs and token counters after processing all sections in parallel. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
results = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor: |
|
|
futures = [] |
|
|
|
|
|
for section_name in sections: |
|
|
|
|
|
_style_logs = copy.deepcopy(style_logs) |
|
|
_critic_logs = copy.deepcopy(critic_logs) |
|
|
_actor_logs = copy.deepcopy(actor_logs) |
|
|
_img_logs = copy.deepcopy(img_logs) |
|
|
|
|
|
futures.append(executor.submit( |
|
|
process_section, |
|
|
section_name, |
|
|
content, |
|
|
outline, |
|
|
sections, |
|
|
_style_logs, |
|
|
_critic_logs, |
|
|
_actor_logs, |
|
|
_img_logs, |
|
|
slide_width, |
|
|
slide_height, |
|
|
name_to_hierarchy, |
|
|
critic_template, |
|
|
actor_template, |
|
|
critic_agent, |
|
|
actor_agent, |
|
|
neg_img, |
|
|
pos_img, |
|
|
MAX_ATTEMPTS, |
|
|
documentation, |
|
|
total_input_token, |
|
|
total_output_token |
|
|
)) |
|
|
|
|
|
for future in futures: |
|
|
results.append(future.result()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for res in results: |
|
|
sec_name = res["section_name"] |
|
|
|
|
|
style_logs[sec_name] = res["style_logs"][sec_name] |
|
|
critic_logs.update(res["critic_logs"]) |
|
|
actor_logs.update(res["actor_logs"]) |
|
|
img_logs.update(res["img_logs"]) |
|
|
total_input_token = res["total_input_token"] |
|
|
total_output_token = res["total_output_token"] |
|
|
|
|
|
return style_logs, critic_logs, actor_logs, img_logs, total_input_token, total_output_token |
|
|
|
|
|
|
|
|
def deoverflow(args, actor_config, critic_config): |
|
|
total_input_token, total_output_token = 0, 0 |
|
|
style_ckpt = pkl.load(open(f'checkpoints/{args.model_name}_{args.poster_name}_style_ckpt_{args.index}.pkl', 'rb')) |
|
|
logs_ckpt = pkl.load(open(f'checkpoints/{args.model_name}_{args.poster_name}_ckpt_{args.index}.pkl', 'rb')) |
|
|
|
|
|
style_logs = style_ckpt['style_logs'] |
|
|
sections = list(style_logs.keys()) |
|
|
sections = [s for s in sections if s != 'meta'] |
|
|
|
|
|
slide_width = style_ckpt['outline']['meta']['width'] |
|
|
slide_height = style_ckpt['outline']['meta']['height'] |
|
|
|
|
|
content = json.load(open(f'contents/{args.model_name}_{args.poster_name}_poster_content_{args.index}.json', 'r')) |
|
|
outline = logs_ckpt['outline'] |
|
|
|
|
|
name_to_hierarchy = get_hierarchy(outline, 1) |
|
|
|
|
|
critic_agent_name = 'critic_overlap_agent' |
|
|
with open(f"prompt_templates/{critic_agent_name}.yaml", "r") as f: |
|
|
deoverflow_critic_config = yaml.safe_load(f) |
|
|
|
|
|
actor_agent_name = 'actor_editor_agent' |
|
|
|
|
|
with open(f"prompt_templates/{actor_agent_name}.yaml", "r") as f: |
|
|
deoverflow_actor_config = yaml.safe_load(f) |
|
|
|
|
|
actor_model = ModelFactory.create( |
|
|
model_platform=actor_config['model_platform'], |
|
|
model_type=actor_config['model_type'], |
|
|
model_config_dict=actor_config['model_config'], |
|
|
) |
|
|
|
|
|
actor_sys_msg = deoverflow_actor_config['system_prompt'] |
|
|
|
|
|
actor_agent = ChatAgent( |
|
|
system_message=actor_sys_msg, |
|
|
model=actor_model, |
|
|
message_window_size=10, |
|
|
) |
|
|
|
|
|
critic_model = ModelFactory.create( |
|
|
model_platform=critic_config['model_platform'], |
|
|
model_type=critic_config['model_type'], |
|
|
model_config_dict=critic_config['model_config'], |
|
|
) |
|
|
|
|
|
critic_sys_msg = deoverflow_critic_config['system_prompt'] |
|
|
|
|
|
critic_agent = ChatAgent( |
|
|
system_message=critic_sys_msg, |
|
|
model=critic_model, |
|
|
message_window_size=None, |
|
|
) |
|
|
|
|
|
jinja_env = Environment(undefined=StrictUndefined) |
|
|
|
|
|
actor_template = jinja_env.from_string(deoverflow_actor_config["template"]) |
|
|
critic_template = jinja_env.from_string(deoverflow_critic_config["template"]) |
|
|
|
|
|
critic_logs = {} |
|
|
actor_logs = {} |
|
|
img_logs = {} |
|
|
|
|
|
|
|
|
neg_img = Image.open('overflow_example/neg.jpg') |
|
|
pos_img = Image.open('overflow_example/pos.jpg') |
|
|
|
|
|
style_logs, critic_logs, actor_logs, img_logs, total_input_token, total_output_token = parallel_by_sections( |
|
|
sections=sections, |
|
|
content=content, |
|
|
outline=outline, |
|
|
style_logs=style_logs, |
|
|
critic_logs=critic_logs, |
|
|
actor_logs=actor_logs, |
|
|
img_logs=img_logs, |
|
|
slide_width=slide_width, |
|
|
slide_height=slide_height, |
|
|
name_to_hierarchy=name_to_hierarchy, |
|
|
critic_template=critic_template, |
|
|
actor_template=actor_template, |
|
|
critic_agent=critic_agent, |
|
|
actor_agent=actor_agent, |
|
|
neg_img=neg_img, |
|
|
pos_img=pos_img, |
|
|
MAX_ATTEMPTS=MAX_ATTEMPTS, |
|
|
documentation=documentation, |
|
|
total_input_token=total_input_token, |
|
|
total_output_token=total_output_token, |
|
|
max_workers=100, |
|
|
) |
|
|
|
|
|
final_code = '' |
|
|
for section in sections: |
|
|
final_code += style_logs[section][-1]['code'] + '\n' |
|
|
|
|
|
run_code_with_utils(final_code, utils_functions) |
|
|
ppt_to_images(f'poster.pptx', 'tmp/non_overlap_preview') |
|
|
|
|
|
result_dir = f'results/{args.poster_name}/{args.model_name}/{args.index}' |
|
|
if not os.path.exists(result_dir): |
|
|
os.makedirs(result_dir) |
|
|
shutil.copy('poster.pptx', f'{result_dir}/non_overlap_poster.pptx') |
|
|
ppt_to_images(f'poster.pptx', f'{result_dir}/non_overlap_poster_preview') |
|
|
|
|
|
final_code_by_section = {} |
|
|
for section in sections: |
|
|
final_code_by_section[section] = style_logs[section][-1]['code'] |
|
|
|
|
|
non_overlap_ckpt = { |
|
|
'critic_logs': critic_logs, |
|
|
'actor_logs': actor_logs, |
|
|
'img_logs': img_logs, |
|
|
'name_to_hierarchy': name_to_hierarchy, |
|
|
'final_code': final_code, |
|
|
'final_code_by_section': final_code_by_section, |
|
|
'total_input_token': total_input_token, |
|
|
'total_output_token': total_output_token |
|
|
} |
|
|
|
|
|
pkl.dump(non_overlap_ckpt, open(f'checkpoints/{args.model_name}_{args.poster_name}_non_overlap_ckpt_{args.index}.pkl', 'wb')) |
|
|
|
|
|
return total_input_token, total_output_token |
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--poster_name', type=str, default=None) |
|
|
parser.add_argument('--model_name', type=str, default='4o') |
|
|
parser.add_argument('--poster_path', type=str, required=True) |
|
|
parser.add_argument('--index', type=int, default=0) |
|
|
parser.add_argument('--max_retry', type=int, default=3) |
|
|
args = parser.parse_args() |
|
|
|
|
|
actor_config = get_agent_config(args.model_name) |
|
|
critic_config = get_agent_config(args.model_name) |
|
|
|
|
|
if args.poster_name is None: |
|
|
args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_') |
|
|
|
|
|
input_token, output_token = deoverflow(args, actor_config, critic_config) |
|
|
|
|
|
print(f'Token consumption: {input_token} -> {output_token}') |