Spaces:
Runtime error
Runtime error
| from utils.dataset_utils import * | |
| class CachedDataset(Dataset): | |
| def __init__(self,cache_dir: str = ''): | |
| self.cache_dir = cache_dir | |
| self.cached_data_list = self.get_files_list() | |
| def get_files_list(self): | |
| tensors_list = [f"{self.cache_dir}/{x}" for x in os.listdir(self.cache_dir) if x.endswith('.pt')] | |
| return sorted(tensors_list) | |
| def __len__(self): | |
| return len(self.cached_data_list) | |
| def __getitem__(self, index): | |
| cached_latent = torch.load(self.cached_data_list[index], map_location='cuda:0') | |
| return cached_latent |