File size: 12,844 Bytes
e34b94f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
"""
Math Vision Dataset Preprocessing Script

This script reads the existing Math Vision dataset from data/math_vision directory
and preprocesses it into the format expected by the dataloader.
The preprocessed data will be saved with fields: prompt, completion, solution, image_path

Usage:
    # Using config file
    uv run scripts/math_vision_process.py --config configs/latent_memory/math_vision.yaml
    
    # Manual parameters
    uv run scripts/math_vision_process.py --input_dir data/math_vision --output_dir data/math_vision
"""

import os
import re
import json
import logging
import argparse
from typing import Dict, List, Optional
import yaml
from datasets import load_dataset, DatasetDict

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


def load_existing_dataset(data_path: str) -> DatasetDict:
    """Load existing Math Vision dataset.
    
    Args:
        data_path: Directory containing train.json and test.json
        
    Returns:
        DatasetDict with train/test splits
    """
    data_files = {}
    train_path = os.path.join(data_path, "train.json")
    test_path = os.path.join(data_path, "test.json")
    
    if os.path.exists(train_path):
        data_files["train"] = train_path
        logging.info(f"Found train.json at {train_path}")
    
    if os.path.exists(test_path):
        data_files["test"] = test_path
        logging.info(f"Found test.json at {test_path}")
    
    if len(data_files) == 0:
        raise FileNotFoundError(f"No data files found in {data_path}")
    
    logging.info(f"Loading dataset from {data_path}")
    dataset_dict = load_dataset("json", data_files=data_files)
    
    return dataset_dict


def split_train_valid(dataset_dict: DatasetDict, val_ratio: float = 0.1) -> DatasetDict:
    """Split train set into train/valid if valid doesn't exist.
    
    Args:
        dataset_dict: Dataset dictionary
        val_ratio: Validation set ratio
        
    Returns:
        DatasetDict with train/valid/test splits
    """
    if "valid" in dataset_dict:
        logging.info("Validation set already exists, skipping split")
        return dataset_dict
    
    if "train" not in dataset_dict:
        logging.warning("No train set found, cannot create validation split")
        return dataset_dict
    
    if val_ratio <= 0 or val_ratio >= 1:
        logging.warning(f"Invalid val_ratio {val_ratio}, skipping validation split")
        return dataset_dict
    
    logging.info(f"Splitting train set with val_ratio={val_ratio}")
    train_test = dataset_dict["train"].train_test_split(test_size=val_ratio, seed=42, shuffle=True)
    
    # Preserve the original test set
    original_test = dataset_dict.get("test", None)
    
    new_dataset_dict = DatasetDict({
        "train": train_test["train"],
        "valid": train_test["test"],
    })
    
    # Add back the original test set if it exists
    if original_test is not None:
        new_dataset_dict["test"] = original_test
    
    logging.info(f"Split sizes - train: {len(new_dataset_dict['train'])}, valid: {len(new_dataset_dict['valid'])}")
    if original_test is not None:
        logging.info(f"Test size: {len(new_dataset_dict['test'])}")
    
    return new_dataset_dict


