spuun commited on
Commit
f5faf92
·
verified ·
1 Parent(s): c9e9eb6

fix!: new model

Browse files
Files changed (1) hide show
  1. models.py +96 -22
models.py CHANGED
@@ -135,6 +135,76 @@ def ResNet101(num_classes=1000):
135
  def ResNet152(num_classes=1000):
136
  return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  class ClassifierHead(nn.Module):
140
  def __init__(self, in_features, num_classes):
@@ -165,6 +235,9 @@ class ResNetUNet(ResNet):
165
  def __init__(self, block, num_blocks, num_classes=1000):
166
  super().__init__(block, num_blocks, num_classes)
167
 
 
 
 
168
  # Calculate encoder channel sizes
169
  self.enc_channels = [
170
  64,
@@ -174,23 +247,14 @@ class ResNetUNet(ResNet):
174
  512 * block.expansion,
175
  ]
176
 
177
- # Replace t_max_avg_pooling with standard avgpool
178
  in_features = 512 * block.expansion
179
  self.classifier_head = ClassifierHead(in_features, num_classes)
180
 
181
- # Decoder layers remain the same
182
- self.decoder5 = nn.Sequential(
183
- nn.Conv2d(2048 + 1024, 1024, 3, padding=1),
184
- nn.BatchNorm2d(1024),
185
- nn.ReLU(inplace=True),
186
- nn.Conv2d(1024, 512, 3, padding=1),
187
- nn.BatchNorm2d(512),
188
- nn.ReLU(inplace=True),
189
- nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
190
- )
191
 
192
- self.decoder4 = nn.Sequential(
193
- nn.Conv2d(512 + 512, 512, 3, padding=1),
 
194
  nn.BatchNorm2d(512),
195
  nn.ReLU(inplace=True),
196
  nn.Conv2d(512, 256, 3, padding=1),
@@ -199,8 +263,8 @@ class ResNetUNet(ResNet):
199
  nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
200
  )
201
 
202
- self.decoder3 = nn.Sequential(
203
- nn.Conv2d(256 + 256, 256, 3, padding=1),
204
  nn.BatchNorm2d(256),
205
  nn.ReLU(inplace=True),
206
  nn.Conv2d(256, 128, 3, padding=1),
@@ -209,8 +273,8 @@ class ResNetUNet(ResNet):
209
  nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
210
  )
211
 
212
- self.decoder2 = nn.Sequential(
213
- nn.Conv2d(128 + 64, 128, 3, padding=1),
214
  nn.BatchNorm2d(128),
215
  nn.ReLU(inplace=True),
216
  nn.Conv2d(128, 64, 3, padding=1),
@@ -219,8 +283,18 @@ class ResNetUNet(ResNet):
219
  nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
220
  )
221
 
 
 
 
 
 
 
 
 
 
 
222
  self.final_conv = nn.Sequential(
223
- nn.Conv2d(64, 32, 3, padding=1),
224
  nn.BatchNorm2d(32),
225
  nn.ReLU(inplace=True),
226
  nn.Conv2d(32, 1, 1),
@@ -265,6 +339,8 @@ class ResNetUNet(ResNet):
265
  seg_out, size=input_size, mode="bilinear", align_corners=True
266
  )
267
 
 
 
268
  # Use segmentation to mask features before classification
269
  # Upsample segmentation mask to match feature size
270
  attention_mask = F.interpolate(
@@ -272,15 +348,13 @@ class ResNetUNet(ResNet):
272
  )
273
 
274
  # Apply attention mask to features
275
- attended_features = e5 * (0.25 + attention_mask)
276
 
277
- # Use new classifier head
278
  cls_out = self.classifier_head(attended_features)
279
 
280
  return cls_out, seg_out
281
 
282
 
283
- # Helper functions without K and T parameters
284
  def ResNet18UNet(num_classes=1000):
285
  return ResNetUNet(BasicBlock, [2, 2, 2, 2], num_classes)
286
 
@@ -298,4 +372,4 @@ def ResNet101UNet(num_classes=1000):
298
 
299
 
300
  def ResNet152UNet(num_classes=1000):
301
- return ResNetUNet(Bottleneck, [3, 8, 36, 3], num_classes)
 
135
  def ResNet152(num_classes=1000):
136
  return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)
137
 
