| import torch | |
| import torch.nn as nn | |
| from mono.utils.comm import get_func | |
| class BaseDepthModel(nn.Module): | |
| def __init__(self, cfg, **kwargs) -> None: | |
| super(BaseDepthModel, self).__init__() | |
| model_type = cfg.model.type | |
| self.depth_model = get_func('mono.model.model_pipelines.' + model_type)(cfg) | |
| def forward(self, data): | |
| output = self.depth_model(**data) | |
| return output['prediction'], output['confidence'], output | |
| def inference(self, data): | |
| with torch.no_grad(): | |
| pred_depth, confidence, _ = self.forward(data) | |
| return pred_depth, confidence |