construction all
Browse files
app.py
CHANGED
|
@@ -228,18 +228,36 @@ def calculate_iou(box1, box2):
|
|
| 228 |
iou = intersection_area / union_area
|
| 229 |
return iou
|
| 230 |
|
| 231 |
-
|
| 232 |
-
def buildmodel(**kwargs):
|
| 233 |
global model
|
| 234 |
global quantizer
|
| 235 |
global tokenizer
|
|
|
|
|
|
|
| 236 |
from modeling_crello import CrelloModel, CrelloModelConfig
|
| 237 |
from quantizer import get_quantizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
# seed / input model / resume
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
quantizer_version = kwargs.get('quantizer_version', 'v4')
|
| 243 |
|
| 244 |
set_seed(seed)
|
| 245 |
# old_tokenizer = AutoTokenizer.from_pretrained(input_model, trust_remote_code=True)
|
|
@@ -261,13 +279,13 @@ def buildmodel(**kwargs):
|
|
| 261 |
quantizer = get_quantizer(
|
| 262 |
quantizer_version,
|
| 263 |
update_vocab = False,
|
| 264 |
-
decimal_quantize_types =
|
| 265 |
-
mask_values =
|
| 266 |
-
width =
|
| 267 |
-
height =
|
| 268 |
simplify_json = False,
|
| 269 |
num_mask_tokens = 0,
|
| 270 |
-
mask_type =
|
| 271 |
)
|
| 272 |
quantizer.setup_tokenizer(tokenizer)
|
| 273 |
|
|
@@ -280,11 +298,7 @@ def buildmodel(**kwargs):
|
|
| 280 |
model_args.freeze_lm = False
|
| 281 |
model_args.opt_version = input_model
|
| 282 |
model_args.use_lora = False
|
| 283 |
-
model_args.load_in_4bit =
|
| 284 |
-
# model = CrelloModel.from_pretrained(
|
| 285 |
-
# resume,
|
| 286 |
-
# config=model_args
|
| 287 |
-
# ).to(device)
|
| 288 |
|
| 289 |
model = CrelloModel.from_pretrained(
|
| 290 |
"WYBar/LLM_For_Layout_Planning",
|
|
@@ -300,63 +314,46 @@ def buildmodel(**kwargs):
|
|
| 300 |
for token in added_special_tokens_list:
|
| 301 |
quantizer.additional_special_tokens.add(token)
|
| 302 |
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
model = model.to("cuda")
|
| 305 |
-
|
|
|
|
| 306 |
model = model.bfloat16()
|
| 307 |
model.eval()
|
| 308 |
-
|
| 309 |
-
# tokenizer = tokenizer.to("cuda")
|
| 310 |
-
# model.lm = model.lm.to("cuda")
|
| 311 |
-
print(model.lm.device)
|
| 312 |
-
|
| 313 |
-
# return model, quantizer, tokenizer
|
| 314 |
-
|
| 315 |
-
def construction_layout():
|
| 316 |
-
params_dict = {
|
| 317 |
-
# 需要修改
|
| 318 |
-
"input_model": "/openseg_blob/v-sirui/temporary/2024-02-21/Layout_train/COLEv2/Design_LLM/checkpoint/Meta-Llama-3-8B",
|
| 319 |
-
"resume": "/openseg_blob/v-sirui/temporary/2024-02-21/SVD/Int2lay_1016/checkpoint/int2lay_1031/1031_test/checkpoint-26000/",
|
| 320 |
-
|
| 321 |
-
"seed": 0,
|
| 322 |
-
"mask_values": False,
|
| 323 |
-
"quantizer_version": 'v4',
|
| 324 |
-
"mask_type": 'cm3',
|
| 325 |
-
"decimal_quantize_types": [],
|
| 326 |
-
"num_mask_tokens": 0,
|
| 327 |
-
"width": 512,
|
| 328 |
-
"height": 512,
|
| 329 |
-
"device": 0,
|
| 330 |
-
}
|
| 331 |
-
device = "cuda"
|
| 332 |
-
# Init model
|
| 333 |
-
buildmodel(**params_dict)
|
| 334 |
-
# model, quantizer, tokenizer = buildmodel(**params_dict)
|
| 335 |
-
|
| 336 |
-
# print('resize token embeddings to match the tokenizer', 129423)
|
| 337 |
-
# model.lm.resize_token_embeddings(129423)
|
| 338 |
-
# model.input_embeddings = model.lm.get_input_embeddings()
|
| 339 |
-
# print('after token embeddings to match the tokenizer', 129423)
|
| 340 |
-
|
| 341 |
-
# print("before .to(device)")
|
| 342 |
-
# model = model.to("cuda")
|
| 343 |
-
# print("after .to(device)")
|
| 344 |
-
# model = model.bfloat16()
|
| 345 |
-
# model.eval()
|
| 346 |
-
# # quantizer = quantizer.to("cuda")
|
| 347 |
-
# # tokenizer = tokenizer.to("cuda")
|
| 348 |
-
# # model.lm = model.lm.to("cuda")
|
| 349 |
-
# print(model.lm.device)
|
| 350 |
-
|
| 351 |
-
return params_dict["width"], params_dict["height"], device
|
| 352 |
-
# return model, quantizer, tokenizer, params_dict["width"], params_dict["height"], device
|
| 353 |
|
| 354 |
@torch.no_grad()
|
| 355 |
@spaces.GPU(duration=120)
|
| 356 |
-
def evaluate_v1(inputs, model, quantizer, tokenizer, width, height,
|
| 357 |
-
print(model.lm.device)
|
| 358 |
json_example = inputs
|
| 359 |
input_intension = '{"wholecaption":"' + json_example["wholecaption"] + '","layout":[{"layer":'
|
|
|
|
| 360 |
print("tokenizer1")
|
| 361 |
inputs = tokenizer(
|
| 362 |
input_intension, return_tensors="pt"
|
|
@@ -395,7 +392,7 @@ def evaluate_v1(inputs, model, quantizer, tokenizer, width, height, device, do_s
|
|
| 395 |
pred_json_example = None
|
| 396 |
return pred_json_example
|
| 397 |
|
| 398 |
-
def inference(generate_method, intention, model, quantizer, tokenizer, width, height,
|
| 399 |
rawdata = {}
|
| 400 |
rawdata["wholecaption"] = intention
|
| 401 |
rawdata["layout"] = []
|
|
@@ -404,7 +401,7 @@ def inference(generate_method, intention, model, quantizer, tokenizer, width, he
|
|
| 404 |
max_try_time = 5
|
| 405 |
preddata = None
|
| 406 |
while preddata is None and max_try_time > 0:
|
| 407 |
-
preddata = evaluate_v1(rawdata, model, quantizer, tokenizer, width, height,
|
| 408 |
max_try_time -= 1
|
| 409 |
else:
|
| 410 |
print("Please input correct generate method")
|
|
@@ -412,41 +409,6 @@ def inference(generate_method, intention, model, quantizer, tokenizer, width, he
|
|
| 412 |
|
| 413 |
return preddata
|
| 414 |
|
| 415 |
-
# @spaces.GPU(enable_queue=True, duration=120)
|
| 416 |
-
def construction():
|
| 417 |
-
global pipeline
|
| 418 |
-
global transp_vae
|
| 419 |
-
from custom_model_mmdit import CustomFluxTransformer2DModel
|
| 420 |
-
from custom_model_transp_vae import AutoencoderKLTransformerTraining as CustomVAE
|
| 421 |
-
from custom_pipeline import CustomFluxPipelineCfg
|
| 422 |
-
|
| 423 |
-
transformer = CustomFluxTransformer2DModel.from_pretrained(
|
| 424 |
-
"WYBar/ART_test_weights",
|
| 425 |
-
subfolder="fused_transformer",
|
| 426 |
-
torch_dtype=torch.bfloat16,
|
| 427 |
-
# cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
|
| 428 |
-
)
|
| 429 |
-
|
| 430 |
-
transp_vae = CustomVAE.from_pretrained(
|
| 431 |
-
"WYBar/ART_test_weights",
|
| 432 |
-
subfolder="custom_vae",
|
| 433 |
-
torch_dtype=torch.float32,
|
| 434 |
-
use_safetensors=True,
|
| 435 |
-
# cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
|
| 436 |
-
)
|
| 437 |
-
|
| 438 |
-
token = os.environ.get("HF_TOKEN")
|
| 439 |
-
pipeline = CustomFluxPipelineCfg.from_pretrained(
|
| 440 |
-
"black-forest-labs/FLUX.1-dev",
|
| 441 |
-
transformer=transformer,
|
| 442 |
-
torch_dtype=torch.bfloat16,
|
| 443 |
-
token=token,
|
| 444 |
-
# cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
|
| 445 |
-
).to("cuda")
|
| 446 |
-
pipeline.enable_model_cpu_offload(gpu_id=0) # Save GPU memory
|
| 447 |
-
|
| 448 |
-
# return pipeline, transp_vae
|
| 449 |
-
|
| 450 |
@spaces.GPU(duration=120)
|
| 451 |
def test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, pipeline, generator, transp_vae):
|
| 452 |
print(validation_box)
|
|
@@ -477,7 +439,7 @@ def test_one_sample(validation_box, validation_prompt, true_gs, inference_steps,
|
|
| 477 |
return output_gradio
|
| 478 |
|
| 479 |
def svg_test_one_sample(validation_prompt, validation_box_str, seed, true_gs, inference_steps, pipeline, transp_vae):
|
| 480 |
-
print("svg_test_one_sample")
|
| 481 |
generator = torch.Generator().manual_seed(seed)
|
| 482 |
try:
|
| 483 |
validation_box = ast.literal_eval(validation_box_str)
|
|
@@ -511,7 +473,7 @@ def svg_test_one_sample(validation_prompt, validation_box_str, seed, true_gs, in
|
|
| 511 |
return result_images, svg_file_path
|
| 512 |
|
| 513 |
def process_svg(text_input, tuple_input, seed, true_gs, inference_steps):
|
| 514 |
-
print("precess_svg")
|
| 515 |
result_images = []
|
| 516 |
result_images, svg_file_path = svg_test_one_sample(text_input, tuple_input, seed, true_gs, inference_steps, pipeline=pipeline, transp_vae=transp_vae)
|
| 517 |
# result_images, svg_file_path = gradio_test_one_sample_partial(text_input, tuple_input, seed, true_gs, inference_steps)
|
|
@@ -534,64 +496,52 @@ def process_svg(text_input, tuple_input, seed, true_gs, inference_steps):
|
|
| 534 |
"""
|
| 535 |
|
| 536 |
return result_images, svg_file_path, svg_editor
|
| 537 |
-
|
| 538 |
-
def
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
if temperature == 0.0:
|
| 559 |
-
# print("looking for greedy decoding strategies, set `do_sample=False`.")
|
| 560 |
-
# preddata = inference_partial(generate_method, intention, do_sample=False)
|
| 561 |
-
preddata = inference(generate_method, intention, model=model, quantizer=quantizer, tokenizer=tokenizer, width=width, height=height, device=device, do_sample=False)
|
| 562 |
else:
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
for i, layout in enumerate(layouts):
|
| 569 |
-
x, y = layout["x"], layout["y"]
|
| 570 |
-
width, height = layout["width"], layout["height"]
|
| 571 |
-
if i == 0:
|
| 572 |
-
list_box.append((0, 0, width, height))
|
| 573 |
-
list_box.append((0, 0, width, height))
|
| 574 |
-
else:
|
| 575 |
-
left = x - width // 2
|
| 576 |
-
top = y - height // 2
|
| 577 |
-
right = x + width // 2
|
| 578 |
-
bottom = y + height // 2
|
| 579 |
-
list_box.append((left, top, right, bottom))
|
| 580 |
-
|
| 581 |
-
# print(list_box)
|
| 582 |
-
filtered_boxes = list_box[:2]
|
| 583 |
-
for i in range(2, len(list_box)):
|
| 584 |
-
keep = True
|
| 585 |
-
for j in range(1, len(filtered_boxes)):
|
| 586 |
-
iou = calculate_iou(list_box[i], filtered_boxes[j])
|
| 587 |
-
if iou > 0.65:
|
| 588 |
-
print(list_box[i], filtered_boxes[j])
|
| 589 |
-
keep = False
|
| 590 |
-
break
|
| 591 |
-
if keep:
|
| 592 |
-
filtered_boxes.append(list_box[i])
|
| 593 |
|
| 594 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 595 |
|
| 596 |
# def process_preddate(intention, generate_method='v1'):
|
| 597 |
# list_box = [(0, 0, 512, 512), (0, 0, 512, 512), (136, 184, 512, 512), (144, 0, 512, 512), (0, 0, 328, 136), (160, 112, 512, 360), (168, 112, 512, 360), (40, 232, 112, 296), (32, 88, 248, 176), (48, 424, 144, 448), (48, 464, 144, 488), (240, 464, 352, 488), (384, 464, 488, 488), (48, 480, 144, 504), (240, 480, 360, 504), (456, 0, 512, 56), (0, 0, 56, 40), (440, 0, 512, 40), (0, 24, 48, 88), (48, 168, 168, 240)]
|
|
@@ -600,7 +550,6 @@ def main():
|
|
| 600 |
# return wholecaption, str(list_box), json_file
|
| 601 |
|
| 602 |
# pipeline, transp_vae = construction()
|
| 603 |
-
construction()
|
| 604 |
|
| 605 |
# gradio_test_one_sample_partial = partial(
|
| 606 |
# svg_test_one_sample,
|
|
|
|
| 228 |
iou = intersection_area / union_area
|
| 229 |
return iou
|
| 230 |
|
| 231 |
+
def construction_all():
|
|
|
|
| 232 |
global model
|
| 233 |
global quantizer
|
| 234 |
global tokenizer
|
| 235 |
+
global pipeline
|
| 236 |
+
global transp_vae
|
| 237 |
from modeling_crello import CrelloModel, CrelloModelConfig
|
| 238 |
from quantizer import get_quantizer
|
| 239 |
+
from custom_model_mmdit import CustomFluxTransformer2DModel
|
| 240 |
+
from custom_model_transp_vae import AutoencoderKLTransformerTraining as CustomVAE
|
| 241 |
+
from custom_pipeline import CustomFluxPipelineCfg
|
| 242 |
+
|
| 243 |
+
params_dict = {
|
| 244 |
+
"input_model": "/openseg_blob/v-sirui/temporary/2024-02-21/Layout_train/COLEv2/Design_LLM/checkpoint/Meta-Llama-3-8B",
|
| 245 |
+
"resume": "/openseg_blob/v-sirui/temporary/2024-02-21/SVD/Int2lay_1016/checkpoint/int2lay_1031/1031_test/checkpoint-26000/",
|
| 246 |
+
"seed": 0,
|
| 247 |
+
"mask_values": False,
|
| 248 |
+
"quantizer_version": 'v4',
|
| 249 |
+
"mask_type": 'cm3',
|
| 250 |
+
"decimal_quantize_types": [],
|
| 251 |
+
"num_mask_tokens": 0,
|
| 252 |
+
"width": 512,
|
| 253 |
+
"height": 512,
|
| 254 |
+
"device": 0,
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
# seed / input model / resume
|
| 258 |
+
seed = params_dict.get('seed', None)
|
| 259 |
+
input_model = params_dict.get('input_model', None)
|
| 260 |
+
quantizer_version = params_dict.get('quantizer_version', 'v4')
|
|
|
|
| 261 |
|
| 262 |
set_seed(seed)
|
| 263 |
# old_tokenizer = AutoTokenizer.from_pretrained(input_model, trust_remote_code=True)
|
|
|
|
| 279 |
quantizer = get_quantizer(
|
| 280 |
quantizer_version,
|
| 281 |
update_vocab = False,
|
| 282 |
+
decimal_quantize_types = params_dict.get('decimal_quantize_types'),
|
| 283 |
+
mask_values = params_dict['mask_values'],
|
| 284 |
+
width = params_dict['width'],
|
| 285 |
+
height = params_dict['height'],
|
| 286 |
simplify_json = False,
|
| 287 |
num_mask_tokens = 0,
|
| 288 |
+
mask_type = params_dict.get('mask_type'),
|
| 289 |
)
|
| 290 |
quantizer.setup_tokenizer(tokenizer)
|
| 291 |
|
|
|
|
| 298 |
model_args.freeze_lm = False
|
| 299 |
model_args.opt_version = input_model
|
| 300 |
model_args.use_lora = False
|
| 301 |
+
model_args.load_in_4bit = params_dict.get('load_in_4bit', False)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
model = CrelloModel.from_pretrained(
|
| 304 |
"WYBar/LLM_For_Layout_Planning",
|
|
|
|
| 314 |
for token in added_special_tokens_list:
|
| 315 |
quantizer.additional_special_tokens.add(token)
|
| 316 |
|
| 317 |
+
transformer = CustomFluxTransformer2DModel.from_pretrained(
|
| 318 |
+
"WYBar/ART_test_weights",
|
| 319 |
+
subfolder="fused_transformer",
|
| 320 |
+
torch_dtype=torch.bfloat16,
|
| 321 |
+
# cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
transp_vae = CustomVAE.from_pretrained(
|
| 325 |
+
"WYBar/ART_test_weights",
|
| 326 |
+
subfolder="custom_vae",
|
| 327 |
+
torch_dtype=torch.float32,
|
| 328 |
+
use_safetensors=True,
|
| 329 |
+
# cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
token = os.environ.get("HF_TOKEN")
|
| 333 |
+
pipeline = CustomFluxPipelineCfg.from_pretrained(
|
| 334 |
+
"black-forest-labs/FLUX.1-dev",
|
| 335 |
+
transformer=transformer,
|
| 336 |
+
torch_dtype=torch.bfloat16,
|
| 337 |
+
token=token,
|
| 338 |
+
# cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir"
|
| 339 |
+
).to("cuda")
|
| 340 |
+
pipeline.enable_model_cpu_offload(gpu_id=0) # Save GPU memory
|
| 341 |
+
|
| 342 |
+
print(f"before .to(device):{model.device} {model.lm.device} {pipeline.device}")
|
| 343 |
model = model.to("cuda")
|
| 344 |
+
pipeline = pipeline.to("cuda")
|
| 345 |
+
print(f"after .to(device):{model.device} {model.lm.device} {pipeline.device}")
|
| 346 |
model = model.bfloat16()
|
| 347 |
model.eval()
|
| 348 |
+
print(f"after bf16 & eval .to(device):{model.device} {model.lm.device} {pipeline.device}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
@torch.no_grad()
|
| 351 |
@spaces.GPU(duration=120)
|
| 352 |
+
def evaluate_v1(inputs, model, quantizer, tokenizer, width, height, do_sample=False, temperature=1.0, top_p=1.0, top_k=50):
|
| 353 |
+
print(f"evaluate_v1 {model.device} {model.lm.device} {pipeline.device}")
|
| 354 |
json_example = inputs
|
| 355 |
input_intension = '{"wholecaption":"' + json_example["wholecaption"] + '","layout":[{"layer":'
|
| 356 |
+
|
| 357 |
print("tokenizer1")
|
| 358 |
inputs = tokenizer(
|
| 359 |
input_intension, return_tensors="pt"
|
|
|
|
| 392 |
pred_json_example = None
|
| 393 |
return pred_json_example
|
| 394 |
|
| 395 |
+
def inference(generate_method, intention, model, quantizer, tokenizer, width, height, do_sample=True, temperature=1.0, top_p=1.0, top_k=50):
|
| 396 |
rawdata = {}
|
| 397 |
rawdata["wholecaption"] = intention
|
| 398 |
rawdata["layout"] = []
|
|
|
|
| 401 |
max_try_time = 5
|
| 402 |
preddata = None
|
| 403 |
while preddata is None and max_try_time > 0:
|
| 404 |
+
preddata = evaluate_v1(rawdata, model, quantizer, tokenizer, width, height, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k)
|
| 405 |
max_try_time -= 1
|
| 406 |
else:
|
| 407 |
print("Please input correct generate method")
|
|
|
|
| 409 |
|
| 410 |
return preddata
|
| 411 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
@spaces.GPU(duration=120)
|
| 413 |
def test_one_sample(validation_box, validation_prompt, true_gs, inference_steps, pipeline, generator, transp_vae):
|
| 414 |
print(validation_box)
|
|
|
|
| 439 |
return output_gradio
|
| 440 |
|
| 441 |
def svg_test_one_sample(validation_prompt, validation_box_str, seed, true_gs, inference_steps, pipeline, transp_vae):
|
| 442 |
+
print(f"svg_test_one_sample {model.device} {model.lm.device} {pipeline.device}")
|
| 443 |
generator = torch.Generator().manual_seed(seed)
|
| 444 |
try:
|
| 445 |
validation_box = ast.literal_eval(validation_box_str)
|
|
|
|
| 473 |
return result_images, svg_file_path
|
| 474 |
|
| 475 |
def process_svg(text_input, tuple_input, seed, true_gs, inference_steps):
|
| 476 |
+
print(f"precess_svg {model.device} {model.lm.device} {pipeline.device}")
|
| 477 |
result_images = []
|
| 478 |
result_images, svg_file_path = svg_test_one_sample(text_input, tuple_input, seed, true_gs, inference_steps, pipeline=pipeline, transp_vae=transp_vae)
|
| 479 |
# result_images, svg_file_path = gradio_test_one_sample_partial(text_input, tuple_input, seed, true_gs, inference_steps)
|
|
|
|
| 496 |
"""
|
| 497 |
|
| 498 |
return result_images, svg_file_path, svg_editor
|
| 499 |
+
|
| 500 |
+
def process_preddate(intention, temperature, top_p, generate_method='v1'):
|
| 501 |
+
intention = intention.replace('\n', '').replace('\r', '').replace('\\', '')
|
| 502 |
+
intention = ensure_space_after_period(intention)
|
| 503 |
+
print(f"process_preddate: {model.lm.device}")
|
| 504 |
+
if temperature == 0.0:
|
| 505 |
+
# print("looking for greedy decoding strategies, set `do_sample=False`.")
|
| 506 |
+
# preddata = inference_partial(generate_method, intention, do_sample=False)
|
| 507 |
+
preddata = inference(generate_method, intention, model=model, quantizer=quantizer, tokenizer=tokenizer, width=512, height=512, do_sample=False)
|
| 508 |
+
else:
|
| 509 |
+
# preddata = inference_partial(generate_method, intention, temperature=temperature, top_p=top_p)
|
| 510 |
+
preddata = inference(generate_method, intention, model=model, quantizer=quantizer, tokenizer=tokenizer, width=512, height=512, temperature=temperature, top_p=top_p)
|
| 511 |
+
|
| 512 |
+
layouts = preddata["layout"]
|
| 513 |
+
list_box = []
|
| 514 |
+
for i, layout in enumerate(layouts):
|
| 515 |
+
x, y = layout["x"], layout["y"]
|
| 516 |
+
width, height = layout["width"], layout["height"]
|
| 517 |
+
if i == 0:
|
| 518 |
+
list_box.append((0, 0, width, height))
|
| 519 |
+
list_box.append((0, 0, width, height))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
else:
|
| 521 |
+
left = x - width // 2
|
| 522 |
+
top = y - height // 2
|
| 523 |
+
right = x + width // 2
|
| 524 |
+
bottom = y + height // 2
|
| 525 |
+
list_box.append((left, top, right, bottom))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
|
| 527 |
+
# print(list_box)
|
| 528 |
+
filtered_boxes = list_box[:2]
|
| 529 |
+
for i in range(2, len(list_box)):
|
| 530 |
+
keep = True
|
| 531 |
+
for j in range(1, len(filtered_boxes)):
|
| 532 |
+
iou = calculate_iou(list_box[i], filtered_boxes[j])
|
| 533 |
+
if iou > 0.65:
|
| 534 |
+
print(list_box[i], filtered_boxes[j])
|
| 535 |
+
keep = False
|
| 536 |
+
break
|
| 537 |
+
if keep:
|
| 538 |
+
filtered_boxes.append(list_box[i])
|
| 539 |
+
|
| 540 |
+
return str(filtered_boxes), intention, str(filtered_boxes)
|
| 541 |
+
|
| 542 |
+
def main():
|
| 543 |
+
construction_all()
|
| 544 |
+
print(f"after construction_all:{model.device} {model.lm.device} {pipeline.device}")
|
| 545 |
|
| 546 |
# def process_preddate(intention, generate_method='v1'):
|
| 547 |
# list_box = [(0, 0, 512, 512), (0, 0, 512, 512), (136, 184, 512, 512), (144, 0, 512, 512), (0, 0, 328, 136), (160, 112, 512, 360), (168, 112, 512, 360), (40, 232, 112, 296), (32, 88, 248, 176), (48, 424, 144, 448), (48, 464, 144, 488), (240, 464, 352, 488), (384, 464, 488, 488), (48, 480, 144, 504), (240, 480, 360, 504), (456, 0, 512, 56), (0, 0, 56, 40), (440, 0, 512, 40), (0, 24, 48, 88), (48, 168, 168, 240)]
|
|
|
|
| 550 |
# return wholecaption, str(list_box), json_file
|
| 551 |
|
| 552 |
# pipeline, transp_vae = construction()
|
|
|
|
| 553 |
|
| 554 |
# gradio_test_one_sample_partial = partial(
|
| 555 |
# svg_test_one_sample,
|