File size: 15,334 Bytes
9b9f9d6
 
 
 
326875d
9b9f9d6
 
c7d61fe
 
9b9f9d6
c7d61fe
 
 
9b9f9d6
 
c7d61fe
9b9f9d6
 
 
 
 
 
 
 
c7d61fe
9b9f9d6
c7d61fe
9b9f9d6
326875d
c7d61fe
9b9f9d6
 
c7d61fe
 
 
9b9f9d6
c7d61fe
 
9b9f9d6
326875d
c7d61fe
9b9f9d6
c7d61fe
 
 
 
 
 
9b9f9d6
 
326875d
 
c7d61fe
9b9f9d6
 
 
 
c7d61fe
 
9b9f9d6
c7d61fe
9b9f9d6
 
326875d
9b9f9d6
 
 
 
 
c7d61fe
9b9f9d6
c7d61fe
9b9f9d6
c7d61fe
9b9f9d6
 
 
 
 
 
 
 
c7d61fe
9b9f9d6
 
 
 
 
 
 
 
 
 
 
 
c7d61fe
9b9f9d6
 
 
 
 
 
 
 
 
 
 
 
c7d61fe
9b9f9d6
 
c7d61fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b9f9d6
c7d61fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b9f9d6
 
 
 
 
 
 
 
c7d61fe
9b9f9d6
 
c7d61fe
326875d
 
 
 
c7d61fe
 
 
9b9f9d6
c7d61fe
326875d
c7d61fe
9b9f9d6
 
c7d61fe
 
 
 
9b9f9d6
c7d61fe
9b9f9d6
 
 
c7d61fe
 
326875d
 
9b9f9d6
326875d
c7d61fe
326875d
9094a6f
9b9f9d6
9094a6f
c7d61fe
 
9094a6f
 
c7d61fe
9b9f9d6
94778bf
 
 
c7d61fe
94778bf
c7d61fe
9094a6f
 
 
c7d61fe
 
e38dd62
3e0966c
e38dd62
f014c13
9b9f9d6
f014c13
d66a824
c7d61fe
 
 
 
 
9b9f9d6
 
c7d61fe
9b9f9d6
0f7bba6
c7d61fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9094a6f
c7d61fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94778bf
8b95acd
f014c13
94778bf
f844b83
9b9f9d6
c7d61fe
62bdf13
f844b83
9b9f9d6
f014c13
c7d61fe
 
 
f844b83
c7d61fe
 
f014c13
9b9f9d6
c7d61fe
 
 
 
 
f014c13
 
c7d61fe
 
 
 
 
5ca0f15
 
c7d61fe
 
 
e38dd62
5ca0f15
 
8b95acd
3e0966c
b6adf0f
c7d61fe
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
# -*- coding: utf-8 -*-
"""
-------------------------------------------------------------------
• V2.5-CAFormer                         : 4-class (photo / anime × AI / Non-AI)
• V-CAFormer-B36 (新)                   : 2-class (AI vs. Non-AI, CrossEntropy Head)
-------------------------------------------------------------------
说明:
- 移除插值选项与无用代码
- 添加 Grad-CAM 选项(默认不启用;启用时才计算梯度并输出热力图)
"""
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

# V2.5-CAFormer (4-class) 在这个仓库
REPO_ID = "telecomadm1145/swin-ai-detection"
HF_FILENAMES = {
    "V2.5-old": "caformer_b36_4class_96.safetensors",
    # V-CAFormer-B36 (2-class) 在它自己的仓库 (见 CKPT_META)
    "V2": "pytorch_model.bin",
}

# --- CKPT_META (更新 V-CAFormer-B36) ---
CKPT_META = {
    # 原始的 4-class CAFormer
    "V2.5-old": {
        "n_cls": 4,
        "head": "v7",  # 使用 TimmClassifierWithHead
        "backbone": "caformer_b36.sail_in22k_ft_in1k_384",
        "repo_id": REPO_ID,  # 使用默认仓库
        "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"],
    },
    # 新的 2-class CrossEntropy CAFormer (匹配 train.py)
    "V2": {
        "n_cls": 2,
        "head": "timm_cross_entropy",  # 使用标准 timm head
        "backbone_timm_name": "hf-hub:animetimm/caformer_b36.dbv4-full",
        "repo_id": "telecomadm1145/danbooru-real-vs-ai-caformer-b36-v2",
        "num_classes_timm": 2,  # 匹配 train.py (num_classes=2)
        # 匹配 train.py (索引 0=AI, 索引 1=Real)
        "names": ["AI Generated", "Non-AI Generated"],
    },
}
# -----------------------------------------------

