File size: 2,788 Bytes
5facae9 542e6d3 9954323 9e99f59 9954323 9e99f59 9954323 9e99f59 9954323 9e99f59 9954323 9e99f59 9954323 9e99f59 9954323 9e99f59 9954323 9e99f59 9954323 9e99f59 9954323 9e99f59 9954323 9e99f59 9954323 9e99f59 9954323 9e99f59 9954323 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import torch
import torch.nn as nn
from typing import Dict, List, Tuple
class Exp:
"""
Configuration class for the page element model.
This class contains all configuration parameters for the YOLOX-based
page element detection model, including architecture settings, inference
parameters, and class-specific thresholds.
"""
def __init__(self) -> None:
"""Initialize the configuration with default parameters."""
self.name: str = "page-element-v3"
self.ckpt: str = "weights.pth"
self.device: str = "cuda:0" if torch.cuda.is_available() else "cpu"
# YOLOX architecture parameters
self.act: str = "silu"
self.depth: float = 1.00
self.width: float = 1.00
self.labels: List[str] = [
"table",
"chart",
"title",
"infographic",
"text",
"header_footer",
]
self.num_classes: int = len(self.labels)
# Inference parameters
self.size: Tuple[int, int] = (1024, 1024)
self.min_bbox_size: int = 0
self.normalize_boxes: bool = True
# NMS & thresholding. These can be updated
self.conf_thresh: float = 0.01
self.iou_thresh: float = 0.5
self.class_agnostic: bool = True
self.thresholds_per_class: Dict[str, float] = {
"table": 0.1,
"chart": 0.01,
"infographic": 0.01,
"title": 0.1,
"text": 0.1,
"header_footer": 0.1,
}
def get_model(self) -> nn.Module:
"""
Get the YOLOX model.
Builds and returns a YOLOX model with the configured architecture.
Also updates batch normalization parameters for optimal inference.
Returns:
nn.Module: The YOLOX model with configured parameters.
"""
from yolox import YOLOX, YOLOPAFPN, YOLOXHead
# Build model
if getattr(self, "model", None) is None:
in_channels = [256, 512, 1024]
backbone = YOLOPAFPN(
self.depth, self.width, in_channels=in_channels, act=self.act
)
head = YOLOXHead(
self.num_classes, self.width, in_channels=in_channels, act=self.act
)
self.model = YOLOX(backbone, head)
# Update batch-norm parameters
def init_yolo(M: nn.Module) -> None:
for m in M.modules():
if isinstance(m, nn.BatchNorm2d):
m.eps = 1e-3
m.momentum = 0.03
self.model.apply(init_yolo)
return self.model
|