""" Math Vision Dataset Preprocessing Script This script reads the existing Math Vision dataset from data/math_vision directory and preprocesses it into the format expected by the dataloader. The preprocessed data will be saved with fields: prompt, completion, solution, image_path Usage: # Using config file uv run scripts/math_vision_process.py --config configs/latent_memory/math_vision.yaml # Manual parameters uv run scripts/math_vision_process.py --input_dir data/math_vision --output_dir data/math_vision """ import os import re import json import logging import argparse from typing import Dict, List, Optional import yaml from datasets import load_dataset, DatasetDict logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') def load_existing_dataset(data_path: str) -> DatasetDict: """Load existing Math Vision dataset. Args: data_path: Directory containing train.json and test.json Returns: DatasetDict with train/test splits """ data_files = {} train_path = os.path.join(data_path, "train.json") test_path = os.path.join(data_path, "test.json") if os.path.exists(train_path): data_files["train"] = train_path logging.info(f"Found train.json at {train_path}") if os.path.exists(test_path): data_files["test"] = test_path logging.info(f"Found test.json at {test_path}") if len(data_files) == 0: raise FileNotFoundError(f"No data files found in {data_path}") logging.info(f"Loading dataset from {data_path}") dataset_dict = load_dataset("json", data_files=data_files) return dataset_dict def split_train_valid(dataset_dict: DatasetDict, val_ratio: float = 0.1) -> DatasetDict: """Split train set into train/valid if valid doesn't exist. Args: dataset_dict: Dataset dictionary val_ratio: Validation set ratio Returns: DatasetDict with train/valid/test splits """ if "valid" in dataset_dict: logging.info("Validation set already exists, skipping split") return dataset_dict if "train" not in dataset_dict: logging.warning("No train set found, cannot create validation split") return dataset_dict if val_ratio <= 0 or val_ratio >= 1: logging.warning(f"Invalid val_ratio {val_ratio}, skipping validation split") return dataset_dict logging.info(f"Splitting train set with val_ratio={val_ratio}") train_test = dataset_dict["train"].train_test_split(test_size=val_ratio, seed=42, shuffle=True) # Preserve the original test set original_test = dataset_dict.get("test", None) new_dataset_dict = DatasetDict({ "train": train_test["train"], "valid": train_test["test"], }) # Add back the original test set if it exists if original_test is not None: new_dataset_dict["test"] = original_test logging.info(f"Split sizes - train: {len(new_dataset_dict['train'])}, valid: {len(new_dataset_dict['valid'])}") if original_test is not None: logging.info(f"Test size: {len(new_dataset_dict['test'])}") return new_dataset_dict def preprocess_batch(batch: Dict, image_root: str) -> Dict: """Preprocess a batch of examples. Args: batch: Batch of raw examples with fields: - id, question, options (list), answer, solution, level, subject, image image_root: Root directory for images (not used, as absolute paths are in data) Returns: Preprocessed batch with fields: - prompt: formatted question prompt - completion: formatted solution/answer text - solution: extracted answer (for reward computation) - image_path: path to image file """ def _format_answer(answer: str, options: List[str] = None) -> str: """Format answer in \\boxed{} format. For multiple choice, if answer is A/B/C/D/E, include the full option text. """ answer = (answer or "").strip() # If already in boxed format, return as is if answer.startswith("\\boxed{") and answer.endswith("}"): return answer # For multiple choice (single letter), optionally expand to full option if len(answer) == 1 and answer.upper() in ['A', 'B', 'C', 'D', 'E'] and options and len(options) > 0: # Map letter to option index idx = ord(answer.upper()) - ord('A') if 0 <= idx < len(options): option_text = options[idx] return f"\\boxed{{{answer}}}" # Keep just the letter for now return "\\boxed{" + answer + "}" def _extract_answer(answer_str: str, options: List[str] = None) -> str: """Extract raw answer without boxed formatting.""" answer = (answer_str or "").strip() # Remove boxed formatting if present if answer.startswith("\\boxed{") and answer.endswith("}"): answer = answer[7:-1].strip() return answer def _format_question_with_options(question: str, options: List[str] = None) -> str: """Format question with multiple choice options if available.""" formatted = question.strip() if options and len(options) > 0: formatted += "\nOptions:\n" for i, opt in enumerate(options): letter = chr(ord('A') + i) formatted += f"{letter}. {opt}\n" return formatted # Templates format_template = r"""Solve the problem and output the answer in the format of \boxed{your answer}.""" prompt_template = "\n Question: {prompt}\n" # Get fields from batch questions: List[str] = batch.get("question", []) options_list: List[List[str]] = batch.get("options", [[]] * len(questions)) answers: List[str] = batch.get("answer", [""] * len(questions)) solutions: List[Optional[str]] = batch.get("solution", [None] * len(questions)) image_paths_src: List[str] = batch.get("image", [""] * len(questions)) prompts: List[str] = [] completions: List[str] = [] solution_labels: List[str] = [] image_paths: List[str] = [] for q, opts, ans, sol, img_path in zip(questions, options_list, answers, solutions, image_paths_src): # Format question with options formatted_q = _format_question_with_options(q, opts) processed_prompt = format_template + prompt_template.format(prompt=formatted_q) # Extract and format answer raw_answer = _extract_answer(ans, opts) solution_label = _format_answer(ans, opts) # For completion, use solution if available, otherwise use answer if sol and sol.strip(): # Solution exists, use it as completion completion_text = sol.strip() else: # No solution, create a simple completion with the answer completion_text = f"The answer is {solution_label}" prompts.append(processed_prompt) completions.append(completion_text) solution_labels.append(solution_label) # Image path is already absolute in the data image_paths.append(img_path if img_path else None) return { "prompt": prompts, "completion": completions, "solution": solution_labels, "image_path": image_paths, } def preprocess_dataset(dataset_dict: DatasetDict, image_root: str, batch_size: int = 512) -> DatasetDict: """Preprocess all splits. Args: dataset_dict: Raw dataset dictionary image_root: Root directory for images (not used for this dataset) batch_size: Batch size for processing Returns: Preprocessed DatasetDict with fields: prompt, completion, solution, image_path """ keep_keys = ["prompt", "completion", "solution", "image_path"] def _map(split): logging.info(f"Preprocessing {split} split with batch_size={batch_size}") ds = dataset_dict[split].map( lambda batch: preprocess_batch(batch, image_root), batched=True, batch_size=batch_size, num_proc=None, remove_columns=dataset_dict[split].column_names, desc=f"Math_Vision preprocess ({split})", ) # Filter out samples with empty solution def has_valid_solution(example): solution = example.get("solution", "") return solution is not None and len(solution.strip()) > 0 ds_filtered = ds.filter(has_valid_solution, desc=f"Filter empty solutions ({split})") num_filtered = len(ds) - len(ds_filtered) if num_filtered > 0: logging.info(f"Filtered {num_filtered} samples with empty solutions from {split}") return ds_filtered.select_columns(keep_keys) result = DatasetDict({split: _map(split) for split in dataset_dict.keys()}) for split, ds in result.items(): logging.info(f"Preprocessed {split}: {len(ds)} samples") return result def save_dataset(dataset_dict: DatasetDict, output_dir: str): """Save preprocessed dataset to JSON files. Args: dataset_dict: Preprocessed dataset output_dir: Output directory """ os.makedirs(output_dir, exist_ok=True) for split_name, ds in dataset_dict.items(): output_path = os.path.join(output_dir, f"{split_name}.json") logging.info(f"Saving {split_name} split to {output_path}") # Convert to list of dicts and save as JSON data = [] for example in ds: data.append({ "prompt": example["prompt"], "completion": example["completion"], "solution": example["solution"], "image_path": example["image_path"], }) with open(output_path, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) logging.info(f"Saved {len(data)} samples to {output_path}") def main(): parser = argparse.ArgumentParser(description="Preprocess Math Vision dataset") parser.add_argument("--config", type=str, help="Path to config YAML file") parser.add_argument("--input_dir", type=str, default="/root/CVPR/MemGen/data/math_vision", help="Input directory with train.json/test.json") parser.add_argument("--output_dir", type=str, default="/root/CVPR/MemGen/data/math_vision", help="Output directory for preprocessed data") parser.add_argument("--val_ratio", type=float, default=0.1, help="Validation set ratio") parser.add_argument("--image_root", type=str, default="/root/CVPR/MemGen/dataset/math_vision/images", help="Image root directory") parser.add_argument("--batch_size", type=int, default=512, help="Batch size for preprocessing") args = parser.parse_args() # Load config if provided if args.config: logging.info(f"Loading config from {args.config}") with open(args.config, "r") as f: cfg = yaml.safe_load(f) # Extract dataset config dataset_cfg = cfg.get("datasets", {}).get("math_vision", {}) mode = dataset_cfg.get("mode", "sft") mode_cfg = dataset_cfg.get(mode, {}) val_ratio = mode_cfg.get("val_ratio", args.val_ratio) image_root = mode_cfg.get("image_root", args.image_root) else: val_ratio = args.val_ratio image_root = args.image_root input_dir = args.input_dir output_dir = args.output_dir batch_size = args.batch_size logging.info("=" * 80) logging.info("Math Vision Dataset Preprocessing") logging.info("=" * 80) logging.info(f"Input directory: {input_dir}") logging.info(f"Output directory: {output_dir}") logging.info(f"Validation ratio: {val_ratio}") logging.info(f"Image root: {image_root}") logging.info(f"Batch size: {batch_size}") # Step 1: Load existing dataset dataset_dict = load_existing_dataset(input_dir) # Step 2: Split train into train/valid if needed dataset_dict = split_train_valid(dataset_dict, val_ratio=val_ratio) # Step 3: Preprocess dataset preprocessed = preprocess_dataset(dataset_dict, image_root, batch_size=batch_size) # Step 4: Save preprocessed data save_dataset(preprocessed, output_dir) logging.info("=" * 80) logging.info("Preprocessing complete!") logging.info("=" * 80) for split, ds in preprocessed.items(): logging.info(f"{split}: {len(ds)} samples") logging.info(f"Output saved to: {output_dir}") if __name__ == "__main__": main()