DEFAULT_CKPT = "V2"  # 默认模型
LOCAL_CKPT_DIR = "./checkpoints"
SEED = 4421
DROP_RATE = 0.1
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(SEED)
np.random.seed(SEED)
print(f"Using device: {device}")
model, current_ckpt, current_meta = None, None, None


# --- 原始 TimmClassifierWithHead (用于 V2.5-CAFormer) ---
class TimmClassifierWithHead(nn.Module):
    """
    一个包装器,用于加载 timm 骨干网络 + 一个自定义的分类头 (v4, v5, v7)。
    此类用于 V2.5-CAFormer (它使用 'v7' head)。
    """
    def __init__(self, model_name, num_classes, pretrained=True, head_version="v4"):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
        self.data_config = timm.data.resolve_data_config({}, model=self.backbone)
        if head_version == "v7":  # V7 (用于 V2.5-CAFormer)
            self.classifier = nn.Sequential(
                nn.Dropout(DROP_RATE),
                nn.Linear(self.backbone.num_features, 64),
                nn.BatchNorm1d(64),
                nn.GELU(),
                nn.Dropout(DROP_RATE * 0.8),
                nn.Linear(64, num_classes),
            )
        elif head_version == "v5":  # V5: 512-128, GELU
            self.classifier = nn.Sequential(
                nn.Dropout(DROP_RATE),
                nn.Linear(self.backbone.num_features, 512),
                nn.BatchNorm1d(512),
                nn.GELU(),
                nn.Dropout(DROP_RATE * 0.7),
                nn.Linear(512, 128),
                nn.BatchNorm1d(128),
                nn.GELU(),
                nn.Dropout(DROP_RATE * 0.5),
                nn.Linear(128, num_classes),
            )
        else:  # V2 / V4: 512-128, ReLU
            self.classifier = nn.Sequential(
                nn.Dropout(DROP_RATE),
                nn.Linear(self.backbone.num_features, 512),
                nn.BatchNorm1d(512),
                nn.ReLU(),
                nn.Dropout(DROP_RATE * 0.7),
                nn.Linear(512, 128),
                nn.BatchNorm1d(128),
                nn.ReLU(),
                nn.Dropout(DROP_RATE * 0.5),
                nn.Linear(128, num_classes),
            )

    def forward(self, x):
        # 输出多分类 logits
        feats = self.backbone(x)
        return self.classifier(feats)


# --- Grad-CAM 工具 ---
class GradCAM:
    def __init__(self, model: nn.Module, target_layer: nn.Module):
        self.model = model
        self.target_layer = target_layer
        self.activations = None
        self.gradients = None
        self.hooks = []
        self._register_hooks()

    def _register_hooks(self):
        # 存特征图与其梯度
        def fwd_hook(module, inp, out):
            self.activations = out
            # 在输出上注册 backward hook(捕获 w.r.t. 输出的梯度)
            def bwd_hook(grad):
                self.gradients = grad
            out.register_hook(bwd_hook)

        self.hooks.append(self.target_layer.register_forward_hook(fwd_hook))

    def remove_hooks(self):
        for h in self.hooks:
            h.remove()
        self.hooks = []

    def compute_cam(self):
        """
        根据 self.activations 与 self.gradients 计算 CAM,返回 (B, H, W) 张量,范围 [0,1]
        """
        A = self.activations  # (B, C, H, W)
        G = self.gradients    # (B, C, H, W)
        if A is None or G is None:
            raise RuntimeError("GradCAM: 未捕获到激活或梯度。")

        # GAP over spatial for weights
        weights = G.mean(dim=(2, 3), keepdim=True)  # (B, C, 1, 1)
        cam = (weights * A).sum(dim=1)  # (B, H, W)
        cam = F.relu(cam)

        # Normalize to [0,1] per-sample
        B = cam.shape[0]
        cam_ = []
        for i in range(B):
            c = cam[i]
            c = c - c.min()
            denom = c.max().clamp(min=1e-6)
            c = c / denom
            cam_.append(c)
        cam = torch.stack(cam_, dim=0)
        return cam


def find_last_conv_layer(module: nn.Module) -> nn.Module:
    last_conv = None
    for _, m in module.named_modules():
        if isinstance(m, nn.Conv2d):
            last_conv = m
    return last_conv


