File size: 15,308 Bytes
4910a8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
# Copyright (c) 2025 Fudan University. All rights reserved.


from io import BytesIO
import random
from PIL import Image
import numpy as np
import cv2
import insightface
import torch
from torchvision import transforms
from torch.cuda.amp import autocast

def face_preserving_resize(img, face_bboxes, target_size=512):
    """
    Resize image while ensuring all faces are preserved in the output.
    
    Args:
        img: PIL Image
        face_bboxes: List of [x1, y1, x2, y2] face coordinates
        target_size: Maximum dimension for resizing
        
    Returns:
        Tuple of (resized image, new_bboxes) or (None, None) if faces can't fit
    """
    
    x1_1, y1_1, x2_1, y2_1 = map(int, face_bboxes[0])
    x1_2, y1_2, x2_2, y2_2 = map(int, face_bboxes[1])
    min_x1 = min(x1_1, x1_2)
    min_y1 = min(y1_1, y1_2)
    max_x2 = max(x2_1, x2_2)
    max_y2 = max(y2_1, y2_2)
    # print("min_x1:", min_x1, "min_y1:", min_y1, "max_x2:", max_x2, "max_y2:", max_y2)
    # if any of them is negative, we cannot resize (Idk why this happens)
    if min_x1 < 0 or min_y1 < 0 or max_x2 < 0 or max_y2 < 0:
        return None, None

    # if face width is longer than the image height, or the face height is longer than the image width, we cannot resize
    face_width = max_x2 - min_x1
    face_height = max_y2 - min_y1
    if face_width > img.height or face_height > img.width:
        return None, None
        
    # Create a copy of face_bboxes for transformation
    new_bboxes = []
    for bbox in face_bboxes:
        new_bboxes.append(list(map(int, bbox)))
    
    # Choose cropping strategy based on image aspect ratio
    if img.width > img.height:
        # We need to crop width to make a square
        square_size = img.height
        
        # Calculate valid horizontal crop range that preserves all faces
        left_max = min_x1  # Leftmost position that includes leftmost face
        right_min = max_x2 - square_size  # Rightmost position that includes rightmost face
        
        if right_min <= left_max:
            # We can find a valid crop window
            start = random.randint(int(right_min), int(left_max)) if right_min < left_max else int(right_min)
            start = max(0, min(start, img.width - square_size))  # Ensure within image bounds
        else:
            # Faces are too far apart for square crop - use center of faces
            face_center = (min_x1 + max_x2) // 2
            start = max(0, min(face_center - (square_size // 2), img.width - square_size))
        
        cropped_img = img.crop((start, 0, start + square_size, square_size))
        
        # Adjust bounding box coordinates based on crop
        for bbox in new_bboxes:
            bbox[0] -= start  # x1 adjustment
            bbox[2] -= start  # x2 adjustment
            # y coordinates remain unchanged
    else:
        # We need to crop height to make a square
        square_size = img.width
        
        # Calculate valid vertical crop range that preserves all faces
        top_max = min_y1  # Topmost position that includes topmost face
        bottom_min = max_y2 - square_size  # Bottommost position that includes bottommost face
        
        if bottom_min <= top_max:
            # We can find a valid crop window
            start = random.randint(int(bottom_min), int(top_max)) if bottom_min < top_max else int(bottom_min)
            start = max(0, min(start, img.height - square_size))  # Ensure within image bounds
        else:
            # Faces are too far apart for square crop - use center of faces
            face_center = (min_y1 + max_y2) // 2
            start = max(0, min(face_center - (square_size // 2), img.height - square_size))
        
        cropped_img = img.crop((0, start, square_size, start + square_size))
        
        # Adjust bounding box coordinates based on crop
        for bbox in new_bboxes:
            bbox[1] -= start  # y1 adjustment
            bbox[3] -= start  # y2 adjustment
            # x coordinates remain unchanged
    
    # Calculate scale factor for resizing from square_size to target_size
    scale_factor = target_size / square_size
    
    # Adjust bounding boxes for the resize operation
    for bbox in new_bboxes:
        bbox[0] = int(bbox[0] * scale_factor)
        bbox[1] = int(bbox[1] * scale_factor)
        bbox[2] = int(bbox[2] * scale_factor)
        bbox[3] = int(bbox[3] * scale_factor)
    
    # Final resize to target size
    resized_img = cropped_img.resize((target_size, target_size), Image.Resampling.LANCZOS)
    
    # Make sure all coordinates are within bounds (0 to target_size)
    # for bbox in new_bboxes:
    #     bbox[0] = max(0, min(bbox[0], target_size - 1))
    #     bbox[1] = max(0, min(bbox[1], target_size - 1))
    #     bbox[2] = max(1, min(bbox[2], target_size))
    #     bbox[3] = max(1, min(bbox[3], target_size))
    
    return resized_img, new_bboxes

def extract_moref(img, json_data, face_size_restriction=100):
    """
    Extract faces from an image based on bounding boxes in JSON data.
    Makes each face square and resizes to 512x512.
    
    Args:
        img: PIL Image or image data
        json_data: JSON object with 'bboxes' and 'crop' information
        
    Returns:
        List of PIL Images, each 512x512, containing extracted faces
    """
    # Ensure img is a PIL Image
    try:
        if not isinstance(img, Image.Image) and not isinstance(img, torch.Tensor) and not isinstance(img, JpegImageFile):
            img = Image.open(BytesIO(img))
        
        bboxes = json_data['bboxes']
        # crop = json_data['crop']
        # print("len of bboxes:", len(bboxes))
        # Recalculate bounding boxes based on crop info
        # new_bboxes = [recalculate_bbox(bbox, crop) for bbox in bboxes]
        new_bboxes = bboxes
        # any of the face is less than 100 * 100, we ignore this image
        for bbox in new_bboxes:
            x1, y1, x2, y2 = bbox
            if x2 - x1 < face_size_restriction or y2 - y1 < face_size_restriction:
                return []
        # print("len of new_bboxes:", len(new_bboxes))
        faces = []
        for bbox in new_bboxes:
            # print("processing bbox")
            # Convert coordinates to integers
            x1, y1, x2, y2 = map(int, bbox)
            
            # Calculate width and height
            width = x2 - x1
            height = y2 - y1
            
            # Make the bounding box square by expanding the shorter dimension
            if width > height:
                # Height is shorter, expand it
                diff = width - height
                y1 -= diff // 2
                y2 += diff - (diff // 2)  # Handle odd differences
            elif height > width:
                # Width is shorter, expand it
                diff = height - width
                x1 -= diff // 2
                x2 += diff - (diff // 2)  # Handle odd differences
            
            # Ensure coordinates are within image boundaries
            img_width, img_height = img.size
            x1 = max(0, x1)
            y1 = max(0, y1)
            x2 = min(img_width, x2)
            y2 = min(img_height, y2)
            
            # Extract face region
            face_region = img.crop((x1, y1, x2, y2))
            
            # Resize to 512x512
            face_region = face_region.resize((512, 512), Image.LANCZOS)
            
            faces.append(face_region)
        # print("len of faces:", len(faces))
        return faces
    except Exception as e:
        print(f"Error processing image: {e}")
        return []

def general_face_preserving_resize(img, face_bboxes, target_size=512):
    """
    Resize image while ensuring all faces are preserved in the output.
    Handles any number of faces (1-5).
    
    Args:
        img: PIL Image
        face_bboxes: List of [x1, y1, x2, y2] face coordinates
        target_size: Maximum dimension for resizing
        
    Returns:
        Tuple of (resized image, new_bboxes) or (None, None) if faces can't fit
    """
    # Find bounding region containing all faces
    if not face_bboxes:
        print("Warning: No face bounding boxes provided.")
        return None, None
        
    min_x1 = min(bbox[0] for bbox in face_bboxes)
    min_y1 = min(bbox[1] for bbox in face_bboxes)
    max_x2 = max(bbox[2] for bbox in face_bboxes)
    max_y2 = max(bbox[3] for bbox in face_bboxes)

    # Check for negative coordinates
    if min_x1 < 0 or min_y1 < 0 or max_x2 < 0 or max_y2 < 0:
        # print("Warning: Negative coordinates found in face bounding boxes.")
        # return None, None
        min_x1 = max(min_x1, 0)
        min_y1 = max(min_y1, 0)

    # Check if faces fit within image
    face_width = max_x2 - min_x1
    face_height = max_y2 - min_y1
    if face_width > img.height or face_height > img.width:
        # print("Warning: Faces are too large for the image dimensions.")
        # return None, None
        # Instead of returning None, we will crop the image to fit the faces
        max_x2 = min(max_x2, img.width)
        max_y2 = min(max_y2, img.height)
        min_x1 = max(min_x1, 0)
        min_y1 = max(min_y1, 0)
    # Create a copy of face_bboxes for transformation
    new_bboxes = []
    for bbox in face_bboxes:
        new_bboxes.append(list(map(int, bbox)))
    
    # Choose cropping strategy based on image aspect ratio
    if img.width > img.height:
        # Crop width to make a square
        square_size = img.height
        
        # Calculate valid horizontal crop range
        left_max = min_x1
        right_min = max_x2 - square_size
        
        if right_min <= left_max:
            # We can find a valid crop window
            start = random.randint(int(right_min), int(left_max)) if right_min < left_max else int(right_min)
            start = max(0, min(start, img.width - square_size))
        else:
            # Faces are too far apart - use center of faces
            face_center = (min_x1 + max_x2) // 2
            start = max(0, min(face_center - (square_size // 2), img.width - square_size))
        
        cropped_img = img.crop((start, 0, start + square_size, square_size))
        
        # Adjust bounding box coordinates
        for bbox in new_bboxes:
            bbox[0] -= start
            bbox[2] -= start
    else:
        # Crop height to make a square
        square_size = img.width
        
        # Calculate valid vertical crop range
        top_max = min_y1
        bottom_min = max_y2 - square_size
        
        if bottom_min <= top_max:
            start = random.randint(int(bottom_min), int(top_max)) if bottom_min < top_max else int(bottom_min)
            start = max(0, min(start, img.height - square_size))
        else:
            face_center = (min_y1 + max_y2) // 2
            start = max(0, min(face_center - (square_size // 2), img.height - square_size))
        
        cropped_img = img.crop((0, start, square_size, start + square_size))
        
        # Adjust bounding box coordinates
        for bbox in new_bboxes:
            bbox[1] -= start
            bbox[3] -= start
    
    # Calculate scale factor and adjust bounding boxes
    scale_factor = target_size / square_size
    
    for bbox in new_bboxes:
        bbox[0] = int(bbox[0] * scale_factor)
        bbox[1] = int(bbox[1] * scale_factor)
        bbox[2] = int(bbox[2] * scale_factor)
        bbox[3] = int(bbox[3] * scale_factor)
    
    # Final resize to target size
    resized_img = cropped_img.resize((target_size, target_size), Image.Resampling.LANCZOS)
    
    # Make sure all coordinates are within bounds
    for bbox in new_bboxes:
        bbox[0] = max(0, min(bbox[0], target_size - 1))
        bbox[1] = max(0, min(bbox[1], target_size - 1))
        bbox[2] = max(1, min(bbox[2], target_size))
        bbox[3] = max(1, min(bbox[3], target_size))
    
    return resized_img, new_bboxes



def horizontal_concat(images):
    widths, heights = zip(*(img.size for img in images))

    total_width = sum(widths)
    max_height = max(heights)

    new_im = Image.new('RGB', (total_width, max_height))

    x_offset = 0
    for img in images:
        new_im.paste(img, (x_offset, 0))
        x_offset += img.size[0]

    return new_im

def extract_object(birefnet, image):


    if image.mode != 'RGB':
        image = image.convert('RGB')
    input_images = transforms.ToTensor()(image).unsqueeze(0).to('cuda', dtype=torch.bfloat16)

    # Prediction
    with torch.no_grad(), autocast(dtype=torch.bfloat16):
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze().float()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image.size)
    
    # Create a binary mask (0 or 255)
    binary_mask = mask.convert("L")
    
    # Create a new image with black background
    result = Image.new("RGB", image.size, (0, 0, 0))
    
    # Paste the original image onto the black background using the mask
    result.paste(image, (0, 0), binary_mask)
    
    return result, mask

class FaceExtractor:
    def __init__(self):
        self.model = insightface.app.FaceAnalysis(name = "antelopev2", root="./")
        self.model.prepare(ctx_id=0, det_thresh=0.4)

    def extract(self, image: Image.Image):
        image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
        res = self.model.get(image_np)
        if len(res) == 0:
            return None, None
        res = res[0]
        # print(res.keys())
        bbox = res["bbox"]
        # print("len(bbox)", len(bbox))

        moref = extract_moref(image, {"bboxes": [bbox]}, 1)
        # print("len(moref)", len(moref))
        return moref[0], res["embedding"]

    def locate_bboxes(self, image: Image.Image):
        image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
        res = self.model.get(image_np)
        if len(res) == 0:
            return None
        bboxes = []
        for r in res:
            bbox = r["bbox"]
            bboxes.append(bbox)

        _, new_bboxes_ = general_face_preserving_resize(image, bboxes, 512)

        # ensure the bbox is square
        new_bboxes = []
        for bbox in new_bboxes_:
            x1, y1, x2, y2 = bbox
            w = x2 - x1
            h = y2 - y1
            if w > h:
                diff = w - h
                y1 = max(0, y1 - diff // 2)
                y2 = min(512, y2 + diff // 2 + diff % 2)
            else:
                diff = h - w
                x1 = max(0, x1 - diff // 2)
                x2 = min(512, x2 + diff // 2 + diff % 2)
            new_bboxes.append([x1, y1, x2, y2])

        return new_bboxes
    def extract_refs(self, image: Image.Image):
        """
        Extracts reference faces from the image.
        Returns a list of reference images and their arcface embeddings.
        """
        image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
        res = self.model.get(image_np)
        if len(res) == 0:
            return None, None
        ref_imgs = []
        arcface_embeddings = []
        for r in res:
            bbox = r["bbox"]
            moref = extract_moref(image, {"bboxes": [bbox]}, 1)
            ref_imgs.append(moref[0])
            arcface_embeddings.append(r["embedding"])
        return ref_imgs, arcface_embeddings