Fioceen commited on
Commit
1778b43
·
1 Parent(s): 3066342

Fix ComfyUI Node, Fix AWB implementation

Browse files
Files changed (3) hide show
  1. image_postprocess/processor.py +20 -10
  2. nodes.py +41 -55
  3. run.py +6 -4
image_postprocess/processor.py CHANGED
@@ -59,14 +59,20 @@ def process_image(path_in, path_out, args):
59
  img = Image.open(path_in).convert('RGB')
60
  arr = np.array(img)
61
 
62
- # --- Auto white-balance using reference (if provided) ---
63
- if args.ref:
64
- try:
65
- ref_img_awb = Image.open(args.ref).convert('RGB')
66
- ref_arr_awb = np.array(ref_img_awb)
67
- arr = auto_white_balance_ref(arr, ref_arr_awb)
68
- except Exception as e:
69
- print(f"Warning: failed to load AWB reference '{args.ref}': {e}. Skipping AWB.")
 
 
 
 
 
 
70
 
71
  # apply CLAHE color correction (contrast)
72
  arr = clahe_color_correction(arr, clip_limit=args.clahe_clip, tile_grid_size=(args.tile, args.tile))
@@ -128,7 +134,11 @@ def build_argparser():
128
  p = argparse.ArgumentParser(description="Image postprocessing pipeline with camera simulation and LUT support")
129
  p.add_argument('input', help='Input image path')
130
  p.add_argument('output', help='Output image path')
131
- p.add_argument('--ref', help='Optional reference image for auto white-balance (applied before CLAHE)', default=None)
 
 
 
 
132
  p.add_argument('--noise-std', type=float, default=0.02, help='Gaussian noise std fraction of 255 (0-0.1)')
133
  p.add_argument('--clahe-clip', type=float, default=2.0, help='CLAHE clip limit')
134
  p.add_argument('--tile', type=int, default=8, help='CLAHE tile grid size')
@@ -173,4 +183,4 @@ if __name__ == "__main__":
173
  print("Input not found:", args.input)
174
  raise SystemExit(2)
175
  process_image(args.input, args.output, args)
176
- print("Saved:", args.output)
 
59
  img = Image.open(path_in).convert('RGB')
60
  arr = np.array(img)
61
 
62
+ # --- Auto white-balance (if enabled) ---
63
+ if args.awb:
64
+ if args.ref:
65
+ try:
66
+ ref_img_awb = Image.open(args.ref).convert('RGB')
67
+ ref_arr_awb = np.array(ref_img_awb)
68
+ arr = auto_white_balance_ref(arr, ref_arr_awb)
69
+ except Exception as e:
70
+ print(f"Warning: failed to load AWB reference '{args.ref}': {e}. Skipping AWB.")
71
+ else:
72
+ print("Applying AWB using grey-world assumption...")
73
+ # Assuming auto_white_balance_ref with a None reference
74
+ # triggers the grey-world algorithm as described.
75
+ arr = auto_white_balance_ref(arr, None)
76
 
77
  # apply CLAHE color correction (contrast)
78
  arr = clahe_color_correction(arr, clip_limit=args.clahe_clip, tile_grid_size=(args.tile, args.tile))
 
134
  p = argparse.ArgumentParser(description="Image postprocessing pipeline with camera simulation and LUT support")
135
  p.add_argument('input', help='Input image path')
136
  p.add_argument('output', help='Output image path')
137
+
138
+ # AWB Options
139
+ p.add_argument('--awb', action='store_true', help='Enable automatic white balancing. Uses grey-world if --ref is not provided.')
140
+ p.add_argument('--ref', help='Optional reference image for auto white-balance (only used if --awb is enabled)', default=None)
141
+
142
  p.add_argument('--noise-std', type=float, default=0.02, help='Gaussian noise std fraction of 255 (0-0.1)')
143
  p.add_argument('--clahe-clip', type=float, default=2.0, help='CLAHE clip limit')
144
  p.add_argument('--tile', type=int, default=8, help='CLAHE tile grid size')
 
183
  print("Input not found:", args.input)
184
  raise SystemExit(2)
