kalpitbcontrails's picture
Update eval2.py
a551aaf verified
import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score
import models # assuming models.py is in your PYTHONPATH or same dir
# -------- Build model as per your code ----------
def build_model(model_name):
if model_name == 'F3Net':
model = models.Det_F3_Net()
elif model_name == 'NPR':
model = models.resnet50_npr()
elif model_name == 'STIL':
model = models.Det_STIL()
elif model_name == 'XCLIP_DeMamba':
model = models.XCLIP_DeMamba()
elif model_name == 'CLIP_DeMamba':
model = models.CLIP_DeMamba()
elif model_name == 'XCLIP':
model = models.XCLIP()
elif model_name == 'CLIP':
model = models.CLIP_Base()
elif model_name == 'ViT_B_MINTIME':
model = models.ViT_B_MINTIME()
else:
raise ValueError(f"Unknown model: {model_name}")
return model
# --------- Evaluation loop -------------
def eval_on_frames(model, frames_dir, device):
model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
frame_paths = []
for root, _, files in os.walk(frames_dir):
for f in files:
if f.lower().endswith(('.png', '.jpg', '.jpeg')):
frame_paths.append(os.path.join(root, f))
frame_paths.sort()
results = []
with torch.no_grad():
for fp in tqdm(frame_paths, desc="Evaluating frames"):
img = Image.open(fp).convert("RGB")
# Transform and add batch dimension
x = transform(img).unsqueeze(0).to(device) # [1, C, H, W]
# Add temporal dimension expected by DeMamba/XCLIP (T=8)
x = x.unsqueeze(1).repeat(1, 8, 1, 1, 1) # [1, 8, C, H, W]
# Forward pass
logit = model(x)
prob = torch.sigmoid(logit[:, 0]).item()
pred_label = int(prob > 0.5)
results.append({
"file_name": os.path.basename(fp),
"predicted_prob": prob,
"predicted_label": pred_label
})
return pd.DataFrame(results)
if __name__ == "__main__":
# ----- config -----
model_name = "XCLIP_DeMamba" # change if needed
model_path = "./results/kling_9k_9k/best_acc.pth"
frames_dir = "./frames"
output_csv = "results.csv"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# ---- load model ----
model = build_model(model_name).to(device)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"], strict=False)
print(f"Loaded model weights from {model_path}")
# ---- evaluate ----
df_results = eval_on_frames(model, frames_dir, device)
df_results.to_csv(output_csv, index=False)
print(f"Saved framewise results to {output_csv}")
# ---- optional: basic metrics if you have GT ----
if "label" in df_results.columns:
y_true = df_results["label"].values
y_pred = df_results["predicted_label"].values
y_prob = df_results["predicted_prob"].values
acc = accuracy_score(y_true, y_pred)
auc = roc_auc_score(y_true, y_prob)
ap = average_precision_score(y_true, y_prob)
print(f"Accuracy: {acc:.4f}, AUC: {auc:.4f}, AP: {ap:.4f}")