|
|
""" |
|
|
utils.py |
|
|
|
|
|
Helper functions for image postprocessing, including EXIF removal, noise addition, |
|
|
color correction, and Fourier spectrum matching. |
|
|
""" |
|
|
import os |
|
|
import re |
|
|
from PIL import Image, ImageOps |
|
|
import numpy as np |
|
|
try: |
|
|
import cv2 |
|
|
_HAS_CV2 = True |
|
|
except Exception: |
|
|
cv2 = None |
|
|
_HAS_CV2 = False |
|
|
from scipy.ndimage import gaussian_filter1d |
|
|
|
|
|
def remove_exif_pil(img: Image.Image) -> Image.Image: |
|
|
data = img.tobytes() |
|
|
new = Image.frombytes(img.mode, img.size, data) |
|
|
return new |
|
|
|
|
|
def add_gaussian_noise(img_arr: np.ndarray, std_frac=0.02, seed=None) -> np.ndarray: |
|
|
if seed is not None: |
|
|
np.random.seed(seed) |
|
|
std = std_frac * 255.0 |
|
|
noise = np.random.normal(loc=0.0, scale=std, size=img_arr.shape) |
|
|
out = img_arr.astype(np.float32) + noise |
|
|
out = np.clip(out, 0, 255).astype(np.uint8) |
|
|
return out |
|
|
|
|
|
def clahe_color_correction(img_arr: np.ndarray, clip_limit=2.0, tile_grid_size=(8,8)) -> np.ndarray: |
|
|
if _HAS_CV2: |
|
|
lab = cv2.cvtColor(img_arr, cv2.COLOR_RGB2LAB) |
|
|
l, a, b = cv2.split(lab) |
|
|
clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size) |
|
|
l2 = clahe.apply(l) |
|
|
lab2 = cv2.merge((l2, a, b)) |
|
|
out = cv2.cvtColor(lab2, cv2.COLOR_LAB2RGB) |
|
|
return out |
|
|
else: |
|
|
pil = Image.fromarray(img_arr) |
|
|
channels = pil.split() |
|
|
new_ch = [] |
|
|
for ch in channels: |
|
|
eq = ImageOps.equalize(ch) |
|
|
new_ch.append(eq) |
|
|
merged = Image.merge('RGB', new_ch) |
|
|
return np.array(merged) |
|
|
|
|
|
def randomized_perturbation(img_arr: np.ndarray, magnitude_frac=0.008, seed=None) -> np.ndarray: |
|
|
if seed is not None: |
|
|
np.random.seed(seed) |
|
|
mag = magnitude_frac * 255.0 |
|
|
perturb = np.random.uniform(low=-mag, high=mag, size=img_arr.shape) |
|
|
out = img_arr.astype(np.float32) + perturb |
|
|
out = np.clip(out, 0, 255).astype(np.uint8) |
|
|
return out |
|
|
|
|
|
def radial_profile(mag: np.ndarray, center=None, nbins=None): |
|
|
h, w = mag.shape |
|
|
if center is None: |
|
|
cy, cx = h // 2, w // 2 |
|
|
else: |
|
|
cy, cx = center |
|
|
|
|
|
if nbins is None: |
|
|
nbins = int(max(h, w) / 2) |
|
|
nbins = max(1, int(nbins)) |
|
|
|
|
|
y = np.arange(h) - cy |
|
|
x = np.arange(w) - cx |
|
|
X, Y = np.meshgrid(x, y) |
|
|
R = np.sqrt(X * X + Y * Y) |
|
|
|
|
|
Rmax = R.max() |
|
|
if Rmax <= 0: |
|
|
Rnorm = R |
|
|
else: |
|
|
Rnorm = R / (Rmax + 1e-12) |
|
|
Rnorm = np.minimum(Rnorm, 1.0 - 1e-12) |
|
|
|
|
|
bin_edges = np.linspace(0.0, 1.0, nbins + 1) |
|
|
bin_idx = np.digitize(Rnorm.ravel(), bin_edges) - 1 |
|
|
bin_idx = np.clip(bin_idx, 0, nbins - 1) |
|
|
|
|
|
sums = np.bincount(bin_idx, weights=mag.ravel(), minlength=nbins) |
|
|
counts = np.bincount(bin_idx, minlength=nbins) |
|
|
|
|
|
radial_mean = np.zeros(nbins, dtype=np.float64) |
|
|
nonzero = counts > 0 |
|
|
radial_mean[nonzero] = sums[nonzero] / counts[nonzero] |
|
|
|
|
|
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) |
|
|
return bin_centers, radial_mean |
|
|
|
|
|
def fourier_match_spectrum(img_arr: np.ndarray, |
|
|
ref_img_arr: np.ndarray = None, |
|
|
mode='auto', |
|
|
alpha=1.0, |
|
|
cutoff=0.25, |
|
|
strength=0.9, |
|
|
randomness=0.05, |
|
|
phase_perturb=0.08, |
|
|
radial_smooth=5, |
|
|
seed=None): |
|
|
if seed is not None: |
|
|
rng = np.random.default_rng(seed) |
|
|
else: |
|
|
rng = np.random.default_rng() |
|
|
|
|
|
h, w = img_arr.shape[:2] |
|
|
cy, cx = h // 2, w // 2 |
|
|
nbins = max(8, int(max(h, w) / 2)) |
|
|
|
|
|
if mode == 'auto': |
|
|
mode = 'ref' if ref_img_arr is not None else 'model' |
|
|
|
|
|
bin_centers_src = np.linspace(0.0, 1.0, nbins) |
|
|
|
|
|
model_radial = None |
|
|
if mode == 'model': |
|
|
eps = 1e-8 |
|
|
model_radial = (1.0 / (bin_centers_src + eps)) ** (alpha / 2.0) |
|
|
lf = max(1, nbins // 8) |
|
|
model_radial = model_radial / (np.median(model_radial[:lf]) + 1e-12) |
|
|
model_radial = gaussian_filter1d(model_radial, sigma=max(1, radial_smooth)) |
|
|
|
|
|
ref_radial = None |
|
|
ref_bin_centers = None |
|
|
if mode == 'ref' and ref_img_arr is not None: |
|
|
if ref_img_arr.shape[0] != h or ref_img_arr.shape[1] != w: |
|
|
ref_img = Image.fromarray(ref_img_arr).resize((w, h), resample=Image.BICUBIC) |
|
|
ref_img_arr = np.array(ref_img) |
|
|
ref_gray = np.mean(ref_img_arr.astype(np.float32), axis=2) if ref_img_arr.ndim == 3 else ref_img_arr.astype(np.float32) |
|
|
Fref = np.fft.fftshift(np.fft.fft2(ref_gray)) |
|
|
Mref = np.abs(Fref) |
|
|
ref_bin_centers, ref_radial = radial_profile(Mref, center=(h // 2, w // 2), nbins=nbins) |
|
|
ref_radial = gaussian_filter1d(ref_radial, sigma=max(1, radial_smooth)) |
|
|
|
|
|
out = np.zeros_like(img_arr, dtype=np.float32) |
|
|
|
|
|
y = np.linspace(-1, 1, h, endpoint=False)[:, None] |
|
|
x = np.linspace(-1, 1, w, endpoint=False)[None, :] |
|
|
r = np.sqrt(x * x + y * y) |
|
|
r = np.clip(r, 0.0, 1.0 - 1e-6) |
|
|
|
|
|
for c in range(img_arr.shape[2]): |
|
|
channel = img_arr[:, :, c].astype(np.float32) |
|
|
F = np.fft.fft2(channel) |
|
|
Fshift = np.fft.fftshift(F) |
|
|
mag = np.abs(Fshift) |
|
|
phase = np.angle(Fshift) |
|
|
|
|
|
bin_centers_src_calc, src_radial = radial_profile(mag, center=(h // 2, w // 2), nbins=nbins) |
|
|
src_radial = gaussian_filter1d(src_radial, sigma=max(1, radial_smooth)) |
|
|
bin_centers_src = bin_centers_src_calc |
|
|
|
|
|
if mode == 'ref' and ref_radial is not None: |
|
|
ref_interp = np.interp(bin_centers_src, ref_bin_centers, ref_radial) |
|
|
eps = 1e-8 |
|
|
ratio = (ref_interp + eps) / (src_radial + eps) |
|
|
desired_radial = src_radial * ratio |
|
|
elif mode == 'model' and model_radial is not None: |
|
|
lf = max(1, nbins // 8) |
|
|
scale = (np.median(src_radial[:lf]) + 1e-12) / (np.median(model_radial[:lf]) + 1e-12) |
|
|
desired_radial = model_radial * scale |
|
|
else: |
|
|
desired_radial = src_radial.copy() |
|
|
|
|
|
eps = 1e-8 |
|
|
multiplier_1d = (desired_radial + eps) / (src_radial + eps) |
|
|
multiplier_1d = np.clip(multiplier_1d, 0.2, 5.0) |
|
|
mult_2d = np.interp(r.ravel(), bin_centers_src, multiplier_1d).reshape(h, w) |
|
|
|
|
|
edge = 0.05 + 0.02 * (1.0 - cutoff) if 'cutoff' in globals() else 0.05 |
|
|
edge = max(edge, 1e-6) |
|
|
weight = np.where(r <= 0.25, 1.0, |
|
|
np.where(r <= 0.25 + edge, |
|
|
0.5 * (1 + np.cos(np.pi * (r - 0.25) / edge)), |
|
|
0.0)) |
|
|
|
|
|
final_multiplier = 1.0 + (mult_2d - 1.0) * (weight * strength) |
|
|
|
|
|
if randomness and randomness > 0.0: |
|
|
noise = rng.normal(loc=1.0, scale=randomness, size=final_multiplier.shape) |
|
|
final_multiplier *= (1.0 + (noise - 1.0) * weight) |
|
|
|
|
|
mag2 = mag * final_multiplier |
|
|
|
|
|
if phase_perturb and phase_perturb > 0.0: |
|
|
phase_sigma = phase_perturb * np.clip((r - 0.25) / (1.0 - 0.25 + 1e-6), 0.0, 1.0) |
|
|
phase_noise = rng.standard_normal(size=phase_sigma.shape) * phase_sigma |
|
|
phase2 = phase + phase_noise |
|
|
else: |
|
|
phase2 = phase |
|
|
|
|
|
Fshift2 = mag2 * np.exp(1j * phase2) |
|
|
F_ishift = np.fft.ifftshift(Fshift2) |
|
|
img_back = np.fft.ifft2(F_ishift) |
|
|
img_back = np.real(img_back) |
|
|
|
|
|
blended = (1.0 - strength) * channel + strength * img_back |
|
|
out[:, :, c] = blended |
|
|
|
|
|
out = np.clip(out, 0, 255).astype(np.uint8) |
|
|
return out |
|
|
|
|
|
def auto_white_balance_ref(img_arr: np.ndarray, ref_img_arr: np.ndarray = None) -> np.ndarray: |
|
|
""" |
|
|
Auto white-balance correction using a reference image. |
|
|
If ref_img_arr is None, uses a gray-world assumption instead. |
|
|
""" |
|
|
img = img_arr.astype(np.float32) |
|
|
|
|
|
if ref_img_arr is not None: |
|
|
ref = ref_img_arr.astype(np.float32) |
|
|
ref_mean = ref.reshape(-1, 3).mean(axis=0) |
|
|
else: |
|
|
|
|
|
ref_mean = np.array([128.0, 128.0, 128.0], dtype=np.float32) |
|
|
|
|
|
img_mean = img.reshape(-1, 3).mean(axis=0) |
|
|
|
|
|
|
|
|
eps = 1e-6 |
|
|
scale = (ref_mean + eps) / (img_mean + eps) |
|
|
|
|
|
corrected = img * scale |
|
|
corrected = np.clip(corrected, 0, 255).astype(np.uint8) |
|
|
|
|
|
return corrected |
|
|
|
|
|
def apply_1d_lut(img_arr: np.ndarray, lut: np.ndarray, strength: float = 1.0) -> np.ndarray: |
|
|
""" |
|
|
Apply a 1D LUT to an image. |
|
|
- img_arr: HxWx3 uint8 |
|
|
- lut: either shape (256,) (applied equally to all channels), (256,3) (per-channel), |
|
|
or (N,) / (N,3) (interpolated across [0..255]) |
|
|
- strength: 0..1 blending between original and LUT result |
|
|
Returns uint8 array. |
|
|
""" |
|
|
if img_arr.ndim != 3 or img_arr.shape[2] != 3: |
|
|
raise ValueError("apply_1d_lut expects an HxWx3 image array") |
|
|
|
|
|
|
|
|
arr = img_arr.astype(np.float32) |
|
|
|
|
|
lut_arr = np.array(lut, dtype=np.float32) |
|
|
|
|
|
|
|
|
if lut_arr.ndim == 1: |
|
|
lut_arr = np.stack([lut_arr, lut_arr, lut_arr], axis=1) |
|
|
|
|
|
if lut_arr.shape[1] != 3: |
|
|
raise ValueError("1D LUT must have shape (N,) or (N,3)") |
|
|
|
|
|
|
|
|
N = lut_arr.shape[0] |
|
|
src_positions = np.linspace(0, 255, N) |
|
|
|
|
|
|
|
|
out = np.empty_like(arr) |
|
|
for c in range(3): |
|
|
channel = arr[..., c].ravel() |
|
|
mapped = np.interp(channel, src_positions, lut_arr[:, c]) |
|
|
out[..., c] = mapped.reshape(arr.shape[0], arr.shape[1]) |
|
|
|
|
|
out = np.clip(out, 0, 255).astype(np.uint8) |
|
|
if strength >= 1.0: |
|
|
return out |
|
|
else: |
|
|
blended = ((1.0 - strength) * img_arr.astype(np.float32) + strength * out.astype(np.float32)) |
|
|
return np.clip(blended, 0, 255).astype(np.uint8) |
|
|
|
|
|
def _trilinear_sample_lut(img_float: np.ndarray, lut: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Vectorized trilinear sampling of 3D LUT. |
|
|
- img_float: HxWx3 floats in [0,1] |
|
|
- lut: SxSxS x 3 floats in [0,1] |
|
|
Returns HxWx3 floats in [0,1] |
|
|
""" |
|
|
S = lut.shape[0] |
|
|
if lut.shape[0] != lut.shape[1] or lut.shape[1] != lut.shape[2]: |
|
|
raise ValueError("3D LUT must be cubic (SxSxSx3)") |
|
|
|
|
|
|
|
|
idx = img_float * (S - 1) |
|
|
r_idx = idx[..., 0] |
|
|
g_idx = idx[..., 1] |
|
|
b_idx = idx[..., 2] |
|
|
|
|
|
r0 = np.floor(r_idx).astype(np.int32) |
|
|
g0 = np.floor(g_idx).astype(np.int32) |
|
|
b0 = np.floor(b_idx).astype(np.int32) |
|
|
|
|
|
r1 = np.clip(r0 + 1, 0, S - 1) |
|
|
g1 = np.clip(g0 + 1, 0, S - 1) |
|
|
b1 = np.clip(b0 + 1, 0, S - 1) |
|
|
|
|
|
dr = (r_idx - r0)[..., None] |
|
|
dg = (g_idx - g0)[..., None] |
|
|
db = (b_idx - b0)[..., None] |
|
|
|
|
|
|
|
|
c000 = lut[r0, g0, b0] |
|
|
c001 = lut[r0, g0, b1] |
|
|
c010 = lut[r0, g1, b0] |
|
|
c011 = lut[r0, g1, b1] |
|
|
c100 = lut[r1, g0, b0] |
|
|
c101 = lut[r1, g0, b1] |
|
|
c110 = lut[r1, g1, b0] |
|
|
c111 = lut[r1, g1, b1] |
|
|
|
|
|
|
|
|
c00 = c000 * (1 - db) + c001 * db |
|
|
c01 = c010 * (1 - db) + c011 * db |
|
|
c10 = c100 * (1 - db) + c101 * db |
|
|
c11 = c110 * (1 - db) + c111 * db |
|
|
|
|
|
|
|
|
c0 = c00 * (1 - dg) + c01 * dg |
|
|
c1 = c10 * (1 - dg) + c11 * dg |
|
|
|
|
|
|
|
|
c = c0 * (1 - dr) + c1 * dr |
|
|
|
|
|
return c |
|
|
|
|
|
def apply_3d_lut(img_arr: np.ndarray, lut3d: np.ndarray, strength: float = 1.0) -> np.ndarray: |
|
|
""" |
|
|
Apply a 3D LUT to the image. |
|
|
- img_arr: HxWx3 uint8 |
|
|
- lut3d: SxSxSx3 float (expected range 0..1) |
|
|
- strength: blending 0..1 |
|
|
Returns uint8 image. |
|
|
""" |
|
|
if img_arr.ndim != 3 or img_arr.shape[2] != 3: |
|
|
raise ValueError("apply_3d_lut expects an HxWx3 image array") |
|
|
|
|
|
img_float = img_arr.astype(np.float32) / 255.0 |
|
|
sampled = _trilinear_sample_lut(img_float, lut3d) |
|
|
out = np.clip(sampled * 255.0, 0, 255).astype(np.uint8) |
|
|
if strength >= 1.0: |
|
|
return out |
|
|
else: |
|
|
blended = ((1.0 - strength) * img_arr.astype(np.float32) + strength * out.astype(np.float32)) |
|
|
return np.clip(blended, 0, 255).astype(np.uint8) |
|
|
|
|
|
def apply_lut(img_arr: np.ndarray, lut: np.ndarray, strength: float = 1.0) -> np.ndarray: |
|
|
""" |
|
|
Auto-detect LUT type and apply. |
|
|
- If lut.ndim in (1,2) treat as 1D LUT (per-channel if shape (N,3)). |
|
|
- If lut.ndim == 4 treat as 3D LUT (SxSxSx3) in [0,1]. |
|
|
""" |
|
|
lut = np.array(lut) |
|
|
if lut.ndim == 4 and lut.shape[3] == 3: |
|
|
|
|
|
|
|
|
if lut.dtype != np.float32 and lut.max() > 1.0: |
|
|
lut = lut.astype(np.float32) / 255.0 |
|
|
return apply_3d_lut(img_arr, lut, strength=strength) |
|
|
elif lut.ndim in (1, 2): |
|
|
return apply_1d_lut(img_arr, lut, strength=strength) |
|
|
else: |
|
|
raise ValueError("Unsupported LUT shape: {}".format(lut.shape)) |
|
|
|
|
|
def load_cube_lut(path: str) -> np.ndarray: |
|
|
""" |
|
|
Parse a .cube file and return a 3D LUT array of shape (S,S,S,3) with float values in [0,1]. |
|
|
Note: .cube file order sometimes varies; this function assumes standard ordering |
|
|
where data lines are triples of floats and LUT_3D_SIZE specifies S. |
|
|
""" |
|
|
with open(path, 'r', encoding='utf-8', errors='ignore') as f: |
|
|
lines = [ln.strip() for ln in f if ln.strip() and not ln.strip().startswith('#')] |
|
|
|
|
|
size = None |
|
|
data = [] |
|
|
domain_min = np.array([0.0, 0.0, 0.0], dtype=np.float32) |
|
|
domain_max = np.array([1.0, 1.0, 1.0], dtype=np.float32) |
|
|
|
|
|
for ln in lines: |
|
|
if ln.upper().startswith('LUT_3D_SIZE'): |
|
|
parts = ln.split() |
|
|
if len(parts) >= 2: |
|
|
size = int(parts[1]) |
|
|
elif ln.upper().startswith('DOMAIN_MIN'): |
|
|
parts = ln.split() |
|
|
domain_min = np.array([float(p) for p in parts[1:4]], dtype=np.float32) |
|
|
elif ln.upper().startswith('DOMAIN_MAX'): |
|
|
parts = ln.split() |
|
|
domain_max = np.array([float(p) for p in parts[1:4]], dtype=np.float32) |
|
|
elif re.match(r'^-?\d+(\.\d+)?\s+-?\d+(\.\d+)?\s+-?\d+(\.\d+)?$', ln): |
|
|
parts = [float(x) for x in ln.split()] |
|
|
data.append(parts) |
|
|
|
|
|
if size is None: |
|
|
raise ValueError("LUT_3D_SIZE not found in .cube file: {}".format(path)) |
|
|
|
|
|
data = np.array(data, dtype=np.float32) |
|
|
if data.shape[0] != size**3: |
|
|
raise ValueError("Cube LUT data length does not match size^3 (got {}, expected {})".format(data.shape[0], size**3)) |
|
|
|
|
|
|
|
|
|
|
|
lut = data.reshape((size, size, size, 3)) |
|
|
|
|
|
if not np.allclose(domain_min, [0.0, 0.0, 0.0]) or not np.allclose(domain_max, [1.0, 1.0, 1.0]): |
|
|
|
|
|
lut = (lut - domain_min) / (domain_max - domain_min + 1e-12) |
|
|
lut = np.clip(lut, 0.0, 1.0) |
|
|
else: |
|
|
|
|
|
if lut.max() > 1.0 + 1e-6: |
|
|
lut = lut / 255.0 |
|
|
return lut.astype(np.float32) |
|
|
|
|
|
def load_lut(path: str) -> np.ndarray: |
|
|
""" |
|
|
Load a LUT from: |
|
|
- .npy (numpy saved array) |
|
|
- .cube (3D LUT) |
|
|
- image (PNG/JPG) that is a 1D LUT strip (common 256x1 or 1x256) |
|
|
Returns numpy array (1D, 2D, or 4D LUT). |
|
|
""" |
|
|
ext = os.path.splitext(path)[1].lower() |
|
|
if ext == '.npy': |
|
|
return np.load(path) |
|
|
elif ext == '.cube': |
|
|
return load_cube_lut(path) |
|
|
else: |
|
|
|
|
|
try: |
|
|
im = Image.open(path).convert('RGB') |
|
|
arr = np.array(im) |
|
|
h, w = arr.shape[:2] |
|
|
|
|
|
if (w == 256 and h == 1) or (h == 256 and w == 1): |
|
|
if h == 1: |
|
|
lut = arr[0, :, :].astype(np.float32) |
|
|
else: |
|
|
lut = arr[:, 0, :].astype(np.float32) |
|
|
return lut |
|
|
|
|
|
|
|
|
flat = arr.reshape(-1, 3).astype(np.float32) |
|
|
|
|
|
L = flat.shape[0] |
|
|
if L <= 4096: |
|
|
return flat |
|
|
raise ValueError("Image LUT not recognized size") |
|
|
except Exception as e: |
|
|
raise ValueError(f"Unsupported LUT file or parse error for {path}: {e}") |
|
|
|
|
|
|
|
|
|