SerdarHelli commited on
Commit
5251218
·
verified ·
1 Parent(s): 62f828b

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +241 -235
src/pipeline.py CHANGED
@@ -1,236 +1,242 @@
1
- import torch
2
- from PIL import Image
3
- import numpy as np
4
- from io import BytesIO
5
- from huggingface_hub import hf_hub_download
6
- from pathlib import Path
7
-
8
- from src.preprocess import read_xray, enhance_exposure, unsharp_masking, apply_clahe, resize_pil_image, increase_brightness
9
- from src.network.model import RealESRGAN
10
- from src.app.exceptions import InputError, ModelLoadError, PreprocessingError, InferenceError,PostprocessingError
11
-
12
- class ModelLoadError(Exception):
13
- pass
14
-
15
- class InferencePipeline:
16
- def __init__(self, config):
17
- """
18
- Initialize the inference pipeline using configuration.
19
-
20
- Args:
21
- config: Configuration dictionary.
22
- """
23
- self.config = config
24
- self.device = config["model"].get("device", "cuda" if torch.cuda.is_available() else "cpu")
25
- self.scale = config["model"].get("scale", 4)
26
-
27
- model_source = config["model"].get("source", "local")
28
- self.model = RealESRGAN(self.device, scale=self.scale)
29
-
30
- print(f"Using device: {self.device}")
31
-
32
- try:
33
- if model_source == "huggingface":
34
- repo_id = config["model"]["repo_id"]
35
- filename = config["model"]["filename"]
36
- local_path = hf_hub_download(repo_id=repo_id, filename=filename)
37
- self.load_weights(local_path)
38
- else:
39
- local_path = config["model"]["weights"]
40
- self.load_weights(local_path)
41
- except Exception as e:
42
- raise ModelLoadError(f"Failed to load the model: {str(e)}")
43
-
44
- def load_weights(self, model_weights):
45
- """
46
- Load the model weights.
47
-
48
- Args:
49
- model_weights: Path to the model weights file.
50
- """
51
- try:
52
- self.model.load_weights(model_weights)
53
- except FileNotFoundError:
54
- raise ModelLoadError(f"Model weights not found at '{model_weights}'.")
55
- except Exception as e:
56
- raise ModelLoadError(f"Error loading weights: {str(e)}")
57
- def preprocess(self, image_path_or_bytes, apply_pre_contrast_adjustment=True, is_dicom=False):
58
- """
59
- Preprocess the input image.
60
-
61
- Args:
62
- image_path: Path to the input image file.
63
- is_dicom: Boolean indicating if the input is a DICOM file.
64
-
65
- Returns:
66
- PIL Image: Preprocessed image.
67
- """
68
- try:
69
- if is_dicom:
70
- img = read_xray(image_path_or_bytes)
71
- else:
72
- img = Image.open(image_path_or_bytes)
73
-
74
- if apply_pre_contrast_adjustment:
75
- img = enhance_exposure(np.array(img))
76
-
77
- if isinstance(img,np.ndarray):
78
- img = Image.fromarray(((img / np.max(img))*255).astype(np.uint8))
79
-
80
- if img.mode not in ['RGB']:
81
- img = img.convert('RGB')
82
-
83
- img = unsharp_masking(
84
- img,
85
- self.config["preprocessing"]["unsharping_mask"].get("kernel_size", 7),
86
- self.config["preprocessing"]["unsharping_mask"].get("strength", 2)
87
- )
88
- img = increase_brightness(
89
- img,
90
- self.config["preprocessing"]["brightness"].get("factor", 1.2),
91
- )
92
-
93
-
94
- if img.mode not in ['RGB']:
95
- img = img.convert('RGB')
96
-
97
-
98
- return img, img.size
99
- except Exception as e:
100
- raise PreprocessingError(f"Error during preprocessing: {str(e)}")
101
-
102
- def postprocess(self, image_array):
103
- """
104
- Postprocess the output from the model.
105
-
106
- Args:
107
- image_array: PIL.Image output from the model.
108
-
109
- Returns:
110
- PIL Image: Postprocessed image.
111
- """
112
- try:
113
- return apply_clahe(
114
- image_array,
115
- self.config["postprocessing"]["clahe"].get("clipLimit", 2.0),
116
- tuple(self.config["postprocessing"]["clahe"].get("tileGridSize", [16, 16]))
117
- )
118
- except Exception as e:
119
- raise PostprocessingError(f"Error during postprocessing: {str(e)}")
120
-
121
- def is_dicom(self, file_path_or_bytes):
122
- """
123
- Check if the input file is a DICOM file.
124
-
125
- Args:
126
- file_path_or_bytes (str or bytes or BytesIO): Path to the file, byte content, or BytesIO object.
127
-
128
- Returns:
129
- bool: True if the file is a DICOM file, False otherwise.
130
- """
131
- try:
132
- if isinstance(file_path_or_bytes, str):
133
- # Check the file extension
134
- file_extension = Path(file_path_or_bytes).suffix.lower()
135
- if file_extension in ['.dcm', '.dicom']:
136
- return True
137
-
138
- # Open the file and check the header
139
- with open(file_path_or_bytes, 'rb') as file:
140
- header = file.read(132)
141
- return header[-4:] == b'DICM'
142
-
143
- elif isinstance(file_path_or_bytes, BytesIO):
144
- file_path_or_bytes.seek(0)
145
- header = file_path_or_bytes.read(132)
146
- file_path_or_bytes.seek(0) # Reset the stream position
147
- return header[-4:] == b'DICM'
148
-
149
- elif isinstance(file_path_or_bytes, bytes):
150
- header = file_path_or_bytes[:132]
151
- return header[-4:] == b'DICM'
152
-
153
- except Exception as e:
154
- print(f"Error during DICOM validation: {e}")
155
- return False
156
-
157
- return False
158
-
159
- def validate_input(self, input_data):
160
- """
161
- Validate the input data to ensure it is suitable for processing.
162
-
163
- Args:
164
- input_data: Path to the input file, bytes content, or BytesIO object.
165
-
166
- Returns:
167
- bool: True if the input is valid, raises InputError otherwise.
168
- """
169
- if isinstance(input_data, str):
170
- # Check if the file exists
171
- if not Path(input_data).exists():
172
- raise InputError(f"Input file '{input_data}' does not exist.")
173
-
174
- # Check if the file type is supported
175
- file_extension = Path(input_data).suffix.lower()
176
- if file_extension not in ['.png', '.jpeg', '.jpg', '.dcm', '.dicom']:
177
- raise InputError(f"Unsupported file type '{file_extension}'. Supported types are PNG, JPEG, and DICOM.")
178
-
179
- elif isinstance(input_data, BytesIO):
180
- # Check if BytesIO data is not empty
181
- if input_data.getbuffer().nbytes == 0:
182
- raise InputError("Input BytesIO data is empty.")
183
-
184
- else:
185
- raise InputError("Unsupported input type. Must be a file path, byte content, or BytesIO object.")
186
-
187
- return True
188
-
189
- def infer(self, input_image):
190
- """
191
- Perform inference on a single image.
192
-
193
- Args:
194
- input_image: PIL Image to be processed.
195
-
196
- Returns:
197
- PIL Image: Super-resolved image.
198
- """
199
- try:
200
- # Perform inference
201
- input_array = np.array(input_image)
202
- sr_array = self.model.predict(input_array)
203
- return sr_array
204
-
205
- except Exception as e:
206
- raise InferenceError(f"Error during inference: {str(e)}")
207
-
208
- def run(self, input_path, apply_pre_contrast_adjustment = True, apply_clahe_postprocess=False, return_original_size = True):
209
- """
210
- Process a single image and save the output.
211
-
212
- Args:
213
- input_path: Path to the input image file.
214
- is_dicom: Boolean indicating if the input is a DICOM file.
215
- apply_clahe_postprocess: Boolean indicating if CLAHE should be applied post-processing.
216
- """
217
- # Validate the input
218
- self.validate_input(input_path)
219
-
220
- is_dicom =self.is_dicom(input_path)
221
-
222
- img, original_size = self.preprocess(input_path, is_dicom=is_dicom, apply_pre_contrast_adjustment = apply_pre_contrast_adjustment)
223
-
224
- if img is None:
225
- raise InputError(f"Invalid Input")
226
-
227
-
228
- sr_image = self.infer(img)
229
-
230
- if apply_clahe_postprocess:
231
- sr_image = self.postprocess(sr_image)
232
-
233
- if return_original_size:
234
- sr_image = resize_pil_image(sr_image, target_shape = original_size)
235
-
 
 
 
 
 
 
236
  return sr_image
 
