|
|
""" |
|
|
Math Vision Dataset Preprocessing Script |
|
|
|
|
|
This script reads the existing Math Vision dataset from data/math_vision directory |
|
|
and preprocesses it into the format expected by the dataloader. |
|
|
The preprocessed data will be saved with fields: prompt, completion, solution, image_path |
|
|
|
|
|
Usage: |
|
|
# Using config file |
|
|
uv run scripts/math_vision_process.py --config configs/latent_memory/math_vision.yaml |
|
|
|
|
|
# Manual parameters |
|
|
uv run scripts/math_vision_process.py --input_dir data/math_vision --output_dir data/math_vision |
|
|
""" |
|
|
|
|
|
import os |
|
|
import re |
|
|
import json |
|
|
import logging |
|
|
import argparse |
|
|
from typing import Dict, List, Optional |
|
|
import yaml |
|
|
from datasets import load_dataset, DatasetDict |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
|
|
|
def load_existing_dataset(data_path: str) -> DatasetDict: |
|
|
"""Load existing Math Vision dataset. |
|
|
|
|
|
Args: |
|
|
data_path: Directory containing train.json and test.json |
|
|
|
|
|
Returns: |
|
|
DatasetDict with train/test splits |
|
|
""" |
|
|
data_files = {} |
|
|
train_path = os.path.join(data_path, "train.json") |
|
|
test_path = os.path.join(data_path, "test.json") |
|
|
|
|
|
if os.path.exists(train_path): |
|
|
data_files["train"] = train_path |
|
|
logging.info(f"Found train.json at {train_path}") |
|
|
|
|
|
if os.path.exists(test_path): |
|
|
data_files["test"] = test_path |
|
|
logging.info(f"Found test.json at {test_path}") |
|
|
|
|
|
if len(data_files) == 0: |
|
|
raise FileNotFoundError(f"No data files found in {data_path}") |
|
|
|
|
|
logging.info(f"Loading dataset from {data_path}") |
|
|
dataset_dict = load_dataset("json", data_files=data_files) |
|
|
|
|
|
return dataset_dict |
|
|
|
|
|
|
|
|
def split_train_valid(dataset_dict: DatasetDict, val_ratio: float = 0.1) -> DatasetDict: |
|
|
"""Split train set into train/valid if valid doesn't exist. |
|
|
|
|
|
Args: |
|
|
dataset_dict: Dataset dictionary |
|
|
val_ratio: Validation set ratio |
|
|
|
|
|
Returns: |
|
|
DatasetDict with train/valid/test splits |
|
|
""" |
|
|
if "valid" in dataset_dict: |
|
|
logging.info("Validation set already exists, skipping split") |
|
|
return dataset_dict |
|
|
|
|
|
if "train" not in dataset_dict: |
|
|
logging.warning("No train set found, cannot create validation split") |
|
|
return dataset_dict |
|
|
|
|
|
if val_ratio <= 0 or val_ratio >= 1: |
|
|
logging.warning(f"Invalid val_ratio {val_ratio}, skipping validation split") |
|
|
return dataset_dict |
|
|
|
|
|
logging.info(f"Splitting train set with val_ratio={val_ratio}") |
|
|
train_test = dataset_dict["train"].train_test_split(test_size=val_ratio, seed=42, shuffle=True) |
|
|
|
|
|
|
|
|
original_test = dataset_dict.get("test", None) |
|
|
|
|
|
new_dataset_dict = DatasetDict({ |
|
|
"train": train_test["train"], |
|
|
"valid": train_test["test"], |
|
|
}) |
|
|
|
|
|
|
|
|
if original_test is not None: |
|
|
new_dataset_dict["test"] = original_test |
|
|
|
|
|
logging.info(f"Split sizes - train: {len(new_dataset_dict['train'])}, valid: {len(new_dataset_dict['valid'])}") |
|
|
if original_test is not None: |
|
|
logging.info(f"Test size: {len(new_dataset_dict['test'])}") |
|
|
|
|
|
return new_dataset_dict |
|
|
|
|
|
|
|
|
def preprocess_batch(batch: Dict, image_root: str) -> Dict: |
|
|
"""Preprocess a batch of examples. |
|
|
|
|
|
Args: |
|
|
batch: Batch of raw examples with fields: |
|
|
- id, question, options (list), answer, solution, level, subject, image |
|
|
image_root: Root directory for images (not used, as absolute paths are in data) |
|
|
|
|
|
Returns: |
|
|
Preprocessed batch with fields: |
|
|
- prompt: formatted question prompt |
|
|
- completion: formatted solution/answer text |
|
|
- solution: extracted answer (for reward computation) |
|
|
- image_path: path to image file |
|
|
""" |
|
|
def _format_answer(answer: str, options: List[str] = None) -> str: |
|
|
"""Format answer in \\boxed{} format. |
|
|
|
|
|
For multiple choice, if answer is A/B/C/D/E, include the full option text. |
|
|
""" |
|
|
answer = (answer or "").strip() |
|
|
|
|
|
|
|
|
if answer.startswith("\\boxed{") and answer.endswith("}"): |
|
|
return answer |
|
|
|
|
|
|
|
|
if len(answer) == 1 and answer.upper() in ['A', 'B', 'C', 'D', 'E'] and options and len(options) > 0: |
|
|
|
|
|
idx = ord(answer.upper()) - ord('A') |
|
|
if 0 <= idx < len(options): |
|
|
option_text = options[idx] |
|
|
return f"\\boxed{{{answer}}}" |
|
|
|
|
|
return "\\boxed{" + answer + "}" |
|
|
|
|
|
def _extract_answer(answer_str: str, options: List[str] = None) -> str: |
|
|
"""Extract raw answer without boxed formatting.""" |
|
|
answer = (answer_str or "").strip() |
|
|
|
|
|
|
|
|
if answer.startswith("\\boxed{") and answer.endswith("}"): |
|
|
answer = answer[7:-1].strip() |
|
|
|
|
|
return answer |
|
|
|
|
|
def _format_question_with_options(question: str, options: List[str] = None) -> str: |
|
|
"""Format question with multiple choice options if available.""" |
|
|
formatted = question.strip() |
|
|
|
|
|
if options and len(options) > 0: |
|
|
formatted += "\nOptions:\n" |
|
|
for i, opt in enumerate(options): |
|
|
letter = chr(ord('A') + i) |
|
|
formatted += f"{letter}. {opt}\n" |
|
|
|
|
|
return formatted |
|
|
|
|
|
|
|
|
format_template = r"""Solve the problem and output the answer in the format of \boxed{your answer}.""" |
|
|
prompt_template = "\n Question: {prompt}\n" |
|
|
|
|
|
|
|
|
questions: List[str] = batch.get("question", []) |
|
|
options_list: List[List[str]] = batch.get("options", [[]] * len(questions)) |
|
|
answers: List[str] = batch.get("answer", [""] * len(questions)) |
|
|
solutions: List[Optional[str]] = batch.get("solution", [None] * len(questions)) |
|
|
image_paths_src: List[str] = batch.get("image", [""] * len(questions)) |
|
|
|
|
|
prompts: List[str] = [] |
|
|
completions: List[str] = [] |
|
|
solution_labels: List[str] = [] |
|
|
image_paths: List[str] = [] |
|
|
|
|
|
for q, opts, ans, sol, img_path in zip(questions, options_list, answers, solutions, image_paths_src): |
|
|
|
|
|
formatted_q = _format_question_with_options(q, opts) |
|
|
processed_prompt = format_template + prompt_template.format(prompt=formatted_q) |
|
|
|
|
|
|
|
|
raw_answer = _extract_answer(ans, opts) |
|
|
solution_label = _format_answer(ans, opts) |
|
|
|
|
|
|
|
|
if sol and sol.strip(): |
|
|
|
|
|
completion_text = sol.strip() |
|
|
else: |
|
|
|
|
|
completion_text = f"The answer is {solution_label}" |
|
|
|
|
|
prompts.append(processed_prompt) |
|
|
completions.append(completion_text) |
|
|
solution_labels.append(solution_label) |
|
|
|
|
|
|
|
|
image_paths.append(img_path if img_path else None) |
|
|
|
|
|
return { |
|
|
"prompt": prompts, |
|
|
"completion": completions, |
|
|
"solution": solution_labels, |
|
|
"image_path": image_paths, |
|
|
} |
|
|
|
|
|
|
|
|
def preprocess_dataset(dataset_dict: DatasetDict, image_root: str, batch_size: int = 512) -> DatasetDict: |
|
|
"""Preprocess all splits. |
|
|
|
|
|
Args: |
|
|
dataset_dict: Raw dataset dictionary |
|
|
image_root: Root directory for images (not used for this dataset) |
|
|
batch_size: Batch size for processing |
|
|
|
|
|
Returns: |
|
|
Preprocessed DatasetDict with fields: prompt, completion, solution, image_path |
|
|
""" |
|
|
keep_keys = ["prompt", "completion", "solution", "image_path"] |
|
|
|
|
|
def _map(split): |
|
|
logging.info(f"Preprocessing {split} split with batch_size={batch_size}") |
|
|
ds = dataset_dict[split].map( |
|
|
lambda batch: preprocess_batch(batch, image_root), |
|
|
batched=True, |
|
|
batch_size=batch_size, |
|
|
num_proc=None, |
|
|
remove_columns=dataset_dict[split].column_names, |
|
|
desc=f"Math_Vision preprocess ({split})", |
|
|
) |
|
|
|
|
|
|
|
|
def has_valid_solution(example): |
|
|
solution = example.get("solution", "") |
|
|
return solution is not None and len(solution.strip()) > 0 |
|
|
|
|
|
ds_filtered = ds.filter(has_valid_solution, desc=f"Filter empty solutions ({split})") |
|
|
|
|
|
num_filtered = len(ds) - len(ds_filtered) |
|
|
if num_filtered > 0: |
|
|
logging.info(f"Filtered {num_filtered} samples with empty solutions from {split}") |
|
|
|
|
|
return ds_filtered.select_columns(keep_keys) |
|
|
|
|
|
result = DatasetDict({split: _map(split) for split in dataset_dict.keys()}) |
|
|
|
|
|
for split, ds in result.items(): |
|
|
logging.info(f"Preprocessed {split}: {len(ds)} samples") |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def save_dataset(dataset_dict: DatasetDict, output_dir: str): |
|
|
"""Save preprocessed dataset to JSON files. |
|
|
|
|
|
Args: |
|
|
dataset_dict: Preprocessed dataset |
|
|
output_dir: Output directory |
|
|
""" |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
for split_name, ds in dataset_dict.items(): |
|
|
output_path = os.path.join(output_dir, f"{split_name}.json") |
|
|
logging.info(f"Saving {split_name} split to {output_path}") |
|
|
|
|
|
|
|
|
data = [] |
|
|
for example in ds: |
|
|
data.append({ |
|
|
"prompt": example["prompt"], |
|
|
"completion": example["completion"], |
|
|
"solution": example["solution"], |
|
|
"image_path": example["image_path"], |
|
|
}) |
|
|
|
|
|
with open(output_path, "w", encoding="utf-8") as f: |
|
|
json.dump(data, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
logging.info(f"Saved {len(data)} samples to {output_path}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Preprocess Math Vision dataset") |
|
|
parser.add_argument("--config", type=str, help="Path to config YAML file") |
|
|
parser.add_argument("--input_dir", type=str, default="/root/CVPR/MemGen/data/math_vision", help="Input directory with train.json/test.json") |
|
|
parser.add_argument("--output_dir", type=str, default="/root/CVPR/MemGen/data/math_vision", help="Output directory for preprocessed data") |
|
|
parser.add_argument("--val_ratio", type=float, default=0.1, help="Validation set ratio") |
|
|
parser.add_argument("--image_root", type=str, default="/root/CVPR/MemGen/dataset/math_vision/images", help="Image root directory") |
|
|
parser.add_argument("--batch_size", type=int, default=512, help="Batch size for preprocessing") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.config: |
|
|
logging.info(f"Loading config from {args.config}") |
|
|
with open(args.config, "r") as f: |
|
|
cfg = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
dataset_cfg = cfg.get("datasets", {}).get("math_vision", {}) |
|
|
mode = dataset_cfg.get("mode", "sft") |
|
|
mode_cfg = dataset_cfg.get(mode, {}) |
|
|
|
|
|
val_ratio = mode_cfg.get("val_ratio", args.val_ratio) |
|
|
image_root = mode_cfg.get("image_root", args.image_root) |
|
|
else: |
|
|
val_ratio = args.val_ratio |
|
|
image_root = args.image_root |
|
|
|
|
|
input_dir = args.input_dir |
|
|
output_dir = args.output_dir |
|
|
batch_size = args.batch_size |
|
|
|
|
|
logging.info("=" * 80) |
|
|
logging.info("Math Vision Dataset Preprocessing") |
|
|
logging.info("=" * 80) |
|
|
logging.info(f"Input directory: {input_dir}") |
|
|
logging.info(f"Output directory: {output_dir}") |
|
|
logging.info(f"Validation ratio: {val_ratio}") |
|
|
logging.info(f"Image root: {image_root}") |
|
|
logging.info(f"Batch size: {batch_size}") |
|
|
|
|
|
|
|
|
dataset_dict = load_existing_dataset(input_dir) |
|
|
|
|
|
|
|
|
dataset_dict = split_train_valid(dataset_dict, val_ratio=val_ratio) |
|
|
|
|
|
|
|
|
preprocessed = preprocess_dataset(dataset_dict, image_root, batch_size=batch_size) |
|
|
|
|
|
|
|
|
save_dataset(preprocessed, output_dir) |
|
|
|
|
|
logging.info("=" * 80) |
|
|
logging.info("Preprocessing complete!") |
|
|
logging.info("=" * 80) |
|
|
for split, ds in preprocessed.items(): |
|
|
logging.info(f"{split}: {len(ds)} samples") |
|
|
logging.info(f"Output saved to: {output_dir}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|