Spaces:
Runtime error
Runtime error
fix
Browse files
server/pipelines/controlnetSDTurbo.py
CHANGED
|
@@ -160,20 +160,19 @@ class Pipeline:
|
|
| 160 |
def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
|
| 161 |
controlnet_canny = ControlNetModel.from_pretrained(
|
| 162 |
controlnet_model, torch_dtype=torch_dtype
|
| 163 |
-
)
|
| 164 |
-
|
| 165 |
self.pipes = {}
|
| 166 |
|
| 167 |
if args.safety_checker:
|
| 168 |
self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
| 169 |
-
base_model,
|
| 170 |
-
controlnet=controlnet_canny,
|
| 171 |
)
|
| 172 |
else:
|
| 173 |
self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
| 174 |
base_model,
|
| 175 |
controlnet=controlnet_canny,
|
| 176 |
safety_checker=None,
|
|
|
|
| 177 |
)
|
| 178 |
|
| 179 |
if args.taesd:
|
|
@@ -207,7 +206,7 @@ class Pipeline:
|
|
| 207 |
|
| 208 |
self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
|
| 209 |
self.pipe.set_progress_bar_config(disable=True)
|
| 210 |
-
self.pipe.to(device=device, dtype=torch_dtype)
|
| 211 |
if device.type != "mps":
|
| 212 |
self.pipe.unet.to(memory_format=torch.channels_last)
|
| 213 |
|
|
|
|
| 160 |
def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
|
| 161 |
controlnet_canny = ControlNetModel.from_pretrained(
|
| 162 |
controlnet_model, torch_dtype=torch_dtype
|
| 163 |
+
)
|
|
|
|
| 164 |
self.pipes = {}
|
| 165 |
|
| 166 |
if args.safety_checker:
|
| 167 |
self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
| 168 |
+
base_model, controlnet=controlnet_canny, torch_dtype=torch_dtype
|
|
|
|
| 169 |
)
|
| 170 |
else:
|
| 171 |
self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
| 172 |
base_model,
|
| 173 |
controlnet=controlnet_canny,
|
| 174 |
safety_checker=None,
|
| 175 |
+
torch_dtype=torch_dtype,
|
| 176 |
)
|
| 177 |
|
| 178 |
if args.taesd:
|
|
|
|
| 206 |
|
| 207 |
self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
|
| 208 |
self.pipe.set_progress_bar_config(disable=True)
|
| 209 |
+
self.pipe.to(device=device, dtype=torch_dtype)
|
| 210 |
if device.type != "mps":
|
| 211 |
self.pipe.unet.to(memory_format=torch.channels_last)
|
| 212 |
|