ReyaLabColumbia commited on
Commit
9aed5ca
·
verified ·
1 Parent(s): c76e306

Upload 3 files

Browse files
Files changed (3) hide show
  1. Colony_Analyzer_AI2.py +321 -0
  2. app.py +33 -0
  3. requirements.txt +6 -0
Colony_Analyzer_AI2.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Thu Mar 20 14:23:27 2025
5
+
6
+ @author: mattc
7
+ """
8
+
9
+ import os
10
+ import cv2
11
+
12
+ # path = '/home/mattc/Documents/ColonyAssaySegformer/'
13
+ # file_list = os.listdir(path)
14
+ # file_list = [x for x in file_list if (x[-4::]==".tif" or x[-5::]==".tiff")]
15
+ def cut_img(path, x):
16
+ img_map = {}
17
+ img = cv2.imread(path + x)
18
+ name = x.split(".")[0]
19
+ i_num = img.shape[0]/512
20
+ j_num = img.shape[1]/512
21
+ count = 1
22
+ for i in range(int(i_num)):
23
+ for j in range(int(j_num)):
24
+ img2 = img[(512*i):(512*(i+1)), (512*j):(512*(j+1))]
25
+ cv2.imwrite(path+name+'_part'+str(count)+'.tif', img2)
26
+ img_map[count] = path+name+'_part'+str(count)+'.tif'
27
+ count +=1
28
+ return(img_map)
29
+
30
+ import numpy as np
31
+
32
+ def stitch(img_map):
33
+ for x in img_map:
34
+ temp = img_map[x]
35
+ img_map[x] = cv2.imread(temp)
36
+ if (img_map[x] is None):
37
+ img_map[x] = cv2.imread(temp, cv2.IMREAD_UNCHANGED)
38
+ os.remove(temp)
39
+ rows = [
40
+ np.hstack([img_map[1], img_map[2], img_map[3], img_map[4]]), # First row (images 0 to 3)
41
+ np.hstack([img_map[5], img_map[6], img_map[7], img_map[8]]), # Second row (images 4 to 7)
42
+ np.hstack([img_map[9], img_map[10], img_map[11], img_map[12]]) # Third row (images 8 to 11)
43
+ ]
44
+
45
+ # Stack rows vertically
46
+ return(np.vstack(rows))
47
+
48
+ #img_map = cut_img(path, file_list[0])
49
+
50
+
51
+ from PIL import Image
52
+
53
+
54
+
55
+ import matplotlib.pyplot as plt
56
+
57
+ def visualize_segmentation(mask, image=0):
58
+ plt.figure(figsize=(10, 5))
59
+
60
+ if(not np.isscalar(image)):
61
+ # Show original image if it is entered
62
+ plt.subplot(1, 2, 1)
63
+ plt.imshow(image)
64
+ plt.title("Original Image")
65
+ plt.axis("off")
66
+
67
+ # Show segmentation mask
68
+ plt.subplot(1, 2, 2)
69
+ plt.imshow(mask, cmap="gray") # Show as grayscale
70
+ plt.title("Segmentation Mask")
71
+ plt.axis("off")
72
+
73
+ plt.show()
74
+
75
+ import torch
76
+ from transformers import SegformerForSemanticSegmentation
77
+ # Load fine-tuned model
78
+ model = SegformerForSemanticSegmentation.from_pretrained("/home/mattc/Documents/ColonyAssaySegformer/segformer_colony_model_ternary_finished") # Adjust path
79
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
80
+ model.to(device)
81
+ model.eval() # Set to evaluation mode
82
+
83
+ # Load image processor
84
+ from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
85
+ image_processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b3-finetuned-cityscapes-1024-1024")
86
+
87
+ def preprocess_image(image_path):
88
+ image = Image.open(image_path).convert("RGB") # Open and convert to RGB
89
+ inputs = image_processor(image, return_tensors="pt") # Preprocess for model
90
+ return image, inputs["pixel_values"]
91
+
92
+ def postprocess_mask(logits):
93
+ mask = torch.argmax(logits, dim=1) # Take argmax across the class dimension
94
+ return mask.squeeze().cpu().numpy() # Convert to NumPy array
95
+
96
+
97
+ def eval_img(image_path):
98
+ # Load and preprocess image
99
+ image, pixel_values = preprocess_image(image_path)
100
+ pixel_values = pixel_values.to(device)
101
+ with torch.no_grad(): # No gradient calculation for inference
102
+ outputs = model(pixel_values=pixel_values) # Run model
103
+ logits = outputs.logits
104
+ # Convert logits to segmentation mask
105
+ segmentation_mask = postprocess_mask(logits)
106
+ #visualize_segmentation(segmentation_mask,image)
107
+ segmentation_mask = cv2.resize(segmentation_mask, (512, 512), interpolation=cv2.INTER_LINEAR_EXACT)
108
+ return(segmentation_mask)
109
+
110
+
111
+ # for x in img_map:
112
+ # mask = eval_img(img_map[x])
113
+ # cv2.imwrite(img_map[x], mask)
114
+ # del mask,x
115
+ # p = stitch(img_map)
116
+ # visualize_segmentation(p)
117
+
118
+ # num_colony = np.count_nonzero(p == 1) # Counts number of 1s
119
+ # num_necrosis = np.count_nonzero(p == 2)
120
+
121
+ # num_necrosis/num_colony
122
+
123
+ def find_colonies(mask, size_cutoff, circ_cutoff):
124
+ binary_mask = np.where(mask == 1, 255, 0).astype(np.uint8)
125
+ contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
126
+ contoursf = []
127
+ for x in contours:
128
+ area = cv2.contourArea(x)
129
+ if (area < size_cutoff):
130
+ continue
131
+ perimeter = cv2.arcLength(x, True)
132
+
133
+ # Avoid division by zero
134
+ if perimeter == 0:
135
+ continue
136
+
137
+ # Calculate circularity
138
+ circularity = (4 * np.pi * area) / (perimeter ** 2)
139
+ if circularity >= circ_cutoff:
140
+ contoursf.append(x)
141
+ return(contoursf)
142
+
143
+ def find_necrosis(mask):
144
+ binary_mask = np.where(mask == 2, 255, 0).astype(np.uint8)
145
+ contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
146
+ return(contours)
147
+
148
+ # contour_image = np.zeros_like(p)
149
+ # contours = find_necrosis(p)
150
+ # cv2.drawContours(contour_image, contours, -1, (255), 2)
151
+ # visualize_segmentation(contour_image)
152
+ import pandas as pd
153
+ def compute_centroid(contour):
154
+ M = cv2.moments(contour)
155
+ if M["m00"] == 0: # Avoid division by zero
156
+ return None
157
+ cx = int(M["m10"] / M["m00"])
158
+ cy = int(M["m01"] / M["m00"])
159
+ return (cx, cy)
160
+
161
+
162
+ def contours_overlap_using_mask(contour1, contour2, image_shape=(1536, 2048)):
163
+ """Check if two contours overlap using a bitwise AND mask."""
164
+ import numpy as np
165
+ import cv2
166
+ mask1 = np.zeros(image_shape, dtype=np.uint8)
167
+ mask2 = np.zeros(image_shape, dtype=np.uint8)
168
+
169
+
170
+ # Draw each contour as a white shape on its respective mask
171
+ cv2.drawContours(mask1, [contour1], -1, 255, thickness=cv2.FILLED)
172
+ cv2.drawContours(mask2, [contour2], -1, 255, thickness=cv2.FILLED)
173
+
174
+
175
+ # Compute bitwise AND to find overlapping regions
176
+ overlap = cv2.bitwise_and(mask1, mask2)
177
+
178
+ return np.any(overlap)
179
+
180
+ def analyze_colonies(mask, size_cutoff, circ_cutoff):
181
+ colonies = find_colonies(mask, size_cutoff, circ_cutoff)
182
+ necrosis = find_necrosis(mask)
183
+
184
+ data = []
185
+
186
+ for colony in colonies:
187
+ colony_area = cv2.contourArea(colony)
188
+ centroid = compute_centroid(colony)
189
+ if colony_area <= 50:
190
+ continue
191
+
192
+ # Check if any necrosis contour is inside the colony
193
+ necrosis_area = 0
194
+ nec_list =[]
195
+ for nec in necrosis:
196
+ # Check if the first point of the necrosis contour is inside the colony
197
+ if contours_overlap_using_mask(colony, nec):
198
+ nec_area = cv2.contourArea(nec)
199
+ necrosis_area += nec_area
200
+ nec_list.append(nec)
201
+
202
+ data.append({
203
+ "colony_area": colony_area,
204
+ "necrosis_area": necrosis_area,
205
+ "centroid": centroid,
206
+ "percent_necrosis": necrosis_area/colony_area,
207
+ "contour": colony,
208
+ "nec_contours": nec_list
209
+ })
210
+
211
+ # Convert results to a DataFrame
212
+ df = pd.DataFrame(data)
213
+ df.index = range(1,len(df.index)+1)
214
+ return(df)
215
+
216
+
217
+ def contour_overlap(contour1, contour2, centroid1, centroid2, area1, area2, centroid_thresh=25, area_thresh = 85, img_shape = (1536, 2048)):
218
+ """
219
+ Determines the overlap between two contours.
220
+ Returns:
221
+ 0: No overlap
222
+ 1: Overlap but does not meet strict conditions
223
+ 2: Overlap >= 80% of the larger contour and centroids are close
224
+ """
225
+ # Create blank images
226
+ img1 = np.zeros(img_shape, dtype=np.uint8)
227
+ img2 = np.zeros(img_shape, dtype=np.uint8)
228
+
229
+ # Draw filled contours
230
+ cv2.drawContours(img1, [contour1], -1, 255, thickness=cv2.FILLED)
231
+ cv2.drawContours(img2, [contour2], -1, 255, thickness=cv2.FILLED)
232
+
233
+ # Compute overlap
234
+ intersection = cv2.bitwise_and(img1, img2)
235
+ intersection_area = np.count_nonzero(intersection)
236
+
237
+ if intersection_area == 0:
238
+ return 0 # No overlap
239
+
240
+ # Compute centroid distance
241
+ centroid_distance = float(np.sqrt(abs(centroid1[0]-centroid2[0])**2 + abs(centroid1[1]-centroid2[1])**2))
242
+ # Check percentage overlap relative to the larger contour
243
+ overlap_ratio = intersection_area/max(area1, area2)
244
+
245
+ if overlap_ratio >= area_thresh and centroid_distance <= centroid_thresh:
246
+ if area1 > area2:
247
+ return(2)
248
+ else:
249
+ return(3)
250
+ else:
251
+ return 1 # Some overlap but not meeting strict criteria
252
+
253
+ def compare_frames(frame1, frame2):
254
+ for i in range(1, len(frame1)+1):
255
+ for j in range(1, len(frame2)+1):
256
+ temp = contour_overlap(frame1.loc[i, "contour"], frame2.loc[j, "contour"], frame1.loc[i, "centroid"], frame2.loc[j, "centroid"], frame1.loc[i, "colony_area"], frame2.loc[j, "colony_area"])
257
+ if temp ==2:
258
+ frame2.loc[j,"exclude"] = True
259
+ elif temp ==3:
260
+ frame1.loc[j, "exclude"] = True
261
+ break
262
+ frame1 = frame1[frame1["exclude"]==False]
263
+ frame2 = frame2[frame2["exclude"]==False]
264
+ df = pd.concat([frame1, frame2], axis=0)
265
+ df.index = range(1,len(df.index)+1)
266
+ return(df)
267
+
268
+ def main(args):
269
+ path = args[0]
270
+ file = args[1]
271
+ min_size = args[2]
272
+ min_circ = args[3]
273
+ colonies = {}
274
+ img_map = cut_img(path, file)
275
+ for z in img_map:
276
+ mask = eval_img(img_map[z])
277
+ cv2.imwrite(img_map[z], mask)
278
+ del mask,z
279
+ p = stitch(img_map)
280
+ colonies = analyze_colonies(p, min_size, min_circ)
281
+
282
+ img = cv2.imread(path + file)
283
+ img = cv2.copyMakeBorder(img,top=0, bottom=10,left=0,right=10, borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255])
284
+
285
+
286
+ colonies = colonies.sort_values(by=["colony_area"], ascending=False)
287
+ colonies = colonies[colonies["colony_area"]>= min_size]
288
+ colonies.index = range(1,len(colonies.index)+1)
289
+
290
+ for i in range(len(colonies)):
291
+ cv2.drawContours(img, [list(colonies["contour"])[i]], -1, (0, 255, 0), 2)
292
+ cv2.drawContours(img, list(colonies['nec_contours'])[i], -1, (0, 0, 255), 2)
293
+ coords = list(list(colonies["centroid"])[i])
294
+ if coords[0] > 1950:
295
+ #if a colony is too close to the right edge, makes the label move to left
296
+ coords[0] = 1950
297
+ cv2.putText(img, str(colonies.index[i]), tuple(coords), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 0), 2)
298
+ img = cv2.copyMakeBorder(img,top=10, bottom=0,left=10,right=0, borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255])
299
+ colonies = colonies.drop('contour', axis=1)
300
+ colonies = colonies.drop('nec_contours', axis=1)
301
+
302
+ colonies.insert(loc=0, column="Colony Number", value=[str(x) for x in range(1, len(colonies)+1)])
303
+ total_area_dark = sum(colonies['necrosis_area'])
304
+ total_area_light = sum(colonies['colony_area'])
305
+ ratio = total_area_dark/(abs(total_area_light)+1)
306
+
307
+ colonies.loc[len(colonies)+1] = ["Total", total_area_light, total_area_dark, None, ratio]
308
+ Parameters = pd.DataFrame({"Minimum colony size in pixels":[min_size], "Minimum colony circularity":[min_circ]})
309
+ file = file.split('.')[0]
310
+ with pd.ExcelWriter(path+file+'.xlsx') as writer:
311
+ colonies.to_excel(writer, sheet_name="Colony data", index=False)
312
+ Parameters.to_excel(writer, sheet_name="Parameters", index=False)
313
+ caption = np.ones((150, 2068, 3), dtype=np.uint8) * 255 # Multiply by 255 to make it white
314
+ cv2.putText(caption, "Total area necrotic: "+str(total_area_dark)+ ", Total area living: "+str(total_area_light)+", Ratio: "+str(ratio), (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 3)
315
+
316
+
317
+
318
+ cv2.imwrite(path+file+'.png', np.vstack((img, caption)))
319
+
320
+ return(np.vstack((img, caption)))
321
+
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import Colony_Analyzer_AI2 as analyzer
3
+ from PIL import Image
4
+ import cv2
5
+ import numpy as np
6
+
7
+ # Analysis function adapted from your Tkinter app
8
+ def analyze_image(image, min_size, circularity):
9
+ # Convert Gradio input image (PIL) to a compatible format
10
+ image = np.array(image)
11
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
12
+
13
+ # Assume your analyzer.main accepts [image, params] format, adjust as needed
14
+ processed_img = analyzer.main([image, min_size, circularity])
15
+
16
+ # Convert back to RGB for display
17
+ result = cv2.cvtColor(processed_img, cv2.COLOR_BGR2RGB)
18
+ return Image.fromarray(result)
19
+
20
+ # Create Gradio interface
21
+ iface = gr.Interface(
22
+ fn=analyze_image,
23
+ inputs=[
24
+ gr.Image(type="pil", label="Upload Image"),
25
+ gr.Number(label="Minimum Colony Size (pixels)", value=1000),
26
+ gr.Number(label="Minimum Circularity", value=0.25)
27
+ ],
28
+ outputs=gr.Image(type="pil", label="Analyzed Image"),
29
+ title="AI Colony Analyzer",
30
+ description="Upload an image to run the colony analysis."
31
+ )
32
+
33
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ opencv-python
3
+ numpy
4
+ Pillow
5
+ torch
6
+ transformers