Spaces:
Sleeping
Sleeping
| import torch | |
| from PIL import Image | |
| import numpy as np | |
| from io import BytesIO | |
| from huggingface_hub import hf_hub_download | |
| from pathlib import Path | |
| from src.preprocess import read_xray, enhance_exposure, unsharp_masking, apply_clahe, resize_pil_image, increase_brightness | |
| from src.network.model import RealESRGAN | |
| from src.app.exceptions import InputError, ModelLoadError, PreprocessingError, InferenceError,PostprocessingError | |
| class ModelLoadError(Exception): | |
| pass | |
| class InferencePipeline: | |
| def __init__(self, config): | |
| """ | |
| Initialize the inference pipeline using configuration. | |
| Args: | |
| config: Configuration dictionary. | |
| """ | |
| self.config = config | |
| preferred_device = config["model"].get("device", "cuda") | |
| if preferred_device == "cuda" and not torch.cuda.is_available(): | |
| print("[Warning] CUDA requested but not available. Falling back to CPU.") | |
| self.device = "cpu" | |
| else: | |
| self.device = preferred_device | |
| self.scale = config["model"].get("scale", 4) | |
| model_source = config["model"].get("source", "local") | |
| self.model = RealESRGAN(self.device, scale=self.scale) | |
| print(f"Using device: {self.device}") | |
| try: | |
| if model_source == "huggingface": | |
| repo_id = config["model"]["repo_id"] | |
| filename = config["model"]["filename"] | |
| local_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
| self.load_weights(local_path) | |
| else: | |
| local_path = config["model"]["weights"] | |
| self.load_weights(local_path) | |
| except Exception as e: | |
| raise ModelLoadError(f"Failed to load the model: {str(e)}") | |
| def load_weights(self, model_weights): | |
| """ | |
| Load the model weights. | |
| Args: | |
| model_weights: Path to the model weights file. | |
| """ | |
| try: | |
| self.model.load_weights(model_weights) | |
| except FileNotFoundError: | |
| raise ModelLoadError(f"Model weights not found at '{model_weights}'.") | |
| except Exception as e: | |
| raise ModelLoadError(f"Error loading weights: {str(e)}") | |
| def preprocess(self, image_path_or_bytes, apply_pre_contrast_adjustment=True, is_dicom=False): | |
| """ | |
| Preprocess the input image. | |
| Args: | |
| image_path: Path to the input image file. | |
| is_dicom: Boolean indicating if the input is a DICOM file. | |
| Returns: | |
| PIL Image: Preprocessed image. | |
| """ | |
| try: | |
| if is_dicom: | |
| img = read_xray(image_path_or_bytes) | |
| else: | |
| img = Image.open(image_path_or_bytes) | |
| if apply_pre_contrast_adjustment: | |
| img = enhance_exposure(np.array(img)) | |
| if isinstance(img,np.ndarray): | |
| img = Image.fromarray(((img / np.max(img))*255).astype(np.uint8)) | |
| if img.mode not in ['RGB']: | |
| img = img.convert('RGB') | |
| img = unsharp_masking( | |
| img, | |
| self.config["preprocessing"]["unsharping_mask"].get("kernel_size", 7), | |
| self.config["preprocessing"]["unsharping_mask"].get("strength", 2) | |
| ) | |
| img = increase_brightness( | |
| img, | |
| self.config["preprocessing"]["brightness"].get("factor", 1.2), | |
| ) | |
| if img.mode not in ['RGB']: | |
| img = img.convert('RGB') | |
| return img, img.size | |
| except Exception as e: | |
| raise PreprocessingError(f"Error during preprocessing: {str(e)}") | |
| def postprocess(self, image_array): | |
| """ | |
| Postprocess the output from the model. | |
| Args: | |
| image_array: PIL.Image output from the model. | |
| Returns: | |
| PIL Image: Postprocessed image. | |
| """ | |
| try: | |
| return apply_clahe( | |
| image_array, | |
| self.config["postprocessing"]["clahe"].get("clipLimit", 2.0), | |
| tuple(self.config["postprocessing"]["clahe"].get("tileGridSize", [16, 16])) | |
| ) | |
| except Exception as e: | |
| raise PostprocessingError(f"Error during postprocessing: {str(e)}") | |
| def is_dicom(self, file_path_or_bytes): | |
| """ | |
| Check if the input file is a DICOM file. | |
| Args: | |
| file_path_or_bytes (str or bytes or BytesIO): Path to the file, byte content, or BytesIO object. | |
| Returns: | |
| bool: True if the file is a DICOM file, False otherwise. | |
| """ | |
| try: | |
| if isinstance(file_path_or_bytes, str): | |
| # Check the file extension | |
| file_extension = Path(file_path_or_bytes).suffix.lower() | |
| if file_extension in ['.dcm', '.dicom']: | |
| return True | |
| # Open the file and check the header | |
| with open(file_path_or_bytes, 'rb') as file: | |
| header = file.read(132) | |
| return header[-4:] == b'DICM' | |
| elif isinstance(file_path_or_bytes, BytesIO): | |
| file_path_or_bytes.seek(0) | |
| header = file_path_or_bytes.read(132) | |
| file_path_or_bytes.seek(0) # Reset the stream position | |
| return header[-4:] == b'DICM' | |
| elif isinstance(file_path_or_bytes, bytes): | |
| header = file_path_or_bytes[:132] | |
| return header[-4:] == b'DICM' | |
| except Exception as e: | |
| print(f"Error during DICOM validation: {e}") | |
| return False | |
| return False | |
| def validate_input(self, input_data): | |
| """ | |
| Validate the input data to ensure it is suitable for processing. | |
| Args: | |
| input_data: Path to the input file, bytes content, or BytesIO object. | |
| Returns: | |
| bool: True if the input is valid, raises InputError otherwise. | |
| """ | |
| if isinstance(input_data, str): | |
| # Check if the file exists | |
| if not Path(input_data).exists(): | |
| raise InputError(f"Input file '{input_data}' does not exist.") | |
| # Check if the file type is supported | |
| file_extension = Path(input_data).suffix.lower() | |
| if file_extension not in ['.png', '.jpeg', '.jpg', '.dcm', '.dicom']: | |
| raise InputError(f"Unsupported file type '{file_extension}'. Supported types are PNG, JPEG, and DICOM.") | |
| elif isinstance(input_data, BytesIO): | |
| # Check if BytesIO data is not empty | |
| if input_data.getbuffer().nbytes == 0: | |
| raise InputError("Input BytesIO data is empty.") | |
| else: | |
| raise InputError("Unsupported input type. Must be a file path, byte content, or BytesIO object.") | |
| return True | |
| def infer(self, input_image): | |
| """ | |
| Perform inference on a single image. | |
| Args: | |
| input_image: PIL Image to be processed. | |
| Returns: | |
| PIL Image: Super-resolved image. | |
| """ | |
| try: | |
| # Perform inference | |
| input_array = np.array(input_image) | |
| sr_array = self.model.predict(input_array) | |
| return sr_array | |
| except Exception as e: | |
| raise InferenceError(f"Error during inference: {str(e)}") | |
| def run(self, input_path, apply_pre_contrast_adjustment = True, apply_clahe_postprocess=False, return_original_size = True): | |
| """ | |
| Process a single image and save the output. | |
| Args: | |
| input_path: Path to the input image file. | |
| is_dicom: Boolean indicating if the input is a DICOM file. | |
| apply_clahe_postprocess: Boolean indicating if CLAHE should be applied post-processing. | |
| """ | |
| # Validate the input | |
| self.validate_input(input_path) | |
| is_dicom =self.is_dicom(input_path) | |
| img, original_size = self.preprocess(input_path, is_dicom=is_dicom, apply_pre_contrast_adjustment = apply_pre_contrast_adjustment) | |
| if img is None: | |
| raise InputError(f"Invalid Input") | |
| sr_image = self.infer(img) | |
| if apply_clahe_postprocess: | |
| sr_image = self.postprocess(sr_image) | |
| if return_original_size: | |
| sr_image = resize_pil_image(sr_image, target_shape = original_size) | |
| return sr_image |