Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| """ | |
| ------------------------------------------------------------------- | |
| • V2.5-CAFormer : 4-class (photo / anime × AI / Non-AI) | |
| • V-CAFormer-B36 (新) : 2-class (AI vs. Non-AI, BCE Head) | |
| ------------------------------------------------------------------- | |
| 说明: | |
| 本脚本已根据请求精简,仅支持上述两个 CAFormer 模型。 | |
| Swin、V3-Emb 和 MoE 模型的代码已被移除。 | |
| """ | |
| import os, torch, timm, numpy as np | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from torchvision import transforms | |
| # V2.5-CAFormer (4-class) 在这个仓库 | |
| REPO_ID = "telecomadm1145/swin-ai-detection" | |
| HF_FILENAMES = { | |
| "V2.5-CAFormer": "caformer_b36_4class_96.safetensors", | |
| # V-CAFormer-B36 (2-class) 在它自己的仓库 (见 CKPT_META) | |
| "V-CAFormer-B36": "pytorch_model.bin", | |
| } | |
| CKPT_META = { | |
| # 原始的 4-class CAFormer | |
| "V2.5-CAFormer": { | |
| "n_cls": 4, | |
| "head": "v7", # 使用 TimmClassifierWithHead | |
| "backbone": "caformer_b36.sail_in22k_ft_in1k_384", | |
| "repo_id": REPO_ID, # 使用默认仓库 | |
| "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"] | |
| }, | |
| # 新的 2-class BCE CAFormer | |
| "V-CAFormer-B36": { | |
| "n_cls": 2, | |
| "head": "timm_bce_classifier", # 使用 TimmBCEClassifierModel | |
| "backbone_timm_name": "caformer_b36.sail_in22k_ft_in1k_384",#"hf-hub:animetimm/caformer_b36.dbv4-full", # 训练时使用的模型 | |
| "repo_id": "telecomadm1145/danbooru-real-vs-ai-caformer-b36-v1", # 权重所在的仓库 | |
| "num_classes_timm": 1, # timm.create_model 时 num_classes=1 | |
| "names": ["Non-AI Generated", "AI Generated"] # 对应 [P(Non-AI), P(AI)] | |
| }, | |
| } | |
| DEFAULT_CKPT = "V-CAFormer-B36" # <-- 设为新的默认模型 | |
| LOCAL_CKPT_DIR = "./checkpoints" | |
| SEED = 4421 | |
| DROP_RATE = 0.1 | |
| DROPOUT_RATE = 0.1 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch.manual_seed(SEED); np.random.seed(SEED) | |
| print(f"Using device: {device}") | |
| model, current_ckpt = None, None | |
| current_meta = None | |
| # --- 新增: TIMM BCE 分类头包装器 (用于 V-CAFormer-B36) --- | |
| class TimmBCEClassifierModel(nn.Module): | |
| """ | |
| 一个包装器,用于加载使用 BCEWithLogitsLoss 训练的 timm 模型 (num_classes=1)。 | |
| 它将单个 logit 转换为 [P(Class 0), P(Class 1)] 的 2-class 概率输出。 | |
| """ | |
| def __init__(self, timm_model_name, num_classes_timm=1): | |
| super().__init__() | |
| # 加载 timm 模型,num_classes=1 (BCE logit) | |
| self.backbone = timm.create_model( | |
| timm_model_name, | |
| pretrained=False, | |
| num_classes=num_classes_timm | |
| ) | |
| self.data_config = timm.data.resolve_data_config({}, model=self.backbone) | |
| def forward(self, x): | |
| # backbone 直接输出 logit | |
| logits = self.backbone(x) | |
| # 训练时 real=1 (Non-AI), ai=0 | |
| # P(Non-AI) = sigmoid(logits) | |
| # P(AI) = 1 - P(Non-AI) | |
| prob_non_ai = torch.sigmoid(logits) | |
| prob_ai = 1 - prob_non_ai | |
| # 返回 [P(Non-AI), P(AI)] 格式以匹配Gradio输出 | |
| return torch.cat([prob_non_ai, prob_ai], dim=1) | |
| # --- 原始 SwinClassifier (重命名) (用于 V2.5-CAFormer) --- | |
| class TimmClassifierWithHead(nn.Module): | |
| """ | |
| 一个包装器,用于加载 timm 骨干网络 + 一个自定义的分类头 (v4, v5, v7)。 | |
| 此类用于 V2.5-CAFormer (它使用 'v7' head)。 | |
| """ | |
| def __init__(self, model_name, num_classes, pretrained=True, | |
| head_version="v4"): | |
| super().__init__() | |
| self.backbone = timm.create_model( | |
| model_name, pretrained=pretrained, num_classes=0 | |
| ) | |
| self.data_config = timm.data.resolve_data_config({}, model=self.backbone) | |
| if head_version == "v7": # V7 (用于 V2.5-CAFormer) | |
| self.classifier = nn.Sequential( | |
| nn.Dropout(DROP_RATE), | |
| nn.Linear(self.backbone.num_features, 64), | |
| nn.BatchNorm1d(64), | |
| nn.GELU(), | |
| nn.Dropout(DROP_RATE * 0.8), | |
| nn.Linear(64, num_classes), | |
| ) | |
| elif head_version == "v5": # V5: 512-128, GELU | |
| self.classifier = nn.Sequential( | |
| nn.Dropout(DROP_RATE), | |
| nn.Linear(self.backbone.num_features, 512), | |
| nn.BatchNorm1d(512), | |
| nn.GELU(), | |
| nn.Dropout(DROP_RATE * 0.7), | |
| nn.Linear(512, 128), | |
| nn.BatchNorm1d(128), | |
| nn.GELU(), | |
| nn.Dropout(DROP_RATE * 0.5), | |
| nn.Linear(128, num_classes), | |
| ) | |
| else: # V2 / V4: 512-128, ReLU | |
| self.classifier = nn.Sequential( | |
| nn.Dropout(DROP_RATE), | |
| nn.Linear(self.backbone.num_features, 512), | |
| nn.BatchNorm1d(512), | |
| nn.ReLU(), | |
| nn.Dropout(DROP_RATE * 0.7), | |
| nn.Linear(512, 128), | |
| nn.BatchNorm1d(128), | |
| nn.ReLU(), | |
| nn.Dropout(DROP_RATE * 0.5), | |
| nn.Linear(128, num_classes), | |
| ) | |
| def forward(self, x): | |
| # 输出多分类 logits | |
| return self.classifier(self.backbone(x)) | |
| def load_model(ckpt_name: str): | |
| global model, current_ckpt, current_meta | |
| if ckpt_name == current_ckpt and model is not None: | |
| return | |
| print(f"\n🔄 正在切换到 {ckpt_name} ...") | |
| meta = CKPT_META[ckpt_name] | |
| ckpt_filename = HF_FILENAMES[ckpt_name] | |
| head_version = meta.get("head", "v4") | |
| # 确定从哪个 repo 下载 | |
| model_repo_id = meta.get("repo_id", REPO_ID) | |
| # --- CAFormer-B36 (BCE) 加载逻辑 --- | |
| if head_version == "timm_bce_classifier": | |
| print(f"创建 TIMM BCE 模型: {meta['backbone_timm_name']}") | |
| model = TimmBCEClassifierModel( | |
| timm_model_name=meta["backbone_timm_name"], | |
| num_classes_timm=meta["num_classes_timm"] | |
| ).to(device) | |
| print(f"从 {model_repo_id} 加载权重...") | |
| ckpt_file = hf_hub_download( | |
| repo_id=model_repo_id, # 使用 meta 中定义的 repo_id | |
| filename=ckpt_filename, # e.g., pytorch_model.bin | |
| local_dir=LOCAL_CKPT_DIR, force_download=False | |
| ) | |
| if ckpt_filename.endswith(".safetensors"): | |
| state = load_file(ckpt_file, device=device) | |
| else: | |
| state = torch.load(ckpt_file, map_location=device, weights_only=False) | |
| # 权重被保存为 state_dict,加载到包装器的 backbone 中 | |
| model.backbone.load_state_dict(state.get("model_state_dict", state)) | |
| print("✅ TIMM BCE 模型权重加载完毕。") | |
| # --- 原始 CAFormer (V2.5) 加载逻辑 --- | |
| else: | |
| print(f"从 {model_repo_id} 加载标准 TIMM + Head 权重...") | |
| ckpt_file = hf_hub_download( | |
| repo_id=model_repo_id, | |
| filename=ckpt_filename, | |
| local_dir=LOCAL_CKPT_DIR, force_download=False | |
| ) | |
| print(f"Checkpoint: {ckpt_file}") | |
| # 使用重命名的类 | |
| model = TimmClassifierWithHead( | |
| meta["backbone"], | |
| num_classes=meta["n_cls"], | |
| pretrained=False, | |
| head_version=head_version | |
| ).to(device) | |
| if ckpt_filename.endswith(".safetensors"): | |
| state = load_file(ckpt_file, device=device) | |
| else: | |
| state = torch.load(ckpt_file, map_location=device, weights_only=False) | |
| model.load_state_dict(state.get("model_state_dict", state), strict=True) | |
| model.eval() | |
| current_ckpt, current_meta = ckpt_name, meta | |
| print(f"✅ {ckpt_name} 加载完毕 (分类数 = {meta['n_cls']})。") | |
| def build_transform(is_training: bool, interpolation: str): | |
| if model is None: raise RuntimeError("模型尚未加载。") | |
| # 两个模型都使用其 backbone 的 data_config | |
| cfg = model.data_config.copy() | |
| cfg.update(dict(interpolation=interpolation)) | |
| return timm.data.create_transform(**cfg, is_training=is_training) | |
| def pil_ensure_rgb(image: Image.Image) -> Image.Image: | |
| # 确保图像是 RGB 格式 | |
| if image.mode not in ["RGB", "RGBA"]: | |
| image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") | |
| if image.mode == "RGBA": | |
| # PNG/WebP 透明背景转为白色 | |
| canvas = Image.new("RGBA", image.size, (255, 255, 255)) | |
| canvas.alpha_composite(image) | |
| image = canvas.convert("RGB") | |
| return image | |
| def predict(image: Image.Image, | |
| ckpt_name: str, | |
| interpolation: str = "bicubic"): | |
| if image is None: return None | |
| load_model(ckpt_name) | |
| head_type = current_meta["head"] | |
| # 两个 CAFormer 模型都使用标准的 timm 变换 | |
| # (移除了 V3-Emb 特有的 padding 和 BGR 通道交换) | |
| processed_image = pil_ensure_rgb(image) | |
| tfm = build_transform(False, interpolation) | |
| inp = tfm(processed_image).unsqueeze(0).to(device) | |
| # --- 输出处理 --- | |
| if head_type == "timm_bce_classifier": | |
| # V-CAFormer-B36: 包装器输出 [P(Non-AI), P(AI)] | |
| output = model(inp) | |
| probs = output.cpu() | |
| probs = probs[0] # 取 batch 里的第一个 | |
| else: | |
| # V2.5-CAFormer: 模型输出 4-class logits, 需要 softmax | |
| probs = F.softmax(model(inp), dim=1)[0].cpu() | |
| class_names = current_meta["names"] | |
| return {class_names[i]: float(probs[i]) | |
| for i in range(len(class_names))} | |
| def launch(): | |
| load_model(DEFAULT_CKPT) | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# AI 图像检测器") | |
| gr.Markdown( | |
| "在左侧选择一个模型,上传一张图片," | |
| "然后点击 **🚀 运行** 来查看预测结果。\n" | |
| "• **V-CAFormer-B36**: 最新的二分类模型 (AI vs Non-AI)。\n" | |
| "• **V2.5-CAFormer**: 4分类模型 (区分动漫/照片 x AI/Non-AI)。" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| run_btn = gr.Button("🚀 运行", variant="primary") | |
| sel_ckpt = gr.Dropdown( | |
| list(HF_FILENAMES.keys()), # 自动填充 ["V2.5-CAFormer", "V-CAFormer-B36"] | |
| value=DEFAULT_CKPT, label="选择模型" | |
| ) | |
| sel_interp = gr.Radio( | |
| ["bilinear", "bicubic", "nearest"], | |
| value="bicubic", label="图像缩放插值" | |
| ) | |
| in_img = gr.Image(type="pil", label="上传图片") | |
| with gr.Column(scale=1): | |
| out_lbl = gr.Label(num_top_classes=4, label="预测结果") | |
| run_btn.click( | |
| predict, | |
| inputs=[in_img, sel_ckpt, sel_interp], | |
| outputs=[out_lbl] | |
| ) | |
| if not os.path.exists("examples"): | |
| os.makedirs("examples") | |
| example_files = [os.path.join("examples", f) | |
| for f in os.listdir("examples") | |
| if f.lower().endswith(('.png', '.jpg', '.jpeg'))] | |
| if example_files: | |
| gr.Examples( | |
| examples=[[f, DEFAULT_CKPT, "bicubic"] for f in example_files], | |
| inputs=[in_img, sel_ckpt, sel_interp], | |
| outputs=[out_lbl], | |
| fn=predict, | |
| cache_examples=False, | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| launch() | |