Spaces:
Configuration error
Configuration error
| import re | |
| from natsort import natsorted | |
| def natural_sort_key(s): | |
| return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)] | |
| def load_example_image_groups(directory): | |
| example_groups = {} | |
| for subdir in os.listdir(directory): | |
| subdir_path = os.path.join(directory, subdir) | |
| if os.path.isdir(subdir_path): | |
| example_groups[subdir] = [] | |
| images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] | |
| images = natsorted(images, key=natural_sort_key) # Natural sorting | |
| for filename in images: | |
| img = Image.open(os.path.join(subdir_path, filename)) | |
| example_groups[subdir].append(img) | |
| return example_groups | |
| from io import BytesIO | |
| import gradio as gr | |
| import uvicorn | |
| from fastapi import FastAPI | |
| from PIL import Image | |
| import numpy as np | |
| import mlxu | |
| import os | |
| import re | |
| from natsort import natsorted | |
| from .inference import MultiProcessInferenceModel | |
| FLAGS, _ = mlxu.define_flags_with_default( | |
| host='0.0.0.0', | |
| port=5007, | |
| dtype='float16', | |
| checkpoint='', | |
| torch_devices='', | |
| context_frames=16, | |
| ) | |
| def natural_sort_key(s): | |
| return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)] | |
| def load_example_image_groups(directory): | |
| example_groups = {} | |
| for subdir in os.listdir(directory): | |
| subdir_path = os.path.join(directory, subdir) | |
| if os.path.isdir(subdir_path): | |
| example_groups[subdir] = [] | |
| images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] | |
| images = natsorted(images, key=natural_sort_key) # Natural sorting | |
| for filename in images: | |
| img = Image.open(os.path.join(subdir_path, filename)) | |
| example_groups[subdir].append(img) | |
| return example_groups | |
| def main(_): | |
| assert FLAGS.checkpoint != '' | |
| model = MultiProcessInferenceModel( | |
| checkpoint=FLAGS.checkpoint, | |
| torch_devices=FLAGS.torch_devices, | |
| dtype=FLAGS.dtype, | |
| context_frames=FLAGS.context_frames, | |
| use_lock=True, | |
| ) | |
| checkerboard_r1 = np.concatenate([np.zeros((8, 8, 3)), np.ones((8, 8, 3)), np.zeros((8, 8, 3))], axis=1) | |
| checkerboard_r2 = np.concatenate([np.ones((8, 8, 3)), np.zeros((8, 8, 3)), np.ones((8, 8, 3))], axis=1) | |
| checkerboard = np.concatenate([checkerboard_r1, checkerboard_r2] * 16, axis=0).astype(np.float32) | |
| def generate_images(input_images, n_new_frames, n_candidates, temperature=1.0, top_p=0.9): | |
| assert len(input_images) > 0 | |
| input_images = [ | |
| np.array(img.convert('RGB').resize((256, 256)), dtype=np.float32) / 255.0 | |
| for img in input_images | |
| ] | |
| input_images = np.stack(input_images, axis=0) | |
| output_images = model([input_images], n_new_frames, n_candidates, temperature, top_p)[0] | |
| generated_images = [] | |
| for candidate in output_images: | |
| concatenated_image = [] | |
| for i, img in enumerate(candidate): | |
| concatenated_image.append(img) | |
| if i < len(candidate) - 1: | |
| concatenated_image.append(checkerboard) | |
| generated_images.append( | |
| Image.fromarray( | |
| (np.concatenate(concatenated_image, axis=1) * 255).astype(np.uint8) | |
| ) | |
| ) | |
| return generated_images | |
| with gr.Blocks(css=""" | |
| .small-button { | |
| padding: 5px 10px; | |
| min-width: 80px; | |
| } | |
| .large-gallery img { | |
| width: 100%; | |
| height: auto; | |
| max-height: 150px; | |
| } | |
| """) as demo: | |
| with gr.Column(): | |
| image_list = gr.State([]) | |
| gr.Markdown('# LVM Demo') | |
| gr.Markdown(f'Serving model: {FLAGS.checkpoint}') | |
| gr.Markdown('## Inputs') | |
| with gr.Row(): | |
| upload_drag = gr.File( | |
| type='binary', | |
| file_types=['image'], | |
| file_count='multiple', | |
| ) | |
| with gr.Column(): | |
| gen_length_slider = gr.Slider( | |
| label='Generation length', | |
| minimum=1, | |
| maximum=32, | |
| value=1, | |
| step=1, | |
| interactive=True, | |
| ) | |
| n_candidates_slider = gr.Slider( | |
| label='Number of candidates', | |
| minimum=1, | |
| maximum=10, | |
| value=1, | |
| step=1, | |
| interactive=True, | |
| ) | |
| temp_slider = gr.Slider( | |
| label='Temperature', | |
| minimum=0, | |
| maximum=2.0, | |
| value=1.0, | |
| interactive=True, | |
| ) | |
| top_p_slider = gr.Slider( | |
| label='Top p', | |
| minimum=0, | |
| maximum=1.0, | |
| value=0.9, | |
| interactive=True, | |
| ) | |
| clear_btn = gr.Button( | |
| value='Clear', | |
| elem_classes=['small-button'], | |
| ) | |
| generate_btn = gr.Button( | |
| value='Generate', | |
| interactive=False, | |
| elem_classes=['small-button'], | |
| ) | |
| input_gallery = gr.Gallery( | |
| columns=7, | |
| rows=1, | |
| object_fit='scale-down', | |
| ) | |
| gr.Markdown('## Outputs') | |
| output_gallery = gr.Gallery( | |
| columns=4, | |
| object_fit='scale-down', | |
| ) | |
| def upload_image_fn(files, images): | |
| for file in files: | |
| images.append(Image.open(BytesIO(file))) | |
| return { | |
| upload_drag: None, | |
| image_list: images, | |
| input_gallery: images, | |
| generate_btn: gr.update(interactive=True), | |
| } | |
| def clear_fn(): | |
| return { | |
| image_list: [], | |
| input_gallery: [], | |
| generate_btn: gr.update(interactive=False), | |
| output_gallery: [], | |
| } | |
| def disable_generate_btn(): | |
| return { | |
| generate_btn: gr.update(interactive=False), | |
| } | |
| def generate_fn(images, n_candidates, gen_length, temperature, top_p): | |
| new_images = generate_images( | |
| images, | |
| gen_length, | |
| n_candidates=n_candidates, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ) | |
| return { | |
| output_gallery: new_images, | |
| generate_btn: gr.update(interactive=True), | |
| } | |
| upload_drag.upload( | |
| upload_image_fn, | |
| inputs=[upload_drag, image_list], | |
| outputs=[upload_drag, image_list, input_gallery, generate_btn], | |
| ) | |
| clear_btn.click( | |
| clear_fn, | |
| inputs=None, | |
| outputs=[image_list, input_gallery, generate_btn, output_gallery], | |
| ) | |
| generate_btn.click( | |
| disable_generate_btn, | |
| inputs=None, | |
| outputs=[generate_btn], | |
| ).then( | |
| generate_fn, | |
| inputs=[image_list, n_candidates_slider, gen_length_slider, temp_slider, top_p_slider], | |
| outputs=[output_gallery, generate_btn], | |
| ) | |
| example_groups = load_example_image_groups('/home/yutongbai/demo_images') | |
| def add_image_group_fn(group_name, images): | |
| new_images = images + example_groups[group_name] | |
| return { | |
| image_list: new_images, | |
| input_gallery: new_images, | |
| generate_btn: gr.update(interactive=True), | |
| } | |
| for group_name, group_images in example_groups.items(): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| add_button = gr.Button(value=f'Add {group_name}', elem_classes=['small-button']) | |
| with gr.Column(scale=7): | |
| group_gallery = gr.Gallery( | |
| value=[Image.fromarray(np.array(img)) for img in group_images], | |
| columns=5, | |
| rows=1, | |
| object_fit='scale-down', | |
| label=group_name, | |
| elem_classes=['large-gallery'], | |
| ) | |
| add_button.click( | |
| add_image_group_fn, | |
| inputs=[gr.State(group_name), image_list], | |
| outputs=[image_list, input_gallery, generate_btn], | |
| ) | |
| app = FastAPI() | |
| app = gr.mount_gradio_app(app, demo, '/') | |
| uvicorn.run(app, host=FLAGS.host, port=FLAGS.port) | |
| if __name__ == "__main__": | |
| mlxu.run(main) | |