|
|
import os |
|
|
import tempfile |
|
|
import unittest |
|
|
|
|
|
from finetrainers.data import ( |
|
|
InMemoryDistributedDataPreprocessor, |
|
|
PrecomputedDistributedDataPreprocessor, |
|
|
VideoCaptionFilePairDataset, |
|
|
initialize_preprocessor, |
|
|
wrap_iterable_dataset_for_preprocessing, |
|
|
) |
|
|
from finetrainers.data.precomputation import PRECOMPUTED_DATA_DIR |
|
|
from finetrainers.utils import find_files |
|
|
|
|
|
from .utils import create_dummy_directory_structure |
|
|
|
|
|
|
|
|
class PreprocessorFastTests(unittest.TestCase): |
|
|
def setUp(self): |
|
|
self.rank = 0 |
|
|
self.world_size = 1 |
|
|
self.num_items = 3 |
|
|
self.processor_fn = { |
|
|
"latent": self._latent_processor_fn, |
|
|
"condition": self._condition_processor_fn, |
|
|
} |
|
|
self.save_dir = tempfile.TemporaryDirectory() |
|
|
|
|
|
directory_structure = [ |
|
|
"0.mp4", |
|
|
"1.mp4", |
|
|
"2.mp4", |
|
|
"0.txt", |
|
|
"1.txt", |
|
|
"2.txt", |
|
|
] |
|
|
create_dummy_directory_structure( |
|
|
directory_structure, self.save_dir, self.num_items, "a cat ruling the world", "mp4" |
|
|
) |
|
|
|
|
|
dataset = VideoCaptionFilePairDataset(self.save_dir.name, infinite=True) |
|
|
dataset = wrap_iterable_dataset_for_preprocessing( |
|
|
dataset, |
|
|
dataset_type="video", |
|
|
config={ |
|
|
"video_resolution_buckets": [[2, 32, 32]], |
|
|
"reshape_mode": "bicubic", |
|
|
}, |
|
|
) |
|
|
self.dataset = dataset |
|
|
|
|
|
def tearDown(self): |
|
|
self.save_dir.cleanup() |
|
|
|
|
|
@staticmethod |
|
|
def _latent_processor_fn(**data): |
|
|
video = data["video"] |
|
|
video = video[:, :, :16, :16] |
|
|
data["video"] = video |
|
|
return data |
|
|
|
|
|
@staticmethod |
|
|
def _condition_processor_fn(**data): |
|
|
caption = data["caption"] |
|
|
caption = caption + " surrounded by mystical aura" |
|
|
data["caption"] = caption |
|
|
return data |
|
|
|
|
|
def test_initialize_preprocessor(self): |
|
|
preprocessor = initialize_preprocessor( |
|
|
self.rank, |
|
|
self.world_size, |
|
|
self.num_items, |
|
|
self.processor_fn, |
|
|
self.save_dir.name, |
|
|
enable_precomputation=False, |
|
|
) |
|
|
self.assertIsInstance(preprocessor, InMemoryDistributedDataPreprocessor) |
|
|
|
|
|
preprocessor = initialize_preprocessor( |
|
|
self.rank, |
|
|
self.world_size, |
|
|
self.num_items, |
|
|
self.processor_fn, |
|
|
self.save_dir.name, |
|
|
enable_precomputation=True, |
|
|
) |
|
|
self.assertIsInstance(preprocessor, PrecomputedDistributedDataPreprocessor) |
|
|
|
|
|
def test_in_memory_preprocessor_consume(self): |
|
|
data_iterator = iter(self.dataset) |
|
|
preprocessor = initialize_preprocessor( |
|
|
self.rank, |
|
|
self.world_size, |
|
|
self.num_items, |
|
|
self.processor_fn, |
|
|
self.save_dir.name, |
|
|
enable_precomputation=False, |
|
|
) |
|
|
|
|
|
condition_iterator = preprocessor.consume( |
|
|
"condition", components={}, data_iterator=data_iterator, cache_samples=True |
|
|
) |
|
|
latent_iterator = preprocessor.consume( |
|
|
"latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True |
|
|
) |
|
|
|
|
|
self.assertFalse(preprocessor.requires_data) |
|
|
for _ in range(self.num_items): |
|
|
condition_item = next(condition_iterator) |
|
|
latent_item = next(latent_iterator) |
|
|
self.assertIn("caption", condition_item) |
|
|
self.assertIn("video", latent_item) |
|
|
self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura") |
|
|
self.assertEqual(latent_item["video"].shape[-2:], (16, 16)) |
|
|
self.assertTrue(preprocessor.requires_data) |
|
|
|
|
|
def test_in_memory_preprocessor_consume_once(self): |
|
|
data_iterator = iter(self.dataset) |
|
|
preprocessor = initialize_preprocessor( |
|
|
self.rank, |
|
|
self.world_size, |
|
|
self.num_items, |
|
|
self.processor_fn, |
|
|
self.save_dir.name, |
|
|
enable_precomputation=False, |
|
|
) |
|
|
|
|
|
condition_iterator = preprocessor.consume_once( |
|
|
"condition", components={}, data_iterator=data_iterator, cache_samples=True |
|
|
) |
|
|
latent_iterator = preprocessor.consume_once( |
|
|
"latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True |
|
|
) |
|
|
|
|
|
self.assertFalse(preprocessor.requires_data) |
|
|
for _ in range(self.num_items): |
|
|
condition_item = next(condition_iterator) |
|
|
latent_item = next(latent_iterator) |
|
|
self.assertIn("caption", condition_item) |
|
|
self.assertIn("video", latent_item) |
|
|
self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura") |
|
|
self.assertEqual(latent_item["video"].shape[-2:], (16, 16)) |
|
|
self.assertFalse(preprocessor.requires_data) |
|
|
|
|
|
def test_precomputed_preprocessor_consume(self): |
|
|
data_iterator = iter(self.dataset) |
|
|
preprocessor = initialize_preprocessor( |
|
|
self.rank, |
|
|
self.world_size, |
|
|
self.num_items, |
|
|
self.processor_fn, |
|
|
self.save_dir.name, |
|
|
enable_precomputation=True, |
|
|
) |
|
|
|
|
|
condition_iterator = preprocessor.consume( |
|
|
"condition", components={}, data_iterator=data_iterator, cache_samples=True |
|
|
) |
|
|
latent_iterator = preprocessor.consume( |
|
|
"latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True |
|
|
) |
|
|
|
|
|
precomputed_data_dir = os.path.join(self.save_dir.name, PRECOMPUTED_DATA_DIR) |
|
|
condition_file_list = find_files(precomputed_data_dir, "condition-*") |
|
|
latent_file_list = find_files(precomputed_data_dir, "latent-*") |
|
|
self.assertEqual(len(condition_file_list), 3) |
|
|
self.assertEqual(len(latent_file_list), 3) |
|
|
|
|
|
self.assertFalse(preprocessor.requires_data) |
|
|
for _ in range(self.num_items): |
|
|
condition_item = next(condition_iterator) |
|
|
latent_item = next(latent_iterator) |
|
|
self.assertIn("caption", condition_item) |
|
|
self.assertIn("video", latent_item) |
|
|
self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura") |
|
|
self.assertEqual(latent_item["video"].shape[-2:], (16, 16)) |
|
|
self.assertTrue(preprocessor.requires_data) |
|
|
|
|
|
def test_precomputed_preprocessor_consume_once(self): |
|
|
data_iterator = iter(self.dataset) |
|
|
preprocessor = initialize_preprocessor( |
|
|
self.rank, |
|
|
self.world_size, |
|
|
self.num_items, |
|
|
self.processor_fn, |
|
|
self.save_dir.name, |
|
|
enable_precomputation=True, |
|
|
) |
|
|
|
|
|
condition_iterator = preprocessor.consume_once( |
|
|
"condition", components={}, data_iterator=data_iterator, cache_samples=True |
|
|
) |
|
|
latent_iterator = preprocessor.consume_once( |
|
|
"latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True |
|
|
) |
|
|
|
|
|
precomputed_data_dir = os.path.join(self.save_dir.name, PRECOMPUTED_DATA_DIR) |
|
|
condition_file_list = find_files(precomputed_data_dir, "condition-*") |
|
|
latent_file_list = find_files(precomputed_data_dir, "latent-*") |
|
|
self.assertEqual(len(condition_file_list), 3) |
|
|
self.assertEqual(len(latent_file_list), 3) |
|
|
|
|
|
self.assertFalse(preprocessor.requires_data) |
|
|
for _ in range(self.num_items): |
|
|
condition_item = next(condition_iterator) |
|
|
latent_item = next(latent_iterator) |
|
|
self.assertIn("caption", condition_item) |
|
|
self.assertIn("video", latent_item) |
|
|
self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura") |
|
|
self.assertEqual(latent_item["video"].shape[-2:], (16, 16)) |
|
|
self.assertFalse(preprocessor.requires_data) |
|
|
|