Spaces:
Running
on
Zero
Running
on
Zero
Update style_transfer.py
Browse files- style_transfer.py +6 -1
style_transfer.py
CHANGED
|
@@ -24,7 +24,7 @@ class DogStyleTransfer:
|
|
| 24 |
"""
|
| 25 |
def __init__(self):
|
| 26 |
self.models = {}
|
| 27 |
-
self.device =
|
| 28 |
|
| 29 |
# Check xformers availability
|
| 30 |
self.xformers_available = False
|
|
@@ -120,6 +120,11 @@ class DogStyleTransfer:
|
|
| 120 |
|
| 121 |
def load_model(self, style_name):
|
| 122 |
"""Load the appropriate model based on style, handling xformers compatibility"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
# Get model ID for the style
|
| 124 |
model_id = self.style_model_mapping.get(style_name, "runwayml/stable-diffusion-v1-5")
|
| 125 |
|
|
|
|
| 24 |
"""
|
| 25 |
def __init__(self):
|
| 26 |
self.models = {}
|
| 27 |
+
self.device = cpu
|
| 28 |
|
| 29 |
# Check xformers availability
|
| 30 |
self.xformers_available = False
|
|
|
|
| 120 |
|
| 121 |
def load_model(self, style_name):
|
| 122 |
"""Load the appropriate model based on style, handling xformers compatibility"""
|
| 123 |
+
|
| 124 |
+
if not hasattr(self, '_cuda_initialized'):
|
| 125 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 126 |
+
self._cuda_initialized = True
|
| 127 |
+
|
| 128 |
# Get model ID for the style
|
| 129 |
model_id = self.style_model_mapping.get(style_name, "runwayml/stable-diffusion-v1-5")
|
| 130 |
|