Spaces:
Sleeping
Sleeping
fix!: new model
Browse files
models.py
CHANGED
|
@@ -135,6 +135,76 @@ def ResNet101(num_classes=1000):
|
|
| 135 |
def ResNet152(num_classes=1000):
|
| 136 |
return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
class ClassifierHead(nn.Module):
|
| 140 |
def __init__(self, in_features, num_classes):
|
|
@@ -165,6 +235,9 @@ class ResNetUNet(ResNet):
|
|
| 165 |
def __init__(self, block, num_blocks, num_classes=1000):
|
| 166 |
super().__init__(block, num_blocks, num_classes)
|
| 167 |
|
|
|
|
|
|
|
|
|
|
| 168 |
# Calculate encoder channel sizes
|
| 169 |
self.enc_channels = [
|
| 170 |
64,
|
|
@@ -174,23 +247,14 @@ class ResNetUNet(ResNet):
|
|
| 174 |
512 * block.expansion,
|
| 175 |
]
|
| 176 |
|
| 177 |
-
# Replace t_max_avg_pooling with standard avgpool
|
| 178 |
in_features = 512 * block.expansion
|
| 179 |
self.classifier_head = ClassifierHead(in_features, num_classes)
|
| 180 |
|
| 181 |
-
|
| 182 |
-
self.decoder5 = nn.Sequential(
|
| 183 |
-
nn.Conv2d(2048 + 1024, 1024, 3, padding=1),
|
| 184 |
-
nn.BatchNorm2d(1024),
|
| 185 |
-
nn.ReLU(inplace=True),
|
| 186 |
-
nn.Conv2d(1024, 512, 3, padding=1),
|
| 187 |
-
nn.BatchNorm2d(512),
|
| 188 |
-
nn.ReLU(inplace=True),
|
| 189 |
-
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
|
| 190 |
-
)
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
|
|
|
| 194 |
nn.BatchNorm2d(512),
|
| 195 |
nn.ReLU(inplace=True),
|
| 196 |
nn.Conv2d(512, 256, 3, padding=1),
|
|
@@ -199,8 +263,8 @@ class ResNetUNet(ResNet):
|
|
| 199 |
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
|
| 200 |
)
|
| 201 |
|
| 202 |
-
self.
|
| 203 |
-
nn.Conv2d(256 +
|
| 204 |
nn.BatchNorm2d(256),
|
| 205 |
nn.ReLU(inplace=True),
|
| 206 |
nn.Conv2d(256, 128, 3, padding=1),
|
|
@@ -209,8 +273,8 @@ class ResNetUNet(ResNet):
|
|
| 209 |
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
|
| 210 |
)
|
| 211 |
|
| 212 |
-
self.
|
| 213 |
-
nn.Conv2d(128 + 64, 128, 3, padding=1),
|
| 214 |
nn.BatchNorm2d(128),
|
| 215 |
nn.ReLU(inplace=True),
|
| 216 |
nn.Conv2d(128, 64, 3, padding=1),
|
|
@@ -219,8 +283,18 @@ class ResNetUNet(ResNet):
|
|
| 219 |
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
|
| 220 |
)
|
| 221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
self.final_conv = nn.Sequential(
|
| 223 |
-
nn.Conv2d(64, 32, 3, padding=1),
|
| 224 |
nn.BatchNorm2d(32),
|
| 225 |
nn.ReLU(inplace=True),
|
| 226 |
nn.Conv2d(32, 1, 1),
|
|
@@ -265,6 +339,8 @@ class ResNetUNet(ResNet):
|
|
| 265 |
seg_out, size=input_size, mode="bilinear", align_corners=True
|
| 266 |
)
|
| 267 |
|
|
|
|
|
|
|
| 268 |
# Use segmentation to mask features before classification
|
| 269 |
# Upsample segmentation mask to match feature size
|
| 270 |
attention_mask = F.interpolate(
|
|
@@ -272,15 +348,13 @@ class ResNetUNet(ResNet):
|
|
| 272 |
)
|
| 273 |
|
| 274 |
# Apply attention mask to features
|
| 275 |
-
attended_features =
|
| 276 |
|
| 277 |
-
# Use new classifier head
|
| 278 |
cls_out = self.classifier_head(attended_features)
|
| 279 |
|
| 280 |
return cls_out, seg_out
|
| 281 |
|
| 282 |
|
| 283 |
-
# Helper functions without K and T parameters
|
| 284 |
def ResNet18UNet(num_classes=1000):
|
| 285 |
return ResNetUNet(BasicBlock, [2, 2, 2, 2], num_classes)
|
| 286 |
|
|
@@ -298,4 +372,4 @@ def ResNet101UNet(num_classes=1000):
|
|
| 298 |
|
| 299 |
|
| 300 |
def ResNet152UNet(num_classes=1000):
|
| 301 |
-
return ResNetUNet(Bottleneck, [3, 8, 36, 3], num_classes)
|
|
|
|
| 135 |
def ResNet152(num_classes=1000):
|
| 136 |
return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)
|
| 137 |
|
| 138 |
+
import torch
|
| 139 |
+
import torch.nn as nn
|
| 140 |
+
import torch.nn.functional as F
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class SAM(nn.Module):
|
| 144 |
+
def __init__(self, bias=False):
|
| 145 |
+
super(SAM, self).__init__()
|
| 146 |
+
self.bias = bias
|
| 147 |
+
self.conv = nn.Conv2d(
|
| 148 |
+
in_channels=2,
|
| 149 |
+
out_channels=1,
|
| 150 |
+
kernel_size=7,
|
| 151 |
+
stride=1,
|
| 152 |
+
padding=3,
|
| 153 |
+
dilation=1,
|
| 154 |
+
bias=self.bias,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
def forward(self, x):
|
| 158 |
+
max = torch.max(x, 1)[0].unsqueeze(1)
|
| 159 |
+
avg = torch.mean(x, 1).unsqueeze(1)
|
| 160 |
+
concat = torch.cat((max, avg), dim=1)
|
| 161 |
+
output = self.conv(concat)
|
| 162 |
+
output = F.sigmoid(output) * x
|
| 163 |
+
return output
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class CAM(nn.Module):
|
| 167 |
+
def __init__(self, channels, r):
|
| 168 |
+
super(CAM, self).__init__()
|
| 169 |
+
self.channels = channels
|
| 170 |
+
self.r = r
|
| 171 |
+
self.linear = nn.Sequential(
|
| 172 |
+
nn.Linear(
|
| 173 |
+
in_features=self.channels,
|
| 174 |
+
out_features=self.channels // self.r,
|
| 175 |
+
bias=True,
|
| 176 |
+
),
|
| 177 |
+
nn.ReLU(inplace=True),
|
| 178 |
+
nn.Linear(
|
| 179 |
+
in_features=self.channels // self.r,
|
| 180 |
+
out_features=self.channels,
|
| 181 |
+
bias=True,
|
| 182 |
+
),
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
def forward(self, x):
|
| 186 |
+
max = F.adaptive_max_pool2d(x, output_size=1)
|
| 187 |
+
avg = F.adaptive_avg_pool2d(x, output_size=1)
|
| 188 |
+
b, c, _, _ = x.size()
|
| 189 |
+
linear_max = self.linear(max.view(b, c)).view(b, c, 1, 1)
|
| 190 |
+
linear_avg = self.linear(avg.view(b, c)).view(b, c, 1, 1)
|
| 191 |
+
output = linear_max + linear_avg
|
| 192 |
+
output = F.sigmoid(output) * x
|
| 193 |
+
return output
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class CBAM(nn.Module):
|
| 197 |
+
def __init__(self, channels, r):
|
| 198 |
+
super(CBAM, self).__init__()
|
| 199 |
+
self.channels = channels
|
| 200 |
+
self.r = r
|
| 201 |
+
self.sam = SAM(bias=False)
|
| 202 |
+
self.cam = CAM(channels=self.channels, r=self.r)
|
| 203 |
+
|
| 204 |
+
def forward(self, x):
|
| 205 |
+
output = self.cam(x)
|
| 206 |
+
output = self.sam(output)
|
| 207 |
+
return output + x
|
| 208 |
|
| 209 |
class ClassifierHead(nn.Module):
|
| 210 |
def __init__(self, in_features, num_classes):
|
|
|
|
| 235 |
def __init__(self, block, num_blocks, num_classes=1000):
|
| 236 |
super().__init__(block, num_blocks, num_classes)
|
| 237 |
|
| 238 |
+
# Get the expansion factor
|
| 239 |
+
expansion = block.expansion
|
| 240 |
+
|
| 241 |
# Calculate encoder channel sizes
|
| 242 |
self.enc_channels = [
|
| 243 |
64,
|
|
|
|
| 247 |
512 * block.expansion,
|
| 248 |
]
|
| 249 |
|
|
|
|
| 250 |
in_features = 512 * block.expansion
|
| 251 |
self.classifier_head = ClassifierHead(in_features, num_classes)
|
| 252 |
|
| 253 |
+
self.cbam = CBAM(channels=512 * block.expansion, r=16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
+
# Calculate encoder channel sizes
|
| 256 |
+
self.decoder5 = nn.Sequential(
|
| 257 |
+
nn.Conv2d((512 * expansion) + (256 * expansion), 512, 3, padding=1),
|
| 258 |
nn.BatchNorm2d(512),
|
| 259 |
nn.ReLU(inplace=True),
|
| 260 |
nn.Conv2d(512, 256, 3, padding=1),
|
|
|
|
| 263 |
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
|
| 264 |
)
|
| 265 |
|
| 266 |
+
self.decoder4 = nn.Sequential(
|
| 267 |
+
nn.Conv2d(256 + (128 * expansion), 256, 3, padding=1),
|
| 268 |
nn.BatchNorm2d(256),
|
| 269 |
nn.ReLU(inplace=True),
|
| 270 |
nn.Conv2d(256, 128, 3, padding=1),
|
|
|
|
| 273 |
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
|
| 274 |
)
|
| 275 |
|
| 276 |
+
self.decoder3 = nn.Sequential(
|
| 277 |
+
nn.Conv2d(128 + (64 * expansion), 128, 3, padding=1),
|
| 278 |
nn.BatchNorm2d(128),
|
| 279 |
nn.ReLU(inplace=True),
|
| 280 |
nn.Conv2d(128, 64, 3, padding=1),
|
|
|
|
| 283 |
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
|
| 284 |
)
|
| 285 |
|
| 286 |
+
self.decoder2 = nn.Sequential(
|
| 287 |
+
nn.Conv2d(64 + 64, 64, 3, padding=1),
|
| 288 |
+
nn.BatchNorm2d(64),
|
| 289 |
+
nn.ReLU(inplace=True),
|
| 290 |
+
nn.Conv2d(64, 64, 3, padding=1),
|
| 291 |
+
nn.BatchNorm2d(64),
|
| 292 |
+
nn.ReLU(inplace=True),
|
| 293 |
+
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
self.final_conv = nn.Sequential(
|
| 297 |
+
nn.Conv2d(64, 32, 3, padding=1),
|
| 298 |
nn.BatchNorm2d(32),
|
| 299 |
nn.ReLU(inplace=True),
|
| 300 |
nn.Conv2d(32, 1, 1),
|
|
|
|
| 339 |
seg_out, size=input_size, mode="bilinear", align_corners=True
|
| 340 |
)
|
| 341 |
|
| 342 |
+
attended_features = self.cbam(e5)
|
| 343 |
+
|
| 344 |
# Use segmentation to mask features before classification
|
| 345 |
# Upsample segmentation mask to match feature size
|
| 346 |
attention_mask = F.interpolate(
|
|
|
|
| 348 |
)
|
| 349 |
|
| 350 |
# Apply attention mask to features
|
| 351 |
+
attended_features = attended_features * (0.25 + attention_mask)
|
| 352 |
|
|
|
|
| 353 |
cls_out = self.classifier_head(attended_features)
|
| 354 |
|
| 355 |
return cls_out, seg_out
|
| 356 |
|
| 357 |
|
|
|
|
| 358 |
def ResNet18UNet(num_classes=1000):
|
| 359 |
return ResNetUNet(BasicBlock, [2, 2, 2, 2], num_classes)
|
| 360 |
|
|
|
|
| 372 |
|
| 373 |
|
| 374 |
def ResNet152UNet(num_classes=1000):
|
| 375 |
+
return ResNetUNet(Bottleneck, [3, 8, 36, 3], num_classes)
|