185
  process_image(args.input, args.output, args)
186
+ print("Saved:", args.output)
nodes.py CHANGED
@@ -13,6 +13,7 @@ except Exception as e:
13
  else:
14
  IMPORT_ERROR = None
15
 
 
16
 
17
  class NovaNodes:
18
  """
@@ -34,6 +35,9 @@ class NovaNodes:
34
  "required": {
35
  "image": ("IMAGE",),
36
 
 
 
 
37
  # EXIF
38
  "apply_exif_o": ("BOOLEAN", {"default": True}),
39
 
@@ -63,9 +67,9 @@ class NovaNodes:
63
  "apply_chromatic_aberration_o": ("BOOLEAN", {"default": True}),
64
  "ca_shift": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.1}),
65
 
66
- # Banding
67
  "apply_banding_o": ("BOOLEAN", {"default": True}),
68
- "banding_levels": ("INT", {"default": 64, "min": 2, "max": 256, "step": 1}),
69
 
70
  # Motion blur
71
  "apply_motion_blur_o": ("BOOLEAN", {"default": True}),
@@ -82,12 +86,13 @@ class NovaNodes:
82
  "iso_scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 16.0, "step": 0.1}),
83
  "read_noise": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 50.0, "step": 0.1}),
84
 
85
- # LUT (new)
86
- "lut": ("STRING", {"default": ""}),
87
  "lut_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
88
  },
89
  "optional": {
90
- "ref_image": ("IMAGE",),
 
91
  }
92
  }
93
 
@@ -96,7 +101,8 @@ class NovaNodes:
96
  FUNCTION = "process"
97
  CATEGORY = "postprocessing"
98
 
99
- def process(self, image, ref_image=None,
 
100
  apply_exif_o=True,
101
  noise_std_frac=0.015,
102
  hot_pixel_prob=1e-6,
@@ -115,7 +121,7 @@ class NovaNodes:
115
  apply_chromatic_aberration_o=True,
116
  ca_shift=1.0,
117
  apply_banding_o=True,
118
- banding_levels=64,
119
  apply_motion_blur_o=True,
120
  motion_blur_ksize=7,
121
  apply_jpeg_cycles_o=True,
@@ -135,55 +141,27 @@ class NovaNodes:
135
 
136
  def to_pil_from_any(inp):
137
  """Convert a torch tensor / numpy array of many shapes into a PIL RGB Image."""
138
- # get numpy
139
  if isinstance(inp, torch.Tensor):
140
  arr = inp.detach().cpu().numpy()
141
  else:
142
  arr = np.asarray(inp)
143
-
144
- # remove leading batch dimension if present
145
  if arr.ndim == 4 and arr.shape[0] == 1:
146
  arr = arr[0]
147
-
148
- # CHW -> HWC
149
  if arr.ndim == 3 and arr.shape[0] in (1, 3):
150
  arr = np.transpose(arr, (1, 2, 0))
151
-
152
- # if still 3D and last dim is channel (H,W,C) but C==1 or 3: OK
153
  if arr.ndim == 2:
154
- # grayscale HxW -> make HxWx1
155
  arr = arr[:, :, None]
156
-
157
- # Now arr should be H x W x C
158
- if arr.ndim != 3:
159
- # try permutations heuristically (rare)
160
- for perm in [(1, 2, 0), (2, 0, 1), (0, 2, 1)]:
161
- try:
162
- cand = np.transpose(arr, perm)
163
- if cand.ndim == 3:
164
- arr = cand
165
- break
166
- except Exception:
167
- pass
168
-
169
  if arr.ndim != 3:
170
  raise TypeError(f"Cannot convert array to HWC image, final ndim={arr.ndim}, shape={arr.shape}")
171
-
172
- # Normalize numeric range to 0..255 uint8
173
  if np.issubdtype(arr.dtype, np.floating):
174
- # assume floats are 0..1 if max <= 1.0
175
  if arr.max() <= 1.0:
176
  arr = (arr * 255.0).clip(0, 255).astype(np.uint8)
177
  else:
178
  arr = np.clip(arr, 0, 255).astype(np.uint8)
179
  else:
180
  arr = arr.astype(np.uint8)
181
-
182
- # If single channel, replicate to 3 channels (we want RGB files)
183
  if arr.shape[2] == 1:
184
  arr = np.repeat(arr, 3, axis=2)
185
-
186
- # finally create PIL
187
  return Image.fromarray(arr)
188
 
189
  try:
@@ -194,25 +172,35 @@ class NovaNodes:
194
  pil_img.save(input_path)
195
  tmp_files.append(input_path)
196
 
197
- # ---- Reference image for AWB and FFT if present ----
198
- ref_path = None
199
- if ref_image is not None:
200
- pil_ref = to_pil_from_any(ref_image[0])
201
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_ref:
202
- ref_path = tmp_ref.name
203
- pil_ref.save(ref_path)
204
- tmp_files.append(ref_path)
 
 
 
 
 
 
 
 
 
205
 
206
  # ---- Output path ----
207
  with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_output:
208
  output_path = tmp_output.name
209
  tmp_files.append(output_path)
210
 
211
- # Prepare args for process_image (keeping your names)
212
  args = SimpleNamespace(
213
  input=input_path,
214
  output=output_path,
215
- ref=ref_path, # Used for AWB if provided
 
216
  noise_std=noise_std_frac,
217
  hot_pixel_prob=hot_pixel_prob,
218
  perturb=perturb_mag_frac,
@@ -224,21 +212,20 @@ class NovaNodes:
224
  fft_alpha=fourier_alpha,
225
  radial_smooth=fourier_radial_smooth,
226
  fft_mode=fourier_mode,
227
- fft_ref=ref_path, # Used for FFT if provided
228
  vignette_strength=vignette_strength if apply_vignette_o else 0.0,
229
  chroma_strength=ca_shift if apply_chromatic_aberration_o else 0.0,
230
- banding_strength=1.0 if apply_banding_o else 0.0,
231
  motion_blur_kernel=motion_blur_ksize if apply_motion_blur_o else 1,
232
  jpeg_cycles=jpeg_cycles if apply_jpeg_cycles_o else 1,
233
  jpeg_qmin=jpeg_quality,
234
  jpeg_qmax=jpeg_quality,
235
  sim_camera=sim_camera,
236
- no_no_bayer=enable_bayer,
237
  iso_scale=iso_scale,
238
  read_noise=read_noise,
239
  seed=None,
240
  cutoff=0.25,
241
- # LUT fields (new)
242
  lut=(lut if lut != "" else None),
243
  lut_strength=lut_strength,
244
  )
@@ -246,9 +233,9 @@ class NovaNodes:
246
  # ---- Run the processing function ----
247
  process_image(input_path, output_path, args)
248
 
249
- # ---- Load result (force RGB to avoid unexpected single-channel shapes) ----
250
  output_img = Image.open(output_path).convert("RGB")
251
- img_out = np.array(output_img) # H x W x 3, uint8
252
 
253
  # ---- EXIF insertion (optional) ----
254
  new_exif = ""
@@ -261,11 +248,10 @@ class NovaNodes:
261
  new_exif = ""
262
 
263
  # ---- Convert to FOOLAI-style tensor: (1, H, W, C), float32 in [0,1] ----
264
- img_float = img_out.astype(np.float32) / 255.0 # H x W x C
265
- tensor_out = torch.from_numpy(img_float).to(dtype=torch.float32).unsqueeze(0) # 1 x H x W x C
266
  tensor_out = torch.clamp(tensor_out, 0.0, 1.0)
267
 
268
- # Return the same format FOOLAI uses: (tensor, exif_string)
269
  return (tensor_out, new_exif)
270
 
271
  finally:
@@ -314,4 +300,4 @@ NODE_CLASS_MAPPINGS = {
314
  }
315
  NODE_DISPLAY_NAME_MAPPINGS = {
316
  "NovaNodes": "Image Postprocess (NOVA NODES)",
317
- }
 
13
  else:
14
  IMPORT_ERROR = None
15
 
16
+ lut_extensions = ['png','npy','cube']
17
 
18
  class NovaNodes:
19
  """
 
