SerdarHelli commited on
Commit
62f828b
·
verified ·
1 Parent(s): 0b46a85

Upload 18 files

Browse files
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from io import BytesIO
4
+ from src.pipeline import InferencePipeline
5
+ from src.app.config import load_config
6
+
7
+ # Load configuration and initialize the inference pipeline
8
+ config = load_config()
9
+ inference_pipeline = InferencePipeline(config)
10
+
11
+ def process_image_from_bytes(file, apply_clahe_postprocess,apply_pre_contrast_adjustment,return_original_size):
12
+ """
13
+ Process the image bytes using the inference pipeline.
14
+
15
+ Args:
16
+ file_bytes: The image file in bytes.
17
+ apply_clahe_postprocess: Boolean indicating if CLAHE postprocessing should be applied.
18
+
19
+ Returns:
20
+ The processed image.
21
+ """
22
+ try:
23
+ # Perform super-resolution
24
+ sr_image = inference_pipeline.run(file, apply_pre_contrast_adjustment=apply_pre_contrast_adjustment, apply_clahe_postprocess=apply_clahe_postprocess,return_original_size=return_original_size)
25
+ return sr_image
26
+ except Exception as e:
27
+ return f"An exception occurred: {str(e)}"
28
+
29
+ # Define the Gradio interface
30
+ def gradio_interface():
31
+ with gr.Blocks() as demo:
32
+ gr.Markdown("""
33
+ # X-Ray Image Super-Resolution-Denoiser Demo
34
+ Provide image bytes to process and optionally apply CLAHE postprocessing.
35
+ """)
36
+
37
+ with gr.Row():
38
+ file_input = gr.File(label="Upload Image (PNG, JPEG, or DICOM)")
39
+ apply_clahe_checkbox = gr.Checkbox(label="Apply CLAHE Postprocessing", value=False)
40
+ apply_pre_contrast_adjustment_checkbox = gr.Checkbox(label="Apply PreContrast Adjustment", value=False)
41
+ return_original_size_checkbox = gr.Checkbox(label="Return Original Size", value=True)
42
+
43
+ process_button = gr.Button("Process Image")
44
+ output_image = gr.Image(label="Processed Image")
45
+
46
+
47
+
48
+ process_button.click(
49
+ process_image_from_bytes,
50
+ inputs=[file_input, apply_clahe_checkbox,apply_pre_contrast_adjustment_checkbox,return_original_size_checkbox],
51
+ outputs=output_image
52
+ )
53
+
54
+ return demo
55
+
56
+ # Launch the Gradio interface
57
+ demo = gradio_interface()
58
+
59
+ demo.launch(
60
+ debug=True,
61
+ )
configs/config.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ source: "huggingface" # Options: "huggingface" or "local"
3
+ repo_id: "SerdarHelli/super_res_xray" # Required if source is "huggingface"
4
+ filename: "net_g.pth" # Model weights filename in HF repo
5
+ weights: "/path/to/weights.pth" # Optional if using local weights
6
+ scale: 4
7
+ device: "cuda" # Options: "cuda", "cpu"
8
+
9
+ preprocessing:
10
+ unsharping_mask:
11
+ kernel_size : 7
12
+ strength: 2
13
+ brightness:
14
+ factor : 1.2
15
+
16
+ postprocessing:
17
+ clahe:
18
+ clipLimit: 2
19
+ tileGridSize:
20
+ - 16
21
+ - 16
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy == 1.26.4
2
+ opencv-python == 4.10.0.84
3
+ Pillow == 11.0.0
4
+ torch == 2.5.1
5
+ torchvision == 0.20.1
6
+ tqdm == 4.67.1
7
+ pydicom == 3.0.1
8
+ fastapi == 0.115.6
9
+ uvicorn == 0.34.0
10
+ scikit-image == 0.25.0
11
+ python-multipart == 0.0.20
12
+ huggingface-hub == 0.25.2
src/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .network.model import RealESRGAN
2
+ from .pipeline import InferencePipeline
3
+ from .preprocess import *
src/app/__init__.py ADDED
File without changes
src/app/config.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+
3
+ def load_config(config_path="configs/config.yaml"):
4
+ """
5
+ Load the configuration from a YAML file.
6
+
7
+ Args:
8
+ config_path: Path to the configuration file.
9
+
10
+ Returns:
11
+ dict: Configuration dictionary.
12
+ """
13
+ try:
14
+ with open(config_path, "r") as file:
15
+ config = yaml.safe_load(file)
16
+ return config
17
+ except FileNotFoundError:
18
+ raise Exception(f"Configuration file '{config_path}' not found.")
19
+ except yaml.YAMLError as e:
20
+ raise Exception(f"Error parsing configuration file: {e}")
src/app/exceptions.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ModelLoadError(Exception):
2
+ """Raised when the model fails to load."""
3
+ def __init__(self, message="Failed to load the model."):
4
+ self.message = message
5
+ super().__init__(self.message)
6
+
7
+
8
+ class PreprocessingError(Exception):
9
+ """Raised when an error occurs during preprocessing."""
10
+ def __init__(self, message="Error during image preprocessing."):
11
+ self.message = message
12
+ super().__init__(self.message)
13
+
14
+ class PostprocessingError(Exception):
15
+ """Raised when an error occurs during postprocessing."""
16
+ def __init__(self, message="Error during image postprocessing."):
17
+ self.message = message
18
+ super().__init__(self.message)
19
+
20
+ class InferenceError(Exception):
21
+ """Raised when an error occurs during inference."""
22
+ def __init__(self, message="Error during inference."):
23
+ self.message = message
24
+ super().__init__(self.message)
25
+
26
+
27
+ class InputError(Exception):
28
+ """Raised when an error occurs during loading input."""
29
+ def __init__(self, message="Error loading input."):
30
+ self.message = message
31
+ super().__init__(self.message)
src/app/main.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI,Request,status
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse
4
+ from src.app.routes import inference
5
+ from src.app.exceptions import ModelLoadError, PreprocessingError, InferenceError,InputError, PostprocessingError
6
+ import torch
7
+ import os
8
+ import sys
9
+
10
+ app = FastAPI(title="Super Resolution Dental X-ray API", version="1.0.0")
11
+ app.add_middleware(CORSMiddleware, allow_origins=["*"])
12
+
13
+ # Include API routes
14
+ app.include_router(inference.router, prefix="/inference", tags=["Inference"])
15
+
16
+ @app.get("/")
17
+ def read_root():
18
+ return {"message": "Welcome to the Super Resolution Dental X-ray API"}
19
+
20
+ @app.get("/health", tags=["Health"])
21
+ async def health_check():
22
+ """
23
+ Health check endpoint to ensure the API and CUDA are running.
24
+
25
+ Returns:
26
+ dict: Status message indicating the API and CUDA availability.
27
+ """
28
+ def bash(command):
29
+ return os.popen(command).read()
30
+
31
+ # Check CUDA status
32
+
33
+ # Construct response
34
+ return {
35
+ "status": "Healthy",
36
+ "message": "API is running successfully.",
37
+ "cuda": {
38
+ "sys.version": sys.version,
39
+ "torch.__version__": torch.__version__,
40
+ "torch.cuda.is_available()": torch.cuda.is_available(),
41
+ "torch.version.cuda": torch.version.cuda,
42
+ "torch.backends.cudnn.version()": torch.backends.cudnn.version(),
43
+ "torch.backends.cudnn.enabled": torch.backends.cudnn.enabled,
44
+ "nvidia-smi": bash('nvidia-smi')
45
+ }
46
+ }
47
+
48
+ # Custom exception handlers
49
+ @app.exception_handler(ModelLoadError)
50
+ async def model_load_error_handler(request: Request, exc: ModelLoadError):
51
+ return JSONResponse(
52
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
53
+ content={"error": "ModelLoadError", "message": exc.message},
54
+ )
55
+
56
+ @app.exception_handler(PreprocessingError)
57
+ async def preprocessing_error_handler(request: Request, exc: PreprocessingError):
58
+ return JSONResponse(
59
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
60
+ content={"error": "PreprocessingError", "message": exc.message},
61
+ )
62
+
63
+ @app.exception_handler(PostprocessingError)
64
+ async def postprocessing_error_handler(request: Request, exc: PostprocessingError):
65
+ return JSONResponse(
66
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
67
+ content={"error": "PostprocessingError", "message": exc.message},
68
+ )
69
+
70
+ @app.exception_handler(InferenceError)
71
+ async def inference_error_handler(request: Request, exc: InferenceError):
72
+ return JSONResponse(
73
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
74
+ content={"error": "InferenceError", "message": exc.message},
75
+ )
76
+
77
+
78
+ @app.exception_handler(InputError)
79
+ async def input_load_error_handler(request: Request, exc: InputError):
80
+ return JSONResponse(
81
+ status_code=status.HTTP_400_BAD_REQUEST,
82
+ content={"error": "InputError", "message": exc.message},
83
+ )
src/app/routes/__init__.py ADDED
File without changes
src/app/routes/inference.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, UploadFile, File
2
+ from io import BytesIO
3
+ from src.app.config import load_config
4
+ from src.pipeline import InferencePipeline
5
+ # Define the router
6
+ router = APIRouter()
7
+
8
+ # Load configuration
9
+ config = load_config()
10
+ inference_pipeline = InferencePipeline(config)
11
+
12
+ from fastapi import APIRouter, UploadFile, File, HTTPException
13
+ from io import BytesIO
14
+ import os
15
+ from fastapi import HTTPException
16
+ from fastapi.responses import FileResponse
17
+
18
+ @router.post("/predict")
19
+ async def process_image(
20
+ file: UploadFile = File(...),
21
+ apply_clahe_postprocess: bool = False,
22
+ apply_pre_contrast_adjustment: bool = True,
23
+ return_original_size: bool = True
24
+
25
+ ):
26
+ """
27
+ API endpoint to process and super-resolve an image.
28
+
29
+ Args:
30
+ file: Image file to process (PNG, JPEG, or DICOM).
31
+ apply_clahe_postprocess: Boolean indicating if CLAHE should be applied post-processing.
32
+
33
+ Returns:
34
+ FileResponse: Processed image file or error message.
35
+ """
36
+ try:
37
+
38
+ # Validate apply_clahe_postprocess parameter
39
+ if not isinstance(apply_clahe_postprocess, bool):
40
+ raise HTTPException(
41
+ status_code=400,
42
+ detail="The 'apply_clahe_postprocess' parameter must be a boolean."
43
+ )
44
+ if not isinstance(apply_pre_contrast_adjustment, bool):
45
+ raise HTTPException(
46
+ status_code=400,
47
+ detail="The 'apply_pre_contrast_adjustment' parameter must be a boolean."
48
+ )
49
+ if not isinstance(return_original_size, bool):
50
+ raise HTTPException(
51
+ status_code=400,
52
+ detail="The 'return_original_size' parameter must be a boolean."
53
+ )
54
+
55
+ # Read the uploaded file into memory
56
+ file_bytes = await file.read()
57
+
58
+ # Perform inference with the pipeline
59
+ sr_image = inference_pipeline.run(BytesIO(file_bytes), apply_clahe_postprocess=apply_clahe_postprocess,
60
+ apply_pre_contrast_adjustment = apply_pre_contrast_adjustment,
61
+ return_original_size = return_original_size
62
+ )
63
+
64
+ # Save the processed image to a temporary file
65
+ output_file_path = "output_highres.png"
66
+ sr_image.save(output_file_path, format="PNG")
67
+
68
+ # Return the file as a response
69
+ return FileResponse(
70
+ path=output_file_path,
71
+ media_type="image/png",
72
+ filename="processed_image.png"
73
+ )
74
+
75
+ except HTTPException as e:
76
+ raise e
77
+
78
+ except Exception as e:
79
+ raise HTTPException(
80
+ status_code=500,
81
+ detail=f"An error occurred during processing: {str(e)}"
82
+ )
83
+ finally:
84
+ # Cleanup temporary file if it exists
85
+ if os.path.exists("output_highres.png"):
86
+ os.remove("output_highres.png")
87
+
src/network/__init__.py ADDED
File without changes
src/network/arch_utils.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn as nn
4
+ from torch.nn import functional as F
5
+ from torch.nn import init as init
6
+ from torch.nn.modules.batchnorm import _BatchNorm
7
+
8
+ @torch.no_grad()
9
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
10
+ """Initialize network weights.
11
+
12
+ Args:
13
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
14
+ scale (float): Scale initialized weights, especially for residual
15
+ blocks. Default: 1.
16
+ bias_fill (float): The value to fill bias. Default: 0
17
+ kwargs (dict): Other arguments for initialization function.
18
+ """
19
+ if not isinstance(module_list, list):
20
+ module_list = [module_list]
21
+ for module in module_list:
22
+ for m in module.modules():
23
+ if isinstance(m, nn.Conv2d):
24
+ init.kaiming_normal_(m.weight, **kwargs)
25
+ m.weight.data *= scale
26
+ if m.bias is not None:
27
+ m.bias.data.fill_(bias_fill)
28
+ elif isinstance(m, nn.Linear):
29
+ init.kaiming_normal_(m.weight, **kwargs)
30
+ m.weight.data *= scale
31
+ if m.bias is not None:
32
+ m.bias.data.fill_(bias_fill)
33
+ elif isinstance(m, _BatchNorm):
34
+ init.constant_(m.weight, 1)
35
+ if m.bias is not None:
36
+ m.bias.data.fill_(bias_fill)
37
+
38
+
39
+ def make_layer(basic_block, num_basic_block, **kwarg):
40
+ """Make layers by stacking the same blocks.
41
+
42
+ Args:
43
+ basic_block (nn.module): nn.module class for basic block.
44
+ num_basic_block (int): number of blocks.
45
+
46
+ Returns:
47
+ nn.Sequential: Stacked blocks in nn.Sequential.
48
+ """
49
+ layers = []
50
+ for _ in range(num_basic_block):
51
+ layers.append(basic_block(**kwarg))
52
+ return nn.Sequential(*layers)
53
+
54
+
55
+ class ResidualBlockNoBN(nn.Module):
56
+ """Residual block without BN.
57
+
58
+ It has a style of:
59
+ ---Conv-ReLU-Conv-+-
60
+ |________________|
61
+
62
+ Args:
63
+ num_feat (int): Channel number of intermediate features.
64
+ Default: 64.
65
+ res_scale (float): Residual scale. Default: 1.
66
+ pytorch_init (bool): If set to True, use pytorch default init,
67
+ otherwise, use default_init_weights. Default: False.
68
+ """
69
+
70
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
71
+ super(ResidualBlockNoBN, self).__init__()
72
+ self.res_scale = res_scale
73
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
74
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
75
+ self.relu = nn.ReLU(inplace=True)
76
+
77
+ if not pytorch_init:
78
+ default_init_weights([self.conv1, self.conv2], 0.1)
79
+
80
+ def forward(self, x):
81
+ identity = x
82
+ out = self.conv2(self.relu(self.conv1(x)))
83
+ return identity + out * self.res_scale
84
+
85
+
86
+ class Upsample(nn.Sequential):
87
+ """Upsample module.
88
+
89
+ Args:
90
+ scale (int): Scale factor. Supported scales: 2^n and 3.
91
+ num_feat (int): Channel number of intermediate features.
92
+ """
93
+
94
+ def __init__(self, scale, num_feat):
95
+ m = []
96
+ if (scale & (scale - 1)) == 0: # scale = 2^n
97
+ for _ in range(int(math.log(scale, 2))):
98
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
99
+ m.append(nn.PixelShuffle(2))
100
+ elif scale == 3:
101
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
102
+ m.append(nn.PixelShuffle(3))
103
+ else:
104
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
105
+ super(Upsample, self).__init__(*m)
106
+
107
+
108
+ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
109
+ """Warp an image or feature map with optical flow.
110
+
111
+ Args:
112
+ x (Tensor): Tensor with size (n, c, h, w).
113
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
114
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
115
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
116
+ Default: 'zeros'.
117
+ align_corners (bool): Before pytorch 1.3, the default value is
118
+ align_corners=True. After pytorch 1.3, the default value is
119
+ align_corners=False. Here, we use the True as default.
120
+
121
+ Returns:
122
+ Tensor: Warped image or feature map.
123
+ """
124
+ assert x.size()[-2:] == flow.size()[1:3]
125
+ _, _, h, w = x.size()
126
+ # create mesh grid
127
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
128
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
129
+ grid.requires_grad = False
130
+
131
+ vgrid = grid + flow
132
+ # scale grid to [-1,1]
133
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
134
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
135
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
136
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
137
+
138
+ # TODO, what if align_corners=False
139
+ return output
140
+
141
+
142
+ def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
143
+ """Resize a flow according to ratio or shape.
144
+
145
+ Args:
146
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
147
+ size_type (str): 'ratio' or 'shape'.
148
+ sizes (list[int | float]): the ratio for resizing or the final output
149
+ shape.
150
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
151
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
152
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
153
+ ratio > 1.0).
154
+ 2) The order of output_size should be [out_h, out_w].
155
+ interp_mode (str): The mode of interpolation for resizing.
156
+ Default: 'bilinear'.
157
+ align_corners (bool): Whether align corners. Default: False.
158
+
159
+ Returns:
160
+ Tensor: Resized flow.
161
+ """
162
+ _, _, flow_h, flow_w = flow.size()
163
+ if size_type == 'ratio':
164
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
165
+ elif size_type == 'shape':
166
+ output_h, output_w = sizes[0], sizes[1]
167
+ else:
168
+ raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
169
+
170
+ input_flow = flow.clone()
171
+ ratio_h = output_h / flow_h
172
+ ratio_w = output_w / flow_w
173
+ input_flow[:, 0, :, :] *= ratio_w
174
+ input_flow[:, 1, :, :] *= ratio_h
175
+ resized_flow = F.interpolate(
176
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
177
+ return resized_flow
178
+
179
+
180
+ # TODO: may write a cpp file
181
+ def pixel_unshuffle(x, scale):
182
+ """ Pixel unshuffle.
183
+
184
+ Args:
185
+ x (Tensor): Input feature with shape (b, c, hh, hw).
186
+ scale (int): Downsample ratio.
187
+
188
+ Returns:
189
+ Tensor: the pixel unshuffled feature.
190
+ """
191
+ b, c, hh, hw = x.size()
192
+ out_channel = c * (scale**2)
193
+ assert hh % scale == 0 and hw % scale == 0
194
+ h = hh // scale
195
+ w = hw // scale
196
+ x_view = x.view(b, c, h, scale, w, scale)
197
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
src/network/model.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+
6
+ from .rrdbnet_arch import RRDBNet
7
+ from .utils import pad_reflect, split_image_into_overlapping_patches, stich_together, \
8
+ unpad_image
9
+
10
+
11
+
12
+
13
+ class RealESRGAN:
14
+ def __init__(self, device, scale=4):
15
+ self.device = device
16
+ self.scale = scale
17
+ self.model = RRDBNet(
18
+ num_in_ch=3, num_out_ch=3, num_feat=64,
19
+ num_block=23, num_grow_ch=32, scale=scale
20
+ )
21
+
22
+ def load_weights(self, model_path):
23
+
24
+ loadnet = torch.load(model_path)
25
+ if 'params' in loadnet:
26
+ self.model.load_state_dict(loadnet['params'], strict=True)
27
+ elif 'params_ema' in loadnet:
28
+ self.model.load_state_dict(loadnet['params_ema'], strict=True)
29
+ else:
30
+ self.model.load_state_dict(loadnet, strict=True)
31
+ self.model.eval()
32
+ self.model.to(self.device)
33
+
34
+ @torch.cuda.amp.autocast()
35
+ def predict(self, lr_image, batch_size=4, patches_size=192,
36
+ padding=24, pad_size=15):
37
+ scale = self.scale
38
+ device = self.device
39
+ lr_image = np.array(lr_image)
40
+ lr_image = pad_reflect(lr_image, pad_size)
41
+
42
+ patches, p_shape = split_image_into_overlapping_patches(
43
+ lr_image, patch_size=patches_size, padding_size=padding
44
+ )
45
+ img = torch.FloatTensor(patches/255).permute((0,3,1,2)).to(device).detach()
46
+
47
+ with torch.no_grad():
48
+ res = self.model(img[0:batch_size])
49
+ for i in range(batch_size, img.shape[0], batch_size):
50
+ res = torch.cat((res, self.model(img[i:i+batch_size])), 0)
51
+
52
+ sr_image = res.permute((0,2,3,1)).clamp_(0, 1).cpu()
53
+ np_sr_image = sr_image.numpy()
54
+
55
+ padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
56
+ scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
57
+ np_sr_image = stich_together(
58
+ np_sr_image, padded_image_shape=padded_size_scaled,
59
+ target_shape=scaled_image_shape, padding_size=padding * scale
60
+ )
61
+ sr_img = (np_sr_image*255).astype(np.uint8)
62
+ sr_img = unpad_image(sr_img, pad_size*scale)
63
+ sr_img = Image.fromarray(sr_img)
64
+
65
+ return sr_img
src/network/rrdbnet_arch.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from .arch_utils import default_init_weights, make_layer, pixel_unshuffle
6
+
7
+
8
+ class ResidualDenseBlock(nn.Module):
9
+ """Residual Dense Block.
10
+
11
+ Used in RRDB block in ESRGAN.
12
+
13
+ Args:
14
+ num_feat (int): Channel number of intermediate features.
15
+ num_grow_ch (int): Channels for each growth.
16
+ """
17
+
18
+ def __init__(self, num_feat=64, num_grow_ch=32):
19
+ super(ResidualDenseBlock, self).__init__()
20
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
21
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
22
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
23
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
24
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
25
+
26
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
27
+
28
+ # initialization
29
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
30
+
31
+ def forward(self, x):
32
+ x1 = self.lrelu(self.conv1(x))
33
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
34
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
35
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
36
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
37
+ # Emperically, we use 0.2 to scale the residual for better performance
38
+ return x5 * 0.2 + x
39
+
40
+
41
+ class RRDB(nn.Module):
42
+ """Residual in Residual Dense Block.
43
+
44
+ Used in RRDB-Net in ESRGAN.
45
+
46
+ Args:
47
+ num_feat (int): Channel number of intermediate features.
48
+ num_grow_ch (int): Channels for each growth.
49
+ """
50
+
51
+ def __init__(self, num_feat, num_grow_ch=32):
52
+ super(RRDB, self).__init__()
53
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
54
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
55
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
56
+
57
+ def forward(self, x):
58
+ out = self.rdb1(x)
59
+ out = self.rdb2(out)
60
+ out = self.rdb3(out)
61
+ # Emperically, we use 0.2 to scale the residual for better performance
62
+ return out * 0.2 + x
63
+
64
+
65
+ class RRDBNet(nn.Module):
66
+ """Networks consisting of Residual in Residual Dense Block, which is used
67
+ in ESRGAN.
68
+
69
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
70
+
71
+ We extend ESRGAN for scale x2 and scale x1.
72
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
73
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
74
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
75
+
76
+ Args:
77
+ num_in_ch (int): Channel number of inputs.
78
+ num_out_ch (int): Channel number of outputs.
79
+ num_feat (int): Channel number of intermediate features.
80
+ Default: 64
81
+ num_block (int): Block number in the trunk network. Defaults: 23
82
+ num_grow_ch (int): Channels for each growth. Default: 32.
83
+ """
84
+
85
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
86
+ super(RRDBNet, self).__init__()
87
+ self.scale = scale
88
+ if scale == 2:
89
+ num_in_ch = num_in_ch * 4
90
+ elif scale == 1:
91
+ num_in_ch = num_in_ch * 16
92
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
93
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
94
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
95
+ # upsample
96
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
97
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
98
+ if scale == 8:
99
+ self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
100
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
101
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
102
+
103
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
104
+
105
+ def forward(self, x):
106
+ if self.scale == 2:
107
+ feat = pixel_unshuffle(x, scale=2)
108
+ elif self.scale == 1:
109
+ feat = pixel_unshuffle(x, scale=4)
110
+ else:
111
+ feat = x
112
+ feat = self.conv_first(feat)
113
+ body_feat = self.conv_body(self.body(feat))
114
+ feat = feat + body_feat
115
+ # upsample
116
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
117
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
118
+ if self.scale == 8:
119
+ feat = self.lrelu(self.conv_up3(F.interpolate(feat, scale_factor=2, mode='nearest')))
120
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
121
+ return out
src/network/utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ import os
5
+ import io
6
+
7
+ def pad_reflect(image, pad_size):
8
+ imsize = image.shape
9
+ height, width = imsize[:2]
10
+ new_img = np.zeros([height+pad_size*2, width+pad_size*2, imsize[2]]).astype(np.uint8)
11
+ new_img[pad_size:-pad_size, pad_size:-pad_size, :] = image
12
+
13
+ new_img[0:pad_size, pad_size:-pad_size, :] = np.flip(image[0:pad_size, :, :], axis=0) #top
14
+ new_img[-pad_size:, pad_size:-pad_size, :] = np.flip(image[-pad_size:, :, :], axis=0) #bottom
15
+ new_img[:, 0:pad_size, :] = np.flip(new_img[:, pad_size:pad_size*2, :], axis=1) #left
16
+ new_img[:, -pad_size:, :] = np.flip(new_img[:, -pad_size*2:-pad_size, :], axis=1) #right
17
+
18
+ return new_img
19
+
20
+ def unpad_image(image, pad_size):
21
+ return image[pad_size:-pad_size, pad_size:-pad_size, :]
22
+
23
+
24
+ def process_array(image_array, expand=True):
25
+ """ Process a 3-dimensional array into a scaled, 4 dimensional batch of size 1. """
26
+
27
+ image_batch = image_array / 255.0
28
+ if expand:
29
+ image_batch = np.expand_dims(image_batch, axis=0)
30
+ return image_batch
31
+
32
+
33
+ def process_output(output_tensor):
34
+ """ Transforms the 4-dimensional output tensor into a suitable image format. """
35
+
36
+ sr_img = output_tensor.clip(0, 1) * 255
37
+ sr_img = np.uint8(sr_img)
38
+ return sr_img
39
+
40
+
41
+ def pad_patch(image_patch, padding_size, channel_last=True):
42
+ """ Pads image_patch with with padding_size edge values. """
43
+
44
+ if channel_last:
45
+ return np.pad(
46
+ image_patch,
47
+ ((padding_size, padding_size), (padding_size, padding_size), (0, 0)),
48
+ 'edge',
49
+ )
50
+ else:
51
+ return np.pad(
52
+ image_patch,
53
+ ((0, 0), (padding_size, padding_size), (padding_size, padding_size)),
54
+ 'edge',
55
+ )
56
+
57
+
58
+ def unpad_patches(image_patches, padding_size):
59
+ return image_patches[:, padding_size:-padding_size, padding_size:-padding_size, :]
60
+
61
+
62
+ def split_image_into_overlapping_patches(image_array, patch_size, padding_size=2):
63
+ """ Splits the image into partially overlapping patches.
64
+ The patches overlap by padding_size pixels.
65
+ Pads the image twice:
66
+ - first to have a size multiple of the patch size,
67
+ - then to have equal padding at the borders.
68
+ Args:
69
+ image_array: numpy array of the input image.
70
+ patch_size: size of the patches from the original image (without padding).
71
+ padding_size: size of the overlapping area.
72
+ """
73
+
74
+ xmax, ymax, _ = image_array.shape
75
+ x_remainder = xmax % patch_size
76
+ y_remainder = ymax % patch_size
77
+
78
+ # modulo here is to avoid extending of patch_size instead of 0
79
+ x_extend = (patch_size - x_remainder) % patch_size
80
+ y_extend = (patch_size - y_remainder) % patch_size
81
+
82
+ # make sure the image is divisible into regular patches
83
+ extended_image = np.pad(image_array, ((0, x_extend), (0, y_extend), (0, 0)), 'edge')
84
+
85
+ # add padding around the image to simplify computations
86
+ padded_image = pad_patch(extended_image, padding_size, channel_last=True)
87
+
88
+ xmax, ymax, _ = padded_image.shape
89
+ patches = []
90
+
91
+ x_lefts = range(padding_size, xmax - padding_size, patch_size)
92
+ y_tops = range(padding_size, ymax - padding_size, patch_size)
93
+
94
+ for x in x_lefts:
95
+ for y in y_tops:
96
+ x_left = x - padding_size
97
+ y_top = y - padding_size
98
+ x_right = x + patch_size + padding_size
99
+ y_bottom = y + patch_size + padding_size
100
+ patch = padded_image[x_left:x_right, y_top:y_bottom, :]
101
+ patches.append(patch)
102
+
103
+ return np.array(patches), padded_image.shape
104
+
105
+
106
+ def stich_together(patches, padded_image_shape, target_shape, padding_size=4):
107
+ """ Reconstruct the image from overlapping patches.
108
+ After scaling, shapes and padding should be scaled too.
109
+ Args:
110
+ patches: patches obtained with split_image_into_overlapping_patches
111
+ padded_image_shape: shape of the padded image contructed in split_image_into_overlapping_patches
112
+ target_shape: shape of the final image
113
+ padding_size: size of the overlapping area.
114
+ """
115
+
116
+ xmax, ymax, _ = padded_image_shape
117
+ patches = unpad_patches(patches, padding_size)
118
+ patch_size = patches.shape[1]
119
+ n_patches_per_row = ymax // patch_size
120
+
121
+ complete_image = np.zeros((xmax, ymax, 3))
122
+
123
+ row = -1
124
+ col = 0
125
+ for i in range(len(patches)):
126
+ if i % n_patches_per_row == 0:
127
+ row += 1
128
+ col = 0
129
+ complete_image[
130
+ row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size,:
131
+ ] = patches[i]
132
+ col += 1
133
+ return complete_image[0: target_shape[0], 0: target_shape[1], :]
src/pipeline.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import numpy as np
4
+ from io import BytesIO
5
+ from huggingface_hub import hf_hub_download
6
+ from pathlib import Path
7
+
8
+ from src.preprocess import read_xray, enhance_exposure, unsharp_masking, apply_clahe, resize_pil_image, increase_brightness
9
+ from src.network.model import RealESRGAN
10
+ from src.app.exceptions import InputError, ModelLoadError, PreprocessingError, InferenceError,PostprocessingError
11
+
12
+ class ModelLoadError(Exception):
13
+ pass
14
+
15
+ class InferencePipeline:
16
+ def __init__(self, config):
17
+ """
18
+ Initialize the inference pipeline using configuration.
19
+
20
+ Args:
21
+ config: Configuration dictionary.
22
+ """
23
+ self.config = config
24
+ self.device = config["model"].get("device", "cuda" if torch.cuda.is_available() else "cpu")
25
+ self.scale = config["model"].get("scale", 4)
26
+
27
+ model_source = config["model"].get("source", "local")
28
+ self.model = RealESRGAN(self.device, scale=self.scale)
29
+
30
+ print(f"Using device: {self.device}")
31
+
32
+ try:
33
+ if model_source == "huggingface":
34
+ repo_id = config["model"]["repo_id"]
35
+ filename = config["model"]["filename"]
36
+ local_path = hf_hub_download(repo_id=repo_id, filename=filename)
37
+ self.load_weights(local_path)
38
+ else:
39
+ local_path = config["model"]["weights"]
40
+ self.load_weights(local_path)
41
+ except Exception as e:
42
+ raise ModelLoadError(f"Failed to load the model: {str(e)}")
43
+
44
+ def load_weights(self, model_weights):
45
+ """
46
+ Load the model weights.
47
+
48
+ Args:
49
+ model_weights: Path to the model weights file.
50
+ """
51
+ try:
52
+ self.model.load_weights(model_weights)
53
+ except FileNotFoundError:
54
+ raise ModelLoadError(f"Model weights not found at '{model_weights}'.")
55
+ except Exception as e:
56
+ raise ModelLoadError(f"Error loading weights: {str(e)}")
57
+ def preprocess(self, image_path_or_bytes, apply_pre_contrast_adjustment=True, is_dicom=False):
58
+ """
59
+ Preprocess the input image.
60
+
61
+ Args:
62
+ image_path: Path to the input image file.
63
+ is_dicom: Boolean indicating if the input is a DICOM file.
64
+
65
+ Returns:
66
+ PIL Image: Preprocessed image.
67
+ """
68
+ try:
69
+ if is_dicom:
70
+ img = read_xray(image_path_or_bytes)
71
+ else:
72
+ img = Image.open(image_path_or_bytes)
73
+
74
+ if apply_pre_contrast_adjustment:
75
+ img = enhance_exposure(np.array(img))
76
+
77
+ if isinstance(img,np.ndarray):
78
+ img = Image.fromarray(((img / np.max(img))*255).astype(np.uint8))
79
+
80
+ if img.mode not in ['RGB']:
81
+ img = img.convert('RGB')
82
+
83
+ img = unsharp_masking(
84
+ img,
85
+ self.config["preprocessing"]["unsharping_mask"].get("kernel_size", 7),
86
+ self.config["preprocessing"]["unsharping_mask"].get("strength", 2)
87
+ )
88
+ img = increase_brightness(
89
+ img,
90
+ self.config["preprocessing"]["brightness"].get("factor", 1.2),
91
+ )
92
+
93
+
94
+ if img.mode not in ['RGB']:
95
+ img = img.convert('RGB')
96
+
97
+
98
+ return img, img.size
99
+ except Exception as e:
100
+ raise PreprocessingError(f"Error during preprocessing: {str(e)}")
101
+
102
+ def postprocess(self, image_array):
103
+ """
104
+ Postprocess the output from the model.
105
+
106
+ Args:
107
+ image_array: PIL.Image output from the model.
108
+
109
+ Returns:
110
+ PIL Image: Postprocessed image.
111
+ """
112
+ try:
113
+ return apply_clahe(
114
+ image_array,
115
+ self.config["postprocessing"]["clahe"].get("clipLimit", 2.0),
116
+ tuple(self.config["postprocessing"]["clahe"].get("tileGridSize", [16, 16]))
117
+ )
118
+ except Exception as e:
119
+ raise PostprocessingError(f"Error during postprocessing: {str(e)}")
120
+
121
+ def is_dicom(self, file_path_or_bytes):
122
+ """
123
+ Check if the input file is a DICOM file.
124
+
125
+ Args:
126
+ file_path_or_bytes (str or bytes or BytesIO): Path to the file, byte content, or BytesIO object.
127
+
128
+ Returns:
129
+ bool: True if the file is a DICOM file, False otherwise.
130
+ """
131
+ try:
132
+ if isinstance(file_path_or_bytes, str):
133
+ # Check the file extension
134
+ file_extension = Path(file_path_or_bytes).suffix.lower()
135
+ if file_extension in ['.dcm', '.dicom']:
136
+ return True
137
+
138
+ # Open the file and check the header
139
+ with open(file_path_or_bytes, 'rb') as file:
140
+ header = file.read(132)
141
+ return header[-4:] == b'DICM'
142
+
143
+ elif isinstance(file_path_or_bytes, BytesIO):
144
+ file_path_or_bytes.seek(0)
145
+ header = file_path_or_bytes.read(132)
146
+ file_path_or_bytes.seek(0) # Reset the stream position
147
+ return header[-4:] == b'DICM'
148
+
149
+ elif isinstance(file_path_or_bytes, bytes):
150
+ header = file_path_or_bytes[:132]
151
+ return header[-4:] == b'DICM'
152
+
153
+ except Exception as e:
154
+ print(f"Error during DICOM validation: {e}")
155
+ return False
156
+
157
+ return False
158
+
159
+ def validate_input(self, input_data):
160
+ """
161
+ Validate the input data to ensure it is suitable for processing.
162
+
163
+ Args:
164
+ input_data: Path to the input file, bytes content, or BytesIO object.
165
+
166
+ Returns:
167
+ bool: True if the input is valid, raises InputError otherwise.
168
+ """
169
+ if isinstance(input_data, str):
170
+ # Check if the file exists
171
+ if not Path(input_data).exists():
172
+ raise InputError(f"Input file '{input_data}' does not exist.")
173
+
174
+ # Check if the file type is supported
175
+ file_extension = Path(input_data).suffix.lower()
176
+ if file_extension not in ['.png', '.jpeg', '.jpg', '.dcm', '.dicom']:
177
+ raise InputError(f"Unsupported file type '{file_extension}'. Supported types are PNG, JPEG, and DICOM.")
178
+
179
+ elif isinstance(input_data, BytesIO):
180
+ # Check if BytesIO data is not empty
181
+ if input_data.getbuffer().nbytes == 0:
182
+ raise InputError("Input BytesIO data is empty.")
183
+
184
+ else:
185
+ raise InputError("Unsupported input type. Must be a file path, byte content, or BytesIO object.")
186
+
187
+ return True
188
+
189
+ def infer(self, input_image):
190
+ """
191
+ Perform inference on a single image.
192
+
193
+ Args:
194
+ input_image: PIL Image to be processed.
195
+
196
+ Returns:
197
+ PIL Image: Super-resolved image.
198
+ """
199
+ try:
200
+ # Perform inference
201
+ input_array = np.array(input_image)
202
+ sr_array = self.model.predict(input_array)
203
+ return sr_array
204
+
205
+ except Exception as e:
206
+ raise InferenceError(f"Error during inference: {str(e)}")
207
+
208
+ def run(self, input_path, apply_pre_contrast_adjustment = True, apply_clahe_postprocess=False, return_original_size = True):
209
+ """
210
+ Process a single image and save the output.
211
+
212
+ Args:
213
+ input_path: Path to the input image file.
214
+ is_dicom: Boolean indicating if the input is a DICOM file.
215
+ apply_clahe_postprocess: Boolean indicating if CLAHE should be applied post-processing.
216
+ """
217
+ # Validate the input
218
+ self.validate_input(input_path)
219
+
220
+ is_dicom =self.is_dicom(input_path)
221
+
222
+ img, original_size = self.preprocess(input_path, is_dicom=is_dicom, apply_pre_contrast_adjustment = apply_pre_contrast_adjustment)
223
+
224
+ if img is None:
225
+ raise InputError(f"Invalid Input")
226
+
227
+
228
+ sr_image = self.infer(img)
229
+
230
+ if apply_clahe_postprocess:
231
+ sr_image = self.postprocess(sr_image)
232
+
233
+ if return_original_size:
234
+ sr_image = resize_pil_image(sr_image, target_shape = original_size)
235
+
236
+ return sr_image
src/preprocess.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pydicom
2
+ import numpy as np
3
+ from pydicom.pixels import apply_voi_lut
4
+ from skimage import exposure
5
+ from PIL import Image,ImageEnhance
6
+ import cv2
7
+
8
+ def read_xray(path, voi_lut=True, fix_monochrome=True):
9
+ """
10
+ Read and preprocess a DICOM X-ray image.
11
+
12
+ Parameters:
13
+ - path: Path to the DICOM file.
14
+ - voi_lut: Apply VOI LUT if available.
15
+ - fix_monochrome: Fix inverted monochrome images.
16
+
17
+ Returns:
18
+ - NumPy array: Preprocessed X-ray image.
19
+ """
20
+ dicom = pydicom.dcmread(path)
21
+
22
+ # Apply VOI LUT if available
23
+ if voi_lut:
24
+ data = apply_voi_lut(dicom.pixel_array, dicom)
25
+ else:
26
+ data = dicom.pixel_array
27
+
28
+ # Fix inverted monochrome images
29
+ if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
30
+ data = np.amax(data) - data
31
+
32
+ # Normalize data to start from 0
33
+ data = data - np.min(data)
34
+
35
+ return data
36
+
37
+ def resize_pil_image(image: Image.Image, target_shape: tuple) -> Image.Image:
38
+ """
39
+ Resizes a PIL image based on a target shape.
40
+
41
+ Args:
42
+ image: Input PIL image.
43
+ target_shape: Desired shape for resizing. It can be a 2D tuple (height, width)
44
+ or a 3D tuple (height, width, channels), where channels will be ignored.
45
+
46
+ Returns:
47
+ Resized PIL image.
48
+ """
49
+ # Convert image to a numpy array
50
+ np_image = np.array(image)
51
+
52
+ # Extract the original height and width from the numpy array
53
+ height, width = np_image.shape[:2]
54
+
55
+ # If the target shape is 2D (height, width)
56
+ if len(target_shape) == 2:
57
+ new_width, new_height = target_shape
58
+ elif len(target_shape) == 3:
59
+ # If the target shape is 3D (height, width, channels), only change the first two dimensions
60
+ new_width, new_height = target_shape[:2]
61
+ else:
62
+ raise ValueError("Target shape must be either 2D or 3D.")
63
+
64
+ # Resize the image using cv2 or PIL's in-built resizing (no channels affected)
65
+ pil_resized_image = Image.fromarray(np_image).resize((new_width, new_height), Image.LANCZOS)
66
+
67
+ return pil_resized_image
68
+
69
+ def enhance_exposure(img):
70
+ """
71
+ Enhance image exposure using histogram equalization.
72
+
73
+ Parameters:
74
+ - img: Input image as a NumPy array.
75
+
76
+ Returns:
77
+ - PIL.Image: Exposure-enhanced image.
78
+ """
79
+ img = exposure.equalize_hist(img)
80
+ img = exposure.equalize_adapthist(img / np.max(img))
81
+ img = (img * 255).astype(np.uint8)
82
+ return Image.fromarray(img)
83
+
84
+ def unsharp_masking(image, kernel_size=5, strength=0.25):
85
+ """
86
+ Apply unsharp masking to enhance image sharpness.
87
+
88
+ Parameters:
89
+ - image: Input image as a NumPy array or PIL.Image.
90
+ - kernel_size: Size of the Gaussian blur kernel.
91
+ - strength: Strength of the high-pass filter.
92
+
93
+ Returns:
94
+ - PIL.Image: Sharpened image.
95
+ """
96
+ image = np.array(image)
97
+
98
+ # Convert to grayscale if needed
99
+ if len(image.shape) == 3:
100
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
101
+ else:
102
+ gray = image
103
+
104
+ # Apply Gaussian blur and calculate high-pass filter
105
+ blurred = cv2.GaussianBlur(gray, (kernel_size, kernel_size), 0)
106
+ high_pass = cv2.subtract(gray, blurred)
107
+
108
+ # Combine high-pass with original image
109
+ sharpened = cv2.addWeighted(gray, 1, high_pass, strength, 0)
110
+
111
+ return Image.fromarray(sharpened)
112
+
113
+ def increase_contrast(image: Image.Image, factor: float) -> Image.Image:
114
+ """
115
+ Increases the contrast of the input PIL image by a given factor.
116
+
117
+ Args:
118
+ image: Input PIL image.
119
+ factor: Factor by which to increase the contrast.
120
+ A factor of 1.0 means no change, values greater than 1.0 increase contrast,
121
+ values between 0.0 and 1.0 decrease contrast.
122
+
123
+ Returns:
124
+ Image with increased contrast.
125
+ """
126
+ if image.mode not in ['RGB', 'L']:
127
+ image = image.convert('RGB')
128
+
129
+ enhancer = ImageEnhance.Contrast(image)
130
+ image_enhanced = enhancer.enhance(factor)
131
+ return image_enhanced
132
+
133
+ def increase_brightness(image: Image.Image, factor: float) -> Image.Image:
134
+ """
135
+ Increases the brightness of the input PIL image by a given factor.
136
+
137
+ Args:
138
+ image: Input PIL image.
139
+ factor: Factor by which to increase the brightness.
140
+ A factor of 1.0 means no change, values greater than 1.0 increase brightness,
141
+ values between 0.0 and 1.0 decrease brightness.
142
+
143
+ Returns:
144
+ Image with increased brightness.
145
+ """
146
+ if image.mode not in ['RGB', 'L']:
147
+ image = image.convert('RGB')
148
+
149
+ enhancer = ImageEnhance.Brightness(image)
150
+ image_enhanced = enhancer.enhance(factor)
151
+ return image_enhanced
152
+
153
+ def apply_clahe(image, clipLimit=2.0, tileGridSize=(8, 8)):
154
+ """
155
+ Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) to an image.
156
+
157
+ Parameters:
158
+ - image: Input image as a PIL.Image.
159
+ - clipLimit: Threshold for contrast limiting.
160
+ - tileGridSize: Size of the grid for histogram equalization.
161
+
162
+ Returns:
163
+ - Processed image in the same format as the input (PIL.Image).
164
+ """
165
+
166
+ image_np = np.array(image)
167
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
168
+
169
+
170
+ # Apply CLAHE based on image type
171
+ if len(image_np.shape) == 2:
172
+ # Grayscale image
173
+ clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
174
+ processed = clahe.apply(image_np)
175
+ else:
176
+ # Color image: Apply CLAHE on the L channel in LAB space
177
+ lab = cv2.cvtColor(image_np, cv2.COLOR_BGR2LAB)
178
+ L, A, B = cv2.split(lab)
179
+
180
+ clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
181
+ L_clahe = clahe.apply(L)
182
+ lab_clahe = cv2.merge((L_clahe, A, B))
183
+ processed = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2BGR)
184
+
185
+ processed_rgb = cv2.cvtColor(processed, cv2.COLOR_BGR2RGB)
186
+ return Image.fromarray(processed_rgb)
187
+
tests/test_inference_pipeline.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from fastapi.testclient import TestClient
3
+ import numpy as np
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ import pydicom
7
+ from pydicom.dataset import Dataset, FileDataset
8
+ import tempfile
9
+ import os
10
+
11
+ import sys
12
+ from pathlib import Path
13
+
14
+ # Add the src directory to the Python path
15
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent ))
16
+
17
+ from src.app.main import app
18
+ from src.pipeline import InferencePipeline
19
+
20
+ # Initialize test client
21
+ client = TestClient(app)
22
+
23
+ @pytest.fixture
24
+ def pipeline_config():
25
+ return {
26
+ "model": {
27
+ "weights": "weights/model.pth",
28
+ "scale": 4,
29
+ "device": "cpu"
30
+ },
31
+ "preprocessing": {
32
+ "unsharping_mask": {
33
+ "kernel_size": 7,
34
+ "strength": 0.5
35
+ }
36
+ },
37
+ "postprocessing": {
38
+ "clahe": {
39
+ "clipLimit": 2,
40
+ "tileGridSize": [16, 16]
41
+ }
42
+ }
43
+ }
44
+
45
+ @pytest.fixture
46
+ def pipeline(pipeline_config):
47
+ return InferencePipeline(pipeline_config)
48
+
49
+ def create_dummy_dicom():
50
+ """Create a dummy DICOM file for testing."""
51
+ meta = Dataset()
52
+ meta.MediaStorageSOPClassUID = "1.2.840.10008.5.1.4.1.1.2"
53
+ meta.MediaStorageSOPInstanceUID = "1.2.3"
54
+ meta.TransferSyntaxUID = pydicom.uid.ExplicitVRLittleEndian
55
+
56
+ ds = FileDataset("", {}, file_meta=meta, preamble=b"\x00" * 128)
57
+
58
+ # Required Patient and Image Information
59
+ ds.PatientName = "Test"
60
+ ds.PatientID = "12345"
61
+ ds.Modality = "CT"
62
+ ds.StudyInstanceUID = "1.2.3.4.5.6.7.8.9.10"
63
+ ds.SeriesInstanceUID = "1.2.3.4.5.6.7.8.9.11"
64
+ ds.SOPInstanceUID = "1.2.3.4.5.6.7.8.9.12"
65
+ ds.StudyDate = "20240101"
66
+ ds.StudyTime = "120000"
67
+ ds.Manufacturer = "TestManufacturer"
68
+
69
+ # Required Image Data Information
70
+ ds.PhotometricInterpretation = "MONOCHROME2"
71
+ ds.Rows = 128
72
+ ds.Columns = 128
73
+ ds.BitsAllocated = 16
74
+ ds.BitsStored = 16 # Add missing Bits Stored
75
+ ds.HighBit = 15 # Highest bit set
76
+ ds.PixelRepresentation = 0 # Unsigned integer
77
+ ds.SamplesPerPixel = 1 # Single-channel (grayscale)
78
+ ds.PixelData = (np.random.rand(128, 128) * 65535).astype(np.uint16).tobytes()
79
+
80
+ # Save to a temporary file
81
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".dcm")
82
+ ds.save_as(temp_file.name)
83
+ return temp_file.name
84
+
85
+
86
+ def test_is_dicom(pipeline):
87
+ dicom_path = create_dummy_dicom()
88
+
89
+ # Test with file path
90
+ assert pipeline.is_dicom(dicom_path) is True
91
+
92
+ # Test with BytesIO
93
+ with open(dicom_path, "rb") as f:
94
+ dicom_bytes = BytesIO(f.read())
95
+ assert pipeline.is_dicom(dicom_bytes) is True
96
+
97
+ # Test with invalid BytesIO (non-DICOM content)
98
+ non_dicom_bytes = BytesIO()
99
+ non_dicom_bytes.write(b"\x89PNG\r\n\x1a\n" + b"\x00" * 128) # Write invalid header
100
+ non_dicom_bytes.seek(0)
101
+ assert pipeline.is_dicom(non_dicom_bytes) is False
102
+
103
+ os.remove(dicom_path)
104
+
105
+
106
+ def test_is_dicom(pipeline):
107
+ dicom_path = create_dummy_dicom()
108
+
109
+ # Test with file path
110
+ assert pipeline.is_dicom(dicom_path) is True, "DICOM file path should be recognized as DICOM"
111
+
112
+ # Test with BytesIO
113
+ with open(dicom_path, "rb") as f:
114
+ dicom_bytes = BytesIO(f.read())
115
+ assert pipeline.is_dicom(dicom_bytes) is True, "BytesIO DICOM content should be recognized as DICOM"
116
+
117
+
118
+
119
+ # Test with invalid BytesIO (non-DICOM content)
120
+ non_dicom_bytes = BytesIO()
121
+ non_dicom_bytes.write(b"\x89PNG\r\n\x1a\n" + b"\x00" * 128) # Write invalid header
122
+ non_dicom_bytes.seek(0)
123
+ assert pipeline.is_dicom(non_dicom_bytes) is False, "Non-DICOM BytesIO should not be recognized as DICOM"
124
+
125
+ # Test with invalid raw bytes
126
+ invalid_raw_bytes = b"\x89PNG\r\n\x1a\n" + b"\x00" * 128
127
+ assert pipeline.is_dicom(invalid_raw_bytes) is False, "Invalid raw bytes should not be recognized as DICOM"
128
+
129
+ os.remove(dicom_path)
130
+
131
+
132
+
133
+ def test_preprocess_normal_image(pipeline):
134
+ # Create a dummy image
135
+ image = Image.new("RGB", (128, 128), color="red")
136
+
137
+ # Test with BytesIO
138
+ image_bytes = BytesIO()
139
+ image.save(image_bytes, format="JPEG")
140
+ image_bytes.seek(0)
141
+
142
+ processed_image_bytes = pipeline.preprocess(image_bytes, is_dicom=False)
143
+ assert isinstance(processed_image_bytes, Image.Image)
144
+
145
+ # Test with file path
146
+ temp_image_path = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg").name
147
+ image.save(temp_image_path)
148
+
149
+ processed_image_path = pipeline.preprocess(temp_image_path, is_dicom=False)
150
+ assert isinstance(processed_image_path, Image.Image)
151
+
152
+ os.remove(temp_image_path)
153
+
154
+
155
+ def test_infer(pipeline):
156
+ # Create a dummy image
157
+ image = Image.new("RGB", (128, 128), color="red")
158
+
159
+ # Perform inference
160
+ result = pipeline.infer(image)
161
+ assert isinstance(result, Image.Image)
162
+
163
+
164
+ def test_postprocess(pipeline):
165
+
166
+ image = Image.new("RGB", (128, 128), color="red")
167
+ result = pipeline.postprocess(image)
168
+
169
+ assert isinstance(result, Image.Image)
170
+
171
+
172
+ def test_api_predict_normal_image():
173
+ # Create a dummy image
174
+ image = Image.new("RGB", (128, 128), color="red")
175
+ image_bytes = BytesIO()
176
+ image.save(image_bytes, format="JPEG")
177
+ image_bytes.seek(0)
178
+
179
+ response = client.post(
180
+ "/inference/predict", # Adjusted to include the prefix
181
+ files={"file": ("test.jpg", image_bytes, "image/jpeg")},
182
+ data={"apply_clahe_postprocess": "false"} # Ensure proper boolean conversion
183
+ )
184
+ assert response.status_code == 200, response.text
185
+ assert response.headers["content-type"] == "image/png"
186
+
187
+
188
+ def test_api_predict_dicom():
189
+ dicom_path = create_dummy_dicom()
190
+
191
+ # Use BytesIO for testing
192
+ with open(dicom_path, "rb") as f:
193
+ dicom_bytes = BytesIO(f.read())
194
+
195
+ response = client.post(
196
+ "/inference/predict", # Adjusted to include the prefix
197
+ files={"file": ("test.dcm", dicom_bytes, "application/dicom")},
198
+ data={"apply_clahe_postprocess": "false"} # Ensure proper boolean conversion
199
+ )
200
+ assert response.status_code == 200, response.text
201
+ assert response.headers["content-type"] == "image/png"
202
+
203
+ os.remove(dicom_path)
204
+