Update app.py
Browse files
app.py
CHANGED
|
@@ -79,54 +79,58 @@ def predict(Seed, ckpt):
|
|
| 79 |
seed_all(int(Seed))
|
| 80 |
|
| 81 |
# --- MODIFICATION START ---
|
| 82 |
-
|
| 83 |
-
# The key might be 'args', 'config', or something similar.
|
| 84 |
-
# We need to inspect the actual keys of a loaded ckpt if this doesn't work.
|
| 85 |
if 'args' in ckpt and hasattr(ckpt['args'], 'model'):
|
| 86 |
actual_args = ckpt['args']
|
| 87 |
print("Using 'args' found in checkpoint.")
|
| 88 |
else:
|
| 89 |
-
#
|
| 90 |
-
# This part needs to be more robust and include all necessary defaults
|
| 91 |
print("Warning: 'args' not found or 'args.model' missing in checkpoint. Constructing mock_args.")
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
# or by inspecting a correctly loaded checkpoint from the original repo.
|
| 95 |
default_latent_dim = 128
|
| 96 |
-
default_hyper = None
|
| 97 |
-
default_residual = True
|
| 98 |
default_flow_depth = 10
|
| 99 |
default_flow_hidden_dim = 256
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
# Try to get values from ckpt if they exist at the top level
|
| 105 |
-
# (some checkpoints might store them flatly instead of under an 'args' key)
|
| 106 |
-
model_type = ckpt.get('model', default_model_type) # Check if 'model' key exists directly
|
| 107 |
-
latent_dim = ckpt.get('latent_dim', default_latent_dim)
|
| 108 |
-
hyper = ckpt.get('hyper', default_hyper)
|
| 109 |
-
residual = ckpt.get('residual', default_residual)
|
| 110 |
-
flow_depth = ckpt.get('flow_depth', default_flow_depth)
|
| 111 |
-
flow_hidden_dim = ckpt.get('flow_hidden_dim', default_flow_hidden_dim)
|
| 112 |
-
num_points_to_generate = ckpt.get('num_points', default_num_points)
|
| 113 |
-
flexibility = ckpt.get('flexibility', default_flexibility)
|
| 114 |
-
|
| 115 |
-
# Create the mock_args object
|
| 116 |
actual_args = type('Args', (), {
|
| 117 |
-
'model':
|
| 118 |
-
'latent_dim': latent_dim,
|
| 119 |
-
'hyper': hyper,
|
| 120 |
-
'residual': residual,
|
| 121 |
-
'flow_depth': flow_depth,
|
| 122 |
-
'flow_hidden_dim': flow_hidden_dim,
|
| 123 |
-
'num_points':
|
| 124 |
-
'flexibility': flexibility
|
| 125 |
-
# Add any other attributes that models might expect from 'args'
|
| 126 |
})()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
# --- MODIFICATION END ---
|
| 128 |
|
| 129 |
-
# Now use actual_args to instantiate models
|
| 130 |
if actual_args.model == 'gaussian':
|
| 131 |
model = GaussianVAE(actual_args).to(device)
|
| 132 |
elif actual_args.model == 'flow':
|
|
@@ -139,6 +143,7 @@ def predict(Seed, ckpt):
|
|
| 139 |
|
| 140 |
gen_pcs = []
|
| 141 |
with torch.no_grad():
|
|
|
|
| 142 |
z = torch.randn([1, actual_args.latent_dim]).to(device)
|
| 143 |
x = model.sample(z, actual_args.num_points, flexibility=actual_args.flexibility)
|
| 144 |
gen_pcs.append(x.detach().cpu())
|
|
@@ -147,7 +152,6 @@ def predict(Seed, ckpt):
|
|
| 147 |
gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox")
|
| 148 |
|
| 149 |
return gen_pcs_normalized[0]
|
| 150 |
-
|
| 151 |
def generate(seed, value):
|
| 152 |
if value == "Airplane":
|
| 153 |
ckpt = ckpt_airplane
|
|
|
|
| 79 |
seed_all(int(Seed))
|
| 80 |
|
| 81 |
# --- MODIFICATION START ---
|
| 82 |
+
actual_args = None
|
|
|
|
|
|
|
| 83 |
if 'args' in ckpt and hasattr(ckpt['args'], 'model'):
|
| 84 |
actual_args = ckpt['args']
|
| 85 |
print("Using 'args' found in checkpoint.")
|
| 86 |
else:
|
| 87 |
+
# This fallback should ideally not be hit if 'args' is usually present
|
|
|
|
| 88 |
print("Warning: 'args' not found or 'args.model' missing in checkpoint. Constructing mock_args.")
|
| 89 |
+
# Define all necessary defaults if we have to construct from scratch
|
| 90 |
+
default_model_type = 'gaussian'
|
|
|
|
| 91 |
default_latent_dim = 128
|
| 92 |
+
default_hyper = None
|
| 93 |
+
default_residual = True
|
| 94 |
default_flow_depth = 10
|
| 95 |
default_flow_hidden_dim = 256
|
| 96 |
+
default_num_points = 2048 # Default for sampling
|
| 97 |
+
default_flexibility = 0.0 # Default for sampling
|
| 98 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
actual_args = type('Args', (), {
|
| 100 |
+
'model': ckpt.get('model', default_model_type),
|
| 101 |
+
'latent_dim': ckpt.get('latent_dim', default_latent_dim),
|
| 102 |
+
'hyper': ckpt.get('hyper', default_hyper),
|
| 103 |
+
'residual': ckpt.get('residual', default_residual),
|
| 104 |
+
'flow_depth': ckpt.get('flow_depth', default_flow_depth),
|
| 105 |
+
'flow_hidden_dim': ckpt.get('flow_hidden_dim', default_flow_hidden_dim),
|
| 106 |
+
'num_points': ckpt.get('num_points', default_num_points), # Try to get from ckpt top-level too
|
| 107 |
+
'flexibility': ckpt.get('flexibility', default_flexibility) # Try to get from ckpt top-level too
|
|
|
|
| 108 |
})()
|
| 109 |
+
|
| 110 |
+
# Ensure essential attributes for sampling exist on actual_args, even if 'args' was found
|
| 111 |
+
# These are parameters for the .sample() method, not necessarily model construction.
|
| 112 |
+
# The original training args might not have included these if they were fixed in the sampling script.
|
| 113 |
+
|
| 114 |
+
# Default values for sampling parameters if not present in actual_args
|
| 115 |
+
default_num_points_sampling = 2048
|
| 116 |
+
default_flexibility_sampling = 0.0
|
| 117 |
+
|
| 118 |
+
if not hasattr(actual_args, 'num_points'):
|
| 119 |
+
print(f"Attribute 'num_points' not found in actual_args. Setting default: {default_num_points_sampling}")
|
| 120 |
+
setattr(actual_args, 'num_points', default_num_points_sampling)
|
| 121 |
+
|
| 122 |
+
if not hasattr(actual_args, 'flexibility'):
|
| 123 |
+
print(f"Attribute 'flexibility' not found in actual_args. Setting default: {default_flexibility_sampling}")
|
| 124 |
+
setattr(actual_args, 'flexibility', default_flexibility_sampling)
|
| 125 |
+
|
| 126 |
+
# Also ensure 'residual' is present if it's a Gaussian model, as it was an issue before
|
| 127 |
+
# This is more for model construction, but good to double-check if the 'args' from ckpt might be incomplete
|
| 128 |
+
if actual_args.model == 'gaussian' and not hasattr(actual_args, 'residual'):
|
| 129 |
+
print(f"Attribute 'residual' not found in actual_args for Gaussian model. Setting default: True")
|
| 130 |
+
setattr(actual_args, 'residual', True) # Default for GaussianVAE
|
| 131 |
+
|
| 132 |
# --- MODIFICATION END ---
|
| 133 |
|
|
|
|
| 134 |
if actual_args.model == 'gaussian':
|
| 135 |
model = GaussianVAE(actual_args).to(device)
|
| 136 |
elif actual_args.model == 'flow':
|
|
|
|
| 143 |
|
| 144 |
gen_pcs = []
|
| 145 |
with torch.no_grad():
|
| 146 |
+
# Use the (potentially now augmented) actual_args for sampling
|
| 147 |
z = torch.randn([1, actual_args.latent_dim]).to(device)
|
| 148 |
x = model.sample(z, actual_args.num_points, flexibility=actual_args.flexibility)
|
| 149 |
gen_pcs.append(x.detach().cpu())
|
|
|
|
| 152 |
gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox")
|
| 153 |
|
| 154 |
return gen_pcs_normalized[0]
|
|
|
|
| 155 |
def generate(seed, value):
|
| 156 |
if value == "Airplane":
|
| 157 |
ckpt = ckpt_airplane
|