File size: 3,781 Bytes
1633fcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import random
import math
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms


class CustomCocoDataset(Dataset):
    def __init__(self, img_folder, img_size=512, hint_size=448):
        self.img_folder = img_folder
        self.img_size = img_size
        self.hint_size = hint_size
        self.ids = [os.path.splitext(f)[0] for f in os.listdir(img_folder) if f.endswith(('.jpg', '.jpeg', '.png'))]

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index):
        img_id = self.ids[index]
        img_path = os.path.join(self.img_folder, img_id + '.png')
        image = Image.open(img_path).convert('RGB')

        # Perform a random crop using the custom random_crop_arr function
        cropped_image = random_crop_arr(image, self.img_size, min_crop_frac=0.8, max_crop_frac=1.0)

        # Convert cropped image back to PIL for further processing
        cropped_image = Image.fromarray(cropped_image)

        # Resize to different resolutions
        jpg_image = transforms.functional.to_tensor(cropped_image)
        hint_image = transforms.functional.resize(cropped_image, (self.hint_size, self.hint_size), interpolation=transforms.InterpolationMode.BICUBIC)
        hint_image = transforms.functional.to_tensor(hint_image)

        # Set captions to an empty string
        prompt = ""

        return dict(jpg=jpg_image, txt=prompt, hint=hint_image)

def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
    min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
    max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
    smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)

    # We are not on a new enough PIL to support the reducing_gap
    # argument, which uses BOX downsampling at powers of two first.
    # Thus, we do it by hand to improve downsample quality.
    while min(*pil_image.size) >= 2 * smaller_dim_size:
        pil_image = pil_image.resize(
            tuple(x // 2 for x in pil_image.size), resample=Image.BOX
        )

    scale = smaller_dim_size / min(*pil_image.size)
    pil_image = pil_image.resize(
        tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
    )

    arr = np.array(pil_image)
    crop_y = random.randrange(arr.shape[0] - image_size + 1)
    crop_x = random.randrange(arr.shape[1] - image_size + 1)
    return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]


if __name__ == "__main__":
    dataset = CustomCocoDataset("/home/t2vg-a100-G4-1/projects/dataset/LSDIR_raw/images/train")
    print(len(dataset))
    print(dataset[0])

    from torch.utils.data import DataLoader
    dataloader = DataLoader(
        dataset, batch_size=4, num_workers=2,
        pin_memory=True, drop_last=True)
    
    # 从 DataLoader 中取出一个批次
    batch = next(iter(dataloader))

    # 提取批次中的 jpg_image 和 hint_image
    jpg_images = batch['jpg']
    hint_images = batch['hint']
    prompts = batch['txt']

    # 打印提示语
    print(f"Prompt: {prompts}")

    # 可视化并保存第一个batch的图像
    import matplotlib.pyplot as plt

    for i in range(len(jpg_images)):
        plt.figure(figsize=(10, 5))
    
        plt.subplot(1, 2, 1)
        plt.title(f"JPG Image {i+1} (512x512)")
        plt.imshow(jpg_images[i].permute(1, 2, 0))  # 转换维度以便imshow使用

        plt.subplot(1, 2, 2)
        plt.title(f"Hint Image {i+1} (448x448)")
        plt.imshow(hint_images[i].permute(1, 2, 0))  # 转换维度以便imshow使用

        # 保存图像到文件
        plt.savefig(f'output_image_{i+1}.png')

        # 关闭当前图像,释放内存
        plt.close()