Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	
		ResearcherXman
		
	commited on
		
		
					Commit 
							
							·
						
						ec7fc1c
	
1
								Parent(s):
							
							f4fab1d
								
support lcm and multi-controlnets
Browse files- app.py +284 -128
- controlnet_util.py +39 -0
- gradio_cached_examples/25/Generated Image/2880a3d19ef9b42e3ed2/image.png +0 -3
- gradio_cached_examples/25/Generated Image/38dde1388c41109c5d39/image.png +0 -3
- gradio_cached_examples/25/Generated Image/6cb5a3af223906666bfd/image.png +0 -3
- gradio_cached_examples/25/Generated Image/7e6ea85f77dd925d842c/image.png +0 -3
- gradio_cached_examples/25/Generated Image/e56b029833685ff77e6a/image.png +0 -3
- gradio_cached_examples/25/log.csv +0 -6
- ip_adapter/attention_processor.py +146 -8
- model_util.py +472 -0
- pipeline_stable_diffusion_xl_instantid.py → pipeline_stable_diffusion_xl_instantid_full.py +102 -21
- requirements.txt +7 -3
    	
        app.py
    CHANGED
    
    | @@ -1,23 +1,34 @@ | |
| 1 | 
            -
            import  | 
| 2 | 
            -
            import random
         | 
| 3 | 
            -
             | 
| 4 | 
             
            import cv2
         | 
| 5 | 
            -
            import  | 
|  | |
| 6 | 
             
            import numpy as np
         | 
|  | |
| 7 | 
             
            import PIL
         | 
| 8 | 
            -
            import  | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
             
            from diffusers.utils import load_image
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 12 | 
             
            from insightface.app import FaceAnalysis
         | 
| 13 | 
            -
            from PIL import Image
         | 
| 14 |  | 
| 15 | 
            -
            from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline
         | 
| 16 | 
             
            from style_template import styles
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 17 |  | 
| 18 | 
             
            # global variable
         | 
| 19 | 
             
            MAX_SEED = np.iinfo(np.int32).max
         | 
| 20 | 
            -
            device =  | 
|  | |
| 21 | 
             
            STYLE_NAMES = list(styles.keys())
         | 
| 22 | 
             
            DEFAULT_STYLE_NAME = "Watercolor"
         | 
