Update app.py
Browse files
app.py
CHANGED
|
@@ -6,6 +6,7 @@ import torch
|
|
| 6 |
from huggingface_hub import hf_hub_download
|
| 7 |
import numpy as np
|
| 8 |
import random
|
|
|
|
| 9 |
|
| 10 |
os.system("git clone https://github.com/luost26/diffusion-point-cloud")
|
| 11 |
sys.path.append("diffusion-point-cloud")
|
|
@@ -15,21 +16,38 @@ sys.path.append("diffusion-point-cloud")
|
|
| 15 |
from models.vae_gaussian import *
|
| 16 |
from models.vae_flow import *
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
device='cuda' if torch.cuda.is_available() else 'cpu'
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
ckpt_airplane = torch.load(airplane, map_location=torch.device(device), weights_only=False)
|
| 25 |
-
ckpt_chair = torch.load(chair,map_location=torch.device(device))
|
| 26 |
|
| 27 |
def seed_all(seed):
|
| 28 |
torch.manual_seed(seed)
|
| 29 |
np.random.seed(seed)
|
| 30 |
random.seed(seed)
|
| 31 |
|
| 32 |
-
def normalize_point_clouds(pcs,mode):
|
| 33 |
if mode is None:
|
| 34 |
return pcs
|
| 35 |
for i in range(pcs.size(0)):
|
|
@@ -38,100 +56,158 @@ def normalize_point_clouds(pcs,mode):
|
|
| 38 |
shift = pc.mean(dim=0).reshape(1, 3)
|
| 39 |
scale = pc.flatten().std().reshape(1, 1)
|
| 40 |
elif mode == 'shape_bbox':
|
| 41 |
-
pc_max, _ = pc.max(dim=0, keepdim=True)
|
| 42 |
-
pc_min, _ = pc.min(dim=0, keepdim=True)
|
| 43 |
shift = ((pc_min + pc_max) / 2).view(1, 3)
|
| 44 |
scale = (pc_max - pc_min).max().reshape(1, 1) / 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
pc = (pc - shift) / scale
|
| 46 |
pcs[i] = pc
|
| 47 |
return pcs
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def predict(Seed,ckpt):
|
| 53 |
-
if Seed==None:
|
| 54 |
-
Seed=777
|
| 55 |
-
seed_all(Seed)
|
| 56 |
-
|
| 57 |
-
if ckpt['args'].model == 'gaussian':
|
| 58 |
-
model = GaussianVAE(ckpt['args']).to(device)
|
| 59 |
-
elif ckpt['args'].model == 'flow':
|
| 60 |
-
model = FlowVAE(ckpt['args']).to(device)
|
| 61 |
-
|
| 62 |
-
model.load_state_dict(ckpt['state_dict'])
|
| 63 |
-
# Generate Point Clouds
|
| 64 |
-
gen_pcs = []
|
| 65 |
-
with torch.no_grad():
|
| 66 |
-
z = torch.randn([1, ckpt['args'].latent_dim]).to(device)
|
| 67 |
-
x = model.sample(z, 2048, flexibility=ckpt['args'].flexibility)
|
| 68 |
-
gen_pcs.append(x.detach().cpu())
|
| 69 |
-
gen_pcs = torch.cat(gen_pcs, dim=0)[:1]
|
| 70 |
-
gen_pcs = normalize_point_clouds(gen_pcs, mode="shape_bbox")
|
| 71 |
-
|
| 72 |
-
return gen_pcs[0]
|
| 73 |
|
| 74 |
-
def
|
| 75 |
-
if
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
|
|
|
|
|
|
| 86 |
|
| 87 |
fig = go.Figure(
|
| 88 |
data=[
|
| 89 |
go.Scatter3d(
|
| 90 |
-
x=points[:,0], y=points[:,1], z=points[:,2],
|
| 91 |
mode='markers',
|
| 92 |
-
marker=dict(size=
|
| 93 |
)
|
| 94 |
],
|
| 95 |
layout=dict(
|
| 96 |
scene=dict(
|
| 97 |
-
xaxis=dict(visible=
|
| 98 |
-
yaxis=dict(visible=
|
| 99 |
-
zaxis=dict(visible=
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
| 101 |
)
|
| 102 |
)
|
| 103 |
return fig
|
| 104 |
-
|
| 105 |
-
markdown=f'''
|
| 106 |
-
# Diffusion Probabilistic Models for 3D Point Cloud Generation
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
|
| 119 |
-
|
| 120 |
-
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
'''
|
| 123 |
-
with gr.Blocks() as demo:
|
| 124 |
with gr.Column():
|
| 125 |
with gr.Row():
|
| 126 |
gr.Markdown(markdown)
|
| 127 |
with gr.Row():
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
#truncate_std = gr.Slider( minimum=1, maximum=2,label='Truncate Std')
|
| 131 |
|
| 132 |
-
btn = gr.Button(value="Generate")
|
| 133 |
-
|
| 134 |
-
demo.load(generate, [seed,value], point_cloud)
|
| 135 |
-
btn.click(generate, [seed,value], point_cloud)
|
| 136 |
|
| 137 |
-
demo.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from huggingface_hub import hf_hub_download
|
| 7 |
import numpy as np
|
| 8 |
import random
|
| 9 |
+
# import argparse # Not strictly needed for weights_only=False, but good practice if dealing with argparse.Namespace
|
| 10 |
|
| 11 |
os.system("git clone https://github.com/luost26/diffusion-point-cloud")
|
| 12 |
sys.path.append("diffusion-point-cloud")
|
|
|
|
| 16 |
from models.vae_gaussian import *
|
| 17 |
from models.vae_flow import *
|
| 18 |
|
| 19 |
+
airplane_model_path = hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt", revision="main")
|
| 20 |
+
# IMPORTANT: GEN_chair.pt must be present in the root directory where this script is run.
|
| 21 |
+
# This script does NOT download GEN_chair.pt. You need to manually place it there.
|
| 22 |
+
# The original repository (https://github.com/luost26/diffusion-point-cloud)
|
| 23 |
+
# mentions downloading checkpoints from Google Drive.
|
| 24 |
+
chair_model_path = "./GEN_chair.pt"
|
| 25 |
|
| 26 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 27 |
|
| 28 |
+
# --- Start of PyTorch 2.6+ loading considerations ---
|
| 29 |
+
# Option 1: Set weights_only=False for each load (Simpler, if you trust the source)
|
| 30 |
+
# This is the approach being applied here as per previous interactions.
|
| 31 |
+
|
| 32 |
+
ckpt_airplane = torch.load(airplane_model_path, map_location=torch.device(device), weights_only=False)
|
| 33 |
+
ckpt_chair = torch.load(chair_model_path, map_location=torch.device(device), weights_only=False) # <--- FIX APPLIED HERE
|
| 34 |
+
|
| 35 |
+
# Option 2: For a more robust/secure approach with PyTorch 2.6+ (if you have many models)
|
| 36 |
+
# You could do this at the top, after importing torch and argparse:
|
| 37 |
+
# import argparse
|
| 38 |
+
# torch.serialization.add_safe_globals([argparse.Namespace])
|
| 39 |
+
# Then, the torch.load calls below would not need weights_only=False (they'd use the default weights_only=True)
|
| 40 |
+
# ckpt_airplane = torch.load(airplane_model_path, map_location=torch.device(device))
|
| 41 |
+
# ckpt_chair = torch.load(chair_model_path, map_location=torch.device(device))
|
| 42 |
+
# --- End of PyTorch 2.6+ loading considerations ---
|
| 43 |
|
|
|
|
|
|
|
| 44 |
|
| 45 |
def seed_all(seed):
|
| 46 |
torch.manual_seed(seed)
|
| 47 |
np.random.seed(seed)
|
| 48 |
random.seed(seed)
|
| 49 |
|
| 50 |
+
def normalize_point_clouds(pcs, mode):
|
| 51 |
if mode is None:
|
| 52 |
return pcs
|
| 53 |
for i in range(pcs.size(0)):
|
|
|
|
| 56 |
shift = pc.mean(dim=0).reshape(1, 3)
|
| 57 |
scale = pc.flatten().std().reshape(1, 1)
|
| 58 |
elif mode == 'shape_bbox':
|
| 59 |
+
pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3)
|
| 60 |
+
pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3)
|
| 61 |
shift = ((pc_min + pc_max) / 2).view(1, 3)
|
| 62 |
scale = (pc_max - pc_min).max().reshape(1, 1) / 2
|
| 63 |
+
else: # Fallback if mode is not recognized, though your code doesn't use this branch with current inputs
|
| 64 |
+
shift = 0
|
| 65 |
+
scale = 1
|
| 66 |
+
|
| 67 |
+
# Prevent division by zero or very small scale
|
| 68 |
+
if scale < 1e-8:
|
| 69 |
+
scale = torch.tensor(1.0).reshape(1,1)
|
| 70 |
+
|
| 71 |
pc = (pc - shift) / scale
|
| 72 |
pcs[i] = pc
|
| 73 |
return pcs
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
+
def predict(Seed, ckpt):
|
| 77 |
+
if Seed is None:
|
| 78 |
+
Seed = 777
|
| 79 |
+
seed_all(int(Seed)) # Ensure Seed is an integer
|
| 80 |
+
|
| 81 |
+
# Ensure args is accessible, provide a default if it's missing or not a Namespace
|
| 82 |
+
# This is a defensive measure, as the error was about loading argparse.Namespace
|
| 83 |
+
if not hasattr(ckpt, 'args') or not hasattr(ckpt['args'], 'model'):
|
| 84 |
+
# This case should ideally not happen if the checkpoint is valid
|
| 85 |
+
# but if it does, we need a fallback or error.
|
| 86 |
+
# For now, let's assume 'args' and 'args.model' exist based on the error.
|
| 87 |
+
print("Warning: Checkpoint 'args' or 'args.model' not found. Assuming 'gaussian'.")
|
| 88 |
+
model_type = 'gaussian'
|
| 89 |
+
latent_dim = ckpt.get('latent_dim', 128) # A common default
|
| 90 |
+
flexibility = ckpt.get('flexibility', 0.0) # A common default
|
| 91 |
+
else:
|
| 92 |
+
model_type = ckpt['args'].model
|
| 93 |
+
latent_dim = ckpt['args'].latent_dim
|
| 94 |
+
flexibility = ckpt['args'].flexibility
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
if model_type == 'gaussian':
|
| 98 |
+
# Pass necessary args to the constructor
|
| 99 |
+
# We need to mock an args object if ckpt['args'] wasn't a full argparse.Namespace
|
| 100 |
+
# or if some attributes are missing.
|
| 101 |
+
mock_args = type('Args', (), {'latent_dim': latent_dim, 'hyper': getattr(ckpt.get('args', {}), 'hyper', None)})() # Add other required args
|
| 102 |
+
model = GaussianVAE(mock_args).to(device)
|
| 103 |
+
elif model_type == 'flow':
|
| 104 |
+
mock_args = type('Args', (), {
|
| 105 |
+
'latent_dim': latent_dim,
|
| 106 |
+
'flow_depth': getattr(ckpt.get('args', {}), 'flow_depth', 10), # Example default
|
| 107 |
+
'flow_hidden_dim': getattr(ckpt.get('args', {}), 'flow_hidden_dim', 256), # Example default
|
| 108 |
+
'hyper': getattr(ckpt.get('args', {}), 'hyper', None)
|
| 109 |
+
})()
|
| 110 |
+
model = FlowVAE(mock_args).to(device)
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
| 113 |
+
|
| 114 |
+
model.load_state_dict(ckpt['state_dict'])
|
| 115 |
+
model.eval() # Set model to evaluation mode
|
| 116 |
+
|
| 117 |
+
# Generate Point Clouds
|
| 118 |
+
gen_pcs = []
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
z = torch.randn([1, latent_dim]).to(device)
|
| 121 |
+
# The sample method might also depend on args from the checkpoint
|
| 122 |
+
num_points_to_generate = getattr(ckpt.get('args', {}), 'num_points', 2048) # Default to 2048 if not in args
|
| 123 |
+
x = model.sample(z, num_points_to_generate, flexibility=flexibility)
|
| 124 |
+
gen_pcs.append(x.detach().cpu())
|
| 125 |
+
|
| 126 |
+
gen_pcs_tensor = torch.cat(gen_pcs, dim=0)[:1] # Ensure we take only one point cloud
|
| 127 |
+
gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox") # Use .clone() if normalize_point_clouds modifies inplace
|
| 128 |
+
|
| 129 |
+
return gen_pcs_normalized[0]
|
| 130 |
+
|
| 131 |
+
def generate(seed, value):
|
| 132 |
+
if value == "Airplane":
|
| 133 |
+
ckpt = ckpt_airplane
|
| 134 |
+
elif value == "Chair":
|
| 135 |
+
ckpt = ckpt_chair
|
| 136 |
+
else:
|
| 137 |
+
# Default case or handle error
|
| 138 |
+
# For now, defaulting to airplane if 'value' is unexpected
|
| 139 |
+
print(f"Warning: Unknown model type '{value}'. Defaulting to Airplane.")
|
| 140 |
+
ckpt = ckpt_airplane
|
| 141 |
+
|
| 142 |
+
colors = (238, 75, 43) # RGB tuple for plotly
|
| 143 |
+
|
| 144 |
+
# Ensure seed is not None and is an int for the predict function
|
| 145 |
+
current_seed = seed
|
| 146 |
+
if current_seed is None:
|
| 147 |
+
current_seed = random.randint(0, 2**16 -1) # Generate a random seed if None
|
| 148 |
+
current_seed = int(current_seed)
|
| 149 |
|
| 150 |
+
points = predict(current_seed, ckpt)
|
| 151 |
+
# num_points = points.shape[0] # Not used directly in fig
|
| 152 |
|
| 153 |
fig = go.Figure(
|
| 154 |
data=[
|
| 155 |
go.Scatter3d(
|
| 156 |
+
x=points[:, 0], y=points[:, 1], z=points[:, 2],
|
| 157 |
mode='markers',
|
| 158 |
+
marker=dict(size=2, color=f'rgb({colors[0]},{colors[1]},{colors[2]})') # plotly expects rgb string
|
| 159 |
)
|
| 160 |
],
|
| 161 |
layout=dict(
|
| 162 |
scene=dict(
|
| 163 |
+
xaxis=dict(visible=True, title='X', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
|
| 164 |
+
yaxis=dict(visible=True, title='Y', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
|
| 165 |
+
zaxis=dict(visible=True, title='Z', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
|
| 166 |
+
aspectmode='data' # Ensures proportional axes
|
| 167 |
+
),
|
| 168 |
+
margin=dict(l=0, r=0, b=0, t=40), # Adjust margins
|
| 169 |
+
title=f"Generated {value} (Seed: {current_seed})"
|
| 170 |
)
|
| 171 |
)
|
| 172 |
return fig
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
+
markdown = f'''
|
| 175 |
+
# Diffusion Probabilistic Models for 3D Point Cloud Generation
|
| 176 |
+
|
| 177 |
+
[The space demo for the CVPR 2021 paper "Diffusion Probabilistic Models for 3D Point Cloud Generation".](https://arxiv.org/abs/2103.01458)
|
| 178 |
|
| 179 |
+
[For the official implementation.](https://github.com/luost26/diffusion-point-cloud)
|
| 180 |
+
### Future Work based on interest
|
| 181 |
+
- Adding new models for new type objects
|
| 182 |
+
- New Customization
|
|
|
|
| 183 |
|
| 184 |
+
It is running on **{device.upper()}**
|
|
|
|
| 185 |
|
| 186 |
+
---
|
| 187 |
+
**Note:** The `GEN_chair.pt` file must be manually placed in the root directory for the "Chair" model to work.
|
| 188 |
+
It is not downloaded automatically by this script.
|
| 189 |
+
Check the [original repository's instructions](https://github.com/luost26/diffusion-point-cloud#pretrained-models) for downloading checkpoints.
|
| 190 |
+
---
|
| 191 |
'''
|
| 192 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 193 |
with gr.Column():
|
| 194 |
with gr.Row():
|
| 195 |
gr.Markdown(markdown)
|
| 196 |
with gr.Row():
|
| 197 |
+
seed_slider = gr.Slider(minimum=0, maximum=2**16 - 1, step=1, label='Seed (0 for random)', value=777) # Set initial value
|
| 198 |
+
model_dropdown = gr.Dropdown(choices=["Airplane", "Chair"], label="Choose Model Type", value="Airplane") # Set initial value
|
|
|
|
| 199 |
|
| 200 |
+
btn = gr.Button(value="Generate Point Cloud")
|
| 201 |
+
point_cloud_plot = gr.Plot() # Changed variable name for clarity
|
|
|
|
|
|
|
| 202 |
|
| 203 |
+
# demo.load(generate, [seed_slider, model_dropdown], point_cloud_plot) # demo.load usually runs on page load
|
| 204 |
+
btn.click(generate, [seed_slider, model_dropdown], point_cloud_plot)
|
| 205 |
+
|
| 206 |
+
if __name__ == "__main__":
|
| 207 |
+
# Ensure GEN_chair.pt exists if Chair model might be selected
|
| 208 |
+
if not os.path.exists(chair_model_path):
|
| 209 |
+
print(f"WARNING: Chair model checkpoint '{chair_model_path}' not found.")
|
| 210 |
+
print(f"The 'Chair' option in the UI may not work unless this file is present.")
|
| 211 |
+
print(f"Please download it from the original project repository and place it at '{chair_model_path}'.")
|
| 212 |
+
|
| 213 |
+
demo.launch()
|