| model = dict( | |
| backbone=dict( | |
| n_points=4, | |
| deform_num_heads=16, | |
| cffn_ratio=0.25, | |
| deform_ratio=0.5, | |
| with_cffn=True, | |
| interact_attn_type='deform', | |
| interaction_drop_path_rate=0.4, | |
| separate_head=True, | |
| branch1=dict( | |
| img_size=128, | |
| patch_size=16, | |
| pretrain_img_size=224, | |
| pretrain_patch_size=16, | |
| depth=12, | |
| embed_dim=768, | |
| num_heads=12, | |
| mlp_ratio=4, | |
| init_scale=1.0, | |
| qkv_bias=True, | |
| drop_rate=0.0, | |
| drop_path_rate=0.2, | |
| interaction_indexes=[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9], [10, 10], [11, 11]], | |
| use_cls_token=True, | |
| use_flash_attn=True, | |
| with_cp=True, | |
| pretrained="pretrained/deit_base_patch16_224-b5f2ef4d.pth", | |
| ), | |
| branch2=dict( | |
| img_size=192, | |
| patch_size=16, | |
| pretrain_img_size=224, | |
| pretrain_patch_size=16, | |
| depth=12, | |
| embed_dim=384, | |
| num_heads=6, | |
| mlp_ratio=4, | |
| init_scale=1.0, | |
| qkv_bias=True, | |
| drop_rate=0.0, | |
| drop_path_rate=0.05, | |
| interaction_indexes=[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9], [10, 10], [11, 11]], | |
| use_cls_token=True, | |
| use_flash_attn=True, | |
| with_cp=True, | |
| pretrained="pretrained/deit_small_patch16_224-cd65a155.pth", | |
| ), | |
| branch3=dict( | |
| img_size=368, | |
| patch_size=16, | |
| pretrain_img_size=224, | |
| pretrain_patch_size=16, | |
| depth=12, | |
| embed_dim=192, | |
| num_heads=3, | |
| mlp_ratio=4, | |
| init_scale=1.0, | |
| qkv_bias=True, | |
| drop_rate=0.0, | |
| drop_path_rate=0.05, | |
| interaction_indexes=[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9], [10, 10], [11, 11]], | |
| use_cls_token=True, | |
| use_flash_attn=True, | |
| with_cp=True, | |
| pretrained="pretrained/deit_tiny_patch16_224-a1311bcf.pth", | |
| ), | |
| ), | |
| ) |