Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # -*- encoding: utf-8 -*- | |
| # Copyright (c) Megvii Inc. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from .darknet import Darknet | |
| from .network_blocks import BaseConv | |
| class YOLOFPN(nn.Module): | |
| """ | |
| YOLOFPN module. Darknet 53 is the default backbone of this model. | |
| """ | |
| def __init__( | |
| self, | |
| depth=53, | |
| in_features=["dark3", "dark4", "dark5"], | |
| ): | |
| super().__init__() | |
| self.backbone = Darknet(depth) | |
| self.in_features = in_features | |
| # out 1 | |
| self.out1_cbl = self._make_cbl(512, 256, 1) | |
| self.out1 = self._make_embedding([256, 512], 512 + 256) | |
| # out 2 | |
| self.out2_cbl = self._make_cbl(256, 128, 1) | |
| self.out2 = self._make_embedding([128, 256], 256 + 128) | |
| # upsample | |
| self.upsample = nn.Upsample(scale_factor=2, mode="nearest") | |
| def _make_cbl(self, _in, _out, ks): | |
| return BaseConv(_in, _out, ks, stride=1, act="lrelu") | |
| def _make_embedding(self, filters_list, in_filters): | |
| m = nn.Sequential( | |
| *[ | |
| self._make_cbl(in_filters, filters_list[0], 1), | |
| self._make_cbl(filters_list[0], filters_list[1], 3), | |
| self._make_cbl(filters_list[1], filters_list[0], 1), | |
| self._make_cbl(filters_list[0], filters_list[1], 3), | |
| self._make_cbl(filters_list[1], filters_list[0], 1), | |
| ] | |
| ) | |
| return m | |
| def load_pretrained_model(self, filename="./weights/darknet53.mix.pth"): | |
| with open(filename, "rb") as f: | |
| state_dict = torch.load(f, map_location="cpu") | |
| print("loading pretrained weights...") | |
| self.backbone.load_state_dict(state_dict) | |
| def forward(self, inputs): | |
| """ | |
| Args: | |
| inputs (Tensor): input image. | |
| Returns: | |
| Tuple[Tensor]: FPN output features.. | |
| """ | |
| # backbone | |
| out_features = self.backbone(inputs) | |
| x2, x1, x0 = [out_features[f] for f in self.in_features] | |
| # yolo branch 1 | |
| x1_in = self.out1_cbl(x0) | |
| x1_in = self.upsample(x1_in) | |
| x1_in = torch.cat([x1_in, x1], 1) | |
| out_dark4 = self.out1(x1_in) | |
| # yolo branch 2 | |
| x2_in = self.out2_cbl(out_dark4) | |
| x2_in = self.upsample(x2_in) | |
| x2_in = torch.cat([x2_in, x2], 1) | |
| out_dark3 = self.out2(x2_in) | |
| outputs = (out_dark3, out_dark4, x0) | |
| return outputs | |