import torch from transformers import ConvNextForImageClassification, AutoConfig, AutoImageProcessor ckpt_path = "./ConvNextmodel.pth" base_model = "facebook/convnext-tiny-224" # the model you started from num_labels = 7 # 1️⃣ Load the raw state dict state_dict = torch.load(ckpt_path, map_location="cpu") if any(k.startswith("module.") for k in state_dict): state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()} # 2️⃣ Rebuild config and model config = AutoConfig.from_pretrained(base_model) config.num_labels = num_labels model = ConvNextForImageClassification(config) missing, unexpected = model.load_state_dict(state_dict, strict=True) print("Missing:", missing, "Unexpected:", unexpected) # 3️⃣ Save in HF format save_dir = "./convnext-tiny-224-7cls" model.save_pretrained(save_dir) # 4️⃣ Also save processor (for transforms) processor = AutoImageProcessor.from_pretrained(base_model) processor.save_pretrained(save_dir)