Spaces:
Running
on
Zero
Running
on
Zero
Brandon May
commited on
Commit
·
26791f7
1
Parent(s):
77b08da
Add theia
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +24 -0
- theia/__init__.py +1 -0
- theia/configs/dataset/ego4d.yaml +5 -0
- theia/configs/dataset/epic_kitchen.yaml +5 -0
- theia/configs/dataset/image_video_default.yaml +7 -0
- theia/configs/dataset/image_video_mix.yaml +8 -0
- theia/configs/dataset/imagenet.yaml +5 -0
- theia/configs/dataset/oxe_octo_mix.yaml +12 -0
- theia/configs/dataset/ssv2.yaml +5 -0
- theia/configs/logging/default.yaml +6 -0
- theia/configs/model/backbone/deit.yaml +2 -0
- theia/configs/model/backbone/deit_nocls.yaml +2 -0
- theia/configs/model/backbone/deit_reg.yaml +3 -0
- theia/configs/model/translator/conv.yaml +3 -0
- theia/configs/model/translator/lconv.yaml +3 -0
- theia/configs/model/translator/mlp.yaml +4 -0
- theia/configs/model/translator/transformer.yaml +5 -0
- theia/configs/train_rvfm_imagenet.yaml +9 -0
- theia/configs/training/frame_level.yaml +35 -0
- theia/configs/training/target_models/cdds.yaml +6 -0
- theia/configs/training/target_models/cddsv.yaml +7 -0
- theia/configs/training/target_models/cddv.yaml +6 -0
- theia/configs/training/target_models/cdesv.yaml +6 -0
- theia/configs/training/target_models/cdis.yaml +5 -0
- theia/configs/training/target_models/cdisv.yaml +6 -0
- theia/configs/training/target_models/cdiv.yaml +5 -0
- theia/configs/training/target_models/clip.yaml +3 -0
- theia/configs/training/target_models/ddsv.yaml +6 -0
- theia/configs/training/target_models/depth_anything.yaml +3 -0
- theia/configs/training/target_models/dinov2.yaml +3 -0
- theia/configs/training/target_models/sam.yaml +3 -0
- theia/configs/training/target_models/vit.yaml +3 -0
- theia/dataset/__init__.py +5 -0
- theia/dataset/data_utils.py +591 -0
- theia/dataset/image/__init__.py +3 -0
- theia/dataset/image/image_common.py +5 -0
- theia/dataset/oxe/__init__.py +1 -0
- theia/dataset/oxe/oxe_common.py +430 -0
- theia/dataset/oxe/oxe_mixes.py +139 -0
- theia/dataset/oxe/oxe_transforms.py +15 -0
- theia/dataset/video/__init__.py +3 -0
- theia/dataset/video/video_common.py +11 -0
- theia/decoding/__init__.py +5 -0
- theia/decoding/decode.py +198 -0
- theia/decoding/depth_anything.py +57 -0
- theia/decoding/dinov2.py +69 -0
- theia/decoding/sam.py +191 -0
- theia/example/decode_to_vfms.ipynb +69 -0
- theia/foundation_models/__init__.py +9 -0
- theia/foundation_models/common.py +87 -0
LICENSE
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright (c) 2024 Boston Dynamics AI Institute LLC
|
| 2 |
+
|
| 3 |
+
Redistribution and use in source and binary forms, with or without
|
| 4 |
+
modification, are permitted provided that the following conditions are met:
|
| 5 |
+
1. Redistributions of source code must retain the copyright notice included
|
| 6 |
+
with the software, this list of conditions and the following disclaimer.
|
| 7 |
+
2. Redistributions in binary form must reproduce the copyright notice, this
|
| 8 |
+
list of conditions and the following disclaimer in the documentation and/or
|
| 9 |
+
other materials provided with the distribution.
|
| 10 |
+
3. Modified versions of the software must be conspicuously marked as such.
|
| 11 |
+
4. The software may only be used for non-commercial research purposes.
|
| 12 |
+
For profit enterprises may use the software, subject to this limitation.
|
| 13 |
+
|
| 14 |
+
THIS SOFTWARE IS PROVIDED BY THE AI INSTITUTE AND CONTRIBUTORS "AS IS" AND
|
| 15 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, NON-
|
| 16 |
+
INFRINGEMENT,TITLE, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 17 |
+
DISCLAIMED. IN NO EVENT SHALL THE AI INSTITUTE OR CONTRIBUTORS BE LIABLE FOR
|
| 18 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR CONSEQUENTIAL
|
| 19 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, DAMAGES ARISING OUT OF CLAIMS OF
|
| 20 |
+
INTELLECTUAL PROPERTY RIGHTS INFRINGEMENT; PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 21 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 22 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 23 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 24 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
theia/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
theia/configs/dataset/ego4d.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- image_video_default
|
| 3 |
+
|
| 4 |
+
dataset_mix:
|
| 5 |
+
- "ego4d_1in150"
|
theia/configs/dataset/epic_kitchen.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- image_video_default
|
| 3 |
+
|
| 4 |
+
dataset_mix:
|
| 5 |
+
- "epic_kitchen_1in60"
|
theia/configs/dataset/image_video_default.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
return_metadata: False
|
| 2 |
+
shuffle: True
|
| 3 |
+
shuffle_buffer_size: 1024
|
| 4 |
+
feature_norm: True
|
| 5 |
+
dataset_root: "/storage/nfs/datasets/jshang/"
|
| 6 |
+
dataset_ratio: 0.1
|
| 7 |
+
load_action: False
|
theia/configs/dataset/image_video_mix.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- image_video_default
|
| 3 |
+
|
| 4 |
+
dataset_mix:
|
| 5 |
+
- "ego4d_1in150"
|
| 6 |
+
- "ssv2_1in32"
|
| 7 |
+
- "epic_kitchen_1in60"
|
| 8 |
+
- "imagenet"
|
theia/configs/dataset/imagenet.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- image_video_default
|
| 3 |
+
|
| 4 |
+
dataset_mix:
|
| 5 |
+
- "imagenet"
|
theia/configs/dataset/oxe_octo_mix.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: dataset.oxe.oxe_data_utils.OXEDataset
|
| 2 |
+
dataset_mix: "oxe_magic_soup"
|
| 3 |
+
image_action_set_root: "/storage/nfs/datasets/jshang/oxe_image_action"
|
| 4 |
+
feature_set_root: "/storage/nfs/datasets/jshang/oxe_vfm_features"
|
| 5 |
+
image_views: null
|
| 6 |
+
split: "train"
|
| 7 |
+
data_portion: 0.01
|
| 8 |
+
load_action: False
|
| 9 |
+
bf16: True
|
| 10 |
+
safe_tensors: True
|
| 11 |
+
trajectory_subsample_len: 32
|
| 12 |
+
return_metadata: False
|
theia/configs/dataset/ssv2.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- image_video_default
|
| 3 |
+
|
| 4 |
+
dataset_mix:
|
| 5 |
+
- "ssv2_1in32"
|
theia/configs/logging/default.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_path: "/storage/nfs/jshang/trained_models"
|
| 2 |
+
log_path: "/storage/nfs/jshang/logs"
|
| 3 |
+
save_ckpt_interval: 20000
|
| 4 |
+
notes: ""
|
| 5 |
+
run_identifier_prefix: ""
|
| 6 |
+
project: "theia"
|
theia/configs/model/backbone/deit.yaml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
backbone: facebook/deit-small-patch16-224
|
| 2 |
+
pretrained: False
|
theia/configs/model/backbone/deit_nocls.yaml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
backbone: nocls-facebook/deit-tiny-patch16-224
|
| 2 |
+
pretrained: False
|
theia/configs/model/backbone/deit_reg.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
backbone: reg-facebook/deit-tiny-patch16-224
|
| 2 |
+
pretrained: False
|
| 3 |
+
num_reg_tokens: 7
|
theia/configs/model/translator/conv.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: "conv"
|
| 2 |
+
kwargs:
|
| 3 |
+
translator_hidden_size: 1024
|
theia/configs/model/translator/lconv.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: "lconv"
|
| 2 |
+
kwargs:
|
| 3 |
+
hidden_size_factor: 1.0
|
theia/configs/model/translator/mlp.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: "mlp"
|
| 2 |
+
kwargs:
|
| 3 |
+
translator_n_layer: 3
|
| 4 |
+
hidden_size: 1024
|
theia/configs/model/translator/transformer.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: "transformer"
|
| 2 |
+
kwargs:
|
| 3 |
+
translator_n_layers: 2
|
| 4 |
+
translator_n_heads: 8
|
| 5 |
+
translator_hidden_size: 1024
|
theia/configs/train_rvfm_imagenet.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- dataset: imagenet
|
| 3 |
+
- model/backbone: deit
|
| 4 |
+
- model/translator: lconv
|
| 5 |
+
- training: frame_level
|
| 6 |
+
- logging: default
|
| 7 |
+
- _self_
|
| 8 |
+
|
| 9 |
+
seed: 0
|
theia/configs/training/frame_level.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- target_models: cdiv
|
| 3 |
+
|
| 4 |
+
epochs: 50
|
| 5 |
+
warm_up_steps_ratio: 0.1
|
| 6 |
+
|
| 7 |
+
base_lr: 2e-3
|
| 8 |
+
batch_size: 16
|
| 9 |
+
random_target_models: -1
|
| 10 |
+
num_workers: 8
|
| 11 |
+
# base training settings to scale lr, rarely changed
|
| 12 |
+
base_batch_size: 64
|
| 13 |
+
base_world_size: 8
|
| 14 |
+
|
| 15 |
+
weight_decay: 0.01
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
optimizer:
|
| 19 |
+
_target_: torch.optim.AdamW
|
| 20 |
+
betas: [0.9, 0.999]
|
| 21 |
+
|
| 22 |
+
lr_scheduler:
|
| 23 |
+
_target_: theia.lr_schedulers.get_constant_lrs_with_linear_warm_up
|
| 24 |
+
warm_up_lr_start_factor: 1e-2
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
grad_clip: False
|
| 28 |
+
grad_clip_norm_warmup: 10.0
|
| 29 |
+
grad_clip_norm: 1.0
|
| 30 |
+
|
| 31 |
+
freeze_translator: False
|
| 32 |
+
freeze_translator_start_steps_ratio: 0.2
|
| 33 |
+
translator_lr_factor: 1.0
|
| 34 |
+
|
| 35 |
+
main_loss: cos_l1
|
theia/configs/training/target_models/cdds.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
target_model_names:
|
| 2 |
+
- "facebook/dinov2-large"
|
| 3 |
+
- "openai/clip-vit-large-patch14"
|
| 4 |
+
- "facebook/sam-vit-huge"
|
| 5 |
+
- "LiheYoung/depth-anything-large-hf"
|
| 6 |
+
target_model_weights: null
|
theia/configs/training/target_models/cddsv.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
target_model_names:
|
| 2 |
+
- "google/vit-huge-patch14-224-in21k"
|
| 3 |
+
- "facebook/dinov2-large"
|
| 4 |
+
- "openai/clip-vit-large-patch14"
|
| 5 |
+
- "facebook/sam-vit-huge"
|
| 6 |
+
- "LiheYoung/depth-anything-large-hf"
|
| 7 |
+
target_model_weights: null
|
theia/configs/training/target_models/cddv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
target_model_names:
|
| 2 |
+
- "google/vit-huge-patch14-224-in21k"
|
| 3 |
+
- "facebook/dinov2-large"
|
| 4 |
+
- "openai/clip-vit-large-patch14"
|
| 5 |
+
- "LiheYoung/depth-anything-large-hf"
|
| 6 |
+
target_model_weights: null
|
theia/configs/training/target_models/cdesv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
target_model_names:
|
| 2 |
+
- "google/vit-huge-patch14-224-in21k"
|
| 3 |
+
- "openai/clip-vit-large-patch14"
|
| 4 |
+
- "facebook/sam-vit-huge"
|
| 5 |
+
- "LiheYoung/depth-anything-large-hf"
|
| 6 |
+
target_model_weights: null
|
theia/configs/training/target_models/cdis.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
target_model_names:
|
| 2 |
+
- "facebook/dinov2-large"
|
| 3 |
+
- "openai/clip-vit-large-patch14"
|
| 4 |
+
- "facebook/sam-vit-huge"
|
| 5 |
+
target_model_weights: null
|
theia/configs/training/target_models/cdisv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
target_model_names:
|
| 2 |
+
- "google/vit-huge-patch14-224-in21k"
|
| 3 |
+
- "facebook/dinov2-large"
|
| 4 |
+
- "openai/clip-vit-large-patch14"
|
| 5 |
+
- "facebook/sam-vit-huge"
|
| 6 |
+
target_model_weights: null
|
theia/configs/training/target_models/cdiv.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
target_model_names:
|
| 2 |
+
- "google/vit-huge-patch14-224-in21k"
|
| 3 |
+
- "facebook/dinov2-large"
|
| 4 |
+
- "openai/clip-vit-large-patch14"
|
| 5 |
+
target_model_weights: null
|
theia/configs/training/target_models/clip.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
target_model_names:
|
| 2 |
+
- "openai/clip-vit-large-patch14"
|
| 3 |
+
target_model_weights: null
|
theia/configs/training/target_models/ddsv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
target_model_names:
|
| 2 |
+
- "google/vit-huge-patch14-224-in21k"
|
| 3 |
+
- "facebook/dinov2-large"
|
| 4 |
+
- "facebook/sam-vit-huge"
|
| 5 |
+
- "LiheYoung/depth-anything-large-hf"
|
| 6 |
+
target_model_weights: null
|
theia/configs/training/target_models/depth_anything.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
target_model_names:
|
| 2 |
+
- "LiheYoung/depth-anything-large-hf"
|
| 3 |
+
target_model_weights: null
|
theia/configs/training/target_models/dinov2.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
target_model_names:
|
| 2 |
+
- "facebook/dinov2-large"
|
| 3 |
+
target_model_weights: null
|
theia/configs/training/target_models/sam.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
target_model_names:
|
| 2 |
+
- "facebook/sam-vit-huge"
|
| 3 |
+
target_model_weights: null
|
theia/configs/training/target_models/vit.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
target_model_names:
|
| 2 |
+
- "google/vit-huge-patch14-224-in21k"
|
| 3 |
+
target_model_weights: null
|
theia/dataset/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
| 2 |
+
|
| 3 |
+
from .image.image_common import ALL_IMAGE_DATASETS
|
| 4 |
+
from .oxe.oxe_common import ALL_OXE_DATASETS
|
| 5 |
+
from .video.video_common import ALL_VIDEO_DATASETS
|
theia/dataset/data_utils.py
ADDED
|
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
| 2 |
+
|
| 3 |
+
"""Defines PyTorch datasets of dataloaders for multiple image, video, and OXE datasets.
|
| 4 |
+
Should use with webdataset >= 0.2.90. See https://github.com/webdataset/webdataset/pull/347"""
|
| 5 |
+
|
| 6 |
+
import glob
|
| 7 |
+
import json
|
| 8 |
+
import math
|
| 9 |
+
import os.path as osp
|
| 10 |
+
from collections import OrderedDict
|
| 11 |
+
from functools import partial
|
| 12 |
+
from io import BytesIO
|
| 13 |
+
from typing import Any, Callable, Generator, Iterator, Literal, Optional
|
| 14 |
+
|
| 15 |
+
import cv2
|
| 16 |
+
import numpy as np
|
| 17 |
+
import omegaconf
|
| 18 |
+
import torch
|
| 19 |
+
import webdataset as wds
|
| 20 |
+
from datasets.combine import DatasetType
|
| 21 |
+
from einops import rearrange
|
| 22 |
+
from numpy.typing import NDArray
|
| 23 |
+
from safetensors.torch import load as sft_load
|
| 24 |
+
from torch import default_generator
|
| 25 |
+
from torch.utils.data import DataLoader, Dataset, IterableDataset, default_collate
|
| 26 |
+
|
| 27 |
+
from theia.foundation_models.common import MODELS
|
| 28 |
+
from theia.dataset.oxe.oxe_common import ALL_OXE_DATASETS
|
| 29 |
+
from theia.dataset.oxe.oxe_mixes import OXE_NAMED_MIXES
|
| 30 |
+
|
| 31 |
+
PACKED_FEATURES = [model_name for model_name in MODELS if "llava" not in model_name]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def normalize_ds_weights_by_ds_len(weights: list[float], lengths: list[int]) -> tuple[list[float], float | Literal[0]]:
|
| 35 |
+
"""Normalize dataset weights by dataset lengths (frames).
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
weights (list[float]): assigned weights.
|
| 39 |
+
lengths (list[int]): lengths of datasets.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
tuple[list[float], int]: normalized weights, and sum of the expected lengths of datasets
|
| 43 |
+
"""
|
| 44 |
+
expected_lengths = [weight * length for weight, length in zip(weights, lengths, strict=False)]
|
| 45 |
+
sum_expected_lengths = sum(expected_lengths)
|
| 46 |
+
if sum_expected_lengths == 0:
|
| 47 |
+
raise ValueError("Sum of dataset length is 0.")
|
| 48 |
+
normalized_weights = [length * 1.0 / sum_expected_lengths for length in expected_lengths]
|
| 49 |
+
return normalized_weights, sum_expected_lengths
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_vo_keys(dataset_name: str, image_views: Optional[list | str | dict[str, str | list[str]]] = None) -> list[str]:
|
| 53 |
+
"""Get visual observation keys of datasets (to be compatible with OXE).
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
dataset_name (str): name of the dataset.
|
| 57 |
+
image_views (Optional[dict[str, str | list[str]]], optional): keys of selected views.
|
| 58 |
+
Defaults to None.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
list[str]: keys to the views in the dataset.
|
| 62 |
+
"""
|
| 63 |
+
default_visual_observation_keys = ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"][:1]
|
| 64 |
+
visual_observation_keys = []
|
| 65 |
+
if image_views is None:
|
| 66 |
+
visual_observation_keys = default_visual_observation_keys
|
| 67 |
+
elif isinstance(image_views, list):
|
| 68 |
+
visual_observation_keys = ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"]
|
| 69 |
+
elif isinstance(image_views, str):
|
| 70 |
+
if image_views == "static":
|
| 71 |
+
visual_observation_keys = [
|
| 72 |
+
k
|
| 73 |
+
for k in ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"]
|
| 74 |
+
if "wrist" not in k and "hand" not in k
|
| 75 |
+
]
|
| 76 |
+
elif image_views == "wrist":
|
| 77 |
+
visual_observation_keys = [
|
| 78 |
+
k for k in ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"] if "wrist" in k or "hand" in k
|
| 79 |
+
]
|
| 80 |
+
if len(visual_observation_keys) == 0:
|
| 81 |
+
visual_observation_keys = default_visual_observation_keys
|
| 82 |
+
return visual_observation_keys
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class RandomMix(IterableDataset):
|
| 86 |
+
"""A random interleave of multiple iterable datasets."""
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
datasets: list[IterableDataset],
|
| 91 |
+
probs: list[float] | NDArray | None = None,
|
| 92 |
+
stopping_strategy: str = "all_exhausted",
|
| 93 |
+
seed: Optional[int | str] = 0,
|
| 94 |
+
) -> None:
|
| 95 |
+
"""Initialization of a random interleave dataset.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
datasets (list[IterableDataset]): datasets to be interleaved.
|
| 99 |
+
probs (list[float] | NDArray, optional): probability of each dataset. Defaults to None.
|
| 100 |
+
stopping_strategy (str, optional): when to end the sampling for one epoch. Defaults to `all_exhausted`.
|
| 101 |
+
`all_exhausted`: each sample in the dataset will be sampled at least once.
|
| 102 |
+
`first_exhausted`: when the first dataset is ran out, this episode ends.
|
| 103 |
+
See also https://huggingface.co/docs/datasets/en/stream#interleave for definitions.
|
| 104 |
+
seed (Optional[int | str]): seed. Defaults to 0.
|
| 105 |
+
"""
|
| 106 |
+
self.datasets = datasets
|
| 107 |
+
if probs is None:
|
| 108 |
+
self.probs = [1.0] * len(self.datasets)
|
| 109 |
+
elif isinstance(probs, np.ndarray):
|
| 110 |
+
self.probs = probs.tolist()
|
| 111 |
+
else:
|
| 112 |
+
self.probs = probs
|
| 113 |
+
self.stopping_strategy = stopping_strategy
|
| 114 |
+
self.seed = seed
|
| 115 |
+
|
| 116 |
+
def __iter__(self) -> Generator:
|
| 117 |
+
"""Return an iterator over the sources."""
|
| 118 |
+
sources = [iter(d) for d in self.datasets]
|
| 119 |
+
probs = self.probs[:]
|
| 120 |
+
seed_gen = torch.Generator()
|
| 121 |
+
seed_gen.manual_seed(self.seed)
|
| 122 |
+
cum = (np.array(probs) / np.sum(probs)).cumsum()
|
| 123 |
+
while len(sources) > 0:
|
| 124 |
+
r = torch.rand(1, generator=seed_gen).item()
|
| 125 |
+
i = np.searchsorted(cum, r)
|
| 126 |
+
try:
|
| 127 |
+
yield next(sources[i])
|
| 128 |
+
except StopIteration:
|
| 129 |
+
if self.stopping_strategy == "all_exhausted":
|
| 130 |
+
del sources[i]
|
| 131 |
+
del probs[i]
|
| 132 |
+
cum = (np.array(probs) / np.sum(probs)).cumsum()
|
| 133 |
+
elif self.stopping_strategy == "first_exhausted":
|
| 134 |
+
break
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def decode_sample(
|
| 138 |
+
key: str, data: bytes, image_transform: Optional[Callable] = None, feature_transform: Optional[Callable] = None
|
| 139 |
+
) -> Any:
|
| 140 |
+
"""Decode a sample from bytes with optional image and feature transforms
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
key (str): key of an attribute (a column) of the sample.
|
| 144 |
+
data (bytes): original data bytes.
|
| 145 |
+
image_transform (Optional[Callable], optional): image transform. Defaults to None.
|
| 146 |
+
feature_transform (Optional[Callable], optional): feature transform. Defaults to None.
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
Any: decoded data.
|
| 150 |
+
"""
|
| 151 |
+
if ".safetensors" in key:
|
| 152 |
+
sft = sft_load(data)
|
| 153 |
+
embedding = rearrange(sft["embedding"], "c h w -> (h w) c")
|
| 154 |
+
if feature_transform is not None:
|
| 155 |
+
embedding = feature_transform(embedding)
|
| 156 |
+
if "cls_token" in sft:
|
| 157 |
+
cls = sft["cls_token"]
|
| 158 |
+
if feature_transform is not None:
|
| 159 |
+
cls = feature_transform(cls)
|
| 160 |
+
return {"embedding": embedding, "cls": cls}
|
| 161 |
+
return {"embedding": embedding}
|
| 162 |
+
elif key == ".image":
|
| 163 |
+
image = np.load(BytesIO(data))
|
| 164 |
+
if len(image.shape) == 2:
|
| 165 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
| 166 |
+
elif len(image.shape) == 3 and image.shape[-1] == 4:
|
| 167 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
| 168 |
+
if image_transform is not None:
|
| 169 |
+
return image_transform(image)
|
| 170 |
+
return image
|
| 171 |
+
else:
|
| 172 |
+
return data
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def get_oxe_frame_dataset(
|
| 176 |
+
dataset_root: str,
|
| 177 |
+
dataset_mix: Optional[str | dict[str, float] | list] = "oxe_magic_soup",
|
| 178 |
+
feature_models: Optional[list[str]] = None,
|
| 179 |
+
split: str = "train",
|
| 180 |
+
dataset_ratio: float = 1.0,
|
| 181 |
+
image_views: Optional[dict[str, str | list[str]]] = None,
|
| 182 |
+
image_transform: Optional[Callable[[Any], torch.Tensor]] = None,
|
| 183 |
+
seed: Optional[int | str] = 0,
|
| 184 |
+
shuffle: bool = False,
|
| 185 |
+
world_size: int = 1,
|
| 186 |
+
) -> tuple[dict[str, DatasetType], float | Literal[0]]:
|
| 187 |
+
"""Get OXE datasets at frame level.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
dataset_root (str): root dir of the datasets.
|
| 191 |
+
dataset_mix (Optional[str | dict[str, float] | list], optional): how to mix the datasets.
|
| 192 |
+
Defaults to "oxe_magic_soup".
|
| 193 |
+
feature_models (Optional[list[str]], optional): models to load their features. Defaults to None.
|
| 194 |
+
split (str, optional): split "train" or "val" or "test". Defaults to "train".
|
| 195 |
+
dataset_ratio (float, optional): how much data use for the (combined) dataset. Defaults to 1.0.
|
| 196 |
+
image_views (Optional[dict[str, str | list[str]]], optional): image views to select. Defaults to None.
|
| 197 |
+
image_transform (Optional[Callable[[Any], torch.Tensor]], optional): image transform applied to samples.
|
| 198 |
+
Defaults to None.
|
| 199 |
+
seed (Optional[int | str], optional): seed. Defaults to 0.
|
| 200 |
+
shuffle (bool, optional): shuffle or not. Defaults to False.
|
| 201 |
+
world_size (int, optional): world size of DDP training. Defaults to 1.
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
tuple[dict[str, DatasetType], int]: a dict of {dataset name: dataset class}.
|
| 205 |
+
"""
|
| 206 |
+
# read dataset mix from any acceptable form
|
| 207 |
+
if isinstance(dataset_mix, str) and dataset_mix in OXE_NAMED_MIXES:
|
| 208 |
+
dataset_mix = OrderedDict({k: v for k, v in OXE_NAMED_MIXES[dataset_mix]})
|
| 209 |
+
elif isinstance(dataset_mix, dict):
|
| 210 |
+
dataset_mix = OrderedDict(**dataset_mix)
|
| 211 |
+
elif isinstance(dataset_mix, list):
|
| 212 |
+
dataset_mix = OrderedDict({d: 1.0 for d in dataset_mix})
|
| 213 |
+
else:
|
| 214 |
+
raise ValueError(f"dataset_mix of {dataset_mix}:{type(dataset_mix)} is not supported.")
|
| 215 |
+
|
| 216 |
+
if split == "eval" or split == "val":
|
| 217 |
+
dataset_mix = OrderedDict({d: 1.0 for d in dataset_mix})
|
| 218 |
+
|
| 219 |
+
# note down the dataset weights
|
| 220 |
+
dataset_weights: list[float] = []
|
| 221 |
+
# get frame level length
|
| 222 |
+
dataset_lens: list[int] = []
|
| 223 |
+
|
| 224 |
+
all_feature_datasets: dict[str, DatasetType] = {}
|
| 225 |
+
for dataset in dataset_mix:
|
| 226 |
+
visual_observation_keys = get_vo_keys(dataset_name=dataset, image_views=image_views)
|
| 227 |
+
|
| 228 |
+
if feature_models is None:
|
| 229 |
+
feature_models = PACKED_FEATURES
|
| 230 |
+
|
| 231 |
+
with open(osp.join(dataset_root, dataset, "splits.json"), "r") as splitf:
|
| 232 |
+
dataset_len = json.load(splitf)[split]
|
| 233 |
+
# if the length is 0, skip
|
| 234 |
+
# this may happen for small datasets with very few shards
|
| 235 |
+
if dataset_len == 0:
|
| 236 |
+
continue
|
| 237 |
+
|
| 238 |
+
for vo_key in visual_observation_keys:
|
| 239 |
+
for model_name in feature_models:
|
| 240 |
+
if model_name not in PACKED_FEATURES:
|
| 241 |
+
feature_set_name = model_name
|
| 242 |
+
path_pattern = osp.join(
|
| 243 |
+
dataset_root, dataset, vo_key + f"_{model_name.replace('/', '_')}", f"*-{split}*.tar"
|
| 244 |
+
)
|
| 245 |
+
rename_kw = {model_name: model_name.replace("/", "_") + ".safetensors"} # replace v by k
|
| 246 |
+
elif "packed" in all_feature_datasets:
|
| 247 |
+
continue
|
| 248 |
+
else:
|
| 249 |
+
feature_set_name = "packed"
|
| 250 |
+
path_pattern = osp.join(dataset_root, dataset, vo_key, f"*-{split}*.tar")
|
| 251 |
+
rename_kw = {
|
| 252 |
+
name: name.replace("/", "_") + ".safetensors" for name in PACKED_FEATURES
|
| 253 |
+
} # replace v by k
|
| 254 |
+
rename_kw["image"] = "image"
|
| 255 |
+
|
| 256 |
+
if feature_set_name not in all_feature_datasets:
|
| 257 |
+
all_feature_datasets[feature_set_name] = []
|
| 258 |
+
|
| 259 |
+
shard_paths = sorted(glob.glob(path_pattern))
|
| 260 |
+
num_shards = len(shard_paths)
|
| 261 |
+
if num_shards < world_size * 8:
|
| 262 |
+
shard_paths *= math.ceil(world_size * 8 / num_shards)
|
| 263 |
+
ds = (
|
| 264 |
+
wds.WebDataset(
|
| 265 |
+
shard_paths,
|
| 266 |
+
nodesplitter=wds.split_by_node,
|
| 267 |
+
workersplitter=wds.split_by_worker,
|
| 268 |
+
detshuffle=True,
|
| 269 |
+
shardshuffle=shuffle,
|
| 270 |
+
seed=seed,
|
| 271 |
+
)
|
| 272 |
+
.decode(partial(decode_sample, image_transform=image_transform))
|
| 273 |
+
.rename(keep=False, **rename_kw)
|
| 274 |
+
)
|
| 275 |
+
all_feature_datasets[feature_set_name].append(ds)
|
| 276 |
+
|
| 277 |
+
dataset_weights.append(dataset_mix[dataset])
|
| 278 |
+
dataset_lens.append(math.ceil(dataset_len * dataset_ratio))
|
| 279 |
+
|
| 280 |
+
normalized_dataset_weights, sum_expected_lengths = normalize_ds_weights_by_ds_len(dataset_weights, dataset_lens)
|
| 281 |
+
|
| 282 |
+
combined_feature_datasets: dict[str, Dataset] = {}
|
| 283 |
+
for feature_set_name, fds in all_feature_datasets.items():
|
| 284 |
+
ds = RandomMix(fds, probs=normalized_dataset_weights, stopping_strategy="all_exhausted")
|
| 285 |
+
combined_feature_datasets[feature_set_name] = ds
|
| 286 |
+
|
| 287 |
+
return combined_feature_datasets, sum_expected_lengths
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def get_oxe_frame_dataloader(
|
| 291 |
+
datasets: dict[str, DatasetType], batch_size: Optional[int] = None, shuffle_buffer_size: int = 1_000, **kwargs: Any
|
| 292 |
+
) -> dict[str, DataLoader]:
|
| 293 |
+
"""Get dataloaders of OXE datasets. Corresponding to `get_oxe_frame_dataset()`.
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
datasets (dict[str, DatasetType]): OXE datasets from `get_oxe_frame_dataset().
|
| 297 |
+
batch_size (Optional[int], optional): batch size. Defaults to None.
|
| 298 |
+
shuffle_buffer_size (int, optional): buffer for shuffle while streaming. Defaults to 1_000.
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
dict[str, DataLoader]: dataloaders. a dict of {dataset name: dataloader}.
|
| 302 |
+
"""
|
| 303 |
+
loaders = {
|
| 304 |
+
k: (
|
| 305 |
+
wds.WebLoader(datasets[k], batch_size=None, **kwargs)
|
| 306 |
+
.shuffle(shuffle_buffer_size) # shuffle after mix
|
| 307 |
+
.batched(batch_size, collation_fn=default_collate)
|
| 308 |
+
)
|
| 309 |
+
for k in datasets
|
| 310 |
+
}
|
| 311 |
+
return loaders
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def get_oxe_frame_iterator(
|
| 315 |
+
data_loaders: dict[str, DataLoader],
|
| 316 |
+
) -> Iterator[dict[str, Any]]:
|
| 317 |
+
"""Get iterator from dataloders. Corresponding to `get_oxe_frame_dataloader()`.
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
data_loaders (dict[str, DataLoader]): dataloaders from `get_oxe_frame_dataloader()`.
|
| 321 |
+
|
| 322 |
+
Yields:
|
| 323 |
+
Iterator[dict[str, Any]]: data sample.
|
| 324 |
+
"""
|
| 325 |
+
packed_loader = data_loaders.get("packed", None)
|
| 326 |
+
# place packed_loader at the first
|
| 327 |
+
if packed_loader is not None:
|
| 328 |
+
loaders = [packed_loader, *[data_loaders[k] for k in data_loaders if k != "packed"]]
|
| 329 |
+
else:
|
| 330 |
+
loaders = list(data_loaders.values())
|
| 331 |
+
|
| 332 |
+
# merge dicts
|
| 333 |
+
for data in zip(*loaders, strict=False):
|
| 334 |
+
# yield data
|
| 335 |
+
for i in range(1, len(loaders)):
|
| 336 |
+
for k in data[i]:
|
| 337 |
+
if k not in data[0]:
|
| 338 |
+
data[0][k] = data[i][k]
|
| 339 |
+
yield data[0]
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def normalize_feature(
|
| 343 |
+
x: torch.Tensor, mean: Optional[torch.Tensor] = None, std: Optional[torch.Tensor] = None
|
| 344 |
+
) -> torch.Tensor:
|
| 345 |
+
"""Normalize the feature given mean and std.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
x (torch.Tensor): input features
|
| 349 |
+
mean (Optional[torch.Tensor], optional): mean values. Defaults to None.
|
| 350 |
+
std (Optional[torch.Tensor], optional): std values. Defaults to None.
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
torch.Tensor: feature after normalization
|
| 354 |
+
"""
|
| 355 |
+
return x if mean is None or std is None else (x - mean) / std
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def load_feature_stats(
|
| 359 |
+
dataset_root: str, feature_models: list[str]
|
| 360 |
+
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
| 361 |
+
"""Load feature statictics (mean and variance).
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
dataset_root (str): root dir of the dataset (or where to hold the statistics).
|
| 365 |
+
feature_models (list[str]): names of the models/features.
|
| 366 |
+
|
| 367 |
+
Returns:
|
| 368 |
+
tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: means and variances. Keys are model names.
|
| 369 |
+
"""
|
| 370 |
+
feature_means: dict[str, torch.Tensor] = {}
|
| 371 |
+
feature_vars: dict[str, torch.Tensor] = {}
|
| 372 |
+
for model in feature_models:
|
| 373 |
+
model_name = model.replace("/", "_")
|
| 374 |
+
feature_means[model] = torch.from_numpy(np.load(osp.join(dataset_root, f"imagenet_mean_{model_name}.npy"))).to(
|
| 375 |
+
torch.bfloat16
|
| 376 |
+
)
|
| 377 |
+
feature_vars[model] = torch.from_numpy(np.load(osp.join(dataset_root, f"imagenet_var_{model_name}.npy"))).to(
|
| 378 |
+
torch.bfloat16
|
| 379 |
+
)
|
| 380 |
+
return feature_means, feature_vars
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def pad_shard_paths(shard_paths: list[str], num_shards: int, num_parts: int) -> list[str]:
|
| 384 |
+
"""Pad shard paths to be divided by number of partitions (ranks*nodes).
|
| 385 |
+
|
| 386 |
+
Args:
|
| 387 |
+
shard_paths (list[str]): pathes of dataset shards.
|
| 388 |
+
num_shards (int): number of shards.
|
| 389 |
+
num_parts (int): number of partitions.
|
| 390 |
+
|
| 391 |
+
Returns:
|
| 392 |
+
list[str]: shard paths padded.
|
| 393 |
+
"""
|
| 394 |
+
final_shard_paths = shard_paths
|
| 395 |
+
if num_shards % num_parts != 0:
|
| 396 |
+
if num_shards < num_parts - num_shards:
|
| 397 |
+
for _ in range(math.floor((num_parts - num_shards) / num_shards)):
|
| 398 |
+
final_shard_paths += shard_paths[:]
|
| 399 |
+
final_shard_paths += shard_paths[: num_parts - len(final_shard_paths)]
|
| 400 |
+
else:
|
| 401 |
+
final_shard_paths += shard_paths[: num_parts - len(final_shard_paths)]
|
| 402 |
+
return final_shard_paths
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def get_image_video_dataset(
|
| 406 |
+
dataset_root: str,
|
| 407 |
+
feature_models: list[str],
|
| 408 |
+
dataset_mix: Optional[str | dict[str, float] | list] = None,
|
| 409 |
+
split: str = "train",
|
| 410 |
+
dataset_ratio: float = 1.0,
|
| 411 |
+
image_transform: Optional[Callable[[Any], torch.Tensor]] = None,
|
| 412 |
+
feature_norm: bool = False,
|
| 413 |
+
seed: Optional[int | str] = 0,
|
| 414 |
+
shuffle: bool = False,
|
| 415 |
+
world_size: int = 1,
|
| 416 |
+
**kwargs: Any,
|
| 417 |
+
) -> tuple[dict[str, DatasetType], float | Literal[0]]:
|
| 418 |
+
"""Get image and video datasets at frame level.
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
dataset_root (str): root dir of the datasets.
|
| 422 |
+
feature_models (list[str]): models to load their features.
|
| 423 |
+
dataset_mix (Optional[str | dict[str, float] | list], optional): how to mix the datasets.
|
| 424 |
+
split (str, optional): split "train" or "val" or "test". Defaults to "train".
|
| 425 |
+
dataset_ratio (float, optional): how much data use for the (combined) dataset. Defaults to 1.0.
|
| 426 |
+
image_transform (Optional[Callable[[Any], torch.Tensor]], optional): image transform applied to samples.
|
| 427 |
+
Defaults to None.
|
| 428 |
+
feature_norm: (bool, optional): whether to normalize the feature. Defaults to False.
|
| 429 |
+
seed (Optional[int | str], optional): seed. Defaults to 0.
|
| 430 |
+
shuffle (bool, optional): shuffle or not. Defaults to False.
|
| 431 |
+
world_size (int, optional): world size of DDP training. Defaults to 1.
|
| 432 |
+
kwargs (Any): arguments to pass-through.
|
| 433 |
+
|
| 434 |
+
Returns:
|
| 435 |
+
tuple[dict[str, DatasetType], int]: a dict of {dataset name: dataset class}.
|
| 436 |
+
"""
|
| 437 |
+
# read dataset mix from any acceptable form
|
| 438 |
+
if isinstance(dataset_mix, str) and dataset_mix in OXE_NAMED_MIXES:
|
| 439 |
+
dataset_mix = OrderedDict({k: v for k, v in OXE_NAMED_MIXES[dataset_mix]})
|
| 440 |
+
elif isinstance(dataset_mix, dict):
|
| 441 |
+
dataset_mix = OrderedDict(**dataset_mix)
|
| 442 |
+
elif isinstance(dataset_mix, list) or isinstance(dataset_mix, omegaconf.listconfig.ListConfig):
|
| 443 |
+
dataset_mix = OrderedDict({d: 1.0 for d in dataset_mix})
|
| 444 |
+
else:
|
| 445 |
+
raise ValueError(f"dataset_mix of {dataset_mix}:{type(dataset_mix)} is not supported.")
|
| 446 |
+
|
| 447 |
+
if split == "eval" or split == "val":
|
| 448 |
+
dataset_mix = OrderedDict({d: 1.0 for d in dataset_mix})
|
| 449 |
+
|
| 450 |
+
# note down the dataset weights
|
| 451 |
+
dataset_weights: list[float] = []
|
| 452 |
+
# get frame level length
|
| 453 |
+
dataset_lens: list[int] = []
|
| 454 |
+
|
| 455 |
+
all_feature_datasets: dict[str, DatasetType] = {}
|
| 456 |
+
|
| 457 |
+
if feature_norm:
|
| 458 |
+
feature_means, feature_vars = load_feature_stats(dataset_root, feature_models)
|
| 459 |
+
|
| 460 |
+
for d in dataset_mix:
|
| 461 |
+
|
| 462 |
+
with open(osp.join(dataset_root, d, "splits.json"), "r") as splitf:
|
| 463 |
+
dataset_len = json.load(splitf)[split]
|
| 464 |
+
|
| 465 |
+
# if the length is 0, skip
|
| 466 |
+
# this may happen for small datasets with very few shards
|
| 467 |
+
if dataset_len == 0:
|
| 468 |
+
continue
|
| 469 |
+
|
| 470 |
+
path_pattern = osp.join(dataset_root, d, "images", f"*-{split}.tar")
|
| 471 |
+
if "image" not in all_feature_datasets:
|
| 472 |
+
all_feature_datasets["image"] = []
|
| 473 |
+
shard_paths = sorted(glob.glob(path_pattern))
|
| 474 |
+
num_shards = len(shard_paths)
|
| 475 |
+
num_parts = world_size
|
| 476 |
+
final_shard_paths = pad_shard_paths(shard_paths, num_shards, num_parts)
|
| 477 |
+
ds = wds.WebDataset(
|
| 478 |
+
final_shard_paths,
|
| 479 |
+
nodesplitter=wds.split_by_node,
|
| 480 |
+
workersplitter=wds.split_by_worker,
|
| 481 |
+
detshuffle=True,
|
| 482 |
+
shardshuffle=shuffle,
|
| 483 |
+
seed=seed,
|
| 484 |
+
).decode(partial(decode_sample, image_transform=image_transform))
|
| 485 |
+
all_feature_datasets["image"].append(ds)
|
| 486 |
+
|
| 487 |
+
for model_name in feature_models:
|
| 488 |
+
path_pattern = osp.join(dataset_root, d, f"{model_name.replace('/', '_')}", f"*-{split}.tar")
|
| 489 |
+
rename_kw = {model_name: model_name.replace("/", "_").lower() + ".safetensors"} # replace v by k
|
| 490 |
+
|
| 491 |
+
if model_name not in all_feature_datasets:
|
| 492 |
+
all_feature_datasets[model_name] = []
|
| 493 |
+
|
| 494 |
+
shard_paths = sorted(glob.glob(path_pattern))
|
| 495 |
+
num_shards = len(shard_paths)
|
| 496 |
+
num_parts = world_size
|
| 497 |
+
final_shard_paths = pad_shard_paths(shard_paths, num_shards, num_parts)
|
| 498 |
+
if feature_norm:
|
| 499 |
+
feature_transform = partial(
|
| 500 |
+
normalize_feature, mean=feature_means[model_name], std=feature_vars[model_name]
|
| 501 |
+
)
|
| 502 |
+
else:
|
| 503 |
+
feature_transform = None
|
| 504 |
+
ds = (
|
| 505 |
+
wds.WebDataset(
|
| 506 |
+
final_shard_paths,
|
| 507 |
+
nodesplitter=wds.split_by_node,
|
| 508 |
+
workersplitter=wds.split_by_worker,
|
| 509 |
+
detshuffle=True,
|
| 510 |
+
shardshuffle=shuffle,
|
| 511 |
+
seed=seed,
|
| 512 |
+
)
|
| 513 |
+
.decode(partial(decode_sample, image_transform=image_transform, feature_transform=feature_transform))
|
| 514 |
+
.rename(keep=False, **rename_kw)
|
| 515 |
+
)
|
| 516 |
+
all_feature_datasets[model_name].append(ds)
|
| 517 |
+
|
| 518 |
+
dataset_weights.append(dataset_mix[d])
|
| 519 |
+
dataset_lens.append(math.ceil(dataset_len * dataset_ratio))
|
| 520 |
+
|
| 521 |
+
normalized_dataset_weights, sum_expected_lengths = normalize_ds_weights_by_ds_len(dataset_weights, dataset_lens)
|
| 522 |
+
|
| 523 |
+
combined_feature_datasets: dict[str, Dataset] = {}
|
| 524 |
+
for feature_set_name, fds in all_feature_datasets.items():
|
| 525 |
+
ds = RandomMix(fds, probs=normalized_dataset_weights, stopping_strategy="all_exhausted", seed=seed)
|
| 526 |
+
combined_feature_datasets[feature_set_name] = ds
|
| 527 |
+
|
| 528 |
+
return combined_feature_datasets, sum_expected_lengths
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def get_frame_dataloader(
|
| 532 |
+
datasets: dict[str, DatasetType],
|
| 533 |
+
batch_size: Optional[int] = None,
|
| 534 |
+
shuffle: bool = False,
|
| 535 |
+
shuffle_buffer_size: int = 1_000,
|
| 536 |
+
seed: Optional[int] = 0,
|
| 537 |
+
**kwargs: Any,
|
| 538 |
+
) -> dict[str, DataLoader]:
|
| 539 |
+
"""Get dataloaders of image and video datasets. Corresponding to `get_image_video_dataset()`.
|
| 540 |
+
|
| 541 |
+
Args:
|
| 542 |
+
datasets (dict[str, DatasetType]): image and video datasets from `get_image_video_dataset().
|
| 543 |
+
batch_size (Optional[int], optional): batch size. Defaults to None.
|
| 544 |
+
shuffle_buffer_size (int, optional): buffer for shuffle while streaming. Defaults to 1_000.
|
| 545 |
+
|
| 546 |
+
Returns:
|
| 547 |
+
dict[str, DataLoader]: dataloaders. a dict of {dataset name: dataloader}.
|
| 548 |
+
"""
|
| 549 |
+
loaders = {}
|
| 550 |
+
for k in datasets:
|
| 551 |
+
loader = wds.WebLoader(datasets[k], batch_size=None, generator=default_generator, **kwargs)
|
| 552 |
+
if shuffle:
|
| 553 |
+
loader = loader.shuffle(shuffle_buffer_size, seed=seed) # shuffle after mix
|
| 554 |
+
loader = loader.batched(batch_size, collation_fn=default_collate)
|
| 555 |
+
loaders[k] = loader
|
| 556 |
+
return loaders
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def get_frame_iterator(
|
| 560 |
+
data_loaders: dict[str, DataLoader],
|
| 561 |
+
) -> Iterator[dict[str, Any]]:
|
| 562 |
+
"""Get iterator from image and video dataset dataloders. Corresponding to `get_frame_dataloader()`.
|
| 563 |
+
|
| 564 |
+
Args:
|
| 565 |
+
data_loaders (dict[str, DataLoader]): dataloaders from `get_frame_dataloader()`.
|
| 566 |
+
|
| 567 |
+
Yields:
|
| 568 |
+
Iterator[dict[str, Any]]: data sample.
|
| 569 |
+
"""
|
| 570 |
+
packed_loader = data_loaders.get("packed", None)
|
| 571 |
+
# place packed_loader at the first
|
| 572 |
+
if packed_loader is not None:
|
| 573 |
+
loaders = [packed_loader, *[data_loaders[k] for k in data_loaders if k != "packed"]]
|
| 574 |
+
else:
|
| 575 |
+
loaders = list(data_loaders.values())
|
| 576 |
+
|
| 577 |
+
# merge dicts
|
| 578 |
+
# this is to accommodate the old organization of datasets (each shard contains one or more columns,
|
| 579 |
+
# and images are duplicated columns).
|
| 580 |
+
# In new (current) dataset organization (columns are completely separated),
|
| 581 |
+
# column keys are all different except some "built-in" keys added by webdataset,
|
| 582 |
+
# but they are not related to any data, training, so on.
|
| 583 |
+
# During transit from old to new, where two organizations exist at the same time,
|
| 584 |
+
# this is to ignore extra "image" field in datasets loaded.
|
| 585 |
+
for data in zip(*loaders, strict=False):
|
| 586 |
+
# yield data
|
| 587 |
+
for i in range(1, len(loaders)):
|
| 588 |
+
for k in data[i]:
|
| 589 |
+
if k not in data[0]:
|
| 590 |
+
data[0][k] = data[i][k]
|
| 591 |
+
yield data[0]
|
theia/dataset/image/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
| 2 |
+
|
| 3 |
+
from .image_common import ALL_IMAGE_DATASETS
|
theia/dataset/image/image_common.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
| 2 |
+
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
|
| 5 |
+
ALL_IMAGE_DATASETS = OrderedDict({"imagenet": {"steps": 1_281_167}})
|
theia/dataset/oxe/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
theia/dataset/oxe/oxe_common.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
| 2 |
+
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
This ALL_OXE_DATASETS below records metadata of all subsets of OXE dataset.
|
| 8 |
+
The datasets are in alphabetical order.
|
| 9 |
+
|
| 10 |
+
versions (list[str]): available and usable versions, sorted from older to newer.
|
| 11 |
+
Usually use the last one.
|
| 12 |
+
episodes (int): total episodes in the dataset.
|
| 13 |
+
steps (int): total steps in the dataset.
|
| 14 |
+
visual_observation_keys (list[str]): keys to specify image observations.
|
| 15 |
+
"""
|
| 16 |
+
ALL_OXE_DATASETS: OrderedDict = OrderedDict(
|
| 17 |
+
{
|
| 18 |
+
"agent_aware_affordances": {
|
| 19 |
+
"versions": ["1.0.0"],
|
| 20 |
+
"episodes": 118,
|
| 21 |
+
"steps": 151628,
|
| 22 |
+
"visual_observation_keys": ["image"],
|
| 23 |
+
},
|
| 24 |
+
"asu_table_top_converted_externally_to_rlds": {
|
| 25 |
+
"versions": ["0.1.0"],
|
| 26 |
+
"episodes": 110,
|
| 27 |
+
"steps": 26113,
|
| 28 |
+
"visual_observation_keys": ["image"],
|
| 29 |
+
},
|
| 30 |
+
"austin_buds_dataset_converted_externally_to_rlds": {
|
| 31 |
+
"versions": ["0.1.0"],
|
| 32 |
+
"episodes": 50,
|
| 33 |
+
"steps": 34112,
|
| 34 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
| 35 |
+
},
|
| 36 |
+
"austin_sailor_dataset_converted_externally_to_rlds": {
|
| 37 |
+
"versions": ["0.1.0"],
|
| 38 |
+
"episodes": 240,
|
| 39 |
+
"steps": 353094,
|
| 40 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
| 41 |
+
},
|
| 42 |
+
"austin_sirius_dataset_converted_externally_to_rlds": {
|
| 43 |
+
"versions": ["0.1.0"],
|
| 44 |
+
"episodes": 559,
|
| 45 |
+
"steps": 279939,
|
| 46 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
| 47 |
+
},
|
| 48 |
+
"bc_z": {
|
| 49 |
+
"versions": [
|
| 50 |
+
"0.1.0", # "1.0.0", "old1.0.1", and "1.0.1" are not usable
|
| 51 |
+
],
|
| 52 |
+
"episodes": 39350,
|
| 53 |
+
"steps": 5471693,
|
| 54 |
+
"visual_observation_keys": ["image"],
|
| 55 |
+
},
|
| 56 |
+
"berkeley_autolab_ur5": {
|
| 57 |
+
"versions": ["0.1.0"],
|
| 58 |
+
"episodes": 896,
|
| 59 |
+
"steps": 87783,
|
| 60 |
+
"visual_observation_keys": ["image", "hand_image"],
|
| 61 |
+
},
|
| 62 |
+
"berkeley_cable_routing": {
|
| 63 |
+
"versions": ["0.1.0"],
|
| 64 |
+
"episodes": 1482,
|
| 65 |
+
"steps": 38240,
|
| 66 |
+
"visual_observation_keys": ["image", "top_image", "wrist225_image", "wrist45_image"],
|
| 67 |
+
},
|
| 68 |
+
"berkeley_fanuc_manipulation": {
|
| 69 |
+
"versions": ["0.1.0"],
|
| 70 |
+
"episodes": 415,
|
| 71 |
+
"steps": 62613,
|
| 72 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
| 73 |
+
},
|
| 74 |
+
"berkeley_gnm_cory_hall": {
|
| 75 |
+
"versions": ["0.1.0"],
|
| 76 |
+
"episodes": 7331,
|
| 77 |
+
"steps": 156012,
|
| 78 |
+
"visual_observation_keys": ["image"],
|
| 79 |
+
},
|
| 80 |
+
"berkeley_gnm_recon": {
|
| 81 |
+
"versions": ["0.1.0"],
|
| 82 |
+
"episodes": 11834,
|
| 83 |
+
"steps": 610907,
|
| 84 |
+
"visual_observation_keys": ["image"],
|
| 85 |
+
},
|
| 86 |
+
"berkeley_gnm_sac_son": {
|
| 87 |
+
"versions": ["0.1.0"],
|
| 88 |
+
"episodes": 2955,
|
| 89 |
+
"steps": 241059,
|
| 90 |
+
"visual_observation_keys": ["image"],
|
| 91 |
+
},
|
| 92 |
+
"berkeley_mvp_converted_externally_to_rlds": {
|
| 93 |
+
"versions": ["0.1.0"],
|
| 94 |
+
"episodes": 480,
|
| 95 |
+
"steps": 45308,
|
| 96 |
+
"visual_observation_keys": ["hand_image"],
|
| 97 |
+
},
|
| 98 |
+
"berkeley_rpt_converted_externally_to_rlds": {
|
| 99 |
+
"versions": ["0.1.0"],
|
| 100 |
+
"episodes": 908,
|
| 101 |
+
"steps": 392578,
|
| 102 |
+
"visual_observation_keys": ["hand_image"],
|
| 103 |
+
},
|
| 104 |
+
"bridge": {"versions": ["0.1.0"], "episodes": 25460, "steps": 864292, "visual_observation_keys": ["image"]},
|
| 105 |
+
"cmu_franka_exploration_dataset_converted_externally_to_rlds": {
|
| 106 |
+
"versions": ["0.1.0"],
|
| 107 |
+
"episodes": 199,
|
| 108 |
+
"steps": 1990,
|
| 109 |
+
"visual_observation_keys": ["image"],
|
| 110 |
+
},
|
| 111 |
+
"cmu_play_fusion": {
|
| 112 |
+
"versions": ["0.1.0"],
|
| 113 |
+
"episodes": 576,
|
| 114 |
+
"steps": 235922,
|
| 115 |
+
"visual_observation_keys": ["image"],
|
| 116 |
+
},
|
| 117 |
+
"cmu_playing_with_food": { # this dataset seems to be corrupted
|
| 118 |
+
"versions": ["1.0.0"],
|
| 119 |
+
"episodes": 4200,
|
| 120 |
+
"steps": 83240,
|
| 121 |
+
"visual_observation_keys": ["image"],
|
| 122 |
+
},
|
| 123 |
+
"cmu_stretch": {"versions": ["0.1.0"], "episodes": 135, "steps": 25016, "visual_observation_keys": ["image"]},
|
| 124 |
+
"columbia_cairlab_pusht_real": {
|
| 125 |
+
"versions": ["0.1.0"],
|
| 126 |
+
"episodes": 122,
|
| 127 |
+
"steps": 24924,
|
| 128 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
| 129 |
+
},
|
| 130 |
+
"dlr_edan_shared_control_converted_externally_to_rlds": {
|
| 131 |
+
"versions": ["0.1.0"],
|
| 132 |
+
"episodes": 104,
|
| 133 |
+
"steps": 8928,
|
| 134 |
+
"visual_observation_keys": ["image"],
|
| 135 |
+
},
|
| 136 |
+
"dlr_sara_grid_clamp_converted_externally_to_rlds": {
|
| 137 |
+
"versions": ["0.1.0"],
|
| 138 |
+
"episodes": 107,
|
| 139 |
+
"steps": 7622,
|
| 140 |
+
"visual_observation_keys": ["image"],
|
| 141 |
+
},
|
| 142 |
+
"dlr_sara_pour_converted_externally_to_rlds": {
|
| 143 |
+
"versions": ["0.1.0"],
|
| 144 |
+
"episodes": 100,
|
| 145 |
+
"steps": 12971,
|
| 146 |
+
"visual_observation_keys": ["image"],
|
| 147 |
+
},
|
| 148 |
+
"eth_agent_affordances": {
|
| 149 |
+
"versions": ["0.1.0"],
|
| 150 |
+
"episodes": 118,
|
| 151 |
+
"steps": 151628,
|
| 152 |
+
"visual_observation_keys": ["image"],
|
| 153 |
+
},
|
| 154 |
+
"fanuc_manipulation_v2": {
|
| 155 |
+
"versions": ["1.0.0"],
|
| 156 |
+
"episodes": 415,
|
| 157 |
+
"steps": 62613,
|
| 158 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
| 159 |
+
},
|
| 160 |
+
"fractal20220817_data": {
|
| 161 |
+
"versions": ["0.1.0"],
|
| 162 |
+
"episodes": 87212,
|
| 163 |
+
"steps": 3786400,
|
| 164 |
+
"visual_observation_keys": ["image"],
|
| 165 |
+
},
|
| 166 |
+
"furniture_bench_dataset_converted_externally_to_rlds": {
|
| 167 |
+
"versions": ["0.1.0"],
|
| 168 |
+
"episodes": 5100,
|
| 169 |
+
"steps": 3948057,
|
| 170 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
| 171 |
+
},
|
| 172 |
+
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": {
|
| 173 |
+
"versions": ["0.1.0"],
|
| 174 |
+
"episodes": 631,
|
| 175 |
+
"steps": 146241,
|
| 176 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
| 177 |
+
},
|
| 178 |
+
"imperial_wrist_dataset": {
|
| 179 |
+
"versions": ["1.0.0"],
|
| 180 |
+
"episodes": 170,
|
| 181 |
+
"steps": 7148,
|
| 182 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
| 183 |
+
},
|
| 184 |
+
"imperialcollege_sawyer_wrist_cam": {
|
| 185 |
+
"versions": ["0.1.0"],
|
| 186 |
+
"episodes": 170,
|
| 187 |
+
"steps": 7148,
|
| 188 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
| 189 |
+
},
|
| 190 |
+
"jaco_play": {
|
| 191 |
+
"versions": ["0.1.0"],
|
| 192 |
+
"episodes": 976,
|
| 193 |
+
"steps": 70127,
|
| 194 |
+
"visual_observation_keys": ["image", "image_wrist"],
|
| 195 |
+
},
|
| 196 |
+
"kaist_nonprehensile_converted_externally_to_rlds": {
|
| 197 |
+
"versions": ["0.1.0"],
|
| 198 |
+
"episodes": 201,
|
| 199 |
+
"steps": 32429,
|
| 200 |
+
"visual_observation_keys": ["image"],
|
| 201 |
+
},
|
| 202 |
+
"kuka": {"versions": ["0.1.0"], "episodes": 580392, "steps": 8583978, "visual_observation_keys": ["image"]},
|
| 203 |
+
"language_table": {
|
| 204 |
+
"versions": ["0.0.1", "0.1.0"],
|
| 205 |
+
"episodes": 442226,
|
| 206 |
+
"steps": 7045476,
|
| 207 |
+
"visual_observation_keys": ["rgb"],
|
| 208 |
+
},
|
| 209 |
+
"language_table_blocktoabsolute_oracle_sim": {
|
| 210 |
+
"versions": ["0.0.1"],
|
| 211 |
+
"episodes": 200000,
|
| 212 |
+
"steps": 15866385,
|
| 213 |
+
"visual_observation_keys": ["rgb"],
|
| 214 |
+
},
|
| 215 |
+
"language_table_blocktoblock_4block_sim": {
|
| 216 |
+
"versions": ["0.0.1"],
|
| 217 |
+
"episodes": 8298,
|
| 218 |
+
"steps": 326768,
|
| 219 |
+
"visual_observation_keys": ["rgb"],
|
| 220 |
+
},
|
| 221 |
+
"language_table_blocktoblock_oracle_sim": {
|
| 222 |
+
"versions": ["0.0.1"],
|
| 223 |
+
"episodes": 200000,
|
| 224 |
+
"steps": 12970620,
|
| 225 |
+
"visual_observation_keys": ["rgb"],
|
| 226 |
+
},
|
| 227 |
+
"language_table_blocktoblock_sim": {
|
| 228 |
+
"versions": ["0.0.1"],
|
| 229 |
+
"episodes": 8000,
|
| 230 |
+
"steps": 351688,
|
| 231 |
+
"visual_observation_keys": ["rgb"],
|
| 232 |
+
},
|
| 233 |
+
"language_table_blocktoblockrelative_oracle_sim": {
|
| 234 |
+
"versions": ["0.0.1"],
|
| 235 |
+
"episodes": 200000,
|
| 236 |
+
"steps": 13016749,
|
| 237 |
+
"visual_observation_keys": ["rgb"],
|
| 238 |
+
},
|
| 239 |
+
"language_table_blocktorelative_oracle_sim": {
|
| 240 |
+
"versions": ["0.0.1"],
|
| 241 |
+
"episodes": 200000,
|
| 242 |
+
"steps": 8655815,
|
| 243 |
+
"visual_observation_keys": ["rgb"],
|
| 244 |
+
},
|
| 245 |
+
"language_table_separate_oracle_sim": {
|
| 246 |
+
"versions": ["0.0.1"],
|
| 247 |
+
"episodes": 200000,
|
| 248 |
+
"steps": 3196661,
|
| 249 |
+
"visual_observation_keys": ["rgb"],
|
| 250 |
+
},
|
| 251 |
+
"language_table_sim": {
|
| 252 |
+
"versions": ["0.0.1"],
|
| 253 |
+
"episodes": 181020,
|
| 254 |
+
"steps": 4665423,
|
| 255 |
+
"visual_observation_keys": ["rgb"],
|
| 256 |
+
},
|
| 257 |
+
"maniskill_dataset_converted_externally_to_rlds": {
|
| 258 |
+
"versions": ["0.1.0"],
|
| 259 |
+
"episodes": 30213,
|
| 260 |
+
"steps": 4537402,
|
| 261 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
| 262 |
+
},
|
| 263 |
+
"mutex_dataset": {
|
| 264 |
+
"versions": ["1.0.0"],
|
| 265 |
+
"episodes": 1500,
|
| 266 |
+
"steps": 361883,
|
| 267 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
| 268 |
+
},
|
| 269 |
+
"nyu_door_opening_surprising_effectiveness": {
|
| 270 |
+
"versions": ["0.1.0"],
|
| 271 |
+
"episodes": 435,
|
| 272 |
+
"steps": 18196,
|
| 273 |
+
"visual_observation_keys": ["image"],
|
| 274 |
+
},
|
| 275 |
+
"nyu_franka_play_dataset_converted_externally_to_rlds": {
|
| 276 |
+
"versions": ["0.1.0"],
|
| 277 |
+
"episodes": 365,
|
| 278 |
+
"steps": 34448,
|
| 279 |
+
"visual_observation_keys": ["image", "image_additional_view"],
|
| 280 |
+
},
|
| 281 |
+
"nyu_rot_dataset_converted_externally_to_rlds": {
|
| 282 |
+
"versions": ["0.1.0"],
|
| 283 |
+
"episodes": 14,
|
| 284 |
+
"steps": 440,
|
| 285 |
+
"visual_observation_keys": ["image"],
|
| 286 |
+
},
|
| 287 |
+
"qut_dexterous_manpulation": {
|
| 288 |
+
"versions": ["0.1.0"],
|
| 289 |
+
"episodes": 200,
|
| 290 |
+
"steps": 176278,
|
| 291 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
| 292 |
+
},
|
| 293 |
+
"robo_net": {
|
| 294 |
+
"versions": ["0.1.0", "1.0.0"],
|
| 295 |
+
"episodes": 82775,
|
| 296 |
+
"steps": 2483250,
|
| 297 |
+
"visual_observation_keys": ["image", "image1", "image2"],
|
| 298 |
+
},
|
| 299 |
+
"robot_vqa": {
|
| 300 |
+
"versions": ["0.1.0"],
|
| 301 |
+
"episodes": 3331523,
|
| 302 |
+
"steps": 3331523,
|
| 303 |
+
"visual_observation_keys": ["images"],
|
| 304 |
+
},
|
| 305 |
+
"roboturk": {
|
| 306 |
+
"versions": ["0.1.0"],
|
| 307 |
+
"episodes": 1796,
|
| 308 |
+
"steps": 168423,
|
| 309 |
+
"visual_observation_keys": ["front_rgb"],
|
| 310 |
+
},
|
| 311 |
+
"stanford_hydra_dataset_converted_externally_to_rlds": {
|
| 312 |
+
"versions": ["0.1.0"],
|
| 313 |
+
"episodes": 570,
|
| 314 |
+
"steps": 358234,
|
| 315 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
| 316 |
+
},
|
| 317 |
+
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": {
|
| 318 |
+
"versions": ["0.1.0"],
|
| 319 |
+
"episodes": 3000,
|
| 320 |
+
"steps": 149985,
|
| 321 |
+
"visual_observation_keys": ["image"],
|
| 322 |
+
},
|
| 323 |
+
"stanford_mask_vit_converted_externally_to_rlds": {
|
| 324 |
+
"versions": ["0.1.0"],
|
| 325 |
+
"episodes": 9109,
|
| 326 |
+
"steps": 282379,
|
| 327 |
+
"visual_observation_keys": ["image"],
|
| 328 |
+
},
|
| 329 |
+
"stanford_robocook_converted_externally_to_rlds": {
|
| 330 |
+
"versions": ["0.1.0"],
|
| 331 |
+
"episodes": 2460,
|
| 332 |
+
"steps": 112980,
|
| 333 |
+
"visual_observation_keys": ["image_1", "image_2", "image_3", "image_4"],
|
| 334 |
+
},
|
| 335 |
+
"taco_play": {
|
| 336 |
+
"versions": ["0.1.0"],
|
| 337 |
+
"episodes": 3242,
|
| 338 |
+
"steps": 213972,
|
| 339 |
+
"visual_observation_keys": ["rgb_static", "rgb_gripper"],
|
| 340 |
+
},
|
| 341 |
+
"tokyo_u_lsmo_converted_externally_to_rlds": {
|
| 342 |
+
"versions": ["0.1.0"],
|
| 343 |
+
"episodes": 50,
|
| 344 |
+
"steps": 11925,
|
| 345 |
+
"visual_observation_keys": ["image"],
|
| 346 |
+
},
|
| 347 |
+
"toto": {"versions": ["0.1.0"], "episodes": 902, "steps": 294139, "visual_observation_keys": ["image"]},
|
| 348 |
+
"ucsd_kitchen_dataset_converted_externally_to_rlds": {
|
| 349 |
+
"versions": ["0.1.0"],
|
| 350 |
+
"episodes": 150,
|
| 351 |
+
"steps": 3970,
|
| 352 |
+
"visual_observation_keys": ["image"],
|
| 353 |
+
},
|
| 354 |
+
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": {
|
| 355 |
+
"versions": ["0.1.0"],
|
| 356 |
+
"episodes": 1355,
|
| 357 |
+
"steps": 67750,
|
| 358 |
+
"visual_observation_keys": ["image"],
|
| 359 |
+
},
|
| 360 |
+
"uiuc_d3field": { # this dataset seems to be corrupted
|
| 361 |
+
"versions": ["0.1.0", "1.1.2"],
|
| 362 |
+
"episodes": 196,
|
| 363 |
+
"steps": 13384,
|
| 364 |
+
"visual_observation_keys": ["image_1", "image_2", "image_3", "image_4"],
|
| 365 |
+
},
|
| 366 |
+
"usc_cloth_sim_converted_externally_to_rlds": {
|
| 367 |
+
"versions": ["0.1.0"],
|
| 368 |
+
"episodes": 800,
|
| 369 |
+
"steps": 80000,
|
| 370 |
+
"visual_observation_keys": ["image"],
|
| 371 |
+
},
|
| 372 |
+
"utaustin_mutex": {
|
| 373 |
+
"versions": ["0.1.0"],
|
| 374 |
+
"episodes": 1500,
|
| 375 |
+
"steps": 361883,
|
| 376 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
| 377 |
+
},
|
| 378 |
+
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": {
|
| 379 |
+
"versions": ["0.1.0"],
|
| 380 |
+
"episodes": 64,
|
| 381 |
+
"steps": 9140,
|
| 382 |
+
"visual_observation_keys": ["image"],
|
| 383 |
+
},
|
| 384 |
+
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": {
|
| 385 |
+
"versions": ["0.1.0"],
|
| 386 |
+
"episodes": 192,
|
| 387 |
+
"steps": 26346,
|
| 388 |
+
"visual_observation_keys": ["image"],
|
| 389 |
+
},
|
| 390 |
+
"utokyo_saytap_converted_externally_to_rlds": {
|
| 391 |
+
"versions": ["0.1.0"],
|
| 392 |
+
"episodes": 20,
|
| 393 |
+
"steps": 22937,
|
| 394 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
| 395 |
+
},
|
| 396 |
+
"utokyo_xarm_bimanual_converted_externally_to_rlds": {
|
| 397 |
+
"versions": ["0.1.0"],
|
| 398 |
+
"episodes": 64,
|
| 399 |
+
"steps": 1388,
|
| 400 |
+
"visual_observation_keys": ["image"],
|
| 401 |
+
},
|
| 402 |
+
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": {
|
| 403 |
+
"versions": ["0.1.0"],
|
| 404 |
+
"episodes": 92,
|
| 405 |
+
"steps": 6789,
|
| 406 |
+
"visual_observation_keys": ["image", "hand_image", "image2"],
|
| 407 |
+
},
|
| 408 |
+
"viola": {
|
| 409 |
+
"versions": ["0.1.0"],
|
| 410 |
+
"episodes": 135,
|
| 411 |
+
"steps": 68913,
|
| 412 |
+
"visual_observation_keys": ["agentview_rgb", "eye_in_hand_rgb"],
|
| 413 |
+
},
|
| 414 |
+
}
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def oxe_dsname2path(dataset_name: str, version: Optional[str] = None) -> str:
|
| 419 |
+
"""From dataset name to remote google clound path to the dataset.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
dataset_name (str): dataset name.
|
| 423 |
+
version (Optional[str]): version string.
|
| 424 |
+
|
| 425 |
+
Returns:
|
| 426 |
+
str: google clound path
|
| 427 |
+
"""
|
| 428 |
+
if version is None:
|
| 429 |
+
version = ALL_OXE_DATASETS[dataset_name]["versions"][-1]
|
| 430 |
+
return f"gs://gresearch/robotics/{dataset_name}/{version}"
|
theia/dataset/oxe/oxe_mixes.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# File modified. Modifications Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
| 2 |
+
|
| 3 |
+
"""MIT License Copyright (c) 2023 Robotic AI & Learning Lab Berkeley
|
| 4 |
+
|
| 5 |
+
From Octo https://github.com/octo-models/octo/blob/main/octo/data/oxe/oxe_dataset_mixes.py
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
BRIDGE_MIX = [
|
| 9 |
+
("bridge_dataset", 1.0),
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
RT_X_MIX = [
|
| 13 |
+
("fractal20220817_data", 0.54087122203),
|
| 14 |
+
("kuka", 0.8341046294),
|
| 15 |
+
("bridge_dataset", 1.0),
|
| 16 |
+
("taco_play", 2.0),
|
| 17 |
+
("jaco_play", 2.0),
|
| 18 |
+
("berkeley_cable_routing", 3.0),
|
| 19 |
+
("roboturk", 1.0),
|
| 20 |
+
("nyu_door_opening_surprising_effectiveness", 5.0),
|
| 21 |
+
("viola", 2.0),
|
| 22 |
+
("berkeley_autolab_ur5", 1.0),
|
| 23 |
+
("toto", 1.0),
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
OXE_FRANKA_MIX = [
|
| 28 |
+
("taco_play", 1.0),
|
| 29 |
+
("berkeley_cable_routing", 1.0),
|
| 30 |
+
("viola", 1.0),
|
| 31 |
+
("toto", 1.0),
|
| 32 |
+
("stanford_hydra_dataset_converted_externally_to_rlds", 1.0),
|
| 33 |
+
("austin_buds_dataset_converted_externally_to_rlds", 3.0),
|
| 34 |
+
("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0),
|
| 35 |
+
("maniskill_dataset_converted_externally_to_rlds", 0.1),
|
| 36 |
+
("furniture_bench_dataset_converted_externally_to_rlds", 0.1),
|
| 37 |
+
("cmu_franka_exploration_dataset_converted_externally_to_rlds", 5.0),
|
| 38 |
+
("austin_sailor_dataset_converted_externally_to_rlds", 1.0),
|
| 39 |
+
("austin_sirius_dataset_converted_externally_to_rlds", 1.0),
|
| 40 |
+
("berkeley_rpt_converted_externally_to_rlds", 1.0),
|
| 41 |
+
("kaist_nonprehensile_converted_externally_to_rlds", 3.0),
|
| 42 |
+
("stanford_robocook_converted_externally_to_rlds", 1.0),
|
| 43 |
+
("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0),
|
| 44 |
+
("utaustin_mutex", 1.0),
|
| 45 |
+
# ("cmu_playing_with_food", 1.0),
|
| 46 |
+
("cmu_play_fusion", 1.0),
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
OXE_MAGIC_SOUP = [
|
| 50 |
+
("fractal20220817_data", 0.54087122203),
|
| 51 |
+
("kuka", 0.8341046294),
|
| 52 |
+
("bridge", 1.0),
|
| 53 |
+
("taco_play", 2.0),
|
| 54 |
+
("jaco_play", 1.0),
|
| 55 |
+
("berkeley_cable_routing", 1.0),
|
| 56 |
+
("roboturk", 2.0),
|
| 57 |
+
("nyu_door_opening_surprising_effectiveness", 1.0),
|
| 58 |
+
("viola", 2.0),
|
| 59 |
+
("berkeley_autolab_ur5", 2.0),
|
| 60 |
+
("toto", 1.0),
|
| 61 |
+
("language_table", 0.1),
|
| 62 |
+
("stanford_hydra_dataset_converted_externally_to_rlds", 2.0),
|
| 63 |
+
("austin_buds_dataset_converted_externally_to_rlds", 1.0),
|
| 64 |
+
("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0),
|
| 65 |
+
("furniture_bench_dataset_converted_externally_to_rlds", 0.1),
|
| 66 |
+
("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0),
|
| 67 |
+
("austin_sailor_dataset_converted_externally_to_rlds", 1.0),
|
| 68 |
+
("austin_sirius_dataset_converted_externally_to_rlds", 1.0),
|
| 69 |
+
("bc_z", 0.2),
|
| 70 |
+
("dlr_edan_shared_control_converted_externally_to_rlds", 1.0),
|
| 71 |
+
("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0),
|
| 72 |
+
# ("uiuc_d3field", 1.0), --> somehow raw data is broken
|
| 73 |
+
("utaustin_mutex", 1.0),
|
| 74 |
+
("berkeley_fanuc_manipulation", 2.0),
|
| 75 |
+
("cmu_stretch", 1.0),
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
OXE_FULL_MIX = [
|
| 80 |
+
("fractal20220817_data", 1.0),
|
| 81 |
+
("kuka", 1.0),
|
| 82 |
+
("bridge_dataset", 1),
|
| 83 |
+
("taco_play", 1.0),
|
| 84 |
+
("jaco_play", 1.0),
|
| 85 |
+
("berkeley_cable_routing", 1.0),
|
| 86 |
+
("roboturk", 1.0),
|
| 87 |
+
("nyu_door_opening_surprising_effectiveness", 1.0),
|
| 88 |
+
("viola", 1.0),
|
| 89 |
+
("berkeley_autolab_ur5", 1.0),
|
| 90 |
+
("toto", 1.0),
|
| 91 |
+
("language_table", 1.0),
|
| 92 |
+
("columbia_cairlab_pusht_real", 1.0),
|
| 93 |
+
("stanford_kuka_multimodal_dataset_converted_externally_to_rlds", 1.0),
|
| 94 |
+
("nyu_rot_dataset_converted_externally_to_rlds", 1.0),
|
| 95 |
+
("stanford_hydra_dataset_converted_externally_to_rlds", 1.0),
|
| 96 |
+
("austin_buds_dataset_converted_externally_to_rlds", 1.0),
|
| 97 |
+
("nyu_franka_play_dataset_converted_externally_to_rlds", 1.0),
|
| 98 |
+
("maniskill_dataset_converted_externally_to_rlds", 1.0),
|
| 99 |
+
("furniture_bench_dataset_converted_externally_to_rlds", 1.0),
|
| 100 |
+
("cmu_franka_exploration_dataset_converted_externally_to_rlds", 1.0),
|
| 101 |
+
("ucsd_kitchen_dataset_converted_externally_to_rlds", 1.0),
|
| 102 |
+
("ucsd_pick_and_place_dataset_converted_externally_to_rlds", 1.0),
|
| 103 |
+
("austin_sailor_dataset_converted_externally_to_rlds", 1.0),
|
| 104 |
+
("austin_sirius_dataset_converted_externally_to_rlds", 1.0),
|
| 105 |
+
("bc_z", 1.0),
|
| 106 |
+
("utokyo_pr2_opening_fridge_converted_externally_to_rlds", 1.0),
|
| 107 |
+
("utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds", 1.0),
|
| 108 |
+
("utokyo_xarm_pick_and_place_converted_externally_to_rlds", 1.0),
|
| 109 |
+
("utokyo_xarm_bimanual_converted_externally_to_rlds", 1.0),
|
| 110 |
+
("robo_net", 1.0),
|
| 111 |
+
("berkeley_mvp_converted_externally_to_rlds", 1.0),
|
| 112 |
+
("berkeley_rpt_converted_externally_to_rlds", 1.0),
|
| 113 |
+
("kaist_nonprehensile_converted_externally_to_rlds", 1.0),
|
| 114 |
+
("stanford_mask_vit_converted_externally_to_rlds", 1.0),
|
| 115 |
+
("tokyo_u_lsmo_converted_externally_to_rlds", 1.0),
|
| 116 |
+
("dlr_sara_pour_converted_externally_to_rlds", 1.0),
|
| 117 |
+
("dlr_sara_grid_clamp_converted_externally_to_rlds", 1.0),
|
| 118 |
+
("dlr_edan_shared_control_converted_externally_to_rlds", 1.0),
|
| 119 |
+
("asu_table_top_converted_externally_to_rlds", 1.0),
|
| 120 |
+
("stanford_robocook_converted_externally_to_rlds", 1.0),
|
| 121 |
+
("imperialcollege_sawyer_wrist_cam", 1.0),
|
| 122 |
+
("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0),
|
| 123 |
+
("uiuc_d3field", 1.0),
|
| 124 |
+
("utaustin_mutex", 1.0),
|
| 125 |
+
("berkeley_fanuc_manipulation", 1.0),
|
| 126 |
+
("cmu_playing_with_food", 1.0),
|
| 127 |
+
("cmu_play_fusion", 1.0),
|
| 128 |
+
("cmu_stretch", 1.0),
|
| 129 |
+
("berkeley_gnm_recon", 1.0),
|
| 130 |
+
("berkeley_gnm_cory_hall", 1.0),
|
| 131 |
+
("berkeley_gnm_sac_son", 1.0),
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
OXE_NAMED_MIXES = {
|
| 135 |
+
"bridge": BRIDGE_MIX,
|
| 136 |
+
"rtx": RT_X_MIX,
|
| 137 |
+
"rtx_franka": RT_X_MIX + OXE_FRANKA_MIX,
|
| 138 |
+
"oxe_magic_soup": OXE_MAGIC_SOUP,
|
| 139 |
+
}
|
theia/dataset/oxe/oxe_transforms.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from numpy.typing import NDArray
|
| 5 |
+
from torchvision.transforms.v2 import Compose, Normalize, ToDtype, ToImage
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def totensor(arr: NDArray) -> torch.Tensor:
|
| 9 |
+
"""Convert ndarray to tensor."""
|
| 10 |
+
return torch.from_numpy(arr)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
oxe_image_transform = Compose(
|
| 14 |
+
[ToImage(), ToDtype(torch.float32, scale=True), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
|
| 15 |
+
) # ImageNet statistics normalization
|
theia/dataset/video/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
| 2 |
+
|
| 3 |
+
from .video_common import ALL_VIDEO_DATASETS
|
theia/dataset/video/video_common.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
| 2 |
+
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
|
| 5 |
+
ALL_VIDEO_DATASETS = OrderedDict(
|
| 6 |
+
{
|
| 7 |
+
"ego4d_1in150": {"steps": 2_800_871},
|
| 8 |
+
"epic_kitchen_1in60": {"steps": 333_117},
|
| 9 |
+
"ssv2_1in32": {"steps": 312_772},
|
| 10 |
+
}
|
| 11 |
+
)
|
theia/decoding/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
| 2 |
+
|
| 3 |
+
from .decode import decode_everything, load_feature_stats
|
| 4 |
+
from .depth_anything import prepare_depth_decoder
|
| 5 |
+
from .sam import prepare_mask_generator
|
theia/decoding/decode.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
from numpy.typing import NDArray
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from sklearn.decomposition import PCA
|
| 13 |
+
from transformers import SamModel, SamProcessor
|
| 14 |
+
from transformers.pipelines import MaskGenerationPipeline
|
| 15 |
+
|
| 16 |
+
from theia.decoding.depth_anything import decode_depth_anything
|
| 17 |
+
from theia.decoding.dinov2 import decode_dinov2
|
| 18 |
+
from theia.decoding.sam import decode_sam
|
| 19 |
+
from theia.preprocessing.feature_extraction_core import (
|
| 20 |
+
get_feature_outputs,
|
| 21 |
+
get_model,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def denormalize_feature(
|
| 26 |
+
x: torch.Tensor, mean: Optional[torch.Tensor] = None, std: Optional[torch.Tensor] = None
|
| 27 |
+
) -> torch.Tensor:
|
| 28 |
+
"""Denormalize the features using mean and std.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
x (torch.Tensor): features to be denomalized.
|
| 32 |
+
mean (Optional[torch.Tensor], optional): mean value of the features. Defaults to None
|
| 33 |
+
std (Optional[torch.Tensor], optional): std value of the features. Defaults to None.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
torch.Tensor: denormalized features.
|
| 37 |
+
"""
|
| 38 |
+
if mean is None and std is None:
|
| 39 |
+
return x
|
| 40 |
+
elif mean is None and std is not None:
|
| 41 |
+
return x * std
|
| 42 |
+
elif mean is not None and std is None:
|
| 43 |
+
return x + mean
|
| 44 |
+
return x * std + mean
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def load_feature_stats(
|
| 48 |
+
feature_models: list[str], stat_file_root: str
|
| 49 |
+
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
| 50 |
+
"""Load the statistics (mean and variance) of the features, per model.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
feature_models (list[str]): names of the models. Note: there are `/` in the name.
|
| 54 |
+
stat_file_root (str): directory that holds feature stat files.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: means and variance.
|
| 58 |
+
"""
|
| 59 |
+
feature_means: dict[str, torch.Tensor] = {}
|
| 60 |
+
feature_vars: dict[str, torch.Tensor] = {}
|
| 61 |
+
for model in feature_models:
|
| 62 |
+
model_name = model.replace("/", "_")
|
| 63 |
+
feature_means[model] = torch.from_numpy(
|
| 64 |
+
np.load(os.path.join(stat_file_root, f"imagenet_mean_{model_name}.npy"))
|
| 65 |
+
)
|
| 66 |
+
feature_vars[model] = torch.from_numpy(np.load(os.path.join(stat_file_root, f"imagenet_var_{model_name}.npy")))
|
| 67 |
+
return feature_means, feature_vars
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def decode_everything(
|
| 71 |
+
theia_model: nn.Module,
|
| 72 |
+
feature_means: dict[str, torch.Tensor],
|
| 73 |
+
feature_vars: dict[str, torch.Tensor],
|
| 74 |
+
images: list[Image.Image],
|
| 75 |
+
mask_generator: MaskGenerationPipeline,
|
| 76 |
+
sam_model: SamModel,
|
| 77 |
+
depth_anything_decoder: nn.Module,
|
| 78 |
+
pred_iou_thresh: float = 0.9,
|
| 79 |
+
stability_score_thresh: float = 0.9,
|
| 80 |
+
gt: bool = False,
|
| 81 |
+
pca: Optional[PCA] = None,
|
| 82 |
+
device: int | str | torch.device = 0,
|
| 83 |
+
) -> tuple[list[NDArray], Optional[list[NDArray]]]:
|
| 84 |
+
"""Decode features from given `theia_model` into different outputs corresponding to upstream models including
|
| 85 |
+
DINOv2, Sam, and Depth-Anything.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
theia_model (nn.Module): theia model.
|
| 89 |
+
feature_means (dict[str, torch.Tensor]): means of the features for denormalization.
|
| 90 |
+
feature_vars (dict[str, torch.Tensor]): variance of the features for denormalization.
|
| 91 |
+
images (list[Image.Image]): input images.
|
| 92 |
+
mask_generator (MaskGenerationPipeline): mask generation pipeline.
|
| 93 |
+
sam_model (SamModel): sam model.
|
| 94 |
+
depth_anything_decoder (nn.Module): depth anything decoder.
|
| 95 |
+
pred_iou_thresh (float, optional): iou threshold for mask generation.
|
| 96 |
+
See transformers.pipelines.MaskGenerationPipeline for more details. Defaults to 0.9.
|
| 97 |
+
stability_score_thresh (float, optional): stability score threshold for mask generation.
|
| 98 |
+
See transformers.pipelines.MaskGenerationPipeline for more details. Defaults to 0.9.
|
| 99 |
+
gt (bool): whether to attach ground truth result in the visualization. Defaults to False.
|
| 100 |
+
pca (Optional[PCA]): pca for DINOv2 decoding. If provided, will use this pca particular. Defaults to None.
|
| 101 |
+
device (int | str | torch.device, optional): device for decoding. Defaults to 0.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
tuple[list[NDArray], Optional[list[NDArray]]]: decoding results from given model,
|
| 105 |
+
and ground truth (if `gt=True`).
|
| 106 |
+
"""
|
| 107 |
+
features: dict[str, torch.Tensor] = {}
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
for im in images:
|
| 110 |
+
feature = theia_model([im])
|
| 111 |
+
if len(features) == 0:
|
| 112 |
+
features = {k: [] for k in feature}
|
| 113 |
+
for k in feature:
|
| 114 |
+
features[k].append(feature[k].detach().cpu())
|
| 115 |
+
for k in features:
|
| 116 |
+
features[k] = torch.cat(features[k], dim=0)
|
| 117 |
+
for m in features:
|
| 118 |
+
features[m] = denormalize_feature(features[m], feature_means[m], feature_vars[m])
|
| 119 |
+
|
| 120 |
+
dino_model_name = "facebook/dinov2-large"
|
| 121 |
+
sam_model_name = "facebook/sam-vit-huge"
|
| 122 |
+
depth_anything_model_name = "LiheYoung/depth-anything-large-hf"
|
| 123 |
+
|
| 124 |
+
pca = None
|
| 125 |
+
# gt
|
| 126 |
+
gt_decode_results = None
|
| 127 |
+
if gt:
|
| 128 |
+
def legit_model_name(model_name: str) -> str:
|
| 129 |
+
return model_name.replace("/", "_")
|
| 130 |
+
|
| 131 |
+
dino_model, dino_processor = get_model(dino_model_name, device=device)
|
| 132 |
+
dino_gt_feature = []
|
| 133 |
+
for im in images:
|
| 134 |
+
dino_gt_feature.append(
|
| 135 |
+
get_feature_outputs(
|
| 136 |
+
legit_model_name(dino_model_name), dino_model, dino_processor, [im], dtype=torch.float
|
| 137 |
+
)[legit_model_name(dino_model_name)]["embedding"]
|
| 138 |
+
.detach()
|
| 139 |
+
.cpu()
|
| 140 |
+
)
|
| 141 |
+
dino_gt_feature = torch.cat(dino_gt_feature, dim=0)
|
| 142 |
+
dino_gt_feature = rearrange(dino_gt_feature, "b c h w -> b (h w) c")
|
| 143 |
+
dino_gt_dec, pca = decode_dinov2(dino_gt_feature, pca=pca)
|
| 144 |
+
sam_processor = SamProcessor.from_pretrained(sam_model_name)
|
| 145 |
+
sam_gt_feature = []
|
| 146 |
+
for im in images:
|
| 147 |
+
sam_inputs = sam_processor(images=[im], return_tensors="pt").to(device)
|
| 148 |
+
with torch.no_grad():
|
| 149 |
+
sam_gt_feature.append(sam_model.get_image_embeddings(sam_inputs["pixel_values"]).detach().cpu())
|
| 150 |
+
sam_gt_feature = torch.cat(sam_gt_feature, dim=0)
|
| 151 |
+
sam_gt_feature = rearrange(sam_gt_feature, "b c h w -> b (h w) c")
|
| 152 |
+
sam_gt_dec = decode_sam(
|
| 153 |
+
sam_gt_feature, images, mask_generator, pred_iou_thresh=0.9, stability_score_thresh=0.9, device=device
|
| 154 |
+
)
|
| 155 |
+
depth_anything_model, depth_anything_processor = get_model(depth_anything_model_name, device=device)
|
| 156 |
+
depth_anything_gt_feature = []
|
| 157 |
+
for im in images:
|
| 158 |
+
depth_anything_gt_feature.append(
|
| 159 |
+
get_feature_outputs(
|
| 160 |
+
legit_model_name(depth_anything_model_name),
|
| 161 |
+
depth_anything_model,
|
| 162 |
+
depth_anything_processor,
|
| 163 |
+
[im],
|
| 164 |
+
dtype=torch.float,
|
| 165 |
+
)[legit_model_name(depth_anything_model_name)]["embedding"]
|
| 166 |
+
.detach()
|
| 167 |
+
.cpu()
|
| 168 |
+
)
|
| 169 |
+
depth_anything_gt_feature = torch.cat(depth_anything_gt_feature, dim=0)
|
| 170 |
+
depth_anything_gt_feature = rearrange(depth_anything_gt_feature, "b c h w -> b (h w) c")
|
| 171 |
+
depth_gt_dec = decode_depth_anything(depth_anything_gt_feature, depth_anything_decoder, device=device)
|
| 172 |
+
|
| 173 |
+
gt_decode_results = [
|
| 174 |
+
np.hstack([np.array(images[i]).astype(np.float32) / 255.0, dino_gt_dec[i], sam_gt_dec[i], depth_gt_dec[i]])
|
| 175 |
+
for i in range(len(images))
|
| 176 |
+
]
|
| 177 |
+
|
| 178 |
+
dino_dec, _ = decode_dinov2(features[dino_model_name], pca=pca)
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
sam_dec = decode_sam(
|
| 182 |
+
features[sam_model_name],
|
| 183 |
+
images,
|
| 184 |
+
mask_generator,
|
| 185 |
+
pred_iou_thresh=pred_iou_thresh,
|
| 186 |
+
stability_score_thresh=stability_score_thresh,
|
| 187 |
+
device=device,
|
| 188 |
+
)
|
| 189 |
+
except IndexError:
|
| 190 |
+
sam_dec = np.zeros_like(dino_dec)
|
| 191 |
+
depth_dec = decode_depth_anything(features[depth_anything_model_name], depth_anything_decoder, device=device)
|
| 192 |
+
|
| 193 |
+
theia_decode_results = [
|
| 194 |
+
np.hstack([np.array(images[i]).astype(np.float32) / 255.0, dino_dec[i], sam_dec[i], depth_dec[i]])
|
| 195 |
+
for i in range(len(images))
|
| 196 |
+
]
|
| 197 |
+
|
| 198 |
+
return theia_decode_results, gt_decode_results
|
theia/decoding/depth_anything.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from theia.foundation_models.vision_models.depth_anything import DepthAnythingForDepthEstimation
|
| 7 |
+
from numpy.typing import NDArray
|
| 8 |
+
from torch.nn.functional import interpolate
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def prepare_depth_decoder(model_name: str, device: int | str | torch.device = 0) -> tuple[nn.Module, int]:
|
| 12 |
+
"""Prepare a depth decoder using DepthAnythingForDepthEstimation.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
model_name (str): name of the depth anything model.
|
| 16 |
+
device (int | str | torch.device, optional): device to put the model on. Defaults to 0.
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
tuple[nn.Module, int]: the decoder, and the patch size for depth anything model.
|
| 20 |
+
"""
|
| 21 |
+
decoder_head = DepthAnythingForDepthEstimation.from_pretrained(model_name)
|
| 22 |
+
patch_size = decoder_head.config.patch_size
|
| 23 |
+
decoder_head = decoder_head.head
|
| 24 |
+
decoder_head = decoder_head.to(device)
|
| 25 |
+
return decoder_head, patch_size
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def decode_depth_anything(features: torch.Tensor, decoder: nn.Module, device: int | str | torch.device = 0) -> NDArray:
|
| 29 |
+
"""Decode features to predicted depth using depth anything
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
features (torch.Tensor): features to be decoded, should be in shape [batch_size, num_tokens, latent_dim].
|
| 33 |
+
decoder (nn.Module): depth anything decoder
|
| 34 |
+
device (int | str | torch.device, optional): device to perform the decoding. Defaults to 0.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
NDArray: decoded depth in image format, represented by an NDArray in size [batch_size, height, width, channels]
|
| 38 |
+
with value between [0, 1]. The depth values are min-max normalized to [0, 1] to generate images.
|
| 39 |
+
"""
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
P = int(features.size(1) ** 0.5)
|
| 42 |
+
features = rearrange(features, "b (h w) c -> b c h w", h=P, w=P)
|
| 43 |
+
features = interpolate(features, (224, 224))
|
| 44 |
+
predicted_depths = []
|
| 45 |
+
for feature in features:
|
| 46 |
+
feature = feature.unsqueeze(0).to(device)
|
| 47 |
+
|
| 48 |
+
predicted_depth = decoder.activation1(feature)
|
| 49 |
+
predicted_depth = decoder.conv3(predicted_depth)
|
| 50 |
+
predicted_depth = decoder.activation2(predicted_depth)
|
| 51 |
+
predicted_depth = predicted_depth.squeeze(dim=1) # shape (batch_size, height, width)
|
| 52 |
+
for i in range(len(predicted_depth)):
|
| 53 |
+
min_depth, max_depth = predicted_depth[i].min(), predicted_depth[i].max()
|
| 54 |
+
predicted_depth[i] = (predicted_depth[i] - min_depth) / (max_depth - min_depth)
|
| 55 |
+
predicted_depths.append(predicted_depth.detach().cpu())
|
| 56 |
+
predicted_depths = torch.cat(predicted_depths, dim=0)
|
| 57 |
+
return predicted_depths.unsqueeze(-1).repeat((1, 1, 1, 3)).numpy() # type: ignore [attr-defined]
|
theia/decoding/dinov2.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
from numpy.typing import NDArray
|
| 8 |
+
from sklearn.decomposition import PCA
|
| 9 |
+
from sklearn.preprocessing import minmax_scale
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def decode_dinov2(
|
| 13 |
+
features: NDArray, threshold: int | float = -100, interpolation: bool = False, pca: Optional[PCA] = None
|
| 14 |
+
) -> tuple[NDArray, PCA]:
|
| 15 |
+
"""
|
| 16 |
+
Decode the input `features` in DINOv2 style using PCA.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
features (NDArray): features to be decoded, should be in shape [batch_size, num_tokens, latent_dim].
|
| 20 |
+
threshold (int | float): threshold of foreground-background split in PCA visualization.
|
| 21 |
+
Defaults to -100 (all patches are included).
|
| 22 |
+
interpolation (bool): whether interpolate the 16x16 pca map to the original image size.
|
| 23 |
+
pca (Optional[PCA]): if provided, use the provided PCA. This is to keep visualizations stable across samples.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
tuple[NDArray, PCA]: the rendered image of this visualization, in NDArray in size
|
| 27 |
+
[batch_size, height, width, channels] with value ranges [0, 1], and the PCA used in this visualization.
|
| 28 |
+
"""
|
| 29 |
+
features = features.numpy()
|
| 30 |
+
batch_size, spatial_size, latent_dim = features.shape
|
| 31 |
+
h = w = int(spatial_size**0.5)
|
| 32 |
+
|
| 33 |
+
features = features.reshape(-1, latent_dim)
|
| 34 |
+
|
| 35 |
+
if pca is None:
|
| 36 |
+
pca = PCA(n_components=3)
|
| 37 |
+
pca.fit(features)
|
| 38 |
+
|
| 39 |
+
pca_features = pca.transform(features)
|
| 40 |
+
|
| 41 |
+
# segment using the first component
|
| 42 |
+
bg_mask = pca_features[:, 0] < threshold
|
| 43 |
+
fg_mask = ~bg_mask
|
| 44 |
+
|
| 45 |
+
# PCA for only foreground patches
|
| 46 |
+
# pca.fit(features[fg_mask])
|
| 47 |
+
pca_features_fg = pca.transform(features[fg_mask])
|
| 48 |
+
for i in range(3):
|
| 49 |
+
pca_features_fg[:, i] = minmax_scale(pca_features_fg[:, i])
|
| 50 |
+
|
| 51 |
+
pca_features_rgb = pca_features.copy()
|
| 52 |
+
pca_features_rgb[bg_mask] = 0
|
| 53 |
+
pca_features_rgb[fg_mask] = pca_features_fg
|
| 54 |
+
|
| 55 |
+
pca_features_rgb = pca_features_rgb.reshape(batch_size, h, w, 3)
|
| 56 |
+
if not interpolation:
|
| 57 |
+
H = W = 224
|
| 58 |
+
scale = H // h
|
| 59 |
+
interpolated_pca_features = np.zeros((batch_size, H, W, 3), dtype=pca_features_rgb.dtype)
|
| 60 |
+
for i in range(len(pca_features_rgb)):
|
| 61 |
+
for j in range(h):
|
| 62 |
+
for k in range(w):
|
| 63 |
+
interpolated_pca_features[i, scale * j : scale * (j + 1), scale * k : scale * (k + 1)] = (
|
| 64 |
+
pca_features_rgb[i, j, k]
|
| 65 |
+
)
|
| 66 |
+
pca_features_rgb = interpolated_pca_features
|
| 67 |
+
else:
|
| 68 |
+
pca_features_rgb = np.stack([cv2.resize(p, (224, 224)) for p in pca_features_rgb])
|
| 69 |
+
return pca_features_rgb, pca
|
theia/decoding/sam.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
| 2 |
+
|
| 3 |
+
from typing import Any, Generator, Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from numpy.typing import NDArray
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from transformers import SamModel, SamProcessor
|
| 11 |
+
from transformers.image_utils import load_image
|
| 12 |
+
from transformers.pipelines import MaskGenerationPipeline
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class MaskGenerationPipelineWithEmbeddings(MaskGenerationPipeline):
|
| 16 |
+
"""
|
| 17 |
+
The wrapper class for huggingface transformers.pipelines.MaskGenerationPipeline
|
| 18 |
+
that can decode from intermediate SAM embeddings.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def _sanitize_parameters(self, **kwargs: Any) -> tuple[dict[str, Any], ...]:
|
| 22 |
+
preprocess_kwargs = {}
|
| 23 |
+
postprocess_kwargs = {}
|
| 24 |
+
forward_params = {}
|
| 25 |
+
# preprocess args
|
| 26 |
+
if "embeddings" in kwargs: # inject embeddings here
|
| 27 |
+
preprocess_kwargs["embeddings"] = kwargs["embeddings"]
|
| 28 |
+
if "points_per_batch" in kwargs:
|
| 29 |
+
preprocess_kwargs["points_per_batch"] = kwargs["points_per_batch"]
|
| 30 |
+
if "points_per_crop" in kwargs:
|
| 31 |
+
preprocess_kwargs["points_per_crop"] = kwargs["points_per_crop"]
|
| 32 |
+
if "crops_n_layers" in kwargs:
|
| 33 |
+
preprocess_kwargs["crops_n_layers"] = kwargs["crops_n_layers"]
|
| 34 |
+
if "crop_overlap_ratio" in kwargs:
|
| 35 |
+
preprocess_kwargs["crop_overlap_ratio"] = kwargs["crop_overlap_ratio"]
|
| 36 |
+
if "crop_n_points_downscale_factor" in kwargs:
|
| 37 |
+
preprocess_kwargs["crop_n_points_downscale_factor"] = kwargs["crop_n_points_downscale_factor"]
|
| 38 |
+
if "timeout" in kwargs:
|
| 39 |
+
preprocess_kwargs["timeout"] = kwargs["timeout"]
|
| 40 |
+
# postprocess args
|
| 41 |
+
if "pred_iou_thresh" in kwargs:
|
| 42 |
+
forward_params["pred_iou_thresh"] = kwargs["pred_iou_thresh"]
|
| 43 |
+
if "stability_score_offset" in kwargs:
|
| 44 |
+
forward_params["stability_score_offset"] = kwargs["stability_score_offset"]
|
| 45 |
+
if "mask_threshold" in kwargs:
|
| 46 |
+
forward_params["mask_threshold"] = kwargs["mask_threshold"]
|
| 47 |
+
if "stability_score_thresh" in kwargs:
|
| 48 |
+
forward_params["stability_score_thresh"] = kwargs["stability_score_thresh"]
|
| 49 |
+
if "crops_nms_thresh" in kwargs:
|
| 50 |
+
postprocess_kwargs["crops_nms_thresh"] = kwargs["crops_nms_thresh"]
|
| 51 |
+
if "output_rle_mask" in kwargs:
|
| 52 |
+
postprocess_kwargs["output_rle_mask"] = kwargs["output_rle_mask"]
|
| 53 |
+
if "output_bboxes_mask" in kwargs:
|
| 54 |
+
postprocess_kwargs["output_bboxes_mask"] = kwargs["output_bboxes_mask"]
|
| 55 |
+
return preprocess_kwargs, forward_params, postprocess_kwargs
|
| 56 |
+
|
| 57 |
+
def preprocess(
|
| 58 |
+
self,
|
| 59 |
+
image: list[Image.Image],
|
| 60 |
+
points_per_batch: int = 64,
|
| 61 |
+
crops_n_layers: int = 0,
|
| 62 |
+
crop_overlap_ratio: float = 512 / 1500,
|
| 63 |
+
points_per_crop: int = 32,
|
| 64 |
+
crop_n_points_downscale_factor: int = 1,
|
| 65 |
+
timeout: Optional[float] = None,
|
| 66 |
+
embeddings: Optional[torch.Tensor] = None,
|
| 67 |
+
) -> Generator[Any, Any, Any]:
|
| 68 |
+
image = load_image(image, timeout=timeout)
|
| 69 |
+
target_size = self.image_processor.size["longest_edge"]
|
| 70 |
+
crop_boxes, grid_points, cropped_images, input_labels = self.image_processor.generate_crop_boxes(
|
| 71 |
+
image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor
|
| 72 |
+
)
|
| 73 |
+
model_inputs = self.image_processor(images=cropped_images, return_tensors="pt")
|
| 74 |
+
|
| 75 |
+
with self.device_placement():
|
| 76 |
+
if self.framework == "pt":
|
| 77 |
+
inference_context = self.get_inference_context()
|
| 78 |
+
with inference_context():
|
| 79 |
+
model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
|
| 80 |
+
if embeddings is None:
|
| 81 |
+
image_embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values"))
|
| 82 |
+
else:
|
| 83 |
+
model_inputs.pop("pixel_values")
|
| 84 |
+
image_embeddings = embeddings
|
| 85 |
+
model_inputs["image_embeddings"] = image_embeddings
|
| 86 |
+
|
| 87 |
+
n_points = grid_points.shape[1]
|
| 88 |
+
points_per_batch = points_per_batch if points_per_batch is not None else n_points
|
| 89 |
+
|
| 90 |
+
if points_per_batch <= 0:
|
| 91 |
+
raise ValueError(
|
| 92 |
+
"Cannot have points_per_batch<=0. Must be >=1 to returned batched outputs. "
|
| 93 |
+
"To return all points at once, set points_per_batch to None"
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
for i in range(0, n_points, points_per_batch):
|
| 97 |
+
batched_points = grid_points[:, i : i + points_per_batch, :, :]
|
| 98 |
+
labels = input_labels[:, i : i + points_per_batch]
|
| 99 |
+
is_last = i == n_points - points_per_batch
|
| 100 |
+
yield {
|
| 101 |
+
"input_points": batched_points,
|
| 102 |
+
"input_labels": labels,
|
| 103 |
+
"input_boxes": crop_boxes,
|
| 104 |
+
"is_last": is_last,
|
| 105 |
+
**model_inputs,
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def draw_mask(mask: NDArray, random_color: bool = False) -> NDArray:
|
| 110 |
+
"""Draw the mask on an image.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
mask (NDArray): mask in shape [height, width].
|
| 114 |
+
random_color (bool): if using a random color. Defaults to False.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
NDArray: NDArray format of the image.
|
| 118 |
+
"""
|
| 119 |
+
if random_color:
|
| 120 |
+
color = np.concatenate([np.random.random(3)], axis=0)
|
| 121 |
+
else:
|
| 122 |
+
color = np.array([30 / 255, 144 / 255, 255 / 255])
|
| 123 |
+
h, w = mask.shape[-2:]
|
| 124 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
| 125 |
+
return mask_image
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def decode_sam(
|
| 129 |
+
features: torch.Tensor,
|
| 130 |
+
images: list[Image.Image],
|
| 131 |
+
mask_generator: Any,
|
| 132 |
+
points_per_batch: int = 64,
|
| 133 |
+
pred_iou_thresh: float = 0.5,
|
| 134 |
+
stability_score_thresh: float = 0.6,
|
| 135 |
+
random_color: bool = True,
|
| 136 |
+
device: int | str | torch.device = 0,
|
| 137 |
+
) -> NDArray:
|
| 138 |
+
"""Decode features using SAM (auto-prompting) mask generation pipeline.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
features (torch.Tensor): features to be decoded, should be in shape [batch_size, num_tokens, latent_dim].
|
| 142 |
+
images (list[Image.Image]): images corresponding to these features.
|
| 143 |
+
mask_generator (Any): mask generation pipeline.
|
| 144 |
+
points_per_batch (int): points per batch for auto-prompting. Defaults to 64.
|
| 145 |
+
See transformers.pipelines.MaskGenerationPipeline for more details. Same below.
|
| 146 |
+
pred_iou_thresh (float): iou threshold. Defaults to 0.5.
|
| 147 |
+
stability_score_thresh (float): stability threshold. Defaults to 0.6.
|
| 148 |
+
random_color (bool): if using a random color. Defaults to True.
|
| 149 |
+
device (int | str | torch.device): device to perform the decoding. Defaults to 0.
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
NDArray: decoded masks rendered in image format, represented by an NDArray in size
|
| 153 |
+
[batch_size, height, width, channels] with value between [0, 1].
|
| 154 |
+
"""
|
| 155 |
+
masks_rgbs = []
|
| 156 |
+
num_patches = int(features.size(1) ** 0.5)
|
| 157 |
+
features = rearrange(features, "b (h w) c -> b c h w", h=num_patches, w=num_patches)
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
for im, feature in zip(images, features, strict=False):
|
| 160 |
+
predicted_ouputs = mask_generator(
|
| 161 |
+
im,
|
| 162 |
+
points_per_batch=points_per_batch,
|
| 163 |
+
embeddings=feature.unsqueeze(0).to(device),
|
| 164 |
+
pred_iou_thresh=pred_iou_thresh,
|
| 165 |
+
stability_score_thresh=stability_score_thresh,
|
| 166 |
+
)
|
| 167 |
+
predicted_masks = predicted_ouputs["masks"]
|
| 168 |
+
masks_rgb = np.zeros((224, 224, 3), dtype=np.float32)
|
| 169 |
+
for mask in predicted_masks:
|
| 170 |
+
masks_rgb += draw_mask(mask, random_color=random_color)
|
| 171 |
+
# masks_rgb = cv2.cvtColor(masks_rgb, cv2.COLOR_RGBA2RGB)
|
| 172 |
+
masks_rgbs.append(masks_rgb)
|
| 173 |
+
return np.stack(masks_rgbs)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def prepare_mask_generator(device: int | str | torch.device = 0) -> MaskGenerationPipeline:
|
| 177 |
+
"""Prepare a mask generation pipeline on device `device`.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
device (int | str | torch.device): device to perform mask generation. Defaults to 0.
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
MaskGenerationPipeline: mask generator.
|
| 184 |
+
"""
|
| 185 |
+
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
|
| 186 |
+
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
| 187 |
+
sam_model.eval()
|
| 188 |
+
mask_generator = MaskGenerationPipelineWithEmbeddings(
|
| 189 |
+
task="mask_generation", model=sam_model, image_processor=processor.image_processor, device=device
|
| 190 |
+
)
|
| 191 |
+
return mask_generator, sam_model
|
theia/example/decode_to_vfms.ipynb
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import os\n",
|
| 10 |
+
"import cv2\n",
|
| 11 |
+
"import torch\n",
|
| 12 |
+
"from PIL import Image\n",
|
| 13 |
+
"import numpy as np\n",
|
| 14 |
+
"from transformers import AutoModel\n",
|
| 15 |
+
"from torchvision.io import read_video, write_video\n",
|
| 16 |
+
"from theia.decoding import load_feature_stats, prepare_depth_decoder, prepare_mask_generator, decode_everything\n",
|
| 17 |
+
"\n",
|
| 18 |
+
"device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
|
| 19 |
+
"theia_model = AutoModel.from_pretrained(\"theaiinstitute/theia-base-patch16-224-cdiv\", trust_remote_code=True)\n",
|
| 20 |
+
"theia_model = theia_model.to(device)\n",
|
| 21 |
+
"target_model_names = [\n",
|
| 22 |
+
" \"google/vit-huge-patch14-224-in21k\",\n",
|
| 23 |
+
" \"facebook/dinov2-large\",\n",
|
| 24 |
+
" \"openai/clip-vit-large-patch14\",\n",
|
| 25 |
+
" \"facebook/sam-vit-huge\",\n",
|
| 26 |
+
" \"LiheYoung/depth-anything-large-hf\",\n",
|
| 27 |
+
"]\n",
|
| 28 |
+
"feature_means, feature_vars = load_feature_stats(target_model_names, stat_file_root=\"../../../feature_stats\")\n",
|
| 29 |
+
"\n",
|
| 30 |
+
"mask_generator, sam_model = prepare_mask_generator(device)\n",
|
| 31 |
+
"depth_anything_model_name = \"LiheYoung/depth-anything-large-hf\"\n",
|
| 32 |
+
"depth_anything_decoder, _ = prepare_depth_decoder(depth_anything_model_name, device)\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"example_video_path = \"../../../media/example_video_to_visualize.mp4\"\n",
|
| 35 |
+
"video, _, _ = read_video(example_video_path, pts_unit=\"sec\", output_format=\"THWC\")\n",
|
| 36 |
+
"video = video.numpy()\n",
|
| 37 |
+
"images = [Image.fromarray(cv2.resize(im, (224, 224))) for im in video]\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"theia_decode_results, gt_decode_results = decode_everything(\n",
|
| 40 |
+
" theia_model=theia_model,\n",
|
| 41 |
+
" feature_means=feature_means,\n",
|
| 42 |
+
" feature_vars=feature_vars,\n",
|
| 43 |
+
" images=images,\n",
|
| 44 |
+
" mask_generator=mask_generator,\n",
|
| 45 |
+
" sam_model=sam_model,\n",
|
| 46 |
+
" depth_anything_decoder=depth_anything_decoder,\n",
|
| 47 |
+
" pred_iou_thresh=0.5,\n",
|
| 48 |
+
" stability_score_thresh=0.7,\n",
|
| 49 |
+
" gt=True,\n",
|
| 50 |
+
" device=device,\n",
|
| 51 |
+
")\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"vis_video = np.stack(\n",
|
| 54 |
+
" [np.vstack([tr, gtr]) for tr, gtr in zip(theia_decode_results, gt_decode_results, strict=False)]\n",
|
| 55 |
+
")\n",
|
| 56 |
+
"vis_video = torch.from_numpy(vis_video * 255.0).to(torch.uint8)\n",
|
| 57 |
+
"vis_save_path = \"./visualized.mp4\"\n",
|
| 58 |
+
"write_video(vis_save_path, vis_video, fps=10)"
|
| 59 |
+
]
|
| 60 |
+
}
|
| 61 |
+
],
|
| 62 |
+
"metadata": {
|
| 63 |
+
"language_info": {
|
| 64 |
+
"name": "python"
|
| 65 |
+
}
|
| 66 |
+
},
|
| 67 |
+
"nbformat": 4,
|
| 68 |
+
"nbformat_minor": 2
|
| 69 |
+
}
|
theia/foundation_models/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
| 2 |
+
|
| 3 |
+
from .vision_language_models.clip import get_clip_feature, get_clip_model
|
| 4 |
+
from .vision_language_models.llava import get_llava_vision_model, get_llava_visual_feature
|
| 5 |
+
from .vision_models.deit import get_deit_feature, get_deit_model
|
| 6 |
+
from .vision_models.depth_anything import get_depth_anything_feature, get_depth_anything_model
|
| 7 |
+
from .vision_models.dinov2 import get_dinov2_feature, get_dinov2_model
|
| 8 |
+
from .vision_models.sam import get_sam_feature, get_sam_model
|
| 9 |
+
from .vision_models.vit import get_vit_feature, get_vit_model
|
theia/foundation_models/common.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
MODELS = [
|
| 8 |
+
"facebook/dinov2-large",
|
| 9 |
+
"facebook/sam-vit-huge",
|
| 10 |
+
"google/vit-huge-patch14-224-in21k",
|
| 11 |
+
"llava-hf/llava-1.5-7b-hf",
|
| 12 |
+
"openai/clip-vit-large-patch14",
|
| 13 |
+
"LiheYoung/depth-anything-large-hf",
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
# handy model feature size constants
|
| 17 |
+
# in the format of (latent_dim, width, height)
|
| 18 |
+
MODEL_FEATURE_SIZES = {
|
| 19 |
+
"facebook/dinov2-large": (1024, 16, 16),
|
| 20 |
+
"facebook/sam-vit-huge": (256, 64, 64),
|
| 21 |
+
"google/vit-huge-patch14-224-in21k": (1280, 16, 16),
|
| 22 |
+
"llava-hf/llava-1.5-7b-hf": (1024, 24, 24),
|
| 23 |
+
"openai/clip-vit-large-patch14": (1024, 16, 16),
|
| 24 |
+
"LiheYoung/depth-anything-large-hf": (32, 64, 64),
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_model_feature_size(
|
| 29 |
+
model_name: str, keep_spatial: bool = False, return_torch_size: bool = False
|
| 30 |
+
) -> tuple[int, ...] | torch.Size:
|
| 31 |
+
"""
|
| 32 |
+
Get the size of queried model feature.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
model_name (str): name of the model.
|
| 36 |
+
keep_spatial (bool): whether to preserve spatial dim. Defaults to False.
|
| 37 |
+
return_torch_size (bool): return torch.Size instead of python tuple. Defaults to False.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
tuple[int, ...] | torch.Size: the size of the feature.
|
| 41 |
+
"""
|
| 42 |
+
size: tuple[int, ...] = MODEL_FEATURE_SIZES[model_name]
|
| 43 |
+
|
| 44 |
+
if not keep_spatial:
|
| 45 |
+
size = (size[0], math.prod(size[1:]))
|
| 46 |
+
|
| 47 |
+
if return_torch_size:
|
| 48 |
+
size = torch.Size(size)
|
| 49 |
+
|
| 50 |
+
return size
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_max_model_spatial_size(
|
| 54 |
+
keep_spatial: bool = True,
|
| 55 |
+
return_torch_size: bool = False,
|
| 56 |
+
return_model_name: bool = False,
|
| 57 |
+
) -> tuple[int, ...] | tuple[tuple[int, ...], str]:
|
| 58 |
+
"""Get the maximal spatial dimensions from available models
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
keep_spatial (bool): whether to preserve spatial dim. Defaults to True.
|
| 62 |
+
return_torch_size (bool): return torch.Size instead of python tuple. Defaults to False.
|
| 63 |
+
return_model_name (bool): the name of the model with maximal size. Defaults to False.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
tuple[int, ...] | tuple[tuple[int, ...], str]: the maximal size and optional model name.
|
| 67 |
+
"""
|
| 68 |
+
max_flatten_size = -1
|
| 69 |
+
max_size: tuple[int, ...] = ()
|
| 70 |
+
max_size_model_name: str = ""
|
| 71 |
+
for model, size in MODEL_FEATURE_SIZES.items():
|
| 72 |
+
flatten_size = math.prod(size[1:])
|
| 73 |
+
if flatten_size > max_flatten_size:
|
| 74 |
+
max_flatten_size = flatten_size
|
| 75 |
+
max_size = size[1:]
|
| 76 |
+
max_size_model_name = model
|
| 77 |
+
|
| 78 |
+
if not keep_spatial:
|
| 79 |
+
max_size = (max_flatten_size,)
|
| 80 |
+
|
| 81 |
+
if return_torch_size:
|
| 82 |
+
max_size = torch.Size(max_size)
|
| 83 |
+
|
| 84 |
+
if return_model_name:
|
| 85 |
+
return max_size, max_size_model_name
|
| 86 |
+
else:
|
| 87 |
+
return max_size
|