File size: 11,874 Bytes
95257c4
 
 
 
d01f62c
 
 
 
 
95257c4
 
 
d01f62c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95257c4
d01f62c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95257c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d01f62c
95257c4
 
 
d01f62c
95257c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d01f62c
95257c4
 
 
 
 
d01f62c
95257c4
 
 
 
 
 
 
 
 
 
 
 
 
d01f62c
95257c4
d01f62c
95257c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
# test.py — In-process callable version (for ZeroGPU stateless)
# Keep original logic; add build_parser(), run_cli(args_list), and run_inference(args)
# Do NOT initialize CUDA at import-time.

import os
from os import path
from argparse import ArgumentParser
import shutil

# 不在这里做 @spaces.GPU 装饰,避免与 app.py 的 @spaces.GPU 双重调度
# import spaces

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from PIL import Image

from inference.data.test_datasets import DAVISTestDataset_221128_TransColorization_batch
from inference.data.mask_mapper import MaskMapper
from model.network import ColorMNet
from inference.inference_core import InferenceCore

from progressbar import progressbar

from dataset.range_transform import inv_im_trans, inv_lll2rgb_trans

from skimage import color, io
import cv2

try:
    import hickle as hkl
except ImportError:
    print('Failed to import hickle. Fine if not using multi-scale testing.')


# ----------------- small utils -----------------
def detach_to_cpu(x):
    return x.detach().cpu()

def tensor_to_np_float(image):
    image_np = image.numpy().astype('float32')
    return image_np

def lab2rgb_transform_PIL(mask):
    mask_d = detach_to_cpu(mask)
    mask_d = inv_lll2rgb_trans(mask_d)
    im = tensor_to_np_float(mask_d)

    if len(im.shape) == 3:
        im = im.transpose((1, 2, 0))
    else:
        im = im[:, :, None]

    im = color.lab2rgb(im)

    return im.clip(0, 1)


# ----------------- argparse -----------------
def build_parser() -> ArgumentParser:
    parser = ArgumentParser()
    parser.add_argument('--model', default='saves/DINOv2FeatureV6_LocalAtten_s2_154000.pth')
    parser.add_argument('--FirstFrameIsNotExemplar', help='Whether the provided reference frame is exactly the first input frame', action='store_true')

    # dataset setting
    parser.add_argument('--d16_batch_path', default='input', help='Point to folder A/ which contains <video_name>/00000.png etc.')
    parser.add_argument('--ref_path', default='ref', help='Kept for parity; dataset will also read ref.png under each video folder when args provided')
    parser.add_argument('--output', default='result', help='Directory to save results')

    parser.add_argument('--reverse', default=False, action='store_true', help='whether to reverse the frame order')
    parser.add_argument('--allow_resume', action='store_true', 
                        help='skip existing videos that have been colorized')

    # For generic (G) evaluation, point to a folder that contains "JPEGImages" and "Annotations"
    parser.add_argument('--generic_path')
    parser.add_argument('--dataset', help='D16/D17/Y18/Y19/LV1/LV3/G', default='D16_batch')
    parser.add_argument('--split', help='val/test', default='val')
    parser.add_argument('--save_all', action='store_true', 
                        help='Save all frames. Useful only in YouTubeVOS/long-time video')
    parser.add_argument('--benchmark', action='store_true', help='enable to disable amp for FPS benchmarking')
            
    # Long-term memory options
    parser.add_argument('--disable_long_term', action='store_true')
    parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=10)
    parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=5)
    parser.add_argument('--max_long_term_elements', help='LT_max in paper, increase if objects disappear for a long time', 
                                                    type=int, default=10000)
    parser.add_argument('--num_prototypes', help='P in paper', type=int, default=128)

    parser.add_argument('--top_k', type=int, default=30)
    parser.add_argument('--mem_every', help='r in paper. Increase to improve running speed.', type=int, default=5)
    parser.add_argument('--deep_update_every', help='Leave -1 normally to synchronize with mem_every', type=int, default=-1)

    # Multi-scale options
    parser.add_argument('--save_scores', action='store_true')
    parser.add_argument('--flip', action='store_true')
    parser.add_argument('--size', default=-1, type=int, 
                        help='Resize the shorter side to this size. -1 to use original resolution. ')
    return parser


