Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| from download import attempt_download_from_hub | |
| import segmentation_models_pytorch as smp | |
| from dataloader import * | |
| import torch | |
| def unet_prediction(input_path, model_path): | |
| model_path = attempt_download_from_hub(model_path) | |
| best_model = torch.load(model_path) | |
| preprocessing_fn = smp.encoders.get_preprocessing_fn('efficientnet-b6', 'imagenet') | |
| test_dataset = Dataset(input_path, augmentation=get_validation_augmentation(), preprocessing=get_preprocessing(preprocessing_fn)) | |
| image = test_dataset.get() | |
| x_tensor = torch.from_numpy(image).to("cuda").unsqueeze(0) | |
| pr_mask = best_model.predict(x_tensor) | |
| pr_mask = (pr_mask.squeeze().cpu().numpy().round())*255 | |
| # Save the predicted mask | |
| cv2.imwrite("output.png", pr_mask) | |
| return 'output.png' | 
 
			

