Spaces:
Runtime error
Runtime error
| from typing import Any, Callable, Dict, List, Optional, Union | |
| import importlib | |
| import inspect | |
| import math | |
| from pathlib import Path | |
| import re | |
| from collections import defaultdict | |
| import cv2 | |
| import time | |
| import numpy as np | |
| import PIL | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import einsum | |
| from torch.autograd.function import Function | |
| from diffusers import DiffusionPipeline | |
| #Support for find the region of object | |
| def encode_region_map_sp(state,tokenizer,unet,width,height, scale_ratio=8, text_ids=None,do_classifier_free_guidance = True): | |
| if text_ids is None: | |
| return torch.Tensor(0) | |
| uncond, cond = text_ids[0], text_ids[1] | |
| '''img_state = [] | |
| for k, v in state.items(): | |
| if v["map"] is None: | |
| continue | |
| v_input = tokenizer( | |
| k, | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| add_special_tokens=False, | |
| ).input_ids | |
| dotmap = v["map"] < 255 | |
| out = dotmap.astype(float) | |
| out = out * float(v["weight"]) * g_strength | |
| #if v["mask_outsides"]: | |
| out[out==0] = -1 * float(v["mask_outsides"]) | |
| arr = torch.from_numpy( | |
| out | |
| ) | |
| img_state.append((v_input, arr)) | |
| if len(img_state) == 0: | |
| return torch.Tensor(0)''' | |
| w_tensors = dict() | |
| cond = cond.reshape(-1,).tolist() if isinstance(cond,np.ndarray) or isinstance(cond, torch.Tensor) else None | |
| uncond = uncond.reshape(-1,).tolist() if isinstance(uncond,np.ndarray) or isinstance(uncond, torch.Tensor) else None | |
| for layer in unet.down_blocks: | |
| c = int(len(cond)) | |
| #w, h = img_state[0][1].shape | |
| w_r, h_r = int(math.ceil(width / scale_ratio)), int(math.ceil(height / scale_ratio)) | |
| ret_cond_tensor = torch.zeros((1, int(w_r * h_r), c), dtype=torch.float32) | |
| ret_uncond_tensor = torch.zeros((1, int(w_r * h_r), c), dtype=torch.float32) | |
| #for v_as_tokens, img_where_color in img_state: | |
| if state is not None: | |
| for k, v in state.items(): | |
| if v["map"] is None: | |
| continue | |
| is_in = 0 | |
| k_as_tokens = tokenizer( | |
| k, | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| add_special_tokens=False, | |
| ).input_ids | |
| region_map_resize = np.array(v["map"] < 255 ,dtype = np.uint8) | |
| region_map_resize = cv2.resize(region_map_resize,(w_r,h_r),interpolation = cv2.INTER_CUBIC) | |
| region_map_resize = (region_map_resize == np.max(region_map_resize)).astype(float) | |
| region_map_resize = region_map_resize * float(v["weight"]) | |
| region_map_resize[region_map_resize==0] = -1 * float(v["mask_outsides"]) | |
| ret = torch.from_numpy( | |
| region_map_resize | |
| ) | |
| ret = ret.reshape(-1, 1).repeat(1, len(k_as_tokens)) | |
| '''ret = ( | |
| F.interpolate( | |
| img_where_color.unsqueeze(0).unsqueeze(1), | |
| scale_factor=1 / scale_ratio, | |
| mode="bilinear", | |
| align_corners=True, | |
| ) | |
| .squeeze() | |
| .reshape(-1, 1) | |
| .repeat(1, len(v_as_tokens)) | |
| )''' | |
| if cond is not None: | |
| for idx, tok in enumerate(cond): | |
| if cond[idx : idx + len(k_as_tokens)] == k_as_tokens: | |
| is_in = 1 | |
| ret_cond_tensor[0, :, idx : idx + len(k_as_tokens)] += ret | |
| if uncond is not None: | |
| for idx, tok in enumerate(uncond): | |
| if uncond[idx : idx + len(k_as_tokens)] == k_as_tokens: | |
| is_in = 1 | |
| ret_uncond_tensor[0, :, idx : idx + len(k_as_tokens)] += ret | |
| if not is_in == 1: | |
| print(f"tokens {k_as_tokens} not found in text") | |
| w_tensors[w_r * h_r] = torch.cat([ret_uncond_tensor, ret_cond_tensor]) if do_classifier_free_guidance else ret_cond_tensor | |
| scale_ratio *= 2 | |
| return w_tensors | |
| def encode_region_map( | |
| pipe : DiffusionPipeline, | |
| state, | |
| width, | |
| height, | |
| num_images_per_prompt, | |
| text_ids = None, | |
| ): | |
| negative_prompt_tokens_id, prompt_tokens_id = text_ids[0] , text_ids[1] | |
| if prompt_tokens_id is None: | |
| return torch.Tensor(0) | |
| prompt_tokens_id = np.array(prompt_tokens_id) | |
| negative_prompt_tokens_id = np.array(prompt_tokens_id) if negative_prompt_tokens_id is not None else None | |
| #Spilit to each prompt | |
| number_prompt = prompt_tokens_id.shape[0] | |
| prompt_tokens_id = np.split(prompt_tokens_id,number_prompt) | |
| negative_prompt_tokens_id = np.split(negative_prompt_tokens_id,number_prompt) if negative_prompt_tokens_id is not None else None | |
| lst_prompt_map = [] | |
| if not isinstance(state,list): | |
| state = [state] | |
| if len(state) < number_prompt: | |
| state = [state] + [None] * int(number_prompt - len(state)) | |
| for i in range(0,number_prompt): | |
| text_ids = [negative_prompt_tokens_id[i],prompt_tokens_id[i]] if negative_prompt_tokens_id is not None else [None,prompt_tokens_id[i]] | |
| region_map = encode_region_map_sp(state[i],pipe.tokenizer,pipe.unet,width,height,scale_ratio = pipe.vae_scale_factor,text_ids = text_ids,do_classifier_free_guidance = pipe.do_classifier_free_guidance) | |
| lst_prompt_map.append(region_map) | |
| region_state_sp = {} | |
| for d in lst_prompt_map: | |
| for key, tensor in d.items(): | |
| if key in region_state_sp: | |
| #If key exist, concat | |
| region_state_sp[key] = torch.cat((region_state_sp[key], tensor)) | |
| else: | |
| # if key doesnt exist, add | |
| region_state_sp[key] = tensor | |
| #add_when_apply num_images_per_prompt | |
| region_state = {} | |
| for key, tensor in region_state_sp.items(): | |
| # Repeant accoding to axis = 0 | |
| region_state[key] = tensor.repeat(num_images_per_prompt,1,1) | |
| return region_state | |