ReyaLabColumbia commited on
Commit
ddfcf22
·
verified ·
1 Parent(s): 713e91d

Upload 2 files

Browse files
Files changed (2) hide show
  1. Colony_Analyzer_AI2_HF.py +305 -0
  2. app.py +2 -2
Colony_Analyzer_AI2_HF.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Thu Mar 20 14:23:27 2025
5
+
6
+ @author: mattc
7
+ """
8
+
9
+ import os
10
+ import cv2
11
+
12
+ #this is the huggingface version
13
+ def cut_img(img):
14
+ img_map = {}
15
+ name = x.split(".")[0]
16
+ i_num = img.shape[0]/512
17
+ j_num = img.shape[1]/512
18
+ count = 1
19
+ for i in range(int(i_num)):
20
+ for j in range(int(j_num)):
21
+ img_map[count] = img[(512*i):(512*(i+1)), (512*j):(512*(j+1))]
22
+ count +=1
23
+ return(img_map)
24
+
25
+ import numpy as np
26
+
27
+ def stitch(img_map):
28
+ rows = [
29
+ np.hstack([img_map[1], img_map[2], img_map[3], img_map[4]]), # First row (images 0 to 3)
30
+ np.hstack([img_map[5], img_map[6], img_map[7], img_map[8]]), # Second row (images 4 to 7)
31
+ np.hstack([img_map[9], img_map[10], img_map[11], img_map[12]]) # Third row (images 8 to 11)
32
+ ]
33
+ # Stack rows vertically
34
+ return(np.vstack(rows))
35
+
36
+
37
+
38
+ from PIL import Image
39
+
40
+
41
+
42
+ import matplotlib.pyplot as plt
43
+
44
+ def visualize_segmentation(mask, image=0):
45
+ plt.figure(figsize=(10, 5))
46
+
47
+ if(not np.isscalar(image)):
48
+ # Show original image if it is entered
49
+ plt.subplot(1, 2, 1)
50
+ plt.imshow(image)
51
+ plt.title("Original Image")
52
+ plt.axis("off")
53
+
54
+ # Show segmentation mask
55
+ plt.subplot(1, 2, 2)
56
+ plt.imshow(mask, cmap="gray") # Show as grayscale
57
+ plt.title("Segmentation Mask")
58
+ plt.axis("off")
59
+
60
+ plt.show()
61
+
62
+ import torch
63
+ from transformers import SegformerForSemanticSegmentation
64
+ # Load fine-tuned model
65
+ model = SegformerForSemanticSegmentation.from_pretrained("ReyaLabColumbia/Segformer_Colony_Counter")
66
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67
+ model.to(device)
68
+ model.eval() # Set to evaluation mode
69
+
70
+ # Load image processor
71
+ from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
72
+ image_processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b3-finetuned-cityscapes-1024-1024")
73
+
74
+ def preprocess_image(image):
75
+ image = image.convert("RGB") # Open and convert to RGB
76
+ inputs = image_processor(image, return_tensors="pt") # Preprocess for model
77
+ return image, inputs["pixel_values"]
78
+
79
+ def postprocess_mask(logits):
80
+ mask = torch.argmax(logits, dim=1) # Take argmax across the class dimension
81
+ return mask.squeeze().cpu().numpy() # Convert to NumPy array
82
+
83
+
84
+ def eval_img(image):
85
+ # Load and preprocess image
86
+ image, pixel_values = preprocess_image(image)
87
+ pixel_values = pixel_values.to(device)
88
+ with torch.no_grad(): # No gradient calculation for inference
89
+ outputs = model(pixel_values=pixel_values) # Run model
90
+ logits = outputs.logits
91
+ # Convert logits to segmentation mask
92
+ segmentation_mask = postprocess_mask(logits)
93
+ #visualize_segmentation(segmentation_mask,image)
94
+ segmentation_mask = cv2.resize(segmentation_mask, (512, 512), interpolation=cv2.INTER_LINEAR_EXACT)
95
+ return(segmentation_mask)
96
+
97
+
98
+ # for x in img_map:
99
+ # mask = eval_img(img_map[x])
100
+ # cv2.imwrite(img_map[x], mask)
101
+ # del mask,x
102
+ # p = stitch(img_map)
103
+ # visualize_segmentation(p)
104
+
105
+ # num_colony = np.count_nonzero(p == 1) # Counts number of 1s
106
+ # num_necrosis = np.count_nonzero(p == 2)
107
+
108
+ # num_necrosis/num_colony
109
+
110
+ def find_colonies(mask, size_cutoff, circ_cutoff):
111
+ binary_mask = np.where(mask == 1, 255, 0).astype(np.uint8)
112
+ contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
113
+ contoursf = []
114
+ for x in contours:
115
+ area = cv2.contourArea(x)
116
+ if (area < size_cutoff):
117
+ continue
118
+ perimeter = cv2.arcLength(x, True)
119
+
120
+ # Avoid division by zero
121
+ if perimeter == 0:
122
+ continue
123
+
124
+ # Calculate circularity
125
+ circularity = (4 * np.pi * area) / (perimeter ** 2)
126
+ if circularity >= circ_cutoff:
127
+ contoursf.append(x)
128
+ return(contoursf)
129
+
130
+ def find_necrosis(mask):
131
+ binary_mask = np.where(mask == 2, 255, 0).astype(np.uint8)
132
+ contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
133
+ return(contours)
134
+
135
+ # contour_image = np.zeros_like(p)
136
+ # contours = find_necrosis(p)
137
+ # cv2.drawContours(contour_image, contours, -1, (255), 2)
138
+ # visualize_segmentation(contour_image)
139
+ import pandas as pd
140
+ def compute_centroid(contour):
141
+ M = cv2.moments(contour)
142
+ if M["m00"] == 0: # Avoid division by zero
143
+ return None
144
+ cx = int(M["m10"] / M["m00"])
145
+ cy = int(M["m01"] / M["m00"])
146
+ return (cx, cy)
147
+
148
+
149
+ def contours_overlap_using_mask(contour1, contour2, image_shape=(1536, 2048)):
150
+ """Check if two contours overlap using a bitwise AND mask."""
151
+ import numpy as np
152
+ import cv2
153
+ mask1 = np.zeros(image_shape, dtype=np.uint8)
154
+ mask2 = np.zeros(image_shape, dtype=np.uint8)
155
+
156
+
157
+ # Draw each contour as a white shape on its respective mask
158
+ cv2.drawContours(mask1, [contour1], -1, 255, thickness=cv2.FILLED)
159
+ cv2.drawContours(mask2, [contour2], -1, 255, thickness=cv2.FILLED)
160
+
161
+
162
+ # Compute bitwise AND to find overlapping regions
163
+ overlap = cv2.bitwise_and(mask1, mask2)
164
+
165
+ return np.any(overlap)
166
+
167
+ def analyze_colonies(mask, size_cutoff, circ_cutoff):
168
+ colonies = find_colonies(mask, size_cutoff, circ_cutoff)
169
+ necrosis = find_necrosis(mask)
170
+
171
+ data = []
172
+
173
+ for colony in colonies:
174
+ colony_area = cv2.contourArea(colony)
175
+ centroid = compute_centroid(colony)
176
+ if colony_area <= 50:
177
+ continue
178
+
179
+ # Check if any necrosis contour is inside the colony
180
+ necrosis_area = 0
181
+ nec_list =[]
182
+ for nec in necrosis:
183
+ # Check if the first point of the necrosis contour is inside the colony
184
+ if contours_overlap_using_mask(colony, nec):
185
+ nec_area = cv2.contourArea(nec)
186
+ necrosis_area += nec_area
187
+ nec_list.append(nec)
188
+
189
+ data.append({
190
+ "colony_area": colony_area,
191
+ "necrosis_area": necrosis_area,
192
+ "centroid": centroid,
193
+ "percent_necrosis": necrosis_area/colony_area,
194
+ "contour": colony,
195
+ "nec_contours": nec_list
196
+ })
197
+
198
+ # Convert results to a DataFrame
199
+ df = pd.DataFrame(data)
200
+ df.index = range(1,len(df.index)+1)
201
+ return(df)
202
+
203
+
204
+ def contour_overlap(contour1, contour2, centroid1, centroid2, area1, area2, centroid_thresh=25, area_thresh = 85, img_shape = (1536, 2048)):
205
+ """
206
+ Determines the overlap between two contours.
207
+ Returns:
208
+ 0: No overlap
209
+ 1: Overlap but does not meet strict conditions
210
+ 2: Overlap >= 80% of the larger contour and centroids are close
211
+ """
212
+ # Create blank images
213
+ img1 = np.zeros(img_shape, dtype=np.uint8)
214
+ img2 = np.zeros(img_shape, dtype=np.uint8)
215
+
216
+ # Draw filled contours
217
+ cv2.drawContours(img1, [contour1], -1, 255, thickness=cv2.FILLED)
218
+ cv2.drawContours(img2, [contour2], -1, 255, thickness=cv2.FILLED)
219
+
220
+ # Compute overlap
221
+ intersection = cv2.bitwise_and(img1, img2)
222
+ intersection_area = np.count_nonzero(intersection)
223
+
224
+ if intersection_area == 0:
225
+ return 0 # No overlap
226
+
227
+ # Compute centroid distance
228
+ centroid_distance = float(np.sqrt(abs(centroid1[0]-centroid2[0])**2 + abs(centroid1[1]-centroid2[1])**2))
229
+ # Check percentage overlap relative to the larger contour
230
+ overlap_ratio = intersection_area/max(area1, area2)
231
+
232
+ if overlap_ratio >= area_thresh and centroid_distance <= centroid_thresh:
233
+ if area1 > area2:
234
+ return(2)
235
+ else:
236
+ return(3)
237
+ else:
238
+ return 1 # Some overlap but not meeting strict criteria
239
+
240
+ def compare_frames(frame1, frame2):
241
+ for i in range(1, len(frame1)+1):
242
+ for j in range(1, len(frame2)+1):
243
+ temp = contour_overlap(frame1.loc[i, "contour"], frame2.loc[j, "contour"], frame1.loc[i, "centroid"], frame2.loc[j, "centroid"], frame1.loc[i, "colony_area"], frame2.loc[j, "colony_area"])
244
+ if temp ==2:
245
+ frame2.loc[j,"exclude"] = True
246
+ elif temp ==3:
247
+ frame1.loc[j, "exclude"] = True
248
+ break
249
+ frame1 = frame1[frame1["exclude"]==False]
250
+ frame2 = frame2[frame2["exclude"]==False]
251
+ df = pd.concat([frame1, frame2], axis=0)
252
+ df.index = range(1,len(df.index)+1)
253
+ return(df)
254
+
255
+ def main(args):
256
+ min_size = args[1]
257
+ min_circ = args[2]
258
+ colonies = {}
259
+ img_map = cut_img(args[0])
260
+ for z in img_map:
261
+ mask = eval_img(img_map[z])
262
+ cv2.imwrite(img_map[z], mask)
263
+ del mask,z
264
+ p = stitch(img_map)
265
+ colonies = analyze_colonies(p, min_size, min_circ)
266
+
267
+ img = args[0]
268
+ img = cv2.copyMakeBorder(img,top=0, bottom=10,left=0,right=10, borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255])
269
+
270
+
271
+ colonies = colonies.sort_values(by=["colony_area"], ascending=False)
272
+ colonies = colonies[colonies["colony_area"]>= min_size]
273
+ colonies.index = range(1,len(colonies.index)+1)
274
+
275
+ for i in range(len(colonies)):
276
+ cv2.drawContours(img, [list(colonies["contour"])[i]], -1, (0, 255, 0), 2)
277
+ cv2.drawContours(img, list(colonies['nec_contours'])[i], -1, (0, 0, 255), 2)
278
+ coords = list(list(colonies["centroid"])[i])
279
+ if coords[0] > 1950:
280
+ #if a colony is too close to the right edge, makes the label move to left
281
+ coords[0] = 1950
282
+ cv2.putText(img, str(colonies.index[i]), tuple(coords), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 0), 2)
283
+ img = cv2.copyMakeBorder(img,top=10, bottom=0,left=10,right=0, borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255])
284
+ colonies = colonies.drop('contour', axis=1)
285
+ colonies = colonies.drop('nec_contours', axis=1)
286
+
287
+ colonies.insert(loc=0, column="Colony Number", value=[str(x) for x in range(1, len(colonies)+1)])
288
+ total_area_dark = sum(colonies['necrosis_area'])
289
+ total_area_light = sum(colonies['colony_area'])
290
+ ratio = total_area_dark/(abs(total_area_light)+1)
291
+
292
+ colonies.loc[len(colonies)+1] = ["Total", total_area_light, total_area_dark, None, ratio]
293
+ Parameters = pd.DataFrame({"Minimum colony size in pixels":[min_size], "Minimum colony circularity":[min_circ]})
294
+ with pd.ExcelWriter('results.xlsx') as writer:
295
+ colonies.to_excel(writer, sheet_name="Colony data", index=False)
296
+ Parameters.to_excel(writer, sheet_name="Parameters", index=False)
297
+ caption = np.ones((150, 2068, 3), dtype=np.uint8) * 255 # Multiply by 255 to make it white
298
+ cv2.putText(caption, "Total area necrotic: "+str(total_area_dark)+ ", Total area living: "+str(total_area_light)+", Ratio: "+str(ratio), (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 3)
299
+
300
+
301
+
302
+ #cv2.imwrite('results.png', np.vstack((img, caption)))
303
+
304
+ return(np.vstack((img, caption)))
305
+
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- import Colony_Analyzer_AI2 as analyzer
3
  from PIL import Image
4
  import cv2
5
  import numpy as np
@@ -25,7 +25,7 @@ iface = gr.Interface(
25
  gr.Number(label="Minimum Colony Size (pixels)", value=1000),
26
  gr.Number(label="Minimum Circularity", value=0.25)
27
  ],
28
- outputs=gr.Image(type="pil", label="Analyzed Image"),
29
  title="AI Colony Analyzer",
30
  description="Upload an image to run the colony analysis."
31
  )
 
1
  import gradio as gr
2
+ import Colony_Analyzer_AI2_HF as analyzer
3
  from PIL import Image
4
  import cv2
5
  import numpy as np
 
25
  gr.Number(label="Minimum Colony Size (pixels)", value=1000),
26
  gr.Number(label="Minimum Circularity", value=0.25)
27
  ],
28
+ outputs=[gr.Image(type="pil", label="Analyzed Image"), gr.File(label="Download results (Excel)")],
29
  title="AI Colony Analyzer",
30
  description="Upload an image to run the colony analysis."
31
  )