Spaces:
Starting
Starting
Update gradio_app.py
Browse files- gradio_app.py +17 -15
gradio_app.py
CHANGED
|
@@ -67,11 +67,11 @@ def create_batch(input_image: Image) -> dict[str, Any]:
|
|
| 67 |
}
|
| 68 |
return batch
|
| 69 |
|
| 70 |
-
def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, height: int = 1024) -> str:
|
| 71 |
"""Generate image from prompt and convert to 3D model."""
|
| 72 |
try:
|
| 73 |
# Generate image using FLUX
|
| 74 |
-
generator = torch.Generator().manual_seed(seed)
|
| 75 |
generated_image = flux_pipe(
|
| 76 |
prompt=prompt,
|
| 77 |
width=width,
|
|
@@ -84,12 +84,13 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
|
|
| 84 |
# Convert PIL image to RGBA
|
| 85 |
input_image = generated_image.convert("RGBA")
|
| 86 |
|
| 87 |
-
# Remove background
|
| 88 |
-
|
|
|
|
| 89 |
|
| 90 |
# Auto crop
|
| 91 |
input_image = spar3d_utils.foreground_crop(
|
| 92 |
-
|
| 93 |
crop_ratio=1.3,
|
| 94 |
newsize=(COND_WIDTH, COND_HEIGHT),
|
| 95 |
no_crop=False
|
|
@@ -101,7 +102,7 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
|
|
| 101 |
|
| 102 |
# Generate mesh
|
| 103 |
with torch.no_grad():
|
| 104 |
-
with torch.autocast(device_type=
|
| 105 |
trimesh_mesh, _ = spar3d_model.generate_mesh(
|
| 106 |
batch,
|
| 107 |
1024, # texture_resolution
|
|
@@ -112,16 +113,17 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig
|
|
| 112 |
trimesh_mesh = trimesh_mesh[0]
|
| 113 |
|
| 114 |
# Export to GLB
|
| 115 |
-
|
| 116 |
-
|
|
|
|
| 117 |
|
| 118 |
-
return
|
| 119 |
|
| 120 |
except Exception as e:
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
|
|
|
|
| 125 |
demo = gr.Interface(
|
| 126 |
fn=generate_and_process_3d,
|
| 127 |
inputs=[
|
|
@@ -153,8 +155,8 @@ demo = gr.Interface(
|
|
| 153 |
],
|
| 154 |
outputs=[
|
| 155 |
gr.File(
|
| 156 |
-
label="Download
|
| 157 |
-
file_types=[".glb"]
|
| 158 |
),
|
| 159 |
gr.Image(
|
| 160 |
label="Generated Image",
|
|
@@ -166,4 +168,4 @@ demo = gr.Interface(
|
|
| 166 |
)
|
| 167 |
|
| 168 |
if __name__ == "__main__":
|
| 169 |
-
demo.launch()
|
|
|
|
| 67 |
}
|
| 68 |
return batch
|
| 69 |
|
| 70 |
+
def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, height: int = 1024) -> tuple[str, Image.Image]:
|
| 71 |
"""Generate image from prompt and convert to 3D model."""
|
| 72 |
try:
|
| 73 |
# Generate image using FLUX
|
| 74 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 75 |
generated_image = flux_pipe(
|
| 76 |
prompt=prompt,
|
| 77 |
width=width,
|
|
|
|
| 84 |
# Convert PIL image to RGBA
|
| 85 |
input_image = generated_image.convert("RGBA")
|
| 86 |
|
| 87 |
+
# Remove background
|
| 88 |
+
rgba_image = bg_remover.process(input_image.convert("RGB"))
|
| 89 |
+
rgba_image.putalpha(255) # Add alpha channel
|
| 90 |
|
| 91 |
# Auto crop
|
| 92 |
input_image = spar3d_utils.foreground_crop(
|
| 93 |
+
rgba_image,
|
| 94 |
crop_ratio=1.3,
|
| 95 |
newsize=(COND_WIDTH, COND_HEIGHT),
|
| 96 |
no_crop=False
|
|
|
|
| 102 |
|
| 103 |
# Generate mesh
|
| 104 |
with torch.no_grad():
|
| 105 |
+
with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
|
| 106 |
trimesh_mesh, _ = spar3d_model.generate_mesh(
|
| 107 |
batch,
|
| 108 |
1024, # texture_resolution
|
|
|
|
| 113 |
trimesh_mesh = trimesh_mesh[0]
|
| 114 |
|
| 115 |
# Export to GLB
|
| 116 |
+
temp_dir = tempfile.mkdtemp()
|
| 117 |
+
output_path = os.path.join(temp_dir, 'output.glb')
|
| 118 |
+
trimesh_mesh.export(output_path, file_type="glb", include_normals=True)
|
| 119 |
|
| 120 |
+
return output_path, generated_image
|
| 121 |
|
| 122 |
except Exception as e:
|
| 123 |
+
print(f"Error: {str(e)}")
|
| 124 |
+
return None, None
|
|
|
|
| 125 |
|
| 126 |
+
# Create Gradio interface
|
| 127 |
demo = gr.Interface(
|
| 128 |
fn=generate_and_process_3d,
|
| 129 |
inputs=[
|
|
|
|
| 155 |
],
|
| 156 |
outputs=[
|
| 157 |
gr.File(
|
| 158 |
+
label="Download 3D Model",
|
| 159 |
+
file_types=[".glb"]
|
| 160 |
),
|
| 161 |
gr.Image(
|
| 162 |
label="Generated Image",
|
|
|
|
| 168 |
)
|
| 169 |
|
| 170 |
if __name__ == "__main__":
|
| 171 |
+
demo.queue().launch()
|