DiffICM / 4_ControlModule /dataset /build_dataset.py
Qiyp's picture
code of stage1 & 3, remove large files
1633fcc
import os
import random
import math
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
class CustomCocoDataset(Dataset):
def __init__(self, img_folder, img_size=512, hint_size=448):
self.img_folder = img_folder
self.img_size = img_size
self.hint_size = hint_size
self.ids = [os.path.splitext(f)[0] for f in os.listdir(img_folder) if f.endswith(('.jpg', '.jpeg', '.png'))]
def __len__(self):
return len(self.ids)
def __getitem__(self, index):
img_id = self.ids[index]
img_path = os.path.join(self.img_folder, img_id + '.png')
image = Image.open(img_path).convert('RGB')
# Perform a random crop using the custom random_crop_arr function
cropped_image = random_crop_arr(image, self.img_size, min_crop_frac=0.8, max_crop_frac=1.0)
# Convert cropped image back to PIL for further processing
cropped_image = Image.fromarray(cropped_image)
# Resize to different resolutions
jpg_image = transforms.functional.to_tensor(cropped_image)
hint_image = transforms.functional.resize(cropped_image, (self.hint_size, self.hint_size), interpolation=transforms.InterpolationMode.BICUBIC)
hint_image = transforms.functional.to_tensor(hint_image)
# Set captions to an empty string
prompt = ""
return dict(jpg=jpg_image, txt=prompt, hint=hint_image)
def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
# We are not on a new enough PIL to support the reducing_gap
# argument, which uses BOX downsampling at powers of two first.
# Thus, we do it by hand to improve downsample quality.
while min(*pil_image.size) >= 2 * smaller_dim_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = smaller_dim_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = random.randrange(arr.shape[0] - image_size + 1)
crop_x = random.randrange(arr.shape[1] - image_size + 1)
return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
if __name__ == "__main__":
dataset = CustomCocoDataset("/home/t2vg-a100-G4-1/projects/dataset/LSDIR_raw/images/train")
print(len(dataset))
print(dataset[0])
from torch.utils.data import DataLoader
dataloader = DataLoader(
dataset, batch_size=4, num_workers=2,
pin_memory=True, drop_last=True)
# 从 DataLoader 中取出一个批次
batch = next(iter(dataloader))
# 提取批次中的 jpg_image 和 hint_image
jpg_images = batch['jpg']
hint_images = batch['hint']
prompts = batch['txt']
# 打印提示语
print(f"Prompt: {prompts}")
# 可视化并保存第一个batch的图像
import matplotlib.pyplot as plt
for i in range(len(jpg_images)):
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title(f"JPG Image {i+1} (512x512)")
plt.imshow(jpg_images[i].permute(1, 2, 0)) # 转换维度以便imshow使用
plt.subplot(1, 2, 2)
plt.title(f"Hint Image {i+1} (448x448)")
plt.imshow(hint_images[i].permute(1, 2, 0)) # 转换维度以便imshow使用
# 保存图像到文件
plt.savefig(f'output_image_{i+1}.png')
# 关闭当前图像,释放内存
plt.close()