APG checkbox
Browse files
app.py
CHANGED
|
@@ -12,6 +12,8 @@ from f_lite import FLitePipeline
|
|
| 12 |
|
| 13 |
# Trick required because it is not a native diffusers model
|
| 14 |
from diffusers.pipelines.pipeline_loading_utils import LOADABLE_CLASSES, ALL_IMPORTABLE_CLASSES
|
|
|
|
|
|
|
| 15 |
LOADABLE_CLASSES["f_lite"] = LOADABLE_CLASSES["f_lite.model"] = {"DiT": ["save_pretrained", "from_pretrained"]}
|
| 16 |
ALL_IMPORTABLE_CLASSES["DiT"] = ["save_pretrained", "from_pretrained"]
|
| 17 |
|
|
@@ -26,7 +28,7 @@ else:
|
|
| 26 |
logging.warning("GEMINI_API_KEY not found in environment variables. Prompt enrichment will not work.")
|
| 27 |
|
| 28 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
-
model_repo_id = "
|
| 30 |
|
| 31 |
if torch.cuda.is_available():
|
| 32 |
torch_dtype = torch.bfloat16
|
|
@@ -135,6 +137,7 @@ def infer(
|
|
| 135 |
guidance_scale,
|
| 136 |
num_inference_steps,
|
| 137 |
use_prompt_enrichment,
|
|
|
|
| 138 |
progress=gr.Progress(track_tqdm=True),
|
| 139 |
):
|
| 140 |
enriched_prompt_str = None
|
|
@@ -160,6 +163,7 @@ def infer(
|
|
| 160 |
width=width,
|
| 161 |
height=height,
|
| 162 |
generator=generator,
|
|
|
|
| 163 |
).images[0]
|
| 164 |
|
| 165 |
# Prepare Gradio updates for the enriched prompt display
|
|
@@ -287,6 +291,10 @@ with gr.Blocks(css=css, theme="ParityError/Interstellar") as demo:
|
|
| 287 |
step=0.1,
|
| 288 |
value=6,
|
| 289 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
num_inference_steps = gr.Slider(
|
| 292 |
label="Number of inference steps",
|
|
@@ -334,6 +342,7 @@ with gr.Blocks(css=css, theme="ParityError/Interstellar") as demo:
|
|
| 334 |
guidance_scale,
|
| 335 |
num_inference_steps,
|
| 336 |
use_prompt_enrichment,
|
|
|
|
| 337 |
],
|
| 338 |
outputs=[result, seed, enriched_prompt_display, enriched_prompt_text, enrichment_error],
|
| 339 |
)
|
|
@@ -342,4 +351,4 @@ with gr.Blocks(css=css, theme="ParityError/Interstellar") as demo:
|
|
| 342 |
gr.Markdown("[F-Lite Model Card and Weights](https://huggingface.co/Freepik/F-Lite)")
|
| 343 |
|
| 344 |
if __name__ == "__main__":
|
| 345 |
-
demo.launch()
|
|
|
|
| 12 |
|
| 13 |
# Trick required because it is not a native diffusers model
|
| 14 |
from diffusers.pipelines.pipeline_loading_utils import LOADABLE_CLASSES, ALL_IMPORTABLE_CLASSES
|
| 15 |
+
|
| 16 |
+
from f_lite.pipeline import APGConfig
|
| 17 |
LOADABLE_CLASSES["f_lite"] = LOADABLE_CLASSES["f_lite.model"] = {"DiT": ["save_pretrained", "from_pretrained"]}
|
| 18 |
ALL_IMPORTABLE_CLASSES["DiT"] = ["save_pretrained", "from_pretrained"]
|
| 19 |
|
|
|
|
| 28 |
logging.warning("GEMINI_API_KEY not found in environment variables. Prompt enrichment will not work.")
|
| 29 |
|
| 30 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
+
model_repo_id = "./grpo_hf"
|
| 32 |
|
| 33 |
if torch.cuda.is_available():
|
| 34 |
torch_dtype = torch.bfloat16
|
|
|
|
| 137 |
guidance_scale,
|
| 138 |
num_inference_steps,
|
| 139 |
use_prompt_enrichment,
|
| 140 |
+
enable_apg,
|
| 141 |
progress=gr.Progress(track_tqdm=True),
|
| 142 |
):
|
| 143 |
enriched_prompt_str = None
|
|
|
|
| 163 |
width=width,
|
| 164 |
height=height,
|
| 165 |
generator=generator,
|
| 166 |
+
apg_config=APGConfig(enabled=enable_apg)
|
| 167 |
).images[0]
|
| 168 |
|
| 169 |
# Prepare Gradio updates for the enriched prompt display
|
|
|
|
| 291 |
step=0.1,
|
| 292 |
value=6,
|
| 293 |
)
|
| 294 |
+
enable_apg = gr.Checkbox(
|
| 295 |
+
label="Enable APG",
|
| 296 |
+
value=True,
|
| 297 |
+
)
|
| 298 |
|
| 299 |
num_inference_steps = gr.Slider(
|
| 300 |
label="Number of inference steps",
|
|
|
|
| 342 |
guidance_scale,
|
| 343 |
num_inference_steps,
|
| 344 |
use_prompt_enrichment,
|
| 345 |
+
enable_apg,
|
| 346 |
],
|
| 347 |
outputs=[result, seed, enriched_prompt_display, enriched_prompt_text, enrichment_error],
|
| 348 |
)
|
|
|
|
| 351 |
gr.Markdown("[F-Lite Model Card and Weights](https://huggingface.co/Freepik/F-Lite)")
|
| 352 |
|
| 353 |
if __name__ == "__main__":
|
| 354 |
+
demo.launch() # server_name="0.0.0.0", share=True)
|