|
|
import json
|
|
|
import logging
|
|
|
import os
|
|
|
import re
|
|
|
from copy import deepcopy
|
|
|
from pathlib import Path
|
|
|
|
|
|
import torch
|
|
|
|
|
|
from .model import CLAP, convert_weights_to_fp16
|
|
|
from .openai import load_openai_model
|
|
|
from .pretrained import get_pretrained_url, download_pretrained
|
|
|
from .transform import image_transform
|
|
|
|
|
|
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
|
|
_MODEL_CONFIGS = {}
|
|
|
|
|
|
|
|
|
def _natural_key(string_):
|
|
|
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
|
|
|
|
|
|
|
|
|
def _rescan_model_configs():
|
|
|
global _MODEL_CONFIGS
|
|
|
|
|
|
config_ext = (".json",)
|
|
|
config_files = []
|
|
|
for config_path in _MODEL_CONFIG_PATHS:
|
|
|
if config_path.is_file() and config_path.suffix in config_ext:
|
|
|
config_files.append(config_path)
|
|
|
elif config_path.is_dir():
|
|
|
for ext in config_ext:
|
|
|
config_files.extend(config_path.glob(f"*{ext}"))
|
|
|
|
|
|
for cf in config_files:
|
|
|
if os.path.basename(cf)[0] == ".":
|
|
|
continue
|
|
|
|
|
|
with open(cf, "r") as f:
|
|
|
model_cfg = json.load(f)
|
|
|
if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
|
|
|
_MODEL_CONFIGS[cf.stem] = model_cfg
|
|
|
|
|
|
_MODEL_CONFIGS = {
|
|
|
k: v
|
|
|
for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
|
|
|
}
|
|
|
|
|
|
|
|
|
_rescan_model_configs()
|
|
|
|
|
|
|
|
|
def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
|
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
|
|
|
state_dict = checkpoint["state_dict"]
|
|
|
else:
|
|
|
state_dict = checkpoint
|
|
|
if skip_params:
|
|
|
if next(iter(state_dict.items()))[0].startswith("module"):
|
|
|
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
def create_model(
|
|
|
amodel_name: str,
|
|
|
tmodel_name: str,
|
|
|
pretrained: str = "",
|
|
|
precision: str = "fp32",
|
|
|
device: torch.device = torch.device("cpu"),
|
|
|
jit: bool = False,
|
|
|
force_quick_gelu: bool = False,
|
|
|
openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
|
|
|
skip_params=True,
|
|
|
pretrained_audio: str = "",
|
|
|
pretrained_text: str = "",
|
|
|
enable_fusion: bool = False,
|
|
|
fusion_type: str = "None"
|
|
|
|
|
|
):
|
|
|
amodel_name = amodel_name.replace(
|
|
|
"/", "-"
|
|
|
)
|
|
|
pretrained_orig = pretrained
|
|
|
pretrained = pretrained.lower()
|
|
|
if pretrained == "openai":
|
|
|
if amodel_name in _MODEL_CONFIGS:
|
|
|
logging.info(f"Loading {amodel_name} model config.")
|
|
|
model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
|
|
|
else:
|
|
|
logging.error(
|
|
|
f"Model config for {amodel_name} not found; available models {list_models()}."
|
|
|
)
|
|
|
raise RuntimeError(f"Model config for {amodel_name} not found.")
|
|
|
|
|
|
logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
|
|
|
|
|
|
model_cfg["text_cfg"]["model_type"] = tmodel_name
|
|
|
model = load_openai_model(
|
|
|
"ViT-B-16",
|
|
|
model_cfg,
|
|
|
device=device,
|
|
|
jit=jit,
|
|
|
cache_dir=openai_model_cache_dir,
|
|
|
enable_fusion=enable_fusion,
|
|
|
fusion_type=fusion_type,
|
|
|
)
|
|
|
|
|
|
if precision == "amp" or precision == "fp32":
|
|
|
model = model.float()
|
|
|
else:
|
|
|
if amodel_name in _MODEL_CONFIGS:
|
|
|
logging.info(f"Loading {amodel_name} model config.")
|
|
|
model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
|
|
|
else:
|
|
|
logging.error(
|
|
|
f"Model config for {amodel_name} not found; available models {list_models()}."
|
|
|
)
|
|
|
raise RuntimeError(f"Model config for {amodel_name} not found.")
|
|
|
|
|
|
if force_quick_gelu:
|
|
|
|
|
|
model_cfg["quick_gelu"] = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_cfg["text_cfg"]["model_type"] = tmodel_name
|
|
|
model_cfg["enable_fusion"] = enable_fusion
|
|
|
model_cfg["fusion_type"] = fusion_type
|
|
|
model = CLAP(**model_cfg)
|
|
|
|
|
|
if pretrained:
|
|
|
checkpoint_path = ""
|
|
|
url = get_pretrained_url(amodel_name, pretrained)
|
|
|
if url:
|
|
|
checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
|
|
|
elif os.path.exists(pretrained_orig):
|
|
|
checkpoint_path = pretrained_orig
|
|
|
if checkpoint_path:
|
|
|
logging.info(
|
|
|
f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})."
|
|
|
)
|
|
|
ckpt = load_state_dict(checkpoint_path, skip_params=True)
|
|
|
model.load_state_dict(ckpt)
|
|
|
param_names = [n for n, p in model.named_parameters()]
|
|
|
|
|
|
|
|
|
else:
|
|
|
logging.warning(
|
|
|
f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
|
|
|
)
|
|
|
raise RuntimeError(
|
|
|
f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
|
|
|
)
|
|
|
|
|
|
if pretrained_audio:
|
|
|
if amodel_name.startswith("PANN"):
|
|
|
if "Cnn14_mAP" in pretrained_audio:
|
|
|
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
|
|
audio_ckpt = audio_ckpt["model"]
|
|
|
keys = list(audio_ckpt.keys())
|
|
|
for key in keys:
|
|
|
if (
|
|
|
"spectrogram_extractor" not in key
|
|
|
and "logmel_extractor" not in key
|
|
|
):
|
|
|
v = audio_ckpt.pop(key)
|
|
|
audio_ckpt["audio_branch." + key] = v
|
|
|
elif os.path.basename(pretrained_audio).startswith(
|
|
|
"PANN"
|
|
|
):
|
|
|
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
|
|
audio_ckpt = audio_ckpt["state_dict"]
|
|
|
keys = list(audio_ckpt.keys())
|
|
|
for key in keys:
|
|
|
if key.startswith("sed_model"):
|
|
|
v = audio_ckpt.pop(key)
|
|
|
audio_ckpt["audio_branch." + key[10:]] = v
|
|
|
elif os.path.basename(pretrained_audio).startswith(
|
|
|
"finetuned"
|
|
|
):
|
|
|
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
|
|
else:
|
|
|
raise ValueError("Unknown audio checkpoint")
|
|
|
elif amodel_name.startswith("HTSAT"):
|
|
|
if "HTSAT_AudioSet_Saved" in pretrained_audio:
|
|
|
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
|
|
audio_ckpt = audio_ckpt["state_dict"]
|
|
|
keys = list(audio_ckpt.keys())
|
|
|
for key in keys:
|
|
|
if key.startswith("sed_model") and (
|
|
|
"spectrogram_extractor" not in key
|
|
|
and "logmel_extractor" not in key
|
|
|
):
|
|
|
v = audio_ckpt.pop(key)
|
|
|
audio_ckpt["audio_branch." + key[10:]] = v
|
|
|
elif os.path.basename(pretrained_audio).startswith(
|
|
|
"HTSAT"
|
|
|
):
|
|
|
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
|
|
audio_ckpt = audio_ckpt["state_dict"]
|
|
|
keys = list(audio_ckpt.keys())
|
|
|
for key in keys:
|
|
|
if key.startswith("sed_model"):
|
|
|
v = audio_ckpt.pop(key)
|
|
|
audio_ckpt["audio_branch." + key[10:]] = v
|
|
|
elif os.path.basename(pretrained_audio).startswith(
|
|
|
"finetuned"
|
|
|
):
|
|
|
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
|
|
else:
|
|
|
raise ValueError("Unknown audio checkpoint")
|
|
|
else:
|
|
|
raise f"this audio encoder pretrained checkpoint is not support"
|
|
|
|
|
|
model.load_state_dict(audio_ckpt, strict=False)
|
|
|
logging.info(
|
|
|
f"Loading pretrained {amodel_name} weights ({pretrained_audio})."
|
|
|
)
|
|
|
param_names = [n for n, p in model.named_parameters()]
|
|
|
for n in param_names:
|
|
|
print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
|
|
|
|
|
|
model.to(device=device)
|
|
|
if precision == "fp16":
|
|
|
assert device.type != "cpu"
|
|
|
convert_weights_to_fp16(model)
|
|
|
|
|
|
if jit:
|
|
|
model = torch.jit.script(model)
|
|
|
|
|
|
return model, model_cfg
|
|
|
|
|
|
|
|
|
def create_model_and_transforms(
|
|
|
model_name: str,
|
|
|
pretrained: str = "",
|
|
|
precision: str = "fp32",
|
|
|
device: torch.device = torch.device("cpu"),
|
|
|
jit: bool = False,
|
|
|
force_quick_gelu: bool = False,
|
|
|
|
|
|
):
|
|
|
model = create_model(
|
|
|
model_name,
|
|
|
pretrained,
|
|
|
precision,
|
|
|
device,
|
|
|
jit,
|
|
|
force_quick_gelu=force_quick_gelu,
|
|
|
|
|
|
)
|
|
|
preprocess_train = image_transform(model.visual.image_size, is_train=True)
|
|
|
preprocess_val = image_transform(model.visual.image_size, is_train=False)
|
|
|
return model, preprocess_train, preprocess_val
|
|
|
|
|
|
|
|
|
def list_models():
|
|
|
"""enumerate available model architectures based on config files"""
|
|
|
return list(_MODEL_CONFIGS.keys())
|
|
|
|
|
|
|
|
|
def add_model_config(path):
|
|
|
"""add model config path or file and update registry"""
|
|
|
if not isinstance(path, Path):
|
|
|
path = Path(path)
|
|
|
_MODEL_CONFIG_PATHS.append(path)
|
|
|
_rescan_model_configs()
|
|
|
|