telecomadm1145 commited on
Commit
77d9851
·
verified ·
1 Parent(s): f8b2050

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -21,7 +21,7 @@ from pytorch_grad_cam.utils.image import show_cam_on_image
21
  REPO_ID = "telecomadm1145/swin-ai-detection"
22
  HF_FILENAME = "swin_classifier_stage1_swin_large.pth"
23
  LOCAL_CKPT_DIR = "./checkpoints"
24
- MODEL_NAME = "swin_large_patch4_window7_224" # ← 使用 large
25
  NUM_CLASSES = 2
26
  SEED = 4421
27
  dropout_rate = 0.1
@@ -42,15 +42,15 @@ class SwinClassifier(nn.Module):
42
  self.data_config = timm.data.resolve_data_config({}, model=self.backbone)
43
 
44
  self.classifier = nn.Sequential(
45
- #nn.Dropout(dropout_rate),
46
  nn.Linear(self.backbone.num_features, 512),
47
  nn.BatchNorm1d(512),
48
  nn.ReLU(),
49
- #nn.Dropout(dropout_rate * 0.7),
50
  nn.Linear(512, 128),
51
  nn.BatchNorm1d(128),
52
  nn.ReLU(),
53
- #nn.Dropout(dropout_rate * 0.5),
54
  nn.Linear(128, num_classes)
55
  )
56
 
 
21
  REPO_ID = "telecomadm1145/swin-ai-detection"
22
  HF_FILENAME = "swin_classifier_stage1_swin_large.pth"
23
  LOCAL_CKPT_DIR = "./checkpoints"
24
+ MODEL_NAME = "swin_large_patch4_window12_384" # ← 使用 large
25
  NUM_CLASSES = 2
26
  SEED = 4421
27
  dropout_rate = 0.1
 
42
  self.data_config = timm.data.resolve_data_config({}, model=self.backbone)
43
 
44
  self.classifier = nn.Sequential(
45
+ nn.Dropout(dropout_rate),
46
  nn.Linear(self.backbone.num_features, 512),
47
  nn.BatchNorm1d(512),
48
  nn.ReLU(),
49
+ nn.Dropout(dropout_rate * 0.7),
50
  nn.Linear(512, 128),
51
  nn.BatchNorm1d(128),
52
  nn.ReLU(),
53
+ nn.Dropout(dropout_rate * 0.5),
54
  nn.Linear(128, num_classes)
55
  )
56