|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import numpy as np |
|
|
import torch |
|
|
from collections.abc import Mapping, Sequence |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
DATAS = ["sample1", "sample1_high_res", "sample1_dino"] |
|
|
|
|
|
|
|
|
def load( |
|
|
name: str = "sonata", |
|
|
download_root: str = None, |
|
|
): |
|
|
if name in DATAS: |
|
|
print(f"Loading data from HuggingFace: {name} ...") |
|
|
data_path = hf_hub_download( |
|
|
repo_id="pointcept/demo", |
|
|
filename=f"{name}.npz", |
|
|
repo_type="dataset", |
|
|
revision="main", |
|
|
local_dir=download_root or os.path.expanduser("~/.cache/sonata/data"), |
|
|
) |
|
|
elif os.path.isfile(name): |
|
|
print(f"Loading data in local path: {name} ...") |
|
|
data_path = name |
|
|
else: |
|
|
raise RuntimeError(f"Data {name} not found; available models = {DATAS}") |
|
|
return dict(np.load(data_path)) |
|
|
|
|
|
|
|
|
from torch.utils.data.dataloader import default_collate |
|
|
|
|
|
|
|
|
def collate_fn(batch): |
|
|
""" |
|
|
collate function for point cloud which support dict and list, |
|
|
'coord' is necessary to determine 'offset' |
|
|
""" |
|
|
if not isinstance(batch, Sequence): |
|
|
raise TypeError(f"{batch.dtype} is not supported.") |
|
|
|
|
|
if isinstance(batch[0], torch.Tensor): |
|
|
return torch.cat(list(batch)) |
|
|
elif isinstance(batch[0], str): |
|
|
|
|
|
return list(batch) |
|
|
elif isinstance(batch[0], Sequence): |
|
|
for data in batch: |
|
|
data.append(torch.tensor([data[0].shape[0]])) |
|
|
batch = [collate_fn(samples) for samples in zip(*batch)] |
|
|
batch[-1] = torch.cumsum(batch[-1], dim=0).int() |
|
|
return batch |
|
|
elif isinstance(batch[0], Mapping): |
|
|
batch = { |
|
|
key: ( |
|
|
collate_fn([d[key] for d in batch]) |
|
|
if "offset" not in key |
|
|
|
|
|
else torch.cumsum( |
|
|
collate_fn([d[key].diff(prepend=torch.tensor([0])) for d in batch]), |
|
|
dim=0, |
|
|
) |
|
|
) |
|
|
for key in batch[0] |
|
|
} |
|
|
return batch |
|
|
else: |
|
|
return default_collate(batch) |
|
|
|