telecomadm1145 commited on
Commit
4e1bcbe
·
verified ·
1 Parent(s): 225693c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -95
app.py CHANGED
@@ -106,73 +106,6 @@ def build_transform(is_training: bool, interpolation: str):
106
  cfg = model.data_config.copy()
107
  cfg.update(dict(interpolation=interpolation))
108
  return timm.data.create_transform(**cfg, is_training=is_training)
109
- # 4. Attention Hook & Visualization =========================================
110
- # ---------------------------------------------------------------------------
111
-
112
- def get_attention_map(module, inputs, output):
113
- """
114
- forward_hook —— 捕获 softmax 后、dropout 前的注意力权重
115
- inputs[0] : [num_windows*B, num_heads, N, N] (N = win_size²)
116
- """
117
- global attention_maps
118
- if inputs and isinstance(inputs[0], torch.Tensor):
119
- # 只保存第一张图片第一帧即可
120
- attention_maps.append(inputs[0].detach().cpu())
121
-
122
-
123
- def create_attention_visualization(image_pil: Image.Image,
124
- attn_map: torch.Tensor,
125
- patch_size: int = 4) -> Image.Image:
126
- """
127
- 1) 把窗口注意力 → token 级热图 (H_token × W_token)
128
- 2) 最近邻上采样到原始 patch 网格 (96×96 for 384² 输入, patch_size=4)
129
- 3) 再上采样到像素级并做 blend
130
- """
131
- # -----------------------------------------------------------
132
- # 1. 计算每个 token 的“重要性” : head + query 维均值
133
- # attn_map: [num_windows, num_heads, N, N] (batch 已经 =1)
134
- attn_map = attn_map.mean(dim=1).mean(dim=2) # → [num_windows, N]
135
- attn_map = attn_map.clamp(min=0)
136
-
137
- num_windows, N = attn_map.shape
138
- win_size = int(math.sqrt(N)) # 12
139
- assert win_size * win_size == N, "N 不是 win_size²"
140
-
141
- # -----------------------------------------------------------
142
- # 2. 先在 token 分辨率下拼一张 heat_token (H_token × W_token)
143
- # token 分辨率 = win_size × windows_per_dim
144
- img_h, img_w = image_pil.size[1], image_pil.size[0] # PIL (w,h)
145
- num_patch_h, num_patch_w = img_h // patch_size, img_w // patch_size # 96×96
146
-
147
- win_per_row = int(round((num_patch_w / win_size))) # 8 for Stage1, 1 for Stage4
148
- token_side = win_per_row * win_size # 96 or 12
149
- heat_token = torch.zeros(token_side, token_side)
150
-
151
- for idx in range(num_windows):
152
- row_w = idx // win_per_row
153
- col_w = idx % win_per_row
154
- r0, r1 = row_w * win_size, (row_w + 1) * win_size
155
- c0, c1 = col_w * win_size, (col_w + 1) * win_size
156
- heat_token[r0:r1, c0:c1] = attn_map[idx].view(win_size, win_size)
157
-
158
- # -----------------------------------------------------------
159
- # 3. 归一化 & 上采样到 patch 网格尺寸 (96×96)
160
- heat_token = heat_token.unsqueeze(0).unsqueeze(0) # [1,1,H_t,W_t]
161
- heat_patch = F.interpolate(heat_token, size=(num_patch_h, num_patch_w),
162
- mode="nearest")[0, 0] # [H_patch,W_patch]
163
-
164
- heat_patch -= heat_patch.min()
165
- heat_patch /= (heat_patch.max() + 1e-6)
166
-
167
- # -----------------------------------------------------------
168
- # 4. 再转为像素热图并 Blend
169
- heat_np = heat_patch.numpy()
170
- heat_img = Image.fromarray((plt.cm.viridis(heat_np)[:, :, :3] * 255).astype(np.uint8))
171
- heat_img = heat_img.resize(image_pil.size, Image.BICUBIC)
172
-
173
- blended = Image.blend(image_pil.convert("RGB"), heat_img, alpha=0.55)
174
- return blended
175
-
176
 
177
  # ---------------------------------------------------------------------------
178
  # 5. 推理 + 可选的注意力可视化
@@ -181,38 +114,20 @@ def predict_and_visualize(image_pil: Image.Image,
181
  interpolation: str = "bicubic",
182
  show_attention: bool = True):
183
  if image_pil is None:
184
- return None, None
185
 
186
  load_model(ckpt_name)
187
 
