Spaces:
Sleeping
Sleeping
Upload 18 files
Browse files- app.py +61 -0
- configs/config.yaml +21 -0
- requirements.txt +12 -0
- src/__init__.py +3 -0
- src/app/__init__.py +0 -0
- src/app/config.py +20 -0
- src/app/exceptions.py +31 -0
- src/app/main.py +83 -0
- src/app/routes/__init__.py +0 -0
- src/app/routes/inference.py +87 -0
- src/network/__init__.py +0 -0
- src/network/arch_utils.py +197 -0
- src/network/model.py +65 -0
- src/network/rrdbnet_arch.py +121 -0
- src/network/utils.py +133 -0
- src/pipeline.py +236 -0
- src/preprocess.py +187 -0
- tests/test_inference_pipeline.py +204 -0
app.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from io import BytesIO
|
| 4 |
+
from src.pipeline import InferencePipeline
|
| 5 |
+
from src.app.config import load_config
|
| 6 |
+
|
| 7 |
+
# Load configuration and initialize the inference pipeline
|
| 8 |
+
config = load_config()
|
| 9 |
+
inference_pipeline = InferencePipeline(config)
|
| 10 |
+
|
| 11 |
+
def process_image_from_bytes(file, apply_clahe_postprocess,apply_pre_contrast_adjustment,return_original_size):
|
| 12 |
+
"""
|
| 13 |
+
Process the image bytes using the inference pipeline.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
file_bytes: The image file in bytes.
|
| 17 |
+
apply_clahe_postprocess: Boolean indicating if CLAHE postprocessing should be applied.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
The processed image.
|
| 21 |
+
"""
|
| 22 |
+
try:
|
| 23 |
+
# Perform super-resolution
|
| 24 |
+
sr_image = inference_pipeline.run(file, apply_pre_contrast_adjustment=apply_pre_contrast_adjustment, apply_clahe_postprocess=apply_clahe_postprocess,return_original_size=return_original_size)
|
| 25 |
+
return sr_image
|
| 26 |
+
except Exception as e:
|
| 27 |
+
return f"An exception occurred: {str(e)}"
|
| 28 |
+
|
| 29 |
+
# Define the Gradio interface
|
| 30 |
+
def gradio_interface():
|
| 31 |
+
with gr.Blocks() as demo:
|
| 32 |
+
gr.Markdown("""
|
| 33 |
+
# X-Ray Image Super-Resolution-Denoiser Demo
|
| 34 |
+
Provide image bytes to process and optionally apply CLAHE postprocessing.
|
| 35 |
+
""")
|
| 36 |
+
|
| 37 |
+
with gr.Row():
|
| 38 |
+
file_input = gr.File(label="Upload Image (PNG, JPEG, or DICOM)")
|
| 39 |
+
apply_clahe_checkbox = gr.Checkbox(label="Apply CLAHE Postprocessing", value=False)
|
| 40 |
+
apply_pre_contrast_adjustment_checkbox = gr.Checkbox(label="Apply PreContrast Adjustment", value=False)
|
| 41 |
+
return_original_size_checkbox = gr.Checkbox(label="Return Original Size", value=True)
|
| 42 |
+
|
| 43 |
+
process_button = gr.Button("Process Image")
|
| 44 |
+
output_image = gr.Image(label="Processed Image")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
process_button.click(
|
| 49 |
+
process_image_from_bytes,
|
| 50 |
+
inputs=[file_input, apply_clahe_checkbox,apply_pre_contrast_adjustment_checkbox,return_original_size_checkbox],
|
| 51 |
+
outputs=output_image
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
return demo
|
| 55 |
+
|
| 56 |
+
# Launch the Gradio interface
|
| 57 |
+
demo = gradio_interface()
|
| 58 |
+
|
| 59 |
+
demo.launch(
|
| 60 |
+
debug=True,
|
| 61 |
+
)
|
configs/config.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
source: "huggingface" # Options: "huggingface" or "local"
|
| 3 |
+
repo_id: "SerdarHelli/super_res_xray" # Required if source is "huggingface"
|
| 4 |
+
filename: "net_g.pth" # Model weights filename in HF repo
|
| 5 |
+
weights: "/path/to/weights.pth" # Optional if using local weights
|
| 6 |
+
scale: 4
|
| 7 |
+
device: "cuda" # Options: "cuda", "cpu"
|
| 8 |
+
|
| 9 |
+
preprocessing:
|
| 10 |
+
unsharping_mask:
|
| 11 |
+
kernel_size : 7
|
| 12 |
+
strength: 2
|
| 13 |
+
brightness:
|
| 14 |
+
factor : 1.2
|
| 15 |
+
|
| 16 |
+
postprocessing:
|
| 17 |
+
clahe:
|
| 18 |
+
clipLimit: 2
|
| 19 |
+
tileGridSize:
|
| 20 |
+
- 16
|
| 21 |
+
- 16
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy == 1.26.4
|
| 2 |
+
opencv-python == 4.10.0.84
|
| 3 |
+
Pillow == 11.0.0
|
| 4 |
+
torch == 2.5.1
|
| 5 |
+
torchvision == 0.20.1
|
| 6 |
+
tqdm == 4.67.1
|
| 7 |
+
pydicom == 3.0.1
|
| 8 |
+
fastapi == 0.115.6
|
| 9 |
+
uvicorn == 0.34.0
|
| 10 |
+
scikit-image == 0.25.0
|
| 11 |
+
python-multipart == 0.0.20
|
| 12 |
+
huggingface-hub == 0.25.2
|
src/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .network.model import RealESRGAN
|
| 2 |
+
from .pipeline import InferencePipeline
|
| 3 |
+
from .preprocess import *
|
src/app/__init__.py
ADDED
|
File without changes
|
src/app/config.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
|
| 3 |
+
def load_config(config_path="configs/config.yaml"):
|
| 4 |
+
"""
|
| 5 |
+
Load the configuration from a YAML file.
|
| 6 |
+
|
| 7 |
+
Args:
|
| 8 |
+
config_path: Path to the configuration file.
|
| 9 |
+
|
| 10 |
+
Returns:
|
| 11 |
+
dict: Configuration dictionary.
|
| 12 |
+
"""
|
| 13 |
+
try:
|
| 14 |
+
with open(config_path, "r") as file:
|
| 15 |
+
config = yaml.safe_load(file)
|
| 16 |
+
return config
|
| 17 |
+
except FileNotFoundError:
|
| 18 |
+
raise Exception(f"Configuration file '{config_path}' not found.")
|
| 19 |
+
except yaml.YAMLError as e:
|
| 20 |
+
raise Exception(f"Error parsing configuration file: {e}")
|
src/app/exceptions.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class ModelLoadError(Exception):
|
| 2 |
+
"""Raised when the model fails to load."""
|
| 3 |
+
def __init__(self, message="Failed to load the model."):
|
| 4 |
+
self.message = message
|
| 5 |
+
super().__init__(self.message)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class PreprocessingError(Exception):
|
| 9 |
+
"""Raised when an error occurs during preprocessing."""
|
| 10 |
+
def __init__(self, message="Error during image preprocessing."):
|
| 11 |
+
self.message = message
|
| 12 |
+
super().__init__(self.message)
|
| 13 |
+
|
| 14 |
+
class PostprocessingError(Exception):
|
| 15 |
+
"""Raised when an error occurs during postprocessing."""
|
| 16 |
+
def __init__(self, message="Error during image postprocessing."):
|
| 17 |
+
self.message = message
|
| 18 |
+
super().__init__(self.message)
|
| 19 |
+
|
| 20 |
+
class InferenceError(Exception):
|
| 21 |
+
"""Raised when an error occurs during inference."""
|
| 22 |
+
def __init__(self, message="Error during inference."):
|
| 23 |
+
self.message = message
|
| 24 |
+
super().__init__(self.message)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class InputError(Exception):
|
| 28 |
+
"""Raised when an error occurs during loading input."""
|
| 29 |
+
def __init__(self, message="Error loading input."):
|
| 30 |
+
self.message = message
|
| 31 |
+
super().__init__(self.message)
|
src/app/main.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI,Request,status
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from fastapi.responses import JSONResponse
|
| 4 |
+
from src.app.routes import inference
|
| 5 |
+
from src.app.exceptions import ModelLoadError, PreprocessingError, InferenceError,InputError, PostprocessingError
|
| 6 |
+
import torch
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
app = FastAPI(title="Super Resolution Dental X-ray API", version="1.0.0")
|
| 11 |
+
app.add_middleware(CORSMiddleware, allow_origins=["*"])
|
| 12 |
+
|
| 13 |
+
# Include API routes
|
| 14 |
+
app.include_router(inference.router, prefix="/inference", tags=["Inference"])
|
| 15 |
+
|
| 16 |
+
@app.get("/")
|
| 17 |
+
def read_root():
|
| 18 |
+
return {"message": "Welcome to the Super Resolution Dental X-ray API"}
|
| 19 |
+
|
| 20 |
+
@app.get("/health", tags=["Health"])
|
| 21 |
+
async def health_check():
|
| 22 |
+
"""
|
| 23 |
+
Health check endpoint to ensure the API and CUDA are running.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
dict: Status message indicating the API and CUDA availability.
|
| 27 |
+
"""
|
| 28 |
+
def bash(command):
|
| 29 |
+
return os.popen(command).read()
|
| 30 |
+
|
| 31 |
+
# Check CUDA status
|
| 32 |
+
|
| 33 |
+
# Construct response
|
| 34 |
+
return {
|
| 35 |
+
"status": "Healthy",
|
| 36 |
+
"message": "API is running successfully.",
|
| 37 |
+
"cuda": {
|
| 38 |
+
"sys.version": sys.version,
|
| 39 |
+
"torch.__version__": torch.__version__,
|
| 40 |
+
"torch.cuda.is_available()": torch.cuda.is_available(),
|
| 41 |
+
"torch.version.cuda": torch.version.cuda,
|
| 42 |
+
"torch.backends.cudnn.version()": torch.backends.cudnn.version(),
|
| 43 |
+
"torch.backends.cudnn.enabled": torch.backends.cudnn.enabled,
|
| 44 |
+
"nvidia-smi": bash('nvidia-smi')
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
# Custom exception handlers
|
| 49 |
+
@app.exception_handler(ModelLoadError)
|
| 50 |
+
async def model_load_error_handler(request: Request, exc: ModelLoadError):
|
| 51 |
+
return JSONResponse(
|
| 52 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 53 |
+
content={"error": "ModelLoadError", "message": exc.message},
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
@app.exception_handler(PreprocessingError)
|
| 57 |
+
async def preprocessing_error_handler(request: Request, exc: PreprocessingError):
|
| 58 |
+
return JSONResponse(
|
| 59 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 60 |
+
content={"error": "PreprocessingError", "message": exc.message},
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
@app.exception_handler(PostprocessingError)
|
| 64 |
+
async def postprocessing_error_handler(request: Request, exc: PostprocessingError):
|
| 65 |
+
return JSONResponse(
|
| 66 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 67 |
+
content={"error": "PostprocessingError", "message": exc.message},
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
@app.exception_handler(InferenceError)
|
| 71 |
+
async def inference_error_handler(request: Request, exc: InferenceError):
|
| 72 |
+
return JSONResponse(
|
| 73 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 74 |
+
content={"error": "InferenceError", "message": exc.message},
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@app.exception_handler(InputError)
|
| 79 |
+
async def input_load_error_handler(request: Request, exc: InputError):
|
| 80 |
+
return JSONResponse(
|
| 81 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 82 |
+
content={"error": "InputError", "message": exc.message},
|
| 83 |
+
)
|
src/app/routes/__init__.py
ADDED
|
File without changes
|
src/app/routes/inference.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, UploadFile, File
|
| 2 |
+
from io import BytesIO
|
| 3 |
+
from src.app.config import load_config
|
| 4 |
+
from src.pipeline import InferencePipeline
|
| 5 |
+
# Define the router
|
| 6 |
+
router = APIRouter()
|
| 7 |
+
|
| 8 |
+
# Load configuration
|
| 9 |
+
config = load_config()
|
| 10 |
+
inference_pipeline = InferencePipeline(config)
|
| 11 |
+
|
| 12 |
+
from fastapi import APIRouter, UploadFile, File, HTTPException
|
| 13 |
+
from io import BytesIO
|
| 14 |
+
import os
|
| 15 |
+
from fastapi import HTTPException
|
| 16 |
+
from fastapi.responses import FileResponse
|
| 17 |
+
|
| 18 |
+
@router.post("/predict")
|
| 19 |
+
async def process_image(
|
| 20 |
+
file: UploadFile = File(...),
|
| 21 |
+
apply_clahe_postprocess: bool = False,
|
| 22 |
+
apply_pre_contrast_adjustment: bool = True,
|
| 23 |
+
return_original_size: bool = True
|
| 24 |
+
|
| 25 |
+
):
|
| 26 |
+
"""
|
| 27 |
+
API endpoint to process and super-resolve an image.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
file: Image file to process (PNG, JPEG, or DICOM).
|
| 31 |
+
apply_clahe_postprocess: Boolean indicating if CLAHE should be applied post-processing.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
FileResponse: Processed image file or error message.
|
| 35 |
+
"""
|
| 36 |
+
try:
|
| 37 |
+
|
| 38 |
+
# Validate apply_clahe_postprocess parameter
|
| 39 |
+
if not isinstance(apply_clahe_postprocess, bool):
|
| 40 |
+
raise HTTPException(
|
| 41 |
+
status_code=400,
|
| 42 |
+
detail="The 'apply_clahe_postprocess' parameter must be a boolean."
|
| 43 |
+
)
|
| 44 |
+
if not isinstance(apply_pre_contrast_adjustment, bool):
|
| 45 |
+
raise HTTPException(
|
| 46 |
+
status_code=400,
|
| 47 |
+
detail="The 'apply_pre_contrast_adjustment' parameter must be a boolean."
|
| 48 |
+
)
|
| 49 |
+
if not isinstance(return_original_size, bool):
|
| 50 |
+
raise HTTPException(
|
| 51 |
+
status_code=400,
|
| 52 |
+
detail="The 'return_original_size' parameter must be a boolean."
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Read the uploaded file into memory
|
| 56 |
+
file_bytes = await file.read()
|
| 57 |
+
|
| 58 |
+
# Perform inference with the pipeline
|
| 59 |
+
sr_image = inference_pipeline.run(BytesIO(file_bytes), apply_clahe_postprocess=apply_clahe_postprocess,
|
| 60 |
+
apply_pre_contrast_adjustment = apply_pre_contrast_adjustment,
|
| 61 |
+
return_original_size = return_original_size
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Save the processed image to a temporary file
|
| 65 |
+
output_file_path = "output_highres.png"
|
| 66 |
+
sr_image.save(output_file_path, format="PNG")
|
| 67 |
+
|
| 68 |
+
# Return the file as a response
|
| 69 |
+
return FileResponse(
|
| 70 |
+
path=output_file_path,
|
| 71 |
+
media_type="image/png",
|
| 72 |
+
filename="processed_image.png"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
except HTTPException as e:
|
| 76 |
+
raise e
|
| 77 |
+
|
| 78 |
+
except Exception as e:
|
| 79 |
+
raise HTTPException(
|
| 80 |
+
status_code=500,
|
| 81 |
+
detail=f"An error occurred during processing: {str(e)}"
|
| 82 |
+
)
|
| 83 |
+
finally:
|
| 84 |
+
# Cleanup temporary file if it exists
|
| 85 |
+
if os.path.exists("output_highres.png"):
|
| 86 |
+
os.remove("output_highres.png")
|
| 87 |
+
|
src/network/__init__.py
ADDED
|
File without changes
|
src/network/arch_utils.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn as nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
from torch.nn import init as init
|
| 6 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
| 7 |
+
|
| 8 |
+
@torch.no_grad()
|
| 9 |
+
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
|
| 10 |
+
"""Initialize network weights.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
|
| 14 |
+
scale (float): Scale initialized weights, especially for residual
|
| 15 |
+
blocks. Default: 1.
|
| 16 |
+
bias_fill (float): The value to fill bias. Default: 0
|
| 17 |
+
kwargs (dict): Other arguments for initialization function.
|
| 18 |
+
"""
|
| 19 |
+
if not isinstance(module_list, list):
|
| 20 |
+
module_list = [module_list]
|
| 21 |
+
for module in module_list:
|
| 22 |
+
for m in module.modules():
|
| 23 |
+
if isinstance(m, nn.Conv2d):
|
| 24 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
| 25 |
+
m.weight.data *= scale
|
| 26 |
+
if m.bias is not None:
|
| 27 |
+
m.bias.data.fill_(bias_fill)
|
| 28 |
+
elif isinstance(m, nn.Linear):
|
| 29 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
| 30 |
+
m.weight.data *= scale
|
| 31 |
+
if m.bias is not None:
|
| 32 |
+
m.bias.data.fill_(bias_fill)
|
| 33 |
+
elif isinstance(m, _BatchNorm):
|
| 34 |
+
init.constant_(m.weight, 1)
|
| 35 |
+
if m.bias is not None:
|
| 36 |
+
m.bias.data.fill_(bias_fill)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def make_layer(basic_block, num_basic_block, **kwarg):
|
| 40 |
+
"""Make layers by stacking the same blocks.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
basic_block (nn.module): nn.module class for basic block.
|
| 44 |
+
num_basic_block (int): number of blocks.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
nn.Sequential: Stacked blocks in nn.Sequential.
|
| 48 |
+
"""
|
| 49 |
+
layers = []
|
| 50 |
+
for _ in range(num_basic_block):
|
| 51 |
+
layers.append(basic_block(**kwarg))
|
| 52 |
+
return nn.Sequential(*layers)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class ResidualBlockNoBN(nn.Module):
|
| 56 |
+
"""Residual block without BN.
|
| 57 |
+
|
| 58 |
+
It has a style of:
|
| 59 |
+
---Conv-ReLU-Conv-+-
|
| 60 |
+
|________________|
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
num_feat (int): Channel number of intermediate features.
|
| 64 |
+
Default: 64.
|
| 65 |
+
res_scale (float): Residual scale. Default: 1.
|
| 66 |
+
pytorch_init (bool): If set to True, use pytorch default init,
|
| 67 |
+
otherwise, use default_init_weights. Default: False.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
|
| 71 |
+
super(ResidualBlockNoBN, self).__init__()
|
| 72 |
+
self.res_scale = res_scale
|
| 73 |
+
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
| 74 |
+
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
| 75 |
+
self.relu = nn.ReLU(inplace=True)
|
| 76 |
+
|
| 77 |
+
if not pytorch_init:
|
| 78 |
+
default_init_weights([self.conv1, self.conv2], 0.1)
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
identity = x
|
| 82 |
+
out = self.conv2(self.relu(self.conv1(x)))
|
| 83 |
+
return identity + out * self.res_scale
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class Upsample(nn.Sequential):
|
| 87 |
+
"""Upsample module.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
| 91 |
+
num_feat (int): Channel number of intermediate features.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def __init__(self, scale, num_feat):
|
| 95 |
+
m = []
|
| 96 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
| 97 |
+
for _ in range(int(math.log(scale, 2))):
|
| 98 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
| 99 |
+
m.append(nn.PixelShuffle(2))
|
| 100 |
+
elif scale == 3:
|
| 101 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
| 102 |
+
m.append(nn.PixelShuffle(3))
|
| 103 |
+
else:
|
| 104 |
+
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
| 105 |
+
super(Upsample, self).__init__(*m)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
|
| 109 |
+
"""Warp an image or feature map with optical flow.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
x (Tensor): Tensor with size (n, c, h, w).
|
| 113 |
+
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
|
| 114 |
+
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
|
| 115 |
+
padding_mode (str): 'zeros' or 'border' or 'reflection'.
|
| 116 |
+
Default: 'zeros'.
|
| 117 |
+
align_corners (bool): Before pytorch 1.3, the default value is
|
| 118 |
+
align_corners=True. After pytorch 1.3, the default value is
|
| 119 |
+
align_corners=False. Here, we use the True as default.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
Tensor: Warped image or feature map.
|
| 123 |
+
"""
|
| 124 |
+
assert x.size()[-2:] == flow.size()[1:3]
|
| 125 |
+
_, _, h, w = x.size()
|
| 126 |
+
# create mesh grid
|
| 127 |
+
grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
|
| 128 |
+
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
|
| 129 |
+
grid.requires_grad = False
|
| 130 |
+
|
| 131 |
+
vgrid = grid + flow
|
| 132 |
+
# scale grid to [-1,1]
|
| 133 |
+
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
|
| 134 |
+
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
|
| 135 |
+
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
|
| 136 |
+
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
|
| 137 |
+
|
| 138 |
+
# TODO, what if align_corners=False
|
| 139 |
+
return output
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
|
| 143 |
+
"""Resize a flow according to ratio or shape.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
flow (Tensor): Precomputed flow. shape [N, 2, H, W].
|
| 147 |
+
size_type (str): 'ratio' or 'shape'.
|
| 148 |
+
sizes (list[int | float]): the ratio for resizing or the final output
|
| 149 |
+
shape.
|
| 150 |
+
1) The order of ratio should be [ratio_h, ratio_w]. For
|
| 151 |
+
downsampling, the ratio should be smaller than 1.0 (i.e., ratio
|
| 152 |
+
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
|
| 153 |
+
ratio > 1.0).
|
| 154 |
+
2) The order of output_size should be [out_h, out_w].
|
| 155 |
+
interp_mode (str): The mode of interpolation for resizing.
|
| 156 |
+
Default: 'bilinear'.
|
| 157 |
+
align_corners (bool): Whether align corners. Default: False.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
Tensor: Resized flow.
|
| 161 |
+
"""
|
| 162 |
+
_, _, flow_h, flow_w = flow.size()
|
| 163 |
+
if size_type == 'ratio':
|
| 164 |
+
output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
|
| 165 |
+
elif size_type == 'shape':
|
| 166 |
+
output_h, output_w = sizes[0], sizes[1]
|
| 167 |
+
else:
|
| 168 |
+
raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
|
| 169 |
+
|
| 170 |
+
input_flow = flow.clone()
|
| 171 |
+
ratio_h = output_h / flow_h
|
| 172 |
+
ratio_w = output_w / flow_w
|
| 173 |
+
input_flow[:, 0, :, :] *= ratio_w
|
| 174 |
+
input_flow[:, 1, :, :] *= ratio_h
|
| 175 |
+
resized_flow = F.interpolate(
|
| 176 |
+
input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
|
| 177 |
+
return resized_flow
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# TODO: may write a cpp file
|
| 181 |
+
def pixel_unshuffle(x, scale):
|
| 182 |
+
""" Pixel unshuffle.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
x (Tensor): Input feature with shape (b, c, hh, hw).
|
| 186 |
+
scale (int): Downsample ratio.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
Tensor: the pixel unshuffled feature.
|
| 190 |
+
"""
|
| 191 |
+
b, c, hh, hw = x.size()
|
| 192 |
+
out_channel = c * (scale**2)
|
| 193 |
+
assert hh % scale == 0 and hw % scale == 0
|
| 194 |
+
h = hh // scale
|
| 195 |
+
w = hw // scale
|
| 196 |
+
x_view = x.view(b, c, h, scale, w, scale)
|
| 197 |
+
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
src/network/model.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from .rrdbnet_arch import RRDBNet
|
| 7 |
+
from .utils import pad_reflect, split_image_into_overlapping_patches, stich_together, \
|
| 8 |
+
unpad_image
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RealESRGAN:
|
| 14 |
+
def __init__(self, device, scale=4):
|
| 15 |
+
self.device = device
|
| 16 |
+
self.scale = scale
|
| 17 |
+
self.model = RRDBNet(
|
| 18 |
+
num_in_ch=3, num_out_ch=3, num_feat=64,
|
| 19 |
+
num_block=23, num_grow_ch=32, scale=scale
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def load_weights(self, model_path):
|
| 23 |
+
|
| 24 |
+
loadnet = torch.load(model_path)
|
| 25 |
+
if 'params' in loadnet:
|
| 26 |
+
self.model.load_state_dict(loadnet['params'], strict=True)
|
| 27 |
+
elif 'params_ema' in loadnet:
|
| 28 |
+
self.model.load_state_dict(loadnet['params_ema'], strict=True)
|
| 29 |
+
else:
|
| 30 |
+
self.model.load_state_dict(loadnet, strict=True)
|
| 31 |
+
self.model.eval()
|
| 32 |
+
self.model.to(self.device)
|
| 33 |
+
|
| 34 |
+
@torch.cuda.amp.autocast()
|
| 35 |
+
def predict(self, lr_image, batch_size=4, patches_size=192,
|
| 36 |
+
padding=24, pad_size=15):
|
| 37 |
+
scale = self.scale
|
| 38 |
+
device = self.device
|
| 39 |
+
lr_image = np.array(lr_image)
|
| 40 |
+
lr_image = pad_reflect(lr_image, pad_size)
|
| 41 |
+
|
| 42 |
+
patches, p_shape = split_image_into_overlapping_patches(
|
| 43 |
+
lr_image, patch_size=patches_size, padding_size=padding
|
| 44 |
+
)
|
| 45 |
+
img = torch.FloatTensor(patches/255).permute((0,3,1,2)).to(device).detach()
|
| 46 |
+
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
res = self.model(img[0:batch_size])
|
| 49 |
+
for i in range(batch_size, img.shape[0], batch_size):
|
| 50 |
+
res = torch.cat((res, self.model(img[i:i+batch_size])), 0)
|
| 51 |
+
|
| 52 |
+
sr_image = res.permute((0,2,3,1)).clamp_(0, 1).cpu()
|
| 53 |
+
np_sr_image = sr_image.numpy()
|
| 54 |
+
|
| 55 |
+
padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
|
| 56 |
+
scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
|
| 57 |
+
np_sr_image = stich_together(
|
| 58 |
+
np_sr_image, padded_image_shape=padded_size_scaled,
|
| 59 |
+
target_shape=scaled_image_shape, padding_size=padding * scale
|
| 60 |
+
)
|
| 61 |
+
sr_img = (np_sr_image*255).astype(np.uint8)
|
| 62 |
+
sr_img = unpad_image(sr_img, pad_size*scale)
|
| 63 |
+
sr_img = Image.fromarray(sr_img)
|
| 64 |
+
|
| 65 |
+
return sr_img
|
src/network/rrdbnet_arch.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn as nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
from .arch_utils import default_init_weights, make_layer, pixel_unshuffle
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ResidualDenseBlock(nn.Module):
|
| 9 |
+
"""Residual Dense Block.
|
| 10 |
+
|
| 11 |
+
Used in RRDB block in ESRGAN.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
num_feat (int): Channel number of intermediate features.
|
| 15 |
+
num_grow_ch (int): Channels for each growth.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, num_feat=64, num_grow_ch=32):
|
| 19 |
+
super(ResidualDenseBlock, self).__init__()
|
| 20 |
+
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
| 21 |
+
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
| 22 |
+
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
| 23 |
+
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
| 24 |
+
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
| 25 |
+
|
| 26 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 27 |
+
|
| 28 |
+
# initialization
|
| 29 |
+
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
x1 = self.lrelu(self.conv1(x))
|
| 33 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
| 34 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
| 35 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
| 36 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
| 37 |
+
# Emperically, we use 0.2 to scale the residual for better performance
|
| 38 |
+
return x5 * 0.2 + x
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class RRDB(nn.Module):
|
| 42 |
+
"""Residual in Residual Dense Block.
|
| 43 |
+
|
| 44 |
+
Used in RRDB-Net in ESRGAN.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
num_feat (int): Channel number of intermediate features.
|
| 48 |
+
num_grow_ch (int): Channels for each growth.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self, num_feat, num_grow_ch=32):
|
| 52 |
+
super(RRDB, self).__init__()
|
| 53 |
+
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
| 54 |
+
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
| 55 |
+
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
out = self.rdb1(x)
|
| 59 |
+
out = self.rdb2(out)
|
| 60 |
+
out = self.rdb3(out)
|
| 61 |
+
# Emperically, we use 0.2 to scale the residual for better performance
|
| 62 |
+
return out * 0.2 + x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class RRDBNet(nn.Module):
|
| 66 |
+
"""Networks consisting of Residual in Residual Dense Block, which is used
|
| 67 |
+
in ESRGAN.
|
| 68 |
+
|
| 69 |
+
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
|
| 70 |
+
|
| 71 |
+
We extend ESRGAN for scale x2 and scale x1.
|
| 72 |
+
Note: This is one option for scale 1, scale 2 in RRDBNet.
|
| 73 |
+
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
|
| 74 |
+
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
num_in_ch (int): Channel number of inputs.
|
| 78 |
+
num_out_ch (int): Channel number of outputs.
|
| 79 |
+
num_feat (int): Channel number of intermediate features.
|
| 80 |
+
Default: 64
|
| 81 |
+
num_block (int): Block number in the trunk network. Defaults: 23
|
| 82 |
+
num_grow_ch (int): Channels for each growth. Default: 32.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
|
| 86 |
+
super(RRDBNet, self).__init__()
|
| 87 |
+
self.scale = scale
|
| 88 |
+
if scale == 2:
|
| 89 |
+
num_in_ch = num_in_ch * 4
|
| 90 |
+
elif scale == 1:
|
| 91 |
+
num_in_ch = num_in_ch * 16
|
| 92 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
| 93 |
+
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
|
| 94 |
+
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
| 95 |
+
# upsample
|
| 96 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
| 97 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
| 98 |
+
if scale == 8:
|
| 99 |
+
self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
| 100 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
| 101 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
| 102 |
+
|
| 103 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 104 |
+
|
| 105 |
+
def forward(self, x):
|
| 106 |
+
if self.scale == 2:
|
| 107 |
+
feat = pixel_unshuffle(x, scale=2)
|
| 108 |
+
elif self.scale == 1:
|
| 109 |
+
feat = pixel_unshuffle(x, scale=4)
|
| 110 |
+
else:
|
| 111 |
+
feat = x
|
| 112 |
+
feat = self.conv_first(feat)
|
| 113 |
+
body_feat = self.conv_body(self.body(feat))
|
| 114 |
+
feat = feat + body_feat
|
| 115 |
+
# upsample
|
| 116 |
+
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
| 117 |
+
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
| 118 |
+
if self.scale == 8:
|
| 119 |
+
feat = self.lrelu(self.conv_up3(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
| 120 |
+
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
| 121 |
+
return out
|
src/network/utils.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import os
|
| 5 |
+
import io
|
| 6 |
+
|
| 7 |
+
def pad_reflect(image, pad_size):
|
| 8 |
+
imsize = image.shape
|
| 9 |
+
height, width = imsize[:2]
|
| 10 |
+
new_img = np.zeros([height+pad_size*2, width+pad_size*2, imsize[2]]).astype(np.uint8)
|
| 11 |
+
new_img[pad_size:-pad_size, pad_size:-pad_size, :] = image
|
| 12 |
+
|
| 13 |
+
new_img[0:pad_size, pad_size:-pad_size, :] = np.flip(image[0:pad_size, :, :], axis=0) #top
|
| 14 |
+
new_img[-pad_size:, pad_size:-pad_size, :] = np.flip(image[-pad_size:, :, :], axis=0) #bottom
|
| 15 |
+
new_img[:, 0:pad_size, :] = np.flip(new_img[:, pad_size:pad_size*2, :], axis=1) #left
|
| 16 |
+
new_img[:, -pad_size:, :] = np.flip(new_img[:, -pad_size*2:-pad_size, :], axis=1) #right
|
| 17 |
+
|
| 18 |
+
return new_img
|
| 19 |
+
|
| 20 |
+
def unpad_image(image, pad_size):
|
| 21 |
+
return image[pad_size:-pad_size, pad_size:-pad_size, :]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def process_array(image_array, expand=True):
|
| 25 |
+
""" Process a 3-dimensional array into a scaled, 4 dimensional batch of size 1. """
|
| 26 |
+
|
| 27 |
+
image_batch = image_array / 255.0
|
| 28 |
+
if expand:
|
| 29 |
+
image_batch = np.expand_dims(image_batch, axis=0)
|
| 30 |
+
return image_batch
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def process_output(output_tensor):
|
| 34 |
+
""" Transforms the 4-dimensional output tensor into a suitable image format. """
|
| 35 |
+
|
| 36 |
+
sr_img = output_tensor.clip(0, 1) * 255
|
| 37 |
+
sr_img = np.uint8(sr_img)
|
| 38 |
+
return sr_img
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def pad_patch(image_patch, padding_size, channel_last=True):
|
| 42 |
+
""" Pads image_patch with with padding_size edge values. """
|
| 43 |
+
|
| 44 |
+
if channel_last:
|
| 45 |
+
return np.pad(
|
| 46 |
+
image_patch,
|
| 47 |
+
((padding_size, padding_size), (padding_size, padding_size), (0, 0)),
|
| 48 |
+
'edge',
|
| 49 |
+
)
|
| 50 |
+
else:
|
| 51 |
+
return np.pad(
|
| 52 |
+
image_patch,
|
| 53 |
+
((0, 0), (padding_size, padding_size), (padding_size, padding_size)),
|
| 54 |
+
'edge',
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def unpad_patches(image_patches, padding_size):
|
| 59 |
+
return image_patches[:, padding_size:-padding_size, padding_size:-padding_size, :]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def split_image_into_overlapping_patches(image_array, patch_size, padding_size=2):
|
| 63 |
+
""" Splits the image into partially overlapping patches.
|
| 64 |
+
The patches overlap by padding_size pixels.
|
| 65 |
+
Pads the image twice:
|
| 66 |
+
- first to have a size multiple of the patch size,
|
| 67 |
+
- then to have equal padding at the borders.
|
| 68 |
+
Args:
|
| 69 |
+
image_array: numpy array of the input image.
|
| 70 |
+
patch_size: size of the patches from the original image (without padding).
|
| 71 |
+
padding_size: size of the overlapping area.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
xmax, ymax, _ = image_array.shape
|
| 75 |
+
x_remainder = xmax % patch_size
|
| 76 |
+
y_remainder = ymax % patch_size
|
| 77 |
+
|
| 78 |
+
# modulo here is to avoid extending of patch_size instead of 0
|
| 79 |
+
x_extend = (patch_size - x_remainder) % patch_size
|
| 80 |
+
y_extend = (patch_size - y_remainder) % patch_size
|
| 81 |
+
|
| 82 |
+
# make sure the image is divisible into regular patches
|
| 83 |
+
extended_image = np.pad(image_array, ((0, x_extend), (0, y_extend), (0, 0)), 'edge')
|
| 84 |
+
|
| 85 |
+
# add padding around the image to simplify computations
|
| 86 |
+
padded_image = pad_patch(extended_image, padding_size, channel_last=True)
|
| 87 |
+
|
| 88 |
+
xmax, ymax, _ = padded_image.shape
|
| 89 |
+
patches = []
|
| 90 |
+
|
| 91 |
+
x_lefts = range(padding_size, xmax - padding_size, patch_size)
|
| 92 |
+
y_tops = range(padding_size, ymax - padding_size, patch_size)
|
| 93 |
+
|
| 94 |
+
for x in x_lefts:
|
| 95 |
+
for y in y_tops:
|
| 96 |
+
x_left = x - padding_size
|
| 97 |
+
y_top = y - padding_size
|
| 98 |
+
x_right = x + patch_size + padding_size
|
| 99 |
+
y_bottom = y + patch_size + padding_size
|
| 100 |
+
patch = padded_image[x_left:x_right, y_top:y_bottom, :]
|
| 101 |
+
patches.append(patch)
|
| 102 |
+
|
| 103 |
+
return np.array(patches), padded_image.shape
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def stich_together(patches, padded_image_shape, target_shape, padding_size=4):
|
| 107 |
+
""" Reconstruct the image from overlapping patches.
|
| 108 |
+
After scaling, shapes and padding should be scaled too.
|
| 109 |
+
Args:
|
| 110 |
+
patches: patches obtained with split_image_into_overlapping_patches
|
| 111 |
+
padded_image_shape: shape of the padded image contructed in split_image_into_overlapping_patches
|
| 112 |
+
target_shape: shape of the final image
|
| 113 |
+
padding_size: size of the overlapping area.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
xmax, ymax, _ = padded_image_shape
|
| 117 |
+
patches = unpad_patches(patches, padding_size)
|
| 118 |
+
patch_size = patches.shape[1]
|
| 119 |
+
n_patches_per_row = ymax // patch_size
|
| 120 |
+
|
| 121 |
+
complete_image = np.zeros((xmax, ymax, 3))
|
| 122 |
+
|
| 123 |
+
row = -1
|
| 124 |
+
col = 0
|
| 125 |
+
for i in range(len(patches)):
|
| 126 |
+
if i % n_patches_per_row == 0:
|
| 127 |
+
row += 1
|
| 128 |
+
col = 0
|
| 129 |
+
complete_image[
|
| 130 |
+
row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size,:
|
| 131 |
+
] = patches[i]
|
| 132 |
+
col += 1
|
| 133 |
+
return complete_image[0: target_shape[0], 0: target_shape[1], :]
|
src/pipeline.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import numpy as np
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from src.preprocess import read_xray, enhance_exposure, unsharp_masking, apply_clahe, resize_pil_image, increase_brightness
|
| 9 |
+
from src.network.model import RealESRGAN
|
| 10 |
+
from src.app.exceptions import InputError, ModelLoadError, PreprocessingError, InferenceError,PostprocessingError
|
| 11 |
+
|
| 12 |
+
class ModelLoadError(Exception):
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
class InferencePipeline:
|
| 16 |
+
def __init__(self, config):
|
| 17 |
+
"""
|
| 18 |
+
Initialize the inference pipeline using configuration.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
config: Configuration dictionary.
|
| 22 |
+
"""
|
| 23 |
+
self.config = config
|
| 24 |
+
self.device = config["model"].get("device", "cuda" if torch.cuda.is_available() else "cpu")
|
| 25 |
+
self.scale = config["model"].get("scale", 4)
|
| 26 |
+
|
| 27 |
+
model_source = config["model"].get("source", "local")
|
| 28 |
+
self.model = RealESRGAN(self.device, scale=self.scale)
|
| 29 |
+
|
| 30 |
+
print(f"Using device: {self.device}")
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
if model_source == "huggingface":
|
| 34 |
+
repo_id = config["model"]["repo_id"]
|
| 35 |
+
filename = config["model"]["filename"]
|
| 36 |
+
local_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 37 |
+
self.load_weights(local_path)
|
| 38 |
+
else:
|
| 39 |
+
local_path = config["model"]["weights"]
|
| 40 |
+
self.load_weights(local_path)
|
| 41 |
+
except Exception as e:
|
| 42 |
+
raise ModelLoadError(f"Failed to load the model: {str(e)}")
|
| 43 |
+
|
| 44 |
+
def load_weights(self, model_weights):
|
| 45 |
+
"""
|
| 46 |
+
Load the model weights.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
model_weights: Path to the model weights file.
|
| 50 |
+
"""
|
| 51 |
+
try:
|
| 52 |
+
self.model.load_weights(model_weights)
|
| 53 |
+
except FileNotFoundError:
|
| 54 |
+
raise ModelLoadError(f"Model weights not found at '{model_weights}'.")
|
| 55 |
+
except Exception as e:
|
| 56 |
+
raise ModelLoadError(f"Error loading weights: {str(e)}")
|
| 57 |
+
def preprocess(self, image_path_or_bytes, apply_pre_contrast_adjustment=True, is_dicom=False):
|
| 58 |
+
"""
|
| 59 |
+
Preprocess the input image.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
image_path: Path to the input image file.
|
| 63 |
+
is_dicom: Boolean indicating if the input is a DICOM file.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
PIL Image: Preprocessed image.
|
| 67 |
+
"""
|
| 68 |
+
try:
|
| 69 |
+
if is_dicom:
|
| 70 |
+
img = read_xray(image_path_or_bytes)
|
| 71 |
+
else:
|
| 72 |
+
img = Image.open(image_path_or_bytes)
|
| 73 |
+
|
| 74 |
+
if apply_pre_contrast_adjustment:
|
| 75 |
+
img = enhance_exposure(np.array(img))
|
| 76 |
+
|
| 77 |
+
if isinstance(img,np.ndarray):
|
| 78 |
+
img = Image.fromarray(((img / np.max(img))*255).astype(np.uint8))
|
| 79 |
+
|
| 80 |
+
if img.mode not in ['RGB']:
|
| 81 |
+
img = img.convert('RGB')
|
| 82 |
+
|
| 83 |
+
img = unsharp_masking(
|
| 84 |
+
img,
|
| 85 |
+
self.config["preprocessing"]["unsharping_mask"].get("kernel_size", 7),
|
| 86 |
+
self.config["preprocessing"]["unsharping_mask"].get("strength", 2)
|
| 87 |
+
)
|
| 88 |
+
img = increase_brightness(
|
| 89 |
+
img,
|
| 90 |
+
self.config["preprocessing"]["brightness"].get("factor", 1.2),
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if img.mode not in ['RGB']:
|
| 95 |
+
img = img.convert('RGB')
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
return img, img.size
|
| 99 |
+
except Exception as e:
|
| 100 |
+
raise PreprocessingError(f"Error during preprocessing: {str(e)}")
|
| 101 |
+
|
| 102 |
+
def postprocess(self, image_array):
|
| 103 |
+
"""
|
| 104 |
+
Postprocess the output from the model.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
image_array: PIL.Image output from the model.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
PIL Image: Postprocessed image.
|
| 111 |
+
"""
|
| 112 |
+
try:
|
| 113 |
+
return apply_clahe(
|
| 114 |
+
image_array,
|
| 115 |
+
self.config["postprocessing"]["clahe"].get("clipLimit", 2.0),
|
| 116 |
+
tuple(self.config["postprocessing"]["clahe"].get("tileGridSize", [16, 16]))
|
| 117 |
+
)
|
| 118 |
+
except Exception as e:
|
| 119 |
+
raise PostprocessingError(f"Error during postprocessing: {str(e)}")
|
| 120 |
+
|
| 121 |
+
def is_dicom(self, file_path_or_bytes):
|
| 122 |
+
"""
|
| 123 |
+
Check if the input file is a DICOM file.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
file_path_or_bytes (str or bytes or BytesIO): Path to the file, byte content, or BytesIO object.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
bool: True if the file is a DICOM file, False otherwise.
|
| 130 |
+
"""
|
| 131 |
+
try:
|
| 132 |
+
if isinstance(file_path_or_bytes, str):
|
| 133 |
+
# Check the file extension
|
| 134 |
+
file_extension = Path(file_path_or_bytes).suffix.lower()
|
| 135 |
+
if file_extension in ['.dcm', '.dicom']:
|
| 136 |
+
return True
|
| 137 |
+
|
| 138 |
+
# Open the file and check the header
|
| 139 |
+
with open(file_path_or_bytes, 'rb') as file:
|
| 140 |
+
header = file.read(132)
|
| 141 |
+
return header[-4:] == b'DICM'
|
| 142 |
+
|
| 143 |
+
elif isinstance(file_path_or_bytes, BytesIO):
|
| 144 |
+
file_path_or_bytes.seek(0)
|
| 145 |
+
header = file_path_or_bytes.read(132)
|
| 146 |
+
file_path_or_bytes.seek(0) # Reset the stream position
|
| 147 |
+
return header[-4:] == b'DICM'
|
| 148 |
+
|
| 149 |
+
elif isinstance(file_path_or_bytes, bytes):
|
| 150 |
+
header = file_path_or_bytes[:132]
|
| 151 |
+
return header[-4:] == b'DICM'
|
| 152 |
+
|
| 153 |
+
except Exception as e:
|
| 154 |
+
print(f"Error during DICOM validation: {e}")
|
| 155 |
+
return False
|
| 156 |
+
|
| 157 |
+
return False
|
| 158 |
+
|
| 159 |
+
def validate_input(self, input_data):
|
| 160 |
+
"""
|
| 161 |
+
Validate the input data to ensure it is suitable for processing.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
input_data: Path to the input file, bytes content, or BytesIO object.
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
bool: True if the input is valid, raises InputError otherwise.
|
| 168 |
+
"""
|
| 169 |
+
if isinstance(input_data, str):
|
| 170 |
+
# Check if the file exists
|
| 171 |
+
if not Path(input_data).exists():
|
| 172 |
+
raise InputError(f"Input file '{input_data}' does not exist.")
|
| 173 |
+
|
| 174 |
+
# Check if the file type is supported
|
| 175 |
+
file_extension = Path(input_data).suffix.lower()
|
| 176 |
+
if file_extension not in ['.png', '.jpeg', '.jpg', '.dcm', '.dicom']:
|
| 177 |
+
raise InputError(f"Unsupported file type '{file_extension}'. Supported types are PNG, JPEG, and DICOM.")
|
| 178 |
+
|
| 179 |
+
elif isinstance(input_data, BytesIO):
|
| 180 |
+
# Check if BytesIO data is not empty
|
| 181 |
+
if input_data.getbuffer().nbytes == 0:
|
| 182 |
+
raise InputError("Input BytesIO data is empty.")
|
| 183 |
+
|
| 184 |
+
else:
|
| 185 |
+
raise InputError("Unsupported input type. Must be a file path, byte content, or BytesIO object.")
|
| 186 |
+
|
| 187 |
+
return True
|
| 188 |
+
|
| 189 |
+
def infer(self, input_image):
|
| 190 |
+
"""
|
| 191 |
+
Perform inference on a single image.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
input_image: PIL Image to be processed.
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
PIL Image: Super-resolved image.
|
| 198 |
+
"""
|
| 199 |
+
try:
|
| 200 |
+
# Perform inference
|
| 201 |
+
input_array = np.array(input_image)
|
| 202 |
+
sr_array = self.model.predict(input_array)
|
| 203 |
+
return sr_array
|
| 204 |
+
|
| 205 |
+
except Exception as e:
|
| 206 |
+
raise InferenceError(f"Error during inference: {str(e)}")
|
| 207 |
+
|
| 208 |
+
def run(self, input_path, apply_pre_contrast_adjustment = True, apply_clahe_postprocess=False, return_original_size = True):
|
| 209 |
+
"""
|
| 210 |
+
Process a single image and save the output.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
input_path: Path to the input image file.
|
| 214 |
+
is_dicom: Boolean indicating if the input is a DICOM file.
|
| 215 |
+
apply_clahe_postprocess: Boolean indicating if CLAHE should be applied post-processing.
|
| 216 |
+
"""
|
| 217 |
+
# Validate the input
|
| 218 |
+
self.validate_input(input_path)
|
| 219 |
+
|
| 220 |
+
is_dicom =self.is_dicom(input_path)
|
| 221 |
+
|
| 222 |
+
img, original_size = self.preprocess(input_path, is_dicom=is_dicom, apply_pre_contrast_adjustment = apply_pre_contrast_adjustment)
|
| 223 |
+
|
| 224 |
+
if img is None:
|
| 225 |
+
raise InputError(f"Invalid Input")
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
sr_image = self.infer(img)
|
| 229 |
+
|
| 230 |
+
if apply_clahe_postprocess:
|
| 231 |
+
sr_image = self.postprocess(sr_image)
|
| 232 |
+
|
| 233 |
+
if return_original_size:
|
| 234 |
+
sr_image = resize_pil_image(sr_image, target_shape = original_size)
|
| 235 |
+
|
| 236 |
+
return sr_image
|
src/preprocess.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pydicom
|
| 2 |
+
import numpy as np
|
| 3 |
+
from pydicom.pixels import apply_voi_lut
|
| 4 |
+
from skimage import exposure
|
| 5 |
+
from PIL import Image,ImageEnhance
|
| 6 |
+
import cv2
|
| 7 |
+
|
| 8 |
+
def read_xray(path, voi_lut=True, fix_monochrome=True):
|
| 9 |
+
"""
|
| 10 |
+
Read and preprocess a DICOM X-ray image.
|
| 11 |
+
|
| 12 |
+
Parameters:
|
| 13 |
+
- path: Path to the DICOM file.
|
| 14 |
+
- voi_lut: Apply VOI LUT if available.
|
| 15 |
+
- fix_monochrome: Fix inverted monochrome images.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
- NumPy array: Preprocessed X-ray image.
|
| 19 |
+
"""
|
| 20 |
+
dicom = pydicom.dcmread(path)
|
| 21 |
+
|
| 22 |
+
# Apply VOI LUT if available
|
| 23 |
+
if voi_lut:
|
| 24 |
+
data = apply_voi_lut(dicom.pixel_array, dicom)
|
| 25 |
+
else:
|
| 26 |
+
data = dicom.pixel_array
|
| 27 |
+
|
| 28 |
+
# Fix inverted monochrome images
|
| 29 |
+
if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
|
| 30 |
+
data = np.amax(data) - data
|
| 31 |
+
|
| 32 |
+
# Normalize data to start from 0
|
| 33 |
+
data = data - np.min(data)
|
| 34 |
+
|
| 35 |
+
return data
|
| 36 |
+
|
| 37 |
+
def resize_pil_image(image: Image.Image, target_shape: tuple) -> Image.Image:
|
| 38 |
+
"""
|
| 39 |
+
Resizes a PIL image based on a target shape.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
image: Input PIL image.
|
| 43 |
+
target_shape: Desired shape for resizing. It can be a 2D tuple (height, width)
|
| 44 |
+
or a 3D tuple (height, width, channels), where channels will be ignored.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Resized PIL image.
|
| 48 |
+
"""
|
| 49 |
+
# Convert image to a numpy array
|
| 50 |
+
np_image = np.array(image)
|
| 51 |
+
|
| 52 |
+
# Extract the original height and width from the numpy array
|
| 53 |
+
height, width = np_image.shape[:2]
|
| 54 |
+
|
| 55 |
+
# If the target shape is 2D (height, width)
|
| 56 |
+
if len(target_shape) == 2:
|
| 57 |
+
new_width, new_height = target_shape
|
| 58 |
+
elif len(target_shape) == 3:
|
| 59 |
+
# If the target shape is 3D (height, width, channels), only change the first two dimensions
|
| 60 |
+
new_width, new_height = target_shape[:2]
|
| 61 |
+
else:
|
| 62 |
+
raise ValueError("Target shape must be either 2D or 3D.")
|
| 63 |
+
|
| 64 |
+
# Resize the image using cv2 or PIL's in-built resizing (no channels affected)
|
| 65 |
+
pil_resized_image = Image.fromarray(np_image).resize((new_width, new_height), Image.LANCZOS)
|
| 66 |
+
|
| 67 |
+
return pil_resized_image
|
| 68 |
+
|
| 69 |
+
def enhance_exposure(img):
|
| 70 |
+
"""
|
| 71 |
+
Enhance image exposure using histogram equalization.
|
| 72 |
+
|
| 73 |
+
Parameters:
|
| 74 |
+
- img: Input image as a NumPy array.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
- PIL.Image: Exposure-enhanced image.
|
| 78 |
+
"""
|
| 79 |
+
img = exposure.equalize_hist(img)
|
| 80 |
+
img = exposure.equalize_adapthist(img / np.max(img))
|
| 81 |
+
img = (img * 255).astype(np.uint8)
|
| 82 |
+
return Image.fromarray(img)
|
| 83 |
+
|
| 84 |
+
def unsharp_masking(image, kernel_size=5, strength=0.25):
|
| 85 |
+
"""
|
| 86 |
+
Apply unsharp masking to enhance image sharpness.
|
| 87 |
+
|
| 88 |
+
Parameters:
|
| 89 |
+
- image: Input image as a NumPy array or PIL.Image.
|
| 90 |
+
- kernel_size: Size of the Gaussian blur kernel.
|
| 91 |
+
- strength: Strength of the high-pass filter.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
- PIL.Image: Sharpened image.
|
| 95 |
+
"""
|
| 96 |
+
image = np.array(image)
|
| 97 |
+
|
| 98 |
+
# Convert to grayscale if needed
|
| 99 |
+
if len(image.shape) == 3:
|
| 100 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
| 101 |
+
else:
|
| 102 |
+
gray = image
|
| 103 |
+
|
| 104 |
+
# Apply Gaussian blur and calculate high-pass filter
|
| 105 |
+
blurred = cv2.GaussianBlur(gray, (kernel_size, kernel_size), 0)
|
| 106 |
+
high_pass = cv2.subtract(gray, blurred)
|
| 107 |
+
|
| 108 |
+
# Combine high-pass with original image
|
| 109 |
+
sharpened = cv2.addWeighted(gray, 1, high_pass, strength, 0)
|
| 110 |
+
|
| 111 |
+
return Image.fromarray(sharpened)
|
| 112 |
+
|
| 113 |
+
def increase_contrast(image: Image.Image, factor: float) -> Image.Image:
|
| 114 |
+
"""
|
| 115 |
+
Increases the contrast of the input PIL image by a given factor.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
image: Input PIL image.
|
| 119 |
+
factor: Factor by which to increase the contrast.
|
| 120 |
+
A factor of 1.0 means no change, values greater than 1.0 increase contrast,
|
| 121 |
+
values between 0.0 and 1.0 decrease contrast.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
Image with increased contrast.
|
| 125 |
+
"""
|
| 126 |
+
if image.mode not in ['RGB', 'L']:
|
| 127 |
+
image = image.convert('RGB')
|
| 128 |
+
|
| 129 |
+
enhancer = ImageEnhance.Contrast(image)
|
| 130 |
+
image_enhanced = enhancer.enhance(factor)
|
| 131 |
+
return image_enhanced
|
| 132 |
+
|
| 133 |
+
def increase_brightness(image: Image.Image, factor: float) -> Image.Image:
|
| 134 |
+
"""
|
| 135 |
+
Increases the brightness of the input PIL image by a given factor.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
image: Input PIL image.
|
| 139 |
+
factor: Factor by which to increase the brightness.
|
| 140 |
+
A factor of 1.0 means no change, values greater than 1.0 increase brightness,
|
| 141 |
+
values between 0.0 and 1.0 decrease brightness.
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
Image with increased brightness.
|
| 145 |
+
"""
|
| 146 |
+
if image.mode not in ['RGB', 'L']:
|
| 147 |
+
image = image.convert('RGB')
|
| 148 |
+
|
| 149 |
+
enhancer = ImageEnhance.Brightness(image)
|
| 150 |
+
image_enhanced = enhancer.enhance(factor)
|
| 151 |
+
return image_enhanced
|
| 152 |
+
|
| 153 |
+
def apply_clahe(image, clipLimit=2.0, tileGridSize=(8, 8)):
|
| 154 |
+
"""
|
| 155 |
+
Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) to an image.
|
| 156 |
+
|
| 157 |
+
Parameters:
|
| 158 |
+
- image: Input image as a PIL.Image.
|
| 159 |
+
- clipLimit: Threshold for contrast limiting.
|
| 160 |
+
- tileGridSize: Size of the grid for histogram equalization.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
- Processed image in the same format as the input (PIL.Image).
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
image_np = np.array(image)
|
| 167 |
+
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# Apply CLAHE based on image type
|
| 171 |
+
if len(image_np.shape) == 2:
|
| 172 |
+
# Grayscale image
|
| 173 |
+
clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
|
| 174 |
+
processed = clahe.apply(image_np)
|
| 175 |
+
else:
|
| 176 |
+
# Color image: Apply CLAHE on the L channel in LAB space
|
| 177 |
+
lab = cv2.cvtColor(image_np, cv2.COLOR_BGR2LAB)
|
| 178 |
+
L, A, B = cv2.split(lab)
|
| 179 |
+
|
| 180 |
+
clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
|
| 181 |
+
L_clahe = clahe.apply(L)
|
| 182 |
+
lab_clahe = cv2.merge((L_clahe, A, B))
|
| 183 |
+
processed = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2BGR)
|
| 184 |
+
|
| 185 |
+
processed_rgb = cv2.cvtColor(processed, cv2.COLOR_BGR2RGB)
|
| 186 |
+
return Image.fromarray(processed_rgb)
|
| 187 |
+
|
tests/test_inference_pipeline.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from fastapi.testclient import TestClient
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
import pydicom
|
| 7 |
+
from pydicom.dataset import Dataset, FileDataset
|
| 8 |
+
import tempfile
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
# Add the src directory to the Python path
|
| 15 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent ))
|
| 16 |
+
|
| 17 |
+
from src.app.main import app
|
| 18 |
+
from src.pipeline import InferencePipeline
|
| 19 |
+
|
| 20 |
+
# Initialize test client
|
| 21 |
+
client = TestClient(app)
|
| 22 |
+
|
| 23 |
+
@pytest.fixture
|
| 24 |
+
def pipeline_config():
|
| 25 |
+
return {
|
| 26 |
+
"model": {
|
| 27 |
+
"weights": "weights/model.pth",
|
| 28 |
+
"scale": 4,
|
| 29 |
+
"device": "cpu"
|
| 30 |
+
},
|
| 31 |
+
"preprocessing": {
|
| 32 |
+
"unsharping_mask": {
|
| 33 |
+
"kernel_size": 7,
|
| 34 |
+
"strength": 0.5
|
| 35 |
+
}
|
| 36 |
+
},
|
| 37 |
+
"postprocessing": {
|
| 38 |
+
"clahe": {
|
| 39 |
+
"clipLimit": 2,
|
| 40 |
+
"tileGridSize": [16, 16]
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
@pytest.fixture
|
| 46 |
+
def pipeline(pipeline_config):
|
| 47 |
+
return InferencePipeline(pipeline_config)
|
| 48 |
+
|
| 49 |
+
def create_dummy_dicom():
|
| 50 |
+
"""Create a dummy DICOM file for testing."""
|
| 51 |
+
meta = Dataset()
|
| 52 |
+
meta.MediaStorageSOPClassUID = "1.2.840.10008.5.1.4.1.1.2"
|
| 53 |
+
meta.MediaStorageSOPInstanceUID = "1.2.3"
|
| 54 |
+
meta.TransferSyntaxUID = pydicom.uid.ExplicitVRLittleEndian
|
| 55 |
+
|
| 56 |
+
ds = FileDataset("", {}, file_meta=meta, preamble=b"\x00" * 128)
|
| 57 |
+
|
| 58 |
+
# Required Patient and Image Information
|
| 59 |
+
ds.PatientName = "Test"
|
| 60 |
+
ds.PatientID = "12345"
|
| 61 |
+
ds.Modality = "CT"
|
| 62 |
+
ds.StudyInstanceUID = "1.2.3.4.5.6.7.8.9.10"
|
| 63 |
+
ds.SeriesInstanceUID = "1.2.3.4.5.6.7.8.9.11"
|
| 64 |
+
ds.SOPInstanceUID = "1.2.3.4.5.6.7.8.9.12"
|
| 65 |
+
ds.StudyDate = "20240101"
|
| 66 |
+
ds.StudyTime = "120000"
|
| 67 |
+
ds.Manufacturer = "TestManufacturer"
|
| 68 |
+
|
| 69 |
+
# Required Image Data Information
|
| 70 |
+
ds.PhotometricInterpretation = "MONOCHROME2"
|
| 71 |
+
ds.Rows = 128
|
| 72 |
+
ds.Columns = 128
|
| 73 |
+
ds.BitsAllocated = 16
|
| 74 |
+
ds.BitsStored = 16 # Add missing Bits Stored
|
| 75 |
+
ds.HighBit = 15 # Highest bit set
|
| 76 |
+
ds.PixelRepresentation = 0 # Unsigned integer
|
| 77 |
+
ds.SamplesPerPixel = 1 # Single-channel (grayscale)
|
| 78 |
+
ds.PixelData = (np.random.rand(128, 128) * 65535).astype(np.uint16).tobytes()
|
| 79 |
+
|
| 80 |
+
# Save to a temporary file
|
| 81 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".dcm")
|
| 82 |
+
ds.save_as(temp_file.name)
|
| 83 |
+
return temp_file.name
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def test_is_dicom(pipeline):
|
| 87 |
+
dicom_path = create_dummy_dicom()
|
| 88 |
+
|
| 89 |
+
# Test with file path
|
| 90 |
+
assert pipeline.is_dicom(dicom_path) is True
|
| 91 |
+
|
| 92 |
+
# Test with BytesIO
|
| 93 |
+
with open(dicom_path, "rb") as f:
|
| 94 |
+
dicom_bytes = BytesIO(f.read())
|
| 95 |
+
assert pipeline.is_dicom(dicom_bytes) is True
|
| 96 |
+
|
| 97 |
+
# Test with invalid BytesIO (non-DICOM content)
|
| 98 |
+
non_dicom_bytes = BytesIO()
|
| 99 |
+
non_dicom_bytes.write(b"\x89PNG\r\n\x1a\n" + b"\x00" * 128) # Write invalid header
|
| 100 |
+
non_dicom_bytes.seek(0)
|
| 101 |
+
assert pipeline.is_dicom(non_dicom_bytes) is False
|
| 102 |
+
|
| 103 |
+
os.remove(dicom_path)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def test_is_dicom(pipeline):
|
| 107 |
+
dicom_path = create_dummy_dicom()
|
| 108 |
+
|
| 109 |
+
# Test with file path
|
| 110 |
+
assert pipeline.is_dicom(dicom_path) is True, "DICOM file path should be recognized as DICOM"
|
| 111 |
+
|
| 112 |
+
# Test with BytesIO
|
| 113 |
+
with open(dicom_path, "rb") as f:
|
| 114 |
+
dicom_bytes = BytesIO(f.read())
|
| 115 |
+
assert pipeline.is_dicom(dicom_bytes) is True, "BytesIO DICOM content should be recognized as DICOM"
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# Test with invalid BytesIO (non-DICOM content)
|
| 120 |
+
non_dicom_bytes = BytesIO()
|
| 121 |
+
non_dicom_bytes.write(b"\x89PNG\r\n\x1a\n" + b"\x00" * 128) # Write invalid header
|
| 122 |
+
non_dicom_bytes.seek(0)
|
| 123 |
+
assert pipeline.is_dicom(non_dicom_bytes) is False, "Non-DICOM BytesIO should not be recognized as DICOM"
|
| 124 |
+
|
| 125 |
+
# Test with invalid raw bytes
|
| 126 |
+
invalid_raw_bytes = b"\x89PNG\r\n\x1a\n" + b"\x00" * 128
|
| 127 |
+
assert pipeline.is_dicom(invalid_raw_bytes) is False, "Invalid raw bytes should not be recognized as DICOM"
|
| 128 |
+
|
| 129 |
+
os.remove(dicom_path)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def test_preprocess_normal_image(pipeline):
|
| 134 |
+
# Create a dummy image
|
| 135 |
+
image = Image.new("RGB", (128, 128), color="red")
|
| 136 |
+
|
| 137 |
+
# Test with BytesIO
|
| 138 |
+
image_bytes = BytesIO()
|
| 139 |
+
image.save(image_bytes, format="JPEG")
|
| 140 |
+
image_bytes.seek(0)
|
| 141 |
+
|
| 142 |
+
processed_image_bytes = pipeline.preprocess(image_bytes, is_dicom=False)
|
| 143 |
+
assert isinstance(processed_image_bytes, Image.Image)
|
| 144 |
+
|
| 145 |
+
# Test with file path
|
| 146 |
+
temp_image_path = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg").name
|
| 147 |
+
image.save(temp_image_path)
|
| 148 |
+
|
| 149 |
+
processed_image_path = pipeline.preprocess(temp_image_path, is_dicom=False)
|
| 150 |
+
assert isinstance(processed_image_path, Image.Image)
|
| 151 |
+
|
| 152 |
+
os.remove(temp_image_path)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def test_infer(pipeline):
|
| 156 |
+
# Create a dummy image
|
| 157 |
+
image = Image.new("RGB", (128, 128), color="red")
|
| 158 |
+
|
| 159 |
+
# Perform inference
|
| 160 |
+
result = pipeline.infer(image)
|
| 161 |
+
assert isinstance(result, Image.Image)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def test_postprocess(pipeline):
|
| 165 |
+
|
| 166 |
+
image = Image.new("RGB", (128, 128), color="red")
|
| 167 |
+
result = pipeline.postprocess(image)
|
| 168 |
+
|
| 169 |
+
assert isinstance(result, Image.Image)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def test_api_predict_normal_image():
|
| 173 |
+
# Create a dummy image
|
| 174 |
+
image = Image.new("RGB", (128, 128), color="red")
|
| 175 |
+
image_bytes = BytesIO()
|
| 176 |
+
image.save(image_bytes, format="JPEG")
|
| 177 |
+
image_bytes.seek(0)
|
| 178 |
+
|
| 179 |
+
response = client.post(
|
| 180 |
+
"/inference/predict", # Adjusted to include the prefix
|
| 181 |
+
files={"file": ("test.jpg", image_bytes, "image/jpeg")},
|
| 182 |
+
data={"apply_clahe_postprocess": "false"} # Ensure proper boolean conversion
|
| 183 |
+
)
|
| 184 |
+
assert response.status_code == 200, response.text
|
| 185 |
+
assert response.headers["content-type"] == "image/png"
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def test_api_predict_dicom():
|
| 189 |
+
dicom_path = create_dummy_dicom()
|
| 190 |
+
|
| 191 |
+
# Use BytesIO for testing
|
| 192 |
+
with open(dicom_path, "rb") as f:
|
| 193 |
+
dicom_bytes = BytesIO(f.read())
|
| 194 |
+
|
| 195 |
+
response = client.post(
|
| 196 |
+
"/inference/predict", # Adjusted to include the prefix
|
| 197 |
+
files={"file": ("test.dcm", dicom_bytes, "application/dicom")},
|
| 198 |
+
data={"apply_clahe_postprocess": "false"} # Ensure proper boolean conversion
|
| 199 |
+
)
|
| 200 |
+
assert response.status_code == 200, response.text
|
| 201 |
+
assert response.headers["content-type"] == "image/png"
|
| 202 |
+
|
| 203 |
+
os.remove(dicom_path)
|
| 204 |
+
|