Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
452f87d
1
Parent(s):
559f623
Update app.py
Browse files
app.py
CHANGED
|
@@ -17,7 +17,7 @@ if not torch.cuda.is_available():
|
|
| 17 |
MAX_SEED = np.iinfo(np.int32).max
|
| 18 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
| 19 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
|
| 20 |
-
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "
|
| 21 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
| 22 |
|
| 23 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
@@ -32,11 +32,11 @@ if torch.cuda.is_available():
|
|
| 32 |
add_watermarker=False,
|
| 33 |
variant="fp16"
|
| 34 |
)
|
| 35 |
-
if ENABLE_CPU_OFFLOAD:
|
| 36 |
-
|
| 37 |
-
else:
|
| 38 |
-
|
| 39 |
-
|
| 40 |
|
| 41 |
if USE_TORCH_COMPILE:
|
| 42 |
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
|
@@ -67,6 +67,7 @@ def generate(
|
|
| 67 |
use_resolution_binning: bool = True,
|
| 68 |
progress=gr.Progress(track_tqdm=True),
|
| 69 |
):
|
|
|
|
| 70 |
seed = int(randomize_seed_fn(seed, randomize_seed))
|
| 71 |
generator = torch.Generator().manual_seed(seed)
|
| 72 |
|
|
|
|
| 17 |
MAX_SEED = np.iinfo(np.int32).max
|
| 18 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
| 19 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
|
| 20 |
+
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
| 21 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
| 22 |
|
| 23 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 32 |
add_watermarker=False,
|
| 33 |
variant="fp16"
|
| 34 |
)
|
| 35 |
+
#if ENABLE_CPU_OFFLOAD:
|
| 36 |
+
# pipe.enable_model_cpu_offload()
|
| 37 |
+
#else:
|
| 38 |
+
# pipe.to(device)
|
| 39 |
+
# print("Loaded on Device!")
|
| 40 |
|
| 41 |
if USE_TORCH_COMPILE:
|
| 42 |
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
|
|
|
| 67 |
use_resolution_binning: bool = True,
|
| 68 |
progress=gr.Progress(track_tqdm=True),
|
| 69 |
):
|
| 70 |
+
pipe.to(device)
|
| 71 |
seed = int(randomize_seed_fn(seed, randomize_seed))
|
| 72 |
generator = torch.Generator().manual_seed(seed)
|
| 73 |
|