File size: 7,798 Bytes
0a5527f
 
 
b7b3c0d
0a5527f
 
 
 
 
 
 
b7b3c0d
0a5527f
 
 
 
 
 
 
b7b3c0d
0a5527f
 
 
 
b7b3c0d
0a5527f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7b3c0d
 
 
 
 
 
 
 
 
 
0a5527f
 
b7b3c0d
0a5527f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7b3c0d
0a5527f
b7b3c0d
0a5527f
 
 
 
 
b7b3c0d
 
 
0a5527f
 
 
 
b7b3c0d
0a5527f
b7b3c0d
0a5527f
 
 
 
 
 
 
 
 
b7b3c0d
 
 
 
0a5527f
 
 
 
 
b7b3c0d
0a5527f
b7b3c0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a5527f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7b3c0d
0a5527f
b7b3c0d
0a5527f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7b3c0d
 
0a5527f
 
 
 
b7b3c0d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
#!/usr/bin/env python3
"""
DeepSeek-OCR Dataset Processing
Minimal adaptation of official run_dpsk_ocr_eval_batch.py for dataset processing
"""

import argparse
import json
import os
import sys
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor

import torch
if torch.version.cuda == '11.8':
    os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"

os.environ['VLLM_USE_V1'] = '0'

from vllm import LLM, SamplingParams
from vllm.model_executor.models.registry import ModelRegistry
from PIL import Image, ImageOps
from tqdm.auto import tqdm
from datasets import load_dataset
from huggingface_hub import login

# Import DeepSeek-OCR modules (unchanged from original)
from deepseek_ocr import DeepseekOCRForCausalLM
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from process.image_process import DeepseekOCRProcessor
from config import MODEL_PATH, PROMPT, CROP_MODE

# Register custom model (unchanged from original)
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)


def check_cuda():
    """Check CUDA availability"""
    if not torch.cuda.is_available():
        print("ERROR: CUDA is not available. This script requires a GPU.")
        sys.exit(1)
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")


def process_single_image(image):
    """Preprocess single image (unchanged from official batch script)"""
    prompt_in = PROMPT
    cache_item = {
        "prompt": prompt_in,
        "multi_modal_data": {"image": DeepseekOCRProcessor().tokenize_with_images(
            images=[image], bos=True, eos=True, cropping=CROP_MODE
        )},
    }
    return cache_item


