File size: 3,673 Bytes
d39b279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a551aaf
d39b279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score

import models  # assuming models.py is in your PYTHONPATH or same dir

# -------- Build model as per your code ----------
def build_model(model_name):
    if model_name == 'F3Net':
        model = models.Det_F3_Net()
    elif model_name == 'NPR':
        model = models.resnet50_npr()
    elif model_name == 'STIL':
        model = models.Det_STIL()
    elif model_name == 'XCLIP_DeMamba':
        model = models.XCLIP_DeMamba()
    elif model_name == 'CLIP_DeMamba':
        model = models.CLIP_DeMamba()
    elif model_name == 'XCLIP':
        model = models.XCLIP()
    elif model_name == 'CLIP':
        model = models.CLIP_Base()
    elif model_name == 'ViT_B_MINTIME':
        model = models.ViT_B_MINTIME()
    else:
        raise ValueError(f"Unknown model: {model_name}")
    return model


# --------- Evaluation loop -------------
def eval_on_frames(model, frames_dir, device):
    model.eval()
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    frame_paths = []
    for root, _, files in os.walk(frames_dir):
        for f in files:
            if f.lower().endswith(('.png', '.jpg', '.jpeg')):
                frame_paths.append(os.path.join(root, f))
    frame_paths.sort()

    results = []

    with torch.no_grad():
        for fp in tqdm(frame_paths, desc="Evaluating frames"):
            img = Image.open(fp).convert("RGB")
            
            # Transform and add batch dimension
            x = transform(img).unsqueeze(0).to(device)  # [1, C, H, W]
            
            # Add temporal dimension expected by DeMamba/XCLIP (T=8)
            x = x.unsqueeze(1).repeat(1, 8, 1, 1, 1)    # [1, 8, C, H, W]

            # Forward pass
            logit = model(x)
            prob = torch.sigmoid(logit[:, 0]).item()
            pred_label = int(prob > 0.5)

            results.append({
                "file_name": os.path.basename(fp),
                "predicted_prob": prob,
                "predicted_label": pred_label
            })

    return pd.DataFrame(results)


if __name__ == "__main__":
    # ----- config -----
    model_name = "XCLIP_DeMamba"  # change if needed
    model_path = "./results/kling_9k_9k/best_acc.pth"
    frames_dir = "./frames"
    output_csv = "results.csv"

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # ---- load model ----
    model = build_model(model_name).to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"], strict=False)
    print(f"Loaded model weights from {model_path}")

    # ---- evaluate ----
    df_results = eval_on_frames(model, frames_dir, device)
    df_results.to_csv(output_csv, index=False)
    print(f"Saved framewise results to {output_csv}")

    # ---- optional: basic metrics if you have GT ----
    if "label" in df_results.columns:
        y_true = df_results["label"].values
        y_pred = df_results["predicted_label"].values
        y_prob = df_results["predicted_prob"].values

        acc = accuracy_score(y_true, y_pred)
        auc = roc_auc_score(y_true, y_prob)
        ap = average_precision_score(y_true, y_prob)

        print(f"Accuracy: {acc:.4f}, AUC: {auc:.4f}, AP: {ap:.4f}")