Nathyboy commited on
Commit
15b68a9
·
1 Parent(s): 6ae6fd8

Add persistent folders, scripts, and small models

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. launch.py +48 -0
  3. modules/.DS_Store +0 -0
  4. modules/api/api.py +928 -0
  5. modules/api/models.py +329 -0
  6. modules/cache.py +123 -0
  7. modules/call_queue.py +134 -0
  8. modules/cmd_args.py +128 -0
  9. modules/codeformer_model.py +64 -0
  10. modules/config_states.py +198 -0
  11. modules/dat_model.py +79 -0
  12. modules/deepbooru.py +98 -0
  13. modules/deepbooru_model.py +678 -0
  14. modules/devices.py +295 -0
  15. modules/errors.py +150 -0
  16. modules/esrgan_model.py +62 -0
  17. modules/extensions.py +299 -0
  18. modules/extra_networks.py +225 -0
  19. modules/extra_networks_hypernet.py +28 -0
  20. modules/extras.py +330 -0
  21. modules/face_restoration.py +19 -0
  22. modules/face_restoration_utils.py +180 -0
  23. modules/fifo_lock.py +37 -0
  24. modules/gfpgan_model.py +69 -0
  25. modules/gitpython_hack.py +42 -0
  26. modules/gradio_extensons.py +83 -0
  27. modules/hashes.py +84 -0
  28. modules/hat_model.py +43 -0
  29. modules/hypernetworks/hypernetwork.py +783 -0
  30. modules/hypernetworks/ui.py +38 -0
  31. modules/images.py +877 -0
  32. modules/img2img.py +253 -0
  33. modules/import_hook.py +16 -0
  34. modules/infotext_utils.py +546 -0
  35. modules/infotext_versions.py +46 -0
  36. modules/initialize.py +169 -0
  37. modules/initialize_util.py +215 -0
  38. modules/interrogate.py +222 -0
  39. modules/launch_utils.py +482 -0
  40. modules/localization.py +37 -0
  41. modules/logging_config.py +58 -0
  42. modules/lowvram.py +165 -0
  43. modules/mac_specific.py +98 -0
  44. modules/masking.py +96 -0
  45. modules/memmon.py +92 -0
  46. modules/modelloader.py +197 -0
  47. modules/models/diffusion/ddpm_edit.py +1460 -0
  48. modules/models/diffusion/uni_pc/__init__.py +1 -0
  49. modules/models/diffusion/uni_pc/sampler.py +101 -0
  50. modules/models/diffusion/uni_pc/uni_pc.py +863 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
launch.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules import launch_utils
2
+
3
+ args = launch_utils.args
4
+ python = launch_utils.python
5
+ git = launch_utils.git
6
+ index_url = launch_utils.index_url
7
+ dir_repos = launch_utils.dir_repos
8
+
9
+ commit_hash = launch_utils.commit_hash
10
+ git_tag = launch_utils.git_tag
11
+
12
+ run = launch_utils.run
13
+ is_installed = launch_utils.is_installed
14
+ repo_dir = launch_utils.repo_dir
15
+
16
+ run_pip = launch_utils.run_pip
17
+ check_run_python = launch_utils.check_run_python
18
+ git_clone = launch_utils.git_clone
19
+ git_pull_recursive = launch_utils.git_pull_recursive
20
+ list_extensions = launch_utils.list_extensions
21
+ run_extension_installer = launch_utils.run_extension_installer
22
+ prepare_environment = launch_utils.prepare_environment
23
+ configure_for_tests = launch_utils.configure_for_tests
24
+ start = launch_utils.start
25
+
26
+
27
+ def main():
28
+ if args.dump_sysinfo:
29
+ filename = launch_utils.dump_sysinfo()
30
+
31
+ print(f"Sysinfo saved as {filename}. Exiting...")
32
+
33
+ exit(0)
34
+
35
+ launch_utils.startup_timer.record("initial startup")
36
+
37
+ with launch_utils.startup_timer.subcategory("prepare environment"):
38
+ if not args.skip_prepare_environment:
39
+ prepare_environment()
40
+
41
+ if args.test_server:
42
+ configure_for_tests()
43
+
44
+ start()
45
+
46
+
47
+ if __name__ == "__main__":
48
+ main()
modules/.DS_Store ADDED
Binary file (6.15 kB). View file
 
modules/api/api.py ADDED
@@ -0,0 +1,928 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import os
4
+ import time
5
+ import datetime
6
+ import uvicorn
7
+ import ipaddress
8
+ import requests
9
+ import gradio as gr
10
+ from threading import Lock
11
+ from io import BytesIO
12
+ from fastapi import APIRouter, Depends, FastAPI, Request, Response
13
+ from fastapi.security import HTTPBasic, HTTPBasicCredentials
14
+ from fastapi.exceptions import HTTPException
15
+ from fastapi.responses import JSONResponse
16
+ from fastapi.encoders import jsonable_encoder
17
+ from secrets import compare_digest
18
+
19
+ import modules.shared as shared
20
+ from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models, sd_schedulers
21
+ from modules.api import models
22
+ from modules.shared import opts
23
+ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
24
+ from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
25
+ from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
26
+ from PIL import PngImagePlugin
27
+ from modules.sd_models_config import find_checkpoint_config_near_filename
28
+ from modules.realesrgan_model import get_realesrgan_models
29
+ from modules import devices
30
+ from typing import Any
31
+ import piexif
32
+ import piexif.helper
33
+ from contextlib import closing
34
+ from modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task
35
+
36
+ def script_name_to_index(name, scripts):
37
+ try:
38
+ return [script.title().lower() for script in scripts].index(name.lower())
39
+ except Exception as e:
40
+ raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e
41
+
42
+
43
+ def validate_sampler_name(name):
44
+ config = sd_samplers.all_samplers_map.get(name, None)
45
+ if config is None:
46
+ raise HTTPException(status_code=400, detail="Sampler not found")
47
+
48
+ return name
49
+
50
+
51
+ def setUpscalers(req: dict):
52
+ reqDict = vars(req)
53
+ reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
54
+ reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
55
+ return reqDict
56
+
57
+
58
+ def verify_url(url):
59
+ """Returns True if the url refers to a global resource."""
60
+
61
+ import socket
62
+ from urllib.parse import urlparse
63
+ try:
64
+ parsed_url = urlparse(url)
65
+ domain_name = parsed_url.netloc
66
+ host = socket.gethostbyname_ex(domain_name)
67
+ for ip in host[2]:
68
+ ip_addr = ipaddress.ip_address(ip)
69
+ if not ip_addr.is_global:
70
+ return False
71
+ except Exception:
72
+ return False
73
+
74
+ return True
75
+
76
+
77
+ def decode_base64_to_image(encoding):
78
+ if encoding.startswith("http://") or encoding.startswith("https://"):
79
+ if not opts.api_enable_requests:
80
+ raise HTTPException(status_code=500, detail="Requests not allowed")
81
+
82
+ if opts.api_forbid_local_requests and not verify_url(encoding):
83
+ raise HTTPException(status_code=500, detail="Request to local resource not allowed")
84
+
85
+ headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}
86
+ response = requests.get(encoding, timeout=30, headers=headers)
87
+ try:
88
+ image = images.read(BytesIO(response.content))
89
+ return image
90
+ except Exception as e:
91
+ raise HTTPException(status_code=500, detail="Invalid image url") from e
92
+
93
+ if encoding.startswith("data:image/"):
94
+ encoding = encoding.split(";")[1].split(",")[1]
95
+ try:
96
+ image = images.read(BytesIO(base64.b64decode(encoding)))
97
+ return image
98
+ except Exception as e:
99
+ raise HTTPException(status_code=500, detail="Invalid encoded image") from e
100
+
101
+
102
+ def encode_pil_to_base64(image):
103
+ with io.BytesIO() as output_bytes:
104
+ if isinstance(image, str):
105
+ return image
106
+ if opts.samples_format.lower() == 'png':
107
+ use_metadata = False
108
+ metadata = PngImagePlugin.PngInfo()
109
+ for key, value in image.info.items():
110
+ if isinstance(key, str) and isinstance(value, str):
111
+ metadata.add_text(key, value)
112
+ use_metadata = True
113
+ image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
114
+
115
+ elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
116
+ if image.mode in ("RGBA", "P"):
117
+ image = image.convert("RGB")
118
+ parameters = image.info.get('parameters', None)
119
+ exif_bytes = piexif.dump({
120
+ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
121
+ })
122
+ if opts.samples_format.lower() in ("jpg", "jpeg"):
123
+ image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
124
+ else:
125
+ image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)
126
+
127
+ else:
128
+ raise HTTPException(status_code=500, detail="Invalid image format")
129
+
130
+ bytes_data = output_bytes.getvalue()
131
+
132
+ return base64.b64encode(bytes_data)
133
+
134
+
135
+ def api_middleware(app: FastAPI):
136
+ rich_available = False
137
+ try:
138
+ if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None:
139
+ import anyio # importing just so it can be placed on silent list
140
+ import starlette # importing just so it can be placed on silent list
141
+ from rich.console import Console
142
+ console = Console()
143
+ rich_available = True
144
+ except Exception:
145
+ pass
146
+
147
+ @app.middleware("http")
148
+ async def log_and_time(req: Request, call_next):
149
+ ts = time.time()
150
+ res: Response = await call_next(req)
151
+ duration = str(round(time.time() - ts, 4))
152
+ res.headers["X-Process-Time"] = duration
153
+ endpoint = req.scope.get('path', 'err')
154
+ if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
155
+ print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
156
+ t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
157
+ code=res.status_code,
158
+ ver=req.scope.get('http_version', '0.0'),
159
+ cli=req.scope.get('client', ('0:0.0.0', 0))[0],
160
+ prot=req.scope.get('scheme', 'err'),
161
+ method=req.scope.get('method', 'err'),
162
+ endpoint=endpoint,
163
+ duration=duration,
164
+ ))
165
+ return res
166
+
167
+ def handle_exception(request: Request, e: Exception):
168
+ err = {
169
+ "error": type(e).__name__,
170
+ "detail": vars(e).get('detail', ''),
171
+ "body": vars(e).get('body', ''),
172
+ "errors": str(e),
173
+ }
174
+ if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
175
+ message = f"API error: {request.method}: {request.url} {err}"
176
+ if rich_available:
177
+ print(message)
178
+ console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
179
+ else:
180
+ errors.report(message, exc_info=True)
181
+ return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
182
+
183
+ @app.middleware("http")
184
+ async def exception_handling(request: Request, call_next):
185
+ try:
186
+ return await call_next(request)
187
+ except Exception as e:
188
+ return handle_exception(request, e)
189
+
190
+ @app.exception_handler(Exception)
191
+ async def fastapi_exception_handler(request: Request, e: Exception):
192
+ return handle_exception(request, e)
193
+
194
+ @app.exception_handler(HTTPException)
195
+ async def http_exception_handler(request: Request, e: HTTPException):
196
+ return handle_exception(request, e)
197
+
198
+
199
+ class Api:
200
+ def __init__(self, app: FastAPI, queue_lock: Lock):
201
+ if shared.cmd_opts.api_auth:
202
+ self.credentials = {}
203
+ for auth in shared.cmd_opts.api_auth.split(","):
204
+ user, password = auth.split(":")
205
+ self.credentials[user] = password
206
+
207
+ self.router = APIRouter()
208
+ self.app = app
209
+ self.queue_lock = queue_lock
210
+ api_middleware(self.app)
211
+ self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
212
+ self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)
213
+ self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
214
+ self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
215
+ self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
216
+ self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
217
+ self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
218
+ self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
219
+ self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
220
+ self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
221
+ self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
222
+ self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
223
+ self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
224
+ self.add_api_route("/sdapi/v1/schedulers", self.get_schedulers, methods=["GET"], response_model=list[models.SchedulerItem])
225
+ self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
226
+ self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
227
+ self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
228
+ self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem])
229
+ self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem])
230
+ self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem])
231
+ self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
232
+ self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
233
+ self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
234
+ self.add_api_route("/sdapi/v1/refresh-embeddings", self.refresh_embeddings, methods=["POST"])
235
+ self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
236
+ self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
237
+ self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
238
+ self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
239
+ self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
240
+ self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
241
+ self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
242
+ self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
243
+ self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
244
+ self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
245
+ self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])
246
+ self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
247
+
248
+ if shared.cmd_opts.api_server_stop:
249
+ self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
250
+ self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
251
+ self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
252
+
253
+ self.default_script_arg_txt2img = []
254
+ self.default_script_arg_img2img = []
255
+
256
+ txt2img_script_runner = scripts.scripts_txt2img
257
+ img2img_script_runner = scripts.scripts_img2img
258
+
259
+ if not txt2img_script_runner.scripts or not img2img_script_runner.scripts:
260
+ ui.create_ui()
261
+
262
+ if not txt2img_script_runner.scripts:
263
+ txt2img_script_runner.initialize_scripts(False)
264
+ if not self.default_script_arg_txt2img:
265
+ self.default_script_arg_txt2img = self.init_default_script_args(txt2img_script_runner)
266
+
267
+ if not img2img_script_runner.scripts:
268
+ img2img_script_runner.initialize_scripts(True)
269
+ if not self.default_script_arg_img2img:
270
+ self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner)
271
+
272
+
273
+
274
+ def add_api_route(self, path: str, endpoint, **kwargs):
275
+ if shared.cmd_opts.api_auth:
276
+ return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
277
+ return self.app.add_api_route(path, endpoint, **kwargs)
278
+
279
+ def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
280
+ if credentials.username in self.credentials:
281
+ if compare_digest(credentials.password, self.credentials[credentials.username]):
282
+ return True
283
+
284
+ raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
285
+
286
+ def get_selectable_script(self, script_name, script_runner):
287
+ if script_name is None or script_name == "":
288
+ return None, None
289
+
290
+ script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
291
+ script = script_runner.selectable_scripts[script_idx]
292
+ return script, script_idx
293
+
294
+ def get_scripts_list(self):
295
+ t2ilist = [script.name for script in scripts.scripts_txt2img.scripts if script.name is not None]
296
+ i2ilist = [script.name for script in scripts.scripts_img2img.scripts if script.name is not None]
297
+
298
+ return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist)
299
+
300
+ def get_script_info(self):
301
+ res = []
302
+
303
+ for script_list in [scripts.scripts_txt2img.scripts, scripts.scripts_img2img.scripts]:
304
+ res += [script.api_info for script in script_list if script.api_info is not None]
305
+
306
+ return res
307
+
308
+ def get_script(self, script_name, script_runner):
309
+ if script_name is None or script_name == "":
310
+ return None, None
311
+
312
+ script_idx = script_name_to_index(script_name, script_runner.scripts)
313
+ return script_runner.scripts[script_idx]
314
+
315
+ def init_default_script_args(self, script_runner):
316
+ #find max idx from the scripts in runner and generate a none array to init script_args
317
+ last_arg_index = 1
318
+ for script in script_runner.scripts:
319
+ if last_arg_index < script.args_to:
320
+ last_arg_index = script.args_to
321
+ # None everywhere except position 0 to initialize script args
322
+ script_args = [None]*last_arg_index
323
+ script_args[0] = 0
324
+
325
+ # get default values
326
+ with gr.Blocks(): # will throw errors calling ui function without this
327
+ for script in script_runner.scripts:
328
+ if script.ui(script.is_img2img):
329
+ ui_default_values = []
330
+ for elem in script.ui(script.is_img2img):
331
+ ui_default_values.append(elem.value)
332
+ script_args[script.args_from:script.args_to] = ui_default_values
333
+ return script_args
334
+
335
+ def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner, *, input_script_args=None):
336
+ script_args = default_script_args.copy()
337
+
338
+ if input_script_args is not None:
339
+ for index, value in input_script_args.items():
340
+ script_args[index] = value
341
+
342
+ # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
343
+ if selectable_scripts:
344
+ script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
345
+ script_args[0] = selectable_idx + 1
346
+
347
+ # Now check for always on scripts
348
+ if request.alwayson_scripts:
349
+ for alwayson_script_name in request.alwayson_scripts.keys():
350
+ alwayson_script = self.get_script(alwayson_script_name, script_runner)
351
+ if alwayson_script is None:
352
+ raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
353
+ # Selectable script in always on script param check
354
+ if alwayson_script.alwayson is False:
355
+ raise HTTPException(status_code=422, detail="Cannot have a selectable script in the always on scripts params")
356
+ # always on script with no arg should always run so you don't really need to add them to the requests
357
+ if "args" in request.alwayson_scripts[alwayson_script_name]:
358
+ # min between arg length in scriptrunner and arg length in the request
359
+ for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))):
360
+ script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
361
+ return script_args
362
+
363
+ def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None):
364
+ """Processes `infotext` field from the `request`, and sets other fields of the `request` according to what's in infotext.
365
+
366
+ If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored.
367
+
368
+ Additionally, fills `mentioned_script_args` dict with index: value pairs for script arguments read from infotext.
369
+ """
370
+
371
+ if not request.infotext:
372
+ return {}
373
+
374
+ possible_fields = infotext_utils.paste_fields[tabname]["fields"]
375
+ set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have different names for this
376
+ params = infotext_utils.parse_generation_parameters(request.infotext)
377
+
378
+ def get_field_value(field, params):
379
+ value = field.function(params) if field.function else params.get(field.label)
380
+ if value is None:
381
+ return None
382
+
383
+ if field.api in request.__fields__:
384
+ target_type = request.__fields__[field.api].type_
385
+ else:
386
+ target_type = type(field.component.value)
387
+
388
+ if target_type == type(None):
389
+ return None
390
+
391
+ if isinstance(value, dict) and value.get('__type__') == 'generic_update': # this is a gradio.update rather than a value
392
+ value = value.get('value')
393
+
394
+ if value is not None and not isinstance(value, target_type):
395
+ value = target_type(value)
396
+
397
+ return value
398
+
399
+ for field in possible_fields:
400
+ if not field.api:
401
+ continue
402
+
403
+ if field.api in set_fields:
404
+ continue
405
+
406
+ value = get_field_value(field, params)
407
+ if value is not None:
408
+ setattr(request, field.api, value)
409
+
410
+ if request.override_settings is None:
411
+ request.override_settings = {}
412
+
413
+ overridden_settings = infotext_utils.get_override_settings(params)
414
+ for _, setting_name, value in overridden_settings:
415
+ if setting_name not in request.override_settings:
416
+ request.override_settings[setting_name] = value
417
+
418
+ if script_runner is not None and mentioned_script_args is not None:
419
+ indexes = {v: i for i, v in enumerate(script_runner.inputs)}
420
+ script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes)
421
+
422
+ for field, index in script_fields:
423
+ value = get_field_value(field, params)
424
+
425
+ if value is None:
426
+ continue
427
+
428
+ mentioned_script_args[index] = value
429
+
430
+ return params
431
+
432
+ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
433
+ task_id = txt2imgreq.force_task_id or create_task_id("txt2img")
434
+
435
+ script_runner = scripts.scripts_txt2img
436
+
437
+ infotext_script_args = {}
438
+ self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
439
+
440
+ selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
441
+ sampler, scheduler = sd_samplers.get_sampler_and_scheduler(txt2imgreq.sampler_name or txt2imgreq.sampler_index, txt2imgreq.scheduler)
442
+
443
+ populate = txt2imgreq.copy(update={ # Override __init__ params
444
+ "sampler_name": validate_sampler_name(sampler),
445
+ "do_not_save_samples": not txt2imgreq.save_images,
446
+ "do_not_save_grid": not txt2imgreq.save_images,
447
+ })
448
+ if populate.sampler_name:
449
+ populate.sampler_index = None # prevent a warning later on
450
+
451
+ if not populate.scheduler and scheduler != "Automatic":
452
+ populate.scheduler = scheduler
453
+
454
+ args = vars(populate)
455
+ args.pop('script_name', None)
456
+ args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
457
+ args.pop('alwayson_scripts', None)
458
+ args.pop('infotext', None)
459
+
460
+ script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
461
+
462
+ send_images = args.pop('send_images', True)
463
+ args.pop('save_images', None)
464
+
465
+ add_task_to_queue(task_id)
466
+
467
+ with self.queue_lock:
468
+ with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
469
+ p.is_api = True
470
+ p.scripts = script_runner
471
+ p.outpath_grids = opts.outdir_txt2img_grids
472
+ p.outpath_samples = opts.outdir_txt2img_samples
473
+
474
+ try:
475
+ shared.state.begin(job="scripts_txt2img")
476
+ start_task(task_id)
477
+ if selectable_scripts is not None:
478
+ p.script_args = script_args
479
+ processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
480
+ else:
481
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
482
+ processed = process_images(p)
483
+ finish_task(task_id)
484
+ finally:
485
+ shared.state.end()
486
+ shared.total_tqdm.clear()
487
+
488
+ b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
489
+
490
+ return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
491
+
492
+ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
493
+ task_id = img2imgreq.force_task_id or create_task_id("img2img")
494
+
495
+ init_images = img2imgreq.init_images
496
+ if init_images is None:
497
+ raise HTTPException(status_code=404, detail="Init image not found")
498
+
499
+ mask = img2imgreq.mask
500
+ if mask:
501
+ mask = decode_base64_to_image(mask)
502
+
503
+ script_runner = scripts.scripts_img2img
504
+
505
+ infotext_script_args = {}
506
+ self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
507
+
508
+ selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
509
+ sampler, scheduler = sd_samplers.get_sampler_and_scheduler(img2imgreq.sampler_name or img2imgreq.sampler_index, img2imgreq.scheduler)
510
+
511
+ populate = img2imgreq.copy(update={ # Override __init__ params
512
+ "sampler_name": validate_sampler_name(sampler),
513
+ "do_not_save_samples": not img2imgreq.save_images,
514
+ "do_not_save_grid": not img2imgreq.save_images,
515
+ "mask": mask,
516
+ })
517
+ if populate.sampler_name:
518
+ populate.sampler_index = None # prevent a warning later on
519
+
520
+ if not populate.scheduler and scheduler != "Automatic":
521
+ populate.scheduler = scheduler
522
+
523
+ args = vars(populate)
524
+ args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
525
+ args.pop('script_name', None)
526
+ args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
527
+ args.pop('alwayson_scripts', None)
528
+ args.pop('infotext', None)
529
+
530
+ script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
531
+
532
+ send_images = args.pop('send_images', True)
533
+ args.pop('save_images', None)
534
+
535
+ add_task_to_queue(task_id)
536
+
537
+ with self.queue_lock:
538
+ with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
539
+ p.init_images = [decode_base64_to_image(x) for x in init_images]
540
+ p.is_api = True
541
+ p.scripts = script_runner
542
+ p.outpath_grids = opts.outdir_img2img_grids
543
+ p.outpath_samples = opts.outdir_img2img_samples
544
+
545
+ try:
546
+ shared.state.begin(job="scripts_img2img")
547
+ start_task(task_id)
548
+ if selectable_scripts is not None:
549
+ p.script_args = script_args
550
+ processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
551
+ else:
552
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
553
+ processed = process_images(p)
554
+ finish_task(task_id)
555
+ finally:
556
+ shared.state.end()
557
+ shared.total_tqdm.clear()
558
+
559
+ b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
560
+
561
+ if not img2imgreq.include_init_images:
562
+ img2imgreq.init_images = None
563
+ img2imgreq.mask = None
564
+
565
+ return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
566
+
567
+ def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):
568
+ reqDict = setUpscalers(req)
569
+
570
+ reqDict['image'] = decode_base64_to_image(reqDict['image'])
571
+
572
+ with self.queue_lock:
573
+ result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
574
+
575
+ return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
576
+
577
+ def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
578
+ reqDict = setUpscalers(req)
579
+
580
+ image_list = reqDict.pop('imageList', [])
581
+ image_folder = [decode_base64_to_image(x.data) for x in image_list]
582
+
583
+ with self.queue_lock:
584
+ result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
585
+
586
+ return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
587
+
588
+ def pnginfoapi(self, req: models.PNGInfoRequest):
589
+ image = decode_base64_to_image(req.image.strip())
590
+ if image is None:
591
+ return models.PNGInfoResponse(info="")
592
+
593
+ geninfo, items = images.read_info_from_image(image)
594
+ if geninfo is None:
595
+ geninfo = ""
596
+
597
+ params = infotext_utils.parse_generation_parameters(geninfo)
598
+ script_callbacks.infotext_pasted_callback(geninfo, params)
599
+
600
+ return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
601
+
602
+ def progressapi(self, req: models.ProgressRequest = Depends()):
603
+ # copy from check_progress_call of ui.py
604
+
605
+ if shared.state.job_count == 0:
606
+ return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
607
+
608
+ # avoid dividing zero
609
+ progress = 0.01
610
+
611
+ if shared.state.job_count > 0:
612
+ progress += shared.state.job_no / shared.state.job_count
613
+ if shared.state.sampling_steps > 0:
614
+ progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
615
+
616
+ time_since_start = time.time() - shared.state.time_start
617
+ eta = (time_since_start/progress)
618
+ eta_relative = eta-time_since_start
619
+
620
+ progress = min(progress, 1)
621
+
622
+ shared.state.set_current_image()
623
+
624
+ current_image = None
625
+ if shared.state.current_image and not req.skip_current_image:
626
+ current_image = encode_pil_to_base64(shared.state.current_image)
627
+
628
+ return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task)
629
+
630
+ def interrogateapi(self, interrogatereq: models.InterrogateRequest):
631
+ image_b64 = interrogatereq.image
632
+ if image_b64 is None:
633
+ raise HTTPException(status_code=404, detail="Image not found")
634
+
635
+ img = decode_base64_to_image(image_b64)
636
+ img = img.convert('RGB')
637
+
638
+ # Override object param
639
+ with self.queue_lock:
640
+ if interrogatereq.model == "clip":
641
+ processed = shared.interrogator.interrogate(img)
642
+ elif interrogatereq.model == "deepdanbooru":
643
+ processed = deepbooru.model.tag(img)
644
+ else:
645
+ raise HTTPException(status_code=404, detail="Model not found")
646
+
647
+ return models.InterrogateResponse(caption=processed)
648
+
649
+ def interruptapi(self):
650
+ shared.state.interrupt()
651
+
652
+ return {}
653
+
654
+ def unloadapi(self):
655
+ sd_models.unload_model_weights()
656
+
657
+ return {}
658
+
659
+ def reloadapi(self):
660
+ sd_models.send_model_to_device(shared.sd_model)
661
+
662
+ return {}
663
+
664
+ def skip(self):
665
+ shared.state.skip()
666
+
667
+ def get_config(self):
668
+ options = {}
669
+ for key in shared.opts.data.keys():
670
+ metadata = shared.opts.data_labels.get(key)
671
+ if(metadata is not None):
672
+ options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
673
+ else:
674
+ options.update({key: shared.opts.data.get(key, None)})
675
+
676
+ return options
677
+
678
+ def set_config(self, req: dict[str, Any]):
679
+ checkpoint_name = req.get("sd_model_checkpoint", None)
680
+ if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
681
+ raise RuntimeError(f"model {checkpoint_name!r} not found")
682
+
683
+ for k, v in req.items():
684
+ shared.opts.set(k, v, is_api=True)
685
+
686
+ shared.opts.save(shared.config_filename)
687
+ return
688
+
689
+ def get_cmd_flags(self):
690
+ return vars(shared.cmd_opts)
691
+
692
+ def get_samplers(self):
693
+ return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
694
+
695
+ def get_schedulers(self):
696
+ return [
697
+ {
698
+ "name": scheduler.name,
699
+ "label": scheduler.label,
700
+ "aliases": scheduler.aliases,
701
+ "default_rho": scheduler.default_rho,
702
+ "need_inner_model": scheduler.need_inner_model,
703
+ }
704
+ for scheduler in sd_schedulers.schedulers]
705
+
706
+ def get_upscalers(self):
707
+ return [
708
+ {
709
+ "name": upscaler.name,
710
+ "model_name": upscaler.scaler.model_name,
711
+ "model_path": upscaler.data_path,
712
+ "model_url": None,
713
+ "scale": upscaler.scale,
714
+ }
715
+ for upscaler in shared.sd_upscalers
716
+ ]
717
+
718
+ def get_latent_upscale_modes(self):
719
+ return [
720
+ {
721
+ "name": upscale_mode,
722
+ }
723
+ for upscale_mode in [*(shared.latent_upscale_modes or {})]
724
+ ]
725
+
726
+ def get_sd_models(self):
727
+ import modules.sd_models as sd_models
728
+ return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()]
729
+
730
+ def get_sd_vaes(self):
731
+ import modules.sd_vae as sd_vae
732
+ return [{"model_name": x, "filename": sd_vae.vae_dict[x]} for x in sd_vae.vae_dict.keys()]
733
+
734
+ def get_hypernetworks(self):
735
+ return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
736
+
737
+ def get_face_restorers(self):
738
+ return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
739
+
740
+ def get_realesrgan_models(self):
741
+ return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
742
+
743
+ def get_prompt_styles(self):
744
+ styleList = []
745
+ for k in shared.prompt_styles.styles:
746
+ style = shared.prompt_styles.styles[k]
747
+ styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
748
+
749
+ return styleList
750
+
751
+ def get_embeddings(self):
752
+ db = sd_hijack.model_hijack.embedding_db
753
+
754
+ def convert_embedding(embedding):
755
+ return {
756
+ "step": embedding.step,
757
+ "sd_checkpoint": embedding.sd_checkpoint,
758
+ "sd_checkpoint_name": embedding.sd_checkpoint_name,
759
+ "shape": embedding.shape,
760
+ "vectors": embedding.vectors,
761
+ }
762
+
763
+ def convert_embeddings(embeddings):
764
+ return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
765
+
766
+ return {
767
+ "loaded": convert_embeddings(db.word_embeddings),
768
+ "skipped": convert_embeddings(db.skipped_embeddings),
769
+ }
770
+
771
+ def refresh_embeddings(self):
772
+ with self.queue_lock:
773
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
774
+
775
+ def refresh_checkpoints(self):
776
+ with self.queue_lock:
777
+ shared.refresh_checkpoints()
778
+
779
+ def refresh_vae(self):
780
+ with self.queue_lock:
781
+ shared_items.refresh_vae_list()
782
+
783
+ def create_embedding(self, args: dict):
784
+ try:
785
+ shared.state.begin(job="create_embedding")
786
+ filename = create_embedding(**args) # create empty embedding
787
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
788
+ return models.CreateResponse(info=f"create embedding filename: {filename}")
789
+ except AssertionError as e:
790
+ return models.TrainResponse(info=f"create embedding error: {e}")
791
+ finally:
792
+ shared.state.end()
793
+
794
+
795
+ def create_hypernetwork(self, args: dict):
796
+ try:
797
+ shared.state.begin(job="create_hypernetwork")
798
+ filename = create_hypernetwork(**args) # create empty embedding
799
+ return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
800
+ except AssertionError as e:
801
+ return models.TrainResponse(info=f"create hypernetwork error: {e}")
802
+ finally:
803
+ shared.state.end()
804
+
805
+ def train_embedding(self, args: dict):
806
+ try:
807
+ shared.state.begin(job="train_embedding")
808
+ apply_optimizations = shared.opts.training_xattention_optimizations
809
+ error = None
810
+ filename = ''
811
+ if not apply_optimizations:
812
+ sd_hijack.undo_optimizations()
813
+ try:
814
+ embedding, filename = train_embedding(**args) # can take a long time to complete
815
+ except Exception as e:
816
+ error = e
817
+ finally:
818
+ if not apply_optimizations:
819
+ sd_hijack.apply_optimizations()
820
+ return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
821
+ except Exception as msg:
822
+ return models.TrainResponse(info=f"train embedding error: {msg}")
823
+ finally:
824
+ shared.state.end()
825
+
826
+ def train_hypernetwork(self, args: dict):
827
+ try:
828
+ shared.state.begin(job="train_hypernetwork")
829
+ shared.loaded_hypernetworks = []
830
+ apply_optimizations = shared.opts.training_xattention_optimizations
831
+ error = None
832
+ filename = ''
833
+ if not apply_optimizations:
834
+ sd_hijack.undo_optimizations()
835
+ try:
836
+ hypernetwork, filename = train_hypernetwork(**args)
837
+ except Exception as e:
838
+ error = e
839
+ finally:
840
+ shared.sd_model.cond_stage_model.to(devices.device)
841
+ shared.sd_model.first_stage_model.to(devices.device)
842
+ if not apply_optimizations:
843
+ sd_hijack.apply_optimizations()
844
+ shared.state.end()
845
+ return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
846
+ except Exception as exc:
847
+ return models.TrainResponse(info=f"train embedding error: {exc}")
848
+ finally:
849
+ shared.state.end()
850
+
851
+ def get_memory(self):
852
+ try:
853
+ import os
854
+ import psutil
855
+ process = psutil.Process(os.getpid())
856
+ res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
857
+ ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
858
+ ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total }
859
+ except Exception as err:
860
+ ram = { 'error': f'{err}' }
861
+ try:
862
+ import torch
863
+ if torch.cuda.is_available():
864
+ s = torch.cuda.mem_get_info()
865
+ system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] }
866
+ s = dict(torch.cuda.memory_stats(shared.device))
867
+ allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] }
868
+ reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] }
869
+ active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] }
870
+ inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] }
871
+ warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }
872
+ cuda = {
873
+ 'system': system,
874
+ 'active': active,
875
+ 'allocated': allocated,
876
+ 'reserved': reserved,
877
+ 'inactive': inactive,
878
+ 'events': warnings,
879
+ }
880
+ else:
881
+ cuda = {'error': 'unavailable'}
882
+ except Exception as err:
883
+ cuda = {'error': f'{err}'}
884
+ return models.MemoryResponse(ram=ram, cuda=cuda)
885
+
886
+ def get_extensions_list(self):
887
+ from modules import extensions
888
+ extensions.list_extensions()
889
+ ext_list = []
890
+ for ext in extensions.extensions:
891
+ ext: extensions.Extension
892
+ ext.read_info_from_repo()
893
+ if ext.remote is not None:
894
+ ext_list.append({
895
+ "name": ext.name,
896
+ "remote": ext.remote,
897
+ "branch": ext.branch,
898
+ "commit_hash":ext.commit_hash,
899
+ "commit_date":ext.commit_date,
900
+ "version":ext.version,
901
+ "enabled":ext.enabled
902
+ })
903
+ return ext_list
904
+
905
+ def launch(self, server_name, port, root_path):
906
+ self.app.include_router(self.router)
907
+ uvicorn.run(
908
+ self.app,
909
+ host=server_name,
910
+ port=port,
911
+ timeout_keep_alive=shared.cmd_opts.timeout_keep_alive,
912
+ root_path=root_path,
913
+ ssl_keyfile=shared.cmd_opts.tls_keyfile,
914
+ ssl_certfile=shared.cmd_opts.tls_certfile
915
+ )
916
+
917
+ def kill_webui(self):
918
+ restart.stop_program()
919
+
920
+ def restart_webui(self):
921
+ if restart.is_restartable():
922
+ restart.restart_program()
923
+ return Response(status_code=501)
924
+
925
+ def stop_webui(request):
926
+ shared.state.server_command = "stop"
927
+ return Response("Stopping.")
928
+
modules/api/models.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+
3
+ from pydantic import BaseModel, Field, create_model
4
+ from typing import Any, Optional, Literal
5
+ from inflection import underscore
6
+ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
7
+ from modules.shared import sd_upscalers, opts, parser
8
+
9
+ API_NOT_ALLOWED = [
10
+ "self",
11
+ "kwargs",
12
+ "sd_model",
13
+ "outpath_samples",
14
+ "outpath_grids",
15
+ "sampler_index",
16
+ # "do_not_save_samples",
17
+ # "do_not_save_grid",
18
+ "extra_generation_params",
19
+ "overlay_images",
20
+ "do_not_reload_embeddings",
21
+ "seed_enable_extras",
22
+ "prompt_for_display",
23
+ "sampler_noise_scheduler_override",
24
+ "ddim_discretize"
25
+ ]
26
+
27
+ class ModelDef(BaseModel):
28
+ """Assistance Class for Pydantic Dynamic Model Generation"""
29
+
30
+ field: str
31
+ field_alias: str
32
+ field_type: Any
33
+ field_value: Any
34
+ field_exclude: bool = False
35
+
36
+
37
+ class PydanticModelGenerator:
38
+ """
39
+ Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
40
+ source_data is a snapshot of the default values produced by the class
41
+ params are the names of the actual keys required by __init__
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ model_name: str = None,
47
+ class_instance = None,
48
+ additional_fields = None,
49
+ ):
50
+ def field_type_generator(k, v):
51
+ field_type = v.annotation
52
+
53
+ if field_type == 'Image':
54
+ # images are sent as base64 strings via API
55
+ field_type = 'str'
56
+
57
+ return Optional[field_type]
58
+
59
+ def merge_class_params(class_):
60
+ all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
61
+ parameters = {}
62
+ for classes in all_classes:
63
+ parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
64
+ return parameters
65
+
66
+ self._model_name = model_name
67
+ self._class_data = merge_class_params(class_instance)
68
+
69
+ self._model_def = [
70
+ ModelDef(
71
+ field=underscore(k),
72
+ field_alias=k,
73
+ field_type=field_type_generator(k, v),
74
+ field_value=None if isinstance(v.default, property) else v.default
75
+ )
76
+ for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
77
+ ]
78
+
79
+ for fields in additional_fields:
80
+ self._model_def.append(ModelDef(
81
+ field=underscore(fields["key"]),
82
+ field_alias=fields["key"],
83
+ field_type=fields["type"],
84
+ field_value=fields["default"],
85
+ field_exclude=fields["exclude"] if "exclude" in fields else False))
86
+
87
+ def generate_model(self):
88
+ """
89
+ Creates a pydantic BaseModel
90
+ from the json and overrides provided at initialization
91
+ """
92
+ fields = {
93
+ d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
94
+ }
95
+ DynamicModel = create_model(self._model_name, **fields)
96
+ DynamicModel.__config__.allow_population_by_field_name = True
97
+ DynamicModel.__config__.allow_mutation = True
98
+ return DynamicModel
99
+
100
+ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
101
+ "StableDiffusionProcessingTxt2Img",
102
+ StableDiffusionProcessingTxt2Img,
103
+ [
104
+ {"key": "sampler_index", "type": str, "default": "Euler"},
105
+ {"key": "script_name", "type": str, "default": None},
106
+ {"key": "script_args", "type": list, "default": []},
107
+ {"key": "send_images", "type": bool, "default": True},
108
+ {"key": "save_images", "type": bool, "default": False},
109
+ {"key": "alwayson_scripts", "type": dict, "default": {}},
110
+ {"key": "force_task_id", "type": str, "default": None},
111
+ {"key": "infotext", "type": str, "default": None},
112
+ ]
113
+ ).generate_model()
114
+
115
+ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
116
+ "StableDiffusionProcessingImg2Img",
117
+ StableDiffusionProcessingImg2Img,
118
+ [
119
+ {"key": "sampler_index", "type": str, "default": "Euler"},
120
+ {"key": "init_images", "type": list, "default": None},
121
+ {"key": "denoising_strength", "type": float, "default": 0.75},
122
+ {"key": "mask", "type": str, "default": None},
123
+ {"key": "include_init_images", "type": bool, "default": False, "exclude" : True},
124
+ {"key": "script_name", "type": str, "default": None},
125
+ {"key": "script_args", "type": list, "default": []},
126
+ {"key": "send_images", "type": bool, "default": True},
127
+ {"key": "save_images", "type": bool, "default": False},
128
+ {"key": "alwayson_scripts", "type": dict, "default": {}},
129
+ {"key": "force_task_id", "type": str, "default": None},
130
+ {"key": "infotext", "type": str, "default": None},
131
+ ]
132
+ ).generate_model()
133
+
134
+ class TextToImageResponse(BaseModel):
135
+ images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
136
+ parameters: dict
137
+ info: str
138
+
139
+ class ImageToImageResponse(BaseModel):
140
+ images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
141
+ parameters: dict
142
+ info: str
143
+
144
+ class ExtrasBaseRequest(BaseModel):
145
+ resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
146
+ show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
147
+ gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
148
+ codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
149
+ codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
150
+ upscaling_resize: float = Field(default=2, title="Upscaling Factor", gt=0, description="By how much to upscale the image, only used when resize_mode=0.")
151
+ upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
152
+ upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
153
+ upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
154
+ upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
155
+ upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
156
+ extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
157
+ upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?")
158
+
159
+ class ExtraBaseResponse(BaseModel):
160
+ html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
161
+
162
+ class ExtrasSingleImageRequest(ExtrasBaseRequest):
163
+ image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
164
+
165
+ class ExtrasSingleImageResponse(ExtraBaseResponse):
166
+ image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
167
+
168
+ class FileData(BaseModel):
169
+ data: str = Field(title="File data", description="Base64 representation of the file")
170
+ name: str = Field(title="File name")
171
+
172
+ class ExtrasBatchImagesRequest(ExtrasBaseRequest):
173
+ imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
174
+
175
+ class ExtrasBatchImagesResponse(ExtraBaseResponse):
176
+ images: list[str] = Field(title="Images", description="The generated images in base64 format.")
177
+
178
+ class PNGInfoRequest(BaseModel):
179
+ image: str = Field(title="Image", description="The base64 encoded PNG image")
180
+
181
+ class PNGInfoResponse(BaseModel):
182
+ info: str = Field(title="Image info", description="A string with the parameters used to generate the image")
183
+ items: dict = Field(title="Items", description="A dictionary containing all the other fields the image had")
184
+ parameters: dict = Field(title="Parameters", description="A dictionary with parsed generation info fields")
185
+
186
+ class ProgressRequest(BaseModel):
187
+ skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
188
+
189
+ class ProgressResponse(BaseModel):
190
+ progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
191
+ eta_relative: float = Field(title="ETA in secs")
192
+ state: dict = Field(title="State", description="The current state snapshot")
193
+ current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
194
+ textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
195
+
196
+ class InterrogateRequest(BaseModel):
197
+ image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
198
+ model: str = Field(default="clip", title="Model", description="The interrogate model used.")
199
+
200
+ class InterrogateResponse(BaseModel):
201
+ caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
202
+
203
+ class TrainResponse(BaseModel):
204
+ info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.")
205
+
206
+ class CreateResponse(BaseModel):
207
+ info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
208
+
209
+ fields = {}
210
+ for key, metadata in opts.data_labels.items():
211
+ value = opts.data.get(key)
212
+ optType = opts.typemap.get(type(metadata.default), type(metadata.default)) if metadata.default else Any
213
+
214
+ if metadata is not None:
215
+ fields.update({key: (Optional[optType], Field(default=metadata.default, description=metadata.label))})
216
+ else:
217
+ fields.update({key: (Optional[optType], Field())})
218
+
219
+ OptionsModel = create_model("Options", **fields)
220
+
221
+ flags = {}
222
+ _options = vars(parser)['_option_string_actions']
223
+ for key in _options:
224
+ if(_options[key].dest != 'help'):
225
+ flag = _options[key]
226
+ _type = str
227
+ if _options[key].default is not None:
228
+ _type = type(_options[key].default)
229
+ flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})
230
+
231
+ FlagsModel = create_model("Flags", **flags)
232
+
233
+ class SamplerItem(BaseModel):
234
+ name: str = Field(title="Name")
235
+ aliases: list[str] = Field(title="Aliases")
236
+ options: dict[str, str] = Field(title="Options")
237
+
238
+ class SchedulerItem(BaseModel):
239
+ name: str = Field(title="Name")
240
+ label: str = Field(title="Label")
241
+ aliases: Optional[list[str]] = Field(title="Aliases")
242
+ default_rho: Optional[float] = Field(title="Default Rho")
243
+ need_inner_model: Optional[bool] = Field(title="Needs Inner Model")
244
+
245
+ class UpscalerItem(BaseModel):
246
+ name: str = Field(title="Name")
247
+ model_name: Optional[str] = Field(title="Model Name")
248
+ model_path: Optional[str] = Field(title="Path")
249
+ model_url: Optional[str] = Field(title="URL")
250
+ scale: Optional[float] = Field(title="Scale")
251
+
252
+ class LatentUpscalerModeItem(BaseModel):
253
+ name: str = Field(title="Name")
254
+
255
+ class SDModelItem(BaseModel):
256
+ title: str = Field(title="Title")
257
+ model_name: str = Field(title="Model Name")
258
+ hash: Optional[str] = Field(title="Short hash")
259
+ sha256: Optional[str] = Field(title="sha256 hash")
260
+ filename: str = Field(title="Filename")
261
+ config: Optional[str] = Field(title="Config file")
262
+
263
+ class SDVaeItem(BaseModel):
264
+ model_name: str = Field(title="Model Name")
265
+ filename: str = Field(title="Filename")
266
+
267
+ class HypernetworkItem(BaseModel):
268
+ name: str = Field(title="Name")
269
+ path: Optional[str] = Field(title="Path")
270
+
271
+ class FaceRestorerItem(BaseModel):
272
+ name: str = Field(title="Name")
273
+ cmd_dir: Optional[str] = Field(title="Path")
274
+
275
+ class RealesrganItem(BaseModel):
276
+ name: str = Field(title="Name")
277
+ path: Optional[str] = Field(title="Path")
278
+ scale: Optional[int] = Field(title="Scale")
279
+
280
+ class PromptStyleItem(BaseModel):
281
+ name: str = Field(title="Name")
282
+ prompt: Optional[str] = Field(title="Prompt")
283
+ negative_prompt: Optional[str] = Field(title="Negative Prompt")
284
+
285
+
286
+ class EmbeddingItem(BaseModel):
287
+ step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
288
+ sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
289
+ sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
290
+ shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
291
+ vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
292
+
293
+ class EmbeddingsResponse(BaseModel):
294
+ loaded: dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
295
+ skipped: dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
296
+
297
+ class MemoryResponse(BaseModel):
298
+ ram: dict = Field(title="RAM", description="System memory stats")
299
+ cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
300
+
301
+
302
+ class ScriptsList(BaseModel):
303
+ txt2img: list = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)")
304
+ img2img: list = Field(default=None, title="Img2img", description="Titles of scripts (img2img)")
305
+
306
+
307
+ class ScriptArg(BaseModel):
308
+ label: str = Field(default=None, title="Label", description="Name of the argument in UI")
309
+ value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument")
310
+ minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
311
+ maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
312
+ step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
313
+ choices: Optional[list[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
314
+
315
+
316
+ class ScriptInfo(BaseModel):
317
+ name: str = Field(default=None, title="Name", description="Script name")
318
+ is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
319
+ is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
320
+ args: list[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
321
+
322
+ class ExtensionItem(BaseModel):
323
+ name: str = Field(title="Name", description="Extension name")
324
+ remote: str = Field(title="Remote", description="Extension Repository URL")
325
+ branch: str = Field(title="Branch", description="Extension Repository Branch")
326
+ commit_hash: str = Field(title="Commit Hash", description="Extension Repository Commit Hash")
327
+ version: str = Field(title="Version", description="Extension Version")
328
+ commit_date: str = Field(title="Commit Date", description="Extension Repository Commit Date")
329
+ enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled")
modules/cache.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import os.path
4
+ import threading
5
+
6
+ import diskcache
7
+ import tqdm
8
+
9
+ from modules.paths import data_path, script_path
10
+
11
+ cache_filename = os.environ.get('SD_WEBUI_CACHE_FILE', os.path.join(data_path, "cache.json"))
12
+ cache_dir = os.environ.get('SD_WEBUI_CACHE_DIR', os.path.join(data_path, "cache"))
13
+ caches = {}
14
+ cache_lock = threading.Lock()
15
+
16
+
17
+ def dump_cache():
18
+ """old function for dumping cache to disk; does nothing since diskcache."""
19
+
20
+ pass
21
+
22
+
23
+ def make_cache(subsection: str) -> diskcache.Cache:
24
+ return diskcache.Cache(
25
+ os.path.join(cache_dir, subsection),
26
+ size_limit=2**32, # 4 GB, culling oldest first
27
+ disk_min_file_size=2**18, # keep up to 256KB in Sqlite
28
+ )
29
+
30
+
31
+ def convert_old_cached_data():
32
+ try:
33
+ with open(cache_filename, "r", encoding="utf8") as file:
34
+ data = json.load(file)
35
+ except FileNotFoundError:
36
+ return
37
+ except Exception:
38
+ os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
39
+ print('[ERROR] issue occurred while trying to read cache.json; old cache has been moved to tmp/cache.json')
40
+ return
41
+
42
+ total_count = sum(len(keyvalues) for keyvalues in data.values())
43
+
44
+ with tqdm.tqdm(total=total_count, desc="converting cache") as progress:
45
+ for subsection, keyvalues in data.items():
46
+ cache_obj = caches.get(subsection)
47
+ if cache_obj is None:
48
+ cache_obj = make_cache(subsection)
49
+ caches[subsection] = cache_obj
50
+
51
+ for key, value in keyvalues.items():
52
+ cache_obj[key] = value
53
+ progress.update(1)
54
+
55
+
56
+ def cache(subsection):
57
+ """
58
+ Retrieves or initializes a cache for a specific subsection.
59
+
60
+ Parameters:
61
+ subsection (str): The subsection identifier for the cache.
62
+
63
+ Returns:
64
+ diskcache.Cache: The cache data for the specified subsection.
65
+ """
66
+
67
+ cache_obj = caches.get(subsection)
68
+ if not cache_obj:
69
+ with cache_lock:
70
+ if not os.path.exists(cache_dir) and os.path.isfile(cache_filename):
71
+ convert_old_cached_data()
72
+
73
+ cache_obj = caches.get(subsection)
74
+ if not cache_obj:
75
+ cache_obj = make_cache(subsection)
76
+ caches[subsection] = cache_obj
77
+
78
+ return cache_obj
79
+
80
+
81
+ def cached_data_for_file(subsection, title, filename, func):
82
+ """
83
+ Retrieves or generates data for a specific file, using a caching mechanism.
84
+
85
+ Parameters:
86
+ subsection (str): The subsection of the cache to use.
87
+ title (str): The title of the data entry in the subsection of the cache.
88
+ filename (str): The path to the file to be checked for modifications.
89
+ func (callable): A function that generates the data if it is not available in the cache.
90
+
91
+ Returns:
92
+ dict or None: The cached or generated data, or None if data generation fails.
93
+
94
+ The `cached_data_for_file` function implements a caching mechanism for data stored in files.
95
+ It checks if the data associated with the given `title` is present in the cache and compares the
96
+ modification time of the file with the cached modification time. If the file has been modified,
97
+ the cache is considered invalid and the data is regenerated using the provided `func`.
98
+ Otherwise, the cached data is returned.
99
+
100
+ If the data generation fails, None is returned to indicate the failure. Otherwise, the generated
101
+ or cached data is returned as a dictionary.
102
+ """
103
+
104
+ existing_cache = cache(subsection)
105
+ ondisk_mtime = os.path.getmtime(filename)
106
+
107
+ entry = existing_cache.get(title)
108
+ if entry:
109
+ cached_mtime = entry.get("mtime", 0)
110
+ if ondisk_mtime > cached_mtime:
111
+ entry = None
112
+
113
+ if not entry or 'value' not in entry:
114
+ value = func()
115
+ if value is None:
116
+ return None
117
+
118
+ entry = {'mtime': ondisk_mtime, 'value': value}
119
+ existing_cache[title] = entry
120
+
121
+ dump_cache()
122
+
123
+ return entry['value']
modules/call_queue.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from functools import wraps
3
+ import html
4
+ import time
5
+
6
+ from modules import shared, progress, errors, devices, fifo_lock, profiling
7
+
8
+ queue_lock = fifo_lock.FIFOLock()
9
+
10
+
11
+ def wrap_queued_call(func):
12
+ def f(*args, **kwargs):
13
+ with queue_lock:
14
+ res = func(*args, **kwargs)
15
+
16
+ return res
17
+
18
+ return f
19
+
20
+
21
+ def wrap_gradio_gpu_call(func, extra_outputs=None):
22
+ @wraps(func)
23
+ def f(*args, **kwargs):
24
+
25
+ # if the first argument is a string that says "task(...)", it is treated as a job id
26
+ if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"):
27
+ id_task = args[0]
28
+ progress.add_task_to_queue(id_task)
29
+ else:
30
+ id_task = None
31
+
32
+ with queue_lock:
33
+ shared.state.begin(job=id_task)
34
+ progress.start_task(id_task)
35
+
36
+ try:
37
+ res = func(*args, **kwargs)
38
+ progress.record_results(id_task, res)
39
+ finally:
40
+ progress.finish_task(id_task)
41
+
42
+ shared.state.end()
43
+
44
+ return res
45
+
46
+ return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
47
+
48
+
49
+ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
50
+ @wraps(func)
51
+ def f(*args, **kwargs):
52
+ try:
53
+ res = func(*args, **kwargs)
54
+ finally:
55
+ shared.state.skipped = False
56
+ shared.state.interrupted = False
57
+ shared.state.stopping_generation = False
58
+ shared.state.job_count = 0
59
+ shared.state.job = ""
60
+ return res
61
+
62
+ return wrap_gradio_call_no_job(f, extra_outputs, add_stats)
63
+
64
+
65
+ def wrap_gradio_call_no_job(func, extra_outputs=None, add_stats=False):
66
+ @wraps(func)
67
+ def f(*args, extra_outputs_array=extra_outputs, **kwargs):
68
+ run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
69
+ if run_memmon:
70
+ shared.mem_mon.monitor()
71
+ t = time.perf_counter()
72
+
73
+ try:
74
+ res = list(func(*args, **kwargs))
75
+ except Exception as e:
76
+ # When printing out our debug argument list,
77
+ # do not print out more than a 100 KB of text
78
+ max_debug_str_len = 131072
79
+ message = "Error completing request"
80
+ arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len]
81
+ if len(arg_str) > max_debug_str_len:
82
+ arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)"
83
+ errors.report(f"{message}\n{arg_str}", exc_info=True)
84
+
85
+ if extra_outputs_array is None:
86
+ extra_outputs_array = [None, '']
87
+
88
+ error_message = f'{type(e).__name__}: {e}'
89
+ res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
90
+
91
+ devices.torch_gc()
92
+
93
+ if not add_stats:
94
+ return tuple(res)
95
+
96
+ elapsed = time.perf_counter() - t
97
+ elapsed_m = int(elapsed // 60)
98
+ elapsed_s = elapsed % 60
99
+ elapsed_text = f"{elapsed_s:.1f} sec."
100
+ if elapsed_m > 0:
101
+ elapsed_text = f"{elapsed_m} min. "+elapsed_text
102
+
103
+ if run_memmon:
104
+ mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
105
+ active_peak = mem_stats['active_peak']
106
+ reserved_peak = mem_stats['reserved_peak']
107
+ sys_peak = mem_stats['system_peak']
108
+ sys_total = mem_stats['total']
109
+ sys_pct = sys_peak/max(sys_total, 1) * 100
110
+
111
+ toltip_a = "Active: peak amount of video memory used during generation (excluding cached data)"
112
+ toltip_r = "Reserved: total amount of video memory allocated by the Torch library "
113
+ toltip_sys = "System: peak amount of video memory allocated by all running programs, out of total capacity"
114
+
115
+ text_a = f"<abbr title='{toltip_a}'>A</abbr>: <span class='measurement'>{active_peak/1024:.2f} GB</span>"
116
+ text_r = f"<abbr title='{toltip_r}'>R</abbr>: <span class='measurement'>{reserved_peak/1024:.2f} GB</span>"
117
+ text_sys = f"<abbr title='{toltip_sys}'>Sys</abbr>: <span class='measurement'>{sys_peak/1024:.1f}/{sys_total/1024:g} GB</span> ({sys_pct:.1f}%)"
118
+
119
+ vram_html = f"<p class='vram'>{text_a}, <wbr>{text_r}, <wbr>{text_sys}</p>"
120
+ else:
121
+ vram_html = ''
122
+
123
+ if shared.opts.profiling_enable and os.path.exists(shared.opts.profiling_filename):
124
+ profiling_html = f"<p class='profile'> [ <a href='{profiling.webpath()}' download>Profile</a> ] </p>"
125
+ else:
126
+ profiling_html = ''
127
+
128
+ # last item is always HTML
129
+ res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr><span class='measurement'>{elapsed_text}</span></p>{vram_html}{profiling_html}</div>"
130
+
131
+ return tuple(res)
132
+
133
+ return f
134
+
modules/cmd_args.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ from modules.paths_internal import normalized_filepath, models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
5
+
6
+ parser = argparse.ArgumentParser()
7
+
8
+ parser.add_argument("-f", action='store_true', help=argparse.SUPPRESS) # allows running as root; implemented outside of webui
9
+ parser.add_argument("--update-all-extensions", action='store_true', help="launch.py argument: download updates for all extensions when starting the program")
10
+ parser.add_argument("--skip-python-version-check", action='store_true', help="launch.py argument: do not check python version")
11
+ parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")
12
+ parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
13
+ parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
14
+ parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
15
+ parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
16
+ parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup")
17
+ parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
18
+ parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
19
+ parser.add_argument("--dump-sysinfo", action='store_true', help="launch.py argument: dump limited sysinfo file (without information about extensions, options) to disk and quit")
20
+ parser.add_argument("--loglevel", type=str, help="log level; one of: CRITICAL, ERROR, WARNING, INFO, DEBUG", default=None)
21
+ parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
22
+ parser.add_argument("--data-dir", type=normalized_filepath, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
23
+ parser.add_argument("--models-dir", type=normalized_filepath, default=None, help="base path where models are stored; overrides --data-dir")
24
+ parser.add_argument("--config", type=normalized_filepath, default=sd_default_config, help="path to config which constructs model",)
25
+ parser.add_argument("--ckpt", type=normalized_filepath, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
26
+ parser.add_argument("--ckpt-dir", type=normalized_filepath, default=None, help="Path to directory with stable diffusion checkpoints")
27
+ parser.add_argument("--vae-dir", type=normalized_filepath, default=None, help="Path to directory with VAE files")
28
+ parser.add_argument("--gfpgan-dir", type=normalized_filepath, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
29
+ parser.add_argument("--gfpgan-model", type=normalized_filepath, help="GFPGAN model file name", default=None)
30
+ parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
31
+ parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
32
+ parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
33
+ parser.add_argument("--max-batch-count", type=int, default=16, help="does not do anything")
34
+ parser.add_argument("--embeddings-dir", type=normalized_filepath, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
35
+ parser.add_argument("--textual-inversion-templates-dir", type=normalized_filepath, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
36
+ parser.add_argument("--hypernetwork-dir", type=normalized_filepath, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
37
+ parser.add_argument("--localizations-dir", type=normalized_filepath, default=os.path.join(script_path, 'localizations'), help="localizations directory")
38
+ parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
39
+ parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
40
+ parser.add_argument("--medvram-sdxl", action='store_true', help="enable --medvram optimization just for SDXL models")
41
+ parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
42
+ parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
43
+ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything")
44
+ parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
45
+ parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "half", "autocast"], default="autocast")
46
+ parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
47
+ parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
48
+ parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
49
+ parser.add_argument("--ngrok-region", type=str, help="does not do anything.", default="")
50
+ parser.add_argument("--ngrok-options", type=json.loads, help='The options to pass to ngrok in JSON format, e.g.: \'{"authtoken_from_env":true, "basic_auth":"user:password", "oauth_provider":"google", "oauth_allow_emails":"user@asdf.com"}\'', default=dict())
51
+ parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
52
+ parser.add_argument("--codeformer-models-path", type=normalized_filepath, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
53
+ parser.add_argument("--gfpgan-models-path", type=normalized_filepath, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
54
+ parser.add_argument("--esrgan-models-path", type=normalized_filepath, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
55
+ parser.add_argument("--bsrgan-models-path", type=normalized_filepath, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
56
+ parser.add_argument("--realesrgan-models-path", type=normalized_filepath, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
57
+ parser.add_argument("--dat-models-path", type=normalized_filepath, help="Path to directory with DAT model file(s).", default=os.path.join(models_path, 'DAT'))
58
+ parser.add_argument("--clip-models-path", type=normalized_filepath, help="Path to directory with CLIP model file(s).", default=None)
59
+ parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
60
+ parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
61
+ parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
62
+ parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
63
+ parser.add_argument("--opt-split-attention", action='store_true', help="prefer Doggettx's cross-attention layer optimization for automatic choice of optimization")
64
+ parser.add_argument("--opt-sub-quad-attention", action='store_true', help="prefer memory efficient sub-quadratic cross-attention layer optimization for automatic choice of optimization")
65
+ parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
66
+ parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
67
+ parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
68
+ parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="prefer InvokeAI's cross-attention layer optimization for automatic choice of optimization")
69
+ parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefer older version of split attention optimization for automatic choice of optimization")
70
+ parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*")
71
+ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*")
72
+ parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
73
+ parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
74
+ parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
75
+ parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device")
76
+ parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
77
+ parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
78
+ parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
79
+ parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
80
+ parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
81
+ parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
82
+ parser.add_argument("--freeze-settings", action='store_true', help="disable editing of all settings globally", default=False)
83
+ parser.add_argument("--freeze-settings-in-sections", type=str, help='disable editing settings in specific sections of the settings page by specifying a comma-delimited list such like "saving-images,upscaling". The list of setting names can be found in the modules/shared_options.py file', default=None)
84
+ parser.add_argument("--freeze-specific-settings", type=str, help='disable editing of individual settings by specifying a comma-delimited list like "samples_save,samples_format". The list of setting names can be found in the config.json file', default=None)
85
+ parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
86
+ parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
87
+ parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
88
+ parser.add_argument("--gradio-auth-path", type=normalized_filepath, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
89
+ parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
90
+ parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
91
+ parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path])
92
+ parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
93
+ parser.add_argument("--styles-file", type=str, action='append', help="path or wildcard path of styles files, allow multiple entries.", default=[])
94
+ parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
95
+ parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
96
+ parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
97
+ parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
98
+ parser.add_argument("--enable-console-prompts", action='store_true', help="does not do anything", default=False) # Legacy compatibility, use as default value shared.opts.enable_console_prompts
99
+ parser.add_argument('--vae-path', type=normalized_filepath, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
100
+ parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
101
+ parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
102
+ parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
103
+ parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
104
+ parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
105
+ parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
106
+ parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
107
+ parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
108
+ parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
109
+ parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
110
+ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
111
+ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
112
+ parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
113
+ parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
114
+ parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
115
+ parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the default in earlier versions")
116
+ parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
117
+ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
118
+ parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
119
+ parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
120
+ parser.add_argument('--add-stop-route', action='store_true', help='does not do anything')
121
+ parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
122
+ parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
123
+ parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
124
+ parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False)
125
+ parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui")
126
+ parser.add_argument("--unix-filenames-sanitization", action='store_true', help="allow any symbols except '/' in filenames. May conflict with your browser and file system")
127
+ parser.add_argument("--filenames-max-length", type=int, default=128, help='maximal length of filenames of saved images. If you override it, it can conflict with your file system')
128
+ parser.add_argument("--no-prompt-history", action='store_true', help="disable read prompt from last generation feature; settings this argument will not create '--data_path/params.txt' file")
modules/codeformer_model.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+
5
+ import torch
6
+
7
+ from modules import (
8
+ devices,
9
+ errors,
10
+ face_restoration,
11
+ face_restoration_utils,
12
+ modelloader,
13
+ shared,
14
+ )
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
19
+ model_download_name = 'codeformer-v0.1.0.pth'
20
+
21
+ # used by e.g. postprocessing_codeformer.py
22
+ codeformer: face_restoration.FaceRestoration | None = None
23
+
24
+
25
+ class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
26
+ def name(self):
27
+ return "CodeFormer"
28
+
29
+ def load_net(self) -> torch.Module:
30
+ for model_path in modelloader.load_models(
31
+ model_path=self.model_path,
32
+ model_url=model_url,
33
+ command_path=self.model_path,
34
+ download_name=model_download_name,
35
+ ext_filter=['.pth'],
36
+ ):
37
+ return modelloader.load_spandrel_model(
38
+ model_path,
39
+ device=devices.device_codeformer,
40
+ expected_architecture='CodeFormer',
41
+ ).model
42
+ raise ValueError("No codeformer model found")
43
+
44
+ def get_device(self):
45
+ return devices.device_codeformer
46
+
47
+ def restore(self, np_image, w: float | None = None):
48
+ if w is None:
49
+ w = getattr(shared.opts, "code_former_weight", 0.5)
50
+
51
+ def restore_face(cropped_face_t):
52
+ assert self.net is not None
53
+ return self.net(cropped_face_t, weight=w, adain=True)[0]
54
+
55
+ return self.restore_with_helper(np_image, restore_face)
56
+
57
+
58
+ def setup_model(dirname: str) -> None:
59
+ global codeformer
60
+ try:
61
+ codeformer = FaceRestorerCodeFormer(dirname)
62
+ shared.face_restorers.append(codeformer)
63
+ except Exception:
64
+ errors.report("Error setting up CodeFormer", exc_info=True)
modules/config_states.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Supports saving and restoring webui and extensions from a known working set of commits
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import tqdm
8
+
9
+ from datetime import datetime
10
+ import git
11
+
12
+ from modules import shared, extensions, errors
13
+ from modules.paths_internal import script_path, config_states_dir
14
+
15
+ all_config_states = {}
16
+
17
+
18
+ def list_config_states():
19
+ global all_config_states
20
+
21
+ all_config_states.clear()
22
+ os.makedirs(config_states_dir, exist_ok=True)
23
+
24
+ config_states = []
25
+ for filename in os.listdir(config_states_dir):
26
+ if filename.endswith(".json"):
27
+ path = os.path.join(config_states_dir, filename)
28
+ try:
29
+ with open(path, "r", encoding="utf-8") as f:
30
+ j = json.load(f)
31
+ assert "created_at" in j, '"created_at" does not exist'
32
+ j["filepath"] = path
33
+ config_states.append(j)
34
+ except Exception as e:
35
+ print(f'[ERROR]: Config states {path}, {e}')
36
+
37
+ config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
38
+
39
+ for cs in config_states:
40
+ timestamp = datetime.fromtimestamp(cs["created_at"]).strftime('%Y-%m-%d %H:%M:%S')
41
+ name = cs.get("name", "Config")
42
+ full_name = f"{name}: {timestamp}"
43
+ all_config_states[full_name] = cs
44
+
45
+ return all_config_states
46
+
47
+
48
+ def get_webui_config():
49
+ webui_repo = None
50
+
51
+ try:
52
+ if os.path.exists(os.path.join(script_path, ".git")):
53
+ webui_repo = git.Repo(script_path)
54
+ except Exception:
55
+ errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
56
+
57
+ webui_remote = None
58
+ webui_commit_hash = None
59
+ webui_commit_date = None
60
+ webui_branch = None
61
+ if webui_repo and not webui_repo.bare:
62
+ try:
63
+ webui_remote = next(webui_repo.remote().urls, None)
64
+ head = webui_repo.head.commit
65
+ webui_commit_date = webui_repo.head.commit.committed_date
66
+ webui_commit_hash = head.hexsha
67
+ webui_branch = webui_repo.active_branch.name
68
+
69
+ except Exception:
70
+ webui_remote = None
71
+
72
+ return {
73
+ "remote": webui_remote,
74
+ "commit_hash": webui_commit_hash,
75
+ "commit_date": webui_commit_date,
76
+ "branch": webui_branch,
77
+ }
78
+
79
+
80
+ def get_extension_config():
81
+ ext_config = {}
82
+
83
+ for ext in extensions.extensions:
84
+ ext.read_info_from_repo()
85
+
86
+ entry = {
87
+ "name": ext.name,
88
+ "path": ext.path,
89
+ "enabled": ext.enabled,
90
+ "is_builtin": ext.is_builtin,
91
+ "remote": ext.remote,
92
+ "commit_hash": ext.commit_hash,
93
+ "commit_date": ext.commit_date,
94
+ "branch": ext.branch,
95
+ "have_info_from_repo": ext.have_info_from_repo
96
+ }
97
+
98
+ ext_config[ext.name] = entry
99
+
100
+ return ext_config
101
+
102
+
103
+ def get_config():
104
+ creation_time = datetime.now().timestamp()
105
+ webui_config = get_webui_config()
106
+ ext_config = get_extension_config()
107
+
108
+ return {
109
+ "created_at": creation_time,
110
+ "webui": webui_config,
111
+ "extensions": ext_config
112
+ }
113
+
114
+
115
+ def restore_webui_config(config):
116
+ print("* Restoring webui state...")
117
+
118
+ if "webui" not in config:
119
+ print("Error: No webui data saved to config")
120
+ return
121
+
122
+ webui_config = config["webui"]
123
+
124
+ if "commit_hash" not in webui_config:
125
+ print("Error: No commit saved to webui config")
126
+ return
127
+
128
+ webui_commit_hash = webui_config.get("commit_hash", None)
129
+ webui_repo = None
130
+
131
+ try:
132
+ if os.path.exists(os.path.join(script_path, ".git")):
133
+ webui_repo = git.Repo(script_path)
134
+ except Exception:
135
+ errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
136
+ return
137
+
138
+ try:
139
+ webui_repo.git.fetch(all=True)
140
+ webui_repo.git.reset(webui_commit_hash, hard=True)
141
+ print(f"* Restored webui to commit {webui_commit_hash}.")
142
+ except Exception:
143
+ errors.report(f"Error restoring webui to commit{webui_commit_hash}")
144
+
145
+
146
+ def restore_extension_config(config):
147
+ print("* Restoring extension state...")
148
+
149
+ if "extensions" not in config:
150
+ print("Error: No extension data saved to config")
151
+ return
152
+
153
+ ext_config = config["extensions"]
154
+
155
+ results = []
156
+ disabled = []
157
+
158
+ for ext in tqdm.tqdm(extensions.extensions):
159
+ if ext.is_builtin:
160
+ continue
161
+
162
+ ext.read_info_from_repo()
163
+ current_commit = ext.commit_hash
164
+
165
+ if ext.name not in ext_config:
166
+ ext.disabled = True
167
+ disabled.append(ext.name)
168
+ results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled"))
169
+ continue
170
+
171
+ entry = ext_config[ext.name]
172
+
173
+ if "commit_hash" in entry and entry["commit_hash"]:
174
+ try:
175
+ ext.fetch_and_reset_hard(entry["commit_hash"])
176
+ ext.read_info_from_repo()
177
+ if current_commit != entry["commit_hash"]:
178
+ results.append((ext, current_commit[:8], True, entry["commit_hash"][:8]))
179
+ except Exception as ex:
180
+ results.append((ext, current_commit[:8], False, ex))
181
+ else:
182
+ results.append((ext, current_commit[:8], False, "No commit hash found in config"))
183
+
184
+ if not entry.get("enabled", False):
185
+ ext.disabled = True
186
+ disabled.append(ext.name)
187
+ else:
188
+ ext.disabled = False
189
+
190
+ shared.opts.disabled_extensions = disabled
191
+ shared.opts.save(shared.config_filename)
192
+
193
+ print("* Finished restoring extensions. Results:")
194
+ for ext, prev_commit, success, result in results:
195
+ if success:
196
+ print(f" + {ext.name}: {prev_commit} -> {result}")
197
+ else:
198
+ print(f" ! {ext.name}: FAILURE ({result})")
modules/dat_model.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from modules import modelloader, errors
4
+ from modules.shared import cmd_opts, opts
5
+ from modules.upscaler import Upscaler, UpscalerData
6
+ from modules.upscaler_utils import upscale_with_model
7
+
8
+
9
+ class UpscalerDAT(Upscaler):
10
+ def __init__(self, user_path):
11
+ self.name = "DAT"
12
+ self.user_path = user_path
13
+ self.scalers = []
14
+ super().__init__()
15
+
16
+ for file in self.find_models(ext_filter=[".pt", ".pth"]):
17
+ name = modelloader.friendly_name(file)
18
+ scaler_data = UpscalerData(name, file, upscaler=self, scale=None)
19
+ self.scalers.append(scaler_data)
20
+
21
+ for model in get_dat_models(self):
22
+ if model.name in opts.dat_enabled_models:
23
+ self.scalers.append(model)
24
+
25
+ def do_upscale(self, img, path):
26
+ try:
27
+ info = self.load_model(path)
28
+ except Exception:
29
+ errors.report(f"Unable to load DAT model {path}", exc_info=True)
30
+ return img
31
+
32
+ model_descriptor = modelloader.load_spandrel_model(
33
+ info.local_data_path,
34
+ device=self.device,
35
+ prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
36
+ expected_architecture="DAT",
37
+ )
38
+ return upscale_with_model(
39
+ model_descriptor,
40
+ img,
41
+ tile_size=opts.DAT_tile,
42
+ tile_overlap=opts.DAT_tile_overlap,
43
+ )
44
+
45
+ def load_model(self, path):
46
+ for scaler in self.scalers:
47
+ if scaler.data_path == path:
48
+ if scaler.local_data_path.startswith("http"):
49
+ scaler.local_data_path = modelloader.load_file_from_url(
50
+ scaler.data_path,
51
+ model_dir=self.model_download_path,
52
+ )
53
+ if not os.path.exists(scaler.local_data_path):
54
+ raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}")
55
+ return scaler
56
+ raise ValueError(f"Unable to find model info: {path}")
57
+
58
+
59
+ def get_dat_models(scaler):
60
+ return [
61
+ UpscalerData(
62
+ name="DAT x2",
63
+ path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth",
64
+ scale=2,
65
+ upscaler=scaler,
66
+ ),
67
+ UpscalerData(
68
+ name="DAT x3",
69
+ path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth",
70
+ scale=3,
71
+ upscaler=scaler,
72
+ ),
73
+ UpscalerData(
74
+ name="DAT x4",
75
+ path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth",
76
+ scale=4,
77
+ upscaler=scaler,
78
+ ),
79
+ ]
modules/deepbooru.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import torch
5
+ import numpy as np
6
+
7
+ from modules import modelloader, paths, deepbooru_model, devices, images, shared
8
+
9
+ re_special = re.compile(r'([\\()])')
10
+
11
+
12
+ class DeepDanbooru:
13
+ def __init__(self):
14
+ self.model = None
15
+
16
+ def load(self):
17
+ if self.model is not None:
18
+ return
19
+
20
+ files = modelloader.load_models(
21
+ model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
22
+ model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
23
+ ext_filter=[".pt"],
24
+ download_name='model-resnet_custom_v3.pt',
25
+ )
26
+
27
+ self.model = deepbooru_model.DeepDanbooruModel()
28
+ self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
29
+
30
+ self.model.eval()
31
+ self.model.to(devices.cpu, devices.dtype)
32
+
33
+ def start(self):
34
+ self.load()
35
+ self.model.to(devices.device)
36
+
37
+ def stop(self):
38
+ if not shared.opts.interrogate_keep_models_in_memory:
39
+ self.model.to(devices.cpu)
40
+ devices.torch_gc()
41
+
42
+ def tag(self, pil_image):
43
+ self.start()
44
+ res = self.tag_multi(pil_image)
45
+ self.stop()
46
+
47
+ return res
48
+
49
+ def tag_multi(self, pil_image, force_disable_ranks=False):
50
+ threshold = shared.opts.interrogate_deepbooru_score_threshold
51
+ use_spaces = shared.opts.deepbooru_use_spaces
52
+ use_escape = shared.opts.deepbooru_escape
53
+ alpha_sort = shared.opts.deepbooru_sort_alpha
54
+ include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
55
+
56
+ pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
57
+ a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
58
+
59
+ with torch.no_grad(), devices.autocast():
60
+ x = torch.from_numpy(a).to(devices.device, devices.dtype)
61
+ y = self.model(x)[0].detach().cpu().numpy()
62
+
63
+ probability_dict = {}
64
+
65
+ for tag, probability in zip(self.model.tags, y):
66
+ if probability < threshold:
67
+ continue
68
+
69
+ if tag.startswith("rating:"):
70
+ continue
71
+
72
+ probability_dict[tag] = probability
73
+
74
+ if alpha_sort:
75
+ tags = sorted(probability_dict)
76
+ else:
77
+ tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
78
+
79
+ res = []
80
+
81
+ filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
82
+
83
+ for tag in [x for x in tags if x not in filtertags]:
84
+ probability = probability_dict[tag]
85
+ tag_outformat = tag
86
+ if use_spaces:
87
+ tag_outformat = tag_outformat.replace('_', ' ')
88
+ if use_escape:
89
+ tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
90
+ if include_ranks:
91
+ tag_outformat = f"({tag_outformat}:{probability:.3f})"
92
+
93
+ res.append(tag_outformat)
94
+
95
+ return ", ".join(res)
96
+
97
+
98
+ model = DeepDanbooru()
modules/deepbooru_model.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from modules import devices
6
+
7
+ # see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
8
+
9
+
10
+ class DeepDanbooruModel(nn.Module):
11
+ def __init__(self):
12
+ super(DeepDanbooruModel, self).__init__()
13
+
14
+ self.tags = []
15
+
16
+ self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
17
+ self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
18
+ self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
19
+ self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
20
+ self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
21
+ self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
22
+ self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
23
+ self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
24
+ self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
25
+ self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
26
+ self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
27
+ self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
28
+ self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
29
+ self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
30
+ self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
31
+ self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
32
+ self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
33
+ self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
34
+ self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
35
+ self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
36
+ self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
37
+ self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
38
+ self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
39
+ self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
40
+ self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
41
+ self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
42
+ self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
43
+ self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
44
+ self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
45
+ self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
46
+ self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
47
+ self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
48
+ self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
49
+ self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
50
+ self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
51
+ self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
52
+ self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
53
+ self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
54
+ self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
55
+ self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
56
+ self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
57
+ self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
58
+ self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
59
+ self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
60
+ self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
61
+ self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
62
+ self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
63
+ self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
64
+ self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
65
+ self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
66
+ self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
67
+ self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
68
+ self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
69
+ self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
70
+ self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
71
+ self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
72
+ self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
73
+ self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
74
+ self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
75
+ self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
76
+ self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
77
+ self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
78
+ self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
79
+ self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
80
+ self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
81
+ self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
82
+ self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
83
+ self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
84
+ self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
85
+ self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
86
+ self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
87
+ self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
88
+ self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
89
+ self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
90
+ self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
91
+ self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
92
+ self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
93
+ self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
94
+ self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
95
+ self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
96
+ self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
97
+ self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
98
+ self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
99
+ self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
100
+ self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
101
+ self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
102
+ self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
103
+ self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
104
+ self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
105
+ self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
106
+ self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
107
+ self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
108
+ self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
109
+ self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
110
+ self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
111
+ self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
112
+ self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
113
+ self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
114
+ self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
115
+ self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
116
+ self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
117
+ self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
118
+ self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
119
+ self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
120
+ self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
121
+ self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
122
+ self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
123
+ self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
124
+ self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
125
+ self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
126
+ self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
127
+ self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
128
+ self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
129
+ self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
130
+ self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
131
+ self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
132
+ self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
133
+ self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
134
+ self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
135
+ self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
136
+ self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
137
+ self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
138
+ self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
139
+ self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
140
+ self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
141
+ self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
142
+ self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
143
+ self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
144
+ self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
145
+ self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
146
+ self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
147
+ self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
148
+ self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
149
+ self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
150
+ self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
151
+ self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
152
+ self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
153
+ self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
154
+ self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
155
+ self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
156
+ self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
157
+ self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
158
+ self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
159
+ self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
160
+ self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
161
+ self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
162
+ self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
163
+ self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
164
+ self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
165
+ self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
166
+ self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
167
+ self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
168
+ self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
169
+ self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
170
+ self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
171
+ self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
172
+ self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
173
+ self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
174
+ self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
175
+ self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
176
+ self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
177
+ self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
178
+ self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
179
+ self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
180
+ self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
181
+ self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
182
+ self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
183
+ self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
184
+ self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
185
+ self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
186
+ self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
187
+ self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
188
+ self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
189
+ self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
190
+ self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
191
+ self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
192
+ self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
193
+ self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
194
+ self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
195
+ self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
196
+
197
+ def forward(self, *inputs):
198
+ t_358, = inputs
199
+ t_359 = t_358.permute(*[0, 3, 1, 2])
200
+ t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
201
+ t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
202
+ t_361 = F.relu(t_360)
203
+ t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
204
+ t_362 = self.n_MaxPool_0(t_361)
205
+ t_363 = self.n_Conv_1(t_362)
206
+ t_364 = self.n_Conv_2(t_362)
207
+ t_365 = F.relu(t_364)
208
+ t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
209
+ t_366 = self.n_Conv_3(t_365_padded)
210
+ t_367 = F.relu(t_366)
211
+ t_368 = self.n_Conv_4(t_367)
212
+ t_369 = torch.add(t_368, t_363)
213
+ t_370 = F.relu(t_369)
214
+ t_371 = self.n_Conv_5(t_370)
215
+ t_372 = F.relu(t_371)
216
+ t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
217
+ t_373 = self.n_Conv_6(t_372_padded)
218
+ t_374 = F.relu(t_373)
219
+ t_375 = self.n_Conv_7(t_374)
220
+ t_376 = torch.add(t_375, t_370)
221
+ t_377 = F.relu(t_376)
222
+ t_378 = self.n_Conv_8(t_377)
223
+ t_379 = F.relu(t_378)
224
+ t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
225
+ t_380 = self.n_Conv_9(t_379_padded)
226
+ t_381 = F.relu(t_380)
227
+ t_382 = self.n_Conv_10(t_381)
228
+ t_383 = torch.add(t_382, t_377)
229
+ t_384 = F.relu(t_383)
230
+ t_385 = self.n_Conv_11(t_384)
231
+ t_386 = self.n_Conv_12(t_384)
232
+ t_387 = F.relu(t_386)
233
+ t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
234
+ t_388 = self.n_Conv_13(t_387_padded)
235
+ t_389 = F.relu(t_388)
236
+ t_390 = self.n_Conv_14(t_389)
237
+ t_391 = torch.add(t_390, t_385)
238
+ t_392 = F.relu(t_391)
239
+ t_393 = self.n_Conv_15(t_392)
240
+ t_394 = F.relu(t_393)
241
+ t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
242
+ t_395 = self.n_Conv_16(t_394_padded)
243
+ t_396 = F.relu(t_395)
244
+ t_397 = self.n_Conv_17(t_396)
245
+ t_398 = torch.add(t_397, t_392)
246
+ t_399 = F.relu(t_398)
247
+ t_400 = self.n_Conv_18(t_399)
248
+ t_401 = F.relu(t_400)
249
+ t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
250
+ t_402 = self.n_Conv_19(t_401_padded)
251
+ t_403 = F.relu(t_402)
252
+ t_404 = self.n_Conv_20(t_403)
253
+ t_405 = torch.add(t_404, t_399)
254
+ t_406 = F.relu(t_405)
255
+ t_407 = self.n_Conv_21(t_406)
256
+ t_408 = F.relu(t_407)
257
+ t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
258
+ t_409 = self.n_Conv_22(t_408_padded)
259
+ t_410 = F.relu(t_409)
260
+ t_411 = self.n_Conv_23(t_410)
261
+ t_412 = torch.add(t_411, t_406)
262
+ t_413 = F.relu(t_412)
263
+ t_414 = self.n_Conv_24(t_413)
264
+ t_415 = F.relu(t_414)
265
+ t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
266
+ t_416 = self.n_Conv_25(t_415_padded)
267
+ t_417 = F.relu(t_416)
268
+ t_418 = self.n_Conv_26(t_417)
269
+ t_419 = torch.add(t_418, t_413)
270
+ t_420 = F.relu(t_419)
271
+ t_421 = self.n_Conv_27(t_420)
272
+ t_422 = F.relu(t_421)
273
+ t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
274
+ t_423 = self.n_Conv_28(t_422_padded)
275
+ t_424 = F.relu(t_423)
276
+ t_425 = self.n_Conv_29(t_424)
277
+ t_426 = torch.add(t_425, t_420)
278
+ t_427 = F.relu(t_426)
279
+ t_428 = self.n_Conv_30(t_427)
280
+ t_429 = F.relu(t_428)
281
+ t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
282
+ t_430 = self.n_Conv_31(t_429_padded)
283
+ t_431 = F.relu(t_430)
284
+ t_432 = self.n_Conv_32(t_431)
285
+ t_433 = torch.add(t_432, t_427)
286
+ t_434 = F.relu(t_433)
287
+ t_435 = self.n_Conv_33(t_434)
288
+ t_436 = F.relu(t_435)
289
+ t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
290
+ t_437 = self.n_Conv_34(t_436_padded)
291
+ t_438 = F.relu(t_437)
292
+ t_439 = self.n_Conv_35(t_438)
293
+ t_440 = torch.add(t_439, t_434)
294
+ t_441 = F.relu(t_440)
295
+ t_442 = self.n_Conv_36(t_441)
296
+ t_443 = self.n_Conv_37(t_441)
297
+ t_444 = F.relu(t_443)
298
+ t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
299
+ t_445 = self.n_Conv_38(t_444_padded)
300
+ t_446 = F.relu(t_445)
301
+ t_447 = self.n_Conv_39(t_446)
302
+ t_448 = torch.add(t_447, t_442)
303
+ t_449 = F.relu(t_448)
304
+ t_450 = self.n_Conv_40(t_449)
305
+ t_451 = F.relu(t_450)
306
+ t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
307
+ t_452 = self.n_Conv_41(t_451_padded)
308
+ t_453 = F.relu(t_452)
309
+ t_454 = self.n_Conv_42(t_453)
310
+ t_455 = torch.add(t_454, t_449)
311
+ t_456 = F.relu(t_455)
312
+ t_457 = self.n_Conv_43(t_456)
313
+ t_458 = F.relu(t_457)
314
+ t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
315
+ t_459 = self.n_Conv_44(t_458_padded)
316
+ t_460 = F.relu(t_459)
317
+ t_461 = self.n_Conv_45(t_460)
318
+ t_462 = torch.add(t_461, t_456)
319
+ t_463 = F.relu(t_462)
320
+ t_464 = self.n_Conv_46(t_463)
321
+ t_465 = F.relu(t_464)
322
+ t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
323
+ t_466 = self.n_Conv_47(t_465_padded)
324
+ t_467 = F.relu(t_466)
325
+ t_468 = self.n_Conv_48(t_467)
326
+ t_469 = torch.add(t_468, t_463)
327
+ t_470 = F.relu(t_469)
328
+ t_471 = self.n_Conv_49(t_470)
329
+ t_472 = F.relu(t_471)
330
+ t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
331
+ t_473 = self.n_Conv_50(t_472_padded)
332
+ t_474 = F.relu(t_473)
333
+ t_475 = self.n_Conv_51(t_474)
334
+ t_476 = torch.add(t_475, t_470)
335
+ t_477 = F.relu(t_476)
336
+ t_478 = self.n_Conv_52(t_477)
337
+ t_479 = F.relu(t_478)
338
+ t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
339
+ t_480 = self.n_Conv_53(t_479_padded)
340
+ t_481 = F.relu(t_480)
341
+ t_482 = self.n_Conv_54(t_481)
342
+ t_483 = torch.add(t_482, t_477)
343
+ t_484 = F.relu(t_483)
344
+ t_485 = self.n_Conv_55(t_484)
345
+ t_486 = F.relu(t_485)
346
+ t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
347
+ t_487 = self.n_Conv_56(t_486_padded)
348
+ t_488 = F.relu(t_487)
349
+ t_489 = self.n_Conv_57(t_488)
350
+ t_490 = torch.add(t_489, t_484)
351
+ t_491 = F.relu(t_490)
352
+ t_492 = self.n_Conv_58(t_491)
353
+ t_493 = F.relu(t_492)
354
+ t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
355
+ t_494 = self.n_Conv_59(t_493_padded)
356
+ t_495 = F.relu(t_494)
357
+ t_496 = self.n_Conv_60(t_495)
358
+ t_497 = torch.add(t_496, t_491)
359
+ t_498 = F.relu(t_497)
360
+ t_499 = self.n_Conv_61(t_498)
361
+ t_500 = F.relu(t_499)
362
+ t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
363
+ t_501 = self.n_Conv_62(t_500_padded)
364
+ t_502 = F.relu(t_501)
365
+ t_503 = self.n_Conv_63(t_502)
366
+ t_504 = torch.add(t_503, t_498)
367
+ t_505 = F.relu(t_504)
368
+ t_506 = self.n_Conv_64(t_505)
369
+ t_507 = F.relu(t_506)
370
+ t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
371
+ t_508 = self.n_Conv_65(t_507_padded)
372
+ t_509 = F.relu(t_508)
373
+ t_510 = self.n_Conv_66(t_509)
374
+ t_511 = torch.add(t_510, t_505)
375
+ t_512 = F.relu(t_511)
376
+ t_513 = self.n_Conv_67(t_512)
377
+ t_514 = F.relu(t_513)
378
+ t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
379
+ t_515 = self.n_Conv_68(t_514_padded)
380
+ t_516 = F.relu(t_515)
381
+ t_517 = self.n_Conv_69(t_516)
382
+ t_518 = torch.add(t_517, t_512)
383
+ t_519 = F.relu(t_518)
384
+ t_520 = self.n_Conv_70(t_519)
385
+ t_521 = F.relu(t_520)
386
+ t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
387
+ t_522 = self.n_Conv_71(t_521_padded)
388
+ t_523 = F.relu(t_522)
389
+ t_524 = self.n_Conv_72(t_523)
390
+ t_525 = torch.add(t_524, t_519)
391
+ t_526 = F.relu(t_525)
392
+ t_527 = self.n_Conv_73(t_526)
393
+ t_528 = F.relu(t_527)
394
+ t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
395
+ t_529 = self.n_Conv_74(t_528_padded)
396
+ t_530 = F.relu(t_529)
397
+ t_531 = self.n_Conv_75(t_530)
398
+ t_532 = torch.add(t_531, t_526)
399
+ t_533 = F.relu(t_532)
400
+ t_534 = self.n_Conv_76(t_533)
401
+ t_535 = F.relu(t_534)
402
+ t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
403
+ t_536 = self.n_Conv_77(t_535_padded)
404
+ t_537 = F.relu(t_536)
405
+ t_538 = self.n_Conv_78(t_537)
406
+ t_539 = torch.add(t_538, t_533)
407
+ t_540 = F.relu(t_539)
408
+ t_541 = self.n_Conv_79(t_540)
409
+ t_542 = F.relu(t_541)
410
+ t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
411
+ t_543 = self.n_Conv_80(t_542_padded)
412
+ t_544 = F.relu(t_543)
413
+ t_545 = self.n_Conv_81(t_544)
414
+ t_546 = torch.add(t_545, t_540)
415
+ t_547 = F.relu(t_546)
416
+ t_548 = self.n_Conv_82(t_547)
417
+ t_549 = F.relu(t_548)
418
+ t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
419
+ t_550 = self.n_Conv_83(t_549_padded)
420
+ t_551 = F.relu(t_550)
421
+ t_552 = self.n_Conv_84(t_551)
422
+ t_553 = torch.add(t_552, t_547)
423
+ t_554 = F.relu(t_553)
424
+ t_555 = self.n_Conv_85(t_554)
425
+ t_556 = F.relu(t_555)
426
+ t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
427
+ t_557 = self.n_Conv_86(t_556_padded)
428
+ t_558 = F.relu(t_557)
429
+ t_559 = self.n_Conv_87(t_558)
430
+ t_560 = torch.add(t_559, t_554)
431
+ t_561 = F.relu(t_560)
432
+ t_562 = self.n_Conv_88(t_561)
433
+ t_563 = F.relu(t_562)
434
+ t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
435
+ t_564 = self.n_Conv_89(t_563_padded)
436
+ t_565 = F.relu(t_564)
437
+ t_566 = self.n_Conv_90(t_565)
438
+ t_567 = torch.add(t_566, t_561)
439
+ t_568 = F.relu(t_567)
440
+ t_569 = self.n_Conv_91(t_568)
441
+ t_570 = F.relu(t_569)
442
+ t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
443
+ t_571 = self.n_Conv_92(t_570_padded)
444
+ t_572 = F.relu(t_571)
445
+ t_573 = self.n_Conv_93(t_572)
446
+ t_574 = torch.add(t_573, t_568)
447
+ t_575 = F.relu(t_574)
448
+ t_576 = self.n_Conv_94(t_575)
449
+ t_577 = F.relu(t_576)
450
+ t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
451
+ t_578 = self.n_Conv_95(t_577_padded)
452
+ t_579 = F.relu(t_578)
453
+ t_580 = self.n_Conv_96(t_579)
454
+ t_581 = torch.add(t_580, t_575)
455
+ t_582 = F.relu(t_581)
456
+ t_583 = self.n_Conv_97(t_582)
457
+ t_584 = F.relu(t_583)
458
+ t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
459
+ t_585 = self.n_Conv_98(t_584_padded)
460
+ t_586 = F.relu(t_585)
461
+ t_587 = self.n_Conv_99(t_586)
462
+ t_588 = self.n_Conv_100(t_582)
463
+ t_589 = torch.add(t_587, t_588)
464
+ t_590 = F.relu(t_589)
465
+ t_591 = self.n_Conv_101(t_590)
466
+ t_592 = F.relu(t_591)
467
+ t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
468
+ t_593 = self.n_Conv_102(t_592_padded)
469
+ t_594 = F.relu(t_593)
470
+ t_595 = self.n_Conv_103(t_594)
471
+ t_596 = torch.add(t_595, t_590)
472
+ t_597 = F.relu(t_596)
473
+ t_598 = self.n_Conv_104(t_597)
474
+ t_599 = F.relu(t_598)
475
+ t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
476
+ t_600 = self.n_Conv_105(t_599_padded)
477
+ t_601 = F.relu(t_600)
478
+ t_602 = self.n_Conv_106(t_601)
479
+ t_603 = torch.add(t_602, t_597)
480
+ t_604 = F.relu(t_603)
481
+ t_605 = self.n_Conv_107(t_604)
482
+ t_606 = F.relu(t_605)
483
+ t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
484
+ t_607 = self.n_Conv_108(t_606_padded)
485
+ t_608 = F.relu(t_607)
486
+ t_609 = self.n_Conv_109(t_608)
487
+ t_610 = torch.add(t_609, t_604)
488
+ t_611 = F.relu(t_610)
489
+ t_612 = self.n_Conv_110(t_611)
490
+ t_613 = F.relu(t_612)
491
+ t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
492
+ t_614 = self.n_Conv_111(t_613_padded)
493
+ t_615 = F.relu(t_614)
494
+ t_616 = self.n_Conv_112(t_615)
495
+ t_617 = torch.add(t_616, t_611)
496
+ t_618 = F.relu(t_617)
497
+ t_619 = self.n_Conv_113(t_618)
498
+ t_620 = F.relu(t_619)
499
+ t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
500
+ t_621 = self.n_Conv_114(t_620_padded)
501
+ t_622 = F.relu(t_621)
502
+ t_623 = self.n_Conv_115(t_622)
503
+ t_624 = torch.add(t_623, t_618)
504
+ t_625 = F.relu(t_624)
505
+ t_626 = self.n_Conv_116(t_625)
506
+ t_627 = F.relu(t_626)
507
+ t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
508
+ t_628 = self.n_Conv_117(t_627_padded)
509
+ t_629 = F.relu(t_628)
510
+ t_630 = self.n_Conv_118(t_629)
511
+ t_631 = torch.add(t_630, t_625)
512
+ t_632 = F.relu(t_631)
513
+ t_633 = self.n_Conv_119(t_632)
514
+ t_634 = F.relu(t_633)
515
+ t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
516
+ t_635 = self.n_Conv_120(t_634_padded)
517
+ t_636 = F.relu(t_635)
518
+ t_637 = self.n_Conv_121(t_636)
519
+ t_638 = torch.add(t_637, t_632)
520
+ t_639 = F.relu(t_638)
521
+ t_640 = self.n_Conv_122(t_639)
522
+ t_641 = F.relu(t_640)
523
+ t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
524
+ t_642 = self.n_Conv_123(t_641_padded)
525
+ t_643 = F.relu(t_642)
526
+ t_644 = self.n_Conv_124(t_643)
527
+ t_645 = torch.add(t_644, t_639)
528
+ t_646 = F.relu(t_645)
529
+ t_647 = self.n_Conv_125(t_646)
530
+ t_648 = F.relu(t_647)
531
+ t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
532
+ t_649 = self.n_Conv_126(t_648_padded)
533
+ t_650 = F.relu(t_649)
534
+ t_651 = self.n_Conv_127(t_650)
535
+ t_652 = torch.add(t_651, t_646)
536
+ t_653 = F.relu(t_652)
537
+ t_654 = self.n_Conv_128(t_653)
538
+ t_655 = F.relu(t_654)
539
+ t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
540
+ t_656 = self.n_Conv_129(t_655_padded)
541
+ t_657 = F.relu(t_656)
542
+ t_658 = self.n_Conv_130(t_657)
543
+ t_659 = torch.add(t_658, t_653)
544
+ t_660 = F.relu(t_659)
545
+ t_661 = self.n_Conv_131(t_660)
546
+ t_662 = F.relu(t_661)
547
+ t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
548
+ t_663 = self.n_Conv_132(t_662_padded)
549
+ t_664 = F.relu(t_663)
550
+ t_665 = self.n_Conv_133(t_664)
551
+ t_666 = torch.add(t_665, t_660)
552
+ t_667 = F.relu(t_666)
553
+ t_668 = self.n_Conv_134(t_667)
554
+ t_669 = F.relu(t_668)
555
+ t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
556
+ t_670 = self.n_Conv_135(t_669_padded)
557
+ t_671 = F.relu(t_670)
558
+ t_672 = self.n_Conv_136(t_671)
559
+ t_673 = torch.add(t_672, t_667)
560
+ t_674 = F.relu(t_673)
561
+ t_675 = self.n_Conv_137(t_674)
562
+ t_676 = F.relu(t_675)
563
+ t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
564
+ t_677 = self.n_Conv_138(t_676_padded)
565
+ t_678 = F.relu(t_677)
566
+ t_679 = self.n_Conv_139(t_678)
567
+ t_680 = torch.add(t_679, t_674)
568
+ t_681 = F.relu(t_680)
569
+ t_682 = self.n_Conv_140(t_681)
570
+ t_683 = F.relu(t_682)
571
+ t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
572
+ t_684 = self.n_Conv_141(t_683_padded)
573
+ t_685 = F.relu(t_684)
574
+ t_686 = self.n_Conv_142(t_685)
575
+ t_687 = torch.add(t_686, t_681)
576
+ t_688 = F.relu(t_687)
577
+ t_689 = self.n_Conv_143(t_688)
578
+ t_690 = F.relu(t_689)
579
+ t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
580
+ t_691 = self.n_Conv_144(t_690_padded)
581
+ t_692 = F.relu(t_691)
582
+ t_693 = self.n_Conv_145(t_692)
583
+ t_694 = torch.add(t_693, t_688)
584
+ t_695 = F.relu(t_694)
585
+ t_696 = self.n_Conv_146(t_695)
586
+ t_697 = F.relu(t_696)
587
+ t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
588
+ t_698 = self.n_Conv_147(t_697_padded)
589
+ t_699 = F.relu(t_698)
590
+ t_700 = self.n_Conv_148(t_699)
591
+ t_701 = torch.add(t_700, t_695)
592
+ t_702 = F.relu(t_701)
593
+ t_703 = self.n_Conv_149(t_702)
594
+ t_704 = F.relu(t_703)
595
+ t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
596
+ t_705 = self.n_Conv_150(t_704_padded)
597
+ t_706 = F.relu(t_705)
598
+ t_707 = self.n_Conv_151(t_706)
599
+ t_708 = torch.add(t_707, t_702)
600
+ t_709 = F.relu(t_708)
601
+ t_710 = self.n_Conv_152(t_709)
602
+ t_711 = F.relu(t_710)
603
+ t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
604
+ t_712 = self.n_Conv_153(t_711_padded)
605
+ t_713 = F.relu(t_712)
606
+ t_714 = self.n_Conv_154(t_713)
607
+ t_715 = torch.add(t_714, t_709)
608
+ t_716 = F.relu(t_715)
609
+ t_717 = self.n_Conv_155(t_716)
610
+ t_718 = F.relu(t_717)
611
+ t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
612
+ t_719 = self.n_Conv_156(t_718_padded)
613
+ t_720 = F.relu(t_719)
614
+ t_721 = self.n_Conv_157(t_720)
615
+ t_722 = torch.add(t_721, t_716)
616
+ t_723 = F.relu(t_722)
617
+ t_724 = self.n_Conv_158(t_723)
618
+ t_725 = self.n_Conv_159(t_723)
619
+ t_726 = F.relu(t_725)
620
+ t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
621
+ t_727 = self.n_Conv_160(t_726_padded)
622
+ t_728 = F.relu(t_727)
623
+ t_729 = self.n_Conv_161(t_728)
624
+ t_730 = torch.add(t_729, t_724)
625
+ t_731 = F.relu(t_730)
626
+ t_732 = self.n_Conv_162(t_731)
627
+ t_733 = F.relu(t_732)
628
+ t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
629
+ t_734 = self.n_Conv_163(t_733_padded)
630
+ t_735 = F.relu(t_734)
631
+ t_736 = self.n_Conv_164(t_735)
632
+ t_737 = torch.add(t_736, t_731)
633
+ t_738 = F.relu(t_737)
634
+ t_739 = self.n_Conv_165(t_738)
635
+ t_740 = F.relu(t_739)
636
+ t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
637
+ t_741 = self.n_Conv_166(t_740_padded)
638
+ t_742 = F.relu(t_741)
639
+ t_743 = self.n_Conv_167(t_742)
640
+ t_744 = torch.add(t_743, t_738)
641
+ t_745 = F.relu(t_744)
642
+ t_746 = self.n_Conv_168(t_745)
643
+ t_747 = self.n_Conv_169(t_745)
644
+ t_748 = F.relu(t_747)
645
+ t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
646
+ t_749 = self.n_Conv_170(t_748_padded)
647
+ t_750 = F.relu(t_749)
648
+ t_751 = self.n_Conv_171(t_750)
649
+ t_752 = torch.add(t_751, t_746)
650
+ t_753 = F.relu(t_752)
651
+ t_754 = self.n_Conv_172(t_753)
652
+ t_755 = F.relu(t_754)
653
+ t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
654
+ t_756 = self.n_Conv_173(t_755_padded)
655
+ t_757 = F.relu(t_756)
656
+ t_758 = self.n_Conv_174(t_757)
657
+ t_759 = torch.add(t_758, t_753)
658
+ t_760 = F.relu(t_759)
659
+ t_761 = self.n_Conv_175(t_760)
660
+ t_762 = F.relu(t_761)
661
+ t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
662
+ t_763 = self.n_Conv_176(t_762_padded)
663
+ t_764 = F.relu(t_763)
664
+ t_765 = self.n_Conv_177(t_764)
665
+ t_766 = torch.add(t_765, t_760)
666
+ t_767 = F.relu(t_766)
667
+ t_768 = self.n_Conv_178(t_767)
668
+ t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
669
+ t_770 = torch.squeeze(t_769, 3)
670
+ t_770 = torch.squeeze(t_770, 2)
671
+ t_771 = torch.sigmoid(t_770)
672
+ return t_771
673
+
674
+ def load_state_dict(self, state_dict, **kwargs):
675
+ self.tags = state_dict.get('tags', [])
676
+
677
+ super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
678
+
modules/devices.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import contextlib
3
+ from functools import lru_cache
4
+
5
+ import torch
6
+ from modules import errors, shared, npu_specific
7
+
8
+ if sys.platform == "darwin":
9
+ from modules import mac_specific
10
+
11
+ if shared.cmd_opts.use_ipex:
12
+ from modules import xpu_specific
13
+
14
+
15
+ def has_xpu() -> bool:
16
+ return shared.cmd_opts.use_ipex and xpu_specific.has_xpu
17
+
18
+
19
+ def has_mps() -> bool:
20
+ if sys.platform != "darwin":
21
+ return False
22
+ else:
23
+ return mac_specific.has_mps
24
+
25
+
26
+ def cuda_no_autocast(device_id=None) -> bool:
27
+ if device_id is None:
28
+ device_id = get_cuda_device_id()
29
+ return (
30
+ torch.cuda.get_device_capability(device_id) == (7, 5)
31
+ and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
32
+ )
33
+
34
+
35
+ def get_cuda_device_id():
36
+ return (
37
+ int(shared.cmd_opts.device_id)
38
+ if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
39
+ else 0
40
+ ) or torch.cuda.current_device()
41
+
42
+
43
+ def get_cuda_device_string():
44
+ if shared.cmd_opts.device_id is not None:
45
+ return f"cuda:{shared.cmd_opts.device_id}"
46
+
47
+ return "cuda"
48
+
49
+
50
+ def get_optimal_device_name():
51
+ if torch.cuda.is_available():
52
+ return get_cuda_device_string()
53
+
54
+ if has_mps():
55
+ return "mps"
56
+
57
+ if has_xpu():
58
+ return xpu_specific.get_xpu_device_string()
59
+
60
+ if npu_specific.has_npu:
61
+ return npu_specific.get_npu_device_string()
62
+
63
+ return "cpu"
64
+
65
+
66
+ def get_optimal_device():
67
+ return torch.device(get_optimal_device_name())
68
+
69
+
70
+ def get_device_for(task):
71
+ if task in shared.cmd_opts.use_cpu or "all" in shared.cmd_opts.use_cpu:
72
+ return cpu
73
+
74
+ return get_optimal_device()
75
+
76
+
77
+ def torch_gc():
78
+
79
+ if torch.cuda.is_available():
80
+ with torch.cuda.device(get_cuda_device_string()):
81
+ torch.cuda.empty_cache()
82
+ torch.cuda.ipc_collect()
83
+
84
+ if has_mps():
85
+ mac_specific.torch_mps_gc()
86
+
87
+ if has_xpu():
88
+ xpu_specific.torch_xpu_gc()
89
+
90
+ if npu_specific.has_npu:
91
+ torch_npu_set_device()
92
+ npu_specific.torch_npu_gc()
93
+
94
+
95
+ def torch_npu_set_device():
96
+ # Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
97
+ if npu_specific.has_npu:
98
+ torch.npu.set_device(0)
99
+
100
+
101
+ def enable_tf32():
102
+ if torch.cuda.is_available():
103
+
104
+ # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
105
+ # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
106
+ if cuda_no_autocast():
107
+ torch.backends.cudnn.benchmark = True
108
+
109
+ torch.backends.cuda.matmul.allow_tf32 = True
110
+ torch.backends.cudnn.allow_tf32 = True
111
+
112
+
113
+ errors.run(enable_tf32, "Enabling TF32")
114
+
115
+ cpu: torch.device = torch.device("cpu")
116
+ fp8: bool = False
117
+ # Force fp16 for all models in inference. No casting during inference.
118
+ # This flag is controlled by "--precision half" command line arg.
119
+ force_fp16: bool = False
120
+ device: torch.device = None
121
+ device_interrogate: torch.device = None
122
+ device_gfpgan: torch.device = None
123
+ device_esrgan: torch.device = None
124
+ device_codeformer: torch.device = None
125
+ dtype: torch.dtype = torch.float16
126
+ dtype_vae: torch.dtype = torch.float16
127
+ dtype_unet: torch.dtype = torch.float16
128
+ dtype_inference: torch.dtype = torch.float16
129
+ unet_needs_upcast = False
130
+
131
+
132
+ def cond_cast_unet(input):
133
+ if force_fp16:
134
+ return input.to(torch.float16)
135
+ return input.to(dtype_unet) if unet_needs_upcast else input
136
+
137
+
138
+ def cond_cast_float(input):
139
+ return input.float() if unet_needs_upcast else input
140
+
141
+
142
+ nv_rng = None
143
+ patch_module_list = [
144
+ torch.nn.Linear,
145
+ torch.nn.Conv2d,
146
+ torch.nn.MultiheadAttention,
147
+ torch.nn.GroupNorm,
148
+ torch.nn.LayerNorm,
149
+ ]
150
+
151
+
152
+ def manual_cast_forward(target_dtype):
153
+ def forward_wrapper(self, *args, **kwargs):
154
+ if any(
155
+ isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
156
+ for arg in args
157
+ ):
158
+ args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
159
+ kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
160
+
161
+ org_dtype = target_dtype
162
+ for param in self.parameters():
163
+ if param.dtype != target_dtype:
164
+ org_dtype = param.dtype
165
+ break
166
+
167
+ if org_dtype != target_dtype:
168
+ self.to(target_dtype)
169
+ result = self.org_forward(*args, **kwargs)
170
+ if org_dtype != target_dtype:
171
+ self.to(org_dtype)
172
+
173
+ if target_dtype != dtype_inference:
174
+ if isinstance(result, tuple):
175
+ result = tuple(
176
+ i.to(dtype_inference)
177
+ if isinstance(i, torch.Tensor)
178
+ else i
179
+ for i in result
180
+ )
181
+ elif isinstance(result, torch.Tensor):
182
+ result = result.to(dtype_inference)
183
+ return result
184
+ return forward_wrapper
185
+
186
+
187
+ @contextlib.contextmanager
188
+ def manual_cast(target_dtype):
189
+ applied = False
190
+ for module_type in patch_module_list:
191
+ if hasattr(module_type, "org_forward"):
192
+ continue
193
+ applied = True
194
+ org_forward = module_type.forward
195
+ if module_type == torch.nn.MultiheadAttention:
196
+ module_type.forward = manual_cast_forward(torch.float32)
197
+ else:
198
+ module_type.forward = manual_cast_forward(target_dtype)
199
+ module_type.org_forward = org_forward
200
+ try:
201
+ yield None
202
+ finally:
203
+ if applied:
204
+ for module_type in patch_module_list:
205
+ if hasattr(module_type, "org_forward"):
206
+ module_type.forward = module_type.org_forward
207
+ delattr(module_type, "org_forward")
208
+
209
+
210
+ def autocast(disable=False):
211
+ if disable:
212
+ return contextlib.nullcontext()
213
+
214
+ if force_fp16:
215
+ # No casting during inference if force_fp16 is enabled.
216
+ # All tensor dtype conversion happens before inference.
217
+ return contextlib.nullcontext()
218
+
219
+ if fp8 and device==cpu:
220
+ return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
221
+
222
+ if fp8 and dtype_inference == torch.float32:
223
+ return manual_cast(dtype)
224
+
225
+ if dtype == torch.float32 or dtype_inference == torch.float32:
226
+ return contextlib.nullcontext()
227
+
228
+ if has_xpu() or has_mps() or cuda_no_autocast():
229
+ return manual_cast(dtype)
230
+
231
+ return torch.autocast("cuda")
232
+
233
+
234
+ def without_autocast(disable=False):
235
+ return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
236
+
237
+
238
+ class NansException(Exception):
239
+ pass
240
+
241
+
242
+ def test_for_nans(x, where):
243
+ if shared.cmd_opts.disable_nan_check:
244
+ return
245
+
246
+ if not torch.isnan(x[(0, ) * len(x.shape)]):
247
+ return
248
+
249
+ if where == "unet":
250
+ message = "A tensor with NaNs was produced in Unet."
251
+
252
+ if not shared.cmd_opts.no_half:
253
+ message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
254
+
255
+ elif where == "vae":
256
+ message = "A tensor with NaNs was produced in VAE."
257
+
258
+ if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
259
+ message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
260
+ else:
261
+ message = "A tensor with NaNs was produced."
262
+
263
+ message += " Use --disable-nan-check commandline argument to disable this check."
264
+
265
+ raise NansException(message)
266
+
267
+
268
+ @lru_cache
269
+ def first_time_calculation():
270
+ """
271
+ just do any calculation with pytorch layers - the first time this is done it allocates about 700MB of memory and
272
+ spends about 2.7 seconds doing that, at least with NVidia.
273
+ """
274
+
275
+ x = torch.zeros((1, 1)).to(device, dtype)
276
+ linear = torch.nn.Linear(1, 1).to(device, dtype)
277
+ linear(x)
278
+
279
+ x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
280
+ conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
281
+ conv2d(x)
282
+
283
+
284
+ def force_model_fp16():
285
+ """
286
+ ldm and sgm has modules.diffusionmodules.util.GroupNorm32.forward, which
287
+ force conversion of input to float32. If force_fp16 is enabled, we need to
288
+ prevent this casting.
289
+ """
290
+ assert force_fp16
291
+ import sgm.modules.diffusionmodules.util as sgm_util
292
+ import ldm.modules.diffusionmodules.util as ldm_util
293
+ sgm_util.GroupNorm32 = torch.nn.GroupNorm
294
+ ldm_util.GroupNorm32 = torch.nn.GroupNorm
295
+ print("ldm/sgm GroupNorm32 replaced with normal torch.nn.GroupNorm due to `--precision half`.")
modules/errors.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import textwrap
3
+ import traceback
4
+
5
+
6
+ exception_records = []
7
+
8
+
9
+ def format_traceback(tb):
10
+ return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
11
+
12
+
13
+ def format_exception(e, tb):
14
+ return {"exception": str(e), "traceback": format_traceback(tb)}
15
+
16
+
17
+ def get_exceptions():
18
+ try:
19
+ return list(reversed(exception_records))
20
+ except Exception as e:
21
+ return str(e)
22
+
23
+
24
+ def record_exception():
25
+ _, e, tb = sys.exc_info()
26
+ if e is None:
27
+ return
28
+
29
+ if exception_records and exception_records[-1] == e:
30
+ return
31
+
32
+ exception_records.append(format_exception(e, tb))
33
+
34
+ if len(exception_records) > 5:
35
+ exception_records.pop(0)
36
+
37
+
38
+ def report(message: str, *, exc_info: bool = False) -> None:
39
+ """
40
+ Print an error message to stderr, with optional traceback.
41
+ """
42
+
43
+ record_exception()
44
+
45
+ for line in message.splitlines():
46
+ print("***", line, file=sys.stderr)
47
+ if exc_info:
48
+ print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr)
49
+ print("---", file=sys.stderr)
50
+
51
+
52
+ def print_error_explanation(message):
53
+ record_exception()
54
+
55
+ lines = message.strip().split("\n")
56
+ max_len = max([len(x) for x in lines])
57
+
58
+ print('=' * max_len, file=sys.stderr)
59
+ for line in lines:
60
+ print(line, file=sys.stderr)
61
+ print('=' * max_len, file=sys.stderr)
62
+
63
+
64
+ def display(e: Exception, task, *, full_traceback=False):
65
+ record_exception()
66
+
67
+ print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
68
+ te = traceback.TracebackException.from_exception(e)
69
+ if full_traceback:
70
+ # include frames leading up to the try-catch block
71
+ te.stack = traceback.StackSummary(traceback.extract_stack()[:-2] + te.stack)
72
+ print(*te.format(), sep="", file=sys.stderr)
73
+
74
+ message = str(e)
75
+ if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
76
+ print_error_explanation("""
77
+ The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
78
+ See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
79
+ """)
80
+
81
+
82
+ already_displayed = {}
83
+
84
+
85
+ def display_once(e: Exception, task):
86
+ record_exception()
87
+
88
+ if task in already_displayed:
89
+ return
90
+
91
+ display(e, task)
92
+
93
+ already_displayed[task] = 1
94
+
95
+
96
+ def run(code, task):
97
+ try:
98
+ code()
99
+ except Exception as e:
100
+ display(task, e)
101
+
102
+
103
+ def check_versions():
104
+ from packaging import version
105
+ from modules import shared
106
+
107
+ import torch
108
+ import gradio
109
+
110
+ expected_torch_version = "2.1.2"
111
+ expected_xformers_version = "0.0.23.post1"
112
+ expected_gradio_version = "3.41.2"
113
+
114
+ if version.parse(torch.__version__) < version.parse(expected_torch_version):
115
+ print_error_explanation(f"""
116
+ You are running torch {torch.__version__}.
117
+ The program is tested to work with torch {expected_torch_version}.
118
+ To reinstall the desired version, run with commandline flag --reinstall-torch.
119
+ Beware that this will cause a lot of large files to be downloaded, as well as
120
+ there are reports of issues with training tab on the latest version.
121
+
122
+ Use --skip-version-check commandline argument to disable this check.
123
+ """.strip())
124
+
125
+ if shared.xformers_available:
126
+ import xformers
127
+
128
+ if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
129
+ print_error_explanation(f"""
130
+ You are running xformers {xformers.__version__}.
131
+ The program is tested to work with xformers {expected_xformers_version}.
132
+ To reinstall the desired version, run with commandline flag --reinstall-xformers.
133
+
134
+ Use --skip-version-check commandline argument to disable this check.
135
+ """.strip())
136
+
137
+ if gradio.__version__ != expected_gradio_version:
138
+ print_error_explanation(f"""
139
+ You are running gradio {gradio.__version__}.
140
+ The program is designed to work with gradio {expected_gradio_version}.
141
+ Using a different version of gradio is extremely likely to break the program.
142
+
143
+ Reasons why you have the mismatched gradio version can be:
144
+ - you use --skip-install flag.
145
+ - you use webui.py to start the program instead of launch.py.
146
+ - an extension installs the incompatible gradio version.
147
+
148
+ Use --skip-version-check commandline argument to disable this check.
149
+ """.strip())
150
+
modules/esrgan_model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules import modelloader, devices, errors
2
+ from modules.shared import opts
3
+ from modules.upscaler import Upscaler, UpscalerData
4
+ from modules.upscaler_utils import upscale_with_model
5
+
6
+
7
+ class UpscalerESRGAN(Upscaler):
8
+ def __init__(self, dirname):
9
+ self.name = "ESRGAN"
10
+ self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
11
+ self.model_name = "ESRGAN_4x"
12
+ self.scalers = []
13
+ self.user_path = dirname
14
+ super().__init__()
15
+ model_paths = self.find_models(ext_filter=[".pt", ".pth"])
16
+ scalers = []
17
+ if len(model_paths) == 0:
18
+ scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
19
+ scalers.append(scaler_data)
20
+ for file in model_paths:
21
+ if file.startswith("http"):
22
+ name = self.model_name
23
+ else:
24
+ name = modelloader.friendly_name(file)
25
+
26
+ scaler_data = UpscalerData(name, file, self, 4)
27
+ self.scalers.append(scaler_data)
28
+
29
+ def do_upscale(self, img, selected_model):
30
+ try:
31
+ model = self.load_model(selected_model)
32
+ except Exception:
33
+ errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True)
34
+ return img
35
+ model.to(devices.device_esrgan)
36
+ return esrgan_upscale(model, img)
37
+
38
+ def load_model(self, path: str):
39
+ if path.startswith("http"):
40
+ # TODO: this doesn't use `path` at all?
41
+ filename = modelloader.load_file_from_url(
42
+ url=self.model_url,
43
+ model_dir=self.model_download_path,
44
+ file_name=f"{self.model_name}.pth",
45
+ )
46
+ else:
47
+ filename = path
48
+
49
+ return modelloader.load_spandrel_model(
50
+ filename,
51
+ device=('cpu' if devices.device_esrgan.type == 'mps' else None),
52
+ expected_architecture='ESRGAN',
53
+ )
54
+
55
+
56
+ def esrgan_upscale(model, img):
57
+ return upscale_with_model(
58
+ model,
59
+ img,
60
+ tile_size=opts.ESRGAN_tile,
61
+ tile_overlap=opts.ESRGAN_tile_overlap,
62
+ )
modules/extensions.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import configparser
4
+ import dataclasses
5
+ import os
6
+ import threading
7
+ import re
8
+
9
+ from modules import shared, errors, cache, scripts
10
+ from modules.gitpython_hack import Repo
11
+ from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
12
+
13
+ extensions: list[Extension] = []
14
+ extension_paths: dict[str, Extension] = {}
15
+ loaded_extensions: dict[str, Exception] = {}
16
+
17
+
18
+ os.makedirs(extensions_dir, exist_ok=True)
19
+
20
+
21
+ def active():
22
+ if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
23
+ return []
24
+ elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra":
25
+ return [x for x in extensions if x.enabled and x.is_builtin]
26
+ else:
27
+ return [x for x in extensions if x.enabled]
28
+
29
+
30
+ @dataclasses.dataclass
31
+ class CallbackOrderInfo:
32
+ name: str
33
+ before: list
34
+ after: list
35
+
36
+
37
+ class ExtensionMetadata:
38
+ filename = "metadata.ini"
39
+ config: configparser.ConfigParser
40
+ canonical_name: str
41
+ requires: list
42
+
43
+ def __init__(self, path, canonical_name):
44
+ self.config = configparser.ConfigParser()
45
+
46
+ filepath = os.path.join(path, self.filename)
47
+ # `self.config.read()` will quietly swallow OSErrors (which FileNotFoundError is),
48
+ # so no need to check whether the file exists beforehand.
49
+ try:
50
+ self.config.read(filepath)
51
+ except Exception:
52
+ errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True)
53
+
54
+ self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)
55
+ self.canonical_name = canonical_name.lower().strip()
56
+
57
+ self.requires = None
58
+
59
+ def get_script_requirements(self, field, section, extra_section=None):
60
+ """reads a list of requirements from the config; field is the name of the field in the ini file,
61
+ like Requires or Before, and section is the name of the [section] in the ini file; additionally,
62
+ reads more requirements from [extra_section] if specified."""
63
+
64
+ x = self.config.get(section, field, fallback='')
65
+
66
+ if extra_section:
67
+ x = x + ', ' + self.config.get(extra_section, field, fallback='')
68
+
69
+ listed_requirements = self.parse_list(x.lower())
70
+ res = []
71
+
72
+ for requirement in listed_requirements:
73
+ loaded_requirements = (x for x in requirement.split("|") if x in loaded_extensions)
74
+ relevant_requirement = next(loaded_requirements, requirement)
75
+ res.append(relevant_requirement)
76
+
77
+ return res
78
+
79
+ def parse_list(self, text):
80
+ """converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])"""
81
+
82
+ if not text:
83
+ return []
84
+
85
+ # both "," and " " are accepted as separator
86
+ return [x for x in re.split(r"[,\s]+", text.strip()) if x]
87
+
88
+ def list_callback_order_instructions(self):
89
+ for section in self.config.sections():
90
+ if not section.startswith("callbacks/"):
91
+ continue
92
+
93
+ callback_name = section[10:]
94
+
95
+ if not callback_name.startswith(self.canonical_name):
96
+ errors.report(f"Callback order section for extension {self.canonical_name} is referencing the wrong extension: {section}")
97
+ continue
98
+
99
+ before = self.parse_list(self.config.get(section, 'Before', fallback=''))
100
+ after = self.parse_list(self.config.get(section, 'After', fallback=''))
101
+
102
+ yield CallbackOrderInfo(callback_name, before, after)
103
+
104
+
105
+ class Extension:
106
+ lock = threading.Lock()
107
+ cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
108
+ metadata: ExtensionMetadata
109
+
110
+ def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None):
111
+ self.name = name
112
+ self.path = path
113
+ self.enabled = enabled
114
+ self.status = ''
115
+ self.can_update = False
116
+ self.is_builtin = is_builtin
117
+ self.commit_hash = ''
118
+ self.commit_date = None
119
+ self.version = ''
120
+ self.branch = None
121
+ self.remote = None
122
+ self.have_info_from_repo = False
123
+ self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower())
124
+ self.canonical_name = metadata.canonical_name
125
+
126
+ def to_dict(self):
127
+ return {x: getattr(self, x) for x in self.cached_fields}
128
+
129
+ def from_dict(self, d):
130
+ for field in self.cached_fields:
131
+ setattr(self, field, d[field])
132
+
133
+ def read_info_from_repo(self):
134
+ if self.is_builtin or self.have_info_from_repo:
135
+ return
136
+
137
+ def read_from_repo():
138
+ with self.lock:
139
+ if self.have_info_from_repo:
140
+ return
141
+
142
+ self.do_read_info_from_repo()
143
+
144
+ return self.to_dict()
145
+
146
+ try:
147
+ d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
148
+ self.from_dict(d)
149
+ except FileNotFoundError:
150
+ pass
151
+ self.status = 'unknown' if self.status == '' else self.status
152
+
153
+ def do_read_info_from_repo(self):
154
+ repo = None
155
+ try:
156
+ if os.path.exists(os.path.join(self.path, ".git")):
157
+ repo = Repo(self.path)
158
+ except Exception:
159
+ errors.report(f"Error reading github repository info from {self.path}", exc_info=True)
160
+
161
+ if repo is None or repo.bare:
162
+ self.remote = None
163
+ else:
164
+ try:
165
+ self.remote = next(repo.remote().urls, None)
166
+ commit = repo.head.commit
167
+ self.commit_date = commit.committed_date
168
+ if repo.active_branch:
169
+ self.branch = repo.active_branch.name
170
+ self.commit_hash = commit.hexsha
171
+ self.version = self.commit_hash[:8]
172
+
173
+ except Exception:
174
+ errors.report(f"Failed reading extension data from Git repository ({self.name})", exc_info=True)
175
+ self.remote = None
176
+
177
+ self.have_info_from_repo = True
178
+
179
+ def list_files(self, subdir, extension):
180
+ dirpath = os.path.join(self.path, subdir)
181
+ if not os.path.isdir(dirpath):
182
+ return []
183
+
184
+ res = []
185
+ for filename in sorted(os.listdir(dirpath)):
186
+ res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
187
+
188
+ res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
189
+
190
+ return res
191
+
192
+ def check_updates(self):
193
+ repo = Repo(self.path)
194
+ branch_name = f'{repo.remote().name}/{self.branch}'
195
+ for fetch in repo.remote().fetch(dry_run=True):
196
+ if self.branch and fetch.name != branch_name:
197
+ continue
198
+ if fetch.flags != fetch.HEAD_UPTODATE:
199
+ self.can_update = True
200
+ self.status = "new commits"
201
+ return
202
+
203
+ try:
204
+ origin = repo.rev_parse(branch_name)
205
+ if repo.head.commit != origin:
206
+ self.can_update = True
207
+ self.status = "behind HEAD"
208
+ return
209
+ except Exception:
210
+ self.can_update = False
211
+ self.status = "unknown (remote error)"
212
+ return
213
+
214
+ self.can_update = False
215
+ self.status = "latest"
216
+
217
+ def fetch_and_reset_hard(self, commit=None):
218
+ repo = Repo(self.path)
219
+ if commit is None:
220
+ commit = f'{repo.remote().name}/{self.branch}'
221
+ # Fix: `error: Your local changes to the following files would be overwritten by merge`,
222
+ # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
223
+ repo.git.fetch(all=True)
224
+ repo.git.reset(commit, hard=True)
225
+ self.have_info_from_repo = False
226
+
227
+
228
+ def list_extensions():
229
+ extensions.clear()
230
+ extension_paths.clear()
231
+ loaded_extensions.clear()
232
+
233
+ if shared.cmd_opts.disable_all_extensions:
234
+ print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
235
+ elif shared.opts.disable_all_extensions == "all":
236
+ print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
237
+ elif shared.cmd_opts.disable_extra_extensions:
238
+ print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***")
239
+ elif shared.opts.disable_all_extensions == "extra":
240
+ print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
241
+
242
+
243
+ # scan through extensions directory and load metadata
244
+ for dirname in [extensions_builtin_dir, extensions_dir]:
245
+ if not os.path.isdir(dirname):
246
+ continue
247
+
248
+ for extension_dirname in sorted(os.listdir(dirname)):
249
+ path = os.path.join(dirname, extension_dirname)
250
+ if not os.path.isdir(path):
251
+ continue
252
+
253
+ canonical_name = extension_dirname
254
+ metadata = ExtensionMetadata(path, canonical_name)
255
+
256
+ # check for duplicated canonical names
257
+ already_loaded_extension = loaded_extensions.get(metadata.canonical_name)
258
+ if already_loaded_extension is not None:
259
+ errors.report(f'Duplicate canonical name "{canonical_name}" found in extensions "{extension_dirname}" and "{already_loaded_extension.name}". Former will be discarded.', exc_info=False)
260
+ continue
261
+
262
+ is_builtin = dirname == extensions_builtin_dir
263
+ extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
264
+ extensions.append(extension)
265
+ extension_paths[extension.path] = extension
266
+ loaded_extensions[canonical_name] = extension
267
+
268
+ for extension in extensions:
269
+ extension.metadata.requires = extension.metadata.get_script_requirements("Requires", "Extension")
270
+
271
+ # check for requirements
272
+ for extension in extensions:
273
+ if not extension.enabled:
274
+ continue
275
+
276
+ for req in extension.metadata.requires:
277
+ required_extension = loaded_extensions.get(req)
278
+ if required_extension is None:
279
+ errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)
280
+ continue
281
+
282
+ if not required_extension.enabled:
283
+ errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)
284
+ continue
285
+
286
+
287
+ def find_extension(filename):
288
+ parentdir = os.path.dirname(os.path.realpath(filename))
289
+
290
+ while parentdir != filename:
291
+ extension = extension_paths.get(parentdir)
292
+ if extension is not None:
293
+ return extension
294
+
295
+ filename = parentdir
296
+ parentdir = os.path.dirname(filename)
297
+
298
+ return None
299
+
modules/extra_networks.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ import logging
5
+ from collections import defaultdict
6
+
7
+ from modules import errors
8
+
9
+ extra_network_registry = {}
10
+ extra_network_aliases = {}
11
+
12
+
13
+ def initialize():
14
+ extra_network_registry.clear()
15
+ extra_network_aliases.clear()
16
+
17
+
18
+ def register_extra_network(extra_network):
19
+ extra_network_registry[extra_network.name] = extra_network
20
+
21
+
22
+ def register_extra_network_alias(extra_network, alias):
23
+ extra_network_aliases[alias] = extra_network
24
+
25
+
26
+ def register_default_extra_networks():
27
+ from modules.extra_networks_hypernet import ExtraNetworkHypernet
28
+ register_extra_network(ExtraNetworkHypernet())
29
+
30
+
31
+ class ExtraNetworkParams:
32
+ def __init__(self, items=None):
33
+ self.items = items or []
34
+ self.positional = []
35
+ self.named = {}
36
+
37
+ for item in self.items:
38
+ parts = item.split('=', 2) if isinstance(item, str) else [item]
39
+ if len(parts) == 2:
40
+ self.named[parts[0]] = parts[1]
41
+ else:
42
+ self.positional.append(item)
43
+
44
+ def __eq__(self, other):
45
+ return self.items == other.items
46
+
47
+
48
+ class ExtraNetwork:
49
+ def __init__(self, name):
50
+ self.name = name
51
+
52
+ def activate(self, p, params_list):
53
+ """
54
+ Called by processing on every run. Whatever the extra network is meant to do should be activated here.
55
+ Passes arguments related to this extra network in params_list.
56
+ User passes arguments by specifying this in his prompt:
57
+
58
+ <name:arg1:arg2:arg3>
59
+
60
+ Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
61
+ separated by colon.
62
+
63
+ Even if the user does not mention this ExtraNetwork in his prompt, the call will still be made, with empty params_list -
64
+ in this case, all effects of this extra networks should be disabled.
65
+
66
+ Can be called multiple times before deactivate() - each new call should override the previous call completely.
67
+
68
+ For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
69
+
70
+ > "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
71
+
72
+ params_list will be:
73
+
74
+ [
75
+ ExtraNetworkParams(items=["agm", "1.1"]),
76
+ ExtraNetworkParams(items=["ray"])
77
+ ]
78
+
79
+ """
80
+ raise NotImplementedError
81
+
82
+ def deactivate(self, p):
83
+ """
84
+ Called at the end of processing for housekeeping. No need to do anything here.
85
+ """
86
+
87
+ raise NotImplementedError
88
+
89
+
90
+ def lookup_extra_networks(extra_network_data):
91
+ """returns a dict mapping ExtraNetwork objects to lists of arguments for those extra networks.
92
+
93
+ Example input:
94
+ {
95
+ 'lora': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>],
96
+ 'lyco': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],
97
+ 'hypernet': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]
98
+ }
99
+
100
+ Example output:
101
+
102
+ {
103
+ <extra_networks_lora.ExtraNetworkLora object at 0x0000020581BEECE0>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>, <modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],
104
+ <modules.extra_networks_hypernet.ExtraNetworkHypernet object at 0x0000020581BEEE60>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]
105
+ }
106
+ """
107
+
108
+ res = {}
109
+
110
+ for extra_network_name, extra_network_args in list(extra_network_data.items()):
111
+ extra_network = extra_network_registry.get(extra_network_name, None)
112
+ alias = extra_network_aliases.get(extra_network_name, None)
113
+
114
+ if alias is not None and extra_network is None:
115
+ extra_network = alias
116
+
117
+ if extra_network is None:
118
+ logging.info(f"Skipping unknown extra network: {extra_network_name}")
119
+ continue
120
+
121
+ res.setdefault(extra_network, []).extend(extra_network_args)
122
+
123
+ return res
124
+
125
+
126
+ def activate(p, extra_network_data):
127
+ """call activate for extra networks in extra_network_data in specified order, then call
128
+ activate for all remaining registered networks with an empty argument list"""
129
+
130
+ activated = []
131
+
132
+ for extra_network, extra_network_args in lookup_extra_networks(extra_network_data).items():
133
+
134
+ try:
135
+ extra_network.activate(p, extra_network_args)
136
+ activated.append(extra_network)
137
+ except Exception as e:
138
+ errors.display(e, f"activating extra network {extra_network.name} with arguments {extra_network_args}")
139
+
140
+ for extra_network_name, extra_network in extra_network_registry.items():
141
+ if extra_network in activated:
142
+ continue
143
+
144
+ try:
145
+ extra_network.activate(p, [])
146
+ except Exception as e:
147
+ errors.display(e, f"activating extra network {extra_network_name}")
148
+
149
+ if p.scripts is not None:
150
+ p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)
151
+
152
+
153
+ def deactivate(p, extra_network_data):
154
+ """call deactivate for extra networks in extra_network_data in specified order, then call
155
+ deactivate for all remaining registered networks"""
156
+
157
+ data = lookup_extra_networks(extra_network_data)
158
+
159
+ for extra_network in data:
160
+ try:
161
+ extra_network.deactivate(p)
162
+ except Exception as e:
163
+ errors.display(e, f"deactivating extra network {extra_network.name}")
164
+
165
+ for extra_network_name, extra_network in extra_network_registry.items():
166
+ if extra_network in data:
167
+ continue
168
+
169
+ try:
170
+ extra_network.deactivate(p)
171
+ except Exception as e:
172
+ errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
173
+
174
+
175
+ re_extra_net = re.compile(r"<(\w+):([^>]+)>")
176
+
177
+
178
+ def parse_prompt(prompt):
179
+ res = defaultdict(list)
180
+
181
+ def found(m):
182
+ name = m.group(1)
183
+ args = m.group(2)
184
+
185
+ res[name].append(ExtraNetworkParams(items=args.split(":")))
186
+
187
+ return ""
188
+
189
+ prompt = re.sub(re_extra_net, found, prompt)
190
+
191
+ return prompt, res
192
+
193
+
194
+ def parse_prompts(prompts):
195
+ res = []
196
+ extra_data = None
197
+
198
+ for prompt in prompts:
199
+ updated_prompt, parsed_extra_data = parse_prompt(prompt)
200
+
201
+ if extra_data is None:
202
+ extra_data = parsed_extra_data
203
+
204
+ res.append(updated_prompt)
205
+
206
+ return res, extra_data
207
+
208
+
209
+ def get_user_metadata(filename, lister=None):
210
+ if filename is None:
211
+ return {}
212
+
213
+ basename, ext = os.path.splitext(filename)
214
+ metadata_filename = basename + '.json'
215
+
216
+ metadata = {}
217
+ try:
218
+ exists = lister.exists(metadata_filename) if lister else os.path.exists(metadata_filename)
219
+ if exists:
220
+ with open(metadata_filename, "r", encoding="utf8") as file:
221
+ metadata = json.load(file)
222
+ except Exception as e:
223
+ errors.display(e, f"reading extra network user metadata from {metadata_filename}")
224
+
225
+ return metadata
modules/extra_networks_hypernet.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules import extra_networks, shared
2
+ from modules.hypernetworks import hypernetwork
3
+
4
+
5
+ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
6
+ def __init__(self):
7
+ super().__init__('hypernet')
8
+
9
+ def activate(self, p, params_list):
10
+ additional = shared.opts.sd_hypernetwork
11
+
12
+ if additional != "None" and additional in shared.hypernetworks and not any(x for x in params_list if x.items[0] == additional):
13
+ hypernet_prompt_text = f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>"
14
+ p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
15
+ params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
16
+
17
+ names = []
18
+ multipliers = []
19
+ for params in params_list:
20
+ assert params.items
21
+
22
+ names.append(params.items[0])
23
+ multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
24
+
25
+ hypernetwork.load_hypernetworks(names, multipliers)
26
+
27
+ def deactivate(self, p):
28
+ pass
modules/extras.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import shutil
4
+ import json
5
+
6
+
7
+ import torch
8
+ import tqdm
9
+
10
+ from modules import shared, images, sd_models, sd_vae, sd_models_config, errors
11
+ from modules.ui_common import plaintext_to_html
12
+ import gradio as gr
13
+ import safetensors.torch
14
+
15
+
16
+ def run_pnginfo(image):
17
+ if image is None:
18
+ return '', '', ''
19
+
20
+ geninfo, items = images.read_info_from_image(image)
21
+ items = {**{'parameters': geninfo}, **items}
22
+
23
+ info = ''
24
+ for key, text in items.items():
25
+ info += f"""
26
+ <div>
27
+ <p><b>{plaintext_to_html(str(key))}</b></p>
28
+ <p>{plaintext_to_html(str(text))}</p>
29
+ </div>
30
+ """.strip()+"\n"
31
+
32
+ if len(info) == 0:
33
+ message = "Nothing found in the image."
34
+ info = f"<div><p>{message}<p></div>"
35
+
36
+ return '', geninfo, info
37
+
38
+
39
+ def create_config(ckpt_result, config_source, a, b, c):
40
+ def config(x):
41
+ res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
42
+ return res if res != shared.sd_default_config else None
43
+
44
+ if config_source == 0:
45
+ cfg = config(a) or config(b) or config(c)
46
+ elif config_source == 1:
47
+ cfg = config(b)
48
+ elif config_source == 2:
49
+ cfg = config(c)
50
+ else:
51
+ cfg = None
52
+
53
+ if cfg is None:
54
+ return
55
+
56
+ filename, _ = os.path.splitext(ckpt_result)
57
+ checkpoint_filename = filename + ".yaml"
58
+
59
+ print("Copying config:")
60
+ print(" from:", cfg)
61
+ print(" to:", checkpoint_filename)
62
+ shutil.copyfile(cfg, checkpoint_filename)
63
+
64
+
65
+ checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
66
+
67
+
68
+ def to_half(tensor, enable):
69
+ if enable and tensor.dtype == torch.float:
70
+ return tensor.half()
71
+
72
+ return tensor
73
+
74
+
75
+ def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name):
76
+ metadata = {}
77
+
78
+ for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]:
79
+ checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None)
80
+ if checkpoint_info is None:
81
+ continue
82
+
83
+ metadata.update(checkpoint_info.metadata)
84
+
85
+ return json.dumps(metadata, indent=4, ensure_ascii=False)
86
+
87
+
88
+ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json):
89
+ shared.state.begin(job="model-merge")
90
+
91
+ def fail(message):
92
+ shared.state.textinfo = message
93
+ shared.state.end()
94
+ return [*[gr.update() for _ in range(4)], message]
95
+
96
+ def weighted_sum(theta0, theta1, alpha):
97
+ return ((1 - alpha) * theta0) + (alpha * theta1)
98
+
99
+ def get_difference(theta1, theta2):
100
+ return theta1 - theta2
101
+
102
+ def add_difference(theta0, theta1_2_diff, alpha):
103
+ return theta0 + (alpha * theta1_2_diff)
104
+
105
+ def filename_weighted_sum():
106
+ a = primary_model_info.model_name
107
+ b = secondary_model_info.model_name
108
+ Ma = round(1 - multiplier, 2)
109
+ Mb = round(multiplier, 2)
110
+
111
+ return f"{Ma}({a}) + {Mb}({b})"
112
+
113
+ def filename_add_difference():
114
+ a = primary_model_info.model_name
115
+ b = secondary_model_info.model_name
116
+ c = tertiary_model_info.model_name
117
+ M = round(multiplier, 2)
118
+
119
+ return f"{a} + {M}({b} - {c})"
120
+
121
+ def filename_nothing():
122
+ return primary_model_info.model_name
123
+
124
+ theta_funcs = {
125
+ "Weighted sum": (filename_weighted_sum, None, weighted_sum),
126
+ "Add difference": (filename_add_difference, get_difference, add_difference),
127
+ "No interpolation": (filename_nothing, None, None),
128
+ }
129
+ filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
130
+ shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
131
+
132
+ if not primary_model_name:
133
+ return fail("Failed: Merging requires a primary model.")
134
+
135
+ primary_model_info = sd_models.checkpoints_list[primary_model_name]
136
+
137
+ if theta_func2 and not secondary_model_name:
138
+ return fail("Failed: Merging requires a secondary model.")
139
+
140
+ secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
141
+
142
+ if theta_func1 and not tertiary_model_name:
143
+ return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
144
+
145
+ tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
146
+
147
+ result_is_inpainting_model = False
148
+ result_is_instruct_pix2pix_model = False
149
+
150
+ if theta_func2:
151
+ shared.state.textinfo = "Loading B"
152
+ print(f"Loading {secondary_model_info.filename}...")
153
+ theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
154
+ else:
155
+ theta_1 = None
156
+
157
+ if theta_func1:
158
+ shared.state.textinfo = "Loading C"
159
+ print(f"Loading {tertiary_model_info.filename}...")
160
+ theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
161
+
162
+ shared.state.textinfo = 'Merging B and C'
163
+ shared.state.sampling_steps = len(theta_1.keys())
164
+ for key in tqdm.tqdm(theta_1.keys()):
165
+ if key in checkpoint_dict_skip_on_merge:
166
+ continue
167
+
168
+ if 'model' in key:
169
+ if key in theta_2:
170
+ t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
171
+ theta_1[key] = theta_func1(theta_1[key], t2)
172
+ else:
173
+ theta_1[key] = torch.zeros_like(theta_1[key])
174
+
175
+ shared.state.sampling_step += 1
176
+ del theta_2
177
+
178
+ shared.state.nextjob()
179
+
180
+ shared.state.textinfo = f"Loading {primary_model_info.filename}..."
181
+ print(f"Loading {primary_model_info.filename}...")
182
+ theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
183
+
184
+ print("Merging...")
185
+ shared.state.textinfo = 'Merging A and B'
186
+ shared.state.sampling_steps = len(theta_0.keys())
187
+ for key in tqdm.tqdm(theta_0.keys()):
188
+ if theta_1 and 'model' in key and key in theta_1:
189
+
190
+ if key in checkpoint_dict_skip_on_merge:
191
+ continue
192
+
193
+ a = theta_0[key]
194
+ b = theta_1[key]
195
+
196
+ # this enables merging an inpainting model (A) with another one (B);
197
+ # where normal model would have 4 channels, for latenst space, inpainting model would
198
+ # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
199
+ if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
200
+ if a.shape[1] == 4 and b.shape[1] == 9:
201
+ raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
202
+ if a.shape[1] == 4 and b.shape[1] == 8:
203
+ raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
204
+
205
+ if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model...
206
+ theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
207
+ result_is_instruct_pix2pix_model = True
208
+ else:
209
+ assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
210
+ theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
211
+ result_is_inpainting_model = True
212
+ else:
213
+ theta_0[key] = theta_func2(a, b, multiplier)
214
+
215
+ theta_0[key] = to_half(theta_0[key], save_as_half)
216
+
217
+ shared.state.sampling_step += 1
218
+
219
+ del theta_1
220
+
221
+ bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
222
+ if bake_in_vae_filename is not None:
223
+ print(f"Baking in VAE from {bake_in_vae_filename}")
224
+ shared.state.textinfo = 'Baking in VAE'
225
+ vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
226
+
227
+ for key in vae_dict.keys():
228
+ theta_0_key = 'first_stage_model.' + key
229
+ if theta_0_key in theta_0:
230
+ theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
231
+
232
+ del vae_dict
233
+
234
+ if save_as_half and not theta_func2:
235
+ for key in theta_0.keys():
236
+ theta_0[key] = to_half(theta_0[key], save_as_half)
237
+
238
+ if discard_weights:
239
+ regex = re.compile(discard_weights)
240
+ for key in list(theta_0):
241
+ if re.search(regex, key):
242
+ theta_0.pop(key, None)
243
+
244
+ ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
245
+
246
+ filename = filename_generator() if custom_name == '' else custom_name
247
+ filename += ".inpainting" if result_is_inpainting_model else ""
248
+ filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else ""
249
+ filename += "." + checkpoint_format
250
+
251
+ output_modelname = os.path.join(ckpt_dir, filename)
252
+
253
+ shared.state.nextjob()
254
+ shared.state.textinfo = "Saving"
255
+ print(f"Saving to {output_modelname}...")
256
+
257
+ metadata = {}
258
+
259
+ if save_metadata and copy_metadata_fields:
260
+ if primary_model_info:
261
+ metadata.update(primary_model_info.metadata)
262
+ if secondary_model_info:
263
+ metadata.update(secondary_model_info.metadata)
264
+ if tertiary_model_info:
265
+ metadata.update(tertiary_model_info.metadata)
266
+
267
+ if save_metadata:
268
+ try:
269
+ metadata.update(json.loads(metadata_json))
270
+ except Exception as e:
271
+ errors.display(e, "readin metadata from json")
272
+
273
+ metadata["format"] = "pt"
274
+
275
+ if save_metadata and add_merge_recipe:
276
+ merge_recipe = {
277
+ "type": "webui", # indicate this model was merged with webui's built-in merger
278
+ "primary_model_hash": primary_model_info.sha256,
279
+ "secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None,
280
+ "tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None,
281
+ "interp_method": interp_method,
282
+ "multiplier": multiplier,
283
+ "save_as_half": save_as_half,
284
+ "custom_name": custom_name,
285
+ "config_source": config_source,
286
+ "bake_in_vae": bake_in_vae,
287
+ "discard_weights": discard_weights,
288
+ "is_inpainting": result_is_inpainting_model,
289
+ "is_instruct_pix2pix": result_is_instruct_pix2pix_model
290
+ }
291
+
292
+ sd_merge_models = {}
293
+
294
+ def add_model_metadata(checkpoint_info):
295
+ checkpoint_info.calculate_shorthash()
296
+ sd_merge_models[checkpoint_info.sha256] = {
297
+ "name": checkpoint_info.name,
298
+ "legacy_hash": checkpoint_info.hash,
299
+ "sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
300
+ }
301
+
302
+ sd_merge_models.update(checkpoint_info.metadata.get("sd_merge_models", {}))
303
+
304
+ add_model_metadata(primary_model_info)
305
+ if secondary_model_info:
306
+ add_model_metadata(secondary_model_info)
307
+ if tertiary_model_info:
308
+ add_model_metadata(tertiary_model_info)
309
+
310
+ metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
311
+ metadata["sd_merge_models"] = json.dumps(sd_merge_models)
312
+
313
+ _, extension = os.path.splitext(output_modelname)
314
+ if extension.lower() == ".safetensors":
315
+ safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata if len(metadata)>0 else None)
316
+ else:
317
+ torch.save(theta_0, output_modelname)
318
+
319
+ sd_models.list_models()
320
+ created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
321
+ if created_model:
322
+ created_model.calculate_shorthash()
323
+
324
+ create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
325
+
326
+ print(f"Checkpoint saved to {output_modelname}.")
327
+ shared.state.textinfo = "Checkpoint saved"
328
+ shared.state.end()
329
+
330
+ return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
modules/face_restoration.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules import shared
2
+
3
+
4
+ class FaceRestoration:
5
+ def name(self):
6
+ return "None"
7
+
8
+ def restore(self, np_image):
9
+ return np_image
10
+
11
+
12
+ def restore_faces(np_image):
13
+ face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
14
+ if len(face_restorers) == 0:
15
+ return np_image
16
+
17
+ face_restorer = face_restorers[0]
18
+
19
+ return face_restorer.restore(np_image)
modules/face_restoration_utils.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ from functools import cached_property
6
+ from typing import TYPE_CHECKING, Callable
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+
12
+ from modules import devices, errors, face_restoration, shared
13
+
14
+ if TYPE_CHECKING:
15
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor:
21
+ """Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor."""
22
+ assert img.shape[2] == 3, "image must be RGB"
23
+ if img.dtype == "float64":
24
+ img = img.astype("float32")
25
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
26
+ return torch.from_numpy(img.transpose(2, 0, 1)).float()
27
+
28
+
29
+ def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray:
30
+ """
31
+ Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range.
32
+ """
33
+ tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
34
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
35
+ assert tensor.dim() == 3, "tensor must be RGB"
36
+ img_np = tensor.numpy().transpose(1, 2, 0)
37
+ if img_np.shape[2] == 1: # gray image, no RGB/BGR required
38
+ return np.squeeze(img_np, axis=2)
39
+ return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
40
+
41
+
42
+ def create_face_helper(device) -> FaceRestoreHelper:
43
+ from facexlib.detection import retinaface
44
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
45
+ if hasattr(retinaface, 'device'):
46
+ retinaface.device = device
47
+ return FaceRestoreHelper(
48
+ upscale_factor=1,
49
+ face_size=512,
50
+ crop_ratio=(1, 1),
51
+ det_model='retinaface_resnet50',
52
+ save_ext='png',
53
+ use_parse=True,
54
+ device=device,
55
+ )
56
+
57
+
58
+ def restore_with_face_helper(
59
+ np_image: np.ndarray,
60
+ face_helper: FaceRestoreHelper,
61
+ restore_face: Callable[[torch.Tensor], torch.Tensor],
62
+ ) -> np.ndarray:
63
+ """
64
+ Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.
65
+
66
+ `restore_face` should take a cropped face image and return a restored face image.
67
+ """
68
+ from torchvision.transforms.functional import normalize
69
+ np_image = np_image[:, :, ::-1]
70
+ original_resolution = np_image.shape[0:2]
71
+
72
+ try:
73
+ logger.debug("Detecting faces...")
74
+ face_helper.clean_all()
75
+ face_helper.read_image(np_image)
76
+ face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
77
+ face_helper.align_warp_face()
78
+ logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
79
+ for cropped_face in face_helper.cropped_faces:
80
+ cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0)
81
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
82
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
83
+
84
+ try:
85
+ with torch.no_grad():
86
+ cropped_face_t = restore_face(cropped_face_t)
87
+ devices.torch_gc()
88
+ except Exception:
89
+ errors.report('Failed face-restoration inference', exc_info=True)
90
+
91
+ restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1))
92
+ restored_face = (restored_face * 255.0).astype('uint8')
93
+ face_helper.add_restored_face(restored_face)
94
+
95
+ logger.debug("Merging restored faces into image")
96
+ face_helper.get_inverse_affine(None)
97
+ img = face_helper.paste_faces_to_input_image()
98
+ img = img[:, :, ::-1]
99
+ if original_resolution != img.shape[0:2]:
100
+ img = cv2.resize(
101
+ img,
102
+ (0, 0),
103
+ fx=original_resolution[1] / img.shape[1],
104
+ fy=original_resolution[0] / img.shape[0],
105
+ interpolation=cv2.INTER_LINEAR,
106
+ )
107
+ logger.debug("Face restoration complete")
108
+ finally:
109
+ face_helper.clean_all()
110
+ return img
111
+
112
+
113
+ class CommonFaceRestoration(face_restoration.FaceRestoration):
114
+ net: torch.Module | None
115
+ model_url: str
116
+ model_download_name: str
117
+
118
+ def __init__(self, model_path: str):
119
+ super().__init__()
120
+ self.net = None
121
+ self.model_path = model_path
122
+ os.makedirs(model_path, exist_ok=True)
123
+
124
+ @cached_property
125
+ def face_helper(self) -> FaceRestoreHelper:
126
+ return create_face_helper(self.get_device())
127
+
128
+ def send_model_to(self, device):
129
+ if self.net:
130
+ logger.debug("Sending %s to %s", self.net, device)
131
+ self.net.to(device)
132
+ if self.face_helper:
133
+ logger.debug("Sending face helper to %s", device)
134
+ self.face_helper.face_det.to(device)
135
+ self.face_helper.face_parse.to(device)
136
+
137
+ def get_device(self):
138
+ raise NotImplementedError("get_device must be implemented by subclasses")
139
+
140
+ def load_net(self) -> torch.Module:
141
+ raise NotImplementedError("load_net must be implemented by subclasses")
142
+
143
+ def restore_with_helper(
144
+ self,
145
+ np_image: np.ndarray,
146
+ restore_face: Callable[[torch.Tensor], torch.Tensor],
147
+ ) -> np.ndarray:
148
+ try:
149
+ if self.net is None:
150
+ self.net = self.load_net()
151
+ except Exception:
152
+ logger.warning("Unable to load face-restoration model", exc_info=True)
153
+ return np_image
154
+
155
+ try:
156
+ self.send_model_to(self.get_device())
157
+ return restore_with_face_helper(np_image, self.face_helper, restore_face)
158
+ finally:
159
+ if shared.opts.face_restoration_unload:
160
+ self.send_model_to(devices.cpu)
161
+
162
+
163
+ def patch_facexlib(dirname: str) -> None:
164
+ import facexlib.detection
165
+ import facexlib.parsing
166
+
167
+ det_facex_load_file_from_url = facexlib.detection.load_file_from_url
168
+ par_facex_load_file_from_url = facexlib.parsing.load_file_from_url
169
+
170
+ def update_kwargs(kwargs):
171
+ return dict(kwargs, save_dir=dirname, model_dir=None)
172
+
173
+ def facex_load_file_from_url(**kwargs):
174
+ return det_facex_load_file_from_url(**update_kwargs(kwargs))
175
+
176
+ def facex_load_file_from_url2(**kwargs):
177
+ return par_facex_load_file_from_url(**update_kwargs(kwargs))
178
+
179
+ facexlib.detection.load_file_from_url = facex_load_file_from_url
180
+ facexlib.parsing.load_file_from_url = facex_load_file_from_url2
modules/fifo_lock.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import collections
3
+
4
+
5
+ # reference: https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a
6
+ class FIFOLock(object):
7
+ def __init__(self):
8
+ self._lock = threading.Lock()
9
+ self._inner_lock = threading.Lock()
10
+ self._pending_threads = collections.deque()
11
+
12
+ def acquire(self, blocking=True):
13
+ with self._inner_lock:
14
+ lock_acquired = self._lock.acquire(False)
15
+ if lock_acquired:
16
+ return True
17
+ elif not blocking:
18
+ return False
19
+
20
+ release_event = threading.Event()
21
+ self._pending_threads.append(release_event)
22
+
23
+ release_event.wait()
24
+ return self._lock.acquire()
25
+
26
+ def release(self):
27
+ with self._inner_lock:
28
+ if self._pending_threads:
29
+ release_event = self._pending_threads.popleft()
30
+ release_event.set()
31
+
32
+ self._lock.release()
33
+
34
+ __enter__ = acquire
35
+
36
+ def __exit__(self, t, v, tb):
37
+ self.release()
modules/gfpgan_model.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+
6
+ import torch
7
+
8
+ from modules import (
9
+ devices,
10
+ errors,
11
+ face_restoration,
12
+ face_restoration_utils,
13
+ modelloader,
14
+ shared,
15
+ )
16
+
17
+ logger = logging.getLogger(__name__)
18
+ model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
19
+ model_download_name = "GFPGANv1.4.pth"
20
+ gfpgan_face_restorer: face_restoration.FaceRestoration | None = None
21
+
22
+
23
+ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
24
+ def name(self):
25
+ return "GFPGAN"
26
+
27
+ def get_device(self):
28
+ return devices.device_gfpgan
29
+
30
+ def load_net(self) -> torch.Module:
31
+ for model_path in modelloader.load_models(
32
+ model_path=self.model_path,
33
+ model_url=model_url,
34
+ command_path=self.model_path,
35
+ download_name=model_download_name,
36
+ ext_filter=['.pth'],
37
+ ):
38
+ if 'GFPGAN' in os.path.basename(model_path):
39
+ return modelloader.load_spandrel_model(
40
+ model_path,
41
+ device=self.get_device(),
42
+ expected_architecture='GFPGAN',
43
+ ).model
44
+ raise ValueError("No GFPGAN model found")
45
+
46
+ def restore(self, np_image):
47
+ def restore_face(cropped_face_t):
48
+ assert self.net is not None
49
+ return self.net(cropped_face_t, return_rgb=False)[0]
50
+
51
+ return self.restore_with_helper(np_image, restore_face)
52
+
53
+
54
+ def gfpgan_fix_faces(np_image):
55
+ if gfpgan_face_restorer:
56
+ return gfpgan_face_restorer.restore(np_image)
57
+ logger.warning("GFPGAN face restorer not set up")
58
+ return np_image
59
+
60
+
61
+ def setup_model(dirname: str) -> None:
62
+ global gfpgan_face_restorer
63
+
64
+ try:
65
+ face_restoration_utils.patch_facexlib(dirname)
66
+ gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname)
67
+ shared.face_restorers.append(gfpgan_face_restorer)
68
+ except Exception:
69
+ errors.report("Error setting up GFPGAN", exc_info=True)
modules/gitpython_hack.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import io
4
+ import subprocess
5
+
6
+ import git
7
+
8
+
9
+ class Git(git.Git):
10
+ """
11
+ Git subclassed to never use persistent processes.
12
+ """
13
+
14
+ def _get_persistent_cmd(self, attr_name, cmd_name, *args, **kwargs):
15
+ raise NotImplementedError(f"Refusing to use persistent process: {attr_name} ({cmd_name} {args} {kwargs})")
16
+
17
+ def get_object_header(self, ref: str | bytes) -> tuple[str, str, int]:
18
+ ret = subprocess.check_output(
19
+ [self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch-check"],
20
+ input=self._prepare_ref(ref),
21
+ cwd=self._working_dir,
22
+ timeout=2,
23
+ )
24
+ return self._parse_object_header(ret)
25
+
26
+ def stream_object_data(self, ref: str) -> tuple[str, str, int, Git.CatFileContentStream]:
27
+ # Not really streaming, per se; this buffers the entire object in memory.
28
+ # Shouldn't be a problem for our use case, since we're only using this for
29
+ # object headers (commit objects).
30
+ ret = subprocess.check_output(
31
+ [self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch"],
32
+ input=self._prepare_ref(ref),
33
+ cwd=self._working_dir,
34
+ timeout=30,
35
+ )
36
+ bio = io.BytesIO(ret)
37
+ hexsha, typename, size = self._parse_object_header(bio.readline())
38
+ return (hexsha, typename, size, self.CatFileContentStream(size, bio))
39
+
40
+
41
+ class Repo(git.Repo):
42
+ GitCommandWrapperType = Git
modules/gradio_extensons.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from modules import scripts, ui_tempdir, patches
4
+
5
+
6
+ def add_classes_to_gradio_component(comp):
7
+ """
8
+ this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
9
+ """
10
+
11
+ comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
12
+
13
+ if getattr(comp, 'multiselect', False):
14
+ comp.elem_classes.append('multiselect')
15
+
16
+
17
+ def IOComponent_init(self, *args, **kwargs):
18
+ self.webui_tooltip = kwargs.pop('tooltip', None)
19
+
20
+ if scripts.scripts_current is not None:
21
+ scripts.scripts_current.before_component(self, **kwargs)
22
+
23
+ scripts.script_callbacks.before_component_callback(self, **kwargs)
24
+
25
+ res = original_IOComponent_init(self, *args, **kwargs)
26
+
27
+ add_classes_to_gradio_component(self)
28
+
29
+ scripts.script_callbacks.after_component_callback(self, **kwargs)
30
+
31
+ if scripts.scripts_current is not None:
32
+ scripts.scripts_current.after_component(self, **kwargs)
33
+
34
+ return res
35
+
36
+
37
+ def Block_get_config(self):
38
+ config = original_Block_get_config(self)
39
+
40
+ webui_tooltip = getattr(self, 'webui_tooltip', None)
41
+ if webui_tooltip:
42
+ config["webui_tooltip"] = webui_tooltip
43
+
44
+ config.pop('example_inputs', None)
45
+
46
+ return config
47
+
48
+
49
+ def BlockContext_init(self, *args, **kwargs):
50
+ if scripts.scripts_current is not None:
51
+ scripts.scripts_current.before_component(self, **kwargs)
52
+
53
+ scripts.script_callbacks.before_component_callback(self, **kwargs)
54
+
55
+ res = original_BlockContext_init(self, *args, **kwargs)
56
+
57
+ add_classes_to_gradio_component(self)
58
+
59
+ scripts.script_callbacks.after_component_callback(self, **kwargs)
60
+
61
+ if scripts.scripts_current is not None:
62
+ scripts.scripts_current.after_component(self, **kwargs)
63
+
64
+ return res
65
+
66
+
67
+ def Blocks_get_config_file(self, *args, **kwargs):
68
+ config = original_Blocks_get_config_file(self, *args, **kwargs)
69
+
70
+ for comp_config in config["components"]:
71
+ if "example_inputs" in comp_config:
72
+ comp_config["example_inputs"] = {"serialized": []}
73
+
74
+ return config
75
+
76
+
77
+ original_IOComponent_init = patches.patch(__name__, obj=gr.components.IOComponent, field="__init__", replacement=IOComponent_init)
78
+ original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config)
79
+ original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init)
80
+ original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file)
81
+
82
+
83
+ ui_tempdir.install_ui_tempdir_override()
modules/hashes.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os.path
3
+
4
+ from modules import shared
5
+ import modules.cache
6
+
7
+ dump_cache = modules.cache.dump_cache
8
+ cache = modules.cache.cache
9
+
10
+
11
+ def calculate_sha256(filename):
12
+ hash_sha256 = hashlib.sha256()
13
+ blksize = 1024 * 1024
14
+
15
+ with open(filename, "rb") as f:
16
+ for chunk in iter(lambda: f.read(blksize), b""):
17
+ hash_sha256.update(chunk)
18
+
19
+ return hash_sha256.hexdigest()
20
+
21
+
22
+ def sha256_from_cache(filename, title, use_addnet_hash=False):
23
+ hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
24
+ try:
25
+ ondisk_mtime = os.path.getmtime(filename)
26
+ except FileNotFoundError:
27
+ return None
28
+
29
+ if title not in hashes:
30
+ return None
31
+
32
+ cached_sha256 = hashes[title].get("sha256", None)
33
+ cached_mtime = hashes[title].get("mtime", 0)
34
+
35
+ if ondisk_mtime > cached_mtime or cached_sha256 is None:
36
+ return None
37
+
38
+ return cached_sha256
39
+
40
+
41
+ def sha256(filename, title, use_addnet_hash=False):
42
+ hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
43
+
44
+ sha256_value = sha256_from_cache(filename, title, use_addnet_hash)
45
+ if sha256_value is not None:
46
+ return sha256_value
47
+
48
+ if shared.cmd_opts.no_hashing:
49
+ return None
50
+
51
+ print(f"Calculating sha256 for {filename}: ", end='')
52
+ if use_addnet_hash:
53
+ with open(filename, "rb") as file:
54
+ sha256_value = addnet_hash_safetensors(file)
55
+ else:
56
+ sha256_value = calculate_sha256(filename)
57
+ print(f"{sha256_value}")
58
+
59
+ hashes[title] = {
60
+ "mtime": os.path.getmtime(filename),
61
+ "sha256": sha256_value,
62
+ }
63
+
64
+ dump_cache()
65
+
66
+ return sha256_value
67
+
68
+
69
+ def addnet_hash_safetensors(b):
70
+ """kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py"""
71
+ hash_sha256 = hashlib.sha256()
72
+ blksize = 1024 * 1024
73
+
74
+ b.seek(0)
75
+ header = b.read(8)
76
+ n = int.from_bytes(header, "little")
77
+
78
+ offset = n + 8
79
+ b.seek(offset)
80
+ for chunk in iter(lambda: b.read(blksize), b""):
81
+ hash_sha256.update(chunk)
82
+
83
+ return hash_sha256.hexdigest()
84
+
modules/hat_model.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ from modules import modelloader, devices
5
+ from modules.shared import opts
6
+ from modules.upscaler import Upscaler, UpscalerData
7
+ from modules.upscaler_utils import upscale_with_model
8
+
9
+
10
+ class UpscalerHAT(Upscaler):
11
+ def __init__(self, dirname):
12
+ self.name = "HAT"
13
+ self.scalers = []
14
+ self.user_path = dirname
15
+ super().__init__()
16
+ for file in self.find_models(ext_filter=[".pt", ".pth"]):
17
+ name = modelloader.friendly_name(file)
18
+ scale = 4 # TODO: scale might not be 4, but we can't know without loading the model
19
+ scaler_data = UpscalerData(name, file, upscaler=self, scale=scale)
20
+ self.scalers.append(scaler_data)
21
+
22
+ def do_upscale(self, img, selected_model):
23
+ try:
24
+ model = self.load_model(selected_model)
25
+ except Exception as e:
26
+ print(f"Unable to load HAT model {selected_model}: {e}", file=sys.stderr)
27
+ return img
28
+ model.to(devices.device_esrgan) # TODO: should probably be device_hat
29
+ return upscale_with_model(
30
+ model,
31
+ img,
32
+ tile_size=opts.ESRGAN_tile, # TODO: should probably be HAT_tile
33
+ tile_overlap=opts.ESRGAN_tile_overlap, # TODO: should probably be HAT_tile_overlap
34
+ )
35
+
36
+ def load_model(self, path: str):
37
+ if not os.path.isfile(path):
38
+ raise FileNotFoundError(f"Model file {path} not found")
39
+ return modelloader.load_spandrel_model(
40
+ path,
41
+ device=devices.device_esrgan, # TODO: should probably be device_hat
42
+ expected_architecture='HAT',
43
+ )
modules/hypernetworks/hypernetwork.py ADDED
@@ -0,0 +1,783 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import glob
3
+ import html
4
+ import os
5
+ import inspect
6
+ from contextlib import closing
7
+
8
+ import modules.textual_inversion.dataset
9
+ import torch
10
+ import tqdm
11
+ from einops import rearrange, repeat
12
+ from ldm.util import default
13
+ from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
14
+ from modules.textual_inversion import textual_inversion, saving_settings
15
+ from modules.textual_inversion.learn_schedule import LearnRateScheduler
16
+ from torch import einsum
17
+ from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
18
+
19
+ from collections import deque
20
+ from statistics import stdev, mean
21
+
22
+
23
+ optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
24
+
25
+ class HypernetworkModule(torch.nn.Module):
26
+ activation_dict = {
27
+ "linear": torch.nn.Identity,
28
+ "relu": torch.nn.ReLU,
29
+ "leakyrelu": torch.nn.LeakyReLU,
30
+ "elu": torch.nn.ELU,
31
+ "swish": torch.nn.Hardswish,
32
+ "tanh": torch.nn.Tanh,
33
+ "sigmoid": torch.nn.Sigmoid,
34
+ }
35
+ activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
36
+
37
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
38
+ add_layer_norm=False, activate_output=False, dropout_structure=None):
39
+ super().__init__()
40
+
41
+ self.multiplier = 1.0
42
+
43
+ assert layer_structure is not None, "layer_structure must not be None"
44
+ assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
45
+ assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
46
+
47
+ linears = []
48
+ for i in range(len(layer_structure) - 1):
49
+
50
+ # Add a fully-connected layer
51
+ linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
52
+
53
+ # Add an activation func except last layer
54
+ if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
55
+ pass
56
+ elif activation_func in self.activation_dict:
57
+ linears.append(self.activation_dict[activation_func]())
58
+ else:
59
+ raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
60
+
61
+ # Add layer normalization
62
+ if add_layer_norm:
63
+ linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
64
+
65
+ # Everything should be now parsed into dropout structure, and applied here.
66
+ # Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
67
+ if dropout_structure is not None and dropout_structure[i+1] > 0:
68
+ assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
69
+ linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
70
+ # Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
71
+
72
+ self.linear = torch.nn.Sequential(*linears)
73
+
74
+ if state_dict is not None:
75
+ self.fix_old_state_dict(state_dict)
76
+ self.load_state_dict(state_dict)
77
+ else:
78
+ for layer in self.linear:
79
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
80
+ w, b = layer.weight.data, layer.bias.data
81
+ if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
82
+ normal_(w, mean=0.0, std=0.01)
83
+ normal_(b, mean=0.0, std=0)
84
+ elif weight_init == 'XavierUniform':
85
+ xavier_uniform_(w)
86
+ zeros_(b)
87
+ elif weight_init == 'XavierNormal':
88
+ xavier_normal_(w)
89
+ zeros_(b)
90
+ elif weight_init == 'KaimingUniform':
91
+ kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
92
+ zeros_(b)
93
+ elif weight_init == 'KaimingNormal':
94
+ kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
95
+ zeros_(b)
96
+ else:
97
+ raise KeyError(f"Key {weight_init} is not defined as initialization!")
98
+ devices.torch_npu_set_device()
99
+ self.to(devices.device)
100
+
101
+ def fix_old_state_dict(self, state_dict):
102
+ changes = {
103
+ 'linear1.bias': 'linear.0.bias',
104
+ 'linear1.weight': 'linear.0.weight',
105
+ 'linear2.bias': 'linear.1.bias',
106
+ 'linear2.weight': 'linear.1.weight',
107
+ }
108
+
109
+ for fr, to in changes.items():
110
+ x = state_dict.get(fr, None)
111
+ if x is None:
112
+ continue
113
+
114
+ del state_dict[fr]
115
+ state_dict[to] = x
116
+
117
+ def forward(self, x):
118
+ return x + self.linear(x) * (self.multiplier if not self.training else 1)
119
+
120
+ def trainables(self):
121
+ layer_structure = []
122
+ for layer in self.linear:
123
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
124
+ layer_structure += [layer.weight, layer.bias]
125
+ return layer_structure
126
+
127
+
128
+ #param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
129
+ def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
130
+ if layer_structure is None:
131
+ layer_structure = [1, 2, 1]
132
+ if not use_dropout:
133
+ return [0] * len(layer_structure)
134
+ dropout_values = [0]
135
+ dropout_values.extend([0.3] * (len(layer_structure) - 3))
136
+ if last_layer_dropout:
137
+ dropout_values.append(0.3)
138
+ else:
139
+ dropout_values.append(0)
140
+ dropout_values.append(0)
141
+ return dropout_values
142
+
143
+
144
+ class Hypernetwork:
145
+ filename = None
146
+ name = None
147
+
148
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
149
+ self.filename = None
150
+ self.name = name
151
+ self.layers = {}
152
+ self.step = 0
153
+ self.sd_checkpoint = None
154
+ self.sd_checkpoint_name = None
155
+ self.layer_structure = layer_structure
156
+ self.activation_func = activation_func
157
+ self.weight_init = weight_init
158
+ self.add_layer_norm = add_layer_norm
159
+ self.use_dropout = use_dropout
160
+ self.activate_output = activate_output
161
+ self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
162
+ self.dropout_structure = kwargs.get('dropout_structure', None)
163
+ if self.dropout_structure is None:
164
+ self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
165
+ self.optimizer_name = None
166
+ self.optimizer_state_dict = None
167
+ self.optional_info = None
168
+
169
+ for size in enable_sizes or []:
170
+ self.layers[size] = (
171
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
172
+ self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
173
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
174
+ self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
175
+ )
176
+ self.eval()
177
+
178
+ def weights(self):
179
+ res = []
180
+ for layers in self.layers.values():
181
+ for layer in layers:
182
+ res += layer.parameters()
183
+ return res
184
+
185
+ def train(self, mode=True):
186
+ for layers in self.layers.values():
187
+ for layer in layers:
188
+ layer.train(mode=mode)
189
+ for param in layer.parameters():
190
+ param.requires_grad = mode
191
+
192
+ def to(self, device):
193
+ for layers in self.layers.values():
194
+ for layer in layers:
195
+ layer.to(device)
196
+
197
+ return self
198
+
199
+ def set_multiplier(self, multiplier):
200
+ for layers in self.layers.values():
201
+ for layer in layers:
202
+ layer.multiplier = multiplier
203
+
204
+ return self
205
+
206
+ def eval(self):
207
+ for layers in self.layers.values():
208
+ for layer in layers:
209
+ layer.eval()
210
+ for param in layer.parameters():
211
+ param.requires_grad = False
212
+
213
+ def save(self, filename):
214
+ state_dict = {}
215
+ optimizer_saved_dict = {}
216
+
217
+ for k, v in self.layers.items():
218
+ state_dict[k] = (v[0].state_dict(), v[1].state_dict())
219
+
220
+ state_dict['step'] = self.step
221
+ state_dict['name'] = self.name
222
+ state_dict['layer_structure'] = self.layer_structure
223
+ state_dict['activation_func'] = self.activation_func
224
+ state_dict['is_layer_norm'] = self.add_layer_norm
225
+ state_dict['weight_initialization'] = self.weight_init
226
+ state_dict['sd_checkpoint'] = self.sd_checkpoint
227
+ state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
228
+ state_dict['activate_output'] = self.activate_output
229
+ state_dict['use_dropout'] = self.use_dropout
230
+ state_dict['dropout_structure'] = self.dropout_structure
231
+ state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
232
+ state_dict['optional_info'] = self.optional_info if self.optional_info else None
233
+
234
+ if self.optimizer_name is not None:
235
+ optimizer_saved_dict['optimizer_name'] = self.optimizer_name
236
+
237
+ torch.save(state_dict, filename)
238
+ if shared.opts.save_optimizer_state and self.optimizer_state_dict:
239
+ optimizer_saved_dict['hash'] = self.shorthash()
240
+ optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
241
+ torch.save(optimizer_saved_dict, filename + '.optim')
242
+
243
+ def load(self, filename):
244
+ self.filename = filename
245
+ if self.name is None:
246
+ self.name = os.path.splitext(os.path.basename(filename))[0]
247
+
248
+ state_dict = torch.load(filename, map_location='cpu')
249
+
250
+ self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
251
+ self.optional_info = state_dict.get('optional_info', None)
252
+ self.activation_func = state_dict.get('activation_func', None)
253
+ self.weight_init = state_dict.get('weight_initialization', 'Normal')
254
+ self.add_layer_norm = state_dict.get('is_layer_norm', False)
255
+ self.dropout_structure = state_dict.get('dropout_structure', None)
256
+ self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
257
+ self.activate_output = state_dict.get('activate_output', True)
258
+ self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
259
+ # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
260
+ if self.dropout_structure is None:
261
+ self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
262
+
263
+ if shared.opts.print_hypernet_extra:
264
+ if self.optional_info is not None:
265
+ print(f" INFO:\n {self.optional_info}\n")
266
+
267
+ print(f" Layer structure: {self.layer_structure}")
268
+ print(f" Activation function: {self.activation_func}")
269
+ print(f" Weight initialization: {self.weight_init}")
270
+ print(f" Layer norm: {self.add_layer_norm}")
271
+ print(f" Dropout usage: {self.use_dropout}" )
272
+ print(f" Activate last layer: {self.activate_output}")
273
+ print(f" Dropout structure: {self.dropout_structure}")
274
+
275
+ optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
276
+
277
+ if self.shorthash() == optimizer_saved_dict.get('hash', None):
278
+ self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
279
+ else:
280
+ self.optimizer_state_dict = None
281
+ if self.optimizer_state_dict:
282
+ self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
283
+ if shared.opts.print_hypernet_extra:
284
+ print("Loaded existing optimizer from checkpoint")
285
+ print(f"Optimizer name is {self.optimizer_name}")
286
+ else:
287
+ self.optimizer_name = "AdamW"
288
+ if shared.opts.print_hypernet_extra:
289
+ print("No saved optimizer exists in checkpoint")
290
+
291
+ for size, sd in state_dict.items():
292
+ if type(size) == int:
293
+ self.layers[size] = (
294
+ HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
295
+ self.add_layer_norm, self.activate_output, self.dropout_structure),
296
+ HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
297
+ self.add_layer_norm, self.activate_output, self.dropout_structure),
298
+ )
299
+
300
+ self.name = state_dict.get('name', self.name)
301
+ self.step = state_dict.get('step', 0)
302
+ self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
303
+ self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
304
+ self.eval()
305
+
306
+ def shorthash(self):
307
+ sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
308
+
309
+ return sha256[0:10] if sha256 else None
310
+
311
+
312
+ def list_hypernetworks(path):
313
+ res = {}
314
+ for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True), key=str.lower):
315
+ name = os.path.splitext(os.path.basename(filename))[0]
316
+ # Prevent a hypothetical "None.pt" from being listed.
317
+ if name != "None":
318
+ res[name] = filename
319
+ return res
320
+
321
+
322
+ def load_hypernetwork(name):
323
+ path = shared.hypernetworks.get(name, None)
324
+
325
+ if path is None:
326
+ return None
327
+
328
+ try:
329
+ hypernetwork = Hypernetwork()
330
+ hypernetwork.load(path)
331
+ return hypernetwork
332
+ except Exception:
333
+ errors.report(f"Error loading hypernetwork {path}", exc_info=True)
334
+ return None
335
+
336
+
337
+ def load_hypernetworks(names, multipliers=None):
338
+ already_loaded = {}
339
+
340
+ for hypernetwork in shared.loaded_hypernetworks:
341
+ if hypernetwork.name in names:
342
+ already_loaded[hypernetwork.name] = hypernetwork
343
+
344
+ shared.loaded_hypernetworks.clear()
345
+
346
+ for i, name in enumerate(names):
347
+ hypernetwork = already_loaded.get(name, None)
348
+ if hypernetwork is None:
349
+ hypernetwork = load_hypernetwork(name)
350
+
351
+ if hypernetwork is None:
352
+ continue
353
+
354
+ hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
355
+ shared.loaded_hypernetworks.append(hypernetwork)
356
+
357
+
358
+ def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
359
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
360
+
361
+ if hypernetwork_layers is None:
362
+ return context_k, context_v
363
+
364
+ if layer is not None:
365
+ layer.hyper_k = hypernetwork_layers[0]
366
+ layer.hyper_v = hypernetwork_layers[1]
367
+
368
+ context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k)))
369
+ context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v)))
370
+ return context_k, context_v
371
+
372
+
373
+ def apply_hypernetworks(hypernetworks, context, layer=None):
374
+ context_k = context
375
+ context_v = context
376
+ for hypernetwork in hypernetworks:
377
+ context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
378
+
379
+ return context_k, context_v
380
+
381
+
382
+ def attention_CrossAttention_forward(self, x, context=None, mask=None, **kwargs):
383
+ h = self.heads
384
+
385
+ q = self.to_q(x)
386
+ context = default(context, x)
387
+
388
+ context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
389
+ k = self.to_k(context_k)
390
+ v = self.to_v(context_v)
391
+
392
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
393
+
394
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
395
+
396
+ if mask is not None:
397
+ mask = rearrange(mask, 'b ... -> b (...)')
398
+ max_neg_value = -torch.finfo(sim.dtype).max
399
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
400
+ sim.masked_fill_(~mask, max_neg_value)
401
+
402
+ # attention, what we cannot get enough of
403
+ attn = sim.softmax(dim=-1)
404
+
405
+ out = einsum('b i j, b j d -> b i d', attn, v)
406
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
407
+ return self.to_out(out)
408
+
409
+
410
+ def stack_conds(conds):
411
+ if len(conds) == 1:
412
+ return torch.stack(conds)
413
+
414
+ # same as in reconstruct_multicond_batch
415
+ token_count = max([x.shape[0] for x in conds])
416
+ for i in range(len(conds)):
417
+ if conds[i].shape[0] != token_count:
418
+ last_vector = conds[i][-1:]
419
+ last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
420
+ conds[i] = torch.vstack([conds[i], last_vector_repeated])
421
+
422
+ return torch.stack(conds)
423
+
424
+
425
+ def statistics(data):
426
+ if len(data) < 2:
427
+ std = 0
428
+ else:
429
+ std = stdev(data)
430
+ total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
431
+ recent_data = data[-32:]
432
+ if len(recent_data) < 2:
433
+ std = 0
434
+ else:
435
+ std = stdev(recent_data)
436
+ recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
437
+ return total_information, recent_information
438
+
439
+
440
+ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
441
+ # Remove illegal characters from name.
442
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
443
+ assert name, "Name cannot be empty!"
444
+
445
+ fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
446
+ if not overwrite_old:
447
+ assert not os.path.exists(fn), f"file {fn} already exists"
448
+
449
+ if type(layer_structure) == str:
450
+ layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
451
+
452
+ if use_dropout and dropout_structure and type(dropout_structure) == str:
453
+ dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
454
+ else:
455
+ dropout_structure = [0] * len(layer_structure)
456
+
457
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
458
+ name=name,
459
+ enable_sizes=[int(x) for x in enable_sizes],
460
+ layer_structure=layer_structure,
461
+ activation_func=activation_func,
462
+ weight_init=weight_init,
463
+ add_layer_norm=add_layer_norm,
464
+ use_dropout=use_dropout,
465
+ dropout_structure=dropout_structure
466
+ )
467
+ hypernet.save(fn)
468
+
469
+ shared.reload_hypernetworks()
470
+
471
+
472
+ def train_hypernetwork(id_task, hypernetwork_name: str, learn_rate: float, batch_size: int, gradient_step: int, data_root: str, log_directory: str, training_width: int, training_height: int, varsize: bool, steps: int, clip_grad_mode: str, clip_grad_value: float, shuffle_tags: bool, tag_drop_out: bool, latent_sampling_method: str, use_weight: bool, create_image_every: int, save_hypernetwork_every: int, template_filename: str, preview_from_txt2img: bool, preview_prompt: str, preview_negative_prompt: str, preview_steps: int, preview_sampler_name: str, preview_cfg_scale: float, preview_seed: int, preview_width: int, preview_height: int):
473
+ from modules import images, processing
474
+
475
+ save_hypernetwork_every = save_hypernetwork_every or 0
476
+ create_image_every = create_image_every or 0
477
+ template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
478
+ textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
479
+ template_file = template_file.path
480
+
481
+ path = shared.hypernetworks.get(hypernetwork_name, None)
482
+ hypernetwork = Hypernetwork()
483
+ hypernetwork.load(path)
484
+ shared.loaded_hypernetworks = [hypernetwork]
485
+
486
+ shared.state.job = "train-hypernetwork"
487
+ shared.state.textinfo = "Initializing hypernetwork training..."
488
+ shared.state.job_count = steps
489
+
490
+ hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
491
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
492
+
493
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
494
+ unload = shared.opts.unload_models_when_training
495
+
496
+ if save_hypernetwork_every > 0:
497
+ hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
498
+ os.makedirs(hypernetwork_dir, exist_ok=True)
499
+ else:
500
+ hypernetwork_dir = None
501
+
502
+ if create_image_every > 0:
503
+ images_dir = os.path.join(log_directory, "images")
504
+ os.makedirs(images_dir, exist_ok=True)
505
+ else:
506
+ images_dir = None
507
+
508
+ checkpoint = sd_models.select_checkpoint()
509
+
510
+ initial_step = hypernetwork.step or 0
511
+ if initial_step >= steps:
512
+ shared.state.textinfo = "Model has already been trained beyond specified max steps"
513
+ return hypernetwork, filename
514
+
515
+ scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
516
+
517
+ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
518
+ if clip_grad:
519
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
520
+
521
+ if shared.opts.training_enable_tensorboard:
522
+ tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
523
+
524
+ # dataset loading may take a while, so input validations and early returns should be done before this
525
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
526
+
527
+ pin_memory = shared.opts.pin_memory
528
+
529
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
530
+
531
+ if shared.opts.save_training_settings_to_txt:
532
+ saved_params = dict(
533
+ model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
534
+ **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
535
+ )
536
+ saving_settings.save_settings_to_file(log_directory, {**saved_params, **locals()})
537
+
538
+ latent_sampling_method = ds.latent_sampling_method
539
+
540
+ dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
541
+
542
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
543
+
544
+ if unload:
545
+ shared.parallel_processing_allowed = False
546
+ shared.sd_model.cond_stage_model.to(devices.cpu)
547
+ shared.sd_model.first_stage_model.to(devices.cpu)
548
+
549
+ weights = hypernetwork.weights()
550
+ hypernetwork.train()
551
+
552
+ # Here we use optimizer from saved HN, or we can specify as UI option.
553
+ if hypernetwork.optimizer_name in optimizer_dict:
554
+ optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
555
+ optimizer_name = hypernetwork.optimizer_name
556
+ else:
557
+ print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
558
+ optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
559
+ optimizer_name = 'AdamW'
560
+
561
+ if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
562
+ try:
563
+ optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
564
+ except RuntimeError as e:
565
+ print("Cannot resume from saved optimizer!")
566
+ print(e)
567
+
568
+ scaler = torch.cuda.amp.GradScaler()
569
+
570
+ batch_size = ds.batch_size
571
+ gradient_step = ds.gradient_step
572
+ # n steps = batch_size * gradient_step * n image processed
573
+ steps_per_epoch = len(ds) // batch_size // gradient_step
574
+ max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
575
+ loss_step = 0
576
+ _loss_step = 0 #internal
577
+ # size = len(ds.indexes)
578
+ # loss_dict = defaultdict(lambda : deque(maxlen = 1024))
579
+ loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
580
+ # losses = torch.zeros((size,))
581
+ # previous_mean_losses = [0]
582
+ # previous_mean_loss = 0
583
+ # print("Mean loss of {} elements".format(size))
584
+
585
+ steps_without_grad = 0
586
+
587
+ last_saved_file = "<none>"
588
+ last_saved_image = "<none>"
589
+ forced_filename = "<none>"
590
+
591
+ pbar = tqdm.tqdm(total=steps - initial_step)
592
+ try:
593
+ sd_hijack_checkpoint.add()
594
+
595
+ for _ in range((steps-initial_step) * gradient_step):
596
+ if scheduler.finished:
597
+ break
598
+ if shared.state.interrupted:
599
+ break
600
+ for j, batch in enumerate(dl):
601
+ # works as a drop_last=True for gradient accumulation
602
+ if j == max_steps_per_epoch:
603
+ break
604
+ scheduler.apply(optimizer, hypernetwork.step)
605
+ if scheduler.finished:
606
+ break
607
+ if shared.state.interrupted:
608
+ break
609
+
610
+ if clip_grad:
611
+ clip_grad_sched.step(hypernetwork.step)
612
+
613
+ with devices.autocast():
614
+ x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
615
+ if use_weight:
616
+ w = batch.weight.to(devices.device, non_blocking=pin_memory)
617
+ if tag_drop_out != 0 or shuffle_tags:
618
+ shared.sd_model.cond_stage_model.to(devices.device)
619
+ c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
620
+ shared.sd_model.cond_stage_model.to(devices.cpu)
621
+ else:
622
+ c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
623
+ if use_weight:
624
+ loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
625
+ del w
626
+ else:
627
+ loss = shared.sd_model.forward(x, c)[0] / gradient_step
628
+ del x
629
+ del c
630
+
631
+ _loss_step += loss.item()
632
+ scaler.scale(loss).backward()
633
+
634
+ # go back until we reach gradient accumulation steps
635
+ if (j + 1) % gradient_step != 0:
636
+ continue
637
+ loss_logging.append(_loss_step)
638
+ if clip_grad:
639
+ clip_grad(weights, clip_grad_sched.learn_rate)
640
+
641
+ scaler.step(optimizer)
642
+ scaler.update()
643
+ hypernetwork.step += 1
644
+ pbar.update()
645
+ optimizer.zero_grad(set_to_none=True)
646
+ loss_step = _loss_step
647
+ _loss_step = 0
648
+
649
+ steps_done = hypernetwork.step + 1
650
+
651
+ epoch_num = hypernetwork.step // steps_per_epoch
652
+ epoch_step = hypernetwork.step % steps_per_epoch
653
+
654
+ description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
655
+ pbar.set_description(description)
656
+ if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
657
+ # Before saving, change name to match current checkpoint.
658
+ hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
659
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
660
+ hypernetwork.optimizer_name = optimizer_name
661
+ if shared.opts.save_optimizer_state:
662
+ hypernetwork.optimizer_state_dict = optimizer.state_dict()
663
+ save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
664
+ hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
665
+
666
+
667
+
668
+ if shared.opts.training_enable_tensorboard:
669
+ epoch_num = hypernetwork.step // len(ds)
670
+ epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
671
+ mean_loss = sum(loss_logging) / len(loss_logging)
672
+ textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
673
+
674
+ textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
675
+ "loss": f"{loss_step:.7f}",
676
+ "learn_rate": scheduler.learn_rate
677
+ })
678
+
679
+ if images_dir is not None and steps_done % create_image_every == 0:
680
+ forced_filename = f'{hypernetwork_name}-{steps_done}'
681
+ last_saved_image = os.path.join(images_dir, forced_filename)
682
+ hypernetwork.eval()
683
+ rng_state = torch.get_rng_state()
684
+ cuda_rng_state = None
685
+ if torch.cuda.is_available():
686
+ cuda_rng_state = torch.cuda.get_rng_state_all()
687
+ shared.sd_model.cond_stage_model.to(devices.device)
688
+ shared.sd_model.first_stage_model.to(devices.device)
689
+
690
+ p = processing.StableDiffusionProcessingTxt2Img(
691
+ sd_model=shared.sd_model,
692
+ do_not_save_grid=True,
693
+ do_not_save_samples=True,
694
+ )
695
+
696
+ p.disable_extra_networks = True
697
+
698
+ if preview_from_txt2img:
699
+ p.prompt = preview_prompt
700
+ p.negative_prompt = preview_negative_prompt
701
+ p.steps = preview_steps
702
+ p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]
703
+ p.cfg_scale = preview_cfg_scale
704
+ p.seed = preview_seed
705
+ p.width = preview_width
706
+ p.height = preview_height
707
+ else:
708
+ p.prompt = batch.cond_text[0]
709
+ p.steps = 20
710
+ p.width = training_width
711
+ p.height = training_height
712
+
713
+ preview_text = p.prompt
714
+
715
+ with closing(p):
716
+ processed = processing.process_images(p)
717
+ image = processed.images[0] if len(processed.images) > 0 else None
718
+
719
+ if unload:
720
+ shared.sd_model.cond_stage_model.to(devices.cpu)
721
+ shared.sd_model.first_stage_model.to(devices.cpu)
722
+ torch.set_rng_state(rng_state)
723
+ if torch.cuda.is_available():
724
+ torch.cuda.set_rng_state_all(cuda_rng_state)
725
+ hypernetwork.train()
726
+ if image is not None:
727
+ shared.state.assign_current_image(image)
728
+ if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
729
+ textual_inversion.tensorboard_add_image(tensorboard_writer,
730
+ f"Validation at epoch {epoch_num}", image,
731
+ hypernetwork.step)
732
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
733
+ last_saved_image += f", prompt: {preview_text}"
734
+
735
+ shared.state.job_no = hypernetwork.step
736
+
737
+ shared.state.textinfo = f"""
738
+ <p>
739
+ Loss: {loss_step:.7f}<br/>
740
+ Step: {steps_done}<br/>
741
+ Last prompt: {html.escape(batch.cond_text[0])}<br/>
742
+ Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
743
+ Last saved image: {html.escape(last_saved_image)}<br/>
744
+ </p>
745
+ """
746
+ except Exception:
747
+ errors.report("Exception in training hypernetwork", exc_info=True)
748
+ finally:
749
+ pbar.leave = False
750
+ pbar.close()
751
+ hypernetwork.eval()
752
+ sd_hijack_checkpoint.remove()
753
+
754
+
755
+
756
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
757
+ hypernetwork.optimizer_name = optimizer_name
758
+ if shared.opts.save_optimizer_state:
759
+ hypernetwork.optimizer_state_dict = optimizer.state_dict()
760
+ save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
761
+
762
+ del optimizer
763
+ hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
764
+ shared.sd_model.cond_stage_model.to(devices.device)
765
+ shared.sd_model.first_stage_model.to(devices.device)
766
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
767
+
768
+ return hypernetwork, filename
769
+
770
+ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
771
+ old_hypernetwork_name = hypernetwork.name
772
+ old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
773
+ old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
774
+ try:
775
+ hypernetwork.sd_checkpoint = checkpoint.shorthash
776
+ hypernetwork.sd_checkpoint_name = checkpoint.model_name
777
+ hypernetwork.name = hypernetwork_name
778
+ hypernetwork.save(filename)
779
+ except:
780
+ hypernetwork.sd_checkpoint = old_sd_checkpoint
781
+ hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
782
+ hypernetwork.name = old_hypernetwork_name
783
+ raise
modules/hypernetworks/ui.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+
3
+ import gradio as gr
4
+ import modules.hypernetworks.hypernetwork
5
+ from modules import devices, sd_hijack, shared
6
+
7
+ not_available = ["hardswish", "multiheadattention"]
8
+ keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
9
+
10
+
11
+ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
12
+ filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
13
+
14
+ return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
15
+
16
+
17
+ def train_hypernetwork(*args):
18
+ shared.loaded_hypernetworks = []
19
+
20
+ assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
21
+
22
+ try:
23
+ sd_hijack.undo_optimizations()
24
+
25
+ hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
26
+
27
+ res = f"""
28
+ Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
29
+ Hypernetwork saved to {html.escape(filename)}
30
+ """
31
+ return res, ""
32
+ except Exception:
33
+ raise
34
+ finally:
35
+ shared.sd_model.cond_stage_model.to(devices.device)
36
+ shared.sd_model.first_stage_model.to(devices.device)
37
+ sd_hijack.apply_optimizations()
38
+
modules/images.py ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import functools
5
+ import pytz
6
+ import io
7
+ import math
8
+ import os
9
+ from collections import namedtuple
10
+ import re
11
+
12
+ import numpy as np
13
+ import piexif
14
+ import piexif.helper
15
+ from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin, ImageOps
16
+ # pillow_avif needs to be imported somewhere in code for it to work
17
+ import pillow_avif # noqa: F401
18
+ import string
19
+ import json
20
+ import hashlib
21
+
22
+ from modules import sd_samplers, shared, script_callbacks, errors
23
+ from modules.paths_internal import roboto_ttf_file
24
+ from modules.shared import opts
25
+
26
+ LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
27
+
28
+
29
+ def get_font(fontsize: int):
30
+ try:
31
+ return ImageFont.truetype(opts.font or roboto_ttf_file, fontsize)
32
+ except Exception:
33
+ return ImageFont.truetype(roboto_ttf_file, fontsize)
34
+
35
+
36
+ def image_grid(imgs, batch_size=1, rows=None):
37
+ if rows is None:
38
+ if opts.n_rows > 0:
39
+ rows = opts.n_rows
40
+ elif opts.n_rows == 0:
41
+ rows = batch_size
42
+ elif opts.grid_prevent_empty_spots:
43
+ rows = math.floor(math.sqrt(len(imgs)))
44
+ while len(imgs) % rows != 0:
45
+ rows -= 1
46
+ else:
47
+ rows = math.sqrt(len(imgs))
48
+ rows = round(rows)
49
+ if rows > len(imgs):
50
+ rows = len(imgs)
51
+
52
+ cols = math.ceil(len(imgs) / rows)
53
+
54
+ params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
55
+ script_callbacks.image_grid_callback(params)
56
+
57
+ w, h = map(max, zip(*(img.size for img in imgs)))
58
+ grid_background_color = ImageColor.getcolor(opts.grid_background_color, 'RGB')
59
+ grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color=grid_background_color)
60
+
61
+ for i, img in enumerate(params.imgs):
62
+ img_w, img_h = img.size
63
+ w_offset, h_offset = 0 if img_w == w else (w - img_w) // 2, 0 if img_h == h else (h - img_h) // 2
64
+ grid.paste(img, box=(i % params.cols * w + w_offset, i // params.cols * h + h_offset))
65
+
66
+ return grid
67
+
68
+
69
+ class Grid(namedtuple("_Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])):
70
+ @property
71
+ def tile_count(self) -> int:
72
+ """
73
+ The total number of tiles in the grid.
74
+ """
75
+ return sum(len(row[2]) for row in self.tiles)
76
+
77
+
78
+ def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid:
79
+ w, h = image.size
80
+
81
+ non_overlap_width = tile_w - overlap
82
+ non_overlap_height = tile_h - overlap
83
+
84
+ cols = math.ceil((w - overlap) / non_overlap_width)
85
+ rows = math.ceil((h - overlap) / non_overlap_height)
86
+
87
+ dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
88
+ dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
89
+
90
+ grid = Grid([], tile_w, tile_h, w, h, overlap)
91
+ for row in range(rows):
92
+ row_images = []
93
+
94
+ y = int(row * dy)
95
+
96
+ if y + tile_h >= h:
97
+ y = h - tile_h
98
+
99
+ for col in range(cols):
100
+ x = int(col * dx)
101
+
102
+ if x + tile_w >= w:
103
+ x = w - tile_w
104
+
105
+ tile = image.crop((x, y, x + tile_w, y + tile_h))
106
+
107
+ row_images.append([x, tile_w, tile])
108
+
109
+ grid.tiles.append([y, tile_h, row_images])
110
+
111
+ return grid
112
+
113
+
114
+ def combine_grid(grid):
115
+ def make_mask_image(r):
116
+ r = r * 255 / grid.overlap
117
+ r = r.astype(np.uint8)
118
+ return Image.fromarray(r, 'L')
119
+
120
+ mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
121
+ mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
122
+
123
+ combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
124
+ for y, h, row in grid.tiles:
125
+ combined_row = Image.new("RGB", (grid.image_w, h))
126
+ for x, w, tile in row:
127
+ if x == 0:
128
+ combined_row.paste(tile, (0, 0))
129
+ continue
130
+
131
+ combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
132
+ combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
133
+
134
+ if y == 0:
135
+ combined_image.paste(combined_row, (0, 0))
136
+ continue
137
+
138
+ combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
139
+ combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
140
+
141
+ return combined_image
142
+
143
+
144
+ class GridAnnotation:
145
+ def __init__(self, text='', is_active=True):
146
+ self.text = text
147
+ self.is_active = is_active
148
+ self.size = None
149
+
150
+
151
+ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
152
+
153
+ color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB')
154
+ color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB')
155
+ color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB')
156
+
157
+ def wrap(drawing, text, font, line_length):
158
+ lines = ['']
159
+ for word in text.split():
160
+ line = f'{lines[-1]} {word}'.strip()
161
+ if drawing.textlength(line, font=font) <= line_length:
162
+ lines[-1] = line
163
+ else:
164
+ lines.append(word)
165
+ return lines
166
+
167
+ def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
168
+ for line in lines:
169
+ fnt = initial_fnt
170
+ fontsize = initial_fontsize
171
+ while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
172
+ fontsize -= 1
173
+ fnt = get_font(fontsize)
174
+ drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
175
+
176
+ if not line.is_active:
177
+ drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
178
+
179
+ draw_y += line.size[1] + line_spacing
180
+
181
+ fontsize = (width + height) // 25
182
+ line_spacing = fontsize // 2
183
+
184
+ fnt = get_font(fontsize)
185
+
186
+ pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
187
+
188
+ cols = im.width // width
189
+ rows = im.height // height
190
+
191
+ assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
192
+ assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
193
+
194
+ calc_img = Image.new("RGB", (1, 1), color_background)
195
+ calc_d = ImageDraw.Draw(calc_img)
196
+
197
+ for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
198
+ items = [] + texts
199
+ texts.clear()
200
+
201
+ for line in items:
202
+ wrapped = wrap(calc_d, line.text, fnt, allowed_width)
203
+ texts += [GridAnnotation(x, line.is_active) for x in wrapped]
204
+
205
+ for line in texts:
206
+ bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
207
+ line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
208
+ line.allowed_width = allowed_width
209
+
210
+ hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
211
+ ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
212
+
213
+ pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
214
+
215
+ result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), color_background)
216
+
217
+ for row in range(rows):
218
+ for col in range(cols):
219
+ cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
220
+ result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))
221
+
222
+ d = ImageDraw.Draw(result)
223
+
224
+ for col in range(cols):
225
+ x = pad_left + (width + margin) * col + width / 2
226
+ y = pad_top / 2 - hor_text_heights[col] / 2
227
+
228
+ draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
229
+
230
+ for row in range(rows):
231
+ x = pad_left / 2
232
+ y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2
233
+
234
+ draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
235
+
236
+ return result
237
+
238
+
239
+ def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
240
+ prompts = all_prompts[1:]
241
+ boundary = math.ceil(len(prompts) / 2)
242
+
243
+ prompts_horiz = prompts[:boundary]
244
+ prompts_vert = prompts[boundary:]
245
+
246
+ hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
247
+ ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
248
+
249
+ return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
250
+
251
+
252
+ def resize_image(resize_mode, im, width, height, upscaler_name=None):
253
+ """
254
+ Resizes an image with the specified resize_mode, width, and height.
255
+
256
+ Args:
257
+ resize_mode: The mode to use when resizing the image.
258
+ 0: Resize the image to the specified width and height.
259
+ 1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
260
+ 2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
261
+ im: The image to resize.
262
+ width: The width to resize the image to.
263
+ height: The height to resize the image to.
264
+ upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
265
+ """
266
+
267
+ upscaler_name = upscaler_name or opts.upscaler_for_img2img
268
+
269
+ def resize(im, w, h):
270
+ if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
271
+ return im.resize((w, h), resample=LANCZOS)
272
+
273
+ scale = max(w / im.width, h / im.height)
274
+
275
+ if scale > 1.0:
276
+ upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
277
+ if len(upscalers) == 0:
278
+ upscaler = shared.sd_upscalers[0]
279
+ print(f"could not find upscaler named {upscaler_name or '<empty string>'}, using {upscaler.name} as a fallback")
280
+ else:
281
+ upscaler = upscalers[0]
282
+
283
+ im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
284
+
285
+ if im.width != w or im.height != h:
286
+ im = im.resize((w, h), resample=LANCZOS)
287
+
288
+ return im
289
+
290
+ if resize_mode == 0:
291
+ res = resize(im, width, height)
292
+
293
+ elif resize_mode == 1:
294
+ ratio = width / height
295
+ src_ratio = im.width / im.height
296
+
297
+ src_w = width if ratio > src_ratio else im.width * height // im.height
298
+ src_h = height if ratio <= src_ratio else im.height * width // im.width
299
+
300
+ resized = resize(im, src_w, src_h)
301
+ res = Image.new("RGB", (width, height))
302
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
303
+
304
+ else:
305
+ ratio = width / height
306
+ src_ratio = im.width / im.height
307
+
308
+ src_w = width if ratio < src_ratio else im.width * height // im.height
309
+ src_h = height if ratio >= src_ratio else im.height * width // im.width
310
+
311
+ resized = resize(im, src_w, src_h)
312
+ res = Image.new("RGB", (width, height))
313
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
314
+
315
+ if ratio < src_ratio:
316
+ fill_height = height // 2 - src_h // 2
317
+ if fill_height > 0:
318
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
319
+ res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
320
+ elif ratio > src_ratio:
321
+ fill_width = width // 2 - src_w // 2
322
+ if fill_width > 0:
323
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
324
+ res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
325
+
326
+ return res
327
+
328
+
329
+ if not shared.cmd_opts.unix_filenames_sanitization:
330
+ invalid_filename_chars = '#<>:"/\\|?*\n\r\t'
331
+ else:
332
+ invalid_filename_chars = '/'
333
+ invalid_filename_prefix = ' '
334
+ invalid_filename_postfix = ' .'
335
+ re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
336
+ re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
337
+ re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
338
+ max_filename_part_length = shared.cmd_opts.filenames_max_length
339
+ NOTHING_AND_SKIP_PREVIOUS_TEXT = object()
340
+
341
+
342
+ def sanitize_filename_part(text, replace_spaces=True):
343
+ if text is None:
344
+ return None
345
+
346
+ if replace_spaces:
347
+ text = text.replace(' ', '_')
348
+
349
+ text = text.translate({ord(x): '_' for x in invalid_filename_chars})
350
+ text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]
351
+ text = text.rstrip(invalid_filename_postfix)
352
+ return text
353
+
354
+
355
+ @functools.cache
356
+ def get_scheduler_str(sampler_name, scheduler_name):
357
+ """Returns {Scheduler} if the scheduler is applicable to the sampler"""
358
+ if scheduler_name == 'Automatic':
359
+ config = sd_samplers.find_sampler_config(sampler_name)
360
+ scheduler_name = config.options.get('scheduler', 'Automatic')
361
+ return scheduler_name.capitalize()
362
+
363
+
364
+ @functools.cache
365
+ def get_sampler_scheduler_str(sampler_name, scheduler_name):
366
+ """Returns the '{Sampler} {Scheduler}' if the scheduler is applicable to the sampler"""
367
+ return f'{sampler_name} {get_scheduler_str(sampler_name, scheduler_name)}'
368
+
369
+
370
+ def get_sampler_scheduler(p, sampler):
371
+ """Returns '{Sampler} {Scheduler}' / '{Scheduler}' / 'NOTHING_AND_SKIP_PREVIOUS_TEXT'"""
372
+ if hasattr(p, 'scheduler') and hasattr(p, 'sampler_name'):
373
+ if sampler:
374
+ sampler_scheduler = get_sampler_scheduler_str(p.sampler_name, p.scheduler)
375
+ else:
376
+ sampler_scheduler = get_scheduler_str(p.sampler_name, p.scheduler)
377
+ return sanitize_filename_part(sampler_scheduler, replace_spaces=False)
378
+ return NOTHING_AND_SKIP_PREVIOUS_TEXT
379
+
380
+
381
+ class FilenameGenerator:
382
+ replacements = {
383
+ 'basename': lambda self: self.basename or 'img',
384
+ 'seed': lambda self: self.seed if self.seed is not None else '',
385
+ 'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
386
+ 'seed_last': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.all_seeds[-1],
387
+ 'steps': lambda self: self.p and self.p.steps,
388
+ 'cfg': lambda self: self.p and self.p.cfg_scale,
389
+ 'width': lambda self: self.image.width,
390
+ 'height': lambda self: self.image.height,
391
+ 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
392
+ 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
393
+ 'sampler_scheduler': lambda self: self.p and get_sampler_scheduler(self.p, True),
394
+ 'scheduler': lambda self: self.p and get_sampler_scheduler(self.p, False),
395
+ 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
396
+ 'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.name_for_extra, replace_spaces=False),
397
+ 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
398
+ 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
399
+ 'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
400
+ 'prompt_hash': lambda self, *args: self.string_hash(self.prompt, *args),
401
+ 'negative_prompt_hash': lambda self, *args: self.string_hash(self.p.negative_prompt, *args),
402
+ 'full_prompt_hash': lambda self, *args: self.string_hash(f"{self.p.prompt} {self.p.negative_prompt}", *args), # a space in between to create a unique string
403
+ 'prompt': lambda self: sanitize_filename_part(self.prompt),
404
+ 'prompt_no_styles': lambda self: self.prompt_no_style(),
405
+ 'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
406
+ 'prompt_words': lambda self: self.prompt_words(),
407
+ 'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 or self.zip else self.p.batch_index + 1,
408
+ 'batch_size': lambda self: self.p.batch_size,
409
+ 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if (self.p.n_iter == 1 and self.p.batch_size == 1) or self.zip else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
410
+ 'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
411
+ 'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
412
+ 'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
413
+ 'user': lambda self: self.p.user,
414
+ 'vae_filename': lambda self: self.get_vae_filename(),
415
+ 'none': lambda self: '', # Overrides the default, so you can get just the sequence number
416
+ 'image_hash': lambda self, *args: self.image_hash(*args) # accepts formats: [image_hash<length>] default full hash
417
+ }
418
+ default_time_format = '%Y%m%d%H%M%S'
419
+
420
+ def __init__(self, p, seed, prompt, image, zip=False, basename=""):
421
+ self.p = p
422
+ self.seed = seed
423
+ self.prompt = prompt
424
+ self.image = image
425
+ self.zip = zip
426
+ self.basename = basename
427
+
428
+ def get_vae_filename(self):
429
+ """Get the name of the VAE file."""
430
+
431
+ import modules.sd_vae as sd_vae
432
+
433
+ if sd_vae.loaded_vae_file is None:
434
+ return "NoneType"
435
+
436
+ file_name = os.path.basename(sd_vae.loaded_vae_file)
437
+ split_file_name = file_name.split('.')
438
+ if len(split_file_name) > 1 and split_file_name[0] == '':
439
+ return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
440
+ else:
441
+ return split_file_name[0]
442
+
443
+
444
+ def hasprompt(self, *args):
445
+ lower = self.prompt.lower()
446
+ if self.p is None or self.prompt is None:
447
+ return None
448
+ outres = ""
449
+ for arg in args:
450
+ if arg != "":
451
+ division = arg.split("|")
452
+ expected = division[0].lower()
453
+ default = division[1] if len(division) > 1 else ""
454
+ if lower.find(expected) >= 0:
455
+ outres = f'{outres}{expected}'
456
+ else:
457
+ outres = outres if default == "" else f'{outres}{default}'
458
+ return sanitize_filename_part(outres)
459
+
460
+ def prompt_no_style(self):
461
+ if self.p is None or self.prompt is None:
462
+ return None
463
+
464
+ prompt_no_style = self.prompt
465
+ for style in shared.prompt_styles.get_style_prompts(self.p.styles):
466
+ if style:
467
+ for part in style.split("{prompt}"):
468
+ prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
469
+
470
+ prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
471
+
472
+ return sanitize_filename_part(prompt_no_style, replace_spaces=False)
473
+
474
+ def prompt_words(self):
475
+ words = [x for x in re_nonletters.split(self.prompt or "") if x]
476
+ if len(words) == 0:
477
+ words = ["empty"]
478
+ return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
479
+
480
+ def datetime(self, *args):
481
+ time_datetime = datetime.datetime.now()
482
+
483
+ time_format = args[0] if (args and args[0] != "") else self.default_time_format
484
+ try:
485
+ time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
486
+ except pytz.exceptions.UnknownTimeZoneError:
487
+ time_zone = None
488
+
489
+ time_zone_time = time_datetime.astimezone(time_zone)
490
+ try:
491
+ formatted_time = time_zone_time.strftime(time_format)
492
+ except (ValueError, TypeError):
493
+ formatted_time = time_zone_time.strftime(self.default_time_format)
494
+
495
+ return sanitize_filename_part(formatted_time, replace_spaces=False)
496
+
497
+ def image_hash(self, *args):
498
+ length = int(args[0]) if (args and args[0] != "") else None
499
+ return hashlib.sha256(self.image.tobytes()).hexdigest()[0:length]
500
+
501
+ def string_hash(self, text, *args):
502
+ length = int(args[0]) if (args and args[0] != "") else 8
503
+ return hashlib.sha256(text.encode()).hexdigest()[0:length]
504
+
505
+ def apply(self, x):
506
+ res = ''
507
+
508
+ for m in re_pattern.finditer(x):
509
+ text, pattern = m.groups()
510
+
511
+ if pattern is None:
512
+ res += text
513
+ continue
514
+
515
+ pattern_args = []
516
+ while True:
517
+ m = re_pattern_arg.match(pattern)
518
+ if m is None:
519
+ break
520
+
521
+ pattern, arg = m.groups()
522
+ pattern_args.insert(0, arg)
523
+
524
+ fun = self.replacements.get(pattern.lower())
525
+ if fun is not None:
526
+ try:
527
+ replacement = fun(self, *pattern_args)
528
+ except Exception:
529
+ replacement = None
530
+ errors.report(f"Error adding [{pattern}] to filename", exc_info=True)
531
+
532
+ if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
533
+ continue
534
+ elif replacement is not None:
535
+ res += text + str(replacement)
536
+ continue
537
+
538
+ res += f'{text}[{pattern}]'
539
+
540
+ return res
541
+
542
+
543
+ def get_next_sequence_number(path, basename):
544
+ """
545
+ Determines and returns the next sequence number to use when saving an image in the specified directory.
546
+
547
+ The sequence starts at 0.
548
+ """
549
+ result = -1
550
+ if basename != '':
551
+ basename = f"{basename}-"
552
+
553
+ prefix_length = len(basename)
554
+ for p in os.listdir(path):
555
+ if p.startswith(basename):
556
+ parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
557
+ try:
558
+ result = max(int(parts[0]), result)
559
+ except ValueError:
560
+ pass
561
+
562
+ return result + 1
563
+
564
+
565
+ def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None, pnginfo_section_name='parameters'):
566
+ """
567
+ Saves image to filename, including geninfo as text information for generation info.
568
+ For PNG images, geninfo is added to existing pnginfo dictionary using the pnginfo_section_name argument as key.
569
+ For JPG images, there's no dictionary and geninfo just replaces the EXIF description.
570
+ """
571
+
572
+ if extension is None:
573
+ extension = os.path.splitext(filename)[1]
574
+
575
+ image_format = Image.registered_extensions()[extension]
576
+
577
+ if extension.lower() == '.png':
578
+ existing_pnginfo = existing_pnginfo or {}
579
+ if opts.enable_pnginfo:
580
+ existing_pnginfo[pnginfo_section_name] = geninfo
581
+
582
+ if opts.enable_pnginfo:
583
+ pnginfo_data = PngImagePlugin.PngInfo()
584
+ for k, v in (existing_pnginfo or {}).items():
585
+ pnginfo_data.add_text(k, str(v))
586
+ else:
587
+ pnginfo_data = None
588
+
589
+ image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
590
+
591
+ elif extension.lower() in (".jpg", ".jpeg", ".webp"):
592
+ if image.mode == 'RGBA':
593
+ image = image.convert("RGB")
594
+ elif image.mode == 'I;16':
595
+ image = image.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
596
+
597
+ image.save(filename, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
598
+
599
+ if opts.enable_pnginfo and geninfo is not None:
600
+ exif_bytes = piexif.dump({
601
+ "Exif": {
602
+ piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(geninfo or "", encoding="unicode")
603
+ },
604
+ })
605
+
606
+ piexif.insert(exif_bytes, filename)
607
+ elif extension.lower() == '.avif':
608
+ if opts.enable_pnginfo and geninfo is not None:
609
+ exif_bytes = piexif.dump({
610
+ "Exif": {
611
+ piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(geninfo or "", encoding="unicode")
612
+ },
613
+ })
614
+ else:
615
+ exif_bytes = None
616
+
617
+ image.save(filename,format=image_format, quality=opts.jpeg_quality, exif=exif_bytes)
618
+ elif extension.lower() == ".gif":
619
+ image.save(filename, format=image_format, comment=geninfo)
620
+ else:
621
+ image.save(filename, format=image_format, quality=opts.jpeg_quality)
622
+
623
+
624
+ def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
625
+ """Save an image.
626
+
627
+ Args:
628
+ image (`PIL.Image`):
629
+ The image to be saved.
630
+ path (`str`):
631
+ The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
632
+ basename (`str`):
633
+ The base filename which will be applied to `filename pattern`.
634
+ seed, prompt, short_filename,
635
+ extension (`str`):
636
+ Image file extension, default is `png`.
637
+ pngsectionname (`str`):
638
+ Specify the name of the section which `info` will be saved in.
639
+ info (`str` or `PngImagePlugin.iTXt`):
640
+ PNG info chunks.
641
+ existing_info (`dict`):
642
+ Additional PNG info. `existing_info == {pngsectionname: info, ...}`
643
+ no_prompt:
644
+ TODO I don't know its meaning.
645
+ p (`StableDiffusionProcessing`)
646
+ forced_filename (`str`):
647
+ If specified, `basename` and filename pattern will be ignored.
648
+ save_to_dirs (bool):
649
+ If true, the image will be saved into a subdirectory of `path`.
650
+
651
+ Returns: (fullfn, txt_fullfn)
652
+ fullfn (`str`):
653
+ The full path of the saved imaged.
654
+ txt_fullfn (`str` or None):
655
+ If a text file is saved for this image, this will be its full path. Otherwise None.
656
+ """
657
+ namegen = FilenameGenerator(p, seed, prompt, image, basename=basename)
658
+
659
+ # WebP and JPG formats have maximum dimension limits of 16383 and 65535 respectively. switch to PNG which has a much higher limit
660
+ if (image.height > 65535 or image.width > 65535) and extension.lower() in ("jpg", "jpeg") or (image.height > 16383 or image.width > 16383) and extension.lower() == "webp":
661
+ print('Image dimensions too large; saving as PNG')
662
+ extension = "png"
663
+
664
+ if save_to_dirs is None:
665
+ save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
666
+
667
+ if save_to_dirs:
668
+ dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
669
+ path = os.path.join(path, dirname)
670
+
671
+ os.makedirs(path, exist_ok=True)
672
+
673
+ if forced_filename is None:
674
+ if short_filename or seed is None:
675
+ file_decoration = ""
676
+ elif opts.save_to_dirs:
677
+ file_decoration = opts.samples_filename_pattern or "[seed]"
678
+ else:
679
+ file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
680
+
681
+ file_decoration = namegen.apply(file_decoration) + suffix
682
+
683
+ add_number = opts.save_images_add_number or file_decoration == ''
684
+
685
+ if file_decoration != "" and add_number:
686
+ file_decoration = f"-{file_decoration}"
687
+
688
+ if add_number:
689
+ basecount = get_next_sequence_number(path, basename)
690
+ fullfn = None
691
+ for i in range(500):
692
+ fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
693
+ fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
694
+ if not os.path.exists(fullfn):
695
+ break
696
+ else:
697
+ fullfn = os.path.join(path, f"{file_decoration}.{extension}")
698
+ else:
699
+ fullfn = os.path.join(path, f"{forced_filename}.{extension}")
700
+
701
+ pnginfo = existing_info or {}
702
+ if info is not None:
703
+ pnginfo[pnginfo_section_name] = info
704
+
705
+ params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
706
+ script_callbacks.before_image_saved_callback(params)
707
+
708
+ image = params.image
709
+ fullfn = params.filename
710
+ info = params.pnginfo.get(pnginfo_section_name, None)
711
+
712
+ def _atomically_save_image(image_to_save, filename_without_extension, extension):
713
+ """
714
+ save image with .tmp extension to avoid race condition when another process detects new image in the directory
715
+ """
716
+ temp_file_path = f"{filename_without_extension}.tmp"
717
+
718
+ save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
719
+
720
+ filename = filename_without_extension + extension
721
+ if shared.opts.save_images_replace_action != "Replace":
722
+ n = 0
723
+ while os.path.exists(filename):
724
+ n += 1
725
+ filename = f"{filename_without_extension}-{n}{extension}"
726
+ os.replace(temp_file_path, filename)
727
+
728
+ fullfn_without_extension, extension = os.path.splitext(params.filename)
729
+ if hasattr(os, 'statvfs'):
730
+ max_name_len = os.statvfs(path).f_namemax
731
+ fullfn_without_extension = fullfn_without_extension[:max_name_len - max(4, len(extension))]
732
+ params.filename = fullfn_without_extension + extension
733
+ fullfn = params.filename
734
+ _atomically_save_image(image, fullfn_without_extension, extension)
735
+
736
+ image.already_saved_as = fullfn
737
+
738
+ oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
739
+ if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
740
+ ratio = image.width / image.height
741
+ resize_to = None
742
+ if oversize and ratio > 1:
743
+ resize_to = round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)
744
+ elif oversize:
745
+ resize_to = round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)
746
+
747
+ if resize_to is not None:
748
+ try:
749
+ # Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16
750
+ image = image.resize(resize_to, LANCZOS)
751
+ except Exception:
752
+ image = image.resize(resize_to)
753
+ try:
754
+ _atomically_save_image(image, fullfn_without_extension, ".jpg")
755
+ except Exception as e:
756
+ errors.display(e, "saving image as downscaled JPG")
757
+
758
+ if opts.save_txt and info is not None:
759
+ txt_fullfn = f"{fullfn_without_extension}.txt"
760
+ with open(txt_fullfn, "w", encoding="utf8") as file:
761
+ file.write(f"{info}\n")
762
+ else:
763
+ txt_fullfn = None
764
+
765
+ script_callbacks.image_saved_callback(params)
766
+
767
+ return fullfn, txt_fullfn
768
+
769
+
770
+ IGNORED_INFO_KEYS = {
771
+ 'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
772
+ 'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
773
+ 'icc_profile', 'chromaticity', 'photoshop',
774
+ }
775
+
776
+
777
+ def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
778
+ items = (image.info or {}).copy()
779
+
780
+ geninfo = items.pop('parameters', None)
781
+
782
+ if "exif" in items:
783
+ exif_data = items["exif"]
784
+ try:
785
+ exif = piexif.load(exif_data)
786
+ except OSError:
787
+ # memory / exif was not valid so piexif tried to read from a file
788
+ exif = None
789
+ exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
790
+ try:
791
+ exif_comment = piexif.helper.UserComment.load(exif_comment)
792
+ except ValueError:
793
+ exif_comment = exif_comment.decode('utf8', errors="ignore")
794
+
795
+ if exif_comment:
796
+ geninfo = exif_comment
797
+ elif "comment" in items: # for gif
798
+ if isinstance(items["comment"], bytes):
799
+ geninfo = items["comment"].decode('utf8', errors="ignore")
800
+ else:
801
+ geninfo = items["comment"]
802
+
803
+ for field in IGNORED_INFO_KEYS:
804
+ items.pop(field, None)
805
+
806
+ if items.get("Software", None) == "NovelAI":
807
+ try:
808
+ json_info = json.loads(items["Comment"])
809
+ sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
810
+
811
+ geninfo = f"""{items["Description"]}
812
+ Negative prompt: {json_info["uc"]}
813
+ Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
814
+ except Exception:
815
+ errors.report("Error parsing NovelAI image generation parameters", exc_info=True)
816
+
817
+ return geninfo, items
818
+
819
+
820
+ def image_data(data):
821
+ import gradio as gr
822
+
823
+ try:
824
+ image = read(io.BytesIO(data))
825
+ textinfo, _ = read_info_from_image(image)
826
+ return textinfo, None
827
+ except Exception:
828
+ pass
829
+
830
+ try:
831
+ text = data.decode('utf8')
832
+ assert len(text) < 10000
833
+ return text, None
834
+
835
+ except Exception:
836
+ pass
837
+
838
+ return gr.update(), None
839
+
840
+
841
+ def flatten(img, bgcolor):
842
+ """replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency"""
843
+
844
+ if img.mode == "RGBA":
845
+ background = Image.new('RGBA', img.size, bgcolor)
846
+ background.paste(img, mask=img)
847
+ img = background
848
+
849
+ return img.convert('RGB')
850
+
851
+
852
+ def read(fp, **kwargs):
853
+ image = Image.open(fp, **kwargs)
854
+ image = fix_image(image)
855
+
856
+ return image
857
+
858
+
859
+ def fix_image(image: Image.Image):
860
+ if image is None:
861
+ return None
862
+
863
+ try:
864
+ image = ImageOps.exif_transpose(image)
865
+ image = fix_png_transparency(image)
866
+ except Exception:
867
+ pass
868
+
869
+ return image
870
+
871
+
872
+ def fix_png_transparency(image: Image.Image):
873
+ if image.mode not in ("RGB", "P") or not isinstance(image.info.get("transparency"), bytes):
874
+ return image
875
+
876
+ image = image.convert("RGBA")
877
+ return image
modules/img2img.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from contextlib import closing
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
7
+ import gradio as gr
8
+
9
+ from modules import images
10
+ from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
11
+ from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
12
+ from modules.shared import opts, state
13
+ from modules.sd_models import get_closet_checkpoint_match
14
+ import modules.shared as shared
15
+ import modules.processing as processing
16
+ from modules.ui import plaintext_to_html
17
+ import modules.scripts
18
+
19
+
20
+ def process_batch(p, input, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
21
+ output_dir = output_dir.strip()
22
+ processing.fix_seed(p)
23
+
24
+ if isinstance(input, str):
25
+ batch_images = list(shared.walk_files(input, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
26
+ else:
27
+ batch_images = [os.path.abspath(x.name) for x in input]
28
+
29
+ is_inpaint_batch = False
30
+ if inpaint_mask_dir:
31
+ inpaint_masks = shared.listfiles(inpaint_mask_dir)
32
+ is_inpaint_batch = bool(inpaint_masks)
33
+
34
+ if is_inpaint_batch:
35
+ print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
36
+
37
+ print(f"Will process {len(batch_images)} images, creating {p.n_iter * p.batch_size} new images for each.")
38
+
39
+ state.job_count = len(batch_images) * p.n_iter
40
+
41
+ # extract "default" params to use in case getting png info fails
42
+ prompt = p.prompt
43
+ negative_prompt = p.negative_prompt
44
+ seed = p.seed
45
+ cfg_scale = p.cfg_scale
46
+ sampler_name = p.sampler_name
47
+ steps = p.steps
48
+ override_settings = p.override_settings
49
+ sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
50
+ batch_results = None
51
+ discard_further_results = False
52
+ for i, image in enumerate(batch_images):
53
+ state.job = f"{i+1} out of {len(batch_images)}"
54
+ if state.skipped:
55
+ state.skipped = False
56
+
57
+ if state.interrupted or state.stopping_generation:
58
+ break
59
+
60
+ try:
61
+ img = images.read(image)
62
+ except UnidentifiedImageError as e:
63
+ print(e)
64
+ continue
65
+ # Use the EXIF orientation of photos taken by smartphones.
66
+ img = ImageOps.exif_transpose(img)
67
+
68
+ if to_scale:
69
+ p.width = int(img.width * scale_by)
70
+ p.height = int(img.height * scale_by)
71
+
72
+ p.init_images = [img] * p.batch_size
73
+
74
+ image_path = Path(image)
75
+ if is_inpaint_batch:
76
+ # try to find corresponding mask for an image using simple filename matching
77
+ if len(inpaint_masks) == 1:
78
+ mask_image_path = inpaint_masks[0]
79
+ else:
80
+ # try to find corresponding mask for an image using simple filename matching
81
+ mask_image_dir = Path(inpaint_mask_dir)
82
+ masks_found = list(mask_image_dir.glob(f"{image_path.stem}.*"))
83
+
84
+ if len(masks_found) == 0:
85
+ print(f"Warning: mask is not found for {image_path} in {mask_image_dir}. Skipping it.")
86
+ continue
87
+
88
+ # it should contain only 1 matching mask
89
+ # otherwise user has many masks with the same name but different extensions
90
+ mask_image_path = masks_found[0]
91
+
92
+ mask_image = images.read(mask_image_path)
93
+ p.image_mask = mask_image
94
+
95
+ if use_png_info:
96
+ try:
97
+ info_img = img
98
+ if png_info_dir:
99
+ info_img_path = os.path.join(png_info_dir, os.path.basename(image))
100
+ info_img = images.read(info_img_path)
101
+ geninfo, _ = images.read_info_from_image(info_img)
102
+ parsed_parameters = parse_generation_parameters(geninfo)
103
+ parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
104
+ except Exception:
105
+ parsed_parameters = {}
106
+
107
+ p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "")
108
+ p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "")
109
+ p.seed = int(parsed_parameters.get("Seed", seed))
110
+ p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale))
111
+ p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
112
+ p.steps = int(parsed_parameters.get("Steps", steps))
113
+
114
+ model_info = get_closet_checkpoint_match(parsed_parameters.get("Model hash", None))
115
+ if model_info is not None:
116
+ p.override_settings['sd_model_checkpoint'] = model_info.name
117
+ elif sd_model_checkpoint_override:
118
+ p.override_settings['sd_model_checkpoint'] = sd_model_checkpoint_override
119
+ else:
120
+ p.override_settings.pop("sd_model_checkpoint", None)
121
+
122
+ if output_dir:
123
+ p.outpath_samples = output_dir
124
+ p.override_settings['save_to_dirs'] = False
125
+ p.override_settings['save_images_replace_action'] = "Add number suffix"
126
+ if p.n_iter > 1 or p.batch_size > 1:
127
+ p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
128
+ else:
129
+ p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
130
+
131
+ proc = modules.scripts.scripts_img2img.run(p, *args)
132
+
133
+ if proc is None:
134
+ p.override_settings.pop('save_images_replace_action', None)
135
+ proc = process_images(p)
136
+
137
+ if not discard_further_results and proc:
138
+ if batch_results:
139
+ batch_results.images.extend(proc.images)
140
+ batch_results.infotexts.extend(proc.infotexts)
141
+ else:
142
+ batch_results = proc
143
+
144
+ if 0 <= shared.opts.img2img_batch_show_results_limit < len(batch_results.images):
145
+ discard_further_results = True
146
+ batch_results.images = batch_results.images[:int(shared.opts.img2img_batch_show_results_limit)]
147
+ batch_results.infotexts = batch_results.infotexts[:int(shared.opts.img2img_batch_show_results_limit)]
148
+
149
+ return batch_results
150
+
151
+
152
+ def img2img(id_task: str, request: gr.Request, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, img2img_batch_source_type: str, img2img_batch_upload: list, *args):
153
+ override_settings = create_override_settings_dict(override_settings_texts)
154
+
155
+ is_batch = mode == 5
156
+
157
+ if mode == 0: # img2img
158
+ image = init_img
159
+ mask = None
160
+ elif mode == 1: # img2img sketch
161
+ image = sketch
162
+ mask = None
163
+ elif mode == 2: # inpaint
164
+ image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
165
+ mask = processing.create_binary_mask(mask)
166
+ elif mode == 3: # inpaint sketch
167
+ image = inpaint_color_sketch
168
+ orig = inpaint_color_sketch_orig or inpaint_color_sketch
169
+ pred = np.any(np.array(image) != np.array(orig), axis=-1)
170
+ mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
171
+ mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
172
+ blur = ImageFilter.GaussianBlur(mask_blur)
173
+ image = Image.composite(image.filter(blur), orig, mask.filter(blur))
174
+ elif mode == 4: # inpaint upload mask
175
+ image = init_img_inpaint
176
+ mask = init_mask_inpaint
177
+ else:
178
+ image = None
179
+ mask = None
180
+
181
+ image = images.fix_image(image)
182
+ mask = images.fix_image(mask)
183
+
184
+ if selected_scale_tab == 1 and not is_batch:
185
+ assert image, "Can't scale by because no image is selected"
186
+
187
+ width = int(image.width * scale_by)
188
+ height = int(image.height * scale_by)
189
+
190
+ assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
191
+
192
+ p = StableDiffusionProcessingImg2Img(
193
+ sd_model=shared.sd_model,
194
+ outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
195
+ outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
196
+ prompt=prompt,
197
+ negative_prompt=negative_prompt,
198
+ styles=prompt_styles,
199
+ batch_size=batch_size,
200
+ n_iter=n_iter,
201
+ cfg_scale=cfg_scale,
202
+ width=width,
203
+ height=height,
204
+ init_images=[image],
205
+ mask=mask,
206
+ mask_blur=mask_blur,
207
+ inpainting_fill=inpainting_fill,
208
+ resize_mode=resize_mode,
209
+ denoising_strength=denoising_strength,
210
+ image_cfg_scale=image_cfg_scale,
211
+ inpaint_full_res=inpaint_full_res,
212
+ inpaint_full_res_padding=inpaint_full_res_padding,
213
+ inpainting_mask_invert=inpainting_mask_invert,
214
+ override_settings=override_settings,
215
+ )
216
+
217
+ p.scripts = modules.scripts.scripts_img2img
218
+ p.script_args = args
219
+
220
+ p.user = request.username
221
+
222
+ if shared.opts.enable_console_prompts:
223
+ print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
224
+
225
+ with closing(p):
226
+ if is_batch:
227
+ if img2img_batch_source_type == "upload":
228
+ assert isinstance(img2img_batch_upload, list) and img2img_batch_upload
229
+ output_dir = ""
230
+ inpaint_mask_dir = ""
231
+ png_info_dir = img2img_batch_png_info_dir if not shared.cmd_opts.hide_ui_dir_config else ""
232
+ processed = process_batch(p, img2img_batch_upload, output_dir, inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=png_info_dir)
233
+ else: # "from dir"
234
+ assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
235
+ processed = process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
236
+
237
+ if processed is None:
238
+ processed = Processed(p, [], p.seed, "")
239
+ else:
240
+ processed = modules.scripts.scripts_img2img.run(p, *args)
241
+ if processed is None:
242
+ processed = process_images(p)
243
+
244
+ shared.total_tqdm.clear()
245
+
246
+ generation_info_js = processed.js()
247
+ if opts.samples_log_stdout:
248
+ print(generation_info_js)
249
+
250
+ if opts.do_not_show_images:
251
+ processed.images = []
252
+
253
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
modules/import_hook.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ # this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
4
+ if "--xformers" not in "".join(sys.argv):
5
+ sys.modules["xformers"] = None
6
+
7
+ # Hack to fix a changed import in torchvision 0.17+, which otherwise breaks
8
+ # basicsr; see https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13985
9
+ try:
10
+ import torchvision.transforms.functional_tensor # noqa: F401
11
+ except ImportError:
12
+ try:
13
+ import torchvision.transforms.functional as functional
14
+ sys.modules["torchvision.transforms.functional_tensor"] = functional
15
+ except ImportError:
16
+ pass # shrug...
modules/infotext_utils.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import base64
3
+ import io
4
+ import json
5
+ import os
6
+ import re
7
+ import sys
8
+
9
+ import gradio as gr
10
+ from modules.paths import data_path
11
+ from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions, images, prompt_parser, errors
12
+ from PIL import Image
13
+
14
+ sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name
15
+
16
+ re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
17
+ re_param = re.compile(re_param_code)
18
+ re_imagesize = re.compile(r"^(\d+)x(\d+)$")
19
+ re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
20
+ type_of_gr_update = type(gr.update())
21
+
22
+
23
+ class ParamBinding:
24
+ def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
25
+ self.paste_button = paste_button
26
+ self.tabname = tabname
27
+ self.source_text_component = source_text_component
28
+ self.source_image_component = source_image_component
29
+ self.source_tabname = source_tabname
30
+ self.override_settings_component = override_settings_component
31
+ self.paste_field_names = paste_field_names or []
32
+
33
+
34
+ class PasteField(tuple):
35
+ def __new__(cls, component, target, *, api=None):
36
+ return super().__new__(cls, (component, target))
37
+
38
+ def __init__(self, component, target, *, api=None):
39
+ super().__init__()
40
+
41
+ self.api = api
42
+ self.component = component
43
+ self.label = target if isinstance(target, str) else None
44
+ self.function = target if callable(target) else None
45
+
46
+
47
+ paste_fields: dict[str, dict] = {}
48
+ registered_param_bindings: list[ParamBinding] = []
49
+
50
+
51
+ def reset():
52
+ paste_fields.clear()
53
+ registered_param_bindings.clear()
54
+
55
+
56
+ def quote(text):
57
+ if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text):
58
+ return text
59
+
60
+ return json.dumps(text, ensure_ascii=False)
61
+
62
+
63
+ def unquote(text):
64
+ if len(text) == 0 or text[0] != '"' or text[-1] != '"':
65
+ return text
66
+
67
+ try:
68
+ return json.loads(text)
69
+ except Exception:
70
+ return text
71
+
72
+
73
+ def image_from_url_text(filedata):
74
+ if filedata is None:
75
+ return None
76
+
77
+ if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False):
78
+ filedata = filedata[0]
79
+
80
+ if type(filedata) == dict and filedata.get("is_file", False):
81
+ filename = filedata["name"]
82
+ is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
83
+ assert is_in_right_dir, 'trying to open image file outside of allowed directories'
84
+
85
+ filename = filename.rsplit('?', 1)[0]
86
+ return images.read(filename)
87
+
88
+ if type(filedata) == list:
89
+ if len(filedata) == 0:
90
+ return None
91
+
92
+ filedata = filedata[0]
93
+
94
+ if filedata.startswith("data:image/png;base64,"):
95
+ filedata = filedata[len("data:image/png;base64,"):]
96
+
97
+ filedata = base64.decodebytes(filedata.encode('utf-8'))
98
+ image = images.read(io.BytesIO(filedata))
99
+ return image
100
+
101
+
102
+ def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
103
+
104
+ if fields:
105
+ for i in range(len(fields)):
106
+ if not isinstance(fields[i], PasteField):
107
+ fields[i] = PasteField(*fields[i])
108
+
109
+ paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
110
+
111
+ # backwards compatibility for existing extensions
112
+ import modules.ui
113
+ if tabname == 'txt2img':
114
+ modules.ui.txt2img_paste_fields = fields
115
+ elif tabname == 'img2img':
116
+ modules.ui.img2img_paste_fields = fields
117
+
118
+
119
+ def create_buttons(tabs_list):
120
+ buttons = {}
121
+ for tab in tabs_list:
122
+ buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
123
+ return buttons
124
+
125
+
126
+ def bind_buttons(buttons, send_image, send_generate_info):
127
+ """old function for backwards compatibility; do not use this, use register_paste_params_button"""
128
+ for tabname, button in buttons.items():
129
+ source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
130
+ source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
131
+
132
+ register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
133
+
134
+
135
+ def register_paste_params_button(binding: ParamBinding):
136
+ registered_param_bindings.append(binding)
137
+
138
+
139
+ def connect_paste_params_buttons():
140
+ for binding in registered_param_bindings:
141
+ destination_image_component = paste_fields[binding.tabname]["init_img"]
142
+ fields = paste_fields[binding.tabname]["fields"]
143
+ override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
144
+
145
+ destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
146
+ destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
147
+
148
+ if binding.source_image_component and destination_image_component:
149
+ need_send_dementions = destination_width_component and binding.tabname != 'inpaint'
150
+ if isinstance(binding.source_image_component, gr.Gallery):
151
+ func = send_image_and_dimensions if need_send_dementions else image_from_url_text
152
+ jsfunc = "extract_image_from_gallery"
153
+ else:
154
+ func = send_image_and_dimensions if need_send_dementions else lambda x: x
155
+ jsfunc = None
156
+
157
+ binding.paste_button.click(
158
+ fn=func,
159
+ _js=jsfunc,
160
+ inputs=[binding.source_image_component],
161
+ outputs=[destination_image_component, destination_width_component, destination_height_component] if need_send_dementions else [destination_image_component],
162
+ show_progress=False,
163
+ )
164
+
165
+ if binding.source_text_component is not None and fields is not None:
166
+ connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
167
+
168
+ if binding.source_tabname is not None and fields is not None:
169
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names
170
+ binding.paste_button.click(
171
+ fn=lambda *x: x,
172
+ inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
173
+ outputs=[field for field, name in fields if name in paste_field_names],
174
+ show_progress=False,
175
+ )
176
+
177
+ binding.paste_button.click(
178
+ fn=None,
179
+ _js=f"switch_to_{binding.tabname}",
180
+ inputs=None,
181
+ outputs=None,
182
+ show_progress=False,
183
+ )
184
+
185
+
186
+ def send_image_and_dimensions(x):
187
+ if isinstance(x, Image.Image):
188
+ img = x
189
+ else:
190
+ img = image_from_url_text(x)
191
+
192
+ if shared.opts.send_size and isinstance(img, Image.Image):
193
+ w = img.width
194
+ h = img.height
195
+ else:
196
+ w = gr.update()
197
+ h = gr.update()
198
+
199
+ return img, w, h
200
+
201
+
202
+ def restore_old_hires_fix_params(res):
203
+ """for infotexts that specify old First pass size parameter, convert it into
204
+ width, height, and hr scale"""
205
+
206
+ firstpass_width = res.get('First pass size-1', None)
207
+ firstpass_height = res.get('First pass size-2', None)
208
+
209
+ if shared.opts.use_old_hires_fix_width_height:
210
+ hires_width = int(res.get("Hires resize-1", 0))
211
+ hires_height = int(res.get("Hires resize-2", 0))
212
+
213
+ if hires_width and hires_height:
214
+ res['Size-1'] = hires_width
215
+ res['Size-2'] = hires_height
216
+ return
217
+
218
+ if firstpass_width is None or firstpass_height is None:
219
+ return
220
+
221
+ firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
222
+ width = int(res.get("Size-1", 512))
223
+ height = int(res.get("Size-2", 512))
224
+
225
+ if firstpass_width == 0 or firstpass_height == 0:
226
+ firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
227
+
228
+ res['Size-1'] = firstpass_width
229
+ res['Size-2'] = firstpass_height
230
+ res['Hires resize-1'] = width
231
+ res['Hires resize-2'] = height
232
+
233
+
234
+ def parse_generation_parameters(x: str, skip_fields: list[str] | None = None):
235
+ """parses generation parameters string, the one you see in text field under the picture in UI:
236
+ ```
237
+ girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
238
+ Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
239
+ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
240
+ ```
241
+
242
+ returns a dict with field values
243
+ """
244
+ if skip_fields is None:
245
+ skip_fields = shared.opts.infotext_skip_pasting
246
+
247
+ res = {}
248
+
249
+ prompt = ""
250
+ negative_prompt = ""
251
+
252
+ done_with_prompt = False
253
+
254
+ *lines, lastline = x.strip().split("\n")
255
+ if len(re_param.findall(lastline)) < 3:
256
+ lines.append(lastline)
257
+ lastline = ''
258
+
259
+ for line in lines:
260
+ line = line.strip()
261
+ if line.startswith("Negative prompt:"):
262
+ done_with_prompt = True
263
+ line = line[16:].strip()
264
+ if done_with_prompt:
265
+ negative_prompt += ("" if negative_prompt == "" else "\n") + line
266
+ else:
267
+ prompt += ("" if prompt == "" else "\n") + line
268
+
269
+ for k, v in re_param.findall(lastline):
270
+ try:
271
+ if v[0] == '"' and v[-1] == '"':
272
+ v = unquote(v)
273
+
274
+ m = re_imagesize.match(v)
275
+ if m is not None:
276
+ res[f"{k}-1"] = m.group(1)
277
+ res[f"{k}-2"] = m.group(2)
278
+ else:
279
+ res[k] = v
280
+ except Exception:
281
+ print(f"Error parsing \"{k}: {v}\"")
282
+
283
+ # Extract styles from prompt
284
+ if shared.opts.infotext_styles != "Ignore":
285
+ found_styles, prompt_no_styles, negative_prompt_no_styles = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)
286
+
287
+ same_hr_styles = True
288
+ if ("Hires prompt" in res or "Hires negative prompt" in res) and (infotext_ver > infotext_versions.v180_hr_styles if (infotext_ver := infotext_versions.parse_version(res.get("Version"))) else True):
289
+ hr_prompt, hr_negative_prompt = res.get("Hires prompt", prompt), res.get("Hires negative prompt", negative_prompt)
290
+ hr_found_styles, hr_prompt_no_styles, hr_negative_prompt_no_styles = shared.prompt_styles.extract_styles_from_prompt(hr_prompt, hr_negative_prompt)
291
+ if same_hr_styles := found_styles == hr_found_styles:
292
+ res["Hires prompt"] = '' if hr_prompt_no_styles == prompt_no_styles else hr_prompt_no_styles
293
+ res['Hires negative prompt'] = '' if hr_negative_prompt_no_styles == negative_prompt_no_styles else hr_negative_prompt_no_styles
294
+
295
+ if same_hr_styles:
296
+ prompt, negative_prompt = prompt_no_styles, negative_prompt_no_styles
297
+ if (shared.opts.infotext_styles == "Apply if any" and found_styles) or shared.opts.infotext_styles == "Apply":
298
+ res['Styles array'] = found_styles
299
+
300
+ res["Prompt"] = prompt
301
+ res["Negative prompt"] = negative_prompt
302
+
303
+ # Missing CLIP skip means it was set to 1 (the default)
304
+ if "Clip skip" not in res:
305
+ res["Clip skip"] = "1"
306
+
307
+ hypernet = res.get("Hypernet", None)
308
+ if hypernet is not None:
309
+ res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
310
+
311
+ if "Hires resize-1" not in res:
312
+ res["Hires resize-1"] = 0
313
+ res["Hires resize-2"] = 0
314
+
315
+ if "Hires sampler" not in res:
316
+ res["Hires sampler"] = "Use same sampler"
317
+
318
+ if "Hires schedule type" not in res:
319
+ res["Hires schedule type"] = "Use same scheduler"
320
+
321
+ if "Hires checkpoint" not in res:
322
+ res["Hires checkpoint"] = "Use same checkpoint"
323
+
324
+ if "Hires prompt" not in res:
325
+ res["Hires prompt"] = ""
326
+
327
+ if "Hires negative prompt" not in res:
328
+ res["Hires negative prompt"] = ""
329
+
330
+ if "Mask mode" not in res:
331
+ res["Mask mode"] = "Inpaint masked"
332
+
333
+ if "Masked content" not in res:
334
+ res["Masked content"] = 'original'
335
+
336
+ if "Inpaint area" not in res:
337
+ res["Inpaint area"] = "Whole picture"
338
+
339
+ if "Masked area padding" not in res:
340
+ res["Masked area padding"] = 32
341
+
342
+ restore_old_hires_fix_params(res)
343
+
344
+ # Missing RNG means the default was set, which is GPU RNG
345
+ if "RNG" not in res:
346
+ res["RNG"] = "GPU"
347
+
348
+ if "Schedule type" not in res:
349
+ res["Schedule type"] = "Automatic"
350
+
351
+ if "Schedule max sigma" not in res:
352
+ res["Schedule max sigma"] = 0
353
+
354
+ if "Schedule min sigma" not in res:
355
+ res["Schedule min sigma"] = 0
356
+
357
+ if "Schedule rho" not in res:
358
+ res["Schedule rho"] = 0
359
+
360
+ if "VAE Encoder" not in res:
361
+ res["VAE Encoder"] = "Full"
362
+
363
+ if "VAE Decoder" not in res:
364
+ res["VAE Decoder"] = "Full"
365
+
366
+ if "FP8 weight" not in res:
367
+ res["FP8 weight"] = "Disable"
368
+
369
+ if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
370
+ res["Cache FP16 weight for LoRA"] = False
371
+
372
+ prompt_attention = prompt_parser.parse_prompt_attention(prompt)
373
+ prompt_attention += prompt_parser.parse_prompt_attention(negative_prompt)
374
+ prompt_uses_emphasis = len(prompt_attention) != len([p for p in prompt_attention if p[1] == 1.0 or p[0] == 'BREAK'])
375
+ if "Emphasis" not in res and prompt_uses_emphasis:
376
+ res["Emphasis"] = "Original"
377
+
378
+ if "Refiner switch by sampling steps" not in res:
379
+ res["Refiner switch by sampling steps"] = False
380
+
381
+ infotext_versions.backcompat(res)
382
+
383
+ for key in skip_fields:
384
+ res.pop(key, None)
385
+
386
+ return res
387
+
388
+
389
+ infotext_to_setting_name_mapping = [
390
+
391
+ ]
392
+ """Mapping of infotext labels to setting names. Only left for backwards compatibility - use OptionInfo(..., infotext='...') instead.
393
+ Example content:
394
+
395
+ infotext_to_setting_name_mapping = [
396
+ ('Conditional mask weight', 'inpainting_mask_weight'),
397
+ ('Model hash', 'sd_model_checkpoint'),
398
+ ('ENSD', 'eta_noise_seed_delta'),
399
+ ('Schedule type', 'k_sched_type'),
400
+ ]
401
+ """
402
+
403
+
404
+ def create_override_settings_dict(text_pairs):
405
+ """creates processing's override_settings parameters from gradio's multiselect
406
+
407
+ Example input:
408
+ ['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']
409
+
410
+ Example output:
411
+ {'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
412
+ """
413
+
414
+ res = {}
415
+
416
+ params = {}
417
+ for pair in text_pairs:
418
+ k, v = pair.split(":", maxsplit=1)
419
+
420
+ params[k] = v.strip()
421
+
422
+ mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
423
+ for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
424
+ value = params.get(param_name, None)
425
+
426
+ if value is None:
427
+ continue
428
+
429
+ res[setting_name] = shared.opts.cast_value(setting_name, value)
430
+
431
+ return res
432
+
433
+
434
+ def get_override_settings(params, *, skip_fields=None):
435
+ """Returns a list of settings overrides from the infotext parameters dictionary.
436
+
437
+ This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns
438
+ a list of tuples containing the parameter name, setting name, and new value cast to correct type.
439
+
440
+ It checks for conditions before adding an override:
441
+ - ignores settings that match the current value
442
+ - ignores parameter keys present in skip_fields argument.
443
+
444
+ Example input:
445
+ {"Clip skip": "2"}
446
+
447
+ Example output:
448
+ [("Clip skip", "CLIP_stop_at_last_layers", 2)]
449
+ """
450
+
451
+ res = []
452
+
453
+ mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
454
+ for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
455
+ if param_name in (skip_fields or {}):
456
+ continue
457
+
458
+ v = params.get(param_name, None)
459
+ if v is None:
460
+ continue
461
+
462
+ if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
463
+ continue
464
+
465
+ v = shared.opts.cast_value(setting_name, v)
466
+ current_value = getattr(shared.opts, setting_name, None)
467
+
468
+ if v == current_value:
469
+ continue
470
+
471
+ res.append((param_name, setting_name, v))
472
+
473
+ return res
474
+
475
+
476
+ def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
477
+ def paste_func(prompt):
478
+ if not prompt and not shared.cmd_opts.hide_ui_dir_config and not shared.cmd_opts.no_prompt_history:
479
+ filename = os.path.join(data_path, "params.txt")
480
+ try:
481
+ with open(filename, "r", encoding="utf8") as file:
482
+ prompt = file.read()
483
+ except OSError:
484
+ pass
485
+
486
+ params = parse_generation_parameters(prompt)
487
+ script_callbacks.infotext_pasted_callback(prompt, params)
488
+ res = []
489
+
490
+ for output, key in paste_fields:
491
+ if callable(key):
492
+ try:
493
+ v = key(params)
494
+ except Exception:
495
+ errors.report(f"Error executing {key}", exc_info=True)
496
+ v = None
497
+ else:
498
+ v = params.get(key, None)
499
+
500
+ if v is None:
501
+ res.append(gr.update())
502
+ elif isinstance(v, type_of_gr_update):
503
+ res.append(v)
504
+ else:
505
+ try:
506
+ valtype = type(output.value)
507
+
508
+ if valtype == bool and v == "False":
509
+ val = False
510
+ elif valtype == int:
511
+ val = float(v)
512
+ else:
513
+ val = valtype(v)
514
+
515
+ res.append(gr.update(value=val))
516
+ except Exception:
517
+ res.append(gr.update())
518
+
519
+ return res
520
+
521
+ if override_settings_component is not None:
522
+ already_handled_fields = {key: 1 for _, key in paste_fields}
523
+
524
+ def paste_settings(params):
525
+ vals = get_override_settings(params, skip_fields=already_handled_fields)
526
+
527
+ vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals]
528
+
529
+ return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
530
+
531
+ paste_fields = paste_fields + [(override_settings_component, paste_settings)]
532
+
533
+ button.click(
534
+ fn=paste_func,
535
+ inputs=[input_comp],
536
+ outputs=[x[0] for x in paste_fields],
537
+ show_progress=False,
538
+ )
539
+ button.click(
540
+ fn=None,
541
+ _js=f"recalculate_prompts_{tabname}",
542
+ inputs=[],
543
+ outputs=[],
544
+ show_progress=False,
545
+ )
546
+
modules/infotext_versions.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules import shared
2
+ from packaging import version
3
+ import re
4
+
5
+
6
+ v160 = version.parse("1.6.0")
7
+ v170_tsnr = version.parse("v1.7.0-225")
8
+ v180 = version.parse("1.8.0")
9
+ v180_hr_styles = version.parse("1.8.0-139")
10
+
11
+
12
+ def parse_version(text):
13
+ if text is None:
14
+ return None
15
+
16
+ m = re.match(r'([^-]+-[^-]+)-.*', text)
17
+ if m:
18
+ text = m.group(1)
19
+
20
+ try:
21
+ return version.parse(text)
22
+ except Exception:
23
+ return None
24
+
25
+
26
+ def backcompat(d):
27
+ """Checks infotext Version field, and enables backwards compatibility options according to it."""
28
+
29
+ if not shared.opts.auto_backcompat:
30
+ return
31
+
32
+ ver = parse_version(d.get("Version"))
33
+ if ver is None:
34
+ return
35
+
36
+ if ver < v160 and '[' in d.get('Prompt', ''):
37
+ d["Old prompt editing timelines"] = True
38
+
39
+ if ver < v160 and d.get('Sampler', '') in ('DDIM', 'PLMS'):
40
+ d["Pad conds v0"] = True
41
+
42
+ if ver < v170_tsnr:
43
+ d["Downcast alphas_cumprod"] = True
44
+
45
+ if ver < v180 and d.get('Refiner'):
46
+ d["Refiner switch by sampling steps"] = True
modules/initialize.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import logging
3
+ import os
4
+ import sys
5
+ import warnings
6
+ from threading import Thread
7
+
8
+ from modules.timer import startup_timer
9
+
10
+
11
+ def imports():
12
+ logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
13
+ logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
14
+
15
+ import torch # noqa: F401
16
+ startup_timer.record("import torch")
17
+ import pytorch_lightning # noqa: F401
18
+ startup_timer.record("import torch")
19
+ warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
20
+ warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
21
+
22
+ os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
23
+ import gradio # noqa: F401
24
+ startup_timer.record("import gradio")
25
+
26
+ from modules import paths, timer, import_hook, errors # noqa: F401
27
+ startup_timer.record("setup paths")
28
+
29
+ import ldm.modules.encoders.modules # noqa: F401
30
+ startup_timer.record("import ldm")
31
+
32
+ import sgm.modules.encoders.modules # noqa: F401
33
+ startup_timer.record("import sgm")
34
+
35
+ from modules import shared_init
36
+ shared_init.initialize()
37
+ startup_timer.record("initialize shared")
38
+
39
+ from modules import processing, gradio_extensons, ui # noqa: F401
40
+ startup_timer.record("other imports")
41
+
42
+
43
+ def check_versions():
44
+ from modules.shared_cmd_options import cmd_opts
45
+
46
+ if not cmd_opts.skip_version_check:
47
+ from modules import errors
48
+ errors.check_versions()
49
+
50
+
51
+ def initialize():
52
+ from modules import initialize_util
53
+ initialize_util.fix_torch_version()
54
+ initialize_util.fix_pytorch_lightning()
55
+ initialize_util.fix_asyncio_event_loop_policy()
56
+ initialize_util.validate_tls_options()
57
+ initialize_util.configure_sigint_handler()
58
+ initialize_util.configure_opts_onchange()
59
+
60
+ from modules import sd_models
61
+ sd_models.setup_model()
62
+ startup_timer.record("setup SD model")
63
+
64
+ from modules.shared_cmd_options import cmd_opts
65
+
66
+ from modules import codeformer_model
67
+ warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision.transforms.functional_tensor")
68
+ codeformer_model.setup_model(cmd_opts.codeformer_models_path)
69
+ startup_timer.record("setup codeformer")
70
+
71
+ from modules import gfpgan_model
72
+ gfpgan_model.setup_model(cmd_opts.gfpgan_models_path)
73
+ startup_timer.record("setup gfpgan")
74
+
75
+ initialize_rest(reload_script_modules=False)
76
+
77
+
78
+ def initialize_rest(*, reload_script_modules=False):
79
+ """
80
+ Called both from initialize() and when reloading the webui.
81
+ """
82
+ from modules.shared_cmd_options import cmd_opts
83
+
84
+ from modules import sd_samplers
85
+ sd_samplers.set_samplers()
86
+ startup_timer.record("set samplers")
87
+
88
+ from modules import extensions
89
+ extensions.list_extensions()
90
+ startup_timer.record("list extensions")
91
+
92
+ from modules import initialize_util
93
+ initialize_util.restore_config_state_file()
94
+ startup_timer.record("restore config state file")
95
+
96
+ from modules import shared, upscaler, scripts
97
+ if cmd_opts.ui_debug_mode:
98
+ shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
99
+ scripts.load_scripts()
100
+ return
101
+
102
+ from modules import sd_models
103
+ sd_models.list_models()
104
+ startup_timer.record("list SD models")
105
+
106
+ from modules import localization
107
+ localization.list_localizations(cmd_opts.localizations_dir)
108
+ startup_timer.record("list localizations")
109
+
110
+ with startup_timer.subcategory("load scripts"):
111
+ scripts.load_scripts()
112
+
113
+ if reload_script_modules and shared.opts.enable_reloading_ui_scripts:
114
+ for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
115
+ importlib.reload(module)
116
+ startup_timer.record("reload script modules")
117
+
118
+ from modules import modelloader
119
+ modelloader.load_upscalers()
120
+ startup_timer.record("load upscalers")
121
+
122
+ from modules import sd_vae
123
+ sd_vae.refresh_vae_list()
124
+ startup_timer.record("refresh VAE")
125
+
126
+ from modules import textual_inversion
127
+ textual_inversion.textual_inversion.list_textual_inversion_templates()
128
+ startup_timer.record("refresh textual inversion templates")
129
+
130
+ from modules import script_callbacks, sd_hijack_optimizations, sd_hijack
131
+ script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers)
132
+ sd_hijack.list_optimizers()
133
+ startup_timer.record("scripts list_optimizers")
134
+
135
+ from modules import sd_unet
136
+ sd_unet.list_unets()
137
+ startup_timer.record("scripts list_unets")
138
+
139
+ def load_model():
140
+ """
141
+ Accesses shared.sd_model property to load model.
142
+ After it's available, if it has been loaded before this access by some extension,
143
+ its optimization may be None because the list of optimizers has not been filled
144
+ by that time, so we apply optimization again.
145
+ """
146
+ from modules import devices
147
+ devices.torch_npu_set_device()
148
+
149
+ shared.sd_model # noqa: B018
150
+
151
+ if sd_hijack.current_optimizer is None:
152
+ sd_hijack.apply_optimizations()
153
+
154
+ devices.first_time_calculation()
155
+ if not shared.cmd_opts.skip_load_model_at_start:
156
+ Thread(target=load_model).start()
157
+
158
+ from modules import shared_items
159
+ shared_items.reload_hypernetworks()
160
+ startup_timer.record("reload hypernetworks")
161
+
162
+ from modules import ui_extra_networks
163
+ ui_extra_networks.initialize()
164
+ ui_extra_networks.register_default_pages()
165
+
166
+ from modules import extra_networks
167
+ extra_networks.initialize()
168
+ extra_networks.register_default_extra_networks()
169
+ startup_timer.record("initialize extra networks")
modules/initialize_util.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import signal
4
+ import sys
5
+ import re
6
+
7
+ from modules.timer import startup_timer
8
+
9
+
10
+ def gradio_server_name():
11
+ from modules.shared_cmd_options import cmd_opts
12
+
13
+ if cmd_opts.server_name:
14
+ return cmd_opts.server_name
15
+ else:
16
+ return "0.0.0.0" if cmd_opts.listen else None
17
+
18
+
19
+ def fix_torch_version():
20
+ import torch
21
+
22
+ # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
23
+ if ".dev" in torch.__version__ or "+git" in torch.__version__:
24
+ torch.__long_version__ = torch.__version__
25
+ torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
26
+
27
+ def fix_pytorch_lightning():
28
+ # Checks if pytorch_lightning.utilities.distributed already exists in the sys.modules cache
29
+ if 'pytorch_lightning.utilities.distributed' not in sys.modules:
30
+ import pytorch_lightning
31
+ # Lets the user know that the library was not found and then will set it to pytorch_lightning.utilities.rank_zero
32
+ print("Pytorch_lightning.distributed not found, attempting pytorch_lightning.rank_zero")
33
+ sys.modules["pytorch_lightning.utilities.distributed"] = pytorch_lightning.utilities.rank_zero
34
+
35
+ def fix_asyncio_event_loop_policy():
36
+ """
37
+ The default `asyncio` event loop policy only automatically creates
38
+ event loops in the main threads. Other threads must create event
39
+ loops explicitly or `asyncio.get_event_loop` (and therefore
40
+ `.IOLoop.current`) will fail. Installing this policy allows event
41
+ loops to be created automatically on any thread, matching the
42
+ behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
43
+ """
44
+
45
+ import asyncio
46
+
47
+ if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
48
+ # "Any thread" and "selector" should be orthogonal, but there's not a clean
49
+ # interface for composing policies so pick the right base.
50
+ _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
51
+ else:
52
+ _BasePolicy = asyncio.DefaultEventLoopPolicy
53
+
54
+ class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
55
+ """Event loop policy that allows loop creation on any thread.
56
+ Usage::
57
+
58
+ asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
59
+ """
60
+
61
+ def get_event_loop(self) -> asyncio.AbstractEventLoop:
62
+ try:
63
+ return super().get_event_loop()
64
+ except (RuntimeError, AssertionError):
65
+ # This was an AssertionError in python 3.4.2 (which ships with debian jessie)
66
+ # and changed to a RuntimeError in 3.4.3.
67
+ # "There is no current event loop in thread %r"
68
+ loop = self.new_event_loop()
69
+ self.set_event_loop(loop)
70
+ return loop
71
+
72
+ asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
73
+
74
+
75
+ def restore_config_state_file():
76
+ from modules import shared, config_states
77
+
78
+ config_state_file = shared.opts.restore_config_state_file
79
+ if config_state_file == "":
80
+ return
81
+
82
+ shared.opts.restore_config_state_file = ""
83
+ shared.opts.save(shared.config_filename)
84
+
85
+ if os.path.isfile(config_state_file):
86
+ print(f"*** About to restore extension state from file: {config_state_file}")
87
+ with open(config_state_file, "r", encoding="utf-8") as f:
88
+ config_state = json.load(f)
89
+ config_states.restore_extension_config(config_state)
90
+ startup_timer.record("restore extension config")
91
+ elif config_state_file:
92
+ print(f"!!! Config state backup not found: {config_state_file}")
93
+
94
+
95
+ def validate_tls_options():
96
+ from modules.shared_cmd_options import cmd_opts
97
+
98
+ if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
99
+ return
100
+
101
+ try:
102
+ if not os.path.exists(cmd_opts.tls_keyfile):
103
+ print("Invalid path to TLS keyfile given")
104
+ if not os.path.exists(cmd_opts.tls_certfile):
105
+ print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
106
+ except TypeError:
107
+ cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
108
+ print("TLS setup invalid, running webui without TLS")
109
+ else:
110
+ print("Running with TLS")
111
+ startup_timer.record("TLS")
112
+
113
+
114
+ def get_gradio_auth_creds():
115
+ """
116
+ Convert the gradio_auth and gradio_auth_path commandline arguments into
117
+ an iterable of (username, password) tuples.
118
+ """
119
+ from modules.shared_cmd_options import cmd_opts
120
+
121
+ def process_credential_line(s):
122
+ s = s.strip()
123
+ if not s:
124
+ return None
125
+ return tuple(s.split(':', 1))
126
+
127
+ if cmd_opts.gradio_auth:
128
+ for cred in cmd_opts.gradio_auth.split(','):
129
+ cred = process_credential_line(cred)
130
+ if cred:
131
+ yield cred
132
+
133
+ if cmd_opts.gradio_auth_path:
134
+ with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
135
+ for line in file.readlines():
136
+ for cred in line.strip().split(','):
137
+ cred = process_credential_line(cred)
138
+ if cred:
139
+ yield cred
140
+
141
+
142
+ def dumpstacks():
143
+ import threading
144
+ import traceback
145
+
146
+ id2name = {th.ident: th.name for th in threading.enumerate()}
147
+ code = []
148
+ for threadId, stack in sys._current_frames().items():
149
+ code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
150
+ for filename, lineno, name, line in traceback.extract_stack(stack):
151
+ code.append(f"""File: "{filename}", line {lineno}, in {name}""")
152
+ if line:
153
+ code.append(" " + line.strip())
154
+
155
+ print("\n".join(code))
156
+
157
+
158
+ def configure_sigint_handler():
159
+ # make the program just exit at ctrl+c without waiting for anything
160
+
161
+ from modules import shared
162
+
163
+ def sigint_handler(sig, frame):
164
+ print(f'Interrupted with signal {sig} in {frame}')
165
+
166
+ if shared.opts.dump_stacks_on_signal:
167
+ dumpstacks()
168
+
169
+ os._exit(0)
170
+
171
+ if not os.environ.get("COVERAGE_RUN"):
172
+ # Don't install the immediate-quit handler when running under coverage,
173
+ # as then the coverage report won't be generated.
174
+ signal.signal(signal.SIGINT, sigint_handler)
175
+
176
+
177
+ def configure_opts_onchange():
178
+ from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack
179
+ from modules.call_queue import wrap_queued_call
180
+
181
+ shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
182
+ shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
183
+ shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
184
+ shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
185
+ shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
186
+ shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
187
+ shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
188
+ shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)
189
+ startup_timer.record("opts onchange")
190
+
191
+
192
+ def setup_middleware(app):
193
+ from starlette.middleware.gzip import GZipMiddleware
194
+
195
+ app.middleware_stack = None # reset current middleware to allow modifying user provided list
196
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
197
+ configure_cors_middleware(app)
198
+ app.build_middleware_stack() # rebuild middleware stack on-the-fly
199
+
200
+
201
+ def configure_cors_middleware(app):
202
+ from starlette.middleware.cors import CORSMiddleware
203
+ from modules.shared_cmd_options import cmd_opts
204
+
205
+ cors_options = {
206
+ "allow_methods": ["*"],
207
+ "allow_headers": ["*"],
208
+ "allow_credentials": True,
209
+ }
210
+ if cmd_opts.cors_allow_origins:
211
+ cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
212
+ if cmd_opts.cors_allow_origins_regex:
213
+ cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
214
+ app.add_middleware(CORSMiddleware, **cors_options)
215
+
modules/interrogate.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from collections import namedtuple
4
+ from pathlib import Path
5
+ import re
6
+
7
+ import torch
8
+ import torch.hub
9
+
10
+ from torchvision import transforms
11
+ from torchvision.transforms.functional import InterpolationMode
12
+
13
+ from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils
14
+
15
+ blip_image_eval_size = 384
16
+ clip_model_name = 'ViT-L/14'
17
+
18
+ Category = namedtuple("Category", ["name", "topn", "items"])
19
+
20
+ re_topn = re.compile(r"\.top(\d+)$")
21
+
22
+ def category_types():
23
+ return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
24
+
25
+
26
+ def download_default_clip_interrogate_categories(content_dir):
27
+ print("Downloading CLIP categories...")
28
+
29
+ tmpdir = f"{content_dir}_tmp"
30
+ category_types = ["artists", "flavors", "mediums", "movements"]
31
+
32
+ try:
33
+ os.makedirs(tmpdir, exist_ok=True)
34
+ for category_type in category_types:
35
+ torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
36
+ os.rename(tmpdir, content_dir)
37
+
38
+ except Exception as e:
39
+ errors.display(e, "downloading default CLIP interrogate categories")
40
+ finally:
41
+ if os.path.exists(tmpdir):
42
+ os.removedirs(tmpdir)
43
+
44
+
45
+ class InterrogateModels:
46
+ blip_model = None
47
+ clip_model = None
48
+ clip_preprocess = None
49
+ dtype = None
50
+ running_on_cpu = None
51
+
52
+ def __init__(self, content_dir):
53
+ self.loaded_categories = None
54
+ self.skip_categories = []
55
+ self.content_dir = content_dir
56
+ self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
57
+
58
+ def categories(self):
59
+ if not os.path.exists(self.content_dir):
60
+ download_default_clip_interrogate_categories(self.content_dir)
61
+
62
+ if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
63
+ return self.loaded_categories
64
+
65
+ self.loaded_categories = []
66
+
67
+ if os.path.exists(self.content_dir):
68
+ self.skip_categories = shared.opts.interrogate_clip_skip_categories
69
+ category_types = []
70
+ for filename in Path(self.content_dir).glob('*.txt'):
71
+ category_types.append(filename.stem)
72
+ if filename.stem in self.skip_categories:
73
+ continue
74
+ m = re_topn.search(filename.stem)
75
+ topn = 1 if m is None else int(m.group(1))
76
+ with open(filename, "r", encoding="utf8") as file:
77
+ lines = [x.strip() for x in file.readlines()]
78
+
79
+ self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
80
+
81
+ return self.loaded_categories
82
+
83
+ def create_fake_fairscale(self):
84
+ class FakeFairscale:
85
+ def checkpoint_wrapper(self):
86
+ pass
87
+
88
+ sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
89
+
90
+ def load_blip_model(self):
91
+ self.create_fake_fairscale()
92
+ import models.blip
93
+
94
+ files = modelloader.load_models(
95
+ model_path=os.path.join(paths.models_path, "BLIP"),
96
+ model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
97
+ ext_filter=[".pth"],
98
+ download_name='model_base_caption_capfilt_large.pth',
99
+ )
100
+
101
+ blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
102
+ blip_model.eval()
103
+
104
+ return blip_model
105
+
106
+ def load_clip_model(self):
107
+ import clip
108
+
109
+ if self.running_on_cpu:
110
+ model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
111
+ else:
112
+ model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
113
+
114
+ model.eval()
115
+ model = model.to(devices.device_interrogate)
116
+
117
+ return model, preprocess
118
+
119
+ def load(self):
120
+ if self.blip_model is None:
121
+ self.blip_model = self.load_blip_model()
122
+ if not shared.cmd_opts.no_half and not self.running_on_cpu:
123
+ self.blip_model = self.blip_model.half()
124
+
125
+ self.blip_model = self.blip_model.to(devices.device_interrogate)
126
+
127
+ if self.clip_model is None:
128
+ self.clip_model, self.clip_preprocess = self.load_clip_model()
129
+ if not shared.cmd_opts.no_half and not self.running_on_cpu:
130
+ self.clip_model = self.clip_model.half()
131
+
132
+ self.clip_model = self.clip_model.to(devices.device_interrogate)
133
+
134
+ self.dtype = torch_utils.get_param(self.clip_model).dtype
135
+
136
+ def send_clip_to_ram(self):
137
+ if not shared.opts.interrogate_keep_models_in_memory:
138
+ if self.clip_model is not None:
139
+ self.clip_model = self.clip_model.to(devices.cpu)
140
+
141
+ def send_blip_to_ram(self):
142
+ if not shared.opts.interrogate_keep_models_in_memory:
143
+ if self.blip_model is not None:
144
+ self.blip_model = self.blip_model.to(devices.cpu)
145
+
146
+ def unload(self):
147
+ self.send_clip_to_ram()
148
+ self.send_blip_to_ram()
149
+
150
+ devices.torch_gc()
151
+
152
+ def rank(self, image_features, text_array, top_count=1):
153
+ import clip
154
+
155
+ devices.torch_gc()
156
+
157
+ if shared.opts.interrogate_clip_dict_limit != 0:
158
+ text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
159
+
160
+ top_count = min(top_count, len(text_array))
161
+ text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
162
+ text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
163
+ text_features /= text_features.norm(dim=-1, keepdim=True)
164
+
165
+ similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
166
+ for i in range(image_features.shape[0]):
167
+ similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
168
+ similarity /= image_features.shape[0]
169
+
170
+ top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
171
+ return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
172
+
173
+ def generate_caption(self, pil_image):
174
+ gpu_image = transforms.Compose([
175
+ transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
176
+ transforms.ToTensor(),
177
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
178
+ ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
179
+
180
+ with torch.no_grad():
181
+ caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
182
+
183
+ return caption[0]
184
+
185
+ def interrogate(self, pil_image):
186
+ res = ""
187
+ shared.state.begin(job="interrogate")
188
+ try:
189
+ lowvram.send_everything_to_cpu()
190
+ devices.torch_gc()
191
+
192
+ self.load()
193
+
194
+ caption = self.generate_caption(pil_image)
195
+ self.send_blip_to_ram()
196
+ devices.torch_gc()
197
+
198
+ res = caption
199
+
200
+ clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
201
+
202
+ with torch.no_grad(), devices.autocast():
203
+ image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
204
+
205
+ image_features /= image_features.norm(dim=-1, keepdim=True)
206
+
207
+ for cat in self.categories():
208
+ matches = self.rank(image_features, cat.items, top_count=cat.topn)
209
+ for match, score in matches:
210
+ if shared.opts.interrogate_return_ranks:
211
+ res += f", ({match}:{score/100:.3f})"
212
+ else:
213
+ res += f", {match}"
214
+
215
+ except Exception:
216
+ errors.report("Error interrogating", exc_info=True)
217
+ res += "<error>"
218
+
219
+ self.unload()
220
+ shared.state.end()
221
+
222
+ return res
modules/launch_utils.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this scripts installs necessary requirements and launches main program in webui.py
2
+ import logging
3
+ import re
4
+ import subprocess
5
+ import os
6
+ import shutil
7
+ import sys
8
+ import importlib.util
9
+ import importlib.metadata
10
+ import platform
11
+ import json
12
+ import shlex
13
+ from functools import lru_cache
14
+
15
+ from modules import cmd_args, errors
16
+ from modules.paths_internal import script_path, extensions_dir
17
+ from modules.timer import startup_timer
18
+ from modules import logging_config
19
+
20
+ args, _ = cmd_args.parser.parse_known_args()
21
+ logging_config.setup_logging(args.loglevel)
22
+
23
+ python = sys.executable
24
+ git = os.environ.get('GIT', "git")
25
+ index_url = os.environ.get('INDEX_URL', "")
26
+ dir_repos = "repositories"
27
+
28
+ # Whether to default to printing command output
29
+ default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
30
+
31
+ os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
32
+
33
+
34
+ def check_python_version():
35
+ is_windows = platform.system() == "Windows"
36
+ major = sys.version_info.major
37
+ minor = sys.version_info.minor
38
+ micro = sys.version_info.micro
39
+
40
+ if is_windows:
41
+ supported_minors = [10]
42
+ else:
43
+ supported_minors = [7, 8, 9, 10, 11]
44
+
45
+ if not (major == 3 and minor in supported_minors):
46
+ import modules.errors
47
+
48
+ modules.errors.print_error_explanation(f"""
49
+ INCOMPATIBLE PYTHON VERSION
50
+
51
+ This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
52
+ If you encounter an error with "RuntimeError: Couldn't install torch." message,
53
+ or any other error regarding unsuccessful package (library) installation,
54
+ please downgrade (or upgrade) to the latest version of 3.10 Python
55
+ and delete current Python and "venv" folder in WebUI's directory.
56
+
57
+ You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/
58
+
59
+ {"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre" if is_windows else ""}
60
+
61
+ Use --skip-python-version-check to suppress this warning.
62
+ """)
63
+
64
+
65
+ @lru_cache()
66
+ def commit_hash():
67
+ try:
68
+ return subprocess.check_output([git, "-C", script_path, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
69
+ except Exception:
70
+ return "<none>"
71
+
72
+
73
+ @lru_cache()
74
+ def git_tag():
75
+ try:
76
+ return subprocess.check_output([git, "-C", script_path, "describe", "--tags"], shell=False, encoding='utf8').strip()
77
+ except Exception:
78
+ try:
79
+
80
+ changelog_md = os.path.join(script_path, "CHANGELOG.md")
81
+ with open(changelog_md, "r", encoding="utf-8") as file:
82
+ line = next((line.strip() for line in file if line.strip()), "<none>")
83
+ line = line.replace("## ", "")
84
+ return line
85
+ except Exception:
86
+ return "<none>"
87
+
88
+
89
+ def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
90
+ if desc is not None:
91
+ print(desc)
92
+
93
+ run_kwargs = {
94
+ "args": command,
95
+ "shell": True,
96
+ "env": os.environ if custom_env is None else custom_env,
97
+ "encoding": 'utf8',
98
+ "errors": 'ignore',
99
+ }
100
+
101
+ if not live:
102
+ run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
103
+
104
+ result = subprocess.run(**run_kwargs)
105
+
106
+ if result.returncode != 0:
107
+ error_bits = [
108
+ f"{errdesc or 'Error running command'}.",
109
+ f"Command: {command}",
110
+ f"Error code: {result.returncode}",
111
+ ]
112
+ if result.stdout:
113
+ error_bits.append(f"stdout: {result.stdout}")
114
+ if result.stderr:
115
+ error_bits.append(f"stderr: {result.stderr}")
116
+ raise RuntimeError("\n".join(error_bits))
117
+
118
+ return (result.stdout or "")
119
+
120
+
121
+ def is_installed(package):
122
+ try:
123
+ dist = importlib.metadata.distribution(package)
124
+ except importlib.metadata.PackageNotFoundError:
125
+ try:
126
+ spec = importlib.util.find_spec(package)
127
+ except ModuleNotFoundError:
128
+ return False
129
+
130
+ return spec is not None
131
+
132
+ return dist is not None
133
+
134
+
135
+ def repo_dir(name):
136
+ return os.path.join(script_path, dir_repos, name)
137
+
138
+
139
+ def run_pip(command, desc=None, live=default_command_live):
140
+ if args.skip_install:
141
+ return
142
+
143
+ index_url_line = f' --index-url {index_url}' if index_url != '' else ''
144
+ return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
145
+
146
+
147
+ def check_run_python(code: str) -> bool:
148
+ result = subprocess.run([python, "-c", code], capture_output=True, shell=False)
149
+ return result.returncode == 0
150
+
151
+
152
+ def git_fix_workspace(dir, name):
153
+ run(f'"{git}" -C "{dir}" fetch --refetch --no-auto-gc', f"Fetching all contents for {name}", f"Couldn't fetch {name}", live=True)
154
+ run(f'"{git}" -C "{dir}" gc --aggressive --prune=now', f"Pruning {name}", f"Couldn't prune {name}", live=True)
155
+ return
156
+
157
+
158
+ def run_git(dir, name, command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live, autofix=True):
159
+ try:
160
+ return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
161
+ except RuntimeError:
162
+ if not autofix:
163
+ raise
164
+
165
+ print(f"{errdesc}, attempting autofix...")
166
+ git_fix_workspace(dir, name)
167
+
168
+ return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
169
+
170
+
171
+ def git_clone(url, dir, name, commithash=None):
172
+ # TODO clone into temporary dir and move if successful
173
+
174
+ if os.path.exists(dir):
175
+ if commithash is None:
176
+ return
177
+
178
+ current_hash = run_git(dir, name, 'rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
179
+ if current_hash == commithash:
180
+ return
181
+
182
+ if run_git(dir, name, 'config --get remote.origin.url', None, f"Couldn't determine {name}'s origin URL", live=False).strip() != url:
183
+ run_git(dir, name, f'remote set-url origin "{url}"', None, f"Failed to set {name}'s origin URL", live=False)
184
+
185
+ run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False)
186
+
187
+ run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
188
+
189
+ return
190
+
191
+ try:
192
+ run(f'"{git}" clone --config core.filemode=false "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
193
+ except RuntimeError:
194
+ shutil.rmtree(dir, ignore_errors=True)
195
+ raise
196
+
197
+ if commithash is not None:
198
+ run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
199
+
200
+
201
+ def git_pull_recursive(dir):
202
+ for subdir, _, _ in os.walk(dir):
203
+ if os.path.exists(os.path.join(subdir, '.git')):
204
+ try:
205
+ output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
206
+ print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
207
+ except subprocess.CalledProcessError as e:
208
+ print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
209
+
210
+
211
+ def version_check(commit):
212
+ try:
213
+ import requests
214
+ commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
215
+ if commit != "<none>" and commits['commit']['sha'] != commit:
216
+ print("--------------------------------------------------------")
217
+ print("| You are not up to date with the most recent release. |")
218
+ print("| Consider running `git pull` to update. |")
219
+ print("--------------------------------------------------------")
220
+ elif commits['commit']['sha'] == commit:
221
+ print("You are up to date with the most recent release.")
222
+ else:
223
+ print("Not a git clone, can't perform version check.")
224
+ except Exception as e:
225
+ print("version check failed", e)
226
+
227
+
228
+ def run_extension_installer(extension_dir):
229
+ path_installer = os.path.join(extension_dir, "install.py")
230
+ if not os.path.isfile(path_installer):
231
+ return
232
+
233
+ try:
234
+ env = os.environ.copy()
235
+ env['PYTHONPATH'] = f"{script_path}{os.pathsep}{env.get('PYTHONPATH', '')}"
236
+
237
+ stdout = run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env).strip()
238
+ if stdout:
239
+ print(stdout)
240
+ except Exception as e:
241
+ errors.report(str(e))
242
+
243
+
244
+ def list_extensions(settings_file):
245
+ settings = {}
246
+
247
+ try:
248
+ with open(settings_file, "r", encoding="utf8") as file:
249
+ settings = json.load(file)
250
+ except FileNotFoundError:
251
+ pass
252
+ except Exception:
253
+ errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
254
+ os.replace(settings_file, os.path.join(script_path, "tmp", "config.json"))
255
+
256
+ disabled_extensions = set(settings.get('disabled_extensions', []))
257
+ disable_all_extensions = settings.get('disable_all_extensions', 'none')
258
+
259
+ if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions or not os.path.isdir(extensions_dir):
260
+ return []
261
+
262
+ return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
263
+
264
+
265
+ def run_extensions_installers(settings_file):
266
+ if not os.path.isdir(extensions_dir):
267
+ return
268
+
269
+ with startup_timer.subcategory("run extensions installers"):
270
+ for dirname_extension in list_extensions(settings_file):
271
+ logging.debug(f"Installing {dirname_extension}")
272
+
273
+ path = os.path.join(extensions_dir, dirname_extension)
274
+
275
+ if os.path.isdir(path):
276
+ run_extension_installer(path)
277
+ startup_timer.record(dirname_extension)
278
+
279
+
280
+ re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
281
+
282
+
283
+ def requirements_met(requirements_file):
284
+ """
285
+ Does a simple parse of a requirements.txt file to determine if all rerqirements in it
286
+ are already installed. Returns True if so, False if not installed or parsing fails.
287
+ """
288
+
289
+ import importlib.metadata
290
+ import packaging.version
291
+
292
+ with open(requirements_file, "r", encoding="utf8") as file:
293
+ for line in file:
294
+ if line.strip() == "":
295
+ continue
296
+
297
+ m = re.match(re_requirement, line)
298
+ if m is None:
299
+ return False
300
+
301
+ package = m.group(1).strip()
302
+ version_required = (m.group(2) or "").strip()
303
+
304
+ if version_required == "":
305
+ continue
306
+
307
+ try:
308
+ version_installed = importlib.metadata.version(package)
309
+ except Exception:
310
+ return False
311
+
312
+ if packaging.version.parse(version_required) != packaging.version.parse(version_installed):
313
+ return False
314
+
315
+ return True
316
+
317
+
318
+ def prepare_environment():
319
+ torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
320
+ torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}")
321
+ if args.use_ipex:
322
+ if platform.system() == "Windows":
323
+ # The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
324
+ # This is NOT an Intel official release so please use it at your own risk!!
325
+ # See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details.
326
+ #
327
+ # Strengths (over official IPEX 2.0.110 windows release):
328
+ # - AOT build (for Arc GPU only) to eliminate JIT compilation overhead: https://github.com/intel/intel-extension-for-pytorch/issues/399
329
+ # - Bundles minimal oneAPI 2023.2 dependencies into the python wheels, so users don't need to install oneAPI for the whole system.
330
+ # - Provides a compatible torchvision wheel: https://github.com/intel/intel-extension-for-pytorch/issues/465
331
+ # Limitation:
332
+ # - Only works for python 3.10
333
+ url_prefix = "https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle"
334
+ torch_command = os.environ.get('TORCH_COMMAND', f"pip install {url_prefix}/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl {url_prefix}/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl {url_prefix}/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl")
335
+ else:
336
+ # Using official IPEX release for linux since it's already an AOT build.
337
+ # However, users still have to install oneAPI toolkit and activate oneAPI environment manually.
338
+ # See https://intel.github.io/intel-extension-for-pytorch/index.html#installation for details.
339
+ torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/")
340
+ torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
341
+ requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
342
+ requirements_file_for_npu = os.environ.get('REQS_FILE_FOR_NPU', "requirements_npu.txt")
343
+
344
+ xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1')
345
+ clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
346
+ openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
347
+
348
+ assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git")
349
+ stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
350
+ stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
351
+ k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
352
+ blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
353
+
354
+ assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
355
+ stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
356
+ stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
357
+ k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
358
+ blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
359
+
360
+ try:
361
+ # the existence of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
362
+ os.remove(os.path.join(script_path, "tmp", "restart"))
363
+ os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
364
+ except OSError:
365
+ pass
366
+
367
+ if not args.skip_python_version_check:
368
+ check_python_version()
369
+
370
+ startup_timer.record("checks")
371
+
372
+ commit = commit_hash()
373
+ tag = git_tag()
374
+ startup_timer.record("git version info")
375
+
376
+ print(f"Python {sys.version}")
377
+ print(f"Version: {tag}")
378
+ print(f"Commit hash: {commit}")
379
+
380
+ if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
381
+ run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
382
+ startup_timer.record("install torch")
383
+
384
+ if args.use_ipex:
385
+ args.skip_torch_cuda_test = True
386
+ if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
387
+ raise RuntimeError(
388
+ 'Torch is not able to use GPU; '
389
+ 'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
390
+ )
391
+ startup_timer.record("torch GPU test")
392
+
393
+ if not is_installed("clip"):
394
+ run_pip(f"install {clip_package}", "clip")
395
+ startup_timer.record("install clip")
396
+
397
+ if not is_installed("open_clip"):
398
+ run_pip(f"install {openclip_package}", "open_clip")
399
+ startup_timer.record("install open_clip")
400
+
401
+ if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
402
+ run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
403
+ startup_timer.record("install xformers")
404
+
405
+ if not is_installed("ngrok") and args.ngrok:
406
+ run_pip("install ngrok", "ngrok")
407
+ startup_timer.record("install ngrok")
408
+
409
+ os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
410
+
411
+ git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash)
412
+ git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
413
+ git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
414
+ git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
415
+ git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
416
+
417
+ startup_timer.record("clone repositores")
418
+
419
+ if not os.path.isfile(requirements_file):
420
+ requirements_file = os.path.join(script_path, requirements_file)
421
+
422
+ if not requirements_met(requirements_file):
423
+ run_pip(f"install -r \"{requirements_file}\"", "requirements")
424
+ startup_timer.record("install requirements")
425
+
426
+ if not os.path.isfile(requirements_file_for_npu):
427
+ requirements_file_for_npu = os.path.join(script_path, requirements_file_for_npu)
428
+
429
+ if "torch_npu" in torch_command and not requirements_met(requirements_file_for_npu):
430
+ run_pip(f"install -r \"{requirements_file_for_npu}\"", "requirements_for_npu")
431
+ startup_timer.record("install requirements_for_npu")
432
+
433
+ if not args.skip_install:
434
+ run_extensions_installers(settings_file=args.ui_settings_file)
435
+
436
+ if args.update_check:
437
+ version_check(commit)
438
+ startup_timer.record("check version")
439
+
440
+ if args.update_all_extensions:
441
+ git_pull_recursive(extensions_dir)
442
+ startup_timer.record("update extensions")
443
+
444
+ if "--exit" in sys.argv:
445
+ print("Exiting because of --exit argument")
446
+ exit(0)
447
+
448
+
449
+ def configure_for_tests():
450
+ if "--api" not in sys.argv:
451
+ sys.argv.append("--api")
452
+ if "--ckpt" not in sys.argv:
453
+ sys.argv.append("--ckpt")
454
+ sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt"))
455
+ if "--skip-torch-cuda-test" not in sys.argv:
456
+ sys.argv.append("--skip-torch-cuda-test")
457
+ if "--disable-nan-check" not in sys.argv:
458
+ sys.argv.append("--disable-nan-check")
459
+
460
+ os.environ['COMMANDLINE_ARGS'] = ""
461
+
462
+
463
+ def start():
464
+ print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {shlex.join(sys.argv[1:])}")
465
+ import webui
466
+ if '--nowebui' in sys.argv:
467
+ webui.api_only()
468
+ else:
469
+ webui.webui()
470
+
471
+
472
+ def dump_sysinfo():
473
+ from modules import sysinfo
474
+ import datetime
475
+
476
+ text = sysinfo.get()
477
+ filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
478
+
479
+ with open(filename, "w", encoding="utf8") as file:
480
+ file.write(text)
481
+
482
+ return filename
modules/localization.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ from modules import errors, scripts
5
+
6
+ localizations = {}
7
+
8
+
9
+ def list_localizations(dirname):
10
+ localizations.clear()
11
+
12
+ for file in os.listdir(dirname):
13
+ fn, ext = os.path.splitext(file)
14
+ if ext.lower() != ".json":
15
+ continue
16
+
17
+ localizations[fn] = [os.path.join(dirname, file)]
18
+
19
+ for file in scripts.list_scripts("localizations", ".json"):
20
+ fn, ext = os.path.splitext(file.filename)
21
+ if fn not in localizations:
22
+ localizations[fn] = []
23
+ localizations[fn].append(file.path)
24
+
25
+
26
+ def localization_js(current_localization_name: str) -> str:
27
+ fns = localizations.get(current_localization_name, None)
28
+ data = {}
29
+ if fns is not None:
30
+ for fn in fns:
31
+ try:
32
+ with open(fn, "r", encoding="utf8") as file:
33
+ data.update(json.load(file))
34
+ except Exception:
35
+ errors.report(f"Error loading localization from {fn}", exc_info=True)
36
+
37
+ return f"window.localization = {json.dumps(data)}"
modules/logging_config.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ try:
5
+ from tqdm import tqdm
6
+
7
+
8
+ class TqdmLoggingHandler(logging.Handler):
9
+ def __init__(self, fallback_handler: logging.Handler):
10
+ super().__init__()
11
+ self.fallback_handler = fallback_handler
12
+
13
+ def emit(self, record):
14
+ try:
15
+ # If there are active tqdm progress bars,
16
+ # attempt to not interfere with them.
17
+ if tqdm._instances:
18
+ tqdm.write(self.format(record))
19
+ else:
20
+ self.fallback_handler.emit(record)
21
+ except Exception:
22
+ self.fallback_handler.emit(record)
23
+
24
+ except ImportError:
25
+ TqdmLoggingHandler = None
26
+
27
+
28
+ def setup_logging(loglevel):
29
+ if loglevel is None:
30
+ loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
31
+
32
+ if not loglevel:
33
+ return
34
+
35
+ if logging.root.handlers:
36
+ # Already configured, do not interfere
37
+ return
38
+
39
+ formatter = logging.Formatter(
40
+ '%(asctime)s %(levelname)s [%(name)s] %(message)s',
41
+ '%Y-%m-%d %H:%M:%S',
42
+ )
43
+
44
+ if os.environ.get("SD_WEBUI_RICH_LOG"):
45
+ from rich.logging import RichHandler
46
+ handler = RichHandler()
47
+ else:
48
+ handler = logging.StreamHandler()
49
+ handler.setFormatter(formatter)
50
+
51
+ if TqdmLoggingHandler:
52
+ handler = TqdmLoggingHandler(handler)
53
+
54
+ handler.setFormatter(formatter)
55
+
56
+ log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
57
+ logging.root.setLevel(log_level)
58
+ logging.root.addHandler(handler)
modules/lowvram.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+
3
+ import torch
4
+ from modules import devices, shared
5
+
6
+ module_in_gpu = None
7
+ cpu = torch.device("cpu")
8
+
9
+ ModuleWithParent = namedtuple('ModuleWithParent', ['module', 'parent'], defaults=['None'])
10
+
11
+ def send_everything_to_cpu():
12
+ global module_in_gpu
13
+
14
+ if module_in_gpu is not None:
15
+ module_in_gpu.to(cpu)
16
+
17
+ module_in_gpu = None
18
+
19
+
20
+ def is_needed(sd_model):
21
+ return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner')
22
+
23
+
24
+ def apply(sd_model):
25
+ enable = is_needed(sd_model)
26
+ shared.parallel_processing_allowed = not enable
27
+
28
+ if enable:
29
+ setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram)
30
+ else:
31
+ sd_model.lowvram = False
32
+
33
+
34
+ def setup_for_low_vram(sd_model, use_medvram):
35
+ if getattr(sd_model, 'lowvram', False):
36
+ return
37
+
38
+ sd_model.lowvram = True
39
+
40
+ parents = {}
41
+
42
+ def send_me_to_gpu(module, _):
43
+ """send this module to GPU; send whatever tracked module was previous in GPU to CPU;
44
+ we add this as forward_pre_hook to a lot of modules and this way all but one of them will
45
+ be in CPU
46
+ """
47
+ global module_in_gpu
48
+
49
+ module = parents.get(module, module)
50
+
51
+ if module_in_gpu == module:
52
+ return
53
+
54
+ if module_in_gpu is not None:
55
+ module_in_gpu.to(cpu)
56
+
57
+ module.to(devices.device)
58
+ module_in_gpu = module
59
+
60
+ # see below for register_forward_pre_hook;
61
+ # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
62
+ # useless here, and we just replace those methods
63
+
64
+ first_stage_model = sd_model.first_stage_model
65
+ first_stage_model_encode = sd_model.first_stage_model.encode
66
+ first_stage_model_decode = sd_model.first_stage_model.decode
67
+
68
+ def first_stage_model_encode_wrap(x):
69
+ send_me_to_gpu(first_stage_model, None)
70
+ return first_stage_model_encode(x)
71
+
72
+ def first_stage_model_decode_wrap(z):
73
+ send_me_to_gpu(first_stage_model, None)
74
+ return first_stage_model_decode(z)
75
+
76
+ to_remain_in_cpu = [
77
+ (sd_model, 'first_stage_model'),
78
+ (sd_model, 'depth_model'),
79
+ (sd_model, 'embedder'),
80
+ (sd_model, 'model'),
81
+ ]
82
+
83
+ is_sdxl = hasattr(sd_model, 'conditioner')
84
+ is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
85
+
86
+ if hasattr(sd_model, 'medvram_fields'):
87
+ to_remain_in_cpu = sd_model.medvram_fields()
88
+ elif is_sdxl:
89
+ to_remain_in_cpu.append((sd_model, 'conditioner'))
90
+ elif is_sd2:
91
+ to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
92
+ else:
93
+ to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))
94
+
95
+ # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
96
+ stored = []
97
+ for obj, field in to_remain_in_cpu:
98
+ module = getattr(obj, field, None)
99
+ stored.append(module)
100
+ setattr(obj, field, None)
101
+
102
+ # send the model to GPU.
103
+ sd_model.to(devices.device)
104
+
105
+ # put modules back. the modules will be in CPU.
106
+ for (obj, field), module in zip(to_remain_in_cpu, stored):
107
+ setattr(obj, field, module)
108
+
109
+ # register hooks for those the first three models
110
+ if hasattr(sd_model, "cond_stage_model") and hasattr(sd_model.cond_stage_model, "medvram_modules"):
111
+ for module in sd_model.cond_stage_model.medvram_modules():
112
+ if isinstance(module, ModuleWithParent):
113
+ parent = module.parent
114
+ module = module.module
115
+ else:
116
+ parent = None
117
+
118
+ if module:
119
+ module.register_forward_pre_hook(send_me_to_gpu)
120
+
121
+ if parent:
122
+ parents[module] = parent
123
+
124
+ elif is_sdxl:
125
+ sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
126
+ elif is_sd2:
127
+ sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
128
+ sd_model.cond_stage_model.model.token_embedding.register_forward_pre_hook(send_me_to_gpu)
129
+ parents[sd_model.cond_stage_model.model] = sd_model.cond_stage_model
130
+ parents[sd_model.cond_stage_model.model.token_embedding] = sd_model.cond_stage_model
131
+ else:
132
+ sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
133
+ parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
134
+
135
+ sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
136
+ sd_model.first_stage_model.encode = first_stage_model_encode_wrap
137
+ sd_model.first_stage_model.decode = first_stage_model_decode_wrap
138
+ if getattr(sd_model, 'depth_model', None) is not None:
139
+ sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
140
+ if getattr(sd_model, 'embedder', None) is not None:
141
+ sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
142
+
143
+ if use_medvram:
144
+ sd_model.model.register_forward_pre_hook(send_me_to_gpu)
145
+ else:
146
+ diff_model = sd_model.model.diffusion_model
147
+
148
+ # the third remaining model is still too big for 4 GB, so we also do the same for its submodules
149
+ # so that only one of them is in GPU at a time
150
+ stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
151
+ diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
152
+ sd_model.model.to(devices.device)
153
+ diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
154
+
155
+ # install hooks for bits of third model
156
+ diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
157
+ for block in diff_model.input_blocks:
158
+ block.register_forward_pre_hook(send_me_to_gpu)
159
+ diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
160
+ for block in diff_model.output_blocks:
161
+ block.register_forward_pre_hook(send_me_to_gpu)
162
+
163
+
164
+ def is_enabled(sd_model):
165
+ return sd_model.lowvram
modules/mac_specific.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ import platform
6
+ from modules.sd_hijack_utils import CondFunc
7
+ from packaging import version
8
+ from modules import shared
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+
13
+ # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
14
+ # use check `getattr` and try it for compatibility.
15
+ # in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availability,
16
+ # since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
17
+ def check_for_mps() -> bool:
18
+ if version.parse(torch.__version__) <= version.parse("2.0.1"):
19
+ if not getattr(torch, 'has_mps', False):
20
+ return False
21
+ try:
22
+ torch.zeros(1).to(torch.device("mps"))
23
+ return True
24
+ except Exception:
25
+ return False
26
+ else:
27
+ return torch.backends.mps.is_available() and torch.backends.mps.is_built()
28
+
29
+
30
+ has_mps = check_for_mps()
31
+
32
+
33
+ def torch_mps_gc() -> None:
34
+ try:
35
+ if shared.state.current_latent is not None:
36
+ log.debug("`current_latent` is set, skipping MPS garbage collection")
37
+ return
38
+ from torch.mps import empty_cache
39
+ empty_cache()
40
+ except Exception:
41
+ log.warning("MPS garbage collection failed", exc_info=True)
42
+
43
+
44
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/89784
45
+ def cumsum_fix(input, cumsum_func, *args, **kwargs):
46
+ if input.device.type == 'mps':
47
+ output_dtype = kwargs.get('dtype', input.dtype)
48
+ if output_dtype == torch.int64:
49
+ return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
50
+ elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
51
+ return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
52
+ return cumsum_func(input, *args, **kwargs)
53
+
54
+
55
+ # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
56
+ def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:
57
+ try:
58
+ return orig_func(*args, **kwargs)
59
+ except RuntimeError as e:
60
+ if "not implemented for" in str(e) and "Half" in str(e):
61
+ input_tensor = args[0]
62
+ return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)
63
+ else:
64
+ print(f"An unexpected RuntimeError occurred: {str(e)}")
65
+
66
+ if has_mps:
67
+ if platform.mac_ver()[0].startswith("13.2."):
68
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
69
+ CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
70
+
71
+ if version.parse(torch.__version__) < version.parse("1.13"):
72
+ # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
73
+
74
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
75
+ CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
76
+ lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
77
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
78
+ CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
79
+ lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
80
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/90532
81
+ CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
82
+ elif version.parse(torch.__version__) > version.parse("1.13.1"):
83
+ cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
84
+ cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
85
+ CondFunc('torch.cumsum', cumsum_fix_func, None)
86
+ CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
87
+ CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
88
+
89
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/96113
90
+ CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
91
+
92
+ # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
93
+ CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)
94
+
95
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/92311
96
+ if platform.processor() == 'i386':
97
+ for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
98
+ CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
modules/masking.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageFilter, ImageOps
2
+
3
+
4
+ def get_crop_region_v2(mask, pad=0):
5
+ """
6
+ Finds a rectangular region that contains all masked ares in a mask.
7
+ Returns None if mask is completely black mask (all 0)
8
+
9
+ Parameters:
10
+ mask: PIL.Image.Image L mode or numpy 1d array
11
+ pad: int number of pixels that the region will be extended on all sides
12
+ Returns: (x1, y1, x2, y2) | None
13
+
14
+ Introduced post 1.9.0
15
+ """
16
+ mask = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)
17
+ if box := mask.getbbox():
18
+ x1, y1, x2, y2 = box
19
+ return (max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask.size[0]), min(y2 + pad, mask.size[1])) if pad else box
20
+
21
+
22
+ def get_crop_region(mask, pad=0):
23
+ """
24
+ Same function as get_crop_region_v2 but handles completely black mask (all 0) differently
25
+ when mask all black still return coordinates but the coordinates may be invalid ie x2>x1 or y2>y1
26
+ Notes: it is possible for the coordinates to be "valid" again if pad size is sufficiently large
27
+ (mask_size.x-pad, mask_size.y-pad, pad, pad)
28
+
29
+ Extension developer should use get_crop_region_v2 instead unless for compatibility considerations.
30
+ """
31
+ mask = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)
32
+ if box := get_crop_region_v2(mask, pad):
33
+ return box
34
+ x1, y1 = mask.size
35
+ x2 = y2 = 0
36
+ return max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask.size[0]), min(y2 + pad, mask.size[1])
37
+
38
+
39
+ def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
40
+ """expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region
41
+ for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128."""
42
+
43
+ x1, y1, x2, y2 = crop_region
44
+
45
+ ratio_crop_region = (x2 - x1) / (y2 - y1)
46
+ ratio_processing = processing_width / processing_height
47
+
48
+ if ratio_crop_region > ratio_processing:
49
+ desired_height = (x2 - x1) / ratio_processing
50
+ desired_height_diff = int(desired_height - (y2-y1))
51
+ y1 -= desired_height_diff//2
52
+ y2 += desired_height_diff - desired_height_diff//2
53
+ if y2 >= image_height:
54
+ diff = y2 - image_height
55
+ y2 -= diff
56
+ y1 -= diff
57
+ if y1 < 0:
58
+ y2 -= y1
59
+ y1 -= y1
60
+ if y2 >= image_height:
61
+ y2 = image_height
62
+ else:
63
+ desired_width = (y2 - y1) * ratio_processing
64
+ desired_width_diff = int(desired_width - (x2-x1))
65
+ x1 -= desired_width_diff//2
66
+ x2 += desired_width_diff - desired_width_diff//2
67
+ if x2 >= image_width:
68
+ diff = x2 - image_width
69
+ x2 -= diff
70
+ x1 -= diff
71
+ if x1 < 0:
72
+ x2 -= x1
73
+ x1 -= x1
74
+ if x2 >= image_width:
75
+ x2 = image_width
76
+
77
+ return x1, y1, x2, y2
78
+
79
+
80
+ def fill(image, mask):
81
+ """fills masked regions with colors from image using blur. Not extremely effective."""
82
+
83
+ image_mod = Image.new('RGBA', (image.width, image.height))
84
+
85
+ image_masked = Image.new('RGBa', (image.width, image.height))
86
+ image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
87
+
88
+ image_masked = image_masked.convert('RGBa')
89
+
90
+ for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
91
+ blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
92
+ for _ in range(repeats):
93
+ image_mod.alpha_composite(blurred)
94
+
95
+ return image_mod.convert("RGB")
96
+
modules/memmon.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+ from collections import defaultdict
4
+
5
+ import torch
6
+
7
+
8
+ class MemUsageMonitor(threading.Thread):
9
+ run_flag = None
10
+ device = None
11
+ disabled = False
12
+ opts = None
13
+ data = None
14
+
15
+ def __init__(self, name, device, opts):
16
+ threading.Thread.__init__(self)
17
+ self.name = name
18
+ self.device = device
19
+ self.opts = opts
20
+
21
+ self.daemon = True
22
+ self.run_flag = threading.Event()
23
+ self.data = defaultdict(int)
24
+
25
+ try:
26
+ self.cuda_mem_get_info()
27
+ torch.cuda.memory_stats(self.device)
28
+ except Exception as e: # AMD or whatever
29
+ print(f"Warning: caught exception '{e}', memory monitor disabled")
30
+ self.disabled = True
31
+
32
+ def cuda_mem_get_info(self):
33
+ index = self.device.index if self.device.index is not None else torch.cuda.current_device()
34
+ return torch.cuda.mem_get_info(index)
35
+
36
+ def run(self):
37
+ if self.disabled:
38
+ return
39
+
40
+ while True:
41
+ self.run_flag.wait()
42
+
43
+ torch.cuda.reset_peak_memory_stats()
44
+ self.data.clear()
45
+
46
+ if self.opts.memmon_poll_rate <= 0:
47
+ self.run_flag.clear()
48
+ continue
49
+
50
+ self.data["min_free"] = self.cuda_mem_get_info()[0]
51
+
52
+ while self.run_flag.is_set():
53
+ free, total = self.cuda_mem_get_info()
54
+ self.data["min_free"] = min(self.data["min_free"], free)
55
+
56
+ time.sleep(1 / self.opts.memmon_poll_rate)
57
+
58
+ def dump_debug(self):
59
+ print(self, 'recorded data:')
60
+ for k, v in self.read().items():
61
+ print(k, -(v // -(1024 ** 2)))
62
+
63
+ print(self, 'raw torch memory stats:')
64
+ tm = torch.cuda.memory_stats(self.device)
65
+ for k, v in tm.items():
66
+ if 'bytes' not in k:
67
+ continue
68
+ print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
69
+
70
+ print(torch.cuda.memory_summary())
71
+
72
+ def monitor(self):
73
+ self.run_flag.set()
74
+
75
+ def read(self):
76
+ if not self.disabled:
77
+ free, total = self.cuda_mem_get_info()
78
+ self.data["free"] = free
79
+ self.data["total"] = total
80
+
81
+ torch_stats = torch.cuda.memory_stats(self.device)
82
+ self.data["active"] = torch_stats["active.all.current"]
83
+ self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
84
+ self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
85
+ self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
86
+ self.data["system_peak"] = total - self.data["min_free"]
87
+
88
+ return self.data
89
+
90
+ def stop(self):
91
+ self.run_flag.clear()
92
+ return self.read()
modules/modelloader.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ import logging
5
+ import os
6
+ from typing import TYPE_CHECKING
7
+ from urllib.parse import urlparse
8
+
9
+ import torch
10
+
11
+ from modules import shared
12
+ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
13
+
14
+ if TYPE_CHECKING:
15
+ import spandrel
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def load_file_from_url(
21
+ url: str,
22
+ *,
23
+ model_dir: str,
24
+ progress: bool = True,
25
+ file_name: str | None = None,
26
+ hash_prefix: str | None = None,
27
+ ) -> str:
28
+ """Download a file from `url` into `model_dir`, using the file present if possible.
29
+
30
+ Returns the path to the downloaded file.
31
+ """
32
+ os.makedirs(model_dir, exist_ok=True)
33
+ if not file_name:
34
+ parts = urlparse(url)
35
+ file_name = os.path.basename(parts.path)
36
+ cached_file = os.path.abspath(os.path.join(model_dir, file_name))
37
+ if not os.path.exists(cached_file):
38
+ print(f'Downloading: "{url}" to {cached_file}\n')
39
+ from torch.hub import download_url_to_file
40
+ download_url_to_file(url, cached_file, progress=progress, hash_prefix=hash_prefix)
41
+ return cached_file
42
+
43
+
44
+ def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None, hash_prefix=None) -> list:
45
+ """
46
+ A one-and done loader to try finding the desired models in specified directories.
47
+
48
+ @param download_name: Specify to download from model_url immediately.
49
+ @param model_url: If no other models are found, this will be downloaded on upscale.
50
+ @param model_path: The location to store/find models in.
51
+ @param command_path: A command-line argument to search for models in first.
52
+ @param ext_filter: An optional list of filename extensions to filter by
53
+ @param hash_prefix: the expected sha256 of the model_url
54
+ @return: A list of paths containing the desired model(s)
55
+ """
56
+ output = []
57
+
58
+ try:
59
+ places = []
60
+
61
+ if command_path is not None and command_path != model_path:
62
+ pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
63
+ if os.path.exists(pretrained_path):
64
+ print(f"Appending path: {pretrained_path}")
65
+ places.append(pretrained_path)
66
+ elif os.path.exists(command_path):
67
+ places.append(command_path)
68
+
69
+ places.append(model_path)
70
+
71
+ for place in places:
72
+ for full_path in shared.walk_files(place, allowed_extensions=ext_filter):
73
+ if os.path.islink(full_path) and not os.path.exists(full_path):
74
+ print(f"Skipping broken symlink: {full_path}")
75
+ continue
76
+ if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
77
+ continue
78
+ if full_path not in output:
79
+ output.append(full_path)
80
+
81
+ if model_url is not None and len(output) == 0:
82
+ if download_name is not None:
83
+ output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name, hash_prefix=hash_prefix))
84
+ else:
85
+ output.append(model_url)
86
+
87
+ except Exception:
88
+ pass
89
+
90
+ return output
91
+
92
+
93
+ def friendly_name(file: str):
94
+ if file.startswith("http"):
95
+ file = urlparse(file).path
96
+
97
+ file = os.path.basename(file)
98
+ model_name, extension = os.path.splitext(file)
99
+ return model_name
100
+
101
+
102
+ def load_upscalers():
103
+ # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
104
+ # so we'll try to import any _model.py files before looking in __subclasses__
105
+ modules_dir = os.path.join(shared.script_path, "modules")
106
+ for file in os.listdir(modules_dir):
107
+ if "_model.py" in file:
108
+ model_name = file.replace("_model.py", "")
109
+ full_model = f"modules.{model_name}_model"
110
+ try:
111
+ importlib.import_module(full_model)
112
+ except Exception:
113
+ pass
114
+
115
+ data = []
116
+ commandline_options = vars(shared.cmd_opts)
117
+
118
+ # some of upscaler classes will not go away after reloading their modules, and we'll end
119
+ # up with two copies of those classes. The newest copy will always be the last in the list,
120
+ # so we go from end to beginning and ignore duplicates
121
+ used_classes = {}
122
+ for cls in reversed(Upscaler.__subclasses__()):
123
+ classname = str(cls)
124
+ if classname not in used_classes:
125
+ used_classes[classname] = cls
126
+
127
+ for cls in reversed(used_classes.values()):
128
+ name = cls.__name__
129
+ cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
130
+ commandline_model_path = commandline_options.get(cmd_name, None)
131
+ scaler = cls(commandline_model_path)
132
+ scaler.user_path = commandline_model_path
133
+ scaler.model_download_path = commandline_model_path or scaler.model_path
134
+ data += scaler.scalers
135
+
136
+ shared.sd_upscalers = sorted(
137
+ data,
138
+ # Special case for UpscalerNone keeps it at the beginning of the list.
139
+ key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
140
+ )
141
+
142
+ # None: not loaded, False: failed to load, True: loaded
143
+ _spandrel_extra_init_state = None
144
+
145
+
146
+ def _init_spandrel_extra_archs() -> None:
147
+ """
148
+ Try to initialize `spandrel_extra_archs` (exactly once).
149
+ """
150
+ global _spandrel_extra_init_state
151
+ if _spandrel_extra_init_state is not None:
152
+ return
153
+
154
+ try:
155
+ import spandrel
156
+ import spandrel_extra_arches
157
+ spandrel.MAIN_REGISTRY.add(*spandrel_extra_arches.EXTRA_REGISTRY)
158
+ _spandrel_extra_init_state = True
159
+ except Exception:
160
+ logger.warning("Failed to load spandrel_extra_arches", exc_info=True)
161
+ _spandrel_extra_init_state = False
162
+
163
+
164
+ def load_spandrel_model(
165
+ path: str | os.PathLike,
166
+ *,
167
+ device: str | torch.device | None,
168
+ prefer_half: bool = False,
169
+ dtype: str | torch.dtype | None = None,
170
+ expected_architecture: str | None = None,
171
+ ) -> spandrel.ModelDescriptor:
172
+ global _spandrel_extra_init_state
173
+
174
+ import spandrel
175
+ _init_spandrel_extra_archs()
176
+
177
+ model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))
178
+ arch = model_descriptor.architecture
179
+ if expected_architecture and arch.name != expected_architecture:
180
+ logger.warning(
181
+ f"Model {path!r} is not a {expected_architecture!r} model (got {arch.name!r})",
182
+ )
183
+ half = False
184
+ if prefer_half:
185
+ if model_descriptor.supports_half:
186
+ model_descriptor.model.half()
187
+ half = True
188
+ else:
189
+ logger.info("Model %s does not support half precision, ignoring --half", path)
190
+ if dtype:
191
+ model_descriptor.model.to(dtype=dtype)
192
+ model_descriptor.model.eval()
193
+ logger.debug(
194
+ "Loaded %s from %s (device=%s, half=%s, dtype=%s)",
195
+ arch, path, device, half, dtype,
196
+ )
197
+ return model_descriptor
modules/models/diffusion/ddpm_edit.py ADDED
@@ -0,0 +1,1460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ wild mixture of
3
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
+ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
+ https://github.com/CompVis/taming-transformers
6
+ -- merci
7
+ """
8
+
9
+ # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
10
+ # See more details in LICENSE.
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+ import pytorch_lightning as pl
16
+ from torch.optim.lr_scheduler import LambdaLR
17
+ from einops import rearrange, repeat
18
+ from contextlib import contextmanager
19
+ from functools import partial
20
+ from tqdm import tqdm
21
+ from torchvision.utils import make_grid
22
+ from pytorch_lightning.utilities.distributed import rank_zero_only
23
+
24
+ from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
25
+ from ldm.modules.ema import LitEma
26
+ from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
27
+ from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
28
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
29
+ from ldm.models.diffusion.ddim import DDIMSampler
30
+
31
+ try:
32
+ from ldm.models.autoencoder import VQModelInterface
33
+ except Exception:
34
+ class VQModelInterface:
35
+ pass
36
+
37
+ __conditioning_keys__ = {'concat': 'c_concat',
38
+ 'crossattn': 'c_crossattn',
39
+ 'adm': 'y'}
40
+
41
+
42
+ def disabled_train(self, mode=True):
43
+ """Overwrite model.train with this function to make sure train/eval mode
44
+ does not change anymore."""
45
+ return self
46
+
47
+
48
+ def uniform_on_device(r1, r2, shape, device):
49
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
50
+
51
+
52
+ class DDPM(pl.LightningModule):
53
+ # classic DDPM with Gaussian diffusion, in image space
54
+ def __init__(self,
55
+ unet_config,
56
+ timesteps=1000,
57
+ beta_schedule="linear",
58
+ loss_type="l2",
59
+ ckpt_path=None,
60
+ ignore_keys=None,
61
+ load_only_unet=False,
62
+ monitor="val/loss",
63
+ use_ema=True,
64
+ first_stage_key="image",
65
+ image_size=256,
66
+ channels=3,
67
+ log_every_t=100,
68
+ clip_denoised=True,
69
+ linear_start=1e-4,
70
+ linear_end=2e-2,
71
+ cosine_s=8e-3,
72
+ given_betas=None,
73
+ original_elbo_weight=0.,
74
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
75
+ l_simple_weight=1.,
76
+ conditioning_key=None,
77
+ parameterization="eps", # all assuming fixed variance schedules
78
+ scheduler_config=None,
79
+ use_positional_encodings=False,
80
+ learn_logvar=False,
81
+ logvar_init=0.,
82
+ load_ema=True,
83
+ ):
84
+ super().__init__()
85
+ assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
86
+ self.parameterization = parameterization
87
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
88
+ self.cond_stage_model = None
89
+ self.clip_denoised = clip_denoised
90
+ self.log_every_t = log_every_t
91
+ self.first_stage_key = first_stage_key
92
+ self.image_size = image_size # try conv?
93
+ self.channels = channels
94
+ self.use_positional_encodings = use_positional_encodings
95
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
96
+ count_params(self.model, verbose=True)
97
+ self.use_ema = use_ema
98
+
99
+ self.use_scheduler = scheduler_config is not None
100
+ if self.use_scheduler:
101
+ self.scheduler_config = scheduler_config
102
+
103
+ self.v_posterior = v_posterior
104
+ self.original_elbo_weight = original_elbo_weight
105
+ self.l_simple_weight = l_simple_weight
106
+
107
+ if monitor is not None:
108
+ self.monitor = monitor
109
+
110
+ if self.use_ema and load_ema:
111
+ self.model_ema = LitEma(self.model)
112
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
113
+
114
+ if ckpt_path is not None:
115
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
116
+
117
+ # If initialing from EMA-only checkpoint, create EMA model after loading.
118
+ if self.use_ema and not load_ema:
119
+ self.model_ema = LitEma(self.model)
120
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
121
+
122
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
123
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
124
+
125
+ self.loss_type = loss_type
126
+
127
+ self.learn_logvar = learn_logvar
128
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
129
+ if self.learn_logvar:
130
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
131
+
132
+
133
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
134
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
135
+ if exists(given_betas):
136
+ betas = given_betas
137
+ else:
138
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
139
+ cosine_s=cosine_s)
140
+ alphas = 1. - betas
141
+ alphas_cumprod = np.cumprod(alphas, axis=0)
142
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
143
+
144
+ timesteps, = betas.shape
145
+ self.num_timesteps = int(timesteps)
146
+ self.linear_start = linear_start
147
+ self.linear_end = linear_end
148
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
149
+
150
+ to_torch = partial(torch.tensor, dtype=torch.float32)
151
+
152
+ self.register_buffer('betas', to_torch(betas))
153
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
154
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
155
+
156
+ # calculations for diffusion q(x_t | x_{t-1}) and others
157
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
158
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
159
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
160
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
161
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
162
+
163
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
164
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
165
+ 1. - alphas_cumprod) + self.v_posterior * betas
166
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
167
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
168
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
169
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
170
+ self.register_buffer('posterior_mean_coef1', to_torch(
171
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
172
+ self.register_buffer('posterior_mean_coef2', to_torch(
173
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
174
+
175
+ if self.parameterization == "eps":
176
+ lvlb_weights = self.betas ** 2 / (
177
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
178
+ elif self.parameterization == "x0":
179
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
180
+ else:
181
+ raise NotImplementedError("mu not supported")
182
+ # TODO how to choose this term
183
+ lvlb_weights[0] = lvlb_weights[1]
184
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
185
+ assert not torch.isnan(self.lvlb_weights).all()
186
+
187
+ @contextmanager
188
+ def ema_scope(self, context=None):
189
+ if self.use_ema:
190
+ self.model_ema.store(self.model.parameters())
191
+ self.model_ema.copy_to(self.model)
192
+ if context is not None:
193
+ print(f"{context}: Switched to EMA weights")
194
+ try:
195
+ yield None
196
+ finally:
197
+ if self.use_ema:
198
+ self.model_ema.restore(self.model.parameters())
199
+ if context is not None:
200
+ print(f"{context}: Restored training weights")
201
+
202
+ def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
203
+ ignore_keys = ignore_keys or []
204
+
205
+ sd = torch.load(path, map_location="cpu")
206
+ if "state_dict" in list(sd.keys()):
207
+ sd = sd["state_dict"]
208
+ keys = list(sd.keys())
209
+
210
+ # Our model adds additional channels to the first layer to condition on an input image.
211
+ # For the first layer, copy existing channel weights and initialize new channel weights to zero.
212
+ input_keys = [
213
+ "model.diffusion_model.input_blocks.0.0.weight",
214
+ "model_ema.diffusion_modelinput_blocks00weight",
215
+ ]
216
+
217
+ self_sd = self.state_dict()
218
+ for input_key in input_keys:
219
+ if input_key not in sd or input_key not in self_sd:
220
+ continue
221
+
222
+ input_weight = self_sd[input_key]
223
+
224
+ if input_weight.size() != sd[input_key].size():
225
+ print(f"Manual init: {input_key}")
226
+ input_weight.zero_()
227
+ input_weight[:, :4, :, :].copy_(sd[input_key])
228
+ ignore_keys.append(input_key)
229
+
230
+ for k in keys:
231
+ for ik in ignore_keys:
232
+ if k.startswith(ik):
233
+ print(f"Deleting key {k} from state_dict.")
234
+ del sd[k]
235
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
236
+ sd, strict=False)
237
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
238
+ if missing:
239
+ print(f"Missing Keys: {missing}")
240
+ if unexpected:
241
+ print(f"Unexpected Keys: {unexpected}")
242
+
243
+ def q_mean_variance(self, x_start, t):
244
+ """
245
+ Get the distribution q(x_t | x_0).
246
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
247
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
248
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
249
+ """
250
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
251
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
252
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
253
+ return mean, variance, log_variance
254
+
255
+ def predict_start_from_noise(self, x_t, t, noise):
256
+ return (
257
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
258
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
259
+ )
260
+
261
+ def q_posterior(self, x_start, x_t, t):
262
+ posterior_mean = (
263
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
264
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
265
+ )
266
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
267
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
268
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
269
+
270
+ def p_mean_variance(self, x, t, clip_denoised: bool):
271
+ model_out = self.model(x, t)
272
+ if self.parameterization == "eps":
273
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
274
+ elif self.parameterization == "x0":
275
+ x_recon = model_out
276
+ if clip_denoised:
277
+ x_recon.clamp_(-1., 1.)
278
+
279
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
280
+ return model_mean, posterior_variance, posterior_log_variance
281
+
282
+ @torch.no_grad()
283
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
284
+ b, *_, device = *x.shape, x.device
285
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
286
+ noise = noise_like(x.shape, device, repeat_noise)
287
+ # no noise when t == 0
288
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
289
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
290
+
291
+ @torch.no_grad()
292
+ def p_sample_loop(self, shape, return_intermediates=False):
293
+ device = self.betas.device
294
+ b = shape[0]
295
+ img = torch.randn(shape, device=device)
296
+ intermediates = [img]
297
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
298
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
299
+ clip_denoised=self.clip_denoised)
300
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
301
+ intermediates.append(img)
302
+ if return_intermediates:
303
+ return img, intermediates
304
+ return img
305
+
306
+ @torch.no_grad()
307
+ def sample(self, batch_size=16, return_intermediates=False):
308
+ image_size = self.image_size
309
+ channels = self.channels
310
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
311
+ return_intermediates=return_intermediates)
312
+
313
+ def q_sample(self, x_start, t, noise=None):
314
+ noise = default(noise, lambda: torch.randn_like(x_start))
315
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
316
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
317
+
318
+ def get_loss(self, pred, target, mean=True):
319
+ if self.loss_type == 'l1':
320
+ loss = (target - pred).abs()
321
+ if mean:
322
+ loss = loss.mean()
323
+ elif self.loss_type == 'l2':
324
+ if mean:
325
+ loss = torch.nn.functional.mse_loss(target, pred)
326
+ else:
327
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
328
+ else:
329
+ raise NotImplementedError("unknown loss type '{loss_type}'")
330
+
331
+ return loss
332
+
333
+ def p_losses(self, x_start, t, noise=None):
334
+ noise = default(noise, lambda: torch.randn_like(x_start))
335
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
336
+ model_out = self.model(x_noisy, t)
337
+
338
+ loss_dict = {}
339
+ if self.parameterization == "eps":
340
+ target = noise
341
+ elif self.parameterization == "x0":
342
+ target = x_start
343
+ else:
344
+ raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
345
+
346
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
347
+
348
+ log_prefix = 'train' if self.training else 'val'
349
+
350
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
351
+ loss_simple = loss.mean() * self.l_simple_weight
352
+
353
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
354
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
355
+
356
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
357
+
358
+ loss_dict.update({f'{log_prefix}/loss': loss})
359
+
360
+ return loss, loss_dict
361
+
362
+ def forward(self, x, *args, **kwargs):
363
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
364
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
365
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
366
+ return self.p_losses(x, t, *args, **kwargs)
367
+
368
+ def get_input(self, batch, k):
369
+ return batch[k]
370
+
371
+ def shared_step(self, batch):
372
+ x = self.get_input(batch, self.first_stage_key)
373
+ loss, loss_dict = self(x)
374
+ return loss, loss_dict
375
+
376
+ def training_step(self, batch, batch_idx):
377
+ loss, loss_dict = self.shared_step(batch)
378
+
379
+ self.log_dict(loss_dict, prog_bar=True,
380
+ logger=True, on_step=True, on_epoch=True)
381
+
382
+ self.log("global_step", self.global_step,
383
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
384
+
385
+ if self.use_scheduler:
386
+ lr = self.optimizers().param_groups[0]['lr']
387
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
388
+
389
+ return loss
390
+
391
+ @torch.no_grad()
392
+ def validation_step(self, batch, batch_idx):
393
+ _, loss_dict_no_ema = self.shared_step(batch)
394
+ with self.ema_scope():
395
+ _, loss_dict_ema = self.shared_step(batch)
396
+ loss_dict_ema = {f"{key}_ema": loss_dict_ema[key] for key in loss_dict_ema}
397
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
398
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
399
+
400
+ def on_train_batch_end(self, *args, **kwargs):
401
+ if self.use_ema:
402
+ self.model_ema(self.model)
403
+
404
+ def _get_rows_from_list(self, samples):
405
+ n_imgs_per_row = len(samples)
406
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
407
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
408
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
409
+ return denoise_grid
410
+
411
+ @torch.no_grad()
412
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
413
+ log = {}
414
+ x = self.get_input(batch, self.first_stage_key)
415
+ N = min(x.shape[0], N)
416
+ n_row = min(x.shape[0], n_row)
417
+ x = x.to(self.device)[:N]
418
+ log["inputs"] = x
419
+
420
+ # get diffusion row
421
+ diffusion_row = []
422
+ x_start = x[:n_row]
423
+
424
+ for t in range(self.num_timesteps):
425
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
426
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
427
+ t = t.to(self.device).long()
428
+ noise = torch.randn_like(x_start)
429
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
430
+ diffusion_row.append(x_noisy)
431
+
432
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
433
+
434
+ if sample:
435
+ # get denoise row
436
+ with self.ema_scope("Plotting"):
437
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
438
+
439
+ log["samples"] = samples
440
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
441
+
442
+ if return_keys:
443
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
444
+ return log
445
+ else:
446
+ return {key: log[key] for key in return_keys}
447
+ return log
448
+
449
+ def configure_optimizers(self):
450
+ lr = self.learning_rate
451
+ params = list(self.model.parameters())
452
+ if self.learn_logvar:
453
+ params = params + [self.logvar]
454
+ opt = torch.optim.AdamW(params, lr=lr)
455
+ return opt
456
+
457
+
458
+ class LatentDiffusion(DDPM):
459
+ """main class"""
460
+ def __init__(self,
461
+ first_stage_config,
462
+ cond_stage_config,
463
+ num_timesteps_cond=None,
464
+ cond_stage_key="image",
465
+ cond_stage_trainable=False,
466
+ concat_mode=True,
467
+ cond_stage_forward=None,
468
+ conditioning_key=None,
469
+ scale_factor=1.0,
470
+ scale_by_std=False,
471
+ load_ema=True,
472
+ *args, **kwargs):
473
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
474
+ self.scale_by_std = scale_by_std
475
+ assert self.num_timesteps_cond <= kwargs['timesteps']
476
+ # for backwards compatibility after implementation of DiffusionWrapper
477
+ if conditioning_key is None:
478
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
479
+ if cond_stage_config == '__is_unconditional__':
480
+ conditioning_key = None
481
+ ckpt_path = kwargs.pop("ckpt_path", None)
482
+ ignore_keys = kwargs.pop("ignore_keys", [])
483
+ super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs)
484
+ self.concat_mode = concat_mode
485
+ self.cond_stage_trainable = cond_stage_trainable
486
+ self.cond_stage_key = cond_stage_key
487
+ try:
488
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
489
+ except Exception:
490
+ self.num_downs = 0
491
+ if not scale_by_std:
492
+ self.scale_factor = scale_factor
493
+ else:
494
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
495
+ self.instantiate_first_stage(first_stage_config)
496
+ self.instantiate_cond_stage(cond_stage_config)
497
+ self.cond_stage_forward = cond_stage_forward
498
+ self.clip_denoised = False
499
+ self.bbox_tokenizer = None
500
+
501
+ self.restarted_from_ckpt = False
502
+ if ckpt_path is not None:
503
+ self.init_from_ckpt(ckpt_path, ignore_keys)
504
+ self.restarted_from_ckpt = True
505
+
506
+ if self.use_ema and not load_ema:
507
+ self.model_ema = LitEma(self.model)
508
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
509
+
510
+ def make_cond_schedule(self, ):
511
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
512
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
513
+ self.cond_ids[:self.num_timesteps_cond] = ids
514
+
515
+ @rank_zero_only
516
+ @torch.no_grad()
517
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
518
+ # only for very first batch
519
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
520
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
521
+ # set rescale weight to 1./std of encodings
522
+ print("### USING STD-RESCALING ###")
523
+ x = super().get_input(batch, self.first_stage_key)
524
+ x = x.to(self.device)
525
+ encoder_posterior = self.encode_first_stage(x)
526
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
527
+ del self.scale_factor
528
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
529
+ print(f"setting self.scale_factor to {self.scale_factor}")
530
+ print("### USING STD-RESCALING ###")
531
+
532
+ def register_schedule(self,
533
+ given_betas=None, beta_schedule="linear", timesteps=1000,
534
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
535
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
536
+
537
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
538
+ if self.shorten_cond_schedule:
539
+ self.make_cond_schedule()
540
+
541
+ def instantiate_first_stage(self, config):
542
+ model = instantiate_from_config(config)
543
+ self.first_stage_model = model.eval()
544
+ self.first_stage_model.train = disabled_train
545
+ for param in self.first_stage_model.parameters():
546
+ param.requires_grad = False
547
+
548
+ def instantiate_cond_stage(self, config):
549
+ if not self.cond_stage_trainable:
550
+ if config == "__is_first_stage__":
551
+ print("Using first stage also as cond stage.")
552
+ self.cond_stage_model = self.first_stage_model
553
+ elif config == "__is_unconditional__":
554
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
555
+ self.cond_stage_model = None
556
+ # self.be_unconditional = True
557
+ else:
558
+ model = instantiate_from_config(config)
559
+ self.cond_stage_model = model.eval()
560
+ self.cond_stage_model.train = disabled_train
561
+ for param in self.cond_stage_model.parameters():
562
+ param.requires_grad = False
563
+ else:
564
+ assert config != '__is_first_stage__'
565
+ assert config != '__is_unconditional__'
566
+ model = instantiate_from_config(config)
567
+ self.cond_stage_model = model
568
+
569
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
570
+ denoise_row = []
571
+ for zd in tqdm(samples, desc=desc):
572
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
573
+ force_not_quantize=force_no_decoder_quantization))
574
+ n_imgs_per_row = len(denoise_row)
575
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
576
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
577
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
578
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
579
+ return denoise_grid
580
+
581
+ def get_first_stage_encoding(self, encoder_posterior):
582
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
583
+ z = encoder_posterior.sample()
584
+ elif isinstance(encoder_posterior, torch.Tensor):
585
+ z = encoder_posterior
586
+ else:
587
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
588
+ return self.scale_factor * z
589
+
590
+ def get_learned_conditioning(self, c):
591
+ if self.cond_stage_forward is None:
592
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
593
+ c = self.cond_stage_model.encode(c)
594
+ if isinstance(c, DiagonalGaussianDistribution):
595
+ c = c.mode()
596
+ else:
597
+ c = self.cond_stage_model(c)
598
+ else:
599
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
600
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
601
+ return c
602
+
603
+ def meshgrid(self, h, w):
604
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
605
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
606
+
607
+ arr = torch.cat([y, x], dim=-1)
608
+ return arr
609
+
610
+ def delta_border(self, h, w):
611
+ """
612
+ :param h: height
613
+ :param w: width
614
+ :return: normalized distance to image border,
615
+ wtith min distance = 0 at border and max dist = 0.5 at image center
616
+ """
617
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
618
+ arr = self.meshgrid(h, w) / lower_right_corner
619
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
620
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
621
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
622
+ return edge_dist
623
+
624
+ def get_weighting(self, h, w, Ly, Lx, device):
625
+ weighting = self.delta_border(h, w)
626
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
627
+ self.split_input_params["clip_max_weight"], )
628
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
629
+
630
+ if self.split_input_params["tie_braker"]:
631
+ L_weighting = self.delta_border(Ly, Lx)
632
+ L_weighting = torch.clip(L_weighting,
633
+ self.split_input_params["clip_min_tie_weight"],
634
+ self.split_input_params["clip_max_tie_weight"])
635
+
636
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
637
+ weighting = weighting * L_weighting
638
+ return weighting
639
+
640
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
641
+ """
642
+ :param x: img of size (bs, c, h, w)
643
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
644
+ """
645
+ bs, nc, h, w = x.shape
646
+
647
+ # number of crops in image
648
+ Ly = (h - kernel_size[0]) // stride[0] + 1
649
+ Lx = (w - kernel_size[1]) // stride[1] + 1
650
+
651
+ if uf == 1 and df == 1:
652
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
653
+ unfold = torch.nn.Unfold(**fold_params)
654
+
655
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
656
+
657
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
658
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
659
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
660
+
661
+ elif uf > 1 and df == 1:
662
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
663
+ unfold = torch.nn.Unfold(**fold_params)
664
+
665
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
666
+ dilation=1, padding=0,
667
+ stride=(stride[0] * uf, stride[1] * uf))
668
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
669
+
670
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
671
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
672
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
673
+
674
+ elif df > 1 and uf == 1:
675
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
676
+ unfold = torch.nn.Unfold(**fold_params)
677
+
678
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
679
+ dilation=1, padding=0,
680
+ stride=(stride[0] // df, stride[1] // df))
681
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
682
+
683
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
684
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
685
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
686
+
687
+ else:
688
+ raise NotImplementedError
689
+
690
+ return fold, unfold, normalization, weighting
691
+
692
+ @torch.no_grad()
693
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
694
+ cond_key=None, return_original_cond=False, bs=None, uncond=0.05):
695
+ x = super().get_input(batch, k)
696
+ if bs is not None:
697
+ x = x[:bs]
698
+ x = x.to(self.device)
699
+ encoder_posterior = self.encode_first_stage(x)
700
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
701
+ cond_key = cond_key or self.cond_stage_key
702
+ xc = super().get_input(batch, cond_key)
703
+ if bs is not None:
704
+ xc["c_crossattn"] = xc["c_crossattn"][:bs]
705
+ xc["c_concat"] = xc["c_concat"][:bs]
706
+ cond = {}
707
+
708
+ # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%.
709
+ random = torch.rand(x.size(0), device=x.device)
710
+ prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1")
711
+ input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1")
712
+
713
+ null_prompt = self.get_learned_conditioning([""])
714
+ cond["c_crossattn"] = [torch.where(prompt_mask, null_prompt, self.get_learned_conditioning(xc["c_crossattn"]).detach())]
715
+ cond["c_concat"] = [input_mask * self.encode_first_stage((xc["c_concat"].to(self.device))).mode().detach()]
716
+
717
+ out = [z, cond]
718
+ if return_first_stage_outputs:
719
+ xrec = self.decode_first_stage(z)
720
+ out.extend([x, xrec])
721
+ if return_original_cond:
722
+ out.append(xc)
723
+ return out
724
+
725
+ @torch.no_grad()
726
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
727
+ if predict_cids:
728
+ if z.dim() == 4:
729
+ z = torch.argmax(z.exp(), dim=1).long()
730
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
731
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
732
+
733
+ z = 1. / self.scale_factor * z
734
+
735
+ if hasattr(self, "split_input_params"):
736
+ if self.split_input_params["patch_distributed_vq"]:
737
+ ks = self.split_input_params["ks"] # eg. (128, 128)
738
+ stride = self.split_input_params["stride"] # eg. (64, 64)
739
+ uf = self.split_input_params["vqf"]
740
+ bs, nc, h, w = z.shape
741
+ if ks[0] > h or ks[1] > w:
742
+ ks = (min(ks[0], h), min(ks[1], w))
743
+ print("reducing Kernel")
744
+
745
+ if stride[0] > h or stride[1] > w:
746
+ stride = (min(stride[0], h), min(stride[1], w))
747
+ print("reducing stride")
748
+
749
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
750
+
751
+ z = unfold(z) # (bn, nc * prod(**ks), L)
752
+ # 1. Reshape to img shape
753
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
754
+
755
+ # 2. apply model loop over last dim
756
+ if isinstance(self.first_stage_model, VQModelInterface):
757
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
758
+ force_not_quantize=predict_cids or force_not_quantize)
759
+ for i in range(z.shape[-1])]
760
+ else:
761
+
762
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
763
+ for i in range(z.shape[-1])]
764
+
765
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
766
+ o = o * weighting
767
+ # Reverse 1. reshape to img shape
768
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
769
+ # stitch crops together
770
+ decoded = fold(o)
771
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
772
+ return decoded
773
+ else:
774
+ if isinstance(self.first_stage_model, VQModelInterface):
775
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
776
+ else:
777
+ return self.first_stage_model.decode(z)
778
+
779
+ else:
780
+ if isinstance(self.first_stage_model, VQModelInterface):
781
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
782
+ else:
783
+ return self.first_stage_model.decode(z)
784
+
785
+ # same as above but without decorator
786
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
787
+ if predict_cids:
788
+ if z.dim() == 4:
789
+ z = torch.argmax(z.exp(), dim=1).long()
790
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
791
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
792
+
793
+ z = 1. / self.scale_factor * z
794
+
795
+ if hasattr(self, "split_input_params"):
796
+ if self.split_input_params["patch_distributed_vq"]:
797
+ ks = self.split_input_params["ks"] # eg. (128, 128)
798
+ stride = self.split_input_params["stride"] # eg. (64, 64)
799
+ uf = self.split_input_params["vqf"]
800
+ bs, nc, h, w = z.shape
801
+ if ks[0] > h or ks[1] > w:
802
+ ks = (min(ks[0], h), min(ks[1], w))
803
+ print("reducing Kernel")
804
+
805
+ if stride[0] > h or stride[1] > w:
806
+ stride = (min(stride[0], h), min(stride[1], w))
807
+ print("reducing stride")
808
+
809
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
810
+
811
+ z = unfold(z) # (bn, nc * prod(**ks), L)
812
+ # 1. Reshape to img shape
813
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
814
+
815
+ # 2. apply model loop over last dim
816
+ if isinstance(self.first_stage_model, VQModelInterface):
817
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
818
+ force_not_quantize=predict_cids or force_not_quantize)
819
+ for i in range(z.shape[-1])]
820
+ else:
821
+
822
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
823
+ for i in range(z.shape[-1])]
824
+
825
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
826
+ o = o * weighting
827
+ # Reverse 1. reshape to img shape
828
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
829
+ # stitch crops together
830
+ decoded = fold(o)
831
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
832
+ return decoded
833
+ else:
834
+ if isinstance(self.first_stage_model, VQModelInterface):
835
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
836
+ else:
837
+ return self.first_stage_model.decode(z)
838
+
839
+ else:
840
+ if isinstance(self.first_stage_model, VQModelInterface):
841
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
842
+ else:
843
+ return self.first_stage_model.decode(z)
844
+
845
+ @torch.no_grad()
846
+ def encode_first_stage(self, x):
847
+ if hasattr(self, "split_input_params"):
848
+ if self.split_input_params["patch_distributed_vq"]:
849
+ ks = self.split_input_params["ks"] # eg. (128, 128)
850
+ stride = self.split_input_params["stride"] # eg. (64, 64)
851
+ df = self.split_input_params["vqf"]
852
+ self.split_input_params['original_image_size'] = x.shape[-2:]
853
+ bs, nc, h, w = x.shape
854
+ if ks[0] > h or ks[1] > w:
855
+ ks = (min(ks[0], h), min(ks[1], w))
856
+ print("reducing Kernel")
857
+
858
+ if stride[0] > h or stride[1] > w:
859
+ stride = (min(stride[0], h), min(stride[1], w))
860
+ print("reducing stride")
861
+
862
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
863
+ z = unfold(x) # (bn, nc * prod(**ks), L)
864
+ # Reshape to img shape
865
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
866
+
867
+ output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
868
+ for i in range(z.shape[-1])]
869
+
870
+ o = torch.stack(output_list, axis=-1)
871
+ o = o * weighting
872
+
873
+ # Reverse reshape to img shape
874
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
875
+ # stitch crops together
876
+ decoded = fold(o)
877
+ decoded = decoded / normalization
878
+ return decoded
879
+
880
+ else:
881
+ return self.first_stage_model.encode(x)
882
+ else:
883
+ return self.first_stage_model.encode(x)
884
+
885
+ def shared_step(self, batch, **kwargs):
886
+ x, c = self.get_input(batch, self.first_stage_key)
887
+ loss = self(x, c)
888
+ return loss
889
+
890
+ def forward(self, x, c, *args, **kwargs):
891
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
892
+ if self.model.conditioning_key is not None:
893
+ assert c is not None
894
+ if self.cond_stage_trainable:
895
+ c = self.get_learned_conditioning(c)
896
+ if self.shorten_cond_schedule: # TODO: drop this option
897
+ tc = self.cond_ids[t].to(self.device)
898
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
899
+ return self.p_losses(x, c, t, *args, **kwargs)
900
+
901
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
902
+
903
+ if isinstance(cond, dict):
904
+ # hybrid case, cond is expected to be a dict
905
+ pass
906
+ else:
907
+ if not isinstance(cond, list):
908
+ cond = [cond]
909
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
910
+ cond = {key: cond}
911
+
912
+ if hasattr(self, "split_input_params"):
913
+ assert len(cond) == 1 # todo can only deal with one conditioning atm
914
+ assert not return_ids
915
+ ks = self.split_input_params["ks"] # eg. (128, 128)
916
+ stride = self.split_input_params["stride"] # eg. (64, 64)
917
+
918
+ h, w = x_noisy.shape[-2:]
919
+
920
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
921
+
922
+ z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
923
+ # Reshape to img shape
924
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
925
+ z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
926
+
927
+ if self.cond_stage_key in ["image", "LR_image", "segmentation",
928
+ 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
929
+ c_key = next(iter(cond.keys())) # get key
930
+ c = next(iter(cond.values())) # get value
931
+ assert (len(c) == 1) # todo extend to list with more than one elem
932
+ c = c[0] # get element
933
+
934
+ c = unfold(c)
935
+ c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
936
+
937
+ cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
938
+
939
+ elif self.cond_stage_key == 'coordinates_bbox':
940
+ assert 'original_image_size' in self.split_input_params, 'BoundingBoxRescaling is missing original_image_size'
941
+
942
+ # assuming padding of unfold is always 0 and its dilation is always 1
943
+ n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
944
+ full_img_h, full_img_w = self.split_input_params['original_image_size']
945
+ # as we are operating on latents, we need the factor from the original image size to the
946
+ # spatial latent size to properly rescale the crops for regenerating the bbox annotations
947
+ num_downs = self.first_stage_model.encoder.num_resolutions - 1
948
+ rescale_latent = 2 ** (num_downs)
949
+
950
+ # get top left positions of patches as conforming for the bbbox tokenizer, therefore we
951
+ # need to rescale the tl patch coordinates to be in between (0,1)
952
+ tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
953
+ rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
954
+ for patch_nr in range(z.shape[-1])]
955
+
956
+ # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
957
+ patch_limits = [(x_tl, y_tl,
958
+ rescale_latent * ks[0] / full_img_w,
959
+ rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
960
+ # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
961
+
962
+ # tokenize crop coordinates for the bounding boxes of the respective patches
963
+ patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
964
+ for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
965
+ print(patch_limits_tknzd[0].shape)
966
+ # cut tknzd crop position from conditioning
967
+ assert isinstance(cond, dict), 'cond must be dict to be fed into model'
968
+ cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
969
+ print(cut_cond.shape)
970
+
971
+ adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
972
+ adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
973
+ print(adapted_cond.shape)
974
+ adapted_cond = self.get_learned_conditioning(adapted_cond)
975
+ print(adapted_cond.shape)
976
+ adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
977
+ print(adapted_cond.shape)
978
+
979
+ cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
980
+
981
+ else:
982
+ cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
983
+
984
+ # apply model by loop over crops
985
+ output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
986
+ assert not isinstance(output_list[0],
987
+ tuple) # todo cant deal with multiple model outputs check this never happens
988
+
989
+ o = torch.stack(output_list, axis=-1)
990
+ o = o * weighting
991
+ # Reverse reshape to img shape
992
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
993
+ # stitch crops together
994
+ x_recon = fold(o) / normalization
995
+
996
+ else:
997
+ x_recon = self.model(x_noisy, t, **cond)
998
+
999
+ if isinstance(x_recon, tuple) and not return_ids:
1000
+ return x_recon[0]
1001
+ else:
1002
+ return x_recon
1003
+
1004
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
1005
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
1006
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
1007
+
1008
+ def _prior_bpd(self, x_start):
1009
+ """
1010
+ Get the prior KL term for the variational lower-bound, measured in
1011
+ bits-per-dim.
1012
+ This term can't be optimized, as it only depends on the encoder.
1013
+ :param x_start: the [N x C x ...] tensor of inputs.
1014
+ :return: a batch of [N] KL values (in bits), one per batch element.
1015
+ """
1016
+ batch_size = x_start.shape[0]
1017
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1018
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1019
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
1020
+ return mean_flat(kl_prior) / np.log(2.0)
1021
+
1022
+ def p_losses(self, x_start, cond, t, noise=None):
1023
+ noise = default(noise, lambda: torch.randn_like(x_start))
1024
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
1025
+ model_output = self.apply_model(x_noisy, t, cond)
1026
+
1027
+ loss_dict = {}
1028
+ prefix = 'train' if self.training else 'val'
1029
+
1030
+ if self.parameterization == "x0":
1031
+ target = x_start
1032
+ elif self.parameterization == "eps":
1033
+ target = noise
1034
+ else:
1035
+ raise NotImplementedError()
1036
+
1037
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
1038
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
1039
+
1040
+ logvar_t = self.logvar[t].to(self.device)
1041
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
1042
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
1043
+ if self.learn_logvar:
1044
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
1045
+ loss_dict.update({'logvar': self.logvar.data.mean()})
1046
+
1047
+ loss = self.l_simple_weight * loss.mean()
1048
+
1049
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
1050
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
1051
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
1052
+ loss += (self.original_elbo_weight * loss_vlb)
1053
+ loss_dict.update({f'{prefix}/loss': loss})
1054
+
1055
+ return loss, loss_dict
1056
+
1057
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
1058
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
1059
+ t_in = t
1060
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
1061
+
1062
+ if score_corrector is not None:
1063
+ assert self.parameterization == "eps"
1064
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
1065
+
1066
+ if return_codebook_ids:
1067
+ model_out, logits = model_out
1068
+
1069
+ if self.parameterization == "eps":
1070
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
1071
+ elif self.parameterization == "x0":
1072
+ x_recon = model_out
1073
+ else:
1074
+ raise NotImplementedError()
1075
+
1076
+ if clip_denoised:
1077
+ x_recon.clamp_(-1., 1.)
1078
+ if quantize_denoised:
1079
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
1080
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
1081
+ if return_codebook_ids:
1082
+ return model_mean, posterior_variance, posterior_log_variance, logits
1083
+ elif return_x0:
1084
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
1085
+ else:
1086
+ return model_mean, posterior_variance, posterior_log_variance
1087
+
1088
+ @torch.no_grad()
1089
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
1090
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
1091
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
1092
+ b, *_, device = *x.shape, x.device
1093
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
1094
+ return_codebook_ids=return_codebook_ids,
1095
+ quantize_denoised=quantize_denoised,
1096
+ return_x0=return_x0,
1097
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1098
+ if return_codebook_ids:
1099
+ raise DeprecationWarning("Support dropped.")
1100
+ model_mean, _, model_log_variance, logits = outputs
1101
+ elif return_x0:
1102
+ model_mean, _, model_log_variance, x0 = outputs
1103
+ else:
1104
+ model_mean, _, model_log_variance = outputs
1105
+
1106
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
1107
+ if noise_dropout > 0.:
1108
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1109
+ # no noise when t == 0
1110
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1111
+
1112
+ if return_codebook_ids:
1113
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
1114
+ if return_x0:
1115
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
1116
+ else:
1117
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1118
+
1119
+ @torch.no_grad()
1120
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
1121
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
1122
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
1123
+ log_every_t=None):
1124
+ if not log_every_t:
1125
+ log_every_t = self.log_every_t
1126
+ timesteps = self.num_timesteps
1127
+ if batch_size is not None:
1128
+ b = batch_size if batch_size is not None else shape[0]
1129
+ shape = [batch_size] + list(shape)
1130
+ else:
1131
+ b = batch_size = shape[0]
1132
+ if x_T is None:
1133
+ img = torch.randn(shape, device=self.device)
1134
+ else:
1135
+ img = x_T
1136
+ intermediates = []
1137
+ if cond is not None:
1138
+ if isinstance(cond, dict):
1139
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1140
+ [x[:batch_size] for x in cond[key]] for key in cond}
1141
+ else:
1142
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1143
+
1144
+ if start_T is not None:
1145
+ timesteps = min(timesteps, start_T)
1146
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1147
+ total=timesteps) if verbose else reversed(
1148
+ range(0, timesteps))
1149
+ if type(temperature) == float:
1150
+ temperature = [temperature] * timesteps
1151
+
1152
+ for i in iterator:
1153
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1154
+ if self.shorten_cond_schedule:
1155
+ assert self.model.conditioning_key != 'hybrid'
1156
+ tc = self.cond_ids[ts].to(cond.device)
1157
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1158
+
1159
+ img, x0_partial = self.p_sample(img, cond, ts,
1160
+ clip_denoised=self.clip_denoised,
1161
+ quantize_denoised=quantize_denoised, return_x0=True,
1162
+ temperature=temperature[i], noise_dropout=noise_dropout,
1163
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1164
+ if mask is not None:
1165
+ assert x0 is not None
1166
+ img_orig = self.q_sample(x0, ts)
1167
+ img = img_orig * mask + (1. - mask) * img
1168
+
1169
+ if i % log_every_t == 0 or i == timesteps - 1:
1170
+ intermediates.append(x0_partial)
1171
+ if callback:
1172
+ callback(i)
1173
+ if img_callback:
1174
+ img_callback(img, i)
1175
+ return img, intermediates
1176
+
1177
+ @torch.no_grad()
1178
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
1179
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1180
+ mask=None, x0=None, img_callback=None, start_T=None,
1181
+ log_every_t=None):
1182
+
1183
+ if not log_every_t:
1184
+ log_every_t = self.log_every_t
1185
+ device = self.betas.device
1186
+ b = shape[0]
1187
+ if x_T is None:
1188
+ img = torch.randn(shape, device=device)
1189
+ else:
1190
+ img = x_T
1191
+
1192
+ intermediates = [img]
1193
+ if timesteps is None:
1194
+ timesteps = self.num_timesteps
1195
+
1196
+ if start_T is not None:
1197
+ timesteps = min(timesteps, start_T)
1198
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1199
+ range(0, timesteps))
1200
+
1201
+ if mask is not None:
1202
+ assert x0 is not None
1203
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1204
+
1205
+ for i in iterator:
1206
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
1207
+ if self.shorten_cond_schedule:
1208
+ assert self.model.conditioning_key != 'hybrid'
1209
+ tc = self.cond_ids[ts].to(cond.device)
1210
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1211
+
1212
+ img = self.p_sample(img, cond, ts,
1213
+ clip_denoised=self.clip_denoised,
1214
+ quantize_denoised=quantize_denoised)
1215
+ if mask is not None:
1216
+ img_orig = self.q_sample(x0, ts)
1217
+ img = img_orig * mask + (1. - mask) * img
1218
+
1219
+ if i % log_every_t == 0 or i == timesteps - 1:
1220
+ intermediates.append(img)
1221
+ if callback:
1222
+ callback(i)
1223
+ if img_callback:
1224
+ img_callback(img, i)
1225
+
1226
+ if return_intermediates:
1227
+ return img, intermediates
1228
+ return img
1229
+
1230
+ @torch.no_grad()
1231
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1232
+ verbose=True, timesteps=None, quantize_denoised=False,
1233
+ mask=None, x0=None, shape=None,**kwargs):
1234
+ if shape is None:
1235
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
1236
+ if cond is not None:
1237
+ if isinstance(cond, dict):
1238
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1239
+ [x[:batch_size] for x in cond[key]] for key in cond}
1240
+ else:
1241
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1242
+ return self.p_sample_loop(cond,
1243
+ shape,
1244
+ return_intermediates=return_intermediates, x_T=x_T,
1245
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1246
+ mask=mask, x0=x0)
1247
+
1248
+ @torch.no_grad()
1249
+ def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
1250
+
1251
+ if ddim:
1252
+ ddim_sampler = DDIMSampler(self)
1253
+ shape = (self.channels, self.image_size, self.image_size)
1254
+ samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
1255
+ shape,cond,verbose=False,**kwargs)
1256
+
1257
+ else:
1258
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1259
+ return_intermediates=True,**kwargs)
1260
+
1261
+ return samples, intermediates
1262
+
1263
+
1264
+ @torch.no_grad()
1265
+ def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1266
+ quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
1267
+ plot_diffusion_rows=False, **kwargs):
1268
+
1269
+ use_ddim = False
1270
+
1271
+ log = {}
1272
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1273
+ return_first_stage_outputs=True,
1274
+ force_c_encode=True,
1275
+ return_original_cond=True,
1276
+ bs=N, uncond=0)
1277
+ N = min(x.shape[0], N)
1278
+ n_row = min(x.shape[0], n_row)
1279
+ log["inputs"] = x
1280
+ log["reals"] = xc["c_concat"]
1281
+ log["reconstruction"] = xrec
1282
+ if self.model.conditioning_key is not None:
1283
+ if hasattr(self.cond_stage_model, "decode"):
1284
+ xc = self.cond_stage_model.decode(c)
1285
+ log["conditioning"] = xc
1286
+ elif self.cond_stage_key in ["caption"]:
1287
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
1288
+ log["conditioning"] = xc
1289
+ elif self.cond_stage_key == 'class_label':
1290
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
1291
+ log['conditioning'] = xc
1292
+ elif isimage(xc):
1293
+ log["conditioning"] = xc
1294
+ if ismap(xc):
1295
+ log["original_conditioning"] = self.to_rgb(xc)
1296
+
1297
+ if plot_diffusion_rows:
1298
+ # get diffusion row
1299
+ diffusion_row = []
1300
+ z_start = z[:n_row]
1301
+ for t in range(self.num_timesteps):
1302
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1303
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1304
+ t = t.to(self.device).long()
1305
+ noise = torch.randn_like(z_start)
1306
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1307
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1308
+
1309
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1310
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1311
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1312
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1313
+ log["diffusion_row"] = diffusion_grid
1314
+
1315
+ if sample:
1316
+ # get denoise row
1317
+ with self.ema_scope("Plotting"):
1318
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1319
+ ddim_steps=ddim_steps,eta=ddim_eta)
1320
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1321
+ x_samples = self.decode_first_stage(samples)
1322
+ log["samples"] = x_samples
1323
+ if plot_denoise_rows:
1324
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1325
+ log["denoise_row"] = denoise_grid
1326
+
1327
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1328
+ self.first_stage_model, IdentityFirstStage):
1329
+ # also display when quantizing x0 while sampling
1330
+ with self.ema_scope("Plotting Quantized Denoised"):
1331
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1332
+ ddim_steps=ddim_steps,eta=ddim_eta,
1333
+ quantize_denoised=True)
1334
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1335
+ # quantize_denoised=True)
1336
+ x_samples = self.decode_first_stage(samples.to(self.device))
1337
+ log["samples_x0_quantized"] = x_samples
1338
+
1339
+ if inpaint:
1340
+ # make a simple center square
1341
+ h, w = z.shape[2], z.shape[3]
1342
+ mask = torch.ones(N, h, w).to(self.device)
1343
+ # zeros will be filled in
1344
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1345
+ mask = mask[:, None, ...]
1346
+ with self.ema_scope("Plotting Inpaint"):
1347
+
1348
+ samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
1349
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1350
+ x_samples = self.decode_first_stage(samples.to(self.device))
1351
+ log["samples_inpainting"] = x_samples
1352
+ log["mask"] = mask
1353
+
1354
+ # outpaint
1355
+ with self.ema_scope("Plotting Outpaint"):
1356
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
1357
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1358
+ x_samples = self.decode_first_stage(samples.to(self.device))
1359
+ log["samples_outpainting"] = x_samples
1360
+
1361
+ if plot_progressive_rows:
1362
+ with self.ema_scope("Plotting Progressives"):
1363
+ img, progressives = self.progressive_denoising(c,
1364
+ shape=(self.channels, self.image_size, self.image_size),
1365
+ batch_size=N)
1366
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1367
+ log["progressive_row"] = prog_row
1368
+
1369
+ if return_keys:
1370
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1371
+ return log
1372
+ else:
1373
+ return {key: log[key] for key in return_keys}
1374
+ return log
1375
+
1376
+ def configure_optimizers(self):
1377
+ lr = self.learning_rate
1378
+ params = list(self.model.parameters())
1379
+ if self.cond_stage_trainable:
1380
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1381
+ params = params + list(self.cond_stage_model.parameters())
1382
+ if self.learn_logvar:
1383
+ print('Diffusion model optimizing logvar')
1384
+ params.append(self.logvar)
1385
+ opt = torch.optim.AdamW(params, lr=lr)
1386
+ if self.use_scheduler:
1387
+ assert 'target' in self.scheduler_config
1388
+ scheduler = instantiate_from_config(self.scheduler_config)
1389
+
1390
+ print("Setting up LambdaLR scheduler...")
1391
+ scheduler = [
1392
+ {
1393
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1394
+ 'interval': 'step',
1395
+ 'frequency': 1
1396
+ }]
1397
+ return [opt], scheduler
1398
+ return opt
1399
+
1400
+ @torch.no_grad()
1401
+ def to_rgb(self, x):
1402
+ x = x.float()
1403
+ if not hasattr(self, "colorize"):
1404
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1405
+ x = nn.functional.conv2d(x, weight=self.colorize)
1406
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1407
+ return x
1408
+
1409
+
1410
+ class DiffusionWrapper(pl.LightningModule):
1411
+ def __init__(self, diff_model_config, conditioning_key):
1412
+ super().__init__()
1413
+ self.diffusion_model = instantiate_from_config(diff_model_config)
1414
+ self.conditioning_key = conditioning_key
1415
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
1416
+
1417
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
1418
+ if self.conditioning_key is None:
1419
+ out = self.diffusion_model(x, t)
1420
+ elif self.conditioning_key == 'concat':
1421
+ xc = torch.cat([x] + c_concat, dim=1)
1422
+ out = self.diffusion_model(xc, t)
1423
+ elif self.conditioning_key == 'crossattn':
1424
+ cc = torch.cat(c_crossattn, 1)
1425
+ out = self.diffusion_model(x, t, context=cc)
1426
+ elif self.conditioning_key == 'hybrid':
1427
+ xc = torch.cat([x] + c_concat, dim=1)
1428
+ cc = torch.cat(c_crossattn, 1)
1429
+ out = self.diffusion_model(xc, t, context=cc)
1430
+ elif self.conditioning_key == 'adm':
1431
+ cc = c_crossattn[0]
1432
+ out = self.diffusion_model(x, t, y=cc)
1433
+ else:
1434
+ raise NotImplementedError()
1435
+
1436
+ return out
1437
+
1438
+
1439
+ class Layout2ImgDiffusion(LatentDiffusion):
1440
+ # TODO: move all layout-specific hacks to this class
1441
+ def __init__(self, cond_stage_key, *args, **kwargs):
1442
+ assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
1443
+ super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
1444
+
1445
+ def log_images(self, batch, N=8, *args, **kwargs):
1446
+ logs = super().log_images(*args, batch=batch, N=N, **kwargs)
1447
+
1448
+ key = 'train' if self.training else 'validation'
1449
+ dset = self.trainer.datamodule.datasets[key]
1450
+ mapper = dset.conditional_builders[self.cond_stage_key]
1451
+
1452
+ bbox_imgs = []
1453
+ map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
1454
+ for tknzd_bbox in batch[self.cond_stage_key][:N]:
1455
+ bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
1456
+ bbox_imgs.append(bboximg)
1457
+
1458
+ cond_img = torch.stack(bbox_imgs, dim=0)
1459
+ logs['bbox_image'] = cond_img
1460
+ return logs
modules/models/diffusion/uni_pc/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sampler import UniPCSampler # noqa: F401
modules/models/diffusion/uni_pc/sampler.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+
5
+ from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
6
+ from modules import shared, devices
7
+
8
+
9
+ class UniPCSampler(object):
10
+ def __init__(self, model, **kwargs):
11
+ super().__init__()
12
+ self.model = model
13
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
14
+ self.before_sample = None
15
+ self.after_sample = None
16
+ self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
17
+
18
+ def register_buffer(self, name, attr):
19
+ if type(attr) == torch.Tensor:
20
+ if attr.device != devices.device:
21
+ attr = attr.to(devices.device)
22
+ setattr(self, name, attr)
23
+
24
+ def set_hooks(self, before_sample, after_sample, after_update):
25
+ self.before_sample = before_sample
26
+ self.after_sample = after_sample
27
+ self.after_update = after_update
28
+
29
+ @torch.no_grad()
30
+ def sample(self,
31
+ S,
32
+ batch_size,
33
+ shape,
34
+ conditioning=None,
35
+ callback=None,
36
+ normals_sequence=None,
37
+ img_callback=None,
38
+ quantize_x0=False,
39
+ eta=0.,
40
+ mask=None,
41
+ x0=None,
42
+ temperature=1.,
43
+ noise_dropout=0.,
44
+ score_corrector=None,
45
+ corrector_kwargs=None,
46
+ verbose=True,
47
+ x_T=None,
48
+ log_every_t=100,
49
+ unconditional_guidance_scale=1.,
50
+ unconditional_conditioning=None,
51
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
52
+ **kwargs
53
+ ):
54
+ if conditioning is not None:
55
+ if isinstance(conditioning, dict):
56
+ ctmp = conditioning[list(conditioning.keys())[0]]
57
+ while isinstance(ctmp, list):
58
+ ctmp = ctmp[0]
59
+ cbs = ctmp.shape[0]
60
+ if cbs != batch_size:
61
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
62
+
63
+ elif isinstance(conditioning, list):
64
+ for ctmp in conditioning:
65
+ if ctmp.shape[0] != batch_size:
66
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
67
+
68
+ else:
69
+ if conditioning.shape[0] != batch_size:
70
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
71
+
72
+ # sampling
73
+ C, H, W = shape
74
+ size = (batch_size, C, H, W)
75
+ # print(f'Data shape for UniPC sampling is {size}')
76
+
77
+ device = self.model.betas.device
78
+ if x_T is None:
79
+ img = torch.randn(size, device=device)
80
+ else:
81
+ img = x_T
82
+
83
+ ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
84
+
85
+ # SD 1.X is "noise", SD 2.X is "v"
86
+ model_type = "v" if self.model.parameterization == "v" else "noise"
87
+
88
+ model_fn = model_wrapper(
89
+ lambda x, t, c: self.model.apply_model(x, t, c),
90
+ ns,
91
+ model_type=model_type,
92
+ guidance_type="classifier-free",
93
+ #condition=conditioning,
94
+ #unconditional_condition=unconditional_conditioning,
95
+ guidance_scale=unconditional_guidance_scale,
96
+ )
97
+
98
+ uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
99
+ x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
100
+
101
+ return x.to(device), None
modules/models/diffusion/uni_pc/uni_pc.py ADDED
@@ -0,0 +1,863 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import tqdm
4
+
5
+
6
+ class NoiseScheduleVP:
7
+ def __init__(
8
+ self,
9
+ schedule='discrete',
10
+ betas=None,
11
+ alphas_cumprod=None,
12
+ continuous_beta_0=0.1,
13
+ continuous_beta_1=20.,
14
+ ):
15
+ """Create a wrapper class for the forward SDE (VP type).
16
+
17
+ ***
18
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
19
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
20
+ ***
21
+
22
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
23
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
24
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
25
+
26
+ log_alpha_t = self.marginal_log_mean_coeff(t)
27
+ sigma_t = self.marginal_std(t)
28
+ lambda_t = self.marginal_lambda(t)
29
+
30
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
31
+
32
+ t = self.inverse_lambda(lambda_t)
33
+
34
+ ===============================================================
35
+
36
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
37
+
38
+ 1. For discrete-time DPMs:
39
+
40
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
41
+ t_i = (i + 1) / N
42
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
43
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
44
+
45
+ Args:
46
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
47
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
48
+
49
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
50
+
51
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
52
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
53
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
54
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
55
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
56
+ and
57
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
58
+
59
+
60
+ 2. For continuous-time DPMs:
61
+
62
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
63
+ schedule are the default settings in DDPM and improved-DDPM:
64
+
65
+ Args:
66
+ beta_min: A `float` number. The smallest beta for the linear schedule.
67
+ beta_max: A `float` number. The largest beta for the linear schedule.
68
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
69
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
70
+ T: A `float` number. The ending time of the forward process.
71
+
72
+ ===============================================================
73
+
74
+ Args:
75
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
76
+ 'linear' or 'cosine' for continuous-time DPMs.
77
+ Returns:
78
+ A wrapper object of the forward SDE (VP type).
79
+
80
+ ===============================================================
81
+
82
+ Example:
83
+
84
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
85
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
86
+
87
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
88
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
89
+
90
+ # For continuous-time DPMs (VPSDE), linear schedule:
91
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
92
+
93
+ """
94
+
95
+ if schedule not in ['discrete', 'linear', 'cosine']:
96
+ raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'")
97
+
98
+ self.schedule = schedule
99
+ if schedule == 'discrete':
100
+ if betas is not None:
101
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
102
+ else:
103
+ assert alphas_cumprod is not None
104
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
105
+ self.total_N = len(log_alphas)
106
+ self.T = 1.
107
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
108
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
109
+ else:
110
+ self.total_N = 1000
111
+ self.beta_0 = continuous_beta_0
112
+ self.beta_1 = continuous_beta_1
113
+ self.cosine_s = 0.008
114
+ self.cosine_beta_max = 999.
115
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
116
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
117
+ self.schedule = schedule
118
+ if schedule == 'cosine':
119
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
120
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
121
+ self.T = 0.9946
122
+ else:
123
+ self.T = 1.
124
+
125
+ def marginal_log_mean_coeff(self, t):
126
+ """
127
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
128
+ """
129
+ if self.schedule == 'discrete':
130
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
131
+ elif self.schedule == 'linear':
132
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
133
+ elif self.schedule == 'cosine':
134
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
135
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
136
+ return log_alpha_t
137
+
138
+ def marginal_alpha(self, t):
139
+ """
140
+ Compute alpha_t of a given continuous-time label t in [0, T].
141
+ """
142
+ return torch.exp(self.marginal_log_mean_coeff(t))
143
+
144
+ def marginal_std(self, t):
145
+ """
146
+ Compute sigma_t of a given continuous-time label t in [0, T].
147
+ """
148
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
149
+
150
+ def marginal_lambda(self, t):
151
+ """
152
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
153
+ """
154
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
155
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
156
+ return log_mean_coeff - log_std
157
+
158
+ def inverse_lambda(self, lamb):
159
+ """
160
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
161
+ """
162
+ if self.schedule == 'linear':
163
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
164
+ Delta = self.beta_0**2 + tmp
165
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
166
+ elif self.schedule == 'discrete':
167
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
168
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
169
+ return t.reshape((-1,))
170
+ else:
171
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
172
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
173
+ t = t_fn(log_alpha)
174
+ return t
175
+
176
+
177
+ def model_wrapper(
178
+ model,
179
+ noise_schedule,
180
+ model_type="noise",
181
+ model_kwargs=None,
182
+ guidance_type="uncond",
183
+ #condition=None,
184
+ #unconditional_condition=None,
185
+ guidance_scale=1.,
186
+ classifier_fn=None,
187
+ classifier_kwargs=None,
188
+ ):
189
+ """Create a wrapper function for the noise prediction model.
190
+
191
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
192
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
193
+
194
+ We support four types of the diffusion model by setting `model_type`:
195
+
196
+ 1. "noise": noise prediction model. (Trained by predicting noise).
197
+
198
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
199
+
200
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
201
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
202
+
203
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
204
+ arXiv preprint arXiv:2202.00512 (2022).
205
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
206
+ arXiv preprint arXiv:2210.02303 (2022).
207
+
208
+ 4. "score": marginal score function. (Trained by denoising score matching).
209
+ Note that the score function and the noise prediction model follows a simple relationship:
210
+ ```
211
+ noise(x_t, t) = -sigma_t * score(x_t, t)
212
+ ```
213
+
214
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
215
+ 1. "uncond": unconditional sampling by DPMs.
216
+ The input `model` has the following format:
217
+ ``
218
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
219
+ ``
220
+
221
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
222
+ The input `model` has the following format:
223
+ ``
224
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
225
+ ``
226
+
227
+ The input `classifier_fn` has the following format:
228
+ ``
229
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
230
+ ``
231
+
232
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
233
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
234
+
235
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
236
+ The input `model` has the following format:
237
+ ``
238
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
239
+ ``
240
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
241
+
242
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
243
+ arXiv preprint arXiv:2207.12598 (2022).
244
+
245
+
246
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
247
+ or continuous-time labels (i.e. epsilon to T).
248
+
249
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
250
+ ``
251
+ def model_fn(x, t_continuous) -> noise:
252
+ t_input = get_model_input_time(t_continuous)
253
+ return noise_pred(model, x, t_input, **model_kwargs)
254
+ ``
255
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
256
+
257
+ ===============================================================
258
+
259
+ Args:
260
+ model: A diffusion model with the corresponding format described above.
261
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
262
+ model_type: A `str`. The parameterization type of the diffusion model.
263
+ "noise" or "x_start" or "v" or "score".
264
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
265
+ guidance_type: A `str`. The type of the guidance for sampling.
266
+ "uncond" or "classifier" or "classifier-free".
267
+ condition: A pytorch tensor. The condition for the guided sampling.
268
+ Only used for "classifier" or "classifier-free" guidance type.
269
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
270
+ Only used for "classifier-free" guidance type.
271
+ guidance_scale: A `float`. The scale for the guided sampling.
272
+ classifier_fn: A classifier function. Only used for the classifier guidance.
273
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
274
+ Returns:
275
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
276
+ """
277
+
278
+ model_kwargs = model_kwargs or {}
279
+ classifier_kwargs = classifier_kwargs or {}
280
+
281
+ def get_model_input_time(t_continuous):
282
+ """
283
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
284
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
285
+ For continuous-time DPMs, we just use `t_continuous`.
286
+ """
287
+ if noise_schedule.schedule == 'discrete':
288
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
289
+ else:
290
+ return t_continuous
291
+
292
+ def noise_pred_fn(x, t_continuous, cond=None):
293
+ if t_continuous.reshape((-1,)).shape[0] == 1:
294
+ t_continuous = t_continuous.expand((x.shape[0]))
295
+ t_input = get_model_input_time(t_continuous)
296
+ if cond is None:
297
+ output = model(x, t_input, None, **model_kwargs)
298
+ else:
299
+ output = model(x, t_input, cond, **model_kwargs)
300
+ if model_type == "noise":
301
+ return output
302
+ elif model_type == "x_start":
303
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
304
+ dims = x.dim()
305
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
306
+ elif model_type == "v":
307
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
308
+ dims = x.dim()
309
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
310
+ elif model_type == "score":
311
+ sigma_t = noise_schedule.marginal_std(t_continuous)
312
+ dims = x.dim()
313
+ return -expand_dims(sigma_t, dims) * output
314
+
315
+ def cond_grad_fn(x, t_input, condition):
316
+ """
317
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
318
+ """
319
+ with torch.enable_grad():
320
+ x_in = x.detach().requires_grad_(True)
321
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
322
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
323
+
324
+ def model_fn(x, t_continuous, condition, unconditional_condition):
325
+ """
326
+ The noise prediction model function that is used for DPM-Solver.
327
+ """
328
+ if t_continuous.reshape((-1,)).shape[0] == 1:
329
+ t_continuous = t_continuous.expand((x.shape[0]))
330
+ if guidance_type == "uncond":
331
+ return noise_pred_fn(x, t_continuous)
332
+ elif guidance_type == "classifier":
333
+ assert classifier_fn is not None
334
+ t_input = get_model_input_time(t_continuous)
335
+ cond_grad = cond_grad_fn(x, t_input, condition)
336
+ sigma_t = noise_schedule.marginal_std(t_continuous)
337
+ noise = noise_pred_fn(x, t_continuous)
338
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
339
+ elif guidance_type == "classifier-free":
340
+ if guidance_scale == 1. or unconditional_condition is None:
341
+ return noise_pred_fn(x, t_continuous, cond=condition)
342
+ else:
343
+ x_in = torch.cat([x] * 2)
344
+ t_in = torch.cat([t_continuous] * 2)
345
+ if isinstance(condition, dict):
346
+ assert isinstance(unconditional_condition, dict)
347
+ c_in = {}
348
+ for k in condition:
349
+ if isinstance(condition[k], list):
350
+ c_in[k] = [torch.cat([
351
+ unconditional_condition[k][i],
352
+ condition[k][i]]) for i in range(len(condition[k]))]
353
+ else:
354
+ c_in[k] = torch.cat([
355
+ unconditional_condition[k],
356
+ condition[k]])
357
+ elif isinstance(condition, list):
358
+ c_in = []
359
+ assert isinstance(unconditional_condition, list)
360
+ for i in range(len(condition)):
361
+ c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
362
+ else:
363
+ c_in = torch.cat([unconditional_condition, condition])
364
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
365
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
366
+
367
+ assert model_type in ["noise", "x_start", "v"]
368
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
369
+ return model_fn
370
+
371
+
372
+ class UniPC:
373
+ def __init__(
374
+ self,
375
+ model_fn,
376
+ noise_schedule,
377
+ predict_x0=True,
378
+ thresholding=False,
379
+ max_val=1.,
380
+ variant='bh1',
381
+ condition=None,
382
+ unconditional_condition=None,
383
+ before_sample=None,
384
+ after_sample=None,
385
+ after_update=None
386
+ ):
387
+ """Construct a UniPC.
388
+
389
+ We support both data_prediction and noise_prediction.
390
+ """
391
+ self.model_fn_ = model_fn
392
+ self.noise_schedule = noise_schedule
393
+ self.variant = variant
394
+ self.predict_x0 = predict_x0
395
+ self.thresholding = thresholding
396
+ self.max_val = max_val
397
+ self.condition = condition
398
+ self.unconditional_condition = unconditional_condition
399
+ self.before_sample = before_sample
400
+ self.after_sample = after_sample
401
+ self.after_update = after_update
402
+
403
+ def dynamic_thresholding_fn(self, x0, t=None):
404
+ """
405
+ The dynamic thresholding method.
406
+ """
407
+ dims = x0.dim()
408
+ p = self.dynamic_thresholding_ratio
409
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
410
+ s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
411
+ x0 = torch.clamp(x0, -s, s) / s
412
+ return x0
413
+
414
+ def model(self, x, t):
415
+ cond = self.condition
416
+ uncond = self.unconditional_condition
417
+ if self.before_sample is not None:
418
+ x, t, cond, uncond = self.before_sample(x, t, cond, uncond)
419
+ res = self.model_fn_(x, t, cond, uncond)
420
+ if self.after_sample is not None:
421
+ x, t, cond, uncond, res = self.after_sample(x, t, cond, uncond, res)
422
+
423
+ if isinstance(res, tuple):
424
+ # (None, pred_x0)
425
+ res = res[1]
426
+
427
+ return res
428
+
429
+ def noise_prediction_fn(self, x, t):
430
+ """
431
+ Return the noise prediction model.
432
+ """
433
+ return self.model(x, t)
434
+
435
+ def data_prediction_fn(self, x, t):
436
+ """
437
+ Return the data prediction model (with thresholding).
438
+ """
439
+ noise = self.noise_prediction_fn(x, t)
440
+ dims = x.dim()
441
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
442
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
443
+ if self.thresholding:
444
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
445
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
446
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
447
+ x0 = torch.clamp(x0, -s, s) / s
448
+ return x0
449
+
450
+ def model_fn(self, x, t):
451
+ """
452
+ Convert the model to the noise prediction model or the data prediction model.
453
+ """
454
+ if self.predict_x0:
455
+ return self.data_prediction_fn(x, t)
456
+ else:
457
+ return self.noise_prediction_fn(x, t)
458
+
459
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
460
+ """Compute the intermediate time steps for sampling.
461
+ """
462
+ if skip_type == 'logSNR':
463
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
464
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
465
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
466
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
467
+ elif skip_type == 'time_uniform':
468
+ return torch.linspace(t_T, t_0, N + 1).to(device)
469
+ elif skip_type == 'time_quadratic':
470
+ t_order = 2
471
+ t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
472
+ return t
473
+ else:
474
+ raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'")
475
+
476
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
477
+ """
478
+ Get the order of each step for sampling by the singlestep DPM-Solver.
479
+ """
480
+ if order == 3:
481
+ K = steps // 3 + 1
482
+ if steps % 3 == 0:
483
+ orders = [3,] * (K - 2) + [2, 1]
484
+ elif steps % 3 == 1:
485
+ orders = [3,] * (K - 1) + [1]
486
+ else:
487
+ orders = [3,] * (K - 1) + [2]
488
+ elif order == 2:
489
+ if steps % 2 == 0:
490
+ K = steps // 2
491
+ orders = [2,] * K
492
+ else:
493
+ K = steps // 2 + 1
494
+ orders = [2,] * (K - 1) + [1]
495
+ elif order == 1:
496
+ K = steps
497
+ orders = [1,] * steps
498
+ else:
499
+ raise ValueError("'order' must be '1' or '2' or '3'.")
500
+ if skip_type == 'logSNR':
501
+ # To reproduce the results in DPM-Solver paper
502
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
503
+ else:
504
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
505
+ return timesteps_outer, orders
506
+
507
+ def denoise_to_zero_fn(self, x, s):
508
+ """
509
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
510
+ """
511
+ return self.data_prediction_fn(x, s)
512
+
513
+ def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
514
+ if len(t.shape) == 0:
515
+ t = t.view(-1)
516
+ if 'bh' in self.variant:
517
+ return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
518
+ else:
519
+ assert self.variant == 'vary_coeff'
520
+ return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
521
+
522
+ def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
523
+ #print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
524
+ ns = self.noise_schedule
525
+ assert order <= len(model_prev_list)
526
+
527
+ # first compute rks
528
+ t_prev_0 = t_prev_list[-1]
529
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
530
+ lambda_t = ns.marginal_lambda(t)
531
+ model_prev_0 = model_prev_list[-1]
532
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
533
+ log_alpha_t = ns.marginal_log_mean_coeff(t)
534
+ alpha_t = torch.exp(log_alpha_t)
535
+
536
+ h = lambda_t - lambda_prev_0
537
+
538
+ rks = []
539
+ D1s = []
540
+ for i in range(1, order):
541
+ t_prev_i = t_prev_list[-(i + 1)]
542
+ model_prev_i = model_prev_list[-(i + 1)]
543
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
544
+ rk = (lambda_prev_i - lambda_prev_0) / h
545
+ rks.append(rk)
546
+ D1s.append((model_prev_i - model_prev_0) / rk)
547
+
548
+ rks.append(1.)
549
+ rks = torch.tensor(rks, device=x.device)
550
+
551
+ K = len(rks)
552
+ # build C matrix
553
+ C = []
554
+
555
+ col = torch.ones_like(rks)
556
+ for k in range(1, K + 1):
557
+ C.append(col)
558
+ col = col * rks / (k + 1)
559
+ C = torch.stack(C, dim=1)
560
+
561
+ if len(D1s) > 0:
562
+ D1s = torch.stack(D1s, dim=1) # (B, K)
563
+ C_inv_p = torch.linalg.inv(C[:-1, :-1])
564
+ A_p = C_inv_p
565
+
566
+ if use_corrector:
567
+ #print('using corrector')
568
+ C_inv = torch.linalg.inv(C)
569
+ A_c = C_inv
570
+
571
+ hh = -h if self.predict_x0 else h
572
+ h_phi_1 = torch.expm1(hh)
573
+ h_phi_ks = []
574
+ factorial_k = 1
575
+ h_phi_k = h_phi_1
576
+ for k in range(1, K + 2):
577
+ h_phi_ks.append(h_phi_k)
578
+ h_phi_k = h_phi_k / hh - 1 / factorial_k
579
+ factorial_k *= (k + 1)
580
+
581
+ model_t = None
582
+ if self.predict_x0:
583
+ x_t_ = (
584
+ sigma_t / sigma_prev_0 * x
585
+ - alpha_t * h_phi_1 * model_prev_0
586
+ )
587
+ # now predictor
588
+ x_t = x_t_
589
+ if len(D1s) > 0:
590
+ # compute the residuals for predictor
591
+ for k in range(K - 1):
592
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
593
+ # now corrector
594
+ if use_corrector:
595
+ model_t = self.model_fn(x_t, t)
596
+ D1_t = (model_t - model_prev_0)
597
+ x_t = x_t_
598
+ k = 0
599
+ for k in range(K - 1):
600
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
601
+ x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
602
+ else:
603
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
604
+ x_t_ = (
605
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
606
+ - (sigma_t * h_phi_1) * model_prev_0
607
+ )
608
+ # now predictor
609
+ x_t = x_t_
610
+ if len(D1s) > 0:
611
+ # compute the residuals for predictor
612
+ for k in range(K - 1):
613
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
614
+ # now corrector
615
+ if use_corrector:
616
+ model_t = self.model_fn(x_t, t)
617
+ D1_t = (model_t - model_prev_0)
618
+ x_t = x_t_
619
+ k = 0
620
+ for k in range(K - 1):
621
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
622
+ x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
623
+ return x_t, model_t
624
+
625
+ def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
626
+ #print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
627
+ ns = self.noise_schedule
628
+ assert order <= len(model_prev_list)
629
+ dims = x.dim()
630
+
631
+ # first compute rks
632
+ t_prev_0 = t_prev_list[-1]
633
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
634
+ lambda_t = ns.marginal_lambda(t)
635
+ model_prev_0 = model_prev_list[-1]
636
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
637
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
638
+ alpha_t = torch.exp(log_alpha_t)
639
+
640
+ h = lambda_t - lambda_prev_0
641
+
642
+ rks = []
643
+ D1s = []
644
+ for i in range(1, order):
645
+ t_prev_i = t_prev_list[-(i + 1)]
646
+ model_prev_i = model_prev_list[-(i + 1)]
647
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
648
+ rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
649
+ rks.append(rk)
650
+ D1s.append((model_prev_i - model_prev_0) / rk)
651
+
652
+ rks.append(1.)
653
+ rks = torch.tensor(rks, device=x.device)
654
+
655
+ R = []
656
+ b = []
657
+
658
+ hh = -h[0] if self.predict_x0 else h[0]
659
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
660
+ h_phi_k = h_phi_1 / hh - 1
661
+
662
+ factorial_i = 1
663
+
664
+ if self.variant == 'bh1':
665
+ B_h = hh
666
+ elif self.variant == 'bh2':
667
+ B_h = torch.expm1(hh)
668
+ else:
669
+ raise NotImplementedError()
670
+
671
+ for i in range(1, order + 1):
672
+ R.append(torch.pow(rks, i - 1))
673
+ b.append(h_phi_k * factorial_i / B_h)
674
+ factorial_i *= (i + 1)
675
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
676
+
677
+ R = torch.stack(R)
678
+ b = torch.tensor(b, device=x.device)
679
+
680
+ # now predictor
681
+ use_predictor = len(D1s) > 0 and x_t is None
682
+ if len(D1s) > 0:
683
+ D1s = torch.stack(D1s, dim=1) # (B, K)
684
+ if x_t is None:
685
+ # for order 2, we use a simplified version
686
+ if order == 2:
687
+ rhos_p = torch.tensor([0.5], device=b.device)
688
+ else:
689
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
690
+ else:
691
+ D1s = None
692
+
693
+ if use_corrector:
694
+ #print('using corrector')
695
+ # for order 1, we use a simplified version
696
+ if order == 1:
697
+ rhos_c = torch.tensor([0.5], device=b.device)
698
+ else:
699
+ rhos_c = torch.linalg.solve(R, b)
700
+
701
+ model_t = None
702
+ if self.predict_x0:
703
+ x_t_ = (
704
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
705
+ - expand_dims(alpha_t * h_phi_1, dims)* model_prev_0
706
+ )
707
+
708
+ if x_t is None:
709
+ if use_predictor:
710
+ pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
711
+ else:
712
+ pred_res = 0
713
+ x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
714
+
715
+ if use_corrector:
716
+ model_t = self.model_fn(x_t, t)
717
+ if D1s is not None:
718
+ corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
719
+ else:
720
+ corr_res = 0
721
+ D1_t = (model_t - model_prev_0)
722
+ x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
723
+ else:
724
+ x_t_ = (
725
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
726
+ - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
727
+ )
728
+ if x_t is None:
729
+ if use_predictor:
730
+ pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
731
+ else:
732
+ pred_res = 0
733
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
734
+
735
+ if use_corrector:
736
+ model_t = self.model_fn(x_t, t)
737
+ if D1s is not None:
738
+ corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
739
+ else:
740
+ corr_res = 0
741
+ D1_t = (model_t - model_prev_0)
742
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
743
+ return x_t, model_t
744
+
745
+
746
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
747
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
748
+ atol=0.0078, rtol=0.05, corrector=False,
749
+ ):
750
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
751
+ t_T = self.noise_schedule.T if t_start is None else t_start
752
+ device = x.device
753
+ if method == 'multistep':
754
+ assert steps >= order, "UniPC order must be < sampling steps"
755
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
756
+ #print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}")
757
+ assert timesteps.shape[0] - 1 == steps
758
+ with torch.no_grad():
759
+ vec_t = timesteps[0].expand((x.shape[0]))
760
+ model_prev_list = [self.model_fn(x, vec_t)]
761
+ t_prev_list = [vec_t]
762
+ with tqdm.tqdm(total=steps) as pbar:
763
+ # Init the first `order` values by lower order multistep DPM-Solver.
764
+ for init_order in range(1, order):
765
+ vec_t = timesteps[init_order].expand(x.shape[0])
766
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
767
+ if model_x is None:
768
+ model_x = self.model_fn(x, vec_t)
769
+ if self.after_update is not None:
770
+ self.after_update(x, model_x)
771
+ model_prev_list.append(model_x)
772
+ t_prev_list.append(vec_t)
773
+ pbar.update()
774
+
775
+ for step in range(order, steps + 1):
776
+ vec_t = timesteps[step].expand(x.shape[0])
777
+ if lower_order_final:
778
+ step_order = min(order, steps + 1 - step)
779
+ else:
780
+ step_order = order
781
+ #print('this step order:', step_order)
782
+ if step == steps:
783
+ #print('do not run corrector at the last step')
784
+ use_corrector = False
785
+ else:
786
+ use_corrector = True
787
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
788
+ if self.after_update is not None:
789
+ self.after_update(x, model_x)
790
+ for i in range(order - 1):
791
+ t_prev_list[i] = t_prev_list[i + 1]
792
+ model_prev_list[i] = model_prev_list[i + 1]
793
+ t_prev_list[-1] = vec_t
794
+ # We do not need to evaluate the final model value.
795
+ if step < steps:
796
+ if model_x is None:
797
+ model_x = self.model_fn(x, vec_t)
798
+ model_prev_list[-1] = model_x
799
+ pbar.update()
800
+ else:
801
+ raise NotImplementedError()
802
+ if denoise_to_zero:
803
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
804
+ return x
805
+
806
+
807
+ #############################################################
808
+ # other utility functions
809
+ #############################################################
810
+
811
+ def interpolate_fn(x, xp, yp):
812
+ """
813
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
814
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
815
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
816
+
817
+ Args:
818
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
819
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
820
+ yp: PyTorch tensor with shape [C, K].
821
+ Returns:
822
+ The function values f(x), with shape [N, C].
823
+ """
824
+ N, K = x.shape[0], xp.shape[1]
825
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
826
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
827
+ x_idx = torch.argmin(x_indices, dim=2)
828
+ cand_start_idx = x_idx - 1
829
+ start_idx = torch.where(
830
+ torch.eq(x_idx, 0),
831
+ torch.tensor(1, device=x.device),
832
+ torch.where(
833
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
834
+ ),
835
+ )
836
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
837
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
838
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
839
+ start_idx2 = torch.where(
840
+ torch.eq(x_idx, 0),
841
+ torch.tensor(0, device=x.device),
842
+ torch.where(
843
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
844
+ ),
845
+ )
846
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
847
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
848
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
849
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
850
+ return cand
851
+
852
+
853
+ def expand_dims(v, dims):
854
+ """
855
+ Expand the tensor `v` to the dim `dims`.
856
+
857
+ Args:
858
+ `v`: a PyTorch tensor with shape [N].
859
+ `dim`: a `int`.
860
+ Returns:
861
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
862
+ """
863
+ return v[(...,) + (None,)*(dims - 1)]