Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,318 Bytes
9507532 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
"""
Model Factory for MapAnything
"""
import importlib.util
import logging
import warnings
import numpy as np
from omegaconf import DictConfig, OmegaConf
# Core models that are always available
from mapanything.models.mapanything import (
MapAnything,
MapAnythingAblations,
ModularDUSt3R,
)
# Suppress DINOv2 warnings
logging.getLogger("dinov2").setLevel(logging.WARNING)
warnings.filterwarnings("ignore", message="xFormers is available", category=UserWarning)
warnings.filterwarnings(
"ignore", message="xFormers is not available", category=UserWarning
)
def resolve_special_float(value):
if value == "inf":
return np.inf
elif value == "-inf":
return -np.inf
else:
raise ValueError(f"Unknown special float value: {value}")
def init_model(
model_str: str, model_config: DictConfig, torch_hub_force_reload: bool = False
):
"""
Initialize a model using OmegaConf configuration.
Args:
model_str (str): Name of the model class to create.
model_config (DictConfig): OmegaConf model configuration.
torch_hub_force_reload (bool): Whether to force reload relevant parts of the model from torch hub.
"""
if not OmegaConf.has_resolver("special_float"):
OmegaConf.register_new_resolver("special_float", resolve_special_float)
model_dict = OmegaConf.to_container(model_config, resolve=True)
model = model_factory(
model_str, torch_hub_force_reload=torch_hub_force_reload, **model_dict
)
return model
# Define model configurations with import paths
MODEL_CONFIGS = {
# Core models
"mapanything": {
"class": MapAnything,
},
"mapanything_ablations": {
"class": MapAnythingAblations,
},
"modular_dust3r": {
"class": ModularDUSt3R,
},
# External models
"anycalib": {
"module": "mapanything.models.external.anycalib",
"class_name": "AnyCalibWrapper",
},
"dust3r": {
"module": "mapanything.models.external.dust3r",
"class_name": "DUSt3RBAWrapper",
},
"mast3r": {
"module": "mapanything.models.external.mast3r",
"class_name": "MASt3RSGAWrapper",
},
"moge": {
"module": "mapanything.models.external.moge",
"class_name": "MoGeWrapper",
},
"must3r": {
"module": "mapanything.models.external.must3r",
"class_name": "MUSt3RWrapper",
},
"pi3": {
"module": "mapanything.models.external.pi3",
"class_name": "Pi3Wrapper",
},
"pow3r": {
"module": "mapanything.models.external.pow3r",
"class_name": "Pow3RWrapper",
},
"pow3r_ba": {
"module": "mapanything.models.external.pow3r",
"class_name": "Pow3RBAWrapper",
},
"vggt": {
"module": "mapanything.models.external.vggt",
"class_name": "VGGTWrapper",
},
# Add other model classes here
}
def check_module_exists(module_path):
"""
Check if a module can be imported without actually importing it.
Args:
module_path (str): The path to the module to check.
Returns:
bool: True if the module can be imported, False otherwise.
"""
return importlib.util.find_spec(module_path) is not None
def model_factory(model_str: str, **kwargs):
"""
Model factory for MapAnything.
Args:
model_str (str): Name of the model to create.
**kwargs: Additional keyword arguments to pass to the model constructor.
Returns:
nn.Module: An instance of the specified model.
"""
if model_str not in MODEL_CONFIGS:
raise ValueError(
f"Unknown model: {model_str}. Valid options are: {', '.join(MODEL_CONFIGS.keys())}"
)
model_config = MODEL_CONFIGS[model_str]
# Handle core models directly
if "class" in model_config:
model_class = model_config["class"]
# Handle external models with dynamic imports
elif "module" in model_config:
module_path = model_config["module"]
class_name = model_config["class_name"]
# Check if the module can be imported
if not check_module_exists(module_path):
raise ImportError(
f"Model '{model_str}' requires module '{module_path}' which is not installed. "
f"Please install the corresponding submodule or package."
)
# Dynamically import the module and get the class
try:
module = importlib.import_module(module_path)
model_class = getattr(module, class_name)
except (ImportError, AttributeError) as e:
raise ImportError(
f"Failed to import {class_name} from {module_path}: {str(e)}"
)
else:
raise ValueError(f"Invalid model configuration for {model_str}")
print(f"Initializing {model_class} with kwargs: {kwargs}")
if model_str != "org_dust3r":
return model_class(**kwargs)
else:
eval_str = kwargs.get("model_eval_str", None)
return eval(eval_str)
def get_available_models() -> list:
"""
Get a list of available models in MapAnything.
Returns:
list: A list of available model names.
"""
return list(MODEL_CONFIGS.keys())
__all__ = ["model_factory", "get_available_models"]
|