Spaces:
Paused
Paused
Delete app.py
Browse files
app.py
DELETED
|
@@ -1,286 +0,0 @@
|
|
| 1 |
-
'''
|
| 2 |
-
author: caishaofei <caishaofei@stu.pku.edu.cn>
|
| 3 |
-
date: 2024-09-20 20:10:44
|
| 4 |
-
Copyright © Team CraftJarvis All rights reserved
|
| 5 |
-
'''
|
| 6 |
-
import re
|
| 7 |
-
import os
|
| 8 |
-
import cv2
|
| 9 |
-
import time
|
| 10 |
-
from pathlib import Path
|
| 11 |
-
import argparse
|
| 12 |
-
import requests
|
| 13 |
-
import gradio as gr
|
| 14 |
-
import torch
|
| 15 |
-
import numpy as np
|
| 16 |
-
from io import BytesIO
|
| 17 |
-
from PIL import Image, ImageDraw
|
| 18 |
-
from rocket.arm.sessions import Session, Pointer
|
| 19 |
-
|
| 20 |
-
COLORS = [
|
| 21 |
-
(255, 0, 0), (0, 255, 0), (0, 0, 255),
|
| 22 |
-
(255, 255, 0), (255, 0, 255), (0, 255, 255),
|
| 23 |
-
(255, 255, 255), (0, 0, 0), (128, 128, 128),
|
| 24 |
-
(128, 0, 0), (128, 128, 0), (0, 128, 0),
|
| 25 |
-
(128, 0, 128), (0, 128, 128), (0, 0, 128),
|
| 26 |
-
]
|
| 27 |
-
|
| 28 |
-
SEGMENT_MAPPING = {
|
| 29 |
-
"Hunt": 0, "Use": 3, "Mine": 2, "Interact": 3, "Craft": 4, "Switch": 5, "Approach": 6
|
| 30 |
-
}
|
| 31 |
-
|
| 32 |
-
NOOP_ACTION = {
|
| 33 |
-
"back": 0,
|
| 34 |
-
"drop": 0,
|
| 35 |
-
"forward": 0,
|
| 36 |
-
"hotbar.1": 0,
|
| 37 |
-
"hotbar.2": 0,
|
| 38 |
-
"hotbar.3": 0,
|
| 39 |
-
"hotbar.4": 0,
|
| 40 |
-
"hotbar.5": 0,
|
| 41 |
-
"hotbar.6": 0,
|
| 42 |
-
"hotbar.7": 0,
|
| 43 |
-
"hotbar.8": 0,
|
| 44 |
-
"hotbar.9": 0,
|
| 45 |
-
"inventory": 0,
|
| 46 |
-
"jump": 0,
|
| 47 |
-
"left": 0,
|
| 48 |
-
"right": 0,
|
| 49 |
-
"sneak": 0,
|
| 50 |
-
"sprint": 0,
|
| 51 |
-
"camera": np.array([0, 0]),
|
| 52 |
-
"attack": 0,
|
| 53 |
-
"use": 0,
|
| 54 |
-
}
|
| 55 |
-
|
| 56 |
-
def reset_fn(env_name, session):
|
| 57 |
-
image = session.reset(env_name)
|
| 58 |
-
return image, session
|
| 59 |
-
|
| 60 |
-
def step_fn(act_key, session):
|
| 61 |
-
action = NOOP_ACTION.copy()
|
| 62 |
-
if act_key != "null":
|
| 63 |
-
action[act_key] = 1
|
| 64 |
-
image = session.step(action)
|
| 65 |
-
return image, session
|
| 66 |
-
|
| 67 |
-
def loop_step_fn(steps, session):
|
| 68 |
-
for i in range(steps):
|
| 69 |
-
image = session.step()
|
| 70 |
-
status = f"Running Agent `Rocket` steps: {i+1}/{steps}. "
|
| 71 |
-
yield image, session.num_steps, status, session
|
| 72 |
-
|
| 73 |
-
def clear_memory_fn(session):
|
| 74 |
-
image = session.current_image
|
| 75 |
-
session.clear_agent_memory()
|
| 76 |
-
return image, "0", session
|
| 77 |
-
|
| 78 |
-
def get_points_with_draw(image, label, session, evt: gr.SelectData):
|
| 79 |
-
points = session.points
|
| 80 |
-
point_label = session.points_label
|
| 81 |
-
x, y = evt.index[0], evt.index[1]
|
| 82 |
-
point_radius, point_color = 5, (0, 255, 0) if label == 'Add Points' else (255, 0, 0)
|
| 83 |
-
points.append([x, y])
|
| 84 |
-
point_label.append(1 if label == 'Add Points' else 0)
|
| 85 |
-
cv2.circle(image, (x, y), point_radius, point_color, -1)
|
| 86 |
-
return image, session
|
| 87 |
-
|
| 88 |
-
def clear_points_fn(session):
|
| 89 |
-
session.clear_points()
|
| 90 |
-
return session.current_image, session
|
| 91 |
-
|
| 92 |
-
def segment_fn(session):
|
| 93 |
-
if len(session.points) == 0:
|
| 94 |
-
return session.current_image, session
|
| 95 |
-
session.segment()
|
| 96 |
-
image = session.apply_mask()
|
| 97 |
-
return image, session
|
| 98 |
-
|
| 99 |
-
def clear_segment_fn(session):
|
| 100 |
-
session.clear_obj_mask()
|
| 101 |
-
session.tracking_flag = False
|
| 102 |
-
return session.current_image, False, session
|
| 103 |
-
|
| 104 |
-
def set_tracking_mode(tracking_flag, session):
|
| 105 |
-
session.tracking_flag = tracking_flag
|
| 106 |
-
return session
|
| 107 |
-
|
| 108 |
-
def set_segment_type(segment_type, session):
|
| 109 |
-
session.segment_type = segment_type
|
| 110 |
-
return session
|
| 111 |
-
|
| 112 |
-
def play_fn(session):
|
| 113 |
-
image = session.step()
|
| 114 |
-
return image, session
|
| 115 |
-
|
| 116 |
-
memory_length = gr.Textbox(value="0", interactive=False, show_label=False)
|
| 117 |
-
|
| 118 |
-
def make_video_fn(session, make_video, save_video, progress=gr.Progress()):
|
| 119 |
-
images = session.image_history
|
| 120 |
-
if len(images) == 0:
|
| 121 |
-
return session, make_video, save_video
|
| 122 |
-
filepath = "rocket.mp4"
|
| 123 |
-
h, w = images[0].shape[:2]
|
| 124 |
-
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 125 |
-
video = cv2.VideoWriter(filepath, fourcc, 20.0, (w, h))
|
| 126 |
-
for image in progress.tqdm(images):
|
| 127 |
-
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 128 |
-
video.write(image)
|
| 129 |
-
video.release()
|
| 130 |
-
session.image_history = []
|
| 131 |
-
return session, gr.Button("Make Video", visible=False), gr.DownloadButton("Download!", value=filepath, visible=True)
|
| 132 |
-
|
| 133 |
-
def save_video_fn(session, make_video, save_video):
|
| 134 |
-
return session, gr.Button("Make Video", visible=True), gr.DownloadButton("Download!", visible=False)
|
| 135 |
-
|
| 136 |
-
def choose_sam_fn(sam_choice, session):
|
| 137 |
-
session.sam_choice = sam_choice
|
| 138 |
-
session.load_sam()
|
| 139 |
-
return session
|
| 140 |
-
|
| 141 |
-
def molmo_fn(molmo_text, molmo_session, rocket_session, display_image):
|
| 142 |
-
image = rocket_session.current_image.copy()
|
| 143 |
-
points = molmo_session.gen_point(image=image, prompt=molmo_text)
|
| 144 |
-
molmo_result = molmo_session.molmo_result
|
| 145 |
-
for x, y in points:
|
| 146 |
-
x, y = int(x), int(y)
|
| 147 |
-
point_radius, point_color = 5, (0, 255, 0)
|
| 148 |
-
rocket_session.points.append([x, y])
|
| 149 |
-
rocket_session.points_label.append(1)
|
| 150 |
-
cv2.circle(display_image, (x, y), point_radius, point_color, -1)
|
| 151 |
-
return molmo_result, display_image
|
| 152 |
-
|
| 153 |
-
def extract_points(data):
|
| 154 |
-
# 匹配 x 和 y 坐标的值,支持 <points> 和 <point> 标签
|
| 155 |
-
pattern = r'x\d?="([-+]?\d*\.\d+|\d+)" y\d?="([-+]?\d*\.\d+|\d+)"'
|
| 156 |
-
points = re.findall(pattern, data)
|
| 157 |
-
# 将提取到的坐标转换为浮点数
|
| 158 |
-
points = [(float(x)/100*640, float(y)/100*360) for x, y in points]
|
| 159 |
-
return points
|
| 160 |
-
|
| 161 |
-
def draw_gradio_components(args):
|
| 162 |
-
|
| 163 |
-
with gr.Blocks() as demo:
|
| 164 |
-
|
| 165 |
-
gr.Markdown(
|
| 166 |
-
"""
|
| 167 |
-
# Welcome to Explore ROCKET-1 in Minecraft!!
|
| 168 |
-
## Please follow next steps to interact with the agent:
|
| 169 |
-
1. Reset the environment by selecting an environment name.
|
| 170 |
-
2. Select a SAM2 checkpoint to load.
|
| 171 |
-
3. Use your mouse to add or remove points on the image.
|
| 172 |
-
4. Select the segment type you want to perform.
|
| 173 |
-
5. Enable `tracking` mode if you want to track objects while stepping actions.
|
| 174 |
-
6. Click `New Segment` to segment the image based on the points you added.
|
| 175 |
-
7. Call the agent by clicking `Call Rocket` to run the agent for a certain number of steps.
|
| 176 |
-
## Hints:
|
| 177 |
-
1. You can use the `Make Video` button to generate a video of the agent's actions.
|
| 178 |
-
2. You can use the `Clear Memory` button to clear the ROCKET-1's memory.
|
| 179 |
-
3. You can use the `Clear Segment` button to clear SAM's memory.
|
| 180 |
-
4. You can use the `Manually Step` button to manually step the agent.
|
| 181 |
-
"""
|
| 182 |
-
)
|
| 183 |
-
|
| 184 |
-
rocket_session = gr.State(Session(
|
| 185 |
-
sam_path=args.sam_path,
|
| 186 |
-
))
|
| 187 |
-
molmo_session = gr.State(Pointer(
|
| 188 |
-
model_id="molmo-72b-0924",
|
| 189 |
-
model_url="http://172.17.30.127:8000/v1",
|
| 190 |
-
))
|
| 191 |
-
with gr.Row():
|
| 192 |
-
|
| 193 |
-
with gr.Column(scale=2):
|
| 194 |
-
# start_image = Image.open("start.png").resize((640, 360))
|
| 195 |
-
start_image = np.zeros((360, 640, 3), dtype=np.uint8)
|
| 196 |
-
|
| 197 |
-
with gr.Group():
|
| 198 |
-
display_image = gr.Image(
|
| 199 |
-
value=np.array(start_image),
|
| 200 |
-
interactive=False,
|
| 201 |
-
show_label=False,
|
| 202 |
-
label="Real-time Environment Observation",
|
| 203 |
-
streaming=True
|
| 204 |
-
)
|
| 205 |
-
display_status = gr.Textbox("Status Bar", interactive=False, show_label=False)
|
| 206 |
-
|
| 207 |
-
with gr.Column(scale=1):
|
| 208 |
-
|
| 209 |
-
sam_choice = gr.Radio(
|
| 210 |
-
choices=["large", "base", "small", "tiny"],
|
| 211 |
-
value="base",
|
| 212 |
-
label="Select SAM2 checkpoint",
|
| 213 |
-
)
|
| 214 |
-
sam_choice.select(fn=choose_sam_fn, inputs=[sam_choice, rocket_session], outputs=[rocket_session], show_progress=False)
|
| 215 |
-
|
| 216 |
-
with gr.Group():
|
| 217 |
-
add_or_remove = gr.Radio(
|
| 218 |
-
choices=["Add Points", "Remove Areas"],
|
| 219 |
-
value="Add Points",
|
| 220 |
-
label="Use you mouse to add or remove points",
|
| 221 |
-
)
|
| 222 |
-
clear_points_btn = gr.Button("Clear Points")
|
| 223 |
-
clear_points_btn.click(clear_points_fn, inputs=[rocket_session], outputs=[display_image, rocket_session], show_progress=True)
|
| 224 |
-
|
| 225 |
-
with gr.Group():
|
| 226 |
-
segment_type = gr.Radio(
|
| 227 |
-
choices=["Approach", "Interact", "Hunt", "Mine", "Craft", "Switch"],
|
| 228 |
-
value="Approach",
|
| 229 |
-
label="What do you want with this segment?",
|
| 230 |
-
)
|
| 231 |
-
track_flag = gr.Checkbox(True, label="Enable tracking objects while steping actions")
|
| 232 |
-
track_flag.select(fn=set_tracking_mode, inputs=[track_flag, rocket_session], outputs=[rocket_session], show_progress=False)
|
| 233 |
-
with gr.Group(), gr.Row():
|
| 234 |
-
new_segment_btn = gr.Button("New Segment")
|
| 235 |
-
clear_segment_btn = gr.Button("Clear Segment")
|
| 236 |
-
new_segment_btn.click(segment_fn, inputs=[rocket_session], outputs=[display_image, rocket_session], show_progress=True)
|
| 237 |
-
clear_segment_btn.click(clear_segment_fn, inputs=[rocket_session], outputs=[display_image, track_flag, rocket_session], show_progress=True)
|
| 238 |
-
|
| 239 |
-
display_image.select(get_points_with_draw, inputs=[display_image, add_or_remove, rocket_session], outputs=[display_image, rocket_session])
|
| 240 |
-
segment_type.select(set_segment_type, inputs=[segment_type, rocket_session], outputs=[rocket_session], show_progress=False)
|
| 241 |
-
|
| 242 |
-
with gr.Row():
|
| 243 |
-
with gr.Group():
|
| 244 |
-
env_list = [f"rocket/{x.stem}" for x in Path("../env_configs/rocket").glob("*.yaml") if 'base' not in x.name != 'base']
|
| 245 |
-
env_name = gr.Dropdown(env_list, multiselect=False, min_width=200, show_label=False, label="Env Name")
|
| 246 |
-
reset_btn = gr.Button("Reset Environment")
|
| 247 |
-
reset_btn.click(fn=reset_fn, inputs=[env_name, rocket_session], outputs=[display_image, rocket_session], show_progress=True)
|
| 248 |
-
|
| 249 |
-
with gr.Group():
|
| 250 |
-
action_list = [x for x in NOOP_ACTION.keys()]
|
| 251 |
-
act_key = gr.Dropdown(action_list, multiselect=False, min_width=200, show_label=False, label="Action")
|
| 252 |
-
step_btn = gr.Button("Manually Step")
|
| 253 |
-
step_btn.click(fn=step_fn, inputs=[act_key, rocket_session], outputs=[display_image, rocket_session], show_progress=False)
|
| 254 |
-
|
| 255 |
-
with gr.Group():
|
| 256 |
-
steps = gr.Slider(1, 600, 30, 1, label="Steps", show_label=False)
|
| 257 |
-
play_btn = gr.Button("Call Rocket")
|
| 258 |
-
play_btn.click(fn=loop_step_fn, inputs=[steps, rocket_session], outputs=[display_image, memory_length, display_status, rocket_session], show_progress=False)
|
| 259 |
-
|
| 260 |
-
with gr.Group():
|
| 261 |
-
memory_length.render()
|
| 262 |
-
clear_states_btn = gr.Button("Clear Memory")
|
| 263 |
-
clear_states_btn.click(fn=clear_memory_fn, inputs=rocket_session, outputs=[display_image, memory_length, rocket_session], show_progress=False)
|
| 264 |
-
|
| 265 |
-
make_video_btn = gr.Button("Make Video")
|
| 266 |
-
save_video_btn = gr.DownloadButton("Download!!", visible=False)
|
| 267 |
-
make_video_btn.click(make_video_fn, inputs=[rocket_session, make_video_btn, save_video_btn], outputs=[rocket_session, make_video_btn, save_video_btn], show_progress=False)
|
| 268 |
-
save_video_btn.click(save_video_fn, inputs=[rocket_session, make_video_btn, save_video_btn], outputs=[rocket_session, make_video_btn, save_video_btn], show_progress=False)
|
| 269 |
-
with gr.Row():
|
| 270 |
-
with gr.Group():
|
| 271 |
-
molmo_text = gr.Textbox("pinpoint the", label="Molmo Text", show_label=True, min_width=200)
|
| 272 |
-
molmo_btn = gr.Button("Generate")
|
| 273 |
-
output_text = gr.Textbox("", label="Molmo Output", show_label=False, min_width=200)
|
| 274 |
-
molmo_btn.click(molmo_fn, inputs=[molmo_text, molmo_session, rocket_session, display_image],outputs=[output_text, display_image],show_progress=False)
|
| 275 |
-
|
| 276 |
-
demo.queue()
|
| 277 |
-
demo.launch(share=False,server_port=args.port)
|
| 278 |
-
|
| 279 |
-
if __name__ == '__main__':
|
| 280 |
-
parser = argparse.ArgumentParser()
|
| 281 |
-
parser.add_argument("--port", type=int, default=7860)
|
| 282 |
-
parser.add_argument("--sam-path", type=str, default="/app/ROCKET-1/rocket/realtime_sam/checkpoints")
|
| 283 |
-
parser.add_argument("--molmo-id", type=str, default="molmo-72b-0924")
|
| 284 |
-
parser.add_argument("--molmo-url", type=str, default="http://127.0.0.1:8000/v1")
|
| 285 |
-
args = parser.parse_args()
|
| 286 |
-
draw_gradio_components(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|