| import timm | |
| import torch.nn as nn | |
| from pathlib import Path | |
| from .utils import activations, forward_default, get_activation | |
| from ..external.next_vit.classification.nextvit import * | |
| def forward_next_vit(pretrained, x): | |
| return forward_default(pretrained, x, "forward") | |
| def _make_next_vit_backbone( | |
| model, | |
| hooks=[2, 6, 36, 39], | |
| ): | |
| pretrained = nn.Module() | |
| pretrained.model = model | |
| pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1")) | |
| pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2")) | |
| pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3")) | |
| pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4")) | |
| pretrained.activations = activations | |
| return pretrained | |
| def _make_pretrained_next_vit_large_6m(hooks=None): | |
| model = timm.create_model("nextvit_large") | |
| hooks = [2, 6, 36, 39] if hooks == None else hooks | |
| return _make_next_vit_backbone( | |
| model, | |
| hooks=hooks, | |
| ) | |