File size: 5,318 Bytes
9507532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""
Model Factory for MapAnything
"""

import importlib.util
import logging
import warnings

import numpy as np
from omegaconf import DictConfig, OmegaConf

# Core models that are always available
from mapanything.models.mapanything import (
    MapAnything,
    MapAnythingAblations,
    ModularDUSt3R,
)

# Suppress DINOv2 warnings
logging.getLogger("dinov2").setLevel(logging.WARNING)
warnings.filterwarnings("ignore", message="xFormers is available", category=UserWarning)
warnings.filterwarnings(
    "ignore", message="xFormers is not available", category=UserWarning
)


def resolve_special_float(value):
    if value == "inf":
        return np.inf
    elif value == "-inf":
        return -np.inf
    else:
        raise ValueError(f"Unknown special float value: {value}")


def init_model(
    model_str: str, model_config: DictConfig, torch_hub_force_reload: bool = False
):
    """
    Initialize a model using OmegaConf configuration.

    Args:
        model_str (str): Name of the model class to create.
        model_config (DictConfig): OmegaConf model configuration.
        torch_hub_force_reload (bool): Whether to force reload relevant parts of the model from torch hub.
    """
    if not OmegaConf.has_resolver("special_float"):
        OmegaConf.register_new_resolver("special_float", resolve_special_float)
    model_dict = OmegaConf.to_container(model_config, resolve=True)
    model = model_factory(
        model_str, torch_hub_force_reload=torch_hub_force_reload, **model_dict
    )

    return model


# Define model configurations with import paths
MODEL_CONFIGS = {
    # Core models
    "mapanything": {
        "class": MapAnything,
    },
    "mapanything_ablations": {
        "class": MapAnythingAblations,
    },
    "modular_dust3r": {
        "class": ModularDUSt3R,
    },
    # External models
    "anycalib": {
        "module": "mapanything.models.external.anycalib",
        "class_name": "AnyCalibWrapper",
    },
    "dust3r": {
        "module": "mapanything.models.external.dust3r",
        "class_name": "DUSt3RBAWrapper",
    },
    "mast3r": {
        "module": "mapanything.models.external.mast3r",
        "class_name": "MASt3RSGAWrapper",
    },
    "moge": {
        "module": "mapanything.models.external.moge",
        "class_name": "MoGeWrapper",
    },
    "must3r": {
        "module": "mapanything.models.external.must3r",
        "class_name": "MUSt3RWrapper",
    },
    "pi3": {
        "module": "mapanything.models.external.pi3",
        "class_name": "Pi3Wrapper",
    },
    "pow3r": {
        "module": "mapanything.models.external.pow3r",
        "class_name": "Pow3RWrapper",
    },
    "pow3r_ba": {
        "module": "mapanything.models.external.pow3r",
        "class_name": "Pow3RBAWrapper",
    },
    "vggt": {
        "module": "mapanything.models.external.vggt",
        "class_name": "VGGTWrapper",
    },
    # Add other model classes here
}


def check_module_exists(module_path):
    """
    Check if a module can be imported without actually importing it.

    Args:
        module_path (str): The path to the module to check.

    Returns:
        bool: True if the module can be imported, False otherwise.
    """
    return importlib.util.find_spec(module_path) is not None


def model_factory(model_str: str, **kwargs):
    """
    Model factory for MapAnything.

    Args:
        model_str (str): Name of the model to create.
        **kwargs: Additional keyword arguments to pass to the model constructor.

    Returns:
       nn.Module: An instance of the specified model.
    """
    if model_str not in MODEL_CONFIGS:
        raise ValueError(
            f"Unknown model: {model_str}. Valid options are: {', '.join(MODEL_CONFIGS.keys())}"
        )

    model_config = MODEL_CONFIGS[model_str]

    # Handle core models directly
    if "class" in model_config:
        model_class = model_config["class"]
    # Handle external models with dynamic imports
    elif "module" in model_config:
        module_path = model_config["module"]
        class_name = model_config["class_name"]

        # Check if the module can be imported
        if not check_module_exists(module_path):
            raise ImportError(
                f"Model '{model_str}' requires module '{module_path}' which is not installed. "
                f"Please install the corresponding submodule or package."
            )

        # Dynamically import the module and get the class
        try:
            module = importlib.import_module(module_path)
            model_class = getattr(module, class_name)
        except (ImportError, AttributeError) as e:
            raise ImportError(
                f"Failed to import {class_name} from {module_path}: {str(e)}"
            )
    else:
        raise ValueError(f"Invalid model configuration for {model_str}")

    print(f"Initializing {model_class} with kwargs: {kwargs}")
    if model_str != "org_dust3r":
        return model_class(**kwargs)
    else:
        eval_str = kwargs.get("model_eval_str", None)
        return eval(eval_str)


def get_available_models() -> list:
    """
    Get a list of available models in MapAnything.

    Returns:
        list: A list of available model names.
    """
    return list(MODEL_CONFIGS.keys())


__all__ = ["model_factory", "get_available_models"]