deepseek-ocr / process_dataset.py
davanstrien's picture
davanstrien HF Staff
Switch to batch processing pattern from official run_dpsk_ocr_eval_batch.py
b7b3c0d
#!/usr/bin/env python3
"""
DeepSeek-OCR Dataset Processing
Minimal adaptation of official run_dpsk_ocr_eval_batch.py for dataset processing
"""
import argparse
import json
import os
import sys
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
import torch
if torch.version.cuda == '11.8':
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"
os.environ['VLLM_USE_V1'] = '0'
from vllm import LLM, SamplingParams
from vllm.model_executor.models.registry import ModelRegistry
from PIL import Image, ImageOps
from tqdm.auto import tqdm
from datasets import load_dataset
from huggingface_hub import login
# Import DeepSeek-OCR modules (unchanged from original)
from deepseek_ocr import DeepseekOCRForCausalLM
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from process.image_process import DeepseekOCRProcessor
from config import MODEL_PATH, PROMPT, CROP_MODE
# Register custom model (unchanged from original)
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
def check_cuda():
"""Check CUDA availability"""
if not torch.cuda.is_available():
print("ERROR: CUDA is not available. This script requires a GPU.")
sys.exit(1)
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
def process_single_image(image):
"""Preprocess single image (unchanged from official batch script)"""
prompt_in = PROMPT
cache_item = {
"prompt": prompt_in,
"multi_modal_data": {"image": DeepseekOCRProcessor().tokenize_with_images(
images=[image], bos=True, eos=True, cropping=CROP_MODE
)},
}
return cache_item
def main(args):
"""Main processing function"""
check_cuda()
# Enable HF_TRANSFER
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Login to HF if token provided
HF_TOKEN = args.hf_token or os.environ.get("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
# Load dataset
print(f"Loading dataset: {args.input_dataset}")
dataset = load_dataset(args.input_dataset, split=args.split)
if args.image_column not in dataset.column_names:
print(f"ERROR: Column '{args.image_column}' not found")
print(f"Available columns: {dataset.column_names}")
sys.exit(1)
# Shuffle if requested
if args.shuffle:
print(f"Shuffling with seed {args.seed}")
dataset = dataset.shuffle(seed=args.seed)
# Limit samples if requested
if args.max_samples:
dataset = dataset.select(range(min(args.max_samples, len(dataset))))
print(f"Processing {len(dataset)} samples")
# Initialize vLLM engine (UNCHANGED from official batch script)
print("Initializing vLLM engine...")
llm = LLM(
model=MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
block_size=256,
enforce_eager=False,
trust_remote_code=True,
max_model_len=args.max_model_len,
swap_space=0,
max_num_seqs=args.max_num_seqs,
tensor_parallel_size=1,
gpu_memory_utilization=args.gpu_memory_utilization,
)
# Sampling params (UNCHANGED from official batch script)
logits_processors = [NoRepeatNGramLogitsProcessor(
ngram_size=40, window_size=90, whitelist_token_ids={128821, 128822}
)]
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=args.max_tokens,
logits_processors=logits_processors,
skip_special_tokens=False,
)
# Load and preprocess images
print(f"Loading images from dataset...")
images = []
for idx in range(len(dataset)):
try:
image = dataset[idx][args.image_column]
if not isinstance(image, Image.Image):
image = Image.open(image) if isinstance(image, str) else image
image = ImageOps.exif_transpose(image.convert('RGB'))
images.append(image)
except Exception as e:
print(f"Error loading image {idx}: {e}")
images.append(None)
# Preprocess images in parallel (UNCHANGED from official batch script)
print(f"Preprocessing images...")
with ThreadPoolExecutor(max_workers=args.num_workers) as executor:
batch_inputs = list(tqdm(
executor.map(lambda img: process_single_image(img) if img else None, images),
total=len(images),
desc="Pre-processing images"
))
# Filter out None entries and track their indices
valid_indices = [i for i, inp in enumerate(batch_inputs) if inp is not None]
valid_batch_inputs = [inp for inp in batch_inputs if inp is not None]
# Batch inference (UNCHANGED from official batch script)
print(f"Running batch inference on {len(valid_batch_inputs)} images...")
outputs_list = llm.generate(
valid_batch_inputs,
sampling_params=sampling_params
)
# Extract results
all_markdown = ["[OCR FAILED]"] * len(dataset)
for idx, output in zip(valid_indices, outputs_list):
all_markdown[idx] = output.outputs[0].text.strip()
# Add markdown column
print("Adding markdown column...")
dataset = dataset.add_column("markdown", all_markdown)
# Handle inference_info
if "inference_info" in dataset.column_names:
try:
existing_info = json.loads(dataset[0]["inference_info"])
if not isinstance(existing_info, list):
existing_info = [existing_info]
except:
existing_info = []
dataset = dataset.remove_columns(["inference_info"])
else:
existing_info = []
new_info = {
"column_name": "markdown",
"model_id": MODEL_PATH,
"processing_date": datetime.now().isoformat(),
"prompt": PROMPT,
"max_tokens": args.max_tokens,
"max_model_len": args.max_model_len,
"gpu_memory_utilization": args.gpu_memory_utilization,
"max_num_seqs": args.max_num_seqs,
"script": "process_dataset.py",
"implementation": "vllm-batch (official deepseek batch code)",
}
existing_info.append(new_info)
info_json = json.dumps(existing_info, ensure_ascii=False)
dataset = dataset.add_column("inference_info", [info_json] * len(dataset))
# Push to hub
print(f"Pushing to {args.output_dataset}")
dataset.push_to_hub(args.output_dataset, private=args.private, token=HF_TOKEN)
print("✅ Complete!")
print(f"Dataset: https://huggingface.co/datasets/{args.output_dataset}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Process images through DeepSeek-OCR"
)
parser.add_argument("input_dataset", help="Input dataset ID")
parser.add_argument("output_dataset", help="Output dataset ID")
parser.add_argument("--image-column", default="image", help="Image column name")
parser.add_argument("--split", default="train", help="Dataset split")
parser.add_argument("--max-samples", type=int, help="Limit number of samples")
parser.add_argument("--shuffle", action="store_true", help="Shuffle dataset")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--max-model-len", type=int, default=8192)
parser.add_argument("--max-tokens", type=int, default=8192)
parser.add_argument("--gpu-memory-utilization", type=float, default=0.75)
parser.add_argument("--max-num-seqs", type=int, default=100, help="Max concurrent sequences")
parser.add_argument("--num-workers", type=int, default=64, help="Image preprocessing workers")
parser.add_argument("--hf-token", help="HF API token")
parser.add_argument("--private", action="store_true", help="Make output private")
args = parser.parse_args()
main(args)