Fioceen commited on
Commit
2a8ed16
·
1 Parent(s): 65c7c05

File modularization

Browse files
image_postprocess/utils/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .autowb import auto_white_balance_ref
2
+ from .clahe import clahe_color_correction
3
+ from .color_lut import load_lut, apply_lut
4
+ from .exif import remove_exif_pil
5
+ from .fourier_pipeline import fourier_match_spectrum
6
+ from .gaussian_noise import add_gaussian_noise
7
+ from .perturbation import randomized_perturbation
8
+
9
+ __all__ = [
10
+ 'auto_white_balance_ref',
11
+ 'clahe_color_correction',
12
+ 'load_lut',
13
+ 'apply_lut',
14
+ 'remove_exif_pil',
15
+ 'fourier_match_spectrum',
16
+ 'add_gaussian_noise',
17
+ 'randomized_perturbation'
18
+ ]
image_postprocess/utils/autowb.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def auto_white_balance_ref(img_arr: np.ndarray, ref_img_arr: np.ndarray = None) -> np.ndarray:
4
+ """
5
+ Auto white-balance correction using a reference image.
6
+ If ref_img_arr is None, uses a gray-world assumption instead.
7
+ """
8
+ img = img_arr.astype(np.float32)
9
+
10
+ if ref_img_arr is not None:
11
+ ref = ref_img_arr.astype(np.float32)
12
+ ref_mean = ref.reshape(-1, 3).mean(axis=0)
13
+ else:
14
+ # Gray-world assumption: target is neutral gray
15
+ ref_mean = np.array([128.0, 128.0, 128.0], dtype=np.float32)
16
+
17
+ img_mean = img.reshape(-1, 3).mean(axis=0)
18
+
19
+ # Avoid divide-by-zero
20
+ eps = 1e-6
21
+ scale = (ref_mean + eps) / (img_mean + eps)
22
+
23
+ corrected = img * scale
24
+ corrected = np.clip(corrected, 0, 255).astype(np.uint8)
25
+
26
+ return corrected
image_postprocess/utils/clahe.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image, ImageOps
3
+
4
+ try:
5
+ import cv2
6
+ _HAS_CV2 = True
7
+ except Exception:
8
+ cv2 = None
9
+ _HAS_CV2 = False
10
+
11
+ def clahe_color_correction(img_arr: np.ndarray, clip_limit=2.0, tile_grid_size=(8,8)) -> np.ndarray:
12
+ if _HAS_CV2:
13
+ lab = cv2.cvtColor(img_arr, cv2.COLOR_RGB2LAB)
14
+ l, a, b = cv2.split(lab)
15
+ clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)
16
+ l2 = clahe.apply(l)
17
+ lab2 = cv2.merge((l2, a, b))
18
+ out = cv2.cvtColor(lab2, cv2.COLOR_LAB2RGB)
19
+ return out
20
+ else:
21
+ pil = Image.fromarray(img_arr)
22
+ channels = pil.split()
23
+ new_ch = []
24
+ for ch in channels:
25
+ eq = ImageOps.equalize(ch)
26
+ new_ch.append(eq)
27
+ merged = Image.merge('RGB', new_ch)
28
+ return np.array(merged)
image_postprocess/utils/color_lut.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import re, os
3
+ from PIL import Image
4
+
5
+ def apply_1d_lut(img_arr: np.ndarray, lut: np.ndarray, strength: float = 1.0) -> np.ndarray:
6
+ """
7
+ Apply a 1D LUT to an image.
8
+ - img_arr: HxWx3 uint8
9
+ - lut: either shape (256,) (applied equally to all channels), (256,3) (per-channel),
10
+ or (N,) / (N,3) (interpolated across [0..255])
11
+ - strength: 0..1 blending between original and LUT result
12
+ Returns uint8 array.
13
+ """
14
+ if img_arr.ndim != 3 or img_arr.shape[2] != 3:
15
+ raise ValueError("apply_1d_lut expects an HxWx3 image array")
16
+
17
+ # Normalize indices 0..255
18
+ arr = img_arr.astype(np.float32)
19
+ # Prepare LUT as float in 0..255 range if necessary
20
+ lut_arr = np.array(lut, dtype=np.float32)
21
+
22
+ # If single channel LUT (N,) expand to three channels
23
+ if lut_arr.ndim == 1:
24
+ lut_arr = np.stack([lut_arr, lut_arr, lut_arr], axis=1) # (N,3)
25
+
26
+ if lut_arr.shape[1] != 3:
27
+ raise ValueError("1D LUT must have shape (N,) or (N,3)")
28
+
29
+ # Build index positions in source LUT space (0..255)
30
+ N = lut_arr.shape[0]
31
+ src_positions = np.linspace(0, 255, N)
32
+
33
+ # Flatten and interpolate per channel
34
+ out = np.empty_like(arr)
35
+ for c in range(3):
36
+ channel = arr[..., c].ravel()
37
+ mapped = np.interp(channel, src_positions, lut_arr[:, c])
38
+ out[..., c] = mapped.reshape(arr.shape[0], arr.shape[1])
39
+
40
+ out = np.clip(out, 0, 255).astype(np.uint8)
41
+ if strength >= 1.0:
42
+ return out
43
+ else:
44
+ blended = ((1.0 - strength) * img_arr.astype(np.float32) + strength * out.astype(np.float32))
45
+ return np.clip(blended, 0, 255).astype(np.uint8)
46
+
47
+ def _trilinear_sample_lut(img_float: np.ndarray, lut: np.ndarray) -> np.ndarray:
48
+ """
49
+ Vectorized trilinear sampling of 3D LUT.
50
+ - img_float: HxWx3 floats in [0,1]
51
+ - lut: SxSxS x 3 floats in [0,1]
52
+ Returns HxWx3 floats in [0,1]
53
+ """
54
+ S = lut.shape[0]
55
+ if lut.shape[0] != lut.shape[1] or lut.shape[1] != lut.shape[2]:
56
+ raise ValueError("3D LUT must be cubic (SxSxSx3)")
57
+
58
+ # map [0,1] -> [0, S-1]
59
+ idx = img_float * (S - 1)
60
+ r_idx = idx[..., 0]
61
+ g_idx = idx[..., 1]
62
+ b_idx = idx[..., 2]
63
+
64
+ r0 = np.floor(r_idx).astype(np.int32)
65
+ g0 = np.floor(g_idx).astype(np.int32)
66
+ b0 = np.floor(b_idx).astype(np.int32)
67
+
68
+ r1 = np.clip(r0 + 1, 0, S - 1)
69
+ g1 = np.clip(g0 + 1, 0, S - 1)
70
+ b1 = np.clip(b0 + 1, 0, S - 1)
71
+
72
+ dr = (r_idx - r0)[..., None]
73
+ dg = (g_idx - g0)[..., None]
74
+ db = (b_idx - b0)[..., None]
75
+
76
+ # gather 8 corners: c000 ... c111
77
+ c000 = lut[r0, g0, b0]
78
+ c001 = lut[r0, g0, b1]
79
+ c010 = lut[r0, g1, b0]
80
+ c011 = lut[r0, g1, b1]
81
+ c100 = lut[r1, g0, b0]
82
+ c101 = lut[r1, g0, b1]
83
+ c110 = lut[r1, g1, b0]
84
+ c111 = lut[r1, g1, b1]
85
+
86
+ # interpolate along b
87
+ c00 = c000 * (1 - db) + c001 * db
88
+ c01 = c010 * (1 - db) + c011 * db
89
+ c10 = c100 * (1 - db) + c101 * db
90
+ c11 = c110 * (1 - db) + c111 * db
91
+
92
+ # interpolate along g
93
+ c0 = c00 * (1 - dg) + c01 * dg
94
+ c1 = c10 * (1 - dg) + c11 * dg
95
+
96
+ # interpolate along r
97
+ c = c0 * (1 - dr) + c1 * dr
98
+
99
+ return c # float in same range as lut (expected [0,1])
100
+
101
+ def apply_3d_lut(img_arr: np.ndarray, lut3d: np.ndarray, strength: float = 1.0) -> np.ndarray:
102
+ """
103
+ Apply a 3D LUT to the image.
104
+ - img_arr: HxWx3 uint8
105
+ - lut3d: SxSxSx3 float (expected range 0..1)
106
+ - strength: blending 0..1
107
+ Returns uint8 image.
108
+ """
109
+ if img_arr.ndim != 3 or img_arr.shape[2] != 3:
110
+ raise ValueError("apply_3d_lut expects an HxWx3 image array")
111
+
112
+ img_float = img_arr.astype(np.float32) / 255.0
113
+ sampled = _trilinear_sample_lut(img_float, lut3d) # HxWx3 floats in [0,1]
114
+ out = np.clip(sampled * 255.0, 0, 255).astype(np.uint8)
115
+ if strength >= 1.0:
116
+ return out
117
+ else:
118
+ blended = ((1.0 - strength) * img_arr.astype(np.float32) + strength * out.astype(np.float32))
119
+ return np.clip(blended, 0, 255).astype(np.uint8)
120
+
121
+ def apply_lut(img_arr: np.ndarray, lut: np.ndarray, strength: float = 1.0) -> np.ndarray:
122
+ """
123
+ Auto-detect LUT type and apply.
124
+ - If lut.ndim in (1,2) treat as 1D LUT (per-channel if shape (N,3)).
125
+ - If lut.ndim == 4 treat as 3D LUT (SxSxSx3) in [0,1].
126
+ """
127
+ lut = np.array(lut)
128
+ if lut.ndim == 4 and lut.shape[3] == 3:
129
+ # 3D LUT (assumed normalized [0..1])
130
+ # If lut is in 0..255, normalize
131
+ if lut.dtype != np.float32 and lut.max() > 1.0:
132
+ lut = lut.astype(np.float32) / 255.0
133
+ return apply_3d_lut(img_arr, lut, strength=strength)
134
+ elif lut.ndim in (1, 2):
135
+ return apply_1d_lut(img_arr, lut, strength=strength)
136
+ else:
137
+ raise ValueError("Unsupported LUT shape: {}".format(lut.shape))
138
+
139
+ def load_cube_lut(path: str) -> np.ndarray:
140
+ """
141
+ Parse a .cube file and return a 3D LUT array of shape (S,S,S,3) with float values in [0,1].
142
+ Note: .cube file order sometimes varies; this function assumes standard ordering
143
+ where data lines are triples of floats and LUT_3D_SIZE specifies S.
144
+ """
145
+ with open(path, 'r', encoding='utf-8', errors='ignore') as f:
146
+ lines = [ln.strip() for ln in f if ln.strip() and not ln.strip().startswith('#')]
147
+
148
+ size = None
149
+ data = []
150
+ domain_min = np.array([0.0, 0.0, 0.0], dtype=np.float32)
151
+ domain_max = np.array([1.0, 1.0, 1.0], dtype=np.float32)
152
+
153
+ for ln in lines:
154
+ if ln.upper().startswith('LUT_3D_SIZE'):
155
+ parts = ln.split()
156
+ if len(parts) >= 2:
157
+ size = int(parts[1])
158
+ elif ln.upper().startswith('DOMAIN_MIN'):
159
+ parts = ln.split()
160
+ domain_min = np.array([float(p) for p in parts[1:4]], dtype=np.float32)
161
+ elif ln.upper().startswith('DOMAIN_MAX'):
162
+ parts = ln.split()
163
+ domain_max = np.array([float(p) for p in parts[1:4]], dtype=np.float32)
164
+ elif re.match(r'^-?\d+(\.\d+)?\s+-?\d+(\.\d+)?\s+-?\d+(\.\d+)?$', ln):
165
+ parts = [float(x) for x in ln.split()]
166
+ data.append(parts)
167
+
168
+ if size is None:
169
+ raise ValueError("LUT_3D_SIZE not found in .cube file: {}".format(path))
170
+
171
+ data = np.array(data, dtype=np.float32)
172
+ if data.shape[0] != size**3:
173
+ raise ValueError("Cube LUT data length does not match size^3 (got {}, expected {})".format(data.shape[0], size**3))
174
+
175
+ # Data ordering in many .cube files is: for r in 0..S-1: for g in 0..S-1: for b in 0..S-1: write RGB
176
+ # We'll reshape into (S,S,S,3) with indices [r,g,b]
177
+ lut = data.reshape((size, size, size, 3))
178
+ # Map domain_min..domain_max to 0..1 if domain specified (rare)
179
+ if not np.allclose(domain_min, [0.0, 0.0, 0.0]) or not np.allclose(domain_max, [1.0, 1.0, 1.0]):
180
+ # scale lut values from domain range into 0..1
181
+ lut = (lut - domain_min) / (domain_max - domain_min + 1e-12)
182
+ lut = np.clip(lut, 0.0, 1.0)
183
+ else:
184
+ # ensure LUT is in [0,1] if not already
185
+ if lut.max() > 1.0 + 1e-6:
186
+ lut = lut / 255.0
187
+ return lut.astype(np.float32)
188
+
189
+ def load_lut(path: str) -> np.ndarray:
190
+ """
191
+ Load a LUT from:
192
+ - .npy (numpy saved array)
193
+ - .cube (3D LUT)
194
+ - image (PNG/JPG) that is a 1D LUT strip (common 256x1 or 1x256)
195
+ Returns numpy array (1D, 2D, or 4D LUT).
196
+ """
197
+ ext = os.path.splitext(path)[1].lower()
198
+ if ext == '.npy':
199
+ return np.load(path)
200
+ elif ext == '.cube':
201
+ return load_cube_lut(path)
202
+ else:
203
+ # try interpreting as image-based 1D LUT
204
+ try:
205
+ im = Image.open(path).convert('RGB')
206
+ arr = np.array(im)
207
+ h, w = arr.shape[:2]
208
+ # 256x1 or 1x256 typical 1D LUT
209
+ if (w == 256 and h == 1) or (h == 256 and w == 1):
210
+ if h == 1:
211
+ lut = arr[0, :, :].astype(np.float32)
212
+ else:
213
+ lut = arr[:, 0, :].astype(np.float32)
214
+ return lut # shape (256,3)
215
+ # sometimes embedded as 512x16 or other tile layouts; attempt to flatten
216
+ # fallback: flatten and try to build (N,3)
217
+ flat = arr.reshape(-1, 3).astype(np.float32)
218
+ # if length is perfect power-of-two and <= 1024, assume 1D
219
+ L = flat.shape[0]
220
+ if L <= 4096:
221
+ return flat # (L,3)
222
+ raise ValueError("Image LUT not recognized size")
223
+ except Exception as e:
224
+ raise ValueError(f"Unsupported LUT file or parse error for {path}: {e}")
image_postprocess/utils/exif.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ def remove_exif_pil(img: Image.Image) -> Image.Image:
4
+ data = img.tobytes()
5
+ new = Image.frombytes(img.mode, img.size, data)
6
+ return new
image_postprocess/utils/fourier_pipeline.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.ndimage import gaussian_filter1d
3
+ from PIL import Image
4
+
5
+ def radial_profile(mag: np.ndarray, center=None, nbins=None):
6
+ h, w = mag.shape
7
+ if center is None:
8
+ cy, cx = h // 2, w // 2
9
+ else:
10
+ cy, cx = center
11
+
12
+ if nbins is None:
13
+ nbins = int(max(h, w) / 2)
14
+ nbins = max(1, int(nbins))
15
+
16
+ y = np.arange(h) - cy
17
+ x = np.arange(w) - cx
18
+ X, Y = np.meshgrid(x, y)
19
+ R = np.sqrt(X * X + Y * Y)
20
+
21
+ Rmax = R.max()
22
+ if Rmax <= 0:
23
+ Rnorm = R
24
+ else:
25
+ Rnorm = R / (Rmax + 1e-12)
26
+ Rnorm = np.minimum(Rnorm, 1.0 - 1e-12)
27
+
28
+ bin_edges = np.linspace(0.0, 1.0, nbins + 1)
29
+ bin_idx = np.digitize(Rnorm.ravel(), bin_edges) - 1
30
+ bin_idx = np.clip(bin_idx, 0, nbins - 1)
31
+
32
+ sums = np.bincount(bin_idx, weights=mag.ravel(), minlength=nbins)
33
+ counts = np.bincount(bin_idx, minlength=nbins)
34
+
35
+ radial_mean = np.zeros(nbins, dtype=np.float64)
36
+ nonzero = counts > 0
37
+ radial_mean[nonzero] = sums[nonzero] / counts[nonzero]
38
+
39
+ bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
40
+ return bin_centers, radial_mean
41
+
42
+ def fourier_match_spectrum(img_arr: np.ndarray,
43
+ ref_img_arr: np.ndarray = None,
44
+ mode='auto',
45
+ alpha=1.0,
46
+ cutoff=0.25,
47
+ strength=0.9,
48
+ randomness=0.05,
49
+ phase_perturb=0.08,
50
+ radial_smooth=5,
51
+ seed=None):
52
+ if seed is not None:
53
+ rng = np.random.default_rng(seed)
54
+ else:
55
+ rng = np.random.default_rng()
56
+
57
+ h, w = img_arr.shape[:2]
58
+ cy, cx = h // 2, w // 2
59
+ nbins = max(8, int(max(h, w) / 2))
60
+
61
+ if mode == 'auto':
62
+ mode = 'ref' if ref_img_arr is not None else 'model'
63
+
64
+ bin_centers_src = np.linspace(0.0, 1.0, nbins)
65
+
66
+ model_radial = None
67
+ if mode == 'model':
68
+ eps = 1e-8
69
+ model_radial = (1.0 / (bin_centers_src + eps)) ** (alpha / 2.0)
70
+ lf = max(1, nbins // 8)
71
+ model_radial = model_radial / (np.median(model_radial[:lf]) + 1e-12)
72
+ model_radial = gaussian_filter1d(model_radial, sigma=max(1, radial_smooth))
73
+
74
+ ref_radial = None
75
+ ref_bin_centers = None
76
+ if mode == 'ref' and ref_img_arr is not None:
77
+ if ref_img_arr.shape[0] != h or ref_img_arr.shape[1] != w:
78
+ ref_img = Image.fromarray(ref_img_arr).resize((w, h), resample=Image.BICUBIC)
79
+ ref_img_arr = np.array(ref_img)
80
+ ref_gray = np.mean(ref_img_arr.astype(np.float32), axis=2) if ref_img_arr.ndim == 3 else ref_img_arr.astype(np.float32)
81
+ Fref = np.fft.fftshift(np.fft.fft2(ref_gray))
82
+ Mref = np.abs(Fref)
83
+ ref_bin_centers, ref_radial = radial_profile(Mref, center=(h // 2, w // 2), nbins=nbins)
84
+ ref_radial = gaussian_filter1d(ref_radial, sigma=max(1, radial_smooth))
85
+
86
+ out = np.zeros_like(img_arr, dtype=np.float32)
87
+
88
+ y = np.linspace(-1, 1, h, endpoint=False)[:, None]
89
+ x = np.linspace(-1, 1, w, endpoint=False)[None, :]
90
+ r = np.sqrt(x * x + y * y)
91
+ r = np.clip(r, 0.0, 1.0 - 1e-6)
92
+
93
+ for c in range(img_arr.shape[2]):
94
+ channel = img_arr[:, :, c].astype(np.float32)
95
+ F = np.fft.fft2(channel)
96
+ Fshift = np.fft.fftshift(F)
97
+ mag = np.abs(Fshift)
98
+ phase = np.angle(Fshift)
99
+
100
+ bin_centers_src_calc, src_radial = radial_profile(mag, center=(h // 2, w // 2), nbins=nbins)
101
+ src_radial = gaussian_filter1d(src_radial, sigma=max(1, radial_smooth))
102
+ bin_centers_src = bin_centers_src_calc
103
+
104
+ if mode == 'ref' and ref_radial is not None:
105
+ ref_interp = np.interp(bin_centers_src, ref_bin_centers, ref_radial)
106
+ eps = 1e-8
107
+ ratio = (ref_interp + eps) / (src_radial + eps)
108
+ desired_radial = src_radial * ratio
109
+ elif mode == 'model' and model_radial is not None:
110
+ lf = max(1, nbins // 8)
111
+ scale = (np.median(src_radial[:lf]) + 1e-12) / (np.median(model_radial[:lf]) + 1e-12)
112
+ desired_radial = model_radial * scale
113
+ else:
114
+ desired_radial = src_radial.copy()
115
+
116
+ eps = 1e-8
117
+ multiplier_1d = (desired_radial + eps) / (src_radial + eps)
118
+ multiplier_1d = np.clip(multiplier_1d, 0.2, 5.0)
119
+ mult_2d = np.interp(r.ravel(), bin_centers_src, multiplier_1d).reshape(h, w)
120
+
121
+ edge = 0.05 + 0.02 * (1.0 - cutoff) if 'cutoff' in globals() else 0.05
122
+ edge = max(edge, 1e-6)
123
+ weight = np.where(r <= 0.25, 1.0,
124
+ np.where(r <= 0.25 + edge,
125
+ 0.5 * (1 + np.cos(np.pi * (r - 0.25) / edge)),
126
+ 0.0))
127
+
128
+ final_multiplier = 1.0 + (mult_2d - 1.0) * (weight * strength)
129
+
130
+ if randomness and randomness > 0.0:
131
+ noise = rng.normal(loc=1.0, scale=randomness, size=final_multiplier.shape)
132
+ final_multiplier *= (1.0 + (noise - 1.0) * weight)
133
+
134
+ mag2 = mag * final_multiplier
135
+
136
+ if phase_perturb and phase_perturb > 0.0:
137
+ phase_sigma = phase_perturb * np.clip((r - 0.25) / (1.0 - 0.25 + 1e-6), 0.0, 1.0)
138
+ phase_noise = rng.standard_normal(size=phase_sigma.shape) * phase_sigma
139
+ phase2 = phase + phase_noise
140
+ else:
141
+ phase2 = phase
142
+
143
+ Fshift2 = mag2 * np.exp(1j * phase2)
144
+ F_ishift = np.fft.ifftshift(Fshift2)
145
+ img_back = np.fft.ifft2(F_ishift)
146
+ img_back = np.real(img_back)
147
+
148
+ blended = (1.0 - strength) * channel + strength * img_back
149
+ out[:, :, c] = blended
150
+
151
+ out = np.clip(out, 0, 255).astype(np.uint8)
152
+ return out
image_postprocess/utils/gaussian_noise.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def add_gaussian_noise(img_arr: np.ndarray, std_frac=0.02, seed=None) -> np.ndarray:
4
+ if seed is not None:
5
+ np.random.seed(seed)
6
+ std = std_frac * 255.0
7
+ noise = np.random.normal(loc=0.0, scale=std, size=img_arr.shape)
8
+ out = img_arr.astype(np.float32) + noise
9
+ out = np.clip(out, 0, 255).astype(np.uint8)
10
+ return out
image_postprocess/utils/perturbation.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def randomized_perturbation(img_arr: np.ndarray, magnitude_frac=0.008, seed=None) -> np.ndarray:
4
+ if seed is not None:
5
+ np.random.seed(seed)
6
+ mag = magnitude_frac * 255.0
7
+ perturb = np.random.uniform(low=-mag, high=mag, size=img_arr.shape)
8
+ out = img_arr.astype(np.float32) + perturb
9
+ out = np.clip(out, 0, 255).astype(np.uint8)
10
+ return out
image_postprocess/utils_deprecated.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ utils.py
3
+
4
+ Helper functions for image postprocessing, including EXIF removal, noise addition,
5
+ color correction, and Fourier spectrum matching.
6
+ """
7
+ import os
8
+ import re
9
+ from PIL import Image, ImageOps
10
+ import numpy as np
11
+ try:
12
+ import cv2
13
+ _HAS_CV2 = True
14
+ except Exception:
15
+ cv2 = None
16
+ _HAS_CV2 = False
17
+ from scipy.ndimage import gaussian_filter1d
18
+
19
+ def remove_exif_pil(img: Image.Image) -> Image.Image:
20
+ data = img.tobytes()
21
+ new = Image.frombytes(img.mode, img.size, data)
22
+ return new
23
+
24
+ def add_gaussian_noise(img_arr: np.ndarray, std_frac=0.02, seed=None) -> np.ndarray:
25
+ if seed is not None:
26
+ np.random.seed(seed)
27
+ std = std_frac * 255.0
28
+ noise = np.random.normal(loc=0.0, scale=std, size=img_arr.shape)
29
+ out = img_arr.astype(np.float32) + noise
30
+ out = np.clip(out, 0, 255).astype(np.uint8)
31
+ return out
32
+
33
+ def clahe_color_correction(img_arr: np.ndarray, clip_limit=2.0, tile_grid_size=(8,8)) -> np.ndarray:
34
+ if _HAS_CV2:
35
+ lab = cv2.cvtColor(img_arr, cv2.COLOR_RGB2LAB)
36
+ l, a, b = cv2.split(lab)
37
+ clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)
38
+ l2 = clahe.apply(l)
39
+ lab2 = cv2.merge((l2, a, b))
40
+ out = cv2.cvtColor(lab2, cv2.COLOR_LAB2RGB)
41
+ return out
42
+ else:
43
+ pil = Image.fromarray(img_arr)
44
+ channels = pil.split()
45
+ new_ch = []
46
+ for ch in channels:
47
+ eq = ImageOps.equalize(ch)
48
+ new_ch.append(eq)
49
+ merged = Image.merge('RGB', new_ch)
50
+ return np.array(merged)
51
+
52
+ def randomized_perturbation(img_arr: np.ndarray, magnitude_frac=0.008, seed=None) -> np.ndarray:
53
+ if seed is not None:
54
+ np.random.seed(seed)
55
+ mag = magnitude_frac * 255.0
56
+ perturb = np.random.uniform(low=-mag, high=mag, size=img_arr.shape)
57
+ out = img_arr.astype(np.float32) + perturb
58
+ out = np.clip(out, 0, 255).astype(np.uint8)
59
+ return out
60
+
61
+ def radial_profile(mag: np.ndarray, center=None, nbins=None):
62
+ h, w = mag.shape
63
+ if center is None:
64
+ cy, cx = h // 2, w // 2
65
+ else:
66
+ cy, cx = center
67
+
68
+ if nbins is None:
69
+ nbins = int(max(h, w) / 2)
70
+ nbins = max(1, int(nbins))
71
+
72
+ y = np.arange(h) - cy
73
+ x = np.arange(w) - cx
74
+ X, Y = np.meshgrid(x, y)
75
+ R = np.sqrt(X * X + Y * Y)
76
+
77
+ Rmax = R.max()
78
+ if Rmax <= 0:
79
+ Rnorm = R
80
+ else:
81
+ Rnorm = R / (Rmax + 1e-12)
82
+ Rnorm = np.minimum(Rnorm, 1.0 - 1e-12)
83
+
84
+ bin_edges = np.linspace(0.0, 1.0, nbins + 1)
85
+ bin_idx = np.digitize(Rnorm.ravel(), bin_edges) - 1
86
+ bin_idx = np.clip(bin_idx, 0, nbins - 1)
87
+
88
+ sums = np.bincount(bin_idx, weights=mag.ravel(), minlength=nbins)
89
+ counts = np.bincount(bin_idx, minlength=nbins)
90
+
91
+ radial_mean = np.zeros(nbins, dtype=np.float64)
92
+ nonzero = counts > 0
93
+ radial_mean[nonzero] = sums[nonzero] / counts[nonzero]
94
+
95
+ bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
96
+ return bin_centers, radial_mean
97
+
98
+ def fourier_match_spectrum(img_arr: np.ndarray,
99
+ ref_img_arr: np.ndarray = None,
100
+ mode='auto',
101
+ alpha=1.0,
102
+ cutoff=0.25,
103
+ strength=0.9,
104
+ randomness=0.05,
105
+ phase_perturb=0.08,
106
+ radial_smooth=5,
107
+ seed=None):
108
+ if seed is not None:
109
+ rng = np.random.default_rng(seed)
110
+ else:
111
+ rng = np.random.default_rng()
112
+
113
+ h, w = img_arr.shape[:2]
114
+ cy, cx = h // 2, w // 2
115
+ nbins = max(8, int(max(h, w) / 2))
116
+
117
+ if mode == 'auto':
118
+ mode = 'ref' if ref_img_arr is not None else 'model'
119
+
120
+ bin_centers_src = np.linspace(0.0, 1.0, nbins)
121
+
122
+ model_radial = None
123
+ if mode == 'model':
124
+ eps = 1e-8
125
+ model_radial = (1.0 / (bin_centers_src + eps)) ** (alpha / 2.0)
126
+ lf = max(1, nbins // 8)
127
+ model_radial = model_radial / (np.median(model_radial[:lf]) + 1e-12)
128
+ model_radial = gaussian_filter1d(model_radial, sigma=max(1, radial_smooth))
129
+
130
+ ref_radial = None
131
+ ref_bin_centers = None
132
+ if mode == 'ref' and ref_img_arr is not None:
133
+ if ref_img_arr.shape[0] != h or ref_img_arr.shape[1] != w:
134
+ ref_img = Image.fromarray(ref_img_arr).resize((w, h), resample=Image.BICUBIC)
135
+ ref_img_arr = np.array(ref_img)
136
+ ref_gray = np.mean(ref_img_arr.astype(np.float32), axis=2) if ref_img_arr.ndim == 3 else ref_img_arr.astype(np.float32)
137
+ Fref = np.fft.fftshift(np.fft.fft2(ref_gray))
138
+ Mref = np.abs(Fref)
139
+ ref_bin_centers, ref_radial = radial_profile(Mref, center=(h // 2, w // 2), nbins=nbins)
140
+ ref_radial = gaussian_filter1d(ref_radial, sigma=max(1, radial_smooth))
141
+
142
+ out = np.zeros_like(img_arr, dtype=np.float32)
143
+
144
+ y = np.linspace(-1, 1, h, endpoint=False)[:, None]
145
+ x = np.linspace(-1, 1, w, endpoint=False)[None, :]
146
+ r = np.sqrt(x * x + y * y)
147
+ r = np.clip(r, 0.0, 1.0 - 1e-6)
148
+
149
+ for c in range(img_arr.shape[2]):
150
+ channel = img_arr[:, :, c].astype(np.float32)
151
+ F = np.fft.fft2(channel)
152
+ Fshift = np.fft.fftshift(F)
153
+ mag = np.abs(Fshift)
154
+ phase = np.angle(Fshift)
155
+
156
+ bin_centers_src_calc, src_radial = radial_profile(mag, center=(h // 2, w // 2), nbins=nbins)
157
+ src_radial = gaussian_filter1d(src_radial, sigma=max(1, radial_smooth))
158
+ bin_centers_src = bin_centers_src_calc
159
+
160
+ if mode == 'ref' and ref_radial is not None:
161
+ ref_interp = np.interp(bin_centers_src, ref_bin_centers, ref_radial)
162
+ eps = 1e-8
163
+ ratio = (ref_interp + eps) / (src_radial + eps)
164
+ desired_radial = src_radial * ratio
165
+ elif mode == 'model' and model_radial is not None:
166
+ lf = max(1, nbins // 8)
167
+ scale = (np.median(src_radial[:lf]) + 1e-12) / (np.median(model_radial[:lf]) + 1e-12)
168
+ desired_radial = model_radial * scale
169
+ else:
170
+ desired_radial = src_radial.copy()
171
+
172
+ eps = 1e-8
173
+ multiplier_1d = (desired_radial + eps) / (src_radial + eps)
174
+ multiplier_1d = np.clip(multiplier_1d, 0.2, 5.0)
175
+ mult_2d = np.interp(r.ravel(), bin_centers_src, multiplier_1d).reshape(h, w)
176
+
177
+ edge = 0.05 + 0.02 * (1.0 - cutoff) if 'cutoff' in globals() else 0.05
178
+ edge = max(edge, 1e-6)
179
+ weight = np.where(r <= 0.25, 1.0,
180
+ np.where(r <= 0.25 + edge,
181
+ 0.5 * (1 + np.cos(np.pi * (r - 0.25) / edge)),
182
+ 0.0))
183
+
184
+ final_multiplier = 1.0 + (mult_2d - 1.0) * (weight * strength)
185
+
186
+ if randomness and randomness > 0.0:
187
+ noise = rng.normal(loc=1.0, scale=randomness, size=final_multiplier.shape)
188
+ final_multiplier *= (1.0 + (noise - 1.0) * weight)
189
+
190
+ mag2 = mag * final_multiplier
191
+
192
+ if phase_perturb and phase_perturb > 0.0:
193
+ phase_sigma = phase_perturb * np.clip((r - 0.25) / (1.0 - 0.25 + 1e-6), 0.0, 1.0)
194
+ phase_noise = rng.standard_normal(size=phase_sigma.shape) * phase_sigma
195
+ phase2 = phase + phase_noise
196
+ else:
197
+ phase2 = phase
198
+
199
+ Fshift2 = mag2 * np.exp(1j * phase2)
200
+ F_ishift = np.fft.ifftshift(Fshift2)
201
+ img_back = np.fft.ifft2(F_ishift)
202
+ img_back = np.real(img_back)
203
+
204
+ blended = (1.0 - strength) * channel + strength * img_back
205
+ out[:, :, c] = blended
206
+
207
+ out = np.clip(out, 0, 255).astype(np.uint8)
208
+ return out
209
+
210
+ def auto_white_balance_ref(img_arr: np.ndarray, ref_img_arr: np.ndarray = None) -> np.ndarray:
211
+ """
212
+ Auto white-balance correction using a reference image.
213
+ If ref_img_arr is None, uses a gray-world assumption instead.
214
+ """
215
+ img = img_arr.astype(np.float32)
216
+
217
+ if ref_img_arr is not None:
218
+ ref = ref_img_arr.astype(np.float32)
219
+ ref_mean = ref.reshape(-1, 3).mean(axis=0)
220
+ else:
221
+ # Gray-world assumption: target is neutral gray
222
+ ref_mean = np.array([128.0, 128.0, 128.0], dtype=np.float32)
223
+
224
+ img_mean = img.reshape(-1, 3).mean(axis=0)
225
+
226
+ # Avoid divide-by-zero
227
+ eps = 1e-6
228
+ scale = (ref_mean + eps) / (img_mean + eps)
229
+
230
+ corrected = img * scale
231
+ corrected = np.clip(corrected, 0, 255).astype(np.uint8)
232
+
233
+ return corrected
234
+
235
+ def apply_1d_lut(img_arr: np.ndarray, lut: np.ndarray, strength: float = 1.0) -> np.ndarray:
236
+ """
237
+ Apply a 1D LUT to an image.
238
+ - img_arr: HxWx3 uint8
239
+ - lut: either shape (256,) (applied equally to all channels), (256,3) (per-channel),
240
+ or (N,) / (N,3) (interpolated across [0..255])
241
+ - strength: 0..1 blending between original and LUT result
242
+ Returns uint8 array.
243
+ """
244
+ if img_arr.ndim != 3 or img_arr.shape[2] != 3:
245
+ raise ValueError("apply_1d_lut expects an HxWx3 image array")
246
+
247
+ # Normalize indices 0..255
248
+ arr = img_arr.astype(np.float32)
249
+ # Prepare LUT as float in 0..255 range if necessary
250
+ lut_arr = np.array(lut, dtype=np.float32)
251
+
252
+ # If single channel LUT (N,) expand to three channels
253
+ if lut_arr.ndim == 1:
254
+ lut_arr = np.stack([lut_arr, lut_arr, lut_arr], axis=1) # (N,3)
255
+
256
+ if lut_arr.shape[1] != 3:
257
+ raise ValueError("1D LUT must have shape (N,) or (N,3)")
258
+
259
+ # Build index positions in source LUT space (0..255)
260
+ N = lut_arr.shape[0]
261
+ src_positions = np.linspace(0, 255, N)
262
+
263
+ # Flatten and interpolate per channel
264
+ out = np.empty_like(arr)
265
+ for c in range(3):
266
+ channel = arr[..., c].ravel()
267
+ mapped = np.interp(channel, src_positions, lut_arr[:, c])
268
+ out[..., c] = mapped.reshape(arr.shape[0], arr.shape[1])
269
+
270
+ out = np.clip(out, 0, 255).astype(np.uint8)
271
+ if strength >= 1.0:
272
+ return out
273
+ else:
274
+ blended = ((1.0 - strength) * img_arr.astype(np.float32) + strength * out.astype(np.float32))
275
+ return np.clip(blended, 0, 255).astype(np.uint8)
276
+
277
+ def _trilinear_sample_lut(img_float: np.ndarray, lut: np.ndarray) -> np.ndarray:
278
+ """
279
+ Vectorized trilinear sampling of 3D LUT.
280
+ - img_float: HxWx3 floats in [0,1]
281
+ - lut: SxSxS x 3 floats in [0,1]
282
+ Returns HxWx3 floats in [0,1]
283
+ """
284
+ S = lut.shape[0]
285
+ if lut.shape[0] != lut.shape[1] or lut.shape[1] != lut.shape[2]:
286
+ raise ValueError("3D LUT must be cubic (SxSxSx3)")
287
+
288
+ # map [0,1] -> [0, S-1]
289
+ idx = img_float * (S - 1)
290
+ r_idx = idx[..., 0]
291
+ g_idx = idx[..., 1]
292
+ b_idx = idx[..., 2]
293
+
294
+ r0 = np.floor(r_idx).astype(np.int32)
295
+ g0 = np.floor(g_idx).astype(np.int32)
296
+ b0 = np.floor(b_idx).astype(np.int32)
297
+
298
+ r1 = np.clip(r0 + 1, 0, S - 1)
299
+ g1 = np.clip(g0 + 1, 0, S - 1)
300
+ b1 = np.clip(b0 + 1, 0, S - 1)
301
+
302
+ dr = (r_idx - r0)[..., None]
303
+ dg = (g_idx - g0)[..., None]
304
+ db = (b_idx - b0)[..., None]
305
+
306
+ # gather 8 corners: c000 ... c111
307
+ c000 = lut[r0, g0, b0]
308
+ c001 = lut[r0, g0, b1]
309
+ c010 = lut[r0, g1, b0]
310
+ c011 = lut[r0, g1, b1]
311
+ c100 = lut[r1, g0, b0]
312
+ c101 = lut[r1, g0, b1]
313
+ c110 = lut[r1, g1, b0]
314
+ c111 = lut[r1, g1, b1]
315
+
316
+ # interpolate along b
317
+ c00 = c000 * (1 - db) + c001 * db
318
+ c01 = c010 * (1 - db) + c011 * db
319
+ c10 = c100 * (1 - db) + c101 * db
320
+ c11 = c110 * (1 - db) + c111 * db
321
+
322
+ # interpolate along g
323
+ c0 = c00 * (1 - dg) + c01 * dg
324
+ c1 = c10 * (1 - dg) + c11 * dg
325
+
326
+ # interpolate along r
327
+ c = c0 * (1 - dr) + c1 * dr
328
+
329
+ return c # float in same range as lut (expected [0,1])
330
+
331
+ def apply_3d_lut(img_arr: np.ndarray, lut3d: np.ndarray, strength: float = 1.0) -> np.ndarray:
332
+ """
333
+ Apply a 3D LUT to the image.
334
+ - img_arr: HxWx3 uint8
335
+ - lut3d: SxSxSx3 float (expected range 0..1)
336
+ - strength: blending 0..1
337
+ Returns uint8 image.
338
+ """
339
+ if img_arr.ndim != 3 or img_arr.shape[2] != 3:
340
+ raise ValueError("apply_3d_lut expects an HxWx3 image array")
341
+
342
+ img_float = img_arr.astype(np.float32) / 255.0
343
+ sampled = _trilinear_sample_lut(img_float, lut3d) # HxWx3 floats in [0,1]
344
+ out = np.clip(sampled * 255.0, 0, 255).astype(np.uint8)
345
+ if strength >= 1.0:
346
+ return out
347
+ else:
348
+ blended = ((1.0 - strength) * img_arr.astype(np.float32) + strength * out.astype(np.float32))
349
+ return np.clip(blended, 0, 255).astype(np.uint8)
350
+
351
+ def apply_lut(img_arr: np.ndarray, lut: np.ndarray, strength: float = 1.0) -> np.ndarray:
352
+ """
353
+ Auto-detect LUT type and apply.
354
+ - If lut.ndim in (1,2) treat as 1D LUT (per-channel if shape (N,3)).
355
+ - If lut.ndim == 4 treat as 3D LUT (SxSxSx3) in [0,1].
356
+ """
357
+ lut = np.array(lut)
358
+ if lut.ndim == 4 and lut.shape[3] == 3:
359
+ # 3D LUT (assumed normalized [0..1])
360
+ # If lut is in 0..255, normalize
361
+ if lut.dtype != np.float32 and lut.max() > 1.0:
362
+ lut = lut.astype(np.float32) / 255.0
363
+ return apply_3d_lut(img_arr, lut, strength=strength)
364
+ elif lut.ndim in (1, 2):
365
+ return apply_1d_lut(img_arr, lut, strength=strength)
366
+ else:
367
+ raise ValueError("Unsupported LUT shape: {}".format(lut.shape))
368
+
369
+ def load_cube_lut(path: str) -> np.ndarray:
370
+ """
371
+ Parse a .cube file and return a 3D LUT array of shape (S,S,S,3) with float values in [0,1].
372
+ Note: .cube file order sometimes varies; this function assumes standard ordering
373
+ where data lines are triples of floats and LUT_3D_SIZE specifies S.
374
+ """
375
+ with open(path, 'r', encoding='utf-8', errors='ignore') as f:
376
+ lines = [ln.strip() for ln in f if ln.strip() and not ln.strip().startswith('#')]
377
+
378
+ size = None
379
+ data = []
380
+ domain_min = np.array([0.0, 0.0, 0.0], dtype=np.float32)
381
+ domain_max = np.array([1.0, 1.0, 1.0], dtype=np.float32)
382
+
383
+ for ln in lines:
384
+ if ln.upper().startswith('LUT_3D_SIZE'):
385
+ parts = ln.split()
386
+ if len(parts) >= 2:
387
+ size = int(parts[1])
388
+ elif ln.upper().startswith('DOMAIN_MIN'):
389
+ parts = ln.split()
390
+ domain_min = np.array([float(p) for p in parts[1:4]], dtype=np.float32)
391
+ elif ln.upper().startswith('DOMAIN_MAX'):
392
+ parts = ln.split()
393
+ domain_max = np.array([float(p) for p in parts[1:4]], dtype=np.float32)
394
+ elif re.match(r'^-?\d+(\.\d+)?\s+-?\d+(\.\d+)?\s+-?\d+(\.\d+)?$', ln):
395
+ parts = [float(x) for x in ln.split()]
396
+ data.append(parts)
397
+
398
+ if size is None:
399
+ raise ValueError("LUT_3D_SIZE not found in .cube file: {}".format(path))
400
+
401
+ data = np.array(data, dtype=np.float32)
402
+ if data.shape[0] != size**3:
403
+ raise ValueError("Cube LUT data length does not match size^3 (got {}, expected {})".format(data.shape[0], size**3))
404
+
405
+ # Data ordering in many .cube files is: for r in 0..S-1: for g in 0..S-1: for b in 0..S-1: write RGB
406
+ # We'll reshape into (S,S,S,3) with indices [r,g,b]
407
+ lut = data.reshape((size, size, size, 3))
408
+ # Map domain_min..domain_max to 0..1 if domain specified (rare)
409
+ if not np.allclose(domain_min, [0.0, 0.0, 0.0]) or not np.allclose(domain_max, [1.0, 1.0, 1.0]):
410
+ # scale lut values from domain range into 0..1
411
+ lut = (lut - domain_min) / (domain_max - domain_min + 1e-12)
412
+ lut = np.clip(lut, 0.0, 1.0)
413
+ else:
414
+ # ensure LUT is in [0,1] if not already
415
+ if lut.max() > 1.0 + 1e-6:
416
+ lut = lut / 255.0
417
+ return lut.astype(np.float32)
418
+
419
+ def load_lut(path: str) -> np.ndarray:
420
+ """
421
+ Load a LUT from:
422
+ - .npy (numpy saved array)
423
+ - .cube (3D LUT)
424
+ - image (PNG/JPG) that is a 1D LUT strip (common 256x1 or 1x256)
425
+ Returns numpy array (1D, 2D, or 4D LUT).
426
+ """
427
+ ext = os.path.splitext(path)[1].lower()
428
+ if ext == '.npy':
429
+ return np.load(path)
430
+ elif ext == '.cube':
431
+ return load_cube_lut(path)
432
+ else:
433
+ # try interpreting as image-based 1D LUT
434
+ try:
435
+ im = Image.open(path).convert('RGB')
436
+ arr = np.array(im)
437
+ h, w = arr.shape[:2]
438
+ # 256x1 or 1x256 typical 1D LUT
439
+ if (w == 256 and h == 1) or (h == 256 and w == 1):
440
+ if h == 1:
441
+ lut = arr[0, :, :].astype(np.float32)
442
+ else:
443
+ lut = arr[:, 0, :].astype(np.float32)
444
+ return lut # shape (256,3)
445
+ # sometimes embedded as 512x16 or other tile layouts; attempt to flatten
446
+ # fallback: flatten and try to build (N,3)
447
+ flat = arr.reshape(-1, 3).astype(np.float32)
448
+ # if length is perfect power-of-two and <= 1024, assume 1D
449
+ L = flat.shape[0]
450
+ if L <= 4096:
451
+ return flat # (L,3)
452
+ raise ValueError("Image LUT not recognized size")
453
+ except Exception as e:
454
+ raise ValueError(f"Unsupported LUT file or parse error for {path}: {e}")
455
+
456
+ # --- end appended LUT helpers
main_window.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ MainWindow definition extracted from the original single-file GUI.
4
+ All GUI wiring, widgets, and the MainWindow class live here.
5
+ """
6
+
7
+ import sys
8
+ import os
9
+ from pathlib import Path
10
+ from PyQt5.QtWidgets import (
11
+ QApplication, QMainWindow, QWidget, QLabel, QPushButton, QFileDialog,
12
+ QHBoxLayout, QVBoxLayout, QFormLayout, QSlider, QSpinBox, QDoubleSpinBox,
13
+ QProgressBar, QMessageBox, QLineEdit, QComboBox, QCheckBox, QToolButton, QScrollArea
14
+ )
15
+ from PyQt5.QtCore import Qt
16
+ from PyQt5.QtGui import QPixmap
17
+ from worker import Worker
18
+ from analysis_panel import AnalysisPanel
19
+ from utils import qpixmap_from_path
20
+ from collapsible_box import CollapsibleBox
21
+ from theme import apply_dark_palette
22
+
23
+ class MainWindow(QMainWindow):
24
+ def __init__(self):
25
+ super().__init__()
26
+ self.setWindowTitle("Image Postprocess — GUI (Camera Simulator)")
27
+ self.setMinimumSize(1200, 760)
28
+
29
+ central = QWidget()
30
+ self.setCentralWidget(central)
31
+ main_h = QHBoxLayout(central)
32
+
33
+ # Left: previews & file selection
34
+ left_v = QVBoxLayout()
35
+ main_h.addLayout(left_v, 2)
36
+
37
+ # Input/Output collapsible
38
+ io_box = CollapsibleBox("Input / Output")
39
+ left_v.addWidget(io_box)
40
+ in_layout = QFormLayout()
41
+ io_container = QWidget()
42
+ io_container.setLayout(in_layout)
43
+ io_box.content_layout.addWidget(io_container)
44
+
45
+ self.input_line = QLineEdit()
46
+ self.input_btn = QPushButton("Choose Input")
47
+ self.input_btn.clicked.connect(self.choose_input)
48
+
49
+ self.ref_line = QLineEdit()
50
+ self.ref_btn = QPushButton("Choose AWB Reference (optional)")
51
+ self.ref_btn.clicked.connect(self.choose_ref)
52
+
53
+ self.fft_ref_line = QLineEdit()
54
+ self.fft_ref_btn = QPushButton("Choose FFT Reference (optional)")
55
+ self.fft_ref_btn.clicked.connect(self.choose_fft_ref)
56
+
57
+ self.output_line = QLineEdit()
58
+ self.output_btn = QPushButton("Choose Output")
59
+ self.output_btn.clicked.connect(self.choose_output)
60
+
61
+ in_layout.addRow(self.input_btn, self.input_line)
62
+ in_layout.addRow(self.ref_btn, self.ref_line)
63
+ in_layout.addRow(self.fft_ref_btn, self.fft_ref_line)
64
+ in_layout.addRow(self.output_btn, self.output_line)
65
+
66
+ # Previews
67
+ self.preview_in = QLabel(alignment=Qt.AlignCenter)
68
+ self.preview_in.setFixedSize(480, 300)
69
+ self.preview_in.setStyleSheet("background:#121213; border:1px solid #2b2b2b; color:#ddd; border-radius:6px")
70
+ self.preview_in.setText("Input preview")
71
+
72
+ self.preview_out = QLabel(alignment=Qt.AlignCenter)
73
+ self.preview_out.setFixedSize(480, 300)
74
+ self.preview_out.setStyleSheet("background:#121213; border:1px solid #2b2b2b; color:#ddd; border-radius:6px")
75
+ self.preview_out.setText("Output preview")
76
+
77
+ left_v.addWidget(self.preview_in)
78
+ left_v.addWidget(self.preview_out)
79
+
80
+ # Actions
81
+ actions_h = QHBoxLayout()
82
+ self.run_btn = QPushButton("Run — Process Image")
83
+ self.run_btn.clicked.connect(self.on_run)
84
+ self.open_out_btn = QPushButton("Open Output Folder")
85
+ self.open_out_btn.clicked.connect(self.open_output_folder)
86
+ actions_h.addWidget(self.run_btn)
87
+ actions_h.addWidget(self.open_out_btn)
88
+ left_v.addLayout(actions_h)
89
+
90
+ self.progress = QProgressBar()
91
+ self.progress.setTextVisible(True)
92
+ self.progress.setRange(0, 100)
93
+ self.progress.setValue(0)
94
+ left_v.addWidget(self.progress)
95
+
96
+ # Right: controls + analysis panels (with scroll area)
97
+ scroll_area = QScrollArea()
98
+ scroll_area.setWidgetResizable(True)
99
+ scroll_area.setStyleSheet("QScrollArea { border: none; }")
100
+ main_h.addWidget(scroll_area, 3)
101
+
102
+ scroll_widget = QWidget()
103
+ right_v = QVBoxLayout(scroll_widget)
104
+ scroll_area.setWidget(scroll_widget)
105
+
106
+ # Auto Mode toggle (keeps top-level quick switch visible)
107
+ self.auto_mode_chk = QCheckBox("Enable Auto Mode")
108
+ self.auto_mode_chk.setChecked(False)
109
+ self.auto_mode_chk.stateChanged.connect(self._on_auto_mode_toggled)
110
+ right_v.addWidget(self.auto_mode_chk)
111
+
112
+ # Make Auto Mode section collapsible
113
+ self.auto_box = CollapsibleBox("Auto Mode")
114
+ right_v.addWidget(self.auto_box)
115
+ auto_layout = QFormLayout()
116
+ auto_container = QWidget()
117
+ auto_container.setLayout(auto_layout)
118
+ self.auto_box.content_layout.addWidget(auto_container)
119
+
120
+ strength_layout = QHBoxLayout()
121
+ self.strength_slider = QSlider(Qt.Horizontal)
122
+ self.strength_slider.setRange(0, 100)
123
+ self.strength_slider.setValue(25)
124
+ self.strength_slider.valueChanged.connect(self._update_strength_label)
125
+ self.strength_label = QLabel("25")
126
+ self.strength_label.setFixedWidth(30)
127
+ strength_layout.addWidget(self.strength_slider)
128
+ strength_layout.addWidget(self.strength_label)
129
+
130
+ auto_layout.addRow("Aberration Strength", strength_layout)
131
+
132
+ # Parameters (Manual Mode) collapsible
133
+ self.params_box = CollapsibleBox("Parameters (Manual Mode)")
134
+ right_v.addWidget(self.params_box)
135
+ params_layout = QFormLayout()
136
+ params_container = QWidget()
137
+ params_container.setLayout(params_layout)
138
+ self.params_box.content_layout.addWidget(params_container)
139
+
140
+ # Noise-std
141
+ self.noise_spin = QDoubleSpinBox()
142
+ self.noise_spin.setRange(0.0, 0.1)
143
+ self.noise_spin.setSingleStep(0.001)
144
+ self.noise_spin.setValue(0.02)
145
+ self.noise_spin.setToolTip("Gaussian noise std fraction of 255")
146
+ params_layout.addRow("Noise std (0-0.1)", self.noise_spin)
147
+
148
+ # CLAHE-clip
149
+ self.clahe_spin = QDoubleSpinBox()
150
+ self.clahe_spin.setRange(0.1, 10.0)
151
+ self.clahe_spin.setSingleStep(0.1)
152
+ self.clahe_spin.setValue(2.0)
153
+ params_layout.addRow("CLAHE clip", self.clahe_spin)
154
+
155
+ # Tile
156
+ self.tile_spin = QSpinBox()
157
+ self.tile_spin.setRange(1, 64)
158
+ self.tile_spin.setValue(8)
159
+ params_layout.addRow("CLAHE tile", self.tile_spin)
160
+
161
+ # Cutoff
162
+ self.cutoff_spin = QDoubleSpinBox()
163
+ self.cutoff_spin.setRange(0.01, 1.0)
164
+ self.cutoff_spin.setSingleStep(0.01)
165
+ self.cutoff_spin.setValue(0.25)
166
+ params_layout.addRow("Fourier cutoff (0-1)", self.cutoff_spin)
167
+
168
+ # Fstrength
169
+ self.fstrength_spin = QDoubleSpinBox()
170
+ self.fstrength_spin.setRange(0.0, 1.0)
171
+ self.fstrength_spin.setSingleStep(0.01)
172
+ self.fstrength_spin.setValue(0.9)
173
+ params_layout.addRow("Fourier strength (0-1)", self.fstrength_spin)
174
+
175
+ # Randomness
176
+ self.randomness_spin = QDoubleSpinBox()
177
+ self.randomness_spin.setRange(0.0, 1.0)
178
+ self.randomness_spin.setSingleStep(0.01)
179
+ self.randomness_spin.setValue(0.05)
180
+ params_layout.addRow("Fourier randomness", self.randomness_spin)
181
+
182
+ # Phase_perturb
183
+ self.phase_perturb_spin = QDoubleSpinBox()
184
+ self.phase_perturb_spin.setRange(0.0, 1.0)
185
+ self.phase_perturb_spin.setSingleStep(0.001)
186
+ self.phase_perturb_spin.setValue(0.08)
187
+ self.phase_perturb_spin.setToolTip("Phase perturbation std (radians)")
188
+ params_layout.addRow("Phase perturb (rad)", self.phase_perturb_spin)
189
+
190
+ # Radial_smooth
191
+ self.radial_smooth_spin = QSpinBox()
192
+ self.radial_smooth_spin.setRange(0, 50)
193
+ self.radial_smooth_spin.setValue(5)
194
+ params_layout.addRow("Radial smooth (bins)", self.radial_smooth_spin)
195
+
196
+ # FFT_mode
197
+ self.fft_mode_combo = QComboBox()
198
+ self.fft_mode_combo.addItems(["auto", "ref", "model"])
199
+ self.fft_mode_combo.setCurrentText("auto")
200
+ params_layout.addRow("FFT mode", self.fft_mode_combo)
201
+
202
+ # FFT_alpha
203
+ self.fft_alpha_spin = QDoubleSpinBox()
204
+ self.fft_alpha_spin.setRange(0.1, 4.0)
205
+ self.fft_alpha_spin.setSingleStep(0.1)
206
+ self.fft_alpha_spin.setValue(1.0)
207
+ self.fft_alpha_spin.setToolTip("Alpha exponent for 1/f model when using model mode")
208
+ params_layout.addRow("FFT alpha (model)", self.fft_alpha_spin)
209
+
210
+ # Perturb
211
+ self.perturb_spin = QDoubleSpinBox()
212
+ self.perturb_spin.setRange(0.0, 0.05)
213
+ self.perturb_spin.setSingleStep(0.001)
214
+ self.perturb_spin.setValue(0.008)
215
+ params_layout.addRow("Pixel perturb", self.perturb_spin)
216
+
217
+ # Seed
218
+ self.seed_spin = QSpinBox()
219
+ self.seed_spin.setRange(0, 2 ** 31 - 1)
220
+ self.seed_spin.setValue(0)
221
+ params_layout.addRow("Seed (0=none)", self.seed_spin)
222
+
223
+ # AWB checkbox
224
+ self.awb_chk = QCheckBox("Enable auto white-balance (AWB)")
225
+ self.awb_chk.setChecked(False)
226
+ self.awb_chk.setToolTip("If checked, AWB is applied. If a reference image is chosen, it will be used; otherwise gray-world AWB is applied.")
227
+ params_layout.addRow(self.awb_chk)
228
+
229
+ # Camera simulator toggle
230
+ self.sim_camera_chk = QCheckBox("Enable camera pipeline simulation")
231
+ self.sim_camera_chk.setChecked(False)
232
+ self.sim_camera_chk.stateChanged.connect(self._on_sim_camera_toggled)
233
+ params_layout.addRow(self.sim_camera_chk)
234
+
235
+ # --- LUT support UI ---
236
+ self.lut_chk = QCheckBox("Enable LUT")
237
+ self.lut_chk.setChecked(False)
238
+ self.lut_chk.setToolTip("Enable applying a 1D/.npy/.cube LUT to the output image")
239
+ self.lut_chk.stateChanged.connect(self._on_lut_toggled)
240
+ params_layout.addRow(self.lut_chk)
241
+
242
+ # LUT chooser (hidden until checkbox checked)
243
+ self.lut_line = QLineEdit()
244
+ self.lut_btn = QPushButton("Choose LUT")
245
+ self.lut_btn.clicked.connect(self.choose_lut)
246
+ lut_box = QWidget()
247
+ lut_box_layout = QHBoxLayout()
248
+ lut_box_layout.setContentsMargins(0, 0, 0, 0)
249
+ lut_box.setLayout(lut_box_layout)
250
+ lut_box_layout.addWidget(self.lut_line)
251
+ lut_box_layout.addWidget(self.lut_btn)
252
+ self.lut_file_label = QLabel("LUT file (png/.npy/.cube)")
253
+ params_layout.addRow(self.lut_file_label, lut_box)
254
+
255
+ self.lut_strength_spin = QDoubleSpinBox()
256
+ self.lut_strength_spin.setRange(0.0, 1.0)
257
+ self.lut_strength_spin.setSingleStep(0.01)
258
+ self.lut_strength_spin.setValue(1.0)
259
+ self.lut_strength_spin.setToolTip("Blend strength for LUT (0.0 = no effect, 1.0 = full LUT)")
260
+ self.lut_strength_label = QLabel("LUT strength")
261
+ params_layout.addRow(self.lut_strength_label, self.lut_strength_spin)
262
+
263
+ # Initially hide LUT controls and their labels
264
+ self.lut_file_label.setVisible(False)
265
+ lut_box.setVisible(False)
266
+ self.lut_strength_label.setVisible(False)
267
+ self.lut_strength_spin.setVisible(False)
268
+
269
+ # Store all widgets that need their visibility toggled
270
+ self._lut_controls = (self.lut_file_label, lut_box, self.lut_strength_label, self.lut_strength_spin)
271
+
272
+ # Camera simulator collapsible group
273
+ self.camera_box = CollapsibleBox("Camera simulator options")
274
+ right_v.addWidget(self.camera_box)
275
+ cam_layout = QFormLayout()
276
+ cam_container = QWidget()
277
+ cam_container.setLayout(cam_layout)
278
+ self.camera_box.content_layout.addWidget(cam_container)
279
+
280
+ # Enable bayer
281
+ self.bayer_chk = QCheckBox("Enable Bayer / demosaic (RGGB)")
282
+ self.bayer_chk.setChecked(True)
283
+ cam_layout.addRow(self.bayer_chk)
284
+
285
+ # JPEG cycles
286
+ self.jpeg_cycles_spin = QSpinBox()
287
+ self.jpeg_cycles_spin.setRange(0, 10)
288
+ self.jpeg_cycles_spin.setValue(1)
289
+ cam_layout.addRow("JPEG cycles", self.jpeg_cycles_spin)
290
+
291
+ # JPEG quality min/max
292
+ self.jpeg_qmin_spin = QSpinBox()
293
+ self.jpeg_qmin_spin.setRange(1, 100)
294
+ self.jpeg_qmin_spin.setValue(88)
295
+ self.jpeg_qmax_spin = QSpinBox()
296
+ self.jpeg_qmax_spin.setRange(1, 100)
297
+ self.jpeg_qmax_spin.setValue(96)
298
+ qbox = QHBoxLayout()
299
+ qbox.addWidget(self.jpeg_qmin_spin)
300
+ qbox.addWidget(QLabel("to"))
301
+ qbox.addWidget(self.jpeg_qmax_spin)
302
+ cam_layout.addRow("JPEG quality (min to max)", qbox)
303
+
304
+ # Vignette strength
305
+ self.vignette_spin = QDoubleSpinBox()
306
+ self.vignette_spin.setRange(0.0, 1.0)
307
+ self.vignette_spin.setSingleStep(0.01)
308
+ self.vignette_spin.setValue(0.35)
309
+ cam_layout.addRow("Vignette strength", self.vignette_spin)
310
+
311
+ # Chromatic aberration strength
312
+ self.chroma_spin = QDoubleSpinBox()
313
+ self.chroma_spin.setRange(0.0, 10.0)
314
+ self.chroma_spin.setSingleStep(0.1)
315
+ self.chroma_spin.setValue(1.2)
316
+ cam_layout.addRow("Chromatic aberration (px)", self.chroma_spin)
317
+
318
+ # ISO scale
319
+ self.iso_spin = QDoubleSpinBox()
320
+ self.iso_spin.setRange(0.1, 16.0)
321
+ self.iso_spin.setSingleStep(0.1)
322
+ self.iso_spin.setValue(1.0)
323
+ cam_layout.addRow("ISO/exposure scale", self.iso_spin)
324
+
325
+ # Read noise
326
+ self.read_noise_spin = QDoubleSpinBox()
327
+ self.read_noise_spin.setRange(0.0, 50.0)
328
+ self.read_noise_spin.setSingleStep(0.1)
329
+ self.read_noise_spin.setValue(2.0)
330
+ cam_layout.addRow("Read noise (DN)", self.read_noise_spin)
331
+
332
+ # Hot pixel prob
333
+ self.hot_pixel_spin = QDoubleSpinBox()
334
+ self.hot_pixel_spin.setDecimals(9)
335
+ self.hot_pixel_spin.setRange(0.0, 1.0)
336
+ self.hot_pixel_spin.setSingleStep(1e-6)
337
+ self.hot_pixel_spin.setValue(1e-6)
338
+ cam_layout.addRow("Hot pixel prob", self.hot_pixel_spin)
339
+
340
+ # Banding strength
341
+ self.banding_spin = QDoubleSpinBox()
342
+ self.banding_spin.setRange(0.0, 1.0)
343
+ self.banding_spin.setSingleStep(0.01)
344
+ self.banding_spin.setValue(0.0)
345
+ cam_layout.addRow("Banding strength", self.banding_spin)
346
+
347
+ # Motion blur kernel
348
+ self.motion_blur_spin = QSpinBox()
349
+ self.motion_blur_spin.setRange(1, 51)
350
+ self.motion_blur_spin.setValue(1)
351
+ cam_layout.addRow("Motion blur kernel", self.motion_blur_spin)
352
+
353
+ self.camera_box.setVisible(False)
354
+
355
+ self.ref_hint = QLabel("AWB uses the 'AWB reference' chooser. FFT spectral matching uses the 'FFT Reference' chooser.")
356
+ right_v.addWidget(self.ref_hint)
357
+
358
+ self.analysis_input = AnalysisPanel(title="Input analysis")
359
+ self.analysis_output = AnalysisPanel(title="Output analysis")
360
+ right_v.addWidget(self.analysis_input)
361
+ right_v.addWidget(self.analysis_output)
362
+
363
+ right_v.addStretch(1)
364
+
365
+ # Status bar
366
+ self.status = QLabel("Ready")
367
+ self.status.setStyleSheet("color:#bdbdbd;padding:6px")
368
+ self.status.setAlignment(Qt.AlignLeft)
369
+ self.status.setFixedHeight(28)
370
+ self.status.setContentsMargins(6, 6, 6, 6)
371
+ self.statusBar().addWidget(self.status)
372
+
373
+ self.worker = None
374
+ self._on_auto_mode_toggled(self.auto_mode_chk.checkState())
375
+
376
+ def _on_sim_camera_toggled(self, state):
377
+ enabled = state == Qt.Checked
378
+ self.camera_box.setVisible(enabled)
379
+
380
+ def _on_auto_mode_toggled(self, state):
381
+ is_auto = (state == Qt.Checked)
382
+ self.auto_box.setVisible(is_auto)
383
+ self.params_box.setVisible(not is_auto)
384
+
385
+ def _update_strength_label(self, value):
386
+ self.strength_label.setText(str(value))
387
+
388
+ def choose_input(self):
389
+ path, _ = QFileDialog.getOpenFileName(self, "Choose input image", str(Path.home()), "Images (*.png *.jpg *.jpeg *.bmp *.tif)")
390
+ if path:
391
+ self.input_line.setText(path)
392
+ self.load_preview(self.preview_in, path)
393
+ self.analysis_input.update_from_path(path)
394
+ out_suggest = str(Path(path).with_name(Path(path).stem + "_out" + Path(path).suffix))
395
+ if not self.output_line.text():
396
+ self.output_line.setText(out_suggest)
397
+
398
+ def choose_ref(self):
399
+ path, _ = QFileDialog.getOpenFileName(self, "Choose AWB reference image", str(Path.home()), "Images (*.png *.jpg *.jpeg *.bmp *.tif)")
400
+ if path:
401
+ self.ref_line.setText(path)
402
+
403
+ def choose_fft_ref(self):
404
+ path, _ = QFileDialog.getOpenFileName(self, "Choose FFT reference image", str(Path.home()), "Images (*.png *.jpg *.jpeg *.bmp *.tif)")
405
+ if path:
406
+ self.fft_ref_line.setText(path)
407
+
408
+ def choose_output(self):
409
+ path, _ = QFileDialog.getSaveFileName(self, "Choose output path", str(Path.home()), "JPEG (*.jpg *.jpeg);;PNG (*.png);;TIFF (*.tif)")
410
+ if path:
411
+ self.output_line.setText(path)
412
+
413
+ def choose_lut(self):
414
+ path, _ = QFileDialog.getOpenFileName(self, "Choose LUT file", str(Path.home()), "LUTs (*.png *.npy *.cube);;All files (*)")
415
+ if path:
416
+ self.lut_line.setText(path)
417
+
418
+ def _on_lut_toggled(self, state):
419
+ visible = (state == Qt.Checked)
420
+ for w in self._lut_controls:
421
+ w.setVisible(visible)
422
+
423
+ def load_preview(self, widget: QLabel, path: str):
424
+ if not path or not os.path.exists(path):
425
+ widget.setText("No image")
426
+ widget.setPixmap(QPixmap())
427
+ return
428
+ pix = qpixmap_from_path(path, max_size=(widget.width(), widget.height()))
429
+ widget.setPixmap(pix)
430
+
431
+ def set_enabled_all(self, enabled: bool):
432
+ for w in self.findChildren((QPushButton, QDoubleSpinBox, QSpinBox, QLineEdit, QComboBox, QCheckBox, QSlider, QToolButton)):
433
+ w.setEnabled(enabled)
434
+
435
+ def on_run(self):
436
+ from types import SimpleNamespace
437
+ inpath = self.input_line.text().strip()
438
+ outpath = self.output_line.text().strip()
439
+ if not inpath or not os.path.exists(inpath):
440
+ QMessageBox.warning(self, "Missing input", "Please choose a valid input image.")
441
+ return
442
+ if not outpath:
443
+ QMessageBox.warning(self, "Missing output", "Please choose an output path.")
444
+ return
445
+
446
+ awb_ref_val = self.ref_line.text() or None
447
+ fft_ref_val = self.fft_ref_line.text() or None
448
+ args = SimpleNamespace()
449
+
450
+ if self.auto_mode_chk.isChecked():
451
+ strength = self.strength_slider.value() / 100.0
452
+ args.noise_std = strength * 0.04
453
+ args.clahe_clip = 1.0 + strength * 3.0
454
+ args.cutoff = max(0.01, 0.4 - strength * 0.3)
455
+ args.fstrength = strength * 0.95
456
+ args.phase_perturb = strength * 0.1
457
+ args.perturb = strength * 0.015
458
+ args.jpeg_cycles = int(strength * 2)
459
+ args.jpeg_qmin = max(1, int(95 - strength * 35))
460
+ args.jpeg_qmax = max(1, int(99 - strength * 25))
461
+ args.vignette_strength = strength * 0.6
462
+ args.chroma_strength = strength * 4.0
463
+ args.motion_blur_kernel = 1 + 2 * int(strength * 6)
464
+ args.banding_strength = strength * 0.1
465
+ args.tile = 8
466
+ args.randomness = 0.05
467
+ args.radial_smooth = 5
468
+ args.fft_mode = "auto"
469
+ args.fft_alpha = 1.0
470
+ args.alpha = 1.0
471
+ seed_val = int(self.seed_spin.value())
472
+ args.seed = None if seed_val == 0 else seed_val
473
+ args.sim_camera = bool(self.sim_camera_chk.isChecked())
474
+ args.no_no_bayer = True
475
+ args.iso_scale = 1.0
476
+ args.read_noise = 2.0
477
+ args.hot_pixel_prob = 1e-6
478
+ else:
479
+ seed_val = int(self.seed_spin.value())
480
+ args.seed = None if seed_val == 0 else seed_val
481
+ sim_camera = bool(self.sim_camera_chk.isChecked())
482
+ enable_bayer = bool(self.bayer_chk.isChecked())
483
+ args.noise_std = float(self.noise_spin.value())
484
+ args.clahe_clip = float(self.clahe_spin.value())
485
+ args.tile = int(self.tile_spin.value())
486
+ args.cutoff = float(self.cutoff_spin.value())
487
+ args.fstrength = float(self.fstrength_spin.value())
488
+ args.strength = float(self.fstrength_spin.value())
489
+ args.randomness = float(self.randomness_spin.value())
490
+ args.phase_perturb = float(self.phase_perturb_spin.value())
491
+ args.perturb = float(self.perturb_spin.value())
492
+ args.fft_mode = self.fft_mode_combo.currentText()
493
+ args.fft_alpha = float(self.fft_alpha_spin.value())
494
+ args.alpha = float(self.fft_alpha_spin.value())
495
+ args.radial_smooth = int(self.radial_smooth_spin.value())
496
+ args.sim_camera = sim_camera
497
+ args.no_no_bayer = bool(enable_bayer)
498
+ args.jpeg_cycles = int(self.jpeg_cycles_spin.value())
499
+ args.jpeg_qmin = int(self.jpeg_qmin_spin.value())
500
+ args.jpeg_qmax = int(self.jpeg_qmax_spin.value())
501
+ args.vignette_strength = float(self.vignette_spin.value())
502
+ args.chroma_strength = float(self.chroma_spin.value())
503
+ args.iso_scale = float(self.iso_spin.value())
504
+ args.read_noise = float(self.read_noise_spin.value())
505
+ args.hot_pixel_prob = float(self.hot_pixel_spin.value())
506
+ args.banding_strength = float(self.banding_spin.value())
507
+ args.motion_blur_kernel = int(self.motion_blur_spin.value())
508
+
509
+ # AWB handling to match the new --awb flag in the backend
510
+ if self.awb_chk.isChecked():
511
+ args.awb = True
512
+ args.ref = awb_ref_val # This can be the path or None (for grey-world)
513
+ else:
514
+ args.awb = False
515
+ args.ref = None
516
+
517
+ # FFT spectral matching reference
518
+ args.fft_ref = fft_ref_val
519
+
520
+ # LUT handling: only include if LUT checkbox is checked and a path is provided
521
+ if self.lut_chk.isChecked():
522
+ lut_path = self.lut_line.text().strip()
523
+ args.lut = lut_path if lut_path else None
524
+ args.lut_strength = float(self.lut_strength_spin.value())
525
+ else:
526
+ args.lut = None
527
+ args.lut_strength = 1.0
528
+
529
+ self.worker = Worker(inpath, outpath, args)
530
+ self.worker.finished.connect(self.on_finished)
531
+ self.worker.error.connect(self.on_error)
532
+ self.worker.started.connect(lambda: self.on_worker_started())
533
+ self.worker.start()
534
+
535
+ self.progress.setRange(0, 0)
536
+ self.status.setText("Processing...")
537
+ self.set_enabled_all(False)
538
+
539
+ def on_worker_started(self):
540
+ pass
541
+
542
+ def on_finished(self, outpath):
543
+ self.progress.setRange(0, 100)
544
+ self.progress.setValue(100)
545
+ self.status.setText("Done — saved to: " + outpath)
546
+ self.load_preview(self.preview_out, outpath)
547
+ self.analysis_output.update_from_path(outpath)
548
+ self.set_enabled_all(True)
549
+
550
+ def on_error(self, msg, traceback_text):
551
+ from PyQt5.QtWidgets import QDialog, QTextEdit
552
+ self.progress.setRange(0, 100)
553
+ self.progress.setValue(0)
554
+ self.status.setText("Error")
555
+
556
+ dialog = QDialog(self)
557
+ dialog.setWindowTitle("Processing Error")
558
+ dialog.setMinimumSize(700, 480)
559
+ layout = QVBoxLayout(dialog)
560
+
561
+ error_label = QLabel(f"Error: {msg}")
562
+ error_label.setWordWrap(True)
563
+ layout.addWidget(error_label)
564
+
565
+ traceback_edit = QTextEdit()
566
+ traceback_edit.setReadOnly(True)
567
+ traceback_edit.setText(traceback_text)
568
+ traceback_edit.setStyleSheet("font-family: monospace; font-size: 12px;")
569
+ layout.addWidget(traceback_edit)
570
+
571
+ ok_button = QPushButton("OK")
572
+ ok_button.clicked.connect(dialog.accept)
573
+ layout.addWidget(ok_button)
574
+
575
+ dialog.exec_()
576
+ self.set_enabled_all(True)
577
+
578
+ def open_output_folder(self):
579
+ out = self.output_line.text().strip()
580
+ if not out:
581
+ QMessageBox.information(self, "No output", "No output path set yet.")
582
+ return
583
+ folder = os.path.dirname(os.path.abspath(out))
584
+ if not os.path.exists(folder):
585
+ QMessageBox.warning(self, "Not found", "Output folder does not exist: " + folder)
586
+ return
587
+ if sys.platform.startswith('darwin'):
588
+ os.system(f'open "{folder}"')
589
+ elif os.name == 'nt':
590
+ os.startfile(folder)
591
+ else:
592
+ os.system(f'xdg-open "{folder}"')