# -*- coding: utf-8 -*- """ ------------------------------------------------------------------- • V2.5-CAFormer : 4-class (photo / anime × AI / Non-AI) • V-CAFormer-B36 (新) : 2-class (AI vs. Non-AI, CrossEntropy Head) ------------------------------------------------------------------- 说明: 脚本已更新,以匹配使用 CrossEntropyLoss (num_classes=2) 训练的 V-CAFormer-B36。 TimmBCEClassifierModel 已被移除。 """ import os, torch, timm, numpy as np import torch.nn as nn import torch.nn.functional as F from PIL import Image import gradio as gr from huggingface_hub import hf_hub_download from safetensors.torch import load_file from torchvision import transforms # V2.5-CAFormer (4-class) 在这个仓库 REPO_ID = "telecomadm1145/swin-ai-detection" HF_FILENAMES = { "V2.5-CAFormer": "caformer_b36_4class_96.safetensors", # V-CAFormer-B36 (2-class) 在它自己的仓库 (见 CKPT_META) "V-CAFormer-B36": "pytorch_model.bin", } # --- CKPT_META (已更新 V-CAFormer-B36) --- CKPT_META = { # 原始的 4-class CAFormer "V2.5-CAFormer": { "n_cls": 4, "head": "v7", # 使用 TimmClassifierWithHead "backbone": "caformer_b36.sail_in22k_ft_in1k_384", "repo_id": REPO_ID, # 使用默认仓库 "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"] }, # 新的 2-class CrossEntropy CAFormer (匹配 train.py) "V-CAFormer-B36": { "n_cls": 2, "head": "timm_cross_entropy", # <-- 更新: 使用标准 timm head "backbone_timm_name": "hf-hub:animetimm/caformer_b36.dbv4-full", # <-- 更新: 匹配 train.py "repo_id": "telecomadm1145/danbooru-real-vs-ai-caformer-b36-v1", # 权重所在的仓库 "num_classes_timm": 2, # <-- 更新: 匹配 train.py (num_classes=2) # 更新: 匹配 train.py (索引 0=AI, 索引 1=Real) "names": ["AI Generated", "Non-AI Generated"] }, } # ----------------------------------------------- DEFAULT_CKPT = "V-CAFormer-B36" # <-- 设为新的默认模型 LOCAL_CKPT_DIR = "./checkpoints" SEED = 4421 DROP_RATE = 0.1 DROPOUT_RATE = 0.1 device = "cuda" if torch.cuda.is_available() else "cpu" torch.manual_seed(SEED); np.random.seed(SEED) print(f"Using device: {device}") model, current_ckpt = None, None current_meta = None # --- (已移除 TimmBCEClassifierModel) --- # --- 原始 TimmClassifierWithHead (用于 V2.5-CAFormer) --- class TimmClassifierWithHead(nn.Module): """ 一个包装器,用于加载 timm 骨干网络 + 一个自定义的分类头 (v4, v5, v7)。 此类用于 V2.5-CAFormer (它使用 'v7' head)。 """ def __init__(self, model_name, num_classes, pretrained=True, head_version="v4"): super().__init__() self.backbone = timm.create_model( model_name, pretrained=pretrained, num_classes=0 ) self.data_config = timm.data.resolve_data_config({}, model=self.backbone) if head_version == "v7": # V7 (用于 V2.5-CAFormer) self.classifier = nn.Sequential( nn.Dropout(DROP_RATE), nn.Linear(self.backbone.num_features, 64), nn.BatchNorm1d(64), nn.GELU(), nn.Dropout(DROP_RATE * 0.8), nn.Linear(64, num_classes), ) elif head_version == "v5": # V5: 512-128, GELU self.classifier = nn.Sequential( nn.Dropout(DROP_RATE), nn.Linear(self.backbone.num_features, 512), nn.BatchNorm1d(512), nn.GELU(), nn.Dropout(DROP_RATE * 0.7), nn.Linear(512, 128), nn.BatchNorm1d(128), nn.GELU(), nn.Dropout(DROP_RATE * 0.5), nn.Linear(128, num_classes), ) else: # V2 / V4: 512-128, ReLU self.classifier = nn.Sequential( nn.Dropout(DROP_RATE), nn.Linear(self.backbone.num_features, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(DROP_RATE * 0.7), nn.Linear(512, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(DROP_RATE * 0.5), nn.Linear(128, num_classes), ) def forward(self, x): # 输出多分类 logits return self.classifier(self.backbone(x)) # --- load_model (已更新) --- def load_model(ckpt_name: str): global model, current_ckpt, current_meta if ckpt_name == current_ckpt and model is not None: return print(f"\n🔄 正在切换到 {ckpt_name} ...") meta = CKPT_META[ckpt_name] ckpt_filename = HF_FILENAMES[ckpt_name] head_version = meta.get("head", "v4") # 确定从哪个 repo 下载 model_repo_id = meta.get("repo_id", REPO_ID) # --- V-CAFormer-B36 (CrossEntropy) 加载逻辑 (匹配 train.py) --- if head_version == "timm_cross_entropy": print(f"创建 TIMM CrossEntropy 模型: {meta['backbone_timm_name']}") # 直接创建 timm 模型, num_classes=2 (匹配 train.py) model = timm.create_model( meta["backbone_timm_name"], pretrained=False, # 我们将加载我们自己的权重 num_classes=meta["num_classes_timm"] # 应该是 2 ).to(device) # 必须附加 data_config 以供 transform 使用 model.data_config = timm.data.resolve_data_config({}, model=model) print(f"从 {model_repo_id} 加载权重...") ckpt_file = hf_hub_download( repo_id=model_repo_id, # 使用 meta 中定义的 repo_id filename=ckpt_filename, # e.g., pytorch_model.bin local_dir=LOCAL_CKPT_DIR, force_download=False ) if ckpt_filename.endswith(".safetensors"): state = load_file(ckpt_file, device=device) else: state = torch.load(ckpt_file, map_location=device, weights_only=False) # 训练脚本直接保存 state_dict model.load_state_dict(state.get("model_state_dict", state)) print("✅ TIMM CrossEntropy 模型权重加载完毕。") # --- 原始 V2.5-CAFormer (4-class, custom head) 加载逻辑 --- else: # 涵盖 "v7", "v5", "v4" print(f"从 {model_repo_id} 加载标准 TIMM + Head 权重 (Head: {head_version})...") ckpt_file = hf_hub_download( repo_id=model_repo_id, filename=ckpt_filename, local_dir=LOCAL_CKPT_DIR, force_download=False ) print(f"Checkpoint: {ckpt_file}") # 使用 TimmClassifierWithHead 包装器 model = TimmClassifierWithHead( meta["backbone"], num_classes=meta["n_cls"], pretrained=False, head_version=head_version ).to(device) if ckpt_filename.endswith(".safetensors"): state = load_file(ckpt_file, device=device) else: state = torch.load(ckpt_file, map_location=device, weights_only=False) model.load_state_dict(state.get("model_state_dict", state), strict=True) model.eval() current_ckpt, current_meta = ckpt_name, meta print(f"✅ {ckpt_name} 加载完毕 (分类数 = {meta['n_cls']})。") def build_transform(is_training: bool, interpolation: str): if model is None: raise RuntimeError("模型尚未加载。") # 两个模型都使用其 backbone 的 data_config cfg = model.data_config.copy() cfg.update(dict(interpolation=interpolation)) return timm.data.create_transform(**cfg, is_training=is_training) def pil_ensure_rgb(image: Image.Image) -> Image.Image: # 确保图像是 RGB 格式 if image.mode not in ["RGB", "RGBA"]: image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") if image.mode == "RGBA": # PNG/WebP 透明背景转为白色 canvas = Image.new("RGBA", image.size, (255, 255, 255)) canvas.alpha_composite(image) image = canvas.convert("RGB") return image # --- predict (已更新) --- @torch.no_grad() def predict(image: Image.Image, ckpt_name: str, interpolation: str = "bicubic"): if image is None: return None load_model(ckpt_name) # 加载正确的模型 (CE 或 v7 head) # (移除了 V3-Emb 特有的 padding 和 BGR 通道交换) processed_image = pil_ensure_rgb(image) tfm = build_transform(False, interpolation) inp = tfm(processed_image).unsqueeze(0).to(device) # --- 输出处理 (已简化) --- # 两个模型 (V2.5 和 V-B36) 现在都输出 raw logits # V-B36: [Logit_AI, Logit_Real] (来自 train.py, 索引 0=AI, 1=Real) # V2.5: [Logit_non_ai, Logit_ai, Logit_ani_non_ai, Logit_ani_ai] logits = model(inp) # 对 raw logits 应用 softmax 得到概率 probs = F.softmax(logits, dim=1)[0].cpu() class_names = current_meta["names"] # class_names[0] (AI) 对应 probs[0] # class_names[1] (Non-AI) 对应 probs[1] return {class_names[i]: float(probs[i]) for i in range(len(class_names))} def launch(): load_model(DEFAULT_CKPT) with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# AI 图像检测器") gr.Markdown( "在左侧选择一个模型,上传一张图片," "然后点击 **🚀 运行** 来查看预测结果。\n" "• **V-CAFormer-B36**: 最新的二分类模型 (AI vs Non-AI)。(已更新为匹配 CrossEntropy 训练)\n" "• **V2.5-CAFormer**: 4分类模型 (区分动漫/照片 x AI/Non-AI)。" ) with gr.Row(): with gr.Column(scale=1): run_btn = gr.Button("🚀 运行", variant="primary") sel_ckpt = gr.Dropdown( list(HF_FILENAMES.keys()), # 自动填充 ["V2.5-CAFormer", "V-CAFormer-B36"] value=DEFAULT_CKPT, label="选择模型" ) sel_interp = gr.Radio( ["bilinear", "bicubic", "nearest"], value="bicubic", label="图像缩放插值" ) in_img = gr.Image(type="pil", label="上传图片") with gr.Column(scale=1): out_lbl = gr.Label(num_top_classes=4, label="预测结果") run_btn.click( predict, inputs=[in_img, sel_ckpt, sel_interp], outputs=[out_lbl] ) if not os.path.exists("examples"): os.makedirs("examples") example_files = [os.path.join("examples", f) for f in os.listdir("examples") if f.lower().endswith(('.png', '.jpg', '.jpeg'))] if example_files: gr.Examples( examples=[[f, DEFAULT_CKPT, "bicubic"] for f in example_files], inputs=[in_img, sel_ckpt, sel_interp], outputs=[out_lbl], fn=predict, cache_examples=False, ) demo.launch() if __name__ == "__main__": launch()