def get_gradcam_target_layer() -> nn.Module:
    """
    根据当前模型类型自动选择用于 Grad-CAM 的目标卷积层。
    优先选取骨干网络的最后一个 Conv2d。
    """
    if model is None or current_meta is None:
        raise RuntimeError("模型尚未加载。")
    head = current_meta.get("head", "v4")
    root = model if head == "timm_cross_entropy" else model.backbone
    target = find_last_conv_layer(root)
    if target is None:
        print("⚠️ 未找到卷积层,Grad-CAM 将不可用。")
    return target


def build_transform(is_training: bool):
    if model is None:
        raise RuntimeError("模型尚未加载。")
    # 使用模型的 data_config 默认设置(包含默认插值/归一化)
    cfg = model.data_config.copy()
    return timm.data.create_transform(**cfg, is_training=is_training)


def pil_ensure_rgb(image: Image.Image) -> Image.Image:
    # 确保图像是 RGB 格式
    if image.mode not in ["RGB", "RGBA"]:
        image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
    if image.mode == "RGBA":
        # PNG/WebP 透明背景转为白色
        canvas = Image.new("RGBA", image.size, (255, 255, 255))
        canvas.alpha_composite(image)
        image = canvas.convert("RGB")
    return image


def tensor_to_pil_image(x: torch.Tensor, data_cfg: dict) -> Image.Image:
    """
    将网络输入张量反归一化并转为 PIL,用于与 CAM 对齐可视化。
    x: (C, H, W)
    """
    mean = torch.tensor(data_cfg.get("mean", (0.0, 0.0, 0.0)), dtype=x.dtype, device=x.device).view(-1, 1, 1)
    std = torch.tensor(data_cfg.get("std", (1.0, 1.0, 1.0)), dtype=x.dtype, device=x.device).view(-1, 1, 1)
    x = (x * std + mean).clamp(0, 1)
    x = (x.permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8)  # (H, W, C)
    return Image.fromarray(x)


def overlay_cam_on_image(base_img: Image.Image, cam_2d: np.ndarray, alpha: float = 0.45) -> Image.Image:
    """
    将 CAM 覆盖到输入图像上(简单红色热力叠加)。
    base_img: PIL (H, W, 3)
    cam_2d: numpy float32 [0,1], shape (H, W)
    """
    cam_2d = np.clip(cam_2d, 0.0, 1.0)
    H, W = base_img.size[1], base_img.size[0]
    cam_img = Image.fromarray((cam_2d * 255).astype(np.uint8), mode="L").resize((W, H), Image.BICUBIC)
    cam_np = np.array(cam_img).astype(np.float32) / 255.0  # (H, W)

    img_np = np.array(base_img).astype(np.float32) / 255.0  # (H, W, 3)
    heat = np.zeros_like(img_np)
    heat[..., 0] = cam_np  # 红色通道
    # 叠加
    overlay = (1 - alpha) * img_np + alpha * heat
    overlay = np.clip(overlay, 0.0, 1.0)
    return Image.fromarray((overlay * 255).astype(np.uint8))


# --- load_model ---
def load_model(ckpt_name: str):
    global model, current_ckpt, current_meta
    if ckpt_name == current_ckpt and model is not None:
        return
    print(f"\n🔄 正在切换到 {ckpt_name} ...")
    meta = CKPT_META[ckpt_name]
    ckpt_filename = HF_FILENAMES[ckpt_name]
    head_version = meta.get("head", "v4")

    # 确定从哪个 repo 下载
    model_repo_id = meta.get("repo_id", REPO_ID)

    # --- V-CAFormer-B36 (CrossEntropy) 加载逻辑 (匹配 train.py) ---
    if head_version == "timm_cross_entropy":
        print(f"创建 TIMM CrossEntropy 模型: {meta['backbone_timm_name']}")
        model = timm.create_model(
            meta["backbone_timm_name"],
            pretrained=False,
            num_classes=meta["num_classes_timm"],
        ).to(device)
        # 附加 data_config 以供 transform 使用
        model.data_config = timm.data.resolve_data_config({}, model=model)

        print(f"从 {model_repo_id} 加载权重...")
        ckpt_file = hf_hub_download(
            repo_id=model_repo_id,
            filename=ckpt_filename,
            local_dir=LOCAL_CKPT_DIR,
            force_download=False,
        )

        if ckpt_filename.endswith(".safetensors"):
            state = load_file(ckpt_file, device=device)
        else:
            state = torch.load(ckpt_file, map_location=device)

        model.load_state_dict(state.get("model_state_dict", state))
        print("✅ TIMM CrossEntropy 模型权重加载完毕。")

    # --- 原始 V2.5-CAFormer (4-class, custom head) 加载逻辑 ---
    else:  # 涵盖 "v7", "v5", "v4"
        print(f"从 {model_repo_id} 加载标准 TIMM + Head 权重 (Head: {head_version})...")
        ckpt_file = hf_hub_download(
            repo_id=model_repo_id,
            filename=ckpt_filename,
            local_dir=LOCAL_CKPT_DIR,
            force_download=False,
        )
        print(f"Checkpoint: {ckpt_file}")

        model = TimmClassifierWithHead(
            meta["backbone"],
            num_classes=meta["n_cls"],
            pretrained=False,
            head_version=head_version,
        ).to(device)

        if ckpt_filename.endswith(".safetensors"):
            state = load_file(ckpt_file, device=device)
        else:
            state = torch.load(ckpt_file, map_location=device)

        model.load_state_dict(state.get("model_state_dict", state), strict=True)

    model.eval()
    current_ckpt, current_meta = ckpt_name, meta
    print(f"✅ {ckpt_name} 加载完毕 (分类数 = {meta['n_cls']})。")


