Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,7 +6,7 @@ import logging
|
|
| 6 |
import os
|
| 7 |
import sys
|
| 8 |
import time
|
| 9 |
-
import spaces
|
| 10 |
import gradio as gr
|
| 11 |
import torch
|
| 12 |
from PIL import Image
|
|
@@ -34,6 +34,10 @@ logger = logging.getLogger("gradio_web_server")
|
|
| 34 |
LOGDIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs")
|
| 35 |
os.makedirs(os.path.join(LOGDIR, "serve_images"), exist_ok=True)
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
default_taxonomy = policy_v1
|
| 38 |
|
| 39 |
|
|
@@ -147,6 +151,7 @@ disable_btn = gr.Button(interactive=False)
|
|
| 147 |
|
| 148 |
|
| 149 |
# Model loading function
|
|
|
|
| 150 |
def load_model(model_path):
|
| 151 |
global tokenizer, model, processor, context_len
|
| 152 |
|
|
@@ -183,16 +188,6 @@ def load_model(model_path):
|
|
| 183 |
return # Remove return value to avoid Gradio warnings
|
| 184 |
|
| 185 |
|
| 186 |
-
def get_model_list():
|
| 187 |
-
models = [
|
| 188 |
-
'AIML-TUDA/QwenGuard-v1.2-3B',
|
| 189 |
-
'AIML-TUDA/QwenGuard-v1.2-7B',
|
| 190 |
-
'AIML-TUDA/LlavaGuard-v1.2-0.5B-OV-hf',
|
| 191 |
-
'AIML-TUDA/LlavaGuard-v1.2-7B-OV-hf',
|
| 192 |
-
]
|
| 193 |
-
return models
|
| 194 |
-
|
| 195 |
-
|
| 196 |
def get_conv_log_filename():
|
| 197 |
t = datetime.datetime.now()
|
| 198 |
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
|
@@ -206,7 +201,7 @@ def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
|
|
| 206 |
global model, tokenizer, processor
|
| 207 |
|
| 208 |
if model is None or processor is None:
|
| 209 |
-
return "Model not loaded. Please
|
| 210 |
try:
|
| 211 |
# Check if it's a Qwen model
|
| 212 |
if isinstance(model, Qwen2_5_VLForConditionalGeneration):
|
|
@@ -290,57 +285,43 @@ function() {
|
|
| 290 |
|
| 291 |
def load_demo(url_params, request: gr.Request):
|
| 292 |
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
| 293 |
-
models = get_model_list()
|
| 294 |
-
|
| 295 |
-
dropdown_update = gr.Dropdown(visible=True)
|
| 296 |
-
if "model" in url_params:
|
| 297 |
-
model = url_params["model"]
|
| 298 |
-
if model in models:
|
| 299 |
-
dropdown_update = gr.Dropdown(value=model, visible=True)
|
| 300 |
-
load_model(model)
|
| 301 |
-
|
| 302 |
state = default_conversation.copy()
|
| 303 |
-
return state
|
| 304 |
|
| 305 |
|
| 306 |
-
def
|
| 307 |
logger.info(f"load_demo. ip: {request.client.host}")
|
| 308 |
-
models = get_model_list()
|
| 309 |
state = default_conversation.copy()
|
| 310 |
-
|
| 311 |
-
choices=models,
|
| 312 |
-
value=models[0] if len(models) > 0 else ""
|
| 313 |
-
)
|
| 314 |
-
return state, dropdown_update
|
| 315 |
|
| 316 |
|
| 317 |
-
def vote_last_response(state, vote_type,
|
| 318 |
with open(get_conv_log_filename(), "a") as fout:
|
| 319 |
data = {
|
| 320 |
"tstamp": round(time.time(), 4),
|
| 321 |
"type": vote_type,
|
| 322 |
-
"model":
|
| 323 |
"state": state.dict(),
|
| 324 |
"ip": request.client.host,
|
| 325 |
}
|
| 326 |
fout.write(json.dumps(data) + "\n")
|
| 327 |
|
| 328 |
|
| 329 |
-
def upvote_last_response(state,
|
| 330 |
logger.info(f"upvote. ip: {request.client.host}")
|
| 331 |
-
vote_last_response(state, "upvote",
|
| 332 |
return ("",) + (disable_btn,) * 3
|
| 333 |
|
| 334 |
|
| 335 |
-
def downvote_last_response(state,
|
| 336 |
logger.info(f"downvote. ip: {request.client.host}")
|
| 337 |
-
vote_last_response(state, "downvote",
|
| 338 |
return ("",) + (disable_btn,) * 3
|
| 339 |
|
| 340 |
|
| 341 |
-
def flag_last_response(state,
|
| 342 |
logger.info(f"flag. ip: {request.client.host}")
|
| 343 |
-
vote_last_response(state, "flag",
|
| 344 |
return ("",) + (disable_btn,) * 3
|
| 345 |
|
| 346 |
|
|
@@ -390,7 +371,7 @@ def add_text(state, text, image, image_process_mode, request: gr.Request):
|
|
| 390 |
return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
|
| 391 |
|
| 392 |
|
| 393 |
-
def llava_bot(state,
|
| 394 |
start_tstamp = time.time()
|
| 395 |
|
| 396 |
if state.skip_next:
|
|
@@ -410,10 +391,6 @@ def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request
|
|
| 410 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
| 411 |
return
|
| 412 |
|
| 413 |
-
# Load model if needed
|
| 414 |
-
if model is None or model_selector != getattr(model, "_name_or_path", ""):
|
| 415 |
-
load_model(model_selector)
|
| 416 |
-
|
| 417 |
# Run inference
|
| 418 |
output = run_inference(prompt, all_images[0], temperature, top_p, max_new_tokens)
|
| 419 |
|
|
@@ -434,7 +411,7 @@ def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request
|
|
| 434 |
data = {
|
| 435 |
"tstamp": round(finish_tstamp, 4),
|
| 436 |
"type": "chat",
|
| 437 |
-
"model":
|
| 438 |
"start": round(start_tstamp, 4),
|
| 439 |
"finish": round(finish_tstamp, 4),
|
| 440 |
"state": state.dict(),
|
|
@@ -477,8 +454,6 @@ block_css = """
|
|
| 477 |
|
| 478 |
|
| 479 |
def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
| 480 |
-
models = get_model_list()
|
| 481 |
-
|
| 482 |
with gr.Blocks(title="LlavaGuard", theme=gr.themes.Default(), css=block_css) as demo:
|
| 483 |
state = gr.State()
|
| 484 |
|
|
@@ -487,13 +462,7 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
|
| 487 |
|
| 488 |
with gr.Row():
|
| 489 |
with gr.Column(scale=3):
|
| 490 |
-
|
| 491 |
-
model_selector = gr.Dropdown(
|
| 492 |
-
choices=models,
|
| 493 |
-
value=models[0] if len(models) > 0 else "",
|
| 494 |
-
interactive=True,
|
| 495 |
-
show_label=False,
|
| 496 |
-
container=False)
|
| 497 |
|
| 498 |
imagebox = gr.Image(type="pil", label="Image", container=False)
|
| 499 |
image_process_mode = gr.Radio(
|
|
@@ -559,35 +528,29 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
|
| 559 |
|
| 560 |
upvote_btn.click(
|
| 561 |
upvote_last_response,
|
| 562 |
-
[state
|
| 563 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
| 564 |
)
|
| 565 |
|
| 566 |
downvote_btn.click(
|
| 567 |
downvote_last_response,
|
| 568 |
-
[state
|
| 569 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
| 570 |
)
|
| 571 |
|
| 572 |
flag_btn.click(
|
| 573 |
flag_last_response,
|
| 574 |
-
[state
|
| 575 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
| 576 |
)
|
| 577 |
|
| 578 |
-
model_selector.change(
|
| 579 |
-
load_model,
|
| 580 |
-
[model_selector],
|
| 581 |
-
None
|
| 582 |
-
)
|
| 583 |
-
|
| 584 |
regenerate_btn.click(
|
| 585 |
regenerate,
|
| 586 |
[state, image_process_mode],
|
| 587 |
[state, chatbot, textbox, imagebox] + btn_list
|
| 588 |
).then(
|
| 589 |
llava_bot,
|
| 590 |
-
[state,
|
| 591 |
[state, chatbot] + btn_list,
|
| 592 |
concurrency_limit=concurrency_count
|
| 593 |
)
|
|
@@ -606,7 +569,7 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
|
| 606 |
queue=False
|
| 607 |
).then(
|
| 608 |
llava_bot,
|
| 609 |
-
[state,
|
| 610 |
[state, chatbot] + btn_list,
|
| 611 |
concurrency_limit=concurrency_count
|
| 612 |
)
|
|
@@ -617,15 +580,15 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
|
| 617 |
[state, chatbot, textbox, imagebox] + btn_list
|
| 618 |
).then(
|
| 619 |
llava_bot,
|
| 620 |
-
[state,
|
| 621 |
[state, chatbot] + btn_list,
|
| 622 |
concurrency_limit=concurrency_count
|
| 623 |
)
|
| 624 |
|
| 625 |
demo.load(
|
| 626 |
-
|
| 627 |
None,
|
| 628 |
-
[state
|
| 629 |
queue=False
|
| 630 |
)
|
| 631 |
|
|
@@ -658,6 +621,8 @@ if api_key:
|
|
| 658 |
login(token=api_key)
|
| 659 |
logger.info("Logged in to Hugging Face Hub")
|
| 660 |
|
|
|
|
|
|
|
| 661 |
|
| 662 |
demo = build_demo(embed_mode=args.embed, cur_dir='./', concurrency_count=args.concurrency_count)
|
| 663 |
demo.queue(
|
|
@@ -667,4 +632,4 @@ demo.queue(
|
|
| 667 |
server_name=args.host,
|
| 668 |
server_port=args.port,
|
| 669 |
share=args.share
|
| 670 |
-
)
|
|
|
|
| 6 |
import os
|
| 7 |
import sys
|
| 8 |
import time
|
| 9 |
+
from huggingface_hub import spaces
|
| 10 |
import gradio as gr
|
| 11 |
import torch
|
| 12 |
from PIL import Image
|
|
|
|
| 34 |
LOGDIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs")
|
| 35 |
os.makedirs(os.path.join(LOGDIR, "serve_images"), exist_ok=True)
|
| 36 |
|
| 37 |
+
# Get default model from environment variable or use a fallback
|
| 38 |
+
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "AIML-TUDA/LlavaGuard-v1.2-7B-OV-hf")
|
| 39 |
+
logger.info(f"Using model: {DEFAULT_MODEL}")
|
| 40 |
+
|
| 41 |
default_taxonomy = policy_v1
|
| 42 |
|
| 43 |
|
|
|
|
| 151 |
|
| 152 |
|
| 153 |
# Model loading function
|
| 154 |
+
@spaces.GPU
|
| 155 |
def load_model(model_path):
|
| 156 |
global tokenizer, model, processor, context_len
|
| 157 |
|
|
|
|
| 188 |
return # Remove return value to avoid Gradio warnings
|
| 189 |
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
def get_conv_log_filename():
|
| 192 |
t = datetime.datetime.now()
|
| 193 |
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
|
|
|
| 201 |
global model, tokenizer, processor
|
| 202 |
|
| 203 |
if model is None or processor is None:
|
| 204 |
+
return "Model not loaded. Please wait for model to initialize."
|
| 205 |
try:
|
| 206 |
# Check if it's a Qwen model
|
| 207 |
if isinstance(model, Qwen2_5_VLForConditionalGeneration):
|
|
|
|
| 285 |
|
| 286 |
def load_demo(url_params, request: gr.Request):
|
| 287 |
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
state = default_conversation.copy()
|
| 289 |
+
return state
|
| 290 |
|
| 291 |
|
| 292 |
+
def load_demo_refresh(request: gr.Request):
|
| 293 |
logger.info(f"load_demo. ip: {request.client.host}")
|
|
|
|
| 294 |
state = default_conversation.copy()
|
| 295 |
+
return state
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
|
| 298 |
+
def vote_last_response(state, vote_type, request: gr.Request):
|
| 299 |
with open(get_conv_log_filename(), "a") as fout:
|
| 300 |
data = {
|
| 301 |
"tstamp": round(time.time(), 4),
|
| 302 |
"type": vote_type,
|
| 303 |
+
"model": DEFAULT_MODEL,
|
| 304 |
"state": state.dict(),
|
| 305 |
"ip": request.client.host,
|
| 306 |
}
|
| 307 |
fout.write(json.dumps(data) + "\n")
|
| 308 |
|
| 309 |
|
| 310 |
+
def upvote_last_response(state, request: gr.Request):
|
| 311 |
logger.info(f"upvote. ip: {request.client.host}")
|
| 312 |
+
vote_last_response(state, "upvote", request)
|
| 313 |
return ("",) + (disable_btn,) * 3
|
| 314 |
|
| 315 |
|
| 316 |
+
def downvote_last_response(state, request: gr.Request):
|
| 317 |
logger.info(f"downvote. ip: {request.client.host}")
|
| 318 |
+
vote_last_response(state, "downvote", request)
|
| 319 |
return ("",) + (disable_btn,) * 3
|
| 320 |
|
| 321 |
|
| 322 |
+
def flag_last_response(state, request: gr.Request):
|
| 323 |
logger.info(f"flag. ip: {request.client.host}")
|
| 324 |
+
vote_last_response(state, "flag", request)
|
| 325 |
return ("",) + (disable_btn,) * 3
|
| 326 |
|
| 327 |
|
|
|
|
| 371 |
return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
|
| 372 |
|
| 373 |
|
| 374 |
+
def llava_bot(state, temperature, top_p, max_new_tokens, request: gr.Request):
|
| 375 |
start_tstamp = time.time()
|
| 376 |
|
| 377 |
if state.skip_next:
|
|
|
|
| 391 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
| 392 |
return
|
| 393 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
# Run inference
|
| 395 |
output = run_inference(prompt, all_images[0], temperature, top_p, max_new_tokens)
|
| 396 |
|
|
|
|
| 411 |
data = {
|
| 412 |
"tstamp": round(finish_tstamp, 4),
|
| 413 |
"type": "chat",
|
| 414 |
+
"model": DEFAULT_MODEL,
|
| 415 |
"start": round(start_tstamp, 4),
|
| 416 |
"finish": round(finish_tstamp, 4),
|
| 417 |
"state": state.dict(),
|
|
|
|
| 454 |
|
| 455 |
|
| 456 |
def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
|
|
|
|
|
|
| 457 |
with gr.Blocks(title="LlavaGuard", theme=gr.themes.Default(), css=block_css) as demo:
|
| 458 |
state = gr.State()
|
| 459 |
|
|
|
|
| 462 |
|
| 463 |
with gr.Row():
|
| 464 |
with gr.Column(scale=3):
|
| 465 |
+
# Model selector removed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
|
| 467 |
imagebox = gr.Image(type="pil", label="Image", container=False)
|
| 468 |
image_process_mode = gr.Radio(
|
|
|
|
| 528 |
|
| 529 |
upvote_btn.click(
|
| 530 |
upvote_last_response,
|
| 531 |
+
[state],
|
| 532 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
| 533 |
)
|
| 534 |
|
| 535 |
downvote_btn.click(
|
| 536 |
downvote_last_response,
|
| 537 |
+
[state],
|
| 538 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
| 539 |
)
|
| 540 |
|
| 541 |
flag_btn.click(
|
| 542 |
flag_last_response,
|
| 543 |
+
[state],
|
| 544 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
| 545 |
)
|
| 546 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
regenerate_btn.click(
|
| 548 |
regenerate,
|
| 549 |
[state, image_process_mode],
|
| 550 |
[state, chatbot, textbox, imagebox] + btn_list
|
| 551 |
).then(
|
| 552 |
llava_bot,
|
| 553 |
+
[state, temperature, top_p, max_output_tokens],
|
| 554 |
[state, chatbot] + btn_list,
|
| 555 |
concurrency_limit=concurrency_count
|
| 556 |
)
|
|
|
|
| 569 |
queue=False
|
| 570 |
).then(
|
| 571 |
llava_bot,
|
| 572 |
+
[state, temperature, top_p, max_output_tokens],
|
| 573 |
[state, chatbot] + btn_list,
|
| 574 |
concurrency_limit=concurrency_count
|
| 575 |
)
|
|
|
|
| 580 |
[state, chatbot, textbox, imagebox] + btn_list
|
| 581 |
).then(
|
| 582 |
llava_bot,
|
| 583 |
+
[state, temperature, top_p, max_output_tokens],
|
| 584 |
[state, chatbot] + btn_list,
|
| 585 |
concurrency_limit=concurrency_count
|
| 586 |
)
|
| 587 |
|
| 588 |
demo.load(
|
| 589 |
+
load_demo_refresh,
|
| 590 |
None,
|
| 591 |
+
[state],
|
| 592 |
queue=False
|
| 593 |
)
|
| 594 |
|
|
|
|
| 621 |
login(token=api_key)
|
| 622 |
logger.info("Logged in to Hugging Face Hub")
|
| 623 |
|
| 624 |
+
# Load model at startup
|
| 625 |
+
load_model(DEFAULT_MODEL)
|
| 626 |
|
| 627 |
demo = build_demo(embed_mode=args.embed, cur_dir='./', concurrency_count=args.concurrency_count)
|
| 628 |
demo.queue(
|
|
|
|
| 632 |
server_name=args.host,
|
| 633 |
server_port=args.port,
|
| 634 |
share=args.share
|
| 635 |
+
)
|