|
|
import os |
|
|
import random |
|
|
import math |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from torch.utils.data import Dataset |
|
|
from torchvision import transforms |
|
|
|
|
|
|
|
|
class CustomCocoDataset(Dataset): |
|
|
def __init__(self, img_folder, img_size=512, hint_size=448): |
|
|
self.img_folder = img_folder |
|
|
self.img_size = img_size |
|
|
self.hint_size = hint_size |
|
|
self.ids = [os.path.splitext(f)[0] for f in os.listdir(img_folder) if f.endswith(('.jpg', '.jpeg', '.png'))] |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.ids) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
img_id = self.ids[index] |
|
|
img_path = os.path.join(self.img_folder, img_id + '.png') |
|
|
image = Image.open(img_path).convert('RGB') |
|
|
|
|
|
|
|
|
cropped_image = random_crop_arr(image, self.img_size, min_crop_frac=0.8, max_crop_frac=1.0) |
|
|
|
|
|
|
|
|
cropped_image = Image.fromarray(cropped_image) |
|
|
|
|
|
|
|
|
jpg_image = transforms.functional.to_tensor(cropped_image) |
|
|
hint_image = transforms.functional.resize(cropped_image, (self.hint_size, self.hint_size), interpolation=transforms.InterpolationMode.BICUBIC) |
|
|
hint_image = transforms.functional.to_tensor(hint_image) |
|
|
|
|
|
|
|
|
prompt = "" |
|
|
|
|
|
return dict(jpg=jpg_image, txt=prompt, hint=hint_image) |
|
|
|
|
|
def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): |
|
|
min_smaller_dim_size = math.ceil(image_size / max_crop_frac) |
|
|
max_smaller_dim_size = math.ceil(image_size / min_crop_frac) |
|
|
smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
while min(*pil_image.size) >= 2 * smaller_dim_size: |
|
|
pil_image = pil_image.resize( |
|
|
tuple(x // 2 for x in pil_image.size), resample=Image.BOX |
|
|
) |
|
|
|
|
|
scale = smaller_dim_size / min(*pil_image.size) |
|
|
pil_image = pil_image.resize( |
|
|
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC |
|
|
) |
|
|
|
|
|
arr = np.array(pil_image) |
|
|
crop_y = random.randrange(arr.shape[0] - image_size + 1) |
|
|
crop_x = random.randrange(arr.shape[1] - image_size + 1) |
|
|
return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
dataset = CustomCocoDataset("/home/t2vg-a100-G4-1/projects/dataset/LSDIR_raw/images/train") |
|
|
print(len(dataset)) |
|
|
print(dataset[0]) |
|
|
|
|
|
from torch.utils.data import DataLoader |
|
|
dataloader = DataLoader( |
|
|
dataset, batch_size=4, num_workers=2, |
|
|
pin_memory=True, drop_last=True) |
|
|
|
|
|
|
|
|
batch = next(iter(dataloader)) |
|
|
|
|
|
|
|
|
jpg_images = batch['jpg'] |
|
|
hint_images = batch['hint'] |
|
|
prompts = batch['txt'] |
|
|
|
|
|
|
|
|
print(f"Prompt: {prompts}") |
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
for i in range(len(jpg_images)): |
|
|
plt.figure(figsize=(10, 5)) |
|
|
|
|
|
plt.subplot(1, 2, 1) |
|
|
plt.title(f"JPG Image {i+1} (512x512)") |
|
|
plt.imshow(jpg_images[i].permute(1, 2, 0)) |
|
|
|
|
|
plt.subplot(1, 2, 2) |
|
|
plt.title(f"Hint Image {i+1} (448x448)") |
|
|
plt.imshow(hint_images[i].permute(1, 2, 0)) |
|
|
|
|
|
|
|
|
plt.savefig(f'output_image_{i+1}.png') |
|
|
|
|
|
|
|
|
plt.close() |
|
|
|
|
|
|
|
|
|