Spaces:
Runtime error
Runtime error
| from timm.models import create_model | |
| from .swin_transformer import SwinTransformer | |
| from . import focalnet | |
| def build_model(config): | |
| model_type = config.TYPE | |
| print(f"Creating model: {model_type}") | |
| if "swin" in model_type: | |
| model = SwinTransformer( | |
| num_classes=0, | |
| img_size=config.IMG_SIZE, | |
| patch_size=config.SWIN.PATCH_SIZE, | |
| in_chans=config.SWIN.IN_CHANS, | |
| embed_dim=config.SWIN.EMBED_DIM, | |
| depths=config.SWIN.DEPTHS, | |
| num_heads=config.SWIN.NUM_HEADS, | |
| window_size=config.SWIN.WINDOW_SIZE, | |
| mlp_ratio=config.SWIN.MLP_RATIO, | |
| qkv_bias=config.SWIN.QKV_BIAS, | |
| qk_scale=config.SWIN.QK_SCALE, | |
| drop_rate=config.DROP_RATE, | |
| drop_path_rate=config.DROP_PATH_RATE, | |
| ape=config.SWIN.APE, | |
| patch_norm=config.SWIN.PATCH_NORM, | |
| use_checkpoint=False | |
| ) | |
| elif "focal" in model_type: | |
| model = create_model( | |
| model_type, | |
| pretrained=False, | |
| img_size=config.IMG_SIZE, | |
| num_classes=0, | |
| drop_path_rate=config.DROP_PATH_RATE, | |
| use_conv_embed=config.FOCAL.USE_CONV_EMBED, | |
| use_layerscale=config.FOCAL.USE_LAYERSCALE, | |
| use_postln=config.FOCAL.USE_POSTLN | |
| ) | |
| elif "vit" in model_type: | |
| model = create_model( | |
| model_type, | |
| pretrained=is_pretrained, | |
| img_size=config.DATA.IMG_SIZE, | |
| num_classes=config.MODEL.NUM_CLASSES, | |
| ) | |
| elif "resnet" in model_type: | |
| model = create_model( | |
| model_type, | |
| pretrained=is_pretrained, | |
| num_classes=config.MODEL.NUM_CLASSES | |
| ) | |
| else: | |
| model = create_model( | |
| model_type, | |
| pretrained=is_pretrained, | |
| num_classes=config.MODEL.NUM_CLASSES | |
| ) | |
| return model | |