35
  "required": {
36
  "image": ("IMAGE",),
37
 
38
+ # AWB
39
+ "enable_awb": ("BOOLEAN", {"default": False}),
40
+
41
  # EXIF
42
  "apply_exif_o": ("BOOLEAN", {"default": True}),
43
 
 
67
  "apply_chromatic_aberration_o": ("BOOLEAN", {"default": True}),
68
  "ca_shift": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.1}),
69
 
70
+ # Banding (FIXED)
71
  "apply_banding_o": ("BOOLEAN", {"default": True}),
72
+ "banding_strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
73
 
74
  # Motion blur
75
  "apply_motion_blur_o": ("BOOLEAN", {"default": True}),
 
86
  "iso_scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 16.0, "step": 0.1}),
87
  "read_noise": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 50.0, "step": 0.1}),
88
 
89
+ # LUT
90
+ "lut": ("STRING", {"default": "X://insert/path/here.npy", "vhs_path_extensions": lut_extensions}),
91
  "lut_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
92
  },
93
  "optional": {
94
+ "awb_ref_image": ("IMAGE",),
95
+ "fft_ref_image": ("IMAGE",),
96
  }
97
  }
98
 
 
101
  FUNCTION = "process"
102
  CATEGORY = "postprocessing"
103
 
104
+ def process(self, image, awb_ref_image=None, fft_ref_image=None,
105
+ enable_awb=False,
106
  apply_exif_o=True,
107
  noise_std_frac=0.015,
108
  hot_pixel_prob=1e-6,
 
121
  apply_chromatic_aberration_o=True,
122
  ca_shift=1.0,
123
  apply_banding_o=True,
124
+ banding_strength=0.5,
125
  apply_motion_blur_o=True,
126
  motion_blur_ksize=7,
127
  apply_jpeg_cycles_o=True,
 
141
 
142
  def to_pil_from_any(inp):
143
  """Convert a torch tensor / numpy array of many shapes into a PIL RGB Image."""
 
144
  if isinstance(inp, torch.Tensor):
145
  arr = inp.detach().cpu().numpy()
146
  else:
147
  arr = np.asarray(inp)
 
 
148
  if arr.ndim == 4 and arr.shape[0] == 1:
149
  arr = arr[0]
 
 
150
  if arr.ndim == 3 and arr.shape[0] in (1, 3):
151
  arr = np.transpose(arr, (1, 2, 0))
 
 
152
  if arr.ndim == 2:
 
153
  arr = arr[:, :, None]
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  if arr.ndim != 3:
155
  raise TypeError(f"Cannot convert array to HWC image, final ndim={arr.ndim}, shape={arr.shape}")
 
 
156
  if np.issubdtype(arr.dtype, np.floating):
 
157
  if arr.max() <= 1.0:
158
  arr = (arr * 255.0).clip(0, 255).astype(np.uint8)
159
  else:
160
  arr = np.clip(arr, 0, 255).astype(np.uint8)
161
  else:
162
  arr = arr.astype(np.uint8)
 
 
163
  if arr.shape[2] == 1:
164
  arr = np.repeat(arr, 3, axis=2)
 
 
165
  return Image.fromarray(arr)
166
 
167
  try:
 
172
  pil_img.save(input_path)
173
  tmp_files.append(input_path)
174
 
175
+ # ---- AWB reference image if present ----
176
+ awb_ref_path = None
177
+ if awb_ref_image is not None:
178
+ pil_ref_awb = to_pil_from_any(awb_ref_image[0])
179
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_ref_awb:
180
+ awb_ref_path = tmp_ref_awb.name
181
+ pil_ref_awb.save(awb_ref_path)
182
+ tmp_files.append(awb_ref_path)
183
+
184
+ # ---- FFT reference image if present ----
185
+ fft_ref_path = None
186
+ if fft_ref_image is not None:
187
+ pil_ref_fft = to_pil_from_any(fft_ref_image[0])
188
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_ref_fft:
189
+ fft_ref_path = tmp_ref_fft.name
190
+ pil_ref_fft.save(fft_ref_path)
191
+ tmp_files.append(fft_ref_path)
192
 
193
  # ---- Output path ----
194
  with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_output:
195
  output_path = tmp_output.name
196
  tmp_files.append(output_path)
197
 
198
+ # Prepare args for process_image
199
  args = SimpleNamespace(
200
  input=input_path,
201
  output=output_path,
202
+ awb=enable_awb, # Explicit AWB flag
203
+ ref=awb_ref_path,
204
  noise_std=noise_std_frac,
205
  hot_pixel_prob=hot_pixel_prob,
206
  perturb=perturb_mag_frac,
 
212
  fft_alpha=fourier_alpha,
213
  radial_smooth=fourier_radial_smooth,
214
  fft_mode=fourier_mode,
215
+ fft_ref=fft_ref_path,
216
  vignette_strength=vignette_strength if apply_vignette_o else 0.0,
217
  chroma_strength=ca_shift if apply_chromatic_aberration_o else 0.0,
218
+ banding_strength=banding_strength if apply_banding_o else 0.0,
219
  motion_blur_kernel=motion_blur_ksize if apply_motion_blur_o else 1,
220
  jpeg_cycles=jpeg_cycles if apply_jpeg_cycles_o else 1,
221
  jpeg_qmin=jpeg_quality,
222
  jpeg_qmax=jpeg_quality,
223
  sim_camera=sim_camera,
224
+ no_no_bayer=not enable_bayer, # FIX: Inverted logic corrected
225
  iso_scale=iso_scale,
226
  read_noise=read_noise,
227
  seed=None,
228
  cutoff=0.25,
 
229
  lut=(lut if lut != "" else None),
230
  lut_strength=lut_strength,
231
  )
 
233
  # ---- Run the processing function ----
234
  process_image(input_path, output_path, args)
235
 
236
+ # ---- Load result (force RGB) ----
237
  output_img = Image.open(output_path).convert("RGB")
238
+ img_out = np.array(output_img)
239
 
240
  # ---- EXIF insertion (optional) ----
241
  new_exif = ""
 
248
  new_exif = ""
249
 
250
  # ---- Convert to FOOLAI-style tensor: (1, H, W, C), float32 in [0,1] ----
251
+ img_float = img_out.astype(np.float32) / 255.0
252
+ tensor_out = torch.from_numpy(img_float).to(dtype=torch.float32).unsqueeze(0)
253
  tensor_out = torch.clamp(tensor_out, 0.0, 1.0)
254
 
 
255
  return (tensor_out, new_exif)
256
 
257
  finally:
 
300
  }
