Spaces:
Runtime error
Runtime error
| import glob | |
| import io | |
| import os | |
| import re | |
| import zipfile | |
| from abc import ABC, abstractmethod | |
| from contextlib import contextmanager | |
| from dataclasses import dataclass | |
| from typing import Dict, Iterator, List, Optional, Sequence, Tuple | |
| import numpy as np | |
| class NumpyArrayInfo: | |
| """ | |
| Information about an array in an npz file. | |
| """ | |
| name: str | |
| dtype: np.dtype | |
| shape: Tuple[int] | |
| def infos_from_first_file(cls, glob_path: str) -> Dict[str, "NumpyArrayInfo"]: | |
| paths, _ = _npz_paths_and_length(glob_path) | |
| return cls.infos_from_file(paths[0]) | |
| def infos_from_file(cls, npz_path: str) -> Dict[str, "NumpyArrayInfo"]: | |
| """ | |
| Extract the info of every array in an npz file. | |
| """ | |
| if not os.path.exists(npz_path): | |
| raise FileNotFoundError(f"batch of samples was not found: {npz_path}") | |
| results = {} | |
| with open(npz_path, "rb") as f: | |
| with zipfile.ZipFile(f, "r") as zip_f: | |
| for name in zip_f.namelist(): | |
| if not name.endswith(".npy"): | |
| continue | |
| key_name = name[: -len(".npy")] | |
| with zip_f.open(name, "r") as arr_f: | |
| version = np.lib.format.read_magic(arr_f) | |
| if version == (1, 0): | |
| header = np.lib.format.read_array_header_1_0(arr_f) | |
| elif version == (2, 0): | |
| header = np.lib.format.read_array_header_2_0(arr_f) | |
| else: | |
| raise ValueError(f"unknown numpy array version: {version}") | |
| shape, _, dtype = header | |
| results[key_name] = cls(name=key_name, dtype=dtype, shape=shape) | |
| return results | |
| def elem_shape(self) -> Tuple[int]: | |
| return self.shape[1:] | |
| def validate(self): | |
| if self.name in {"R", "G", "B"}: | |
| if len(self.shape) != 2: | |
| raise ValueError( | |
| f"expecting exactly 2-D shape for '{self.name}' but got: {self.shape}" | |
| ) | |
| elif self.name == "arr_0": | |
| if len(self.shape) < 2: | |
| raise ValueError(f"expecting at least 2-D shape but got: {self.shape}") | |
| elif len(self.shape) == 3: | |
| # For audio, we require continuous samples. | |
| if not np.issubdtype(self.dtype, np.floating): | |
| raise ValueError( | |
| f"invalid dtype for audio batch: {self.dtype} (expected float)" | |
| ) | |
| elif self.dtype != np.uint8: | |
| raise ValueError(f"invalid dtype for image batch: {self.dtype} (expected uint8)") | |
| class NpzStreamer: | |
| def __init__(self, glob_path: str): | |
| self.paths, self.trunc_length = _npz_paths_and_length(glob_path) | |
| self.infos = NumpyArrayInfo.infos_from_file(self.paths[0]) | |
| def keys(self) -> List[str]: | |
| return list(self.infos.keys()) | |
| def stream(self, batch_size: int, keys: Sequence[str]) -> Iterator[Dict[str, np.ndarray]]: | |
| cur_batch = None | |
| num_remaining = self.trunc_length | |
| for path in self.paths: | |
| if num_remaining is not None and num_remaining <= 0: | |
| break | |
| with open_npz_arrays(path, keys) as readers: | |
| combined_reader = CombinedReader(keys, readers) | |
| while num_remaining is None or num_remaining > 0: | |
| read_bs = batch_size | |
| if cur_batch is not None: | |
| read_bs -= _dict_batch_size(cur_batch) | |
| if num_remaining is not None: | |
| read_bs = min(read_bs, num_remaining) | |
| batch = combined_reader.read_batch(read_bs) | |
| if batch is None: | |
| break | |
| if num_remaining is not None: | |
| num_remaining -= _dict_batch_size(batch) | |
| if cur_batch is None: | |
| cur_batch = batch | |
| else: | |
| cur_batch = { | |
| # pylint: disable=unsubscriptable-object | |
| k: np.concatenate([cur_batch[k], v], axis=0) | |
| for k, v in batch.items() | |
| } | |
| if _dict_batch_size(cur_batch) == batch_size: | |
| yield cur_batch | |
| cur_batch = None | |
| if cur_batch is not None: | |
| yield cur_batch | |
| def _npz_paths_and_length(glob_path: str) -> Tuple[List[str], Optional[int]]: | |
| # Match slice syntax like path[:100]. | |
| count_match = re.match("^(.*)\\[:([0-9]*)\\]$", glob_path) | |
| if count_match: | |
| raw_path = count_match[1] | |
| max_count = int(count_match[2]) | |
| else: | |
| raw_path = glob_path | |
| max_count = None | |
| paths = sorted(glob.glob(raw_path)) | |
| if not len(paths): | |
| raise ValueError(f"no paths found matching: {glob_path}") | |
| return paths, max_count | |
| class NpzArrayReader(ABC): | |
| def read_batch(self, batch_size: int) -> Optional[np.ndarray]: | |
| pass | |
| class StreamingNpzArrayReader(NpzArrayReader): | |
| def __init__(self, arr_f, shape, dtype): | |
| self.arr_f = arr_f | |
| self.shape = shape | |
| self.dtype = dtype | |
| self.idx = 0 | |
| def read_batch(self, batch_size: int) -> Optional[np.ndarray]: | |
| if self.idx >= self.shape[0]: | |
| return None | |
| bs = min(batch_size, self.shape[0] - self.idx) | |
| self.idx += bs | |
| if self.dtype.itemsize == 0: | |
| return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype) | |
| read_count = bs * np.prod(self.shape[1:]) | |
| read_size = int(read_count * self.dtype.itemsize) | |
| data = _read_bytes(self.arr_f, read_size, "array data") | |
| return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]]) | |
| class MemoryNpzArrayReader(NpzArrayReader): | |
| def __init__(self, arr): | |
| self.arr = arr | |
| self.idx = 0 | |
| def load(cls, path: str, arr_name: str): | |
| with open(path, "rb") as f: | |
| arr = np.load(f)[arr_name] | |
| return cls(arr) | |
| def read_batch(self, batch_size: int) -> Optional[np.ndarray]: | |
| if self.idx >= self.arr.shape[0]: | |
| return None | |
| res = self.arr[self.idx : self.idx + batch_size] | |
| self.idx += batch_size | |
| return res | |
| def open_npz_arrays(path: str, arr_names: Sequence[str]) -> List[NpzArrayReader]: | |
| if not len(arr_names): | |
| yield [] | |
| return | |
| arr_name = arr_names[0] | |
| with open_array(path, arr_name) as arr_f: | |
| version = np.lib.format.read_magic(arr_f) | |
| header = None | |
| if version == (1, 0): | |
| header = np.lib.format.read_array_header_1_0(arr_f) | |
| elif version == (2, 0): | |
| header = np.lib.format.read_array_header_2_0(arr_f) | |
| if header is None: | |
| reader = MemoryNpzArrayReader.load(path, arr_name) | |
| else: | |
| shape, fortran, dtype = header | |
| if fortran or dtype.hasobject: | |
| reader = MemoryNpzArrayReader.load(path, arr_name) | |
| else: | |
| reader = StreamingNpzArrayReader(arr_f, shape, dtype) | |
| with open_npz_arrays(path, arr_names[1:]) as next_readers: | |
| yield [reader] + next_readers | |
| class CombinedReader: | |
| def __init__(self, keys: List[str], readers: List[NpzArrayReader]): | |
| self.keys = keys | |
| self.readers = readers | |
| def read_batch(self, batch_size: int) -> Optional[Dict[str, np.ndarray]]: | |
| batches = [r.read_batch(batch_size) for r in self.readers] | |
| any_none = any(x is None for x in batches) | |
| all_none = all(x is None for x in batches) | |
| if any_none != all_none: | |
| raise RuntimeError("different keys had different numbers of elements") | |
| if any_none: | |
| return None | |
| if any(len(x) != len(batches[0]) for x in batches): | |
| raise RuntimeError("different keys had different numbers of elements") | |
| return dict(zip(self.keys, batches)) | |
| def _read_bytes(fp, size, error_template="ran out of data"): | |
| """ | |
| Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886 | |
| Read from file-like object until size bytes are read. | |
| Raises ValueError if not EOF is encountered before size bytes are read. | |
| Non-blocking objects only supported if they derive from io objects. | |
| Required as e.g. ZipExtFile in python 2.6 can return less data than | |
| requested. | |
| """ | |
| data = bytes() | |
| while True: | |
| # io files (default in python3) return None or raise on | |
| # would-block, python2 file will truncate, probably nothing can be | |
| # done about that. note that regular files can't be non-blocking | |
| try: | |
| r = fp.read(size - len(data)) | |
| data += r | |
| if len(r) == 0 or len(data) == size: | |
| break | |
| except io.BlockingIOError: | |
| pass | |
| if len(data) != size: | |
| msg = "EOF: reading %s, expected %d bytes got %d" | |
| raise ValueError(msg % (error_template, size, len(data))) | |
| else: | |
| return data | |
| def open_array(path: str, arr_name: str): | |
| with open(path, "rb") as f: | |
| with zipfile.ZipFile(f, "r") as zip_f: | |
| if f"{arr_name}.npy" not in zip_f.namelist(): | |
| raise ValueError(f"missing {arr_name} in npz file") | |
| with zip_f.open(f"{arr_name}.npy", "r") as arr_f: | |
| yield arr_f | |
| def _dict_batch_size(objs: Dict[str, np.ndarray]) -> int: | |
| return len(next(iter(objs.values()))) | |