Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch.nn as nn | |
| from .dpt_head import DPTHead | |
| from .track_modules.base_track_predictor import BaseTrackerPredictor | |
| class TrackHead(nn.Module): | |
| """ | |
| Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking. | |
| The tracking is performed iteratively, refining predictions over multiple iterations. | |
| """ | |
| def __init__( | |
| self, | |
| dim_in, | |
| patch_size=14, | |
| features=128, | |
| iters=4, | |
| predict_conf=True, | |
| stride=2, | |
| corr_levels=7, | |
| corr_radius=4, | |
| hidden_size=384, | |
| ): | |
| """ | |
| Initialize the TrackHead module. | |
| Args: | |
| dim_in (int): Input dimension of tokens from the backbone. | |
| patch_size (int): Size of image patches used in the vision transformer. | |
| features (int): Number of feature channels in the feature extractor output. | |
| iters (int): Number of refinement iterations for tracking predictions. | |
| predict_conf (bool): Whether to predict confidence scores for tracked points. | |
| stride (int): Stride value for the tracker predictor. | |
| corr_levels (int): Number of correlation pyramid levels | |
| corr_radius (int): Radius for correlation computation, controlling the search area. | |
| hidden_size (int): Size of hidden layers in the tracker network. | |
| """ | |
| super().__init__() | |
| self.patch_size = patch_size | |
| # Feature extractor based on DPT architecture | |
| # Processes tokens into feature maps for tracking | |
| self.feature_extractor = DPTHead( | |
| dim_in=dim_in, | |
| patch_size=patch_size, | |
| features=features, | |
| feature_only=True, # Only output features, no activation | |
| down_ratio=2, # Reduces spatial dimensions by factor of 2 | |
| pos_embed=False, | |
| ) | |
| # Tracker module that predicts point trajectories | |
| # Takes feature maps and predicts coordinates and visibility | |
| self.tracker = BaseTrackerPredictor( | |
| latent_dim=features, # Match the output_dim of feature extractor | |
| predict_conf=predict_conf, | |
| stride=stride, | |
| corr_levels=corr_levels, | |
| corr_radius=corr_radius, | |
| hidden_size=hidden_size, | |
| ) | |
| self.iters = iters | |
| def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None): | |
| """ | |
| Forward pass of the TrackHead. | |
| Args: | |
| aggregated_tokens_list (list): List of aggregated tokens from the backbone. | |
| images (torch.Tensor): Input images of shape (B, S, C, H, W) where: | |
| B = batch size, S = sequence length. | |
| patch_start_idx (int): Starting index for patch tokens. | |
| query_points (torch.Tensor, optional): Initial query points to track. | |
| If None, points are initialized by the tracker. | |
| iters (int, optional): Number of refinement iterations. If None, uses self.iters. | |
| Returns: | |
| tuple: | |
| - coord_preds (torch.Tensor): Predicted coordinates for tracked points. | |
| - vis_scores (torch.Tensor): Visibility scores for tracked points. | |
| - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True). | |
| """ | |
| B, S, _, H, W = images.shape | |
| # Extract features from tokens | |
| # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2 | |
| feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx) | |
| # Use default iterations if not specified | |
| if iters is None: | |
| iters = self.iters | |
| # Perform tracking using the extracted features | |
| coord_preds, vis_scores, conf_scores = self.tracker( | |
| query_points=query_points, | |
| fmaps=feature_maps, | |
| iters=iters, | |
| ) | |
| return coord_preds, vis_scores, conf_scores | |