import os import torch import numpy as np import pandas as pd from tqdm import tqdm from torchvision import transforms from PIL import Image import torch.nn.functional as F from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score import models # assuming models.py is in your PYTHONPATH or same dir # -------- Build model as per your code ---------- def build_model(model_name): if model_name == 'F3Net': model = models.Det_F3_Net() elif model_name == 'NPR': model = models.resnet50_npr() elif model_name == 'STIL': model = models.Det_STIL() elif model_name == 'XCLIP_DeMamba': model = models.XCLIP_DeMamba() elif model_name == 'CLIP_DeMamba': model = models.CLIP_DeMamba() elif model_name == 'XCLIP': model = models.XCLIP() elif model_name == 'CLIP': model = models.CLIP_Base() elif model_name == 'ViT_B_MINTIME': model = models.ViT_B_MINTIME() else: raise ValueError(f"Unknown model: {model_name}") return model # --------- Evaluation loop ------------- def eval_on_frames(model, frames_dir, device): model.eval() transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) frame_paths = [] for root, _, files in os.walk(frames_dir): for f in files: if f.lower().endswith(('.png', '.jpg', '.jpeg')): frame_paths.append(os.path.join(root, f)) frame_paths.sort() results = [] with torch.no_grad(): for fp in tqdm(frame_paths, desc="Evaluating frames"): img = Image.open(fp).convert("RGB") # Transform and add batch dimension x = transform(img).unsqueeze(0).to(device) # [1, C, H, W] # Add temporal dimension expected by DeMamba/XCLIP (T=8) x = x.unsqueeze(1).repeat(1, 8, 1, 1, 1) # [1, 8, C, H, W] # Forward pass logit = model(x) prob = torch.sigmoid(logit[:, 0]).item() pred_label = int(prob > 0.5) results.append({ "file_name": os.path.basename(fp), "predicted_prob": prob, "predicted_label": pred_label }) return pd.DataFrame(results) if __name__ == "__main__": # ----- config ----- model_name = "XCLIP_DeMamba" # change if needed model_path = "./results/kling_9k_9k/best_acc.pth" frames_dir = "./frames" output_csv = "results.csv" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # ---- load model ---- model = build_model(model_name).to(device) checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint["model_state_dict"], strict=False) print(f"Loaded model weights from {model_path}") # ---- evaluate ---- df_results = eval_on_frames(model, frames_dir, device) df_results.to_csv(output_csv, index=False) print(f"Saved framewise results to {output_csv}") # ---- optional: basic metrics if you have GT ---- if "label" in df_results.columns: y_true = df_results["label"].values y_pred = df_results["predicted_label"].values y_prob = df_results["predicted_prob"].values acc = accuracy_score(y_true, y_pred) auc = roc_auc_score(y_true, y_prob) ap = average_precision_score(y_true, y_prob) print(f"Accuracy: {acc:.4f}, AUC: {auc:.4f}, AP: {ap:.4f}")