Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| from PIL import Image, ImageOps | |
| import random | |
| import spaces | |
| import cv2 | |
| from diffusers import StableVideoDiffusionPipeline | |
| from diffusers.image_processor import VaeImageProcessor | |
| from diffusers.models.attention_processor import XFormersAttnProcessor | |
| from diffusers.utils import export_to_gif | |
| import gradio as gr | |
| import numpy as np | |
| from safetensors import safe_open | |
| from segment_anything import build_sam, SamPredictor | |
| from tqdm import tqdm | |
| import torch | |
| from svd import ( | |
| UNetDragSpatioTemporalConditionModel, | |
| AllToFirstXFormersAttnProcessor, | |
| ) | |
| TITLE = '''Puppet-Master: Scaling Interactive Video Generation as a Motion Prior for Part-Level Dynamics''' | |
| DESCRIPTION = """ | |
| <div> | |
| Try <a href='https://vgg-puppetmaster.github.io/'><b>Puppet-Master</b></a> yourself to animate your favorite objects in seconds! | |
| </div> | |
| <div> | |
| Please give us a π on <a href='https://github.com/RuiningLi/puppet-master'>Github</a> if you like our work! | |
| </div> | |
| """ | |
| INSTRUCTION = ''' | |
| 2 steps to get started: | |
| - Upload an image of a dynamic object. | |
| - Add one or more drags on the object to specify the part-level interactions. | |
| How to add drags: | |
| - To add a drag, first click on the starting point of the drag, then click on the ending point of the drag, on the Input Image (leftmost). | |
| - You can add up to 5 drags. | |
| - After every click, the drags will be visualized on the Image with Drags (second from left). | |
| - If the last drag is not completed (you specified the starting point but not the ending point), it will simply be ignored. | |
| - To retry, click the [x] button on the top-right corner of the input image to start over, even if you just want to try a different set of drags. | |
| - Have fun dragging! | |
| Then, you will be prompted to verify the object segmentation. Once you confirm that the segmentation is decent, the output image will be generated in seconds! | |
| Tips: | |
| - We found having classifier-free guidance weight ~5.0 works best. | |
| - Try changing the random seed to get different results. | |
| ''' | |
| PREPROCESS_INSTRUCTION = ''' | |
| Segmentation is needed if it is not already provided through an alpha channel in the input image. | |
| You don't need to tick this box if you have chosen one of the example images. | |
| If you have uploaded one of your own images, it is very likely that you will need to tick this box. | |
| You should verify that the preprocessed image is object-centric (i.e., clearly contains a single object) and has white background. | |
| ''' | |
| def tensor2vid(video: torch.Tensor, processor: VaeImageProcessor, output_type: str = "np"): | |
| batch_size = video.shape[0] | |
| outputs = [] | |
| for batch_idx in range(batch_size): | |
| batch_vid = video[batch_idx].permute(1, 0, 2, 3) | |
| batch_output = processor.postprocess(batch_vid, output_type) | |
| outputs.append(batch_output) | |
| if output_type == "np": | |
| outputs = np.stack(outputs) | |
| elif output_type == "pt": | |
| outputs = torch.stack(outputs) | |
| elif not output_type == "pil": | |
| raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") | |
| return outputs | |
| def center_and_square_image(pil_image_rgba, drags, scale_factor): | |
| image = pil_image_rgba | |
| alpha = np.array(image)[:, :, 3] # Extract the alpha channel | |
| foreground_coords = np.argwhere(alpha > 0) | |
| y_min, x_min = foreground_coords.min(axis=0) | |
| y_max, x_max = foreground_coords.max(axis=0) | |
| cy, cx = (y_min + y_max) // 2, (x_min + x_max) // 2 | |
| crop_height, crop_width = y_max - y_min + 1, x_max - x_min + 1 | |
| side_length = int(max(crop_height, crop_width) * scale_factor) | |
| padded_image = ImageOps.expand( | |
| image, | |
| (side_length // 2, side_length // 2, side_length // 2, side_length // 2), | |
| fill=(255, 255, 255, 255) | |
| ) | |
| left, top = cx, cy | |
| new_drags = [] | |
| for d in drags: | |
| x, y = d | |
| new_x, new_y = (x + side_length // 2 - cx) / side_length, (y + side_length // 2 - cy) / side_length | |
| new_drags.append((new_x, new_y)) | |
| # Crop or pad the image as needed to make it centered around (cx, cy) | |
| image = padded_image.crop((left, top, left + side_length, top + side_length)) | |
| # Resize the image to 256x256 | |
| image = image.resize((256, 256), Image.Resampling.LANCZOS) | |
| return image, new_drags | |
| def sam_init(): | |
| sam_checkpoint = os.path.join(os.path.dirname(__file__), "ckpts", "sam_vit_h_4b8939.pth") | |
| predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to("cuda")) | |
| return predictor | |
| def model_init(): | |
| model_checkpoint = os.path.join(os.path.dirname(__file__), "ckpts", "model.safetensors") | |
| state_dict = {} | |
| with safe_open(model_checkpoint, framework="pt", device="cpu") as f: | |
| for k in f.keys(): | |
| state_dict[k] = f.get_tensor(k) | |
| model = UNetDragSpatioTemporalConditionModel(num_drags=5) | |
| attn_processors_dict={ | |
| "down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor": AllToFirstXFormersAttnProcessor(), | |
| "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor": AllToFirstXFormersAttnProcessor(), | |
| "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor": AllToFirstXFormersAttnProcessor(), | |
| "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor": AllToFirstXFormersAttnProcessor(), | |
| "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor": AllToFirstXFormersAttnProcessor(), | |
| "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor": AllToFirstXFormersAttnProcessor(), | |
| "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "down_blocks.0.attentions.0.temporal_transformer_blocks.0.attn1.processor": XFormersAttnProcessor(), | |
| "down_blocks.0.attentions.0.temporal_transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "down_blocks.0.attentions.1.temporal_transformer_blocks.0.attn1.processor": XFormersAttnProcessor(), | |
| "down_blocks.0.attentions.1.temporal_transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "down_blocks.1.attentions.0.temporal_transformer_blocks.0.attn1.processor": XFormersAttnProcessor(), | |
| "down_blocks.1.attentions.0.temporal_transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "down_blocks.1.attentions.1.temporal_transformer_blocks.0.attn1.processor": XFormersAttnProcessor(), | |
| "down_blocks.1.attentions.1.temporal_transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "down_blocks.2.attentions.0.temporal_transformer_blocks.0.attn1.processor": XFormersAttnProcessor(), | |
| "down_blocks.2.attentions.0.temporal_transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "down_blocks.2.attentions.1.temporal_transformer_blocks.0.attn1.processor": XFormersAttnProcessor(), | |
| "down_blocks.2.attentions.1.temporal_transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor": AllToFirstXFormersAttnProcessor(), | |
| "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor": AllToFirstXFormersAttnProcessor(), | |
| "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor": AllToFirstXFormersAttnProcessor(), | |
| "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "up_blocks.2.attentions.0.transformer_blocks.0.attn1.processor": AllToFirstXFormersAttnProcessor(), | |
| "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "up_blocks.2.attentions.1.transformer_blocks.0.attn1.processor": AllToFirstXFormersAttnProcessor(), | |
| "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "up_blocks.2.attentions.2.transformer_blocks.0.attn1.processor": AllToFirstXFormersAttnProcessor(), | |
| "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "up_blocks.3.attentions.0.transformer_blocks.0.attn1.processor": AllToFirstXFormersAttnProcessor(), | |
| "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "up_blocks.3.attentions.1.transformer_blocks.0.attn1.processor": AllToFirstXFormersAttnProcessor(), | |
| "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "up_blocks.3.attentions.2.transformer_blocks.0.attn1.processor": AllToFirstXFormersAttnProcessor(), | |
| "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "up_blocks.1.attentions.0.temporal_transformer_blocks.0.attn1.processor": XFormersAttnProcessor(), | |
| "up_blocks.1.attentions.0.temporal_transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "up_blocks.1.attentions.1.temporal_transformer_blocks.0.attn1.processor": XFormersAttnProcessor(), | |
| "up_blocks.1.attentions.1.temporal_transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "up_blocks.1.attentions.2.temporal_transformer_blocks.0.attn1.processor": XFormersAttnProcessor(), | |
| "up_blocks.1.attentions.2.temporal_transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "up_blocks.2.attentions.0.temporal_transformer_blocks.0.attn1.processor": XFormersAttnProcessor(), | |
| "up_blocks.2.attentions.0.temporal_transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "up_blocks.2.attentions.1.temporal_transformer_blocks.0.attn1.processor": XFormersAttnProcessor(), | |
| "up_blocks.2.attentions.1.temporal_transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "up_blocks.2.attentions.2.temporal_transformer_blocks.0.attn1.processor": XFormersAttnProcessor(), | |
| "up_blocks.2.attentions.2.temporal_transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "up_blocks.3.attentions.0.temporal_transformer_blocks.0.attn1.processor": XFormersAttnProcessor(), | |
| "up_blocks.3.attentions.0.temporal_transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "up_blocks.3.attentions.1.temporal_transformer_blocks.0.attn1.processor": XFormersAttnProcessor(), | |
| "up_blocks.3.attentions.1.temporal_transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "up_blocks.3.attentions.2.temporal_transformer_blocks.0.attn1.processor": XFormersAttnProcessor(), | |
| "up_blocks.3.attentions.2.temporal_transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "mid_block.attentions.0.transformer_blocks.0.attn1.processor": AllToFirstXFormersAttnProcessor(), | |
| "mid_block.attentions.0.transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| "mid_block.attentions.0.temporal_transformer_blocks.0.attn1.processor": XFormersAttnProcessor(), | |
| "mid_block.attentions.0.temporal_transformer_blocks.0.attn2.processor": XFormersAttnProcessor(), | |
| } | |
| model.set_attn_processor(attn_processors_dict) | |
| model.load_state_dict(state_dict, strict=True) | |
| return model.to("cuda") | |
| sam_predictor = sam_init() | |
| model = model_init() | |
| pipe = StableVideoDiffusionPipeline.from_pretrained( | |
| "stabilityai/stable-video-diffusion-img2vid", | |
| torch_dtype=torch.float16, variant="fp16", map_location="cpu" | |
| ) | |
| pipe = pipe.to("cuda") | |
| def sam_segment(input_image, drags, foreground_points=None, scale_factor=2.2): | |
| image = np.asarray(input_image) | |
| sam_predictor.set_image(image) | |
| with torch.no_grad(): | |
| masks_bbox, _, _ = sam_predictor.predict( | |
| point_coords=foreground_points if foreground_points is not None else None, | |
| point_labels=np.ones(len(foreground_points)) if foreground_points is not None else None, | |
| multimask_output=True | |
| ) | |
| out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8) | |
| out_image[:, :, :3] = image | |
| out_image[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 | |
| torch.cuda.empty_cache() | |
| out_image, new_drags = center_and_square_image(Image.fromarray(out_image, mode="RGBA"), drags, scale_factor) | |
| return out_image, new_drags | |
| def get_point(img, sel_pix, evt: gr.SelectData): | |
| sel_pix.append(evt.index) | |
| points = [] | |
| img = np.array(img) | |
| height = img.shape[0] | |
| arrow_width_large = 7 * height // 256 | |
| arrow_width_small = 3 * height // 256 | |
| circle_size = 5 * height // 256 | |
| with_alpha = img.shape[2] == 4 | |
| for idx, point in enumerate(sel_pix): | |
| if idx % 2 == 1: | |
| cv2.circle(img, tuple(point), circle_size, (0, 0, 255, 255) if with_alpha else (0, 0, 255), -1) | |
| else: | |
| cv2.circle(img, tuple(point), circle_size, (255, 0, 0, 255) if with_alpha else (255, 0, 0), -1) | |
| points.append(tuple(point)) | |
| if len(points) == 2: | |
| cv2.arrowedLine(img, points[0], points[1], (0, 0, 0, 255) if with_alpha else (0, 0, 0), arrow_width_large) | |
| cv2.arrowedLine(img, points[0], points[1], (255, 255, 0, 255) if with_alpha else (0, 0, 0), arrow_width_small) | |
| points = [] | |
| return img if isinstance(img, np.ndarray) else np.array(img) | |
| def clear_drag(): | |
| return [] | |
| def preprocess_image(img, chk_group, drags): | |
| if img is None: | |
| gr.Warning("No image is specified. Please specify an image before preprocessing.") | |
| return None, drags | |
| if drags is None or len(drags) == 0: | |
| foreground_points = None | |
| else: | |
| foreground_points = np.array([drags[i] for i in range(0, len(drags), 2)]) | |
| if len(drags) == 0: | |
| gr.Warning("No drags are specified. We recommend first specifying the drags before preprocessing.") | |
| new_drags = drags | |
| if "Preprocess with Segmentation" in chk_group: | |
| img_np = np.array(img) | |
| rgb_img = img_np[..., :3] | |
| img, new_drags = sam_segment( | |
| rgb_img, | |
| drags, | |
| foreground_points=foreground_points, | |
| ) | |
| else: | |
| new_drags = [(d[0] / img.width, d[1] / img.height) for d in drags] | |
| img = np.array(img).astype(np.float32) | |
| processed_img = img[..., :3] * img[..., 3:] / 255. + 255. * (1 - img[..., 3:] / 255.) | |
| image_pil = Image.fromarray(processed_img.astype(np.uint8), mode="RGB") | |
| processed_img = image_pil.resize((256, 256), Image.LANCZOS) | |
| return processed_img, new_drags | |
| def sample_from_noise(model, scheduler, cond_latent, cond_embedding, drags, | |
| min_guidance=1.0, max_guidance=3.0, num_inference_steps=50): | |
| model.eval() | |
| scheduler.set_timesteps(num_inference_steps, device=cond_latent.device) | |
| timesteps = scheduler.timesteps.to(cond_latent.device) | |
| do_classifier_free_guidance = max_guidance > 1.0 | |
| latents = torch.randn((1, 14, 4, 32, 32)).to(cond_latent) * scheduler.init_noise_sigma | |
| guidance_scale = torch.linspace(min_guidance, max_guidance, 14).unsqueeze(0).to(cond_latent)[..., None, None, None] | |
| for i, t in tqdm(enumerate(timesteps)): | |
| latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
| latent_model_input = scheduler.scale_model_input(latent_model_input, t) | |
| with torch.no_grad(): | |
| noise_pred = model( | |
| latent_model_input, | |
| t, | |
| image_latents=torch.cat([cond_latent, torch.zeros_like(cond_latent)]) if do_classifier_free_guidance else cond_latent, | |
| encoder_hidden_states=torch.cat([cond_embedding, torch.zeros_like(cond_embedding)]) if do_classifier_free_guidance else cond_embedding, | |
| added_time_ids=None, # dummy | |
| drags=torch.cat([drags, torch.zeros_like(drags)]) if do_classifier_free_guidance else drags, | |
| ) | |
| if do_classifier_free_guidance: | |
| noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
| latents = scheduler.step(noise_pred, t, latents).prev_sample | |
| return latents | |
| def generate_image(img_cond, seed, cfg_scale, drags_list): | |
| if img_cond is None: | |
| gr.Warning("Please preprocess the image first.") | |
| return None | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| img_cond_pil = Image.fromarray(img_cond) | |
| img_cond_preprocessed = pipe.video_processor.preprocess(img_cond_pil, height=256, width=256) | |
| img_cond_preprocessed = img_cond_preprocessed.to(device="cuda", dtype=torch.float16) | |
| latent_dist = pipe.vae.encode(img_cond_preprocessed).latent_dist | |
| embeddings = pipe._encode_image(img_cond_pil, device="cuda", num_videos_per_prompt=1, do_classifier_free_guidance=False) | |
| drags = torch.zeros(14, 5, 4) | |
| for i in range(0, len(drags_list), 2): | |
| start_point, end_point = drags_list[i:i+2] | |
| drag_idx = i // 2 | |
| drags[:, drag_idx, :2] = torch.Tensor(start_point) | |
| drags[0, drag_idx, 2:] = torch.Tensor(start_point) | |
| drags[-1, drag_idx, 2:] = torch.Tensor(end_point) | |
| if drag_idx == 4: | |
| break | |
| frame_indices = torch.arange(1, 13).unsqueeze(-1).unsqueeze(-1) | |
| t = frame_indices.float() / 13.0 # Normalize time to [0, 1] | |
| drags[1:-1, :, 2:] = drags[0, :, 2:] * (1 - t) + drags[-1, :, 2:] * t | |
| drags = drags[None].to(device="cuda") | |
| batch = dict( | |
| drags=drags, | |
| cond_embedding=embeddings.to(dtype=torch.float32), | |
| cond_latent=latent_dist.mean.to(dtype=torch.float32), | |
| ) | |
| with torch.no_grad(): | |
| latents = sample_from_noise( | |
| model, | |
| pipe.scheduler, | |
| **batch, | |
| max_guidance=cfg_scale, | |
| num_inference_steps=50, | |
| ) | |
| frames = pipe.vae.decode(latents.flatten(0, 1).to(torch.float16) / 0.18215, num_frames=14).sample.float() | |
| frames = tensor2vid(frames.view(-1, 14, 3, 256, 256).permute(0, 2, 1, 3, 4), pipe.video_processor, output_type="pil")[0] | |
| # Add drags | |
| frame_with_drag = np.ascontiguousarray(np.array(frames[0])) | |
| for i in range(0, len(drags_list), 2): | |
| drag_idx = i // 2 | |
| start_point, end_point = drags_list[i:i+2] | |
| start_point = (int(start_point[0] * 256), int(start_point[1] * 256)) | |
| end_point = (int(end_point[0] * 256), int(end_point[1] * 256)) | |
| frame_with_drag = cv2.arrowedLine(frame_with_drag, start_point, end_point, (0, 0, 0), 4) | |
| frame_with_drag = cv2.arrowedLine(frame_with_drag, start_point, end_point, (255, 255, 0), 2) | |
| if drag_idx == 4: | |
| break | |
| frames = [Image.fromarray(frame_with_drag)] * 5 + frames | |
| save_dir = os.path.join(os.path.dirname(__file__), "outputs") | |
| if not os.path.exists(save_dir): | |
| os.makedirs(save_dir) | |
| save_id = len(os.listdir(save_dir)) | |
| save_path = os.path.join(save_dir, f"{save_id:05d}.gif") | |
| export_to_gif(frames, save_path) | |
| return save_path | |
| with gr.Blocks(title=TITLE) as demo: | |
| gr.Markdown("# " + DESCRIPTION) | |
| with gr.Row(): | |
| gr.Markdown(INSTRUCTION) | |
| drags = gr.State(value=[]) | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| interactive=True, | |
| type='pil', | |
| image_mode="RGBA", | |
| width=256, | |
| show_label=True, | |
| label="Input Image", | |
| ) | |
| example_folder = os.path.join(os.path.dirname(__file__), "./example_images") | |
| example_fns = [os.path.join(example_folder, example) for example in sorted(os.listdir(example_folder))] | |
| gr.Examples( | |
| examples=example_fns, | |
| inputs=[input_image], | |
| cache_examples=False, | |
| label='Feel free to use one of our provided examples!', | |
| examples_per_page=30 | |
| ) | |
| input_image.change( | |
| fn=clear_drag, | |
| outputs=[drags], | |
| ) | |
| with gr.Column(scale=1): | |
| drag_image = gr.Image( | |
| type="numpy", | |
| label="Image with Drags", | |
| interactive=False, | |
| width=256, | |
| image_mode="RGB", | |
| ) | |
| input_image.select( | |
| fn=get_point, | |
| inputs=[input_image, drags], | |
| outputs=[drag_image], | |
| ) | |
| with gr.Column(scale=1): | |
| processed_image = gr.Image( | |
| type='numpy', | |
| label="Processed Image", | |
| interactive=False, | |
| width=256, | |
| height=256, | |
| image_mode='RGB', | |
| ) | |
| processed_image_highres = gr.Image(type='pil', image_mode='RGB', visible=False) | |
| with gr.Accordion('Advanced preprocessing options', open=True): | |
| with gr.Row(): | |
| with gr.Column(): | |
| preprocess_chk_group = gr.CheckboxGroup( | |
| ['Preprocess with Segmentation'], | |
| label='Segment', | |
| info=PREPROCESS_INSTRUCTION | |
| ) | |
| preprocess_button = gr.Button( | |
| value="Preprocess Input Image", | |
| ) | |
| preprocess_button.click( | |
| fn=preprocess_image, | |
| inputs=[input_image, preprocess_chk_group, drags], | |
| outputs=[processed_image, drags], | |
| queue=True, | |
| ) | |
| with gr.Column(scale=1): | |
| generated_gif = gr.Image( | |
| type="filepath", | |
| label="Generated GIF", | |
| interactive=False, | |
| height=256, | |
| width=256, | |
| image_mode="RGB", | |
| ) | |
| with gr.Accordion('Advanced generation options', open=True): | |
| with gr.Row(): | |
| with gr.Column(): | |
| seed = gr.Slider(label="seed", value=0, minimum=0, maximum=10000, step=1, randomize=False) | |
| cfg_scale = gr.Slider( | |
| label="classifier-free guidance weight", | |
| value=5, minimum=1, maximum=10, step=0.1 | |
| ) | |
| generate_button = gr.Button( | |
| value="Generate Image", | |
| ) | |
| generate_button.click( | |
| fn=generate_image, | |
| inputs=[processed_image, seed, cfg_scale, drags], | |
| outputs=[generated_gif], | |
| ) | |
| demo.launch(share=True) | |