Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -42,10 +42,13 @@ CKPT_META = {
|
|
| 42 |
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
|
| 43 |
"V2.5-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384",
|
| 44 |
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
|
|
|
|
| 45 |
"EmbeddingClassifier": {
|
| 46 |
"n_cls": 2,
|
| 47 |
"head": "embedding_classifier",
|
| 48 |
-
"
|
|
|
|
|
|
|
| 49 |
"names": ["Non-AI Generated", "AI Generated"]
|
| 50 |
}
|
| 51 |
}
|
|
@@ -60,7 +63,6 @@ print(f"Using device: {device}")
|
|
| 60 |
model, current_ckpt = None, None
|
| 61 |
current_meta = None
|
| 62 |
|
| 63 |
-
# --- EmbeddingClassifier Model ---
|
| 64 |
class EmbeddingClassifier(nn.Module):
|
| 65 |
def __init__(self, input_dim=1024, hidden_dim1=512, hidden_dim2=256, output_dim=1):
|
| 66 |
super().__init__()
|
|
@@ -79,28 +81,32 @@ class EmbeddingClassifier(nn.Module):
|
|
| 79 |
def forward(self, x):
|
| 80 |
return self.net(x)
|
| 81 |
|
|
|
|
| 82 |
class EmbeddingClassifierModel(nn.Module):
|
| 83 |
-
def __init__(self,
|
| 84 |
super().__init__()
|
| 85 |
-
self.backbone = timm.create_model(
|
| 86 |
self.data_config = timm.data.resolve_data_config({}, model=self.backbone)
|
| 87 |
-
|
| 88 |
-
# that indicates if an image is real (Non-AI). So, output_dim is fixed to 1.
|
| 89 |
-
self.classifier = EmbeddingClassifier(input_dim=1024, hidden_dim1=512, hidden_dim2=256, output_dim=1)
|
| 90 |
|
| 91 |
def forward(self, x):
|
| 92 |
features = self.backbone(x)
|
| 93 |
-
# The classifier returns a single
|
| 94 |
-
|
| 95 |
-
logit_class0 = self.classifier(features)
|
| 96 |
|
| 97 |
-
# To maintain compatibility with the `predict` function which expects multi-class
|
| 98 |
-
# and applies softmax, we construct a 2-class
|
| 99 |
-
#
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
-
# The final logits are for ["Non-AI", "AI"], i.e., [logit_class0, logit_class1].
|
| 103 |
-
return torch.cat([logit_class0, logit_class1], dim=1)
|
| 104 |
|
| 105 |
# --- Original DINOv2 Classifier (Weighted Attention Pooling) ---
|
| 106 |
class DINOv2Classifier_WeightedPool(nn.Module):
|
|
@@ -234,22 +240,48 @@ def load_model(ckpt_name: str):
|
|
| 234 |
meta = CKPT_META[ckpt_name]
|
| 235 |
ckpt_filename = HF_FILENAMES[ckpt_name]
|
| 236 |
|
| 237 |
-
|
| 238 |
-
repo_id=REPO_ID,
|
| 239 |
-
filename=ckpt_filename,
|
| 240 |
-
local_dir=LOCAL_CKPT_DIR, force_download=False
|
| 241 |
-
)
|
| 242 |
-
print(f"Checkpoint: {ckpt_file}")
|
| 243 |
-
|
| 244 |
-
# Build model structure based on model_type or head
|
| 245 |
head_version = meta.get("head", "v4")
|
| 246 |
if head_version == "embedding_classifier":
|
|
|
|
|
|
|
| 247 |
model = EmbeddingClassifierModel(
|
| 248 |
-
|
| 249 |
num_classes=meta["n_cls"]
|
| 250 |
).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
else:
|
| 252 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
model_type = meta.get("model_type")
|
| 254 |
if model_type == "dinov2_weighted_pool":
|
| 255 |
model = DINOv2Classifier_WeightedPool(
|
|
@@ -269,18 +301,12 @@ def load_model(ckpt_name: str):
|
|
| 269 |
head_version=head_version
|
| 270 |
).to(device)
|
| 271 |
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
# Load state dict
|
| 279 |
-
if head_version == "embedding_classifier":
|
| 280 |
-
# For EmbeddingClassifierModel, we need to load the state dict for the classifier part
|
| 281 |
-
# Assuming the checkpoint only contains the classifier state dict
|
| 282 |
-
model.classifier.load_state_dict(state)
|
| 283 |
-
else:
|
| 284 |
model.load_state_dict(state.get("model_state_dict", state), strict=True)
|
| 285 |
|
| 286 |
model.eval()
|
|
@@ -318,7 +344,13 @@ def predict(image: Image.Image,
|
|
| 318 |
tfm = build_transform(False, interpolation)
|
| 319 |
|
| 320 |
inp = tfm(image).unsqueeze(0).to(device)
|
| 321 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
class_names = current_meta["names"]
|
| 323 |
# 保证 gr.Label 在 2 / 4 类都能正常显示
|
| 324 |
return {class_names[i]: float(probs[i])
|
|
|
|
| 42 |
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
|
| 43 |
"V2.5-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384",
|
| 44 |
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
|
| 45 |
+
# --- MODIFIED: Added specific keys for the new loading logic ---
|
| 46 |
"EmbeddingClassifier": {
|
| 47 |
"n_cls": 2,
|
| 48 |
"head": "embedding_classifier",
|
| 49 |
+
"timm_model_name": "swinv2_base_window8_256.ms_in1k",
|
| 50 |
+
"backbone_repo_id": "SmilingWolf/wd-swinv2-tagger-v3",
|
| 51 |
+
"backbone_filename": "model.safetensors",
|
| 52 |
"names": ["Non-AI Generated", "AI Generated"]
|
| 53 |
}
|
| 54 |
}
|
|
|
|
| 63 |
model, current_ckpt = None, None
|
| 64 |
current_meta = None
|
| 65 |
|
|
|
|
| 66 |
class EmbeddingClassifier(nn.Module):
|
| 67 |
def __init__(self, input_dim=1024, hidden_dim1=512, hidden_dim2=256, output_dim=1):
|
| 68 |
super().__init__()
|
|
|
|
| 81 |
def forward(self, x):
|
| 82 |
return self.net(x)
|
| 83 |
|
| 84 |
+
# MODIFIED: Changed __init__ to accept timm_model_name and use pretrained=False
|
| 85 |
class EmbeddingClassifierModel(nn.Module):
|
| 86 |
+
def __init__(self, timm_model_name, num_classes):
|
| 87 |
super().__init__()
|
| 88 |
+
self.backbone = timm.create_model(timm_model_name, pretrained=False, num_classes=0)
|
| 89 |
self.data_config = timm.data.resolve_data_config({}, model=self.backbone)
|
| 90 |
+
self.classifier = EmbeddingClassifier(input_dim=self.backbone.num_features)
|
|
|
|
|
|
|
| 91 |
|
| 92 |
def forward(self, x):
|
| 93 |
features = self.backbone(x)
|
| 94 |
+
# The classifier returns a single value (probability of being Non-AI)
|
| 95 |
+
prob_class0 = self.classifier(features)
|
|
|
|
| 96 |
|
| 97 |
+
# To maintain compatibility with the `predict` function which expects multi-class outputs
|
| 98 |
+
# and applies softmax, we construct a 2-class output.
|
| 99 |
+
# prob_class1 is simply 1 - prob_class0
|
| 100 |
+
prob_class1 = 1 - prob_class0
|
| 101 |
+
|
| 102 |
+
# The final output is for ["Non-AI", "AI"], i.e., [prob_class0, prob_class1].
|
| 103 |
+
# The softmax in predict() will be applied to this, so we should return logits.
|
| 104 |
+
# However, since the original output is a sigmoid, we can work with probabilities
|
| 105 |
+
# and just return them directly. The gr.Label will normalize this.
|
| 106 |
+
# A simpler way is to construct logits that would result in these probabilities.
|
| 107 |
+
# Let's stick to the original logic's output format.
|
| 108 |
+
return torch.cat([prob_class0, prob_class1], dim=1)
|
| 109 |
|
|
|
|
|
|
|
| 110 |
|
| 111 |
# --- Original DINOv2 Classifier (Weighted Attention Pooling) ---
|
| 112 |
class DINOv2Classifier_WeightedPool(nn.Module):
|
|
|
|
| 240 |
meta = CKPT_META[ckpt_name]
|
| 241 |
ckpt_filename = HF_FILENAMES[ckpt_name]
|
| 242 |
|
| 243 |
+
# --- MODIFIED: Special handling for EmbeddingClassifier ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
head_version = meta.get("head", "v4")
|
| 245 |
if head_version == "embedding_classifier":
|
| 246 |
+
# 1. Create the model structure with a non-pretrained backbone
|
| 247 |
+
print(f"Creating backbone: {meta['timm_model_name']}")
|
| 248 |
model = EmbeddingClassifierModel(
|
| 249 |
+
timm_model_name=meta["timm_model_name"],
|
| 250 |
num_classes=meta["n_cls"]
|
| 251 |
).to(device)
|
| 252 |
+
|
| 253 |
+
# 2. Download and load backbone weights from SmilingWolf's repo
|
| 254 |
+
print(f"Loading backbone weights from {meta['backbone_repo_id']}...")
|
| 255 |
+
backbone_ckpt_file = hf_hub_download(
|
| 256 |
+
repo_id=meta["backbone_repo_id"],
|
| 257 |
+
filename=meta["backbone_filename"],
|
| 258 |
+
local_dir=LOCAL_CKPT_DIR, force_download=False
|
| 259 |
+
)
|
| 260 |
+
backbone_state = load_file(backbone_ckpt_file, device=device)
|
| 261 |
+
model.backbone.load_state_dict(backbone_state)
|
| 262 |
+
print("✅ Backbone weights loaded.")
|
| 263 |
+
|
| 264 |
+
# 3. Download and load classifier (head) weights from the main repo
|
| 265 |
+
print(f"Loading classifier head weights from {REPO_ID}...")
|
| 266 |
+
classifier_ckpt_file = hf_hub_download(
|
| 267 |
+
repo_id=REPO_ID,
|
| 268 |
+
filename=ckpt_filename, # This is 'swinv2_v3_v1.pth'
|
| 269 |
+
local_dir=LOCAL_CKPT_DIR, force_download=False
|
| 270 |
+
)
|
| 271 |
+
classifier_state = torch.load(classifier_ckpt_file, map_location=device, weights_only=False)
|
| 272 |
+
model.classifier.load_state_dict(classifier_state)
|
| 273 |
+
print("✅ Classifier head weights loaded.")
|
| 274 |
+
|
| 275 |
else:
|
| 276 |
+
# --- Original logic for all other models ---
|
| 277 |
+
ckpt_file = hf_hub_download(
|
| 278 |
+
repo_id=REPO_ID,
|
| 279 |
+
filename=ckpt_filename,
|
| 280 |
+
local_dir=LOCAL_CKPT_DIR, force_download=False
|
| 281 |
+
)
|
| 282 |
+
print(f"Checkpoint: {ckpt_file}")
|
| 283 |
+
|
| 284 |
+
# Build model structure based on model_type or head
|
| 285 |
model_type = meta.get("model_type")
|
| 286 |
if model_type == "dinov2_weighted_pool":
|
| 287 |
model = DINOv2Classifier_WeightedPool(
|
|
|
|
| 301 |
head_version=head_version
|
| 302 |
).to(device)
|
| 303 |
|
| 304 |
+
# Compatible load for .pth and .safetensors
|
| 305 |
+
if ckpt_filename.endswith(".safetensors"):
|
| 306 |
+
state = load_file(ckpt_file, device=device)
|
| 307 |
+
else:
|
| 308 |
+
state = torch.load(ckpt_file, map_location=device, weights_only=False)
|
| 309 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
model.load_state_dict(state.get("model_state_dict", state), strict=True)
|
| 311 |
|
| 312 |
model.eval()
|
|
|
|
| 344 |
tfm = build_transform(False, interpolation)
|
| 345 |
|
| 346 |
inp = tfm(image).unsqueeze(0).to(device)
|
| 347 |
+
# MODIFIED: For EmbeddingClassifier, the output is already probabilities, no need for softmax.
|
| 348 |
+
# For others, softmax is needed.
|
| 349 |
+
if current_meta["head"] == "embedding_classifier":
|
| 350 |
+
probs = model(inp)[0].cpu()
|
| 351 |
+
else:
|
| 352 |
+
probs = F.softmax(model(inp), dim=1)[0].cpu()
|
| 353 |
+
|
| 354 |
class_names = current_meta["names"]
|
| 355 |
# 保证 gr.Label 在 2 / 4 类都能正常显示
|
| 356 |
return {class_names[i]: float(probs[i])
|