Spaces:
Running
on
Zero
Running
on
Zero
Working New 3d Model and Height Map
Browse files- app.py +103 -27
- trellis/renderers/gaussian_render.py +0 -1
- trellis/utils/render_utils.py +61 -2
- utils/depth_estimation.py +4 -2
- utils/image_utils.py +2 -1
app.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import os
|
| 3 |
import spaces
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
os.environ['SPCONV_ALGO'] = 'native'
|
| 6 |
from typing import *
|
|
@@ -16,7 +16,7 @@ from tempfile import NamedTemporaryFile
|
|
| 16 |
import atexit
|
| 17 |
import random
|
| 18 |
#import accelerate
|
| 19 |
-
from transformers import AutoTokenizer
|
| 20 |
from trellis.pipelines import TrellisImageTo3DPipeline
|
| 21 |
from trellis.representations import Gaussian, MeshExtractResult
|
| 22 |
from trellis.utils import render_utils, postprocessing_utils
|
|
@@ -100,6 +100,8 @@ from utils.version_info import (
|
|
| 100 |
#release_torch_resources,
|
| 101 |
#get_torch_info
|
| 102 |
)
|
|
|
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
input_image_palette = []
|
|
@@ -675,8 +677,61 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
|
|
| 675 |
|
| 676 |
return gs, mesh, name
|
| 677 |
|
| 678 |
-
@spaces.GPU(
|
| 679 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 680 |
# Choose the image based on source
|
| 681 |
if depth_image_source == "Input Image":
|
| 682 |
image_path = input_image
|
|
@@ -695,41 +750,45 @@ def generate_3d_asset(depth_image_source, randomize_seed, seed, input_image, out
|
|
| 695 |
|
| 696 |
# Determine the final seed using default MAX_SEED from constants
|
| 697 |
final_seed = np.random.randint(0, constants.MAX_SEED) if randomize_seed else seed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 698 |
|
|
|
|
|
|
|
| 699 |
# Open image using standardized defaults
|
| 700 |
image_raw = Image.open(image_path).convert("RGB")
|
| 701 |
-
|
|
|
|
| 702 |
# Preprocess and run the Trellis pipeline with fixed sampler settings
|
| 703 |
-
|
| 704 |
-
# dict: The information of the generated 3D model.
|
| 705 |
-
# str: The path to the video of the 3D model.
|
| 706 |
-
processed_image = TRELLIS_PIPELINE.preprocess_image(image_raw, max_resolution=1536)
|
| 707 |
outputs = TRELLIS_PIPELINE.run(
|
| 708 |
processed_image,
|
| 709 |
-
seed=
|
| 710 |
formats=["gaussian", "mesh"],
|
| 711 |
preprocess_image=False,
|
| 712 |
sparse_structure_sampler_params={
|
| 713 |
-
"steps":
|
| 714 |
"cfg_strength": 7.5,
|
| 715 |
},
|
| 716 |
slat_sampler_params={
|
| 717 |
-
"steps":
|
| 718 |
"cfg_strength": 3.0,
|
| 719 |
},
|
| 720 |
)
|
| 721 |
|
| 722 |
# Validate the mesh
|
| 723 |
mesh = outputs['mesh'][0]
|
| 724 |
-
|
| 725 |
-
if
|
| 726 |
vertices = mesh['vertices']
|
| 727 |
faces = mesh['faces']
|
| 728 |
else:
|
| 729 |
vertices = mesh.vertices
|
| 730 |
faces = mesh.faces
|
| 731 |
|
| 732 |
-
# Check mesh properties
|
| 733 |
print(f"Mesh vertices: {vertices.shape}, faces: {faces.shape}")
|
| 734 |
if faces.max() >= vertices.shape[0]:
|
| 735 |
raise ValueError(f"Invalid mesh: face index {faces.max()} exceeds vertex count {vertices.shape[0]}")
|
|
@@ -738,23 +797,31 @@ def generate_3d_asset(depth_image_source, randomize_seed, seed, input_image, out
|
|
| 738 |
if not vertices.is_cuda or not faces.is_cuda:
|
| 739 |
raise ValueError("Mesh data must be on GPU")
|
| 740 |
if vertices.dtype != torch.float32 or faces.dtype != torch.int32:
|
| 741 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 742 |
|
| 743 |
-
# Save the video to a temporary file
|
| 744 |
user_dir = os.path.join(constants.TMPDIR, str(req.session_hash))
|
| 745 |
os.makedirs(user_dir, exist_ok=True)
|
| 746 |
|
| 747 |
-
video = render_utils.render_video(outputs['gaussian'][0], resolution=576, num_frames=
|
| 748 |
-
|
| 749 |
-
depth_snapshot = snapshot_results['depth'][0]
|
| 750 |
-
video_geo = render_utils.render_video(outputs['mesh'][0], resolution=576, num_frames=60, r=1)['normal']
|
| 751 |
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
|
| 752 |
video_path = os.path.join(user_dir, f'{output_name}.mp4')
|
| 753 |
-
imageio.mimsave(video_path, video, fps=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 754 |
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], output_name)
|
| 755 |
torch.cuda.empty_cache()
|
| 756 |
return [state, video_path, depth_snapshot]
|
| 757 |
|
|
|
|
| 758 |
@spaces.GPU(duration=90,progress=gr.Progress(track_tqdm=True))
|
| 759 |
def extract_glb(
|
| 760 |
state: dict,
|
|
@@ -1028,7 +1095,7 @@ with gr.Blocks(css_paths="style_20250128.css", title=title, theme='Surn/beeuty',
|
|
| 1028 |
# Gallery from PRE_RENDERED_IMAGES GOES HERE
|
| 1029 |
prerendered_image_gallery = gr.Gallery(label="Image Gallery", show_label=True, value=build_prerendered_images_by_quality(3,'thumbnail'), elem_id="gallery", elem_classes="solid", type="filepath", columns=[3], rows=[3], preview=False ,object_fit="contain", height="auto", format="png",allow_preview=False)
|
| 1030 |
with gr.Column():
|
| 1031 |
-
image_guidance_stength = gr.Slider(label="Image Guidance Strength (prompt percentage)", minimum=0, maximum=1.0, value=0.85, step=0.01, interactive=True)
|
| 1032 |
replace_input_image_button = gr.Button(
|
| 1033 |
"Replace Input Image",
|
| 1034 |
elem_id="prerendered_replace_input_image_button",
|
|
@@ -1106,7 +1173,7 @@ with gr.Blocks(css_paths="style_20250128.css", title=title, theme='Surn/beeuty',
|
|
| 1106 |
with gr.Row():
|
| 1107 |
with gr.Column():
|
| 1108 |
# Use standard seed settings only
|
| 1109 |
-
seed_3d = gr.Slider(0, constants.MAX_SEED, label="Seed (3D Generation)", value=0, step=1)
|
| 1110 |
randomize_seed_3d = gr.Checkbox(label="Randomize Seed (3D Generation)", value=True)
|
| 1111 |
with gr.Column():
|
| 1112 |
depth_image_source = gr.Radio(
|
|
@@ -1116,11 +1183,11 @@ with gr.Blocks(css_paths="style_20250128.css", title=title, theme='Surn/beeuty',
|
|
| 1116 |
)
|
| 1117 |
with gr.Row():
|
| 1118 |
generate_3d_asset_button = gr.Button("Generate 3D Asset", elem_classes="solid", variant="secondary")
|
|
|
|
|
|
|
| 1119 |
with gr.Row():
|
| 1120 |
# For display: video output and 3D model preview (GLTF)
|
| 1121 |
video_output = gr.Video(label="3D Asset Video", autoplay=True, loop=True, height=400)
|
| 1122 |
-
with gr.Row():
|
| 1123 |
-
depth_output = gr.Image(label="Depth Map", image_mode="L", elem_classes="centered solid imgcontainer", format="PNG", type="filepath", key="DepthOutput",interactive=False, show_download_button=True, show_fullscreen_button=True, show_share_button=True)
|
| 1124 |
with gr.Accordion("GLB Extraction Settings", open=False):
|
| 1125 |
with gr.Row():
|
| 1126 |
mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
|
|
@@ -1134,6 +1201,8 @@ with gr.Blocks(css_paths="style_20250128.css", title=title, theme='Surn/beeuty',
|
|
| 1134 |
model_file = gr.File(label="3D GLTF", elem_classes="solid small centered")
|
| 1135 |
is_multiimage = gr.State(False)
|
| 1136 |
output_buf = gr.State()
|
|
|
|
|
|
|
| 1137 |
with gr.Row():
|
| 1138 |
gr.Examples(examples=[
|
| 1139 |
["assets//examples//hex_map_p1.png", False, True, -32,-31,80,80,-1.8,0,35,0,1,"#FFD0D0", 15],
|
|
@@ -1245,8 +1314,13 @@ with gr.Blocks(css_paths="style_20250128.css", title=title, theme='Surn/beeuty',
|
|
| 1245 |
|
| 1246 |
# Chain the buttons
|
| 1247 |
generate_3d_asset_button.click(
|
| 1248 |
-
fn=
|
| 1249 |
inputs=[depth_image_source, randomize_seed_3d, seed_3d, input_image, output_image, overlay_image, bordered_image_output],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1250 |
outputs=[output_buf, video_output, depth_output],
|
| 1251 |
scroll_to_output=True
|
| 1252 |
).then(
|
|
@@ -1293,6 +1367,8 @@ if __name__ == "__main__":
|
|
| 1293 |
|
| 1294 |
#-------------- ------------------------------------------------MODEL INITIALIZATION------------------------------------------------------------#
|
| 1295 |
# Load models once during module import
|
|
|
|
|
|
|
| 1296 |
TRELLIS_PIPELINE = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
|
| 1297 |
TRELLIS_PIPELINE.cuda()
|
| 1298 |
try:
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
import spaces
|
| 3 |
+
import os
|
| 4 |
import numpy as np
|
| 5 |
os.environ['SPCONV_ALGO'] = 'native'
|
| 6 |
from typing import *
|
|
|
|
| 16 |
import atexit
|
| 17 |
import random
|
| 18 |
#import accelerate
|
| 19 |
+
from transformers import AutoTokenizer, DPTImageProcessor, DPTForDepthEstimation
|
| 20 |
from trellis.pipelines import TrellisImageTo3DPipeline
|
| 21 |
from trellis.representations import Gaussian, MeshExtractResult
|
| 22 |
from trellis.utils import render_utils, postprocessing_utils
|
|
|
|
| 100 |
#release_torch_resources,
|
| 101 |
#get_torch_info
|
| 102 |
)
|
| 103 |
+
#from utils.depth_estimation import (get_depth_map_from_state)
|
| 104 |
+
|
| 105 |
|
| 106 |
|
| 107 |
input_image_palette = []
|
|
|
|
| 677 |
|
| 678 |
return gs, mesh, name
|
| 679 |
|
| 680 |
+
@spaces.GPU()
|
| 681 |
+
def depth_process_image(image_path, resized_width=800, z_scale=208):
|
| 682 |
+
"""
|
| 683 |
+
Processes the input image to generate a depth map.
|
| 684 |
+
|
| 685 |
+
Args:
|
| 686 |
+
image_path (str): The file path to the input image.
|
| 687 |
+
resized_width (int, optional): The width to which the image is resized. Defaults to 800.
|
| 688 |
+
z_scale (int, optional): Z-axis scale factor. Defaults to 208.
|
| 689 |
+
|
| 690 |
+
Returns:
|
| 691 |
+
list: A list containing the depth image.
|
| 692 |
+
"""
|
| 693 |
+
image_path = Path(image_path)
|
| 694 |
+
if not image_path.exists():
|
| 695 |
+
raise ValueError("Image file not found")
|
| 696 |
+
|
| 697 |
+
# Load and resize the image
|
| 698 |
+
image_raw = Image.open(image_path).convert("RGB")
|
| 699 |
+
print(f"Original size: {image_raw.size}")
|
| 700 |
+
resized_height = int(resized_width * image_raw.size[1] / image_raw.size[0])
|
| 701 |
+
image = image_raw.resize((resized_width, resized_height), Image.Resampling.LANCZOS)
|
| 702 |
+
print(f"Resized size: {image.size}")
|
| 703 |
+
|
| 704 |
+
# Prepare image for the model
|
| 705 |
+
encoding = image_processor(image, return_tensors="pt")
|
| 706 |
+
|
| 707 |
+
# Perform depth estimation
|
| 708 |
+
with torch.no_grad():
|
| 709 |
+
outputs = depth_model(**encoding)
|
| 710 |
+
predicted_depth = outputs.predicted_depth
|
| 711 |
+
|
| 712 |
+
# Interpolate depth to match the image size
|
| 713 |
+
prediction = torch.nn.functional.interpolate(
|
| 714 |
+
predicted_depth.unsqueeze(1),
|
| 715 |
+
size=(image.height, image.width),
|
| 716 |
+
mode="bicubic",
|
| 717 |
+
align_corners=False,
|
| 718 |
+
).squeeze()
|
| 719 |
+
|
| 720 |
+
# Normalize the depth image to 8-bit
|
| 721 |
+
if torch.cuda.is_available():
|
| 722 |
+
prediction = prediction.numpy()
|
| 723 |
+
else:
|
| 724 |
+
prediction = prediction.cpu().numpy()
|
| 725 |
+
depth_min, depth_max = prediction.min(), prediction.max()
|
| 726 |
+
depth_image = ((prediction - depth_min) / (depth_max - depth_min) * 255).astype("uint8")
|
| 727 |
+
img = Image.fromarray(depth_image)
|
| 728 |
+
|
| 729 |
+
if torch.cuda.is_available():
|
| 730 |
+
torch.cuda.empty_cache()
|
| 731 |
+
torch.cuda.ipc_collect()
|
| 732 |
+
return img
|
| 733 |
+
|
| 734 |
+
def generate_3d_asset_part1(depth_image_source, randomize_seed, seed, input_image, output_image, overlay_image, bordered_image_output, progress=gr.Progress(track_tqdm=True)):
|
| 735 |
# Choose the image based on source
|
| 736 |
if depth_image_source == "Input Image":
|
| 737 |
image_path = input_image
|
|
|
|
| 750 |
|
| 751 |
# Determine the final seed using default MAX_SEED from constants
|
| 752 |
final_seed = np.random.randint(0, constants.MAX_SEED) if randomize_seed else seed
|
| 753 |
+
# Process the image for depth estimation
|
| 754 |
+
depth_img = depth_process_image(image_path, resized_width=1536, z_scale=332)
|
| 755 |
+
depth_img = resize_image_with_aspect_ratio(depth_img, 1536, 1536)
|
| 756 |
+
|
| 757 |
+
return depth_img, image_path, output_name, final_seed
|
| 758 |
|
| 759 |
+
@spaces.GPU(duration=150,progress=gr.Progress(track_tqdm=True))
|
| 760 |
+
def generate_3d_asset_part2(depth_img, image_path, output_name, seed, req: gr.Request, progress=gr.Progress(track_tqdm=True)):
|
| 761 |
# Open image using standardized defaults
|
| 762 |
image_raw = Image.open(image_path).convert("RGB")
|
| 763 |
+
resized_image = resize_image_with_aspect_ratio(image_raw, 1536, 1536)
|
| 764 |
+
depth_img = Image.open(depth_img).convert("RGBA")
|
| 765 |
# Preprocess and run the Trellis pipeline with fixed sampler settings
|
| 766 |
+
processed_image = TRELLIS_PIPELINE.preprocess_image(resized_image, max_resolution=1536)
|
|
|
|
|
|
|
|
|
|
| 767 |
outputs = TRELLIS_PIPELINE.run(
|
| 768 |
processed_image,
|
| 769 |
+
seed=seed,
|
| 770 |
formats=["gaussian", "mesh"],
|
| 771 |
preprocess_image=False,
|
| 772 |
sparse_structure_sampler_params={
|
| 773 |
+
"steps": 15,
|
| 774 |
"cfg_strength": 7.5,
|
| 775 |
},
|
| 776 |
slat_sampler_params={
|
| 777 |
+
"steps": 15,
|
| 778 |
"cfg_strength": 3.0,
|
| 779 |
},
|
| 780 |
)
|
| 781 |
|
| 782 |
# Validate the mesh
|
| 783 |
mesh = outputs['mesh'][0]
|
| 784 |
+
meshisdict = isinstance(mesh, dict)
|
| 785 |
+
if meshisdict:
|
| 786 |
vertices = mesh['vertices']
|
| 787 |
faces = mesh['faces']
|
| 788 |
else:
|
| 789 |
vertices = mesh.vertices
|
| 790 |
faces = mesh.faces
|
| 791 |
|
|
|
|
| 792 |
print(f"Mesh vertices: {vertices.shape}, faces: {faces.shape}")
|
| 793 |
if faces.max() >= vertices.shape[0]:
|
| 794 |
raise ValueError(f"Invalid mesh: face index {faces.max()} exceeds vertex count {vertices.shape[0]}")
|
|
|
|
| 797 |
if not vertices.is_cuda or not faces.is_cuda:
|
| 798 |
raise ValueError("Mesh data must be on GPU")
|
| 799 |
if vertices.dtype != torch.float32 or faces.dtype != torch.int32:
|
| 800 |
+
if meshisdict:
|
| 801 |
+
mesh['faces'] = faces.to(torch.int32)
|
| 802 |
+
mesh['vertices'] = vertices.to(torch.float32)
|
| 803 |
+
else:
|
| 804 |
+
mesh.faces = faces.to(torch.int32)
|
| 805 |
+
mesh.vertices = vertices.to(torch.float32)
|
| 806 |
|
|
|
|
| 807 |
user_dir = os.path.join(constants.TMPDIR, str(req.session_hash))
|
| 808 |
os.makedirs(user_dir, exist_ok=True)
|
| 809 |
|
| 810 |
+
video = render_utils.render_video(outputs['gaussian'][0], resolution=576, num_frames=64, r=1, fov=45)['color']
|
| 811 |
+
video_geo = render_utils.render_video(outputs['mesh'][0], resolution=576, num_frames=64, r=1, fov=45)['normal']
|
|
|
|
|
|
|
| 812 |
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
|
| 813 |
video_path = os.path.join(user_dir, f'{output_name}.mp4')
|
| 814 |
+
imageio.mimsave(video_path, video, fps=8)
|
| 815 |
+
|
| 816 |
+
#snapshot_results = render_utils.render_snapshot_depth(outputs['mesh'][0], resolution=1280, r=1, fov=80)
|
| 817 |
+
#depth_snapshot = Image.fromarray(snapshot_results['normal'][0]).convert("L")
|
| 818 |
+
depth_snapshot = depth_img
|
| 819 |
+
|
| 820 |
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], output_name)
|
| 821 |
torch.cuda.empty_cache()
|
| 822 |
return [state, video_path, depth_snapshot]
|
| 823 |
|
| 824 |
+
|
| 825 |
@spaces.GPU(duration=90,progress=gr.Progress(track_tqdm=True))
|
| 826 |
def extract_glb(
|
| 827 |
state: dict,
|
|
|
|
| 1095 |
# Gallery from PRE_RENDERED_IMAGES GOES HERE
|
| 1096 |
prerendered_image_gallery = gr.Gallery(label="Image Gallery", show_label=True, value=build_prerendered_images_by_quality(3,'thumbnail'), elem_id="gallery", elem_classes="solid", type="filepath", columns=[3], rows=[3], preview=False ,object_fit="contain", height="auto", format="png",allow_preview=False)
|
| 1097 |
with gr.Column():
|
| 1098 |
+
image_guidance_stength = gr.Slider(label="Image Guidance Strength (prompt percentage)", minimum=0, maximum=1.0, value=0.85, step=0.01, interactive=True)
|
| 1099 |
replace_input_image_button = gr.Button(
|
| 1100 |
"Replace Input Image",
|
| 1101 |
elem_id="prerendered_replace_input_image_button",
|
|
|
|
| 1173 |
with gr.Row():
|
| 1174 |
with gr.Column():
|
| 1175 |
# Use standard seed settings only
|
| 1176 |
+
seed_3d = gr.Slider(0, constants.MAX_SEED, label="Seed (3D Generation)", value=0, step=1, randomize=True)
|
| 1177 |
randomize_seed_3d = gr.Checkbox(label="Randomize Seed (3D Generation)", value=True)
|
| 1178 |
with gr.Column():
|
| 1179 |
depth_image_source = gr.Radio(
|
|
|
|
| 1183 |
)
|
| 1184 |
with gr.Row():
|
| 1185 |
generate_3d_asset_button = gr.Button("Generate 3D Asset", elem_classes="solid", variant="secondary")
|
| 1186 |
+
with gr.Row():
|
| 1187 |
+
depth_output = gr.Image(label="Depth Map", image_mode="L", elem_classes="centered solid imgcontainer", format="PNG", type="filepath", key="DepthOutput",interactive=False, show_download_button=True, show_fullscreen_button=True, show_share_button=True, height=400)
|
| 1188 |
with gr.Row():
|
| 1189 |
# For display: video output and 3D model preview (GLTF)
|
| 1190 |
video_output = gr.Video(label="3D Asset Video", autoplay=True, loop=True, height=400)
|
|
|
|
|
|
|
| 1191 |
with gr.Accordion("GLB Extraction Settings", open=False):
|
| 1192 |
with gr.Row():
|
| 1193 |
mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
|
|
|
|
| 1201 |
model_file = gr.File(label="3D GLTF", elem_classes="solid small centered")
|
| 1202 |
is_multiimage = gr.State(False)
|
| 1203 |
output_buf = gr.State()
|
| 1204 |
+
ddd_image_path = gr.State("./images/images/Beeuty-1.png")
|
| 1205 |
+
ddd_file_name = gr.State("Hexagon_file")
|
| 1206 |
with gr.Row():
|
| 1207 |
gr.Examples(examples=[
|
| 1208 |
["assets//examples//hex_map_p1.png", False, True, -32,-31,80,80,-1.8,0,35,0,1,"#FFD0D0", 15],
|
|
|
|
| 1314 |
|
| 1315 |
# Chain the buttons
|
| 1316 |
generate_3d_asset_button.click(
|
| 1317 |
+
fn=generate_3d_asset_part1,
|
| 1318 |
inputs=[depth_image_source, randomize_seed_3d, seed_3d, input_image, output_image, overlay_image, bordered_image_output],
|
| 1319 |
+
outputs=[depth_output, ddd_image_path, ddd_file_name, seed_3d ],
|
| 1320 |
+
scroll_to_output=True
|
| 1321 |
+
).then(
|
| 1322 |
+
fn=generate_3d_asset_part2,
|
| 1323 |
+
inputs=[depth_output, ddd_image_path, ddd_file_name, seed_3d ],
|
| 1324 |
outputs=[output_buf, video_output, depth_output],
|
| 1325 |
scroll_to_output=True
|
| 1326 |
).then(
|
|
|
|
| 1367 |
|
| 1368 |
#-------------- ------------------------------------------------MODEL INITIALIZATION------------------------------------------------------------#
|
| 1369 |
# Load models once during module import
|
| 1370 |
+
image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
|
| 1371 |
+
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large", ignore_mismatched_sizes=True)
|
| 1372 |
TRELLIS_PIPELINE = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
|
| 1373 |
TRELLIS_PIPELINE.cuda()
|
| 1374 |
try:
|
trellis/renderers/gaussian_render.py
CHANGED
|
@@ -11,7 +11,6 @@
|
|
| 11 |
|
| 12 |
import torch
|
| 13 |
import math
|
| 14 |
-
from easydict import EasyDict as edict
|
| 15 |
import numpy as np
|
| 16 |
from ..representations.gaussian import Gaussian
|
| 17 |
from .sh_utils import eval_sh
|
|
|
|
| 11 |
|
| 12 |
import torch
|
| 13 |
import math
|
|
|
|
| 14 |
import numpy as np
|
| 15 |
from ..representations.gaussian import Gaussian
|
| 16 |
from .sh_utils import eval_sh
|
trellis/utils/render_utils.py
CHANGED
|
@@ -67,6 +67,53 @@ def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=N
|
|
| 67 |
else:
|
| 68 |
raise ValueError(f'Unsupported sample type: {type(sample)}')
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
rets = {}
|
| 71 |
for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose):
|
| 72 |
if not isinstance(sample, MeshExtractResult):
|
|
@@ -82,11 +129,15 @@ def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=N
|
|
| 82 |
rets['depth'].append(None)
|
| 83 |
else:
|
| 84 |
res = renderer.render(sample, extr, intr)
|
|
|
|
| 85 |
if 'normal' not in rets: rets['normal'] = []
|
| 86 |
rets['normal'].append(np.clip(res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
return rets
|
| 88 |
|
| 89 |
-
|
| 90 |
def render_video(sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, **kwargs):
|
| 91 |
yaws = torch.linspace(0, 2 * 3.1415, num_frames)
|
| 92 |
pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames))
|
|
@@ -107,10 +158,18 @@ def render_multiview(sample, resolution=512, nviews=30):
|
|
| 107 |
return res['color'], extrinsics, intrinsics
|
| 108 |
|
| 109 |
|
| 110 |
-
def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=
|
| 111 |
yaw = [0, np.pi/2, np.pi, 3*np.pi/2]
|
| 112 |
yaw_offset = offset[0]
|
| 113 |
yaw = [y + yaw_offset for y in yaw]
|
| 114 |
pitch = [offset[1] for _ in range(4)]
|
| 115 |
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov)
|
| 116 |
return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
else:
|
| 68 |
raise ValueError(f'Unsupported sample type: {type(sample)}')
|
| 69 |
|
| 70 |
+
rets = {}
|
| 71 |
+
for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose):
|
| 72 |
+
if not isinstance(sample, MeshExtractResult):
|
| 73 |
+
res = renderer.render(sample, extr, intr, colors_overwrite=colors_overwrite)
|
| 74 |
+
if 'color' not in rets: rets['color'] = []
|
| 75 |
+
# if 'depth' not in rets: rets['depth'] = []
|
| 76 |
+
rets['color'].append(np.clip(res['color'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
|
| 77 |
+
# if 'percent_depth' in res:
|
| 78 |
+
# rets['depth'].append(res['percent_depth'].detach().cpu().numpy())
|
| 79 |
+
# elif 'depth' in res:
|
| 80 |
+
# rets['depth'].append(res['depth'].detach().cpu().numpy())
|
| 81 |
+
# else:
|
| 82 |
+
# rets['depth'].append(None)
|
| 83 |
+
else:
|
| 84 |
+
res = renderer.render(sample, extr, intr)
|
| 85 |
+
if 'normal' not in rets: rets['normal'] = []
|
| 86 |
+
rets['normal'].append(np.clip(res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
|
| 87 |
+
|
| 88 |
+
return rets
|
| 89 |
+
|
| 90 |
+
def render_frames_depth(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, **kwargs):
|
| 91 |
+
if isinstance(sample, Octree):
|
| 92 |
+
renderer = OctreeRenderer()
|
| 93 |
+
renderer.rendering_options.resolution = options.get('resolution', 512)
|
| 94 |
+
renderer.rendering_options.near = options.get('near', 0.8)
|
| 95 |
+
renderer.rendering_options.far = options.get('far', 1.6)
|
| 96 |
+
renderer.rendering_options.bg_color = options.get('bg_color', (0, 0, 0))
|
| 97 |
+
renderer.rendering_options.ssaa = options.get('ssaa', 4)
|
| 98 |
+
renderer.pipe.primitive = sample.primitive
|
| 99 |
+
elif isinstance(sample, Gaussian):
|
| 100 |
+
renderer = GaussianRenderer()
|
| 101 |
+
renderer.rendering_options.resolution = options.get('resolution', 512)
|
| 102 |
+
renderer.rendering_options.near = options.get('near', 0.8)
|
| 103 |
+
renderer.rendering_options.far = options.get('far', 1.6)
|
| 104 |
+
renderer.rendering_options.bg_color = options.get('bg_color', (0, 0, 0))
|
| 105 |
+
renderer.rendering_options.ssaa = options.get('ssaa', 1)
|
| 106 |
+
renderer.pipe.kernel_size = kwargs.get('kernel_size', 0.1)
|
| 107 |
+
renderer.pipe.use_mip_gaussian = True
|
| 108 |
+
elif isinstance(sample, MeshExtractResult):
|
| 109 |
+
renderer = MeshRenderer()
|
| 110 |
+
renderer.rendering_options.resolution = options.get('resolution', 512)
|
| 111 |
+
renderer.rendering_options.near = options.get('near', 1)
|
| 112 |
+
renderer.rendering_options.far = options.get('far', 100)
|
| 113 |
+
renderer.rendering_options.ssaa = options.get('ssaa', 4)
|
| 114 |
+
else:
|
| 115 |
+
raise ValueError(f'Unsupported sample type: {type(sample)}')
|
| 116 |
+
|
| 117 |
rets = {}
|
| 118 |
for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose):
|
| 119 |
if not isinstance(sample, MeshExtractResult):
|
|
|
|
| 129 |
rets['depth'].append(None)
|
| 130 |
else:
|
| 131 |
res = renderer.render(sample, extr, intr)
|
| 132 |
+
if 'depth' not in rets: rets['depth'] = []
|
| 133 |
if 'normal' not in rets: rets['normal'] = []
|
| 134 |
rets['normal'].append(np.clip(res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
|
| 135 |
+
if 'depth' in res:
|
| 136 |
+
rets['depth'].append(np.clip(res['depth'].detach().cpu().numpy(), 0, 255).astype(np.uint8))
|
| 137 |
+
else:
|
| 138 |
+
rets['depth'].append(None)
|
| 139 |
return rets
|
| 140 |
|
|
|
|
| 141 |
def render_video(sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, **kwargs):
|
| 142 |
yaws = torch.linspace(0, 2 * 3.1415, num_frames)
|
| 143 |
pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames))
|
|
|
|
| 158 |
return res['color'], extrinsics, intrinsics
|
| 159 |
|
| 160 |
|
| 161 |
+
def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=2, fov=60, **kwargs):
|
| 162 |
yaw = [0, np.pi/2, np.pi, 3*np.pi/2]
|
| 163 |
yaw_offset = offset[0]
|
| 164 |
yaw = [y + yaw_offset for y in yaw]
|
| 165 |
pitch = [offset[1] for _ in range(4)]
|
| 166 |
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov)
|
| 167 |
return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
|
| 168 |
+
|
| 169 |
+
def render_snapshot_depth(samples, resolution=512, bg_color=(0, 0, 0), offset=(0, np.pi/2), r=2, fov=90, **kwargs):
|
| 170 |
+
yaw = [0, np.pi/2, np.pi, 3*np.pi/2]
|
| 171 |
+
yaw_offset = offset[0]
|
| 172 |
+
yaw = [y + yaw_offset for y in yaw]
|
| 173 |
+
pitch = [offset[1] for _ in range(4)]
|
| 174 |
+
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov)
|
| 175 |
+
return render_frames_depth(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
|
utils/depth_estimation.py
CHANGED
|
@@ -12,6 +12,8 @@ from utils.image_utils import (
|
|
| 12 |
resize_image_with_aspect_ratio
|
| 13 |
)
|
| 14 |
from utils.constants import TMPDIR
|
|
|
|
|
|
|
| 15 |
|
| 16 |
# Load models once during module import
|
| 17 |
image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
|
|
@@ -258,10 +260,10 @@ def depth_process_image(image_path, resized_width=800, z_scale=208):
|
|
| 258 |
torch.cuda.ipc_collect()
|
| 259 |
return [img, gltf_path, gltf_path]
|
| 260 |
|
| 261 |
-
def get_depth_map_from_state(state):
|
| 262 |
from diff_gaussian_rasterization import GaussianRasterizer, GaussianRasterizationSettings
|
| 263 |
|
| 264 |
-
settings = GaussianRasterizationSettings(image_height=
|
| 265 |
rasterizer = GaussianRasterizer(settings)
|
| 266 |
# Assume state has necessary data like means3D, scales, etc.
|
| 267 |
rendered_image, rendered_depth, _, _, _, _ = rasterizer(means3D=state["means3D"], means2D=state["means2D"], shs=state["shs"], colors_precomp=state["colors_precomp"], opacities=state["opacities"], scales=state["scales"], rotations=state["rotations"], cov3D_precomp=state["cov3D_precomp"])
|
|
|
|
| 12 |
resize_image_with_aspect_ratio
|
| 13 |
)
|
| 14 |
from utils.constants import TMPDIR
|
| 15 |
+
from easydict import EasyDict as edict
|
| 16 |
+
|
| 17 |
|
| 18 |
# Load models once during module import
|
| 19 |
image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
|
|
|
|
| 260 |
torch.cuda.ipc_collect()
|
| 261 |
return [img, gltf_path, gltf_path]
|
| 262 |
|
| 263 |
+
def get_depth_map_from_state(state, image_height=1024, image_width=1024):
|
| 264 |
from diff_gaussian_rasterization import GaussianRasterizer, GaussianRasterizationSettings
|
| 265 |
|
| 266 |
+
settings = GaussianRasterizationSettings(image_height=image_height, image_width=image_width, kernel_size=0.01,bg=(0.0, 0.0, 0.0))
|
| 267 |
rasterizer = GaussianRasterizer(settings)
|
| 268 |
# Assume state has necessary data like means3D, scales, etc.
|
| 269 |
rendered_image, rendered_depth, _, _, _, _ = rasterizer(means3D=state["means3D"], means2D=state["means2D"], shs=state["shs"], colors_precomp=state["colors_precomp"], opacities=state["opacities"], scales=state["scales"], rotations=state["rotations"], cov3D_precomp=state["cov3D_precomp"])
|
utils/image_utils.py
CHANGED
|
@@ -276,7 +276,7 @@ def resize_image_with_aspect_ratio(image, target_width, target_height):
|
|
| 276 |
original_width, original_height = image.size
|
| 277 |
target_aspect = target_width / target_height
|
| 278 |
original_aspect = original_width / original_height
|
| 279 |
-
|
| 280 |
# Decide whether to fit width or height
|
| 281 |
if original_aspect > target_aspect:
|
| 282 |
# Image is wider than target aspect ratio
|
|
@@ -289,6 +289,7 @@ def resize_image_with_aspect_ratio(image, target_width, target_height):
|
|
| 289 |
|
| 290 |
# Resize the image
|
| 291 |
resized_image = image.resize((new_width, new_height), Image.LANCZOS)
|
|
|
|
| 292 |
|
| 293 |
# Create a new image with target dimensions and black background
|
| 294 |
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
|
|
|
|
| 276 |
original_width, original_height = image.size
|
| 277 |
target_aspect = target_width / target_height
|
| 278 |
original_aspect = original_width / original_height
|
| 279 |
+
#print(f"Original size: {image.size}\ntarget_aspect: {target_aspect}\noriginal_aspect: {original_aspect}\n")
|
| 280 |
# Decide whether to fit width or height
|
| 281 |
if original_aspect > target_aspect:
|
| 282 |
# Image is wider than target aspect ratio
|
|
|
|
| 289 |
|
| 290 |
# Resize the image
|
| 291 |
resized_image = image.resize((new_width, new_height), Image.LANCZOS)
|
| 292 |
+
#print(f"Resized size: {resized_image.size}\n")
|
| 293 |
|
| 294 |
# Create a new image with target dimensions and black background
|
| 295 |
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
|