Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Simple 3D skeleton motion visualizer for HumanML3D motion data. | |
| Usage: python visualize.py <motion.pt> [--output output.mp4] [--fps 20] | |
| """ | |
| import argparse | |
| import numpy as np | |
| import torch | |
| import matplotlib.pyplot as plt | |
| from matplotlib.animation import FuncAnimation, FFMpegWriter | |
| from pathlib import Path | |
| # HumanML3D skeleton structure (22 joints) | |
| # Kinematic chain based on HumanML3D dataset specification | |
| # From mld/utils/joints.py and datasets/HumanML3D/paramUtil.py | |
| SKELETON_CHAINS = [ | |
| [0, 3, 6, 9, 12, 15], # Body: root -> BP -> BT -> BLN -> BMN -> BUN (head) | |
| [9, 14, 17, 19, 21], # Left arm: BLN -> LSI -> LS -> LE -> LW | |
| [9, 13, 16, 18, 20], # Right arm: BLN -> RSI -> RS -> RE -> RW | |
| [0, 2, 5, 8, 11], # Left leg: root -> LH -> LK -> LMrot -> LF | |
| [0, 1, 4, 7, 10], # Right leg: root -> RH -> RK -> RMrot -> RF | |
| ] | |
| def load_motion(pt_path: str) -> np.ndarray: | |
| """ | |
| Load motion data from .pt file (PyTorch tensor). | |
| HumanML3D format: (frames, 22, 3) where last dimension is (x, y, z) | |
| In HumanML3D: Y is vertical (up), X and Z are horizontal | |
| For proper 3D visualization: we'll map Y -> Z (vertical), X -> X, Z -> Y | |
| Returns numpy array for matplotlib visualization. | |
| """ | |
| # Load PyTorch tensor and convert to numpy for visualization | |
| motion_tensor = torch.load(pt_path, map_location="cpu") | |
| motion = motion_tensor.numpy() | |
| print(f"Loaded motion: {motion.shape}") | |
| print(f" Frames: {motion.shape[0]}") | |
| print(f" Joints: {motion.shape[1]}") | |
| print(f" Dimensions: {motion.shape[2]}") | |
| # Remap axes: HumanML3D (x, y, z) -> Visualization (x, z, y) | |
| # This makes Y axis (vertical in HumanML3D) become Z axis (vertical in plot) | |
| motion_remapped = motion.copy() | |
| motion_remapped[:, :, [0, 1, 2]] = motion[:, :, [0, 2, 1]] # x, z, y <- x, y, z | |
| return motion_remapped | |
| def setup_3d_plot(): | |
| """Set up the 3D plot with proper viewing angle.""" | |
| fig = plt.figure(figsize=(10, 10)) | |
| ax = fig.add_subplot(111, projection="3d") | |
| # Set labels | |
| ax.set_xlabel("X") | |
| ax.set_ylabel("Y") | |
| ax.set_zlabel("Z") | |
| return fig, ax | |
| def update_frame(frame_idx: int, motion: np.ndarray, ax, lines: list, points: list): | |
| """Update function for animation.""" | |
| ax.clear() | |
| # Get current frame | |
| frame = motion[frame_idx] | |
| # Set consistent axis limits based on all frames | |
| all_coords = motion.reshape(-1, 3) | |
| margin = 0.5 | |
| x_range = [all_coords[:, 0].min() - margin, all_coords[:, 0].max() + margin] | |
| y_range = [all_coords[:, 1].min() - margin, all_coords[:, 1].max() + margin] | |
| z_range = [0, all_coords[:, 2].max() + margin] # Z starts at ground (0) | |
| ax.set_xlim(x_range) | |
| ax.set_ylim(y_range) | |
| ax.set_zlim(z_range) | |
| # Set labels and title | |
| ax.set_xlabel("X", fontsize=10) | |
| ax.set_ylabel("Y", fontsize=10) | |
| ax.set_zlabel("Z (Height)", fontsize=10) | |
| ax.set_title(f"Frame {frame_idx + 1}/{len(motion)}", fontsize=14, pad=20) | |
| # Set viewing angle (slightly elevated, rotated for better view) | |
| ax.view_init(elev=15, azim=45) | |
| # Draw ground plane at z=0 | |
| xx, yy = np.meshgrid( | |
| np.linspace(x_range[0], x_range[1], 2), np.linspace(y_range[0], y_range[1], 2) | |
| ) | |
| zz = np.zeros_like(xx) | |
| ax.plot_surface(xx, yy, zz, alpha=0.1, color="gray") | |
| # Plot skeleton bones with different colors for different parts | |
| colors = ["red", "blue", "green", "cyan", "magenta"] | |
| for chain_idx, chain in enumerate(SKELETON_CHAINS): | |
| color = colors[chain_idx % len(colors)] | |
| for i in range(len(chain) - 1): | |
| j1, j2 = chain[i], chain[i + 1] | |
| if j1 < len(frame) and j2 < len(frame): | |
| xs = [frame[j1, 0], frame[j2, 0]] | |
| ys = [frame[j1, 1], frame[j2, 1]] | |
| zs = [frame[j1, 2], frame[j2, 2]] | |
| linewidth = 4.0 if chain_idx == 0 else 3.0 # Thicker for body | |
| ax.plot(xs, ys, zs, color=color, linewidth=linewidth, alpha=0.8) | |
| # Plot joints (darker red) | |
| ax.scatter( | |
| frame[:, 0], | |
| frame[:, 1], | |
| frame[:, 2], | |
| c="darkred", | |
| marker="o", | |
| s=50, | |
| alpha=0.9, | |
| edgecolors="black", | |
| linewidth=0.5, | |
| ) | |
| # Add grid | |
| ax.grid(True, alpha=0.3) | |
| return (ax,) | |
| def create_video_from_joints( | |
| joints: torch.Tensor | np.ndarray, output_path: str, fps: int = 20 | |
| ) -> str: | |
| """ | |
| Create 3D skeleton animation directly from joint tensor or array. | |
| Args: | |
| joints: Joint positions as torch.Tensor or np.ndarray (frames, 22, 3) | |
| output_path: Path to save video | |
| fps: Frames per second for the video | |
| Returns: | |
| Path to output video | |
| """ | |
| # Convert to numpy if it's a torch tensor | |
| if isinstance(joints, torch.Tensor): | |
| joints = joints.cpu().numpy() | |
| # Remap axes for visualization (same as load_motion) | |
| motion = joints.copy() | |
| motion[:, :, [0, 1, 2]] = joints[:, :, [0, 2, 1]] # x, z, y <- x, y, z | |
| # Set up plot | |
| fig, ax = setup_3d_plot() | |
| lines, points = [], [] | |
| # Create animation | |
| anim = FuncAnimation( | |
| fig, | |
| update_frame, | |
| frames=len(motion), | |
| fargs=(motion, ax, lines, points), | |
| interval=1000 / fps, | |
| blit=False, | |
| repeat=True, | |
| ) | |
| # Save video using FFMpeg | |
| writer = FFMpegWriter(fps=fps, bitrate=1800, codec="libx264") | |
| anim.save(str(output_path), writer=writer, dpi=100) | |
| plt.close(fig) | |
| return str(output_path) | |
| def visualize_motion( | |
| pt_path: str, output_path: str | None = None, fps: int = 20, show: bool = False | |
| ) -> str: | |
| """ | |
| Visualize motion from .pt file (PyTorch tensor). | |
| Args: | |
| pt_path: Path to .pt motion file | |
| output_path: Path to save video (if None, will auto-generate) | |
| fps: Frames per second for the video | |
| show: If True, display the animation in a window | |
| Returns: | |
| Path to the generated video file | |
| """ | |
| # Load motion data (converts to numpy internally for matplotlib) | |
| motion = load_motion(pt_path) | |
| # Create output path if not specified | |
| if output_path is None: | |
| output_path = Path(pt_path).with_suffix(".mp4") | |
| else: | |
| output_path = Path(output_path) | |
| print(f"\nCreating animation with {fps} FPS...") | |
| # Set up plot | |
| fig, ax = setup_3d_plot() | |
| lines, points = [], [] | |
| # Create animation | |
| anim = FuncAnimation( | |
| fig, | |
| update_frame, | |
| frames=len(motion), | |
| fargs=(motion, ax, lines, points), | |
| interval=1000 / fps, | |
| blit=False, | |
| repeat=True, | |
| ) | |
| # Save video using FFMpeg | |
| print(f"Saving video to: {output_path}") | |
| writer = FFMpegWriter(fps=fps, bitrate=1800, codec="libx264") | |
| anim.save(str(output_path), writer=writer, dpi=100) | |
| print("✓ Video saved successfully!") | |
| # Show animation if requested | |
| if show: | |
| plt.show() | |
| plt.close(fig) | |
| return str(output_path) | |
| def main() -> int: | |
| """Main entry point for CLI""" | |
| parser = argparse.ArgumentParser( | |
| description="Visualize HumanML3D motion data as 3D skeleton animation" | |
| ) | |
| parser.add_argument("input", type=str, help="Path to input .pt motion file") | |
| parser.add_argument( | |
| "--output", | |
| "-o", | |
| type=str, | |
| default=None, | |
| help="Path to output video file (default: input_name.mp4)", | |
| ) | |
| parser.add_argument( | |
| "--fps", | |
| type=int, | |
| default=20, | |
| help="Frames per second for output video (default: 20)", | |
| ) | |
| parser.add_argument( | |
| "--show", | |
| action="store_true", | |
| help="Display the animation in a window (in addition to saving)", | |
| ) | |
| args = parser.parse_args() | |
| # Check if input file exists | |
| input_path = Path(args.input) | |
| if not input_path.exists(): | |
| print(f"Error: Input file not found: {args.input}") | |
| return 1 | |
| # Visualize the motion | |
| try: | |
| output_path = visualize_motion( | |
| args.input, output_path=args.output, fps=args.fps, show=args.show | |
| ) | |
| print(f"\n✓ Done! Video saved to: {output_path}") | |
| return 0 | |
| except Exception as e: | |
| print(f"\n✗ Error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return 1 | |
| if __name__ == "__main__": | |
| exit(main()) | |