kadirnar's picture
Upload 6 files
4792ad8
raw
history blame
2.79 kB
import gradio as gr
import numpy as np
from PIL import Image
import random
import matplotlib.pyplot as plt
import torch
from transformers import SegformerForSemanticSegmentation
from transformers import SegformerImageProcessor
image_list = [
"data/1.png",
"data/2.png",
"data/3.png",
"data/4.png",
]
def visualize_instance_seg_mask(mask):
# Initialize image with zeros with the image resolution
# of the segmentation mask and 3 channels
image = np.zeros((mask.shape[0], mask.shape[1], 3))
# Create labels
labels = np.unique(mask)
label2color = {
label: (
random.randint(0, 255),
random.randint(0, 255),
random.randint(0, 255),
)
for label in labels
}
for height in range(image.shape[0]):
for width in range(image.shape[1]):
image[height, width, :] = label2color[mask[height, width]]
image = image / 255
return image
def Segformer_Segmentation(image_path, model_id):
output_save = "output.png"
test_image = Image.open(image_path)
model = SegformerForSemanticSegmentation.from_pretrained(model_id)
proccessor = SegformerImageProcessor(model_id)
inputs = proccessor(images=test_image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
result = proccessor.post_process_semantic_segmentation(outputs)[0]
result = np.array(result)
result = visualize_instance_seg_mask(result)
plt.figure(figsize=(10, 10))
for plot_index in range(2):
if plot_index == 0:
plot_image = test_image
title = "Original"
else:
plot_image = result
title = "Segmentation"
plt.subplot(1, 2, plot_index+1)
plt.imshow(plot_image)
plt.title(title)
plt.axis("off")
plt.savefig(output_save)
return output_save
inputs = [
gr.inputs.Image(type="filepath", label="Input Image"),
gr.inputs.Dropdown(
choices=[
"deprem-ml/deprem_satellite_semantic_whu"
],
label="Model ID",
default="deprem-ml/deprem_satellite_semantic_whu",
)
]
outputs = gr.Image(type="filepath", label="Segmentation")
examples = [[image_list[0], "deprem-ml/deprem_satellite_semantic_whu"],
[image_list[1], "deprem-ml/deprem_satellite_semantic_whu"],
[image_list[2], "deprem-ml/deprem_satellite_semantic_whu"],
[image_list[3], "deprem-ml/deprem_satellite_semantic_whu"]]
title = "Deprem ML - Segformer Semantic Segmentation"
demo_app = gr.Interface(
Segformer_Segmentation,
inputs,
outputs,
examples=examples,
title=title,
cache_examples=True
)
demo_app.launch(debug=True, enable_queue=True)