Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This source code is licensed under the Apache License, Version 2.0 | |
| # found in the LICENSE file in the root directory of this source tree. | |
| # Implementation of 2D Rotary Position Embeddings (RoPE). | |
| # This module provides a clean implementation of 2D Rotary Position Embeddings, | |
| # which extends the original RoPE concept to handle 2D spatial positions. | |
| # Inspired by: | |
| # https://github.com/meta-llama/codellama/blob/main/llama/model.py | |
| # https://github.com/naver-ai/rope-vit | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Dict, Tuple | |
| class PositionGetter: | |
| """Generates and caches 2D spatial positions for patches in a grid. | |
| This class efficiently manages the generation of spatial coordinates for patches | |
| in a 2D grid, caching results to avoid redundant computations. | |
| Attributes: | |
| position_cache: Dictionary storing precomputed position tensors for different | |
| grid dimensions. | |
| """ | |
| def __init__(self): | |
| """Initializes the position generator with an empty cache.""" | |
| self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {} | |
| def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor: | |
| """Generates spatial positions for a batch of patches. | |
| Args: | |
| batch_size: Number of samples in the batch. | |
| height: Height of the grid in patches. | |
| width: Width of the grid in patches. | |
| device: Target device for the position tensor. | |
| Returns: | |
| Tensor of shape (batch_size, height*width, 2) containing y,x coordinates | |
| for each position in the grid, repeated for each batch item. | |
| """ | |
| if (height, width) not in self.position_cache: | |
| y_coords = torch.arange(height, device=device) | |
| x_coords = torch.arange(width, device=device) | |
| positions = torch.cartesian_prod(y_coords, x_coords) | |
| self.position_cache[height, width] = positions | |
| cached_positions = self.position_cache[height, width] | |
| return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone() | |
| class RotaryPositionEmbedding2D(nn.Module): | |
| """2D Rotary Position Embedding implementation. | |
| This module applies rotary position embeddings to input tokens based on their | |
| 2D spatial positions. It handles the position-dependent rotation of features | |
| separately for vertical and horizontal dimensions. | |
| Args: | |
| frequency: Base frequency for the position embeddings. Default: 100.0 | |
| scaling_factor: Scaling factor for frequency computation. Default: 1.0 | |
| Attributes: | |
| base_frequency: Base frequency for computing position embeddings. | |
| scaling_factor: Factor to scale the computed frequencies. | |
| frequency_cache: Cache for storing precomputed frequency components. | |
| """ | |
| def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0): | |
| """Initializes the 2D RoPE module.""" | |
| super().__init__() | |
| self.base_frequency = frequency | |
| self.scaling_factor = scaling_factor | |
| self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {} | |
| def _compute_frequency_components( | |
| self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Computes frequency components for rotary embeddings. | |
| Args: | |
| dim: Feature dimension (must be even). | |
| seq_len: Maximum sequence length. | |
| device: Target device for computations. | |
| dtype: Data type for the computed tensors. | |
| Returns: | |
| Tuple of (cosine, sine) tensors for frequency components. | |
| """ | |
| cache_key = (dim, seq_len, device, dtype) | |
| if cache_key not in self.frequency_cache: | |
| # Compute frequency bands | |
| exponents = torch.arange(0, dim, 2, device=device).float() / dim | |
| inv_freq = 1.0 / (self.base_frequency**exponents) | |
| # Generate position-dependent frequencies | |
| positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) | |
| angles = torch.einsum("i,j->ij", positions, inv_freq) | |
| # Compute and cache frequency components | |
| angles = angles.to(dtype) | |
| angles = torch.cat((angles, angles), dim=-1) | |
| cos_components = angles.cos().to(dtype) | |
| sin_components = angles.sin().to(dtype) | |
| self.frequency_cache[cache_key] = (cos_components, sin_components) | |
| return self.frequency_cache[cache_key] | |
| def _rotate_features(x: torch.Tensor) -> torch.Tensor: | |
| """Performs feature rotation by splitting and recombining feature dimensions. | |
| Args: | |
| x: Input tensor to rotate. | |
| Returns: | |
| Rotated feature tensor. | |
| """ | |
| feature_dim = x.shape[-1] | |
| x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| def _apply_1d_rope( | |
| self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor | |
| ) -> torch.Tensor: | |
| """Applies 1D rotary position embeddings along one dimension. | |
| Args: | |
| tokens: Input token features. | |
| positions: Position indices. | |
| cos_comp: Cosine components for rotation. | |
| sin_comp: Sine components for rotation. | |
| Returns: | |
| Tokens with applied rotary position embeddings. | |
| """ | |
| # Embed positions with frequency components | |
| cos = F.embedding(positions, cos_comp)[:, None, :, :] | |
| sin = F.embedding(positions, sin_comp)[:, None, :, :] | |
| # Apply rotation | |
| return (tokens * cos) + (self._rotate_features(tokens) * sin) | |
| def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: | |
| """Applies 2D rotary position embeddings to input tokens. | |
| Args: | |
| tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim). | |
| The feature dimension (dim) must be divisible by 4. | |
| positions: Position tensor of shape (batch_size, n_tokens, 2) containing | |
| the y and x coordinates for each token. | |
| Returns: | |
| Tensor of same shape as input with applied 2D rotary position embeddings. | |
| Raises: | |
| AssertionError: If input dimensions are invalid or positions are malformed. | |
| """ | |
| # Validate inputs | |
| assert tokens.size(-1) % 2 == 0, "Feature dimension must be even" | |
| assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)" | |
| # Compute feature dimension for each spatial direction | |
| feature_dim = tokens.size(-1) // 2 | |
| # Get frequency components | |
| max_position = int(positions.max()) + 1 | |
| cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype) | |
| # Split features for vertical and horizontal processing | |
| vertical_features, horizontal_features = tokens.chunk(2, dim=-1) | |
| # Apply RoPE separately for each dimension | |
| vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp) | |
| horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp) | |
| # Combine processed features | |
| return torch.cat((vertical_features, horizontal_features), dim=-1) | |