Spaces:
Runtime error
Runtime error
Update model_new.py
Browse files- model_new.py +32 -40
model_new.py
CHANGED
|
@@ -12,33 +12,13 @@ from condition.canny import CannyDetector
|
|
| 12 |
import time
|
| 13 |
from autoregressive.models.generate import generate
|
| 14 |
from condition.midas.depth import MidasDetector
|
| 15 |
-
|
| 16 |
|
| 17 |
models = {
|
| 18 |
-
"
|
| 19 |
-
"depth": "checkpoints/
|
| 20 |
}
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def resize_image_to_16_multiple(image, condition_type='canny'):
|
| 24 |
-
if isinstance(image, np.ndarray):
|
| 25 |
-
image = Image.fromarray(image)
|
| 26 |
-
# image = Image.open(image_path)
|
| 27 |
-
width, height = image.size
|
| 28 |
-
|
| 29 |
-
if condition_type == 'depth': # The depth model requires a side length that is a multiple of 32
|
| 30 |
-
new_width = (width + 31) // 32 * 32
|
| 31 |
-
new_height = (height + 31) // 32 * 32
|
| 32 |
-
else:
|
| 33 |
-
new_width = (width + 15) // 16 * 16
|
| 34 |
-
new_height = (height + 15) // 16 * 16
|
| 35 |
-
|
| 36 |
-
resized_image = image.resize((new_width, new_height))
|
| 37 |
-
return resized_image
|
| 38 |
-
|
| 39 |
-
|
| 40 |
class Model:
|
| 41 |
-
|
| 42 |
def __init__(self):
|
| 43 |
self.device = torch.device(
|
| 44 |
"cuda")
|
|
@@ -46,8 +26,9 @@ class Model:
|
|
| 46 |
self.task_name = ""
|
| 47 |
self.vq_model = self.load_vq()
|
| 48 |
self.t5_model = self.load_t5()
|
| 49 |
-
self.
|
| 50 |
-
|
|
|
|
| 51 |
|
| 52 |
def to(self, device):
|
| 53 |
self.gpt_model_canny.to('cuda')
|
|
@@ -67,19 +48,17 @@ class Model:
|
|
| 67 |
gpt_ckpt = models[condition_type]
|
| 68 |
# precision = torch.bfloat16
|
| 69 |
precision = torch.float32
|
| 70 |
-
latent_size =
|
| 71 |
gpt_model = GPT_models["GPT-XL"](
|
| 72 |
block_size=latent_size**2,
|
| 73 |
cls_token_num=120,
|
| 74 |
model_type='t2i',
|
| 75 |
condition_type=condition_type,
|
|
|
|
| 76 |
).to(device='cpu', dtype=precision)
|
| 77 |
-
|
| 78 |
model_weight = load_file(gpt_ckpt)
|
| 79 |
-
print("prev:", model_weight['adapter.model.embeddings.patch_embeddings.projection.weight'])
|
| 80 |
gpt_model.load_state_dict(model_weight, strict=True)
|
| 81 |
gpt_model.eval()
|
| 82 |
-
print("loaded:", gpt_model.adapter.model.embeddings.patch_embeddings.projection.weight)
|
| 83 |
print("gpt model is loaded")
|
| 84 |
return gpt_model
|
| 85 |
|
|
@@ -109,22 +88,35 @@ class Model:
|
|
| 109 |
seed: int,
|
| 110 |
low_threshold: int,
|
| 111 |
high_threshold: int,
|
|
|
|
|
|
|
| 112 |
) -> list[PIL.Image.Image]:
|
| 113 |
-
print(image)
|
| 114 |
-
image = resize_image_to_16_multiple(image, 'canny')
|
| 115 |
-
W, H = image.size
|
| 116 |
-
print(W, H)
|
| 117 |
self.t5_model.model.to('cuda').to(torch.bfloat16)
|
| 118 |
self.gpt_model_canny.to('cuda').to(torch.bfloat16)
|
| 119 |
self.vq_model.to('cuda')
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
condition_img = condition_img.to(self.device)
|
| 127 |
-
condition_img = 2
|
| 128 |
prompts = [prompt] * 2
|
| 129 |
caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
|
| 130 |
|
|
|
|
| 12 |
import time
|
| 13 |
from autoregressive.models.generate import generate
|
| 14 |
from condition.midas.depth import MidasDetector
|
| 15 |
+
from preprocessor import Preprocessor
|
| 16 |
|
| 17 |
models = {
|
| 18 |
+
"edge": "checkpoints/edge_base.safetensors",
|
| 19 |
+
"depth": "checkpoints/depth_base.safetensors",
|
| 20 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
class Model:
|
|
|
|
| 22 |
def __init__(self):
|
| 23 |
self.device = torch.device(
|
| 24 |
"cuda")
|
|
|
|
| 26 |
self.task_name = ""
|
| 27 |
self.vq_model = self.load_vq()
|
| 28 |
self.t5_model = self.load_t5()
|
| 29 |
+
self.gpt_model_edge = self.load_gpt(condition_type='edge')
|
| 30 |
+
self.gpt_model_depth = self.load_gpt(condition_type='depth')
|
| 31 |
+
self.preprocessor = Preprocessor()
|
| 32 |
|
| 33 |
def to(self, device):
|
| 34 |
self.gpt_model_canny.to('cuda')
|
|
|
|
| 48 |
gpt_ckpt = models[condition_type]
|
| 49 |
# precision = torch.bfloat16
|
| 50 |
precision = torch.float32
|
| 51 |
+
latent_size = 512 // 16
|
| 52 |
gpt_model = GPT_models["GPT-XL"](
|
| 53 |
block_size=latent_size**2,
|
| 54 |
cls_token_num=120,
|
| 55 |
model_type='t2i',
|
| 56 |
condition_type=condition_type,
|
| 57 |
+
adapter_size='base',
|
| 58 |
).to(device='cpu', dtype=precision)
|
|
|
|
| 59 |
model_weight = load_file(gpt_ckpt)
|
|
|
|
| 60 |
gpt_model.load_state_dict(model_weight, strict=True)
|
| 61 |
gpt_model.eval()
|
|
|
|
| 62 |
print("gpt model is loaded")
|
| 63 |
return gpt_model
|
| 64 |
|
|
|
|
| 88 |
seed: int,
|
| 89 |
low_threshold: int,
|
| 90 |
high_threshold: int,
|
| 91 |
+
control_strength: float,
|
| 92 |
+
preprocessor_name: str,
|
| 93 |
) -> list[PIL.Image.Image]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
self.t5_model.model.to('cuda').to(torch.bfloat16)
|
| 95 |
self.gpt_model_canny.to('cuda').to(torch.bfloat16)
|
| 96 |
self.vq_model.to('cuda')
|
| 97 |
+
if isinstance(image, np.ndarray):
|
| 98 |
+
image = Image.fromarray(image)
|
| 99 |
+
origin_W, origin_H = image.size
|
| 100 |
+
if preprocessor_name == 'Canny':
|
| 101 |
+
self.preprocessor.load("Canny")
|
| 102 |
+
condition_img = self.preprocessor(
|
| 103 |
+
image=image, low_threshold=low_threshold, high_threshold=high_threshold, detect_resolution=512)
|
| 104 |
+
elif preprocessor_name == 'Hed':
|
| 105 |
+
self.preprocessor.load("HED")
|
| 106 |
+
condition_img = self.preprocessor(
|
| 107 |
+
image=image,image_resolution=512, detect_resolution=512)
|
| 108 |
+
elif preprocessor_name == 'Lineart':
|
| 109 |
+
self.preprocessor.load("Lineart")
|
| 110 |
+
condition_img = self.preprocessor(
|
| 111 |
+
image=image,image_resolution=512, detect_resolution=512)
|
| 112 |
+
elif preprocessor_name == 'No preprocess':
|
| 113 |
+
condition_img = image
|
| 114 |
+
condition_img = condition_img.resize((512,512))
|
| 115 |
+
W, H = condition_img.size
|
| 116 |
+
|
| 117 |
+
condition_img = torch.from_numpy(np.array(condition_img)).unsqueeze(0).permute(0,3,1,2).repeat(2,1,1,1)
|
| 118 |
condition_img = condition_img.to(self.device)
|
| 119 |
+
condition_img = 2*(condition_img/255 - 0.5)
|
| 120 |
prompts = [prompt] * 2
|
| 121 |
caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
|
| 122 |
|