Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from safetensors.torch import load_model, save_model | |
| from models import * | |
| import os | |
| class WasteClassifier: | |
| def __init__(self, model, class_names, device): | |
| self.model = model | |
| self.class_names = class_names | |
| self.device = device | |
| self.transform = transforms.Compose( | |
| [ | |
| transforms.Resize((384, 384)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
| ), | |
| ] | |
| ) | |
| def predict(self, image): | |
| self.model.eval() | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| original_size = image.size | |
| img_tensor = self.transform(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| outputs, seg_mask = self.model(img_tensor) # Handle both outputs | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
| probs = probabilities[0].cpu().numpy() | |
| pred_class = self.class_names[np.argmax(probs)] | |
| confidence = np.max(probs) | |
| # Process segmentation mask | |
| seg_mask = ( | |
| seg_mask[0, 0].cpu().numpy().astype(np.float32) | |
| ) | |
| # Get first image, first channel | |
| # seg_mask = (seg_mask >= 0.2).astype(np.float32) # Threshold at 0.2 | |
| # Resize mask back to original image size | |
| seg_mask = Image.fromarray(seg_mask) | |
| seg_mask = seg_mask.resize(original_size, Image.NEAREST) | |
| seg_mask = np.array(seg_mask) | |
| results = { | |
| "predicted_class": pred_class, | |
| "confidence": confidence, | |
| "class_probabilities": { | |
| class_name: float(prob) | |
| for class_name, prob in zip(self.class_names, probs) | |
| }, | |
| "segmentation_mask": seg_mask, | |
| } | |
| return results | |
| def interface(classifier): | |
| def process_image(image): | |
| results = classifier.predict(image) | |
| if isinstance(image, Image.Image): | |
| image_np = np.array(image) | |
| else: | |
| image_np = image | |
| mask = results["segmentation_mask"] | |
| print(mask) | |
| # Normalize the mask to 0-1 range based on its min/max values | |
| if mask.max() > mask.min(): # Avoid division by zero if mask is perfectly flat | |
| normalized_mask = (mask - mask.min()) / (mask.max() - mask.min()) | |
| else: | |
| normalized_mask = np.zeros_like(mask) # If flat, just make it all black | |
| # Use the normalized mask for visualization | |
| mask_viz = (normalized_mask * 255).astype(np.uint8) | |
| binary_for_overlay = normalized_mask > 0.25 | |
| overlay = image_np.copy() | |
| overlay[~binary_for_overlay] = (overlay[~binary_for_overlay] * 0.3).astype(np.uint8) | |
| output_str = f"Predicted Class: {results['predicted_class']}\n" | |
| output_str += f"Confidence: {results['confidence']*100:.2f}%\n\n" | |
| output_str += "Class Probabilities:\n" | |
| sorted_probs = sorted( | |
| results["class_probabilities"].items(), key=lambda x: x[1], reverse=True | |
| ) | |
| for class_name, prob in sorted_probs: | |
| output_str += f"{class_name}: {prob*100:.2f}%\n" | |
| return [output_str, overlay, mask_viz] | |
| demo = gr.Interface( | |
| fn=process_image, | |
| inputs=[gr.Image(type="pil", label="Upload Image")], | |
| outputs=[ | |
| gr.Textbox(label="Classification Results"), | |
| gr.Image(label="Segmented Object"), | |
| gr.Image(label="Segmentation Mask"), | |
| ], | |
| title="Waste Classification System", | |
| description=""" | |
| Upload an image of waste to classify it into different categories. | |
| The model will predict the type of waste, show confidence scores for each category, | |
| and display the segmented object along with its mask. | |
| """, | |
| examples=( | |
| [["example1.jpg"], ["example2.jpg"], ["example3.jpg"]] | |
| if os.path.exists("example1.jpg") | |
| else None | |
| ), | |
| analytics_enabled=False, | |
| theme="default", | |
| ) | |
| return demo | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| class_names = [ | |
| "Cardboard", | |
| "Food Organics", | |
| "Glass", | |
| "Metal", | |
| "Miscellaneous Trash", | |
| "Paper", | |
| "Plastic", | |
| "Textile Trash", | |
| "Vegetation", | |
| ] | |
| best_model = ResNet18UNet(num_classes=len(class_names)) | |
| best_model = best_model.to(device) | |
| load_model( | |
| best_model, | |
| os.path.join(os.path.dirname(os.path.abspath(__file__)), "ppn62p.safetensors"), | |
| ) | |
| classifier = WasteClassifier(best_model, class_names, device) | |
| demo = interface(classifier) | |
| demo.launch() | |