""" utils.py Helper functions for image postprocessing, including EXIF removal, noise addition, color correction, and Fourier spectrum matching. """ from PIL import Image, ImageOps import numpy as np try: import cv2 _HAS_CV2 = True except Exception: cv2 = None _HAS_CV2 = False from scipy.ndimage import gaussian_filter1d def remove_exif_pil(img: Image.Image) -> Image.Image: data = img.tobytes() new = Image.frombytes(img.mode, img.size, data) return new def add_gaussian_noise(img_arr: np.ndarray, std_frac=0.02, seed=None) -> np.ndarray: if seed is not None: np.random.seed(seed) std = std_frac * 255.0 noise = np.random.normal(loc=0.0, scale=std, size=img_arr.shape) out = img_arr.astype(np.float32) + noise out = np.clip(out, 0, 255).astype(np.uint8) return out def clahe_color_correction(img_arr: np.ndarray, clip_limit=2.0, tile_grid_size=(8,8)) -> np.ndarray: if _HAS_CV2: lab = cv2.cvtColor(img_arr, cv2.COLOR_RGB2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size) l2 = clahe.apply(l) lab2 = cv2.merge((l2, a, b)) out = cv2.cvtColor(lab2, cv2.COLOR_LAB2RGB) return out else: pil = Image.fromarray(img_arr) channels = pil.split() new_ch = [] for ch in channels: eq = ImageOps.equalize(ch) new_ch.append(eq) merged = Image.merge('RGB', new_ch) return np.array(merged) def randomized_perturbation(img_arr: np.ndarray, magnitude_frac=0.008, seed=None) -> np.ndarray: if seed is not None: np.random.seed(seed) mag = magnitude_frac * 255.0 perturb = np.random.uniform(low=-mag, high=mag, size=img_arr.shape) out = img_arr.astype(np.float32) + perturb out = np.clip(out, 0, 255).astype(np.uint8) return out def radial_profile(mag: np.ndarray, center=None, nbins=None): h, w = mag.shape if center is None: cy, cx = h // 2, w // 2 else: cy, cx = center if nbins is None: nbins = int(max(h, w) / 2) nbins = max(1, int(nbins)) y = np.arange(h) - cy x = np.arange(w) - cx X, Y = np.meshgrid(x, y) R = np.sqrt(X * X + Y * Y) Rmax = R.max() if Rmax <= 0: Rnorm = R else: Rnorm = R / (Rmax + 1e-12) Rnorm = np.minimum(Rnorm, 1.0 - 1e-12) bin_edges = np.linspace(0.0, 1.0, nbins + 1) bin_idx = np.digitize(Rnorm.ravel(), bin_edges) - 1 bin_idx = np.clip(bin_idx, 0, nbins - 1) sums = np.bincount(bin_idx, weights=mag.ravel(), minlength=nbins) counts = np.bincount(bin_idx, minlength=nbins) radial_mean = np.zeros(nbins, dtype=np.float64) nonzero = counts > 0 radial_mean[nonzero] = sums[nonzero] / counts[nonzero] bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) return bin_centers, radial_mean def fourier_match_spectrum(img_arr: np.ndarray, ref_img_arr: np.ndarray = None, mode='auto', alpha=1.0, cutoff=0.25, strength=0.9, randomness=0.05, phase_perturb=0.08, radial_smooth=5, seed=None): if seed is not None: rng = np.random.default_rng(seed) else: rng = np.random.default_rng() h, w = img_arr.shape[:2] cy, cx = h // 2, w // 2 nbins = max(8, int(max(h, w) / 2)) if mode == 'auto': mode = 'ref' if ref_img_arr is not None else 'model' bin_centers_src = np.linspace(0.0, 1.0, nbins) model_radial = None if mode == 'model': eps = 1e-8 model_radial = (1.0 / (bin_centers_src + eps)) ** (alpha / 2.0) lf = max(1, nbins // 8) model_radial = model_radial / (np.median(model_radial[:lf]) + 1e-12) model_radial = gaussian_filter1d(model_radial, sigma=max(1, radial_smooth)) ref_radial = None ref_bin_centers = None if mode == 'ref' and ref_img_arr is not None: if ref_img_arr.shape[0] != h or ref_img_arr.shape[1] != w: ref_img = Image.fromarray(ref_img_arr).resize((w, h), resample=Image.BICUBIC) ref_img_arr = np.array(ref_img) ref_gray = np.mean(ref_img_arr.astype(np.float32), axis=2) if ref_img_arr.ndim == 3 else ref_img_arr.astype(np.float32) Fref = np.fft.fftshift(np.fft.fft2(ref_gray)) Mref = np.abs(Fref) ref_bin_centers, ref_radial = radial_profile(Mref, center=(h // 2, w // 2), nbins=nbins) ref_radial = gaussian_filter1d(ref_radial, sigma=max(1, radial_smooth)) out = np.zeros_like(img_arr, dtype=np.float32) y = np.linspace(-1, 1, h, endpoint=False)[:, None] x = np.linspace(-1, 1, w, endpoint=False)[None, :] r = np.sqrt(x * x + y * y) r = np.clip(r, 0.0, 1.0 - 1e-6) for c in range(img_arr.shape[2]): channel = img_arr[:, :, c].astype(np.float32) F = np.fft.fft2(channel) Fshift = np.fft.fftshift(F) mag = np.abs(Fshift) phase = np.angle(Fshift) bin_centers_src_calc, src_radial = radial_profile(mag, center=(h // 2, w // 2), nbins=nbins) src_radial = gaussian_filter1d(src_radial, sigma=max(1, radial_smooth)) bin_centers_src = bin_centers_src_calc if mode == 'ref' and ref_radial is not None: ref_interp = np.interp(bin_centers_src, ref_bin_centers, ref_radial) eps = 1e-8 ratio = (ref_interp + eps) / (src_radial + eps) desired_radial = src_radial * ratio elif mode == 'model' and model_radial is not None: lf = max(1, nbins // 8) scale = (np.median(src_radial[:lf]) + 1e-12) / (np.median(model_radial[:lf]) + 1e-12) desired_radial = model_radial * scale else: desired_radial = src_radial.copy() eps = 1e-8 multiplier_1d = (desired_radial + eps) / (src_radial + eps) multiplier_1d = np.clip(multiplier_1d, 0.2, 5.0) mult_2d = np.interp(r.ravel(), bin_centers_src, multiplier_1d).reshape(h, w) edge = 0.05 + 0.02 * (1.0 - cutoff) if 'cutoff' in globals() else 0.05 edge = max(edge, 1e-6) weight = np.where(r <= 0.25, 1.0, np.where(r <= 0.25 + edge, 0.5 * (1 + np.cos(np.pi * (r - 0.25) / edge)), 0.0)) final_multiplier = 1.0 + (mult_2d - 1.0) * (weight * strength) if randomness and randomness > 0.0: noise = rng.normal(loc=1.0, scale=randomness, size=final_multiplier.shape) final_multiplier *= (1.0 + (noise - 1.0) * weight) mag2 = mag * final_multiplier if phase_perturb and phase_perturb > 0.0: phase_sigma = phase_perturb * np.clip((r - 0.25) / (1.0 - 0.25 + 1e-6), 0.0, 1.0) phase_noise = rng.standard_normal(size=phase_sigma.shape) * phase_sigma phase2 = phase + phase_noise else: phase2 = phase Fshift2 = mag2 * np.exp(1j * phase2) F_ishift = np.fft.ifftshift(Fshift2) img_back = np.fft.ifft2(F_ishift) img_back = np.real(img_back) blended = (1.0 - strength) * channel + strength * img_back out[:, :, c] = blended out = np.clip(out, 0, 255).astype(np.uint8) return out