File size: 3,673 Bytes
d39b279 a551aaf d39b279 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score
import models # assuming models.py is in your PYTHONPATH or same dir
# -------- Build model as per your code ----------
def build_model(model_name):
if model_name == 'F3Net':
model = models.Det_F3_Net()
elif model_name == 'NPR':
model = models.resnet50_npr()
elif model_name == 'STIL':
model = models.Det_STIL()
elif model_name == 'XCLIP_DeMamba':
model = models.XCLIP_DeMamba()
elif model_name == 'CLIP_DeMamba':
model = models.CLIP_DeMamba()
elif model_name == 'XCLIP':
model = models.XCLIP()
elif model_name == 'CLIP':
model = models.CLIP_Base()
elif model_name == 'ViT_B_MINTIME':
model = models.ViT_B_MINTIME()
else:
raise ValueError(f"Unknown model: {model_name}")
return model
# --------- Evaluation loop -------------
def eval_on_frames(model, frames_dir, device):
model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
frame_paths = []
for root, _, files in os.walk(frames_dir):
for f in files:
if f.lower().endswith(('.png', '.jpg', '.jpeg')):
frame_paths.append(os.path.join(root, f))
frame_paths.sort()
results = []
with torch.no_grad():
for fp in tqdm(frame_paths, desc="Evaluating frames"):
img = Image.open(fp).convert("RGB")
# Transform and add batch dimension
x = transform(img).unsqueeze(0).to(device) # [1, C, H, W]
# Add temporal dimension expected by DeMamba/XCLIP (T=8)
x = x.unsqueeze(1).repeat(1, 8, 1, 1, 1) # [1, 8, C, H, W]
# Forward pass
logit = model(x)
prob = torch.sigmoid(logit[:, 0]).item()
pred_label = int(prob > 0.5)
results.append({
"file_name": os.path.basename(fp),
"predicted_prob": prob,
"predicted_label": pred_label
})
return pd.DataFrame(results)
if __name__ == "__main__":
# ----- config -----
model_name = "XCLIP_DeMamba" # change if needed
model_path = "./results/kling_9k_9k/best_acc.pth"
frames_dir = "./frames"
output_csv = "results.csv"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# ---- load model ----
model = build_model(model_name).to(device)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"], strict=False)
print(f"Loaded model weights from {model_path}")
# ---- evaluate ----
df_results = eval_on_frames(model, frames_dir, device)
df_results.to_csv(output_csv, index=False)
print(f"Saved framewise results to {output_csv}")
# ---- optional: basic metrics if you have GT ----
if "label" in df_results.columns:
y_true = df_results["label"].values
y_pred = df_results["predicted_label"].values
y_prob = df_results["predicted_prob"].values
acc = accuracy_score(y_true, y_pred)
auc = roc_auc_score(y_true, y_prob)
ap = average_precision_score(y_true, y_prob)
print(f"Accuracy: {acc:.4f}, AUC: {auc:.4f}, AP: {ap:.4f}")
|