AIDetectV2 / app.py
telecomadm1145's picture
Update app.py
d717f85 verified
# -*- coding: utf-8 -*-
"""
-------------------------------------------------------------------
• V2.5-CAFormer : 4-class (photo / anime × AI / Non-AI)
• V-CAFormer-B36 (新) : 2-class (AI vs. Non-AI, BCE Head)
-------------------------------------------------------------------
说明:
本脚本已根据请求精简,仅支持上述两个 CAFormer 模型。
Swin、V3-Emb 和 MoE 模型的代码已被移除。
"""
import os, torch, timm, numpy as np
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from torchvision import transforms
# V2.5-CAFormer (4-class) 在这个仓库
REPO_ID = "telecomadm1145/swin-ai-detection"
HF_FILENAMES = {
"V2.5-CAFormer": "caformer_b36_4class_96.safetensors",
# V-CAFormer-B36 (2-class) 在它自己的仓库 (见 CKPT_META)
"V-CAFormer-B36": "pytorch_model.bin",
}
CKPT_META = {
# 原始的 4-class CAFormer
"V2.5-CAFormer": {
"n_cls": 4,
"head": "v7", # 使用 TimmClassifierWithHead
"backbone": "caformer_b36.sail_in22k_ft_in1k_384",
"repo_id": REPO_ID, # 使用默认仓库
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]
},
# 新的 2-class BCE CAFormer
"V-CAFormer-B36": {
"n_cls": 2,
"head": "timm_bce_classifier", # 使用 TimmBCEClassifierModel
"backbone_timm_name": "caformer_b36.sail_in22k_ft_in1k_384",#"hf-hub:animetimm/caformer_b36.dbv4-full", # 训练时使用的模型
"repo_id": "telecomadm1145/danbooru-real-vs-ai-caformer-b36-v1", # 权重所在的仓库
"num_classes_timm": 1, # timm.create_model 时 num_classes=1
"names": ["Non-AI Generated", "AI Generated"] # 对应 [P(Non-AI), P(AI)]
},
}
DEFAULT_CKPT = "V-CAFormer-B36" # <-- 设为新的默认模型
LOCAL_CKPT_DIR = "./checkpoints"
SEED = 4421
DROP_RATE = 0.1
DROPOUT_RATE = 0.1
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(SEED); np.random.seed(SEED)
print(f"Using device: {device}")
model, current_ckpt = None, None
current_meta = None
# --- 新增: TIMM BCE 分类头包装器 (用于 V-CAFormer-B36) ---
class TimmBCEClassifierModel(nn.Module):
"""
一个包装器,用于加载使用 BCEWithLogitsLoss 训练的 timm 模型 (num_classes=1)。
它将单个 logit 转换为 [P(Class 0), P(Class 1)] 的 2-class 概率输出。
"""
def __init__(self, timm_model_name, num_classes_timm=1):
super().__init__()
# 加载 timm 模型,num_classes=1 (BCE logit)
self.backbone = timm.create_model(
timm_model_name,
pretrained=False,
num_classes=num_classes_timm
)
self.data_config = timm.data.resolve_data_config({}, model=self.backbone)
def forward(self, x):
# backbone 直接输出 logit
logits = self.backbone(x)
# 训练时 real=1 (Non-AI), ai=0
# P(Non-AI) = sigmoid(logits)
# P(AI) = 1 - P(Non-AI)
prob_non_ai = torch.sigmoid(logits)
prob_ai = 1 - prob_non_ai
# 返回 [P(Non-AI), P(AI)] 格式以匹配Gradio输出
return torch.cat([prob_non_ai, prob_ai], dim=1)
# --- 原始 SwinClassifier (重命名) (用于 V2.5-CAFormer) ---
class TimmClassifierWithHead(nn.Module):
"""
一个包装器,用于加载 timm 骨干网络 + 一个自定义的分类头 (v4, v5, v7)。
此类用于 V2.5-CAFormer (它使用 'v7' head)。
"""
def __init__(self, model_name, num_classes, pretrained=True,
head_version="v4"):
super().__init__()
self.backbone = timm.create_model(
model_name, pretrained=pretrained, num_classes=0
)
self.data_config = timm.data.resolve_data_config({}, model=self.backbone)
if head_version == "v7": # V7 (用于 V2.5-CAFormer)
self.classifier = nn.Sequential(
nn.Dropout(DROP_RATE),
nn.Linear(self.backbone.num_features, 64),
nn.BatchNorm1d(64),
nn.GELU(),
nn.Dropout(DROP_RATE * 0.8),
nn.Linear(64, num_classes),
)
elif head_version == "v5": # V5: 512-128, GELU
self.classifier = nn.Sequential(
nn.Dropout(DROP_RATE),
nn.Linear(self.backbone.num_features, 512),
nn.BatchNorm1d(512),
nn.GELU(),
nn.Dropout(DROP_RATE * 0.7),
nn.Linear(512, 128),
nn.BatchNorm1d(128),
nn.GELU(),
nn.Dropout(DROP_RATE * 0.5),
nn.Linear(128, num_classes),
)
else: # V2 / V4: 512-128, ReLU
self.classifier = nn.Sequential(
nn.Dropout(DROP_RATE),
nn.Linear(self.backbone.num_features, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(DROP_RATE * 0.7),
nn.Linear(512, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(DROP_RATE * 0.5),
nn.Linear(128, num_classes),
)
def forward(self, x):
# 输出多分类 logits
return self.classifier(self.backbone(x))
def load_model(ckpt_name: str):
global model, current_ckpt, current_meta
if ckpt_name == current_ckpt and model is not None:
return
print(f"\n🔄 正在切换到 {ckpt_name} ...")
meta = CKPT_META[ckpt_name]
ckpt_filename = HF_FILENAMES[ckpt_name]
head_version = meta.get("head", "v4")
# 确定从哪个 repo 下载
model_repo_id = meta.get("repo_id", REPO_ID)
# --- CAFormer-B36 (BCE) 加载逻辑 ---
if head_version == "timm_bce_classifier":
print(f"创建 TIMM BCE 模型: {meta['backbone_timm_name']}")
model = TimmBCEClassifierModel(
timm_model_name=meta["backbone_timm_name"],
num_classes_timm=meta["num_classes_timm"]
).to(device)
print(f"从 {model_repo_id} 加载权重...")
ckpt_file = hf_hub_download(
repo_id=model_repo_id, # 使用 meta 中定义的 repo_id
filename=ckpt_filename, # e.g., pytorch_model.bin
local_dir=LOCAL_CKPT_DIR, force_download=False
)
if ckpt_filename.endswith(".safetensors"):
state = load_file(ckpt_file, device=device)
else:
state = torch.load(ckpt_file, map_location=device, weights_only=False)
# 权重被保存为 state_dict,加载到包装器的 backbone 中
model.backbone.load_state_dict(state.get("model_state_dict", state))
print("✅ TIMM BCE 模型权重加载完毕。")
# --- 原始 CAFormer (V2.5) 加载逻辑 ---
else:
print(f"从 {model_repo_id} 加载标准 TIMM + Head 权重...")
ckpt_file = hf_hub_download(
repo_id=model_repo_id,
filename=ckpt_filename,
local_dir=LOCAL_CKPT_DIR, force_download=False
)
print(f"Checkpoint: {ckpt_file}")
# 使用重命名的类
model = TimmClassifierWithHead(
meta["backbone"],
num_classes=meta["n_cls"],
pretrained=False,
head_version=head_version
).to(device)
if ckpt_filename.endswith(".safetensors"):
state = load_file(ckpt_file, device=device)
else:
state = torch.load(ckpt_file, map_location=device, weights_only=False)
model.load_state_dict(state.get("model_state_dict", state), strict=True)
model.eval()
current_ckpt, current_meta = ckpt_name, meta
print(f"✅ {ckpt_name} 加载完毕 (分类数 = {meta['n_cls']})。")
def build_transform(is_training: bool, interpolation: str):
if model is None: raise RuntimeError("模型尚未加载。")
# 两个模型都使用其 backbone 的 data_config
cfg = model.data_config.copy()
cfg.update(dict(interpolation=interpolation))
return timm.data.create_transform(**cfg, is_training=is_training)
def pil_ensure_rgb(image: Image.Image) -> Image.Image:
# 确保图像是 RGB 格式
if image.mode not in ["RGB", "RGBA"]:
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
if image.mode == "RGBA":
# PNG/WebP 透明背景转为白色
canvas = Image.new("RGBA", image.size, (255, 255, 255))
canvas.alpha_composite(image)
image = canvas.convert("RGB")
return image
@torch.no_grad()
def predict(image: Image.Image,
ckpt_name: str,
interpolation: str = "bicubic"):
if image is None: return None
load_model(ckpt_name)
head_type = current_meta["head"]
# 两个 CAFormer 模型都使用标准的 timm 变换
# (移除了 V3-Emb 特有的 padding 和 BGR 通道交换)
processed_image = pil_ensure_rgb(image)
tfm = build_transform(False, interpolation)
inp = tfm(processed_image).unsqueeze(0).to(device)
# --- 输出处理 ---
if head_type == "timm_bce_classifier":
# V-CAFormer-B36: 包装器输出 [P(Non-AI), P(AI)]
output = model(inp)
probs = output.cpu()
probs = probs[0] # 取 batch 里的第一个
else:
# V2.5-CAFormer: 模型输出 4-class logits, 需要 softmax
probs = F.softmax(model(inp), dim=1)[0].cpu()
class_names = current_meta["names"]
return {class_names[i]: float(probs[i])
for i in range(len(class_names))}
def launch():
load_model(DEFAULT_CKPT)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# AI 图像检测器")
gr.Markdown(
"在左侧选择一个模型,上传一张图片,"
"然后点击 **🚀 运行** 来查看预测结果。\n"
"• **V-CAFormer-B36**: 最新的二分类模型 (AI vs Non-AI)。\n"
"• **V2.5-CAFormer**: 4分类模型 (区分动漫/照片 x AI/Non-AI)。"
)
with gr.Row():
with gr.Column(scale=1):
run_btn = gr.Button("🚀 运行", variant="primary")
sel_ckpt = gr.Dropdown(
list(HF_FILENAMES.keys()), # 自动填充 ["V2.5-CAFormer", "V-CAFormer-B36"]
value=DEFAULT_CKPT, label="选择模型"
)
sel_interp = gr.Radio(
["bilinear", "bicubic", "nearest"],
value="bicubic", label="图像缩放插值"
)
in_img = gr.Image(type="pil", label="上传图片")
with gr.Column(scale=1):
out_lbl = gr.Label(num_top_classes=4, label="预测结果")
run_btn.click(
predict,
inputs=[in_img, sel_ckpt, sel_interp],
outputs=[out_lbl]
)
if not os.path.exists("examples"):
os.makedirs("examples")
example_files = [os.path.join("examples", f)
for f in os.listdir("examples")
if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
if example_files:
gr.Examples(
examples=[[f, DEFAULT_CKPT, "bicubic"] for f in example_files],
inputs=[in_img, sel_ckpt, sel_interp],
outputs=[out_lbl],
fn=predict,
cache_examples=False,
)
demo.launch()
if __name__ == "__main__":
launch()