Spaces:
Running
on
Zero
Running
on
Zero
Upload 29 files
Browse files- .gitignore +232 -0
- .gitmodules +0 -0
- LICENSE +54 -0
- app.py +586 -0
- gradio_app.py +568 -0
- gradio_edit.py +563 -0
- infer_withanyone.py +309 -0
- nohup.out +2 -0
- requirements.txt +24 -0
- util.py +411 -0
- withanyone/flux/__pycache__/math.cpython-310.pyc +0 -0
- withanyone/flux/__pycache__/model.cpython-310.pyc +0 -0
- withanyone/flux/__pycache__/pipeline.cpython-310.pyc +0 -0
- withanyone/flux/__pycache__/sampling.cpython-310.pyc +0 -0
- withanyone/flux/__pycache__/util.cpython-310.pyc +0 -0
- withanyone/flux/math.py +49 -0
- withanyone/flux/model.py +610 -0
- withanyone/flux/modules/__pycache__/autoencoder.cpython-310.pyc +0 -0
- withanyone/flux/modules/__pycache__/conditioner.cpython-310.pyc +0 -0
- withanyone/flux/modules/__pycache__/layers.cpython-310.pyc +0 -0
- withanyone/flux/modules/autoencoder.py +327 -0
- withanyone/flux/modules/conditioner.py +53 -0
- withanyone/flux/modules/layers.py +530 -0
- withanyone/flux/pipeline.py +406 -0
- withanyone/flux/sampling.py +171 -0
- withanyone/flux/util.py +518 -0
- withanyone/utils/convert_yaml_to_args_file.py +22 -0
.gitignore
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
|
| 110 |
+
# pdm
|
| 111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 112 |
+
#pdm.lock
|
| 113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 114 |
+
# in version control.
|
| 115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 116 |
+
.pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 121 |
+
__pypackages__/
|
| 122 |
+
|
| 123 |
+
# Celery stuff
|
| 124 |
+
celerybeat-schedule
|
| 125 |
+
celerybeat.pid
|
| 126 |
+
|
| 127 |
+
# SageMath parsed files
|
| 128 |
+
*.sage.py
|
| 129 |
+
|
| 130 |
+
# Environments
|
| 131 |
+
.env
|
| 132 |
+
.venv
|
| 133 |
+
env/
|
| 134 |
+
venv/
|
| 135 |
+
ENV/
|
| 136 |
+
env.bak/
|
| 137 |
+
venv.bak/
|
| 138 |
+
|
| 139 |
+
# Spyder project settings
|
| 140 |
+
.spyderproject
|
| 141 |
+
.spyproject
|
| 142 |
+
|
| 143 |
+
# Rope project settings
|
| 144 |
+
.ropeproject
|
| 145 |
+
|
| 146 |
+
# mkdocs documentation
|
| 147 |
+
/site
|
| 148 |
+
|
| 149 |
+
# mypy
|
| 150 |
+
.mypy_cache/
|
| 151 |
+
.dmypy.json
|
| 152 |
+
dmypy.json
|
| 153 |
+
|
| 154 |
+
# Pyre type checker
|
| 155 |
+
.pyre/
|
| 156 |
+
|
| 157 |
+
# pytype static type analyzer
|
| 158 |
+
.pytype/
|
| 159 |
+
|
| 160 |
+
# Cython debug symbols
|
| 161 |
+
cython_debug/
|
| 162 |
+
|
| 163 |
+
# PyCharm
|
| 164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 168 |
+
#.idea/
|
| 169 |
+
|
| 170 |
+
# Ruff stuff:
|
| 171 |
+
.ruff_cache/
|
| 172 |
+
|
| 173 |
+
# PyPI configuration file
|
| 174 |
+
.pypirc
|
| 175 |
+
|
| 176 |
+
# User config files
|
| 177 |
+
.vscode/
|
| 178 |
+
output/
|
| 179 |
+
|
| 180 |
+
# ckpt
|
| 181 |
+
*.bin
|
| 182 |
+
*.pt
|
| 183 |
+
*.pth
|
| 184 |
+
ckpts/
|
| 185 |
+
ckpt-*
|
| 186 |
+
ckpts/*
|
| 187 |
+
|
| 188 |
+
# legacy code
|
| 189 |
+
legacy/
|
| 190 |
+
legacy/*
|
| 191 |
+
|
| 192 |
+
# wandb
|
| 193 |
+
wandb/
|
| 194 |
+
wandb/*
|
| 195 |
+
|
| 196 |
+
# arcface models
|
| 197 |
+
models/
|
| 198 |
+
|
| 199 |
+
# debug
|
| 200 |
+
debug*
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
data_single/
|
| 204 |
+
data_single_10/
|
| 205 |
+
lora_attampt/
|
| 206 |
+
lora_attampt/*
|
| 207 |
+
|
| 208 |
+
*.safetensors
|
| 209 |
+
*.ckpt
|
| 210 |
+
|
| 211 |
+
.output/
|
| 212 |
+
for_bbox/
|
| 213 |
+
|
| 214 |
+
# data
|
| 215 |
+
data/
|
| 216 |
+
datasets/
|
| 217 |
+
|
| 218 |
+
nohup.out
|
| 219 |
+
|
| 220 |
+
10**
|
| 221 |
+
temp_generated.png
|
| 222 |
+
|
| 223 |
+
facenet_pytorch/
|
| 224 |
+
facenet_pytorch/*
|
| 225 |
+
|
| 226 |
+
# AdaFace/
|
| 227 |
+
# AdaFace/*
|
| 228 |
+
|
| 229 |
+
pretrained/
|
| 230 |
+
|
| 231 |
+
git_backup/
|
| 232 |
+
git_backup/*
|
.gitmodules
ADDED
|
File without changes
|
LICENSE
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FLUX.1 [dev] Non-Commercial License v1.1.1
|
| 2 |
+
|
| 3 |
+
Black Forest Labs Inc. (“we” or “our” or “Company”) is pleased to make available the weights, parameters and inference code for the FLUX.1 [dev] Model (as defined below) freely available for your non-commercial and non-production use as set forth in this FLUX.1 [dev] Non-Commercial License (“License”). The “FLUX.1 [dev] Model” means the FLUX.1 [dev] AI models and models denoted as FLUX.1 [dev], including but not limited to FLUX.1 [dev], FLUX.1 Fill [dev], FLUX.1 Depth [dev], FLUX.1 Canny [dev], FLUX.1 Redux [dev], FLUX.1 Canny [dev] LoRA, FLUX.1 Depth [dev] LoRA, and FLUX.1 Kontext [dev], and their elements which includes algorithms, software, checkpoints, parameters, source code (inference code, evaluation code, and if applicable, fine-tuning code) and any other materials associated with the FLUX.1 [dev] AI models made available by Company under this License, including if any, the technical documentation, manuals and instructions for the use and operation thereof (collectively, “FLUX.1 [dev] Model”). Note that we may also make available certain elements of what is included in the definition of “FLUX.1 [dev] Model” under a separate license, such as the inference code, and nothing in this License will be deemed to restrict or limit any other licenses granted by us in such elements.
|
| 4 |
+
|
| 5 |
+
By downloading, accessing, using, Distributing (as defined below), or creating a Derivative (as defined below) of the FLUX.1 [dev] Model, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to access, use, Distribute or create a Derivative of the FLUX.1 [dev] Model and you must immediately cease using the FLUX.1 [dev] Model. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to us that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the FLUX.1 [dev] Model on behalf of your employer or other entity.
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
- a. “Derivative” means any (i) modified version of the FLUX.1 [dev] Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the FLUX.1 [dev] Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered Derivatives under this License.
|
| 9 |
+
- b. “Distribution” or “Distribute” or “Distributing” means providing or making available, by any means, a copy of the FLUX.1 [dev] Models and/or the Derivatives as the case may be.
|
| 10 |
+
- c. “Non-Commercial Purpose” means any of the following uses, but only so far as you do not receive any direct or indirect payment arising from the use of the FLUX.1 [dev] Model, Derivatives, or FLUX Content Filters (as defined below): (i) personal use for research, experiment, and testing for the benefit of public knowledge, personal study, private entertainment, hobby projects, or otherwise not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities; (ii) use by commercial or for-profit entities for testing, evaluation, or non-commercial research and development in a non-production environment; and (iii) use by any charitable organization for charitable purposes, or for testing or evaluation. For clarity, use (a) for revenue-generating activity, (b) in direct interactions with or that has impact on end users, or (c) to train, fine tune or distill other models for commercial use, in each case is not a Non-Commercial Purpose.
|
| 11 |
+
- d. “Outputs” means any content generated by the operation of the FLUX.1 [dev] Models or the Derivatives from an input (such as an image input) or prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of the FLUX.1 [dev] Models, such as any fine-tuned versions of the FLUX.1 [dev] Models, the weights, or parameters.
|
| 12 |
+
- e. “you” or “your” means the individual or entity entering into this License with Company.
|
| 13 |
+
|
| 14 |
+
2. License Grant.
|
| 15 |
+
- a. License. Subject to your compliance with this License, Company grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license to access, use, create Derivatives of, and Distribute the FLUX.1 [dev] Models and Derivatives solely for your Non-Commercial Purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Company’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License. Any restrictions set forth herein regarding the FLUX.1 [dev] Model also apply to any Derivative you create or that are created on your behalf.
|
| 16 |
+
- b. Non-Commercial Use Only. You may only access, use, Distribute, or create Derivatives of the FLUX.1 [dev] Model or Derivatives for Non-Commercial Purposes. If you want to use a FLUX.1 [dev] Model or a Derivative for any purpose that is not expressly authorized under this License, such as for a commercial activity, you must request a license from Company, which Company may grant to you in Company’s sole discretion and which additional use may be subject to a fee, royalty or other revenue share. Please see www.bfl.ai if you would like a commercial license.
|
| 17 |
+
- c. Reserved Rights. The grant of rights expressly set forth in this License are the complete grant of rights to you in the FLUX.1 [dev] Model, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Company and its licensors reserve all rights not expressly granted by this License.
|
| 18 |
+
- d. Outputs. We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs you generate and their subsequent uses in accordance with this License. You may use Output for any purpose (including for commercial purposes), except as expressly prohibited herein. You may not use the Output to train, fine-tune or distill a model that is competitive with the FLUX.1 [dev] Model or the FLUX.1 Kontext [dev] Model.
|
| 19 |
+
- e. You may access, use, Distribute, or create Output of the FLUX.1 [dev] Model or Derivatives if you: (i) (A) implement and maintain content filtering measures (“Content Filters”) for your use of the FLUX.1 [dev] Model or Derivatives to prevent the creation, display, transmission, generation, or dissemination of unlawful or infringing content, which may include Content Filters that we may make available for use with the FLUX.1 [dev] Model (“FLUX Content Filters”), or (B) ensure Output undergoes review for unlawful or infringing content before public or non-public distribution, display, transmission or dissemination; and (ii) ensure Output includes disclosure (or other indication) that the Output was generated or modified using artificial intelligence technologies to the extent required under applicable law.
|
| 20 |
+
|
| 21 |
+
3. Distribution. Subject to this License, you may Distribute copies of the FLUX.1 [dev] Model and/or Derivatives made by you, under the following conditions:
|
| 22 |
+
- a. you must make available a copy of this License to third-party recipients of the FLUX.1 [dev] Models and/or Derivatives you Distribute, and specify that any rights to use the FLUX.1 [dev] Models and/or Derivatives shall be directly granted by Company to said third-party recipients pursuant to this License;
|
| 23 |
+
- b. you must prominently display the following notice alongside the Distribution of the FLUX.1 [dev] Model or Derivative (such as via a “Notice” text file distributed as part of such FLUX.1 [dev] Model or Derivative) (the “Attribution Notice”):
|
| 24 |
+
|
| 25 |
+
“The FLUX.1 [dev] Model is licensed by Black Forest Labs Inc. under the FLUX.1 [dev] Non-Commercial License. Copyright Black Forest Labs Inc.
|
| 26 |
+
IN NO EVENT SHALL BLACK FOREST LABS INC. BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH USE OF THIS MODEL.”
|
| 27 |
+
|
| 28 |
+
- c. in the case of Distribution of Derivatives made by you: (i) you must also include in the Attribution Notice a statement that you have modified the applicable FLUX.1 [dev] Model; (ii) any terms and conditions you impose on any third-party recipients relating to Derivatives made by or for you shall neither limit such third-party recipients’ use of the FLUX.1 [dev] Model or any Derivatives made by or for Company in accordance with this License nor conflict with any of its terms and conditions and must include disclaimer of warranties and limitation of liability provisions that are at least as protective of Company as those set forth herein; and (iii) you must not misrepresent or imply, through any means, that the Derivatives made by or for you and/or any modified version of the FLUX.1 [dev] Model you Distribute under your name and responsibility is an official product of the Company or has been endorsed, approved or validated by the Company, unless you are authorized by Company to do so in writing.
|
| 29 |
+
|
| 30 |
+
4. Restrictions. You will not, and will not permit, assist or cause any third party to
|
| 31 |
+
- a. use, modify, copy, reproduce, create Derivatives of, or Distribute the FLUX.1 [dev] Model (or any Derivative thereof, or any data produced by the FLUX.1 [dev] Model), in whole or in part, (i) for any commercial or production purposes, (ii) military purposes, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates (or is likely to infringe, misappropriate, or otherwise violate) any third party’s legal rights, including rights of publicity or “digital replica” rights, (vi) in any unlawful, fraudulent, defamatory, or abusive activity, (vii) to generate unlawful content, including child sexual abuse material, or non-consensual intimate images; or (viii) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, any and all laws governing the processing of biometric information, and the EU Artificial Intelligence Act (Regulation (EU) 2024/1689), as well as all amendments and successor laws to any of the foregoing;
|
| 32 |
+
- b. alter or remove copyright and other proprietary notices which appear on or in any portion of the FLUX.1 [dev] Model;
|
| 33 |
+
- c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Company in connection with the FLUX.1 [dev] Model, or to circumvent or remove any usage restrictions, or to enable functionality disabled by FLUX.1 [dev] Model;
|
| 34 |
+
- d. offer or impose any terms on the FLUX.1 [dev] Model that alter, restrict, or are inconsistent with the terms of this License;
|
| 35 |
+
- e. violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”) in connection with your use or Distribution of any FLUX.1 [dev] Model;
|
| 36 |
+
- f. directly or indirectly Distribute, export, or otherwise transfer FLUX.1 [dev] Model (i) to any individual, entity, or country prohibited by Export Laws; (ii) to anyone on U.S. or non-U.S. government restricted parties lists; (iii) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; (iv) use or download FLUX.1 [dev] Model if you or they are (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (v) will not disguise your location through IP proxying or other methods.
|
| 37 |
+
|
| 38 |
+
5. DISCLAIMERS. THE FLUX.1 [dev] MODEL AND FLUX CONTENT FILTERS ARE PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. COMPANY EXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE FLUX.1 [dev] MODEL AND FLUX CONTENT FILTERS, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. COMPANY MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE FLUX.1 [dev] MODEL AND FLUX CONTENT FILTERS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS.
|
| 39 |
+
|
| 40 |
+
6. LIMITATION OF LIABILITY. TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL COMPANY BE LIABLE TO YOU OR YOUR EMPLOYEES, AFFILIATES, USERS, OFFICERS OR DIRECTORS (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF COMPANY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE FLUX.1 [dev] MODEL, ITS CONSTITUENT COMPONENTS, FLUX CONTENT FILTERS, AND ANY OUTPUT (COLLECTIVELY, “MODEL MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE MODEL MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE MODEL MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE MODEL MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE.
|
| 41 |
+
|
| 42 |
+
7. INDEMNIFICATION. You will indemnify, defend and hold harmless Company and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Company Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Company Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to (a) your access to or use of the FLUX.1 [dev] Model (including in connection with any Output, results or data generated from such access or use, or from your access or use of any FLUX Content Filters), including any High-Risk Use; (b) your Content Filters, including your failure to implement any Content Filters where required by this License such as in Section 2(e); (c) your violation of this License; or (d) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Company Parties of any such Claims, and cooperate with Company Parties in defending such Claims. You will also grant the Company Parties sole control of the defense or settlement, at Company’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Company or the other Company Parties.
|
| 43 |
+
|
| 44 |
+
8. Termination; Survival.
|
| 45 |
+
- a. This License will automatically terminate upon any breach by you of the terms of this License.
|
| 46 |
+
- b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you.
|
| 47 |
+
- c. If you initiate any legal action or proceedings against Company or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the FLUX.1 [dev] Model, any Derivative, or FLUX Content Filters, or any part thereof, infringe upon intellectual property or other rights owned or licensable by you, then any licenses granted to you under this License will immediately terminate as of the date such legal action or claim is filed or initiated.
|
| 48 |
+
- d. Upon termination of this License, you must cease all use, access or Distribution of the FLUX.1 [dev] Model, any Derivatives, and any FLUX Content Filters. The following sections survive termination of this License 2(c), 2(d), 4-11.
|
| 49 |
+
|
| 50 |
+
9. Third Party Materials. The FLUX.1 [dev] Model may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Company does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk.
|
| 51 |
+
|
| 52 |
+
10. Trademarks. You have not been granted any trademark license as part of this License and may not use any name, logo or trademark associated with Company without the prior written permission of Company, except to the extent necessary to make the reference required in the Attribution Notice as specified above or as is reasonably necessary in describing the FLUX.1 [dev] Model and its creators.
|
| 53 |
+
|
| 54 |
+
11. General. This License will be governed and construed under the laws of the State of Delaware without regard to conflicts of law provisions. If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Company to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the documentation, contains the entire understanding between you and Company regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Company regarding such subject matter.
|
app.py
ADDED
|
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Fudan University. All rights reserved.
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
import dataclasses
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import List, Literal, Optional
|
| 10 |
+
|
| 11 |
+
import cv2
|
| 12 |
+
import gradio as gr
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
from PIL import Image, ImageDraw
|
| 16 |
+
|
| 17 |
+
from withanyone.flux.pipeline import WithAnyonePipeline
|
| 18 |
+
from util import extract_moref, face_preserving_resize
|
| 19 |
+
import insightface
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def captioner(prompt: str, num_person = 1) -> List[List[float]]:
|
| 23 |
+
# use random choose for testing
|
| 24 |
+
# within 512
|
| 25 |
+
if num_person == 1:
|
| 26 |
+
bbox_choices = [
|
| 27 |
+
# expanded, centered and quadrant placements
|
| 28 |
+
[96, 96, 288, 288],
|
| 29 |
+
[128, 128, 320, 320],
|
| 30 |
+
[160, 96, 352, 288],
|
| 31 |
+
[96, 160, 288, 352],
|
| 32 |
+
[208, 96, 400, 288],
|
| 33 |
+
[96, 208, 288, 400],
|
| 34 |
+
[192, 160, 368, 336],
|
| 35 |
+
[64, 128, 224, 320],
|
| 36 |
+
[288, 128, 448, 320],
|
| 37 |
+
[128, 256, 320, 448],
|
| 38 |
+
[80, 80, 240, 272],
|
| 39 |
+
[196, 196, 380, 380],
|
| 40 |
+
# originals
|
| 41 |
+
[100, 100, 300, 300],
|
| 42 |
+
[150, 50, 450, 350],
|
| 43 |
+
[200, 100, 500, 400],
|
| 44 |
+
[250, 150, 512, 450],
|
| 45 |
+
]
|
| 46 |
+
return [bbox_choices[np.random.randint(0, len(bbox_choices))]]
|
| 47 |
+
elif num_person == 2:
|
| 48 |
+
# realistic side-by-side rows (no vertical stacks or diagonals)
|
| 49 |
+
bbox_choices = [
|
| 50 |
+
[[64, 112, 224, 304], [288, 112, 448, 304]],
|
| 51 |
+
[[48, 128, 208, 320], [304, 128, 464, 320]],
|
| 52 |
+
[[32, 144, 192, 336], [320, 144, 480, 336]],
|
| 53 |
+
[[80, 96, 240, 288], [272, 96, 432, 288]],
|
| 54 |
+
[[80, 160, 240, 352], [272, 160, 432, 352]],
|
| 55 |
+
[[64, 128, 240, 336], [272, 144, 432, 320]], # slight stagger, same row
|
| 56 |
+
[[96, 160, 256, 352], [288, 160, 448, 352]],
|
| 57 |
+
[[64, 192, 224, 384], [288, 192, 448, 384]], # lower row
|
| 58 |
+
[[16, 128, 176, 320], [336, 128, 496, 320]], # near edges
|
| 59 |
+
[[48, 120, 232, 328], [280, 120, 464, 328]],
|
| 60 |
+
[[96, 160, 240, 336], [272, 160, 416, 336]], # tighter faces
|
| 61 |
+
[[72, 136, 232, 328], [280, 152, 440, 344]], # small vertical offset
|
| 62 |
+
[[48, 120, 224, 344], [288, 144, 448, 336]], # asymmetric sizes
|
| 63 |
+
[[80, 224, 240, 416], [272, 224, 432, 416]], # bottom row
|
| 64 |
+
[[80, 64, 240, 256], [272, 64, 432, 256]], # top row
|
| 65 |
+
[[96, 176, 256, 368], [288, 176, 448, 368]],
|
| 66 |
+
]
|
| 67 |
+
return bbox_choices[np.random.randint(0, len(bbox_choices))]
|
| 68 |
+
|
| 69 |
+
elif num_person == 3:
|
| 70 |
+
# Non-overlapping 3-person layouts within 512x512
|
| 71 |
+
bbox_choices = [
|
| 72 |
+
[[20, 140, 150, 360], [180, 120, 330, 360], [360, 130, 500, 360]],
|
| 73 |
+
[[30, 100, 160, 300], [190, 90, 320, 290], [350, 110, 480, 310]],
|
| 74 |
+
[[40, 180, 150, 330], [200, 180, 310, 330], [360, 180, 470, 330]],
|
| 75 |
+
[[60, 120, 170, 300], [210, 110, 320, 290], [350, 140, 480, 320]],
|
| 76 |
+
[[50, 80, 170, 250], [200, 130, 320, 300], [350, 80, 480, 250]],
|
| 77 |
+
[[40, 260, 170, 480], [190, 60, 320, 240], [350, 260, 490, 480]],
|
| 78 |
+
[[30, 120, 150, 320], [200, 140, 320, 340], [360, 160, 500, 360]],
|
| 79 |
+
[[80, 140, 200, 300], [220, 80, 350, 260], [370, 160, 500, 320]],
|
| 80 |
+
]
|
| 81 |
+
return bbox_choices[np.random.randint(0, len(bbox_choices))]
|
| 82 |
+
elif num_person == 4:
|
| 83 |
+
# Non-overlapping 4-person layouts within 512x512
|
| 84 |
+
bbox_choices = [
|
| 85 |
+
[[20, 100, 120, 240], [140, 100, 240, 240], [260, 100, 360, 240], [380, 100, 480, 240]],
|
| 86 |
+
[[40, 60, 200, 260], [220, 60, 380, 260], [40, 280, 200, 480], [220, 280, 380, 480]],
|
| 87 |
+
[[180, 30, 330, 170], [30, 220, 150, 380], [200, 220, 320, 380], [360, 220, 490, 380]],
|
| 88 |
+
[[30, 60, 140, 200], [370, 60, 480, 200], [30, 320, 140, 460], [370, 320, 480, 460]],
|
| 89 |
+
[[20, 120, 120, 380], [140, 100, 240, 360], [260, 120, 360, 380], [380, 100, 480, 360]],
|
| 90 |
+
[[30, 80, 150, 240], [180, 120, 300, 280], [330, 80, 450, 240], [200, 300, 320, 460]],
|
| 91 |
+
[[30, 140, 110, 330], [140, 140, 220, 330], [250, 140, 330, 330], [370, 140, 450, 330]],
|
| 92 |
+
[[40, 80, 150, 240], [40, 260, 150, 420], [200, 80, 310, 240], [370, 80, 480, 240]],
|
| 93 |
+
]
|
| 94 |
+
return bbox_choices[np.random.randint(0, len(bbox_choices))]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class FaceExtractor:
|
| 100 |
+
def __init__(self, model_path="./"):
|
| 101 |
+
try:
|
| 102 |
+
self.model = insightface.app.FaceAnalysis(name = "antelopev2", root=model_path, providers=['CUDAExecutionProvider'])
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"Error loading insightface model: {e}. There might be an issue with the directory structure. Trying to fix it...")
|
| 105 |
+
antelopev2_nested_path = os.path.join(model_path, "models", "antelopev2", "antelopev2")
|
| 106 |
+
print(f"Checking for nested path: {antelopev2_nested_path}")
|
| 107 |
+
if os.path.exists(antelopev2_nested_path):
|
| 108 |
+
import subprocess
|
| 109 |
+
print("Detected nested antelopev2 directory, fixing directory structure...")
|
| 110 |
+
# Change to the model_path directory to execute commands
|
| 111 |
+
current_dir = os.getcwd()
|
| 112 |
+
os.chdir(model_path)
|
| 113 |
+
# Execute the commands as specified by the user
|
| 114 |
+
subprocess.run(["mv", "models/antelopev2/", "models/antelopev2_"])
|
| 115 |
+
subprocess.run(["mv", "models/antelopev2_/antelopev2/", "models/antelopev2/"])
|
| 116 |
+
# Return to the original directory
|
| 117 |
+
os.chdir(current_dir)
|
| 118 |
+
print("Directory structure fixed.")
|
| 119 |
+
self.model = insightface.app.FaceAnalysis(name="antelopev2", root="./")
|
| 120 |
+
self.model.prepare(ctx_id=0)
|
| 121 |
+
|
| 122 |
+
def extract(self, image: Image.Image):
|
| 123 |
+
"""Extract single face and embedding from an image"""
|
| 124 |
+
image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 125 |
+
res = self.model.get(image_np)
|
| 126 |
+
if len(res) == 0:
|
| 127 |
+
return None, None
|
| 128 |
+
res = res[0]
|
| 129 |
+
bbox = res["bbox"]
|
| 130 |
+
moref = extract_moref(image, {"bboxes": [bbox]}, 1)
|
| 131 |
+
return moref[0], res["embedding"]
|
| 132 |
+
|
| 133 |
+
def extract_refs(self, image: Image.Image):
|
| 134 |
+
"""Extract multiple faces and embeddings from an image"""
|
| 135 |
+
image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 136 |
+
res = self.model.get(image_np)
|
| 137 |
+
if len(res) == 0:
|
| 138 |
+
return None, None, None
|
| 139 |
+
ref_imgs = []
|
| 140 |
+
arcface_embeddings = []
|
| 141 |
+
bboxes = []
|
| 142 |
+
for r in res:
|
| 143 |
+
bbox = r["bbox"]
|
| 144 |
+
bboxes.append(bbox)
|
| 145 |
+
moref = extract_moref(image, {"bboxes": [bbox]}, 1)
|
| 146 |
+
ref_imgs.append(moref[0])
|
| 147 |
+
arcface_embeddings.append(r["embedding"])
|
| 148 |
+
|
| 149 |
+
# Convert bboxes to the correct format
|
| 150 |
+
new_img, new_bboxes = face_preserving_resize(image, bboxes, 512)
|
| 151 |
+
return ref_imgs, arcface_embeddings, new_bboxes, new_img
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def resize_bbox(bbox, ori_width, ori_height, new_width, new_height):
|
| 155 |
+
"""Resize bounding box coordinates while preserving aspect ratio"""
|
| 156 |
+
x1, y1, x2, y2 = bbox
|
| 157 |
+
|
| 158 |
+
# Calculate scaling factors
|
| 159 |
+
width_scale = new_width / ori_width
|
| 160 |
+
height_scale = new_height / ori_height
|
| 161 |
+
|
| 162 |
+
# Use minimum scaling factor to preserve aspect ratio
|
| 163 |
+
min_scale = min(width_scale, height_scale)
|
| 164 |
+
|
| 165 |
+
# Calculate offsets for centering the scaled box
|
| 166 |
+
width_offset = (new_width - ori_width * min_scale) / 2
|
| 167 |
+
height_offset = (new_height - ori_height * min_scale) / 2
|
| 168 |
+
|
| 169 |
+
# Scale and adjust coordinates
|
| 170 |
+
new_x1 = int(x1 * min_scale + width_offset)
|
| 171 |
+
new_y1 = int(y1 * min_scale + height_offset)
|
| 172 |
+
new_x2 = int(x2 * min_scale + width_offset)
|
| 173 |
+
new_y2 = int(y2 * min_scale + height_offset)
|
| 174 |
+
|
| 175 |
+
return [new_x1, new_y1, new_x2, new_y2]
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def draw_bboxes_on_image(image, bboxes):
|
| 179 |
+
"""Draw bounding boxes on image for visualization"""
|
| 180 |
+
if bboxes is None:
|
| 181 |
+
return image
|
| 182 |
+
|
| 183 |
+
# Create a copy to draw on
|
| 184 |
+
img_draw = image.copy()
|
| 185 |
+
draw = ImageDraw.Draw(img_draw)
|
| 186 |
+
|
| 187 |
+
# Draw each bbox with a different color
|
| 188 |
+
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]
|
| 189 |
+
|
| 190 |
+
for i, bbox in enumerate(bboxes):
|
| 191 |
+
color = colors[i % len(colors)]
|
| 192 |
+
x1, y1, x2, y2 = [int(coord) for coord in bbox]
|
| 193 |
+
# Draw rectangle
|
| 194 |
+
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
|
| 195 |
+
# Draw label
|
| 196 |
+
draw.text((x1, y1-15), f"Face {i+1}", fill=color)
|
| 197 |
+
|
| 198 |
+
return img_draw
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def create_demo(
|
| 202 |
+
model_type: str = "flux-dev",
|
| 203 |
+
ipa_path: str = "./ckpt/ipa.safetensors",
|
| 204 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
| 205 |
+
offload: bool = False,
|
| 206 |
+
lora_rank: int = 64,
|
| 207 |
+
additional_lora_ckpt: Optional[str] = None,
|
| 208 |
+
lora_scale: float = 1.0,
|
| 209 |
+
clip_path: str = "openai/clip-vit-large-patch14",
|
| 210 |
+
t5_path: str = "xlabs-ai/xflux_text_encoders",
|
| 211 |
+
flux_path: str = "black-forest-labs/FLUX.1-dev",
|
| 212 |
+
):
|
| 213 |
+
|
| 214 |
+
face_extractor = FaceExtractor()
|
| 215 |
+
# Initialize pipeline and face extractor
|
| 216 |
+
pipeline = WithAnyonePipeline(
|
| 217 |
+
model_type,
|
| 218 |
+
ipa_path,
|
| 219 |
+
device,
|
| 220 |
+
offload,
|
| 221 |
+
only_lora=True,
|
| 222 |
+
no_lora=True,
|
| 223 |
+
lora_rank=lora_rank,
|
| 224 |
+
additional_lora_ckpt=additional_lora_ckpt,
|
| 225 |
+
lora_weight=lora_scale,
|
| 226 |
+
face_extractor=face_extractor,
|
| 227 |
+
clip_path=clip_path,
|
| 228 |
+
t5_path=t5_path,
|
| 229 |
+
flux_path=flux_path,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# Add project badges
|
| 235 |
+
# badges_text = r"""
|
| 236 |
+
# <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
|
| 237 |
+
# <a href="https://github.com/bytedance/UNO"><img alt="Build" src="https://img.shields.io/github/stars/bytedance/UNO"></a>
|
| 238 |
+
# <a href="https://bytedance.github.io/UNO/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-UNO-yellow"></a>
|
| 239 |
+
# <a href="https://arxiv.org/abs/2504.02160"><img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-UNO-b31b1b.svg"></a>
|
| 240 |
+
# </div>
|
| 241 |
+
# """.strip()
|
| 242 |
+
|
| 243 |
+
def parse_bboxes(bbox_text):
|
| 244 |
+
"""Parse bounding box text input"""
|
| 245 |
+
if not bbox_text or bbox_text.strip() == "":
|
| 246 |
+
return None
|
| 247 |
+
|
| 248 |
+
try:
|
| 249 |
+
bboxes = []
|
| 250 |
+
lines = bbox_text.strip().split("\n")
|
| 251 |
+
for line in lines:
|
| 252 |
+
if not line.strip():
|
| 253 |
+
continue
|
| 254 |
+
coords = [float(x) for x in line.strip().split(",")]
|
| 255 |
+
if len(coords) != 4:
|
| 256 |
+
raise ValueError(f"Each bbox must have 4 coordinates (x1,y1,x2,y2), got: {line}")
|
| 257 |
+
bboxes.append(coords)
|
| 258 |
+
# print(f"\nParsed bboxes: {bboxes}\n")
|
| 259 |
+
return bboxes
|
| 260 |
+
except Exception as e:
|
| 261 |
+
raise gr.Error(f"Invalid bbox format: {e}")
|
| 262 |
+
|
| 263 |
+
def extract_from_multi_person(multi_person_image):
|
| 264 |
+
"""Extract references and bboxes from a multi-person image"""
|
| 265 |
+
if multi_person_image is None:
|
| 266 |
+
return None, None, None, None
|
| 267 |
+
|
| 268 |
+
# Convert from numpy to PIL if needed
|
| 269 |
+
if isinstance(multi_person_image, np.ndarray):
|
| 270 |
+
multi_person_image = Image.fromarray(multi_person_image)
|
| 271 |
+
|
| 272 |
+
ref_imgs, arcface_embeddings, bboxes, new_img = face_extractor.extract_refs(multi_person_image)
|
| 273 |
+
|
| 274 |
+
if ref_imgs is None or len(ref_imgs) == 0:
|
| 275 |
+
raise gr.Error("No faces detected in the multi-person image")
|
| 276 |
+
|
| 277 |
+
# Limit to max 4 faces
|
| 278 |
+
ref_imgs = ref_imgs[:4]
|
| 279 |
+
arcface_embeddings = arcface_embeddings[:4]
|
| 280 |
+
bboxes = bboxes[:4]
|
| 281 |
+
|
| 282 |
+
# Create visualization with bboxes
|
| 283 |
+
viz_image = draw_bboxes_on_image(new_img, bboxes)
|
| 284 |
+
|
| 285 |
+
# Format bboxes as string for display
|
| 286 |
+
bbox_text = "\n".join([f"{bbox[0]:.1f},{bbox[1]:.1f},{bbox[2]:.1f},{bbox[3]:.1f}" for bbox in bboxes])
|
| 287 |
+
|
| 288 |
+
return ref_imgs, arcface_embeddings, bboxes, viz_image
|
| 289 |
+
|
| 290 |
+
def process_and_generate(
|
| 291 |
+
prompt,
|
| 292 |
+
width, height,
|
| 293 |
+
guidance, num_steps, seed,
|
| 294 |
+
ref_img1, ref_img2, ref_img3, ref_img4,
|
| 295 |
+
manual_bboxes_text,
|
| 296 |
+
multi_person_image,
|
| 297 |
+
# use_text_prompt,
|
| 298 |
+
# id_weight,
|
| 299 |
+
siglip_weight
|
| 300 |
+
):
|
| 301 |
+
# Collect and validate reference images
|
| 302 |
+
ref_images = [img for img in [ref_img1, ref_img2, ref_img3, ref_img4] if img is not None]
|
| 303 |
+
|
| 304 |
+
if not ref_images:
|
| 305 |
+
raise gr.Error("At least one reference image is required")
|
| 306 |
+
|
| 307 |
+
# Process reference images to extract face and embeddings
|
| 308 |
+
ref_imgs = []
|
| 309 |
+
arcface_embeddings = []
|
| 310 |
+
|
| 311 |
+
# Modified bbox handling logic
|
| 312 |
+
if multi_person_image is not None:
|
| 313 |
+
# Extract from multi-person image mode
|
| 314 |
+
extracted_refs, extracted_embeddings, bboxes_, _ = extract_from_multi_person(multi_person_image)
|
| 315 |
+
if extracted_refs is None:
|
| 316 |
+
raise gr.Error("Failed to extract faces from the multi-person image")
|
| 317 |
+
|
| 318 |
+
print("bboxes from multi-person image:", bboxes_)
|
| 319 |
+
# need to resize bboxes from 512 512 to width height
|
| 320 |
+
bboxes_ = [resize_bbox(bbox, 512, 512, width, height) for bbox in bboxes_]
|
| 321 |
+
|
| 322 |
+
else:
|
| 323 |
+
# Parse manual bboxes
|
| 324 |
+
bboxes_ = parse_bboxes(manual_bboxes_text)
|
| 325 |
+
|
| 326 |
+
# If no manual bboxes provided, use automatic captioner
|
| 327 |
+
if bboxes_ is None:
|
| 328 |
+
print("No multi-person image or manual bboxes provided. Using automatic captioner.")
|
| 329 |
+
# Generate automatic bboxes based on image dimensions
|
| 330 |
+
bboxes__ = captioner(prompt, num_person=len(ref_images))
|
| 331 |
+
# resize to width height
|
| 332 |
+
bboxes_ = [resize_bbox(bbox, 512, 512, width, height) for bbox in bboxes__]
|
| 333 |
+
print("Automatically generated bboxes:", bboxes_)
|
| 334 |
+
|
| 335 |
+
bboxes = [bboxes_] # 伪装batch输入
|
| 336 |
+
# else:
|
| 337 |
+
# Manual mode: process each reference image
|
| 338 |
+
for img in ref_images:
|
| 339 |
+
if isinstance(img, np.ndarray):
|
| 340 |
+
img = Image.fromarray(img)
|
| 341 |
+
|
| 342 |
+
ref_img, embedding = face_extractor.extract(img)
|
| 343 |
+
if ref_img is None or embedding is None:
|
| 344 |
+
raise gr.Error("Failed to extract face from one of the reference images")
|
| 345 |
+
|
| 346 |
+
ref_imgs.append(ref_img)
|
| 347 |
+
arcface_embeddings.append(embedding)
|
| 348 |
+
|
| 349 |
+
# pad arcface_embeddings to 4 if less than 4
|
| 350 |
+
# while len(arcface_embeddings) < 4:
|
| 351 |
+
# arcface_embeddings.append(np.zeros_like(arcface_embeddings[0]))
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
if bboxes is None:
|
| 355 |
+
raise gr.Error("Either provide manual bboxes or a multi-person image for bbox extraction")
|
| 356 |
+
|
| 357 |
+
if len(bboxes[0]) != len(ref_imgs):
|
| 358 |
+
raise gr.Error(f"Number of bboxes ({len(bboxes[0])}) must match number of reference images ({len(ref_imgs)})")
|
| 359 |
+
|
| 360 |
+
# Convert arcface embeddings to tensor
|
| 361 |
+
arcface_embeddings = [torch.tensor(embedding) for embedding in arcface_embeddings]
|
| 362 |
+
arcface_embeddings = torch.stack(arcface_embeddings).to(device)
|
| 363 |
+
|
| 364 |
+
# Generate image
|
| 365 |
+
final_prompt = prompt
|
| 366 |
+
|
| 367 |
+
print(f"Generating image of size {width}x{height} with bboxes: {bboxes} ")
|
| 368 |
+
|
| 369 |
+
if seed < 0:
|
| 370 |
+
seed = np.random.randint(0, 1000000)
|
| 371 |
+
|
| 372 |
+
image_gen = pipeline(
|
| 373 |
+
prompt=final_prompt,
|
| 374 |
+
width=width,
|
| 375 |
+
height=height,
|
| 376 |
+
guidance=guidance,
|
| 377 |
+
num_steps=num_steps,
|
| 378 |
+
seed=seed if seed > 0 else None,
|
| 379 |
+
ref_imgs=ref_imgs,
|
| 380 |
+
arcface_embeddings=arcface_embeddings,
|
| 381 |
+
bboxes=bboxes,
|
| 382 |
+
id_weight = 1 - siglip_weight,
|
| 383 |
+
siglip_weight=siglip_weight,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
# Save temp file for download
|
| 387 |
+
temp_path = "temp_generated.png"
|
| 388 |
+
image_gen.save(temp_path)
|
| 389 |
+
|
| 390 |
+
# draw bboxes on the generated image for debug
|
| 391 |
+
debug_face = draw_bboxes_on_image(image_gen, bboxes[0])
|
| 392 |
+
|
| 393 |
+
return image_gen, debug_face, temp_path
|
| 394 |
+
|
| 395 |
+
def update_bbox_display(multi_person_image):
|
| 396 |
+
if multi_person_image is None:
|
| 397 |
+
return None, gr.update(visible=True), gr.update(visible=False)
|
| 398 |
+
|
| 399 |
+
try:
|
| 400 |
+
_, _, _, viz_image = extract_from_multi_person(multi_person_image)
|
| 401 |
+
return viz_image, gr.update(visible=False), gr.update(visible=True)
|
| 402 |
+
except Exception as e:
|
| 403 |
+
return None, gr.update(visible=True), gr.update(visible=False)
|
| 404 |
+
|
| 405 |
+
# Create Gradio interface
|
| 406 |
+
with gr.Blocks() as demo:
|
| 407 |
+
gr.Markdown("# WithAnyone Demo")
|
| 408 |
+
# gr.Markdown(badges_text)
|
| 409 |
+
|
| 410 |
+
with gr.Row():
|
| 411 |
+
|
| 412 |
+
with gr.Column():
|
| 413 |
+
# Input controls
|
| 414 |
+
generate_btn = gr.Button("Generate", variant="primary")
|
| 415 |
+
with gr.Row():
|
| 416 |
+
with gr.Column():
|
| 417 |
+
siglip_weight = gr.Slider(0.0, 1.0, 1.0, step=0.05, label="Spiritual Resemblance <--> Formal Resemblance")
|
| 418 |
+
with gr.Row():
|
| 419 |
+
prompt = gr.Textbox(label="Prompt", value="a person in a beautiful garden. High resolution, extremely detailed")
|
| 420 |
+
# use_text_prompt = gr.Checkbox(label="Use text prompt", value=True)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
with gr.Row():
|
| 424 |
+
# Image generation settings
|
| 425 |
+
with gr.Column():
|
| 426 |
+
width = gr.Slider(512, 1024, 768, step=64, label="Generation Width")
|
| 427 |
+
height = gr.Slider(512, 1024, 768, step=64, label="Generation Height")
|
| 428 |
+
|
| 429 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 430 |
+
with gr.Row():
|
| 431 |
+
num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
|
| 432 |
+
guidance = gr.Slider(1.0, 10.0, 4.0, step=0.1, label="Guidance")
|
| 433 |
+
seed = gr.Number(-1, label="Seed (-1 for random)")
|
| 434 |
+
|
| 435 |
+
# start_at = gr.Slider(0, 50, 0, step=1, label="Start Identity at Step")
|
| 436 |
+
# end_at = gr.Number(-1, label="End Identity at Step (-1 for last)")
|
| 437 |
+
|
| 438 |
+
# with gr.Row():
|
| 439 |
+
# # skip_every = gr.Number(-1, label="Skip Identity Every N Steps (-1 for no skip)")
|
| 440 |
+
|
| 441 |
+
# siglip_weight = gr.Slider(0.0, 1.0, 1.0, step=0.05, label="Siglip Weight")
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
with gr.Row():
|
| 445 |
+
with gr.Column():
|
| 446 |
+
# Reference image inputs
|
| 447 |
+
gr.Markdown("### Face References (1-4 required)")
|
| 448 |
+
ref_img1 = gr.Image(label="Reference 1", type="pil")
|
| 449 |
+
ref_img2 = gr.Image(label="Reference 2", type="pil", visible=True)
|
| 450 |
+
ref_img3 = gr.Image(label="Reference 3", type="pil", visible=True)
|
| 451 |
+
ref_img4 = gr.Image(label="Reference 4", type="pil", visible=True)
|
| 452 |
+
|
| 453 |
+
with gr.Column():
|
| 454 |
+
# Bounding box inputs
|
| 455 |
+
gr.Markdown("### Mask Configuration (Option 1: Automatic)")
|
| 456 |
+
multi_person_image = gr.Image(label="Multi-person image (for automatic bbox extraction)", type="pil")
|
| 457 |
+
bbox_preview = gr.Image(label="Detected Faces", type="pil")
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
gr.Markdown("### Mask Configuration (Option 2: Manual)")
|
| 461 |
+
manual_bbox_input = gr.Textbox(
|
| 462 |
+
label="Manual Bounding Boxes (one per line, format: x1,y1,x2,y2)",
|
| 463 |
+
lines=4,
|
| 464 |
+
placeholder="100,100,200,200\n300,100,400,200"
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
# generate_btn = gr.Button("Generate", variant="primary")
|
| 472 |
+
|
| 473 |
+
with gr.Column():
|
| 474 |
+
# Output display
|
| 475 |
+
output_image = gr.Image(label="Generated Image")
|
| 476 |
+
debug_face = gr.Image(label="Debug. Faces are expected to be generated in these boxes")
|
| 477 |
+
download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False)
|
| 478 |
+
|
| 479 |
+
# Examples section
|
| 480 |
+
with gr.Row():
|
| 481 |
+
|
| 482 |
+
gr.Markdown("""
|
| 483 |
+
# Example Configurations
|
| 484 |
+
|
| 485 |
+
### Tips for Better Results
|
| 486 |
+
Be prepared for the first few runs as it may not be very satisfying.
|
| 487 |
+
|
| 488 |
+
- Provide detailed prompts describing the identity. WithAnyone is "controllable", so it needs more information to be controlled. Here are something that might go wrong if not specified:
|
| 489 |
+
- Skin color (generally the race is fine, but for asain descent, if not specified, it may generate darker skin tone);
|
| 490 |
+
- Age (e.g., intead of "a man", try "a young man". If not specified, it may generate an older figure);
|
| 491 |
+
- Body build;
|
| 492 |
+
- Hairstyle;
|
| 493 |
+
- Accessories (glasses, hats, earrings, etc.);
|
| 494 |
+
- Makeup
|
| 495 |
+
- Use the slider to balance between "Resemblance in Spirit" and "Resemblance in Form" according to your needs. If you want to preserve more details in the reference image, move the slider to the right; if you want more freedom and creativity, move it to the left.
|
| 496 |
+
- Try it with LoRAs from community. They are usually fantastic.
|
| 497 |
+
""")
|
| 498 |
+
with gr.Row():
|
| 499 |
+
examples = gr.Examples(
|
| 500 |
+
examples=[
|
| 501 |
+
[
|
| 502 |
+
"a highly detailed portrait of a woman shown in profile. Her long, dark hair flows elegantly, intricately decorated with an abundant array of colorful flowers—ranging from soft light pinks and vibrant light oranges to delicate greyish blues—and lush green leaves, giving a sense of natural beauty and charm. Her bright blue eyes are striking, and her lips are painted a vivid red, adding to her alluring appearance. She is clad in an ornate garment with intricate floral patterns in warm hues like pink and orange, featuring exquisite detailing that speaks of fine craftsmanship. Around her neck, she wears a decorative choker with intricate designs, and dangling from her ears are beautiful blue teardrop earrings that catch the light. The background is filled with a profusion of flowers in various shades, creating a rich, vibrant, and romantic atmosphere that complements the woman's elegant and enchanting look.", # prompt
|
| 503 |
+
1024, 1024, # width, height
|
| 504 |
+
4.0, 25, 42, # guidance, num_steps, seed
|
| 505 |
+
"assets/ref1.jpg", None, None, None, # ref images
|
| 506 |
+
"240,180,540,500", None, # manual_bbox_input, multi_person_image
|
| 507 |
+
# True, # use_text_prompt
|
| 508 |
+
0.0, # siglip_weight
|
| 509 |
+
],
|
| 510 |
+
[
|
| 511 |
+
"High resolution anfd extremely detailed image of two elegant ladies enjoying a serene afternoon in a quaint Parisian café. They both wear fashionable trench coats and stylish berets, exuding an air of sophistication. One lady gently sips on a cappuccino, while her companion reads an intriguing novel with a subtle smile. The café is framed by charming antique furniture and vintage posters adorning the walls. Soft, warm light filters through a window, casting delicate shadows and creating a cozy, inviting atmosphere. Captured from a slightly elevated angle, the composition highlights the warmth of the scene in a gentle watercolor illustrative style. ", # prompt
|
| 512 |
+
1024, 1024, # width, height
|
| 513 |
+
4.0, 25, 42, # guidance, num_steps, seed
|
| 514 |
+
"assets/ref1.jpg", "assets/ref2.jpg", None, None, # ref images
|
| 515 |
+
"248,172,428,498\n554,128,728,464", None, # manual_bbox_input, multi_person_image
|
| 516 |
+
# True, # use_text_prompt
|
| 517 |
+
0.0, # siglip_weight
|
| 518 |
+
]
|
| 519 |
+
],
|
| 520 |
+
inputs=[
|
| 521 |
+
prompt, width, height, guidance, num_steps, seed,
|
| 522 |
+
ref_img1, ref_img2, ref_img3, ref_img4,
|
| 523 |
+
manual_bbox_input, multi_person_image,
|
| 524 |
+
siglip_weight
|
| 525 |
+
],
|
| 526 |
+
label="Click to load example configurations"
|
| 527 |
+
)
|
| 528 |
+
# Set up event handlers
|
| 529 |
+
multi_person_image.change(
|
| 530 |
+
fn=update_bbox_display,
|
| 531 |
+
inputs=[multi_person_image],
|
| 532 |
+
outputs=[bbox_preview, manual_bbox_input, bbox_preview]
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
generate_btn.click(
|
| 536 |
+
fn=process_and_generate,
|
| 537 |
+
inputs=[
|
| 538 |
+
prompt, width, height, guidance, num_steps, seed,
|
| 539 |
+
ref_img1, ref_img2, ref_img3, ref_img4,
|
| 540 |
+
manual_bbox_input, multi_person_image,
|
| 541 |
+
siglip_weight
|
| 542 |
+
],
|
| 543 |
+
outputs=[output_image,debug_face, download_btn]
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
return demo
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
if __name__ == "__main__":
|
| 550 |
+
from transformers import HfArgumentParser
|
| 551 |
+
|
| 552 |
+
@dataclasses.dataclass
|
| 553 |
+
class AppArgs:
|
| 554 |
+
model_type: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
|
| 555 |
+
device: Literal["cuda", "cpu"] = (
|
| 556 |
+
"cuda" if torch.cuda.is_available()
|
| 557 |
+
else "mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
| 558 |
+
else "cpu"
|
| 559 |
+
)
|
| 560 |
+
offload: bool = False
|
| 561 |
+
lora_rank: int = 64
|
| 562 |
+
port: int = 7860
|
| 563 |
+
additional_lora: str = None
|
| 564 |
+
lora_scale: float = 1.0
|
| 565 |
+
ipa_path: str = "WithAnyone/WithAnyone"
|
| 566 |
+
clip_path: str = "openai/clip-vit-large-patch14"
|
| 567 |
+
t5_path: str = "xlabs-ai/xflux_text_encoders"
|
| 568 |
+
flux_path: str = "black-forest-labs/FLUX.1-dev"
|
| 569 |
+
|
| 570 |
+
parser = HfArgumentParser([AppArgs])
|
| 571 |
+
args = parser.parse_args_into_dataclasses()[0]
|
| 572 |
+
|
| 573 |
+
demo = create_demo(
|
| 574 |
+
args.model_type,
|
| 575 |
+
args.ipa_path,
|
| 576 |
+
args.device,
|
| 577 |
+
args.offload,
|
| 578 |
+
args.lora_rank,
|
| 579 |
+
args.additional_lora,
|
| 580 |
+
args.lora_scale,
|
| 581 |
+
args.clip_path,
|
| 582 |
+
args.t5_path,
|
| 583 |
+
args.flux_path,
|
| 584 |
+
)
|
| 585 |
+
demo.launch(server_port=args.port)
|
| 586 |
+
|
gradio_app.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Fudan University. All rights reserved.
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
import dataclasses
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import List, Literal, Optional
|
| 10 |
+
|
| 11 |
+
import cv2
|
| 12 |
+
import gradio as gr
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
from PIL import Image, ImageDraw
|
| 16 |
+
|
| 17 |
+
from withanyone.flux.pipeline import WithAnyonePipeline
|
| 18 |
+
from util import extract_moref, face_preserving_resize
|
| 19 |
+
import insightface
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def captioner(prompt: str, num_person = 1) -> List[List[float]]:
|
| 23 |
+
# use random choose for testing
|
| 24 |
+
# within 512
|
| 25 |
+
if num_person == 1:
|
| 26 |
+
bbox_choices = [
|
| 27 |
+
# expanded, centered and quadrant placements
|
| 28 |
+
[96, 96, 288, 288],
|
| 29 |
+
[128, 128, 320, 320],
|
| 30 |
+
[160, 96, 352, 288],
|
| 31 |
+
[96, 160, 288, 352],
|
| 32 |
+
[208, 96, 400, 288],
|
| 33 |
+
[96, 208, 288, 400],
|
| 34 |
+
[192, 160, 368, 336],
|
| 35 |
+
[64, 128, 224, 320],
|
| 36 |
+
[288, 128, 448, 320],
|
| 37 |
+
[128, 256, 320, 448],
|
| 38 |
+
[80, 80, 240, 272],
|
| 39 |
+
[196, 196, 380, 380],
|
| 40 |
+
# originals
|
| 41 |
+
[100, 100, 300, 300],
|
| 42 |
+
[150, 50, 450, 350],
|
| 43 |
+
[200, 100, 500, 400],
|
| 44 |
+
[250, 150, 512, 450],
|
| 45 |
+
]
|
| 46 |
+
return [bbox_choices[np.random.randint(0, len(bbox_choices))]]
|
| 47 |
+
elif num_person == 2:
|
| 48 |
+
# realistic side-by-side rows (no vertical stacks or diagonals)
|
| 49 |
+
bbox_choices = [
|
| 50 |
+
[[64, 112, 224, 304], [288, 112, 448, 304]],
|
| 51 |
+
[[48, 128, 208, 320], [304, 128, 464, 320]],
|
| 52 |
+
[[32, 144, 192, 336], [320, 144, 480, 336]],
|
| 53 |
+
[[80, 96, 240, 288], [272, 96, 432, 288]],
|
| 54 |
+
[[80, 160, 240, 352], [272, 160, 432, 352]],
|
| 55 |
+
[[64, 128, 240, 336], [272, 144, 432, 320]], # slight stagger, same row
|
| 56 |
+
[[96, 160, 256, 352], [288, 160, 448, 352]],
|
| 57 |
+
[[64, 192, 224, 384], [288, 192, 448, 384]], # lower row
|
| 58 |
+
[[16, 128, 176, 320], [336, 128, 496, 320]], # near edges
|
| 59 |
+
[[48, 120, 232, 328], [280, 120, 464, 328]],
|
| 60 |
+
[[96, 160, 240, 336], [272, 160, 416, 336]], # tighter faces
|
| 61 |
+
[[72, 136, 232, 328], [280, 152, 440, 344]], # small vertical offset
|
| 62 |
+
[[48, 120, 224, 344], [288, 144, 448, 336]], # asymmetric sizes
|
| 63 |
+
[[80, 224, 240, 416], [272, 224, 432, 416]], # bottom row
|
| 64 |
+
[[80, 64, 240, 256], [272, 64, 432, 256]], # top row
|
| 65 |
+
[[96, 176, 256, 368], [288, 176, 448, 368]],
|
| 66 |
+
]
|
| 67 |
+
return bbox_choices[np.random.randint(0, len(bbox_choices))]
|
| 68 |
+
|
| 69 |
+
elif num_person == 3:
|
| 70 |
+
# Non-overlapping 3-person layouts within 512x512
|
| 71 |
+
bbox_choices = [
|
| 72 |
+
[[20, 140, 150, 360], [180, 120, 330, 360], [360, 130, 500, 360]],
|
| 73 |
+
[[30, 100, 160, 300], [190, 90, 320, 290], [350, 110, 480, 310]],
|
| 74 |
+
[[40, 180, 150, 330], [200, 180, 310, 330], [360, 180, 470, 330]],
|
| 75 |
+
[[60, 120, 170, 300], [210, 110, 320, 290], [350, 140, 480, 320]],
|
| 76 |
+
[[50, 80, 170, 250], [200, 130, 320, 300], [350, 80, 480, 250]],
|
| 77 |
+
[[40, 260, 170, 480], [190, 60, 320, 240], [350, 260, 490, 480]],
|
| 78 |
+
[[30, 120, 150, 320], [200, 140, 320, 340], [360, 160, 500, 360]],
|
| 79 |
+
[[80, 140, 200, 300], [220, 80, 350, 260], [370, 160, 500, 320]],
|
| 80 |
+
]
|
| 81 |
+
return bbox_choices[np.random.randint(0, len(bbox_choices))]
|
| 82 |
+
elif num_person == 4:
|
| 83 |
+
# Non-overlapping 4-person layouts within 512x512
|
| 84 |
+
bbox_choices = [
|
| 85 |
+
[[20, 100, 120, 240], [140, 100, 240, 240], [260, 100, 360, 240], [380, 100, 480, 240]],
|
| 86 |
+
[[40, 60, 200, 260], [220, 60, 380, 260], [40, 280, 200, 480], [220, 280, 380, 480]],
|
| 87 |
+
[[180, 30, 330, 170], [30, 220, 150, 380], [200, 220, 320, 380], [360, 220, 490, 380]],
|
| 88 |
+
[[30, 60, 140, 200], [370, 60, 480, 200], [30, 320, 140, 460], [370, 320, 480, 460]],
|
| 89 |
+
[[20, 120, 120, 380], [140, 100, 240, 360], [260, 120, 360, 380], [380, 100, 480, 360]],
|
| 90 |
+
[[30, 80, 150, 240], [180, 120, 300, 280], [330, 80, 450, 240], [200, 300, 320, 460]],
|
| 91 |
+
[[30, 140, 110, 330], [140, 140, 220, 330], [250, 140, 330, 330], [370, 140, 450, 330]],
|
| 92 |
+
[[40, 80, 150, 240], [40, 260, 150, 420], [200, 80, 310, 240], [370, 80, 480, 240]],
|
| 93 |
+
]
|
| 94 |
+
return bbox_choices[np.random.randint(0, len(bbox_choices))]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class FaceExtractor:
|
| 100 |
+
def __init__(self, model_path="./"):
|
| 101 |
+
self.model = insightface.app.FaceAnalysis(name="antelopev2", root="./")
|
| 102 |
+
self.model.prepare(ctx_id=0)
|
| 103 |
+
|
| 104 |
+
def extract(self, image: Image.Image):
|
| 105 |
+
"""Extract single face and embedding from an image"""
|
| 106 |
+
image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 107 |
+
res = self.model.get(image_np)
|
| 108 |
+
if len(res) == 0:
|
| 109 |
+
return None, None
|
| 110 |
+
res = res[0]
|
| 111 |
+
bbox = res["bbox"]
|
| 112 |
+
moref = extract_moref(image, {"bboxes": [bbox]}, 1)
|
| 113 |
+
return moref[0], res["embedding"]
|
| 114 |
+
|
| 115 |
+
def extract_refs(self, image: Image.Image):
|
| 116 |
+
"""Extract multiple faces and embeddings from an image"""
|
| 117 |
+
image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 118 |
+
res = self.model.get(image_np)
|
| 119 |
+
if len(res) == 0:
|
| 120 |
+
return None, None, None
|
| 121 |
+
ref_imgs = []
|
| 122 |
+
arcface_embeddings = []
|
| 123 |
+
bboxes = []
|
| 124 |
+
for r in res:
|
| 125 |
+
bbox = r["bbox"]
|
| 126 |
+
bboxes.append(bbox)
|
| 127 |
+
moref = extract_moref(image, {"bboxes": [bbox]}, 1)
|
| 128 |
+
ref_imgs.append(moref[0])
|
| 129 |
+
arcface_embeddings.append(r["embedding"])
|
| 130 |
+
|
| 131 |
+
# Convert bboxes to the correct format
|
| 132 |
+
new_img, new_bboxes = face_preserving_resize(image, bboxes, 512)
|
| 133 |
+
return ref_imgs, arcface_embeddings, new_bboxes, new_img
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def resize_bbox(bbox, ori_width, ori_height, new_width, new_height):
|
| 137 |
+
"""Resize bounding box coordinates while preserving aspect ratio"""
|
| 138 |
+
x1, y1, x2, y2 = bbox
|
| 139 |
+
|
| 140 |
+
# Calculate scaling factors
|
| 141 |
+
width_scale = new_width / ori_width
|
| 142 |
+
height_scale = new_height / ori_height
|
| 143 |
+
|
| 144 |
+
# Use minimum scaling factor to preserve aspect ratio
|
| 145 |
+
min_scale = min(width_scale, height_scale)
|
| 146 |
+
|
| 147 |
+
# Calculate offsets for centering the scaled box
|
| 148 |
+
width_offset = (new_width - ori_width * min_scale) / 2
|
| 149 |
+
height_offset = (new_height - ori_height * min_scale) / 2
|
| 150 |
+
|
| 151 |
+
# Scale and adjust coordinates
|
| 152 |
+
new_x1 = int(x1 * min_scale + width_offset)
|
| 153 |
+
new_y1 = int(y1 * min_scale + height_offset)
|
| 154 |
+
new_x2 = int(x2 * min_scale + width_offset)
|
| 155 |
+
new_y2 = int(y2 * min_scale + height_offset)
|
| 156 |
+
|
| 157 |
+
return [new_x1, new_y1, new_x2, new_y2]
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def draw_bboxes_on_image(image, bboxes):
|
| 161 |
+
"""Draw bounding boxes on image for visualization"""
|
| 162 |
+
if bboxes is None:
|
| 163 |
+
return image
|
| 164 |
+
|
| 165 |
+
# Create a copy to draw on
|
| 166 |
+
img_draw = image.copy()
|
| 167 |
+
draw = ImageDraw.Draw(img_draw)
|
| 168 |
+
|
| 169 |
+
# Draw each bbox with a different color
|
| 170 |
+
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]
|
| 171 |
+
|
| 172 |
+
for i, bbox in enumerate(bboxes):
|
| 173 |
+
color = colors[i % len(colors)]
|
| 174 |
+
x1, y1, x2, y2 = [int(coord) for coord in bbox]
|
| 175 |
+
# Draw rectangle
|
| 176 |
+
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
|
| 177 |
+
# Draw label
|
| 178 |
+
draw.text((x1, y1-15), f"Face {i+1}", fill=color)
|
| 179 |
+
|
| 180 |
+
return img_draw
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def create_demo(
|
| 184 |
+
model_type: str = "flux-dev",
|
| 185 |
+
ipa_path: str = "./ckpt/ipa.safetensors",
|
| 186 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
| 187 |
+
offload: bool = False,
|
| 188 |
+
lora_rank: int = 64,
|
| 189 |
+
additional_lora_ckpt: Optional[str] = None,
|
| 190 |
+
lora_scale: float = 1.0,
|
| 191 |
+
clip_path: str = "openai/clip-vit-large-patch14",
|
| 192 |
+
t5_path: str = "xlabs-ai/xflux_text_encoders",
|
| 193 |
+
flux_path: str = "black-forest-labs/FLUX.1-dev",
|
| 194 |
+
):
|
| 195 |
+
|
| 196 |
+
face_extractor = FaceExtractor()
|
| 197 |
+
# Initialize pipeline and face extractor
|
| 198 |
+
pipeline = WithAnyonePipeline(
|
| 199 |
+
model_type,
|
| 200 |
+
ipa_path,
|
| 201 |
+
device,
|
| 202 |
+
offload,
|
| 203 |
+
only_lora=True,
|
| 204 |
+
no_lora=True,
|
| 205 |
+
lora_rank=lora_rank,
|
| 206 |
+
additional_lora_ckpt=additional_lora_ckpt,
|
| 207 |
+
lora_weight=lora_scale,
|
| 208 |
+
face_extractor=face_extractor,
|
| 209 |
+
clip_path=clip_path,
|
| 210 |
+
t5_path=t5_path,
|
| 211 |
+
flux_path=flux_path,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# Add project badges
|
| 217 |
+
# badges_text = r"""
|
| 218 |
+
# <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
|
| 219 |
+
# <a href="https://github.com/bytedance/UNO"><img alt="Build" src="https://img.shields.io/github/stars/bytedance/UNO"></a>
|
| 220 |
+
# <a href="https://bytedance.github.io/UNO/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-UNO-yellow"></a>
|
| 221 |
+
# <a href="https://arxiv.org/abs/2504.02160"><img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-UNO-b31b1b.svg"></a>
|
| 222 |
+
# </div>
|
| 223 |
+
# """.strip()
|
| 224 |
+
|
| 225 |
+
def parse_bboxes(bbox_text):
|
| 226 |
+
"""Parse bounding box text input"""
|
| 227 |
+
if not bbox_text or bbox_text.strip() == "":
|
| 228 |
+
return None
|
| 229 |
+
|
| 230 |
+
try:
|
| 231 |
+
bboxes = []
|
| 232 |
+
lines = bbox_text.strip().split("\n")
|
| 233 |
+
for line in lines:
|
| 234 |
+
if not line.strip():
|
| 235 |
+
continue
|
| 236 |
+
coords = [float(x) for x in line.strip().split(",")]
|
| 237 |
+
if len(coords) != 4:
|
| 238 |
+
raise ValueError(f"Each bbox must have 4 coordinates (x1,y1,x2,y2), got: {line}")
|
| 239 |
+
bboxes.append(coords)
|
| 240 |
+
# print(f"\nParsed bboxes: {bboxes}\n")
|
| 241 |
+
return bboxes
|
| 242 |
+
except Exception as e:
|
| 243 |
+
raise gr.Error(f"Invalid bbox format: {e}")
|
| 244 |
+
|
| 245 |
+
def extract_from_multi_person(multi_person_image):
|
| 246 |
+
"""Extract references and bboxes from a multi-person image"""
|
| 247 |
+
if multi_person_image is None:
|
| 248 |
+
return None, None, None, None
|
| 249 |
+
|
| 250 |
+
# Convert from numpy to PIL if needed
|
| 251 |
+
if isinstance(multi_person_image, np.ndarray):
|
| 252 |
+
multi_person_image = Image.fromarray(multi_person_image)
|
| 253 |
+
|
| 254 |
+
ref_imgs, arcface_embeddings, bboxes, new_img = face_extractor.extract_refs(multi_person_image)
|
| 255 |
+
|
| 256 |
+
if ref_imgs is None or len(ref_imgs) == 0:
|
| 257 |
+
raise gr.Error("No faces detected in the multi-person image")
|
| 258 |
+
|
| 259 |
+
# Limit to max 4 faces
|
| 260 |
+
ref_imgs = ref_imgs[:4]
|
| 261 |
+
arcface_embeddings = arcface_embeddings[:4]
|
| 262 |
+
bboxes = bboxes[:4]
|
| 263 |
+
|
| 264 |
+
# Create visualization with bboxes
|
| 265 |
+
viz_image = draw_bboxes_on_image(new_img, bboxes)
|
| 266 |
+
|
| 267 |
+
# Format bboxes as string for display
|
| 268 |
+
bbox_text = "\n".join([f"{bbox[0]:.1f},{bbox[1]:.1f},{bbox[2]:.1f},{bbox[3]:.1f}" for bbox in bboxes])
|
| 269 |
+
|
| 270 |
+
return ref_imgs, arcface_embeddings, bboxes, viz_image
|
| 271 |
+
|
| 272 |
+
def process_and_generate(
|
| 273 |
+
prompt,
|
| 274 |
+
width, height,
|
| 275 |
+
guidance, num_steps, seed,
|
| 276 |
+
ref_img1, ref_img2, ref_img3, ref_img4,
|
| 277 |
+
manual_bboxes_text,
|
| 278 |
+
multi_person_image,
|
| 279 |
+
# use_text_prompt,
|
| 280 |
+
# id_weight,
|
| 281 |
+
siglip_weight
|
| 282 |
+
):
|
| 283 |
+
# Collect and validate reference images
|
| 284 |
+
ref_images = [img for img in [ref_img1, ref_img2, ref_img3, ref_img4] if img is not None]
|
| 285 |
+
|
| 286 |
+
if not ref_images:
|
| 287 |
+
raise gr.Error("At least one reference image is required")
|
| 288 |
+
|
| 289 |
+
# Process reference images to extract face and embeddings
|
| 290 |
+
ref_imgs = []
|
| 291 |
+
arcface_embeddings = []
|
| 292 |
+
|
| 293 |
+
# Modified bbox handling logic
|
| 294 |
+
if multi_person_image is not None:
|
| 295 |
+
# Extract from multi-person image mode
|
| 296 |
+
extracted_refs, extracted_embeddings, bboxes_, _ = extract_from_multi_person(multi_person_image)
|
| 297 |
+
if extracted_refs is None:
|
| 298 |
+
raise gr.Error("Failed to extract faces from the multi-person image")
|
| 299 |
+
|
| 300 |
+
print("bboxes from multi-person image:", bboxes_)
|
| 301 |
+
# need to resize bboxes from 512 512 to width height
|
| 302 |
+
bboxes_ = [resize_bbox(bbox, 512, 512, width, height) for bbox in bboxes_]
|
| 303 |
+
|
| 304 |
+
else:
|
| 305 |
+
# Parse manual bboxes
|
| 306 |
+
bboxes_ = parse_bboxes(manual_bboxes_text)
|
| 307 |
+
|
| 308 |
+
# If no manual bboxes provided, use automatic captioner
|
| 309 |
+
if bboxes_ is None:
|
| 310 |
+
print("No multi-person image or manual bboxes provided. Using automatic captioner.")
|
| 311 |
+
# Generate automatic bboxes based on image dimensions
|
| 312 |
+
bboxes__ = captioner(prompt, num_person=len(ref_images))
|
| 313 |
+
# resize to width height
|
| 314 |
+
bboxes_ = [resize_bbox(bbox, 512, 512, width, height) for bbox in bboxes__]
|
| 315 |
+
print("Automatically generated bboxes:", bboxes_)
|
| 316 |
+
|
| 317 |
+
bboxes = [bboxes_] # 伪装batch输入
|
| 318 |
+
# else:
|
| 319 |
+
# Manual mode: process each reference image
|
| 320 |
+
for img in ref_images:
|
| 321 |
+
if isinstance(img, np.ndarray):
|
| 322 |
+
img = Image.fromarray(img)
|
| 323 |
+
|
| 324 |
+
ref_img, embedding = face_extractor.extract(img)
|
| 325 |
+
if ref_img is None or embedding is None:
|
| 326 |
+
raise gr.Error("Failed to extract face from one of the reference images")
|
| 327 |
+
|
| 328 |
+
ref_imgs.append(ref_img)
|
| 329 |
+
arcface_embeddings.append(embedding)
|
| 330 |
+
|
| 331 |
+
# pad arcface_embeddings to 4 if less than 4
|
| 332 |
+
# while len(arcface_embeddings) < 4:
|
| 333 |
+
# arcface_embeddings.append(np.zeros_like(arcface_embeddings[0]))
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
if bboxes is None:
|
| 337 |
+
raise gr.Error("Either provide manual bboxes or a multi-person image for bbox extraction")
|
| 338 |
+
|
| 339 |
+
if len(bboxes[0]) != len(ref_imgs):
|
| 340 |
+
raise gr.Error(f"Number of bboxes ({len(bboxes[0])}) must match number of reference images ({len(ref_imgs)})")
|
| 341 |
+
|
| 342 |
+
# Convert arcface embeddings to tensor
|
| 343 |
+
arcface_embeddings = [torch.tensor(embedding) for embedding in arcface_embeddings]
|
| 344 |
+
arcface_embeddings = torch.stack(arcface_embeddings).to(device)
|
| 345 |
+
|
| 346 |
+
# Generate image
|
| 347 |
+
final_prompt = prompt
|
| 348 |
+
|
| 349 |
+
print(f"Generating image of size {width}x{height} with bboxes: {bboxes} ")
|
| 350 |
+
|
| 351 |
+
if seed < 0:
|
| 352 |
+
seed = np.random.randint(0, 1000000)
|
| 353 |
+
|
| 354 |
+
image_gen = pipeline(
|
| 355 |
+
prompt=final_prompt,
|
| 356 |
+
width=width,
|
| 357 |
+
height=height,
|
| 358 |
+
guidance=guidance,
|
| 359 |
+
num_steps=num_steps,
|
| 360 |
+
seed=seed if seed > 0 else None,
|
| 361 |
+
ref_imgs=ref_imgs,
|
| 362 |
+
arcface_embeddings=arcface_embeddings,
|
| 363 |
+
bboxes=bboxes,
|
| 364 |
+
id_weight = 1 - siglip_weight,
|
| 365 |
+
siglip_weight=siglip_weight,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
# Save temp file for download
|
| 369 |
+
temp_path = "temp_generated.png"
|
| 370 |
+
image_gen.save(temp_path)
|
| 371 |
+
|
| 372 |
+
# draw bboxes on the generated image for debug
|
| 373 |
+
debug_face = draw_bboxes_on_image(image_gen, bboxes[0])
|
| 374 |
+
|
| 375 |
+
return image_gen, debug_face, temp_path
|
| 376 |
+
|
| 377 |
+
def update_bbox_display(multi_person_image):
|
| 378 |
+
if multi_person_image is None:
|
| 379 |
+
return None, gr.update(visible=True), gr.update(visible=False)
|
| 380 |
+
|
| 381 |
+
try:
|
| 382 |
+
_, _, _, viz_image = extract_from_multi_person(multi_person_image)
|
| 383 |
+
return viz_image, gr.update(visible=False), gr.update(visible=True)
|
| 384 |
+
except Exception as e:
|
| 385 |
+
return None, gr.update(visible=True), gr.update(visible=False)
|
| 386 |
+
|
| 387 |
+
# Create Gradio interface
|
| 388 |
+
with gr.Blocks() as demo:
|
| 389 |
+
gr.Markdown("# WithAnyone Demo")
|
| 390 |
+
# gr.Markdown(badges_text)
|
| 391 |
+
|
| 392 |
+
with gr.Row():
|
| 393 |
+
|
| 394 |
+
with gr.Column():
|
| 395 |
+
# Input controls
|
| 396 |
+
generate_btn = gr.Button("Generate", variant="primary")
|
| 397 |
+
with gr.Row():
|
| 398 |
+
with gr.Column():
|
| 399 |
+
siglip_weight = gr.Slider(0.0, 1.0, 1.0, step=0.05, label="Spiritual Resemblance <--> Formal Resemblance")
|
| 400 |
+
with gr.Row():
|
| 401 |
+
prompt = gr.Textbox(label="Prompt", value="a person in a beautiful garden. High resolution, extremely detailed")
|
| 402 |
+
# use_text_prompt = gr.Checkbox(label="Use text prompt", value=True)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
with gr.Row():
|
| 406 |
+
# Image generation settings
|
| 407 |
+
with gr.Column():
|
| 408 |
+
width = gr.Slider(512, 1024, 768, step=64, label="Generation Width")
|
| 409 |
+
height = gr.Slider(512, 1024, 768, step=64, label="Generation Height")
|
| 410 |
+
|
| 411 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 412 |
+
with gr.Row():
|
| 413 |
+
num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
|
| 414 |
+
guidance = gr.Slider(1.0, 10.0, 4.0, step=0.1, label="Guidance")
|
| 415 |
+
seed = gr.Number(-1, label="Seed (-1 for random)")
|
| 416 |
+
|
| 417 |
+
# start_at = gr.Slider(0, 50, 0, step=1, label="Start Identity at Step")
|
| 418 |
+
# end_at = gr.Number(-1, label="End Identity at Step (-1 for last)")
|
| 419 |
+
|
| 420 |
+
# with gr.Row():
|
| 421 |
+
# # skip_every = gr.Number(-1, label="Skip Identity Every N Steps (-1 for no skip)")
|
| 422 |
+
|
| 423 |
+
# siglip_weight = gr.Slider(0.0, 1.0, 1.0, step=0.05, label="Siglip Weight")
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
with gr.Row():
|
| 427 |
+
with gr.Column():
|
| 428 |
+
# Reference image inputs
|
| 429 |
+
gr.Markdown("### Face References (1-4 required)")
|
| 430 |
+
ref_img1 = gr.Image(label="Reference 1", type="pil")
|
| 431 |
+
ref_img2 = gr.Image(label="Reference 2", type="pil", visible=True)
|
| 432 |
+
ref_img3 = gr.Image(label="Reference 3", type="pil", visible=True)
|
| 433 |
+
ref_img4 = gr.Image(label="Reference 4", type="pil", visible=True)
|
| 434 |
+
|
| 435 |
+
with gr.Column():
|
| 436 |
+
# Bounding box inputs
|
| 437 |
+
gr.Markdown("### Mask Configuration (Option 1: Automatic)")
|
| 438 |
+
multi_person_image = gr.Image(label="Multi-person image (for automatic bbox extraction)", type="pil")
|
| 439 |
+
bbox_preview = gr.Image(label="Detected Faces", type="pil")
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
gr.Markdown("### Mask Configuration (Option 2: Manual)")
|
| 443 |
+
manual_bbox_input = gr.Textbox(
|
| 444 |
+
label="Manual Bounding Boxes (one per line, format: x1,y1,x2,y2)",
|
| 445 |
+
lines=4,
|
| 446 |
+
placeholder="100,100,200,200\n300,100,400,200"
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
# generate_btn = gr.Button("Generate", variant="primary")
|
| 454 |
+
|
| 455 |
+
with gr.Column():
|
| 456 |
+
# Output display
|
| 457 |
+
output_image = gr.Image(label="Generated Image")
|
| 458 |
+
debug_face = gr.Image(label="Debug. Faces are expected to be generated in these boxes")
|
| 459 |
+
download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False)
|
| 460 |
+
|
| 461 |
+
# Examples section
|
| 462 |
+
with gr.Row():
|
| 463 |
+
|
| 464 |
+
gr.Markdown("""
|
| 465 |
+
# Example Configurations
|
| 466 |
+
|
| 467 |
+
### Tips for Better Results
|
| 468 |
+
Be prepared for the first few runs as it may not be very satisfying.
|
| 469 |
+
|
| 470 |
+
- Provide detailed prompts describing the identity. WithAnyone is "controllable", so it needs more information to be controlled. Here are something that might go wrong if not specified:
|
| 471 |
+
- Skin color (generally the race is fine, but for asain descent, if not specified, it may generate darker skin tone);
|
| 472 |
+
- Age (e.g., intead of "a man", try "a young man". If not specified, it may generate an older figure);
|
| 473 |
+
- Body build;
|
| 474 |
+
- Hairstyle;
|
| 475 |
+
- Accessories (glasses, hats, earrings, etc.);
|
| 476 |
+
- Makeup
|
| 477 |
+
- Use the slider to balance between "Resemblance in Spirit" and "Resemblance in Form" according to your needs. If you want to preserve more details in the reference image, move the slider to the right; if you want more freedom and creativity, move it to the left.
|
| 478 |
+
- Try it with LoRAs from community. They are usually fantastic.
|
| 479 |
+
""")
|
| 480 |
+
with gr.Row():
|
| 481 |
+
examples = gr.Examples(
|
| 482 |
+
examples=[
|
| 483 |
+
[
|
| 484 |
+
"a highly detailed portrait of a woman shown in profile. Her long, dark hair flows elegantly, intricately decorated with an abundant array of colorful flowers—ranging from soft light pinks and vibrant light oranges to delicate greyish blues—and lush green leaves, giving a sense of natural beauty and charm. Her bright blue eyes are striking, and her lips are painted a vivid red, adding to her alluring appearance. She is clad in an ornate garment with intricate floral patterns in warm hues like pink and orange, featuring exquisite detailing that speaks of fine craftsmanship. Around her neck, she wears a decorative choker with intricate designs, and dangling from her ears are beautiful blue teardrop earrings that catch the light. The background is filled with a profusion of flowers in various shades, creating a rich, vibrant, and romantic atmosphere that complements the woman's elegant and enchanting look.", # prompt
|
| 485 |
+
1024, 1024, # width, height
|
| 486 |
+
4.0, 25, 42, # guidance, num_steps, seed
|
| 487 |
+
"assets/ref1.jpg", None, None, None, # ref images
|
| 488 |
+
"240,180,540,500", None, # manual_bbox_input, multi_person_image
|
| 489 |
+
# True, # use_text_prompt
|
| 490 |
+
0.0, # siglip_weight
|
| 491 |
+
],
|
| 492 |
+
[
|
| 493 |
+
"High resolution anfd extremely detailed image of two elegant ladies enjoying a serene afternoon in a quaint Parisian café. They both wear fashionable trench coats and stylish berets, exuding an air of sophistication. One lady gently sips on a cappuccino, while her companion reads an intriguing novel with a subtle smile. The café is framed by charming antique furniture and vintage posters adorning the walls. Soft, warm light filters through a window, casting delicate shadows and creating a cozy, inviting atmosphere. Captured from a slightly elevated angle, the composition highlights the warmth of the scene in a gentle watercolor illustrative style. ", # prompt
|
| 494 |
+
1024, 1024, # width, height
|
| 495 |
+
4.0, 25, 42, # guidance, num_steps, seed
|
| 496 |
+
"assets/ref1.jpg", "assets/ref2.jpg", None, None, # ref images
|
| 497 |
+
"248,172,428,498\n554,128,728,464", None, # manual_bbox_input, multi_person_image
|
| 498 |
+
# True, # use_text_prompt
|
| 499 |
+
0.0, # siglip_weight
|
| 500 |
+
]
|
| 501 |
+
],
|
| 502 |
+
inputs=[
|
| 503 |
+
prompt, width, height, guidance, num_steps, seed,
|
| 504 |
+
ref_img1, ref_img2, ref_img3, ref_img4,
|
| 505 |
+
manual_bbox_input, multi_person_image,
|
| 506 |
+
siglip_weight
|
| 507 |
+
],
|
| 508 |
+
label="Click to load example configurations"
|
| 509 |
+
)
|
| 510 |
+
# Set up event handlers
|
| 511 |
+
multi_person_image.change(
|
| 512 |
+
fn=update_bbox_display,
|
| 513 |
+
inputs=[multi_person_image],
|
| 514 |
+
outputs=[bbox_preview, manual_bbox_input, bbox_preview]
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
generate_btn.click(
|
| 518 |
+
fn=process_and_generate,
|
| 519 |
+
inputs=[
|
| 520 |
+
prompt, width, height, guidance, num_steps, seed,
|
| 521 |
+
ref_img1, ref_img2, ref_img3, ref_img4,
|
| 522 |
+
manual_bbox_input, multi_person_image,
|
| 523 |
+
siglip_weight
|
| 524 |
+
],
|
| 525 |
+
outputs=[output_image,debug_face, download_btn]
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
return demo
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
if __name__ == "__main__":
|
| 532 |
+
from transformers import HfArgumentParser
|
| 533 |
+
|
| 534 |
+
@dataclasses.dataclass
|
| 535 |
+
class AppArgs:
|
| 536 |
+
model_type: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
|
| 537 |
+
device: Literal["cuda", "cpu"] = (
|
| 538 |
+
"cuda" if torch.cuda.is_available()
|
| 539 |
+
else "mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
| 540 |
+
else "cpu"
|
| 541 |
+
)
|
| 542 |
+
offload: bool = False
|
| 543 |
+
lora_rank: int = 64
|
| 544 |
+
port: int = 7860
|
| 545 |
+
additional_lora: str = None
|
| 546 |
+
lora_scale: float = 1.0
|
| 547 |
+
ipa_path: str = "./ckpt/ipa.safetensors"
|
| 548 |
+
clip_path: str = "openai/clip-vit-large-patch14"
|
| 549 |
+
t5_path: str = "xlabs-ai/xflux_text_encoders"
|
| 550 |
+
flux_path: str = "black-forest-labs/FLUX.1-dev"
|
| 551 |
+
|
| 552 |
+
parser = HfArgumentParser([AppArgs])
|
| 553 |
+
args = parser.parse_args_into_dataclasses()[0]
|
| 554 |
+
|
| 555 |
+
demo = create_demo(
|
| 556 |
+
args.model_type,
|
| 557 |
+
args.ipa_path,
|
| 558 |
+
args.device,
|
| 559 |
+
args.offload,
|
| 560 |
+
args.lora_rank,
|
| 561 |
+
args.additional_lora,
|
| 562 |
+
args.lora_scale,
|
| 563 |
+
args.clip_path,
|
| 564 |
+
args.t5_path,
|
| 565 |
+
args.flux_path,
|
| 566 |
+
)
|
| 567 |
+
demo.launch(server_port=args.port)
|
| 568 |
+
|
gradio_edit.py
ADDED
|
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Fudan University. All rights reserved.
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
import dataclasses
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import List, Literal, Optional, Tuple, Union
|
| 10 |
+
from io import BytesIO
|
| 11 |
+
|
| 12 |
+
import cv2
|
| 13 |
+
import gradio as gr
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
from PIL import Image, ImageDraw, ImageFilter
|
| 17 |
+
from PIL.JpegImagePlugin import JpegImageFile
|
| 18 |
+
|
| 19 |
+
from withanyone_kontext_s.flux.pipeline import WithAnyonePipeline
|
| 20 |
+
from util import extract_moref, face_preserving_resize
|
| 21 |
+
import insightface
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def blur_faces_in_image(img, json_data, face_size_threshold=100, blur_radius=15):
|
| 25 |
+
"""
|
| 26 |
+
Blurs facial areas directly in the original image for privacy protection.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
img: PIL Image or image data
|
| 30 |
+
json_data: JSON object with 'bboxes' and 'crop' information
|
| 31 |
+
face_size_threshold: Minimum size for faces to be considered (default: 100 pixels)
|
| 32 |
+
blur_radius: Strength of the blur effect (higher = more blurred)
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
PIL Image with faces blurred
|
| 36 |
+
"""
|
| 37 |
+
# Ensure img is a PIL Image
|
| 38 |
+
if not isinstance(img, Image.Image) and not isinstance(img, torch.Tensor) and not isinstance(img, JpegImageFile):
|
| 39 |
+
img = Image.open(BytesIO(img))
|
| 40 |
+
|
| 41 |
+
new_bboxes = json_data['bboxes']
|
| 42 |
+
# crop = json_data['crop'] if 'crop' in json_data else [0, 0, img.width, img.height]
|
| 43 |
+
|
| 44 |
+
# # Recalculate bounding boxes based on crop info
|
| 45 |
+
# new_bboxes = [recalculate_bbox(bbox, crop) for bbox in bboxes]
|
| 46 |
+
|
| 47 |
+
# Check face sizes and filter out faces that are too small
|
| 48 |
+
valid_bboxes = []
|
| 49 |
+
for bbox in new_bboxes:
|
| 50 |
+
x1, y1, x2, y2 = bbox
|
| 51 |
+
if x2 - x1 >= face_size_threshold and y2 - y1 >= face_size_threshold:
|
| 52 |
+
valid_bboxes.append(bbox)
|
| 53 |
+
|
| 54 |
+
# If no valid faces found, return original image
|
| 55 |
+
if not valid_bboxes:
|
| 56 |
+
return img
|
| 57 |
+
|
| 58 |
+
# Create a copy of the original image to modify
|
| 59 |
+
blurred_img = img.copy()
|
| 60 |
+
|
| 61 |
+
# Process each face
|
| 62 |
+
for bbox in valid_bboxes:
|
| 63 |
+
# Convert coordinates to integers
|
| 64 |
+
x1, y1, x2, y2 = map(int, bbox)
|
| 65 |
+
|
| 66 |
+
# Ensure coordinates are within image boundaries
|
| 67 |
+
img_width, img_height = img.size
|
| 68 |
+
x1 = max(0, x1)
|
| 69 |
+
y1 = max(0, y1)
|
| 70 |
+
x2 = min(img_width, x2)
|
| 71 |
+
y2 = min(img_height, y2)
|
| 72 |
+
|
| 73 |
+
# Extract the face region
|
| 74 |
+
face_region = img.crop((x1, y1, x2, y2))
|
| 75 |
+
|
| 76 |
+
# Apply blur to the face region
|
| 77 |
+
blurred_face = face_region.filter(ImageFilter.GaussianBlur(radius=blur_radius))
|
| 78 |
+
|
| 79 |
+
# Paste the blurred face back into the image
|
| 80 |
+
blurred_img.paste(blurred_face, (x1, y1))
|
| 81 |
+
|
| 82 |
+
return blurred_img
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def captioner(prompt: str, num_person = 1) -> List[List[float]]:
|
| 86 |
+
# use random choose for testing
|
| 87 |
+
# within 512
|
| 88 |
+
if num_person == 1:
|
| 89 |
+
bbox_choices = [
|
| 90 |
+
# expanded, centered and quadrant placements
|
| 91 |
+
[96, 96, 288, 288],
|
| 92 |
+
[128, 128, 320, 320],
|
| 93 |
+
[160, 96, 352, 288],
|
| 94 |
+
[96, 160, 288, 352],
|
| 95 |
+
[208, 96, 400, 288],
|
| 96 |
+
[96, 208, 288, 400],
|
| 97 |
+
[192, 160, 368, 336],
|
| 98 |
+
[64, 128, 224, 320],
|
| 99 |
+
[288, 128, 448, 320],
|
| 100 |
+
[128, 256, 320, 448],
|
| 101 |
+
[80, 80, 240, 272],
|
| 102 |
+
[196, 196, 380, 380],
|
| 103 |
+
# originals
|
| 104 |
+
[100, 100, 300, 300],
|
| 105 |
+
[150, 50, 450, 350],
|
| 106 |
+
[200, 100, 500, 400],
|
| 107 |
+
[250, 150, 512, 450],
|
| 108 |
+
]
|
| 109 |
+
return [bbox_choices[np.random.randint(0, len(bbox_choices))]]
|
| 110 |
+
elif num_person == 2:
|
| 111 |
+
# realistic side-by-side rows (no vertical stacks or diagonals)
|
| 112 |
+
bbox_choices = [
|
| 113 |
+
[[64, 112, 224, 304], [288, 112, 448, 304]],
|
| 114 |
+
[[48, 128, 208, 320], [304, 128, 464, 320]],
|
| 115 |
+
[[32, 144, 192, 336], [320, 144, 480, 336]],
|
| 116 |
+
[[80, 96, 240, 288], [272, 96, 432, 288]],
|
| 117 |
+
[[80, 160, 240, 352], [272, 160, 432, 352]],
|
| 118 |
+
[[64, 128, 240, 336], [272, 144, 432, 320]], # slight stagger, same row
|
| 119 |
+
[[96, 160, 256, 352], [288, 160, 448, 352]],
|
| 120 |
+
[[64, 192, 224, 384], [288, 192, 448, 384]], # lower row
|
| 121 |
+
[[16, 128, 176, 320], [336, 128, 496, 320]], # near edges
|
| 122 |
+
[[48, 120, 232, 328], [280, 120, 464, 328]],
|
| 123 |
+
[[96, 160, 240, 336], [272, 160, 416, 336]], # tighter faces
|
| 124 |
+
[[72, 136, 232, 328], [280, 152, 440, 344]], # small vertical offset
|
| 125 |
+
[[48, 120, 224, 344], [288, 144, 448, 336]], # asymmetric sizes
|
| 126 |
+
[[80, 224, 240, 416], [272, 224, 432, 416]], # bottom row
|
| 127 |
+
[[80, 64, 240, 256], [272, 64, 432, 256]], # top row
|
| 128 |
+
[[96, 176, 256, 368], [288, 176, 448, 368]],
|
| 129 |
+
]
|
| 130 |
+
return bbox_choices[np.random.randint(0, len(bbox_choices))]
|
| 131 |
+
|
| 132 |
+
elif num_person == 3:
|
| 133 |
+
# Non-overlapping 3-person layouts within 512x512
|
| 134 |
+
bbox_choices = [
|
| 135 |
+
[[20, 140, 150, 360], [180, 120, 330, 360], [360, 130, 500, 360]],
|
| 136 |
+
[[30, 100, 160, 300], [190, 90, 320, 290], [350, 110, 480, 310]],
|
| 137 |
+
[[40, 180, 150, 330], [200, 180, 310, 330], [360, 180, 470, 330]],
|
| 138 |
+
[[60, 120, 170, 300], [210, 110, 320, 290], [350, 140, 480, 320]],
|
| 139 |
+
[[50, 80, 170, 250], [200, 130, 320, 300], [350, 80, 480, 250]],
|
| 140 |
+
[[40, 260, 170, 480], [190, 60, 320, 240], [350, 260, 490, 480]],
|
| 141 |
+
[[30, 120, 150, 320], [200, 140, 320, 340], [360, 160, 500, 360]],
|
| 142 |
+
[[80, 140, 200, 300], [220, 80, 350, 260], [370, 160, 500, 320]],
|
| 143 |
+
]
|
| 144 |
+
return bbox_choices[np.random.randint(0, len(bbox_choices))]
|
| 145 |
+
elif num_person == 4:
|
| 146 |
+
# Non-overlapping 4-person layouts within 512x512
|
| 147 |
+
bbox_choices = [
|
| 148 |
+
[[20, 100, 120, 240], [140, 100, 240, 240], [260, 100, 360, 240], [380, 100, 480, 240]],
|
| 149 |
+
[[40, 60, 200, 260], [220, 60, 380, 260], [40, 280, 200, 480], [220, 280, 380, 480]],
|
| 150 |
+
[[180, 30, 330, 170], [30, 220, 150, 380], [200, 220, 320, 380], [360, 220, 490, 380]],
|
| 151 |
+
[[30, 60, 140, 200], [370, 60, 480, 200], [30, 320, 140, 460], [370, 320, 480, 460]],
|
| 152 |
+
[[20, 120, 120, 380], [140, 100, 240, 360], [260, 120, 360, 380], [380, 100, 480, 360]],
|
| 153 |
+
[[30, 80, 150, 240], [180, 120, 300, 280], [330, 80, 450, 240], [200, 300, 320, 460]],
|
| 154 |
+
[[30, 140, 110, 330], [140, 140, 220, 330], [250, 140, 330, 330], [370, 140, 450, 330]],
|
| 155 |
+
[[40, 80, 150, 240], [40, 260, 150, 420], [200, 80, 310, 240], [370, 80, 480, 240]],
|
| 156 |
+
]
|
| 157 |
+
return bbox_choices[np.random.randint(0, len(bbox_choices))]
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class FaceExtractor:
|
| 161 |
+
def __init__(self, model_path="./"):
|
| 162 |
+
self.model = insightface.app.FaceAnalysis(name="antelopev2", root="./")
|
| 163 |
+
self.model.prepare(ctx_id=0)
|
| 164 |
+
|
| 165 |
+
def extract(self, image: Image.Image):
|
| 166 |
+
"""Extract single face and embedding from an image"""
|
| 167 |
+
image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 168 |
+
res = self.model.get(image_np)
|
| 169 |
+
if len(res) == 0:
|
| 170 |
+
return None, None
|
| 171 |
+
res = res[0]
|
| 172 |
+
bbox = res["bbox"]
|
| 173 |
+
moref = extract_moref(image, {"bboxes": [bbox]}, 1)
|
| 174 |
+
return moref[0], res["embedding"]
|
| 175 |
+
|
| 176 |
+
def extract_refs(self, image: Image.Image):
|
| 177 |
+
"""Extract multiple faces and embeddings from an image"""
|
| 178 |
+
image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 179 |
+
res = self.model.get(image_np)
|
| 180 |
+
if len(res) == 0:
|
| 181 |
+
return None, None, None, None
|
| 182 |
+
ref_imgs = []
|
| 183 |
+
arcface_embeddings = []
|
| 184 |
+
bboxes = []
|
| 185 |
+
for r in res:
|
| 186 |
+
bbox = r["bbox"]
|
| 187 |
+
bboxes.append(bbox)
|
| 188 |
+
moref = extract_moref(image, {"bboxes": [bbox]}, 1)
|
| 189 |
+
ref_imgs.append(moref[0])
|
| 190 |
+
arcface_embeddings.append(r["embedding"])
|
| 191 |
+
|
| 192 |
+
# Convert bboxes to the correct format
|
| 193 |
+
new_img, new_bboxes = face_preserving_resize(image, bboxes, 512)
|
| 194 |
+
return ref_imgs, arcface_embeddings, new_bboxes, new_img
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def resize_bbox(bbox, ori_width, ori_height, new_width, new_height):
|
| 198 |
+
"""Resize bounding box coordinates while preserving aspect ratio"""
|
| 199 |
+
x1, y1, x2, y2 = bbox
|
| 200 |
+
|
| 201 |
+
# Calculate scaling factors
|
| 202 |
+
width_scale = new_width / ori_width
|
| 203 |
+
height_scale = new_height / ori_height
|
| 204 |
+
|
| 205 |
+
# Use minimum scaling factor to preserve aspect ratio
|
| 206 |
+
min_scale = min(width_scale, height_scale)
|
| 207 |
+
|
| 208 |
+
# Calculate offsets for centering the scaled box
|
| 209 |
+
width_offset = (new_width - ori_width * min_scale) / 2
|
| 210 |
+
height_offset = (new_height - ori_height * min_scale) / 2
|
| 211 |
+
|
| 212 |
+
# Scale and adjust coordinates
|
| 213 |
+
new_x1 = int(x1 * min_scale + width_offset)
|
| 214 |
+
new_y1 = int(y1 * min_scale + height_offset)
|
| 215 |
+
new_x2 = int(x2 * min_scale + width_offset)
|
| 216 |
+
new_y2 = int(y2 * min_scale + height_offset)
|
| 217 |
+
|
| 218 |
+
return [new_x1, new_y1, new_x2, new_y2]
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def draw_bboxes_on_image(image, bboxes):
|
| 222 |
+
"""Draw bounding boxes on image for visualization"""
|
| 223 |
+
if bboxes is None:
|
| 224 |
+
return image
|
| 225 |
+
|
| 226 |
+
# Create a copy to draw on
|
| 227 |
+
img_draw = image.copy()
|
| 228 |
+
draw = ImageDraw.Draw(img_draw)
|
| 229 |
+
|
| 230 |
+
# Draw each bbox with a different color
|
| 231 |
+
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]
|
| 232 |
+
|
| 233 |
+
for i, bbox in enumerate(bboxes):
|
| 234 |
+
color = colors[i % len(colors)]
|
| 235 |
+
x1, y1, x2, y2 = [int(coord) for coord in bbox]
|
| 236 |
+
# Draw rectangle
|
| 237 |
+
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
|
| 238 |
+
# Draw label
|
| 239 |
+
draw.text((x1, y1-15), f"Face {i+1}", fill=color)
|
| 240 |
+
|
| 241 |
+
return img_draw
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def create_demo(
|
| 245 |
+
model_type: str = "flux-dev",
|
| 246 |
+
ipa_path: str = "./ckpt/ipa.safetensors",
|
| 247 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
| 248 |
+
offload: bool = False,
|
| 249 |
+
lora_rank: int = 64,
|
| 250 |
+
additional_lora_ckpt: Optional[str] = None,
|
| 251 |
+
lora_scale: float = 1.0,
|
| 252 |
+
clip_path: str = "openai/clip-vit-large-patch14",
|
| 253 |
+
t5_path: str = "xlabs-ai/xflux_text_encoders",
|
| 254 |
+
flux_path: str = "black-forest-labs/FLUX.1-dev",
|
| 255 |
+
):
|
| 256 |
+
|
| 257 |
+
face_extractor = FaceExtractor()
|
| 258 |
+
# Initialize pipeline and face extractor
|
| 259 |
+
pipeline = WithAnyonePipeline(
|
| 260 |
+
model_type,
|
| 261 |
+
ipa_path,
|
| 262 |
+
device,
|
| 263 |
+
offload,
|
| 264 |
+
only_lora=True,
|
| 265 |
+
no_lora=True,
|
| 266 |
+
lora_rank=lora_rank,
|
| 267 |
+
additional_lora_ckpt=additional_lora_ckpt,
|
| 268 |
+
lora_weight=lora_scale,
|
| 269 |
+
face_extractor=face_extractor,
|
| 270 |
+
clip_path=clip_path,
|
| 271 |
+
t5_path=t5_path,
|
| 272 |
+
flux_path=flux_path,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def parse_bboxes(bbox_text):
|
| 277 |
+
"""Parse bounding box text input"""
|
| 278 |
+
if not bbox_text or bbox_text.strip() == "":
|
| 279 |
+
return None
|
| 280 |
+
|
| 281 |
+
try:
|
| 282 |
+
bboxes = []
|
| 283 |
+
lines = bbox_text.strip().split("\n")
|
| 284 |
+
for line in lines:
|
| 285 |
+
if not line.strip():
|
| 286 |
+
continue
|
| 287 |
+
coords = [float(x) for x in line.strip().split(",")]
|
| 288 |
+
if len(coords) != 4:
|
| 289 |
+
raise ValueError(f"Each bbox must have 4 coordinates (x1,y1,x2,y2), got: {line}")
|
| 290 |
+
bboxes.append(coords)
|
| 291 |
+
return bboxes
|
| 292 |
+
except Exception as e:
|
| 293 |
+
raise gr.Error(f"Invalid bbox format: {e}")
|
| 294 |
+
|
| 295 |
+
def extract_from_base_image(base_img):
|
| 296 |
+
"""Extract references and bboxes from the base image"""
|
| 297 |
+
if base_img is None:
|
| 298 |
+
return None, None, None, None
|
| 299 |
+
|
| 300 |
+
# Convert from numpy to PIL if needed
|
| 301 |
+
if isinstance(base_img, np.ndarray):
|
| 302 |
+
base_img = Image.fromarray(base_img)
|
| 303 |
+
|
| 304 |
+
ref_imgs, arcface_embeddings, bboxes, new_img = face_extractor.extract_refs(base_img)
|
| 305 |
+
|
| 306 |
+
if ref_imgs is None or len(ref_imgs) == 0:
|
| 307 |
+
raise gr.Error("No faces detected in the base image")
|
| 308 |
+
|
| 309 |
+
# Limit to max 4 faces
|
| 310 |
+
ref_imgs = ref_imgs[:4]
|
| 311 |
+
arcface_embeddings = arcface_embeddings[:4]
|
| 312 |
+
bboxes = bboxes[:4]
|
| 313 |
+
|
| 314 |
+
# Create visualization with bboxes
|
| 315 |
+
viz_image = draw_bboxes_on_image(new_img, bboxes)
|
| 316 |
+
|
| 317 |
+
# Format bboxes as string for display
|
| 318 |
+
bbox_text = "\n".join([f"{bbox[0]:.1f},{bbox[1]:.1f},{bbox[2]:.1f},{bbox[3]:.1f}" for bbox in bboxes])
|
| 319 |
+
|
| 320 |
+
return ref_imgs, arcface_embeddings, bboxes, viz_image, bbox_text
|
| 321 |
+
|
| 322 |
+
def process_and_generate(
|
| 323 |
+
prompt,
|
| 324 |
+
guidance, num_steps, seed,
|
| 325 |
+
ref_img1, ref_img2, ref_img3, ref_img4,
|
| 326 |
+
base_img,
|
| 327 |
+
manual_bboxes_text,
|
| 328 |
+
use_text_prompt,
|
| 329 |
+
siglip_weight
|
| 330 |
+
):
|
| 331 |
+
# Validate base_img is provided
|
| 332 |
+
if base_img is None:
|
| 333 |
+
raise gr.Error("Base image is required")
|
| 334 |
+
|
| 335 |
+
# Convert numpy to PIL if needed
|
| 336 |
+
if isinstance(base_img, np.ndarray):
|
| 337 |
+
base_img = Image.fromarray(base_img)
|
| 338 |
+
|
| 339 |
+
# Get dimensions from base_img
|
| 340 |
+
width, height = base_img.size
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
# Collect and validate reference images
|
| 344 |
+
ref_images = [img for img in [ref_img1, ref_img2, ref_img3, ref_img4] if img is not None]
|
| 345 |
+
|
| 346 |
+
if not ref_images:
|
| 347 |
+
raise gr.Error("At least one reference image is required")
|
| 348 |
+
|
| 349 |
+
# Process reference images to extract face and embeddings
|
| 350 |
+
ref_imgs = []
|
| 351 |
+
arcface_embeddings = []
|
| 352 |
+
|
| 353 |
+
# Extract bboxes from the base image
|
| 354 |
+
extracted_refs, extracted_embeddings, bboxes_, _, _ = extract_from_base_image(base_img)
|
| 355 |
+
bboxes__ = [resize_bbox(bbox, 512, 512, width, height) for bbox in bboxes_]
|
| 356 |
+
if extracted_refs is None:
|
| 357 |
+
raise gr.Error("No faces detected in the base image. Please provide a different base image with clear faces.")
|
| 358 |
+
|
| 359 |
+
# Create blurred canvas by blurring faces in the base image
|
| 360 |
+
blurred_canvas = blur_faces_in_image(base_img, {'bboxes': bboxes__})
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
bboxes = [bboxes__] # Wrap in list for batch input format
|
| 364 |
+
|
| 365 |
+
# Process each reference image
|
| 366 |
+
for img in ref_images:
|
| 367 |
+
if isinstance(img, np.ndarray):
|
| 368 |
+
img = Image.fromarray(img)
|
| 369 |
+
|
| 370 |
+
ref_img, embedding = face_extractor.extract(img)
|
| 371 |
+
if ref_img is None or embedding is None:
|
| 372 |
+
raise gr.Error("Failed to extract face from one of the reference images")
|
| 373 |
+
|
| 374 |
+
ref_imgs.append(ref_img)
|
| 375 |
+
arcface_embeddings.append(embedding)
|
| 376 |
+
|
| 377 |
+
if len(bboxes[0]) != len(ref_imgs):
|
| 378 |
+
raise gr.Error(f"Number of bboxes ({len(bboxes[0])}) must match number of reference images ({len(ref_imgs)})")
|
| 379 |
+
|
| 380 |
+
# Convert arcface embeddings to tensor
|
| 381 |
+
arcface_embeddings = [torch.tensor(embedding) for embedding in arcface_embeddings]
|
| 382 |
+
arcface_embeddings = torch.stack(arcface_embeddings).to(device)
|
| 383 |
+
|
| 384 |
+
# Generate image
|
| 385 |
+
final_prompt = prompt if use_text_prompt else ""
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
if seed < 0:
|
| 389 |
+
seed = np.random.randint(0, 1000000)
|
| 390 |
+
|
| 391 |
+
image_gen = pipeline(
|
| 392 |
+
prompt=final_prompt,
|
| 393 |
+
width=width,
|
| 394 |
+
height=height,
|
| 395 |
+
guidance=guidance,
|
| 396 |
+
num_steps=num_steps,
|
| 397 |
+
seed=seed if seed > 0 else None,
|
| 398 |
+
ref_imgs=ref_imgs,
|
| 399 |
+
img_cond=blurred_canvas, # Pass the blurred canvas image
|
| 400 |
+
arcface_embeddings=arcface_embeddings,
|
| 401 |
+
bboxes=bboxes,
|
| 402 |
+
max_num_ids=len(ref_imgs),
|
| 403 |
+
siglip_weight=0,
|
| 404 |
+
id_weight=1, # only arcface supported now
|
| 405 |
+
arc_only=True,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# Save temp file for download
|
| 409 |
+
temp_path = "temp_generated.png"
|
| 410 |
+
image_gen.save(temp_path)
|
| 411 |
+
|
| 412 |
+
# draw bboxes on the generated image for debug
|
| 413 |
+
debug_face = draw_bboxes_on_image(image_gen, bboxes[0])
|
| 414 |
+
|
| 415 |
+
return image_gen, debug_face, temp_path
|
| 416 |
+
|
| 417 |
+
def update_bbox_display(base_img):
|
| 418 |
+
if base_img is None:
|
| 419 |
+
return None, None
|
| 420 |
+
|
| 421 |
+
try:
|
| 422 |
+
_, _, _, viz_image, bbox_text = extract_from_base_image(base_img)
|
| 423 |
+
return viz_image, bbox_text
|
| 424 |
+
except Exception as e:
|
| 425 |
+
return None, None
|
| 426 |
+
|
| 427 |
+
# Create Gradio interface
|
| 428 |
+
with gr.Blocks() as demo:
|
| 429 |
+
gr.Markdown("# WithAnyone Kontext Demo")
|
| 430 |
+
|
| 431 |
+
with gr.Row():
|
| 432 |
+
|
| 433 |
+
with gr.Column():
|
| 434 |
+
# Input controls
|
| 435 |
+
generate_btn = gr.Button("Generate", variant="primary")
|
| 436 |
+
siglip_weight = 0.0
|
| 437 |
+
with gr.Row():
|
| 438 |
+
prompt = gr.Textbox(label="Prompt", value="a person in a beautiful garden. High resolution, extremely detailed")
|
| 439 |
+
use_text_prompt = gr.Checkbox(label="Use text prompt", value=True)
|
| 440 |
+
|
| 441 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 442 |
+
with gr.Row():
|
| 443 |
+
num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
|
| 444 |
+
guidance = gr.Slider(1.0, 10.0, 4.0, step=0.1, label="Guidance")
|
| 445 |
+
seed = gr.Number(-1, label="Seed (-1 for random)")
|
| 446 |
+
|
| 447 |
+
with gr.Row():
|
| 448 |
+
with gr.Column():
|
| 449 |
+
# Reference image inputs
|
| 450 |
+
gr.Markdown("### Face References (1-4 required)")
|
| 451 |
+
ref_img1 = gr.Image(label="Reference 1", type="pil")
|
| 452 |
+
ref_img2 = gr.Image(label="Reference 2", type="pil", visible=True)
|
| 453 |
+
ref_img3 = gr.Image(label="Reference 3", type="pil", visible=True)
|
| 454 |
+
ref_img4 = gr.Image(label="Reference 4", type="pil", visible=True)
|
| 455 |
+
|
| 456 |
+
with gr.Column():
|
| 457 |
+
# Base image input - combines the previous canvas and multi-person image
|
| 458 |
+
gr.Markdown("### Base Image (Required)")
|
| 459 |
+
base_img = gr.Image(label="Base Image - faces will be detected and replaced", type="pil")
|
| 460 |
+
|
| 461 |
+
bbox_preview = gr.Image(label="Detected Faces", type="pil")
|
| 462 |
+
|
| 463 |
+
gr.Markdown("### Manual Bounding Box Override (Optional)")
|
| 464 |
+
manual_bbox_input = gr.Textbox(
|
| 465 |
+
label="Manual Bounding Boxes (one per line, format: x1,y1,x2,y2)",
|
| 466 |
+
lines=4,
|
| 467 |
+
placeholder="100,100,200,200\n300,100,400,200"
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
with gr.Column():
|
| 472 |
+
# Output display
|
| 473 |
+
output_image = gr.Image(label="Generated Image")
|
| 474 |
+
debug_face = gr.Image(label="Debug: Faces are expected to be generated in these boxes")
|
| 475 |
+
download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False)
|
| 476 |
+
|
| 477 |
+
# Examples section
|
| 478 |
+
with gr.Row():
|
| 479 |
+
|
| 480 |
+
gr.Markdown("""
|
| 481 |
+
# Example Configurations
|
| 482 |
+
|
| 483 |
+
### Tips for Better Results
|
| 484 |
+
- Base image is required - faces in this image will be detected, blurred, and then replaced
|
| 485 |
+
- Provide clear reference images with visible faces
|
| 486 |
+
- Use detailed prompts describing the desired output
|
| 487 |
+
- Adjust the resemblance slider based on your needs - more to the right for closer facial resemblance
|
| 488 |
+
""")
|
| 489 |
+
with gr.Row():
|
| 490 |
+
examples = gr.Examples(
|
| 491 |
+
examples=[
|
| 492 |
+
[
|
| 493 |
+
"", # prompt
|
| 494 |
+
4.0, 25, 42, # guidance, num_steps, seed
|
| 495 |
+
"assets/ref3.jpg", "assets/ref1.jpg", None, None, # ref images
|
| 496 |
+
"assets/canvas.jpg", # base image
|
| 497 |
+
False, # use_text_prompt
|
| 498 |
+
]
|
| 499 |
+
],
|
| 500 |
+
inputs=[
|
| 501 |
+
prompt, guidance, num_steps, seed,
|
| 502 |
+
ref_img1, ref_img2, ref_img3, ref_img4,
|
| 503 |
+
base_img, use_text_prompt
|
| 504 |
+
],
|
| 505 |
+
label="Click to load example configurations"
|
| 506 |
+
)
|
| 507 |
+
# Set up event handlers
|
| 508 |
+
base_img.change(
|
| 509 |
+
fn=update_bbox_display,
|
| 510 |
+
inputs=[base_img],
|
| 511 |
+
outputs=[bbox_preview, manual_bbox_input]
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
generate_btn.click(
|
| 515 |
+
fn=process_and_generate,
|
| 516 |
+
inputs=[
|
| 517 |
+
prompt, guidance, num_steps, seed,
|
| 518 |
+
ref_img1, ref_img2, ref_img3, ref_img4,
|
| 519 |
+
base_img, use_text_prompt,
|
| 520 |
+
],
|
| 521 |
+
outputs=[output_image, debug_face, download_btn]
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
return demo
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
if __name__ == "__main__":
|
| 528 |
+
from transformers import HfArgumentParser
|
| 529 |
+
|
| 530 |
+
@dataclasses.dataclass
|
| 531 |
+
class AppArgs:
|
| 532 |
+
model_type: Literal["flux-dev", "flux-kontext", "flux-schnell"] = "flux-kontext"
|
| 533 |
+
device: Literal["cuda", "cpu"] = (
|
| 534 |
+
"cuda" if torch.cuda.is_available()
|
| 535 |
+
else "mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
| 536 |
+
else "cpu"
|
| 537 |
+
)
|
| 538 |
+
offload: bool = False
|
| 539 |
+
lora_rank: int = 64
|
| 540 |
+
port: int = 7860
|
| 541 |
+
additional_lora: str = None
|
| 542 |
+
lora_scale: float = 1.0
|
| 543 |
+
ipa_path: str = "./ckpt/ipa.safetensors"
|
| 544 |
+
clip_path: str = "openai/clip-vit-large-patch14"
|
| 545 |
+
t5_path: str = "xlabs-ai/xflux_text_encoders"
|
| 546 |
+
flux_path: str = "black-forest-labs/FLUX.1-dev"
|
| 547 |
+
|
| 548 |
+
parser = HfArgumentParser([AppArgs])
|
| 549 |
+
args = parser.parse_args_into_dataclasses()[0]
|
| 550 |
+
|
| 551 |
+
demo = create_demo(
|
| 552 |
+
args.model_type,
|
| 553 |
+
args.ipa_path,
|
| 554 |
+
args.device,
|
| 555 |
+
args.offload,
|
| 556 |
+
args.lora_rank,
|
| 557 |
+
args.additional_lora,
|
| 558 |
+
args.lora_scale,
|
| 559 |
+
args.clip_path,
|
| 560 |
+
args.t5_path,
|
| 561 |
+
args.flux_path,
|
| 562 |
+
)
|
| 563 |
+
demo.launch(server_port=args.port)
|
infer_withanyone.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Fudan University. All rights reserved.
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import dataclasses
|
| 7 |
+
from typing import Literal
|
| 8 |
+
|
| 9 |
+
from accelerate import Accelerator
|
| 10 |
+
from transformers import HfArgumentParser
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import json
|
| 13 |
+
import itertools
|
| 14 |
+
|
| 15 |
+
from withanyone.flux.pipeline import WithAnyonePipeline
|
| 16 |
+
|
| 17 |
+
from util import extract_moref, general_face_preserving_resize, horizontal_concat, extract_object, FaceExtractor
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
import random
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
from transformers import AutoModelForImageSegmentation
|
| 27 |
+
from torch.cuda.amp import autocast
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
BACK_UP_BBOXES_DOUBLE = [
|
| 31 |
+
|
| 32 |
+
[[100,100,200,200], [300,100,400,200]], # 2 faces
|
| 33 |
+
[[150,100,250,200], [300,100,400,200]]
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
BACK_UP_BBOXES = [ # for single face
|
| 37 |
+
[[150,100,250,200]],
|
| 38 |
+
[[100,100,200,200]],
|
| 39 |
+
[[200,100,300,200]],
|
| 40 |
+
[[250,100,350,200]],
|
| 41 |
+
[[300,100,400,200]],
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclasses.dataclass
|
| 50 |
+
class InferenceArgs:
|
| 51 |
+
prompt: str | None = None
|
| 52 |
+
image_paths: list[str] | None = None
|
| 53 |
+
eval_json_path: str | None = None
|
| 54 |
+
offload: bool = False
|
| 55 |
+
num_images_per_prompt: int = 1
|
| 56 |
+
model_type: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
|
| 57 |
+
width: int = 512
|
| 58 |
+
height: int = 512
|
| 59 |
+
ref_size: int = -1
|
| 60 |
+
num_steps: int = 25
|
| 61 |
+
guidance: float = 4
|
| 62 |
+
seed: int = 1234
|
| 63 |
+
save_path: str = "output/inference"
|
| 64 |
+
only_lora: bool = True
|
| 65 |
+
concat_refs: bool = False
|
| 66 |
+
lora_rank: int = 64
|
| 67 |
+
data_resolution: int = 512
|
| 68 |
+
save_iter: str = "500"
|
| 69 |
+
use_rec: bool = False
|
| 70 |
+
drop_text: bool = False
|
| 71 |
+
use_matting: bool = False
|
| 72 |
+
id_weight: float = 1.0
|
| 73 |
+
siglip_weight: float = 1.0
|
| 74 |
+
bbox_from_json: bool = True
|
| 75 |
+
data_root: str = "./"
|
| 76 |
+
# for lora
|
| 77 |
+
additional_lora: str | None = None
|
| 78 |
+
trigger: str = ""
|
| 79 |
+
lora_weight: float = 1.0
|
| 80 |
+
|
| 81 |
+
# path to the ipa model
|
| 82 |
+
ipa_path: str = "./ckpt/ipa.safetensors"
|
| 83 |
+
clip_path: str = "openai/clip-vit-large-patch14"
|
| 84 |
+
t5_path: str = "xlabs-ai/xflux_text_encoders"
|
| 85 |
+
flux_path: str = "black-forest-labs/FLUX.1-dev"
|
| 86 |
+
siglip_path: str = "google/siglip-base-patch16-256-i18n"
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def main(args: InferenceArgs):
|
| 91 |
+
accelerator = Accelerator()
|
| 92 |
+
|
| 93 |
+
face_extractor = FaceExtractor()
|
| 94 |
+
|
| 95 |
+
pipeline = WithAnyonePipeline(
|
| 96 |
+
args.model_type,
|
| 97 |
+
args.ipa_path,
|
| 98 |
+
accelerator.device,
|
| 99 |
+
args.offload,
|
| 100 |
+
only_lora=args.only_lora,
|
| 101 |
+
face_extractor=face_extractor,
|
| 102 |
+
additional_lora_ckpt=args.additional_lora,
|
| 103 |
+
lora_weight=args.lora_weight,
|
| 104 |
+
clip_path=args.clip_path,
|
| 105 |
+
t5_path=args.t5_path,
|
| 106 |
+
flux_path=args.flux_path,
|
| 107 |
+
siglip_path=args.siglip_path,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
if args.use_matting:
|
| 113 |
+
birefnet = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True).to('cuda', dtype=torch.bfloat16)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
assert args.prompt is not None or args.eval_json_path is not None, \
|
| 117 |
+
"Please provide either prompt or eval_json_path"
|
| 118 |
+
|
| 119 |
+
# if args.eval_json_path is not None:
|
| 120 |
+
assert args.eval_json_path is not None, "Please provide eval_json_path. This script only supports batch inference."
|
| 121 |
+
with open(args.eval_json_path, "rt") as f:
|
| 122 |
+
data_dicts = json.load(f)
|
| 123 |
+
data_root = args.data_root
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
metadata = {}
|
| 128 |
+
for (i, data_dict), j in itertools.product(enumerate(data_dicts), range(args.num_images_per_prompt)):
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if (i * args.num_images_per_prompt + j) % accelerator.num_processes != accelerator.process_index:
|
| 132 |
+
continue
|
| 133 |
+
# check if exist, if this image is already generated, skip it
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# if any of the images are None, skip this image
|
| 138 |
+
if not os.path.exists(os.path.join(data_root, data_dict["image_paths"][0])):
|
| 139 |
+
print(f"Image {i} does not exist, skipping...")
|
| 140 |
+
print("path:", os.path.join(data_root, data_dict["image_paths"][0]))
|
| 141 |
+
continue
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# extract bbox
|
| 145 |
+
|
| 146 |
+
ori_img_path = data_dict.get("ori_img_path", None)
|
| 147 |
+
# ori_img = Image.open(os.path.join(data_root, data_dict["ori_img_path"]))
|
| 148 |
+
|
| 149 |
+
# basename = data_dict["ori_img_path"].split(".")[0].replace("/", "_")
|
| 150 |
+
if ori_img_path is None:
|
| 151 |
+
basename = f"{i}_{j}"
|
| 152 |
+
else:
|
| 153 |
+
basename = data_dict["ori_img_path"].split(".")[0].replace("/", "_")
|
| 154 |
+
ori_img = Image.open(os.path.join(data_root, ori_img_path))
|
| 155 |
+
bboxes = None
|
| 156 |
+
print("Processing image", i, basename)
|
| 157 |
+
if not args.bbox_from_json:
|
| 158 |
+
if ori_img_path is None:
|
| 159 |
+
print(f"Image {i} has no ori_img_path, cannot extract bbox, skipping...")
|
| 160 |
+
continue
|
| 161 |
+
ori_img = Image.open(os.path.join(data_root, ori_img_path))
|
| 162 |
+
bboxes = face_extractor.locate_bboxes(ori_img)
|
| 163 |
+
# cut bbox length to num of imgae_paths
|
| 164 |
+
if bboxes is not None and len(bboxes) > len(data_dict["image_paths"]):
|
| 165 |
+
bboxes = bboxes[:len(data_dict["image_paths"])]
|
| 166 |
+
elif bboxes is not None and len(bboxes) < len(data_dict["image_paths"]):
|
| 167 |
+
print(f"Image {i} has less faces than image_paths, continuing...")
|
| 168 |
+
continue
|
| 169 |
+
else:
|
| 170 |
+
json_file_path = os.path.join(data_root, basename + ".json")
|
| 171 |
+
if os.path.exists(json_file_path):
|
| 172 |
+
with open(json_file_path, "r") as f:
|
| 173 |
+
json_data = json.load(f)
|
| 174 |
+
old_bboxes = json_data.get("bboxes", None)
|
| 175 |
+
|
| 176 |
+
if old_bboxes is None:
|
| 177 |
+
print(f"Image {i} has no bboxes in json file, using backup bboxes...")
|
| 178 |
+
# v202 -> 2 faces v200_single -> 1 face
|
| 179 |
+
if "v202" in args.eval_json_path:
|
| 180 |
+
old_bboxes = random.choice(BACK_UP_BBOXES_DOUBLE)
|
| 181 |
+
elif "v200_single" in args.eval_json_path:
|
| 182 |
+
old_bboxes = random.choice(BACK_UP_BBOXES)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def recalculate_bbox( bbox, crop):
|
| 186 |
+
"""
|
| 187 |
+
The image is cropped, so we need to recalculate the bbox.
|
| 188 |
+
bbox: [x1, y1, x2, y2]
|
| 189 |
+
crop: [x1c, y1c, x2c, y2c]
|
| 190 |
+
we just need to minus x1c and y1c from x1, y1,
|
| 191 |
+
"""
|
| 192 |
+
x1, y1, x2, y2 = bbox
|
| 193 |
+
x1c, y1c, x2c, y2c = crop
|
| 194 |
+
return [x1-x1c, y1-y1c, x2-x1c, y2-y1c]
|
| 195 |
+
crop = json_data.get("crop", None)
|
| 196 |
+
rec_bboxes = [
|
| 197 |
+
recalculate_bbox(bbox, crop) if crop is not None else bbox for bbox in old_bboxes]
|
| 198 |
+
# face_preserving_resize(image, bboxes, 512)
|
| 199 |
+
if ori_img_path is not None:
|
| 200 |
+
_, bboxes = general_face_preserving_resize(ori_img, rec_bboxes, 512)
|
| 201 |
+
# else we consider the provided bbox is already in target size
|
| 202 |
+
else:
|
| 203 |
+
bboxes = rec_bboxes
|
| 204 |
+
|
| 205 |
+
if bboxes is None:
|
| 206 |
+
|
| 207 |
+
print(f"Image {i} has no face, bboxes are None, using backup bboxes..., basename: {basename}")
|
| 208 |
+
|
| 209 |
+
bboxes = random.choice(BACK_UP_BBOXES_DOUBLE)
|
| 210 |
+
print(f"Use backup bboxes: {bboxes}")
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
ref_imgs = []
|
| 214 |
+
arcface_embeddings = []
|
| 215 |
+
if not args.use_rec:
|
| 216 |
+
break_flag = False
|
| 217 |
+
for img_path in data_dict["image_paths"]:
|
| 218 |
+
img = Image.open(os.path.join(data_root, img_path))
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
ref_img, arcface_embedding = face_extractor.extract(img)
|
| 222 |
+
|
| 223 |
+
if ref_img is not None and arcface_embedding is not None:
|
| 224 |
+
if args.use_matting:
|
| 225 |
+
ref_img, _ = extract_object(birefnet, ref_img)
|
| 226 |
+
ref_imgs.append(ref_img)
|
| 227 |
+
arcface_embeddings.append(arcface_embedding)
|
| 228 |
+
else:
|
| 229 |
+
print(f"Image {i} has no face, skipping...")
|
| 230 |
+
break_flag = True
|
| 231 |
+
break
|
| 232 |
+
if break_flag:
|
| 233 |
+
continue
|
| 234 |
+
else:
|
| 235 |
+
ref_imgs, arcface_embeddings = face_extractor.extract_refs(ori_img)
|
| 236 |
+
|
| 237 |
+
if ref_imgs is None or arcface_embeddings is None:
|
| 238 |
+
print(f"Image {i} has no face, skipping...")
|
| 239 |
+
continue
|
| 240 |
+
|
| 241 |
+
if args.use_matting:
|
| 242 |
+
ref_imgs = [extract_object(birefnet, ref_img)[0] for ref_img in ref_imgs]
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
# arcface to tensor
|
| 246 |
+
arcface_embeddings = [torch.tensor(arcface_embedding) for arcface_embedding in arcface_embeddings]
|
| 247 |
+
arcface_embeddings = torch.stack(arcface_embeddings).to(accelerator.device)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
# check, if any of the images are None, if so, skip this image
|
| 251 |
+
if any(ref_img is None for ref_img in ref_imgs):
|
| 252 |
+
print(f"Image {i}: failed to extract face, skipping...")
|
| 253 |
+
continue
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
if args.ref_size==-1:
|
| 257 |
+
args.ref_size = 512 if len(ref_imgs)==1 else 320
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
if args.trigger != "" and args.trigger is not None:
|
| 261 |
+
data_dict["prompt"] = args.trigger + " " + data_dict["prompt"]
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
image_gen = pipeline(
|
| 265 |
+
prompt=data_dict["prompt"] if not args.drop_text else "",
|
| 266 |
+
width=args.width,
|
| 267 |
+
height=args.height,
|
| 268 |
+
guidance=args.guidance,
|
| 269 |
+
num_steps=args.num_steps,
|
| 270 |
+
seed=args.seed,
|
| 271 |
+
ref_imgs=ref_imgs,
|
| 272 |
+
arcface_embeddings=arcface_embeddings,
|
| 273 |
+
bboxes=[bboxes],
|
| 274 |
+
id_weight=args.id_weight,
|
| 275 |
+
siglip_weight=args.siglip_weight,
|
| 276 |
+
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if args.concat_refs:
|
| 281 |
+
image_gen = horizontal_concat([image_gen, *ref_imgs])
|
| 282 |
+
|
| 283 |
+
os.makedirs(args.save_path, exist_ok=True)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
save_path = os.path.join(args.save_path, basename)
|
| 287 |
+
os.makedirs(os.path.join(args.save_path, basename), exist_ok=True)
|
| 288 |
+
|
| 289 |
+
# save refs, image_gen and original image
|
| 290 |
+
for k, ref_img in enumerate(ref_imgs):
|
| 291 |
+
ref_img.save(os.path.join(save_path, f"ref_{k}.jpg"))
|
| 292 |
+
image_gen.save(os.path.join(save_path, f"out.jpg"))
|
| 293 |
+
# original image
|
| 294 |
+
ori_img = Image.open(os.path.join(data_root, data_dict["ori_img_path"])) if "ori_img_path" in data_dict else None
|
| 295 |
+
if ori_img is not None:
|
| 296 |
+
ori_img.save(os.path.join(save_path, f"ori.jpg"))
|
| 297 |
+
# save config
|
| 298 |
+
args_dict = vars(args)
|
| 299 |
+
args_dict['prompt'] = data_dict["prompt"]
|
| 300 |
+
args_dict["name"] = data_dict["name"] if "name" in data_dict else None
|
| 301 |
+
json.dump(args_dict, open(os.path.join(save_path, f"meta.json"), 'w'), indent=4, ensure_ascii=False)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
if __name__ == "__main__":
|
| 305 |
+
parser = HfArgumentParser([InferenceArgs])
|
| 306 |
+
args = parser.parse_args_into_dataclasses()[0]
|
| 307 |
+
main(args)
|
| 308 |
+
|
| 309 |
+
|
nohup.out
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Only 2 GPUs available, exiting.
|
| 2 |
+
Only 2 GPUs available, exiting.
|
requirements.txt
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==1.6.0
|
| 2 |
+
einops
|
| 3 |
+
gradio
|
| 4 |
+
huggingface_hub
|
| 5 |
+
insightface
|
| 6 |
+
matplotlib
|
| 7 |
+
numpy
|
| 8 |
+
opencv-python
|
| 9 |
+
opencv-python-headless
|
| 10 |
+
optimum
|
| 11 |
+
optimum_quanto
|
| 12 |
+
Pillow
|
| 13 |
+
PyYAML
|
| 14 |
+
PyYAML
|
| 15 |
+
safetensors
|
| 16 |
+
seaborn
|
| 17 |
+
scikit-image
|
| 18 |
+
torch==2.5.1
|
| 19 |
+
torchvision==0.20.1
|
| 20 |
+
tqdm
|
| 21 |
+
transformers==4.45.2
|
| 22 |
+
onnxruntime
|
| 23 |
+
onnxruntime-gpu
|
| 24 |
+
sentencepiece
|
util.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Fudan University. All rights reserved.
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
import random
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import numpy as np
|
| 8 |
+
import cv2
|
| 9 |
+
import insightface
|
| 10 |
+
import torch
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
from torch.cuda.amp import autocast
|
| 13 |
+
|
| 14 |
+
def face_preserving_resize(img, face_bboxes, target_size=512):
|
| 15 |
+
"""
|
| 16 |
+
Resize image while ensuring all faces are preserved in the output.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
img: PIL Image
|
| 20 |
+
face_bboxes: List of [x1, y1, x2, y2] face coordinates
|
| 21 |
+
target_size: Maximum dimension for resizing
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Tuple of (resized image, new_bboxes) or (None, None) if faces can't fit
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
x1_1, y1_1, x2_1, y2_1 = map(int, face_bboxes[0])
|
| 28 |
+
x1_2, y1_2, x2_2, y2_2 = map(int, face_bboxes[1])
|
| 29 |
+
min_x1 = min(x1_1, x1_2)
|
| 30 |
+
min_y1 = min(y1_1, y1_2)
|
| 31 |
+
max_x2 = max(x2_1, x2_2)
|
| 32 |
+
max_y2 = max(y2_1, y2_2)
|
| 33 |
+
# print("min_x1:", min_x1, "min_y1:", min_y1, "max_x2:", max_x2, "max_y2:", max_y2)
|
| 34 |
+
# if any of them is negative, we cannot resize (Idk why this happens)
|
| 35 |
+
if min_x1 < 0 or min_y1 < 0 or max_x2 < 0 or max_y2 < 0:
|
| 36 |
+
return None, None
|
| 37 |
+
|
| 38 |
+
# if face width is longer than the image height, or the face height is longer than the image width, we cannot resize
|
| 39 |
+
face_width = max_x2 - min_x1
|
| 40 |
+
face_height = max_y2 - min_y1
|
| 41 |
+
if face_width > img.height or face_height > img.width:
|
| 42 |
+
return None, None
|
| 43 |
+
|
| 44 |
+
# Create a copy of face_bboxes for transformation
|
| 45 |
+
new_bboxes = []
|
| 46 |
+
for bbox in face_bboxes:
|
| 47 |
+
new_bboxes.append(list(map(int, bbox)))
|
| 48 |
+
|
| 49 |
+
# Choose cropping strategy based on image aspect ratio
|
| 50 |
+
if img.width > img.height:
|
| 51 |
+
# We need to crop width to make a square
|
| 52 |
+
square_size = img.height
|
| 53 |
+
|
| 54 |
+
# Calculate valid horizontal crop range that preserves all faces
|
| 55 |
+
left_max = min_x1 # Leftmost position that includes leftmost face
|
| 56 |
+
right_min = max_x2 - square_size # Rightmost position that includes rightmost face
|
| 57 |
+
|
| 58 |
+
if right_min <= left_max:
|
| 59 |
+
# We can find a valid crop window
|
| 60 |
+
start = random.randint(int(right_min), int(left_max)) if right_min < left_max else int(right_min)
|
| 61 |
+
start = max(0, min(start, img.width - square_size)) # Ensure within image bounds
|
| 62 |
+
else:
|
| 63 |
+
# Faces are too far apart for square crop - use center of faces
|
| 64 |
+
face_center = (min_x1 + max_x2) // 2
|
| 65 |
+
start = max(0, min(face_center - (square_size // 2), img.width - square_size))
|
| 66 |
+
|
| 67 |
+
cropped_img = img.crop((start, 0, start + square_size, square_size))
|
| 68 |
+
|
| 69 |
+
# Adjust bounding box coordinates based on crop
|
| 70 |
+
for bbox in new_bboxes:
|
| 71 |
+
bbox[0] -= start # x1 adjustment
|
| 72 |
+
bbox[2] -= start # x2 adjustment
|
| 73 |
+
# y coordinates remain unchanged
|
| 74 |
+
else:
|
| 75 |
+
# We need to crop height to make a square
|
| 76 |
+
square_size = img.width
|
| 77 |
+
|
| 78 |
+
# Calculate valid vertical crop range that preserves all faces
|
| 79 |
+
top_max = min_y1 # Topmost position that includes topmost face
|
| 80 |
+
bottom_min = max_y2 - square_size # Bottommost position that includes bottommost face
|
| 81 |
+
|
| 82 |
+
if bottom_min <= top_max:
|
| 83 |
+
# We can find a valid crop window
|
| 84 |
+
start = random.randint(int(bottom_min), int(top_max)) if bottom_min < top_max else int(bottom_min)
|
| 85 |
+
start = max(0, min(start, img.height - square_size)) # Ensure within image bounds
|
| 86 |
+
else:
|
| 87 |
+
# Faces are too far apart for square crop - use center of faces
|
| 88 |
+
face_center = (min_y1 + max_y2) // 2
|
| 89 |
+
start = max(0, min(face_center - (square_size // 2), img.height - square_size))
|
| 90 |
+
|
| 91 |
+
cropped_img = img.crop((0, start, square_size, start + square_size))
|
| 92 |
+
|
| 93 |
+
# Adjust bounding box coordinates based on crop
|
| 94 |
+
for bbox in new_bboxes:
|
| 95 |
+
bbox[1] -= start # y1 adjustment
|
| 96 |
+
bbox[3] -= start # y2 adjustment
|
| 97 |
+
# x coordinates remain unchanged
|
| 98 |
+
|
| 99 |
+
# Calculate scale factor for resizing from square_size to target_size
|
| 100 |
+
scale_factor = target_size / square_size
|
| 101 |
+
|
| 102 |
+
# Adjust bounding boxes for the resize operation
|
| 103 |
+
for bbox in new_bboxes:
|
| 104 |
+
bbox[0] = int(bbox[0] * scale_factor)
|
| 105 |
+
bbox[1] = int(bbox[1] * scale_factor)
|
| 106 |
+
bbox[2] = int(bbox[2] * scale_factor)
|
| 107 |
+
bbox[3] = int(bbox[3] * scale_factor)
|
| 108 |
+
|
| 109 |
+
# Final resize to target size
|
| 110 |
+
resized_img = cropped_img.resize((target_size, target_size), Image.Resampling.LANCZOS)
|
| 111 |
+
|
| 112 |
+
# Make sure all coordinates are within bounds (0 to target_size)
|
| 113 |
+
# for bbox in new_bboxes:
|
| 114 |
+
# bbox[0] = max(0, min(bbox[0], target_size - 1))
|
| 115 |
+
# bbox[1] = max(0, min(bbox[1], target_size - 1))
|
| 116 |
+
# bbox[2] = max(1, min(bbox[2], target_size))
|
| 117 |
+
# bbox[3] = max(1, min(bbox[3], target_size))
|
| 118 |
+
|
| 119 |
+
return resized_img, new_bboxes
|
| 120 |
+
|
| 121 |
+
def extract_moref(img, json_data, face_size_restriction=100):
|
| 122 |
+
"""
|
| 123 |
+
Extract faces from an image based on bounding boxes in JSON data.
|
| 124 |
+
Makes each face square and resizes to 512x512.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
img: PIL Image or image data
|
| 128 |
+
json_data: JSON object with 'bboxes' and 'crop' information
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
List of PIL Images, each 512x512, containing extracted faces
|
| 132 |
+
"""
|
| 133 |
+
# Ensure img is a PIL Image
|
| 134 |
+
try:
|
| 135 |
+
if not isinstance(img, Image.Image) and not isinstance(img, torch.Tensor) and not isinstance(img, JpegImageFile):
|
| 136 |
+
img = Image.open(BytesIO(img))
|
| 137 |
+
|
| 138 |
+
bboxes = json_data['bboxes']
|
| 139 |
+
# crop = json_data['crop']
|
| 140 |
+
# print("len of bboxes:", len(bboxes))
|
| 141 |
+
# Recalculate bounding boxes based on crop info
|
| 142 |
+
# new_bboxes = [recalculate_bbox(bbox, crop) for bbox in bboxes]
|
| 143 |
+
new_bboxes = bboxes
|
| 144 |
+
# any of the face is less than 100 * 100, we ignore this image
|
| 145 |
+
for bbox in new_bboxes:
|
| 146 |
+
x1, y1, x2, y2 = bbox
|
| 147 |
+
if x2 - x1 < face_size_restriction or y2 - y1 < face_size_restriction:
|
| 148 |
+
return []
|
| 149 |
+
# print("len of new_bboxes:", len(new_bboxes))
|
| 150 |
+
faces = []
|
| 151 |
+
for bbox in new_bboxes:
|
| 152 |
+
# print("processing bbox")
|
| 153 |
+
# Convert coordinates to integers
|
| 154 |
+
x1, y1, x2, y2 = map(int, bbox)
|
| 155 |
+
|
| 156 |
+
# Calculate width and height
|
| 157 |
+
width = x2 - x1
|
| 158 |
+
height = y2 - y1
|
| 159 |
+
|
| 160 |
+
# Make the bounding box square by expanding the shorter dimension
|
| 161 |
+
if width > height:
|
| 162 |
+
# Height is shorter, expand it
|
| 163 |
+
diff = width - height
|
| 164 |
+
y1 -= diff // 2
|
| 165 |
+
y2 += diff - (diff // 2) # Handle odd differences
|
| 166 |
+
elif height > width:
|
| 167 |
+
# Width is shorter, expand it
|
| 168 |
+
diff = height - width
|
| 169 |
+
x1 -= diff // 2
|
| 170 |
+
x2 += diff - (diff // 2) # Handle odd differences
|
| 171 |
+
|
| 172 |
+
# Ensure coordinates are within image boundaries
|
| 173 |
+
img_width, img_height = img.size
|
| 174 |
+
x1 = max(0, x1)
|
| 175 |
+
y1 = max(0, y1)
|
| 176 |
+
x2 = min(img_width, x2)
|
| 177 |
+
y2 = min(img_height, y2)
|
| 178 |
+
|
| 179 |
+
# Extract face region
|
| 180 |
+
face_region = img.crop((x1, y1, x2, y2))
|
| 181 |
+
|
| 182 |
+
# Resize to 512x512
|
| 183 |
+
face_region = face_region.resize((512, 512), Image.LANCZOS)
|
| 184 |
+
|
| 185 |
+
faces.append(face_region)
|
| 186 |
+
# print("len of faces:", len(faces))
|
| 187 |
+
return faces
|
| 188 |
+
except Exception as e:
|
| 189 |
+
print(f"Error processing image: {e}")
|
| 190 |
+
return []
|
| 191 |
+
|
| 192 |
+
def general_face_preserving_resize(img, face_bboxes, target_size=512):
|
| 193 |
+
"""
|
| 194 |
+
Resize image while ensuring all faces are preserved in the output.
|
| 195 |
+
Handles any number of faces (1-5).
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
img: PIL Image
|
| 199 |
+
face_bboxes: List of [x1, y1, x2, y2] face coordinates
|
| 200 |
+
target_size: Maximum dimension for resizing
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
Tuple of (resized image, new_bboxes) or (None, None) if faces can't fit
|
| 204 |
+
"""
|
| 205 |
+
# Find bounding region containing all faces
|
| 206 |
+
if not face_bboxes:
|
| 207 |
+
print("Warning: No face bounding boxes provided.")
|
| 208 |
+
return None, None
|
| 209 |
+
|
| 210 |
+
min_x1 = min(bbox[0] for bbox in face_bboxes)
|
| 211 |
+
min_y1 = min(bbox[1] for bbox in face_bboxes)
|
| 212 |
+
max_x2 = max(bbox[2] for bbox in face_bboxes)
|
| 213 |
+
max_y2 = max(bbox[3] for bbox in face_bboxes)
|
| 214 |
+
|
| 215 |
+
# Check for negative coordinates
|
| 216 |
+
if min_x1 < 0 or min_y1 < 0 or max_x2 < 0 or max_y2 < 0:
|
| 217 |
+
# print("Warning: Negative coordinates found in face bounding boxes.")
|
| 218 |
+
# return None, None
|
| 219 |
+
min_x1 = max(min_x1, 0)
|
| 220 |
+
min_y1 = max(min_y1, 0)
|
| 221 |
+
|
| 222 |
+
# Check if faces fit within image
|
| 223 |
+
face_width = max_x2 - min_x1
|
| 224 |
+
face_height = max_y2 - min_y1
|
| 225 |
+
if face_width > img.height or face_height > img.width:
|
| 226 |
+
# print("Warning: Faces are too large for the image dimensions.")
|
| 227 |
+
# return None, None
|
| 228 |
+
# Instead of returning None, we will crop the image to fit the faces
|
| 229 |
+
max_x2 = min(max_x2, img.width)
|
| 230 |
+
max_y2 = min(max_y2, img.height)
|
| 231 |
+
min_x1 = max(min_x1, 0)
|
| 232 |
+
min_y1 = max(min_y1, 0)
|
| 233 |
+
# Create a copy of face_bboxes for transformation
|
| 234 |
+
new_bboxes = []
|
| 235 |
+
for bbox in face_bboxes:
|
| 236 |
+
new_bboxes.append(list(map(int, bbox)))
|
| 237 |
+
|
| 238 |
+
# Choose cropping strategy based on image aspect ratio
|
| 239 |
+
if img.width > img.height:
|
| 240 |
+
# Crop width to make a square
|
| 241 |
+
square_size = img.height
|
| 242 |
+
|
| 243 |
+
# Calculate valid horizontal crop range
|
| 244 |
+
left_max = min_x1
|
| 245 |
+
right_min = max_x2 - square_size
|
| 246 |
+
|
| 247 |
+
if right_min <= left_max:
|
| 248 |
+
# We can find a valid crop window
|
| 249 |
+
start = random.randint(int(right_min), int(left_max)) if right_min < left_max else int(right_min)
|
| 250 |
+
start = max(0, min(start, img.width - square_size))
|
| 251 |
+
else:
|
| 252 |
+
# Faces are too far apart - use center of faces
|
| 253 |
+
face_center = (min_x1 + max_x2) // 2
|
| 254 |
+
start = max(0, min(face_center - (square_size // 2), img.width - square_size))
|
| 255 |
+
|
| 256 |
+
cropped_img = img.crop((start, 0, start + square_size, square_size))
|
| 257 |
+
|
| 258 |
+
# Adjust bounding box coordinates
|
| 259 |
+
for bbox in new_bboxes:
|
| 260 |
+
bbox[0] -= start
|
| 261 |
+
bbox[2] -= start
|
| 262 |
+
else:
|
| 263 |
+
# Crop height to make a square
|
| 264 |
+
square_size = img.width
|
| 265 |
+
|
| 266 |
+
# Calculate valid vertical crop range
|
| 267 |
+
top_max = min_y1
|
| 268 |
+
bottom_min = max_y2 - square_size
|
| 269 |
+
|
| 270 |
+
if bottom_min <= top_max:
|
| 271 |
+
start = random.randint(int(bottom_min), int(top_max)) if bottom_min < top_max else int(bottom_min)
|
| 272 |
+
start = max(0, min(start, img.height - square_size))
|
| 273 |
+
else:
|
| 274 |
+
face_center = (min_y1 + max_y2) // 2
|
| 275 |
+
start = max(0, min(face_center - (square_size // 2), img.height - square_size))
|
| 276 |
+
|
| 277 |
+
cropped_img = img.crop((0, start, square_size, start + square_size))
|
| 278 |
+
|
| 279 |
+
# Adjust bounding box coordinates
|
| 280 |
+
for bbox in new_bboxes:
|
| 281 |
+
bbox[1] -= start
|
| 282 |
+
bbox[3] -= start
|
| 283 |
+
|
| 284 |
+
# Calculate scale factor and adjust bounding boxes
|
| 285 |
+
scale_factor = target_size / square_size
|
| 286 |
+
|
| 287 |
+
for bbox in new_bboxes:
|
| 288 |
+
bbox[0] = int(bbox[0] * scale_factor)
|
| 289 |
+
bbox[1] = int(bbox[1] * scale_factor)
|
| 290 |
+
bbox[2] = int(bbox[2] * scale_factor)
|
| 291 |
+
bbox[3] = int(bbox[3] * scale_factor)
|
| 292 |
+
|
| 293 |
+
# Final resize to target size
|
| 294 |
+
resized_img = cropped_img.resize((target_size, target_size), Image.Resampling.LANCZOS)
|
| 295 |
+
|
| 296 |
+
# Make sure all coordinates are within bounds
|
| 297 |
+
for bbox in new_bboxes:
|
| 298 |
+
bbox[0] = max(0, min(bbox[0], target_size - 1))
|
| 299 |
+
bbox[1] = max(0, min(bbox[1], target_size - 1))
|
| 300 |
+
bbox[2] = max(1, min(bbox[2], target_size))
|
| 301 |
+
bbox[3] = max(1, min(bbox[3], target_size))
|
| 302 |
+
|
| 303 |
+
return resized_img, new_bboxes
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def horizontal_concat(images):
|
| 308 |
+
widths, heights = zip(*(img.size for img in images))
|
| 309 |
+
|
| 310 |
+
total_width = sum(widths)
|
| 311 |
+
max_height = max(heights)
|
| 312 |
+
|
| 313 |
+
new_im = Image.new('RGB', (total_width, max_height))
|
| 314 |
+
|
| 315 |
+
x_offset = 0
|
| 316 |
+
for img in images:
|
| 317 |
+
new_im.paste(img, (x_offset, 0))
|
| 318 |
+
x_offset += img.size[0]
|
| 319 |
+
|
| 320 |
+
return new_im
|
| 321 |
+
|
| 322 |
+
def extract_object(birefnet, image):
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
if image.mode != 'RGB':
|
| 326 |
+
image = image.convert('RGB')
|
| 327 |
+
input_images = transforms.ToTensor()(image).unsqueeze(0).to('cuda', dtype=torch.bfloat16)
|
| 328 |
+
|
| 329 |
+
# Prediction
|
| 330 |
+
with torch.no_grad(), autocast(dtype=torch.bfloat16):
|
| 331 |
+
preds = birefnet(input_images)[-1].sigmoid().cpu()
|
| 332 |
+
pred = preds[0].squeeze().float()
|
| 333 |
+
pred_pil = transforms.ToPILImage()(pred)
|
| 334 |
+
mask = pred_pil.resize(image.size)
|
| 335 |
+
|
| 336 |
+
# Create a binary mask (0 or 255)
|
| 337 |
+
binary_mask = mask.convert("L")
|
| 338 |
+
|
| 339 |
+
# Create a new image with black background
|
| 340 |
+
result = Image.new("RGB", image.size, (0, 0, 0))
|
| 341 |
+
|
| 342 |
+
# Paste the original image onto the black background using the mask
|
| 343 |
+
result.paste(image, (0, 0), binary_mask)
|
| 344 |
+
|
| 345 |
+
return result, mask
|
| 346 |
+
|
| 347 |
+
class FaceExtractor:
|
| 348 |
+
def __init__(self):
|
| 349 |
+
self.model = insightface.app.FaceAnalysis(name = "antelopev2", root="./")
|
| 350 |
+
self.model.prepare(ctx_id=0, det_thresh=0.4)
|
| 351 |
+
|
| 352 |
+
def extract(self, image: Image.Image):
|
| 353 |
+
image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 354 |
+
res = self.model.get(image_np)
|
| 355 |
+
if len(res) == 0:
|
| 356 |
+
return None, None
|
| 357 |
+
res = res[0]
|
| 358 |
+
# print(res.keys())
|
| 359 |
+
bbox = res["bbox"]
|
| 360 |
+
# print("len(bbox)", len(bbox))
|
| 361 |
+
|
| 362 |
+
moref = extract_moref(image, {"bboxes": [bbox]}, 1)
|
| 363 |
+
# print("len(moref)", len(moref))
|
| 364 |
+
return moref[0], res["embedding"]
|
| 365 |
+
|
| 366 |
+
def locate_bboxes(self, image: Image.Image):
|
| 367 |
+
image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 368 |
+
res = self.model.get(image_np)
|
| 369 |
+
if len(res) == 0:
|
| 370 |
+
return None
|
| 371 |
+
bboxes = []
|
| 372 |
+
for r in res:
|
| 373 |
+
bbox = r["bbox"]
|
| 374 |
+
bboxes.append(bbox)
|
| 375 |
+
|
| 376 |
+
_, new_bboxes_ = general_face_preserving_resize(image, bboxes, 512)
|
| 377 |
+
|
| 378 |
+
# ensure the bbox is square
|
| 379 |
+
new_bboxes = []
|
| 380 |
+
for bbox in new_bboxes_:
|
| 381 |
+
x1, y1, x2, y2 = bbox
|
| 382 |
+
w = x2 - x1
|
| 383 |
+
h = y2 - y1
|
| 384 |
+
if w > h:
|
| 385 |
+
diff = w - h
|
| 386 |
+
y1 = max(0, y1 - diff // 2)
|
| 387 |
+
y2 = min(512, y2 + diff // 2 + diff % 2)
|
| 388 |
+
else:
|
| 389 |
+
diff = h - w
|
| 390 |
+
x1 = max(0, x1 - diff // 2)
|
| 391 |
+
x2 = min(512, x2 + diff // 2 + diff % 2)
|
| 392 |
+
new_bboxes.append([x1, y1, x2, y2])
|
| 393 |
+
|
| 394 |
+
return new_bboxes
|
| 395 |
+
def extract_refs(self, image: Image.Image):
|
| 396 |
+
"""
|
| 397 |
+
Extracts reference faces from the image.
|
| 398 |
+
Returns a list of reference images and their arcface embeddings.
|
| 399 |
+
"""
|
| 400 |
+
image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 401 |
+
res = self.model.get(image_np)
|
| 402 |
+
if len(res) == 0:
|
| 403 |
+
return None, None
|
| 404 |
+
ref_imgs = []
|
| 405 |
+
arcface_embeddings = []
|
| 406 |
+
for r in res:
|
| 407 |
+
bbox = r["bbox"]
|
| 408 |
+
moref = extract_moref(image, {"bboxes": [bbox]}, 1)
|
| 409 |
+
ref_imgs.append(moref[0])
|
| 410 |
+
arcface_embeddings.append(r["embedding"])
|
| 411 |
+
return ref_imgs, arcface_embeddings
|
withanyone/flux/__pycache__/math.cpython-310.pyc
ADDED
|
Binary file (2.03 kB). View file
|
|
|
withanyone/flux/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
withanyone/flux/__pycache__/pipeline.cpython-310.pyc
ADDED
|
Binary file (8.58 kB). View file
|
|
|
withanyone/flux/__pycache__/sampling.cpython-310.pyc
ADDED
|
Binary file (4.12 kB). View file
|
|
|
withanyone/flux/__pycache__/util.cpython-310.pyc
ADDED
|
Binary file (11 kB). View file
|
|
|
withanyone/flux/math.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import numpy as np
|
| 10 |
+
import os
|
| 11 |
+
import seaborn as sns
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
from matplotlib.colors import LinearSegmentedColormap
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
# a return class
|
| 16 |
+
@dataclass
|
| 17 |
+
class AttentionReturnQAndMAP:
|
| 18 |
+
result: Tensor
|
| 19 |
+
attention_map: Tensor
|
| 20 |
+
Q: Tensor
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask = None, token_aug_idx = -1, text_length = None, image_length = None, return_map = False) -> Tensor:
|
| 25 |
+
q, k = apply_rope(q, k, pe)
|
| 26 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, mask)
|
| 27 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
| 28 |
+
|
| 29 |
+
return x
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
| 35 |
+
assert dim % 2 == 0
|
| 36 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
| 37 |
+
omega = 1.0 / (theta**scale)
|
| 38 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
| 39 |
+
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
| 40 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
| 41 |
+
return out.float()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
| 45 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
| 46 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
| 47 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
| 48 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
| 49 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
withanyone/flux/model.py
ADDED
|
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor, nn
|
| 7 |
+
|
| 8 |
+
from .modules.layers import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding, PerceiverAttentionCA
|
| 9 |
+
from transformers import AutoTokenizer, AutoProcessor, SiglipModel
|
| 10 |
+
import math
|
| 11 |
+
from transformers import AutoModelForImageSegmentation
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
|
| 14 |
+
from torchvision import transforms
|
| 15 |
+
from PIL import Image
|
| 16 |
+
from torch.cuda.amp import autocast
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def create_person_cross_attention_mask_varlen(
|
| 21 |
+
batch_size, img_len, id_len,
|
| 22 |
+
bbox_lists, original_width, original_height,
|
| 23 |
+
max_num_ids=2, # Default to support 2 identities
|
| 24 |
+
vae_scale_factor=8, patch_size=2, num_heads = 24
|
| 25 |
+
):
|
| 26 |
+
"""
|
| 27 |
+
Create boolean attention masks limiting image tokens to interact only with corresponding person ID tokens
|
| 28 |
+
|
| 29 |
+
Parameters:
|
| 30 |
+
- batch_size: Number of samples in batch
|
| 31 |
+
- num_heads: Number of attention heads
|
| 32 |
+
- img_len: Length of image token sequence
|
| 33 |
+
- id_len: Length of EACH identity embedding (not total)
|
| 34 |
+
- bbox_lists: List where bbox_lists[i] contains all bboxes for batch i
|
| 35 |
+
Each batch may have a different number of bboxes/identities
|
| 36 |
+
- max_num_ids: Maximum number of identities to support (for padding)
|
| 37 |
+
- original_width/height: Original image dimensions
|
| 38 |
+
- vae_scale_factor: VAE downsampling factor (default 8)
|
| 39 |
+
- patch_size: Patch size for token creation (default 2)
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
- Boolean attention mask of shape [batch_size, num_heads, img_len, total_id_len]
|
| 43 |
+
"""
|
| 44 |
+
# Total length of ID tokens based on maximum number of identities
|
| 45 |
+
total_id_len = max_num_ids * id_len
|
| 46 |
+
|
| 47 |
+
# Initialize mask to block all attention
|
| 48 |
+
mask = torch.zeros((batch_size, num_heads, img_len, total_id_len), dtype=torch.bool)
|
| 49 |
+
|
| 50 |
+
# Calculate VAE dimensions
|
| 51 |
+
latent_width = original_width // vae_scale_factor
|
| 52 |
+
latent_height = original_height // vae_scale_factor
|
| 53 |
+
patches_width = latent_width // patch_size
|
| 54 |
+
patches_height = latent_height // patch_size
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Convert boundary box to token indices
|
| 59 |
+
def bbox_to_token_indices(bbox):
|
| 60 |
+
x1, y1, x2, y2 = bbox
|
| 61 |
+
|
| 62 |
+
# Convert to patch space coordinates
|
| 63 |
+
if isinstance(x1, torch.Tensor):
|
| 64 |
+
x1_patch = max(0, int(x1.item()) // vae_scale_factor // patch_size)
|
| 65 |
+
y1_patch = max(0, int(y1.item()) // vae_scale_factor // patch_size)
|
| 66 |
+
x2_patch = min(patches_width, math.ceil(int(x2.item()) / vae_scale_factor / patch_size))
|
| 67 |
+
y2_patch = min(patches_height, math.ceil(int(y2.item()) / vae_scale_factor / patch_size))
|
| 68 |
+
elif isinstance(x1, int):
|
| 69 |
+
x1_patch = max(0, x1 // vae_scale_factor // patch_size)
|
| 70 |
+
y1_patch = max(0, y1 // vae_scale_factor // patch_size)
|
| 71 |
+
x2_patch = min(patches_width, math.ceil(x2 / vae_scale_factor / patch_size))
|
| 72 |
+
y2_patch = min(patches_height, math.ceil(y2 / vae_scale_factor / patch_size))
|
| 73 |
+
elif isinstance(x1, float):
|
| 74 |
+
x1_patch = max(0, int(x1) // vae_scale_factor // patch_size)
|
| 75 |
+
y1_patch = max(0, int(y1) // vae_scale_factor // patch_size)
|
| 76 |
+
x2_patch = min(patches_width, math.ceil(x2 / vae_scale_factor / patch_size))
|
| 77 |
+
y2_patch = min(patches_height, math.ceil(y2 / vae_scale_factor / patch_size))
|
| 78 |
+
else:
|
| 79 |
+
raise TypeError(f"Unsupported type: {type(x1)}")
|
| 80 |
+
|
| 81 |
+
# Create list of all token indices in this region
|
| 82 |
+
indices = []
|
| 83 |
+
for y in range(y1_patch, y2_patch):
|
| 84 |
+
for x in range(x1_patch, x2_patch):
|
| 85 |
+
idx = y * patches_width + x
|
| 86 |
+
indices.append(idx)
|
| 87 |
+
|
| 88 |
+
return indices
|
| 89 |
+
|
| 90 |
+
for b in range(batch_size):
|
| 91 |
+
# Get all bboxes for this batch item
|
| 92 |
+
batch_bboxes = bbox_lists[b] if b < len(bbox_lists) else []
|
| 93 |
+
|
| 94 |
+
# Process each bbox in the batch up to max_num_ids
|
| 95 |
+
for identity_idx, bbox in enumerate(batch_bboxes[:max_num_ids]):
|
| 96 |
+
# Get image token indices for this bbox
|
| 97 |
+
image_indices = bbox_to_token_indices(bbox)
|
| 98 |
+
|
| 99 |
+
# Calculate ID token slice for this identity
|
| 100 |
+
id_start = identity_idx * id_len
|
| 101 |
+
id_end = id_start + id_len
|
| 102 |
+
id_slice = slice(id_start, id_end)
|
| 103 |
+
|
| 104 |
+
# Enable attention between this region's image tokens and the identity's tokens
|
| 105 |
+
for h in range(num_heads):
|
| 106 |
+
for idx in image_indices:
|
| 107 |
+
mask[b, h, idx, id_slice] = True
|
| 108 |
+
|
| 109 |
+
return mask
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# FFN
|
| 115 |
+
def FeedForward(dim, mult=4):
|
| 116 |
+
inner_dim = int(dim * mult)
|
| 117 |
+
return nn.Sequential(
|
| 118 |
+
nn.LayerNorm(dim),
|
| 119 |
+
nn.Linear(dim, inner_dim, bias=False),
|
| 120 |
+
nn.GELU(),
|
| 121 |
+
nn.Linear(inner_dim, dim, bias=False),
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@dataclass
|
| 127 |
+
class FluxParams:
|
| 128 |
+
in_channels: int
|
| 129 |
+
vec_in_dim: int
|
| 130 |
+
context_in_dim: int
|
| 131 |
+
hidden_size: int
|
| 132 |
+
mlp_ratio: float
|
| 133 |
+
num_heads: int
|
| 134 |
+
depth: int
|
| 135 |
+
depth_single_blocks: int
|
| 136 |
+
axes_dim: list[int]
|
| 137 |
+
theta: int
|
| 138 |
+
qkv_bias: bool
|
| 139 |
+
guidance_embed: bool
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class SiglipEmbedding(nn.Module):
|
| 143 |
+
def __init__(self, siglip_path = "google/siglip-base-patch16-256-i18n", use_matting=False):
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.model = SiglipModel.from_pretrained(siglip_path).vision_model.to(torch.bfloat16)
|
| 146 |
+
self.processor = AutoProcessor.from_pretrained(siglip_path)
|
| 147 |
+
self.model.to(torch.cuda.current_device())
|
| 148 |
+
|
| 149 |
+
# BiRefNet matting setup
|
| 150 |
+
self.use_matting = use_matting
|
| 151 |
+
if self.use_matting:
|
| 152 |
+
self.birefnet = AutoModelForImageSegmentation.from_pretrained(
|
| 153 |
+
'briaai/RMBG-2.0', trust_remote_code=True).to(torch.cuda.current_device(), dtype=torch.bfloat16)
|
| 154 |
+
# Apply half precision to the entire model after loading
|
| 155 |
+
self.matting_transform = transforms.Compose([
|
| 156 |
+
# transforms.Resize((512, 512)),
|
| 157 |
+
transforms.ToTensor(),
|
| 158 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 159 |
+
])
|
| 160 |
+
|
| 161 |
+
def apply_matting(self, image):
|
| 162 |
+
"""Apply BiRefNet matting to remove background from image"""
|
| 163 |
+
if not self.use_matting:
|
| 164 |
+
return image
|
| 165 |
+
|
| 166 |
+
# Convert to input format and move to GPU
|
| 167 |
+
input_image = self.matting_transform(image).unsqueeze(0).to(torch.cuda.current_device(), dtype=torch.bfloat16)
|
| 168 |
+
|
| 169 |
+
# Generate prediction
|
| 170 |
+
with torch.no_grad(), autocast(dtype=torch.bfloat16):
|
| 171 |
+
preds = self.birefnet(input_image)[-1].sigmoid().cpu()
|
| 172 |
+
|
| 173 |
+
# Process the mask
|
| 174 |
+
pred = preds[0].squeeze().float()
|
| 175 |
+
pred_pil = transforms.ToPILImage()(pred)
|
| 176 |
+
mask = pred_pil.resize(image.size)
|
| 177 |
+
binary_mask = mask.convert("L")
|
| 178 |
+
|
| 179 |
+
# Create a new image with black background
|
| 180 |
+
result = Image.new("RGB", image.size, (0, 0, 0))
|
| 181 |
+
result.paste(image, (0, 0), binary_mask)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
return result
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def get_id_embedding(self, refimage):
|
| 188 |
+
'''
|
| 189 |
+
refimage is a list (batch) of list (num of person) of PIL images
|
| 190 |
+
considering the whole batch, the number of person is fixed
|
| 191 |
+
'''
|
| 192 |
+
siglip_embedding = []
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
if isinstance(refimage, list):
|
| 196 |
+
batch_size = len(refimage)
|
| 197 |
+
for batch_idx, refimage_batch in enumerate(refimage):
|
| 198 |
+
# Apply matting if enabled
|
| 199 |
+
if self.use_matting:
|
| 200 |
+
|
| 201 |
+
processed_images = [self.apply_matting(img) for img in refimage_batch]
|
| 202 |
+
else:
|
| 203 |
+
processed_images = refimage_batch
|
| 204 |
+
|
| 205 |
+
pixel_values = self.processor(images=processed_images, return_tensors="pt").pixel_values
|
| 206 |
+
# device
|
| 207 |
+
pixel_values = pixel_values.to(torch.cuda.current_device(), dtype=torch.bfloat16)
|
| 208 |
+
last_hidden_state = self.model(pixel_values).last_hidden_state # 2, 256 768
|
| 209 |
+
# pooled_output = self.model(pixel_values).pooler_output # 2, 768
|
| 210 |
+
siglip_embedding.append(last_hidden_state)
|
| 211 |
+
# siglip_embedding.append(pooled_output) # 2, 768
|
| 212 |
+
siglip_embedding = torch.stack(siglip_embedding, dim=0) # shape ([batch_size, num_of_person, 256, 768])
|
| 213 |
+
|
| 214 |
+
if batch_size < 4:
|
| 215 |
+
# run additional times to avoid the first time cuda memory allocation overhead
|
| 216 |
+
for _ in range(4 - batch_size):
|
| 217 |
+
pixel_values = self.processor(images=processed_images, return_tensors="pt").pixel_values
|
| 218 |
+
# device
|
| 219 |
+
pixel_values = pixel_values.to(torch.cuda.current_device(), dtype=torch.bfloat16)
|
| 220 |
+
last_hidden_state = self.model(pixel_values).last_hidden_state
|
| 221 |
+
|
| 222 |
+
elif isinstance(refimage, torch.Tensor):
|
| 223 |
+
# refimage is a tensor of shape (batch_size, num_of_person, 3, H, W)
|
| 224 |
+
batch_size, num_of_person, C, H, W = refimage.shape
|
| 225 |
+
refimage = refimage.view(batch_size * num_of_person, C, H, W)
|
| 226 |
+
refimage = refimage.to(torch.cuda.current_device(), dtype=torch.bfloat16)
|
| 227 |
+
last_hidden_state = self.model(refimage).last_hidden_state
|
| 228 |
+
siglip_embedding = last_hidden_state.view(batch_size, num_of_person, 256, 768)
|
| 229 |
+
|
| 230 |
+
return siglip_embedding
|
| 231 |
+
|
| 232 |
+
def forward(self, refimage):
|
| 233 |
+
return self.get_id_embedding(refimage)
|
| 234 |
+
|
| 235 |
+
class Flux(nn.Module):
|
| 236 |
+
"""
|
| 237 |
+
Transformer model for flow matching on sequences.
|
| 238 |
+
"""
|
| 239 |
+
_supports_gradient_checkpointing = True
|
| 240 |
+
|
| 241 |
+
def __init__(self, params: FluxParams):
|
| 242 |
+
super().__init__()
|
| 243 |
+
|
| 244 |
+
self.params = params
|
| 245 |
+
self.in_channels = params.in_channels
|
| 246 |
+
self.out_channels = self.in_channels
|
| 247 |
+
if params.hidden_size % params.num_heads != 0:
|
| 248 |
+
raise ValueError(
|
| 249 |
+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
| 250 |
+
)
|
| 251 |
+
pe_dim = params.hidden_size // params.num_heads
|
| 252 |
+
if sum(params.axes_dim) != pe_dim:
|
| 253 |
+
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
| 254 |
+
self.hidden_size = params.hidden_size
|
| 255 |
+
self.num_heads = params.num_heads
|
| 256 |
+
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
| 257 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
| 258 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
| 259 |
+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
| 260 |
+
self.guidance_in = (
|
| 261 |
+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
| 262 |
+
)
|
| 263 |
+
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
| 264 |
+
|
| 265 |
+
self.double_blocks = nn.ModuleList(
|
| 266 |
+
[
|
| 267 |
+
DoubleStreamBlock(
|
| 268 |
+
self.hidden_size,
|
| 269 |
+
self.num_heads,
|
| 270 |
+
mlp_ratio=params.mlp_ratio,
|
| 271 |
+
qkv_bias=params.qkv_bias,
|
| 272 |
+
)
|
| 273 |
+
for _ in range(params.depth)
|
| 274 |
+
]
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
self.single_blocks = nn.ModuleList(
|
| 278 |
+
[
|
| 279 |
+
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
| 280 |
+
for _ in range(params.depth_single_blocks)
|
| 281 |
+
]
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
| 285 |
+
self.gradient_checkpointing = False
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# use cross attention
|
| 291 |
+
self.ipa_arc = nn.ModuleList([
|
| 292 |
+
PerceiverAttentionCA(dim=self.hidden_size, kv_dim=self.hidden_size, heads=self.num_heads)
|
| 293 |
+
for _ in range(self.params.depth_single_blocks + self.params.depth)
|
| 294 |
+
])
|
| 295 |
+
self.ipa_sig = nn.ModuleList([
|
| 296 |
+
PerceiverAttentionCA(dim=self.hidden_size, kv_dim=self.hidden_size, heads=self.num_heads)
|
| 297 |
+
for _ in range(self.params.depth_single_blocks + self.params.depth)
|
| 298 |
+
])
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
self.arcface_in_arc = nn.Sequential(
|
| 303 |
+
nn.Linear(512, 4 * self.hidden_size, bias=True),
|
| 304 |
+
nn.GELU(),
|
| 305 |
+
nn.LayerNorm(4 * self.hidden_size),
|
| 306 |
+
nn.Linear(4 * self.hidden_size, 8 * self.hidden_size, bias=True),
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
self.arcface_in_sig = nn.Sequential(
|
| 311 |
+
nn.Linear(512, 4 * self.hidden_size, bias=True),
|
| 312 |
+
nn.GELU(),
|
| 313 |
+
nn.LayerNorm(4 * self.hidden_size),
|
| 314 |
+
nn.Linear(4 * self.hidden_size, 8 * self.hidden_size, bias=True),
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
self.siglip_in_sig = nn.Sequential(
|
| 318 |
+
nn.Linear(768, self.hidden_size, bias=True),
|
| 319 |
+
nn.GELU(),
|
| 320 |
+
nn.LayerNorm(self.hidden_size),
|
| 321 |
+
nn.Linear(self.hidden_size, self.hidden_size, bias=True),
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def lq_in_arc(self, txt_lq, siglip_embeddings, arcface_embeddings):
|
| 326 |
+
"""
|
| 327 |
+
Process the siglip and arcface embeddings.
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
# shape of arcface: (num_refs, bs, 512)
|
| 331 |
+
arcface_embeddings = self.arcface_in_arc(arcface_embeddings)
|
| 332 |
+
# shape of arcface: (num_refs, bs, 4*hidden_size)
|
| 333 |
+
# 4*hidden_size -> 4 tokens of hidden_size
|
| 334 |
+
arcface_embeddings = rearrange(arcface_embeddings, 'b n (t d) -> b n t d', t=8, d=self.hidden_size)
|
| 335 |
+
# (num_ref, tokens, hidden_size) -> (bs, num_refs*tokens, hidden_size)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
arcface_embeddings = arcface_embeddings.permute(1, 0, 2, 3) # (n, b, t, d) -> (b, n, t, d)
|
| 339 |
+
|
| 340 |
+
arcface_embeddings = rearrange(arcface_embeddings, 'b n t d -> b (n t) d')
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
return arcface_embeddings
|
| 345 |
+
|
| 346 |
+
def lq_in_sig(self, txt_lq, siglip_embeddings, arcface_embeddings):
|
| 347 |
+
"""
|
| 348 |
+
Process the siglip and arcface embeddings.
|
| 349 |
+
"""
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
# shape of arcface: (num_refs, bs, 512)
|
| 353 |
+
arcface_embeddings = self.arcface_in_sig(arcface_embeddings)
|
| 354 |
+
|
| 355 |
+
arcface_embeddings = rearrange(arcface_embeddings, 'b n (t d) -> b n t d', t=8, d=self.hidden_size)
|
| 356 |
+
# (num_ref, tokens, hidden_size) -> (bs, num_refs*tokens, hidden_size)
|
| 357 |
+
|
| 358 |
+
arcface_embeddings = arcface_embeddings.permute(1, 0, 2, 3) # (n, b, t, d) -> (b, n, t, d)
|
| 359 |
+
|
| 360 |
+
siglip_embeddings = self.siglip_in_sig(siglip_embeddings) # (bs, num_refs, 256, 768) -> (bs, num_refs, 4*hidden_size)
|
| 361 |
+
|
| 362 |
+
# concat in token dimension
|
| 363 |
+
arcface_embeddings = torch.cat((siglip_embeddings, arcface_embeddings), dim=2) # (bs, num_refs, 4, hidden_size) cat (bs, num_refs, 4, hidden_size) -> (bs, num_refs, 8, hidden_size)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
arcface_embeddings = rearrange(arcface_embeddings, 'b n t d -> b (n t) d')
|
| 367 |
+
return arcface_embeddings
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 372 |
+
if hasattr(module, "gradient_checkpointing"):
|
| 373 |
+
module.gradient_checkpointing = value
|
| 374 |
+
|
| 375 |
+
@property
|
| 376 |
+
def attn_processors(self):
|
| 377 |
+
# set recursively
|
| 378 |
+
processors = {} # type: dict[str, nn.Module]
|
| 379 |
+
|
| 380 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
|
| 381 |
+
if hasattr(module, "set_processor"):
|
| 382 |
+
processors[f"{name}.processor"] = module.processor
|
| 383 |
+
|
| 384 |
+
for sub_name, child in module.named_children():
|
| 385 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 386 |
+
|
| 387 |
+
return processors
|
| 388 |
+
|
| 389 |
+
for name, module in self.named_children():
|
| 390 |
+
fn_recursive_add_processors(name, module, processors)
|
| 391 |
+
|
| 392 |
+
return processors
|
| 393 |
+
|
| 394 |
+
def set_attn_processor(self, processor):
|
| 395 |
+
r"""
|
| 396 |
+
Sets the attention processor to use to compute attention.
|
| 397 |
+
|
| 398 |
+
Parameters:
|
| 399 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 400 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 401 |
+
for **all** `Attention` layers.
|
| 402 |
+
|
| 403 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 404 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 405 |
+
|
| 406 |
+
"""
|
| 407 |
+
count = len(self.attn_processors.keys())
|
| 408 |
+
|
| 409 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 410 |
+
raise ValueError(
|
| 411 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 412 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 416 |
+
if hasattr(module, "set_processor"):
|
| 417 |
+
if not isinstance(processor, dict):
|
| 418 |
+
module.set_processor(processor)
|
| 419 |
+
else:
|
| 420 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 421 |
+
|
| 422 |
+
for sub_name, child in module.named_children():
|
| 423 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 424 |
+
|
| 425 |
+
for name, module in self.named_children():
|
| 426 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def forward(
|
| 431 |
+
self,
|
| 432 |
+
img: Tensor,
|
| 433 |
+
img_ids: Tensor,
|
| 434 |
+
txt: Tensor,
|
| 435 |
+
txt_ids: Tensor,
|
| 436 |
+
timesteps: Tensor,
|
| 437 |
+
y: Tensor,
|
| 438 |
+
guidance: Tensor | None = None,
|
| 439 |
+
siglip_embeddings: Tensor | None = None, # (bs, num_refs, 256, 768)
|
| 440 |
+
arcface_embeddings: Tensor | None = None, # (bs, num_refs, 512)
|
| 441 |
+
bbox_lists: list | None = None, # list of list of bboxes, bbox_lists[i] is for the i-th batch, each has different number of bboxes (ids), which should align with the dim1 of arcface_embeddings. This is used to replace bbox_A and bbox_B, which should be discarded, but remained for compatibility.
|
| 442 |
+
use_mask: bool = True,
|
| 443 |
+
id_weight: float = 1.0,
|
| 444 |
+
siglip_weight: float = 1.0,
|
| 445 |
+
siglip_mask = None,
|
| 446 |
+
arc_mask = None,
|
| 447 |
+
|
| 448 |
+
img_height: int = 512,
|
| 449 |
+
img_width: int = 512,
|
| 450 |
+
) -> Tensor:
|
| 451 |
+
if img.ndim != 3 or txt.ndim != 3:
|
| 452 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
| 453 |
+
|
| 454 |
+
# running on sequences img
|
| 455 |
+
img = self.img_in(img)
|
| 456 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
| 457 |
+
if self.params.guidance_embed:
|
| 458 |
+
if guidance is None:
|
| 459 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
| 460 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
| 461 |
+
vec = vec + self.vector_in(y)
|
| 462 |
+
txt = self.txt_in(txt)
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
text_length = txt.shape[1]
|
| 468 |
+
img_length = img.shape[1]
|
| 469 |
+
|
| 470 |
+
img_end = img.shape[1]
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
use_ip = arcface_embeddings is not None
|
| 474 |
+
|
| 475 |
+
if use_ip:
|
| 476 |
+
|
| 477 |
+
id_embeddings = self.lq_in_arc(None, siglip_embeddings, arcface_embeddings)
|
| 478 |
+
siglip_embeddings = self.lq_in_sig(None, siglip_embeddings, arcface_embeddings)
|
| 479 |
+
|
| 480 |
+
text_length = txt.shape[1] # update text_length after adding learnable query
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
# 8 tokens for arcface, 256 tokens for siglip
|
| 484 |
+
id_len = 8
|
| 485 |
+
siglip_len = 256 + 8
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
if bbox_lists is not None and use_mask and (arc_mask is None or siglip_mask is None):
|
| 490 |
+
arc_mask = create_person_cross_attention_mask_varlen(
|
| 491 |
+
batch_size=img.shape[0],
|
| 492 |
+
num_heads=self.params.num_heads,
|
| 493 |
+
# txt_len=text_length,
|
| 494 |
+
img_len=img_length,
|
| 495 |
+
id_len=id_len,
|
| 496 |
+
bbox_lists=bbox_lists,
|
| 497 |
+
max_num_ids=len(bbox_lists[0]),
|
| 498 |
+
original_width=img_width,
|
| 499 |
+
original_height= img_height,
|
| 500 |
+
).to(img.device)
|
| 501 |
+
siglip_mask = create_person_cross_attention_mask_varlen(
|
| 502 |
+
batch_size=img.shape[0],
|
| 503 |
+
num_heads=self.params.num_heads,
|
| 504 |
+
# txt_len=text_length,
|
| 505 |
+
img_len=img_length,
|
| 506 |
+
id_len=siglip_len,
|
| 507 |
+
bbox_lists=bbox_lists,
|
| 508 |
+
max_num_ids=len(bbox_lists[0]),
|
| 509 |
+
original_width=img_width,
|
| 510 |
+
original_height= img_height,
|
| 511 |
+
).to(img.device)
|
| 512 |
+
else:
|
| 513 |
+
arc_mask = None
|
| 514 |
+
siglip_mask = None
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
# update text_ids and id_ids
|
| 519 |
+
txt_ids = torch.zeros((txt.shape[0], text_length, 3)).to(img_ids.device) # (bs, T, 3)
|
| 520 |
+
|
| 521 |
+
ids = torch.cat((txt_ids, img_ids), dim=1) # (bs, T + I + ID, 3)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
pe = self.pe_embedder(ids)
|
| 525 |
+
|
| 526 |
+
# ipa
|
| 527 |
+
ipa_idx = 0
|
| 528 |
+
|
| 529 |
+
for index_block, block in enumerate(self.double_blocks):
|
| 530 |
+
if self.training and self.gradient_checkpointing:
|
| 531 |
+
img, txt = torch.utils.checkpoint.checkpoint(
|
| 532 |
+
block,
|
| 533 |
+
img=img,
|
| 534 |
+
txt=txt,
|
| 535 |
+
vec=vec,
|
| 536 |
+
pe=pe,
|
| 537 |
+
# mask=mask,
|
| 538 |
+
text_length=text_length,
|
| 539 |
+
image_length=img_length,
|
| 540 |
+
# return_map = False,
|
| 541 |
+
use_reentrant=False,
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
else:
|
| 547 |
+
img, txt= block(
|
| 548 |
+
img=img,
|
| 549 |
+
txt=txt,
|
| 550 |
+
vec=vec,
|
| 551 |
+
pe=pe,
|
| 552 |
+
text_length=text_length,
|
| 553 |
+
image_length=img_length,
|
| 554 |
+
# return_map=False,
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
if use_ip:
|
| 559 |
+
|
| 560 |
+
img = img + id_weight * self.ipa_arc[ipa_idx](id_embeddings, img, mask=arc_mask) + siglip_weight * self.ipa_sig[ipa_idx](siglip_embeddings, img, mask=siglip_mask)
|
| 561 |
+
ipa_idx += 1
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
# for block in self.single_blocks:
|
| 570 |
+
img = torch.cat((txt, img), 1)
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
for index_block, block in enumerate(self.single_blocks):
|
| 574 |
+
if self.training and self.gradient_checkpointing:
|
| 575 |
+
img = torch.utils.checkpoint.checkpoint(
|
| 576 |
+
block,
|
| 577 |
+
img, vec=vec, pe=pe, #mask=mask,
|
| 578 |
+
text_length=text_length,
|
| 579 |
+
image_length=img_length,
|
| 580 |
+
return_map=False,
|
| 581 |
+
use_reentrant=False
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
else:
|
| 585 |
+
img = block(img, vec=vec, pe=pe,text_length=text_length, image_length=img_length, return_map=False)
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
# IPA
|
| 591 |
+
if use_ip:
|
| 592 |
+
txt, real_img = img[:, :text_length, :], img[:, text_length:, :]
|
| 593 |
+
|
| 594 |
+
id_ca = id_weight * self.ipa_arc[ipa_idx](id_embeddings, real_img, mask=arc_mask) + siglip_weight * self.ipa_sig[ipa_idx](siglip_embeddings, real_img, mask=siglip_mask)
|
| 595 |
+
|
| 596 |
+
real_img = real_img + id_ca
|
| 597 |
+
img = torch.cat((txt, real_img), dim=1)
|
| 598 |
+
ipa_idx += 1
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
img = img[:, txt.shape[1] :, ...]
|
| 605 |
+
# index img
|
| 606 |
+
img = img[:, :img_end, ...]
|
| 607 |
+
|
| 608 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
| 609 |
+
|
| 610 |
+
return img
|
withanyone/flux/modules/__pycache__/autoencoder.cpython-310.pyc
ADDED
|
Binary file (9.09 kB). View file
|
|
|
withanyone/flux/modules/__pycache__/conditioner.cpython-310.pyc
ADDED
|
Binary file (1.52 kB). View file
|
|
|
withanyone/flux/modules/__pycache__/layers.cpython-310.pyc
ADDED
|
Binary file (18 kB). View file
|
|
|
withanyone/flux/modules/autoencoder.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
|
| 2 |
+
# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
|
| 3 |
+
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
from torch import Tensor, nn
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class AutoEncoderParams:
|
| 25 |
+
resolution: int
|
| 26 |
+
in_channels: int
|
| 27 |
+
ch: int
|
| 28 |
+
out_ch: int
|
| 29 |
+
ch_mult: list[int]
|
| 30 |
+
num_res_blocks: int
|
| 31 |
+
z_channels: int
|
| 32 |
+
scale_factor: float
|
| 33 |
+
shift_factor: float
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def swish(x: Tensor) -> Tensor:
|
| 37 |
+
return x * torch.sigmoid(x)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class AttnBlock(nn.Module):
|
| 41 |
+
def __init__(self, in_channels: int):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.in_channels = in_channels
|
| 44 |
+
|
| 45 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 46 |
+
|
| 47 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 48 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 49 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 50 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 51 |
+
|
| 52 |
+
def attention(self, h_: Tensor) -> Tensor:
|
| 53 |
+
h_ = self.norm(h_)
|
| 54 |
+
q = self.q(h_)
|
| 55 |
+
k = self.k(h_)
|
| 56 |
+
v = self.v(h_)
|
| 57 |
+
|
| 58 |
+
b, c, h, w = q.shape
|
| 59 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
| 60 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
| 61 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
| 62 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
| 63 |
+
|
| 64 |
+
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
| 65 |
+
|
| 66 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 67 |
+
return x + self.proj_out(self.attention(x))
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class ResnetBlock(nn.Module):
|
| 71 |
+
def __init__(self, in_channels: int, out_channels: int):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.in_channels = in_channels
|
| 74 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 75 |
+
self.out_channels = out_channels
|
| 76 |
+
|
| 77 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 78 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 79 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
| 80 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 81 |
+
if self.in_channels != self.out_channels:
|
| 82 |
+
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
h = x
|
| 86 |
+
h = self.norm1(h)
|
| 87 |
+
h = swish(h)
|
| 88 |
+
h = self.conv1(h)
|
| 89 |
+
|
| 90 |
+
h = self.norm2(h)
|
| 91 |
+
h = swish(h)
|
| 92 |
+
h = self.conv2(h)
|
| 93 |
+
|
| 94 |
+
if self.in_channels != self.out_channels:
|
| 95 |
+
x = self.nin_shortcut(x)
|
| 96 |
+
|
| 97 |
+
return x + h
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class Downsample(nn.Module):
|
| 101 |
+
def __init__(self, in_channels: int):
|
| 102 |
+
super().__init__()
|
| 103 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 104 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
| 105 |
+
|
| 106 |
+
def forward(self, x: Tensor):
|
| 107 |
+
pad = (0, 1, 0, 1)
|
| 108 |
+
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
| 109 |
+
x = self.conv(x)
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class Upsample(nn.Module):
|
| 114 |
+
def __init__(self, in_channels: int):
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
| 117 |
+
|
| 118 |
+
def forward(self, x: Tensor):
|
| 119 |
+
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 120 |
+
x = self.conv(x)
|
| 121 |
+
return x
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class Encoder(nn.Module):
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
resolution: int,
|
| 128 |
+
in_channels: int,
|
| 129 |
+
ch: int,
|
| 130 |
+
ch_mult: list[int],
|
| 131 |
+
num_res_blocks: int,
|
| 132 |
+
z_channels: int,
|
| 133 |
+
):
|
| 134 |
+
super().__init__()
|
| 135 |
+
self.ch = ch
|
| 136 |
+
self.num_resolutions = len(ch_mult)
|
| 137 |
+
self.num_res_blocks = num_res_blocks
|
| 138 |
+
self.resolution = resolution
|
| 139 |
+
self.in_channels = in_channels
|
| 140 |
+
# downsampling
|
| 141 |
+
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
| 142 |
+
|
| 143 |
+
curr_res = resolution
|
| 144 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 145 |
+
self.in_ch_mult = in_ch_mult
|
| 146 |
+
self.down = nn.ModuleList()
|
| 147 |
+
block_in = self.ch
|
| 148 |
+
for i_level in range(self.num_resolutions):
|
| 149 |
+
block = nn.ModuleList()
|
| 150 |
+
attn = nn.ModuleList()
|
| 151 |
+
block_in = ch * in_ch_mult[i_level]
|
| 152 |
+
block_out = ch * ch_mult[i_level]
|
| 153 |
+
for _ in range(self.num_res_blocks):
|
| 154 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 155 |
+
block_in = block_out
|
| 156 |
+
down = nn.Module()
|
| 157 |
+
down.block = block
|
| 158 |
+
down.attn = attn
|
| 159 |
+
if i_level != self.num_resolutions - 1:
|
| 160 |
+
down.downsample = Downsample(block_in)
|
| 161 |
+
curr_res = curr_res // 2
|
| 162 |
+
self.down.append(down)
|
| 163 |
+
|
| 164 |
+
# middle
|
| 165 |
+
self.mid = nn.Module()
|
| 166 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 167 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 168 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 169 |
+
|
| 170 |
+
# end
|
| 171 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
| 172 |
+
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
| 173 |
+
|
| 174 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 175 |
+
# downsampling
|
| 176 |
+
hs = [self.conv_in(x)]
|
| 177 |
+
for i_level in range(self.num_resolutions):
|
| 178 |
+
for i_block in range(self.num_res_blocks):
|
| 179 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
| 180 |
+
if len(self.down[i_level].attn) > 0:
|
| 181 |
+
h = self.down[i_level].attn[i_block](h)
|
| 182 |
+
hs.append(h)
|
| 183 |
+
if i_level != self.num_resolutions - 1:
|
| 184 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 185 |
+
|
| 186 |
+
# middle
|
| 187 |
+
h = hs[-1]
|
| 188 |
+
h = self.mid.block_1(h)
|
| 189 |
+
h = self.mid.attn_1(h)
|
| 190 |
+
h = self.mid.block_2(h)
|
| 191 |
+
# end
|
| 192 |
+
h = self.norm_out(h)
|
| 193 |
+
h = swish(h)
|
| 194 |
+
h = self.conv_out(h)
|
| 195 |
+
return h
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class Decoder(nn.Module):
|
| 199 |
+
def __init__(
|
| 200 |
+
self,
|
| 201 |
+
ch: int,
|
| 202 |
+
out_ch: int,
|
| 203 |
+
ch_mult: list[int],
|
| 204 |
+
num_res_blocks: int,
|
| 205 |
+
in_channels: int,
|
| 206 |
+
resolution: int,
|
| 207 |
+
z_channels: int,
|
| 208 |
+
):
|
| 209 |
+
super().__init__()
|
| 210 |
+
self.ch = ch
|
| 211 |
+
self.num_resolutions = len(ch_mult)
|
| 212 |
+
self.num_res_blocks = num_res_blocks
|
| 213 |
+
self.resolution = resolution
|
| 214 |
+
self.in_channels = in_channels
|
| 215 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
| 216 |
+
|
| 217 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 218 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
| 219 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
| 220 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
| 221 |
+
|
| 222 |
+
# z to block_in
|
| 223 |
+
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
| 224 |
+
|
| 225 |
+
# middle
|
| 226 |
+
self.mid = nn.Module()
|
| 227 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 228 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 229 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 230 |
+
|
| 231 |
+
# upsampling
|
| 232 |
+
self.up = nn.ModuleList()
|
| 233 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 234 |
+
block = nn.ModuleList()
|
| 235 |
+
attn = nn.ModuleList()
|
| 236 |
+
block_out = ch * ch_mult[i_level]
|
| 237 |
+
for _ in range(self.num_res_blocks + 1):
|
| 238 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 239 |
+
block_in = block_out
|
| 240 |
+
up = nn.Module()
|
| 241 |
+
up.block = block
|
| 242 |
+
up.attn = attn
|
| 243 |
+
if i_level != 0:
|
| 244 |
+
up.upsample = Upsample(block_in)
|
| 245 |
+
curr_res = curr_res * 2
|
| 246 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 247 |
+
|
| 248 |
+
# end
|
| 249 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
| 250 |
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
| 251 |
+
|
| 252 |
+
def forward(self, z: Tensor) -> Tensor:
|
| 253 |
+
# z to block_in
|
| 254 |
+
h = self.conv_in(z)
|
| 255 |
+
|
| 256 |
+
# middle
|
| 257 |
+
h = self.mid.block_1(h)
|
| 258 |
+
h = self.mid.attn_1(h)
|
| 259 |
+
h = self.mid.block_2(h)
|
| 260 |
+
|
| 261 |
+
# upsampling
|
| 262 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 263 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 264 |
+
h = self.up[i_level].block[i_block](h)
|
| 265 |
+
if len(self.up[i_level].attn) > 0:
|
| 266 |
+
h = self.up[i_level].attn[i_block](h)
|
| 267 |
+
if i_level != 0:
|
| 268 |
+
h = self.up[i_level].upsample(h)
|
| 269 |
+
|
| 270 |
+
# end
|
| 271 |
+
h = self.norm_out(h)
|
| 272 |
+
h = swish(h)
|
| 273 |
+
h = self.conv_out(h)
|
| 274 |
+
return h
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class DiagonalGaussian(nn.Module):
|
| 278 |
+
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
| 279 |
+
super().__init__()
|
| 280 |
+
self.sample = sample
|
| 281 |
+
self.chunk_dim = chunk_dim
|
| 282 |
+
|
| 283 |
+
def forward(self, z: Tensor) -> Tensor:
|
| 284 |
+
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
| 285 |
+
if self.sample:
|
| 286 |
+
std = torch.exp(0.5 * logvar)
|
| 287 |
+
return mean + std * torch.randn_like(mean)
|
| 288 |
+
else:
|
| 289 |
+
return mean
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class AutoEncoder(nn.Module):
|
| 293 |
+
def __init__(self, params: AutoEncoderParams):
|
| 294 |
+
super().__init__()
|
| 295 |
+
self.encoder = Encoder(
|
| 296 |
+
resolution=params.resolution,
|
| 297 |
+
in_channels=params.in_channels,
|
| 298 |
+
ch=params.ch,
|
| 299 |
+
ch_mult=params.ch_mult,
|
| 300 |
+
num_res_blocks=params.num_res_blocks,
|
| 301 |
+
z_channels=params.z_channels,
|
| 302 |
+
)
|
| 303 |
+
self.decoder = Decoder(
|
| 304 |
+
resolution=params.resolution,
|
| 305 |
+
in_channels=params.in_channels,
|
| 306 |
+
ch=params.ch,
|
| 307 |
+
out_ch=params.out_ch,
|
| 308 |
+
ch_mult=params.ch_mult,
|
| 309 |
+
num_res_blocks=params.num_res_blocks,
|
| 310 |
+
z_channels=params.z_channels,
|
| 311 |
+
)
|
| 312 |
+
self.reg = DiagonalGaussian()
|
| 313 |
+
|
| 314 |
+
self.scale_factor = params.scale_factor
|
| 315 |
+
self.shift_factor = params.shift_factor
|
| 316 |
+
|
| 317 |
+
def encode(self, x: Tensor) -> Tensor:
|
| 318 |
+
z = self.reg(self.encoder(x))
|
| 319 |
+
z = self.scale_factor * (z - self.shift_factor)
|
| 320 |
+
return z
|
| 321 |
+
|
| 322 |
+
def decode(self, z: Tensor) -> Tensor:
|
| 323 |
+
z = z / self.scale_factor + self.shift_factor
|
| 324 |
+
return self.decoder(z)
|
| 325 |
+
|
| 326 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 327 |
+
return self.decode(self.encode(x))
|
withanyone/flux/modules/conditioner.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
|
| 2 |
+
# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
|
| 3 |
+
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from torch import Tensor, nn
|
| 17 |
+
from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
|
| 18 |
+
T5Tokenizer)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class HFEmbedder(nn.Module):
|
| 22 |
+
def __init__(self, version: str, max_length: int, **hf_kwargs):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.is_clip = "clip" in version.lower()
|
| 25 |
+
self.max_length = max_length
|
| 26 |
+
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
| 27 |
+
|
| 28 |
+
if self.is_clip:
|
| 29 |
+
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
|
| 30 |
+
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
|
| 31 |
+
else:
|
| 32 |
+
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
|
| 33 |
+
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
|
| 34 |
+
|
| 35 |
+
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
| 36 |
+
|
| 37 |
+
def forward(self, text: list[str]) -> Tensor:
|
| 38 |
+
batch_encoding = self.tokenizer(
|
| 39 |
+
text,
|
| 40 |
+
truncation=True,
|
| 41 |
+
max_length=self.max_length,
|
| 42 |
+
return_length=False,
|
| 43 |
+
return_overflowing_tokens=False,
|
| 44 |
+
padding="max_length",
|
| 45 |
+
return_tensors="pt",
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
outputs = self.hf_module(
|
| 49 |
+
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
| 50 |
+
attention_mask=None,
|
| 51 |
+
output_hidden_states=False,
|
| 52 |
+
)
|
| 53 |
+
return outputs[self.output_key]
|
withanyone/flux/modules/layers.py
ADDED
|
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from torch import Tensor, nn
|
| 9 |
+
|
| 10 |
+
# from ..math import attention, rope
|
| 11 |
+
from ..math import rope
|
| 12 |
+
from ..math import attention
|
| 13 |
+
# from ..math import attention
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
TOKEN_AUG_IDX = 2048
|
| 17 |
+
|
| 18 |
+
class EmbedND(nn.Module):
|
| 19 |
+
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.dim = dim
|
| 22 |
+
self.theta = theta
|
| 23 |
+
self.axes_dim = axes_dim
|
| 24 |
+
|
| 25 |
+
def forward(self, ids: Tensor) -> Tensor:
|
| 26 |
+
n_axes = ids.shape[-1]
|
| 27 |
+
emb = torch.cat(
|
| 28 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
| 29 |
+
dim=-3,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
return emb.unsqueeze(1)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
| 36 |
+
"""
|
| 37 |
+
Create sinusoidal timestep embeddings.
|
| 38 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 39 |
+
These may be fractional.
|
| 40 |
+
:param dim: the dimension of the output.
|
| 41 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 42 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 43 |
+
"""
|
| 44 |
+
t = time_factor * t
|
| 45 |
+
half = dim // 2
|
| 46 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
| 47 |
+
t.device
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
args = t[:, None].float() * freqs[None]
|
| 51 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 52 |
+
if dim % 2:
|
| 53 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 54 |
+
if torch.is_floating_point(t):
|
| 55 |
+
embedding = embedding.to(t)
|
| 56 |
+
return embedding
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class MLPEmbedder(nn.Module):
|
| 60 |
+
def __init__(self, in_dim: int, hidden_dim: int):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
| 63 |
+
self.silu = nn.SiLU()
|
| 64 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 65 |
+
|
| 66 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 67 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
| 68 |
+
|
| 69 |
+
def reshape_tensor(x, heads):
|
| 70 |
+
# print("x in reshape_tensor", x.shape)
|
| 71 |
+
bs, length, width = x.shape
|
| 72 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
| 73 |
+
x = x.view(bs, length, heads, -1)
|
| 74 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
| 75 |
+
x = x.transpose(1, 2)
|
| 76 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
| 77 |
+
x = x.reshape(bs, heads, length, -1)
|
| 78 |
+
return x
|
| 79 |
+
class PerceiverAttentionCA(nn.Module):
|
| 80 |
+
def __init__(self, *, dim=3072, dim_head=64, heads=16, kv_dim=2048):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.scale = dim_head ** -0.5
|
| 83 |
+
self.dim_head = dim_head
|
| 84 |
+
self.heads = heads
|
| 85 |
+
inner_dim = dim_head * heads
|
| 86 |
+
|
| 87 |
+
self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
|
| 88 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 89 |
+
|
| 90 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
| 91 |
+
self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
|
| 92 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
| 93 |
+
|
| 94 |
+
def forward(self, x, latents, mask=None):
|
| 95 |
+
"""
|
| 96 |
+
Args:
|
| 97 |
+
x (torch.Tensor): image features
|
| 98 |
+
shape (b, n1, D)
|
| 99 |
+
latent (torch.Tensor): latent features
|
| 100 |
+
shape (b, n2, D)
|
| 101 |
+
"""
|
| 102 |
+
x = self.norm1(x)
|
| 103 |
+
latents = self.norm2(latents)
|
| 104 |
+
|
| 105 |
+
# print("x, latents in PerceiverAttentionCA", x.shape, latents.shape)
|
| 106 |
+
|
| 107 |
+
b, seq_len, _ = latents.shape
|
| 108 |
+
|
| 109 |
+
q = self.to_q(latents)
|
| 110 |
+
k, v = self.to_kv(x).chunk(2, dim=-1)
|
| 111 |
+
|
| 112 |
+
# print("q, k, v in PerceiverAttentionCA", q.shape, k.shape, v.shape)
|
| 113 |
+
|
| 114 |
+
q = reshape_tensor(q, self.heads)
|
| 115 |
+
k = reshape_tensor(k, self.heads)
|
| 116 |
+
v = reshape_tensor(v, self.heads)
|
| 117 |
+
|
| 118 |
+
# # attention
|
| 119 |
+
# scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
| 120 |
+
# weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
| 121 |
+
# print("is there any nan in weight:", torch.isnan(weight).any())
|
| 122 |
+
# if mask is not None:
|
| 123 |
+
# # Mask shape should be [batch_size, num_heads, q_len, kv_len]
|
| 124 |
+
# # weight = weight.masked_fill(mask == 0, float("-inf"))
|
| 125 |
+
# if mask.dtype == torch.bool:
|
| 126 |
+
# # Boolean mask: False values are masked out
|
| 127 |
+
# # print("Got boolean mask")
|
| 128 |
+
# weight = weight.masked_fill(~mask, -float('inf'))
|
| 129 |
+
# else:
|
| 130 |
+
# # Float mask: values are added directly to scores
|
| 131 |
+
# weight = weight + mask
|
| 132 |
+
# print("is there any nan in weight after mask:", torch.isnan(weight).any())
|
| 133 |
+
# weight = torch.softmax(weight, dim=-1)
|
| 134 |
+
# print("is there any nan in weight after softmax:", torch.isnan(weight).any())
|
| 135 |
+
# out = weight @ v
|
| 136 |
+
|
| 137 |
+
# use sdpa
|
| 138 |
+
# if mask is not None:
|
| 139 |
+
# print("mask shape in PerceiverAttentionCA", mask.shape)
|
| 140 |
+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
| 141 |
+
|
| 142 |
+
out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
|
| 143 |
+
|
| 144 |
+
return self.to_out(out)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class RMSNorm(torch.nn.Module):
|
| 150 |
+
def __init__(self, dim: int):
|
| 151 |
+
super().__init__()
|
| 152 |
+
self.scale = nn.Parameter(torch.ones(dim))
|
| 153 |
+
|
| 154 |
+
def forward(self, x: Tensor):
|
| 155 |
+
x_dtype = x.dtype
|
| 156 |
+
x = x.float()
|
| 157 |
+
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
| 158 |
+
return ((x * rrms) * self.scale.float()).to(dtype=x_dtype)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class QKNorm(torch.nn.Module):
|
| 162 |
+
def __init__(self, dim: int):
|
| 163 |
+
super().__init__()
|
| 164 |
+
self.query_norm = RMSNorm(dim)
|
| 165 |
+
self.key_norm = RMSNorm(dim)
|
| 166 |
+
|
| 167 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
| 168 |
+
q = self.query_norm(q)
|
| 169 |
+
k = self.key_norm(k)
|
| 170 |
+
return q.to(v), k.to(v)
|
| 171 |
+
|
| 172 |
+
class LoRALinearLayer(nn.Module):
|
| 173 |
+
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
|
| 174 |
+
super().__init__()
|
| 175 |
+
|
| 176 |
+
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
| 177 |
+
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
|
| 178 |
+
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
| 179 |
+
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
| 180 |
+
self.network_alpha = network_alpha
|
| 181 |
+
self.rank = rank
|
| 182 |
+
|
| 183 |
+
nn.init.normal_(self.down.weight, std=1 / rank)
|
| 184 |
+
nn.init.zeros_(self.up.weight)
|
| 185 |
+
|
| 186 |
+
def forward(self, hidden_states):
|
| 187 |
+
orig_dtype = hidden_states.dtype
|
| 188 |
+
dtype = self.down.weight.dtype
|
| 189 |
+
|
| 190 |
+
down_hidden_states = self.down(hidden_states.to(dtype))
|
| 191 |
+
up_hidden_states = self.up(down_hidden_states)
|
| 192 |
+
|
| 193 |
+
if self.network_alpha is not None:
|
| 194 |
+
up_hidden_states *= self.network_alpha / self.rank
|
| 195 |
+
|
| 196 |
+
return up_hidden_states.to(orig_dtype)
|
| 197 |
+
|
| 198 |
+
class FLuxSelfAttnProcessor:
|
| 199 |
+
def __call__(self, attn, x, pe, **attention_kwargs):
|
| 200 |
+
qkv = attn.qkv(x)
|
| 201 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 202 |
+
q, k = attn.norm(q, k, v)
|
| 203 |
+
x = attention(q, k, v, pe=pe)
|
| 204 |
+
x = attn.proj(x)
|
| 205 |
+
return x
|
| 206 |
+
|
| 207 |
+
class LoraFluxAttnProcessor(nn.Module):
|
| 208 |
+
|
| 209 |
+
def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
|
| 210 |
+
super().__init__()
|
| 211 |
+
self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
|
| 212 |
+
self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
|
| 213 |
+
self.lora_weight = lora_weight
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def __call__(self, attn, x, pe, **attention_kwargs):
|
| 217 |
+
qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight
|
| 218 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 219 |
+
q, k = attn.norm(q, k, v)
|
| 220 |
+
x = attention(q, k, v, pe=pe)
|
| 221 |
+
x = attn.proj(x) + self.proj_lora(x) * self.lora_weight
|
| 222 |
+
return x
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class SelfAttention(nn.Module):
|
| 226 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
| 227 |
+
super().__init__()
|
| 228 |
+
self.num_heads = num_heads
|
| 229 |
+
head_dim = dim // num_heads
|
| 230 |
+
|
| 231 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 232 |
+
self.norm = QKNorm(head_dim)
|
| 233 |
+
self.proj = nn.Linear(dim, dim)
|
| 234 |
+
def forward():
|
| 235 |
+
pass
|
| 236 |
+
|
| 237 |
+
@dataclass
|
| 238 |
+
class ModulationOut:
|
| 239 |
+
shift: Tensor
|
| 240 |
+
scale: Tensor
|
| 241 |
+
gate: Tensor
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class Modulation(nn.Module):
|
| 245 |
+
def __init__(self, dim: int, double: bool):
|
| 246 |
+
super().__init__()
|
| 247 |
+
self.is_double = double
|
| 248 |
+
self.multiplier = 6 if double else 3
|
| 249 |
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
| 250 |
+
|
| 251 |
+
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
| 252 |
+
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
| 253 |
+
|
| 254 |
+
return (
|
| 255 |
+
ModulationOut(*out[:3]),
|
| 256 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
class DoubleStreamBlockLoraProcessor(nn.Module):
|
| 260 |
+
def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
|
| 261 |
+
super().__init__()
|
| 262 |
+
self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
|
| 263 |
+
self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
|
| 264 |
+
self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
|
| 265 |
+
self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
|
| 266 |
+
self.lora_weight = lora_weight
|
| 267 |
+
|
| 268 |
+
def forward(self, attn, img, txt, vec, pe, mask, text_length, image_length, **attention_kwargs):
|
| 269 |
+
img_mod1, img_mod2 = attn.img_mod(vec)
|
| 270 |
+
txt_mod1, txt_mod2 = attn.txt_mod(vec)
|
| 271 |
+
|
| 272 |
+
# prepare image for attention
|
| 273 |
+
img_modulated = attn.img_norm1(img)
|
| 274 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
| 275 |
+
img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight
|
| 276 |
+
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
|
| 277 |
+
img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
|
| 278 |
+
|
| 279 |
+
# prepare txt for attention
|
| 280 |
+
txt_modulated = attn.txt_norm1(txt)
|
| 281 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
| 282 |
+
txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight
|
| 283 |
+
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
|
| 284 |
+
txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
|
| 285 |
+
|
| 286 |
+
# run actual attention
|
| 287 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
| 288 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
| 289 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
| 290 |
+
|
| 291 |
+
attn1 = attention(q, k, v, pe=pe, mask=mask, token_aug_idx=TOKEN_AUG_IDX, text_length=text_length, image_length=image_length)
|
| 292 |
+
txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
|
| 293 |
+
|
| 294 |
+
# calculate the img bloks
|
| 295 |
+
img = img + img_mod1.gate * (attn.img_attn.proj(img_attn) + self.proj_lora1(img_attn) * self.lora_weight)
|
| 296 |
+
img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
|
| 297 |
+
|
| 298 |
+
# calculate the txt bloks
|
| 299 |
+
txt = txt + txt_mod1.gate * (attn.txt_attn.proj(txt_attn) + self.proj_lora2(txt_attn) * self.lora_weight)
|
| 300 |
+
txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
return img, txt
|
| 304 |
+
|
| 305 |
+
class DoubleStreamBlockProcessor:
|
| 306 |
+
def __call__(self, attn, img, txt, vec, pe, mask, text_length, image_length, **attention_kwargs):
|
| 307 |
+
img_mod1, img_mod2 = attn.img_mod(vec)
|
| 308 |
+
txt_mod1, txt_mod2 = attn.txt_mod(vec)
|
| 309 |
+
|
| 310 |
+
# prepare image for attention
|
| 311 |
+
img_modulated = attn.img_norm1(img)
|
| 312 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
| 313 |
+
img_qkv = attn.img_attn.qkv(img_modulated)
|
| 314 |
+
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
|
| 315 |
+
img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
|
| 316 |
+
|
| 317 |
+
# prepare txt for attention
|
| 318 |
+
txt_modulated = attn.txt_norm1(txt)
|
| 319 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
| 320 |
+
txt_qkv = attn.txt_attn.qkv(txt_modulated)
|
| 321 |
+
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
|
| 322 |
+
txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
|
| 323 |
+
|
| 324 |
+
# run actual attention
|
| 325 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
| 326 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
| 327 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
attn1 = attention(q, k, v, pe=pe, mask=attention_kwargs.get("mask"), token_aug_idx=TOKEN_AUG_IDX,text_length=text_length, image_length=image_length)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
|
| 334 |
+
|
| 335 |
+
# calculate the img bloks
|
| 336 |
+
img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
|
| 337 |
+
img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
|
| 338 |
+
|
| 339 |
+
# calculate the txt bloks
|
| 340 |
+
txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
|
| 341 |
+
txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
return img, txt
|
| 345 |
+
|
| 346 |
+
class DoubleStreamBlock(nn.Module):
|
| 347 |
+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
|
| 348 |
+
super().__init__()
|
| 349 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 350 |
+
self.num_heads = num_heads
|
| 351 |
+
self.hidden_size = hidden_size
|
| 352 |
+
self.head_dim = hidden_size // num_heads
|
| 353 |
+
|
| 354 |
+
self.img_mod = Modulation(hidden_size, double=True)
|
| 355 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 356 |
+
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
| 357 |
+
|
| 358 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 359 |
+
self.img_mlp = nn.Sequential(
|
| 360 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
| 361 |
+
nn.GELU(approximate="tanh"),
|
| 362 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
self.txt_mod = Modulation(hidden_size, double=True)
|
| 366 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 367 |
+
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
| 368 |
+
|
| 369 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 370 |
+
self.txt_mlp = nn.Sequential(
|
| 371 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
| 372 |
+
nn.GELU(approximate="tanh"),
|
| 373 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
| 374 |
+
)
|
| 375 |
+
processor = DoubleStreamBlockProcessor()
|
| 376 |
+
self.set_processor(processor)
|
| 377 |
+
|
| 378 |
+
def set_processor(self, processor) -> None:
|
| 379 |
+
self.processor = processor
|
| 380 |
+
|
| 381 |
+
def get_processor(self):
|
| 382 |
+
return self.processor
|
| 383 |
+
|
| 384 |
+
def forward(
|
| 385 |
+
self,
|
| 386 |
+
img: Tensor,
|
| 387 |
+
txt: Tensor,
|
| 388 |
+
vec: Tensor,
|
| 389 |
+
pe: Tensor,
|
| 390 |
+
image_proj: Tensor = None,
|
| 391 |
+
ip_scale: float =1.0,
|
| 392 |
+
mask: Tensor | None = None,
|
| 393 |
+
text_length: int = None,
|
| 394 |
+
image_length: int = None,
|
| 395 |
+
return_map: bool = False,
|
| 396 |
+
**attention_kwargs
|
| 397 |
+
) -> tuple[Tensor, Tensor]:
|
| 398 |
+
if image_proj is None:
|
| 399 |
+
|
| 400 |
+
return self.processor(self, img, txt, vec, pe, mask, text_length, image_length)
|
| 401 |
+
else:
|
| 402 |
+
|
| 403 |
+
return self.processor(self, img, txt, vec, pe, mask, text_length, image_length, image_proj, ip_scale)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
class SingleStreamBlockLoraProcessor(nn.Module):
|
| 407 |
+
def __init__(self, dim: int, rank: int = 4, network_alpha = None, lora_weight: float = 1):
|
| 408 |
+
super().__init__()
|
| 409 |
+
self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
|
| 410 |
+
self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha)
|
| 411 |
+
self.lora_weight = lora_weight
|
| 412 |
+
|
| 413 |
+
def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, mask = None, text_length = None, image_length = None, return_map=False) -> Tensor:
|
| 414 |
+
|
| 415 |
+
mod, _ = attn.modulation(vec)
|
| 416 |
+
x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
|
| 417 |
+
qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
|
| 418 |
+
qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight
|
| 419 |
+
|
| 420 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
|
| 421 |
+
q, k = attn.norm(q, k, v)
|
| 422 |
+
|
| 423 |
+
# compute attention
|
| 424 |
+
|
| 425 |
+
attn_1 = attention(q, k, v, pe=pe, mask=mask, token_aug_idx=TOKEN_AUG_IDX,text_length=text_length, image_length=image_length)
|
| 426 |
+
|
| 427 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
| 428 |
+
output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
|
| 429 |
+
output = output + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) * self.lora_weight
|
| 430 |
+
output = x + mod.gate * output
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
return output
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
class SingleStreamBlockProcessor:
|
| 437 |
+
def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, mask: Tensor, text_length, image_length, return_map=False, **attention_kwargs) -> Tensor:
|
| 438 |
+
|
| 439 |
+
mod, _ = attn.modulation(vec)
|
| 440 |
+
x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
|
| 441 |
+
qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
|
| 442 |
+
|
| 443 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
|
| 444 |
+
q, k = attn.norm(q, k, v)
|
| 445 |
+
|
| 446 |
+
# compute attention
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
attn_1 = attention(q, k, v, pe=pe, mask=mask, token_aug_idx=TOKEN_AUG_IDX,text_length=text_length, image_length=image_length)
|
| 450 |
+
|
| 451 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
| 452 |
+
output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
|
| 453 |
+
output = x + mod.gate * output
|
| 454 |
+
|
| 455 |
+
return output
|
| 456 |
+
|
| 457 |
+
class SingleStreamBlock(nn.Module):
|
| 458 |
+
"""
|
| 459 |
+
A DiT block with parallel linear layers as described in
|
| 460 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
| 461 |
+
"""
|
| 462 |
+
|
| 463 |
+
def __init__(
|
| 464 |
+
self,
|
| 465 |
+
hidden_size: int,
|
| 466 |
+
num_heads: int,
|
| 467 |
+
mlp_ratio: float = 4.0,
|
| 468 |
+
qk_scale: float | None = None,
|
| 469 |
+
):
|
| 470 |
+
super().__init__()
|
| 471 |
+
self.hidden_dim = hidden_size
|
| 472 |
+
self.num_heads = num_heads
|
| 473 |
+
self.head_dim = hidden_size // num_heads
|
| 474 |
+
self.scale = qk_scale or self.head_dim**-0.5
|
| 475 |
+
|
| 476 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 477 |
+
# qkv and mlp_in
|
| 478 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
| 479 |
+
# proj and mlp_out
|
| 480 |
+
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
| 481 |
+
|
| 482 |
+
self.norm = QKNorm(self.head_dim)
|
| 483 |
+
|
| 484 |
+
self.hidden_size = hidden_size
|
| 485 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 486 |
+
|
| 487 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
| 488 |
+
self.modulation = Modulation(hidden_size, double=False)
|
| 489 |
+
|
| 490 |
+
processor = SingleStreamBlockProcessor()
|
| 491 |
+
self.set_processor(processor)
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def set_processor(self, processor) -> None:
|
| 495 |
+
self.processor = processor
|
| 496 |
+
|
| 497 |
+
def get_processor(self):
|
| 498 |
+
return self.processor
|
| 499 |
+
|
| 500 |
+
def forward(
|
| 501 |
+
self,
|
| 502 |
+
x: Tensor,
|
| 503 |
+
vec: Tensor,
|
| 504 |
+
pe: Tensor,
|
| 505 |
+
image_proj: Tensor | None = None,
|
| 506 |
+
ip_scale: float = 1.0,
|
| 507 |
+
mask: Tensor | None = None,
|
| 508 |
+
text_length: int | None = None,
|
| 509 |
+
image_length: int | None = None,
|
| 510 |
+
return_map: bool = False,
|
| 511 |
+
) -> Tensor:
|
| 512 |
+
if image_proj is None:
|
| 513 |
+
return self.processor(self, x, vec, pe, mask, text_length=text_length, image_length=image_length)
|
| 514 |
+
else:
|
| 515 |
+
return self.processor(self, x, vec, pe, mask, image_proj, ip_scale, text_length=text_length, image_length=image_length)
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
class LastLayer(nn.Module):
|
| 520 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
| 521 |
+
super().__init__()
|
| 522 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 523 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 524 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
| 525 |
+
|
| 526 |
+
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
| 527 |
+
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
| 528 |
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
| 529 |
+
x = self.linear(x)
|
| 530 |
+
return x
|
withanyone/flux/pipeline.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
|
| 2 |
+
# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
|
| 3 |
+
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
from typing import Literal
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from einops import rearrange
|
| 21 |
+
from PIL import ExifTags, Image
|
| 22 |
+
import torchvision.transforms.functional as TVF
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
from withanyone.flux.modules.layers import (
|
| 26 |
+
DoubleStreamBlockLoraProcessor,
|
| 27 |
+
DoubleStreamBlockProcessor,
|
| 28 |
+
SingleStreamBlockLoraProcessor,
|
| 29 |
+
SingleStreamBlockProcessor,
|
| 30 |
+
)
|
| 31 |
+
from withanyone.flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
|
| 32 |
+
from withanyone.flux.util import (
|
| 33 |
+
load_ae,
|
| 34 |
+
load_clip,
|
| 35 |
+
load_flow_model_no_lora,
|
| 36 |
+
load_flow_model_diffusers,
|
| 37 |
+
load_t5,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
from withanyone.flux.model import SiglipEmbedding, create_person_cross_attention_mask_varlen
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def preprocess_ref(raw_image: Image.Image, long_size: int = 512):
|
| 44 |
+
|
| 45 |
+
image_w, image_h = raw_image.size
|
| 46 |
+
|
| 47 |
+
if image_w >= image_h:
|
| 48 |
+
new_w = long_size
|
| 49 |
+
new_h = int((long_size / image_w) * image_h)
|
| 50 |
+
else:
|
| 51 |
+
new_h = long_size
|
| 52 |
+
new_w = int((long_size / image_h) * image_w)
|
| 53 |
+
|
| 54 |
+
raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS)
|
| 55 |
+
target_w = new_w // 16 * 16
|
| 56 |
+
target_h = new_h // 16 * 16
|
| 57 |
+
|
| 58 |
+
left = (new_w - target_w) // 2
|
| 59 |
+
top = (new_h - target_h) // 2
|
| 60 |
+
right = left + target_w
|
| 61 |
+
bottom = top + target_h
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
raw_image = raw_image.crop((left, top, right, bottom))
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
raw_image = raw_image.convert("RGB")
|
| 68 |
+
return raw_image
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
from io import BytesIO
|
| 72 |
+
import insightface
|
| 73 |
+
import numpy as np
|
| 74 |
+
class FaceExtractor:
|
| 75 |
+
def __init__(self, model_path = "./"):
|
| 76 |
+
|
| 77 |
+
self.model = insightface.app.FaceAnalysis(name = "antelopev2", root=model_path, providers=['CUDAExecutionProvider'])
|
| 78 |
+
self.model.prepare(ctx_id=0, det_thresh=0.45)
|
| 79 |
+
|
| 80 |
+
def extract_moref(self, img, bboxes, face_size_restriction=1):
|
| 81 |
+
"""
|
| 82 |
+
Extract faces from an image based on bounding boxes in JSON data.
|
| 83 |
+
Makes each face square and resizes to 512x512.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
img: PIL Image or image data
|
| 87 |
+
json_data: JSON object with 'bboxes' and 'crop' information
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
List of PIL Images, each 512x512, containing extracted faces
|
| 91 |
+
"""
|
| 92 |
+
# Ensure img is a PIL Image
|
| 93 |
+
try:
|
| 94 |
+
if not isinstance(img, Image.Image) and not isinstance(img, torch.Tensor):
|
| 95 |
+
img = Image.open(BytesIO(img))
|
| 96 |
+
|
| 97 |
+
# bboxes = json_data['bboxes']
|
| 98 |
+
# crop = json_data['crop']
|
| 99 |
+
# print("len of bboxes:", len(bboxes))
|
| 100 |
+
# Recalculate bounding boxes based on crop info
|
| 101 |
+
# new_bboxes = [recalculate_bbox(bbox, crop) for bbox in bboxes]
|
| 102 |
+
new_bboxes = bboxes
|
| 103 |
+
# any of the face is less than 100 * 100, we ignore this image
|
| 104 |
+
for bbox in new_bboxes:
|
| 105 |
+
x1, y1, x2, y2 = bbox
|
| 106 |
+
if x2 - x1 < face_size_restriction or y2 - y1 < face_size_restriction:
|
| 107 |
+
return []
|
| 108 |
+
# print("len of new_bboxes:", len(new_bboxes))
|
| 109 |
+
faces = []
|
| 110 |
+
for bbox in new_bboxes:
|
| 111 |
+
# print("processing bbox")
|
| 112 |
+
# Convert coordinates to integers
|
| 113 |
+
x1, y1, x2, y2 = map(int, bbox)
|
| 114 |
+
|
| 115 |
+
# Calculate width and height
|
| 116 |
+
width = x2 - x1
|
| 117 |
+
height = y2 - y1
|
| 118 |
+
|
| 119 |
+
# Make the bounding box square by expanding the shorter dimension
|
| 120 |
+
if width > height:
|
| 121 |
+
# Height is shorter, expand it
|
| 122 |
+
diff = width - height
|
| 123 |
+
y1 -= diff // 2
|
| 124 |
+
y2 += diff - (diff // 2) # Handle odd differences
|
| 125 |
+
elif height > width:
|
| 126 |
+
# Width is shorter, expand it
|
| 127 |
+
diff = height - width
|
| 128 |
+
x1 -= diff // 2
|
| 129 |
+
x2 += diff - (diff // 2) # Handle odd differences
|
| 130 |
+
|
| 131 |
+
# Ensure coordinates are within image boundaries
|
| 132 |
+
img_width, img_height = img.size
|
| 133 |
+
x1 = max(0, x1)
|
| 134 |
+
y1 = max(0, y1)
|
| 135 |
+
x2 = min(img_width, x2)
|
| 136 |
+
y2 = min(img_height, y2)
|
| 137 |
+
|
| 138 |
+
# Extract face region
|
| 139 |
+
face_region = img.crop((x1, y1, x2, y2))
|
| 140 |
+
|
| 141 |
+
# Resize to 512x512
|
| 142 |
+
face_region = face_region.resize((512, 512), Image.LANCZOS)
|
| 143 |
+
|
| 144 |
+
faces.append(face_region)
|
| 145 |
+
# print("len of faces:", len(faces))
|
| 146 |
+
return faces
|
| 147 |
+
except Exception as e:
|
| 148 |
+
print(f"Error processing image: {e}")
|
| 149 |
+
return []
|
| 150 |
+
|
| 151 |
+
def __call__(self, img):
|
| 152 |
+
# if np, get PIL, else, get np
|
| 153 |
+
if isinstance(img, torch.Tensor):
|
| 154 |
+
img_np = img.cpu().numpy()
|
| 155 |
+
img_pil = Image.fromarray(img_np)
|
| 156 |
+
elif isinstance(img, Image.Image):
|
| 157 |
+
img_pil = img
|
| 158 |
+
img_np = np.array(img)
|
| 159 |
+
elif isinstance(img, np.ndarray):
|
| 160 |
+
img_np = img
|
| 161 |
+
img_pil = Image.fromarray(img)
|
| 162 |
+
|
| 163 |
+
else:
|
| 164 |
+
raise ValueError("Unsupported image format. Please provide a PIL Image or numpy array.")
|
| 165 |
+
# Detect faces in the image
|
| 166 |
+
faces = self.model.get(img_np)
|
| 167 |
+
# use one
|
| 168 |
+
if len(faces) > 0:
|
| 169 |
+
bboxes = []
|
| 170 |
+
face = faces[0]
|
| 171 |
+
bbox = face.bbox.astype(int)
|
| 172 |
+
bboxes.append(bbox)
|
| 173 |
+
return self.extract_moref(img_pil, bboxes)[0]
|
| 174 |
+
else:
|
| 175 |
+
print("Warning: No faces detected in the image.")
|
| 176 |
+
return img_pil
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class WithAnyonePipeline:
|
| 180 |
+
def __init__(
|
| 181 |
+
self,
|
| 182 |
+
model_type: str,
|
| 183 |
+
ipa_path: str,
|
| 184 |
+
device: torch.device,
|
| 185 |
+
offload: bool = False,
|
| 186 |
+
only_lora: bool = False,
|
| 187 |
+
no_lora: bool = False,
|
| 188 |
+
lora_rank: int = 16,
|
| 189 |
+
face_extractor = None,
|
| 190 |
+
additional_lora_ckpt: str = None,
|
| 191 |
+
lora_weight: float = 1.0,
|
| 192 |
+
clip_path: str = "openai/clip-vit-large-patch14",
|
| 193 |
+
t5_path: str = "xlabs-ai/xflux_text_encoders",
|
| 194 |
+
flux_path: str = "black-forest-labs/FLUX.1-dev",
|
| 195 |
+
siglip_path: str = "google/siglip-base-patch16-256-i18n",
|
| 196 |
+
|
| 197 |
+
):
|
| 198 |
+
self.device = device
|
| 199 |
+
self.offload = offload
|
| 200 |
+
self.model_type = model_type
|
| 201 |
+
|
| 202 |
+
self.clip = load_clip(clip_path, self.device)
|
| 203 |
+
self.t5 = load_t5(t5_path, self.device, max_length=512)
|
| 204 |
+
self.ae = load_ae(flux_path, model_type, device="cpu" if offload else self.device)
|
| 205 |
+
self.use_fp8 = "fp8" in model_type
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
if additional_lora_ckpt is not None:
|
| 209 |
+
self.model = load_flow_model_diffusers(
|
| 210 |
+
model_type,
|
| 211 |
+
flux_path,
|
| 212 |
+
ipa_path,
|
| 213 |
+
device="cpu" if offload else self.device,
|
| 214 |
+
lora_rank=lora_rank,
|
| 215 |
+
use_fp8=self.use_fp8,
|
| 216 |
+
additional_lora_ckpt=additional_lora_ckpt,
|
| 217 |
+
lora_weight=lora_weight,
|
| 218 |
+
|
| 219 |
+
).to("cuda" if torch.cuda.is_available() else "cpu")
|
| 220 |
+
else:
|
| 221 |
+
self.model = load_flow_model_no_lora(
|
| 222 |
+
model_type,
|
| 223 |
+
flux_path,
|
| 224 |
+
ipa_path,
|
| 225 |
+
device="cpu" if offload else self.device,
|
| 226 |
+
use_fp8=self.use_fp8
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
if face_extractor is not None:
|
| 230 |
+
self.face_extractor = face_extractor
|
| 231 |
+
else:
|
| 232 |
+
self.face_extractor = FaceExtractor()
|
| 233 |
+
|
| 234 |
+
self.siglip = SiglipEmbedding(siglip_path=siglip_path)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def load_ckpt(self, ckpt_path):
|
| 238 |
+
if ckpt_path is not None:
|
| 239 |
+
from safetensors.torch import load_file as load_sft
|
| 240 |
+
print("Loading checkpoint to replace old keys")
|
| 241 |
+
# load_sft doesn't support torch.device
|
| 242 |
+
if ckpt_path.endswith('safetensors'):
|
| 243 |
+
sd = load_sft(ckpt_path, device='cpu')
|
| 244 |
+
missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
|
| 245 |
+
else:
|
| 246 |
+
dit_state = torch.load(ckpt_path, map_location='cpu')
|
| 247 |
+
sd = {}
|
| 248 |
+
for k in dit_state.keys():
|
| 249 |
+
sd[k.replace('module.','')] = dit_state[k]
|
| 250 |
+
missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
|
| 251 |
+
self.model.to(str(self.device))
|
| 252 |
+
print(f"missing keys: {missing}\n\n\n\n\nunexpected keys: {unexpected}")
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def __call__(
|
| 257 |
+
self,
|
| 258 |
+
prompt: str,
|
| 259 |
+
width: int = 512,
|
| 260 |
+
height: int = 512,
|
| 261 |
+
guidance: float = 4,
|
| 262 |
+
num_steps: int = 50,
|
| 263 |
+
seed: int = 123456789,
|
| 264 |
+
**kwargs
|
| 265 |
+
):
|
| 266 |
+
width = 16 * (width // 16)
|
| 267 |
+
height = 16 * (height // 16)
|
| 268 |
+
|
| 269 |
+
device_type = self.device if isinstance(self.device, str) else self.device.type
|
| 270 |
+
if device_type == "mps":
|
| 271 |
+
device_type = "cpu" # for support macos mps
|
| 272 |
+
with torch.autocast(enabled=self.use_fp8, device_type=device_type, dtype=torch.bfloat16):
|
| 273 |
+
return self.forward(
|
| 274 |
+
prompt,
|
| 275 |
+
width,
|
| 276 |
+
height,
|
| 277 |
+
guidance,
|
| 278 |
+
num_steps,
|
| 279 |
+
seed,
|
| 280 |
+
**kwargs
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
@torch.inference_mode
|
| 286 |
+
def forward(
|
| 287 |
+
self,
|
| 288 |
+
prompt: str,
|
| 289 |
+
width: int,
|
| 290 |
+
height: int,
|
| 291 |
+
guidance: float,
|
| 292 |
+
num_steps: int,
|
| 293 |
+
seed: int,
|
| 294 |
+
ref_imgs: list[Image.Image] | None = None,
|
| 295 |
+
arcface_embeddings: list[torch.Tensor] = None,
|
| 296 |
+
bboxes = None,
|
| 297 |
+
id_weight: float = 1.0,
|
| 298 |
+
siglip_weight: float = 1.0,
|
| 299 |
+
):
|
| 300 |
+
x = get_noise(
|
| 301 |
+
1, height, width, device=self.device,
|
| 302 |
+
dtype=torch.bfloat16, seed=seed
|
| 303 |
+
)
|
| 304 |
+
timesteps = get_schedule(
|
| 305 |
+
num_steps,
|
| 306 |
+
(width // 8) * (height // 8) // (16 * 16),
|
| 307 |
+
shift=True,
|
| 308 |
+
)
|
| 309 |
+
if self.offload:
|
| 310 |
+
self.ae.encoder = self.ae.encoder.to(self.device)
|
| 311 |
+
|
| 312 |
+
if ref_imgs is None:
|
| 313 |
+
siglip_embeddings = None
|
| 314 |
+
else:
|
| 315 |
+
siglip_embeddings = self.siglip(ref_imgs).to(self.device, torch.bfloat16).permute(1,0,2,3)
|
| 316 |
+
# num_ref, (1), n, d
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
if arcface_embeddings is not None:
|
| 320 |
+
arcface_embeddings = arcface_embeddings.unsqueeze(1)
|
| 321 |
+
# num_ref, 1, 512
|
| 322 |
+
arcface_embeddings = arcface_embeddings.to(self.device, torch.bfloat16)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
if self.offload:
|
| 326 |
+
self.offload_model_to_cpu(self.ae.encoder)
|
| 327 |
+
self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
inp_cond = prepare(t5=self.t5, clip=self.clip,img=x,prompt=prompt
|
| 331 |
+
)
|
| 332 |
+
if self.offload:
|
| 333 |
+
self.offload_model_to_cpu(self.t5, self.clip)
|
| 334 |
+
self.model = self.model.to(self.device)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
img = inp_cond["img"]
|
| 339 |
+
img_length = img.shape[1]
|
| 340 |
+
##### create mask for siglip and arcface #####
|
| 341 |
+
if bboxes is not None:
|
| 342 |
+
arc_mask = create_person_cross_attention_mask_varlen(
|
| 343 |
+
batch_size=img.shape[0],
|
| 344 |
+
# num_heads=self.params.num_heads,
|
| 345 |
+
# txt_len=text_length,
|
| 346 |
+
img_len=img_length,
|
| 347 |
+
id_len=8,
|
| 348 |
+
bbox_lists=bboxes,
|
| 349 |
+
max_num_ids=len(bboxes[0]),
|
| 350 |
+
original_width=width,
|
| 351 |
+
original_height= height,
|
| 352 |
+
).to(img.device)
|
| 353 |
+
siglip_mask = create_person_cross_attention_mask_varlen(
|
| 354 |
+
batch_size=img.shape[0],
|
| 355 |
+
# num_heads=self.params.num_heads,
|
| 356 |
+
# txt_len=text_length,
|
| 357 |
+
img_len=img_length,
|
| 358 |
+
id_len=256+8,
|
| 359 |
+
bbox_lists=bboxes,
|
| 360 |
+
max_num_ids=len(bboxes[0]),
|
| 361 |
+
original_width=width,
|
| 362 |
+
original_height= height,
|
| 363 |
+
).to(img.device)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
results = denoise(
|
| 370 |
+
self.model,
|
| 371 |
+
**inp_cond,
|
| 372 |
+
timesteps=timesteps,
|
| 373 |
+
guidance=guidance,
|
| 374 |
+
arcface_embeddings=arcface_embeddings,
|
| 375 |
+
siglip_embeddings=siglip_embeddings,
|
| 376 |
+
bboxes=bboxes,
|
| 377 |
+
id_weight=id_weight,
|
| 378 |
+
siglip_weight=siglip_weight,
|
| 379 |
+
img_height=height,
|
| 380 |
+
img_width=width,
|
| 381 |
+
arc_mask=arc_mask if bboxes is not None else None,
|
| 382 |
+
siglip_mask=siglip_mask if bboxes is not None else None,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
x = results
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
if self.offload:
|
| 389 |
+
self.offload_model_to_cpu(self.model)
|
| 390 |
+
self.ae.decoder.to(x.device)
|
| 391 |
+
x = unpack(x.float(), height, width)
|
| 392 |
+
x = self.ae.decode(x)
|
| 393 |
+
self.offload_model_to_cpu(self.ae.decoder)
|
| 394 |
+
|
| 395 |
+
x1 = x.clamp(-1, 1)
|
| 396 |
+
x1 = rearrange(x1[-1], "c h w -> h w c")
|
| 397 |
+
output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
return output_img
|
| 401 |
+
|
| 402 |
+
def offload_model_to_cpu(self, *models):
|
| 403 |
+
if not self.offload: return
|
| 404 |
+
for model in models:
|
| 405 |
+
model.cpu()
|
| 406 |
+
torch.cuda.empty_cache()
|
withanyone/flux/sampling.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Literal
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from einops import rearrange, repeat
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
from .model import Flux
|
| 12 |
+
from .modules.conditioner import HFEmbedder
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_noise(
|
| 16 |
+
num_samples: int,
|
| 17 |
+
height: int,
|
| 18 |
+
width: int,
|
| 19 |
+
device: torch.device,
|
| 20 |
+
dtype: torch.dtype,
|
| 21 |
+
seed: int,
|
| 22 |
+
):
|
| 23 |
+
return torch.randn(
|
| 24 |
+
num_samples,
|
| 25 |
+
16,
|
| 26 |
+
# allow for packing
|
| 27 |
+
2 * math.ceil(height / 16),
|
| 28 |
+
2 * math.ceil(width / 16),
|
| 29 |
+
device=device,
|
| 30 |
+
dtype=dtype,
|
| 31 |
+
generator=torch.Generator(device=device).manual_seed(seed),
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def prepare(
|
| 36 |
+
t5: HFEmbedder,
|
| 37 |
+
clip: HFEmbedder,
|
| 38 |
+
img: Tensor,
|
| 39 |
+
prompt: str | list[str],
|
| 40 |
+
) -> dict[str, Tensor]:
|
| 41 |
+
bs, c, h, w = img.shape
|
| 42 |
+
if bs == 1 and not isinstance(prompt, str):
|
| 43 |
+
bs = len(prompt)
|
| 44 |
+
|
| 45 |
+
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
| 46 |
+
if img.shape[0] == 1 and bs > 1:
|
| 47 |
+
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
| 48 |
+
|
| 49 |
+
img_ids = torch.zeros(h // 2, w // 2, 3)
|
| 50 |
+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
| 51 |
+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
| 52 |
+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if isinstance(prompt, str):
|
| 56 |
+
prompt = [prompt]
|
| 57 |
+
txt = t5(prompt)
|
| 58 |
+
if txt.shape[0] == 1 and bs > 1:
|
| 59 |
+
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
| 60 |
+
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
| 61 |
+
|
| 62 |
+
vec = clip(prompt)
|
| 63 |
+
if vec.shape[0] == 1 and bs > 1:
|
| 64 |
+
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
return {
|
| 69 |
+
"img": img,
|
| 70 |
+
"img_ids": img_ids.to(img.device),
|
| 71 |
+
"txt": txt.to(img.device),
|
| 72 |
+
"txt_ids": txt_ids.to(img.device),
|
| 73 |
+
"vec": vec.to(img.device),
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def time_shift(mu: float, sigma: float, t: Tensor):
|
| 80 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def get_lin_function(
|
| 84 |
+
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
| 85 |
+
):
|
| 86 |
+
m = (y2 - y1) / (x2 - x1)
|
| 87 |
+
b = y1 - m * x1
|
| 88 |
+
return lambda x: m * x + b
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_schedule(
|
| 92 |
+
num_steps: int,
|
| 93 |
+
image_seq_len: int,
|
| 94 |
+
base_shift: float = 0.5,
|
| 95 |
+
max_shift: float = 1.15,
|
| 96 |
+
shift: bool = True,
|
| 97 |
+
) -> list[float]:
|
| 98 |
+
# extra step for zero
|
| 99 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
| 100 |
+
|
| 101 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
| 102 |
+
if shift:
|
| 103 |
+
# eastimate mu based on linear estimation between two points
|
| 104 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
| 105 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
| 106 |
+
|
| 107 |
+
return timesteps.tolist()
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def denoise(
|
| 111 |
+
model: Flux,
|
| 112 |
+
# model input
|
| 113 |
+
img: Tensor,
|
| 114 |
+
img_ids: Tensor,
|
| 115 |
+
txt: Tensor,
|
| 116 |
+
txt_ids: Tensor,
|
| 117 |
+
vec: Tensor,
|
| 118 |
+
|
| 119 |
+
timesteps: list[float],
|
| 120 |
+
guidance: float = 4.0,
|
| 121 |
+
|
| 122 |
+
arcface_embeddings = None,
|
| 123 |
+
siglip_embeddings = None,
|
| 124 |
+
bboxes: Tensor = None,
|
| 125 |
+
|
| 126 |
+
id_weight: float = 1.0, # weight for identity embeddings
|
| 127 |
+
siglip_weight: float = 1.0, # weight for siglip embeddings
|
| 128 |
+
img_height: int = 512,
|
| 129 |
+
img_width: int = 512,
|
| 130 |
+
arc_mask = None,
|
| 131 |
+
siglip_mask = None,
|
| 132 |
+
):
|
| 133 |
+
i = 0
|
| 134 |
+
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
| 135 |
+
for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1):
|
| 136 |
+
|
| 137 |
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
| 138 |
+
|
| 139 |
+
pred = model(
|
| 140 |
+
img=img,
|
| 141 |
+
img_ids=img_ids,
|
| 142 |
+
siglip_embeddings=siglip_embeddings,
|
| 143 |
+
txt=txt,
|
| 144 |
+
txt_ids=txt_ids,
|
| 145 |
+
y=vec,
|
| 146 |
+
timesteps=t_vec,
|
| 147 |
+
guidance=guidance_vec,
|
| 148 |
+
arcface_embeddings=arcface_embeddings,
|
| 149 |
+
bbox_lists=bboxes,
|
| 150 |
+
id_weight=id_weight,
|
| 151 |
+
siglip_weight=siglip_weight,
|
| 152 |
+
img_height=img_height,
|
| 153 |
+
img_width=img_width,
|
| 154 |
+
arc_mask=arc_mask,
|
| 155 |
+
siglip_mask=siglip_mask,
|
| 156 |
+
)
|
| 157 |
+
img = img + (t_prev - t_curr) * pred
|
| 158 |
+
i += 1
|
| 159 |
+
|
| 160 |
+
return img
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
| 164 |
+
return rearrange(
|
| 165 |
+
x,
|
| 166 |
+
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
| 167 |
+
h=math.ceil(height / 16),
|
| 168 |
+
w=math.ceil(width / 16),
|
| 169 |
+
ph=2,
|
| 170 |
+
pw=2,
|
| 171 |
+
)
|
withanyone/flux/util.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import json
|
| 8 |
+
import numpy as np
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
+
from safetensors import safe_open
|
| 11 |
+
from safetensors.torch import load_file as load_sft
|
| 12 |
+
|
| 13 |
+
from withanyone.flux.model import Flux, FluxParams
|
| 14 |
+
from .modules.autoencoder import AutoEncoder, AutoEncoderParams
|
| 15 |
+
from .modules.conditioner import HFEmbedder
|
| 16 |
+
|
| 17 |
+
import re
|
| 18 |
+
from withanyone.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def c_crop(image):
|
| 24 |
+
width, height = image.size
|
| 25 |
+
new_size = min(width, height)
|
| 26 |
+
left = (width - new_size) / 2
|
| 27 |
+
top = (height - new_size) / 2
|
| 28 |
+
right = (width + new_size) / 2
|
| 29 |
+
bottom = (height + new_size) / 2
|
| 30 |
+
return image.crop((left, top, right, bottom))
|
| 31 |
+
|
| 32 |
+
def pad64(x):
|
| 33 |
+
return int(np.ceil(float(x) / 64.0) * 64 - x)
|
| 34 |
+
|
| 35 |
+
def HWC3(x):
|
| 36 |
+
assert x.dtype == np.uint8
|
| 37 |
+
if x.ndim == 2:
|
| 38 |
+
x = x[:, :, None]
|
| 39 |
+
assert x.ndim == 3
|
| 40 |
+
H, W, C = x.shape
|
| 41 |
+
assert C == 1 or C == 3 or C == 4
|
| 42 |
+
if C == 3:
|
| 43 |
+
return x
|
| 44 |
+
if C == 1:
|
| 45 |
+
return np.concatenate([x, x, x], axis=2)
|
| 46 |
+
if C == 4:
|
| 47 |
+
color = x[:, :, 0:3].astype(np.float32)
|
| 48 |
+
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
| 49 |
+
y = color * alpha + 255.0 * (1.0 - alpha)
|
| 50 |
+
y = y.clip(0, 255).astype(np.uint8)
|
| 51 |
+
return y
|
| 52 |
+
|
| 53 |
+
@dataclass
|
| 54 |
+
class ModelSpec:
|
| 55 |
+
params: FluxParams
|
| 56 |
+
ae_params: AutoEncoderParams
|
| 57 |
+
repo_id: str | None
|
| 58 |
+
repo_flow: str | None
|
| 59 |
+
repo_ae: str | None
|
| 60 |
+
repo_id_ae: str | None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
configs = {
|
| 64 |
+
"flux-dev": ModelSpec(
|
| 65 |
+
repo_id="black-forest-labs/FLUX.1-dev",
|
| 66 |
+
repo_id_ae="black-forest-labs/FLUX.1-dev",
|
| 67 |
+
repo_flow="flux1-dev.safetensors",
|
| 68 |
+
repo_ae="ae.safetensors",
|
| 69 |
+
params=FluxParams(
|
| 70 |
+
in_channels=64,
|
| 71 |
+
vec_in_dim=768,
|
| 72 |
+
context_in_dim=4096,
|
| 73 |
+
hidden_size=3072,
|
| 74 |
+
mlp_ratio=4.0,
|
| 75 |
+
num_heads=24,
|
| 76 |
+
depth=19,
|
| 77 |
+
depth_single_blocks=38,
|
| 78 |
+
axes_dim=[16, 56, 56],
|
| 79 |
+
theta=10_000,
|
| 80 |
+
qkv_bias=True,
|
| 81 |
+
guidance_embed=True,
|
| 82 |
+
),
|
| 83 |
+
ae_params=AutoEncoderParams(
|
| 84 |
+
resolution=256,
|
| 85 |
+
in_channels=3,
|
| 86 |
+
ch=128,
|
| 87 |
+
out_ch=3,
|
| 88 |
+
ch_mult=[1, 2, 4, 4],
|
| 89 |
+
num_res_blocks=2,
|
| 90 |
+
z_channels=16,
|
| 91 |
+
scale_factor=0.3611,
|
| 92 |
+
shift_factor=0.1159,
|
| 93 |
+
),
|
| 94 |
+
),
|
| 95 |
+
"flux-dev-fp8": ModelSpec(
|
| 96 |
+
repo_id="black-forest-labs/FLUX.1-dev",
|
| 97 |
+
repo_id_ae="black-forest-labs/FLUX.1-dev",
|
| 98 |
+
repo_flow="flux1-dev.safetensors",
|
| 99 |
+
repo_ae="ae.safetensors",
|
| 100 |
+
params=FluxParams(
|
| 101 |
+
in_channels=64,
|
| 102 |
+
vec_in_dim=768,
|
| 103 |
+
context_in_dim=4096,
|
| 104 |
+
hidden_size=3072,
|
| 105 |
+
mlp_ratio=4.0,
|
| 106 |
+
num_heads=24,
|
| 107 |
+
depth=19,
|
| 108 |
+
depth_single_blocks=38,
|
| 109 |
+
axes_dim=[16, 56, 56],
|
| 110 |
+
theta=10_000,
|
| 111 |
+
qkv_bias=True,
|
| 112 |
+
guidance_embed=True,
|
| 113 |
+
),
|
| 114 |
+
ae_params=AutoEncoderParams(
|
| 115 |
+
resolution=256,
|
| 116 |
+
in_channels=3,
|
| 117 |
+
ch=128,
|
| 118 |
+
out_ch=3,
|
| 119 |
+
ch_mult=[1, 2, 4, 4],
|
| 120 |
+
num_res_blocks=2,
|
| 121 |
+
z_channels=16,
|
| 122 |
+
scale_factor=0.3611,
|
| 123 |
+
shift_factor=0.1159,
|
| 124 |
+
),
|
| 125 |
+
),
|
| 126 |
+
"flux-krea": ModelSpec(
|
| 127 |
+
repo_id="black-forest-labs/FLUX.1-Krea-dev",
|
| 128 |
+
repo_id_ae="black-forest-labs/FLUX.1-Krea-dev",
|
| 129 |
+
repo_flow="flux1-krea-dev.safetensors",
|
| 130 |
+
repo_ae="ae.safetensors",
|
| 131 |
+
params=FluxParams(
|
| 132 |
+
in_channels=64,
|
| 133 |
+
vec_in_dim=768,
|
| 134 |
+
context_in_dim=4096,
|
| 135 |
+
hidden_size=3072,
|
| 136 |
+
mlp_ratio=4.0,
|
| 137 |
+
num_heads=24,
|
| 138 |
+
depth=19,
|
| 139 |
+
depth_single_blocks=38,
|
| 140 |
+
axes_dim=[16, 56, 56],
|
| 141 |
+
theta=10_000,
|
| 142 |
+
qkv_bias=True,
|
| 143 |
+
guidance_embed=True,
|
| 144 |
+
),
|
| 145 |
+
ae_params=AutoEncoderParams(
|
| 146 |
+
resolution=256,
|
| 147 |
+
in_channels=3,
|
| 148 |
+
ch=128,
|
| 149 |
+
out_ch=3,
|
| 150 |
+
ch_mult=[1, 2, 4, 4],
|
| 151 |
+
num_res_blocks=2,
|
| 152 |
+
z_channels=16,
|
| 153 |
+
scale_factor=0.3611,
|
| 154 |
+
shift_factor=0.1159,
|
| 155 |
+
),
|
| 156 |
+
),
|
| 157 |
+
"flux-schnell": ModelSpec(
|
| 158 |
+
repo_id="black-forest-labs/FLUX.1-schnell",
|
| 159 |
+
repo_id_ae="black-forest-labs/FLUX.1-dev",
|
| 160 |
+
repo_flow="flux1-schnell.safetensors",
|
| 161 |
+
repo_ae="ae.safetensors",
|
| 162 |
+
params=FluxParams(
|
| 163 |
+
in_channels=64,
|
| 164 |
+
vec_in_dim=768,
|
| 165 |
+
context_in_dim=4096,
|
| 166 |
+
hidden_size=3072,
|
| 167 |
+
mlp_ratio=4.0,
|
| 168 |
+
num_heads=24,
|
| 169 |
+
depth=19,
|
| 170 |
+
depth_single_blocks=38,
|
| 171 |
+
axes_dim=[16, 56, 56],
|
| 172 |
+
theta=10_000,
|
| 173 |
+
qkv_bias=True,
|
| 174 |
+
guidance_embed=False,
|
| 175 |
+
),
|
| 176 |
+
ae_params=AutoEncoderParams(
|
| 177 |
+
resolution=256,
|
| 178 |
+
in_channels=3,
|
| 179 |
+
ch=128,
|
| 180 |
+
out_ch=3,
|
| 181 |
+
ch_mult=[1, 2, 4, 4],
|
| 182 |
+
num_res_blocks=2,
|
| 183 |
+
z_channels=16,
|
| 184 |
+
scale_factor=0.3611,
|
| 185 |
+
shift_factor=0.1159,
|
| 186 |
+
),
|
| 187 |
+
),
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
|
| 192 |
+
if len(missing) > 0 and len(unexpected) > 0:
|
| 193 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
| 194 |
+
print("\n" + "-" * 79 + "\n")
|
| 195 |
+
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
| 196 |
+
elif len(missing) > 0:
|
| 197 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
| 198 |
+
elif len(unexpected) > 0:
|
| 199 |
+
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
| 200 |
+
|
| 201 |
+
def load_from_repo_id(repo_id, checkpoint_name):
|
| 202 |
+
ckpt_path = hf_hub_download(repo_id, checkpoint_name)
|
| 203 |
+
sd = load_sft(ckpt_path, device='cpu')
|
| 204 |
+
return sd
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def load_flow_model_no_lora(
|
| 211 |
+
name: str,
|
| 212 |
+
path: str,
|
| 213 |
+
ipa_path: str ,
|
| 214 |
+
device: str | torch.device = "cuda",
|
| 215 |
+
hf_download: bool = True,
|
| 216 |
+
lora_rank: int = 16,
|
| 217 |
+
use_fp8: bool = False
|
| 218 |
+
):
|
| 219 |
+
# Loading Flux
|
| 220 |
+
print("Init model")
|
| 221 |
+
ckpt_path = path
|
| 222 |
+
if ckpt_path == "black-forest-labs/FLUX.1-dev" or (
|
| 223 |
+
ckpt_path is None
|
| 224 |
+
and configs[name].repo_id is not None
|
| 225 |
+
and configs[name].repo_flow is not None
|
| 226 |
+
and hf_download
|
| 227 |
+
):
|
| 228 |
+
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
|
| 229 |
+
print("Downloading checkpoint from HF:", ckpt_path)
|
| 230 |
+
else:
|
| 231 |
+
ckpt_path = os.path.join(path, "flux1-dev.safetensors") if path is not None else None
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
ipa_ckpt_path = ipa_path
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
with torch.device("meta" if ckpt_path is not None else device):
|
| 240 |
+
model = Flux(configs[name].params)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
# model = set_lora(model, lora_rank, device="meta" if ipa_ckpt_path is not None else device)
|
| 244 |
+
|
| 245 |
+
if ckpt_path is not None:
|
| 246 |
+
if ipa_ckpt_path == 'WithAnyone/WithAnyone':
|
| 247 |
+
ipa_ckpt_path = hf_hub_download("WithAnyone/WithAnyone", "withanyone.safetensors")
|
| 248 |
+
|
| 249 |
+
lora_sd = load_sft(ipa_ckpt_path, device=str(device)) if ipa_ckpt_path.endswith("safetensors")\
|
| 250 |
+
else torch.load(ipa_ckpt_path, map_location='cpu')
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
print("Loading main checkpoint")
|
| 254 |
+
# load_sft doesn't support torch.device
|
| 255 |
+
|
| 256 |
+
if ckpt_path.endswith('safetensors'):
|
| 257 |
+
if use_fp8:
|
| 258 |
+
print(
|
| 259 |
+
"####\n"
|
| 260 |
+
"We are in fp8 mode right now, since the fp8 checkpoint of XLabs-AI/flux-dev-fp8 seems broken\n"
|
| 261 |
+
"we convert the fp8 checkpoint on flight from bf16 checkpoint\n"
|
| 262 |
+
"If your storage is constrained"
|
| 263 |
+
"you can save the fp8 checkpoint and replace the bf16 checkpoint by yourself\n"
|
| 264 |
+
)
|
| 265 |
+
sd = load_sft(ckpt_path, device="cpu")
|
| 266 |
+
sd = {k: v.to(dtype=torch.float8_e4m3fn, device=device) for k, v in sd.items()}
|
| 267 |
+
else:
|
| 268 |
+
sd = load_sft(ckpt_path, device=str(device))
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# # Then proceed with the update
|
| 273 |
+
sd.update(lora_sd)
|
| 274 |
+
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
| 275 |
+
else:
|
| 276 |
+
dit_state = torch.load(ckpt_path, map_location='cpu')
|
| 277 |
+
sd = {}
|
| 278 |
+
for k in dit_state.keys():
|
| 279 |
+
sd[k.replace('module.','')] = dit_state[k]
|
| 280 |
+
sd.update(lora_sd)
|
| 281 |
+
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
| 282 |
+
model.to(str(device))
|
| 283 |
+
print_load_warning(missing, unexpected)
|
| 284 |
+
return model
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def merge_to_flux_model(
|
| 288 |
+
loading_device, working_device, flux_state_dict, model, ratio, merge_dtype, save_dtype, mem_eff_load_save=False
|
| 289 |
+
):
|
| 290 |
+
|
| 291 |
+
lora_name_to_module_key = {}
|
| 292 |
+
keys = list(flux_state_dict.keys())
|
| 293 |
+
for key in keys:
|
| 294 |
+
if key.endswith(".weight"):
|
| 295 |
+
module_name = ".".join(key.split(".")[:-1])
|
| 296 |
+
lora_name = "lora_unet" + "_" + module_name.replace(".", "_")
|
| 297 |
+
lora_name_to_module_key[lora_name] = key
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
print(f"loading: {model}")
|
| 301 |
+
lora_sd = load_sft(model, device=loading_device) if model.endswith("safetensors")\
|
| 302 |
+
else torch.load(model, map_location='cpu')
|
| 303 |
+
|
| 304 |
+
print(f"merging...")
|
| 305 |
+
for key in list(lora_sd.keys()):
|
| 306 |
+
if "lora_down" in key:
|
| 307 |
+
lora_name = key[: key.rfind(".lora_down")]
|
| 308 |
+
up_key = key.replace("lora_down", "lora_up")
|
| 309 |
+
alpha_key = key[: key.index("lora_down")] + "alpha"
|
| 310 |
+
|
| 311 |
+
if lora_name not in lora_name_to_module_key:
|
| 312 |
+
print(f"no module found for LoRA weight: {key}. LoRA for Text Encoder is not supported yet.")
|
| 313 |
+
continue
|
| 314 |
+
|
| 315 |
+
down_weight = lora_sd.pop(key)
|
| 316 |
+
up_weight = lora_sd.pop(up_key)
|
| 317 |
+
|
| 318 |
+
dim = down_weight.size()[0]
|
| 319 |
+
alpha = lora_sd.pop(alpha_key, dim)
|
| 320 |
+
scale = alpha / dim
|
| 321 |
+
|
| 322 |
+
# W <- W + U * D
|
| 323 |
+
module_weight_key = lora_name_to_module_key[lora_name]
|
| 324 |
+
if module_weight_key not in flux_state_dict:
|
| 325 |
+
# weight = flux_file.get_tensor(module_weight_key)
|
| 326 |
+
print(f"no module found for LoRA weight: {module_weight_key}")
|
| 327 |
+
else:
|
| 328 |
+
weight = flux_state_dict[module_weight_key]
|
| 329 |
+
|
| 330 |
+
weight = weight.to(working_device, merge_dtype)
|
| 331 |
+
up_weight = up_weight.to(working_device, merge_dtype)
|
| 332 |
+
down_weight = down_weight.to(working_device, merge_dtype)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
if len(weight.size()) == 2:
|
| 336 |
+
# linear
|
| 337 |
+
weight = weight + ratio * (up_weight @ down_weight) * scale
|
| 338 |
+
elif down_weight.size()[2:4] == (1, 1):
|
| 339 |
+
# conv2d 1x1
|
| 340 |
+
weight = (
|
| 341 |
+
weight
|
| 342 |
+
+ ratio
|
| 343 |
+
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
| 344 |
+
* scale
|
| 345 |
+
)
|
| 346 |
+
else:
|
| 347 |
+
# conv2d 3x3
|
| 348 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
| 349 |
+
weight = weight + ratio * conved * scale
|
| 350 |
+
|
| 351 |
+
flux_state_dict[module_weight_key] = weight.to(loading_device, save_dtype)
|
| 352 |
+
del up_weight
|
| 353 |
+
del down_weight
|
| 354 |
+
del weight
|
| 355 |
+
|
| 356 |
+
if len(lora_sd) > 0:
|
| 357 |
+
print(f"Unused keys in LoRA model: {list(lora_sd.keys())}")
|
| 358 |
+
|
| 359 |
+
return flux_state_dict
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def load_flow_model_diffusers(
|
| 364 |
+
name: str,
|
| 365 |
+
path: str,
|
| 366 |
+
ipa_path: str ,
|
| 367 |
+
device: str | torch.device = "cuda",
|
| 368 |
+
hf_download: bool = True,
|
| 369 |
+
lora_rank: int = 16,
|
| 370 |
+
use_fp8: bool = False,
|
| 371 |
+
additional_lora_ckpt: str | None = None,
|
| 372 |
+
lora_weight: float = 1.0,
|
| 373 |
+
):
|
| 374 |
+
# Loading Flux
|
| 375 |
+
print("Init model")
|
| 376 |
+
|
| 377 |
+
ckpt_path = os.path.join(path, "flux1-dev.safetensors") if path is not None else None
|
| 378 |
+
print("Loading checkpoint from", ckpt_path)
|
| 379 |
+
if (
|
| 380 |
+
ckpt_path is None
|
| 381 |
+
and configs[name].repo_id is not None
|
| 382 |
+
and configs[name].repo_flow is not None
|
| 383 |
+
and hf_download
|
| 384 |
+
):
|
| 385 |
+
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
ipa_ckpt_path = ipa_path
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
with torch.device("meta" if ckpt_path is not None else device):
|
| 393 |
+
model = Flux(configs[name].params)
|
| 394 |
+
|
| 395 |
+
# if additional_lora_ckpt is not None:
|
| 396 |
+
# model = set_lora(model, lora_rank, device="meta" if ipa_ckpt_path is not None else device)
|
| 397 |
+
assert additional_lora_ckpt is not None, "additional_lora_ckpt should have been provided. this must be a bug"
|
| 398 |
+
|
| 399 |
+
if ckpt_path is not None:
|
| 400 |
+
if ipa_ckpt_path == 'WithAnyone/WithAnyone':
|
| 401 |
+
ipa_ckpt_path = hf_hub_download("WithAnyone/WithAnyone", "withanyone.safetensors")
|
| 402 |
+
else:
|
| 403 |
+
lora_sd = load_sft(ipa_ckpt_path, device=str(device)) if ipa_ckpt_path.endswith("safetensors")\
|
| 404 |
+
else torch.load(ipa_ckpt_path, map_location='cpu')
|
| 405 |
+
|
| 406 |
+
extra_lora_path = additional_lora_ckpt
|
| 407 |
+
|
| 408 |
+
print("Loading main checkpoint")
|
| 409 |
+
# load_sft doesn't support torch.device
|
| 410 |
+
|
| 411 |
+
if ckpt_path.endswith('safetensors'):
|
| 412 |
+
if use_fp8:
|
| 413 |
+
print(
|
| 414 |
+
"####\n"
|
| 415 |
+
"We are in fp8 mode right now, since the fp8 checkpoint of XLabs-AI/flux-dev-fp8 seems broken\n"
|
| 416 |
+
"we convert the fp8 checkpoint on flight from bf16 checkpoint\n"
|
| 417 |
+
"If your storage is constrained"
|
| 418 |
+
"you can save the fp8 checkpoint and replace the bf16 checkpoint by yourself\n"
|
| 419 |
+
)
|
| 420 |
+
sd = load_sft(ckpt_path, device="cpu")
|
| 421 |
+
sd = {k: v.to(dtype=torch.float8_e4m3fn, device=device) for k, v in sd.items()}
|
| 422 |
+
else:
|
| 423 |
+
sd = load_sft(ckpt_path, device=str(device))
|
| 424 |
+
|
| 425 |
+
if extra_lora_path is not None:
|
| 426 |
+
print("Merging extra lora to main checkpoint")
|
| 427 |
+
lora_ckpt_path = extra_lora_path
|
| 428 |
+
sd = merge_to_flux_model("cpu", device, sd, lora_ckpt_path, lora_weight, torch.float8_e4m3fn if use_fp8 else torch.bfloat16, torch.float8_e4m3fn if use_fp8 else torch.bfloat16)
|
| 429 |
+
# # Then proceed with the update
|
| 430 |
+
sd.update(ipa_lora_sd)
|
| 431 |
+
|
| 432 |
+
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
| 433 |
+
model.to(str(device))
|
| 434 |
+
else:
|
| 435 |
+
dit_state = torch.load(ckpt_path, map_location='cpu')
|
| 436 |
+
sd = {}
|
| 437 |
+
for k in dit_state.keys():
|
| 438 |
+
sd[k.replace('module.','')] = dit_state[k]
|
| 439 |
+
|
| 440 |
+
if extra_lora_path is not None:
|
| 441 |
+
print("Merging extra lora to main checkpoint")
|
| 442 |
+
lora_ckpt_path = extra_lora_path
|
| 443 |
+
sd = merge_to_flux_model("cpu", device, sd, lora_ckpt_path, 1.0, torch.float8_e4m3fn if use_fp8 else torch.bfloat16, torch.float8_e4m3fn if use_fp8 else torch.bfloat16)
|
| 444 |
+
sd.update(ipa_lora_sd)
|
| 445 |
+
|
| 446 |
+
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
| 447 |
+
model.to(str(device))
|
| 448 |
+
print_load_warning(missing, unexpected)
|
| 449 |
+
|
| 450 |
+
return model
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def set_lora(
|
| 454 |
+
model: Flux,
|
| 455 |
+
lora_rank: int,
|
| 456 |
+
double_blocks_indices: list[int] | None = None,
|
| 457 |
+
single_blocks_indices: list[int] | None = None,
|
| 458 |
+
device: str | torch.device = "cpu",
|
| 459 |
+
) -> Flux:
|
| 460 |
+
double_blocks_indices = list(range(model.params.depth)) if double_blocks_indices is None else double_blocks_indices
|
| 461 |
+
single_blocks_indices = list(range(model.params.depth_single_blocks)) if single_blocks_indices is None \
|
| 462 |
+
else single_blocks_indices
|
| 463 |
+
|
| 464 |
+
lora_attn_procs = {}
|
| 465 |
+
with torch.device(device):
|
| 466 |
+
for name, attn_processor in model.attn_processors.items():
|
| 467 |
+
match = re.search(r'\.(\d+)\.', name)
|
| 468 |
+
if match:
|
| 469 |
+
layer_index = int(match.group(1))
|
| 470 |
+
|
| 471 |
+
if name.startswith("double_blocks") and layer_index in double_blocks_indices:
|
| 472 |
+
lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank)
|
| 473 |
+
elif name.startswith("single_blocks") and layer_index in single_blocks_indices:
|
| 474 |
+
lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank)
|
| 475 |
+
else:
|
| 476 |
+
lora_attn_procs[name] = attn_processor
|
| 477 |
+
model.set_attn_processor(lora_attn_procs)
|
| 478 |
+
return model
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def load_t5(t5_path, device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
|
| 484 |
+
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
| 485 |
+
version = t5_path
|
| 486 |
+
return HFEmbedder(version, max_length=max_length, torch_dtype=torch.bfloat16).to(device)
|
| 487 |
+
|
| 488 |
+
def load_clip(clip_path, device: str | torch.device = "cuda") -> HFEmbedder:
|
| 489 |
+
version = clip_path
|
| 490 |
+
|
| 491 |
+
return HFEmbedder(version, max_length=77, torch_dtype=torch.bfloat16).to(device)
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def load_ae(flux_path, name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
if flux_path == "black-forest-labs/FLUX.1-dev" or flux_path == "black-forest-labs/FLUX.1-schnell" or flux_path == "black-forest-labs/FLUX.1-Krea-dev" or flux_path == "black-forest-labs/FLUX.1-Kontext-dev":
|
| 498 |
+
ckpt_path = hf_hub_download("black-forest-labs/FLUX.1-dev", "ae.safetensors")
|
| 499 |
+
else:
|
| 500 |
+
ckpt_path = os.path.join(flux_path, "ae.safetensors")
|
| 501 |
+
if not os.path.exists(ckpt_path):
|
| 502 |
+
# try diffusion_pytorch_model.safetensors
|
| 503 |
+
ckpt_path = os.path.join(flux_path, "vae", "ae.safetensors")
|
| 504 |
+
if not os.path.exists(ckpt_path):
|
| 505 |
+
raise FileNotFoundError(f"Cannot find ae checkpoint in {flux_path}/ae.safetensors or {flux_path}/vae/ae.safetensors")
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
# Loading the autoencoder
|
| 509 |
+
print("Init AE")
|
| 510 |
+
with torch.device("meta" if ckpt_path is not None else device):
|
| 511 |
+
ae = AutoEncoder(configs[name].ae_params)
|
| 512 |
+
|
| 513 |
+
# if ckpt_path is not None:
|
| 514 |
+
assert ckpt_path is not None, "ckpt_path should have been provided. this must be a bug"
|
| 515 |
+
sd = load_sft(ckpt_path, device=str(device))
|
| 516 |
+
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
|
| 517 |
+
print_load_warning(missing, unexpected)
|
| 518 |
+
return ae
|
withanyone/utils/convert_yaml_to_args_file.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import yaml
|
| 5 |
+
|
| 6 |
+
parser = argparse.ArgumentParser()
|
| 7 |
+
parser.add_argument("--yaml", type=str, required=True)
|
| 8 |
+
parser.add_argument("--arg", type=str, required=True)
|
| 9 |
+
args = parser.parse_args()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
with open(args.yaml, "r") as f:
|
| 13 |
+
data = yaml.safe_load(f)
|
| 14 |
+
|
| 15 |
+
with open(args.arg, "w") as f:
|
| 16 |
+
for k, v in data.items():
|
| 17 |
+
if isinstance(v, list):
|
| 18 |
+
v = list(map(str, v))
|
| 19 |
+
v = " ".join(v)
|
| 20 |
+
if v is None:
|
| 21 |
+
continue
|
| 22 |
+
print(f"--{k} {v}", end=" ", file=f)
|