Spaces:
Sleeping
Sleeping
File size: 4,879 Bytes
8f7598e c9e9eb6 8f7598e f9a5b08 8f7598e c9e9eb6 f9a5b08 bdcbadf c9e9eb6 8f7598e c9e9eb6 8f7598e c9e9eb6 d5f52e1 c9e9eb6 39c426d a48be64 c9e9eb6 39c426d c9e9eb6 8f7598e c9e9eb6 8f7598e c9e9eb6 8f7598e c9e9eb6 8f7598e c9e9eb6 33c7cc9 8f7598e 0a3a9fc 33c7cc9 0a3a9fc 8f7598e 9cb0704 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import torch.nn.functional as F
from safetensors.torch import load_model, save_model
from models import *
import os
class WasteClassifier:
def __init__(self, model, class_names, device):
self.model = model
self.class_names = class_names
self.device = device
self.transform = transforms.Compose(
[
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
def predict(self, image):
self.model.eval()
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
original_size = image.size
img_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs, seg_mask = self.model(img_tensor) # Handle both outputs
probabilities = torch.nn.functional.softmax(outputs, dim=1)
probs = probabilities[0].cpu().numpy()
pred_class = self.class_names[np.argmax(probs)]
confidence = np.max(probs)
# Process segmentation mask
seg_mask = (
seg_mask[0, 0].cpu().numpy().astype(np.float32)
)
# Get first image, first channel
# seg_mask = (seg_mask >= 0.2).astype(np.float32) # Threshold at 0.2
# Resize mask back to original image size
seg_mask = Image.fromarray(seg_mask)
seg_mask = seg_mask.resize(original_size, Image.NEAREST)
seg_mask = np.array(seg_mask)
results = {
"predicted_class": pred_class,
"confidence": confidence,
"class_probabilities": {
class_name: float(prob)
for class_name, prob in zip(self.class_names, probs)
},
"segmentation_mask": seg_mask,
}
return results
def interface(classifier):
def process_image(image):
results = classifier.predict(image)
if isinstance(image, Image.Image):
image_np = np.array(image)
else:
image_np = image
mask = results["segmentation_mask"]
print(mask)
# Normalize the mask to 0-1 range based on its min/max values
if mask.max() > mask.min(): # Avoid division by zero if mask is perfectly flat
normalized_mask = (mask - mask.min()) / (mask.max() - mask.min())
else:
normalized_mask = np.zeros_like(mask) # If flat, just make it all black
# Use the normalized mask for visualization
mask_viz = (normalized_mask * 255).astype(np.uint8)
binary_for_overlay = normalized_mask > 0.25
overlay = image_np.copy()
overlay[~binary_for_overlay] = (overlay[~binary_for_overlay] * 0.3).astype(np.uint8)
output_str = f"Predicted Class: {results['predicted_class']}\n"
output_str += f"Confidence: {results['confidence']*100:.2f}%\n\n"
output_str += "Class Probabilities:\n"
sorted_probs = sorted(
results["class_probabilities"].items(), key=lambda x: x[1], reverse=True
)
for class_name, prob in sorted_probs:
output_str += f"{class_name}: {prob*100:.2f}%\n"
return [output_str, overlay, mask_viz]
demo = gr.Interface(
fn=process_image,
inputs=[gr.Image(type="pil", label="Upload Image")],
outputs=[
gr.Textbox(label="Classification Results"),
gr.Image(label="Segmented Object"),
gr.Image(label="Segmentation Mask"),
],
title="Waste Classification System",
description="""
Upload an image of waste to classify it into different categories.
The model will predict the type of waste, show confidence scores for each category,
and display the segmented object along with its mask.
""",
examples=(
[["example1.jpg"], ["example2.jpg"], ["example3.jpg"]]
if os.path.exists("example1.jpg")
else None
),
analytics_enabled=False,
theme="default",
)
return demo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_names = [
"Cardboard",
"Food Organics",
"Glass",
"Metal",
"Miscellaneous Trash",
"Paper",
"Plastic",
"Textile Trash",
"Vegetation",
]
best_model = ResNet18UNet(num_classes=len(class_names))
best_model = best_model.to(device)
load_model(
best_model,
os.path.join(os.path.dirname(os.path.abspath(__file__)), "ppn62p.safetensors"),
)
classifier = WasteClassifier(best_model, class_names, device)
demo = interface(classifier)
demo.launch()
|