188
- global attention_maps
189
- attention_maps = []
190
-
191
  transform = build_transform(is_training=False, interpolation=interpolation)
192
  input_tensor = transform(image_pil).unsqueeze(0).to(device)
193
 
194
- hook_handle = None
195
- if show_attention:
196
- # --- FIX: Target the attn_drop layer inside the attention module ---
197
- target_layer = model.backbone.layers[-1].blocks[-1].attn.attn_drop
198
- hook_handle = target_layer.register_forward_hook(get_attention_map)
199
-
200
  with torch.no_grad():
201
  logits = model(input_tensor)
202
-
203
- if hook_handle:
204
- hook_handle.remove()
205
-
206
  probs = F.softmax(logits, dim=1)[0]
207
  confidences = {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)}
208
 
209
- viz_image = None
210
- if show_attention and attention_maps:
211
- viz_image = create_attention_visualization(image_pil.copy(),
212
- attention_maps[0],
213
- patch_size=4) # Swin-Large 默认 patch 4
214
-
215
- return confidences, viz_image
216
 
217
  # ---------------------------------------------------------------------------
218
  # 6. Gradio UI
@@ -234,18 +149,16 @@ def launch_app():
234
  ["bilinear", "bicubic", "nearest"], value="bicubic",
235
  label="Resize Interpolation (Preprocessing)"
236
  )
237
- viz_checkbox = gr.Checkbox(value=True, label="Show Attention Visualization")
238
 
239
  in_img = gr.Image(type="pil", label="Upload an Image")
240
 
241
  with gr.Column(scale=2):
242
  out_lbl = gr.Label(num_top_classes=2, label="Predictions")
243
- out_viz = gr.Image(type="pil", label="Attention Map Visualization", visible=True)
244
 
245
  run_btn.click(
246
  predict_and_visualize,
247
- inputs=[in_img, model_choice, interp_choice, viz_checkbox],
248
- outputs=[out_lbl, out_viz]
249
  )
250
 
251
  # Create a dummy examples directory if it doesn't exist
@@ -258,9 +171,9 @@ def launch_app():
258
  example_files = [os.path.join(example_dir, f) for f in os.listdir(example_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
259
  if example_files:
260
  gr.Examples(
261
- examples=[[f, DEFAULT_CKPT, "bicubic", True] for f in example_files],
262
- inputs=[in_img, model_choice, interp_choice, viz_checkbox],
263
- outputs=[out_lbl, out_viz],
264
  fn=predict_and_visualize,
265
  cache_examples=False,
266
  )
 
106
  cfg = model.data_config.copy()
107
  cfg.update(dict(interpolation=interpolation))
108
  return timm.data.create_transform(**cfg, is_training=is_training)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  # ---------------------------------------------------------------------------
111
  # 5. 推理 + 可选的注意力可视化
 
114
  interpolation: str = "bicubic",
115
  show_attention: bool = True):
116
  if image_pil is None:
117
+ return None
118
 
119
  load_model(ckpt_name)
120
 
 
 
 
121
  transform = build_transform(is_training=False, interpolation=interpolation)
122
  input_tensor = transform(image_pil).unsqueeze(0).to(device)
123
 
 
 
 
 
 
 
124
  with torch.no_grad():
125
  logits = model(input_tensor)
126
+
 
 
 
127
  probs = F.softmax(logits, dim=1)[0]
128
  confidences = {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)}
129
 
130
+ return confidences
 
 
 
 
 
 
131
 
132
  # ---------------------------------------------------------------------------
133
  # 6. Gradio UI
 
149
  ["bilinear", "bicubic", "nearest"], value="bicubic",
150
  label="Resize Interpolation (Preprocessing)"
151
  )
 
152
 
153
  in_img = gr.Image(type="pil", label="Upload an Image")
154
 
155
  with gr.Column(scale=2):
156
  out_lbl = gr.Label(num_top_classes=2, label="Predictions")
 
157
 
158
  run_btn.click(
159
  predict_and_visualize,
160
+ inputs=[in_img, model_choice, interp_choice],
161
+ outputs=[out_lbl]
162
  )
163
 
164
  # Create a dummy examples directory if it doesn't exist
 
171
  example_files = [os.path.join(example_dir, f) for f in os.listdir(example_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
172
  if example_files:
173
  gr.Examples(
174
+ examples=[[f, DEFAULT_CKPT, "bicubic"] for f in example_files],
175
+ inputs=[in_img, model_choice, interp_choice],
176
+ outputs=[out_lbl],
177
  fn=predict_and_visualize,
178
  cache_examples=False,
179
  )