# ----------------- core inference -----------------
def run_inference(args):
    """
    真正的推理流程。必须在 ZeroGPU 的调度上下文里被调用(由 app.py 的 @spaces.GPU 包裹)。
    不要在导入模块时做任何 CUDA 初始化。
    """
    config = vars(args)
    config['enable_long_term'] = not config['disable_long_term']

    if args.output is None:
        args.output = f'.output/{args.dataset}_{args.split}'
        print(f'Output path not provided. Defaulting to {args.output}')

    # ----- Data preparation -----
    is_youtube = args.dataset.startswith('Y')
    is_davis = args.dataset.startswith('D')
    is_lv = args.dataset.startswith('LV')

    if is_youtube or args.save_scores:
        out_path = path.join(args.output, 'Annotations')
    else:
        out_path = args.output

    if args.split != 'val':
        raise NotImplementedError('Only split=val is supported in this script.')

    # 数据集:支持 A/<video>/00000.png ... 且读取 A/<video>/ref.png
    meta_dataset = DAVISTestDataset_221128_TransColorization_batch(
        args.d16_batch_path, imset=args.ref_path, size=args.size, args=args
    )
    palette = None  # 兼容保留

    torch.autograd.set_grad_enabled(False)

    # Set up loader list (video readers)
    meta_loader = meta_dataset.get_datasets()

    # Load checkpoint/model
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    network = ColorMNet(config, args.model).to(device).eval()
    if args.model is not None:
        # map_location 不指定,按默认走(ZeroGPU 下会在被调度的设备上加载)
        model_weights = torch.load(args.model, map_location=device)
        network.load_weights(model_weights, init_as_zero_if_needed=True)
    else:
        print('No model loaded.')

    total_process_time = 0.0
    total_frames = 0

    # ----- Start eval over videos -----
    for vid_reader in progressbar(meta_loader, max_value=len(meta_dataset), redirect_stdout=True):
        # 注意:ZeroGPU/Spaces 环境不允许子进程多线程加载,保持 num_workers=0
        loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)
        vid_name = vid_reader.vid_name
        vid_length = len(loader)

        # LT usage check per original logic
        config['enable_long_term_count_usage'] = (
            config['enable_long_term'] and
            (vid_length
                / (config['max_mid_term_frames'] - config['min_mid_term_frames'])
                * config['num_prototypes'])
            >= config['max_long_term_elements']
        )

        mapper = MaskMapper()
        processor = InferenceCore(network, config=config)
        first_mask_loaded = False

        # skip existing videos 
        if args.allow_resume:
            this_out_path = path.join(out_path, vid_name)
            if path.exists(this_out_path):
                print(f'Skipping {this_out_path} because output already exists.')
                continue

        for ti, data in enumerate(loader):
            with torch.cuda.amp.autocast(enabled=not args.benchmark):
                rgb = data['rgb'].to(device)[0]

                msk = data.get('mask')
                if not config['FirstFrameIsNotExemplar']:
                    msk = msk[:, 1:3, :, :] if msk is not None else None
                    
                info = data['info']
                frame = info['frame'][0]
                shape = info['shape']
                need_resize = info['need_resize'][0]

                start = torch.cuda.Event(enable_timing=True)
                end = torch.cuda.Event(enable_timing=True)
                start.record()

                # 第一次必须有 mask
                if not first_mask_loaded:
                    if msk is not None:
                        first_mask_loaded = True
                    else:
                        continue

                if args.flip:
                    rgb = torch.flip(rgb, dims=[-1])
                    msk = torch.flip(msk, dims=[-1]) if msk is not None else None

                # Map possibly non-continuous labels to continuous ones
                if msk is not None:
                    msk = torch.Tensor(msk[0]).to(device)
                    if need_resize:
                        msk = vid_reader.resize_mask(msk.unsqueeze(0))[0]
                    processor.set_all_labels(list(range(1, 3)))
                    labels = range(1, 3)
                else:
                    labels = None
        
                # Run the model on this frame
                if config['FirstFrameIsNotExemplar']:
                    prob = processor.step_AnyExemplar(
                        rgb,
                        msk[:1, :, :].repeat(3, 1, 1) if msk is not None else None,
                        msk[1:3, :, :] if msk is not None else None,
                        labels,
                        end=(ti == vid_length - 1)
                    )
                else:
                    prob = processor.step(rgb, msk, labels, end=(ti == vid_length - 1))

                # Upsample to original size if needed
                if need_resize:
                    prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:, 0]

                end.record()
                torch.cuda.synchronize()
                total_process_time += (start.elapsed_time(end)/1000)
                total_frames += 1

                if args.flip:
                    prob = torch.flip(prob, dims=[-1])

                if args.save_scores:
                    prob = (prob.detach().cpu().numpy() * 255).astype(np.uint8)

                # Save the mask
                if args.save_all or info['save'][0]:
                    this_out_path = path.join(out_path, vid_name)
                    os.makedirs(this_out_path, exist_ok=True)

                    out_mask_final = lab2rgb_transform_PIL(torch.cat([rgb[:1, :, :], prob], dim=0))
                    out_mask_final = (out_mask_final * 255).astype(np.uint8)

                    out_img = Image.fromarray(out_mask_final)
                    out_img.save(os.path.join(this_out_path, frame[:-4] + '.png'))

    print(f'Total processing time: {total_process_time}')
    print(f'Total processed frames: {total_frames}')
    print(f'FPS: {total_frames / total_process_time}')
    print(f'Max allocated memory (MB): {torch.cuda.max_memory_allocated() / (2**20)}')

    # 与原版一致:只在 save_scores=False 且特定数据集/子集时打 zip
    if not args.save_scores:
        if is_youtube:
            print('Making zip for YouTubeVOS...')
            shutil.make_archive(path.join(args.output, path.basename(args.output)), 'zip', args.output, 'Annotations')
        elif is_davis and args.split == 'test':
            print('Making zip for DAVIS test-dev...')
            shutil.make_archive(args.output, 'zip', args.output)


# ----------------- public entrypoints -----------------
def run_cli(args_list=None):
    """
    供 app.py 同进程调用:test.run_cli(args_list)
    """
    parser = build_parser()
    args = parser.parse_args(args_list)
    return run_inference(args)


def main():
    """
    保留命令行可运行:python test.py --d16_batch_path A --output result ...
    注意:若在 Hugging Face Spaces/ZeroGPU 无状态环境下直接 run main(),
    需要由上层(如 app.py 的 @spaces.GPU)提供调度上下文。
    """
    run_cli()


if __name__ == '__main__':
    main()