def preprocess_batch(batch: Dict, image_root: str) -> Dict:
    """Preprocess a batch of examples.
    
    Args:
        batch: Batch of raw examples with fields:
               - id, question, options (list), answer, solution, level, subject, image
        image_root: Root directory for images (not used, as absolute paths are in data)
        
    Returns:
        Preprocessed batch with fields:
        - prompt: formatted question prompt
        - completion: formatted solution/answer text
        - solution: extracted answer (for reward computation)
        - image_path: path to image file
    """
    def _format_answer(answer: str, options: List[str] = None) -> str:
        """Format answer in \\boxed{} format.
        
        For multiple choice, if answer is A/B/C/D/E, include the full option text.
        """
        answer = (answer or "").strip()
        
        # If already in boxed format, return as is
        if answer.startswith("\\boxed{") and answer.endswith("}"):
            return answer
        
        # For multiple choice (single letter), optionally expand to full option
        if len(answer) == 1 and answer.upper() in ['A', 'B', 'C', 'D', 'E'] and options and len(options) > 0:
            # Map letter to option index
            idx = ord(answer.upper()) - ord('A')
            if 0 <= idx < len(options):
                option_text = options[idx]
                return f"\\boxed{{{answer}}}"  # Keep just the letter for now
        
        return "\\boxed{" + answer + "}"

    def _extract_answer(answer_str: str, options: List[str] = None) -> str:
        """Extract raw answer without boxed formatting."""
        answer = (answer_str or "").strip()
        
        # Remove boxed formatting if present
        if answer.startswith("\\boxed{") and answer.endswith("}"):
            answer = answer[7:-1].strip()
        
        return answer

    def _format_question_with_options(question: str, options: List[str] = None) -> str:
        """Format question with multiple choice options if available."""
        formatted = question.strip()
        
        if options and len(options) > 0:
            formatted += "\nOptions:\n"
            for i, opt in enumerate(options):
                letter = chr(ord('A') + i)
                formatted += f"{letter}. {opt}\n"
        
        return formatted

    # Templates
    format_template = r"""Solve the problem and output the answer in the format of \boxed{your answer}."""
    prompt_template = "\n Question: {prompt}\n"
    
    # Get fields from batch
    questions: List[str] = batch.get("question", [])
    options_list: List[List[str]] = batch.get("options", [[]] * len(questions))
    answers: List[str] = batch.get("answer", [""] * len(questions))
    solutions: List[Optional[str]] = batch.get("solution", [None] * len(questions))
    image_paths_src: List[str] = batch.get("image", [""] * len(questions))
    
    prompts: List[str] = []
    completions: List[str] = []
    solution_labels: List[str] = []
    image_paths: List[str] = []
    
    for q, opts, ans, sol, img_path in zip(questions, options_list, answers, solutions, image_paths_src):
        # Format question with options
        formatted_q = _format_question_with_options(q, opts)
        processed_prompt = format_template + prompt_template.format(prompt=formatted_q)
        
        # Extract and format answer
        raw_answer = _extract_answer(ans, opts)
        solution_label = _format_answer(ans, opts)
        
        # For completion, use solution if available, otherwise use answer
        if sol and sol.strip():
            # Solution exists, use it as completion
            completion_text = sol.strip()
        else:
            # No solution, create a simple completion with the answer
            completion_text = f"The answer is {solution_label}"
        
        prompts.append(processed_prompt)
        completions.append(completion_text)
        solution_labels.append(solution_label)
        
        # Image path is already absolute in the data
        image_paths.append(img_path if img_path else None)
    
    return {
        "prompt": prompts,
        "completion": completions,
        "solution": solution_labels,
        "image_path": image_paths,
    }


def preprocess_dataset(dataset_dict: DatasetDict, image_root: str, batch_size: int = 512) -> DatasetDict:
    """Preprocess all splits.
    
    Args:
        dataset_dict: Raw dataset dictionary
        image_root: Root directory for images (not used for this dataset)
        batch_size: Batch size for processing
        
    Returns:
        Preprocessed DatasetDict with fields: prompt, completion, solution, image_path
    """
    keep_keys = ["prompt", "completion", "solution", "image_path"]

    def _map(split):
        logging.info(f"Preprocessing {split} split with batch_size={batch_size}")
        ds = dataset_dict[split].map(
            lambda batch: preprocess_batch(batch, image_root),
            batched=True,
            batch_size=batch_size,
            num_proc=None,
            remove_columns=dataset_dict[split].column_names,
            desc=f"Math_Vision preprocess ({split})",
        )

        # Filter out samples with empty solution
        def has_valid_solution(example):
            solution = example.get("solution", "")
            return solution is not None and len(solution.strip()) > 0

        ds_filtered = ds.filter(has_valid_solution, desc=f"Filter empty solutions ({split})")
        
        num_filtered = len(ds) - len(ds_filtered)
        if num_filtered > 0:
            logging.info(f"Filtered {num_filtered} samples with empty solutions from {split}")
        
        return ds_filtered.select_columns(keep_keys)

    result = DatasetDict({split: _map(split) for split in dataset_dict.keys()})
    
    for split, ds in result.items():
        logging.info(f"Preprocessed {split}: {len(ds)} samples")
    
    return result