1
+ import torch
2
+ from PIL import Image
3
+ import numpy as np
4
+ from io import BytesIO
5
+ from huggingface_hub import hf_hub_download
6
+ from pathlib import Path
7
+
8
+ from src.preprocess import read_xray, enhance_exposure, unsharp_masking, apply_clahe, resize_pil_image, increase_brightness
9
+ from src.network.model import RealESRGAN
10
+ from src.app.exceptions import InputError, ModelLoadError, PreprocessingError, InferenceError,PostprocessingError
11
+
12
+ class ModelLoadError(Exception):
13
+ pass
14
+
15
+ class InferencePipeline:
16
+ def __init__(self, config):
17
+ """
18
+ Initialize the inference pipeline using configuration.
19
+
20
+ Args:
21
+ config: Configuration dictionary.
22
+ """
23
+ self.config = config
24
+ preferred_device = config["model"].get("device", "cuda")
25
+ if preferred_device == "cuda" and not torch.cuda.is_available():
26
+ print("[Warning] CUDA requested but not available. Falling back to CPU.")
27
+ self.device = "cpu"
28
+ else:
29
+ self.device = preferred_device
30
+
31
+ self.scale = config["model"].get("scale", 4)
32
+
33
+ model_source = config["model"].get("source", "local")
34
+ self.model = RealESRGAN(self.device, scale=self.scale)
35
+
36
+ print(f"Using device: {self.device}")
37
+
38
+ try:
39
+ if model_source == "huggingface":
40
+ repo_id = config["model"]["repo_id"]
41
+ filename = config["model"]["filename"]
42
+ local_path = hf_hub_download(repo_id=repo_id, filename=filename)
43
+ self.load_weights(local_path)
44
+ else:
45
+ local_path = config["model"]["weights"]
46
+ self.load_weights(local_path)
47
+ except Exception as e:
48
+ raise ModelLoadError(f"Failed to load the model: {str(e)}")
49
+
50
+ def load_weights(self, model_weights):
51
+ """
52
+ Load the model weights.
53
+
54
+ Args:
55
+ model_weights: Path to the model weights file.
56
+ """
57
+ try:
58
+ self.model.load_weights(model_weights)
59
+ except FileNotFoundError:
60
+ raise ModelLoadError(f"Model weights not found at '{model_weights}'.")
61
+ except Exception as e:
62
+ raise ModelLoadError(f"Error loading weights: {str(e)}")
63
+ def preprocess(self, image_path_or_bytes, apply_pre_contrast_adjustment=True, is_dicom=False):
64
+ """
65
+ Preprocess the input image.
66
+
67
+ Args:
68
+ image_path: Path to the input image file.
69
+ is_dicom: Boolean indicating if the input is a DICOM file.
70
+
71
+ Returns:
72
+ PIL Image: Preprocessed image.
73
+ """
74
+ try:
75
+ if is_dicom:
76
+ img = read_xray(image_path_or_bytes)
77
+ else:
78
+ img = Image.open(image_path_or_bytes)
79
+
80
+ if apply_pre_contrast_adjustment:
81
+ img = enhance_exposure(np.array(img))
82
+
83
+ if isinstance(img,np.ndarray):
84
+ img = Image.fromarray(((img / np.max(img))*255).astype(np.uint8))
85
+
86
+ if img.mode not in ['RGB']:
87
+ img = img.convert('RGB')
88
+
89
+ img = unsharp_masking(
90
+ img,
91
+ self.config["preprocessing"]["unsharping_mask"].get("kernel_size", 7),
92
+ self.config["preprocessing"]["unsharping_mask"].get("strength", 2)
93
+ )
94
+ img = increase_brightness(
95
+ img,
96
+ self.config["preprocessing"]["brightness"].get("factor", 1.2),
97
+ )
98
+
99
+
100
+ if img.mode not in ['RGB']:
101
+ img = img.convert('RGB')
102
+
103
+
104
+ return img, img.size
105
+ except Exception as e:
106
+ raise PreprocessingError(f"Error during preprocessing: {str(e)}")
107
+
108
+ def postprocess(self, image_array):
109
+ """
110
+ Postprocess the output from the model.
111
+
112
+ Args:
113
+ image_array: PIL.Image output from the model.
114
+
115
+ Returns:
116
+ PIL Image: Postprocessed image.
117
+ """
118
+ try:
119
+ return apply_clahe(
120
+ image_array,
121
+ self.config["postprocessing"]["clahe"].get("clipLimit", 2.0),
122
+ tuple(self.config["postprocessing"]["clahe"].get("tileGridSize", [16, 16]))
123
+ )
124
+ except Exception as e:
125
+ raise PostprocessingError(f"Error during postprocessing: {str(e)}")
126
+
127
+ def is_dicom(self, file_path_or_bytes):
128
+ """
129
+ Check if the input file is a DICOM file.
130
+
131
+ Args:
132
+ file_path_or_bytes (str or bytes or BytesIO): Path to the file, byte content, or BytesIO object.
133
+
134
+ Returns:
135
+ bool: True if the file is a DICOM file, False otherwise.
136
+ """
137
+ try:
138
+ if isinstance(file_path_or_bytes, str):
139
+ # Check the file extension
140
+ file_extension = Path(file_path_or_bytes).suffix.lower()
141
+ if file_extension in ['.dcm', '.dicom']:
142
+ return True
143
+
144
+ # Open the file and check the header
145
+ with open(file_path_or_bytes, 'rb') as file:
146
+ header = file.read(132)
147
+ return header[-4:] == b'DICM'
148
+
149
+ elif isinstance(file_path_or_bytes, BytesIO):
150
+ file_path_or_bytes.seek(0)
151
+ header = file_path_or_bytes.read(132)
152
+ file_path_or_bytes.seek(0) # Reset the stream position
153
+ return header[-4:] == b'DICM'
154
+
155
+ elif isinstance(file_path_or_bytes, bytes):
156
+ header = file_path_or_bytes[:132]
157
+ return header[-4:] == b'DICM'
158
+
159
+ except Exception as e:
160
+ print(f"Error during DICOM validation: {e}")
161
+ return False
162
+
163
+ return False
164
+
165
+ def validate_input(self, input_data):
166
+ """
167
+ Validate the input data to ensure it is suitable for processing.
168
+
169
+ Args:
170
+ input_data: Path to the input file, bytes content, or BytesIO object.
171
+
172
+ Returns:
173
+ bool: True if the input is valid, raises InputError otherwise.
174
+ """
175
+ if isinstance(input_data, str):
176
+ # Check if the file exists
177
+ if not Path(input_data).exists():
178
+ raise InputError(f"Input file '{input_data}' does not exist.")
179
+
180
+ # Check if the file type is supported
181
+ file_extension = Path(input_data).suffix.lower()
182
+ if file_extension not in ['.png', '.jpeg', '.jpg', '.dcm', '.dicom']:
183
+ raise InputError(f"Unsupported file type '{file_extension}'. Supported types are PNG, JPEG, and DICOM.")
184
+
185
+ elif isinstance(input_data, BytesIO):
186
+ # Check if BytesIO data is not empty
187
+ if input_data.getbuffer().nbytes == 0:
188
+ raise InputError("Input BytesIO data is empty.")
189
+
190
+ else:
191
+ raise InputError("Unsupported input type. Must be a file path, byte content, or BytesIO object.")
192
+
193
+ return True
194
+
195
+ def infer(self, input_image):
196
+ """
197
+ Perform inference on a single image.
198
+
199
+ Args:
200
+ input_image: PIL Image to be processed.
201
+
202
+ Returns:
203
+ PIL Image: Super-resolved image.
204
+ """
205
+ try:
206
+ # Perform inference
207
+ input_array = np.array(input_image)
208
+ sr_array = self.model.predict(input_array)
209
+ return sr_array
210
+
211
+ except Exception as e:
212
+ raise InferenceError(f"Error during inference: {str(e)}")
213
+
214
+ def run(self, input_path, apply_pre_contrast_adjustment = True, apply_clahe_postprocess=False, return_original_size = True):
215
+ """
216
+ Process a single image and save the output.
217
+
218
+ Args:
219
+ input_path: Path to the input image file.
220
+ is_dicom: Boolean indicating if the input is a DICOM file.
221
+ apply_clahe_postprocess: Boolean indicating if CLAHE should be applied post-processing.
222
+ """
223
+ # Validate the input
224
+ self.validate_input(input_path)
225
+
226
+ is_dicom =self.is_dicom(input_path)
227
+
228
+ img, original_size = self.preprocess(input_path, is_dicom=is_dicom, apply_pre_contrast_adjustment = apply_pre_contrast_adjustment)
229
+
230
+ if img is None:
231
+ raise InputError(f"Invalid Input")
232
+
233
+
234
+ sr_image = self.infer(img)
235
+
236
+ if apply_clahe_postprocess:
237
+ sr_image = self.postprocess(sr_image)
238
+
239
+ if return_original_size:
240
+ sr_image = resize_pil_image(sr_image, target_shape = original_size)
241
+
242
  return sr_image