Spaces:
Runtime error
Runtime error
Update model_new.py
Browse files- model_new.py +20 -8
model_new.py
CHANGED
|
@@ -180,21 +180,33 @@ class Model:
|
|
| 180 |
top_k: int,
|
| 181 |
top_p: int,
|
| 182 |
seed: int,
|
|
|
|
|
|
|
| 183 |
) -> list[PIL.Image.Image]:
|
| 184 |
-
image = resize_image_to_16_multiple(image, 'depth')
|
| 185 |
-
W, H = image.size
|
| 186 |
-
print(W, H)
|
| 187 |
self.gpt_model_canny.to('cpu')
|
| 188 |
self.t5_model.model.to(self.device)
|
| 189 |
self.gpt_model_depth.to(self.device)
|
| 190 |
self.get_control_depth.model.to(self.device)
|
| 191 |
self.vq_model.to(self.device)
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
-
condition_img =
|
| 195 |
-
|
| 196 |
-
condition_img =
|
| 197 |
-
|
| 198 |
prompts = [prompt] * 2
|
| 199 |
caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
|
| 200 |
|
|
|
|
| 180 |
top_k: int,
|
| 181 |
top_p: int,
|
| 182 |
seed: int,
|
| 183 |
+
control_strength: float,
|
| 184 |
+
preprocessor_name: str
|
| 185 |
) -> list[PIL.Image.Image]:
|
|
|
|
|
|
|
|
|
|
| 186 |
self.gpt_model_canny.to('cpu')
|
| 187 |
self.t5_model.model.to(self.device)
|
| 188 |
self.gpt_model_depth.to(self.device)
|
| 189 |
self.get_control_depth.model.to(self.device)
|
| 190 |
self.vq_model.to(self.device)
|
| 191 |
+
if isinstance(image, np.ndarray):
|
| 192 |
+
image = Image.fromarray(image)
|
| 193 |
+
origin_W, origin_H = image.size
|
| 194 |
+
# print(image)
|
| 195 |
+
if preprocessor_name == 'depth':
|
| 196 |
+
self.preprocessor.load("Depth")
|
| 197 |
+
condition_img = self.preprocessor(
|
| 198 |
+
image=image,
|
| 199 |
+
image_resolution=512,
|
| 200 |
+
detect_resolution=512,
|
| 201 |
+
)
|
| 202 |
+
elif preprocessor_name == 'No preprocess':
|
| 203 |
+
condition_img = image
|
| 204 |
+
condition_img = condition_img.resize((512,512))
|
| 205 |
+
W, H = condition_img.size
|
| 206 |
|
| 207 |
+
condition_img = torch.from_numpy(np.array(condition_img)).unsqueeze(0).permute(0,3,1,2).repeat(2,1,1,1)
|
| 208 |
+
condition_img = condition_img.to(self.device)
|
| 209 |
+
condition_img = 2*(condition_img/255 - 0.5)
|
|
|
|
| 210 |
prompts = [prompt] * 2
|
| 211 |
caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
|
| 212 |
|