Spaces:
Build error
Build error
Upload folder using huggingface_hub
Browse files- .gitignore +3 -0
- .gradio/certificate.pem +31 -0
- README.md +17 -8
- app.py +180 -0
- go.py +188 -0
- gui_actor/__init__.py +0 -0
- gui_actor/__pycache__/__init__.cpython-312.pyc +0 -0
- gui_actor/__pycache__/constants.cpython-312.pyc +0 -0
- gui_actor/__pycache__/inference.cpython-312.pyc +0 -0
- gui_actor/__pycache__/modeling_qwen25vl.cpython-312.pyc +0 -0
- gui_actor/__pycache__/trainer.cpython-312.pyc +0 -0
- gui_actor/constants.py +40 -0
- gui_actor/dataset.py +533 -0
- gui_actor/inference.py +300 -0
- gui_actor/modeling.py +361 -0
- gui_actor/modeling_qwen25vl.py +376 -0
- gui_actor/trainer.py +313 -0
- gui_actor/utils.py +90 -0
- requirements.txt +19 -0
- run.py +173 -0
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
server.py
|
| 2 |
+
poster.py
|
| 3 |
+
test.png
|
.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
README.md
CHANGED
|
@@ -1,12 +1,21 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji: 📉
|
| 4 |
-
colorFrom: blue
|
| 5 |
-
colorTo: indigo
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.45.0
|
| 8 |
app_file: app.py
|
| 9 |
-
|
|
|
|
| 10 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
|
|
|
| 1 |
---
|
| 2 |
+
title: gui-actor-demo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
app_file: app.py
|
| 4 |
+
sdk: gradio
|
| 5 |
+
sdk_version: 5.44.1
|
| 6 |
---
|
| 7 |
+
# :rocket: Demo for GUI Actor
|
| 8 |
+
|
| 9 |
+
### 1. Install dependencies
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
pip install -r requirements.txt
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
### 2. Run the demo
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
python app.py
|
| 19 |
+
```
|
| 20 |
|
| 21 |
+

|
app.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64, os
|
| 2 |
+
# import spaces
|
| 3 |
+
import json
|
| 4 |
+
import torch
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from PIL import Image, ImageDraw
|
| 8 |
+
import numpy as np
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from qwen_vl_utils import process_vision_info
|
| 11 |
+
from datasets import load_dataset
|
| 12 |
+
from transformers import AutoProcessor
|
| 13 |
+
from gui_actor.constants import chat_template
|
| 14 |
+
from gui_actor.modeling_qwen25vl import Qwen2_5_VLForConditionalGenerationWithPointer
|
| 15 |
+
from gui_actor.inference import inference
|
| 16 |
+
|
| 17 |
+
MAX_PIXELS = 3200 * 1800
|
| 18 |
+
|
| 19 |
+
def resize_image(image, resize_to_pixels=MAX_PIXELS):
|
| 20 |
+
image_width, image_height = image.size
|
| 21 |
+
if (resize_to_pixels is not None) and ((image_width * image_height) != resize_to_pixels):
|
| 22 |
+
resize_ratio = (resize_to_pixels / (image_width * image_height)) ** 0.5
|
| 23 |
+
image_width_resized, image_height_resized = int(image_width * resize_ratio), int(image_height * resize_ratio)
|
| 24 |
+
image = image.resize((image_width_resized, image_height_resized))
|
| 25 |
+
return image
|
| 26 |
+
|
| 27 |
+
# @spaces.GPU
|
| 28 |
+
@torch.inference_mode()
|
| 29 |
+
def draw_point(image: Image.Image, point: list, radius=8, color=(255, 0, 0, 128)):
|
| 30 |
+
overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
|
| 31 |
+
overlay_draw = ImageDraw.Draw(overlay)
|
| 32 |
+
x, y = point
|
| 33 |
+
overlay_draw.ellipse(
|
| 34 |
+
[(x - radius, y - radius), (x + radius, y + radius)],
|
| 35 |
+
outline=color,
|
| 36 |
+
width=5 # Adjust thickness as needed
|
| 37 |
+
)
|
| 38 |
+
image = image.convert('RGBA')
|
| 39 |
+
combined = Image.alpha_composite(image, overlay)
|
| 40 |
+
combined = combined.convert('RGB')
|
| 41 |
+
return combined
|
| 42 |
+
|
| 43 |
+
# @spaces.GPU
|
| 44 |
+
@torch.inference_mode()
|
| 45 |
+
def get_attn_map(image, attn_scores, n_width, n_height):
|
| 46 |
+
w, h = image.size
|
| 47 |
+
scores = np.array(attn_scores[0]).reshape(n_height, n_width)
|
| 48 |
+
|
| 49 |
+
scores_norm = (scores - scores.min()) / (scores.max() - scores.min())
|
| 50 |
+
# Resize score map to match image size
|
| 51 |
+
score_map = Image.fromarray((scores_norm * 255).astype(np.uint8)).resize((w, h), resample=Image.NEAREST) # BILINEAR)
|
| 52 |
+
# Apply colormap
|
| 53 |
+
colormap = plt.get_cmap('jet')
|
| 54 |
+
colored_score_map = colormap(np.array(score_map) / 255.0) # returns RGBA
|
| 55 |
+
colored_score_map = (colored_score_map[:, :, :3] * 255).astype(np.uint8)
|
| 56 |
+
colored_overlay = Image.fromarray(colored_score_map)
|
| 57 |
+
|
| 58 |
+
# Blend with original image
|
| 59 |
+
blended = Image.blend(image, colored_overlay, alpha=0.3)
|
| 60 |
+
return blended
|
| 61 |
+
|
| 62 |
+
# load model
|
| 63 |
+
if torch.cuda.is_available():
|
| 64 |
+
# os.system('pip install flash-attn --no-build-isolation')
|
| 65 |
+
model_name_or_path = "microsoft/GUI-Actor-7B-Qwen2.5-VL"
|
| 66 |
+
data_processor = AutoProcessor.from_pretrained(model_name_or_path)
|
| 67 |
+
tokenizer = data_processor.tokenizer
|
| 68 |
+
model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained(
|
| 69 |
+
model_name_or_path,
|
| 70 |
+
torch_dtype=torch.bfloat16,
|
| 71 |
+
device_map="cuda:0",
|
| 72 |
+
attn_implementation="flash_attention_2"
|
| 73 |
+
).eval()
|
| 74 |
+
else:
|
| 75 |
+
model_name_or_path = "microsoft/GUI-Actor-3B-Qwen2.5-VL"
|
| 76 |
+
data_processor = AutoProcessor.from_pretrained(model_name_or_path)
|
| 77 |
+
tokenizer = data_processor.tokenizer
|
| 78 |
+
model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained(
|
| 79 |
+
model_name_or_path,
|
| 80 |
+
torch_dtype=torch.bfloat16,
|
| 81 |
+
device_map="cpu"
|
| 82 |
+
).eval()
|
| 83 |
+
|
| 84 |
+
title = "GUI-Actor"
|
| 85 |
+
header = """
|
| 86 |
+
<div align="center">
|
| 87 |
+
<h1 style="padding-bottom: 10px; padding-top: 10px;">🎯 <strong>GUI-Actor</strong>: Coordinate-Free Visual Grounding for GUI Agents</h1>
|
| 88 |
+
<div style="padding-bottom: 10px; padding-top: 10px; font-size: 16px;">
|
| 89 |
+
Qianhui Wu*, Kanzhi Cheng*, Rui Yang*, Chaoyun Zhang, Jianwei Yang, Huiqiang Jiang, Jian Mu, Baolin Peng, Bo Qiao, Reuben Tan, Si Qin, Lars Liden<br>
|
| 90 |
+
Qingwei Lin, Huan Zhang, Tong Zhang, Jianbing Zhang, Dongmei Zhang, Jianfeng Gao<br/>
|
| 91 |
+
</div>
|
| 92 |
+
<div style="padding-bottom: 10px; padding-top: 10px; font-size: 16px;">
|
| 93 |
+
<a href="https://microsoft.github.io/GUI-Actor/">🌐 Project Page</a> | <a href="https://arxiv.org/abs/2403.12968">📄 arXiv Paper</a> | <a href="https://github.com/microsoft/GUI-Actor">💻 Github Repo</a><br/>
|
| 94 |
+
</div>
|
| 95 |
+
</div>
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
theme = "soft"
|
| 99 |
+
css = """#anno-img .mask {opacity: 0.5; transition: all 0.2s ease-in-out;}
|
| 100 |
+
#anno-img .mask.active {opacity: 0.7}"""
|
| 101 |
+
|
| 102 |
+
# @spaces.GPU
|
| 103 |
+
@torch.inference_mode()
|
| 104 |
+
def process(image, instruction):
|
| 105 |
+
# resize image
|
| 106 |
+
w, h = image.size
|
| 107 |
+
if w * h > MAX_PIXELS:
|
| 108 |
+
image = resize_image(image)
|
| 109 |
+
|
| 110 |
+
conversation = [
|
| 111 |
+
{
|
| 112 |
+
"role": "system",
|
| 113 |
+
"content": [
|
| 114 |
+
{
|
| 115 |
+
"type": "text",
|
| 116 |
+
"text": "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, your task is to locate the screen element that corresponds to the instruction. You should output a PyAutoGUI action that performs a click on the correct position. To indicate the click location, we will use some special tokens, which is used to refer to a visual patch later. For example, you can output: pyautogui.click(<your_special_token_here>).",
|
| 117 |
+
}
|
| 118 |
+
]
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"role": "user",
|
| 122 |
+
"content": [
|
| 123 |
+
{
|
| 124 |
+
"type": "image",
|
| 125 |
+
"image": image, # PIL.Image.Image or str to path
|
| 126 |
+
# "image_url": "https://xxxxx.png" or "https://xxxxx.jpg" or "file://xxxxx.png" or "data:image/png;base64,xxxxxxxx", will be split by "base64,"
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"type": "text",
|
| 130 |
+
"text": instruction,
|
| 131 |
+
},
|
| 132 |
+
],
|
| 133 |
+
},
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
try:
|
| 137 |
+
pred = inference(conversation, model, tokenizer, data_processor, use_placeholder=True, topk=3)
|
| 138 |
+
except Exception as e:
|
| 139 |
+
print(e)
|
| 140 |
+
return image, f"Error: {e}", None
|
| 141 |
+
|
| 142 |
+
px, py = pred["topk_points"][0]
|
| 143 |
+
output_coord = f"({px:.4f}, {py:.4f})"
|
| 144 |
+
img_with_point = draw_point(image, (px * w, py * h))
|
| 145 |
+
|
| 146 |
+
n_width, n_height = pred["n_width"], pred["n_height"]
|
| 147 |
+
attn_scores = pred["attn_scores"]
|
| 148 |
+
att_map = get_attn_map(image, attn_scores, n_width, n_height)
|
| 149 |
+
|
| 150 |
+
return img_with_point, output_coord, att_map
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
with gr.Blocks(title=title, css=css) as demo:
|
| 154 |
+
gr.Markdown(header)
|
| 155 |
+
with gr.Row():
|
| 156 |
+
with gr.Column():
|
| 157 |
+
input_image = gr.Image(
|
| 158 |
+
type='pil', label='Upload image')
|
| 159 |
+
# text box
|
| 160 |
+
input_instruction = gr.Textbox(label='Instruction', placeholder='Text your (low-level) instruction here')
|
| 161 |
+
submit_button = gr.Button(
|
| 162 |
+
value='Submit', variant='primary')
|
| 163 |
+
with gr.Column():
|
| 164 |
+
image_with_point = gr.Image(type='pil', label='Image with Point (red circle)')
|
| 165 |
+
with gr.Accordion('Detailed prediction'):
|
| 166 |
+
pred_xy = gr.Textbox(label='Predicted Coordinates', placeholder='(x, y)')
|
| 167 |
+
att_map = gr.Image(type='pil', label='Attention Map')
|
| 168 |
+
|
| 169 |
+
submit_button.click(
|
| 170 |
+
fn=process,
|
| 171 |
+
inputs=[
|
| 172 |
+
input_image,
|
| 173 |
+
input_instruction
|
| 174 |
+
],
|
| 175 |
+
outputs=[image_with_point, pred_xy, att_map]
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# demo.launch(debug=False, show_error=True, share=True)
|
| 179 |
+
demo.launch(share=True, server_port=5566, server_name='0.0.0.0')
|
| 180 |
+
# demo.queue().launch(share=False)
|
go.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64, os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import argparse # 新增:导入argparse
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from PIL import Image, ImageDraw
|
| 8 |
+
import numpy as np
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from qwen_vl_utils import process_vision_info
|
| 11 |
+
from datasets import load_dataset
|
| 12 |
+
from transformers import AutoProcessor
|
| 13 |
+
from gui_actor.constants import chat_template
|
| 14 |
+
from gui_actor.modeling_qwen25vl import Qwen2_5_VLForConditionalGenerationWithPointer
|
| 15 |
+
from gui_actor.inference import inference
|
| 16 |
+
|
| 17 |
+
MAX_PIXELS = 3200 * 1800
|
| 18 |
+
|
| 19 |
+
def resize_image(image, resize_to_pixels=MAX_PIXELS):
|
| 20 |
+
image_width, image_height = image.size
|
| 21 |
+
if (resize_to_pixels is not None) and ((image_width * image_height) != resize_to_pixels):
|
| 22 |
+
resize_ratio = (resize_to_pixels / (image_width * image_height)) ** 0.5
|
| 23 |
+
image_width_resized, image_height_resized = int(image_width * resize_ratio), int(image_height * resize_ratio)
|
| 24 |
+
image = image.resize((image_width_resized, image_height_resized))
|
| 25 |
+
return image
|
| 26 |
+
|
| 27 |
+
@torch.inference_mode()
|
| 28 |
+
def draw_point(image: Image.Image, point: list, radius=8, color=(255, 0, 0, 128)):
|
| 29 |
+
overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
|
| 30 |
+
overlay_draw = ImageDraw.Draw(overlay)
|
| 31 |
+
x, y = point
|
| 32 |
+
overlay_draw.ellipse(
|
| 33 |
+
[(x - radius, y - radius), (x + radius, y + radius)],
|
| 34 |
+
outline=color,
|
| 35 |
+
width=5
|
| 36 |
+
)
|
| 37 |
+
image = image.convert('RGBA')
|
| 38 |
+
combined = Image.alpha_composite(image, overlay)
|
| 39 |
+
combined = combined.convert('RGB')
|
| 40 |
+
return combined
|
| 41 |
+
|
| 42 |
+
@torch.inference_mode()
|
| 43 |
+
def get_attn_map(image, attn_scores, n_width, n_height):
|
| 44 |
+
w, h = image.size
|
| 45 |
+
scores = np.array(attn_scores[0]).reshape(n_height, n_width)
|
| 46 |
+
|
| 47 |
+
scores_norm = (scores - scores.min()) / (scores.max() - scores.min())
|
| 48 |
+
score_map = Image.fromarray((scores_norm * 255).astype(np.uint8)).resize((w, h), resample=Image.NEAREST)
|
| 49 |
+
colormap = plt.get_cmap('jet')
|
| 50 |
+
colored_score_map = colormap(np.array(score_map) / 255.0)
|
| 51 |
+
colored_score_map = (colored_score_map[:, :, :3] * 255).astype(np.uint8)
|
| 52 |
+
colored_overlay = Image.fromarray(colored_score_map)
|
| 53 |
+
|
| 54 |
+
blended = Image.blend(image, colored_overlay, alpha=0.3)
|
| 55 |
+
return blended
|
| 56 |
+
|
| 57 |
+
# 加载模型
|
| 58 |
+
if torch.cuda.is_available():
|
| 59 |
+
model_name_or_path = "microsoft/GUI-Actor-7B-Qwen2.5-VL"
|
| 60 |
+
data_processor = AutoProcessor.from_pretrained(model_name_or_path)
|
| 61 |
+
tokenizer = data_processor.tokenizer
|
| 62 |
+
model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained(
|
| 63 |
+
model_name_or_path,
|
| 64 |
+
torch_dtype=torch.bfloat16,
|
| 65 |
+
device_map="cuda:0",
|
| 66 |
+
attn_implementation="flash_attention_2"
|
| 67 |
+
).eval()
|
| 68 |
+
else:
|
| 69 |
+
model_name_or_path = "microsoft/GUI-Actor-3B-Qwen2.5-VL"
|
| 70 |
+
data_processor = AutoProcessor.from_pretrained(model_name_or_path)
|
| 71 |
+
tokenizer = data_processor.tokenizer
|
| 72 |
+
model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained(
|
| 73 |
+
model_name_or_path,
|
| 74 |
+
torch_dtype=torch.bfloat16,
|
| 75 |
+
device_map="cpu"
|
| 76 |
+
).eval()
|
| 77 |
+
|
| 78 |
+
title = "GUI-Actor"
|
| 79 |
+
header = """
|
| 80 |
+
<div align="center">
|
| 81 |
+
<h1 style="padding-bottom: 10px; padding-top: 10px;">🎯 <strong>GUI-Actor</strong>: Coordinate-Free Visual Grounding for GUI Agents</h1>
|
| 82 |
+
<div style="padding-bottom: 10px; padding-top: 10px; font-size: 16px;">
|
| 83 |
+
Qianhui Wu*, Kanzhi Cheng*, Rui Yang*, Chaoyun Zhang, Jianwei Yang, Huiqiang Jiang, Jian Mu, Baolin Peng, Bo Qiao, Reuben Tan, Si Qin, Lars Liden<br>
|
| 84 |
+
Qingwei Lin, Huan Zhang, Tong Zhang, Jianbing Zhang, Dongmei Zhang, Jianfeng Gao<br/>
|
| 85 |
+
</div>
|
| 86 |
+
<div style="padding-bottom: 10px; padding-top: 10px; font-size: 16px;">
|
| 87 |
+
<a href="https://microsoft.github.io/GUI-Actor/">🌐 Project Page</a> | <a href="https://arxiv.org/abs/2403.12968">📄 arXiv Paper</a> | <a href="https://github.com/microsoft/GUI-Actor">💻 Github Repo</a><br/>
|
| 88 |
+
</div>
|
| 89 |
+
</div>
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
theme = "soft"
|
| 93 |
+
css = """#anno-img .mask {opacity: 0.5; transition: all 0.2s ease-in-out;}
|
| 94 |
+
#anno-img .mask.active {opacity: 0.7}"""
|
| 95 |
+
|
| 96 |
+
@torch.inference_mode()
|
| 97 |
+
def process(image, instruction):
|
| 98 |
+
# 调整图像大小
|
| 99 |
+
w, h = image.size
|
| 100 |
+
if w * h > MAX_PIXELS:
|
| 101 |
+
image = resize_image(image)
|
| 102 |
+
|
| 103 |
+
conversation = [
|
| 104 |
+
{
|
| 105 |
+
"role": "system",
|
| 106 |
+
"content": [
|
| 107 |
+
{
|
| 108 |
+
"type": "text",
|
| 109 |
+
"text": "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, your task is to locate the screen element that corresponds to the instruction. You should output a PyAutoGUI action that performs a click on the correct position. To indicate the click location, we will use some special tokens, which is used to refer to a visual patch later. For example, you can output: pyautogui.click(<your_special_token_here>).",
|
| 110 |
+
}
|
| 111 |
+
]
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"role": "user",
|
| 115 |
+
"content": [
|
| 116 |
+
{
|
| 117 |
+
"type": "image",
|
| 118 |
+
"image": image,
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"type": "text",
|
| 122 |
+
"text": instruction,
|
| 123 |
+
},
|
| 124 |
+
],
|
| 125 |
+
},
|
| 126 |
+
]
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
pred = inference(conversation, model, tokenizer, data_processor, use_placeholder=True, topk=3)
|
| 130 |
+
except Exception as e:
|
| 131 |
+
print(e)
|
| 132 |
+
return image, f"Error: {e}", None
|
| 133 |
+
|
| 134 |
+
px, py = pred["topk_points"][0]
|
| 135 |
+
output_coord = f"({px:.4f}, {py:.4f})"
|
| 136 |
+
img_with_point = draw_point(image, (px * w, py * h))
|
| 137 |
+
|
| 138 |
+
n_width, n_height = pred["n_width"], pred["n_height"]
|
| 139 |
+
attn_scores = pred["attn_scores"]
|
| 140 |
+
att_map = get_attn_map(image, attn_scores, n_width, n_height)
|
| 141 |
+
|
| 142 |
+
return img_with_point, output_coord, att_map
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def main(): # 新增:主函数,使用argparse解析参数
|
| 146 |
+
parser = argparse.ArgumentParser(description="GUI-Actor 服务")
|
| 147 |
+
parser.add_argument("--port", type=int, default=9876, help="服务端口(默认:9876)")
|
| 148 |
+
parser.add_argument("--host", default="localhost", help="服务主机(默认:localhost)")
|
| 149 |
+
|
| 150 |
+
args = parser.parse_args()
|
| 151 |
+
|
| 152 |
+
# 创建Gradio界面
|
| 153 |
+
with gr.Blocks(title=title, css=css) as demo:
|
| 154 |
+
gr.Markdown(header)
|
| 155 |
+
with gr.Row():
|
| 156 |
+
with gr.Column():
|
| 157 |
+
input_image = gr.Image(
|
| 158 |
+
type='pil', label='Upload image')
|
| 159 |
+
input_instruction = gr.Textbox(label='Instruction', placeholder='Text your (low-level) instruction here')
|
| 160 |
+
submit_button = gr.Button(
|
| 161 |
+
value='Submit', variant='primary')
|
| 162 |
+
with gr.Column():
|
| 163 |
+
image_with_point = gr.Image(type='pil', label='Image with Point (red circle)')
|
| 164 |
+
with gr.Accordion('Detailed prediction'):
|
| 165 |
+
pred_xy = gr.Textbox(label='Predicted Coordinates', placeholder='(x, y)')
|
| 166 |
+
att_map = gr.Image(type='pil', label='Attention Map')
|
| 167 |
+
|
| 168 |
+
submit_button.click(
|
| 169 |
+
fn=process,
|
| 170 |
+
inputs=[
|
| 171 |
+
input_image,
|
| 172 |
+
input_instruction
|
| 173 |
+
],
|
| 174 |
+
outputs=[image_with_point, pred_xy, att_map]
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# 启动服务(使用解析的参数)
|
| 178 |
+
print(f"🚀 GUI-Actor 服务启动中...")
|
| 179 |
+
print(f"🌐 访问地址: http://{args.host}:{args.port}")
|
| 180 |
+
|
| 181 |
+
demo.queue().launch(
|
| 182 |
+
server_name=args.host,
|
| 183 |
+
server_port=args.port,
|
| 184 |
+
share=True
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
if __name__ == "__main__": # 新增:程序入口
|
| 188 |
+
main()
|
gui_actor/__init__.py
ADDED
|
File without changes
|
gui_actor/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (146 Bytes). View file
|
|
|
gui_actor/__pycache__/constants.cpython-312.pyc
ADDED
|
Binary file (2.61 kB). View file
|
|
|
gui_actor/__pycache__/inference.cpython-312.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
gui_actor/__pycache__/modeling_qwen25vl.cpython-312.pyc
ADDED
|
Binary file (16.6 kB). View file
|
|
|
gui_actor/__pycache__/trainer.cpython-312.pyc
ADDED
|
Binary file (17.5 kB). View file
|
|
|
gui_actor/constants.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
| 4 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
| 5 |
+
|
| 6 |
+
LOGDIR = "."
|
| 7 |
+
|
| 8 |
+
# Model Constants
|
| 9 |
+
IGNORE_INDEX = -100
|
| 10 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
| 11 |
+
DEFAULT_POINTER_START_TOKEN = "<|pointer_start|>"
|
| 12 |
+
DEFAULT_POINTER_END_TOKEN = "<|pointer_end|>"
|
| 13 |
+
DEFAULT_POINTER_PAD_TOKEN = "<|pointer_pad|>"
|
| 14 |
+
|
| 15 |
+
# UNMASK_TOKEN_IDS = [198, 151644, 151645]
|
| 16 |
+
|
| 17 |
+
# System Message
|
| 18 |
+
grounding_system_message = "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, your task is to locate the screen element that corresponds to the instruction. You should output a PyAutoGUI action that performs a click on the correct position. To indicate the click location, we will use some special tokens, which is used to refer to a visual patch later. For example, you can output: pyautogui.click(<your_special_token_here>)."
|
| 19 |
+
|
| 20 |
+
# Chat Template
|
| 21 |
+
chat_template = "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
| 22 |
+
|
| 23 |
+
assistant_template = "{% for message in messages %}{{'<|im_start|>' + message['role']}}{% if 'recipient' in message %}<|recipient|>{{ message['recipient'] }}{% endif %}{{'\n' + message['content'][0]['text']}}{% if 'end_turn' in message and message['end_turn'] %}{{'<|diff_marker|>\n'}}{% else %}{{'<|im_end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|recipient|>' }}{% endif %}"
|
| 24 |
+
|
| 25 |
+
# Special Tokens
|
| 26 |
+
ADDITIONAL_SPECIAL_TOKENS = [
|
| 27 |
+
"<|recipient|>",
|
| 28 |
+
"<|diff_marker|>",
|
| 29 |
+
DEFAULT_POINTER_START_TOKEN,
|
| 30 |
+
DEFAULT_POINTER_END_TOKEN,
|
| 31 |
+
DEFAULT_POINTER_PAD_TOKEN,
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
# Action Patterns to be replaced with special tokens
|
| 35 |
+
ACTION_PATTENS_XY = [
|
| 36 |
+
r"x=([0-9.]+), y=([0-9.]+)",
|
| 37 |
+
r"from_coord=\[([0-9.]+), ([0-9.]+)\], to_coord=\[([0-9.]+), ([0-9.]+)\]",
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
until = ["<|diff_marker|>"]
|
gui_actor/dataset.py
ADDED
|
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import json
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
import re
|
| 7 |
+
import ast
|
| 8 |
+
from typing import Dict
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import transformers
|
| 12 |
+
import yaml
|
| 13 |
+
from qwen_vl_utils import smart_resize, process_vision_info
|
| 14 |
+
from torch.utils.data import Dataset
|
| 15 |
+
|
| 16 |
+
from gui_actor.constants import (
|
| 17 |
+
IGNORE_INDEX,
|
| 18 |
+
DEFAULT_IMAGE_TOKEN,
|
| 19 |
+
DEFAULT_POINTER_START_TOKEN,
|
| 20 |
+
DEFAULT_POINTER_PAD_TOKEN,
|
| 21 |
+
DEFAULT_POINTER_END_TOKEN,
|
| 22 |
+
ACTION_PATTENS_XY,
|
| 23 |
+
ADDITIONAL_SPECIAL_TOKENS,
|
| 24 |
+
assistant_template,
|
| 25 |
+
chat_template,
|
| 26 |
+
grounding_system_message,
|
| 27 |
+
)
|
| 28 |
+
from gui_actor.trainer import rank0_print
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def reformat_coordinates(text):
|
| 32 |
+
"""
|
| 33 |
+
(1) Find all the coordinates in the text.
|
| 34 |
+
(2) Replace the coordinates with the special tokens.
|
| 35 |
+
(3) Return the new text and the coordinates as a list of (x, y), where x in [0, 1] and y in [0, 1].
|
| 36 |
+
"""
|
| 37 |
+
epsilon = 0.001
|
| 38 |
+
def adjust_coord(c):
|
| 39 |
+
"""
|
| 40 |
+
Adjust coordinate if it is too close to 0 or 1.
|
| 41 |
+
"""
|
| 42 |
+
if abs(c) < epsilon:
|
| 43 |
+
return epsilon
|
| 44 |
+
elif abs(c - 1) < epsilon:
|
| 45 |
+
return 1 - epsilon
|
| 46 |
+
return c
|
| 47 |
+
|
| 48 |
+
all_matches = []
|
| 49 |
+
for pattern in ACTION_PATTENS_XY:
|
| 50 |
+
matches = list(re.finditer(pattern, text))
|
| 51 |
+
for match in matches:
|
| 52 |
+
all_matches.append((match.start(), match.groups()))
|
| 53 |
+
if pattern == ACTION_PATTENS_XY[0]:
|
| 54 |
+
target_text = f"{DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}"
|
| 55 |
+
else:
|
| 56 |
+
target_text = f"{DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}, {DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}"
|
| 57 |
+
text = re.sub(
|
| 58 |
+
pattern,
|
| 59 |
+
target_text,
|
| 60 |
+
text
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
coordinates = []
|
| 64 |
+
all_matches.sort(key=lambda x: x[0])
|
| 65 |
+
# Extract coordinates in order
|
| 66 |
+
for _, groups in all_matches:
|
| 67 |
+
# When two coordinate values are found, parse them as one (x, y) pair.
|
| 68 |
+
if len(groups) == 2:
|
| 69 |
+
x_str, y_str = groups
|
| 70 |
+
x = adjust_coord(ast.literal_eval(x_str))
|
| 71 |
+
y = adjust_coord(ast.literal_eval(y_str))
|
| 72 |
+
coordinates.append((x, y))
|
| 73 |
+
# When four coordinate values are found, parse them as two pairs.
|
| 74 |
+
elif len(groups) == 4:
|
| 75 |
+
x1_str, y1_str, x2_str, y2_str = groups
|
| 76 |
+
x1 = adjust_coord(ast.literal_eval(x1_str))
|
| 77 |
+
y1 = adjust_coord(ast.literal_eval(y1_str))
|
| 78 |
+
x2 = adjust_coord(ast.literal_eval(x2_str))
|
| 79 |
+
y2 = adjust_coord(ast.literal_eval(y2_str))
|
| 80 |
+
coordinates.append((x1, y1))
|
| 81 |
+
coordinates.append((x2, y2))
|
| 82 |
+
|
| 83 |
+
return text, coordinates
|
| 84 |
+
|
| 85 |
+
def get_token_index(image_processor, image, point_x, point_y):
|
| 86 |
+
"""
|
| 87 |
+
Get the index of the visual token that contains the point (x, y).
|
| 88 |
+
Args:
|
| 89 |
+
image_processor: the image processor
|
| 90 |
+
image: the image in PIL format
|
| 91 |
+
point_x: the x coordinate of the point, in [0, 1].
|
| 92 |
+
point_y: the y coordinate of the point, in [0, 1].
|
| 93 |
+
"""
|
| 94 |
+
if len(image) != 1:
|
| 95 |
+
raise ValueError(f"Expected 1 image, got {len(image)}")
|
| 96 |
+
|
| 97 |
+
# get the original image size and the resized image size
|
| 98 |
+
image = image[0]
|
| 99 |
+
w, h = image.size
|
| 100 |
+
px, py = w * point_x, h * point_y
|
| 101 |
+
# rank0_print(f"px: {px}, py: {py}")
|
| 102 |
+
# get the token index
|
| 103 |
+
merge_patch_size = image_processor.patch_size * image_processor.merge_size
|
| 104 |
+
x_index = math.floor(px / merge_patch_size)
|
| 105 |
+
y_index = math.floor(py / merge_patch_size)
|
| 106 |
+
|
| 107 |
+
visual_token_index = y_index * (w // merge_patch_size) + x_index
|
| 108 |
+
|
| 109 |
+
# merge all above print into one line
|
| 110 |
+
return visual_token_index
|
| 111 |
+
|
| 112 |
+
def get_multi_patch_labels(image_processor, image, bbox_gt):
|
| 113 |
+
"""
|
| 114 |
+
Get the multi-patch labels for the bounding box.
|
| 115 |
+
Args:
|
| 116 |
+
image_processor: the image processor
|
| 117 |
+
image: the image in PIL format
|
| 118 |
+
bbox_gt: the bounding box in the format of (x_min, y_min, x_max, y_max) [0,1]
|
| 119 |
+
"""
|
| 120 |
+
if len(image) != 1:
|
| 121 |
+
raise ValueError(f"Expected 1 image, got {len(image)}")
|
| 122 |
+
|
| 123 |
+
# Get the original image size and the resized image size
|
| 124 |
+
image = image[0]
|
| 125 |
+
w, h = image.size
|
| 126 |
+
|
| 127 |
+
bbox_gt = [bbox_gt[0]*w, bbox_gt[1]*h, bbox_gt[2]*w, bbox_gt[3]*h]
|
| 128 |
+
# Extract bounding box coordinates
|
| 129 |
+
x_min, y_min, x_max, y_max = bbox_gt
|
| 130 |
+
x_min = max(0, x_min)
|
| 131 |
+
y_min = max(0, y_min)
|
| 132 |
+
x_max = min(w, x_max)
|
| 133 |
+
y_max = min(h, y_max)
|
| 134 |
+
|
| 135 |
+
merge_patch_size = image_processor.patch_size * image_processor.merge_size
|
| 136 |
+
assert w % merge_patch_size == 0 and h % merge_patch_size == 0, f"Image size {w}x{h} is not divisible by merge_patch_size {merge_patch_size}"
|
| 137 |
+
grid_h, grid_w = h // merge_patch_size, w // merge_patch_size
|
| 138 |
+
|
| 139 |
+
binary_mask = torch.zeros(grid_h * grid_w)
|
| 140 |
+
# Iterate through all patches, check if they overlap with the bounding box
|
| 141 |
+
for y_idx in range(grid_h):
|
| 142 |
+
for x_idx in range(grid_w):
|
| 143 |
+
# Calculate patch boundaries
|
| 144 |
+
patch_x_min = x_idx * merge_patch_size
|
| 145 |
+
patch_y_min = y_idx * merge_patch_size
|
| 146 |
+
patch_x_max = patch_x_min + merge_patch_size
|
| 147 |
+
patch_y_max = patch_y_min + merge_patch_size
|
| 148 |
+
|
| 149 |
+
# Check if patch overlaps with the bounding box
|
| 150 |
+
if not (patch_x_max <= x_min or patch_x_min >= x_max or
|
| 151 |
+
patch_y_max <= y_min or patch_y_min >= y_max):
|
| 152 |
+
# Calculate patch index in the flattened grid
|
| 153 |
+
patch_idx = y_idx * grid_w + x_idx
|
| 154 |
+
binary_mask[patch_idx] = 1
|
| 155 |
+
|
| 156 |
+
return binary_mask
|
| 157 |
+
|
| 158 |
+
def token_index_to_coordinates(image_processor, visual_token_index, image_width, image_height):
|
| 159 |
+
merge_patch_size = image_processor.patch_size * image_processor.merge_size
|
| 160 |
+
x_index = visual_token_index % (image_width // merge_patch_size)
|
| 161 |
+
y_index = visual_token_index // (image_width // merge_patch_size)
|
| 162 |
+
px = x_index * merge_patch_size + merge_patch_size / 2
|
| 163 |
+
py = y_index * merge_patch_size + merge_patch_size / 2
|
| 164 |
+
return px, py
|
| 165 |
+
|
| 166 |
+
class LazySupervisedDataset(Dataset):
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 170 |
+
processor: transformers.ProcessorMixin,
|
| 171 |
+
data_path: str,
|
| 172 |
+
data_args,
|
| 173 |
+
):
|
| 174 |
+
super().__init__()
|
| 175 |
+
self.tokenizer = tokenizer
|
| 176 |
+
self.processor = processor
|
| 177 |
+
self.list_data_dict = []
|
| 178 |
+
self.list_image_path = []
|
| 179 |
+
self.pointer_pad_token_id = tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0]
|
| 180 |
+
self.pointer_start_token_id = tokenizer.encode(DEFAULT_POINTER_START_TOKEN)[0]
|
| 181 |
+
self.pointer_end_token_id = tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0]
|
| 182 |
+
|
| 183 |
+
# Handle multiple JSON files specified in the data_path
|
| 184 |
+
if "{" in data_path and "}" in data_path:
|
| 185 |
+
base_path, file_pattern = re.match(r"^(.*)\{(.*)\}\.json$", data_path).groups()
|
| 186 |
+
file_names = file_pattern.split(",")
|
| 187 |
+
rank0_print(f"Loading {file_names} from {base_path}")
|
| 188 |
+
data_args.dataset_paths = []
|
| 189 |
+
for file_name in file_names:
|
| 190 |
+
data_args.dataset_paths.append(f"{base_path}{file_name}.json")
|
| 191 |
+
full_path = f"{base_path}{file_name}.json"
|
| 192 |
+
rank0_print(f"Loading {full_path}")
|
| 193 |
+
with open(full_path) as file:
|
| 194 |
+
cur_data_dict = json.load(file)
|
| 195 |
+
rank0_print(f"Loaded {len(cur_data_dict)} samples from {full_path}")
|
| 196 |
+
self.list_data_dict.extend(cur_data_dict)
|
| 197 |
+
elif data_path.endswith(".yaml"):
|
| 198 |
+
with open(data_path) as file:
|
| 199 |
+
yaml_data = yaml.safe_load(file)
|
| 200 |
+
datasets = yaml_data.get("datasets")
|
| 201 |
+
# file should be in the format of:
|
| 202 |
+
# datasets:
|
| 203 |
+
# - json_path: xxxx1.json
|
| 204 |
+
# sampling_strategy: first:1000
|
| 205 |
+
# - json_path: xxxx2.json
|
| 206 |
+
# sampling_strategy: end:3000
|
| 207 |
+
# - json_path: xxxx3.json
|
| 208 |
+
# sampling_strategy: random:999
|
| 209 |
+
data_args.dataset_paths = [dataset.get("json_path") for dataset in datasets]
|
| 210 |
+
for dataset in datasets:
|
| 211 |
+
json_path = dataset.get("json_path")
|
| 212 |
+
sampling_strategy = dataset.get("sampling_strategy", "all")
|
| 213 |
+
images_folder = dataset.get("images_folder")
|
| 214 |
+
sampling_number = None
|
| 215 |
+
|
| 216 |
+
rank0_print(f"Loading {json_path} with {sampling_strategy} sampling strategy")
|
| 217 |
+
|
| 218 |
+
if json_path.endswith(".jsonl"):
|
| 219 |
+
cur_data_dict = []
|
| 220 |
+
with open(json_path) as json_file:
|
| 221 |
+
for line in json_file:
|
| 222 |
+
cur_data_dict.append(json.loads(line.strip()))
|
| 223 |
+
elif json_path.endswith(".json"):
|
| 224 |
+
# NOTE: we only use json_path with .json now
|
| 225 |
+
# Handle the images_folder in yaml
|
| 226 |
+
with open(json_path) as json_file:
|
| 227 |
+
cur_data_dict = json.load(json_file)
|
| 228 |
+
else:
|
| 229 |
+
raise ValueError(f"Unsupported file type: {json_path}")
|
| 230 |
+
|
| 231 |
+
if ":" in sampling_strategy:
|
| 232 |
+
sampling_strategy, sampling_number = sampling_strategy.split(":")
|
| 233 |
+
if "%" in sampling_number:
|
| 234 |
+
sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
|
| 235 |
+
else:
|
| 236 |
+
sampling_number = int(sampling_number)
|
| 237 |
+
|
| 238 |
+
# Apply the sampling strategy
|
| 239 |
+
if sampling_strategy == "first" and sampling_number is not None:
|
| 240 |
+
cur_data_dict = cur_data_dict[:sampling_number]
|
| 241 |
+
elif sampling_strategy == "end" and sampling_number is not None:
|
| 242 |
+
cur_data_dict = cur_data_dict[-sampling_number:]
|
| 243 |
+
elif sampling_strategy == "random" and sampling_number is not None:
|
| 244 |
+
random.shuffle(cur_data_dict)
|
| 245 |
+
cur_data_dict = cur_data_dict[:sampling_number]
|
| 246 |
+
|
| 247 |
+
rank0_print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
|
| 248 |
+
self.list_data_dict.extend(cur_data_dict)
|
| 249 |
+
self.list_image_path.extend([images_folder] * len(cur_data_dict))
|
| 250 |
+
else:
|
| 251 |
+
data_args.dataset_paths = [data_path]
|
| 252 |
+
rank0_print(f"Loading {data_path}")
|
| 253 |
+
with open(data_path) as file:
|
| 254 |
+
cur_data_dict = json.load(file)
|
| 255 |
+
rank0_print(f"Loaded {len(cur_data_dict)} samples from {data_path}")
|
| 256 |
+
self.list_data_dict.extend(cur_data_dict)
|
| 257 |
+
self.list_image_path.extend([""] * len(cur_data_dict)) # NOTE: the image subfolder is empty...
|
| 258 |
+
|
| 259 |
+
rank0_print(f"Loaded {len(self.list_data_dict)} samples from {data_path}")
|
| 260 |
+
rank0_print("Formatting inputs...Skip in lazy mode")
|
| 261 |
+
self.tokenizer = tokenizer
|
| 262 |
+
self.data_args = data_args
|
| 263 |
+
|
| 264 |
+
def __len__(self):
|
| 265 |
+
return len(self.list_data_dict)
|
| 266 |
+
|
| 267 |
+
@property
|
| 268 |
+
def lengths(self):
|
| 269 |
+
length_list = []
|
| 270 |
+
for sample in self.list_data_dict:
|
| 271 |
+
img_tokens = (
|
| 272 |
+
1200 * len(sample["image"]) if isinstance(sample["image"], list) else 1200 if "image" in sample else 0
|
| 273 |
+
)
|
| 274 |
+
length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
|
| 275 |
+
return length_list
|
| 276 |
+
|
| 277 |
+
@property
|
| 278 |
+
def modality_lengths(self):
|
| 279 |
+
length_list = []
|
| 280 |
+
for sample in self.list_data_dict:
|
| 281 |
+
cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"])
|
| 282 |
+
assert cur_len > 0, f"Conversation length is 0 for {sample}"
|
| 283 |
+
|
| 284 |
+
img_tokens = (
|
| 285 |
+
1200 * len(sample["image"]) if isinstance(sample["image"], list) else 1200 if "image" in sample else 0
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if "image" in sample or "video" in sample or self.data_args.early_mix_text:
|
| 289 |
+
length_list.append(cur_len + img_tokens)
|
| 290 |
+
else:
|
| 291 |
+
length_list.append(-cur_len)
|
| 292 |
+
return length_list
|
| 293 |
+
|
| 294 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
| 295 |
+
sample = self._get_item(i)
|
| 296 |
+
if sample is None:
|
| 297 |
+
new_index = random.randint(0, len(self.list_data_dict) - 1)
|
| 298 |
+
return self.__getitem__(new_index)
|
| 299 |
+
else:
|
| 300 |
+
return sample
|
| 301 |
+
try:
|
| 302 |
+
sample = self._get_item(i)
|
| 303 |
+
if sample is None:
|
| 304 |
+
new_index = random.randint(0, len(self.list_data_dict) - 1)
|
| 305 |
+
return self.__getitem__(new_index)
|
| 306 |
+
except Exception as e:
|
| 307 |
+
print(f"Failed to fetch sample {i}. Exception:", e)
|
| 308 |
+
new_index = random.randint(0, len(self.list_data_dict) - 1)
|
| 309 |
+
return self.__getitem__(new_index)
|
| 310 |
+
return sample
|
| 311 |
+
|
| 312 |
+
def _get_item(self, i) -> Dict[str, torch.Tensor]:
|
| 313 |
+
sources = self.list_data_dict[i]
|
| 314 |
+
image_path = os.path.join(self.data_args.image_folder, self.list_image_path[i])
|
| 315 |
+
|
| 316 |
+
if "image" in sources:
|
| 317 |
+
image_file = self.list_data_dict[i]["image"]
|
| 318 |
+
if type(image_file) is list:
|
| 319 |
+
image_list = [os.path.join(image_path, image_file) for image_file in image_file]
|
| 320 |
+
else:
|
| 321 |
+
image_list = [os.path.join(image_path, image_file)]
|
| 322 |
+
|
| 323 |
+
sources = copy.deepcopy(sources["conversations"])
|
| 324 |
+
elif "video" in sources:
|
| 325 |
+
raise NotImplementedError("Video is not supported for Qwen2VL")
|
| 326 |
+
else:
|
| 327 |
+
sources = copy.deepcopy(sources["conversations"])
|
| 328 |
+
|
| 329 |
+
item_id = self.list_data_dict[i].get("id", i)
|
| 330 |
+
|
| 331 |
+
data_dict = self.preprocess_qwen2vl(sources, self.tokenizer, self.processor, image_list, id=item_id)
|
| 332 |
+
if isinstance(i, int):
|
| 333 |
+
data_dict = {
|
| 334 |
+
"input_ids": data_dict["input_ids"][0],
|
| 335 |
+
"labels": data_dict["labels"][0],
|
| 336 |
+
"coordinates": data_dict["coordinates"][0],
|
| 337 |
+
"visual_token_indices_of_coordinates": data_dict["visual_token_indices_of_coordinates"][0],
|
| 338 |
+
"pixel_values": data_dict["pixel_values"],
|
| 339 |
+
"image_grid_thw": data_dict["image_grid_thw"],
|
| 340 |
+
"multi_patch_labels": data_dict["multi_patch_labels"][0], # add multi_patch_labels
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
data_dict["id"] = item_id
|
| 344 |
+
|
| 345 |
+
# return None if the input_ids is longer than the model_max_length
|
| 346 |
+
n_image_tokens = (
|
| 347 |
+
data_dict["image_grid_thw"][0][0] *
|
| 348 |
+
data_dict["image_grid_thw"][0][1] *
|
| 349 |
+
data_dict["image_grid_thw"][0][2] /
|
| 350 |
+
self.processor.image_processor.merge_size /
|
| 351 |
+
self.processor.image_processor.merge_size
|
| 352 |
+
)
|
| 353 |
+
if (len(data_dict["input_ids"]) + n_image_tokens) > self.tokenizer.model_max_length:
|
| 354 |
+
rank0_print(f"=== Removed data_dict {i} because it is longer than the model_max_length: {len(data_dict['input_ids'])} + {n_image_tokens} > {self.tokenizer.model_max_length}")
|
| 355 |
+
return None
|
| 356 |
+
|
| 357 |
+
return data_dict
|
| 358 |
+
|
| 359 |
+
def preprocess_qwen2vl(
|
| 360 |
+
self,
|
| 361 |
+
source, # conversations
|
| 362 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 363 |
+
processor: transformers.ProcessorMixin,
|
| 364 |
+
image: list,
|
| 365 |
+
system_message: str = grounding_system_message,
|
| 366 |
+
agent_mode: bool = True,
|
| 367 |
+
chat_template: str = chat_template,
|
| 368 |
+
assistant_template: str = assistant_template,
|
| 369 |
+
id: int = None,
|
| 370 |
+
) -> Dict:
|
| 371 |
+
roles = {"human": "user", "gpt": "assistant", "system": "system"}
|
| 372 |
+
assistant_template = assistant_template if agent_mode else chat_template
|
| 373 |
+
processor.tokenizer = tokenizer
|
| 374 |
+
assert tokenizer.additional_special_tokens == ADDITIONAL_SPECIAL_TOKENS
|
| 375 |
+
|
| 376 |
+
# Apply prompt templates
|
| 377 |
+
pixel_values, image_grid_thw = None, None
|
| 378 |
+
|
| 379 |
+
input_id, target = [], []
|
| 380 |
+
coordinates = []
|
| 381 |
+
visual_token_indices_of_coordinates = []
|
| 382 |
+
multi_patch_labels = []
|
| 383 |
+
|
| 384 |
+
image_list = []
|
| 385 |
+
image_index = 0
|
| 386 |
+
|
| 387 |
+
## prepare the system message
|
| 388 |
+
if roles[source[0]["from"]] == "system":
|
| 389 |
+
system_message = source[0]["value"]
|
| 390 |
+
source = source[1:self.data_args.max_conv_turns]
|
| 391 |
+
# else: use the constant system message
|
| 392 |
+
system_input_id = tokenizer.apply_chat_template(
|
| 393 |
+
conversation=[{"role": "system", "content": [{"type": "text", "text": system_message}]}],
|
| 394 |
+
chat_template=chat_template,
|
| 395 |
+
)
|
| 396 |
+
input_id += system_input_id
|
| 397 |
+
target += [IGNORE_INDEX] * len(system_input_id)
|
| 398 |
+
|
| 399 |
+
## prepare user-assistant conversation
|
| 400 |
+
for conv in source:
|
| 401 |
+
# regularize the conversation format
|
| 402 |
+
try:
|
| 403 |
+
role = conv["role"]
|
| 404 |
+
content = conv["content"]
|
| 405 |
+
except Exception:
|
| 406 |
+
role = conv["from"]
|
| 407 |
+
content = conv["value"]
|
| 408 |
+
role = roles.get(role, role)
|
| 409 |
+
|
| 410 |
+
# Count the number of <image> tokens in the content
|
| 411 |
+
image_count = content.count(DEFAULT_IMAGE_TOKEN)
|
| 412 |
+
if image_count > 0:
|
| 413 |
+
assert role == "user", "Images are only supported for user messages"
|
| 414 |
+
# include image information regarding to current conversation turn
|
| 415 |
+
image_placeholders = []
|
| 416 |
+
for _ in range(image_count):
|
| 417 |
+
image_placeholders.append({
|
| 418 |
+
"type": "image",
|
| 419 |
+
"image": image[image_index],
|
| 420 |
+
"min_pixels": self.processor.image_processor.min_pixels,
|
| 421 |
+
"max_pixels": self.processor.image_processor.max_pixels,
|
| 422 |
+
})
|
| 423 |
+
image_index += 1
|
| 424 |
+
|
| 425 |
+
content = content.replace(DEFAULT_IMAGE_TOKEN, "")
|
| 426 |
+
conv = {"role": role, "content": image_placeholders + [{"type": "text", "text": content}]}
|
| 427 |
+
|
| 428 |
+
image_inputs, _ = process_vision_info([conv]) # list of PIL.Image.Image
|
| 429 |
+
image_list.extend(image_inputs)
|
| 430 |
+
|
| 431 |
+
templated_conv = tokenizer.apply_chat_template(
|
| 432 |
+
conversation=[conv], chat_template=chat_template, tokenize=False
|
| 433 |
+
)
|
| 434 |
+
inputs = processor(text=[templated_conv], images=image_inputs, return_tensors="pt")
|
| 435 |
+
|
| 436 |
+
if pixel_values is None and image_grid_thw is None:
|
| 437 |
+
pixel_values = inputs["pixel_values"]
|
| 438 |
+
image_grid_thw = inputs["image_grid_thw"]
|
| 439 |
+
else:
|
| 440 |
+
pixel_values = torch.concat([pixel_values, inputs["pixel_values"]], dim=0)
|
| 441 |
+
image_grid_thw = torch.concat([image_grid_thw, inputs["image_grid_thw"]], dim=0)
|
| 442 |
+
else:
|
| 443 |
+
if role in ["user", "system"]:
|
| 444 |
+
conv = {"role": role, "content": [{"type": "text", "text": content}]}
|
| 445 |
+
else: # assistant
|
| 446 |
+
conv = {
|
| 447 |
+
"role": role,
|
| 448 |
+
"content": [{"type": "text", "text": content}],
|
| 449 |
+
"recipient": conv.get("recipient", "os"),
|
| 450 |
+
"end_turn": conv.get("end_turn", True),
|
| 451 |
+
"bbox_gt": conv.get("bbox_gt", None),
|
| 452 |
+
}
|
| 453 |
+
if conv["recipient"] == "os":
|
| 454 |
+
if len(image_inputs) == 0:
|
| 455 |
+
raise ValueError("No image found for visual grounding")
|
| 456 |
+
# replace the coordinates with the special tokens
|
| 457 |
+
text, coord = reformat_coordinates(conv["content"][0]["text"])
|
| 458 |
+
conv["content"][0]["text"] = text
|
| 459 |
+
# rank0_print(f"coord: {coord}")
|
| 460 |
+
|
| 461 |
+
# get the visual token indices of the coordinates
|
| 462 |
+
coordinates.extend(coord)
|
| 463 |
+
for (point_x, point_y) in coord:
|
| 464 |
+
visual_token_index = get_token_index(
|
| 465 |
+
processor.image_processor,
|
| 466 |
+
image_list,
|
| 467 |
+
point_x,
|
| 468 |
+
point_y
|
| 469 |
+
)
|
| 470 |
+
# px, py = token_index_to_coordinates(
|
| 471 |
+
# processor.image_processor,
|
| 472 |
+
# visual_token_index,
|
| 473 |
+
# image_list[0].size[0], # make sure the size here is after qwen2vl processing
|
| 474 |
+
# image_list[0].size[1]
|
| 475 |
+
# )
|
| 476 |
+
# rank0_print(f"estimated px: {px}, py: {py}")
|
| 477 |
+
visual_token_indices_of_coordinates.append(visual_token_index)
|
| 478 |
+
|
| 479 |
+
if conv["bbox_gt"] is not None:
|
| 480 |
+
patch_mask = get_multi_patch_labels(
|
| 481 |
+
processor.image_processor,
|
| 482 |
+
image_list,
|
| 483 |
+
conv["bbox_gt"]
|
| 484 |
+
)
|
| 485 |
+
multi_patch_labels.append(patch_mask)
|
| 486 |
+
|
| 487 |
+
templated_conv = tokenizer.apply_chat_template(
|
| 488 |
+
conversation=[conv],
|
| 489 |
+
chat_template=assistant_template,
|
| 490 |
+
tokenize=False,
|
| 491 |
+
)
|
| 492 |
+
inputs = processor(text=[templated_conv], return_tensors="pt")
|
| 493 |
+
|
| 494 |
+
encode_id = inputs.input_ids[0].tolist()
|
| 495 |
+
|
| 496 |
+
input_id += encode_id
|
| 497 |
+
if role in ["user", "system"]:
|
| 498 |
+
target += [IGNORE_INDEX] * len(encode_id)
|
| 499 |
+
else:
|
| 500 |
+
target += encode_id
|
| 501 |
+
|
| 502 |
+
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
|
| 503 |
+
|
| 504 |
+
# make the labels of all pointer_end_token_id to be IGNORE_INDEX
|
| 505 |
+
target = [IGNORE_INDEX if token == self.pointer_end_token_id else token for token in target]
|
| 506 |
+
|
| 507 |
+
input_ids = torch.tensor([input_id], dtype=torch.long)
|
| 508 |
+
targets = torch.tensor([target], dtype=torch.long)
|
| 509 |
+
visual_token_indices_of_coordinates = torch.tensor([visual_token_indices_of_coordinates], dtype=torch.long) if len(visual_token_indices_of_coordinates) > 0 else [None]
|
| 510 |
+
coordinates = [coordinates] if len(coordinates) > 0 else [None]
|
| 511 |
+
|
| 512 |
+
# process multi_patch_labels
|
| 513 |
+
if len(multi_patch_labels) > 0:
|
| 514 |
+
multi_patch_labels = [torch.stack(multi_patch_labels)]
|
| 515 |
+
else:
|
| 516 |
+
multi_patch_labels = [None]
|
| 517 |
+
|
| 518 |
+
data_dict = {
|
| 519 |
+
"input_ids": input_ids, # tensor(bs x seq_len)
|
| 520 |
+
"labels": targets, # tensor(bs x seq_len)
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
if pixel_values is not None:
|
| 524 |
+
data_dict["pixel_values"] = pixel_values
|
| 525 |
+
data_dict["image_grid_thw"] = image_grid_thw
|
| 526 |
+
|
| 527 |
+
# if len(coordinates[0]) != len(visual_token_indices_of_coordinates[0]):
|
| 528 |
+
# raise ValueError(f"The number of coordinates ({len(coordinates[0])}) does not match the number of image token indices ({len(visual_token_indices_of_coordinates[0])})")
|
| 529 |
+
data_dict["coordinates"] = coordinates
|
| 530 |
+
data_dict["visual_token_indices_of_coordinates"] = visual_token_indices_of_coordinates
|
| 531 |
+
data_dict["multi_patch_labels"] = multi_patch_labels
|
| 532 |
+
|
| 533 |
+
return data_dict
|
gui_actor/inference.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
import os
|
| 5 |
+
from qwen_vl_utils import process_vision_info
|
| 6 |
+
from transformers import (
|
| 7 |
+
Qwen2VLForConditionalGeneration,
|
| 8 |
+
LogitsProcessor,
|
| 9 |
+
LogitsProcessorList,
|
| 10 |
+
AutoModelForCausalLM,
|
| 11 |
+
AutoTokenizer
|
| 12 |
+
)
|
| 13 |
+
from gui_actor.constants import (
|
| 14 |
+
DEFAULT_POINTER_END_TOKEN,
|
| 15 |
+
DEFAULT_POINTER_PAD_TOKEN,
|
| 16 |
+
chat_template
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
class ForceFollowTokensLogitsProcessor(LogitsProcessor):
|
| 20 |
+
"""
|
| 21 |
+
Forces tokens B (pointer_pad_token) and C (pointer_end_token) to follow token A (pointer_start_token).
|
| 22 |
+
Whenever token_a_id is generated, enqueue the forced_sequence (e.g. [B, C]).
|
| 23 |
+
As long as forced tokens remain in the queue, force them in the output.
|
| 24 |
+
"""
|
| 25 |
+
def __init__(self, token_a_id, forced_sequence=[DEFAULT_POINTER_PAD_TOKEN, DEFAULT_POINTER_END_TOKEN]):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.token_a_id = token_a_id
|
| 28 |
+
self.forced_sequence = forced_sequence # list of token IDs, e.g. [B_id, C_id]
|
| 29 |
+
self.force_queue = [] # holds the tokens we still need to force
|
| 30 |
+
|
| 31 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 32 |
+
"""
|
| 33 |
+
Called at each decoding step to modify `scores`.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
input_ids: shape (batch_size, seq_len). The already-decoded tokens.
|
| 37 |
+
scores: shape (batch_size, vocab_size). Model logits for the next token.
|
| 38 |
+
"""
|
| 39 |
+
batch_size = input_ids.shape[0]
|
| 40 |
+
if batch_size > 1:
|
| 41 |
+
raise NotImplementedError("Batch size must be 1 for this logits processor.")
|
| 42 |
+
|
| 43 |
+
# We assume batch_size=1 for simplicity; if you have multiple sequences,
|
| 44 |
+
# you'll need to adapt the logic to handle each item in the batch.
|
| 45 |
+
last_token_id = input_ids[0, -1].item()
|
| 46 |
+
|
| 47 |
+
# If the last token was A, enqueue B and C
|
| 48 |
+
if last_token_id == self.token_a_id:
|
| 49 |
+
self.force_queue.extend(self.forced_sequence)
|
| 50 |
+
|
| 51 |
+
# If we have forced tokens waiting in the queue, override the distribution
|
| 52 |
+
if len(self.force_queue) > 0:
|
| 53 |
+
forced_token = self.force_queue.pop(0) # next token to force
|
| 54 |
+
# Create a mask of -inf for all tokens except the forced one
|
| 55 |
+
new_scores = torch.full_like(scores, float('-inf'))
|
| 56 |
+
new_scores[0, forced_token] = 0.0 # log prob = 0 => prob = 1
|
| 57 |
+
return new_scores
|
| 58 |
+
|
| 59 |
+
# Otherwise, return scores unmodified
|
| 60 |
+
return scores
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_prediction_region_point(attn_scores, n_width, n_height, top_n=30, activation_threshold=0.3, return_all_regions=True, rect_center=False):
|
| 64 |
+
"""
|
| 65 |
+
1. Select activated patches
|
| 66 |
+
2. Divide connected patches into different regions
|
| 67 |
+
3. Calculate the average activation value for each region
|
| 68 |
+
4. Select the region with the highest average activation value
|
| 69 |
+
5. Return the center point of that region as the final prediction point
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
# Get patches with activation values greater than a certain proportion of the maximum activation value as activated patches
|
| 73 |
+
# Get the highest activation value and threshold
|
| 74 |
+
max_score = attn_scores[0].max().item()
|
| 75 |
+
threshold = max_score * activation_threshold
|
| 76 |
+
# Select all patches above the threshold
|
| 77 |
+
mask = attn_scores[0] > threshold
|
| 78 |
+
valid_indices = torch.nonzero(mask).squeeze(-1)
|
| 79 |
+
topk_values = attn_scores[0][valid_indices]
|
| 80 |
+
topk_indices = valid_indices
|
| 81 |
+
|
| 82 |
+
# Convert indices to 2D coordinates
|
| 83 |
+
topk_coords = []
|
| 84 |
+
for idx in topk_indices.tolist():
|
| 85 |
+
y = idx // n_width
|
| 86 |
+
x = idx % n_width
|
| 87 |
+
topk_coords.append((y, x, idx))
|
| 88 |
+
|
| 89 |
+
# Divide into connected regions
|
| 90 |
+
regions = []
|
| 91 |
+
visited = set()
|
| 92 |
+
for i, (y, x, idx) in enumerate(topk_coords):
|
| 93 |
+
if idx in visited:
|
| 94 |
+
continue
|
| 95 |
+
|
| 96 |
+
# Start a new region
|
| 97 |
+
region = [(y, x, idx, topk_values[i].item())]
|
| 98 |
+
visited.add(idx)
|
| 99 |
+
queue = [(y, x, idx, topk_values[i].item())]
|
| 100 |
+
|
| 101 |
+
# BFS to find connected points
|
| 102 |
+
while queue:
|
| 103 |
+
cy, cx, c_idx, c_val = queue.pop(0)
|
| 104 |
+
|
| 105 |
+
# Check 4 adjacent directions
|
| 106 |
+
for dy, dx in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
|
| 107 |
+
ny, nx = cy + dy, cx + dx
|
| 108 |
+
n_idx = ny * n_width + nx
|
| 109 |
+
|
| 110 |
+
# Check if this adjacent point is in the topk list
|
| 111 |
+
for j, (ty, tx, t_idx) in enumerate(topk_coords):
|
| 112 |
+
if ty == ny and tx == nx and t_idx not in visited:
|
| 113 |
+
visited.add(t_idx)
|
| 114 |
+
region.append((ny, nx, t_idx, topk_values[j].item()))
|
| 115 |
+
queue.append((ny, nx, t_idx, topk_values[j].item()))
|
| 116 |
+
|
| 117 |
+
regions.append(region)
|
| 118 |
+
|
| 119 |
+
# Calculate the average activation value for each region
|
| 120 |
+
region_scores = []
|
| 121 |
+
region_centers = []
|
| 122 |
+
region_points = []
|
| 123 |
+
|
| 124 |
+
for region in regions:
|
| 125 |
+
# Calculate average score for the region
|
| 126 |
+
avg_score = sum(item[3] for item in region) / len(region)
|
| 127 |
+
region_scores.append(avg_score)
|
| 128 |
+
|
| 129 |
+
# Calculate normalized center coordinates for each patch, then take the average
|
| 130 |
+
normalized_centers = []
|
| 131 |
+
weights = []
|
| 132 |
+
y_coords = set()
|
| 133 |
+
x_coords = set()
|
| 134 |
+
|
| 135 |
+
for y, x, _, score in region:
|
| 136 |
+
# Normalized coordinates of the center point for each patch
|
| 137 |
+
center_y = (y + 0.5) / n_height
|
| 138 |
+
center_x = (x + 0.5) / n_width
|
| 139 |
+
normalized_centers.append((center_x, center_y))
|
| 140 |
+
weights.append(score)
|
| 141 |
+
|
| 142 |
+
y_coords.add(center_y)
|
| 143 |
+
x_coords.add(center_x)
|
| 144 |
+
|
| 145 |
+
region_points.append(normalized_centers)
|
| 146 |
+
|
| 147 |
+
# Calculate the average of normalized coordinates as the region center
|
| 148 |
+
if not rect_center:
|
| 149 |
+
# Weighted average
|
| 150 |
+
total_weight = sum(weights)
|
| 151 |
+
weighted_x = sum(nc[0] * w for nc, w in zip(normalized_centers, weights)) / total_weight
|
| 152 |
+
weighted_y = sum(nc[1] * w for nc, w in zip(normalized_centers, weights)) / total_weight
|
| 153 |
+
avg_center_x, avg_center_y = weighted_x, weighted_y
|
| 154 |
+
# # Simple average
|
| 155 |
+
# avg_center_x = sum(nc[0] for nc in normalized_centers) / len(normalized_centers)
|
| 156 |
+
# avg_center_y = sum(nc[1] for nc in normalized_centers) / len(normalized_centers)
|
| 157 |
+
else:
|
| 158 |
+
avg_center_x = sum(x_coords) / len(x_coords)
|
| 159 |
+
avg_center_y = sum(y_coords) / len(y_coords)
|
| 160 |
+
region_centers.append((avg_center_x, avg_center_y))
|
| 161 |
+
|
| 162 |
+
# Select the region with the highest average activation value
|
| 163 |
+
sorted_indices = sorted(range(len(region_scores)), key=lambda i: region_scores[i], reverse=True)
|
| 164 |
+
sorted_scores = [region_scores[i] for i in sorted_indices]
|
| 165 |
+
sorted_centers = [region_centers[i] for i in sorted_indices]
|
| 166 |
+
sorted_points = [region_points[i] for i in sorted_indices]
|
| 167 |
+
best_point = sorted_centers[0]
|
| 168 |
+
|
| 169 |
+
if return_all_regions:
|
| 170 |
+
# Outputs:
|
| 171 |
+
# 1. best_point: the center point of the region with the highest average activation value
|
| 172 |
+
# 2. sorted_centers: the center points of all regions, sorted by the average activation value in descending order
|
| 173 |
+
# 3. sorted_scores: the average activation values of all regions, sorted in descending order
|
| 174 |
+
# 4. sorted_points: the normalized center coordinates of all patches, sorted by the average activation value in descending order
|
| 175 |
+
return best_point, sorted_centers, sorted_scores, sorted_points
|
| 176 |
+
else:
|
| 177 |
+
return best_point
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def inference(conversation, model, tokenizer, data_processor, logits_processor=None, use_placeholder=False, topk=5):
|
| 181 |
+
"""
|
| 182 |
+
conversation = [
|
| 183 |
+
{
|
| 184 |
+
"role": "system",
|
| 185 |
+
"content": [
|
| 186 |
+
{
|
| 187 |
+
"type": "text",
|
| 188 |
+
"text": grounding_system_message,
|
| 189 |
+
}
|
| 190 |
+
]
|
| 191 |
+
},
|
| 192 |
+
{
|
| 193 |
+
"role": "user",
|
| 194 |
+
"content": [
|
| 195 |
+
{
|
| 196 |
+
"type": "image",
|
| 197 |
+
"image": example["image"], # PIL.Image.Image or str to path
|
| 198 |
+
# "image_url": "https://xxxxx.png" or "https://xxxxx.jpg" or "file://xxxxx.png" or "data:image/png;base64,xxxxxxxx", will be split by "base64,"
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"type": "text",
|
| 202 |
+
"text": example["instruction"]
|
| 203 |
+
},
|
| 204 |
+
],
|
| 205 |
+
},
|
| 206 |
+
]
|
| 207 |
+
"""
|
| 208 |
+
if logits_processor is None:
|
| 209 |
+
logits_processor = ForceFollowTokensLogitsProcessor(
|
| 210 |
+
token_a_id=tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0],
|
| 211 |
+
forced_sequence=[
|
| 212 |
+
tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0]
|
| 213 |
+
]
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
assiatant_starter = "" if not use_placeholder else "<|im_start|>assistant<|recipient|>os\npyautogui.click(<|pointer_start|><|pointer_pad|><|pointer_end|>)"
|
| 217 |
+
|
| 218 |
+
pred = {
|
| 219 |
+
"output_text": None, # generated text
|
| 220 |
+
"n_width": None, # number of patch_tokens in width dimension
|
| 221 |
+
"n_height": None, # number of patch_tokens in height dimension
|
| 222 |
+
"attn_scores": None, # attention scores over the image patches
|
| 223 |
+
"topk_points": None, # topk points
|
| 224 |
+
"topk_values": None, # topk values
|
| 225 |
+
"topk_points_all": None, # all points
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
# prepare text
|
| 229 |
+
text = data_processor.apply_chat_template(conversation,
|
| 230 |
+
tokenize=False,
|
| 231 |
+
add_generation_prompt=False,
|
| 232 |
+
chat_template=chat_template
|
| 233 |
+
)
|
| 234 |
+
text += assiatant_starter
|
| 235 |
+
|
| 236 |
+
# prepare inputs
|
| 237 |
+
image_inputs, video_inputs = process_vision_info(conversation)
|
| 238 |
+
inputs = data_processor(text=[text],
|
| 239 |
+
images=image_inputs,
|
| 240 |
+
videos=video_inputs,
|
| 241 |
+
padding=True,
|
| 242 |
+
return_tensors="pt"
|
| 243 |
+
)
|
| 244 |
+
inputs = inputs.to(model.device)
|
| 245 |
+
|
| 246 |
+
# generate
|
| 247 |
+
results = model.generate(**inputs,
|
| 248 |
+
max_new_tokens=2048 if not use_placeholder else 1,
|
| 249 |
+
logits_processor=LogitsProcessorList([logits_processor]),
|
| 250 |
+
return_dict_in_generate=True,
|
| 251 |
+
output_hidden_states=True
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# decode the generated ids
|
| 256 |
+
input_ids = inputs["input_ids"][0]
|
| 257 |
+
generated_ids = results.sequences[0][len(input_ids):]
|
| 258 |
+
output_text = tokenizer.decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
|
| 259 |
+
pred["output_text"] = output_text
|
| 260 |
+
|
| 261 |
+
# check if there are <POINTER_TOKEN> is inside the input_ids or generated_ids
|
| 262 |
+
if use_placeholder:
|
| 263 |
+
pointer_pad_mask = (inputs["input_ids"][0] == model.config.pointer_pad_token_id) # n_all_input_tokens
|
| 264 |
+
else:
|
| 265 |
+
pointer_pad_mask = (generated_ids[:-1] == model.config.pointer_pad_token_id) # seq_len_generated_ids-1
|
| 266 |
+
|
| 267 |
+
# if there are no <POINTER_TOKEN> in the input_ids or generated_ids, return the pred
|
| 268 |
+
if len(pointer_pad_mask) == 0:
|
| 269 |
+
return pred
|
| 270 |
+
|
| 271 |
+
# otherwise, get the coordinate from the action head
|
| 272 |
+
if use_placeholder:
|
| 273 |
+
decoder_hidden_states = results.hidden_states[0][-1][0] # n_all_input_tokens, hidden_size
|
| 274 |
+
else:
|
| 275 |
+
decoder_hidden_states = [step_hidden_states[-1][0] for step_hidden_states in results.hidden_states[1:]]
|
| 276 |
+
decoder_hidden_states = torch.cat(decoder_hidden_states, dim=0) # seq_len_generated_ids-1, hidden_size
|
| 277 |
+
decoder_hidden_states = decoder_hidden_states[pointer_pad_mask] # n_pointer_pad_tokens, hidden_size
|
| 278 |
+
|
| 279 |
+
# get the image embeddings as encoder vectors
|
| 280 |
+
# image_embeds = model.visual(inputs["pixel_values"], grid_thw=inputs["image_grid_thw"]) # n_image_tokens, hidden_size
|
| 281 |
+
image_mask = (inputs["input_ids"][0] == tokenizer.encode("<|image_pad|>")[0])
|
| 282 |
+
image_embeds = results.hidden_states[0][0][0][image_mask] # n_image_tokens, hidden_size
|
| 283 |
+
|
| 284 |
+
attn_scores, _ = model.multi_patch_pointer_head(image_embeds, decoder_hidden_states)
|
| 285 |
+
pred["attn_scores"] = attn_scores.tolist()
|
| 286 |
+
|
| 287 |
+
_, n_height, n_width = (inputs["image_grid_thw"][0] // model.visual.spatial_merge_size).tolist()
|
| 288 |
+
pred["n_width"] = n_width
|
| 289 |
+
pred["n_height"] = n_height
|
| 290 |
+
|
| 291 |
+
# get the topk points according to the attention scores
|
| 292 |
+
best_point, region_points, region_scores, region_points_all = get_prediction_region_point(attn_scores, n_width, n_height, return_all_regions=True, rect_center=False)
|
| 293 |
+
topk_points = region_points[:topk] if len(region_points) > topk else region_points
|
| 294 |
+
topk_values = region_scores[:topk] if len(region_scores) > topk else region_scores
|
| 295 |
+
topk_points_all = region_points_all[:topk] if len(region_points_all) > topk else region_points_all
|
| 296 |
+
pred["topk_points"] = topk_points
|
| 297 |
+
pred["topk_values"] = topk_values
|
| 298 |
+
pred["topk_points_all"] = topk_points_all
|
| 299 |
+
|
| 300 |
+
return pred
|
gui_actor/modeling.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLForConditionalGeneration
|
| 6 |
+
from gui_actor.constants import IGNORE_INDEX
|
| 7 |
+
from typing import List, Tuple, Union, Optional
|
| 8 |
+
from gui_actor.trainer import rank0_print
|
| 9 |
+
|
| 10 |
+
class QwenVLwithVisionHeadOutputWithPast(Qwen2VLCausalLMOutputWithPast):
|
| 11 |
+
"""
|
| 12 |
+
Output class for Qwen2VL with pointer head, extending the base output class.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
lm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
|
| 16 |
+
Language modeling loss.
|
| 17 |
+
pointer_loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
|
| 18 |
+
Vision pointer network loss.
|
| 19 |
+
pointer_scores (`List[torch.FloatTensor]`, *optional*):
|
| 20 |
+
Attention scores from the pointer network, one tensor per batch item.
|
| 21 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
|
| 22 |
+
Combined loss (weighted sum of lm_loss and pointer_loss).
|
| 23 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 24 |
+
Prediction scores from the language modeling head.
|
| 25 |
+
past_key_values, hidden_states, attentions, rope_deltas:
|
| 26 |
+
Same as parent class.
|
| 27 |
+
"""
|
| 28 |
+
def __init__(self, lm_loss=None, pointer_loss=None, pointer_scores=None, *args, **kwargs):
|
| 29 |
+
super().__init__(*args, **kwargs)
|
| 30 |
+
self.lm_loss = lm_loss
|
| 31 |
+
self.pointer_loss = pointer_loss
|
| 32 |
+
self.pointer_scores = pointer_scores
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class VisionHead_MultiPatch(nn.Module):
|
| 36 |
+
def __init__(self, d_model, projection_dim, num_attention_heads=8, dropout_rate=0.1):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.d_model = d_model
|
| 39 |
+
|
| 40 |
+
# Note: We omit additional normalization here because Qwen2VL
|
| 41 |
+
# already normalizes hidden states using RMSNorm.
|
| 42 |
+
self.projection_enc = nn.Sequential(
|
| 43 |
+
nn.Linear(d_model, projection_dim),
|
| 44 |
+
nn.GELU(),
|
| 45 |
+
nn.Linear(projection_dim, d_model)
|
| 46 |
+
)
|
| 47 |
+
self.projection_dec = nn.Sequential(
|
| 48 |
+
nn.Linear(d_model, projection_dim),
|
| 49 |
+
nn.GELU(),
|
| 50 |
+
nn.Linear(projection_dim, d_model)
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Add self-attention layer for visual features
|
| 54 |
+
self.self_attention = nn.MultiheadAttention(
|
| 55 |
+
embed_dim=d_model,
|
| 56 |
+
num_heads=num_attention_heads,
|
| 57 |
+
dropout=dropout_rate,
|
| 58 |
+
batch_first=True
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Layer normalization and residual connection
|
| 62 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
| 63 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 64 |
+
|
| 65 |
+
def forward(self,
|
| 66 |
+
hidden_state_enc, # shape: [n_enc, d_model] where n_enc can vary with image size
|
| 67 |
+
hidden_state_dec, # shape: [n_dec, d_model] there can be multiple query in one sample
|
| 68 |
+
labels: Optional[torch.Tensor] = None, # shape: [n_dec, n_enc], binary mask of patches in bbox
|
| 69 |
+
do_single_patch: bool = False,
|
| 70 |
+
):
|
| 71 |
+
|
| 72 |
+
enc_input = hidden_state_enc.unsqueeze(0)
|
| 73 |
+
attn_output, _ = self.self_attention(
|
| 74 |
+
query=enc_input,
|
| 75 |
+
key=enc_input,
|
| 76 |
+
value=enc_input,
|
| 77 |
+
# attn_mask=attention_mask,
|
| 78 |
+
need_weights=False
|
| 79 |
+
)
|
| 80 |
+
# Residual connection and layer normalization
|
| 81 |
+
hidden_state_enc_ctx = self.layer_norm(enc_input + self.dropout(attn_output))
|
| 82 |
+
# Remove batch dimension
|
| 83 |
+
hidden_state_enc_ctx = hidden_state_enc_ctx.squeeze(0) # [n_enc, d_model]
|
| 84 |
+
|
| 85 |
+
# Apply the projection networks.
|
| 86 |
+
proj_enc = self.projection_enc(hidden_state_enc_ctx) # [n_enc, d_model]
|
| 87 |
+
proj_dec = self.projection_dec(hidden_state_dec) # [n_dec, d_model]
|
| 88 |
+
|
| 89 |
+
# Compute scaled dot-product attention scores.
|
| 90 |
+
# Scaling by sqrt(d_model) is critical regardless of variable n_enc.
|
| 91 |
+
scaling = self.d_model ** 0.5
|
| 92 |
+
patch_logits = torch.matmul(proj_dec, proj_enc.transpose(0, 1)) / scaling # [n_dec, n_enc]
|
| 93 |
+
|
| 94 |
+
# Softmax normalization is applied along the encoder dimension.
|
| 95 |
+
attn_weights = F.softmax(patch_logits, dim=-1)
|
| 96 |
+
|
| 97 |
+
loss = None
|
| 98 |
+
if (labels is not None) and (not do_single_patch):
|
| 99 |
+
epsilon = 1e-8
|
| 100 |
+
labels_float = labels.float()
|
| 101 |
+
# Normalize each row to get target probability distribution
|
| 102 |
+
target_dist = labels_float / (labels_float.sum(dim=-1, keepdim=True) + epsilon)
|
| 103 |
+
|
| 104 |
+
# Apply log_softmax to logits
|
| 105 |
+
pred_log_probs = F.log_softmax(patch_logits, dim=-1)
|
| 106 |
+
# Use KL divergence as loss
|
| 107 |
+
loss = F.kl_div(pred_log_probs, target_dist, reduction='batchmean')
|
| 108 |
+
|
| 109 |
+
if do_single_patch and (labels is not None):
|
| 110 |
+
loss = F.cross_entropy(attn_scores, labels)
|
| 111 |
+
|
| 112 |
+
return attn_weights, loss
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class Qwen2VLForConditionalGenerationWithPointer(Qwen2VLForConditionalGeneration):
|
| 116 |
+
def __init__(self, *args, **kwargs):
|
| 117 |
+
super().__init__(*args, **kwargs)
|
| 118 |
+
self.multi_patch_pointer_head = VisionHead_MultiPatch(self.config.hidden_size, self.config.hidden_size)
|
| 119 |
+
self.pointer_loss_weight = kwargs.get("pointer_loss_weight", 1.0)
|
| 120 |
+
self.lm_loss_weight = kwargs.get("lm_loss_weight", 1.0)
|
| 121 |
+
self.post_init()
|
| 122 |
+
|
| 123 |
+
def reset_loss_weights(self, pointer_loss_weight, lm_loss_weight):
|
| 124 |
+
self.pointer_loss_weight = pointer_loss_weight
|
| 125 |
+
self.lm_loss_weight = lm_loss_weight
|
| 126 |
+
|
| 127 |
+
def forward(self,
|
| 128 |
+
input_ids: torch.LongTensor = None, # (batch_size, seq_len)
|
| 129 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 130 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 131 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 132 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 133 |
+
labels: Optional[torch.LongTensor] = None,
|
| 134 |
+
use_cache: Optional[bool] = None,
|
| 135 |
+
output_attentions: Optional[bool] = None,
|
| 136 |
+
output_hidden_states: Optional[bool] = None,
|
| 137 |
+
return_dict: Optional[bool] = None,
|
| 138 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 139 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
| 140 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
| 141 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
| 142 |
+
rope_deltas: Optional[torch.LongTensor] = None,
|
| 143 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 144 |
+
# Grounding
|
| 145 |
+
visual_token_indices_of_coordinates: Optional[torch.Tensor] = None, # shape: (batch_size, n_target); each element is the ground-truth index of the visual token that should be attended to for the corresponding target token
|
| 146 |
+
multi_patch_labels: Optional[torch.Tensor] = None, # shape: list [(n_target, n_visual), ...]; binary mask of patches in bbox
|
| 147 |
+
if_multi_patch: bool = True,
|
| 148 |
+
coordinates: Optional[List[Tuple[float, float]]] = None,
|
| 149 |
+
verbose: bool = False) -> Union[Tuple, QwenVLwithVisionHeadOutputWithPast]:
|
| 150 |
+
|
| 151 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 152 |
+
output_hidden_states = (
|
| 153 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 154 |
+
)
|
| 155 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 156 |
+
|
| 157 |
+
if verbose:
|
| 158 |
+
rank0_print(f"input_ids: {input_ids.shape}, {input_ids[0][:5]}...")
|
| 159 |
+
rank0_print(f"labels: {labels.shape}, {labels[0][:5]}...")
|
| 160 |
+
rank0_print(f"pixel_values: {pixel_values.shape}")
|
| 161 |
+
rank0_print(f"image_grid_thw: {image_grid_thw.shape}, {image_grid_thw}")
|
| 162 |
+
rank0_print(f"coordinates: {coordinates}")
|
| 163 |
+
rank0_print(f"visual_token_indices_of_coordinates: {visual_token_indices_of_coordinates}")
|
| 164 |
+
rank0_print(f"return_dict: {return_dict}")
|
| 165 |
+
|
| 166 |
+
if inputs_embeds is None:
|
| 167 |
+
inputs_embeds = self.model.embed_tokens(input_ids) # shape: (batch_size, seq_len, d_model)
|
| 168 |
+
if pixel_values is not None:
|
| 169 |
+
pixel_values = pixel_values.type(self.visual.dtype)
|
| 170 |
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
| 171 |
+
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
| 172 |
+
n_image_features = image_embeds.shape[0]
|
| 173 |
+
if n_image_tokens != n_image_features:
|
| 174 |
+
raise ValueError(
|
| 175 |
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
| 176 |
+
)
|
| 177 |
+
image_mask = (
|
| 178 |
+
(input_ids == self.config.image_token_id)
|
| 179 |
+
.unsqueeze(-1)
|
| 180 |
+
.expand_as(inputs_embeds)
|
| 181 |
+
.to(inputs_embeds.device)
|
| 182 |
+
)
|
| 183 |
+
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 184 |
+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
| 185 |
+
|
| 186 |
+
if pixel_values_videos is not None:
|
| 187 |
+
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
| 188 |
+
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
| 189 |
+
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
| 190 |
+
n_video_features = video_embeds.shape[0]
|
| 191 |
+
if n_video_tokens != n_video_features:
|
| 192 |
+
raise ValueError(
|
| 193 |
+
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
| 194 |
+
)
|
| 195 |
+
video_mask = (
|
| 196 |
+
(input_ids == self.config.video_token_id)
|
| 197 |
+
.unsqueeze(-1)
|
| 198 |
+
.expand_as(inputs_embeds)
|
| 199 |
+
.to(inputs_embeds.device)
|
| 200 |
+
)
|
| 201 |
+
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 202 |
+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
| 203 |
+
|
| 204 |
+
if attention_mask is not None:
|
| 205 |
+
attention_mask = attention_mask.to(inputs_embeds.device)
|
| 206 |
+
|
| 207 |
+
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
| 208 |
+
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
| 209 |
+
# calculate RoPE index once per generation in the pre-fill stage only
|
| 210 |
+
if (
|
| 211 |
+
(cache_position is not None and cache_position[0] == 0)
|
| 212 |
+
or self.rope_deltas is None
|
| 213 |
+
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
| 214 |
+
):
|
| 215 |
+
position_ids, rope_deltas = self.get_rope_index(
|
| 216 |
+
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
| 217 |
+
)
|
| 218 |
+
self.rope_deltas = rope_deltas
|
| 219 |
+
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
| 220 |
+
else:
|
| 221 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
| 222 |
+
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
|
| 223 |
+
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
| 224 |
+
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
| 225 |
+
if cache_position is not None: # otherwise `deltas` is an int `0`
|
| 226 |
+
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
| 227 |
+
delta = delta.to(position_ids.device)
|
| 228 |
+
position_ids = position_ids.add(delta)
|
| 229 |
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
| 230 |
+
|
| 231 |
+
outputs = self.model(
|
| 232 |
+
input_ids=None,
|
| 233 |
+
position_ids=position_ids,
|
| 234 |
+
attention_mask=attention_mask,
|
| 235 |
+
past_key_values=past_key_values,
|
| 236 |
+
inputs_embeds=inputs_embeds,
|
| 237 |
+
use_cache=use_cache,
|
| 238 |
+
output_attentions=output_attentions,
|
| 239 |
+
output_hidden_states=output_hidden_states,
|
| 240 |
+
return_dict=return_dict,
|
| 241 |
+
cache_position=cache_position,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
hidden_states = outputs[0] # shape: (batch_size, seq_len, d_model)
|
| 245 |
+
logits = self.lm_head(hidden_states)
|
| 246 |
+
|
| 247 |
+
lm_loss = None
|
| 248 |
+
if labels is not None and self.lm_loss_weight > 0:
|
| 249 |
+
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
| 250 |
+
logits = logits.float()
|
| 251 |
+
# Shift so that tokens < n predict n
|
| 252 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 253 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 254 |
+
# Flatten the tokens
|
| 255 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 256 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 257 |
+
shift_labels = shift_labels.view(-1)
|
| 258 |
+
# Enable model parallelism
|
| 259 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 260 |
+
lm_loss = loss_fct(shift_logits, shift_labels)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
# If vision supervision is requested, process the action head.
|
| 264 |
+
pointer_loss = None
|
| 265 |
+
pointer_scores = []
|
| 266 |
+
if visual_token_indices_of_coordinates is not None:
|
| 267 |
+
batch_size = input_ids.shape[0]
|
| 268 |
+
pointer_losses = []
|
| 269 |
+
|
| 270 |
+
# Process each sample individually because the number of visual and target tokens may vary.
|
| 271 |
+
for i in range(batch_size):
|
| 272 |
+
dummy_target = False
|
| 273 |
+
|
| 274 |
+
# Get the token ids and corresponding hidden states for sample i.
|
| 275 |
+
token_ids = input_ids[i] # shape: (seq_length,)
|
| 276 |
+
hs = hidden_states[i] # shape: (seq_length, d_model)
|
| 277 |
+
|
| 278 |
+
# Identify visual tokens indices.
|
| 279 |
+
visual_mask = (token_ids == self.config.image_token_id)
|
| 280 |
+
visual_indices = torch.nonzero(visual_mask, as_tuple=False).squeeze(-1) # shape: (n_visual,)
|
| 281 |
+
|
| 282 |
+
# Identify target tokens (the ones that should attend to visual features).
|
| 283 |
+
target_mask = (token_ids == self.config.pointer_pad_token_id)
|
| 284 |
+
target_indices = torch.nonzero(target_mask, as_tuple=False).squeeze(-1)
|
| 285 |
+
|
| 286 |
+
# If either visual or target tokens are missing, skip this sample.
|
| 287 |
+
if visual_indices.numel() == 0:
|
| 288 |
+
raise ValueError(f"No visual or target tokens found for sample {i}.")
|
| 289 |
+
if target_indices.numel() == 0:
|
| 290 |
+
target_indices = torch.tensor([hs.shape[0] - 1]) # take the last token as the dummy target token
|
| 291 |
+
gt = torch.tensor([0]).to(hs.device) # take the first visual token as the dummy ground truth
|
| 292 |
+
if if_multi_patch: # task the first 4 visual tokens as the ground truth
|
| 293 |
+
sample_labels = torch.zeros_like(visual_indices).unsqueeze(0)
|
| 294 |
+
sample_labels[0][:4] = 1
|
| 295 |
+
dummy_target = True
|
| 296 |
+
else:
|
| 297 |
+
# For supervision, we assume that visual_token_indices_of_coordinates[i] is a tensor of shape (n_target,)
|
| 298 |
+
# where each element is an integer in the range [0, n_visual-1] indicating the ground-truth visual token.
|
| 299 |
+
gt = visual_token_indices_of_coordinates[i].to(hs.device) # shape: (n_target,)
|
| 300 |
+
if if_multi_patch:
|
| 301 |
+
sample_labels = multi_patch_labels[i]
|
| 302 |
+
|
| 303 |
+
# Gather the corresponding hidden state representations.
|
| 304 |
+
# visual_hidden = hs[visual_indices] # shape: (n_visual, d_model)
|
| 305 |
+
visual_embeds = inputs_embeds[i][visual_indices]
|
| 306 |
+
target_hidden = hs[target_indices] # shape: (n_target, d_model)
|
| 307 |
+
|
| 308 |
+
# Calculate loss for multi-patch mode
|
| 309 |
+
if if_multi_patch:
|
| 310 |
+
# Ensure the number of targets matches between sample and labels
|
| 311 |
+
if sample_labels.shape[0] != target_indices.shape[0]:
|
| 312 |
+
raise ValueError(f"Sample {i} has mismatched target counts: {sample_labels.shape[0]} labels but found {target_indices.shape[0]} target tokens")
|
| 313 |
+
|
| 314 |
+
# Process using VisionHead_MultiPatch
|
| 315 |
+
attn_scores, loss_v = self.multi_patch_pointer_head(
|
| 316 |
+
visual_embeds,
|
| 317 |
+
target_hidden,
|
| 318 |
+
labels=sample_labels
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
else:
|
| 322 |
+
# Deprecated branch - single patch mode is no longer used
|
| 323 |
+
# Run the action head to compute the attention (from target tokens to visual tokens) and its loss.
|
| 324 |
+
attn_scores, loss_v = self.pointer_head(visual_embeds, target_hidden, labels=gt)
|
| 325 |
+
|
| 326 |
+
pointer_scores.append(attn_scores.detach().cpu())
|
| 327 |
+
|
| 328 |
+
pointer_losses.append(loss_v * 0.0 if dummy_target else loss_v)
|
| 329 |
+
|
| 330 |
+
pointer_loss = torch.stack(pointer_losses).mean()
|
| 331 |
+
|
| 332 |
+
# Combine the LM loss and vision loss using the provided loss weights.
|
| 333 |
+
|
| 334 |
+
if lm_loss is None:
|
| 335 |
+
total_loss = pointer_loss
|
| 336 |
+
elif pointer_loss is None:
|
| 337 |
+
total_loss = lm_loss
|
| 338 |
+
else:
|
| 339 |
+
total_loss = self.lm_loss_weight * lm_loss + self.pointer_loss_weight * pointer_loss
|
| 340 |
+
|
| 341 |
+
if return_dict:
|
| 342 |
+
return QwenVLwithVisionHeadOutputWithPast(
|
| 343 |
+
lm_loss=lm_loss,
|
| 344 |
+
pointer_loss=pointer_loss,
|
| 345 |
+
pointer_scores=pointer_scores,
|
| 346 |
+
loss=total_loss,
|
| 347 |
+
logits=logits,
|
| 348 |
+
past_key_values=outputs.past_key_values,
|
| 349 |
+
hidden_states=outputs.hidden_states,
|
| 350 |
+
attentions=outputs.attentions,
|
| 351 |
+
rope_deltas=self.rope_deltas,
|
| 352 |
+
)
|
| 353 |
+
else:
|
| 354 |
+
# When labels are provided, parent's forward returns a tuple with loss as the first element.
|
| 355 |
+
if labels is not None:
|
| 356 |
+
# Replace the LM loss with the combined loss.
|
| 357 |
+
output = (lm_loss, pointer_loss, logits, pointer_scores,) + outputs[1:]
|
| 358 |
+
print(f"returning: total_loss, logits, pointer_scores, ...")
|
| 359 |
+
return (total_loss,) + output if total_loss is not None else output
|
| 360 |
+
else:
|
| 361 |
+
return outputs
|
gui_actor/modeling_qwen25vl.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast, Qwen2_5_VLForConditionalGeneration
|
| 6 |
+
from gui_actor.constants import IGNORE_INDEX
|
| 7 |
+
from typing import List, Tuple, Union, Optional
|
| 8 |
+
from gui_actor.trainer import rank0_print
|
| 9 |
+
|
| 10 |
+
class QwenVLwithVisionHeadOutputWithPast(Qwen2_5_VLCausalLMOutputWithPast):
|
| 11 |
+
"""
|
| 12 |
+
Output class for Qwen2_5_VL with pointer head, extending the base output class.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
lm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
|
| 16 |
+
Language modeling loss.
|
| 17 |
+
pointer_loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
|
| 18 |
+
Vision pointer network loss.
|
| 19 |
+
pointer_scores (`List[torch.FloatTensor]`, *optional*):
|
| 20 |
+
Attention scores from the pointer network, one tensor per batch item.
|
| 21 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
|
| 22 |
+
Combined loss (weighted sum of lm_loss and pointer_loss).
|
| 23 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 24 |
+
Prediction scores from the language modeling head.
|
| 25 |
+
past_key_values, hidden_states, attentions, rope_deltas:
|
| 26 |
+
Same as parent class.
|
| 27 |
+
"""
|
| 28 |
+
def __init__(self, lm_loss=None, pointer_loss=None, pointer_scores=None, *args, **kwargs):
|
| 29 |
+
super().__init__(*args, **kwargs)
|
| 30 |
+
self.lm_loss = lm_loss
|
| 31 |
+
self.pointer_loss = pointer_loss
|
| 32 |
+
self.pointer_scores = pointer_scores
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class VisionHead_MultiPatch(nn.Module):
|
| 36 |
+
def __init__(self, d_model, projection_dim, num_attention_heads=8, dropout_rate=0.1):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.d_model = d_model
|
| 39 |
+
|
| 40 |
+
# Note: We omit additional normalization here because Qwen2VL
|
| 41 |
+
# already normalizes hidden states using RMSNorm.
|
| 42 |
+
self.projection_enc = nn.Sequential(
|
| 43 |
+
nn.Linear(d_model, projection_dim),
|
| 44 |
+
nn.GELU(),
|
| 45 |
+
nn.Linear(projection_dim, d_model)
|
| 46 |
+
)
|
| 47 |
+
self.projection_dec = nn.Sequential(
|
| 48 |
+
nn.Linear(d_model, projection_dim),
|
| 49 |
+
nn.GELU(),
|
| 50 |
+
nn.Linear(projection_dim, d_model)
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Add self-attention layer for visual features
|
| 54 |
+
self.self_attention = nn.MultiheadAttention(
|
| 55 |
+
embed_dim=d_model,
|
| 56 |
+
num_heads=num_attention_heads,
|
| 57 |
+
dropout=dropout_rate,
|
| 58 |
+
batch_first=True
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Layer normalization and residual connection
|
| 62 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
| 63 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 64 |
+
|
| 65 |
+
def forward(self,
|
| 66 |
+
hidden_state_enc, # shape: [n_enc, d_model] where n_enc can vary with image size
|
| 67 |
+
hidden_state_dec, # shape: [n_dec, d_model] there can be multiple query in one sample
|
| 68 |
+
labels: Optional[torch.Tensor] = None, # shape: [n_dec, n_enc], binary mask of patches in bbox
|
| 69 |
+
do_single_patch: bool = False,
|
| 70 |
+
):
|
| 71 |
+
|
| 72 |
+
enc_input = hidden_state_enc.unsqueeze(0)
|
| 73 |
+
attn_output, _ = self.self_attention(
|
| 74 |
+
query=enc_input,
|
| 75 |
+
key=enc_input,
|
| 76 |
+
value=enc_input,
|
| 77 |
+
# attn_mask=attention_mask,
|
| 78 |
+
need_weights=False
|
| 79 |
+
)
|
| 80 |
+
# Residual connection and layer normalization
|
| 81 |
+
hidden_state_enc_ctx = self.layer_norm(enc_input + self.dropout(attn_output))
|
| 82 |
+
# Remove batch dimension
|
| 83 |
+
hidden_state_enc_ctx = hidden_state_enc_ctx.squeeze(0) # [n_enc, d_model]
|
| 84 |
+
|
| 85 |
+
# Apply the projection networks.
|
| 86 |
+
proj_enc = self.projection_enc(hidden_state_enc_ctx) # [n_enc, d_model]
|
| 87 |
+
proj_dec = self.projection_dec(hidden_state_dec) # [n_dec, d_model]
|
| 88 |
+
|
| 89 |
+
# Compute scaled dot-product attention scores.
|
| 90 |
+
# Scaling by sqrt(d_model) is critical regardless of variable n_enc.
|
| 91 |
+
scaling = self.d_model ** 0.5
|
| 92 |
+
patch_logits = torch.matmul(proj_dec, proj_enc.transpose(0, 1)) / scaling # [n_dec, n_enc]
|
| 93 |
+
|
| 94 |
+
# Softmax normalization is applied along the encoder dimension.
|
| 95 |
+
attn_weights = F.softmax(patch_logits, dim=-1)
|
| 96 |
+
|
| 97 |
+
loss = None
|
| 98 |
+
if (labels is not None) and (not do_single_patch):
|
| 99 |
+
epsilon = 1e-8
|
| 100 |
+
labels_float = labels.float()
|
| 101 |
+
# Normalize each row to get target probability distribution
|
| 102 |
+
target_dist = labels_float / (labels_float.sum(dim=-1, keepdim=True) + epsilon)
|
| 103 |
+
|
| 104 |
+
# Apply log_softmax to logits
|
| 105 |
+
pred_log_probs = F.log_softmax(patch_logits, dim=-1)
|
| 106 |
+
# Use KL divergence as loss
|
| 107 |
+
loss = F.kl_div(pred_log_probs, target_dist, reduction='batchmean')
|
| 108 |
+
|
| 109 |
+
if do_single_patch and (labels is not None):
|
| 110 |
+
loss = F.cross_entropy(attn_scores, labels)
|
| 111 |
+
|
| 112 |
+
return attn_weights, loss
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class Qwen2_5_VLForConditionalGenerationWithPointer(Qwen2_5_VLForConditionalGeneration):
|
| 116 |
+
def __init__(self, *args, **kwargs):
|
| 117 |
+
super().__init__(*args, **kwargs)
|
| 118 |
+
self.multi_patch_pointer_head = VisionHead_MultiPatch(self.config.hidden_size, self.config.hidden_size)
|
| 119 |
+
self.pointer_loss_weight = kwargs.get("pointer_loss_weight", 1.0)
|
| 120 |
+
self.lm_loss_weight = kwargs.get("lm_loss_weight", 1.0)
|
| 121 |
+
self.post_init()
|
| 122 |
+
|
| 123 |
+
def reset_loss_weights(self, pointer_loss_weight, lm_loss_weight):
|
| 124 |
+
self.pointer_loss_weight = pointer_loss_weight
|
| 125 |
+
self.lm_loss_weight = lm_loss_weight
|
| 126 |
+
|
| 127 |
+
def forward(self,
|
| 128 |
+
input_ids: torch.LongTensor = None, # (batch_size, seq_len)
|
| 129 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 130 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 131 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 132 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 133 |
+
labels: Optional[torch.LongTensor] = None,
|
| 134 |
+
use_cache: Optional[bool] = None,
|
| 135 |
+
output_attentions: Optional[bool] = None,
|
| 136 |
+
output_hidden_states: Optional[bool] = None,
|
| 137 |
+
return_dict: Optional[bool] = None,
|
| 138 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 139 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
| 140 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
| 141 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
| 142 |
+
rope_deltas: Optional[torch.LongTensor] = None,
|
| 143 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 144 |
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
| 145 |
+
# Grounding
|
| 146 |
+
visual_token_indices_of_coordinates: Optional[torch.Tensor] = None, # shape: (batch_size, n_target); each element is the ground-truth index of the visual token that should be attended to for the corresponding target token
|
| 147 |
+
multi_patch_labels: Optional[torch.Tensor] = None, # shape: list [(n_target, n_visual), ...]; binary mask of patches in bbox
|
| 148 |
+
if_multi_patch: bool = True,
|
| 149 |
+
coordinates: Optional[List[Tuple[float, float]]] = None,
|
| 150 |
+
verbose: bool = False) -> Union[Tuple, QwenVLwithVisionHeadOutputWithPast]:
|
| 151 |
+
|
| 152 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 153 |
+
output_hidden_states = (
|
| 154 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 155 |
+
)
|
| 156 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 157 |
+
|
| 158 |
+
if verbose:
|
| 159 |
+
rank0_print(f"input_ids: {input_ids.shape}, {input_ids[0][:5]}...")
|
| 160 |
+
rank0_print(f"labels: {labels.shape}, {labels[0][:5]}...")
|
| 161 |
+
rank0_print(f"pixel_values: {pixel_values.shape}")
|
| 162 |
+
rank0_print(f"image_grid_thw: {image_grid_thw.shape}, {image_grid_thw}")
|
| 163 |
+
rank0_print(f"coordinates: {coordinates}")
|
| 164 |
+
rank0_print(f"visual_token_indices_of_coordinates: {visual_token_indices_of_coordinates}")
|
| 165 |
+
rank0_print(f"return_dict: {return_dict}")
|
| 166 |
+
|
| 167 |
+
if inputs_embeds is None:
|
| 168 |
+
inputs_embeds = self.model.embed_tokens(input_ids) # shape: (batch_size, seq_len, d_model)
|
| 169 |
+
if pixel_values is not None:
|
| 170 |
+
pixel_values = pixel_values.type(self.visual.dtype)
|
| 171 |
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
| 172 |
+
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
| 173 |
+
n_image_features = image_embeds.shape[0]
|
| 174 |
+
if n_image_tokens != n_image_features:
|
| 175 |
+
raise ValueError(
|
| 176 |
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
| 177 |
+
)
|
| 178 |
+
image_mask = (
|
| 179 |
+
(input_ids == self.config.image_token_id)
|
| 180 |
+
.unsqueeze(-1)
|
| 181 |
+
.expand_as(inputs_embeds)
|
| 182 |
+
.to(inputs_embeds.device)
|
| 183 |
+
)
|
| 184 |
+
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 185 |
+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
| 186 |
+
|
| 187 |
+
if pixel_values_videos is not None:
|
| 188 |
+
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
| 189 |
+
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
| 190 |
+
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
| 191 |
+
n_video_features = video_embeds.shape[0]
|
| 192 |
+
if n_video_tokens != n_video_features:
|
| 193 |
+
raise ValueError(
|
| 194 |
+
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
| 195 |
+
)
|
| 196 |
+
video_mask = (
|
| 197 |
+
(input_ids == self.config.video_token_id)
|
| 198 |
+
.unsqueeze(-1)
|
| 199 |
+
.expand_as(inputs_embeds)
|
| 200 |
+
.to(inputs_embeds.device)
|
| 201 |
+
)
|
| 202 |
+
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 203 |
+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
| 204 |
+
|
| 205 |
+
if attention_mask is not None:
|
| 206 |
+
attention_mask = attention_mask.to(inputs_embeds.device)
|
| 207 |
+
|
| 208 |
+
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
| 209 |
+
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
| 210 |
+
# calculate RoPE index once per generation in the pre-fill stage only
|
| 211 |
+
if (
|
| 212 |
+
(cache_position is not None and cache_position[0] == 0)
|
| 213 |
+
or self.rope_deltas is None
|
| 214 |
+
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
| 215 |
+
):
|
| 216 |
+
position_ids, rope_deltas = self.get_rope_index(
|
| 217 |
+
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
| 218 |
+
)
|
| 219 |
+
self.rope_deltas = rope_deltas
|
| 220 |
+
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
| 221 |
+
else:
|
| 222 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
| 223 |
+
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
|
| 224 |
+
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
| 225 |
+
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
| 226 |
+
if cache_position is not None: # otherwise `deltas` is an int `0`
|
| 227 |
+
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
| 228 |
+
delta = delta.to(position_ids.device)
|
| 229 |
+
position_ids = position_ids.add(delta)
|
| 230 |
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
| 231 |
+
|
| 232 |
+
outputs = self.model(
|
| 233 |
+
input_ids=None,
|
| 234 |
+
position_ids=position_ids,
|
| 235 |
+
attention_mask=attention_mask,
|
| 236 |
+
past_key_values=past_key_values,
|
| 237 |
+
inputs_embeds=inputs_embeds,
|
| 238 |
+
use_cache=use_cache,
|
| 239 |
+
output_attentions=output_attentions,
|
| 240 |
+
output_hidden_states=output_hidden_states,
|
| 241 |
+
return_dict=return_dict,
|
| 242 |
+
cache_position=cache_position,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
hidden_states = outputs[0] # shape: (batch_size, seq_len, d_model)
|
| 246 |
+
logits = self.lm_head(hidden_states)
|
| 247 |
+
|
| 248 |
+
lm_loss = None
|
| 249 |
+
if labels is not None and self.lm_loss_weight > 0:
|
| 250 |
+
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
| 251 |
+
logits = logits.float()
|
| 252 |
+
# Shift so that tokens < n predict n
|
| 253 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 254 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 255 |
+
# Flatten the tokens
|
| 256 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 257 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 258 |
+
shift_labels = shift_labels.view(-1)
|
| 259 |
+
# Enable model parallelism
|
| 260 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 261 |
+
lm_loss = loss_fct(shift_logits, shift_labels)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
# If vision supervision is requested, process the action head.
|
| 265 |
+
pointer_loss = None
|
| 266 |
+
pointer_scores = []
|
| 267 |
+
if visual_token_indices_of_coordinates is not None:
|
| 268 |
+
batch_size = input_ids.shape[0]
|
| 269 |
+
pointer_losses = []
|
| 270 |
+
|
| 271 |
+
# Process each sample individually because the number of visual and target tokens may vary.
|
| 272 |
+
for i in range(batch_size):
|
| 273 |
+
dummy_target = False
|
| 274 |
+
|
| 275 |
+
# Get the token ids and corresponding hidden states for sample i.
|
| 276 |
+
token_ids = input_ids[i] # shape: (seq_length,)
|
| 277 |
+
hs = hidden_states[i] # shape: (seq_length, d_model)
|
| 278 |
+
|
| 279 |
+
# Identify visual tokens indices.
|
| 280 |
+
visual_mask = (token_ids == self.config.image_token_id)
|
| 281 |
+
visual_indices = torch.nonzero(visual_mask, as_tuple=False).squeeze(-1) # shape: (n_visual,)
|
| 282 |
+
|
| 283 |
+
# Identify target tokens (the ones that should attend to visual features).
|
| 284 |
+
target_mask = (token_ids == self.config.pointer_pad_token_id)
|
| 285 |
+
target_indices = torch.nonzero(target_mask, as_tuple=False).squeeze(-1)
|
| 286 |
+
|
| 287 |
+
# If either visual or target tokens are missing, skip this sample.
|
| 288 |
+
if visual_indices.numel() == 0:
|
| 289 |
+
raise ValueError(f"No visual or target tokens found for sample {i}.")
|
| 290 |
+
if target_indices.numel() == 0:
|
| 291 |
+
target_indices = torch.tensor([hs.shape[0] - 1]) # take the last token as the dummy target token
|
| 292 |
+
gt = torch.tensor([0]).to(hs.device) # take the first visual token as the dummy ground truth
|
| 293 |
+
if if_multi_patch: # task the first 4 visual tokens as the ground truth
|
| 294 |
+
sample_labels = torch.zeros_like(visual_indices).unsqueeze(0)
|
| 295 |
+
sample_labels[0][:4] = 1
|
| 296 |
+
# n_t = target_indices.size(0) # 目标 token 个数
|
| 297 |
+
# n_v = visual_indices.size(0)
|
| 298 |
+
# sample_labels = torch.zeros(
|
| 299 |
+
# (n_t, n_v), device=hs.device, dtype=torch.float
|
| 300 |
+
# )
|
| 301 |
+
# sample_labels[:, :min(4, n_v)] = 1
|
| 302 |
+
dummy_target = True
|
| 303 |
+
else:
|
| 304 |
+
# For supervision, we assume that visual_token_indices_of_coordinates[i] is a tensor of shape (n_target,)
|
| 305 |
+
# where each element is an integer in the range [0, n_visual-1] indicating the ground-truth visual token.
|
| 306 |
+
gt = visual_token_indices_of_coordinates[i].to(hs.device) # shape: (n_target,)
|
| 307 |
+
if if_multi_patch:
|
| 308 |
+
sample_labels = multi_patch_labels[i]
|
| 309 |
+
# if sample_labels is None:
|
| 310 |
+
# n_t = target_indices.size(0) # 目标 token 个数
|
| 311 |
+
# n_v = visual_indices.size(0)
|
| 312 |
+
# sample_labels = torch.zeros(
|
| 313 |
+
# (n_t, n_v), device=hs.device, dtype=torch.float
|
| 314 |
+
# )
|
| 315 |
+
# sample_labels[:, :min(4, n_v)] = 1
|
| 316 |
+
# dummy_target = True
|
| 317 |
+
|
| 318 |
+
# Gather the corresponding hidden state representations.
|
| 319 |
+
# visual_hidden = hs[visual_indices] # shape: (n_visual, d_model)
|
| 320 |
+
visual_embeds = inputs_embeds[i][visual_indices]
|
| 321 |
+
target_hidden = hs[target_indices] # shape: (n_target, d_model)
|
| 322 |
+
|
| 323 |
+
# Calculate loss for multi-patch mode
|
| 324 |
+
if if_multi_patch:
|
| 325 |
+
# Ensure the number of targets matches between sample and labels
|
| 326 |
+
if sample_labels.shape[0] != target_indices.shape[0]:
|
| 327 |
+
raise ValueError(f"Sample {i} has mismatched target counts: {sample_labels.shape[0]} labels but found {target_indices.shape[0]} target tokens")
|
| 328 |
+
|
| 329 |
+
# Process using VisionHead_MultiPatch
|
| 330 |
+
attn_scores, loss_v = self.multi_patch_pointer_head(
|
| 331 |
+
visual_embeds,
|
| 332 |
+
target_hidden,
|
| 333 |
+
labels=sample_labels
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
else:
|
| 337 |
+
# Deprecated branch - single patch mode is no longer used
|
| 338 |
+
# Run the action head to compute the attention (from target tokens to visual tokens) and its loss.
|
| 339 |
+
attn_scores, loss_v = self.pointer_head(visual_embeds, target_hidden, labels=gt)
|
| 340 |
+
|
| 341 |
+
pointer_scores.append(attn_scores.detach().cpu())
|
| 342 |
+
|
| 343 |
+
pointer_losses.append(loss_v * 0.0 if dummy_target else loss_v)
|
| 344 |
+
|
| 345 |
+
pointer_loss = torch.stack(pointer_losses).mean()
|
| 346 |
+
|
| 347 |
+
# Combine the LM loss and vision loss using the provided loss weights.
|
| 348 |
+
|
| 349 |
+
if lm_loss is None:
|
| 350 |
+
total_loss = pointer_loss
|
| 351 |
+
elif pointer_loss is None:
|
| 352 |
+
total_loss = lm_loss
|
| 353 |
+
else:
|
| 354 |
+
total_loss = self.lm_loss_weight * lm_loss + self.pointer_loss_weight * pointer_loss
|
| 355 |
+
|
| 356 |
+
if return_dict:
|
| 357 |
+
return QwenVLwithVisionHeadOutputWithPast(
|
| 358 |
+
lm_loss=lm_loss,
|
| 359 |
+
pointer_loss=pointer_loss,
|
| 360 |
+
pointer_scores=pointer_scores,
|
| 361 |
+
loss=total_loss,
|
| 362 |
+
logits=logits,
|
| 363 |
+
past_key_values=outputs.past_key_values,
|
| 364 |
+
hidden_states=outputs.hidden_states,
|
| 365 |
+
attentions=outputs.attentions,
|
| 366 |
+
rope_deltas=self.rope_deltas,
|
| 367 |
+
)
|
| 368 |
+
else:
|
| 369 |
+
# When labels are provided, parent's forward returns a tuple with loss as the first element.
|
| 370 |
+
if labels is not None:
|
| 371 |
+
# Replace the LM loss with the combined loss.
|
| 372 |
+
output = (lm_loss, pointer_loss, logits, pointer_scores,) + outputs[1:]
|
| 373 |
+
print(f"returning: total_loss, logits, pointer_scores, ...")
|
| 374 |
+
return (total_loss,) + output if total_loss is not None else output
|
| 375 |
+
else:
|
| 376 |
+
return outputs
|
gui_actor/trainer.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import timedelta
|
| 2 |
+
from functools import wraps
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
import transformers
|
| 8 |
+
from accelerate import Accelerator, DataLoaderConfiguration
|
| 9 |
+
from accelerate.utils import GradientAccumulationPlugin, InitProcessGroupKwargs
|
| 10 |
+
from torch.utils.data import DataLoader, RandomSampler
|
| 11 |
+
from transformers import Trainer
|
| 12 |
+
from transformers.trainer import (
|
| 13 |
+
ALL_LAYERNORM_LAYERS,
|
| 14 |
+
get_parameter_names,
|
| 15 |
+
has_length,
|
| 16 |
+
is_accelerate_available,
|
| 17 |
+
is_datasets_available,
|
| 18 |
+
is_sagemaker_mp_enabled,
|
| 19 |
+
)
|
| 20 |
+
from transformers.trainer_pt_utils import LengthGroupedSampler as HFLengthGroupedSampler
|
| 21 |
+
from transformers.trainer_utils import seed_worker
|
| 22 |
+
from transformers.utils import logging
|
| 23 |
+
|
| 24 |
+
if is_datasets_available():
|
| 25 |
+
import datasets
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def rank0_print(*args):
|
| 29 |
+
if dist.is_initialized():
|
| 30 |
+
if dist.get_rank() == 0:
|
| 31 |
+
print(f"Rank {dist.get_rank()}: ", *args)
|
| 32 |
+
else:
|
| 33 |
+
print(*args)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def maybe_zero_3(param, ignore_status=False, name=None):
|
| 37 |
+
from deepspeed import zero
|
| 38 |
+
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
| 39 |
+
|
| 40 |
+
if hasattr(param, "ds_id"):
|
| 41 |
+
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE and not ignore_status:
|
| 42 |
+
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
|
| 43 |
+
with zero.GatheredParameters([param]):
|
| 44 |
+
param = param.data.detach().cpu().clone()
|
| 45 |
+
else:
|
| 46 |
+
param = param.detach().cpu().clone()
|
| 47 |
+
return param
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
|
| 51 |
+
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
|
| 52 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
|
| 53 |
+
return to_return
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
| 57 |
+
"""Collects the state dict and dump to disk."""
|
| 58 |
+
trainer.accelerator.wait_for_everyone()
|
| 59 |
+
torch.cuda.synchronize()
|
| 60 |
+
|
| 61 |
+
if trainer.deepspeed:
|
| 62 |
+
trainer.save_model(output_dir)
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
+
state_dict = trainer.model.state_dict()
|
| 66 |
+
if trainer.args.should_save:
|
| 67 |
+
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
| 68 |
+
del state_dict
|
| 69 |
+
trainer._save(output_dir, state_dict=cpu_state_dict)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class AGUVISTrainer(Trainer):
|
| 73 |
+
|
| 74 |
+
def __init__(self, *args, **kwargs):
|
| 75 |
+
super().__init__(*args, **kwargs)
|
| 76 |
+
|
| 77 |
+
original_save = self._save
|
| 78 |
+
original_save_model = self.save_model
|
| 79 |
+
|
| 80 |
+
def modify_eos_token(func):
|
| 81 |
+
@wraps(func)
|
| 82 |
+
def wrapper(*args, **kwargs):
|
| 83 |
+
tokenizer = self.processing_class.tokenizer
|
| 84 |
+
old_config_id = self.model.config.eos_token_id
|
| 85 |
+
old_eos_token = tokenizer.eos_token
|
| 86 |
+
old_generation_config_eos_token_id = (
|
| 87 |
+
self.model.generation_config.eos_token_id if hasattr(self.model, "generation_config") else None
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
new_eos_token_id = tokenizer.convert_tokens_to_ids("<|diff_marker|>")
|
| 92 |
+
self.model.config.eos_token_id = [new_eos_token_id]
|
| 93 |
+
tokenizer.eos_token = "<|diff_marker|>"
|
| 94 |
+
if hasattr(self.model, "generation_config"):
|
| 95 |
+
self.model.generation_config.eos_token_id = [new_eos_token_id]
|
| 96 |
+
|
| 97 |
+
print("Set eos token id to", new_eos_token_id)
|
| 98 |
+
print("Set eos token to", "<|diff_marker|>")
|
| 99 |
+
print("Set generation config eos token id to", [new_eos_token_id])
|
| 100 |
+
|
| 101 |
+
result = func(*args, **kwargs)
|
| 102 |
+
return result
|
| 103 |
+
finally:
|
| 104 |
+
self.model.config.eos_token_id = old_config_id
|
| 105 |
+
tokenizer.eos_token = old_eos_token
|
| 106 |
+
if hasattr(self.model, "generation_config") and old_generation_config_eos_token_id is not None:
|
| 107 |
+
self.model.generation_config.eos_token_id = old_generation_config_eos_token_id
|
| 108 |
+
|
| 109 |
+
print("Set eos token id back to", old_config_id)
|
| 110 |
+
print("Set eos token back to", old_eos_token)
|
| 111 |
+
if old_generation_config_eos_token_id is not None:
|
| 112 |
+
print("Set generation config eos token id back to", old_generation_config_eos_token_id)
|
| 113 |
+
|
| 114 |
+
return wrapper
|
| 115 |
+
|
| 116 |
+
self._save = modify_eos_token(original_save)
|
| 117 |
+
self.save_model = modify_eos_token(original_save_model)
|
| 118 |
+
|
| 119 |
+
def create_accelerator_and_postprocess(self):
|
| 120 |
+
grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
|
| 121 |
+
grad_acc_kwargs["sync_with_dataloader"] = False
|
| 122 |
+
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
|
| 123 |
+
|
| 124 |
+
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
|
| 125 |
+
|
| 126 |
+
# create accelerator object
|
| 127 |
+
dispatch_batches = getattr(self.args, "dispatch_batches", None)
|
| 128 |
+
split_batches = getattr(self.args, "split_batches", None)
|
| 129 |
+
self.dataloader_config = DataLoaderConfiguration(
|
| 130 |
+
dispatch_batches=dispatch_batches,
|
| 131 |
+
split_batches=split_batches,
|
| 132 |
+
)
|
| 133 |
+
self.accelerator = Accelerator(
|
| 134 |
+
dataloader_config=self.dataloader_config,
|
| 135 |
+
deepspeed_plugin=self.args.deepspeed_plugin,
|
| 136 |
+
gradient_accumulation_plugin=gradient_accumulation_plugin,
|
| 137 |
+
kwargs_handlers=[accelerator_kwargs],
|
| 138 |
+
)
|
| 139 |
+
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
|
| 140 |
+
self.gather_function = self.accelerator.gather_for_metrics
|
| 141 |
+
|
| 142 |
+
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
|
| 143 |
+
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
| 144 |
+
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
| 145 |
+
|
| 146 |
+
# post accelerator creation setup
|
| 147 |
+
if self.is_fsdp_enabled:
|
| 148 |
+
fsdp_plugin = self.accelerator.state.fsdp_plugin
|
| 149 |
+
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
|
| 150 |
+
"limit_all_gathers", fsdp_plugin.limit_all_gathers
|
| 151 |
+
)
|
| 152 |
+
if is_accelerate_available("0.23.0"):
|
| 153 |
+
fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get(
|
| 154 |
+
"activation_checkpointing", fsdp_plugin.activation_checkpointing
|
| 155 |
+
)
|
| 156 |
+
if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
|
| 157 |
+
raise ValueError(
|
| 158 |
+
"The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
|
| 159 |
+
"can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic "
|
| 160 |
+
"when using FSDP."
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None:
|
| 164 |
+
self.propagate_args_to_deepspeed()
|
| 165 |
+
|
| 166 |
+
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
| 167 |
+
if self.train_dataset is None or not has_length(self.train_dataset):
|
| 168 |
+
return None
|
| 169 |
+
|
| 170 |
+
if self.args.group_by_length:
|
| 171 |
+
lengths = self.train_dataset.lengths
|
| 172 |
+
return HFLengthGroupedSampler(
|
| 173 |
+
self.args.train_batch_size * self.args.gradient_accumulation_steps,
|
| 174 |
+
dataset=self.train_dataset,
|
| 175 |
+
lengths=lengths,
|
| 176 |
+
)
|
| 177 |
+
elif self.args.group_by_modality_length:
|
| 178 |
+
lengths = self.train_dataset.modality_lengths
|
| 179 |
+
return HFLengthGroupedSampler(
|
| 180 |
+
self.args.train_batch_size * self.args.gradient_accumulation_steps,
|
| 181 |
+
dataset=self.train_dataset,
|
| 182 |
+
lengths=lengths,
|
| 183 |
+
)
|
| 184 |
+
else:
|
| 185 |
+
return RandomSampler(self.train_dataset)
|
| 186 |
+
|
| 187 |
+
def get_train_dataloader(self) -> DataLoader:
|
| 188 |
+
"""
|
| 189 |
+
Returns the training [`~torch.utils.data.DataLoader`].
|
| 190 |
+
|
| 191 |
+
Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
|
| 192 |
+
training if necessary) otherwise.
|
| 193 |
+
|
| 194 |
+
Subclass and override this method if you want to inject some custom behavior.
|
| 195 |
+
"""
|
| 196 |
+
if self.train_dataset is None:
|
| 197 |
+
raise ValueError("Trainer: training requires a train_dataset.")
|
| 198 |
+
|
| 199 |
+
train_dataset = self.train_dataset
|
| 200 |
+
data_collator = self.data_collator
|
| 201 |
+
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
| 202 |
+
train_dataset = self._remove_unused_columns(train_dataset, description="training")
|
| 203 |
+
else:
|
| 204 |
+
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
|
| 205 |
+
|
| 206 |
+
dataloader_params = {
|
| 207 |
+
"batch_size": self._train_batch_size,
|
| 208 |
+
"collate_fn": data_collator,
|
| 209 |
+
"num_workers": self.args.dataloader_num_workers,
|
| 210 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
| 211 |
+
"persistent_workers": self.args.dataloader_persistent_workers,
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
| 215 |
+
dataloader_params["sampler"] = self._get_train_sampler()
|
| 216 |
+
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
| 217 |
+
dataloader_params["worker_init_fn"] = seed_worker
|
| 218 |
+
dataloader_params["prefetch_factor"] = (
|
| 219 |
+
self.args.dataloader_num_workers * 2 if self.args.dataloader_num_workers != 0 else None
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
| 223 |
+
|
| 224 |
+
return dataloader
|
| 225 |
+
|
| 226 |
+
def create_optimizer(self):
|
| 227 |
+
"""
|
| 228 |
+
Setup the optimizer.
|
| 229 |
+
|
| 230 |
+
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
| 231 |
+
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
|
| 232 |
+
"""
|
| 233 |
+
if is_sagemaker_mp_enabled():
|
| 234 |
+
return super().create_optimizer()
|
| 235 |
+
|
| 236 |
+
opt_model = self.model
|
| 237 |
+
|
| 238 |
+
if self.optimizer is None:
|
| 239 |
+
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
|
| 240 |
+
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
| 241 |
+
optimizer_grouped_parameters = [
|
| 242 |
+
{
|
| 243 |
+
"params": [
|
| 244 |
+
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
|
| 245 |
+
],
|
| 246 |
+
"weight_decay": self.args.weight_decay,
|
| 247 |
+
},
|
| 248 |
+
{
|
| 249 |
+
"params": [
|
| 250 |
+
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
|
| 251 |
+
],
|
| 252 |
+
"weight_decay": 0.0,
|
| 253 |
+
},
|
| 254 |
+
]
|
| 255 |
+
|
| 256 |
+
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
|
| 257 |
+
|
| 258 |
+
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
| 259 |
+
|
| 260 |
+
return self.optimizer
|
| 261 |
+
|
| 262 |
+
def create_optimizer_with_different_learning_rates(self):
|
| 263 |
+
"""
|
| 264 |
+
Setup the optimizer.
|
| 265 |
+
|
| 266 |
+
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
| 267 |
+
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
|
| 268 |
+
"""
|
| 269 |
+
if is_sagemaker_mp_enabled():
|
| 270 |
+
raise NotImplementedError("Sagemaker MP is not supported for separate learning rate yet")
|
| 271 |
+
return super().create_optimizer()
|
| 272 |
+
|
| 273 |
+
opt_model = self.model
|
| 274 |
+
|
| 275 |
+
if self.optimizer is None:
|
| 276 |
+
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
|
| 277 |
+
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
| 278 |
+
|
| 279 |
+
new_parameters = []
|
| 280 |
+
for name, param in opt_model.named_parameters():
|
| 281 |
+
if ("pointer_head" in name) or ("embed_tokens" in name):
|
| 282 |
+
new_parameters.append(name)
|
| 283 |
+
rank0_print(f"new_parameters: {len(new_parameters)}")
|
| 284 |
+
|
| 285 |
+
optimizer_grouped_parameters = [
|
| 286 |
+
{
|
| 287 |
+
"params": [p for n, p in opt_model.named_parameters() if ((n in decay_parameters) and (n not in new_parameters) and p.requires_grad)],
|
| 288 |
+
"weight_decay": self.args.weight_decay,
|
| 289 |
+
"lr": self.args.learning_rate,
|
| 290 |
+
},
|
| 291 |
+
{
|
| 292 |
+
"params": [p for n, p in opt_model.named_parameters() if ((n not in decay_parameters) and (n not in new_parameters) and p.requires_grad)],
|
| 293 |
+
"weight_decay": 0.0,
|
| 294 |
+
"lr": self.args.learning_rate,
|
| 295 |
+
},
|
| 296 |
+
{
|
| 297 |
+
"params": [p for n, p in opt_model.named_parameters() if ((n in decay_parameters) and (n in new_parameters) and p.requires_grad)],
|
| 298 |
+
"weight_decay": self.args.weight_decay,
|
| 299 |
+
"lr": self.args.learning_rate_new_params,
|
| 300 |
+
},
|
| 301 |
+
{
|
| 302 |
+
"params": [p for n, p in opt_model.named_parameters() if ((n not in decay_parameters) and (n in new_parameters) and p.requires_grad)],
|
| 303 |
+
"weight_decay": 0.0,
|
| 304 |
+
"lr": self.args.learning_rate_new_params,
|
| 305 |
+
},
|
| 306 |
+
]
|
| 307 |
+
|
| 308 |
+
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) # {'lr': 0.0001, 'betas': (0.9, 0.999), 'eps': 1e-08}
|
| 309 |
+
optimizer_kwargs.pop("lr")
|
| 310 |
+
|
| 311 |
+
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
| 312 |
+
|
| 313 |
+
return self.optimizer
|
gui_actor/utils.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image, ImageDraw, ImageColor
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
def dump_args_to_json(model_config, data_processor, model_args, data_args, training_args, output_dir):
|
| 6 |
+
def is_json_serializable(v):
|
| 7 |
+
try:
|
| 8 |
+
json.dumps(v)
|
| 9 |
+
return True
|
| 10 |
+
except:
|
| 11 |
+
return False
|
| 12 |
+
|
| 13 |
+
save_path = f"{output_dir}/args.json"
|
| 14 |
+
if not os.path.exists(save_path):
|
| 15 |
+
with open(save_path, "w") as f:
|
| 16 |
+
json.dump({
|
| 17 |
+
"model_config": {k: v for k, v in model_config.__dict__.items() if is_json_serializable(v)},
|
| 18 |
+
"data_processor_config": {k: v for k, v in data_processor.__dict__.items() if is_json_serializable(v)},
|
| 19 |
+
"image_processor_config": {k: v for k, v in data_processor.image_processor.__dict__.items() if is_json_serializable(v)},
|
| 20 |
+
"model_args": {k: v for k, v in model_args.__dict__.items() if is_json_serializable(v)},
|
| 21 |
+
"data_args": {k: v for k, v in data_args.__dict__.items() if is_json_serializable(v)},
|
| 22 |
+
"training_args": {k: v for k, v in training_args.__dict__.items() if is_json_serializable(v)},
|
| 23 |
+
}, f, indent=4)
|
| 24 |
+
|
| 25 |
+
def draw_point(image: Image.Image, point: list, color=None):
|
| 26 |
+
if isinstance(color, str):
|
| 27 |
+
try:
|
| 28 |
+
color = ImageColor.getrgb(color)
|
| 29 |
+
color = color + (128,)
|
| 30 |
+
except ValueError:
|
| 31 |
+
color = (255, 0, 0, 128)
|
| 32 |
+
else:
|
| 33 |
+
color = (255, 0, 0, 128)
|
| 34 |
+
|
| 35 |
+
overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
|
| 36 |
+
overlay_draw = ImageDraw.Draw(overlay)
|
| 37 |
+
radius = 14
|
| 38 |
+
x, y = point
|
| 39 |
+
|
| 40 |
+
overlay_draw.rectangle(
|
| 41 |
+
[x - radius, y - radius, x + radius, y + radius],
|
| 42 |
+
fill=color
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
center_radius = radius * 0.1
|
| 46 |
+
overlay_draw.ellipse(
|
| 47 |
+
[(x - center_radius, y - center_radius),
|
| 48 |
+
(x + center_radius, y + center_radius)],
|
| 49 |
+
fill=(0, 255, 0, 255)
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
image = image.convert('RGBA')
|
| 53 |
+
combined = Image.alpha_composite(image, overlay)
|
| 54 |
+
|
| 55 |
+
return combined.convert('RGB')
|
| 56 |
+
|
| 57 |
+
def draw_bbox(image: Image.Image, bbox: list, color=None):
|
| 58 |
+
"""bbox is in the format of [x1, y1, x2, y2]"""
|
| 59 |
+
if isinstance(color, str):
|
| 60 |
+
try:
|
| 61 |
+
color = ImageColor.getrgb(color)
|
| 62 |
+
color = color + (128,)
|
| 63 |
+
except ValueError:
|
| 64 |
+
color = (255, 0, 0, 128)
|
| 65 |
+
else:
|
| 66 |
+
color = (255, 0, 0, 128)
|
| 67 |
+
|
| 68 |
+
overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
|
| 69 |
+
overlay_draw = ImageDraw.Draw(overlay)
|
| 70 |
+
overlay_draw.rectangle(bbox, fill=color)
|
| 71 |
+
return Image.alpha_composite(image, overlay).convert('RGB')
|
| 72 |
+
|
| 73 |
+
def do_boxes_overlap(box1, box2):
|
| 74 |
+
"""
|
| 75 |
+
Check if two boxes overlap.
|
| 76 |
+
|
| 77 |
+
Each box is represented as a tuple: (x1, y1, x2, y2)
|
| 78 |
+
Where (x1, y1) is the top-left and (x2, y2) is the bottom-right corner.
|
| 79 |
+
"""
|
| 80 |
+
# Unpack the coordinates
|
| 81 |
+
x1_min, y1_min, x1_max, y1_max = box1
|
| 82 |
+
x2_min, y2_min, x2_max, y2_max = box2
|
| 83 |
+
|
| 84 |
+
# Check for no overlap
|
| 85 |
+
if x1_max < x2_min or x2_max < x1_min:
|
| 86 |
+
return False
|
| 87 |
+
if y1_max < y2_min or y2_max < y1_min:
|
| 88 |
+
return False
|
| 89 |
+
|
| 90 |
+
return True
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.7.1
|
| 2 |
+
torchvision
|
| 3 |
+
torchaudio
|
| 4 |
+
|
| 5 |
+
accelerate==1.1.1
|
| 6 |
+
gradio
|
| 7 |
+
gradio_client
|
| 8 |
+
spaces
|
| 9 |
+
|
| 10 |
+
Pillow==11.1.0
|
| 11 |
+
opencv-python-headless==4.11.0.86
|
| 12 |
+
datasets==3.6.0
|
| 13 |
+
|
| 14 |
+
transformers==4.51.3
|
| 15 |
+
qwen-vl-utils==0.0.8
|
| 16 |
+
|
| 17 |
+
pre-commit==4.2.0
|
| 18 |
+
matplotlib
|
| 19 |
+
flash-attn==2.7.3
|
run.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64, os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import gradio as gr
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from PIL import Image, ImageDraw
|
| 7 |
+
import numpy as np
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
from qwen_vl_utils import process_vision_info
|
| 10 |
+
from datasets import load_dataset
|
| 11 |
+
from transformers import AutoProcessor
|
| 12 |
+
from gui_actor.constants import chat_template
|
| 13 |
+
from gui_actor.modeling_qwen25vl import Qwen2_5_VLForConditionalGenerationWithPointer
|
| 14 |
+
from gui_actor.inference import inference
|
| 15 |
+
|
| 16 |
+
MAX_PIXELS = 3200 * 1800
|
| 17 |
+
|
| 18 |
+
def resize_image(image, resize_to_pixels=MAX_PIXELS):
|
| 19 |
+
image_width, image_height = image.size
|
| 20 |
+
if (resize_to_pixels is not None) and ((image_width * image_height) != resize_to_pixels):
|
| 21 |
+
resize_ratio = (resize_to_pixels / (image_width * image_height)) ** 0.5
|
| 22 |
+
image_width_resized, image_height_resized = int(image_width * resize_ratio), int(image_height * resize_ratio)
|
| 23 |
+
image = image.resize((image_width_resized, image_height_resized))
|
| 24 |
+
return image
|
| 25 |
+
|
| 26 |
+
@torch.inference_mode()
|
| 27 |
+
def draw_point(image: Image.Image, point: list, radius=8, color=(255, 0, 0, 128)):
|
| 28 |
+
overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
|
| 29 |
+
overlay_draw = ImageDraw.Draw(overlay)
|
| 30 |
+
x, y = point
|
| 31 |
+
overlay_draw.ellipse(
|
| 32 |
+
[(x - radius, y - radius), (x + radius, y + radius)],
|
| 33 |
+
outline=color,
|
| 34 |
+
width=5
|
| 35 |
+
)
|
| 36 |
+
image = image.convert('RGBA')
|
| 37 |
+
combined = Image.alpha_composite(image, overlay)
|
| 38 |
+
combined = combined.convert('RGB')
|
| 39 |
+
return combined
|
| 40 |
+
|
| 41 |
+
@torch.inference_mode()
|
| 42 |
+
def get_attn_map(image, attn_scores, n_width, n_height):
|
| 43 |
+
w, h = image.size
|
| 44 |
+
scores = np.array(attn_scores[0]).reshape(n_height, n_width)
|
| 45 |
+
|
| 46 |
+
scores_norm = (scores - scores.min()) / (scores.max() - scores.min())
|
| 47 |
+
score_map = Image.fromarray((scores_norm * 255).astype(np.uint8)).resize((w, h), resample=Image.NEAREST)
|
| 48 |
+
colormap = plt.get_cmap('jet')
|
| 49 |
+
colored_score_map = colormap(np.array(score_map) / 255.0)
|
| 50 |
+
colored_score_map = (colored_score_map[:, :, :3] * 255).astype(np.uint8)
|
| 51 |
+
colored_overlay = Image.fromarray(colored_score_map)
|
| 52 |
+
|
| 53 |
+
blended = Image.blend(image, colored_overlay, alpha=0.3)
|
| 54 |
+
return blended
|
| 55 |
+
|
| 56 |
+
# 加载模型
|
| 57 |
+
if torch.cuda.is_available():
|
| 58 |
+
model_name_or_path = "microsoft/GUI-Actor-7B-Qwen2.5-VL"
|
| 59 |
+
data_processor = AutoProcessor.from_pretrained(model_name_or_path)
|
| 60 |
+
tokenizer = data_processor.tokenizer
|
| 61 |
+
model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained(
|
| 62 |
+
model_name_or_path,
|
| 63 |
+
torch_dtype=torch.bfloat16,
|
| 64 |
+
device_map="cuda:0",
|
| 65 |
+
attn_implementation="flash_attention_2"
|
| 66 |
+
).eval()
|
| 67 |
+
else:
|
| 68 |
+
model_name_or_path = "microsoft/GUI-Actor-3B-Qwen2.5-VL"
|
| 69 |
+
data_processor = AutoProcessor.from_pretrained(model_name_or_path)
|
| 70 |
+
tokenizer = data_processor.tokenizer
|
| 71 |
+
model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained(
|
| 72 |
+
model_name_or_path,
|
| 73 |
+
torch_dtype=torch.bfloat16,
|
| 74 |
+
device_map="cpu"
|
| 75 |
+
).eval()
|
| 76 |
+
|
| 77 |
+
title = "GUI-Actor"
|
| 78 |
+
header = """
|
| 79 |
+
<div align="center">
|
| 80 |
+
<h1 style="padding-bottom: 10px; padding-top: 10px;">🎯 <strong>GUI-Actor</strong>: Coordinate-Free Visual Grounding for GUI Agents</h1>
|
| 81 |
+
<div style="padding-bottom: 10px; padding-top: 10px; font-size: 16px;">
|
| 82 |
+
Qianhui Wu*, Kanzhi Cheng*, Rui Yang*, Chaoyun Zhang, Jianwei Yang, Huiqiang Jiang, Jian Mu, Baolin Peng, Bo Qiao, Reuben Tan, Si Qin, Lars Liden<br>
|
| 83 |
+
Qingwei Lin, Huan Zhang, Tong Zhang, Jianbing Zhang, Dongmei Zhang, Jianfeng Gao<br/>
|
| 84 |
+
</div>
|
| 85 |
+
<div style="padding-bottom: 10px; padding-top: 10px; font-size: 16px;">
|
| 86 |
+
<a href="https://microsoft.github.io/GUI-Actor/">🌐 Project Page</a> | <a href="https://arxiv.org/abs/2403.12968">📄 arXiv Paper</a> | <a href="https://github.com/microsoft/GUI-Actor">💻 Github Repo</a><br/>
|
| 87 |
+
</div>
|
| 88 |
+
</div>
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
theme = "soft"
|
| 92 |
+
css = """#anno-img .mask {opacity: 0.5; transition: all 0.2s ease-in-out;}
|
| 93 |
+
#anno-img .mask.active {opacity: 0.7}"""
|
| 94 |
+
|
| 95 |
+
@torch.inference_mode()
|
| 96 |
+
def process(image, instruction):
|
| 97 |
+
# 调整图像大小
|
| 98 |
+
w, h = image.size
|
| 99 |
+
if w * h > MAX_PIXELS:
|
| 100 |
+
image = resize_image(image)
|
| 101 |
+
|
| 102 |
+
conversation = [
|
| 103 |
+
{
|
| 104 |
+
"role": "system",
|
| 105 |
+
"content": [
|
| 106 |
+
{
|
| 107 |
+
"type": "text",
|
| 108 |
+
"text": "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, your task is to locate the screen element that corresponds to the instruction. You should output a PyAutoGUI action that performs a click on the correct position. To indicate the click location, we will use some special tokens, which is used to refer to a visual patch later. For example, you can output: pyautogui.click(<your_special_token_here>).",
|
| 109 |
+
}
|
| 110 |
+
]
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
"role": "user",
|
| 114 |
+
"content": [
|
| 115 |
+
{
|
| 116 |
+
"type": "image",
|
| 117 |
+
"image": image,
|
| 118 |
+
},
|
| 119 |
+
{
|
| 120 |
+
"type": "text",
|
| 121 |
+
"text": instruction,
|
| 122 |
+
},
|
| 123 |
+
],
|
| 124 |
+
},
|
| 125 |
+
]
|
| 126 |
+
|
| 127 |
+
try:
|
| 128 |
+
pred = inference(conversation, model, tokenizer, data_processor, use_placeholder=True, topk=3)
|
| 129 |
+
except Exception as e:
|
| 130 |
+
print(e)
|
| 131 |
+
return image, f"Error: {e}", None
|
| 132 |
+
|
| 133 |
+
px, py = pred["topk_points"][0]
|
| 134 |
+
output_coord = f"({px:.4f}, {py:.4f})"
|
| 135 |
+
img_with_point = draw_point(image, (px * w, py * h))
|
| 136 |
+
|
| 137 |
+
n_width, n_height = pred["n_width"], pred["n_height"]
|
| 138 |
+
attn_scores = pred["attn_scores"]
|
| 139 |
+
att_map = get_attn_map(image, attn_scores, n_width, n_height)
|
| 140 |
+
|
| 141 |
+
return img_with_point, output_coord, att_map
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
with gr.Blocks(title=title, css=css) as demo:
|
| 145 |
+
gr.Markdown(header)
|
| 146 |
+
with gr.Row():
|
| 147 |
+
with gr.Column():
|
| 148 |
+
input_image = gr.Image(
|
| 149 |
+
type='pil', label='Upload image')
|
| 150 |
+
input_instruction = gr.Textbox(label='Instruction', placeholder='Text your (low-level) instruction here')
|
| 151 |
+
submit_button = gr.Button(
|
| 152 |
+
value='Submit', variant='primary')
|
| 153 |
+
with gr.Column():
|
| 154 |
+
image_with_point = gr.Image(type='pil', label='Image with Point (red circle)')
|
| 155 |
+
with gr.Accordion('Detailed prediction'):
|
| 156 |
+
pred_xy = gr.Textbox(label='Predicted Coordinates', placeholder='(x, y)')
|
| 157 |
+
att_map = gr.Image(type='pil', label='Attention Map')
|
| 158 |
+
|
| 159 |
+
submit_button.click(
|
| 160 |
+
fn=process,
|
| 161 |
+
inputs=[
|
| 162 |
+
input_image,
|
| 163 |
+
input_instruction
|
| 164 |
+
],
|
| 165 |
+
outputs=[image_with_point, pred_xy, att_map]
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# 关键修改:仅在本地9876端口启动服务,不启用公网转发
|
| 169 |
+
demo.queue().launch(
|
| 170 |
+
server_port=9876, # 指定端口为9876
|
| 171 |
+
server_name='127.0.0.1',# 仅本地可访问(如需局域网访问可改为'0.0.0.0')
|
| 172 |
+
share=False # 禁用公网转发服务,避免临时链接
|
| 173 |
+
)
|