Spaces:
Running
Running
| import json | |
| import os | |
| import subprocess | |
| import time | |
| import uuid | |
| import zipfile | |
| from dataclasses import fields | |
| from urllib.request import urlretrieve | |
| import gradio as gr | |
| import torch.multiprocessing as mp | |
| import transformers | |
| from legogpt.models import LegoGPT, LegoGPTConfig | |
| def setup(): | |
| # Set up Gurobi licence | |
| licence_filename = 'gurobi.lic' | |
| licence_lines = [] | |
| for secret_name in ['WLSACCESSID', 'WLSSECRET', 'LICENSEID']: | |
| secret = os.environ.get(secret_name) | |
| if not secret: | |
| raise ValueError(f'Env variable {secret_name} not found. Please set it in the Hugging Face Space settings.') | |
| licence_lines.append(f'{secret_name}={secret}\n') | |
| with open(licence_filename, 'w') as f: | |
| f.writelines(licence_lines) | |
| os.environ['GRB_LICENSE_FILE'] = os.path.abspath(licence_filename) | |
| # Download LDraw part library and set LDraw library path | |
| ldraw_zip_url = 'https://library.ldraw.org/library/updates/complete.zip' | |
| ldraw_zip_filename = 'complete.zip' | |
| urlretrieve(ldraw_zip_url, ldraw_zip_filename) | |
| with zipfile.ZipFile(ldraw_zip_filename) as zip_ref: | |
| zip_ref.extractall() | |
| os.environ['LDRAW_LIBRARY_PATH'] = os.path.abspath('ldraw') | |
| def main(): | |
| if os.environ.get('IS_HF_SPACE') == '1': | |
| print('Running in Hugging Face Space, setting up environment...') | |
| setup() | |
| model_cfg = LegoGPTConfig(max_regenerations=5) | |
| generator = LegoGenerator(LegoGPT(model_cfg)) | |
| # Define inputs and outputs | |
| in_prompt = gr.Textbox(label='Prompt', placeholder='Enter a prompt to generate a LEGO model.') | |
| in_temperature = gr.Slider(0.01, 2.0, value=model_cfg.temperature, step=0.01, | |
| label='Temperature', info=get_help_string('temperature')) | |
| in_seed = gr.Number(value=42, label='Seed', info='Random seed for generation.', precision=0, step=1) | |
| in_bricks = gr.Number(value=model_cfg.max_bricks, label='Max bricks', info=get_help_string('max_bricks'), | |
| precision=0, minimum=1, step=1) | |
| in_rejections = gr.Number(value=model_cfg.max_brick_rejections, label='Max brick rejections', | |
| info=get_help_string('max_brick_rejections'), precision=0, minimum=0, step=1) | |
| in_regenerations = gr.Number(value=model_cfg.max_regenerations, label='Max regenerations', | |
| info=get_help_string('max_regenerations'), precision=0, minimum=0, step=1) | |
| out_img = gr.Image(label='Output image', format='png') | |
| out_txt = gr.Textbox(label='Output LEGO bricks', lines=5, max_lines=5, show_copy_button=True, | |
| info='The LEGO structure in text format. Each line of the form "hxw (x,y,z)" represents a ' | |
| '1-unit-tall rectangular brick with dimensions hxw placed at coordinates (x,y,z).') | |
| # Define Gradio interface | |
| demo = gr.Interface( | |
| fn=generator.generate_lego_subprocess, | |
| title='LegoGPT Demo', | |
| description='Official demo for [LegoGPT](https://avalovelace1.github.io/LegoGPT/), the first approach for generating physically stable LEGO brick models from text prompts.\n\n' | |
| 'The model is restricted to creating structures made of 1-unit-tall cuboid bricks on a 20x20x20 grid. It was trained on a dataset of 21 object categories: ' | |
| '*basket, bed, bench, birdhouse, bookshelf, bottle, bowl, bus, camera, car, chair, guitar, jar, mug, piano, pot, sofa, table, tower, train, vessel.* ' | |
| 'Performance on prompts from outside these categories may be limited. This demo does not include texturing or coloring.', | |
| inputs=[in_prompt], | |
| additional_inputs=[in_temperature, in_seed, in_bricks, in_rejections, in_regenerations], | |
| outputs=[out_img, out_txt], | |
| flagging_mode='never', | |
| ) | |
| with demo: | |
| with gr.Row(): | |
| examples = get_examples() | |
| dummy_name = gr.Textbox(visible=False, label='Name') | |
| dummy_out_img = gr.Image(visible=False, label='Result') | |
| gr.Examples( | |
| examples=[[name, example['prompt'], example['temperature'], example['seed'], example['output_img']] | |
| for name, example in examples.items()], | |
| inputs=[dummy_name, in_prompt, in_temperature, in_seed, dummy_out_img], | |
| outputs=[out_img, out_txt], | |
| fn=lambda *args: (args[-1], examples[args[0]]['output_txt']), | |
| run_on_click=True, | |
| ) | |
| concurrency_limit = 2 if os.environ.get('CONCURRENCY_LIMIT') is None else int(os.environ.get('CONCURRENCY_LIMIT')) | |
| demo.queue(default_concurrency_limit=concurrency_limit) | |
| demo.launch(share=True) | |
| class LegoGenerator: | |
| def __init__(self, model: LegoGPT): | |
| self.model = model | |
| self.ctx = mp.get_context('spawn') | |
| def generate_lego( | |
| self, | |
| prompt: str, | |
| temperature: float | None, | |
| seed: int | None, | |
| max_bricks: int | None, | |
| max_brick_rejections: int | None, | |
| max_regenerations: int | None, | |
| ): | |
| # Set model parameters | |
| if temperature is not None: self.model.temperature = temperature | |
| if max_bricks is not None: self.model.max_bricks = max_bricks | |
| if max_brick_rejections is not None: self.model.max_brick_rejections = max_brick_rejections | |
| if max_regenerations is not None: self.model.max_regenerations = max_regenerations | |
| if seed is not None: transformers.set_seed(seed) | |
| # Generate LEGO | |
| print(f'Generating LEGO for prompt: "{prompt}"') | |
| start_time = time.time() | |
| output = self.model(prompt) | |
| # Write output LDR to file | |
| output_dir = os.path.abspath('out') | |
| output_uuid = str(uuid.uuid4()) | |
| os.makedirs(output_dir, exist_ok=True) | |
| ldr_filename = os.path.join(output_dir, f'{output_uuid}.ldr') | |
| with open(ldr_filename, 'w') as f: | |
| f.write(output['lego'].to_ldr()) | |
| print(f'Finished generation in {time.time() - start_time:.1f}s!') | |
| # Render LEGO model to image | |
| print('Rendering image...') | |
| start_time = time.time() | |
| img_filename = os.path.join(output_dir, f'{output_uuid}.png') | |
| subprocess.run(['python', 'render_lego.py', '--in_file', ldr_filename, '--out_file', img_filename], | |
| check=True) # Run render as a subprocess to prevent issues with Blender | |
| print(f'Finished rendering in {time.time() - start_time:.1f}s!') | |
| return img_filename, output['lego'] | |
| def generate_lego_subprocess(self, *args): | |
| """ | |
| Run generation as a subprocess so that multiple requests can be handled concurrently. | |
| """ | |
| with self.ctx.Pool(1) as pool: | |
| return pool.starmap(self.generate_lego, [args])[0] | |
| def get_help_string(field_name: str) -> str: | |
| """ | |
| :param field_name: Name of a field in LegoGPTConfig. | |
| :return: Help string for the field. | |
| """ | |
| data_fields = fields(LegoGPTConfig) | |
| name_field = next(f for f in data_fields if f.name == field_name) | |
| return name_field.metadata['help'] | |
| def get_examples(example_dir: str = os.path.abspath('examples')) -> dict[str, dict[str, str]]: | |
| examples_file = os.path.join(example_dir, 'examples.json') | |
| with open(examples_file) as f: | |
| examples = json.load(f) | |
| for example in examples.values(): | |
| example['output_img'] = os.path.join(example_dir, example['output_img']) | |
| return examples | |
| if __name__ == '__main__': | |
| main() | |