Spaces:
Running
on
L4
Running
on
L4
LIU, Zichen
commited on
Commit
·
ca78dbf
1
Parent(s):
79ecf3f
update
Browse files
MagicQuill/magic_utils.py
CHANGED
|
@@ -110,14 +110,12 @@ def draw_contour(img, mask):
|
|
| 110 |
img_np = img_np.astype(np.uint8)
|
| 111 |
img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
| 112 |
|
| 113 |
-
# 膨胀掩码
|
| 114 |
kernel = np.ones((5, 5), np.uint8)
|
| 115 |
mask_dilated = cv2.dilate(mask_np, kernel, iterations=3)
|
| 116 |
contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 117 |
for contour in contours:
|
| 118 |
-
cv2.drawContours(img_bgr, [contour], -1, (0, 0, 255), thickness=10)
|
| 119 |
img_np = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
| 120 |
-
# 转换回tensor
|
| 121 |
transform = transforms.ToTensor()
|
| 122 |
img_tensor = transform(img_np)
|
| 123 |
|
|
@@ -128,7 +126,6 @@ def draw_contour(img, mask):
|
|
| 128 |
def get_colored_contour(img1, img2, threshold=10):
|
| 129 |
diff = torch.abs(img1 - img2).float()
|
| 130 |
diff_gray = torch.mean(diff, dim=-1)
|
| 131 |
-
# 阈值处理以生成二进制掩码
|
| 132 |
mask = diff_gray > threshold
|
| 133 |
|
| 134 |
return draw_contour(img2, mask), mask
|
|
@@ -153,9 +150,7 @@ def rgb_to_name(rgb_tuple):
|
|
| 153 |
def find_different_colors(img1, img2, threshold=10):
|
| 154 |
img1 = img1.to(torch.uint8)
|
| 155 |
img2 = img2.to(torch.uint8)
|
| 156 |
-
# 计算图像之间的绝对差异
|
| 157 |
diff = torch.abs(img1 - img2).float().mean(dim=-1)
|
| 158 |
-
# 找到大于阈值的差异区域
|
| 159 |
diff_mask = diff > threshold
|
| 160 |
diff_indices = torch.nonzero(diff_mask, as_tuple=True)
|
| 161 |
|
|
@@ -165,14 +160,10 @@ def find_different_colors(img1, img2, threshold=10):
|
|
| 165 |
else:
|
| 166 |
sampled_diff_indices = diff_indices
|
| 167 |
|
| 168 |
-
# 提取不同区域的颜色
|
| 169 |
diff_colors = img2[sampled_diff_indices[0], sampled_diff_indices[1], :]
|
| 170 |
-
# 将颜色值转换为颜色名称
|
| 171 |
color_names = [rgb_to_name(tuple(color)) for color in diff_colors]
|
| 172 |
name_counter = Counter(color_names)
|
| 173 |
-
# 过滤出现超过10次的颜色
|
| 174 |
filtered_colors = {name: count for name, count in name_counter.items() if count > 10}
|
| 175 |
-
# 按出现次数从大到小排序
|
| 176 |
sorted_color_names = [name for name, count in sorted(filtered_colors.items(), key=lambda item: item[1], reverse=True)]
|
| 177 |
if len(sorted_color_names) >= 3:
|
| 178 |
return "colorful"
|
|
@@ -183,19 +174,15 @@ def get_bounding_box_from_mask(mask, padded=False):
|
|
| 183 |
# Ensure the mask is a binary mask (0s and 1s)
|
| 184 |
mask = mask.squeeze()
|
| 185 |
rows, cols = torch.where(mask > 0.5)
|
| 186 |
-
# If there are no '1's in the mask, return None or an appropriate bounding box like (0,0,0,0)
|
| 187 |
if len(rows) == 0 or len(cols) == 0:
|
| 188 |
return (0, 0, 0, 0)
|
| 189 |
height, width = mask.shape
|
| 190 |
if padded:
|
| 191 |
padded_size = max(width, height)
|
| 192 |
-
# 检查填充发生在哪个方向
|
| 193 |
if width < height:
|
| 194 |
-
# 宽度较小,填充发生在宽度上
|
| 195 |
offset_x = (padded_size - width) / 2
|
| 196 |
offset_y = 0
|
| 197 |
else:
|
| 198 |
-
# 高度较小,填充发生在高度上
|
| 199 |
offset_y = (padded_size - height) / 2
|
| 200 |
offset_x = 0
|
| 201 |
# Find the bounding box coordinates
|
|
|
|
| 110 |
img_np = img_np.astype(np.uint8)
|
| 111 |
img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
| 112 |
|
|
|
|
| 113 |
kernel = np.ones((5, 5), np.uint8)
|
| 114 |
mask_dilated = cv2.dilate(mask_np, kernel, iterations=3)
|
| 115 |
contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 116 |
for contour in contours:
|
| 117 |
+
cv2.drawContours(img_bgr, [contour], -1, (0, 0, 255), thickness=10)
|
| 118 |
img_np = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
|
|
|
| 119 |
transform = transforms.ToTensor()
|
| 120 |
img_tensor = transform(img_np)
|
| 121 |
|
|
|
|
| 126 |
def get_colored_contour(img1, img2, threshold=10):
|
| 127 |
diff = torch.abs(img1 - img2).float()
|
| 128 |
diff_gray = torch.mean(diff, dim=-1)
|
|
|
|
| 129 |
mask = diff_gray > threshold
|
| 130 |
|
| 131 |
return draw_contour(img2, mask), mask
|
|
|
|
| 150 |
def find_different_colors(img1, img2, threshold=10):
|
| 151 |
img1 = img1.to(torch.uint8)
|
| 152 |
img2 = img2.to(torch.uint8)
|
|
|
|
| 153 |
diff = torch.abs(img1 - img2).float().mean(dim=-1)
|
|
|
|
| 154 |
diff_mask = diff > threshold
|
| 155 |
diff_indices = torch.nonzero(diff_mask, as_tuple=True)
|
| 156 |
|
|
|
|
| 160 |
else:
|
| 161 |
sampled_diff_indices = diff_indices
|
| 162 |
|
|
|
|
| 163 |
diff_colors = img2[sampled_diff_indices[0], sampled_diff_indices[1], :]
|
|
|
|
| 164 |
color_names = [rgb_to_name(tuple(color)) for color in diff_colors]
|
| 165 |
name_counter = Counter(color_names)
|
|
|
|
| 166 |
filtered_colors = {name: count for name, count in name_counter.items() if count > 10}
|
|
|
|
| 167 |
sorted_color_names = [name for name, count in sorted(filtered_colors.items(), key=lambda item: item[1], reverse=True)]
|
| 168 |
if len(sorted_color_names) >= 3:
|
| 169 |
return "colorful"
|
|
|
|
| 174 |
# Ensure the mask is a binary mask (0s and 1s)
|
| 175 |
mask = mask.squeeze()
|
| 176 |
rows, cols = torch.where(mask > 0.5)
|
|
|
|
| 177 |
if len(rows) == 0 or len(cols) == 0:
|
| 178 |
return (0, 0, 0, 0)
|
| 179 |
height, width = mask.shape
|
| 180 |
if padded:
|
| 181 |
padded_size = max(width, height)
|
|
|
|
| 182 |
if width < height:
|
|
|
|
| 183 |
offset_x = (padded_size - width) / 2
|
| 184 |
offset_y = 0
|
| 185 |
else:
|
|
|
|
| 186 |
offset_y = (padded_size - height) / 2
|
| 187 |
offset_x = 0
|
| 188 |
# Find the bounding box coordinates
|
MagicQuill/scribble_color_edit.py
CHANGED
|
@@ -53,7 +53,6 @@ class ScribbleColorEditModel():
|
|
| 53 |
self.model, self.clip, self.vae = self.checkpoint_loader.load_checkpoint(ckpt_name)
|
| 54 |
if not hasattr(self, 'edge_controlnet') or not hasattr(self, 'color_controlnet') or not hasattr(self, 'brushnet'):
|
| 55 |
self.load_models(base_model_version, dtype)
|
| 56 |
-
# 根据基础模型版本加载相应的 ControlNet&BrushNet 模型
|
| 57 |
positive = self.clip_text_encoder.encode(self.clip, positive_prompt)[0]
|
| 58 |
negative = self.clip_text_encoder.encode(self.clip, negative_prompt)[0]
|
| 59 |
# Grow Mask for Color Editing
|
|
@@ -90,9 +89,7 @@ class ScribbleColorEditModel():
|
|
| 90 |
bool_remove_mask_resized = (remove_mask_resized > 0.5)
|
| 91 |
|
| 92 |
if stroke_as_edge == "enable":
|
| 93 |
-
# 将remove_mask区域的像素变成黑色
|
| 94 |
lineart_output[bool_remove_mask_resized] = 0.0
|
| 95 |
-
# 将add_mask区域的像素变成白色
|
| 96 |
lineart_output[bool_add_mask_resized] = 1.0
|
| 97 |
else:
|
| 98 |
lineart_output[bool_remove_mask_resized & ~bool_add_mask_resized] = 0.0
|
|
@@ -101,7 +98,7 @@ class ScribbleColorEditModel():
|
|
| 101 |
# BrushNet
|
| 102 |
model, positive, negative, latent = self.brushnet_node.model_update(
|
| 103 |
model=self.model,
|
| 104 |
-
vae=self.vae,
|
| 105 |
image=image,
|
| 106 |
mask=mask,
|
| 107 |
brushnet=self.brushnet,
|
|
|
|
| 53 |
self.model, self.clip, self.vae = self.checkpoint_loader.load_checkpoint(ckpt_name)
|
| 54 |
if not hasattr(self, 'edge_controlnet') or not hasattr(self, 'color_controlnet') or not hasattr(self, 'brushnet'):
|
| 55 |
self.load_models(base_model_version, dtype)
|
|
|
|
| 56 |
positive = self.clip_text_encoder.encode(self.clip, positive_prompt)[0]
|
| 57 |
negative = self.clip_text_encoder.encode(self.clip, negative_prompt)[0]
|
| 58 |
# Grow Mask for Color Editing
|
|
|
|
| 89 |
bool_remove_mask_resized = (remove_mask_resized > 0.5)
|
| 90 |
|
| 91 |
if stroke_as_edge == "enable":
|
|
|
|
| 92 |
lineart_output[bool_remove_mask_resized] = 0.0
|
|
|
|
| 93 |
lineart_output[bool_add_mask_resized] = 1.0
|
| 94 |
else:
|
| 95 |
lineart_output[bool_remove_mask_resized & ~bool_add_mask_resized] = 0.0
|
|
|
|
| 98 |
# BrushNet
|
| 99 |
model, positive, negative, latent = self.brushnet_node.model_update(
|
| 100 |
model=self.model,
|
| 101 |
+
vae=self.vae,
|
| 102 |
image=image,
|
| 103 |
mask=mask,
|
| 104 |
brushnet=self.brushnet,
|