Spaces:
Runtime error
Runtime error
davanstrien
HF Staff
Switch to batch processing pattern from official run_dpsk_ocr_eval_batch.py
b7b3c0d
| #!/usr/bin/env python3 | |
| """ | |
| DeepSeek-OCR Dataset Processing | |
| Minimal adaptation of official run_dpsk_ocr_eval_batch.py for dataset processing | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| from datetime import datetime | |
| from concurrent.futures import ThreadPoolExecutor | |
| import torch | |
| if torch.version.cuda == '11.8': | |
| os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas" | |
| os.environ['VLLM_USE_V1'] = '0' | |
| from vllm import LLM, SamplingParams | |
| from vllm.model_executor.models.registry import ModelRegistry | |
| from PIL import Image, ImageOps | |
| from tqdm.auto import tqdm | |
| from datasets import load_dataset | |
| from huggingface_hub import login | |
| # Import DeepSeek-OCR modules (unchanged from original) | |
| from deepseek_ocr import DeepseekOCRForCausalLM | |
| from process.ngram_norepeat import NoRepeatNGramLogitsProcessor | |
| from process.image_process import DeepseekOCRProcessor | |
| from config import MODEL_PATH, PROMPT, CROP_MODE | |
| # Register custom model (unchanged from original) | |
| ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM) | |
| def check_cuda(): | |
| """Check CUDA availability""" | |
| if not torch.cuda.is_available(): | |
| print("ERROR: CUDA is not available. This script requires a GPU.") | |
| sys.exit(1) | |
| print(f"Using GPU: {torch.cuda.get_device_name(0)}") | |
| def process_single_image(image): | |
| """Preprocess single image (unchanged from official batch script)""" | |
| prompt_in = PROMPT | |
| cache_item = { | |
| "prompt": prompt_in, | |
| "multi_modal_data": {"image": DeepseekOCRProcessor().tokenize_with_images( | |
| images=[image], bos=True, eos=True, cropping=CROP_MODE | |
| )}, | |
| } | |
| return cache_item | |
| def main(args): | |
| """Main processing function""" | |
| check_cuda() | |
| # Enable HF_TRANSFER | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
| # Login to HF if token provided | |
| HF_TOKEN = args.hf_token or os.environ.get("HF_TOKEN") | |
| if HF_TOKEN: | |
| login(token=HF_TOKEN) | |
| # Load dataset | |
| print(f"Loading dataset: {args.input_dataset}") | |
| dataset = load_dataset(args.input_dataset, split=args.split) | |
| if args.image_column not in dataset.column_names: | |
| print(f"ERROR: Column '{args.image_column}' not found") | |
| print(f"Available columns: {dataset.column_names}") | |
| sys.exit(1) | |
| # Shuffle if requested | |
| if args.shuffle: | |
| print(f"Shuffling with seed {args.seed}") | |
| dataset = dataset.shuffle(seed=args.seed) | |
| # Limit samples if requested | |
| if args.max_samples: | |
| dataset = dataset.select(range(min(args.max_samples, len(dataset)))) | |
| print(f"Processing {len(dataset)} samples") | |
| # Initialize vLLM engine (UNCHANGED from official batch script) | |
| print("Initializing vLLM engine...") | |
| llm = LLM( | |
| model=MODEL_PATH, | |
| hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]}, | |
| block_size=256, | |
| enforce_eager=False, | |
| trust_remote_code=True, | |
| max_model_len=args.max_model_len, | |
| swap_space=0, | |
| max_num_seqs=args.max_num_seqs, | |
| tensor_parallel_size=1, | |
| gpu_memory_utilization=args.gpu_memory_utilization, | |
| ) | |
| # Sampling params (UNCHANGED from official batch script) | |
| logits_processors = [NoRepeatNGramLogitsProcessor( | |
| ngram_size=40, window_size=90, whitelist_token_ids={128821, 128822} | |
| )] | |
| sampling_params = SamplingParams( | |
| temperature=0.0, | |
| max_tokens=args.max_tokens, | |
| logits_processors=logits_processors, | |
| skip_special_tokens=False, | |
| ) | |
| # Load and preprocess images | |
| print(f"Loading images from dataset...") | |
| images = [] | |
| for idx in range(len(dataset)): | |
| try: | |
| image = dataset[idx][args.image_column] | |
| if not isinstance(image, Image.Image): | |
| image = Image.open(image) if isinstance(image, str) else image | |
| image = ImageOps.exif_transpose(image.convert('RGB')) | |
| images.append(image) | |
| except Exception as e: | |
| print(f"Error loading image {idx}: {e}") | |
| images.append(None) | |
| # Preprocess images in parallel (UNCHANGED from official batch script) | |
| print(f"Preprocessing images...") | |
| with ThreadPoolExecutor(max_workers=args.num_workers) as executor: | |
| batch_inputs = list(tqdm( | |
| executor.map(lambda img: process_single_image(img) if img else None, images), | |
| total=len(images), | |
| desc="Pre-processing images" | |
| )) | |
| # Filter out None entries and track their indices | |
| valid_indices = [i for i, inp in enumerate(batch_inputs) if inp is not None] | |
| valid_batch_inputs = [inp for inp in batch_inputs if inp is not None] | |
| # Batch inference (UNCHANGED from official batch script) | |
| print(f"Running batch inference on {len(valid_batch_inputs)} images...") | |
| outputs_list = llm.generate( | |
| valid_batch_inputs, | |
| sampling_params=sampling_params | |
| ) | |
| # Extract results | |
| all_markdown = ["[OCR FAILED]"] * len(dataset) | |
| for idx, output in zip(valid_indices, outputs_list): | |
| all_markdown[idx] = output.outputs[0].text.strip() | |
| # Add markdown column | |
| print("Adding markdown column...") | |
| dataset = dataset.add_column("markdown", all_markdown) | |
| # Handle inference_info | |
| if "inference_info" in dataset.column_names: | |
| try: | |
| existing_info = json.loads(dataset[0]["inference_info"]) | |
| if not isinstance(existing_info, list): | |
| existing_info = [existing_info] | |
| except: | |
| existing_info = [] | |
| dataset = dataset.remove_columns(["inference_info"]) | |
| else: | |
| existing_info = [] | |
| new_info = { | |
| "column_name": "markdown", | |
| "model_id": MODEL_PATH, | |
| "processing_date": datetime.now().isoformat(), | |
| "prompt": PROMPT, | |
| "max_tokens": args.max_tokens, | |
| "max_model_len": args.max_model_len, | |
| "gpu_memory_utilization": args.gpu_memory_utilization, | |
| "max_num_seqs": args.max_num_seqs, | |
| "script": "process_dataset.py", | |
| "implementation": "vllm-batch (official deepseek batch code)", | |
| } | |
| existing_info.append(new_info) | |
| info_json = json.dumps(existing_info, ensure_ascii=False) | |
| dataset = dataset.add_column("inference_info", [info_json] * len(dataset)) | |
| # Push to hub | |
| print(f"Pushing to {args.output_dataset}") | |
| dataset.push_to_hub(args.output_dataset, private=args.private, token=HF_TOKEN) | |
| print("✅ Complete!") | |
| print(f"Dataset: https://huggingface.co/datasets/{args.output_dataset}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Process images through DeepSeek-OCR" | |
| ) | |
| parser.add_argument("input_dataset", help="Input dataset ID") | |
| parser.add_argument("output_dataset", help="Output dataset ID") | |
| parser.add_argument("--image-column", default="image", help="Image column name") | |
| parser.add_argument("--split", default="train", help="Dataset split") | |
| parser.add_argument("--max-samples", type=int, help="Limit number of samples") | |
| parser.add_argument("--shuffle", action="store_true", help="Shuffle dataset") | |
| parser.add_argument("--seed", type=int, default=42, help="Random seed") | |
| parser.add_argument("--max-model-len", type=int, default=8192) | |
| parser.add_argument("--max-tokens", type=int, default=8192) | |
| parser.add_argument("--gpu-memory-utilization", type=float, default=0.75) | |
| parser.add_argument("--max-num-seqs", type=int, default=100, help="Max concurrent sequences") | |
| parser.add_argument("--num-workers", type=int, default=64, help="Image preprocessing workers") | |
| parser.add_argument("--hf-token", help="HF API token") | |
| parser.add_argument("--private", action="store_true", help="Make output private") | |
| args = parser.parse_args() | |
| main(args) | |