|
|
import pathlib |
|
|
import tempfile |
|
|
import unittest |
|
|
|
|
|
import torch |
|
|
from PIL import Image |
|
|
|
|
|
from finetrainers.data import ( |
|
|
ImageCaptionFilePairDataset, |
|
|
ImageFileCaptionFileListDataset, |
|
|
ImageFolderDataset, |
|
|
ValidationDataset, |
|
|
VideoCaptionFilePairDataset, |
|
|
VideoFileCaptionFileListDataset, |
|
|
VideoFolderDataset, |
|
|
VideoWebDataset, |
|
|
initialize_dataset, |
|
|
) |
|
|
from finetrainers.utils import find_files |
|
|
|
|
|
from .utils import create_dummy_directory_structure |
|
|
|
|
|
|
|
|
class DatasetTesterMixin: |
|
|
num_data_files = None |
|
|
directory_structure = None |
|
|
caption = "A cat ruling the world" |
|
|
metadata_extension = None |
|
|
|
|
|
def setUp(self): |
|
|
if self.num_data_files is None: |
|
|
raise ValueError("num_data_files is not defined") |
|
|
if self.directory_structure is None: |
|
|
raise ValueError("dataset_structure is not defined") |
|
|
|
|
|
self.tmpdir = tempfile.TemporaryDirectory() |
|
|
create_dummy_directory_structure( |
|
|
self.directory_structure, self.tmpdir, self.num_data_files, self.caption, self.metadata_extension |
|
|
) |
|
|
|
|
|
def tearDown(self): |
|
|
self.tmpdir.cleanup() |
|
|
|
|
|
|
|
|
class ImageDatasetTesterMixin(DatasetTesterMixin): |
|
|
metadata_extension = "jpg" |
|
|
|
|
|
|
|
|
class VideoDatasetTesterMixin(DatasetTesterMixin): |
|
|
metadata_extension = "mp4" |
|
|
|
|
|
|
|
|
class ImageCaptionFilePairDatasetFastTests(ImageDatasetTesterMixin, unittest.TestCase): |
|
|
num_data_files = 3 |
|
|
directory_structure = [ |
|
|
"0.jpg", |
|
|
"1.jpg", |
|
|
"2.jpg", |
|
|
"0.txt", |
|
|
"1.txt", |
|
|
"2.txt", |
|
|
] |
|
|
|
|
|
def setUp(self): |
|
|
super().setUp() |
|
|
self.dataset = ImageCaptionFilePairDataset(self.tmpdir.name, infinite=False) |
|
|
|
|
|
def test_getitem(self): |
|
|
iterator = iter(self.dataset) |
|
|
for _ in range(self.num_data_files): |
|
|
item = next(iterator) |
|
|
self.assertEqual(item["caption"], self.caption) |
|
|
self.assertTrue(torch.is_tensor(item["image"])) |
|
|
self.assertEqual(item["image"].shape, (3, 64, 64)) |
|
|
|
|
|
def test_initialize_dataset(self): |
|
|
dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False) |
|
|
self.assertIsInstance(dataset, ImageCaptionFilePairDataset) |
|
|
|
|
|
|
|
|
class ImageFileCaptionFileListDatasetFastTests(ImageDatasetTesterMixin, unittest.TestCase): |
|
|
num_data_files = 3 |
|
|
directory_structure = [ |
|
|
"prompts.txt", |
|
|
"images.txt", |
|
|
"images/", |
|
|
"images/0.jpg", |
|
|
"images/1.jpg", |
|
|
"images/2.jpg", |
|
|
] |
|
|
|
|
|
def setUp(self): |
|
|
super().setUp() |
|
|
self.dataset = ImageFileCaptionFileListDataset(self.tmpdir.name, infinite=False) |
|
|
|
|
|
def test_getitem(self): |
|
|
iterator = iter(self.dataset) |
|
|
for i in range(3): |
|
|
item = next(iterator) |
|
|
self.assertEqual(item["caption"], self.caption) |
|
|
self.assertTrue(torch.is_tensor(item["image"])) |
|
|
self.assertEqual(item["image"].shape, (3, 64, 64)) |
|
|
|
|
|
def test_initialize_dataset(self): |
|
|
dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False) |
|
|
self.assertIsInstance(dataset, ImageFileCaptionFileListDataset) |
|
|
|
|
|
|
|
|
class ImageFolderDatasetFastTests___CSV(ImageDatasetTesterMixin, unittest.TestCase): |
|
|
num_data_files = 3 |
|
|
directory_structure = [ |
|
|
"metadata.csv", |
|
|
"0.jpg", |
|
|
"1.jpg", |
|
|
"2.jpg", |
|
|
] |
|
|
|
|
|
def setUp(self): |
|
|
super().setUp() |
|
|
self.dataset = ImageFolderDataset(self.tmpdir.name, infinite=False) |
|
|
|
|
|
def test_getitem(self): |
|
|
iterator = iter(self.dataset) |
|
|
for _ in range(3): |
|
|
item = next(iterator) |
|
|
self.assertIn("caption", item) |
|
|
self.assertEqual(item["caption"], self.caption) |
|
|
self.assertTrue(torch.is_tensor(item["image"])) |
|
|
|
|
|
def test_initialize_dataset(self): |
|
|
dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False) |
|
|
self.assertIsInstance(dataset, ImageFolderDataset) |
|
|
|
|
|
|
|
|
class ImageFolderDatasetFastTests___JSONL(ImageDatasetTesterMixin, unittest.TestCase): |
|
|
num_data_files = 3 |
|
|
directory_structure = [ |
|
|
"metadata.jsonl", |
|
|
"0.jpg", |
|
|
"1.jpg", |
|
|
"2.jpg", |
|
|
] |
|
|
|
|
|
def setUp(self): |
|
|
super().setUp() |
|
|
self.dataset = ImageFolderDataset(self.tmpdir.name, infinite=False) |
|
|
|
|
|
def test_getitem(self): |
|
|
iterator = iter(self.dataset) |
|
|
for _ in range(3): |
|
|
item = next(iterator) |
|
|
self.assertIn("caption", item) |
|
|
self.assertEqual(item["caption"], self.caption) |
|
|
self.assertTrue(torch.is_tensor(item["image"])) |
|
|
|
|
|
def test_initialize_dataset(self): |
|
|
dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False) |
|
|
self.assertIsInstance(dataset, ImageFolderDataset) |
|
|
|
|
|
|
|
|
class VideoCaptionFilePairDatasetFastTests(VideoDatasetTesterMixin, unittest.TestCase): |
|
|
num_data_files = 3 |
|
|
directory_structure = [ |
|
|
"0.mp4", |
|
|
"1.mp4", |
|
|
"2.mp4", |
|
|
"0.txt", |
|
|
"1.txt", |
|
|
"2.txt", |
|
|
] |
|
|
|
|
|
def setUp(self): |
|
|
super().setUp() |
|
|
self.dataset = VideoCaptionFilePairDataset(self.tmpdir.name, infinite=False) |
|
|
|
|
|
def test_getitem(self): |
|
|
iterator = iter(self.dataset) |
|
|
for _ in range(self.num_data_files): |
|
|
item = next(iterator) |
|
|
self.assertEqual(item["caption"], self.caption) |
|
|
self.assertTrue(torch.is_tensor(item["video"])) |
|
|
self.assertEqual(len(item["video"]), 4) |
|
|
self.assertEqual(item["video"][0].shape, (3, 64, 64)) |
|
|
|
|
|
def test_initialize_dataset(self): |
|
|
dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False) |
|
|
self.assertIsInstance(dataset, VideoCaptionFilePairDataset) |
|
|
|
|
|
|
|
|
class VideoFileCaptionFileListDatasetFastTests(VideoDatasetTesterMixin, unittest.TestCase): |
|
|
num_data_files = 3 |
|
|
directory_structure = [ |
|
|
"prompts.txt", |
|
|
"videos.txt", |
|
|
"videos/", |
|
|
"videos/0.mp4", |
|
|
"videos/1.mp4", |
|
|
"videos/2.mp4", |
|
|
] |
|
|
|
|
|
def setUp(self): |
|
|
super().setUp() |
|
|
self.dataset = VideoFileCaptionFileListDataset(self.tmpdir.name, infinite=False) |
|
|
|
|
|
def test_getitem(self): |
|
|
iterator = iter(self.dataset) |
|
|
for _ in range(3): |
|
|
item = next(iterator) |
|
|
self.assertEqual(item["caption"], self.caption) |
|
|
self.assertTrue(torch.is_tensor(item["video"])) |
|
|
self.assertEqual(len(item["video"]), 4) |
|
|
self.assertEqual(item["video"][0].shape, (3, 64, 64)) |
|
|
|
|
|
def test_initialize_dataset(self): |
|
|
dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False) |
|
|
self.assertIsInstance(dataset, VideoFileCaptionFileListDataset) |
|
|
|
|
|
|
|
|
class VideoFolderDatasetFastTests___CSV(VideoDatasetTesterMixin, unittest.TestCase): |
|
|
num_data_files = 3 |
|
|
directory_structure = [ |
|
|
"metadata.csv", |
|
|
"0.mp4", |
|
|
"1.mp4", |
|
|
"2.mp4", |
|
|
] |
|
|
|
|
|
def setUp(self): |
|
|
super().setUp() |
|
|
self.dataset = VideoFolderDataset(self.tmpdir.name, infinite=False) |
|
|
|
|
|
def test_getitem(self): |
|
|
iterator = iter(self.dataset) |
|
|
for _ in range(3): |
|
|
item = next(iterator) |
|
|
self.assertIn("caption", item) |
|
|
self.assertEqual(item["caption"], self.caption) |
|
|
self.assertTrue(torch.is_tensor(item["video"])) |
|
|
self.assertEqual(len(item["video"]), 4) |
|
|
self.assertEqual(item["video"][0].shape, (3, 64, 64)) |
|
|
|
|
|
def test_initialize_dataset(self): |
|
|
dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False) |
|
|
self.assertIsInstance(dataset, VideoFolderDataset) |
|
|
|
|
|
|
|
|
class VideoFolderDatasetFastTests___JSONL(VideoDatasetTesterMixin, unittest.TestCase): |
|
|
num_data_files = 3 |
|
|
directory_structure = [ |
|
|
"metadata.jsonl", |
|
|
"0.mp4", |
|
|
"1.mp4", |
|
|
"2.mp4", |
|
|
] |
|
|
|
|
|
def setUp(self): |
|
|
super().setUp() |
|
|
self.dataset = VideoFolderDataset(self.tmpdir.name, infinite=False) |
|
|
|
|
|
def test_getitem(self): |
|
|
iterator = iter(self.dataset) |
|
|
for _ in range(3): |
|
|
item = next(iterator) |
|
|
self.assertIn("caption", item) |
|
|
self.assertEqual(item["caption"], self.caption) |
|
|
self.assertTrue(torch.is_tensor(item["video"])) |
|
|
self.assertEqual(len(item["video"]), 4) |
|
|
self.assertEqual(item["video"][0].shape, (3, 64, 64)) |
|
|
|
|
|
def test_initialize_dataset(self): |
|
|
dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False) |
|
|
self.assertIsInstance(dataset, VideoFolderDataset) |
|
|
|
|
|
|
|
|
class ImageWebDatasetFastTests(unittest.TestCase): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
class VideoWebDatasetFastTests(unittest.TestCase): |
|
|
def setUp(self): |
|
|
self.num_data_files = 15 |
|
|
self.dataset = VideoWebDataset("finetrainers/dummy-squish-wds", infinite=False) |
|
|
|
|
|
def test_getitem(self): |
|
|
for index, item in enumerate(self.dataset): |
|
|
if index > 2: |
|
|
break |
|
|
self.assertIn("caption", item) |
|
|
self.assertIn("video", item) |
|
|
self.assertTrue(torch.is_tensor(item["video"])) |
|
|
self.assertEqual(len(item["video"]), 121) |
|
|
self.assertEqual(item["video"][0].shape, (3, 720, 1280)) |
|
|
|
|
|
def test_initialize_dataset(self): |
|
|
dataset = initialize_dataset("finetrainers/dummy-squish-wds", "video", infinite=False) |
|
|
self.assertIsInstance(dataset, VideoWebDataset) |
|
|
|
|
|
|
|
|
class DatasetUtilsFastTests(unittest.TestCase): |
|
|
def test_find_files_depth_0(self): |
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
file1 = tempfile.NamedTemporaryFile(dir=tmpdir, suffix=".txt", delete=False) |
|
|
file2 = tempfile.NamedTemporaryFile(dir=tmpdir, suffix=".txt", delete=False) |
|
|
file3 = tempfile.NamedTemporaryFile(dir=tmpdir, suffix=".txt", delete=False) |
|
|
|
|
|
files = find_files(tmpdir, "*.txt") |
|
|
self.assertEqual(len(files), 3) |
|
|
self.assertIn(file1.name, files) |
|
|
self.assertIn(file2.name, files) |
|
|
self.assertIn(file3.name, files) |
|
|
|
|
|
def test_find_files_depth_n(self): |
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
dir1 = tempfile.TemporaryDirectory(dir=tmpdir) |
|
|
dir2 = tempfile.TemporaryDirectory(dir=dir1.name) |
|
|
file1 = tempfile.NamedTemporaryFile(dir=dir1.name, suffix=".txt", delete=False) |
|
|
file2 = tempfile.NamedTemporaryFile(dir=dir2.name, suffix=".txt", delete=False) |
|
|
|
|
|
files = find_files(tmpdir, "*.txt", depth=1) |
|
|
self.assertEqual(len(files), 1) |
|
|
self.assertIn(file1.name, files) |
|
|
self.assertNotIn(file2.name, files) |
|
|
|
|
|
files = find_files(tmpdir, "*.txt", depth=2) |
|
|
self.assertEqual(len(files), 2) |
|
|
self.assertIn(file1.name, files) |
|
|
self.assertIn(file2.name, files) |
|
|
self.assertNotIn(dir1.name, files) |
|
|
self.assertNotIn(dir2.name, files) |
|
|
|
|
|
|
|
|
class ValidationDatasetFastTests(unittest.TestCase): |
|
|
def setUp(self): |
|
|
num_data_files = 3 |
|
|
|
|
|
self.tmpdir = tempfile.TemporaryDirectory() |
|
|
metadata_filename = pathlib.Path(self.tmpdir.name) / "metadata.csv" |
|
|
|
|
|
with open(metadata_filename, "w") as f: |
|
|
f.write("caption,image_path,video_path\n") |
|
|
for i in range(num_data_files): |
|
|
Image.new("RGB", (64, 64)).save((pathlib.Path(self.tmpdir.name) / f"{i}.jpg").as_posix()) |
|
|
f.write(f"test caption,{self.tmpdir.name}/{i}.jpg,\n") |
|
|
|
|
|
self.dataset = ValidationDataset(metadata_filename.as_posix()) |
|
|
|
|
|
def tearDown(self): |
|
|
self.tmpdir.cleanup() |
|
|
|
|
|
def test_getitem(self): |
|
|
for i, data in enumerate(self.dataset): |
|
|
self.assertEqual(data["image_path"], f"{self.tmpdir.name}/{i}.jpg") |
|
|
self.assertIsInstance(data["image"], Image.Image) |
|
|
self.assertEqual(data["image"].size, (64, 64)) |
|
|
|