Spaces:
Paused
Paused
add download weights
Browse files
app.py
CHANGED
|
@@ -4,11 +4,12 @@ import os
|
|
| 4 |
import torch
|
| 5 |
import trimesh
|
| 6 |
import sys
|
| 7 |
-
sys.path.append("
|
| 8 |
from cube3d.inference.engine import EngineFast
|
| 9 |
from pathlib import Path
|
| 10 |
import uuid
|
| 11 |
import shutil
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
GLOBAL_STATE = {}
|
|
@@ -104,10 +105,17 @@ if __name__=="__main__":
|
|
| 104 |
)
|
| 105 |
|
| 106 |
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
engine_fast = EngineFast(
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
device=torch.device("cuda"),
|
| 112 |
)
|
| 113 |
GLOBAL_STATE["engine_fast"] = engine_fast
|
|
|
|
| 4 |
import torch
|
| 5 |
import trimesh
|
| 6 |
import sys
|
| 7 |
+
sys.path.append("cube")
|
| 8 |
from cube3d.inference.engine import EngineFast
|
| 9 |
from pathlib import Path
|
| 10 |
import uuid
|
| 11 |
import shutil
|
| 12 |
+
from huggingface_hub import snapshot_download
|
| 13 |
|
| 14 |
|
| 15 |
GLOBAL_STATE = {}
|
|
|
|
| 105 |
)
|
| 106 |
|
| 107 |
args = parser.parse_args()
|
| 108 |
+
snapshot_download(
|
| 109 |
+
repo_id="Roblox/cube3d-v0.1",
|
| 110 |
+
local_dir="./model_weights"
|
| 111 |
+
)
|
| 112 |
+
config_path = "./model_weights/shape_tokenizer.safetensors"
|
| 113 |
+
gpt_ckpt_path = "./model_weights/shape_gpt.safetensors"
|
| 114 |
+
shape_ckpt_path = "./model_weights/shape_tokenizer.safetensors"
|
| 115 |
engine_fast = EngineFast(
|
| 116 |
+
config_path,
|
| 117 |
+
gpt_ckpt_path,
|
| 118 |
+
shape_ckpt_path,
|
| 119 |
device=torch.device("cuda"),
|
| 120 |
)
|
| 121 |
GLOBAL_STATE["engine_fast"] = engine_fast
|