| 23 |  | 
| @@ -33,69 +44,120 @@ hf_hub_download( | |
| 33 | 
             
            hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
         | 
| 34 |  | 
| 35 | 
             
            # Load face encoder
         | 
| 36 | 
            -
            app = FaceAnalysis( | 
|  | |
|  | |
|  | |
|  | |
| 37 | 
             
            app.prepare(ctx_id=0, det_size=(640, 640))
         | 
| 38 |  | 
| 39 | 
             
            # Path to InstantID models
         | 
| 40 | 
            -
            face_adapter = "./checkpoints/ip-adapter.bin"
         | 
| 41 | 
            -
            controlnet_path = "./checkpoints/ControlNetModel"
         | 
| 42 |  | 
| 43 | 
            -
            # Load pipeline
         | 
| 44 | 
            -
             | 
|  | |
|  | |
| 45 |  | 
| 46 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 47 |  | 
| 48 | 
             
            pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
         | 
| 49 | 
            -
                 | 
| 50 | 
            -
                controlnet= | 
| 51 | 
            -
                torch_dtype= | 
| 52 | 
             
                safety_checker=None,
         | 
| 53 | 
             
                feature_extractor=None,
         | 
|  | |
|  | |
|  | |
|  | |
| 54 | 
             
            )
         | 
| 55 | 
            -
            pipe.cuda()
         | 
| 56 | 
            -
            pipe.load_ip_adapter_instantid(face_adapter)
         | 
| 57 | 
            -
            pipe.image_proj_model.to("cuda")
         | 
| 58 | 
            -
            pipe.unet.to("cuda")
         | 
| 59 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 60 |  | 
| 61 | 
             
            def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
         | 
| 62 | 
             
                if randomize_seed:
         | 
| 63 | 
             
                    seed = random.randint(0, MAX_SEED)
         | 
| 64 | 
             
                return seed
         | 
| 65 |  | 
| 66 | 
            -
             | 
| 67 | 
             
            def remove_tips():
         | 
| 68 | 
             
                return gr.update(visible=False)
         | 
| 69 |  | 
| 70 | 
            -
             | 
| 71 | 
             
            def get_example():
         | 
| 72 | 
             
                case = [
         | 
| 73 | 
             
                    [
         | 
| 74 | 
             
                        "./examples/yann-lecun_resize.jpg",
         | 
|  | |
| 75 | 
             
                        "a man",
         | 
| 76 | 
             
                        "Snow",
         | 
| 77 | 
             
                        "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
         | 
| 78 | 
             
                    ],
         | 
| 79 | 
             
                    [
         | 
| 80 | 
             
                        "./examples/musk_resize.jpeg",
         | 
| 81 | 
            -
                        " | 
|  | |
| 82 | 
             
                        "Mars",
         | 
| 83 | 
             
                        "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
         | 
| 84 | 
             
                    ],
         | 
| 85 | 
             
                    [
         | 
| 86 | 
             
                        "./examples/sam_resize.png",
         | 
| 87 | 
            -
                        " | 
|  | |
| 88 | 
             
                        "Jungle",
         | 
| 89 | 
             
                        "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, gree",
         | 
| 90 | 
             
                    ],
         | 
| 91 | 
             
                    [
         | 
| 92 | 
             
                        "./examples/schmidhuber_resize.png",
         | 
| 93 | 
            -
                        " | 
|  | |
| 94 | 
             
                        "Neon",
         | 
| 95 | 
             
                        "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
         | 
| 96 | 
             
                    ],
         | 
| 97 | 
             
                    [
         | 
| 98 | 
             
                        "./examples/kaifu_resize.png",
         | 
|  | |
| 99 | 
             
                        "a man",
         | 
| 100 | 
             
                        "Vibrant Color",
         | 
| 101 | 
             
                        "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
         | 
| @@ -103,50 +165,33 @@ def get_example(): | |
| 103 | 
             
                ]
         | 
| 104 | 
             
                return case
         | 
| 105 |  | 
| 106 | 
            -
             | 
| 107 | 
            -
             | 
| 108 | 
            -
             | 
| 109 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 110 |  | 
| 111 | 
             
            def convert_from_cv2_to_image(img: np.ndarray) -> Image:
         | 
| 112 | 
             
                return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
         | 
| 113 |  | 
| 114 | 
            -
             | 
| 115 | 
             
            def convert_from_image_to_cv2(img: Image) -> np.ndarray:
         | 
| 116 | 
             
                return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
         | 
| 117 |  | 
| 118 | 
            -
             | 
| 119 | 
            -
            def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
         | 
| 120 | 
            -
                stickwidth = 4
         | 
| 121 | 
            -
                limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
         | 
| 122 | 
            -
                kps = np.array(kps)
         | 
| 123 | 
            -
             | 
| 124 | 
            -
                w, h = image_pil.size
         | 
| 125 | 
            -
                out_img = np.zeros([h, w, 3])
         | 
| 126 | 
            -
             | 
| 127 | 
            -
                for i in range(len(limbSeq)):
         | 
| 128 | 
            -
                    index = limbSeq[i]
         | 
| 129 | 
            -
                    color = color_list[index[0]]
         | 
| 130 | 
            -
             | 
| 131 | 
            -
                    x = kps[index][:, 0]
         | 
| 132 | 
            -
                    y = kps[index][:, 1]
         | 
| 133 | 
            -
                    length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
         | 
| 134 | 
            -
                    angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
         | 
| 135 | 
            -
                    polygon = cv2.ellipse2Poly(
         | 
| 136 | 
            -
                        (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1
         | 
| 137 | 
            -
                    )
         | 
| 138 | 
            -
                    out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
         | 
| 139 | 
            -
                out_img = (out_img * 0.6).astype(np.uint8)
         | 
| 140 | 
            -
             | 
| 141 | 
            -
                for idx_kp, kp in enumerate(kps):
         | 
| 142 | 
            -
                    color = color_list[idx_kp]
         | 
| 143 | 
            -
                    x, y = kp
         | 
| 144 | 
            -
                    out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
         | 
| 145 | 
            -
             | 
| 146 | 
            -
                out_img_pil = Image.fromarray(out_img.astype(np.uint8))
         | 
| 147 | 
            -
                return out_img_pil
         | 
| 148 | 
            -
             | 
| 149 | 
            -
             | 
| 150 | 
             
            def resize_img(
         | 
| 151 | 
             
                input_image,
         | 
| 152 | 
             
                max_side=1280,
         | 
| @@ -172,21 +217,18 @@ def resize_img( | |
| 172 | 
             
                    res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
         | 
| 173 | 
             
                    offset_x = (max_side - w_resize_new) // 2
         | 
| 174 | 
             
                    offset_y = (max_side - h_resize_new) // 2
         | 
| 175 | 
            -
                    res[ | 
|  | |
|  | |
| 176 | 
             
                    input_image = Image.fromarray(res)
         | 
| 177 | 
             
                return input_image
         | 
| 178 |  | 
| 179 | 
            -
             | 
| 180 | 
            -
             | 
|  | |
| 181 | 
             
                p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
         | 
| 182 | 
             
                return p.replace("{prompt}", positive), n + " " + negative
         | 
| 183 |  | 
| 184 | 
            -
             | 
| 185 | 
            -
            def check_input_image(face_image):
         | 
| 186 | 
            -
                if face_image is None:
         | 
| 187 | 
            -
                    raise gr.Error("Cannot find any input face image! Please upload the face image")
         | 
| 188 | 
            -
             | 
| 189 | 
            -
             | 
| 190 | 
             
            @spaces.GPU
         | 
| 191 | 
             
            def generate_image(
         | 
| 192 | 
             
                face_image_path,
         | 
| @@ -194,14 +236,41 @@ def generate_image( | |
| 194 | 
             
                prompt,
         | 
| 195 | 
             
                negative_prompt,
         | 
| 196 | 
             
                style_name,
         | 
| 197 | 
            -
                enhance_face_region,
         | 
| 198 | 
             
                num_steps,
         | 
| 199 | 
             
                identitynet_strength_ratio,
         | 
| 200 | 
             
                adapter_strength_ratio,
         | 
|  | |
|  | |
|  | |
|  | |
| 201 | 
             
                guidance_scale,
         | 
| 202 | 
             
                seed,
         | 
|  | |
|  | |
|  | |
| 203 | 
             
                progress=gr.Progress(track_tqdm=True),
         | 
| 204 | 
             
            ):
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 205 | 
             
                if prompt is None:
         | 
| 206 | 
             
                    prompt = "a person"
         | 
| 207 |  | 
| @@ -209,7 +278,7 @@ def generate_image( | |
| 209 | 
             
                prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
         | 
| 210 |  | 
| 211 | 
             
                face_image = load_image(face_image_path)
         | 
| 212 | 
            -
                face_image = resize_img(face_image)
         | 
| 213 | 
             
                face_image_cv2 = convert_from_image_to_cv2(face_image)
         | 
| 214 | 
             
                height, width, _ = face_image_cv2.shape
         | 
| 215 |  | 
| @@ -217,23 +286,31 @@ def generate_image( | |
| 217 | 
             
                face_info = app.get(face_image_cv2)
         | 
| 218 |  | 
| 219 | 
             
                if len(face_info) == 0:
         | 
| 220 | 
            -
                    raise gr.Error( | 
|  | |
|  | |
| 221 |  | 
| 222 | 
            -
                face_info = sorted( | 
|  | |
|  | |
|  | |
| 223 | 
             
                    -1
         | 
| 224 | 
             
                ]  # only use the maximum face
         | 
| 225 | 
             
                face_emb = face_info["embedding"]
         | 
| 226 | 
             
                face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"])
         | 
| 227 | 
            -
             | 
| 228 | 
             
                if pose_image_path is not None:
         | 
| 229 | 
             
                    pose_image = load_image(pose_image_path)
         | 
| 230 | 
            -
                    pose_image = resize_img(pose_image)
         | 
|  | |
| 231 | 
             
                    pose_image_cv2 = convert_from_image_to_cv2(pose_image)
         | 
| 232 |  | 
| 233 | 
             
                    face_info = app.get(pose_image_cv2)
         | 
| 234 |  | 
| 235 | 
             
                    if len(face_info) == 0:
         | 
| 236 | 
            -
                        raise gr.Error( | 
|  | |
|  | |
| 237 |  | 
| 238 | 
             
                    face_info = face_info[-1]
         | 
| 239 | 
             
                    face_kps = draw_kps(pose_image, face_info["kps"])
         | 
| @@ -249,6 +326,28 @@ def generate_image( | |
| 249 | 
             
                else:
         | 
| 250 | 
             
                    control_mask = None
         | 
| 251 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 252 | 
             
                generator = torch.Generator(device=device).manual_seed(seed)
         | 
| 253 |  | 
| 254 | 
             
                print("Start inference...")
         | 
| @@ -259,9 +358,9 @@ def generate_image( | |
| 259 | 
             
                    prompt=prompt,
         | 
| 260 | 
             
                    negative_prompt=negative_prompt,
         | 
| 261 | 
             
                    image_embeds=face_emb,
         | 
| 262 | 
            -
                    image= | 
| 263 | 
             
                    control_mask=control_mask,
         | 
| 264 | 
            -
                    controlnet_conditioning_scale= | 
| 265 | 
             
                    num_inference_steps=num_steps,
         | 
| 266 | 
             
                    guidance_scale=guidance_scale,
         | 
| 267 | 
             
                    height=height,
         | 
| @@ -271,8 +370,7 @@ def generate_image( | |
| 271 |  | 
| 272 | 
             
                return images[0], gr.update(visible=True)
         | 
| 273 |  | 
| 274 | 
            -
             | 
| 275 | 
            -
            ### Description
         | 
| 276 | 
             
            title = r"""
         | 
| 277 | 
             
            <h1 align="center">InstantID: Zero-shot Identity-Preserving Generation in Seconds</h1>
         | 
| 278 | 
             
            """
         | 
| @@ -281,12 +379,12 @@ description = r""" | |
| 281 | 
             
            <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/InstantID/InstantID' target='_blank'><b>InstantID: Zero-shot Identity-Preserving Generation in Seconds</b></a>.<br>
         | 
| 282 |  | 
| 283 | 
             
            How to use:<br>
         | 
| 284 | 
            -
            1. Upload a  | 
| 285 | 
            -
            2. ( | 
| 286 | 
            -
            3.  | 
| 287 | 
            -
            4.  | 
| 288 | 
            -
            5.  | 
| 289 | 
            -
            """
         | 
| 290 |  | 
| 291 | 
             
            article = r"""
         | 
| 292 | 
             
            ---
         | 
| @@ -295,10 +393,10 @@ article = r""" | |
| 295 | 
             
            If our work is helpful for your research or applications, please cite us via:
         | 
| 296 | 
             
            ```bibtex
         | 
| 297 | 
             
            @article{wang2024instantid,
         | 
| 298 | 
            -
             | 
| 299 | 
            -
             | 
| 300 | 
            -
             | 
| 301 | 
            -
             | 
| 302 | 
             
            }
         | 
| 303 | 
             
            ```
         | 
| 304 | 
             
            📧 **Contact**
         | 
| @@ -308,10 +406,10 @@ If you have any questions, please feel free to open an issue or directly reach u | |
| 308 |  | 
| 309 | 
             
            tips = r"""
         | 
| 310 | 
             
            ### Usage tips of InstantID
         | 
| 311 | 
            -
            1. If you're  | 
| 312 | 
            -
            2. If the  | 
| 313 | 
            -
            3. If text control is not as expected, decrease  | 
| 314 | 
            -
            4.  | 
| 315 | 
             
            """
         | 
| 316 |  | 
| 317 | 
             
            css = """
         | 
| @@ -324,27 +422,39 @@ with gr.Blocks(css=css) as demo: | |
| 324 |  | 
| 325 | 
             
                with gr.Row():
         | 
| 326 | 
             
                    with gr.Column():
         | 
| 327 | 
            -
                         | 
| 328 | 
            -
             | 
| 329 | 
            -
             | 
| 330 | 
            -
             | 
| 331 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 332 |  | 
| 333 | 
             
                        # prompt
         | 
| 334 | 
             
                        prompt = gr.Textbox(
         | 
| 335 | 
             
                            label="Prompt",
         | 
| 336 | 
            -
                            info="Give simple prompt is enough to achieve good face  | 
| 337 | 
             
                            placeholder="A photo of a person",
         | 
| 338 | 
             
                            value="",
         | 
| 339 | 
             
                        )
         | 
| 340 |  | 
| 341 | 
             
                        submit = gr.Button("Submit", variant="primary")
         | 
| 342 | 
            -
             | 
| 343 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 344 |  | 
| 345 | 
             
                        # strength
         | 
| 346 | 
             
                        identitynet_strength_ratio = gr.Slider(
         | 
| 347 | 
            -
                            label="IdentityNet strength (for  | 
| 348 | 
             
                            minimum=0,
         | 
| 349 | 
             
                            maximum=1.5,
         | 
| 350 | 
             
                            step=0.05,
         | 
| @@ -357,26 +467,51 @@ with gr.Blocks(css=css) as demo: | |
| 357 | 
             
                            step=0.05,
         | 
| 358 | 
             
                            value=0.80,
         | 
| 359 | 
             
                        )
         | 
| 360 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 361 | 
             
                        with gr.Accordion(open=False, label="Advanced Options"):
         | 
| 362 | 
             
                            negative_prompt = gr.Textbox(
         | 
| 363 | 
             
                                label="Negative Prompt",
         | 
| 364 | 
             
                                placeholder="low quality",
         | 
| 365 | 
            -
                                value="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed,  | 
| 366 | 
             
                            )
         | 
| 367 | 
             
                            num_steps = gr.Slider(
         | 
| 368 | 
             
                                label="Number of sample steps",
         | 
| 369 | 
            -
                                minimum= | 
| 370 | 
             
                                maximum=100,
         | 
| 371 | 
             
                                step=1,
         | 
| 372 | 
            -
                                value=30,
         | 
| 373 | 
             
                            )
         | 
| 374 | 
             
                            guidance_scale = gr.Slider(
         | 
| 375 | 
             
                                label="Guidance scale",
         | 
| 376 | 
             
                                minimum=0.1,
         | 
| 377 | 
            -
                                maximum= | 
| 378 | 
             
                                step=0.1,
         | 
| 379 | 
            -
                                value=5,
         | 
| 380 | 
             
                            )
         | 
| 381 | 
             
                            seed = gr.Slider(
         | 
| 382 | 
             
                                label="Seed",
         | 
| @@ -385,18 +520,31 @@ with gr.Blocks(css=css) as demo: | |
| 385 | 
             
                                step=1,
         | 
| 386 | 
             
                                value=42,
         | 
| 387 | 
             
                            )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 388 | 
             
                            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
         | 
| 389 | 
             
                            enhance_face_region = gr.Checkbox(label="Enhance non-face region", value=True)
         | 
| 390 |  | 
| 391 | 
            -
                    with gr.Column():
         | 
| 392 | 
            -
                         | 
| 393 | 
            -
                        usage_tips = gr.Markdown( | 
|  | |
|  | |
| 394 |  | 
| 395 | 
             
                    submit.click(
         | 
| 396 | 
             
                        fn=remove_tips,
         | 
| 397 | 
             
                        outputs=usage_tips,
         | 
| 398 | 
            -
                        queue=False,
         | 
| 399 | 
            -
                        api_name=False,
         | 
| 400 | 
             
                    ).then(
         | 
| 401 | 
             
                        fn=randomize_seed_fn,
         | 
| 402 | 
             
                        inputs=[seed, randomize_seed],
         | 
| @@ -404,11 +552,6 @@ with gr.Blocks(css=css) as demo: | |
| 404 | 
             
                        queue=False,
         | 
| 405 | 
             
                        api_name=False,
         | 
| 406 | 
             
                    ).then(
         | 
| 407 | 
            -
                        fn=check_input_image,
         | 
| 408 | 
            -
                        inputs=face_file,
         | 
| 409 | 
            -
                        queue=False,
         | 
| 410 | 
            -
                        api_name=False,
         | 
| 411 | 
            -
                    ).success(
         | 
| 412 | 
             
                        fn=generate_image,
         | 
| 413 | 
             
                        inputs=[
         | 
| 414 | 
             
                            face_file,
         | 
| @@ -416,21 +559,34 @@ with gr.Blocks(css=css) as demo: | |
| 416 | 
             
                            prompt,
         | 
| 417 | 
             
                            negative_prompt,
         | 
| 418 | 
             
                            style,
         | 
| 419 | 
            -
                            enhance_face_region,
         | 
| 420 | 
             
                            num_steps,
         | 
| 421 | 
             
                            identitynet_strength_ratio,
         | 
| 422 | 
             
                            adapter_strength_ratio,
         | 
|  | |
|  | |
|  | |
|  | |
| 423 | 
             
                            guidance_scale,
         | 
| 424 | 
             
                            seed,
         | 
|  | |
|  | |
|  | |
| 425 | 
             
                        ],
         | 
| 426 | 
            -
                        outputs=[ | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 427 | 
             
                    )
         | 
| 428 |  | 
| 429 | 
             
                gr.Examples(
         | 
| 430 | 
             
                    examples=get_example(),
         | 
| 431 | 
            -
                    inputs=[face_file, prompt, style, negative_prompt],
         | 
| 432 | 
            -
                    outputs=[output_image, usage_tips],
         | 
| 433 | 
             
                    fn=run_for_examples,
         | 
|  | |
| 434 | 
             
                    cache_examples=True,
         | 
| 435 | 
             
                )
         | 
| 436 |  | 
|  | |
| 1 | 
            +
            import os
         | 
|  | |
|  | |
| 2 | 
             
            import cv2
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import random
         | 
| 5 | 
             
            import numpy as np
         | 
| 6 | 
            +
             | 
| 7 | 
             
            import PIL
         | 
| 8 | 
            +
            from PIL import Image
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import diffusers
         | 
| 11 | 
             
            from diffusers.utils import load_image
         | 
| 12 | 
            +
            from diffusers.models import ControlNetModel
         | 
| 13 | 
            +
            from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from huggingface_hub import hf_hub_download
         | 
| 16 | 
            +
             | 
| 17 | 
             
            from insightface.app import FaceAnalysis
         | 
|  | |
| 18 |  | 
|  | |
| 19 | 
             
            from style_template import styles
         | 
| 20 | 
            +
            from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline, draw_kps
         | 
| 21 | 
            +
            from model_util import load_models_xl, get_torch_device
         | 
| 22 | 
            +
            from controlnet_util import openpose, get_depth_map, get_canny_image
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            import gradio as gr
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            import spaces
         | 
| 27 |  | 
| 28 | 
             
            # global variable
         | 
| 29 | 
             
            MAX_SEED = np.iinfo(np.int32).max
         | 
| 30 | 
            +
            device = get_torch_device()
         | 
| 31 | 
            +
            dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
         | 
| 32 | 
             
            STYLE_NAMES = list(styles.keys())
         | 
| 33 | 
             
            DEFAULT_STYLE_NAME = "Watercolor"
         | 
| 34 |  | 
|  | |
| 44 | 
             
            hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
         | 
| 45 |  | 
| 46 | 
             
            # Load face encoder
         | 
| 47 | 
            +
            app = FaceAnalysis(
         | 
| 48 | 
            +
                name="antelopev2",
         | 
| 49 | 
            +
                root="./",
         | 
| 50 | 
            +
                providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
         | 
| 51 | 
            +
            )
         | 
| 52 | 
             
            app.prepare(ctx_id=0, det_size=(640, 640))
         | 
| 53 |  | 
| 54 | 
             
            # Path to InstantID models
         | 
| 55 | 
            +
            face_adapter = f"./checkpoints/ip-adapter.bin"
         | 
| 56 | 
            +
            controlnet_path = f"./checkpoints/ControlNetModel"
         | 
| 57 |  | 
| 58 | 
            +
            # Load pipeline face ControlNetModel
         | 
| 59 | 
            +
            controlnet_identitynet = ControlNetModel.from_pretrained(
         | 
| 60 | 
            +
                controlnet_path, torch_dtype=dtype
         | 
| 61 | 
            +
            )
         | 
| 62 |  | 
| 63 | 
            +
            # controlnet-pose
         | 
| 64 | 
            +
            controlnet_pose_model = "thibaud/controlnet-openpose-sdxl-1.0"
         | 
| 65 | 
            +
            controlnet_canny_model = "diffusers/controlnet-canny-sdxl-1.0"
         | 
| 66 | 
            +
            controlnet_depth_model = "diffusers/controlnet-depth-sdxl-1.0-small"
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            controlnet_pose = ControlNetModel.from_pretrained(
         | 
| 69 | 
            +
                controlnet_pose_model, torch_dtype=dtype
         | 
| 70 | 
            +
            ).to(device)
         | 
| 71 | 
            +
            controlnet_canny = ControlNetModel.from_pretrained(
         | 
| 72 | 
            +
                controlnet_canny_model, torch_dtype=dtype
         | 
| 73 | 
            +
            ).to(device)
         | 
| 74 | 
            +
            controlnet_depth = ControlNetModel.from_pretrained(
         | 
| 75 | 
            +
                controlnet_depth_model, torch_dtype=dtype
         | 
| 76 | 
            +
            ).to(device)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            controlnet_map = {
         | 
| 79 | 
            +
                "pose": controlnet_pose,
         | 
| 80 | 
            +
                "canny": controlnet_canny,
         | 
| 81 | 
            +
                "depth": controlnet_depth,
         | 
| 82 | 
            +
            }
         | 
| 83 | 
            +
            controlnet_map_fn = {
         | 
| 84 | 
            +
                "pose": openpose,
         | 
| 85 | 
            +
                "canny": get_canny_image,
         | 
| 86 | 
            +
                "depth": get_depth_map,
         | 
| 87 | 
            +
            }
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            pretrained_model_name_or_path = "wangqixun/YamerMIX_v8"
         | 
| 90 |  | 
| 91 | 
             
            pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
         | 
| 92 | 
            +
                pretrained_model_name_or_path,
         | 
| 93 | 
            +
                controlnet=[controlnet_identitynet],
         | 
| 94 | 
            +
                torch_dtype=dtype,
         | 
| 95 | 
             
                safety_checker=None,
         | 
| 96 | 
             
                feature_extractor=None,
         | 
| 97 | 
            +
            ).to(device)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
            pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(
         | 
| 100 | 
            +
                pipe.scheduler.config
         | 
| 101 | 
             
            )
         | 
|  | |
|  | |
|  | |
|  | |
| 102 |  | 
| 103 | 
            +
            pipe.load_ip_adapter_instantid(face_adapter)
         | 
| 104 | 
            +
            # load and disable LCM
         | 
| 105 | 
            +
            pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
         | 
| 106 | 
            +
            pipe.disable_lora()
         | 
| 107 | 
            +
             | 
| 108 | 
            +
            def toggle_lcm_ui(value):
         | 
| 109 | 
            +
                if value:
         | 
| 110 | 
            +
                    return (
         | 
| 111 | 
            +
                        gr.update(minimum=0, maximum=100, step=1, value=5),
         | 
| 112 | 
            +
                        gr.update(minimum=0.1, maximum=20.0, step=0.1, value=1.5),
         | 
| 113 | 
            +
                    )
         | 
| 114 | 
            +
                else:
         | 
| 115 | 
            +
                    return (
         | 
| 116 | 
            +
                        gr.update(minimum=5, maximum=100, step=1, value=30),
         | 
| 117 | 
            +
                        gr.update(minimum=0.1, maximum=20.0, step=0.1, value=5),
         | 
| 118 | 
            +
                    )
         | 
| 119 |  | 
| 120 | 
             
            def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
         | 
| 121 | 
             
                if randomize_seed:
         | 
| 122 | 
             
                    seed = random.randint(0, MAX_SEED)
         | 
| 123 | 
             
                return seed
         | 
| 124 |  | 
|  | |
| 125 | 
             
            def remove_tips():
         | 
| 126 | 
             
                return gr.update(visible=False)
         | 
| 127 |  | 
|  | |
| 128 | 
             
            def get_example():
         | 
| 129 | 
             
                case = [
         | 
| 130 | 
             
                    [
         | 
| 131 | 
             
                        "./examples/yann-lecun_resize.jpg",
         | 
| 132 | 
            +
                        None,
         | 
| 133 | 
             
                        "a man",
         | 
| 134 | 
             
                        "Snow",
         | 
| 135 | 
             
                        "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
         | 
| 136 | 
             
                    ],
         | 
| 137 | 
             
                    [
         | 
| 138 | 
             
                        "./examples/musk_resize.jpeg",
         | 
| 139 | 
            +
                        "./examples/poses/pose2.jpg",
         | 
| 140 | 
            +
                        "a man flying in the sky in Mars",
         | 
| 141 | 
             
                        "Mars",
         | 
| 142 | 
             
                        "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
         | 
| 143 | 
             
                    ],
         | 
| 144 | 
             
                    [
         | 
| 145 | 
             
                        "./examples/sam_resize.png",
         | 
| 146 | 
            +
                        "./examples/poses/pose4.jpg",
         | 
| 147 | 
            +
                        "a man doing a silly pose wearing a suite",
         | 
| 148 | 
             
                        "Jungle",
         | 
| 149 | 
             
                        "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, gree",
         | 
| 150 | 
             
                    ],
         | 
| 151 | 
             
                    [
         | 
| 152 | 
             
                        "./examples/schmidhuber_resize.png",
         | 
| 153 | 
            +
                        "./examples/poses/pose3.jpg",
         | 
| 154 | 
            +
                        "a man sit on a chair",
         | 
| 155 | 
             
                        "Neon",
         | 
| 156 | 
             
                        "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
         | 
| 157 | 
             
                    ],
         | 
| 158 | 
             
                    [
         | 
| 159 | 
             
                        "./examples/kaifu_resize.png",
         | 
| 160 | 
            +
                        "./examples/poses/pose.jpg",
         | 
| 161 | 
             
                        "a man",
         | 
| 162 | 
             
                        "Vibrant Color",
         | 
| 163 | 
             
                        "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
         | 
|  | |
| 165 | 
             
                ]
         | 
| 166 | 
             
                return case
         | 
| 167 |  | 
| 168 | 
            +
            def run_for_examples(face_file, pose_file, prompt, style, negative_prompt):
         | 
| 169 | 
            +
                return generate_image(
         | 
| 170 | 
            +
                    face_file,
         | 
| 171 | 
            +
                    pose_file,
         | 
| 172 | 
            +
                    prompt,
         | 
| 173 | 
            +
                    negative_prompt,
         | 
| 174 | 
            +
                    style,
         | 
| 175 | 
            +
                    20,  # num_steps
         | 
| 176 | 
            +
                    0.8,  # identitynet_strength_ratio
         | 
| 177 | 
            +
                    0.8,  # adapter_strength_ratio
         | 
| 178 | 
            +
                    0.4,  # pose_strength
         | 
| 179 | 
            +
                    0.3,  # canny_strength
         | 
| 180 | 
            +
                    0.5,  # depth_strength
         | 
| 181 | 
            +
                    ["pose", "canny"],  # controlnet_selection
         | 
| 182 | 
            +
                    5.0,  # guidance_scale
         | 
| 183 | 
            +
                    42,  # seed
         | 
| 184 | 
            +
                    "EulerDiscreteScheduler",  # scheduler
         | 
| 185 | 
            +
                    False,  # enable_LCM
         | 
| 186 | 
            +
                    True,  # enable_Face_Region
         | 
| 187 | 
            +
                )
         | 
| 188 |  | 
| 189 | 
             
            def convert_from_cv2_to_image(img: np.ndarray) -> Image:
         | 
| 190 | 
             
                return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
         | 
| 191 |  | 
|  | |
| 192 | 
             
            def convert_from_image_to_cv2(img: Image) -> np.ndarray:
         | 
| 193 | 
             
                return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
         | 
| 194 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 195 | 
             
            def resize_img(
         | 
| 196 | 
             
                input_image,
         | 
| 197 | 
             
                max_side=1280,
         | 
|  | |
| 217 | 
             
                    res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
         | 
| 218 | 
             
                    offset_x = (max_side - w_resize_new) // 2
         | 
| 219 | 
             
                    offset_y = (max_side - h_resize_new) // 2
         | 
| 220 | 
            +
                    res[
         | 
| 221 | 
            +
                        offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new
         | 
| 222 | 
            +
                    ] = np.array(input_image)
         | 
| 223 | 
             
                    input_image = Image.fromarray(res)
         | 
| 224 | 
             
                return input_image
         | 
| 225 |  | 
| 226 | 
            +
            def apply_style(
         | 
| 227 | 
            +
                style_name: str, positive: str, negative: str = ""
         | 
| 228 | 
            +
            ) -> tuple[str, str]:
         | 
| 229 | 
             
                p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
         | 
| 230 | 
             
                return p.replace("{prompt}", positive), n + " " + negative
         | 
| 231 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 232 | 
             
            @spaces.GPU
         | 
| 233 | 
             
            def generate_image(
         | 
| 234 | 
             
                face_image_path,
         | 
|  | |
| 236 | 
             
                prompt,
         | 
| 237 | 
             
                negative_prompt,
         | 
| 238 | 
             
                style_name,
         | 
|  | |
| 239 | 
             
                num_steps,
         | 
| 240 | 
             
                identitynet_strength_ratio,
         | 
| 241 | 
             
                adapter_strength_ratio,
         | 
| 242 | 
            +
                pose_strength,
         | 
| 243 | 
            +
                canny_strength,
         | 
| 244 | 
            +
                depth_strength,
         | 
| 245 | 
            +
                controlnet_selection,
         | 
| 246 | 
             
                guidance_scale,
         | 
| 247 | 
             
                seed,
         | 
| 248 | 
            +
                scheduler,
         | 
| 249 | 
            +
                enable_LCM,
         | 
| 250 | 
            +
                enhance_face_region,
         | 
| 251 | 
             
                progress=gr.Progress(track_tqdm=True),
         | 
| 252 | 
             
            ):
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                if enable_LCM:
         | 
| 255 | 
            +
                    pipe.scheduler = diffusers.LCMScheduler.from_config(pipe.scheduler.config)
         | 
| 256 | 
            +
                    pipe.enable_lora()
         | 
| 257 | 
            +
                else:
         | 
| 258 | 
            +
                    pipe.disable_lora()
         | 
| 259 | 
            +
                    scheduler_class_name = scheduler.split("-")[0]
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    add_kwargs = {}
         | 
| 262 | 
            +
                    if len(scheduler.split("-")) > 1:
         | 
| 263 | 
            +
                        add_kwargs["use_karras_sigmas"] = True
         | 
| 264 | 
            +
                    if len(scheduler.split("-")) > 2:
         | 
| 265 | 
            +
                        add_kwargs["algorithm_type"] = "sde-dpmsolver++"
         | 
| 266 | 
            +
                    scheduler = getattr(diffusers, scheduler_class_name)
         | 
| 267 | 
            +
                    pipe.scheduler = scheduler.from_config(pipe.scheduler.config, **add_kwargs)
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                if face_image_path is None:
         | 
| 270 | 
            +
                    raise gr.Error(
         | 
| 271 | 
            +
                        f"Cannot find any input face image! Please upload the face image"
         | 
| 272 | 
            +
                    )
         | 
| 273 | 
            +
             | 
| 274 | 
             
                if prompt is None:
         | 
| 275 | 
             
                    prompt = "a person"
         | 
| 276 |  | 
|  | |
| 278 | 
             
                prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
         | 
| 279 |  | 
| 280 | 
             
                face_image = load_image(face_image_path)
         | 
| 281 | 
            +
                face_image = resize_img(face_image, max_side=1024)
         | 
| 282 | 
             
                face_image_cv2 = convert_from_image_to_cv2(face_image)
         | 
| 283 | 
             
                height, width, _ = face_image_cv2.shape
         | 
| 284 |  | 
|  | |
| 286 | 
             
                face_info = app.get(face_image_cv2)
         | 
| 287 |  | 
| 288 | 
             
                if len(face_info) == 0:
         | 
| 289 | 
            +
                    raise gr.Error(
         | 
| 290 | 
            +
                        f"Unable to detect a face in the image. Please upload a different photo with a clear face."
         | 
| 291 | 
            +
                    )
         | 
| 292 |  | 
| 293 | 
            +
                face_info = sorted(
         | 
| 294 | 
            +
                    face_info,
         | 
| 295 | 
            +
                    key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1],
         | 
| 296 | 
            +
                )[
         | 
| 297 | 
             
                    -1
         | 
| 298 | 
             
                ]  # only use the maximum face
         | 
| 299 | 
             
                face_emb = face_info["embedding"]
         | 
| 300 | 
             
                face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"])
         | 
| 301 | 
            +
                img_controlnet = face_image
         | 
| 302 | 
             
                if pose_image_path is not None:
         | 
| 303 | 
             
                    pose_image = load_image(pose_image_path)
         | 
| 304 | 
            +
                    pose_image = resize_img(pose_image, max_side=1024)
         | 
| 305 | 
            +
                    img_controlnet = pose_image
         | 
| 306 | 
             
                    pose_image_cv2 = convert_from_image_to_cv2(pose_image)
         | 
| 307 |  | 
| 308 | 
             
                    face_info = app.get(pose_image_cv2)
         | 
| 309 |  | 
| 310 | 
             
                    if len(face_info) == 0:
         | 
| 311 | 
            +
                        raise gr.Error(
         | 
| 312 | 
            +
                            f"Cannot find any face in the reference image! Please upload another person image"
         | 
| 313 | 
            +
                        )
         | 
| 314 |  | 
| 315 | 
             
                    face_info = face_info[-1]
         | 
| 316 | 
             
                    face_kps = draw_kps(pose_image, face_info["kps"])
         | 
|  | |
| 326 | 
             
                else:
         | 
| 327 | 
             
                    control_mask = None
         | 
| 328 |  | 
| 329 | 
            +
                if len(controlnet_selection) > 0:
         | 
| 330 | 
            +
                    controlnet_scales = {
         | 
| 331 | 
            +
                        "pose": pose_strength,
         | 
| 332 | 
            +
                        "canny": canny_strength,
         | 
| 333 | 
            +
                        "depth": depth_strength,
         | 
| 334 | 
            +
                    }
         | 
| 335 | 
            +
                    pipe.controlnet = MultiControlNetModel(
         | 
| 336 | 
            +
                        [controlnet_identitynet]
         | 
| 337 | 
            +
                        + [controlnet_map[s] for s in controlnet_selection]
         | 
| 338 | 
            +
                    )
         | 
| 339 | 
            +
                    control_scales = [float(identitynet_strength_ratio)] + [
         | 
| 340 | 
            +
                        controlnet_scales[s] for s in controlnet_selection
         | 
| 341 | 
            +
                    ]
         | 
| 342 | 
            +
                    control_images = [face_kps] + [
         | 
| 343 | 
            +
                        controlnet_map_fn[s](img_controlnet).resize((width, height))
         | 
| 344 | 
            +
                        for s in controlnet_selection
         | 
| 345 | 
            +
                    ]
         | 
| 346 | 
            +
                else:
         | 
| 347 | 
            +
                    pipe.controlnet = controlnet_identitynet
         | 
| 348 | 
            +
                    control_scales = float(identitynet_strength_ratio)
         | 
| 349 | 
            +
                    control_images = face_kps
         | 
| 350 | 
            +
             | 
| 351 | 
             
                generator = torch.Generator(device=device).manual_seed(seed)
         | 
| 352 |  | 
| 353 | 
             
                print("Start inference...")
         | 
|  | |
| 358 | 
             
                    prompt=prompt,
         | 
| 359 | 
             
                    negative_prompt=negative_prompt,
         | 
| 360 | 
             
                    image_embeds=face_emb,
         | 
| 361 | 
            +
                    image=control_images,
         | 
| 362 | 
             
                    control_mask=control_mask,
         | 
| 363 | 
            +
                    controlnet_conditioning_scale=control_scales,
         | 
| 364 | 
             
                    num_inference_steps=num_steps,
         | 
| 365 | 
             
                    guidance_scale=guidance_scale,
         | 
| 366 | 
             
                    height=height,
         | 
|  | |
| 370 |  | 
| 371 | 
             
                return images[0], gr.update(visible=True)
         | 
| 372 |  | 
| 373 | 
            +
            # Description
         | 
|  | |
| 374 | 
             
            title = r"""
         | 
| 375 | 
             
            <h1 align="center">InstantID: Zero-shot Identity-Preserving Generation in Seconds</h1>
         | 
| 376 | 
             
            """
         | 
|  | |
| 379 | 
             
            <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/InstantID/InstantID' target='_blank'><b>InstantID: Zero-shot Identity-Preserving Generation in Seconds</b></a>.<br>
         | 
| 380 |  | 
| 381 | 
             
            How to use:<br>
         | 
| 382 | 
            +
            1. Upload an image with a face. For images with multiple faces, we will only detect the largest face. Ensure the face is not too small and is clearly visible without significant obstructions or blurring.
         | 
| 383 | 
            +
            2. (Optional) You can upload another image as a reference for the face pose. If you don't, we will use the first detected face image to extract facial landmarks. If you use a cropped face at step 1, it is recommended to upload it to define a new face pose.
         | 
| 384 | 
            +
            3. (Optional) You can select multiple ControlNet models to control the generation process. The default is to use the IdentityNet only. The ControlNet models include pose skeleton, canny, and depth. You can adjust the strength of each ControlNet model to control the generation process.
         | 
| 385 | 
            +
            4. Enter a text prompt, as done in normal text-to-image models.
         | 
| 386 | 
            +
            5. Click the <b>Submit</b> button to begin customization.
         | 
| 387 | 
            +
            6. Share your customized photo with your friends and enjoy! 😊"""
         | 
| 388 |  | 
| 389 | 
             
            article = r"""
         | 
| 390 | 
             
            ---
         | 
|  | |
| 393 | 
             
            If our work is helpful for your research or applications, please cite us via:
         | 
| 394 | 
             
            ```bibtex
         | 
| 395 | 
             
            @article{wang2024instantid,
         | 
| 396 | 
            +
            title={InstantID: Zero-shot Identity-Preserving Generation in Seconds},
         | 
| 397 | 
            +
            author={Wang, Qixun and Bai, Xu and Wang, Haofan and Qin, Zekui and Chen, Anthony},
         | 
| 398 | 
            +
            journal={arXiv preprint arXiv:2401.07519},
         | 
| 399 | 
            +
            year={2024}
         | 
| 400 | 
             
            }
         | 
| 401 | 
             
            ```
         | 
| 402 | 
             
            📧 **Contact**
         | 
|  | |
| 406 |  | 
| 407 | 
             
            tips = r"""
         | 
| 408 | 
             
            ### Usage tips of InstantID
         | 
| 409 | 
            +
            1. If you're not satisfied with the similarity, try increasing the weight of "IdentityNet Strength" and "Adapter Strength."    
         | 
| 410 | 
            +
            2. If you feel that the saturation is too high, first decrease the Adapter strength. If it remains too high, then decrease the IdentityNet strength.
         | 
| 411 | 
            +
            3. If you find that text control is not as expected, decrease Adapter strength.
         | 
| 412 | 
            +
            4. If you find that realistic style is not good enough, go for our Github repo and use a more realistic base model.
         | 
| 413 | 
             
            """
         | 
| 414 |  | 
| 415 | 
             
            css = """
         | 
|  | |
| 422 |  | 
| 423 | 
             
                with gr.Row():
         | 
| 424 | 
             
                    with gr.Column():
         | 
| 425 | 
            +
                        with gr.Row(equal_height=True):
         | 
| 426 | 
            +
                            # upload face image
         | 
| 427 | 
            +
                            face_file = gr.Image(
         | 
| 428 | 
            +
                                label="Upload a photo of your face", type="filepath"
         | 
| 429 | 
            +
                            )
         | 
| 430 | 
            +
                            # optional: upload a reference pose image
         | 
| 431 | 
            +
                            pose_file = gr.Image(
         | 
| 432 | 
            +
                                label="Upload a reference pose image (Optional)",
         | 
| 433 | 
            +
                                type="filepath",
         | 
| 434 | 
            +
                            )
         | 
| 435 |  | 
| 436 | 
             
                        # prompt
         | 
| 437 | 
             
                        prompt = gr.Textbox(
         | 
| 438 | 
             
                            label="Prompt",
         | 
| 439 | 
            +
                            info="Give simple prompt is enough to achieve good face fidelity",
         | 
| 440 | 
             
                            placeholder="A photo of a person",
         | 
| 441 | 
             
                            value="",
         | 
| 442 | 
             
                        )
         | 
| 443 |  | 
| 444 | 
             
                        submit = gr.Button("Submit", variant="primary")
         | 
| 445 | 
            +
                        enable_LCM = gr.Checkbox(
         | 
| 446 | 
            +
                            label="Enable Fast Inference with LCM", value=enable_lcm_arg,
         | 
| 447 | 
            +
                            info="LCM speeds up the inference step, the trade-off is the quality of the generated image. It performs better with portrait face images rather than distant faces",
         | 
| 448 | 
            +
                        )
         | 
| 449 | 
            +
                        style = gr.Dropdown(
         | 
| 450 | 
            +
                            label="Style template",
         | 
| 451 | 
            +
                            choices=STYLE_NAMES,
         | 
| 452 | 
            +
                            value=DEFAULT_STYLE_NAME,
         | 
| 453 | 
            +
                        )
         | 
| 454 |  | 
| 455 | 
             
                        # strength
         | 
| 456 | 
             
                        identitynet_strength_ratio = gr.Slider(
         | 
| 457 | 
            +
                            label="IdentityNet strength (for fidelity)",
         | 
| 458 | 
             
                            minimum=0,
         | 
| 459 | 
             
                            maximum=1.5,
         | 
| 460 | 
             
                            step=0.05,
         | 
|  | |
| 467 | 
             
                            step=0.05,
         | 
| 468 | 
             
                            value=0.80,
         | 
| 469 | 
             
                        )
         | 
| 470 | 
            +
                        with gr.Accordion("Controlnet"):
         | 
| 471 | 
            +
                            controlnet_selection = gr.CheckboxGroup(
         | 
| 472 | 
            +
                                ["pose", "canny", "depth"], label="Controlnet", value=["pose"],
         | 
| 473 | 
            +
                                info="Use pose for skeleton inference, canny for edge detection, and depth for depth map estimation. You can try all three to control the generation process"
         | 
| 474 | 
            +
                            )
         | 
| 475 | 
            +
                            pose_strength = gr.Slider(
         | 
| 476 | 
            +
                                label="Pose strength",
         | 
| 477 | 
            +
                                minimum=0,
         | 
| 478 | 
            +
                                maximum=1.5,
         | 
| 479 | 
            +
                                step=0.05,
         | 
| 480 | 
            +
                                value=0.40,
         | 
| 481 | 
            +
                            )
         | 
| 482 | 
            +
                            canny_strength = gr.Slider(
         | 
| 483 | 
            +
                                label="Canny strength",
         | 
| 484 | 
            +
                                minimum=0,
         | 
| 485 | 
            +
                                maximum=1.5,
         | 
| 486 | 
            +
                                step=0.05,
         | 
| 487 | 
            +
                                value=0.40,
         | 
| 488 | 
            +
                            )
         | 
| 489 | 
            +
                            depth_strength = gr.Slider(
         | 
| 490 | 
            +
                                label="Depth strength",
         | 
| 491 | 
            +
                                minimum=0,
         | 
| 492 | 
            +
                                maximum=1.5,
         | 
| 493 | 
            +
                                step=0.05,
         | 
| 494 | 
            +
                                value=0.40,
         | 
| 495 | 
            +
                            )
         | 
| 496 | 
             
                        with gr.Accordion(open=False, label="Advanced Options"):
         | 
| 497 | 
             
                            negative_prompt = gr.Textbox(
         | 
| 498 | 
             
                                label="Negative Prompt",
         | 
| 499 | 
             
                                placeholder="low quality",
         | 
| 500 | 
            +
                                value="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
         | 
| 501 | 
             
                            )
         | 
| 502 | 
             
                            num_steps = gr.Slider(
         | 
| 503 | 
             
                                label="Number of sample steps",
         | 
| 504 | 
            +
                                minimum=1,
         | 
| 505 | 
             
                                maximum=100,
         | 
| 506 | 
             
                                step=1,
         | 
| 507 | 
            +
                                value=5 if enable_lcm_arg else 30,
         | 
| 508 | 
             
                            )
         | 
| 509 | 
             
                            guidance_scale = gr.Slider(
         | 
| 510 | 
             
                                label="Guidance scale",
         | 
| 511 | 
             
                                minimum=0.1,
         | 
| 512 | 
            +
                                maximum=20.0,
         | 
| 513 | 
             
                                step=0.1,
         | 
| 514 | 
            +
                                value=0.0 if enable_lcm_arg else 5.0,
         | 
| 515 | 
             
                            )
         | 
| 516 | 
             
                            seed = gr.Slider(
         | 
| 517 | 
             
                                label="Seed",
         | 
|  | |
| 520 | 
             
                                step=1,
         | 
| 521 | 
             
                                value=42,
         | 
| 522 | 
             
                            )
         | 
| 523 | 
            +
                            schedulers = [
         | 
| 524 | 
            +
                                "DEISMultistepScheduler",
         | 
| 525 | 
            +
                                "HeunDiscreteScheduler",
         | 
| 526 | 
            +
                                "EulerDiscreteScheduler",
         | 
| 527 | 
            +
                                "DPMSolverMultistepScheduler",
         | 
| 528 | 
            +
                                "DPMSolverMultistepScheduler-Karras",
         | 
| 529 | 
            +
                                "DPMSolverMultistepScheduler-Karras-SDE",
         | 
| 530 | 
            +
                            ]
         | 
| 531 | 
            +
                            scheduler = gr.Dropdown(
         | 
| 532 | 
            +
                                label="Schedulers",
         | 
| 533 | 
            +
                                choices=schedulers,
         | 
| 534 | 
            +
                                value="EulerDiscreteScheduler",
         | 
| 535 | 
            +
                            )
         | 
| 536 | 
             
                            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
         | 
| 537 | 
             
                            enhance_face_region = gr.Checkbox(label="Enhance non-face region", value=True)
         | 
| 538 |  | 
| 539 | 
            +
                    with gr.Column(scale=1):
         | 
| 540 | 
            +
                        gallery = gr.Image(label="Generated Images")
         | 
| 541 | 
            +
                        usage_tips = gr.Markdown(
         | 
| 542 | 
            +
                            label="InstantID Usage Tips", value=tips, visible=False
         | 
| 543 | 
            +
                        )
         | 
| 544 |  | 
| 545 | 
             
                    submit.click(
         | 
| 546 | 
             
                        fn=remove_tips,
         | 
| 547 | 
             
                        outputs=usage_tips,
         | 
|  | |
|  | |
| 548 | 
             
                    ).then(
         | 
| 549 | 
             
                        fn=randomize_seed_fn,
         | 
| 550 | 
             
                        inputs=[seed, randomize_seed],
         | 
|  | |
| 552 | 
             
                        queue=False,
         | 
| 553 | 
             
                        api_name=False,
         | 
| 554 | 
             
                    ).then(
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 555 | 
             
                        fn=generate_image,
         | 
| 556 | 
             
                        inputs=[
         | 
| 557 | 
             
                            face_file,
         | 
|  | |
| 559 | 
             
                            prompt,
         | 
| 560 | 
             
                            negative_prompt,
         | 
| 561 | 
             
                            style,
         | 
|  | |
| 562 | 
             
                            num_steps,
         | 
| 563 | 
             
                            identitynet_strength_ratio,
         | 
| 564 | 
             
                            adapter_strength_ratio,
         | 
| 565 | 
            +
                            pose_strength,
         | 
| 566 | 
            +
                            canny_strength,
         | 
| 567 | 
            +
                            depth_strength,
         | 
| 568 | 
            +
                            controlnet_selection,
         | 
| 569 | 
             
                            guidance_scale,
         | 
| 570 | 
             
                            seed,
         | 
| 571 | 
            +
                            scheduler,
         | 
| 572 | 
            +
                            enable_LCM,
         | 
| 573 | 
            +
                            enhance_face_region,
         | 
| 574 | 
             
                        ],
         | 
| 575 | 
            +
                        outputs=[gallery, usage_tips],
         | 
| 576 | 
            +
                    )
         | 
| 577 | 
            +
             | 
| 578 | 
            +
                    enable_LCM.input(
         | 
| 579 | 
            +
                        fn=toggle_lcm_ui,
         | 
| 580 | 
            +
                        inputs=[enable_LCM],
         | 
| 581 | 
            +
                        outputs=[num_steps, guidance_scale],
         | 
| 582 | 
            +
                        queue=False,
         | 
| 583 | 
             
                    )
         | 
| 584 |  | 
| 585 | 
             
                gr.Examples(
         | 
| 586 | 
             
                    examples=get_example(),
         | 
| 587 | 
            +
                    inputs=[face_file, pose_file, prompt, style, negative_prompt],
         | 
|  | |
| 588 | 
             
                    fn=run_for_examples,
         | 
| 589 | 
            +
                    outputs=[gallery, usage_tips],
         | 
| 590 | 
             
                    cache_examples=True,
         | 
| 591 | 
             
                )
         | 
| 592 |  | 
    	
        controlnet_util.py
    ADDED
    
    | @@ -0,0 +1,39 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            from PIL import Image
         | 
| 4 | 
            +
            from controlnet_aux import OpenposeDetector
         | 
| 5 | 
            +
            from model_util import get_torch_device
         | 
| 6 | 
            +
            import cv2
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            from transformers import DPTImageProcessor, DPTForDepthEstimation
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            device = get_torch_device()
         | 
| 12 | 
            +
            depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
         | 
| 13 | 
            +
            feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
         | 
| 14 | 
            +
            openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            def get_depth_map(image):
         | 
| 17 | 
            +
                image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
         | 
| 18 | 
            +
                with torch.no_grad(), torch.autocast("cuda"):
         | 
| 19 | 
            +
                    depth_map = depth_estimator(image).predicted_depth
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                depth_map = torch.nn.functional.interpolate(
         | 
| 22 | 
            +
                    depth_map.unsqueeze(1),
         | 
| 23 | 
            +
                    size=(1024, 1024),
         | 
| 24 | 
            +
                    mode="bicubic",
         | 
| 25 | 
            +
                    align_corners=False,
         | 
| 26 | 
            +
                )
         | 
| 27 | 
            +
                depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
         | 
| 28 | 
            +
                depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
         | 
| 29 | 
            +
                depth_map = (depth_map - depth_min) / (depth_max - depth_min)
         | 
| 30 | 
            +
                image = torch.cat([depth_map] * 3, dim=1)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
         | 
| 33 | 
            +
                image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
         | 
| 34 | 
            +
                return image
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            def get_canny_image(image, t1=100, t2=200):
         | 
| 37 | 
            +
                image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
         | 
| 38 | 
            +
                edges = cv2.Canny(image, t1, t2)
         | 
| 39 | 
            +
                return Image.fromarray(edges, "L")
         | 
    	
        gradio_cached_examples/25/Generated Image/2880a3d19ef9b42e3ed2/image.png
    DELETED
    
    | Git LFS Details
 | 
    	
        gradio_cached_examples/25/Generated Image/38dde1388c41109c5d39/image.png
    DELETED
    
    | Git LFS Details
 | 
    	
        gradio_cached_examples/25/Generated Image/6cb5a3af223906666bfd/image.png
    DELETED
    
    | Git LFS Details
 | 
    	
        gradio_cached_examples/25/Generated Image/7e6ea85f77dd925d842c/image.png
    DELETED
    
    | Git LFS Details
 | 
    	
        gradio_cached_examples/25/Generated Image/e56b029833685ff77e6a/image.png
    DELETED
    
    | Git LFS Details
 | 
    	
        gradio_cached_examples/25/log.csv
    DELETED
    
    | @@ -1,6 +0,0 @@ | |
| 1 | 
            -
            Generated Image,Usage tips of InstantID,flag,username,timestamp
         | 
| 2 | 
            -
            "{""path"":""gradio_cached_examples/25/Generated Image/7e6ea85f77dd925d842c/image.png"",""url"":null,""size"":null,""orig_name"":""image.png"",""mime_type"":null}","{'visible': True, '__type__': 'update'}",,,2024-01-24 08:55:38.846769
         | 
| 3 | 
            -
            "{""path"":""gradio_cached_examples/25/Generated Image/2880a3d19ef9b42e3ed2/image.png"",""url"":null,""size"":null,""orig_name"":""image.png"",""mime_type"":null}","{'visible': True, '__type__': 'update'}",,,2024-01-24 08:56:11.432078
         | 
| 4 | 
            -
            "{""path"":""gradio_cached_examples/25/Generated Image/38dde1388c41109c5d39/image.png"",""url"":null,""size"":null,""orig_name"":""image.png"",""mime_type"":null}","{'visible': True, '__type__': 'update'}",,,2024-01-24 08:56:45.563918
         | 
| 5 | 
            -
            "{""path"":""gradio_cached_examples/25/Generated Image/e56b029833685ff77e6a/image.png"",""url"":null,""size"":null,""orig_name"":""image.png"",""mime_type"":null}","{'visible': True, '__type__': 'update'}",,,2024-01-24 08:57:20.321876
         | 
| 6 | 
            -
            "{""path"":""gradio_cached_examples/25/Generated Image/6cb5a3af223906666bfd/image.png"",""url"":null,""size"":null,""orig_name"":""image.png"",""mime_type"":null}","{'visible': True, '__type__': 'update'}",,,2024-01-24 08:57:53.871716
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        ip_adapter/attention_processor.py
    CHANGED
    
    | @@ -10,14 +10,11 @@ try: | |
| 10 | 
             
            except Exception as e:
         | 
| 11 | 
             
                xformers_available = False
         | 
| 12 |  | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
             
            class RegionControler(object):
         | 
| 16 | 
             
                def __init__(self) -> None:
         | 
| 17 | 
             
                    self.prompt_image_conditioning = []
         | 
| 18 | 
             
            region_control = RegionControler()
         | 
| 19 |  | 
| 20 | 
            -
             | 
| 21 | 
             
            class AttnProcessor(nn.Module):
         | 
| 22 | 
             
                r"""
         | 
| 23 | 
             
                Default processor for performing attention-related computations.
         | 
| @@ -29,7 +26,7 @@ class AttnProcessor(nn.Module): | |
| 29 | 
             
                ):
         | 
| 30 | 
             
                    super().__init__()
         | 
| 31 |  | 
| 32 | 
            -
                def  | 
| 33 | 
             
                    self,
         | 
| 34 | 
             
                    attn,
         | 
| 35 | 
             
                    hidden_states,
         | 
| @@ -115,7 +112,7 @@ class IPAttnProcessor(nn.Module): | |
| 115 | 
             
                    self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
         | 
| 116 | 
             
                    self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
         | 
| 117 |  | 
| 118 | 
            -
                def  | 
| 119 | 
             
                    self,
         | 
| 120 | 
             
                    attn,
         | 
| 121 | 
             
                    hidden_states,
         | 
| @@ -180,7 +177,7 @@ class IPAttnProcessor(nn.Module): | |
| 180 | 
             
                        ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
         | 
| 181 | 
             
                        ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
         | 
| 182 | 
             
                    ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
         | 
| 183 | 
            -
             | 
| 184 | 
             
                    # region control
         | 
| 185 | 
             
                    if len(region_control.prompt_image_conditioning) == 1:
         | 
| 186 | 
             
                        region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
         | 
| @@ -190,7 +187,7 @@ class IPAttnProcessor(nn.Module): | |
| 190 | 
             
                            mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
         | 
| 191 | 
             
                        else:
         | 
| 192 | 
             
                            mask = torch.ones_like(ip_hidden_states)
         | 
| 193 | 
            -
                        ip_hidden_states = ip_hidden_states * mask | 
| 194 |  | 
| 195 | 
             
                    hidden_states = hidden_states + self.scale * ip_hidden_states
         | 
| 196 |  | 
| @@ -233,7 +230,7 @@ class AttnProcessor2_0(torch.nn.Module): | |
| 233 | 
             
                    if not hasattr(F, "scaled_dot_product_attention"):
         | 
| 234 | 
             
                        raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
         | 
| 235 |  | 
| 236 | 
            -
                def  | 
| 237 | 
             
                    self,
         | 
| 238 | 
             
                    attn,
         | 
| 239 | 
             
                    hidden_states,
         | 
| @@ -305,4 +302,145 @@ class AttnProcessor2_0(torch.nn.Module): | |
| 305 |  | 
| 306 | 
             
                    hidden_states = hidden_states / attn.rescale_output_factor
         | 
| 307 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 308 | 
             
                    return hidden_states
         | 
|  | |
| 10 | 
             
            except Exception as e:
         | 
| 11 | 
             
                xformers_available = False
         | 
| 12 |  | 
|  | |
|  | |
| 13 | 
             
            class RegionControler(object):
         | 
| 14 | 
             
                def __init__(self) -> None:
         | 
| 15 | 
             
                    self.prompt_image_conditioning = []
         | 
| 16 | 
             
            region_control = RegionControler()
         | 
| 17 |  | 
|  | |
| 18 | 
             
            class AttnProcessor(nn.Module):
         | 
| 19 | 
             
                r"""
         | 
| 20 | 
             
                Default processor for performing attention-related computations.
         | 
|  | |
| 26 | 
             
                ):
         | 
| 27 | 
             
                    super().__init__()
         | 
| 28 |  | 
| 29 | 
            +
                def forward(
         | 
| 30 | 
             
                    self,
         | 
| 31 | 
             
                    attn,
         | 
| 32 | 
             
                    hidden_states,
         | 
|  | |
| 112 | 
             
                    self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
         | 
| 113 | 
             
                    self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
         | 
| 114 |  | 
| 115 | 
            +
                def forward(
         | 
| 116 | 
             
                    self,
         | 
| 117 | 
             
                    attn,
         | 
| 118 | 
             
                    hidden_states,
         | 
|  | |
| 177 | 
             
                        ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
         | 
| 178 | 
             
                        ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
         | 
| 179 | 
             
                    ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
         | 
| 180 | 
            +
             | 
| 181 | 
             
                    # region control
         | 
| 182 | 
             
                    if len(region_control.prompt_image_conditioning) == 1:
         | 
| 183 | 
             
                        region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
         | 
|  | |
| 187 | 
             
                            mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
         | 
| 188 | 
             
                        else:
         | 
| 189 | 
             
                            mask = torch.ones_like(ip_hidden_states)
         | 
| 190 | 
            +
                        ip_hidden_states = ip_hidden_states * mask     
         | 
| 191 |  | 
| 192 | 
             
                    hidden_states = hidden_states + self.scale * ip_hidden_states
         | 
| 193 |  | 
|  | |
| 230 | 
             
                    if not hasattr(F, "scaled_dot_product_attention"):
         | 
| 231 | 
             
                        raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
         | 
| 232 |  | 
| 233 | 
            +
                def forward(
         | 
| 234 | 
             
                    self,
         | 
| 235 | 
             
                    attn,
         | 
| 236 | 
             
                    hidden_states,
         | 
|  | |
| 302 |  | 
| 303 | 
             
                    hidden_states = hidden_states / attn.rescale_output_factor
         | 
| 304 |  | 
| 305 | 
            +
                    return hidden_states
         | 
| 306 | 
            +
             | 
| 307 | 
            +
            class IPAttnProcessor2_0(torch.nn.Module):
         | 
| 308 | 
            +
                r"""
         | 
| 309 | 
            +
                Attention processor for IP-Adapater for PyTorch 2.0.
         | 
| 310 | 
            +
                Args:
         | 
| 311 | 
            +
                    hidden_size (`int`):
         | 
| 312 | 
            +
                        The hidden size of the attention layer.
         | 
| 313 | 
            +
                    cross_attention_dim (`int`):
         | 
| 314 | 
            +
                        The number of channels in the `encoder_hidden_states`.
         | 
| 315 | 
            +
                    scale (`float`, defaults to 1.0):
         | 
| 316 | 
            +
                        the weight scale of image prompt.
         | 
| 317 | 
            +
                    num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
         | 
| 318 | 
            +
                        The context length of the image features.
         | 
| 319 | 
            +
                """
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
         | 
| 322 | 
            +
                    super().__init__()
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                    if not hasattr(F, "scaled_dot_product_attention"):
         | 
| 325 | 
            +
                        raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                    self.hidden_size = hidden_size
         | 
| 328 | 
            +
                    self.cross_attention_dim = cross_attention_dim
         | 
| 329 | 
            +
                    self.scale = scale
         | 
| 330 | 
            +
                    self.num_tokens = num_tokens
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                    self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
         | 
| 333 | 
            +
                    self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                def forward(
         | 
| 336 | 
            +
                    self,
         | 
| 337 | 
            +
                    attn,
         | 
| 338 | 
            +
                    hidden_states,
         | 
| 339 | 
            +
                    encoder_hidden_states=None,
         | 
| 340 | 
            +
                    attention_mask=None,
         | 
| 341 | 
            +
                    temb=None,
         | 
| 342 | 
            +
                ):
         | 
| 343 | 
            +
                    residual = hidden_states
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                    if attn.spatial_norm is not None:
         | 
| 346 | 
            +
                        hidden_states = attn.spatial_norm(hidden_states, temb)
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                    input_ndim = hidden_states.ndim
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                    if input_ndim == 4:
         | 
| 351 | 
            +
                        batch_size, channel, height, width = hidden_states.shape
         | 
| 352 | 
            +
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    batch_size, sequence_length, _ = (
         | 
| 355 | 
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         | 
| 356 | 
            +
                    )
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                    if attention_mask is not None:
         | 
| 359 | 
            +
                        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         | 
| 360 | 
            +
                        # scaled_dot_product_attention expects attention_mask shape to be
         | 
| 361 | 
            +
                        # (batch, heads, source_length, target_length)
         | 
| 362 | 
            +
                        attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                    if attn.group_norm is not None:
         | 
| 365 | 
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                    query = attn.to_q(hidden_states)
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                    if encoder_hidden_states is None:
         | 
| 370 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 371 | 
            +
                    else:
         | 
| 372 | 
            +
                        # get encoder_hidden_states, ip_hidden_states
         | 
| 373 | 
            +
                        end_pos = encoder_hidden_states.shape[1] - self.num_tokens
         | 
| 374 | 
            +
                        encoder_hidden_states, ip_hidden_states = (
         | 
| 375 | 
            +
                            encoder_hidden_states[:, :end_pos, :],
         | 
| 376 | 
            +
                            encoder_hidden_states[:, end_pos:, :],
         | 
| 377 | 
            +
                        )
         | 
| 378 | 
            +
                        if attn.norm_cross:
         | 
| 379 | 
            +
                            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                    key = attn.to_k(encoder_hidden_states)
         | 
| 382 | 
            +
                    value = attn.to_v(encoder_hidden_states)
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                    inner_dim = key.shape[-1]
         | 
| 385 | 
            +
                    head_dim = inner_dim // attn.heads
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 390 | 
            +
                    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                    # the output of sdp = (batch, num_heads, seq_len, head_dim)
         | 
| 393 | 
            +
                    # TODO: add support for attn.scale when we move to Torch 2.1
         | 
| 394 | 
            +
                    hidden_states = F.scaled_dot_product_attention(
         | 
| 395 | 
            +
                        query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
         | 
| 396 | 
            +
                    )
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                    hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
         | 
| 399 | 
            +
                    hidden_states = hidden_states.to(query.dtype)
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                    # for ip-adapter
         | 
| 402 | 
            +
                    ip_key = self.to_k_ip(ip_hidden_states)
         | 
| 403 | 
            +
                    ip_value = self.to_v_ip(ip_hidden_states)
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                    ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 406 | 
            +
                    ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                    # the output of sdp = (batch, num_heads, seq_len, head_dim)
         | 
| 409 | 
            +
                    # TODO: add support for attn.scale when we move to Torch 2.1
         | 
| 410 | 
            +
                    ip_hidden_states = F.scaled_dot_product_attention(
         | 
| 411 | 
            +
                        query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
         | 
| 412 | 
            +
                    )
         | 
| 413 | 
            +
                    with torch.no_grad():
         | 
| 414 | 
            +
                        self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
         | 
| 415 | 
            +
                        #print(self.attn_map.shape)
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                    ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
         | 
| 418 | 
            +
                    ip_hidden_states = ip_hidden_states.to(query.dtype)
         | 
| 419 | 
            +
             | 
| 420 | 
            +
                    # region control
         | 
| 421 | 
            +
                    if len(region_control.prompt_image_conditioning) == 1:
         | 
| 422 | 
            +
                        region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
         | 
| 423 | 
            +
                        if region_mask is not None:
         | 
| 424 | 
            +
                            h, w = region_mask.shape[:2]
         | 
| 425 | 
            +
                            ratio = (h * w / query.shape[1]) ** 0.5
         | 
| 426 | 
            +
                            mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
         | 
| 427 | 
            +
                        else:
         | 
| 428 | 
            +
                            mask = torch.ones_like(ip_hidden_states)
         | 
| 429 | 
            +
                        ip_hidden_states = ip_hidden_states * mask
         | 
| 430 | 
            +
             | 
| 431 | 
            +
                    hidden_states = hidden_states + self.scale * ip_hidden_states
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                    # linear proj
         | 
| 434 | 
            +
                    hidden_states = attn.to_out[0](hidden_states)
         | 
| 435 | 
            +
                    # dropout
         | 
| 436 | 
            +
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 437 | 
            +
             | 
| 438 | 
            +
                    if input_ndim == 4:
         | 
| 439 | 
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                    if attn.residual_connection:
         | 
| 442 | 
            +
                        hidden_states = hidden_states + residual
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         | 
| 445 | 
            +
             | 
| 446 | 
             
                    return hidden_states
         | 
    	
        model_util.py
    ADDED
    
    | @@ -0,0 +1,472 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Literal, Union, Optional, Tuple, List
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
         | 
| 5 | 
            +
            from diffusers import (
         | 
| 6 | 
            +
                UNet2DConditionModel,
         | 
| 7 | 
            +
                SchedulerMixin,
         | 
| 8 | 
            +
                StableDiffusionPipeline,
         | 
| 9 | 
            +
                StableDiffusionXLPipeline,
         | 
| 10 | 
            +
                AutoencoderKL,
         | 
| 11 | 
            +
            )
         | 
| 12 | 
            +
            from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
         | 
| 13 | 
            +
                convert_ldm_unet_checkpoint,
         | 
| 14 | 
            +
            )
         | 
| 15 | 
            +
            from safetensors.torch import load_file
         | 
| 16 | 
            +
            from diffusers.schedulers import (
         | 
| 17 | 
            +
                DDIMScheduler,
         | 
| 18 | 
            +
                DDPMScheduler,
         | 
| 19 | 
            +
                LMSDiscreteScheduler,
         | 
| 20 | 
            +
                EulerDiscreteScheduler,
         | 
| 21 | 
            +
                EulerAncestralDiscreteScheduler,
         | 
| 22 | 
            +
                UniPCMultistepScheduler,
         | 
| 23 | 
            +
            )
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from omegaconf import OmegaConf
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            # DiffUsers版StableDiffusionのモデルパラメータ
         | 
| 28 | 
            +
            NUM_TRAIN_TIMESTEPS = 1000
         | 
| 29 | 
            +
            BETA_START = 0.00085
         | 
| 30 | 
            +
            BETA_END = 0.0120
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            UNET_PARAMS_MODEL_CHANNELS = 320
         | 
| 33 | 
            +
            UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
         | 
| 34 | 
            +
            UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
         | 
| 35 | 
            +
            UNET_PARAMS_IMAGE_SIZE = 64  # fixed from old invalid value `32`
         | 
| 36 | 
            +
            UNET_PARAMS_IN_CHANNELS = 4
         | 
| 37 | 
            +
            UNET_PARAMS_OUT_CHANNELS = 4
         | 
| 38 | 
            +
            UNET_PARAMS_NUM_RES_BLOCKS = 2
         | 
| 39 | 
            +
            UNET_PARAMS_CONTEXT_DIM = 768
         | 
| 40 | 
            +
            UNET_PARAMS_NUM_HEADS = 8
         | 
| 41 | 
            +
            # UNET_PARAMS_USE_LINEAR_PROJECTION = False
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            VAE_PARAMS_Z_CHANNELS = 4
         | 
| 44 | 
            +
            VAE_PARAMS_RESOLUTION = 256
         | 
| 45 | 
            +
            VAE_PARAMS_IN_CHANNELS = 3
         | 
| 46 | 
            +
            VAE_PARAMS_OUT_CH = 3
         | 
| 47 | 
            +
            VAE_PARAMS_CH = 128
         | 
| 48 | 
            +
            VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
         | 
| 49 | 
            +
            VAE_PARAMS_NUM_RES_BLOCKS = 2
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            # V2
         | 
| 52 | 
            +
            V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
         | 
| 53 | 
            +
            V2_UNET_PARAMS_CONTEXT_DIM = 1024
         | 
| 54 | 
            +
            # V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4"
         | 
| 57 | 
            +
            TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1"
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a", "euler", "uniPC"]
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            DIFFUSERS_CACHE_DIR = None  # if you want to change the cache dir, change this
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            def load_checkpoint_with_text_encoder_conversion(ckpt_path: str, device="cpu"):
         | 
| 67 | 
            +
                # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
         | 
| 68 | 
            +
                TEXT_ENCODER_KEY_REPLACEMENTS = [
         | 
| 69 | 
            +
                    (
         | 
| 70 | 
            +
                        "cond_stage_model.transformer.embeddings.",
         | 
| 71 | 
            +
                        "cond_stage_model.transformer.text_model.embeddings.",
         | 
| 72 | 
            +
                    ),
         | 
| 73 | 
            +
                    (
         | 
| 74 | 
            +
                        "cond_stage_model.transformer.encoder.",
         | 
| 75 | 
            +
                        "cond_stage_model.transformer.text_model.encoder.",
         | 
| 76 | 
            +
                    ),
         | 
| 77 | 
            +
                    (
         | 
| 78 | 
            +
                        "cond_stage_model.transformer.final_layer_norm.",
         | 
| 79 | 
            +
                        "cond_stage_model.transformer.text_model.final_layer_norm.",
         | 
| 80 | 
            +
                    ),
         | 
| 81 | 
            +
                ]
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                if ckpt_path.endswith(".safetensors"):
         | 
| 84 | 
            +
                    checkpoint = None
         | 
| 85 | 
            +
                    state_dict = load_file(ckpt_path)  # , device) # may causes error
         | 
| 86 | 
            +
                else:
         | 
| 87 | 
            +
                    checkpoint = torch.load(ckpt_path, map_location=device)
         | 
| 88 | 
            +
                    if "state_dict" in checkpoint:
         | 
| 89 | 
            +
                        state_dict = checkpoint["state_dict"]
         | 
| 90 | 
            +
                    else:
         | 
| 91 | 
            +
                        state_dict = checkpoint
         | 
| 92 | 
            +
                        checkpoint = None
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                key_reps = []
         | 
| 95 | 
            +
                for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
         | 
| 96 | 
            +
                    for key in state_dict.keys():
         | 
| 97 | 
            +
                        if key.startswith(rep_from):
         | 
| 98 | 
            +
                            new_key = rep_to + key[len(rep_from) :]
         | 
| 99 | 
            +
                            key_reps.append((key, new_key))
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                for key, new_key in key_reps:
         | 
| 102 | 
            +
                    state_dict[new_key] = state_dict[key]
         | 
| 103 | 
            +
                    del state_dict[key]
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                return checkpoint, state_dict
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
         | 
| 109 | 
            +
                """
         | 
| 110 | 
            +
                Creates a config for the diffusers based on the config of the LDM model.
         | 
| 111 | 
            +
                """
         | 
| 112 | 
            +
                # unet_params = original_config.model.params.unet_config.params
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                block_out_channels = [
         | 
| 115 | 
            +
                    UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT
         | 
| 116 | 
            +
                ]
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                down_block_types = []
         | 
| 119 | 
            +
                resolution = 1
         | 
| 120 | 
            +
                for i in range(len(block_out_channels)):
         | 
| 121 | 
            +
                    block_type = (
         | 
| 122 | 
            +
                        "CrossAttnDownBlock2D"
         | 
| 123 | 
            +
                        if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS
         | 
| 124 | 
            +
                        else "DownBlock2D"
         | 
| 125 | 
            +
                    )
         | 
| 126 | 
            +
                    down_block_types.append(block_type)
         | 
| 127 | 
            +
                    if i != len(block_out_channels) - 1:
         | 
| 128 | 
            +
                        resolution *= 2
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                up_block_types = []
         | 
| 131 | 
            +
                for i in range(len(block_out_channels)):
         | 
| 132 | 
            +
                    block_type = (
         | 
| 133 | 
            +
                        "CrossAttnUpBlock2D"
         | 
| 134 | 
            +
                        if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS
         | 
| 135 | 
            +
                        else "UpBlock2D"
         | 
| 136 | 
            +
                    )
         | 
| 137 | 
            +
                    up_block_types.append(block_type)
         | 
| 138 | 
            +
                    resolution //= 2
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                config = dict(
         | 
| 141 | 
            +
                    sample_size=UNET_PARAMS_IMAGE_SIZE,
         | 
| 142 | 
            +
                    in_channels=UNET_PARAMS_IN_CHANNELS,
         | 
| 143 | 
            +
                    out_channels=UNET_PARAMS_OUT_CHANNELS,
         | 
| 144 | 
            +
                    down_block_types=tuple(down_block_types),
         | 
| 145 | 
            +
                    up_block_types=tuple(up_block_types),
         | 
| 146 | 
            +
                    block_out_channels=tuple(block_out_channels),
         | 
| 147 | 
            +
                    layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
         | 
| 148 | 
            +
                    cross_attention_dim=UNET_PARAMS_CONTEXT_DIM
         | 
| 149 | 
            +
                    if not v2
         | 
| 150 | 
            +
                    else V2_UNET_PARAMS_CONTEXT_DIM,
         | 
| 151 | 
            +
                    attention_head_dim=UNET_PARAMS_NUM_HEADS
         | 
| 152 | 
            +
                    if not v2
         | 
| 153 | 
            +
                    else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
         | 
| 154 | 
            +
                    # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
         | 
| 155 | 
            +
                )
         | 
| 156 | 
            +
                if v2 and use_linear_projection_in_v2:
         | 
| 157 | 
            +
                    config["use_linear_projection"] = True
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                return config
         | 
| 160 | 
            +
             | 
| 161 | 
            +
             | 
| 162 | 
            +
            def load_diffusers_model(
         | 
| 163 | 
            +
                pretrained_model_name_or_path: str,
         | 
| 164 | 
            +
                v2: bool = False,
         | 
| 165 | 
            +
                clip_skip: Optional[int] = None,
         | 
| 166 | 
            +
                weight_dtype: torch.dtype = torch.float32,
         | 
| 167 | 
            +
            ) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
         | 
| 168 | 
            +
                if v2:
         | 
| 169 | 
            +
                    tokenizer = CLIPTokenizer.from_pretrained(
         | 
| 170 | 
            +
                        TOKENIZER_V2_MODEL_NAME,
         | 
| 171 | 
            +
                        subfolder="tokenizer",
         | 
| 172 | 
            +
                        torch_dtype=weight_dtype,
         | 
| 173 | 
            +
                        cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 174 | 
            +
                    )
         | 
| 175 | 
            +
                    text_encoder = CLIPTextModel.from_pretrained(
         | 
| 176 | 
            +
                        pretrained_model_name_or_path,
         | 
| 177 | 
            +
                        subfolder="text_encoder",
         | 
| 178 | 
            +
                        # default is clip skip 2
         | 
| 179 | 
            +
                        num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23,
         | 
| 180 | 
            +
                        torch_dtype=weight_dtype,
         | 
| 181 | 
            +
                        cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 182 | 
            +
                    )
         | 
| 183 | 
            +
                else:
         | 
| 184 | 
            +
                    tokenizer = CLIPTokenizer.from_pretrained(
         | 
| 185 | 
            +
                        TOKENIZER_V1_MODEL_NAME,
         | 
| 186 | 
            +
                        subfolder="tokenizer",
         | 
| 187 | 
            +
                        torch_dtype=weight_dtype,
         | 
| 188 | 
            +
                        cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 189 | 
            +
                    )
         | 
| 190 | 
            +
                    text_encoder = CLIPTextModel.from_pretrained(
         | 
| 191 | 
            +
                        pretrained_model_name_or_path,
         | 
| 192 | 
            +
                        subfolder="text_encoder",
         | 
| 193 | 
            +
                        num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12,
         | 
| 194 | 
            +
                        torch_dtype=weight_dtype,
         | 
| 195 | 
            +
                        cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 196 | 
            +
                    )
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                unet = UNet2DConditionModel.from_pretrained(
         | 
| 199 | 
            +
                    pretrained_model_name_or_path,
         | 
| 200 | 
            +
                    subfolder="unet",
         | 
| 201 | 
            +
                    torch_dtype=weight_dtype,
         | 
| 202 | 
            +
                    cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 203 | 
            +
                )
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                return tokenizer, text_encoder, unet, vae
         | 
| 208 | 
            +
             | 
| 209 | 
            +
             | 
| 210 | 
            +
            def load_checkpoint_model(
         | 
| 211 | 
            +
                checkpoint_path: str,
         | 
| 212 | 
            +
                v2: bool = False,
         | 
| 213 | 
            +
                clip_skip: Optional[int] = None,
         | 
| 214 | 
            +
                weight_dtype: torch.dtype = torch.float32,
         | 
| 215 | 
            +
            ) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
         | 
| 216 | 
            +
                pipe = StableDiffusionPipeline.from_single_file(
         | 
| 217 | 
            +
                    checkpoint_path,
         | 
| 218 | 
            +
                    upcast_attention=True if v2 else False,
         | 
| 219 | 
            +
                    torch_dtype=weight_dtype,
         | 
| 220 | 
            +
                    cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 221 | 
            +
                )
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                _, state_dict = load_checkpoint_with_text_encoder_conversion(checkpoint_path)
         | 
| 224 | 
            +
                unet_config = create_unet_diffusers_config(v2, use_linear_projection_in_v2=v2)
         | 
| 225 | 
            +
                unet_config["class_embed_type"] = None
         | 
| 226 | 
            +
                unet_config["addition_embed_type"] = None
         | 
| 227 | 
            +
                converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config)
         | 
| 228 | 
            +
                unet = UNet2DConditionModel(**unet_config)
         | 
| 229 | 
            +
                unet.load_state_dict(converted_unet_checkpoint)
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                tokenizer = pipe.tokenizer
         | 
| 232 | 
            +
                text_encoder = pipe.text_encoder
         | 
| 233 | 
            +
                vae = pipe.vae
         | 
| 234 | 
            +
                if clip_skip is not None:
         | 
| 235 | 
            +
                    if v2:
         | 
| 236 | 
            +
                        text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1)
         | 
| 237 | 
            +
                    else:
         | 
| 238 | 
            +
                        text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                del pipe
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                return tokenizer, text_encoder, unet, vae
         | 
| 243 | 
            +
             | 
| 244 | 
            +
             | 
| 245 | 
            +
            def load_models(
         | 
| 246 | 
            +
                pretrained_model_name_or_path: str,
         | 
| 247 | 
            +
                scheduler_name: str,
         | 
| 248 | 
            +
                v2: bool = False,
         | 
| 249 | 
            +
                v_pred: bool = False,
         | 
| 250 | 
            +
                weight_dtype: torch.dtype = torch.float32,
         | 
| 251 | 
            +
            ) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]:
         | 
| 252 | 
            +
                if pretrained_model_name_or_path.endswith(
         | 
| 253 | 
            +
                    ".ckpt"
         | 
| 254 | 
            +
                ) or pretrained_model_name_or_path.endswith(".safetensors"):
         | 
| 255 | 
            +
                    tokenizer, text_encoder, unet, vae = load_checkpoint_model(
         | 
| 256 | 
            +
                        pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
         | 
| 257 | 
            +
                    )
         | 
| 258 | 
            +
                else:  # diffusers
         | 
| 259 | 
            +
                    tokenizer, text_encoder, unet, vae = load_diffusers_model(
         | 
| 260 | 
            +
                        pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
         | 
| 261 | 
            +
                    )
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                if scheduler_name:
         | 
| 264 | 
            +
                    scheduler = create_noise_scheduler(
         | 
| 265 | 
            +
                        scheduler_name,
         | 
| 266 | 
            +
                        prediction_type="v_prediction" if v_pred else "epsilon",
         | 
| 267 | 
            +
                    )
         | 
| 268 | 
            +
                else:
         | 
| 269 | 
            +
                    scheduler = None
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                return tokenizer, text_encoder, unet, scheduler, vae
         | 
| 272 | 
            +
             | 
| 273 | 
            +
             | 
| 274 | 
            +
            def load_diffusers_model_xl(
         | 
| 275 | 
            +
                pretrained_model_name_or_path: str,
         | 
| 276 | 
            +
                weight_dtype: torch.dtype = torch.float32,
         | 
| 277 | 
            +
            ) -> Tuple[List[CLIPTokenizer], List[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
         | 
| 278 | 
            +
                # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                tokenizers = [
         | 
| 281 | 
            +
                    CLIPTokenizer.from_pretrained(
         | 
| 282 | 
            +
                        pretrained_model_name_or_path,
         | 
| 283 | 
            +
                        subfolder="tokenizer",
         | 
| 284 | 
            +
                        torch_dtype=weight_dtype,
         | 
| 285 | 
            +
                        cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 286 | 
            +
                    ),
         | 
| 287 | 
            +
                    CLIPTokenizer.from_pretrained(
         | 
| 288 | 
            +
                        pretrained_model_name_or_path,
         | 
| 289 | 
            +
                        subfolder="tokenizer_2",
         | 
| 290 | 
            +
                        torch_dtype=weight_dtype,
         | 
| 291 | 
            +
                        cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 292 | 
            +
                        pad_token_id=0,  # same as open clip
         | 
| 293 | 
            +
                    ),
         | 
| 294 | 
            +
                ]
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                text_encoders = [
         | 
| 297 | 
            +
                    CLIPTextModel.from_pretrained(
         | 
| 298 | 
            +
                        pretrained_model_name_or_path,
         | 
| 299 | 
            +
                        subfolder="text_encoder",
         | 
| 300 | 
            +
                        torch_dtype=weight_dtype,
         | 
| 301 | 
            +
                        cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 302 | 
            +
                    ),
         | 
| 303 | 
            +
                    CLIPTextModelWithProjection.from_pretrained(
         | 
| 304 | 
            +
                        pretrained_model_name_or_path,
         | 
| 305 | 
            +
                        subfolder="text_encoder_2",
         | 
| 306 | 
            +
                        torch_dtype=weight_dtype,
         | 
| 307 | 
            +
                        cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 308 | 
            +
                    ),
         | 
| 309 | 
            +
                ]
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                unet = UNet2DConditionModel.from_pretrained(
         | 
| 312 | 
            +
                    pretrained_model_name_or_path,
         | 
| 313 | 
            +
                    subfolder="unet",
         | 
| 314 | 
            +
                    torch_dtype=weight_dtype,
         | 
| 315 | 
            +
                    cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 316 | 
            +
                )
         | 
| 317 | 
            +
                vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
         | 
| 318 | 
            +
                return tokenizers, text_encoders, unet, vae
         | 
| 319 | 
            +
             | 
| 320 | 
            +
             | 
| 321 | 
            +
            def load_checkpoint_model_xl(
         | 
| 322 | 
            +
                checkpoint_path: str,
         | 
| 323 | 
            +
                weight_dtype: torch.dtype = torch.float32,
         | 
| 324 | 
            +
            ) -> Tuple[List[CLIPTokenizer], List[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
         | 
| 325 | 
            +
                pipe = StableDiffusionXLPipeline.from_single_file(
         | 
| 326 | 
            +
                    checkpoint_path,
         | 
| 327 | 
            +
                    torch_dtype=weight_dtype,
         | 
| 328 | 
            +
                    cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 329 | 
            +
                )
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                unet = pipe.unet
         | 
| 332 | 
            +
                vae = pipe.vae
         | 
| 333 | 
            +
                tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
         | 
| 334 | 
            +
                text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
         | 
| 335 | 
            +
                if len(text_encoders) == 2:
         | 
| 336 | 
            +
                    text_encoders[1].pad_token_id = 0
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                del pipe
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                return tokenizers, text_encoders, unet, vae
         | 
| 341 | 
            +
             | 
| 342 | 
            +
             | 
| 343 | 
            +
            def load_models_xl(
         | 
| 344 | 
            +
                pretrained_model_name_or_path: str,
         | 
| 345 | 
            +
                scheduler_name: str,
         | 
| 346 | 
            +
                weight_dtype: torch.dtype = torch.float32,
         | 
| 347 | 
            +
                noise_scheduler_kwargs=None,
         | 
| 348 | 
            +
            ) -> Tuple[
         | 
| 349 | 
            +
                List[CLIPTokenizer],
         | 
| 350 | 
            +
                List[SDXL_TEXT_ENCODER_TYPE],
         | 
| 351 | 
            +
                UNet2DConditionModel,
         | 
| 352 | 
            +
                SchedulerMixin,
         | 
| 353 | 
            +
            ]:
         | 
| 354 | 
            +
                if pretrained_model_name_or_path.endswith(
         | 
| 355 | 
            +
                    ".ckpt"
         | 
| 356 | 
            +
                ) or pretrained_model_name_or_path.endswith(".safetensors"):
         | 
| 357 | 
            +
                    (tokenizers, text_encoders, unet, vae) = load_checkpoint_model_xl(
         | 
| 358 | 
            +
                        pretrained_model_name_or_path, weight_dtype
         | 
| 359 | 
            +
                    )
         | 
| 360 | 
            +
                else:  # diffusers
         | 
| 361 | 
            +
                    (tokenizers, text_encoders, unet, vae) = load_diffusers_model_xl(
         | 
| 362 | 
            +
                        pretrained_model_name_or_path, weight_dtype
         | 
| 363 | 
            +
                    )
         | 
| 364 | 
            +
                if scheduler_name:
         | 
| 365 | 
            +
                    scheduler = create_noise_scheduler(scheduler_name, noise_scheduler_kwargs)
         | 
| 366 | 
            +
                else:
         | 
| 367 | 
            +
                    scheduler = None
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                return tokenizers, text_encoders, unet, scheduler, vae
         | 
| 370 | 
            +
             | 
| 371 | 
            +
            def create_noise_scheduler(
         | 
| 372 | 
            +
                scheduler_name: AVAILABLE_SCHEDULERS = "ddpm",
         | 
| 373 | 
            +
                noise_scheduler_kwargs=None,
         | 
| 374 | 
            +
                prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
         | 
| 375 | 
            +
            ) -> SchedulerMixin:
         | 
| 376 | 
            +
                name = scheduler_name.lower().replace(" ", "_")
         | 
| 377 | 
            +
                if name.lower() == "ddim":
         | 
| 378 | 
            +
                    # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim
         | 
| 379 | 
            +
                    scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
         | 
| 380 | 
            +
                elif name.lower() == "ddpm":
         | 
| 381 | 
            +
                    # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm
         | 
| 382 | 
            +
                    scheduler = DDPMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
         | 
| 383 | 
            +
                elif name.lower() == "lms":
         | 
| 384 | 
            +
                    # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete
         | 
| 385 | 
            +
                    scheduler = LMSDiscreteScheduler(
         | 
| 386 | 
            +
                        **OmegaConf.to_container(noise_scheduler_kwargs)
         | 
| 387 | 
            +
                    )
         | 
| 388 | 
            +
                elif name.lower() == "euler_a":
         | 
| 389 | 
            +
                    # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
         | 
| 390 | 
            +
                    scheduler = EulerAncestralDiscreteScheduler(
         | 
| 391 | 
            +
                        **OmegaConf.to_container(noise_scheduler_kwargs)
         | 
| 392 | 
            +
                    )
         | 
| 393 | 
            +
                elif name.lower() == "euler":
         | 
| 394 | 
            +
                    # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
         | 
| 395 | 
            +
                    scheduler = EulerDiscreteScheduler(
         | 
| 396 | 
            +
                        **OmegaConf.to_container(noise_scheduler_kwargs)
         | 
| 397 | 
            +
                    )
         | 
| 398 | 
            +
                elif name.lower() == "unipc":
         | 
| 399 | 
            +
                    # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/unipc
         | 
| 400 | 
            +
                    scheduler = UniPCMultistepScheduler(
         | 
| 401 | 
            +
                        **OmegaConf.to_container(noise_scheduler_kwargs)
         | 
| 402 | 
            +
                    )
         | 
| 403 | 
            +
                else:
         | 
| 404 | 
            +
                    raise ValueError(f"Unknown scheduler name: {name}")
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                return scheduler
         | 
| 407 | 
            +
             | 
| 408 | 
            +
             | 
| 409 | 
            +
            def torch_gc():
         | 
| 410 | 
            +
                import gc
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                gc.collect()
         | 
| 413 | 
            +
                if torch.cuda.is_available():
         | 
| 414 | 
            +
                    with torch.cuda.device("cuda"):
         | 
| 415 | 
            +
                        torch.cuda.empty_cache()
         | 
| 416 | 
            +
                        torch.cuda.ipc_collect()
         | 
| 417 | 
            +
             | 
| 418 | 
            +
             | 
| 419 | 
            +
            from enum import Enum
         | 
| 420 | 
            +
             | 
| 421 | 
            +
             | 
| 422 | 
            +
            class CPUState(Enum):
         | 
| 423 | 
            +
                GPU = 0
         | 
| 424 | 
            +
                CPU = 1
         | 
| 425 | 
            +
                MPS = 2
         | 
| 426 | 
            +
             | 
| 427 | 
            +
             | 
| 428 | 
            +
            cpu_state = CPUState.GPU
         | 
| 429 | 
            +
            xpu_available = False
         | 
| 430 | 
            +
            directml_enabled = False
         | 
| 431 | 
            +
             | 
| 432 | 
            +
             | 
| 433 | 
            +
            def is_intel_xpu():
         | 
| 434 | 
            +
                global cpu_state
         | 
| 435 | 
            +
                global xpu_available
         | 
| 436 | 
            +
                if cpu_state == CPUState.GPU:
         | 
| 437 | 
            +
                    if xpu_available:
         | 
| 438 | 
            +
                        return True
         | 
| 439 | 
            +
                return False
         | 
| 440 | 
            +
             | 
| 441 | 
            +
             | 
| 442 | 
            +
            try:
         | 
| 443 | 
            +
                import intel_extension_for_pytorch as ipex
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                if torch.xpu.is_available():
         | 
| 446 | 
            +
                    xpu_available = True
         | 
| 447 | 
            +
            except:
         | 
| 448 | 
            +
                pass
         | 
| 449 | 
            +
             | 
| 450 | 
            +
            try:
         | 
| 451 | 
            +
                if torch.backends.mps.is_available():
         | 
| 452 | 
            +
                    cpu_state = CPUState.MPS
         | 
| 453 | 
            +
                    import torch.mps
         | 
| 454 | 
            +
            except:
         | 
| 455 | 
            +
                pass
         | 
| 456 | 
            +
             | 
| 457 | 
            +
             | 
| 458 | 
            +
            def get_torch_device():
         | 
| 459 | 
            +
                global directml_enabled
         | 
| 460 | 
            +
                global cpu_state
         | 
| 461 | 
            +
                if directml_enabled:
         | 
| 462 | 
            +
                    global directml_device
         | 
| 463 | 
            +
                    return directml_device
         | 
| 464 | 
            +
                if cpu_state == CPUState.MPS:
         | 
| 465 | 
            +
                    return torch.device("mps")
         | 
| 466 | 
            +
                if cpu_state == CPUState.CPU:
         | 
| 467 | 
            +
                    return torch.device("cpu")
         | 
| 468 | 
            +
                else:
         | 
| 469 | 
            +
                    if is_intel_xpu():
         | 
| 470 | 
            +
                        return torch.device("xpu")
         | 
| 471 | 
            +
                    else:
         | 
| 472 | 
            +
                        return torch.device(torch.cuda.current_device())
         | 
    	
        pipeline_stable_diffusion_xl_instantid.py → pipeline_stable_diffusion_xl_instantid_full.py
    RENAMED
    
    | @@ -22,7 +22,6 @@ import numpy as np | |
| 22 | 
             
            import PIL.Image
         | 
| 23 | 
             
            import torch
         | 
| 24 | 
             
            import torch.nn.functional as F
         | 
| 25 | 
            -
            from transformers import CLIPTokenizer
         | 
| 26 |  | 
| 27 | 
             
            from diffusers.image_processor import PipelineImageInput
         | 
| 28 |  | 
| @@ -41,8 +40,12 @@ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel | |
| 41 | 
             
            from diffusers.utils.import_utils import is_xformers_available
         | 
| 42 |  | 
| 43 | 
             
            from ip_adapter.resampler import Resampler
         | 
|  | |
| 44 |  | 
| 45 | 
            -
             | 
|  | |
|  | |
|  | |
| 46 | 
             
            from ip_adapter.attention_processor import region_control
         | 
| 47 |  | 
| 48 | 
             
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| @@ -102,7 +105,7 @@ EXAMPLE_DOC_STRING = """ | |
| 102 | 
             
                    ```
         | 
| 103 | 
             
            """
         | 
| 104 |  | 
| 105 | 
            -
             | 
| 106 | 
             
            from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
         | 
| 107 | 
             
            class LongPromptWeight(object):
         | 
| 108 |  | 
| @@ -482,6 +485,34 @@ class LongPromptWeight(object): | |
| 482 | 
             
                    prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
         | 
| 483 | 
             
                    return prompt_embeds
         | 
| 484 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 485 |  | 
| 486 | 
             
            class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
         | 
| 487 |  | 
| @@ -567,7 +598,7 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline): | |
| 567 | 
             
                        if isinstance(attn_processor, IPAttnProcessor):
         | 
| 568 | 
             
                            attn_processor.scale = scale
         | 
| 569 |  | 
| 570 | 
            -
                def _encode_prompt_image_emb(self, prompt_image_emb, device, dtype, do_classifier_free_guidance):
         | 
| 571 |  | 
| 572 | 
             
                    if isinstance(prompt_image_emb, torch.Tensor):
         | 
| 573 | 
             
                        prompt_image_emb = prompt_image_emb.clone().detach()
         | 
| @@ -583,6 +614,11 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline): | |
| 583 | 
             
                        prompt_image_emb = torch.cat([prompt_image_emb], dim=0)
         | 
| 584 |  | 
| 585 | 
             
                    prompt_image_emb = self.image_proj_model(prompt_image_emb)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 586 | 
             
                    return prompt_image_emb
         | 
| 587 |  | 
| 588 | 
             
                @torch.no_grad()
         | 
| @@ -623,7 +659,13 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline): | |
| 623 | 
             
                    clip_skip: Optional[int] = None,
         | 
| 624 | 
             
                    callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
         | 
| 625 | 
             
                    callback_on_step_end_tensor_inputs: List[str] = ["latents"],
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 626 | 
             
                    control_mask = None,
         | 
|  | |
| 627 | 
             
                    **kwargs,
         | 
| 628 | 
             
                ):
         | 
| 629 | 
             
                    r"""
         | 
| @@ -758,6 +800,7 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline): | |
| 758 | 
             
                            If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
         | 
| 759 | 
             
                            otherwise a `tuple` is returned containing the output images.
         | 
| 760 | 
             
                    """
         | 
|  | |
| 761 | 
             
                    lpw = LongPromptWeight()
         | 
| 762 |  | 
| 763 | 
             
                    callback = kwargs.pop("callback", None)
         | 
| @@ -789,6 +832,10 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline): | |
| 789 | 
             
                            mult * [control_guidance_start],
         | 
| 790 | 
             
                            mult * [control_guidance_end],
         | 
| 791 | 
             
                        )
         | 
|  | |
|  | |
|  | |
|  | |
| 792 |  | 
| 793 | 
             
                    # 1. Check inputs. Raise error if not correct
         | 
| 794 | 
             
                    self.check_inputs(
         | 
| @@ -851,6 +898,7 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline): | |
| 851 | 
             
                    # 3.2 Encode image prompt
         | 
| 852 | 
             
                    prompt_image_emb = self._encode_prompt_image_emb(image_embeds, 
         | 
| 853 | 
             
                                                                     device,
         | 
|  | |
| 854 | 
             
                                                                     self.unet.dtype,
         | 
| 855 | 
             
                                                                     self.do_classifier_free_guidance)
         | 
| 856 |  | 
| @@ -1031,24 +1079,57 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline): | |
| 1031 | 
             
                                    controlnet_cond_scale = controlnet_cond_scale[0]
         | 
| 1032 | 
             
                                cond_scale = controlnet_cond_scale * controlnet_keep[i]
         | 
| 1033 |  | 
| 1034 | 
            -
                             | 
| 1035 | 
            -
                                 | 
| 1036 | 
            -
                                 | 
| 1037 | 
            -
             | 
| 1038 | 
            -
             | 
| 1039 | 
            -
             | 
| 1040 | 
            -
             | 
| 1041 | 
            -
             | 
| 1042 | 
            -
             | 
| 1043 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1044 |  | 
| 1045 | 
            -
             | 
| 1046 | 
            -
             | 
| 1047 | 
            -
             | 
| 1048 | 
            -
             | 
| 1049 | 
            -
             | 
| 1050 | 
            -
             | 
| 1051 | 
            -
             | 
| 1052 |  | 
| 1053 | 
             
                            if guess_mode and self.do_classifier_free_guidance:
         | 
| 1054 | 
             
                                # Infered ControlNet only for the conditional batch.
         | 
|  | |
| 22 | 
             
            import PIL.Image
         | 
| 23 | 
             
            import torch
         | 
| 24 | 
             
            import torch.nn.functional as F
         | 
|  | |
| 25 |  | 
| 26 | 
             
            from diffusers.image_processor import PipelineImageInput
         | 
| 27 |  | 
|  | |
| 40 | 
             
            from diffusers.utils.import_utils import is_xformers_available
         | 
| 41 |  | 
| 42 | 
             
            from ip_adapter.resampler import Resampler
         | 
| 43 | 
            +
            from ip_adapter.utils import is_torch2_available
         | 
| 44 |  | 
| 45 | 
            +
            if is_torch2_available():
         | 
| 46 | 
            +
                from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
         | 
| 47 | 
            +
            else:
         | 
| 48 | 
            +
                from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor
         | 
| 49 | 
             
            from ip_adapter.attention_processor import region_control
         | 
| 50 |  | 
| 51 | 
             
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
|  | |
| 105 | 
             
                    ```
         | 
| 106 | 
             
            """
         | 
| 107 |  | 
| 108 | 
            +
            from transformers import CLIPTokenizer
         | 
| 109 | 
             
            from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
         | 
| 110 | 
             
            class LongPromptWeight(object):
         | 
| 111 |  | 
|  | |
| 485 | 
             
                    prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
         | 
| 486 | 
             
                    return prompt_embeds
         | 
| 487 |  | 
| 488 | 
            +
            def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
         | 
| 489 | 
            +
                
         | 
| 490 | 
            +
                stickwidth = 4
         | 
| 491 | 
            +
                limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
         | 
| 492 | 
            +
                kps = np.array(kps)
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                w, h = image_pil.size
         | 
| 495 | 
            +
                out_img = np.zeros([h, w, 3])
         | 
| 496 | 
            +
             | 
| 497 | 
            +
                for i in range(len(limbSeq)):
         | 
| 498 | 
            +
                    index = limbSeq[i]
         | 
| 499 | 
            +
                    color = color_list[index[0]]
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                    x = kps[index][:, 0]
         | 
| 502 | 
            +
                    y = kps[index][:, 1]
         | 
| 503 | 
            +
                    length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
         | 
| 504 | 
            +
                    angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
         | 
| 505 | 
            +
                    polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
         | 
| 506 | 
            +
                    out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
         | 
| 507 | 
            +
                out_img = (out_img * 0.6).astype(np.uint8)
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                for idx_kp, kp in enumerate(kps):
         | 
| 510 | 
            +
                    color = color_list[idx_kp]
         | 
| 511 | 
            +
                    x, y = kp
         | 
| 512 | 
            +
                    out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
         | 
| 515 | 
            +
                return out_img_pil
         | 
| 516 |  | 
| 517 | 
             
            class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
         | 
| 518 |  | 
|  | |
| 598 | 
             
                        if isinstance(attn_processor, IPAttnProcessor):
         | 
| 599 | 
             
                            attn_processor.scale = scale
         | 
| 600 |  | 
| 601 | 
            +
                def _encode_prompt_image_emb(self, prompt_image_emb, device, num_images_per_prompt, dtype, do_classifier_free_guidance):
         | 
| 602 |  | 
| 603 | 
             
                    if isinstance(prompt_image_emb, torch.Tensor):
         | 
| 604 | 
             
                        prompt_image_emb = prompt_image_emb.clone().detach()
         | 
|  | |
| 614 | 
             
                        prompt_image_emb = torch.cat([prompt_image_emb], dim=0)
         | 
| 615 |  | 
| 616 | 
             
                    prompt_image_emb = self.image_proj_model(prompt_image_emb)
         | 
| 617 | 
            +
             | 
| 618 | 
            +
                    bs_embed, seq_len, _ = prompt_image_emb.shape
         | 
| 619 | 
            +
                    prompt_image_emb = prompt_image_emb.repeat(1, num_images_per_prompt, 1)
         | 
| 620 | 
            +
                    prompt_image_emb = prompt_image_emb.view(bs_embed * num_images_per_prompt, seq_len, -1)
         | 
| 621 | 
            +
                    
         | 
| 622 | 
             
                    return prompt_image_emb
         | 
| 623 |  | 
| 624 | 
             
                @torch.no_grad()
         | 
|  | |
| 659 | 
             
                    clip_skip: Optional[int] = None,
         | 
| 660 | 
             
                    callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
         | 
| 661 | 
             
                    callback_on_step_end_tensor_inputs: List[str] = ["latents"],
         | 
| 662 | 
            +
             | 
| 663 | 
            +
                    # IP adapter
         | 
| 664 | 
            +
                    ip_adapter_scale=None,
         | 
| 665 | 
            +
             | 
| 666 | 
            +
                    # Enhance Face Region
         | 
| 667 | 
             
                    control_mask = None,
         | 
| 668 | 
            +
             | 
| 669 | 
             
                    **kwargs,
         | 
| 670 | 
             
                ):
         | 
| 671 | 
             
                    r"""
         | 
|  | |
| 800 | 
             
                            If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
         | 
| 801 | 
             
                            otherwise a `tuple` is returned containing the output images.
         | 
| 802 | 
             
                    """
         | 
| 803 | 
            +
             | 
| 804 | 
             
                    lpw = LongPromptWeight()
         | 
| 805 |  | 
| 806 | 
             
                    callback = kwargs.pop("callback", None)
         | 
|  | |
| 832 | 
             
                            mult * [control_guidance_start],
         | 
| 833 | 
             
                            mult * [control_guidance_end],
         | 
| 834 | 
             
                        )
         | 
| 835 | 
            +
                    
         | 
| 836 | 
            +
                    # 0. set ip_adapter_scale
         | 
| 837 | 
            +
                    if ip_adapter_scale is not None:
         | 
| 838 | 
            +
                        self.set_ip_adapter_scale(ip_adapter_scale)
         | 
| 839 |  | 
| 840 | 
             
                    # 1. Check inputs. Raise error if not correct
         | 
| 841 | 
             
                    self.check_inputs(
         | 
|  | |
| 898 | 
             
                    # 3.2 Encode image prompt
         | 
| 899 | 
             
                    prompt_image_emb = self._encode_prompt_image_emb(image_embeds, 
         | 
| 900 | 
             
                                                                     device,
         | 
| 901 | 
            +
                                                                     num_images_per_prompt,
         | 
| 902 | 
             
                                                                     self.unet.dtype,
         | 
| 903 | 
             
                                                                     self.do_classifier_free_guidance)
         | 
| 904 |  | 
|  | |
| 1079 | 
             
                                    controlnet_cond_scale = controlnet_cond_scale[0]
         | 
| 1080 | 
             
                                cond_scale = controlnet_cond_scale * controlnet_keep[i]
         | 
| 1081 |  | 
| 1082 | 
            +
                            if isinstance(self.controlnet, MultiControlNetModel):
         | 
| 1083 | 
            +
                                down_block_res_samples_list, mid_block_res_sample_list = [], []
         | 
| 1084 | 
            +
                                for control_index in range(len(self.controlnet.nets)):
         | 
| 1085 | 
            +
                                    controlnet = self.controlnet.nets[control_index]
         | 
| 1086 | 
            +
                                    if control_index == 0:
         | 
| 1087 | 
            +
                                        # assume fhe first controlnet is IdentityNet
         | 
| 1088 | 
            +
                                        controlnet_prompt_embeds = prompt_image_emb
         | 
| 1089 | 
            +
                                    else:
         | 
| 1090 | 
            +
                                        controlnet_prompt_embeds = prompt_embeds
         | 
| 1091 | 
            +
                                    down_block_res_samples, mid_block_res_sample = controlnet(control_model_input,
         | 
| 1092 | 
            +
                                                                                              t,
         | 
| 1093 | 
            +
                                                                                              encoder_hidden_states=controlnet_prompt_embeds,
         | 
| 1094 | 
            +
                                                                                              controlnet_cond=image[control_index],
         | 
| 1095 | 
            +
                                                                                              conditioning_scale=cond_scale[control_index],
         | 
| 1096 | 
            +
                                                                                              guess_mode=guess_mode,
         | 
| 1097 | 
            +
                                                                                              added_cond_kwargs=controlnet_added_cond_kwargs,
         | 
| 1098 | 
            +
                                                                                              return_dict=False)
         | 
| 1099 | 
            +
             | 
| 1100 | 
            +
                                    # controlnet mask
         | 
| 1101 | 
            +
                                    if control_index == 0 and control_mask_wight_image_list is not None:
         | 
| 1102 | 
            +
                                        down_block_res_samples = [
         | 
| 1103 | 
            +
                                            down_block_res_sample * mask_weight
         | 
| 1104 | 
            +
                                            for down_block_res_sample, mask_weight in zip(down_block_res_samples, control_mask_wight_image_list)
         | 
| 1105 | 
            +
                                        ]
         | 
| 1106 | 
            +
                                        mid_block_res_sample *= control_mask_wight_image_list[-1]
         | 
| 1107 | 
            +
             | 
| 1108 | 
            +
                                    down_block_res_samples_list.append(down_block_res_samples)
         | 
| 1109 | 
            +
                                    mid_block_res_sample_list.append(mid_block_res_sample)
         | 
| 1110 | 
            +
             | 
| 1111 | 
            +
                                mid_block_res_sample = torch.stack(mid_block_res_sample_list).sum(dim=0)
         | 
| 1112 | 
            +
                                down_block_res_samples = [torch.stack(down_block_res_samples).sum(dim=0) for down_block_res_samples in
         | 
| 1113 | 
            +
                                                          zip(*down_block_res_samples_list)]
         | 
| 1114 | 
            +
                            else:
         | 
| 1115 | 
            +
                                down_block_res_samples, mid_block_res_sample = self.controlnet(
         | 
| 1116 | 
            +
                                    control_model_input,
         | 
| 1117 | 
            +
                                    t,
         | 
| 1118 | 
            +
                                    encoder_hidden_states=prompt_image_emb,
         | 
| 1119 | 
            +
                                    controlnet_cond=image,
         | 
| 1120 | 
            +
                                    conditioning_scale=cond_scale,
         | 
| 1121 | 
            +
                                    guess_mode=guess_mode,
         | 
| 1122 | 
            +
                                    added_cond_kwargs=controlnet_added_cond_kwargs,
         | 
| 1123 | 
            +
                                    return_dict=False,
         | 
| 1124 | 
            +
                                )
         | 
| 1125 |  | 
| 1126 | 
            +
                                # controlnet mask
         | 
| 1127 | 
            +
                                if control_mask_wight_image_list is not None:
         | 
| 1128 | 
            +
                                    down_block_res_samples = [
         | 
| 1129 | 
            +
                                        down_block_res_sample * mask_weight
         | 
| 1130 | 
            +
                                        for down_block_res_sample, mask_weight in zip(down_block_res_samples, control_mask_wight_image_list)
         | 
| 1131 | 
            +
                                    ]
         | 
| 1132 | 
            +
                                    mid_block_res_sample *= control_mask_wight_image_list[-1]
         | 
| 1133 |  | 
| 1134 | 
             
                            if guess_mode and self.do_classifier_free_guidance:
         | 
| 1135 | 
             
                                # Infered ControlNet only for the conditional batch.
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -1,7 +1,7 @@ | |
| 1 | 
            -
            diffusers==0.25. | 
| 2 | 
             
            torch==2.0.0
         | 
| 3 | 
             
            torchvision==0.15.1
         | 
| 4 | 
            -
            transformers==4. | 
| 5 | 
             
            accelerate
         | 
| 6 | 
             
            safetensors
         | 
| 7 | 
             
            einops
         | 
| @@ -11,4 +11,8 @@ omegaconf | |
| 11 | 
             
            peft
         | 
| 12 | 
             
            huggingface-hub==0.20.2
         | 
| 13 | 
             
            opencv-python
         | 
| 14 | 
            -
            insightface
         | 
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            diffusers==0.25.1
         | 
| 2 | 
             
            torch==2.0.0
         | 
| 3 | 
             
            torchvision==0.15.1
         | 
| 4 | 
            +
            transformers==4.37.1
         | 
| 5 | 
             
            accelerate
         | 
| 6 | 
             
            safetensors
         | 
| 7 | 
             
            einops
         | 
|  | |
| 11 | 
             
            peft
         | 
| 12 | 
             
            huggingface-hub==0.20.2
         | 
| 13 | 
             
            opencv-python
         | 
| 14 | 
            +
            insightface
         | 
| 15 | 
            +
            gradio
         | 
| 16 | 
            +
            controlnet_aux
         | 
| 17 | 
            +
            gdown
         | 
| 18 | 
            +
            peft
         | 