def save_dataset(dataset_dict: DatasetDict, output_dir: str):
    """Save preprocessed dataset to JSON files.
    
    Args:
        dataset_dict: Preprocessed dataset
        output_dir: Output directory
    """
    os.makedirs(output_dir, exist_ok=True)
    
    for split_name, ds in dataset_dict.items():
        output_path = os.path.join(output_dir, f"{split_name}.json")
        logging.info(f"Saving {split_name} split to {output_path}")
        
        # Convert to list of dicts and save as JSON
        data = []
        for example in ds:
            data.append({
                "prompt": example["prompt"],
                "completion": example["completion"],
                "solution": example["solution"],
                "image_path": example["image_path"],
            })
        
        with open(output_path, "w", encoding="utf-8") as f:
            json.dump(data, f, ensure_ascii=False, indent=2)
        
        logging.info(f"Saved {len(data)} samples to {output_path}")


def main():
    parser = argparse.ArgumentParser(description="Preprocess Math Vision dataset")
    parser.add_argument("--config", type=str, help="Path to config YAML file")
    parser.add_argument("--input_dir", type=str, default="/root/CVPR/MemGen/data/math_vision", help="Input directory with train.json/test.json")
    parser.add_argument("--output_dir", type=str, default="/root/CVPR/MemGen/data/math_vision", help="Output directory for preprocessed data")
    parser.add_argument("--val_ratio", type=float, default=0.1, help="Validation set ratio")
    parser.add_argument("--image_root", type=str, default="/root/CVPR/MemGen/dataset/math_vision/images", help="Image root directory")
    parser.add_argument("--batch_size", type=int, default=512, help="Batch size for preprocessing")
    
    args = parser.parse_args()
    
    # Load config if provided
    if args.config:
        logging.info(f"Loading config from {args.config}")
        with open(args.config, "r") as f:
            cfg = yaml.safe_load(f)
        
        # Extract dataset config
        dataset_cfg = cfg.get("datasets", {}).get("math_vision", {})
        mode = dataset_cfg.get("mode", "sft")
        mode_cfg = dataset_cfg.get(mode, {})
        
        val_ratio = mode_cfg.get("val_ratio", args.val_ratio)
        image_root = mode_cfg.get("image_root", args.image_root)
    else:
        val_ratio = args.val_ratio
        image_root = args.image_root
    
    input_dir = args.input_dir
    output_dir = args.output_dir
    batch_size = args.batch_size
    
    logging.info("=" * 80)
    logging.info("Math Vision Dataset Preprocessing")
    logging.info("=" * 80)
    logging.info(f"Input directory: {input_dir}")
    logging.info(f"Output directory: {output_dir}")
    logging.info(f"Validation ratio: {val_ratio}")
    logging.info(f"Image root: {image_root}")
    logging.info(f"Batch size: {batch_size}")
    
    # Step 1: Load existing dataset
    dataset_dict = load_existing_dataset(input_dir)
    
    # Step 2: Split train into train/valid if needed
    dataset_dict = split_train_valid(dataset_dict, val_ratio=val_ratio)
    
    # Step 3: Preprocess dataset
    preprocessed = preprocess_dataset(dataset_dict, image_root, batch_size=batch_size)
    
    # Step 4: Save preprocessed data
    save_dataset(preprocessed, output_dir)
    
    logging.info("=" * 80)
    logging.info("Preprocessing complete!")
    logging.info("=" * 80)
    for split, ds in preprocessed.items():
        logging.info(f"{split}: {len(ds)} samples")
    logging.info(f"Output saved to: {output_dir}")


if __name__ == "__main__":
    main()