|
|
import os |
|
|
import torch |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from tqdm import tqdm |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import torch.nn.functional as F |
|
|
from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score |
|
|
|
|
|
import models |
|
|
|
|
|
|
|
|
def build_model(model_name): |
|
|
if model_name == 'F3Net': |
|
|
model = models.Det_F3_Net() |
|
|
elif model_name == 'NPR': |
|
|
model = models.resnet50_npr() |
|
|
elif model_name == 'STIL': |
|
|
model = models.Det_STIL() |
|
|
elif model_name == 'XCLIP_DeMamba': |
|
|
model = models.XCLIP_DeMamba() |
|
|
elif model_name == 'CLIP_DeMamba': |
|
|
model = models.CLIP_DeMamba() |
|
|
elif model_name == 'XCLIP': |
|
|
model = models.XCLIP() |
|
|
elif model_name == 'CLIP': |
|
|
model = models.CLIP_Base() |
|
|
elif model_name == 'ViT_B_MINTIME': |
|
|
model = models.ViT_B_MINTIME() |
|
|
else: |
|
|
raise ValueError(f"Unknown model: {model_name}") |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
def eval_on_frames(model, frames_dir, device): |
|
|
model.eval() |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
frame_paths = [] |
|
|
for root, _, files in os.walk(frames_dir): |
|
|
for f in files: |
|
|
if f.lower().endswith(('.png', '.jpg', '.jpeg')): |
|
|
frame_paths.append(os.path.join(root, f)) |
|
|
frame_paths.sort() |
|
|
|
|
|
results = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for fp in tqdm(frame_paths, desc="Evaluating frames"): |
|
|
img = Image.open(fp).convert("RGB") |
|
|
|
|
|
|
|
|
x = transform(img).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
x = x.unsqueeze(1).repeat(1, 8, 1, 1, 1) |
|
|
|
|
|
|
|
|
logit = model(x) |
|
|
prob = torch.sigmoid(logit[:, 0]).item() |
|
|
pred_label = int(prob > 0.5) |
|
|
|
|
|
results.append({ |
|
|
"file_name": os.path.basename(fp), |
|
|
"predicted_prob": prob, |
|
|
"predicted_label": pred_label |
|
|
}) |
|
|
|
|
|
return pd.DataFrame(results) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
model_name = "XCLIP_DeMamba" |
|
|
model_path = "./results/kling_9k_9k/best_acc.pth" |
|
|
frames_dir = "./frames" |
|
|
output_csv = "results.csv" |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
model = build_model(model_name).to(device) |
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
model.load_state_dict(checkpoint["model_state_dict"], strict=False) |
|
|
print(f"Loaded model weights from {model_path}") |
|
|
|
|
|
|
|
|
df_results = eval_on_frames(model, frames_dir, device) |
|
|
df_results.to_csv(output_csv, index=False) |
|
|
print(f"Saved framewise results to {output_csv}") |
|
|
|
|
|
|
|
|
if "label" in df_results.columns: |
|
|
y_true = df_results["label"].values |
|
|
y_pred = df_results["predicted_label"].values |
|
|
y_prob = df_results["predicted_prob"].values |
|
|
|
|
|
acc = accuracy_score(y_true, y_pred) |
|
|
auc = roc_auc_score(y_true, y_prob) |
|
|
ap = average_precision_score(y_true, y_prob) |
|
|
|
|
|
print(f"Accuracy: {acc:.4f}, AUC: {auc:.4f}, AP: {ap:.4f}") |
|
|
|