Spaces:
Build error
Build error
jwyang
commited on
Commit
·
ad7aaa6
1
Parent(s):
a574e10
add heatmap visualization
Browse files- app.py +16 -4
- model/image_encoder/swin_transformer.py +8 -4
- model/model.py +15 -4
app.py
CHANGED
|
@@ -118,11 +118,20 @@ def recognize_image(image, texts):
|
|
| 118 |
text_embeddings = model.get_text_embeddings(texts.split(';'))
|
| 119 |
|
| 120 |
# compute output
|
| 121 |
-
feat_img = model.encode_image(img_t.unsqueeze(0))
|
| 122 |
output = model.logit_scale.exp() * feat_img @ text_embeddings.t()
|
| 123 |
prediction = output.softmax(-1).flatten()
|
| 124 |
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
|
| 128 |
image = gr.inputs.Image()
|
|
@@ -132,8 +141,11 @@ gr.Interface(
|
|
| 132 |
description="UniCL for Zero-shot Image Recognition Demo (https://github.com/microsoft/unicl)",
|
| 133 |
fn=recognize_image,
|
| 134 |
inputs=["image", "text"],
|
| 135 |
-
outputs=[
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
| 137 |
],
|
| 138 |
examples=[
|
| 139 |
["./elephants.png", "an elephant; an elephant walking in the river; four elephants walking in the river"],
|
|
|
|
| 118 |
text_embeddings = model.get_text_embeddings(texts.split(';'))
|
| 119 |
|
| 120 |
# compute output
|
| 121 |
+
feat_img, feat_map = model.encode_image(img_t.unsqueeze(0), output_map=True)
|
| 122 |
output = model.logit_scale.exp() * feat_img @ text_embeddings.t()
|
| 123 |
prediction = output.softmax(-1).flatten()
|
| 124 |
|
| 125 |
+
# generate feat map given the top matched texts
|
| 126 |
+
output_map = (feat_map * text_embeddings[prediction.argmax()].unsqueeze(-1)).sum(1).softmax(-1)
|
| 127 |
+
output_map = output_map.view(1, 1, 7, 7)
|
| 128 |
+
|
| 129 |
+
output_map = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(output_map)
|
| 130 |
+
output_map = output_map.squeeze(1).detach().permute(1, 2, 0).numpy()
|
| 131 |
+
output_map = (output_map - output_map.min()) / (output_map.max() - output_map.min())
|
| 132 |
+
heatmap = show_cam_on_image(img_d, output_map, use_rgb=True)
|
| 133 |
+
|
| 134 |
+
return Image.fromarray(heatmap), {texts.split(';')[i]: float(prediction[i]) for i in range(len(texts.split(';')))}
|
| 135 |
|
| 136 |
|
| 137 |
image = gr.inputs.Image()
|
|
|
|
| 141 |
description="UniCL for Zero-shot Image Recognition Demo (https://github.com/microsoft/unicl)",
|
| 142 |
fn=recognize_image,
|
| 143 |
inputs=["image", "text"],
|
| 144 |
+
outputs=[
|
| 145 |
+
gr.outputs.Image(
|
| 146 |
+
type="pil",
|
| 147 |
+
label="zero-shot heat map"),
|
| 148 |
+
label,
|
| 149 |
],
|
| 150 |
examples=[
|
| 151 |
["./elephants.png", "an elephant; an elephant walking in the river; four elephants walking in the river"],
|
model/image_encoder/swin_transformer.py
CHANGED
|
@@ -557,7 +557,7 @@ class SwinTransformer(nn.Module):
|
|
| 557 |
def no_weight_decay_keywords(self):
|
| 558 |
return {'relative_position_bias_table'}
|
| 559 |
|
| 560 |
-
def forward_features(self, x):
|
| 561 |
x = self.patch_embed(x)
|
| 562 |
if self.ape:
|
| 563 |
x = x + self.absolute_pos_embed
|
|
@@ -566,10 +566,14 @@ class SwinTransformer(nn.Module):
|
|
| 566 |
for layer in self.layers:
|
| 567 |
x = layer(x)
|
| 568 |
|
| 569 |
-
|
| 570 |
-
x = self.avgpool(
|
| 571 |
x = torch.flatten(x, 1)
|
| 572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
|
| 574 |
def forward(self, x):
|
| 575 |
x = self.forward_features(x)
|
|
|
|
| 557 |
def no_weight_decay_keywords(self):
|
| 558 |
return {'relative_position_bias_table'}
|
| 559 |
|
| 560 |
+
def forward_features(self, x, output_map=False):
|
| 561 |
x = self.patch_embed(x)
|
| 562 |
if self.ape:
|
| 563 |
x = x + self.absolute_pos_embed
|
|
|
|
| 566 |
for layer in self.layers:
|
| 567 |
x = layer(x)
|
| 568 |
|
| 569 |
+
x_map = self.norm(x).transpose(1, 2) # B C L
|
| 570 |
+
x = self.avgpool(x_map) # B C 1
|
| 571 |
x = torch.flatten(x, 1)
|
| 572 |
+
|
| 573 |
+
if output_map:
|
| 574 |
+
return x, x_map
|
| 575 |
+
else:
|
| 576 |
+
return x
|
| 577 |
|
| 578 |
def forward(self, x):
|
| 579 |
x = self.forward_features(x)
|
model/model.py
CHANGED
|
@@ -153,14 +153,25 @@ class UniCLModel(nn.Module):
|
|
| 153 |
imnet_text_embeddings = torch.stack(clss_embeddings, dim=0)
|
| 154 |
return imnet_text_embeddings
|
| 155 |
|
| 156 |
-
def encode_image(self, image, norm=True):
|
| 157 |
-
x = self.image_encoder.forward_features(image)
|
|
|
|
|
|
|
|
|
|
| 158 |
x = x @ self.image_projection
|
| 159 |
|
|
|
|
|
|
|
|
|
|
| 160 |
if norm:
|
| 161 |
x = x / x.norm(dim=-1, keepdim=True)
|
| 162 |
-
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
def encode_text(self, text, norm=True):
|
| 166 |
x = self.text_encoder(**text)
|
|
|
|
| 153 |
imnet_text_embeddings = torch.stack(clss_embeddings, dim=0)
|
| 154 |
return imnet_text_embeddings
|
| 155 |
|
| 156 |
+
def encode_image(self, image, norm=True, output_map=False):
|
| 157 |
+
x = self.image_encoder.forward_features(image, output_map=output_map)
|
| 158 |
+
if output_map:
|
| 159 |
+
x, x_map = x
|
| 160 |
+
|
| 161 |
x = x @ self.image_projection
|
| 162 |
|
| 163 |
+
if output_map:
|
| 164 |
+
x_map = self.image_projection.unsqueeze(0).transpose(1, 2) @ x_map
|
| 165 |
+
|
| 166 |
if norm:
|
| 167 |
x = x / x.norm(dim=-1, keepdim=True)
|
| 168 |
+
if output_map:
|
| 169 |
+
x_map = x_map / x_map.norm(dim=1, keepdim=True)
|
| 170 |
+
|
| 171 |
+
if output_map:
|
| 172 |
+
return x, x_map
|
| 173 |
+
else:
|
| 174 |
+
return x
|
| 175 |
|
| 176 |
def encode_text(self, text, norm=True):
|
| 177 |
x = self.text_encoder(**text)
|