Orient-Anything-V2 / vision_tower.py
zhangziang
initial commit track binary
f783161
raw
history blame
10.4 kB
# import sys
# sys.path.append("..")
import torch
from torch import nn
import torch.nn.init as init
import torch.nn.functional as F
from paths import *
from typing import Dict, List, Optional, Set, Tuple, Union
import os
from contextlib import nullcontext
from vggt.models.vggt import VGGT
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
from vggt.layers import Mlp
from vggt.layers.block import Block
from vggt.heads.head_act import activate_pose
class OriAny_CameraHead(nn.Module):
"""
CameraHead predicts camera parameters from token representations using iterative refinement.
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
"""
def __init__(
self,
dim_in: int = 2048,
trunk_depth: int = 4,
pose_encoding_type: str = "OriAny",
num_heads: int = 16,
mlp_ratio: int = 4,
init_values: float = 0.01,
):
super().__init__()
if pose_encoding_type == "OriAny":
self.target_dim = 360+180+360+2
else:
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
self.trunk_depth = trunk_depth
# Build the trunk using a sequence of transformer blocks.
self.trunk = nn.Sequential(
*[
Block(
dim=dim_in,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
init_values=init_values,
)
for _ in range(trunk_depth)
]
)
# Normalizations for camera token and trunk output.
self.token_norm = nn.LayerNorm(dim_in)
self.trunk_norm = nn.LayerNorm(dim_in)
# Learnable empty camera pose token.
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
self.embed_pose = nn.Linear(self.target_dim, dim_in)
# Module for producing modulation parameters: shift, scale, and a gate.
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
# Adaptive layer normalization without affine parameters.
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
self.pose_branch = Mlp(
in_features=dim_in,
hidden_features=dim_in // 2,
out_features=self.target_dim,
drop=0,
)
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
"""
Forward pass to predict camera parameters.
Args:
aggregated_tokens_list (list): List of token tensors from the network;
the last tensor is used for prediction.
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
Returns:
list: A list of predicted camera encodings (post-activation) from each iteration.
"""
# Use tokens from the last block for camera prediction.
tokens = aggregated_tokens_list[-1]
# Extract the camera tokens
pose_tokens = tokens[:, :, 0]
pose_tokens = self.token_norm(pose_tokens)
pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
return pred_pose_enc_list
def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
"""
Iteratively refine camera pose predictions.
Args:
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
num_iterations (int): Number of refinement iterations.
Returns:
list: List of activated camera encodings from each iteration.
"""
B, S, C = pose_tokens.shape # S is expected to be 1.
pred_pose_enc = None
pred_pose_enc_list = []
for _ in range(num_iterations):
# Use a learned empty pose for the first iteration.
if pred_pose_enc is None:
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
else:
# Detach the previous prediction to avoid backprop through time.
pred_pose_enc = pred_pose_enc.detach()
module_input = self.embed_pose(pred_pose_enc)
# Generate modulation parameters and split them into shift, scale, and gate components.
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
# Adaptive layer normalization and modulation.
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
pose_tokens_modulated = self.trunk(pose_tokens_modulated)
# Compute the delta update for the pose encoding.
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
if pred_pose_enc is None:
pred_pose_enc = pred_pose_enc_delta
else:
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
# Apply final activation functions for translation, quaternion, and field-of-view.
# activated_pose = activate_pose(
# pred_pose_enc,
# trans_act=self.trans_act,
# quat_act=self.quat_act,
# fl_act=self.fl_act,
# )
# pred_pose_enc_list.append(activated_pose)
pred_pose_enc_list.append(pred_pose_enc)
return pred_pose_enc_list
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""
Modulate the input tensor using scaling and shifting parameters.
"""
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
return x * (1 + scale) + shift
def load_patch_embed_weights(model, checkpoint_path):
# 1. 加载 checkpoint
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# 2. 获取 state_dict
state_dict = checkpoint.get("state_dict", checkpoint)
# 3. 过滤只包含 aggregator.patch_embed 的参数
patch_embed_state = {
k.replace("aggregator.patch_embed.", ""): v
for k, v in state_dict.items()
if k.startswith("aggregator.patch_embed.")
}
# 4. 加载到目标模块
missing_keys, unexpected_keys = model.aggregator.patch_embed.load_state_dict(
patch_embed_state, strict=False
)
print("Loaded patch_embed weights.")
print("Missing keys:", missing_keys)
print("Unexpected keys:", unexpected_keys)
class VGGT_OriAny_Ref(nn.Module):
def __init__(self,
dtype,
out_dim,
nopretrain
) -> None:
super().__init__()
self.vggt = VGGT()
self.dtype = dtype
self.ref_sampler = MLP_dim(in_dim=2048, out_dim=out_dim)
self.ref_sampler.apply(init_weights)
self.tgt_sampler = MLP_dim(in_dim=2048, out_dim=out_dim)
self.tgt_sampler.apply(init_weights)
def forward(self, img_inputs):
device = self.get_device()
with torch.amp.autocast(device_type='cuda', dtype=self.dtype):
if img_inputs.shape == 4:
img_inputs = img_inputs[None]
aggregated_tokens_list, ps_idx = self.vggt.aggregator(img_inputs)
# Predict Cameras
# pose_enc = self.oriany_camera_head(aggregated_tokens_list)[-1]
# Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
# extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
# Use tokens from the last block for camera prediction.
tokens = aggregated_tokens_list[-1]
# Extract the camera tokens
pose_tokens = tokens[:, :, 0]
# tokens = aggregated_tokens_list[-1]
B, S, C = pose_tokens.shape
if S>1:
# 分离每个 batch 的第一个 token 和其余 token
ref_tokens = pose_tokens[:, 0, :] # shape: (B, C)
tgt_tokens = pose_tokens[:, 1:, :] # shape: (B, S-1, C)
# 下采样
ref_feat = self.ref_sampler(ref_tokens) # shape: (B, C'),假设输出 channel 为 C'
tgt_feat = self.tgt_sampler(tgt_tokens.reshape(B * (S - 1), C)) # shape: (B*(S-1), C')
# 合并结果
pose_enc = torch.cat([
ref_feat.unsqueeze(1), # (B, 1, C')
tgt_feat.view(B, S - 1, -1) # (B, S-1, C')
], dim=1) # 最终 shape: (B*S, C')
else:
pose_enc = self.ref_sampler(pose_tokens.view(B*S,C))
return pose_enc
def get_device(self):
return next(self.parameters()).device
def init_weights(m):
if isinstance(m, nn.Linear):
init.xavier_uniform_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
def get_activation(activation):
if activation.lower() == 'gelu':
return nn.GELU()
elif activation.lower() == 'rrelu':
return nn.RReLU(inplace=True)
elif activation.lower() == 'selu':
return nn.SELU(inplace=True)
elif activation.lower() == 'silu':
return nn.SiLU(inplace=True)
elif activation.lower() == 'hardswish':
return nn.Hardswish(inplace=True)
elif activation.lower() == 'leakyrelu':
return nn.LeakyReLU(inplace=True)
elif activation.lower() == 'sigmoid':
return nn.Sigmoid()
elif activation.lower() == 'tanh':
return nn.Tanh()
else:
return nn.ReLU(inplace=True)
class MLP_dim(nn.Module):
def __init__(
self, in_dim=512, out_dim=1024, bias=True, activation='relu'):
super().__init__()
self.act = get_activation(activation)
self.net1 = nn.Sequential(
nn.Linear(in_dim, int(out_dim), bias=bias),
nn.BatchNorm1d(int(out_dim)),
self.act
)
self.net2 = nn.Sequential(
nn.Linear(int(out_dim), out_dim, bias=bias),
nn.BatchNorm1d(out_dim)
)
def forward(self, x):
return self.net2(self.net1(x))