Update app.py
Browse files
app.py
CHANGED
|
@@ -156,6 +156,7 @@ def get_highlighted_text_tuples(current_x_ids_batch, prompt_input_ids, prompt_le
|
|
| 156 |
return intermediate_tuples
|
| 157 |
|
| 158 |
@torch.no_grad()
|
|
|
|
| 159 |
def generate_viz_wrapper_t2i(prompt_text, steps, guidance_scale, mask_schedule="cosine"):
|
| 160 |
global MODEL, TOKENIZER, MASK_ID, DEVICE, uni_prompting
|
| 161 |
|
|
@@ -196,6 +197,7 @@ def generate_viz_wrapper_t2i(prompt_text, steps, guidance_scale, mask_schedule="
|
|
| 196 |
|
| 197 |
|
| 198 |
@torch.no_grad()
|
|
|
|
| 199 |
def generate_viz_wrapper_lm(prompt_text, steps, gen_length, block_length, temperature,
|
| 200 |
cfg_scale, remasking_strategy, thinking_mode_lm=False):
|
| 201 |
global MODEL, TOKENIZER, MASK_ID, DEVICE
|
|
@@ -349,6 +351,7 @@ def generate_viz_wrapper_lm(prompt_text, steps, gen_length, block_length, temper
|
|
| 349 |
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str
|
| 350 |
|
| 351 |
@torch.no_grad()
|
|
|
|
| 352 |
def generate_viz_wrapper(uploaded_image_pil, prompt_text, steps, gen_length, block_length, temperature,
|
| 353 |
cfg_scale, remasking_strategy, thinking_mode_mmu=False):
|
| 354 |
global MODEL, TOKENIZER, MASK_ID, DEVICE
|
|
|
|
| 156 |
return intermediate_tuples
|
| 157 |
|
| 158 |
@torch.no_grad()
|
| 159 |
+
@spaces.GPU
|
| 160 |
def generate_viz_wrapper_t2i(prompt_text, steps, guidance_scale, mask_schedule="cosine"):
|
| 161 |
global MODEL, TOKENIZER, MASK_ID, DEVICE, uni_prompting
|
| 162 |
|
|
|
|
| 197 |
|
| 198 |
|
| 199 |
@torch.no_grad()
|
| 200 |
+
@spaces.GPU
|
| 201 |
def generate_viz_wrapper_lm(prompt_text, steps, gen_length, block_length, temperature,
|
| 202 |
cfg_scale, remasking_strategy, thinking_mode_lm=False):
|
| 203 |
global MODEL, TOKENIZER, MASK_ID, DEVICE
|
|
|
|
| 351 |
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str
|
| 352 |
|
| 353 |
@torch.no_grad()
|
| 354 |
+
@spaces.GPU
|
| 355 |
def generate_viz_wrapper(uploaded_image_pil, prompt_text, steps, gen_length, block_length, temperature,
|
| 356 |
cfg_scale, remasking_strategy, thinking_mode_mmu=False):
|
| 357 |
global MODEL, TOKENIZER, MASK_ID, DEVICE
|