301
  NODE_DISPLAY_NAME_MAPPINGS = {
302
  "NovaNodes": "Image Postprocess (NOVA NODES)",
303
+ }
run.py CHANGED
@@ -515,11 +515,13 @@ class MainWindow(QMainWindow):
515
  args.banding_strength = float(self.banding_spin.value())
516
  args.motion_blur_kernel = int(self.motion_blur_spin.value())
517
 
518
- # AWB handling: only apply if checkbox is checked
519
  if self.awb_chk.isChecked():
520
- args.ref = awb_ref_val # may be None -> the pipeline will fall back to gray-world
 
521
  else:
522
- args.ref = None
 
523
 
524
  # FFT spectral matching reference
525
  args.fft_ref = fft_ref_val
@@ -608,4 +610,4 @@ def main():
608
  sys.exit(app.exec_())
609
 
610
  if __name__ == '__main__':
611
- main()
 
515
  args.banding_strength = float(self.banding_spin.value())
516
  args.motion_blur_kernel = int(self.motion_blur_spin.value())
517
 
518
+ # AWB handling to match the new --awb flag in the backend
519
  if self.awb_chk.isChecked():
520
+ args.awb = True
521
+ args.ref = awb_ref_val # This can be the path or None (for grey-world)
522
  else:
523
+ args.awb = False
524
+ args.ref = None # Not strictly necessary as backend ignores it, but good practice
525
 
526
  # FFT spectral matching reference
527
  args.fft_ref = fft_ref_val
 
610
  sys.exit(app.exec_())
611
 
612
  if __name__ == '__main__':
613
+ main()