Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| from segment_anything import SamPredictor, sam_model_registry | |
| from PIL import Image | |
| models = { | |
| 'vit_b': './checkpoints/sam_vit_b_01ec64.pth', | |
| 'vit_l': './checkpoints/sam_vit_l_0b3195.pth', | |
| 'vit_h': './checkpoints/sam_vit_h_4b8939.pth' | |
| } | |
| def get_sam_predictor(model_type='vit_h', device=None, image=None): | |
| if device is None and torch.cuda.is_available(): | |
| device = 'cuda' | |
| elif device is None: | |
| device = 'cpu' | |
| # sam model | |
| sam = sam_model_registry[model_type](checkpoint=models[model_type]) | |
| sam = sam.to(device) | |
| predictor = SamPredictor(sam) | |
| if image is not None: | |
| predictor.set_image(image) | |
| return predictor | |
| def sam_seg(predictor, input_img, input_points, input_labels): | |
| masks, scores, logits = predictor.predict( | |
| point_coords=input_points, | |
| point_labels=input_labels, | |
| multimask_output=True, | |
| ) | |
| opt_idx = np.argmax(scores) | |
| mask = masks[opt_idx] | |
| out_image = np.zeros((input_img.shape[0], input_img.shape[1], 4), dtype=np.uint8) | |
| out_image[:, :, :3] = input_img | |
| out_image[:, :, 3] = mask.astype(np.uint8) * 255 | |
| torch.cuda.empty_cache() | |
| return Image.fromarray(out_image, mode='RGBA') | |