aknapitsch user
initial commit of map anything demo
9507532
"""
Model Factory for MapAnything
"""
import importlib.util
import logging
import warnings
import numpy as np
from omegaconf import DictConfig, OmegaConf
# Core models that are always available
from mapanything.models.mapanything import (
MapAnything,
MapAnythingAblations,
ModularDUSt3R,
)
# Suppress DINOv2 warnings
logging.getLogger("dinov2").setLevel(logging.WARNING)
warnings.filterwarnings("ignore", message="xFormers is available", category=UserWarning)
warnings.filterwarnings(
"ignore", message="xFormers is not available", category=UserWarning
)
def resolve_special_float(value):
if value == "inf":
return np.inf
elif value == "-inf":
return -np.inf
else:
raise ValueError(f"Unknown special float value: {value}")
def init_model(
model_str: str, model_config: DictConfig, torch_hub_force_reload: bool = False
):
"""
Initialize a model using OmegaConf configuration.
Args:
model_str (str): Name of the model class to create.
model_config (DictConfig): OmegaConf model configuration.
torch_hub_force_reload (bool): Whether to force reload relevant parts of the model from torch hub.
"""
if not OmegaConf.has_resolver("special_float"):
OmegaConf.register_new_resolver("special_float", resolve_special_float)
model_dict = OmegaConf.to_container(model_config, resolve=True)
model = model_factory(
model_str, torch_hub_force_reload=torch_hub_force_reload, **model_dict
)
return model
# Define model configurations with import paths
MODEL_CONFIGS = {
# Core models
"mapanything": {
"class": MapAnything,
},
"mapanything_ablations": {
"class": MapAnythingAblations,
},
"modular_dust3r": {
"class": ModularDUSt3R,
},
# External models
"anycalib": {
"module": "mapanything.models.external.anycalib",
"class_name": "AnyCalibWrapper",
},
"dust3r": {
"module": "mapanything.models.external.dust3r",
"class_name": "DUSt3RBAWrapper",
},
"mast3r": {
"module": "mapanything.models.external.mast3r",
"class_name": "MASt3RSGAWrapper",
},
"moge": {
"module": "mapanything.models.external.moge",
"class_name": "MoGeWrapper",
},
"must3r": {
"module": "mapanything.models.external.must3r",
"class_name": "MUSt3RWrapper",
},
"pi3": {
"module": "mapanything.models.external.pi3",
"class_name": "Pi3Wrapper",
},
"pow3r": {
"module": "mapanything.models.external.pow3r",
"class_name": "Pow3RWrapper",
},
"pow3r_ba": {
"module": "mapanything.models.external.pow3r",
"class_name": "Pow3RBAWrapper",
},
"vggt": {
"module": "mapanything.models.external.vggt",
"class_name": "VGGTWrapper",
},
# Add other model classes here
}
def check_module_exists(module_path):
"""
Check if a module can be imported without actually importing it.
Args:
module_path (str): The path to the module to check.
Returns:
bool: True if the module can be imported, False otherwise.
"""
return importlib.util.find_spec(module_path) is not None
def model_factory(model_str: str, **kwargs):
"""
Model factory for MapAnything.
Args:
model_str (str): Name of the model to create.
**kwargs: Additional keyword arguments to pass to the model constructor.
Returns:
nn.Module: An instance of the specified model.
"""
if model_str not in MODEL_CONFIGS:
raise ValueError(
f"Unknown model: {model_str}. Valid options are: {', '.join(MODEL_CONFIGS.keys())}"
)
model_config = MODEL_CONFIGS[model_str]
# Handle core models directly
if "class" in model_config:
model_class = model_config["class"]
# Handle external models with dynamic imports
elif "module" in model_config:
module_path = model_config["module"]
class_name = model_config["class_name"]
# Check if the module can be imported
if not check_module_exists(module_path):
raise ImportError(
f"Model '{model_str}' requires module '{module_path}' which is not installed. "
f"Please install the corresponding submodule or package."
)
# Dynamically import the module and get the class
try:
module = importlib.import_module(module_path)
model_class = getattr(module, class_name)
except (ImportError, AttributeError) as e:
raise ImportError(
f"Failed to import {class_name} from {module_path}: {str(e)}"
)
else:
raise ValueError(f"Invalid model configuration for {model_str}")
print(f"Initializing {model_class} with kwargs: {kwargs}")
if model_str != "org_dust3r":
return model_class(**kwargs)
else:
eval_str = kwargs.get("model_eval_str", None)
return eval(eval_str)
def get_available_models() -> list:
"""
Get a list of available models in MapAnything.
Returns:
list: A list of available model names.
"""
return list(MODEL_CONFIGS.keys())
__all__ = ["model_factory", "get_available_models"]