File size: 8,136 Bytes
7c08dc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import json
import os
import shutil
from collections import defaultdict

from jinja2 import Template

import src.llms as llms
from src.model_utils import get_cluster, get_image_embedding, images_cosine_similarity
from src.presentation import Presentation
from src.utils import Config, pexists, pjoin, tenacity


class SlideInducter:
    """
    Stage I: Presentation Analysis.
    This stage is to analyze the presentation: cluster slides into different layouts, and extract content schema for each layout.
    """

    def __init__(
        self,
        prs: Presentation,
        ppt_image_folder: str,
        template_image_folder: str,
        config: Config,
        image_models: list,
    ):
        """
        Initialize the SlideInducter.

        Args:
            prs (Presentation): The presentation object.
            ppt_image_folder (str): The folder containing PPT images.
            template_image_folder (str): The folder containing normalized slide images.
            config (Config): The configuration object.
            image_models (list): A list of image models.
        """
        self.prs = prs
        self.config = config
        self.ppt_image_folder = ppt_image_folder
        self.template_image_folder = template_image_folder
        assert (
            len(os.listdir(template_image_folder))
            == len(prs)
            == len(os.listdir(ppt_image_folder))
        )
        self.image_models = image_models
        self.slide_induction = defaultdict(lambda: defaultdict(list))
        model_identifier = llms.get_simple_modelname(
            [llms.language_model, llms.vision_model]
        )
        self.output_dir = pjoin(config.RUN_DIR, "template_induct", model_identifier)
        self.split_cache = pjoin(self.output_dir, f"split_cache.json")
        self.induct_cache = pjoin(self.output_dir, f"induct_cache.json")
        os.makedirs(self.output_dir, exist_ok=True)

    def layout_induct(self):
        """
        Perform layout induction for the presentation.
        """
        if pexists(self.induct_cache):
            return json.load(open(self.induct_cache))
        content_slides_index, functional_cluster = self.category_split()
        for layout_name, cluster in functional_cluster.items():
            for slide_idx in cluster:
                content_type = self.prs.slides[slide_idx - 1].get_content_type()
                self.slide_induction[layout_name + ":" + content_type]["slides"].append(
                    slide_idx
                )
        for layout_name, cluster in self.slide_induction.items():
            cluster["template_id"] = cluster["slides"][-1]

        functional_keys = list(self.slide_induction.keys())
        function_slides_index = set()
        for layout_name, cluster in self.slide_induction.items():
            function_slides_index.update(cluster["slides"])
        used_slides_index = function_slides_index.union(content_slides_index)
        for i in range(len(self.prs.slides)):
            if i + 1 not in used_slides_index:
                content_slides_index.add(i + 1)
        self.layout_split(content_slides_index)
        if self.config.DEBUG:
            for layout_name, cluster in self.slide_induction.items():
                cluster_dir = pjoin(self.output_dir, "cluster_slides", layout_name)
                os.makedirs(cluster_dir, exist_ok=True)
                for slide_idx in cluster["slides"]:
                    shutil.copy(
                        pjoin(self.ppt_image_folder, f"slide_{slide_idx:04d}.jpg"),
                        pjoin(cluster_dir, f"slide_{slide_idx:04d}.jpg"),
                    )
        self.slide_induction["functional_keys"] = functional_keys
        json.dump(
            self.slide_induction,
            open(self.induct_cache, "w"),
            indent=4,
            ensure_ascii=False,
        )
        return self.slide_induction

    def category_split(self):
        """
        Split slides into categories based on their functional purpose.
        """
        if pexists(self.split_cache):
            split = json.load(open(self.split_cache))
            return set(split["content_slides_index"]), split["functional_cluster"]
        category_split_template = Template(open("prompts/category_split.txt").read())
        functional_cluster = llms.language_model(
            category_split_template.render(slides=self.prs.to_text()),
            return_json=True,
        )
        functional_slides = set(sum(functional_cluster.values(), []))
        content_slides_index = set(range(1, len(self.prs) + 1)) - functional_slides

        json.dump(
            {
                "content_slides_index": list(content_slides_index),
                "functional_cluster": functional_cluster,
            },
            open(self.split_cache, "w"),
            indent=4,
            ensure_ascii=False,
        )
        return content_slides_index, functional_cluster

    def layout_split(self, content_slides_index: set[int]):
        """
        Cluster slides into different layouts.
        """
        embeddings = get_image_embedding(self.template_image_folder, *self.image_models)
        assert len(embeddings) == len(self.prs)
        template = Template(open("prompts/ask_category.txt").read())
        content_split = defaultdict(list)
        for slide_idx in content_slides_index:
            slide = self.prs.slides[slide_idx - 1]
            content_type = slide.get_content_type()
            layout_name = slide.slide_layout_name
            content_split[(layout_name, content_type)].append(slide_idx)

        for (layout_name, content_type), slides in content_split.items():
            sub_embeddings = [
                embeddings[f"slide_{slide_idx:04d}.jpg"] for slide_idx in slides
            ]
            similarity = images_cosine_similarity(sub_embeddings)
            for cluster in get_cluster(similarity):
                slide_indexs = [slides[i] for i in cluster]
                template_id = max(
                    slide_indexs,
                    key=lambda x: len(self.prs.slides[x - 1].shapes),
                )
                cluster_name = (
                    llms.vision_model(
                        template.render(
                            existed_layoutnames=list(self.slide_induction.keys()),
                        ),
                        pjoin(self.ppt_image_folder, f"slide_{template_id:04d}.jpg"),
                    )
                    + ":"
                    + content_type
                )
                self.slide_induction[cluster_name]["template_id"] = template_id
                self.slide_induction[cluster_name]["slides"] = slide_indexs

    @tenacity
    def content_induct(self):
        """
        Perform content schema extraction for the presentation.
        """
        self.slide_induction = self.layout_induct()
        content_induct_prompt = Template(open("prompts/content_induct.txt").read())
        for layout_name, cluster in self.slide_induction.items():
            if "template_id" in cluster and "content_schema" not in cluster:
                schema = llms.language_model(
                    content_induct_prompt.render(
                        slide=self.prs.slides[cluster["template_id"] - 1].to_html(
                            element_id=False, paragraph_id=False
                        )
                    ),
                    return_json=True,
                )
                for k in list(schema.keys()):
                    if "data" not in schema[k]:
                        raise ValueError(f"Cannot find `data` in {k}\n{schema[k]}")
                    if len(schema[k]["data"]) == 0:
                        print(f"Empty content schema: {schema[k]}")
                        schema.pop(k)
                assert len(schema) > 0, "No content schema generated"
                self.slide_induction[layout_name]["content_schema"] = schema
        json.dump(
            self.slide_induction,
            open(self.induct_cache, "w"),
            indent=4,
            ensure_ascii=False,
        )
        return self.slide_induction