# --- predict (新增 use_gradcam 开关) ---
def predict(image: Image.Image, ckpt_name: str, use_gradcam: bool = False):
    if image is None:
        return None, None
    load_model(ckpt_name)  # 加载正确的模型 (CE 或 v7 head)

    processed_image = pil_ensure_rgb(image)
    tfm = build_transform(is_training=False)
    inp = tfm(processed_image).unsqueeze(0).to(device)

    # 推理与概率
    if not use_gradcam:
        with torch.inference_mode():
            logits = model(inp)
            probs = F.softmax(logits, dim=1)[0].detach().cpu()
        class_names = current_meta["names"]
        out_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
        return out_dict, None

    # 启用 Grad-CAM:仅此分支计算梯度
    target_layer = get_gradcam_target_layer()
    if target_layer is None:
        # 找不到合适卷积层时,返回正常预测,无热力图
        with torch.inference_mode():
            logits = model(inp)
            probs = F.softmax(logits, dim=1)[0].detach().cpu()
        class_names = current_meta["names"]
        out_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
        return out_dict, None

    gradcam = GradCAM(model, target_layer=target_layer)
    try:
        with torch.enable_grad():
            logits = model(inp)
            probs = F.softmax(logits, dim=1)[0]
            pred_idx = int(torch.argmax(probs).item())

            # 反向传播到目标层
            model.zero_grad(set_to_none=True)
            score = logits[:, pred_idx].sum()
            score.backward()

            # 生成 CAM,尺寸与网络输入一致
            cam = gradcam.compute_cam()[0]  # (H, W)
            cam_np = cam.detach().cpu().numpy()

        # 将网络输入反归一化并转 PIL
        input_pil = tensor_to_pil_image(inp[0], model.data_config)
        cam_vis = overlay_cam_on_image(input_pil, cam_np, alpha=0.45)

        # 输出预测字典 + 热力图
        probs = probs.detach().cpu()
        class_names = current_meta["names"]
        out_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
        return out_dict, cam_vis
    finally:
        gradcam.remove_hooks()


def launch():
    load_model(DEFAULT_CKPT)
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown("# AI 图像检测器")
        gr.Markdown("在左侧选择一个模型,上传一张图片,然后点击 **🚀 运行** 来查看预测结果。可选启用 Grad-CAM 查看热力图。")
        with gr.Row():
            with gr.Column(scale=1):
                run_btn = gr.Button("🚀 运行", variant="primary")
                sel_ckpt = gr.Dropdown(
                    list(HF_FILENAMES.keys()),
                    value=DEFAULT_CKPT,
                    label="选择模型",
                )
                use_gradcam = gr.Checkbox(value=False, label="启用 Grad-CAM(可视化热力图)")
                in_img = gr.Image(type="pil", label="上传图片")
            with gr.Column(scale=1):
                out_lbl = gr.Label(num_top_classes=4, label="预测结果")
                out_cam = gr.Image(label="Grad-CAM 热力图", type="pil")

        run_btn.click(predict, inputs=[in_img, sel_ckpt, use_gradcam], outputs=[out_lbl, out_cam])

        # 示例
        if not os.path.exists("examples"):
            os.makedirs("examples")
        example_files = [
            os.path.join("examples", f)
            for f in os.listdir("examples")
            if f.lower().endswith((".png", ".jpg", ".jpeg"))
        ]
        if example_files:
            gr.Examples(
                examples=[[f, DEFAULT_CKPT, False] for f in example_files],
                inputs=[in_img, sel_ckpt, use_gradcam],
                outputs=[out_lbl, out_cam],
                fn=predict,
                cache_examples=False,
            )
    demo.launch()


if __name__ == "__main__":
    launch()