Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@ import spaces
|
|
| 4 |
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
|
| 5 |
from huggingface_hub import hf_hub_download
|
| 6 |
from safetensors.torch import load_file
|
|
|
|
| 7 |
|
| 8 |
assert torch.cuda.is_available()
|
| 9 |
|
|
@@ -36,7 +37,9 @@ def generate(prompt, option, progress=gr.Progress()):
|
|
| 36 |
print(prompt, option)
|
| 37 |
ckpt, step = opts[option]
|
| 38 |
if any(word in prompt for word in filter_words):
|
| 39 |
-
|
|
|
|
|
|
|
| 40 |
progress((0, step))
|
| 41 |
if step != step_loaded:
|
| 42 |
print(f"Switching checkpoint from {step_loaded} to {step}")
|
|
@@ -46,7 +49,17 @@ def generate(prompt, option, progress=gr.Progress()):
|
|
| 46 |
def inference_callback(p, i, t, kwargs):
|
| 47 |
progress((i+1, step))
|
| 48 |
return kwargs
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
with gr.Blocks(css="style.css") as demo:
|
| 52 |
gr.HTML(
|
|
|
|
| 4 |
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
|
| 5 |
from huggingface_hub import hf_hub_download
|
| 6 |
from safetensors.torch import load_file
|
| 7 |
+
from PIL import Image
|
| 8 |
|
| 9 |
assert torch.cuda.is_available()
|
| 10 |
|
|
|
|
| 37 |
print(prompt, option)
|
| 38 |
ckpt, step = opts[option]
|
| 39 |
if any(word in prompt for word in filter_words):
|
| 40 |
+
gr.Warning("Safety checker triggered.")
|
| 41 |
+
print(f"Safety checker triggered on prompt: {prompt}")
|
| 42 |
+
return Image.new("RGB", (512, 512))
|
| 43 |
progress((0, step))
|
| 44 |
if step != step_loaded:
|
| 45 |
print(f"Switching checkpoint from {step_loaded} to {step}")
|
|
|
|
| 49 |
def inference_callback(p, i, t, kwargs):
|
| 50 |
progress((i+1, step))
|
| 51 |
return kwargs
|
| 52 |
+
results = pipe(prompt, num_inference_steps=step, guidance_scale=0, callback_on_step_end=inference_callback)
|
| 53 |
+
nsfw_content_detected = (
|
| 54 |
+
results.nsfw_content_detected[0]
|
| 55 |
+
if "nsfw_content_detected" in results
|
| 56 |
+
else False
|
| 57 |
+
)
|
| 58 |
+
if nsfw_content_detected:
|
| 59 |
+
gr.Warning("Safety checker triggered.")
|
| 60 |
+
print(f"Safety checker triggered on prompt: {prompt}")
|
| 61 |
+
return Image.new("RGB", (512, 512))
|
| 62 |
+
return results.images[0]
|
| 63 |
|
| 64 |
with gr.Blocks(css="style.css") as demo:
|
| 65 |
gr.HTML(
|