Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,7 +6,7 @@ import gradio as gr
|
|
| 6 |
import time
|
| 7 |
import traceback
|
| 8 |
import spaces
|
| 9 |
-
from torchvision.models import
|
| 10 |
from torchvision.ops import nms, box_iou
|
| 11 |
import torch.nn.functional as F
|
| 12 |
from torchvision import transforms
|
|
@@ -98,29 +98,61 @@ class MultiHeadAttention(nn.Module):
|
|
| 98 |
return out
|
| 99 |
|
| 100 |
class BaseModel(nn.Module):
|
|
|
|
| 101 |
def __init__(self, num_classes, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
| 102 |
super().__init__()
|
| 103 |
self.device = device
|
| 104 |
-
self.backbone = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.IMAGENET1K_V1)
|
| 105 |
-
self.feature_dim = self.backbone.classifier[1].in_features
|
| 106 |
-
self.backbone.classifier = nn.Identity()
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
self.num_heads = max(1, min(8, self.feature_dim // 64))
|
| 109 |
self.attention = MultiHeadAttention(self.feature_dim, num_heads=self.num_heads)
|
| 110 |
|
|
|
|
| 111 |
self.classifier = nn.Sequential(
|
| 112 |
nn.LayerNorm(self.feature_dim),
|
| 113 |
nn.Dropout(0.3),
|
| 114 |
nn.Linear(self.feature_dim, num_classes)
|
| 115 |
)
|
| 116 |
|
| 117 |
-
self.to(device)
|
| 118 |
-
|
| 119 |
def forward(self, x):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
x = x.to(self.device)
|
|
|
|
|
|
|
| 121 |
features = self.backbone(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
attended_features = self.attention(features)
|
|
|
|
|
|
|
| 123 |
logits = self.classifier(attended_features)
|
|
|
|
| 124 |
return logits, attended_features
|
| 125 |
|
| 126 |
|
|
@@ -179,7 +211,7 @@ class ModelManager:
|
|
| 179 |
).to(self.device)
|
| 180 |
|
| 181 |
checkpoint = torch.load(
|
| 182 |
-
'
|
| 183 |
map_location=self.device # 確保checkpoint加載到正確的設備
|
| 184 |
)
|
| 185 |
self._breed_model.load_state_dict(checkpoint['base_model'], strict=False)
|
|
|
|
| 6 |
import time
|
| 7 |
import traceback
|
| 8 |
import spaces
|
| 9 |
+
from torchvision.models import convnext_base, ConvNeXt_Base_Weights
|
| 10 |
from torchvision.ops import nms, box_iou
|
| 11 |
import torch.nn.functional as F
|
| 12 |
from torchvision import transforms
|
|
|
|
| 98 |
return out
|
| 99 |
|
| 100 |
class BaseModel(nn.Module):
|
| 101 |
+
|
| 102 |
def __init__(self, num_classes, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
| 103 |
super().__init__()
|
| 104 |
self.device = device
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
+
# 1. 初始化 backbone
|
| 107 |
+
self.backbone = convnext_base(weights=ConvNeXt_Base_Weights.IMAGENET1K_V1)
|
| 108 |
+
self.backbone.classifier = nn.Identity() # 移除原始分類器
|
| 109 |
+
|
| 110 |
+
# 2. 使用測試數據確定實際的特徵維度
|
| 111 |
+
with torch.no_grad(): # 不需要計算梯度
|
| 112 |
+
dummy_input = torch.randn(1, 3, 224, 224) # 創建示例輸入
|
| 113 |
+
features = self.backbone(dummy_input)
|
| 114 |
+
if len(features.shape) > 2: # 如果特徵是多維的
|
| 115 |
+
features = features.mean([-2, -1]) # 進行全局平均池化
|
| 116 |
+
self.feature_dim = features.shape[1] # 獲取正確的特徵維度
|
| 117 |
+
|
| 118 |
+
print(f"Feature Dim: {self.feature_dim}") # 幫助調試
|
| 119 |
+
|
| 120 |
+
# 3. 設置多頭注意力層
|
| 121 |
self.num_heads = max(1, min(8, self.feature_dim // 64))
|
| 122 |
self.attention = MultiHeadAttention(self.feature_dim, num_heads=self.num_heads)
|
| 123 |
|
| 124 |
+
# 4. 設置分類器
|
| 125 |
self.classifier = nn.Sequential(
|
| 126 |
nn.LayerNorm(self.feature_dim),
|
| 127 |
nn.Dropout(0.3),
|
| 128 |
nn.Linear(self.feature_dim, num_classes)
|
| 129 |
)
|
| 130 |
|
|
|
|
|
|
|
| 131 |
def forward(self, x):
|
| 132 |
+
"""
|
| 133 |
+
模型的前向傳播過程
|
| 134 |
+
Args:
|
| 135 |
+
x (Tensor): 輸入圖像張量,形狀為 [batch_size, channels, height, width]
|
| 136 |
+
Returns:
|
| 137 |
+
Tuple[Tensor, Tensor]: 分類邏輯值和注意力特徵
|
| 138 |
+
"""
|
| 139 |
x = x.to(self.device)
|
| 140 |
+
|
| 141 |
+
# 1. 提取基礎特徵
|
| 142 |
features = self.backbone(x)
|
| 143 |
+
|
| 144 |
+
# 2. 處理特徵維度
|
| 145 |
+
if len(features.shape) > 2:
|
| 146 |
+
# 如果特徵維度是 [batch_size, channels, height, width]
|
| 147 |
+
# 轉換為 [batch_size, channels]
|
| 148 |
+
features = features.mean([-2, -1]) # 使用全局平均池化
|
| 149 |
+
|
| 150 |
+
# 3. 應用注意力機制
|
| 151 |
attended_features = self.attention(features)
|
| 152 |
+
|
| 153 |
+
# 4. 最終分類
|
| 154 |
logits = self.classifier(attended_features)
|
| 155 |
+
|
| 156 |
return logits, attended_features
|
| 157 |
|
| 158 |
|
|
|
|
| 211 |
).to(self.device)
|
| 212 |
|
| 213 |
checkpoint = torch.load(
|
| 214 |
+
'ConvNextBase_best_model_dog.pth',
|
| 215 |
map_location=self.device # 確保checkpoint加載到正確的設備
|
| 216 |
)
|
| 217 |
self._breed_model.load_state_dict(checkpoint['base_model'], strict=False)
|