Spaces:
Sleeping
Sleeping
try handle OOM errors
Browse files
app.py
CHANGED
|
@@ -108,14 +108,25 @@ def generate_image(setup_args, num_iterations):
|
|
| 108 |
|
| 109 |
# Function to run main in a separate thread
|
| 110 |
def run_main():
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
# Start main in a separate thread
|
| 114 |
main_thread = threading.Thread(target=run_main)
|
| 115 |
main_thread.start()
|
| 116 |
|
| 117 |
last_step_yielded = 0
|
| 118 |
-
while main_thread.is_alive()
|
| 119 |
# Check if new steps have been completed
|
| 120 |
if steps_completed and steps_completed[-1] > last_step_yielded:
|
| 121 |
last_step_yielded = steps_completed[-1]
|
|
@@ -130,21 +141,35 @@ def generate_image(setup_args, num_iterations):
|
|
| 130 |
# Small sleep to prevent busy waiting
|
| 131 |
time.sleep(0.1)
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
# After main is complete, yield the final image
|
| 136 |
-
final_image_path = os.path.join(save_dir, "best_image.png")
|
| 137 |
-
if os.path.exists(final_image_path):
|
| 138 |
-
iter_images = list_iter_images(save_dir)
|
| 139 |
torch.cuda.empty_cache() # Free up cached memory
|
| 140 |
-
yield (
|
| 141 |
else:
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
-
except Exception as e:
|
| 146 |
torch.cuda.empty_cache() # Free up cached memory
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
def show_gallery_output(gallery_state):
|
| 150 |
if gallery_state is not None:
|
|
|
|
| 108 |
|
| 109 |
# Function to run main in a separate thread
|
| 110 |
def run_main():
|
| 111 |
+
try:
|
| 112 |
+
# Call main and handle any potential OOM errors
|
| 113 |
+
result_container["best_image"], result_container["total_init_rewards"], result_container["total_best_rewards"] = execute_task(args, trainer, device, dtype, shape, enable_grad, settings, progress_callback)
|
| 114 |
+
except torch.cuda.OutOfMemoryError as e:
|
| 115 |
+
# Handle CUDA OOM error
|
| 116 |
+
print("CUDA Out of Memory Error: ", e)
|
| 117 |
+
status["error_occurred"] = True # Update status on error
|
| 118 |
+
except RuntimeError as e:
|
| 119 |
+
if 'out of memory' in str(e):
|
| 120 |
+
status["error_occurred"] = True # Update status on error
|
| 121 |
+
else:
|
| 122 |
+
raise # Reraise if it's not a CUDA OOM error
|
| 123 |
|
| 124 |
# Start main in a separate thread
|
| 125 |
main_thread = threading.Thread(target=run_main)
|
| 126 |
main_thread.start()
|
| 127 |
|
| 128 |
last_step_yielded = 0
|
| 129 |
+
while main_thread.is_alive() and not status["error_occurred"]:
|
| 130 |
# Check if new steps have been completed
|
| 131 |
if steps_completed and steps_completed[-1] > last_step_yielded:
|
| 132 |
last_step_yielded = steps_completed[-1]
|
|
|
|
| 141 |
# Small sleep to prevent busy waiting
|
| 142 |
time.sleep(0.1)
|
| 143 |
|
| 144 |
+
# If an error occurred, clean up resources and stop
|
| 145 |
+
if status["error_occurred"]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
torch.cuda.empty_cache() # Free up cached memory
|
| 147 |
+
yield (None, "CUDA out of memory. Please reduce your batch size or image resolution.", None)
|
| 148 |
else:
|
| 149 |
+
main_thread.join()
|
| 150 |
+
|
| 151 |
+
# After main is complete, yield the final image
|
| 152 |
+
final_image_path = os.path.join(save_dir, "best_image.png")
|
| 153 |
+
if os.path.exists(final_image_path):
|
| 154 |
+
iter_images = list_iter_images(save_dir)
|
| 155 |
+
torch.cuda.empty_cache() # Free up cached memory
|
| 156 |
+
yield (final_image_path, f"Final image saved at {final_image_path}", iter_images)
|
| 157 |
+
else:
|
| 158 |
+
torch.cuda.empty_cache() # Free up cached memory
|
| 159 |
+
yield (None, "Image generation completed, but no final image was found.", None)
|
| 160 |
|
|
|
|
| 161 |
torch.cuda.empty_cache() # Free up cached memory
|
| 162 |
+
|
| 163 |
+
except torch.cuda.OutOfMemoryError as e:
|
| 164 |
+
# Handle CUDA OOM error globally
|
| 165 |
+
yield (None, "CUDA out of memory.", None)
|
| 166 |
+
except RuntimeError as e:
|
| 167 |
+
if 'out of memory' in str(e):
|
| 168 |
+
yield (None, "CUDA out of memory.", None)
|
| 169 |
+
else:
|
| 170 |
+
yield (None, f"An error occurred: {str(e)}", None)
|
| 171 |
+
except Exception as e:
|
| 172 |
+
yield (None, f"An unexpected error occurred: {str(e)}", None)
|
| 173 |
|
| 174 |
def show_gallery_output(gallery_state):
|
| 175 |
if gallery_state is not None:
|