Commit
·
5b53c67
1
Parent(s):
6ba6dce
clean
Browse files- interactive_demo.py +16 -39
interactive_demo.py
CHANGED
|
@@ -47,20 +47,12 @@ def heart_beat_worker(controller):
|
|
| 47 |
|
| 48 |
|
| 49 |
class ModelWorker:
|
| 50 |
-
def __init__(self, controller_addr, worker_addr, worker_id, no_register, vlm,
|
| 51 |
self.controller_addr = controller_addr
|
| 52 |
self.worker_addr = worker_addr
|
| 53 |
self.worker_id = worker_id
|
| 54 |
self.model_name = model_name
|
| 55 |
-
|
| 56 |
-
# logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
|
| 57 |
self.vlm = vlm
|
| 58 |
-
self.tokenizer, self.model, self.image_processor, self.context_len = (
|
| 59 |
-
vlm.tokenizer,
|
| 60 |
-
vlm.model,
|
| 61 |
-
vlm.image_processor,
|
| 62 |
-
vlm.max_length,
|
| 63 |
-
)
|
| 64 |
|
| 65 |
if not no_register:
|
| 66 |
self.register_to_controller()
|
|
@@ -68,18 +60,12 @@ class ModelWorker:
|
|
| 68 |
self.heart_beat_thread.start()
|
| 69 |
|
| 70 |
def register_to_controller(self):
|
| 71 |
-
# logger.info("Register to controller")
|
| 72 |
-
|
| 73 |
url = self.controller_addr + "/register_worker"
|
| 74 |
data = {"worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status()}
|
| 75 |
r = requests.post(url, json=data)
|
| 76 |
assert r.status_code == 200
|
| 77 |
|
| 78 |
def send_heart_beat(self):
|
| 79 |
-
# logger.info(f"Send heart beat. Models: {[self.model_name]}. "
|
| 80 |
-
# f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
|
| 81 |
-
# f"global_counter: {global_counter}")
|
| 82 |
-
|
| 83 |
url = self.controller_addr + "/receive_heart_beat"
|
| 84 |
|
| 85 |
while True:
|
|
@@ -91,7 +77,6 @@ class ModelWorker:
|
|
| 91 |
break
|
| 92 |
except requests.exceptions.RequestException:
|
| 93 |
pass
|
| 94 |
-
# logger.error(f"heart beat error: {e}")
|
| 95 |
time.sleep(5)
|
| 96 |
|
| 97 |
if not exist:
|
|
@@ -145,12 +130,12 @@ class ModelWorker:
|
|
| 145 |
else:
|
| 146 |
question_prompt = [prompt_fn()]
|
| 147 |
|
| 148 |
-
if isinstance(self.image_processor, Compose) or hasattr(self.image_processor, "is_prismatic"):
|
| 149 |
# This is a standard `torchvision.transforms` object or custom PrismaticVLM wrapper
|
| 150 |
-
pixel_values = self.image_processor(images[0].convert("RGB"))
|
| 151 |
else:
|
| 152 |
# Assume `image_transform` is a HF ImageProcessor...
|
| 153 |
-
pixel_values = self.image_processor(images[0].convert("RGB"), return_tensors="pt")["pixel_values"][0]
|
| 154 |
|
| 155 |
if type(pixel_values) is dict:
|
| 156 |
for k in pixel_values.keys():
|
|
@@ -227,31 +212,29 @@ overwatch = initialize_overwatch(__name__)
|
|
| 227 |
class DemoConfig:
|
| 228 |
# fmt: off
|
| 229 |
|
| 230 |
-
# === Model Parameters =>>
|
| 231 |
-
model_family: str = "
|
| 232 |
-
model_id: str = "
|
| 233 |
-
model_dir:
|
| 234 |
-
"resize-naive-siglip-vit-l-16-384px-no-align-2-epochs+13b+stage-finetune+x7"
|
| 235 |
-
)
|
| 236 |
|
| 237 |
# === Model Parameters =>> Official LLaVa ===
|
| 238 |
# model_family: str = "llava-v15"
|
| 239 |
# model_id: str = "llava-v1.5-13b"
|
| 240 |
# model_dir: Path = "liuhaotian/llava-v1.5-13b"
|
| 241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
# Model Worker Parameters
|
| 243 |
host: str = "0.0.0.0"
|
| 244 |
port: int = 40000
|
| 245 |
controller_address: str = "http://localhost:10000"
|
| 246 |
-
model_base: str = "llava-v15"
|
| 247 |
limit_model_concurrency: int = 5
|
| 248 |
stream_interval: int = 1
|
| 249 |
no_register: bool = False
|
| 250 |
|
| 251 |
-
# Inference Parameters
|
| 252 |
-
device_batch_size: int = 1 # Device Batch Size set to 1 until LLaVa/HF LLaMa fixes bugs!
|
| 253 |
-
num_workers: int = 2 # Number of Dataloader Workers (on each process)
|
| 254 |
-
|
| 255 |
# HF Hub Credentials (for LLaMa-2)
|
| 256 |
hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token
|
| 257 |
|
|
@@ -259,14 +242,8 @@ class DemoConfig:
|
|
| 259 |
seed: int = 21 # Random Seed (for reproducibility)
|
| 260 |
|
| 261 |
def __post_init__(self) -> None:
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
self.run_dir = Path("/mnt/fsx/x-onyx-vlms/runs") / self.model_dir
|
| 265 |
-
elif self.model_family in {"instruct-blip", "llava", "llava-v15"}:
|
| 266 |
-
self.model_name = MODEL_ID_TO_NAME[self.model_id]
|
| 267 |
-
self.run_dir = self.model_dir
|
| 268 |
-
else:
|
| 269 |
-
raise ValueError(f"Run Directory for `{self.model_family = }` does not exist!")
|
| 270 |
self.worker_address = f"http://localhost:{self.port}"
|
| 271 |
|
| 272 |
# fmt: on
|
|
@@ -286,7 +263,7 @@ def interactive_demo(cfg: DemoConfig):
|
|
| 286 |
global limit_model_concurrency
|
| 287 |
limit_model_concurrency = cfg.limit_model_concurrency
|
| 288 |
worker = ModelWorker(
|
| 289 |
-
cfg.controller_address, cfg.worker_address, worker_id, cfg.no_register, vlm, cfg.
|
| 290 |
)
|
| 291 |
uvicorn.run(app, host=cfg.host, port=cfg.port, log_level="info")
|
| 292 |
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
class ModelWorker:
|
| 50 |
+
def __init__(self, controller_addr, worker_addr, worker_id, no_register, vlm, model_name):
|
| 51 |
self.controller_addr = controller_addr
|
| 52 |
self.worker_addr = worker_addr
|
| 53 |
self.worker_id = worker_id
|
| 54 |
self.model_name = model_name
|
|
|
|
|
|
|
| 55 |
self.vlm = vlm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
if not no_register:
|
| 58 |
self.register_to_controller()
|
|
|
|
| 60 |
self.heart_beat_thread.start()
|
| 61 |
|
| 62 |
def register_to_controller(self):
|
|
|
|
|
|
|
| 63 |
url = self.controller_addr + "/register_worker"
|
| 64 |
data = {"worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status()}
|
| 65 |
r = requests.post(url, json=data)
|
| 66 |
assert r.status_code == 200
|
| 67 |
|
| 68 |
def send_heart_beat(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
url = self.controller_addr + "/receive_heart_beat"
|
| 70 |
|
| 71 |
while True:
|
|
|
|
| 77 |
break
|
| 78 |
except requests.exceptions.RequestException:
|
| 79 |
pass
|
|
|
|
| 80 |
time.sleep(5)
|
| 81 |
|
| 82 |
if not exist:
|
|
|
|
| 130 |
else:
|
| 131 |
question_prompt = [prompt_fn()]
|
| 132 |
|
| 133 |
+
if isinstance(self.vlm.image_processor, Compose) or hasattr(self.vlm.image_processor, "is_prismatic"):
|
| 134 |
# This is a standard `torchvision.transforms` object or custom PrismaticVLM wrapper
|
| 135 |
+
pixel_values = self.vlm.image_processor(images[0].convert("RGB"))
|
| 136 |
else:
|
| 137 |
# Assume `image_transform` is a HF ImageProcessor...
|
| 138 |
+
pixel_values = self.vlm.image_processor(images[0].convert("RGB"), return_tensors="pt")["pixel_values"][0]
|
| 139 |
|
| 140 |
if type(pixel_values) is dict:
|
| 141 |
for k in pixel_values.keys():
|
|
|
|
| 212 |
class DemoConfig:
|
| 213 |
# fmt: off
|
| 214 |
|
| 215 |
+
# === Model Parameters =>> Prismatic ===
|
| 216 |
+
model_family: str = "prismatic" # Model family to load from in < `prismatic` | `llava-v15` | ... >
|
| 217 |
+
model_id: str = "prism-dinosiglip+7b" # Model ID to load and run (instance of `model_family`)
|
| 218 |
+
model_dir: str = None # Can optionally supply model_dir instead of model_id
|
|
|
|
|
|
|
| 219 |
|
| 220 |
# === Model Parameters =>> Official LLaVa ===
|
| 221 |
# model_family: str = "llava-v15"
|
| 222 |
# model_id: str = "llava-v1.5-13b"
|
| 223 |
# model_dir: Path = "liuhaotian/llava-v1.5-13b"
|
| 224 |
|
| 225 |
+
# === Model Parameters =>> Official InstructBLIP ===
|
| 226 |
+
# model_family: str = "instruct-blip"
|
| 227 |
+
# model_id: str = "instructblip-vicuna-7b"
|
| 228 |
+
# model_dir: Path = "Salesforce/instructblip-vicuna-7b"
|
| 229 |
+
|
| 230 |
# Model Worker Parameters
|
| 231 |
host: str = "0.0.0.0"
|
| 232 |
port: int = 40000
|
| 233 |
controller_address: str = "http://localhost:10000"
|
|
|
|
| 234 |
limit_model_concurrency: int = 5
|
| 235 |
stream_interval: int = 1
|
| 236 |
no_register: bool = False
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
# HF Hub Credentials (for LLaMa-2)
|
| 239 |
hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token
|
| 240 |
|
|
|
|
| 242 |
seed: int = 21 # Random Seed (for reproducibility)
|
| 243 |
|
| 244 |
def __post_init__(self) -> None:
|
| 245 |
+
self.run_dir = self.model_dir
|
| 246 |
+
self.model_name = MODEL_ID_TO_NAME[str(self.model_id)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
self.worker_address = f"http://localhost:{self.port}"
|
| 248 |
|
| 249 |
# fmt: on
|
|
|
|
| 263 |
global limit_model_concurrency
|
| 264 |
limit_model_concurrency = cfg.limit_model_concurrency
|
| 265 |
worker = ModelWorker(
|
| 266 |
+
cfg.controller_address, cfg.worker_address, worker_id, cfg.no_register, vlm, cfg.model_name
|
| 267 |
)
|
| 268 |
uvicorn.run(app, host=cfg.host, port=cfg.port, log_level="info")
|
| 269 |
|