Spaces:
Sleeping
Sleeping
| from octo.model.octo_model import OctoModel | |
| from PIL import Image | |
| import requests | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import jax | |
| import os | |
| os.environ['JAX_PLATFORMS'] = 'cpu' | |
| model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5") | |
| # download one example BridgeV2 image | |
| IMAGE_URL = "https://rail.eecs.berkeley.edu/datasets/bridge_release/raw/bridge_data_v2/datacol2_toykitchen7/drawer_pnp/01/2023-04-19_09-18-15/raw/traj_group0/traj0/images0/im_12.jpg" | |
| img = np.array(Image.open(requests.get(IMAGE_URL, stream=True).raw).resize((256, 256))) | |
| # add batch + time horizon 1 | |
| img = img[np.newaxis,np.newaxis,...] | |
| observation = {"image_primary": img, "timestep_pad_mask": np.array([[True]])} | |
| task = model.create_tasks(texts=["pick up the fork"]) | |
| norm_actions = model.sample_actions(observation, task, rng=jax.random.PRNGKey(0)) | |
| norm_actions = norm_actions[0] # remove batch | |
| actions = ( | |
| norm_actions * model.dataset_statistics["bridge_dataset"]['action']['std'] | |
| + model.dataset_statistics["bridge_dataset"]['action']['mean'] | |
| ) | |
| actions = np.concatenate( | |
| ( | |
| steps[step+1]['action']['world_vector'], | |
| steps[step+1]['action']['rotation_delta'], | |
| np.array(steps[step+1]['action']['open_gripper']).astype(np.float32)[None] | |
| ), axis=-1 | |
| ) | |
| print(actions) |