demo-ml-v3 / app.py
spuun's picture
fix: thresh
a48be64 verified
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import torch.nn.functional as F
from safetensors.torch import load_model, save_model
from models import *
import os
class WasteClassifier:
def __init__(self, model, class_names, device):
self.model = model
self.class_names = class_names
self.device = device
self.transform = transforms.Compose(
[
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
def predict(self, image):
self.model.eval()
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
original_size = image.size
img_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs, seg_mask = self.model(img_tensor) # Handle both outputs
probabilities = torch.nn.functional.softmax(outputs, dim=1)
probs = probabilities[0].cpu().numpy()
pred_class = self.class_names[np.argmax(probs)]
confidence = np.max(probs)
# Process segmentation mask
seg_mask = (
seg_mask[0, 0].cpu().numpy().astype(np.float32)
)
# Get first image, first channel
# seg_mask = (seg_mask >= 0.2).astype(np.float32) # Threshold at 0.2
# Resize mask back to original image size
seg_mask = Image.fromarray(seg_mask)
seg_mask = seg_mask.resize(original_size, Image.NEAREST)
seg_mask = np.array(seg_mask)
results = {
"predicted_class": pred_class,
"confidence": confidence,
"class_probabilities": {
class_name: float(prob)
for class_name, prob in zip(self.class_names, probs)
},
"segmentation_mask": seg_mask,
}
return results
def interface(classifier):
def process_image(image):
results = classifier.predict(image)
if isinstance(image, Image.Image):
image_np = np.array(image)
else:
image_np = image
mask = results["segmentation_mask"]
print(mask)
# Normalize the mask to 0-1 range based on its min/max values
if mask.max() > mask.min(): # Avoid division by zero if mask is perfectly flat
normalized_mask = (mask - mask.min()) / (mask.max() - mask.min())
else:
normalized_mask = np.zeros_like(mask) # If flat, just make it all black
# Use the normalized mask for visualization
mask_viz = (normalized_mask * 255).astype(np.uint8)
binary_for_overlay = normalized_mask > 0.25
overlay = image_np.copy()
overlay[~binary_for_overlay] = (overlay[~binary_for_overlay] * 0.3).astype(np.uint8)
output_str = f"Predicted Class: {results['predicted_class']}\n"
output_str += f"Confidence: {results['confidence']*100:.2f}%\n\n"
output_str += "Class Probabilities:\n"
sorted_probs = sorted(
results["class_probabilities"].items(), key=lambda x: x[1], reverse=True
)
for class_name, prob in sorted_probs:
output_str += f"{class_name}: {prob*100:.2f}%\n"
return [output_str, overlay, mask_viz]
demo = gr.Interface(
fn=process_image,
inputs=[gr.Image(type="pil", label="Upload Image")],
outputs=[
gr.Textbox(label="Classification Results"),
gr.Image(label="Segmented Object"),
gr.Image(label="Segmentation Mask"),
],
title="Waste Classification System",
description="""
Upload an image of waste to classify it into different categories.
The model will predict the type of waste, show confidence scores for each category,
and display the segmented object along with its mask.
""",
examples=(
[["example1.jpg"], ["example2.jpg"], ["example3.jpg"]]
if os.path.exists("example1.jpg")
else None
),
analytics_enabled=False,
theme="default",
)
return demo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_names = [
"Cardboard",
"Food Organics",
"Glass",
"Metal",
"Miscellaneous Trash",
"Paper",
"Plastic",
"Textile Trash",
"Vegetation",
]
best_model = ResNet18UNet(num_classes=len(class_names))
best_model = best_model.to(device)
load_model(
best_model,
os.path.join(os.path.dirname(os.path.abspath(__file__)), "ppn62p.safetensors"),
)
classifier = WasteClassifier(best_model, class_names, device)
demo = interface(classifier)
demo.launch()