model111 / scripts /math_vision_process.py
LCZZZZ's picture
Upload MemGen code and data
e34b94f verified
"""
Math Vision Dataset Preprocessing Script
This script reads the existing Math Vision dataset from data/math_vision directory
and preprocesses it into the format expected by the dataloader.
The preprocessed data will be saved with fields: prompt, completion, solution, image_path
Usage:
# Using config file
uv run scripts/math_vision_process.py --config configs/latent_memory/math_vision.yaml
# Manual parameters
uv run scripts/math_vision_process.py --input_dir data/math_vision --output_dir data/math_vision
"""
import os
import re
import json
import logging
import argparse
from typing import Dict, List, Optional
import yaml
from datasets import load_dataset, DatasetDict
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def load_existing_dataset(data_path: str) -> DatasetDict:
"""Load existing Math Vision dataset.
Args:
data_path: Directory containing train.json and test.json
Returns:
DatasetDict with train/test splits
"""
data_files = {}
train_path = os.path.join(data_path, "train.json")
test_path = os.path.join(data_path, "test.json")
if os.path.exists(train_path):
data_files["train"] = train_path
logging.info(f"Found train.json at {train_path}")
if os.path.exists(test_path):
data_files["test"] = test_path
logging.info(f"Found test.json at {test_path}")
if len(data_files) == 0:
raise FileNotFoundError(f"No data files found in {data_path}")
logging.info(f"Loading dataset from {data_path}")
dataset_dict = load_dataset("json", data_files=data_files)
return dataset_dict
def split_train_valid(dataset_dict: DatasetDict, val_ratio: float = 0.1) -> DatasetDict:
"""Split train set into train/valid if valid doesn't exist.
Args:
dataset_dict: Dataset dictionary
val_ratio: Validation set ratio
Returns:
DatasetDict with train/valid/test splits
"""
if "valid" in dataset_dict:
logging.info("Validation set already exists, skipping split")
return dataset_dict
if "train" not in dataset_dict:
logging.warning("No train set found, cannot create validation split")
return dataset_dict
if val_ratio <= 0 or val_ratio >= 1:
logging.warning(f"Invalid val_ratio {val_ratio}, skipping validation split")
return dataset_dict
logging.info(f"Splitting train set with val_ratio={val_ratio}")
train_test = dataset_dict["train"].train_test_split(test_size=val_ratio, seed=42, shuffle=True)
# Preserve the original test set
original_test = dataset_dict.get("test", None)
new_dataset_dict = DatasetDict({
"train": train_test["train"],
"valid": train_test["test"],
})
# Add back the original test set if it exists
if original_test is not None:
new_dataset_dict["test"] = original_test
logging.info(f"Split sizes - train: {len(new_dataset_dict['train'])}, valid: {len(new_dataset_dict['valid'])}")
if original_test is not None:
logging.info(f"Test size: {len(new_dataset_dict['test'])}")
return new_dataset_dict
def preprocess_batch(batch: Dict, image_root: str) -> Dict:
"""Preprocess a batch of examples.
Args:
batch: Batch of raw examples with fields:
- id, question, options (list), answer, solution, level, subject, image
image_root: Root directory for images (not used, as absolute paths are in data)
Returns:
Preprocessed batch with fields:
- prompt: formatted question prompt
- completion: formatted solution/answer text
- solution: extracted answer (for reward computation)
- image_path: path to image file
"""
def _format_answer(answer: str, options: List[str] = None) -> str:
"""Format answer in \\boxed{} format.
For multiple choice, if answer is A/B/C/D/E, include the full option text.
"""
answer = (answer or "").strip()
# If already in boxed format, return as is
if answer.startswith("\\boxed{") and answer.endswith("}"):
return answer
# For multiple choice (single letter), optionally expand to full option
if len(answer) == 1 and answer.upper() in ['A', 'B', 'C', 'D', 'E'] and options and len(options) > 0:
# Map letter to option index
idx = ord(answer.upper()) - ord('A')
if 0 <= idx < len(options):
option_text = options[idx]
return f"\\boxed{{{answer}}}" # Keep just the letter for now
return "\\boxed{" + answer + "}"
def _extract_answer(answer_str: str, options: List[str] = None) -> str:
"""Extract raw answer without boxed formatting."""
answer = (answer_str or "").strip()
# Remove boxed formatting if present
if answer.startswith("\\boxed{") and answer.endswith("}"):
answer = answer[7:-1].strip()
return answer
def _format_question_with_options(question: str, options: List[str] = None) -> str:
"""Format question with multiple choice options if available."""
formatted = question.strip()
if options and len(options) > 0:
formatted += "\nOptions:\n"
for i, opt in enumerate(options):
letter = chr(ord('A') + i)
formatted += f"{letter}. {opt}\n"
return formatted
# Templates
format_template = r"""Solve the problem and output the answer in the format of \boxed{your answer}."""
prompt_template = "\n Question: {prompt}\n"
# Get fields from batch
questions: List[str] = batch.get("question", [])
options_list: List[List[str]] = batch.get("options", [[]] * len(questions))
answers: List[str] = batch.get("answer", [""] * len(questions))
solutions: List[Optional[str]] = batch.get("solution", [None] * len(questions))
image_paths_src: List[str] = batch.get("image", [""] * len(questions))
prompts: List[str] = []
completions: List[str] = []
solution_labels: List[str] = []
image_paths: List[str] = []
for q, opts, ans, sol, img_path in zip(questions, options_list, answers, solutions, image_paths_src):
# Format question with options
formatted_q = _format_question_with_options(q, opts)
processed_prompt = format_template + prompt_template.format(prompt=formatted_q)
# Extract and format answer
raw_answer = _extract_answer(ans, opts)
solution_label = _format_answer(ans, opts)
# For completion, use solution if available, otherwise use answer
if sol and sol.strip():
# Solution exists, use it as completion
completion_text = sol.strip()
else:
# No solution, create a simple completion with the answer
completion_text = f"The answer is {solution_label}"
prompts.append(processed_prompt)
completions.append(completion_text)
solution_labels.append(solution_label)
# Image path is already absolute in the data
image_paths.append(img_path if img_path else None)
return {
"prompt": prompts,
"completion": completions,
"solution": solution_labels,
"image_path": image_paths,
}
def preprocess_dataset(dataset_dict: DatasetDict, image_root: str, batch_size: int = 512) -> DatasetDict:
"""Preprocess all splits.
Args:
dataset_dict: Raw dataset dictionary
image_root: Root directory for images (not used for this dataset)
batch_size: Batch size for processing
Returns:
Preprocessed DatasetDict with fields: prompt, completion, solution, image_path
"""
keep_keys = ["prompt", "completion", "solution", "image_path"]
def _map(split):
logging.info(f"Preprocessing {split} split with batch_size={batch_size}")
ds = dataset_dict[split].map(
lambda batch: preprocess_batch(batch, image_root),
batched=True,
batch_size=batch_size,
num_proc=None,
remove_columns=dataset_dict[split].column_names,
desc=f"Math_Vision preprocess ({split})",
)
# Filter out samples with empty solution
def has_valid_solution(example):
solution = example.get("solution", "")
return solution is not None and len(solution.strip()) > 0
ds_filtered = ds.filter(has_valid_solution, desc=f"Filter empty solutions ({split})")
num_filtered = len(ds) - len(ds_filtered)
if num_filtered > 0:
logging.info(f"Filtered {num_filtered} samples with empty solutions from {split}")
return ds_filtered.select_columns(keep_keys)
result = DatasetDict({split: _map(split) for split in dataset_dict.keys()})
for split, ds in result.items():
logging.info(f"Preprocessed {split}: {len(ds)} samples")
return result
def save_dataset(dataset_dict: DatasetDict, output_dir: str):
"""Save preprocessed dataset to JSON files.
Args:
dataset_dict: Preprocessed dataset
output_dir: Output directory
"""
os.makedirs(output_dir, exist_ok=True)
for split_name, ds in dataset_dict.items():
output_path = os.path.join(output_dir, f"{split_name}.json")
logging.info(f"Saving {split_name} split to {output_path}")
# Convert to list of dicts and save as JSON
data = []
for example in ds:
data.append({
"prompt": example["prompt"],
"completion": example["completion"],
"solution": example["solution"],
"image_path": example["image_path"],
})
with open(output_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logging.info(f"Saved {len(data)} samples to {output_path}")
def main():
parser = argparse.ArgumentParser(description="Preprocess Math Vision dataset")
parser.add_argument("--config", type=str, help="Path to config YAML file")
parser.add_argument("--input_dir", type=str, default="/root/CVPR/MemGen/data/math_vision", help="Input directory with train.json/test.json")
parser.add_argument("--output_dir", type=str, default="/root/CVPR/MemGen/data/math_vision", help="Output directory for preprocessed data")
parser.add_argument("--val_ratio", type=float, default=0.1, help="Validation set ratio")
parser.add_argument("--image_root", type=str, default="/root/CVPR/MemGen/dataset/math_vision/images", help="Image root directory")
parser.add_argument("--batch_size", type=int, default=512, help="Batch size for preprocessing")
args = parser.parse_args()
# Load config if provided
if args.config:
logging.info(f"Loading config from {args.config}")
with open(args.config, "r") as f:
cfg = yaml.safe_load(f)
# Extract dataset config
dataset_cfg = cfg.get("datasets", {}).get("math_vision", {})
mode = dataset_cfg.get("mode", "sft")
mode_cfg = dataset_cfg.get(mode, {})
val_ratio = mode_cfg.get("val_ratio", args.val_ratio)
image_root = mode_cfg.get("image_root", args.image_root)
else:
val_ratio = args.val_ratio
image_root = args.image_root
input_dir = args.input_dir
output_dir = args.output_dir
batch_size = args.batch_size
logging.info("=" * 80)
logging.info("Math Vision Dataset Preprocessing")
logging.info("=" * 80)
logging.info(f"Input directory: {input_dir}")
logging.info(f"Output directory: {output_dir}")
logging.info(f"Validation ratio: {val_ratio}")
logging.info(f"Image root: {image_root}")
logging.info(f"Batch size: {batch_size}")
# Step 1: Load existing dataset
dataset_dict = load_existing_dataset(input_dir)
# Step 2: Split train into train/valid if needed
dataset_dict = split_train_valid(dataset_dict, val_ratio=val_ratio)
# Step 3: Preprocess dataset
preprocessed = preprocess_dataset(dataset_dict, image_root, batch_size=batch_size)
# Step 4: Save preprocessed data
save_dataset(preprocessed, output_dir)
logging.info("=" * 80)
logging.info("Preprocessing complete!")
logging.info("=" * 80)
for split, ds in preprocessed.items():
logging.info(f"{split}: {len(ds)} samples")
logging.info(f"Output saved to: {output_dir}")
if __name__ == "__main__":
main()