Spaces:
Sleeping
Sleeping
Add persistent folders, scripts, and small models
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- launch.py +48 -0
- modules/.DS_Store +0 -0
- modules/api/api.py +928 -0
- modules/api/models.py +329 -0
- modules/cache.py +123 -0
- modules/call_queue.py +134 -0
- modules/cmd_args.py +128 -0
- modules/codeformer_model.py +64 -0
- modules/config_states.py +198 -0
- modules/dat_model.py +79 -0
- modules/deepbooru.py +98 -0
- modules/deepbooru_model.py +678 -0
- modules/devices.py +295 -0
- modules/errors.py +150 -0
- modules/esrgan_model.py +62 -0
- modules/extensions.py +299 -0
- modules/extra_networks.py +225 -0
- modules/extra_networks_hypernet.py +28 -0
- modules/extras.py +330 -0
- modules/face_restoration.py +19 -0
- modules/face_restoration_utils.py +180 -0
- modules/fifo_lock.py +37 -0
- modules/gfpgan_model.py +69 -0
- modules/gitpython_hack.py +42 -0
- modules/gradio_extensons.py +83 -0
- modules/hashes.py +84 -0
- modules/hat_model.py +43 -0
- modules/hypernetworks/hypernetwork.py +783 -0
- modules/hypernetworks/ui.py +38 -0
- modules/images.py +877 -0
- modules/img2img.py +253 -0
- modules/import_hook.py +16 -0
- modules/infotext_utils.py +546 -0
- modules/infotext_versions.py +46 -0
- modules/initialize.py +169 -0
- modules/initialize_util.py +215 -0
- modules/interrogate.py +222 -0
- modules/launch_utils.py +482 -0
- modules/localization.py +37 -0
- modules/logging_config.py +58 -0
- modules/lowvram.py +165 -0
- modules/mac_specific.py +98 -0
- modules/masking.py +96 -0
- modules/memmon.py +92 -0
- modules/modelloader.py +197 -0
- modules/models/diffusion/ddpm_edit.py +1460 -0
- modules/models/diffusion/uni_pc/__init__.py +1 -0
- modules/models/diffusion/uni_pc/sampler.py +101 -0
- modules/models/diffusion/uni_pc/uni_pc.py +863 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
launch.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from modules import launch_utils
|
| 2 |
+
|
| 3 |
+
args = launch_utils.args
|
| 4 |
+
python = launch_utils.python
|
| 5 |
+
git = launch_utils.git
|
| 6 |
+
index_url = launch_utils.index_url
|
| 7 |
+
dir_repos = launch_utils.dir_repos
|
| 8 |
+
|
| 9 |
+
commit_hash = launch_utils.commit_hash
|
| 10 |
+
git_tag = launch_utils.git_tag
|
| 11 |
+
|
| 12 |
+
run = launch_utils.run
|
| 13 |
+
is_installed = launch_utils.is_installed
|
| 14 |
+
repo_dir = launch_utils.repo_dir
|
| 15 |
+
|
| 16 |
+
run_pip = launch_utils.run_pip
|
| 17 |
+
check_run_python = launch_utils.check_run_python
|
| 18 |
+
git_clone = launch_utils.git_clone
|
| 19 |
+
git_pull_recursive = launch_utils.git_pull_recursive
|
| 20 |
+
list_extensions = launch_utils.list_extensions
|
| 21 |
+
run_extension_installer = launch_utils.run_extension_installer
|
| 22 |
+
prepare_environment = launch_utils.prepare_environment
|
| 23 |
+
configure_for_tests = launch_utils.configure_for_tests
|
| 24 |
+
start = launch_utils.start
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main():
|
| 28 |
+
if args.dump_sysinfo:
|
| 29 |
+
filename = launch_utils.dump_sysinfo()
|
| 30 |
+
|
| 31 |
+
print(f"Sysinfo saved as {filename}. Exiting...")
|
| 32 |
+
|
| 33 |
+
exit(0)
|
| 34 |
+
|
| 35 |
+
launch_utils.startup_timer.record("initial startup")
|
| 36 |
+
|
| 37 |
+
with launch_utils.startup_timer.subcategory("prepare environment"):
|
| 38 |
+
if not args.skip_prepare_environment:
|
| 39 |
+
prepare_environment()
|
| 40 |
+
|
| 41 |
+
if args.test_server:
|
| 42 |
+
configure_for_tests()
|
| 43 |
+
|
| 44 |
+
start()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
if __name__ == "__main__":
|
| 48 |
+
main()
|
modules/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
modules/api/api.py
ADDED
|
@@ -0,0 +1,928 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import io
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
import datetime
|
| 6 |
+
import uvicorn
|
| 7 |
+
import ipaddress
|
| 8 |
+
import requests
|
| 9 |
+
import gradio as gr
|
| 10 |
+
from threading import Lock
|
| 11 |
+
from io import BytesIO
|
| 12 |
+
from fastapi import APIRouter, Depends, FastAPI, Request, Response
|
| 13 |
+
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
| 14 |
+
from fastapi.exceptions import HTTPException
|
| 15 |
+
from fastapi.responses import JSONResponse
|
| 16 |
+
from fastapi.encoders import jsonable_encoder
|
| 17 |
+
from secrets import compare_digest
|
| 18 |
+
|
| 19 |
+
import modules.shared as shared
|
| 20 |
+
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models, sd_schedulers
|
| 21 |
+
from modules.api import models
|
| 22 |
+
from modules.shared import opts
|
| 23 |
+
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
| 24 |
+
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
| 25 |
+
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
| 26 |
+
from PIL import PngImagePlugin
|
| 27 |
+
from modules.sd_models_config import find_checkpoint_config_near_filename
|
| 28 |
+
from modules.realesrgan_model import get_realesrgan_models
|
| 29 |
+
from modules import devices
|
| 30 |
+
from typing import Any
|
| 31 |
+
import piexif
|
| 32 |
+
import piexif.helper
|
| 33 |
+
from contextlib import closing
|
| 34 |
+
from modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task
|
| 35 |
+
|
| 36 |
+
def script_name_to_index(name, scripts):
|
| 37 |
+
try:
|
| 38 |
+
return [script.title().lower() for script in scripts].index(name.lower())
|
| 39 |
+
except Exception as e:
|
| 40 |
+
raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def validate_sampler_name(name):
|
| 44 |
+
config = sd_samplers.all_samplers_map.get(name, None)
|
| 45 |
+
if config is None:
|
| 46 |
+
raise HTTPException(status_code=400, detail="Sampler not found")
|
| 47 |
+
|
| 48 |
+
return name
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def setUpscalers(req: dict):
|
| 52 |
+
reqDict = vars(req)
|
| 53 |
+
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
|
| 54 |
+
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
|
| 55 |
+
return reqDict
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def verify_url(url):
|
| 59 |
+
"""Returns True if the url refers to a global resource."""
|
| 60 |
+
|
| 61 |
+
import socket
|
| 62 |
+
from urllib.parse import urlparse
|
| 63 |
+
try:
|
| 64 |
+
parsed_url = urlparse(url)
|
| 65 |
+
domain_name = parsed_url.netloc
|
| 66 |
+
host = socket.gethostbyname_ex(domain_name)
|
| 67 |
+
for ip in host[2]:
|
| 68 |
+
ip_addr = ipaddress.ip_address(ip)
|
| 69 |
+
if not ip_addr.is_global:
|
| 70 |
+
return False
|
| 71 |
+
except Exception:
|
| 72 |
+
return False
|
| 73 |
+
|
| 74 |
+
return True
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def decode_base64_to_image(encoding):
|
| 78 |
+
if encoding.startswith("http://") or encoding.startswith("https://"):
|
| 79 |
+
if not opts.api_enable_requests:
|
| 80 |
+
raise HTTPException(status_code=500, detail="Requests not allowed")
|
| 81 |
+
|
| 82 |
+
if opts.api_forbid_local_requests and not verify_url(encoding):
|
| 83 |
+
raise HTTPException(status_code=500, detail="Request to local resource not allowed")
|
| 84 |
+
|
| 85 |
+
headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}
|
| 86 |
+
response = requests.get(encoding, timeout=30, headers=headers)
|
| 87 |
+
try:
|
| 88 |
+
image = images.read(BytesIO(response.content))
|
| 89 |
+
return image
|
| 90 |
+
except Exception as e:
|
| 91 |
+
raise HTTPException(status_code=500, detail="Invalid image url") from e
|
| 92 |
+
|
| 93 |
+
if encoding.startswith("data:image/"):
|
| 94 |
+
encoding = encoding.split(";")[1].split(",")[1]
|
| 95 |
+
try:
|
| 96 |
+
image = images.read(BytesIO(base64.b64decode(encoding)))
|
| 97 |
+
return image
|
| 98 |
+
except Exception as e:
|
| 99 |
+
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def encode_pil_to_base64(image):
|
| 103 |
+
with io.BytesIO() as output_bytes:
|
| 104 |
+
if isinstance(image, str):
|
| 105 |
+
return image
|
| 106 |
+
if opts.samples_format.lower() == 'png':
|
| 107 |
+
use_metadata = False
|
| 108 |
+
metadata = PngImagePlugin.PngInfo()
|
| 109 |
+
for key, value in image.info.items():
|
| 110 |
+
if isinstance(key, str) and isinstance(value, str):
|
| 111 |
+
metadata.add_text(key, value)
|
| 112 |
+
use_metadata = True
|
| 113 |
+
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
|
| 114 |
+
|
| 115 |
+
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
|
| 116 |
+
if image.mode in ("RGBA", "P"):
|
| 117 |
+
image = image.convert("RGB")
|
| 118 |
+
parameters = image.info.get('parameters', None)
|
| 119 |
+
exif_bytes = piexif.dump({
|
| 120 |
+
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
|
| 121 |
+
})
|
| 122 |
+
if opts.samples_format.lower() in ("jpg", "jpeg"):
|
| 123 |
+
image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
|
| 124 |
+
else:
|
| 125 |
+
image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)
|
| 126 |
+
|
| 127 |
+
else:
|
| 128 |
+
raise HTTPException(status_code=500, detail="Invalid image format")
|
| 129 |
+
|
| 130 |
+
bytes_data = output_bytes.getvalue()
|
| 131 |
+
|
| 132 |
+
return base64.b64encode(bytes_data)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def api_middleware(app: FastAPI):
|
| 136 |
+
rich_available = False
|
| 137 |
+
try:
|
| 138 |
+
if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None:
|
| 139 |
+
import anyio # importing just so it can be placed on silent list
|
| 140 |
+
import starlette # importing just so it can be placed on silent list
|
| 141 |
+
from rich.console import Console
|
| 142 |
+
console = Console()
|
| 143 |
+
rich_available = True
|
| 144 |
+
except Exception:
|
| 145 |
+
pass
|
| 146 |
+
|
| 147 |
+
@app.middleware("http")
|
| 148 |
+
async def log_and_time(req: Request, call_next):
|
| 149 |
+
ts = time.time()
|
| 150 |
+
res: Response = await call_next(req)
|
| 151 |
+
duration = str(round(time.time() - ts, 4))
|
| 152 |
+
res.headers["X-Process-Time"] = duration
|
| 153 |
+
endpoint = req.scope.get('path', 'err')
|
| 154 |
+
if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
|
| 155 |
+
print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
|
| 156 |
+
t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
|
| 157 |
+
code=res.status_code,
|
| 158 |
+
ver=req.scope.get('http_version', '0.0'),
|
| 159 |
+
cli=req.scope.get('client', ('0:0.0.0', 0))[0],
|
| 160 |
+
prot=req.scope.get('scheme', 'err'),
|
| 161 |
+
method=req.scope.get('method', 'err'),
|
| 162 |
+
endpoint=endpoint,
|
| 163 |
+
duration=duration,
|
| 164 |
+
))
|
| 165 |
+
return res
|
| 166 |
+
|
| 167 |
+
def handle_exception(request: Request, e: Exception):
|
| 168 |
+
err = {
|
| 169 |
+
"error": type(e).__name__,
|
| 170 |
+
"detail": vars(e).get('detail', ''),
|
| 171 |
+
"body": vars(e).get('body', ''),
|
| 172 |
+
"errors": str(e),
|
| 173 |
+
}
|
| 174 |
+
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
|
| 175 |
+
message = f"API error: {request.method}: {request.url} {err}"
|
| 176 |
+
if rich_available:
|
| 177 |
+
print(message)
|
| 178 |
+
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
|
| 179 |
+
else:
|
| 180 |
+
errors.report(message, exc_info=True)
|
| 181 |
+
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
|
| 182 |
+
|
| 183 |
+
@app.middleware("http")
|
| 184 |
+
async def exception_handling(request: Request, call_next):
|
| 185 |
+
try:
|
| 186 |
+
return await call_next(request)
|
| 187 |
+
except Exception as e:
|
| 188 |
+
return handle_exception(request, e)
|
| 189 |
+
|
| 190 |
+
@app.exception_handler(Exception)
|
| 191 |
+
async def fastapi_exception_handler(request: Request, e: Exception):
|
| 192 |
+
return handle_exception(request, e)
|
| 193 |
+
|
| 194 |
+
@app.exception_handler(HTTPException)
|
| 195 |
+
async def http_exception_handler(request: Request, e: HTTPException):
|
| 196 |
+
return handle_exception(request, e)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class Api:
|
| 200 |
+
def __init__(self, app: FastAPI, queue_lock: Lock):
|
| 201 |
+
if shared.cmd_opts.api_auth:
|
| 202 |
+
self.credentials = {}
|
| 203 |
+
for auth in shared.cmd_opts.api_auth.split(","):
|
| 204 |
+
user, password = auth.split(":")
|
| 205 |
+
self.credentials[user] = password
|
| 206 |
+
|
| 207 |
+
self.router = APIRouter()
|
| 208 |
+
self.app = app
|
| 209 |
+
self.queue_lock = queue_lock
|
| 210 |
+
api_middleware(self.app)
|
| 211 |
+
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
|
| 212 |
+
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)
|
| 213 |
+
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
|
| 214 |
+
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
|
| 215 |
+
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
|
| 216 |
+
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
|
| 217 |
+
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
| 218 |
+
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
| 219 |
+
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
| 220 |
+
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
| 221 |
+
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
| 222 |
+
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
| 223 |
+
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
|
| 224 |
+
self.add_api_route("/sdapi/v1/schedulers", self.get_schedulers, methods=["GET"], response_model=list[models.SchedulerItem])
|
| 225 |
+
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
|
| 226 |
+
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
|
| 227 |
+
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
|
| 228 |
+
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem])
|
| 229 |
+
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem])
|
| 230 |
+
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem])
|
| 231 |
+
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
|
| 232 |
+
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
|
| 233 |
+
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
| 234 |
+
self.add_api_route("/sdapi/v1/refresh-embeddings", self.refresh_embeddings, methods=["POST"])
|
| 235 |
+
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
| 236 |
+
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
| 237 |
+
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
| 238 |
+
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
| 239 |
+
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
| 240 |
+
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
| 241 |
+
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
| 242 |
+
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
| 243 |
+
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
| 244 |
+
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
| 245 |
+
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])
|
| 246 |
+
self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
|
| 247 |
+
|
| 248 |
+
if shared.cmd_opts.api_server_stop:
|
| 249 |
+
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
|
| 250 |
+
self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
|
| 251 |
+
self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
|
| 252 |
+
|
| 253 |
+
self.default_script_arg_txt2img = []
|
| 254 |
+
self.default_script_arg_img2img = []
|
| 255 |
+
|
| 256 |
+
txt2img_script_runner = scripts.scripts_txt2img
|
| 257 |
+
img2img_script_runner = scripts.scripts_img2img
|
| 258 |
+
|
| 259 |
+
if not txt2img_script_runner.scripts or not img2img_script_runner.scripts:
|
| 260 |
+
ui.create_ui()
|
| 261 |
+
|
| 262 |
+
if not txt2img_script_runner.scripts:
|
| 263 |
+
txt2img_script_runner.initialize_scripts(False)
|
| 264 |
+
if not self.default_script_arg_txt2img:
|
| 265 |
+
self.default_script_arg_txt2img = self.init_default_script_args(txt2img_script_runner)
|
| 266 |
+
|
| 267 |
+
if not img2img_script_runner.scripts:
|
| 268 |
+
img2img_script_runner.initialize_scripts(True)
|
| 269 |
+
if not self.default_script_arg_img2img:
|
| 270 |
+
self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def add_api_route(self, path: str, endpoint, **kwargs):
|
| 275 |
+
if shared.cmd_opts.api_auth:
|
| 276 |
+
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
|
| 277 |
+
return self.app.add_api_route(path, endpoint, **kwargs)
|
| 278 |
+
|
| 279 |
+
def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
|
| 280 |
+
if credentials.username in self.credentials:
|
| 281 |
+
if compare_digest(credentials.password, self.credentials[credentials.username]):
|
| 282 |
+
return True
|
| 283 |
+
|
| 284 |
+
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
|
| 285 |
+
|
| 286 |
+
def get_selectable_script(self, script_name, script_runner):
|
| 287 |
+
if script_name is None or script_name == "":
|
| 288 |
+
return None, None
|
| 289 |
+
|
| 290 |
+
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
|
| 291 |
+
script = script_runner.selectable_scripts[script_idx]
|
| 292 |
+
return script, script_idx
|
| 293 |
+
|
| 294 |
+
def get_scripts_list(self):
|
| 295 |
+
t2ilist = [script.name for script in scripts.scripts_txt2img.scripts if script.name is not None]
|
| 296 |
+
i2ilist = [script.name for script in scripts.scripts_img2img.scripts if script.name is not None]
|
| 297 |
+
|
| 298 |
+
return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist)
|
| 299 |
+
|
| 300 |
+
def get_script_info(self):
|
| 301 |
+
res = []
|
| 302 |
+
|
| 303 |
+
for script_list in [scripts.scripts_txt2img.scripts, scripts.scripts_img2img.scripts]:
|
| 304 |
+
res += [script.api_info for script in script_list if script.api_info is not None]
|
| 305 |
+
|
| 306 |
+
return res
|
| 307 |
+
|
| 308 |
+
def get_script(self, script_name, script_runner):
|
| 309 |
+
if script_name is None or script_name == "":
|
| 310 |
+
return None, None
|
| 311 |
+
|
| 312 |
+
script_idx = script_name_to_index(script_name, script_runner.scripts)
|
| 313 |
+
return script_runner.scripts[script_idx]
|
| 314 |
+
|
| 315 |
+
def init_default_script_args(self, script_runner):
|
| 316 |
+
#find max idx from the scripts in runner and generate a none array to init script_args
|
| 317 |
+
last_arg_index = 1
|
| 318 |
+
for script in script_runner.scripts:
|
| 319 |
+
if last_arg_index < script.args_to:
|
| 320 |
+
last_arg_index = script.args_to
|
| 321 |
+
# None everywhere except position 0 to initialize script args
|
| 322 |
+
script_args = [None]*last_arg_index
|
| 323 |
+
script_args[0] = 0
|
| 324 |
+
|
| 325 |
+
# get default values
|
| 326 |
+
with gr.Blocks(): # will throw errors calling ui function without this
|
| 327 |
+
for script in script_runner.scripts:
|
| 328 |
+
if script.ui(script.is_img2img):
|
| 329 |
+
ui_default_values = []
|
| 330 |
+
for elem in script.ui(script.is_img2img):
|
| 331 |
+
ui_default_values.append(elem.value)
|
| 332 |
+
script_args[script.args_from:script.args_to] = ui_default_values
|
| 333 |
+
return script_args
|
| 334 |
+
|
| 335 |
+
def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner, *, input_script_args=None):
|
| 336 |
+
script_args = default_script_args.copy()
|
| 337 |
+
|
| 338 |
+
if input_script_args is not None:
|
| 339 |
+
for index, value in input_script_args.items():
|
| 340 |
+
script_args[index] = value
|
| 341 |
+
|
| 342 |
+
# position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
|
| 343 |
+
if selectable_scripts:
|
| 344 |
+
script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
|
| 345 |
+
script_args[0] = selectable_idx + 1
|
| 346 |
+
|
| 347 |
+
# Now check for always on scripts
|
| 348 |
+
if request.alwayson_scripts:
|
| 349 |
+
for alwayson_script_name in request.alwayson_scripts.keys():
|
| 350 |
+
alwayson_script = self.get_script(alwayson_script_name, script_runner)
|
| 351 |
+
if alwayson_script is None:
|
| 352 |
+
raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
|
| 353 |
+
# Selectable script in always on script param check
|
| 354 |
+
if alwayson_script.alwayson is False:
|
| 355 |
+
raise HTTPException(status_code=422, detail="Cannot have a selectable script in the always on scripts params")
|
| 356 |
+
# always on script with no arg should always run so you don't really need to add them to the requests
|
| 357 |
+
if "args" in request.alwayson_scripts[alwayson_script_name]:
|
| 358 |
+
# min between arg length in scriptrunner and arg length in the request
|
| 359 |
+
for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))):
|
| 360 |
+
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
|
| 361 |
+
return script_args
|
| 362 |
+
|
| 363 |
+
def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None):
|
| 364 |
+
"""Processes `infotext` field from the `request`, and sets other fields of the `request` according to what's in infotext.
|
| 365 |
+
|
| 366 |
+
If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored.
|
| 367 |
+
|
| 368 |
+
Additionally, fills `mentioned_script_args` dict with index: value pairs for script arguments read from infotext.
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
if not request.infotext:
|
| 372 |
+
return {}
|
| 373 |
+
|
| 374 |
+
possible_fields = infotext_utils.paste_fields[tabname]["fields"]
|
| 375 |
+
set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have different names for this
|
| 376 |
+
params = infotext_utils.parse_generation_parameters(request.infotext)
|
| 377 |
+
|
| 378 |
+
def get_field_value(field, params):
|
| 379 |
+
value = field.function(params) if field.function else params.get(field.label)
|
| 380 |
+
if value is None:
|
| 381 |
+
return None
|
| 382 |
+
|
| 383 |
+
if field.api in request.__fields__:
|
| 384 |
+
target_type = request.__fields__[field.api].type_
|
| 385 |
+
else:
|
| 386 |
+
target_type = type(field.component.value)
|
| 387 |
+
|
| 388 |
+
if target_type == type(None):
|
| 389 |
+
return None
|
| 390 |
+
|
| 391 |
+
if isinstance(value, dict) and value.get('__type__') == 'generic_update': # this is a gradio.update rather than a value
|
| 392 |
+
value = value.get('value')
|
| 393 |
+
|
| 394 |
+
if value is not None and not isinstance(value, target_type):
|
| 395 |
+
value = target_type(value)
|
| 396 |
+
|
| 397 |
+
return value
|
| 398 |
+
|
| 399 |
+
for field in possible_fields:
|
| 400 |
+
if not field.api:
|
| 401 |
+
continue
|
| 402 |
+
|
| 403 |
+
if field.api in set_fields:
|
| 404 |
+
continue
|
| 405 |
+
|
| 406 |
+
value = get_field_value(field, params)
|
| 407 |
+
if value is not None:
|
| 408 |
+
setattr(request, field.api, value)
|
| 409 |
+
|
| 410 |
+
if request.override_settings is None:
|
| 411 |
+
request.override_settings = {}
|
| 412 |
+
|
| 413 |
+
overridden_settings = infotext_utils.get_override_settings(params)
|
| 414 |
+
for _, setting_name, value in overridden_settings:
|
| 415 |
+
if setting_name not in request.override_settings:
|
| 416 |
+
request.override_settings[setting_name] = value
|
| 417 |
+
|
| 418 |
+
if script_runner is not None and mentioned_script_args is not None:
|
| 419 |
+
indexes = {v: i for i, v in enumerate(script_runner.inputs)}
|
| 420 |
+
script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes)
|
| 421 |
+
|
| 422 |
+
for field, index in script_fields:
|
| 423 |
+
value = get_field_value(field, params)
|
| 424 |
+
|
| 425 |
+
if value is None:
|
| 426 |
+
continue
|
| 427 |
+
|
| 428 |
+
mentioned_script_args[index] = value
|
| 429 |
+
|
| 430 |
+
return params
|
| 431 |
+
|
| 432 |
+
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
|
| 433 |
+
task_id = txt2imgreq.force_task_id or create_task_id("txt2img")
|
| 434 |
+
|
| 435 |
+
script_runner = scripts.scripts_txt2img
|
| 436 |
+
|
| 437 |
+
infotext_script_args = {}
|
| 438 |
+
self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
|
| 439 |
+
|
| 440 |
+
selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
|
| 441 |
+
sampler, scheduler = sd_samplers.get_sampler_and_scheduler(txt2imgreq.sampler_name or txt2imgreq.sampler_index, txt2imgreq.scheduler)
|
| 442 |
+
|
| 443 |
+
populate = txt2imgreq.copy(update={ # Override __init__ params
|
| 444 |
+
"sampler_name": validate_sampler_name(sampler),
|
| 445 |
+
"do_not_save_samples": not txt2imgreq.save_images,
|
| 446 |
+
"do_not_save_grid": not txt2imgreq.save_images,
|
| 447 |
+
})
|
| 448 |
+
if populate.sampler_name:
|
| 449 |
+
populate.sampler_index = None # prevent a warning later on
|
| 450 |
+
|
| 451 |
+
if not populate.scheduler and scheduler != "Automatic":
|
| 452 |
+
populate.scheduler = scheduler
|
| 453 |
+
|
| 454 |
+
args = vars(populate)
|
| 455 |
+
args.pop('script_name', None)
|
| 456 |
+
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
|
| 457 |
+
args.pop('alwayson_scripts', None)
|
| 458 |
+
args.pop('infotext', None)
|
| 459 |
+
|
| 460 |
+
script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
|
| 461 |
+
|
| 462 |
+
send_images = args.pop('send_images', True)
|
| 463 |
+
args.pop('save_images', None)
|
| 464 |
+
|
| 465 |
+
add_task_to_queue(task_id)
|
| 466 |
+
|
| 467 |
+
with self.queue_lock:
|
| 468 |
+
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
|
| 469 |
+
p.is_api = True
|
| 470 |
+
p.scripts = script_runner
|
| 471 |
+
p.outpath_grids = opts.outdir_txt2img_grids
|
| 472 |
+
p.outpath_samples = opts.outdir_txt2img_samples
|
| 473 |
+
|
| 474 |
+
try:
|
| 475 |
+
shared.state.begin(job="scripts_txt2img")
|
| 476 |
+
start_task(task_id)
|
| 477 |
+
if selectable_scripts is not None:
|
| 478 |
+
p.script_args = script_args
|
| 479 |
+
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
| 480 |
+
else:
|
| 481 |
+
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
| 482 |
+
processed = process_images(p)
|
| 483 |
+
finish_task(task_id)
|
| 484 |
+
finally:
|
| 485 |
+
shared.state.end()
|
| 486 |
+
shared.total_tqdm.clear()
|
| 487 |
+
|
| 488 |
+
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
| 489 |
+
|
| 490 |
+
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
| 491 |
+
|
| 492 |
+
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
|
| 493 |
+
task_id = img2imgreq.force_task_id or create_task_id("img2img")
|
| 494 |
+
|
| 495 |
+
init_images = img2imgreq.init_images
|
| 496 |
+
if init_images is None:
|
| 497 |
+
raise HTTPException(status_code=404, detail="Init image not found")
|
| 498 |
+
|
| 499 |
+
mask = img2imgreq.mask
|
| 500 |
+
if mask:
|
| 501 |
+
mask = decode_base64_to_image(mask)
|
| 502 |
+
|
| 503 |
+
script_runner = scripts.scripts_img2img
|
| 504 |
+
|
| 505 |
+
infotext_script_args = {}
|
| 506 |
+
self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
|
| 507 |
+
|
| 508 |
+
selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
|
| 509 |
+
sampler, scheduler = sd_samplers.get_sampler_and_scheduler(img2imgreq.sampler_name or img2imgreq.sampler_index, img2imgreq.scheduler)
|
| 510 |
+
|
| 511 |
+
populate = img2imgreq.copy(update={ # Override __init__ params
|
| 512 |
+
"sampler_name": validate_sampler_name(sampler),
|
| 513 |
+
"do_not_save_samples": not img2imgreq.save_images,
|
| 514 |
+
"do_not_save_grid": not img2imgreq.save_images,
|
| 515 |
+
"mask": mask,
|
| 516 |
+
})
|
| 517 |
+
if populate.sampler_name:
|
| 518 |
+
populate.sampler_index = None # prevent a warning later on
|
| 519 |
+
|
| 520 |
+
if not populate.scheduler and scheduler != "Automatic":
|
| 521 |
+
populate.scheduler = scheduler
|
| 522 |
+
|
| 523 |
+
args = vars(populate)
|
| 524 |
+
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
|
| 525 |
+
args.pop('script_name', None)
|
| 526 |
+
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
|
| 527 |
+
args.pop('alwayson_scripts', None)
|
| 528 |
+
args.pop('infotext', None)
|
| 529 |
+
|
| 530 |
+
script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
|
| 531 |
+
|
| 532 |
+
send_images = args.pop('send_images', True)
|
| 533 |
+
args.pop('save_images', None)
|
| 534 |
+
|
| 535 |
+
add_task_to_queue(task_id)
|
| 536 |
+
|
| 537 |
+
with self.queue_lock:
|
| 538 |
+
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
|
| 539 |
+
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
| 540 |
+
p.is_api = True
|
| 541 |
+
p.scripts = script_runner
|
| 542 |
+
p.outpath_grids = opts.outdir_img2img_grids
|
| 543 |
+
p.outpath_samples = opts.outdir_img2img_samples
|
| 544 |
+
|
| 545 |
+
try:
|
| 546 |
+
shared.state.begin(job="scripts_img2img")
|
| 547 |
+
start_task(task_id)
|
| 548 |
+
if selectable_scripts is not None:
|
| 549 |
+
p.script_args = script_args
|
| 550 |
+
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
| 551 |
+
else:
|
| 552 |
+
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
| 553 |
+
processed = process_images(p)
|
| 554 |
+
finish_task(task_id)
|
| 555 |
+
finally:
|
| 556 |
+
shared.state.end()
|
| 557 |
+
shared.total_tqdm.clear()
|
| 558 |
+
|
| 559 |
+
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
| 560 |
+
|
| 561 |
+
if not img2imgreq.include_init_images:
|
| 562 |
+
img2imgreq.init_images = None
|
| 563 |
+
img2imgreq.mask = None
|
| 564 |
+
|
| 565 |
+
return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
| 566 |
+
|
| 567 |
+
def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):
|
| 568 |
+
reqDict = setUpscalers(req)
|
| 569 |
+
|
| 570 |
+
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
| 571 |
+
|
| 572 |
+
with self.queue_lock:
|
| 573 |
+
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
| 574 |
+
|
| 575 |
+
return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
| 576 |
+
|
| 577 |
+
def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
|
| 578 |
+
reqDict = setUpscalers(req)
|
| 579 |
+
|
| 580 |
+
image_list = reqDict.pop('imageList', [])
|
| 581 |
+
image_folder = [decode_base64_to_image(x.data) for x in image_list]
|
| 582 |
+
|
| 583 |
+
with self.queue_lock:
|
| 584 |
+
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
| 585 |
+
|
| 586 |
+
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
| 587 |
+
|
| 588 |
+
def pnginfoapi(self, req: models.PNGInfoRequest):
|
| 589 |
+
image = decode_base64_to_image(req.image.strip())
|
| 590 |
+
if image is None:
|
| 591 |
+
return models.PNGInfoResponse(info="")
|
| 592 |
+
|
| 593 |
+
geninfo, items = images.read_info_from_image(image)
|
| 594 |
+
if geninfo is None:
|
| 595 |
+
geninfo = ""
|
| 596 |
+
|
| 597 |
+
params = infotext_utils.parse_generation_parameters(geninfo)
|
| 598 |
+
script_callbacks.infotext_pasted_callback(geninfo, params)
|
| 599 |
+
|
| 600 |
+
return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
|
| 601 |
+
|
| 602 |
+
def progressapi(self, req: models.ProgressRequest = Depends()):
|
| 603 |
+
# copy from check_progress_call of ui.py
|
| 604 |
+
|
| 605 |
+
if shared.state.job_count == 0:
|
| 606 |
+
return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
| 607 |
+
|
| 608 |
+
# avoid dividing zero
|
| 609 |
+
progress = 0.01
|
| 610 |
+
|
| 611 |
+
if shared.state.job_count > 0:
|
| 612 |
+
progress += shared.state.job_no / shared.state.job_count
|
| 613 |
+
if shared.state.sampling_steps > 0:
|
| 614 |
+
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
|
| 615 |
+
|
| 616 |
+
time_since_start = time.time() - shared.state.time_start
|
| 617 |
+
eta = (time_since_start/progress)
|
| 618 |
+
eta_relative = eta-time_since_start
|
| 619 |
+
|
| 620 |
+
progress = min(progress, 1)
|
| 621 |
+
|
| 622 |
+
shared.state.set_current_image()
|
| 623 |
+
|
| 624 |
+
current_image = None
|
| 625 |
+
if shared.state.current_image and not req.skip_current_image:
|
| 626 |
+
current_image = encode_pil_to_base64(shared.state.current_image)
|
| 627 |
+
|
| 628 |
+
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task)
|
| 629 |
+
|
| 630 |
+
def interrogateapi(self, interrogatereq: models.InterrogateRequest):
|
| 631 |
+
image_b64 = interrogatereq.image
|
| 632 |
+
if image_b64 is None:
|
| 633 |
+
raise HTTPException(status_code=404, detail="Image not found")
|
| 634 |
+
|
| 635 |
+
img = decode_base64_to_image(image_b64)
|
| 636 |
+
img = img.convert('RGB')
|
| 637 |
+
|
| 638 |
+
# Override object param
|
| 639 |
+
with self.queue_lock:
|
| 640 |
+
if interrogatereq.model == "clip":
|
| 641 |
+
processed = shared.interrogator.interrogate(img)
|
| 642 |
+
elif interrogatereq.model == "deepdanbooru":
|
| 643 |
+
processed = deepbooru.model.tag(img)
|
| 644 |
+
else:
|
| 645 |
+
raise HTTPException(status_code=404, detail="Model not found")
|
| 646 |
+
|
| 647 |
+
return models.InterrogateResponse(caption=processed)
|
| 648 |
+
|
| 649 |
+
def interruptapi(self):
|
| 650 |
+
shared.state.interrupt()
|
| 651 |
+
|
| 652 |
+
return {}
|
| 653 |
+
|
| 654 |
+
def unloadapi(self):
|
| 655 |
+
sd_models.unload_model_weights()
|
| 656 |
+
|
| 657 |
+
return {}
|
| 658 |
+
|
| 659 |
+
def reloadapi(self):
|
| 660 |
+
sd_models.send_model_to_device(shared.sd_model)
|
| 661 |
+
|
| 662 |
+
return {}
|
| 663 |
+
|
| 664 |
+
def skip(self):
|
| 665 |
+
shared.state.skip()
|
| 666 |
+
|
| 667 |
+
def get_config(self):
|
| 668 |
+
options = {}
|
| 669 |
+
for key in shared.opts.data.keys():
|
| 670 |
+
metadata = shared.opts.data_labels.get(key)
|
| 671 |
+
if(metadata is not None):
|
| 672 |
+
options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
|
| 673 |
+
else:
|
| 674 |
+
options.update({key: shared.opts.data.get(key, None)})
|
| 675 |
+
|
| 676 |
+
return options
|
| 677 |
+
|
| 678 |
+
def set_config(self, req: dict[str, Any]):
|
| 679 |
+
checkpoint_name = req.get("sd_model_checkpoint", None)
|
| 680 |
+
if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
|
| 681 |
+
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
| 682 |
+
|
| 683 |
+
for k, v in req.items():
|
| 684 |
+
shared.opts.set(k, v, is_api=True)
|
| 685 |
+
|
| 686 |
+
shared.opts.save(shared.config_filename)
|
| 687 |
+
return
|
| 688 |
+
|
| 689 |
+
def get_cmd_flags(self):
|
| 690 |
+
return vars(shared.cmd_opts)
|
| 691 |
+
|
| 692 |
+
def get_samplers(self):
|
| 693 |
+
return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
|
| 694 |
+
|
| 695 |
+
def get_schedulers(self):
|
| 696 |
+
return [
|
| 697 |
+
{
|
| 698 |
+
"name": scheduler.name,
|
| 699 |
+
"label": scheduler.label,
|
| 700 |
+
"aliases": scheduler.aliases,
|
| 701 |
+
"default_rho": scheduler.default_rho,
|
| 702 |
+
"need_inner_model": scheduler.need_inner_model,
|
| 703 |
+
}
|
| 704 |
+
for scheduler in sd_schedulers.schedulers]
|
| 705 |
+
|
| 706 |
+
def get_upscalers(self):
|
| 707 |
+
return [
|
| 708 |
+
{
|
| 709 |
+
"name": upscaler.name,
|
| 710 |
+
"model_name": upscaler.scaler.model_name,
|
| 711 |
+
"model_path": upscaler.data_path,
|
| 712 |
+
"model_url": None,
|
| 713 |
+
"scale": upscaler.scale,
|
| 714 |
+
}
|
| 715 |
+
for upscaler in shared.sd_upscalers
|
| 716 |
+
]
|
| 717 |
+
|
| 718 |
+
def get_latent_upscale_modes(self):
|
| 719 |
+
return [
|
| 720 |
+
{
|
| 721 |
+
"name": upscale_mode,
|
| 722 |
+
}
|
| 723 |
+
for upscale_mode in [*(shared.latent_upscale_modes or {})]
|
| 724 |
+
]
|
| 725 |
+
|
| 726 |
+
def get_sd_models(self):
|
| 727 |
+
import modules.sd_models as sd_models
|
| 728 |
+
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()]
|
| 729 |
+
|
| 730 |
+
def get_sd_vaes(self):
|
| 731 |
+
import modules.sd_vae as sd_vae
|
| 732 |
+
return [{"model_name": x, "filename": sd_vae.vae_dict[x]} for x in sd_vae.vae_dict.keys()]
|
| 733 |
+
|
| 734 |
+
def get_hypernetworks(self):
|
| 735 |
+
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
| 736 |
+
|
| 737 |
+
def get_face_restorers(self):
|
| 738 |
+
return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
|
| 739 |
+
|
| 740 |
+
def get_realesrgan_models(self):
|
| 741 |
+
return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
|
| 742 |
+
|
| 743 |
+
def get_prompt_styles(self):
|
| 744 |
+
styleList = []
|
| 745 |
+
for k in shared.prompt_styles.styles:
|
| 746 |
+
style = shared.prompt_styles.styles[k]
|
| 747 |
+
styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
|
| 748 |
+
|
| 749 |
+
return styleList
|
| 750 |
+
|
| 751 |
+
def get_embeddings(self):
|
| 752 |
+
db = sd_hijack.model_hijack.embedding_db
|
| 753 |
+
|
| 754 |
+
def convert_embedding(embedding):
|
| 755 |
+
return {
|
| 756 |
+
"step": embedding.step,
|
| 757 |
+
"sd_checkpoint": embedding.sd_checkpoint,
|
| 758 |
+
"sd_checkpoint_name": embedding.sd_checkpoint_name,
|
| 759 |
+
"shape": embedding.shape,
|
| 760 |
+
"vectors": embedding.vectors,
|
| 761 |
+
}
|
| 762 |
+
|
| 763 |
+
def convert_embeddings(embeddings):
|
| 764 |
+
return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
|
| 765 |
+
|
| 766 |
+
return {
|
| 767 |
+
"loaded": convert_embeddings(db.word_embeddings),
|
| 768 |
+
"skipped": convert_embeddings(db.skipped_embeddings),
|
| 769 |
+
}
|
| 770 |
+
|
| 771 |
+
def refresh_embeddings(self):
|
| 772 |
+
with self.queue_lock:
|
| 773 |
+
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
| 774 |
+
|
| 775 |
+
def refresh_checkpoints(self):
|
| 776 |
+
with self.queue_lock:
|
| 777 |
+
shared.refresh_checkpoints()
|
| 778 |
+
|
| 779 |
+
def refresh_vae(self):
|
| 780 |
+
with self.queue_lock:
|
| 781 |
+
shared_items.refresh_vae_list()
|
| 782 |
+
|
| 783 |
+
def create_embedding(self, args: dict):
|
| 784 |
+
try:
|
| 785 |
+
shared.state.begin(job="create_embedding")
|
| 786 |
+
filename = create_embedding(**args) # create empty embedding
|
| 787 |
+
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
| 788 |
+
return models.CreateResponse(info=f"create embedding filename: {filename}")
|
| 789 |
+
except AssertionError as e:
|
| 790 |
+
return models.TrainResponse(info=f"create embedding error: {e}")
|
| 791 |
+
finally:
|
| 792 |
+
shared.state.end()
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
def create_hypernetwork(self, args: dict):
|
| 796 |
+
try:
|
| 797 |
+
shared.state.begin(job="create_hypernetwork")
|
| 798 |
+
filename = create_hypernetwork(**args) # create empty embedding
|
| 799 |
+
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
|
| 800 |
+
except AssertionError as e:
|
| 801 |
+
return models.TrainResponse(info=f"create hypernetwork error: {e}")
|
| 802 |
+
finally:
|
| 803 |
+
shared.state.end()
|
| 804 |
+
|
| 805 |
+
def train_embedding(self, args: dict):
|
| 806 |
+
try:
|
| 807 |
+
shared.state.begin(job="train_embedding")
|
| 808 |
+
apply_optimizations = shared.opts.training_xattention_optimizations
|
| 809 |
+
error = None
|
| 810 |
+
filename = ''
|
| 811 |
+
if not apply_optimizations:
|
| 812 |
+
sd_hijack.undo_optimizations()
|
| 813 |
+
try:
|
| 814 |
+
embedding, filename = train_embedding(**args) # can take a long time to complete
|
| 815 |
+
except Exception as e:
|
| 816 |
+
error = e
|
| 817 |
+
finally:
|
| 818 |
+
if not apply_optimizations:
|
| 819 |
+
sd_hijack.apply_optimizations()
|
| 820 |
+
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
| 821 |
+
except Exception as msg:
|
| 822 |
+
return models.TrainResponse(info=f"train embedding error: {msg}")
|
| 823 |
+
finally:
|
| 824 |
+
shared.state.end()
|
| 825 |
+
|
| 826 |
+
def train_hypernetwork(self, args: dict):
|
| 827 |
+
try:
|
| 828 |
+
shared.state.begin(job="train_hypernetwork")
|
| 829 |
+
shared.loaded_hypernetworks = []
|
| 830 |
+
apply_optimizations = shared.opts.training_xattention_optimizations
|
| 831 |
+
error = None
|
| 832 |
+
filename = ''
|
| 833 |
+
if not apply_optimizations:
|
| 834 |
+
sd_hijack.undo_optimizations()
|
| 835 |
+
try:
|
| 836 |
+
hypernetwork, filename = train_hypernetwork(**args)
|
| 837 |
+
except Exception as e:
|
| 838 |
+
error = e
|
| 839 |
+
finally:
|
| 840 |
+
shared.sd_model.cond_stage_model.to(devices.device)
|
| 841 |
+
shared.sd_model.first_stage_model.to(devices.device)
|
| 842 |
+
if not apply_optimizations:
|
| 843 |
+
sd_hijack.apply_optimizations()
|
| 844 |
+
shared.state.end()
|
| 845 |
+
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
| 846 |
+
except Exception as exc:
|
| 847 |
+
return models.TrainResponse(info=f"train embedding error: {exc}")
|
| 848 |
+
finally:
|
| 849 |
+
shared.state.end()
|
| 850 |
+
|
| 851 |
+
def get_memory(self):
|
| 852 |
+
try:
|
| 853 |
+
import os
|
| 854 |
+
import psutil
|
| 855 |
+
process = psutil.Process(os.getpid())
|
| 856 |
+
res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
|
| 857 |
+
ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
|
| 858 |
+
ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total }
|
| 859 |
+
except Exception as err:
|
| 860 |
+
ram = { 'error': f'{err}' }
|
| 861 |
+
try:
|
| 862 |
+
import torch
|
| 863 |
+
if torch.cuda.is_available():
|
| 864 |
+
s = torch.cuda.mem_get_info()
|
| 865 |
+
system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] }
|
| 866 |
+
s = dict(torch.cuda.memory_stats(shared.device))
|
| 867 |
+
allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] }
|
| 868 |
+
reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] }
|
| 869 |
+
active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] }
|
| 870 |
+
inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] }
|
| 871 |
+
warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }
|
| 872 |
+
cuda = {
|
| 873 |
+
'system': system,
|
| 874 |
+
'active': active,
|
| 875 |
+
'allocated': allocated,
|
| 876 |
+
'reserved': reserved,
|
| 877 |
+
'inactive': inactive,
|
| 878 |
+
'events': warnings,
|
| 879 |
+
}
|
| 880 |
+
else:
|
| 881 |
+
cuda = {'error': 'unavailable'}
|
| 882 |
+
except Exception as err:
|
| 883 |
+
cuda = {'error': f'{err}'}
|
| 884 |
+
return models.MemoryResponse(ram=ram, cuda=cuda)
|
| 885 |
+
|
| 886 |
+
def get_extensions_list(self):
|
| 887 |
+
from modules import extensions
|
| 888 |
+
extensions.list_extensions()
|
| 889 |
+
ext_list = []
|
| 890 |
+
for ext in extensions.extensions:
|
| 891 |
+
ext: extensions.Extension
|
| 892 |
+
ext.read_info_from_repo()
|
| 893 |
+
if ext.remote is not None:
|
| 894 |
+
ext_list.append({
|
| 895 |
+
"name": ext.name,
|
| 896 |
+
"remote": ext.remote,
|
| 897 |
+
"branch": ext.branch,
|
| 898 |
+
"commit_hash":ext.commit_hash,
|
| 899 |
+
"commit_date":ext.commit_date,
|
| 900 |
+
"version":ext.version,
|
| 901 |
+
"enabled":ext.enabled
|
| 902 |
+
})
|
| 903 |
+
return ext_list
|
| 904 |
+
|
| 905 |
+
def launch(self, server_name, port, root_path):
|
| 906 |
+
self.app.include_router(self.router)
|
| 907 |
+
uvicorn.run(
|
| 908 |
+
self.app,
|
| 909 |
+
host=server_name,
|
| 910 |
+
port=port,
|
| 911 |
+
timeout_keep_alive=shared.cmd_opts.timeout_keep_alive,
|
| 912 |
+
root_path=root_path,
|
| 913 |
+
ssl_keyfile=shared.cmd_opts.tls_keyfile,
|
| 914 |
+
ssl_certfile=shared.cmd_opts.tls_certfile
|
| 915 |
+
)
|
| 916 |
+
|
| 917 |
+
def kill_webui(self):
|
| 918 |
+
restart.stop_program()
|
| 919 |
+
|
| 920 |
+
def restart_webui(self):
|
| 921 |
+
if restart.is_restartable():
|
| 922 |
+
restart.restart_program()
|
| 923 |
+
return Response(status_code=501)
|
| 924 |
+
|
| 925 |
+
def stop_webui(request):
|
| 926 |
+
shared.state.server_command = "stop"
|
| 927 |
+
return Response("Stopping.")
|
| 928 |
+
|
modules/api/models.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel, Field, create_model
|
| 4 |
+
from typing import Any, Optional, Literal
|
| 5 |
+
from inflection import underscore
|
| 6 |
+
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
| 7 |
+
from modules.shared import sd_upscalers, opts, parser
|
| 8 |
+
|
| 9 |
+
API_NOT_ALLOWED = [
|
| 10 |
+
"self",
|
| 11 |
+
"kwargs",
|
| 12 |
+
"sd_model",
|
| 13 |
+
"outpath_samples",
|
| 14 |
+
"outpath_grids",
|
| 15 |
+
"sampler_index",
|
| 16 |
+
# "do_not_save_samples",
|
| 17 |
+
# "do_not_save_grid",
|
| 18 |
+
"extra_generation_params",
|
| 19 |
+
"overlay_images",
|
| 20 |
+
"do_not_reload_embeddings",
|
| 21 |
+
"seed_enable_extras",
|
| 22 |
+
"prompt_for_display",
|
| 23 |
+
"sampler_noise_scheduler_override",
|
| 24 |
+
"ddim_discretize"
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
class ModelDef(BaseModel):
|
| 28 |
+
"""Assistance Class for Pydantic Dynamic Model Generation"""
|
| 29 |
+
|
| 30 |
+
field: str
|
| 31 |
+
field_alias: str
|
| 32 |
+
field_type: Any
|
| 33 |
+
field_value: Any
|
| 34 |
+
field_exclude: bool = False
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class PydanticModelGenerator:
|
| 38 |
+
"""
|
| 39 |
+
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
|
| 40 |
+
source_data is a snapshot of the default values produced by the class
|
| 41 |
+
params are the names of the actual keys required by __init__
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
model_name: str = None,
|
| 47 |
+
class_instance = None,
|
| 48 |
+
additional_fields = None,
|
| 49 |
+
):
|
| 50 |
+
def field_type_generator(k, v):
|
| 51 |
+
field_type = v.annotation
|
| 52 |
+
|
| 53 |
+
if field_type == 'Image':
|
| 54 |
+
# images are sent as base64 strings via API
|
| 55 |
+
field_type = 'str'
|
| 56 |
+
|
| 57 |
+
return Optional[field_type]
|
| 58 |
+
|
| 59 |
+
def merge_class_params(class_):
|
| 60 |
+
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
|
| 61 |
+
parameters = {}
|
| 62 |
+
for classes in all_classes:
|
| 63 |
+
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
| 64 |
+
return parameters
|
| 65 |
+
|
| 66 |
+
self._model_name = model_name
|
| 67 |
+
self._class_data = merge_class_params(class_instance)
|
| 68 |
+
|
| 69 |
+
self._model_def = [
|
| 70 |
+
ModelDef(
|
| 71 |
+
field=underscore(k),
|
| 72 |
+
field_alias=k,
|
| 73 |
+
field_type=field_type_generator(k, v),
|
| 74 |
+
field_value=None if isinstance(v.default, property) else v.default
|
| 75 |
+
)
|
| 76 |
+
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
for fields in additional_fields:
|
| 80 |
+
self._model_def.append(ModelDef(
|
| 81 |
+
field=underscore(fields["key"]),
|
| 82 |
+
field_alias=fields["key"],
|
| 83 |
+
field_type=fields["type"],
|
| 84 |
+
field_value=fields["default"],
|
| 85 |
+
field_exclude=fields["exclude"] if "exclude" in fields else False))
|
| 86 |
+
|
| 87 |
+
def generate_model(self):
|
| 88 |
+
"""
|
| 89 |
+
Creates a pydantic BaseModel
|
| 90 |
+
from the json and overrides provided at initialization
|
| 91 |
+
"""
|
| 92 |
+
fields = {
|
| 93 |
+
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
|
| 94 |
+
}
|
| 95 |
+
DynamicModel = create_model(self._model_name, **fields)
|
| 96 |
+
DynamicModel.__config__.allow_population_by_field_name = True
|
| 97 |
+
DynamicModel.__config__.allow_mutation = True
|
| 98 |
+
return DynamicModel
|
| 99 |
+
|
| 100 |
+
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
| 101 |
+
"StableDiffusionProcessingTxt2Img",
|
| 102 |
+
StableDiffusionProcessingTxt2Img,
|
| 103 |
+
[
|
| 104 |
+
{"key": "sampler_index", "type": str, "default": "Euler"},
|
| 105 |
+
{"key": "script_name", "type": str, "default": None},
|
| 106 |
+
{"key": "script_args", "type": list, "default": []},
|
| 107 |
+
{"key": "send_images", "type": bool, "default": True},
|
| 108 |
+
{"key": "save_images", "type": bool, "default": False},
|
| 109 |
+
{"key": "alwayson_scripts", "type": dict, "default": {}},
|
| 110 |
+
{"key": "force_task_id", "type": str, "default": None},
|
| 111 |
+
{"key": "infotext", "type": str, "default": None},
|
| 112 |
+
]
|
| 113 |
+
).generate_model()
|
| 114 |
+
|
| 115 |
+
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
| 116 |
+
"StableDiffusionProcessingImg2Img",
|
| 117 |
+
StableDiffusionProcessingImg2Img,
|
| 118 |
+
[
|
| 119 |
+
{"key": "sampler_index", "type": str, "default": "Euler"},
|
| 120 |
+
{"key": "init_images", "type": list, "default": None},
|
| 121 |
+
{"key": "denoising_strength", "type": float, "default": 0.75},
|
| 122 |
+
{"key": "mask", "type": str, "default": None},
|
| 123 |
+
{"key": "include_init_images", "type": bool, "default": False, "exclude" : True},
|
| 124 |
+
{"key": "script_name", "type": str, "default": None},
|
| 125 |
+
{"key": "script_args", "type": list, "default": []},
|
| 126 |
+
{"key": "send_images", "type": bool, "default": True},
|
| 127 |
+
{"key": "save_images", "type": bool, "default": False},
|
| 128 |
+
{"key": "alwayson_scripts", "type": dict, "default": {}},
|
| 129 |
+
{"key": "force_task_id", "type": str, "default": None},
|
| 130 |
+
{"key": "infotext", "type": str, "default": None},
|
| 131 |
+
]
|
| 132 |
+
).generate_model()
|
| 133 |
+
|
| 134 |
+
class TextToImageResponse(BaseModel):
|
| 135 |
+
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
| 136 |
+
parameters: dict
|
| 137 |
+
info: str
|
| 138 |
+
|
| 139 |
+
class ImageToImageResponse(BaseModel):
|
| 140 |
+
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
| 141 |
+
parameters: dict
|
| 142 |
+
info: str
|
| 143 |
+
|
| 144 |
+
class ExtrasBaseRequest(BaseModel):
|
| 145 |
+
resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
|
| 146 |
+
show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
|
| 147 |
+
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
|
| 148 |
+
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
|
| 149 |
+
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
|
| 150 |
+
upscaling_resize: float = Field(default=2, title="Upscaling Factor", gt=0, description="By how much to upscale the image, only used when resize_mode=0.")
|
| 151 |
+
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
|
| 152 |
+
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
|
| 153 |
+
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
|
| 154 |
+
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
| 155 |
+
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
| 156 |
+
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
|
| 157 |
+
upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?")
|
| 158 |
+
|
| 159 |
+
class ExtraBaseResponse(BaseModel):
|
| 160 |
+
html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
|
| 161 |
+
|
| 162 |
+
class ExtrasSingleImageRequest(ExtrasBaseRequest):
|
| 163 |
+
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
| 164 |
+
|
| 165 |
+
class ExtrasSingleImageResponse(ExtraBaseResponse):
|
| 166 |
+
image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
|
| 167 |
+
|
| 168 |
+
class FileData(BaseModel):
|
| 169 |
+
data: str = Field(title="File data", description="Base64 representation of the file")
|
| 170 |
+
name: str = Field(title="File name")
|
| 171 |
+
|
| 172 |
+
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
|
| 173 |
+
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
|
| 174 |
+
|
| 175 |
+
class ExtrasBatchImagesResponse(ExtraBaseResponse):
|
| 176 |
+
images: list[str] = Field(title="Images", description="The generated images in base64 format.")
|
| 177 |
+
|
| 178 |
+
class PNGInfoRequest(BaseModel):
|
| 179 |
+
image: str = Field(title="Image", description="The base64 encoded PNG image")
|
| 180 |
+
|
| 181 |
+
class PNGInfoResponse(BaseModel):
|
| 182 |
+
info: str = Field(title="Image info", description="A string with the parameters used to generate the image")
|
| 183 |
+
items: dict = Field(title="Items", description="A dictionary containing all the other fields the image had")
|
| 184 |
+
parameters: dict = Field(title="Parameters", description="A dictionary with parsed generation info fields")
|
| 185 |
+
|
| 186 |
+
class ProgressRequest(BaseModel):
|
| 187 |
+
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
|
| 188 |
+
|
| 189 |
+
class ProgressResponse(BaseModel):
|
| 190 |
+
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
|
| 191 |
+
eta_relative: float = Field(title="ETA in secs")
|
| 192 |
+
state: dict = Field(title="State", description="The current state snapshot")
|
| 193 |
+
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
| 194 |
+
textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
|
| 195 |
+
|
| 196 |
+
class InterrogateRequest(BaseModel):
|
| 197 |
+
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
| 198 |
+
model: str = Field(default="clip", title="Model", description="The interrogate model used.")
|
| 199 |
+
|
| 200 |
+
class InterrogateResponse(BaseModel):
|
| 201 |
+
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
|
| 202 |
+
|
| 203 |
+
class TrainResponse(BaseModel):
|
| 204 |
+
info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.")
|
| 205 |
+
|
| 206 |
+
class CreateResponse(BaseModel):
|
| 207 |
+
info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
|
| 208 |
+
|
| 209 |
+
fields = {}
|
| 210 |
+
for key, metadata in opts.data_labels.items():
|
| 211 |
+
value = opts.data.get(key)
|
| 212 |
+
optType = opts.typemap.get(type(metadata.default), type(metadata.default)) if metadata.default else Any
|
| 213 |
+
|
| 214 |
+
if metadata is not None:
|
| 215 |
+
fields.update({key: (Optional[optType], Field(default=metadata.default, description=metadata.label))})
|
| 216 |
+
else:
|
| 217 |
+
fields.update({key: (Optional[optType], Field())})
|
| 218 |
+
|
| 219 |
+
OptionsModel = create_model("Options", **fields)
|
| 220 |
+
|
| 221 |
+
flags = {}
|
| 222 |
+
_options = vars(parser)['_option_string_actions']
|
| 223 |
+
for key in _options:
|
| 224 |
+
if(_options[key].dest != 'help'):
|
| 225 |
+
flag = _options[key]
|
| 226 |
+
_type = str
|
| 227 |
+
if _options[key].default is not None:
|
| 228 |
+
_type = type(_options[key].default)
|
| 229 |
+
flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})
|
| 230 |
+
|
| 231 |
+
FlagsModel = create_model("Flags", **flags)
|
| 232 |
+
|
| 233 |
+
class SamplerItem(BaseModel):
|
| 234 |
+
name: str = Field(title="Name")
|
| 235 |
+
aliases: list[str] = Field(title="Aliases")
|
| 236 |
+
options: dict[str, str] = Field(title="Options")
|
| 237 |
+
|
| 238 |
+
class SchedulerItem(BaseModel):
|
| 239 |
+
name: str = Field(title="Name")
|
| 240 |
+
label: str = Field(title="Label")
|
| 241 |
+
aliases: Optional[list[str]] = Field(title="Aliases")
|
| 242 |
+
default_rho: Optional[float] = Field(title="Default Rho")
|
| 243 |
+
need_inner_model: Optional[bool] = Field(title="Needs Inner Model")
|
| 244 |
+
|
| 245 |
+
class UpscalerItem(BaseModel):
|
| 246 |
+
name: str = Field(title="Name")
|
| 247 |
+
model_name: Optional[str] = Field(title="Model Name")
|
| 248 |
+
model_path: Optional[str] = Field(title="Path")
|
| 249 |
+
model_url: Optional[str] = Field(title="URL")
|
| 250 |
+
scale: Optional[float] = Field(title="Scale")
|
| 251 |
+
|
| 252 |
+
class LatentUpscalerModeItem(BaseModel):
|
| 253 |
+
name: str = Field(title="Name")
|
| 254 |
+
|
| 255 |
+
class SDModelItem(BaseModel):
|
| 256 |
+
title: str = Field(title="Title")
|
| 257 |
+
model_name: str = Field(title="Model Name")
|
| 258 |
+
hash: Optional[str] = Field(title="Short hash")
|
| 259 |
+
sha256: Optional[str] = Field(title="sha256 hash")
|
| 260 |
+
filename: str = Field(title="Filename")
|
| 261 |
+
config: Optional[str] = Field(title="Config file")
|
| 262 |
+
|
| 263 |
+
class SDVaeItem(BaseModel):
|
| 264 |
+
model_name: str = Field(title="Model Name")
|
| 265 |
+
filename: str = Field(title="Filename")
|
| 266 |
+
|
| 267 |
+
class HypernetworkItem(BaseModel):
|
| 268 |
+
name: str = Field(title="Name")
|
| 269 |
+
path: Optional[str] = Field(title="Path")
|
| 270 |
+
|
| 271 |
+
class FaceRestorerItem(BaseModel):
|
| 272 |
+
name: str = Field(title="Name")
|
| 273 |
+
cmd_dir: Optional[str] = Field(title="Path")
|
| 274 |
+
|
| 275 |
+
class RealesrganItem(BaseModel):
|
| 276 |
+
name: str = Field(title="Name")
|
| 277 |
+
path: Optional[str] = Field(title="Path")
|
| 278 |
+
scale: Optional[int] = Field(title="Scale")
|
| 279 |
+
|
| 280 |
+
class PromptStyleItem(BaseModel):
|
| 281 |
+
name: str = Field(title="Name")
|
| 282 |
+
prompt: Optional[str] = Field(title="Prompt")
|
| 283 |
+
negative_prompt: Optional[str] = Field(title="Negative Prompt")
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class EmbeddingItem(BaseModel):
|
| 287 |
+
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
|
| 288 |
+
sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
|
| 289 |
+
sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
|
| 290 |
+
shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
|
| 291 |
+
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
|
| 292 |
+
|
| 293 |
+
class EmbeddingsResponse(BaseModel):
|
| 294 |
+
loaded: dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
|
| 295 |
+
skipped: dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
|
| 296 |
+
|
| 297 |
+
class MemoryResponse(BaseModel):
|
| 298 |
+
ram: dict = Field(title="RAM", description="System memory stats")
|
| 299 |
+
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class ScriptsList(BaseModel):
|
| 303 |
+
txt2img: list = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)")
|
| 304 |
+
img2img: list = Field(default=None, title="Img2img", description="Titles of scripts (img2img)")
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class ScriptArg(BaseModel):
|
| 308 |
+
label: str = Field(default=None, title="Label", description="Name of the argument in UI")
|
| 309 |
+
value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument")
|
| 310 |
+
minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
|
| 311 |
+
maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
|
| 312 |
+
step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
|
| 313 |
+
choices: Optional[list[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class ScriptInfo(BaseModel):
|
| 317 |
+
name: str = Field(default=None, title="Name", description="Script name")
|
| 318 |
+
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
|
| 319 |
+
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
|
| 320 |
+
args: list[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
|
| 321 |
+
|
| 322 |
+
class ExtensionItem(BaseModel):
|
| 323 |
+
name: str = Field(title="Name", description="Extension name")
|
| 324 |
+
remote: str = Field(title="Remote", description="Extension Repository URL")
|
| 325 |
+
branch: str = Field(title="Branch", description="Extension Repository Branch")
|
| 326 |
+
commit_hash: str = Field(title="Commit Hash", description="Extension Repository Commit Hash")
|
| 327 |
+
version: str = Field(title="Version", description="Extension Version")
|
| 328 |
+
commit_date: str = Field(title="Commit Date", description="Extension Repository Commit Date")
|
| 329 |
+
enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled")
|
modules/cache.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import os.path
|
| 4 |
+
import threading
|
| 5 |
+
|
| 6 |
+
import diskcache
|
| 7 |
+
import tqdm
|
| 8 |
+
|
| 9 |
+
from modules.paths import data_path, script_path
|
| 10 |
+
|
| 11 |
+
cache_filename = os.environ.get('SD_WEBUI_CACHE_FILE', os.path.join(data_path, "cache.json"))
|
| 12 |
+
cache_dir = os.environ.get('SD_WEBUI_CACHE_DIR', os.path.join(data_path, "cache"))
|
| 13 |
+
caches = {}
|
| 14 |
+
cache_lock = threading.Lock()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def dump_cache():
|
| 18 |
+
"""old function for dumping cache to disk; does nothing since diskcache."""
|
| 19 |
+
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def make_cache(subsection: str) -> diskcache.Cache:
|
| 24 |
+
return diskcache.Cache(
|
| 25 |
+
os.path.join(cache_dir, subsection),
|
| 26 |
+
size_limit=2**32, # 4 GB, culling oldest first
|
| 27 |
+
disk_min_file_size=2**18, # keep up to 256KB in Sqlite
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def convert_old_cached_data():
|
| 32 |
+
try:
|
| 33 |
+
with open(cache_filename, "r", encoding="utf8") as file:
|
| 34 |
+
data = json.load(file)
|
| 35 |
+
except FileNotFoundError:
|
| 36 |
+
return
|
| 37 |
+
except Exception:
|
| 38 |
+
os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
|
| 39 |
+
print('[ERROR] issue occurred while trying to read cache.json; old cache has been moved to tmp/cache.json')
|
| 40 |
+
return
|
| 41 |
+
|
| 42 |
+
total_count = sum(len(keyvalues) for keyvalues in data.values())
|
| 43 |
+
|
| 44 |
+
with tqdm.tqdm(total=total_count, desc="converting cache") as progress:
|
| 45 |
+
for subsection, keyvalues in data.items():
|
| 46 |
+
cache_obj = caches.get(subsection)
|
| 47 |
+
if cache_obj is None:
|
| 48 |
+
cache_obj = make_cache(subsection)
|
| 49 |
+
caches[subsection] = cache_obj
|
| 50 |
+
|
| 51 |
+
for key, value in keyvalues.items():
|
| 52 |
+
cache_obj[key] = value
|
| 53 |
+
progress.update(1)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def cache(subsection):
|
| 57 |
+
"""
|
| 58 |
+
Retrieves or initializes a cache for a specific subsection.
|
| 59 |
+
|
| 60 |
+
Parameters:
|
| 61 |
+
subsection (str): The subsection identifier for the cache.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
diskcache.Cache: The cache data for the specified subsection.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
cache_obj = caches.get(subsection)
|
| 68 |
+
if not cache_obj:
|
| 69 |
+
with cache_lock:
|
| 70 |
+
if not os.path.exists(cache_dir) and os.path.isfile(cache_filename):
|
| 71 |
+
convert_old_cached_data()
|
| 72 |
+
|
| 73 |
+
cache_obj = caches.get(subsection)
|
| 74 |
+
if not cache_obj:
|
| 75 |
+
cache_obj = make_cache(subsection)
|
| 76 |
+
caches[subsection] = cache_obj
|
| 77 |
+
|
| 78 |
+
return cache_obj
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def cached_data_for_file(subsection, title, filename, func):
|
| 82 |
+
"""
|
| 83 |
+
Retrieves or generates data for a specific file, using a caching mechanism.
|
| 84 |
+
|
| 85 |
+
Parameters:
|
| 86 |
+
subsection (str): The subsection of the cache to use.
|
| 87 |
+
title (str): The title of the data entry in the subsection of the cache.
|
| 88 |
+
filename (str): The path to the file to be checked for modifications.
|
| 89 |
+
func (callable): A function that generates the data if it is not available in the cache.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
dict or None: The cached or generated data, or None if data generation fails.
|
| 93 |
+
|
| 94 |
+
The `cached_data_for_file` function implements a caching mechanism for data stored in files.
|
| 95 |
+
It checks if the data associated with the given `title` is present in the cache and compares the
|
| 96 |
+
modification time of the file with the cached modification time. If the file has been modified,
|
| 97 |
+
the cache is considered invalid and the data is regenerated using the provided `func`.
|
| 98 |
+
Otherwise, the cached data is returned.
|
| 99 |
+
|
| 100 |
+
If the data generation fails, None is returned to indicate the failure. Otherwise, the generated
|
| 101 |
+
or cached data is returned as a dictionary.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
existing_cache = cache(subsection)
|
| 105 |
+
ondisk_mtime = os.path.getmtime(filename)
|
| 106 |
+
|
| 107 |
+
entry = existing_cache.get(title)
|
| 108 |
+
if entry:
|
| 109 |
+
cached_mtime = entry.get("mtime", 0)
|
| 110 |
+
if ondisk_mtime > cached_mtime:
|
| 111 |
+
entry = None
|
| 112 |
+
|
| 113 |
+
if not entry or 'value' not in entry:
|
| 114 |
+
value = func()
|
| 115 |
+
if value is None:
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
entry = {'mtime': ondisk_mtime, 'value': value}
|
| 119 |
+
existing_cache[title] = entry
|
| 120 |
+
|
| 121 |
+
dump_cache()
|
| 122 |
+
|
| 123 |
+
return entry['value']
|
modules/call_queue.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
from functools import wraps
|
| 3 |
+
import html
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
from modules import shared, progress, errors, devices, fifo_lock, profiling
|
| 7 |
+
|
| 8 |
+
queue_lock = fifo_lock.FIFOLock()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def wrap_queued_call(func):
|
| 12 |
+
def f(*args, **kwargs):
|
| 13 |
+
with queue_lock:
|
| 14 |
+
res = func(*args, **kwargs)
|
| 15 |
+
|
| 16 |
+
return res
|
| 17 |
+
|
| 18 |
+
return f
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
| 22 |
+
@wraps(func)
|
| 23 |
+
def f(*args, **kwargs):
|
| 24 |
+
|
| 25 |
+
# if the first argument is a string that says "task(...)", it is treated as a job id
|
| 26 |
+
if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"):
|
| 27 |
+
id_task = args[0]
|
| 28 |
+
progress.add_task_to_queue(id_task)
|
| 29 |
+
else:
|
| 30 |
+
id_task = None
|
| 31 |
+
|
| 32 |
+
with queue_lock:
|
| 33 |
+
shared.state.begin(job=id_task)
|
| 34 |
+
progress.start_task(id_task)
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
res = func(*args, **kwargs)
|
| 38 |
+
progress.record_results(id_task, res)
|
| 39 |
+
finally:
|
| 40 |
+
progress.finish_task(id_task)
|
| 41 |
+
|
| 42 |
+
shared.state.end()
|
| 43 |
+
|
| 44 |
+
return res
|
| 45 |
+
|
| 46 |
+
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
| 50 |
+
@wraps(func)
|
| 51 |
+
def f(*args, **kwargs):
|
| 52 |
+
try:
|
| 53 |
+
res = func(*args, **kwargs)
|
| 54 |
+
finally:
|
| 55 |
+
shared.state.skipped = False
|
| 56 |
+
shared.state.interrupted = False
|
| 57 |
+
shared.state.stopping_generation = False
|
| 58 |
+
shared.state.job_count = 0
|
| 59 |
+
shared.state.job = ""
|
| 60 |
+
return res
|
| 61 |
+
|
| 62 |
+
return wrap_gradio_call_no_job(f, extra_outputs, add_stats)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def wrap_gradio_call_no_job(func, extra_outputs=None, add_stats=False):
|
| 66 |
+
@wraps(func)
|
| 67 |
+
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
| 68 |
+
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
|
| 69 |
+
if run_memmon:
|
| 70 |
+
shared.mem_mon.monitor()
|
| 71 |
+
t = time.perf_counter()
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
res = list(func(*args, **kwargs))
|
| 75 |
+
except Exception as e:
|
| 76 |
+
# When printing out our debug argument list,
|
| 77 |
+
# do not print out more than a 100 KB of text
|
| 78 |
+
max_debug_str_len = 131072
|
| 79 |
+
message = "Error completing request"
|
| 80 |
+
arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len]
|
| 81 |
+
if len(arg_str) > max_debug_str_len:
|
| 82 |
+
arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)"
|
| 83 |
+
errors.report(f"{message}\n{arg_str}", exc_info=True)
|
| 84 |
+
|
| 85 |
+
if extra_outputs_array is None:
|
| 86 |
+
extra_outputs_array = [None, '']
|
| 87 |
+
|
| 88 |
+
error_message = f'{type(e).__name__}: {e}'
|
| 89 |
+
res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
|
| 90 |
+
|
| 91 |
+
devices.torch_gc()
|
| 92 |
+
|
| 93 |
+
if not add_stats:
|
| 94 |
+
return tuple(res)
|
| 95 |
+
|
| 96 |
+
elapsed = time.perf_counter() - t
|
| 97 |
+
elapsed_m = int(elapsed // 60)
|
| 98 |
+
elapsed_s = elapsed % 60
|
| 99 |
+
elapsed_text = f"{elapsed_s:.1f} sec."
|
| 100 |
+
if elapsed_m > 0:
|
| 101 |
+
elapsed_text = f"{elapsed_m} min. "+elapsed_text
|
| 102 |
+
|
| 103 |
+
if run_memmon:
|
| 104 |
+
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
|
| 105 |
+
active_peak = mem_stats['active_peak']
|
| 106 |
+
reserved_peak = mem_stats['reserved_peak']
|
| 107 |
+
sys_peak = mem_stats['system_peak']
|
| 108 |
+
sys_total = mem_stats['total']
|
| 109 |
+
sys_pct = sys_peak/max(sys_total, 1) * 100
|
| 110 |
+
|
| 111 |
+
toltip_a = "Active: peak amount of video memory used during generation (excluding cached data)"
|
| 112 |
+
toltip_r = "Reserved: total amount of video memory allocated by the Torch library "
|
| 113 |
+
toltip_sys = "System: peak amount of video memory allocated by all running programs, out of total capacity"
|
| 114 |
+
|
| 115 |
+
text_a = f"<abbr title='{toltip_a}'>A</abbr>: <span class='measurement'>{active_peak/1024:.2f} GB</span>"
|
| 116 |
+
text_r = f"<abbr title='{toltip_r}'>R</abbr>: <span class='measurement'>{reserved_peak/1024:.2f} GB</span>"
|
| 117 |
+
text_sys = f"<abbr title='{toltip_sys}'>Sys</abbr>: <span class='measurement'>{sys_peak/1024:.1f}/{sys_total/1024:g} GB</span> ({sys_pct:.1f}%)"
|
| 118 |
+
|
| 119 |
+
vram_html = f"<p class='vram'>{text_a}, <wbr>{text_r}, <wbr>{text_sys}</p>"
|
| 120 |
+
else:
|
| 121 |
+
vram_html = ''
|
| 122 |
+
|
| 123 |
+
if shared.opts.profiling_enable and os.path.exists(shared.opts.profiling_filename):
|
| 124 |
+
profiling_html = f"<p class='profile'> [ <a href='{profiling.webpath()}' download>Profile</a> ] </p>"
|
| 125 |
+
else:
|
| 126 |
+
profiling_html = ''
|
| 127 |
+
|
| 128 |
+
# last item is always HTML
|
| 129 |
+
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr><span class='measurement'>{elapsed_text}</span></p>{vram_html}{profiling_html}</div>"
|
| 130 |
+
|
| 131 |
+
return tuple(res)
|
| 132 |
+
|
| 133 |
+
return f
|
| 134 |
+
|
modules/cmd_args.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from modules.paths_internal import normalized_filepath, models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
|
| 5 |
+
|
| 6 |
+
parser = argparse.ArgumentParser()
|
| 7 |
+
|
| 8 |
+
parser.add_argument("-f", action='store_true', help=argparse.SUPPRESS) # allows running as root; implemented outside of webui
|
| 9 |
+
parser.add_argument("--update-all-extensions", action='store_true', help="launch.py argument: download updates for all extensions when starting the program")
|
| 10 |
+
parser.add_argument("--skip-python-version-check", action='store_true', help="launch.py argument: do not check python version")
|
| 11 |
+
parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")
|
| 12 |
+
parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
|
| 13 |
+
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
|
| 14 |
+
parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
|
| 15 |
+
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
|
| 16 |
+
parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup")
|
| 17 |
+
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
|
| 18 |
+
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
|
| 19 |
+
parser.add_argument("--dump-sysinfo", action='store_true', help="launch.py argument: dump limited sysinfo file (without information about extensions, options) to disk and quit")
|
| 20 |
+
parser.add_argument("--loglevel", type=str, help="log level; one of: CRITICAL, ERROR, WARNING, INFO, DEBUG", default=None)
|
| 21 |
+
parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
|
| 22 |
+
parser.add_argument("--data-dir", type=normalized_filepath, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
|
| 23 |
+
parser.add_argument("--models-dir", type=normalized_filepath, default=None, help="base path where models are stored; overrides --data-dir")
|
| 24 |
+
parser.add_argument("--config", type=normalized_filepath, default=sd_default_config, help="path to config which constructs model",)
|
| 25 |
+
parser.add_argument("--ckpt", type=normalized_filepath, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
| 26 |
+
parser.add_argument("--ckpt-dir", type=normalized_filepath, default=None, help="Path to directory with stable diffusion checkpoints")
|
| 27 |
+
parser.add_argument("--vae-dir", type=normalized_filepath, default=None, help="Path to directory with VAE files")
|
| 28 |
+
parser.add_argument("--gfpgan-dir", type=normalized_filepath, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
| 29 |
+
parser.add_argument("--gfpgan-model", type=normalized_filepath, help="GFPGAN model file name", default=None)
|
| 30 |
+
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
| 31 |
+
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
|
| 32 |
+
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
| 33 |
+
parser.add_argument("--max-batch-count", type=int, default=16, help="does not do anything")
|
| 34 |
+
parser.add_argument("--embeddings-dir", type=normalized_filepath, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
| 35 |
+
parser.add_argument("--textual-inversion-templates-dir", type=normalized_filepath, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
|
| 36 |
+
parser.add_argument("--hypernetwork-dir", type=normalized_filepath, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
| 37 |
+
parser.add_argument("--localizations-dir", type=normalized_filepath, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
| 38 |
+
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
| 39 |
+
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
| 40 |
+
parser.add_argument("--medvram-sdxl", action='store_true', help="enable --medvram optimization just for SDXL models")
|
| 41 |
+
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
| 42 |
+
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
|
| 43 |
+
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything")
|
| 44 |
+
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
| 45 |
+
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "half", "autocast"], default="autocast")
|
| 46 |
+
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
|
| 47 |
+
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
| 48 |
+
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
| 49 |
+
parser.add_argument("--ngrok-region", type=str, help="does not do anything.", default="")
|
| 50 |
+
parser.add_argument("--ngrok-options", type=json.loads, help='The options to pass to ngrok in JSON format, e.g.: \'{"authtoken_from_env":true, "basic_auth":"user:password", "oauth_provider":"google", "oauth_allow_emails":"user@asdf.com"}\'', default=dict())
|
| 51 |
+
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
|
| 52 |
+
parser.add_argument("--codeformer-models-path", type=normalized_filepath, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
| 53 |
+
parser.add_argument("--gfpgan-models-path", type=normalized_filepath, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
|
| 54 |
+
parser.add_argument("--esrgan-models-path", type=normalized_filepath, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
|
| 55 |
+
parser.add_argument("--bsrgan-models-path", type=normalized_filepath, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
|
| 56 |
+
parser.add_argument("--realesrgan-models-path", type=normalized_filepath, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
|
| 57 |
+
parser.add_argument("--dat-models-path", type=normalized_filepath, help="Path to directory with DAT model file(s).", default=os.path.join(models_path, 'DAT'))
|
| 58 |
+
parser.add_argument("--clip-models-path", type=normalized_filepath, help="Path to directory with CLIP model file(s).", default=None)
|
| 59 |
+
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
| 60 |
+
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
| 61 |
+
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
|
| 62 |
+
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
| 63 |
+
parser.add_argument("--opt-split-attention", action='store_true', help="prefer Doggettx's cross-attention layer optimization for automatic choice of optimization")
|
| 64 |
+
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="prefer memory efficient sub-quadratic cross-attention layer optimization for automatic choice of optimization")
|
| 65 |
+
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
|
| 66 |
+
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
|
| 67 |
+
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
|
| 68 |
+
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="prefer InvokeAI's cross-attention layer optimization for automatic choice of optimization")
|
| 69 |
+
parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefer older version of split attention optimization for automatic choice of optimization")
|
| 70 |
+
parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*")
|
| 71 |
+
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*")
|
| 72 |
+
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
|
| 73 |
+
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
| 74 |
+
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
| 75 |
+
parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device")
|
| 76 |
+
parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
|
| 77 |
+
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
| 78 |
+
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
| 79 |
+
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
| 80 |
+
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
|
| 81 |
+
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
|
| 82 |
+
parser.add_argument("--freeze-settings", action='store_true', help="disable editing of all settings globally", default=False)
|
| 83 |
+
parser.add_argument("--freeze-settings-in-sections", type=str, help='disable editing settings in specific sections of the settings page by specifying a comma-delimited list such like "saving-images,upscaling". The list of setting names can be found in the modules/shared_options.py file', default=None)
|
| 84 |
+
parser.add_argument("--freeze-specific-settings", type=str, help='disable editing of individual settings by specifying a comma-delimited list like "samples_save,samples_format". The list of setting names can be found in the config.json file', default=None)
|
| 85 |
+
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
|
| 86 |
+
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
| 87 |
+
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
| 88 |
+
parser.add_argument("--gradio-auth-path", type=normalized_filepath, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
|
| 89 |
+
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
| 90 |
+
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
| 91 |
+
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path])
|
| 92 |
+
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
| 93 |
+
parser.add_argument("--styles-file", type=str, action='append', help="path or wildcard path of styles files, allow multiple entries.", default=[])
|
| 94 |
+
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
| 95 |
+
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
| 96 |
+
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
| 97 |
+
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
| 98 |
+
parser.add_argument("--enable-console-prompts", action='store_true', help="does not do anything", default=False) # Legacy compatibility, use as default value shared.opts.enable_console_prompts
|
| 99 |
+
parser.add_argument('--vae-path', type=normalized_filepath, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
|
| 100 |
+
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
| 101 |
+
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
| 102 |
+
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
| 103 |
+
parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
|
| 104 |
+
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
|
| 105 |
+
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
|
| 106 |
+
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
|
| 107 |
+
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
|
| 108 |
+
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
|
| 109 |
+
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
|
| 110 |
+
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
|
| 111 |
+
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
| 112 |
+
parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
|
| 113 |
+
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
| 114 |
+
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
|
| 115 |
+
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the default in earlier versions")
|
| 116 |
+
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
| 117 |
+
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
| 118 |
+
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
| 119 |
+
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
| 120 |
+
parser.add_argument('--add-stop-route', action='store_true', help='does not do anything')
|
| 121 |
+
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
|
| 122 |
+
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
|
| 123 |
+
parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
|
| 124 |
+
parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False)
|
| 125 |
+
parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui")
|
| 126 |
+
parser.add_argument("--unix-filenames-sanitization", action='store_true', help="allow any symbols except '/' in filenames. May conflict with your browser and file system")
|
| 127 |
+
parser.add_argument("--filenames-max-length", type=int, default=128, help='maximal length of filenames of saved images. If you override it, it can conflict with your file system')
|
| 128 |
+
parser.add_argument("--no-prompt-history", action='store_true', help="disable read prompt from last generation feature; settings this argument will not create '--data_path/params.txt' file")
|
modules/codeformer_model.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from modules import (
|
| 8 |
+
devices,
|
| 9 |
+
errors,
|
| 10 |
+
face_restoration,
|
| 11 |
+
face_restoration_utils,
|
| 12 |
+
modelloader,
|
| 13 |
+
shared,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
| 19 |
+
model_download_name = 'codeformer-v0.1.0.pth'
|
| 20 |
+
|
| 21 |
+
# used by e.g. postprocessing_codeformer.py
|
| 22 |
+
codeformer: face_restoration.FaceRestoration | None = None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
|
| 26 |
+
def name(self):
|
| 27 |
+
return "CodeFormer"
|
| 28 |
+
|
| 29 |
+
def load_net(self) -> torch.Module:
|
| 30 |
+
for model_path in modelloader.load_models(
|
| 31 |
+
model_path=self.model_path,
|
| 32 |
+
model_url=model_url,
|
| 33 |
+
command_path=self.model_path,
|
| 34 |
+
download_name=model_download_name,
|
| 35 |
+
ext_filter=['.pth'],
|
| 36 |
+
):
|
| 37 |
+
return modelloader.load_spandrel_model(
|
| 38 |
+
model_path,
|
| 39 |
+
device=devices.device_codeformer,
|
| 40 |
+
expected_architecture='CodeFormer',
|
| 41 |
+
).model
|
| 42 |
+
raise ValueError("No codeformer model found")
|
| 43 |
+
|
| 44 |
+
def get_device(self):
|
| 45 |
+
return devices.device_codeformer
|
| 46 |
+
|
| 47 |
+
def restore(self, np_image, w: float | None = None):
|
| 48 |
+
if w is None:
|
| 49 |
+
w = getattr(shared.opts, "code_former_weight", 0.5)
|
| 50 |
+
|
| 51 |
+
def restore_face(cropped_face_t):
|
| 52 |
+
assert self.net is not None
|
| 53 |
+
return self.net(cropped_face_t, weight=w, adain=True)[0]
|
| 54 |
+
|
| 55 |
+
return self.restore_with_helper(np_image, restore_face)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def setup_model(dirname: str) -> None:
|
| 59 |
+
global codeformer
|
| 60 |
+
try:
|
| 61 |
+
codeformer = FaceRestorerCodeFormer(dirname)
|
| 62 |
+
shared.face_restorers.append(codeformer)
|
| 63 |
+
except Exception:
|
| 64 |
+
errors.report("Error setting up CodeFormer", exc_info=True)
|
modules/config_states.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Supports saving and restoring webui and extensions from a known working set of commits
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import tqdm
|
| 8 |
+
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
import git
|
| 11 |
+
|
| 12 |
+
from modules import shared, extensions, errors
|
| 13 |
+
from modules.paths_internal import script_path, config_states_dir
|
| 14 |
+
|
| 15 |
+
all_config_states = {}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def list_config_states():
|
| 19 |
+
global all_config_states
|
| 20 |
+
|
| 21 |
+
all_config_states.clear()
|
| 22 |
+
os.makedirs(config_states_dir, exist_ok=True)
|
| 23 |
+
|
| 24 |
+
config_states = []
|
| 25 |
+
for filename in os.listdir(config_states_dir):
|
| 26 |
+
if filename.endswith(".json"):
|
| 27 |
+
path = os.path.join(config_states_dir, filename)
|
| 28 |
+
try:
|
| 29 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 30 |
+
j = json.load(f)
|
| 31 |
+
assert "created_at" in j, '"created_at" does not exist'
|
| 32 |
+
j["filepath"] = path
|
| 33 |
+
config_states.append(j)
|
| 34 |
+
except Exception as e:
|
| 35 |
+
print(f'[ERROR]: Config states {path}, {e}')
|
| 36 |
+
|
| 37 |
+
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
|
| 38 |
+
|
| 39 |
+
for cs in config_states:
|
| 40 |
+
timestamp = datetime.fromtimestamp(cs["created_at"]).strftime('%Y-%m-%d %H:%M:%S')
|
| 41 |
+
name = cs.get("name", "Config")
|
| 42 |
+
full_name = f"{name}: {timestamp}"
|
| 43 |
+
all_config_states[full_name] = cs
|
| 44 |
+
|
| 45 |
+
return all_config_states
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_webui_config():
|
| 49 |
+
webui_repo = None
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
if os.path.exists(os.path.join(script_path, ".git")):
|
| 53 |
+
webui_repo = git.Repo(script_path)
|
| 54 |
+
except Exception:
|
| 55 |
+
errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
|
| 56 |
+
|
| 57 |
+
webui_remote = None
|
| 58 |
+
webui_commit_hash = None
|
| 59 |
+
webui_commit_date = None
|
| 60 |
+
webui_branch = None
|
| 61 |
+
if webui_repo and not webui_repo.bare:
|
| 62 |
+
try:
|
| 63 |
+
webui_remote = next(webui_repo.remote().urls, None)
|
| 64 |
+
head = webui_repo.head.commit
|
| 65 |
+
webui_commit_date = webui_repo.head.commit.committed_date
|
| 66 |
+
webui_commit_hash = head.hexsha
|
| 67 |
+
webui_branch = webui_repo.active_branch.name
|
| 68 |
+
|
| 69 |
+
except Exception:
|
| 70 |
+
webui_remote = None
|
| 71 |
+
|
| 72 |
+
return {
|
| 73 |
+
"remote": webui_remote,
|
| 74 |
+
"commit_hash": webui_commit_hash,
|
| 75 |
+
"commit_date": webui_commit_date,
|
| 76 |
+
"branch": webui_branch,
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_extension_config():
|
| 81 |
+
ext_config = {}
|
| 82 |
+
|
| 83 |
+
for ext in extensions.extensions:
|
| 84 |
+
ext.read_info_from_repo()
|
| 85 |
+
|
| 86 |
+
entry = {
|
| 87 |
+
"name": ext.name,
|
| 88 |
+
"path": ext.path,
|
| 89 |
+
"enabled": ext.enabled,
|
| 90 |
+
"is_builtin": ext.is_builtin,
|
| 91 |
+
"remote": ext.remote,
|
| 92 |
+
"commit_hash": ext.commit_hash,
|
| 93 |
+
"commit_date": ext.commit_date,
|
| 94 |
+
"branch": ext.branch,
|
| 95 |
+
"have_info_from_repo": ext.have_info_from_repo
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
ext_config[ext.name] = entry
|
| 99 |
+
|
| 100 |
+
return ext_config
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def get_config():
|
| 104 |
+
creation_time = datetime.now().timestamp()
|
| 105 |
+
webui_config = get_webui_config()
|
| 106 |
+
ext_config = get_extension_config()
|
| 107 |
+
|
| 108 |
+
return {
|
| 109 |
+
"created_at": creation_time,
|
| 110 |
+
"webui": webui_config,
|
| 111 |
+
"extensions": ext_config
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def restore_webui_config(config):
|
| 116 |
+
print("* Restoring webui state...")
|
| 117 |
+
|
| 118 |
+
if "webui" not in config:
|
| 119 |
+
print("Error: No webui data saved to config")
|
| 120 |
+
return
|
| 121 |
+
|
| 122 |
+
webui_config = config["webui"]
|
| 123 |
+
|
| 124 |
+
if "commit_hash" not in webui_config:
|
| 125 |
+
print("Error: No commit saved to webui config")
|
| 126 |
+
return
|
| 127 |
+
|
| 128 |
+
webui_commit_hash = webui_config.get("commit_hash", None)
|
| 129 |
+
webui_repo = None
|
| 130 |
+
|
| 131 |
+
try:
|
| 132 |
+
if os.path.exists(os.path.join(script_path, ".git")):
|
| 133 |
+
webui_repo = git.Repo(script_path)
|
| 134 |
+
except Exception:
|
| 135 |
+
errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
|
| 136 |
+
return
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
webui_repo.git.fetch(all=True)
|
| 140 |
+
webui_repo.git.reset(webui_commit_hash, hard=True)
|
| 141 |
+
print(f"* Restored webui to commit {webui_commit_hash}.")
|
| 142 |
+
except Exception:
|
| 143 |
+
errors.report(f"Error restoring webui to commit{webui_commit_hash}")
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def restore_extension_config(config):
|
| 147 |
+
print("* Restoring extension state...")
|
| 148 |
+
|
| 149 |
+
if "extensions" not in config:
|
| 150 |
+
print("Error: No extension data saved to config")
|
| 151 |
+
return
|
| 152 |
+
|
| 153 |
+
ext_config = config["extensions"]
|
| 154 |
+
|
| 155 |
+
results = []
|
| 156 |
+
disabled = []
|
| 157 |
+
|
| 158 |
+
for ext in tqdm.tqdm(extensions.extensions):
|
| 159 |
+
if ext.is_builtin:
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
ext.read_info_from_repo()
|
| 163 |
+
current_commit = ext.commit_hash
|
| 164 |
+
|
| 165 |
+
if ext.name not in ext_config:
|
| 166 |
+
ext.disabled = True
|
| 167 |
+
disabled.append(ext.name)
|
| 168 |
+
results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled"))
|
| 169 |
+
continue
|
| 170 |
+
|
| 171 |
+
entry = ext_config[ext.name]
|
| 172 |
+
|
| 173 |
+
if "commit_hash" in entry and entry["commit_hash"]:
|
| 174 |
+
try:
|
| 175 |
+
ext.fetch_and_reset_hard(entry["commit_hash"])
|
| 176 |
+
ext.read_info_from_repo()
|
| 177 |
+
if current_commit != entry["commit_hash"]:
|
| 178 |
+
results.append((ext, current_commit[:8], True, entry["commit_hash"][:8]))
|
| 179 |
+
except Exception as ex:
|
| 180 |
+
results.append((ext, current_commit[:8], False, ex))
|
| 181 |
+
else:
|
| 182 |
+
results.append((ext, current_commit[:8], False, "No commit hash found in config"))
|
| 183 |
+
|
| 184 |
+
if not entry.get("enabled", False):
|
| 185 |
+
ext.disabled = True
|
| 186 |
+
disabled.append(ext.name)
|
| 187 |
+
else:
|
| 188 |
+
ext.disabled = False
|
| 189 |
+
|
| 190 |
+
shared.opts.disabled_extensions = disabled
|
| 191 |
+
shared.opts.save(shared.config_filename)
|
| 192 |
+
|
| 193 |
+
print("* Finished restoring extensions. Results:")
|
| 194 |
+
for ext, prev_commit, success, result in results:
|
| 195 |
+
if success:
|
| 196 |
+
print(f" + {ext.name}: {prev_commit} -> {result}")
|
| 197 |
+
else:
|
| 198 |
+
print(f" ! {ext.name}: FAILURE ({result})")
|
modules/dat_model.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from modules import modelloader, errors
|
| 4 |
+
from modules.shared import cmd_opts, opts
|
| 5 |
+
from modules.upscaler import Upscaler, UpscalerData
|
| 6 |
+
from modules.upscaler_utils import upscale_with_model
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class UpscalerDAT(Upscaler):
|
| 10 |
+
def __init__(self, user_path):
|
| 11 |
+
self.name = "DAT"
|
| 12 |
+
self.user_path = user_path
|
| 13 |
+
self.scalers = []
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
for file in self.find_models(ext_filter=[".pt", ".pth"]):
|
| 17 |
+
name = modelloader.friendly_name(file)
|
| 18 |
+
scaler_data = UpscalerData(name, file, upscaler=self, scale=None)
|
| 19 |
+
self.scalers.append(scaler_data)
|
| 20 |
+
|
| 21 |
+
for model in get_dat_models(self):
|
| 22 |
+
if model.name in opts.dat_enabled_models:
|
| 23 |
+
self.scalers.append(model)
|
| 24 |
+
|
| 25 |
+
def do_upscale(self, img, path):
|
| 26 |
+
try:
|
| 27 |
+
info = self.load_model(path)
|
| 28 |
+
except Exception:
|
| 29 |
+
errors.report(f"Unable to load DAT model {path}", exc_info=True)
|
| 30 |
+
return img
|
| 31 |
+
|
| 32 |
+
model_descriptor = modelloader.load_spandrel_model(
|
| 33 |
+
info.local_data_path,
|
| 34 |
+
device=self.device,
|
| 35 |
+
prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
|
| 36 |
+
expected_architecture="DAT",
|
| 37 |
+
)
|
| 38 |
+
return upscale_with_model(
|
| 39 |
+
model_descriptor,
|
| 40 |
+
img,
|
| 41 |
+
tile_size=opts.DAT_tile,
|
| 42 |
+
tile_overlap=opts.DAT_tile_overlap,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def load_model(self, path):
|
| 46 |
+
for scaler in self.scalers:
|
| 47 |
+
if scaler.data_path == path:
|
| 48 |
+
if scaler.local_data_path.startswith("http"):
|
| 49 |
+
scaler.local_data_path = modelloader.load_file_from_url(
|
| 50 |
+
scaler.data_path,
|
| 51 |
+
model_dir=self.model_download_path,
|
| 52 |
+
)
|
| 53 |
+
if not os.path.exists(scaler.local_data_path):
|
| 54 |
+
raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}")
|
| 55 |
+
return scaler
|
| 56 |
+
raise ValueError(f"Unable to find model info: {path}")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_dat_models(scaler):
|
| 60 |
+
return [
|
| 61 |
+
UpscalerData(
|
| 62 |
+
name="DAT x2",
|
| 63 |
+
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth",
|
| 64 |
+
scale=2,
|
| 65 |
+
upscaler=scaler,
|
| 66 |
+
),
|
| 67 |
+
UpscalerData(
|
| 68 |
+
name="DAT x3",
|
| 69 |
+
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth",
|
| 70 |
+
scale=3,
|
| 71 |
+
upscaler=scaler,
|
| 72 |
+
),
|
| 73 |
+
UpscalerData(
|
| 74 |
+
name="DAT x4",
|
| 75 |
+
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth",
|
| 76 |
+
scale=4,
|
| 77 |
+
upscaler=scaler,
|
| 78 |
+
),
|
| 79 |
+
]
|
modules/deepbooru.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from modules import modelloader, paths, deepbooru_model, devices, images, shared
|
| 8 |
+
|
| 9 |
+
re_special = re.compile(r'([\\()])')
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DeepDanbooru:
|
| 13 |
+
def __init__(self):
|
| 14 |
+
self.model = None
|
| 15 |
+
|
| 16 |
+
def load(self):
|
| 17 |
+
if self.model is not None:
|
| 18 |
+
return
|
| 19 |
+
|
| 20 |
+
files = modelloader.load_models(
|
| 21 |
+
model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
|
| 22 |
+
model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
|
| 23 |
+
ext_filter=[".pt"],
|
| 24 |
+
download_name='model-resnet_custom_v3.pt',
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
self.model = deepbooru_model.DeepDanbooruModel()
|
| 28 |
+
self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
|
| 29 |
+
|
| 30 |
+
self.model.eval()
|
| 31 |
+
self.model.to(devices.cpu, devices.dtype)
|
| 32 |
+
|
| 33 |
+
def start(self):
|
| 34 |
+
self.load()
|
| 35 |
+
self.model.to(devices.device)
|
| 36 |
+
|
| 37 |
+
def stop(self):
|
| 38 |
+
if not shared.opts.interrogate_keep_models_in_memory:
|
| 39 |
+
self.model.to(devices.cpu)
|
| 40 |
+
devices.torch_gc()
|
| 41 |
+
|
| 42 |
+
def tag(self, pil_image):
|
| 43 |
+
self.start()
|
| 44 |
+
res = self.tag_multi(pil_image)
|
| 45 |
+
self.stop()
|
| 46 |
+
|
| 47 |
+
return res
|
| 48 |
+
|
| 49 |
+
def tag_multi(self, pil_image, force_disable_ranks=False):
|
| 50 |
+
threshold = shared.opts.interrogate_deepbooru_score_threshold
|
| 51 |
+
use_spaces = shared.opts.deepbooru_use_spaces
|
| 52 |
+
use_escape = shared.opts.deepbooru_escape
|
| 53 |
+
alpha_sort = shared.opts.deepbooru_sort_alpha
|
| 54 |
+
include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
|
| 55 |
+
|
| 56 |
+
pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
|
| 57 |
+
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
|
| 58 |
+
|
| 59 |
+
with torch.no_grad(), devices.autocast():
|
| 60 |
+
x = torch.from_numpy(a).to(devices.device, devices.dtype)
|
| 61 |
+
y = self.model(x)[0].detach().cpu().numpy()
|
| 62 |
+
|
| 63 |
+
probability_dict = {}
|
| 64 |
+
|
| 65 |
+
for tag, probability in zip(self.model.tags, y):
|
| 66 |
+
if probability < threshold:
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
if tag.startswith("rating:"):
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
probability_dict[tag] = probability
|
| 73 |
+
|
| 74 |
+
if alpha_sort:
|
| 75 |
+
tags = sorted(probability_dict)
|
| 76 |
+
else:
|
| 77 |
+
tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
|
| 78 |
+
|
| 79 |
+
res = []
|
| 80 |
+
|
| 81 |
+
filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
|
| 82 |
+
|
| 83 |
+
for tag in [x for x in tags if x not in filtertags]:
|
| 84 |
+
probability = probability_dict[tag]
|
| 85 |
+
tag_outformat = tag
|
| 86 |
+
if use_spaces:
|
| 87 |
+
tag_outformat = tag_outformat.replace('_', ' ')
|
| 88 |
+
if use_escape:
|
| 89 |
+
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
|
| 90 |
+
if include_ranks:
|
| 91 |
+
tag_outformat = f"({tag_outformat}:{probability:.3f})"
|
| 92 |
+
|
| 93 |
+
res.append(tag_outformat)
|
| 94 |
+
|
| 95 |
+
return ", ".join(res)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
model = DeepDanbooru()
|
modules/deepbooru_model.py
ADDED
|
@@ -0,0 +1,678 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from modules import devices
|
| 6 |
+
|
| 7 |
+
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DeepDanbooruModel(nn.Module):
|
| 11 |
+
def __init__(self):
|
| 12 |
+
super(DeepDanbooruModel, self).__init__()
|
| 13 |
+
|
| 14 |
+
self.tags = []
|
| 15 |
+
|
| 16 |
+
self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
|
| 17 |
+
self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
|
| 18 |
+
self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
| 19 |
+
self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
|
| 20 |
+
self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
| 21 |
+
self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
| 22 |
+
self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
| 23 |
+
self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
| 24 |
+
self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
| 25 |
+
self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
| 26 |
+
self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
| 27 |
+
self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
| 28 |
+
self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
|
| 29 |
+
self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
|
| 30 |
+
self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
|
| 31 |
+
self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
| 32 |
+
self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
| 33 |
+
self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
| 34 |
+
self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
| 35 |
+
self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
| 36 |
+
self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
| 37 |
+
self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
| 38 |
+
self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
| 39 |
+
self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
| 40 |
+
self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
| 41 |
+
self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
| 42 |
+
self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
| 43 |
+
self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
| 44 |
+
self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
| 45 |
+
self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
| 46 |
+
self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
| 47 |
+
self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
| 48 |
+
self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
| 49 |
+
self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
| 50 |
+
self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
| 51 |
+
self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
| 52 |
+
self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
| 53 |
+
self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
|
| 54 |
+
self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
|
| 55 |
+
self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
| 56 |
+
self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 57 |
+
self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 58 |
+
self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 59 |
+
self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 60 |
+
self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 61 |
+
self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 62 |
+
self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 63 |
+
self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 64 |
+
self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 65 |
+
self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 66 |
+
self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 67 |
+
self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 68 |
+
self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 69 |
+
self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 70 |
+
self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 71 |
+
self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 72 |
+
self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 73 |
+
self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 74 |
+
self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 75 |
+
self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 76 |
+
self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 77 |
+
self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 78 |
+
self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 79 |
+
self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 80 |
+
self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 81 |
+
self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 82 |
+
self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 83 |
+
self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 84 |
+
self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 85 |
+
self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 86 |
+
self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 87 |
+
self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 88 |
+
self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 89 |
+
self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 90 |
+
self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 91 |
+
self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 92 |
+
self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 93 |
+
self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 94 |
+
self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 95 |
+
self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 96 |
+
self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 97 |
+
self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 98 |
+
self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 99 |
+
self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 100 |
+
self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 101 |
+
self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 102 |
+
self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 103 |
+
self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 104 |
+
self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 105 |
+
self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 106 |
+
self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 107 |
+
self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 108 |
+
self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 109 |
+
self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 110 |
+
self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 111 |
+
self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 112 |
+
self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 113 |
+
self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 114 |
+
self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 115 |
+
self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
| 116 |
+
self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 117 |
+
self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
|
| 118 |
+
self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 119 |
+
self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 120 |
+
self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 121 |
+
self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 122 |
+
self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 123 |
+
self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 124 |
+
self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 125 |
+
self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 126 |
+
self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 127 |
+
self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 128 |
+
self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 129 |
+
self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 130 |
+
self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 131 |
+
self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 132 |
+
self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 133 |
+
self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 134 |
+
self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 135 |
+
self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 136 |
+
self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 137 |
+
self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 138 |
+
self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 139 |
+
self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 140 |
+
self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 141 |
+
self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 142 |
+
self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 143 |
+
self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 144 |
+
self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 145 |
+
self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 146 |
+
self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 147 |
+
self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 148 |
+
self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 149 |
+
self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 150 |
+
self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 151 |
+
self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 152 |
+
self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 153 |
+
self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 154 |
+
self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 155 |
+
self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 156 |
+
self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 157 |
+
self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 158 |
+
self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 159 |
+
self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 160 |
+
self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 161 |
+
self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 162 |
+
self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 163 |
+
self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 164 |
+
self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 165 |
+
self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 166 |
+
self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 167 |
+
self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 168 |
+
self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 169 |
+
self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 170 |
+
self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 171 |
+
self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 172 |
+
self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
| 173 |
+
self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
| 174 |
+
self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
| 175 |
+
self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
|
| 176 |
+
self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
|
| 177 |
+
self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
|
| 178 |
+
self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
| 179 |
+
self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
| 180 |
+
self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
| 181 |
+
self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
| 182 |
+
self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
| 183 |
+
self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
| 184 |
+
self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
| 185 |
+
self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
|
| 186 |
+
self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
|
| 187 |
+
self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
|
| 188 |
+
self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
| 189 |
+
self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
| 190 |
+
self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
| 191 |
+
self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
| 192 |
+
self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
| 193 |
+
self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
| 194 |
+
self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
| 195 |
+
self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
|
| 196 |
+
|
| 197 |
+
def forward(self, *inputs):
|
| 198 |
+
t_358, = inputs
|
| 199 |
+
t_359 = t_358.permute(*[0, 3, 1, 2])
|
| 200 |
+
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
|
| 201 |
+
t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
|
| 202 |
+
t_361 = F.relu(t_360)
|
| 203 |
+
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
|
| 204 |
+
t_362 = self.n_MaxPool_0(t_361)
|
| 205 |
+
t_363 = self.n_Conv_1(t_362)
|
| 206 |
+
t_364 = self.n_Conv_2(t_362)
|
| 207 |
+
t_365 = F.relu(t_364)
|
| 208 |
+
t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
|
| 209 |
+
t_366 = self.n_Conv_3(t_365_padded)
|
| 210 |
+
t_367 = F.relu(t_366)
|
| 211 |
+
t_368 = self.n_Conv_4(t_367)
|
| 212 |
+
t_369 = torch.add(t_368, t_363)
|
| 213 |
+
t_370 = F.relu(t_369)
|
| 214 |
+
t_371 = self.n_Conv_5(t_370)
|
| 215 |
+
t_372 = F.relu(t_371)
|
| 216 |
+
t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
|
| 217 |
+
t_373 = self.n_Conv_6(t_372_padded)
|
| 218 |
+
t_374 = F.relu(t_373)
|
| 219 |
+
t_375 = self.n_Conv_7(t_374)
|
| 220 |
+
t_376 = torch.add(t_375, t_370)
|
| 221 |
+
t_377 = F.relu(t_376)
|
| 222 |
+
t_378 = self.n_Conv_8(t_377)
|
| 223 |
+
t_379 = F.relu(t_378)
|
| 224 |
+
t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
|
| 225 |
+
t_380 = self.n_Conv_9(t_379_padded)
|
| 226 |
+
t_381 = F.relu(t_380)
|
| 227 |
+
t_382 = self.n_Conv_10(t_381)
|
| 228 |
+
t_383 = torch.add(t_382, t_377)
|
| 229 |
+
t_384 = F.relu(t_383)
|
| 230 |
+
t_385 = self.n_Conv_11(t_384)
|
| 231 |
+
t_386 = self.n_Conv_12(t_384)
|
| 232 |
+
t_387 = F.relu(t_386)
|
| 233 |
+
t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
|
| 234 |
+
t_388 = self.n_Conv_13(t_387_padded)
|
| 235 |
+
t_389 = F.relu(t_388)
|
| 236 |
+
t_390 = self.n_Conv_14(t_389)
|
| 237 |
+
t_391 = torch.add(t_390, t_385)
|
| 238 |
+
t_392 = F.relu(t_391)
|
| 239 |
+
t_393 = self.n_Conv_15(t_392)
|
| 240 |
+
t_394 = F.relu(t_393)
|
| 241 |
+
t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
|
| 242 |
+
t_395 = self.n_Conv_16(t_394_padded)
|
| 243 |
+
t_396 = F.relu(t_395)
|
| 244 |
+
t_397 = self.n_Conv_17(t_396)
|
| 245 |
+
t_398 = torch.add(t_397, t_392)
|
| 246 |
+
t_399 = F.relu(t_398)
|
| 247 |
+
t_400 = self.n_Conv_18(t_399)
|
| 248 |
+
t_401 = F.relu(t_400)
|
| 249 |
+
t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
|
| 250 |
+
t_402 = self.n_Conv_19(t_401_padded)
|
| 251 |
+
t_403 = F.relu(t_402)
|
| 252 |
+
t_404 = self.n_Conv_20(t_403)
|
| 253 |
+
t_405 = torch.add(t_404, t_399)
|
| 254 |
+
t_406 = F.relu(t_405)
|
| 255 |
+
t_407 = self.n_Conv_21(t_406)
|
| 256 |
+
t_408 = F.relu(t_407)
|
| 257 |
+
t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
|
| 258 |
+
t_409 = self.n_Conv_22(t_408_padded)
|
| 259 |
+
t_410 = F.relu(t_409)
|
| 260 |
+
t_411 = self.n_Conv_23(t_410)
|
| 261 |
+
t_412 = torch.add(t_411, t_406)
|
| 262 |
+
t_413 = F.relu(t_412)
|
| 263 |
+
t_414 = self.n_Conv_24(t_413)
|
| 264 |
+
t_415 = F.relu(t_414)
|
| 265 |
+
t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
|
| 266 |
+
t_416 = self.n_Conv_25(t_415_padded)
|
| 267 |
+
t_417 = F.relu(t_416)
|
| 268 |
+
t_418 = self.n_Conv_26(t_417)
|
| 269 |
+
t_419 = torch.add(t_418, t_413)
|
| 270 |
+
t_420 = F.relu(t_419)
|
| 271 |
+
t_421 = self.n_Conv_27(t_420)
|
| 272 |
+
t_422 = F.relu(t_421)
|
| 273 |
+
t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
|
| 274 |
+
t_423 = self.n_Conv_28(t_422_padded)
|
| 275 |
+
t_424 = F.relu(t_423)
|
| 276 |
+
t_425 = self.n_Conv_29(t_424)
|
| 277 |
+
t_426 = torch.add(t_425, t_420)
|
| 278 |
+
t_427 = F.relu(t_426)
|
| 279 |
+
t_428 = self.n_Conv_30(t_427)
|
| 280 |
+
t_429 = F.relu(t_428)
|
| 281 |
+
t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
|
| 282 |
+
t_430 = self.n_Conv_31(t_429_padded)
|
| 283 |
+
t_431 = F.relu(t_430)
|
| 284 |
+
t_432 = self.n_Conv_32(t_431)
|
| 285 |
+
t_433 = torch.add(t_432, t_427)
|
| 286 |
+
t_434 = F.relu(t_433)
|
| 287 |
+
t_435 = self.n_Conv_33(t_434)
|
| 288 |
+
t_436 = F.relu(t_435)
|
| 289 |
+
t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
|
| 290 |
+
t_437 = self.n_Conv_34(t_436_padded)
|
| 291 |
+
t_438 = F.relu(t_437)
|
| 292 |
+
t_439 = self.n_Conv_35(t_438)
|
| 293 |
+
t_440 = torch.add(t_439, t_434)
|
| 294 |
+
t_441 = F.relu(t_440)
|
| 295 |
+
t_442 = self.n_Conv_36(t_441)
|
| 296 |
+
t_443 = self.n_Conv_37(t_441)
|
| 297 |
+
t_444 = F.relu(t_443)
|
| 298 |
+
t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
|
| 299 |
+
t_445 = self.n_Conv_38(t_444_padded)
|
| 300 |
+
t_446 = F.relu(t_445)
|
| 301 |
+
t_447 = self.n_Conv_39(t_446)
|
| 302 |
+
t_448 = torch.add(t_447, t_442)
|
| 303 |
+
t_449 = F.relu(t_448)
|
| 304 |
+
t_450 = self.n_Conv_40(t_449)
|
| 305 |
+
t_451 = F.relu(t_450)
|
| 306 |
+
t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
|
| 307 |
+
t_452 = self.n_Conv_41(t_451_padded)
|
| 308 |
+
t_453 = F.relu(t_452)
|
| 309 |
+
t_454 = self.n_Conv_42(t_453)
|
| 310 |
+
t_455 = torch.add(t_454, t_449)
|
| 311 |
+
t_456 = F.relu(t_455)
|
| 312 |
+
t_457 = self.n_Conv_43(t_456)
|
| 313 |
+
t_458 = F.relu(t_457)
|
| 314 |
+
t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
|
| 315 |
+
t_459 = self.n_Conv_44(t_458_padded)
|
| 316 |
+
t_460 = F.relu(t_459)
|
| 317 |
+
t_461 = self.n_Conv_45(t_460)
|
| 318 |
+
t_462 = torch.add(t_461, t_456)
|
| 319 |
+
t_463 = F.relu(t_462)
|
| 320 |
+
t_464 = self.n_Conv_46(t_463)
|
| 321 |
+
t_465 = F.relu(t_464)
|
| 322 |
+
t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
|
| 323 |
+
t_466 = self.n_Conv_47(t_465_padded)
|
| 324 |
+
t_467 = F.relu(t_466)
|
| 325 |
+
t_468 = self.n_Conv_48(t_467)
|
| 326 |
+
t_469 = torch.add(t_468, t_463)
|
| 327 |
+
t_470 = F.relu(t_469)
|
| 328 |
+
t_471 = self.n_Conv_49(t_470)
|
| 329 |
+
t_472 = F.relu(t_471)
|
| 330 |
+
t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
|
| 331 |
+
t_473 = self.n_Conv_50(t_472_padded)
|
| 332 |
+
t_474 = F.relu(t_473)
|
| 333 |
+
t_475 = self.n_Conv_51(t_474)
|
| 334 |
+
t_476 = torch.add(t_475, t_470)
|
| 335 |
+
t_477 = F.relu(t_476)
|
| 336 |
+
t_478 = self.n_Conv_52(t_477)
|
| 337 |
+
t_479 = F.relu(t_478)
|
| 338 |
+
t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
|
| 339 |
+
t_480 = self.n_Conv_53(t_479_padded)
|
| 340 |
+
t_481 = F.relu(t_480)
|
| 341 |
+
t_482 = self.n_Conv_54(t_481)
|
| 342 |
+
t_483 = torch.add(t_482, t_477)
|
| 343 |
+
t_484 = F.relu(t_483)
|
| 344 |
+
t_485 = self.n_Conv_55(t_484)
|
| 345 |
+
t_486 = F.relu(t_485)
|
| 346 |
+
t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
|
| 347 |
+
t_487 = self.n_Conv_56(t_486_padded)
|
| 348 |
+
t_488 = F.relu(t_487)
|
| 349 |
+
t_489 = self.n_Conv_57(t_488)
|
| 350 |
+
t_490 = torch.add(t_489, t_484)
|
| 351 |
+
t_491 = F.relu(t_490)
|
| 352 |
+
t_492 = self.n_Conv_58(t_491)
|
| 353 |
+
t_493 = F.relu(t_492)
|
| 354 |
+
t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
|
| 355 |
+
t_494 = self.n_Conv_59(t_493_padded)
|
| 356 |
+
t_495 = F.relu(t_494)
|
| 357 |
+
t_496 = self.n_Conv_60(t_495)
|
| 358 |
+
t_497 = torch.add(t_496, t_491)
|
| 359 |
+
t_498 = F.relu(t_497)
|
| 360 |
+
t_499 = self.n_Conv_61(t_498)
|
| 361 |
+
t_500 = F.relu(t_499)
|
| 362 |
+
t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
|
| 363 |
+
t_501 = self.n_Conv_62(t_500_padded)
|
| 364 |
+
t_502 = F.relu(t_501)
|
| 365 |
+
t_503 = self.n_Conv_63(t_502)
|
| 366 |
+
t_504 = torch.add(t_503, t_498)
|
| 367 |
+
t_505 = F.relu(t_504)
|
| 368 |
+
t_506 = self.n_Conv_64(t_505)
|
| 369 |
+
t_507 = F.relu(t_506)
|
| 370 |
+
t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
|
| 371 |
+
t_508 = self.n_Conv_65(t_507_padded)
|
| 372 |
+
t_509 = F.relu(t_508)
|
| 373 |
+
t_510 = self.n_Conv_66(t_509)
|
| 374 |
+
t_511 = torch.add(t_510, t_505)
|
| 375 |
+
t_512 = F.relu(t_511)
|
| 376 |
+
t_513 = self.n_Conv_67(t_512)
|
| 377 |
+
t_514 = F.relu(t_513)
|
| 378 |
+
t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
|
| 379 |
+
t_515 = self.n_Conv_68(t_514_padded)
|
| 380 |
+
t_516 = F.relu(t_515)
|
| 381 |
+
t_517 = self.n_Conv_69(t_516)
|
| 382 |
+
t_518 = torch.add(t_517, t_512)
|
| 383 |
+
t_519 = F.relu(t_518)
|
| 384 |
+
t_520 = self.n_Conv_70(t_519)
|
| 385 |
+
t_521 = F.relu(t_520)
|
| 386 |
+
t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
|
| 387 |
+
t_522 = self.n_Conv_71(t_521_padded)
|
| 388 |
+
t_523 = F.relu(t_522)
|
| 389 |
+
t_524 = self.n_Conv_72(t_523)
|
| 390 |
+
t_525 = torch.add(t_524, t_519)
|
| 391 |
+
t_526 = F.relu(t_525)
|
| 392 |
+
t_527 = self.n_Conv_73(t_526)
|
| 393 |
+
t_528 = F.relu(t_527)
|
| 394 |
+
t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
|
| 395 |
+
t_529 = self.n_Conv_74(t_528_padded)
|
| 396 |
+
t_530 = F.relu(t_529)
|
| 397 |
+
t_531 = self.n_Conv_75(t_530)
|
| 398 |
+
t_532 = torch.add(t_531, t_526)
|
| 399 |
+
t_533 = F.relu(t_532)
|
| 400 |
+
t_534 = self.n_Conv_76(t_533)
|
| 401 |
+
t_535 = F.relu(t_534)
|
| 402 |
+
t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
|
| 403 |
+
t_536 = self.n_Conv_77(t_535_padded)
|
| 404 |
+
t_537 = F.relu(t_536)
|
| 405 |
+
t_538 = self.n_Conv_78(t_537)
|
| 406 |
+
t_539 = torch.add(t_538, t_533)
|
| 407 |
+
t_540 = F.relu(t_539)
|
| 408 |
+
t_541 = self.n_Conv_79(t_540)
|
| 409 |
+
t_542 = F.relu(t_541)
|
| 410 |
+
t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
|
| 411 |
+
t_543 = self.n_Conv_80(t_542_padded)
|
| 412 |
+
t_544 = F.relu(t_543)
|
| 413 |
+
t_545 = self.n_Conv_81(t_544)
|
| 414 |
+
t_546 = torch.add(t_545, t_540)
|
| 415 |
+
t_547 = F.relu(t_546)
|
| 416 |
+
t_548 = self.n_Conv_82(t_547)
|
| 417 |
+
t_549 = F.relu(t_548)
|
| 418 |
+
t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
|
| 419 |
+
t_550 = self.n_Conv_83(t_549_padded)
|
| 420 |
+
t_551 = F.relu(t_550)
|
| 421 |
+
t_552 = self.n_Conv_84(t_551)
|
| 422 |
+
t_553 = torch.add(t_552, t_547)
|
| 423 |
+
t_554 = F.relu(t_553)
|
| 424 |
+
t_555 = self.n_Conv_85(t_554)
|
| 425 |
+
t_556 = F.relu(t_555)
|
| 426 |
+
t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
|
| 427 |
+
t_557 = self.n_Conv_86(t_556_padded)
|
| 428 |
+
t_558 = F.relu(t_557)
|
| 429 |
+
t_559 = self.n_Conv_87(t_558)
|
| 430 |
+
t_560 = torch.add(t_559, t_554)
|
| 431 |
+
t_561 = F.relu(t_560)
|
| 432 |
+
t_562 = self.n_Conv_88(t_561)
|
| 433 |
+
t_563 = F.relu(t_562)
|
| 434 |
+
t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
|
| 435 |
+
t_564 = self.n_Conv_89(t_563_padded)
|
| 436 |
+
t_565 = F.relu(t_564)
|
| 437 |
+
t_566 = self.n_Conv_90(t_565)
|
| 438 |
+
t_567 = torch.add(t_566, t_561)
|
| 439 |
+
t_568 = F.relu(t_567)
|
| 440 |
+
t_569 = self.n_Conv_91(t_568)
|
| 441 |
+
t_570 = F.relu(t_569)
|
| 442 |
+
t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
|
| 443 |
+
t_571 = self.n_Conv_92(t_570_padded)
|
| 444 |
+
t_572 = F.relu(t_571)
|
| 445 |
+
t_573 = self.n_Conv_93(t_572)
|
| 446 |
+
t_574 = torch.add(t_573, t_568)
|
| 447 |
+
t_575 = F.relu(t_574)
|
| 448 |
+
t_576 = self.n_Conv_94(t_575)
|
| 449 |
+
t_577 = F.relu(t_576)
|
| 450 |
+
t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
|
| 451 |
+
t_578 = self.n_Conv_95(t_577_padded)
|
| 452 |
+
t_579 = F.relu(t_578)
|
| 453 |
+
t_580 = self.n_Conv_96(t_579)
|
| 454 |
+
t_581 = torch.add(t_580, t_575)
|
| 455 |
+
t_582 = F.relu(t_581)
|
| 456 |
+
t_583 = self.n_Conv_97(t_582)
|
| 457 |
+
t_584 = F.relu(t_583)
|
| 458 |
+
t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
|
| 459 |
+
t_585 = self.n_Conv_98(t_584_padded)
|
| 460 |
+
t_586 = F.relu(t_585)
|
| 461 |
+
t_587 = self.n_Conv_99(t_586)
|
| 462 |
+
t_588 = self.n_Conv_100(t_582)
|
| 463 |
+
t_589 = torch.add(t_587, t_588)
|
| 464 |
+
t_590 = F.relu(t_589)
|
| 465 |
+
t_591 = self.n_Conv_101(t_590)
|
| 466 |
+
t_592 = F.relu(t_591)
|
| 467 |
+
t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
|
| 468 |
+
t_593 = self.n_Conv_102(t_592_padded)
|
| 469 |
+
t_594 = F.relu(t_593)
|
| 470 |
+
t_595 = self.n_Conv_103(t_594)
|
| 471 |
+
t_596 = torch.add(t_595, t_590)
|
| 472 |
+
t_597 = F.relu(t_596)
|
| 473 |
+
t_598 = self.n_Conv_104(t_597)
|
| 474 |
+
t_599 = F.relu(t_598)
|
| 475 |
+
t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
|
| 476 |
+
t_600 = self.n_Conv_105(t_599_padded)
|
| 477 |
+
t_601 = F.relu(t_600)
|
| 478 |
+
t_602 = self.n_Conv_106(t_601)
|
| 479 |
+
t_603 = torch.add(t_602, t_597)
|
| 480 |
+
t_604 = F.relu(t_603)
|
| 481 |
+
t_605 = self.n_Conv_107(t_604)
|
| 482 |
+
t_606 = F.relu(t_605)
|
| 483 |
+
t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
|
| 484 |
+
t_607 = self.n_Conv_108(t_606_padded)
|
| 485 |
+
t_608 = F.relu(t_607)
|
| 486 |
+
t_609 = self.n_Conv_109(t_608)
|
| 487 |
+
t_610 = torch.add(t_609, t_604)
|
| 488 |
+
t_611 = F.relu(t_610)
|
| 489 |
+
t_612 = self.n_Conv_110(t_611)
|
| 490 |
+
t_613 = F.relu(t_612)
|
| 491 |
+
t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
|
| 492 |
+
t_614 = self.n_Conv_111(t_613_padded)
|
| 493 |
+
t_615 = F.relu(t_614)
|
| 494 |
+
t_616 = self.n_Conv_112(t_615)
|
| 495 |
+
t_617 = torch.add(t_616, t_611)
|
| 496 |
+
t_618 = F.relu(t_617)
|
| 497 |
+
t_619 = self.n_Conv_113(t_618)
|
| 498 |
+
t_620 = F.relu(t_619)
|
| 499 |
+
t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
|
| 500 |
+
t_621 = self.n_Conv_114(t_620_padded)
|
| 501 |
+
t_622 = F.relu(t_621)
|
| 502 |
+
t_623 = self.n_Conv_115(t_622)
|
| 503 |
+
t_624 = torch.add(t_623, t_618)
|
| 504 |
+
t_625 = F.relu(t_624)
|
| 505 |
+
t_626 = self.n_Conv_116(t_625)
|
| 506 |
+
t_627 = F.relu(t_626)
|
| 507 |
+
t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
|
| 508 |
+
t_628 = self.n_Conv_117(t_627_padded)
|
| 509 |
+
t_629 = F.relu(t_628)
|
| 510 |
+
t_630 = self.n_Conv_118(t_629)
|
| 511 |
+
t_631 = torch.add(t_630, t_625)
|
| 512 |
+
t_632 = F.relu(t_631)
|
| 513 |
+
t_633 = self.n_Conv_119(t_632)
|
| 514 |
+
t_634 = F.relu(t_633)
|
| 515 |
+
t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
|
| 516 |
+
t_635 = self.n_Conv_120(t_634_padded)
|
| 517 |
+
t_636 = F.relu(t_635)
|
| 518 |
+
t_637 = self.n_Conv_121(t_636)
|
| 519 |
+
t_638 = torch.add(t_637, t_632)
|
| 520 |
+
t_639 = F.relu(t_638)
|
| 521 |
+
t_640 = self.n_Conv_122(t_639)
|
| 522 |
+
t_641 = F.relu(t_640)
|
| 523 |
+
t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
|
| 524 |
+
t_642 = self.n_Conv_123(t_641_padded)
|
| 525 |
+
t_643 = F.relu(t_642)
|
| 526 |
+
t_644 = self.n_Conv_124(t_643)
|
| 527 |
+
t_645 = torch.add(t_644, t_639)
|
| 528 |
+
t_646 = F.relu(t_645)
|
| 529 |
+
t_647 = self.n_Conv_125(t_646)
|
| 530 |
+
t_648 = F.relu(t_647)
|
| 531 |
+
t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
|
| 532 |
+
t_649 = self.n_Conv_126(t_648_padded)
|
| 533 |
+
t_650 = F.relu(t_649)
|
| 534 |
+
t_651 = self.n_Conv_127(t_650)
|
| 535 |
+
t_652 = torch.add(t_651, t_646)
|
| 536 |
+
t_653 = F.relu(t_652)
|
| 537 |
+
t_654 = self.n_Conv_128(t_653)
|
| 538 |
+
t_655 = F.relu(t_654)
|
| 539 |
+
t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
|
| 540 |
+
t_656 = self.n_Conv_129(t_655_padded)
|
| 541 |
+
t_657 = F.relu(t_656)
|
| 542 |
+
t_658 = self.n_Conv_130(t_657)
|
| 543 |
+
t_659 = torch.add(t_658, t_653)
|
| 544 |
+
t_660 = F.relu(t_659)
|
| 545 |
+
t_661 = self.n_Conv_131(t_660)
|
| 546 |
+
t_662 = F.relu(t_661)
|
| 547 |
+
t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
|
| 548 |
+
t_663 = self.n_Conv_132(t_662_padded)
|
| 549 |
+
t_664 = F.relu(t_663)
|
| 550 |
+
t_665 = self.n_Conv_133(t_664)
|
| 551 |
+
t_666 = torch.add(t_665, t_660)
|
| 552 |
+
t_667 = F.relu(t_666)
|
| 553 |
+
t_668 = self.n_Conv_134(t_667)
|
| 554 |
+
t_669 = F.relu(t_668)
|
| 555 |
+
t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
|
| 556 |
+
t_670 = self.n_Conv_135(t_669_padded)
|
| 557 |
+
t_671 = F.relu(t_670)
|
| 558 |
+
t_672 = self.n_Conv_136(t_671)
|
| 559 |
+
t_673 = torch.add(t_672, t_667)
|
| 560 |
+
t_674 = F.relu(t_673)
|
| 561 |
+
t_675 = self.n_Conv_137(t_674)
|
| 562 |
+
t_676 = F.relu(t_675)
|
| 563 |
+
t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
|
| 564 |
+
t_677 = self.n_Conv_138(t_676_padded)
|
| 565 |
+
t_678 = F.relu(t_677)
|
| 566 |
+
t_679 = self.n_Conv_139(t_678)
|
| 567 |
+
t_680 = torch.add(t_679, t_674)
|
| 568 |
+
t_681 = F.relu(t_680)
|
| 569 |
+
t_682 = self.n_Conv_140(t_681)
|
| 570 |
+
t_683 = F.relu(t_682)
|
| 571 |
+
t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
|
| 572 |
+
t_684 = self.n_Conv_141(t_683_padded)
|
| 573 |
+
t_685 = F.relu(t_684)
|
| 574 |
+
t_686 = self.n_Conv_142(t_685)
|
| 575 |
+
t_687 = torch.add(t_686, t_681)
|
| 576 |
+
t_688 = F.relu(t_687)
|
| 577 |
+
t_689 = self.n_Conv_143(t_688)
|
| 578 |
+
t_690 = F.relu(t_689)
|
| 579 |
+
t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
|
| 580 |
+
t_691 = self.n_Conv_144(t_690_padded)
|
| 581 |
+
t_692 = F.relu(t_691)
|
| 582 |
+
t_693 = self.n_Conv_145(t_692)
|
| 583 |
+
t_694 = torch.add(t_693, t_688)
|
| 584 |
+
t_695 = F.relu(t_694)
|
| 585 |
+
t_696 = self.n_Conv_146(t_695)
|
| 586 |
+
t_697 = F.relu(t_696)
|
| 587 |
+
t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
|
| 588 |
+
t_698 = self.n_Conv_147(t_697_padded)
|
| 589 |
+
t_699 = F.relu(t_698)
|
| 590 |
+
t_700 = self.n_Conv_148(t_699)
|
| 591 |
+
t_701 = torch.add(t_700, t_695)
|
| 592 |
+
t_702 = F.relu(t_701)
|
| 593 |
+
t_703 = self.n_Conv_149(t_702)
|
| 594 |
+
t_704 = F.relu(t_703)
|
| 595 |
+
t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
|
| 596 |
+
t_705 = self.n_Conv_150(t_704_padded)
|
| 597 |
+
t_706 = F.relu(t_705)
|
| 598 |
+
t_707 = self.n_Conv_151(t_706)
|
| 599 |
+
t_708 = torch.add(t_707, t_702)
|
| 600 |
+
t_709 = F.relu(t_708)
|
| 601 |
+
t_710 = self.n_Conv_152(t_709)
|
| 602 |
+
t_711 = F.relu(t_710)
|
| 603 |
+
t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
|
| 604 |
+
t_712 = self.n_Conv_153(t_711_padded)
|
| 605 |
+
t_713 = F.relu(t_712)
|
| 606 |
+
t_714 = self.n_Conv_154(t_713)
|
| 607 |
+
t_715 = torch.add(t_714, t_709)
|
| 608 |
+
t_716 = F.relu(t_715)
|
| 609 |
+
t_717 = self.n_Conv_155(t_716)
|
| 610 |
+
t_718 = F.relu(t_717)
|
| 611 |
+
t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
|
| 612 |
+
t_719 = self.n_Conv_156(t_718_padded)
|
| 613 |
+
t_720 = F.relu(t_719)
|
| 614 |
+
t_721 = self.n_Conv_157(t_720)
|
| 615 |
+
t_722 = torch.add(t_721, t_716)
|
| 616 |
+
t_723 = F.relu(t_722)
|
| 617 |
+
t_724 = self.n_Conv_158(t_723)
|
| 618 |
+
t_725 = self.n_Conv_159(t_723)
|
| 619 |
+
t_726 = F.relu(t_725)
|
| 620 |
+
t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
|
| 621 |
+
t_727 = self.n_Conv_160(t_726_padded)
|
| 622 |
+
t_728 = F.relu(t_727)
|
| 623 |
+
t_729 = self.n_Conv_161(t_728)
|
| 624 |
+
t_730 = torch.add(t_729, t_724)
|
| 625 |
+
t_731 = F.relu(t_730)
|
| 626 |
+
t_732 = self.n_Conv_162(t_731)
|
| 627 |
+
t_733 = F.relu(t_732)
|
| 628 |
+
t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
|
| 629 |
+
t_734 = self.n_Conv_163(t_733_padded)
|
| 630 |
+
t_735 = F.relu(t_734)
|
| 631 |
+
t_736 = self.n_Conv_164(t_735)
|
| 632 |
+
t_737 = torch.add(t_736, t_731)
|
| 633 |
+
t_738 = F.relu(t_737)
|
| 634 |
+
t_739 = self.n_Conv_165(t_738)
|
| 635 |
+
t_740 = F.relu(t_739)
|
| 636 |
+
t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
|
| 637 |
+
t_741 = self.n_Conv_166(t_740_padded)
|
| 638 |
+
t_742 = F.relu(t_741)
|
| 639 |
+
t_743 = self.n_Conv_167(t_742)
|
| 640 |
+
t_744 = torch.add(t_743, t_738)
|
| 641 |
+
t_745 = F.relu(t_744)
|
| 642 |
+
t_746 = self.n_Conv_168(t_745)
|
| 643 |
+
t_747 = self.n_Conv_169(t_745)
|
| 644 |
+
t_748 = F.relu(t_747)
|
| 645 |
+
t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
|
| 646 |
+
t_749 = self.n_Conv_170(t_748_padded)
|
| 647 |
+
t_750 = F.relu(t_749)
|
| 648 |
+
t_751 = self.n_Conv_171(t_750)
|
| 649 |
+
t_752 = torch.add(t_751, t_746)
|
| 650 |
+
t_753 = F.relu(t_752)
|
| 651 |
+
t_754 = self.n_Conv_172(t_753)
|
| 652 |
+
t_755 = F.relu(t_754)
|
| 653 |
+
t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
|
| 654 |
+
t_756 = self.n_Conv_173(t_755_padded)
|
| 655 |
+
t_757 = F.relu(t_756)
|
| 656 |
+
t_758 = self.n_Conv_174(t_757)
|
| 657 |
+
t_759 = torch.add(t_758, t_753)
|
| 658 |
+
t_760 = F.relu(t_759)
|
| 659 |
+
t_761 = self.n_Conv_175(t_760)
|
| 660 |
+
t_762 = F.relu(t_761)
|
| 661 |
+
t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
|
| 662 |
+
t_763 = self.n_Conv_176(t_762_padded)
|
| 663 |
+
t_764 = F.relu(t_763)
|
| 664 |
+
t_765 = self.n_Conv_177(t_764)
|
| 665 |
+
t_766 = torch.add(t_765, t_760)
|
| 666 |
+
t_767 = F.relu(t_766)
|
| 667 |
+
t_768 = self.n_Conv_178(t_767)
|
| 668 |
+
t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
|
| 669 |
+
t_770 = torch.squeeze(t_769, 3)
|
| 670 |
+
t_770 = torch.squeeze(t_770, 2)
|
| 671 |
+
t_771 = torch.sigmoid(t_770)
|
| 672 |
+
return t_771
|
| 673 |
+
|
| 674 |
+
def load_state_dict(self, state_dict, **kwargs):
|
| 675 |
+
self.tags = state_dict.get('tags', [])
|
| 676 |
+
|
| 677 |
+
super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
|
| 678 |
+
|
modules/devices.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import contextlib
|
| 3 |
+
from functools import lru_cache
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from modules import errors, shared, npu_specific
|
| 7 |
+
|
| 8 |
+
if sys.platform == "darwin":
|
| 9 |
+
from modules import mac_specific
|
| 10 |
+
|
| 11 |
+
if shared.cmd_opts.use_ipex:
|
| 12 |
+
from modules import xpu_specific
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def has_xpu() -> bool:
|
| 16 |
+
return shared.cmd_opts.use_ipex and xpu_specific.has_xpu
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def has_mps() -> bool:
|
| 20 |
+
if sys.platform != "darwin":
|
| 21 |
+
return False
|
| 22 |
+
else:
|
| 23 |
+
return mac_specific.has_mps
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def cuda_no_autocast(device_id=None) -> bool:
|
| 27 |
+
if device_id is None:
|
| 28 |
+
device_id = get_cuda_device_id()
|
| 29 |
+
return (
|
| 30 |
+
torch.cuda.get_device_capability(device_id) == (7, 5)
|
| 31 |
+
and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_cuda_device_id():
|
| 36 |
+
return (
|
| 37 |
+
int(shared.cmd_opts.device_id)
|
| 38 |
+
if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
|
| 39 |
+
else 0
|
| 40 |
+
) or torch.cuda.current_device()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_cuda_device_string():
|
| 44 |
+
if shared.cmd_opts.device_id is not None:
|
| 45 |
+
return f"cuda:{shared.cmd_opts.device_id}"
|
| 46 |
+
|
| 47 |
+
return "cuda"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_optimal_device_name():
|
| 51 |
+
if torch.cuda.is_available():
|
| 52 |
+
return get_cuda_device_string()
|
| 53 |
+
|
| 54 |
+
if has_mps():
|
| 55 |
+
return "mps"
|
| 56 |
+
|
| 57 |
+
if has_xpu():
|
| 58 |
+
return xpu_specific.get_xpu_device_string()
|
| 59 |
+
|
| 60 |
+
if npu_specific.has_npu:
|
| 61 |
+
return npu_specific.get_npu_device_string()
|
| 62 |
+
|
| 63 |
+
return "cpu"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_optimal_device():
|
| 67 |
+
return torch.device(get_optimal_device_name())
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_device_for(task):
|
| 71 |
+
if task in shared.cmd_opts.use_cpu or "all" in shared.cmd_opts.use_cpu:
|
| 72 |
+
return cpu
|
| 73 |
+
|
| 74 |
+
return get_optimal_device()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def torch_gc():
|
| 78 |
+
|
| 79 |
+
if torch.cuda.is_available():
|
| 80 |
+
with torch.cuda.device(get_cuda_device_string()):
|
| 81 |
+
torch.cuda.empty_cache()
|
| 82 |
+
torch.cuda.ipc_collect()
|
| 83 |
+
|
| 84 |
+
if has_mps():
|
| 85 |
+
mac_specific.torch_mps_gc()
|
| 86 |
+
|
| 87 |
+
if has_xpu():
|
| 88 |
+
xpu_specific.torch_xpu_gc()
|
| 89 |
+
|
| 90 |
+
if npu_specific.has_npu:
|
| 91 |
+
torch_npu_set_device()
|
| 92 |
+
npu_specific.torch_npu_gc()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def torch_npu_set_device():
|
| 96 |
+
# Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
|
| 97 |
+
if npu_specific.has_npu:
|
| 98 |
+
torch.npu.set_device(0)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def enable_tf32():
|
| 102 |
+
if torch.cuda.is_available():
|
| 103 |
+
|
| 104 |
+
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
| 105 |
+
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
| 106 |
+
if cuda_no_autocast():
|
| 107 |
+
torch.backends.cudnn.benchmark = True
|
| 108 |
+
|
| 109 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 110 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
errors.run(enable_tf32, "Enabling TF32")
|
| 114 |
+
|
| 115 |
+
cpu: torch.device = torch.device("cpu")
|
| 116 |
+
fp8: bool = False
|
| 117 |
+
# Force fp16 for all models in inference. No casting during inference.
|
| 118 |
+
# This flag is controlled by "--precision half" command line arg.
|
| 119 |
+
force_fp16: bool = False
|
| 120 |
+
device: torch.device = None
|
| 121 |
+
device_interrogate: torch.device = None
|
| 122 |
+
device_gfpgan: torch.device = None
|
| 123 |
+
device_esrgan: torch.device = None
|
| 124 |
+
device_codeformer: torch.device = None
|
| 125 |
+
dtype: torch.dtype = torch.float16
|
| 126 |
+
dtype_vae: torch.dtype = torch.float16
|
| 127 |
+
dtype_unet: torch.dtype = torch.float16
|
| 128 |
+
dtype_inference: torch.dtype = torch.float16
|
| 129 |
+
unet_needs_upcast = False
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def cond_cast_unet(input):
|
| 133 |
+
if force_fp16:
|
| 134 |
+
return input.to(torch.float16)
|
| 135 |
+
return input.to(dtype_unet) if unet_needs_upcast else input
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def cond_cast_float(input):
|
| 139 |
+
return input.float() if unet_needs_upcast else input
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
nv_rng = None
|
| 143 |
+
patch_module_list = [
|
| 144 |
+
torch.nn.Linear,
|
| 145 |
+
torch.nn.Conv2d,
|
| 146 |
+
torch.nn.MultiheadAttention,
|
| 147 |
+
torch.nn.GroupNorm,
|
| 148 |
+
torch.nn.LayerNorm,
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def manual_cast_forward(target_dtype):
|
| 153 |
+
def forward_wrapper(self, *args, **kwargs):
|
| 154 |
+
if any(
|
| 155 |
+
isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
|
| 156 |
+
for arg in args
|
| 157 |
+
):
|
| 158 |
+
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
| 159 |
+
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
| 160 |
+
|
| 161 |
+
org_dtype = target_dtype
|
| 162 |
+
for param in self.parameters():
|
| 163 |
+
if param.dtype != target_dtype:
|
| 164 |
+
org_dtype = param.dtype
|
| 165 |
+
break
|
| 166 |
+
|
| 167 |
+
if org_dtype != target_dtype:
|
| 168 |
+
self.to(target_dtype)
|
| 169 |
+
result = self.org_forward(*args, **kwargs)
|
| 170 |
+
if org_dtype != target_dtype:
|
| 171 |
+
self.to(org_dtype)
|
| 172 |
+
|
| 173 |
+
if target_dtype != dtype_inference:
|
| 174 |
+
if isinstance(result, tuple):
|
| 175 |
+
result = tuple(
|
| 176 |
+
i.to(dtype_inference)
|
| 177 |
+
if isinstance(i, torch.Tensor)
|
| 178 |
+
else i
|
| 179 |
+
for i in result
|
| 180 |
+
)
|
| 181 |
+
elif isinstance(result, torch.Tensor):
|
| 182 |
+
result = result.to(dtype_inference)
|
| 183 |
+
return result
|
| 184 |
+
return forward_wrapper
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
@contextlib.contextmanager
|
| 188 |
+
def manual_cast(target_dtype):
|
| 189 |
+
applied = False
|
| 190 |
+
for module_type in patch_module_list:
|
| 191 |
+
if hasattr(module_type, "org_forward"):
|
| 192 |
+
continue
|
| 193 |
+
applied = True
|
| 194 |
+
org_forward = module_type.forward
|
| 195 |
+
if module_type == torch.nn.MultiheadAttention:
|
| 196 |
+
module_type.forward = manual_cast_forward(torch.float32)
|
| 197 |
+
else:
|
| 198 |
+
module_type.forward = manual_cast_forward(target_dtype)
|
| 199 |
+
module_type.org_forward = org_forward
|
| 200 |
+
try:
|
| 201 |
+
yield None
|
| 202 |
+
finally:
|
| 203 |
+
if applied:
|
| 204 |
+
for module_type in patch_module_list:
|
| 205 |
+
if hasattr(module_type, "org_forward"):
|
| 206 |
+
module_type.forward = module_type.org_forward
|
| 207 |
+
delattr(module_type, "org_forward")
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def autocast(disable=False):
|
| 211 |
+
if disable:
|
| 212 |
+
return contextlib.nullcontext()
|
| 213 |
+
|
| 214 |
+
if force_fp16:
|
| 215 |
+
# No casting during inference if force_fp16 is enabled.
|
| 216 |
+
# All tensor dtype conversion happens before inference.
|
| 217 |
+
return contextlib.nullcontext()
|
| 218 |
+
|
| 219 |
+
if fp8 and device==cpu:
|
| 220 |
+
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
| 221 |
+
|
| 222 |
+
if fp8 and dtype_inference == torch.float32:
|
| 223 |
+
return manual_cast(dtype)
|
| 224 |
+
|
| 225 |
+
if dtype == torch.float32 or dtype_inference == torch.float32:
|
| 226 |
+
return contextlib.nullcontext()
|
| 227 |
+
|
| 228 |
+
if has_xpu() or has_mps() or cuda_no_autocast():
|
| 229 |
+
return manual_cast(dtype)
|
| 230 |
+
|
| 231 |
+
return torch.autocast("cuda")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def without_autocast(disable=False):
|
| 235 |
+
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class NansException(Exception):
|
| 239 |
+
pass
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def test_for_nans(x, where):
|
| 243 |
+
if shared.cmd_opts.disable_nan_check:
|
| 244 |
+
return
|
| 245 |
+
|
| 246 |
+
if not torch.isnan(x[(0, ) * len(x.shape)]):
|
| 247 |
+
return
|
| 248 |
+
|
| 249 |
+
if where == "unet":
|
| 250 |
+
message = "A tensor with NaNs was produced in Unet."
|
| 251 |
+
|
| 252 |
+
if not shared.cmd_opts.no_half:
|
| 253 |
+
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
|
| 254 |
+
|
| 255 |
+
elif where == "vae":
|
| 256 |
+
message = "A tensor with NaNs was produced in VAE."
|
| 257 |
+
|
| 258 |
+
if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
|
| 259 |
+
message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
|
| 260 |
+
else:
|
| 261 |
+
message = "A tensor with NaNs was produced."
|
| 262 |
+
|
| 263 |
+
message += " Use --disable-nan-check commandline argument to disable this check."
|
| 264 |
+
|
| 265 |
+
raise NansException(message)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
@lru_cache
|
| 269 |
+
def first_time_calculation():
|
| 270 |
+
"""
|
| 271 |
+
just do any calculation with pytorch layers - the first time this is done it allocates about 700MB of memory and
|
| 272 |
+
spends about 2.7 seconds doing that, at least with NVidia.
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
x = torch.zeros((1, 1)).to(device, dtype)
|
| 276 |
+
linear = torch.nn.Linear(1, 1).to(device, dtype)
|
| 277 |
+
linear(x)
|
| 278 |
+
|
| 279 |
+
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
| 280 |
+
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
| 281 |
+
conv2d(x)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def force_model_fp16():
|
| 285 |
+
"""
|
| 286 |
+
ldm and sgm has modules.diffusionmodules.util.GroupNorm32.forward, which
|
| 287 |
+
force conversion of input to float32. If force_fp16 is enabled, we need to
|
| 288 |
+
prevent this casting.
|
| 289 |
+
"""
|
| 290 |
+
assert force_fp16
|
| 291 |
+
import sgm.modules.diffusionmodules.util as sgm_util
|
| 292 |
+
import ldm.modules.diffusionmodules.util as ldm_util
|
| 293 |
+
sgm_util.GroupNorm32 = torch.nn.GroupNorm
|
| 294 |
+
ldm_util.GroupNorm32 = torch.nn.GroupNorm
|
| 295 |
+
print("ldm/sgm GroupNorm32 replaced with normal torch.nn.GroupNorm due to `--precision half`.")
|
modules/errors.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import textwrap
|
| 3 |
+
import traceback
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
exception_records = []
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def format_traceback(tb):
|
| 10 |
+
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def format_exception(e, tb):
|
| 14 |
+
return {"exception": str(e), "traceback": format_traceback(tb)}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_exceptions():
|
| 18 |
+
try:
|
| 19 |
+
return list(reversed(exception_records))
|
| 20 |
+
except Exception as e:
|
| 21 |
+
return str(e)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def record_exception():
|
| 25 |
+
_, e, tb = sys.exc_info()
|
| 26 |
+
if e is None:
|
| 27 |
+
return
|
| 28 |
+
|
| 29 |
+
if exception_records and exception_records[-1] == e:
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
exception_records.append(format_exception(e, tb))
|
| 33 |
+
|
| 34 |
+
if len(exception_records) > 5:
|
| 35 |
+
exception_records.pop(0)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def report(message: str, *, exc_info: bool = False) -> None:
|
| 39 |
+
"""
|
| 40 |
+
Print an error message to stderr, with optional traceback.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
record_exception()
|
| 44 |
+
|
| 45 |
+
for line in message.splitlines():
|
| 46 |
+
print("***", line, file=sys.stderr)
|
| 47 |
+
if exc_info:
|
| 48 |
+
print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr)
|
| 49 |
+
print("---", file=sys.stderr)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def print_error_explanation(message):
|
| 53 |
+
record_exception()
|
| 54 |
+
|
| 55 |
+
lines = message.strip().split("\n")
|
| 56 |
+
max_len = max([len(x) for x in lines])
|
| 57 |
+
|
| 58 |
+
print('=' * max_len, file=sys.stderr)
|
| 59 |
+
for line in lines:
|
| 60 |
+
print(line, file=sys.stderr)
|
| 61 |
+
print('=' * max_len, file=sys.stderr)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def display(e: Exception, task, *, full_traceback=False):
|
| 65 |
+
record_exception()
|
| 66 |
+
|
| 67 |
+
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
|
| 68 |
+
te = traceback.TracebackException.from_exception(e)
|
| 69 |
+
if full_traceback:
|
| 70 |
+
# include frames leading up to the try-catch block
|
| 71 |
+
te.stack = traceback.StackSummary(traceback.extract_stack()[:-2] + te.stack)
|
| 72 |
+
print(*te.format(), sep="", file=sys.stderr)
|
| 73 |
+
|
| 74 |
+
message = str(e)
|
| 75 |
+
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
|
| 76 |
+
print_error_explanation("""
|
| 77 |
+
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
|
| 78 |
+
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
|
| 79 |
+
""")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
already_displayed = {}
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def display_once(e: Exception, task):
|
| 86 |
+
record_exception()
|
| 87 |
+
|
| 88 |
+
if task in already_displayed:
|
| 89 |
+
return
|
| 90 |
+
|
| 91 |
+
display(e, task)
|
| 92 |
+
|
| 93 |
+
already_displayed[task] = 1
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def run(code, task):
|
| 97 |
+
try:
|
| 98 |
+
code()
|
| 99 |
+
except Exception as e:
|
| 100 |
+
display(task, e)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def check_versions():
|
| 104 |
+
from packaging import version
|
| 105 |
+
from modules import shared
|
| 106 |
+
|
| 107 |
+
import torch
|
| 108 |
+
import gradio
|
| 109 |
+
|
| 110 |
+
expected_torch_version = "2.1.2"
|
| 111 |
+
expected_xformers_version = "0.0.23.post1"
|
| 112 |
+
expected_gradio_version = "3.41.2"
|
| 113 |
+
|
| 114 |
+
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
| 115 |
+
print_error_explanation(f"""
|
| 116 |
+
You are running torch {torch.__version__}.
|
| 117 |
+
The program is tested to work with torch {expected_torch_version}.
|
| 118 |
+
To reinstall the desired version, run with commandline flag --reinstall-torch.
|
| 119 |
+
Beware that this will cause a lot of large files to be downloaded, as well as
|
| 120 |
+
there are reports of issues with training tab on the latest version.
|
| 121 |
+
|
| 122 |
+
Use --skip-version-check commandline argument to disable this check.
|
| 123 |
+
""".strip())
|
| 124 |
+
|
| 125 |
+
if shared.xformers_available:
|
| 126 |
+
import xformers
|
| 127 |
+
|
| 128 |
+
if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
|
| 129 |
+
print_error_explanation(f"""
|
| 130 |
+
You are running xformers {xformers.__version__}.
|
| 131 |
+
The program is tested to work with xformers {expected_xformers_version}.
|
| 132 |
+
To reinstall the desired version, run with commandline flag --reinstall-xformers.
|
| 133 |
+
|
| 134 |
+
Use --skip-version-check commandline argument to disable this check.
|
| 135 |
+
""".strip())
|
| 136 |
+
|
| 137 |
+
if gradio.__version__ != expected_gradio_version:
|
| 138 |
+
print_error_explanation(f"""
|
| 139 |
+
You are running gradio {gradio.__version__}.
|
| 140 |
+
The program is designed to work with gradio {expected_gradio_version}.
|
| 141 |
+
Using a different version of gradio is extremely likely to break the program.
|
| 142 |
+
|
| 143 |
+
Reasons why you have the mismatched gradio version can be:
|
| 144 |
+
- you use --skip-install flag.
|
| 145 |
+
- you use webui.py to start the program instead of launch.py.
|
| 146 |
+
- an extension installs the incompatible gradio version.
|
| 147 |
+
|
| 148 |
+
Use --skip-version-check commandline argument to disable this check.
|
| 149 |
+
""".strip())
|
| 150 |
+
|
modules/esrgan_model.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from modules import modelloader, devices, errors
|
| 2 |
+
from modules.shared import opts
|
| 3 |
+
from modules.upscaler import Upscaler, UpscalerData
|
| 4 |
+
from modules.upscaler_utils import upscale_with_model
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class UpscalerESRGAN(Upscaler):
|
| 8 |
+
def __init__(self, dirname):
|
| 9 |
+
self.name = "ESRGAN"
|
| 10 |
+
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
|
| 11 |
+
self.model_name = "ESRGAN_4x"
|
| 12 |
+
self.scalers = []
|
| 13 |
+
self.user_path = dirname
|
| 14 |
+
super().__init__()
|
| 15 |
+
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
| 16 |
+
scalers = []
|
| 17 |
+
if len(model_paths) == 0:
|
| 18 |
+
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
| 19 |
+
scalers.append(scaler_data)
|
| 20 |
+
for file in model_paths:
|
| 21 |
+
if file.startswith("http"):
|
| 22 |
+
name = self.model_name
|
| 23 |
+
else:
|
| 24 |
+
name = modelloader.friendly_name(file)
|
| 25 |
+
|
| 26 |
+
scaler_data = UpscalerData(name, file, self, 4)
|
| 27 |
+
self.scalers.append(scaler_data)
|
| 28 |
+
|
| 29 |
+
def do_upscale(self, img, selected_model):
|
| 30 |
+
try:
|
| 31 |
+
model = self.load_model(selected_model)
|
| 32 |
+
except Exception:
|
| 33 |
+
errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True)
|
| 34 |
+
return img
|
| 35 |
+
model.to(devices.device_esrgan)
|
| 36 |
+
return esrgan_upscale(model, img)
|
| 37 |
+
|
| 38 |
+
def load_model(self, path: str):
|
| 39 |
+
if path.startswith("http"):
|
| 40 |
+
# TODO: this doesn't use `path` at all?
|
| 41 |
+
filename = modelloader.load_file_from_url(
|
| 42 |
+
url=self.model_url,
|
| 43 |
+
model_dir=self.model_download_path,
|
| 44 |
+
file_name=f"{self.model_name}.pth",
|
| 45 |
+
)
|
| 46 |
+
else:
|
| 47 |
+
filename = path
|
| 48 |
+
|
| 49 |
+
return modelloader.load_spandrel_model(
|
| 50 |
+
filename,
|
| 51 |
+
device=('cpu' if devices.device_esrgan.type == 'mps' else None),
|
| 52 |
+
expected_architecture='ESRGAN',
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def esrgan_upscale(model, img):
|
| 57 |
+
return upscale_with_model(
|
| 58 |
+
model,
|
| 59 |
+
img,
|
| 60 |
+
tile_size=opts.ESRGAN_tile,
|
| 61 |
+
tile_overlap=opts.ESRGAN_tile_overlap,
|
| 62 |
+
)
|
modules/extensions.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import configparser
|
| 4 |
+
import dataclasses
|
| 5 |
+
import os
|
| 6 |
+
import threading
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
from modules import shared, errors, cache, scripts
|
| 10 |
+
from modules.gitpython_hack import Repo
|
| 11 |
+
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
| 12 |
+
|
| 13 |
+
extensions: list[Extension] = []
|
| 14 |
+
extension_paths: dict[str, Extension] = {}
|
| 15 |
+
loaded_extensions: dict[str, Exception] = {}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
os.makedirs(extensions_dir, exist_ok=True)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def active():
|
| 22 |
+
if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
|
| 23 |
+
return []
|
| 24 |
+
elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra":
|
| 25 |
+
return [x for x in extensions if x.enabled and x.is_builtin]
|
| 26 |
+
else:
|
| 27 |
+
return [x for x in extensions if x.enabled]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclasses.dataclass
|
| 31 |
+
class CallbackOrderInfo:
|
| 32 |
+
name: str
|
| 33 |
+
before: list
|
| 34 |
+
after: list
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ExtensionMetadata:
|
| 38 |
+
filename = "metadata.ini"
|
| 39 |
+
config: configparser.ConfigParser
|
| 40 |
+
canonical_name: str
|
| 41 |
+
requires: list
|
| 42 |
+
|
| 43 |
+
def __init__(self, path, canonical_name):
|
| 44 |
+
self.config = configparser.ConfigParser()
|
| 45 |
+
|
| 46 |
+
filepath = os.path.join(path, self.filename)
|
| 47 |
+
# `self.config.read()` will quietly swallow OSErrors (which FileNotFoundError is),
|
| 48 |
+
# so no need to check whether the file exists beforehand.
|
| 49 |
+
try:
|
| 50 |
+
self.config.read(filepath)
|
| 51 |
+
except Exception:
|
| 52 |
+
errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True)
|
| 53 |
+
|
| 54 |
+
self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)
|
| 55 |
+
self.canonical_name = canonical_name.lower().strip()
|
| 56 |
+
|
| 57 |
+
self.requires = None
|
| 58 |
+
|
| 59 |
+
def get_script_requirements(self, field, section, extra_section=None):
|
| 60 |
+
"""reads a list of requirements from the config; field is the name of the field in the ini file,
|
| 61 |
+
like Requires or Before, and section is the name of the [section] in the ini file; additionally,
|
| 62 |
+
reads more requirements from [extra_section] if specified."""
|
| 63 |
+
|
| 64 |
+
x = self.config.get(section, field, fallback='')
|
| 65 |
+
|
| 66 |
+
if extra_section:
|
| 67 |
+
x = x + ', ' + self.config.get(extra_section, field, fallback='')
|
| 68 |
+
|
| 69 |
+
listed_requirements = self.parse_list(x.lower())
|
| 70 |
+
res = []
|
| 71 |
+
|
| 72 |
+
for requirement in listed_requirements:
|
| 73 |
+
loaded_requirements = (x for x in requirement.split("|") if x in loaded_extensions)
|
| 74 |
+
relevant_requirement = next(loaded_requirements, requirement)
|
| 75 |
+
res.append(relevant_requirement)
|
| 76 |
+
|
| 77 |
+
return res
|
| 78 |
+
|
| 79 |
+
def parse_list(self, text):
|
| 80 |
+
"""converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])"""
|
| 81 |
+
|
| 82 |
+
if not text:
|
| 83 |
+
return []
|
| 84 |
+
|
| 85 |
+
# both "," and " " are accepted as separator
|
| 86 |
+
return [x for x in re.split(r"[,\s]+", text.strip()) if x]
|
| 87 |
+
|
| 88 |
+
def list_callback_order_instructions(self):
|
| 89 |
+
for section in self.config.sections():
|
| 90 |
+
if not section.startswith("callbacks/"):
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
callback_name = section[10:]
|
| 94 |
+
|
| 95 |
+
if not callback_name.startswith(self.canonical_name):
|
| 96 |
+
errors.report(f"Callback order section for extension {self.canonical_name} is referencing the wrong extension: {section}")
|
| 97 |
+
continue
|
| 98 |
+
|
| 99 |
+
before = self.parse_list(self.config.get(section, 'Before', fallback=''))
|
| 100 |
+
after = self.parse_list(self.config.get(section, 'After', fallback=''))
|
| 101 |
+
|
| 102 |
+
yield CallbackOrderInfo(callback_name, before, after)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class Extension:
|
| 106 |
+
lock = threading.Lock()
|
| 107 |
+
cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
|
| 108 |
+
metadata: ExtensionMetadata
|
| 109 |
+
|
| 110 |
+
def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None):
|
| 111 |
+
self.name = name
|
| 112 |
+
self.path = path
|
| 113 |
+
self.enabled = enabled
|
| 114 |
+
self.status = ''
|
| 115 |
+
self.can_update = False
|
| 116 |
+
self.is_builtin = is_builtin
|
| 117 |
+
self.commit_hash = ''
|
| 118 |
+
self.commit_date = None
|
| 119 |
+
self.version = ''
|
| 120 |
+
self.branch = None
|
| 121 |
+
self.remote = None
|
| 122 |
+
self.have_info_from_repo = False
|
| 123 |
+
self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower())
|
| 124 |
+
self.canonical_name = metadata.canonical_name
|
| 125 |
+
|
| 126 |
+
def to_dict(self):
|
| 127 |
+
return {x: getattr(self, x) for x in self.cached_fields}
|
| 128 |
+
|
| 129 |
+
def from_dict(self, d):
|
| 130 |
+
for field in self.cached_fields:
|
| 131 |
+
setattr(self, field, d[field])
|
| 132 |
+
|
| 133 |
+
def read_info_from_repo(self):
|
| 134 |
+
if self.is_builtin or self.have_info_from_repo:
|
| 135 |
+
return
|
| 136 |
+
|
| 137 |
+
def read_from_repo():
|
| 138 |
+
with self.lock:
|
| 139 |
+
if self.have_info_from_repo:
|
| 140 |
+
return
|
| 141 |
+
|
| 142 |
+
self.do_read_info_from_repo()
|
| 143 |
+
|
| 144 |
+
return self.to_dict()
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
|
| 148 |
+
self.from_dict(d)
|
| 149 |
+
except FileNotFoundError:
|
| 150 |
+
pass
|
| 151 |
+
self.status = 'unknown' if self.status == '' else self.status
|
| 152 |
+
|
| 153 |
+
def do_read_info_from_repo(self):
|
| 154 |
+
repo = None
|
| 155 |
+
try:
|
| 156 |
+
if os.path.exists(os.path.join(self.path, ".git")):
|
| 157 |
+
repo = Repo(self.path)
|
| 158 |
+
except Exception:
|
| 159 |
+
errors.report(f"Error reading github repository info from {self.path}", exc_info=True)
|
| 160 |
+
|
| 161 |
+
if repo is None or repo.bare:
|
| 162 |
+
self.remote = None
|
| 163 |
+
else:
|
| 164 |
+
try:
|
| 165 |
+
self.remote = next(repo.remote().urls, None)
|
| 166 |
+
commit = repo.head.commit
|
| 167 |
+
self.commit_date = commit.committed_date
|
| 168 |
+
if repo.active_branch:
|
| 169 |
+
self.branch = repo.active_branch.name
|
| 170 |
+
self.commit_hash = commit.hexsha
|
| 171 |
+
self.version = self.commit_hash[:8]
|
| 172 |
+
|
| 173 |
+
except Exception:
|
| 174 |
+
errors.report(f"Failed reading extension data from Git repository ({self.name})", exc_info=True)
|
| 175 |
+
self.remote = None
|
| 176 |
+
|
| 177 |
+
self.have_info_from_repo = True
|
| 178 |
+
|
| 179 |
+
def list_files(self, subdir, extension):
|
| 180 |
+
dirpath = os.path.join(self.path, subdir)
|
| 181 |
+
if not os.path.isdir(dirpath):
|
| 182 |
+
return []
|
| 183 |
+
|
| 184 |
+
res = []
|
| 185 |
+
for filename in sorted(os.listdir(dirpath)):
|
| 186 |
+
res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
|
| 187 |
+
|
| 188 |
+
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
| 189 |
+
|
| 190 |
+
return res
|
| 191 |
+
|
| 192 |
+
def check_updates(self):
|
| 193 |
+
repo = Repo(self.path)
|
| 194 |
+
branch_name = f'{repo.remote().name}/{self.branch}'
|
| 195 |
+
for fetch in repo.remote().fetch(dry_run=True):
|
| 196 |
+
if self.branch and fetch.name != branch_name:
|
| 197 |
+
continue
|
| 198 |
+
if fetch.flags != fetch.HEAD_UPTODATE:
|
| 199 |
+
self.can_update = True
|
| 200 |
+
self.status = "new commits"
|
| 201 |
+
return
|
| 202 |
+
|
| 203 |
+
try:
|
| 204 |
+
origin = repo.rev_parse(branch_name)
|
| 205 |
+
if repo.head.commit != origin:
|
| 206 |
+
self.can_update = True
|
| 207 |
+
self.status = "behind HEAD"
|
| 208 |
+
return
|
| 209 |
+
except Exception:
|
| 210 |
+
self.can_update = False
|
| 211 |
+
self.status = "unknown (remote error)"
|
| 212 |
+
return
|
| 213 |
+
|
| 214 |
+
self.can_update = False
|
| 215 |
+
self.status = "latest"
|
| 216 |
+
|
| 217 |
+
def fetch_and_reset_hard(self, commit=None):
|
| 218 |
+
repo = Repo(self.path)
|
| 219 |
+
if commit is None:
|
| 220 |
+
commit = f'{repo.remote().name}/{self.branch}'
|
| 221 |
+
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
| 222 |
+
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
| 223 |
+
repo.git.fetch(all=True)
|
| 224 |
+
repo.git.reset(commit, hard=True)
|
| 225 |
+
self.have_info_from_repo = False
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def list_extensions():
|
| 229 |
+
extensions.clear()
|
| 230 |
+
extension_paths.clear()
|
| 231 |
+
loaded_extensions.clear()
|
| 232 |
+
|
| 233 |
+
if shared.cmd_opts.disable_all_extensions:
|
| 234 |
+
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
|
| 235 |
+
elif shared.opts.disable_all_extensions == "all":
|
| 236 |
+
print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
|
| 237 |
+
elif shared.cmd_opts.disable_extra_extensions:
|
| 238 |
+
print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***")
|
| 239 |
+
elif shared.opts.disable_all_extensions == "extra":
|
| 240 |
+
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
# scan through extensions directory and load metadata
|
| 244 |
+
for dirname in [extensions_builtin_dir, extensions_dir]:
|
| 245 |
+
if not os.path.isdir(dirname):
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
for extension_dirname in sorted(os.listdir(dirname)):
|
| 249 |
+
path = os.path.join(dirname, extension_dirname)
|
| 250 |
+
if not os.path.isdir(path):
|
| 251 |
+
continue
|
| 252 |
+
|
| 253 |
+
canonical_name = extension_dirname
|
| 254 |
+
metadata = ExtensionMetadata(path, canonical_name)
|
| 255 |
+
|
| 256 |
+
# check for duplicated canonical names
|
| 257 |
+
already_loaded_extension = loaded_extensions.get(metadata.canonical_name)
|
| 258 |
+
if already_loaded_extension is not None:
|
| 259 |
+
errors.report(f'Duplicate canonical name "{canonical_name}" found in extensions "{extension_dirname}" and "{already_loaded_extension.name}". Former will be discarded.', exc_info=False)
|
| 260 |
+
continue
|
| 261 |
+
|
| 262 |
+
is_builtin = dirname == extensions_builtin_dir
|
| 263 |
+
extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
|
| 264 |
+
extensions.append(extension)
|
| 265 |
+
extension_paths[extension.path] = extension
|
| 266 |
+
loaded_extensions[canonical_name] = extension
|
| 267 |
+
|
| 268 |
+
for extension in extensions:
|
| 269 |
+
extension.metadata.requires = extension.metadata.get_script_requirements("Requires", "Extension")
|
| 270 |
+
|
| 271 |
+
# check for requirements
|
| 272 |
+
for extension in extensions:
|
| 273 |
+
if not extension.enabled:
|
| 274 |
+
continue
|
| 275 |
+
|
| 276 |
+
for req in extension.metadata.requires:
|
| 277 |
+
required_extension = loaded_extensions.get(req)
|
| 278 |
+
if required_extension is None:
|
| 279 |
+
errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)
|
| 280 |
+
continue
|
| 281 |
+
|
| 282 |
+
if not required_extension.enabled:
|
| 283 |
+
errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)
|
| 284 |
+
continue
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def find_extension(filename):
|
| 288 |
+
parentdir = os.path.dirname(os.path.realpath(filename))
|
| 289 |
+
|
| 290 |
+
while parentdir != filename:
|
| 291 |
+
extension = extension_paths.get(parentdir)
|
| 292 |
+
if extension is not None:
|
| 293 |
+
return extension
|
| 294 |
+
|
| 295 |
+
filename = parentdir
|
| 296 |
+
parentdir = os.path.dirname(filename)
|
| 297 |
+
|
| 298 |
+
return None
|
| 299 |
+
|
modules/extra_networks.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import logging
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
|
| 7 |
+
from modules import errors
|
| 8 |
+
|
| 9 |
+
extra_network_registry = {}
|
| 10 |
+
extra_network_aliases = {}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def initialize():
|
| 14 |
+
extra_network_registry.clear()
|
| 15 |
+
extra_network_aliases.clear()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def register_extra_network(extra_network):
|
| 19 |
+
extra_network_registry[extra_network.name] = extra_network
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def register_extra_network_alias(extra_network, alias):
|
| 23 |
+
extra_network_aliases[alias] = extra_network
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def register_default_extra_networks():
|
| 27 |
+
from modules.extra_networks_hypernet import ExtraNetworkHypernet
|
| 28 |
+
register_extra_network(ExtraNetworkHypernet())
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ExtraNetworkParams:
|
| 32 |
+
def __init__(self, items=None):
|
| 33 |
+
self.items = items or []
|
| 34 |
+
self.positional = []
|
| 35 |
+
self.named = {}
|
| 36 |
+
|
| 37 |
+
for item in self.items:
|
| 38 |
+
parts = item.split('=', 2) if isinstance(item, str) else [item]
|
| 39 |
+
if len(parts) == 2:
|
| 40 |
+
self.named[parts[0]] = parts[1]
|
| 41 |
+
else:
|
| 42 |
+
self.positional.append(item)
|
| 43 |
+
|
| 44 |
+
def __eq__(self, other):
|
| 45 |
+
return self.items == other.items
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ExtraNetwork:
|
| 49 |
+
def __init__(self, name):
|
| 50 |
+
self.name = name
|
| 51 |
+
|
| 52 |
+
def activate(self, p, params_list):
|
| 53 |
+
"""
|
| 54 |
+
Called by processing on every run. Whatever the extra network is meant to do should be activated here.
|
| 55 |
+
Passes arguments related to this extra network in params_list.
|
| 56 |
+
User passes arguments by specifying this in his prompt:
|
| 57 |
+
|
| 58 |
+
<name:arg1:arg2:arg3>
|
| 59 |
+
|
| 60 |
+
Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
|
| 61 |
+
separated by colon.
|
| 62 |
+
|
| 63 |
+
Even if the user does not mention this ExtraNetwork in his prompt, the call will still be made, with empty params_list -
|
| 64 |
+
in this case, all effects of this extra networks should be disabled.
|
| 65 |
+
|
| 66 |
+
Can be called multiple times before deactivate() - each new call should override the previous call completely.
|
| 67 |
+
|
| 68 |
+
For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
|
| 69 |
+
|
| 70 |
+
> "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
|
| 71 |
+
|
| 72 |
+
params_list will be:
|
| 73 |
+
|
| 74 |
+
[
|
| 75 |
+
ExtraNetworkParams(items=["agm", "1.1"]),
|
| 76 |
+
ExtraNetworkParams(items=["ray"])
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
"""
|
| 80 |
+
raise NotImplementedError
|
| 81 |
+
|
| 82 |
+
def deactivate(self, p):
|
| 83 |
+
"""
|
| 84 |
+
Called at the end of processing for housekeeping. No need to do anything here.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
raise NotImplementedError
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def lookup_extra_networks(extra_network_data):
|
| 91 |
+
"""returns a dict mapping ExtraNetwork objects to lists of arguments for those extra networks.
|
| 92 |
+
|
| 93 |
+
Example input:
|
| 94 |
+
{
|
| 95 |
+
'lora': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>],
|
| 96 |
+
'lyco': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],
|
| 97 |
+
'hypernet': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
Example output:
|
| 101 |
+
|
| 102 |
+
{
|
| 103 |
+
<extra_networks_lora.ExtraNetworkLora object at 0x0000020581BEECE0>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>, <modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],
|
| 104 |
+
<modules.extra_networks_hypernet.ExtraNetworkHypernet object at 0x0000020581BEEE60>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]
|
| 105 |
+
}
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
res = {}
|
| 109 |
+
|
| 110 |
+
for extra_network_name, extra_network_args in list(extra_network_data.items()):
|
| 111 |
+
extra_network = extra_network_registry.get(extra_network_name, None)
|
| 112 |
+
alias = extra_network_aliases.get(extra_network_name, None)
|
| 113 |
+
|
| 114 |
+
if alias is not None and extra_network is None:
|
| 115 |
+
extra_network = alias
|
| 116 |
+
|
| 117 |
+
if extra_network is None:
|
| 118 |
+
logging.info(f"Skipping unknown extra network: {extra_network_name}")
|
| 119 |
+
continue
|
| 120 |
+
|
| 121 |
+
res.setdefault(extra_network, []).extend(extra_network_args)
|
| 122 |
+
|
| 123 |
+
return res
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def activate(p, extra_network_data):
|
| 127 |
+
"""call activate for extra networks in extra_network_data in specified order, then call
|
| 128 |
+
activate for all remaining registered networks with an empty argument list"""
|
| 129 |
+
|
| 130 |
+
activated = []
|
| 131 |
+
|
| 132 |
+
for extra_network, extra_network_args in lookup_extra_networks(extra_network_data).items():
|
| 133 |
+
|
| 134 |
+
try:
|
| 135 |
+
extra_network.activate(p, extra_network_args)
|
| 136 |
+
activated.append(extra_network)
|
| 137 |
+
except Exception as e:
|
| 138 |
+
errors.display(e, f"activating extra network {extra_network.name} with arguments {extra_network_args}")
|
| 139 |
+
|
| 140 |
+
for extra_network_name, extra_network in extra_network_registry.items():
|
| 141 |
+
if extra_network in activated:
|
| 142 |
+
continue
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
extra_network.activate(p, [])
|
| 146 |
+
except Exception as e:
|
| 147 |
+
errors.display(e, f"activating extra network {extra_network_name}")
|
| 148 |
+
|
| 149 |
+
if p.scripts is not None:
|
| 150 |
+
p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def deactivate(p, extra_network_data):
|
| 154 |
+
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
| 155 |
+
deactivate for all remaining registered networks"""
|
| 156 |
+
|
| 157 |
+
data = lookup_extra_networks(extra_network_data)
|
| 158 |
+
|
| 159 |
+
for extra_network in data:
|
| 160 |
+
try:
|
| 161 |
+
extra_network.deactivate(p)
|
| 162 |
+
except Exception as e:
|
| 163 |
+
errors.display(e, f"deactivating extra network {extra_network.name}")
|
| 164 |
+
|
| 165 |
+
for extra_network_name, extra_network in extra_network_registry.items():
|
| 166 |
+
if extra_network in data:
|
| 167 |
+
continue
|
| 168 |
+
|
| 169 |
+
try:
|
| 170 |
+
extra_network.deactivate(p)
|
| 171 |
+
except Exception as e:
|
| 172 |
+
errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
re_extra_net = re.compile(r"<(\w+):([^>]+)>")
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def parse_prompt(prompt):
|
| 179 |
+
res = defaultdict(list)
|
| 180 |
+
|
| 181 |
+
def found(m):
|
| 182 |
+
name = m.group(1)
|
| 183 |
+
args = m.group(2)
|
| 184 |
+
|
| 185 |
+
res[name].append(ExtraNetworkParams(items=args.split(":")))
|
| 186 |
+
|
| 187 |
+
return ""
|
| 188 |
+
|
| 189 |
+
prompt = re.sub(re_extra_net, found, prompt)
|
| 190 |
+
|
| 191 |
+
return prompt, res
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def parse_prompts(prompts):
|
| 195 |
+
res = []
|
| 196 |
+
extra_data = None
|
| 197 |
+
|
| 198 |
+
for prompt in prompts:
|
| 199 |
+
updated_prompt, parsed_extra_data = parse_prompt(prompt)
|
| 200 |
+
|
| 201 |
+
if extra_data is None:
|
| 202 |
+
extra_data = parsed_extra_data
|
| 203 |
+
|
| 204 |
+
res.append(updated_prompt)
|
| 205 |
+
|
| 206 |
+
return res, extra_data
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def get_user_metadata(filename, lister=None):
|
| 210 |
+
if filename is None:
|
| 211 |
+
return {}
|
| 212 |
+
|
| 213 |
+
basename, ext = os.path.splitext(filename)
|
| 214 |
+
metadata_filename = basename + '.json'
|
| 215 |
+
|
| 216 |
+
metadata = {}
|
| 217 |
+
try:
|
| 218 |
+
exists = lister.exists(metadata_filename) if lister else os.path.exists(metadata_filename)
|
| 219 |
+
if exists:
|
| 220 |
+
with open(metadata_filename, "r", encoding="utf8") as file:
|
| 221 |
+
metadata = json.load(file)
|
| 222 |
+
except Exception as e:
|
| 223 |
+
errors.display(e, f"reading extra network user metadata from {metadata_filename}")
|
| 224 |
+
|
| 225 |
+
return metadata
|
modules/extra_networks_hypernet.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from modules import extra_networks, shared
|
| 2 |
+
from modules.hypernetworks import hypernetwork
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
|
| 6 |
+
def __init__(self):
|
| 7 |
+
super().__init__('hypernet')
|
| 8 |
+
|
| 9 |
+
def activate(self, p, params_list):
|
| 10 |
+
additional = shared.opts.sd_hypernetwork
|
| 11 |
+
|
| 12 |
+
if additional != "None" and additional in shared.hypernetworks and not any(x for x in params_list if x.items[0] == additional):
|
| 13 |
+
hypernet_prompt_text = f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>"
|
| 14 |
+
p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
|
| 15 |
+
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
| 16 |
+
|
| 17 |
+
names = []
|
| 18 |
+
multipliers = []
|
| 19 |
+
for params in params_list:
|
| 20 |
+
assert params.items
|
| 21 |
+
|
| 22 |
+
names.append(params.items[0])
|
| 23 |
+
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
|
| 24 |
+
|
| 25 |
+
hypernetwork.load_hypernetworks(names, multipliers)
|
| 26 |
+
|
| 27 |
+
def deactivate(self, p):
|
| 28 |
+
pass
|
modules/extras.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import shutil
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import tqdm
|
| 9 |
+
|
| 10 |
+
from modules import shared, images, sd_models, sd_vae, sd_models_config, errors
|
| 11 |
+
from modules.ui_common import plaintext_to_html
|
| 12 |
+
import gradio as gr
|
| 13 |
+
import safetensors.torch
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def run_pnginfo(image):
|
| 17 |
+
if image is None:
|
| 18 |
+
return '', '', ''
|
| 19 |
+
|
| 20 |
+
geninfo, items = images.read_info_from_image(image)
|
| 21 |
+
items = {**{'parameters': geninfo}, **items}
|
| 22 |
+
|
| 23 |
+
info = ''
|
| 24 |
+
for key, text in items.items():
|
| 25 |
+
info += f"""
|
| 26 |
+
<div>
|
| 27 |
+
<p><b>{plaintext_to_html(str(key))}</b></p>
|
| 28 |
+
<p>{plaintext_to_html(str(text))}</p>
|
| 29 |
+
</div>
|
| 30 |
+
""".strip()+"\n"
|
| 31 |
+
|
| 32 |
+
if len(info) == 0:
|
| 33 |
+
message = "Nothing found in the image."
|
| 34 |
+
info = f"<div><p>{message}<p></div>"
|
| 35 |
+
|
| 36 |
+
return '', geninfo, info
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def create_config(ckpt_result, config_source, a, b, c):
|
| 40 |
+
def config(x):
|
| 41 |
+
res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
|
| 42 |
+
return res if res != shared.sd_default_config else None
|
| 43 |
+
|
| 44 |
+
if config_source == 0:
|
| 45 |
+
cfg = config(a) or config(b) or config(c)
|
| 46 |
+
elif config_source == 1:
|
| 47 |
+
cfg = config(b)
|
| 48 |
+
elif config_source == 2:
|
| 49 |
+
cfg = config(c)
|
| 50 |
+
else:
|
| 51 |
+
cfg = None
|
| 52 |
+
|
| 53 |
+
if cfg is None:
|
| 54 |
+
return
|
| 55 |
+
|
| 56 |
+
filename, _ = os.path.splitext(ckpt_result)
|
| 57 |
+
checkpoint_filename = filename + ".yaml"
|
| 58 |
+
|
| 59 |
+
print("Copying config:")
|
| 60 |
+
print(" from:", cfg)
|
| 61 |
+
print(" to:", checkpoint_filename)
|
| 62 |
+
shutil.copyfile(cfg, checkpoint_filename)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def to_half(tensor, enable):
|
| 69 |
+
if enable and tensor.dtype == torch.float:
|
| 70 |
+
return tensor.half()
|
| 71 |
+
|
| 72 |
+
return tensor
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name):
|
| 76 |
+
metadata = {}
|
| 77 |
+
|
| 78 |
+
for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]:
|
| 79 |
+
checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None)
|
| 80 |
+
if checkpoint_info is None:
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
metadata.update(checkpoint_info.metadata)
|
| 84 |
+
|
| 85 |
+
return json.dumps(metadata, indent=4, ensure_ascii=False)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json):
|
| 89 |
+
shared.state.begin(job="model-merge")
|
| 90 |
+
|
| 91 |
+
def fail(message):
|
| 92 |
+
shared.state.textinfo = message
|
| 93 |
+
shared.state.end()
|
| 94 |
+
return [*[gr.update() for _ in range(4)], message]
|
| 95 |
+
|
| 96 |
+
def weighted_sum(theta0, theta1, alpha):
|
| 97 |
+
return ((1 - alpha) * theta0) + (alpha * theta1)
|
| 98 |
+
|
| 99 |
+
def get_difference(theta1, theta2):
|
| 100 |
+
return theta1 - theta2
|
| 101 |
+
|
| 102 |
+
def add_difference(theta0, theta1_2_diff, alpha):
|
| 103 |
+
return theta0 + (alpha * theta1_2_diff)
|
| 104 |
+
|
| 105 |
+
def filename_weighted_sum():
|
| 106 |
+
a = primary_model_info.model_name
|
| 107 |
+
b = secondary_model_info.model_name
|
| 108 |
+
Ma = round(1 - multiplier, 2)
|
| 109 |
+
Mb = round(multiplier, 2)
|
| 110 |
+
|
| 111 |
+
return f"{Ma}({a}) + {Mb}({b})"
|
| 112 |
+
|
| 113 |
+
def filename_add_difference():
|
| 114 |
+
a = primary_model_info.model_name
|
| 115 |
+
b = secondary_model_info.model_name
|
| 116 |
+
c = tertiary_model_info.model_name
|
| 117 |
+
M = round(multiplier, 2)
|
| 118 |
+
|
| 119 |
+
return f"{a} + {M}({b} - {c})"
|
| 120 |
+
|
| 121 |
+
def filename_nothing():
|
| 122 |
+
return primary_model_info.model_name
|
| 123 |
+
|
| 124 |
+
theta_funcs = {
|
| 125 |
+
"Weighted sum": (filename_weighted_sum, None, weighted_sum),
|
| 126 |
+
"Add difference": (filename_add_difference, get_difference, add_difference),
|
| 127 |
+
"No interpolation": (filename_nothing, None, None),
|
| 128 |
+
}
|
| 129 |
+
filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
|
| 130 |
+
shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
|
| 131 |
+
|
| 132 |
+
if not primary_model_name:
|
| 133 |
+
return fail("Failed: Merging requires a primary model.")
|
| 134 |
+
|
| 135 |
+
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
| 136 |
+
|
| 137 |
+
if theta_func2 and not secondary_model_name:
|
| 138 |
+
return fail("Failed: Merging requires a secondary model.")
|
| 139 |
+
|
| 140 |
+
secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
|
| 141 |
+
|
| 142 |
+
if theta_func1 and not tertiary_model_name:
|
| 143 |
+
return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
|
| 144 |
+
|
| 145 |
+
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
|
| 146 |
+
|
| 147 |
+
result_is_inpainting_model = False
|
| 148 |
+
result_is_instruct_pix2pix_model = False
|
| 149 |
+
|
| 150 |
+
if theta_func2:
|
| 151 |
+
shared.state.textinfo = "Loading B"
|
| 152 |
+
print(f"Loading {secondary_model_info.filename}...")
|
| 153 |
+
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
| 154 |
+
else:
|
| 155 |
+
theta_1 = None
|
| 156 |
+
|
| 157 |
+
if theta_func1:
|
| 158 |
+
shared.state.textinfo = "Loading C"
|
| 159 |
+
print(f"Loading {tertiary_model_info.filename}...")
|
| 160 |
+
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
| 161 |
+
|
| 162 |
+
shared.state.textinfo = 'Merging B and C'
|
| 163 |
+
shared.state.sampling_steps = len(theta_1.keys())
|
| 164 |
+
for key in tqdm.tqdm(theta_1.keys()):
|
| 165 |
+
if key in checkpoint_dict_skip_on_merge:
|
| 166 |
+
continue
|
| 167 |
+
|
| 168 |
+
if 'model' in key:
|
| 169 |
+
if key in theta_2:
|
| 170 |
+
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
|
| 171 |
+
theta_1[key] = theta_func1(theta_1[key], t2)
|
| 172 |
+
else:
|
| 173 |
+
theta_1[key] = torch.zeros_like(theta_1[key])
|
| 174 |
+
|
| 175 |
+
shared.state.sampling_step += 1
|
| 176 |
+
del theta_2
|
| 177 |
+
|
| 178 |
+
shared.state.nextjob()
|
| 179 |
+
|
| 180 |
+
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
|
| 181 |
+
print(f"Loading {primary_model_info.filename}...")
|
| 182 |
+
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
| 183 |
+
|
| 184 |
+
print("Merging...")
|
| 185 |
+
shared.state.textinfo = 'Merging A and B'
|
| 186 |
+
shared.state.sampling_steps = len(theta_0.keys())
|
| 187 |
+
for key in tqdm.tqdm(theta_0.keys()):
|
| 188 |
+
if theta_1 and 'model' in key and key in theta_1:
|
| 189 |
+
|
| 190 |
+
if key in checkpoint_dict_skip_on_merge:
|
| 191 |
+
continue
|
| 192 |
+
|
| 193 |
+
a = theta_0[key]
|
| 194 |
+
b = theta_1[key]
|
| 195 |
+
|
| 196 |
+
# this enables merging an inpainting model (A) with another one (B);
|
| 197 |
+
# where normal model would have 4 channels, for latenst space, inpainting model would
|
| 198 |
+
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
|
| 199 |
+
if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
|
| 200 |
+
if a.shape[1] == 4 and b.shape[1] == 9:
|
| 201 |
+
raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
|
| 202 |
+
if a.shape[1] == 4 and b.shape[1] == 8:
|
| 203 |
+
raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
|
| 204 |
+
|
| 205 |
+
if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model...
|
| 206 |
+
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
|
| 207 |
+
result_is_instruct_pix2pix_model = True
|
| 208 |
+
else:
|
| 209 |
+
assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
|
| 210 |
+
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
| 211 |
+
result_is_inpainting_model = True
|
| 212 |
+
else:
|
| 213 |
+
theta_0[key] = theta_func2(a, b, multiplier)
|
| 214 |
+
|
| 215 |
+
theta_0[key] = to_half(theta_0[key], save_as_half)
|
| 216 |
+
|
| 217 |
+
shared.state.sampling_step += 1
|
| 218 |
+
|
| 219 |
+
del theta_1
|
| 220 |
+
|
| 221 |
+
bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
|
| 222 |
+
if bake_in_vae_filename is not None:
|
| 223 |
+
print(f"Baking in VAE from {bake_in_vae_filename}")
|
| 224 |
+
shared.state.textinfo = 'Baking in VAE'
|
| 225 |
+
vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
|
| 226 |
+
|
| 227 |
+
for key in vae_dict.keys():
|
| 228 |
+
theta_0_key = 'first_stage_model.' + key
|
| 229 |
+
if theta_0_key in theta_0:
|
| 230 |
+
theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
|
| 231 |
+
|
| 232 |
+
del vae_dict
|
| 233 |
+
|
| 234 |
+
if save_as_half and not theta_func2:
|
| 235 |
+
for key in theta_0.keys():
|
| 236 |
+
theta_0[key] = to_half(theta_0[key], save_as_half)
|
| 237 |
+
|
| 238 |
+
if discard_weights:
|
| 239 |
+
regex = re.compile(discard_weights)
|
| 240 |
+
for key in list(theta_0):
|
| 241 |
+
if re.search(regex, key):
|
| 242 |
+
theta_0.pop(key, None)
|
| 243 |
+
|
| 244 |
+
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
| 245 |
+
|
| 246 |
+
filename = filename_generator() if custom_name == '' else custom_name
|
| 247 |
+
filename += ".inpainting" if result_is_inpainting_model else ""
|
| 248 |
+
filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else ""
|
| 249 |
+
filename += "." + checkpoint_format
|
| 250 |
+
|
| 251 |
+
output_modelname = os.path.join(ckpt_dir, filename)
|
| 252 |
+
|
| 253 |
+
shared.state.nextjob()
|
| 254 |
+
shared.state.textinfo = "Saving"
|
| 255 |
+
print(f"Saving to {output_modelname}...")
|
| 256 |
+
|
| 257 |
+
metadata = {}
|
| 258 |
+
|
| 259 |
+
if save_metadata and copy_metadata_fields:
|
| 260 |
+
if primary_model_info:
|
| 261 |
+
metadata.update(primary_model_info.metadata)
|
| 262 |
+
if secondary_model_info:
|
| 263 |
+
metadata.update(secondary_model_info.metadata)
|
| 264 |
+
if tertiary_model_info:
|
| 265 |
+
metadata.update(tertiary_model_info.metadata)
|
| 266 |
+
|
| 267 |
+
if save_metadata:
|
| 268 |
+
try:
|
| 269 |
+
metadata.update(json.loads(metadata_json))
|
| 270 |
+
except Exception as e:
|
| 271 |
+
errors.display(e, "readin metadata from json")
|
| 272 |
+
|
| 273 |
+
metadata["format"] = "pt"
|
| 274 |
+
|
| 275 |
+
if save_metadata and add_merge_recipe:
|
| 276 |
+
merge_recipe = {
|
| 277 |
+
"type": "webui", # indicate this model was merged with webui's built-in merger
|
| 278 |
+
"primary_model_hash": primary_model_info.sha256,
|
| 279 |
+
"secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None,
|
| 280 |
+
"tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None,
|
| 281 |
+
"interp_method": interp_method,
|
| 282 |
+
"multiplier": multiplier,
|
| 283 |
+
"save_as_half": save_as_half,
|
| 284 |
+
"custom_name": custom_name,
|
| 285 |
+
"config_source": config_source,
|
| 286 |
+
"bake_in_vae": bake_in_vae,
|
| 287 |
+
"discard_weights": discard_weights,
|
| 288 |
+
"is_inpainting": result_is_inpainting_model,
|
| 289 |
+
"is_instruct_pix2pix": result_is_instruct_pix2pix_model
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
sd_merge_models = {}
|
| 293 |
+
|
| 294 |
+
def add_model_metadata(checkpoint_info):
|
| 295 |
+
checkpoint_info.calculate_shorthash()
|
| 296 |
+
sd_merge_models[checkpoint_info.sha256] = {
|
| 297 |
+
"name": checkpoint_info.name,
|
| 298 |
+
"legacy_hash": checkpoint_info.hash,
|
| 299 |
+
"sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
sd_merge_models.update(checkpoint_info.metadata.get("sd_merge_models", {}))
|
| 303 |
+
|
| 304 |
+
add_model_metadata(primary_model_info)
|
| 305 |
+
if secondary_model_info:
|
| 306 |
+
add_model_metadata(secondary_model_info)
|
| 307 |
+
if tertiary_model_info:
|
| 308 |
+
add_model_metadata(tertiary_model_info)
|
| 309 |
+
|
| 310 |
+
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
|
| 311 |
+
metadata["sd_merge_models"] = json.dumps(sd_merge_models)
|
| 312 |
+
|
| 313 |
+
_, extension = os.path.splitext(output_modelname)
|
| 314 |
+
if extension.lower() == ".safetensors":
|
| 315 |
+
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata if len(metadata)>0 else None)
|
| 316 |
+
else:
|
| 317 |
+
torch.save(theta_0, output_modelname)
|
| 318 |
+
|
| 319 |
+
sd_models.list_models()
|
| 320 |
+
created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
|
| 321 |
+
if created_model:
|
| 322 |
+
created_model.calculate_shorthash()
|
| 323 |
+
|
| 324 |
+
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
| 325 |
+
|
| 326 |
+
print(f"Checkpoint saved to {output_modelname}.")
|
| 327 |
+
shared.state.textinfo = "Checkpoint saved"
|
| 328 |
+
shared.state.end()
|
| 329 |
+
|
| 330 |
+
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
|
modules/face_restoration.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from modules import shared
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class FaceRestoration:
|
| 5 |
+
def name(self):
|
| 6 |
+
return "None"
|
| 7 |
+
|
| 8 |
+
def restore(self, np_image):
|
| 9 |
+
return np_image
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def restore_faces(np_image):
|
| 13 |
+
face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
|
| 14 |
+
if len(face_restorers) == 0:
|
| 15 |
+
return np_image
|
| 16 |
+
|
| 17 |
+
face_restorer = face_restorers[0]
|
| 18 |
+
|
| 19 |
+
return face_restorer.restore(np_image)
|
modules/face_restoration_utils.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
from functools import cached_property
|
| 6 |
+
from typing import TYPE_CHECKING, Callable
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from modules import devices, errors, face_restoration, shared
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor:
|
| 21 |
+
"""Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor."""
|
| 22 |
+
assert img.shape[2] == 3, "image must be RGB"
|
| 23 |
+
if img.dtype == "float64":
|
| 24 |
+
img = img.astype("float32")
|
| 25 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 26 |
+
return torch.from_numpy(img.transpose(2, 0, 1)).float()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray:
|
| 30 |
+
"""
|
| 31 |
+
Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range.
|
| 32 |
+
"""
|
| 33 |
+
tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
|
| 34 |
+
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
|
| 35 |
+
assert tensor.dim() == 3, "tensor must be RGB"
|
| 36 |
+
img_np = tensor.numpy().transpose(1, 2, 0)
|
| 37 |
+
if img_np.shape[2] == 1: # gray image, no RGB/BGR required
|
| 38 |
+
return np.squeeze(img_np, axis=2)
|
| 39 |
+
return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def create_face_helper(device) -> FaceRestoreHelper:
|
| 43 |
+
from facexlib.detection import retinaface
|
| 44 |
+
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
| 45 |
+
if hasattr(retinaface, 'device'):
|
| 46 |
+
retinaface.device = device
|
| 47 |
+
return FaceRestoreHelper(
|
| 48 |
+
upscale_factor=1,
|
| 49 |
+
face_size=512,
|
| 50 |
+
crop_ratio=(1, 1),
|
| 51 |
+
det_model='retinaface_resnet50',
|
| 52 |
+
save_ext='png',
|
| 53 |
+
use_parse=True,
|
| 54 |
+
device=device,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def restore_with_face_helper(
|
| 59 |
+
np_image: np.ndarray,
|
| 60 |
+
face_helper: FaceRestoreHelper,
|
| 61 |
+
restore_face: Callable[[torch.Tensor], torch.Tensor],
|
| 62 |
+
) -> np.ndarray:
|
| 63 |
+
"""
|
| 64 |
+
Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.
|
| 65 |
+
|
| 66 |
+
`restore_face` should take a cropped face image and return a restored face image.
|
| 67 |
+
"""
|
| 68 |
+
from torchvision.transforms.functional import normalize
|
| 69 |
+
np_image = np_image[:, :, ::-1]
|
| 70 |
+
original_resolution = np_image.shape[0:2]
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
logger.debug("Detecting faces...")
|
| 74 |
+
face_helper.clean_all()
|
| 75 |
+
face_helper.read_image(np_image)
|
| 76 |
+
face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
| 77 |
+
face_helper.align_warp_face()
|
| 78 |
+
logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
|
| 79 |
+
for cropped_face in face_helper.cropped_faces:
|
| 80 |
+
cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0)
|
| 81 |
+
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
| 82 |
+
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
with torch.no_grad():
|
| 86 |
+
cropped_face_t = restore_face(cropped_face_t)
|
| 87 |
+
devices.torch_gc()
|
| 88 |
+
except Exception:
|
| 89 |
+
errors.report('Failed face-restoration inference', exc_info=True)
|
| 90 |
+
|
| 91 |
+
restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1))
|
| 92 |
+
restored_face = (restored_face * 255.0).astype('uint8')
|
| 93 |
+
face_helper.add_restored_face(restored_face)
|
| 94 |
+
|
| 95 |
+
logger.debug("Merging restored faces into image")
|
| 96 |
+
face_helper.get_inverse_affine(None)
|
| 97 |
+
img = face_helper.paste_faces_to_input_image()
|
| 98 |
+
img = img[:, :, ::-1]
|
| 99 |
+
if original_resolution != img.shape[0:2]:
|
| 100 |
+
img = cv2.resize(
|
| 101 |
+
img,
|
| 102 |
+
(0, 0),
|
| 103 |
+
fx=original_resolution[1] / img.shape[1],
|
| 104 |
+
fy=original_resolution[0] / img.shape[0],
|
| 105 |
+
interpolation=cv2.INTER_LINEAR,
|
| 106 |
+
)
|
| 107 |
+
logger.debug("Face restoration complete")
|
| 108 |
+
finally:
|
| 109 |
+
face_helper.clean_all()
|
| 110 |
+
return img
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class CommonFaceRestoration(face_restoration.FaceRestoration):
|
| 114 |
+
net: torch.Module | None
|
| 115 |
+
model_url: str
|
| 116 |
+
model_download_name: str
|
| 117 |
+
|
| 118 |
+
def __init__(self, model_path: str):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.net = None
|
| 121 |
+
self.model_path = model_path
|
| 122 |
+
os.makedirs(model_path, exist_ok=True)
|
| 123 |
+
|
| 124 |
+
@cached_property
|
| 125 |
+
def face_helper(self) -> FaceRestoreHelper:
|
| 126 |
+
return create_face_helper(self.get_device())
|
| 127 |
+
|
| 128 |
+
def send_model_to(self, device):
|
| 129 |
+
if self.net:
|
| 130 |
+
logger.debug("Sending %s to %s", self.net, device)
|
| 131 |
+
self.net.to(device)
|
| 132 |
+
if self.face_helper:
|
| 133 |
+
logger.debug("Sending face helper to %s", device)
|
| 134 |
+
self.face_helper.face_det.to(device)
|
| 135 |
+
self.face_helper.face_parse.to(device)
|
| 136 |
+
|
| 137 |
+
def get_device(self):
|
| 138 |
+
raise NotImplementedError("get_device must be implemented by subclasses")
|
| 139 |
+
|
| 140 |
+
def load_net(self) -> torch.Module:
|
| 141 |
+
raise NotImplementedError("load_net must be implemented by subclasses")
|
| 142 |
+
|
| 143 |
+
def restore_with_helper(
|
| 144 |
+
self,
|
| 145 |
+
np_image: np.ndarray,
|
| 146 |
+
restore_face: Callable[[torch.Tensor], torch.Tensor],
|
| 147 |
+
) -> np.ndarray:
|
| 148 |
+
try:
|
| 149 |
+
if self.net is None:
|
| 150 |
+
self.net = self.load_net()
|
| 151 |
+
except Exception:
|
| 152 |
+
logger.warning("Unable to load face-restoration model", exc_info=True)
|
| 153 |
+
return np_image
|
| 154 |
+
|
| 155 |
+
try:
|
| 156 |
+
self.send_model_to(self.get_device())
|
| 157 |
+
return restore_with_face_helper(np_image, self.face_helper, restore_face)
|
| 158 |
+
finally:
|
| 159 |
+
if shared.opts.face_restoration_unload:
|
| 160 |
+
self.send_model_to(devices.cpu)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def patch_facexlib(dirname: str) -> None:
|
| 164 |
+
import facexlib.detection
|
| 165 |
+
import facexlib.parsing
|
| 166 |
+
|
| 167 |
+
det_facex_load_file_from_url = facexlib.detection.load_file_from_url
|
| 168 |
+
par_facex_load_file_from_url = facexlib.parsing.load_file_from_url
|
| 169 |
+
|
| 170 |
+
def update_kwargs(kwargs):
|
| 171 |
+
return dict(kwargs, save_dir=dirname, model_dir=None)
|
| 172 |
+
|
| 173 |
+
def facex_load_file_from_url(**kwargs):
|
| 174 |
+
return det_facex_load_file_from_url(**update_kwargs(kwargs))
|
| 175 |
+
|
| 176 |
+
def facex_load_file_from_url2(**kwargs):
|
| 177 |
+
return par_facex_load_file_from_url(**update_kwargs(kwargs))
|
| 178 |
+
|
| 179 |
+
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
| 180 |
+
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
|
modules/fifo_lock.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
import collections
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# reference: https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a
|
| 6 |
+
class FIFOLock(object):
|
| 7 |
+
def __init__(self):
|
| 8 |
+
self._lock = threading.Lock()
|
| 9 |
+
self._inner_lock = threading.Lock()
|
| 10 |
+
self._pending_threads = collections.deque()
|
| 11 |
+
|
| 12 |
+
def acquire(self, blocking=True):
|
| 13 |
+
with self._inner_lock:
|
| 14 |
+
lock_acquired = self._lock.acquire(False)
|
| 15 |
+
if lock_acquired:
|
| 16 |
+
return True
|
| 17 |
+
elif not blocking:
|
| 18 |
+
return False
|
| 19 |
+
|
| 20 |
+
release_event = threading.Event()
|
| 21 |
+
self._pending_threads.append(release_event)
|
| 22 |
+
|
| 23 |
+
release_event.wait()
|
| 24 |
+
return self._lock.acquire()
|
| 25 |
+
|
| 26 |
+
def release(self):
|
| 27 |
+
with self._inner_lock:
|
| 28 |
+
if self._pending_threads:
|
| 29 |
+
release_event = self._pending_threads.popleft()
|
| 30 |
+
release_event.set()
|
| 31 |
+
|
| 32 |
+
self._lock.release()
|
| 33 |
+
|
| 34 |
+
__enter__ = acquire
|
| 35 |
+
|
| 36 |
+
def __exit__(self, t, v, tb):
|
| 37 |
+
self.release()
|
modules/gfpgan_model.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from modules import (
|
| 9 |
+
devices,
|
| 10 |
+
errors,
|
| 11 |
+
face_restoration,
|
| 12 |
+
face_restoration_utils,
|
| 13 |
+
modelloader,
|
| 14 |
+
shared,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
| 19 |
+
model_download_name = "GFPGANv1.4.pth"
|
| 20 |
+
gfpgan_face_restorer: face_restoration.FaceRestoration | None = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
|
| 24 |
+
def name(self):
|
| 25 |
+
return "GFPGAN"
|
| 26 |
+
|
| 27 |
+
def get_device(self):
|
| 28 |
+
return devices.device_gfpgan
|
| 29 |
+
|
| 30 |
+
def load_net(self) -> torch.Module:
|
| 31 |
+
for model_path in modelloader.load_models(
|
| 32 |
+
model_path=self.model_path,
|
| 33 |
+
model_url=model_url,
|
| 34 |
+
command_path=self.model_path,
|
| 35 |
+
download_name=model_download_name,
|
| 36 |
+
ext_filter=['.pth'],
|
| 37 |
+
):
|
| 38 |
+
if 'GFPGAN' in os.path.basename(model_path):
|
| 39 |
+
return modelloader.load_spandrel_model(
|
| 40 |
+
model_path,
|
| 41 |
+
device=self.get_device(),
|
| 42 |
+
expected_architecture='GFPGAN',
|
| 43 |
+
).model
|
| 44 |
+
raise ValueError("No GFPGAN model found")
|
| 45 |
+
|
| 46 |
+
def restore(self, np_image):
|
| 47 |
+
def restore_face(cropped_face_t):
|
| 48 |
+
assert self.net is not None
|
| 49 |
+
return self.net(cropped_face_t, return_rgb=False)[0]
|
| 50 |
+
|
| 51 |
+
return self.restore_with_helper(np_image, restore_face)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def gfpgan_fix_faces(np_image):
|
| 55 |
+
if gfpgan_face_restorer:
|
| 56 |
+
return gfpgan_face_restorer.restore(np_image)
|
| 57 |
+
logger.warning("GFPGAN face restorer not set up")
|
| 58 |
+
return np_image
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def setup_model(dirname: str) -> None:
|
| 62 |
+
global gfpgan_face_restorer
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
face_restoration_utils.patch_facexlib(dirname)
|
| 66 |
+
gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname)
|
| 67 |
+
shared.face_restorers.append(gfpgan_face_restorer)
|
| 68 |
+
except Exception:
|
| 69 |
+
errors.report("Error setting up GFPGAN", exc_info=True)
|
modules/gitpython_hack.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import io
|
| 4 |
+
import subprocess
|
| 5 |
+
|
| 6 |
+
import git
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Git(git.Git):
|
| 10 |
+
"""
|
| 11 |
+
Git subclassed to never use persistent processes.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def _get_persistent_cmd(self, attr_name, cmd_name, *args, **kwargs):
|
| 15 |
+
raise NotImplementedError(f"Refusing to use persistent process: {attr_name} ({cmd_name} {args} {kwargs})")
|
| 16 |
+
|
| 17 |
+
def get_object_header(self, ref: str | bytes) -> tuple[str, str, int]:
|
| 18 |
+
ret = subprocess.check_output(
|
| 19 |
+
[self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch-check"],
|
| 20 |
+
input=self._prepare_ref(ref),
|
| 21 |
+
cwd=self._working_dir,
|
| 22 |
+
timeout=2,
|
| 23 |
+
)
|
| 24 |
+
return self._parse_object_header(ret)
|
| 25 |
+
|
| 26 |
+
def stream_object_data(self, ref: str) -> tuple[str, str, int, Git.CatFileContentStream]:
|
| 27 |
+
# Not really streaming, per se; this buffers the entire object in memory.
|
| 28 |
+
# Shouldn't be a problem for our use case, since we're only using this for
|
| 29 |
+
# object headers (commit objects).
|
| 30 |
+
ret = subprocess.check_output(
|
| 31 |
+
[self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch"],
|
| 32 |
+
input=self._prepare_ref(ref),
|
| 33 |
+
cwd=self._working_dir,
|
| 34 |
+
timeout=30,
|
| 35 |
+
)
|
| 36 |
+
bio = io.BytesIO(ret)
|
| 37 |
+
hexsha, typename, size = self._parse_object_header(bio.readline())
|
| 38 |
+
return (hexsha, typename, size, self.CatFileContentStream(size, bio))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Repo(git.Repo):
|
| 42 |
+
GitCommandWrapperType = Git
|
modules/gradio_extensons.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
|
| 3 |
+
from modules import scripts, ui_tempdir, patches
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def add_classes_to_gradio_component(comp):
|
| 7 |
+
"""
|
| 8 |
+
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
|
| 12 |
+
|
| 13 |
+
if getattr(comp, 'multiselect', False):
|
| 14 |
+
comp.elem_classes.append('multiselect')
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def IOComponent_init(self, *args, **kwargs):
|
| 18 |
+
self.webui_tooltip = kwargs.pop('tooltip', None)
|
| 19 |
+
|
| 20 |
+
if scripts.scripts_current is not None:
|
| 21 |
+
scripts.scripts_current.before_component(self, **kwargs)
|
| 22 |
+
|
| 23 |
+
scripts.script_callbacks.before_component_callback(self, **kwargs)
|
| 24 |
+
|
| 25 |
+
res = original_IOComponent_init(self, *args, **kwargs)
|
| 26 |
+
|
| 27 |
+
add_classes_to_gradio_component(self)
|
| 28 |
+
|
| 29 |
+
scripts.script_callbacks.after_component_callback(self, **kwargs)
|
| 30 |
+
|
| 31 |
+
if scripts.scripts_current is not None:
|
| 32 |
+
scripts.scripts_current.after_component(self, **kwargs)
|
| 33 |
+
|
| 34 |
+
return res
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def Block_get_config(self):
|
| 38 |
+
config = original_Block_get_config(self)
|
| 39 |
+
|
| 40 |
+
webui_tooltip = getattr(self, 'webui_tooltip', None)
|
| 41 |
+
if webui_tooltip:
|
| 42 |
+
config["webui_tooltip"] = webui_tooltip
|
| 43 |
+
|
| 44 |
+
config.pop('example_inputs', None)
|
| 45 |
+
|
| 46 |
+
return config
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def BlockContext_init(self, *args, **kwargs):
|
| 50 |
+
if scripts.scripts_current is not None:
|
| 51 |
+
scripts.scripts_current.before_component(self, **kwargs)
|
| 52 |
+
|
| 53 |
+
scripts.script_callbacks.before_component_callback(self, **kwargs)
|
| 54 |
+
|
| 55 |
+
res = original_BlockContext_init(self, *args, **kwargs)
|
| 56 |
+
|
| 57 |
+
add_classes_to_gradio_component(self)
|
| 58 |
+
|
| 59 |
+
scripts.script_callbacks.after_component_callback(self, **kwargs)
|
| 60 |
+
|
| 61 |
+
if scripts.scripts_current is not None:
|
| 62 |
+
scripts.scripts_current.after_component(self, **kwargs)
|
| 63 |
+
|
| 64 |
+
return res
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def Blocks_get_config_file(self, *args, **kwargs):
|
| 68 |
+
config = original_Blocks_get_config_file(self, *args, **kwargs)
|
| 69 |
+
|
| 70 |
+
for comp_config in config["components"]:
|
| 71 |
+
if "example_inputs" in comp_config:
|
| 72 |
+
comp_config["example_inputs"] = {"serialized": []}
|
| 73 |
+
|
| 74 |
+
return config
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
original_IOComponent_init = patches.patch(__name__, obj=gr.components.IOComponent, field="__init__", replacement=IOComponent_init)
|
| 78 |
+
original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config)
|
| 79 |
+
original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init)
|
| 80 |
+
original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
ui_tempdir.install_ui_tempdir_override()
|
modules/hashes.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import os.path
|
| 3 |
+
|
| 4 |
+
from modules import shared
|
| 5 |
+
import modules.cache
|
| 6 |
+
|
| 7 |
+
dump_cache = modules.cache.dump_cache
|
| 8 |
+
cache = modules.cache.cache
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def calculate_sha256(filename):
|
| 12 |
+
hash_sha256 = hashlib.sha256()
|
| 13 |
+
blksize = 1024 * 1024
|
| 14 |
+
|
| 15 |
+
with open(filename, "rb") as f:
|
| 16 |
+
for chunk in iter(lambda: f.read(blksize), b""):
|
| 17 |
+
hash_sha256.update(chunk)
|
| 18 |
+
|
| 19 |
+
return hash_sha256.hexdigest()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def sha256_from_cache(filename, title, use_addnet_hash=False):
|
| 23 |
+
hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
|
| 24 |
+
try:
|
| 25 |
+
ondisk_mtime = os.path.getmtime(filename)
|
| 26 |
+
except FileNotFoundError:
|
| 27 |
+
return None
|
| 28 |
+
|
| 29 |
+
if title not in hashes:
|
| 30 |
+
return None
|
| 31 |
+
|
| 32 |
+
cached_sha256 = hashes[title].get("sha256", None)
|
| 33 |
+
cached_mtime = hashes[title].get("mtime", 0)
|
| 34 |
+
|
| 35 |
+
if ondisk_mtime > cached_mtime or cached_sha256 is None:
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
return cached_sha256
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def sha256(filename, title, use_addnet_hash=False):
|
| 42 |
+
hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
|
| 43 |
+
|
| 44 |
+
sha256_value = sha256_from_cache(filename, title, use_addnet_hash)
|
| 45 |
+
if sha256_value is not None:
|
| 46 |
+
return sha256_value
|
| 47 |
+
|
| 48 |
+
if shared.cmd_opts.no_hashing:
|
| 49 |
+
return None
|
| 50 |
+
|
| 51 |
+
print(f"Calculating sha256 for {filename}: ", end='')
|
| 52 |
+
if use_addnet_hash:
|
| 53 |
+
with open(filename, "rb") as file:
|
| 54 |
+
sha256_value = addnet_hash_safetensors(file)
|
| 55 |
+
else:
|
| 56 |
+
sha256_value = calculate_sha256(filename)
|
| 57 |
+
print(f"{sha256_value}")
|
| 58 |
+
|
| 59 |
+
hashes[title] = {
|
| 60 |
+
"mtime": os.path.getmtime(filename),
|
| 61 |
+
"sha256": sha256_value,
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
dump_cache()
|
| 65 |
+
|
| 66 |
+
return sha256_value
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def addnet_hash_safetensors(b):
|
| 70 |
+
"""kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py"""
|
| 71 |
+
hash_sha256 = hashlib.sha256()
|
| 72 |
+
blksize = 1024 * 1024
|
| 73 |
+
|
| 74 |
+
b.seek(0)
|
| 75 |
+
header = b.read(8)
|
| 76 |
+
n = int.from_bytes(header, "little")
|
| 77 |
+
|
| 78 |
+
offset = n + 8
|
| 79 |
+
b.seek(offset)
|
| 80 |
+
for chunk in iter(lambda: b.read(blksize), b""):
|
| 81 |
+
hash_sha256.update(chunk)
|
| 82 |
+
|
| 83 |
+
return hash_sha256.hexdigest()
|
| 84 |
+
|
modules/hat_model.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
from modules import modelloader, devices
|
| 5 |
+
from modules.shared import opts
|
| 6 |
+
from modules.upscaler import Upscaler, UpscalerData
|
| 7 |
+
from modules.upscaler_utils import upscale_with_model
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class UpscalerHAT(Upscaler):
|
| 11 |
+
def __init__(self, dirname):
|
| 12 |
+
self.name = "HAT"
|
| 13 |
+
self.scalers = []
|
| 14 |
+
self.user_path = dirname
|
| 15 |
+
super().__init__()
|
| 16 |
+
for file in self.find_models(ext_filter=[".pt", ".pth"]):
|
| 17 |
+
name = modelloader.friendly_name(file)
|
| 18 |
+
scale = 4 # TODO: scale might not be 4, but we can't know without loading the model
|
| 19 |
+
scaler_data = UpscalerData(name, file, upscaler=self, scale=scale)
|
| 20 |
+
self.scalers.append(scaler_data)
|
| 21 |
+
|
| 22 |
+
def do_upscale(self, img, selected_model):
|
| 23 |
+
try:
|
| 24 |
+
model = self.load_model(selected_model)
|
| 25 |
+
except Exception as e:
|
| 26 |
+
print(f"Unable to load HAT model {selected_model}: {e}", file=sys.stderr)
|
| 27 |
+
return img
|
| 28 |
+
model.to(devices.device_esrgan) # TODO: should probably be device_hat
|
| 29 |
+
return upscale_with_model(
|
| 30 |
+
model,
|
| 31 |
+
img,
|
| 32 |
+
tile_size=opts.ESRGAN_tile, # TODO: should probably be HAT_tile
|
| 33 |
+
tile_overlap=opts.ESRGAN_tile_overlap, # TODO: should probably be HAT_tile_overlap
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def load_model(self, path: str):
|
| 37 |
+
if not os.path.isfile(path):
|
| 38 |
+
raise FileNotFoundError(f"Model file {path} not found")
|
| 39 |
+
return modelloader.load_spandrel_model(
|
| 40 |
+
path,
|
| 41 |
+
device=devices.device_esrgan, # TODO: should probably be device_hat
|
| 42 |
+
expected_architecture='HAT',
|
| 43 |
+
)
|
modules/hypernetworks/hypernetwork.py
ADDED
|
@@ -0,0 +1,783 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import glob
|
| 3 |
+
import html
|
| 4 |
+
import os
|
| 5 |
+
import inspect
|
| 6 |
+
from contextlib import closing
|
| 7 |
+
|
| 8 |
+
import modules.textual_inversion.dataset
|
| 9 |
+
import torch
|
| 10 |
+
import tqdm
|
| 11 |
+
from einops import rearrange, repeat
|
| 12 |
+
from ldm.util import default
|
| 13 |
+
from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
|
| 14 |
+
from modules.textual_inversion import textual_inversion, saving_settings
|
| 15 |
+
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
| 16 |
+
from torch import einsum
|
| 17 |
+
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
| 18 |
+
|
| 19 |
+
from collections import deque
|
| 20 |
+
from statistics import stdev, mean
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
| 24 |
+
|
| 25 |
+
class HypernetworkModule(torch.nn.Module):
|
| 26 |
+
activation_dict = {
|
| 27 |
+
"linear": torch.nn.Identity,
|
| 28 |
+
"relu": torch.nn.ReLU,
|
| 29 |
+
"leakyrelu": torch.nn.LeakyReLU,
|
| 30 |
+
"elu": torch.nn.ELU,
|
| 31 |
+
"swish": torch.nn.Hardswish,
|
| 32 |
+
"tanh": torch.nn.Tanh,
|
| 33 |
+
"sigmoid": torch.nn.Sigmoid,
|
| 34 |
+
}
|
| 35 |
+
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
| 36 |
+
|
| 37 |
+
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
|
| 38 |
+
add_layer_norm=False, activate_output=False, dropout_structure=None):
|
| 39 |
+
super().__init__()
|
| 40 |
+
|
| 41 |
+
self.multiplier = 1.0
|
| 42 |
+
|
| 43 |
+
assert layer_structure is not None, "layer_structure must not be None"
|
| 44 |
+
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
| 45 |
+
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
| 46 |
+
|
| 47 |
+
linears = []
|
| 48 |
+
for i in range(len(layer_structure) - 1):
|
| 49 |
+
|
| 50 |
+
# Add a fully-connected layer
|
| 51 |
+
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
| 52 |
+
|
| 53 |
+
# Add an activation func except last layer
|
| 54 |
+
if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
|
| 55 |
+
pass
|
| 56 |
+
elif activation_func in self.activation_dict:
|
| 57 |
+
linears.append(self.activation_dict[activation_func]())
|
| 58 |
+
else:
|
| 59 |
+
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
| 60 |
+
|
| 61 |
+
# Add layer normalization
|
| 62 |
+
if add_layer_norm:
|
| 63 |
+
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
| 64 |
+
|
| 65 |
+
# Everything should be now parsed into dropout structure, and applied here.
|
| 66 |
+
# Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
|
| 67 |
+
if dropout_structure is not None and dropout_structure[i+1] > 0:
|
| 68 |
+
assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
|
| 69 |
+
linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
|
| 70 |
+
# Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
|
| 71 |
+
|
| 72 |
+
self.linear = torch.nn.Sequential(*linears)
|
| 73 |
+
|
| 74 |
+
if state_dict is not None:
|
| 75 |
+
self.fix_old_state_dict(state_dict)
|
| 76 |
+
self.load_state_dict(state_dict)
|
| 77 |
+
else:
|
| 78 |
+
for layer in self.linear:
|
| 79 |
+
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
| 80 |
+
w, b = layer.weight.data, layer.bias.data
|
| 81 |
+
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
|
| 82 |
+
normal_(w, mean=0.0, std=0.01)
|
| 83 |
+
normal_(b, mean=0.0, std=0)
|
| 84 |
+
elif weight_init == 'XavierUniform':
|
| 85 |
+
xavier_uniform_(w)
|
| 86 |
+
zeros_(b)
|
| 87 |
+
elif weight_init == 'XavierNormal':
|
| 88 |
+
xavier_normal_(w)
|
| 89 |
+
zeros_(b)
|
| 90 |
+
elif weight_init == 'KaimingUniform':
|
| 91 |
+
kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
| 92 |
+
zeros_(b)
|
| 93 |
+
elif weight_init == 'KaimingNormal':
|
| 94 |
+
kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
| 95 |
+
zeros_(b)
|
| 96 |
+
else:
|
| 97 |
+
raise KeyError(f"Key {weight_init} is not defined as initialization!")
|
| 98 |
+
devices.torch_npu_set_device()
|
| 99 |
+
self.to(devices.device)
|
| 100 |
+
|
| 101 |
+
def fix_old_state_dict(self, state_dict):
|
| 102 |
+
changes = {
|
| 103 |
+
'linear1.bias': 'linear.0.bias',
|
| 104 |
+
'linear1.weight': 'linear.0.weight',
|
| 105 |
+
'linear2.bias': 'linear.1.bias',
|
| 106 |
+
'linear2.weight': 'linear.1.weight',
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
for fr, to in changes.items():
|
| 110 |
+
x = state_dict.get(fr, None)
|
| 111 |
+
if x is None:
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
del state_dict[fr]
|
| 115 |
+
state_dict[to] = x
|
| 116 |
+
|
| 117 |
+
def forward(self, x):
|
| 118 |
+
return x + self.linear(x) * (self.multiplier if not self.training else 1)
|
| 119 |
+
|
| 120 |
+
def trainables(self):
|
| 121 |
+
layer_structure = []
|
| 122 |
+
for layer in self.linear:
|
| 123 |
+
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
| 124 |
+
layer_structure += [layer.weight, layer.bias]
|
| 125 |
+
return layer_structure
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
|
| 129 |
+
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
|
| 130 |
+
if layer_structure is None:
|
| 131 |
+
layer_structure = [1, 2, 1]
|
| 132 |
+
if not use_dropout:
|
| 133 |
+
return [0] * len(layer_structure)
|
| 134 |
+
dropout_values = [0]
|
| 135 |
+
dropout_values.extend([0.3] * (len(layer_structure) - 3))
|
| 136 |
+
if last_layer_dropout:
|
| 137 |
+
dropout_values.append(0.3)
|
| 138 |
+
else:
|
| 139 |
+
dropout_values.append(0)
|
| 140 |
+
dropout_values.append(0)
|
| 141 |
+
return dropout_values
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class Hypernetwork:
|
| 145 |
+
filename = None
|
| 146 |
+
name = None
|
| 147 |
+
|
| 148 |
+
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
|
| 149 |
+
self.filename = None
|
| 150 |
+
self.name = name
|
| 151 |
+
self.layers = {}
|
| 152 |
+
self.step = 0
|
| 153 |
+
self.sd_checkpoint = None
|
| 154 |
+
self.sd_checkpoint_name = None
|
| 155 |
+
self.layer_structure = layer_structure
|
| 156 |
+
self.activation_func = activation_func
|
| 157 |
+
self.weight_init = weight_init
|
| 158 |
+
self.add_layer_norm = add_layer_norm
|
| 159 |
+
self.use_dropout = use_dropout
|
| 160 |
+
self.activate_output = activate_output
|
| 161 |
+
self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
|
| 162 |
+
self.dropout_structure = kwargs.get('dropout_structure', None)
|
| 163 |
+
if self.dropout_structure is None:
|
| 164 |
+
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
| 165 |
+
self.optimizer_name = None
|
| 166 |
+
self.optimizer_state_dict = None
|
| 167 |
+
self.optional_info = None
|
| 168 |
+
|
| 169 |
+
for size in enable_sizes or []:
|
| 170 |
+
self.layers[size] = (
|
| 171 |
+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
| 172 |
+
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
|
| 173 |
+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
| 174 |
+
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
|
| 175 |
+
)
|
| 176 |
+
self.eval()
|
| 177 |
+
|
| 178 |
+
def weights(self):
|
| 179 |
+
res = []
|
| 180 |
+
for layers in self.layers.values():
|
| 181 |
+
for layer in layers:
|
| 182 |
+
res += layer.parameters()
|
| 183 |
+
return res
|
| 184 |
+
|
| 185 |
+
def train(self, mode=True):
|
| 186 |
+
for layers in self.layers.values():
|
| 187 |
+
for layer in layers:
|
| 188 |
+
layer.train(mode=mode)
|
| 189 |
+
for param in layer.parameters():
|
| 190 |
+
param.requires_grad = mode
|
| 191 |
+
|
| 192 |
+
def to(self, device):
|
| 193 |
+
for layers in self.layers.values():
|
| 194 |
+
for layer in layers:
|
| 195 |
+
layer.to(device)
|
| 196 |
+
|
| 197 |
+
return self
|
| 198 |
+
|
| 199 |
+
def set_multiplier(self, multiplier):
|
| 200 |
+
for layers in self.layers.values():
|
| 201 |
+
for layer in layers:
|
| 202 |
+
layer.multiplier = multiplier
|
| 203 |
+
|
| 204 |
+
return self
|
| 205 |
+
|
| 206 |
+
def eval(self):
|
| 207 |
+
for layers in self.layers.values():
|
| 208 |
+
for layer in layers:
|
| 209 |
+
layer.eval()
|
| 210 |
+
for param in layer.parameters():
|
| 211 |
+
param.requires_grad = False
|
| 212 |
+
|
| 213 |
+
def save(self, filename):
|
| 214 |
+
state_dict = {}
|
| 215 |
+
optimizer_saved_dict = {}
|
| 216 |
+
|
| 217 |
+
for k, v in self.layers.items():
|
| 218 |
+
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
|
| 219 |
+
|
| 220 |
+
state_dict['step'] = self.step
|
| 221 |
+
state_dict['name'] = self.name
|
| 222 |
+
state_dict['layer_structure'] = self.layer_structure
|
| 223 |
+
state_dict['activation_func'] = self.activation_func
|
| 224 |
+
state_dict['is_layer_norm'] = self.add_layer_norm
|
| 225 |
+
state_dict['weight_initialization'] = self.weight_init
|
| 226 |
+
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
| 227 |
+
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
| 228 |
+
state_dict['activate_output'] = self.activate_output
|
| 229 |
+
state_dict['use_dropout'] = self.use_dropout
|
| 230 |
+
state_dict['dropout_structure'] = self.dropout_structure
|
| 231 |
+
state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
|
| 232 |
+
state_dict['optional_info'] = self.optional_info if self.optional_info else None
|
| 233 |
+
|
| 234 |
+
if self.optimizer_name is not None:
|
| 235 |
+
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
|
| 236 |
+
|
| 237 |
+
torch.save(state_dict, filename)
|
| 238 |
+
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
|
| 239 |
+
optimizer_saved_dict['hash'] = self.shorthash()
|
| 240 |
+
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
|
| 241 |
+
torch.save(optimizer_saved_dict, filename + '.optim')
|
| 242 |
+
|
| 243 |
+
def load(self, filename):
|
| 244 |
+
self.filename = filename
|
| 245 |
+
if self.name is None:
|
| 246 |
+
self.name = os.path.splitext(os.path.basename(filename))[0]
|
| 247 |
+
|
| 248 |
+
state_dict = torch.load(filename, map_location='cpu')
|
| 249 |
+
|
| 250 |
+
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
| 251 |
+
self.optional_info = state_dict.get('optional_info', None)
|
| 252 |
+
self.activation_func = state_dict.get('activation_func', None)
|
| 253 |
+
self.weight_init = state_dict.get('weight_initialization', 'Normal')
|
| 254 |
+
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
| 255 |
+
self.dropout_structure = state_dict.get('dropout_structure', None)
|
| 256 |
+
self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
|
| 257 |
+
self.activate_output = state_dict.get('activate_output', True)
|
| 258 |
+
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
|
| 259 |
+
# Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
|
| 260 |
+
if self.dropout_structure is None:
|
| 261 |
+
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
| 262 |
+
|
| 263 |
+
if shared.opts.print_hypernet_extra:
|
| 264 |
+
if self.optional_info is not None:
|
| 265 |
+
print(f" INFO:\n {self.optional_info}\n")
|
| 266 |
+
|
| 267 |
+
print(f" Layer structure: {self.layer_structure}")
|
| 268 |
+
print(f" Activation function: {self.activation_func}")
|
| 269 |
+
print(f" Weight initialization: {self.weight_init}")
|
| 270 |
+
print(f" Layer norm: {self.add_layer_norm}")
|
| 271 |
+
print(f" Dropout usage: {self.use_dropout}" )
|
| 272 |
+
print(f" Activate last layer: {self.activate_output}")
|
| 273 |
+
print(f" Dropout structure: {self.dropout_structure}")
|
| 274 |
+
|
| 275 |
+
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
|
| 276 |
+
|
| 277 |
+
if self.shorthash() == optimizer_saved_dict.get('hash', None):
|
| 278 |
+
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
| 279 |
+
else:
|
| 280 |
+
self.optimizer_state_dict = None
|
| 281 |
+
if self.optimizer_state_dict:
|
| 282 |
+
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
| 283 |
+
if shared.opts.print_hypernet_extra:
|
| 284 |
+
print("Loaded existing optimizer from checkpoint")
|
| 285 |
+
print(f"Optimizer name is {self.optimizer_name}")
|
| 286 |
+
else:
|
| 287 |
+
self.optimizer_name = "AdamW"
|
| 288 |
+
if shared.opts.print_hypernet_extra:
|
| 289 |
+
print("No saved optimizer exists in checkpoint")
|
| 290 |
+
|
| 291 |
+
for size, sd in state_dict.items():
|
| 292 |
+
if type(size) == int:
|
| 293 |
+
self.layers[size] = (
|
| 294 |
+
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
|
| 295 |
+
self.add_layer_norm, self.activate_output, self.dropout_structure),
|
| 296 |
+
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
|
| 297 |
+
self.add_layer_norm, self.activate_output, self.dropout_structure),
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
self.name = state_dict.get('name', self.name)
|
| 301 |
+
self.step = state_dict.get('step', 0)
|
| 302 |
+
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
| 303 |
+
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
| 304 |
+
self.eval()
|
| 305 |
+
|
| 306 |
+
def shorthash(self):
|
| 307 |
+
sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
|
| 308 |
+
|
| 309 |
+
return sha256[0:10] if sha256 else None
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def list_hypernetworks(path):
|
| 313 |
+
res = {}
|
| 314 |
+
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True), key=str.lower):
|
| 315 |
+
name = os.path.splitext(os.path.basename(filename))[0]
|
| 316 |
+
# Prevent a hypothetical "None.pt" from being listed.
|
| 317 |
+
if name != "None":
|
| 318 |
+
res[name] = filename
|
| 319 |
+
return res
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def load_hypernetwork(name):
|
| 323 |
+
path = shared.hypernetworks.get(name, None)
|
| 324 |
+
|
| 325 |
+
if path is None:
|
| 326 |
+
return None
|
| 327 |
+
|
| 328 |
+
try:
|
| 329 |
+
hypernetwork = Hypernetwork()
|
| 330 |
+
hypernetwork.load(path)
|
| 331 |
+
return hypernetwork
|
| 332 |
+
except Exception:
|
| 333 |
+
errors.report(f"Error loading hypernetwork {path}", exc_info=True)
|
| 334 |
+
return None
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def load_hypernetworks(names, multipliers=None):
|
| 338 |
+
already_loaded = {}
|
| 339 |
+
|
| 340 |
+
for hypernetwork in shared.loaded_hypernetworks:
|
| 341 |
+
if hypernetwork.name in names:
|
| 342 |
+
already_loaded[hypernetwork.name] = hypernetwork
|
| 343 |
+
|
| 344 |
+
shared.loaded_hypernetworks.clear()
|
| 345 |
+
|
| 346 |
+
for i, name in enumerate(names):
|
| 347 |
+
hypernetwork = already_loaded.get(name, None)
|
| 348 |
+
if hypernetwork is None:
|
| 349 |
+
hypernetwork = load_hypernetwork(name)
|
| 350 |
+
|
| 351 |
+
if hypernetwork is None:
|
| 352 |
+
continue
|
| 353 |
+
|
| 354 |
+
hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
|
| 355 |
+
shared.loaded_hypernetworks.append(hypernetwork)
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
| 359 |
+
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
|
| 360 |
+
|
| 361 |
+
if hypernetwork_layers is None:
|
| 362 |
+
return context_k, context_v
|
| 363 |
+
|
| 364 |
+
if layer is not None:
|
| 365 |
+
layer.hyper_k = hypernetwork_layers[0]
|
| 366 |
+
layer.hyper_v = hypernetwork_layers[1]
|
| 367 |
+
|
| 368 |
+
context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k)))
|
| 369 |
+
context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v)))
|
| 370 |
+
return context_k, context_v
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def apply_hypernetworks(hypernetworks, context, layer=None):
|
| 374 |
+
context_k = context
|
| 375 |
+
context_v = context
|
| 376 |
+
for hypernetwork in hypernetworks:
|
| 377 |
+
context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
|
| 378 |
+
|
| 379 |
+
return context_k, context_v
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def attention_CrossAttention_forward(self, x, context=None, mask=None, **kwargs):
|
| 383 |
+
h = self.heads
|
| 384 |
+
|
| 385 |
+
q = self.to_q(x)
|
| 386 |
+
context = default(context, x)
|
| 387 |
+
|
| 388 |
+
context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
|
| 389 |
+
k = self.to_k(context_k)
|
| 390 |
+
v = self.to_v(context_v)
|
| 391 |
+
|
| 392 |
+
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
| 393 |
+
|
| 394 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
| 395 |
+
|
| 396 |
+
if mask is not None:
|
| 397 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
| 398 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
| 399 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
| 400 |
+
sim.masked_fill_(~mask, max_neg_value)
|
| 401 |
+
|
| 402 |
+
# attention, what we cannot get enough of
|
| 403 |
+
attn = sim.softmax(dim=-1)
|
| 404 |
+
|
| 405 |
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
| 406 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
| 407 |
+
return self.to_out(out)
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def stack_conds(conds):
|
| 411 |
+
if len(conds) == 1:
|
| 412 |
+
return torch.stack(conds)
|
| 413 |
+
|
| 414 |
+
# same as in reconstruct_multicond_batch
|
| 415 |
+
token_count = max([x.shape[0] for x in conds])
|
| 416 |
+
for i in range(len(conds)):
|
| 417 |
+
if conds[i].shape[0] != token_count:
|
| 418 |
+
last_vector = conds[i][-1:]
|
| 419 |
+
last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
|
| 420 |
+
conds[i] = torch.vstack([conds[i], last_vector_repeated])
|
| 421 |
+
|
| 422 |
+
return torch.stack(conds)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def statistics(data):
|
| 426 |
+
if len(data) < 2:
|
| 427 |
+
std = 0
|
| 428 |
+
else:
|
| 429 |
+
std = stdev(data)
|
| 430 |
+
total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
|
| 431 |
+
recent_data = data[-32:]
|
| 432 |
+
if len(recent_data) < 2:
|
| 433 |
+
std = 0
|
| 434 |
+
else:
|
| 435 |
+
std = stdev(recent_data)
|
| 436 |
+
recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
|
| 437 |
+
return total_information, recent_information
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
| 441 |
+
# Remove illegal characters from name.
|
| 442 |
+
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
| 443 |
+
assert name, "Name cannot be empty!"
|
| 444 |
+
|
| 445 |
+
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
| 446 |
+
if not overwrite_old:
|
| 447 |
+
assert not os.path.exists(fn), f"file {fn} already exists"
|
| 448 |
+
|
| 449 |
+
if type(layer_structure) == str:
|
| 450 |
+
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
| 451 |
+
|
| 452 |
+
if use_dropout and dropout_structure and type(dropout_structure) == str:
|
| 453 |
+
dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
|
| 454 |
+
else:
|
| 455 |
+
dropout_structure = [0] * len(layer_structure)
|
| 456 |
+
|
| 457 |
+
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
| 458 |
+
name=name,
|
| 459 |
+
enable_sizes=[int(x) for x in enable_sizes],
|
| 460 |
+
layer_structure=layer_structure,
|
| 461 |
+
activation_func=activation_func,
|
| 462 |
+
weight_init=weight_init,
|
| 463 |
+
add_layer_norm=add_layer_norm,
|
| 464 |
+
use_dropout=use_dropout,
|
| 465 |
+
dropout_structure=dropout_structure
|
| 466 |
+
)
|
| 467 |
+
hypernet.save(fn)
|
| 468 |
+
|
| 469 |
+
shared.reload_hypernetworks()
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def train_hypernetwork(id_task, hypernetwork_name: str, learn_rate: float, batch_size: int, gradient_step: int, data_root: str, log_directory: str, training_width: int, training_height: int, varsize: bool, steps: int, clip_grad_mode: str, clip_grad_value: float, shuffle_tags: bool, tag_drop_out: bool, latent_sampling_method: str, use_weight: bool, create_image_every: int, save_hypernetwork_every: int, template_filename: str, preview_from_txt2img: bool, preview_prompt: str, preview_negative_prompt: str, preview_steps: int, preview_sampler_name: str, preview_cfg_scale: float, preview_seed: int, preview_width: int, preview_height: int):
|
| 473 |
+
from modules import images, processing
|
| 474 |
+
|
| 475 |
+
save_hypernetwork_every = save_hypernetwork_every or 0
|
| 476 |
+
create_image_every = create_image_every or 0
|
| 477 |
+
template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
|
| 478 |
+
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
|
| 479 |
+
template_file = template_file.path
|
| 480 |
+
|
| 481 |
+
path = shared.hypernetworks.get(hypernetwork_name, None)
|
| 482 |
+
hypernetwork = Hypernetwork()
|
| 483 |
+
hypernetwork.load(path)
|
| 484 |
+
shared.loaded_hypernetworks = [hypernetwork]
|
| 485 |
+
|
| 486 |
+
shared.state.job = "train-hypernetwork"
|
| 487 |
+
shared.state.textinfo = "Initializing hypernetwork training..."
|
| 488 |
+
shared.state.job_count = steps
|
| 489 |
+
|
| 490 |
+
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
| 491 |
+
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
| 492 |
+
|
| 493 |
+
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
|
| 494 |
+
unload = shared.opts.unload_models_when_training
|
| 495 |
+
|
| 496 |
+
if save_hypernetwork_every > 0:
|
| 497 |
+
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
|
| 498 |
+
os.makedirs(hypernetwork_dir, exist_ok=True)
|
| 499 |
+
else:
|
| 500 |
+
hypernetwork_dir = None
|
| 501 |
+
|
| 502 |
+
if create_image_every > 0:
|
| 503 |
+
images_dir = os.path.join(log_directory, "images")
|
| 504 |
+
os.makedirs(images_dir, exist_ok=True)
|
| 505 |
+
else:
|
| 506 |
+
images_dir = None
|
| 507 |
+
|
| 508 |
+
checkpoint = sd_models.select_checkpoint()
|
| 509 |
+
|
| 510 |
+
initial_step = hypernetwork.step or 0
|
| 511 |
+
if initial_step >= steps:
|
| 512 |
+
shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
| 513 |
+
return hypernetwork, filename
|
| 514 |
+
|
| 515 |
+
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
| 516 |
+
|
| 517 |
+
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
| 518 |
+
if clip_grad:
|
| 519 |
+
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
| 520 |
+
|
| 521 |
+
if shared.opts.training_enable_tensorboard:
|
| 522 |
+
tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
|
| 523 |
+
|
| 524 |
+
# dataset loading may take a while, so input validations and early returns should be done before this
|
| 525 |
+
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
| 526 |
+
|
| 527 |
+
pin_memory = shared.opts.pin_memory
|
| 528 |
+
|
| 529 |
+
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
|
| 530 |
+
|
| 531 |
+
if shared.opts.save_training_settings_to_txt:
|
| 532 |
+
saved_params = dict(
|
| 533 |
+
model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
|
| 534 |
+
**{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
|
| 535 |
+
)
|
| 536 |
+
saving_settings.save_settings_to_file(log_directory, {**saved_params, **locals()})
|
| 537 |
+
|
| 538 |
+
latent_sampling_method = ds.latent_sampling_method
|
| 539 |
+
|
| 540 |
+
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
| 541 |
+
|
| 542 |
+
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
| 543 |
+
|
| 544 |
+
if unload:
|
| 545 |
+
shared.parallel_processing_allowed = False
|
| 546 |
+
shared.sd_model.cond_stage_model.to(devices.cpu)
|
| 547 |
+
shared.sd_model.first_stage_model.to(devices.cpu)
|
| 548 |
+
|
| 549 |
+
weights = hypernetwork.weights()
|
| 550 |
+
hypernetwork.train()
|
| 551 |
+
|
| 552 |
+
# Here we use optimizer from saved HN, or we can specify as UI option.
|
| 553 |
+
if hypernetwork.optimizer_name in optimizer_dict:
|
| 554 |
+
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
|
| 555 |
+
optimizer_name = hypernetwork.optimizer_name
|
| 556 |
+
else:
|
| 557 |
+
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
|
| 558 |
+
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
|
| 559 |
+
optimizer_name = 'AdamW'
|
| 560 |
+
|
| 561 |
+
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
| 562 |
+
try:
|
| 563 |
+
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
| 564 |
+
except RuntimeError as e:
|
| 565 |
+
print("Cannot resume from saved optimizer!")
|
| 566 |
+
print(e)
|
| 567 |
+
|
| 568 |
+
scaler = torch.cuda.amp.GradScaler()
|
| 569 |
+
|
| 570 |
+
batch_size = ds.batch_size
|
| 571 |
+
gradient_step = ds.gradient_step
|
| 572 |
+
# n steps = batch_size * gradient_step * n image processed
|
| 573 |
+
steps_per_epoch = len(ds) // batch_size // gradient_step
|
| 574 |
+
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
|
| 575 |
+
loss_step = 0
|
| 576 |
+
_loss_step = 0 #internal
|
| 577 |
+
# size = len(ds.indexes)
|
| 578 |
+
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
| 579 |
+
loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
|
| 580 |
+
# losses = torch.zeros((size,))
|
| 581 |
+
# previous_mean_losses = [0]
|
| 582 |
+
# previous_mean_loss = 0
|
| 583 |
+
# print("Mean loss of {} elements".format(size))
|
| 584 |
+
|
| 585 |
+
steps_without_grad = 0
|
| 586 |
+
|
| 587 |
+
last_saved_file = "<none>"
|
| 588 |
+
last_saved_image = "<none>"
|
| 589 |
+
forced_filename = "<none>"
|
| 590 |
+
|
| 591 |
+
pbar = tqdm.tqdm(total=steps - initial_step)
|
| 592 |
+
try:
|
| 593 |
+
sd_hijack_checkpoint.add()
|
| 594 |
+
|
| 595 |
+
for _ in range((steps-initial_step) * gradient_step):
|
| 596 |
+
if scheduler.finished:
|
| 597 |
+
break
|
| 598 |
+
if shared.state.interrupted:
|
| 599 |
+
break
|
| 600 |
+
for j, batch in enumerate(dl):
|
| 601 |
+
# works as a drop_last=True for gradient accumulation
|
| 602 |
+
if j == max_steps_per_epoch:
|
| 603 |
+
break
|
| 604 |
+
scheduler.apply(optimizer, hypernetwork.step)
|
| 605 |
+
if scheduler.finished:
|
| 606 |
+
break
|
| 607 |
+
if shared.state.interrupted:
|
| 608 |
+
break
|
| 609 |
+
|
| 610 |
+
if clip_grad:
|
| 611 |
+
clip_grad_sched.step(hypernetwork.step)
|
| 612 |
+
|
| 613 |
+
with devices.autocast():
|
| 614 |
+
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
| 615 |
+
if use_weight:
|
| 616 |
+
w = batch.weight.to(devices.device, non_blocking=pin_memory)
|
| 617 |
+
if tag_drop_out != 0 or shuffle_tags:
|
| 618 |
+
shared.sd_model.cond_stage_model.to(devices.device)
|
| 619 |
+
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
|
| 620 |
+
shared.sd_model.cond_stage_model.to(devices.cpu)
|
| 621 |
+
else:
|
| 622 |
+
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
| 623 |
+
if use_weight:
|
| 624 |
+
loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
|
| 625 |
+
del w
|
| 626 |
+
else:
|
| 627 |
+
loss = shared.sd_model.forward(x, c)[0] / gradient_step
|
| 628 |
+
del x
|
| 629 |
+
del c
|
| 630 |
+
|
| 631 |
+
_loss_step += loss.item()
|
| 632 |
+
scaler.scale(loss).backward()
|
| 633 |
+
|
| 634 |
+
# go back until we reach gradient accumulation steps
|
| 635 |
+
if (j + 1) % gradient_step != 0:
|
| 636 |
+
continue
|
| 637 |
+
loss_logging.append(_loss_step)
|
| 638 |
+
if clip_grad:
|
| 639 |
+
clip_grad(weights, clip_grad_sched.learn_rate)
|
| 640 |
+
|
| 641 |
+
scaler.step(optimizer)
|
| 642 |
+
scaler.update()
|
| 643 |
+
hypernetwork.step += 1
|
| 644 |
+
pbar.update()
|
| 645 |
+
optimizer.zero_grad(set_to_none=True)
|
| 646 |
+
loss_step = _loss_step
|
| 647 |
+
_loss_step = 0
|
| 648 |
+
|
| 649 |
+
steps_done = hypernetwork.step + 1
|
| 650 |
+
|
| 651 |
+
epoch_num = hypernetwork.step // steps_per_epoch
|
| 652 |
+
epoch_step = hypernetwork.step % steps_per_epoch
|
| 653 |
+
|
| 654 |
+
description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
|
| 655 |
+
pbar.set_description(description)
|
| 656 |
+
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
| 657 |
+
# Before saving, change name to match current checkpoint.
|
| 658 |
+
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
| 659 |
+
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
| 660 |
+
hypernetwork.optimizer_name = optimizer_name
|
| 661 |
+
if shared.opts.save_optimizer_state:
|
| 662 |
+
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
| 663 |
+
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
| 664 |
+
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
if shared.opts.training_enable_tensorboard:
|
| 669 |
+
epoch_num = hypernetwork.step // len(ds)
|
| 670 |
+
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
|
| 671 |
+
mean_loss = sum(loss_logging) / len(loss_logging)
|
| 672 |
+
textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
|
| 673 |
+
|
| 674 |
+
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
|
| 675 |
+
"loss": f"{loss_step:.7f}",
|
| 676 |
+
"learn_rate": scheduler.learn_rate
|
| 677 |
+
})
|
| 678 |
+
|
| 679 |
+
if images_dir is not None and steps_done % create_image_every == 0:
|
| 680 |
+
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
| 681 |
+
last_saved_image = os.path.join(images_dir, forced_filename)
|
| 682 |
+
hypernetwork.eval()
|
| 683 |
+
rng_state = torch.get_rng_state()
|
| 684 |
+
cuda_rng_state = None
|
| 685 |
+
if torch.cuda.is_available():
|
| 686 |
+
cuda_rng_state = torch.cuda.get_rng_state_all()
|
| 687 |
+
shared.sd_model.cond_stage_model.to(devices.device)
|
| 688 |
+
shared.sd_model.first_stage_model.to(devices.device)
|
| 689 |
+
|
| 690 |
+
p = processing.StableDiffusionProcessingTxt2Img(
|
| 691 |
+
sd_model=shared.sd_model,
|
| 692 |
+
do_not_save_grid=True,
|
| 693 |
+
do_not_save_samples=True,
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
p.disable_extra_networks = True
|
| 697 |
+
|
| 698 |
+
if preview_from_txt2img:
|
| 699 |
+
p.prompt = preview_prompt
|
| 700 |
+
p.negative_prompt = preview_negative_prompt
|
| 701 |
+
p.steps = preview_steps
|
| 702 |
+
p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]
|
| 703 |
+
p.cfg_scale = preview_cfg_scale
|
| 704 |
+
p.seed = preview_seed
|
| 705 |
+
p.width = preview_width
|
| 706 |
+
p.height = preview_height
|
| 707 |
+
else:
|
| 708 |
+
p.prompt = batch.cond_text[0]
|
| 709 |
+
p.steps = 20
|
| 710 |
+
p.width = training_width
|
| 711 |
+
p.height = training_height
|
| 712 |
+
|
| 713 |
+
preview_text = p.prompt
|
| 714 |
+
|
| 715 |
+
with closing(p):
|
| 716 |
+
processed = processing.process_images(p)
|
| 717 |
+
image = processed.images[0] if len(processed.images) > 0 else None
|
| 718 |
+
|
| 719 |
+
if unload:
|
| 720 |
+
shared.sd_model.cond_stage_model.to(devices.cpu)
|
| 721 |
+
shared.sd_model.first_stage_model.to(devices.cpu)
|
| 722 |
+
torch.set_rng_state(rng_state)
|
| 723 |
+
if torch.cuda.is_available():
|
| 724 |
+
torch.cuda.set_rng_state_all(cuda_rng_state)
|
| 725 |
+
hypernetwork.train()
|
| 726 |
+
if image is not None:
|
| 727 |
+
shared.state.assign_current_image(image)
|
| 728 |
+
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
| 729 |
+
textual_inversion.tensorboard_add_image(tensorboard_writer,
|
| 730 |
+
f"Validation at epoch {epoch_num}", image,
|
| 731 |
+
hypernetwork.step)
|
| 732 |
+
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
| 733 |
+
last_saved_image += f", prompt: {preview_text}"
|
| 734 |
+
|
| 735 |
+
shared.state.job_no = hypernetwork.step
|
| 736 |
+
|
| 737 |
+
shared.state.textinfo = f"""
|
| 738 |
+
<p>
|
| 739 |
+
Loss: {loss_step:.7f}<br/>
|
| 740 |
+
Step: {steps_done}<br/>
|
| 741 |
+
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
| 742 |
+
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
| 743 |
+
Last saved image: {html.escape(last_saved_image)}<br/>
|
| 744 |
+
</p>
|
| 745 |
+
"""
|
| 746 |
+
except Exception:
|
| 747 |
+
errors.report("Exception in training hypernetwork", exc_info=True)
|
| 748 |
+
finally:
|
| 749 |
+
pbar.leave = False
|
| 750 |
+
pbar.close()
|
| 751 |
+
hypernetwork.eval()
|
| 752 |
+
sd_hijack_checkpoint.remove()
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
| 757 |
+
hypernetwork.optimizer_name = optimizer_name
|
| 758 |
+
if shared.opts.save_optimizer_state:
|
| 759 |
+
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
| 760 |
+
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
| 761 |
+
|
| 762 |
+
del optimizer
|
| 763 |
+
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
| 764 |
+
shared.sd_model.cond_stage_model.to(devices.device)
|
| 765 |
+
shared.sd_model.first_stage_model.to(devices.device)
|
| 766 |
+
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
| 767 |
+
|
| 768 |
+
return hypernetwork, filename
|
| 769 |
+
|
| 770 |
+
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
|
| 771 |
+
old_hypernetwork_name = hypernetwork.name
|
| 772 |
+
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
|
| 773 |
+
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
|
| 774 |
+
try:
|
| 775 |
+
hypernetwork.sd_checkpoint = checkpoint.shorthash
|
| 776 |
+
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
| 777 |
+
hypernetwork.name = hypernetwork_name
|
| 778 |
+
hypernetwork.save(filename)
|
| 779 |
+
except:
|
| 780 |
+
hypernetwork.sd_checkpoint = old_sd_checkpoint
|
| 781 |
+
hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
|
| 782 |
+
hypernetwork.name = old_hypernetwork_name
|
| 783 |
+
raise
|
modules/hypernetworks/ui.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import html
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import modules.hypernetworks.hypernetwork
|
| 5 |
+
from modules import devices, sd_hijack, shared
|
| 6 |
+
|
| 7 |
+
not_available = ["hardswish", "multiheadattention"]
|
| 8 |
+
keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
| 12 |
+
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
| 13 |
+
|
| 14 |
+
return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def train_hypernetwork(*args):
|
| 18 |
+
shared.loaded_hypernetworks = []
|
| 19 |
+
|
| 20 |
+
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
sd_hijack.undo_optimizations()
|
| 24 |
+
|
| 25 |
+
hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
|
| 26 |
+
|
| 27 |
+
res = f"""
|
| 28 |
+
Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
|
| 29 |
+
Hypernetwork saved to {html.escape(filename)}
|
| 30 |
+
"""
|
| 31 |
+
return res, ""
|
| 32 |
+
except Exception:
|
| 33 |
+
raise
|
| 34 |
+
finally:
|
| 35 |
+
shared.sd_model.cond_stage_model.to(devices.device)
|
| 36 |
+
shared.sd_model.first_stage_model.to(devices.device)
|
| 37 |
+
sd_hijack.apply_optimizations()
|
| 38 |
+
|
modules/images.py
ADDED
|
@@ -0,0 +1,877 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import datetime
|
| 4 |
+
import functools
|
| 5 |
+
import pytz
|
| 6 |
+
import io
|
| 7 |
+
import math
|
| 8 |
+
import os
|
| 9 |
+
from collections import namedtuple
|
| 10 |
+
import re
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import piexif
|
| 14 |
+
import piexif.helper
|
| 15 |
+
from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin, ImageOps
|
| 16 |
+
# pillow_avif needs to be imported somewhere in code for it to work
|
| 17 |
+
import pillow_avif # noqa: F401
|
| 18 |
+
import string
|
| 19 |
+
import json
|
| 20 |
+
import hashlib
|
| 21 |
+
|
| 22 |
+
from modules import sd_samplers, shared, script_callbacks, errors
|
| 23 |
+
from modules.paths_internal import roboto_ttf_file
|
| 24 |
+
from modules.shared import opts
|
| 25 |
+
|
| 26 |
+
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_font(fontsize: int):
|
| 30 |
+
try:
|
| 31 |
+
return ImageFont.truetype(opts.font or roboto_ttf_file, fontsize)
|
| 32 |
+
except Exception:
|
| 33 |
+
return ImageFont.truetype(roboto_ttf_file, fontsize)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def image_grid(imgs, batch_size=1, rows=None):
|
| 37 |
+
if rows is None:
|
| 38 |
+
if opts.n_rows > 0:
|
| 39 |
+
rows = opts.n_rows
|
| 40 |
+
elif opts.n_rows == 0:
|
| 41 |
+
rows = batch_size
|
| 42 |
+
elif opts.grid_prevent_empty_spots:
|
| 43 |
+
rows = math.floor(math.sqrt(len(imgs)))
|
| 44 |
+
while len(imgs) % rows != 0:
|
| 45 |
+
rows -= 1
|
| 46 |
+
else:
|
| 47 |
+
rows = math.sqrt(len(imgs))
|
| 48 |
+
rows = round(rows)
|
| 49 |
+
if rows > len(imgs):
|
| 50 |
+
rows = len(imgs)
|
| 51 |
+
|
| 52 |
+
cols = math.ceil(len(imgs) / rows)
|
| 53 |
+
|
| 54 |
+
params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
|
| 55 |
+
script_callbacks.image_grid_callback(params)
|
| 56 |
+
|
| 57 |
+
w, h = map(max, zip(*(img.size for img in imgs)))
|
| 58 |
+
grid_background_color = ImageColor.getcolor(opts.grid_background_color, 'RGB')
|
| 59 |
+
grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color=grid_background_color)
|
| 60 |
+
|
| 61 |
+
for i, img in enumerate(params.imgs):
|
| 62 |
+
img_w, img_h = img.size
|
| 63 |
+
w_offset, h_offset = 0 if img_w == w else (w - img_w) // 2, 0 if img_h == h else (h - img_h) // 2
|
| 64 |
+
grid.paste(img, box=(i % params.cols * w + w_offset, i // params.cols * h + h_offset))
|
| 65 |
+
|
| 66 |
+
return grid
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class Grid(namedtuple("_Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])):
|
| 70 |
+
@property
|
| 71 |
+
def tile_count(self) -> int:
|
| 72 |
+
"""
|
| 73 |
+
The total number of tiles in the grid.
|
| 74 |
+
"""
|
| 75 |
+
return sum(len(row[2]) for row in self.tiles)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid:
|
| 79 |
+
w, h = image.size
|
| 80 |
+
|
| 81 |
+
non_overlap_width = tile_w - overlap
|
| 82 |
+
non_overlap_height = tile_h - overlap
|
| 83 |
+
|
| 84 |
+
cols = math.ceil((w - overlap) / non_overlap_width)
|
| 85 |
+
rows = math.ceil((h - overlap) / non_overlap_height)
|
| 86 |
+
|
| 87 |
+
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
|
| 88 |
+
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
|
| 89 |
+
|
| 90 |
+
grid = Grid([], tile_w, tile_h, w, h, overlap)
|
| 91 |
+
for row in range(rows):
|
| 92 |
+
row_images = []
|
| 93 |
+
|
| 94 |
+
y = int(row * dy)
|
| 95 |
+
|
| 96 |
+
if y + tile_h >= h:
|
| 97 |
+
y = h - tile_h
|
| 98 |
+
|
| 99 |
+
for col in range(cols):
|
| 100 |
+
x = int(col * dx)
|
| 101 |
+
|
| 102 |
+
if x + tile_w >= w:
|
| 103 |
+
x = w - tile_w
|
| 104 |
+
|
| 105 |
+
tile = image.crop((x, y, x + tile_w, y + tile_h))
|
| 106 |
+
|
| 107 |
+
row_images.append([x, tile_w, tile])
|
| 108 |
+
|
| 109 |
+
grid.tiles.append([y, tile_h, row_images])
|
| 110 |
+
|
| 111 |
+
return grid
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def combine_grid(grid):
|
| 115 |
+
def make_mask_image(r):
|
| 116 |
+
r = r * 255 / grid.overlap
|
| 117 |
+
r = r.astype(np.uint8)
|
| 118 |
+
return Image.fromarray(r, 'L')
|
| 119 |
+
|
| 120 |
+
mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
|
| 121 |
+
mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
|
| 122 |
+
|
| 123 |
+
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
|
| 124 |
+
for y, h, row in grid.tiles:
|
| 125 |
+
combined_row = Image.new("RGB", (grid.image_w, h))
|
| 126 |
+
for x, w, tile in row:
|
| 127 |
+
if x == 0:
|
| 128 |
+
combined_row.paste(tile, (0, 0))
|
| 129 |
+
continue
|
| 130 |
+
|
| 131 |
+
combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
|
| 132 |
+
combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
|
| 133 |
+
|
| 134 |
+
if y == 0:
|
| 135 |
+
combined_image.paste(combined_row, (0, 0))
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
|
| 139 |
+
combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
|
| 140 |
+
|
| 141 |
+
return combined_image
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class GridAnnotation:
|
| 145 |
+
def __init__(self, text='', is_active=True):
|
| 146 |
+
self.text = text
|
| 147 |
+
self.is_active = is_active
|
| 148 |
+
self.size = None
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
| 152 |
+
|
| 153 |
+
color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB')
|
| 154 |
+
color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB')
|
| 155 |
+
color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB')
|
| 156 |
+
|
| 157 |
+
def wrap(drawing, text, font, line_length):
|
| 158 |
+
lines = ['']
|
| 159 |
+
for word in text.split():
|
| 160 |
+
line = f'{lines[-1]} {word}'.strip()
|
| 161 |
+
if drawing.textlength(line, font=font) <= line_length:
|
| 162 |
+
lines[-1] = line
|
| 163 |
+
else:
|
| 164 |
+
lines.append(word)
|
| 165 |
+
return lines
|
| 166 |
+
|
| 167 |
+
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
|
| 168 |
+
for line in lines:
|
| 169 |
+
fnt = initial_fnt
|
| 170 |
+
fontsize = initial_fontsize
|
| 171 |
+
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
|
| 172 |
+
fontsize -= 1
|
| 173 |
+
fnt = get_font(fontsize)
|
| 174 |
+
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
|
| 175 |
+
|
| 176 |
+
if not line.is_active:
|
| 177 |
+
drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
|
| 178 |
+
|
| 179 |
+
draw_y += line.size[1] + line_spacing
|
| 180 |
+
|
| 181 |
+
fontsize = (width + height) // 25
|
| 182 |
+
line_spacing = fontsize // 2
|
| 183 |
+
|
| 184 |
+
fnt = get_font(fontsize)
|
| 185 |
+
|
| 186 |
+
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
|
| 187 |
+
|
| 188 |
+
cols = im.width // width
|
| 189 |
+
rows = im.height // height
|
| 190 |
+
|
| 191 |
+
assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
|
| 192 |
+
assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
|
| 193 |
+
|
| 194 |
+
calc_img = Image.new("RGB", (1, 1), color_background)
|
| 195 |
+
calc_d = ImageDraw.Draw(calc_img)
|
| 196 |
+
|
| 197 |
+
for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
|
| 198 |
+
items = [] + texts
|
| 199 |
+
texts.clear()
|
| 200 |
+
|
| 201 |
+
for line in items:
|
| 202 |
+
wrapped = wrap(calc_d, line.text, fnt, allowed_width)
|
| 203 |
+
texts += [GridAnnotation(x, line.is_active) for x in wrapped]
|
| 204 |
+
|
| 205 |
+
for line in texts:
|
| 206 |
+
bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
|
| 207 |
+
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
|
| 208 |
+
line.allowed_width = allowed_width
|
| 209 |
+
|
| 210 |
+
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
|
| 211 |
+
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
|
| 212 |
+
|
| 213 |
+
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
|
| 214 |
+
|
| 215 |
+
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), color_background)
|
| 216 |
+
|
| 217 |
+
for row in range(rows):
|
| 218 |
+
for col in range(cols):
|
| 219 |
+
cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
|
| 220 |
+
result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))
|
| 221 |
+
|
| 222 |
+
d = ImageDraw.Draw(result)
|
| 223 |
+
|
| 224 |
+
for col in range(cols):
|
| 225 |
+
x = pad_left + (width + margin) * col + width / 2
|
| 226 |
+
y = pad_top / 2 - hor_text_heights[col] / 2
|
| 227 |
+
|
| 228 |
+
draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
|
| 229 |
+
|
| 230 |
+
for row in range(rows):
|
| 231 |
+
x = pad_left / 2
|
| 232 |
+
y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2
|
| 233 |
+
|
| 234 |
+
draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
|
| 235 |
+
|
| 236 |
+
return result
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
|
| 240 |
+
prompts = all_prompts[1:]
|
| 241 |
+
boundary = math.ceil(len(prompts) / 2)
|
| 242 |
+
|
| 243 |
+
prompts_horiz = prompts[:boundary]
|
| 244 |
+
prompts_vert = prompts[boundary:]
|
| 245 |
+
|
| 246 |
+
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
|
| 247 |
+
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
|
| 248 |
+
|
| 249 |
+
return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
| 253 |
+
"""
|
| 254 |
+
Resizes an image with the specified resize_mode, width, and height.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
resize_mode: The mode to use when resizing the image.
|
| 258 |
+
0: Resize the image to the specified width and height.
|
| 259 |
+
1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
|
| 260 |
+
2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
|
| 261 |
+
im: The image to resize.
|
| 262 |
+
width: The width to resize the image to.
|
| 263 |
+
height: The height to resize the image to.
|
| 264 |
+
upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
upscaler_name = upscaler_name or opts.upscaler_for_img2img
|
| 268 |
+
|
| 269 |
+
def resize(im, w, h):
|
| 270 |
+
if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
|
| 271 |
+
return im.resize((w, h), resample=LANCZOS)
|
| 272 |
+
|
| 273 |
+
scale = max(w / im.width, h / im.height)
|
| 274 |
+
|
| 275 |
+
if scale > 1.0:
|
| 276 |
+
upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
|
| 277 |
+
if len(upscalers) == 0:
|
| 278 |
+
upscaler = shared.sd_upscalers[0]
|
| 279 |
+
print(f"could not find upscaler named {upscaler_name or '<empty string>'}, using {upscaler.name} as a fallback")
|
| 280 |
+
else:
|
| 281 |
+
upscaler = upscalers[0]
|
| 282 |
+
|
| 283 |
+
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
|
| 284 |
+
|
| 285 |
+
if im.width != w or im.height != h:
|
| 286 |
+
im = im.resize((w, h), resample=LANCZOS)
|
| 287 |
+
|
| 288 |
+
return im
|
| 289 |
+
|
| 290 |
+
if resize_mode == 0:
|
| 291 |
+
res = resize(im, width, height)
|
| 292 |
+
|
| 293 |
+
elif resize_mode == 1:
|
| 294 |
+
ratio = width / height
|
| 295 |
+
src_ratio = im.width / im.height
|
| 296 |
+
|
| 297 |
+
src_w = width if ratio > src_ratio else im.width * height // im.height
|
| 298 |
+
src_h = height if ratio <= src_ratio else im.height * width // im.width
|
| 299 |
+
|
| 300 |
+
resized = resize(im, src_w, src_h)
|
| 301 |
+
res = Image.new("RGB", (width, height))
|
| 302 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
| 303 |
+
|
| 304 |
+
else:
|
| 305 |
+
ratio = width / height
|
| 306 |
+
src_ratio = im.width / im.height
|
| 307 |
+
|
| 308 |
+
src_w = width if ratio < src_ratio else im.width * height // im.height
|
| 309 |
+
src_h = height if ratio >= src_ratio else im.height * width // im.width
|
| 310 |
+
|
| 311 |
+
resized = resize(im, src_w, src_h)
|
| 312 |
+
res = Image.new("RGB", (width, height))
|
| 313 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
| 314 |
+
|
| 315 |
+
if ratio < src_ratio:
|
| 316 |
+
fill_height = height // 2 - src_h // 2
|
| 317 |
+
if fill_height > 0:
|
| 318 |
+
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
| 319 |
+
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
|
| 320 |
+
elif ratio > src_ratio:
|
| 321 |
+
fill_width = width // 2 - src_w // 2
|
| 322 |
+
if fill_width > 0:
|
| 323 |
+
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
| 324 |
+
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
|
| 325 |
+
|
| 326 |
+
return res
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
if not shared.cmd_opts.unix_filenames_sanitization:
|
| 330 |
+
invalid_filename_chars = '#<>:"/\\|?*\n\r\t'
|
| 331 |
+
else:
|
| 332 |
+
invalid_filename_chars = '/'
|
| 333 |
+
invalid_filename_prefix = ' '
|
| 334 |
+
invalid_filename_postfix = ' .'
|
| 335 |
+
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
| 336 |
+
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
|
| 337 |
+
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
|
| 338 |
+
max_filename_part_length = shared.cmd_opts.filenames_max_length
|
| 339 |
+
NOTHING_AND_SKIP_PREVIOUS_TEXT = object()
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def sanitize_filename_part(text, replace_spaces=True):
|
| 343 |
+
if text is None:
|
| 344 |
+
return None
|
| 345 |
+
|
| 346 |
+
if replace_spaces:
|
| 347 |
+
text = text.replace(' ', '_')
|
| 348 |
+
|
| 349 |
+
text = text.translate({ord(x): '_' for x in invalid_filename_chars})
|
| 350 |
+
text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]
|
| 351 |
+
text = text.rstrip(invalid_filename_postfix)
|
| 352 |
+
return text
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
@functools.cache
|
| 356 |
+
def get_scheduler_str(sampler_name, scheduler_name):
|
| 357 |
+
"""Returns {Scheduler} if the scheduler is applicable to the sampler"""
|
| 358 |
+
if scheduler_name == 'Automatic':
|
| 359 |
+
config = sd_samplers.find_sampler_config(sampler_name)
|
| 360 |
+
scheduler_name = config.options.get('scheduler', 'Automatic')
|
| 361 |
+
return scheduler_name.capitalize()
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
@functools.cache
|
| 365 |
+
def get_sampler_scheduler_str(sampler_name, scheduler_name):
|
| 366 |
+
"""Returns the '{Sampler} {Scheduler}' if the scheduler is applicable to the sampler"""
|
| 367 |
+
return f'{sampler_name} {get_scheduler_str(sampler_name, scheduler_name)}'
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def get_sampler_scheduler(p, sampler):
|
| 371 |
+
"""Returns '{Sampler} {Scheduler}' / '{Scheduler}' / 'NOTHING_AND_SKIP_PREVIOUS_TEXT'"""
|
| 372 |
+
if hasattr(p, 'scheduler') and hasattr(p, 'sampler_name'):
|
| 373 |
+
if sampler:
|
| 374 |
+
sampler_scheduler = get_sampler_scheduler_str(p.sampler_name, p.scheduler)
|
| 375 |
+
else:
|
| 376 |
+
sampler_scheduler = get_scheduler_str(p.sampler_name, p.scheduler)
|
| 377 |
+
return sanitize_filename_part(sampler_scheduler, replace_spaces=False)
|
| 378 |
+
return NOTHING_AND_SKIP_PREVIOUS_TEXT
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class FilenameGenerator:
|
| 382 |
+
replacements = {
|
| 383 |
+
'basename': lambda self: self.basename or 'img',
|
| 384 |
+
'seed': lambda self: self.seed if self.seed is not None else '',
|
| 385 |
+
'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
|
| 386 |
+
'seed_last': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.all_seeds[-1],
|
| 387 |
+
'steps': lambda self: self.p and self.p.steps,
|
| 388 |
+
'cfg': lambda self: self.p and self.p.cfg_scale,
|
| 389 |
+
'width': lambda self: self.image.width,
|
| 390 |
+
'height': lambda self: self.image.height,
|
| 391 |
+
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
|
| 392 |
+
'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
|
| 393 |
+
'sampler_scheduler': lambda self: self.p and get_sampler_scheduler(self.p, True),
|
| 394 |
+
'scheduler': lambda self: self.p and get_sampler_scheduler(self.p, False),
|
| 395 |
+
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
|
| 396 |
+
'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.name_for_extra, replace_spaces=False),
|
| 397 |
+
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
| 398 |
+
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
| 399 |
+
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
| 400 |
+
'prompt_hash': lambda self, *args: self.string_hash(self.prompt, *args),
|
| 401 |
+
'negative_prompt_hash': lambda self, *args: self.string_hash(self.p.negative_prompt, *args),
|
| 402 |
+
'full_prompt_hash': lambda self, *args: self.string_hash(f"{self.p.prompt} {self.p.negative_prompt}", *args), # a space in between to create a unique string
|
| 403 |
+
'prompt': lambda self: sanitize_filename_part(self.prompt),
|
| 404 |
+
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
| 405 |
+
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
| 406 |
+
'prompt_words': lambda self: self.prompt_words(),
|
| 407 |
+
'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 or self.zip else self.p.batch_index + 1,
|
| 408 |
+
'batch_size': lambda self: self.p.batch_size,
|
| 409 |
+
'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if (self.p.n_iter == 1 and self.p.batch_size == 1) or self.zip else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
|
| 410 |
+
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
|
| 411 |
+
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
|
| 412 |
+
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
|
| 413 |
+
'user': lambda self: self.p.user,
|
| 414 |
+
'vae_filename': lambda self: self.get_vae_filename(),
|
| 415 |
+
'none': lambda self: '', # Overrides the default, so you can get just the sequence number
|
| 416 |
+
'image_hash': lambda self, *args: self.image_hash(*args) # accepts formats: [image_hash<length>] default full hash
|
| 417 |
+
}
|
| 418 |
+
default_time_format = '%Y%m%d%H%M%S'
|
| 419 |
+
|
| 420 |
+
def __init__(self, p, seed, prompt, image, zip=False, basename=""):
|
| 421 |
+
self.p = p
|
| 422 |
+
self.seed = seed
|
| 423 |
+
self.prompt = prompt
|
| 424 |
+
self.image = image
|
| 425 |
+
self.zip = zip
|
| 426 |
+
self.basename = basename
|
| 427 |
+
|
| 428 |
+
def get_vae_filename(self):
|
| 429 |
+
"""Get the name of the VAE file."""
|
| 430 |
+
|
| 431 |
+
import modules.sd_vae as sd_vae
|
| 432 |
+
|
| 433 |
+
if sd_vae.loaded_vae_file is None:
|
| 434 |
+
return "NoneType"
|
| 435 |
+
|
| 436 |
+
file_name = os.path.basename(sd_vae.loaded_vae_file)
|
| 437 |
+
split_file_name = file_name.split('.')
|
| 438 |
+
if len(split_file_name) > 1 and split_file_name[0] == '':
|
| 439 |
+
return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
|
| 440 |
+
else:
|
| 441 |
+
return split_file_name[0]
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def hasprompt(self, *args):
|
| 445 |
+
lower = self.prompt.lower()
|
| 446 |
+
if self.p is None or self.prompt is None:
|
| 447 |
+
return None
|
| 448 |
+
outres = ""
|
| 449 |
+
for arg in args:
|
| 450 |
+
if arg != "":
|
| 451 |
+
division = arg.split("|")
|
| 452 |
+
expected = division[0].lower()
|
| 453 |
+
default = division[1] if len(division) > 1 else ""
|
| 454 |
+
if lower.find(expected) >= 0:
|
| 455 |
+
outres = f'{outres}{expected}'
|
| 456 |
+
else:
|
| 457 |
+
outres = outres if default == "" else f'{outres}{default}'
|
| 458 |
+
return sanitize_filename_part(outres)
|
| 459 |
+
|
| 460 |
+
def prompt_no_style(self):
|
| 461 |
+
if self.p is None or self.prompt is None:
|
| 462 |
+
return None
|
| 463 |
+
|
| 464 |
+
prompt_no_style = self.prompt
|
| 465 |
+
for style in shared.prompt_styles.get_style_prompts(self.p.styles):
|
| 466 |
+
if style:
|
| 467 |
+
for part in style.split("{prompt}"):
|
| 468 |
+
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
|
| 469 |
+
|
| 470 |
+
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
|
| 471 |
+
|
| 472 |
+
return sanitize_filename_part(prompt_no_style, replace_spaces=False)
|
| 473 |
+
|
| 474 |
+
def prompt_words(self):
|
| 475 |
+
words = [x for x in re_nonletters.split(self.prompt or "") if x]
|
| 476 |
+
if len(words) == 0:
|
| 477 |
+
words = ["empty"]
|
| 478 |
+
return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
|
| 479 |
+
|
| 480 |
+
def datetime(self, *args):
|
| 481 |
+
time_datetime = datetime.datetime.now()
|
| 482 |
+
|
| 483 |
+
time_format = args[0] if (args and args[0] != "") else self.default_time_format
|
| 484 |
+
try:
|
| 485 |
+
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
|
| 486 |
+
except pytz.exceptions.UnknownTimeZoneError:
|
| 487 |
+
time_zone = None
|
| 488 |
+
|
| 489 |
+
time_zone_time = time_datetime.astimezone(time_zone)
|
| 490 |
+
try:
|
| 491 |
+
formatted_time = time_zone_time.strftime(time_format)
|
| 492 |
+
except (ValueError, TypeError):
|
| 493 |
+
formatted_time = time_zone_time.strftime(self.default_time_format)
|
| 494 |
+
|
| 495 |
+
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
| 496 |
+
|
| 497 |
+
def image_hash(self, *args):
|
| 498 |
+
length = int(args[0]) if (args and args[0] != "") else None
|
| 499 |
+
return hashlib.sha256(self.image.tobytes()).hexdigest()[0:length]
|
| 500 |
+
|
| 501 |
+
def string_hash(self, text, *args):
|
| 502 |
+
length = int(args[0]) if (args and args[0] != "") else 8
|
| 503 |
+
return hashlib.sha256(text.encode()).hexdigest()[0:length]
|
| 504 |
+
|
| 505 |
+
def apply(self, x):
|
| 506 |
+
res = ''
|
| 507 |
+
|
| 508 |
+
for m in re_pattern.finditer(x):
|
| 509 |
+
text, pattern = m.groups()
|
| 510 |
+
|
| 511 |
+
if pattern is None:
|
| 512 |
+
res += text
|
| 513 |
+
continue
|
| 514 |
+
|
| 515 |
+
pattern_args = []
|
| 516 |
+
while True:
|
| 517 |
+
m = re_pattern_arg.match(pattern)
|
| 518 |
+
if m is None:
|
| 519 |
+
break
|
| 520 |
+
|
| 521 |
+
pattern, arg = m.groups()
|
| 522 |
+
pattern_args.insert(0, arg)
|
| 523 |
+
|
| 524 |
+
fun = self.replacements.get(pattern.lower())
|
| 525 |
+
if fun is not None:
|
| 526 |
+
try:
|
| 527 |
+
replacement = fun(self, *pattern_args)
|
| 528 |
+
except Exception:
|
| 529 |
+
replacement = None
|
| 530 |
+
errors.report(f"Error adding [{pattern}] to filename", exc_info=True)
|
| 531 |
+
|
| 532 |
+
if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
|
| 533 |
+
continue
|
| 534 |
+
elif replacement is not None:
|
| 535 |
+
res += text + str(replacement)
|
| 536 |
+
continue
|
| 537 |
+
|
| 538 |
+
res += f'{text}[{pattern}]'
|
| 539 |
+
|
| 540 |
+
return res
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
def get_next_sequence_number(path, basename):
|
| 544 |
+
"""
|
| 545 |
+
Determines and returns the next sequence number to use when saving an image in the specified directory.
|
| 546 |
+
|
| 547 |
+
The sequence starts at 0.
|
| 548 |
+
"""
|
| 549 |
+
result = -1
|
| 550 |
+
if basename != '':
|
| 551 |
+
basename = f"{basename}-"
|
| 552 |
+
|
| 553 |
+
prefix_length = len(basename)
|
| 554 |
+
for p in os.listdir(path):
|
| 555 |
+
if p.startswith(basename):
|
| 556 |
+
parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
| 557 |
+
try:
|
| 558 |
+
result = max(int(parts[0]), result)
|
| 559 |
+
except ValueError:
|
| 560 |
+
pass
|
| 561 |
+
|
| 562 |
+
return result + 1
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None, pnginfo_section_name='parameters'):
|
| 566 |
+
"""
|
| 567 |
+
Saves image to filename, including geninfo as text information for generation info.
|
| 568 |
+
For PNG images, geninfo is added to existing pnginfo dictionary using the pnginfo_section_name argument as key.
|
| 569 |
+
For JPG images, there's no dictionary and geninfo just replaces the EXIF description.
|
| 570 |
+
"""
|
| 571 |
+
|
| 572 |
+
if extension is None:
|
| 573 |
+
extension = os.path.splitext(filename)[1]
|
| 574 |
+
|
| 575 |
+
image_format = Image.registered_extensions()[extension]
|
| 576 |
+
|
| 577 |
+
if extension.lower() == '.png':
|
| 578 |
+
existing_pnginfo = existing_pnginfo or {}
|
| 579 |
+
if opts.enable_pnginfo:
|
| 580 |
+
existing_pnginfo[pnginfo_section_name] = geninfo
|
| 581 |
+
|
| 582 |
+
if opts.enable_pnginfo:
|
| 583 |
+
pnginfo_data = PngImagePlugin.PngInfo()
|
| 584 |
+
for k, v in (existing_pnginfo or {}).items():
|
| 585 |
+
pnginfo_data.add_text(k, str(v))
|
| 586 |
+
else:
|
| 587 |
+
pnginfo_data = None
|
| 588 |
+
|
| 589 |
+
image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
| 590 |
+
|
| 591 |
+
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
|
| 592 |
+
if image.mode == 'RGBA':
|
| 593 |
+
image = image.convert("RGB")
|
| 594 |
+
elif image.mode == 'I;16':
|
| 595 |
+
image = image.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
|
| 596 |
+
|
| 597 |
+
image.save(filename, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
|
| 598 |
+
|
| 599 |
+
if opts.enable_pnginfo and geninfo is not None:
|
| 600 |
+
exif_bytes = piexif.dump({
|
| 601 |
+
"Exif": {
|
| 602 |
+
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(geninfo or "", encoding="unicode")
|
| 603 |
+
},
|
| 604 |
+
})
|
| 605 |
+
|
| 606 |
+
piexif.insert(exif_bytes, filename)
|
| 607 |
+
elif extension.lower() == '.avif':
|
| 608 |
+
if opts.enable_pnginfo and geninfo is not None:
|
| 609 |
+
exif_bytes = piexif.dump({
|
| 610 |
+
"Exif": {
|
| 611 |
+
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(geninfo or "", encoding="unicode")
|
| 612 |
+
},
|
| 613 |
+
})
|
| 614 |
+
else:
|
| 615 |
+
exif_bytes = None
|
| 616 |
+
|
| 617 |
+
image.save(filename,format=image_format, quality=opts.jpeg_quality, exif=exif_bytes)
|
| 618 |
+
elif extension.lower() == ".gif":
|
| 619 |
+
image.save(filename, format=image_format, comment=geninfo)
|
| 620 |
+
else:
|
| 621 |
+
image.save(filename, format=image_format, quality=opts.jpeg_quality)
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
|
| 625 |
+
"""Save an image.
|
| 626 |
+
|
| 627 |
+
Args:
|
| 628 |
+
image (`PIL.Image`):
|
| 629 |
+
The image to be saved.
|
| 630 |
+
path (`str`):
|
| 631 |
+
The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
|
| 632 |
+
basename (`str`):
|
| 633 |
+
The base filename which will be applied to `filename pattern`.
|
| 634 |
+
seed, prompt, short_filename,
|
| 635 |
+
extension (`str`):
|
| 636 |
+
Image file extension, default is `png`.
|
| 637 |
+
pngsectionname (`str`):
|
| 638 |
+
Specify the name of the section which `info` will be saved in.
|
| 639 |
+
info (`str` or `PngImagePlugin.iTXt`):
|
| 640 |
+
PNG info chunks.
|
| 641 |
+
existing_info (`dict`):
|
| 642 |
+
Additional PNG info. `existing_info == {pngsectionname: info, ...}`
|
| 643 |
+
no_prompt:
|
| 644 |
+
TODO I don't know its meaning.
|
| 645 |
+
p (`StableDiffusionProcessing`)
|
| 646 |
+
forced_filename (`str`):
|
| 647 |
+
If specified, `basename` and filename pattern will be ignored.
|
| 648 |
+
save_to_dirs (bool):
|
| 649 |
+
If true, the image will be saved into a subdirectory of `path`.
|
| 650 |
+
|
| 651 |
+
Returns: (fullfn, txt_fullfn)
|
| 652 |
+
fullfn (`str`):
|
| 653 |
+
The full path of the saved imaged.
|
| 654 |
+
txt_fullfn (`str` or None):
|
| 655 |
+
If a text file is saved for this image, this will be its full path. Otherwise None.
|
| 656 |
+
"""
|
| 657 |
+
namegen = FilenameGenerator(p, seed, prompt, image, basename=basename)
|
| 658 |
+
|
| 659 |
+
# WebP and JPG formats have maximum dimension limits of 16383 and 65535 respectively. switch to PNG which has a much higher limit
|
| 660 |
+
if (image.height > 65535 or image.width > 65535) and extension.lower() in ("jpg", "jpeg") or (image.height > 16383 or image.width > 16383) and extension.lower() == "webp":
|
| 661 |
+
print('Image dimensions too large; saving as PNG')
|
| 662 |
+
extension = "png"
|
| 663 |
+
|
| 664 |
+
if save_to_dirs is None:
|
| 665 |
+
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
|
| 666 |
+
|
| 667 |
+
if save_to_dirs:
|
| 668 |
+
dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
|
| 669 |
+
path = os.path.join(path, dirname)
|
| 670 |
+
|
| 671 |
+
os.makedirs(path, exist_ok=True)
|
| 672 |
+
|
| 673 |
+
if forced_filename is None:
|
| 674 |
+
if short_filename or seed is None:
|
| 675 |
+
file_decoration = ""
|
| 676 |
+
elif opts.save_to_dirs:
|
| 677 |
+
file_decoration = opts.samples_filename_pattern or "[seed]"
|
| 678 |
+
else:
|
| 679 |
+
file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
|
| 680 |
+
|
| 681 |
+
file_decoration = namegen.apply(file_decoration) + suffix
|
| 682 |
+
|
| 683 |
+
add_number = opts.save_images_add_number or file_decoration == ''
|
| 684 |
+
|
| 685 |
+
if file_decoration != "" and add_number:
|
| 686 |
+
file_decoration = f"-{file_decoration}"
|
| 687 |
+
|
| 688 |
+
if add_number:
|
| 689 |
+
basecount = get_next_sequence_number(path, basename)
|
| 690 |
+
fullfn = None
|
| 691 |
+
for i in range(500):
|
| 692 |
+
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
|
| 693 |
+
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
|
| 694 |
+
if not os.path.exists(fullfn):
|
| 695 |
+
break
|
| 696 |
+
else:
|
| 697 |
+
fullfn = os.path.join(path, f"{file_decoration}.{extension}")
|
| 698 |
+
else:
|
| 699 |
+
fullfn = os.path.join(path, f"{forced_filename}.{extension}")
|
| 700 |
+
|
| 701 |
+
pnginfo = existing_info or {}
|
| 702 |
+
if info is not None:
|
| 703 |
+
pnginfo[pnginfo_section_name] = info
|
| 704 |
+
|
| 705 |
+
params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
|
| 706 |
+
script_callbacks.before_image_saved_callback(params)
|
| 707 |
+
|
| 708 |
+
image = params.image
|
| 709 |
+
fullfn = params.filename
|
| 710 |
+
info = params.pnginfo.get(pnginfo_section_name, None)
|
| 711 |
+
|
| 712 |
+
def _atomically_save_image(image_to_save, filename_without_extension, extension):
|
| 713 |
+
"""
|
| 714 |
+
save image with .tmp extension to avoid race condition when another process detects new image in the directory
|
| 715 |
+
"""
|
| 716 |
+
temp_file_path = f"{filename_without_extension}.tmp"
|
| 717 |
+
|
| 718 |
+
save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
|
| 719 |
+
|
| 720 |
+
filename = filename_without_extension + extension
|
| 721 |
+
if shared.opts.save_images_replace_action != "Replace":
|
| 722 |
+
n = 0
|
| 723 |
+
while os.path.exists(filename):
|
| 724 |
+
n += 1
|
| 725 |
+
filename = f"{filename_without_extension}-{n}{extension}"
|
| 726 |
+
os.replace(temp_file_path, filename)
|
| 727 |
+
|
| 728 |
+
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
| 729 |
+
if hasattr(os, 'statvfs'):
|
| 730 |
+
max_name_len = os.statvfs(path).f_namemax
|
| 731 |
+
fullfn_without_extension = fullfn_without_extension[:max_name_len - max(4, len(extension))]
|
| 732 |
+
params.filename = fullfn_without_extension + extension
|
| 733 |
+
fullfn = params.filename
|
| 734 |
+
_atomically_save_image(image, fullfn_without_extension, extension)
|
| 735 |
+
|
| 736 |
+
image.already_saved_as = fullfn
|
| 737 |
+
|
| 738 |
+
oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
|
| 739 |
+
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
|
| 740 |
+
ratio = image.width / image.height
|
| 741 |
+
resize_to = None
|
| 742 |
+
if oversize and ratio > 1:
|
| 743 |
+
resize_to = round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)
|
| 744 |
+
elif oversize:
|
| 745 |
+
resize_to = round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)
|
| 746 |
+
|
| 747 |
+
if resize_to is not None:
|
| 748 |
+
try:
|
| 749 |
+
# Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16
|
| 750 |
+
image = image.resize(resize_to, LANCZOS)
|
| 751 |
+
except Exception:
|
| 752 |
+
image = image.resize(resize_to)
|
| 753 |
+
try:
|
| 754 |
+
_atomically_save_image(image, fullfn_without_extension, ".jpg")
|
| 755 |
+
except Exception as e:
|
| 756 |
+
errors.display(e, "saving image as downscaled JPG")
|
| 757 |
+
|
| 758 |
+
if opts.save_txt and info is not None:
|
| 759 |
+
txt_fullfn = f"{fullfn_without_extension}.txt"
|
| 760 |
+
with open(txt_fullfn, "w", encoding="utf8") as file:
|
| 761 |
+
file.write(f"{info}\n")
|
| 762 |
+
else:
|
| 763 |
+
txt_fullfn = None
|
| 764 |
+
|
| 765 |
+
script_callbacks.image_saved_callback(params)
|
| 766 |
+
|
| 767 |
+
return fullfn, txt_fullfn
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
IGNORED_INFO_KEYS = {
|
| 771 |
+
'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
| 772 |
+
'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
|
| 773 |
+
'icc_profile', 'chromaticity', 'photoshop',
|
| 774 |
+
}
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
|
| 778 |
+
items = (image.info or {}).copy()
|
| 779 |
+
|
| 780 |
+
geninfo = items.pop('parameters', None)
|
| 781 |
+
|
| 782 |
+
if "exif" in items:
|
| 783 |
+
exif_data = items["exif"]
|
| 784 |
+
try:
|
| 785 |
+
exif = piexif.load(exif_data)
|
| 786 |
+
except OSError:
|
| 787 |
+
# memory / exif was not valid so piexif tried to read from a file
|
| 788 |
+
exif = None
|
| 789 |
+
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
|
| 790 |
+
try:
|
| 791 |
+
exif_comment = piexif.helper.UserComment.load(exif_comment)
|
| 792 |
+
except ValueError:
|
| 793 |
+
exif_comment = exif_comment.decode('utf8', errors="ignore")
|
| 794 |
+
|
| 795 |
+
if exif_comment:
|
| 796 |
+
geninfo = exif_comment
|
| 797 |
+
elif "comment" in items: # for gif
|
| 798 |
+
if isinstance(items["comment"], bytes):
|
| 799 |
+
geninfo = items["comment"].decode('utf8', errors="ignore")
|
| 800 |
+
else:
|
| 801 |
+
geninfo = items["comment"]
|
| 802 |
+
|
| 803 |
+
for field in IGNORED_INFO_KEYS:
|
| 804 |
+
items.pop(field, None)
|
| 805 |
+
|
| 806 |
+
if items.get("Software", None) == "NovelAI":
|
| 807 |
+
try:
|
| 808 |
+
json_info = json.loads(items["Comment"])
|
| 809 |
+
sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
|
| 810 |
+
|
| 811 |
+
geninfo = f"""{items["Description"]}
|
| 812 |
+
Negative prompt: {json_info["uc"]}
|
| 813 |
+
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
|
| 814 |
+
except Exception:
|
| 815 |
+
errors.report("Error parsing NovelAI image generation parameters", exc_info=True)
|
| 816 |
+
|
| 817 |
+
return geninfo, items
|
| 818 |
+
|
| 819 |
+
|
| 820 |
+
def image_data(data):
|
| 821 |
+
import gradio as gr
|
| 822 |
+
|
| 823 |
+
try:
|
| 824 |
+
image = read(io.BytesIO(data))
|
| 825 |
+
textinfo, _ = read_info_from_image(image)
|
| 826 |
+
return textinfo, None
|
| 827 |
+
except Exception:
|
| 828 |
+
pass
|
| 829 |
+
|
| 830 |
+
try:
|
| 831 |
+
text = data.decode('utf8')
|
| 832 |
+
assert len(text) < 10000
|
| 833 |
+
return text, None
|
| 834 |
+
|
| 835 |
+
except Exception:
|
| 836 |
+
pass
|
| 837 |
+
|
| 838 |
+
return gr.update(), None
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
def flatten(img, bgcolor):
|
| 842 |
+
"""replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency"""
|
| 843 |
+
|
| 844 |
+
if img.mode == "RGBA":
|
| 845 |
+
background = Image.new('RGBA', img.size, bgcolor)
|
| 846 |
+
background.paste(img, mask=img)
|
| 847 |
+
img = background
|
| 848 |
+
|
| 849 |
+
return img.convert('RGB')
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
def read(fp, **kwargs):
|
| 853 |
+
image = Image.open(fp, **kwargs)
|
| 854 |
+
image = fix_image(image)
|
| 855 |
+
|
| 856 |
+
return image
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
def fix_image(image: Image.Image):
|
| 860 |
+
if image is None:
|
| 861 |
+
return None
|
| 862 |
+
|
| 863 |
+
try:
|
| 864 |
+
image = ImageOps.exif_transpose(image)
|
| 865 |
+
image = fix_png_transparency(image)
|
| 866 |
+
except Exception:
|
| 867 |
+
pass
|
| 868 |
+
|
| 869 |
+
return image
|
| 870 |
+
|
| 871 |
+
|
| 872 |
+
def fix_png_transparency(image: Image.Image):
|
| 873 |
+
if image.mode not in ("RGB", "P") or not isinstance(image.info.get("transparency"), bytes):
|
| 874 |
+
return image
|
| 875 |
+
|
| 876 |
+
image = image.convert("RGBA")
|
| 877 |
+
return image
|
modules/img2img.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from contextlib import closing
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
|
| 7 |
+
import gradio as gr
|
| 8 |
+
|
| 9 |
+
from modules import images
|
| 10 |
+
from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
|
| 11 |
+
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
| 12 |
+
from modules.shared import opts, state
|
| 13 |
+
from modules.sd_models import get_closet_checkpoint_match
|
| 14 |
+
import modules.shared as shared
|
| 15 |
+
import modules.processing as processing
|
| 16 |
+
from modules.ui import plaintext_to_html
|
| 17 |
+
import modules.scripts
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def process_batch(p, input, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
|
| 21 |
+
output_dir = output_dir.strip()
|
| 22 |
+
processing.fix_seed(p)
|
| 23 |
+
|
| 24 |
+
if isinstance(input, str):
|
| 25 |
+
batch_images = list(shared.walk_files(input, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
|
| 26 |
+
else:
|
| 27 |
+
batch_images = [os.path.abspath(x.name) for x in input]
|
| 28 |
+
|
| 29 |
+
is_inpaint_batch = False
|
| 30 |
+
if inpaint_mask_dir:
|
| 31 |
+
inpaint_masks = shared.listfiles(inpaint_mask_dir)
|
| 32 |
+
is_inpaint_batch = bool(inpaint_masks)
|
| 33 |
+
|
| 34 |
+
if is_inpaint_batch:
|
| 35 |
+
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
|
| 36 |
+
|
| 37 |
+
print(f"Will process {len(batch_images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
| 38 |
+
|
| 39 |
+
state.job_count = len(batch_images) * p.n_iter
|
| 40 |
+
|
| 41 |
+
# extract "default" params to use in case getting png info fails
|
| 42 |
+
prompt = p.prompt
|
| 43 |
+
negative_prompt = p.negative_prompt
|
| 44 |
+
seed = p.seed
|
| 45 |
+
cfg_scale = p.cfg_scale
|
| 46 |
+
sampler_name = p.sampler_name
|
| 47 |
+
steps = p.steps
|
| 48 |
+
override_settings = p.override_settings
|
| 49 |
+
sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
|
| 50 |
+
batch_results = None
|
| 51 |
+
discard_further_results = False
|
| 52 |
+
for i, image in enumerate(batch_images):
|
| 53 |
+
state.job = f"{i+1} out of {len(batch_images)}"
|
| 54 |
+
if state.skipped:
|
| 55 |
+
state.skipped = False
|
| 56 |
+
|
| 57 |
+
if state.interrupted or state.stopping_generation:
|
| 58 |
+
break
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
img = images.read(image)
|
| 62 |
+
except UnidentifiedImageError as e:
|
| 63 |
+
print(e)
|
| 64 |
+
continue
|
| 65 |
+
# Use the EXIF orientation of photos taken by smartphones.
|
| 66 |
+
img = ImageOps.exif_transpose(img)
|
| 67 |
+
|
| 68 |
+
if to_scale:
|
| 69 |
+
p.width = int(img.width * scale_by)
|
| 70 |
+
p.height = int(img.height * scale_by)
|
| 71 |
+
|
| 72 |
+
p.init_images = [img] * p.batch_size
|
| 73 |
+
|
| 74 |
+
image_path = Path(image)
|
| 75 |
+
if is_inpaint_batch:
|
| 76 |
+
# try to find corresponding mask for an image using simple filename matching
|
| 77 |
+
if len(inpaint_masks) == 1:
|
| 78 |
+
mask_image_path = inpaint_masks[0]
|
| 79 |
+
else:
|
| 80 |
+
# try to find corresponding mask for an image using simple filename matching
|
| 81 |
+
mask_image_dir = Path(inpaint_mask_dir)
|
| 82 |
+
masks_found = list(mask_image_dir.glob(f"{image_path.stem}.*"))
|
| 83 |
+
|
| 84 |
+
if len(masks_found) == 0:
|
| 85 |
+
print(f"Warning: mask is not found for {image_path} in {mask_image_dir}. Skipping it.")
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
# it should contain only 1 matching mask
|
| 89 |
+
# otherwise user has many masks with the same name but different extensions
|
| 90 |
+
mask_image_path = masks_found[0]
|
| 91 |
+
|
| 92 |
+
mask_image = images.read(mask_image_path)
|
| 93 |
+
p.image_mask = mask_image
|
| 94 |
+
|
| 95 |
+
if use_png_info:
|
| 96 |
+
try:
|
| 97 |
+
info_img = img
|
| 98 |
+
if png_info_dir:
|
| 99 |
+
info_img_path = os.path.join(png_info_dir, os.path.basename(image))
|
| 100 |
+
info_img = images.read(info_img_path)
|
| 101 |
+
geninfo, _ = images.read_info_from_image(info_img)
|
| 102 |
+
parsed_parameters = parse_generation_parameters(geninfo)
|
| 103 |
+
parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
|
| 104 |
+
except Exception:
|
| 105 |
+
parsed_parameters = {}
|
| 106 |
+
|
| 107 |
+
p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "")
|
| 108 |
+
p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "")
|
| 109 |
+
p.seed = int(parsed_parameters.get("Seed", seed))
|
| 110 |
+
p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale))
|
| 111 |
+
p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
|
| 112 |
+
p.steps = int(parsed_parameters.get("Steps", steps))
|
| 113 |
+
|
| 114 |
+
model_info = get_closet_checkpoint_match(parsed_parameters.get("Model hash", None))
|
| 115 |
+
if model_info is not None:
|
| 116 |
+
p.override_settings['sd_model_checkpoint'] = model_info.name
|
| 117 |
+
elif sd_model_checkpoint_override:
|
| 118 |
+
p.override_settings['sd_model_checkpoint'] = sd_model_checkpoint_override
|
| 119 |
+
else:
|
| 120 |
+
p.override_settings.pop("sd_model_checkpoint", None)
|
| 121 |
+
|
| 122 |
+
if output_dir:
|
| 123 |
+
p.outpath_samples = output_dir
|
| 124 |
+
p.override_settings['save_to_dirs'] = False
|
| 125 |
+
p.override_settings['save_images_replace_action'] = "Add number suffix"
|
| 126 |
+
if p.n_iter > 1 or p.batch_size > 1:
|
| 127 |
+
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
|
| 128 |
+
else:
|
| 129 |
+
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
|
| 130 |
+
|
| 131 |
+
proc = modules.scripts.scripts_img2img.run(p, *args)
|
| 132 |
+
|
| 133 |
+
if proc is None:
|
| 134 |
+
p.override_settings.pop('save_images_replace_action', None)
|
| 135 |
+
proc = process_images(p)
|
| 136 |
+
|
| 137 |
+
if not discard_further_results and proc:
|
| 138 |
+
if batch_results:
|
| 139 |
+
batch_results.images.extend(proc.images)
|
| 140 |
+
batch_results.infotexts.extend(proc.infotexts)
|
| 141 |
+
else:
|
| 142 |
+
batch_results = proc
|
| 143 |
+
|
| 144 |
+
if 0 <= shared.opts.img2img_batch_show_results_limit < len(batch_results.images):
|
| 145 |
+
discard_further_results = True
|
| 146 |
+
batch_results.images = batch_results.images[:int(shared.opts.img2img_batch_show_results_limit)]
|
| 147 |
+
batch_results.infotexts = batch_results.infotexts[:int(shared.opts.img2img_batch_show_results_limit)]
|
| 148 |
+
|
| 149 |
+
return batch_results
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def img2img(id_task: str, request: gr.Request, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, img2img_batch_source_type: str, img2img_batch_upload: list, *args):
|
| 153 |
+
override_settings = create_override_settings_dict(override_settings_texts)
|
| 154 |
+
|
| 155 |
+
is_batch = mode == 5
|
| 156 |
+
|
| 157 |
+
if mode == 0: # img2img
|
| 158 |
+
image = init_img
|
| 159 |
+
mask = None
|
| 160 |
+
elif mode == 1: # img2img sketch
|
| 161 |
+
image = sketch
|
| 162 |
+
mask = None
|
| 163 |
+
elif mode == 2: # inpaint
|
| 164 |
+
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
| 165 |
+
mask = processing.create_binary_mask(mask)
|
| 166 |
+
elif mode == 3: # inpaint sketch
|
| 167 |
+
image = inpaint_color_sketch
|
| 168 |
+
orig = inpaint_color_sketch_orig or inpaint_color_sketch
|
| 169 |
+
pred = np.any(np.array(image) != np.array(orig), axis=-1)
|
| 170 |
+
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
|
| 171 |
+
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
| 172 |
+
blur = ImageFilter.GaussianBlur(mask_blur)
|
| 173 |
+
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
| 174 |
+
elif mode == 4: # inpaint upload mask
|
| 175 |
+
image = init_img_inpaint
|
| 176 |
+
mask = init_mask_inpaint
|
| 177 |
+
else:
|
| 178 |
+
image = None
|
| 179 |
+
mask = None
|
| 180 |
+
|
| 181 |
+
image = images.fix_image(image)
|
| 182 |
+
mask = images.fix_image(mask)
|
| 183 |
+
|
| 184 |
+
if selected_scale_tab == 1 and not is_batch:
|
| 185 |
+
assert image, "Can't scale by because no image is selected"
|
| 186 |
+
|
| 187 |
+
width = int(image.width * scale_by)
|
| 188 |
+
height = int(image.height * scale_by)
|
| 189 |
+
|
| 190 |
+
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
| 191 |
+
|
| 192 |
+
p = StableDiffusionProcessingImg2Img(
|
| 193 |
+
sd_model=shared.sd_model,
|
| 194 |
+
outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
|
| 195 |
+
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
|
| 196 |
+
prompt=prompt,
|
| 197 |
+
negative_prompt=negative_prompt,
|
| 198 |
+
styles=prompt_styles,
|
| 199 |
+
batch_size=batch_size,
|
| 200 |
+
n_iter=n_iter,
|
| 201 |
+
cfg_scale=cfg_scale,
|
| 202 |
+
width=width,
|
| 203 |
+
height=height,
|
| 204 |
+
init_images=[image],
|
| 205 |
+
mask=mask,
|
| 206 |
+
mask_blur=mask_blur,
|
| 207 |
+
inpainting_fill=inpainting_fill,
|
| 208 |
+
resize_mode=resize_mode,
|
| 209 |
+
denoising_strength=denoising_strength,
|
| 210 |
+
image_cfg_scale=image_cfg_scale,
|
| 211 |
+
inpaint_full_res=inpaint_full_res,
|
| 212 |
+
inpaint_full_res_padding=inpaint_full_res_padding,
|
| 213 |
+
inpainting_mask_invert=inpainting_mask_invert,
|
| 214 |
+
override_settings=override_settings,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
p.scripts = modules.scripts.scripts_img2img
|
| 218 |
+
p.script_args = args
|
| 219 |
+
|
| 220 |
+
p.user = request.username
|
| 221 |
+
|
| 222 |
+
if shared.opts.enable_console_prompts:
|
| 223 |
+
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
| 224 |
+
|
| 225 |
+
with closing(p):
|
| 226 |
+
if is_batch:
|
| 227 |
+
if img2img_batch_source_type == "upload":
|
| 228 |
+
assert isinstance(img2img_batch_upload, list) and img2img_batch_upload
|
| 229 |
+
output_dir = ""
|
| 230 |
+
inpaint_mask_dir = ""
|
| 231 |
+
png_info_dir = img2img_batch_png_info_dir if not shared.cmd_opts.hide_ui_dir_config else ""
|
| 232 |
+
processed = process_batch(p, img2img_batch_upload, output_dir, inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=png_info_dir)
|
| 233 |
+
else: # "from dir"
|
| 234 |
+
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
| 235 |
+
processed = process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
|
| 236 |
+
|
| 237 |
+
if processed is None:
|
| 238 |
+
processed = Processed(p, [], p.seed, "")
|
| 239 |
+
else:
|
| 240 |
+
processed = modules.scripts.scripts_img2img.run(p, *args)
|
| 241 |
+
if processed is None:
|
| 242 |
+
processed = process_images(p)
|
| 243 |
+
|
| 244 |
+
shared.total_tqdm.clear()
|
| 245 |
+
|
| 246 |
+
generation_info_js = processed.js()
|
| 247 |
+
if opts.samples_log_stdout:
|
| 248 |
+
print(generation_info_js)
|
| 249 |
+
|
| 250 |
+
if opts.do_not_show_images:
|
| 251 |
+
processed.images = []
|
| 252 |
+
|
| 253 |
+
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
|
modules/import_hook.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
|
| 4 |
+
if "--xformers" not in "".join(sys.argv):
|
| 5 |
+
sys.modules["xformers"] = None
|
| 6 |
+
|
| 7 |
+
# Hack to fix a changed import in torchvision 0.17+, which otherwise breaks
|
| 8 |
+
# basicsr; see https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13985
|
| 9 |
+
try:
|
| 10 |
+
import torchvision.transforms.functional_tensor # noqa: F401
|
| 11 |
+
except ImportError:
|
| 12 |
+
try:
|
| 13 |
+
import torchvision.transforms.functional as functional
|
| 14 |
+
sys.modules["torchvision.transforms.functional_tensor"] = functional
|
| 15 |
+
except ImportError:
|
| 16 |
+
pass # shrug...
|
modules/infotext_utils.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import base64
|
| 3 |
+
import io
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
import gradio as gr
|
| 10 |
+
from modules.paths import data_path
|
| 11 |
+
from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions, images, prompt_parser, errors
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name
|
| 15 |
+
|
| 16 |
+
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
|
| 17 |
+
re_param = re.compile(re_param_code)
|
| 18 |
+
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
| 19 |
+
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
|
| 20 |
+
type_of_gr_update = type(gr.update())
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ParamBinding:
|
| 24 |
+
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
|
| 25 |
+
self.paste_button = paste_button
|
| 26 |
+
self.tabname = tabname
|
| 27 |
+
self.source_text_component = source_text_component
|
| 28 |
+
self.source_image_component = source_image_component
|
| 29 |
+
self.source_tabname = source_tabname
|
| 30 |
+
self.override_settings_component = override_settings_component
|
| 31 |
+
self.paste_field_names = paste_field_names or []
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class PasteField(tuple):
|
| 35 |
+
def __new__(cls, component, target, *, api=None):
|
| 36 |
+
return super().__new__(cls, (component, target))
|
| 37 |
+
|
| 38 |
+
def __init__(self, component, target, *, api=None):
|
| 39 |
+
super().__init__()
|
| 40 |
+
|
| 41 |
+
self.api = api
|
| 42 |
+
self.component = component
|
| 43 |
+
self.label = target if isinstance(target, str) else None
|
| 44 |
+
self.function = target if callable(target) else None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
paste_fields: dict[str, dict] = {}
|
| 48 |
+
registered_param_bindings: list[ParamBinding] = []
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def reset():
|
| 52 |
+
paste_fields.clear()
|
| 53 |
+
registered_param_bindings.clear()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def quote(text):
|
| 57 |
+
if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text):
|
| 58 |
+
return text
|
| 59 |
+
|
| 60 |
+
return json.dumps(text, ensure_ascii=False)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def unquote(text):
|
| 64 |
+
if len(text) == 0 or text[0] != '"' or text[-1] != '"':
|
| 65 |
+
return text
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
return json.loads(text)
|
| 69 |
+
except Exception:
|
| 70 |
+
return text
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def image_from_url_text(filedata):
|
| 74 |
+
if filedata is None:
|
| 75 |
+
return None
|
| 76 |
+
|
| 77 |
+
if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False):
|
| 78 |
+
filedata = filedata[0]
|
| 79 |
+
|
| 80 |
+
if type(filedata) == dict and filedata.get("is_file", False):
|
| 81 |
+
filename = filedata["name"]
|
| 82 |
+
is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
|
| 83 |
+
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
|
| 84 |
+
|
| 85 |
+
filename = filename.rsplit('?', 1)[0]
|
| 86 |
+
return images.read(filename)
|
| 87 |
+
|
| 88 |
+
if type(filedata) == list:
|
| 89 |
+
if len(filedata) == 0:
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
filedata = filedata[0]
|
| 93 |
+
|
| 94 |
+
if filedata.startswith("data:image/png;base64,"):
|
| 95 |
+
filedata = filedata[len("data:image/png;base64,"):]
|
| 96 |
+
|
| 97 |
+
filedata = base64.decodebytes(filedata.encode('utf-8'))
|
| 98 |
+
image = images.read(io.BytesIO(filedata))
|
| 99 |
+
return image
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
|
| 103 |
+
|
| 104 |
+
if fields:
|
| 105 |
+
for i in range(len(fields)):
|
| 106 |
+
if not isinstance(fields[i], PasteField):
|
| 107 |
+
fields[i] = PasteField(*fields[i])
|
| 108 |
+
|
| 109 |
+
paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
|
| 110 |
+
|
| 111 |
+
# backwards compatibility for existing extensions
|
| 112 |
+
import modules.ui
|
| 113 |
+
if tabname == 'txt2img':
|
| 114 |
+
modules.ui.txt2img_paste_fields = fields
|
| 115 |
+
elif tabname == 'img2img':
|
| 116 |
+
modules.ui.img2img_paste_fields = fields
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def create_buttons(tabs_list):
|
| 120 |
+
buttons = {}
|
| 121 |
+
for tab in tabs_list:
|
| 122 |
+
buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
|
| 123 |
+
return buttons
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def bind_buttons(buttons, send_image, send_generate_info):
|
| 127 |
+
"""old function for backwards compatibility; do not use this, use register_paste_params_button"""
|
| 128 |
+
for tabname, button in buttons.items():
|
| 129 |
+
source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
|
| 130 |
+
source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
|
| 131 |
+
|
| 132 |
+
register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def register_paste_params_button(binding: ParamBinding):
|
| 136 |
+
registered_param_bindings.append(binding)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def connect_paste_params_buttons():
|
| 140 |
+
for binding in registered_param_bindings:
|
| 141 |
+
destination_image_component = paste_fields[binding.tabname]["init_img"]
|
| 142 |
+
fields = paste_fields[binding.tabname]["fields"]
|
| 143 |
+
override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
|
| 144 |
+
|
| 145 |
+
destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
|
| 146 |
+
destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
|
| 147 |
+
|
| 148 |
+
if binding.source_image_component and destination_image_component:
|
| 149 |
+
need_send_dementions = destination_width_component and binding.tabname != 'inpaint'
|
| 150 |
+
if isinstance(binding.source_image_component, gr.Gallery):
|
| 151 |
+
func = send_image_and_dimensions if need_send_dementions else image_from_url_text
|
| 152 |
+
jsfunc = "extract_image_from_gallery"
|
| 153 |
+
else:
|
| 154 |
+
func = send_image_and_dimensions if need_send_dementions else lambda x: x
|
| 155 |
+
jsfunc = None
|
| 156 |
+
|
| 157 |
+
binding.paste_button.click(
|
| 158 |
+
fn=func,
|
| 159 |
+
_js=jsfunc,
|
| 160 |
+
inputs=[binding.source_image_component],
|
| 161 |
+
outputs=[destination_image_component, destination_width_component, destination_height_component] if need_send_dementions else [destination_image_component],
|
| 162 |
+
show_progress=False,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
if binding.source_text_component is not None and fields is not None:
|
| 166 |
+
connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
|
| 167 |
+
|
| 168 |
+
if binding.source_tabname is not None and fields is not None:
|
| 169 |
+
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names
|
| 170 |
+
binding.paste_button.click(
|
| 171 |
+
fn=lambda *x: x,
|
| 172 |
+
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
|
| 173 |
+
outputs=[field for field, name in fields if name in paste_field_names],
|
| 174 |
+
show_progress=False,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
binding.paste_button.click(
|
| 178 |
+
fn=None,
|
| 179 |
+
_js=f"switch_to_{binding.tabname}",
|
| 180 |
+
inputs=None,
|
| 181 |
+
outputs=None,
|
| 182 |
+
show_progress=False,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def send_image_and_dimensions(x):
|
| 187 |
+
if isinstance(x, Image.Image):
|
| 188 |
+
img = x
|
| 189 |
+
else:
|
| 190 |
+
img = image_from_url_text(x)
|
| 191 |
+
|
| 192 |
+
if shared.opts.send_size and isinstance(img, Image.Image):
|
| 193 |
+
w = img.width
|
| 194 |
+
h = img.height
|
| 195 |
+
else:
|
| 196 |
+
w = gr.update()
|
| 197 |
+
h = gr.update()
|
| 198 |
+
|
| 199 |
+
return img, w, h
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def restore_old_hires_fix_params(res):
|
| 203 |
+
"""for infotexts that specify old First pass size parameter, convert it into
|
| 204 |
+
width, height, and hr scale"""
|
| 205 |
+
|
| 206 |
+
firstpass_width = res.get('First pass size-1', None)
|
| 207 |
+
firstpass_height = res.get('First pass size-2', None)
|
| 208 |
+
|
| 209 |
+
if shared.opts.use_old_hires_fix_width_height:
|
| 210 |
+
hires_width = int(res.get("Hires resize-1", 0))
|
| 211 |
+
hires_height = int(res.get("Hires resize-2", 0))
|
| 212 |
+
|
| 213 |
+
if hires_width and hires_height:
|
| 214 |
+
res['Size-1'] = hires_width
|
| 215 |
+
res['Size-2'] = hires_height
|
| 216 |
+
return
|
| 217 |
+
|
| 218 |
+
if firstpass_width is None or firstpass_height is None:
|
| 219 |
+
return
|
| 220 |
+
|
| 221 |
+
firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
|
| 222 |
+
width = int(res.get("Size-1", 512))
|
| 223 |
+
height = int(res.get("Size-2", 512))
|
| 224 |
+
|
| 225 |
+
if firstpass_width == 0 or firstpass_height == 0:
|
| 226 |
+
firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
|
| 227 |
+
|
| 228 |
+
res['Size-1'] = firstpass_width
|
| 229 |
+
res['Size-2'] = firstpass_height
|
| 230 |
+
res['Hires resize-1'] = width
|
| 231 |
+
res['Hires resize-2'] = height
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def parse_generation_parameters(x: str, skip_fields: list[str] | None = None):
|
| 235 |
+
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
| 236 |
+
```
|
| 237 |
+
girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
|
| 238 |
+
Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
|
| 239 |
+
Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
|
| 240 |
+
```
|
| 241 |
+
|
| 242 |
+
returns a dict with field values
|
| 243 |
+
"""
|
| 244 |
+
if skip_fields is None:
|
| 245 |
+
skip_fields = shared.opts.infotext_skip_pasting
|
| 246 |
+
|
| 247 |
+
res = {}
|
| 248 |
+
|
| 249 |
+
prompt = ""
|
| 250 |
+
negative_prompt = ""
|
| 251 |
+
|
| 252 |
+
done_with_prompt = False
|
| 253 |
+
|
| 254 |
+
*lines, lastline = x.strip().split("\n")
|
| 255 |
+
if len(re_param.findall(lastline)) < 3:
|
| 256 |
+
lines.append(lastline)
|
| 257 |
+
lastline = ''
|
| 258 |
+
|
| 259 |
+
for line in lines:
|
| 260 |
+
line = line.strip()
|
| 261 |
+
if line.startswith("Negative prompt:"):
|
| 262 |
+
done_with_prompt = True
|
| 263 |
+
line = line[16:].strip()
|
| 264 |
+
if done_with_prompt:
|
| 265 |
+
negative_prompt += ("" if negative_prompt == "" else "\n") + line
|
| 266 |
+
else:
|
| 267 |
+
prompt += ("" if prompt == "" else "\n") + line
|
| 268 |
+
|
| 269 |
+
for k, v in re_param.findall(lastline):
|
| 270 |
+
try:
|
| 271 |
+
if v[0] == '"' and v[-1] == '"':
|
| 272 |
+
v = unquote(v)
|
| 273 |
+
|
| 274 |
+
m = re_imagesize.match(v)
|
| 275 |
+
if m is not None:
|
| 276 |
+
res[f"{k}-1"] = m.group(1)
|
| 277 |
+
res[f"{k}-2"] = m.group(2)
|
| 278 |
+
else:
|
| 279 |
+
res[k] = v
|
| 280 |
+
except Exception:
|
| 281 |
+
print(f"Error parsing \"{k}: {v}\"")
|
| 282 |
+
|
| 283 |
+
# Extract styles from prompt
|
| 284 |
+
if shared.opts.infotext_styles != "Ignore":
|
| 285 |
+
found_styles, prompt_no_styles, negative_prompt_no_styles = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)
|
| 286 |
+
|
| 287 |
+
same_hr_styles = True
|
| 288 |
+
if ("Hires prompt" in res or "Hires negative prompt" in res) and (infotext_ver > infotext_versions.v180_hr_styles if (infotext_ver := infotext_versions.parse_version(res.get("Version"))) else True):
|
| 289 |
+
hr_prompt, hr_negative_prompt = res.get("Hires prompt", prompt), res.get("Hires negative prompt", negative_prompt)
|
| 290 |
+
hr_found_styles, hr_prompt_no_styles, hr_negative_prompt_no_styles = shared.prompt_styles.extract_styles_from_prompt(hr_prompt, hr_negative_prompt)
|
| 291 |
+
if same_hr_styles := found_styles == hr_found_styles:
|
| 292 |
+
res["Hires prompt"] = '' if hr_prompt_no_styles == prompt_no_styles else hr_prompt_no_styles
|
| 293 |
+
res['Hires negative prompt'] = '' if hr_negative_prompt_no_styles == negative_prompt_no_styles else hr_negative_prompt_no_styles
|
| 294 |
+
|
| 295 |
+
if same_hr_styles:
|
| 296 |
+
prompt, negative_prompt = prompt_no_styles, negative_prompt_no_styles
|
| 297 |
+
if (shared.opts.infotext_styles == "Apply if any" and found_styles) or shared.opts.infotext_styles == "Apply":
|
| 298 |
+
res['Styles array'] = found_styles
|
| 299 |
+
|
| 300 |
+
res["Prompt"] = prompt
|
| 301 |
+
res["Negative prompt"] = negative_prompt
|
| 302 |
+
|
| 303 |
+
# Missing CLIP skip means it was set to 1 (the default)
|
| 304 |
+
if "Clip skip" not in res:
|
| 305 |
+
res["Clip skip"] = "1"
|
| 306 |
+
|
| 307 |
+
hypernet = res.get("Hypernet", None)
|
| 308 |
+
if hypernet is not None:
|
| 309 |
+
res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
|
| 310 |
+
|
| 311 |
+
if "Hires resize-1" not in res:
|
| 312 |
+
res["Hires resize-1"] = 0
|
| 313 |
+
res["Hires resize-2"] = 0
|
| 314 |
+
|
| 315 |
+
if "Hires sampler" not in res:
|
| 316 |
+
res["Hires sampler"] = "Use same sampler"
|
| 317 |
+
|
| 318 |
+
if "Hires schedule type" not in res:
|
| 319 |
+
res["Hires schedule type"] = "Use same scheduler"
|
| 320 |
+
|
| 321 |
+
if "Hires checkpoint" not in res:
|
| 322 |
+
res["Hires checkpoint"] = "Use same checkpoint"
|
| 323 |
+
|
| 324 |
+
if "Hires prompt" not in res:
|
| 325 |
+
res["Hires prompt"] = ""
|
| 326 |
+
|
| 327 |
+
if "Hires negative prompt" not in res:
|
| 328 |
+
res["Hires negative prompt"] = ""
|
| 329 |
+
|
| 330 |
+
if "Mask mode" not in res:
|
| 331 |
+
res["Mask mode"] = "Inpaint masked"
|
| 332 |
+
|
| 333 |
+
if "Masked content" not in res:
|
| 334 |
+
res["Masked content"] = 'original'
|
| 335 |
+
|
| 336 |
+
if "Inpaint area" not in res:
|
| 337 |
+
res["Inpaint area"] = "Whole picture"
|
| 338 |
+
|
| 339 |
+
if "Masked area padding" not in res:
|
| 340 |
+
res["Masked area padding"] = 32
|
| 341 |
+
|
| 342 |
+
restore_old_hires_fix_params(res)
|
| 343 |
+
|
| 344 |
+
# Missing RNG means the default was set, which is GPU RNG
|
| 345 |
+
if "RNG" not in res:
|
| 346 |
+
res["RNG"] = "GPU"
|
| 347 |
+
|
| 348 |
+
if "Schedule type" not in res:
|
| 349 |
+
res["Schedule type"] = "Automatic"
|
| 350 |
+
|
| 351 |
+
if "Schedule max sigma" not in res:
|
| 352 |
+
res["Schedule max sigma"] = 0
|
| 353 |
+
|
| 354 |
+
if "Schedule min sigma" not in res:
|
| 355 |
+
res["Schedule min sigma"] = 0
|
| 356 |
+
|
| 357 |
+
if "Schedule rho" not in res:
|
| 358 |
+
res["Schedule rho"] = 0
|
| 359 |
+
|
| 360 |
+
if "VAE Encoder" not in res:
|
| 361 |
+
res["VAE Encoder"] = "Full"
|
| 362 |
+
|
| 363 |
+
if "VAE Decoder" not in res:
|
| 364 |
+
res["VAE Decoder"] = "Full"
|
| 365 |
+
|
| 366 |
+
if "FP8 weight" not in res:
|
| 367 |
+
res["FP8 weight"] = "Disable"
|
| 368 |
+
|
| 369 |
+
if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
|
| 370 |
+
res["Cache FP16 weight for LoRA"] = False
|
| 371 |
+
|
| 372 |
+
prompt_attention = prompt_parser.parse_prompt_attention(prompt)
|
| 373 |
+
prompt_attention += prompt_parser.parse_prompt_attention(negative_prompt)
|
| 374 |
+
prompt_uses_emphasis = len(prompt_attention) != len([p for p in prompt_attention if p[1] == 1.0 or p[0] == 'BREAK'])
|
| 375 |
+
if "Emphasis" not in res and prompt_uses_emphasis:
|
| 376 |
+
res["Emphasis"] = "Original"
|
| 377 |
+
|
| 378 |
+
if "Refiner switch by sampling steps" not in res:
|
| 379 |
+
res["Refiner switch by sampling steps"] = False
|
| 380 |
+
|
| 381 |
+
infotext_versions.backcompat(res)
|
| 382 |
+
|
| 383 |
+
for key in skip_fields:
|
| 384 |
+
res.pop(key, None)
|
| 385 |
+
|
| 386 |
+
return res
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
infotext_to_setting_name_mapping = [
|
| 390 |
+
|
| 391 |
+
]
|
| 392 |
+
"""Mapping of infotext labels to setting names. Only left for backwards compatibility - use OptionInfo(..., infotext='...') instead.
|
| 393 |
+
Example content:
|
| 394 |
+
|
| 395 |
+
infotext_to_setting_name_mapping = [
|
| 396 |
+
('Conditional mask weight', 'inpainting_mask_weight'),
|
| 397 |
+
('Model hash', 'sd_model_checkpoint'),
|
| 398 |
+
('ENSD', 'eta_noise_seed_delta'),
|
| 399 |
+
('Schedule type', 'k_sched_type'),
|
| 400 |
+
]
|
| 401 |
+
"""
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def create_override_settings_dict(text_pairs):
|
| 405 |
+
"""creates processing's override_settings parameters from gradio's multiselect
|
| 406 |
+
|
| 407 |
+
Example input:
|
| 408 |
+
['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']
|
| 409 |
+
|
| 410 |
+
Example output:
|
| 411 |
+
{'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
|
| 412 |
+
"""
|
| 413 |
+
|
| 414 |
+
res = {}
|
| 415 |
+
|
| 416 |
+
params = {}
|
| 417 |
+
for pair in text_pairs:
|
| 418 |
+
k, v = pair.split(":", maxsplit=1)
|
| 419 |
+
|
| 420 |
+
params[k] = v.strip()
|
| 421 |
+
|
| 422 |
+
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
|
| 423 |
+
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
|
| 424 |
+
value = params.get(param_name, None)
|
| 425 |
+
|
| 426 |
+
if value is None:
|
| 427 |
+
continue
|
| 428 |
+
|
| 429 |
+
res[setting_name] = shared.opts.cast_value(setting_name, value)
|
| 430 |
+
|
| 431 |
+
return res
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def get_override_settings(params, *, skip_fields=None):
|
| 435 |
+
"""Returns a list of settings overrides from the infotext parameters dictionary.
|
| 436 |
+
|
| 437 |
+
This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns
|
| 438 |
+
a list of tuples containing the parameter name, setting name, and new value cast to correct type.
|
| 439 |
+
|
| 440 |
+
It checks for conditions before adding an override:
|
| 441 |
+
- ignores settings that match the current value
|
| 442 |
+
- ignores parameter keys present in skip_fields argument.
|
| 443 |
+
|
| 444 |
+
Example input:
|
| 445 |
+
{"Clip skip": "2"}
|
| 446 |
+
|
| 447 |
+
Example output:
|
| 448 |
+
[("Clip skip", "CLIP_stop_at_last_layers", 2)]
|
| 449 |
+
"""
|
| 450 |
+
|
| 451 |
+
res = []
|
| 452 |
+
|
| 453 |
+
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
|
| 454 |
+
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
|
| 455 |
+
if param_name in (skip_fields or {}):
|
| 456 |
+
continue
|
| 457 |
+
|
| 458 |
+
v = params.get(param_name, None)
|
| 459 |
+
if v is None:
|
| 460 |
+
continue
|
| 461 |
+
|
| 462 |
+
if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
|
| 463 |
+
continue
|
| 464 |
+
|
| 465 |
+
v = shared.opts.cast_value(setting_name, v)
|
| 466 |
+
current_value = getattr(shared.opts, setting_name, None)
|
| 467 |
+
|
| 468 |
+
if v == current_value:
|
| 469 |
+
continue
|
| 470 |
+
|
| 471 |
+
res.append((param_name, setting_name, v))
|
| 472 |
+
|
| 473 |
+
return res
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
|
| 477 |
+
def paste_func(prompt):
|
| 478 |
+
if not prompt and not shared.cmd_opts.hide_ui_dir_config and not shared.cmd_opts.no_prompt_history:
|
| 479 |
+
filename = os.path.join(data_path, "params.txt")
|
| 480 |
+
try:
|
| 481 |
+
with open(filename, "r", encoding="utf8") as file:
|
| 482 |
+
prompt = file.read()
|
| 483 |
+
except OSError:
|
| 484 |
+
pass
|
| 485 |
+
|
| 486 |
+
params = parse_generation_parameters(prompt)
|
| 487 |
+
script_callbacks.infotext_pasted_callback(prompt, params)
|
| 488 |
+
res = []
|
| 489 |
+
|
| 490 |
+
for output, key in paste_fields:
|
| 491 |
+
if callable(key):
|
| 492 |
+
try:
|
| 493 |
+
v = key(params)
|
| 494 |
+
except Exception:
|
| 495 |
+
errors.report(f"Error executing {key}", exc_info=True)
|
| 496 |
+
v = None
|
| 497 |
+
else:
|
| 498 |
+
v = params.get(key, None)
|
| 499 |
+
|
| 500 |
+
if v is None:
|
| 501 |
+
res.append(gr.update())
|
| 502 |
+
elif isinstance(v, type_of_gr_update):
|
| 503 |
+
res.append(v)
|
| 504 |
+
else:
|
| 505 |
+
try:
|
| 506 |
+
valtype = type(output.value)
|
| 507 |
+
|
| 508 |
+
if valtype == bool and v == "False":
|
| 509 |
+
val = False
|
| 510 |
+
elif valtype == int:
|
| 511 |
+
val = float(v)
|
| 512 |
+
else:
|
| 513 |
+
val = valtype(v)
|
| 514 |
+
|
| 515 |
+
res.append(gr.update(value=val))
|
| 516 |
+
except Exception:
|
| 517 |
+
res.append(gr.update())
|
| 518 |
+
|
| 519 |
+
return res
|
| 520 |
+
|
| 521 |
+
if override_settings_component is not None:
|
| 522 |
+
already_handled_fields = {key: 1 for _, key in paste_fields}
|
| 523 |
+
|
| 524 |
+
def paste_settings(params):
|
| 525 |
+
vals = get_override_settings(params, skip_fields=already_handled_fields)
|
| 526 |
+
|
| 527 |
+
vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals]
|
| 528 |
+
|
| 529 |
+
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
|
| 530 |
+
|
| 531 |
+
paste_fields = paste_fields + [(override_settings_component, paste_settings)]
|
| 532 |
+
|
| 533 |
+
button.click(
|
| 534 |
+
fn=paste_func,
|
| 535 |
+
inputs=[input_comp],
|
| 536 |
+
outputs=[x[0] for x in paste_fields],
|
| 537 |
+
show_progress=False,
|
| 538 |
+
)
|
| 539 |
+
button.click(
|
| 540 |
+
fn=None,
|
| 541 |
+
_js=f"recalculate_prompts_{tabname}",
|
| 542 |
+
inputs=[],
|
| 543 |
+
outputs=[],
|
| 544 |
+
show_progress=False,
|
| 545 |
+
)
|
| 546 |
+
|
modules/infotext_versions.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from modules import shared
|
| 2 |
+
from packaging import version
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
v160 = version.parse("1.6.0")
|
| 7 |
+
v170_tsnr = version.parse("v1.7.0-225")
|
| 8 |
+
v180 = version.parse("1.8.0")
|
| 9 |
+
v180_hr_styles = version.parse("1.8.0-139")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def parse_version(text):
|
| 13 |
+
if text is None:
|
| 14 |
+
return None
|
| 15 |
+
|
| 16 |
+
m = re.match(r'([^-]+-[^-]+)-.*', text)
|
| 17 |
+
if m:
|
| 18 |
+
text = m.group(1)
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
return version.parse(text)
|
| 22 |
+
except Exception:
|
| 23 |
+
return None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def backcompat(d):
|
| 27 |
+
"""Checks infotext Version field, and enables backwards compatibility options according to it."""
|
| 28 |
+
|
| 29 |
+
if not shared.opts.auto_backcompat:
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
ver = parse_version(d.get("Version"))
|
| 33 |
+
if ver is None:
|
| 34 |
+
return
|
| 35 |
+
|
| 36 |
+
if ver < v160 and '[' in d.get('Prompt', ''):
|
| 37 |
+
d["Old prompt editing timelines"] = True
|
| 38 |
+
|
| 39 |
+
if ver < v160 and d.get('Sampler', '') in ('DDIM', 'PLMS'):
|
| 40 |
+
d["Pad conds v0"] = True
|
| 41 |
+
|
| 42 |
+
if ver < v170_tsnr:
|
| 43 |
+
d["Downcast alphas_cumprod"] = True
|
| 44 |
+
|
| 45 |
+
if ver < v180 and d.get('Refiner'):
|
| 46 |
+
d["Refiner switch by sampling steps"] = True
|
modules/initialize.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import warnings
|
| 6 |
+
from threading import Thread
|
| 7 |
+
|
| 8 |
+
from modules.timer import startup_timer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def imports():
|
| 12 |
+
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
|
| 13 |
+
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
| 14 |
+
|
| 15 |
+
import torch # noqa: F401
|
| 16 |
+
startup_timer.record("import torch")
|
| 17 |
+
import pytorch_lightning # noqa: F401
|
| 18 |
+
startup_timer.record("import torch")
|
| 19 |
+
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
|
| 20 |
+
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
| 21 |
+
|
| 22 |
+
os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
|
| 23 |
+
import gradio # noqa: F401
|
| 24 |
+
startup_timer.record("import gradio")
|
| 25 |
+
|
| 26 |
+
from modules import paths, timer, import_hook, errors # noqa: F401
|
| 27 |
+
startup_timer.record("setup paths")
|
| 28 |
+
|
| 29 |
+
import ldm.modules.encoders.modules # noqa: F401
|
| 30 |
+
startup_timer.record("import ldm")
|
| 31 |
+
|
| 32 |
+
import sgm.modules.encoders.modules # noqa: F401
|
| 33 |
+
startup_timer.record("import sgm")
|
| 34 |
+
|
| 35 |
+
from modules import shared_init
|
| 36 |
+
shared_init.initialize()
|
| 37 |
+
startup_timer.record("initialize shared")
|
| 38 |
+
|
| 39 |
+
from modules import processing, gradio_extensons, ui # noqa: F401
|
| 40 |
+
startup_timer.record("other imports")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def check_versions():
|
| 44 |
+
from modules.shared_cmd_options import cmd_opts
|
| 45 |
+
|
| 46 |
+
if not cmd_opts.skip_version_check:
|
| 47 |
+
from modules import errors
|
| 48 |
+
errors.check_versions()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def initialize():
|
| 52 |
+
from modules import initialize_util
|
| 53 |
+
initialize_util.fix_torch_version()
|
| 54 |
+
initialize_util.fix_pytorch_lightning()
|
| 55 |
+
initialize_util.fix_asyncio_event_loop_policy()
|
| 56 |
+
initialize_util.validate_tls_options()
|
| 57 |
+
initialize_util.configure_sigint_handler()
|
| 58 |
+
initialize_util.configure_opts_onchange()
|
| 59 |
+
|
| 60 |
+
from modules import sd_models
|
| 61 |
+
sd_models.setup_model()
|
| 62 |
+
startup_timer.record("setup SD model")
|
| 63 |
+
|
| 64 |
+
from modules.shared_cmd_options import cmd_opts
|
| 65 |
+
|
| 66 |
+
from modules import codeformer_model
|
| 67 |
+
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision.transforms.functional_tensor")
|
| 68 |
+
codeformer_model.setup_model(cmd_opts.codeformer_models_path)
|
| 69 |
+
startup_timer.record("setup codeformer")
|
| 70 |
+
|
| 71 |
+
from modules import gfpgan_model
|
| 72 |
+
gfpgan_model.setup_model(cmd_opts.gfpgan_models_path)
|
| 73 |
+
startup_timer.record("setup gfpgan")
|
| 74 |
+
|
| 75 |
+
initialize_rest(reload_script_modules=False)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def initialize_rest(*, reload_script_modules=False):
|
| 79 |
+
"""
|
| 80 |
+
Called both from initialize() and when reloading the webui.
|
| 81 |
+
"""
|
| 82 |
+
from modules.shared_cmd_options import cmd_opts
|
| 83 |
+
|
| 84 |
+
from modules import sd_samplers
|
| 85 |
+
sd_samplers.set_samplers()
|
| 86 |
+
startup_timer.record("set samplers")
|
| 87 |
+
|
| 88 |
+
from modules import extensions
|
| 89 |
+
extensions.list_extensions()
|
| 90 |
+
startup_timer.record("list extensions")
|
| 91 |
+
|
| 92 |
+
from modules import initialize_util
|
| 93 |
+
initialize_util.restore_config_state_file()
|
| 94 |
+
startup_timer.record("restore config state file")
|
| 95 |
+
|
| 96 |
+
from modules import shared, upscaler, scripts
|
| 97 |
+
if cmd_opts.ui_debug_mode:
|
| 98 |
+
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
|
| 99 |
+
scripts.load_scripts()
|
| 100 |
+
return
|
| 101 |
+
|
| 102 |
+
from modules import sd_models
|
| 103 |
+
sd_models.list_models()
|
| 104 |
+
startup_timer.record("list SD models")
|
| 105 |
+
|
| 106 |
+
from modules import localization
|
| 107 |
+
localization.list_localizations(cmd_opts.localizations_dir)
|
| 108 |
+
startup_timer.record("list localizations")
|
| 109 |
+
|
| 110 |
+
with startup_timer.subcategory("load scripts"):
|
| 111 |
+
scripts.load_scripts()
|
| 112 |
+
|
| 113 |
+
if reload_script_modules and shared.opts.enable_reloading_ui_scripts:
|
| 114 |
+
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
|
| 115 |
+
importlib.reload(module)
|
| 116 |
+
startup_timer.record("reload script modules")
|
| 117 |
+
|
| 118 |
+
from modules import modelloader
|
| 119 |
+
modelloader.load_upscalers()
|
| 120 |
+
startup_timer.record("load upscalers")
|
| 121 |
+
|
| 122 |
+
from modules import sd_vae
|
| 123 |
+
sd_vae.refresh_vae_list()
|
| 124 |
+
startup_timer.record("refresh VAE")
|
| 125 |
+
|
| 126 |
+
from modules import textual_inversion
|
| 127 |
+
textual_inversion.textual_inversion.list_textual_inversion_templates()
|
| 128 |
+
startup_timer.record("refresh textual inversion templates")
|
| 129 |
+
|
| 130 |
+
from modules import script_callbacks, sd_hijack_optimizations, sd_hijack
|
| 131 |
+
script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers)
|
| 132 |
+
sd_hijack.list_optimizers()
|
| 133 |
+
startup_timer.record("scripts list_optimizers")
|
| 134 |
+
|
| 135 |
+
from modules import sd_unet
|
| 136 |
+
sd_unet.list_unets()
|
| 137 |
+
startup_timer.record("scripts list_unets")
|
| 138 |
+
|
| 139 |
+
def load_model():
|
| 140 |
+
"""
|
| 141 |
+
Accesses shared.sd_model property to load model.
|
| 142 |
+
After it's available, if it has been loaded before this access by some extension,
|
| 143 |
+
its optimization may be None because the list of optimizers has not been filled
|
| 144 |
+
by that time, so we apply optimization again.
|
| 145 |
+
"""
|
| 146 |
+
from modules import devices
|
| 147 |
+
devices.torch_npu_set_device()
|
| 148 |
+
|
| 149 |
+
shared.sd_model # noqa: B018
|
| 150 |
+
|
| 151 |
+
if sd_hijack.current_optimizer is None:
|
| 152 |
+
sd_hijack.apply_optimizations()
|
| 153 |
+
|
| 154 |
+
devices.first_time_calculation()
|
| 155 |
+
if not shared.cmd_opts.skip_load_model_at_start:
|
| 156 |
+
Thread(target=load_model).start()
|
| 157 |
+
|
| 158 |
+
from modules import shared_items
|
| 159 |
+
shared_items.reload_hypernetworks()
|
| 160 |
+
startup_timer.record("reload hypernetworks")
|
| 161 |
+
|
| 162 |
+
from modules import ui_extra_networks
|
| 163 |
+
ui_extra_networks.initialize()
|
| 164 |
+
ui_extra_networks.register_default_pages()
|
| 165 |
+
|
| 166 |
+
from modules import extra_networks
|
| 167 |
+
extra_networks.initialize()
|
| 168 |
+
extra_networks.register_default_extra_networks()
|
| 169 |
+
startup_timer.record("initialize extra networks")
|
modules/initialize_util.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import signal
|
| 4 |
+
import sys
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
from modules.timer import startup_timer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def gradio_server_name():
|
| 11 |
+
from modules.shared_cmd_options import cmd_opts
|
| 12 |
+
|
| 13 |
+
if cmd_opts.server_name:
|
| 14 |
+
return cmd_opts.server_name
|
| 15 |
+
else:
|
| 16 |
+
return "0.0.0.0" if cmd_opts.listen else None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def fix_torch_version():
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
|
| 23 |
+
if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
| 24 |
+
torch.__long_version__ = torch.__version__
|
| 25 |
+
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
| 26 |
+
|
| 27 |
+
def fix_pytorch_lightning():
|
| 28 |
+
# Checks if pytorch_lightning.utilities.distributed already exists in the sys.modules cache
|
| 29 |
+
if 'pytorch_lightning.utilities.distributed' not in sys.modules:
|
| 30 |
+
import pytorch_lightning
|
| 31 |
+
# Lets the user know that the library was not found and then will set it to pytorch_lightning.utilities.rank_zero
|
| 32 |
+
print("Pytorch_lightning.distributed not found, attempting pytorch_lightning.rank_zero")
|
| 33 |
+
sys.modules["pytorch_lightning.utilities.distributed"] = pytorch_lightning.utilities.rank_zero
|
| 34 |
+
|
| 35 |
+
def fix_asyncio_event_loop_policy():
|
| 36 |
+
"""
|
| 37 |
+
The default `asyncio` event loop policy only automatically creates
|
| 38 |
+
event loops in the main threads. Other threads must create event
|
| 39 |
+
loops explicitly or `asyncio.get_event_loop` (and therefore
|
| 40 |
+
`.IOLoop.current`) will fail. Installing this policy allows event
|
| 41 |
+
loops to be created automatically on any thread, matching the
|
| 42 |
+
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
import asyncio
|
| 46 |
+
|
| 47 |
+
if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
|
| 48 |
+
# "Any thread" and "selector" should be orthogonal, but there's not a clean
|
| 49 |
+
# interface for composing policies so pick the right base.
|
| 50 |
+
_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
|
| 51 |
+
else:
|
| 52 |
+
_BasePolicy = asyncio.DefaultEventLoopPolicy
|
| 53 |
+
|
| 54 |
+
class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
|
| 55 |
+
"""Event loop policy that allows loop creation on any thread.
|
| 56 |
+
Usage::
|
| 57 |
+
|
| 58 |
+
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def get_event_loop(self) -> asyncio.AbstractEventLoop:
|
| 62 |
+
try:
|
| 63 |
+
return super().get_event_loop()
|
| 64 |
+
except (RuntimeError, AssertionError):
|
| 65 |
+
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
|
| 66 |
+
# and changed to a RuntimeError in 3.4.3.
|
| 67 |
+
# "There is no current event loop in thread %r"
|
| 68 |
+
loop = self.new_event_loop()
|
| 69 |
+
self.set_event_loop(loop)
|
| 70 |
+
return loop
|
| 71 |
+
|
| 72 |
+
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def restore_config_state_file():
|
| 76 |
+
from modules import shared, config_states
|
| 77 |
+
|
| 78 |
+
config_state_file = shared.opts.restore_config_state_file
|
| 79 |
+
if config_state_file == "":
|
| 80 |
+
return
|
| 81 |
+
|
| 82 |
+
shared.opts.restore_config_state_file = ""
|
| 83 |
+
shared.opts.save(shared.config_filename)
|
| 84 |
+
|
| 85 |
+
if os.path.isfile(config_state_file):
|
| 86 |
+
print(f"*** About to restore extension state from file: {config_state_file}")
|
| 87 |
+
with open(config_state_file, "r", encoding="utf-8") as f:
|
| 88 |
+
config_state = json.load(f)
|
| 89 |
+
config_states.restore_extension_config(config_state)
|
| 90 |
+
startup_timer.record("restore extension config")
|
| 91 |
+
elif config_state_file:
|
| 92 |
+
print(f"!!! Config state backup not found: {config_state_file}")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def validate_tls_options():
|
| 96 |
+
from modules.shared_cmd_options import cmd_opts
|
| 97 |
+
|
| 98 |
+
if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
|
| 99 |
+
return
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
if not os.path.exists(cmd_opts.tls_keyfile):
|
| 103 |
+
print("Invalid path to TLS keyfile given")
|
| 104 |
+
if not os.path.exists(cmd_opts.tls_certfile):
|
| 105 |
+
print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
|
| 106 |
+
except TypeError:
|
| 107 |
+
cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
|
| 108 |
+
print("TLS setup invalid, running webui without TLS")
|
| 109 |
+
else:
|
| 110 |
+
print("Running with TLS")
|
| 111 |
+
startup_timer.record("TLS")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def get_gradio_auth_creds():
|
| 115 |
+
"""
|
| 116 |
+
Convert the gradio_auth and gradio_auth_path commandline arguments into
|
| 117 |
+
an iterable of (username, password) tuples.
|
| 118 |
+
"""
|
| 119 |
+
from modules.shared_cmd_options import cmd_opts
|
| 120 |
+
|
| 121 |
+
def process_credential_line(s):
|
| 122 |
+
s = s.strip()
|
| 123 |
+
if not s:
|
| 124 |
+
return None
|
| 125 |
+
return tuple(s.split(':', 1))
|
| 126 |
+
|
| 127 |
+
if cmd_opts.gradio_auth:
|
| 128 |
+
for cred in cmd_opts.gradio_auth.split(','):
|
| 129 |
+
cred = process_credential_line(cred)
|
| 130 |
+
if cred:
|
| 131 |
+
yield cred
|
| 132 |
+
|
| 133 |
+
if cmd_opts.gradio_auth_path:
|
| 134 |
+
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
|
| 135 |
+
for line in file.readlines():
|
| 136 |
+
for cred in line.strip().split(','):
|
| 137 |
+
cred = process_credential_line(cred)
|
| 138 |
+
if cred:
|
| 139 |
+
yield cred
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def dumpstacks():
|
| 143 |
+
import threading
|
| 144 |
+
import traceback
|
| 145 |
+
|
| 146 |
+
id2name = {th.ident: th.name for th in threading.enumerate()}
|
| 147 |
+
code = []
|
| 148 |
+
for threadId, stack in sys._current_frames().items():
|
| 149 |
+
code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
|
| 150 |
+
for filename, lineno, name, line in traceback.extract_stack(stack):
|
| 151 |
+
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
|
| 152 |
+
if line:
|
| 153 |
+
code.append(" " + line.strip())
|
| 154 |
+
|
| 155 |
+
print("\n".join(code))
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def configure_sigint_handler():
|
| 159 |
+
# make the program just exit at ctrl+c without waiting for anything
|
| 160 |
+
|
| 161 |
+
from modules import shared
|
| 162 |
+
|
| 163 |
+
def sigint_handler(sig, frame):
|
| 164 |
+
print(f'Interrupted with signal {sig} in {frame}')
|
| 165 |
+
|
| 166 |
+
if shared.opts.dump_stacks_on_signal:
|
| 167 |
+
dumpstacks()
|
| 168 |
+
|
| 169 |
+
os._exit(0)
|
| 170 |
+
|
| 171 |
+
if not os.environ.get("COVERAGE_RUN"):
|
| 172 |
+
# Don't install the immediate-quit handler when running under coverage,
|
| 173 |
+
# as then the coverage report won't be generated.
|
| 174 |
+
signal.signal(signal.SIGINT, sigint_handler)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def configure_opts_onchange():
|
| 178 |
+
from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack
|
| 179 |
+
from modules.call_queue import wrap_queued_call
|
| 180 |
+
|
| 181 |
+
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
|
| 182 |
+
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
|
| 183 |
+
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
|
| 184 |
+
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
| 185 |
+
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
| 186 |
+
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
|
| 187 |
+
shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
|
| 188 |
+
shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)
|
| 189 |
+
startup_timer.record("opts onchange")
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def setup_middleware(app):
|
| 193 |
+
from starlette.middleware.gzip import GZipMiddleware
|
| 194 |
+
|
| 195 |
+
app.middleware_stack = None # reset current middleware to allow modifying user provided list
|
| 196 |
+
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
| 197 |
+
configure_cors_middleware(app)
|
| 198 |
+
app.build_middleware_stack() # rebuild middleware stack on-the-fly
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def configure_cors_middleware(app):
|
| 202 |
+
from starlette.middleware.cors import CORSMiddleware
|
| 203 |
+
from modules.shared_cmd_options import cmd_opts
|
| 204 |
+
|
| 205 |
+
cors_options = {
|
| 206 |
+
"allow_methods": ["*"],
|
| 207 |
+
"allow_headers": ["*"],
|
| 208 |
+
"allow_credentials": True,
|
| 209 |
+
}
|
| 210 |
+
if cmd_opts.cors_allow_origins:
|
| 211 |
+
cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
|
| 212 |
+
if cmd_opts.cors_allow_origins_regex:
|
| 213 |
+
cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
|
| 214 |
+
app.add_middleware(CORSMiddleware, **cors_options)
|
| 215 |
+
|
modules/interrogate.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from collections import namedtuple
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.hub
|
| 9 |
+
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 12 |
+
|
| 13 |
+
from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils
|
| 14 |
+
|
| 15 |
+
blip_image_eval_size = 384
|
| 16 |
+
clip_model_name = 'ViT-L/14'
|
| 17 |
+
|
| 18 |
+
Category = namedtuple("Category", ["name", "topn", "items"])
|
| 19 |
+
|
| 20 |
+
re_topn = re.compile(r"\.top(\d+)$")
|
| 21 |
+
|
| 22 |
+
def category_types():
|
| 23 |
+
return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def download_default_clip_interrogate_categories(content_dir):
|
| 27 |
+
print("Downloading CLIP categories...")
|
| 28 |
+
|
| 29 |
+
tmpdir = f"{content_dir}_tmp"
|
| 30 |
+
category_types = ["artists", "flavors", "mediums", "movements"]
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
os.makedirs(tmpdir, exist_ok=True)
|
| 34 |
+
for category_type in category_types:
|
| 35 |
+
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
|
| 36 |
+
os.rename(tmpdir, content_dir)
|
| 37 |
+
|
| 38 |
+
except Exception as e:
|
| 39 |
+
errors.display(e, "downloading default CLIP interrogate categories")
|
| 40 |
+
finally:
|
| 41 |
+
if os.path.exists(tmpdir):
|
| 42 |
+
os.removedirs(tmpdir)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class InterrogateModels:
|
| 46 |
+
blip_model = None
|
| 47 |
+
clip_model = None
|
| 48 |
+
clip_preprocess = None
|
| 49 |
+
dtype = None
|
| 50 |
+
running_on_cpu = None
|
| 51 |
+
|
| 52 |
+
def __init__(self, content_dir):
|
| 53 |
+
self.loaded_categories = None
|
| 54 |
+
self.skip_categories = []
|
| 55 |
+
self.content_dir = content_dir
|
| 56 |
+
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
|
| 57 |
+
|
| 58 |
+
def categories(self):
|
| 59 |
+
if not os.path.exists(self.content_dir):
|
| 60 |
+
download_default_clip_interrogate_categories(self.content_dir)
|
| 61 |
+
|
| 62 |
+
if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
|
| 63 |
+
return self.loaded_categories
|
| 64 |
+
|
| 65 |
+
self.loaded_categories = []
|
| 66 |
+
|
| 67 |
+
if os.path.exists(self.content_dir):
|
| 68 |
+
self.skip_categories = shared.opts.interrogate_clip_skip_categories
|
| 69 |
+
category_types = []
|
| 70 |
+
for filename in Path(self.content_dir).glob('*.txt'):
|
| 71 |
+
category_types.append(filename.stem)
|
| 72 |
+
if filename.stem in self.skip_categories:
|
| 73 |
+
continue
|
| 74 |
+
m = re_topn.search(filename.stem)
|
| 75 |
+
topn = 1 if m is None else int(m.group(1))
|
| 76 |
+
with open(filename, "r", encoding="utf8") as file:
|
| 77 |
+
lines = [x.strip() for x in file.readlines()]
|
| 78 |
+
|
| 79 |
+
self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
|
| 80 |
+
|
| 81 |
+
return self.loaded_categories
|
| 82 |
+
|
| 83 |
+
def create_fake_fairscale(self):
|
| 84 |
+
class FakeFairscale:
|
| 85 |
+
def checkpoint_wrapper(self):
|
| 86 |
+
pass
|
| 87 |
+
|
| 88 |
+
sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
|
| 89 |
+
|
| 90 |
+
def load_blip_model(self):
|
| 91 |
+
self.create_fake_fairscale()
|
| 92 |
+
import models.blip
|
| 93 |
+
|
| 94 |
+
files = modelloader.load_models(
|
| 95 |
+
model_path=os.path.join(paths.models_path, "BLIP"),
|
| 96 |
+
model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
|
| 97 |
+
ext_filter=[".pth"],
|
| 98 |
+
download_name='model_base_caption_capfilt_large.pth',
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
|
| 102 |
+
blip_model.eval()
|
| 103 |
+
|
| 104 |
+
return blip_model
|
| 105 |
+
|
| 106 |
+
def load_clip_model(self):
|
| 107 |
+
import clip
|
| 108 |
+
|
| 109 |
+
if self.running_on_cpu:
|
| 110 |
+
model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
|
| 111 |
+
else:
|
| 112 |
+
model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
|
| 113 |
+
|
| 114 |
+
model.eval()
|
| 115 |
+
model = model.to(devices.device_interrogate)
|
| 116 |
+
|
| 117 |
+
return model, preprocess
|
| 118 |
+
|
| 119 |
+
def load(self):
|
| 120 |
+
if self.blip_model is None:
|
| 121 |
+
self.blip_model = self.load_blip_model()
|
| 122 |
+
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
| 123 |
+
self.blip_model = self.blip_model.half()
|
| 124 |
+
|
| 125 |
+
self.blip_model = self.blip_model.to(devices.device_interrogate)
|
| 126 |
+
|
| 127 |
+
if self.clip_model is None:
|
| 128 |
+
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
| 129 |
+
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
| 130 |
+
self.clip_model = self.clip_model.half()
|
| 131 |
+
|
| 132 |
+
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
| 133 |
+
|
| 134 |
+
self.dtype = torch_utils.get_param(self.clip_model).dtype
|
| 135 |
+
|
| 136 |
+
def send_clip_to_ram(self):
|
| 137 |
+
if not shared.opts.interrogate_keep_models_in_memory:
|
| 138 |
+
if self.clip_model is not None:
|
| 139 |
+
self.clip_model = self.clip_model.to(devices.cpu)
|
| 140 |
+
|
| 141 |
+
def send_blip_to_ram(self):
|
| 142 |
+
if not shared.opts.interrogate_keep_models_in_memory:
|
| 143 |
+
if self.blip_model is not None:
|
| 144 |
+
self.blip_model = self.blip_model.to(devices.cpu)
|
| 145 |
+
|
| 146 |
+
def unload(self):
|
| 147 |
+
self.send_clip_to_ram()
|
| 148 |
+
self.send_blip_to_ram()
|
| 149 |
+
|
| 150 |
+
devices.torch_gc()
|
| 151 |
+
|
| 152 |
+
def rank(self, image_features, text_array, top_count=1):
|
| 153 |
+
import clip
|
| 154 |
+
|
| 155 |
+
devices.torch_gc()
|
| 156 |
+
|
| 157 |
+
if shared.opts.interrogate_clip_dict_limit != 0:
|
| 158 |
+
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
| 159 |
+
|
| 160 |
+
top_count = min(top_count, len(text_array))
|
| 161 |
+
text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
|
| 162 |
+
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
| 163 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
| 164 |
+
|
| 165 |
+
similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
|
| 166 |
+
for i in range(image_features.shape[0]):
|
| 167 |
+
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
|
| 168 |
+
similarity /= image_features.shape[0]
|
| 169 |
+
|
| 170 |
+
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
|
| 171 |
+
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
|
| 172 |
+
|
| 173 |
+
def generate_caption(self, pil_image):
|
| 174 |
+
gpu_image = transforms.Compose([
|
| 175 |
+
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
|
| 176 |
+
transforms.ToTensor(),
|
| 177 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
| 178 |
+
])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
| 179 |
+
|
| 180 |
+
with torch.no_grad():
|
| 181 |
+
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
|
| 182 |
+
|
| 183 |
+
return caption[0]
|
| 184 |
+
|
| 185 |
+
def interrogate(self, pil_image):
|
| 186 |
+
res = ""
|
| 187 |
+
shared.state.begin(job="interrogate")
|
| 188 |
+
try:
|
| 189 |
+
lowvram.send_everything_to_cpu()
|
| 190 |
+
devices.torch_gc()
|
| 191 |
+
|
| 192 |
+
self.load()
|
| 193 |
+
|
| 194 |
+
caption = self.generate_caption(pil_image)
|
| 195 |
+
self.send_blip_to_ram()
|
| 196 |
+
devices.torch_gc()
|
| 197 |
+
|
| 198 |
+
res = caption
|
| 199 |
+
|
| 200 |
+
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
| 201 |
+
|
| 202 |
+
with torch.no_grad(), devices.autocast():
|
| 203 |
+
image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
|
| 204 |
+
|
| 205 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
| 206 |
+
|
| 207 |
+
for cat in self.categories():
|
| 208 |
+
matches = self.rank(image_features, cat.items, top_count=cat.topn)
|
| 209 |
+
for match, score in matches:
|
| 210 |
+
if shared.opts.interrogate_return_ranks:
|
| 211 |
+
res += f", ({match}:{score/100:.3f})"
|
| 212 |
+
else:
|
| 213 |
+
res += f", {match}"
|
| 214 |
+
|
| 215 |
+
except Exception:
|
| 216 |
+
errors.report("Error interrogating", exc_info=True)
|
| 217 |
+
res += "<error>"
|
| 218 |
+
|
| 219 |
+
self.unload()
|
| 220 |
+
shared.state.end()
|
| 221 |
+
|
| 222 |
+
return res
|
modules/launch_utils.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# this scripts installs necessary requirements and launches main program in webui.py
|
| 2 |
+
import logging
|
| 3 |
+
import re
|
| 4 |
+
import subprocess
|
| 5 |
+
import os
|
| 6 |
+
import shutil
|
| 7 |
+
import sys
|
| 8 |
+
import importlib.util
|
| 9 |
+
import importlib.metadata
|
| 10 |
+
import platform
|
| 11 |
+
import json
|
| 12 |
+
import shlex
|
| 13 |
+
from functools import lru_cache
|
| 14 |
+
|
| 15 |
+
from modules import cmd_args, errors
|
| 16 |
+
from modules.paths_internal import script_path, extensions_dir
|
| 17 |
+
from modules.timer import startup_timer
|
| 18 |
+
from modules import logging_config
|
| 19 |
+
|
| 20 |
+
args, _ = cmd_args.parser.parse_known_args()
|
| 21 |
+
logging_config.setup_logging(args.loglevel)
|
| 22 |
+
|
| 23 |
+
python = sys.executable
|
| 24 |
+
git = os.environ.get('GIT', "git")
|
| 25 |
+
index_url = os.environ.get('INDEX_URL', "")
|
| 26 |
+
dir_repos = "repositories"
|
| 27 |
+
|
| 28 |
+
# Whether to default to printing command output
|
| 29 |
+
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
|
| 30 |
+
|
| 31 |
+
os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def check_python_version():
|
| 35 |
+
is_windows = platform.system() == "Windows"
|
| 36 |
+
major = sys.version_info.major
|
| 37 |
+
minor = sys.version_info.minor
|
| 38 |
+
micro = sys.version_info.micro
|
| 39 |
+
|
| 40 |
+
if is_windows:
|
| 41 |
+
supported_minors = [10]
|
| 42 |
+
else:
|
| 43 |
+
supported_minors = [7, 8, 9, 10, 11]
|
| 44 |
+
|
| 45 |
+
if not (major == 3 and minor in supported_minors):
|
| 46 |
+
import modules.errors
|
| 47 |
+
|
| 48 |
+
modules.errors.print_error_explanation(f"""
|
| 49 |
+
INCOMPATIBLE PYTHON VERSION
|
| 50 |
+
|
| 51 |
+
This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
|
| 52 |
+
If you encounter an error with "RuntimeError: Couldn't install torch." message,
|
| 53 |
+
or any other error regarding unsuccessful package (library) installation,
|
| 54 |
+
please downgrade (or upgrade) to the latest version of 3.10 Python
|
| 55 |
+
and delete current Python and "venv" folder in WebUI's directory.
|
| 56 |
+
|
| 57 |
+
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/
|
| 58 |
+
|
| 59 |
+
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre" if is_windows else ""}
|
| 60 |
+
|
| 61 |
+
Use --skip-python-version-check to suppress this warning.
|
| 62 |
+
""")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@lru_cache()
|
| 66 |
+
def commit_hash():
|
| 67 |
+
try:
|
| 68 |
+
return subprocess.check_output([git, "-C", script_path, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
|
| 69 |
+
except Exception:
|
| 70 |
+
return "<none>"
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@lru_cache()
|
| 74 |
+
def git_tag():
|
| 75 |
+
try:
|
| 76 |
+
return subprocess.check_output([git, "-C", script_path, "describe", "--tags"], shell=False, encoding='utf8').strip()
|
| 77 |
+
except Exception:
|
| 78 |
+
try:
|
| 79 |
+
|
| 80 |
+
changelog_md = os.path.join(script_path, "CHANGELOG.md")
|
| 81 |
+
with open(changelog_md, "r", encoding="utf-8") as file:
|
| 82 |
+
line = next((line.strip() for line in file if line.strip()), "<none>")
|
| 83 |
+
line = line.replace("## ", "")
|
| 84 |
+
return line
|
| 85 |
+
except Exception:
|
| 86 |
+
return "<none>"
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
|
| 90 |
+
if desc is not None:
|
| 91 |
+
print(desc)
|
| 92 |
+
|
| 93 |
+
run_kwargs = {
|
| 94 |
+
"args": command,
|
| 95 |
+
"shell": True,
|
| 96 |
+
"env": os.environ if custom_env is None else custom_env,
|
| 97 |
+
"encoding": 'utf8',
|
| 98 |
+
"errors": 'ignore',
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
if not live:
|
| 102 |
+
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
|
| 103 |
+
|
| 104 |
+
result = subprocess.run(**run_kwargs)
|
| 105 |
+
|
| 106 |
+
if result.returncode != 0:
|
| 107 |
+
error_bits = [
|
| 108 |
+
f"{errdesc or 'Error running command'}.",
|
| 109 |
+
f"Command: {command}",
|
| 110 |
+
f"Error code: {result.returncode}",
|
| 111 |
+
]
|
| 112 |
+
if result.stdout:
|
| 113 |
+
error_bits.append(f"stdout: {result.stdout}")
|
| 114 |
+
if result.stderr:
|
| 115 |
+
error_bits.append(f"stderr: {result.stderr}")
|
| 116 |
+
raise RuntimeError("\n".join(error_bits))
|
| 117 |
+
|
| 118 |
+
return (result.stdout or "")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def is_installed(package):
|
| 122 |
+
try:
|
| 123 |
+
dist = importlib.metadata.distribution(package)
|
| 124 |
+
except importlib.metadata.PackageNotFoundError:
|
| 125 |
+
try:
|
| 126 |
+
spec = importlib.util.find_spec(package)
|
| 127 |
+
except ModuleNotFoundError:
|
| 128 |
+
return False
|
| 129 |
+
|
| 130 |
+
return spec is not None
|
| 131 |
+
|
| 132 |
+
return dist is not None
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def repo_dir(name):
|
| 136 |
+
return os.path.join(script_path, dir_repos, name)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def run_pip(command, desc=None, live=default_command_live):
|
| 140 |
+
if args.skip_install:
|
| 141 |
+
return
|
| 142 |
+
|
| 143 |
+
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
| 144 |
+
return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def check_run_python(code: str) -> bool:
|
| 148 |
+
result = subprocess.run([python, "-c", code], capture_output=True, shell=False)
|
| 149 |
+
return result.returncode == 0
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def git_fix_workspace(dir, name):
|
| 153 |
+
run(f'"{git}" -C "{dir}" fetch --refetch --no-auto-gc', f"Fetching all contents for {name}", f"Couldn't fetch {name}", live=True)
|
| 154 |
+
run(f'"{git}" -C "{dir}" gc --aggressive --prune=now', f"Pruning {name}", f"Couldn't prune {name}", live=True)
|
| 155 |
+
return
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def run_git(dir, name, command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live, autofix=True):
|
| 159 |
+
try:
|
| 160 |
+
return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
|
| 161 |
+
except RuntimeError:
|
| 162 |
+
if not autofix:
|
| 163 |
+
raise
|
| 164 |
+
|
| 165 |
+
print(f"{errdesc}, attempting autofix...")
|
| 166 |
+
git_fix_workspace(dir, name)
|
| 167 |
+
|
| 168 |
+
return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def git_clone(url, dir, name, commithash=None):
|
| 172 |
+
# TODO clone into temporary dir and move if successful
|
| 173 |
+
|
| 174 |
+
if os.path.exists(dir):
|
| 175 |
+
if commithash is None:
|
| 176 |
+
return
|
| 177 |
+
|
| 178 |
+
current_hash = run_git(dir, name, 'rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
|
| 179 |
+
if current_hash == commithash:
|
| 180 |
+
return
|
| 181 |
+
|
| 182 |
+
if run_git(dir, name, 'config --get remote.origin.url', None, f"Couldn't determine {name}'s origin URL", live=False).strip() != url:
|
| 183 |
+
run_git(dir, name, f'remote set-url origin "{url}"', None, f"Failed to set {name}'s origin URL", live=False)
|
| 184 |
+
|
| 185 |
+
run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False)
|
| 186 |
+
|
| 187 |
+
run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
| 188 |
+
|
| 189 |
+
return
|
| 190 |
+
|
| 191 |
+
try:
|
| 192 |
+
run(f'"{git}" clone --config core.filemode=false "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
| 193 |
+
except RuntimeError:
|
| 194 |
+
shutil.rmtree(dir, ignore_errors=True)
|
| 195 |
+
raise
|
| 196 |
+
|
| 197 |
+
if commithash is not None:
|
| 198 |
+
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def git_pull_recursive(dir):
|
| 202 |
+
for subdir, _, _ in os.walk(dir):
|
| 203 |
+
if os.path.exists(os.path.join(subdir, '.git')):
|
| 204 |
+
try:
|
| 205 |
+
output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
|
| 206 |
+
print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
|
| 207 |
+
except subprocess.CalledProcessError as e:
|
| 208 |
+
print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def version_check(commit):
|
| 212 |
+
try:
|
| 213 |
+
import requests
|
| 214 |
+
commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
|
| 215 |
+
if commit != "<none>" and commits['commit']['sha'] != commit:
|
| 216 |
+
print("--------------------------------------------------------")
|
| 217 |
+
print("| You are not up to date with the most recent release. |")
|
| 218 |
+
print("| Consider running `git pull` to update. |")
|
| 219 |
+
print("--------------------------------------------------------")
|
| 220 |
+
elif commits['commit']['sha'] == commit:
|
| 221 |
+
print("You are up to date with the most recent release.")
|
| 222 |
+
else:
|
| 223 |
+
print("Not a git clone, can't perform version check.")
|
| 224 |
+
except Exception as e:
|
| 225 |
+
print("version check failed", e)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def run_extension_installer(extension_dir):
|
| 229 |
+
path_installer = os.path.join(extension_dir, "install.py")
|
| 230 |
+
if not os.path.isfile(path_installer):
|
| 231 |
+
return
|
| 232 |
+
|
| 233 |
+
try:
|
| 234 |
+
env = os.environ.copy()
|
| 235 |
+
env['PYTHONPATH'] = f"{script_path}{os.pathsep}{env.get('PYTHONPATH', '')}"
|
| 236 |
+
|
| 237 |
+
stdout = run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env).strip()
|
| 238 |
+
if stdout:
|
| 239 |
+
print(stdout)
|
| 240 |
+
except Exception as e:
|
| 241 |
+
errors.report(str(e))
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def list_extensions(settings_file):
|
| 245 |
+
settings = {}
|
| 246 |
+
|
| 247 |
+
try:
|
| 248 |
+
with open(settings_file, "r", encoding="utf8") as file:
|
| 249 |
+
settings = json.load(file)
|
| 250 |
+
except FileNotFoundError:
|
| 251 |
+
pass
|
| 252 |
+
except Exception:
|
| 253 |
+
errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
|
| 254 |
+
os.replace(settings_file, os.path.join(script_path, "tmp", "config.json"))
|
| 255 |
+
|
| 256 |
+
disabled_extensions = set(settings.get('disabled_extensions', []))
|
| 257 |
+
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
| 258 |
+
|
| 259 |
+
if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions or not os.path.isdir(extensions_dir):
|
| 260 |
+
return []
|
| 261 |
+
|
| 262 |
+
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def run_extensions_installers(settings_file):
|
| 266 |
+
if not os.path.isdir(extensions_dir):
|
| 267 |
+
return
|
| 268 |
+
|
| 269 |
+
with startup_timer.subcategory("run extensions installers"):
|
| 270 |
+
for dirname_extension in list_extensions(settings_file):
|
| 271 |
+
logging.debug(f"Installing {dirname_extension}")
|
| 272 |
+
|
| 273 |
+
path = os.path.join(extensions_dir, dirname_extension)
|
| 274 |
+
|
| 275 |
+
if os.path.isdir(path):
|
| 276 |
+
run_extension_installer(path)
|
| 277 |
+
startup_timer.record(dirname_extension)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def requirements_met(requirements_file):
|
| 284 |
+
"""
|
| 285 |
+
Does a simple parse of a requirements.txt file to determine if all rerqirements in it
|
| 286 |
+
are already installed. Returns True if so, False if not installed or parsing fails.
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
import importlib.metadata
|
| 290 |
+
import packaging.version
|
| 291 |
+
|
| 292 |
+
with open(requirements_file, "r", encoding="utf8") as file:
|
| 293 |
+
for line in file:
|
| 294 |
+
if line.strip() == "":
|
| 295 |
+
continue
|
| 296 |
+
|
| 297 |
+
m = re.match(re_requirement, line)
|
| 298 |
+
if m is None:
|
| 299 |
+
return False
|
| 300 |
+
|
| 301 |
+
package = m.group(1).strip()
|
| 302 |
+
version_required = (m.group(2) or "").strip()
|
| 303 |
+
|
| 304 |
+
if version_required == "":
|
| 305 |
+
continue
|
| 306 |
+
|
| 307 |
+
try:
|
| 308 |
+
version_installed = importlib.metadata.version(package)
|
| 309 |
+
except Exception:
|
| 310 |
+
return False
|
| 311 |
+
|
| 312 |
+
if packaging.version.parse(version_required) != packaging.version.parse(version_installed):
|
| 313 |
+
return False
|
| 314 |
+
|
| 315 |
+
return True
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def prepare_environment():
|
| 319 |
+
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
|
| 320 |
+
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}")
|
| 321 |
+
if args.use_ipex:
|
| 322 |
+
if platform.system() == "Windows":
|
| 323 |
+
# The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
|
| 324 |
+
# This is NOT an Intel official release so please use it at your own risk!!
|
| 325 |
+
# See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details.
|
| 326 |
+
#
|
| 327 |
+
# Strengths (over official IPEX 2.0.110 windows release):
|
| 328 |
+
# - AOT build (for Arc GPU only) to eliminate JIT compilation overhead: https://github.com/intel/intel-extension-for-pytorch/issues/399
|
| 329 |
+
# - Bundles minimal oneAPI 2023.2 dependencies into the python wheels, so users don't need to install oneAPI for the whole system.
|
| 330 |
+
# - Provides a compatible torchvision wheel: https://github.com/intel/intel-extension-for-pytorch/issues/465
|
| 331 |
+
# Limitation:
|
| 332 |
+
# - Only works for python 3.10
|
| 333 |
+
url_prefix = "https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle"
|
| 334 |
+
torch_command = os.environ.get('TORCH_COMMAND', f"pip install {url_prefix}/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl {url_prefix}/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl {url_prefix}/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl")
|
| 335 |
+
else:
|
| 336 |
+
# Using official IPEX release for linux since it's already an AOT build.
|
| 337 |
+
# However, users still have to install oneAPI toolkit and activate oneAPI environment manually.
|
| 338 |
+
# See https://intel.github.io/intel-extension-for-pytorch/index.html#installation for details.
|
| 339 |
+
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/")
|
| 340 |
+
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
|
| 341 |
+
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
| 342 |
+
requirements_file_for_npu = os.environ.get('REQS_FILE_FOR_NPU', "requirements_npu.txt")
|
| 343 |
+
|
| 344 |
+
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1')
|
| 345 |
+
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
| 346 |
+
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
| 347 |
+
|
| 348 |
+
assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git")
|
| 349 |
+
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
| 350 |
+
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
| 351 |
+
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
| 352 |
+
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
| 353 |
+
|
| 354 |
+
assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
|
| 355 |
+
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
| 356 |
+
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
| 357 |
+
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
| 358 |
+
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
| 359 |
+
|
| 360 |
+
try:
|
| 361 |
+
# the existence of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
|
| 362 |
+
os.remove(os.path.join(script_path, "tmp", "restart"))
|
| 363 |
+
os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
|
| 364 |
+
except OSError:
|
| 365 |
+
pass
|
| 366 |
+
|
| 367 |
+
if not args.skip_python_version_check:
|
| 368 |
+
check_python_version()
|
| 369 |
+
|
| 370 |
+
startup_timer.record("checks")
|
| 371 |
+
|
| 372 |
+
commit = commit_hash()
|
| 373 |
+
tag = git_tag()
|
| 374 |
+
startup_timer.record("git version info")
|
| 375 |
+
|
| 376 |
+
print(f"Python {sys.version}")
|
| 377 |
+
print(f"Version: {tag}")
|
| 378 |
+
print(f"Commit hash: {commit}")
|
| 379 |
+
|
| 380 |
+
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
| 381 |
+
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
| 382 |
+
startup_timer.record("install torch")
|
| 383 |
+
|
| 384 |
+
if args.use_ipex:
|
| 385 |
+
args.skip_torch_cuda_test = True
|
| 386 |
+
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
|
| 387 |
+
raise RuntimeError(
|
| 388 |
+
'Torch is not able to use GPU; '
|
| 389 |
+
'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
|
| 390 |
+
)
|
| 391 |
+
startup_timer.record("torch GPU test")
|
| 392 |
+
|
| 393 |
+
if not is_installed("clip"):
|
| 394 |
+
run_pip(f"install {clip_package}", "clip")
|
| 395 |
+
startup_timer.record("install clip")
|
| 396 |
+
|
| 397 |
+
if not is_installed("open_clip"):
|
| 398 |
+
run_pip(f"install {openclip_package}", "open_clip")
|
| 399 |
+
startup_timer.record("install open_clip")
|
| 400 |
+
|
| 401 |
+
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
|
| 402 |
+
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
| 403 |
+
startup_timer.record("install xformers")
|
| 404 |
+
|
| 405 |
+
if not is_installed("ngrok") and args.ngrok:
|
| 406 |
+
run_pip("install ngrok", "ngrok")
|
| 407 |
+
startup_timer.record("install ngrok")
|
| 408 |
+
|
| 409 |
+
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
| 410 |
+
|
| 411 |
+
git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash)
|
| 412 |
+
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
| 413 |
+
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
| 414 |
+
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
| 415 |
+
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
| 416 |
+
|
| 417 |
+
startup_timer.record("clone repositores")
|
| 418 |
+
|
| 419 |
+
if not os.path.isfile(requirements_file):
|
| 420 |
+
requirements_file = os.path.join(script_path, requirements_file)
|
| 421 |
+
|
| 422 |
+
if not requirements_met(requirements_file):
|
| 423 |
+
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
| 424 |
+
startup_timer.record("install requirements")
|
| 425 |
+
|
| 426 |
+
if not os.path.isfile(requirements_file_for_npu):
|
| 427 |
+
requirements_file_for_npu = os.path.join(script_path, requirements_file_for_npu)
|
| 428 |
+
|
| 429 |
+
if "torch_npu" in torch_command and not requirements_met(requirements_file_for_npu):
|
| 430 |
+
run_pip(f"install -r \"{requirements_file_for_npu}\"", "requirements_for_npu")
|
| 431 |
+
startup_timer.record("install requirements_for_npu")
|
| 432 |
+
|
| 433 |
+
if not args.skip_install:
|
| 434 |
+
run_extensions_installers(settings_file=args.ui_settings_file)
|
| 435 |
+
|
| 436 |
+
if args.update_check:
|
| 437 |
+
version_check(commit)
|
| 438 |
+
startup_timer.record("check version")
|
| 439 |
+
|
| 440 |
+
if args.update_all_extensions:
|
| 441 |
+
git_pull_recursive(extensions_dir)
|
| 442 |
+
startup_timer.record("update extensions")
|
| 443 |
+
|
| 444 |
+
if "--exit" in sys.argv:
|
| 445 |
+
print("Exiting because of --exit argument")
|
| 446 |
+
exit(0)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def configure_for_tests():
|
| 450 |
+
if "--api" not in sys.argv:
|
| 451 |
+
sys.argv.append("--api")
|
| 452 |
+
if "--ckpt" not in sys.argv:
|
| 453 |
+
sys.argv.append("--ckpt")
|
| 454 |
+
sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt"))
|
| 455 |
+
if "--skip-torch-cuda-test" not in sys.argv:
|
| 456 |
+
sys.argv.append("--skip-torch-cuda-test")
|
| 457 |
+
if "--disable-nan-check" not in sys.argv:
|
| 458 |
+
sys.argv.append("--disable-nan-check")
|
| 459 |
+
|
| 460 |
+
os.environ['COMMANDLINE_ARGS'] = ""
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
def start():
|
| 464 |
+
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {shlex.join(sys.argv[1:])}")
|
| 465 |
+
import webui
|
| 466 |
+
if '--nowebui' in sys.argv:
|
| 467 |
+
webui.api_only()
|
| 468 |
+
else:
|
| 469 |
+
webui.webui()
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def dump_sysinfo():
|
| 473 |
+
from modules import sysinfo
|
| 474 |
+
import datetime
|
| 475 |
+
|
| 476 |
+
text = sysinfo.get()
|
| 477 |
+
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
|
| 478 |
+
|
| 479 |
+
with open(filename, "w", encoding="utf8") as file:
|
| 480 |
+
file.write(text)
|
| 481 |
+
|
| 482 |
+
return filename
|
modules/localization.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from modules import errors, scripts
|
| 5 |
+
|
| 6 |
+
localizations = {}
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def list_localizations(dirname):
|
| 10 |
+
localizations.clear()
|
| 11 |
+
|
| 12 |
+
for file in os.listdir(dirname):
|
| 13 |
+
fn, ext = os.path.splitext(file)
|
| 14 |
+
if ext.lower() != ".json":
|
| 15 |
+
continue
|
| 16 |
+
|
| 17 |
+
localizations[fn] = [os.path.join(dirname, file)]
|
| 18 |
+
|
| 19 |
+
for file in scripts.list_scripts("localizations", ".json"):
|
| 20 |
+
fn, ext = os.path.splitext(file.filename)
|
| 21 |
+
if fn not in localizations:
|
| 22 |
+
localizations[fn] = []
|
| 23 |
+
localizations[fn].append(file.path)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def localization_js(current_localization_name: str) -> str:
|
| 27 |
+
fns = localizations.get(current_localization_name, None)
|
| 28 |
+
data = {}
|
| 29 |
+
if fns is not None:
|
| 30 |
+
for fn in fns:
|
| 31 |
+
try:
|
| 32 |
+
with open(fn, "r", encoding="utf8") as file:
|
| 33 |
+
data.update(json.load(file))
|
| 34 |
+
except Exception:
|
| 35 |
+
errors.report(f"Error loading localization from {fn}", exc_info=True)
|
| 36 |
+
|
| 37 |
+
return f"window.localization = {json.dumps(data)}"
|
modules/logging_config.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
try:
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TqdmLoggingHandler(logging.Handler):
|
| 9 |
+
def __init__(self, fallback_handler: logging.Handler):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.fallback_handler = fallback_handler
|
| 12 |
+
|
| 13 |
+
def emit(self, record):
|
| 14 |
+
try:
|
| 15 |
+
# If there are active tqdm progress bars,
|
| 16 |
+
# attempt to not interfere with them.
|
| 17 |
+
if tqdm._instances:
|
| 18 |
+
tqdm.write(self.format(record))
|
| 19 |
+
else:
|
| 20 |
+
self.fallback_handler.emit(record)
|
| 21 |
+
except Exception:
|
| 22 |
+
self.fallback_handler.emit(record)
|
| 23 |
+
|
| 24 |
+
except ImportError:
|
| 25 |
+
TqdmLoggingHandler = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def setup_logging(loglevel):
|
| 29 |
+
if loglevel is None:
|
| 30 |
+
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
|
| 31 |
+
|
| 32 |
+
if not loglevel:
|
| 33 |
+
return
|
| 34 |
+
|
| 35 |
+
if logging.root.handlers:
|
| 36 |
+
# Already configured, do not interfere
|
| 37 |
+
return
|
| 38 |
+
|
| 39 |
+
formatter = logging.Formatter(
|
| 40 |
+
'%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
| 41 |
+
'%Y-%m-%d %H:%M:%S',
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
if os.environ.get("SD_WEBUI_RICH_LOG"):
|
| 45 |
+
from rich.logging import RichHandler
|
| 46 |
+
handler = RichHandler()
|
| 47 |
+
else:
|
| 48 |
+
handler = logging.StreamHandler()
|
| 49 |
+
handler.setFormatter(formatter)
|
| 50 |
+
|
| 51 |
+
if TqdmLoggingHandler:
|
| 52 |
+
handler = TqdmLoggingHandler(handler)
|
| 53 |
+
|
| 54 |
+
handler.setFormatter(formatter)
|
| 55 |
+
|
| 56 |
+
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
|
| 57 |
+
logging.root.setLevel(log_level)
|
| 58 |
+
logging.root.addHandler(handler)
|
modules/lowvram.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import namedtuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from modules import devices, shared
|
| 5 |
+
|
| 6 |
+
module_in_gpu = None
|
| 7 |
+
cpu = torch.device("cpu")
|
| 8 |
+
|
| 9 |
+
ModuleWithParent = namedtuple('ModuleWithParent', ['module', 'parent'], defaults=['None'])
|
| 10 |
+
|
| 11 |
+
def send_everything_to_cpu():
|
| 12 |
+
global module_in_gpu
|
| 13 |
+
|
| 14 |
+
if module_in_gpu is not None:
|
| 15 |
+
module_in_gpu.to(cpu)
|
| 16 |
+
|
| 17 |
+
module_in_gpu = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def is_needed(sd_model):
|
| 21 |
+
return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner')
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def apply(sd_model):
|
| 25 |
+
enable = is_needed(sd_model)
|
| 26 |
+
shared.parallel_processing_allowed = not enable
|
| 27 |
+
|
| 28 |
+
if enable:
|
| 29 |
+
setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram)
|
| 30 |
+
else:
|
| 31 |
+
sd_model.lowvram = False
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def setup_for_low_vram(sd_model, use_medvram):
|
| 35 |
+
if getattr(sd_model, 'lowvram', False):
|
| 36 |
+
return
|
| 37 |
+
|
| 38 |
+
sd_model.lowvram = True
|
| 39 |
+
|
| 40 |
+
parents = {}
|
| 41 |
+
|
| 42 |
+
def send_me_to_gpu(module, _):
|
| 43 |
+
"""send this module to GPU; send whatever tracked module was previous in GPU to CPU;
|
| 44 |
+
we add this as forward_pre_hook to a lot of modules and this way all but one of them will
|
| 45 |
+
be in CPU
|
| 46 |
+
"""
|
| 47 |
+
global module_in_gpu
|
| 48 |
+
|
| 49 |
+
module = parents.get(module, module)
|
| 50 |
+
|
| 51 |
+
if module_in_gpu == module:
|
| 52 |
+
return
|
| 53 |
+
|
| 54 |
+
if module_in_gpu is not None:
|
| 55 |
+
module_in_gpu.to(cpu)
|
| 56 |
+
|
| 57 |
+
module.to(devices.device)
|
| 58 |
+
module_in_gpu = module
|
| 59 |
+
|
| 60 |
+
# see below for register_forward_pre_hook;
|
| 61 |
+
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
|
| 62 |
+
# useless here, and we just replace those methods
|
| 63 |
+
|
| 64 |
+
first_stage_model = sd_model.first_stage_model
|
| 65 |
+
first_stage_model_encode = sd_model.first_stage_model.encode
|
| 66 |
+
first_stage_model_decode = sd_model.first_stage_model.decode
|
| 67 |
+
|
| 68 |
+
def first_stage_model_encode_wrap(x):
|
| 69 |
+
send_me_to_gpu(first_stage_model, None)
|
| 70 |
+
return first_stage_model_encode(x)
|
| 71 |
+
|
| 72 |
+
def first_stage_model_decode_wrap(z):
|
| 73 |
+
send_me_to_gpu(first_stage_model, None)
|
| 74 |
+
return first_stage_model_decode(z)
|
| 75 |
+
|
| 76 |
+
to_remain_in_cpu = [
|
| 77 |
+
(sd_model, 'first_stage_model'),
|
| 78 |
+
(sd_model, 'depth_model'),
|
| 79 |
+
(sd_model, 'embedder'),
|
| 80 |
+
(sd_model, 'model'),
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
is_sdxl = hasattr(sd_model, 'conditioner')
|
| 84 |
+
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
|
| 85 |
+
|
| 86 |
+
if hasattr(sd_model, 'medvram_fields'):
|
| 87 |
+
to_remain_in_cpu = sd_model.medvram_fields()
|
| 88 |
+
elif is_sdxl:
|
| 89 |
+
to_remain_in_cpu.append((sd_model, 'conditioner'))
|
| 90 |
+
elif is_sd2:
|
| 91 |
+
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
|
| 92 |
+
else:
|
| 93 |
+
to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))
|
| 94 |
+
|
| 95 |
+
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
|
| 96 |
+
stored = []
|
| 97 |
+
for obj, field in to_remain_in_cpu:
|
| 98 |
+
module = getattr(obj, field, None)
|
| 99 |
+
stored.append(module)
|
| 100 |
+
setattr(obj, field, None)
|
| 101 |
+
|
| 102 |
+
# send the model to GPU.
|
| 103 |
+
sd_model.to(devices.device)
|
| 104 |
+
|
| 105 |
+
# put modules back. the modules will be in CPU.
|
| 106 |
+
for (obj, field), module in zip(to_remain_in_cpu, stored):
|
| 107 |
+
setattr(obj, field, module)
|
| 108 |
+
|
| 109 |
+
# register hooks for those the first three models
|
| 110 |
+
if hasattr(sd_model, "cond_stage_model") and hasattr(sd_model.cond_stage_model, "medvram_modules"):
|
| 111 |
+
for module in sd_model.cond_stage_model.medvram_modules():
|
| 112 |
+
if isinstance(module, ModuleWithParent):
|
| 113 |
+
parent = module.parent
|
| 114 |
+
module = module.module
|
| 115 |
+
else:
|
| 116 |
+
parent = None
|
| 117 |
+
|
| 118 |
+
if module:
|
| 119 |
+
module.register_forward_pre_hook(send_me_to_gpu)
|
| 120 |
+
|
| 121 |
+
if parent:
|
| 122 |
+
parents[module] = parent
|
| 123 |
+
|
| 124 |
+
elif is_sdxl:
|
| 125 |
+
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
|
| 126 |
+
elif is_sd2:
|
| 127 |
+
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
|
| 128 |
+
sd_model.cond_stage_model.model.token_embedding.register_forward_pre_hook(send_me_to_gpu)
|
| 129 |
+
parents[sd_model.cond_stage_model.model] = sd_model.cond_stage_model
|
| 130 |
+
parents[sd_model.cond_stage_model.model.token_embedding] = sd_model.cond_stage_model
|
| 131 |
+
else:
|
| 132 |
+
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
| 133 |
+
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
| 134 |
+
|
| 135 |
+
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
| 136 |
+
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
| 137 |
+
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
| 138 |
+
if getattr(sd_model, 'depth_model', None) is not None:
|
| 139 |
+
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
|
| 140 |
+
if getattr(sd_model, 'embedder', None) is not None:
|
| 141 |
+
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
|
| 142 |
+
|
| 143 |
+
if use_medvram:
|
| 144 |
+
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
|
| 145 |
+
else:
|
| 146 |
+
diff_model = sd_model.model.diffusion_model
|
| 147 |
+
|
| 148 |
+
# the third remaining model is still too big for 4 GB, so we also do the same for its submodules
|
| 149 |
+
# so that only one of them is in GPU at a time
|
| 150 |
+
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
|
| 151 |
+
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
|
| 152 |
+
sd_model.model.to(devices.device)
|
| 153 |
+
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
|
| 154 |
+
|
| 155 |
+
# install hooks for bits of third model
|
| 156 |
+
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
|
| 157 |
+
for block in diff_model.input_blocks:
|
| 158 |
+
block.register_forward_pre_hook(send_me_to_gpu)
|
| 159 |
+
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
|
| 160 |
+
for block in diff_model.output_blocks:
|
| 161 |
+
block.register_forward_pre_hook(send_me_to_gpu)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def is_enabled(sd_model):
|
| 165 |
+
return sd_model.lowvram
|
modules/mac_specific.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
import platform
|
| 6 |
+
from modules.sd_hijack_utils import CondFunc
|
| 7 |
+
from packaging import version
|
| 8 |
+
from modules import shared
|
| 9 |
+
|
| 10 |
+
log = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
|
| 14 |
+
# use check `getattr` and try it for compatibility.
|
| 15 |
+
# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availability,
|
| 16 |
+
# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
|
| 17 |
+
def check_for_mps() -> bool:
|
| 18 |
+
if version.parse(torch.__version__) <= version.parse("2.0.1"):
|
| 19 |
+
if not getattr(torch, 'has_mps', False):
|
| 20 |
+
return False
|
| 21 |
+
try:
|
| 22 |
+
torch.zeros(1).to(torch.device("mps"))
|
| 23 |
+
return True
|
| 24 |
+
except Exception:
|
| 25 |
+
return False
|
| 26 |
+
else:
|
| 27 |
+
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
has_mps = check_for_mps()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def torch_mps_gc() -> None:
|
| 34 |
+
try:
|
| 35 |
+
if shared.state.current_latent is not None:
|
| 36 |
+
log.debug("`current_latent` is set, skipping MPS garbage collection")
|
| 37 |
+
return
|
| 38 |
+
from torch.mps import empty_cache
|
| 39 |
+
empty_cache()
|
| 40 |
+
except Exception:
|
| 41 |
+
log.warning("MPS garbage collection failed", exc_info=True)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
| 45 |
+
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
| 46 |
+
if input.device.type == 'mps':
|
| 47 |
+
output_dtype = kwargs.get('dtype', input.dtype)
|
| 48 |
+
if output_dtype == torch.int64:
|
| 49 |
+
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
| 50 |
+
elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
| 51 |
+
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
| 52 |
+
return cumsum_func(input, *args, **kwargs)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
|
| 56 |
+
def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:
|
| 57 |
+
try:
|
| 58 |
+
return orig_func(*args, **kwargs)
|
| 59 |
+
except RuntimeError as e:
|
| 60 |
+
if "not implemented for" in str(e) and "Half" in str(e):
|
| 61 |
+
input_tensor = args[0]
|
| 62 |
+
return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)
|
| 63 |
+
else:
|
| 64 |
+
print(f"An unexpected RuntimeError occurred: {str(e)}")
|
| 65 |
+
|
| 66 |
+
if has_mps:
|
| 67 |
+
if platform.mac_ver()[0].startswith("13.2."):
|
| 68 |
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
| 69 |
+
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
|
| 70 |
+
|
| 71 |
+
if version.parse(torch.__version__) < version.parse("1.13"):
|
| 72 |
+
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
| 73 |
+
|
| 74 |
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
| 75 |
+
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
| 76 |
+
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
| 77 |
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
| 78 |
+
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
| 79 |
+
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
| 80 |
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
| 81 |
+
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
| 82 |
+
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
| 83 |
+
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
| 84 |
+
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
| 85 |
+
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
| 86 |
+
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
| 87 |
+
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
| 88 |
+
|
| 89 |
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
| 90 |
+
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
|
| 91 |
+
|
| 92 |
+
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
|
| 93 |
+
CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)
|
| 94 |
+
|
| 95 |
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
|
| 96 |
+
if platform.processor() == 'i386':
|
| 97 |
+
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
|
| 98 |
+
CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
|
modules/masking.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image, ImageFilter, ImageOps
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_crop_region_v2(mask, pad=0):
|
| 5 |
+
"""
|
| 6 |
+
Finds a rectangular region that contains all masked ares in a mask.
|
| 7 |
+
Returns None if mask is completely black mask (all 0)
|
| 8 |
+
|
| 9 |
+
Parameters:
|
| 10 |
+
mask: PIL.Image.Image L mode or numpy 1d array
|
| 11 |
+
pad: int number of pixels that the region will be extended on all sides
|
| 12 |
+
Returns: (x1, y1, x2, y2) | None
|
| 13 |
+
|
| 14 |
+
Introduced post 1.9.0
|
| 15 |
+
"""
|
| 16 |
+
mask = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)
|
| 17 |
+
if box := mask.getbbox():
|
| 18 |
+
x1, y1, x2, y2 = box
|
| 19 |
+
return (max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask.size[0]), min(y2 + pad, mask.size[1])) if pad else box
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_crop_region(mask, pad=0):
|
| 23 |
+
"""
|
| 24 |
+
Same function as get_crop_region_v2 but handles completely black mask (all 0) differently
|
| 25 |
+
when mask all black still return coordinates but the coordinates may be invalid ie x2>x1 or y2>y1
|
| 26 |
+
Notes: it is possible for the coordinates to be "valid" again if pad size is sufficiently large
|
| 27 |
+
(mask_size.x-pad, mask_size.y-pad, pad, pad)
|
| 28 |
+
|
| 29 |
+
Extension developer should use get_crop_region_v2 instead unless for compatibility considerations.
|
| 30 |
+
"""
|
| 31 |
+
mask = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)
|
| 32 |
+
if box := get_crop_region_v2(mask, pad):
|
| 33 |
+
return box
|
| 34 |
+
x1, y1 = mask.size
|
| 35 |
+
x2 = y2 = 0
|
| 36 |
+
return max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask.size[0]), min(y2 + pad, mask.size[1])
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
|
| 40 |
+
"""expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region
|
| 41 |
+
for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128."""
|
| 42 |
+
|
| 43 |
+
x1, y1, x2, y2 = crop_region
|
| 44 |
+
|
| 45 |
+
ratio_crop_region = (x2 - x1) / (y2 - y1)
|
| 46 |
+
ratio_processing = processing_width / processing_height
|
| 47 |
+
|
| 48 |
+
if ratio_crop_region > ratio_processing:
|
| 49 |
+
desired_height = (x2 - x1) / ratio_processing
|
| 50 |
+
desired_height_diff = int(desired_height - (y2-y1))
|
| 51 |
+
y1 -= desired_height_diff//2
|
| 52 |
+
y2 += desired_height_diff - desired_height_diff//2
|
| 53 |
+
if y2 >= image_height:
|
| 54 |
+
diff = y2 - image_height
|
| 55 |
+
y2 -= diff
|
| 56 |
+
y1 -= diff
|
| 57 |
+
if y1 < 0:
|
| 58 |
+
y2 -= y1
|
| 59 |
+
y1 -= y1
|
| 60 |
+
if y2 >= image_height:
|
| 61 |
+
y2 = image_height
|
| 62 |
+
else:
|
| 63 |
+
desired_width = (y2 - y1) * ratio_processing
|
| 64 |
+
desired_width_diff = int(desired_width - (x2-x1))
|
| 65 |
+
x1 -= desired_width_diff//2
|
| 66 |
+
x2 += desired_width_diff - desired_width_diff//2
|
| 67 |
+
if x2 >= image_width:
|
| 68 |
+
diff = x2 - image_width
|
| 69 |
+
x2 -= diff
|
| 70 |
+
x1 -= diff
|
| 71 |
+
if x1 < 0:
|
| 72 |
+
x2 -= x1
|
| 73 |
+
x1 -= x1
|
| 74 |
+
if x2 >= image_width:
|
| 75 |
+
x2 = image_width
|
| 76 |
+
|
| 77 |
+
return x1, y1, x2, y2
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def fill(image, mask):
|
| 81 |
+
"""fills masked regions with colors from image using blur. Not extremely effective."""
|
| 82 |
+
|
| 83 |
+
image_mod = Image.new('RGBA', (image.width, image.height))
|
| 84 |
+
|
| 85 |
+
image_masked = Image.new('RGBa', (image.width, image.height))
|
| 86 |
+
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
|
| 87 |
+
|
| 88 |
+
image_masked = image_masked.convert('RGBa')
|
| 89 |
+
|
| 90 |
+
for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
|
| 91 |
+
blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
|
| 92 |
+
for _ in range(repeats):
|
| 93 |
+
image_mod.alpha_composite(blurred)
|
| 94 |
+
|
| 95 |
+
return image_mod.convert("RGB")
|
| 96 |
+
|
modules/memmon.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
import time
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MemUsageMonitor(threading.Thread):
|
| 9 |
+
run_flag = None
|
| 10 |
+
device = None
|
| 11 |
+
disabled = False
|
| 12 |
+
opts = None
|
| 13 |
+
data = None
|
| 14 |
+
|
| 15 |
+
def __init__(self, name, device, opts):
|
| 16 |
+
threading.Thread.__init__(self)
|
| 17 |
+
self.name = name
|
| 18 |
+
self.device = device
|
| 19 |
+
self.opts = opts
|
| 20 |
+
|
| 21 |
+
self.daemon = True
|
| 22 |
+
self.run_flag = threading.Event()
|
| 23 |
+
self.data = defaultdict(int)
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
self.cuda_mem_get_info()
|
| 27 |
+
torch.cuda.memory_stats(self.device)
|
| 28 |
+
except Exception as e: # AMD or whatever
|
| 29 |
+
print(f"Warning: caught exception '{e}', memory monitor disabled")
|
| 30 |
+
self.disabled = True
|
| 31 |
+
|
| 32 |
+
def cuda_mem_get_info(self):
|
| 33 |
+
index = self.device.index if self.device.index is not None else torch.cuda.current_device()
|
| 34 |
+
return torch.cuda.mem_get_info(index)
|
| 35 |
+
|
| 36 |
+
def run(self):
|
| 37 |
+
if self.disabled:
|
| 38 |
+
return
|
| 39 |
+
|
| 40 |
+
while True:
|
| 41 |
+
self.run_flag.wait()
|
| 42 |
+
|
| 43 |
+
torch.cuda.reset_peak_memory_stats()
|
| 44 |
+
self.data.clear()
|
| 45 |
+
|
| 46 |
+
if self.opts.memmon_poll_rate <= 0:
|
| 47 |
+
self.run_flag.clear()
|
| 48 |
+
continue
|
| 49 |
+
|
| 50 |
+
self.data["min_free"] = self.cuda_mem_get_info()[0]
|
| 51 |
+
|
| 52 |
+
while self.run_flag.is_set():
|
| 53 |
+
free, total = self.cuda_mem_get_info()
|
| 54 |
+
self.data["min_free"] = min(self.data["min_free"], free)
|
| 55 |
+
|
| 56 |
+
time.sleep(1 / self.opts.memmon_poll_rate)
|
| 57 |
+
|
| 58 |
+
def dump_debug(self):
|
| 59 |
+
print(self, 'recorded data:')
|
| 60 |
+
for k, v in self.read().items():
|
| 61 |
+
print(k, -(v // -(1024 ** 2)))
|
| 62 |
+
|
| 63 |
+
print(self, 'raw torch memory stats:')
|
| 64 |
+
tm = torch.cuda.memory_stats(self.device)
|
| 65 |
+
for k, v in tm.items():
|
| 66 |
+
if 'bytes' not in k:
|
| 67 |
+
continue
|
| 68 |
+
print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
|
| 69 |
+
|
| 70 |
+
print(torch.cuda.memory_summary())
|
| 71 |
+
|
| 72 |
+
def monitor(self):
|
| 73 |
+
self.run_flag.set()
|
| 74 |
+
|
| 75 |
+
def read(self):
|
| 76 |
+
if not self.disabled:
|
| 77 |
+
free, total = self.cuda_mem_get_info()
|
| 78 |
+
self.data["free"] = free
|
| 79 |
+
self.data["total"] = total
|
| 80 |
+
|
| 81 |
+
torch_stats = torch.cuda.memory_stats(self.device)
|
| 82 |
+
self.data["active"] = torch_stats["active.all.current"]
|
| 83 |
+
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
|
| 84 |
+
self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
|
| 85 |
+
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
|
| 86 |
+
self.data["system_peak"] = total - self.data["min_free"]
|
| 87 |
+
|
| 88 |
+
return self.data
|
| 89 |
+
|
| 90 |
+
def stop(self):
|
| 91 |
+
self.run_flag.clear()
|
| 92 |
+
return self.read()
|
modules/modelloader.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import importlib
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
from typing import TYPE_CHECKING
|
| 7 |
+
from urllib.parse import urlparse
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from modules import shared
|
| 12 |
+
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
import spandrel
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_file_from_url(
|
| 21 |
+
url: str,
|
| 22 |
+
*,
|
| 23 |
+
model_dir: str,
|
| 24 |
+
progress: bool = True,
|
| 25 |
+
file_name: str | None = None,
|
| 26 |
+
hash_prefix: str | None = None,
|
| 27 |
+
) -> str:
|
| 28 |
+
"""Download a file from `url` into `model_dir`, using the file present if possible.
|
| 29 |
+
|
| 30 |
+
Returns the path to the downloaded file.
|
| 31 |
+
"""
|
| 32 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 33 |
+
if not file_name:
|
| 34 |
+
parts = urlparse(url)
|
| 35 |
+
file_name = os.path.basename(parts.path)
|
| 36 |
+
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
|
| 37 |
+
if not os.path.exists(cached_file):
|
| 38 |
+
print(f'Downloading: "{url}" to {cached_file}\n')
|
| 39 |
+
from torch.hub import download_url_to_file
|
| 40 |
+
download_url_to_file(url, cached_file, progress=progress, hash_prefix=hash_prefix)
|
| 41 |
+
return cached_file
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None, hash_prefix=None) -> list:
|
| 45 |
+
"""
|
| 46 |
+
A one-and done loader to try finding the desired models in specified directories.
|
| 47 |
+
|
| 48 |
+
@param download_name: Specify to download from model_url immediately.
|
| 49 |
+
@param model_url: If no other models are found, this will be downloaded on upscale.
|
| 50 |
+
@param model_path: The location to store/find models in.
|
| 51 |
+
@param command_path: A command-line argument to search for models in first.
|
| 52 |
+
@param ext_filter: An optional list of filename extensions to filter by
|
| 53 |
+
@param hash_prefix: the expected sha256 of the model_url
|
| 54 |
+
@return: A list of paths containing the desired model(s)
|
| 55 |
+
"""
|
| 56 |
+
output = []
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
places = []
|
| 60 |
+
|
| 61 |
+
if command_path is not None and command_path != model_path:
|
| 62 |
+
pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
|
| 63 |
+
if os.path.exists(pretrained_path):
|
| 64 |
+
print(f"Appending path: {pretrained_path}")
|
| 65 |
+
places.append(pretrained_path)
|
| 66 |
+
elif os.path.exists(command_path):
|
| 67 |
+
places.append(command_path)
|
| 68 |
+
|
| 69 |
+
places.append(model_path)
|
| 70 |
+
|
| 71 |
+
for place in places:
|
| 72 |
+
for full_path in shared.walk_files(place, allowed_extensions=ext_filter):
|
| 73 |
+
if os.path.islink(full_path) and not os.path.exists(full_path):
|
| 74 |
+
print(f"Skipping broken symlink: {full_path}")
|
| 75 |
+
continue
|
| 76 |
+
if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
|
| 77 |
+
continue
|
| 78 |
+
if full_path not in output:
|
| 79 |
+
output.append(full_path)
|
| 80 |
+
|
| 81 |
+
if model_url is not None and len(output) == 0:
|
| 82 |
+
if download_name is not None:
|
| 83 |
+
output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name, hash_prefix=hash_prefix))
|
| 84 |
+
else:
|
| 85 |
+
output.append(model_url)
|
| 86 |
+
|
| 87 |
+
except Exception:
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
return output
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def friendly_name(file: str):
|
| 94 |
+
if file.startswith("http"):
|
| 95 |
+
file = urlparse(file).path
|
| 96 |
+
|
| 97 |
+
file = os.path.basename(file)
|
| 98 |
+
model_name, extension = os.path.splitext(file)
|
| 99 |
+
return model_name
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def load_upscalers():
|
| 103 |
+
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
| 104 |
+
# so we'll try to import any _model.py files before looking in __subclasses__
|
| 105 |
+
modules_dir = os.path.join(shared.script_path, "modules")
|
| 106 |
+
for file in os.listdir(modules_dir):
|
| 107 |
+
if "_model.py" in file:
|
| 108 |
+
model_name = file.replace("_model.py", "")
|
| 109 |
+
full_model = f"modules.{model_name}_model"
|
| 110 |
+
try:
|
| 111 |
+
importlib.import_module(full_model)
|
| 112 |
+
except Exception:
|
| 113 |
+
pass
|
| 114 |
+
|
| 115 |
+
data = []
|
| 116 |
+
commandline_options = vars(shared.cmd_opts)
|
| 117 |
+
|
| 118 |
+
# some of upscaler classes will not go away after reloading their modules, and we'll end
|
| 119 |
+
# up with two copies of those classes. The newest copy will always be the last in the list,
|
| 120 |
+
# so we go from end to beginning and ignore duplicates
|
| 121 |
+
used_classes = {}
|
| 122 |
+
for cls in reversed(Upscaler.__subclasses__()):
|
| 123 |
+
classname = str(cls)
|
| 124 |
+
if classname not in used_classes:
|
| 125 |
+
used_classes[classname] = cls
|
| 126 |
+
|
| 127 |
+
for cls in reversed(used_classes.values()):
|
| 128 |
+
name = cls.__name__
|
| 129 |
+
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
|
| 130 |
+
commandline_model_path = commandline_options.get(cmd_name, None)
|
| 131 |
+
scaler = cls(commandline_model_path)
|
| 132 |
+
scaler.user_path = commandline_model_path
|
| 133 |
+
scaler.model_download_path = commandline_model_path or scaler.model_path
|
| 134 |
+
data += scaler.scalers
|
| 135 |
+
|
| 136 |
+
shared.sd_upscalers = sorted(
|
| 137 |
+
data,
|
| 138 |
+
# Special case for UpscalerNone keeps it at the beginning of the list.
|
| 139 |
+
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# None: not loaded, False: failed to load, True: loaded
|
| 143 |
+
_spandrel_extra_init_state = None
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _init_spandrel_extra_archs() -> None:
|
| 147 |
+
"""
|
| 148 |
+
Try to initialize `spandrel_extra_archs` (exactly once).
|
| 149 |
+
"""
|
| 150 |
+
global _spandrel_extra_init_state
|
| 151 |
+
if _spandrel_extra_init_state is not None:
|
| 152 |
+
return
|
| 153 |
+
|
| 154 |
+
try:
|
| 155 |
+
import spandrel
|
| 156 |
+
import spandrel_extra_arches
|
| 157 |
+
spandrel.MAIN_REGISTRY.add(*spandrel_extra_arches.EXTRA_REGISTRY)
|
| 158 |
+
_spandrel_extra_init_state = True
|
| 159 |
+
except Exception:
|
| 160 |
+
logger.warning("Failed to load spandrel_extra_arches", exc_info=True)
|
| 161 |
+
_spandrel_extra_init_state = False
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def load_spandrel_model(
|
| 165 |
+
path: str | os.PathLike,
|
| 166 |
+
*,
|
| 167 |
+
device: str | torch.device | None,
|
| 168 |
+
prefer_half: bool = False,
|
| 169 |
+
dtype: str | torch.dtype | None = None,
|
| 170 |
+
expected_architecture: str | None = None,
|
| 171 |
+
) -> spandrel.ModelDescriptor:
|
| 172 |
+
global _spandrel_extra_init_state
|
| 173 |
+
|
| 174 |
+
import spandrel
|
| 175 |
+
_init_spandrel_extra_archs()
|
| 176 |
+
|
| 177 |
+
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))
|
| 178 |
+
arch = model_descriptor.architecture
|
| 179 |
+
if expected_architecture and arch.name != expected_architecture:
|
| 180 |
+
logger.warning(
|
| 181 |
+
f"Model {path!r} is not a {expected_architecture!r} model (got {arch.name!r})",
|
| 182 |
+
)
|
| 183 |
+
half = False
|
| 184 |
+
if prefer_half:
|
| 185 |
+
if model_descriptor.supports_half:
|
| 186 |
+
model_descriptor.model.half()
|
| 187 |
+
half = True
|
| 188 |
+
else:
|
| 189 |
+
logger.info("Model %s does not support half precision, ignoring --half", path)
|
| 190 |
+
if dtype:
|
| 191 |
+
model_descriptor.model.to(dtype=dtype)
|
| 192 |
+
model_descriptor.model.eval()
|
| 193 |
+
logger.debug(
|
| 194 |
+
"Loaded %s from %s (device=%s, half=%s, dtype=%s)",
|
| 195 |
+
arch, path, device, half, dtype,
|
| 196 |
+
)
|
| 197 |
+
return model_descriptor
|
modules/models/diffusion/ddpm_edit.py
ADDED
|
@@ -0,0 +1,1460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
wild mixture of
|
| 3 |
+
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
| 4 |
+
https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
|
| 5 |
+
https://github.com/CompVis/taming-transformers
|
| 6 |
+
-- merci
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
|
| 10 |
+
# See more details in LICENSE.
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import numpy as np
|
| 15 |
+
import pytorch_lightning as pl
|
| 16 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 17 |
+
from einops import rearrange, repeat
|
| 18 |
+
from contextlib import contextmanager
|
| 19 |
+
from functools import partial
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
from torchvision.utils import make_grid
|
| 22 |
+
from pytorch_lightning.utilities.distributed import rank_zero_only
|
| 23 |
+
|
| 24 |
+
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
| 25 |
+
from ldm.modules.ema import LitEma
|
| 26 |
+
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
| 27 |
+
from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
|
| 28 |
+
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
| 29 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
from ldm.models.autoencoder import VQModelInterface
|
| 33 |
+
except Exception:
|
| 34 |
+
class VQModelInterface:
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
__conditioning_keys__ = {'concat': 'c_concat',
|
| 38 |
+
'crossattn': 'c_crossattn',
|
| 39 |
+
'adm': 'y'}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def disabled_train(self, mode=True):
|
| 43 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
| 44 |
+
does not change anymore."""
|
| 45 |
+
return self
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def uniform_on_device(r1, r2, shape, device):
|
| 49 |
+
return (r1 - r2) * torch.rand(*shape, device=device) + r2
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class DDPM(pl.LightningModule):
|
| 53 |
+
# classic DDPM with Gaussian diffusion, in image space
|
| 54 |
+
def __init__(self,
|
| 55 |
+
unet_config,
|
| 56 |
+
timesteps=1000,
|
| 57 |
+
beta_schedule="linear",
|
| 58 |
+
loss_type="l2",
|
| 59 |
+
ckpt_path=None,
|
| 60 |
+
ignore_keys=None,
|
| 61 |
+
load_only_unet=False,
|
| 62 |
+
monitor="val/loss",
|
| 63 |
+
use_ema=True,
|
| 64 |
+
first_stage_key="image",
|
| 65 |
+
image_size=256,
|
| 66 |
+
channels=3,
|
| 67 |
+
log_every_t=100,
|
| 68 |
+
clip_denoised=True,
|
| 69 |
+
linear_start=1e-4,
|
| 70 |
+
linear_end=2e-2,
|
| 71 |
+
cosine_s=8e-3,
|
| 72 |
+
given_betas=None,
|
| 73 |
+
original_elbo_weight=0.,
|
| 74 |
+
v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
|
| 75 |
+
l_simple_weight=1.,
|
| 76 |
+
conditioning_key=None,
|
| 77 |
+
parameterization="eps", # all assuming fixed variance schedules
|
| 78 |
+
scheduler_config=None,
|
| 79 |
+
use_positional_encodings=False,
|
| 80 |
+
learn_logvar=False,
|
| 81 |
+
logvar_init=0.,
|
| 82 |
+
load_ema=True,
|
| 83 |
+
):
|
| 84 |
+
super().__init__()
|
| 85 |
+
assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
|
| 86 |
+
self.parameterization = parameterization
|
| 87 |
+
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
|
| 88 |
+
self.cond_stage_model = None
|
| 89 |
+
self.clip_denoised = clip_denoised
|
| 90 |
+
self.log_every_t = log_every_t
|
| 91 |
+
self.first_stage_key = first_stage_key
|
| 92 |
+
self.image_size = image_size # try conv?
|
| 93 |
+
self.channels = channels
|
| 94 |
+
self.use_positional_encodings = use_positional_encodings
|
| 95 |
+
self.model = DiffusionWrapper(unet_config, conditioning_key)
|
| 96 |
+
count_params(self.model, verbose=True)
|
| 97 |
+
self.use_ema = use_ema
|
| 98 |
+
|
| 99 |
+
self.use_scheduler = scheduler_config is not None
|
| 100 |
+
if self.use_scheduler:
|
| 101 |
+
self.scheduler_config = scheduler_config
|
| 102 |
+
|
| 103 |
+
self.v_posterior = v_posterior
|
| 104 |
+
self.original_elbo_weight = original_elbo_weight
|
| 105 |
+
self.l_simple_weight = l_simple_weight
|
| 106 |
+
|
| 107 |
+
if monitor is not None:
|
| 108 |
+
self.monitor = monitor
|
| 109 |
+
|
| 110 |
+
if self.use_ema and load_ema:
|
| 111 |
+
self.model_ema = LitEma(self.model)
|
| 112 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
| 113 |
+
|
| 114 |
+
if ckpt_path is not None:
|
| 115 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
|
| 116 |
+
|
| 117 |
+
# If initialing from EMA-only checkpoint, create EMA model after loading.
|
| 118 |
+
if self.use_ema and not load_ema:
|
| 119 |
+
self.model_ema = LitEma(self.model)
|
| 120 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
| 121 |
+
|
| 122 |
+
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
|
| 123 |
+
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
| 124 |
+
|
| 125 |
+
self.loss_type = loss_type
|
| 126 |
+
|
| 127 |
+
self.learn_logvar = learn_logvar
|
| 128 |
+
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
|
| 129 |
+
if self.learn_logvar:
|
| 130 |
+
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
| 134 |
+
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
| 135 |
+
if exists(given_betas):
|
| 136 |
+
betas = given_betas
|
| 137 |
+
else:
|
| 138 |
+
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
| 139 |
+
cosine_s=cosine_s)
|
| 140 |
+
alphas = 1. - betas
|
| 141 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
| 142 |
+
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
| 143 |
+
|
| 144 |
+
timesteps, = betas.shape
|
| 145 |
+
self.num_timesteps = int(timesteps)
|
| 146 |
+
self.linear_start = linear_start
|
| 147 |
+
self.linear_end = linear_end
|
| 148 |
+
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
| 149 |
+
|
| 150 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
| 151 |
+
|
| 152 |
+
self.register_buffer('betas', to_torch(betas))
|
| 153 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
| 154 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
| 155 |
+
|
| 156 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 157 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
| 158 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
| 159 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
| 160 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
| 161 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
| 162 |
+
|
| 163 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
| 164 |
+
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
|
| 165 |
+
1. - alphas_cumprod) + self.v_posterior * betas
|
| 166 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
| 167 |
+
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
| 168 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
| 169 |
+
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
| 170 |
+
self.register_buffer('posterior_mean_coef1', to_torch(
|
| 171 |
+
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
|
| 172 |
+
self.register_buffer('posterior_mean_coef2', to_torch(
|
| 173 |
+
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
|
| 174 |
+
|
| 175 |
+
if self.parameterization == "eps":
|
| 176 |
+
lvlb_weights = self.betas ** 2 / (
|
| 177 |
+
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
|
| 178 |
+
elif self.parameterization == "x0":
|
| 179 |
+
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
|
| 180 |
+
else:
|
| 181 |
+
raise NotImplementedError("mu not supported")
|
| 182 |
+
# TODO how to choose this term
|
| 183 |
+
lvlb_weights[0] = lvlb_weights[1]
|
| 184 |
+
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
|
| 185 |
+
assert not torch.isnan(self.lvlb_weights).all()
|
| 186 |
+
|
| 187 |
+
@contextmanager
|
| 188 |
+
def ema_scope(self, context=None):
|
| 189 |
+
if self.use_ema:
|
| 190 |
+
self.model_ema.store(self.model.parameters())
|
| 191 |
+
self.model_ema.copy_to(self.model)
|
| 192 |
+
if context is not None:
|
| 193 |
+
print(f"{context}: Switched to EMA weights")
|
| 194 |
+
try:
|
| 195 |
+
yield None
|
| 196 |
+
finally:
|
| 197 |
+
if self.use_ema:
|
| 198 |
+
self.model_ema.restore(self.model.parameters())
|
| 199 |
+
if context is not None:
|
| 200 |
+
print(f"{context}: Restored training weights")
|
| 201 |
+
|
| 202 |
+
def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
|
| 203 |
+
ignore_keys = ignore_keys or []
|
| 204 |
+
|
| 205 |
+
sd = torch.load(path, map_location="cpu")
|
| 206 |
+
if "state_dict" in list(sd.keys()):
|
| 207 |
+
sd = sd["state_dict"]
|
| 208 |
+
keys = list(sd.keys())
|
| 209 |
+
|
| 210 |
+
# Our model adds additional channels to the first layer to condition on an input image.
|
| 211 |
+
# For the first layer, copy existing channel weights and initialize new channel weights to zero.
|
| 212 |
+
input_keys = [
|
| 213 |
+
"model.diffusion_model.input_blocks.0.0.weight",
|
| 214 |
+
"model_ema.diffusion_modelinput_blocks00weight",
|
| 215 |
+
]
|
| 216 |
+
|
| 217 |
+
self_sd = self.state_dict()
|
| 218 |
+
for input_key in input_keys:
|
| 219 |
+
if input_key not in sd or input_key not in self_sd:
|
| 220 |
+
continue
|
| 221 |
+
|
| 222 |
+
input_weight = self_sd[input_key]
|
| 223 |
+
|
| 224 |
+
if input_weight.size() != sd[input_key].size():
|
| 225 |
+
print(f"Manual init: {input_key}")
|
| 226 |
+
input_weight.zero_()
|
| 227 |
+
input_weight[:, :4, :, :].copy_(sd[input_key])
|
| 228 |
+
ignore_keys.append(input_key)
|
| 229 |
+
|
| 230 |
+
for k in keys:
|
| 231 |
+
for ik in ignore_keys:
|
| 232 |
+
if k.startswith(ik):
|
| 233 |
+
print(f"Deleting key {k} from state_dict.")
|
| 234 |
+
del sd[k]
|
| 235 |
+
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
| 236 |
+
sd, strict=False)
|
| 237 |
+
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
| 238 |
+
if missing:
|
| 239 |
+
print(f"Missing Keys: {missing}")
|
| 240 |
+
if unexpected:
|
| 241 |
+
print(f"Unexpected Keys: {unexpected}")
|
| 242 |
+
|
| 243 |
+
def q_mean_variance(self, x_start, t):
|
| 244 |
+
"""
|
| 245 |
+
Get the distribution q(x_t | x_0).
|
| 246 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
| 247 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
| 248 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
| 249 |
+
"""
|
| 250 |
+
mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
|
| 251 |
+
variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
| 252 |
+
log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
| 253 |
+
return mean, variance, log_variance
|
| 254 |
+
|
| 255 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
| 256 |
+
return (
|
| 257 |
+
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
| 258 |
+
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
def q_posterior(self, x_start, x_t, t):
|
| 262 |
+
posterior_mean = (
|
| 263 |
+
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
| 264 |
+
extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 265 |
+
)
|
| 266 |
+
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
| 267 |
+
posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
|
| 268 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 269 |
+
|
| 270 |
+
def p_mean_variance(self, x, t, clip_denoised: bool):
|
| 271 |
+
model_out = self.model(x, t)
|
| 272 |
+
if self.parameterization == "eps":
|
| 273 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
|
| 274 |
+
elif self.parameterization == "x0":
|
| 275 |
+
x_recon = model_out
|
| 276 |
+
if clip_denoised:
|
| 277 |
+
x_recon.clamp_(-1., 1.)
|
| 278 |
+
|
| 279 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
| 280 |
+
return model_mean, posterior_variance, posterior_log_variance
|
| 281 |
+
|
| 282 |
+
@torch.no_grad()
|
| 283 |
+
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
|
| 284 |
+
b, *_, device = *x.shape, x.device
|
| 285 |
+
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
|
| 286 |
+
noise = noise_like(x.shape, device, repeat_noise)
|
| 287 |
+
# no noise when t == 0
|
| 288 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
| 289 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
| 290 |
+
|
| 291 |
+
@torch.no_grad()
|
| 292 |
+
def p_sample_loop(self, shape, return_intermediates=False):
|
| 293 |
+
device = self.betas.device
|
| 294 |
+
b = shape[0]
|
| 295 |
+
img = torch.randn(shape, device=device)
|
| 296 |
+
intermediates = [img]
|
| 297 |
+
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
|
| 298 |
+
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
|
| 299 |
+
clip_denoised=self.clip_denoised)
|
| 300 |
+
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
|
| 301 |
+
intermediates.append(img)
|
| 302 |
+
if return_intermediates:
|
| 303 |
+
return img, intermediates
|
| 304 |
+
return img
|
| 305 |
+
|
| 306 |
+
@torch.no_grad()
|
| 307 |
+
def sample(self, batch_size=16, return_intermediates=False):
|
| 308 |
+
image_size = self.image_size
|
| 309 |
+
channels = self.channels
|
| 310 |
+
return self.p_sample_loop((batch_size, channels, image_size, image_size),
|
| 311 |
+
return_intermediates=return_intermediates)
|
| 312 |
+
|
| 313 |
+
def q_sample(self, x_start, t, noise=None):
|
| 314 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 315 |
+
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
| 316 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
| 317 |
+
|
| 318 |
+
def get_loss(self, pred, target, mean=True):
|
| 319 |
+
if self.loss_type == 'l1':
|
| 320 |
+
loss = (target - pred).abs()
|
| 321 |
+
if mean:
|
| 322 |
+
loss = loss.mean()
|
| 323 |
+
elif self.loss_type == 'l2':
|
| 324 |
+
if mean:
|
| 325 |
+
loss = torch.nn.functional.mse_loss(target, pred)
|
| 326 |
+
else:
|
| 327 |
+
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
|
| 328 |
+
else:
|
| 329 |
+
raise NotImplementedError("unknown loss type '{loss_type}'")
|
| 330 |
+
|
| 331 |
+
return loss
|
| 332 |
+
|
| 333 |
+
def p_losses(self, x_start, t, noise=None):
|
| 334 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 335 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
| 336 |
+
model_out = self.model(x_noisy, t)
|
| 337 |
+
|
| 338 |
+
loss_dict = {}
|
| 339 |
+
if self.parameterization == "eps":
|
| 340 |
+
target = noise
|
| 341 |
+
elif self.parameterization == "x0":
|
| 342 |
+
target = x_start
|
| 343 |
+
else:
|
| 344 |
+
raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
|
| 345 |
+
|
| 346 |
+
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
|
| 347 |
+
|
| 348 |
+
log_prefix = 'train' if self.training else 'val'
|
| 349 |
+
|
| 350 |
+
loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
|
| 351 |
+
loss_simple = loss.mean() * self.l_simple_weight
|
| 352 |
+
|
| 353 |
+
loss_vlb = (self.lvlb_weights[t] * loss).mean()
|
| 354 |
+
loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
|
| 355 |
+
|
| 356 |
+
loss = loss_simple + self.original_elbo_weight * loss_vlb
|
| 357 |
+
|
| 358 |
+
loss_dict.update({f'{log_prefix}/loss': loss})
|
| 359 |
+
|
| 360 |
+
return loss, loss_dict
|
| 361 |
+
|
| 362 |
+
def forward(self, x, *args, **kwargs):
|
| 363 |
+
# b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
|
| 364 |
+
# assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
|
| 365 |
+
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
|
| 366 |
+
return self.p_losses(x, t, *args, **kwargs)
|
| 367 |
+
|
| 368 |
+
def get_input(self, batch, k):
|
| 369 |
+
return batch[k]
|
| 370 |
+
|
| 371 |
+
def shared_step(self, batch):
|
| 372 |
+
x = self.get_input(batch, self.first_stage_key)
|
| 373 |
+
loss, loss_dict = self(x)
|
| 374 |
+
return loss, loss_dict
|
| 375 |
+
|
| 376 |
+
def training_step(self, batch, batch_idx):
|
| 377 |
+
loss, loss_dict = self.shared_step(batch)
|
| 378 |
+
|
| 379 |
+
self.log_dict(loss_dict, prog_bar=True,
|
| 380 |
+
logger=True, on_step=True, on_epoch=True)
|
| 381 |
+
|
| 382 |
+
self.log("global_step", self.global_step,
|
| 383 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
| 384 |
+
|
| 385 |
+
if self.use_scheduler:
|
| 386 |
+
lr = self.optimizers().param_groups[0]['lr']
|
| 387 |
+
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
| 388 |
+
|
| 389 |
+
return loss
|
| 390 |
+
|
| 391 |
+
@torch.no_grad()
|
| 392 |
+
def validation_step(self, batch, batch_idx):
|
| 393 |
+
_, loss_dict_no_ema = self.shared_step(batch)
|
| 394 |
+
with self.ema_scope():
|
| 395 |
+
_, loss_dict_ema = self.shared_step(batch)
|
| 396 |
+
loss_dict_ema = {f"{key}_ema": loss_dict_ema[key] for key in loss_dict_ema}
|
| 397 |
+
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
| 398 |
+
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
| 399 |
+
|
| 400 |
+
def on_train_batch_end(self, *args, **kwargs):
|
| 401 |
+
if self.use_ema:
|
| 402 |
+
self.model_ema(self.model)
|
| 403 |
+
|
| 404 |
+
def _get_rows_from_list(self, samples):
|
| 405 |
+
n_imgs_per_row = len(samples)
|
| 406 |
+
denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
|
| 407 |
+
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
|
| 408 |
+
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
|
| 409 |
+
return denoise_grid
|
| 410 |
+
|
| 411 |
+
@torch.no_grad()
|
| 412 |
+
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
| 413 |
+
log = {}
|
| 414 |
+
x = self.get_input(batch, self.first_stage_key)
|
| 415 |
+
N = min(x.shape[0], N)
|
| 416 |
+
n_row = min(x.shape[0], n_row)
|
| 417 |
+
x = x.to(self.device)[:N]
|
| 418 |
+
log["inputs"] = x
|
| 419 |
+
|
| 420 |
+
# get diffusion row
|
| 421 |
+
diffusion_row = []
|
| 422 |
+
x_start = x[:n_row]
|
| 423 |
+
|
| 424 |
+
for t in range(self.num_timesteps):
|
| 425 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
| 426 |
+
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
| 427 |
+
t = t.to(self.device).long()
|
| 428 |
+
noise = torch.randn_like(x_start)
|
| 429 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
| 430 |
+
diffusion_row.append(x_noisy)
|
| 431 |
+
|
| 432 |
+
log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
|
| 433 |
+
|
| 434 |
+
if sample:
|
| 435 |
+
# get denoise row
|
| 436 |
+
with self.ema_scope("Plotting"):
|
| 437 |
+
samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
|
| 438 |
+
|
| 439 |
+
log["samples"] = samples
|
| 440 |
+
log["denoise_row"] = self._get_rows_from_list(denoise_row)
|
| 441 |
+
|
| 442 |
+
if return_keys:
|
| 443 |
+
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
|
| 444 |
+
return log
|
| 445 |
+
else:
|
| 446 |
+
return {key: log[key] for key in return_keys}
|
| 447 |
+
return log
|
| 448 |
+
|
| 449 |
+
def configure_optimizers(self):
|
| 450 |
+
lr = self.learning_rate
|
| 451 |
+
params = list(self.model.parameters())
|
| 452 |
+
if self.learn_logvar:
|
| 453 |
+
params = params + [self.logvar]
|
| 454 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
| 455 |
+
return opt
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
class LatentDiffusion(DDPM):
|
| 459 |
+
"""main class"""
|
| 460 |
+
def __init__(self,
|
| 461 |
+
first_stage_config,
|
| 462 |
+
cond_stage_config,
|
| 463 |
+
num_timesteps_cond=None,
|
| 464 |
+
cond_stage_key="image",
|
| 465 |
+
cond_stage_trainable=False,
|
| 466 |
+
concat_mode=True,
|
| 467 |
+
cond_stage_forward=None,
|
| 468 |
+
conditioning_key=None,
|
| 469 |
+
scale_factor=1.0,
|
| 470 |
+
scale_by_std=False,
|
| 471 |
+
load_ema=True,
|
| 472 |
+
*args, **kwargs):
|
| 473 |
+
self.num_timesteps_cond = default(num_timesteps_cond, 1)
|
| 474 |
+
self.scale_by_std = scale_by_std
|
| 475 |
+
assert self.num_timesteps_cond <= kwargs['timesteps']
|
| 476 |
+
# for backwards compatibility after implementation of DiffusionWrapper
|
| 477 |
+
if conditioning_key is None:
|
| 478 |
+
conditioning_key = 'concat' if concat_mode else 'crossattn'
|
| 479 |
+
if cond_stage_config == '__is_unconditional__':
|
| 480 |
+
conditioning_key = None
|
| 481 |
+
ckpt_path = kwargs.pop("ckpt_path", None)
|
| 482 |
+
ignore_keys = kwargs.pop("ignore_keys", [])
|
| 483 |
+
super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs)
|
| 484 |
+
self.concat_mode = concat_mode
|
| 485 |
+
self.cond_stage_trainable = cond_stage_trainable
|
| 486 |
+
self.cond_stage_key = cond_stage_key
|
| 487 |
+
try:
|
| 488 |
+
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
| 489 |
+
except Exception:
|
| 490 |
+
self.num_downs = 0
|
| 491 |
+
if not scale_by_std:
|
| 492 |
+
self.scale_factor = scale_factor
|
| 493 |
+
else:
|
| 494 |
+
self.register_buffer('scale_factor', torch.tensor(scale_factor))
|
| 495 |
+
self.instantiate_first_stage(first_stage_config)
|
| 496 |
+
self.instantiate_cond_stage(cond_stage_config)
|
| 497 |
+
self.cond_stage_forward = cond_stage_forward
|
| 498 |
+
self.clip_denoised = False
|
| 499 |
+
self.bbox_tokenizer = None
|
| 500 |
+
|
| 501 |
+
self.restarted_from_ckpt = False
|
| 502 |
+
if ckpt_path is not None:
|
| 503 |
+
self.init_from_ckpt(ckpt_path, ignore_keys)
|
| 504 |
+
self.restarted_from_ckpt = True
|
| 505 |
+
|
| 506 |
+
if self.use_ema and not load_ema:
|
| 507 |
+
self.model_ema = LitEma(self.model)
|
| 508 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
| 509 |
+
|
| 510 |
+
def make_cond_schedule(self, ):
|
| 511 |
+
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
|
| 512 |
+
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
|
| 513 |
+
self.cond_ids[:self.num_timesteps_cond] = ids
|
| 514 |
+
|
| 515 |
+
@rank_zero_only
|
| 516 |
+
@torch.no_grad()
|
| 517 |
+
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
|
| 518 |
+
# only for very first batch
|
| 519 |
+
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
|
| 520 |
+
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
|
| 521 |
+
# set rescale weight to 1./std of encodings
|
| 522 |
+
print("### USING STD-RESCALING ###")
|
| 523 |
+
x = super().get_input(batch, self.first_stage_key)
|
| 524 |
+
x = x.to(self.device)
|
| 525 |
+
encoder_posterior = self.encode_first_stage(x)
|
| 526 |
+
z = self.get_first_stage_encoding(encoder_posterior).detach()
|
| 527 |
+
del self.scale_factor
|
| 528 |
+
self.register_buffer('scale_factor', 1. / z.flatten().std())
|
| 529 |
+
print(f"setting self.scale_factor to {self.scale_factor}")
|
| 530 |
+
print("### USING STD-RESCALING ###")
|
| 531 |
+
|
| 532 |
+
def register_schedule(self,
|
| 533 |
+
given_betas=None, beta_schedule="linear", timesteps=1000,
|
| 534 |
+
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
| 535 |
+
super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
|
| 536 |
+
|
| 537 |
+
self.shorten_cond_schedule = self.num_timesteps_cond > 1
|
| 538 |
+
if self.shorten_cond_schedule:
|
| 539 |
+
self.make_cond_schedule()
|
| 540 |
+
|
| 541 |
+
def instantiate_first_stage(self, config):
|
| 542 |
+
model = instantiate_from_config(config)
|
| 543 |
+
self.first_stage_model = model.eval()
|
| 544 |
+
self.first_stage_model.train = disabled_train
|
| 545 |
+
for param in self.first_stage_model.parameters():
|
| 546 |
+
param.requires_grad = False
|
| 547 |
+
|
| 548 |
+
def instantiate_cond_stage(self, config):
|
| 549 |
+
if not self.cond_stage_trainable:
|
| 550 |
+
if config == "__is_first_stage__":
|
| 551 |
+
print("Using first stage also as cond stage.")
|
| 552 |
+
self.cond_stage_model = self.first_stage_model
|
| 553 |
+
elif config == "__is_unconditional__":
|
| 554 |
+
print(f"Training {self.__class__.__name__} as an unconditional model.")
|
| 555 |
+
self.cond_stage_model = None
|
| 556 |
+
# self.be_unconditional = True
|
| 557 |
+
else:
|
| 558 |
+
model = instantiate_from_config(config)
|
| 559 |
+
self.cond_stage_model = model.eval()
|
| 560 |
+
self.cond_stage_model.train = disabled_train
|
| 561 |
+
for param in self.cond_stage_model.parameters():
|
| 562 |
+
param.requires_grad = False
|
| 563 |
+
else:
|
| 564 |
+
assert config != '__is_first_stage__'
|
| 565 |
+
assert config != '__is_unconditional__'
|
| 566 |
+
model = instantiate_from_config(config)
|
| 567 |
+
self.cond_stage_model = model
|
| 568 |
+
|
| 569 |
+
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
|
| 570 |
+
denoise_row = []
|
| 571 |
+
for zd in tqdm(samples, desc=desc):
|
| 572 |
+
denoise_row.append(self.decode_first_stage(zd.to(self.device),
|
| 573 |
+
force_not_quantize=force_no_decoder_quantization))
|
| 574 |
+
n_imgs_per_row = len(denoise_row)
|
| 575 |
+
denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
|
| 576 |
+
denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
|
| 577 |
+
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
|
| 578 |
+
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
|
| 579 |
+
return denoise_grid
|
| 580 |
+
|
| 581 |
+
def get_first_stage_encoding(self, encoder_posterior):
|
| 582 |
+
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
|
| 583 |
+
z = encoder_posterior.sample()
|
| 584 |
+
elif isinstance(encoder_posterior, torch.Tensor):
|
| 585 |
+
z = encoder_posterior
|
| 586 |
+
else:
|
| 587 |
+
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
|
| 588 |
+
return self.scale_factor * z
|
| 589 |
+
|
| 590 |
+
def get_learned_conditioning(self, c):
|
| 591 |
+
if self.cond_stage_forward is None:
|
| 592 |
+
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
|
| 593 |
+
c = self.cond_stage_model.encode(c)
|
| 594 |
+
if isinstance(c, DiagonalGaussianDistribution):
|
| 595 |
+
c = c.mode()
|
| 596 |
+
else:
|
| 597 |
+
c = self.cond_stage_model(c)
|
| 598 |
+
else:
|
| 599 |
+
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
|
| 600 |
+
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
|
| 601 |
+
return c
|
| 602 |
+
|
| 603 |
+
def meshgrid(self, h, w):
|
| 604 |
+
y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
|
| 605 |
+
x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
|
| 606 |
+
|
| 607 |
+
arr = torch.cat([y, x], dim=-1)
|
| 608 |
+
return arr
|
| 609 |
+
|
| 610 |
+
def delta_border(self, h, w):
|
| 611 |
+
"""
|
| 612 |
+
:param h: height
|
| 613 |
+
:param w: width
|
| 614 |
+
:return: normalized distance to image border,
|
| 615 |
+
wtith min distance = 0 at border and max dist = 0.5 at image center
|
| 616 |
+
"""
|
| 617 |
+
lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
|
| 618 |
+
arr = self.meshgrid(h, w) / lower_right_corner
|
| 619 |
+
dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
|
| 620 |
+
dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
|
| 621 |
+
edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
|
| 622 |
+
return edge_dist
|
| 623 |
+
|
| 624 |
+
def get_weighting(self, h, w, Ly, Lx, device):
|
| 625 |
+
weighting = self.delta_border(h, w)
|
| 626 |
+
weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
|
| 627 |
+
self.split_input_params["clip_max_weight"], )
|
| 628 |
+
weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
|
| 629 |
+
|
| 630 |
+
if self.split_input_params["tie_braker"]:
|
| 631 |
+
L_weighting = self.delta_border(Ly, Lx)
|
| 632 |
+
L_weighting = torch.clip(L_weighting,
|
| 633 |
+
self.split_input_params["clip_min_tie_weight"],
|
| 634 |
+
self.split_input_params["clip_max_tie_weight"])
|
| 635 |
+
|
| 636 |
+
L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
|
| 637 |
+
weighting = weighting * L_weighting
|
| 638 |
+
return weighting
|
| 639 |
+
|
| 640 |
+
def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
|
| 641 |
+
"""
|
| 642 |
+
:param x: img of size (bs, c, h, w)
|
| 643 |
+
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
|
| 644 |
+
"""
|
| 645 |
+
bs, nc, h, w = x.shape
|
| 646 |
+
|
| 647 |
+
# number of crops in image
|
| 648 |
+
Ly = (h - kernel_size[0]) // stride[0] + 1
|
| 649 |
+
Lx = (w - kernel_size[1]) // stride[1] + 1
|
| 650 |
+
|
| 651 |
+
if uf == 1 and df == 1:
|
| 652 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
| 653 |
+
unfold = torch.nn.Unfold(**fold_params)
|
| 654 |
+
|
| 655 |
+
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
|
| 656 |
+
|
| 657 |
+
weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
|
| 658 |
+
normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
|
| 659 |
+
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
|
| 660 |
+
|
| 661 |
+
elif uf > 1 and df == 1:
|
| 662 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
| 663 |
+
unfold = torch.nn.Unfold(**fold_params)
|
| 664 |
+
|
| 665 |
+
fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
|
| 666 |
+
dilation=1, padding=0,
|
| 667 |
+
stride=(stride[0] * uf, stride[1] * uf))
|
| 668 |
+
fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
|
| 669 |
+
|
| 670 |
+
weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
|
| 671 |
+
normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
|
| 672 |
+
weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
|
| 673 |
+
|
| 674 |
+
elif df > 1 and uf == 1:
|
| 675 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
| 676 |
+
unfold = torch.nn.Unfold(**fold_params)
|
| 677 |
+
|
| 678 |
+
fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
|
| 679 |
+
dilation=1, padding=0,
|
| 680 |
+
stride=(stride[0] // df, stride[1] // df))
|
| 681 |
+
fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
|
| 682 |
+
|
| 683 |
+
weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
|
| 684 |
+
normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
|
| 685 |
+
weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
|
| 686 |
+
|
| 687 |
+
else:
|
| 688 |
+
raise NotImplementedError
|
| 689 |
+
|
| 690 |
+
return fold, unfold, normalization, weighting
|
| 691 |
+
|
| 692 |
+
@torch.no_grad()
|
| 693 |
+
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
|
| 694 |
+
cond_key=None, return_original_cond=False, bs=None, uncond=0.05):
|
| 695 |
+
x = super().get_input(batch, k)
|
| 696 |
+
if bs is not None:
|
| 697 |
+
x = x[:bs]
|
| 698 |
+
x = x.to(self.device)
|
| 699 |
+
encoder_posterior = self.encode_first_stage(x)
|
| 700 |
+
z = self.get_first_stage_encoding(encoder_posterior).detach()
|
| 701 |
+
cond_key = cond_key or self.cond_stage_key
|
| 702 |
+
xc = super().get_input(batch, cond_key)
|
| 703 |
+
if bs is not None:
|
| 704 |
+
xc["c_crossattn"] = xc["c_crossattn"][:bs]
|
| 705 |
+
xc["c_concat"] = xc["c_concat"][:bs]
|
| 706 |
+
cond = {}
|
| 707 |
+
|
| 708 |
+
# To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%.
|
| 709 |
+
random = torch.rand(x.size(0), device=x.device)
|
| 710 |
+
prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1")
|
| 711 |
+
input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1")
|
| 712 |
+
|
| 713 |
+
null_prompt = self.get_learned_conditioning([""])
|
| 714 |
+
cond["c_crossattn"] = [torch.where(prompt_mask, null_prompt, self.get_learned_conditioning(xc["c_crossattn"]).detach())]
|
| 715 |
+
cond["c_concat"] = [input_mask * self.encode_first_stage((xc["c_concat"].to(self.device))).mode().detach()]
|
| 716 |
+
|
| 717 |
+
out = [z, cond]
|
| 718 |
+
if return_first_stage_outputs:
|
| 719 |
+
xrec = self.decode_first_stage(z)
|
| 720 |
+
out.extend([x, xrec])
|
| 721 |
+
if return_original_cond:
|
| 722 |
+
out.append(xc)
|
| 723 |
+
return out
|
| 724 |
+
|
| 725 |
+
@torch.no_grad()
|
| 726 |
+
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
|
| 727 |
+
if predict_cids:
|
| 728 |
+
if z.dim() == 4:
|
| 729 |
+
z = torch.argmax(z.exp(), dim=1).long()
|
| 730 |
+
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
|
| 731 |
+
z = rearrange(z, 'b h w c -> b c h w').contiguous()
|
| 732 |
+
|
| 733 |
+
z = 1. / self.scale_factor * z
|
| 734 |
+
|
| 735 |
+
if hasattr(self, "split_input_params"):
|
| 736 |
+
if self.split_input_params["patch_distributed_vq"]:
|
| 737 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
| 738 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
| 739 |
+
uf = self.split_input_params["vqf"]
|
| 740 |
+
bs, nc, h, w = z.shape
|
| 741 |
+
if ks[0] > h or ks[1] > w:
|
| 742 |
+
ks = (min(ks[0], h), min(ks[1], w))
|
| 743 |
+
print("reducing Kernel")
|
| 744 |
+
|
| 745 |
+
if stride[0] > h or stride[1] > w:
|
| 746 |
+
stride = (min(stride[0], h), min(stride[1], w))
|
| 747 |
+
print("reducing stride")
|
| 748 |
+
|
| 749 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
|
| 750 |
+
|
| 751 |
+
z = unfold(z) # (bn, nc * prod(**ks), L)
|
| 752 |
+
# 1. Reshape to img shape
|
| 753 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
| 754 |
+
|
| 755 |
+
# 2. apply model loop over last dim
|
| 756 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
| 757 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
| 758 |
+
force_not_quantize=predict_cids or force_not_quantize)
|
| 759 |
+
for i in range(z.shape[-1])]
|
| 760 |
+
else:
|
| 761 |
+
|
| 762 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
|
| 763 |
+
for i in range(z.shape[-1])]
|
| 764 |
+
|
| 765 |
+
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
|
| 766 |
+
o = o * weighting
|
| 767 |
+
# Reverse 1. reshape to img shape
|
| 768 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
| 769 |
+
# stitch crops together
|
| 770 |
+
decoded = fold(o)
|
| 771 |
+
decoded = decoded / normalization # norm is shape (1, 1, h, w)
|
| 772 |
+
return decoded
|
| 773 |
+
else:
|
| 774 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
| 775 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
| 776 |
+
else:
|
| 777 |
+
return self.first_stage_model.decode(z)
|
| 778 |
+
|
| 779 |
+
else:
|
| 780 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
| 781 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
| 782 |
+
else:
|
| 783 |
+
return self.first_stage_model.decode(z)
|
| 784 |
+
|
| 785 |
+
# same as above but without decorator
|
| 786 |
+
def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
|
| 787 |
+
if predict_cids:
|
| 788 |
+
if z.dim() == 4:
|
| 789 |
+
z = torch.argmax(z.exp(), dim=1).long()
|
| 790 |
+
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
|
| 791 |
+
z = rearrange(z, 'b h w c -> b c h w').contiguous()
|
| 792 |
+
|
| 793 |
+
z = 1. / self.scale_factor * z
|
| 794 |
+
|
| 795 |
+
if hasattr(self, "split_input_params"):
|
| 796 |
+
if self.split_input_params["patch_distributed_vq"]:
|
| 797 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
| 798 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
| 799 |
+
uf = self.split_input_params["vqf"]
|
| 800 |
+
bs, nc, h, w = z.shape
|
| 801 |
+
if ks[0] > h or ks[1] > w:
|
| 802 |
+
ks = (min(ks[0], h), min(ks[1], w))
|
| 803 |
+
print("reducing Kernel")
|
| 804 |
+
|
| 805 |
+
if stride[0] > h or stride[1] > w:
|
| 806 |
+
stride = (min(stride[0], h), min(stride[1], w))
|
| 807 |
+
print("reducing stride")
|
| 808 |
+
|
| 809 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
|
| 810 |
+
|
| 811 |
+
z = unfold(z) # (bn, nc * prod(**ks), L)
|
| 812 |
+
# 1. Reshape to img shape
|
| 813 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
| 814 |
+
|
| 815 |
+
# 2. apply model loop over last dim
|
| 816 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
| 817 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
| 818 |
+
force_not_quantize=predict_cids or force_not_quantize)
|
| 819 |
+
for i in range(z.shape[-1])]
|
| 820 |
+
else:
|
| 821 |
+
|
| 822 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
|
| 823 |
+
for i in range(z.shape[-1])]
|
| 824 |
+
|
| 825 |
+
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
|
| 826 |
+
o = o * weighting
|
| 827 |
+
# Reverse 1. reshape to img shape
|
| 828 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
| 829 |
+
# stitch crops together
|
| 830 |
+
decoded = fold(o)
|
| 831 |
+
decoded = decoded / normalization # norm is shape (1, 1, h, w)
|
| 832 |
+
return decoded
|
| 833 |
+
else:
|
| 834 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
| 835 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
| 836 |
+
else:
|
| 837 |
+
return self.first_stage_model.decode(z)
|
| 838 |
+
|
| 839 |
+
else:
|
| 840 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
| 841 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
| 842 |
+
else:
|
| 843 |
+
return self.first_stage_model.decode(z)
|
| 844 |
+
|
| 845 |
+
@torch.no_grad()
|
| 846 |
+
def encode_first_stage(self, x):
|
| 847 |
+
if hasattr(self, "split_input_params"):
|
| 848 |
+
if self.split_input_params["patch_distributed_vq"]:
|
| 849 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
| 850 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
| 851 |
+
df = self.split_input_params["vqf"]
|
| 852 |
+
self.split_input_params['original_image_size'] = x.shape[-2:]
|
| 853 |
+
bs, nc, h, w = x.shape
|
| 854 |
+
if ks[0] > h or ks[1] > w:
|
| 855 |
+
ks = (min(ks[0], h), min(ks[1], w))
|
| 856 |
+
print("reducing Kernel")
|
| 857 |
+
|
| 858 |
+
if stride[0] > h or stride[1] > w:
|
| 859 |
+
stride = (min(stride[0], h), min(stride[1], w))
|
| 860 |
+
print("reducing stride")
|
| 861 |
+
|
| 862 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
|
| 863 |
+
z = unfold(x) # (bn, nc * prod(**ks), L)
|
| 864 |
+
# Reshape to img shape
|
| 865 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
| 866 |
+
|
| 867 |
+
output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
|
| 868 |
+
for i in range(z.shape[-1])]
|
| 869 |
+
|
| 870 |
+
o = torch.stack(output_list, axis=-1)
|
| 871 |
+
o = o * weighting
|
| 872 |
+
|
| 873 |
+
# Reverse reshape to img shape
|
| 874 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
| 875 |
+
# stitch crops together
|
| 876 |
+
decoded = fold(o)
|
| 877 |
+
decoded = decoded / normalization
|
| 878 |
+
return decoded
|
| 879 |
+
|
| 880 |
+
else:
|
| 881 |
+
return self.first_stage_model.encode(x)
|
| 882 |
+
else:
|
| 883 |
+
return self.first_stage_model.encode(x)
|
| 884 |
+
|
| 885 |
+
def shared_step(self, batch, **kwargs):
|
| 886 |
+
x, c = self.get_input(batch, self.first_stage_key)
|
| 887 |
+
loss = self(x, c)
|
| 888 |
+
return loss
|
| 889 |
+
|
| 890 |
+
def forward(self, x, c, *args, **kwargs):
|
| 891 |
+
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
|
| 892 |
+
if self.model.conditioning_key is not None:
|
| 893 |
+
assert c is not None
|
| 894 |
+
if self.cond_stage_trainable:
|
| 895 |
+
c = self.get_learned_conditioning(c)
|
| 896 |
+
if self.shorten_cond_schedule: # TODO: drop this option
|
| 897 |
+
tc = self.cond_ids[t].to(self.device)
|
| 898 |
+
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
| 899 |
+
return self.p_losses(x, c, t, *args, **kwargs)
|
| 900 |
+
|
| 901 |
+
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
| 902 |
+
|
| 903 |
+
if isinstance(cond, dict):
|
| 904 |
+
# hybrid case, cond is expected to be a dict
|
| 905 |
+
pass
|
| 906 |
+
else:
|
| 907 |
+
if not isinstance(cond, list):
|
| 908 |
+
cond = [cond]
|
| 909 |
+
key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
|
| 910 |
+
cond = {key: cond}
|
| 911 |
+
|
| 912 |
+
if hasattr(self, "split_input_params"):
|
| 913 |
+
assert len(cond) == 1 # todo can only deal with one conditioning atm
|
| 914 |
+
assert not return_ids
|
| 915 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
| 916 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
| 917 |
+
|
| 918 |
+
h, w = x_noisy.shape[-2:]
|
| 919 |
+
|
| 920 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
|
| 921 |
+
|
| 922 |
+
z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
|
| 923 |
+
# Reshape to img shape
|
| 924 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
| 925 |
+
z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
|
| 926 |
+
|
| 927 |
+
if self.cond_stage_key in ["image", "LR_image", "segmentation",
|
| 928 |
+
'bbox_img'] and self.model.conditioning_key: # todo check for completeness
|
| 929 |
+
c_key = next(iter(cond.keys())) # get key
|
| 930 |
+
c = next(iter(cond.values())) # get value
|
| 931 |
+
assert (len(c) == 1) # todo extend to list with more than one elem
|
| 932 |
+
c = c[0] # get element
|
| 933 |
+
|
| 934 |
+
c = unfold(c)
|
| 935 |
+
c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
| 936 |
+
|
| 937 |
+
cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
|
| 938 |
+
|
| 939 |
+
elif self.cond_stage_key == 'coordinates_bbox':
|
| 940 |
+
assert 'original_image_size' in self.split_input_params, 'BoundingBoxRescaling is missing original_image_size'
|
| 941 |
+
|
| 942 |
+
# assuming padding of unfold is always 0 and its dilation is always 1
|
| 943 |
+
n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
|
| 944 |
+
full_img_h, full_img_w = self.split_input_params['original_image_size']
|
| 945 |
+
# as we are operating on latents, we need the factor from the original image size to the
|
| 946 |
+
# spatial latent size to properly rescale the crops for regenerating the bbox annotations
|
| 947 |
+
num_downs = self.first_stage_model.encoder.num_resolutions - 1
|
| 948 |
+
rescale_latent = 2 ** (num_downs)
|
| 949 |
+
|
| 950 |
+
# get top left positions of patches as conforming for the bbbox tokenizer, therefore we
|
| 951 |
+
# need to rescale the tl patch coordinates to be in between (0,1)
|
| 952 |
+
tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
|
| 953 |
+
rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
|
| 954 |
+
for patch_nr in range(z.shape[-1])]
|
| 955 |
+
|
| 956 |
+
# patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
|
| 957 |
+
patch_limits = [(x_tl, y_tl,
|
| 958 |
+
rescale_latent * ks[0] / full_img_w,
|
| 959 |
+
rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
|
| 960 |
+
# patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
|
| 961 |
+
|
| 962 |
+
# tokenize crop coordinates for the bounding boxes of the respective patches
|
| 963 |
+
patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
|
| 964 |
+
for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
|
| 965 |
+
print(patch_limits_tknzd[0].shape)
|
| 966 |
+
# cut tknzd crop position from conditioning
|
| 967 |
+
assert isinstance(cond, dict), 'cond must be dict to be fed into model'
|
| 968 |
+
cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
|
| 969 |
+
print(cut_cond.shape)
|
| 970 |
+
|
| 971 |
+
adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
|
| 972 |
+
adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
|
| 973 |
+
print(adapted_cond.shape)
|
| 974 |
+
adapted_cond = self.get_learned_conditioning(adapted_cond)
|
| 975 |
+
print(adapted_cond.shape)
|
| 976 |
+
adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
|
| 977 |
+
print(adapted_cond.shape)
|
| 978 |
+
|
| 979 |
+
cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
|
| 980 |
+
|
| 981 |
+
else:
|
| 982 |
+
cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
|
| 983 |
+
|
| 984 |
+
# apply model by loop over crops
|
| 985 |
+
output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
|
| 986 |
+
assert not isinstance(output_list[0],
|
| 987 |
+
tuple) # todo cant deal with multiple model outputs check this never happens
|
| 988 |
+
|
| 989 |
+
o = torch.stack(output_list, axis=-1)
|
| 990 |
+
o = o * weighting
|
| 991 |
+
# Reverse reshape to img shape
|
| 992 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
| 993 |
+
# stitch crops together
|
| 994 |
+
x_recon = fold(o) / normalization
|
| 995 |
+
|
| 996 |
+
else:
|
| 997 |
+
x_recon = self.model(x_noisy, t, **cond)
|
| 998 |
+
|
| 999 |
+
if isinstance(x_recon, tuple) and not return_ids:
|
| 1000 |
+
return x_recon[0]
|
| 1001 |
+
else:
|
| 1002 |
+
return x_recon
|
| 1003 |
+
|
| 1004 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
| 1005 |
+
return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
|
| 1006 |
+
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
| 1007 |
+
|
| 1008 |
+
def _prior_bpd(self, x_start):
|
| 1009 |
+
"""
|
| 1010 |
+
Get the prior KL term for the variational lower-bound, measured in
|
| 1011 |
+
bits-per-dim.
|
| 1012 |
+
This term can't be optimized, as it only depends on the encoder.
|
| 1013 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
| 1014 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
| 1015 |
+
"""
|
| 1016 |
+
batch_size = x_start.shape[0]
|
| 1017 |
+
t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
| 1018 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
| 1019 |
+
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
|
| 1020 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
| 1021 |
+
|
| 1022 |
+
def p_losses(self, x_start, cond, t, noise=None):
|
| 1023 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 1024 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
| 1025 |
+
model_output = self.apply_model(x_noisy, t, cond)
|
| 1026 |
+
|
| 1027 |
+
loss_dict = {}
|
| 1028 |
+
prefix = 'train' if self.training else 'val'
|
| 1029 |
+
|
| 1030 |
+
if self.parameterization == "x0":
|
| 1031 |
+
target = x_start
|
| 1032 |
+
elif self.parameterization == "eps":
|
| 1033 |
+
target = noise
|
| 1034 |
+
else:
|
| 1035 |
+
raise NotImplementedError()
|
| 1036 |
+
|
| 1037 |
+
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
|
| 1038 |
+
loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
|
| 1039 |
+
|
| 1040 |
+
logvar_t = self.logvar[t].to(self.device)
|
| 1041 |
+
loss = loss_simple / torch.exp(logvar_t) + logvar_t
|
| 1042 |
+
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
|
| 1043 |
+
if self.learn_logvar:
|
| 1044 |
+
loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
|
| 1045 |
+
loss_dict.update({'logvar': self.logvar.data.mean()})
|
| 1046 |
+
|
| 1047 |
+
loss = self.l_simple_weight * loss.mean()
|
| 1048 |
+
|
| 1049 |
+
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
|
| 1050 |
+
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
|
| 1051 |
+
loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
|
| 1052 |
+
loss += (self.original_elbo_weight * loss_vlb)
|
| 1053 |
+
loss_dict.update({f'{prefix}/loss': loss})
|
| 1054 |
+
|
| 1055 |
+
return loss, loss_dict
|
| 1056 |
+
|
| 1057 |
+
def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
|
| 1058 |
+
return_x0=False, score_corrector=None, corrector_kwargs=None):
|
| 1059 |
+
t_in = t
|
| 1060 |
+
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
|
| 1061 |
+
|
| 1062 |
+
if score_corrector is not None:
|
| 1063 |
+
assert self.parameterization == "eps"
|
| 1064 |
+
model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
|
| 1065 |
+
|
| 1066 |
+
if return_codebook_ids:
|
| 1067 |
+
model_out, logits = model_out
|
| 1068 |
+
|
| 1069 |
+
if self.parameterization == "eps":
|
| 1070 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
|
| 1071 |
+
elif self.parameterization == "x0":
|
| 1072 |
+
x_recon = model_out
|
| 1073 |
+
else:
|
| 1074 |
+
raise NotImplementedError()
|
| 1075 |
+
|
| 1076 |
+
if clip_denoised:
|
| 1077 |
+
x_recon.clamp_(-1., 1.)
|
| 1078 |
+
if quantize_denoised:
|
| 1079 |
+
x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
|
| 1080 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
| 1081 |
+
if return_codebook_ids:
|
| 1082 |
+
return model_mean, posterior_variance, posterior_log_variance, logits
|
| 1083 |
+
elif return_x0:
|
| 1084 |
+
return model_mean, posterior_variance, posterior_log_variance, x_recon
|
| 1085 |
+
else:
|
| 1086 |
+
return model_mean, posterior_variance, posterior_log_variance
|
| 1087 |
+
|
| 1088 |
+
@torch.no_grad()
|
| 1089 |
+
def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
|
| 1090 |
+
return_codebook_ids=False, quantize_denoised=False, return_x0=False,
|
| 1091 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
|
| 1092 |
+
b, *_, device = *x.shape, x.device
|
| 1093 |
+
outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
|
| 1094 |
+
return_codebook_ids=return_codebook_ids,
|
| 1095 |
+
quantize_denoised=quantize_denoised,
|
| 1096 |
+
return_x0=return_x0,
|
| 1097 |
+
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
|
| 1098 |
+
if return_codebook_ids:
|
| 1099 |
+
raise DeprecationWarning("Support dropped.")
|
| 1100 |
+
model_mean, _, model_log_variance, logits = outputs
|
| 1101 |
+
elif return_x0:
|
| 1102 |
+
model_mean, _, model_log_variance, x0 = outputs
|
| 1103 |
+
else:
|
| 1104 |
+
model_mean, _, model_log_variance = outputs
|
| 1105 |
+
|
| 1106 |
+
noise = noise_like(x.shape, device, repeat_noise) * temperature
|
| 1107 |
+
if noise_dropout > 0.:
|
| 1108 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
| 1109 |
+
# no noise when t == 0
|
| 1110 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
| 1111 |
+
|
| 1112 |
+
if return_codebook_ids:
|
| 1113 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
|
| 1114 |
+
if return_x0:
|
| 1115 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
|
| 1116 |
+
else:
|
| 1117 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
| 1118 |
+
|
| 1119 |
+
@torch.no_grad()
|
| 1120 |
+
def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
|
| 1121 |
+
img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
|
| 1122 |
+
score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
|
| 1123 |
+
log_every_t=None):
|
| 1124 |
+
if not log_every_t:
|
| 1125 |
+
log_every_t = self.log_every_t
|
| 1126 |
+
timesteps = self.num_timesteps
|
| 1127 |
+
if batch_size is not None:
|
| 1128 |
+
b = batch_size if batch_size is not None else shape[0]
|
| 1129 |
+
shape = [batch_size] + list(shape)
|
| 1130 |
+
else:
|
| 1131 |
+
b = batch_size = shape[0]
|
| 1132 |
+
if x_T is None:
|
| 1133 |
+
img = torch.randn(shape, device=self.device)
|
| 1134 |
+
else:
|
| 1135 |
+
img = x_T
|
| 1136 |
+
intermediates = []
|
| 1137 |
+
if cond is not None:
|
| 1138 |
+
if isinstance(cond, dict):
|
| 1139 |
+
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
| 1140 |
+
[x[:batch_size] for x in cond[key]] for key in cond}
|
| 1141 |
+
else:
|
| 1142 |
+
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
| 1143 |
+
|
| 1144 |
+
if start_T is not None:
|
| 1145 |
+
timesteps = min(timesteps, start_T)
|
| 1146 |
+
iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
|
| 1147 |
+
total=timesteps) if verbose else reversed(
|
| 1148 |
+
range(0, timesteps))
|
| 1149 |
+
if type(temperature) == float:
|
| 1150 |
+
temperature = [temperature] * timesteps
|
| 1151 |
+
|
| 1152 |
+
for i in iterator:
|
| 1153 |
+
ts = torch.full((b,), i, device=self.device, dtype=torch.long)
|
| 1154 |
+
if self.shorten_cond_schedule:
|
| 1155 |
+
assert self.model.conditioning_key != 'hybrid'
|
| 1156 |
+
tc = self.cond_ids[ts].to(cond.device)
|
| 1157 |
+
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
|
| 1158 |
+
|
| 1159 |
+
img, x0_partial = self.p_sample(img, cond, ts,
|
| 1160 |
+
clip_denoised=self.clip_denoised,
|
| 1161 |
+
quantize_denoised=quantize_denoised, return_x0=True,
|
| 1162 |
+
temperature=temperature[i], noise_dropout=noise_dropout,
|
| 1163 |
+
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
|
| 1164 |
+
if mask is not None:
|
| 1165 |
+
assert x0 is not None
|
| 1166 |
+
img_orig = self.q_sample(x0, ts)
|
| 1167 |
+
img = img_orig * mask + (1. - mask) * img
|
| 1168 |
+
|
| 1169 |
+
if i % log_every_t == 0 or i == timesteps - 1:
|
| 1170 |
+
intermediates.append(x0_partial)
|
| 1171 |
+
if callback:
|
| 1172 |
+
callback(i)
|
| 1173 |
+
if img_callback:
|
| 1174 |
+
img_callback(img, i)
|
| 1175 |
+
return img, intermediates
|
| 1176 |
+
|
| 1177 |
+
@torch.no_grad()
|
| 1178 |
+
def p_sample_loop(self, cond, shape, return_intermediates=False,
|
| 1179 |
+
x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
|
| 1180 |
+
mask=None, x0=None, img_callback=None, start_T=None,
|
| 1181 |
+
log_every_t=None):
|
| 1182 |
+
|
| 1183 |
+
if not log_every_t:
|
| 1184 |
+
log_every_t = self.log_every_t
|
| 1185 |
+
device = self.betas.device
|
| 1186 |
+
b = shape[0]
|
| 1187 |
+
if x_T is None:
|
| 1188 |
+
img = torch.randn(shape, device=device)
|
| 1189 |
+
else:
|
| 1190 |
+
img = x_T
|
| 1191 |
+
|
| 1192 |
+
intermediates = [img]
|
| 1193 |
+
if timesteps is None:
|
| 1194 |
+
timesteps = self.num_timesteps
|
| 1195 |
+
|
| 1196 |
+
if start_T is not None:
|
| 1197 |
+
timesteps = min(timesteps, start_T)
|
| 1198 |
+
iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
|
| 1199 |
+
range(0, timesteps))
|
| 1200 |
+
|
| 1201 |
+
if mask is not None:
|
| 1202 |
+
assert x0 is not None
|
| 1203 |
+
assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
|
| 1204 |
+
|
| 1205 |
+
for i in iterator:
|
| 1206 |
+
ts = torch.full((b,), i, device=device, dtype=torch.long)
|
| 1207 |
+
if self.shorten_cond_schedule:
|
| 1208 |
+
assert self.model.conditioning_key != 'hybrid'
|
| 1209 |
+
tc = self.cond_ids[ts].to(cond.device)
|
| 1210 |
+
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
|
| 1211 |
+
|
| 1212 |
+
img = self.p_sample(img, cond, ts,
|
| 1213 |
+
clip_denoised=self.clip_denoised,
|
| 1214 |
+
quantize_denoised=quantize_denoised)
|
| 1215 |
+
if mask is not None:
|
| 1216 |
+
img_orig = self.q_sample(x0, ts)
|
| 1217 |
+
img = img_orig * mask + (1. - mask) * img
|
| 1218 |
+
|
| 1219 |
+
if i % log_every_t == 0 or i == timesteps - 1:
|
| 1220 |
+
intermediates.append(img)
|
| 1221 |
+
if callback:
|
| 1222 |
+
callback(i)
|
| 1223 |
+
if img_callback:
|
| 1224 |
+
img_callback(img, i)
|
| 1225 |
+
|
| 1226 |
+
if return_intermediates:
|
| 1227 |
+
return img, intermediates
|
| 1228 |
+
return img
|
| 1229 |
+
|
| 1230 |
+
@torch.no_grad()
|
| 1231 |
+
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
|
| 1232 |
+
verbose=True, timesteps=None, quantize_denoised=False,
|
| 1233 |
+
mask=None, x0=None, shape=None,**kwargs):
|
| 1234 |
+
if shape is None:
|
| 1235 |
+
shape = (batch_size, self.channels, self.image_size, self.image_size)
|
| 1236 |
+
if cond is not None:
|
| 1237 |
+
if isinstance(cond, dict):
|
| 1238 |
+
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
| 1239 |
+
[x[:batch_size] for x in cond[key]] for key in cond}
|
| 1240 |
+
else:
|
| 1241 |
+
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
| 1242 |
+
return self.p_sample_loop(cond,
|
| 1243 |
+
shape,
|
| 1244 |
+
return_intermediates=return_intermediates, x_T=x_T,
|
| 1245 |
+
verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
|
| 1246 |
+
mask=mask, x0=x0)
|
| 1247 |
+
|
| 1248 |
+
@torch.no_grad()
|
| 1249 |
+
def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
|
| 1250 |
+
|
| 1251 |
+
if ddim:
|
| 1252 |
+
ddim_sampler = DDIMSampler(self)
|
| 1253 |
+
shape = (self.channels, self.image_size, self.image_size)
|
| 1254 |
+
samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
|
| 1255 |
+
shape,cond,verbose=False,**kwargs)
|
| 1256 |
+
|
| 1257 |
+
else:
|
| 1258 |
+
samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
|
| 1259 |
+
return_intermediates=True,**kwargs)
|
| 1260 |
+
|
| 1261 |
+
return samples, intermediates
|
| 1262 |
+
|
| 1263 |
+
|
| 1264 |
+
@torch.no_grad()
|
| 1265 |
+
def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
|
| 1266 |
+
quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
|
| 1267 |
+
plot_diffusion_rows=False, **kwargs):
|
| 1268 |
+
|
| 1269 |
+
use_ddim = False
|
| 1270 |
+
|
| 1271 |
+
log = {}
|
| 1272 |
+
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
| 1273 |
+
return_first_stage_outputs=True,
|
| 1274 |
+
force_c_encode=True,
|
| 1275 |
+
return_original_cond=True,
|
| 1276 |
+
bs=N, uncond=0)
|
| 1277 |
+
N = min(x.shape[0], N)
|
| 1278 |
+
n_row = min(x.shape[0], n_row)
|
| 1279 |
+
log["inputs"] = x
|
| 1280 |
+
log["reals"] = xc["c_concat"]
|
| 1281 |
+
log["reconstruction"] = xrec
|
| 1282 |
+
if self.model.conditioning_key is not None:
|
| 1283 |
+
if hasattr(self.cond_stage_model, "decode"):
|
| 1284 |
+
xc = self.cond_stage_model.decode(c)
|
| 1285 |
+
log["conditioning"] = xc
|
| 1286 |
+
elif self.cond_stage_key in ["caption"]:
|
| 1287 |
+
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
|
| 1288 |
+
log["conditioning"] = xc
|
| 1289 |
+
elif self.cond_stage_key == 'class_label':
|
| 1290 |
+
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
| 1291 |
+
log['conditioning'] = xc
|
| 1292 |
+
elif isimage(xc):
|
| 1293 |
+
log["conditioning"] = xc
|
| 1294 |
+
if ismap(xc):
|
| 1295 |
+
log["original_conditioning"] = self.to_rgb(xc)
|
| 1296 |
+
|
| 1297 |
+
if plot_diffusion_rows:
|
| 1298 |
+
# get diffusion row
|
| 1299 |
+
diffusion_row = []
|
| 1300 |
+
z_start = z[:n_row]
|
| 1301 |
+
for t in range(self.num_timesteps):
|
| 1302 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
| 1303 |
+
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
| 1304 |
+
t = t.to(self.device).long()
|
| 1305 |
+
noise = torch.randn_like(z_start)
|
| 1306 |
+
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
| 1307 |
+
diffusion_row.append(self.decode_first_stage(z_noisy))
|
| 1308 |
+
|
| 1309 |
+
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
| 1310 |
+
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
|
| 1311 |
+
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
|
| 1312 |
+
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
| 1313 |
+
log["diffusion_row"] = diffusion_grid
|
| 1314 |
+
|
| 1315 |
+
if sample:
|
| 1316 |
+
# get denoise row
|
| 1317 |
+
with self.ema_scope("Plotting"):
|
| 1318 |
+
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
| 1319 |
+
ddim_steps=ddim_steps,eta=ddim_eta)
|
| 1320 |
+
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
| 1321 |
+
x_samples = self.decode_first_stage(samples)
|
| 1322 |
+
log["samples"] = x_samples
|
| 1323 |
+
if plot_denoise_rows:
|
| 1324 |
+
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
| 1325 |
+
log["denoise_row"] = denoise_grid
|
| 1326 |
+
|
| 1327 |
+
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
|
| 1328 |
+
self.first_stage_model, IdentityFirstStage):
|
| 1329 |
+
# also display when quantizing x0 while sampling
|
| 1330 |
+
with self.ema_scope("Plotting Quantized Denoised"):
|
| 1331 |
+
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
| 1332 |
+
ddim_steps=ddim_steps,eta=ddim_eta,
|
| 1333 |
+
quantize_denoised=True)
|
| 1334 |
+
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
|
| 1335 |
+
# quantize_denoised=True)
|
| 1336 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
| 1337 |
+
log["samples_x0_quantized"] = x_samples
|
| 1338 |
+
|
| 1339 |
+
if inpaint:
|
| 1340 |
+
# make a simple center square
|
| 1341 |
+
h, w = z.shape[2], z.shape[3]
|
| 1342 |
+
mask = torch.ones(N, h, w).to(self.device)
|
| 1343 |
+
# zeros will be filled in
|
| 1344 |
+
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
| 1345 |
+
mask = mask[:, None, ...]
|
| 1346 |
+
with self.ema_scope("Plotting Inpaint"):
|
| 1347 |
+
|
| 1348 |
+
samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
|
| 1349 |
+
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
| 1350 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
| 1351 |
+
log["samples_inpainting"] = x_samples
|
| 1352 |
+
log["mask"] = mask
|
| 1353 |
+
|
| 1354 |
+
# outpaint
|
| 1355 |
+
with self.ema_scope("Plotting Outpaint"):
|
| 1356 |
+
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
|
| 1357 |
+
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
| 1358 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
| 1359 |
+
log["samples_outpainting"] = x_samples
|
| 1360 |
+
|
| 1361 |
+
if plot_progressive_rows:
|
| 1362 |
+
with self.ema_scope("Plotting Progressives"):
|
| 1363 |
+
img, progressives = self.progressive_denoising(c,
|
| 1364 |
+
shape=(self.channels, self.image_size, self.image_size),
|
| 1365 |
+
batch_size=N)
|
| 1366 |
+
prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
|
| 1367 |
+
log["progressive_row"] = prog_row
|
| 1368 |
+
|
| 1369 |
+
if return_keys:
|
| 1370 |
+
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
|
| 1371 |
+
return log
|
| 1372 |
+
else:
|
| 1373 |
+
return {key: log[key] for key in return_keys}
|
| 1374 |
+
return log
|
| 1375 |
+
|
| 1376 |
+
def configure_optimizers(self):
|
| 1377 |
+
lr = self.learning_rate
|
| 1378 |
+
params = list(self.model.parameters())
|
| 1379 |
+
if self.cond_stage_trainable:
|
| 1380 |
+
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
|
| 1381 |
+
params = params + list(self.cond_stage_model.parameters())
|
| 1382 |
+
if self.learn_logvar:
|
| 1383 |
+
print('Diffusion model optimizing logvar')
|
| 1384 |
+
params.append(self.logvar)
|
| 1385 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
| 1386 |
+
if self.use_scheduler:
|
| 1387 |
+
assert 'target' in self.scheduler_config
|
| 1388 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
| 1389 |
+
|
| 1390 |
+
print("Setting up LambdaLR scheduler...")
|
| 1391 |
+
scheduler = [
|
| 1392 |
+
{
|
| 1393 |
+
'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
|
| 1394 |
+
'interval': 'step',
|
| 1395 |
+
'frequency': 1
|
| 1396 |
+
}]
|
| 1397 |
+
return [opt], scheduler
|
| 1398 |
+
return opt
|
| 1399 |
+
|
| 1400 |
+
@torch.no_grad()
|
| 1401 |
+
def to_rgb(self, x):
|
| 1402 |
+
x = x.float()
|
| 1403 |
+
if not hasattr(self, "colorize"):
|
| 1404 |
+
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
|
| 1405 |
+
x = nn.functional.conv2d(x, weight=self.colorize)
|
| 1406 |
+
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
| 1407 |
+
return x
|
| 1408 |
+
|
| 1409 |
+
|
| 1410 |
+
class DiffusionWrapper(pl.LightningModule):
|
| 1411 |
+
def __init__(self, diff_model_config, conditioning_key):
|
| 1412 |
+
super().__init__()
|
| 1413 |
+
self.diffusion_model = instantiate_from_config(diff_model_config)
|
| 1414 |
+
self.conditioning_key = conditioning_key
|
| 1415 |
+
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
|
| 1416 |
+
|
| 1417 |
+
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
|
| 1418 |
+
if self.conditioning_key is None:
|
| 1419 |
+
out = self.diffusion_model(x, t)
|
| 1420 |
+
elif self.conditioning_key == 'concat':
|
| 1421 |
+
xc = torch.cat([x] + c_concat, dim=1)
|
| 1422 |
+
out = self.diffusion_model(xc, t)
|
| 1423 |
+
elif self.conditioning_key == 'crossattn':
|
| 1424 |
+
cc = torch.cat(c_crossattn, 1)
|
| 1425 |
+
out = self.diffusion_model(x, t, context=cc)
|
| 1426 |
+
elif self.conditioning_key == 'hybrid':
|
| 1427 |
+
xc = torch.cat([x] + c_concat, dim=1)
|
| 1428 |
+
cc = torch.cat(c_crossattn, 1)
|
| 1429 |
+
out = self.diffusion_model(xc, t, context=cc)
|
| 1430 |
+
elif self.conditioning_key == 'adm':
|
| 1431 |
+
cc = c_crossattn[0]
|
| 1432 |
+
out = self.diffusion_model(x, t, y=cc)
|
| 1433 |
+
else:
|
| 1434 |
+
raise NotImplementedError()
|
| 1435 |
+
|
| 1436 |
+
return out
|
| 1437 |
+
|
| 1438 |
+
|
| 1439 |
+
class Layout2ImgDiffusion(LatentDiffusion):
|
| 1440 |
+
# TODO: move all layout-specific hacks to this class
|
| 1441 |
+
def __init__(self, cond_stage_key, *args, **kwargs):
|
| 1442 |
+
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
| 1443 |
+
super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
|
| 1444 |
+
|
| 1445 |
+
def log_images(self, batch, N=8, *args, **kwargs):
|
| 1446 |
+
logs = super().log_images(*args, batch=batch, N=N, **kwargs)
|
| 1447 |
+
|
| 1448 |
+
key = 'train' if self.training else 'validation'
|
| 1449 |
+
dset = self.trainer.datamodule.datasets[key]
|
| 1450 |
+
mapper = dset.conditional_builders[self.cond_stage_key]
|
| 1451 |
+
|
| 1452 |
+
bbox_imgs = []
|
| 1453 |
+
map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
|
| 1454 |
+
for tknzd_bbox in batch[self.cond_stage_key][:N]:
|
| 1455 |
+
bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
|
| 1456 |
+
bbox_imgs.append(bboximg)
|
| 1457 |
+
|
| 1458 |
+
cond_img = torch.stack(bbox_imgs, dim=0)
|
| 1459 |
+
logs['bbox_image'] = cond_img
|
| 1460 |
+
return logs
|
modules/models/diffusion/uni_pc/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .sampler import UniPCSampler # noqa: F401
|
modules/models/diffusion/uni_pc/sampler.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SAMPLING ONLY."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
|
| 6 |
+
from modules import shared, devices
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class UniPCSampler(object):
|
| 10 |
+
def __init__(self, model, **kwargs):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.model = model
|
| 13 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
| 14 |
+
self.before_sample = None
|
| 15 |
+
self.after_sample = None
|
| 16 |
+
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
| 17 |
+
|
| 18 |
+
def register_buffer(self, name, attr):
|
| 19 |
+
if type(attr) == torch.Tensor:
|
| 20 |
+
if attr.device != devices.device:
|
| 21 |
+
attr = attr.to(devices.device)
|
| 22 |
+
setattr(self, name, attr)
|
| 23 |
+
|
| 24 |
+
def set_hooks(self, before_sample, after_sample, after_update):
|
| 25 |
+
self.before_sample = before_sample
|
| 26 |
+
self.after_sample = after_sample
|
| 27 |
+
self.after_update = after_update
|
| 28 |
+
|
| 29 |
+
@torch.no_grad()
|
| 30 |
+
def sample(self,
|
| 31 |
+
S,
|
| 32 |
+
batch_size,
|
| 33 |
+
shape,
|
| 34 |
+
conditioning=None,
|
| 35 |
+
callback=None,
|
| 36 |
+
normals_sequence=None,
|
| 37 |
+
img_callback=None,
|
| 38 |
+
quantize_x0=False,
|
| 39 |
+
eta=0.,
|
| 40 |
+
mask=None,
|
| 41 |
+
x0=None,
|
| 42 |
+
temperature=1.,
|
| 43 |
+
noise_dropout=0.,
|
| 44 |
+
score_corrector=None,
|
| 45 |
+
corrector_kwargs=None,
|
| 46 |
+
verbose=True,
|
| 47 |
+
x_T=None,
|
| 48 |
+
log_every_t=100,
|
| 49 |
+
unconditional_guidance_scale=1.,
|
| 50 |
+
unconditional_conditioning=None,
|
| 51 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
| 52 |
+
**kwargs
|
| 53 |
+
):
|
| 54 |
+
if conditioning is not None:
|
| 55 |
+
if isinstance(conditioning, dict):
|
| 56 |
+
ctmp = conditioning[list(conditioning.keys())[0]]
|
| 57 |
+
while isinstance(ctmp, list):
|
| 58 |
+
ctmp = ctmp[0]
|
| 59 |
+
cbs = ctmp.shape[0]
|
| 60 |
+
if cbs != batch_size:
|
| 61 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
| 62 |
+
|
| 63 |
+
elif isinstance(conditioning, list):
|
| 64 |
+
for ctmp in conditioning:
|
| 65 |
+
if ctmp.shape[0] != batch_size:
|
| 66 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
| 67 |
+
|
| 68 |
+
else:
|
| 69 |
+
if conditioning.shape[0] != batch_size:
|
| 70 |
+
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
| 71 |
+
|
| 72 |
+
# sampling
|
| 73 |
+
C, H, W = shape
|
| 74 |
+
size = (batch_size, C, H, W)
|
| 75 |
+
# print(f'Data shape for UniPC sampling is {size}')
|
| 76 |
+
|
| 77 |
+
device = self.model.betas.device
|
| 78 |
+
if x_T is None:
|
| 79 |
+
img = torch.randn(size, device=device)
|
| 80 |
+
else:
|
| 81 |
+
img = x_T
|
| 82 |
+
|
| 83 |
+
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
|
| 84 |
+
|
| 85 |
+
# SD 1.X is "noise", SD 2.X is "v"
|
| 86 |
+
model_type = "v" if self.model.parameterization == "v" else "noise"
|
| 87 |
+
|
| 88 |
+
model_fn = model_wrapper(
|
| 89 |
+
lambda x, t, c: self.model.apply_model(x, t, c),
|
| 90 |
+
ns,
|
| 91 |
+
model_type=model_type,
|
| 92 |
+
guidance_type="classifier-free",
|
| 93 |
+
#condition=conditioning,
|
| 94 |
+
#unconditional_condition=unconditional_conditioning,
|
| 95 |
+
guidance_scale=unconditional_guidance_scale,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
|
| 99 |
+
x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
|
| 100 |
+
|
| 101 |
+
return x.to(device), None
|
modules/models/diffusion/uni_pc/uni_pc.py
ADDED
|
@@ -0,0 +1,863 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
import tqdm
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class NoiseScheduleVP:
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
schedule='discrete',
|
| 10 |
+
betas=None,
|
| 11 |
+
alphas_cumprod=None,
|
| 12 |
+
continuous_beta_0=0.1,
|
| 13 |
+
continuous_beta_1=20.,
|
| 14 |
+
):
|
| 15 |
+
"""Create a wrapper class for the forward SDE (VP type).
|
| 16 |
+
|
| 17 |
+
***
|
| 18 |
+
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
| 19 |
+
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
|
| 20 |
+
***
|
| 21 |
+
|
| 22 |
+
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
|
| 23 |
+
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
|
| 24 |
+
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
|
| 25 |
+
|
| 26 |
+
log_alpha_t = self.marginal_log_mean_coeff(t)
|
| 27 |
+
sigma_t = self.marginal_std(t)
|
| 28 |
+
lambda_t = self.marginal_lambda(t)
|
| 29 |
+
|
| 30 |
+
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
|
| 31 |
+
|
| 32 |
+
t = self.inverse_lambda(lambda_t)
|
| 33 |
+
|
| 34 |
+
===============================================================
|
| 35 |
+
|
| 36 |
+
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
|
| 37 |
+
|
| 38 |
+
1. For discrete-time DPMs:
|
| 39 |
+
|
| 40 |
+
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
|
| 41 |
+
t_i = (i + 1) / N
|
| 42 |
+
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
|
| 43 |
+
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
|
| 47 |
+
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
|
| 48 |
+
|
| 49 |
+
Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
|
| 50 |
+
|
| 51 |
+
**Important**: Please pay special attention for the args for `alphas_cumprod`:
|
| 52 |
+
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
|
| 53 |
+
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
|
| 54 |
+
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
|
| 55 |
+
alpha_{t_n} = \sqrt{\hat{alpha_n}},
|
| 56 |
+
and
|
| 57 |
+
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
2. For continuous-time DPMs:
|
| 61 |
+
|
| 62 |
+
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
|
| 63 |
+
schedule are the default settings in DDPM and improved-DDPM:
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
beta_min: A `float` number. The smallest beta for the linear schedule.
|
| 67 |
+
beta_max: A `float` number. The largest beta for the linear schedule.
|
| 68 |
+
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
|
| 69 |
+
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
|
| 70 |
+
T: A `float` number. The ending time of the forward process.
|
| 71 |
+
|
| 72 |
+
===============================================================
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
|
| 76 |
+
'linear' or 'cosine' for continuous-time DPMs.
|
| 77 |
+
Returns:
|
| 78 |
+
A wrapper object of the forward SDE (VP type).
|
| 79 |
+
|
| 80 |
+
===============================================================
|
| 81 |
+
|
| 82 |
+
Example:
|
| 83 |
+
|
| 84 |
+
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
|
| 85 |
+
>>> ns = NoiseScheduleVP('discrete', betas=betas)
|
| 86 |
+
|
| 87 |
+
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
|
| 88 |
+
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
| 89 |
+
|
| 90 |
+
# For continuous-time DPMs (VPSDE), linear schedule:
|
| 91 |
+
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
|
| 92 |
+
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
if schedule not in ['discrete', 'linear', 'cosine']:
|
| 96 |
+
raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'")
|
| 97 |
+
|
| 98 |
+
self.schedule = schedule
|
| 99 |
+
if schedule == 'discrete':
|
| 100 |
+
if betas is not None:
|
| 101 |
+
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
| 102 |
+
else:
|
| 103 |
+
assert alphas_cumprod is not None
|
| 104 |
+
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
| 105 |
+
self.total_N = len(log_alphas)
|
| 106 |
+
self.T = 1.
|
| 107 |
+
self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
|
| 108 |
+
self.log_alpha_array = log_alphas.reshape((1, -1,))
|
| 109 |
+
else:
|
| 110 |
+
self.total_N = 1000
|
| 111 |
+
self.beta_0 = continuous_beta_0
|
| 112 |
+
self.beta_1 = continuous_beta_1
|
| 113 |
+
self.cosine_s = 0.008
|
| 114 |
+
self.cosine_beta_max = 999.
|
| 115 |
+
self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
| 116 |
+
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
|
| 117 |
+
self.schedule = schedule
|
| 118 |
+
if schedule == 'cosine':
|
| 119 |
+
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
|
| 120 |
+
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
|
| 121 |
+
self.T = 0.9946
|
| 122 |
+
else:
|
| 123 |
+
self.T = 1.
|
| 124 |
+
|
| 125 |
+
def marginal_log_mean_coeff(self, t):
|
| 126 |
+
"""
|
| 127 |
+
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
| 128 |
+
"""
|
| 129 |
+
if self.schedule == 'discrete':
|
| 130 |
+
return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
|
| 131 |
+
elif self.schedule == 'linear':
|
| 132 |
+
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
| 133 |
+
elif self.schedule == 'cosine':
|
| 134 |
+
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
|
| 135 |
+
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
|
| 136 |
+
return log_alpha_t
|
| 137 |
+
|
| 138 |
+
def marginal_alpha(self, t):
|
| 139 |
+
"""
|
| 140 |
+
Compute alpha_t of a given continuous-time label t in [0, T].
|
| 141 |
+
"""
|
| 142 |
+
return torch.exp(self.marginal_log_mean_coeff(t))
|
| 143 |
+
|
| 144 |
+
def marginal_std(self, t):
|
| 145 |
+
"""
|
| 146 |
+
Compute sigma_t of a given continuous-time label t in [0, T].
|
| 147 |
+
"""
|
| 148 |
+
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
|
| 149 |
+
|
| 150 |
+
def marginal_lambda(self, t):
|
| 151 |
+
"""
|
| 152 |
+
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
| 153 |
+
"""
|
| 154 |
+
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
| 155 |
+
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
| 156 |
+
return log_mean_coeff - log_std
|
| 157 |
+
|
| 158 |
+
def inverse_lambda(self, lamb):
|
| 159 |
+
"""
|
| 160 |
+
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
| 161 |
+
"""
|
| 162 |
+
if self.schedule == 'linear':
|
| 163 |
+
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
| 164 |
+
Delta = self.beta_0**2 + tmp
|
| 165 |
+
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
| 166 |
+
elif self.schedule == 'discrete':
|
| 167 |
+
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
|
| 168 |
+
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
|
| 169 |
+
return t.reshape((-1,))
|
| 170 |
+
else:
|
| 171 |
+
log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
| 172 |
+
t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
| 173 |
+
t = t_fn(log_alpha)
|
| 174 |
+
return t
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def model_wrapper(
|
| 178 |
+
model,
|
| 179 |
+
noise_schedule,
|
| 180 |
+
model_type="noise",
|
| 181 |
+
model_kwargs=None,
|
| 182 |
+
guidance_type="uncond",
|
| 183 |
+
#condition=None,
|
| 184 |
+
#unconditional_condition=None,
|
| 185 |
+
guidance_scale=1.,
|
| 186 |
+
classifier_fn=None,
|
| 187 |
+
classifier_kwargs=None,
|
| 188 |
+
):
|
| 189 |
+
"""Create a wrapper function for the noise prediction model.
|
| 190 |
+
|
| 191 |
+
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
|
| 192 |
+
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
|
| 193 |
+
|
| 194 |
+
We support four types of the diffusion model by setting `model_type`:
|
| 195 |
+
|
| 196 |
+
1. "noise": noise prediction model. (Trained by predicting noise).
|
| 197 |
+
|
| 198 |
+
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
|
| 199 |
+
|
| 200 |
+
3. "v": velocity prediction model. (Trained by predicting the velocity).
|
| 201 |
+
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
|
| 202 |
+
|
| 203 |
+
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
|
| 204 |
+
arXiv preprint arXiv:2202.00512 (2022).
|
| 205 |
+
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
| 206 |
+
arXiv preprint arXiv:2210.02303 (2022).
|
| 207 |
+
|
| 208 |
+
4. "score": marginal score function. (Trained by denoising score matching).
|
| 209 |
+
Note that the score function and the noise prediction model follows a simple relationship:
|
| 210 |
+
```
|
| 211 |
+
noise(x_t, t) = -sigma_t * score(x_t, t)
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
We support three types of guided sampling by DPMs by setting `guidance_type`:
|
| 215 |
+
1. "uncond": unconditional sampling by DPMs.
|
| 216 |
+
The input `model` has the following format:
|
| 217 |
+
``
|
| 218 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
| 219 |
+
``
|
| 220 |
+
|
| 221 |
+
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
|
| 222 |
+
The input `model` has the following format:
|
| 223 |
+
``
|
| 224 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
| 225 |
+
``
|
| 226 |
+
|
| 227 |
+
The input `classifier_fn` has the following format:
|
| 228 |
+
``
|
| 229 |
+
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
|
| 230 |
+
``
|
| 231 |
+
|
| 232 |
+
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
|
| 233 |
+
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
|
| 234 |
+
|
| 235 |
+
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
|
| 236 |
+
The input `model` has the following format:
|
| 237 |
+
``
|
| 238 |
+
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
| 239 |
+
``
|
| 240 |
+
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
| 241 |
+
|
| 242 |
+
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
| 243 |
+
arXiv preprint arXiv:2207.12598 (2022).
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
| 247 |
+
or continuous-time labels (i.e. epsilon to T).
|
| 248 |
+
|
| 249 |
+
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
|
| 250 |
+
``
|
| 251 |
+
def model_fn(x, t_continuous) -> noise:
|
| 252 |
+
t_input = get_model_input_time(t_continuous)
|
| 253 |
+
return noise_pred(model, x, t_input, **model_kwargs)
|
| 254 |
+
``
|
| 255 |
+
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
| 256 |
+
|
| 257 |
+
===============================================================
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
model: A diffusion model with the corresponding format described above.
|
| 261 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
| 262 |
+
model_type: A `str`. The parameterization type of the diffusion model.
|
| 263 |
+
"noise" or "x_start" or "v" or "score".
|
| 264 |
+
model_kwargs: A `dict`. A dict for the other inputs of the model function.
|
| 265 |
+
guidance_type: A `str`. The type of the guidance for sampling.
|
| 266 |
+
"uncond" or "classifier" or "classifier-free".
|
| 267 |
+
condition: A pytorch tensor. The condition for the guided sampling.
|
| 268 |
+
Only used for "classifier" or "classifier-free" guidance type.
|
| 269 |
+
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
|
| 270 |
+
Only used for "classifier-free" guidance type.
|
| 271 |
+
guidance_scale: A `float`. The scale for the guided sampling.
|
| 272 |
+
classifier_fn: A classifier function. Only used for the classifier guidance.
|
| 273 |
+
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
|
| 274 |
+
Returns:
|
| 275 |
+
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
model_kwargs = model_kwargs or {}
|
| 279 |
+
classifier_kwargs = classifier_kwargs or {}
|
| 280 |
+
|
| 281 |
+
def get_model_input_time(t_continuous):
|
| 282 |
+
"""
|
| 283 |
+
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
| 284 |
+
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
|
| 285 |
+
For continuous-time DPMs, we just use `t_continuous`.
|
| 286 |
+
"""
|
| 287 |
+
if noise_schedule.schedule == 'discrete':
|
| 288 |
+
return (t_continuous - 1. / noise_schedule.total_N) * 1000.
|
| 289 |
+
else:
|
| 290 |
+
return t_continuous
|
| 291 |
+
|
| 292 |
+
def noise_pred_fn(x, t_continuous, cond=None):
|
| 293 |
+
if t_continuous.reshape((-1,)).shape[0] == 1:
|
| 294 |
+
t_continuous = t_continuous.expand((x.shape[0]))
|
| 295 |
+
t_input = get_model_input_time(t_continuous)
|
| 296 |
+
if cond is None:
|
| 297 |
+
output = model(x, t_input, None, **model_kwargs)
|
| 298 |
+
else:
|
| 299 |
+
output = model(x, t_input, cond, **model_kwargs)
|
| 300 |
+
if model_type == "noise":
|
| 301 |
+
return output
|
| 302 |
+
elif model_type == "x_start":
|
| 303 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
| 304 |
+
dims = x.dim()
|
| 305 |
+
return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
|
| 306 |
+
elif model_type == "v":
|
| 307 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
| 308 |
+
dims = x.dim()
|
| 309 |
+
return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
|
| 310 |
+
elif model_type == "score":
|
| 311 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
| 312 |
+
dims = x.dim()
|
| 313 |
+
return -expand_dims(sigma_t, dims) * output
|
| 314 |
+
|
| 315 |
+
def cond_grad_fn(x, t_input, condition):
|
| 316 |
+
"""
|
| 317 |
+
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
|
| 318 |
+
"""
|
| 319 |
+
with torch.enable_grad():
|
| 320 |
+
x_in = x.detach().requires_grad_(True)
|
| 321 |
+
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
|
| 322 |
+
return torch.autograd.grad(log_prob.sum(), x_in)[0]
|
| 323 |
+
|
| 324 |
+
def model_fn(x, t_continuous, condition, unconditional_condition):
|
| 325 |
+
"""
|
| 326 |
+
The noise prediction model function that is used for DPM-Solver.
|
| 327 |
+
"""
|
| 328 |
+
if t_continuous.reshape((-1,)).shape[0] == 1:
|
| 329 |
+
t_continuous = t_continuous.expand((x.shape[0]))
|
| 330 |
+
if guidance_type == "uncond":
|
| 331 |
+
return noise_pred_fn(x, t_continuous)
|
| 332 |
+
elif guidance_type == "classifier":
|
| 333 |
+
assert classifier_fn is not None
|
| 334 |
+
t_input = get_model_input_time(t_continuous)
|
| 335 |
+
cond_grad = cond_grad_fn(x, t_input, condition)
|
| 336 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
| 337 |
+
noise = noise_pred_fn(x, t_continuous)
|
| 338 |
+
return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
|
| 339 |
+
elif guidance_type == "classifier-free":
|
| 340 |
+
if guidance_scale == 1. or unconditional_condition is None:
|
| 341 |
+
return noise_pred_fn(x, t_continuous, cond=condition)
|
| 342 |
+
else:
|
| 343 |
+
x_in = torch.cat([x] * 2)
|
| 344 |
+
t_in = torch.cat([t_continuous] * 2)
|
| 345 |
+
if isinstance(condition, dict):
|
| 346 |
+
assert isinstance(unconditional_condition, dict)
|
| 347 |
+
c_in = {}
|
| 348 |
+
for k in condition:
|
| 349 |
+
if isinstance(condition[k], list):
|
| 350 |
+
c_in[k] = [torch.cat([
|
| 351 |
+
unconditional_condition[k][i],
|
| 352 |
+
condition[k][i]]) for i in range(len(condition[k]))]
|
| 353 |
+
else:
|
| 354 |
+
c_in[k] = torch.cat([
|
| 355 |
+
unconditional_condition[k],
|
| 356 |
+
condition[k]])
|
| 357 |
+
elif isinstance(condition, list):
|
| 358 |
+
c_in = []
|
| 359 |
+
assert isinstance(unconditional_condition, list)
|
| 360 |
+
for i in range(len(condition)):
|
| 361 |
+
c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
|
| 362 |
+
else:
|
| 363 |
+
c_in = torch.cat([unconditional_condition, condition])
|
| 364 |
+
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
| 365 |
+
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
| 366 |
+
|
| 367 |
+
assert model_type in ["noise", "x_start", "v"]
|
| 368 |
+
assert guidance_type in ["uncond", "classifier", "classifier-free"]
|
| 369 |
+
return model_fn
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class UniPC:
|
| 373 |
+
def __init__(
|
| 374 |
+
self,
|
| 375 |
+
model_fn,
|
| 376 |
+
noise_schedule,
|
| 377 |
+
predict_x0=True,
|
| 378 |
+
thresholding=False,
|
| 379 |
+
max_val=1.,
|
| 380 |
+
variant='bh1',
|
| 381 |
+
condition=None,
|
| 382 |
+
unconditional_condition=None,
|
| 383 |
+
before_sample=None,
|
| 384 |
+
after_sample=None,
|
| 385 |
+
after_update=None
|
| 386 |
+
):
|
| 387 |
+
"""Construct a UniPC.
|
| 388 |
+
|
| 389 |
+
We support both data_prediction and noise_prediction.
|
| 390 |
+
"""
|
| 391 |
+
self.model_fn_ = model_fn
|
| 392 |
+
self.noise_schedule = noise_schedule
|
| 393 |
+
self.variant = variant
|
| 394 |
+
self.predict_x0 = predict_x0
|
| 395 |
+
self.thresholding = thresholding
|
| 396 |
+
self.max_val = max_val
|
| 397 |
+
self.condition = condition
|
| 398 |
+
self.unconditional_condition = unconditional_condition
|
| 399 |
+
self.before_sample = before_sample
|
| 400 |
+
self.after_sample = after_sample
|
| 401 |
+
self.after_update = after_update
|
| 402 |
+
|
| 403 |
+
def dynamic_thresholding_fn(self, x0, t=None):
|
| 404 |
+
"""
|
| 405 |
+
The dynamic thresholding method.
|
| 406 |
+
"""
|
| 407 |
+
dims = x0.dim()
|
| 408 |
+
p = self.dynamic_thresholding_ratio
|
| 409 |
+
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
| 410 |
+
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
|
| 411 |
+
x0 = torch.clamp(x0, -s, s) / s
|
| 412 |
+
return x0
|
| 413 |
+
|
| 414 |
+
def model(self, x, t):
|
| 415 |
+
cond = self.condition
|
| 416 |
+
uncond = self.unconditional_condition
|
| 417 |
+
if self.before_sample is not None:
|
| 418 |
+
x, t, cond, uncond = self.before_sample(x, t, cond, uncond)
|
| 419 |
+
res = self.model_fn_(x, t, cond, uncond)
|
| 420 |
+
if self.after_sample is not None:
|
| 421 |
+
x, t, cond, uncond, res = self.after_sample(x, t, cond, uncond, res)
|
| 422 |
+
|
| 423 |
+
if isinstance(res, tuple):
|
| 424 |
+
# (None, pred_x0)
|
| 425 |
+
res = res[1]
|
| 426 |
+
|
| 427 |
+
return res
|
| 428 |
+
|
| 429 |
+
def noise_prediction_fn(self, x, t):
|
| 430 |
+
"""
|
| 431 |
+
Return the noise prediction model.
|
| 432 |
+
"""
|
| 433 |
+
return self.model(x, t)
|
| 434 |
+
|
| 435 |
+
def data_prediction_fn(self, x, t):
|
| 436 |
+
"""
|
| 437 |
+
Return the data prediction model (with thresholding).
|
| 438 |
+
"""
|
| 439 |
+
noise = self.noise_prediction_fn(x, t)
|
| 440 |
+
dims = x.dim()
|
| 441 |
+
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
| 442 |
+
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
|
| 443 |
+
if self.thresholding:
|
| 444 |
+
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
|
| 445 |
+
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
| 446 |
+
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
| 447 |
+
x0 = torch.clamp(x0, -s, s) / s
|
| 448 |
+
return x0
|
| 449 |
+
|
| 450 |
+
def model_fn(self, x, t):
|
| 451 |
+
"""
|
| 452 |
+
Convert the model to the noise prediction model or the data prediction model.
|
| 453 |
+
"""
|
| 454 |
+
if self.predict_x0:
|
| 455 |
+
return self.data_prediction_fn(x, t)
|
| 456 |
+
else:
|
| 457 |
+
return self.noise_prediction_fn(x, t)
|
| 458 |
+
|
| 459 |
+
def get_time_steps(self, skip_type, t_T, t_0, N, device):
|
| 460 |
+
"""Compute the intermediate time steps for sampling.
|
| 461 |
+
"""
|
| 462 |
+
if skip_type == 'logSNR':
|
| 463 |
+
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
| 464 |
+
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
| 465 |
+
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
| 466 |
+
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
| 467 |
+
elif skip_type == 'time_uniform':
|
| 468 |
+
return torch.linspace(t_T, t_0, N + 1).to(device)
|
| 469 |
+
elif skip_type == 'time_quadratic':
|
| 470 |
+
t_order = 2
|
| 471 |
+
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
| 472 |
+
return t
|
| 473 |
+
else:
|
| 474 |
+
raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'")
|
| 475 |
+
|
| 476 |
+
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
| 477 |
+
"""
|
| 478 |
+
Get the order of each step for sampling by the singlestep DPM-Solver.
|
| 479 |
+
"""
|
| 480 |
+
if order == 3:
|
| 481 |
+
K = steps // 3 + 1
|
| 482 |
+
if steps % 3 == 0:
|
| 483 |
+
orders = [3,] * (K - 2) + [2, 1]
|
| 484 |
+
elif steps % 3 == 1:
|
| 485 |
+
orders = [3,] * (K - 1) + [1]
|
| 486 |
+
else:
|
| 487 |
+
orders = [3,] * (K - 1) + [2]
|
| 488 |
+
elif order == 2:
|
| 489 |
+
if steps % 2 == 0:
|
| 490 |
+
K = steps // 2
|
| 491 |
+
orders = [2,] * K
|
| 492 |
+
else:
|
| 493 |
+
K = steps // 2 + 1
|
| 494 |
+
orders = [2,] * (K - 1) + [1]
|
| 495 |
+
elif order == 1:
|
| 496 |
+
K = steps
|
| 497 |
+
orders = [1,] * steps
|
| 498 |
+
else:
|
| 499 |
+
raise ValueError("'order' must be '1' or '2' or '3'.")
|
| 500 |
+
if skip_type == 'logSNR':
|
| 501 |
+
# To reproduce the results in DPM-Solver paper
|
| 502 |
+
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
| 503 |
+
else:
|
| 504 |
+
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
|
| 505 |
+
return timesteps_outer, orders
|
| 506 |
+
|
| 507 |
+
def denoise_to_zero_fn(self, x, s):
|
| 508 |
+
"""
|
| 509 |
+
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
| 510 |
+
"""
|
| 511 |
+
return self.data_prediction_fn(x, s)
|
| 512 |
+
|
| 513 |
+
def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
|
| 514 |
+
if len(t.shape) == 0:
|
| 515 |
+
t = t.view(-1)
|
| 516 |
+
if 'bh' in self.variant:
|
| 517 |
+
return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
| 518 |
+
else:
|
| 519 |
+
assert self.variant == 'vary_coeff'
|
| 520 |
+
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
| 521 |
+
|
| 522 |
+
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
| 523 |
+
#print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
| 524 |
+
ns = self.noise_schedule
|
| 525 |
+
assert order <= len(model_prev_list)
|
| 526 |
+
|
| 527 |
+
# first compute rks
|
| 528 |
+
t_prev_0 = t_prev_list[-1]
|
| 529 |
+
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
| 530 |
+
lambda_t = ns.marginal_lambda(t)
|
| 531 |
+
model_prev_0 = model_prev_list[-1]
|
| 532 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
| 533 |
+
log_alpha_t = ns.marginal_log_mean_coeff(t)
|
| 534 |
+
alpha_t = torch.exp(log_alpha_t)
|
| 535 |
+
|
| 536 |
+
h = lambda_t - lambda_prev_0
|
| 537 |
+
|
| 538 |
+
rks = []
|
| 539 |
+
D1s = []
|
| 540 |
+
for i in range(1, order):
|
| 541 |
+
t_prev_i = t_prev_list[-(i + 1)]
|
| 542 |
+
model_prev_i = model_prev_list[-(i + 1)]
|
| 543 |
+
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
| 544 |
+
rk = (lambda_prev_i - lambda_prev_0) / h
|
| 545 |
+
rks.append(rk)
|
| 546 |
+
D1s.append((model_prev_i - model_prev_0) / rk)
|
| 547 |
+
|
| 548 |
+
rks.append(1.)
|
| 549 |
+
rks = torch.tensor(rks, device=x.device)
|
| 550 |
+
|
| 551 |
+
K = len(rks)
|
| 552 |
+
# build C matrix
|
| 553 |
+
C = []
|
| 554 |
+
|
| 555 |
+
col = torch.ones_like(rks)
|
| 556 |
+
for k in range(1, K + 1):
|
| 557 |
+
C.append(col)
|
| 558 |
+
col = col * rks / (k + 1)
|
| 559 |
+
C = torch.stack(C, dim=1)
|
| 560 |
+
|
| 561 |
+
if len(D1s) > 0:
|
| 562 |
+
D1s = torch.stack(D1s, dim=1) # (B, K)
|
| 563 |
+
C_inv_p = torch.linalg.inv(C[:-1, :-1])
|
| 564 |
+
A_p = C_inv_p
|
| 565 |
+
|
| 566 |
+
if use_corrector:
|
| 567 |
+
#print('using corrector')
|
| 568 |
+
C_inv = torch.linalg.inv(C)
|
| 569 |
+
A_c = C_inv
|
| 570 |
+
|
| 571 |
+
hh = -h if self.predict_x0 else h
|
| 572 |
+
h_phi_1 = torch.expm1(hh)
|
| 573 |
+
h_phi_ks = []
|
| 574 |
+
factorial_k = 1
|
| 575 |
+
h_phi_k = h_phi_1
|
| 576 |
+
for k in range(1, K + 2):
|
| 577 |
+
h_phi_ks.append(h_phi_k)
|
| 578 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_k
|
| 579 |
+
factorial_k *= (k + 1)
|
| 580 |
+
|
| 581 |
+
model_t = None
|
| 582 |
+
if self.predict_x0:
|
| 583 |
+
x_t_ = (
|
| 584 |
+
sigma_t / sigma_prev_0 * x
|
| 585 |
+
- alpha_t * h_phi_1 * model_prev_0
|
| 586 |
+
)
|
| 587 |
+
# now predictor
|
| 588 |
+
x_t = x_t_
|
| 589 |
+
if len(D1s) > 0:
|
| 590 |
+
# compute the residuals for predictor
|
| 591 |
+
for k in range(K - 1):
|
| 592 |
+
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
|
| 593 |
+
# now corrector
|
| 594 |
+
if use_corrector:
|
| 595 |
+
model_t = self.model_fn(x_t, t)
|
| 596 |
+
D1_t = (model_t - model_prev_0)
|
| 597 |
+
x_t = x_t_
|
| 598 |
+
k = 0
|
| 599 |
+
for k in range(K - 1):
|
| 600 |
+
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
|
| 601 |
+
x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
|
| 602 |
+
else:
|
| 603 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
| 604 |
+
x_t_ = (
|
| 605 |
+
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
|
| 606 |
+
- (sigma_t * h_phi_1) * model_prev_0
|
| 607 |
+
)
|
| 608 |
+
# now predictor
|
| 609 |
+
x_t = x_t_
|
| 610 |
+
if len(D1s) > 0:
|
| 611 |
+
# compute the residuals for predictor
|
| 612 |
+
for k in range(K - 1):
|
| 613 |
+
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
|
| 614 |
+
# now corrector
|
| 615 |
+
if use_corrector:
|
| 616 |
+
model_t = self.model_fn(x_t, t)
|
| 617 |
+
D1_t = (model_t - model_prev_0)
|
| 618 |
+
x_t = x_t_
|
| 619 |
+
k = 0
|
| 620 |
+
for k in range(K - 1):
|
| 621 |
+
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
|
| 622 |
+
x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
|
| 623 |
+
return x_t, model_t
|
| 624 |
+
|
| 625 |
+
def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
|
| 626 |
+
#print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
|
| 627 |
+
ns = self.noise_schedule
|
| 628 |
+
assert order <= len(model_prev_list)
|
| 629 |
+
dims = x.dim()
|
| 630 |
+
|
| 631 |
+
# first compute rks
|
| 632 |
+
t_prev_0 = t_prev_list[-1]
|
| 633 |
+
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
| 634 |
+
lambda_t = ns.marginal_lambda(t)
|
| 635 |
+
model_prev_0 = model_prev_list[-1]
|
| 636 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
| 637 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
| 638 |
+
alpha_t = torch.exp(log_alpha_t)
|
| 639 |
+
|
| 640 |
+
h = lambda_t - lambda_prev_0
|
| 641 |
+
|
| 642 |
+
rks = []
|
| 643 |
+
D1s = []
|
| 644 |
+
for i in range(1, order):
|
| 645 |
+
t_prev_i = t_prev_list[-(i + 1)]
|
| 646 |
+
model_prev_i = model_prev_list[-(i + 1)]
|
| 647 |
+
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
| 648 |
+
rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
|
| 649 |
+
rks.append(rk)
|
| 650 |
+
D1s.append((model_prev_i - model_prev_0) / rk)
|
| 651 |
+
|
| 652 |
+
rks.append(1.)
|
| 653 |
+
rks = torch.tensor(rks, device=x.device)
|
| 654 |
+
|
| 655 |
+
R = []
|
| 656 |
+
b = []
|
| 657 |
+
|
| 658 |
+
hh = -h[0] if self.predict_x0 else h[0]
|
| 659 |
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
| 660 |
+
h_phi_k = h_phi_1 / hh - 1
|
| 661 |
+
|
| 662 |
+
factorial_i = 1
|
| 663 |
+
|
| 664 |
+
if self.variant == 'bh1':
|
| 665 |
+
B_h = hh
|
| 666 |
+
elif self.variant == 'bh2':
|
| 667 |
+
B_h = torch.expm1(hh)
|
| 668 |
+
else:
|
| 669 |
+
raise NotImplementedError()
|
| 670 |
+
|
| 671 |
+
for i in range(1, order + 1):
|
| 672 |
+
R.append(torch.pow(rks, i - 1))
|
| 673 |
+
b.append(h_phi_k * factorial_i / B_h)
|
| 674 |
+
factorial_i *= (i + 1)
|
| 675 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
| 676 |
+
|
| 677 |
+
R = torch.stack(R)
|
| 678 |
+
b = torch.tensor(b, device=x.device)
|
| 679 |
+
|
| 680 |
+
# now predictor
|
| 681 |
+
use_predictor = len(D1s) > 0 and x_t is None
|
| 682 |
+
if len(D1s) > 0:
|
| 683 |
+
D1s = torch.stack(D1s, dim=1) # (B, K)
|
| 684 |
+
if x_t is None:
|
| 685 |
+
# for order 2, we use a simplified version
|
| 686 |
+
if order == 2:
|
| 687 |
+
rhos_p = torch.tensor([0.5], device=b.device)
|
| 688 |
+
else:
|
| 689 |
+
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
| 690 |
+
else:
|
| 691 |
+
D1s = None
|
| 692 |
+
|
| 693 |
+
if use_corrector:
|
| 694 |
+
#print('using corrector')
|
| 695 |
+
# for order 1, we use a simplified version
|
| 696 |
+
if order == 1:
|
| 697 |
+
rhos_c = torch.tensor([0.5], device=b.device)
|
| 698 |
+
else:
|
| 699 |
+
rhos_c = torch.linalg.solve(R, b)
|
| 700 |
+
|
| 701 |
+
model_t = None
|
| 702 |
+
if self.predict_x0:
|
| 703 |
+
x_t_ = (
|
| 704 |
+
expand_dims(sigma_t / sigma_prev_0, dims) * x
|
| 705 |
+
- expand_dims(alpha_t * h_phi_1, dims)* model_prev_0
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
if x_t is None:
|
| 709 |
+
if use_predictor:
|
| 710 |
+
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
| 711 |
+
else:
|
| 712 |
+
pred_res = 0
|
| 713 |
+
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
|
| 714 |
+
|
| 715 |
+
if use_corrector:
|
| 716 |
+
model_t = self.model_fn(x_t, t)
|
| 717 |
+
if D1s is not None:
|
| 718 |
+
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
| 719 |
+
else:
|
| 720 |
+
corr_res = 0
|
| 721 |
+
D1_t = (model_t - model_prev_0)
|
| 722 |
+
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
| 723 |
+
else:
|
| 724 |
+
x_t_ = (
|
| 725 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
| 726 |
+
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
|
| 727 |
+
)
|
| 728 |
+
if x_t is None:
|
| 729 |
+
if use_predictor:
|
| 730 |
+
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
| 731 |
+
else:
|
| 732 |
+
pred_res = 0
|
| 733 |
+
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
|
| 734 |
+
|
| 735 |
+
if use_corrector:
|
| 736 |
+
model_t = self.model_fn(x_t, t)
|
| 737 |
+
if D1s is not None:
|
| 738 |
+
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
| 739 |
+
else:
|
| 740 |
+
corr_res = 0
|
| 741 |
+
D1_t = (model_t - model_prev_0)
|
| 742 |
+
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
| 743 |
+
return x_t, model_t
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
|
| 747 |
+
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
| 748 |
+
atol=0.0078, rtol=0.05, corrector=False,
|
| 749 |
+
):
|
| 750 |
+
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
| 751 |
+
t_T = self.noise_schedule.T if t_start is None else t_start
|
| 752 |
+
device = x.device
|
| 753 |
+
if method == 'multistep':
|
| 754 |
+
assert steps >= order, "UniPC order must be < sampling steps"
|
| 755 |
+
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
| 756 |
+
#print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}")
|
| 757 |
+
assert timesteps.shape[0] - 1 == steps
|
| 758 |
+
with torch.no_grad():
|
| 759 |
+
vec_t = timesteps[0].expand((x.shape[0]))
|
| 760 |
+
model_prev_list = [self.model_fn(x, vec_t)]
|
| 761 |
+
t_prev_list = [vec_t]
|
| 762 |
+
with tqdm.tqdm(total=steps) as pbar:
|
| 763 |
+
# Init the first `order` values by lower order multistep DPM-Solver.
|
| 764 |
+
for init_order in range(1, order):
|
| 765 |
+
vec_t = timesteps[init_order].expand(x.shape[0])
|
| 766 |
+
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
| 767 |
+
if model_x is None:
|
| 768 |
+
model_x = self.model_fn(x, vec_t)
|
| 769 |
+
if self.after_update is not None:
|
| 770 |
+
self.after_update(x, model_x)
|
| 771 |
+
model_prev_list.append(model_x)
|
| 772 |
+
t_prev_list.append(vec_t)
|
| 773 |
+
pbar.update()
|
| 774 |
+
|
| 775 |
+
for step in range(order, steps + 1):
|
| 776 |
+
vec_t = timesteps[step].expand(x.shape[0])
|
| 777 |
+
if lower_order_final:
|
| 778 |
+
step_order = min(order, steps + 1 - step)
|
| 779 |
+
else:
|
| 780 |
+
step_order = order
|
| 781 |
+
#print('this step order:', step_order)
|
| 782 |
+
if step == steps:
|
| 783 |
+
#print('do not run corrector at the last step')
|
| 784 |
+
use_corrector = False
|
| 785 |
+
else:
|
| 786 |
+
use_corrector = True
|
| 787 |
+
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
| 788 |
+
if self.after_update is not None:
|
| 789 |
+
self.after_update(x, model_x)
|
| 790 |
+
for i in range(order - 1):
|
| 791 |
+
t_prev_list[i] = t_prev_list[i + 1]
|
| 792 |
+
model_prev_list[i] = model_prev_list[i + 1]
|
| 793 |
+
t_prev_list[-1] = vec_t
|
| 794 |
+
# We do not need to evaluate the final model value.
|
| 795 |
+
if step < steps:
|
| 796 |
+
if model_x is None:
|
| 797 |
+
model_x = self.model_fn(x, vec_t)
|
| 798 |
+
model_prev_list[-1] = model_x
|
| 799 |
+
pbar.update()
|
| 800 |
+
else:
|
| 801 |
+
raise NotImplementedError()
|
| 802 |
+
if denoise_to_zero:
|
| 803 |
+
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
| 804 |
+
return x
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
#############################################################
|
| 808 |
+
# other utility functions
|
| 809 |
+
#############################################################
|
| 810 |
+
|
| 811 |
+
def interpolate_fn(x, xp, yp):
|
| 812 |
+
"""
|
| 813 |
+
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
| 814 |
+
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
| 815 |
+
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
| 816 |
+
|
| 817 |
+
Args:
|
| 818 |
+
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
| 819 |
+
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
| 820 |
+
yp: PyTorch tensor with shape [C, K].
|
| 821 |
+
Returns:
|
| 822 |
+
The function values f(x), with shape [N, C].
|
| 823 |
+
"""
|
| 824 |
+
N, K = x.shape[0], xp.shape[1]
|
| 825 |
+
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
| 826 |
+
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
| 827 |
+
x_idx = torch.argmin(x_indices, dim=2)
|
| 828 |
+
cand_start_idx = x_idx - 1
|
| 829 |
+
start_idx = torch.where(
|
| 830 |
+
torch.eq(x_idx, 0),
|
| 831 |
+
torch.tensor(1, device=x.device),
|
| 832 |
+
torch.where(
|
| 833 |
+
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
| 834 |
+
),
|
| 835 |
+
)
|
| 836 |
+
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
| 837 |
+
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
| 838 |
+
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
| 839 |
+
start_idx2 = torch.where(
|
| 840 |
+
torch.eq(x_idx, 0),
|
| 841 |
+
torch.tensor(0, device=x.device),
|
| 842 |
+
torch.where(
|
| 843 |
+
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
| 844 |
+
),
|
| 845 |
+
)
|
| 846 |
+
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
| 847 |
+
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
| 848 |
+
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
| 849 |
+
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
| 850 |
+
return cand
|
| 851 |
+
|
| 852 |
+
|
| 853 |
+
def expand_dims(v, dims):
|
| 854 |
+
"""
|
| 855 |
+
Expand the tensor `v` to the dim `dims`.
|
| 856 |
+
|
| 857 |
+
Args:
|
| 858 |
+
`v`: a PyTorch tensor with shape [N].
|
| 859 |
+
`dim`: a `int`.
|
| 860 |
+
Returns:
|
| 861 |
+
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
| 862 |
+
"""
|
| 863 |
+
return v[(...,) + (None,)*(dims - 1)]
|