Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation | |
| from diffusers import StableDiffusionInpaintPipeline # , DiffusionPipeline | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # Image segmentation | |
| seg_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") | |
| seg_model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") | |
| def segment_image(image): | |
| inputs = seg_processor(image, return_tensors="pt") | |
| with torch.no_grad(): | |
| seg_outputs = seg_model(**inputs) | |
| # get prediction dict | |
| seg_prediction = seg_processor.post_process_panoptic_segmentation(seg_outputs, target_sizes=[image.size[::-1]])[0] | |
| # get segment labels dict | |
| segment_labels = {} | |
| for segment in seg_prediction['segments_info']: | |
| segment_id = segment['id'] | |
| segment_label_id = segment['label_id'] | |
| segment_label = seg_model.config.id2label[segment_label_id] | |
| segment_labels.update({segment_id : segment_label}) | |
| return seg_prediction, segment_labels | |
| # Image inpainting pipeline | |
| # get Stable Diffusion model for image inpainting | |
| if device == 'cuda': | |
| pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-inpainting", | |
| torch_dtype=torch.float16, | |
| ).to(device) | |
| else: | |
| pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-inpainting", | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto" | |
| ) | |
| # pipe = StableDiffusionInpaintPipeline.from_pretrained( # DiffusionPipeline.from_pretrained( | |
| # "runwayml/stable-diffusion-inpainting", | |
| # revision="fp16", | |
| # torch_dtype=torch.bfloat16, | |
| # # device_map="auto" # use for Hugging face spaces | |
| # ) | |
| # pipe.to(device) # use for local environment | |
| def inpaint(image, mask, W, H, prompt="", seed=0, guidance_scale=17.5, num_samples=3): | |
| """ Uses Stable Diffusion model to inpaint image | |
| Inputs: | |
| image - input image (PIL or torch tensor) | |
| mask - mask for inpainting same size as image (PIL or troch tensor) | |
| W - size of image | |
| H - size of mask | |
| prompt - prompt for inpainting | |
| seed - random seed | |
| Outputs: | |
| images - output images | |
| """ | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| images = pipe( | |
| prompt=prompt, | |
| image=image, | |
| mask_image=mask, # ensure mask is same type as image | |
| height=H, | |
| width=W, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| num_images_per_prompt=num_samples, | |
| ).images | |
| return images | |