Update app.py
Browse files
app.py
CHANGED
|
@@ -6,42 +6,84 @@ import torch
|
|
| 6 |
from huggingface_hub import hf_hub_download
|
| 7 |
import numpy as np
|
| 8 |
import random
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
sys.path.append("diffusion-point-cloud")
|
| 13 |
|
| 14 |
-
#
|
| 15 |
-
|
| 16 |
-
from models.vae_gaussian import
|
| 17 |
-
from models.vae_flow import
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
#
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
#
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
def seed_all(seed):
|
| 46 |
torch.manual_seed(seed)
|
| 47 |
np.random.seed(seed)
|
|
@@ -56,182 +98,231 @@ def normalize_point_clouds(pcs, mode):
|
|
| 56 |
shift = pc.mean(dim=0).reshape(1, 3)
|
| 57 |
scale = pc.flatten().std().reshape(1, 1)
|
| 58 |
elif mode == 'shape_bbox':
|
| 59 |
-
pc_max, _ = pc.max(dim=0, keepdim=True)
|
| 60 |
-
pc_min, _ = pc.min(dim=0, keepdim=True)
|
| 61 |
shift = ((pc_min + pc_max) / 2).view(1, 3)
|
| 62 |
scale = (pc_max - pc_min).max().reshape(1, 1) / 2
|
| 63 |
-
else: # Fallback
|
| 64 |
-
shift = 0
|
| 65 |
-
scale = 1
|
| 66 |
|
| 67 |
-
# Prevent division by zero or very small scale
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
pc = (pc - shift) / scale
|
| 72 |
-
pcs[i] = pc
|
| 73 |
return pcs
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
-
|
| 77 |
-
if Seed is None:
|
| 78 |
-
Seed = 777
|
| 79 |
-
seed_all(int(Seed))
|
| 80 |
|
| 81 |
-
# ---
|
| 82 |
actual_args = None
|
|
|
|
| 83 |
if 'args' in ckpt and hasattr(ckpt['args'], 'model'):
|
| 84 |
actual_args = ckpt['args']
|
| 85 |
-
print("Using 'args' found in checkpoint.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
else:
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
'model': ckpt.get('model', default_model_type),
|
| 101 |
-
'latent_dim': ckpt.get('latent_dim', default_latent_dim),
|
| 102 |
-
'hyper': ckpt.get('hyper', default_hyper),
|
| 103 |
-
'residual': ckpt.get('residual', default_residual),
|
| 104 |
-
'flow_depth': ckpt.get('flow_depth', default_flow_depth),
|
| 105 |
-
'flow_hidden_dim': ckpt.get('flow_hidden_dim', default_flow_hidden_dim),
|
| 106 |
-
'num_points': ckpt.get('num_points', default_num_points), # Try to get from ckpt top-level too
|
| 107 |
-
'flexibility': ckpt.get('flexibility', default_flexibility) # Try to get from ckpt top-level too
|
| 108 |
-
})()
|
| 109 |
-
|
| 110 |
-
# Ensure essential attributes for sampling exist on actual_args, even if 'args' was found
|
| 111 |
-
# These are parameters for the .sample() method, not necessarily model construction.
|
| 112 |
-
# The original training args might not have included these if they were fixed in the sampling script.
|
| 113 |
-
|
| 114 |
-
# Default values for sampling parameters if not present in actual_args
|
| 115 |
-
default_num_points_sampling = 2048
|
| 116 |
-
default_flexibility_sampling = 0.0
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
if not hasattr(actual_args, 'num_points'):
|
| 119 |
-
print(
|
| 120 |
-
setattr(actual_args, 'num_points',
|
| 121 |
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
# Also ensure 'residual' is present if it's a Gaussian model, as it was an issue before
|
| 127 |
-
# This is more for model construction, but good to double-check if the 'args' from ckpt might be incomplete
|
| 128 |
-
if actual_args.model == 'gaussian' and not hasattr(actual_args, 'residual'):
|
| 129 |
-
print(f"Attribute 'residual' not found in actual_args for Gaussian model. Setting default: True")
|
| 130 |
-
setattr(actual_args, 'residual', True) # Default for GaussianVAE
|
| 131 |
|
| 132 |
-
# --- MODIFICATION END ---
|
| 133 |
|
|
|
|
|
|
|
| 134 |
if actual_args.model == 'gaussian':
|
| 135 |
-
model = GaussianVAE(actual_args).to(
|
| 136 |
elif actual_args.model == 'flow':
|
| 137 |
-
model = FlowVAE(actual_args).to(
|
| 138 |
else:
|
| 139 |
-
raise ValueError(f"Unknown model type: {actual_args.model}")
|
| 140 |
|
| 141 |
model.load_state_dict(ckpt['state_dict'])
|
| 142 |
model.eval()
|
| 143 |
|
|
|
|
| 144 |
gen_pcs = []
|
| 145 |
with torch.no_grad():
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
x = model.sample(z, actual_args.num_points, flexibility=actual_args.flexibility)
|
| 149 |
gen_pcs.append(x.detach().cpu())
|
| 150 |
-
|
| 151 |
gen_pcs_tensor = torch.cat(gen_pcs, dim=0)[:1]
|
| 152 |
gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox")
|
| 153 |
|
| 154 |
return gen_pcs_normalized[0]
|
| 155 |
-
def generate(seed, value):
|
| 156 |
-
if value == "Airplane":
|
| 157 |
-
ckpt = ckpt_airplane
|
| 158 |
-
elif value == "Chair":
|
| 159 |
-
ckpt = ckpt_chair
|
| 160 |
-
else:
|
| 161 |
-
# Default case or handle error
|
| 162 |
-
# For now, defaulting to airplane if 'value' is unexpected
|
| 163 |
-
print(f"Warning: Unknown model type '{value}'. Defaulting to Airplane.")
|
| 164 |
-
ckpt = ckpt_airplane
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
)
|
| 184 |
-
],
|
| 185 |
-
layout=dict(
|
| 186 |
-
scene=dict(
|
| 187 |
-
xaxis=dict(visible=True, title='X', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
|
| 188 |
-
yaxis=dict(visible=True, title='Y', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
|
| 189 |
-
zaxis=dict(visible=True, title='Z', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
|
| 190 |
-
aspectmode='data' # Ensures proportional axes
|
| 191 |
-
),
|
| 192 |
-
margin=dict(l=0, r=0, b=0, t=40), # Adjust margins
|
| 193 |
-
title=f"Generated {value} (Seed: {current_seed})"
|
| 194 |
)
|
| 195 |
-
)
|
| 196 |
-
return fig
|
| 197 |
|
| 198 |
-
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
- Adding new models for new type objects
|
| 206 |
-
- New Customization
|
| 207 |
|
| 208 |
-
|
| 209 |
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
|
|
|
|
|
|
| 215 |
'''
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
with gr.Row():
|
| 219 |
-
gr.Markdown(markdown)
|
| 220 |
-
with gr.Row():
|
| 221 |
-
seed_slider = gr.Slider(minimum=0, maximum=2**16 - 1, step=1, label='Seed (0 for random)', value=777) # Set initial value
|
| 222 |
-
model_dropdown = gr.Dropdown(choices=["Airplane", "Chair"], label="Choose Model Type", value="Airplane") # Set initial value
|
| 223 |
|
| 224 |
-
btn = gr.Button(value="Generate Point Cloud")
|
| 225 |
-
point_cloud_plot = gr.Plot() # Changed variable name for clarity
|
| 226 |
|
| 227 |
-
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
|
|
|
| 230 |
if __name__ == "__main__":
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
print(f"Please download it from the original project repository and place it at '{chair_model_path}'.")
|
| 236 |
|
| 237 |
-
demo
|
|
|
|
|
|
| 6 |
from huggingface_hub import hf_hub_download
|
| 7 |
import numpy as np
|
| 8 |
import random
|
| 9 |
+
import tempfile # For creating temporary files for download
|
| 10 |
+
import traceback # For detailed error logging
|
| 11 |
+
|
| 12 |
+
# --- Environment Setup ---
|
| 13 |
+
# Suppress TensorFlow oneDNN optimization messages if TensorFlow is inadvertently imported by a dependency
|
| 14 |
+
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
| 15 |
+
# Clone the repository only if the directory doesn't exist
|
| 16 |
+
if not os.path.exists("diffusion-point-cloud"):
|
| 17 |
+
print("Cloning diffusion-point-cloud repository...")
|
| 18 |
+
os.system("git clone https://github.com/luost26/diffusion-point-cloud")
|
| 19 |
+
else:
|
| 20 |
+
print("diffusion-point-cloud repository already exists.")
|
| 21 |
sys.path.append("diffusion-point-cloud")
|
| 22 |
|
| 23 |
+
# --- Model Imports ---
|
| 24 |
+
try:
|
| 25 |
+
from models.vae_gaussian import GaussianVAE
|
| 26 |
+
from models.vae_flow import FlowVAE
|
| 27 |
+
except ImportError as e:
|
| 28 |
+
print(f"CRITICAL Error importing models: {e}")
|
| 29 |
+
print("Please ensure 'diffusion-point-cloud' directory is in sys.path and contains the model definitions.")
|
| 30 |
+
sys.exit(1)
|
| 31 |
+
|
| 32 |
+
# --- Model Checkpoint Paths and Loading ---
|
| 33 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 34 |
+
print(f"Using device: {DEVICE.upper()}")
|
| 35 |
+
|
| 36 |
+
MODEL_CONFIGS = {
|
| 37 |
+
"Airplane": {
|
| 38 |
+
"path_function": lambda: hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt", revision="main"),
|
| 39 |
+
"expected_model_type": "gaussian",
|
| 40 |
+
"default_args": {
|
| 41 |
+
'model': "gaussian", # Should match expected_model_type
|
| 42 |
+
'latent_dim': 128,
|
| 43 |
+
'hyper': None,
|
| 44 |
+
'residual': True,
|
| 45 |
+
'num_points': 2048, # For sampling
|
| 46 |
+
# 'flexibility' will be taken from UI
|
| 47 |
+
}
|
| 48 |
+
},
|
| 49 |
+
"Chair": {
|
| 50 |
+
"path_function": lambda: "./GEN_chair.pt",
|
| 51 |
+
"expected_model_type": "gaussian", # Assuming Gaussian for chair as well
|
| 52 |
+
"default_args": {
|
| 53 |
+
'model': "gaussian",
|
| 54 |
+
'latent_dim': 128,
|
| 55 |
+
'hyper': None,
|
| 56 |
+
'residual': True,
|
| 57 |
+
'num_points': 2048,
|
| 58 |
+
}
|
| 59 |
+
}
|
| 60 |
+
# To add more models:
|
| 61 |
+
# "YourModelName": {
|
| 62 |
+
# "path_function": lambda: "path/to/your/model.pt",
|
| 63 |
+
# "expected_model_type": "gaussian", # or "flow"
|
| 64 |
+
# "default_args": { ... } # Model-specific defaults
|
| 65 |
+
# }
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# Load checkpoints
|
| 70 |
+
LOADED_CHECKPOINTS = {}
|
| 71 |
+
for model_name, config in MODEL_CONFIGS.items():
|
| 72 |
+
model_path = "" # Initialize for error message
|
| 73 |
+
try:
|
| 74 |
+
model_path = config["path_function"]()
|
| 75 |
+
if model_name == "Chair" and not os.path.exists(model_path): # Specific check for local file
|
| 76 |
+
print(f"WARNING: Checkpoint for {model_name} not found at '{model_path}'. This model will not be available.")
|
| 77 |
+
LOADED_CHECKPOINTS[model_name] = None
|
| 78 |
+
continue
|
| 79 |
+
print(f"Loading checkpoint for {model_name} from '{model_path}'...")
|
| 80 |
+
LOADED_CHECKPOINTS[model_name] = torch.load(model_path, map_location=torch.device(DEVICE), weights_only=False)
|
| 81 |
+
print(f"Successfully loaded {model_name}.")
|
| 82 |
+
except Exception as e:
|
| 83 |
+
print(f"ERROR loading checkpoint for {model_name} from '{model_path}': {e}")
|
| 84 |
+
LOADED_CHECKPOINTS[model_name] = None
|
| 85 |
+
|
| 86 |
+
# --- Helper Functions ---
|
| 87 |
def seed_all(seed):
|
| 88 |
torch.manual_seed(seed)
|
| 89 |
np.random.seed(seed)
|
|
|
|
| 98 |
shift = pc.mean(dim=0).reshape(1, 3)
|
| 99 |
scale = pc.flatten().std().reshape(1, 1)
|
| 100 |
elif mode == 'shape_bbox':
|
| 101 |
+
pc_max, _ = pc.max(dim=0, keepdim=True)
|
| 102 |
+
pc_min, _ = pc.min(dim=0, keepdim=True)
|
| 103 |
shift = ((pc_min + pc_max) / 2).view(1, 3)
|
| 104 |
scale = (pc_max - pc_min).max().reshape(1, 1) / 2
|
| 105 |
+
else: # Fallback
|
| 106 |
+
shift = torch.zeros_like(pc.mean(dim=0).reshape(1, 3))
|
| 107 |
+
scale = torch.ones_like(pc.flatten().std().reshape(1, 1))
|
| 108 |
|
| 109 |
+
if scale.abs().item() < 1e-8: # Prevent division by zero or very small scale
|
| 110 |
+
scale = torch.tensor(1.0, device=pc.device, dtype=pc.dtype).reshape(1, 1)
|
| 111 |
+
|
| 112 |
+
pcs[i] = (pc - shift) / scale
|
|
|
|
|
|
|
| 113 |
return pcs
|
| 114 |
|
| 115 |
+
# --- Core Prediction Logic ---
|
| 116 |
+
def predict(seed_val, selected_model_name, flexibility_val):
|
| 117 |
+
seed_all(int(seed_val))
|
| 118 |
+
|
| 119 |
+
ckpt = LOADED_CHECKPOINTS.get(selected_model_name)
|
| 120 |
+
if ckpt is None:
|
| 121 |
+
raise ValueError(f"Checkpoint for model '{selected_model_name}' not loaded or unavailable.")
|
| 122 |
|
| 123 |
+
model_specific_defaults = MODEL_CONFIGS[selected_model_name].get("default_args", {})
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
+
# --- Argument Handling for Model Instantiation and Sampling ---
|
| 126 |
actual_args = None
|
| 127 |
+
# Prioritize args from checkpoint if available and seems valid
|
| 128 |
if 'args' in ckpt and hasattr(ckpt['args'], 'model'):
|
| 129 |
actual_args = ckpt['args']
|
| 130 |
+
print(f"Using 'args' found in checkpoint for {selected_model_name}.")
|
| 131 |
+
# Augment with model-specific defaults if attributes are missing from ckpt['args']
|
| 132 |
+
for key, default_value in model_specific_defaults.items():
|
| 133 |
+
if not hasattr(actual_args, key):
|
| 134 |
+
print(f"Checkpoint 'args' missing '{key}'. Setting default: {default_value}")
|
| 135 |
+
setattr(actual_args, key, default_value)
|
| 136 |
else:
|
| 137 |
+
print(f"Warning: 'args' not found or 'args.model' missing in checkpoint for {selected_model_name}. Constructing mock_args from defaults.")
|
| 138 |
+
# Fallback: construct args using model_specific_defaults, trying to get values from top-level of ckpt
|
| 139 |
+
actual_args_dict = {}
|
| 140 |
+
for key, default_value in model_specific_defaults.items():
|
| 141 |
+
# Try to get from ckpt top-level first, then use the model-specific default
|
| 142 |
+
actual_args_dict[key] = ckpt.get(key, default_value)
|
| 143 |
+
actual_args = type('Args', (), actual_args_dict)()
|
| 144 |
+
|
| 145 |
+
# Ensure essential attributes for model construction and sampling are present on actual_args
|
| 146 |
+
# These might have been set by defaults above, but good to double check or enforce
|
| 147 |
+
if not hasattr(actual_args, 'model'): # Critical
|
| 148 |
+
raise ValueError("Resolved 'actual_args' is missing the 'model' attribute.")
|
| 149 |
+
if not hasattr(actual_args, 'latent_dim'): setattr(actual_args, 'latent_dim', 128) # A common default
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
+
if actual_args.model == 'gaussian':
|
| 152 |
+
if not hasattr(actual_args, 'residual'):
|
| 153 |
+
print("Setting default 'residual=True' for GaussianVAE.")
|
| 154 |
+
setattr(actual_args, 'residual', True)
|
| 155 |
+
elif actual_args.model == 'flow': # Parameters for FlowVAE
|
| 156 |
+
if not hasattr(actual_args, 'flow_depth'): setattr(actual_args, 'flow_depth', 10)
|
| 157 |
+
if not hasattr(actual_args, 'flow_hidden_dim'): setattr(actual_args, 'flow_hidden_dim', 256)
|
| 158 |
+
|
| 159 |
+
# Sampling parameters
|
| 160 |
if not hasattr(actual_args, 'num_points'):
|
| 161 |
+
print("Setting default 'num_points=2048' for sampling.")
|
| 162 |
+
setattr(actual_args, 'num_points', 2048)
|
| 163 |
|
| 164 |
+
# Use flexibility from UI slider, this overrides any 'flexibility' in args
|
| 165 |
+
setattr(actual_args, 'flexibility', flexibility_val)
|
| 166 |
+
print(f"Using flexibility: {actual_args.flexibility} for sampling.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
|
|
|
| 168 |
|
| 169 |
+
# --- Model Instantiation ---
|
| 170 |
+
model = None
|
| 171 |
if actual_args.model == 'gaussian':
|
| 172 |
+
model = GaussianVAE(actual_args).to(DEVICE)
|
| 173 |
elif actual_args.model == 'flow':
|
| 174 |
+
model = FlowVAE(actual_args).to(DEVICE)
|
| 175 |
else:
|
| 176 |
+
raise ValueError(f"Unknown model type in args: '{actual_args.model}'. Expected 'gaussian' or 'flow'.")
|
| 177 |
|
| 178 |
model.load_state_dict(ckpt['state_dict'])
|
| 179 |
model.eval()
|
| 180 |
|
| 181 |
+
# --- Point Cloud Generation ---
|
| 182 |
gen_pcs = []
|
| 183 |
with torch.no_grad():
|
| 184 |
+
z = torch.randn([1, actual_args.latent_dim], device=DEVICE)
|
| 185 |
+
x = model.sample(z, int(actual_args.num_points), flexibility=actual_args.flexibility)
|
|
|
|
| 186 |
gen_pcs.append(x.detach().cpu())
|
| 187 |
+
|
| 188 |
gen_pcs_tensor = torch.cat(gen_pcs, dim=0)[:1]
|
| 189 |
gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox")
|
| 190 |
|
| 191 |
return gen_pcs_normalized[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
+
|
| 194 |
+
# --- Gradio Interface Function ---
|
| 195 |
+
def generate_gradio(seed, model_choice, flexibility, point_color_hex, marker_size):
|
| 196 |
+
error_message = ""
|
| 197 |
+
figure_plot = None
|
| 198 |
+
download_file_path = None
|
| 199 |
+
|
| 200 |
+
try:
|
| 201 |
+
if seed is None:
|
| 202 |
+
seed = random.randint(0, 2**16 - 1)
|
| 203 |
+
seed = int(seed)
|
| 204 |
+
|
| 205 |
+
if not model_choice:
|
| 206 |
+
error_message = "Please choose a model type."
|
| 207 |
+
# Return empty plot and no file if model not chosen
|
| 208 |
+
return go.Figure(), None, error_message
|
| 209 |
+
|
| 210 |
+
print(f"Generating {model_choice} with Seed: {seed}, Flex: {flexibility}, Color: {point_color_hex}, Size: {marker_size}")
|
| 211 |
+
|
| 212 |
+
points = predict(seed, model_choice, flexibility)
|
| 213 |
+
|
| 214 |
+
# Create Plotly figure
|
| 215 |
+
figure_plot = go.Figure(
|
| 216 |
+
data=[
|
| 217 |
+
go.Scatter3d(
|
| 218 |
+
x=points[:, 0], y=points[:, 1], z=points[:, 2],
|
| 219 |
+
mode='markers',
|
| 220 |
+
marker=dict(size=marker_size, color=point_color_hex) # Use hex color directly
|
| 221 |
+
)
|
| 222 |
+
],
|
| 223 |
+
layout=dict(
|
| 224 |
+
title=f"Generated {model_choice} (Seed: {seed}, Flex: {flexibility:.2f})",
|
| 225 |
+
scene=dict(
|
| 226 |
+
xaxis=dict(visible=True, title='X', backgroundcolor="rgb(230,230,230)", gridcolor="white", zerolinecolor="white"),
|
| 227 |
+
yaxis=dict(visible=True, title='Y', backgroundcolor="rgb(230,230,230)", gridcolor="white", zerolinecolor="white"),
|
| 228 |
+
zaxis=dict(visible=True, title='Z', backgroundcolor="rgb(230,230,230)", gridcolor="white", zerolinecolor="white"),
|
| 229 |
+
aspectmode='data'
|
| 230 |
+
),
|
| 231 |
+
margin=dict(l=0, r=0, b=0, t=40)
|
| 232 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
)
|
|
|
|
|
|
|
| 234 |
|
| 235 |
+
# Prepare file for download
|
| 236 |
+
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".xyz", encoding='utf-8') as tmp_file:
|
| 237 |
+
for point in points:
|
| 238 |
+
tmp_file.write(f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f}\n")
|
| 239 |
+
download_file_path = tmp_file.name
|
| 240 |
+
print(f"Point cloud saved for download at: {download_file_path}")
|
| 241 |
+
|
| 242 |
+
except ValueError as ve:
|
| 243 |
+
error_message = f"Configuration Error: {str(ve)}"
|
| 244 |
+
print(error_message)
|
| 245 |
+
except AttributeError as ae:
|
| 246 |
+
error_message = f"Model Configuration Issue: {str(ae)}. The checkpoint might be missing expected parameters or they are incompatible."
|
| 247 |
+
print(error_message)
|
| 248 |
+
except Exception as e:
|
| 249 |
+
error_message = f"An unexpected error occurred: {str(e)}"
|
| 250 |
+
print(f"{error_message}\nFull Traceback:\n{traceback.format_exc()}")
|
| 251 |
+
|
| 252 |
+
# Ensure we always return three values, even on error
|
| 253 |
+
if figure_plot is None: figure_plot = go.Figure() # Empty plot on error
|
| 254 |
+
return figure_plot, download_file_path, error_message
|
| 255 |
|
| 256 |
+
# --- Gradio UI Definition ---
|
| 257 |
+
available_models = [name for name, ckpt in LOADED_CHECKPOINTS.items() if ckpt is not None]
|
| 258 |
+
if not available_models:
|
| 259 |
+
print("CRITICAL: No models were loaded successfully. The application may not function as expected.")
|
| 260 |
|
| 261 |
+
markdown_description = f'''
|
| 262 |
+
# Diffusion Probabilistic Models for 3D Point Cloud Generation
|
|
|
|
|
|
|
| 263 |
|
| 264 |
+
[CVPR 2021 Paper: "Diffusion Probabilistic Models for 3D Point Cloud Generation"](https://arxiv.org/abs/2103.01458) | [Official GitHub](https://github.com/luost26/diffusion-point-cloud)
|
| 265 |
|
| 266 |
+
This demo allows you to generate 3D point clouds using pre-trained models.
|
| 267 |
+
- Adjust the **Seed** for different random initializations.
|
| 268 |
+
- Choose a **Model Type** (e.g., Airplane, Chair).
|
| 269 |
+
- Control **Sampling Flexibility**: Lower values tend towards the mean shape, higher values increase diversity.
|
| 270 |
+
- Customize **Point Color** and **Marker Size**.
|
| 271 |
+
|
| 272 |
+
Running on: **{DEVICE.upper()}**
|
| 273 |
'''
|
| 274 |
+
if "Chair" in MODEL_CONFIGS and "Chair" not in available_models: # Check if Chair was intended but failed to load
|
| 275 |
+
markdown_description += "\n\n**Warning:** The 'Chair' model checkpoint (`GEN_chair.pt`) was not found or failed to load. Please ensure it's in the root directory if you intend to use it."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
|
|
|
|
|
|
|
| 277 |
|
| 278 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 279 |
+
gr.Markdown(markdown_description)
|
| 280 |
+
|
| 281 |
+
with gr.Row():
|
| 282 |
+
with gr.Column(scale=1): # Controls Column
|
| 283 |
+
model_dropdown = gr.Dropdown(choices=available_models, label="Choose Model Type", value=available_models[0] if available_models else None)
|
| 284 |
+
seed_slider = gr.Slider(minimum=0, maximum=2**16 - 1, step=1, label='Seed', value=777, randomize=True)
|
| 285 |
+
flexibility_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, label='Sampling Flexibility', value=0.0)
|
| 286 |
+
|
| 287 |
+
with gr.Row():
|
| 288 |
+
color_picker = gr.ColorPicker(label="Point Color", value="#EE4B2B") # Default orange
|
| 289 |
+
marker_size_slider = gr.Slider(minimum=1, maximum=10, step=1, label="Marker Size", value=2)
|
| 290 |
+
|
| 291 |
+
generate_btn = gr.Button(value="Generate Point Cloud", variant="primary")
|
| 292 |
+
|
| 293 |
+
with gr.Column(scale=2): # Output Column
|
| 294 |
+
plot_output = gr.Plot(label="Generated Point Cloud")
|
| 295 |
+
file_download_output = gr.File(label="Download Point Cloud (.xyz)")
|
| 296 |
+
error_display = gr.Markdown("") # For displaying error messages
|
| 297 |
+
|
| 298 |
+
generate_btn.click(
|
| 299 |
+
fn=generate_gradio,
|
| 300 |
+
inputs=[seed_slider, model_dropdown, flexibility_slider, color_picker, marker_size_slider],
|
| 301 |
+
outputs=[plot_output, file_download_output, error_display]
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
if available_models:
|
| 305 |
+
example_list = [
|
| 306 |
+
[777, available_models[0], 0.0, "#EE4B2B", 2],
|
| 307 |
+
[1234, available_models[0], 0.5, "#1E90FF", 3], # DodgerBlue
|
| 308 |
+
]
|
| 309 |
+
if len(available_models) > 1: # If Chair (or another model) is available
|
| 310 |
+
example_list.append([100, available_models[1], 0.2, "#32CD32", 2.5]) # LimeGreen
|
| 311 |
+
|
| 312 |
+
gr.Examples(
|
| 313 |
+
examples=example_list,
|
| 314 |
+
inputs=[seed_slider, model_dropdown, flexibility_slider, color_picker, marker_size_slider],
|
| 315 |
+
outputs=[plot_output, file_download_output, error_display],
|
| 316 |
+
fn=generate_gradio,
|
| 317 |
+
cache_examples=False, # Generation is fast enough, no need to cache potentially large plots
|
| 318 |
+
)
|
| 319 |
|
| 320 |
+
# --- Application Launch ---
|
| 321 |
if __name__ == "__main__":
|
| 322 |
+
if not available_models:
|
| 323 |
+
print("No models available to run the Gradio demo. You might want to check checkpoint paths and errors above.")
|
| 324 |
+
# Optionally, you could still launch a limited UI that just shows an error.
|
| 325 |
+
# For now, we'll just print and let it potentially launch an empty UI if Gradio is set up.
|
|
|
|
| 326 |
|
| 327 |
+
print("Launching Gradio demo...")
|
| 328 |
+
demo.launch() # Add share=True if you want a public link when running locally
|