Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Copyright (c) Alibaba, Inc. and its affiliates. | |
| import random | |
| from collections import OrderedDict | |
| import torch, numpy as np | |
| from PIL import Image | |
| from scepter.modules.model.registry import MODELS | |
| from scepter.modules.utils.config import Config | |
| from scepter.modules.utils.distribute import we | |
| from .registry import BaseInference, INFERENCES | |
| from .utils import ACEPlusImageProcessor | |
| class ACEInference(BaseInference): | |
| ''' | |
| reuse the ldm code | |
| ''' | |
| def __init__(self, cfg, logger=None): | |
| super().__init__(cfg, logger) | |
| self.pipe = MODELS.build(cfg.MODEL, logger=self.logger).eval().to(we.device_id) | |
| self.image_processor = ACEPlusImageProcessor(max_seq_len=cfg.MAX_SEQ_LEN) | |
| self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for | |
| k, v in cfg.SAMPLE_ARGS.items()} | |
| self.dtype = getattr(torch, cfg.get("DTYPE", "bfloat16")) | |
| def __call__(self, | |
| reference_image=None, | |
| edit_image=None, | |
| edit_mask=None, | |
| prompt='', | |
| edit_type=None, | |
| output_height=1024, | |
| output_width=1024, | |
| sampler='flow_euler', | |
| sample_steps=28, | |
| guide_scale=50, | |
| lora_path=None, | |
| seed=-1, | |
| repainting_scale=0, | |
| use_change=False, | |
| keep_pixels=False, | |
| keep_pixels_rate=0.8, | |
| **kwargs): | |
| # convert the input info to the input of ldm. | |
| if isinstance(prompt, str): | |
| prompt = [prompt] | |
| seed = seed if seed >= 0 else random.randint(0, 2 ** 24 - 1) | |
| image, mask, change_image, content_image, out_h, out_w, slice_w = self.image_processor.preprocess(reference_image, edit_image, edit_mask, | |
| height=output_height, width=output_width, | |
| repainting_scale=repainting_scale, | |
| keep_pixels=keep_pixels, | |
| keep_pixels_rate=keep_pixels_rate, | |
| use_change = use_change) | |
| change_image = [None] if change_image is None else [change_image.to(we.device_id)] | |
| image, mask = [image.to(we.device_id)], [mask.to(we.device_id)] | |
| (src_image_list, src_mask_list, modify_image_list, | |
| edit_id, prompt) = [image], [mask], [change_image], [[0]], [prompt] | |
| with torch.amp.autocast(enabled=True, dtype=self.dtype, device_type='cuda'): | |
| out_image = self.pipe( | |
| src_image_list=src_image_list, | |
| modify_image_list= modify_image_list, | |
| src_mask_list=src_mask_list, | |
| edit_id=edit_id, | |
| image=image, | |
| image_mask=mask, | |
| prompt=prompt, | |
| sampler='flow_euler', | |
| sample_steps=sample_steps, | |
| seed=seed, | |
| guide_scale=guide_scale, | |
| show_process=True, | |
| ) | |
| imgs = [x_i['reconstruct_image'].float().permute(1, 2, 0).cpu().numpy() | |
| for x_i in out_image | |
| ] | |
| imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs] | |
| edit_image = Image.fromarray((torch.clamp(image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8)) | |
| change_image = Image.fromarray((torch.clamp(change_image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8)) | |
| mask = Image.fromarray((mask[0] * 255).squeeze(0).cpu().numpy().astype(np.uint8)) | |
| return self.image_processor.postprocess(imgs[0], slice_w, out_w, out_h), edit_image, change_image, mask, seed | |