Spaces:
Runtime error
Runtime error
add pruna
Browse files
server/pipelines/txt2img.py
CHANGED
|
@@ -12,6 +12,7 @@ from pydantic import BaseModel, Field
|
|
| 12 |
from util import ParamsModel
|
| 13 |
from PIL import Image
|
| 14 |
from typing import List
|
|
|
|
| 15 |
|
| 16 |
base_model = "SimianLuo/LCM_Dreamshaper_v7"
|
| 17 |
taesd_model = "madebyollin/taesd"
|
|
@@ -86,6 +87,13 @@ class Pipeline:
|
|
| 86 |
taesd_model, torch_dtype=torch_dtype, use_safetensors=True
|
| 87 |
).to(device)
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
if args.sfast:
|
| 90 |
from sfast.compilers.stable_diffusion_pipeline_compiler import (
|
| 91 |
compile,
|
|
|
|
| 12 |
from util import ParamsModel
|
| 13 |
from PIL import Image
|
| 14 |
from typing import List
|
| 15 |
+
from pruna import SmashConfig, smash
|
| 16 |
|
| 17 |
base_model = "SimianLuo/LCM_Dreamshaper_v7"
|
| 18 |
taesd_model = "madebyollin/taesd"
|
|
|
|
| 87 |
taesd_model, torch_dtype=torch_dtype, use_safetensors=True
|
| 88 |
).to(device)
|
| 89 |
|
| 90 |
+
if args.pruna:
|
| 91 |
+
# Create and smash your model
|
| 92 |
+
smash_config = SmashConfig()
|
| 93 |
+
# smash_config["cacher"] = "deepcache"
|
| 94 |
+
smash_config["compiler"] = "stable_fast"
|
| 95 |
+
self.pipe = smash(model=self.pipe, smash_config=smash_config)
|
| 96 |
+
|
| 97 |
if args.sfast:
|
| 98 |
from sfast.compilers.stable_diffusion_pipeline_compiler import (
|
| 99 |
compile,
|
server/pipelines/txt2imgLora.py
CHANGED
|
@@ -12,6 +12,7 @@ from config import Args
|
|
| 12 |
from pydantic import BaseModel, Field
|
| 13 |
from util import ParamsModel
|
| 14 |
from PIL import Image
|
|
|
|
| 15 |
|
| 16 |
base_model = "wavymulder/Analog-Diffusion"
|
| 17 |
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
|
|
@@ -93,6 +94,13 @@ class Pipeline:
|
|
| 93 |
taesd_model, torch_dtype=torch_dtype, use_safetensors=True
|
| 94 |
).to(device)
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
|
| 97 |
self.pipe.set_progress_bar_config(disable=True)
|
| 98 |
self.pipe.load_lora_weights(lcm_lora_id, adapter_name="lcm")
|
|
|
|
| 12 |
from pydantic import BaseModel, Field
|
| 13 |
from util import ParamsModel
|
| 14 |
from PIL import Image
|
| 15 |
+
from pruna import SmashConfig, smash
|
| 16 |
|
| 17 |
base_model = "wavymulder/Analog-Diffusion"
|
| 18 |
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
|
|
|
|
| 94 |
taesd_model, torch_dtype=torch_dtype, use_safetensors=True
|
| 95 |
).to(device)
|
| 96 |
|
| 97 |
+
if args.pruna:
|
| 98 |
+
# Create and smash your model
|
| 99 |
+
smash_config = SmashConfig()
|
| 100 |
+
# smash_config["cacher"] = "deepcache"
|
| 101 |
+
smash_config["compiler"] = "stable_fast"
|
| 102 |
+
self.pipe = smash(model=self.pipe, smash_config=smash_config)
|
| 103 |
+
|
| 104 |
self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
|
| 105 |
self.pipe.set_progress_bar_config(disable=True)
|
| 106 |
self.pipe.load_lora_weights(lcm_lora_id, adapter_name="lcm")
|