Spaces:
Build error
Build error
jwyang
commited on
Commit
·
eb1d5d5
1
Parent(s):
8424dda
support arbitary size
Browse files- app.py +3 -3
- model/__pycache__/__init__.cpython-39.pyc +0 -0
- model/__pycache__/model.cpython-39.pyc +0 -0
- model/__pycache__/templates.cpython-39.pyc +0 -0
- model/image_encoder/__pycache__/__init__.cpython-39.pyc +0 -0
- model/image_encoder/__pycache__/build.cpython-39.pyc +0 -0
- model/image_encoder/__pycache__/focalnet.cpython-39.pyc +0 -0
- model/image_encoder/__pycache__/swin_transformer.cpython-39.pyc +0 -0
- model/image_encoder/swin_transformer.py +71 -25
- model/model.py +2 -2
- model/text_encoder/__pycache__/__init__.cpython-39.pyc +0 -0
- model/text_encoder/__pycache__/build.cpython-39.pyc +0 -0
- model/text_encoder/__pycache__/hf_model.cpython-39.pyc +0 -0
- model/text_encoder/__pycache__/registry.cpython-39.pyc +0 -0
- model/text_encoder/__pycache__/transformer.cpython-39.pyc +0 -0
app.py
CHANGED
|
@@ -118,13 +118,13 @@ def recognize_image(image, texts):
|
|
| 118 |
text_embeddings = model.get_text_embeddings(texts.split(';'))
|
| 119 |
|
| 120 |
# compute output
|
| 121 |
-
feat_img, feat_map = model.encode_image(img_t.unsqueeze(0), output_map=True)
|
| 122 |
output = model.logit_scale.exp() * feat_img @ text_embeddings.t()
|
| 123 |
prediction = output.softmax(-1).flatten()
|
| 124 |
|
| 125 |
# generate feat map given the top matched texts
|
| 126 |
output_map = (feat_map * text_embeddings[prediction.argmax()].unsqueeze(-1)).sum(1).softmax(-1)
|
| 127 |
-
output_map = output_map.view(1, 1,
|
| 128 |
|
| 129 |
output_map = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(output_map)
|
| 130 |
output_map = output_map.squeeze(1).detach().permute(1, 2, 0).numpy()
|
|
@@ -142,10 +142,10 @@ gr.Interface(
|
|
| 142 |
fn=recognize_image,
|
| 143 |
inputs=["image", "text"],
|
| 144 |
outputs=[
|
| 145 |
-
label,
|
| 146 |
gr.outputs.Image(
|
| 147 |
type="pil",
|
| 148 |
label="zero-shot heat map"),
|
|
|
|
| 149 |
],
|
| 150 |
examples=[
|
| 151 |
["./elephants.png", "an elephant; an elephant walking in the river; four elephants walking in the river"],
|
|
|
|
| 118 |
text_embeddings = model.get_text_embeddings(texts.split(';'))
|
| 119 |
|
| 120 |
# compute output
|
| 121 |
+
feat_img, feat_map, H, W = model.encode_image(img_t.unsqueeze(0), output_map=True)
|
| 122 |
output = model.logit_scale.exp() * feat_img @ text_embeddings.t()
|
| 123 |
prediction = output.softmax(-1).flatten()
|
| 124 |
|
| 125 |
# generate feat map given the top matched texts
|
| 126 |
output_map = (feat_map * text_embeddings[prediction.argmax()].unsqueeze(-1)).sum(1).softmax(-1)
|
| 127 |
+
output_map = output_map.view(1, 1, H, W)
|
| 128 |
|
| 129 |
output_map = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(output_map)
|
| 130 |
output_map = output_map.squeeze(1).detach().permute(1, 2, 0).numpy()
|
|
|
|
| 142 |
fn=recognize_image,
|
| 143 |
inputs=["image", "text"],
|
| 144 |
outputs=[
|
|
|
|
| 145 |
gr.outputs.Image(
|
| 146 |
type="pil",
|
| 147 |
label="zero-shot heat map"),
|
| 148 |
+
label
|
| 149 |
],
|
| 150 |
examples=[
|
| 151 |
["./elephants.png", "an elephant; an elephant walking in the river; four elephants walking in the river"],
|
model/__pycache__/__init__.cpython-39.pyc
CHANGED
|
Binary files a/model/__pycache__/__init__.cpython-39.pyc and b/model/__pycache__/__init__.cpython-39.pyc differ
|
|
|
model/__pycache__/model.cpython-39.pyc
CHANGED
|
Binary files a/model/__pycache__/model.cpython-39.pyc and b/model/__pycache__/model.cpython-39.pyc differ
|
|
|
model/__pycache__/templates.cpython-39.pyc
CHANGED
|
Binary files a/model/__pycache__/templates.cpython-39.pyc and b/model/__pycache__/templates.cpython-39.pyc differ
|
|
|
model/image_encoder/__pycache__/__init__.cpython-39.pyc
CHANGED
|
Binary files a/model/image_encoder/__pycache__/__init__.cpython-39.pyc and b/model/image_encoder/__pycache__/__init__.cpython-39.pyc differ
|
|
|
model/image_encoder/__pycache__/build.cpython-39.pyc
CHANGED
|
Binary files a/model/image_encoder/__pycache__/build.cpython-39.pyc and b/model/image_encoder/__pycache__/build.cpython-39.pyc differ
|
|
|
model/image_encoder/__pycache__/focalnet.cpython-39.pyc
CHANGED
|
Binary files a/model/image_encoder/__pycache__/focalnet.cpython-39.pyc and b/model/image_encoder/__pycache__/focalnet.cpython-39.pyc differ
|
|
|
model/image_encoder/__pycache__/swin_transformer.cpython-39.pyc
CHANGED
|
Binary files a/model/image_encoder/__pycache__/swin_transformer.cpython-39.pyc and b/model/image_encoder/__pycache__/swin_transformer.cpython-39.pyc differ
|
|
|
model/image_encoder/swin_transformer.py
CHANGED
|
@@ -4,9 +4,10 @@
|
|
| 4 |
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
# Written by Ze Liu
|
| 6 |
# --------------------------------------------------------
|
| 7 |
-
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
|
|
|
| 10 |
import torch.utils.checkpoint as checkpoint
|
| 11 |
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 12 |
|
|
@@ -230,38 +231,51 @@ class SwinTransformerBlock(nn.Module):
|
|
| 230 |
|
| 231 |
self.register_buffer("attn_mask", attn_mask)
|
| 232 |
|
| 233 |
-
def forward(self, x):
|
| 234 |
-
H, W = self.input_resolution
|
| 235 |
B, L, C = x.shape
|
| 236 |
-
assert L == H * W, "input feature has wrong size"
|
| 237 |
|
| 238 |
shortcut = x
|
| 239 |
x = self.norm1(x)
|
| 240 |
-
x = x.view(B,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
# cyclic shift
|
| 243 |
if self.shift_size > 0:
|
| 244 |
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
|
|
|
| 245 |
else:
|
| 246 |
shifted_x = x
|
|
|
|
| 247 |
|
| 248 |
# partition windows
|
| 249 |
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
| 250 |
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
| 251 |
|
| 252 |
# W-MSA/SW-MSA
|
| 253 |
-
attn_windows = self.attn(x_windows, mask=
|
| 254 |
|
| 255 |
# merge windows
|
| 256 |
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
| 257 |
-
shifted_x = window_reverse(attn_windows, self.window_size,
|
| 258 |
|
| 259 |
# reverse cyclic shift
|
| 260 |
if self.shift_size > 0:
|
| 261 |
-
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
| 262 |
else:
|
| 263 |
x = shifted_x
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
# FFN
|
| 267 |
x = shortcut + self.drop_path(x)
|
|
@@ -304,16 +318,20 @@ class PatchMerging(nn.Module):
|
|
| 304 |
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 305 |
self.norm = norm_layer(4 * dim)
|
| 306 |
|
| 307 |
-
def forward(self, x):
|
| 308 |
"""
|
| 309 |
x: B, H*W, C
|
| 310 |
"""
|
| 311 |
-
H, W = self.input_resolution
|
| 312 |
B, L, C = x.shape
|
| 313 |
-
assert L == H * W, "input feature has wrong size"
|
| 314 |
-
assert
|
| 315 |
|
| 316 |
-
x = x.view(B,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
|
| 318 |
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
| 319 |
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
|
@@ -366,6 +384,8 @@ class BasicLayer(nn.Module):
|
|
| 366 |
self.input_resolution = input_resolution
|
| 367 |
self.depth = depth
|
| 368 |
self.use_checkpoint = use_checkpoint
|
|
|
|
|
|
|
| 369 |
|
| 370 |
# build blocks
|
| 371 |
self.blocks = nn.ModuleList([
|
|
@@ -385,15 +405,39 @@ class BasicLayer(nn.Module):
|
|
| 385 |
else:
|
| 386 |
self.downsample = None
|
| 387 |
|
| 388 |
-
def forward(self, x):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
for blk in self.blocks:
|
| 390 |
if self.use_checkpoint:
|
| 391 |
x = checkpoint.checkpoint(blk, x)
|
| 392 |
else:
|
| 393 |
-
x = blk(x)
|
| 394 |
if self.downsample is not None:
|
| 395 |
-
x = self.downsample(x)
|
| 396 |
-
|
|
|
|
| 397 |
|
| 398 |
def extra_repr(self) -> str:
|
| 399 |
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
|
@@ -440,12 +484,14 @@ class PatchEmbed(nn.Module):
|
|
| 440 |
def forward(self, x):
|
| 441 |
B, C, H, W = x.shape
|
| 442 |
# FIXME look at relaxing size constraints
|
| 443 |
-
assert H == self.img_size[0] and W == self.img_size[1], \
|
| 444 |
-
|
| 445 |
-
x = self.proj(x)
|
|
|
|
|
|
|
| 446 |
if self.norm is not None:
|
| 447 |
x = self.norm(x)
|
| 448 |
-
return x
|
| 449 |
|
| 450 |
def flops(self):
|
| 451 |
Ho, Wo = self.patches_resolution
|
|
@@ -558,20 +604,20 @@ class SwinTransformer(nn.Module):
|
|
| 558 |
return {'relative_position_bias_table'}
|
| 559 |
|
| 560 |
def forward_features(self, x, output_map=False):
|
| 561 |
-
x = self.patch_embed(x)
|
| 562 |
if self.ape:
|
| 563 |
x = x + self.absolute_pos_embed
|
| 564 |
x = self.pos_drop(x)
|
| 565 |
|
| 566 |
for layer in self.layers:
|
| 567 |
-
x = layer(x)
|
| 568 |
|
| 569 |
x_map = self.norm(x).transpose(1, 2) # B C L
|
| 570 |
x = self.avgpool(x_map) # B C 1
|
| 571 |
x = torch.flatten(x, 1)
|
| 572 |
|
| 573 |
if output_map:
|
| 574 |
-
return x, x_map
|
| 575 |
else:
|
| 576 |
return x
|
| 577 |
|
|
|
|
| 4 |
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
# Written by Ze Liu
|
| 6 |
# --------------------------------------------------------
|
| 7 |
+
import numpy as np
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
import torch.utils.checkpoint as checkpoint
|
| 12 |
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 13 |
|
|
|
|
| 231 |
|
| 232 |
self.register_buffer("attn_mask", attn_mask)
|
| 233 |
|
| 234 |
+
def forward(self, x, Ph, Pw, attn_mask):
|
| 235 |
+
# H, W = self.input_resolution
|
| 236 |
B, L, C = x.shape
|
| 237 |
+
# assert L == H * W, "input feature has wrong size"
|
| 238 |
|
| 239 |
shortcut = x
|
| 240 |
x = self.norm1(x)
|
| 241 |
+
x = x.view(B, Ph, Pw, C)
|
| 242 |
+
|
| 243 |
+
# pad feature maps to multiples of window size
|
| 244 |
+
pad_l = pad_t = 0
|
| 245 |
+
pad_r = (self.window_size - Pw % self.window_size) % self.window_size
|
| 246 |
+
pad_b = (self.window_size - Ph % self.window_size) % self.window_size
|
| 247 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
| 248 |
+
_, Hp, Wp, _ = x.shape
|
| 249 |
|
| 250 |
# cyclic shift
|
| 251 |
if self.shift_size > 0:
|
| 252 |
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
| 253 |
+
attn_mask = attn_mask
|
| 254 |
else:
|
| 255 |
shifted_x = x
|
| 256 |
+
attn_mask = None
|
| 257 |
|
| 258 |
# partition windows
|
| 259 |
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
| 260 |
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
| 261 |
|
| 262 |
# W-MSA/SW-MSA
|
| 263 |
+
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
| 264 |
|
| 265 |
# merge windows
|
| 266 |
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
| 267 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
| 268 |
|
| 269 |
# reverse cyclic shift
|
| 270 |
if self.shift_size > 0:
|
| 271 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
| 272 |
else:
|
| 273 |
x = shifted_x
|
| 274 |
+
|
| 275 |
+
if pad_r > 0 or pad_b > 0:
|
| 276 |
+
x = x[:, :Ph, :Pw, :].contiguous()
|
| 277 |
+
|
| 278 |
+
x = x.view(B, Ph * Pw, C)
|
| 279 |
|
| 280 |
# FFN
|
| 281 |
x = shortcut + self.drop_path(x)
|
|
|
|
| 318 |
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 319 |
self.norm = norm_layer(4 * dim)
|
| 320 |
|
| 321 |
+
def forward(self, x, Ph, Pw):
|
| 322 |
"""
|
| 323 |
x: B, H*W, C
|
| 324 |
"""
|
|
|
|
| 325 |
B, L, C = x.shape
|
| 326 |
+
# assert L == H * W, "input feature has wrong size"
|
| 327 |
+
# assert Ph % 2 == 0 and Pw % 2 == 0, f"x size ({Ph}*{Pw}) are not even."
|
| 328 |
|
| 329 |
+
x = x.view(B, Ph, Pw, C)
|
| 330 |
+
|
| 331 |
+
# padding
|
| 332 |
+
pad_input = (Ph % 2 == 1) or (Pw % 2 == 1)
|
| 333 |
+
if pad_input:
|
| 334 |
+
x = F.pad(x, (0, 0, 0, Pw % 2, 0, Ph % 2))
|
| 335 |
|
| 336 |
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
| 337 |
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
|
|
|
| 384 |
self.input_resolution = input_resolution
|
| 385 |
self.depth = depth
|
| 386 |
self.use_checkpoint = use_checkpoint
|
| 387 |
+
self.window_size = window_size
|
| 388 |
+
self.shift_size = window_size // 2
|
| 389 |
|
| 390 |
# build blocks
|
| 391 |
self.blocks = nn.ModuleList([
|
|
|
|
| 405 |
else:
|
| 406 |
self.downsample = None
|
| 407 |
|
| 408 |
+
def forward(self, x, Ph, Pw):
|
| 409 |
+
|
| 410 |
+
# calculate attention mask for SW-MSA
|
| 411 |
+
Hp = int(np.ceil(Ph / self.window_size)) * self.window_size
|
| 412 |
+
Wp = int(np.ceil(Pw / self.window_size)) * self.window_size
|
| 413 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
| 414 |
+
h_slices = (slice(0, -self.window_size),
|
| 415 |
+
slice(-self.window_size, -self.shift_size),
|
| 416 |
+
slice(-self.shift_size, None))
|
| 417 |
+
w_slices = (slice(0, -self.window_size),
|
| 418 |
+
slice(-self.window_size, -self.shift_size),
|
| 419 |
+
slice(-self.shift_size, None))
|
| 420 |
+
cnt = 0
|
| 421 |
+
for h in h_slices:
|
| 422 |
+
for w in w_slices:
|
| 423 |
+
img_mask[:, h, w, :] = cnt
|
| 424 |
+
cnt += 1
|
| 425 |
+
|
| 426 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
| 427 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
| 428 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 429 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
| 430 |
+
|
| 431 |
+
|
| 432 |
for blk in self.blocks:
|
| 433 |
if self.use_checkpoint:
|
| 434 |
x = checkpoint.checkpoint(blk, x)
|
| 435 |
else:
|
| 436 |
+
x = blk(x, Ph, Pw, attn_mask)
|
| 437 |
if self.downsample is not None:
|
| 438 |
+
x = self.downsample(x, Ph, Pw)
|
| 439 |
+
Ph, Pw = (Ph + 1) // 2, (Pw + 1) // 2
|
| 440 |
+
return x, Ph, Pw
|
| 441 |
|
| 442 |
def extra_repr(self) -> str:
|
| 443 |
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
|
|
|
| 484 |
def forward(self, x):
|
| 485 |
B, C, H, W = x.shape
|
| 486 |
# FIXME look at relaxing size constraints
|
| 487 |
+
# assert H == self.img_size[0] and W == self.img_size[1], \
|
| 488 |
+
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
| 489 |
+
x = self.proj(x)
|
| 490 |
+
Ph, Pw = x.shape[2:]
|
| 491 |
+
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
|
| 492 |
if self.norm is not None:
|
| 493 |
x = self.norm(x)
|
| 494 |
+
return x, Ph, Pw
|
| 495 |
|
| 496 |
def flops(self):
|
| 497 |
Ho, Wo = self.patches_resolution
|
|
|
|
| 604 |
return {'relative_position_bias_table'}
|
| 605 |
|
| 606 |
def forward_features(self, x, output_map=False):
|
| 607 |
+
x, Ph, Pw = self.patch_embed(x)
|
| 608 |
if self.ape:
|
| 609 |
x = x + self.absolute_pos_embed
|
| 610 |
x = self.pos_drop(x)
|
| 611 |
|
| 612 |
for layer in self.layers:
|
| 613 |
+
x, Ph, Pw = layer(x, Ph, Pw)
|
| 614 |
|
| 615 |
x_map = self.norm(x).transpose(1, 2) # B C L
|
| 616 |
x = self.avgpool(x_map) # B C 1
|
| 617 |
x = torch.flatten(x, 1)
|
| 618 |
|
| 619 |
if output_map:
|
| 620 |
+
return x, x_map, Ph, Pw
|
| 621 |
else:
|
| 622 |
return x
|
| 623 |
|
model/model.py
CHANGED
|
@@ -156,7 +156,7 @@ class UniCLModel(nn.Module):
|
|
| 156 |
def encode_image(self, image, norm=True, output_map=False):
|
| 157 |
x = self.image_encoder.forward_features(image, output_map=output_map)
|
| 158 |
if output_map:
|
| 159 |
-
x, x_map = x
|
| 160 |
|
| 161 |
x = x @ self.image_projection
|
| 162 |
|
|
@@ -169,7 +169,7 @@ class UniCLModel(nn.Module):
|
|
| 169 |
x_map = x_map / x_map.norm(dim=1, keepdim=True)
|
| 170 |
|
| 171 |
if output_map:
|
| 172 |
-
return x, x_map
|
| 173 |
else:
|
| 174 |
return x
|
| 175 |
|
|
|
|
| 156 |
def encode_image(self, image, norm=True, output_map=False):
|
| 157 |
x = self.image_encoder.forward_features(image, output_map=output_map)
|
| 158 |
if output_map:
|
| 159 |
+
x, x_map, H, W = x
|
| 160 |
|
| 161 |
x = x @ self.image_projection
|
| 162 |
|
|
|
|
| 169 |
x_map = x_map / x_map.norm(dim=1, keepdim=True)
|
| 170 |
|
| 171 |
if output_map:
|
| 172 |
+
return x, x_map, H, W
|
| 173 |
else:
|
| 174 |
return x
|
| 175 |
|
model/text_encoder/__pycache__/__init__.cpython-39.pyc
CHANGED
|
Binary files a/model/text_encoder/__pycache__/__init__.cpython-39.pyc and b/model/text_encoder/__pycache__/__init__.cpython-39.pyc differ
|
|
|
model/text_encoder/__pycache__/build.cpython-39.pyc
CHANGED
|
Binary files a/model/text_encoder/__pycache__/build.cpython-39.pyc and b/model/text_encoder/__pycache__/build.cpython-39.pyc differ
|
|
|
model/text_encoder/__pycache__/hf_model.cpython-39.pyc
CHANGED
|
Binary files a/model/text_encoder/__pycache__/hf_model.cpython-39.pyc and b/model/text_encoder/__pycache__/hf_model.cpython-39.pyc differ
|
|
|
model/text_encoder/__pycache__/registry.cpython-39.pyc
CHANGED
|
Binary files a/model/text_encoder/__pycache__/registry.cpython-39.pyc and b/model/text_encoder/__pycache__/registry.cpython-39.pyc differ
|
|
|
model/text_encoder/__pycache__/transformer.cpython-39.pyc
CHANGED
|
Binary files a/model/text_encoder/__pycache__/transformer.cpython-39.pyc and b/model/text_encoder/__pycache__/transformer.cpython-39.pyc differ
|
|
|