File size: 975 Bytes
c458c3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from transformers import ConvNextForImageClassification, AutoConfig, AutoImageProcessor

ckpt_path = "./ConvNextmodel.pth"
base_model = "facebook/convnext-tiny-224"   # the model you started from
num_labels = 7

# 1️⃣ Load the raw state dict
state_dict = torch.load(ckpt_path, map_location="cpu")
if any(k.startswith("module.") for k in state_dict):
    state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()}

# 2️⃣ Rebuild config and model
config = AutoConfig.from_pretrained(base_model)
config.num_labels = num_labels
model = ConvNextForImageClassification(config)
missing, unexpected = model.load_state_dict(state_dict, strict=True)
print("Missing:", missing, "Unexpected:", unexpected)

# 3️⃣ Save in HF format
save_dir = "./convnext-tiny-224-7cls"
model.save_pretrained(save_dir)

# 4️⃣ Also save processor (for transforms)
processor = AutoImageProcessor.from_pretrained(base_model)
processor.save_pretrained(save_dir)