138
+ import torch
139
+ import torch.nn as nn
140
+ import torch.nn.functional as F
141
+
142
+
143
+ class SAM(nn.Module):
144
+ def __init__(self, bias=False):
145
+ super(SAM, self).__init__()
146
+ self.bias = bias
147
+ self.conv = nn.Conv2d(
148
+ in_channels=2,
149
+ out_channels=1,
150
+ kernel_size=7,
151
+ stride=1,
152
+ padding=3,
153
+ dilation=1,
154
+ bias=self.bias,
155
+ )
156
+
157
+ def forward(self, x):
158
+ max = torch.max(x, 1)[0].unsqueeze(1)
159
+ avg = torch.mean(x, 1).unsqueeze(1)
160
+ concat = torch.cat((max, avg), dim=1)
161
+ output = self.conv(concat)
162
+ output = F.sigmoid(output) * x
163
+ return output
164
+
165
+
166
+ class CAM(nn.Module):
167
+ def __init__(self, channels, r):
168
+ super(CAM, self).__init__()
169
+ self.channels = channels
170
+ self.r = r
171
+ self.linear = nn.Sequential(
172
+ nn.Linear(
173
+ in_features=self.channels,
174
+ out_features=self.channels // self.r,
175
+ bias=True,
176
+ ),
177
+ nn.ReLU(inplace=True),
178
+ nn.Linear(
179
+ in_features=self.channels // self.r,
180
+ out_features=self.channels,
181
+ bias=True,
182
+ ),
183
+ )
184
+
185
+ def forward(self, x):
186
+ max = F.adaptive_max_pool2d(x, output_size=1)
187
+ avg = F.adaptive_avg_pool2d(x, output_size=1)
188
+ b, c, _, _ = x.size()
189
+ linear_max = self.linear(max.view(b, c)).view(b, c, 1, 1)
190
+ linear_avg = self.linear(avg.view(b, c)).view(b, c, 1, 1)
191
+ output = linear_max + linear_avg
192
+ output = F.sigmoid(output) * x
193
+ return output
194
+
195
+
196
+ class CBAM(nn.Module):
197
+ def __init__(self, channels, r):
198
+ super(CBAM, self).__init__()
199
+ self.channels = channels
200
+ self.r = r
201
+ self.sam = SAM(bias=False)
202
+ self.cam = CAM(channels=self.channels, r=self.r)
203
+
204
+ def forward(self, x):
205
+ output = self.cam(x)
206
+ output = self.sam(output)
207
+ return output + x
208
 
209
  class ClassifierHead(nn.Module):
210
  def __init__(self, in_features, num_classes):
 
235
  def __init__(self, block, num_blocks, num_classes=1000):
236
  super().__init__(block, num_blocks, num_classes)
237
 
238
+ # Get the expansion factor
239
+ expansion = block.expansion
240
+
241
  # Calculate encoder channel sizes
242
  self.enc_channels = [
243
  64,
 
247
  512 * block.expansion,
248
  ]
249
 
 
250
  in_features = 512 * block.expansion
251
  self.classifier_head = ClassifierHead(in_features, num_classes)
252
 
253
+ self.cbam = CBAM(channels=512 * block.expansion, r=16)
 
 
 
 
 
 
 
 
 
254
 
255
+ # Calculate encoder channel sizes
256
+ self.decoder5 = nn.Sequential(
257
+ nn.Conv2d((512 * expansion) + (256 * expansion), 512, 3, padding=1),
258
  nn.BatchNorm2d(512),
259
  nn.ReLU(inplace=True),
260
  nn.Conv2d(512, 256, 3, padding=1),
 
263
  nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
264
  )
265
 
266
+ self.decoder4 = nn.Sequential(
267
+ nn.Conv2d(256 + (128 * expansion), 256, 3, padding=1),
268
  nn.BatchNorm2d(256),
269
  nn.ReLU(inplace=True),
270
  nn.Conv2d(256, 128, 3, padding=1),
 
273
  nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
274
  )
275
 
276
+ self.decoder3 = nn.Sequential(
277
+ nn.Conv2d(128 + (64 * expansion), 128, 3, padding=1),
278
  nn.BatchNorm2d(128),
279
  nn.ReLU(inplace=True),
280
  nn.Conv2d(128, 64, 3, padding=1),
 
283
  nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
284
  )
285
 
286
+ self.decoder2 = nn.Sequential(
287
+ nn.Conv2d(64 + 64, 64, 3, padding=1),
288
+ nn.BatchNorm2d(64),
289
+ nn.ReLU(inplace=True),
290
+ nn.Conv2d(64, 64, 3, padding=1),
291
+ nn.BatchNorm2d(64),
292
+ nn.ReLU(inplace=True),
293
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
294
+ )
295
+
296
  self.final_conv = nn.Sequential(
297
+ nn.Conv2d(64, 32, 3, padding=1),
298
  nn.BatchNorm2d(32),
299
  nn.ReLU(inplace=True),
300
  nn.Conv2d(32, 1, 1),
 
339
  seg_out, size=input_size, mode="bilinear", align_corners=True
340
  )
341
 
342
+ attended_features = self.cbam(e5)
343
+
344
  # Use segmentation to mask features before classification
345
  # Upsample segmentation mask to match feature size
346
  attention_mask = F.interpolate(
 
348
  )
349
 
350
  # Apply attention mask to features
351
+ attended_features = attended_features * (0.25 + attention_mask)
352
 
 
353
  cls_out = self.classifier_head(attended_features)
354
 
355
  return cls_out, seg_out
356
 
357
 
 
358
  def ResNet18UNet(num_classes=1000):
359
  return ResNetUNet(BasicBlock, [2, 2, 2, 2], num_classes)
360
 
 
372
 
373
 
374
  def ResNet152UNet(num_classes=1000):
375
+ return ResNetUNet(Bottleneck, [3, 8, 36, 3], num_classes)