Spaces:
Runtime error
Runtime error
| import random | |
| from einops import rearrange | |
| from diffusers.models import AutoencoderKL | |
| from PIL import Image | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from torchvision.transforms.functional import to_pil_image | |
| from flux.sampling import prepare_modified | |
| from flux.util import load_clip, load_t5, load_flow_model | |
| from transport import Sampler, create_transport | |
| from imgproc import to_rgb_if_rgba | |
| def center_crop(image, target_size): | |
| width, height = image.size | |
| new_width, new_height = target_size | |
| left = (width - new_width) // 2 | |
| top = (height - new_height) // 2 | |
| right = left + new_width | |
| bottom = top + new_height | |
| return image.crop((left, top, right, bottom)) | |
| def resize_with_aspect_ratio(img, resolution, divisible=16, aspect_ratio=None): | |
| """Resize image while maintaining aspect ratio, ensuring area is close to resolution**2 and dimensions are divisible by 16 | |
| Args: | |
| img: PIL Image or torch.Tensor (C,H,W)/(B,C,H,W) | |
| resolution: target resolution | |
| divisible: ensure output dimensions are divisible by this number | |
| Returns: | |
| Resized image of the same type as input | |
| """ | |
| # Check input type and get dimensions | |
| is_tensor = isinstance(img, torch.Tensor) | |
| if is_tensor: | |
| if img.dim() == 3: | |
| c, h, w = img.shape | |
| batch_dim = False | |
| else: | |
| b, c, h, w = img.shape | |
| batch_dim = True | |
| else: | |
| w, h = img.size | |
| # Calculate new dimensions | |
| if aspect_ratio is None: | |
| aspect_ratio = w / h | |
| target_area = resolution * resolution | |
| new_h = int((target_area / aspect_ratio) ** 0.5) | |
| new_w = int(new_h * aspect_ratio) | |
| # Ensure divisible by divisible | |
| new_w = max(new_w // divisible, 1) * divisible | |
| new_h = max(new_h // divisible, 1) * divisible | |
| # Adjust size based on input type | |
| if is_tensor: | |
| # Use torch interpolation method | |
| mode = 'bilinear' | |
| align_corners = False | |
| if batch_dim: | |
| return F.interpolate(img, size=(new_h, new_w), | |
| mode=mode, align_corners=align_corners) | |
| else: | |
| return F.interpolate(img.unsqueeze(0), size=(new_h, new_w), | |
| mode=mode, align_corners=align_corners).squeeze(0) | |
| else: | |
| # Use PIL LANCZOS resampling | |
| return img.resize((new_w, new_h), Image.LANCZOS) | |
| class VisualClozeModel: | |
| def __init__( | |
| self, model_path, model_name="flux-dev-fill-lora", max_length=512, lora_rank=256, | |
| atol=1e-6, rtol=1e-3, solver='euler', time_shifting_factor=1, | |
| resolution=384, precision='bf16'): | |
| self.atol = atol | |
| self.rtol = rtol | |
| self.solver = solver | |
| self.time_shifting_factor = time_shifting_factor | |
| self.resolution = resolution | |
| self.precision = precision | |
| self.max_length = max_length | |
| self.lora_rank = lora_rank | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.precision] | |
| # Initialize model | |
| print("Initializing model...") | |
| self.model = load_flow_model(model_name, device=self.device, lora_rank=self.lora_rank) | |
| # Initialize VAE | |
| print("Initializing VAE...") | |
| self.ae = AutoencoderKL.from_pretrained(f"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=self.dtype).to(self.device) | |
| self.ae.requires_grad_(False) | |
| # Initialize text encoders | |
| print("Initializing text encoders...") | |
| self.t5 = load_t5(self.device, max_length=self.max_length) | |
| self.clip = load_clip(self.device) | |
| self.model.eval().to(self.device, dtype=self.dtype) | |
| # Load model weights | |
| ckpt = torch.load(model_path) | |
| self.model.load_state_dict(ckpt, strict=False) | |
| del ckpt | |
| # Initialize sampler | |
| transport = create_transport( | |
| "Linear", | |
| "velocity", | |
| do_shift=True, | |
| ) | |
| self.sampler = Sampler(transport) | |
| self.sample_fn = self.sampler.sample_ode( | |
| sampling_method=self.solver, | |
| num_steps=30, | |
| atol=self.atol, | |
| rtol=self.rtol, | |
| reverse=False, | |
| do_shift=True, | |
| time_shifting_factor=self.time_shifting_factor, | |
| ) | |
| # Image transformation | |
| self.image_transform = transforms.Compose([ | |
| transforms.Lambda(lambda img: to_rgb_if_rgba(img)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
| ]) | |
| self.grid_h = None | |
| self.grid_w = None | |
| def set_grid_size(self, h, w): | |
| """Set grid size""" | |
| self.grid_h = h | |
| self.grid_w = w | |
| def upsampling(self, image, target_size, cfg, upsampling_steps, upsampling_noise, generator, content_prompt): | |
| content_instruction = [ | |
| "The content of the last image in the final row is: ", | |
| "The last image of the last row depicts: ", | |
| "In the final row, the last image shows: ", | |
| "The last image in the bottom row illustrates: ", | |
| "The content of the bottom-right image is: ", | |
| "The final image in the last row portrays: ", | |
| "The last image of the final row displays: ", | |
| "In the last row, the final image captures: ", | |
| "The bottom-right corner image presents: ", | |
| "The content of the last image in the concluding row is: ", | |
| "In the last row, ", | |
| "The editing instruction in the last row is: ", | |
| ] | |
| for c in content_instruction: | |
| if content_prompt.startswith(c): | |
| content_prompt = content_prompt.replace(c, '') | |
| if target_size is None: | |
| aspect_ratio = 1 | |
| target_area = 1024 * 1024 | |
| new_h = int((target_area / aspect_ratio) ** 0.5) | |
| new_w = int(new_h * aspect_ratio) | |
| target_size = (new_w, new_h) | |
| if target_size[0] * target_size[1] > 1024 * 1024: | |
| aspect_ratio = target_size[0] / target_size[1] | |
| target_area = 1024 * 1024 | |
| new_h = int((target_area / aspect_ratio) ** 0.5) | |
| new_w = int(new_h * aspect_ratio) | |
| target_size = (new_w, new_h) | |
| self.sample_fn = self.sampler.sample_ode( | |
| sampling_method=self.solver, | |
| num_steps=upsampling_steps, | |
| atol=self.atol, | |
| rtol=self.rtol, | |
| reverse=False, | |
| do_shift=False, | |
| time_shifting_factor=1.0, | |
| strength=upsampling_noise | |
| ) | |
| image = image.resize(((target_size[0] // 16) * 16, (target_size[1] // 16) * 16)) | |
| processed_image = self.image_transform(image) | |
| processed_image = processed_image.to(self.device, non_blocking=True) | |
| blank = torch.zeros_like(processed_image, device=self.device, dtype=self.dtype) | |
| mask = torch.full((1, 1, processed_image.shape[1], processed_image.shape[2]), fill_value=1, device=self.device, dtype=self.dtype) | |
| with torch.no_grad(): | |
| latent = self.ae.encode(processed_image[None].to(self.ae.dtype)).latent_dist.sample() | |
| blank = self.ae.encode(blank[None].to(self.ae.dtype)).latent_dist.sample() | |
| latent = (latent - self.ae.config.shift_factor) * self.ae.config.scaling_factor | |
| blank = (blank - self.ae.config.shift_factor) * self.ae.config.scaling_factor | |
| latent_h, latent_w = latent.shape[2:] | |
| mask = rearrange(mask, "b c (h ph) (w pw) -> b (c ph pw) h w", ph=8, pw=8) | |
| mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
| latent = latent.to(self.dtype) | |
| blank = blank.to(self.dtype) | |
| latent = rearrange(latent, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
| blank = rearrange(blank, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
| img_cond = torch.cat((blank, mask), dim=-1) | |
| # Generate noise | |
| noise = torch.randn([1, 16, latent_h, latent_w], device=self.device, generator=generator).to(self.dtype) | |
| x = [[noise]] | |
| inp = prepare_modified(t5=self.t5, clip=self.clip, img=x, prompt=[content_prompt], proportion_empty_prompts=0.0) | |
| inp["img"] = inp["img"] * (1 - upsampling_noise) + latent * upsampling_noise | |
| model_kwargs = dict( | |
| txt=inp["txt"], | |
| txt_ids=inp["txt_ids"], | |
| txt_mask=inp["txt_mask"], | |
| y=inp["vec"], | |
| img_ids=inp["img_ids"], | |
| img_mask=inp["img_mask"], | |
| cond=img_cond, | |
| guidance=torch.full((1,), cfg, device=self.device, dtype=self.dtype), | |
| ) | |
| sample = self.sample_fn( | |
| inp["img"], self.model.forward, model_kwargs | |
| )[-1] | |
| sample = sample[:1] | |
| sample = rearrange(sample, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=2, pw=2, h=latent_h // 2, w=latent_w // 2) | |
| sample = self.ae.decode(sample / self.ae.config.scaling_factor + self.ae.config.shift_factor)[0] | |
| sample = (sample + 1.0) / 2.0 | |
| sample.clamp_(0.0, 1.0) | |
| sample = sample[0] | |
| output_image = to_pil_image(sample.float()) | |
| return output_image | |
| def process_images( | |
| self, images: list[list[Image.Image]], | |
| prompts: list[str], | |
| seed: int = 0, | |
| cfg: int = 30, | |
| steps: int = 30, | |
| upsampling_steps: int = 10, | |
| upsampling_noise: float = 0.4, | |
| is_upsampling: bool =True): | |
| """ | |
| Processes a list of images based on provided text prompts and settings, | |
| with optional upsampling steps to improve image resolution or detail. | |
| Parameters: | |
| images (list[list[Image.Image]]): A grid-layout image collection, each row represents an in-context example or the current query, | |
| where the current query should be placed in the last row. | |
| The target image can be None in the input. The other images should be the PIL Image class (Image.Image). | |
| prompts (list[str]): Three prompts, representing the layout prompt, task prompt, and content prompt respectively. | |
| seed (int): A fixed integer seed to ensure reproducibility of the random elements in the processing. | |
| cfg (int): The strength of Classifier-Free Diffusion Guidance. | |
| steps (int): The number of sampling steps. | |
| upsampling_steps (int): The number of denoising steps when upsampling. | |
| upsampling_noise (float): When upsampling using SDEdit, | |
| the noise is used as a starting point and less noise is added the higher the strength. | |
| A value of 1 means added noise is maximum. | |
| is_upsampling (bool, optional): A flag to indicate if upsampling should be applied using SDEdit. | |
| Returns: | |
| Processed images as a result of the algorithm, with optional upsampling applied based on the `is_upsampling` flag. | |
| """ | |
| if seed == 0: | |
| seed = random.randint(0, 2 ** 32 - 1) | |
| self.sample_fn = self.sampler.sample_ode( | |
| sampling_method=self.solver, | |
| num_steps=steps, | |
| atol=self.atol, | |
| rtol=self.rtol, | |
| reverse=False, | |
| do_shift=True, | |
| time_shifting_factor=self.time_shifting_factor, | |
| ) | |
| # Use class grid size | |
| grid_h, grid_w = self.grid_h, self.grid_w | |
| # Ensure all images are RGB mode or None | |
| for i in range(0, grid_h): | |
| images[i] = [img.convert("RGB") if img is not None else None for img in images[i]] | |
| # Adjust all image sizes | |
| resolution = self.resolution | |
| processed_images = [] | |
| mask_position = [] | |
| target_size = None | |
| upsampling_size = None | |
| for i in range(grid_h): | |
| # Find the size of the first non-empty image in this row | |
| reference_size = None | |
| for j in range(0, grid_w): | |
| if images[i][j] is not None: | |
| if i == grid_h - 1 and upsampling_size is None: | |
| upsampling_size = images[i][j].size | |
| resized = resize_with_aspect_ratio(images[i][j], resolution, aspect_ratio=None) | |
| reference_size = resized.size | |
| if i == grid_h - 1 and target_size is None: | |
| target_size = reference_size | |
| break | |
| # Process all images in this row | |
| for j in range(0, grid_w): | |
| if images[i][j] is not None: | |
| target = resize_with_aspect_ratio(images[i][j], resolution, aspect_ratio=None) | |
| if target.width <= target.height: | |
| target = target.resize((reference_size[0], int(reference_size[0] / target.width * target.height))) | |
| target = center_crop(target, reference_size) | |
| elif target.width > target.height: | |
| target = target.resize((int(reference_size[1] / target.height * target.width), reference_size[1])) | |
| target = center_crop(target, reference_size) | |
| processed_images.append(target) | |
| if i == grid_h - 1: | |
| mask_position.append(0) | |
| else: | |
| # If this row has a reference size, use it; otherwise use default size | |
| if reference_size: | |
| blank = Image.new('RGB', reference_size, (0, 0, 0)) | |
| else: | |
| blank = Image.new('RGB', (resolution, resolution), (0, 0, 0)) | |
| processed_images.append(blank) | |
| if i == grid_h - 1: | |
| mask_position.append(1) | |
| if len(mask_position) > 1 and sum(mask_position) > 1: | |
| if target_size is None: | |
| new_w = 384 | |
| else: | |
| new_w = target_size[0] | |
| for i in range(len(processed_images)): | |
| if processed_images[i] is not None: | |
| new_h = int(processed_images[i].height * (new_w / processed_images[i].width)) | |
| new_w = int(new_w / 16) * 16 | |
| new_h = int(new_h / 16) * 16 | |
| processed_images[i] = processed_images[i].resize((new_w, new_h)) | |
| # Build grid image and mask | |
| with torch.autocast("cuda", self.dtype): | |
| grid_image = [] | |
| fill_mask = [] | |
| for i in range(grid_h): | |
| row_images = [self.image_transform(img) for img in processed_images[i * grid_w: (i + 1) * grid_w]] | |
| if i == grid_h - 1: | |
| row_masks = [torch.full((1, 1, row_images[0].shape[1], row_images[0].shape[2]), fill_value=m, device=self.device) for m in mask_position] | |
| else: | |
| row_masks = [torch.full((1, 1, row_images[0].shape[1], row_images[0].shape[2]), fill_value=0, device=self.device) for m in mask_position] | |
| grid_image.append(torch.cat(row_images, dim=2).to(self.device, non_blocking=True)) | |
| fill_mask.append(torch.cat(row_masks, dim=3)) | |
| # Encode condition image | |
| with torch.no_grad(): | |
| fill_cond = [self.ae.encode(img[None].to(self.ae.dtype)).latent_dist.sample()[0] for img in grid_image] | |
| fill_cond = [(img - self.ae.config.shift_factor) * self.ae.config.scaling_factor for img in fill_cond] | |
| # Rearrange mask | |
| fill_mask = [rearrange(mask, "b c (h ph) (w pw) -> b (c ph pw) h w", ph=8, pw=8) for mask in fill_mask] | |
| fill_mask = [rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) for mask in fill_mask] | |
| fill_cond = [img.to(self.dtype) for img in fill_cond] | |
| fill_cond = [rearrange(img.unsqueeze(0), "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) for img in fill_cond] | |
| fill_cond = torch.cat(fill_cond, dim=1) | |
| fill_mask = torch.cat(fill_mask, dim=1) | |
| img_cond = torch.cat((fill_cond, fill_mask), dim=-1) | |
| # Generate sample | |
| noise = [] | |
| sliced_subimage = [] | |
| rng = torch.Generator(device=self.device).manual_seed(int(seed)) | |
| for sub_img in grid_image: | |
| h, w = sub_img.shape[-2:] | |
| sliced_subimage.append((h, w)) | |
| latent_w, latent_h = w // 8, h // 8 | |
| noise.append(torch.randn([1, 16, latent_h, latent_w], device=self.device, generator=rng).to(self.dtype)) | |
| x = [noise] | |
| with torch.no_grad(): | |
| inp = prepare_modified(t5=self.t5, clip=self.clip, img=x, prompt=[' '.join(prompts)], proportion_empty_prompts=0.0) | |
| model_kwargs = dict( | |
| txt=inp["txt"], | |
| txt_ids=inp["txt_ids"], | |
| txt_mask=inp["txt_mask"], | |
| y=inp["vec"], | |
| img_ids=inp["img_ids"], | |
| img_mask=inp["img_mask"], | |
| cond=img_cond, | |
| guidance=torch.full((1,), cfg, device=self.device, dtype=self.dtype), | |
| ) | |
| samples = self.sample_fn( | |
| inp["img"], self.model.forward, model_kwargs | |
| )[-1] | |
| # Get query row | |
| with torch.no_grad(): | |
| samples = samples[:1] | |
| row_samples = [] | |
| start = 0 | |
| for size in sliced_subimage: | |
| end = start + (size[0] * size[1] // 256) | |
| latent_h = size[0] // 8 | |
| latent_w = size[1] // 8 | |
| row_sample = samples[:, start:end, :] | |
| row_sample = rearrange(row_sample, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=2, pw=2, h=latent_h//2, w=latent_w//2) | |
| row_sample = self.ae.decode(row_sample / self.ae.config.scaling_factor + self.ae.config.shift_factor)[0] | |
| row_sample = (row_sample + 1.0) / 2.0 | |
| row_sample.clamp_(0.0, 1.0) | |
| row_samples.append(row_sample[0]) | |
| start = end | |
| # Convert all samples to PIL images | |
| output_images = [] | |
| for row_sample in row_samples: | |
| output_image = to_pil_image(row_sample.float()) | |
| output_images.append(output_image) | |
| torch.cuda.empty_cache() | |
| ret = [] | |
| ret_w = output_images[-1].width | |
| ret_h = output_images[-1].height | |
| row_start = (grid_h - 1) * grid_w | |
| row_end = grid_h * grid_w | |
| for i in range(row_start, row_end): | |
| # when the image is masked, then output it | |
| if True: # images[i] is None: | |
| cropped = output_images[-1].crop(((i - row_start) * ret_w // self.grid_w, 0, ((i - row_start) + 1) * ret_w // self.grid_w, ret_h)) | |
| ret.append(cropped) | |
| if mask_position[i - row_start] and is_upsampling: | |
| upsampled = self.upsampling( | |
| cropped, | |
| upsampling_size, | |
| cfg, | |
| upsampling_steps=upsampling_steps, | |
| upsampling_noise=upsampling_noise, | |
| generator=rng, | |
| content_prompt=prompts[2]) | |
| ret.append(upsampled) | |
| return ret |