Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -70,8 +70,6 @@ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompt
|
|
| 70 |
truncation=True,
|
| 71 |
return_tensors="pt",
|
| 72 |
)
|
| 73 |
-
for name, param in text_encoder.named_parameters():
|
| 74 |
-
print(name, param.device)
|
| 75 |
|
| 76 |
print(f"Text Encoder Device: {text_encoder.device}")
|
| 77 |
text_input_ids = text_inputs.input_ids.cuda()
|
|
@@ -90,7 +88,7 @@ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompt
|
|
| 90 |
|
| 91 |
|
| 92 |
@torch.no_grad()
|
| 93 |
-
def model_main(args, master_port, rank
|
| 94 |
# import here to avoid huggingface Tokenizer parallelism warnings
|
| 95 |
from diffusers.models import AutoencoderKL
|
| 96 |
from transformers import AutoModel, AutoTokenizer
|
|
@@ -106,10 +104,10 @@ def model_main(args, master_port, rank, request_queue, response_queue, mp_barrie
|
|
| 106 |
# Override the built-in print with the new version
|
| 107 |
builtins.print = print
|
| 108 |
|
| 109 |
-
os.environ["MASTER_PORT"] = str(master_port)
|
| 110 |
-
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
| 111 |
-
os.environ["RANK"] = str(rank)
|
| 112 |
-
os.environ["WORLD_SIZE"] = str(args.num_gpus)
|
| 113 |
|
| 114 |
|
| 115 |
train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
|
|
@@ -159,8 +157,12 @@ def model_main(args, master_port, rank, request_queue, response_queue, mp_barrie
|
|
| 159 |
ckpt = torch.load(ckpt_path, map_location="cuda")
|
| 160 |
model.load_state_dict(ckpt, strict=True)
|
| 161 |
print('load model finish')
|
| 162 |
-
|
|
|
|
|
|
|
| 163 |
|
|
|
|
|
|
|
| 164 |
with torch.autocast("cuda", dtype):
|
| 165 |
while True:
|
| 166 |
(
|
|
@@ -178,7 +180,7 @@ def model_main(args, master_port, rank, request_queue, response_queue, mp_barrie
|
|
| 178 |
scaling_method,
|
| 179 |
scaling_watershed,
|
| 180 |
proportional_attn,
|
| 181 |
-
) =
|
| 182 |
|
| 183 |
|
| 184 |
system_prompt = system_type
|
|
@@ -243,13 +245,13 @@ def model_main(args, master_port, rank, request_queue, response_queue, mp_barrie
|
|
| 243 |
torch.random.manual_seed(int(seed))
|
| 244 |
z = torch.randn([1, 16, latent_h, latent_w], device="cuda").to(dtype)
|
| 245 |
z = z.repeat(2, 1, 1, 1)
|
| 246 |
-
|
| 247 |
with torch.no_grad():
|
| 248 |
if neg_cap != "":
|
| 249 |
cap_feats, cap_mask = encode_prompt([cap] + [neg_cap], text_encoder, tokenizer, 0.0)
|
| 250 |
else:
|
| 251 |
cap_feats, cap_mask = encode_prompt([cap] + [""], text_encoder, tokenizer, 0.0)
|
| 252 |
-
|
| 253 |
cap_mask = cap_mask.to(cap_feats.device)
|
| 254 |
|
| 255 |
model_kwargs = dict(
|
|
@@ -297,12 +299,13 @@ def model_main(args, master_port, rank, request_queue, response_queue, mp_barrie
|
|
| 297 |
img = to_pil_image(samples[0, :].float())
|
| 298 |
print("> generated image, done.")
|
| 299 |
|
| 300 |
-
if response_queue is not None:
|
| 301 |
-
|
| 302 |
-
|
| 303 |
except Exception:
|
| 304 |
print(traceback.format_exc())
|
| 305 |
-
|
|
|
|
| 306 |
|
| 307 |
|
| 308 |
def none_or_str(value):
|
|
@@ -389,25 +392,27 @@ def main():
|
|
| 389 |
|
| 390 |
master_port = find_free_port()
|
| 391 |
#mp.set_start_method("fork")
|
| 392 |
-
processes = []
|
| 393 |
-
request_queues = []
|
| 394 |
-
response_queue = mp.Queue()
|
| 395 |
-
mp_barrier = mp.Barrier(args.num_gpus + 1)
|
| 396 |
-
for i in range(args.num_gpus):
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
|
|
|
|
|
|
| 411 |
|
| 412 |
description = args.ckpt.split('/')[-1]
|
| 413 |
#"""
|
|
@@ -552,15 +557,18 @@ def main():
|
|
| 552 |
) # noqa
|
| 553 |
|
| 554 |
@spaces.GPU(duration=200)
|
| 555 |
-
def on_submit(*
|
| 556 |
-
for q in request_queues:
|
| 557 |
-
|
| 558 |
-
result = response_queue.get()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 559 |
if isinstance(result, ModelFailure):
|
| 560 |
-
raise RuntimeError
|
| 561 |
-
img, metadata = result
|
| 562 |
|
| 563 |
-
return
|
| 564 |
|
| 565 |
submit_btn.click(
|
| 566 |
on_submit,
|
|
|
|
| 70 |
truncation=True,
|
| 71 |
return_tensors="pt",
|
| 72 |
)
|
|
|
|
|
|
|
| 73 |
|
| 74 |
print(f"Text Encoder Device: {text_encoder.device}")
|
| 75 |
text_input_ids = text_inputs.input_ids.cuda()
|
|
|
|
| 88 |
|
| 89 |
|
| 90 |
@torch.no_grad()
|
| 91 |
+
def model_main(args, master_port, rank):
|
| 92 |
# import here to avoid huggingface Tokenizer parallelism warnings
|
| 93 |
from diffusers.models import AutoencoderKL
|
| 94 |
from transformers import AutoModel, AutoTokenizer
|
|
|
|
| 104 |
# Override the built-in print with the new version
|
| 105 |
builtins.print = print
|
| 106 |
|
| 107 |
+
# os.environ["MASTER_PORT"] = str(master_port)
|
| 108 |
+
# os.environ["MASTER_ADDR"] = "127.0.0.1"
|
| 109 |
+
# os.environ["RANK"] = str(rank)
|
| 110 |
+
# os.environ["WORLD_SIZE"] = str(args.num_gpus)
|
| 111 |
|
| 112 |
|
| 113 |
train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
|
|
|
|
| 157 |
ckpt = torch.load(ckpt_path, map_location="cuda")
|
| 158 |
model.load_state_dict(ckpt, strict=True)
|
| 159 |
print('load model finish')
|
| 160 |
+
|
| 161 |
+
return text_encoder, tokenizer, vae, model
|
| 162 |
+
|
| 163 |
|
| 164 |
+
@torch.no_grad()
|
| 165 |
+
def inference(args, infer_args, text_encoder, tokenizer, vae, model)
|
| 166 |
with torch.autocast("cuda", dtype):
|
| 167 |
while True:
|
| 168 |
(
|
|
|
|
| 180 |
scaling_method,
|
| 181 |
scaling_watershed,
|
| 182 |
proportional_attn,
|
| 183 |
+
) = infer_args
|
| 184 |
|
| 185 |
|
| 186 |
system_prompt = system_type
|
|
|
|
| 245 |
torch.random.manual_seed(int(seed))
|
| 246 |
z = torch.randn([1, 16, latent_h, latent_w], device="cuda").to(dtype)
|
| 247 |
z = z.repeat(2, 1, 1, 1)
|
| 248 |
+
|
| 249 |
with torch.no_grad():
|
| 250 |
if neg_cap != "":
|
| 251 |
cap_feats, cap_mask = encode_prompt([cap] + [neg_cap], text_encoder, tokenizer, 0.0)
|
| 252 |
else:
|
| 253 |
cap_feats, cap_mask = encode_prompt([cap] + [""], text_encoder, tokenizer, 0.0)
|
| 254 |
+
|
| 255 |
cap_mask = cap_mask.to(cap_feats.device)
|
| 256 |
|
| 257 |
model_kwargs = dict(
|
|
|
|
| 299 |
img = to_pil_image(samples[0, :].float())
|
| 300 |
print("> generated image, done.")
|
| 301 |
|
| 302 |
+
# if response_queue is not None:
|
| 303 |
+
# response_queue.put((img, metadata))
|
| 304 |
+
return img, metadata
|
| 305 |
except Exception:
|
| 306 |
print(traceback.format_exc())
|
| 307 |
+
return ModelFailure()
|
| 308 |
+
# response_queue.put(ModelFailure())
|
| 309 |
|
| 310 |
|
| 311 |
def none_or_str(value):
|
|
|
|
| 392 |
|
| 393 |
master_port = find_free_port()
|
| 394 |
#mp.set_start_method("fork")
|
| 395 |
+
# processes = []
|
| 396 |
+
# request_queues = []
|
| 397 |
+
# response_queue = mp.Queue()
|
| 398 |
+
# mp_barrier = mp.Barrier(args.num_gpus + 1)
|
| 399 |
+
# for i in range(args.num_gpus):
|
| 400 |
+
# request_queues.append(mp.Queue())
|
| 401 |
+
# p = mp.Process(
|
| 402 |
+
# target=model_main,
|
| 403 |
+
# args=(
|
| 404 |
+
# args,
|
| 405 |
+
# master_port,
|
| 406 |
+
# i,
|
| 407 |
+
# request_queues[i],
|
| 408 |
+
# response_queue if i == 0 else None,
|
| 409 |
+
# mp_barrier,
|
| 410 |
+
# ),
|
| 411 |
+
# )
|
| 412 |
+
# p.start()
|
| 413 |
+
# processes.append(p)
|
| 414 |
+
|
| 415 |
+
model_main(args, master_port, 0)
|
| 416 |
|
| 417 |
description = args.ckpt.split('/')[-1]
|
| 418 |
#"""
|
|
|
|
| 557 |
) # noqa
|
| 558 |
|
| 559 |
@spaces.GPU(duration=200)
|
| 560 |
+
def on_submit(*infer_args):
|
| 561 |
+
# for q in request_queues:
|
| 562 |
+
# q.put(args)
|
| 563 |
+
# result = response_queue.get()
|
| 564 |
+
# if isinstance(result, ModelFailure):
|
| 565 |
+
# raise RuntimeError
|
| 566 |
+
# img, metadata = result
|
| 567 |
+
result = inference(args, infer_args, text_encoder, tokenizer, vae, model)
|
| 568 |
if isinstance(result, ModelFailure):
|
| 569 |
+
raise RuntimeError("Model failed to generate the image.")
|
|
|
|
| 570 |
|
| 571 |
+
return result
|
| 572 |
|
| 573 |
submit_btn.click(
|
| 574 |
on_submit,
|