def main(args):
    """Main processing function"""
    check_cuda()

    # Enable HF_TRANSFER
    os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

    # Login to HF if token provided
    HF_TOKEN = args.hf_token or os.environ.get("HF_TOKEN")
    if HF_TOKEN:
        login(token=HF_TOKEN)

    # Load dataset
    print(f"Loading dataset: {args.input_dataset}")
    dataset = load_dataset(args.input_dataset, split=args.split)

    if args.image_column not in dataset.column_names:
        print(f"ERROR: Column '{args.image_column}' not found")
        print(f"Available columns: {dataset.column_names}")
        sys.exit(1)

    # Shuffle if requested
    if args.shuffle:
        print(f"Shuffling with seed {args.seed}")
        dataset = dataset.shuffle(seed=args.seed)

    # Limit samples if requested
    if args.max_samples:
        dataset = dataset.select(range(min(args.max_samples, len(dataset))))
        print(f"Processing {len(dataset)} samples")

    # Initialize vLLM engine (UNCHANGED from official batch script)
    print("Initializing vLLM engine...")
    llm = LLM(
        model=MODEL_PATH,
        hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
        block_size=256,
        enforce_eager=False,
        trust_remote_code=True,
        max_model_len=args.max_model_len,
        swap_space=0,
        max_num_seqs=args.max_num_seqs,
        tensor_parallel_size=1,
        gpu_memory_utilization=args.gpu_memory_utilization,
    )

    # Sampling params (UNCHANGED from official batch script)
    logits_processors = [NoRepeatNGramLogitsProcessor(
        ngram_size=40, window_size=90, whitelist_token_ids={128821, 128822}
    )]

    sampling_params = SamplingParams(
        temperature=0.0,
        max_tokens=args.max_tokens,
        logits_processors=logits_processors,
        skip_special_tokens=False,
    )

    # Load and preprocess images
    print(f"Loading images from dataset...")
    images = []
    for idx in range(len(dataset)):
        try:
            image = dataset[idx][args.image_column]
            if not isinstance(image, Image.Image):
                image = Image.open(image) if isinstance(image, str) else image
            image = ImageOps.exif_transpose(image.convert('RGB'))
            images.append(image)
        except Exception as e:
            print(f"Error loading image {idx}: {e}")
            images.append(None)

    # Preprocess images in parallel (UNCHANGED from official batch script)
    print(f"Preprocessing images...")
    with ThreadPoolExecutor(max_workers=args.num_workers) as executor:
        batch_inputs = list(tqdm(
            executor.map(lambda img: process_single_image(img) if img else None, images),
            total=len(images),
            desc="Pre-processing images"
        ))

    # Filter out None entries and track their indices
    valid_indices = [i for i, inp in enumerate(batch_inputs) if inp is not None]
    valid_batch_inputs = [inp for inp in batch_inputs if inp is not None]

    # Batch inference (UNCHANGED from official batch script)
    print(f"Running batch inference on {len(valid_batch_inputs)} images...")
    outputs_list = llm.generate(
        valid_batch_inputs,
        sampling_params=sampling_params
    )

    # Extract results
    all_markdown = ["[OCR FAILED]"] * len(dataset)
    for idx, output in zip(valid_indices, outputs_list):
        all_markdown[idx] = output.outputs[0].text.strip()

    # Add markdown column
    print("Adding markdown column...")
    dataset = dataset.add_column("markdown", all_markdown)

    # Handle inference_info
    if "inference_info" in dataset.column_names:
        try:
            existing_info = json.loads(dataset[0]["inference_info"])
            if not isinstance(existing_info, list):
                existing_info = [existing_info]
        except:
            existing_info = []
        dataset = dataset.remove_columns(["inference_info"])
    else:
        existing_info = []

    new_info = {
        "column_name": "markdown",
        "model_id": MODEL_PATH,
        "processing_date": datetime.now().isoformat(),
        "prompt": PROMPT,
        "max_tokens": args.max_tokens,
        "max_model_len": args.max_model_len,
        "gpu_memory_utilization": args.gpu_memory_utilization,
        "max_num_seqs": args.max_num_seqs,
        "script": "process_dataset.py",
        "implementation": "vllm-batch (official deepseek batch code)",
    }
    existing_info.append(new_info)

    info_json = json.dumps(existing_info, ensure_ascii=False)
    dataset = dataset.add_column("inference_info", [info_json] * len(dataset))

    # Push to hub
    print(f"Pushing to {args.output_dataset}")
    dataset.push_to_hub(args.output_dataset, private=args.private, token=HF_TOKEN)

    print("✅ Complete!")
    print(f"Dataset: https://huggingface.co/datasets/{args.output_dataset}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Process images through DeepSeek-OCR"
    )
    parser.add_argument("input_dataset", help="Input dataset ID")
    parser.add_argument("output_dataset", help="Output dataset ID")
    parser.add_argument("--image-column", default="image", help="Image column name")
    parser.add_argument("--split", default="train", help="Dataset split")
    parser.add_argument("--max-samples", type=int, help="Limit number of samples")
    parser.add_argument("--shuffle", action="store_true", help="Shuffle dataset")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--max-model-len", type=int, default=8192)
    parser.add_argument("--max-tokens", type=int, default=8192)
    parser.add_argument("--gpu-memory-utilization", type=float, default=0.75)
    parser.add_argument("--max-num-seqs", type=int, default=100, help="Max concurrent sequences")
    parser.add_argument("--num-workers", type=int, default=64, help="Image preprocessing workers")
    parser.add_argument("--hf-token", help="HF API token")
    parser.add_argument("--private", action="store_true", help="Make output private")

    args = parser.parse_args()
    main(args)